コード例 #1
0
def test_shard_api():
  si = tfds.core.SplitInfo(
      name='train',
      shard_lengths=[10, 20, 13],
      num_bytes=0,
  )
  fi = [
      shard_utils.FileInstruction(
          filename='ds_name-train.tfrecord-00000-of-00003',
          skip=0,
          take=-1,
          num_examples=10,
      ),
      shard_utils.FileInstruction(
          filename='ds_name-train.tfrecord-00001-of-00003',
          skip=0,
          take=-1,
          num_examples=20,
      ),
      shard_utils.FileInstruction(
          filename='ds_name-train.tfrecord-00002-of-00003',
          skip=0,
          take=-1,
          num_examples=13,
      ),
  ]
  sd = splits.SplitDict([si], dataset_name='ds_name')
  assert sd['train[0shard]'].file_instructions == [fi[0]]
  assert sd['train[1shard]'].file_instructions == [fi[1]]
  assert sd['train[-1shard]'].file_instructions == [fi[-1]]
  assert sd['train[-2shard]'].file_instructions == [fi[-2]]
  assert sd['train[:2shard]'].file_instructions == fi[:2]
  assert sd['train[1shard:]'].file_instructions == fi[1:]
  assert sd['train[-1shard:]'].file_instructions == fi[-1:]
  assert sd['train[1:-1shard]'].file_instructions == fi[1:-1]
コード例 #2
0
 def test_read_files(self):
   self._write_tfrecord('train', 4, 'abcdefghijkl')
   filename_template = self._filename_template(split='train')
   ds = self.reader.read_files(
       [
           shard_utils.FileInstruction(
               filename=os.fspath(
                   filename_template.sharded_filepath(
                       shard_index=1, num_shards=4)),
               skip=0,
               take=-1,
               num_examples=3),
           shard_utils.FileInstruction(
               filename=os.fspath(
                   filename_template.sharded_filepath(
                       shard_index=3, num_shards=4)),
               skip=1,
               take=1,
               num_examples=1),
       ],
       read_config=read_config_lib.ReadConfig(),
       shuffle_files=False,
   )
   read_data = list(tfds.as_numpy(ds))
   self.assertEqual(read_data, [six.b(l) for l in 'defk'])
コード例 #3
0
ファイル: writer_test.py プロジェクト: suvarnak/datasets
 def test_4buckets_2shards(self):
   specs = writer_lib._get_shard_specs(
       num_examples=8,
       total_size=16,
       bucket_lengths=[2, 3, 0, 3],
       filename_template=naming.ShardedFileTemplate(
           dataset_name='bar',
           split='train',
           data_dir='/',
           filetype_suffix='tfrecord'))
   self.assertEqual(
       specs,
       [
           # Shard#, path, examples_number, reading instructions.
           _ShardSpec(0, '/bar-train.tfrecord-00000-of-00002',
                      '/bar-train.tfrecord-00000-of-00002_index.json', 4, [
                          shard_utils.FileInstruction(
                              filename='0', skip=0, take=-1, num_examples=2),
                          shard_utils.FileInstruction(
                              filename='1', skip=0, take=2, num_examples=2),
                      ]),
           _ShardSpec(1, '/bar-train.tfrecord-00001-of-00002',
                      '/bar-train.tfrecord-00001-of-00002_index.json', 4, [
                          shard_utils.FileInstruction(
                              filename='1', skip=2, take=-1, num_examples=1),
                          shard_utils.FileInstruction(
                              filename='3', skip=0, take=-1, num_examples=3),
                      ]),
       ])
コード例 #4
0
 def test_4buckets_2shards(self):
     specs = tfrecords_writer._get_shard_specs(num_examples=8,
                                               total_size=16,
                                               bucket_lengths=[2, 3, 0, 3],
                                               path='/bar.tfrecord')
     self.assertEqual(
         specs,
         [
             # Shard#, path, examples_number, reading instructions.
             _ShardSpec(
                 0, '/bar.tfrecord-00000-of-00002',
                 '/bar.tfrecord-00000-of-00002_index.json', 4, [
                     shard_utils.FileInstruction(
                         filename='0', skip=0, take=-1, num_examples=2),
                     shard_utils.FileInstruction(
                         filename='1', skip=0, take=2, num_examples=2),
                 ]),
             _ShardSpec(
                 1, '/bar.tfrecord-00001-of-00002',
                 '/bar.tfrecord-00001-of-00002_index.json', 4, [
                     shard_utils.FileInstruction(
                         filename='1', skip=2, take=-1, num_examples=1),
                     shard_utils.FileInstruction(
                         filename='3', skip=0, take=-1, num_examples=3),
                 ]),
         ])
コード例 #5
0
 def test_skip_take2(self):
   # 2 elements in across two shards are taken in middle.
   instruction = tfrecords_reader._AbsoluteInstruction('train', 7, 9)
   files = self._get_files(instruction)
   self.assertEqual(files, [
       shard_utils.FileInstruction(
           filename=self.PATH_PATTERN % 2, skip=2, take=-1, num_examples=1),
       shard_utils.FileInstruction(
           filename=self.PATH_PATTERN % 3, skip=0, take=1, num_examples=1),
   ])
コード例 #6
0
 def test_from1_to10(self):
   res = shard_utils.get_file_instructions(
       1, 10, ['f1', 'f2', 'f3', 'f4'], [4, 4, 0, 4])
   self.assertEqual(res, [
       shard_utils.FileInstruction(
           filename='f1', skip=1, take=-1, num_examples=3),
       shard_utils.FileInstruction(
           filename='f2', skip=0, take=-1, num_examples=4),
       shard_utils.FileInstruction(
           filename='f4', skip=0, take=2, num_examples=2),
   ])
コード例 #7
0
 def test_read_all_empty_shard(self):
   res = shard_utils.get_file_instructions(
       0, 12, ['f1', 'f2', 'f3', 'f4'], [4, 4, 0, 4])
   self.assertEqual(res, [
       shard_utils.FileInstruction(
           filename='f1', skip=0, take=-1, num_examples=4),
       shard_utils.FileInstruction(
           filename='f2', skip=0, take=-1, num_examples=4),
       shard_utils.FileInstruction(
           filename='f4', skip=0, take=-1, num_examples=4),
   ])
コード例 #8
0
 def test_take(self):
   # Two files are not taken, one file is partially taken.
   instruction = tfrecords_reader._AbsoluteInstruction('train', None, 6)
   files = self._get_files(instruction)
   self.assertEqual(files, [
       shard_utils.FileInstruction(
           filename=self.PATH_PATTERN % 0, skip=0, take=-1, num_examples=3),
       shard_utils.FileInstruction(
           filename=self.PATH_PATTERN % 1, skip=0, take=-1, num_examples=2),
       shard_utils.FileInstruction(
           filename=self.PATH_PATTERN % 2, skip=0, take=1, num_examples=1),
   ])
コード例 #9
0
 def test_skip(self):
   # One file is not taken, one file is partially taken.
   instruction = splits._AbsoluteInstruction('train', 4, None)
   files = self._get_files(instruction)
   self.assertEqual(files, [
       shard_utils.FileInstruction(
           filename=self.PATH_PATTERN % 1, skip=1, take=-1, num_examples=1),
       shard_utils.FileInstruction(
           filename=self.PATH_PATTERN % 2, skip=0, take=-1, num_examples=3),
       shard_utils.FileInstruction(
           filename=self.PATH_PATTERN % 3, skip=0, take=-1, num_examples=2),
       shard_utils.FileInstruction(
           filename=self.PATH_PATTERN % 4, skip=0, take=-1, num_examples=3),
   ])
コード例 #10
0
 def test_read_files(self):
   self._write_tfrecord('train', 4, 'abcdefghijkl')
   fname_pattern = 'mnist-train.tfrecord-0000%d-of-00004'
   ds = self.reader.read_files(
       [
           shard_utils.FileInstruction(
               filename=fname_pattern % 1, skip=0, take=-1, num_examples=3),
           shard_utils.FileInstruction(
               filename=fname_pattern % 3, skip=1, take=1, num_examples=1),
       ],
       read_config=read_config_lib.ReadConfig(),
       shuffle_files=False,
   )
   read_data = list(tfds.as_numpy(ds))
   self.assertEqual(read_data, [six.b(l) for l in 'defk'])
コード例 #11
0
ファイル: reader_test.py プロジェクト: suvarnak/datasets
 def test_cycle_length_must_be_one(self):
     self._write_tfrecord('train', 4, 'abcdefghijkl')
     filename_template = self._filename_template(split='train')
     instructions = [
         shard_utils.FileInstruction(filename=os.fspath(
             filename_template.sharded_filepath(shard_index=1,
                                                num_shards=4)),
                                     skip=0,
                                     take=-1,
                                     num_examples=3),
     ]
     # In ordered dataset interleave_cycle_length is set to 1 by default
     self.reader.read_files(
         instructions,
         read_config=read_config_lib.ReadConfig(),
         shuffle_files=False,
         disable_shuffling=True,
     )
     with self.assertRaisesWithPredicateMatch(ValueError,
                                              _CYCLE_LENGTH_ERROR_MESSAGE):
         self.reader.read_files(
             instructions,
             read_config=read_config_lib.ReadConfig(
                 interleave_cycle_length=16),
             shuffle_files=False,
             disable_shuffling=True,
         )
コード例 #12
0
 def test_cycle_length_must_be_one(self):
     self._write_tfrecord('train', 4, 'abcdefghijkl')
     fname_pattern = 'mnist-train.tfrecord-0000%d-of-00004'
     instructions = [
         shard_utils.FileInstruction(filename=fname_pattern % 1,
                                     skip=0,
                                     take=-1,
                                     num_examples=3),
     ]
     # In ordered dataset interleave_cycle_length is set to 1 by default
     self.reader.read_files(
         instructions,
         read_config=read_config_lib.ReadConfig(),
         shuffle_files=False,
         disable_shuffling=True,
     )
     with self.assertRaisesWithPredicateMatch(ValueError,
                                              _CYCLE_LENGTH_ERROR_MESSAGE):
         self.reader.read_files(
             instructions,
             read_config=read_config_lib.ReadConfig(
                 interleave_cycle_length=16),
             shuffle_files=False,
             disable_shuffling=True,
         )
コード例 #13
0
 def test_no_skip_no_take(self):
   instruction = tfrecords_reader._AbsoluteInstruction('train', None, None)
   files = self._get_files(instruction)
   self.assertEqual(files, [
       shard_utils.FileInstruction(
           filename=self.PATH_PATTERN % i, skip=0, take=-1, num_examples=n)
       for i, n in enumerate([3, 2, 3, 2, 3])
   ])
コード例 #14
0
 def test_split_file_instructions(self):
   fi = self._builder.info.splits["train"].file_instructions
   self.assertEqual(fi, [shard_utils.FileInstruction(
       filename="dummy_dataset_shared_generator-train.tfrecord-00000-of-00001",
       skip=0,
       take=-1,
       num_examples=20,
   )])
コード例 #15
0
 def test_skip_take1(self):
   # A single shard with both skip and take.
   instruction = tfrecords_reader._AbsoluteInstruction('train', 1, 2)
   files = self._get_files(instruction)
   self.assertEqual(files, [
       shard_utils.FileInstruction(
           filename=self.PATH_PATTERN % 0, skip=1, take=1, num_examples=1),
   ])
コード例 #16
0
 def test_sub_split_file_instructions(self):
   fi = self._builder.info.splits['train[75%:]'].file_instructions
   self.assertEqual(fi, [
       shard_utils.FileInstruction(
           filename=f'{self._builder.data_dir}/dummy_dataset_shared_generator-train.tfrecord-00000-of-00001',
           skip=15,
           take=-1,
           num_examples=5,
       )
   ])
コード例 #17
0
 def test_1bucket_6shards(self):
   specs = tfrecords_writer._get_shard_specs(
       num_examples=8, total_size=16, bucket_lengths=[8],
       path='/bar.tfrecord')
   self.assertEqual(specs, [
       # Shard#, path, from_bucket, examples_number, reading instructions.
       _ShardSpec(0, '/bar.tfrecord-00000-of-00006', 1, [
           shard_utils.FileInstruction(
               filename='0', skip=0, take=1, num_examples=1),
       ]),
       _ShardSpec(1, '/bar.tfrecord-00001-of-00006', 2, [
           shard_utils.FileInstruction(
               filename='0', skip=1, take=2, num_examples=2),
       ]),
       _ShardSpec(2, '/bar.tfrecord-00002-of-00006', 1, [
           shard_utils.FileInstruction(
               filename='0', skip=3, take=1, num_examples=1),
       ]),
       _ShardSpec(3, '/bar.tfrecord-00003-of-00006', 1, [
           shard_utils.FileInstruction(
               filename='0', skip=4, take=1, num_examples=1),
       ]),
       _ShardSpec(4, '/bar.tfrecord-00004-of-00006', 2, [
           shard_utils.FileInstruction(
               filename='0', skip=5, take=2, num_examples=2),
       ]),
       _ShardSpec(5, '/bar.tfrecord-00005-of-00006', 1, [
           shard_utils.FileInstruction(
               filename='0', skip=7, take=-1, num_examples=1),
       ]),
   ])
コード例 #18
0
def test_multi_split_infos():
  split = 'train'
  split_infos = [
      tfds.core.SplitInfo(
          name=split,
          shard_lengths=[10, 10],
          num_bytes=100,
          filename_template=_filename_template(split=split, data_dir='/abc')),
      tfds.core.SplitInfo(
          name=split,
          shard_lengths=[1],
          num_bytes=20,
          filename_template=_filename_template(split=split, data_dir='/xyz')),
  ]
  multi_split_info = splits.MultiSplitInfo(name=split, split_infos=split_infos)
  assert multi_split_info.num_bytes == 120
  assert multi_split_info.num_examples == 21
  assert multi_split_info.num_shards == 3
  assert multi_split_info.shard_lengths == [10, 10, 1]
  assert multi_split_info.file_instructions == [
      shard_utils.FileInstruction(
          filename='/abc/ds_name-train.tfrecord-00000-of-00002',
          skip=0,
          take=-1,
          num_examples=10),
      shard_utils.FileInstruction(
          filename='/abc/ds_name-train.tfrecord-00001-of-00002',
          skip=0,
          take=-1,
          num_examples=10),
      shard_utils.FileInstruction(
          filename='/xyz/ds_name-train.tfrecord-00000-of-00001',
          skip=0,
          take=-1,
          num_examples=1)
  ]
  assert str(multi_split_info) == (
      'MultiSplitInfo(name=\'train\', split_infos=[<SplitInfo num_examples=20, '
      'num_shards=2>, <SplitInfo num_examples=1, num_shards=1>])')
コード例 #19
0
def test_shard_api():
  si = tfds.core.SplitInfo(
      name='train',
      shard_lengths=[10, 20, 13],
      num_bytes=0,
      filename_template=naming.ShardedFileTemplate(
          dataset_name='ds_name',
          split='train',
          filetype_suffix='tfrecord',
          data_dir='/path'))
  fi = [
      shard_utils.FileInstruction(
          filename='/path/ds_name-train.tfrecord-00000-of-00003',
          skip=0,
          take=-1,
          num_examples=10,
      ),
      shard_utils.FileInstruction(
          filename='/path/ds_name-train.tfrecord-00001-of-00003',
          skip=0,
          take=-1,
          num_examples=20,
      ),
      shard_utils.FileInstruction(
          filename='/path/ds_name-train.tfrecord-00002-of-00003',
          skip=0,
          take=-1,
          num_examples=13,
      ),
  ]
  sd = splits.SplitDict([si])
  assert sd['train[0shard]'].file_instructions == [fi[0]]
  assert sd['train[1shard]'].file_instructions == [fi[1]]
  assert sd['train[-1shard]'].file_instructions == [fi[-1]]
  assert sd['train[-2shard]'].file_instructions == [fi[-2]]
  assert sd['train[:2shard]'].file_instructions == fi[:2]
  assert sd['train[1shard:]'].file_instructions == fi[1:]
  assert sd['train[-1shard:]'].file_instructions == fi[-1:]
  assert sd['train[1:-1shard]'].file_instructions == fi[1:-1]
コード例 #20
0
 def test_shuffle_files_should_be_disabled(self):
   self._write_tfrecord('train', 4, 'abcdefghijkl')
   fname_pattern = 'mnist-train.tfrecord-0000%d-of-00004'
   with self.assertRaisesWithPredicateMatch(ValueError,
                                            _SHUFFLE_FILES_ERROR_MESSAGE):
     self.reader.read_files(
         [
             shard_utils.FileInstruction(
                 filename=fname_pattern % 1, skip=0, take=-1, num_examples=3),
         ],
         read_config=read_config_lib.ReadConfig(),
         shuffle_files=True,
         disable_shuffling=True,
     )
コード例 #21
0
ファイル: reader_test.py プロジェクト: suvarnak/datasets
 def test_shuffle_files_should_be_disabled(self):
     self._write_tfrecord('train', 4, 'abcdefghijkl')
     filename_template = self._filename_template(split='train')
     with self.assertRaisesWithPredicateMatch(ValueError,
                                              _SHUFFLE_FILES_ERROR_MESSAGE):
         self.reader.read_files(
             [
                 shard_utils.FileInstruction(filename=os.fspath(
                     filename_template.sharded_filepath(shard_index=1,
                                                        num_shards=4)),
                                             skip=0,
                                             take=-1,
                                             num_examples=3),
             ],
             read_config=read_config_lib.ReadConfig(),
             shuffle_files=True,
             disable_shuffling=True,
         )
コード例 #22
0
 def test_ordering_guard(self):
   self._write_tfrecord('train', 4, 'abcdefghijkl')
   fname_pattern = 'mnist-train.tfrecord-0000%d-of-00004'
   instructions = [
       shard_utils.FileInstruction(
           filename=fname_pattern % 1, skip=0, take=-1, num_examples=3),
   ]
   reported_warnings = []
   with mock.patch('absl.logging.warning', reported_warnings.append):
     self.reader.read_files(
         instructions,
         read_config=read_config_lib.ReadConfig(
             interleave_cycle_length=16, enable_ordering_guard=False),
         shuffle_files=True,
         disable_shuffling=True,
     )
     expected_warning = _SHUFFLE_FILES_ERROR_MESSAGE + '\n' + _CYCLE_LENGTH_ERROR_MESSAGE
     self.assertIn(expected_warning, reported_warnings)
コード例 #23
0
 def test_1bucket_6shards(self):
     specs = tfrecords_writer._get_shard_specs(
         num_examples=8,
         total_size=16,
         bucket_lengths=[8],
         filename_template=naming.ShardedFileTemplate(
             dataset_name='bar',
             split='train',
             data_dir='/',
             filetype_suffix='tfrecord'))
     self.assertEqual(
         specs,
         [
             # Shard#, path, from_bucket, examples_number, reading instructions.
             _ShardSpec(
                 0, '/bar-train.tfrecord-00000-of-00006',
                 '/bar-train.tfrecord-00000-of-00006_index.json', 1, [
                     shard_utils.FileInstruction(
                         filename='0', skip=0, take=1, num_examples=1),
                 ]),
             _ShardSpec(
                 1, '/bar-train.tfrecord-00001-of-00006',
                 '/bar-train.tfrecord-00001-of-00006_index.json', 2, [
                     shard_utils.FileInstruction(
                         filename='0', skip=1, take=2, num_examples=2),
                 ]),
             _ShardSpec(
                 2, '/bar-train.tfrecord-00002-of-00006',
                 '/bar-train.tfrecord-00002-of-00006_index.json', 1, [
                     shard_utils.FileInstruction(
                         filename='0', skip=3, take=1, num_examples=1),
                 ]),
             _ShardSpec(
                 3, '/bar-train.tfrecord-00003-of-00006',
                 '/bar-train.tfrecord-00003-of-00006_index.json', 1, [
                     shard_utils.FileInstruction(
                         filename='0', skip=4, take=1, num_examples=1),
                 ]),
             _ShardSpec(
                 4, '/bar-train.tfrecord-00004-of-00006',
                 '/bar-train.tfrecord-00004-of-00006_index.json', 2, [
                     shard_utils.FileInstruction(
                         filename='0', skip=5, take=2, num_examples=2),
                 ]),
             _ShardSpec(
                 5, '/bar-train.tfrecord-00005-of-00006',
                 '/bar-train.tfrecord-00005-of-00006_index.json', 1, [
                     shard_utils.FileInstruction(
                         filename='0', skip=7, take=-1, num_examples=1),
                 ]),
         ])