예제 #1
0
def mnist_input_fn(params):
    dataset, num_examples = mnist.load_mnist_as_dataset(flatten_images=True)

    # Shuffle before repeat is correct unless you want repeat cases in the
    # same batch.
    dataset = (dataset.shuffle(num_examples).repeat().batch(
        params['batch_size'],
        drop_remainder=True).prefetch(tf.data.experimental.AUTOTUNE))
    return dataset
예제 #2
0
파일: rnn_mnist.py 프로젝트: phymucs/kfac
def load_mnist():
    """Creates MNIST dataset and wraps it inside cached data reader.

  Returns:
    cached_reader: `data_reader.CachedReader` instance which wraps MNIST
      dataset.
    num_examples: int. The number of training examples.
  """
    # Wrap the data set into cached_reader which provides variable sized training
    # and caches the read train batch.

    if not FLAGS.use_alt_data_reader:
        # Version 1 using data_reader.py (slow!)
        dataset, num_examples = mnist.load_mnist_as_dataset(
            flatten_images=False)
        if FLAGS.use_batch_size_schedule:
            max_batch_size = num_examples
        else:
            max_batch_size = FLAGS.batch_size

        # Shuffle before repeat is correct unless you want repeat cases in the
        # same batch.
        dataset = (dataset.shuffle(num_examples).repeat().batch(
            max_batch_size).prefetch(5))
        dataset = tf.compat.v1.data.make_one_shot_iterator(dataset).get_next()

        # This version of CachedDataReader requires the dataset to be shuffled
        return data_reader.CachedDataReader(dataset,
                                            max_batch_size), num_examples

    else:
        # Version 2 using data_reader_alt.py (faster)
        images, labels, num_examples = mnist.load_mnist_as_tensors(
            flatten_images=False)
        dataset = (images, labels)

        # This version of CachedDataReader requires the dataset to NOT be shuffled
        return data_reader_alt.CachedDataReader(dataset,
                                                num_examples), num_examples