def get_shuffled_batches(dataset: tf.data.Dataset, seed: int = 0, batch_size: int = 64) -> tf.data.Dataset: """Returns a Dataset that consists of padded batches when iterated over. This shuffles the examples randomly each epoch. The random order is deterministic and controlled by the seed. Batches are padded because sentences have different lengths. Sentences that are shorter in a batch will get 0s added at the end, until all sentences in the batch have the same length. Args: dataset: A TF Dataset with examples to be shuffled and batched. seed: The seed that determines the shuffling order, with a different order each epoch. batch_size: The size of each batch. The remainder is dropped. Returns: A TF Dataset containing padded batches. """ # For shuffling we need to know how many training examples we have. num_examples = dataset.reduce(np.int64(0), lambda x, _: x + 1).numpy() # `padded_shapes` says what kind of shapes to expect: [] means a scalar, [-1] # means a vector of variable length, and [1] means a vector of size 1. return dataset.shuffle(num_examples, seed=seed, reshuffle_each_iteration=True).padded_batch( batch_size, padded_shapes={ 'idx': [], 'sentence': [-1], 'label': [1], 'length': [] }, drop_remainder=True).prefetch( tf.data.experimental.AUTOTUNE)
def pipeline( self, dataset: tf.data.Dataset, input_context: tf.distribute.InputContext = None ) -> tf.data.Dataset: """Build a pipeline fetching, shuffling, and preprocessing the dataset. Args: dataset: A `tf.data.Dataset` that loads raw files. input_context: An optional context provided by `tf.distribute` for cross-replica training. This isn't necessary if using Keras compile/fit. Returns: A TensorFlow dataset outputting batched images and labels. """ if input_context and input_context.num_input_pipelines > 1: dataset = dataset.shard(input_context.num_input_pipelines, input_context.input_pipeline_id) if self.is_training and not self.config.cache: dataset = dataset.repeat() if self.config.builder == 'records': # Read the data from disk in parallel buffer_size = 8 * 1024 * 1024 # Use 8 MiB per file dataset = dataset.interleave( lambda name: tf.data.TFRecordDataset(name, buffer_size=buffer_size), cycle_length=16, num_parallel_calls=tf.data.experimental.AUTOTUNE) dataset = dataset.prefetch(self.global_batch_size) if self.config.cache: dataset = dataset.cache() if self.is_training: dataset = dataset.shuffle(self.config.shuffle_buffer_size) dataset = dataset.repeat() # Parse, pre-process, and batch the data in parallel if self.config.builder == 'records': preprocess = self.parse_record else: preprocess = self.preprocess dataset = dataset.map(preprocess, num_parallel_calls=tf.data.experimental.AUTOTUNE) dataset = dataset.batch(self.batch_size, drop_remainder=self.is_training) # Note: we could do image normalization here, but we defer it to the model # which can perform it much faster on a GPU/TPU # TODO(dankondratyuk): if we fix prefetching, we can do it here if self.is_training and self.config.deterministic_train is not None: options = tf.data.Options() options.experimental_deterministic = self.config.deterministic_train options.experimental_slack = self.config.use_slack options.experimental_optimization.parallel_batch = True options.experimental_optimization.map_fusion = True options.experimental_optimization.map_vectorization.enabled = True options.experimental_optimization.map_parallelization = True dataset = dataset.with_options(options) # Prefetch overlaps in-feed with training # Note: autotune here is not recommended, as this can lead to memory leaks. # Instead, use a constant prefetch size like the the number of devices. dataset = dataset.prefetch(self.config.num_devices) return dataset