Coverage for packages / dqm-ml-pytorch / src / dqm_ml_pytorch / image_embedding.py: 80%
135 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-15 10:11 +0000
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-15 10:11 +0000
1"""Image embedding processor using pre-trained deep learning models.
3This module contains the ImageEmbeddingProcessor class that extracts
4high-dimensional embeddings from images using PyTorch and torchvision
5pre-trained models.
6"""
8from __future__ import annotations
10import io
11import logging
12from pathlib import Path
13from typing import Any
15import numpy as np
16from PIL import Image
17import pyarrow as pa
18import torch
19import torchvision
20from torchvision import transforms
21from torchvision.models.feature_extraction import create_feature_extractor
23# COMPATIBILITY : from typing import Any, override # When support of 3.10 and 3.11 will be removed
24from typing_extensions import override
26from dqm_ml_core import DatametricProcessor
28logger = logging.getLogger(__name__)
31class ImageEmbeddingProcessor(DatametricProcessor):
32 """
33 Computes high-dimensional latent vectors (embeddings) for images using deep learning models.
35 This processor uses PyTorch and Torchvision to:
36 1. Load images from bytes or file paths.
37 2. Preprocess images (resize, normalize) for the selected model.
38 3. Run batch inference using a pre-trained model (e.g., ResNet, ViT).
39 4. Extract features from a specific layer (e.g., 'avgpool').
41 The resulting embeddings are stored as a `FixedSizeListArray` in the features.
42 """
44 def __init__(
45 self,
46 name: str = "image_embedding",
47 config: dict[str, Any] | None = None,
48 ):
49 """
50 Initialize the image embedding processor.
52 Args:
53 name: Unique name of the processor instance.
54 config: Configuration dictionary containing:
55 - DATA:
56 - image_column: Column name containing image data (default: "image_bytes").
57 - mode: Source type, "bytes" or "path" (default: "bytes").
58 - INFER:
59 - width, height: Input resolution for the model (default: 224x224).
60 - batch_size: Number of images per inference pass (default: 32).
61 - norm_mean, norm_std: Preprocessing normalization stats.
62 - MODEL:
63 - arch: Torchvision model name (default: "resnet18").
64 - n_layer_feature: Target layer for feature extraction (default: "avgpool").
65 - device: Execution device, "cpu" or "cuda" (default: "cpu").
66 """
67 super().__init__(name, config)
68 self._checked = False
70 # ---------------- API ----------------
71 def check_config(self) -> None:
72 """Validate and initialize model/transforms from configuration.
74 This method parses the configuration dictionary and initializes:
75 - Image loading parameters (column name, mode, dataset root path)
76 - Inference parameters (image size, batch size, normalization)
77 - Model parameters (architecture, feature extraction layer, device)
78 - Loads the pre-trained model and creates the feature extractor.
79 """
80 cfg = self.config or {}
82 dcfg = cfg.get("DATA", {})
83 self.image_column: str = dcfg.get("image_column", "image_bytes")
84 self.mode: str = dcfg.get("mode", "bytes") # "bytes" or "path"
85 if self.mode not in {"bytes", "path"}: 85 ↛ 86line 85 didn't jump to line 86 because the condition on line 85 was never true
86 raise ValueError(f"[{self.name}] DATA.mode must be 'bytes' or 'path'")
88 # handle relative paths in parquet to a dataset located at dataset_root_path
89 self.dataset_root_path = str(cfg.get("dataset_root_path", "undefined"))
90 logger.info(f"[ImageEmbeddingProcessor] dataset_root_path = '{self.dataset_root_path}'")
92 icfg = cfg.get("INFER", {})
93 self.size: tuple[int, int] = (
94 int(icfg.get("width", 224)),
95 int(icfg.get("height", 224)),
96 )
97 mean = icfg.get("norm_mean", [0.485, 0.456, 0.406])
98 std = icfg.get("norm_std", [0.229, 0.224, 0.225])
99 self.batch_size: int = int(icfg.get("batch_size", 32))
101 mcfg = cfg.get("MODEL", {})
102 self.arch: str = mcfg.get("arch", "resnet18")
103 self.nodes = mcfg.get("n_layer_feature", "avgpool")
104 self.device: str = mcfg.get("device", "cpu")
106 # Build once
107 self.transform = transforms.Compose(
108 [
109 transforms.Resize(self.size),
110 transforms.ToTensor(),
111 transforms.Normalize(mean=mean, std=std),
112 ]
113 )
114 self.model = self._load_model(self.arch, self.device)
115 self.fx = self._make_extractor(self.model, self.nodes)
116 self._embed_dim: int | None = None
118 self._checked = True
120 @override
121 def needed_columns(self) -> list[str]:
122 if not getattr(self, "_checked", False):
123 self.check_config()
124 return [self.image_column]
126 def generated_columns(self) -> list[str]:
127 """Return the list of columns generated by this processor.
129 Returns:
130 A list containing 'embedding'.
131 """
132 return ["embedding"]
134 @override
135 def compute_features(self, batch: pa.RecordBatch, prev_features: pa.Array = None) -> dict[str, pa.Array]:
136 """
137 Extract image embeddings for all samples in the batch.
139 1. Images are loaded and transformed.
140 2. Model inference is performed in sub-batches defined by `INFER.batch_size`.
141 3. Results are aggregated into a pyarrow `FixedSizeListArray`.
143 Args:
144 batch: Raw pyarrow batch.
145 prev_features: Pre-computed features (not used).
147 Returns:
148 Dictionary mapping 'embedding' to the calculated feature vectors.
149 """
150 if not getattr(self, "_checked", False): 150 ↛ 151line 150 didn't jump to line 151 because the condition on line 150 was never true
151 self.check_config()
152 if self.image_column not in batch.schema.names: 152 ↛ 153line 152 didn't jump to line 153 because the condition on line 152 was never true
153 logger.warning(f"[ImageEmbeddingProcessor] missing column '{self.image_column}'")
154 return {}
156 # 1 load images
157 vals = batch.column(self.image_column).to_pylist()
158 imgs: list[torch.Tensor | None] = []
159 for v in vals:
160 if v is None: 160 ↛ 161line 160 didn't jump to line 161 because the condition on line 160 was never true
161 imgs.append(None)
162 continue
163 try:
164 if self.mode == "bytes":
165 img = Image.open(io.BytesIO(v)).convert("RGB")
166 else:
167 img_path = Path(self.dataset_root_path) / v if self.dataset_root_path != "undefined" else Path(v)
168 img = Image.open(img_path).convert("RGB")
169 imgs.append(self.transform(img))
170 except Exception as e:
171 logger.warning(f"[ImageEmbeddingProcessor] failed to load image: {e}")
172 imgs.append(None)
174 # inference in windows, preserve order
175 embs: list[np.ndarray | None] = []
176 self.fx.eval()
177 with torch.no_grad():
178 i = 0
179 while i < len(imgs):
180 window = imgs[i : i + self.batch_size]
181 valid = [t for t in window if t is not None]
182 if valid: 182 ↛ 199line 182 didn't jump to line 199 because the condition on line 182 was always true
183 bt = torch.stack(valid).to(self.device)
184 out = self.fx(bt)
185 if isinstance(out, dict): 185 ↛ 189line 185 didn't jump to line 189 because the condition on line 185 was always true
186 flat_feats = [v.flatten(1) for v in out.values()]
187 feats = torch.cat(flat_feats, dim=1) # type : ignore TODO : check type error
188 else:
189 feats = out.flatten(1) if out.dim() > 2 else out
190 arr = feats.detach().cpu().numpy().astype("float32")
191 p = 0
192 for t in window:
193 if t is None: 193 ↛ 194line 193 didn't jump to line 194 because the condition on line 193 was never true
194 embs.append(None)
195 else:
196 embs.append(arr[p])
197 p += 1
198 else:
199 embs.extend([None] * len(window))
200 i += self.batch_size
202 # 3. Infer embedding dim
203 if self._embed_dim is None:
204 for emb in embs: 204 ↛ 208line 204 didn't jump to line 208 because the loop on line 204 didn't complete
205 if emb is not None: 205 ↛ 204line 205 didn't jump to line 204 because the condition on line 205 was always true
206 self._embed_dim = int(emb.size)
207 break
208 if self._embed_dim is None: 208 ↛ 209line 208 didn't jump to line 209 because the condition on line 208 was never true
209 return {}
210 d = self._embed_dim
212 # 4. Build FixedSizeListArray
213 flat: list[float] = []
214 for emb in embs:
215 if emb is None: 215 ↛ 216line 215 didn't jump to line 216 because the condition on line 215 was never true
216 flat.extend([0.0] * d)
217 else:
218 v = emb.ravel()
219 if v.size != d: 219 ↛ 220line 219 didn't jump to line 220 because the condition on line 219 was never true
220 v = v[:d] if v.size > d else np.pad(v, (0, d - v.size))
221 flat.extend(v.tolist())
223 child = pa.array(np.asarray(flat, dtype=np.float32))
224 return {"embedding": pa.FixedSizeListArray.from_arrays(child, d)}
226 @override
227 def compute_batch_metric(self, features: dict[str, pa.Array]) -> dict[str, pa.Array]:
228 """
229 Return an empty dictionary as embeddings are stored as features, we do not compute metrics.
230 """
231 return {}
233 @override
234 def compute(self, batch_metrics: dict[str, pa.Array]) -> dict[str, pa.Array]:
235 """Compute final dataset-level metrics (not used for embeddings).
237 Returns:
238 Empty dictionary as embeddings are computed at feature level.
239 """
240 return {}
242 @override
243 def compute_delta(self, source: dict[str, pa.Array], target: dict[str, pa.Array]) -> dict[str, pa.Array]:
244 """Compute delta between source and target embeddings (not used).
246 Args:
247 source: Source embeddings (not used).
248 target: Target embeddings (not used).
250 Returns:
251 Empty dictionary as delta computation is handled by DomainGapProcessor.
252 """
253 return {}
255 # utils functions
256 def _load_model(self, arch: str, device: str) -> Any:
257 """Load a pre-trained torchvision model.
259 Args:
260 arch: Model architecture name (e.g., 'resnet18', 'resnet50').
261 device: Device to load the model on ('cpu' or 'cuda').
263 Returns:
264 The loaded PyTorch model.
265 """
266 try:
267 m = torchvision.models.get_model(arch, weights="DEFAULT")
268 except Exception:
269 m = getattr(torchvision.models, arch)(pretrained=True)
270 return m.to(device)
272 def _make_extractor(self, model: torch.nn.Module, nodes: Any) -> Any:
273 """Create a feature extractor from a model.
275 Args:
276 model: The PyTorch model to extract features from.
277 nodes: Layer name (str), index (int), or list of names to extract.
279 Returns:
280 A feature extractor that returns the requested layer outputs.
281 """
282 names = list(dict(model.named_modules()).keys())
283 if isinstance(nodes, list): 283 ↛ 284line 283 didn't jump to line 284 because the condition on line 283 was never true
284 return create_feature_extractor(model, return_nodes={n: n for n in nodes})
285 if isinstance(nodes, int): 285 ↛ 286line 285 didn't jump to line 286 because the condition on line 285 was never true
286 idx = nodes if nodes >= 0 else len(names) + nodes
287 nodes = names[idx]
288 return create_feature_extractor(model, return_nodes={nodes: "features"})