def main(unused_argv): del unused_argv # Unused if FLAGS.input_layout not in ['NCHW', 'NHWC']: raise RuntimeError('--input_layout must be one of [NCHW, NHWC]') run_config = tpu_config.RunConfig( master=FLAGS.master, evaluation_master=FLAGS.master, model_dir=FLAGS.model_dir, save_checkpoints_secs=FLAGS.save_checkpoints_secs, save_summary_steps=FLAGS.save_summary_steps, session_config=tf.ConfigProto( allow_soft_placement=True, log_device_placement=FLAGS.log_device_placement), tpu_config=tpu_config.TPUConfig( iterations_per_loop=FLAGS.iterations, num_shards=FLAGS.num_shards)) inception_classifier = tpu_estimator.TPUEstimator( model_fn=inception_model_fn, use_tpu=FLAGS.use_tpu, config=run_config, train_batch_size=FLAGS.train_batch_size, eval_batch_size=FLAGS.eval_batch_size, batch_axis=(get_batch_axis( FLAGS.train_batch_size // FLAGS.num_shards), 0)) for cycle in range(FLAGS.train_steps // FLAGS.train_steps_per_eval): # tensors_to_log = { # 'learning_rate': 'learning_rate', # 'prediction_loss': 'prediction_loss', # 'train_accuracy': 'train_accuracy' # } # logging_hook = tf.train.LoggingTensorHook( # tensors=tensors_to_log, every_n_iter=100) tf.logging.info('Starting training cycle %d.' % cycle) inception_classifier.train( input_fn=ImageNetInput(True), steps=FLAGS.train_steps_per_eval) if FLAGS.eval_enabled: eval_steps = (imagenet.get_split_size('validation') // FLAGS.eval_batch_size) tf.logging.info('Starting evaluation cycle %d .' % cycle) eval_results = inception_classifier.evaluate( input_fn=ImageNetInput(False), steps=eval_steps) tf.logging.info('Evaluation results: %s' % eval_results)
def inception_model_fn(features, labels, mode, params): """Inception v2 model using Estimator API.""" del params num_classes = FLAGS.num_classes training_active = (mode == tf.estimator.ModeKeys.TRAIN) eval_active = (mode == tf.estimator.ModeKeys.EVAL) if training_active: size = FLAGS.train_batch_size // FLAGS.num_shards else: size = FLAGS.eval_batch_size input_transform_fn = TensorTranspose(size, is_input=True) features = input_transform_fn(features) with slim.arg_scope( inception.inception_v2_arg_scope( use_fused_batchnorm=FLAGS.use_fused_batchnorm)): logits, _ = inception.inception_v2( features, num_classes, is_training=training_active, depth_multiplier=FLAGS.depth_multiplier, replace_separable_convolution=True) predictions = { 'classes': tf.argmax(input=logits, axis=1), 'probabilities': tf.nn.softmax(logits, name='softmax_tensor') } if mode == tf.estimator.ModeKeys.PREDICT: return tf.estimator.EstimatorSpec(mode=mode, predictions=predictions) prediction_loss = tf.losses.softmax_cross_entropy(onehot_labels=labels, logits=logits, weights=1.0, label_smoothing=0.1) tf.losses.add_loss(prediction_loss) loss = tf.losses.get_total_loss(add_regularization_losses=True) initial_learning_rate = FLAGS.learning_rate * FLAGS.train_batch_size / 256 final_learning_rate = 0.01 * initial_learning_rate train_op = None if training_active: # Multiply the learning rate by 0.1 every 30 epochs. training_set_len = imagenet.get_split_size('train') batches_per_epoch = training_set_len // FLAGS.train_batch_size learning_rate = tf.train.exponential_decay( learning_rate=initial_learning_rate, global_step=tf.train.get_global_step(), decay_steps=_LEARNING_RATE_DECAY_EPOCHS * batches_per_epoch, decay_rate=_LEARNING_RATE_DECAY, staircase=True) # Set a minimum boundary for the learning rate. learning_rate = tf.maximum(learning_rate, final_learning_rate, name='learning_rate') # tf.summary.scalar('learning_rate', learning_rate) if FLAGS.optimizer == 'sgd': tf.logging.info('Using SGD optimizer') optimizer = tf.train.GradientDescentOptimizer( learning_rate=FLAGS.learning_rate) elif FLAGS.optimizer == 'momentum': tf.logging.info('Using Momentum optimizer') optimizer = tf.train.MomentumOptimizer( learning_rate=FLAGS.learning_rate, momentum=0.9) else: tf.logging.fatal('Unknown optimizer:', FLAGS.optimizer) if FLAGS.use_tpu: optimizer = tpu_optimizer.CrossShardOptimizer(optimizer) update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) with tf.control_dependencies(update_ops): train_op = optimizer.minimize( loss, global_step=tf.train.get_or_create_global_step()) eval_metrics = None if eval_active: def metric_fn(labels, logits): predictions = tf.argmax(input=logits, axis=1) accuracy = tf.metrics.accuracy(tf.argmax(input=labels, axis=1), predictions) return {'accuracy': accuracy} eval_metrics = (metric_fn, [labels, logits]) return tpu_estimator.TPUEstimatorSpec(mode=mode, loss=loss, train_op=train_op, eval_metrics=eval_metrics)