def __init__(self, train_dir, dataset, **kwargs): self.train_dir = os.path.join(train_dir, self.experiment_name(**kwargs)) self.params = utils.EasyDict(kwargs) self.dataset = dataset self.session = None self.tmp = utils.EasyDict(print_queue=[], cache=utils.EasyDict()) self.step = tf.train.get_or_create_global_step() self.ops = self.model(**kwargs) self.ops.update_step = tf.assign_add(self.step, FLAGS.batch) self.add_summaries(**kwargs) print(' Config '.center(80, '-')) print('train_dir', self.train_dir) print('%-32s %s' % ('Model', self.__class__.__name__)) print('%-32s %s' % ('Dataset', dataset.name)) for k, v in sorted(kwargs.items()): print('%-32s %s' % (k, v)) print(' Model '.center(80, '-')) to_print = [ tuple(['%s' % x for x in (v.name, np.prod(v.shape), v.shape)]) for v in utils.model_vars(None) ] to_print.append(('Total', str(sum(int(x[1]) for x in to_print)), '')) sizes = [max([len(x[i]) for x in to_print]) for i in range(3)] fmt = '%%-%ds %%%ds %%%ds' % tuple(sizes) for x in to_print[:-1]: print(fmt % x) print() print(fmt % to_print[-1]) print('-' * 80) self._create_initial_files() self.work_unit = None self.measurement = {}
def model(self, batch, lr, wd, wu, confidence, uratio, ema=0.999, **kwargs): hwc = [self.dataset.height, self.dataset.width, self.dataset.colors] xt_in = tf.placeholder(tf.float32, [batch] + hwc, 'xt') # Training labeled x_in = tf.placeholder(tf.float32, [None] + hwc, 'x') # Eval images y_in = tf.placeholder(tf.float32, [batch * uratio, 2] + hwc, 'y') # Training unlabeled (weak, strong) l_in = tf.placeholder(tf.int32, [batch], 'labels') # Labels lrate = tf.clip_by_value(tf.to_float(self.step) / (FLAGS.train_kimg << 10), 0, 1) lr *= tf.cos(lrate * (7 * np.pi) / (2 * 8)) tf.summary.scalar('monitors/lr', lr) # Compute logits for xt_in and y_in classifier = lambda x, **kw: self.classifier(x, **kw, **kwargs).logits skip_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) x = utils.interleave(tf.concat([xt_in, y_in[:, 0], y_in[:, 1]], 0), 2 * uratio + 1) logits = utils.para_cat(lambda x: classifier(x, training=True), x) logits = utils.de_interleave(logits, 2 * uratio+1) post_ops = [v for v in tf.get_collection(tf.GraphKeys.UPDATE_OPS) if v not in skip_ops] logits_x = logits[:batch] logits_weak, logits_strong = tf.split(logits[batch:], 2) del logits, skip_ops # Labeled cross-entropy loss_xe = tf.nn.sparse_softmax_cross_entropy_with_logits(labels=l_in, logits=logits_x) loss_xe = tf.reduce_mean(loss_xe) tf.summary.scalar('losses/xe', loss_xe) # Pseudo-label cross entropy for unlabeled data pseudo_labels = tf.stop_gradient(tf.nn.softmax(logits_weak)) loss_xeu = tf.nn.sparse_softmax_cross_entropy_with_logits(labels=tf.argmax(pseudo_labels, axis=1), logits=logits_strong) pseudo_mask = tf.to_float(tf.reduce_max(pseudo_labels, axis=1) >= confidence) tf.summary.scalar('monitors/mask', tf.reduce_mean(pseudo_mask)) loss_xeu = tf.reduce_mean(loss_xeu * pseudo_mask) tf.summary.scalar('losses/xeu', loss_xeu) # L2 regularization loss_wd = sum(tf.nn.l2_loss(v) for v in utils.model_vars('classify') if 'kernel' in v.name) tf.summary.scalar('losses/wd', loss_wd) ema = tf.train.ExponentialMovingAverage(decay=ema) ema_op = ema.apply(utils.model_vars()) ema_getter = functools.partial(utils.getter_ema, ema) post_ops.append(ema_op) train_op = tf.train.MomentumOptimizer(lr, 0.9, use_nesterov=True).minimize( loss_xe + wu * loss_xeu + wd * loss_wd, colocate_gradients_with_ops=True) with tf.control_dependencies([train_op]): train_op = tf.group(*post_ops) return utils.EasyDict( xt=xt_in, x=x_in, y=y_in, label=l_in, train_op=train_op, classify_raw=tf.nn.softmax(classifier(x_in, training=False)), # No EMA, for debugging. classify_op=tf.nn.softmax(classifier(x_in, getter=ema_getter, training=False)))
def model_val(self, batch, uratio, **kwargs): hwc = [self.dataset.height, self.dataset.width, self.dataset.colors] # validation labeled x = tf.placeholder(tf.float32, [batch * uratio] + hwc, 'v') l_in = tf.placeholder(tf.int32, [batch * uratio], 'labels') # Labels classifier = lambda x, **kw: self.classifier(x, **kw, **kwargs).logits logits = utils.para_cat(lambda x: classifier(x, training=True), x) loss_xe = tf.reduce_mean( tf.nn.sparse_softmax_cross_entropy_with_logits( labels=l_in, logits=logits)) # loss grad_val_loss_op = tf.gradients(loss_xe, self.model_params) return utils.EasyDict(v=x, l=l_in, grad_val_loss_op=grad_val_loss_op)
def guess_label(self, y, classifier, p_data, p_model, T, **kwargs): del kwargs logits_y = [classifier(yi, training=True) for yi in y] logits_y = tf.concat(logits_y, 0) # Compute predicted probability distribution py. p_model_y = tf.reshape(tf.nn.softmax(logits_y), [len(y), -1, self.nclass]) p_model_y = tf.reduce_mean(p_model_y, axis=0) # Compute the target distribution. p_target = tf.pow(p_model_y, 1. / T) p_target /= tf.reduce_sum(p_target, axis=1, keep_dims=True) return utils.EasyDict(p_target=p_target, p_model=p_model_y)
def model_per_ex(self, nclass, batch, confidence, uratio, **kwargs): hwc = [self.dataset.height, self.dataset.width, self.dataset.colors] y_in = tf.placeholder(tf.float32, [batch * uratio, 2] + hwc, 'y') # weights for unlabeled data w_match = tf.placeholder(tf.float32, [batch * uratio], 'w_match') # forward classifier = lambda x, **kw: self.classifier(x, **kw, **kwargs).logits x = tf.concat([y_in[:, 0], y_in[:, 1]], 0) logits = classifier(x, training=True) logits_weak, logits_strong = tf.split(logits, 2) # Pseudo-label cross entropy for unlabeled data pseudo_labels = tf.stop_gradient(tf.nn.softmax(logits_weak)) pseudo_labels_hard = tf.one_hot(tf.argmax(pseudo_labels, axis=1), nclass) loss_xeu_all = tf.nn.softmax_cross_entropy_with_logits_v2( labels=pseudo_labels_hard, logits=logits_strong) pseudo_mask = tf.to_float( tf.reduce_max(pseudo_labels, axis=1) >= confidence) loss_xeu = tf.reduce_mean(loss_xeu_all * pseudo_mask * w_match) # per-ex-grad_wrt_unlabeled_loss grads_train_per_ex = custom_gradient(loss_xeu_all, self.model_params) return utils.EasyDict(y=y_in, w_match=w_match, grads_train_per_ex=grads_train_per_ex)
def model(self, batch, lr, wd, wu, mom, delT, confidence, balance, uratio, ema=0.999, **kwargs): hwc = [self.dataset.height, self.dataset.width, self.dataset.colors] xt_in = tf.placeholder(tf.float32, [batch] + hwc, 'xt') # Training labeled x_in = tf.placeholder(tf.float32, [None] + hwc, 'x') # Eval images y_in = tf.placeholder(tf.float32, [batch * uratio, 2] + hwc, 'y') # Training unlabeled (weak, strong) l_in = tf.placeholder(tf.int32, [batch], 'labels') # Labels lrate = tf.clip_by_value( tf.to_float(self.step) / (FLAGS.train_kimg << 10), 0, 1) lr *= tf.cos(lrate * (7 * np.pi) / (2 * 8)) tf.summary.scalar('monitors/lr', lr) # Compute logits for xt_in and y_in classifier = lambda x, **kw: self.classifier(x, **kw, **kwargs).logits skip_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) x = utils.interleave(tf.concat([xt_in, y_in[:, 0], y_in[:, 1]], 0), 2 * uratio + 1) logits = utils.para_cat(lambda x: classifier(x, training=True), x) logits = utils.de_interleave(logits, 2 * uratio + 1) post_ops = [ v for v in tf.get_collection(tf.GraphKeys.UPDATE_OPS) if v not in skip_ops ] logits_x = logits[:batch] logits_weak, logits_strong = tf.split(logits[batch:], 2) del logits, skip_ops # Labeled cross-entropy loss_xe = tf.nn.sparse_softmax_cross_entropy_with_logits( labels=l_in, logits=logits_x) loss_xe = tf.reduce_mean(loss_xe) tf.summary.scalar('losses/xe', loss_xe) # Pseudo-label cross entropy for unlabeled data pseudo_labels = tf.stop_gradient(tf.nn.softmax(logits_weak)) pLabels = tf.math.argmax(pseudo_labels, axis=1) loss_xeu = tf.nn.sparse_softmax_cross_entropy_with_logits( labels=pLabels, logits=logits_strong) ####################### Modification pLabels = tf.cast(pLabels, dtype=tf.float32) classes, idx, counts = tf.unique_with_counts(pLabels) shape = tf.constant([self.dataset.nclass]) classes = tf.cast(classes, dtype=tf.int32) class_count = tf.scatter_nd(tf.reshape(classes, [tf.size(classes), 1]), counts, shape) print_cc = tf.print("class_count ", class_count, output_stream=sys.stdout) class_count = tf.cast(class_count, dtype=tf.float32) mxCount = tf.reduce_max(class_count, axis=0) if balance > 0: pLabels = tf.cast(pLabels, dtype=tf.int32) if balance == 1 or balance == 4: confidences = tf.zeros_like(pLabels, dtype=tf.float32) ratios = 1.0 - tf.math.divide_no_nan(class_count, mxCount) ratios = confidence - delT * ratios confidences = tf.gather_nd( ratios, tf.reshape(pLabels, [tf.size(pLabels), 1])) pseudo_mask = tf.reduce_max(pseudo_labels, axis=1) >= confidences else: pseudo_mask = tf.reduce_max(pseudo_labels, axis=1) >= confidence if balance == 3 or balance == 4: classes, idx, counts = tf.unique_with_counts( tf.boolean_mask(pLabels, pseudo_mask)) shape = tf.constant([self.dataset.nclass]) classes = tf.cast(classes, dtype=tf.int32) class_count = tf.scatter_nd( tf.reshape(classes, [tf.size(classes), 1]), counts, shape) class_count = tf.cast(class_count, dtype=tf.float32) pseudo_mask = tf.cast(pseudo_mask, dtype=tf.float32) if balance > 1: ratios = tf.math.divide_no_nan( tf.ones_like(class_count, dtype=tf.float32), class_count) ratio = tf.gather_nd( ratios, tf.reshape(pLabels, [tf.size(pLabels), 1])) # ratio = sum(pseudo_mask) * ratio / sum(ratio) Z = tf.reduce_sum(pseudo_mask) pseudo_mask = tf.math.multiply( pseudo_mask, tf.cast(ratio, dtype=tf.float32)) pseudo_mask = tf.math.divide_no_nan( tf.scalar_mul(Z, pseudo_mask), tf.reduce_sum(pseudo_mask)) else: pseudo_mask = tf.cast( tf.reduce_max(pseudo_labels, axis=1) >= confidence, dtype=tf.float32) ###################### End # tf.print(" class_count= ",class_count) tf.summary.scalar('monitors/mask', tf.reduce_mean(pseudo_mask)) loss_xeu = tf.reduce_mean(loss_xeu * pseudo_mask) tf.summary.scalar('losses/xeu', loss_xeu) # L2 regularization loss_wd = sum( tf.nn.l2_loss(v) for v in utils.model_vars('classify') if 'kernel' in v.name) tf.summary.scalar('losses/wd', loss_wd) ema = tf.train.ExponentialMovingAverage(decay=ema) ema_op = ema.apply(utils.model_vars()) ema_getter = functools.partial(utils.getter_ema, ema) post_ops.append(ema_op) # train_op = tf.train.MomentumOptimizer(lr, 0.9, use_nesterov=True).minimize( train_op = tf.train.MomentumOptimizer( lr, mom, use_nesterov=True).minimize(loss_xe + wu * loss_xeu + wd * loss_wd, colocate_gradients_with_ops=True) with tf.control_dependencies([train_op]): train_op = tf.group(*post_ops) return utils.EasyDict( xt=xt_in, x=x_in, y=y_in, label=l_in, train_op=train_op, classify_raw=tf.nn.softmax(classifier( x_in, training=False)), # No EMA, for debugging. classify_op=tf.nn.softmax( classifier(x_in, getter=ema_getter, training=False)))
def __init__(self, train_dir: str, **kwargs): self.train_dir = os.path.join( train_dir) # , self.experiment_name(**kwargs)) self.params = utils.EasyDict(kwargs) self.session = None self.tmp = utils.EasyDict(print_queue=[], cache=utils.EasyDict())
def model(self, batch, lr, wd, wu, wclr, mom, confidence, balance, delT, uratio, clrratio, temperature, ema=0.999, **kwargs): hwc = [self.dataset.height, self.dataset.width, self.dataset.colors] xt_in = tf.placeholder(tf.float32, [batch] + hwc, 'xt') # Training labeled x_in = tf.placeholder(tf.float32, [None] + hwc, 'x') # Eval images y_in = tf.placeholder(tf.float32, [batch * uratio, 2] + hwc, 'y') # Training unlabeled (weak, strong) l_in = tf.placeholder(tf.int32, [batch], 'labels') # Labels wclr_in = tf.placeholder(tf.int32, [1], 'wclr') # wclr lrate = tf.clip_by_value( tf.to_float(self.step) / (FLAGS.train_kimg << 10), 0, 1) lr *= tf.cos(lrate * (7 * np.pi) / (2 * 8)) tf.summary.scalar('monitors/lr', lr) # Compute logits for xt_in and y_in classifier = lambda x, **kw: self.classifier(x, **kw, **kwargs).logits skip_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) x = utils.interleave(tf.concat([xt_in, y_in[:, 0], y_in[:, 1]], 0), 2 * uratio + 1) logits = utils.para_cat(lambda x: classifier(x, training=True), x) logits = utils.de_interleave(logits, 2 * uratio + 1) post_ops = [ v for v in tf.get_collection(tf.GraphKeys.UPDATE_OPS) if v not in skip_ops ] logits_x = logits[:batch] logits_weak, logits_strong = tf.split(logits[batch:], 2) del logits, skip_ops # Labeled cross-entropy loss_xe = tf.nn.sparse_softmax_cross_entropy_with_logits( labels=l_in, logits=logits_x) loss_xe = tf.reduce_mean(loss_xe) tf.summary.scalar('losses/xe', loss_xe) # Pseudo-label cross entropy for unlabeled data pseudo_labels = tf.stop_gradient(tf.nn.softmax(logits_weak)) loss_xeu = tf.nn.sparse_softmax_cross_entropy_with_logits( labels=tf.argmax(pseudo_labels, axis=1), logits=logits_strong) # pseudo_mask = tf.to_float(tf.reduce_max(pseudo_labels, axis=1) >= confidence) pseudo_mask = self.class_balancing(pseudo_labels, balance, confidence, delT) tf.summary.scalar('monitors/mask', tf.reduce_mean(pseudo_mask)) loss_xeu = tf.reduce_mean(loss_xeu * pseudo_mask) tf.summary.scalar('losses/xeu', loss_xeu) ####################### Modification # Contrastive loss term contrast_loss = 0 if wclr > 0 and wclr_in == 0: ratio = min(uratio, clrratio) if FLAGS.clrDataAug == 1: preprocess_fn = functools.partial( data_util.preprocess_for_train, height=self.dataset.height, width=self.dataset.width) x = tf.concat( [lambda y: preprocess_fn(y), lambda y: preprocess_fn(y)], 0) embeds = lambda x, **kw: self.classifier(x, **kw, **kwargs ).embeds hidden = utils.para_cat(lambda x: embeds(x, training=True), x) else: embeds = lambda x, **kw: self.classifier(x, **kw, **kwargs ).embeds hiddens = utils.para_cat(lambda x: embeds(x, training=True), x) hiddens = utils.de_interleave(hiddens, 2 * uratio + 1) hiddens_weak, hiddens_strong = tf.split(hiddens[batch:], 2, 0) hidden = tf.concat([ hiddens_weak[:batch * ratio], hiddens_strong[:batch * ratio] ], axis=0) del hiddens, hiddens_weak, hiddens_strong contrast_loss, _, _ = obj_lib.add_contrastive_loss( hidden, hidden_norm=True, # FLAGS.hidden_norm, temperature=temperature, tpu_context=None) tf.summary.scalar('losses/contrast', contrast_loss) del embeds, hidden ###################### End # L2 regularization loss_wd = sum( tf.nn.l2_loss(v) for v in utils.model_vars('classify') if 'kernel' in v.name) tf.summary.scalar('losses/wd', loss_wd) ema = tf.train.ExponentialMovingAverage(decay=ema) ema_op = ema.apply(utils.model_vars()) ema_getter = functools.partial(utils.getter_ema, ema) post_ops.append(ema_op) # train_op = tf.train.MomentumOptimizer(lr, 0.9, use_nesterov=True).minimize( train_op = tf.train.MomentumOptimizer( lr, mom, use_nesterov=True).minimize( loss_xe + wu * loss_xeu + wclr * contrast_loss + wd * loss_wd, colocate_gradients_with_ops=True) with tf.control_dependencies([train_op]): train_op = tf.group(*post_ops) return utils.EasyDict( xt=xt_in, x=x_in, y=y_in, label=l_in, wclr=wclr_in, train_op=train_op, classify_raw=tf.nn.softmax(classifier( x_in, training=False)), # No EMA, for debugging. classify_op=tf.nn.softmax( classifier(x_in, getter=ema_getter, training=False)))
def model(self, batch, lr, wd, wu, we, confidence, uratio, temperature=1.0, tsa='no', tsa_pos=10, ema=0.999, **kwargs): hwc = [self.dataset.height, self.dataset.width, self.dataset.colors] xt_in = tf.placeholder(tf.float32, [batch] + hwc, 'xt') # For training x_in = tf.placeholder(tf.float32, [None] + hwc, 'x') y_in = tf.placeholder(tf.float32, [batch * uratio, 2] + hwc, 'y') l_in = tf.placeholder(tf.int32, [batch], 'labels') l = tf.one_hot(l_in, self.nclass) lrate = tf.clip_by_value( tf.to_float(self.step) / (FLAGS.train_kimg << 10), 0, 1) lr *= tf.cos(lrate * (7 * np.pi) / (2 * 8)) tf.summary.scalar('monitors/lr', lr) # Compute logits for xt_in and y_in classifier = lambda x, **kw: self.classifier(x, **kw, **kwargs).logits skip_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) x = utils.interleave(tf.concat([xt_in, y_in[:, 0], y_in[:, 1]], 0), 2 * uratio + 1) logits = utils.para_cat(lambda x: classifier(x, training=True), x) logits = utils.de_interleave(logits, 2 * uratio + 1) post_ops = [ v for v in tf.get_collection(tf.GraphKeys.UPDATE_OPS) if v not in skip_ops ] logits_x = logits[:batch] logits_weak, logits_strong = tf.split(logits[batch:], 2) del logits, skip_ops # softmax temperature control logits_weak_tgt = self.softmax_temperature_controlling(logits_weak, T=temperature) # generate confidence mask based on sharpened distribution pseudo_labels = tf.stop_gradient(tf.nn.softmax(logits_weak)) pseudo_mask = self.confidence_based_masking(logits=None, p_class=pseudo_labels, thresh=confidence) tf.summary.scalar('monitors/mask', tf.reduce_mean(pseudo_mask)) tf.summary.scalar( 'monitors/conf_weak', tf.reduce_mean(tf.reduce_max(tf.nn.softmax(logits_weak), axis=1))) tf.summary.scalar( 'monitors/conf_strong', tf.reduce_mean(tf.reduce_max(tf.nn.softmax(logits_strong), axis=1))) kld = self.kl_divergence_from_logits(logits_weak_tgt, logits_strong) entropy = self.entropy_from_logits(logits_weak) loss_xeu = tf.reduce_mean(kld * pseudo_mask) tf.summary.scalar('losses/xeu', loss_xeu) loss_ent = tf.reduce_mean(entropy) tf.summary.scalar('losses/entropy', loss_ent) # supervised loss with TSA loss_mask = self.tsa_loss_mask(tsa=tsa, logits=logits_x, labels=l, tsa_pos=tsa_pos) loss_xe = tf.nn.softmax_cross_entropy_with_logits_v2(labels=l, logits=logits_x) loss_xe = tf.reduce_sum(loss_xe * loss_mask) / tf.math.maximum( tf.reduce_sum(loss_mask), 1.0) tf.summary.scalar('losses/xe', loss_xe) tf.summary.scalar('losses/mask_sup', tf.reduce_mean(loss_mask)) # L2 regularization loss_wd = sum( tf.nn.l2_loss(v) for v in utils.model_vars('classify') if 'kernel' in v.name) tf.summary.scalar('losses/wd', loss_wd) ema = tf.train.ExponentialMovingAverage(decay=ema) ema_op = ema.apply(utils.model_vars()) ema_getter = functools.partial(utils.getter_ema, ema) post_ops.append(ema_op) train_op = tf.train.MomentumOptimizer( lr, 0.9, use_nesterov=True).minimize( loss_xe + loss_xeu * wu + loss_ent * we + loss_wd * wd, colocate_gradients_with_ops=True) with tf.control_dependencies([train_op]): train_op = tf.group(*post_ops) return utils.EasyDict( xt=xt_in, x=x_in, y=y_in, label=l_in, train_op=train_op, classify_raw=tf.nn.softmax(classifier( x_in, training=False)), # No EMA, for debugging. classify_op=tf.nn.softmax( classifier(x_in, getter=ema_getter, training=False)))
def model(self, nu, w_match, warmup_kimg, batch, lr, wd, ema, dbuf, beta, mixmode, logit_norm, **kwargs): def classifier(x, logit_norm=logit_norm, **kw): v = self.classifier(x, **kw, **kwargs)[0] if not logit_norm: return v return v * tf.rsqrt(tf.reduce_mean(tf.square(v)) + 1e-8) def embedding(x, **kw): return self.classifier(x, **kw, **kwargs)[1] label_index = tf.Variable(utils.idx_to_fixlen( self.dataset.labeled_indices, self.dataset.ntrain), trainable=False, name='label_index', dtype=tf.int32) label_index_input = tf.placeholder(tf.int32, self.dataset.ntrain, 'label_index_input') update_label_index = tf.assign(label_index, label_index_input) hwc = [self.dataset.height, self.dataset.width, self.dataset.colors] x_in = tf.placeholder(tf.float32, [None] + hwc, 'x') y_in = tf.placeholder(tf.float32, [None, nu] + hwc, 'y') l_in = tf.placeholder(tf.int32, [None], 'labels') wd *= lr w_match *= tf.clip_by_value( tf.cast(self.step, tf.float32) / (warmup_kimg << 10), 0, 1) augment = MixMode(mixmode) # Moving average of the current estimated label distribution p_model = layers.PMovingAverage('p_model', self.nclass, dbuf) p_target = layers.PMovingAverage( 'p_target', self.nclass, dbuf) # Rectified distribution (only for plotting) # Known (or inferred) true unlabeled distribution p_data = layers.PData(self.dataset) y = tf.reshape(tf.transpose(y_in, [1, 0, 2, 3, 4]), [-1] + hwc) guess = self.guess_label(tf.split(y, nu), classifier, p_data(), p_model(), **kwargs) ly = tf.stop_gradient(guess.p_target) lx = tf.one_hot(l_in, self.nclass) xy, labels_xy = augment([x_in] + tf.split(y, nu), [lx] + [ly] * nu, [beta, beta]) x, y = xy[0], xy[1:] labels_x, labels_y = labels_xy[0], tf.concat(labels_xy[1:], 0) del xy, labels_xy batches = layers.interleave([x] + y, batch) skip_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) logits = [classifier(batches[0], training=True)] post_ops = [ v for v in tf.get_collection(tf.GraphKeys.UPDATE_OPS) if v not in skip_ops ] for batchi in batches[1:]: logits.append(classifier(batchi, training=True)) logits = layers.interleave(logits, batch) logits_x = logits[0] logits_y = tf.concat(logits[1:], 0) loss_xe = tf.nn.softmax_cross_entropy_with_logits_v2(labels=labels_x, logits=logits_x) loss_xe = tf.reduce_mean(loss_xe) loss_l2u = tf.square(labels_y - tf.nn.softmax(logits_y)) loss_l2u = tf.reduce_mean(loss_l2u) tf.summary.scalar('losses/xe', loss_xe) tf.summary.scalar('losses/l2u', loss_l2u) self.distribution_summary(p_data(), p_model(), p_target()) ema = tf.train.ExponentialMovingAverage(decay=ema) ema_op = ema.apply(utils.model_vars()) ema_getter = functools.partial(utils.getter_ema, ema) post_ops.extend([ ema_op, p_model.update(guess.p_model), p_target.update(guess.p_target) ]) if p_data.has_update: post_ops.append(p_data.update(lx)) post_ops.extend([ tf.assign(v, v * (1 - wd)) for v in utils.model_vars('classify') if 'kernel' in v.name ]) train_op = tf.train.AdamOptimizer(lr).minimize( loss_xe + w_match * loss_l2u, colocate_gradients_with_ops=True) with tf.control_dependencies([train_op]): train_op = tf.group(*post_ops) # Tuning op: only retrain batch norm. skip_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) classifier(batches[0], training=True) train_bn = tf.group(*[ v for v in tf.get_collection(tf.GraphKeys.UPDATE_OPS) if v not in skip_ops ]) return utils.EasyDict( x=x_in, y=y_in, label=l_in, train_op=train_op, tune_op=train_bn, classify_raw=tf.nn.softmax( classifier(x_in, logit_norm=False, training=False)), # No EMA, for debugging. classify_op=tf.nn.softmax( classifier(x_in, logit_norm=False, getter=ema_getter, training=False)), embedding_op=embedding(x_in, training=False), label_index=label_index, update_label_index=update_label_index, label_index_input=label_index_input, )