def test_nothing_to_read(self): res = shard_utils.get_file_instructions( 0, 0, ['f1', 'f2', 'f3', 'f4'], [0, 3, 0, 2]) self.assertEqual(res, []) res = shard_utils.get_file_instructions( 4, 4, ['f1', 'f2', 'f3', 'f4'], [0, 3, 0, 2]) self.assertEqual(res, []) res = shard_utils.get_file_instructions( 5, 5, ['f1', 'f2', 'f3', 'f4'], [0, 3, 0, 2]) self.assertEqual(res, [])
def _make_file_instructions_from_absolutes( name: str, name2shard_lengths: Dict[str, List[int]], absolute_instructions: 'ReadInstruction', file_format: file_adapters.FileFormat = file_adapters.DEFAULT_FILE_FORMAT, ) -> List[shard_utils.FileInstruction]: """Returns the files instructions from the absolute instructions list.""" # For each split, return the files instruction (skip/take) file_instructions = [] for abs_instr in absolute_instructions: shard_lengths = name2shard_lengths[abs_instr.splitname] if not shard_lengths: raise ValueError( 'Shard empty. This might means that dataset hasn\'t been generated ' 'yet and info not restored from GCS, or that legacy dataset is used.' ) filenames = naming.filenames_for_dataset_split( dataset_name=name, split=abs_instr.splitname, num_shards=len(shard_lengths), filetype_suffix=file_adapters.ADAPTER_FOR_FORMAT[file_format]. FILE_SUFFIX) from_ = 0 if abs_instr.from_ is None else abs_instr.from_ to = sum(shard_lengths) if abs_instr.to is None else abs_instr.to single_file_instructions = shard_utils.get_file_instructions( from_, to, filenames, shard_lengths) file_instructions.extend(single_file_instructions) return file_instructions
def _get_shard_specs( num_examples: int, total_size: int, bucket_lengths: List[int], path: str, ) -> List[_ShardSpec]: """Returns list of _ShardSpec instances, corresponding to shards to write. Args: num_examples: int, number of examples in split. total_size: int (bytes), sum of example sizes. bucket_lengths: list of ints, number of examples in each bucket. path: string, path to tfrecord. `-xxxxx-of-xxxxx` will be added. """ num_shards = _get_number_shards(total_size, num_examples) shard_boundaries = _get_shard_boundaries(num_examples, num_shards) shard_specs = [] bucket_indexes = [str(i) for i in range(len(bucket_lengths))] from_ = 0 for shard_index, to in enumerate(shard_boundaries): # Read the bucket indexes file_instructions = shard_utils.get_file_instructions( from_, to, bucket_indexes, bucket_lengths) shard_path = "%s-%05d-of-%05d" % (path, shard_index, num_shards) index_path = _get_index_path(shard_path) shard_specs.append( _ShardSpec( shard_index=shard_index, path=shard_path, index_path=index_path, examples_number=to - from_, file_instructions=file_instructions, )) from_ = to return shard_specs
def test_from1_to10(self): res = shard_utils.get_file_instructions( 1, 10, ['f1', 'f2', 'f3', 'f4'], [4, 4, 0, 4]) self.assertEqual(res, [ shard_utils.FileInstruction( filename='f1', skip=1, take=-1, num_examples=3), shard_utils.FileInstruction( filename='f2', skip=0, take=-1, num_examples=4), shard_utils.FileInstruction( filename='f4', skip=0, take=2, num_examples=2), ])
def test_read_all_empty_shard(self): res = shard_utils.get_file_instructions( 0, 12, ['f1', 'f2', 'f3', 'f4'], [4, 4, 0, 4]) self.assertEqual(res, [ shard_utils.FileInstruction( filename='f1', skip=0, take=-1, num_examples=4), shard_utils.FileInstruction( filename='f2', skip=0, take=-1, num_examples=4), shard_utils.FileInstruction( filename='f4', skip=0, take=-1, num_examples=4), ])
def _file_instructions_for_split( instruction: _AbsoluteInstruction, split_info: SplitInfo, ) -> List[shard_utils.FileInstruction]: """Returns the file instructions from the given instruction applied to the given split info.""" if not split_info.num_examples: raise ValueError( 'Shard empty. This might means that dataset hasn\'t been generated ' 'yet and info not restored from GCS, or that legacy dataset is used.' ) to = split_info.num_examples if instruction.to is None else instruction.to return shard_utils.get_file_instructions( from_=instruction.from_ or 0, to=to, filenames=[os.fspath(fp) for fp in split_info.filepaths], shard_lengths=split_info.shard_lengths)
def _get_shard_specs( num_examples: int, total_size: int, bucket_lengths: List[int], filename_template: naming.ShardedFileTemplate, ) -> List[_ShardSpec]: """Returns list of _ShardSpec instances, corresponding to shards to write. Args: num_examples: int, number of examples in split. total_size: int (bytes), sum of example sizes. bucket_lengths: list of ints, number of examples in each bucket. filename_template: template to format sharded filenames. """ num_shards = _get_number_shards(total_size, num_examples) shard_boundaries = _get_shard_boundaries(num_examples, num_shards) shard_specs = [] bucket_indexes = [str(i) for i in range(len(bucket_lengths))] from_ = 0 for shard_index, to in enumerate(shard_boundaries): # Read the bucket indexes file_instructions = shard_utils.get_file_instructions( from_, to, bucket_indexes, bucket_lengths) shard_path = filename_template.sharded_filepath( shard_index=shard_index, num_shards=num_shards) index_path = _get_index_path(os.fspath(shard_path)) shard_specs.append( _ShardSpec( shard_index=shard_index, path=os.fspath(shard_path), index_path=index_path, examples_number=to - from_, file_instructions=file_instructions, )) from_ = to return shard_specs