Coverage for packages / dqm-ml-job / src / dqm_ml_job / dataloaders / parquet.py: 98%
83 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-15 10:11 +0000
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-15 10:11 +0000
1"""Parquet data loader for reading Parquet files.
3This module contains the ParquetDataLoader and ParquetDataSelection classes
4for loading and iterating over Parquet file data.
5"""
7import logging
8from typing import Any
10import pyarrow.compute as pc
11import pyarrow.parquet as pq
13# COMPATIBILITY : from typing import Any, override # When support of 3.10 and 3.11 will be removed
14from typing_extensions import override
16from dqm_ml_job.dataloaders.proto import DataSelection
18logger = logging.getLogger(__name__)
21class ParquetDataSelection(DataSelection):
22 """A specific selection of data from a Parquet dataset.
24 This class represents a filtered subset of a Parquet dataset
25 and provides an iterator over PyArrow RecordBatches.
27 Attributes:
28 name: Name identifier for this selection.
29 path: Path to the Parquet file or directory.
30 batch_size: Number of rows per batch.
31 threads: Number of threads for parallel reading.
32 filters_dict: Optional dictionary of column filters to apply.
33 """
35 def __init__(
36 self,
37 name: str,
38 path: str,
39 batch_size: int = 100_000,
40 threads: int = 4,
41 filters_dict: dict[str, Any] | None = None,
42 ):
43 """Initialize a Parquet data selection.
45 Args:
46 name: Name identifier for this selection.
47 path: Path to the Parquet file or directory.
48 batch_size: Number of rows per batch (default: 100000).
49 threads: Number of threads for parallel reading (default: 4).
50 filters_dict: Optional dictionary of column filters to apply.
51 """
52 self.name = name
53 self.path = path
54 self.batch_size = batch_size
55 self.threads = threads
56 self.filters_dict = filters_dict
57 self.columns_list: list[str] | None = None
58 self.dataset: pq.ParquetDataset | None = None
59 self.samples_count: int = 0
61 @override
62 def bootstrap(self, columns_list: list[str]) -> None:
63 self.columns_list = columns_list
64 filter_expr = None
65 if self.filters_dict is not None:
66 expr = None
67 for col, val in self.filters_dict.items():
68 if col not in (self.columns_list):
69 self.columns_list.append(col)
70 col_expr = pc.equal(pc.field(col), val)
71 expr = col_expr if expr is None else pc.and_(expr, col_expr)
72 filter_expr = expr
73 self.filter_expr = filter_expr
74 self.dataset = pq.ParquetDataset(self.path, filters=filter_expr)
75 if len(self.dataset.fragments) > 0:
76 self.samples_count = sum(p.count_rows() for p in self.dataset.fragments)
77 else:
78 self.samples_count = 0
80 def __len__(self) -> int:
81 return int(self.samples_count)
83 @override
84 def get_nb_batches(self) -> int:
85 return int(len(self) / self.batch_size) + (len(self) % self.batch_size > 0)
87 @override
88 def __iter__(self) -> Any:
89 if self.dataset is None: 89 ↛ 90line 89 didn't jump to line 90 because the condition on line 89 was never true
90 return
91 for file in self.dataset.files:
92 parquet_file = pq.ParquetFile(file)
93 batch_iterator = parquet_file.iter_batches(
94 batch_size=self.batch_size, columns=self.columns_list, use_threads=self.threads
95 )
96 for batch in batch_iterator:
97 if self.filter_expr is not None:
98 batch = batch.filter(self.filter_expr)
99 if len(batch) == 0:
100 continue
101 yield batch
103 @override
104 def __repr__(self) -> str:
105 return f"ParquetSelection(name='{self.name}', path='{self.path}', filters={self.filters_dict})"
108class ParquetDataLoader:
109 """Data loader for Parquet files that generates one or more DataSelections.
111 This loader can read from a single Parquet file or a directory of Parquet
112 files, optionally splitting the data by a column value to create multiple
113 selections.
115 Attributes:
116 type: The loader type identifier ("parquet").
117 """
119 type: str = "parquet"
121 def __init__(self, name: str, config: dict[str, Any] | None = None):
122 """Initialize the Parquet data loader.
124 Args:
125 name: Unique name for this loader instance.
126 config: Configuration dictionary containing:
127 - path: Path to Parquet file or directory (required)
128 - batch_size: Rows per batch (default: 100000)
129 - threads: Number of threads (default: 4)
130 - split_by: Column name to split selections by
131 - split_values: Specific values to split on
132 - filter: Dictionary of column filters
134 Raises:
135 ValueError: If required config keys are missing.
136 """
137 if not config or "path" not in config:
138 raise ValueError(f"Configuration for dataloader '{name}' must contain 'path'")
140 self.name = name
141 self.config = config
142 self.path = config["path"]
143 self.batch_size = config.get("batch_size", 100_000)
144 self.threads = config.get("threads", 4)
145 self.split_by = config.get("split_by")
146 self.split_values = config.get("split_values")
147 self.filters_dict = config.get("filter", None)
149 def get_selections(self) -> list[DataSelection]:
150 """Create one or more ParquetDataSelection instances based on configuration.
152 Returns:
153 A list of DataSelection instances. If split_by is configured,
154 returns one selection per unique value. Otherwise, returns a
155 single selection for the entire dataset.
156 """
157 if not self.split_by:
158 # Single selection
159 return [
160 ParquetDataSelection(
161 name=self.name,
162 path=self.path,
163 batch_size=self.batch_size,
164 threads=self.threads,
165 filters_dict=self.filters_dict,
166 )
167 ]
169 # Splitting logic
170 values = self.split_values
171 if values is None:
172 # Automatic discovery if split_values not provided
173 logger.info(f"Discovering unique values for split_by='{self.split_by}' in {self.path}")
174 table = pq.read_table(self.path, columns=[self.split_by])
175 values = [str(v) for v in pc.unique(table.column(0)).to_pylist() if v is not None]
177 selections: list[DataSelection] = []
178 for val in values:
179 selection_name = f"{self.name}_{val}"
180 # Merge existing filters with the split filter
181 merged_filters = (self.filters_dict or {}).copy()
182 merged_filters[self.split_by] = val
184 selections.append(
185 ParquetDataSelection(
186 name=selection_name,
187 path=self.path,
188 batch_size=self.batch_size,
189 threads=self.threads,
190 filters_dict=merged_filters,
191 )
192 )
193 return selections