def set_splits(self, split_dict: splits_lib.SplitDict) -> None: """Split setter (private method).""" if self._builder.name != split_dict._dataset_name: # pylint: disable=protected-access raise AssertionError( "SplitDict dataset_name does not seems to match dataset_info. " # pylint: disable=protected-access f"{self._builder.name} != {split_dict._dataset_name}") # If the statistics have been pre-loaded, forward the statistics # into the new split_dict new_split_infos = [] for split_info in split_dict.values(): old_split_info = self._splits.get(split_info.name) if (not split_info.statistics.ByteSize() and old_split_info and old_split_info.statistics.ByteSize() and old_split_info.shard_lengths == split_info.shard_lengths): split_info = split_info.replace( statistics=old_split_info.statistics) new_split_infos.append(split_info) # Update the dictionary representation. self._splits = splits_lib.SplitDict( new_split_infos, dataset_name=self._builder.name, ) # Update the proto del self.as_proto.splits[:] # Clear previous for split_info in split_dict.to_proto(): self.as_proto.splits.add().CopyFrom(split_info)
def set_splits(self, split_dict: splits_lib.SplitDict) -> None: """Split setter (private method).""" for split, split_info in split_dict.items(): if isinstance(split_info, splits_lib.MultiSplitInfo): # When splits are from multiple folders, the dataset can be different. continue if (split_info.filename_template and self.name != split_info.filename_template.dataset_name): raise AssertionError( f"SplitDict contains SplitInfo for split {split} whose " "dataset_name does not match to the dataset name in dataset_info. " f"{self.name} != {split_info.filename_template.dataset_name}" ) # If the statistics have been pre-loaded, forward the statistics # into the new split_dict. Also add the filename template if it's not set. new_split_infos = [] incomplete_filename_template = naming.ShardedFileTemplate( dataset_name=self.name, data_dir=self.data_dir, filetype_suffix=(self.as_proto.file_format or file_adapters.DEFAULT_FILE_FORMAT.value)) for split_info in split_dict.values(): if isinstance(split_info, splits_lib.MultiSplitInfo): new_split_infos.append(split_info) continue old_split_info = self._splits.get(split_info.name) if (not split_info.statistics.ByteSize() and old_split_info and old_split_info.statistics.ByteSize() and old_split_info.shard_lengths == split_info.shard_lengths): split_info = split_info.replace( statistics=old_split_info.statistics) if not split_info.filename_template: filename_template = incomplete_filename_template.replace( split=split_info.name) split_info = split_info.replace( filename_template=filename_template) new_split_infos.append(split_info) # Update the dictionary representation. self._splits = splits_lib.SplitDict(new_split_infos) # Update the proto # Note that the proto should not be saved or used for multi-folder datasets. del self.as_proto.splits[:] # Clear previous for split_info in self._splits.values(): if isinstance(split_info, splits_lib.MultiSplitInfo): for si in split_info.split_infos: self.as_proto.splits.add().CopyFrom(si.to_proto()) else: self.as_proto.splits.add().CopyFrom(split_info.to_proto())