예제 #1
0
 def test_missing_shard_lengths(self):
     with self.assertRaisesWithPredicateMatch(ValueError, 'Shard empty.'):
         split_info = [
             splits.SplitInfo(name='train', shard_lengths=[]),
         ]
         tfrecords_reader.make_file_instructions('mnist', split_info,
                                                 'train')
예제 #2
0
 def file_instructions(self):
   """Returns the list of dict(filename, take, skip)."""
   # `self._dataset_name` is assigned in `SplitDict.add()`.
   instructions = tfrecords_reader.make_file_instructions(
       name=self._dataset_name,
       split_infos=[self],
       instruction=str(self.name),
   )
   return instructions.file_instructions
예제 #3
0
 def __getitem__(self, key):
     # 1st case: The key exists: `info.splits['train']`
     if str(key) in self:
         return super(SplitDict, self).__getitem__(str(key))
     # 2nd case: Uses instructions: `info.splits['train[50%]']`
     else:
         instructions = tfrecords_reader.make_file_instructions(
             name=self._dataset_name,
             split_infos=self.values(),
             instruction=key,
         )
         return SubSplitInfo(instructions)
예제 #4
0
 def __getitem__(self, key):
     if not self:
         raise KeyError(
             f"Trying to access `splits[{key!r}]` but `splits` is empty. "
             "This likely indicate the dataset has not been generated yet.")
     # 1st case: The key exists: `info.splits['train']`
     elif str(key) in self:
         return super(SplitDict, self).__getitem__(str(key))
     # 2nd case: Uses instructions: `info.splits['train[50%]']`
     else:
         instructions = tfrecords_reader.make_file_instructions(
             name=self._dataset_name,
             split_infos=self.values(),
             instruction=key,
         )
         return SubSplitInfo(instructions)
예제 #5
0
    def file_instructions(self):
        """Returns the list of dict(filename, take, skip).

    This allows for creating your own `tf.data.Dataset` using the low-level
    TFDS values.

    Example:

    ```
    file_instructions = info.splits['train[75%:]'].file_instructions
    instruction_ds = tf.data.Dataset.from_generator(
        lambda: file_instructions,
        output_types={
            'filename': tf.string,
            'take': tf.int64,
            'skip': tf.int64,
        },
    )
    ds = instruction_ds.interleave(
        lambda f: tf.data.TFRecordDataset(
            f['filename']).skip(f['skip']).take(f['take'])
    )
    ```

    When `skip=0` and `take=-1`, the full shard will be read, so the `ds.skip`
    and `ds.take` could be skipped.

    Returns:
      A `dict(filename, take, skip)`
    """
        # `self._dataset_name` is assigned in `SplitDict.add()`.
        instructions = tfrecords_reader.make_file_instructions(
            name=self._dataset_name,
            split_infos=[self],
            instruction=str(self.name),
        )
        return instructions.file_instructions