Skip to content

dqm_ml_pytorch.image_embedding

Image embedding processor using pre-trained deep learning models.

This module contains the ImageEmbeddingProcessor class that extracts high-dimensional embeddings from images using PyTorch and torchvision pre-trained models.

logger = logging.getLogger(__name__) module-attribute

ImageEmbeddingProcessor

Bases: DatametricProcessor

Computes high-dimensional latent vectors (embeddings) for images using deep learning models.

This processor uses PyTorch and Torchvision to: 1. Load images from bytes or file paths. 2. Preprocess images (resize, normalize) for the selected model. 3. Run batch inference using a pre-trained model (e.g., ResNet, ViT). 4. Extract features from a specific layer (e.g., 'avgpool').

The resulting embeddings are stored as a FixedSizeListArray in the features.

Source code in packages/dqm-ml-pytorch/src/dqm_ml_pytorch/image_embedding.py
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
class ImageEmbeddingProcessor(DatametricProcessor):
    """
    Computes high-dimensional latent vectors (embeddings) for images using deep learning models.

    This processor uses PyTorch and Torchvision to:
    1. Load images from bytes or file paths.
    2. Preprocess images (resize, normalize) for the selected model.
    3. Run batch inference using a pre-trained model (e.g., ResNet, ViT).
    4. Extract features from a specific layer (e.g., 'avgpool').

    The resulting embeddings are stored as a `FixedSizeListArray` in the features.
    """

    def __init__(
        self,
        name: str = "image_embedding",
        config: dict[str, Any] | None = None,
    ):
        """
        Initialize the image embedding processor.

        Args:
            name: Unique name of the processor instance.
            config: Configuration dictionary containing:
                - DATA:
                    - image_column: Column name containing image data (default: "image_bytes").
                    - mode: Source type, "bytes" or "path" (default: "bytes").
                - INFER:
                    - width, height: Input resolution for the model (default: 224x224).
                    - batch_size: Number of images per inference pass (default: 32).
                    - norm_mean, norm_std: Preprocessing normalization stats.
                - MODEL:
                    - arch: Torchvision model name (default: "resnet18").
                    - n_layer_feature: Target layer for feature extraction (default: "avgpool").
                    - device: Execution device, "cpu" or "cuda" (default: "cpu").
        """
        super().__init__(name, config)
        self._checked = False

    # ---------------- API ----------------
    def check_config(self) -> None:
        """Validate and initialize model/transforms from configuration.

        This method parses the configuration dictionary and initializes:
        - Image loading parameters (column name, mode, dataset root path)
        - Inference parameters (image size, batch size, normalization)
        - Model parameters (architecture, feature extraction layer, device)
        - Loads the pre-trained model and creates the feature extractor.
        """
        cfg = self.config or {}

        dcfg = cfg.get("DATA", {})
        self.image_column: str = dcfg.get("image_column", "image_bytes")
        self.mode: str = dcfg.get("mode", "bytes")  # "bytes" or "path"
        if self.mode not in {"bytes", "path"}:
            raise ValueError(f"[{self.name}] DATA.mode must be 'bytes' or 'path'")

        # handle relative paths in parquet to a dataset located at dataset_root_path
        self.dataset_root_path = str(cfg.get("dataset_root_path", "undefined"))
        logger.info(f"[ImageEmbeddingProcessor] dataset_root_path = '{self.dataset_root_path}'")

        icfg = cfg.get("INFER", {})
        self.size: tuple[int, int] = (
            int(icfg.get("width", 224)),
            int(icfg.get("height", 224)),
        )
        mean = icfg.get("norm_mean", [0.485, 0.456, 0.406])
        std = icfg.get("norm_std", [0.229, 0.224, 0.225])
        self.batch_size: int = int(icfg.get("batch_size", 32))

        mcfg = cfg.get("MODEL", {})
        self.arch: str = mcfg.get("arch", "resnet18")
        self.nodes = mcfg.get("n_layer_feature", "avgpool")
        self.device: str = mcfg.get("device", "cpu")

        # Build once
        self.transform = transforms.Compose(
            [
                transforms.Resize(self.size),
                transforms.ToTensor(),
                transforms.Normalize(mean=mean, std=std),
            ]
        )
        self.model = self._load_model(self.arch, self.device)
        self.fx = self._make_extractor(self.model, self.nodes)
        self._embed_dim: int | None = None

        self._checked = True

    @override
    def needed_columns(self) -> list[str]:
        if not getattr(self, "_checked", False):
            self.check_config()
        return [self.image_column]

    def generated_columns(self) -> list[str]:
        """Return the list of columns generated by this processor.

        Returns:
            A list containing 'embedding'.
        """
        return ["embedding"]

    @override
    def compute_features(self, batch: pa.RecordBatch, prev_features: pa.Array = None) -> dict[str, pa.Array]:
        """
        Extract image embeddings for all samples in the batch.

        1. Images are loaded and transformed.
        2. Model inference is performed in sub-batches defined by `INFER.batch_size`.
        3. Results are aggregated into a pyarrow `FixedSizeListArray`.

        Args:
            batch: Raw pyarrow batch.
            prev_features: Pre-computed features (not used).

        Returns:
            Dictionary mapping 'embedding' to the calculated feature vectors.
        """
        if not getattr(self, "_checked", False):
            self.check_config()
        if self.image_column not in batch.schema.names:
            logger.warning(f"[ImageEmbeddingProcessor] missing column '{self.image_column}'")
            return {}

        # 1 load images
        vals = batch.column(self.image_column).to_pylist()
        imgs: list[torch.Tensor | None] = []
        for v in vals:
            if v is None:
                imgs.append(None)
                continue
            try:
                if self.mode == "bytes":
                    img = Image.open(io.BytesIO(v)).convert("RGB")
                else:
                    img_path = Path(self.dataset_root_path) / v if self.dataset_root_path != "undefined" else Path(v)
                    img = Image.open(img_path).convert("RGB")
                imgs.append(self.transform(img))
            except Exception as e:
                logger.warning(f"[ImageEmbeddingProcessor] failed to load image: {e}")
                imgs.append(None)

        # inference in windows, preserve order
        embs: list[np.ndarray | None] = []
        self.fx.eval()
        with torch.no_grad():
            i = 0
            while i < len(imgs):
                window = imgs[i : i + self.batch_size]
                valid = [t for t in window if t is not None]
                if valid:
                    bt = torch.stack(valid).to(self.device)
                    out = self.fx(bt)
                    if isinstance(out, dict):
                        flat_feats = [v.flatten(1) for v in out.values()]
                        feats = torch.cat(flat_feats, dim=1)  # type : ignore TODO : check type error
                    else:
                        feats = out.flatten(1) if out.dim() > 2 else out
                    arr = feats.detach().cpu().numpy().astype("float32")
                    p = 0
                    for t in window:
                        if t is None:
                            embs.append(None)
                        else:
                            embs.append(arr[p])
                            p += 1
                else:
                    embs.extend([None] * len(window))
                i += self.batch_size

        # 3. Infer embedding dim
        if self._embed_dim is None:
            for emb in embs:
                if emb is not None:
                    self._embed_dim = int(emb.size)
                    break
            if self._embed_dim is None:
                return {}
        d = self._embed_dim

        # 4. Build FixedSizeListArray
        flat: list[float] = []
        for emb in embs:
            if emb is None:
                flat.extend([0.0] * d)
            else:
                v = emb.ravel()
                if v.size != d:
                    v = v[:d] if v.size > d else np.pad(v, (0, d - v.size))
                flat.extend(v.tolist())

        child = pa.array(np.asarray(flat, dtype=np.float32))
        return {"embedding": pa.FixedSizeListArray.from_arrays(child, d)}

    @override
    def compute_batch_metric(self, features: dict[str, pa.Array]) -> dict[str, pa.Array]:
        """
        Return an empty dictionary as embeddings are stored as features, we do not compute metrics.
        """
        return {}

    @override
    def compute(self, batch_metrics: dict[str, pa.Array]) -> dict[str, pa.Array]:
        """Compute final dataset-level metrics (not used for embeddings).

        Returns:
            Empty dictionary as embeddings are computed at feature level.
        """
        return {}

    @override
    def compute_delta(self, source: dict[str, pa.Array], target: dict[str, pa.Array]) -> dict[str, pa.Array]:
        """Compute delta between source and target embeddings (not used).

        Args:
            source: Source embeddings (not used).
            target: Target embeddings (not used).

        Returns:
            Empty dictionary as delta computation is handled by DomainGapProcessor.
        """
        return {}

    # utils functions
    def _load_model(self, arch: str, device: str) -> Any:
        """Load a pre-trained torchvision model.

        Args:
            arch: Model architecture name (e.g., 'resnet18', 'resnet50').
            device: Device to load the model on ('cpu' or 'cuda').

        Returns:
            The loaded PyTorch model.
        """
        try:
            m = torchvision.models.get_model(arch, weights="DEFAULT")
        except Exception:
            m = getattr(torchvision.models, arch)(pretrained=True)
        return m.to(device)

    def _make_extractor(self, model: torch.nn.Module, nodes: Any) -> Any:
        """Create a feature extractor from a model.

        Args:
            model: The PyTorch model to extract features from.
            nodes: Layer name (str), index (int), or list of names to extract.

        Returns:
            A feature extractor that returns the requested layer outputs.
        """
        names = list(dict(model.named_modules()).keys())
        if isinstance(nodes, list):
            return create_feature_extractor(model, return_nodes={n: n for n in nodes})
        if isinstance(nodes, int):
            idx = nodes if nodes >= 0 else len(names) + nodes
            nodes = names[idx]
        return create_feature_extractor(model, return_nodes={nodes: "features"})

__init__(name: str = 'image_embedding', config: dict[str, Any] | None = None)

Initialize the image embedding processor.

Parameters:

Name Type Description Default
name str

Unique name of the processor instance.

'image_embedding'
config dict[str, Any] | None

Configuration dictionary containing: - DATA: - image_column: Column name containing image data (default: "image_bytes"). - mode: Source type, "bytes" or "path" (default: "bytes"). - INFER: - width, height: Input resolution for the model (default: 224x224). - batch_size: Number of images per inference pass (default: 32). - norm_mean, norm_std: Preprocessing normalization stats. - MODEL: - arch: Torchvision model name (default: "resnet18"). - n_layer_feature: Target layer for feature extraction (default: "avgpool"). - device: Execution device, "cpu" or "cuda" (default: "cpu").

None
Source code in packages/dqm-ml-pytorch/src/dqm_ml_pytorch/image_embedding.py
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
def __init__(
    self,
    name: str = "image_embedding",
    config: dict[str, Any] | None = None,
):
    """
    Initialize the image embedding processor.

    Args:
        name: Unique name of the processor instance.
        config: Configuration dictionary containing:
            - DATA:
                - image_column: Column name containing image data (default: "image_bytes").
                - mode: Source type, "bytes" or "path" (default: "bytes").
            - INFER:
                - width, height: Input resolution for the model (default: 224x224).
                - batch_size: Number of images per inference pass (default: 32).
                - norm_mean, norm_std: Preprocessing normalization stats.
            - MODEL:
                - arch: Torchvision model name (default: "resnet18").
                - n_layer_feature: Target layer for feature extraction (default: "avgpool").
                - device: Execution device, "cpu" or "cuda" (default: "cpu").
    """
    super().__init__(name, config)
    self._checked = False

check_config() -> None

Validate and initialize model/transforms from configuration.

This method parses the configuration dictionary and initializes: - Image loading parameters (column name, mode, dataset root path) - Inference parameters (image size, batch size, normalization) - Model parameters (architecture, feature extraction layer, device) - Loads the pre-trained model and creates the feature extractor.

Source code in packages/dqm-ml-pytorch/src/dqm_ml_pytorch/image_embedding.py
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
def check_config(self) -> None:
    """Validate and initialize model/transforms from configuration.

    This method parses the configuration dictionary and initializes:
    - Image loading parameters (column name, mode, dataset root path)
    - Inference parameters (image size, batch size, normalization)
    - Model parameters (architecture, feature extraction layer, device)
    - Loads the pre-trained model and creates the feature extractor.
    """
    cfg = self.config or {}

    dcfg = cfg.get("DATA", {})
    self.image_column: str = dcfg.get("image_column", "image_bytes")
    self.mode: str = dcfg.get("mode", "bytes")  # "bytes" or "path"
    if self.mode not in {"bytes", "path"}:
        raise ValueError(f"[{self.name}] DATA.mode must be 'bytes' or 'path'")

    # handle relative paths in parquet to a dataset located at dataset_root_path
    self.dataset_root_path = str(cfg.get("dataset_root_path", "undefined"))
    logger.info(f"[ImageEmbeddingProcessor] dataset_root_path = '{self.dataset_root_path}'")

    icfg = cfg.get("INFER", {})
    self.size: tuple[int, int] = (
        int(icfg.get("width", 224)),
        int(icfg.get("height", 224)),
    )
    mean = icfg.get("norm_mean", [0.485, 0.456, 0.406])
    std = icfg.get("norm_std", [0.229, 0.224, 0.225])
    self.batch_size: int = int(icfg.get("batch_size", 32))

    mcfg = cfg.get("MODEL", {})
    self.arch: str = mcfg.get("arch", "resnet18")
    self.nodes = mcfg.get("n_layer_feature", "avgpool")
    self.device: str = mcfg.get("device", "cpu")

    # Build once
    self.transform = transforms.Compose(
        [
            transforms.Resize(self.size),
            transforms.ToTensor(),
            transforms.Normalize(mean=mean, std=std),
        ]
    )
    self.model = self._load_model(self.arch, self.device)
    self.fx = self._make_extractor(self.model, self.nodes)
    self._embed_dim: int | None = None

    self._checked = True

compute(batch_metrics: dict[str, pa.Array]) -> dict[str, pa.Array]

Compute final dataset-level metrics (not used for embeddings).

Returns:

Type Description
dict[str, Array]

Empty dictionary as embeddings are computed at feature level.

Source code in packages/dqm-ml-pytorch/src/dqm_ml_pytorch/image_embedding.py
233
234
235
236
237
238
239
240
@override
def compute(self, batch_metrics: dict[str, pa.Array]) -> dict[str, pa.Array]:
    """Compute final dataset-level metrics (not used for embeddings).

    Returns:
        Empty dictionary as embeddings are computed at feature level.
    """
    return {}

compute_batch_metric(features: dict[str, pa.Array]) -> dict[str, pa.Array]

Return an empty dictionary as embeddings are stored as features, we do not compute metrics.

Source code in packages/dqm-ml-pytorch/src/dqm_ml_pytorch/image_embedding.py
226
227
228
229
230
231
@override
def compute_batch_metric(self, features: dict[str, pa.Array]) -> dict[str, pa.Array]:
    """
    Return an empty dictionary as embeddings are stored as features, we do not compute metrics.
    """
    return {}

compute_delta(source: dict[str, pa.Array], target: dict[str, pa.Array]) -> dict[str, pa.Array]

Compute delta between source and target embeddings (not used).

Parameters:

Name Type Description Default
source dict[str, Array]

Source embeddings (not used).

required
target dict[str, Array]

Target embeddings (not used).

required

Returns:

Type Description
dict[str, Array]

Empty dictionary as delta computation is handled by DomainGapProcessor.

Source code in packages/dqm-ml-pytorch/src/dqm_ml_pytorch/image_embedding.py
242
243
244
245
246
247
248
249
250
251
252
253
@override
def compute_delta(self, source: dict[str, pa.Array], target: dict[str, pa.Array]) -> dict[str, pa.Array]:
    """Compute delta between source and target embeddings (not used).

    Args:
        source: Source embeddings (not used).
        target: Target embeddings (not used).

    Returns:
        Empty dictionary as delta computation is handled by DomainGapProcessor.
    """
    return {}

compute_features(batch: pa.RecordBatch, prev_features: pa.Array = None) -> dict[str, pa.Array]

Extract image embeddings for all samples in the batch.

  1. Images are loaded and transformed.
  2. Model inference is performed in sub-batches defined by INFER.batch_size.
  3. Results are aggregated into a pyarrow FixedSizeListArray.

Parameters:

Name Type Description Default
batch RecordBatch

Raw pyarrow batch.

required
prev_features Array

Pre-computed features (not used).

None

Returns:

Type Description
dict[str, Array]

Dictionary mapping 'embedding' to the calculated feature vectors.

Source code in packages/dqm-ml-pytorch/src/dqm_ml_pytorch/image_embedding.py
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
@override
def compute_features(self, batch: pa.RecordBatch, prev_features: pa.Array = None) -> dict[str, pa.Array]:
    """
    Extract image embeddings for all samples in the batch.

    1. Images are loaded and transformed.
    2. Model inference is performed in sub-batches defined by `INFER.batch_size`.
    3. Results are aggregated into a pyarrow `FixedSizeListArray`.

    Args:
        batch: Raw pyarrow batch.
        prev_features: Pre-computed features (not used).

    Returns:
        Dictionary mapping 'embedding' to the calculated feature vectors.
    """
    if not getattr(self, "_checked", False):
        self.check_config()
    if self.image_column not in batch.schema.names:
        logger.warning(f"[ImageEmbeddingProcessor] missing column '{self.image_column}'")
        return {}

    # 1 load images
    vals = batch.column(self.image_column).to_pylist()
    imgs: list[torch.Tensor | None] = []
    for v in vals:
        if v is None:
            imgs.append(None)
            continue
        try:
            if self.mode == "bytes":
                img = Image.open(io.BytesIO(v)).convert("RGB")
            else:
                img_path = Path(self.dataset_root_path) / v if self.dataset_root_path != "undefined" else Path(v)
                img = Image.open(img_path).convert("RGB")
            imgs.append(self.transform(img))
        except Exception as e:
            logger.warning(f"[ImageEmbeddingProcessor] failed to load image: {e}")
            imgs.append(None)

    # inference in windows, preserve order
    embs: list[np.ndarray | None] = []
    self.fx.eval()
    with torch.no_grad():
        i = 0
        while i < len(imgs):
            window = imgs[i : i + self.batch_size]
            valid = [t for t in window if t is not None]
            if valid:
                bt = torch.stack(valid).to(self.device)
                out = self.fx(bt)
                if isinstance(out, dict):
                    flat_feats = [v.flatten(1) for v in out.values()]
                    feats = torch.cat(flat_feats, dim=1)  # type : ignore TODO : check type error
                else:
                    feats = out.flatten(1) if out.dim() > 2 else out
                arr = feats.detach().cpu().numpy().astype("float32")
                p = 0
                for t in window:
                    if t is None:
                        embs.append(None)
                    else:
                        embs.append(arr[p])
                        p += 1
            else:
                embs.extend([None] * len(window))
            i += self.batch_size

    # 3. Infer embedding dim
    if self._embed_dim is None:
        for emb in embs:
            if emb is not None:
                self._embed_dim = int(emb.size)
                break
        if self._embed_dim is None:
            return {}
    d = self._embed_dim

    # 4. Build FixedSizeListArray
    flat: list[float] = []
    for emb in embs:
        if emb is None:
            flat.extend([0.0] * d)
        else:
            v = emb.ravel()
            if v.size != d:
                v = v[:d] if v.size > d else np.pad(v, (0, d - v.size))
            flat.extend(v.tolist())

    child = pa.array(np.asarray(flat, dtype=np.float32))
    return {"embedding": pa.FixedSizeListArray.from_arrays(child, d)}

generated_columns() -> list[str]

Return the list of columns generated by this processor.

Returns:

Type Description
list[str]

A list containing 'embedding'.

Source code in packages/dqm-ml-pytorch/src/dqm_ml_pytorch/image_embedding.py
126
127
128
129
130
131
132
def generated_columns(self) -> list[str]:
    """Return the list of columns generated by this processor.

    Returns:
        A list containing 'embedding'.
    """
    return ["embedding"]

needed_columns() -> list[str]

Source code in packages/dqm-ml-pytorch/src/dqm_ml_pytorch/image_embedding.py
120
121
122
123
124
@override
def needed_columns(self) -> list[str]:
    if not getattr(self, "_checked", False):
        self.check_config()
    return [self.image_column]