def inception_model_fn(features, labels, mode, params): """Inception v4 model using Estimator API.""" num_classes = FLAGS.num_classes is_training = (mode == tf.estimator.ModeKeys.TRAIN) features = tensor_transform_fn(features, params['model_transpose_dims']) if FLAGS.clear_update_collections: with arg_scope( inception.inception_v4_arg_scope( batch_norm_decay=BATCH_NORM_DECAY, batch_norm_epsilon=BATCH_NORM_EPSILON, updates_collections=None)): logits, end_points = inception.inception_v4( features, num_classes, is_training=is_training) else: with arg_scope( inception.inception_v4_arg_scope( batch_norm_decay=BATCH_NORM_DECAY, batch_norm_epsilon=BATCH_NORM_EPSILON)): logits, end_points = inception.inception_v4( features, num_classes, is_training=is_training) predictions = { 'classes': tf.argmax(input=logits, axis=1), 'probabilities': tf.nn.softmax(logits, name='softmax_tensor') } if mode == tf.estimator.ModeKeys.PREDICT: return tf.estimator.EstimatorSpec(mode=mode, predictions=predictions) if mode == tf.estimator.ModeKeys.EVAL and FLAGS.display_tensors and ( not FLAGS.use_tpu): with tf.control_dependencies([ tf.Print(predictions['classes'], [predictions['classes']], summarize=FLAGS.eval_batch_size, message='prediction: ') ]): labels = tf.Print(labels, [labels], summarize=FLAGS.eval_batch_size, message='label: ') one_hot_labels = tf.one_hot(labels, FLAGS.num_classes, dtype=tf.int32) if 'AuxLogits' in end_points: tf.losses.softmax_cross_entropy(onehot_labels=one_hot_labels, logits=end_points['AuxLogits'], weights=0.4, label_smoothing=0.1, scope='aux_loss') tf.losses.softmax_cross_entropy(onehot_labels=one_hot_labels, logits=logits, weights=1.0, label_smoothing=0.1) loss = tf.losses.get_total_loss(add_regularization_losses=True) initial_learning_rate = FLAGS.learning_rate * FLAGS.train_batch_size / 256 # Adjust the initial learning rate for warmup initial_learning_rate /= ( FLAGS.learning_rate_decay**((FLAGS.warmup_epochs + FLAGS.cold_epochs) / FLAGS.learning_rate_decay_epochs)) final_learning_rate = 0.0001 * initial_learning_rate train_op = None if is_training: batches_per_epoch = _NUM_TRAIN_IMAGES // FLAGS.train_batch_size global_step = tf.train.get_or_create_global_step() cur_epoch = tf.cast( (tf.cast(global_step, tf.float32) / batches_per_epoch), tf.int32) clr = FLAGS.cold_learning_rate wlr = initial_learning_rate / (FLAGS.warmup_epochs + FLAGS.cold_epochs) learning_rate = tf.where( tf.greater_equal(cur_epoch, FLAGS.cold_epochs), (tf.where( tf.greater_equal(cur_epoch, FLAGS.warmup_epochs + FLAGS.cold_epochs), tf.train.exponential_decay( learning_rate=initial_learning_rate, global_step=global_step, decay_steps=FLAGS.learning_rate_decay_epochs * batches_per_epoch, decay_rate=FLAGS.learning_rate_decay, staircase=True), tf.multiply(tf.cast(cur_epoch, tf.float32), wlr))), clr) # Set a minimum boundary for the learning rate. learning_rate = tf.maximum(learning_rate, final_learning_rate, name='learning_rate') if FLAGS.optimizer == 'sgd': tf.logging.info('Using SGD optimizer') optimizer = tf.train.GradientDescentOptimizer( learning_rate=learning_rate) elif FLAGS.optimizer == 'momentum': tf.logging.info('Using Momentum optimizer') optimizer = tf.train.MomentumOptimizer(learning_rate=learning_rate, momentum=0.9) elif FLAGS.optimizer == 'RMS': tf.logging.info('Using RMS optimizer') optimizer = tf.train.RMSPropOptimizer(learning_rate, RMSPROP_DECAY, momentum=RMSPROP_MOMENTUM, epsilon=RMSPROP_EPSILON) else: tf.logging.fatal('Unknown optimizer:', FLAGS.optimizer) if FLAGS.use_tpu: optimizer = tpu_optimizer.CrossShardOptimizer(optimizer) update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) with tf.control_dependencies(update_ops): train_op = optimizer.minimize(loss, global_step=global_step) if FLAGS.moving_average: ema = tf.train.ExponentialMovingAverage(decay=MOVING_AVERAGE_DECAY, num_updates=global_step) variables_to_average = (tf.trainable_variables() + tf.moving_average_variables()) with tf.control_dependencies([train_op ]), tf.name_scope('moving_average'): train_op = ema.apply(variables_to_average) eval_metrics = None if mode == tf.estimator.ModeKeys.EVAL: def metric_fn(labels, predictions): accuracy = tf.metrics.accuracy( labels, tf.argmax(input=predictions, axis=1)) return {'accuracy': accuracy} if FLAGS.use_logits: eval_predictions = logits else: eval_predictions = end_points['Predictions'] eval_metrics = (metric_fn, [labels, eval_predictions]) return tpu_estimator.TPUEstimatorSpec(mode=mode, loss=loss, train_op=train_op, eval_metrics=eval_metrics)
def inception_model_fn(features, labels, mode, params): """Inception v4 model using Estimator API.""" num_classes = FLAGS.num_classes is_training = (mode == tf.estimator.ModeKeys.TRAIN) is_eval = (mode == tf.estimator.ModeKeys.EVAL) if isinstance(features, dict): features = features['feature'] features = tensor_transform_fn(features, params['model_transpose_dims']) # This nested function allows us to avoid duplicating the logic which # builds the network, for different values of --precision. def build_network(): if FLAGS.precision == 'bfloat16': with tf.contrib.tpu.bfloat16_scope(): logits, end_points = inception.inception_v4( features, num_classes, is_training=is_training) logits = tf.cast(logits, tf.float32) elif FLAGS.precision == 'float32': logits, end_points = inception.inception_v4( features, num_classes, is_training=is_training) return logits, end_points if FLAGS.clear_update_collections: with arg_scope( inception.inception_v4_arg_scope( weight_decay=0.0, batch_norm_decay=BATCH_NORM_DECAY, batch_norm_epsilon=BATCH_NORM_EPSILON, updates_collections=None)): logits, end_points = build_network() else: with arg_scope( inception.inception_v4_arg_scope( batch_norm_decay=BATCH_NORM_DECAY, batch_norm_epsilon=BATCH_NORM_EPSILON)): logits, end_points = build_network() predictions = { 'classes': tf.argmax(input=logits, axis=1), 'probabilities': tf.nn.softmax(logits, name='softmax_tensor') } if mode == tf.estimator.ModeKeys.PREDICT: return tf.estimator.EstimatorSpec( mode=mode, predictions=predictions, export_outputs={ 'classify': tf.estimator.export.PredictOutput(predictions) }) if mode == tf.estimator.ModeKeys.EVAL and FLAGS.display_tensors and ( not FLAGS.use_tpu): with tf.control_dependencies([ tf.Print(predictions['classes'], [predictions['classes']], summarize=FLAGS.eval_batch_size, message='prediction: ') ]): labels = tf.Print(labels, [labels], summarize=FLAGS.eval_batch_size, message='label: ') one_hot_labels = tf.one_hot(labels, FLAGS.num_classes, dtype=tf.int32) if 'AuxLogits' in end_points: tf.compat.v1.losses.softmax_cross_entropy(onehot_labels=one_hot_labels, logits=tf.cast( end_points['AuxLogits'], tf.float32), weights=0.4, label_smoothing=0.1, scope='aux_loss') tf.compat.v1.losses.softmax_cross_entropy(onehot_labels=one_hot_labels, logits=logits, weights=1.0, label_smoothing=0.1) losses = tf.add_n(tf.losses.get_losses()) l2_loss = [] for v in tf.trainable_variables(): tf.logging.info(v.name) if 'BatchNorm' not in v.name and 'weights' in v.name: l2_loss.append(tf.nn.l2_loss(v)) tf.logging.info(len(l2_loss)) loss = losses + WEIGHT_DECAY * tf.add_n(l2_loss) initial_learning_rate = FLAGS.learning_rate * FLAGS.train_batch_size / 256 # Adjust the initial learning rate for warmup initial_learning_rate /= ( FLAGS.learning_rate_decay**((FLAGS.warmup_epochs + FLAGS.cold_epochs) / FLAGS.learning_rate_decay_epochs)) final_learning_rate = 0.0001 * initial_learning_rate host_call = None train_op = None if is_training: batches_per_epoch = _NUM_TRAIN_IMAGES / FLAGS.train_batch_size global_step = tf.compat.v1.train.get_or_create_global_step() current_epoch = tf.cast( (tf.cast(global_step, tf.float32) / batches_per_epoch), tf.int32) clr = FLAGS.cold_learning_rate wlr = initial_learning_rate / (FLAGS.warmup_epochs + FLAGS.cold_epochs) learning_rate = tf.where( tf.greater_equal(current_epoch, FLAGS.cold_epochs), (tf.where( tf.greater_equal(current_epoch, FLAGS.warmup_epochs + FLAGS.cold_epochs), tf.compat.v1.train.exponential_decay( learning_rate=initial_learning_rate, global_step=global_step, decay_steps=int( FLAGS.learning_rate_decay_epochs * batches_per_epoch), decay_rate=FLAGS.learning_rate_decay, staircase=True), tf.multiply(tf.cast(current_epoch, tf.float32), wlr))), clr) # Set a minimum boundary for the learning rate. learning_rate = tf.maximum(learning_rate, final_learning_rate, name='learning_rate') if FLAGS.optimizer == 'sgd': tf.logging.info('Using SGD optimizer') optimizer = tf.train.GradientDescentOptimizer( learning_rate=learning_rate) elif FLAGS.optimizer == 'momentum': tf.logging.info('Using Momentum optimizer') optimizer = tf.train.MomentumOptimizer(learning_rate=learning_rate, momentum=0.9) elif FLAGS.optimizer == 'RMS': tf.logging.info('Using RMS optimizer') optimizer = tf.compat.v1.train.RMSPropOptimizer( learning_rate, RMSPROP_DECAY, momentum=RMSPROP_MOMENTUM, epsilon=RMSPROP_EPSILON) else: tf.logging.fatal('Unknown optimizer:', FLAGS.optimizer) if FLAGS.use_tpu: optimizer = tf.contrib.tpu.CrossShardOptimizer(optimizer) update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) with tf.control_dependencies(update_ops): train_op = optimizer.minimize(loss, global_step=global_step) if FLAGS.moving_average: ema = tf.train.ExponentialMovingAverage(decay=MOVING_AVERAGE_DECAY, num_updates=global_step) variables_to_average = (tf.trainable_variables() + tf.moving_average_variables()) with tf.control_dependencies([train_op ]), tf.name_scope('moving_average'): train_op = ema.apply(variables_to_average) # To log the loss, current learning rate, and epoch for Tensorboard, the # summary op needs to be run on the host CPU via host_call. host_call # expects [batch_size, ...] Tensors, thus reshape to introduce a batch # dimension. These Tensors are implicitly concatenated to # [params['batch_size']]. gs_t = tf.reshape(global_step, [1]) loss_t = tf.reshape(loss, [1]) lr_t = tf.reshape(learning_rate, [1]) ce_t = tf.reshape(current_epoch, [1]) if not FLAGS.skip_host_call: def host_call_fn(gs, loss, lr, ce): """Training host call. Creates scalar summaries for training metrics. This function is executed on the CPU and should not directly reference any Tensors in the rest of the `model_fn`. To pass Tensors from the model to the `metric_fn`, provide as part of the `host_call`. See https://www.tensorflow.org/api_docs/python/tf/contrib/tpu/TPUEstimatorSpec for more information. Arguments should match the list of `Tensor` objects passed as the second element in the tuple passed to `host_call`. Args: gs: `Tensor with shape `[batch]` for the global_step loss: `Tensor` with shape `[batch]` for the training loss. lr: `Tensor` with shape `[batch]` for the learning_rate. ce: `Tensor` with shape `[batch]` for the current_epoch. Returns: List of summary ops to run on the CPU host. """ gs = gs[0] with summary.create_file_writer(FLAGS.model_dir).as_default(): with summary.always_record_summaries(): summary.scalar('loss', tf.reduce_mean(loss), step=gs) summary.scalar('learning_rate', tf.reduce_mean(lr), step=gs) summary.scalar('current_epoch', tf.reduce_mean(ce), step=gs) return summary.all_summary_ops() host_call = (host_call_fn, [gs_t, loss_t, lr_t, ce_t]) eval_metrics = None if is_eval: def metric_fn(labels, logits): """Evaluation metric function. Evaluates accuracy. This function is executed on the CPU and should not directly reference any Tensors in the rest of the `model_fn`. To pass Tensors from the model to the `metric_fn`, provide as part of the `eval_metrics`. See https://www.tensorflow.org/api_docs/python/tf/contrib/tpu/TPUEstimatorSpec for more information. Arguments should match the list of `Tensor` objects passed as the second element in the tuple passed to `eval_metrics`. Args: labels: `Tensor` with shape `[batch, ]`. logits: `Tensor` with shape `[batch, num_classes]`. Returns: A dict of the metrics to return from evaluation. """ predictions = tf.argmax(logits, axis=1) top_1_accuracy = tf.metrics.accuracy(labels, predictions) in_top_5 = tf.cast(tf.nn.in_top_k(logits, labels, 5), tf.float32) top_5_accuracy = tf.metrics.mean(in_top_5) return { 'accuracy': top_1_accuracy, 'accuracy@5': top_5_accuracy, } eval_metrics = (metric_fn, [labels, logits]) return tf.contrib.tpu.TPUEstimatorSpec(mode=mode, loss=loss, train_op=train_op, host_call=host_call, eval_metrics=eval_metrics)
def inception_model_fn(features, labels, mode, params): """Inception v4 model using Estimator API.""" num_classes = FLAGS.num_classes is_training = (mode == tf.estimator.ModeKeys.TRAIN) is_eval = (mode == tf.estimator.ModeKeys.EVAL) if isinstance(features, dict): features = features['feature'] #features = tensor_transform_fn(features, params['model_transpose_dims']) # This nested function allows us to avoid duplicating the logic which # builds the network, for different values of --precision. def build_network(): if FLAGS.precision == 'bfloat16': with tf.contrib.tpu.bfloat16_scope(): logits, end_points = inception.inception_v4( features, num_classes, is_training=is_training) logits = tf.cast(logits, tf.float32) elif FLAGS.precision == 'float32': logits, end_points = inception.inception_v4( features, num_classes, is_training=is_training) return logits, end_points if FLAGS.clear_update_collections: with arg_scope( inception.inception_v4_arg_scope( weight_decay=0.0, batch_norm_decay=BATCH_NORM_DECAY, batch_norm_epsilon=BATCH_NORM_EPSILON, updates_collections=None)): logits, end_points = build_network() else: with arg_scope( inception.inception_v4_arg_scope( batch_norm_decay=BATCH_NORM_DECAY, batch_norm_epsilon=BATCH_NORM_EPSILON)): logits, end_points = build_network() predictions = { 'classes': tf.argmax(input=logits, axis=1), 'probabilities': tf.nn.softmax(logits, name='softmax_tensor') } if mode == tf.estimator.ModeKeys.PREDICT: return tf.estimator.EstimatorSpec( mode=mode, predictions=predictions, export_outputs={ 'classify': tf.estimator.export.PredictOutput(predictions) }) if mode == tf.estimator.ModeKeys.EVAL and FLAGS.display_tensors: with tf.control_dependencies([ tf.Print(predictions['classes'], [predictions['classes']], summarize=FLAGS.eval_batch_size, message='prediction: ') ]): labels = tf.Print(labels, [labels], summarize=FLAGS.eval_batch_size, message='label: ') one_hot_labels = tf.one_hot(labels, FLAGS.num_classes, dtype=tf.int32) if 'AuxLogits' in end_points: tf.losses.softmax_cross_entropy(onehot_labels=one_hot_labels, logits=tf.cast(end_points['AuxLogits'], tf.float32), weights=0.4, label_smoothing=0.1, scope='aux_loss') tf.losses.softmax_cross_entropy(onehot_labels=one_hot_labels, logits=logits, weights=1.0, label_smoothing=0.1) losses = tf.add_n(tf.losses.get_losses()) l2_loss = [] for v in tf.trainable_variables(): tf.logging.info(v.name) if 'BatchNorm' not in v.name and 'weights' in v.name: l2_loss.append(tf.nn.l2_loss(v)) tf.logging.info(len(l2_loss)) loss = losses + WEIGHT_DECAY * tf.add_n(l2_loss) initial_learning_rate = FLAGS.learning_rate * FLAGS.train_batch_size / 256 # Adjust the initial learning rate for warmup initial_learning_rate /= ( FLAGS.learning_rate_decay**((FLAGS.warmup_epochs + FLAGS.cold_epochs) / FLAGS.learning_rate_decay_epochs)) final_learning_rate = 0.0001 * initial_learning_rate train_op = None if is_training: batches_per_epoch = _NUM_TRAIN_IMAGES / FLAGS.train_batch_size global_step = tf.train.get_or_create_global_step() current_epoch = tf.cast( (tf.cast(global_step, tf.float32) / batches_per_epoch), tf.int32) clr = FLAGS.cold_learning_rate wlr = initial_learning_rate / (FLAGS.warmup_epochs + FLAGS.cold_epochs) learning_rate = tf.where( tf.greater_equal(current_epoch, FLAGS.cold_epochs), (tf.where( tf.greater_equal(current_epoch, FLAGS.warmup_epochs + FLAGS.cold_epochs), tf.train.exponential_decay( learning_rate=initial_learning_rate, global_step=global_step, decay_steps=int( FLAGS.learning_rate_decay_epochs * batches_per_epoch), decay_rate=FLAGS.learning_rate_decay, staircase=True), tf.multiply(tf.cast(current_epoch, tf.float32), wlr))), clr, ) # Set a minimum boundary for the learning rate. learning_rate = tf.maximum(learning_rate, final_learning_rate, name='learning_rate') if FLAGS.optimizer == 'sgd': tf.logging.info('Using SGD optimizer') optimizer = tf.train.GradientDescentOptimizer( learning_rate=learning_rate) elif FLAGS.optimizer == 'momentum': tf.logging.info('Using Momentum optimizer') optimizer = tf.train.MomentumOptimizer(learning_rate=learning_rate, momentum=0.9) elif FLAGS.optimizer == 'RMS': tf.logging.info('Using RMS optimizer') optimizer = tf.train.RMSPropOptimizer(learning_rate, RMSPROP_DECAY, momentum=RMSPROP_MOMENTUM, epsilon=RMSPROP_EPSILON) else: tf.logging.fatal('Unknown optimizer:', FLAGS.optimizer) update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) with tf.control_dependencies(update_ops): train_op = optimizer.minimize(loss, global_step=global_step) if FLAGS.moving_average: ema = tf.train.ExponentialMovingAverage(decay=MOVING_AVERAGE_DECAY, num_updates=global_step) variables_to_average = (tf.trainable_variables() + tf.moving_average_variables()) with tf.control_dependencies([train_op ]), tf.name_scope('moving_average'): train_op = ema.apply(variables_to_average) return tf.estimator.EstimatorSpec(mode=mode, loss=loss, train_op=train_op)