Beispiel #1
0
def model_fn(features, labels, mode, params):
    del params
    if mode == tf.estimator.ModeKeys.PREDICT:
        raise RuntimeError("mode {} is not supported yet".format(mode))
    image = features
    if isinstance(image, dict):
        image = features["image"]

    model = mnist.Model("channels_last")
    logits = model(image, training=(mode == tf.estimator.ModeKeys.TRAIN))
    loss = tf.losses.softmax_cross_entropy(onehot_labels=labels, logits=logits)

    if mode == tf.estimator.ModeKeys.TRAIN:
        learning_rate = tf.train.exponential_decay(FLAGS.learning_rate,
                                                   tf.train.get_global_step(),
                                                   decay_steps=100000,
                                                   decay_rate=0.96)
        optimizer = tf.train.GradientDescentOptimizer(
            learning_rate=learning_rate)
        if FLAGS.use_tpu:
            optimizer = tf.contrib.tpu.CrossShardOptimizer(optimizer)
        return tf.contrib.tpu.TPUEstimatorSpec(mode=mode,
                                               loss=loss,
                                               train_op=optimizer.minimize(
                                                   loss,
                                                   tf.train.get_global_step()))

    if mode == tf.estimator.ModeKeys.EVAL:
        return tf.contrib.tpu.TPUEstimatorSpec(mode=mode,
                                               loss=loss,
                                               eval_metrics=(metric_fn,
                                                             [labels, logits]))
def evaluate(defun=False):
    model = mnist.Model(data_format())
    dataset = random_dataset()
    if defun:
        model.call = tfe.defun(model.call)
    with tf.device(device()):
        mnist_eager.test(model, dataset)
def train(defun=False):
    model = mnist.Model(data_format())
    if defun:
        model.call = tfe.defun(model.call)
    optimizer = tf.train.GradientDescentOptimizer(learning_rate=0.01)
    dataset = random_dataset()
    with tf.device(device()):
        mnist_eager.train(model, optimizer, dataset)
Beispiel #4
0
def main(_):
    tfe.enable_eager_execution()

    (device, data_format) = ('/gpu:0', 'channels_first')
    if FLAGS.no_gpu or tfe.num_gpus() <= 0:
        (device, data_format) = ('/cpu:0', 'channels_last')
    print('Using device %s, and data format %s.' % (device, data_format))

    # Load the datasets
    train_ds = dataset.train(FLAGS.data_dir).shuffle(60000).batch(
        FLAGS.batch_size)
    test_ds = dataset.test(FLAGS.data_dir).batch(FLAGS.batch_size)

    # Create the model and optimizer
    model = mnist.Model(data_format)
    optimizer = tf.train.MomentumOptimizer(FLAGS.lr, FLAGS.momentum)

    if FLAGS.output_dir:
        # Create directories to which summaries will be written
        # tensorboard --logdir=<output_dir>
        # can then be used to see the recorded summaries.
        train_dir = os.path.join(FLAGS.output_dir, 'train')
        test_dir = os.path.join(FLAGS.output_dir, 'eval')
        tf.gfile.MakeDirs(FLAGS.output_dir)
    else:
        train_dir = None
        test_dir = None
    summary_writer = tf.contrib.summary.create_file_writer(train_dir,
                                                           flush_millis=10000)
    test_summary_writer = tf.contrib.summary.create_file_writer(
        test_dir, flush_millis=10000, name='test')
    checkpoint_prefix = os.path.join(FLAGS.checkpoint_dir, 'ckpt')
    step_counter = tf.train.get_or_create_global_step()
    checkpoint = tfe.Checkpoint(model=model,
                                optimizer=optimizer,
                                step_counter=step_counter)
    # Restore variables on creation if a checkpoint exists.
    checkpoint.restore(tf.train.latest_checkpoint(FLAGS.checkpoint_dir))
    # Train and evaluate for 10 epochs.
    with tf.device(device):
        for _ in range(10):
            start = time.time()
            with summary_writer.as_default():
                train(model, optimizer, train_ds, step_counter,
                      FLAGS.log_interval)
            end = time.time()
            print('\nTrain time for epoch #%d (%d total steps): %f' %
                  (checkpoint.save_counter.numpy() + 1, step_counter.numpy(),
                   end - start))
            with test_summary_writer.as_default():
                test(model, test_ds)
            checkpoint.save(checkpoint_prefix)