Beispiel #1
0
    def set_splits(self, split_dict: splits_lib.SplitDict) -> None:
        """Split setter (private method)."""
        if self._builder.name != split_dict._dataset_name:  # pylint: disable=protected-access
            raise AssertionError(
                "SplitDict dataset_name does not seems to match dataset_info. "  # pylint: disable=protected-access
                f"{self._builder.name} != {split_dict._dataset_name}")

        # If the statistics have been pre-loaded, forward the statistics
        # into the new split_dict
        new_split_infos = []
        for split_info in split_dict.values():
            old_split_info = self._splits.get(split_info.name)
            if (not split_info.statistics.ByteSize() and old_split_info
                    and old_split_info.statistics.ByteSize() and
                    old_split_info.shard_lengths == split_info.shard_lengths):
                split_info = split_info.replace(
                    statistics=old_split_info.statistics)
            new_split_infos.append(split_info)

        # Update the dictionary representation.
        self._splits = splits_lib.SplitDict(
            new_split_infos,
            dataset_name=self._builder.name,
        )

        # Update the proto
        del self.as_proto.splits[:]  # Clear previous
        for split_info in split_dict.to_proto():
            self.as_proto.splits.add().CopyFrom(split_info)
Beispiel #2
0
    def set_splits(self, split_dict: splits_lib.SplitDict) -> None:
        """Split setter (private method)."""
        for split, split_info in split_dict.items():
            if isinstance(split_info, splits_lib.MultiSplitInfo):
                # When splits are from multiple folders, the dataset can be different.
                continue
            if (split_info.filename_template and
                    self.name != split_info.filename_template.dataset_name):
                raise AssertionError(
                    f"SplitDict contains SplitInfo for split {split} whose "
                    "dataset_name does not match to the dataset name in dataset_info. "
                    f"{self.name} != {split_info.filename_template.dataset_name}"
                )

        # If the statistics have been pre-loaded, forward the statistics
        # into the new split_dict. Also add the filename template if it's not set.
        new_split_infos = []
        incomplete_filename_template = naming.ShardedFileTemplate(
            dataset_name=self.name,
            data_dir=self.data_dir,
            filetype_suffix=(self.as_proto.file_format
                             or file_adapters.DEFAULT_FILE_FORMAT.value))
        for split_info in split_dict.values():
            if isinstance(split_info, splits_lib.MultiSplitInfo):
                new_split_infos.append(split_info)
                continue
            old_split_info = self._splits.get(split_info.name)
            if (not split_info.statistics.ByteSize() and old_split_info
                    and old_split_info.statistics.ByteSize() and
                    old_split_info.shard_lengths == split_info.shard_lengths):
                split_info = split_info.replace(
                    statistics=old_split_info.statistics)
            if not split_info.filename_template:
                filename_template = incomplete_filename_template.replace(
                    split=split_info.name)
                split_info = split_info.replace(
                    filename_template=filename_template)
            new_split_infos.append(split_info)

        # Update the dictionary representation.
        self._splits = splits_lib.SplitDict(new_split_infos)

        # Update the proto
        # Note that the proto should not be saved or used for multi-folder datasets.
        del self.as_proto.splits[:]  # Clear previous
        for split_info in self._splits.values():
            if isinstance(split_info, splits_lib.MultiSplitInfo):
                for si in split_info.split_infos:
                    self.as_proto.splits.add().CopyFrom(si.to_proto())
            else:
                self.as_proto.splits.add().CopyFrom(split_info.to_proto())