def model(self, batch, lr, wd, wu, wclr, mom, confidence, balance, delT, uratio, clrratio, temperature, ema=0.999, **kwargs): hwc = [self.dataset.height, self.dataset.width, self.dataset.colors] xt_in = tf.placeholder(tf.float32, [batch] + hwc, 'xt') # Training labeled x_in = tf.placeholder(tf.float32, [None] + hwc, 'x') # Eval images y_in = tf.placeholder(tf.float32, [batch * uratio, 2] + hwc, 'y') # Training unlabeled (weak, strong) l_in = tf.placeholder(tf.int32, [batch], 'labels') # Labels wclr_in = tf.placeholder(tf.int32, [1], 'wclr') # wclr lrate = tf.clip_by_value( tf.to_float(self.step) / (FLAGS.train_kimg << 10), 0, 1) lr *= tf.cos(lrate * (7 * np.pi) / (2 * 8)) tf.summary.scalar('monitors/lr', lr) # Compute logits for xt_in and y_in classifier = lambda x, **kw: self.classifier(x, **kw, **kwargs).logits skip_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) x = utils.interleave(tf.concat([xt_in, y_in[:, 0], y_in[:, 1]], 0), 2 * uratio + 1) logits = utils.para_cat(lambda x: classifier(x, training=True), x) logits = utils.de_interleave(logits, 2 * uratio + 1) post_ops = [ v for v in tf.get_collection(tf.GraphKeys.UPDATE_OPS) if v not in skip_ops ] logits_x = logits[:batch] logits_weak, logits_strong = tf.split(logits[batch:], 2) del logits, skip_ops # Labeled cross-entropy loss_xe = tf.nn.sparse_softmax_cross_entropy_with_logits( labels=l_in, logits=logits_x) loss_xe = tf.reduce_mean(loss_xe) tf.summary.scalar('losses/xe', loss_xe) # Pseudo-label cross entropy for unlabeled data pseudo_labels = tf.stop_gradient(tf.nn.softmax(logits_weak)) loss_xeu = tf.nn.sparse_softmax_cross_entropy_with_logits( labels=tf.argmax(pseudo_labels, axis=1), logits=logits_strong) # pseudo_mask = tf.to_float(tf.reduce_max(pseudo_labels, axis=1) >= confidence) pseudo_mask = self.class_balancing(pseudo_labels, balance, confidence, delT) tf.summary.scalar('monitors/mask', tf.reduce_mean(pseudo_mask)) loss_xeu = tf.reduce_mean(loss_xeu * pseudo_mask) tf.summary.scalar('losses/xeu', loss_xeu) ####################### Modification # Contrastive loss term contrast_loss = 0 if wclr > 0 and wclr_in == 0: ratio = min(uratio, clrratio) if FLAGS.clrDataAug == 1: preprocess_fn = functools.partial( data_util.preprocess_for_train, height=self.dataset.height, width=self.dataset.width) x = tf.concat( [lambda y: preprocess_fn(y), lambda y: preprocess_fn(y)], 0) embeds = lambda x, **kw: self.classifier(x, **kw, **kwargs ).embeds hidden = utils.para_cat(lambda x: embeds(x, training=True), x) else: embeds = lambda x, **kw: self.classifier(x, **kw, **kwargs ).embeds hiddens = utils.para_cat(lambda x: embeds(x, training=True), x) hiddens = utils.de_interleave(hiddens, 2 * uratio + 1) hiddens_weak, hiddens_strong = tf.split(hiddens[batch:], 2, 0) hidden = tf.concat([ hiddens_weak[:batch * ratio], hiddens_strong[:batch * ratio] ], axis=0) del hiddens, hiddens_weak, hiddens_strong contrast_loss, _, _ = obj_lib.add_contrastive_loss( hidden, hidden_norm=True, # FLAGS.hidden_norm, temperature=temperature, tpu_context=None) tf.summary.scalar('losses/contrast', contrast_loss) del embeds, hidden ###################### End # L2 regularization loss_wd = sum( tf.nn.l2_loss(v) for v in utils.model_vars('classify') if 'kernel' in v.name) tf.summary.scalar('losses/wd', loss_wd) ema = tf.train.ExponentialMovingAverage(decay=ema) ema_op = ema.apply(utils.model_vars()) ema_getter = functools.partial(utils.getter_ema, ema) post_ops.append(ema_op) # train_op = tf.train.MomentumOptimizer(lr, 0.9, use_nesterov=True).minimize( train_op = tf.train.MomentumOptimizer( lr, mom, use_nesterov=True).minimize( loss_xe + wu * loss_xeu + wclr * contrast_loss + wd * loss_wd, colocate_gradients_with_ops=True) with tf.control_dependencies([train_op]): train_op = tf.group(*post_ops) return utils.EasyDict( xt=xt_in, x=x_in, y=y_in, label=l_in, wclr=wclr_in, train_op=train_op, classify_raw=tf.nn.softmax(classifier( x_in, training=False)), # No EMA, for debugging. classify_op=tf.nn.softmax( classifier(x_in, getter=ema_getter, training=False)))
def model_fn(features, labels, mode, params=None): """Build model and optimizer.""" is_training = mode == tf.estimator.ModeKeys.TRAIN # Check training mode. if FLAGS.train_mode == 'pretrain': num_transforms = 2 if FLAGS.fine_tune_after_block > -1: raise ValueError( 'Does not support layer freezing during pretraining,' 'should set fine_tune_after_block<=-1 for safety.') elif FLAGS.train_mode == 'finetune': num_transforms = 1 else: raise ValueError('Unknown train_mode {}'.format(FLAGS.train_mode)) # Split channels, and optionally apply extra batched augmentation. features_list = tf.split(features, num_or_size_splits=num_transforms, axis=-1) if FLAGS.use_blur and is_training and FLAGS.train_mode == 'pretrain': features_list = data_util.batch_random_blur( features_list, FLAGS.image_size, FLAGS.image_size) features = tf.concat(features_list, 0) # (num_transforms * bsz, h, w, c) # Base network forward pass. with tf.variable_scope('base_model'): if FLAGS.train_mode == 'finetune' and FLAGS.fine_tune_after_block >= 4: # Finetune just supervised (linear) head will not update BN stats. model_train_mode = False else: # Pretrain or finetuen anything else will update BN stats. model_train_mode = is_training hiddens = model(features, is_training=model_train_mode) # Add head and loss. if FLAGS.train_mode == 'pretrain': tpu_context = params['context'] if 'context' in params else None hiddens_proj = model_util.projection_head(hiddens, is_training) contrast_loss, logits_con, labels_con = obj_lib.add_contrastive_loss( hiddens_proj, hidden_norm=FLAGS.hidden_norm, temperature=FLAGS.temperature, tpu_context=tpu_context if is_training else None) logits_sup = tf.zeros([params['batch_size'], num_classes]) else: contrast_loss = tf.zeros([]) logits_con = tf.zeros([params['batch_size'], 10]) labels_con = tf.zeros([params['batch_size'], 10]) logits_sup = model_util.supervised_head(hiddens, num_classes, is_training) obj_lib.add_supervised_loss(labels=labels['labels'], logits=logits_sup, weights=labels['mask']) # Add weight decay to loss, for non-LARS optimizers. model_util.add_weight_decay(adjust_per_optimizer=True) loss = tf.losses.get_total_loss() if FLAGS.train_mode == 'pretrain': variables_to_train = tf.trainable_variables() else: collection_prefix = 'trainable_variables_inblock_' variables_to_train = [] for j in range(FLAGS.fine_tune_after_block + 1, 6): variables_to_train += tf.get_collection(collection_prefix + str(j)) assert variables_to_train, 'variables_to_train shouldn\'t be empty!' tf.logging.info( '===============Variables to train (begin)===============') tf.logging.info(variables_to_train) tf.logging.info( '================Variables to train (end)================') learning_rate = model_util.learning_rate_schedule( FLAGS.learning_rate, num_train_examples) if is_training: if FLAGS.train_summary_steps > 0: # Compute stats for the summary. prob_con = tf.nn.softmax(logits_con) entropy_con = -tf.reduce_mean( tf.reduce_sum(prob_con * tf.math.log(prob_con + 1e-8), -1)) summary_writer = tf2.summary.create_file_writer( FLAGS.model_dir) # TODO(iamtingchen): remove this control_dependencies in the future. with tf.control_dependencies([summary_writer.init()]): with summary_writer.as_default(): should_record = tf.math.equal( tf.math.floormod(tf.train.get_global_step(), FLAGS.train_summary_steps), 0) with tf2.summary.record_if(should_record): contrast_acc = tf.equal( tf.argmax(labels_con, 1), tf.argmax(logits_con, axis=1)) contrast_acc = tf.reduce_mean( tf.cast(contrast_acc, tf.float32)) label_acc = tf.equal( tf.argmax(labels['labels'], 1), tf.argmax(logits_sup, axis=1)) label_acc = tf.reduce_mean( tf.cast(label_acc, tf.float32)) tf2.summary.scalar('train_contrast_loss', contrast_loss, step=tf.train.get_global_step()) tf2.summary.scalar('train_contrast_acc', contrast_acc, step=tf.train.get_global_step()) tf2.summary.scalar('train_label_accuracy', label_acc, step=tf.train.get_global_step()) tf2.summary.scalar('contrast_entropy', entropy_con, step=tf.train.get_global_step()) tf2.summary.scalar('learning_rate', learning_rate, step=tf.train.get_global_step()) tf2.summary.scalar('input_mean', tf.reduce_mean(features), step=tf.train.get_global_step()) tf2.summary.scalar('input_max', tf.reduce_max(features), step=tf.train.get_global_step()) tf2.summary.scalar('input_min', tf.reduce_min(features), step=tf.train.get_global_step()) tf2.summary.scalar('num_labels', tf.reduce_mean( tf.reduce_sum( labels['labels'], -1)), step=tf.train.get_global_step()) if FLAGS.optimizer == 'momentum': optimizer = tf.train.MomentumOptimizer(learning_rate, FLAGS.momentum, use_nesterov=True) elif FLAGS.optimizer == 'adam': optimizer = tf.train.AdamOptimizer(learning_rate) elif FLAGS.optimizer == 'lars': optimizer = LARSOptimizer( learning_rate, momentum=FLAGS.momentum, weight_decay=FLAGS.weight_decay, exclude_from_weight_decay=['batch_normalization', 'bias']) else: raise ValueError('Unknown optimizer {}'.format( FLAGS.optimizer)) if FLAGS.use_tpu: optimizer = tf.tpu.CrossShardOptimizer(optimizer) control_deps = tf.get_collection(tf.GraphKeys.UPDATE_OPS) if FLAGS.train_summary_steps > 0: control_deps.extend(tf.summary.all_v2_summary_ops()) with tf.control_dependencies(control_deps): train_op = optimizer.minimize( loss, global_step=tf.train.get_or_create_global_step(), var_list=variables_to_train) if FLAGS.checkpoint: def scaffold_fn(): """Scaffold function to restore non-logits vars from checkpoint.""" for v in tf.global_variables(FLAGS.variable_schema): print(v.op.name) tf.train.init_from_checkpoint( FLAGS.checkpoint, { v.op.name: v.op.name for v in tf.global_variables(FLAGS.variable_schema) }) if FLAGS.zero_init_logits_layer: # Init op that initializes output layer parameters to zeros. output_layer_parameters = [ var for var in tf.trainable_variables() if var.name.startswith('head_supervised') ] tf.logging.info( 'Initializing output layer parameters %s to zero', [x.op.name for x in output_layer_parameters]) with tf.control_dependencies( [tf.global_variables_initializer()]): init_op = tf.group([ tf.assign(x, tf.zeros_like(x)) for x in output_layer_parameters ]) return tf.train.Scaffold(init_op=init_op) else: return tf.train.Scaffold() else: scaffold_fn = None return tf.estimator.tpu.TPUEstimatorSpec(mode=mode, train_op=train_op, loss=loss, scaffold_fn=scaffold_fn) elif mode == tf.estimator.ModeKeys.PREDICT: _, top_5 = tf.nn.top_k(logits_sup, k=5) predictions = { 'label': tf.argmax(labels['labels'], 1), 'top_5': top_5, } return tf.estimator.tpu.TPUEstimatorSpec(mode=mode, predictions=predictions) else: def metric_fn(logits_sup, labels_sup, logits_con, labels_con, mask, **kws): """Inner metric function.""" metrics = { k: tf.metrics.mean(v, weights=mask) for k, v in kws.items() } metrics['label_top_1_accuracy'] = tf.metrics.accuracy( tf.argmax(labels_sup, 1), tf.argmax(logits_sup, axis=1), weights=mask) metrics['label_top_5_accuracy'] = tf.metrics.recall_at_k( tf.argmax(labels_sup, 1), logits_sup, k=5, weights=mask) metrics['contrastive_top_1_accuracy'] = tf.metrics.accuracy( tf.argmax(labels_con, 1), tf.argmax(logits_con, axis=1), weights=mask) metrics['contrastive_top_5_accuracy'] = tf.metrics.recall_at_k( tf.argmax(labels_con, 1), logits_con, k=5, weights=mask) metrics[ 'mean_class_accuracy'] = tf.metrics.mean_per_class_accuracy( tf.argmax(labels_sup, 1), tf.argmax(logits_sup, axis=1), num_classes, weights=mask, name='mca') running_vars = tf.get_collection(tf.GraphKeys.LOCAL_VARIABLES, scope="mca") metrics['mean_class_accuracy_total'] = running_vars[0] metrics['mean_class_accuracy_count'] = running_vars[1] return metrics metrics = { 'logits_sup': logits_sup, 'labels_sup': labels['labels'], 'logits_con': logits_con, 'labels_con': labels_con, 'mask': labels['mask'], 'contrast_loss': tf.fill((params['batch_size'], ), contrast_loss), 'regularization_loss': tf.fill((params['batch_size'], ), tf.losses.get_regularization_loss()), } return tf.estimator.tpu.TPUEstimatorSpec(mode=mode, loss=loss, eval_metrics=(metric_fn, metrics), scaffold_fn=None)
def single_step(features, labels): with tf.GradientTape() as tape: # Log summaries on the last step of the training loop to match # logging frequency of other scalar summaries. # # Notes: # 1. Summary ops on TPUs get outside compiled so they do not affect # performance. # 2. Summaries are recorded only on replica 0. So effectively this # summary would be written once per host when should_record == True. # 3. optimizer.iterations is incremented in the call to apply_gradients. # So we use `iterations + 1` here so that the step number matches # those of scalar summaries. # 4. We intentionally run the summary op before the actual model # training so that it can run in parallel. should_record = tf.equal((optimizer.iterations + 1) % steps_per_loop, 0) with tf.summary.record_if(should_record): # Only log augmented images for the first tower. tf.summary.image( 'image', features[:, :, :, :3], step=optimizer.iterations + 1) projection_head_outputs, supervised_head_outputs = model( features, training=True) loss = None if projection_head_outputs is not None: outputs = projection_head_outputs con_loss, logits_con, labels_con = obj_lib.add_contrastive_loss( outputs, hidden_norm=FLAGS.hidden_norm, temperature=FLAGS.temperature, strategy=strategy) if loss is None: loss = con_loss else: loss += con_loss metrics.update_pretrain_metrics_train(contrast_loss_metric, contrast_acc_metric, contrast_entropy_metric, con_loss, logits_con, labels_con) if supervised_head_outputs is not None: outputs = supervised_head_outputs l = labels['labels'] if FLAGS.train_mode == 'pretrain' and FLAGS.lineareval_while_pretraining: l = tf.concat([l, l], 0) sup_loss = obj_lib.add_supervised_loss(labels=l, logits=outputs) if loss is None: loss = sup_loss else: loss += sup_loss metrics.update_finetune_metrics_train(supervised_loss_metric, supervised_acc_metric, sup_loss, l, outputs) weight_decay = model_lib.add_weight_decay( model, adjust_per_optimizer=True) weight_decay_metric.update_state(weight_decay) loss += weight_decay total_loss_metric.update_state(loss) # The default behavior of `apply_gradients` is to sum gradients from all # replicas so we divide the loss by the number of replicas so that the # mean gradient is applied. loss = loss / strategy.num_replicas_in_sync logging.info('Trainable variables:') for var in model.trainable_variables: logging.info(var.name) grads = tape.gradient(loss, model.trainable_variables) optimizer.apply_gradients(zip(grads, model.trainable_variables))