Exemple #1
0
    def model(self,
              batch,
              lr,
              wd,
              wu,
              wclr,
              mom,
              confidence,
              balance,
              delT,
              uratio,
              clrratio,
              temperature,
              ema=0.999,
              **kwargs):
        hwc = [self.dataset.height, self.dataset.width, self.dataset.colors]
        xt_in = tf.placeholder(tf.float32, [batch] + hwc,
                               'xt')  # Training labeled
        x_in = tf.placeholder(tf.float32, [None] + hwc, 'x')  # Eval images
        y_in = tf.placeholder(tf.float32, [batch * uratio, 2] + hwc,
                              'y')  # Training unlabeled (weak, strong)
        l_in = tf.placeholder(tf.int32, [batch], 'labels')  # Labels
        wclr_in = tf.placeholder(tf.int32, [1], 'wclr')  # wclr

        lrate = tf.clip_by_value(
            tf.to_float(self.step) / (FLAGS.train_kimg << 10), 0, 1)
        lr *= tf.cos(lrate * (7 * np.pi) / (2 * 8))
        tf.summary.scalar('monitors/lr', lr)

        # Compute logits for xt_in and y_in
        classifier = lambda x, **kw: self.classifier(x, **kw, **kwargs).logits
        skip_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
        x = utils.interleave(tf.concat([xt_in, y_in[:, 0], y_in[:, 1]], 0),
                             2 * uratio + 1)
        logits = utils.para_cat(lambda x: classifier(x, training=True), x)
        logits = utils.de_interleave(logits, 2 * uratio + 1)
        post_ops = [
            v for v in tf.get_collection(tf.GraphKeys.UPDATE_OPS)
            if v not in skip_ops
        ]
        logits_x = logits[:batch]
        logits_weak, logits_strong = tf.split(logits[batch:], 2)
        del logits, skip_ops

        # Labeled cross-entropy
        loss_xe = tf.nn.sparse_softmax_cross_entropy_with_logits(
            labels=l_in, logits=logits_x)
        loss_xe = tf.reduce_mean(loss_xe)
        tf.summary.scalar('losses/xe', loss_xe)

        # Pseudo-label cross entropy for unlabeled data
        pseudo_labels = tf.stop_gradient(tf.nn.softmax(logits_weak))
        loss_xeu = tf.nn.sparse_softmax_cross_entropy_with_logits(
            labels=tf.argmax(pseudo_labels, axis=1), logits=logits_strong)
        #        pseudo_mask = tf.to_float(tf.reduce_max(pseudo_labels, axis=1) >= confidence)
        pseudo_mask = self.class_balancing(pseudo_labels, balance, confidence,
                                           delT)
        tf.summary.scalar('monitors/mask', tf.reduce_mean(pseudo_mask))
        loss_xeu = tf.reduce_mean(loss_xeu * pseudo_mask)
        tf.summary.scalar('losses/xeu', loss_xeu)

        ####################### Modification
        # Contrastive loss term
        contrast_loss = 0
        if wclr > 0 and wclr_in == 0:
            ratio = min(uratio, clrratio)
            if FLAGS.clrDataAug == 1:
                preprocess_fn = functools.partial(
                    data_util.preprocess_for_train,
                    height=self.dataset.height,
                    width=self.dataset.width)
                x = tf.concat(
                    [lambda y: preprocess_fn(y), lambda y: preprocess_fn(y)],
                    0)
                embeds = lambda x, **kw: self.classifier(x, **kw, **kwargs
                                                         ).embeds
                hidden = utils.para_cat(lambda x: embeds(x, training=True), x)
            else:
                embeds = lambda x, **kw: self.classifier(x, **kw, **kwargs
                                                         ).embeds
                hiddens = utils.para_cat(lambda x: embeds(x, training=True), x)
                hiddens = utils.de_interleave(hiddens, 2 * uratio + 1)
                hiddens_weak, hiddens_strong = tf.split(hiddens[batch:], 2, 0)
                hidden = tf.concat([
                    hiddens_weak[:batch * ratio],
                    hiddens_strong[:batch * ratio]
                ],
                                   axis=0)
                del hiddens, hiddens_weak, hiddens_strong

            contrast_loss, _, _ = obj_lib.add_contrastive_loss(
                hidden,
                hidden_norm=True,  # FLAGS.hidden_norm,
                temperature=temperature,
                tpu_context=None)

            tf.summary.scalar('losses/contrast', contrast_loss)
            del embeds, hidden
###################### End

# L2 regularization
        loss_wd = sum(
            tf.nn.l2_loss(v) for v in utils.model_vars('classify')
            if 'kernel' in v.name)
        tf.summary.scalar('losses/wd', loss_wd)

        ema = tf.train.ExponentialMovingAverage(decay=ema)
        ema_op = ema.apply(utils.model_vars())
        ema_getter = functools.partial(utils.getter_ema, ema)
        post_ops.append(ema_op)

        #        train_op = tf.train.MomentumOptimizer(lr, 0.9, use_nesterov=True).minimize(
        train_op = tf.train.MomentumOptimizer(
            lr, mom, use_nesterov=True).minimize(
                loss_xe + wu * loss_xeu + wclr * contrast_loss + wd * loss_wd,
                colocate_gradients_with_ops=True)
        with tf.control_dependencies([train_op]):
            train_op = tf.group(*post_ops)

        return utils.EasyDict(
            xt=xt_in,
            x=x_in,
            y=y_in,
            label=l_in,
            wclr=wclr_in,
            train_op=train_op,
            classify_raw=tf.nn.softmax(classifier(
                x_in, training=False)),  # No EMA, for debugging.
            classify_op=tf.nn.softmax(
                classifier(x_in, getter=ema_getter, training=False)))
Exemple #2
0
    def model_fn(features, labels, mode, params=None):
        """Build model and optimizer."""
        is_training = mode == tf.estimator.ModeKeys.TRAIN

        # Check training mode.
        if FLAGS.train_mode == 'pretrain':
            num_transforms = 2
            if FLAGS.fine_tune_after_block > -1:
                raise ValueError(
                    'Does not support layer freezing during pretraining,'
                    'should set fine_tune_after_block<=-1 for safety.')
        elif FLAGS.train_mode == 'finetune':
            num_transforms = 1
        else:
            raise ValueError('Unknown train_mode {}'.format(FLAGS.train_mode))

        # Split channels, and optionally apply extra batched augmentation.
        features_list = tf.split(features,
                                 num_or_size_splits=num_transforms,
                                 axis=-1)
        if FLAGS.use_blur and is_training and FLAGS.train_mode == 'pretrain':
            features_list = data_util.batch_random_blur(
                features_list, FLAGS.image_size, FLAGS.image_size)
        features = tf.concat(features_list,
                             0)  # (num_transforms * bsz, h, w, c)

        # Base network forward pass.
        with tf.variable_scope('base_model'):
            if FLAGS.train_mode == 'finetune' and FLAGS.fine_tune_after_block >= 4:
                # Finetune just supervised (linear) head will not update BN stats.
                model_train_mode = False
            else:
                # Pretrain or finetuen anything else will update BN stats.
                model_train_mode = is_training
            hiddens = model(features, is_training=model_train_mode)

        # Add head and loss.
        if FLAGS.train_mode == 'pretrain':
            tpu_context = params['context'] if 'context' in params else None
            hiddens_proj = model_util.projection_head(hiddens, is_training)
            contrast_loss, logits_con, labels_con = obj_lib.add_contrastive_loss(
                hiddens_proj,
                hidden_norm=FLAGS.hidden_norm,
                temperature=FLAGS.temperature,
                tpu_context=tpu_context if is_training else None)
            logits_sup = tf.zeros([params['batch_size'], num_classes])
        else:
            contrast_loss = tf.zeros([])
            logits_con = tf.zeros([params['batch_size'], 10])
            labels_con = tf.zeros([params['batch_size'], 10])
            logits_sup = model_util.supervised_head(hiddens, num_classes,
                                                    is_training)
            obj_lib.add_supervised_loss(labels=labels['labels'],
                                        logits=logits_sup,
                                        weights=labels['mask'])

        # Add weight decay to loss, for non-LARS optimizers.
        model_util.add_weight_decay(adjust_per_optimizer=True)
        loss = tf.losses.get_total_loss()

        if FLAGS.train_mode == 'pretrain':
            variables_to_train = tf.trainable_variables()
        else:
            collection_prefix = 'trainable_variables_inblock_'
            variables_to_train = []
            for j in range(FLAGS.fine_tune_after_block + 1, 6):
                variables_to_train += tf.get_collection(collection_prefix +
                                                        str(j))
            assert variables_to_train, 'variables_to_train shouldn\'t be empty!'

        tf.logging.info(
            '===============Variables to train (begin)===============')
        tf.logging.info(variables_to_train)
        tf.logging.info(
            '================Variables to train (end)================')

        learning_rate = model_util.learning_rate_schedule(
            FLAGS.learning_rate, num_train_examples)

        if is_training:
            if FLAGS.train_summary_steps > 0:
                # Compute stats for the summary.
                prob_con = tf.nn.softmax(logits_con)
                entropy_con = -tf.reduce_mean(
                    tf.reduce_sum(prob_con * tf.math.log(prob_con + 1e-8), -1))

                summary_writer = tf2.summary.create_file_writer(
                    FLAGS.model_dir)
                # TODO(iamtingchen): remove this control_dependencies in the future.
                with tf.control_dependencies([summary_writer.init()]):
                    with summary_writer.as_default():
                        should_record = tf.math.equal(
                            tf.math.floormod(tf.train.get_global_step(),
                                             FLAGS.train_summary_steps), 0)
                        with tf2.summary.record_if(should_record):
                            contrast_acc = tf.equal(
                                tf.argmax(labels_con, 1),
                                tf.argmax(logits_con, axis=1))
                            contrast_acc = tf.reduce_mean(
                                tf.cast(contrast_acc, tf.float32))
                            label_acc = tf.equal(
                                tf.argmax(labels['labels'], 1),
                                tf.argmax(logits_sup, axis=1))
                            label_acc = tf.reduce_mean(
                                tf.cast(label_acc, tf.float32))
                            tf2.summary.scalar('train_contrast_loss',
                                               contrast_loss,
                                               step=tf.train.get_global_step())
                            tf2.summary.scalar('train_contrast_acc',
                                               contrast_acc,
                                               step=tf.train.get_global_step())
                            tf2.summary.scalar('train_label_accuracy',
                                               label_acc,
                                               step=tf.train.get_global_step())
                            tf2.summary.scalar('contrast_entropy',
                                               entropy_con,
                                               step=tf.train.get_global_step())
                            tf2.summary.scalar('learning_rate',
                                               learning_rate,
                                               step=tf.train.get_global_step())
                            tf2.summary.scalar('input_mean',
                                               tf.reduce_mean(features),
                                               step=tf.train.get_global_step())
                            tf2.summary.scalar('input_max',
                                               tf.reduce_max(features),
                                               step=tf.train.get_global_step())
                            tf2.summary.scalar('input_min',
                                               tf.reduce_min(features),
                                               step=tf.train.get_global_step())
                            tf2.summary.scalar('num_labels',
                                               tf.reduce_mean(
                                                   tf.reduce_sum(
                                                       labels['labels'], -1)),
                                               step=tf.train.get_global_step())

            if FLAGS.optimizer == 'momentum':
                optimizer = tf.train.MomentumOptimizer(learning_rate,
                                                       FLAGS.momentum,
                                                       use_nesterov=True)
            elif FLAGS.optimizer == 'adam':
                optimizer = tf.train.AdamOptimizer(learning_rate)
            elif FLAGS.optimizer == 'lars':
                optimizer = LARSOptimizer(
                    learning_rate,
                    momentum=FLAGS.momentum,
                    weight_decay=FLAGS.weight_decay,
                    exclude_from_weight_decay=['batch_normalization', 'bias'])
            else:
                raise ValueError('Unknown optimizer {}'.format(
                    FLAGS.optimizer))

            if FLAGS.use_tpu:
                optimizer = tf.tpu.CrossShardOptimizer(optimizer)

            control_deps = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
            if FLAGS.train_summary_steps > 0:
                control_deps.extend(tf.summary.all_v2_summary_ops())
            with tf.control_dependencies(control_deps):
                train_op = optimizer.minimize(
                    loss,
                    global_step=tf.train.get_or_create_global_step(),
                    var_list=variables_to_train)

            if FLAGS.checkpoint:

                def scaffold_fn():
                    """Scaffold function to restore non-logits vars from checkpoint."""
                    for v in tf.global_variables(FLAGS.variable_schema):
                        print(v.op.name)
                    tf.train.init_from_checkpoint(
                        FLAGS.checkpoint, {
                            v.op.name: v.op.name
                            for v in tf.global_variables(FLAGS.variable_schema)
                        })

                    if FLAGS.zero_init_logits_layer:
                        # Init op that initializes output layer parameters to zeros.
                        output_layer_parameters = [
                            var for var in tf.trainable_variables()
                            if var.name.startswith('head_supervised')
                        ]
                        tf.logging.info(
                            'Initializing output layer parameters %s to zero',
                            [x.op.name for x in output_layer_parameters])
                        with tf.control_dependencies(
                            [tf.global_variables_initializer()]):
                            init_op = tf.group([
                                tf.assign(x, tf.zeros_like(x))
                                for x in output_layer_parameters
                            ])
                        return tf.train.Scaffold(init_op=init_op)
                    else:
                        return tf.train.Scaffold()
            else:
                scaffold_fn = None

            return tf.estimator.tpu.TPUEstimatorSpec(mode=mode,
                                                     train_op=train_op,
                                                     loss=loss,
                                                     scaffold_fn=scaffold_fn)

        elif mode == tf.estimator.ModeKeys.PREDICT:
            _, top_5 = tf.nn.top_k(logits_sup, k=5)
            predictions = {
                'label': tf.argmax(labels['labels'], 1),
                'top_5': top_5,
            }
            return tf.estimator.tpu.TPUEstimatorSpec(mode=mode,
                                                     predictions=predictions)
        else:

            def metric_fn(logits_sup, labels_sup, logits_con, labels_con, mask,
                          **kws):
                """Inner metric function."""
                metrics = {
                    k: tf.metrics.mean(v, weights=mask)
                    for k, v in kws.items()
                }
                metrics['label_top_1_accuracy'] = tf.metrics.accuracy(
                    tf.argmax(labels_sup, 1),
                    tf.argmax(logits_sup, axis=1),
                    weights=mask)
                metrics['label_top_5_accuracy'] = tf.metrics.recall_at_k(
                    tf.argmax(labels_sup, 1), logits_sup, k=5, weights=mask)
                metrics['contrastive_top_1_accuracy'] = tf.metrics.accuracy(
                    tf.argmax(labels_con, 1),
                    tf.argmax(logits_con, axis=1),
                    weights=mask)
                metrics['contrastive_top_5_accuracy'] = tf.metrics.recall_at_k(
                    tf.argmax(labels_con, 1), logits_con, k=5, weights=mask)

                metrics[
                    'mean_class_accuracy'] = tf.metrics.mean_per_class_accuracy(
                        tf.argmax(labels_sup, 1),
                        tf.argmax(logits_sup, axis=1),
                        num_classes,
                        weights=mask,
                        name='mca')

                running_vars = tf.get_collection(tf.GraphKeys.LOCAL_VARIABLES,
                                                 scope="mca")
                metrics['mean_class_accuracy_total'] = running_vars[0]
                metrics['mean_class_accuracy_count'] = running_vars[1]

                return metrics

            metrics = {
                'logits_sup':
                logits_sup,
                'labels_sup':
                labels['labels'],
                'logits_con':
                logits_con,
                'labels_con':
                labels_con,
                'mask':
                labels['mask'],
                'contrast_loss':
                tf.fill((params['batch_size'], ), contrast_loss),
                'regularization_loss':
                tf.fill((params['batch_size'], ),
                        tf.losses.get_regularization_loss()),
            }

            return tf.estimator.tpu.TPUEstimatorSpec(mode=mode,
                                                     loss=loss,
                                                     eval_metrics=(metric_fn,
                                                                   metrics),
                                                     scaffold_fn=None)
Exemple #3
0
 def single_step(features, labels):
   with tf.GradientTape() as tape:
     # Log summaries on the last step of the training loop to match
     # logging frequency of other scalar summaries.
     #
     # Notes:
     # 1. Summary ops on TPUs get outside compiled so they do not affect
     #    performance.
     # 2. Summaries are recorded only on replica 0. So effectively this
     #    summary would be written once per host when should_record == True.
     # 3. optimizer.iterations is incremented in the call to apply_gradients.
     #    So we use  `iterations + 1` here so that the step number matches
     #    those of scalar summaries.
     # 4. We intentionally run the summary op before the actual model
     #    training so that it can run in parallel.
     should_record = tf.equal((optimizer.iterations + 1) % steps_per_loop, 0)
     with tf.summary.record_if(should_record):
       # Only log augmented images for the first tower.
       tf.summary.image(
           'image', features[:, :, :, :3], step=optimizer.iterations + 1)
     projection_head_outputs, supervised_head_outputs = model(
         features, training=True)
     loss = None
     if projection_head_outputs is not None:
       outputs = projection_head_outputs
       con_loss, logits_con, labels_con = obj_lib.add_contrastive_loss(
           outputs,
           hidden_norm=FLAGS.hidden_norm,
           temperature=FLAGS.temperature,
           strategy=strategy)
       if loss is None:
         loss = con_loss
       else:
         loss += con_loss
       metrics.update_pretrain_metrics_train(contrast_loss_metric,
                                             contrast_acc_metric,
                                             contrast_entropy_metric,
                                             con_loss, logits_con,
                                             labels_con)
     if supervised_head_outputs is not None:
       outputs = supervised_head_outputs
       l = labels['labels']
       if FLAGS.train_mode == 'pretrain' and FLAGS.lineareval_while_pretraining:
         l = tf.concat([l, l], 0)
       sup_loss = obj_lib.add_supervised_loss(labels=l, logits=outputs)
       if loss is None:
         loss = sup_loss
       else:
         loss += sup_loss
       metrics.update_finetune_metrics_train(supervised_loss_metric,
                                             supervised_acc_metric, sup_loss,
                                             l, outputs)
     weight_decay = model_lib.add_weight_decay(
         model, adjust_per_optimizer=True)
     weight_decay_metric.update_state(weight_decay)
     loss += weight_decay
     total_loss_metric.update_state(loss)
     # The default behavior of `apply_gradients` is to sum gradients from all
     # replicas so we divide the loss by the number of replicas so that the
     # mean gradient is applied.
     loss = loss / strategy.num_replicas_in_sync
     logging.info('Trainable variables:')
     for var in model.trainable_variables:
       logging.info(var.name)
     grads = tape.gradient(loss, model.trainable_variables)
     optimizer.apply_gradients(zip(grads, model.trainable_variables))