def setUp(self): super(ReadInstructionTest, self).setUp() self.splits = { 'train': splits.SplitInfo( name='train', shard_lengths=[200], num_bytes=0, filename_template=_filename_template('train')), 'test': splits.SplitInfo( name='test', shard_lengths=[101], num_bytes=0, filename_template=_filename_template('test')), 'validation': splits.SplitInfo( name='validation', shard_lengths=[30], num_bytes=0, filename_template=_filename_template('validation')), 'dev-train': splits.SplitInfo( name='dev-train', shard_lengths=[5, 5], num_bytes=0, filename_template=_filename_template('dev-train')), }
def test_even_splits_subsplit(): split_infos = splits_lib.SplitDict(split_infos=[ splits_lib.SplitInfo( name='train', shard_lengths=[2, 3, 2, 3], # 10 num_bytes=0, filename_template=_filename_template('train'), ), splits_lib.SplitInfo( name='test', shard_lengths=[8], num_bytes=0, filename_template=_filename_template('test'), ), ]) # Test to split multiple splits subsplits = subsplits_utils.even_splits('train+test[50%:]', 3) expected = [ 'train[:4]+test[4:6]', 'train[4:7]+test[6:7]', 'train[7:]+test[7:8]', ] file_instructions = [split_infos[s].file_instructions for s in subsplits] expected_file_instructions = [ split_infos[s].file_instructions for s in expected ] assert file_instructions == expected_file_instructions
def test_set_splits_normal(self): info = dataset_info.DatasetInfo(builder=self._builder) split_info1 = splits_lib.SplitInfo(name="train", shard_lengths=[1, 2], num_bytes=0) split_info2 = splits_lib.SplitInfo(name="test", shard_lengths=[1], num_bytes=0) split_dict = splits_lib.SplitDict( split_infos=[split_info1, split_info2]) info.set_splits(split_dict) self.assertEqual(str(info.splits), str(split_dict)) self.assertEqual(str(info.as_proto.splits), str([split_info1.to_proto(), split_info2.to_proto()]))
def __init__( self, root_dir: str, *, shape: Optional[type_utils.Shape] = None, dtype: Optional[tf.DType] = None, ): """Construct the `DatasetBuilder`. Args: root_dir: Path to the directory containing the images. shape: Image shape forwarded to `tfds.features.Image`. dtype: Image dtype forwarded to `tfds.features.Image`. """ self._image_shape = shape self._image_dtype = dtype super(ImageFolder, self).__init__() self._data_dir = root_dir # Set data_dir to the existing dir. # Extract the splits, examples, labels root_dir = os.path.expanduser(root_dir) self._split_examples, labels = _get_split_label_images(root_dir) # Update DatasetInfo labels self.info.features['label'].names = sorted(labels) # Update DatasetInfo splits split_dict = split_lib.SplitDict(self.name) for split_name, examples in self._split_examples.items(): split_dict.add(split_lib.SplitInfo( name=split_name, shard_lengths=[len(examples)], )) self.info.update_splits_if_different(split_dict)
def _initialize_split(split_name: str, data_directory: Any, ds_name: str, filetype_suffix: str, shard_lengths: Optional[List[int]] = None, num_bytes: int = 0) -> Split: """Initializes a split. Args: split_name: name of the split. data_directory: directory where the split data will be located. ds_name: name of the dataset. filetype_suffix: file format. shard_lengths: if the split already has shards, it contains the list of the shard lenghts. If None, it assumes that the split is empty. num_bytes: number of bytes that have been written already. Returns: A Split. """ if not shard_lengths: shard_lengths = [] filename_template = naming.ShardedFileTemplate( dataset_name=ds_name, data_dir=data_directory, split=split_name, filetype_suffix=filetype_suffix, template='{DATASET}-{SPLIT}.{FILEFORMAT}-{SHARD_INDEX}', ) return Split(info=splits_lib.SplitInfo( name=split_name, shard_lengths=shard_lengths, num_bytes=num_bytes, filename_template=filename_template), complete_shards=len(shard_lengths), ds_name=ds_name)
def test_missing_shard_lengths(self): with self.assertRaisesWithPredicateMatch(ValueError, 'Shard empty.'): split_info = [ splits.SplitInfo(name='train', shard_lengths=[]), ] tfrecords_reader.make_file_instructions('mnist', split_info, 'train')
def test_nodata_instruction(self): # Given instruction corresponds to no data. with self.assertRaisesWithPredicateMatch(AssertionError, 'corresponds to no data!'): train_info = splits.SplitInfo(name='train', shard_lengths=[2, 3, 2, 3, 2]) self.reader.read('mnist', 'train[0:0]', [train_info])
def _build_from_generator( self, split_name: str, generator: Iterable[KeyExample], path: type_utils.PathLike, ) -> _SplitInfoFuture: """Split generator for example generators. Args: split_name: str, generator: Iterable[KeyExample], path: type_utils.PathLike, Returns: future: The future containing the `tfds.core.SplitInfo`. """ if self._max_examples_per_split is not None: logging.warning('Splits capped at %s examples max.', self._max_examples_per_split) generator = itertools.islice(generator, self._max_examples_per_split) total_num_examples = self._max_examples_per_split else: # If dataset info has been pre-downloaded from the internet, # we can use the pre-computed number of example for the progression bar. split_info = self._split_dict.get(split_name) if split_info and split_info.num_examples: total_num_examples = split_info.num_examples else: total_num_examples = None writer = tfrecords_writer.Writer( example_specs=self._features.get_serialized_info(), path=path, hash_salt=split_name, file_format=self._file_format, ) for key, example in utils.tqdm( generator, desc=f'Generating {split_name} examples...', unit=' examples', total=total_num_examples, leave=False, ): try: example = self._features.encode_example(example) except Exception as e: # pylint: disable=broad-except utils.reraise(e, prefix=f'Failed to encode example:\n{example}\n') writer.write(key, example) shard_lengths, total_size = writer.finalize() split_info = splits_lib.SplitInfo( name=split_name, shard_lengths=shard_lengths, num_bytes=total_size, ) return _SplitInfoFuture(lambda: split_info)
def test_even_splits_add(): # Compatibility of even_splits with other splits split_infos = splits_lib.SplitDict(split_infos=[ splits_lib.SplitInfo( name='train', shard_lengths=[2, 3, 2, 3], # 10 num_bytes=0, filename_template=_filename_template('train'), ), splits_lib.SplitInfo( name='test', shard_lengths=[8], num_bytes=0, filename_template=_filename_template('test'), ), splits_lib.SplitInfo( name='validation', shard_lengths=[8], filename_template=_filename_template('validation'), num_bytes=0, ), ]) # Test to split multiple splits split = subsplits_utils.even_splits('train', 3, drop_remainder=True)[0] split = split + 'test' expected = 'train[:3]+test' file_instructions = split_infos[split].file_instructions expected_file_instructions = (split_infos[expected].file_instructions) assert file_instructions == expected_file_instructions # Test nested `even_splits` splits = subsplits_utils.even_splits('validation', n=2) splits = subsplits_utils.even_splits(splits[1], n=2) assert (split_infos[splits[0]].file_instructions == split_infos['validation[4:6]'].file_instructions) assert (split_infos[splits[1]].file_instructions == split_infos['validation[6:8]'].file_instructions)
def _get_files(self, instruction): split_infos = [ splits.SplitInfo( name='train', shard_lengths=[3, 2, 3, 2, 3], # 13 examples. num_bytes=0, filename_template=_filename_template( dataset_name='mnist', split='train'), ) ] splits_dict = splits.SplitDict(split_infos=split_infos) return splits_dict[instruction].file_instructions
def _resolve_future(): if self._in_contextmanager: raise AssertionError('`future.result()` should be called after the ' '`maybe_beam_pipeline` contextmanager.') logging.info('Retrieving split info for %s...', split_name) shard_lengths, total_size = beam_writer.finalize() return splits_lib.SplitInfo( name=split_name, shard_lengths=shard_lengths, num_bytes=total_size, filename_template=filename_template, )
def test_missing_shard_lengths(self): with self.assertRaisesWithPredicateMatch(ValueError, 'Shard empty.'): filename_template = _filename_template( split='train', dataset_name='mnist') split_infos = [ splits.SplitInfo( name='train', shard_lengths=[], num_bytes=0, filename_template=filename_template), ] splits_dict = splits.SplitDict(split_infos=split_infos) _ = splits_dict['train'].file_instructions
def close_shard(self) -> None: """Finalizes a shard and updates the split metadata accordingly.""" if not self.current_shard: return self.current_shard.close_writer() self.info = splits_lib.SplitInfo( name=self.info.name, shard_lengths=self.info.shard_lengths + [self.current_shard.num_examples], num_bytes=self.info.num_bytes + self.current_shard.num_bytes, filename_template=self.info.filename_template) self.complete_shards += 1 self.current_shard = None
def test_nodata_instruction(self): # Given instruction corresponds to no data. with self.assertRaisesWithPredicateMatch(ValueError, 'corresponds to no data!'): train_info = splits.SplitInfo( name='train', shard_lengths=[2, 3, 2, 3, 2], num_bytes=0, filename_template=self._filename_template(split='train'), ) self.reader.read( instructions='train[0:0]', split_infos=[train_info], )
def test_set_splits_incorrect_dataset_name(self): info = dataset_info.DatasetInfo(builder=self._builder) split_info1 = splits_lib.SplitInfo( name="train", shard_lengths=[1, 2], num_bytes=0, filename_template=naming.ShardedFileTemplate( dataset_name="some_other_dataset", split="train", data_dir=info.data_dir, filetype_suffix="tfrecord")) split_dict = splits_lib.SplitDict(split_infos=[split_info1]) with pytest.raises(AssertionError, match="SplitDict contains SplitInfo for split"): info.set_splits(split_dict)
def _write_tfrecord(self, split_name, shards_number, records): path = os.path.join(self.tmp_dir, 'mnist-%s.tfrecord' % split_name) num_examples = len(records) with absltest.mock.patch.object(tfrecords_writer, '_get_number_shards', return_value=shards_number): shard_specs = tfrecords_writer._get_shard_specs( num_examples, 0, [num_examples], path) serialized_records = [six.b(rec) for rec in records] for shard_spec in shard_specs: _write_tfrecord_from_shard_spec( shard_spec, lambda unused_i: iter(serialized_records)) return splits.SplitInfo( name=split_name, shard_lengths=[int(s.examples_number) for s in shard_specs], )
def _merge_shard_info(shard_infos: List[_ShardInfo]) -> split_lib.SplitInfo: """Merge all shard info from one splits and returns the json. Args: shard_infos: The shard infos from all shards. Returns: The json SplitInfo proto """ split_name, = {s.file_info.split for s in shard_infos} shard_infos = sorted(shard_infos, key=lambda s: s.file_info.shard_index) return split_lib.SplitInfo( name=split_name, shard_lengths=[s.num_examples for s in shard_infos], num_bytes=sum(s.bytes_size for s in shard_infos), )
def __init__(self, root_dir: str): # Extract the splits, examples root_dir = os.path.expanduser(root_dir) self._split_examples, self._languages = _get_split_language_examples( root_dir) super(TranslateFolder, self).__init__() # Reset `_data_dir` as it should not change to DATA_DIR/Version self._data_dir = root_dir # Update DatasetInfo splits split_dict = split_lib.SplitDict(self.name) for split_name, examples in self._split_examples.items(): split_dict.add( split_lib.SplitInfo( name=split_name, shard_lengths=[len(next(iter(examples.values())))], )) self.info.update_splits_if_different(split_dict)
def _write_tfrecord(self, split_name, shards_number, records): filename_template = self._filename_template(split=split_name) num_examples = len(records) with mock.patch.object( tfrecords_writer, '_get_number_shards', return_value=shards_number): shard_specs = tfrecords_writer._get_shard_specs( num_examples=num_examples, total_size=0, bucket_lengths=[num_examples], filename_template=filename_template) serialized_records = [(key, six.b(rec)) for key, rec in enumerate(records)] for shard_spec in shard_specs: _write_tfrecord_from_shard_spec(shard_spec, lambda unused_i: iter(serialized_records)) return splits.SplitInfo( name=split_name, shard_lengths=[int(s.examples_number) for s in shard_specs], num_bytes=0, filename_template=filename_template)
def test_even_splits(num_examples, n, drop_remainder, expected): split_infos = splits_lib.SplitDict(split_infos=[ splits_lib.SplitInfo( name='train', shard_lengths=[num_examples], num_bytes=0, filename_template=_filename_template('train'), ), ]) subsplits = subsplits_utils.even_splits('train', n, drop_remainder=drop_remainder) file_instructions = [split_infos[s].file_instructions for s in subsplits] expected_file_instructions = [ split_infos[f'train{s}'].file_instructions for s in expected ] assert file_instructions == expected_file_instructions
def __init__(self, root_dir: str): # Extract the splits, examples root_dir = os.path.expanduser(root_dir) self._split_examples, self._languages = _get_split_language_examples( root_dir) super(TranslateFolder, self).__init__() # Reset `_data_dir` as it should not change to DATA_DIR/Version self._data_dir = root_dir # Update DatasetInfo splits split_infos = [ split_lib.SplitInfo( # pylint: disable=g-complex-comprehension name=split_name, shard_lengths=[len(next(iter(examples.values())))], num_bytes=0, ) for split_name, examples in self._split_examples.items() ] split_dict = split_lib.SplitDict(split_infos, dataset_name=self.name) self.info.set_splits(split_dict)
def __init__(self, root_dir: str): super(ImageFolder, self).__init__() self._data_dir = root_dir # Set data_dir to the existing dir. # Extract the splits, examples, labels root_dir = os.path.expanduser(root_dir) self._split_examples, labels = _get_split_label_images(root_dir) # Update DatasetInfo labels self.info.features['label'].names = sorted(labels) # Update DatasetInfo splits split_dict = split_lib.SplitDict(self.name) for split_name, examples in self._split_examples.items(): split_dict.add( split_lib.SplitInfo( name=split_name, shard_lengths=[len(examples)], )) self.info.update_splits_if_different(split_dict)
def _merge_shard_info( shard_infos: List[_ShardInfo], filename_template: naming.ShardedFileTemplate, ) -> split_lib.SplitInfo: """Merge all shard info from one splits and returns the SplitInfo. Args: shard_infos: The shard infos from all shards. filename_template: filename template of the splits. Returns: The json SplitInfo proto """ split_name, = {s.file_info.split for s in shard_infos} shard_infos = sorted(shard_infos, key=lambda s: s.file_info.shard_index) filename_template = filename_template.replace(split=split_name) return split_lib.SplitInfo( name=split_name, shard_lengths=[s.num_examples for s in shard_infos], num_bytes=sum(s.bytes_size for s in shard_infos), filename_template=filename_template, )
class GetDatasetFilesTest(testing.TestCase): SPLIT_INFOS = { 'train': splits.SplitInfo( name='train', shard_lengths=[3, 2, 3, 2, 3], # 13 examples. num_bytes=0, ), } PATH_PATTERN = 'mnist-train.tfrecord-0000%d-of-00005' def _get_files(self, instruction): file_instructions = tfrecords_reader._make_file_instructions_from_absolutes( name='mnist', split_infos=self.SPLIT_INFOS, absolute_instructions=[instruction], ) return file_instructions def test_no_skip_no_take(self): instruction = tfrecords_reader._AbsoluteInstruction('train', None, None) files = self._get_files(instruction) self.assertEqual(files, [ shard_utils.FileInstruction( filename=self.PATH_PATTERN % i, skip=0, take=-1, num_examples=n) for i, n in enumerate([3, 2, 3, 2, 3]) ]) def test_skip(self): # One file is not taken, one file is partially taken. instruction = tfrecords_reader._AbsoluteInstruction('train', 4, None) files = self._get_files(instruction) self.assertEqual(files, [ shard_utils.FileInstruction( filename=self.PATH_PATTERN % 1, skip=1, take=-1, num_examples=1), shard_utils.FileInstruction( filename=self.PATH_PATTERN % 2, skip=0, take=-1, num_examples=3), shard_utils.FileInstruction( filename=self.PATH_PATTERN % 3, skip=0, take=-1, num_examples=2), shard_utils.FileInstruction( filename=self.PATH_PATTERN % 4, skip=0, take=-1, num_examples=3), ]) def test_take(self): # Two files are not taken, one file is partially taken. instruction = tfrecords_reader._AbsoluteInstruction('train', None, 6) files = self._get_files(instruction) self.assertEqual(files, [ shard_utils.FileInstruction( filename=self.PATH_PATTERN % 0, skip=0, take=-1, num_examples=3), shard_utils.FileInstruction( filename=self.PATH_PATTERN % 1, skip=0, take=-1, num_examples=2), shard_utils.FileInstruction( filename=self.PATH_PATTERN % 2, skip=0, take=1, num_examples=1), ]) def test_skip_take1(self): # A single shard with both skip and take. instruction = tfrecords_reader._AbsoluteInstruction('train', 1, 2) files = self._get_files(instruction) self.assertEqual(files, [ shard_utils.FileInstruction( filename=self.PATH_PATTERN % 0, skip=1, take=1, num_examples=1), ]) def test_skip_take2(self): # 2 elements in across two shards are taken in middle. instruction = tfrecords_reader._AbsoluteInstruction('train', 7, 9) files = self._get_files(instruction) self.assertEqual(files, [ shard_utils.FileInstruction( filename=self.PATH_PATTERN % 2, skip=2, take=-1, num_examples=1), shard_utils.FileInstruction( filename=self.PATH_PATTERN % 3, skip=0, take=1, num_examples=1), ]) def test_touching_boundaries(self): # Nothing to read. instruction = tfrecords_reader._AbsoluteInstruction('train', 0, 0) files = self._get_files(instruction) self.assertEqual(files, []) instruction = tfrecords_reader._AbsoluteInstruction('train', None, 0) files = self._get_files(instruction) self.assertEqual(files, []) instruction = tfrecords_reader._AbsoluteInstruction('train', 3, 3) files = self._get_files(instruction) self.assertEqual(files, []) instruction = tfrecords_reader._AbsoluteInstruction('train', 13, None) files = self._get_files(instruction) self.assertEqual(files, []) def test_missing_shard_lengths(self): with self.assertRaisesWithPredicateMatch(ValueError, 'Shard empty.'): split_info = [ splits.SplitInfo(name='train', shard_lengths=[], num_bytes=0), ] tfrecords_reader.make_file_instructions('mnist', split_info, 'train')
def _build_from_generator( self, split_name: str, generator: Iterable[KeyExample], filename_template: naming.ShardedFileTemplate, disable_shuffling: bool, ) -> _SplitInfoFuture: """Split generator for example generators. Args: split_name: str, generator: Iterable[KeyExample], filename_template: Template to format the filename for a shard. disable_shuffling: Specifies whether to shuffle the examples, Returns: future: The future containing the `tfds.core.SplitInfo`. """ if self._max_examples_per_split is not None: logging.warning('Splits capped at %s examples max.', self._max_examples_per_split) generator = itertools.islice(generator, self._max_examples_per_split) total_num_examples = self._max_examples_per_split else: # If dataset info has been pre-downloaded from the internet, # we can use the pre-computed number of example for the progression bar. split_info = self._split_dict.get(split_name) if split_info and split_info.num_examples: total_num_examples = split_info.num_examples else: total_num_examples = None writer = writer_lib.Writer( serializer=example_serializer.ExampleSerializer( self._features.get_serialized_info()), filename_template=filename_template, hash_salt=split_name, disable_shuffling=disable_shuffling, # TODO(weide) remove this because it's already in filename_template? file_format=self._file_format, ) for key, example in utils.tqdm( generator, desc=f'Generating {split_name} examples...', unit=' examples', total=total_num_examples, leave=False, ): try: example = self._features.encode_example(example) except Exception as e: # pylint: disable=broad-except utils.reraise(e, prefix=f'Failed to encode example:\n{example}\n') writer.write(key, example) shard_lengths, total_size = writer.finalize() split_info = splits_lib.SplitInfo( name=split_name, shard_lengths=shard_lengths, num_bytes=total_size, filename_template=filename_template, ) return _SplitInfoFuture(lambda: split_info)
class ReaderTest(testing.TestCase): SPLIT_INFOS = [ splits.SplitInfo(name='train', shard_lengths=[2, 3, 2, 3, 2]), # 12 ex. splits.SplitInfo(name='test', shard_lengths=[2, 3, 2]), # 7 ex. ] def setUp(self): super(ReaderTest, self).setUp() with absltest.mock.patch.object(example_parser, 'ExampleParser', testing.DummyParser): self.reader = tfrecords_reader.Reader(self.tmp_dir, 'some_spec') def _write_tfrecord(self, split_name, shards_number, records): path = os.path.join(self.tmp_dir, 'mnist-%s.tfrecord' % split_name) writer = tfrecords_writer._TFRecordWriter(path, len(records), shards_number) for rec in records: writer.write(six.b(rec)) with absltest.mock.patch.object(tfrecords_writer, '_get_number_shards', return_value=shards_number): writer.finalize() def _write_tfrecords(self): self._write_tfrecord('train', 5, 'abcdefghijkl') self._write_tfrecord('test', 3, 'mnopqrs') def test_nodata_instruction(self): # Given instruction corresponds to no data. with self.assertRaisesWithPredicateMatch(AssertionError, 'corresponds to no data!'): self.reader.read('mnist', 'train[0:0]', self.SPLIT_INFOS) def test_noskip_notake(self): self._write_tfrecord('train', 5, 'abcdefghijkl') ds = self.reader.read('mnist', 'train', self.SPLIT_INFOS) read_data = list(tfds.as_numpy(ds)) self.assertEqual(read_data, [six.b(l) for l in 'abcdefghijkl']) def test_overlap(self): self._write_tfrecord('train', 5, 'abcdefghijkl') ds = self.reader.read('mnist', 'train+train[:2]', self.SPLIT_INFOS) read_data = list(tfds.as_numpy(ds)) self.assertEqual(read_data, [six.b(l) for l in 'abcdefghijklab']) def test_complex(self): self._write_tfrecord('train', 5, 'abcdefghijkl') self._write_tfrecord('test', 3, 'mnopqrs') ds = self.reader.read('mnist', 'train[1:-1]+test[:-50%]', self.SPLIT_INFOS) read_data = list(tfds.as_numpy(ds)) self.assertEqual(read_data, [six.b(l) for l in 'bcdefghijkmno']) def test_4fold(self): self._write_tfrecord('train', 5, 'abcdefghijkl') instructions = [ tfrecords_reader.ReadInstruction('train', from_=k, to=k + 25, unit='%') for k in range(0, 100, 25) ] tests = self.reader.read('mnist', instructions, self.SPLIT_INFOS) instructions = [ (tfrecords_reader.ReadInstruction('train', to=k, unit='%') + tfrecords_reader.ReadInstruction('train', from_=k + 25, unit='%')) for k in range(0, 100, 25) ] trains = self.reader.read('mnist', instructions, self.SPLIT_INFOS) read_tests = [list(r) for r in tfds.as_numpy(tests)] read_trains = [list(r) for r in tfds.as_numpy(trains)] self.assertEqual(read_tests, [[b'a', b'b', b'c'], [b'd', b'e', b'f'], [b'g', b'h', b'i'], [b'j', b'k', b'l']]) self.assertEqual( read_trains, [[b'd', b'e', b'f', b'g', b'h', b'i', b'j', b'k', b'l'], [b'a', b'b', b'c', b'g', b'h', b'i', b'j', b'k', b'l'], [b'a', b'b', b'c', b'd', b'e', b'f', b'j', b'k', b'l'], [b'a', b'b', b'c', b'd', b'e', b'f', b'g', b'h', b'i']])
class ReaderTest(testing.TestCase): SPLIT_INFOS = [ splits.SplitInfo(name='train', shard_lengths=[2, 3, 2, 3, 2]), # 12 ex. splits.SplitInfo(name='test', shard_lengths=[2, 3, 2]), # 7 ex. ] def setUp(self): super(ReaderTest, self).setUp() with absltest.mock.patch.object(example_parser, 'ExampleParser', testing.DummyParser): self.reader = tfrecords_reader.Reader(self.tmp_dir, 'some_spec') self.reader.read = functools.partial( self.reader.read, read_config=read_config_lib.ReadConfig(), shuffle_files=False, ) def _write_tfrecord(self, split_name, shards_number, records): path = os.path.join(self.tmp_dir, 'mnist-%s.tfrecord' % split_name) num_examples = len(records) with absltest.mock.patch.object(tfrecords_writer, '_get_number_shards', return_value=shards_number): shard_specs = tfrecords_writer._get_shard_specs( num_examples, 0, [num_examples], path) serialized_records = [six.b(rec) for rec in records] for shard_spec in shard_specs: tfrecords_writer._write_tfrecord_from_shard_spec( shard_spec, lambda unused_i: iter(serialized_records)) def _write_tfrecords(self): self._write_tfrecord('train', 5, 'abcdefghijkl') self._write_tfrecord('test', 3, 'mnopqrs') def test_nodata_instruction(self): # Given instruction corresponds to no data. with self.assertRaisesWithPredicateMatch(AssertionError, 'corresponds to no data!'): self.reader.read('mnist', 'train[0:0]', self.SPLIT_INFOS) def test_noskip_notake(self): self._write_tfrecord('train', 5, 'abcdefghijkl') ds = self.reader.read('mnist', 'train', self.SPLIT_INFOS) read_data = list(tfds.as_numpy(ds)) self.assertEqual(read_data, [six.b(l) for l in 'abcdefghijkl']) def test_overlap(self): self._write_tfrecord('train', 5, 'abcdefghijkl') ds = self.reader.read('mnist', 'train+train[:2]', self.SPLIT_INFOS) read_data = list(tfds.as_numpy(ds)) self.assertEqual(read_data, [six.b(l) for l in 'abcdefghijklab']) def test_complex(self): self._write_tfrecord('train', 5, 'abcdefghijkl') self._write_tfrecord('test', 3, 'mnopqrs') ds = self.reader.read('mnist', 'train[1:-1]+test[:-50%]', self.SPLIT_INFOS) read_data = list(tfds.as_numpy(ds)) self.assertEqual(read_data, [six.b(l) for l in 'bcdefghijkmno']) def test_shuffle_files(self): self._write_tfrecord('train', 5, 'abcdefghijkl') ds = self.reader.read('mnist', 'train', self.SPLIT_INFOS, shuffle_files=True) shards = [ # The shards of the dataset: [b'a', b'b'], [b'c', b'd', b'e'], [b'f', b'g'], [b'h', b'i', b'j'], [b'k', b'l'], ] # The various orders in which the dataset can be read: expected_permutations = [ tuple(sum(shard, [])) for shard in itertools.permutations(shards) ] ds = ds.batch(12).repeat(100) read_data = set(tuple(e) for e in tfds.as_numpy(ds)) for batch in read_data: self.assertIn(batch, expected_permutations) # There are theoritically 5! (=120) different arrangements, but we would # need too many repeats to be sure to get them. self.assertGreater(len(set(read_data)), 10) def test_shuffle_deterministic(self): self._write_tfrecord('train', 5, 'abcdefghijkl') read_config = read_config_lib.ReadConfig(shuffle_seed=123, ) ds = self.reader.read('mnist', 'train', self.SPLIT_INFOS, read_config=read_config, shuffle_files=True) ds_values = list(tfds.as_numpy(ds)) # Check that shuffle=True with a seed provides deterministic results. self.assertEqual(ds_values, [ b'a', b'b', b'k', b'l', b'h', b'i', b'j', b'c', b'd', b'e', b'f', b'g' ]) def test_4fold(self): self._write_tfrecord('train', 5, 'abcdefghijkl') instructions = [ tfrecords_reader.ReadInstruction('train', from_=k, to=k + 25, unit='%') for k in range(0, 100, 25) ] tests = self.reader.read('mnist', instructions, self.SPLIT_INFOS) instructions = [ (tfrecords_reader.ReadInstruction('train', to=k, unit='%') + tfrecords_reader.ReadInstruction('train', from_=k + 25, unit='%')) for k in range(0, 100, 25) ] trains = self.reader.read('mnist', instructions, self.SPLIT_INFOS) read_tests = [list(r) for r in tfds.as_numpy(tests)] read_trains = [list(r) for r in tfds.as_numpy(trains)] self.assertEqual(read_tests, [[b'a', b'b', b'c'], [b'd', b'e', b'f'], [b'g', b'h', b'i'], [b'j', b'k', b'l']]) self.assertEqual( read_trains, [[b'd', b'e', b'f', b'g', b'h', b'i', b'j', b'k', b'l'], [b'a', b'b', b'c', b'g', b'h', b'i', b'j', b'k', b'l'], [b'a', b'b', b'c', b'd', b'e', b'f', b'j', b'k', b'l'], [b'a', b'b', b'c', b'd', b'e', b'f', b'g', b'h', b'i']])
def split_info_for(name: str, shard_lengths, template) -> splits.SplitInfo: return splits.SplitInfo( name=name, shard_lengths=shard_lengths, num_bytes=0, filename_template=template)