Example #1
0
def model_fn(features, mode, params):
    '''The model_fn to be used with TPUEstimator.

  Args:
    features: `Tensor` of batched images.
    labels: `Tensor` of labels for the data samples
    mode: one of `tf.estimator.ModeKeys.{TRAIN,EVAL,PREDICT}`
    params: `dict` of parameters passed to the model from the TPUEstimator,
        `params['batch_size']` is always provided and should be used as the
        effective batch size.

  Returns:
    A `TPUEstimatorSpec` for the model
  '''
    def preprocess_image(image):
        # In most cases, the default data format NCHW instead of NHWC should be
        # used for a significant performance boost on GPU. NHWC should be used
        # only if the network needs to be run on CPU since the pooling operations
        # are only supported on NHWC. TPU uses XLA compiler to figure out best layout.
        if FLAGS.data_format == 'channels_first':
            assert not FLAGS.transpose_input  # channels_first only for GPU
            image = tf.transpose(image, [0, 3, 1, 2])

        if FLAGS.transpose_input and mode == tf.estimator.ModeKeys.TRAIN:
            image = tf.transpose(image, [3, 0, 1, 2])  # HWCN to NHWC
        return image

    def normalize_image(image):
        # Normalize the image to zero mean and unit variance.
        if FLAGS.data_format == 'channels_first':
            stats_shape = [3, 1, 1]
        else:
            stats_shape = [1, 1, 3]
        mean, std = task_info.get_mean_std(FLAGS.task_name)
        image -= tf.constant(mean, shape=stats_shape, dtype=image.dtype)
        image /= tf.constant(std, shape=stats_shape, dtype=image.dtype)
        return image

    image = features['image']
    image = preprocess_image(image)

    image_shape = image.get_shape().as_list()
    tf.logging.info('image shape: {}'.format(image_shape))
    is_training = (mode == tf.estimator.ModeKeys.TRAIN)

    if mode != tf.estimator.ModeKeys.PREDICT:
        labels = features['label']
    else:
        labels = None

    # If necessary, in the model_fn, use params['batch_size'] instead the batch
    # size flags (--train_batch_size or --eval_batch_size).
    batch_size = params['batch_size']  # pylint: disable=unused-variable

    if FLAGS.unlabel_ratio and is_training:
        unl_bsz = features['unl_probs'].shape[0]
    else:
        unl_bsz = 0

    lab_bsz = image.shape[0] - unl_bsz
    assert lab_bsz == batch_size

    metric_dict = {}
    global_step = tf.train.get_global_step()

    has_moving_average_decay = (FLAGS.moving_average_decay > 0)
    # This is essential, if using a keras-derived model.
    tf.keras.backend.set_learning_phase(is_training)
    tf.logging.info('Using open-source implementation.')
    override_params = {}
    if FLAGS.dropout_rate is not None:
        override_params['dropout_rate'] = FLAGS.dropout_rate
    if FLAGS.stochastic_depth_rate is not None:
        override_params['stochastic_depth_rate'] = FLAGS.stochastic_depth_rate
    if FLAGS.data_format:
        override_params['data_format'] = FLAGS.data_format
    if FLAGS.num_label_classes:
        override_params['num_classes'] = FLAGS.num_label_classes
    if FLAGS.depth_coefficient:
        override_params['depth_coefficient'] = FLAGS.depth_coefficient
    if FLAGS.width_coefficient:
        override_params['width_coefficient'] = FLAGS.width_coefficient

    def build_model(scope=None,
                    reuse=tf.AUTO_REUSE,
                    model_name=None,
                    model_is_training=None,
                    input_image=None,
                    use_adv_bn=False,
                    is_teacher=False):
        model_name = model_name or FLAGS.model_name
        if model_is_training is None:
            model_is_training = is_training
        if input_image is None:
            input_image = image
        input_image = normalize_image(input_image)

        scope_model_name = model_name

        if scope:
            scope = scope + '/'
        else:
            scope = ''
        with tf.variable_scope(scope + scope_model_name, reuse=reuse):
            if model_name.startswith('efficientnet'):
                logits, _ = efficientnet_builder.build_model(
                    input_image,
                    model_name=model_name,
                    training=model_is_training,
                    override_params=override_params,
                    model_dir=FLAGS.model_dir,
                    use_adv_bn=use_adv_bn,
                    is_teacher=is_teacher)
            else:
                assert False, 'model {} not implemented'.format(model_name)
        return logits

    if params['use_bfloat16']:
        with tf.tpu.bfloat16_scope():
            logits = tf.cast(build_model(), tf.float32)
    else:
        logits = build_model()

    if FLAGS.teacher_model_name:
        teacher_image = preprocess_image(features['teacher_image'])
        if params['use_bfloat16']:
            with tf.tpu.bfloat16_scope():
                teacher_logits = tf.cast(
                    build_model(scope='teacher_model',
                                model_name=FLAGS.teacher_model_name,
                                model_is_training=False,
                                input_image=teacher_image,
                                is_teacher=True), tf.float32)
        else:
            teacher_logits = build_model(scope='teacher_model',
                                         model_name=FLAGS.teacher_model_name,
                                         model_is_training=False,
                                         input_image=teacher_image,
                                         is_teacher=True)
        teacher_logits = tf.stop_gradient(teacher_logits)
        if FLAGS.teacher_softmax_temp != -1:
            teacher_prob = tf.nn.softmax(teacher_logits /
                                         FLAGS.teacher_softmax_temp)
        else:
            teacher_prob = None
            teacher_one_hot_pred = tf.argmax(teacher_logits,
                                             axis=1,
                                             output_type=labels.dtype)

    if mode == tf.estimator.ModeKeys.PREDICT:
        if has_moving_average_decay:
            ema = tf.train.ExponentialMovingAverage(
                decay=FLAGS.moving_average_decay)
            ema_vars = utils.get_all_variable()
            restore_vars_dict = ema.variables_to_restore(ema_vars)
            tf.logging.info(
                'restored variables:\n%s',
                json.dumps(sorted(restore_vars_dict.keys()), indent=4))

        predictions = {
            'classes': tf.argmax(logits, axis=1),
            'probabilities': tf.nn.softmax(logits, name='softmax_tensor')
        }
        return tf.estimator.tpu.TPUEstimatorSpec(
            mode=mode,
            predictions=predictions,
            scaffold_fn=functools.partial(_scaffold_fn,
                                          restore_vars_dict=restore_vars_dict)
            if has_moving_average_decay else None)

    if has_moving_average_decay:
        ema_step = global_step
        ema = tf.train.ExponentialMovingAverage(
            decay=FLAGS.moving_average_decay, num_updates=ema_step)
        ema_vars = utils.get_all_variable()

    lab_labels = labels[:lab_bsz]
    lab_logits = logits[:lab_bsz]
    lab_pred = tf.argmax(lab_logits, axis=-1, output_type=labels.dtype)
    lab_prob = tf.nn.softmax(lab_logits)
    lab_acc = tf.to_float(tf.equal(lab_pred, lab_labels))
    metric_dict['lab/acc'] = tf.reduce_mean(lab_acc)
    metric_dict['lab/pred_prob'] = tf.reduce_mean(
        tf.reduce_max(lab_prob, axis=-1))
    one_hot_labels = tf.one_hot(lab_labels, FLAGS.num_label_classes)

    if FLAGS.unlabel_ratio:
        unl_labels = labels[lab_bsz:]
        unl_logits = logits[lab_bsz:]
        unl_pred = tf.argmax(unl_logits, axis=-1, output_type=labels.dtype)
        unl_prob = tf.nn.softmax(unl_logits)
        unl_acc = tf.to_float(tf.equal(unl_pred, unl_labels))
        metric_dict['unl/acc_to_dump'] = tf.reduce_mean(unl_acc)
        metric_dict['unl/pred_prob'] = tf.reduce_mean(
            tf.reduce_max(unl_prob, axis=-1))

    # compute lab_loss
    one_hot_labels = tf.one_hot(lab_labels, FLAGS.num_label_classes)
    lab_loss = tf.losses.softmax_cross_entropy(
        logits=lab_logits,
        onehot_labels=one_hot_labels,
        label_smoothing=FLAGS.label_smoothing,
        reduction=tf.losses.Reduction.NONE)
    if FLAGS.label_data_sample_prob != 1:
        # mask out part of the labeled data
        random_mask = tf.floor(
            FLAGS.label_data_sample_prob +
            tf.random_uniform(tf.shape(lab_loss), dtype=lab_loss.dtype))
        lab_loss = tf.reduce_mean(lab_loss * random_mask)
    else:
        lab_loss = tf.reduce_mean(lab_loss)
    metric_dict['lab/loss'] = lab_loss

    if FLAGS.unlabel_ratio:
        if FLAGS.teacher_softmax_temp == -1:  # Hard labels
            # Get one-hot labels
            if FLAGS.teacher_model_name:
                ext_teacher_pred = teacher_one_hot_pred[lab_bsz:]
                one_hot_labels = tf.one_hot(ext_teacher_pred,
                                            FLAGS.num_label_classes)
            else:
                one_hot_labels = tf.one_hot(unl_labels,
                                            FLAGS.num_label_classes)
            # Compute cross entropy
            unl_loss = tf.losses.softmax_cross_entropy(
                logits=unl_logits,
                onehot_labels=one_hot_labels,
                label_smoothing=FLAGS.label_smoothing)
        else:  # Soft labels
            # Get teacher prob
            if FLAGS.teacher_model_name:
                unl_teacher_prob = teacher_prob[lab_bsz:]
            else:
                scaled_prob = tf.pow(features['unl_probs'],
                                     1 / FLAGS.teacher_softmax_temp)
                unl_teacher_prob = scaled_prob / tf.reduce_sum(
                    scaled_prob, axis=-1, keepdims=True)
            metric_dict['unl/target_prob'] = tf.reduce_mean(
                tf.reduce_max(unl_teacher_prob, axis=-1))
            unl_loss = cross_entropy(unl_teacher_prob,
                                     unl_logits,
                                     return_mean=True)

        metric_dict['ext/loss'] = unl_loss
    else:
        unl_loss = 0

    real_lab_bsz = tf.to_float(lab_bsz) * FLAGS.label_data_sample_prob
    real_unl_bsz = batch_size * FLAGS.label_data_sample_prob * FLAGS.unlabel_ratio
    data_loss = lab_loss * real_lab_bsz + unl_loss * real_unl_bsz
    data_loss = data_loss / real_lab_bsz

    # Add weight decay to the loss for non-batch-normalization variables.
    loss = data_loss + FLAGS.weight_decay * tf.add_n([
        tf.nn.l2_loss(v) for v in tf.trainable_variables()
        if 'batch_normalization' not in v.name
    ])
    metric_dict['train/data_loss'] = data_loss
    metric_dict['train/loss'] = loss

    host_call = None
    restore_vars_dict = None

    if is_training:
        # Compute the current epoch and associated learning rate from global_step.
        current_epoch = (tf.cast(global_step, tf.float32) /
                         params['steps_per_epoch'])
        real_train_batch_size = FLAGS.train_batch_size
        real_train_batch_size *= FLAGS.label_data_sample_prob
        scaled_lr = FLAGS.base_learning_rate * (real_train_batch_size / 256.0)
        if FLAGS.final_base_lr:
            # total number of training epochs
            total_epochs = FLAGS.train_steps * FLAGS.train_batch_size * 1. / FLAGS.num_train_images - 5
            decay_times = math.log(FLAGS.final_base_lr /
                                   FLAGS.base_learning_rate) / math.log(0.97)
            decay_epochs = total_epochs / decay_times
            tf.logging.info(
                'setting decay_epochs to {:.2f}'.format(decay_epochs) +
                '\n' * 3)
        else:
            decay_epochs = 2.4 * FLAGS.train_ratio
        learning_rate = utils.build_learning_rate(
            scaled_lr,
            global_step,
            params['steps_per_epoch'],
            decay_epochs=decay_epochs,
            start_from_step=FLAGS.train_steps - FLAGS.train_last_step_num,
            warmup_epochs=5,
        )
        metric_dict['train/lr'] = learning_rate
        metric_dict['train/epoch'] = current_epoch
        optimizer = utils.build_optimizer(learning_rate)
        if FLAGS.use_tpu:
            # When using TPU, wrap the optimizer with CrossShardOptimizer which
            # handles synchronization details between different TPU cores. To the
            # user, this should look like regular synchronous training.
            optimizer = tf.tpu.CrossShardOptimizer(optimizer)

        # Batch normalization requires UPDATE_OPS to be added as a dependency to
        # the train operation.
        update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
        tvars = tf.trainable_variables()
        g_vars = []
        tvars = sorted(tvars, key=lambda var: var.name)
        for var in tvars:
            if 'teacher_model' not in var.name:
                g_vars += [var]
        with tf.control_dependencies(update_ops):
            train_op = optimizer.minimize(loss, global_step, var_list=g_vars)

        if has_moving_average_decay:
            with tf.control_dependencies([train_op]):
                train_op = ema.apply(ema_vars)

        if not FLAGS.skip_host_call:
            host_call = utils.construct_scalar_host_call(metric_dict)
        scaffold_fn = None
        if FLAGS.teacher_model_name or FLAGS.init_model:
            scaffold_fn = utils.init_from_ckpt(scaffold_fn)
    else:
        train_op = None
        if has_moving_average_decay:
            # Load moving average variables for eval.
            restore_vars_dict = ema.variables_to_restore(ema_vars)

    eval_metrics = None
    if mode == tf.estimator.ModeKeys.EVAL:
        scaffold_fn = functools.partial(_scaffold_fn,
                                        restore_vars_dict=restore_vars_dict
                                        ) if has_moving_average_decay else None

        def metric_fn(labels, logits):
            '''Evaluation metric function. Evaluates accuracy.

      This function is executed on the CPU and should not directly reference
      any Tensors in the rest of the `model_fn`. To pass Tensors from the model
      to the `metric_fn`, provide as part of the `eval_metrics`. See
      https://www.tensorflow.org/api_docs/python/tf/contrib/tpu/TPUEstimatorSpec
      for more information.

      Arguments should match the list of `Tensor` objects passed as the second
      element in the tuple passed to `eval_metrics`.

      Args:
        labels: `Tensor` with shape `[batch]`.
        logits: `Tensor` with shape `[batch, num_classes]`.

      Returns:
        A dict of the metrics to return from evaluation.
      '''

            predictions = tf.argmax(logits, axis=1)
            top_1_accuracy = tf.metrics.accuracy(labels, predictions)
            in_top_5 = tf.cast(tf.nn.in_top_k(logits, labels, 5), tf.float32)
            top_5_accuracy = tf.metrics.mean(in_top_5)

            result_dict = {
                'top_1_accuracy': top_1_accuracy,
                'top_5_accuracy': top_5_accuracy,
            }

            return result_dict

        eval_metrics = (metric_fn, [labels, logits])

    num_params = np.sum([np.prod(v.shape) for v in tf.trainable_variables()])
    tf.logging.info('number of trainable parameters: {}'.format(num_params))

    return tf.estimator.tpu.TPUEstimatorSpec(mode=mode,
                                             loss=loss,
                                             train_op=train_op,
                                             host_call=host_call,
                                             eval_metrics=eval_metrics,
                                             scaffold_fn=scaffold_fn)
Example #2
0
def model_fn(features, labels, mode, params):
    """The model_fn to be used with TPUEstimator.

  Args:
    features: `Tensor` of batched images.
    labels: `Tensor` of one hot labels for the data samples
    mode: one of `tf.estimator.ModeKeys.{TRAIN,EVAL,PREDICT}`
    params: `dict` of parameters passed to the model from the TPUEstimator,
        `params['batch_size']` is always provided and should be used as the
        effective batch size.

  Returns:
    A `TPUEstimatorSpec` for the model
  """
    if isinstance(features, dict):
        features = features['feature']

    # In most cases, the default data format NCHW instead of NHWC should be
    # used for a significant performance boost on GPU. NHWC should be used
    # only if the network needs to be run on CPU since the pooling operations
    # are only supported on NHWC. TPU uses XLA compiler to figure out best layout.
    if FLAGS.data_format == 'channels_first':
        assert not FLAGS.transpose_input  # channels_first only for GPU
        features = tf.transpose(features, [0, 3, 1, 2])
        stats_shape = [3, 1, 1]
    else:
        stats_shape = [1, 1, 3]

    if FLAGS.transpose_input and mode != tf.estimator.ModeKeys.PREDICT:
        features = tf.transpose(features, [3, 0, 1, 2])  # HWCN to NHWC

    is_training = (mode == tf.estimator.ModeKeys.TRAIN)
    has_moving_average_decay = (FLAGS.moving_average_decay > 0)
    # This is essential, if using a keras-derived model.
    tf.keras.backend.set_learning_phase(is_training)
    logging.info('Using open-source implementation.')
    override_params = {}
    if FLAGS.batch_norm_momentum is not None:
        override_params['batch_norm_momentum'] = FLAGS.batch_norm_momentum
    if FLAGS.batch_norm_epsilon is not None:
        override_params['batch_norm_epsilon'] = FLAGS.batch_norm_epsilon
    if FLAGS.dropout_rate is not None:
        override_params['dropout_rate'] = FLAGS.dropout_rate
    if FLAGS.survival_prob is not None:
        override_params['survival_prob'] = FLAGS.survival_prob
    if FLAGS.data_format:
        override_params['data_format'] = FLAGS.data_format
    if FLAGS.num_label_classes:
        override_params['num_classes'] = FLAGS.num_label_classes
    if FLAGS.depth_coefficient:
        override_params['depth_coefficient'] = FLAGS.depth_coefficient
    if FLAGS.width_coefficient:
        override_params['width_coefficient'] = FLAGS.width_coefficient

    def normalize_features(features, mean_rgb, stddev_rgb):
        """Normalize the image given the means and stddevs."""
        features -= tf.constant(mean_rgb,
                                shape=stats_shape,
                                dtype=features.dtype)
        features /= tf.constant(stddev_rgb,
                                shape=stats_shape,
                                dtype=features.dtype)
        return features

    def build_model():
        """Build model using the model_name given through the command line."""
        model_builder = model_builder_factory.get_model_builder(
            FLAGS.model_name)
        normalized_features = normalize_features(features,
                                                 model_builder.MEAN_RGB,
                                                 model_builder.STDDEV_RGB)
        logits, _ = model_builder.build_model(normalized_features,
                                              model_name=FLAGS.model_name,
                                              training=is_training,
                                              override_params=override_params,
                                              model_dir=FLAGS.model_dir)
        return logits

    if params['use_bfloat16']:
        with tf.tpu.bfloat16_scope():
            logits = tf.cast(build_model(), tf.float32)
    else:
        logits = build_model()

    if mode == tf.estimator.ModeKeys.PREDICT:
        predictions = {
            'classes': tf.argmax(logits, axis=1),
            'probabilities': tf.nn.softmax(logits, name='softmax_tensor')
        }
        return tf.estimator.EstimatorSpec(
            mode=mode,
            predictions=predictions,
            export_outputs={
                'classify': tf.estimator.export.PredictOutput(predictions)
            })

    # If necessary, in the model_fn, use params['batch_size'] instead the batch
    # size flags (--train_batch_size or --eval_batch_size).
    batch_size = params['batch_size']  # pylint: disable=unused-variable

    # Calculate loss, which includes softmax cross entropy and L2 regularization.
    cross_entropy = tf.losses.softmax_cross_entropy(
        logits=logits,
        onehot_labels=labels,
        label_smoothing=FLAGS.label_smoothing)

    # Add weight decay to the loss for non-batch-normalization variables.
    loss = cross_entropy + FLAGS.weight_decay * tf.add_n([
        tf.nn.l2_loss(v) for v in tf.trainable_variables()
        if 'batch_normalization' not in v.name
    ])

    global_step = tf.train.get_global_step()
    if has_moving_average_decay:
        ema = tf.train.ExponentialMovingAverage(
            decay=FLAGS.moving_average_decay, num_updates=global_step)
        ema_vars = utils.get_ema_vars()

    host_call = None
    restore_vars_dict = None
    if is_training:
        # Compute the current epoch and associated learning rate from global_step.
        current_epoch = (tf.cast(global_step, tf.float32) /
                         params['steps_per_epoch'])

        scaled_lr = FLAGS.base_learning_rate * (FLAGS.train_batch_size / 256.0)
        logging.info('base_learning_rate = %f', FLAGS.base_learning_rate)
        learning_rate = utils.build_learning_rate(
            scaled_lr,
            global_step,
            params['steps_per_epoch'],
            decay_epochs=FLAGS.lr_decay_epoch)
        optimizer = utils.build_optimizer(learning_rate)
        if FLAGS.use_tpu:
            # When using TPU, wrap the optimizer with CrossShardOptimizer which
            # handles synchronization details between different TPU cores. To the
            # user, this should look like regular synchronous training.
            optimizer = tf.tpu.CrossShardOptimizer(optimizer)

        # Batch normalization requires UPDATE_OPS to be added as a dependency to
        # the train operation.
        update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
        with tf.control_dependencies(update_ops):
            train_op = optimizer.minimize(loss, global_step)

        if has_moving_average_decay:
            with tf.control_dependencies([train_op]):
                train_op = ema.apply(ema_vars)

        if not FLAGS.skip_host_call:

            def host_call_fn(gs, lr, ce):
                """Training host call. Creates scalar summaries for training metrics.

        This function is executed on the CPU and should not directly reference
        any Tensors in the rest of the `model_fn`. To pass Tensors from the
        model to the `metric_fn`, provide as part of the `host_call`. See
        https://www.tensorflow.org/api_docs/python/tf/estimator/tpu/TPUEstimatorSpec
        for more information.

        Arguments should match the list of `Tensor` objects passed as the second
        element in the tuple passed to `host_call`.

        Args:
          gs: `Tensor with shape `[batch]` for the global_step
          lr: `Tensor` with shape `[batch]` for the learning_rate.
          ce: `Tensor` with shape `[batch]` for the current_epoch.

        Returns:
          List of summary ops to run on the CPU host.
        """
                gs = gs[0]
                # Host call fns are executed FLAGS.iterations_per_loop times after one
                # TPU loop is finished, setting max_queue value to the same as number of
                # iterations will make the summary writer only flush the data to storage
                # once per loop.
                with tf2.summary.create_file_writer(
                        FLAGS.model_dir,
                        max_queue=FLAGS.iterations_per_loop).as_default():
                    with tf2.summary.record_if(True):
                        tf2.summary.scalar('learning_rate', lr[0], step=gs)
                        tf2.summary.scalar('current_epoch', ce[0], step=gs)

                        return tf.summary.all_v2_summary_ops()

            # To log the loss, current learning rate, and epoch for Tensorboard, the
            # summary op needs to be run on the host CPU via host_call. host_call
            # expects [batch_size, ...] Tensors, thus reshape to introduce a batch
            # dimension. These Tensors are implicitly concatenated to
            # [params['batch_size']].
            gs_t = tf.reshape(global_step, [1])
            lr_t = tf.reshape(learning_rate, [1])
            ce_t = tf.reshape(current_epoch, [1])

            host_call = (host_call_fn, [gs_t, lr_t, ce_t])

    else:
        train_op = None
        if has_moving_average_decay:
            # Load moving average variables for eval.
            restore_vars_dict = ema.variables_to_restore(ema_vars)

    eval_metrics = None
    if mode == tf.estimator.ModeKeys.EVAL:

        def metric_fn(labels, logits):
            """Evaluation metric function. Evaluates accuracy.

      This function is executed on the CPU and should not directly reference
      any Tensors in the rest of the `model_fn`. To pass Tensors from the model
      to the `metric_fn`, provide as part of the `eval_metrics`. See
      https://www.tensorflow.org/api_docs/python/tf/estimator/tpu/TPUEstimatorSpec
      for more information.

      Arguments should match the list of `Tensor` objects passed as the second
      element in the tuple passed to `eval_metrics`.

      Args:
        labels: `Tensor` with shape `[batch, num_classes]`.
        logits: `Tensor` with shape `[batch, num_classes]`.

      Returns:
        A dict of the metrics to return from evaluation.
      """
            labels = tf.argmax(labels, axis=1)
            predictions = tf.argmax(logits, axis=1)
            top_1_accuracy = tf.metrics.accuracy(labels, predictions)
            in_top_5 = tf.cast(tf.nn.in_top_k(logits, labels, 5), tf.float32)
            top_5_accuracy = tf.metrics.mean(in_top_5)

            return {
                'top_1_accuracy': top_1_accuracy,
                'top_5_accuracy': top_5_accuracy,
            }

        eval_metrics = (metric_fn, [labels, logits])

    num_params = np.sum([np.prod(v.shape) for v in tf.trainable_variables()])
    logging.info('number of trainable parameters: %d', num_params)

    def _scaffold_fn():
        saver = tf.train.Saver(restore_vars_dict)
        return tf.train.Scaffold(saver=saver)

    if has_moving_average_decay and not is_training:
        # Only apply scaffold for eval jobs.
        scaffold_fn = _scaffold_fn
    else:
        scaffold_fn = None

    return tf.estimator.tpu.TPUEstimatorSpec(mode=mode,
                                             loss=loss,
                                             train_op=train_op,
                                             host_call=host_call,
                                             eval_metrics=eval_metrics,
                                             scaffold_fn=scaffold_fn)
Example #3
0
def mnasnet_model_fn(features, labels, mode, params):
    """The model_fn for MnasNet to be used with TPUEstimator.

  Args:
    features: `Tensor` of batched images.
    labels: `Tensor` of labels for the data samples
    mode: one of `tf.estimator.ModeKeys.{TRAIN,EVAL,PREDICT}`
    params: `dict` of parameters passed to the model from the TPUEstimator,
      `params['batch_size']` is always provided and should be used as the
      effective batch size.

  Returns:
    A `TPUEstimatorSpec` for the model
  """
    is_training = (mode == tf.estimator.ModeKeys.TRAIN)
    # This is essential, if using a keras-derived model.
    K.set_learning_phase(is_training)

    if isinstance(features, dict):
        features = features['feature']

    if mode == tf.estimator.ModeKeys.PREDICT:
        # Adds an identify node to help TFLite export.
        features = tf.identity(features, 'float_image_input')

    # In most cases, the default data format NCHW instead of NHWC should be
    # used for a significant performance boost on GPU. NHWC should be used
    # only if the network needs to be run on CPU since the pooling operations
    # are only supported on NHWC. TPU uses XLA compiler to figure out best layout.
    if params['data_format'] == 'channels_first':
        assert not params['transpose_input']  # channels_first only for GPU
        features = tf.transpose(features, [0, 3, 1, 2])
        stats_shape = [3, 1, 1]
    else:
        stats_shape = [1, 1, 3]

    if params['transpose_input'] and mode != tf.estimator.ModeKeys.PREDICT:
        features = tf.transpose(features, [3, 0, 1, 2])  # HWCN to NHWC

    # Normalize the image to zero mean and unit variance.
    features -= tf.constant(imagenet_input.MEAN_RGB,
                            shape=stats_shape,
                            dtype=features.dtype)
    features /= tf.constant(imagenet_input.STDDEV_RGB,
                            shape=stats_shape,
                            dtype=features.dtype)

    has_moving_average_decay = (params['moving_average_decay'] > 0)

    tf.logging.info('Using open-source implementation for MnasNet definition.')
    override_params = {}
    if params['batch_norm_momentum']:
        override_params['batch_norm_momentum'] = params['batch_norm_momentum']
    if params['batch_norm_epsilon']:
        override_params['batch_norm_epsilon'] = params['batch_norm_epsilon']
    if params['dropout_rate']:
        override_params['dropout_rate'] = params['dropout_rate']
    if params['data_format']:
        override_params['data_format'] = params['data_format']
    if params['num_label_classes']:
        override_params['num_classes'] = params['num_label_classes']
    if params['depth_multiplier']:
        override_params['depth_multiplier'] = params['depth_multiplier']
    if params['depth_divisor']:
        override_params['depth_divisor'] = params['depth_divisor']
    if params['min_depth']:
        override_params['min_depth'] = params['min_depth']
    override_params['use_keras'] = params['use_keras']

    if params['precision'] == 'bfloat16':
        with tf.contrib.tpu.bfloat16_scope():
            logits, _ = mnasnet_models.build_mnasnet_model(
                features,
                model_name=params['model_name'],
                training=is_training,
                override_params=override_params)
        logits = tf.cast(logits, tf.float32)
    else:  # params['precision'] == 'float32'
        logits, _ = mnasnet_models.build_mnasnet_model(
            features,
            model_name=params['model_name'],
            training=is_training,
            override_params=override_params)

    if params['quantized_training']:
        if is_training:
            tf.logging.info('Adding fake quantization ops for training.')
            tf.contrib.quantize.create_training_graph(
                quant_delay=int(params['steps_per_epoch'] *
                                FLAGS.quantization_delay_epochs))
        else:
            tf.logging.info('Adding fake quantization ops for evaluation.')
            tf.contrib.quantize.create_eval_graph()

    if mode == tf.estimator.ModeKeys.PREDICT:
        scaffold_fn = None
        if FLAGS.export_moving_average:
            # If the model is trained with moving average decay, to match evaluation
            # metrics, we need to export the model using moving average variables.
            restore_checkpoint = tf.train.latest_checkpoint(FLAGS.model_dir)
            variables_to_restore = get_pretrained_variables_to_restore(
                restore_checkpoint, load_moving_average=True)
            tf.logging.info('Restoring from the latest checkpoint: %s',
                            restore_checkpoint)
            tf.logging.info(str(variables_to_restore))

            def restore_scaffold():
                saver = tf.train.Saver(variables_to_restore)
                return tf.train.Scaffold(saver=saver)

            scaffold_fn = restore_scaffold

        predictions = {
            'classes': tf.argmax(logits, axis=1),
            'probabilities': tf.nn.softmax(logits, name='softmax_tensor')
        }
        return tf.contrib.tpu.TPUEstimatorSpec(
            mode=mode,
            predictions=predictions,
            export_outputs={
                'classify': tf.estimator.export.PredictOutput(predictions)
            },
            scaffold_fn=scaffold_fn)

    # If necessary, in the model_fn, use params['batch_size'] instead the batch
    # size flags (--train_batch_size or --eval_batch_size).
    batch_size = params['batch_size']  # pylint: disable=unused-variable

    # Calculate loss, which includes softmax cross entropy and L2 regularization.
    one_hot_labels = tf.one_hot(labels, params['num_label_classes'])
    cross_entropy = tf.losses.softmax_cross_entropy(
        logits=logits,
        onehot_labels=one_hot_labels,
        label_smoothing=params['label_smoothing'])

    # Add weight decay to the loss for non-batch-normalization variables.
    loss = cross_entropy + params['weight_decay'] * tf.add_n([
        tf.nn.l2_loss(v) for v in tf.trainable_variables()
        if 'batch_normalization' not in v.name
    ])

    global_step = tf.train.get_global_step()
    if has_moving_average_decay:
        ema = tf.train.ExponentialMovingAverage(
            decay=params['moving_average_decay'], num_updates=global_step)
        ema_vars = utils.get_ema_vars()

    host_call = None
    if is_training:
        # Compute the current epoch and associated learning rate from global_step.
        current_epoch = (tf.cast(global_step, tf.float32) /
                         params['steps_per_epoch'])

        scaled_lr = params['base_learning_rate'] * (params['train_batch_size'] / 256.0)  # pylint: disable=line-too-long
        learning_rate = utils.build_learning_rate(scaled_lr, global_step,
                                                  params['steps_per_epoch'])
        optimizer = utils.build_optimizer(learning_rate)
        if params['use_tpu']:
            # When using TPU, wrap the optimizer with CrossShardOptimizer which
            # handles synchronization details between different TPU cores. To the
            # user, this should look like regular synchronous training.
            optimizer = tf.contrib.tpu.CrossShardOptimizer(optimizer)

        # Batch normalization requires UPDATE_OPS to be added as a dependency to
        # the train operation.
        update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
        with tf.control_dependencies(update_ops):
            train_op = optimizer.minimize(loss, global_step)

        if has_moving_average_decay:
            with tf.control_dependencies([train_op]):
                train_op = ema.apply(ema_vars)

        if not params['skip_host_call']:

            def host_call_fn(gs, loss, lr, ce):
                """Training host call.

        Creates scalar summaries for training metrics.

        This function is executed on the CPU and should not directly reference
        any Tensors in the rest of the `model_fn`. To pass Tensors from the
        model to the `metric_fn`, provide as part of the `host_call`. See
        https://www.tensorflow.org/api_docs/python/tf/contrib/tpu/TPUEstimatorSpec
        for more information.

        Arguments should match the list of `Tensor` objects passed as the second
        element in the tuple passed to `host_call`.

        Args:
          gs: `Tensor with shape `[batch]` for the global_step
          loss: `Tensor` with shape `[batch]` for the training loss.
          lr: `Tensor` with shape `[batch]` for the learning_rate.
          ce: `Tensor` with shape `[batch]` for the current_epoch.

        Returns:
          List of summary ops to run on the CPU host.
        """
                gs = gs[0]
                # Host call fns are executed params['iterations_per_loop'] times after
                # one TPU loop is finished, setting max_queue value to the same as
                # number of iterations will make the summary writer only flush the
                # data to storage once per loop.
                with tf.contrib.summary.create_file_writer(
                        FLAGS.model_dir,
                        max_queue=params['iterations_per_loop']).as_default():
                    with tf.contrib.summary.always_record_summaries():
                        tf.contrib.summary.scalar('loss', loss[0], step=gs)
                        tf.contrib.summary.scalar('learning_rate',
                                                  lr[0],
                                                  step=gs)
                        tf.contrib.summary.scalar('current_epoch',
                                                  ce[0],
                                                  step=gs)

                        return tf.contrib.summary.all_summary_ops()

            # To log the loss, current learning rate, and epoch for Tensorboard, the
            # summary op needs to be run on the host CPU via host_call. host_call
            # expects [batch_size, ...] Tensors, thus reshape to introduce a batch
            # dimension. These Tensors are implicitly concatenated to
            # [params['batch_size']].
            gs_t = tf.reshape(global_step, [1])
            loss_t = tf.reshape(loss, [1])
            lr_t = tf.reshape(learning_rate, [1])
            ce_t = tf.reshape(current_epoch, [1])

            host_call = (host_call_fn, [gs_t, loss_t, lr_t, ce_t])

    else:
        train_op = None

    eval_metrics = None
    if mode == tf.estimator.ModeKeys.EVAL:

        def metric_fn(labels, logits):
            """Evaluation metric function.

      Evaluates accuracy.

      This function is executed on the CPU and should not directly reference
      any Tensors in the rest of the `model_fn`. To pass Tensors from the model
      to the `metric_fn`, provide as part of the `eval_metrics`. See
      https://www.tensorflow.org/api_docs/python/tf/contrib/tpu/TPUEstimatorSpec
      for more information.

      Arguments should match the list of `Tensor` objects passed as the second
      element in the tuple passed to `eval_metrics`.

      Args:
        labels: `Tensor` with shape `[batch]`.
        logits: `Tensor` with shape `[batch, num_classes]`.

      Returns:
        A dict of the metrics to return from evaluation.
      """
            predictions = tf.argmax(logits, axis=1)
            top_1_accuracy = tf.metrics.accuracy(labels, predictions)
            in_top_5 = tf.cast(tf.nn.in_top_k(logits, labels, 5), tf.float32)
            top_5_accuracy = tf.metrics.mean(in_top_5)

            return {
                'top_1_accuracy': top_1_accuracy,
                'top_5_accuracy': top_5_accuracy,
            }

        eval_metrics = (metric_fn, [labels, logits])

    num_params = np.sum([np.prod(v.shape) for v in tf.trainable_variables()])
    tf.logging.info('number of trainable parameters: {}'.format(num_params))

    # Prepares scaffold_fn if needed.
    scaffold_fn = None
    if is_training and FLAGS.init_checkpoint:
        variables_to_restore = get_pretrained_variables_to_restore(
            FLAGS.init_checkpoint, has_moving_average_decay)
        tf.logging.info('Initializing from pretrained checkpoint: %s',
                        FLAGS.init_checkpoint)
        if FLAGS.use_tpu:

            def init_scaffold():
                tf.train.init_from_checkpoint(FLAGS.init_checkpoint,
                                              variables_to_restore)
                return tf.train.Scaffold()

            scaffold_fn = init_scaffold
        else:
            tf.train.init_from_checkpoint(FLAGS.init_checkpoint,
                                          variables_to_restore)

    restore_vars_dict = None
    if not is_training and has_moving_average_decay:
        # Load moving average variables for eval.
        restore_vars_dict = ema.variables_to_restore(ema_vars)

        def eval_scaffold():
            saver = tf.train.Saver(restore_vars_dict)
            return tf.train.Scaffold(saver=saver)

        scaffold_fn = eval_scaffold

    return tf.contrib.tpu.TPUEstimatorSpec(mode=mode,
                                           loss=loss,
                                           train_op=train_op,
                                           host_call=host_call,
                                           eval_metrics=eval_metrics,
                                           scaffold_fn=scaffold_fn)
Example #4
0
def final_model_fn(features, labels, mode, params):
    """The model_fn for ConvNet to be used with TPUEstimator.
    Args:
      features: `Tensor` of batched images.
      labels: `Tensor` of labels for the data samples
      mode: one of `tf.estimator.ModeKeys.{TRAIN,EVAL,PREDICT}`
      params: `dict` of parameters passed to the model from the TPUEstimator,
          `params['batch_size']` is always provided and should be used as the
          effective batch size.
    Returns:
      A `TPUEstimatorSpec` for the model
    """
    if isinstance(features, dict):
        features = features['feature']

    # In most cases, the default data format NCHW instead of NHWC should be
    # used for a significant performance boost on GPU/TPU. NHWC should be used
    # only if the network needs to be run on CPU since the pooling operations
    # are only supported on NHWC.
    if FLAGS.data_format == 'channels_first':
        if not FLAGS.transpose_input:
            # channels_first only for GPU
            raise ValueError('The option transpose_input is set to False')
        features = tf.transpose(features, [0, 3, 1, 2])

    if FLAGS.transpose_input and mode != tf.estimator.ModeKeys.PREDICT:
        features = tf.transpose(features, [3, 0, 1, 2])  # HWCN to NHWC

    # Normalize the image to zero mean and unit variance.
    features -= tf.constant(MEAN_RGB, shape=[1, 1, 3], dtype=features.dtype)
    features /= tf.constant(STDDEV_RGB, shape=[1, 1, 3], dtype=features.dtype)

    is_training = (mode == tf.estimator.ModeKeys.TRAIN)
    has_moving_average_decay = (FLAGS.moving_average_decay > 0)
    # This is essential, if using a keras-derived model.
    K.set_learning_phase(is_training)
    tf.logging.info('Using open-source implementation for MnasNet definition.')

    # Override params when necessary
    override_params = utils.get_override_params_dict(FLAGS)

    logits, _ = models.build_model(features,
                                   model_name=FLAGS.model_name,
                                   training=is_training,
                                   override_params=override_params)

    if mode == tf.estimator.ModeKeys.PREDICT:
        predictions = {
            'classes': tf.argmax(logits, axis=1),
            'probabilities': tf.nn.softmax(logits, name='softmax_tensor')
        }
        return tf.estimator.EstimatorSpec(
            mode=mode,
            predictions=predictions,
            export_outputs={
                'classify': tf.estimator.export.PredictOutput(predictions)
            })

    # If necessary, in the model_fn, use params['batch_size'] instead the batch
    # size flags (--train_batch_size or --eval_batch_size).
    batch_size = params['batch_size']  # pylint: disable=unused-variable

    # Calculate loss, which includes softmax cross entropy and L2 regularization.
    one_hot_labels = tf.one_hot(labels, FLAGS.num_label_classes)
    cross_entropy = tf.losses.softmax_cross_entropy(
        logits=logits,
        onehot_labels=one_hot_labels,
        label_smoothing=FLAGS.label_smoothing)

    # Add weight decay to the loss for non-batch-normalization variables.
    loss = cross_entropy + FLAGS.weight_decay * tf.add_n([
        tf.nn.l2_loss(v) for v in tf.trainable_variables()
        if 'batch_normalization' not in v.name
    ])

    global_step = tf.train.get_global_step()
    if has_moving_average_decay:
        ema = tf.train.ExponentialMovingAverage(
            decay=FLAGS.moving_average_decay, num_updates=global_step)
        ema_vars = tf.trainable_variables() + tf.get_collection('moving_vars')
        for v in tf.global_variables():
            # We maintain mva for batch norm moving mean and variance as well.
            if 'moving_mean' in v.name or 'moving_variance' in v.name:
                ema_vars.append(v)
        ema_vars = list(set(ema_vars))

    host_call = None
    restore_vars_dict = None
    if is_training:
        # Compute the current epoch and associated learning rate from global_step.
        current_epoch = (tf.cast(global_step, tf.float32) /
                         params['steps_per_epoch'])

        scaled_lr = FLAGS.base_learning_rate * (FLAGS.train_batch_size / 256.0)
        learning_rate = utils.build_learning_rate(scaled_lr, global_step,
                                                  params['steps_per_epoch'])
        optimizer = utils.build_optimizer(learning_rate)
        if FLAGS.use_tpu:
            # When using TPU, wrap the optimizer with CrossShardOptimizer which
            # handles synchronization details between different TPU cores. To the
            # user, this should look like regular synchronous training.
            optimizer = tf.contrib.tpu.CrossShardOptimizer(optimizer)

        # Batch normalization requires UPDATE_OPS to be added as a dependency to
        # the train operation.
        update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
        with tf.control_dependencies(update_ops):
            train_op = optimizer.minimize(loss, global_step)

        if has_moving_average_decay:
            with tf.control_dependencies([train_op]):
                train_op = ema.apply(ema_vars)

        if not FLAGS.skip_host_call:
            # To log the loss, current learning rate, and epoch for Tensorboard, the
            # summary op needs to be run on the host CPU via host_call. host_call
            # expects [batch_size, ...] Tensors, thus reshape to introduce a batch
            # dimension. These Tensors are implicitly concatenated to
            # [params['batch_size']].
            gs_t = tf.reshape(global_step, [1])
            loss_t = tf.reshape(loss, [1])
            lr_t = tf.reshape(learning_rate, [1])
            ce_t = tf.reshape(current_epoch, [1])

            host_call = (train_host_call_fn, [gs_t, loss_t, lr_t, ce_t])

    else:
        train_op = None
        if has_moving_average_decay:
            # Load moving average variables for eval.
            restore_vars_dict = ema.variables_to_restore(ema_vars)

    eval_metrics = None
    if mode == tf.estimator.ModeKeys.EVAL:
        eval_metrics = (eval_metric_fn, [labels, logits])

    num_params = np.sum([np.prod(v.shape) for v in tf.trainable_variables()])
    tf.logging.info('number of trainable parameters: {}'.format(num_params))

    def _scaffold_fn():
        saver = tf.train.Saver(restore_vars_dict)
        return tf.train.Scaffold(saver=saver)

    return tf.contrib.tpu.TPUEstimatorSpec(
        mode=mode,
        loss=loss,
        train_op=train_op,
        host_call=host_call,
        eval_metrics=eval_metrics,
        scaffold_fn=_scaffold_fn if has_moving_average_decay else None)
Example #5
0
File: train.py Project: yichenj/tpu
def model_fn(features, labels, mode, params):
    """The model_fn to be used with TPUEstimator.

  Args:
    features: `Tensor` of batched images.
    labels: `Tensor` of labels for the data samples
    mode: one of `tf.estimator.ModeKeys.{TRAIN,EVAL,PREDICT}`
    params: `dict` of parameters passed to the model from the TPUEstimator,
        `params['batch_size']` is always provided and should be used as the
        effective batch size.

  Returns:
    A `TPUEstimatorSpec` for the model
  """
    if isinstance(features, dict):
        features = features['feature']

    stats_shape = [1, 1, 3]

    is_training = (mode == tf.estimator.ModeKeys.TRAIN)
    has_moving_average_decay = (FLAGS.moving_average_decay > 0)
    # This is essential, if using a keras-derived model.
    tf.keras.backend.set_learning_phase(is_training)
    tf.logging.info('Using open-source implementation.')
    override_params = {}
    if FLAGS.batch_norm_momentum is not None:
        override_params['batch_norm_momentum'] = FLAGS.batch_norm_momentum
    if FLAGS.batch_norm_epsilon is not None:
        override_params['batch_norm_epsilon'] = FLAGS.batch_norm_epsilon
    if FLAGS.dropout_rate is not None:
        override_params['dropout_rate'] = FLAGS.dropout_rate
    if FLAGS.drop_connect_rate is not None:
        override_params['drop_connect_rate'] = FLAGS.drop_connect_rate
    if FLAGS.num_label_classes:
        override_params['num_classes'] = FLAGS.num_label_classes
    if FLAGS.depth_coefficient:
        override_params['depth_coefficient'] = FLAGS.depth_coefficient
    if FLAGS.width_coefficient:
        override_params['width_coefficient'] = FLAGS.width_coefficient

    def normalize_features(features, mean_rgb, stddev_rgb):
        """Normalize the image given the means and stddevs."""
        features -= tf.constant(mean_rgb,
                                shape=stats_shape,
                                dtype=features.dtype)
        features /= tf.constant(stddev_rgb,
                                shape=stats_shape,
                                dtype=features.dtype)
        return features

    def build_model():
        """Build model using the model_name given through the command line."""
        model_builder = None
        if FLAGS.model_name.startswith('efficientnet'):
            model_builder = efficientnet_builder
        else:
            raise ValueError('Model must be either efficientnet-b*')

        normalized_features = normalize_features(features,
                                                 model_builder.MEAN_RGB,
                                                 model_builder.STDDEV_RGB)
        logits, _ = model_builder.build_model(normalized_features,
                                              model_name=FLAGS.model_name,
                                              training=is_training,
                                              override_params=override_params,
                                              model_dir=FLAGS.model_dir)
        return logits

    logits = build_model()

    if mode == tf.estimator.ModeKeys.PREDICT:
        predictions = {
            'classes': tf.argmax(logits, axis=1),
            'probabilities': tf.nn.softmax(logits, name='softmax_tensor')
        }
        return tf.estimator.EstimatorSpec(
            mode=mode,
            predictions=predictions,
            export_outputs={
                'classify': tf.estimator.export.PredictOutput(predictions)
            })

    # Calculate loss, which includes softmax cross entropy and L2 regularization.
    one_hot_labels = tf.one_hot(labels, FLAGS.num_label_classes)
    cross_entropy = tf.losses.softmax_cross_entropy(
        logits=logits,
        onehot_labels=one_hot_labels,
        label_smoothing=FLAGS.label_smoothing)

    # Add weight decay to the loss for non-batch-normalization variables.
    loss = cross_entropy + FLAGS.weight_decay * tf.add_n([
        tf.nn.l2_loss(v) for v in tf.trainable_variables()
        if 'batch_normalization' not in v.name
    ])

    global_step = tf.train.get_global_step()
    if has_moving_average_decay:
        ema = tf.train.ExponentialMovingAverage(
            decay=FLAGS.moving_average_decay, num_updates=global_step)
        ema_vars = utils.get_ema_vars()

    train_op = None
    restore_vars_dict = None
    training_hooks = []
    if is_training:
        # Compute the current epoch and associated learning rate from global_step.
        current_epoch = (tf.cast(global_step, tf.float32) /
                         params['steps_per_epoch'])

        scaled_lr = FLAGS.base_learning_rate * (FLAGS.train_batch_size / 256.0)
        learning_rate = utils.build_learning_rate(scaled_lr, global_step,
                                                  params['steps_per_epoch'])
        optimizer = utils.build_optimizer(learning_rate, optimizer_name='adam')

        # Batch normalization requires UPDATE_OPS to be added as a dependency to
        # the train operation.
        update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
        with tf.control_dependencies(update_ops):
            train_op = optimizer.minimize(loss, global_step)

        if has_moving_average_decay:
            with tf.control_dependencies([train_op]):
                train_op = ema.apply(ema_vars)

        predictions = tf.argmax(logits, axis=1)
        top1_accuray = tf.metrics.accuracy(labels, predictions)
        logging_hook = tf.train.LoggingTensorHook(
            {
                "loss": loss,
                "accuracy": top1_accuray[1],
                "step": global_step
            },
            every_n_iter=1)
        training_hooks.append(logging_hook)

    eval_metrics = None
    if mode == tf.estimator.ModeKeys.EVAL:
        predictions = tf.argmax(logits, axis=1)
        top1_accuray = tf.metrics.accuracy(labels, predictions)
        eval_metrics = {'val_accuracy': top1_accuray}

    num_params = np.sum([np.prod(v.shape) for v in tf.trainable_variables()])
    tf.logging.info('number of trainable parameters: {}'.format(num_params))

    scaffold = None
    if has_moving_average_decay and not is_training:
        # Only apply scaffold for eval jobs.
        saver = tf.train.Saver(restore_vars_dict)
        scaffold = tf.train.Scaffold(saver=saver)

    return tf.estimator.EstimatorSpec(mode=mode,
                                      loss=loss,
                                      train_op=train_op,
                                      training_hooks=training_hooks,
                                      eval_metric_ops=eval_metrics,
                                      scaffold=scaffold)
Example #6
0
    def model_fn(features, labels, mode, params=None):
        """The model_fn to be used with TPUEstimator.

        Args:
            features: `Tensor` of batched images.
            labels: `Tensor` of one hot labels for the data samples
            mode: one of `tf.estimator.ModeKeys.{TRAIN,EVAL,PREDICT}`

        Returns:
        A `TPUEstimatorSpec` for the model
        """
        if isinstance(features, dict):
            features = features["feature"]

        # In most cases, the default data format NCHW instead of NHWC should be
        # used for a significant performance boost on GPU. NHWC should be used
        # only if the network needs to be run on CPU since the pooling operations
        # are only supported on NHWC. TPU uses XLA compiler to figure out best layout.
        if context.get_hparam("data_format") == "channels_first":
            assert not context.get_hparam("transpose_input")  # channels_first only for GPU
            features = tf.transpose(features, [0, 3, 1, 2])
            stats_shape = [3, 1, 1]
        else:
            stats_shape = [1, 1, 3]

        #if context.get_hparam("transpose_input") and mode != tf.estimator.ModeKeys.PREDICT:
        #    features = tf.transpose(features, [3, 0, 1, 2])  # HWCN to NHWC

        is_training = mode == tf.estimator.ModeKeys.TRAIN
        has_moving_average_decay = context.get_hparam("moving_average_decay") > 0
        # This is essential, if using a keras-derived model.
        tf.keras.backend.set_learning_phase(is_training)
        logging.info("Using open-source implementation.")
        override_params = {}
        #if context.get_hparam("batch_norm_momentum") is not None:
        #    override_params["batch_norm_momentum"] = context.get_hparam("batch_norm_momentum")
        #if context.get_hparam("batch_norm_epsilon") is not None:
        #    override_params["batch_norm_epsilon"] = context.get_hparam("batch_norm_epsilon")
       # if context.get_hparam("dropout_rate") is not None:
       #     override_params["dropout_rate"] = context.get_hparam("dropout_rate")
       # if context.get_hparam("survival_prob") is not None:
       #     override_params["survival_prob"] = context.get_hparam("survival_prob")
       # if context.get_hparam("data_format"):
       #     override_params["data_format"] = context.get_hparam("data_format")
       # if context.get_hparam("num_label_classes"):
       #     override_params["num_classes"] = context.get_hparam("num_label_classes")
       # if context.get_hparam("depth_coefficient"):
       #     override_params["depth_coefficient"] = context.get_hparam("depth_coefficient")
       # if context.get_hparam("width_coefficient"):
       #     override_params["width_coefficient"] = context.get_hparam("width_coefficient")

        def normalize_features(features, mean_rgb, stddev_rgb):
            """Normalize the image given the means and stddevs."""
            features -= tf.constant(mean_rgb, shape=stats_shape, dtype=features.dtype)
            features /= tf.constant(stddev_rgb, shape=stats_shape, dtype=features.dtype)
            return features

        def build_model():
            """Build model using the model_name given through the command line."""
            model_builder = model_builder_factory.get_model_builder(
                context.get_hparam("model_name"),
            )
            normalized_features = normalize_features(
                features, model_builder.MEAN_RGB, model_builder.STDDEV_RGB
            )
            logits, _ = model_builder.build_model(
                normalized_features,
                model_name=context.get_hparam("model_name"),
                training=is_training,
                override_params=override_params,
                #model_dir=context.get_hparam("model_dir"),
            )
            return logits

        logits = build_model()

        # Calculate loss, which includes softmax cross entropy and L2 regularization.
        cross_entropy = tf.losses.softmax_cross_entropy(
            logits=logits, onehot_labels=labels, label_smoothing=context.get_hparam("label_smoothing")
        )

        # Add weight decay to the loss for non-batch-normalization variables.
        loss = cross_entropy + context.get_hparam("weight_decay") * tf.add_n(
            [
                tf.nn.l2_loss(v)
                for v in tf.trainable_variables()
                if "batch_normalization" not in v.name
            ]
        )

        global_step = tf.train.get_global_step()
        if has_moving_average_decay:
            ema = tf.train.ExponentialMovingAverage(
                decay=context.get_hparam("moving_average_decay"), num_updates=global_step
            )
            ema_vars = utils.get_ema_vars()

        restore_vars_dict = None
        train_op = None
        if is_training:
            # Compute the current epoch and associated learning rate from global_step.
            current_epoch = tf.cast(global_step, tf.float32) / context.get_hparam("steps_per_epoch")

            scaled_lr = context.get_hparam("base_learning_rate") * (context.get_hparam("train_batch_size") / 256.0)
            logging.info("base_learning_rate = %f", context.get_hparam("base_learning_rate"))
            learning_rate = utils.build_learning_rate(
                scaled_lr, global_step, context.get_hparam("steps_per_epoch"),
            )
            optimizer = utils.build_optimizer(context, learning_rate)

            # Batch normalization requires UPDATE_OPS to be added as a dependency to
            # the train operation.
            update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
            with tf.control_dependencies(update_ops):
                train_op = optimizer.minimize(loss, global_step)

            if has_moving_average_decay:
                with tf.control_dependencies([train_op]):
                    train_op = ema.apply(ema_vars)

        if has_moving_average_decay:
            # Load moving average variables for eval.
            restore_vars_dict = ema.variables_to_restore(ema_vars)

        eval_metrics = None
        if mode == tf.estimator.ModeKeys.EVAL:

            def metric_fn(labels, logits):
                """Evaluation metric function. Evaluates accuracy.

                This function is executed on the CPU and should not directly reference
                any Tensors in the rest of the `model_fn`. To pass Tensors from the model
                to the `metric_fn`, provide as part of the `eval_metrics`. See
                https://www.tensorflow.org/api_docs/python/tf/estimator/tpu/TPUEstimatorSpec
                for more information.

                Arguments should match the list of `Tensor` objects passed as the second
                element in the tuple passed to `eval_metrics`.

                Args:
                    labels: `Tensor` with shape `[batch, num_classes]`.
                    logits: `Tensor` with shape `[batch, num_classes]`.

                Returns:
                    A dict of the metrics to return from evaluation.
                """
                labels = tf.argmax(labels, axis=1)
                predictions = tf.argmax(logits, axis=1)
                top_1_accuracy = tf.metrics.accuracy(labels, predictions)
                in_top_5 = tf.cast(tf.nn.in_top_k(logits, labels, 5), tf.float32)
                top_5_accuracy = tf.metrics.mean(in_top_5)

                return {
                    "top_1_accuracy": top_1_accuracy,
                    "top_5_accuracy": top_5_accuracy,
                }

            eval_metrics = metric_fn(labels, logits)

        num_params = np.sum([np.prod(v.shape) for v in tf.trainable_variables()])
        logging.info("number of trainable parameters: %d", num_params)


        return tf.estimator.EstimatorSpec(
            mode=mode, loss=loss, train_op=train_op, eval_metric_ops=eval_metrics,
        )
def build_model():
    logits, _ = efficientnet_builder.build_model(
        features,
        model_name=FLAGS.model_name,
        training=is_training,
        override_params=override_params,
        model_dir=FLAGS.model_dir)
    return logits

    if params['use_bfloat16']:
        with tf.contrib.tpu.bfloat16_scope():
            logits = tf.cast(build_model(), tf.float32)
    else:
        logits = build_model()

    if mode == tf.estimator.ModeKeys.PREDICT:
        predictions = {
            'classes': tf.argmax(logits, axis=1),
            'probabilities': tf.nn.softmax(logits, name='softmax_tensor')
        }
        return tf.estimator.EstimatorSpec(
            mode=mode,
            predictions=predictions,
            export_outputs={
                'classify': tf.estimator.export.PredictOutput(predictions)
            })

    # If necessary, in the model_fn, use params['batch_size'] instead the batch
    # size flags (--train_batch_size or --eval_batch_size).
    batch_size = params['batch_size']  # pylint: disable=unused-variable

    # Calculate loss, which includes softmax cross entropy and L2 regularization.
    one_hot_labels = tf.one_hot(labels, FLAGS.num_label_classes)
    cross_entropy = tf.losses.softmax_cross_entropy(
        logits=logits,
        onehot_labels=one_hot_labels,
        label_smoothing=FLAGS.label_smoothing)

    # Add weight decay to the loss for non-batch-normalization variables.
    loss = cross_entropy + FLAGS.weight_decay * tf.add_n([
        tf.nn.l2_loss(v) for v in tf.trainable_variables()
        if 'batch_normalization' not in v.name
    ])

    global_step = tf.train.get_global_step()
    if has_moving_average_decay:
        ema = tf.train.ExponentialMovingAverage(
            decay=FLAGS.moving_average_decay, num_updates=global_step)
        ema_vars = tf.trainable_variables() + tf.get_collection('moving_vars')
        for v in tf.global_variables():
            # We maintain mva for batch norm moving mean and variance as well.
            if 'moving_mean' in v.name or 'moving_variance' in v.name:
                ema_vars.append(v)
            ema_vars = list(set(ema_vars))

    host_call = None
    restore_vars_dict = None
    if is_training:
        # Compute the current epoch and associated learning rate from global_step.
        current_epoch = (tf.cast(global_step, tf.float32) /
                         params['steps_per_epoch'])

        scaled_lr = FLAGS.base_learning_rate * (FLAGS.train_batch_size / 256.0)
        learning_rate = utils.build_learning_rate(scaled_lr, global_step,
                                                  params['steps_per_epoch'])
        optimizer = utils.build_optimizer(learning_rate)
        if FLAGS.use_tpu:
            # When using TPU, wrap the optimizer with CrossShardOptimizer which
            # handles synchronization details between different TPU cores. To the
            # user, this should look like regular synchronous training.
            optimizer = tf.contrib.tpu.CrossShardOptimizer(optimizer)

        # Batch normalization requires UPDATE_OPS to be added as a dependency to
        # the train operation.
        update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
        with tf.control_dependencies(update_ops):
            train_op = optimizer.minimize(loss, global_step)

        if has_moving_average_decay:
            with tf.control_dependencies([train_op]):
                train_op = ema.apply(ema_vars)

        if not FLAGS.skip_host_call:

            def host_call_fn(gs, lr, ce):
                """Training host call. Creates scalar summaries for training metrics.
                This function is executed on the CPU and should not directly reference
                any Tensors in the rest of the `model_fn`. To pass Tensors from the
                model to the `metric_fn`, provide as part of the `host_call`. See
                https://www.tensorflow.org/api_docs/python/tf/contrib/tpu/TPUEstimatorSpec
                for more information.
                Arguments should match the list of `Tensor` objects passed as the second
                element in the tuple passed to `host_call`.
                Args:
                gs: `Tensor with shape `[batch]` for the global_step
                lr: `Tensor` with shape `[batch]` for the learning_rate.
                ce: `Tensor` with shape `[batch]` for the current_epoch.
                Returns:
                List of summary ops to run on the CPU host.
                """
                gs = gs[0]
                # Host call fns are executed FLAGS.iterations_per_loop times after one
                # TPU loop is finished, setting max_queue value to the same as number of
                # iterations will make the summary writer only flush the data to storage
                # once per loop.
                with tf.contrib.summary.create_file_writer(
                        FLAGS.model_dir,
                        max_queue=FLAGS.iterations_per_loop).as_default():
                    with tf.contrib.summary.always_record_summaries():
                        tf.contrib.summary.scalar('learning_rate',
                                                  lr[0],
                                                  step=gs)
                        tf.contrib.summary.scalar('current_epoch',
                                                  ce[0],
                                                  step=gs)

                        return tf.contrib.summary.all_summary_ops()

        # To log the loss, current learning rate, and epoch for Tensorboard, the
        # summary op needs to be run on the host CPU via host_call. host_call
        # expects [batch_size, ...] Tensors, thus reshape to introduce a batch
        # dimension. These Tensors are implicitly concatenated to
        # [params['batch_size']].
        gs_t = tf.reshape(global_step, [1])
        lr_t = tf.reshape(learning_rate, [1])
        ce_t = tf.reshape(current_epoch, [1])

        host_call = (host_call_fn, [gs_t, lr_t, ce_t])

    else:
        train_op = None
        if has_moving_average_decay:
            # Load moving average variables for eval.
            restore_vars_dict = ema.variables_to_restore(ema_vars)

    eval_metrics = None
    if mode == tf.estimator.ModeKeys.EVAL:

        def metric_fn(labels, logits):
            """Evaluation metric function. Evaluates accuracy.
            This function is executed on the CPU and should not directly reference
            any Tensors in the rest of the `model_fn`. To pass Tensors from the model
            to the `metric_fn`, provide as part of the `eval_metrics`. See
            https://www.tensorflow.org/api_docs/python/tf/contrib/tpu/TPUEstimatorSpec
            for more information.
            Arguments should match the list of `Tensor` objects passed as the second
            element in the tuple passed to `eval_metrics`.
            Args:
                labels: `Tensor` with shape `[batch]`.
                logits: `Tensor` with shape `[batch, num_classes]`.
            Returns:
                A dict of the metrics to return from evaluation.
            """
            predictions = tf.argmax(logits, axis=1)
            top_1_accuracy = tf.metrics.accuracy(labels, predictions)
            in_top_5 = tf.cast(tf.nn.in_top_k(logits, labels, 5), tf.float32)
            top_5_accuracy = tf.metrics.mean(in_top_5)

            return {
                'top_1_accuracy': top_1_accuracy,
                'top_5_accuracy': top_5_accuracy,
            }

        eval_metrics = (metric_fn, [labels, logits])

    num_params = np.sum([np.prod(v.shape) for v in tf.trainable_variables()])
    tf.logging.info('number of trainable parameters: {}'.format(num_params))

    def _scaffold_fn():
        saver = tf.train.Saver(restore_vars_dict)
        return tf.train.Scaffold(saver=saver)

    return tf.contrib.tpu.TPUEstimatorSpec(
        mode=mode,
        loss=loss,
        train_op=train_op,
        host_call=host_call,
        eval_metrics=eval_metrics,
        scaffold_fn=_scaffold_fn if has_moving_average_decay else None)