def resnet_model_fn(features, labels, mode, model_class, resnet_size, weight_decay, learning_rate_fn, momentum, data_format, resnet_version, loss_scale, loss_filter_fn=None, model_type=resnet_model.DEFAULT_MODEL_TYPE, dtype=resnet_model.DEFAULT_DTYPE, fine_tune=False, label_smoothing=0.0): """Shared functionality for different resnet model_fns. Initializes the ResnetModel representing the model layers and uses that model to build the necessary EstimatorSpecs for the `mode` in question. For training, this means building losses, the optimizer, and the train op that get passed into the EstimatorSpec. For evaluation and prediction, the EstimatorSpec is returned without a train op, but with the necessary parameters for the given mode. Args: features: tensor representing input images labels: tensor representing class labels for all input images mode: current estimator mode; should be one of `tf.estimator.ModeKeys.TRAIN`, `EVALUATE`, `PREDICT` model_class: a class representing a TensorFlow model that has a __call__ function. We assume here that this is a subclass of ResnetModel. resnet_size: A single integer for the size of the ResNet model. weight_decay: weight decay loss rate used to regularize learned variables. learning_rate_fn: function that returns the current learning rate given the current global_step momentum: momentum term used for optimization data_format: Input format ('channels_last', 'channels_first', or None). If set to None, the format is dependent on whether a GPU is available. resnet_version: Integer representing which version of the ResNet network to use. See README for details. Valid values: [1, 2] loss_scale: The factor to scale the loss for numerical stability. A detailed summary is present in the arg parser help text. loss_filter_fn: function that takes a string variable name and returns True if the var should be included in loss calculation, and False otherwise. If None, batch_normalization variables will be excluded from the loss. dtype: the TensorFlow dtype to use for calculations. fine_tune: If True only train the dense layers(final layers). label_smoothing: If greater than 0 then smooth the labels. Returns: EstimatorSpec parameterized according to the input params and the current mode. """ # Uncomment the following lines if you want to write images to summary, # we turned it off for performance reason # Generate a summary node for the images # tf.compat.v1.summary.image('images', # (features, tf.cast(features, tf.float32)) [features.dtype == tf.bfloat16], # max_outputs=6) if features.dtype != tf.bfloat16: # Checks that features/images have same data type being used for calculations. assert features.dtype == dtype model = model_class(resnet_size, data_format, resnet_version=resnet_version, model_type=model_type, dtype=dtype) logits = model(features, mode == tf.estimator.ModeKeys.TRAIN) # This acts as a no-op if the logits are already in fp32 (provided logits are # not a SparseTensor). If dtype is is low precision, logits must be cast to # fp32 for numerical stability. logits = tf.cast(logits, tf.float32) if flags.FLAGS.is_mlperf_enabled: num_examples_metric = tf_mlperf_log.sum_metric(tensor=tf.shape(input=logits)[0], name=_NUM_EXAMPLES_NAME) predictions = { 'classes': tf.argmax(input=logits, axis=1), 'probabilities': tf.nn.softmax(logits, name='softmax_tensor') } if mode == tf.estimator.ModeKeys.PREDICT: # Return the predictions and the specification for serving a SavedModel return tf.estimator.EstimatorSpec( mode=mode, predictions=predictions, export_outputs={ 'predict': tf.estimator.export.PredictOutput(predictions) }) # Calculate loss, which includes softmax cross entropy and L2 regularization. labels = tf.cast(labels, tf.int32) if label_smoothing != 0.0: one_hot_labels = tf.one_hot(labels, 1001) cross_entropy = tf.compat.v1.losses.softmax_cross_entropy( logits=logits, onehot_labels=one_hot_labels, label_smoothing=label_smoothing) else: cross_entropy = tf.compat.v1.losses.sparse_softmax_cross_entropy( logits=logits, labels=labels) # Create a tensor named cross_entropy for logging purposes. tf.identity(cross_entropy, name='cross_entropy') tf.compat.v1.summary.scalar('cross_entropy', cross_entropy) # If no loss_filter_fn is passed, assume we want the default behavior, # which is that batch_normalization variables are excluded from loss. def exclude_batch_norm(name): return 'batch_normalization' not in name loss_filter_fn = loss_filter_fn or exclude_batch_norm # Add weight decay to the loss. l2_loss = weight_decay * tf.add_n( # loss is computed using fp32 for numerical stability. [ tf.nn.l2_loss(tf.cast(v, tf.float32)) for v in tf.compat.v1.trainable_variables() if loss_filter_fn(v.name) ]) tf.compat.v1.summary.scalar('l2_loss', l2_loss) loss = cross_entropy + l2_loss if mode == tf.estimator.ModeKeys.TRAIN: global_step = tf.compat.v1.train.get_or_create_global_step() learning_rate = learning_rate_fn(global_step) # Create a tensor named learning_rate for logging purposes tf.identity(learning_rate, name='learning_rate') tf.compat.v1.summary.scalar('learning_rate', learning_rate) if flags.FLAGS.enable_lars: tf.compat.v1.logging.info('Using LARS Optimizer.') optimizer = lars.LARSOptimizer( learning_rate, momentum=momentum, weight_decay=weight_decay, skip_list=['batch_normalization', 'bias']) if flags.FLAGS.is_mlperf_enabled: mllogger.event(key=mllog.constants.OPT_NAME, value=mllog.constants.LARS) mllogger.event(key=mllog.constants.LARS_EPSILON, value=0.0) mllogger.event(key=mllog.constants.LARS_OPT_WEIGHT_DECAY, value=weight_decay) else: optimizer = tf.compat.v1.train.MomentumOptimizer( learning_rate=learning_rate, momentum=momentum ) fp16_implementation = getattr(flags.FLAGS, 'fp16_implementation', None) if fp16_implementation == 'graph_rewrite': optimizer = ( tf.compat.v1.train.experimental.enable_mixed_precision_graph_rewrite( optimizer, loss_scale=loss_scale)) if horovod_enabled(): optimizer = hvd.DistributedOptimizer(optimizer) def _dense_grad_filter(gvs): """Only apply gradient updates to the final layer. This function is used for fine tuning. Args: gvs: list of tuples with gradients and variable info Returns: filtered gradients so that only the dense layer remains """ return [(g, v) for g, v in gvs if 'dense' in v.name] if loss_scale != 1 and fp16_implementation != 'graph_rewrite': # When computing fp16 gradients, often intermediate tensor values are # so small, they underflow to 0. To avoid this, we multiply the loss by # loss_scale to make these tensor values loss_scale times bigger. scaled_grad_vars = optimizer.compute_gradients(loss * loss_scale) if fine_tune: scaled_grad_vars = _dense_grad_filter(scaled_grad_vars) # Once the gradient computation is complete we can scale the gradients # back to the correct scale before passing them to the optimizer. unscaled_grad_vars = [(grad / loss_scale, var) for grad, var in scaled_grad_vars] minimize_op = optimizer.apply_gradients(unscaled_grad_vars, global_step) else: grad_vars = optimizer.compute_gradients(loss*loss_scale) if fine_tune: grad_vars = _dense_grad_filter(grad_vars) grad_vars = [(grad / loss_scale, var) for grad, var in grad_vars] minimize_op = optimizer.apply_gradients(grad_vars, global_step) update_ops = tf.compat.v1.get_collection(tf.compat.v1.GraphKeys.UPDATE_OPS) if flags.FLAGS.is_mlperf_enabled: train_op = tf.group(minimize_op, update_ops, num_examples_metric[1]) else: train_op = tf.group(minimize_op, update_ops) else: train_op = None accuracy = tf.compat.v1.metrics.accuracy(labels, predictions['classes']) accuracy_top_5 = tf.compat.v1.metrics.mean( tf.nn.in_top_k(predictions=logits, targets=labels, k=5, name='top_5_op')) metrics = {'accuracy': accuracy, 'accuracy_top_5': accuracy_top_5} if flags.FLAGS.is_mlperf_enabled: metrics.update({_NUM_EXAMPLES_NAME: num_examples_metric}) # Create a tensor named train_accuracy for logging purposes tf.identity(accuracy[1], name='train_accuracy') tf.identity(accuracy_top_5[1], name='train_accuracy_top_5') tf.compat.v1.summary.scalar('train_accuracy', accuracy[1]) tf.compat.v1.summary.scalar('train_accuracy_top_5', accuracy_top_5[1]) return tf.estimator.EstimatorSpec( mode=mode, predictions=predictions, loss=loss, train_op=train_op, eval_metric_ops=metrics)
def resnet_model_fn(features, labels, mode, model_class, resnet_size, weight_decay, learning_rate_fn, momentum, data_format, version, loss_scale, loss_filter_fn=None, dtype=resnet_model.DEFAULT_DTYPE, label_smoothing=0.0, enable_lars=False): """Shared functionality for different resnet model_fns. Initializes the ResnetModel representing the model layers and uses that model to build the necessary EstimatorSpecs for the `mode` in question. For training, this means building losses, the optimizer, and the train op that get passed into the EstimatorSpec. For evaluation and prediction, the EstimatorSpec is returned without a train op, but with the necessary parameters for the given mode. Args: features: tensor representing input images labels: tensor representing class labels for all input images mode: current estimator mode; should be one of `tf.estimator.ModeKeys.TRAIN`, `EVALUATE`, `PREDICT` model_class: a class representing a TensorFlow model that has a __call__ function. We assume here that this is a subclass of ResnetModel. resnet_size: A single integer for the size of the ResNet model. weight_decay: weight decay loss rate used to regularize learned variables. learning_rate_fn: function that returns the current learning rate given the current global_step momentum: momentum term used for optimization data_format: Input format ('channels_last', 'channels_first', or None). If set to None, the format is dependent on whether a GPU is available. version: Integer representing which version of the ResNet network to use. See README for details. Valid values: [1, 2] loss_scale: The factor to scale the loss for numerical stability. A detailed summary is present in the arg parser help text. loss_filter_fn: function that takes a string variable name and returns True if the var should be included in loss calculation, and False otherwise. If None, batch_normalization variables will be excluded from the loss. dtype: the TensorFlow dtype to use for calculations. Returns: EstimatorSpec parameterized according to the input params and the current mode. """ # Generate a summary node for the images tf.summary.image('images', features, max_outputs=6) # Checks that features/images have same data type being used for calculations. assert features.dtype == dtype features = tf.cast(features, dtype) model = model_class(resnet_size, data_format, version=version, dtype=dtype) logits = model(features, mode == tf.estimator.ModeKeys.TRAIN) # This acts as a no-op if the logits are already in fp32 (provided logits are # not a SparseTensor). If dtype is is low precision, logits must be cast to # fp32 for numerical stability. logits = tf.cast(logits, tf.float32) num_examples_metric = tf_mlperf_log.sum_metric(tensor=tf.shape(logits)[0], name=_NUM_EXAMPLES_NAME) predictions = { 'classes': tf.argmax(logits, axis=1), 'probabilities': tf.nn.softmax(logits, name='softmax_tensor') } if mode == tf.estimator.ModeKeys.PREDICT: # Return the predictions and the specification for serving a SavedModel return tf.estimator.EstimatorSpec( mode=mode, predictions=predictions, export_outputs={ 'predict': tf.estimator.export.PredictOutput(predictions) }) # Calculate loss, which includes softmax cross entropy and L2 regularization. mlperf_log.resnet_print(key=mlperf_log.MODEL_HP_LOSS_FN, value=mlperf_log.CCE) if label_smoothing != 0.0: one_hot_labels = tf.one_hot(labels, 1001) cross_entropy = tf.losses.softmax_cross_entropy( logits=logits, onehot_labels=one_hot_labels, label_smoothing=label_smoothing) else: cross_entropy = tf.losses.sparse_softmax_cross_entropy(logits=logits, labels=labels) # Create a tensor named cross_entropy for logging purposes. tf.identity(cross_entropy, name='cross_entropy') tf.summary.scalar('cross_entropy', cross_entropy) # If no loss_filter_fn is passed, assume we want the default behavior, # which is that batch_normalization variables are excluded from loss. def exclude_batch_norm(name): return 'batch_normalization' not in name loss_filter_fn = loss_filter_fn or exclude_batch_norm mlperf_log.resnet_print(key=mlperf_log.MODEL_EXCLUDE_BN_FROM_L2, value=not loss_filter_fn('batch_normalization')) # Add weight decay to the loss. mlperf_log.resnet_print(key=mlperf_log.MODEL_L2_REGULARIZATION, value=weight_decay) l2_loss = weight_decay * tf.add_n( # loss is computed using fp32 for numerical stability. [ tf.nn.l2_loss(tf.cast(v, tf.float32)) for v in tf.trainable_variables() if loss_filter_fn(v.name) ]) tf.summary.scalar('l2_loss', l2_loss) loss = cross_entropy + l2_loss if mode == tf.estimator.ModeKeys.TRAIN: global_step = tf.train.get_or_create_global_step() learning_rate = learning_rate_fn(global_step) log_id = mlperf_log.resnet_print(key=mlperf_log.OPT_LR, deferred=True) learning_rate = tf_mlperf_log.log_deferred(op=learning_rate, log_id=log_id, every_n=100) # Create a tensor named learning_rate for logging purposes tf.identity(learning_rate, name='learning_rate') tf.summary.scalar('learning_rate', learning_rate) mlperf_log.resnet_print(key=mlperf_log.OPT_NAME, value=mlperf_log.SGD_WITH_MOMENTUM) mlperf_log.resnet_print(key=mlperf_log.OPT_MOMENTUM, value=momentum) if enable_lars: optimizer = tf.contrib.opt.LARSOptimizer( learning_rate, momentum=momentum, weight_decay=weight_decay, skip_list=['batch_normalization', 'bias']) else: optimizer = tf.train.MomentumOptimizer(learning_rate=learning_rate, momentum=momentum) if loss_scale != 1: # When computing fp16 gradients, often intermediate tensor values are # so small, they underflow to 0. To avoid this, we multiply the loss by # loss_scale to make these tensor values loss_scale times bigger. scaled_grad_vars = optimizer.compute_gradients(loss * loss_scale) # Once the gradient computation is complete we can scale the gradients # back to the correct scale before passing them to the optimizer. unscaled_grad_vars = [(grad / loss_scale, var) if grad is not None else (grad, var) for grad, var in scaled_grad_vars] minimize_op = optimizer.apply_gradients(unscaled_grad_vars, global_step) else: minimize_op = optimizer.minimize(loss, global_step) update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) train_op = tf.group(minimize_op, update_ops, num_examples_metric[1]) else: train_op = None accuracy = tf.metrics.accuracy(labels, predictions['classes']) accuracy_top_5 = tf.metrics.mean( tf.nn.in_top_k(predictions=logits, targets=labels, k=5, name='top_5_op')) metrics = { 'accuracy': accuracy, 'accuracy_top_5': accuracy_top_5, _NUM_EXAMPLES_NAME: num_examples_metric } # Create a tensor named train_accuracy for logging purposes tf.identity(accuracy[1], name='train_accuracy') tf.identity(accuracy_top_5[1], name='train_accuracy_top_5') tf.summary.scalar('train_accuracy', accuracy[1]) tf.summary.scalar('train_accuracy_top_5', accuracy_top_5[1]) return tf.estimator.EstimatorSpec(mode=mode, predictions=predictions, loss=loss, train_op=train_op, eval_metric_ops=metrics)