Coverage for packages / dqm-ml-pytorch / src / dqm_ml_pytorch / domain_gap.py: 89%

174 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-04-15 10:11 +0000

1"""Domain gap processor for measuring distribution distance between datasets. 

2 

3This module contains the DomainGapProcessor class that computes statistical 

4distances (KL divergence, MMD, FID, Wasserstein) between source and target 

5datasets using image embeddings. 

6""" 

7 

8from __future__ import annotations 

9 

10import logging 

11from typing import Any 

12 

13import numpy as np 

14import pyarrow as pa 

15 

16# COMPATIBILITY : from typing import Any, override # When support of 3.10 and 3.11 will be removed 

17from typing_extensions import override 

18 

19from dqm_ml_core import DatametricProcessor 

20 

21logger = logging.getLogger(__name__) 

22 

23 

24class DomainGapProcessor(DatametricProcessor): 

25 """ 

26 Computes statistical distances between source and target dataselections using image embeddings. 

27 

28 This processor works in two stages: 

29 1. Dataset Summary: Aggregates high-dimensional embeddings into compact statistics 

30 (mean, variance, outer products, histograms). 

31 2. Delta Computation: Uses these summaries to calculate distance metrics between 

32 a source and a target dataset. 

33 

34 Supported Delta Metrics: 

35 - `klmvn_diag`: Kullback-Leibler Divergence assuming a multivariate Normal 

36 distribution with a diagonal covariance matrix. 

37 - `mmd_linear`: Maximum Mean Discrepancy with a linear kernel. 

38 - `fid`: Frechet Inception Distance. Measures distance between two Gaussians 

39 fitted to feature representations (requires full covariance calculation). 

40 - `wasserstein_1d`: Average 1D Wasserstein distance across embedding dimensions, 

41 approximated via histograms. 

42 """ 

43 

44 def __init__( 

45 self, 

46 name: str = "image_embedding", 

47 config: dict[str, Any] | None = None, 

48 ): 

49 """ 

50 Initialize the domain gap processor. 

51 

52 Args: 

53 name: Unique name of the processor instance. 

54 config: Configuration dictionary containing: 

55 - INPUT: 

56 - embedding_col: Column name containing embeddings (default: "embedding"). 

57 - SUMMARY: 

58 - collect_sum_outer: Whether to compute outer products (needed for FID). 

59 - collect_hist_1d: Whether to compute histograms (needed for Wasserstein). 

60 - hist_dims: Number of dimensions to histogram. 

61 - hist_bins: Number of bins per histogram. 

62 - DELTA: 

63 - metric: Target metric ("klmvn_diag", "mmd_linear", "fid", "wasserstein_1d"). 

64 """ 

65 super().__init__(name, config) 

66 self._checked = False 

67 

68 # ---------------- API ---------------- 

69 def check_config(self) -> None: 

70 """Validate and configure the domain gap processor. 

71 

72 This method parses the configuration dictionary and sets: 

73 - INPUT: Embedding column name 

74 - SUMMARY: Options for collecting summary statistics (outer products, histograms) 

75 - DELTA: The target metric to compute (klmvn_diag, mmd_linear, fid, wasserstein_1d) 

76 """ 

77 cfg = self.config or {} 

78 icfg = cfg.get("INPUT", {}) 

79 self.embedding_col: str = icfg.get("embedding_col", "embedding") 

80 

81 dcfg = cfg.get("DELTA", {}) 

82 self.delta_metric: str = str(dcfg.get("metric", "klmvn_diag")).lower() 

83 scfg = cfg.get("SUMMARY", {}) 

84 

85 if self.delta_metric == "fid": 

86 auto_sum_outer = True 

87 auto_hist_1d = False 

88 elif self.delta_metric == "wasserstein_1d": 

89 auto_sum_outer = False 

90 auto_hist_1d = True 

91 else: # klmvn_diag, mmd_linear 

92 auto_sum_outer = False 

93 auto_hist_1d = False 

94 

95 self.collect_sum_outer: bool = bool(scfg.get("collect_sum_outer", auto_sum_outer)) 

96 self.collect_hist_1d: bool = bool(scfg.get("collect_hist_1d", auto_hist_1d)) 

97 

98 # Wasserstein-1D parameters 

99 self.hist_dims: int = int(scfg.get("hist_dims", 64)) 

100 self.hist_bins: int = int(scfg.get("hist_bins", 32)) 

101 rng = scfg.get("hist_range", [-3.0, 3.0]) 

102 self.hist_range: tuple[float, float] = (float(rng[0]), float(rng[1])) 

103 

104 self._checked = True 

105 

106 @override 

107 def needed_columns(self) -> list[str]: 

108 if not getattr(self, "_checked", False): 

109 self.check_config() 

110 return [self.embedding_col] 

111 

112 def generated_columns(self) -> list[str]: 

113 """Return the list of columns generated by this processor. 

114 

115 Returns: 

116 Empty list as this processor computes deltas rather than features. 

117 """ 

118 return [] 

119 

120 @override 

121 def compute_batch_metric(self, features: dict[str, pa.Array]) -> dict[str, pa.Array]: 

122 """ 

123 Reduce a batch of embeddings into summary statistics. 

124 

125 Returns a dictionary containing: 

126 - count: Number of samples. 

127 - sum: Element-wise sum of embeddings. 

128 - sum_sq: Element-wise sum of squared embeddings. 

129 - sum_outer: (Optional) Sum of outer products (for FID). 

130 - hist_counts: (Optional) Flattened histogram counts (for Wasserstein). 

131 """ 

132 if not getattr(self, "_checked", False): 132 ↛ 133line 132 didn't jump to line 133 because the condition on line 132 was never true

133 self.check_config() 

134 

135 emb = features.get(self.embedding_col) 

136 if emb is None or not isinstance(emb, pa.FixedSizeListArray): 136 ↛ 137line 136 didn't jump to line 137 because the condition on line 136 was never true

137 return {} 

138 

139 n = len(emb) 

140 # d = emb.list_size 

141 d = len(emb[0]) 

142 child = emb.values # flat 

143 arr = np.asarray(child.to_numpy()).reshape(n, d) 

144 

145 out: dict[str, pa.Array] = {} 

146 out["count"] = pa.array([n], type=pa.int64()) 

147 out["sum"] = pa.FixedSizeListArray.from_arrays(pa.array(arr.sum(axis=0).astype(np.float64)), d) 

148 out["sum_sq"] = pa.FixedSizeListArray.from_arrays(pa.array((arr * arr).sum(axis=0).astype(np.float64)), d) 

149 

150 # optional: sum_outer for FID 

151 if self.collect_sum_outer: 

152 s = (arr.T @ arr).reshape(-1).astype(np.float64) 

153 out["sum_outer"] = pa.FixedSizeListArray.from_arrays(pa.array(s), d * d) 

154 

155 # optional: histograms for Wasserstein-1D 

156 if self.collect_hist_1d: 

157 use_dims = min(d, self.hist_dims) 

158 low, high = self.hist_range 

159 hist_list: list[np.ndarray] = [] 

160 for j in range(use_dims): 

161 h, _ = np.histogram(arr[:, j], bins=self.hist_bins, range=(low, high)) 

162 hist_list.append(h.astype(np.int64)) 

163 h = np.stack(hist_list, axis=0).reshape(-1) 

164 out["hist_counts"] = pa.FixedSizeListArray.from_arrays(pa.array(h), self.hist_bins * use_dims) 

165 

166 return out 

167 

168 @override 

169 def compute(self, batch_metrics: dict[str, pa.Array]) -> dict[str, pa.Array]: 

170 """Aggregate batch-level summary statistics into global dataselection statistics. 

171 

172 Args: 

173 batch_metrics: Dictionary containing batch-level statistics (count, sum, sum_sq, etc.). 

174 

175 Returns: 

176 Dictionary containing aggregated dataset-level statistics. 

177 """ 

178 if not batch_metrics: 178 ↛ 179line 178 didn't jump to line 179 because the condition on line 178 was never true

179 return {} 

180 

181 def _sum_scalar(a: pa.Array) -> int: 

182 return int(np.asarray(a.to_numpy()).sum()) 

183 

184 def _sum_fixed(v: pa.FixedSizeListArray) -> tuple[np.ndarray, int]: 

185 vals = np.asarray(v.values.to_numpy(), dtype=np.float64) 

186 d = len(v[0]) 

187 # d = v.list_size 

188 return vals.reshape(-1, d).sum(axis=0), d 

189 

190 out: dict[str, pa.Array] = {} 

191 

192 # count 

193 if "count" not in batch_metrics: 193 ↛ 194line 193 didn't jump to line 194 because the condition on line 193 was never true

194 return {} 

195 total_n = _sum_scalar(batch_metrics["count"]) 

196 out["count"] = pa.array([total_n], type=pa.int64()) 

197 

198 # sum / sum_sq 

199 if "sum" in batch_metrics: 199 ↛ 202line 199 didn't jump to line 202 because the condition on line 199 was always true

200 s, d = _sum_fixed(batch_metrics["sum"]) 

201 out["sum"] = pa.FixedSizeListArray.from_arrays(pa.array(s), d) 

202 if "sum_sq" in batch_metrics: 202 ↛ 207line 202 didn't jump to line 207 because the condition on line 202 was always true

203 s2, d2 = _sum_fixed(batch_metrics["sum_sq"]) 

204 out["sum_sq"] = pa.FixedSizeListArray.from_arrays(pa.array(s2), d2) 

205 

206 # optional sum_outer 

207 if "sum_outer" in batch_metrics: 

208 so_vals = np.asarray(batch_metrics["sum_outer"].values.to_numpy(), dtype=np.float64) 

209 dd = len(batch_metrics["sum_outer"][0]) 

210 # dd = batch_metrics["sum_outer"].list_size 

211 out["sum_outer"] = pa.FixedSizeListArray.from_arrays(pa.array(so_vals.reshape(-1, dd).sum(axis=0)), dd) 

212 

213 # optional hist_counts 

214 if "hist_counts" in batch_metrics: 

215 h_vals = np.asarray(batch_metrics["hist_counts"].values.to_numpy(), dtype=np.int64) 

216 h_len = len(batch_metrics["hist_counts"][0]) 

217 # h_len = batch_metrics["hist_counts"].list_size 

218 out["hist_counts"] = pa.FixedSizeListArray.from_arrays( 

219 pa.array(h_vals.reshape(-1, h_len).sum(axis=0)), h_len 

220 ) 

221 

222 return out 

223 

224 @override 

225 def compute_delta(self, source: dict[str, pa.Array], target: dict[str, pa.Array]) -> dict[str, pa.Array]: 

226 """ 

227 Calculate the domain gap metric between source and target dataselection statistics. 

228 

229 Args: 

230 source: Dataselection statistics from the source dataset (computed via `compute`). 

231 target: Dataselection statistics from the target dataset (computed via `compute`). 

232 

233 Returns: 

234 Dictionary containing the calculated metric value. 

235 """ 

236 if not getattr(self, "_checked", False): 236 ↛ 237line 236 didn't jump to line 237 because the condition on line 236 was never true

237 self.check_config() 

238 # TODO : check config and available metrics outside of computation 

239 # TODO : add a return code error in th API 

240 

241 def vec( 

242 a: pa.FixedSizeListArray, 

243 ) -> Any: # TODO : check type error np.ndarray 

244 """Aggregate a FixedSizeListArray into a single numpy vector by summing all lists.""" 

245 len_a = len(a[0]) 

246 # len_a = a.list_size 

247 array = np.asarray(a.values.to_numpy(), dtype=np.float64).reshape(-1, len_a).sum(axis=0) 

248 return array # type : ignore[no-any-return] 

249 

250 def scalar(a: pa.Array) -> float: 

251 """Sum all elements in a pyarrow Array and return as a float.""" 

252 return float(np.asarray(a.to_numpy()).sum()) 

253 

254 metric = self.delta_metric 

255 

256 if metric in {"klmvn_diag", "mmd_linear", "fid"}: 

257 need = {"count", "sum"} 

258 if metric in {"klmvn_diag", "fid"}: 

259 need |= {"sum_sq"} 

260 if metric == "fid": 

261 need |= {"sum_outer"} 

262 for side, name in ((source, "source"), (target, "target")): 

263 if not need.issubset(side.keys()): 263 ↛ 264line 263 didn't jump to line 264 because the condition on line 263 was never true

264 return { 

265 "metric": pa.array([metric]), 

266 "note": pa.array([f"missing keys in {name}: {sorted(need)}"]), 

267 } 

268 

269 n1, n2 = scalar(source["count"]), scalar(target["count"]) 

270 if n1 <= 0 or n2 <= 0: 270 ↛ 271line 270 didn't jump to line 271 because the condition on line 270 was never true

271 return { 

272 "metric": pa.array([metric]), 

273 "note": pa.array(["empty summaries"]), 

274 } 

275 

276 mu1 = vec(source["sum"]) / n1 

277 mu2 = vec(target["sum"]) / n2 

278 

279 if metric == "mmd_linear": 

280 diff = mu1 - mu2 

281 val = float(np.dot(diff, diff)) 

282 return {"mmd_linear": pa.array([val], type=pa.float64())} 

283 

284 v1 = np.maximum(vec(source["sum_sq"]) / n1 - mu1 * mu1, 1e-9) 

285 v2 = np.maximum(vec(target["sum_sq"]) / n2 - mu2 * mu2, 1e-9) 

286 

287 if metric == "klmvn_diag": 

288 term_var = np.sum(v1 / v2 - 1.0 - np.log(v1 / v2)) 

289 term_mean = np.sum((mu2 - mu1) ** 2 / v2) 

290 val = 0.5 * (term_var + term_mean) 

291 return {"klmvn_diag": pa.array([float(val)], type=pa.float64())} 

292 

293 if metric == "fid": 293 ↛ 313line 293 didn't jump to line 313 because the condition on line 293 was always true

294 so1 = vec(source["sum_outer"]) 

295 so2 = vec(target["sum_outer"]) 

296 d = int(np.sqrt(so1.size)) 

297 s1 = (so1.reshape(d, d) / n1) - np.outer(mu1, mu1) 

298 s2 = (so2.reshape(d, d) / n2) - np.outer(mu2, mu2) 

299 from scipy.linalg import sqrtm 

300 

301 diff = mu1 - mu2 

302 # The `disp` argument is deprecated and will be 

303 # removed in SciPy 1.18.0. The previously returned error estimate 

304 # can be computed as ``norm(X @ X - A, 'fro')**2 / norm(A, 'fro')`` 

305 # covmean, _ = sqrtm(s1.dot(s2), disp=False) 

306 covmean = sqrtm(s1.dot(s2)) 

307 

308 if np.iscomplexobj(covmean): 308 ↛ 309line 308 didn't jump to line 309 because the condition on line 308 was never true

309 covmean = covmean.real 

310 fid = diff.dot(diff) + np.trace(s1) + np.trace(s2) - 2 * np.trace(covmean) 

311 return {"fid": pa.array([float(abs(fid))], type=pa.float64())} 

312 

313 if metric == "wasserstein_1d": 313 ↛ 346line 313 didn't jump to line 346 because the condition on line 313 was always true

314 if "hist_counts" not in source or "hist_counts" not in target: 314 ↛ 315line 314 didn't jump to line 315 because the condition on line 314 was never true

315 return { 

316 "metric": pa.array([metric]), 

317 "note": pa.array(["missing hist_counts"]), 

318 } 

319 h1 = np.asarray(source["hist_counts"].values.to_numpy(), dtype=np.int64) 

320 h2 = np.asarray(target["hist_counts"].values.to_numpy(), dtype=np.int64) 

321 # derive dims from summary config 

322 use_dims = self.hist_dims 

323 bins = self.hist_bins 

324 if h1.size != h2.size or h1.size != bins * use_dims: 324 ↛ 325line 324 didn't jump to line 325 because the condition on line 324 was never true

325 return { 

326 "metric": pa.array([metric]), 

327 "note": pa.array(["hist_counts length mismatch"]), 

328 } 

329 width = (self.hist_range[1] - self.hist_range[0]) / bins 

330 total = 0.0 

331 used = 0 

332 for j in range(use_dims): 

333 h1 = h1[j * bins : (j + 1) * bins].astype(np.float64) 

334 h2 = h2[j * bins : (j + 1) * bins].astype(np.float64) 

335 if h1.sum() == 0 and h2.sum() == 0: 

336 continue 

337 p = h1 / max(1.0, h1.sum()) 

338 q = h2 / max(1.0, h2.sum()) 

339 cdf_p = np.cumsum(p) 

340 cdf_q = np.cumsum(q) 

341 total += float(np.sum(np.abs(cdf_p - cdf_q)) * width) 

342 used += 1 

343 val = total / max(1, used) 

344 return {"wasserstein_1d": pa.array([val], type=pa.float64())} 

345 

346 return { 

347 "metric": pa.array([metric]), 

348 "note": pa.array(["unsupported metric or invalid inputs"]), 

349 }