def model(self, batch, lr, wd, ema, beta, w_match, warmup_kimg=1024, nu=2, mixmode='xxy.yxy', **kwargs): hwc = [self.dataset.height, self.dataset.width, self.dataset.colors] x_in = tf.placeholder(tf.float32, [None] + hwc, 'x') y_in = tf.placeholder(tf.float32, [None, nu] + hwc, 'y') l_in = tf.placeholder(tf.int32, [None], 'labels') wd *= lr w_match *= tf.clip_by_value(tf.cast(self.step, tf.float32) / (warmup_kimg << 10), 0, 1) augment = MixMode(mixmode) classifier = functools.partial(self.classifier, **kwargs) y = tf.reshape(tf.transpose(y_in, [1, 0, 2, 3, 4]), [-1] + hwc) # generate guessed label guess = self.guess_label(tf.split(y, nu), classifier, T=0.5, **kwargs) ly = tf.stop_gradient(guess.p_target) lx = tf.one_hot(l_in, self.nclass) # apply mixup xy, labels_xy = augment([x_in] + tf.split(y, nu), [lx] + [ly] * nu, [beta, beta]) x, y = xy[0], xy[1:] labels_x, labels_y = labels_xy[0], tf.concat(labels_xy[1:], 0) del xy, labels_xy batches = layers.interleave([x] + y, batch) skip_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) logits = [classifier(batches[0], training=True)] post_ops = [v for v in tf.get_collection(tf.GraphKeys.UPDATE_OPS) if v not in skip_ops] for batchi in batches[1:]: logits.append(classifier(batchi, training=True)) logits = layers.interleave(logits, batch) logits_x = logits[0] logits_y = tf.concat(logits[1:], 0) loss_xe = tf.nn.softmax_cross_entropy_with_logits_v2(labels=labels_x, logits=logits_x) loss_xe = tf.reduce_mean(loss_xe) loss_l2u = tf.square(labels_y - tf.nn.softmax(logits_y)) loss_l2u = tf.reduce_mean(loss_l2u) tf.summary.scalar('losses/xe', loss_xe) tf.summary.scalar('losses/l2u', loss_l2u) ema = tf.train.ExponentialMovingAverage(decay=ema) ema_op = ema.apply(utils.model_vars()) ema_getter = functools.partial(utils.getter_ema, ema) post_ops.append(ema_op) post_ops.extend([tf.assign(v, v * (1 - wd)) for v in utils.model_vars('classify') if 'kernel' in v.name]) train_op = tf.train.AdamOptimizer(lr).minimize(loss_xe + w_match * loss_l2u, colocate_gradients_with_ops=True) with tf.control_dependencies([train_op]): train_op = tf.group(*post_ops) # Tuning op: only retrain batch norm. skip_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) classifier(batches[0], training=True) train_bn = tf.group(*[v for v in tf.get_collection(tf.GraphKeys.UPDATE_OPS) if v not in skip_ops]) return EasyDict( x=x_in, y=y_in, label=l_in, train_op=train_op, tune_op=train_bn, classify_raw=tf.nn.softmax(classifier(x_in, training=False)), # No EMA, for debugging. classify_op=tf.nn.softmax(classifier(x_in, getter=ema_getter, training=False))) def cutmix():
def model(self, batch, lr, wd, beta, w_kl, w_match, w_rot, K, use_xe, warmup_kimg=1024, T=0.5, mixmode='xxy.yxy', dbuf=128, ema=0.999, **kwargs): hwc = [self.dataset.height, self.dataset.width, self.dataset.colors] xt_in = tf.placeholder(tf.float32, [batch] + hwc, 'xt') # For training x_in = tf.placeholder(tf.float32, [None] + hwc, 'x') y_in = tf.placeholder(tf.float32, [batch, K + 1] + hwc, 'y') l_in = tf.placeholder(tf.int32, [batch], 'labels') wd *= lr w_match *= tf.clip_by_value( tf.cast(self.step, tf.float32) / (warmup_kimg << 10), 0, 1) augment = layers.MixMode(mixmode) gpu = utils.get_gpu() def classifier_to_gpu(x, **kw): with tf.device(next(gpu)): return self.classifier(x, **kw, **kwargs).logits def random_rotate(x): b4 = batch // 4 x, xt = x[:2 * b4], tf.transpose(x[2 * b4:], [0, 2, 1, 3]) l = np.zeros(b4, np.int32) l = tf.constant(np.concatenate([l, l + 1, l + 2, l + 3], axis=0)) return tf.concat( [x[:b4], x[b4:, ::-1, ::-1], xt[:b4, ::-1], xt[b4:, :, ::-1]], axis=0), l # Moving average of the current estimated label distribution p_model = layers.PMovingAverage('p_model', self.nclass, dbuf) p_target = layers.PMovingAverage( 'p_target', self.nclass, dbuf) # Rectified distribution (only for plotting) # Known (or inferred) true unlabeled distribution p_data = layers.PData(self.dataset) if w_rot > 0: rot_y, rot_l = random_rotate(y_in[:, 1]) with tf.device(next(gpu)): rot_logits = self.classifier_rot( self.classifier(rot_y, training=True, **kwargs).embeds) loss_rot = tf.nn.softmax_cross_entropy_with_logits_v2( labels=tf.one_hot(rot_l, 4), logits=rot_logits) loss_rot = tf.reduce_mean(loss_rot) tf.summary.scalar('losses/rot', loss_rot) else: loss_rot = 0 if kwargs['redux'] == '1st' and w_kl <= 0: logits_y = [classifier_to_gpu(y_in[:, 0], training=True)] * (K + 1) elif kwargs['redux'] == '1st': logits_y = [ classifier_to_gpu(y_in[:, i], training=True) for i in range(2) ] logits_y += logits_y[:1] * (K - 1) else: logits_y = [ classifier_to_gpu(y_in[:, i], training=True) for i in range(K + 1) ] guess = self.guess_label(logits_y, p_data(), p_model(), T=T, **kwargs) ly = tf.stop_gradient(guess.p_target) if w_kl > 0: w_kl *= tf.clip_by_value( tf.cast(self.step, tf.float32) / (warmup_kimg << 10), 0, 1) loss_kl = tf.nn.softmax_cross_entropy_with_logits_v2( labels=ly[:batch], logits=logits_y[1]) loss_kl = tf.reduce_mean(loss_kl) tf.summary.scalar('losses/kl', loss_kl) else: loss_kl = 0 del logits_y lx = tf.one_hot(l_in, self.nclass) xy, labels_xy = augment([xt_in] + [y_in[:, i] for i in range(K + 1)], [lx] + tf.split(ly, K + 1), [beta, beta]) x, y = xy[0], xy[1:] labels_x, labels_y = labels_xy[0], tf.concat(labels_xy[1:], 0) del xy, labels_xy batches = layers.interleave([x] + y, batch) logits = [classifier_to_gpu(yi, training=True) for yi in batches[:-1]] skip_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) logits.append(classifier_to_gpu(batches[-1], training=True)) post_ops = [ v for v in tf.get_collection(tf.GraphKeys.UPDATE_OPS) if v not in skip_ops ] logits = layers.interleave(logits, batch) logits_x = logits[0] logits_y = tf.concat(logits[1:], 0) del batches, logits loss_xe = tf.nn.softmax_cross_entropy_with_logits_v2(labels=labels_x, logits=logits_x) loss_xe = tf.reduce_mean(loss_xe) if use_xe: loss_xeu = tf.nn.softmax_cross_entropy_with_logits_v2( labels=labels_y, logits=logits_y) else: loss_xeu = tf.square(labels_y - tf.nn.softmax(logits_y)) loss_xeu = tf.reduce_mean(loss_xeu) tf.summary.scalar('losses/xe', loss_xe) tf.summary.scalar('losses/%s' % ('xeu' if use_xe else 'l2u'), loss_xeu) self.distribution_summary(p_data(), p_model(), p_target()) ema = tf.train.ExponentialMovingAverage(decay=ema) ema_op = ema.apply(utils.model_vars()) ema_getter = functools.partial(utils.getter_ema, ema) post_ops.extend([ ema_op, p_model.update(guess.p_model), p_target.update(guess.p_target) ]) if p_data.has_update: post_ops.append(p_data.update(lx)) post_ops.extend([ tf.assign(v, v * (1 - wd)) for v in utils.model_vars('classify') if 'kernel' in v.name ]) train_op = tf.train.AdamOptimizer(lr).minimize( loss_xe + w_kl * loss_kl + w_match * loss_xeu + w_rot * loss_rot, colocate_gradients_with_ops=True) with tf.control_dependencies([train_op]): train_op = tf.group(*post_ops) return EasyDict( xt=xt_in, x=x_in, y=y_in, label=l_in, train_op=train_op, classify_op=tf.nn.softmax( classifier_to_gpu(x_in, getter=ema_getter, training=False)), classify_raw=tf.nn.softmax(classifier_to_gpu( x_in, training=False))) # No EMA, for debugging.
def model(self, batch, lr, wd, ema, beta, w_match, warmup_kimg=1024, nu=2, mixmode='xxy.yxy', dbuf=128, **kwargs): hwc = [self.dataset.height, self.dataset.width, self.dataset.colors] xt_in = tf.placeholder(tf.float32, [batch] + hwc, 'xt') # For training x_in = tf.placeholder(tf.float32, [None] + hwc, 'x') y_in = tf.placeholder(tf.float32, [batch, nu] + hwc, 'y') l_in = tf.placeholder(tf.int32, [batch], 'labels') wd *= lr w_match *= tf.clip_by_value( tf.cast(self.step, tf.float32) / (warmup_kimg << 10), 0, 1) augment = MixMode(mixmode) classifier = lambda x, **kw: self.classifier(x, **kw, **kwargs).logits # Moving average of the current estimated label distribution p_model = layers.PMovingAverage('p_model', self.nclass, dbuf) p_target = layers.PMovingAverage( 'p_target', self.nclass, dbuf) # Rectified distribution (only for plotting) # Known (or inferred) true unlabeled distribution p_data = layers.PData(self.dataset) y = tf.reshape(tf.transpose(y_in, [1, 0, 2, 3, 4]), [-1] + hwc) guess = self.guess_label(tf.split(y, nu), classifier, T=0.5, **kwargs) ly = tf.stop_gradient(guess.p_target) lx = tf.one_hot(l_in, self.nclass) xy, labels_xy = augment([xt_in] + tf.split(y, nu), [lx] + [ly] * nu, [beta, beta]) x, y = xy[0], xy[1:] labels_x, labels_y = labels_xy[0], tf.concat(labels_xy[1:], 0) del xy, labels_xy batches = layers.interleave([x] + y, batch) skip_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) logits = [classifier(batches[0], training=True)] post_ops = [ v for v in tf.get_collection(tf.GraphKeys.UPDATE_OPS) if v not in skip_ops ] for batchi in batches[1:]: logits.append(classifier(batchi, training=True)) logits = layers.interleave(logits, batch) logits_x = logits[0] logits_y = tf.concat(logits[1:], 0) loss_xe = tf.nn.softmax_cross_entropy_with_logits_v2(labels=labels_x, logits=logits_x) loss_xe = tf.reduce_mean(loss_xe) loss_l2u = tf.square(labels_y - tf.nn.softmax(logits_y)) loss_l2u = tf.reduce_mean(loss_l2u) tf.summary.scalar('losses/xe', loss_xe) tf.summary.scalar('losses/l2u', loss_l2u) self.distribution_summary(p_data(), p_model(), p_target()) ema = tf.train.ExponentialMovingAverage(decay=ema) ema_op = ema.apply(utils.model_vars()) ema_getter = functools.partial(utils.getter_ema, ema) post_ops.extend([ ema_op, p_model.update(guess.p_model), p_target.update(guess.p_target) ]) if p_data.has_update: post_ops.append(p_data.update(lx)) post_ops.extend([ tf.assign(v, v * (1 - wd)) for v in utils.model_vars('classify') if 'kernel' in v.name ]) train_op = tf.train.AdamOptimizer(lr).minimize( loss_xe + w_match * loss_l2u, colocate_gradients_with_ops=True) with tf.control_dependencies([train_op]): train_op = tf.group(*post_ops) return EasyDict( xt=xt_in, x=x_in, y=y_in, label=l_in, train_op=train_op, classify_raw=tf.nn.softmax(classifier( x_in, training=False)), # No EMA, for debugging. classify_op=tf.nn.softmax( classifier(x_in, getter=ema_getter, training=False)))
def model(self, batch, lr, wd, ema, beta, w_match, warmup_kimg=1024, nu=2, mixmode='xxy.yxy', **kwargs): hwc = [self.dataset.height, self.dataset.width, self.dataset.colors] # Create placeholders for the labeled images, unlabeled images, # and the ground truth supervised labels respectively. x_in = tf.placeholder(tf.float32, [None] + hwc, 'x') y_in = tf.placeholder(tf.float32, [None, nu] + hwc, 'y') l_in = tf.placeholder(tf.int32, [None], 'labels') wd *= lr w_match *= tf.clip_by_value( tf.cast(self.step, tf.float32) / (warmup_kimg << 10), 0, 1) augment = MixMode(mixmode) classifier = functools.partial(self.classifier, **kwargs) y = tf.reshape(tf.transpose(y_in, [1, 0, 2, 3, 4]), [-1] + hwc) guess = self.guess_label(tf.split(y, nu), classifier, T=0.5, **kwargs) ly = tf.stop_gradient(guess.p_target) lx = tf.one_hot(l_in, self.nclass) # Create MixUp examples. xy, labels_xy = augment([x_in] + tf.split(y, nu), [lx] + [ly] * nu, [beta, beta]) x, y = xy[0], xy[1:] labels_x, labels_y = labels_xy[0], tf.concat(labels_xy[1:], 0) del xy, labels_xy # Create batches that represent both labeled and unlabeled batches. # For more, see google-research/mixmatch/issues/5. batches = layers.interleave([x] + y, batch) skip_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) logits = [classifier(batches[0], training=True)] post_ops = [ v for v in tf.get_collection(tf.GraphKeys.UPDATE_OPS) if v not in skip_ops ] for batchi in batches[1:]: logits.append(classifier(batchi, training=True)) logits = layers.interleave(logits, batch) logits_x = logits[0] logits_y = tf.concat(logits[1:], 0) # Calculate supervised and unsupervised losses. loss_xe = tf.nn.softmax_cross_entropy_with_logits_v2(labels=labels_x, logits=logits_x) if FLAGS.tsa != "none": print("Using training signal annealing...") loss_xe = self.anneal_sup_loss(logits_x, labels_x, loss_xe, self.step) else: loss_xe = tf.reduce_mean(loss_xe) loss_l2u = tf.square(labels_y - tf.nn.softmax(logits_y)) if FLAGS.percent_mask > 0: print("Using percent-based confidence masking...") loss_l2u = self.percent_confidence_mask_unsup( logits_y, labels_y, loss_l2u) else: loss_l2u = tf.reduce_mean(loss_l2u) # Calculate largest predicted probability for each image. unsup_prob = tf.nn.softmax(logits_y, axis=-1) tf.summary.scalar('losses/min_unsup_prob', tf.reduce_min(tf.reduce_max(unsup_prob, axis=-1))) tf.summary.scalar('losses/mean_unsup_prob', tf.reduce_mean(tf.reduce_max(unsup_prob, axis=-1))) tf.summary.scalar('losses/max_unsup_prob', tf.reduce_max(tf.reduce_max(unsup_prob, axis=-1))) # Print losses to tensorboard. tf.summary.scalar('losses/xe', loss_xe) tf.summary.scalar('losses/l2u', loss_l2u) tf.summary.scalar('losses/overall', loss_xe + w_match * loss_l2u) # Applying EMA weights to model. Conceptualized by Tarvainen & Valpola, 2017 # See https://arxiv.org/abs/1703.01780 for more. ema = tf.train.ExponentialMovingAverage(decay=ema) ema_op = ema.apply(utils.model_vars()) ema_getter = functools.partial(utils.getter_ema, ema) post_ops.append(ema_op) post_ops.extend([ tf.assign(v, v * (1 - wd)) for v in utils.model_vars('classify') if 'kernel' in v.name ]) train_op = tf.train.AdamOptimizer(lr).minimize( loss_xe + w_match * loss_l2u, colocate_gradients_with_ops=True) with tf.control_dependencies([train_op]): train_op = tf.group(*post_ops) # Tuning op: only retrain batch norm. skip_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) classifier(batches[0], training=True) train_bn = tf.group(*[ v for v in tf.get_collection(tf.GraphKeys.UPDATE_OPS) if v not in skip_ops ]) return EasyDict( x=x_in, y=y_in, label=l_in, train_op=train_op, tune_op=train_bn, classify_raw=tf.nn.softmax(classifier( x_in, training=False)), # No EMA, for debugging. classify_op=tf.nn.softmax( classifier(x_in, getter=ema_getter, training=False)), eval_loss_op=tf.reduce_mean( tf.nn.softmax_cross_entropy_with_logits_v2( logits=classifier(x_in, getter=ema_getter, training=False), labels=tf.one_hot(l_in, self.nclass))))
def model(self, nu, w_match, warmup_kimg, batch, lr, wd, ema, dbuf, beta, mixmode, logit_norm, **kwargs): def classifier(x, logit_norm=logit_norm, **kw): v = self.classifier(x, **kw, **kwargs)[0] if not logit_norm: return v return v * tf.rsqrt(tf.reduce_mean(tf.square(v)) + 1e-8) def embedding(x, **kw): return self.classifier(x, **kw, **kwargs)[1] label_index = tf.Variable(utils.idx_to_fixlen( self.dataset.labeled_indices, self.dataset.ntrain), trainable=False, name='label_index', dtype=tf.int32) label_index_input = tf.placeholder(tf.int32, self.dataset.ntrain, 'label_index_input') update_label_index = tf.assign(label_index, label_index_input) hwc = [self.dataset.height, self.dataset.width, self.dataset.colors] x_in = tf.placeholder(tf.float32, [None] + hwc, 'x') y_in = tf.placeholder(tf.float32, [None, nu] + hwc, 'y') l_in = tf.placeholder(tf.int32, [None], 'labels') wd *= lr w_match *= tf.clip_by_value( tf.cast(self.step, tf.float32) / (warmup_kimg << 10), 0, 1) augment = MixMode(mixmode) # Moving average of the current estimated label distribution p_model = layers.PMovingAverage('p_model', self.nclass, dbuf) p_target = layers.PMovingAverage( 'p_target', self.nclass, dbuf) # Rectified distribution (only for plotting) # Known (or inferred) true unlabeled distribution p_data = layers.PData(self.dataset) y = tf.reshape(tf.transpose(y_in, [1, 0, 2, 3, 4]), [-1] + hwc) guess = self.guess_label(tf.split(y, nu), classifier, p_data(), p_model(), **kwargs) ly = tf.stop_gradient(guess.p_target) lx = tf.one_hot(l_in, self.nclass) xy, labels_xy = augment([x_in] + tf.split(y, nu), [lx] + [ly] * nu, [beta, beta]) x, y = xy[0], xy[1:] labels_x, labels_y = labels_xy[0], tf.concat(labels_xy[1:], 0) del xy, labels_xy batches = layers.interleave([x] + y, batch) skip_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) logits = [classifier(batches[0], training=True)] post_ops = [ v for v in tf.get_collection(tf.GraphKeys.UPDATE_OPS) if v not in skip_ops ] for batchi in batches[1:]: logits.append(classifier(batchi, training=True)) logits = layers.interleave(logits, batch) logits_x = logits[0] logits_y = tf.concat(logits[1:], 0) loss_xe = tf.nn.softmax_cross_entropy_with_logits_v2(labels=labels_x, logits=logits_x) loss_xe = tf.reduce_mean(loss_xe) loss_l2u = tf.square(labels_y - tf.nn.softmax(logits_y)) loss_l2u = tf.reduce_mean(loss_l2u) tf.summary.scalar('losses/xe', loss_xe) tf.summary.scalar('losses/l2u', loss_l2u) self.distribution_summary(p_data(), p_model(), p_target()) ema = tf.train.ExponentialMovingAverage(decay=ema) ema_op = ema.apply(utils.model_vars()) ema_getter = functools.partial(utils.getter_ema, ema) post_ops.extend([ ema_op, p_model.update(guess.p_model), p_target.update(guess.p_target) ]) if p_data.has_update: post_ops.append(p_data.update(lx)) post_ops.extend([ tf.assign(v, v * (1 - wd)) for v in utils.model_vars('classify') if 'kernel' in v.name ]) train_op = tf.train.AdamOptimizer(lr).minimize( loss_xe + w_match * loss_l2u, colocate_gradients_with_ops=True) with tf.control_dependencies([train_op]): train_op = tf.group(*post_ops) # Tuning op: only retrain batch norm. skip_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) classifier(batches[0], training=True) train_bn = tf.group(*[ v for v in tf.get_collection(tf.GraphKeys.UPDATE_OPS) if v not in skip_ops ]) return utils.EasyDict( x=x_in, y=y_in, label=l_in, train_op=train_op, tune_op=train_bn, classify_raw=tf.nn.softmax( classifier(x_in, logit_norm=False, training=False)), # No EMA, for debugging. classify_op=tf.nn.softmax( classifier(x_in, logit_norm=False, getter=ema_getter, training=False)), embedding_op=embedding(x_in, training=False), label_index=label_index, update_label_index=update_label_index, label_index_input=label_index_input, )