def mnist_input_fn(params): dataset, num_examples = mnist.load_mnist_as_dataset(flatten_images=True) # Shuffle before repeat is correct unless you want repeat cases in the # same batch. dataset = (dataset.shuffle(num_examples).repeat().batch( params['batch_size'], drop_remainder=True).prefetch(tf.data.experimental.AUTOTUNE)) return dataset
def load_mnist(): """Creates MNIST dataset and wraps it inside cached data reader. Returns: cached_reader: `data_reader.CachedReader` instance which wraps MNIST dataset. num_examples: int. The number of training examples. """ # Wrap the data set into cached_reader which provides variable sized training # and caches the read train batch. if not FLAGS.use_alt_data_reader: # Version 1 using data_reader.py (slow!) dataset, num_examples = mnist.load_mnist_as_dataset( flatten_images=False) if FLAGS.use_batch_size_schedule: max_batch_size = num_examples else: max_batch_size = FLAGS.batch_size # Shuffle before repeat is correct unless you want repeat cases in the # same batch. dataset = (dataset.shuffle(num_examples).repeat().batch( max_batch_size).prefetch(5)) dataset = tf.compat.v1.data.make_one_shot_iterator(dataset).get_next() # This version of CachedDataReader requires the dataset to be shuffled return data_reader.CachedDataReader(dataset, max_batch_size), num_examples else: # Version 2 using data_reader_alt.py (faster) images, labels, num_examples = mnist.load_mnist_as_tensors( flatten_images=False) dataset = (images, labels) # This version of CachedDataReader requires the dataset to NOT be shuffled return data_reader_alt.CachedDataReader(dataset, num_examples), num_examples