Skip to content

dqm_ml_pytorch

DQM ML PyTorch package for deep learning-based data quality metrics.

This package provides metric processors that use PyTorch models for computing image embeddings and domain gap metrics.

Classes:

Name Description
ImageEmbeddingProcessor

Extracts image embeddings using pre-trained CNNs.

DomainGapProcessor

Computes statistical distances between datasets.

__all__ = ['DomainGapProcessor', 'ImageEmbeddingProcessor'] module-attribute

DomainGapProcessor

Bases: DatametricProcessor

Computes statistical distances between source and target dataselections using image embeddings.

This processor works in two stages: 1. Dataset Summary: Aggregates high-dimensional embeddings into compact statistics (mean, variance, outer products, histograms). 2. Delta Computation: Uses these summaries to calculate distance metrics between a source and a target dataset.

Supported Delta Metrics
  • klmvn_diag: Kullback-Leibler Divergence assuming a multivariate Normal distribution with a diagonal covariance matrix.
  • mmd_linear: Maximum Mean Discrepancy with a linear kernel.
  • fid: Frechet Inception Distance. Measures distance between two Gaussians fitted to feature representations (requires full covariance calculation).
  • wasserstein_1d: Average 1D Wasserstein distance across embedding dimensions, approximated via histograms.
Source code in packages/dqm-ml-pytorch/src/dqm_ml_pytorch/domain_gap.py
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
class DomainGapProcessor(DatametricProcessor):
    """
    Computes statistical distances between source and target dataselections using image embeddings.

    This processor works in two stages:
    1. Dataset Summary: Aggregates high-dimensional embeddings into compact statistics
       (mean, variance, outer products, histograms).
    2. Delta Computation: Uses these summaries to calculate distance metrics between
       a source and a target dataset.

    Supported Delta Metrics:
      - `klmvn_diag`: Kullback-Leibler Divergence assuming a multivariate Normal
        distribution with a diagonal covariance matrix.
      - `mmd_linear`: Maximum Mean Discrepancy with a linear kernel.
      - `fid`: Frechet Inception Distance. Measures distance between two Gaussians
        fitted to feature representations (requires full covariance calculation).
      - `wasserstein_1d`: Average 1D Wasserstein distance across embedding dimensions,
        approximated via histograms.
    """

    def __init__(
        self,
        name: str = "image_embedding",
        config: dict[str, Any] | None = None,
    ):
        """
        Initialize the domain gap processor.

        Args:
            name: Unique name of the processor instance.
            config: Configuration dictionary containing:
                - INPUT:
                    - embedding_col: Column name containing embeddings (default: "embedding").
                - SUMMARY:
                    - collect_sum_outer: Whether to compute outer products (needed for FID).
                    - collect_hist_1d: Whether to compute histograms (needed for Wasserstein).
                    - hist_dims: Number of dimensions to histogram.
                    - hist_bins: Number of bins per histogram.
                - DELTA:
                    - metric: Target metric ("klmvn_diag", "mmd_linear", "fid", "wasserstein_1d").
        """
        super().__init__(name, config)
        self._checked = False

    # ---------------- API ----------------
    def check_config(self) -> None:
        """Validate and configure the domain gap processor.

        This method parses the configuration dictionary and sets:
        - INPUT: Embedding column name
        - SUMMARY: Options for collecting summary statistics (outer products, histograms)
        - DELTA: The target metric to compute (klmvn_diag, mmd_linear, fid, wasserstein_1d)
        """
        cfg = self.config or {}
        icfg = cfg.get("INPUT", {})
        self.embedding_col: str = icfg.get("embedding_col", "embedding")

        dcfg = cfg.get("DELTA", {})
        self.delta_metric: str = str(dcfg.get("metric", "klmvn_diag")).lower()
        scfg = cfg.get("SUMMARY", {})

        if self.delta_metric == "fid":
            auto_sum_outer = True
            auto_hist_1d = False
        elif self.delta_metric == "wasserstein_1d":
            auto_sum_outer = False
            auto_hist_1d = True
        else:  # klmvn_diag, mmd_linear
            auto_sum_outer = False
            auto_hist_1d = False

        self.collect_sum_outer: bool = bool(scfg.get("collect_sum_outer", auto_sum_outer))
        self.collect_hist_1d: bool = bool(scfg.get("collect_hist_1d", auto_hist_1d))

        # Wasserstein-1D parameters
        self.hist_dims: int = int(scfg.get("hist_dims", 64))
        self.hist_bins: int = int(scfg.get("hist_bins", 32))
        rng = scfg.get("hist_range", [-3.0, 3.0])
        self.hist_range: tuple[float, float] = (float(rng[0]), float(rng[1]))

        self._checked = True

    @override
    def needed_columns(self) -> list[str]:
        if not getattr(self, "_checked", False):
            self.check_config()
        return [self.embedding_col]

    def generated_columns(self) -> list[str]:
        """Return the list of columns generated by this processor.

        Returns:
            Empty list as this processor computes deltas rather than features.
        """
        return []

    @override
    def compute_batch_metric(self, features: dict[str, pa.Array]) -> dict[str, pa.Array]:
        """
        Reduce a batch of embeddings into summary statistics.

        Returns a dictionary containing:
            - count: Number of samples.
            - sum: Element-wise sum of embeddings.
            - sum_sq: Element-wise sum of squared embeddings.
            - sum_outer: (Optional) Sum of outer products (for FID).
            - hist_counts: (Optional) Flattened histogram counts (for Wasserstein).
        """
        if not getattr(self, "_checked", False):
            self.check_config()

        emb = features.get(self.embedding_col)
        if emb is None or not isinstance(emb, pa.FixedSizeListArray):
            return {}

        n = len(emb)
        # d = emb.list_size
        d = len(emb[0])
        child = emb.values  # flat
        arr = np.asarray(child.to_numpy()).reshape(n, d)

        out: dict[str, pa.Array] = {}
        out["count"] = pa.array([n], type=pa.int64())
        out["sum"] = pa.FixedSizeListArray.from_arrays(pa.array(arr.sum(axis=0).astype(np.float64)), d)
        out["sum_sq"] = pa.FixedSizeListArray.from_arrays(pa.array((arr * arr).sum(axis=0).astype(np.float64)), d)

        # optional: sum_outer for FID
        if self.collect_sum_outer:
            s = (arr.T @ arr).reshape(-1).astype(np.float64)
            out["sum_outer"] = pa.FixedSizeListArray.from_arrays(pa.array(s), d * d)

        # optional: histograms for Wasserstein-1D
        if self.collect_hist_1d:
            use_dims = min(d, self.hist_dims)
            low, high = self.hist_range
            hist_list: list[np.ndarray] = []
            for j in range(use_dims):
                h, _ = np.histogram(arr[:, j], bins=self.hist_bins, range=(low, high))
                hist_list.append(h.astype(np.int64))
            h = np.stack(hist_list, axis=0).reshape(-1)
            out["hist_counts"] = pa.FixedSizeListArray.from_arrays(pa.array(h), self.hist_bins * use_dims)

        return out

    @override
    def compute(self, batch_metrics: dict[str, pa.Array]) -> dict[str, pa.Array]:
        """Aggregate batch-level summary statistics into global dataselection statistics.

        Args:
            batch_metrics: Dictionary containing batch-level statistics (count, sum, sum_sq, etc.).

        Returns:
            Dictionary containing aggregated dataset-level statistics.
        """
        if not batch_metrics:
            return {}

        def _sum_scalar(a: pa.Array) -> int:
            return int(np.asarray(a.to_numpy()).sum())

        def _sum_fixed(v: pa.FixedSizeListArray) -> tuple[np.ndarray, int]:
            vals = np.asarray(v.values.to_numpy(), dtype=np.float64)
            d = len(v[0])
            # d = v.list_size
            return vals.reshape(-1, d).sum(axis=0), d

        out: dict[str, pa.Array] = {}

        # count
        if "count" not in batch_metrics:
            return {}
        total_n = _sum_scalar(batch_metrics["count"])
        out["count"] = pa.array([total_n], type=pa.int64())

        # sum / sum_sq
        if "sum" in batch_metrics:
            s, d = _sum_fixed(batch_metrics["sum"])
            out["sum"] = pa.FixedSizeListArray.from_arrays(pa.array(s), d)
        if "sum_sq" in batch_metrics:
            s2, d2 = _sum_fixed(batch_metrics["sum_sq"])
            out["sum_sq"] = pa.FixedSizeListArray.from_arrays(pa.array(s2), d2)

        # optional sum_outer
        if "sum_outer" in batch_metrics:
            so_vals = np.asarray(batch_metrics["sum_outer"].values.to_numpy(), dtype=np.float64)
            dd = len(batch_metrics["sum_outer"][0])
            # dd = batch_metrics["sum_outer"].list_size
            out["sum_outer"] = pa.FixedSizeListArray.from_arrays(pa.array(so_vals.reshape(-1, dd).sum(axis=0)), dd)

        # optional hist_counts
        if "hist_counts" in batch_metrics:
            h_vals = np.asarray(batch_metrics["hist_counts"].values.to_numpy(), dtype=np.int64)
            h_len = len(batch_metrics["hist_counts"][0])
            # h_len = batch_metrics["hist_counts"].list_size
            out["hist_counts"] = pa.FixedSizeListArray.from_arrays(
                pa.array(h_vals.reshape(-1, h_len).sum(axis=0)), h_len
            )

        return out

    @override
    def compute_delta(self, source: dict[str, pa.Array], target: dict[str, pa.Array]) -> dict[str, pa.Array]:
        """
        Calculate the domain gap metric between source and target dataselection statistics.

        Args:
            source: Dataselection statistics from the source dataset (computed via `compute`).
            target: Dataselection statistics from the target dataset (computed via `compute`).

        Returns:
            Dictionary containing the calculated metric value.
        """
        if not getattr(self, "_checked", False):
            self.check_config()
        # TODO : check config and available metrics outside of computation
        # TODO : add a return code error in th API

        def vec(
            a: pa.FixedSizeListArray,
        ) -> Any:  # TODO : check type error np.ndarray
            """Aggregate a FixedSizeListArray into a single numpy vector by summing all lists."""
            len_a = len(a[0])
            # len_a = a.list_size
            array = np.asarray(a.values.to_numpy(), dtype=np.float64).reshape(-1, len_a).sum(axis=0)
            return array  # type : ignore[no-any-return]

        def scalar(a: pa.Array) -> float:
            """Sum all elements in a pyarrow Array and return as a float."""
            return float(np.asarray(a.to_numpy()).sum())

        metric = self.delta_metric

        if metric in {"klmvn_diag", "mmd_linear", "fid"}:
            need = {"count", "sum"}
            if metric in {"klmvn_diag", "fid"}:
                need |= {"sum_sq"}
            if metric == "fid":
                need |= {"sum_outer"}
            for side, name in ((source, "source"), (target, "target")):
                if not need.issubset(side.keys()):
                    return {
                        "metric": pa.array([metric]),
                        "note": pa.array([f"missing keys in {name}: {sorted(need)}"]),
                    }

            n1, n2 = scalar(source["count"]), scalar(target["count"])
            if n1 <= 0 or n2 <= 0:
                return {
                    "metric": pa.array([metric]),
                    "note": pa.array(["empty summaries"]),
                }

            mu1 = vec(source["sum"]) / n1
            mu2 = vec(target["sum"]) / n2

            if metric == "mmd_linear":
                diff = mu1 - mu2
                val = float(np.dot(diff, diff))
                return {"mmd_linear": pa.array([val], type=pa.float64())}

            v1 = np.maximum(vec(source["sum_sq"]) / n1 - mu1 * mu1, 1e-9)
            v2 = np.maximum(vec(target["sum_sq"]) / n2 - mu2 * mu2, 1e-9)

            if metric == "klmvn_diag":
                term_var = np.sum(v1 / v2 - 1.0 - np.log(v1 / v2))
                term_mean = np.sum((mu2 - mu1) ** 2 / v2)
                val = 0.5 * (term_var + term_mean)
                return {"klmvn_diag": pa.array([float(val)], type=pa.float64())}

            if metric == "fid":
                so1 = vec(source["sum_outer"])
                so2 = vec(target["sum_outer"])
                d = int(np.sqrt(so1.size))
                s1 = (so1.reshape(d, d) / n1) - np.outer(mu1, mu1)
                s2 = (so2.reshape(d, d) / n2) - np.outer(mu2, mu2)
                from scipy.linalg import sqrtm

                diff = mu1 - mu2
                # The `disp` argument is deprecated and will be
                # removed in SciPy 1.18.0. The previously returned error estimate
                # can be computed as ``norm(X @ X - A, 'fro')**2 / norm(A, 'fro')``
                # covmean, _ = sqrtm(s1.dot(s2), disp=False)
                covmean = sqrtm(s1.dot(s2))

                if np.iscomplexobj(covmean):
                    covmean = covmean.real
                fid = diff.dot(diff) + np.trace(s1) + np.trace(s2) - 2 * np.trace(covmean)
                return {"fid": pa.array([float(abs(fid))], type=pa.float64())}

        if metric == "wasserstein_1d":
            if "hist_counts" not in source or "hist_counts" not in target:
                return {
                    "metric": pa.array([metric]),
                    "note": pa.array(["missing hist_counts"]),
                }
            h1 = np.asarray(source["hist_counts"].values.to_numpy(), dtype=np.int64)
            h2 = np.asarray(target["hist_counts"].values.to_numpy(), dtype=np.int64)
            # derive dims from summary config
            use_dims = self.hist_dims
            bins = self.hist_bins
            if h1.size != h2.size or h1.size != bins * use_dims:
                return {
                    "metric": pa.array([metric]),
                    "note": pa.array(["hist_counts length mismatch"]),
                }
            width = (self.hist_range[1] - self.hist_range[0]) / bins
            total = 0.0
            used = 0
            for j in range(use_dims):
                h1 = h1[j * bins : (j + 1) * bins].astype(np.float64)
                h2 = h2[j * bins : (j + 1) * bins].astype(np.float64)
                if h1.sum() == 0 and h2.sum() == 0:
                    continue
                p = h1 / max(1.0, h1.sum())
                q = h2 / max(1.0, h2.sum())
                cdf_p = np.cumsum(p)
                cdf_q = np.cumsum(q)
                total += float(np.sum(np.abs(cdf_p - cdf_q)) * width)
                used += 1
            val = total / max(1, used)
            return {"wasserstein_1d": pa.array([val], type=pa.float64())}

        return {
            "metric": pa.array([metric]),
            "note": pa.array(["unsupported metric or invalid inputs"]),
        }

__init__(name: str = 'image_embedding', config: dict[str, Any] | None = None)

Initialize the domain gap processor.

Parameters:

Name Type Description Default
name str

Unique name of the processor instance.

'image_embedding'
config dict[str, Any] | None

Configuration dictionary containing: - INPUT: - embedding_col: Column name containing embeddings (default: "embedding"). - SUMMARY: - collect_sum_outer: Whether to compute outer products (needed for FID). - collect_hist_1d: Whether to compute histograms (needed for Wasserstein). - hist_dims: Number of dimensions to histogram. - hist_bins: Number of bins per histogram. - DELTA: - metric: Target metric ("klmvn_diag", "mmd_linear", "fid", "wasserstein_1d").

None
Source code in packages/dqm-ml-pytorch/src/dqm_ml_pytorch/domain_gap.py
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
def __init__(
    self,
    name: str = "image_embedding",
    config: dict[str, Any] | None = None,
):
    """
    Initialize the domain gap processor.

    Args:
        name: Unique name of the processor instance.
        config: Configuration dictionary containing:
            - INPUT:
                - embedding_col: Column name containing embeddings (default: "embedding").
            - SUMMARY:
                - collect_sum_outer: Whether to compute outer products (needed for FID).
                - collect_hist_1d: Whether to compute histograms (needed for Wasserstein).
                - hist_dims: Number of dimensions to histogram.
                - hist_bins: Number of bins per histogram.
            - DELTA:
                - metric: Target metric ("klmvn_diag", "mmd_linear", "fid", "wasserstein_1d").
    """
    super().__init__(name, config)
    self._checked = False

check_config() -> None

Validate and configure the domain gap processor.

This method parses the configuration dictionary and sets: - INPUT: Embedding column name - SUMMARY: Options for collecting summary statistics (outer products, histograms) - DELTA: The target metric to compute (klmvn_diag, mmd_linear, fid, wasserstein_1d)

Source code in packages/dqm-ml-pytorch/src/dqm_ml_pytorch/domain_gap.py
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
def check_config(self) -> None:
    """Validate and configure the domain gap processor.

    This method parses the configuration dictionary and sets:
    - INPUT: Embedding column name
    - SUMMARY: Options for collecting summary statistics (outer products, histograms)
    - DELTA: The target metric to compute (klmvn_diag, mmd_linear, fid, wasserstein_1d)
    """
    cfg = self.config or {}
    icfg = cfg.get("INPUT", {})
    self.embedding_col: str = icfg.get("embedding_col", "embedding")

    dcfg = cfg.get("DELTA", {})
    self.delta_metric: str = str(dcfg.get("metric", "klmvn_diag")).lower()
    scfg = cfg.get("SUMMARY", {})

    if self.delta_metric == "fid":
        auto_sum_outer = True
        auto_hist_1d = False
    elif self.delta_metric == "wasserstein_1d":
        auto_sum_outer = False
        auto_hist_1d = True
    else:  # klmvn_diag, mmd_linear
        auto_sum_outer = False
        auto_hist_1d = False

    self.collect_sum_outer: bool = bool(scfg.get("collect_sum_outer", auto_sum_outer))
    self.collect_hist_1d: bool = bool(scfg.get("collect_hist_1d", auto_hist_1d))

    # Wasserstein-1D parameters
    self.hist_dims: int = int(scfg.get("hist_dims", 64))
    self.hist_bins: int = int(scfg.get("hist_bins", 32))
    rng = scfg.get("hist_range", [-3.0, 3.0])
    self.hist_range: tuple[float, float] = (float(rng[0]), float(rng[1]))

    self._checked = True

compute(batch_metrics: dict[str, pa.Array]) -> dict[str, pa.Array]

Aggregate batch-level summary statistics into global dataselection statistics.

Parameters:

Name Type Description Default
batch_metrics dict[str, Array]

Dictionary containing batch-level statistics (count, sum, sum_sq, etc.).

required

Returns:

Type Description
dict[str, Array]

Dictionary containing aggregated dataset-level statistics.

Source code in packages/dqm-ml-pytorch/src/dqm_ml_pytorch/domain_gap.py
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
@override
def compute(self, batch_metrics: dict[str, pa.Array]) -> dict[str, pa.Array]:
    """Aggregate batch-level summary statistics into global dataselection statistics.

    Args:
        batch_metrics: Dictionary containing batch-level statistics (count, sum, sum_sq, etc.).

    Returns:
        Dictionary containing aggregated dataset-level statistics.
    """
    if not batch_metrics:
        return {}

    def _sum_scalar(a: pa.Array) -> int:
        return int(np.asarray(a.to_numpy()).sum())

    def _sum_fixed(v: pa.FixedSizeListArray) -> tuple[np.ndarray, int]:
        vals = np.asarray(v.values.to_numpy(), dtype=np.float64)
        d = len(v[0])
        # d = v.list_size
        return vals.reshape(-1, d).sum(axis=0), d

    out: dict[str, pa.Array] = {}

    # count
    if "count" not in batch_metrics:
        return {}
    total_n = _sum_scalar(batch_metrics["count"])
    out["count"] = pa.array([total_n], type=pa.int64())

    # sum / sum_sq
    if "sum" in batch_metrics:
        s, d = _sum_fixed(batch_metrics["sum"])
        out["sum"] = pa.FixedSizeListArray.from_arrays(pa.array(s), d)
    if "sum_sq" in batch_metrics:
        s2, d2 = _sum_fixed(batch_metrics["sum_sq"])
        out["sum_sq"] = pa.FixedSizeListArray.from_arrays(pa.array(s2), d2)

    # optional sum_outer
    if "sum_outer" in batch_metrics:
        so_vals = np.asarray(batch_metrics["sum_outer"].values.to_numpy(), dtype=np.float64)
        dd = len(batch_metrics["sum_outer"][0])
        # dd = batch_metrics["sum_outer"].list_size
        out["sum_outer"] = pa.FixedSizeListArray.from_arrays(pa.array(so_vals.reshape(-1, dd).sum(axis=0)), dd)

    # optional hist_counts
    if "hist_counts" in batch_metrics:
        h_vals = np.asarray(batch_metrics["hist_counts"].values.to_numpy(), dtype=np.int64)
        h_len = len(batch_metrics["hist_counts"][0])
        # h_len = batch_metrics["hist_counts"].list_size
        out["hist_counts"] = pa.FixedSizeListArray.from_arrays(
            pa.array(h_vals.reshape(-1, h_len).sum(axis=0)), h_len
        )

    return out

compute_batch_metric(features: dict[str, pa.Array]) -> dict[str, pa.Array]

Reduce a batch of embeddings into summary statistics.

Returns a dictionary containing
  • count: Number of samples.
  • sum: Element-wise sum of embeddings.
  • sum_sq: Element-wise sum of squared embeddings.
  • sum_outer: (Optional) Sum of outer products (for FID).
  • hist_counts: (Optional) Flattened histogram counts (for Wasserstein).
Source code in packages/dqm-ml-pytorch/src/dqm_ml_pytorch/domain_gap.py
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
@override
def compute_batch_metric(self, features: dict[str, pa.Array]) -> dict[str, pa.Array]:
    """
    Reduce a batch of embeddings into summary statistics.

    Returns a dictionary containing:
        - count: Number of samples.
        - sum: Element-wise sum of embeddings.
        - sum_sq: Element-wise sum of squared embeddings.
        - sum_outer: (Optional) Sum of outer products (for FID).
        - hist_counts: (Optional) Flattened histogram counts (for Wasserstein).
    """
    if not getattr(self, "_checked", False):
        self.check_config()

    emb = features.get(self.embedding_col)
    if emb is None or not isinstance(emb, pa.FixedSizeListArray):
        return {}

    n = len(emb)
    # d = emb.list_size
    d = len(emb[0])
    child = emb.values  # flat
    arr = np.asarray(child.to_numpy()).reshape(n, d)

    out: dict[str, pa.Array] = {}
    out["count"] = pa.array([n], type=pa.int64())
    out["sum"] = pa.FixedSizeListArray.from_arrays(pa.array(arr.sum(axis=0).astype(np.float64)), d)
    out["sum_sq"] = pa.FixedSizeListArray.from_arrays(pa.array((arr * arr).sum(axis=0).astype(np.float64)), d)

    # optional: sum_outer for FID
    if self.collect_sum_outer:
        s = (arr.T @ arr).reshape(-1).astype(np.float64)
        out["sum_outer"] = pa.FixedSizeListArray.from_arrays(pa.array(s), d * d)

    # optional: histograms for Wasserstein-1D
    if self.collect_hist_1d:
        use_dims = min(d, self.hist_dims)
        low, high = self.hist_range
        hist_list: list[np.ndarray] = []
        for j in range(use_dims):
            h, _ = np.histogram(arr[:, j], bins=self.hist_bins, range=(low, high))
            hist_list.append(h.astype(np.int64))
        h = np.stack(hist_list, axis=0).reshape(-1)
        out["hist_counts"] = pa.FixedSizeListArray.from_arrays(pa.array(h), self.hist_bins * use_dims)

    return out

compute_delta(source: dict[str, pa.Array], target: dict[str, pa.Array]) -> dict[str, pa.Array]

Calculate the domain gap metric between source and target dataselection statistics.

Parameters:

Name Type Description Default
source dict[str, Array]

Dataselection statistics from the source dataset (computed via compute).

required
target dict[str, Array]

Dataselection statistics from the target dataset (computed via compute).

required

Returns:

Type Description
dict[str, Array]

Dictionary containing the calculated metric value.

Source code in packages/dqm-ml-pytorch/src/dqm_ml_pytorch/domain_gap.py
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
@override
def compute_delta(self, source: dict[str, pa.Array], target: dict[str, pa.Array]) -> dict[str, pa.Array]:
    """
    Calculate the domain gap metric between source and target dataselection statistics.

    Args:
        source: Dataselection statistics from the source dataset (computed via `compute`).
        target: Dataselection statistics from the target dataset (computed via `compute`).

    Returns:
        Dictionary containing the calculated metric value.
    """
    if not getattr(self, "_checked", False):
        self.check_config()
    # TODO : check config and available metrics outside of computation
    # TODO : add a return code error in th API

    def vec(
        a: pa.FixedSizeListArray,
    ) -> Any:  # TODO : check type error np.ndarray
        """Aggregate a FixedSizeListArray into a single numpy vector by summing all lists."""
        len_a = len(a[0])
        # len_a = a.list_size
        array = np.asarray(a.values.to_numpy(), dtype=np.float64).reshape(-1, len_a).sum(axis=0)
        return array  # type : ignore[no-any-return]

    def scalar(a: pa.Array) -> float:
        """Sum all elements in a pyarrow Array and return as a float."""
        return float(np.asarray(a.to_numpy()).sum())

    metric = self.delta_metric

    if metric in {"klmvn_diag", "mmd_linear", "fid"}:
        need = {"count", "sum"}
        if metric in {"klmvn_diag", "fid"}:
            need |= {"sum_sq"}
        if metric == "fid":
            need |= {"sum_outer"}
        for side, name in ((source, "source"), (target, "target")):
            if not need.issubset(side.keys()):
                return {
                    "metric": pa.array([metric]),
                    "note": pa.array([f"missing keys in {name}: {sorted(need)}"]),
                }

        n1, n2 = scalar(source["count"]), scalar(target["count"])
        if n1 <= 0 or n2 <= 0:
            return {
                "metric": pa.array([metric]),
                "note": pa.array(["empty summaries"]),
            }

        mu1 = vec(source["sum"]) / n1
        mu2 = vec(target["sum"]) / n2

        if metric == "mmd_linear":
            diff = mu1 - mu2
            val = float(np.dot(diff, diff))
            return {"mmd_linear": pa.array([val], type=pa.float64())}

        v1 = np.maximum(vec(source["sum_sq"]) / n1 - mu1 * mu1, 1e-9)
        v2 = np.maximum(vec(target["sum_sq"]) / n2 - mu2 * mu2, 1e-9)

        if metric == "klmvn_diag":
            term_var = np.sum(v1 / v2 - 1.0 - np.log(v1 / v2))
            term_mean = np.sum((mu2 - mu1) ** 2 / v2)
            val = 0.5 * (term_var + term_mean)
            return {"klmvn_diag": pa.array([float(val)], type=pa.float64())}

        if metric == "fid":
            so1 = vec(source["sum_outer"])
            so2 = vec(target["sum_outer"])
            d = int(np.sqrt(so1.size))
            s1 = (so1.reshape(d, d) / n1) - np.outer(mu1, mu1)
            s2 = (so2.reshape(d, d) / n2) - np.outer(mu2, mu2)
            from scipy.linalg import sqrtm

            diff = mu1 - mu2
            # The `disp` argument is deprecated and will be
            # removed in SciPy 1.18.0. The previously returned error estimate
            # can be computed as ``norm(X @ X - A, 'fro')**2 / norm(A, 'fro')``
            # covmean, _ = sqrtm(s1.dot(s2), disp=False)
            covmean = sqrtm(s1.dot(s2))

            if np.iscomplexobj(covmean):
                covmean = covmean.real
            fid = diff.dot(diff) + np.trace(s1) + np.trace(s2) - 2 * np.trace(covmean)
            return {"fid": pa.array([float(abs(fid))], type=pa.float64())}

    if metric == "wasserstein_1d":
        if "hist_counts" not in source or "hist_counts" not in target:
            return {
                "metric": pa.array([metric]),
                "note": pa.array(["missing hist_counts"]),
            }
        h1 = np.asarray(source["hist_counts"].values.to_numpy(), dtype=np.int64)
        h2 = np.asarray(target["hist_counts"].values.to_numpy(), dtype=np.int64)
        # derive dims from summary config
        use_dims = self.hist_dims
        bins = self.hist_bins
        if h1.size != h2.size or h1.size != bins * use_dims:
            return {
                "metric": pa.array([metric]),
                "note": pa.array(["hist_counts length mismatch"]),
            }
        width = (self.hist_range[1] - self.hist_range[0]) / bins
        total = 0.0
        used = 0
        for j in range(use_dims):
            h1 = h1[j * bins : (j + 1) * bins].astype(np.float64)
            h2 = h2[j * bins : (j + 1) * bins].astype(np.float64)
            if h1.sum() == 0 and h2.sum() == 0:
                continue
            p = h1 / max(1.0, h1.sum())
            q = h2 / max(1.0, h2.sum())
            cdf_p = np.cumsum(p)
            cdf_q = np.cumsum(q)
            total += float(np.sum(np.abs(cdf_p - cdf_q)) * width)
            used += 1
        val = total / max(1, used)
        return {"wasserstein_1d": pa.array([val], type=pa.float64())}

    return {
        "metric": pa.array([metric]),
        "note": pa.array(["unsupported metric or invalid inputs"]),
    }

generated_columns() -> list[str]

Return the list of columns generated by this processor.

Returns:

Type Description
list[str]

Empty list as this processor computes deltas rather than features.

Source code in packages/dqm-ml-pytorch/src/dqm_ml_pytorch/domain_gap.py
112
113
114
115
116
117
118
def generated_columns(self) -> list[str]:
    """Return the list of columns generated by this processor.

    Returns:
        Empty list as this processor computes deltas rather than features.
    """
    return []

needed_columns() -> list[str]

Source code in packages/dqm-ml-pytorch/src/dqm_ml_pytorch/domain_gap.py
106
107
108
109
110
@override
def needed_columns(self) -> list[str]:
    if not getattr(self, "_checked", False):
        self.check_config()
    return [self.embedding_col]

ImageEmbeddingProcessor

Bases: DatametricProcessor

Computes high-dimensional latent vectors (embeddings) for images using deep learning models.

This processor uses PyTorch and Torchvision to: 1. Load images from bytes or file paths. 2. Preprocess images (resize, normalize) for the selected model. 3. Run batch inference using a pre-trained model (e.g., ResNet, ViT). 4. Extract features from a specific layer (e.g., 'avgpool').

The resulting embeddings are stored as a FixedSizeListArray in the features.

Source code in packages/dqm-ml-pytorch/src/dqm_ml_pytorch/image_embedding.py
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
class ImageEmbeddingProcessor(DatametricProcessor):
    """
    Computes high-dimensional latent vectors (embeddings) for images using deep learning models.

    This processor uses PyTorch and Torchvision to:
    1. Load images from bytes or file paths.
    2. Preprocess images (resize, normalize) for the selected model.
    3. Run batch inference using a pre-trained model (e.g., ResNet, ViT).
    4. Extract features from a specific layer (e.g., 'avgpool').

    The resulting embeddings are stored as a `FixedSizeListArray` in the features.
    """

    def __init__(
        self,
        name: str = "image_embedding",
        config: dict[str, Any] | None = None,
    ):
        """
        Initialize the image embedding processor.

        Args:
            name: Unique name of the processor instance.
            config: Configuration dictionary containing:
                - DATA:
                    - image_column: Column name containing image data (default: "image_bytes").
                    - mode: Source type, "bytes" or "path" (default: "bytes").
                - INFER:
                    - width, height: Input resolution for the model (default: 224x224).
                    - batch_size: Number of images per inference pass (default: 32).
                    - norm_mean, norm_std: Preprocessing normalization stats.
                - MODEL:
                    - arch: Torchvision model name (default: "resnet18").
                    - n_layer_feature: Target layer for feature extraction (default: "avgpool").
                    - device: Execution device, "cpu" or "cuda" (default: "cpu").
        """
        super().__init__(name, config)
        self._checked = False

    # ---------------- API ----------------
    def check_config(self) -> None:
        """Validate and initialize model/transforms from configuration.

        This method parses the configuration dictionary and initializes:
        - Image loading parameters (column name, mode, dataset root path)
        - Inference parameters (image size, batch size, normalization)
        - Model parameters (architecture, feature extraction layer, device)
        - Loads the pre-trained model and creates the feature extractor.
        """
        cfg = self.config or {}

        dcfg = cfg.get("DATA", {})
        self.image_column: str = dcfg.get("image_column", "image_bytes")
        self.mode: str = dcfg.get("mode", "bytes")  # "bytes" or "path"
        if self.mode not in {"bytes", "path"}:
            raise ValueError(f"[{self.name}] DATA.mode must be 'bytes' or 'path'")

        # handle relative paths in parquet to a dataset located at dataset_root_path
        self.dataset_root_path = str(cfg.get("dataset_root_path", "undefined"))
        logger.info(f"[ImageEmbeddingProcessor] dataset_root_path = '{self.dataset_root_path}'")

        icfg = cfg.get("INFER", {})
        self.size: tuple[int, int] = (
            int(icfg.get("width", 224)),
            int(icfg.get("height", 224)),
        )
        mean = icfg.get("norm_mean", [0.485, 0.456, 0.406])
        std = icfg.get("norm_std", [0.229, 0.224, 0.225])
        self.batch_size: int = int(icfg.get("batch_size", 32))

        mcfg = cfg.get("MODEL", {})
        self.arch: str = mcfg.get("arch", "resnet18")
        self.nodes = mcfg.get("n_layer_feature", "avgpool")
        self.device: str = mcfg.get("device", "cpu")

        # Build once
        self.transform = transforms.Compose(
            [
                transforms.Resize(self.size),
                transforms.ToTensor(),
                transforms.Normalize(mean=mean, std=std),
            ]
        )
        self.model = self._load_model(self.arch, self.device)
        self.fx = self._make_extractor(self.model, self.nodes)
        self._embed_dim: int | None = None

        self._checked = True

    @override
    def needed_columns(self) -> list[str]:
        if not getattr(self, "_checked", False):
            self.check_config()
        return [self.image_column]

    def generated_columns(self) -> list[str]:
        """Return the list of columns generated by this processor.

        Returns:
            A list containing 'embedding'.
        """
        return ["embedding"]

    @override
    def compute_features(self, batch: pa.RecordBatch, prev_features: pa.Array = None) -> dict[str, pa.Array]:
        """
        Extract image embeddings for all samples in the batch.

        1. Images are loaded and transformed.
        2. Model inference is performed in sub-batches defined by `INFER.batch_size`.
        3. Results are aggregated into a pyarrow `FixedSizeListArray`.

        Args:
            batch: Raw pyarrow batch.
            prev_features: Pre-computed features (not used).

        Returns:
            Dictionary mapping 'embedding' to the calculated feature vectors.
        """
        if not getattr(self, "_checked", False):
            self.check_config()
        if self.image_column not in batch.schema.names:
            logger.warning(f"[ImageEmbeddingProcessor] missing column '{self.image_column}'")
            return {}

        # 1 load images
        vals = batch.column(self.image_column).to_pylist()
        imgs: list[torch.Tensor | None] = []
        for v in vals:
            if v is None:
                imgs.append(None)
                continue
            try:
                if self.mode == "bytes":
                    img = Image.open(io.BytesIO(v)).convert("RGB")
                else:
                    img_path = Path(self.dataset_root_path) / v if self.dataset_root_path != "undefined" else Path(v)
                    img = Image.open(img_path).convert("RGB")
                imgs.append(self.transform(img))
            except Exception as e:
                logger.warning(f"[ImageEmbeddingProcessor] failed to load image: {e}")
                imgs.append(None)

        # inference in windows, preserve order
        embs: list[np.ndarray | None] = []
        self.fx.eval()
        with torch.no_grad():
            i = 0
            while i < len(imgs):
                window = imgs[i : i + self.batch_size]
                valid = [t for t in window if t is not None]
                if valid:
                    bt = torch.stack(valid).to(self.device)
                    out = self.fx(bt)
                    if isinstance(out, dict):
                        flat_feats = [v.flatten(1) for v in out.values()]
                        feats = torch.cat(flat_feats, dim=1)  # type : ignore TODO : check type error
                    else:
                        feats = out.flatten(1) if out.dim() > 2 else out
                    arr = feats.detach().cpu().numpy().astype("float32")
                    p = 0
                    for t in window:
                        if t is None:
                            embs.append(None)
                        else:
                            embs.append(arr[p])
                            p += 1
                else:
                    embs.extend([None] * len(window))
                i += self.batch_size

        # 3. Infer embedding dim
        if self._embed_dim is None:
            for emb in embs:
                if emb is not None:
                    self._embed_dim = int(emb.size)
                    break
            if self._embed_dim is None:
                return {}
        d = self._embed_dim

        # 4. Build FixedSizeListArray
        flat: list[float] = []
        for emb in embs:
            if emb is None:
                flat.extend([0.0] * d)
            else:
                v = emb.ravel()
                if v.size != d:
                    v = v[:d] if v.size > d else np.pad(v, (0, d - v.size))
                flat.extend(v.tolist())

        child = pa.array(np.asarray(flat, dtype=np.float32))
        return {"embedding": pa.FixedSizeListArray.from_arrays(child, d)}

    @override
    def compute_batch_metric(self, features: dict[str, pa.Array]) -> dict[str, pa.Array]:
        """
        Return an empty dictionary as embeddings are stored as features, we do not compute metrics.
        """
        return {}

    @override
    def compute(self, batch_metrics: dict[str, pa.Array]) -> dict[str, pa.Array]:
        """Compute final dataset-level metrics (not used for embeddings).

        Returns:
            Empty dictionary as embeddings are computed at feature level.
        """
        return {}

    @override
    def compute_delta(self, source: dict[str, pa.Array], target: dict[str, pa.Array]) -> dict[str, pa.Array]:
        """Compute delta between source and target embeddings (not used).

        Args:
            source: Source embeddings (not used).
            target: Target embeddings (not used).

        Returns:
            Empty dictionary as delta computation is handled by DomainGapProcessor.
        """
        return {}

    # utils functions
    def _load_model(self, arch: str, device: str) -> Any:
        """Load a pre-trained torchvision model.

        Args:
            arch: Model architecture name (e.g., 'resnet18', 'resnet50').
            device: Device to load the model on ('cpu' or 'cuda').

        Returns:
            The loaded PyTorch model.
        """
        try:
            m = torchvision.models.get_model(arch, weights="DEFAULT")
        except Exception:
            m = getattr(torchvision.models, arch)(pretrained=True)
        return m.to(device)

    def _make_extractor(self, model: torch.nn.Module, nodes: Any) -> Any:
        """Create a feature extractor from a model.

        Args:
            model: The PyTorch model to extract features from.
            nodes: Layer name (str), index (int), or list of names to extract.

        Returns:
            A feature extractor that returns the requested layer outputs.
        """
        names = list(dict(model.named_modules()).keys())
        if isinstance(nodes, list):
            return create_feature_extractor(model, return_nodes={n: n for n in nodes})
        if isinstance(nodes, int):
            idx = nodes if nodes >= 0 else len(names) + nodes
            nodes = names[idx]
        return create_feature_extractor(model, return_nodes={nodes: "features"})

__init__(name: str = 'image_embedding', config: dict[str, Any] | None = None)

Initialize the image embedding processor.

Parameters:

Name Type Description Default
name str

Unique name of the processor instance.

'image_embedding'
config dict[str, Any] | None

Configuration dictionary containing: - DATA: - image_column: Column name containing image data (default: "image_bytes"). - mode: Source type, "bytes" or "path" (default: "bytes"). - INFER: - width, height: Input resolution for the model (default: 224x224). - batch_size: Number of images per inference pass (default: 32). - norm_mean, norm_std: Preprocessing normalization stats. - MODEL: - arch: Torchvision model name (default: "resnet18"). - n_layer_feature: Target layer for feature extraction (default: "avgpool"). - device: Execution device, "cpu" or "cuda" (default: "cpu").

None
Source code in packages/dqm-ml-pytorch/src/dqm_ml_pytorch/image_embedding.py
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
def __init__(
    self,
    name: str = "image_embedding",
    config: dict[str, Any] | None = None,
):
    """
    Initialize the image embedding processor.

    Args:
        name: Unique name of the processor instance.
        config: Configuration dictionary containing:
            - DATA:
                - image_column: Column name containing image data (default: "image_bytes").
                - mode: Source type, "bytes" or "path" (default: "bytes").
            - INFER:
                - width, height: Input resolution for the model (default: 224x224).
                - batch_size: Number of images per inference pass (default: 32).
                - norm_mean, norm_std: Preprocessing normalization stats.
            - MODEL:
                - arch: Torchvision model name (default: "resnet18").
                - n_layer_feature: Target layer for feature extraction (default: "avgpool").
                - device: Execution device, "cpu" or "cuda" (default: "cpu").
    """
    super().__init__(name, config)
    self._checked = False

check_config() -> None

Validate and initialize model/transforms from configuration.

This method parses the configuration dictionary and initializes: - Image loading parameters (column name, mode, dataset root path) - Inference parameters (image size, batch size, normalization) - Model parameters (architecture, feature extraction layer, device) - Loads the pre-trained model and creates the feature extractor.

Source code in packages/dqm-ml-pytorch/src/dqm_ml_pytorch/image_embedding.py
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
def check_config(self) -> None:
    """Validate and initialize model/transforms from configuration.

    This method parses the configuration dictionary and initializes:
    - Image loading parameters (column name, mode, dataset root path)
    - Inference parameters (image size, batch size, normalization)
    - Model parameters (architecture, feature extraction layer, device)
    - Loads the pre-trained model and creates the feature extractor.
    """
    cfg = self.config or {}

    dcfg = cfg.get("DATA", {})
    self.image_column: str = dcfg.get("image_column", "image_bytes")
    self.mode: str = dcfg.get("mode", "bytes")  # "bytes" or "path"
    if self.mode not in {"bytes", "path"}:
        raise ValueError(f"[{self.name}] DATA.mode must be 'bytes' or 'path'")

    # handle relative paths in parquet to a dataset located at dataset_root_path
    self.dataset_root_path = str(cfg.get("dataset_root_path", "undefined"))
    logger.info(f"[ImageEmbeddingProcessor] dataset_root_path = '{self.dataset_root_path}'")

    icfg = cfg.get("INFER", {})
    self.size: tuple[int, int] = (
        int(icfg.get("width", 224)),
        int(icfg.get("height", 224)),
    )
    mean = icfg.get("norm_mean", [0.485, 0.456, 0.406])
    std = icfg.get("norm_std", [0.229, 0.224, 0.225])
    self.batch_size: int = int(icfg.get("batch_size", 32))

    mcfg = cfg.get("MODEL", {})
    self.arch: str = mcfg.get("arch", "resnet18")
    self.nodes = mcfg.get("n_layer_feature", "avgpool")
    self.device: str = mcfg.get("device", "cpu")

    # Build once
    self.transform = transforms.Compose(
        [
            transforms.Resize(self.size),
            transforms.ToTensor(),
            transforms.Normalize(mean=mean, std=std),
        ]
    )
    self.model = self._load_model(self.arch, self.device)
    self.fx = self._make_extractor(self.model, self.nodes)
    self._embed_dim: int | None = None

    self._checked = True

compute(batch_metrics: dict[str, pa.Array]) -> dict[str, pa.Array]

Compute final dataset-level metrics (not used for embeddings).

Returns:

Type Description
dict[str, Array]

Empty dictionary as embeddings are computed at feature level.

Source code in packages/dqm-ml-pytorch/src/dqm_ml_pytorch/image_embedding.py
233
234
235
236
237
238
239
240
@override
def compute(self, batch_metrics: dict[str, pa.Array]) -> dict[str, pa.Array]:
    """Compute final dataset-level metrics (not used for embeddings).

    Returns:
        Empty dictionary as embeddings are computed at feature level.
    """
    return {}

compute_batch_metric(features: dict[str, pa.Array]) -> dict[str, pa.Array]

Return an empty dictionary as embeddings are stored as features, we do not compute metrics.

Source code in packages/dqm-ml-pytorch/src/dqm_ml_pytorch/image_embedding.py
226
227
228
229
230
231
@override
def compute_batch_metric(self, features: dict[str, pa.Array]) -> dict[str, pa.Array]:
    """
    Return an empty dictionary as embeddings are stored as features, we do not compute metrics.
    """
    return {}

compute_delta(source: dict[str, pa.Array], target: dict[str, pa.Array]) -> dict[str, pa.Array]

Compute delta between source and target embeddings (not used).

Parameters:

Name Type Description Default
source dict[str, Array]

Source embeddings (not used).

required
target dict[str, Array]

Target embeddings (not used).

required

Returns:

Type Description
dict[str, Array]

Empty dictionary as delta computation is handled by DomainGapProcessor.

Source code in packages/dqm-ml-pytorch/src/dqm_ml_pytorch/image_embedding.py
242
243
244
245
246
247
248
249
250
251
252
253
@override
def compute_delta(self, source: dict[str, pa.Array], target: dict[str, pa.Array]) -> dict[str, pa.Array]:
    """Compute delta between source and target embeddings (not used).

    Args:
        source: Source embeddings (not used).
        target: Target embeddings (not used).

    Returns:
        Empty dictionary as delta computation is handled by DomainGapProcessor.
    """
    return {}

compute_features(batch: pa.RecordBatch, prev_features: pa.Array = None) -> dict[str, pa.Array]

Extract image embeddings for all samples in the batch.

  1. Images are loaded and transformed.
  2. Model inference is performed in sub-batches defined by INFER.batch_size.
  3. Results are aggregated into a pyarrow FixedSizeListArray.

Parameters:

Name Type Description Default
batch RecordBatch

Raw pyarrow batch.

required
prev_features Array

Pre-computed features (not used).

None

Returns:

Type Description
dict[str, Array]

Dictionary mapping 'embedding' to the calculated feature vectors.

Source code in packages/dqm-ml-pytorch/src/dqm_ml_pytorch/image_embedding.py
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
@override
def compute_features(self, batch: pa.RecordBatch, prev_features: pa.Array = None) -> dict[str, pa.Array]:
    """
    Extract image embeddings for all samples in the batch.

    1. Images are loaded and transformed.
    2. Model inference is performed in sub-batches defined by `INFER.batch_size`.
    3. Results are aggregated into a pyarrow `FixedSizeListArray`.

    Args:
        batch: Raw pyarrow batch.
        prev_features: Pre-computed features (not used).

    Returns:
        Dictionary mapping 'embedding' to the calculated feature vectors.
    """
    if not getattr(self, "_checked", False):
        self.check_config()
    if self.image_column not in batch.schema.names:
        logger.warning(f"[ImageEmbeddingProcessor] missing column '{self.image_column}'")
        return {}

    # 1 load images
    vals = batch.column(self.image_column).to_pylist()
    imgs: list[torch.Tensor | None] = []
    for v in vals:
        if v is None:
            imgs.append(None)
            continue
        try:
            if self.mode == "bytes":
                img = Image.open(io.BytesIO(v)).convert("RGB")
            else:
                img_path = Path(self.dataset_root_path) / v if self.dataset_root_path != "undefined" else Path(v)
                img = Image.open(img_path).convert("RGB")
            imgs.append(self.transform(img))
        except Exception as e:
            logger.warning(f"[ImageEmbeddingProcessor] failed to load image: {e}")
            imgs.append(None)

    # inference in windows, preserve order
    embs: list[np.ndarray | None] = []
    self.fx.eval()
    with torch.no_grad():
        i = 0
        while i < len(imgs):
            window = imgs[i : i + self.batch_size]
            valid = [t for t in window if t is not None]
            if valid:
                bt = torch.stack(valid).to(self.device)
                out = self.fx(bt)
                if isinstance(out, dict):
                    flat_feats = [v.flatten(1) for v in out.values()]
                    feats = torch.cat(flat_feats, dim=1)  # type : ignore TODO : check type error
                else:
                    feats = out.flatten(1) if out.dim() > 2 else out
                arr = feats.detach().cpu().numpy().astype("float32")
                p = 0
                for t in window:
                    if t is None:
                        embs.append(None)
                    else:
                        embs.append(arr[p])
                        p += 1
            else:
                embs.extend([None] * len(window))
            i += self.batch_size

    # 3. Infer embedding dim
    if self._embed_dim is None:
        for emb in embs:
            if emb is not None:
                self._embed_dim = int(emb.size)
                break
        if self._embed_dim is None:
            return {}
    d = self._embed_dim

    # 4. Build FixedSizeListArray
    flat: list[float] = []
    for emb in embs:
        if emb is None:
            flat.extend([0.0] * d)
        else:
            v = emb.ravel()
            if v.size != d:
                v = v[:d] if v.size > d else np.pad(v, (0, d - v.size))
            flat.extend(v.tolist())

    child = pa.array(np.asarray(flat, dtype=np.float32))
    return {"embedding": pa.FixedSizeListArray.from_arrays(child, d)}

generated_columns() -> list[str]

Return the list of columns generated by this processor.

Returns:

Type Description
list[str]

A list containing 'embedding'.

Source code in packages/dqm-ml-pytorch/src/dqm_ml_pytorch/image_embedding.py
126
127
128
129
130
131
132
def generated_columns(self) -> list[str]:
    """Return the list of columns generated by this processor.

    Returns:
        A list containing 'embedding'.
    """
    return ["embedding"]

needed_columns() -> list[str]

Source code in packages/dqm-ml-pytorch/src/dqm_ml_pytorch/image_embedding.py
120
121
122
123
124
@override
def needed_columns(self) -> list[str]:
    if not getattr(self, "_checked", False):
        self.check_config()
    return [self.image_column]