def test_nothing_to_read(self): res = _sharded_files.get_read_instructions(0, 0, ["f1", "f2", "f3", "f4"], [0, 3, 0, 2]) self.assertEqual(res, []) res = _sharded_files.get_read_instructions(4, 4, ["f1", "f2", "f3", "f4"], [0, 3, 0, 2]) self.assertEqual(res, []) res = _sharded_files.get_read_instructions(5, 5, ["f1", "f2", "f3", "f4"], [0, 3, 0, 2]) self.assertEqual(res, [])
def _make_file_instructions_from_absolutes( name, name2shard_lengths, absolute_instructions, ): """Returns the files instructions from the absolute instructions list.""" # For each split, return the files instruction (skip/take) file_instructions = [] num_examples_per_shard = [] for abs_instr in absolute_instructions: shard_lengths = name2shard_lengths[abs_instr.splitname] if not shard_lengths: raise ValueError( 'Shard empty. This might means that dataset hasn\'t been generated ' 'yet and info not restored from GCS, or that legacy dataset is used.' ) filenames = naming.filenames_for_dataset_split( dataset_name=name, split=abs_instr.splitname, num_shards=len(shard_lengths), filetype_suffix='tfrecord') from_ = 0 if abs_instr.from_ is None else abs_instr.from_ to = sum(shard_lengths) if abs_instr.to is None else abs_instr.to num_examples_per_shard.append(to - from_) single_file_instructions = _sharded_files.get_read_instructions( from_, to, filenames, shard_lengths) file_instructions.extend(single_file_instructions) return FileInstructions( num_examples_per_shard=num_examples_per_shard, file_instructions=file_instructions, )
def _get_shard_specs(num_examples, total_size, bucket_lengths, path): """Returns list of _ShardSpec instances, corresponding to shards to write. Args: num_examples: int, number of examples in split. total_size: int (bytes), sum of example sizes. bucket_lengths: list of ints, number of examples in each bucket. path: string, path to tfrecord. `-xxxxx-of-xxxxx` will be added. """ num_shards = _get_number_shards(total_size, num_examples) shard_boundaries = _get_shard_boundaries(num_examples, num_shards) shard_specs = [] bucket_indexes = list(range(len(bucket_lengths))) from_ = 0 for shard_index, to in enumerate(shard_boundaries): instructions = _sharded_files.get_read_instructions( from_, to, bucket_indexes, bucket_lengths, shardref_name="bucket_index") shard_specs.append( _ShardSpec( shard_index=shard_index, path="%s-%05d-of-%05d" % (path, shard_index, num_shards), examples_number=to - from_, reading_instructions=instructions, )) from_ = to return shard_specs
def _get_dataset_files(name, path, instruction, name2shard_lengths): """Returns a list of files (+skip/take) corresponding to given instruction. Args: name: Name of the dataset. path: path to tfrecords. instruction: _AbsoluteInstruction instance. name2shard_lengths: dict associating split names to shard lengths. Returns: list of dict(filename, skip, take). """ shard_lengths = name2shard_lengths[instruction.splitname] if not shard_lengths: msg = ( '`DatasetInfo.SplitInfo.num_shards` is empty. S3 tfrecords_reader ' 'cannot be used. Make sure the data you are trying to read was ' 'generated using tfrecords_writer module (S3).') raise AssertionError(msg) filenames = naming.filepaths_for_dataset_split( dataset_name=name, split=instruction.splitname, num_shards=len(shard_lengths), data_dir=path, filetype_suffix='tfrecord') from_ = 0 if instruction.from_ is None else instruction.from_ to = sum(shard_lengths) if instruction.to is None else instruction.to return _sharded_files.get_read_instructions(from_, to, filenames, shard_lengths)
def test_from1_to10(self): res = _sharded_files.get_read_instructions(1, 10, ["f1", "f2", "f3", "f4"], [4, 4, 0, 4]) self.assertEqual(res, [ { "filename": "f1", "skip": 1, "take": -1 }, { "filename": "f2", "skip": 0, "take": -1 }, { "filename": "f4", "skip": 0, "take": 2 }, ])
def test_read_all_empty_shard(self): res = _sharded_files.get_read_instructions(0, 12, ["f1", "f2", "f3", "f4"], [4, 4, 0, 4]) self.assertEqual(res, [ { "filename": "f1", "skip": 0, "take": -1 }, { "filename": "f2", "skip": 0, "take": -1 }, { "filename": "f4", "skip": 0, "take": -1 }, ])
def test_read_all_even_sharding(self): # Even sharding res = _sharded_files.get_read_instructions(0, 12, ["f1", "f2", "f3"], [4, 4, 4]) self.assertEqual(res, [ { "filename": "f1", "skip": 0, "take": -1 }, { "filename": "f2", "skip": 0, "take": -1 }, { "filename": "f3", "skip": 0, "take": -1 }, ])