def test_cached_batch(self): max_batch_size = 100 data_set = tf.random_uniform(shape=(max_batch_size, 784), maxval=1.) var_data = data_reader.CachedDataReader((data_set, ), max_batch_size) cur_batch_size = tf.placeholder(shape=(), dtype=tf.int32, name='cur_batch_size') # Force create the ops data = var_data(cur_batch_size)[0] with self.test_session() as sess: sess.run(tf.global_variables_initializer()) coord = tf.train.Coordinator() tf.train.start_queue_runners(sess=sess, coord=coord) data_ = sess.run(data, feed_dict={cur_batch_size: 25}) stored_data_ = sess.run(var_data.cached_batch)[0] self.assertListEqual(list(data_[1]), list(stored_data_[1]))
def test_read_batch(self): max_batch_size = 10 batch_size_schedule = [2, 4, 6, 8] data_set = tf.random_uniform(shape=(max_batch_size, 784), maxval=1.) var_data = data_reader.CachedDataReader((data_set, ), max_batch_size) cur_batch_size = tf.placeholder(shape=(), dtype=tf.int32, name='cur_batch_size') # Force create the ops data = var_data(cur_batch_size)[0] with self.test_session() as sess: sess.run(tf.global_variables_initializer()) coord = tf.train.Coordinator() tf.train.start_queue_runners(sess=sess, coord=coord) for batch_size in batch_size_schedule: data_ = sess.run(data, feed_dict={cur_batch_size: batch_size}) self.assertEqual(len(data_), batch_size) self.assertEqual(len(data_[0]), 784)
def load_mnist(): """Creates MNIST dataset and wraps it inside cached data reader. Returns: cached_reader: `data_reader.CachedReader` instance which wraps MNIST dataset. num_examples: int. The number of training examples. """ # Wrap the data set into cached_reader which provides variable sized training # and caches the read train batch. if not FLAGS.use_alt_data_reader: # Version 1 using data_reader.py (slow!) dataset, num_examples = mnist.load_mnist_as_dataset( flatten_images=False) if FLAGS.use_batch_size_schedule: max_batch_size = num_examples else: max_batch_size = FLAGS.batch_size # Shuffle before repeat is correct unless you want repeat cases in the # same batch. dataset = (dataset.shuffle(num_examples).repeat().batch( max_batch_size).prefetch(5)) dataset = tf.compat.v1.data.make_one_shot_iterator(dataset).get_next() # This version of CachedDataReader requires the dataset to be shuffled return data_reader.CachedDataReader(dataset, max_batch_size), num_examples else: # Version 2 using data_reader_alt.py (faster) images, labels, num_examples = mnist.load_mnist_as_tensors( flatten_images=False) dataset = (images, labels) # This version of CachedDataReader requires the dataset to NOT be shuffled return data_reader_alt.CachedDataReader(dataset, num_examples), num_examples
def load_mnist(batch_size): """Creates MNIST dataset and wraps it inside cached data reader. Args: batch_size: Scalar placeholder variable which needs to fed to read variable sized training data. Returns: cached_reader: `data_reader.CachedReader` instance which wraps MNIST dataset. training_batch: Tensor of shape `[batch_size, 784]`, MNIST training images. """ # Create a MNIST data batch with max training batch size. # data_set = datasets.Mnist(batch_size=_BATCH_SIZE, mode='train')() data_set = mnist.load_mnist(FLAGS.data_dir, num_epochs=FLAGS.num_epochs, batch_size=_BATCH_SIZE, flatten_images=True) # Wrap the data set into cached_reader which provides variable sized training # and caches the read train batch. cached_reader = data_reader.CachedDataReader(data_set, _BATCH_SIZE) return cached_reader, cached_reader(batch_size)[0]