Skip to content

dqm_ml_pytorch.domain_gap

Domain gap processor for measuring distribution distance between datasets.

This module contains the DomainGapProcessor class that computes statistical distances (KL divergence, MMD, FID, Wasserstein) between source and target datasets using image embeddings.

logger = logging.getLogger(__name__) module-attribute

DomainGapProcessor

Bases: DatametricProcessor

Computes statistical distances between source and target dataselections using image embeddings.

This processor works in two stages: 1. Dataset Summary: Aggregates high-dimensional embeddings into compact statistics (mean, variance, outer products, histograms). 2. Delta Computation: Uses these summaries to calculate distance metrics between a source and a target dataset.

Supported Delta Metrics
  • klmvn_diag: Kullback-Leibler Divergence assuming a multivariate Normal distribution with a diagonal covariance matrix.
  • mmd_linear: Maximum Mean Discrepancy with a linear kernel.
  • fid: Frechet Inception Distance. Measures distance between two Gaussians fitted to feature representations (requires full covariance calculation).
  • wasserstein_1d: Average 1D Wasserstein distance across embedding dimensions, approximated via histograms.
Source code in packages/dqm-ml-pytorch/src/dqm_ml_pytorch/domain_gap.py
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
class DomainGapProcessor(DatametricProcessor):
    """
    Computes statistical distances between source and target dataselections using image embeddings.

    This processor works in two stages:
    1. Dataset Summary: Aggregates high-dimensional embeddings into compact statistics
       (mean, variance, outer products, histograms).
    2. Delta Computation: Uses these summaries to calculate distance metrics between
       a source and a target dataset.

    Supported Delta Metrics:
      - `klmvn_diag`: Kullback-Leibler Divergence assuming a multivariate Normal
        distribution with a diagonal covariance matrix.
      - `mmd_linear`: Maximum Mean Discrepancy with a linear kernel.
      - `fid`: Frechet Inception Distance. Measures distance between two Gaussians
        fitted to feature representations (requires full covariance calculation).
      - `wasserstein_1d`: Average 1D Wasserstein distance across embedding dimensions,
        approximated via histograms.
    """

    def __init__(
        self,
        name: str = "image_embedding",
        config: dict[str, Any] | None = None,
    ):
        """
        Initialize the domain gap processor.

        Args:
            name: Unique name of the processor instance.
            config: Configuration dictionary containing:
                - INPUT:
                    - embedding_col: Column name containing embeddings (default: "embedding").
                - SUMMARY:
                    - collect_sum_outer: Whether to compute outer products (needed for FID).
                    - collect_hist_1d: Whether to compute histograms (needed for Wasserstein).
                    - hist_dims: Number of dimensions to histogram.
                    - hist_bins: Number of bins per histogram.
                - DELTA:
                    - metric: Target metric ("klmvn_diag", "mmd_linear", "fid", "wasserstein_1d").
        """
        super().__init__(name, config)
        self._checked = False

    # ---------------- API ----------------
    def check_config(self) -> None:
        """Validate and configure the domain gap processor.

        This method parses the configuration dictionary and sets:
        - INPUT: Embedding column name
        - SUMMARY: Options for collecting summary statistics (outer products, histograms)
        - DELTA: The target metric to compute (klmvn_diag, mmd_linear, fid, wasserstein_1d)
        """
        cfg = self.config or {}
        icfg = cfg.get("INPUT", {})
        self.embedding_col: str = icfg.get("embedding_col", "embedding")

        dcfg = cfg.get("DELTA", {})
        self.delta_metric: str = str(dcfg.get("metric", "klmvn_diag")).lower()
        scfg = cfg.get("SUMMARY", {})

        if self.delta_metric == "fid":
            auto_sum_outer = True
            auto_hist_1d = False
        elif self.delta_metric == "wasserstein_1d":
            auto_sum_outer = False
            auto_hist_1d = True
        else:  # klmvn_diag, mmd_linear
            auto_sum_outer = False
            auto_hist_1d = False

        self.collect_sum_outer: bool = bool(scfg.get("collect_sum_outer", auto_sum_outer))
        self.collect_hist_1d: bool = bool(scfg.get("collect_hist_1d", auto_hist_1d))

        # Wasserstein-1D parameters
        self.hist_dims: int = int(scfg.get("hist_dims", 64))
        self.hist_bins: int = int(scfg.get("hist_bins", 32))
        rng = scfg.get("hist_range", [-3.0, 3.0])
        self.hist_range: tuple[float, float] = (float(rng[0]), float(rng[1]))

        self._checked = True

    @override
    def needed_columns(self) -> list[str]:
        if not getattr(self, "_checked", False):
            self.check_config()
        return [self.embedding_col]

    def generated_columns(self) -> list[str]:
        """Return the list of columns generated by this processor.

        Returns:
            Empty list as this processor computes deltas rather than features.
        """
        return []

    @override
    def compute_batch_metric(self, features: dict[str, pa.Array]) -> dict[str, pa.Array]:
        """
        Reduce a batch of embeddings into summary statistics.

        Returns a dictionary containing:
            - count: Number of samples.
            - sum: Element-wise sum of embeddings.
            - sum_sq: Element-wise sum of squared embeddings.
            - sum_outer: (Optional) Sum of outer products (for FID).
            - hist_counts: (Optional) Flattened histogram counts (for Wasserstein).
        """
        if not getattr(self, "_checked", False):
            self.check_config()

        emb = features.get(self.embedding_col)
        if emb is None or not isinstance(emb, pa.FixedSizeListArray):
            return {}

        n = len(emb)
        # d = emb.list_size
        d = len(emb[0])
        child = emb.values  # flat
        arr = np.asarray(child.to_numpy()).reshape(n, d)

        out: dict[str, pa.Array] = {}
        out["count"] = pa.array([n], type=pa.int64())
        out["sum"] = pa.FixedSizeListArray.from_arrays(pa.array(arr.sum(axis=0).astype(np.float64)), d)
        out["sum_sq"] = pa.FixedSizeListArray.from_arrays(pa.array((arr * arr).sum(axis=0).astype(np.float64)), d)

        # optional: sum_outer for FID
        if self.collect_sum_outer:
            s = (arr.T @ arr).reshape(-1).astype(np.float64)
            out["sum_outer"] = pa.FixedSizeListArray.from_arrays(pa.array(s), d * d)

        # optional: histograms for Wasserstein-1D
        if self.collect_hist_1d:
            use_dims = min(d, self.hist_dims)
            low, high = self.hist_range
            hist_list: list[np.ndarray] = []
            for j in range(use_dims):
                h, _ = np.histogram(arr[:, j], bins=self.hist_bins, range=(low, high))
                hist_list.append(h.astype(np.int64))
            h = np.stack(hist_list, axis=0).reshape(-1)
            out["hist_counts"] = pa.FixedSizeListArray.from_arrays(pa.array(h), self.hist_bins * use_dims)

        return out

    @override
    def compute(self, batch_metrics: dict[str, pa.Array]) -> dict[str, pa.Array]:
        """Aggregate batch-level summary statistics into global dataselection statistics.

        Args:
            batch_metrics: Dictionary containing batch-level statistics (count, sum, sum_sq, etc.).

        Returns:
            Dictionary containing aggregated dataset-level statistics.
        """
        if not batch_metrics:
            return {}

        def _sum_scalar(a: pa.Array) -> int:
            return int(np.asarray(a.to_numpy()).sum())

        def _sum_fixed(v: pa.FixedSizeListArray) -> tuple[np.ndarray, int]:
            vals = np.asarray(v.values.to_numpy(), dtype=np.float64)
            d = len(v[0])
            # d = v.list_size
            return vals.reshape(-1, d).sum(axis=0), d

        out: dict[str, pa.Array] = {}

        # count
        if "count" not in batch_metrics:
            return {}
        total_n = _sum_scalar(batch_metrics["count"])
        out["count"] = pa.array([total_n], type=pa.int64())

        # sum / sum_sq
        if "sum" in batch_metrics:
            s, d = _sum_fixed(batch_metrics["sum"])
            out["sum"] = pa.FixedSizeListArray.from_arrays(pa.array(s), d)
        if "sum_sq" in batch_metrics:
            s2, d2 = _sum_fixed(batch_metrics["sum_sq"])
            out["sum_sq"] = pa.FixedSizeListArray.from_arrays(pa.array(s2), d2)

        # optional sum_outer
        if "sum_outer" in batch_metrics:
            so_vals = np.asarray(batch_metrics["sum_outer"].values.to_numpy(), dtype=np.float64)
            dd = len(batch_metrics["sum_outer"][0])
            # dd = batch_metrics["sum_outer"].list_size
            out["sum_outer"] = pa.FixedSizeListArray.from_arrays(pa.array(so_vals.reshape(-1, dd).sum(axis=0)), dd)

        # optional hist_counts
        if "hist_counts" in batch_metrics:
            h_vals = np.asarray(batch_metrics["hist_counts"].values.to_numpy(), dtype=np.int64)
            h_len = len(batch_metrics["hist_counts"][0])
            # h_len = batch_metrics["hist_counts"].list_size
            out["hist_counts"] = pa.FixedSizeListArray.from_arrays(
                pa.array(h_vals.reshape(-1, h_len).sum(axis=0)), h_len
            )

        return out

    @override
    def compute_delta(self, source: dict[str, pa.Array], target: dict[str, pa.Array]) -> dict[str, pa.Array]:
        """
        Calculate the domain gap metric between source and target dataselection statistics.

        Args:
            source: Dataselection statistics from the source dataset (computed via `compute`).
            target: Dataselection statistics from the target dataset (computed via `compute`).

        Returns:
            Dictionary containing the calculated metric value.
        """
        if not getattr(self, "_checked", False):
            self.check_config()
        # TODO : check config and available metrics outside of computation
        # TODO : add a return code error in th API

        def vec(
            a: pa.FixedSizeListArray,
        ) -> Any:  # TODO : check type error np.ndarray
            """Aggregate a FixedSizeListArray into a single numpy vector by summing all lists."""
            len_a = len(a[0])
            # len_a = a.list_size
            array = np.asarray(a.values.to_numpy(), dtype=np.float64).reshape(-1, len_a).sum(axis=0)
            return array  # type : ignore[no-any-return]

        def scalar(a: pa.Array) -> float:
            """Sum all elements in a pyarrow Array and return as a float."""
            return float(np.asarray(a.to_numpy()).sum())

        metric = self.delta_metric

        if metric in {"klmvn_diag", "mmd_linear", "fid"}:
            need = {"count", "sum"}
            if metric in {"klmvn_diag", "fid"}:
                need |= {"sum_sq"}
            if metric == "fid":
                need |= {"sum_outer"}
            for side, name in ((source, "source"), (target, "target")):
                if not need.issubset(side.keys()):
                    return {
                        "metric": pa.array([metric]),
                        "note": pa.array([f"missing keys in {name}: {sorted(need)}"]),
                    }

            n1, n2 = scalar(source["count"]), scalar(target["count"])
            if n1 <= 0 or n2 <= 0:
                return {
                    "metric": pa.array([metric]),
                    "note": pa.array(["empty summaries"]),
                }

            mu1 = vec(source["sum"]) / n1
            mu2 = vec(target["sum"]) / n2

            if metric == "mmd_linear":
                diff = mu1 - mu2
                val = float(np.dot(diff, diff))
                return {"mmd_linear": pa.array([val], type=pa.float64())}

            v1 = np.maximum(vec(source["sum_sq"]) / n1 - mu1 * mu1, 1e-9)
            v2 = np.maximum(vec(target["sum_sq"]) / n2 - mu2 * mu2, 1e-9)

            if metric == "klmvn_diag":
                term_var = np.sum(v1 / v2 - 1.0 - np.log(v1 / v2))
                term_mean = np.sum((mu2 - mu1) ** 2 / v2)
                val = 0.5 * (term_var + term_mean)
                return {"klmvn_diag": pa.array([float(val)], type=pa.float64())}

            if metric == "fid":
                so1 = vec(source["sum_outer"])
                so2 = vec(target["sum_outer"])
                d = int(np.sqrt(so1.size))
                s1 = (so1.reshape(d, d) / n1) - np.outer(mu1, mu1)
                s2 = (so2.reshape(d, d) / n2) - np.outer(mu2, mu2)
                from scipy.linalg import sqrtm

                diff = mu1 - mu2
                # The `disp` argument is deprecated and will be
                # removed in SciPy 1.18.0. The previously returned error estimate
                # can be computed as ``norm(X @ X - A, 'fro')**2 / norm(A, 'fro')``
                # covmean, _ = sqrtm(s1.dot(s2), disp=False)
                covmean = sqrtm(s1.dot(s2))

                if np.iscomplexobj(covmean):
                    covmean = covmean.real
                fid = diff.dot(diff) + np.trace(s1) + np.trace(s2) - 2 * np.trace(covmean)
                return {"fid": pa.array([float(abs(fid))], type=pa.float64())}

        if metric == "wasserstein_1d":
            if "hist_counts" not in source or "hist_counts" not in target:
                return {
                    "metric": pa.array([metric]),
                    "note": pa.array(["missing hist_counts"]),
                }
            h1 = np.asarray(source["hist_counts"].values.to_numpy(), dtype=np.int64)
            h2 = np.asarray(target["hist_counts"].values.to_numpy(), dtype=np.int64)
            # derive dims from summary config
            use_dims = self.hist_dims
            bins = self.hist_bins
            if h1.size != h2.size or h1.size != bins * use_dims:
                return {
                    "metric": pa.array([metric]),
                    "note": pa.array(["hist_counts length mismatch"]),
                }
            width = (self.hist_range[1] - self.hist_range[0]) / bins
            total = 0.0
            used = 0
            for j in range(use_dims):
                h1 = h1[j * bins : (j + 1) * bins].astype(np.float64)
                h2 = h2[j * bins : (j + 1) * bins].astype(np.float64)
                if h1.sum() == 0 and h2.sum() == 0:
                    continue
                p = h1 / max(1.0, h1.sum())
                q = h2 / max(1.0, h2.sum())
                cdf_p = np.cumsum(p)
                cdf_q = np.cumsum(q)
                total += float(np.sum(np.abs(cdf_p - cdf_q)) * width)
                used += 1
            val = total / max(1, used)
            return {"wasserstein_1d": pa.array([val], type=pa.float64())}

        return {
            "metric": pa.array([metric]),
            "note": pa.array(["unsupported metric or invalid inputs"]),
        }

__init__(name: str = 'image_embedding', config: dict[str, Any] | None = None)

Initialize the domain gap processor.

Parameters:

Name Type Description Default
name str

Unique name of the processor instance.

'image_embedding'
config dict[str, Any] | None

Configuration dictionary containing: - INPUT: - embedding_col: Column name containing embeddings (default: "embedding"). - SUMMARY: - collect_sum_outer: Whether to compute outer products (needed for FID). - collect_hist_1d: Whether to compute histograms (needed for Wasserstein). - hist_dims: Number of dimensions to histogram. - hist_bins: Number of bins per histogram. - DELTA: - metric: Target metric ("klmvn_diag", "mmd_linear", "fid", "wasserstein_1d").

None
Source code in packages/dqm-ml-pytorch/src/dqm_ml_pytorch/domain_gap.py
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
def __init__(
    self,
    name: str = "image_embedding",
    config: dict[str, Any] | None = None,
):
    """
    Initialize the domain gap processor.

    Args:
        name: Unique name of the processor instance.
        config: Configuration dictionary containing:
            - INPUT:
                - embedding_col: Column name containing embeddings (default: "embedding").
            - SUMMARY:
                - collect_sum_outer: Whether to compute outer products (needed for FID).
                - collect_hist_1d: Whether to compute histograms (needed for Wasserstein).
                - hist_dims: Number of dimensions to histogram.
                - hist_bins: Number of bins per histogram.
            - DELTA:
                - metric: Target metric ("klmvn_diag", "mmd_linear", "fid", "wasserstein_1d").
    """
    super().__init__(name, config)
    self._checked = False

check_config() -> None

Validate and configure the domain gap processor.

This method parses the configuration dictionary and sets: - INPUT: Embedding column name - SUMMARY: Options for collecting summary statistics (outer products, histograms) - DELTA: The target metric to compute (klmvn_diag, mmd_linear, fid, wasserstein_1d)

Source code in packages/dqm-ml-pytorch/src/dqm_ml_pytorch/domain_gap.py
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
def check_config(self) -> None:
    """Validate and configure the domain gap processor.

    This method parses the configuration dictionary and sets:
    - INPUT: Embedding column name
    - SUMMARY: Options for collecting summary statistics (outer products, histograms)
    - DELTA: The target metric to compute (klmvn_diag, mmd_linear, fid, wasserstein_1d)
    """
    cfg = self.config or {}
    icfg = cfg.get("INPUT", {})
    self.embedding_col: str = icfg.get("embedding_col", "embedding")

    dcfg = cfg.get("DELTA", {})
    self.delta_metric: str = str(dcfg.get("metric", "klmvn_diag")).lower()
    scfg = cfg.get("SUMMARY", {})

    if self.delta_metric == "fid":
        auto_sum_outer = True
        auto_hist_1d = False
    elif self.delta_metric == "wasserstein_1d":
        auto_sum_outer = False
        auto_hist_1d = True
    else:  # klmvn_diag, mmd_linear
        auto_sum_outer = False
        auto_hist_1d = False

    self.collect_sum_outer: bool = bool(scfg.get("collect_sum_outer", auto_sum_outer))
    self.collect_hist_1d: bool = bool(scfg.get("collect_hist_1d", auto_hist_1d))

    # Wasserstein-1D parameters
    self.hist_dims: int = int(scfg.get("hist_dims", 64))
    self.hist_bins: int = int(scfg.get("hist_bins", 32))
    rng = scfg.get("hist_range", [-3.0, 3.0])
    self.hist_range: tuple[float, float] = (float(rng[0]), float(rng[1]))

    self._checked = True

compute(batch_metrics: dict[str, pa.Array]) -> dict[str, pa.Array]

Aggregate batch-level summary statistics into global dataselection statistics.

Parameters:

Name Type Description Default
batch_metrics dict[str, Array]

Dictionary containing batch-level statistics (count, sum, sum_sq, etc.).

required

Returns:

Type Description
dict[str, Array]

Dictionary containing aggregated dataset-level statistics.

Source code in packages/dqm-ml-pytorch/src/dqm_ml_pytorch/domain_gap.py
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
@override
def compute(self, batch_metrics: dict[str, pa.Array]) -> dict[str, pa.Array]:
    """Aggregate batch-level summary statistics into global dataselection statistics.

    Args:
        batch_metrics: Dictionary containing batch-level statistics (count, sum, sum_sq, etc.).

    Returns:
        Dictionary containing aggregated dataset-level statistics.
    """
    if not batch_metrics:
        return {}

    def _sum_scalar(a: pa.Array) -> int:
        return int(np.asarray(a.to_numpy()).sum())

    def _sum_fixed(v: pa.FixedSizeListArray) -> tuple[np.ndarray, int]:
        vals = np.asarray(v.values.to_numpy(), dtype=np.float64)
        d = len(v[0])
        # d = v.list_size
        return vals.reshape(-1, d).sum(axis=0), d

    out: dict[str, pa.Array] = {}

    # count
    if "count" not in batch_metrics:
        return {}
    total_n = _sum_scalar(batch_metrics["count"])
    out["count"] = pa.array([total_n], type=pa.int64())

    # sum / sum_sq
    if "sum" in batch_metrics:
        s, d = _sum_fixed(batch_metrics["sum"])
        out["sum"] = pa.FixedSizeListArray.from_arrays(pa.array(s), d)
    if "sum_sq" in batch_metrics:
        s2, d2 = _sum_fixed(batch_metrics["sum_sq"])
        out["sum_sq"] = pa.FixedSizeListArray.from_arrays(pa.array(s2), d2)

    # optional sum_outer
    if "sum_outer" in batch_metrics:
        so_vals = np.asarray(batch_metrics["sum_outer"].values.to_numpy(), dtype=np.float64)
        dd = len(batch_metrics["sum_outer"][0])
        # dd = batch_metrics["sum_outer"].list_size
        out["sum_outer"] = pa.FixedSizeListArray.from_arrays(pa.array(so_vals.reshape(-1, dd).sum(axis=0)), dd)

    # optional hist_counts
    if "hist_counts" in batch_metrics:
        h_vals = np.asarray(batch_metrics["hist_counts"].values.to_numpy(), dtype=np.int64)
        h_len = len(batch_metrics["hist_counts"][0])
        # h_len = batch_metrics["hist_counts"].list_size
        out["hist_counts"] = pa.FixedSizeListArray.from_arrays(
            pa.array(h_vals.reshape(-1, h_len).sum(axis=0)), h_len
        )

    return out

compute_batch_metric(features: dict[str, pa.Array]) -> dict[str, pa.Array]

Reduce a batch of embeddings into summary statistics.

Returns a dictionary containing
  • count: Number of samples.
  • sum: Element-wise sum of embeddings.
  • sum_sq: Element-wise sum of squared embeddings.
  • sum_outer: (Optional) Sum of outer products (for FID).
  • hist_counts: (Optional) Flattened histogram counts (for Wasserstein).
Source code in packages/dqm-ml-pytorch/src/dqm_ml_pytorch/domain_gap.py
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
@override
def compute_batch_metric(self, features: dict[str, pa.Array]) -> dict[str, pa.Array]:
    """
    Reduce a batch of embeddings into summary statistics.

    Returns a dictionary containing:
        - count: Number of samples.
        - sum: Element-wise sum of embeddings.
        - sum_sq: Element-wise sum of squared embeddings.
        - sum_outer: (Optional) Sum of outer products (for FID).
        - hist_counts: (Optional) Flattened histogram counts (for Wasserstein).
    """
    if not getattr(self, "_checked", False):
        self.check_config()

    emb = features.get(self.embedding_col)
    if emb is None or not isinstance(emb, pa.FixedSizeListArray):
        return {}

    n = len(emb)
    # d = emb.list_size
    d = len(emb[0])
    child = emb.values  # flat
    arr = np.asarray(child.to_numpy()).reshape(n, d)

    out: dict[str, pa.Array] = {}
    out["count"] = pa.array([n], type=pa.int64())
    out["sum"] = pa.FixedSizeListArray.from_arrays(pa.array(arr.sum(axis=0).astype(np.float64)), d)
    out["sum_sq"] = pa.FixedSizeListArray.from_arrays(pa.array((arr * arr).sum(axis=0).astype(np.float64)), d)

    # optional: sum_outer for FID
    if self.collect_sum_outer:
        s = (arr.T @ arr).reshape(-1).astype(np.float64)
        out["sum_outer"] = pa.FixedSizeListArray.from_arrays(pa.array(s), d * d)

    # optional: histograms for Wasserstein-1D
    if self.collect_hist_1d:
        use_dims = min(d, self.hist_dims)
        low, high = self.hist_range
        hist_list: list[np.ndarray] = []
        for j in range(use_dims):
            h, _ = np.histogram(arr[:, j], bins=self.hist_bins, range=(low, high))
            hist_list.append(h.astype(np.int64))
        h = np.stack(hist_list, axis=0).reshape(-1)
        out["hist_counts"] = pa.FixedSizeListArray.from_arrays(pa.array(h), self.hist_bins * use_dims)

    return out

compute_delta(source: dict[str, pa.Array], target: dict[str, pa.Array]) -> dict[str, pa.Array]

Calculate the domain gap metric between source and target dataselection statistics.

Parameters:

Name Type Description Default
source dict[str, Array]

Dataselection statistics from the source dataset (computed via compute).

required
target dict[str, Array]

Dataselection statistics from the target dataset (computed via compute).

required

Returns:

Type Description
dict[str, Array]

Dictionary containing the calculated metric value.

Source code in packages/dqm-ml-pytorch/src/dqm_ml_pytorch/domain_gap.py
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
@override
def compute_delta(self, source: dict[str, pa.Array], target: dict[str, pa.Array]) -> dict[str, pa.Array]:
    """
    Calculate the domain gap metric between source and target dataselection statistics.

    Args:
        source: Dataselection statistics from the source dataset (computed via `compute`).
        target: Dataselection statistics from the target dataset (computed via `compute`).

    Returns:
        Dictionary containing the calculated metric value.
    """
    if not getattr(self, "_checked", False):
        self.check_config()
    # TODO : check config and available metrics outside of computation
    # TODO : add a return code error in th API

    def vec(
        a: pa.FixedSizeListArray,
    ) -> Any:  # TODO : check type error np.ndarray
        """Aggregate a FixedSizeListArray into a single numpy vector by summing all lists."""
        len_a = len(a[0])
        # len_a = a.list_size
        array = np.asarray(a.values.to_numpy(), dtype=np.float64).reshape(-1, len_a).sum(axis=0)
        return array  # type : ignore[no-any-return]

    def scalar(a: pa.Array) -> float:
        """Sum all elements in a pyarrow Array and return as a float."""
        return float(np.asarray(a.to_numpy()).sum())

    metric = self.delta_metric

    if metric in {"klmvn_diag", "mmd_linear", "fid"}:
        need = {"count", "sum"}
        if metric in {"klmvn_diag", "fid"}:
            need |= {"sum_sq"}
        if metric == "fid":
            need |= {"sum_outer"}
        for side, name in ((source, "source"), (target, "target")):
            if not need.issubset(side.keys()):
                return {
                    "metric": pa.array([metric]),
                    "note": pa.array([f"missing keys in {name}: {sorted(need)}"]),
                }

        n1, n2 = scalar(source["count"]), scalar(target["count"])
        if n1 <= 0 or n2 <= 0:
            return {
                "metric": pa.array([metric]),
                "note": pa.array(["empty summaries"]),
            }

        mu1 = vec(source["sum"]) / n1
        mu2 = vec(target["sum"]) / n2

        if metric == "mmd_linear":
            diff = mu1 - mu2
            val = float(np.dot(diff, diff))
            return {"mmd_linear": pa.array([val], type=pa.float64())}

        v1 = np.maximum(vec(source["sum_sq"]) / n1 - mu1 * mu1, 1e-9)
        v2 = np.maximum(vec(target["sum_sq"]) / n2 - mu2 * mu2, 1e-9)

        if metric == "klmvn_diag":
            term_var = np.sum(v1 / v2 - 1.0 - np.log(v1 / v2))
            term_mean = np.sum((mu2 - mu1) ** 2 / v2)
            val = 0.5 * (term_var + term_mean)
            return {"klmvn_diag": pa.array([float(val)], type=pa.float64())}

        if metric == "fid":
            so1 = vec(source["sum_outer"])
            so2 = vec(target["sum_outer"])
            d = int(np.sqrt(so1.size))
            s1 = (so1.reshape(d, d) / n1) - np.outer(mu1, mu1)
            s2 = (so2.reshape(d, d) / n2) - np.outer(mu2, mu2)
            from scipy.linalg import sqrtm

            diff = mu1 - mu2
            # The `disp` argument is deprecated and will be
            # removed in SciPy 1.18.0. The previously returned error estimate
            # can be computed as ``norm(X @ X - A, 'fro')**2 / norm(A, 'fro')``
            # covmean, _ = sqrtm(s1.dot(s2), disp=False)
            covmean = sqrtm(s1.dot(s2))

            if np.iscomplexobj(covmean):
                covmean = covmean.real
            fid = diff.dot(diff) + np.trace(s1) + np.trace(s2) - 2 * np.trace(covmean)
            return {"fid": pa.array([float(abs(fid))], type=pa.float64())}

    if metric == "wasserstein_1d":
        if "hist_counts" not in source or "hist_counts" not in target:
            return {
                "metric": pa.array([metric]),
                "note": pa.array(["missing hist_counts"]),
            }
        h1 = np.asarray(source["hist_counts"].values.to_numpy(), dtype=np.int64)
        h2 = np.asarray(target["hist_counts"].values.to_numpy(), dtype=np.int64)
        # derive dims from summary config
        use_dims = self.hist_dims
        bins = self.hist_bins
        if h1.size != h2.size or h1.size != bins * use_dims:
            return {
                "metric": pa.array([metric]),
                "note": pa.array(["hist_counts length mismatch"]),
            }
        width = (self.hist_range[1] - self.hist_range[0]) / bins
        total = 0.0
        used = 0
        for j in range(use_dims):
            h1 = h1[j * bins : (j + 1) * bins].astype(np.float64)
            h2 = h2[j * bins : (j + 1) * bins].astype(np.float64)
            if h1.sum() == 0 and h2.sum() == 0:
                continue
            p = h1 / max(1.0, h1.sum())
            q = h2 / max(1.0, h2.sum())
            cdf_p = np.cumsum(p)
            cdf_q = np.cumsum(q)
            total += float(np.sum(np.abs(cdf_p - cdf_q)) * width)
            used += 1
        val = total / max(1, used)
        return {"wasserstein_1d": pa.array([val], type=pa.float64())}

    return {
        "metric": pa.array([metric]),
        "note": pa.array(["unsupported metric or invalid inputs"]),
    }

generated_columns() -> list[str]

Return the list of columns generated by this processor.

Returns:

Type Description
list[str]

Empty list as this processor computes deltas rather than features.

Source code in packages/dqm-ml-pytorch/src/dqm_ml_pytorch/domain_gap.py
112
113
114
115
116
117
118
def generated_columns(self) -> list[str]:
    """Return the list of columns generated by this processor.

    Returns:
        Empty list as this processor computes deltas rather than features.
    """
    return []

needed_columns() -> list[str]

Source code in packages/dqm-ml-pytorch/src/dqm_ml_pytorch/domain_gap.py
106
107
108
109
110
@override
def needed_columns(self) -> list[str]:
    if not getattr(self, "_checked", False):
        self.check_config()
    return [self.embedding_col]