Coverage for packages / dqm-ml-pytorch / src / dqm_ml_pytorch / domain_gap.py: 89%
174 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-15 10:11 +0000
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-15 10:11 +0000
1"""Domain gap processor for measuring distribution distance between datasets.
3This module contains the DomainGapProcessor class that computes statistical
4distances (KL divergence, MMD, FID, Wasserstein) between source and target
5datasets using image embeddings.
6"""
8from __future__ import annotations
10import logging
11from typing import Any
13import numpy as np
14import pyarrow as pa
16# COMPATIBILITY : from typing import Any, override # When support of 3.10 and 3.11 will be removed
17from typing_extensions import override
19from dqm_ml_core import DatametricProcessor
21logger = logging.getLogger(__name__)
24class DomainGapProcessor(DatametricProcessor):
25 """
26 Computes statistical distances between source and target dataselections using image embeddings.
28 This processor works in two stages:
29 1. Dataset Summary: Aggregates high-dimensional embeddings into compact statistics
30 (mean, variance, outer products, histograms).
31 2. Delta Computation: Uses these summaries to calculate distance metrics between
32 a source and a target dataset.
34 Supported Delta Metrics:
35 - `klmvn_diag`: Kullback-Leibler Divergence assuming a multivariate Normal
36 distribution with a diagonal covariance matrix.
37 - `mmd_linear`: Maximum Mean Discrepancy with a linear kernel.
38 - `fid`: Frechet Inception Distance. Measures distance between two Gaussians
39 fitted to feature representations (requires full covariance calculation).
40 - `wasserstein_1d`: Average 1D Wasserstein distance across embedding dimensions,
41 approximated via histograms.
42 """
44 def __init__(
45 self,
46 name: str = "image_embedding",
47 config: dict[str, Any] | None = None,
48 ):
49 """
50 Initialize the domain gap processor.
52 Args:
53 name: Unique name of the processor instance.
54 config: Configuration dictionary containing:
55 - INPUT:
56 - embedding_col: Column name containing embeddings (default: "embedding").
57 - SUMMARY:
58 - collect_sum_outer: Whether to compute outer products (needed for FID).
59 - collect_hist_1d: Whether to compute histograms (needed for Wasserstein).
60 - hist_dims: Number of dimensions to histogram.
61 - hist_bins: Number of bins per histogram.
62 - DELTA:
63 - metric: Target metric ("klmvn_diag", "mmd_linear", "fid", "wasserstein_1d").
64 """
65 super().__init__(name, config)
66 self._checked = False
68 # ---------------- API ----------------
69 def check_config(self) -> None:
70 """Validate and configure the domain gap processor.
72 This method parses the configuration dictionary and sets:
73 - INPUT: Embedding column name
74 - SUMMARY: Options for collecting summary statistics (outer products, histograms)
75 - DELTA: The target metric to compute (klmvn_diag, mmd_linear, fid, wasserstein_1d)
76 """
77 cfg = self.config or {}
78 icfg = cfg.get("INPUT", {})
79 self.embedding_col: str = icfg.get("embedding_col", "embedding")
81 dcfg = cfg.get("DELTA", {})
82 self.delta_metric: str = str(dcfg.get("metric", "klmvn_diag")).lower()
83 scfg = cfg.get("SUMMARY", {})
85 if self.delta_metric == "fid":
86 auto_sum_outer = True
87 auto_hist_1d = False
88 elif self.delta_metric == "wasserstein_1d":
89 auto_sum_outer = False
90 auto_hist_1d = True
91 else: # klmvn_diag, mmd_linear
92 auto_sum_outer = False
93 auto_hist_1d = False
95 self.collect_sum_outer: bool = bool(scfg.get("collect_sum_outer", auto_sum_outer))
96 self.collect_hist_1d: bool = bool(scfg.get("collect_hist_1d", auto_hist_1d))
98 # Wasserstein-1D parameters
99 self.hist_dims: int = int(scfg.get("hist_dims", 64))
100 self.hist_bins: int = int(scfg.get("hist_bins", 32))
101 rng = scfg.get("hist_range", [-3.0, 3.0])
102 self.hist_range: tuple[float, float] = (float(rng[0]), float(rng[1]))
104 self._checked = True
106 @override
107 def needed_columns(self) -> list[str]:
108 if not getattr(self, "_checked", False):
109 self.check_config()
110 return [self.embedding_col]
112 def generated_columns(self) -> list[str]:
113 """Return the list of columns generated by this processor.
115 Returns:
116 Empty list as this processor computes deltas rather than features.
117 """
118 return []
120 @override
121 def compute_batch_metric(self, features: dict[str, pa.Array]) -> dict[str, pa.Array]:
122 """
123 Reduce a batch of embeddings into summary statistics.
125 Returns a dictionary containing:
126 - count: Number of samples.
127 - sum: Element-wise sum of embeddings.
128 - sum_sq: Element-wise sum of squared embeddings.
129 - sum_outer: (Optional) Sum of outer products (for FID).
130 - hist_counts: (Optional) Flattened histogram counts (for Wasserstein).
131 """
132 if not getattr(self, "_checked", False): 132 ↛ 133line 132 didn't jump to line 133 because the condition on line 132 was never true
133 self.check_config()
135 emb = features.get(self.embedding_col)
136 if emb is None or not isinstance(emb, pa.FixedSizeListArray): 136 ↛ 137line 136 didn't jump to line 137 because the condition on line 136 was never true
137 return {}
139 n = len(emb)
140 # d = emb.list_size
141 d = len(emb[0])
142 child = emb.values # flat
143 arr = np.asarray(child.to_numpy()).reshape(n, d)
145 out: dict[str, pa.Array] = {}
146 out["count"] = pa.array([n], type=pa.int64())
147 out["sum"] = pa.FixedSizeListArray.from_arrays(pa.array(arr.sum(axis=0).astype(np.float64)), d)
148 out["sum_sq"] = pa.FixedSizeListArray.from_arrays(pa.array((arr * arr).sum(axis=0).astype(np.float64)), d)
150 # optional: sum_outer for FID
151 if self.collect_sum_outer:
152 s = (arr.T @ arr).reshape(-1).astype(np.float64)
153 out["sum_outer"] = pa.FixedSizeListArray.from_arrays(pa.array(s), d * d)
155 # optional: histograms for Wasserstein-1D
156 if self.collect_hist_1d:
157 use_dims = min(d, self.hist_dims)
158 low, high = self.hist_range
159 hist_list: list[np.ndarray] = []
160 for j in range(use_dims):
161 h, _ = np.histogram(arr[:, j], bins=self.hist_bins, range=(low, high))
162 hist_list.append(h.astype(np.int64))
163 h = np.stack(hist_list, axis=0).reshape(-1)
164 out["hist_counts"] = pa.FixedSizeListArray.from_arrays(pa.array(h), self.hist_bins * use_dims)
166 return out
168 @override
169 def compute(self, batch_metrics: dict[str, pa.Array]) -> dict[str, pa.Array]:
170 """Aggregate batch-level summary statistics into global dataselection statistics.
172 Args:
173 batch_metrics: Dictionary containing batch-level statistics (count, sum, sum_sq, etc.).
175 Returns:
176 Dictionary containing aggregated dataset-level statistics.
177 """
178 if not batch_metrics: 178 ↛ 179line 178 didn't jump to line 179 because the condition on line 178 was never true
179 return {}
181 def _sum_scalar(a: pa.Array) -> int:
182 return int(np.asarray(a.to_numpy()).sum())
184 def _sum_fixed(v: pa.FixedSizeListArray) -> tuple[np.ndarray, int]:
185 vals = np.asarray(v.values.to_numpy(), dtype=np.float64)
186 d = len(v[0])
187 # d = v.list_size
188 return vals.reshape(-1, d).sum(axis=0), d
190 out: dict[str, pa.Array] = {}
192 # count
193 if "count" not in batch_metrics: 193 ↛ 194line 193 didn't jump to line 194 because the condition on line 193 was never true
194 return {}
195 total_n = _sum_scalar(batch_metrics["count"])
196 out["count"] = pa.array([total_n], type=pa.int64())
198 # sum / sum_sq
199 if "sum" in batch_metrics: 199 ↛ 202line 199 didn't jump to line 202 because the condition on line 199 was always true
200 s, d = _sum_fixed(batch_metrics["sum"])
201 out["sum"] = pa.FixedSizeListArray.from_arrays(pa.array(s), d)
202 if "sum_sq" in batch_metrics: 202 ↛ 207line 202 didn't jump to line 207 because the condition on line 202 was always true
203 s2, d2 = _sum_fixed(batch_metrics["sum_sq"])
204 out["sum_sq"] = pa.FixedSizeListArray.from_arrays(pa.array(s2), d2)
206 # optional sum_outer
207 if "sum_outer" in batch_metrics:
208 so_vals = np.asarray(batch_metrics["sum_outer"].values.to_numpy(), dtype=np.float64)
209 dd = len(batch_metrics["sum_outer"][0])
210 # dd = batch_metrics["sum_outer"].list_size
211 out["sum_outer"] = pa.FixedSizeListArray.from_arrays(pa.array(so_vals.reshape(-1, dd).sum(axis=0)), dd)
213 # optional hist_counts
214 if "hist_counts" in batch_metrics:
215 h_vals = np.asarray(batch_metrics["hist_counts"].values.to_numpy(), dtype=np.int64)
216 h_len = len(batch_metrics["hist_counts"][0])
217 # h_len = batch_metrics["hist_counts"].list_size
218 out["hist_counts"] = pa.FixedSizeListArray.from_arrays(
219 pa.array(h_vals.reshape(-1, h_len).sum(axis=0)), h_len
220 )
222 return out
224 @override
225 def compute_delta(self, source: dict[str, pa.Array], target: dict[str, pa.Array]) -> dict[str, pa.Array]:
226 """
227 Calculate the domain gap metric between source and target dataselection statistics.
229 Args:
230 source: Dataselection statistics from the source dataset (computed via `compute`).
231 target: Dataselection statistics from the target dataset (computed via `compute`).
233 Returns:
234 Dictionary containing the calculated metric value.
235 """
236 if not getattr(self, "_checked", False): 236 ↛ 237line 236 didn't jump to line 237 because the condition on line 236 was never true
237 self.check_config()
238 # TODO : check config and available metrics outside of computation
239 # TODO : add a return code error in th API
241 def vec(
242 a: pa.FixedSizeListArray,
243 ) -> Any: # TODO : check type error np.ndarray
244 """Aggregate a FixedSizeListArray into a single numpy vector by summing all lists."""
245 len_a = len(a[0])
246 # len_a = a.list_size
247 array = np.asarray(a.values.to_numpy(), dtype=np.float64).reshape(-1, len_a).sum(axis=0)
248 return array # type : ignore[no-any-return]
250 def scalar(a: pa.Array) -> float:
251 """Sum all elements in a pyarrow Array and return as a float."""
252 return float(np.asarray(a.to_numpy()).sum())
254 metric = self.delta_metric
256 if metric in {"klmvn_diag", "mmd_linear", "fid"}:
257 need = {"count", "sum"}
258 if metric in {"klmvn_diag", "fid"}:
259 need |= {"sum_sq"}
260 if metric == "fid":
261 need |= {"sum_outer"}
262 for side, name in ((source, "source"), (target, "target")):
263 if not need.issubset(side.keys()): 263 ↛ 264line 263 didn't jump to line 264 because the condition on line 263 was never true
264 return {
265 "metric": pa.array([metric]),
266 "note": pa.array([f"missing keys in {name}: {sorted(need)}"]),
267 }
269 n1, n2 = scalar(source["count"]), scalar(target["count"])
270 if n1 <= 0 or n2 <= 0: 270 ↛ 271line 270 didn't jump to line 271 because the condition on line 270 was never true
271 return {
272 "metric": pa.array([metric]),
273 "note": pa.array(["empty summaries"]),
274 }
276 mu1 = vec(source["sum"]) / n1
277 mu2 = vec(target["sum"]) / n2
279 if metric == "mmd_linear":
280 diff = mu1 - mu2
281 val = float(np.dot(diff, diff))
282 return {"mmd_linear": pa.array([val], type=pa.float64())}
284 v1 = np.maximum(vec(source["sum_sq"]) / n1 - mu1 * mu1, 1e-9)
285 v2 = np.maximum(vec(target["sum_sq"]) / n2 - mu2 * mu2, 1e-9)
287 if metric == "klmvn_diag":
288 term_var = np.sum(v1 / v2 - 1.0 - np.log(v1 / v2))
289 term_mean = np.sum((mu2 - mu1) ** 2 / v2)
290 val = 0.5 * (term_var + term_mean)
291 return {"klmvn_diag": pa.array([float(val)], type=pa.float64())}
293 if metric == "fid": 293 ↛ 313line 293 didn't jump to line 313 because the condition on line 293 was always true
294 so1 = vec(source["sum_outer"])
295 so2 = vec(target["sum_outer"])
296 d = int(np.sqrt(so1.size))
297 s1 = (so1.reshape(d, d) / n1) - np.outer(mu1, mu1)
298 s2 = (so2.reshape(d, d) / n2) - np.outer(mu2, mu2)
299 from scipy.linalg import sqrtm
301 diff = mu1 - mu2
302 # The `disp` argument is deprecated and will be
303 # removed in SciPy 1.18.0. The previously returned error estimate
304 # can be computed as ``norm(X @ X - A, 'fro')**2 / norm(A, 'fro')``
305 # covmean, _ = sqrtm(s1.dot(s2), disp=False)
306 covmean = sqrtm(s1.dot(s2))
308 if np.iscomplexobj(covmean): 308 ↛ 309line 308 didn't jump to line 309 because the condition on line 308 was never true
309 covmean = covmean.real
310 fid = diff.dot(diff) + np.trace(s1) + np.trace(s2) - 2 * np.trace(covmean)
311 return {"fid": pa.array([float(abs(fid))], type=pa.float64())}
313 if metric == "wasserstein_1d": 313 ↛ 346line 313 didn't jump to line 346 because the condition on line 313 was always true
314 if "hist_counts" not in source or "hist_counts" not in target: 314 ↛ 315line 314 didn't jump to line 315 because the condition on line 314 was never true
315 return {
316 "metric": pa.array([metric]),
317 "note": pa.array(["missing hist_counts"]),
318 }
319 h1 = np.asarray(source["hist_counts"].values.to_numpy(), dtype=np.int64)
320 h2 = np.asarray(target["hist_counts"].values.to_numpy(), dtype=np.int64)
321 # derive dims from summary config
322 use_dims = self.hist_dims
323 bins = self.hist_bins
324 if h1.size != h2.size or h1.size != bins * use_dims: 324 ↛ 325line 324 didn't jump to line 325 because the condition on line 324 was never true
325 return {
326 "metric": pa.array([metric]),
327 "note": pa.array(["hist_counts length mismatch"]),
328 }
329 width = (self.hist_range[1] - self.hist_range[0]) / bins
330 total = 0.0
331 used = 0
332 for j in range(use_dims):
333 h1 = h1[j * bins : (j + 1) * bins].astype(np.float64)
334 h2 = h2[j * bins : (j + 1) * bins].astype(np.float64)
335 if h1.sum() == 0 and h2.sum() == 0:
336 continue
337 p = h1 / max(1.0, h1.sum())
338 q = h2 / max(1.0, h2.sum())
339 cdf_p = np.cumsum(p)
340 cdf_q = np.cumsum(q)
341 total += float(np.sum(np.abs(cdf_p - cdf_q)) * width)
342 used += 1
343 val = total / max(1, used)
344 return {"wasserstein_1d": pa.array([val], type=pa.float64())}
346 return {
347 "metric": pa.array([metric]),
348 "note": pa.array(["unsupported metric or invalid inputs"]),
349 }