Coverage for packages / dqm-ml-job / src / dqm_ml_job / dataloaders / parquet.py: 98%

83 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-04-15 10:11 +0000

1"""Parquet data loader for reading Parquet files. 

2 

3This module contains the ParquetDataLoader and ParquetDataSelection classes 

4for loading and iterating over Parquet file data. 

5""" 

6 

7import logging 

8from typing import Any 

9 

10import pyarrow.compute as pc 

11import pyarrow.parquet as pq 

12 

13# COMPATIBILITY : from typing import Any, override # When support of 3.10 and 3.11 will be removed 

14from typing_extensions import override 

15 

16from dqm_ml_job.dataloaders.proto import DataSelection 

17 

18logger = logging.getLogger(__name__) 

19 

20 

21class ParquetDataSelection(DataSelection): 

22 """A specific selection of data from a Parquet dataset. 

23 

24 This class represents a filtered subset of a Parquet dataset 

25 and provides an iterator over PyArrow RecordBatches. 

26 

27 Attributes: 

28 name: Name identifier for this selection. 

29 path: Path to the Parquet file or directory. 

30 batch_size: Number of rows per batch. 

31 threads: Number of threads for parallel reading. 

32 filters_dict: Optional dictionary of column filters to apply. 

33 """ 

34 

35 def __init__( 

36 self, 

37 name: str, 

38 path: str, 

39 batch_size: int = 100_000, 

40 threads: int = 4, 

41 filters_dict: dict[str, Any] | None = None, 

42 ): 

43 """Initialize a Parquet data selection. 

44 

45 Args: 

46 name: Name identifier for this selection. 

47 path: Path to the Parquet file or directory. 

48 batch_size: Number of rows per batch (default: 100000). 

49 threads: Number of threads for parallel reading (default: 4). 

50 filters_dict: Optional dictionary of column filters to apply. 

51 """ 

52 self.name = name 

53 self.path = path 

54 self.batch_size = batch_size 

55 self.threads = threads 

56 self.filters_dict = filters_dict 

57 self.columns_list: list[str] | None = None 

58 self.dataset: pq.ParquetDataset | None = None 

59 self.samples_count: int = 0 

60 

61 @override 

62 def bootstrap(self, columns_list: list[str]) -> None: 

63 self.columns_list = columns_list 

64 filter_expr = None 

65 if self.filters_dict is not None: 

66 expr = None 

67 for col, val in self.filters_dict.items(): 

68 if col not in (self.columns_list): 

69 self.columns_list.append(col) 

70 col_expr = pc.equal(pc.field(col), val) 

71 expr = col_expr if expr is None else pc.and_(expr, col_expr) 

72 filter_expr = expr 

73 self.filter_expr = filter_expr 

74 self.dataset = pq.ParquetDataset(self.path, filters=filter_expr) 

75 if len(self.dataset.fragments) > 0: 

76 self.samples_count = sum(p.count_rows() for p in self.dataset.fragments) 

77 else: 

78 self.samples_count = 0 

79 

80 def __len__(self) -> int: 

81 return int(self.samples_count) 

82 

83 @override 

84 def get_nb_batches(self) -> int: 

85 return int(len(self) / self.batch_size) + (len(self) % self.batch_size > 0) 

86 

87 @override 

88 def __iter__(self) -> Any: 

89 if self.dataset is None: 89 ↛ 90line 89 didn't jump to line 90 because the condition on line 89 was never true

90 return 

91 for file in self.dataset.files: 

92 parquet_file = pq.ParquetFile(file) 

93 batch_iterator = parquet_file.iter_batches( 

94 batch_size=self.batch_size, columns=self.columns_list, use_threads=self.threads 

95 ) 

96 for batch in batch_iterator: 

97 if self.filter_expr is not None: 

98 batch = batch.filter(self.filter_expr) 

99 if len(batch) == 0: 

100 continue 

101 yield batch 

102 

103 @override 

104 def __repr__(self) -> str: 

105 return f"ParquetSelection(name='{self.name}', path='{self.path}', filters={self.filters_dict})" 

106 

107 

108class ParquetDataLoader: 

109 """Data loader for Parquet files that generates one or more DataSelections. 

110 

111 This loader can read from a single Parquet file or a directory of Parquet 

112 files, optionally splitting the data by a column value to create multiple 

113 selections. 

114 

115 Attributes: 

116 type: The loader type identifier ("parquet"). 

117 """ 

118 

119 type: str = "parquet" 

120 

121 def __init__(self, name: str, config: dict[str, Any] | None = None): 

122 """Initialize the Parquet data loader. 

123 

124 Args: 

125 name: Unique name for this loader instance. 

126 config: Configuration dictionary containing: 

127 - path: Path to Parquet file or directory (required) 

128 - batch_size: Rows per batch (default: 100000) 

129 - threads: Number of threads (default: 4) 

130 - split_by: Column name to split selections by 

131 - split_values: Specific values to split on 

132 - filter: Dictionary of column filters 

133 

134 Raises: 

135 ValueError: If required config keys are missing. 

136 """ 

137 if not config or "path" not in config: 

138 raise ValueError(f"Configuration for dataloader '{name}' must contain 'path'") 

139 

140 self.name = name 

141 self.config = config 

142 self.path = config["path"] 

143 self.batch_size = config.get("batch_size", 100_000) 

144 self.threads = config.get("threads", 4) 

145 self.split_by = config.get("split_by") 

146 self.split_values = config.get("split_values") 

147 self.filters_dict = config.get("filter", None) 

148 

149 def get_selections(self) -> list[DataSelection]: 

150 """Create one or more ParquetDataSelection instances based on configuration. 

151 

152 Returns: 

153 A list of DataSelection instances. If split_by is configured, 

154 returns one selection per unique value. Otherwise, returns a 

155 single selection for the entire dataset. 

156 """ 

157 if not self.split_by: 

158 # Single selection 

159 return [ 

160 ParquetDataSelection( 

161 name=self.name, 

162 path=self.path, 

163 batch_size=self.batch_size, 

164 threads=self.threads, 

165 filters_dict=self.filters_dict, 

166 ) 

167 ] 

168 

169 # Splitting logic 

170 values = self.split_values 

171 if values is None: 

172 # Automatic discovery if split_values not provided 

173 logger.info(f"Discovering unique values for split_by='{self.split_by}' in {self.path}") 

174 table = pq.read_table(self.path, columns=[self.split_by]) 

175 values = [str(v) for v in pc.unique(table.column(0)).to_pylist() if v is not None] 

176 

177 selections: list[DataSelection] = [] 

178 for val in values: 

179 selection_name = f"{self.name}_{val}" 

180 # Merge existing filters with the split filter 

181 merged_filters = (self.filters_dict or {}).copy() 

182 merged_filters[self.split_by] = val 

183 

184 selections.append( 

185 ParquetDataSelection( 

186 name=selection_name, 

187 path=self.path, 

188 batch_size=self.batch_size, 

189 threads=self.threads, 

190 filters_dict=merged_filters, 

191 ) 

192 ) 

193 return selections