コード例 #1
0
def save_gradients(model, epoch, model_dir, data_dir):
    # Calculate number of parameters
    num_params = sum(_size(variable) for variable in model.variables)
    num_inputs = 600

    # Memory mapped file to store gradients in
    filename = os.path.join(model_dir, 'gradients_{}.dat'.format(epoch))
    gradients = np.memmap(filename=filename,
                          dtype=np.float32,
                          mode='w+',
                          shape=(num_inputs, num_params))

    with tf.device('/cpu:0'):
        train_ds = mnist_dataset.train(data_dir).batch(1)
    for i, (image, label) in enumerate(train_ds):
        with tf.GradientTape() as tape:
            logit = model(image, training=False)
            loss_value = loss(logit, label)
        grads = tape.gradient(loss_value, model.variables)
        j = 0
        for grad in grads:
            s = _size(grad)
            gradients[i, j:j + s] = tf.reshape(grad, [-1])
            j += s
        assert j == num_params
    assert i + 1 == num_inputs
コード例 #2
0
ファイル: mnist_eager.py プロジェクト: zzusunjs/models
def main(_):
    tfe.enable_eager_execution()

    # Automatically determine device and data_format
    (device, data_format) = ('/gpu:0', 'channels_first')
    if FLAGS.no_gpu or tfe.num_gpus() <= 0:
        (device, data_format) = ('/cpu:0', 'channels_last')
    # If data_format is defined in FLAGS, overwrite automatically set value.
    if FLAGS.data_format is not None:
        data_format = data_format
    print('Using device %s, and data format %s.' % (device, data_format))

    # Load the datasets
    train_ds = mnist_dataset.train(FLAGS.data_dir).shuffle(60000).batch(
        FLAGS.batch_size)
    test_ds = mnist_dataset.test(FLAGS.data_dir).batch(FLAGS.batch_size)

    # Create the model and optimizer
    model = mnist.Model(data_format)
    optimizer = tf.train.MomentumOptimizer(FLAGS.lr, FLAGS.momentum)

    # Create file writers for writing TensorBoard summaries.
    if FLAGS.output_dir:
        # Create directories to which summaries will be written
        # tensorboard --logdir=<output_dir>
        # can then be used to see the recorded summaries.
        train_dir = os.path.join(FLAGS.output_dir, 'train')
        test_dir = os.path.join(FLAGS.output_dir, 'eval')
        tf.gfile.MakeDirs(FLAGS.output_dir)
    else:
        train_dir = None
        test_dir = None
    summary_writer = tf.contrib.summary.create_file_writer(train_dir,
                                                           flush_millis=10000)
    test_summary_writer = tf.contrib.summary.create_file_writer(
        test_dir, flush_millis=10000, name='test')

    # Create and restore checkpoint (if one exists on the path)
    checkpoint_prefix = os.path.join(FLAGS.model_dir, 'ckpt')
    step_counter = tf.train.get_or_create_global_step()
    checkpoint = tfe.Checkpoint(model=model,
                                optimizer=optimizer,
                                step_counter=step_counter)
    # Restore variables on creation if a checkpoint exists.
    checkpoint.restore(tf.train.latest_checkpoint(FLAGS.model_dir))

    # Train and evaluate for a set number of epochs.
    with tf.device(device):
        for _ in range(FLAGS.train_epochs):
            start = time.time()
            with summary_writer.as_default():
                train(model, optimizer, train_ds, step_counter,
                      FLAGS.log_interval)
            end = time.time()
            print('\nTrain time for epoch #%d (%d total steps): %f' %
                  (checkpoint.save_counter.numpy() + 1, step_counter.numpy(),
                   end - start))
            with test_summary_writer.as_default():
                test(model, test_ds)
            checkpoint.save(checkpoint_prefix)
コード例 #3
0
def get_inputs(mode, batch_size=64):
    """
    Get batched (features, labels) from mnist.

    Args:
        `mode`: string representing mode of inputs.
            Should be one of {"train", "eval", "predict", "infer"}

    Returns:
        `features`: float32 tensor of shape (batch_size, 28, 28, 1) with
            grayscale values between 0 and 1.
        `labels`: int32 tensor of shape (batch_size,) with labels indicating
            the digit shown in `features`.
    """
    # Get the base dataset
    if mode == ModeKeys.TRAIN:
        dataset = ds.train('/tmp/mnist_data')
    elif mode in {ModeKeys.PREDICT, ModeKeys.EVAL}:
        dataset = ds.test('/tmp/mnist_data')
    else:
        raise ValueError('mode must be one in ModeKeys')

    # repeat and shuffle if training
    if mode == 'train':
        dataset = dataset.repeat()  # repeat indefinitely
        dataset = dataset.shuffle(buffer_size=10000)

    dataset = dataset.batch(batch_size)

    image, labels = dataset.make_one_shot_iterator().get_next()
    image = tf.cast(tf.reshape(image, (-1, 28, 28, 1)), tf.float32)
    return image, labels
コード例 #4
0
 def train_input_fn():
     # When choosing shuffle buffer sizes, larger sizes result in better
     # randomness, while smaller sizes use less memory. MNIST is a small
     # enough dataset that we can easily shuffle the full epoch.
     ds = dataset.train(FLAGS.data_dir)
     ds = ds.cache().shuffle(buffer_size=50000).batch(
         FLAGS.batch_size).repeat(FLAGS.train_epochs)
     return ds
コード例 #5
0
ファイル: mnist.py プロジェクト: icemansina/models
 def train_input_fn():
   # When choosing shuffle buffer sizes, larger sizes result in better
   # randomness, while smaller sizes use less memory. MNIST is a small
   # enough dataset that we can easily shuffle the full epoch.
   ds = dataset.train(FLAGS.data_dir)
   ds = ds.cache().shuffle(buffer_size=50000).batch(FLAGS.batch_size).repeat(
       FLAGS.train_epochs)
   return ds
コード例 #6
0
ファイル: mnist.py プロジェクト: scpwais/au2018
    def train_input_fn():
        from official.mnist import dataset as mnist_dataset

        # Load the datasets
        train_ds = mnist_dataset.train(params.DATA_BASEDIR)
        if params.LIMIT >= 0:
            train_ds = train_ds.take(params.LIMIT)
        train_ds = train_ds.shuffle(60000).batch(params.BATCH_SIZE)
        return train_ds
コード例 #7
0
ファイル: mnist_tpu.py プロジェクト: Exscotticus/models
def train_input_fn(params):
  """train_input_fn defines the input pipeline used for training."""
  batch_size = params["batch_size"]
  data_dir = params["data_dir"]
  # Retrieves the batch size for the current shard. The # of shards is
  # computed according to the input pipeline deployment. See
  # `tf.contrib.tpu.RunConfig` for details.
  ds = dataset.train(data_dir).cache().repeat().shuffle(
      buffer_size=50000).batch(batch_size, drop_remainder=True)
  return ds
コード例 #8
0
def train_input_fn(params):
    """train_input_fn defines the input pipeline used for training."""
    batch_size = params["batch_size"]
    data_dir = params["data_dir"]
    # Retrieves the batch size for the current shard. The # of shards is
    # computed according to the input pipeline deployment. See
    # `tf.contrib.tpu.RunConfig` for details.
    ds = dataset.train(data_dir).cache().repeat().shuffle(
        buffer_size=50000).batch(batch_size, drop_remainder=True)
    return ds
コード例 #9
0
ファイル: mnist_1.7.0.py プロジェクト: laurafdeza/training
    def train_input_fn():
        # When choosing shuffle buffer sizes, larger sizes result in better
        # randomness, while smaller sizes use less memory. MNIST is a small
        # enough dataset that we can easily shuffle the full epoch.
        ds = dataset.train(FLAGS.data_dir)
        ds = ds.cache().shuffle(buffer_size=50000).batch(FLAGS.batch_size)

        # Iterate through the dataset a set number (`epochs_between_evals`) of times
        # during each training session.
        ds = ds.repeat(FLAGS.epochs_between_evals)
        return ds
コード例 #10
0
ファイル: mnist.py プロジェクト: jostosh/gan
def prepare_input_pipeline(flags_obj):
    with tf.name_scope("InputPipeline"):
        ds = mnist.train("/home/jos/datasets/mnist")
        ds = ds.cache() \
            .shuffle(buffer_size=50000) \
            .batch(flags_obj.batch_size) \
            .repeat()\
            .make_one_shot_iterator()
        images, _ = ds.get_next()
        # Reshape and rescale to [-1, 1]
        return tf.reshape(images, (-1, 28, 28, 1)) * 2.0 - 1.0
コード例 #11
0
def train_input_fn(params):
    """train_input_fn defines the input pipeline used for training."""
    batch_size = params["batch_size"]
    data_dir = params["data_dir"]
    # Retrieves the batch size for the current shard. The # of shards is
    # computed according to the input pipeline deployment. See
    # `tf.contrib.tpu.RunConfig` for details.
    ds = dataset.train(data_dir).cache().repeat().shuffle(
        buffer_size=50000).apply(
            tf.contrib.data.batch_and_drop_remainder(batch_size))
    images, labels = ds.make_one_shot_iterator().get_next()
    return images, labels
コード例 #12
0
ファイル: mnist_tpu.py プロジェクト: Toyben/models
def train_input_fn(params):
  """train_input_fn defines the input pipeline used for training."""
  batch_size = params["batch_size"]
  data_dir = params["data_dir"]
  # Retrieves the batch size for the current shard. The # of shards is
  # computed according to the input pipeline deployment. See
  # `tf.contrib.tpu.RunConfig` for details.
  ds = dataset.train(data_dir).cache().repeat().shuffle(
      buffer_size=50000).apply(
          tf.contrib.data.batch_and_drop_remainder(batch_size))
  images, labels = ds.make_one_shot_iterator().get_next()
  return images, labels
コード例 #13
0
ファイル: mnist.py プロジェクト: 812864539/models
  def train_input_fn():
    """Prepare data for training."""

    # When choosing shuffle buffer sizes, larger sizes result in better
    # randomness, while smaller sizes use less memory. MNIST is a small
    # enough dataset that we can easily shuffle the full epoch.
    ds = dataset.train(flags_obj.data_dir)
    ds = ds.cache().shuffle(buffer_size=50000).batch(flags_obj.batch_size)

    # Iterate through the dataset a set number (`epochs_between_evals`) of times
    # during each training session.
    ds = ds.repeat(flags_obj.epochs_between_evals)
    return ds
コード例 #14
0
ファイル: mnist_eager.py プロジェクト: icemansina/models
def main(_):
  tfe.enable_eager_execution()

  (device, data_format) = ('/gpu:0', 'channels_first')
  if FLAGS.no_gpu or tfe.num_gpus() <= 0:
    (device, data_format) = ('/cpu:0', 'channels_last')
  print('Using device %s, and data format %s.' % (device, data_format))

  # Load the datasets
  train_ds = dataset.train(FLAGS.data_dir).shuffle(60000).batch(
      FLAGS.batch_size)
  test_ds = dataset.test(FLAGS.data_dir).batch(FLAGS.batch_size)

  # Create the model and optimizer
  model = mnist.Model(data_format)
  optimizer = tf.train.MomentumOptimizer(FLAGS.lr, FLAGS.momentum)

  if FLAGS.output_dir:
    # Create directories to which summaries will be written
    # tensorboard --logdir=<output_dir>
    # can then be used to see the recorded summaries.
    train_dir = os.path.join(FLAGS.output_dir, 'train')
    test_dir = os.path.join(FLAGS.output_dir, 'eval')
    tf.gfile.MakeDirs(FLAGS.output_dir)
  else:
    train_dir = None
    test_dir = None
  summary_writer = tf.contrib.summary.create_file_writer(
      train_dir, flush_millis=10000)
  test_summary_writer = tf.contrib.summary.create_file_writer(
      test_dir, flush_millis=10000, name='test')
  checkpoint_prefix = os.path.join(FLAGS.checkpoint_dir, 'ckpt')
  step_counter = tf.train.get_or_create_global_step()
  checkpoint = tfe.Checkpoint(
      model=model, optimizer=optimizer, step_counter=step_counter)
  # Restore variables on creation if a checkpoint exists.
  checkpoint.restore(tf.train.latest_checkpoint(FLAGS.checkpoint_dir))
  # Train and evaluate for 10 epochs.
  with tf.device(device):
    for _ in range(10):
      start = time.time()
      with summary_writer.as_default():
        train(model, optimizer, train_ds, step_counter, FLAGS.log_interval)
      end = time.time()
      print('\nTrain time for epoch #%d (%d total steps): %f' %
            (checkpoint.save_counter.numpy() + 1,
             step_counter.numpy(),
             end - start))
      with test_summary_writer.as_default():
        test(model, test_ds)
      checkpoint.save(checkpoint_prefix)
コード例 #15
0
ファイル: mnist.py プロジェクト: jsntsay/models
  def train_input_fn():
    """Prepare data for training."""

    # When choosing shuffle buffer sizes, larger sizes result in better
    # randomness, while smaller sizes use less memory. MNIST is a small
    # enough dataset that we can easily shuffle the full epoch.
    ds = dataset.train(flags_obj.data_dir)
    ##### HYPERPARAMETER
    ##### buffer_size here seems to be a hyperparameter given the above comment
    ds = ds.cache().shuffle(buffer_size=50000).batch(flags_obj.batch_size)

    # Iterate through the dataset a set number (`epochs_between_evals`) of times
    # during each training session.
    ds = ds.repeat(flags_obj.epochs_between_evals)
    return ds
コード例 #16
0
ファイル: mnist.py プロジェクト: scpwais/au2018
    def datasets_iter_image_rows(cls, params=None):
        params = params or MNIST.Params()

        log = util.create_log()

        def gen_dataset(ds, split):
            import imageio
            import numpy as np

            n = 0
            with util.tf_data_session(ds) as (sess, iter_dataset):
                for image, label in iter_dataset():
                    image = np.reshape(image * 255.,
                                       (28, 28, 1)).astype(np.uint8)
                    label = int(label)
                    row = dataset.ImageRow.from_np_img_labels(
                        image,
                        label,
                        dataset=cls.TABLE_NAME,
                        split=split,
                        uri='mnist_%s_%s' % (split, n))
                    yield row

                    if params.LIMIT >= 0 and n == params.LIMIT:
                        break
                    n += 1
                    if n % 100 == 0:
                        log.info("Read %s records from tf.Dataset" % n)

        from official.mnist import dataset as mnist_dataset

        # Keep our dataset ops in an isolated graph
        g = tf.Graph()
        with g.as_default():
            gens = itertools.chain(
                gen_dataset(mnist_dataset.train(params.DATA_BASEDIR), 'train'),
                gen_dataset(mnist_dataset.test(params.DATA_BASEDIR), 'test'))
            for row in gens:
                yield row
コード例 #17
0
# pylint: disable=g-bad-import-order
from absl import app as absl_app
from absl import flags
import tensorflow as tf
# pylint: enable=g-bad-import-order

from official.mnist import dataset as mnist_dataset
from official.mnist import mnist
from official.utils.flags import core as flags_core
from official.utils.misc import model_helpers

# Eager Modeに変更する
tf.enable_eager_execution()
tfe = tf.contrib.eager

train_ds = mnist_dataset.train(".").shuffle(60000).batch(128)
test_ds = mnist_dataset.train(".").batch(128)

import ipdb
ipdb.set_trace()

model = mnist.create_model('channels_last')
optimizer = tf.train.MomentumOptimizer(0.01, 0.01)

train_dir = "train"
test_dir = "test"


def loss(logits, labels):
    return tf.reduce_mean(
        tf.nn.sparse_softmax_cross_entropy_with_logits(
コード例 #18
0
def run_mnist_eager(flags_obj):
    """Run MNIST training and eval loop in eager mode.

  Args:
    flags_obj: An object containing parsed flag values.
  """
    tf.enable_eager_execution()
    model_helpers.apply_clean(flags.FLAGS)

    # Automatically determine device and data_format
    (device, data_format) = ('/gpu:0', 'channels_first')
    if flags_obj.no_gpu or not tf.test.is_gpu_available():
        (device, data_format) = ('/cpu:0', 'channels_last')
    # If data_format is defined in FLAGS, overwrite automatically set value.
    if flags_obj.data_format is not None:
        data_format = flags_obj.data_format
    print('Using device %s, and data format %s.' % (device, data_format))

    # Load the datasets
    train_ds = mnist_dataset.train(flags_obj.data_dir).shuffle(60000).batch(
        flags_obj.batch_size)
    test_ds = mnist_dataset.test(flags_obj.data_dir).batch(
        flags_obj.batch_size)

    # Create the model and optimizer
    model = mnist.create_model(data_format)
    optimizer = tf.train.MomentumOptimizer(flags_obj.lr, flags_obj.momentum)

    # Create file writers for writing TensorBoard summaries.
    if flags_obj.output_dir:
        # Create directories to which summaries will be written
        # tensorboard --logdir=<output_dir>
        # can then be used to see the recorded summaries.
        train_dir = os.path.join(flags_obj.output_dir, 'train')
        test_dir = os.path.join(flags_obj.output_dir, 'eval')
        tf.gfile.MakeDirs(flags_obj.output_dir)
    else:
        train_dir = None
        test_dir = None
    summary_writer = tf.contrib.summary.create_file_writer(train_dir,
                                                           flush_millis=10000)
    test_summary_writer = tf.contrib.summary.create_file_writer(
        test_dir, flush_millis=10000, name='test')

    # Create and restore checkpoint (if one exists on the path)
    checkpoint_prefix = os.path.join(flags_obj.model_dir, 'ckpt')
    step_counter = tf.train.get_or_create_global_step()
    checkpoint = tf.train.Checkpoint(model=model,
                                     optimizer=optimizer,
                                     step_counter=step_counter)
    # Restore variables on creation if a checkpoint exists.
    checkpoint.restore(tf.train.latest_checkpoint(flags_obj.model_dir))

    # Train and evaluate for a set number of epochs.
    with tf.device(device):
        for epoch in range(flags_obj.train_epochs):
            print('Dumping gradient matrix to {}'.format(flags_obj.model_dir))
            if epoch == flags_obj.save_gradients_epoch:
                save_gradients(model, epoch, flags_obj.model_dir,
                               flags_obj.data_dir)
            start = time.time()
            with summary_writer.as_default():
                train(model, optimizer, train_ds, step_counter,
                      flags_obj.log_interval)
            end = time.time()
            print('\nTrain time for epoch #%d (%d total steps): %f' %
                  (checkpoint.save_counter.numpy() + 1, step_counter.numpy(),
                   end - start))
            with test_summary_writer.as_default():
                test(model, test_ds)
            checkpoint.save(checkpoint_prefix)
コード例 #19
0
 def train_input_fn():
     ds = dataset.train(flags_obj.data_dir)
     ds = ds.cache().shuffle(buffer_size=50000).batch(flags_obj.batch_size)
     # repeat的功能就是将整个序列重复多次,主要用来处理机器学习中的epoch
     ds = ds.repeat(flags_obj.epochs_between_evals)
     return ds
コード例 #20
0
ファイル: mnist_eager.py プロジェクト: cybermaster/reference
def main(argv):
  parser = MNISTEagerArgParser()
  flags = parser.parse_args(args=argv[1:])

  tfe.enable_eager_execution()

  # Automatically determine device and data_format
  (device, data_format) = ('/gpu:0', 'channels_first')
  if flags.no_gpu or tfe.num_gpus() <= 0:
    (device, data_format) = ('/cpu:0', 'channels_last')
  # If data_format is defined in FLAGS, overwrite automatically set value.
  if flags.data_format is not None:
    data_format = flags.data_format
  print('Using device %s, and data format %s.' % (device, data_format))

  # Load the datasets
  train_ds = mnist_dataset.train(flags.data_dir).shuffle(60000).batch(
      flags.batch_size)
  test_ds = mnist_dataset.test(flags.data_dir).batch(flags.batch_size)

  # Create the model and optimizer
  model = mnist.create_model(data_format)
  optimizer = tf.train.MomentumOptimizer(flags.lr, flags.momentum)

  # Create file writers for writing TensorBoard summaries.
  if flags.output_dir:
    # Create directories to which summaries will be written
    # tensorboard --logdir=<output_dir>
    # can then be used to see the recorded summaries.
    train_dir = os.path.join(flags.output_dir, 'train')
    test_dir = os.path.join(flags.output_dir, 'eval')
    tf.gfile.MakeDirs(flags.output_dir)
  else:
    train_dir = None
    test_dir = None
  summary_writer = tf.contrib.summary.create_file_writer(
      train_dir, flush_millis=10000)
  test_summary_writer = tf.contrib.summary.create_file_writer(
      test_dir, flush_millis=10000, name='test')

  # Create and restore checkpoint (if one exists on the path)
  checkpoint_prefix = os.path.join(flags.model_dir, 'ckpt')
  step_counter = tf.train.get_or_create_global_step()
  checkpoint = tfe.Checkpoint(
      model=model, optimizer=optimizer, step_counter=step_counter)
  # Restore variables on creation if a checkpoint exists.
  checkpoint.restore(tf.train.latest_checkpoint(flags.model_dir))

  # Train and evaluate for a set number of epochs.
  with tf.device(device):
    for _ in range(flags.train_epochs):
      start = time.time()
      with summary_writer.as_default():
        train(model, optimizer, train_ds, step_counter, flags.log_interval)
      end = time.time()
      print('\nTrain time for epoch #%d (%d total steps): %f' %
            (checkpoint.save_counter.numpy() + 1,
             step_counter.numpy(),
             end - start))
      with test_summary_writer.as_default():
        test(model, test_ds)
      checkpoint.save(checkpoint_prefix)
コード例 #21
0
import os
import sys
import tensorflow as tf
import numpy as np

models_path = r'E:\machine-learning\models'
sys.path.append(models_path)

from official.mnist import dataset

# hyper parameters
LEARNING_RATE = 1e-4
TRAINING_EPOCHS = 20
BATCH_SIZE = 100

mnist_train = dataset.train(r"E:\machine-learning\machine_learning\MNIST_data")
mnist_test = dataset.test(r"E:\machine-learning\machine_learning\MNIST_data")


def train_input_fn(features, labels, batch_size):
    pass


def cnn_model_fn(features, labels, mode):
    """
    Input Layer
    Reshape X to 4-D tensor: [batch_size, width, height, channels]
    MNIST images are 28x28 pixels, and have one color channel
    """
    input_layer = tf.reshape(features["x"], [-1, 28, 28, 1])
    conv1 = tf.layers.conv2d(inputs=input_layer,
コード例 #22
0
def save(ar, file_name):
    w, h = 28, 28
    data = np.zeros((h, w), dtype=np.uint8)
    for i in range(0, 28):
        for j in range(0, 28):
            data[i, j] = ar[28 * i + j] * 256
    img = Image.fromarray(data, 'L')
    img.save(file_name)


def load():
    img = Image.open('my.png').load()
    m = np.zeros(784, np.float32)
    for i in range(0, 28):
        for j in range(0, 28):
            m[28 * i + j] = 0.00390625 * img[j, i]
    return m


it = dataset.train('mnist_data').batch(3).make_one_shot_iterator()
next = it.get_next()
with tf.Session() as sess:
    # Run the initializer
    sess.run(tf.global_variables_initializer())
    n = sess.run(next)
    save(n[0][2], 'my.png')
    reloaded = load()
    save(reloaded, 'yours.png')
    print(n[1][2])
コード例 #23
0
def get_inputs(
        mode, batch_size=64, repeat=None, shuffle=None,
        data_dir='/tmp/mnist_data', corruption_stddev=5e-2):
    """
    Get optionally corrupted MNIST batches.

    Args:
        mode: `'train'` or in `{'eval', 'predict', 'infer'}`
        batch_size: size of returned batches
        repeat: bool indicating whether or not to repeat indefinitely. If None,
            repeats if `mode` is `'train'`
        shuffle: bool indicating whether or not to shuffle each epoch. If None,
            shuffles if `mode` is `'train'`
        data_dir: where to load/download data to
        corruption_stddev: if training, normally distributed noise is added
            to each pixel of the image.

    Returns:
        `image`, `labels` tensors, shape (?, 28, 28, 1) and (?) respecitvely.
        First dimension is batch_size except possibly on final batches.
    """
    # get the original dataset from `tensorflow/models/official`
    # https://github.com/tensorflow/models
    if mode == 'train':
        dataset = ds.train(data_dir)
    elif mode in {'eval', 'predict', 'infer'}:
        dataset = ds.test(data_dir)
    else:
        raise ValueError('mode "%s" not recognized' % mode)

    training = mode == 'train'

    # repeat before training is better for performance, though possibly worse
    # around epoch boundaries
    if repeat or repeat is None and training:
        dataset = dataset.repeat()

    if shuffle or shuffle is None and training:
        # A larger buffer size requires more memory but gives better shufffling
        dataset = dataset.shuffle(buffer_size=10000)

    def map_fn(image, labels):
        image += tf.random_normal(
            shape=image.shape, dtype=tf.float32, stddev=corruption_stddev)
        return image, labels

    # num_parallel_calls defaults to None, but included here to draw attention
    # for datasets with more preprocessing this may significantly speed things
    # up
    if training:
        dataset = dataset.map(map_fn, num_parallel_calls=None)
    dataset = dataset.batch(batch_size)

    # prefetching allows the CPU to preprocess/load data while the GPU is busy
    # prefetch_to_device should be faster, but likely won't make a difference
    # at this scale.
    dataset = dataset.prefetch(1)
    # dataset = dataset.apply(tf.contrib.data.prefetch_to_device('/gpu:0'))

    image, labels = dataset.make_one_shot_iterator().get_next()
    image = tf.reshape(image, (-1, 28, 28, 1))  # could also go in map_fn
    return image, labels
コード例 #24
0
 def train_input_fn():
     ds = dataset.train(flags.data_dir)
     ds = ds.cache().shuffle(buffer_size=50000).batch(flags.batch_size)
     return ds