def test_shard_api(): si = tfds.core.SplitInfo( name='train', shard_lengths=[10, 20, 13], num_bytes=0, ) fi = [ shard_utils.FileInstruction( filename='ds_name-train.tfrecord-00000-of-00003', skip=0, take=-1, num_examples=10, ), shard_utils.FileInstruction( filename='ds_name-train.tfrecord-00001-of-00003', skip=0, take=-1, num_examples=20, ), shard_utils.FileInstruction( filename='ds_name-train.tfrecord-00002-of-00003', skip=0, take=-1, num_examples=13, ), ] sd = splits.SplitDict([si], dataset_name='ds_name') assert sd['train[0shard]'].file_instructions == [fi[0]] assert sd['train[1shard]'].file_instructions == [fi[1]] assert sd['train[-1shard]'].file_instructions == [fi[-1]] assert sd['train[-2shard]'].file_instructions == [fi[-2]] assert sd['train[:2shard]'].file_instructions == fi[:2] assert sd['train[1shard:]'].file_instructions == fi[1:] assert sd['train[-1shard:]'].file_instructions == fi[-1:] assert sd['train[1:-1shard]'].file_instructions == fi[1:-1]
def test_read_files(self): self._write_tfrecord('train', 4, 'abcdefghijkl') filename_template = self._filename_template(split='train') ds = self.reader.read_files( [ shard_utils.FileInstruction( filename=os.fspath( filename_template.sharded_filepath( shard_index=1, num_shards=4)), skip=0, take=-1, num_examples=3), shard_utils.FileInstruction( filename=os.fspath( filename_template.sharded_filepath( shard_index=3, num_shards=4)), skip=1, take=1, num_examples=1), ], read_config=read_config_lib.ReadConfig(), shuffle_files=False, ) read_data = list(tfds.as_numpy(ds)) self.assertEqual(read_data, [six.b(l) for l in 'defk'])
def test_4buckets_2shards(self): specs = writer_lib._get_shard_specs( num_examples=8, total_size=16, bucket_lengths=[2, 3, 0, 3], filename_template=naming.ShardedFileTemplate( dataset_name='bar', split='train', data_dir='/', filetype_suffix='tfrecord')) self.assertEqual( specs, [ # Shard#, path, examples_number, reading instructions. _ShardSpec(0, '/bar-train.tfrecord-00000-of-00002', '/bar-train.tfrecord-00000-of-00002_index.json', 4, [ shard_utils.FileInstruction( filename='0', skip=0, take=-1, num_examples=2), shard_utils.FileInstruction( filename='1', skip=0, take=2, num_examples=2), ]), _ShardSpec(1, '/bar-train.tfrecord-00001-of-00002', '/bar-train.tfrecord-00001-of-00002_index.json', 4, [ shard_utils.FileInstruction( filename='1', skip=2, take=-1, num_examples=1), shard_utils.FileInstruction( filename='3', skip=0, take=-1, num_examples=3), ]), ])
def test_4buckets_2shards(self): specs = tfrecords_writer._get_shard_specs(num_examples=8, total_size=16, bucket_lengths=[2, 3, 0, 3], path='/bar.tfrecord') self.assertEqual( specs, [ # Shard#, path, examples_number, reading instructions. _ShardSpec( 0, '/bar.tfrecord-00000-of-00002', '/bar.tfrecord-00000-of-00002_index.json', 4, [ shard_utils.FileInstruction( filename='0', skip=0, take=-1, num_examples=2), shard_utils.FileInstruction( filename='1', skip=0, take=2, num_examples=2), ]), _ShardSpec( 1, '/bar.tfrecord-00001-of-00002', '/bar.tfrecord-00001-of-00002_index.json', 4, [ shard_utils.FileInstruction( filename='1', skip=2, take=-1, num_examples=1), shard_utils.FileInstruction( filename='3', skip=0, take=-1, num_examples=3), ]), ])
def test_skip_take2(self): # 2 elements in across two shards are taken in middle. instruction = tfrecords_reader._AbsoluteInstruction('train', 7, 9) files = self._get_files(instruction) self.assertEqual(files, [ shard_utils.FileInstruction( filename=self.PATH_PATTERN % 2, skip=2, take=-1, num_examples=1), shard_utils.FileInstruction( filename=self.PATH_PATTERN % 3, skip=0, take=1, num_examples=1), ])
def test_from1_to10(self): res = shard_utils.get_file_instructions( 1, 10, ['f1', 'f2', 'f3', 'f4'], [4, 4, 0, 4]) self.assertEqual(res, [ shard_utils.FileInstruction( filename='f1', skip=1, take=-1, num_examples=3), shard_utils.FileInstruction( filename='f2', skip=0, take=-1, num_examples=4), shard_utils.FileInstruction( filename='f4', skip=0, take=2, num_examples=2), ])
def test_read_all_empty_shard(self): res = shard_utils.get_file_instructions( 0, 12, ['f1', 'f2', 'f3', 'f4'], [4, 4, 0, 4]) self.assertEqual(res, [ shard_utils.FileInstruction( filename='f1', skip=0, take=-1, num_examples=4), shard_utils.FileInstruction( filename='f2', skip=0, take=-1, num_examples=4), shard_utils.FileInstruction( filename='f4', skip=0, take=-1, num_examples=4), ])
def test_take(self): # Two files are not taken, one file is partially taken. instruction = tfrecords_reader._AbsoluteInstruction('train', None, 6) files = self._get_files(instruction) self.assertEqual(files, [ shard_utils.FileInstruction( filename=self.PATH_PATTERN % 0, skip=0, take=-1, num_examples=3), shard_utils.FileInstruction( filename=self.PATH_PATTERN % 1, skip=0, take=-1, num_examples=2), shard_utils.FileInstruction( filename=self.PATH_PATTERN % 2, skip=0, take=1, num_examples=1), ])
def test_skip(self): # One file is not taken, one file is partially taken. instruction = splits._AbsoluteInstruction('train', 4, None) files = self._get_files(instruction) self.assertEqual(files, [ shard_utils.FileInstruction( filename=self.PATH_PATTERN % 1, skip=1, take=-1, num_examples=1), shard_utils.FileInstruction( filename=self.PATH_PATTERN % 2, skip=0, take=-1, num_examples=3), shard_utils.FileInstruction( filename=self.PATH_PATTERN % 3, skip=0, take=-1, num_examples=2), shard_utils.FileInstruction( filename=self.PATH_PATTERN % 4, skip=0, take=-1, num_examples=3), ])
def test_read_files(self): self._write_tfrecord('train', 4, 'abcdefghijkl') fname_pattern = 'mnist-train.tfrecord-0000%d-of-00004' ds = self.reader.read_files( [ shard_utils.FileInstruction( filename=fname_pattern % 1, skip=0, take=-1, num_examples=3), shard_utils.FileInstruction( filename=fname_pattern % 3, skip=1, take=1, num_examples=1), ], read_config=read_config_lib.ReadConfig(), shuffle_files=False, ) read_data = list(tfds.as_numpy(ds)) self.assertEqual(read_data, [six.b(l) for l in 'defk'])
def test_cycle_length_must_be_one(self): self._write_tfrecord('train', 4, 'abcdefghijkl') filename_template = self._filename_template(split='train') instructions = [ shard_utils.FileInstruction(filename=os.fspath( filename_template.sharded_filepath(shard_index=1, num_shards=4)), skip=0, take=-1, num_examples=3), ] # In ordered dataset interleave_cycle_length is set to 1 by default self.reader.read_files( instructions, read_config=read_config_lib.ReadConfig(), shuffle_files=False, disable_shuffling=True, ) with self.assertRaisesWithPredicateMatch(ValueError, _CYCLE_LENGTH_ERROR_MESSAGE): self.reader.read_files( instructions, read_config=read_config_lib.ReadConfig( interleave_cycle_length=16), shuffle_files=False, disable_shuffling=True, )
def test_cycle_length_must_be_one(self): self._write_tfrecord('train', 4, 'abcdefghijkl') fname_pattern = 'mnist-train.tfrecord-0000%d-of-00004' instructions = [ shard_utils.FileInstruction(filename=fname_pattern % 1, skip=0, take=-1, num_examples=3), ] # In ordered dataset interleave_cycle_length is set to 1 by default self.reader.read_files( instructions, read_config=read_config_lib.ReadConfig(), shuffle_files=False, disable_shuffling=True, ) with self.assertRaisesWithPredicateMatch(ValueError, _CYCLE_LENGTH_ERROR_MESSAGE): self.reader.read_files( instructions, read_config=read_config_lib.ReadConfig( interleave_cycle_length=16), shuffle_files=False, disable_shuffling=True, )
def test_no_skip_no_take(self): instruction = tfrecords_reader._AbsoluteInstruction('train', None, None) files = self._get_files(instruction) self.assertEqual(files, [ shard_utils.FileInstruction( filename=self.PATH_PATTERN % i, skip=0, take=-1, num_examples=n) for i, n in enumerate([3, 2, 3, 2, 3]) ])
def test_split_file_instructions(self): fi = self._builder.info.splits["train"].file_instructions self.assertEqual(fi, [shard_utils.FileInstruction( filename="dummy_dataset_shared_generator-train.tfrecord-00000-of-00001", skip=0, take=-1, num_examples=20, )])
def test_skip_take1(self): # A single shard with both skip and take. instruction = tfrecords_reader._AbsoluteInstruction('train', 1, 2) files = self._get_files(instruction) self.assertEqual(files, [ shard_utils.FileInstruction( filename=self.PATH_PATTERN % 0, skip=1, take=1, num_examples=1), ])
def test_sub_split_file_instructions(self): fi = self._builder.info.splits['train[75%:]'].file_instructions self.assertEqual(fi, [ shard_utils.FileInstruction( filename=f'{self._builder.data_dir}/dummy_dataset_shared_generator-train.tfrecord-00000-of-00001', skip=15, take=-1, num_examples=5, ) ])
def test_1bucket_6shards(self): specs = tfrecords_writer._get_shard_specs( num_examples=8, total_size=16, bucket_lengths=[8], path='/bar.tfrecord') self.assertEqual(specs, [ # Shard#, path, from_bucket, examples_number, reading instructions. _ShardSpec(0, '/bar.tfrecord-00000-of-00006', 1, [ shard_utils.FileInstruction( filename='0', skip=0, take=1, num_examples=1), ]), _ShardSpec(1, '/bar.tfrecord-00001-of-00006', 2, [ shard_utils.FileInstruction( filename='0', skip=1, take=2, num_examples=2), ]), _ShardSpec(2, '/bar.tfrecord-00002-of-00006', 1, [ shard_utils.FileInstruction( filename='0', skip=3, take=1, num_examples=1), ]), _ShardSpec(3, '/bar.tfrecord-00003-of-00006', 1, [ shard_utils.FileInstruction( filename='0', skip=4, take=1, num_examples=1), ]), _ShardSpec(4, '/bar.tfrecord-00004-of-00006', 2, [ shard_utils.FileInstruction( filename='0', skip=5, take=2, num_examples=2), ]), _ShardSpec(5, '/bar.tfrecord-00005-of-00006', 1, [ shard_utils.FileInstruction( filename='0', skip=7, take=-1, num_examples=1), ]), ])
def test_multi_split_infos(): split = 'train' split_infos = [ tfds.core.SplitInfo( name=split, shard_lengths=[10, 10], num_bytes=100, filename_template=_filename_template(split=split, data_dir='/abc')), tfds.core.SplitInfo( name=split, shard_lengths=[1], num_bytes=20, filename_template=_filename_template(split=split, data_dir='/xyz')), ] multi_split_info = splits.MultiSplitInfo(name=split, split_infos=split_infos) assert multi_split_info.num_bytes == 120 assert multi_split_info.num_examples == 21 assert multi_split_info.num_shards == 3 assert multi_split_info.shard_lengths == [10, 10, 1] assert multi_split_info.file_instructions == [ shard_utils.FileInstruction( filename='/abc/ds_name-train.tfrecord-00000-of-00002', skip=0, take=-1, num_examples=10), shard_utils.FileInstruction( filename='/abc/ds_name-train.tfrecord-00001-of-00002', skip=0, take=-1, num_examples=10), shard_utils.FileInstruction( filename='/xyz/ds_name-train.tfrecord-00000-of-00001', skip=0, take=-1, num_examples=1) ] assert str(multi_split_info) == ( 'MultiSplitInfo(name=\'train\', split_infos=[<SplitInfo num_examples=20, ' 'num_shards=2>, <SplitInfo num_examples=1, num_shards=1>])')
def test_shard_api(): si = tfds.core.SplitInfo( name='train', shard_lengths=[10, 20, 13], num_bytes=0, filename_template=naming.ShardedFileTemplate( dataset_name='ds_name', split='train', filetype_suffix='tfrecord', data_dir='/path')) fi = [ shard_utils.FileInstruction( filename='/path/ds_name-train.tfrecord-00000-of-00003', skip=0, take=-1, num_examples=10, ), shard_utils.FileInstruction( filename='/path/ds_name-train.tfrecord-00001-of-00003', skip=0, take=-1, num_examples=20, ), shard_utils.FileInstruction( filename='/path/ds_name-train.tfrecord-00002-of-00003', skip=0, take=-1, num_examples=13, ), ] sd = splits.SplitDict([si]) assert sd['train[0shard]'].file_instructions == [fi[0]] assert sd['train[1shard]'].file_instructions == [fi[1]] assert sd['train[-1shard]'].file_instructions == [fi[-1]] assert sd['train[-2shard]'].file_instructions == [fi[-2]] assert sd['train[:2shard]'].file_instructions == fi[:2] assert sd['train[1shard:]'].file_instructions == fi[1:] assert sd['train[-1shard:]'].file_instructions == fi[-1:] assert sd['train[1:-1shard]'].file_instructions == fi[1:-1]
def test_shuffle_files_should_be_disabled(self): self._write_tfrecord('train', 4, 'abcdefghijkl') fname_pattern = 'mnist-train.tfrecord-0000%d-of-00004' with self.assertRaisesWithPredicateMatch(ValueError, _SHUFFLE_FILES_ERROR_MESSAGE): self.reader.read_files( [ shard_utils.FileInstruction( filename=fname_pattern % 1, skip=0, take=-1, num_examples=3), ], read_config=read_config_lib.ReadConfig(), shuffle_files=True, disable_shuffling=True, )
def test_shuffle_files_should_be_disabled(self): self._write_tfrecord('train', 4, 'abcdefghijkl') filename_template = self._filename_template(split='train') with self.assertRaisesWithPredicateMatch(ValueError, _SHUFFLE_FILES_ERROR_MESSAGE): self.reader.read_files( [ shard_utils.FileInstruction(filename=os.fspath( filename_template.sharded_filepath(shard_index=1, num_shards=4)), skip=0, take=-1, num_examples=3), ], read_config=read_config_lib.ReadConfig(), shuffle_files=True, disable_shuffling=True, )
def test_ordering_guard(self): self._write_tfrecord('train', 4, 'abcdefghijkl') fname_pattern = 'mnist-train.tfrecord-0000%d-of-00004' instructions = [ shard_utils.FileInstruction( filename=fname_pattern % 1, skip=0, take=-1, num_examples=3), ] reported_warnings = [] with mock.patch('absl.logging.warning', reported_warnings.append): self.reader.read_files( instructions, read_config=read_config_lib.ReadConfig( interleave_cycle_length=16, enable_ordering_guard=False), shuffle_files=True, disable_shuffling=True, ) expected_warning = _SHUFFLE_FILES_ERROR_MESSAGE + '\n' + _CYCLE_LENGTH_ERROR_MESSAGE self.assertIn(expected_warning, reported_warnings)
def test_1bucket_6shards(self): specs = tfrecords_writer._get_shard_specs( num_examples=8, total_size=16, bucket_lengths=[8], filename_template=naming.ShardedFileTemplate( dataset_name='bar', split='train', data_dir='/', filetype_suffix='tfrecord')) self.assertEqual( specs, [ # Shard#, path, from_bucket, examples_number, reading instructions. _ShardSpec( 0, '/bar-train.tfrecord-00000-of-00006', '/bar-train.tfrecord-00000-of-00006_index.json', 1, [ shard_utils.FileInstruction( filename='0', skip=0, take=1, num_examples=1), ]), _ShardSpec( 1, '/bar-train.tfrecord-00001-of-00006', '/bar-train.tfrecord-00001-of-00006_index.json', 2, [ shard_utils.FileInstruction( filename='0', skip=1, take=2, num_examples=2), ]), _ShardSpec( 2, '/bar-train.tfrecord-00002-of-00006', '/bar-train.tfrecord-00002-of-00006_index.json', 1, [ shard_utils.FileInstruction( filename='0', skip=3, take=1, num_examples=1), ]), _ShardSpec( 3, '/bar-train.tfrecord-00003-of-00006', '/bar-train.tfrecord-00003-of-00006_index.json', 1, [ shard_utils.FileInstruction( filename='0', skip=4, take=1, num_examples=1), ]), _ShardSpec( 4, '/bar-train.tfrecord-00004-of-00006', '/bar-train.tfrecord-00004-of-00006_index.json', 2, [ shard_utils.FileInstruction( filename='0', skip=5, take=2, num_examples=2), ]), _ShardSpec( 5, '/bar-train.tfrecord-00005-of-00006', '/bar-train.tfrecord-00005-of-00006_index.json', 1, [ shard_utils.FileInstruction( filename='0', skip=7, take=-1, num_examples=1), ]), ])