Coverage for packages / dqm-ml-job / src / dqm_ml_job / job.py: 77%
158 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-15 10:11 +0000
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-15 10:11 +0000
1"""Dataset job orchestrator for end-to-end data quality assessment.
3This module contains the DatasetJob class that orchestrates the complete
4pipeline: data loading, metric computation, and result persistence.
5"""
7import itertools
8import logging
9from typing import Any
11import numpy as np
12import pyarrow as pa
13from tqdm import tqdm
15from dqm_ml_core.api.data_processor import DatametricProcessor
16from dqm_ml_job.dataloaders import DataLoader, DataSelection
17from dqm_ml_job.outputwriter import OutputWriter
19logger = logging.getLogger(__name__)
22class DatasetJob:
23 """
24 Orchestrates the end-to-end data quality assessment process.
26 The job handles:
27 1. Plugin discovery and component initialization.
28 2. Data selection discovery via DataLoaders.
29 3. Streaming execution: Iterating over selections and batches to compute features and metrics.
30 4. Result persistence via OutputWriters.
31 5. Comparison metrics (deltas) between discovered datasets.
32 """
34 def __init__(
35 self,
36 dataloaders: dict[str, DataLoader],
37 metrics: dict[str, DatametricProcessor],
38 features_output: OutputWriter | None,
39 progress_bar: bool = True,
40 ) -> None:
41 """
42 Initialize the pipeline components.
44 Args:
45 dataloaders: Map of initialized DataLoader instances.
46 metrics: Map of initialized DatametricProcessor instances.
47 features_output: Optional writer for persisting per-sample features.
48 progress_bar: Whether to display execution progress in the terminal.
49 """
50 # We initialize loaded pluging elements
51 self.dataloaders = dataloaders
52 self.metrics = metrics
53 self.features_output = features_output
54 self.progress_bar = progress_bar
56 # Determine needed input/generated columns
57 self.needed_input_columns: list[str] = []
58 self.generated_features: list[str] = []
59 self.generated_metrics: list[str] = []
60 for metric in self.metrics.values():
61 self.needed_input_columns.extend(metric.needed_columns())
62 self.generated_features.extend(metric.generated_features())
63 self.generated_metrics.extend(metric.generated_metrics())
65 # Deduplicate columns
66 self.needed_input_columns = list(dict.fromkeys(self.needed_input_columns))
67 self.generated_features = list(dict.fromkeys(self.generated_features))
68 self.generated_metrics = list(dict.fromkeys(self.generated_metrics))
70 # Ensure output columns are included in needed input columns
71 if self.features_output:
72 for col in self.features_output.columns:
73 if col not in self.generated_features: 73 ↛ 74line 73 didn't jump to line 74 because the condition on line 73 was never true
74 logger.info(f"Adding required output column '{col}' to input columns")
75 self.needed_input_columns.insert(0, col)
77 logger.info(
78 f"DQM job pipeline initiazed will process {len(self.dataloaders)} dataloaders, " # noqa: E501
79 f"{len(self.metrics)} metrics processors, "
80 f"outputting features to '{self.features_output.name if self.features_output else 'None'}' "
81 )
83 def get_ordered_metrics(self) -> list[DatametricProcessor]:
84 """
85 Return the list of metrics processors in the order they should be executed.
87 Currently returns processors in the order they were defined in the config.
88 """
89 # TODO: Implement proper ordering based on dependencies
90 return list(self.metrics.values())
92 def describe(self, selections: list[DataSelection]) -> None:
93 """
94 Log a summary of the execution plan, including discovered selections and metrics.
95 """
96 logger.info(f"Executing dqm-ml-job on {len(selections)} selections, using {len(self.metrics)} metrics ")
97 for selection in selections:
98 logger.info(f" Selection: {selection.name} -> {selection}")
100 for metric_name, metric in self.metrics.items():
101 logger.info(f" Metric: {metric_name} -> {metric}")
102 logger.info(f" Needed columns: {metric.needed_columns()}")
103 logger.info(f" Generated features: {metric.generated_features()}")
104 logger.info(f" Generated metrics: {metric.generated_metrics()}")
106 def run(self) -> tuple[dict[Any, dict[str, Any]], dict[str, Any] | None]:
107 """
108 Execute the job on all discovered data selections.
110 This is the main entry point for execution. It iterates through every
111 selection found by the loaders, computes statistics, and finally
112 calculates deltas between datasets.
114 Returns:
115 A tuple containing:
116 - Mapping of selection names to their final metric dictionaries.
117 - pyarrow Table (or dict of arrays) containing all computed deltas.
118 """
119 # TODO: Check with needed input order of metric computation
120 metrics_processors = self.get_ordered_metrics()
122 columns_list = self.needed_input_columns
124 # Discover all selections
125 all_selections: list[DataSelection] = []
126 for loader in self.dataloaders.values():
127 all_selections.extend(loader.get_selections())
129 dataselection_metrics_list = {}
131 job_iter = tqdm(all_selections, desc="selection", position=0) if self.progress_bar else all_selections # noqa: E501
133 # TODO : add as a specific command line argument
134 self.describe(all_selections)
136 for selection in job_iter:
137 selection_name = selection.name
138 logger.info(f"Processing selection '{selection_name}'")
140 selection.bootstrap(columns_list)
142 # Compute features and metrics for all batches
143 batches_metrics_array = self._compute_batches_metrics(selection_name, selection, metrics_processors)
145 # Compute dataset-level metrics
146 dataset_metrics: dict[str, Any] = {}
148 metrics_iter = (
149 tqdm(metrics_processors, desc="metrics", position=1, leave=False)
150 if self.progress_bar
151 else metrics_processors
152 )
154 for metric in metrics_iter:
155 if logging.getLogger().level == logging.DEBUG: 155 ↛ 156line 155 didn't jump to line 156 because the condition on line 155 was never true
156 logger.debug(f"Metric computation {metric.__class__.__name__} for dataselection {selection_name}")
157 dataset_metrics.update(metric.compute(batch_metrics=batches_metrics_array))
158 if logging.getLogger().level == logging.DEBUG: 158 ↛ 159line 158 didn't jump to line 159 because the condition on line 158 was never true
159 logger.debug(f"Available metrics {list(dataset_metrics.keys())}")
161 dataselection_metrics_list[selection_name] = dataset_metrics
163 # If we have to compute delta metrics
164 delta_metrics_table = self._compute_delta_metrics(metrics_processors, dataselection_metrics_list)
166 return dataselection_metrics_list, delta_metrics_table
168 @staticmethod
169 def _to_pa_array(value: Any, key: str) -> pa.Array:
170 """Convert a delta metric value to PyArrow array.
172 Args:
173 value: The value to convert (float, int, str, np.ndarray, or pa.Array).
174 key: The metric name for error logging.
176 Returns:
177 PyArrow array containing the value.
179 Raises:
180 TypeError: If the value type is not supported.
181 """
182 if isinstance(value, pa.Array): 182 ↛ 184line 182 didn't jump to line 184 because the condition on line 182 was always true
183 return value
184 elif isinstance(value, (int, float, np.number)):
185 return pa.array([float(value)])
186 elif isinstance(value, str):
187 return pa.array([value])
188 elif isinstance(value, np.ndarray):
189 return pa.array([value.tolist()])
190 else:
191 logger.error(f"Cannot convert delta metric '{key}' to pa.Array: type={type(value)}")
192 raise TypeError(f"Unsupported delta metric type: {type(value)} for key '{key}'")
194 def _compute_delta_metrics(
195 self, metrics_processors: list[DatametricProcessor], dataselection_metrics_list: dict[str, dict[str, Any]]
196 ) -> dict[str, Any] | None:
197 """Compute comparison metrics between every unique pair of data selections.
199 Args:
200 metrics_processors: List of processors capable of computing deltas.
201 dataselection_metrics_list: Map of selection names to their metrics.
203 Returns:
204 A pyarrow-compatible dictionary representing the delta table.
205 """
207 selection_combinaisons = itertools.combinations(dataselection_metrics_list, 2)
209 delta_metrics_table = None
210 for combinaison in selection_combinaisons:
211 src_metrics = dataselection_metrics_list[combinaison[0]]
212 target_metrics = dataselection_metrics_list[combinaison[1]]
214 for metric in metrics_processors:
215 delta_metrics = metric.compute_delta(src_metrics, target_metrics)
217 # TODO : check format of classical metrics / delta metrics for combinaison of format
218 if len(delta_metrics) == 0:
219 continue
221 if delta_metrics_table is None: 221 ↛ 226line 221 didn't jump to line 226 because the condition on line 221 was always true
222 delta_metrics_table = {key: self._to_pa_array(value, key) for key, value in delta_metrics.items()}
223 delta_metrics_table["selection_source"] = pa.array([combinaison[0]])
224 delta_metrics_table["selection_target"] = pa.array([combinaison[1]])
225 else:
226 for m_name, value in delta_metrics.items():
227 delta_metrics_table[m_name] = pa.concat_arrays(
228 [delta_metrics_table[m_name], self._to_pa_array(value, m_name)]
229 ) # noqa: E501
231 delta_metrics_table["selection_source"] = pa.concat_arrays(
232 [delta_metrics_table["selection_source"], pa.array([combinaison[0]])]
233 ) # noqa: E501
234 delta_metrics_table["selection_target"] = pa.concat_arrays(
235 [delta_metrics_table["selection_target"], pa.array([combinaison[1]])]
236 ) # noqa: E501
237 logger.debug(f"Writing delta metrics for dataloader {'_'.join(combinaison)}")
239 return delta_metrics_table
241 def _compute_batches_metrics(
242 self, selection_name: str, selection: DataSelection, metrics_processors: list[DatametricProcessor]
243 ) -> dict[str, Any]:
244 """Process all batches in a selection to compute intermediate statistics and features.
246 Memory Management:
247 - Batch-level statistics (`batch_metrics`) are accumulated in lists and
248 concatenated once the selection is complete.
249 - Per-sample features are also accumulated in memory before being passed
250 to the OutputWriter.
251 - NOTE: For extremely large datasets, this accumulation can lead to high
252 memory usage. Future versions will implement disk-flushing (chunking).
254 Args:
255 selection_name: Name of the current data selection.
256 selection: The selection iterator.
257 metrics_processors: List of processors to apply to each batch.
259 Returns:
260 Dictionary of concatenated intermediate statistics arrays.
261 """
262 # Use lists for O(1) appending, then concat once at the end.
263 batch_metrics_accumulator: dict[str, list[Any]] = {}
264 features_accumulator: dict[str, list[Any]] = {}
266 # Track memory size for potential chunking
267 feature_array_size = 0
268 part_index = 0
269 memory_threshold = 512 * 1024 * 1024 # 512MB threshold for flushing features
271 dataloader_iter = (
272 tqdm(selection, desc="batches", position=1, leave=False, total=selection.get_nb_batches())
273 if self.progress_bar
274 else selection
275 )
277 for batch in dataloader_iter:
278 batch_features: dict[str, Any] = {}
279 batch_metrics: dict[str, Any] = {}
281 # Compute features and batch-level metrics
282 for metric in metrics_processors:
283 batch_features.update(metric.compute_features(batch, prev_features=batch_features))
284 batch_metrics.update(metric.compute_batch_metric(batch_features))
285 if logging.getLogger().level == logging.DEBUG: 285 ↛ 286line 285 didn't jump to line 286 because the condition on line 285 was never true
286 m_keys, m_features = list(batch_metrics.keys()), list(batch_features.keys())
287 logger.debug(f"{metric.name} - Available batch_metrics {m_keys} - features {m_features}")
289 # Accumulate batch metrics
290 for k, v in batch_metrics.items():
291 if k not in batch_metrics_accumulator:
292 batch_metrics_accumulator[k] = []
293 batch_metrics_accumulator[k].append(v)
295 # Accumulate features from source dataset
296 for i, col_name in enumerate(batch.column_names):
297 if self.features_output is None:
298 continue
299 if col_name not in self.features_output.columns: 299 ↛ 302line 299 didn't jump to line 302 because the condition on line 299 was always true
300 continue
302 col_data = batch.column(i)
303 if col_name not in features_accumulator:
304 features_accumulator[col_name] = []
305 features_accumulator[col_name].append(col_data)
306 feature_array_size += col_data.get_total_buffer_size()
308 # Accumulate generated features
309 for k, v in batch_features.items():
310 if self.features_output is None:
311 continue
312 # Avoid duplication if feature is also a metric or not required
313 if k not in self.features_output.columns or k in batch_metrics: 313 ↛ 314line 313 didn't jump to line 314 because the condition on line 313 was never true
314 continue
316 if k not in features_accumulator: 316 ↛ 318line 316 didn't jump to line 318 because the condition on line 316 was always true
317 features_accumulator[k] = []
318 features_accumulator[k].append(v)
319 feature_array_size += v.get_total_buffer_size()
321 # Flush features to disk if memory threshold reached
322 if feature_array_size > memory_threshold and self.features_output: 322 ↛ 323line 322 didn't jump to line 323 because the condition on line 322 was never true
323 logger.info(
324 f"Memory threshold reached ({feature_array_size / 1024**2:.1f}MB). Flushing chunk {part_index}"
325 )
326 features_chunk: dict[str, Any] = {}
327 for k, v_list in features_accumulator.items():
328 features_chunk[k] = pa.concat_arrays(v_list)
330 self.features_output.write_table(selection_name, features_chunk, part_index)
332 # Reset features accumulator
333 features_accumulator = {}
334 feature_array_size = 0
335 part_index += 1
337 # Concatenate all accumulated arrays
338 batches_metrics_array: dict[str, Any] = {}
339 for k, v_list in batch_metrics_accumulator.items():
340 batches_metrics_array[k] = pa.concat_arrays(v_list)
342 features_array: dict[str, Any] = {}
343 if features_accumulator:
344 for k, v_list in features_accumulator.items():
345 features_array[k] = pa.concat_arrays(v_list)
347 # Write remaining features to disk
348 if self.features_output and features_array:
349 self.features_output.write_table(selection_name, features_array, part_index)
351 return batches_metrics_array