예제 #1
0
    def _slice_split_info_to_instruction_dicts(self, list_sliced_split_info):
        """Return the list of files and reading mask of the files to read."""
        instruction_dicts = []
        for sliced_split_info in list_sliced_split_info:
            mask = splits_lib.slice_to_percent_mask(
                sliced_split_info.slice_value)

            # Compute filenames from the given split
            filepaths = list(
                sorted(
                    self._build_split_filenames(sliced_split_info.split_info)))

            # Compute the offsets
            if sliced_split_info.split_info.num_examples:
                shard_id2num_examples = splits_lib.get_shard_id2num_examples(
                    sliced_split_info.split_info.num_shards,
                    sliced_split_info.split_info.num_examples,
                )
                mask_offsets = splits_lib.compute_mask_offsets(
                    shard_id2num_examples)
            else:
                logging.warning(
                    "Statistics not present in the dataset. TFDS is not able to load "
                    "the total number of examples, so using the subsplit API may not "
                    "provide precise subsplits.")
                mask_offsets = [0] * len(filepaths)

            for filepath, mask_offset in zip(filepaths, mask_offsets):
                instruction_dicts.append({
                    "filepath": filepath,
                    "mask": mask,
                    "mask_offset": mask_offset,
                })
        return instruction_dicts
 def test_compute_mask_offsets(self):
     self.assertEqual(
         splits.compute_mask_offsets([1100, 500, 1100, 110]),
         [0, 0, 0, 0],
     )
     self.assertEqual(
         splits.compute_mask_offsets([1101, 500, 1100, 110]),
         [0, 1, 1, 1],
     )
     self.assertEqual(
         splits.compute_mask_offsets([87]),
         [0],
     )
     self.assertEqual(
         splits.compute_mask_offsets([1101, 501, 1113, 110]),
         [0, 1, 2, 15],
     )