def model_fn(features, labels, mode, params): """Mobilenet v1 model using Estimator API.""" num_classes = FLAGS.num_classes training_active = (mode == tf.estimator.ModeKeys.TRAIN) eval_active = (mode == tf.estimator.ModeKeys.EVAL) features = tensor_transform_fn(features, params['input_perm']) with bfloat16.bfloat16_scope(): if FLAGS.clear_update_collections: # updates_collections must be set to None in order to use fused batchnorm with arg_scope(mobilenet_v1.mobilenet_v1_arg_scope()): logits, end_points = mobilenet_v1.mobilenet_v1( features, num_classes, is_training=training_active, depth_multiplier=FLAGS.depth_multiplier) else: with arg_scope(mobilenet_v1.mobilenet_v1_arg_scope()): logits, end_points = mobilenet_v1.mobilenet_v1( features, num_classes, is_training=training_active, depth_multiplier=FLAGS.depth_multiplier) logits = tf.cast(logits, tf.float32) for k in end_points.keys(): end_points[k] = tf.cast(end_points[k], tf.float32) predictions = { 'classes': tf.argmax(input=logits, axis=1), 'probabilities': tf.nn.softmax(logits, name='softmax_tensor') } if mode == tf.estimator.ModeKeys.PREDICT: return tf.estimator.EstimatorSpec(mode=mode, predictions=predictions) if mode == tf.estimator.ModeKeys.EVAL and FLAGS.display_tensors and ( not FLAGS.use_tpu): with tf.control_dependencies([ tf.Print(predictions['classes'], [predictions['classes']], summarize=FLAGS.eval_batch_size, message='prediction: ') ]): labels = tf.Print(labels, [labels], summarize=FLAGS.eval_batch_size, message='label: ') one_hot_labels = tf.one_hot(labels, FLAGS.num_classes, dtype=tf.int32) loss = tf.losses.softmax_cross_entropy(onehot_labels=one_hot_labels, logits=logits, weights=1.0, label_smoothing=0.1) #loss = tf.losses.get_total_loss(add_regularization_losses=True) loss += WEIGHT_DECAY * tf.add_n([ tf.nn.l2_loss(v) for v in tf.trainable_variables() if 'batch_normalization' not in v.name ]) initial_learning_rate = FLAGS.learning_rate * FLAGS.train_batch_size / 256 final_learning_rate = 0.0001 * initial_learning_rate train_op = None if training_active: batches_per_epoch = _NUM_TRAIN_IMAGES // FLAGS.train_batch_size global_step = tf.train.get_or_create_global_step() learning_rate = tf.train.exponential_decay( learning_rate=initial_learning_rate, global_step=global_step, decay_steps=FLAGS.learning_rate_decay_epochs * batches_per_epoch, decay_rate=FLAGS.learning_rate_decay, staircase=True) # Set a minimum boundary for the learning rate. learning_rate = tf.maximum(learning_rate, final_learning_rate, name='learning_rate') if FLAGS.optimizer == 'sgd': tf.logging.info('Using SGD optimizer') optimizer = tf.train.GradientDescentOptimizer( learning_rate=learning_rate) elif FLAGS.optimizer == 'momentum': tf.logging.info('Using Momentum optimizer') optimizer = tf.train.MomentumOptimizer(learning_rate=learning_rate, momentum=0.9) elif FLAGS.optimizer == 'RMS': tf.logging.info('Using RMS optimizer') optimizer = tf.train.RMSPropOptimizer(learning_rate, RMSPROP_DECAY, momentum=RMSPROP_MOMENTUM, epsilon=RMSPROP_EPSILON) else: tf.logging.fatal('Unknown optimizer:', FLAGS.optimizer) if FLAGS.use_tpu: optimizer = tpu_optimizer.CrossShardOptimizer(optimizer) update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) with tf.control_dependencies(update_ops): train_op = optimizer.minimize(loss, global_step=global_step) if FLAGS.moving_average: ema = tf.train.ExponentialMovingAverage(decay=MOVING_AVERAGE_DECAY, num_updates=global_step) variables_to_average = (tf.trainable_variables() + tf.moving_average_variables()) with tf.control_dependencies([train_op ]), tf.name_scope('moving_average'): train_op = ema.apply(variables_to_average) eval_metrics = None if eval_active: def metric_fn(labels, predictions): accuracy = tf.metrics.accuracy( labels, tf.argmax(input=predictions, axis=1)) return {'accuracy': accuracy} if FLAGS.use_logits: eval_predictions = logits else: eval_predictions = end_points['Predictions'] eval_metrics = (metric_fn, [labels, eval_predictions]) return tpu_estimator.TPUEstimatorSpec(mode=mode, loss=loss, train_op=train_op, eval_metrics=eval_metrics)
def model_fn(features, labels, mode, params): """Mobilenet v1 model using Estimator API.""" num_classes = params['num_classes'] training_active = (mode == tf.estimator.ModeKeys.TRAIN) eval_active = (mode == tf.estimator.ModeKeys.EVAL) if isinstance(features, dict): features = features['feature'] features = supervised_images.tensor_transform_fn(features, params['input_perm']) if params['clear_update_collections']: # updates_collections must be set to None in order to use fused batchnorm with arg_scope(mobilenet_v1.mobilenet_v1_arg_scope()): logits, end_points = mobilenet_v1.mobilenet_v1( features, num_classes, is_training=training_active, depth_multiplier=params['depth_multiplier']) else: with arg_scope(mobilenet_v1.mobilenet_v1_arg_scope()): logits, end_points = mobilenet_v1.mobilenet_v1( features, num_classes, is_training=training_active, depth_multiplier=params['depth_multiplier']) predictions = { 'classes': tf.argmax(input=logits, axis=1), 'probabilities': tf.nn.softmax(logits, name='softmax_tensor') } if mode == tf.estimator.ModeKeys.PREDICT: return tf.estimator.EstimatorSpec( mode=mode, predictions=predictions, export_outputs={ 'classify': tf.estimator.export.PredictOutput(predictions) }) if mode == tf.estimator.ModeKeys.EVAL and FLAGS.display_tensors and ( not params['use_tpu']): with tf.control_dependencies([ tf.Print(predictions['classes'], [predictions['classes']], summarize=params['eval_batch_size'], message='prediction: ') ]): labels = tf.Print(labels, [labels], summarize=params['eval_batch_size'], message='label: ') one_hot_labels = tf.one_hot(labels, params['num_classes'], dtype=tf.int32) tf.losses.softmax_cross_entropy(onehot_labels=one_hot_labels, logits=logits, weights=1.0, label_smoothing=0.1) loss = tf.losses.get_total_loss(add_regularization_losses=True) initial_learning_rate = params['learning_rate'] * params['train_batch_size'] / 256 # pylint: disable=line-too-long final_learning_rate = 0.0001 * initial_learning_rate train_op = None if training_active: batches_per_epoch = params['num_train_images'] // params[ 'train_batch_size'] global_step = tf.train.get_or_create_global_step() learning_rate = tf.train.exponential_decay( learning_rate=initial_learning_rate, global_step=global_step, decay_steps=params['learning_rate_decay_epochs'] * batches_per_epoch, decay_rate=params['learning_rate_decay'], staircase=True) # Set a minimum boundary for the learning rate. learning_rate = tf.maximum(learning_rate, final_learning_rate, name='learning_rate') if params['optimizer'] == 'sgd': tf.logging.info('Using SGD optimizer') optimizer = tf.train.GradientDescentOptimizer( learning_rate=learning_rate) elif params['optimizer'] == 'momentum': tf.logging.info('Using Momentum optimizer') optimizer = tf.train.MomentumOptimizer(learning_rate=learning_rate, momentum=0.9) elif params['optimizer'] == 'RMS': tf.logging.info('Using RMS optimizer') optimizer = tf.train.RMSPropOptimizer(learning_rate, RMSPROP_DECAY, momentum=RMSPROP_MOMENTUM, epsilon=RMSPROP_EPSILON) else: tf.logging.fatal('Unknown optimizer:', params['optimizer']) if params['use_tpu']: optimizer = tf.contrib.tpu.CrossShardOptimizer(optimizer) update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) with tf.control_dependencies(update_ops): train_op = optimizer.minimize(loss, global_step=global_step) if params['moving_average']: ema = tf.train.ExponentialMovingAverage(decay=MOVING_AVERAGE_DECAY, num_updates=global_step) variables_to_average = (tf.trainable_variables() + tf.moving_average_variables()) with tf.control_dependencies([train_op ]), tf.name_scope('moving_average'): train_op = ema.apply(variables_to_average) eval_metrics = None if eval_active: def metric_fn(labels, predictions): accuracy = tf.metrics.accuracy( labels, tf.argmax(input=predictions, axis=1)) return {'accuracy': accuracy} if params['use_logits']: eval_predictions = logits else: eval_predictions = end_points['Predictions'] eval_metrics = (metric_fn, [labels, eval_predictions]) return tf.contrib.tpu.TPUEstimatorSpec(mode=mode, loss=loss, train_op=train_op, eval_metrics=eval_metrics)