Пример #1
0
def _repeat_batch(batch_sizes: Sequence[int],
                  ds: tf.data.Dataset,
                  repeat: int = 1) -> tf.data.Dataset:
    """Tiles the inner most batch dimension."""
    if repeat <= 1:
        return ds
    if batch_sizes[-1] % repeat != 0:
        raise ValueError(
            f'The last element of `batch_sizes` ({batch_sizes}) must '
            f'be divisible by `repeat` ({repeat}).')
    # Perform regular batching with reduced number of elements.
    for i, batch_size in enumerate(reversed(batch_sizes)):
        ds = ds.batch(batch_size // repeat if i == 0 else batch_size,
                      drop_remainder=True)
    # Repeat batch.
    fn = lambda x: tf.repeat(x, repeats=repeat, axis=len(batch_sizes) - 1)

    def repeat_inner_batch(example):
        return jax.tree_map(fn, example)

    ds = ds.map(repeat_inner_batch, num_parallel_calls=tf.data.AUTOTUNE)
    # Unbatch.
    for _ in batch_sizes:
        ds = ds.unbatch()
    return ds
Пример #2
0
 def _batch(self, split: Split,
            dataset: tf.data.Dataset,
            drop_remainder: bool = True) -> tf.data.Dataset:
   """Get the batched version of `dataset`."""
   # `uneven_datasets` is a list of datasets with a number of validation and/or
   # test examples that is not evenly divisible by commonly used batch sizes.
   uneven_datasets = ['criteo', 'svhn']
   if self._is_training(split):
     batch_size = self.batch_size
   elif split == Split.VAL:
     batch_size = self.eval_batch_size
     if (self._num_validation_examples % batch_size != 0 and
         self.name not in uneven_datasets):
       logging.warn(
           'Batch size does not evenly divide the number of validation '
           'examples , cannot ensure static shapes on TPU. Batch size: %d, '
           'validation examples: %d',
           batch_size,
           self._num_validation_examples)
   else:
     batch_size = self.eval_batch_size
     if (self._num_test_examples % batch_size != 0 and
         self.name not in uneven_datasets):
       logging.warn(
           'Batch size does not evenly divide the number of test examples, '
           'cannot ensure static shapes on TPU. Batch size: %d, test '
           'examples: %d', batch_size, self._num_test_examples)
   # Note that we always drop the last batch when the batch size does not
   # evenly divide the number of examples.
   return dataset.batch(batch_size, drop_remainder=drop_remainder)
Пример #3
0
    def pipeline(
            self,
            dataset: tf.data.Dataset,
            input_context: tf.distribute.InputContext = None
    ) -> tf.data.Dataset:
        """Build a pipeline fetching, shuffling, and preprocessing the dataset.

    Args:
      dataset: A `tf.data.Dataset` that loads raw files.
      input_context: An optional context provided by `tf.distribute` for
        cross-replica training. This isn't necessary if using Keras
        compile/fit.

    Returns:
      A TensorFlow dataset outputting batched images and labels.
    """
        if input_context and input_context.num_input_pipelines > 1:
            dataset = dataset.shard(input_context.num_input_pipelines,
                                    input_context.input_pipeline_id)

        if self.is_training and not self.config.cache:
            dataset = dataset.repeat()

        if self.config.builder == 'records':
            # Read the data from disk in parallel
            buffer_size = 8 * 1024 * 1024  # Use 8 MiB per file
            dataset = dataset.interleave(
                lambda name: tf.data.TFRecordDataset(name,
                                                     buffer_size=buffer_size),
                cycle_length=16,
                num_parallel_calls=tf.data.experimental.AUTOTUNE)

        dataset = dataset.prefetch(self.global_batch_size)

        if self.config.cache:
            dataset = dataset.cache()

        if self.is_training:
            dataset = dataset.shuffle(self.config.shuffle_buffer_size)
            dataset = dataset.repeat()

        # Parse, pre-process, and batch the data in parallel
        if self.config.builder == 'records':
            preprocess = self.parse_record
        else:
            preprocess = self.preprocess
        dataset = dataset.map(preprocess,
                              num_parallel_calls=tf.data.experimental.AUTOTUNE)

        dataset = dataset.batch(self.batch_size,
                                drop_remainder=self.is_training)

        # Note: we could do image normalization here, but we defer it to the model
        # which can perform it much faster on a GPU/TPU
        # TODO(dankondratyuk): if we fix prefetching, we can do it here

        if self.is_training and self.config.deterministic_train is not None:
            options = tf.data.Options()
            options.experimental_deterministic = self.config.deterministic_train
            options.experimental_slack = self.config.use_slack
            options.experimental_optimization.parallel_batch = True
            options.experimental_optimization.map_fusion = True
            options.experimental_optimization.map_vectorization.enabled = True
            options.experimental_optimization.map_parallelization = True
            dataset = dataset.with_options(options)

        # Prefetch overlaps in-feed with training
        # Note: autotune here is not recommended, as this can lead to memory leaks.
        # Instead, use a constant prefetch size like the the number of devices.
        dataset = dataset.prefetch(self.config.num_devices)

        return dataset