def test_4buckets_2shards(self): specs = tfrecords_writer._get_shard_specs(num_examples=8, total_size=16, bucket_lengths=[2, 3, 0, 3], path='/bar.tfrecord') self.assertEqual( specs, [ # Shard#, path, examples_number, reading instructions. _ShardSpec(0, '/bar.tfrecord-00000-of-00002', 4, [ { 'bucket_index': 0, 'skip': 0, 'take': -1 }, { 'bucket_index': 1, 'skip': 0, 'take': 2 }, ]), _ShardSpec(1, '/bar.tfrecord-00001-of-00002', 4, [ { 'bucket_index': 1, 'skip': 2, 'take': -1 }, { 'bucket_index': 3, 'skip': 0, 'take': -1 }, ]), ])
def test_1bucket_6shards(self): specs = tfrecords_writer._get_shard_specs( num_examples=8, total_size=16, bucket_lengths=[8], path='/bar.tfrecord') self.assertEqual(specs, [ # Shard#, path, from_bucket, examples_number, reading instructions. _ShardSpec(0, '/bar.tfrecord-00000-of-00006', 1, [ shard_utils.FileInstruction( filename='0', skip=0, take=1, num_examples=1), ]), _ShardSpec(1, '/bar.tfrecord-00001-of-00006', 2, [ shard_utils.FileInstruction( filename='0', skip=1, take=2, num_examples=2), ]), _ShardSpec(2, '/bar.tfrecord-00002-of-00006', 1, [ shard_utils.FileInstruction( filename='0', skip=3, take=1, num_examples=1), ]), _ShardSpec(3, '/bar.tfrecord-00003-of-00006', 1, [ shard_utils.FileInstruction( filename='0', skip=4, take=1, num_examples=1), ]), _ShardSpec(4, '/bar.tfrecord-00004-of-00006', 2, [ shard_utils.FileInstruction( filename='0', skip=5, take=2, num_examples=2), ]), _ShardSpec(5, '/bar.tfrecord-00005-of-00006', 1, [ shard_utils.FileInstruction( filename='0', skip=7, take=-1, num_examples=1), ]), ])
def test_4buckets_2shards(self): specs = tfrecords_writer._get_shard_specs(num_examples=8, total_size=16, bucket_lengths=[2, 3, 0, 3], path='/bar.tfrecord') self.assertEqual( specs, [ # Shard#, path, examples_number, reading instructions. _ShardSpec( 0, '/bar.tfrecord-00000-of-00002', '/bar.tfrecord-00000-of-00002_index.json', 4, [ shard_utils.FileInstruction( filename='0', skip=0, take=-1, num_examples=2), shard_utils.FileInstruction( filename='1', skip=0, take=2, num_examples=2), ]), _ShardSpec( 1, '/bar.tfrecord-00001-of-00002', '/bar.tfrecord-00001-of-00002_index.json', 4, [ shard_utils.FileInstruction( filename='1', skip=2, take=-1, num_examples=1), shard_utils.FileInstruction( filename='3', skip=0, take=-1, num_examples=3), ]), ])
def test_4buckets_2shards(self): specs = tfrecords_writer._get_shard_specs( num_examples=8, total_size=16, bucket_lengths=[2, 3, 0, 3], filename_template=naming.ShardedFileTemplate( dataset_name='bar', split='train', data_dir='/', filetype_suffix='tfrecord')) self.assertEqual( specs, [ # Shard#, path, examples_number, reading instructions. _ShardSpec( 0, '/bar-train.tfrecord-00000-of-00002', '/bar-train.tfrecord-00000-of-00002_index.json', 4, [ shard_utils.FileInstruction( filename='0', skip=0, take=-1, num_examples=2), shard_utils.FileInstruction( filename='1', skip=0, take=2, num_examples=2), ]), _ShardSpec( 1, '/bar-train.tfrecord-00001-of-00002', '/bar-train.tfrecord-00001-of-00002_index.json', 4, [ shard_utils.FileInstruction( filename='1', skip=2, take=-1, num_examples=1), shard_utils.FileInstruction( filename='3', skip=0, take=-1, num_examples=3), ]), ])
def _write_tfrecord(self, split_name, shards_number, records): path = os.path.join(self.tmp_dir, 'mnist-%s.tfrecord' % split_name) num_examples = len(records) with absltest.mock.patch.object(tfrecords_writer, '_get_number_shards', return_value=shards_number): shard_specs = tfrecords_writer._get_shard_specs( num_examples, 0, [num_examples], path) serialized_records = [six.b(rec) for rec in records] for shard_spec in shard_specs: tfrecords_writer._write_tfrecord_from_shard_spec( shard_spec, lambda unused_i: iter(serialized_records))
def test_1bucket_6shards(self): specs = tfrecords_writer._get_shard_specs( num_examples=8, total_size=16, bucket_lengths=[8], filename_template=naming.ShardedFileTemplate( dataset_name='bar', split='train', data_dir='/', filetype_suffix='tfrecord')) self.assertEqual( specs, [ # Shard#, path, from_bucket, examples_number, reading instructions. _ShardSpec( 0, '/bar-train.tfrecord-00000-of-00006', '/bar-train.tfrecord-00000-of-00006_index.json', 1, [ shard_utils.FileInstruction( filename='0', skip=0, take=1, num_examples=1), ]), _ShardSpec( 1, '/bar-train.tfrecord-00001-of-00006', '/bar-train.tfrecord-00001-of-00006_index.json', 2, [ shard_utils.FileInstruction( filename='0', skip=1, take=2, num_examples=2), ]), _ShardSpec( 2, '/bar-train.tfrecord-00002-of-00006', '/bar-train.tfrecord-00002-of-00006_index.json', 1, [ shard_utils.FileInstruction( filename='0', skip=3, take=1, num_examples=1), ]), _ShardSpec( 3, '/bar-train.tfrecord-00003-of-00006', '/bar-train.tfrecord-00003-of-00006_index.json', 1, [ shard_utils.FileInstruction( filename='0', skip=4, take=1, num_examples=1), ]), _ShardSpec( 4, '/bar-train.tfrecord-00004-of-00006', '/bar-train.tfrecord-00004-of-00006_index.json', 2, [ shard_utils.FileInstruction( filename='0', skip=5, take=2, num_examples=2), ]), _ShardSpec( 5, '/bar-train.tfrecord-00005-of-00006', '/bar-train.tfrecord-00005-of-00006_index.json', 1, [ shard_utils.FileInstruction( filename='0', skip=7, take=-1, num_examples=1), ]), ])
def _write_tfrecord(self, split_name, shards_number, records): path = os.path.join(self.tmp_dir, 'mnist-%s.tfrecord' % split_name) num_examples = len(records) with mock.patch.object( tfrecords_writer, '_get_number_shards', return_value=shards_number): shard_specs = tfrecords_writer._get_shard_specs(num_examples, 0, [num_examples], path) serialized_records = [(key, six.b(rec)) for key, rec in enumerate(records)] for shard_spec in shard_specs: _write_tfrecord_from_shard_spec(shard_spec, lambda unused_i: iter(serialized_records)) return splits.SplitInfo( name=split_name, shard_lengths=[int(s.examples_number) for s in shard_specs], num_bytes=0, )
def test_1bucket_6shards(self): specs = tfrecords_writer._get_shard_specs(num_examples=8, total_size=16, bucket_lengths=[8], path='/bar.tfrecord') self.assertEqual( specs, [ # Shard#, path, from_bucket, examples_number, reading instructions. _ShardSpec(0, '/bar.tfrecord-00000-of-00006', 1, [{ 'bucket_index': 0, 'skip': 0, 'take': 1 }]), _ShardSpec(1, '/bar.tfrecord-00001-of-00006', 2, [{ 'bucket_index': 0, 'skip': 1, 'take': 2 }]), _ShardSpec(2, '/bar.tfrecord-00002-of-00006', 1, [{ 'bucket_index': 0, 'skip': 3, 'take': 1 }]), _ShardSpec(3, '/bar.tfrecord-00003-of-00006', 1, [{ 'bucket_index': 0, 'skip': 4, 'take': 1 }]), _ShardSpec(4, '/bar.tfrecord-00004-of-00006', 2, [{ 'bucket_index': 0, 'skip': 5, 'take': 2 }]), _ShardSpec(5, '/bar.tfrecord-00005-of-00006', 1, [{ 'bucket_index': 0, 'skip': 7, 'take': -1 }]), ])
def _write_tfrecord(self, split_name, shards_number, records): filename_template = self._filename_template(split=split_name) num_examples = len(records) with mock.patch.object( tfrecords_writer, '_get_number_shards', return_value=shards_number): shard_specs = tfrecords_writer._get_shard_specs( num_examples=num_examples, total_size=0, bucket_lengths=[num_examples], filename_template=filename_template) serialized_records = [(key, six.b(rec)) for key, rec in enumerate(records)] for shard_spec in shard_specs: _write_tfrecord_from_shard_spec(shard_spec, lambda unused_i: iter(serialized_records)) return splits.SplitInfo( name=split_name, shard_lengths=[int(s.examples_number) for s in shard_specs], num_bytes=0, filename_template=filename_template)