Exemplo n.º 1
0
  def setUp(self):
    super(ReadInstructionTest, self).setUp()

    self.splits = {
        'train':
            splits.SplitInfo(
                name='train',
                shard_lengths=[200],
                num_bytes=0,
                filename_template=_filename_template('train')),
        'test':
            splits.SplitInfo(
                name='test',
                shard_lengths=[101],
                num_bytes=0,
                filename_template=_filename_template('test')),
        'validation':
            splits.SplitInfo(
                name='validation',
                shard_lengths=[30],
                num_bytes=0,
                filename_template=_filename_template('validation')),
        'dev-train':
            splits.SplitInfo(
                name='dev-train',
                shard_lengths=[5, 5],
                num_bytes=0,
                filename_template=_filename_template('dev-train')),
    }
Exemplo n.º 2
0
def test_even_splits_subsplit():
    split_infos = splits_lib.SplitDict(split_infos=[
        splits_lib.SplitInfo(
            name='train',
            shard_lengths=[2, 3, 2, 3],  # 10
            num_bytes=0,
            filename_template=_filename_template('train'),
        ),
        splits_lib.SplitInfo(
            name='test',
            shard_lengths=[8],
            num_bytes=0,
            filename_template=_filename_template('test'),
        ),
    ])

    # Test to split multiple splits
    subsplits = subsplits_utils.even_splits('train+test[50%:]', 3)

    expected = [
        'train[:4]+test[4:6]',
        'train[4:7]+test[6:7]',
        'train[7:]+test[7:8]',
    ]

    file_instructions = [split_infos[s].file_instructions for s in subsplits]
    expected_file_instructions = [
        split_infos[s].file_instructions for s in expected
    ]
    assert file_instructions == expected_file_instructions
Exemplo n.º 3
0
 def test_set_splits_normal(self):
     info = dataset_info.DatasetInfo(builder=self._builder)
     split_info1 = splits_lib.SplitInfo(name="train",
                                        shard_lengths=[1, 2],
                                        num_bytes=0)
     split_info2 = splits_lib.SplitInfo(name="test",
                                        shard_lengths=[1],
                                        num_bytes=0)
     split_dict = splits_lib.SplitDict(
         split_infos=[split_info1, split_info2])
     info.set_splits(split_dict)
     self.assertEqual(str(info.splits), str(split_dict))
     self.assertEqual(str(info.as_proto.splits),
                      str([split_info1.to_proto(),
                           split_info2.to_proto()]))
Exemplo n.º 4
0
  def __init__(
      self,
      root_dir: str,
      *,
      shape: Optional[type_utils.Shape] = None,
      dtype: Optional[tf.DType] = None,
  ):
    """Construct the `DatasetBuilder`.

    Args:
      root_dir: Path to the directory containing the images.
      shape: Image shape forwarded to `tfds.features.Image`.
      dtype: Image dtype forwarded to `tfds.features.Image`.
    """
    self._image_shape = shape
    self._image_dtype = dtype
    super(ImageFolder, self).__init__()
    self._data_dir = root_dir  # Set data_dir to the existing dir.

    # Extract the splits, examples, labels
    root_dir = os.path.expanduser(root_dir)
    self._split_examples, labels = _get_split_label_images(root_dir)

    # Update DatasetInfo labels
    self.info.features['label'].names = sorted(labels)

    # Update DatasetInfo splits
    split_dict = split_lib.SplitDict(self.name)
    for split_name, examples in self._split_examples.items():
      split_dict.add(split_lib.SplitInfo(
          name=split_name,
          shard_lengths=[len(examples)],
      ))
    self.info.update_splits_if_different(split_dict)
Exemplo n.º 5
0
def _initialize_split(split_name: str,
                      data_directory: Any,
                      ds_name: str,
                      filetype_suffix: str,
                      shard_lengths: Optional[List[int]] = None,
                      num_bytes: int = 0) -> Split:
    """Initializes a split.

  Args:
    split_name: name of the split.
    data_directory: directory where the split data will be located.
    ds_name: name of the dataset.
    filetype_suffix: file format.
    shard_lengths: if the split already has shards, it contains the list of the
      shard lenghts. If None, it assumes that the split is empty.
    num_bytes: number of bytes that have been written already.
  Returns:
    A Split.
  """
    if not shard_lengths:
        shard_lengths = []
    filename_template = naming.ShardedFileTemplate(
        dataset_name=ds_name,
        data_dir=data_directory,
        split=split_name,
        filetype_suffix=filetype_suffix,
        template='{DATASET}-{SPLIT}.{FILEFORMAT}-{SHARD_INDEX}',
    )
    return Split(info=splits_lib.SplitInfo(
        name=split_name,
        shard_lengths=shard_lengths,
        num_bytes=num_bytes,
        filename_template=filename_template),
                 complete_shards=len(shard_lengths),
                 ds_name=ds_name)
Exemplo n.º 6
0
 def test_missing_shard_lengths(self):
     with self.assertRaisesWithPredicateMatch(ValueError, 'Shard empty.'):
         split_info = [
             splits.SplitInfo(name='train', shard_lengths=[]),
         ]
         tfrecords_reader.make_file_instructions('mnist', split_info,
                                                 'train')
Exemplo n.º 7
0
 def test_nodata_instruction(self):
     # Given instruction corresponds to no data.
     with self.assertRaisesWithPredicateMatch(AssertionError,
                                              'corresponds to no data!'):
         train_info = splits.SplitInfo(name='train',
                                       shard_lengths=[2, 3, 2, 3, 2])
         self.reader.read('mnist', 'train[0:0]', [train_info])
Exemplo n.º 8
0
    def _build_from_generator(
        self,
        split_name: str,
        generator: Iterable[KeyExample],
        path: type_utils.PathLike,
    ) -> _SplitInfoFuture:
        """Split generator for example generators.

    Args:
      split_name: str,
      generator: Iterable[KeyExample],
      path: type_utils.PathLike,

    Returns:
      future: The future containing the `tfds.core.SplitInfo`.
    """
        if self._max_examples_per_split is not None:
            logging.warning('Splits capped at %s examples max.',
                            self._max_examples_per_split)
            generator = itertools.islice(generator,
                                         self._max_examples_per_split)
            total_num_examples = self._max_examples_per_split
        else:
            # If dataset info has been pre-downloaded from the internet,
            # we can use the pre-computed number of example for the progression bar.
            split_info = self._split_dict.get(split_name)
            if split_info and split_info.num_examples:
                total_num_examples = split_info.num_examples
            else:
                total_num_examples = None

        writer = tfrecords_writer.Writer(
            example_specs=self._features.get_serialized_info(),
            path=path,
            hash_salt=split_name,
            file_format=self._file_format,
        )
        for key, example in utils.tqdm(
                generator,
                desc=f'Generating {split_name} examples...',
                unit=' examples',
                total=total_num_examples,
                leave=False,
        ):
            try:
                example = self._features.encode_example(example)
            except Exception as e:  # pylint: disable=broad-except
                utils.reraise(e,
                              prefix=f'Failed to encode example:\n{example}\n')
            writer.write(key, example)
        shard_lengths, total_size = writer.finalize()

        split_info = splits_lib.SplitInfo(
            name=split_name,
            shard_lengths=shard_lengths,
            num_bytes=total_size,
        )
        return _SplitInfoFuture(lambda: split_info)
Exemplo n.º 9
0
def test_even_splits_add():
    # Compatibility of even_splits with other splits

    split_infos = splits_lib.SplitDict(split_infos=[
        splits_lib.SplitInfo(
            name='train',
            shard_lengths=[2, 3, 2, 3],  # 10
            num_bytes=0,
            filename_template=_filename_template('train'),
        ),
        splits_lib.SplitInfo(
            name='test',
            shard_lengths=[8],
            num_bytes=0,
            filename_template=_filename_template('test'),
        ),
        splits_lib.SplitInfo(
            name='validation',
            shard_lengths=[8],
            filename_template=_filename_template('validation'),
            num_bytes=0,
        ),
    ])

    # Test to split multiple splits
    split = subsplits_utils.even_splits('train', 3, drop_remainder=True)[0]
    split = split + 'test'

    expected = 'train[:3]+test'

    file_instructions = split_infos[split].file_instructions
    expected_file_instructions = (split_infos[expected].file_instructions)
    assert file_instructions == expected_file_instructions

    # Test nested `even_splits`
    splits = subsplits_utils.even_splits('validation', n=2)
    splits = subsplits_utils.even_splits(splits[1], n=2)
    assert (split_infos[splits[0]].file_instructions ==
            split_infos['validation[4:6]'].file_instructions)
    assert (split_infos[splits[1]].file_instructions ==
            split_infos['validation[6:8]'].file_instructions)
Exemplo n.º 10
0
 def _get_files(self, instruction):
   split_infos = [
       splits.SplitInfo(
           name='train',
           shard_lengths=[3, 2, 3, 2, 3],  # 13 examples.
           num_bytes=0,
           filename_template=_filename_template(
               dataset_name='mnist', split='train'),
       )
   ]
   splits_dict = splits.SplitDict(split_infos=split_infos)
   return splits_dict[instruction].file_instructions
Exemplo n.º 11
0
 def _resolve_future():
   if self._in_contextmanager:
     raise AssertionError('`future.result()` should be called after the '
                          '`maybe_beam_pipeline` contextmanager.')
   logging.info('Retrieving split info for %s...', split_name)
   shard_lengths, total_size = beam_writer.finalize()
   return splits_lib.SplitInfo(
       name=split_name,
       shard_lengths=shard_lengths,
       num_bytes=total_size,
       filename_template=filename_template,
   )
Exemplo n.º 12
0
 def test_missing_shard_lengths(self):
   with self.assertRaisesWithPredicateMatch(ValueError, 'Shard empty.'):
     filename_template = _filename_template(
         split='train', dataset_name='mnist')
     split_infos = [
         splits.SplitInfo(
             name='train',
             shard_lengths=[],
             num_bytes=0,
             filename_template=filename_template),
     ]
     splits_dict = splits.SplitDict(split_infos=split_infos)
     _ = splits_dict['train'].file_instructions
Exemplo n.º 13
0
 def close_shard(self) -> None:
     """Finalizes a shard and updates the split metadata accordingly."""
     if not self.current_shard:
         return
     self.current_shard.close_writer()
     self.info = splits_lib.SplitInfo(
         name=self.info.name,
         shard_lengths=self.info.shard_lengths +
         [self.current_shard.num_examples],
         num_bytes=self.info.num_bytes + self.current_shard.num_bytes,
         filename_template=self.info.filename_template)
     self.complete_shards += 1
     self.current_shard = None
Exemplo n.º 14
0
 def test_nodata_instruction(self):
     # Given instruction corresponds to no data.
     with self.assertRaisesWithPredicateMatch(ValueError,
                                              'corresponds to no data!'):
         train_info = splits.SplitInfo(
             name='train',
             shard_lengths=[2, 3, 2, 3, 2],
             num_bytes=0,
             filename_template=self._filename_template(split='train'),
         )
         self.reader.read(
             instructions='train[0:0]',
             split_infos=[train_info],
         )
Exemplo n.º 15
0
 def test_set_splits_incorrect_dataset_name(self):
     info = dataset_info.DatasetInfo(builder=self._builder)
     split_info1 = splits_lib.SplitInfo(
         name="train",
         shard_lengths=[1, 2],
         num_bytes=0,
         filename_template=naming.ShardedFileTemplate(
             dataset_name="some_other_dataset",
             split="train",
             data_dir=info.data_dir,
             filetype_suffix="tfrecord"))
     split_dict = splits_lib.SplitDict(split_infos=[split_info1])
     with pytest.raises(AssertionError,
                        match="SplitDict contains SplitInfo for split"):
         info.set_splits(split_dict)
Exemplo n.º 16
0
 def _write_tfrecord(self, split_name, shards_number, records):
   path = os.path.join(self.tmp_dir, 'mnist-%s.tfrecord' % split_name)
   num_examples = len(records)
   with absltest.mock.patch.object(tfrecords_writer, '_get_number_shards',
                                   return_value=shards_number):
     shard_specs = tfrecords_writer._get_shard_specs(
         num_examples, 0, [num_examples], path)
   serialized_records = [six.b(rec) for rec in records]
   for shard_spec in shard_specs:
     _write_tfrecord_from_shard_spec(
         shard_spec, lambda unused_i: iter(serialized_records))
   return splits.SplitInfo(
       name=split_name,
       shard_lengths=[int(s.examples_number) for s in shard_specs],
   )
Exemplo n.º 17
0
def _merge_shard_info(shard_infos: List[_ShardInfo]) -> split_lib.SplitInfo:
    """Merge all shard info from one splits and returns the json.

  Args:
    shard_infos: The shard infos from all shards.

  Returns:
    The json SplitInfo proto
  """
    split_name, = {s.file_info.split for s in shard_infos}
    shard_infos = sorted(shard_infos, key=lambda s: s.file_info.shard_index)
    return split_lib.SplitInfo(
        name=split_name,
        shard_lengths=[s.num_examples for s in shard_infos],
        num_bytes=sum(s.bytes_size for s in shard_infos),
    )
Exemplo n.º 18
0
    def __init__(self, root_dir: str):
        # Extract the splits, examples
        root_dir = os.path.expanduser(root_dir)
        self._split_examples, self._languages = _get_split_language_examples(
            root_dir)

        super(TranslateFolder, self).__init__()
        # Reset `_data_dir` as it should not change to DATA_DIR/Version
        self._data_dir = root_dir

        # Update DatasetInfo splits
        split_dict = split_lib.SplitDict(self.name)
        for split_name, examples in self._split_examples.items():
            split_dict.add(
                split_lib.SplitInfo(
                    name=split_name,
                    shard_lengths=[len(next(iter(examples.values())))],
                ))
        self.info.update_splits_if_different(split_dict)
Exemplo n.º 19
0
 def _write_tfrecord(self, split_name, shards_number, records):
   filename_template = self._filename_template(split=split_name)
   num_examples = len(records)
   with mock.patch.object(
       tfrecords_writer, '_get_number_shards', return_value=shards_number):
     shard_specs = tfrecords_writer._get_shard_specs(
         num_examples=num_examples,
         total_size=0,
         bucket_lengths=[num_examples],
         filename_template=filename_template)
   serialized_records = [(key, six.b(rec)) for key, rec in enumerate(records)]
   for shard_spec in shard_specs:
     _write_tfrecord_from_shard_spec(shard_spec,
                                     lambda unused_i: iter(serialized_records))
   return splits.SplitInfo(
       name=split_name,
       shard_lengths=[int(s.examples_number) for s in shard_specs],
       num_bytes=0,
       filename_template=filename_template)
Exemplo n.º 20
0
def test_even_splits(num_examples, n, drop_remainder, expected):
    split_infos = splits_lib.SplitDict(split_infos=[
        splits_lib.SplitInfo(
            name='train',
            shard_lengths=[num_examples],
            num_bytes=0,
            filename_template=_filename_template('train'),
        ),
    ])

    subsplits = subsplits_utils.even_splits('train',
                                            n,
                                            drop_remainder=drop_remainder)

    file_instructions = [split_infos[s].file_instructions for s in subsplits]
    expected_file_instructions = [
        split_infos[f'train{s}'].file_instructions for s in expected
    ]
    assert file_instructions == expected_file_instructions
Exemplo n.º 21
0
    def __init__(self, root_dir: str):
        # Extract the splits, examples
        root_dir = os.path.expanduser(root_dir)
        self._split_examples, self._languages = _get_split_language_examples(
            root_dir)

        super(TranslateFolder, self).__init__()
        # Reset `_data_dir` as it should not change to DATA_DIR/Version
        self._data_dir = root_dir

        # Update DatasetInfo splits
        split_infos = [
            split_lib.SplitInfo(  # pylint: disable=g-complex-comprehension
                name=split_name,
                shard_lengths=[len(next(iter(examples.values())))],
                num_bytes=0,
            ) for split_name, examples in self._split_examples.items()
        ]
        split_dict = split_lib.SplitDict(split_infos, dataset_name=self.name)
        self.info.set_splits(split_dict)
Exemplo n.º 22
0
    def __init__(self, root_dir: str):
        super(ImageFolder, self).__init__()
        self._data_dir = root_dir  # Set data_dir to the existing dir.

        # Extract the splits, examples, labels
        root_dir = os.path.expanduser(root_dir)
        self._split_examples, labels = _get_split_label_images(root_dir)

        # Update DatasetInfo labels
        self.info.features['label'].names = sorted(labels)

        # Update DatasetInfo splits
        split_dict = split_lib.SplitDict(self.name)
        for split_name, examples in self._split_examples.items():
            split_dict.add(
                split_lib.SplitInfo(
                    name=split_name,
                    shard_lengths=[len(examples)],
                ))
        self.info.update_splits_if_different(split_dict)
Exemplo n.º 23
0
def _merge_shard_info(
    shard_infos: List[_ShardInfo],
    filename_template: naming.ShardedFileTemplate,
) -> split_lib.SplitInfo:
    """Merge all shard info from one splits and returns the SplitInfo.

  Args:
    shard_infos: The shard infos from all shards.
    filename_template: filename template of the splits.

  Returns:
    The json SplitInfo proto
  """
    split_name, = {s.file_info.split for s in shard_infos}
    shard_infos = sorted(shard_infos, key=lambda s: s.file_info.shard_index)
    filename_template = filename_template.replace(split=split_name)
    return split_lib.SplitInfo(
        name=split_name,
        shard_lengths=[s.num_examples for s in shard_infos],
        num_bytes=sum(s.bytes_size for s in shard_infos),
        filename_template=filename_template,
    )
Exemplo n.º 24
0
class GetDatasetFilesTest(testing.TestCase):

  SPLIT_INFOS = {
      'train':
          splits.SplitInfo(
              name='train',
              shard_lengths=[3, 2, 3, 2, 3],  # 13 examples.
              num_bytes=0,
          ),
  }

  PATH_PATTERN = 'mnist-train.tfrecord-0000%d-of-00005'

  def _get_files(self, instruction):
    file_instructions = tfrecords_reader._make_file_instructions_from_absolutes(
        name='mnist',
        split_infos=self.SPLIT_INFOS,
        absolute_instructions=[instruction],
    )
    return file_instructions

  def test_no_skip_no_take(self):
    instruction = tfrecords_reader._AbsoluteInstruction('train', None, None)
    files = self._get_files(instruction)
    self.assertEqual(files, [
        shard_utils.FileInstruction(
            filename=self.PATH_PATTERN % i, skip=0, take=-1, num_examples=n)
        for i, n in enumerate([3, 2, 3, 2, 3])
    ])

  def test_skip(self):
    # One file is not taken, one file is partially taken.
    instruction = tfrecords_reader._AbsoluteInstruction('train', 4, None)
    files = self._get_files(instruction)
    self.assertEqual(files, [
        shard_utils.FileInstruction(
            filename=self.PATH_PATTERN % 1, skip=1, take=-1, num_examples=1),
        shard_utils.FileInstruction(
            filename=self.PATH_PATTERN % 2, skip=0, take=-1, num_examples=3),
        shard_utils.FileInstruction(
            filename=self.PATH_PATTERN % 3, skip=0, take=-1, num_examples=2),
        shard_utils.FileInstruction(
            filename=self.PATH_PATTERN % 4, skip=0, take=-1, num_examples=3),
    ])

  def test_take(self):
    # Two files are not taken, one file is partially taken.
    instruction = tfrecords_reader._AbsoluteInstruction('train', None, 6)
    files = self._get_files(instruction)
    self.assertEqual(files, [
        shard_utils.FileInstruction(
            filename=self.PATH_PATTERN % 0, skip=0, take=-1, num_examples=3),
        shard_utils.FileInstruction(
            filename=self.PATH_PATTERN % 1, skip=0, take=-1, num_examples=2),
        shard_utils.FileInstruction(
            filename=self.PATH_PATTERN % 2, skip=0, take=1, num_examples=1),
    ])

  def test_skip_take1(self):
    # A single shard with both skip and take.
    instruction = tfrecords_reader._AbsoluteInstruction('train', 1, 2)
    files = self._get_files(instruction)
    self.assertEqual(files, [
        shard_utils.FileInstruction(
            filename=self.PATH_PATTERN % 0, skip=1, take=1, num_examples=1),
    ])

  def test_skip_take2(self):
    # 2 elements in across two shards are taken in middle.
    instruction = tfrecords_reader._AbsoluteInstruction('train', 7, 9)
    files = self._get_files(instruction)
    self.assertEqual(files, [
        shard_utils.FileInstruction(
            filename=self.PATH_PATTERN % 2, skip=2, take=-1, num_examples=1),
        shard_utils.FileInstruction(
            filename=self.PATH_PATTERN % 3, skip=0, take=1, num_examples=1),
    ])

  def test_touching_boundaries(self):
    # Nothing to read.
    instruction = tfrecords_reader._AbsoluteInstruction('train', 0, 0)
    files = self._get_files(instruction)
    self.assertEqual(files, [])

    instruction = tfrecords_reader._AbsoluteInstruction('train', None, 0)
    files = self._get_files(instruction)
    self.assertEqual(files, [])

    instruction = tfrecords_reader._AbsoluteInstruction('train', 3, 3)
    files = self._get_files(instruction)
    self.assertEqual(files, [])

    instruction = tfrecords_reader._AbsoluteInstruction('train', 13, None)
    files = self._get_files(instruction)
    self.assertEqual(files, [])

  def test_missing_shard_lengths(self):
    with self.assertRaisesWithPredicateMatch(ValueError, 'Shard empty.'):
      split_info = [
          splits.SplitInfo(name='train', shard_lengths=[], num_bytes=0),
      ]
      tfrecords_reader.make_file_instructions('mnist', split_info, 'train')
Exemplo n.º 25
0
  def _build_from_generator(
      self,
      split_name: str,
      generator: Iterable[KeyExample],
      filename_template: naming.ShardedFileTemplate,
      disable_shuffling: bool,
  ) -> _SplitInfoFuture:
    """Split generator for example generators.

    Args:
      split_name: str,
      generator: Iterable[KeyExample],
      filename_template: Template to format the filename for a shard.
      disable_shuffling: Specifies whether to shuffle the examples,

    Returns:
      future: The future containing the `tfds.core.SplitInfo`.
    """
    if self._max_examples_per_split is not None:
      logging.warning('Splits capped at %s examples max.',
                      self._max_examples_per_split)
      generator = itertools.islice(generator, self._max_examples_per_split)
      total_num_examples = self._max_examples_per_split
    else:
      # If dataset info has been pre-downloaded from the internet,
      # we can use the pre-computed number of example for the progression bar.
      split_info = self._split_dict.get(split_name)
      if split_info and split_info.num_examples:
        total_num_examples = split_info.num_examples
      else:
        total_num_examples = None

    writer = writer_lib.Writer(
        serializer=example_serializer.ExampleSerializer(
            self._features.get_serialized_info()),
        filename_template=filename_template,
        hash_salt=split_name,
        disable_shuffling=disable_shuffling,
        # TODO(weide) remove this because it's already in filename_template?
        file_format=self._file_format,
    )
    for key, example in utils.tqdm(
        generator,
        desc=f'Generating {split_name} examples...',
        unit=' examples',
        total=total_num_examples,
        leave=False,
    ):
      try:
        example = self._features.encode_example(example)
      except Exception as e:  # pylint: disable=broad-except
        utils.reraise(e, prefix=f'Failed to encode example:\n{example}\n')
      writer.write(key, example)
    shard_lengths, total_size = writer.finalize()

    split_info = splits_lib.SplitInfo(
        name=split_name,
        shard_lengths=shard_lengths,
        num_bytes=total_size,
        filename_template=filename_template,
    )
    return _SplitInfoFuture(lambda: split_info)
Exemplo n.º 26
0
class ReaderTest(testing.TestCase):

    SPLIT_INFOS = [
        splits.SplitInfo(name='train', shard_lengths=[2, 3, 2, 3,
                                                      2]),  # 12 ex.
        splits.SplitInfo(name='test', shard_lengths=[2, 3, 2]),  # 7 ex.
    ]

    def setUp(self):
        super(ReaderTest, self).setUp()
        with absltest.mock.patch.object(example_parser, 'ExampleParser',
                                        testing.DummyParser):
            self.reader = tfrecords_reader.Reader(self.tmp_dir, 'some_spec')

    def _write_tfrecord(self, split_name, shards_number, records):
        path = os.path.join(self.tmp_dir, 'mnist-%s.tfrecord' % split_name)
        writer = tfrecords_writer._TFRecordWriter(path, len(records),
                                                  shards_number)
        for rec in records:
            writer.write(six.b(rec))
        with absltest.mock.patch.object(tfrecords_writer,
                                        '_get_number_shards',
                                        return_value=shards_number):
            writer.finalize()

    def _write_tfrecords(self):
        self._write_tfrecord('train', 5, 'abcdefghijkl')
        self._write_tfrecord('test', 3, 'mnopqrs')

    def test_nodata_instruction(self):
        # Given instruction corresponds to no data.
        with self.assertRaisesWithPredicateMatch(AssertionError,
                                                 'corresponds to no data!'):
            self.reader.read('mnist', 'train[0:0]', self.SPLIT_INFOS)

    def test_noskip_notake(self):
        self._write_tfrecord('train', 5, 'abcdefghijkl')
        ds = self.reader.read('mnist', 'train', self.SPLIT_INFOS)
        read_data = list(tfds.as_numpy(ds))
        self.assertEqual(read_data, [six.b(l) for l in 'abcdefghijkl'])

    def test_overlap(self):
        self._write_tfrecord('train', 5, 'abcdefghijkl')
        ds = self.reader.read('mnist', 'train+train[:2]', self.SPLIT_INFOS)
        read_data = list(tfds.as_numpy(ds))
        self.assertEqual(read_data, [six.b(l) for l in 'abcdefghijklab'])

    def test_complex(self):
        self._write_tfrecord('train', 5, 'abcdefghijkl')
        self._write_tfrecord('test', 3, 'mnopqrs')
        ds = self.reader.read('mnist', 'train[1:-1]+test[:-50%]',
                              self.SPLIT_INFOS)
        read_data = list(tfds.as_numpy(ds))
        self.assertEqual(read_data, [six.b(l) for l in 'bcdefghijkmno'])

    def test_4fold(self):
        self._write_tfrecord('train', 5, 'abcdefghijkl')
        instructions = [
            tfrecords_reader.ReadInstruction('train',
                                             from_=k,
                                             to=k + 25,
                                             unit='%')
            for k in range(0, 100, 25)
        ]
        tests = self.reader.read('mnist', instructions, self.SPLIT_INFOS)
        instructions = [
            (tfrecords_reader.ReadInstruction('train', to=k, unit='%') +
             tfrecords_reader.ReadInstruction('train', from_=k + 25, unit='%'))
            for k in range(0, 100, 25)
        ]
        trains = self.reader.read('mnist', instructions, self.SPLIT_INFOS)
        read_tests = [list(r) for r in tfds.as_numpy(tests)]
        read_trains = [list(r) for r in tfds.as_numpy(trains)]
        self.assertEqual(read_tests, [[b'a', b'b', b'c'], [b'd', b'e', b'f'],
                                      [b'g', b'h', b'i'], [b'j', b'k', b'l']])
        self.assertEqual(
            read_trains,
            [[b'd', b'e', b'f', b'g', b'h', b'i', b'j', b'k', b'l'],
             [b'a', b'b', b'c', b'g', b'h', b'i', b'j', b'k', b'l'],
             [b'a', b'b', b'c', b'd', b'e', b'f', b'j', b'k', b'l'],
             [b'a', b'b', b'c', b'd', b'e', b'f', b'g', b'h', b'i']])
Exemplo n.º 27
0
class ReaderTest(testing.TestCase):

    SPLIT_INFOS = [
        splits.SplitInfo(name='train', shard_lengths=[2, 3, 2, 3,
                                                      2]),  # 12 ex.
        splits.SplitInfo(name='test', shard_lengths=[2, 3, 2]),  # 7 ex.
    ]

    def setUp(self):
        super(ReaderTest, self).setUp()
        with absltest.mock.patch.object(example_parser, 'ExampleParser',
                                        testing.DummyParser):
            self.reader = tfrecords_reader.Reader(self.tmp_dir, 'some_spec')
            self.reader.read = functools.partial(
                self.reader.read,
                read_config=read_config_lib.ReadConfig(),
                shuffle_files=False,
            )

    def _write_tfrecord(self, split_name, shards_number, records):
        path = os.path.join(self.tmp_dir, 'mnist-%s.tfrecord' % split_name)
        num_examples = len(records)
        with absltest.mock.patch.object(tfrecords_writer,
                                        '_get_number_shards',
                                        return_value=shards_number):
            shard_specs = tfrecords_writer._get_shard_specs(
                num_examples, 0, [num_examples], path)
        serialized_records = [six.b(rec) for rec in records]
        for shard_spec in shard_specs:
            tfrecords_writer._write_tfrecord_from_shard_spec(
                shard_spec, lambda unused_i: iter(serialized_records))

    def _write_tfrecords(self):
        self._write_tfrecord('train', 5, 'abcdefghijkl')
        self._write_tfrecord('test', 3, 'mnopqrs')

    def test_nodata_instruction(self):
        # Given instruction corresponds to no data.
        with self.assertRaisesWithPredicateMatch(AssertionError,
                                                 'corresponds to no data!'):
            self.reader.read('mnist', 'train[0:0]', self.SPLIT_INFOS)

    def test_noskip_notake(self):
        self._write_tfrecord('train', 5, 'abcdefghijkl')
        ds = self.reader.read('mnist', 'train', self.SPLIT_INFOS)
        read_data = list(tfds.as_numpy(ds))
        self.assertEqual(read_data, [six.b(l) for l in 'abcdefghijkl'])

    def test_overlap(self):
        self._write_tfrecord('train', 5, 'abcdefghijkl')
        ds = self.reader.read('mnist', 'train+train[:2]', self.SPLIT_INFOS)
        read_data = list(tfds.as_numpy(ds))
        self.assertEqual(read_data, [six.b(l) for l in 'abcdefghijklab'])

    def test_complex(self):
        self._write_tfrecord('train', 5, 'abcdefghijkl')
        self._write_tfrecord('test', 3, 'mnopqrs')
        ds = self.reader.read('mnist', 'train[1:-1]+test[:-50%]',
                              self.SPLIT_INFOS)
        read_data = list(tfds.as_numpy(ds))
        self.assertEqual(read_data, [six.b(l) for l in 'bcdefghijkmno'])

    def test_shuffle_files(self):
        self._write_tfrecord('train', 5, 'abcdefghijkl')
        ds = self.reader.read('mnist',
                              'train',
                              self.SPLIT_INFOS,
                              shuffle_files=True)
        shards = [  # The shards of the dataset:
            [b'a', b'b'],
            [b'c', b'd', b'e'],
            [b'f', b'g'],
            [b'h', b'i', b'j'],
            [b'k', b'l'],
        ]
        # The various orders in which the dataset can be read:
        expected_permutations = [
            tuple(sum(shard, [])) for shard in itertools.permutations(shards)
        ]
        ds = ds.batch(12).repeat(100)
        read_data = set(tuple(e) for e in tfds.as_numpy(ds))
        for batch in read_data:
            self.assertIn(batch, expected_permutations)
        # There are theoritically 5! (=120) different arrangements, but we would
        # need too many repeats to be sure to get them.
        self.assertGreater(len(set(read_data)), 10)

    def test_shuffle_deterministic(self):

        self._write_tfrecord('train', 5, 'abcdefghijkl')
        read_config = read_config_lib.ReadConfig(shuffle_seed=123, )
        ds = self.reader.read('mnist',
                              'train',
                              self.SPLIT_INFOS,
                              read_config=read_config,
                              shuffle_files=True)
        ds_values = list(tfds.as_numpy(ds))

        # Check that shuffle=True with a seed provides deterministic results.
        self.assertEqual(ds_values, [
            b'a', b'b', b'k', b'l', b'h', b'i', b'j', b'c', b'd', b'e', b'f',
            b'g'
        ])

    def test_4fold(self):
        self._write_tfrecord('train', 5, 'abcdefghijkl')
        instructions = [
            tfrecords_reader.ReadInstruction('train',
                                             from_=k,
                                             to=k + 25,
                                             unit='%')
            for k in range(0, 100, 25)
        ]
        tests = self.reader.read('mnist', instructions, self.SPLIT_INFOS)
        instructions = [
            (tfrecords_reader.ReadInstruction('train', to=k, unit='%') +
             tfrecords_reader.ReadInstruction('train', from_=k + 25, unit='%'))
            for k in range(0, 100, 25)
        ]
        trains = self.reader.read('mnist', instructions, self.SPLIT_INFOS)
        read_tests = [list(r) for r in tfds.as_numpy(tests)]
        read_trains = [list(r) for r in tfds.as_numpy(trains)]
        self.assertEqual(read_tests, [[b'a', b'b', b'c'], [b'd', b'e', b'f'],
                                      [b'g', b'h', b'i'], [b'j', b'k', b'l']])
        self.assertEqual(
            read_trains,
            [[b'd', b'e', b'f', b'g', b'h', b'i', b'j', b'k', b'l'],
             [b'a', b'b', b'c', b'g', b'h', b'i', b'j', b'k', b'l'],
             [b'a', b'b', b'c', b'd', b'e', b'f', b'j', b'k', b'l'],
             [b'a', b'b', b'c', b'd', b'e', b'f', b'g', b'h', b'i']])
Exemplo n.º 28
0
 def split_info_for(name: str, shard_lengths, template) -> splits.SplitInfo:
   return splits.SplitInfo(
       name=name,
       shard_lengths=shard_lengths,
       num_bytes=0,
       filename_template=template)