Exemple #1
0
def model_fn(features, labels, mode, params):
  """model_fn constructs the ML model used to predict handwritten digits."""

  del params
  if mode == tf.estimator.ModeKeys.PREDICT:
    raise RuntimeError("mode {} is not supported yet".format(mode))
  image = features
  if isinstance(image, dict):
    image = features["image"]

  model = mnist.Model("channels_last")
  logits = model(image, training=(mode == tf.estimator.ModeKeys.TRAIN))
  loss = tf.losses.sparse_softmax_cross_entropy(labels=labels, logits=logits)

  if mode == tf.estimator.ModeKeys.TRAIN:
    learning_rate = tf.train.exponential_decay(
        FLAGS.learning_rate,
        tf.train.get_global_step(),
        decay_steps=100000,
        decay_rate=0.96)
    optimizer = tf.train.GradientDescentOptimizer(learning_rate=learning_rate)
    if FLAGS.use_tpu:
      optimizer = tf.contrib.tpu.CrossShardOptimizer(optimizer)
    return tf.contrib.tpu.TPUEstimatorSpec(
        mode=mode,
        loss=loss,
        train_op=optimizer.minimize(loss, tf.train.get_global_step()))

  if mode == tf.estimator.ModeKeys.EVAL:
    return tf.contrib.tpu.TPUEstimatorSpec(
        mode=mode, loss=loss, eval_metrics=(metric_fn, [labels, logits]))
def evaluate(defun=False):
    model = mnist.Model(data_format())
    dataset = random_dataset()
    if defun:
        model.call = tfe.defun(model.call)
    with tf.device(device()):
        mnist_eager.test(model, dataset)
Exemple #3
0
def main(_):
    tfe.enable_eager_execution()

    # Automatically determine device and data_format
    (device, data_format) = ('/gpu:0', 'channels_first')
    if FLAGS.no_gpu or tfe.num_gpus() <= 0:
        (device, data_format) = ('/cpu:0', 'channels_last')
    # If data_format is defined in FLAGS, overwrite automatically set value.
    if FLAGS.data_format is not None:
        data_format = data_format
    print('Using device %s, and data format %s.' % (device, data_format))

    # Load the datasets
    train_ds = mnist_dataset.train(FLAGS.data_dir).shuffle(60000).batch(
        FLAGS.batch_size)
    test_ds = mnist_dataset.test(FLAGS.data_dir).batch(FLAGS.batch_size)

    # Create the model and optimizer
    model = mnist.Model(data_format)
    optimizer = tf.train.MomentumOptimizer(FLAGS.lr, FLAGS.momentum)

    # Create file writers for writing TensorBoard summaries.
    if FLAGS.output_dir:
        # Create directories to which summaries will be written
        # tensorboard --logdir=<output_dir>
        # can then be used to see the recorded summaries.
        train_dir = os.path.join(FLAGS.output_dir, 'train')
        test_dir = os.path.join(FLAGS.output_dir, 'eval')
        tf.gfile.MakeDirs(FLAGS.output_dir)
    else:
        train_dir = None
        test_dir = None
    summary_writer = tf.contrib.summary.create_file_writer(train_dir,
                                                           flush_millis=10000)
    test_summary_writer = tf.contrib.summary.create_file_writer(
        test_dir, flush_millis=10000, name='test')

    # Create and restore checkpoint (if one exists on the path)
    checkpoint_prefix = os.path.join(FLAGS.model_dir, 'ckpt')
    step_counter = tf.train.get_or_create_global_step()
    checkpoint = tfe.Checkpoint(model=model,
                                optimizer=optimizer,
                                step_counter=step_counter)
    # Restore variables on creation if a checkpoint exists.
    checkpoint.restore(tf.train.latest_checkpoint(FLAGS.model_dir))

    # Train and evaluate for a set number of epochs.
    with tf.device(device):
        for _ in range(FLAGS.train_epochs):
            start = time.time()
            with summary_writer.as_default():
                train(model, optimizer, train_ds, step_counter,
                      FLAGS.log_interval)
            end = time.time()
            print('\nTrain time for epoch #%d (%d total steps): %f' %
                  (checkpoint.save_counter.numpy() + 1, step_counter.numpy(),
                   end - start))
            with test_summary_writer.as_default():
                test(model, test_ds)
            checkpoint.save(checkpoint_prefix)
def train(defun=False):
    model = mnist.Model(data_format())
    if defun:
        model.call = tfe.defun(model.call)
    optimizer = tf.train.GradientDescentOptimizer(learning_rate=0.01)
    dataset = random_dataset()
    with tf.device(device()):
        mnist_eager.train(model,
                          optimizer,
                          dataset,
                          step_counter=tf.train.get_or_create_global_step())