コード例 #1
0
ファイル: sselfi_main.py プロジェクト: EiffL/SSELFI
def resnet_model_fn(features, labels, mode, params):
  """The model_fn for ResNet to be used with TPUEstimator.

  Args:
    features: `Tensor` of batched images. If transpose_input is enabled, it
        is transposed to device layout and reshaped to 1D tensor.
    labels: `Tensor` of labels for the data samples
    mode: one of `tf.estimator.ModeKeys.{TRAIN,EVAL,PREDICT}`
    params: `dict` of parameters passed to the model from the TPUEstimator,
        `params['batch_size']` is always provided and should be used as the
        effective batch size.

  Returns:
    A `TPUEstimatorSpec` for the model
  """
  if isinstance(features, dict):
    features = features['feature']

  # In most cases, the default data format NCHW instead of NHWC should be
  # used for a significant performance boost on GPU/TPU. NHWC should be used
  # only if the network needs to be run on CPU since the pooling operations
  # are only supported on NHWC.
  if params['data_format'] == 'channels_first':
    assert not params['transpose_input']    # channels_first only for GPU
    features = tf.transpose(features, [0, 3, 1, 2])

  if params['transpose_input'] and mode != tf.estimator.ModeKeys.PREDICT:
    image_size = params['image_size']
    features = tf.reshape(features, [image_size, image_size, 1, -1])
    features = tf.transpose(features, [3, 0, 1, 2])  # HWCN to NHWC

  # DropBlock keep_prob for the 4 block groups of ResNet architecture.
  # None means applying no DropBlock at the corresponding block group.
  dropblock_keep_probs = [None] * 4
  if params['dropblock_groups']:
    # Scheduled keep_prob for DropBlock.
    train_steps = tf.cast(params['train_steps'], tf.float32)
    current_step = tf.cast(tf.train.get_global_step(), tf.float32)
    current_ratio = current_step / train_steps
    dropblock_keep_prob = (1 - current_ratio * (
        1 - params['dropblock_keep_prob']))

    # Computes DropBlock keep_prob for different block groups of ResNet.
    dropblock_groups = [int(x) for x in params['dropblock_groups'].split(',')]
    for block_group in dropblock_groups:
      if block_group < 1 or block_group > 4:
        raise ValueError(
            'dropblock_groups should be a comma separated list of integers '
            'between 1 and 4 (dropblcok_groups: {}).'
            .format(params['dropblock_groups']))
      dropblock_keep_probs[block_group - 1] = 1 - (
          (1 - dropblock_keep_prob) / 4.0**(4 - block_group))

  # This nested function allows us to avoid duplicating the logic which
  # builds the network, for different values of --precision.
  def build_network():
    network = resnet_model.resnet_v1(
        resnet_depth=params['resnet_depth'],
        num_classes=params['num_label_classes'],
        dropblock_size=params['dropblock_size'],
        dropblock_keep_probs=dropblock_keep_probs,
        data_format=params['data_format'])
    return network(
        inputs=features, is_training=(mode == tf.estimator.ModeKeys.TRAIN))

  # Compute the summary statistic
  if params['precision'] == 'bfloat16':
    with tf.tpu.bfloat16_scope():
      sum_stat = build_network()
    sum_stat = tf.cast(sum_stat, tf.float32)
  elif params['precision'] == 'float32':
    sum_stat = build_network()

  if mode == tf.estimator.ModeKeys.PREDICT:
    predictions = {
        'summary': sum_stat,
    }
    return tf.estimator.EstimatorSpec(
        mode=mode,
        predictions=predictions,
        export_outputs={
            'inference': tf.estimator.export.PredictOutput(predictions)
        })

  n = params['num_label_classes']

  # If necessary, in the model_fn, use params['batch_size'] instead the batch
  # size flags (--train_batch_size or --eval_batch_size).
  batch_size = params['batch_size']   # pylint: disable=unused-variable

  # Add a little bit of scatter to the labels to smooth out the distribution
  if (params['label_smoothing'] > 0.) and (mode == tf.estimator.ModeKeys.TRAIN):
    labels += params['label_smoothing']*tf.random_normal(shape=[batch_size, n])


  # Now build a conditional density estimator from this density
  # Defines the chain of bijective transforms
  if params['training_loss'] == 'VMIM':
    net = sum_stat
    # Below is the chain for a MAF
    chain = [ tfp.bijectors.MaskedAutoregressiveFlow(
                 shift_and_log_scale_fn=masked_autoregressive_conditional_template(hidden_layers=[128,128],
                                                                                   conditional_tensor=net,
                                                                                   shift_only=False)),
              tfb.Permute(np.arange(n)[::-1]),
              tfp.bijectors.MaskedAutoregressiveFlow(
                 shift_and_log_scale_fn=masked_autoregressive_conditional_template(hidden_layers=[128,128],
                                                                                   conditional_tensor=net,
                                                                                   shift_only=False)),
              tfb.Permute(np.arange(n)[::-1]),
              tfp.bijectors.MaskedAutoregressiveFlow(
                 shift_and_log_scale_fn=masked_autoregressive_conditional_template(hidden_layers=[128,128],
                                                                                   conditional_tensor=net,
                                                                                   shift_only=True)),
              tfb.Permute(np.arange(n)[::-1]),
              tfp.bijectors.MaskedAutoregressiveFlow(
                 shift_and_log_scale_fn=masked_autoregressive_conditional_template(hidden_layers=[128,128],
                                                                                   conditional_tensor=net,
                                                                                   shift_only=True)),
            ]

    bij = tfb.Chain(chain)
    prior  = tfd.MultivariateNormalDiag(loc=tf.zeros(n), scale_identity_multiplier=1.0)
    distribution = tfd.TransformedDistribution(prior, bijector=bij)

    # Compute loss function with some L2 regularization
    loss = - tf.reduce_mean(distribution.log_prob(labels),axis=0)
  elif params['training_loss'] == 'MAE':
    loss = tf.reduce_mean(tf.keras.losses.mae(labels, sum_stat),axis=0)
  elif params['training_loss'] == 'MSE':
    loss = tf.reduce_mean(tf.keras.losses.mse(labels, sum_stat),axis=0)
  else:
    raise NotImplementedError

  # Add weight decay to the loss for non-batch-normalization variables.
  if params['enable_lars']:
    loss = loss
  else:
    loss = loss + params['weight_decay'] * tf.add_n([
        tf.nn.l2_loss(v)
        for v in tf.trainable_variables()
        if 'batch_normalization' not in v.name
    ])

  host_call = None
  if mode == tf.estimator.ModeKeys.TRAIN:
    # Compute the current epoch and associated learning rate from global_step.
    global_step = tf.train.get_global_step()
    steps_per_epoch = params['num_train_images'] / params['train_batch_size']
    current_epoch = (tf.cast(global_step, tf.float32) /
                     steps_per_epoch)
    # LARS is a large batch optimizer. LARS enables higher accuracy at batch 16K
    # and larger batch sizes.
    if params['enable_lars']:
      learning_rate = 0.0
      optimizer = lars_util.init_lars_optimizer(current_epoch, params)
    else:
      learning_rate = learning_rate_schedule(params, current_epoch)
      optimizer = tf.train.MomentumOptimizer(
          learning_rate=learning_rate,
          momentum=params['momentum'],
          use_nesterov=True)
    if params['use_tpu']:
      # When using TPU, wrap the optimizer with CrossShardOptimizer which
      # handles synchronization details between different TPU cores. To the
      # user, this should look like regular synchronous training.
      optimizer = tf.tpu.CrossShardOptimizer(optimizer)

    # Batch normalization requires UPDATE_OPS to be added as a dependency to
    # the train operation.
    update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
    with tf.control_dependencies(update_ops):
      train_op = optimizer.minimize(loss, global_step)

    if not params['skip_host_call']:
      def host_call_fn(gs, loss, lr, ce):
        """Training host call. Creates scalar summaries for training metrics.

        This function is executed on the CPU and should not directly reference
        any Tensors in the rest of the `model_fn`. To pass Tensors from the
        model to the `metric_fn`, provide as part of the `host_call`. See
        https://www.tensorflow.org/api_docs/python/tf/estimator/tpu/TPUEstimatorSpec
        for more information.

        Arguments should match the list of `Tensor` objects passed as the second
        element in the tuple passed to `host_call`.

        Args:
          gs: `Tensor with shape `[batch]` for the global_step
          loss: `Tensor` with shape `[batch]` for the training loss.
          lr: `Tensor` with shape `[batch]` for the learning_rate.
          ce: `Tensor` with shape `[batch]` for the current_epoch.

        Returns:
          List of summary ops to run on the CPU host.
        """
        gs = gs[0]
        # Host call fns are executed params['iterations_per_loop'] times after
        # one TPU loop is finished, setting max_queue value to the same as
        # number of iterations will make the summary writer only flush the data
        # to storage once per loop.
        with tf2.summary.create_file_writer(
            FLAGS.model_dir,
            max_queue=params['iterations_per_loop']).as_default():
          with tf2.summary.record_if(True):
            tf2.summary.scalar('loss', loss[0], step=gs)
            tf2.summary.scalar('learning_rate', lr[0], step=gs)
            tf2.summary.scalar('current_epoch', ce[0], step=gs)

          return tf.summary.all_v2_summary_ops()

      # To log the loss, current learning rate, and epoch for Tensorboard, the
      # summary op needs to be run on the host CPU via host_call. host_call
      # expects [batch_size, ...] Tensors, thus reshape to introduce a batch
      # dimension. These Tensors are implicitly concatenated to
      # [params['batch_size']].
      gs_t = tf.reshape(global_step, [1])
      loss_t = tf.reshape(loss, [1])
      lr_t = tf.reshape(learning_rate, [1])
      ce_t = tf.reshape(current_epoch, [1])

      host_call = (host_call_fn, [gs_t, loss_t, lr_t, ce_t])

  else:
    train_op = None

  eval_metrics = None
  return tf.estimator.tpu.TPUEstimatorSpec(
      mode=mode,
      loss=loss,
      train_op=train_op,
      host_call=host_call,
      eval_metrics=eval_metrics)
コード例 #2
0
ファイル: resnet_main.py プロジェクト: mypolarbear/tpu
def resnet_model_fn(features, labels, mode, params):
    """The model_fn for ResNet to be used with TPUEstimator.

  Args:
    features: `Tensor` of batched images. If transpose_input is enabled, it
        is transposed to device layout and reshaped to 1D tensor.
    labels: `Tensor` of labels for the data samples
    mode: one of `tf.estimator.ModeKeys.{TRAIN,EVAL,PREDICT}`
    params: `dict` of parameters passed to the model from the TPUEstimator,
        `params['batch_size']` is always provided and should be used as the
        effective batch size.

  Returns:
    A `TPUEstimatorSpec` for the model
  """
    if isinstance(features, dict):
        features = features['feature']

    # In most cases, the default data format NCHW instead of NHWC should be
    # used for a significant performance boost on GPU/TPU. NHWC should be used
    # only if the network needs to be run on CPU since the pooling operations
    # are only supported on NHWC.
    if params['data_format'] == 'channels_first':
        assert not params['transpose_input']  # channels_first only for GPU
        features = tf.transpose(features, [0, 3, 1, 2])

    if params['transpose_input'] and mode != tf.estimator.ModeKeys.PREDICT:
        image_size = tf.sqrt(tf.shape(features)[0] / (3 * tf.shape(labels)[0]))
        features = tf.reshape(features, [image_size, image_size, 3, -1])
        features = tf.transpose(features, [3, 0, 1, 2])  # HWCN to NHWC

    # Normalize the image to zero mean and unit variance.
    features -= tf.constant(MEAN_RGB, shape=[1, 1, 3], dtype=features.dtype)
    features /= tf.constant(STDDEV_RGB, shape=[1, 1, 3], dtype=features.dtype)

    # DropBlock keep_prob for the 4 block groups of ResNet architecture.
    # None means applying no DropBlock at the corresponding block group.
    dropblock_keep_probs = [None] * 4
    if params['dropblock_groups']:
        # Scheduled keep_prob for DropBlock.
        train_steps = tf.cast(params['train_steps'], tf.float32)
        current_step = tf.cast(tf.train.get_global_step(), tf.float32)
        current_ratio = current_step / train_steps
        dropblock_keep_prob = (1 - current_ratio *
                               (1 - params['dropblock_keep_prob']))

        # Computes DropBlock keep_prob for different block groups of ResNet.
        dropblock_groups = [
            int(x) for x in params['dropblock_groups'].split(',')
        ]
        for block_group in dropblock_groups:
            if block_group < 1 or block_group > 4:
                raise ValueError(
                    'dropblock_groups should be a comma separated list of integers '
                    'between 1 and 4 (dropblcok_groups: {}).'.format(
                        params['dropblock_groups']))
            dropblock_keep_probs[block_group - 1] = 1 - (
                (1 - dropblock_keep_prob) / 4.0**(4 - block_group))

    # This nested function allows us to avoid duplicating the logic which
    # builds the network, for different values of --precision.
    def build_network():
        network = resnet_model.resnet_v1(
            resnet_depth=params['resnet_depth'],
            num_classes=params['num_label_classes'],
            dropblock_size=params['dropblock_size'],
            dropblock_keep_probs=dropblock_keep_probs,
            data_format=params['data_format'])
        return network(inputs=features,
                       is_training=(mode == tf.estimator.ModeKeys.TRAIN))

    if params['precision'] == 'bfloat16':
        with tf.contrib.tpu.bfloat16_scope():
            logits = build_network()
        logits = tf.cast(logits, tf.float32)
    elif params['precision'] == 'float32':
        logits = build_network()

    if mode == tf.estimator.ModeKeys.PREDICT:
        predictions = {
            'classes': tf.argmax(logits, axis=1),
            'probabilities': tf.nn.softmax(logits, name='softmax_tensor')
        }
        return tf.estimator.EstimatorSpec(
            mode=mode,
            predictions=predictions,
            export_outputs={
                'classify': tf.estimator.export.PredictOutput(predictions)
            })

    # If necessary, in the model_fn, use params['batch_size'] instead the batch
    # size flags (--train_batch_size or --eval_batch_size).
    batch_size = params['batch_size']  # pylint: disable=unused-variable

    # Calculate loss, which includes softmax cross entropy and L2 regularization.
    one_hot_labels = tf.one_hot(labels, params['num_label_classes'])
    cross_entropy = tf.losses.softmax_cross_entropy(
        logits=logits,
        onehot_labels=one_hot_labels,
        label_smoothing=params['label_smoothing'])

    # Add weight decay to the loss for non-batch-normalization variables.
    loss = cross_entropy + params['weight_decay'] * tf.add_n([
        tf.nn.l2_loss(v) for v in tf.trainable_variables()
        if 'batch_normalization' not in v.name
    ])

    host_call = None
    if mode == tf.estimator.ModeKeys.TRAIN:
        # Compute the current epoch and associated learning rate from global_step.
        global_step = tf.train.get_global_step()
        steps_per_epoch = params['num_train_images'] / params[
            'train_batch_size']
        current_epoch = (tf.cast(global_step, tf.float32) / steps_per_epoch)
        # LARS is a large batch optimizer. LARS enables higher accuracy at batch 16K
        # and larger batch sizes.
        if params['train_batch_size'] >= 16384 and params['enable_lars']:
            learning_rate = 0.0
            optimizer = lars_util.init_lars_optimizer(current_epoch, params)
        else:
            learning_rate = learning_rate_schedule(params, current_epoch)
            optimizer = tf.train.MomentumOptimizer(learning_rate=learning_rate,
                                                   momentum=params['momentum'],
                                                   use_nesterov=True)
        if params['use_tpu']:
            # When using TPU, wrap the optimizer with CrossShardOptimizer which
            # handles synchronization details between different TPU cores. To the
            # user, this should look like regular synchronous training.
            optimizer = tf.contrib.tpu.CrossShardOptimizer(optimizer)

        # Batch normalization requires UPDATE_OPS to be added as a dependency to
        # the train operation.
        update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
        with tf.control_dependencies(update_ops):
            train_op = optimizer.minimize(loss, global_step)

        if not params['skip_host_call']:

            def host_call_fn(gs, loss, lr, ce):
                """Training host call. Creates scalar summaries for training metrics.

        This function is executed on the CPU and should not directly reference
        any Tensors in the rest of the `model_fn`. To pass Tensors from the
        model to the `metric_fn`, provide as part of the `host_call`. See
        https://www.tensorflow.org/api_docs/python/tf/contrib/tpu/TPUEstimatorSpec
        for more information.

        Arguments should match the list of `Tensor` objects passed as the second
        element in the tuple passed to `host_call`.

        Args:
          gs: `Tensor with shape `[batch]` for the global_step
          loss: `Tensor` with shape `[batch]` for the training loss.
          lr: `Tensor` with shape `[batch]` for the learning_rate.
          ce: `Tensor` with shape `[batch]` for the current_epoch.

        Returns:
          List of summary ops to run on the CPU host.
        """
                gs = gs[0]
                # Host call fns are executed params['iterations_per_loop'] times after
                # one TPU loop is finished, setting max_queue value to the same as
                # number of iterations will make the summary writer only flush the data
                # to storage once per loop.
                with summary.create_file_writer(
                        FLAGS.model_dir,
                        max_queue=params['iterations_per_loop']).as_default():
                    with summary.always_record_summaries():
                        summary.scalar('loss', loss[0], step=gs)
                        summary.scalar('learning_rate', lr[0], step=gs)
                        summary.scalar('current_epoch', ce[0], step=gs)

                        return summary.all_summary_ops()

            # To log the loss, current learning rate, and epoch for Tensorboard, the
            # summary op needs to be run on the host CPU via host_call. host_call
            # expects [batch_size, ...] Tensors, thus reshape to introduce a batch
            # dimension. These Tensors are implicitly concatenated to
            # [params['batch_size']].
            gs_t = tf.reshape(global_step, [1])
            loss_t = tf.reshape(loss, [1])
            lr_t = tf.reshape(learning_rate, [1])
            ce_t = tf.reshape(current_epoch, [1])

            host_call = (host_call_fn, [gs_t, loss_t, lr_t, ce_t])

    else:
        train_op = None

    eval_metrics = None
    if mode == tf.estimator.ModeKeys.EVAL:

        def metric_fn(labels, logits):
            """Evaluation metric function. Evaluates accuracy.

      This function is executed on the CPU and should not directly reference
      any Tensors in the rest of the `model_fn`. To pass Tensors from the model
      to the `metric_fn`, provide as part of the `eval_metrics`. See
      https://www.tensorflow.org/api_docs/python/tf/contrib/tpu/TPUEstimatorSpec
      for more information.

      Arguments should match the list of `Tensor` objects passed as the second
      element in the tuple passed to `eval_metrics`.

      Args:
        labels: `Tensor` with shape `[batch]`.
        logits: `Tensor` with shape `[batch, num_classes]`.

      Returns:
        A dict of the metrics to return from evaluation.
      """
            predictions = tf.argmax(logits, axis=1)
            top_1_accuracy = tf.metrics.accuracy(labels, predictions)
            in_top_5 = tf.cast(tf.nn.in_top_k(logits, labels, 5), tf.float32)
            top_5_accuracy = tf.metrics.mean(in_top_5)

            return {
                'top_1_accuracy': top_1_accuracy,
                'top_5_accuracy': top_5_accuracy,
            }

        eval_metrics = (metric_fn, [labels, logits])

    return tf.contrib.tpu.TPUEstimatorSpec(mode=mode,
                                           loss=loss,
                                           train_op=train_op,
                                           host_call=host_call,
                                           eval_metrics=eval_metrics)
コード例 #3
0
def resnet_model_fn(features, labels, mode, params):
  """The model_fn for ResNet to be used with TPUEstimator.

  Args:
    features: `Tensor` of batched images.
    labels: `Tensor` of labels for the data samples
    mode: one of `tf.estimator.ModeKeys.{TRAIN,EVAL,PREDICT}`
    params: `dict` of parameters passed to the model from the TPUEstimator,
        `params['batch_size']` is always provided and should be used as the
        effective batch size.

  Returns:
    A `TPUEstimatorSpec` for the model
  """
  if isinstance(features, dict):
    features = features['feature']

  # In most cases, the default data format NCHW instead of NHWC should be
  # used for a significant performance boost on GPU/TPU. NHWC should be used
  # only if the network needs to be run on CPU since the pooling operations
  # are only supported on NHWC.
  if FLAGS.data_format == 'channels_first':
    assert not FLAGS.transpose_input    # channels_first only for GPU
    features = tf.transpose(features, [0, 3, 1, 2])

  if FLAGS.transpose_input and mode != tf.estimator.ModeKeys.PREDICT:
    features = tf.transpose(features, [3, 0, 1, 2])  # HWCN to NHWC

  # Normalize the image to zero mean and unit variance.
  features -= tf.constant(MEAN_RGB, shape=[1, 1, 3], dtype=features.dtype)
  features /= tf.constant(STDDEV_RGB, shape=[1, 1, 3], dtype=features.dtype)

  # This nested function allows us to avoid duplicating the logic which
  # builds the network, for different values of --precision.
  def build_network():
    network = resnet_model.resnet_v1(
        resnet_depth=FLAGS.resnet_depth,
        num_classes=FLAGS.num_label_classes,
        data_format=FLAGS.data_format)
    return network(
        inputs=features, is_training=(mode == tf.estimator.ModeKeys.TRAIN))

  if FLAGS.precision == 'bfloat16':
    with tf.contrib.tpu.bfloat16_scope():
      logits = build_network()
    logits = tf.cast(logits, tf.float32)
  elif FLAGS.precision == 'float32':
    logits = build_network()

  if mode == tf.estimator.ModeKeys.PREDICT:
    predictions = {
        'classes': tf.argmax(logits, axis=1),
        'probabilities': tf.sigmoid(logits, name='sigmoid_tensor')
    }
    return tf.estimator.EstimatorSpec(
        mode=mode,
        predictions=predictions,
        export_outputs={
            'classify': tf.estimator.export.PredictOutput(predictions)
        })

  # If necessary, in the model_fn, use params['batch_size'] instead the batch
  # size flags (--train_batch_size or --eval_batch_size).
  batch_size = params['batch_size']   # pylint: disable=unused-variable

  # Calculate loss, which includes classification loss and L2 regularization.
  one_hot_labels = tf.one_hot(labels, FLAGS.num_label_classes)

  # Normalized weights based on inverse number of effective data per class.
  img_num_per_cls = [int(line.strip()) for line in open(
        FLAGS.img_num_per_cls_file, 'r')]
  effective_num = 1.0 - np.power(FLAGS.beta, img_num_per_cls)
  weights = (1.0 - FLAGS.beta) / np.array(effective_num)
  weights = weights / np.sum(weights) * FLAGS.num_label_classes

  weights = tf.cast(weights, dtype=tf.float32)
  weights = tf.expand_dims(weights, 0)
  weights = tf.tile(weights, [tf.shape(one_hot_labels)[0], 1]) * one_hot_labels
  weights = tf.reduce_sum(weights, axis=1)
  weights = tf.expand_dims(weights, 1)
  weights = tf.tile(weights, [1, FLAGS.num_label_classes])

  classification_loss = focal_loss(
      one_hot_labels, logits, weights, FLAGS.gamma)

  # Add weight decay to the loss for non-batch-normalization variables.
  loss = classification_loss + FLAGS.weight_decay * tf.add_n(
      [tf.nn.l2_loss(v) for v in tf.trainable_variables()
       if 'batch_normalization' not in v.name and 'dense/bias' not in v.name])

  host_call = None
  if mode == tf.estimator.ModeKeys.TRAIN:
    # Compute the current epoch and associated learning rate from global_step.
    global_step = tf.train.get_global_step()
    steps_per_epoch = FLAGS.num_train_images / FLAGS.train_batch_size
    current_epoch = (tf.cast(global_step, tf.float32) /
                     steps_per_epoch)
    # LARS is a large batch optimizer. LARS enables higher accuracy at batch 16K
    # and larger batch sizes.
    if FLAGS.train_batch_size >= 16384 and FLAGS.enable_lars:
      learning_rate = 0.0
      optimizer = lars_util.init_lars_optimizer(current_epoch)
    else:
      learning_rate = learning_rate_schedule(current_epoch)
      optimizer = tf.train.MomentumOptimizer(
          learning_rate=learning_rate,
          momentum=FLAGS.momentum,
          use_nesterov=True)
    if FLAGS.use_tpu:
      # When using TPU, wrap the optimizer with CrossShardOptimizer which
      # handles synchronization details between different TPU cores. To the
      # user, this should look like regular synchronous training.
      optimizer = tf.contrib.tpu.CrossShardOptimizer(optimizer)

    # Batch normalization requires UPDATE_OPS to be added as a dependency to
    # the train operation.
    update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
    with tf.control_dependencies(update_ops):
      train_op = optimizer.minimize(loss, global_step)

    if not FLAGS.skip_host_call:
      def host_call_fn(gs, fl_loss, loss, lr, ce):
        """Training host call. Creates scalar summaries for training metrics.

        This function is executed on the CPU and should not directly reference
        any Tensors in the rest of the `model_fn`. To pass Tensors from the
        model to the `metric_fn`, provide as part of the `host_call`. See
        https://www.tensorflow.org/api_docs/python/tf/contrib/tpu/TPUEstimatorSpec
        for more information.

        Arguments should match the list of `Tensor` objects passed as the second
        element in the tuple passed to `host_call`.

        Args:
          gs: `Tensor with shape `[batch]` for the global_step.
          fl_loss: `Tensor` with shape `[batch]` for the training focal loss.
          loss: `Tensor` with shape `[batch]` for the training loss.
          lr: `Tensor` with shape `[batch]` for the learning_rate.
          ce: `Tensor` with shape `[batch]` for the current_epoch.

        Returns:
          List of summary ops to run on the CPU host.
        """
        gs = gs[0]
        # Host call fns are executed FLAGS.iterations_per_loop times after one
        # TPU loop is finished, setting max_queue value to the same as number of
        # iterations will make the summary writer only flush the data to storage
        # once per loop.
        with summary.create_file_writer(
            FLAGS.model_dir, max_queue=FLAGS.iterations_per_loop).as_default():
          with summary.always_record_summaries():
            summary.scalar('focal_loss', fl_loss[0], step=gs)
            summary.scalar('loss', loss[0], step=gs)
            summary.scalar('learning_rate', lr[0], step=gs)
            summary.scalar('current_epoch', ce[0], step=gs)

            return summary.all_summary_ops()

      # To log the loss, current learning rate, and epoch for Tensorboard, the
      # summary op needs to be run on the host CPU via host_call. host_call
      # expects [batch_size, ...] Tensors, thus reshape to introduce a batch
      # dimension. These Tensors are implicitly concatenated to
      # [params['batch_size']].
      gs_t = tf.reshape(global_step, [1])
      fl_loss_t = tf.reshape(classification_loss, [1])
      loss_t = tf.reshape(loss, [1])
      lr_t = tf.reshape(learning_rate, [1])
      ce_t = tf.reshape(current_epoch, [1])

      host_call = (host_call_fn, [gs_t, fl_loss_t, loss_t, lr_t, ce_t])

  else:
    train_op = None

  eval_metrics = None
  if mode == tf.estimator.ModeKeys.EVAL:
    def metric_fn(labels, logits):
      """Evaluation metric function. Evaluates accuracy.

      This function is executed on the CPU and should not directly reference
      any Tensors in the rest of the `model_fn`. To pass Tensors from the model
      to the `metric_fn`, provide as part of the `eval_metrics`. See
      https://www.tensorflow.org/api_docs/python/tf/contrib/tpu/TPUEstimatorSpec
      for more information.

      Arguments should match the list of `Tensor` objects passed as the second
      element in the tuple passed to `eval_metrics`.

      Args:
        labels: `Tensor` with shape `[batch]`.
        logits: `Tensor` with shape `[batch, num_classes]`.

      Returns:
        A dict of the metrics to return from evaluation.
      """
      predictions = tf.argmax(logits, axis=1)
      top_1_accuracy = tf.metrics.accuracy(labels, predictions)
      in_top_5 = tf.cast(tf.nn.in_top_k(logits, labels, 5), tf.float32)
      top_5_accuracy = tf.metrics.mean(in_top_5)

      return {
          'top_1_accuracy': top_1_accuracy,
          'top_5_accuracy': top_5_accuracy,
      }

    eval_metrics = (metric_fn, [labels, logits])

  return tf.contrib.tpu.TPUEstimatorSpec(
      mode=mode,
      loss=loss,
      train_op=train_op,
      host_call=host_call,
      eval_metrics=eval_metrics)