Exemplo n.º 1
0
def train_input_fn(params):
    """train_input_fn defines the input pipeline used for training."""
    batch_size = params["batch_size"]
    data_dir = params["data_dir"]
    # Retrieves the batch size for the current shard. The # of shards is
    # computed according to the input pipeline deployment. See
    # `tf.compat.v1.estimator.tpu.RunConfig` for details.
    ds = dataset.train(data_dir).cache().repeat().shuffle(
        buffer_size=50000).batch(batch_size, drop_remainder=True)
    return ds
Exemplo n.º 2
0
    def train_input_fn():
        """Prepare data for training."""

        # When choosing shuffle buffer sizes, larger sizes result in better
        # randomness, while smaller sizes use less memory. MNIST is a small
        # enough dataset that we can easily shuffle the full epoch.
        ds = dataset.train(flags_obj.data_dir)
        ds = ds.cache().shuffle(buffer_size=50000).batch(flags_obj.batch_size)

        # Iterate through the dataset a set number (`epochs_between_evals`) of times
        # during each training session.
        ds = ds.repeat(flags_obj.epochs_between_evals)
        return ds
Exemplo n.º 3
0
def run_mnist_eager(flags_obj):
    """Run MNIST training and eval loop in eager mode.

  Args:
    flags_obj: An object containing parsed flag values.
  """
    tf.enable_eager_execution()
    model_helpers.apply_clean(flags.FLAGS)

    # Automatically determine device and data_format
    (device, data_format) = ('/gpu:0', 'channels_first')
    if flags_obj.no_gpu or not tf.test.is_gpu_available():
        (device, data_format) = ('/cpu:0', 'channels_last')
    # If data_format is defined in FLAGS, overwrite automatically set value.
    if flags_obj.data_format is not None:
        data_format = flags_obj.data_format
    print('Using device %s, and data format %s.' % (device, data_format))

    # Load the datasets
    train_ds = mnist_dataset.train(flags_obj.data_dir).shuffle(60000).batch(
        flags_obj.batch_size)
    test_ds = mnist_dataset.test(flags_obj.data_dir).batch(
        flags_obj.batch_size)

    # Create the model and optimizer
    model = mnist.create_model(data_format)
    optimizer = tf.train.MomentumOptimizer(flags_obj.lr, flags_obj.momentum)

    # Create file writers for writing TensorBoard summaries.
    if flags_obj.output_dir:
        # Create directories to which summaries will be written
        # tensorboard --logdir=<output_dir>
        # can then be used to see the recorded summaries.
        train_dir = os.path.join(flags_obj.output_dir, 'train')
        test_dir = os.path.join(flags_obj.output_dir, 'eval')
        tf.gfile.MakeDirs(flags_obj.output_dir)
    else:
        train_dir = None
        test_dir = None
    summary_writer = tf.contrib.summary.create_file_writer(train_dir,
                                                           flush_millis=10000)
    test_summary_writer = tf.contrib.summary.create_file_writer(
        test_dir, flush_millis=10000, name='test')

    # Create and restore checkpoint (if one exists on the path)
    checkpoint_prefix = os.path.join(flags_obj.model_dir, 'ckpt')
    step_counter = tf.train.get_or_create_global_step()
    checkpoint = tf.train.Checkpoint(model=model,
                                     optimizer=optimizer,
                                     step_counter=step_counter)
    # Restore variables on creation if a checkpoint exists.
    checkpoint.restore(tf.train.latest_checkpoint(flags_obj.model_dir))

    # Train and evaluate for a set number of epochs.
    with tf.device(device):
        for _ in range(flags_obj.train_epochs):
            start = time.time()
            with summary_writer.as_default():
                train(model, optimizer, train_ds, step_counter,
                      flags_obj.log_interval)
            end = time.time()
            print('\nTrain time for epoch #%d (%d total steps): %f' %
                  (checkpoint.save_counter.numpy() + 1, step_counter.numpy(),
                   end - start))
            with test_summary_writer.as_default():
                test(model, test_ds)
            checkpoint.save(checkpoint_prefix)