Coverage for packages / dqm-ml-job / src / dqm_ml_job / job.py: 77%

158 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-04-15 10:11 +0000

1"""Dataset job orchestrator for end-to-end data quality assessment. 

2 

3This module contains the DatasetJob class that orchestrates the complete 

4pipeline: data loading, metric computation, and result persistence. 

5""" 

6 

7import itertools 

8import logging 

9from typing import Any 

10 

11import numpy as np 

12import pyarrow as pa 

13from tqdm import tqdm 

14 

15from dqm_ml_core.api.data_processor import DatametricProcessor 

16from dqm_ml_job.dataloaders import DataLoader, DataSelection 

17from dqm_ml_job.outputwriter import OutputWriter 

18 

19logger = logging.getLogger(__name__) 

20 

21 

22class DatasetJob: 

23 """ 

24 Orchestrates the end-to-end data quality assessment process. 

25 

26 The job handles: 

27 1. Plugin discovery and component initialization. 

28 2. Data selection discovery via DataLoaders. 

29 3. Streaming execution: Iterating over selections and batches to compute features and metrics. 

30 4. Result persistence via OutputWriters. 

31 5. Comparison metrics (deltas) between discovered datasets. 

32 """ 

33 

34 def __init__( 

35 self, 

36 dataloaders: dict[str, DataLoader], 

37 metrics: dict[str, DatametricProcessor], 

38 features_output: OutputWriter | None, 

39 progress_bar: bool = True, 

40 ) -> None: 

41 """ 

42 Initialize the pipeline components. 

43 

44 Args: 

45 dataloaders: Map of initialized DataLoader instances. 

46 metrics: Map of initialized DatametricProcessor instances. 

47 features_output: Optional writer for persisting per-sample features. 

48 progress_bar: Whether to display execution progress in the terminal. 

49 """ 

50 # We initialize loaded pluging elements 

51 self.dataloaders = dataloaders 

52 self.metrics = metrics 

53 self.features_output = features_output 

54 self.progress_bar = progress_bar 

55 

56 # Determine needed input/generated columns 

57 self.needed_input_columns: list[str] = [] 

58 self.generated_features: list[str] = [] 

59 self.generated_metrics: list[str] = [] 

60 for metric in self.metrics.values(): 

61 self.needed_input_columns.extend(metric.needed_columns()) 

62 self.generated_features.extend(metric.generated_features()) 

63 self.generated_metrics.extend(metric.generated_metrics()) 

64 

65 # Deduplicate columns 

66 self.needed_input_columns = list(dict.fromkeys(self.needed_input_columns)) 

67 self.generated_features = list(dict.fromkeys(self.generated_features)) 

68 self.generated_metrics = list(dict.fromkeys(self.generated_metrics)) 

69 

70 # Ensure output columns are included in needed input columns 

71 if self.features_output: 

72 for col in self.features_output.columns: 

73 if col not in self.generated_features: 73 ↛ 74line 73 didn't jump to line 74 because the condition on line 73 was never true

74 logger.info(f"Adding required output column '{col}' to input columns") 

75 self.needed_input_columns.insert(0, col) 

76 

77 logger.info( 

78 f"DQM job pipeline initiazed will process {len(self.dataloaders)} dataloaders, " # noqa: E501 

79 f"{len(self.metrics)} metrics processors, " 

80 f"outputting features to '{self.features_output.name if self.features_output else 'None'}' " 

81 ) 

82 

83 def get_ordered_metrics(self) -> list[DatametricProcessor]: 

84 """ 

85 Return the list of metrics processors in the order they should be executed. 

86 

87 Currently returns processors in the order they were defined in the config. 

88 """ 

89 # TODO: Implement proper ordering based on dependencies 

90 return list(self.metrics.values()) 

91 

92 def describe(self, selections: list[DataSelection]) -> None: 

93 """ 

94 Log a summary of the execution plan, including discovered selections and metrics. 

95 """ 

96 logger.info(f"Executing dqm-ml-job on {len(selections)} selections, using {len(self.metrics)} metrics ") 

97 for selection in selections: 

98 logger.info(f" Selection: {selection.name} -> {selection}") 

99 

100 for metric_name, metric in self.metrics.items(): 

101 logger.info(f" Metric: {metric_name} -> {metric}") 

102 logger.info(f" Needed columns: {metric.needed_columns()}") 

103 logger.info(f" Generated features: {metric.generated_features()}") 

104 logger.info(f" Generated metrics: {metric.generated_metrics()}") 

105 

106 def run(self) -> tuple[dict[Any, dict[str, Any]], dict[str, Any] | None]: 

107 """ 

108 Execute the job on all discovered data selections. 

109 

110 This is the main entry point for execution. It iterates through every 

111 selection found by the loaders, computes statistics, and finally 

112 calculates deltas between datasets. 

113 

114 Returns: 

115 A tuple containing: 

116 - Mapping of selection names to their final metric dictionaries. 

117 - pyarrow Table (or dict of arrays) containing all computed deltas. 

118 """ 

119 # TODO: Check with needed input order of metric computation 

120 metrics_processors = self.get_ordered_metrics() 

121 

122 columns_list = self.needed_input_columns 

123 

124 # Discover all selections 

125 all_selections: list[DataSelection] = [] 

126 for loader in self.dataloaders.values(): 

127 all_selections.extend(loader.get_selections()) 

128 

129 dataselection_metrics_list = {} 

130 

131 job_iter = tqdm(all_selections, desc="selection", position=0) if self.progress_bar else all_selections # noqa: E501 

132 

133 # TODO : add as a specific command line argument 

134 self.describe(all_selections) 

135 

136 for selection in job_iter: 

137 selection_name = selection.name 

138 logger.info(f"Processing selection '{selection_name}'") 

139 

140 selection.bootstrap(columns_list) 

141 

142 # Compute features and metrics for all batches 

143 batches_metrics_array = self._compute_batches_metrics(selection_name, selection, metrics_processors) 

144 

145 # Compute dataset-level metrics 

146 dataset_metrics: dict[str, Any] = {} 

147 

148 metrics_iter = ( 

149 tqdm(metrics_processors, desc="metrics", position=1, leave=False) 

150 if self.progress_bar 

151 else metrics_processors 

152 ) 

153 

154 for metric in metrics_iter: 

155 if logging.getLogger().level == logging.DEBUG: 155 ↛ 156line 155 didn't jump to line 156 because the condition on line 155 was never true

156 logger.debug(f"Metric computation {metric.__class__.__name__} for dataselection {selection_name}") 

157 dataset_metrics.update(metric.compute(batch_metrics=batches_metrics_array)) 

158 if logging.getLogger().level == logging.DEBUG: 158 ↛ 159line 158 didn't jump to line 159 because the condition on line 158 was never true

159 logger.debug(f"Available metrics {list(dataset_metrics.keys())}") 

160 

161 dataselection_metrics_list[selection_name] = dataset_metrics 

162 

163 # If we have to compute delta metrics 

164 delta_metrics_table = self._compute_delta_metrics(metrics_processors, dataselection_metrics_list) 

165 

166 return dataselection_metrics_list, delta_metrics_table 

167 

168 @staticmethod 

169 def _to_pa_array(value: Any, key: str) -> pa.Array: 

170 """Convert a delta metric value to PyArrow array. 

171 

172 Args: 

173 value: The value to convert (float, int, str, np.ndarray, or pa.Array). 

174 key: The metric name for error logging. 

175 

176 Returns: 

177 PyArrow array containing the value. 

178 

179 Raises: 

180 TypeError: If the value type is not supported. 

181 """ 

182 if isinstance(value, pa.Array): 182 ↛ 184line 182 didn't jump to line 184 because the condition on line 182 was always true

183 return value 

184 elif isinstance(value, (int, float, np.number)): 

185 return pa.array([float(value)]) 

186 elif isinstance(value, str): 

187 return pa.array([value]) 

188 elif isinstance(value, np.ndarray): 

189 return pa.array([value.tolist()]) 

190 else: 

191 logger.error(f"Cannot convert delta metric '{key}' to pa.Array: type={type(value)}") 

192 raise TypeError(f"Unsupported delta metric type: {type(value)} for key '{key}'") 

193 

194 def _compute_delta_metrics( 

195 self, metrics_processors: list[DatametricProcessor], dataselection_metrics_list: dict[str, dict[str, Any]] 

196 ) -> dict[str, Any] | None: 

197 """Compute comparison metrics between every unique pair of data selections. 

198 

199 Args: 

200 metrics_processors: List of processors capable of computing deltas. 

201 dataselection_metrics_list: Map of selection names to their metrics. 

202 

203 Returns: 

204 A pyarrow-compatible dictionary representing the delta table. 

205 """ 

206 

207 selection_combinaisons = itertools.combinations(dataselection_metrics_list, 2) 

208 

209 delta_metrics_table = None 

210 for combinaison in selection_combinaisons: 

211 src_metrics = dataselection_metrics_list[combinaison[0]] 

212 target_metrics = dataselection_metrics_list[combinaison[1]] 

213 

214 for metric in metrics_processors: 

215 delta_metrics = metric.compute_delta(src_metrics, target_metrics) 

216 

217 # TODO : check format of classical metrics / delta metrics for combinaison of format 

218 if len(delta_metrics) == 0: 

219 continue 

220 

221 if delta_metrics_table is None: 221 ↛ 226line 221 didn't jump to line 226 because the condition on line 221 was always true

222 delta_metrics_table = {key: self._to_pa_array(value, key) for key, value in delta_metrics.items()} 

223 delta_metrics_table["selection_source"] = pa.array([combinaison[0]]) 

224 delta_metrics_table["selection_target"] = pa.array([combinaison[1]]) 

225 else: 

226 for m_name, value in delta_metrics.items(): 

227 delta_metrics_table[m_name] = pa.concat_arrays( 

228 [delta_metrics_table[m_name], self._to_pa_array(value, m_name)] 

229 ) # noqa: E501 

230 

231 delta_metrics_table["selection_source"] = pa.concat_arrays( 

232 [delta_metrics_table["selection_source"], pa.array([combinaison[0]])] 

233 ) # noqa: E501 

234 delta_metrics_table["selection_target"] = pa.concat_arrays( 

235 [delta_metrics_table["selection_target"], pa.array([combinaison[1]])] 

236 ) # noqa: E501 

237 logger.debug(f"Writing delta metrics for dataloader {'_'.join(combinaison)}") 

238 

239 return delta_metrics_table 

240 

241 def _compute_batches_metrics( 

242 self, selection_name: str, selection: DataSelection, metrics_processors: list[DatametricProcessor] 

243 ) -> dict[str, Any]: 

244 """Process all batches in a selection to compute intermediate statistics and features. 

245 

246 Memory Management: 

247 - Batch-level statistics (`batch_metrics`) are accumulated in lists and 

248 concatenated once the selection is complete. 

249 - Per-sample features are also accumulated in memory before being passed 

250 to the OutputWriter. 

251 - NOTE: For extremely large datasets, this accumulation can lead to high 

252 memory usage. Future versions will implement disk-flushing (chunking). 

253 

254 Args: 

255 selection_name: Name of the current data selection. 

256 selection: The selection iterator. 

257 metrics_processors: List of processors to apply to each batch. 

258 

259 Returns: 

260 Dictionary of concatenated intermediate statistics arrays. 

261 """ 

262 # Use lists for O(1) appending, then concat once at the end. 

263 batch_metrics_accumulator: dict[str, list[Any]] = {} 

264 features_accumulator: dict[str, list[Any]] = {} 

265 

266 # Track memory size for potential chunking 

267 feature_array_size = 0 

268 part_index = 0 

269 memory_threshold = 512 * 1024 * 1024 # 512MB threshold for flushing features 

270 

271 dataloader_iter = ( 

272 tqdm(selection, desc="batches", position=1, leave=False, total=selection.get_nb_batches()) 

273 if self.progress_bar 

274 else selection 

275 ) 

276 

277 for batch in dataloader_iter: 

278 batch_features: dict[str, Any] = {} 

279 batch_metrics: dict[str, Any] = {} 

280 

281 # Compute features and batch-level metrics 

282 for metric in metrics_processors: 

283 batch_features.update(metric.compute_features(batch, prev_features=batch_features)) 

284 batch_metrics.update(metric.compute_batch_metric(batch_features)) 

285 if logging.getLogger().level == logging.DEBUG: 285 ↛ 286line 285 didn't jump to line 286 because the condition on line 285 was never true

286 m_keys, m_features = list(batch_metrics.keys()), list(batch_features.keys()) 

287 logger.debug(f"{metric.name} - Available batch_metrics {m_keys} - features {m_features}") 

288 

289 # Accumulate batch metrics 

290 for k, v in batch_metrics.items(): 

291 if k not in batch_metrics_accumulator: 

292 batch_metrics_accumulator[k] = [] 

293 batch_metrics_accumulator[k].append(v) 

294 

295 # Accumulate features from source dataset 

296 for i, col_name in enumerate(batch.column_names): 

297 if self.features_output is None: 

298 continue 

299 if col_name not in self.features_output.columns: 299 ↛ 302line 299 didn't jump to line 302 because the condition on line 299 was always true

300 continue 

301 

302 col_data = batch.column(i) 

303 if col_name not in features_accumulator: 

304 features_accumulator[col_name] = [] 

305 features_accumulator[col_name].append(col_data) 

306 feature_array_size += col_data.get_total_buffer_size() 

307 

308 # Accumulate generated features 

309 for k, v in batch_features.items(): 

310 if self.features_output is None: 

311 continue 

312 # Avoid duplication if feature is also a metric or not required 

313 if k not in self.features_output.columns or k in batch_metrics: 313 ↛ 314line 313 didn't jump to line 314 because the condition on line 313 was never true

314 continue 

315 

316 if k not in features_accumulator: 316 ↛ 318line 316 didn't jump to line 318 because the condition on line 316 was always true

317 features_accumulator[k] = [] 

318 features_accumulator[k].append(v) 

319 feature_array_size += v.get_total_buffer_size() 

320 

321 # Flush features to disk if memory threshold reached 

322 if feature_array_size > memory_threshold and self.features_output: 322 ↛ 323line 322 didn't jump to line 323 because the condition on line 322 was never true

323 logger.info( 

324 f"Memory threshold reached ({feature_array_size / 1024**2:.1f}MB). Flushing chunk {part_index}" 

325 ) 

326 features_chunk: dict[str, Any] = {} 

327 for k, v_list in features_accumulator.items(): 

328 features_chunk[k] = pa.concat_arrays(v_list) 

329 

330 self.features_output.write_table(selection_name, features_chunk, part_index) 

331 

332 # Reset features accumulator 

333 features_accumulator = {} 

334 feature_array_size = 0 

335 part_index += 1 

336 

337 # Concatenate all accumulated arrays 

338 batches_metrics_array: dict[str, Any] = {} 

339 for k, v_list in batch_metrics_accumulator.items(): 

340 batches_metrics_array[k] = pa.concat_arrays(v_list) 

341 

342 features_array: dict[str, Any] = {} 

343 if features_accumulator: 

344 for k, v_list in features_accumulator.items(): 

345 features_array[k] = pa.concat_arrays(v_list) 

346 

347 # Write remaining features to disk 

348 if self.features_output and features_array: 

349 self.features_output.write_table(selection_name, features_array, part_index) 

350 

351 return batches_metrics_array