def test_4buckets_2shards(self):
     specs = tfrecords_writer._get_shard_specs(num_examples=8,
                                               total_size=16,
                                               bucket_lengths=[2, 3, 0, 3],
                                               path='/bar.tfrecord')
     self.assertEqual(
         specs,
         [
             # Shard#, path, examples_number, reading instructions.
             _ShardSpec(0, '/bar.tfrecord-00000-of-00002', 4, [
                 {
                     'bucket_index': 0,
                     'skip': 0,
                     'take': -1
                 },
                 {
                     'bucket_index': 1,
                     'skip': 0,
                     'take': 2
                 },
             ]),
             _ShardSpec(1, '/bar.tfrecord-00001-of-00002', 4, [
                 {
                     'bucket_index': 1,
                     'skip': 2,
                     'take': -1
                 },
                 {
                     'bucket_index': 3,
                     'skip': 0,
                     'take': -1
                 },
             ]),
         ])
 def test_1bucket_6shards(self):
   specs = tfrecords_writer._get_shard_specs(
       num_examples=8, total_size=16, bucket_lengths=[8],
       path='/bar.tfrecord')
   self.assertEqual(specs, [
       # Shard#, path, from_bucket, examples_number, reading instructions.
       _ShardSpec(0, '/bar.tfrecord-00000-of-00006', 1, [
           shard_utils.FileInstruction(
               filename='0', skip=0, take=1, num_examples=1),
       ]),
       _ShardSpec(1, '/bar.tfrecord-00001-of-00006', 2, [
           shard_utils.FileInstruction(
               filename='0', skip=1, take=2, num_examples=2),
       ]),
       _ShardSpec(2, '/bar.tfrecord-00002-of-00006', 1, [
           shard_utils.FileInstruction(
               filename='0', skip=3, take=1, num_examples=1),
       ]),
       _ShardSpec(3, '/bar.tfrecord-00003-of-00006', 1, [
           shard_utils.FileInstruction(
               filename='0', skip=4, take=1, num_examples=1),
       ]),
       _ShardSpec(4, '/bar.tfrecord-00004-of-00006', 2, [
           shard_utils.FileInstruction(
               filename='0', skip=5, take=2, num_examples=2),
       ]),
       _ShardSpec(5, '/bar.tfrecord-00005-of-00006', 1, [
           shard_utils.FileInstruction(
               filename='0', skip=7, take=-1, num_examples=1),
       ]),
   ])
 def test_4buckets_2shards(self):
     specs = tfrecords_writer._get_shard_specs(num_examples=8,
                                               total_size=16,
                                               bucket_lengths=[2, 3, 0, 3],
                                               path='/bar.tfrecord')
     self.assertEqual(
         specs,
         [
             # Shard#, path, examples_number, reading instructions.
             _ShardSpec(
                 0, '/bar.tfrecord-00000-of-00002',
                 '/bar.tfrecord-00000-of-00002_index.json', 4, [
                     shard_utils.FileInstruction(
                         filename='0', skip=0, take=-1, num_examples=2),
                     shard_utils.FileInstruction(
                         filename='1', skip=0, take=2, num_examples=2),
                 ]),
             _ShardSpec(
                 1, '/bar.tfrecord-00001-of-00002',
                 '/bar.tfrecord-00001-of-00002_index.json', 4, [
                     shard_utils.FileInstruction(
                         filename='1', skip=2, take=-1, num_examples=1),
                     shard_utils.FileInstruction(
                         filename='3', skip=0, take=-1, num_examples=3),
                 ]),
         ])
Beispiel #4
0
 def test_4buckets_2shards(self):
     specs = tfrecords_writer._get_shard_specs(
         num_examples=8,
         total_size=16,
         bucket_lengths=[2, 3, 0, 3],
         filename_template=naming.ShardedFileTemplate(
             dataset_name='bar',
             split='train',
             data_dir='/',
             filetype_suffix='tfrecord'))
     self.assertEqual(
         specs,
         [
             # Shard#, path, examples_number, reading instructions.
             _ShardSpec(
                 0, '/bar-train.tfrecord-00000-of-00002',
                 '/bar-train.tfrecord-00000-of-00002_index.json', 4, [
                     shard_utils.FileInstruction(
                         filename='0', skip=0, take=-1, num_examples=2),
                     shard_utils.FileInstruction(
                         filename='1', skip=0, take=2, num_examples=2),
                 ]),
             _ShardSpec(
                 1, '/bar-train.tfrecord-00001-of-00002',
                 '/bar-train.tfrecord-00001-of-00002_index.json', 4, [
                     shard_utils.FileInstruction(
                         filename='1', skip=2, take=-1, num_examples=1),
                     shard_utils.FileInstruction(
                         filename='3', skip=0, take=-1, num_examples=3),
                 ]),
         ])
Beispiel #5
0
 def _write_tfrecord(self, split_name, shards_number, records):
     path = os.path.join(self.tmp_dir, 'mnist-%s.tfrecord' % split_name)
     num_examples = len(records)
     with absltest.mock.patch.object(tfrecords_writer,
                                     '_get_number_shards',
                                     return_value=shards_number):
         shard_specs = tfrecords_writer._get_shard_specs(
             num_examples, 0, [num_examples], path)
     serialized_records = [six.b(rec) for rec in records]
     for shard_spec in shard_specs:
         tfrecords_writer._write_tfrecord_from_shard_spec(
             shard_spec, lambda unused_i: iter(serialized_records))
Beispiel #6
0
 def test_1bucket_6shards(self):
     specs = tfrecords_writer._get_shard_specs(
         num_examples=8,
         total_size=16,
         bucket_lengths=[8],
         filename_template=naming.ShardedFileTemplate(
             dataset_name='bar',
             split='train',
             data_dir='/',
             filetype_suffix='tfrecord'))
     self.assertEqual(
         specs,
         [
             # Shard#, path, from_bucket, examples_number, reading instructions.
             _ShardSpec(
                 0, '/bar-train.tfrecord-00000-of-00006',
                 '/bar-train.tfrecord-00000-of-00006_index.json', 1, [
                     shard_utils.FileInstruction(
                         filename='0', skip=0, take=1, num_examples=1),
                 ]),
             _ShardSpec(
                 1, '/bar-train.tfrecord-00001-of-00006',
                 '/bar-train.tfrecord-00001-of-00006_index.json', 2, [
                     shard_utils.FileInstruction(
                         filename='0', skip=1, take=2, num_examples=2),
                 ]),
             _ShardSpec(
                 2, '/bar-train.tfrecord-00002-of-00006',
                 '/bar-train.tfrecord-00002-of-00006_index.json', 1, [
                     shard_utils.FileInstruction(
                         filename='0', skip=3, take=1, num_examples=1),
                 ]),
             _ShardSpec(
                 3, '/bar-train.tfrecord-00003-of-00006',
                 '/bar-train.tfrecord-00003-of-00006_index.json', 1, [
                     shard_utils.FileInstruction(
                         filename='0', skip=4, take=1, num_examples=1),
                 ]),
             _ShardSpec(
                 4, '/bar-train.tfrecord-00004-of-00006',
                 '/bar-train.tfrecord-00004-of-00006_index.json', 2, [
                     shard_utils.FileInstruction(
                         filename='0', skip=5, take=2, num_examples=2),
                 ]),
             _ShardSpec(
                 5, '/bar-train.tfrecord-00005-of-00006',
                 '/bar-train.tfrecord-00005-of-00006_index.json', 1, [
                     shard_utils.FileInstruction(
                         filename='0', skip=7, take=-1, num_examples=1),
                 ]),
         ])
 def _write_tfrecord(self, split_name, shards_number, records):
   path = os.path.join(self.tmp_dir, 'mnist-%s.tfrecord' % split_name)
   num_examples = len(records)
   with mock.patch.object(
       tfrecords_writer, '_get_number_shards', return_value=shards_number):
     shard_specs = tfrecords_writer._get_shard_specs(num_examples, 0,
                                                     [num_examples], path)
   serialized_records = [(key, six.b(rec)) for key, rec in enumerate(records)]
   for shard_spec in shard_specs:
     _write_tfrecord_from_shard_spec(shard_spec,
                                     lambda unused_i: iter(serialized_records))
   return splits.SplitInfo(
       name=split_name,
       shard_lengths=[int(s.examples_number) for s in shard_specs],
       num_bytes=0,
   )
 def test_1bucket_6shards(self):
     specs = tfrecords_writer._get_shard_specs(num_examples=8,
                                               total_size=16,
                                               bucket_lengths=[8],
                                               path='/bar.tfrecord')
     self.assertEqual(
         specs,
         [
             # Shard#, path, from_bucket, examples_number, reading instructions.
             _ShardSpec(0, '/bar.tfrecord-00000-of-00006', 1,
                        [{
                            'bucket_index': 0,
                            'skip': 0,
                            'take': 1
                        }]),
             _ShardSpec(1, '/bar.tfrecord-00001-of-00006', 2,
                        [{
                            'bucket_index': 0,
                            'skip': 1,
                            'take': 2
                        }]),
             _ShardSpec(2, '/bar.tfrecord-00002-of-00006', 1,
                        [{
                            'bucket_index': 0,
                            'skip': 3,
                            'take': 1
                        }]),
             _ShardSpec(3, '/bar.tfrecord-00003-of-00006', 1,
                        [{
                            'bucket_index': 0,
                            'skip': 4,
                            'take': 1
                        }]),
             _ShardSpec(4, '/bar.tfrecord-00004-of-00006', 2,
                        [{
                            'bucket_index': 0,
                            'skip': 5,
                            'take': 2
                        }]),
             _ShardSpec(5, '/bar.tfrecord-00005-of-00006', 1,
                        [{
                            'bucket_index': 0,
                            'skip': 7,
                            'take': -1
                        }]),
         ])
Beispiel #9
0
 def _write_tfrecord(self, split_name, shards_number, records):
   filename_template = self._filename_template(split=split_name)
   num_examples = len(records)
   with mock.patch.object(
       tfrecords_writer, '_get_number_shards', return_value=shards_number):
     shard_specs = tfrecords_writer._get_shard_specs(
         num_examples=num_examples,
         total_size=0,
         bucket_lengths=[num_examples],
         filename_template=filename_template)
   serialized_records = [(key, six.b(rec)) for key, rec in enumerate(records)]
   for shard_spec in shard_specs:
     _write_tfrecord_from_shard_spec(shard_spec,
                                     lambda unused_i: iter(serialized_records))
   return splits.SplitInfo(
       name=split_name,
       shard_lengths=[int(s.examples_number) for s in shard_specs],
       num_bytes=0,
       filename_template=filename_template)