Coverage for packages / dqm-ml-pytorch / src / dqm_ml_pytorch / image_embedding.py: 80%

135 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-04-15 10:11 +0000

1"""Image embedding processor using pre-trained deep learning models. 

2 

3This module contains the ImageEmbeddingProcessor class that extracts 

4high-dimensional embeddings from images using PyTorch and torchvision 

5pre-trained models. 

6""" 

7 

8from __future__ import annotations 

9 

10import io 

11import logging 

12from pathlib import Path 

13from typing import Any 

14 

15import numpy as np 

16from PIL import Image 

17import pyarrow as pa 

18import torch 

19import torchvision 

20from torchvision import transforms 

21from torchvision.models.feature_extraction import create_feature_extractor 

22 

23# COMPATIBILITY : from typing import Any, override # When support of 3.10 and 3.11 will be removed 

24from typing_extensions import override 

25 

26from dqm_ml_core import DatametricProcessor 

27 

28logger = logging.getLogger(__name__) 

29 

30 

31class ImageEmbeddingProcessor(DatametricProcessor): 

32 """ 

33 Computes high-dimensional latent vectors (embeddings) for images using deep learning models. 

34 

35 This processor uses PyTorch and Torchvision to: 

36 1. Load images from bytes or file paths. 

37 2. Preprocess images (resize, normalize) for the selected model. 

38 3. Run batch inference using a pre-trained model (e.g., ResNet, ViT). 

39 4. Extract features from a specific layer (e.g., 'avgpool'). 

40 

41 The resulting embeddings are stored as a `FixedSizeListArray` in the features. 

42 """ 

43 

44 def __init__( 

45 self, 

46 name: str = "image_embedding", 

47 config: dict[str, Any] | None = None, 

48 ): 

49 """ 

50 Initialize the image embedding processor. 

51 

52 Args: 

53 name: Unique name of the processor instance. 

54 config: Configuration dictionary containing: 

55 - DATA: 

56 - image_column: Column name containing image data (default: "image_bytes"). 

57 - mode: Source type, "bytes" or "path" (default: "bytes"). 

58 - INFER: 

59 - width, height: Input resolution for the model (default: 224x224). 

60 - batch_size: Number of images per inference pass (default: 32). 

61 - norm_mean, norm_std: Preprocessing normalization stats. 

62 - MODEL: 

63 - arch: Torchvision model name (default: "resnet18"). 

64 - n_layer_feature: Target layer for feature extraction (default: "avgpool"). 

65 - device: Execution device, "cpu" or "cuda" (default: "cpu"). 

66 """ 

67 super().__init__(name, config) 

68 self._checked = False 

69 

70 # ---------------- API ---------------- 

71 def check_config(self) -> None: 

72 """Validate and initialize model/transforms from configuration. 

73 

74 This method parses the configuration dictionary and initializes: 

75 - Image loading parameters (column name, mode, dataset root path) 

76 - Inference parameters (image size, batch size, normalization) 

77 - Model parameters (architecture, feature extraction layer, device) 

78 - Loads the pre-trained model and creates the feature extractor. 

79 """ 

80 cfg = self.config or {} 

81 

82 dcfg = cfg.get("DATA", {}) 

83 self.image_column: str = dcfg.get("image_column", "image_bytes") 

84 self.mode: str = dcfg.get("mode", "bytes") # "bytes" or "path" 

85 if self.mode not in {"bytes", "path"}: 85 ↛ 86line 85 didn't jump to line 86 because the condition on line 85 was never true

86 raise ValueError(f"[{self.name}] DATA.mode must be 'bytes' or 'path'") 

87 

88 # handle relative paths in parquet to a dataset located at dataset_root_path 

89 self.dataset_root_path = str(cfg.get("dataset_root_path", "undefined")) 

90 logger.info(f"[ImageEmbeddingProcessor] dataset_root_path = '{self.dataset_root_path}'") 

91 

92 icfg = cfg.get("INFER", {}) 

93 self.size: tuple[int, int] = ( 

94 int(icfg.get("width", 224)), 

95 int(icfg.get("height", 224)), 

96 ) 

97 mean = icfg.get("norm_mean", [0.485, 0.456, 0.406]) 

98 std = icfg.get("norm_std", [0.229, 0.224, 0.225]) 

99 self.batch_size: int = int(icfg.get("batch_size", 32)) 

100 

101 mcfg = cfg.get("MODEL", {}) 

102 self.arch: str = mcfg.get("arch", "resnet18") 

103 self.nodes = mcfg.get("n_layer_feature", "avgpool") 

104 self.device: str = mcfg.get("device", "cpu") 

105 

106 # Build once 

107 self.transform = transforms.Compose( 

108 [ 

109 transforms.Resize(self.size), 

110 transforms.ToTensor(), 

111 transforms.Normalize(mean=mean, std=std), 

112 ] 

113 ) 

114 self.model = self._load_model(self.arch, self.device) 

115 self.fx = self._make_extractor(self.model, self.nodes) 

116 self._embed_dim: int | None = None 

117 

118 self._checked = True 

119 

120 @override 

121 def needed_columns(self) -> list[str]: 

122 if not getattr(self, "_checked", False): 

123 self.check_config() 

124 return [self.image_column] 

125 

126 def generated_columns(self) -> list[str]: 

127 """Return the list of columns generated by this processor. 

128 

129 Returns: 

130 A list containing 'embedding'. 

131 """ 

132 return ["embedding"] 

133 

134 @override 

135 def compute_features(self, batch: pa.RecordBatch, prev_features: pa.Array = None) -> dict[str, pa.Array]: 

136 """ 

137 Extract image embeddings for all samples in the batch. 

138 

139 1. Images are loaded and transformed. 

140 2. Model inference is performed in sub-batches defined by `INFER.batch_size`. 

141 3. Results are aggregated into a pyarrow `FixedSizeListArray`. 

142 

143 Args: 

144 batch: Raw pyarrow batch. 

145 prev_features: Pre-computed features (not used). 

146 

147 Returns: 

148 Dictionary mapping 'embedding' to the calculated feature vectors. 

149 """ 

150 if not getattr(self, "_checked", False): 150 ↛ 151line 150 didn't jump to line 151 because the condition on line 150 was never true

151 self.check_config() 

152 if self.image_column not in batch.schema.names: 152 ↛ 153line 152 didn't jump to line 153 because the condition on line 152 was never true

153 logger.warning(f"[ImageEmbeddingProcessor] missing column '{self.image_column}'") 

154 return {} 

155 

156 # 1 load images 

157 vals = batch.column(self.image_column).to_pylist() 

158 imgs: list[torch.Tensor | None] = [] 

159 for v in vals: 

160 if v is None: 160 ↛ 161line 160 didn't jump to line 161 because the condition on line 160 was never true

161 imgs.append(None) 

162 continue 

163 try: 

164 if self.mode == "bytes": 

165 img = Image.open(io.BytesIO(v)).convert("RGB") 

166 else: 

167 img_path = Path(self.dataset_root_path) / v if self.dataset_root_path != "undefined" else Path(v) 

168 img = Image.open(img_path).convert("RGB") 

169 imgs.append(self.transform(img)) 

170 except Exception as e: 

171 logger.warning(f"[ImageEmbeddingProcessor] failed to load image: {e}") 

172 imgs.append(None) 

173 

174 # inference in windows, preserve order 

175 embs: list[np.ndarray | None] = [] 

176 self.fx.eval() 

177 with torch.no_grad(): 

178 i = 0 

179 while i < len(imgs): 

180 window = imgs[i : i + self.batch_size] 

181 valid = [t for t in window if t is not None] 

182 if valid: 182 ↛ 199line 182 didn't jump to line 199 because the condition on line 182 was always true

183 bt = torch.stack(valid).to(self.device) 

184 out = self.fx(bt) 

185 if isinstance(out, dict): 185 ↛ 189line 185 didn't jump to line 189 because the condition on line 185 was always true

186 flat_feats = [v.flatten(1) for v in out.values()] 

187 feats = torch.cat(flat_feats, dim=1) # type : ignore TODO : check type error 

188 else: 

189 feats = out.flatten(1) if out.dim() > 2 else out 

190 arr = feats.detach().cpu().numpy().astype("float32") 

191 p = 0 

192 for t in window: 

193 if t is None: 193 ↛ 194line 193 didn't jump to line 194 because the condition on line 193 was never true

194 embs.append(None) 

195 else: 

196 embs.append(arr[p]) 

197 p += 1 

198 else: 

199 embs.extend([None] * len(window)) 

200 i += self.batch_size 

201 

202 # 3. Infer embedding dim 

203 if self._embed_dim is None: 

204 for emb in embs: 204 ↛ 208line 204 didn't jump to line 208 because the loop on line 204 didn't complete

205 if emb is not None: 205 ↛ 204line 205 didn't jump to line 204 because the condition on line 205 was always true

206 self._embed_dim = int(emb.size) 

207 break 

208 if self._embed_dim is None: 208 ↛ 209line 208 didn't jump to line 209 because the condition on line 208 was never true

209 return {} 

210 d = self._embed_dim 

211 

212 # 4. Build FixedSizeListArray 

213 flat: list[float] = [] 

214 for emb in embs: 

215 if emb is None: 215 ↛ 216line 215 didn't jump to line 216 because the condition on line 215 was never true

216 flat.extend([0.0] * d) 

217 else: 

218 v = emb.ravel() 

219 if v.size != d: 219 ↛ 220line 219 didn't jump to line 220 because the condition on line 219 was never true

220 v = v[:d] if v.size > d else np.pad(v, (0, d - v.size)) 

221 flat.extend(v.tolist()) 

222 

223 child = pa.array(np.asarray(flat, dtype=np.float32)) 

224 return {"embedding": pa.FixedSizeListArray.from_arrays(child, d)} 

225 

226 @override 

227 def compute_batch_metric(self, features: dict[str, pa.Array]) -> dict[str, pa.Array]: 

228 """ 

229 Return an empty dictionary as embeddings are stored as features, we do not compute metrics. 

230 """ 

231 return {} 

232 

233 @override 

234 def compute(self, batch_metrics: dict[str, pa.Array]) -> dict[str, pa.Array]: 

235 """Compute final dataset-level metrics (not used for embeddings). 

236 

237 Returns: 

238 Empty dictionary as embeddings are computed at feature level. 

239 """ 

240 return {} 

241 

242 @override 

243 def compute_delta(self, source: dict[str, pa.Array], target: dict[str, pa.Array]) -> dict[str, pa.Array]: 

244 """Compute delta between source and target embeddings (not used). 

245 

246 Args: 

247 source: Source embeddings (not used). 

248 target: Target embeddings (not used). 

249 

250 Returns: 

251 Empty dictionary as delta computation is handled by DomainGapProcessor. 

252 """ 

253 return {} 

254 

255 # utils functions 

256 def _load_model(self, arch: str, device: str) -> Any: 

257 """Load a pre-trained torchvision model. 

258 

259 Args: 

260 arch: Model architecture name (e.g., 'resnet18', 'resnet50'). 

261 device: Device to load the model on ('cpu' or 'cuda'). 

262 

263 Returns: 

264 The loaded PyTorch model. 

265 """ 

266 try: 

267 m = torchvision.models.get_model(arch, weights="DEFAULT") 

268 except Exception: 

269 m = getattr(torchvision.models, arch)(pretrained=True) 

270 return m.to(device) 

271 

272 def _make_extractor(self, model: torch.nn.Module, nodes: Any) -> Any: 

273 """Create a feature extractor from a model. 

274 

275 Args: 

276 model: The PyTorch model to extract features from. 

277 nodes: Layer name (str), index (int), or list of names to extract. 

278 

279 Returns: 

280 A feature extractor that returns the requested layer outputs. 

281 """ 

282 names = list(dict(model.named_modules()).keys()) 

283 if isinstance(nodes, list): 283 ↛ 284line 283 didn't jump to line 284 because the condition on line 283 was never true

284 return create_feature_extractor(model, return_nodes={n: n for n in nodes}) 

285 if isinstance(nodes, int): 285 ↛ 286line 285 didn't jump to line 286 because the condition on line 285 was never true

286 idx = nodes if nodes >= 0 else len(names) + nodes 

287 nodes = names[idx] 

288 return create_feature_extractor(model, return_nodes={nodes: "features"})