def _repeat_batch(batch_sizes: Sequence[int], ds: tf.data.Dataset, repeat: int = 1) -> tf.data.Dataset: """Tiles the inner most batch dimension.""" if repeat <= 1: return ds if batch_sizes[-1] % repeat != 0: raise ValueError( f'The last element of `batch_sizes` ({batch_sizes}) must ' f'be divisible by `repeat` ({repeat}).') # Perform regular batching with reduced number of elements. for i, batch_size in enumerate(reversed(batch_sizes)): ds = ds.batch(batch_size // repeat if i == 0 else batch_size, drop_remainder=True) # Repeat batch. fn = lambda x: tf.repeat(x, repeats=repeat, axis=len(batch_sizes) - 1) def repeat_inner_batch(example): return jax.tree_map(fn, example) ds = ds.map(repeat_inner_batch, num_parallel_calls=tf.data.AUTOTUNE) # Unbatch. for _ in batch_sizes: ds = ds.unbatch() return ds
def _batch(self, split: Split, dataset: tf.data.Dataset, drop_remainder: bool = True) -> tf.data.Dataset: """Get the batched version of `dataset`.""" # `uneven_datasets` is a list of datasets with a number of validation and/or # test examples that is not evenly divisible by commonly used batch sizes. uneven_datasets = ['criteo', 'svhn'] if self._is_training(split): batch_size = self.batch_size elif split == Split.VAL: batch_size = self.eval_batch_size if (self._num_validation_examples % batch_size != 0 and self.name not in uneven_datasets): logging.warn( 'Batch size does not evenly divide the number of validation ' 'examples , cannot ensure static shapes on TPU. Batch size: %d, ' 'validation examples: %d', batch_size, self._num_validation_examples) else: batch_size = self.eval_batch_size if (self._num_test_examples % batch_size != 0 and self.name not in uneven_datasets): logging.warn( 'Batch size does not evenly divide the number of test examples, ' 'cannot ensure static shapes on TPU. Batch size: %d, test ' 'examples: %d', batch_size, self._num_test_examples) # Note that we always drop the last batch when the batch size does not # evenly divide the number of examples. return dataset.batch(batch_size, drop_remainder=drop_remainder)
def pipeline( self, dataset: tf.data.Dataset, input_context: tf.distribute.InputContext = None ) -> tf.data.Dataset: """Build a pipeline fetching, shuffling, and preprocessing the dataset. Args: dataset: A `tf.data.Dataset` that loads raw files. input_context: An optional context provided by `tf.distribute` for cross-replica training. This isn't necessary if using Keras compile/fit. Returns: A TensorFlow dataset outputting batched images and labels. """ if input_context and input_context.num_input_pipelines > 1: dataset = dataset.shard(input_context.num_input_pipelines, input_context.input_pipeline_id) if self.is_training and not self.config.cache: dataset = dataset.repeat() if self.config.builder == 'records': # Read the data from disk in parallel buffer_size = 8 * 1024 * 1024 # Use 8 MiB per file dataset = dataset.interleave( lambda name: tf.data.TFRecordDataset(name, buffer_size=buffer_size), cycle_length=16, num_parallel_calls=tf.data.experimental.AUTOTUNE) dataset = dataset.prefetch(self.global_batch_size) if self.config.cache: dataset = dataset.cache() if self.is_training: dataset = dataset.shuffle(self.config.shuffle_buffer_size) dataset = dataset.repeat() # Parse, pre-process, and batch the data in parallel if self.config.builder == 'records': preprocess = self.parse_record else: preprocess = self.preprocess dataset = dataset.map(preprocess, num_parallel_calls=tf.data.experimental.AUTOTUNE) dataset = dataset.batch(self.batch_size, drop_remainder=self.is_training) # Note: we could do image normalization here, but we defer it to the model # which can perform it much faster on a GPU/TPU # TODO(dankondratyuk): if we fix prefetching, we can do it here if self.is_training and self.config.deterministic_train is not None: options = tf.data.Options() options.experimental_deterministic = self.config.deterministic_train options.experimental_slack = self.config.use_slack options.experimental_optimization.parallel_batch = True options.experimental_optimization.map_fusion = True options.experimental_optimization.map_vectorization.enabled = True options.experimental_optimization.map_parallelization = True dataset = dataset.with_options(options) # Prefetch overlaps in-feed with training # Note: autotune here is not recommended, as this can lead to memory leaks. # Instead, use a constant prefetch size like the the number of devices. dataset = dataset.prefetch(self.config.num_devices) return dataset