Example #1
0
 def test_cached_batch(self):
     max_batch_size = 100
     data_set = tf.random_uniform(shape=(max_batch_size, 784), maxval=1.)
     var_data = data_reader.CachedDataReader((data_set, ), max_batch_size)
     cur_batch_size = tf.placeholder(shape=(),
                                     dtype=tf.int32,
                                     name='cur_batch_size')
     # Force create the ops
     data = var_data(cur_batch_size)[0]
     with self.test_session() as sess:
         sess.run(tf.global_variables_initializer())
         coord = tf.train.Coordinator()
         tf.train.start_queue_runners(sess=sess, coord=coord)
         data_ = sess.run(data, feed_dict={cur_batch_size: 25})
         stored_data_ = sess.run(var_data.cached_batch)[0]
         self.assertListEqual(list(data_[1]), list(stored_data_[1]))
Example #2
0
 def test_read_batch(self):
     max_batch_size = 10
     batch_size_schedule = [2, 4, 6, 8]
     data_set = tf.random_uniform(shape=(max_batch_size, 784), maxval=1.)
     var_data = data_reader.CachedDataReader((data_set, ), max_batch_size)
     cur_batch_size = tf.placeholder(shape=(),
                                     dtype=tf.int32,
                                     name='cur_batch_size')
     # Force create the ops
     data = var_data(cur_batch_size)[0]
     with self.test_session() as sess:
         sess.run(tf.global_variables_initializer())
         coord = tf.train.Coordinator()
         tf.train.start_queue_runners(sess=sess, coord=coord)
         for batch_size in batch_size_schedule:
             data_ = sess.run(data, feed_dict={cur_batch_size: batch_size})
             self.assertEqual(len(data_), batch_size)
             self.assertEqual(len(data_[0]), 784)
Example #3
0
def load_mnist():
    """Creates MNIST dataset and wraps it inside cached data reader.

  Returns:
    cached_reader: `data_reader.CachedReader` instance which wraps MNIST
      dataset.
    num_examples: int. The number of training examples.
  """
    # Wrap the data set into cached_reader which provides variable sized training
    # and caches the read train batch.

    if not FLAGS.use_alt_data_reader:
        # Version 1 using data_reader.py (slow!)
        dataset, num_examples = mnist.load_mnist_as_dataset(
            flatten_images=False)
        if FLAGS.use_batch_size_schedule:
            max_batch_size = num_examples
        else:
            max_batch_size = FLAGS.batch_size

        # Shuffle before repeat is correct unless you want repeat cases in the
        # same batch.
        dataset = (dataset.shuffle(num_examples).repeat().batch(
            max_batch_size).prefetch(5))
        dataset = tf.compat.v1.data.make_one_shot_iterator(dataset).get_next()

        # This version of CachedDataReader requires the dataset to be shuffled
        return data_reader.CachedDataReader(dataset,
                                            max_batch_size), num_examples

    else:
        # Version 2 using data_reader_alt.py (faster)
        images, labels, num_examples = mnist.load_mnist_as_tensors(
            flatten_images=False)
        dataset = (images, labels)

        # This version of CachedDataReader requires the dataset to NOT be shuffled
        return data_reader_alt.CachedDataReader(dataset,
                                                num_examples), num_examples
def load_mnist(batch_size):
    """Creates MNIST dataset and wraps it inside cached data reader.

  Args:
    batch_size: Scalar placeholder variable which needs to fed to read variable
      sized training data.

  Returns:
    cached_reader: `data_reader.CachedReader` instance which wraps MNIST
      dataset.
    training_batch: Tensor of shape `[batch_size, 784]`, MNIST training images.
  """
    # Create a MNIST data batch with max training batch size.
    # data_set = datasets.Mnist(batch_size=_BATCH_SIZE, mode='train')()
    data_set = mnist.load_mnist(FLAGS.data_dir,
                                num_epochs=FLAGS.num_epochs,
                                batch_size=_BATCH_SIZE,
                                flatten_images=True)
    # Wrap the data set into cached_reader which provides variable sized training
    # and caches the read train batch.
    cached_reader = data_reader.CachedDataReader(data_set, _BATCH_SIZE)
    return cached_reader, cached_reader(batch_size)[0]