def model_fn(features, labels, mode, params): del params if mode == tf.estimator.ModeKeys.PREDICT: raise RuntimeError("mode {} is not supported yet".format(mode)) image = features if isinstance(image, dict): image = features["image"] model = mnist.Model("channels_last") logits = model(image, training=(mode == tf.estimator.ModeKeys.TRAIN)) loss = tf.losses.softmax_cross_entropy(onehot_labels=labels, logits=logits) if mode == tf.estimator.ModeKeys.TRAIN: learning_rate = tf.train.exponential_decay(FLAGS.learning_rate, tf.train.get_global_step(), decay_steps=100000, decay_rate=0.96) optimizer = tf.train.GradientDescentOptimizer( learning_rate=learning_rate) if FLAGS.use_tpu: optimizer = tf.contrib.tpu.CrossShardOptimizer(optimizer) return tf.contrib.tpu.TPUEstimatorSpec(mode=mode, loss=loss, train_op=optimizer.minimize( loss, tf.train.get_global_step())) if mode == tf.estimator.ModeKeys.EVAL: return tf.contrib.tpu.TPUEstimatorSpec(mode=mode, loss=loss, eval_metrics=(metric_fn, [labels, logits]))
def evaluate(defun=False): model = mnist.Model(data_format()) dataset = random_dataset() if defun: model.call = tfe.defun(model.call) with tf.device(device()): mnist_eager.test(model, dataset)
def train(defun=False): model = mnist.Model(data_format()) if defun: model.call = tfe.defun(model.call) optimizer = tf.train.GradientDescentOptimizer(learning_rate=0.01) dataset = random_dataset() with tf.device(device()): mnist_eager.train(model, optimizer, dataset)
def main(_): tfe.enable_eager_execution() (device, data_format) = ('/gpu:0', 'channels_first') if FLAGS.no_gpu or tfe.num_gpus() <= 0: (device, data_format) = ('/cpu:0', 'channels_last') print('Using device %s, and data format %s.' % (device, data_format)) # Load the datasets train_ds = dataset.train(FLAGS.data_dir).shuffle(60000).batch( FLAGS.batch_size) test_ds = dataset.test(FLAGS.data_dir).batch(FLAGS.batch_size) # Create the model and optimizer model = mnist.Model(data_format) optimizer = tf.train.MomentumOptimizer(FLAGS.lr, FLAGS.momentum) if FLAGS.output_dir: # Create directories to which summaries will be written # tensorboard --logdir=<output_dir> # can then be used to see the recorded summaries. train_dir = os.path.join(FLAGS.output_dir, 'train') test_dir = os.path.join(FLAGS.output_dir, 'eval') tf.gfile.MakeDirs(FLAGS.output_dir) else: train_dir = None test_dir = None summary_writer = tf.contrib.summary.create_file_writer(train_dir, flush_millis=10000) test_summary_writer = tf.contrib.summary.create_file_writer( test_dir, flush_millis=10000, name='test') checkpoint_prefix = os.path.join(FLAGS.checkpoint_dir, 'ckpt') step_counter = tf.train.get_or_create_global_step() checkpoint = tfe.Checkpoint(model=model, optimizer=optimizer, step_counter=step_counter) # Restore variables on creation if a checkpoint exists. checkpoint.restore(tf.train.latest_checkpoint(FLAGS.checkpoint_dir)) # Train and evaluate for 10 epochs. with tf.device(device): for _ in range(10): start = time.time() with summary_writer.as_default(): train(model, optimizer, train_ds, step_counter, FLAGS.log_interval) end = time.time() print('\nTrain time for epoch #%d (%d total steps): %f' % (checkpoint.save_counter.numpy() + 1, step_counter.numpy(), end - start)) with test_summary_writer.as_default(): test(model, test_ds) checkpoint.save(checkpoint_prefix)