Exemple #1
0
    def model_fn(features, labels, mode, params=None):
        is_training = mode == tf.estimator.ModeKeys.TRAIN

        # Check training mode
        if FLAGS.train_mode == 'pretrain':
            raise ValueError('Pretraining not possible in multihead config,'
                             'Set train_mode to finetune')
        elif FLAGS.train_mode == 'finetune':
            # Base network forward pass.
            with tf.variable_scope('base_model'):
                if FLAGS.train_mode == 'finetune'\
                 and FLAGS.fine_tune_after_block >= 4:
                    # Finetune just supervised (linear) head will not
                    # update BN stats.
                    model_train_mode = False
                else:
                    model_train_mode = is_training
                hiddens = model(features, is_training=model_train_mode)

            logits_animal = model_util.supervised_head(
                hiddens,
                animal_num_classes,
                is_training,
                name='head_supervised_animals')
            obj_lib.add_supervised_loss(labels=labels['animal_label'],
                                        logits=logits_animal,
                                        weights=labels['animal_mask'])

            logits_plant = model_util.supervised_head(
                hiddens,
                plant_num_classes,
                is_training,
                name='head_supervised_plants')
            obj_lib.add_supervised_loss(labels=labels['plant_label'],
                                        logits=logits_plant,
                                        weights=labels['plant_mask'])
            model_util.add_weight_decay(adjust_per_optimizer=True)
            loss = tf.losses.get_total_loss()

            collection_prefix = 'trainable_variables_inblock_'
            variables_to_train = []
            for j in range(FLAGS.fine_tune_after_block + 1, 6):
                variables_to_train += tf.get_collection(collection_prefix +
                                                        str(j))
            assert variables_to_train, 'variables_to_train shouldn\'t be empty'
            tf.logging.info(
                '===============Variables to train (begin)===============')
            tf.logging.info(variables_to_train)
            tf.logging.info(
                '================Variables to train (end)================')

            learning_rate = model_util.learning_rate_schedule(
                FLAGS.learning_rate, num_train_examples)

            if is_training:
                if FLAGS.train_summary_steps > 0:
                    summary_writer = tf2.summary.create_file_writer(
                        FLAGS.model_dir)
                    with tf.control_dependencies([summary_writer.init()]):
                        with summary_writer.as_default():
                            should_record = tf.math.equal(
                                tf.math.floormod(tf.train.get_global_step(),
                                                 FLAGS.train_summary_steps), 0)
                        with tf2.summary.record_if(should_record):
                            label_acc_animal = tf.equal(
                                tf.argmax(labels['animal_label'], 1),
                                tf.argmax(logits_animal, axis=1))
                            label_acc_plant = tf.equal(
                                tf.argmax(labels['plant_label'], 1),
                                tf.argmax(logits_plant, axis=1))

                            label_acc = tf.math.logical_or(
                                label_acc_animal, label_acc_plant)
                            label_acc = tf.reduce_mean(
                                tf.cast(label_acc, tf.float32))
                            tf2.summary.scalar('train_label_accuracy',
                                               label_acc,
                                               step=tf.train.get_global_step())
                            tf2.summary.scalar('learning_rate',
                                               learning_rate,
                                               step=tf.train.get_global_step())
                if FLAGS.optimizer == 'momentum':
                    optimizer = tf.train.MomentumOptimizer(learning_rate,
                                                           FLAGS.momentum,
                                                           use_nesterov=True)
                elif FLAGS.optimizer == 'adam':
                    optimizer = tf.train.AdamOptimizer(learning_rate)
                elif FLAGS.optimizer == 'lars':
                    optimizer = LARSOptimizer(learning_rate,
                                              momentum=FLAGS.momentum,
                                              weight_decay=FLAGS.weight_decay,
                                              exclude_from_weight_decay=[
                                                  'batch_normalization', 'bias'
                                              ])
                else:
                    raise ValueError('Unknown optimizer {}'.format(
                        FLAGS.optimizer))

                if FLAGS.use_tpu:
                    optimizer = tf.tpu.CrossShardOptimizer(optimizer)
                control_deps = tf.get_collection(tf.GraphKeys.UPDATE_OPS)

                if FLAGS.train_summary_steps > 0:
                    control_deps.extend(tf.summary.all_v2_summary_ops())
                with tf.control_dependencies(control_deps):
                    train_op = optimizer.minimize(
                        loss,
                        global_step=tf.train.get_or_create_global_step(),
                        var_list=variables_to_train)

                if FLAGS.checkpoint:

                    def scaffold_fn():
                        """
                        Scaffold function to restore non-logits vars from
                        checkpoint.
                        """
                        tf.train.init_from_checkpoint(
                            FLAGS.checkpoint, {
                                v.op.name: v.op.name
                                for v in tf.global_variables(
                                    FLAGS.variable_schema)
                            })

                        if FLAGS.zero_init_logits_layer:
                            # Init op that initializes output layer parameters
                            # to zeros.
                            output_layer_parameters = [
                                var for var in tf.trainable_variables()
                                if var.name.startswith('head_supervised')
                            ]
                            tf.logging.info(
                                'Initializing output layer parameters %s to 0',
                                [x.op.name for x in output_layer_parameters])
                            with tf.control_dependencies(
                                [tf.global_variables_initializer()]):
                                init_op = tf.group([
                                    tf.assign(x, tf.zeros_like(x))
                                    for x in output_layer_parameters
                                ])
                            return tf.train.Scaffold(init_op=init_op)
                        else:
                            return tf.train.Scaffold()
                else:
                    scaffold_fn = None

                return tf.estimator.tpu.TPUEstimatorSpec(
                    mode=mode,
                    train_op=train_op,
                    loss=loss,
                    scaffold_fn=scaffold_fn)

            elif mode == tf.estimator.ModeKeys.PREDICT:
                _, animal_top_5 = tf.nn.top_k(logits_animal, k=5)
                _, plant_top_5 = tf.nn.top_k(logits_plant, k=5)

                predictions = {
                    'label': tf.argmax(labels['label'], 1),
                    'animal_label': tf.argmax(labels['animal_label'], 1),
                    'plant_label': tf.argmax(labels['plant_label'], 1),
                    'animal_top_5': animal_top_5,
                    'plant_top_5': plant_top_5,
                }

                return tf.estimator.tpu.TPUEstimatorSpec(
                    mode=mode, predictions=predictions)
            elif mode == tf.estimator.ModeKeys.EVAL:

                def metric_fn(logits_animal, logits_plant, labels_animal,
                              labels_plant, mask_animal, mask_plant, **kws):
                    metrics[
                        'label_animal_top_1_accuracy'] = tf.metrics.accuracy(
                            tf.argmax(labels_animal, 1),
                            tf.argmax(logits_animal, axis=1),
                            weights=mask_animal)
                    metrics[
                        'label_animal_top_5_accuracy'] = tf.metrics.recall_at_k(
                            tf.argmax(labels_animal, 1),
                            logits_animal,
                            k=5,
                            weights=mask_animal)
                    metrics[
                        'label_plant_top_1_accuracy'] = tf.metrics.accuracy(
                            tf.argmax(labels_plant, 1),
                            tf.argmax(logits_plant, axis=1),
                            weights=mask_plant)
                    metrics[
                        'label_plant_top_5_accuracy'] = tf.metrics.recall_at_k(
                            tf.argmax(labels_plant, 1),
                            logits_plant,
                            k=5,
                            weights=mask_plant)

                metrics = {
                    'logits_animal': logits_animal,
                    'logits_plant': logits_plant,
                    'labels_animal': labels['animal_label'],
                    'labels_plant': labels['plant_label'],
                    'mask_animal': labels['animal_mask'],
                    'mask_plant': labels['plant_mask'],
                }

                return tf.estimator.tpu.TPUEstimatorSpec(
                    mode=mode,
                    loss=loss,
                    eval_metrics=(metric_fn, metrics),
                    scaffold_fn=None)
            else:
                print('Invalid mode.')
Exemple #2
0
 def single_step(features, labels):
   with tf.GradientTape() as tape:
     # Log summaries on the last step of the training loop to match
     # logging frequency of other scalar summaries.
     #
     # Notes:
     # 1. Summary ops on TPUs get outside compiled so they do not affect
     #    performance.
     # 2. Summaries are recorded only on replica 0. So effectively this
     #    summary would be written once per host when should_record == True.
     # 3. optimizer.iterations is incremented in the call to apply_gradients.
     #    So we use  `iterations + 1` here so that the step number matches
     #    those of scalar summaries.
     # 4. We intentionally run the summary op before the actual model
     #    training so that it can run in parallel.
     should_record = tf.equal((optimizer.iterations + 1) % steps_per_loop, 0)
     with tf.summary.record_if(should_record):
       # Only log augmented images for the first tower.
       tf.summary.image(
           'image', features[:, :, :, :3], step=optimizer.iterations + 1)
     projection_head_outputs, supervised_head_outputs = model(
         features, training=True)
     loss = None
     if projection_head_outputs is not None:
       outputs = projection_head_outputs
       con_loss, logits_con, labels_con = obj_lib.add_contrastive_loss(
           outputs,
           hidden_norm=FLAGS.hidden_norm,
           temperature=FLAGS.temperature,
           strategy=strategy)
       if loss is None:
         loss = con_loss
       else:
         loss += con_loss
       metrics.update_pretrain_metrics_train(contrast_loss_metric,
                                             contrast_acc_metric,
                                             contrast_entropy_metric,
                                             con_loss, logits_con,
                                             labels_con)
     if supervised_head_outputs is not None:
       outputs = supervised_head_outputs
       l = labels['labels']
       if FLAGS.train_mode == 'pretrain' and FLAGS.lineareval_while_pretraining:
         l = tf.concat([l, l], 0)
       sup_loss = obj_lib.add_supervised_loss(labels=l, logits=outputs)
       if loss is None:
         loss = sup_loss
       else:
         loss += sup_loss
       metrics.update_finetune_metrics_train(supervised_loss_metric,
                                             supervised_acc_metric, sup_loss,
                                             l, outputs)
     weight_decay = model_lib.add_weight_decay(
         model, adjust_per_optimizer=True)
     weight_decay_metric.update_state(weight_decay)
     loss += weight_decay
     total_loss_metric.update_state(loss)
     # The default behavior of `apply_gradients` is to sum gradients from all
     # replicas so we divide the loss by the number of replicas so that the
     # mean gradient is applied.
     loss = loss / strategy.num_replicas_in_sync
     logging.info('Trainable variables:')
     for var in model.trainable_variables:
       logging.info(var.name)
     grads = tape.gradient(loss, model.trainable_variables)
     optimizer.apply_gradients(zip(grads, model.trainable_variables))
Exemple #3
0
    def model_fn(features, labels, mode, params=None):
        """Build model and optimizer."""
        is_training = mode == tf.estimator.ModeKeys.TRAIN

        # Check training mode.
        if FLAGS.train_mode == 'pretrain':
            num_transforms = 2
            if FLAGS.fine_tune_after_block > -1:
                raise ValueError(
                    'Does not support layer freezing during pretraining,'
                    'should set fine_tune_after_block<=-1 for safety.')
        elif FLAGS.train_mode == 'finetune':
            num_transforms = 1
        else:
            raise ValueError('Unknown train_mode {}'.format(FLAGS.train_mode))

        # Split channels, and optionally apply extra batched augmentation.
        features_list = tf.split(features,
                                 num_or_size_splits=num_transforms,
                                 axis=-1)
        if FLAGS.use_blur and is_training and FLAGS.train_mode == 'pretrain':
            features_list = data_util.batch_random_blur(
                features_list, FLAGS.image_size, FLAGS.image_size)
        features = tf.concat(features_list,
                             0)  # (num_transforms * bsz, h, w, c)

        # Base network forward pass.
        with tf.variable_scope('base_model'):
            if FLAGS.train_mode == 'finetune' and FLAGS.fine_tune_after_block >= 4:
                # Finetune just supervised (linear) head will not update BN stats.
                model_train_mode = False
            else:
                # Pretrain or finetuen anything else will update BN stats.
                model_train_mode = is_training
            hiddens = model(features, is_training=model_train_mode)

        # Add head and loss.
        if FLAGS.train_mode == 'pretrain':
            tpu_context = params['context'] if 'context' in params else None
            hiddens_proj = model_util.projection_head(hiddens, is_training)
            contrast_loss, logits_con, labels_con = obj_lib.add_contrastive_loss(
                hiddens_proj,
                hidden_norm=FLAGS.hidden_norm,
                temperature=FLAGS.temperature,
                tpu_context=tpu_context if is_training else None)
            logits_sup = tf.zeros([params['batch_size'], num_classes])
        else:
            contrast_loss = tf.zeros([])
            logits_con = tf.zeros([params['batch_size'], 10])
            labels_con = tf.zeros([params['batch_size'], 10])
            logits_sup = model_util.supervised_head(hiddens, num_classes,
                                                    is_training)
            obj_lib.add_supervised_loss(labels=labels['labels'],
                                        logits=logits_sup,
                                        weights=labels['mask'])

        # Add weight decay to loss, for non-LARS optimizers.
        model_util.add_weight_decay(adjust_per_optimizer=True)
        loss = tf.losses.get_total_loss()

        if FLAGS.train_mode == 'pretrain':
            variables_to_train = tf.trainable_variables()
        else:
            collection_prefix = 'trainable_variables_inblock_'
            variables_to_train = []
            for j in range(FLAGS.fine_tune_after_block + 1, 6):
                variables_to_train += tf.get_collection(collection_prefix +
                                                        str(j))
            assert variables_to_train, 'variables_to_train shouldn\'t be empty!'

        tf.logging.info(
            '===============Variables to train (begin)===============')
        tf.logging.info(variables_to_train)
        tf.logging.info(
            '================Variables to train (end)================')

        learning_rate = model_util.learning_rate_schedule(
            FLAGS.learning_rate, num_train_examples)

        if is_training:
            if FLAGS.train_summary_steps > 0:
                # Compute stats for the summary.
                prob_con = tf.nn.softmax(logits_con)
                entropy_con = -tf.reduce_mean(
                    tf.reduce_sum(prob_con * tf.math.log(prob_con + 1e-8), -1))

                summary_writer = tf2.summary.create_file_writer(
                    FLAGS.model_dir)
                # TODO(iamtingchen): remove this control_dependencies in the future.
                with tf.control_dependencies([summary_writer.init()]):
                    with summary_writer.as_default():
                        should_record = tf.math.equal(
                            tf.math.floormod(tf.train.get_global_step(),
                                             FLAGS.train_summary_steps), 0)
                        with tf2.summary.record_if(should_record):
                            contrast_acc = tf.equal(
                                tf.argmax(labels_con, 1),
                                tf.argmax(logits_con, axis=1))
                            contrast_acc = tf.reduce_mean(
                                tf.cast(contrast_acc, tf.float32))
                            label_acc = tf.equal(
                                tf.argmax(labels['labels'], 1),
                                tf.argmax(logits_sup, axis=1))
                            label_acc = tf.reduce_mean(
                                tf.cast(label_acc, tf.float32))
                            tf2.summary.scalar('train_contrast_loss',
                                               contrast_loss,
                                               step=tf.train.get_global_step())
                            tf2.summary.scalar('train_contrast_acc',
                                               contrast_acc,
                                               step=tf.train.get_global_step())
                            tf2.summary.scalar('train_label_accuracy',
                                               label_acc,
                                               step=tf.train.get_global_step())
                            tf2.summary.scalar('contrast_entropy',
                                               entropy_con,
                                               step=tf.train.get_global_step())
                            tf2.summary.scalar('learning_rate',
                                               learning_rate,
                                               step=tf.train.get_global_step())
                            tf2.summary.scalar('input_mean',
                                               tf.reduce_mean(features),
                                               step=tf.train.get_global_step())
                            tf2.summary.scalar('input_max',
                                               tf.reduce_max(features),
                                               step=tf.train.get_global_step())
                            tf2.summary.scalar('input_min',
                                               tf.reduce_min(features),
                                               step=tf.train.get_global_step())
                            tf2.summary.scalar('num_labels',
                                               tf.reduce_mean(
                                                   tf.reduce_sum(
                                                       labels['labels'], -1)),
                                               step=tf.train.get_global_step())

            if FLAGS.optimizer == 'momentum':
                optimizer = tf.train.MomentumOptimizer(learning_rate,
                                                       FLAGS.momentum,
                                                       use_nesterov=True)
            elif FLAGS.optimizer == 'adam':
                optimizer = tf.train.AdamOptimizer(learning_rate)
            elif FLAGS.optimizer == 'lars':
                optimizer = LARSOptimizer(
                    learning_rate,
                    momentum=FLAGS.momentum,
                    weight_decay=FLAGS.weight_decay,
                    exclude_from_weight_decay=['batch_normalization', 'bias'])
            else:
                raise ValueError('Unknown optimizer {}'.format(
                    FLAGS.optimizer))

            if FLAGS.use_tpu:
                optimizer = tf.tpu.CrossShardOptimizer(optimizer)

            control_deps = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
            if FLAGS.train_summary_steps > 0:
                control_deps.extend(tf.summary.all_v2_summary_ops())
            with tf.control_dependencies(control_deps):
                train_op = optimizer.minimize(
                    loss,
                    global_step=tf.train.get_or_create_global_step(),
                    var_list=variables_to_train)

            if FLAGS.checkpoint:

                def scaffold_fn():
                    """Scaffold function to restore non-logits vars from checkpoint."""
                    for v in tf.global_variables(FLAGS.variable_schema):
                        print(v.op.name)
                    tf.train.init_from_checkpoint(
                        FLAGS.checkpoint, {
                            v.op.name: v.op.name
                            for v in tf.global_variables(FLAGS.variable_schema)
                        })

                    if FLAGS.zero_init_logits_layer:
                        # Init op that initializes output layer parameters to zeros.
                        output_layer_parameters = [
                            var for var in tf.trainable_variables()
                            if var.name.startswith('head_supervised')
                        ]
                        tf.logging.info(
                            'Initializing output layer parameters %s to zero',
                            [x.op.name for x in output_layer_parameters])
                        with tf.control_dependencies(
                            [tf.global_variables_initializer()]):
                            init_op = tf.group([
                                tf.assign(x, tf.zeros_like(x))
                                for x in output_layer_parameters
                            ])
                        return tf.train.Scaffold(init_op=init_op)
                    else:
                        return tf.train.Scaffold()
            else:
                scaffold_fn = None

            return tf.estimator.tpu.TPUEstimatorSpec(mode=mode,
                                                     train_op=train_op,
                                                     loss=loss,
                                                     scaffold_fn=scaffold_fn)

        elif mode == tf.estimator.ModeKeys.PREDICT:
            _, top_5 = tf.nn.top_k(logits_sup, k=5)
            predictions = {
                'label': tf.argmax(labels['labels'], 1),
                'top_5': top_5,
            }
            return tf.estimator.tpu.TPUEstimatorSpec(mode=mode,
                                                     predictions=predictions)
        else:

            def metric_fn(logits_sup, labels_sup, logits_con, labels_con, mask,
                          **kws):
                """Inner metric function."""
                metrics = {
                    k: tf.metrics.mean(v, weights=mask)
                    for k, v in kws.items()
                }
                metrics['label_top_1_accuracy'] = tf.metrics.accuracy(
                    tf.argmax(labels_sup, 1),
                    tf.argmax(logits_sup, axis=1),
                    weights=mask)
                metrics['label_top_5_accuracy'] = tf.metrics.recall_at_k(
                    tf.argmax(labels_sup, 1), logits_sup, k=5, weights=mask)
                metrics['contrastive_top_1_accuracy'] = tf.metrics.accuracy(
                    tf.argmax(labels_con, 1),
                    tf.argmax(logits_con, axis=1),
                    weights=mask)
                metrics['contrastive_top_5_accuracy'] = tf.metrics.recall_at_k(
                    tf.argmax(labels_con, 1), logits_con, k=5, weights=mask)

                metrics[
                    'mean_class_accuracy'] = tf.metrics.mean_per_class_accuracy(
                        tf.argmax(labels_sup, 1),
                        tf.argmax(logits_sup, axis=1),
                        num_classes,
                        weights=mask,
                        name='mca')

                running_vars = tf.get_collection(tf.GraphKeys.LOCAL_VARIABLES,
                                                 scope="mca")
                metrics['mean_class_accuracy_total'] = running_vars[0]
                metrics['mean_class_accuracy_count'] = running_vars[1]

                return metrics

            metrics = {
                'logits_sup':
                logits_sup,
                'labels_sup':
                labels['labels'],
                'logits_con':
                logits_con,
                'labels_con':
                labels_con,
                'mask':
                labels['mask'],
                'contrast_loss':
                tf.fill((params['batch_size'], ), contrast_loss),
                'regularization_loss':
                tf.fill((params['batch_size'], ),
                        tf.losses.get_regularization_loss()),
            }

            return tf.estimator.tpu.TPUEstimatorSpec(mode=mode,
                                                     loss=loss,
                                                     eval_metrics=(metric_fn,
                                                                   metrics),
                                                     scaffold_fn=None)
Exemple #4
0
    def model_fn(features, labels, mode, params=None):
        """Build model and optimizer."""
        is_training = mode == tf.estimator.ModeKeys.TRAIN

        # Check training mode.
        if FLAGS.train_mode == 'pretrain':
            num_transforms = 2
            if FLAGS.fine_tune_after_block > -1:
                raise ValueError(
                    'Does not support layer freezing during pretraining,'
                    'should set fine_tune_after_block<=-1 for safety.')
        elif FLAGS.train_mode == 'finetune':
            num_transforms = 1
        else:
            raise ValueError('Unknown train_mode {}'.format(FLAGS.train_mode))

        # Split channels, and optionally apply extra batched augmentation.
        features_list = tf.split(
            features, num_or_size_splits=num_transforms, axis=-1
        )  # splits into 2 tensors with 3 channels instead of 1 tensor with 6 channels
        if FLAGS.use_blur and is_training and FLAGS.train_mode == 'pretrain':
            features_list = data_util.batch_random_blur(
                features_list, FLAGS.image_size, FLAGS.image_size)
        # crop images now that all other preprocessing has finished
        if is_training and FLAGS.train_mode == 'pretrain':
            features_list = data_util.batch_random_crop(
                features_list, FLAGS.crop_size,
                FLAGS.crop_size)  # cropped -> hiddens1
        # features = tf.concat(features_list, 0)  # (num_transforms * bsz, h, w, c)
        # Concatenating again is not needed since list elements are used separately from now on
        features = features_list
        if FLAGS.train_mode == 'finetune':
            features = tf.concat(features_list, 0)
            features = [
                features, features
            ]  # since num_transforms is 1, was never split into list. Only one network's output is used in eval, so they are never compared.

        # Base network forward pass.
        with tf.variable_scope('base_model'):
            if FLAGS.train_mode == 'finetune' and FLAGS.fine_tune_after_block >= 4:  # train_mode is either 'finetune' or 'pretrain', 'finetune' used for just training linear head
                # Finetune just supervised (linear) head will not update BN stats.
                model_train_mode = False
            else:
                # Pretrain or finetuen anything else will update BN stats.
                model_train_mode = is_training
            # hiddens = model(features, is_training=model_train_mode)  # model_train_mode=True if fine_tune_after_block < 4, bug??
            hiddens_f = model['model_full'](
                features[0],
                is_training=model_train_mode)  # output of full model
            hiddens_c = model['model_cropped'](
                features[1],
                is_training=model_train_mode)  # output of cropped model

        # Add head and loss.
        if FLAGS.train_mode == 'pretrain':
            tpu_context = params['context'] if 'context' in params else None
            # hiddens_proj = model_util.projection_head(hiddens, is_training)  # by default adds 3 nonlinear layers, paper claims 2 only
            hiddens_proj_f = model_util.projection_head(hiddens_f, is_training)
            hiddens_proj_c = model_util.projection_head(hiddens_c, is_training)

            # calculate attention mask
            attn_mask = model_util.attn_mask_head(
                10 * hiddens_proj_c, is_training,
                name='attn_network')  # 10* helps converge faster
            if FLAGS.attention_output == 'hard':
                attn_mask = tf.cast(
                    attn_mask >= 0.5, tf.float32
                )  # use softmax instead? L2 norm? alter 10* also?
            elif FLAGS.attention_output == 'softmax':
                attn_mask = tf.nn.softmax(
                    attn_mask)  # performed along last dim
            elif FLAGS.attention_output == 'L2':
                attn_mask = tf.math.l2_normalize(attn_mask)
            else:
                raise ValueError('Unknown attention_output {}'.format(
                    FLAGS.attention_output))
            # apply attention mask
            hiddens_proj_f = hiddens_proj_f * attn_mask
            hiddens_proj_c = hiddens_proj_c * attn_mask

            contrast_loss, logits_con, labels_con = obj_lib.add_contrastive_loss_2(
                hiddens_proj_f,
                hiddens_proj_c,
                hidden_norm=FLAGS.hidden_norm,
                temperature=FLAGS.temperature,
                tpu_context=tpu_context if is_training else None)
            logits_sup = tf.zeros([params['batch_size'], num_classes])
        else:
            contrast_loss = tf.zeros([])
            logits_con = tf.zeros([params['batch_size'], 10])
            labels_con = tf.zeros([params['batch_size'], 10])
            # hiddens = model_util.projection_head(hiddens, is_training)  # adds 3 nonlinear layers by default
            hiddens_f = model_util.projection_head(hiddens_f, is_training)
            hiddens_c = model_util.projection_head(hiddens_c, is_training)
            logits_sup = model_util.supervised_head(  # supervised head is just one linear layer, but 3 nonlinear layrs already added above
                hiddens_f, num_classes, is_training
            )  # only evaluate on output from model_full (otherwise need another param to choose)
            obj_lib.add_supervised_loss(  # just softmax_cross_entropy
                labels=labels['labels'],
                logits=logits_sup,
                weights=labels['mask'])  # what does labels[mask] do?

        # Add weight decay to loss, for non-LARS optimizers.
        model_util.add_weight_decay(adjust_per_optimizer=True)
        loss = tf.losses.get_total_loss()

        if FLAGS.train_mode == 'pretrain':
            variables_to_train = tf.trainable_variables()
        else:
            collection_prefix = 'trainable_variables_inblock_'
            variables_to_train = []
            for j in range(FLAGS.fine_tune_after_block + 1, 6):
                variables_to_train += tf.get_collection(collection_prefix +
                                                        str(j))
            assert variables_to_train, 'variables_to_train shouldn\'t be empty!'

        tf.logging.info(
            '===============Variables to train (begin)===============')
        tf.logging.info(variables_to_train)
        tf.logging.info(
            '================Variables to train (end)================')

        learning_rate = model_util.learning_rate_schedule(
            FLAGS.learning_rate, num_train_examples)

        if is_training:
            if FLAGS.train_summary_steps > 0:
                # Compute stats for the summary.
                prob_con = tf.nn.softmax(logits_con)
                entropy_con = -tf.reduce_mean(
                    tf.reduce_sum(prob_con * tf.math.log(prob_con + 1e-8), -1))

                summary_writer = tf2.summary.create_file_writer(
                    FLAGS.model_dir)
                # TODO(iamtingchen): remove this control_dependencies in the future.
                with tf.control_dependencies([summary_writer.init()]):
                    with summary_writer.as_default():
                        should_record = tf.math.equal(
                            tf.math.floormod(tf.train.get_global_step(),
                                             FLAGS.train_summary_steps), 0)
                        with tf2.summary.record_if(should_record):
                            contrast_acc = tf.equal(
                                tf.argmax(labels_con, 1),
                                tf.argmax(logits_con, axis=1))
                            contrast_acc = tf.reduce_mean(
                                tf.cast(contrast_acc, tf.float32))
                            label_acc = tf.equal(
                                tf.argmax(labels['labels'], 1),
                                tf.argmax(logits_sup, axis=1))
                            label_acc = tf.reduce_mean(
                                tf.cast(label_acc, tf.float32))
                            tf2.summary.scalar('train_contrast_loss',
                                               contrast_loss,
                                               step=tf.train.get_global_step())
                            tf2.summary.scalar('train_contrast_acc',
                                               contrast_acc,
                                               step=tf.train.get_global_step())
                            tf2.summary.scalar('train_label_accuracy',
                                               label_acc,
                                               step=tf.train.get_global_step())
                            tf2.summary.scalar('contrast_entropy',
                                               entropy_con,
                                               step=tf.train.get_global_step())
                            tf2.summary.scalar('learning_rate',
                                               learning_rate,
                                               step=tf.train.get_global_step())

                            if FLAGS.train_mode == 'pretrain':
                                tf2.summary.histogram(
                                    'mask_hist',
                                    attn_mask,
                                    step=tf.train.get_global_step())

            optimizer = model_util.get_optimizer(learning_rate)
            control_deps = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
            if FLAGS.train_summary_steps > 0:
                control_deps.extend(tf.summary.all_v2_summary_ops())
            with tf.control_dependencies(control_deps):
                train_op = optimizer.minimize(
                    loss,
                    global_step=tf.train.get_or_create_global_step(),
                    var_list=variables_to_train)

            if FLAGS.checkpoint:

                def scaffold_fn():
                    """Scaffold function to restore non-logits vars from checkpoint."""
                    tf.train.init_from_checkpoint(
                        FLAGS.checkpoint, {
                            v.op.name: v.op.name
                            for v in tf.global_variables(FLAGS.variable_schema)
                        })

                    if FLAGS.zero_init_logits_layer:
                        # Init op that initializes output layer parameters to zeros.
                        output_layer_parameters = [
                            var for var in tf.trainable_variables()
                            if var.name.startswith('head_supervised')
                        ]
                        tf.logging.info(
                            'Initializing output layer parameters %s to zero',
                            [x.op.name for x in output_layer_parameters])
                        with tf.control_dependencies(
                            [tf.global_variables_initializer()]):
                            init_op = tf.group([
                                tf.assign(x, tf.zeros_like(x))
                                for x in output_layer_parameters
                            ])
                        return tf.train.Scaffold(init_op=init_op)
                    else:
                        return tf.train.Scaffold()
            else:
                scaffold_fn = None

            return tf.estimator.tpu.TPUEstimatorSpec(mode=mode,
                                                     train_op=train_op,
                                                     loss=loss,
                                                     scaffold_fn=scaffold_fn)
        else:

            def metric_fn(logits_sup, labels_sup, logits_con, labels_con, mask,
                          **kws):
                """Inner metric function."""
                metrics = {
                    k: tf.metrics.mean(v, weights=mask)
                    for k, v in kws.items()
                }
                metrics['label_top_1_accuracy'] = tf.metrics.accuracy(
                    tf.argmax(labels_sup, 1),
                    tf.argmax(logits_sup, axis=1),
                    weights=mask)
                metrics['label_top_5_accuracy'] = tf.metrics.recall_at_k(
                    tf.argmax(labels_sup, 1), logits_sup, k=5, weights=mask)
                metrics['contrastive_top_1_accuracy'] = tf.metrics.accuracy(
                    tf.argmax(labels_con, 1),
                    tf.argmax(logits_con, axis=1),
                    weights=mask)
                metrics['contrastive_top_5_accuracy'] = tf.metrics.recall_at_k(
                    tf.argmax(labels_con, 1), logits_con, k=5, weights=mask)
                return metrics

            metrics = {
                'logits_sup':
                logits_sup,
                'labels_sup':
                labels['labels'],
                'logits_con':
                logits_con,
                'labels_con':
                labels_con,
                'mask':
                labels['mask'],
                'contrast_loss':
                tf.fill((params['batch_size'], ), contrast_loss),
                'regularization_loss':
                tf.fill((params['batch_size'], ),
                        tf.losses.get_regularization_loss()),
            }

            return tf.estimator.tpu.TPUEstimatorSpec(mode=mode,
                                                     loss=loss,
                                                     eval_metrics=(metric_fn,
                                                                   metrics),
                                                     scaffold_fn=None)