def resnet_model_fn(features, labels, mode, params): """The model_fn for ResNet to be used with TPUEstimator. Args: features: `Tensor` of batched images. If transpose_input is enabled, it is transposed to device layout and reshaped to 1D tensor. labels: `Tensor` of labels for the data samples mode: one of `tf.estimator.ModeKeys.{TRAIN,EVAL,PREDICT}` params: `dict` of parameters passed to the model from the TPUEstimator, `params['batch_size']` is always provided and should be used as the effective batch size. Returns: A `TPUEstimatorSpec` for the model """ if isinstance(features, dict): features = features['feature'] # In most cases, the default data format NCHW instead of NHWC should be # used for a significant performance boost on GPU/TPU. NHWC should be used # only if the network needs to be run on CPU since the pooling operations # are only supported on NHWC. if params['data_format'] == 'channels_first': assert not params['transpose_input'] # channels_first only for GPU features = tf.transpose(features, [0, 3, 1, 2]) if params['transpose_input'] and mode != tf.estimator.ModeKeys.PREDICT: image_size = params['image_size'] features = tf.reshape(features, [image_size, image_size, 1, -1]) features = tf.transpose(features, [3, 0, 1, 2]) # HWCN to NHWC # DropBlock keep_prob for the 4 block groups of ResNet architecture. # None means applying no DropBlock at the corresponding block group. dropblock_keep_probs = [None] * 4 if params['dropblock_groups']: # Scheduled keep_prob for DropBlock. train_steps = tf.cast(params['train_steps'], tf.float32) current_step = tf.cast(tf.train.get_global_step(), tf.float32) current_ratio = current_step / train_steps dropblock_keep_prob = (1 - current_ratio * ( 1 - params['dropblock_keep_prob'])) # Computes DropBlock keep_prob for different block groups of ResNet. dropblock_groups = [int(x) for x in params['dropblock_groups'].split(',')] for block_group in dropblock_groups: if block_group < 1 or block_group > 4: raise ValueError( 'dropblock_groups should be a comma separated list of integers ' 'between 1 and 4 (dropblcok_groups: {}).' .format(params['dropblock_groups'])) dropblock_keep_probs[block_group - 1] = 1 - ( (1 - dropblock_keep_prob) / 4.0**(4 - block_group)) # This nested function allows us to avoid duplicating the logic which # builds the network, for different values of --precision. def build_network(): network = resnet_model.resnet_v1( resnet_depth=params['resnet_depth'], num_classes=params['num_label_classes'], dropblock_size=params['dropblock_size'], dropblock_keep_probs=dropblock_keep_probs, data_format=params['data_format']) return network( inputs=features, is_training=(mode == tf.estimator.ModeKeys.TRAIN)) # Compute the summary statistic if params['precision'] == 'bfloat16': with tf.tpu.bfloat16_scope(): sum_stat = build_network() sum_stat = tf.cast(sum_stat, tf.float32) elif params['precision'] == 'float32': sum_stat = build_network() if mode == tf.estimator.ModeKeys.PREDICT: predictions = { 'summary': sum_stat, } return tf.estimator.EstimatorSpec( mode=mode, predictions=predictions, export_outputs={ 'inference': tf.estimator.export.PredictOutput(predictions) }) n = params['num_label_classes'] # If necessary, in the model_fn, use params['batch_size'] instead the batch # size flags (--train_batch_size or --eval_batch_size). batch_size = params['batch_size'] # pylint: disable=unused-variable # Add a little bit of scatter to the labels to smooth out the distribution if (params['label_smoothing'] > 0.) and (mode == tf.estimator.ModeKeys.TRAIN): labels += params['label_smoothing']*tf.random_normal(shape=[batch_size, n]) # Now build a conditional density estimator from this density # Defines the chain of bijective transforms if params['training_loss'] == 'VMIM': net = sum_stat # Below is the chain for a MAF chain = [ tfp.bijectors.MaskedAutoregressiveFlow( shift_and_log_scale_fn=masked_autoregressive_conditional_template(hidden_layers=[128,128], conditional_tensor=net, shift_only=False)), tfb.Permute(np.arange(n)[::-1]), tfp.bijectors.MaskedAutoregressiveFlow( shift_and_log_scale_fn=masked_autoregressive_conditional_template(hidden_layers=[128,128], conditional_tensor=net, shift_only=False)), tfb.Permute(np.arange(n)[::-1]), tfp.bijectors.MaskedAutoregressiveFlow( shift_and_log_scale_fn=masked_autoregressive_conditional_template(hidden_layers=[128,128], conditional_tensor=net, shift_only=True)), tfb.Permute(np.arange(n)[::-1]), tfp.bijectors.MaskedAutoregressiveFlow( shift_and_log_scale_fn=masked_autoregressive_conditional_template(hidden_layers=[128,128], conditional_tensor=net, shift_only=True)), ] bij = tfb.Chain(chain) prior = tfd.MultivariateNormalDiag(loc=tf.zeros(n), scale_identity_multiplier=1.0) distribution = tfd.TransformedDistribution(prior, bijector=bij) # Compute loss function with some L2 regularization loss = - tf.reduce_mean(distribution.log_prob(labels),axis=0) elif params['training_loss'] == 'MAE': loss = tf.reduce_mean(tf.keras.losses.mae(labels, sum_stat),axis=0) elif params['training_loss'] == 'MSE': loss = tf.reduce_mean(tf.keras.losses.mse(labels, sum_stat),axis=0) else: raise NotImplementedError # Add weight decay to the loss for non-batch-normalization variables. if params['enable_lars']: loss = loss else: loss = loss + params['weight_decay'] * tf.add_n([ tf.nn.l2_loss(v) for v in tf.trainable_variables() if 'batch_normalization' not in v.name ]) host_call = None if mode == tf.estimator.ModeKeys.TRAIN: # Compute the current epoch and associated learning rate from global_step. global_step = tf.train.get_global_step() steps_per_epoch = params['num_train_images'] / params['train_batch_size'] current_epoch = (tf.cast(global_step, tf.float32) / steps_per_epoch) # LARS is a large batch optimizer. LARS enables higher accuracy at batch 16K # and larger batch sizes. if params['enable_lars']: learning_rate = 0.0 optimizer = lars_util.init_lars_optimizer(current_epoch, params) else: learning_rate = learning_rate_schedule(params, current_epoch) optimizer = tf.train.MomentumOptimizer( learning_rate=learning_rate, momentum=params['momentum'], use_nesterov=True) if params['use_tpu']: # When using TPU, wrap the optimizer with CrossShardOptimizer which # handles synchronization details between different TPU cores. To the # user, this should look like regular synchronous training. optimizer = tf.tpu.CrossShardOptimizer(optimizer) # Batch normalization requires UPDATE_OPS to be added as a dependency to # the train operation. update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) with tf.control_dependencies(update_ops): train_op = optimizer.minimize(loss, global_step) if not params['skip_host_call']: def host_call_fn(gs, loss, lr, ce): """Training host call. Creates scalar summaries for training metrics. This function is executed on the CPU and should not directly reference any Tensors in the rest of the `model_fn`. To pass Tensors from the model to the `metric_fn`, provide as part of the `host_call`. See https://www.tensorflow.org/api_docs/python/tf/estimator/tpu/TPUEstimatorSpec for more information. Arguments should match the list of `Tensor` objects passed as the second element in the tuple passed to `host_call`. Args: gs: `Tensor with shape `[batch]` for the global_step loss: `Tensor` with shape `[batch]` for the training loss. lr: `Tensor` with shape `[batch]` for the learning_rate. ce: `Tensor` with shape `[batch]` for the current_epoch. Returns: List of summary ops to run on the CPU host. """ gs = gs[0] # Host call fns are executed params['iterations_per_loop'] times after # one TPU loop is finished, setting max_queue value to the same as # number of iterations will make the summary writer only flush the data # to storage once per loop. with tf2.summary.create_file_writer( FLAGS.model_dir, max_queue=params['iterations_per_loop']).as_default(): with tf2.summary.record_if(True): tf2.summary.scalar('loss', loss[0], step=gs) tf2.summary.scalar('learning_rate', lr[0], step=gs) tf2.summary.scalar('current_epoch', ce[0], step=gs) return tf.summary.all_v2_summary_ops() # To log the loss, current learning rate, and epoch for Tensorboard, the # summary op needs to be run on the host CPU via host_call. host_call # expects [batch_size, ...] Tensors, thus reshape to introduce a batch # dimension. These Tensors are implicitly concatenated to # [params['batch_size']]. gs_t = tf.reshape(global_step, [1]) loss_t = tf.reshape(loss, [1]) lr_t = tf.reshape(learning_rate, [1]) ce_t = tf.reshape(current_epoch, [1]) host_call = (host_call_fn, [gs_t, loss_t, lr_t, ce_t]) else: train_op = None eval_metrics = None return tf.estimator.tpu.TPUEstimatorSpec( mode=mode, loss=loss, train_op=train_op, host_call=host_call, eval_metrics=eval_metrics)
def resnet_model_fn(features, labels, mode, params): """The model_fn for ResNet to be used with TPUEstimator. Args: features: `Tensor` of batched images. If transpose_input is enabled, it is transposed to device layout and reshaped to 1D tensor. labels: `Tensor` of labels for the data samples mode: one of `tf.estimator.ModeKeys.{TRAIN,EVAL,PREDICT}` params: `dict` of parameters passed to the model from the TPUEstimator, `params['batch_size']` is always provided and should be used as the effective batch size. Returns: A `TPUEstimatorSpec` for the model """ if isinstance(features, dict): features = features['feature'] # In most cases, the default data format NCHW instead of NHWC should be # used for a significant performance boost on GPU/TPU. NHWC should be used # only if the network needs to be run on CPU since the pooling operations # are only supported on NHWC. if params['data_format'] == 'channels_first': assert not params['transpose_input'] # channels_first only for GPU features = tf.transpose(features, [0, 3, 1, 2]) if params['transpose_input'] and mode != tf.estimator.ModeKeys.PREDICT: image_size = tf.sqrt(tf.shape(features)[0] / (3 * tf.shape(labels)[0])) features = tf.reshape(features, [image_size, image_size, 3, -1]) features = tf.transpose(features, [3, 0, 1, 2]) # HWCN to NHWC # Normalize the image to zero mean and unit variance. features -= tf.constant(MEAN_RGB, shape=[1, 1, 3], dtype=features.dtype) features /= tf.constant(STDDEV_RGB, shape=[1, 1, 3], dtype=features.dtype) # DropBlock keep_prob for the 4 block groups of ResNet architecture. # None means applying no DropBlock at the corresponding block group. dropblock_keep_probs = [None] * 4 if params['dropblock_groups']: # Scheduled keep_prob for DropBlock. train_steps = tf.cast(params['train_steps'], tf.float32) current_step = tf.cast(tf.train.get_global_step(), tf.float32) current_ratio = current_step / train_steps dropblock_keep_prob = (1 - current_ratio * (1 - params['dropblock_keep_prob'])) # Computes DropBlock keep_prob for different block groups of ResNet. dropblock_groups = [ int(x) for x in params['dropblock_groups'].split(',') ] for block_group in dropblock_groups: if block_group < 1 or block_group > 4: raise ValueError( 'dropblock_groups should be a comma separated list of integers ' 'between 1 and 4 (dropblcok_groups: {}).'.format( params['dropblock_groups'])) dropblock_keep_probs[block_group - 1] = 1 - ( (1 - dropblock_keep_prob) / 4.0**(4 - block_group)) # This nested function allows us to avoid duplicating the logic which # builds the network, for different values of --precision. def build_network(): network = resnet_model.resnet_v1( resnet_depth=params['resnet_depth'], num_classes=params['num_label_classes'], dropblock_size=params['dropblock_size'], dropblock_keep_probs=dropblock_keep_probs, data_format=params['data_format']) return network(inputs=features, is_training=(mode == tf.estimator.ModeKeys.TRAIN)) if params['precision'] == 'bfloat16': with tf.contrib.tpu.bfloat16_scope(): logits = build_network() logits = tf.cast(logits, tf.float32) elif params['precision'] == 'float32': logits = build_network() if mode == tf.estimator.ModeKeys.PREDICT: predictions = { 'classes': tf.argmax(logits, axis=1), 'probabilities': tf.nn.softmax(logits, name='softmax_tensor') } return tf.estimator.EstimatorSpec( mode=mode, predictions=predictions, export_outputs={ 'classify': tf.estimator.export.PredictOutput(predictions) }) # If necessary, in the model_fn, use params['batch_size'] instead the batch # size flags (--train_batch_size or --eval_batch_size). batch_size = params['batch_size'] # pylint: disable=unused-variable # Calculate loss, which includes softmax cross entropy and L2 regularization. one_hot_labels = tf.one_hot(labels, params['num_label_classes']) cross_entropy = tf.losses.softmax_cross_entropy( logits=logits, onehot_labels=one_hot_labels, label_smoothing=params['label_smoothing']) # Add weight decay to the loss for non-batch-normalization variables. loss = cross_entropy + params['weight_decay'] * tf.add_n([ tf.nn.l2_loss(v) for v in tf.trainable_variables() if 'batch_normalization' not in v.name ]) host_call = None if mode == tf.estimator.ModeKeys.TRAIN: # Compute the current epoch and associated learning rate from global_step. global_step = tf.train.get_global_step() steps_per_epoch = params['num_train_images'] / params[ 'train_batch_size'] current_epoch = (tf.cast(global_step, tf.float32) / steps_per_epoch) # LARS is a large batch optimizer. LARS enables higher accuracy at batch 16K # and larger batch sizes. if params['train_batch_size'] >= 16384 and params['enable_lars']: learning_rate = 0.0 optimizer = lars_util.init_lars_optimizer(current_epoch, params) else: learning_rate = learning_rate_schedule(params, current_epoch) optimizer = tf.train.MomentumOptimizer(learning_rate=learning_rate, momentum=params['momentum'], use_nesterov=True) if params['use_tpu']: # When using TPU, wrap the optimizer with CrossShardOptimizer which # handles synchronization details between different TPU cores. To the # user, this should look like regular synchronous training. optimizer = tf.contrib.tpu.CrossShardOptimizer(optimizer) # Batch normalization requires UPDATE_OPS to be added as a dependency to # the train operation. update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) with tf.control_dependencies(update_ops): train_op = optimizer.minimize(loss, global_step) if not params['skip_host_call']: def host_call_fn(gs, loss, lr, ce): """Training host call. Creates scalar summaries for training metrics. This function is executed on the CPU and should not directly reference any Tensors in the rest of the `model_fn`. To pass Tensors from the model to the `metric_fn`, provide as part of the `host_call`. See https://www.tensorflow.org/api_docs/python/tf/contrib/tpu/TPUEstimatorSpec for more information. Arguments should match the list of `Tensor` objects passed as the second element in the tuple passed to `host_call`. Args: gs: `Tensor with shape `[batch]` for the global_step loss: `Tensor` with shape `[batch]` for the training loss. lr: `Tensor` with shape `[batch]` for the learning_rate. ce: `Tensor` with shape `[batch]` for the current_epoch. Returns: List of summary ops to run on the CPU host. """ gs = gs[0] # Host call fns are executed params['iterations_per_loop'] times after # one TPU loop is finished, setting max_queue value to the same as # number of iterations will make the summary writer only flush the data # to storage once per loop. with summary.create_file_writer( FLAGS.model_dir, max_queue=params['iterations_per_loop']).as_default(): with summary.always_record_summaries(): summary.scalar('loss', loss[0], step=gs) summary.scalar('learning_rate', lr[0], step=gs) summary.scalar('current_epoch', ce[0], step=gs) return summary.all_summary_ops() # To log the loss, current learning rate, and epoch for Tensorboard, the # summary op needs to be run on the host CPU via host_call. host_call # expects [batch_size, ...] Tensors, thus reshape to introduce a batch # dimension. These Tensors are implicitly concatenated to # [params['batch_size']]. gs_t = tf.reshape(global_step, [1]) loss_t = tf.reshape(loss, [1]) lr_t = tf.reshape(learning_rate, [1]) ce_t = tf.reshape(current_epoch, [1]) host_call = (host_call_fn, [gs_t, loss_t, lr_t, ce_t]) else: train_op = None eval_metrics = None if mode == tf.estimator.ModeKeys.EVAL: def metric_fn(labels, logits): """Evaluation metric function. Evaluates accuracy. This function is executed on the CPU and should not directly reference any Tensors in the rest of the `model_fn`. To pass Tensors from the model to the `metric_fn`, provide as part of the `eval_metrics`. See https://www.tensorflow.org/api_docs/python/tf/contrib/tpu/TPUEstimatorSpec for more information. Arguments should match the list of `Tensor` objects passed as the second element in the tuple passed to `eval_metrics`. Args: labels: `Tensor` with shape `[batch]`. logits: `Tensor` with shape `[batch, num_classes]`. Returns: A dict of the metrics to return from evaluation. """ predictions = tf.argmax(logits, axis=1) top_1_accuracy = tf.metrics.accuracy(labels, predictions) in_top_5 = tf.cast(tf.nn.in_top_k(logits, labels, 5), tf.float32) top_5_accuracy = tf.metrics.mean(in_top_5) return { 'top_1_accuracy': top_1_accuracy, 'top_5_accuracy': top_5_accuracy, } eval_metrics = (metric_fn, [labels, logits]) return tf.contrib.tpu.TPUEstimatorSpec(mode=mode, loss=loss, train_op=train_op, host_call=host_call, eval_metrics=eval_metrics)
def resnet_model_fn(features, labels, mode, params): """The model_fn for ResNet to be used with TPUEstimator. Args: features: `Tensor` of batched images. labels: `Tensor` of labels for the data samples mode: one of `tf.estimator.ModeKeys.{TRAIN,EVAL,PREDICT}` params: `dict` of parameters passed to the model from the TPUEstimator, `params['batch_size']` is always provided and should be used as the effective batch size. Returns: A `TPUEstimatorSpec` for the model """ if isinstance(features, dict): features = features['feature'] # In most cases, the default data format NCHW instead of NHWC should be # used for a significant performance boost on GPU/TPU. NHWC should be used # only if the network needs to be run on CPU since the pooling operations # are only supported on NHWC. if FLAGS.data_format == 'channels_first': assert not FLAGS.transpose_input # channels_first only for GPU features = tf.transpose(features, [0, 3, 1, 2]) if FLAGS.transpose_input and mode != tf.estimator.ModeKeys.PREDICT: features = tf.transpose(features, [3, 0, 1, 2]) # HWCN to NHWC # Normalize the image to zero mean and unit variance. features -= tf.constant(MEAN_RGB, shape=[1, 1, 3], dtype=features.dtype) features /= tf.constant(STDDEV_RGB, shape=[1, 1, 3], dtype=features.dtype) # This nested function allows us to avoid duplicating the logic which # builds the network, for different values of --precision. def build_network(): network = resnet_model.resnet_v1( resnet_depth=FLAGS.resnet_depth, num_classes=FLAGS.num_label_classes, data_format=FLAGS.data_format) return network( inputs=features, is_training=(mode == tf.estimator.ModeKeys.TRAIN)) if FLAGS.precision == 'bfloat16': with tf.contrib.tpu.bfloat16_scope(): logits = build_network() logits = tf.cast(logits, tf.float32) elif FLAGS.precision == 'float32': logits = build_network() if mode == tf.estimator.ModeKeys.PREDICT: predictions = { 'classes': tf.argmax(logits, axis=1), 'probabilities': tf.sigmoid(logits, name='sigmoid_tensor') } return tf.estimator.EstimatorSpec( mode=mode, predictions=predictions, export_outputs={ 'classify': tf.estimator.export.PredictOutput(predictions) }) # If necessary, in the model_fn, use params['batch_size'] instead the batch # size flags (--train_batch_size or --eval_batch_size). batch_size = params['batch_size'] # pylint: disable=unused-variable # Calculate loss, which includes classification loss and L2 regularization. one_hot_labels = tf.one_hot(labels, FLAGS.num_label_classes) # Normalized weights based on inverse number of effective data per class. img_num_per_cls = [int(line.strip()) for line in open( FLAGS.img_num_per_cls_file, 'r')] effective_num = 1.0 - np.power(FLAGS.beta, img_num_per_cls) weights = (1.0 - FLAGS.beta) / np.array(effective_num) weights = weights / np.sum(weights) * FLAGS.num_label_classes weights = tf.cast(weights, dtype=tf.float32) weights = tf.expand_dims(weights, 0) weights = tf.tile(weights, [tf.shape(one_hot_labels)[0], 1]) * one_hot_labels weights = tf.reduce_sum(weights, axis=1) weights = tf.expand_dims(weights, 1) weights = tf.tile(weights, [1, FLAGS.num_label_classes]) classification_loss = focal_loss( one_hot_labels, logits, weights, FLAGS.gamma) # Add weight decay to the loss for non-batch-normalization variables. loss = classification_loss + FLAGS.weight_decay * tf.add_n( [tf.nn.l2_loss(v) for v in tf.trainable_variables() if 'batch_normalization' not in v.name and 'dense/bias' not in v.name]) host_call = None if mode == tf.estimator.ModeKeys.TRAIN: # Compute the current epoch and associated learning rate from global_step. global_step = tf.train.get_global_step() steps_per_epoch = FLAGS.num_train_images / FLAGS.train_batch_size current_epoch = (tf.cast(global_step, tf.float32) / steps_per_epoch) # LARS is a large batch optimizer. LARS enables higher accuracy at batch 16K # and larger batch sizes. if FLAGS.train_batch_size >= 16384 and FLAGS.enable_lars: learning_rate = 0.0 optimizer = lars_util.init_lars_optimizer(current_epoch) else: learning_rate = learning_rate_schedule(current_epoch) optimizer = tf.train.MomentumOptimizer( learning_rate=learning_rate, momentum=FLAGS.momentum, use_nesterov=True) if FLAGS.use_tpu: # When using TPU, wrap the optimizer with CrossShardOptimizer which # handles synchronization details between different TPU cores. To the # user, this should look like regular synchronous training. optimizer = tf.contrib.tpu.CrossShardOptimizer(optimizer) # Batch normalization requires UPDATE_OPS to be added as a dependency to # the train operation. update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) with tf.control_dependencies(update_ops): train_op = optimizer.minimize(loss, global_step) if not FLAGS.skip_host_call: def host_call_fn(gs, fl_loss, loss, lr, ce): """Training host call. Creates scalar summaries for training metrics. This function is executed on the CPU and should not directly reference any Tensors in the rest of the `model_fn`. To pass Tensors from the model to the `metric_fn`, provide as part of the `host_call`. See https://www.tensorflow.org/api_docs/python/tf/contrib/tpu/TPUEstimatorSpec for more information. Arguments should match the list of `Tensor` objects passed as the second element in the tuple passed to `host_call`. Args: gs: `Tensor with shape `[batch]` for the global_step. fl_loss: `Tensor` with shape `[batch]` for the training focal loss. loss: `Tensor` with shape `[batch]` for the training loss. lr: `Tensor` with shape `[batch]` for the learning_rate. ce: `Tensor` with shape `[batch]` for the current_epoch. Returns: List of summary ops to run on the CPU host. """ gs = gs[0] # Host call fns are executed FLAGS.iterations_per_loop times after one # TPU loop is finished, setting max_queue value to the same as number of # iterations will make the summary writer only flush the data to storage # once per loop. with summary.create_file_writer( FLAGS.model_dir, max_queue=FLAGS.iterations_per_loop).as_default(): with summary.always_record_summaries(): summary.scalar('focal_loss', fl_loss[0], step=gs) summary.scalar('loss', loss[0], step=gs) summary.scalar('learning_rate', lr[0], step=gs) summary.scalar('current_epoch', ce[0], step=gs) return summary.all_summary_ops() # To log the loss, current learning rate, and epoch for Tensorboard, the # summary op needs to be run on the host CPU via host_call. host_call # expects [batch_size, ...] Tensors, thus reshape to introduce a batch # dimension. These Tensors are implicitly concatenated to # [params['batch_size']]. gs_t = tf.reshape(global_step, [1]) fl_loss_t = tf.reshape(classification_loss, [1]) loss_t = tf.reshape(loss, [1]) lr_t = tf.reshape(learning_rate, [1]) ce_t = tf.reshape(current_epoch, [1]) host_call = (host_call_fn, [gs_t, fl_loss_t, loss_t, lr_t, ce_t]) else: train_op = None eval_metrics = None if mode == tf.estimator.ModeKeys.EVAL: def metric_fn(labels, logits): """Evaluation metric function. Evaluates accuracy. This function is executed on the CPU and should not directly reference any Tensors in the rest of the `model_fn`. To pass Tensors from the model to the `metric_fn`, provide as part of the `eval_metrics`. See https://www.tensorflow.org/api_docs/python/tf/contrib/tpu/TPUEstimatorSpec for more information. Arguments should match the list of `Tensor` objects passed as the second element in the tuple passed to `eval_metrics`. Args: labels: `Tensor` with shape `[batch]`. logits: `Tensor` with shape `[batch, num_classes]`. Returns: A dict of the metrics to return from evaluation. """ predictions = tf.argmax(logits, axis=1) top_1_accuracy = tf.metrics.accuracy(labels, predictions) in_top_5 = tf.cast(tf.nn.in_top_k(logits, labels, 5), tf.float32) top_5_accuracy = tf.metrics.mean(in_top_5) return { 'top_1_accuracy': top_1_accuracy, 'top_5_accuracy': top_5_accuracy, } eval_metrics = (metric_fn, [labels, logits]) return tf.contrib.tpu.TPUEstimatorSpec( mode=mode, loss=loss, train_op=train_op, host_call=host_call, eval_metrics=eval_metrics)