Пример #1
0
def get_shuffled_batches(dataset: tf.data.Dataset,
                         seed: int = 0,
                         batch_size: int = 64) -> tf.data.Dataset:
    """Returns a Dataset that consists of padded batches when iterated over.

  This shuffles the examples randomly each epoch. The random order is
  deterministic and controlled by the seed.

  Batches are padded because sentences have different lengths.
  Sentences that are shorter in a batch will get 0s added at the end, until
  all sentences in the batch have the same length.

  Args:
    dataset: A TF Dataset with examples to be shuffled and batched.
    seed: The seed that determines the shuffling order, with a different order
      each epoch.
    batch_size: The size of each batch. The remainder is dropped.

  Returns:
    A TF Dataset containing padded batches.
  """
    # For shuffling we need to know how many training examples we have.
    num_examples = dataset.reduce(np.int64(0), lambda x, _: x + 1).numpy()

    # `padded_shapes` says what kind of shapes to expect: [] means a scalar, [-1]
    # means a vector of variable length, and [1] means a vector of size 1.
    return dataset.shuffle(num_examples,
                           seed=seed,
                           reshuffle_each_iteration=True).padded_batch(
                               batch_size,
                               padded_shapes={
                                   'idx': [],
                                   'sentence': [-1],
                                   'label': [1],
                                   'length': []
                               },
                               drop_remainder=True).prefetch(
                                   tf.data.experimental.AUTOTUNE)
Пример #2
0
    def pipeline(
            self,
            dataset: tf.data.Dataset,
            input_context: tf.distribute.InputContext = None
    ) -> tf.data.Dataset:
        """Build a pipeline fetching, shuffling, and preprocessing the dataset.

    Args:
      dataset: A `tf.data.Dataset` that loads raw files.
      input_context: An optional context provided by `tf.distribute` for
        cross-replica training. This isn't necessary if using Keras
        compile/fit.

    Returns:
      A TensorFlow dataset outputting batched images and labels.
    """
        if input_context and input_context.num_input_pipelines > 1:
            dataset = dataset.shard(input_context.num_input_pipelines,
                                    input_context.input_pipeline_id)

        if self.is_training and not self.config.cache:
            dataset = dataset.repeat()

        if self.config.builder == 'records':
            # Read the data from disk in parallel
            buffer_size = 8 * 1024 * 1024  # Use 8 MiB per file
            dataset = dataset.interleave(
                lambda name: tf.data.TFRecordDataset(name,
                                                     buffer_size=buffer_size),
                cycle_length=16,
                num_parallel_calls=tf.data.experimental.AUTOTUNE)

        dataset = dataset.prefetch(self.global_batch_size)

        if self.config.cache:
            dataset = dataset.cache()

        if self.is_training:
            dataset = dataset.shuffle(self.config.shuffle_buffer_size)
            dataset = dataset.repeat()

        # Parse, pre-process, and batch the data in parallel
        if self.config.builder == 'records':
            preprocess = self.parse_record
        else:
            preprocess = self.preprocess
        dataset = dataset.map(preprocess,
                              num_parallel_calls=tf.data.experimental.AUTOTUNE)

        dataset = dataset.batch(self.batch_size,
                                drop_remainder=self.is_training)

        # Note: we could do image normalization here, but we defer it to the model
        # which can perform it much faster on a GPU/TPU
        # TODO(dankondratyuk): if we fix prefetching, we can do it here

        if self.is_training and self.config.deterministic_train is not None:
            options = tf.data.Options()
            options.experimental_deterministic = self.config.deterministic_train
            options.experimental_slack = self.config.use_slack
            options.experimental_optimization.parallel_batch = True
            options.experimental_optimization.map_fusion = True
            options.experimental_optimization.map_vectorization.enabled = True
            options.experimental_optimization.map_parallelization = True
            dataset = dataset.with_options(options)

        # Prefetch overlaps in-feed with training
        # Note: autotune here is not recommended, as this can lead to memory leaks.
        # Instead, use a constant prefetch size like the the number of devices.
        dataset = dataset.prefetch(self.config.num_devices)

        return dataset