Beispiel #1
0
 def test_nothing_to_read(self):
     res = _sharded_files.get_read_instructions(0, 0,
                                                ["f1", "f2", "f3", "f4"],
                                                [0, 3, 0, 2])
     self.assertEqual(res, [])
     res = _sharded_files.get_read_instructions(4, 4,
                                                ["f1", "f2", "f3", "f4"],
                                                [0, 3, 0, 2])
     self.assertEqual(res, [])
     res = _sharded_files.get_read_instructions(5, 5,
                                                ["f1", "f2", "f3", "f4"],
                                                [0, 3, 0, 2])
     self.assertEqual(res, [])
Beispiel #2
0
def _make_file_instructions_from_absolutes(
    name,
    name2shard_lengths,
    absolute_instructions,
):
    """Returns the files instructions from the absolute instructions list."""
    # For each split, return the files instruction (skip/take)
    file_instructions = []
    num_examples_per_shard = []
    for abs_instr in absolute_instructions:
        shard_lengths = name2shard_lengths[abs_instr.splitname]
        if not shard_lengths:
            raise ValueError(
                'Shard empty. This might means that dataset hasn\'t been generated '
                'yet and info not restored from GCS, or that legacy dataset is used.'
            )
        filenames = naming.filenames_for_dataset_split(
            dataset_name=name,
            split=abs_instr.splitname,
            num_shards=len(shard_lengths),
            filetype_suffix='tfrecord')
        from_ = 0 if abs_instr.from_ is None else abs_instr.from_
        to = sum(shard_lengths) if abs_instr.to is None else abs_instr.to
        num_examples_per_shard.append(to - from_)
        single_file_instructions = _sharded_files.get_read_instructions(
            from_, to, filenames, shard_lengths)
        file_instructions.extend(single_file_instructions)
    return FileInstructions(
        num_examples_per_shard=num_examples_per_shard,
        file_instructions=file_instructions,
    )
def _get_shard_specs(num_examples, total_size, bucket_lengths, path):
    """Returns list of _ShardSpec instances, corresponding to shards to write.

  Args:
    num_examples: int, number of examples in split.
    total_size: int (bytes), sum of example sizes.
    bucket_lengths: list of ints, number of examples in each bucket.
    path: string, path to tfrecord. `-xxxxx-of-xxxxx` will be added.
  """
    num_shards = _get_number_shards(total_size, num_examples)
    shard_boundaries = _get_shard_boundaries(num_examples, num_shards)
    shard_specs = []
    bucket_indexes = list(range(len(bucket_lengths)))
    from_ = 0
    for shard_index, to in enumerate(shard_boundaries):
        instructions = _sharded_files.get_read_instructions(
            from_,
            to,
            bucket_indexes,
            bucket_lengths,
            shardref_name="bucket_index")
        shard_specs.append(
            _ShardSpec(
                shard_index=shard_index,
                path="%s-%05d-of-%05d" % (path, shard_index, num_shards),
                examples_number=to - from_,
                reading_instructions=instructions,
            ))
        from_ = to
    return shard_specs
Beispiel #4
0
def _get_dataset_files(name, path, instruction, name2shard_lengths):
    """Returns a list of files (+skip/take) corresponding to given instruction.

  Args:
    name: Name of the dataset.
    path: path to tfrecords.
    instruction: _AbsoluteInstruction instance.
    name2shard_lengths: dict associating split names to shard lengths.

  Returns:
    list of dict(filename, skip, take).
  """
    shard_lengths = name2shard_lengths[instruction.splitname]
    if not shard_lengths:
        msg = (
            '`DatasetInfo.SplitInfo.num_shards` is empty. S3 tfrecords_reader '
            'cannot be used. Make sure the data you are trying to read was '
            'generated using tfrecords_writer module (S3).')
        raise AssertionError(msg)
    filenames = naming.filepaths_for_dataset_split(
        dataset_name=name,
        split=instruction.splitname,
        num_shards=len(shard_lengths),
        data_dir=path,
        filetype_suffix='tfrecord')
    from_ = 0 if instruction.from_ is None else instruction.from_
    to = sum(shard_lengths) if instruction.to is None else instruction.to
    return _sharded_files.get_read_instructions(from_, to, filenames,
                                                shard_lengths)
Beispiel #5
0
 def test_from1_to10(self):
     res = _sharded_files.get_read_instructions(1, 10,
                                                ["f1", "f2", "f3", "f4"],
                                                [4, 4, 0, 4])
     self.assertEqual(res, [
         {
             "filename": "f1",
             "skip": 1,
             "take": -1
         },
         {
             "filename": "f2",
             "skip": 0,
             "take": -1
         },
         {
             "filename": "f4",
             "skip": 0,
             "take": 2
         },
     ])
Beispiel #6
0
 def test_read_all_empty_shard(self):
     res = _sharded_files.get_read_instructions(0, 12,
                                                ["f1", "f2", "f3", "f4"],
                                                [4, 4, 0, 4])
     self.assertEqual(res, [
         {
             "filename": "f1",
             "skip": 0,
             "take": -1
         },
         {
             "filename": "f2",
             "skip": 0,
             "take": -1
         },
         {
             "filename": "f4",
             "skip": 0,
             "take": -1
         },
     ])
Beispiel #7
0
 def test_read_all_even_sharding(self):
     # Even sharding
     res = _sharded_files.get_read_instructions(0, 12, ["f1", "f2", "f3"],
                                                [4, 4, 4])
     self.assertEqual(res, [
         {
             "filename": "f1",
             "skip": 0,
             "take": -1
         },
         {
             "filename": "f2",
             "skip": 0,
             "take": -1
         },
         {
             "filename": "f3",
             "skip": 0,
             "take": -1
         },
     ])