def test_4buckets_2shards(self): specs = tfrecords_writer._get_shard_specs(num_examples=8, total_size=16, bucket_lengths=[2, 3, 0, 3], path='/bar.tfrecord') self.assertEqual( specs, [ # Shard#, path, examples_number, reading instructions. _ShardSpec(0, '/bar.tfrecord-00000-of-00002', 4, [ { 'bucket_index': 0, 'skip': 0, 'take': -1 }, { 'bucket_index': 1, 'skip': 0, 'take': 2 }, ]), _ShardSpec(1, '/bar.tfrecord-00001-of-00002', 4, [ { 'bucket_index': 1, 'skip': 2, 'take': -1 }, { 'bucket_index': 3, 'skip': 0, 'take': -1 }, ]), ])
def test_4buckets_2shards(self): specs = tfrecords_writer._get_shard_specs(num_examples=8, total_size=16, bucket_lengths=[2, 3, 0, 3], path='/bar.tfrecord') self.assertEqual( specs, [ # Shard#, path, examples_number, reading instructions. _ShardSpec( 0, '/bar.tfrecord-00000-of-00002', '/bar.tfrecord-00000-of-00002_index.json', 4, [ shard_utils.FileInstruction( filename='0', skip=0, take=-1, num_examples=2), shard_utils.FileInstruction( filename='1', skip=0, take=2, num_examples=2), ]), _ShardSpec( 1, '/bar.tfrecord-00001-of-00002', '/bar.tfrecord-00001-of-00002_index.json', 4, [ shard_utils.FileInstruction( filename='1', skip=2, take=-1, num_examples=1), shard_utils.FileInstruction( filename='3', skip=0, take=-1, num_examples=3), ]), ])
def test_4buckets_2shards(self): specs = tfrecords_writer._get_shard_specs( num_examples=8, total_size=16, bucket_lengths=[2, 3, 0, 3], filename_template=naming.ShardedFileTemplate( dataset_name='bar', split='train', data_dir='/', filetype_suffix='tfrecord')) self.assertEqual( specs, [ # Shard#, path, examples_number, reading instructions. _ShardSpec( 0, '/bar-train.tfrecord-00000-of-00002', '/bar-train.tfrecord-00000-of-00002_index.json', 4, [ shard_utils.FileInstruction( filename='0', skip=0, take=-1, num_examples=2), shard_utils.FileInstruction( filename='1', skip=0, take=2, num_examples=2), ]), _ShardSpec( 1, '/bar-train.tfrecord-00001-of-00002', '/bar-train.tfrecord-00001-of-00002_index.json', 4, [ shard_utils.FileInstruction( filename='1', skip=2, take=-1, num_examples=1), shard_utils.FileInstruction( filename='3', skip=0, take=-1, num_examples=3), ]), ])
def test_1bucket_6shards(self): specs = tfrecords_writer._get_shard_specs( num_examples=8, total_size=16, bucket_lengths=[8], path='/bar.tfrecord') self.assertEqual(specs, [ # Shard#, path, from_bucket, examples_number, reading instructions. _ShardSpec(0, '/bar.tfrecord-00000-of-00006', 1, [ shard_utils.FileInstruction( filename='0', skip=0, take=1, num_examples=1), ]), _ShardSpec(1, '/bar.tfrecord-00001-of-00006', 2, [ shard_utils.FileInstruction( filename='0', skip=1, take=2, num_examples=2), ]), _ShardSpec(2, '/bar.tfrecord-00002-of-00006', 1, [ shard_utils.FileInstruction( filename='0', skip=3, take=1, num_examples=1), ]), _ShardSpec(3, '/bar.tfrecord-00003-of-00006', 1, [ shard_utils.FileInstruction( filename='0', skip=4, take=1, num_examples=1), ]), _ShardSpec(4, '/bar.tfrecord-00004-of-00006', 2, [ shard_utils.FileInstruction( filename='0', skip=5, take=2, num_examples=2), ]), _ShardSpec(5, '/bar.tfrecord-00005-of-00006', 1, [ shard_utils.FileInstruction( filename='0', skip=7, take=-1, num_examples=1), ]), ])
def test_1bucket_6shards(self): specs = tfrecords_writer._get_shard_specs( num_examples=8, total_size=16, bucket_lengths=[8], filename_template=naming.ShardedFileTemplate( dataset_name='bar', split='train', data_dir='/', filetype_suffix='tfrecord')) self.assertEqual( specs, [ # Shard#, path, from_bucket, examples_number, reading instructions. _ShardSpec( 0, '/bar-train.tfrecord-00000-of-00006', '/bar-train.tfrecord-00000-of-00006_index.json', 1, [ shard_utils.FileInstruction( filename='0', skip=0, take=1, num_examples=1), ]), _ShardSpec( 1, '/bar-train.tfrecord-00001-of-00006', '/bar-train.tfrecord-00001-of-00006_index.json', 2, [ shard_utils.FileInstruction( filename='0', skip=1, take=2, num_examples=2), ]), _ShardSpec( 2, '/bar-train.tfrecord-00002-of-00006', '/bar-train.tfrecord-00002-of-00006_index.json', 1, [ shard_utils.FileInstruction( filename='0', skip=3, take=1, num_examples=1), ]), _ShardSpec( 3, '/bar-train.tfrecord-00003-of-00006', '/bar-train.tfrecord-00003-of-00006_index.json', 1, [ shard_utils.FileInstruction( filename='0', skip=4, take=1, num_examples=1), ]), _ShardSpec( 4, '/bar-train.tfrecord-00004-of-00006', '/bar-train.tfrecord-00004-of-00006_index.json', 2, [ shard_utils.FileInstruction( filename='0', skip=5, take=2, num_examples=2), ]), _ShardSpec( 5, '/bar-train.tfrecord-00005-of-00006', '/bar-train.tfrecord-00005-of-00006_index.json', 1, [ shard_utils.FileInstruction( filename='0', skip=7, take=-1, num_examples=1), ]), ])
def test_1bucket_6shards(self): specs = tfrecords_writer._get_shard_specs(num_examples=8, total_size=16, bucket_lengths=[8], path='/bar.tfrecord') self.assertEqual( specs, [ # Shard#, path, from_bucket, examples_number, reading instructions. _ShardSpec(0, '/bar.tfrecord-00000-of-00006', 1, [{ 'bucket_index': 0, 'skip': 0, 'take': 1 }]), _ShardSpec(1, '/bar.tfrecord-00001-of-00006', 2, [{ 'bucket_index': 0, 'skip': 1, 'take': 2 }]), _ShardSpec(2, '/bar.tfrecord-00002-of-00006', 1, [{ 'bucket_index': 0, 'skip': 3, 'take': 1 }]), _ShardSpec(3, '/bar.tfrecord-00003-of-00006', 1, [{ 'bucket_index': 0, 'skip': 4, 'take': 1 }]), _ShardSpec(4, '/bar.tfrecord-00004-of-00006', 2, [{ 'bucket_index': 0, 'skip': 5, 'take': 2 }]), _ShardSpec(5, '/bar.tfrecord-00005-of-00006', 1, [{ 'bucket_index': 0, 'skip': 7, 'take': -1 }]), ])