Example #1
0
def main(unused_argv):
  del unused_argv  # Unused

  if FLAGS.input_layout not in ['NCHW', 'NHWC']:
    raise RuntimeError('--input_layout must be one of [NCHW, NHWC]')

  run_config = tpu_config.RunConfig(
      master=FLAGS.master,
      evaluation_master=FLAGS.master,
      model_dir=FLAGS.model_dir,
      save_checkpoints_secs=FLAGS.save_checkpoints_secs,
      save_summary_steps=FLAGS.save_summary_steps,
      session_config=tf.ConfigProto(
          allow_soft_placement=True,
          log_device_placement=FLAGS.log_device_placement),
      tpu_config=tpu_config.TPUConfig(
          iterations_per_loop=FLAGS.iterations,
          num_shards=FLAGS.num_shards))

  inception_classifier = tpu_estimator.TPUEstimator(
      model_fn=inception_model_fn,
      use_tpu=FLAGS.use_tpu,
      config=run_config,
      train_batch_size=FLAGS.train_batch_size,
      eval_batch_size=FLAGS.eval_batch_size,
      batch_axis=(get_batch_axis(
          FLAGS.train_batch_size // FLAGS.num_shards), 0))

  for cycle in range(FLAGS.train_steps // FLAGS.train_steps_per_eval):
    # tensors_to_log = {
    #     'learning_rate': 'learning_rate',
    #     'prediction_loss': 'prediction_loss',
    #     'train_accuracy': 'train_accuracy'
    # }

    # logging_hook = tf.train.LoggingTensorHook(
    #     tensors=tensors_to_log, every_n_iter=100)

    tf.logging.info('Starting training cycle %d.' % cycle)
    inception_classifier.train(
        input_fn=ImageNetInput(True), steps=FLAGS.train_steps_per_eval)

    if FLAGS.eval_enabled:
      eval_steps = (imagenet.get_split_size('validation') //
                    FLAGS.eval_batch_size)
      tf.logging.info('Starting evaluation cycle %d .' % cycle)
      eval_results = inception_classifier.evaluate(
          input_fn=ImageNetInput(False), steps=eval_steps)
      tf.logging.info('Evaluation results: %s' % eval_results)
Example #2
0
def inception_model_fn(features, labels, mode, params):
    """Inception v2 model using Estimator API."""
    del params

    num_classes = FLAGS.num_classes
    training_active = (mode == tf.estimator.ModeKeys.TRAIN)
    eval_active = (mode == tf.estimator.ModeKeys.EVAL)

    if training_active:
        size = FLAGS.train_batch_size // FLAGS.num_shards
    else:
        size = FLAGS.eval_batch_size
    input_transform_fn = TensorTranspose(size, is_input=True)
    features = input_transform_fn(features)

    with slim.arg_scope(
            inception.inception_v2_arg_scope(
                use_fused_batchnorm=FLAGS.use_fused_batchnorm)):
        logits, _ = inception.inception_v2(
            features,
            num_classes,
            is_training=training_active,
            depth_multiplier=FLAGS.depth_multiplier,
            replace_separable_convolution=True)

    predictions = {
        'classes': tf.argmax(input=logits, axis=1),
        'probabilities': tf.nn.softmax(logits, name='softmax_tensor')
    }

    if mode == tf.estimator.ModeKeys.PREDICT:
        return tf.estimator.EstimatorSpec(mode=mode, predictions=predictions)

    prediction_loss = tf.losses.softmax_cross_entropy(onehot_labels=labels,
                                                      logits=logits,
                                                      weights=1.0,
                                                      label_smoothing=0.1)
    tf.losses.add_loss(prediction_loss)
    loss = tf.losses.get_total_loss(add_regularization_losses=True)

    initial_learning_rate = FLAGS.learning_rate * FLAGS.train_batch_size / 256
    final_learning_rate = 0.01 * initial_learning_rate

    train_op = None
    if training_active:
        # Multiply the learning rate by 0.1 every 30 epochs.
        training_set_len = imagenet.get_split_size('train')
        batches_per_epoch = training_set_len // FLAGS.train_batch_size
        learning_rate = tf.train.exponential_decay(
            learning_rate=initial_learning_rate,
            global_step=tf.train.get_global_step(),
            decay_steps=_LEARNING_RATE_DECAY_EPOCHS * batches_per_epoch,
            decay_rate=_LEARNING_RATE_DECAY,
            staircase=True)

        # Set a minimum boundary for the learning rate.
        learning_rate = tf.maximum(learning_rate,
                                   final_learning_rate,
                                   name='learning_rate')

        # tf.summary.scalar('learning_rate', learning_rate)

        if FLAGS.optimizer == 'sgd':
            tf.logging.info('Using SGD optimizer')
            optimizer = tf.train.GradientDescentOptimizer(
                learning_rate=FLAGS.learning_rate)
        elif FLAGS.optimizer == 'momentum':
            tf.logging.info('Using Momentum optimizer')
            optimizer = tf.train.MomentumOptimizer(
                learning_rate=FLAGS.learning_rate, momentum=0.9)
        else:
            tf.logging.fatal('Unknown optimizer:', FLAGS.optimizer)

        if FLAGS.use_tpu:
            optimizer = tpu_optimizer.CrossShardOptimizer(optimizer)

        update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
        with tf.control_dependencies(update_ops):
            train_op = optimizer.minimize(
                loss, global_step=tf.train.get_or_create_global_step())

    eval_metrics = None
    if eval_active:

        def metric_fn(labels, logits):
            predictions = tf.argmax(input=logits, axis=1)
            accuracy = tf.metrics.accuracy(tf.argmax(input=labels, axis=1),
                                           predictions)
            return {'accuracy': accuracy}

        eval_metrics = (metric_fn, [labels, logits])

    return tpu_estimator.TPUEstimatorSpec(mode=mode,
                                          loss=loss,
                                          train_op=train_op,
                                          eval_metrics=eval_metrics)