def model(self, batch, lr, wd, ema, beta, w_match, warmup_kimg=1024, nu=2, mixmode='xxy.yxy', **kwargs): hwc = [self.dataset.height, self.dataset.width, self.dataset.colors] x_in = tf.placeholder(tf.float32, [None] + hwc, 'x') y_in = tf.placeholder(tf.float32, [None, nu] + hwc, 'y') l_in = tf.placeholder(tf.int32, [None], 'labels') wd *= lr w_match *= tf.clip_by_value(tf.cast(self.step, tf.float32) / (warmup_kimg << 10), 0, 1) augment = MixMode(mixmode) classifier = functools.partial(self.classifier, **kwargs) y = tf.reshape(tf.transpose(y_in, [1, 0, 2, 3, 4]), [-1] + hwc) # generate guessed label guess = self.guess_label(tf.split(y, nu), classifier, T=0.5, **kwargs) ly = tf.stop_gradient(guess.p_target) lx = tf.one_hot(l_in, self.nclass) # apply mixup xy, labels_xy = augment([x_in] + tf.split(y, nu), [lx] + [ly] * nu, [beta, beta]) x, y = xy[0], xy[1:] labels_x, labels_y = labels_xy[0], tf.concat(labels_xy[1:], 0) del xy, labels_xy batches = layers.interleave([x] + y, batch) skip_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) logits = [classifier(batches[0], training=True)] post_ops = [v for v in tf.get_collection(tf.GraphKeys.UPDATE_OPS) if v not in skip_ops] for batchi in batches[1:]: logits.append(classifier(batchi, training=True)) logits = layers.interleave(logits, batch) logits_x = logits[0] logits_y = tf.concat(logits[1:], 0) loss_xe = tf.nn.softmax_cross_entropy_with_logits_v2(labels=labels_x, logits=logits_x) loss_xe = tf.reduce_mean(loss_xe) loss_l2u = tf.square(labels_y - tf.nn.softmax(logits_y)) loss_l2u = tf.reduce_mean(loss_l2u) tf.summary.scalar('losses/xe', loss_xe) tf.summary.scalar('losses/l2u', loss_l2u) ema = tf.train.ExponentialMovingAverage(decay=ema) ema_op = ema.apply(utils.model_vars()) ema_getter = functools.partial(utils.getter_ema, ema) post_ops.append(ema_op) post_ops.extend([tf.assign(v, v * (1 - wd)) for v in utils.model_vars('classify') if 'kernel' in v.name]) train_op = tf.train.AdamOptimizer(lr).minimize(loss_xe + w_match * loss_l2u, colocate_gradients_with_ops=True) with tf.control_dependencies([train_op]): train_op = tf.group(*post_ops) # Tuning op: only retrain batch norm. skip_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) classifier(batches[0], training=True) train_bn = tf.group(*[v for v in tf.get_collection(tf.GraphKeys.UPDATE_OPS) if v not in skip_ops]) return EasyDict( x=x_in, y=y_in, label=l_in, train_op=train_op, tune_op=train_bn, classify_raw=tf.nn.softmax(classifier(x_in, training=False)), # No EMA, for debugging. classify_op=tf.nn.softmax(classifier(x_in, getter=ema_getter, training=False))) def cutmix():
def model(self, batch, lr, wd, ema, warmup_pos, consistency_weight, **kwargs): hwc = [self.dataset.height, self.dataset.width, self.dataset.colors] xt_in = tf.placeholder(tf.float32, [batch] + hwc, 'xt') # For training x_in = tf.placeholder(tf.float32, [None] + hwc, 'x') y_in = tf.placeholder(tf.float32, [batch, 2] + hwc, 'y') l_in = tf.placeholder(tf.int32, [batch], 'labels') l = tf.one_hot(l_in, self.nclass) wd *= lr warmup = tf.clip_by_value( tf.to_float(self.step) / (warmup_pos * (FLAGS.train_kimg << 10)), 0, 1) classifier = lambda x, **kw: self.classifier(x, **kw, **kwargs).logits logits_x = classifier(xt_in, training=True) post_ops = tf.get_collection( tf.GraphKeys.UPDATE_OPS ) # Take only first call to update batch norm. y = tf.reshape(tf.transpose(y_in, [1, 0, 2, 3, 4]), [-1] + hwc) y_1, y_2 = tf.split(y, 2) ema = tf.train.ExponentialMovingAverage(decay=ema) ema_op = ema.apply(utils.model_vars()) ema_getter = functools.partial(utils.getter_ema, ema) logits_y = classifier(y_1, training=True, getter=ema_getter) logits_teacher = tf.stop_gradient(logits_y) logits_student = classifier(y_2, training=True) loss_mt = tf.reduce_mean( (tf.nn.softmax(logits_teacher) - tf.nn.softmax(logits_student))**2, -1) loss_mt = tf.reduce_mean(loss_mt) loss = tf.nn.softmax_cross_entropy_with_logits_v2(labels=l, logits=logits_x) loss = tf.reduce_mean(loss) tf.summary.scalar('losses/xe', loss) tf.summary.scalar('losses/mt', loss_mt) post_ops.append(ema_op) post_ops.extend([ tf.assign(v, v * (1 - wd)) for v in utils.model_vars('classify') if 'kernel' in v.name ]) train_op = tf.train.AdamOptimizer(lr).minimize( loss + loss_mt * warmup * consistency_weight, colocate_gradients_with_ops=True) with tf.control_dependencies([train_op]): train_op = tf.group(*post_ops) return EasyDict( xt=xt_in, x=x_in, y=y_in, label=l_in, train_op=train_op, classify_raw=tf.nn.softmax(classifier( x_in, training=False)), # No EMA, for debugging. classify_op=tf.nn.softmax( classifier(x_in, getter=ema_getter, training=False)))
def model(self, lr, wd, ema, warmup_pos, consistency_weight, threshold, **kwargs): hwc = [self.dataset.height, self.dataset.width, self.dataset.colors] x_in = tf.placeholder(tf.float32, [None] + hwc, 'x') y_in = tf.placeholder(tf.float32, [None] + hwc, 'y') l_in = tf.placeholder(tf.int32, [None], 'labels') l = tf.one_hot(l_in, self.nclass) wd *= lr warmup = tf.clip_by_value(tf.to_float(self.step) / (warmup_pos * (FLAGS.train_kimg << 10)), 0, 1) classifier = functools.partial(self.classifier, **kwargs) logits_x = classifier(x_in, training=True) post_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) # Take only first call to update batch norm. logits_y = classifier(y_in, training=True) # Get the pseudo-label loss loss_pl = tf.nn.sparse_softmax_cross_entropy_with_logits( labels=tf.argmax(logits_y, axis=-1), logits=logits_y ) # Masks denoting which data points have high-confidence predictions greater_than_thresh = tf.reduce_any( tf.greater(tf.nn.softmax(logits_y), threshold), axis=-1, keepdims=True, ) greater_than_thresh = tf.cast(greater_than_thresh, loss_pl.dtype) # Only enforce the loss when the model is confident loss_pl *= greater_than_thresh # Note that we also average over examples without confident outputs; # this is consistent with the realistic evaluation codebase loss_pl = tf.reduce_mean(loss_pl) loss = tf.nn.softmax_cross_entropy_with_logits_v2(labels=l, logits=logits_x) loss = tf.reduce_mean(loss) tf.summary.scalar('losses/xe', loss) tf.summary.scalar('losses/pl', loss_pl) ema = tf.train.ExponentialMovingAverage(decay=ema) ema_op = ema.apply(utils.model_vars()) ema_getter = functools.partial(utils.getter_ema, ema) post_ops.append(ema_op) post_ops.extend([tf.assign(v, v * (1 - wd)) for v in utils.model_vars('classify') if 'kernel' in v.name]) train_op = tf.train.AdamOptimizer(lr).minimize(loss + loss_pl * warmup * consistency_weight, colocate_gradients_with_ops=True) with tf.control_dependencies([train_op]): train_op = tf.group(*post_ops) # Tuning op: only retrain batch norm. skip_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) classifier(x_in, training=True) train_bn = tf.group(*[v for v in tf.get_collection(tf.GraphKeys.UPDATE_OPS) if v not in skip_ops]) return EasyDict( x=x_in, y=y_in, label=l_in, train_op=train_op, tune_op=train_bn, classify_raw=tf.nn.softmax(classifier(x_in, training=False)), # No EMA, for debugging. classify_op=tf.nn.softmax(classifier(x_in, getter=ema_getter, training=False)))
def model(self, batch, lr, wd, ema, warmup_pos, consistency_weight, tcr_augment, **kwargs): hwc = [self.dataset.height, self.dataset.width, self.dataset.colors] xt_in = tf.placeholder(tf.float32, [batch] + hwc, 'xt') # For training x_in = tf.placeholder(tf.float32, [None] + hwc, 'x') y_in = tf.placeholder(tf.float32, [batch, 2] + hwc, 'y') # The unlabeled data l_in = tf.placeholder(tf.int32, [batch], 'labels') l = tf.one_hot(l_in, self.nclass) warmup = tf.clip_by_value(tf.to_float(self.step) / (warmup_pos * (FLAGS.train_kimg << 10)), 0, 1) lrate = tf.clip_by_value(tf.to_float(self.step) / (FLAGS.train_kimg << 10), 0, 1) lr *= tf.cos(lrate * (7 * np.pi) / (2 * 8)) tf.summary.scalar('monitors/lr', lr) # Labeled data. classifier = lambda x, **kw: self.classifier(x, **kw, **kwargs).logits logits_x = classifier(xt_in, training=True) post_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) # Take only first call to update batch norm. # Unlabeled data. classifier_embedding = lambda x, **kw: self.classifier(x, **kw, **kwargs).embeds y = tf.reshape(tf.transpose(y_in, [1, 0, 2, 3, 4]), [-1] + hwc) y_delta = self.augment(y, tcr_augment=tcr_augment) # Apply tcr_augment y_1, y_2 = tf.split(y, 2) y_1_delta, y_2_delta = tf.split(y_delta, 2) embeds_y_1 = classifier_embedding(y_1, training=True) embeds_y_1_delta = classifier_embedding(y_1_delta, training=True) embeds_y_2 = classifier_embedding(y_2, training=True) embeds_y_2_delta = classifier_embedding(y_2_delta, training=True) loss_tcr = tf.losses.mean_squared_error((y_1_delta - y_1), (y_2_delta - y_2)) loss_tcr = tf.reduce_mean(loss_tcr) tf.summary.scalar('losses/xeu', loss_tcr) loss = tf.nn.softmax_cross_entropy_with_logits_v2(labels=l, logits=logits_x) loss = tf.reduce_mean(loss) tf.summary.scalar('losses/xe', loss) # L2 regularization loss_wd = sum(tf.nn.l2_loss(v) for v in utils.model_vars('classify') if 'kernel' in v.name) tf.summary.scalar('losses/wd', loss_wd) ema = tf.train.ExponentialMovingAverage(decay=ema) ema_op = ema.apply(utils.model_vars()) ema_getter = functools.partial(utils.getter_ema, ema) post_ops.append(ema_op) train_op = tf.train.MomentumOptimizer(lr, 0.9, use_nesterov=True).minimize( loss + loss_tcr * warmup * consistency_weight + wd * loss_wd, colocate_gradients_with_ops=True) with tf.control_dependencies([train_op]): train_op = tf.group(*post_ops) return EasyDict( xt=xt_in, x=x_in, y=y_in, label=l_in, train_op=train_op, classify_raw=tf.nn.softmax(classifier(x_in, training=False)), # No EMA, for debugging. classify_op=tf.nn.softmax(classifier(x_in, getter=ema_getter, training=False)))
def model(self, batch, lr, wd, ema, **kwargs): hwc = [self.dataset.height, self.dataset.width, self.dataset.colors] xt_in = tf.placeholder(tf.float32, [batch] + hwc, 'xt') # For training x_in = tf.placeholder(tf.float32, [None] + hwc, 'x') y_in = tf.placeholder(tf.float32, [batch] + hwc, 'y') l_in = tf.placeholder(tf.int32, [batch], 'labels') wd *= lr classifier = lambda x, **kw: self.classifier(x, **kw, **kwargs).logits def get_logits(x): logits = classifier(x, training=True) return logits x, labels_x = self.augment(xt_in, tf.one_hot(l_in, self.nclass), **kwargs) logits_x = get_logits(x) post_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) y, labels_y = self.augment(y_in, tf.nn.softmax(get_logits(y_in)), **kwargs) labels_y = tf.stop_gradient(labels_y) logits_y = get_logits(y) loss_xe = tf.nn.softmax_cross_entropy_with_logits_v2(labels=labels_x, logits=logits_x) loss_xe = tf.reduce_mean(loss_xe) loss_xeu = tf.nn.softmax_cross_entropy_with_logits_v2(labels=labels_y, logits=logits_y) loss_xeu = tf.reduce_mean(loss_xeu) tf.summary.scalar('losses/xe', loss_xe) tf.summary.scalar('losses/xeu', loss_xeu) ema = tf.train.ExponentialMovingAverage(decay=ema) ema_op = ema.apply(utils.model_vars()) ema_getter = functools.partial(utils.getter_ema, ema) post_ops.append(ema_op) post_ops.extend([ tf.assign(v, v * (1 - wd)) for v in utils.model_vars('classify') if 'kernel' in v.name ]) train_op = tf.train.AdamOptimizer(lr).minimize( loss_xe + loss_xeu, colocate_gradients_with_ops=True) with tf.control_dependencies([train_op]): train_op = tf.group(*post_ops) return EasyDict( xt=xt_in, x=x_in, y=y_in, label=l_in, train_op=train_op, classify_raw=tf.nn.softmax(classifier( x_in, training=False)), # No EMA, for debugging. classify_op=tf.nn.softmax( classifier(x_in, getter=ema_getter, training=False)))
def model(self, batch, lr, wd, wu, confidence, uratio, ema=0.999, **kwargs): hwc = [self.dataset.height, self.dataset.width, self.dataset.colors] xt_in = tf.placeholder(tf.float32, [batch] + hwc, 'xt') # Training labeled x_in = tf.placeholder(tf.float32, [None] + hwc, 'x') # Eval images y_in = tf.placeholder(tf.float32, [batch * uratio, 2] + hwc, 'y') # Training unlabeled (weak, strong) l_in = tf.placeholder(tf.int32, [batch], 'labels') # Labels lrate = tf.clip_by_value(tf.to_float(self.step) / (FLAGS.train_kimg << 10), 0, 1) lr *= tf.cos(lrate * (7 * np.pi) / (2 * 8)) tf.summary.scalar('monitors/lr', lr) # Compute logits for xt_in and y_in classifier = lambda x, **kw: self.classifier(x, **kw, **kwargs).logits skip_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) x = utils.interleave(tf.concat([xt_in, y_in[:, 0], y_in[:, 1]], 0), 2 * uratio + 1) logits = utils.para_cat(lambda x: classifier(x, training=True), x) logits = utils.de_interleave(logits, 2 * uratio+1) post_ops = [v for v in tf.get_collection(tf.GraphKeys.UPDATE_OPS) if v not in skip_ops] logits_x = logits[:batch] logits_weak, logits_strong = tf.split(logits[batch:], 2) del logits, skip_ops # Labeled cross-entropy loss_xe = tf.nn.sparse_softmax_cross_entropy_with_logits(labels=l_in, logits=logits_x) loss_xe = tf.reduce_mean(loss_xe) tf.summary.scalar('losses/xe', loss_xe) # Pseudo-label cross entropy for unlabeled data pseudo_labels = tf.stop_gradient(tf.nn.softmax(logits_weak)) loss_xeu = tf.nn.sparse_softmax_cross_entropy_with_logits(labels=tf.argmax(pseudo_labels, axis=1), logits=logits_strong) pseudo_mask = tf.to_float(tf.reduce_max(pseudo_labels, axis=1) >= confidence) tf.summary.scalar('monitors/mask', tf.reduce_mean(pseudo_mask)) loss_xeu = tf.reduce_mean(loss_xeu * pseudo_mask) tf.summary.scalar('losses/xeu', loss_xeu) # L2 regularization loss_wd = sum(tf.nn.l2_loss(v) for v in utils.model_vars('classify') if 'kernel' in v.name) tf.summary.scalar('losses/wd', loss_wd) ema = tf.train.ExponentialMovingAverage(decay=ema) ema_op = ema.apply(utils.model_vars()) ema_getter = functools.partial(utils.getter_ema, ema) post_ops.append(ema_op) train_op = tf.train.MomentumOptimizer(lr, 0.9, use_nesterov=True).minimize( loss_xe + wu * loss_xeu + wd * loss_wd, colocate_gradients_with_ops=True) with tf.control_dependencies([train_op]): train_op = tf.group(*post_ops) return utils.EasyDict( xt=xt_in, x=x_in, y=y_in, label=l_in, train_op=train_op, classify_raw=tf.nn.softmax(classifier(x_in, training=False)), # No EMA, for debugging. classify_op=tf.nn.softmax(classifier(x_in, getter=ema_getter, training=False)))
def model(self, lr, wd, ema, warmup_pos, consistency_weight, beta, **kwargs): hwc = [self.dataset.height, self.dataset.width, self.dataset.colors] x_in = tf.placeholder(tf.float32, [None] + hwc, 'x') y_in = tf.placeholder(tf.float32, [None, 2] + hwc, 'y') l_in = tf.placeholder(tf.int32, [None], 'labels') l = tf.one_hot(l_in, self.nclass) wd *= lr warmup = tf.clip_by_value(tf.to_float(self.step) / (warmup_pos * (FLAGS.train_kimg << 10)), 0, 1) y = tf.reshape(tf.transpose(y_in, [1, 0, 2, 3, 4]), [-1] + hwc) y_1, y_2 = tf.split(y, 2) mix = tf.distributions.Beta(beta, beta).sample([tf.shape(x_in)[0], 1, 1, 1]) mix = tf.maximum(mix, 1 - mix) classifier = functools.partial(self.classifier, **kwargs) logits_x = classifier(x_in, training=True) post_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) # Take only first call to update batch norm. ema = tf.train.ExponentialMovingAverage(decay=ema) ema_op = ema.apply(utils.model_vars()) ema_getter = functools.partial(utils.getter_ema, ema) logits_teacher = classifier(y_1, training=True, getter=ema_getter) labels_teacher = tf.stop_gradient(tf.nn.softmax(logits_teacher)) labels_teacher = labels_teacher * mix[:, :, 0, 0] + labels_teacher[::-1] * (1 - mix[:, :, 0, 0]) logits_student = classifier(y_1 * mix + y_1[::-1] * (1 - mix), training=True) loss_mt = tf.reduce_mean((labels_teacher - tf.nn.softmax(logits_student)) ** 2, -1) loss_mt = tf.reduce_mean(loss_mt) loss = tf.nn.softmax_cross_entropy_with_logits_v2(labels=l, logits=logits_x) loss = tf.reduce_mean(loss) tf.summary.scalar('losses/xe', loss) tf.summary.scalar('losses/mt', loss_mt) post_ops.append(ema_op) post_ops.extend([tf.assign(v, v * (1 - wd)) for v in utils.model_vars('classify') if 'kernel' in v.name]) train_op = tf.train.AdamOptimizer(lr).minimize(loss + loss_mt * warmup * consistency_weight, colocate_gradients_with_ops=True) with tf.control_dependencies([train_op]): train_op = tf.group(*post_ops) # Tuning op: only retrain batch norm. skip_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) classifier(x_in, training=True) train_bn = tf.group(*[v for v in tf.get_collection(tf.GraphKeys.UPDATE_OPS) if v not in skip_ops]) return EasyDict( x=x_in, y=y_in, label=l_in, train_op=train_op, tune_op=train_bn, classify_raw=tf.nn.softmax(classifier(x_in, training=False)), # No EMA, for debugging. classify_op=tf.nn.softmax(classifier(x_in, getter=ema_getter, training=False)))
def model(self, lr, wd, ema, **kwargs): hwc = [self.dataset.height, self.dataset.width, self.dataset.colors] x_in = tf.placeholder(tf.float32, [None] + hwc, 'x') l_in = tf.placeholder(tf.int32, [None], 'labels') wd *= lr l = tf.one_hot(l_in, self.nclass) x, l = self.augment(x_in, l, **kwargs) classifier = functools.partial(self.classifier, **kwargs) logits = classifier(x, training=True) loss = tf.nn.softmax_cross_entropy_with_logits_v2(labels=l, logits=logits) loss = tf.reduce_mean(loss) tf.summary.scalar('losses/xe', loss) ema = tf.train.ExponentialMovingAverage(decay=ema) ema_op = ema.apply(utils.model_vars()) ema_getter = functools.partial(utils.getter_ema, ema) post_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) + [ema_op] post_ops.extend([ tf.assign(v, v * (1 - wd)) for v in utils.model_vars('classify') if 'kernel' in v.name ]) train_op = tf.train.AdamOptimizer(lr).minimize( loss, colocate_gradients_with_ops=True) with tf.control_dependencies([train_op]): train_op = tf.group(*post_ops) # Tuning op: only retrain batch norm. skip_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) classifier(x_in, training=True) train_bn = tf.group(*[ v for v in tf.get_collection(tf.GraphKeys.UPDATE_OPS) if v not in skip_ops ]) return EasyDict( x=x_in, label=l_in, train_op=train_op, tune_op=train_bn, classify_raw=tf.nn.softmax(classifier( x_in, training=False)), # No EMA, for debugging. classify_op=tf.nn.softmax( classifier(x_in, getter=ema_getter, training=False)), eval_loss_op=tf.reduce_mean( tf.nn.softmax_cross_entropy_with_logits_v2( logits=classifier(x_in, getter=ema_getter, training=False), labels=tf.one_hot(l_in, self.nclass))))
def __init__(self, train_dir: str, dataset: data.DataSet, **kwargs): self.train_dir = os.path.join(train_dir, self.experiment_name(**kwargs)) self.params = EasyDict(kwargs) self.dataset = dataset self.session = None self.tmp = EasyDict(print_queue=[], cache=EasyDict()) self.step = tf.train.get_or_create_global_step() self.ops = self.model(**kwargs) self.ops.update_step = tf.assign_add(self.step, FLAGS.batch) self.add_summaries(**kwargs) print(' Config '.center(80, '-')) print('train_dir', self.train_dir) print('%-32s %s' % ('Model', self.__class__.__name__)) print('%-32s %s' % ('Dataset', dataset.name)) for k, v in sorted(kwargs.items()): print('%-32s %s' % (k, v)) print(' Model '.center(80, '-')) to_print = [ tuple(['%s' % x for x in (v.name, np.prod(v.shape), v.shape)]) for v in utils.model_vars(None) ] to_print.append(('Total', str(sum(int(x[1]) for x in to_print)), '')) sizes = [max([len(x[i]) for x in to_print]) for i in range(3)] fmt = '%%-%ds %%%ds %%%ds' % tuple(sizes) for x in to_print[:-1]: print(fmt % x) print() print(fmt % to_print[-1]) print('-' * 80) self._create_initial_files()
def model(self, batch, lr, wd, ema, warmup_pos, vat, vat_eps, entmin_weight, **kwargs): hwc = [self.dataset.height, self.dataset.width, self.dataset.colors] xt_in = tf.placeholder(tf.float32, [batch] + hwc, 'xt') # For training x_in = tf.placeholder(tf.float32, [None] + hwc, 'x') y_in = tf.placeholder(tf.float32, [batch] + hwc, 'y') l_in = tf.placeholder(tf.int32, [batch], 'labels') wd *= lr warmup = tf.clip_by_value(tf.to_float(self.step) / (warmup_pos * (FLAGS.train_kimg << 10)), 0, 1) classifier = lambda x, **kw: self.classifier(x, **kw, **kwargs).logits l = tf.one_hot(l_in, self.nclass) logits_x = classifier(xt_in, training=True) post_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) # Take only first call to update batch norm. logits_y = classifier(y_in, training=True) delta_y = vat_utils.generate_perturbation(y_in, logits_y, lambda x: classifier(x, training=True), vat_eps) logits_student = classifier(y_in + delta_y, training=True) logits_teacher = tf.stop_gradient(logits_y) loss_vat = layers.kl_divergence_from_logits(logits_student, logits_teacher) loss_vat = tf.reduce_mean(loss_vat) loss_entmin = tf.reduce_mean(tf.distributions.Categorical(logits=logits_y).entropy()) loss = tf.nn.softmax_cross_entropy_with_logits_v2(labels=l, logits=logits_x) loss = tf.reduce_mean(loss) tf.summary.scalar('losses/xe', loss) tf.summary.scalar('losses/vat', loss_vat) tf.summary.scalar('losses/entmin', loss_entmin) ema = tf.train.ExponentialMovingAverage(decay=ema) ema_op = ema.apply(utils.model_vars()) ema_getter = functools.partial(utils.getter_ema, ema) post_ops.append(ema_op) post_ops.extend([tf.assign(v, v * (1 - wd)) for v in utils.model_vars('classify') if 'kernel' in v.name]) train_op = tf.train.AdamOptimizer(lr).minimize(loss + loss_vat * warmup * vat + entmin_weight * loss_entmin, colocate_gradients_with_ops=True) with tf.control_dependencies([train_op]): train_op = tf.group(*post_ops) return EasyDict( xt=xt_in, x=x_in, y=y_in, label=l_in, train_op=train_op, classify_raw=tf.nn.softmax(classifier(x_in, training=False)), # No EMA, for debugging. classify_op=tf.nn.softmax(classifier(x_in, getter=ema_getter, training=False)))
def __init__(self, train_dir: str, dataset: data.DataSet, **kwargs): self.train_dir = train_dir self.params = EasyDict(kwargs) self.dataset = dataset self.session = None self.tmp = EasyDict(print_queue=[], cache=EasyDict()) self.step = tf.train.get_or_create_global_step() self.ops = self.model(**kwargs) self.ops.update_step = tf.assign_add(self.step, FLAGS.batch) self.add_summaries(**kwargs) #Initialize accuracies.txt if os.path.exists(self.train_dir + "/accuracies.txt"): with open(self.train_dir + "/accuracies.txt", 'r') as infile: self.accuracies = json.loads(infile.read()) else: self.accuracies = {} #Initialize noise.txt if os.path.exists(FLAGS.eval_ckpt[:-23] + "/noise.txt"): with open(FLAGS.eval_ckpt[:-23] + "/noise.txt", 'r') as infile: self.noise = json.loads(infile.read()) else: self.noise = {} #Print model Config. print() print(' Config '.center(80, '-')) print('train_dir', self.train_dir) print('%-32s %s' % ('Model', self.__class__.__name__)) print('%-32s %s' % ('Dataset', dataset.name)) for k, v in sorted(kwargs.items()): print('%-32s %s' % (k, v)) print(' Model '.center(80, '-')) to_print = [ tuple(['%s' % x for x in (v.name, np.prod(v.shape), v.shape)]) for v in utils.model_vars(None) ] to_print.append(('Total', str(sum(int(x[1]) for x in to_print)), '')) sizes = [max([len(x[i]) for x in to_print]) for i in range(3)] fmt = '%%-%ds %%%ds %%%ds' % tuple(sizes) for x in to_print[:-1]: print(fmt % x) print() print(fmt % to_print[-1]) print('-' * 80) if not FLAGS.eval_ckpt: self._create_initial_files() else: print('-' * 50) print('Evaluation mode') print('-' * 50) pass
def model(self, batch, lr, wd, wu, wclr, mom, confidence, balance, delT, uratio, clrratio, temperature, ema=0.999, **kwargs): hwc = [self.dataset.height, self.dataset.width, self.dataset.colors] xt_in = tf.placeholder(tf.float32, [batch] + hwc, 'xt') # Training labeled x_in = tf.placeholder(tf.float32, [None] + hwc, 'x') # Eval images y_in = tf.placeholder(tf.float32, [batch * uratio, 2] + hwc, 'y') # Training unlabeled (weak, strong) l_in = tf.placeholder(tf.int32, [batch], 'labels') # Labels wclr_in = tf.placeholder(tf.int32, [1], 'wclr') # wclr lrate = tf.clip_by_value( tf.to_float(self.step) / (FLAGS.train_kimg << 10), 0, 1) lr *= tf.cos(lrate * (7 * np.pi) / (2 * 8)) tf.summary.scalar('monitors/lr', lr) # Compute logits for xt_in and y_in classifier = lambda x, **kw: self.classifier(x, **kw, **kwargs).logits skip_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) x = utils.interleave(tf.concat([xt_in, y_in[:, 0], y_in[:, 1]], 0), 2 * uratio + 1) logits = utils.para_cat(lambda x: classifier(x, training=True), x) logits = utils.de_interleave(logits, 2 * uratio + 1) post_ops = [ v for v in tf.get_collection(tf.GraphKeys.UPDATE_OPS) if v not in skip_ops ] logits_x = logits[:batch] logits_weak, logits_strong = tf.split(logits[batch:], 2) del logits, skip_ops # Labeled cross-entropy loss_xe = tf.nn.sparse_softmax_cross_entropy_with_logits( labels=l_in, logits=logits_x) loss_xe = tf.reduce_mean(loss_xe) tf.summary.scalar('losses/xe', loss_xe) # Pseudo-label cross entropy for unlabeled data pseudo_labels = tf.stop_gradient(tf.nn.softmax(logits_weak)) loss_xeu = tf.nn.sparse_softmax_cross_entropy_with_logits( labels=tf.argmax(pseudo_labels, axis=1), logits=logits_strong) # pseudo_mask = tf.to_float(tf.reduce_max(pseudo_labels, axis=1) >= confidence) pseudo_mask = self.class_balancing(pseudo_labels, balance, confidence, delT) tf.summary.scalar('monitors/mask', tf.reduce_mean(pseudo_mask)) loss_xeu = tf.reduce_mean(loss_xeu * pseudo_mask) tf.summary.scalar('losses/xeu', loss_xeu) ####################### Modification # Contrastive loss term contrast_loss = 0 if wclr > 0 and wclr_in == 0: ratio = min(uratio, clrratio) if FLAGS.clrDataAug == 1: preprocess_fn = functools.partial( data_util.preprocess_for_train, height=self.dataset.height, width=self.dataset.width) x = tf.concat( [lambda y: preprocess_fn(y), lambda y: preprocess_fn(y)], 0) embeds = lambda x, **kw: self.classifier(x, **kw, **kwargs ).embeds hidden = utils.para_cat(lambda x: embeds(x, training=True), x) else: embeds = lambda x, **kw: self.classifier(x, **kw, **kwargs ).embeds hiddens = utils.para_cat(lambda x: embeds(x, training=True), x) hiddens = utils.de_interleave(hiddens, 2 * uratio + 1) hiddens_weak, hiddens_strong = tf.split(hiddens[batch:], 2, 0) hidden = tf.concat([ hiddens_weak[:batch * ratio], hiddens_strong[:batch * ratio] ], axis=0) del hiddens, hiddens_weak, hiddens_strong contrast_loss, _, _ = obj_lib.add_contrastive_loss( hidden, hidden_norm=True, # FLAGS.hidden_norm, temperature=temperature, tpu_context=None) tf.summary.scalar('losses/contrast', contrast_loss) del embeds, hidden ###################### End # L2 regularization loss_wd = sum( tf.nn.l2_loss(v) for v in utils.model_vars('classify') if 'kernel' in v.name) tf.summary.scalar('losses/wd', loss_wd) ema = tf.train.ExponentialMovingAverage(decay=ema) ema_op = ema.apply(utils.model_vars()) ema_getter = functools.partial(utils.getter_ema, ema) post_ops.append(ema_op) # train_op = tf.train.MomentumOptimizer(lr, 0.9, use_nesterov=True).minimize( train_op = tf.train.MomentumOptimizer( lr, mom, use_nesterov=True).minimize( loss_xe + wu * loss_xeu + wclr * contrast_loss + wd * loss_wd, colocate_gradients_with_ops=True) with tf.control_dependencies([train_op]): train_op = tf.group(*post_ops) return utils.EasyDict( xt=xt_in, x=x_in, y=y_in, label=l_in, wclr=wclr_in, train_op=train_op, classify_raw=tf.nn.softmax(classifier( x_in, training=False)), # No EMA, for debugging. classify_op=tf.nn.softmax( classifier(x_in, getter=ema_getter, training=False)))
def get_all_params(self): all_params = [] for v in utils.model_vars(None): temp_tensor = tf.get_default_graph().get_tensor_by_name(v.name) all_params.append(temp_tensor) return all_params
def model(self, lr, wd, ema, warmup_pos, consistency_weight, **kwargs): hwc = [self.dataset.height, self.dataset.width, self.dataset.colors] x_in = tf.placeholder(tf.float32, [None] + hwc, 'x') y_in = tf.placeholder(tf.float32, [None] + hwc, 'y') e_in = tf.placeholder(tf.float32, [None] + hwc, 'e') l_in = tf.placeholder(tf.int32, [None], 'labels') l = tf.one_hot(l_in, self.nclass) wd *= lr warmup = tf.clip_by_value( tf.to_float(self.step) / (warmup_pos * (FLAGS.train_kimg << 10)), 0, 1) classifier = functools.partial(self.classifier, **kwargs) logits_x = classifier(x_in, training=True) post_ops = tf.get_collection( tf.GraphKeys.UPDATE_OPS ) # Take only first call to update batch norm. logits_y = classifier(y_in, training=True) logits_e = classifier(e_in, training=True) loss_nst = tf.losses.mean_squared_error(tf.nn.softmax(logits_e), tf.nn.softmax(logits_y)) loss = tf.nn.softmax_cross_entropy_with_logits_v2(labels=l, logits=logits_x) loss = tf.reduce_mean(loss) tf.summary.scalar('losses/xe', loss) tf.summary.scalar('losses/nst', loss_nst) ema = tf.train.ExponentialMovingAverage(decay=ema) ema_op = ema.apply(utils.model_vars()) ema_getter = functools.partial(utils.getter_ema, ema) post_ops.append(ema_op) post_ops.extend([ tf.assign(v, v * (1 - wd)) for v in utils.model_vars('classify') if 'kernel' in v.name ]) train_op = tf.train.AdamOptimizer(lr).minimize( loss + loss_nst * warmup * consistency_weight, colocate_gradients_with_ops=True) with tf.control_dependencies([train_op]): train_op = tf.group(*post_ops) # Tuning op: only retrain batch norm. skip_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) classifier(x_in, training=True) train_bn = tf.group(*[ v for v in tf.get_collection(tf.GraphKeys.UPDATE_OPS) if v not in skip_ops ]) return EasyDict( x=x_in, y=y_in, e=e_in, label=l_in, train_op=train_op, tune_op=train_bn, classify_raw=tf.nn.softmax(classifier( x_in, training=False)), # No EMA, for debugging. classify_op=tf.nn.softmax( classifier(x_in, getter=ema_getter, training=False)))
def model(self, nu, w_match, warmup_kimg, batch, lr, wd, ema, dbuf, beta, mixmode, logit_norm, **kwargs): def classifier(x, logit_norm=logit_norm, **kw): v = self.classifier(x, **kw, **kwargs)[0] if not logit_norm: return v return v * tf.rsqrt(tf.reduce_mean(tf.square(v)) + 1e-8) def embedding(x, **kw): return self.classifier(x, **kw, **kwargs)[1] label_index = tf.Variable(utils.idx_to_fixlen( self.dataset.labeled_indices, self.dataset.ntrain), trainable=False, name='label_index', dtype=tf.int32) label_index_input = tf.placeholder(tf.int32, self.dataset.ntrain, 'label_index_input') update_label_index = tf.assign(label_index, label_index_input) hwc = [self.dataset.height, self.dataset.width, self.dataset.colors] x_in = tf.placeholder(tf.float32, [None] + hwc, 'x') y_in = tf.placeholder(tf.float32, [None, nu] + hwc, 'y') l_in = tf.placeholder(tf.int32, [None], 'labels') wd *= lr w_match *= tf.clip_by_value( tf.cast(self.step, tf.float32) / (warmup_kimg << 10), 0, 1) augment = MixMode(mixmode) # Moving average of the current estimated label distribution p_model = layers.PMovingAverage('p_model', self.nclass, dbuf) p_target = layers.PMovingAverage( 'p_target', self.nclass, dbuf) # Rectified distribution (only for plotting) # Known (or inferred) true unlabeled distribution p_data = layers.PData(self.dataset) y = tf.reshape(tf.transpose(y_in, [1, 0, 2, 3, 4]), [-1] + hwc) guess = self.guess_label(tf.split(y, nu), classifier, p_data(), p_model(), **kwargs) ly = tf.stop_gradient(guess.p_target) lx = tf.one_hot(l_in, self.nclass) xy, labels_xy = augment([x_in] + tf.split(y, nu), [lx] + [ly] * nu, [beta, beta]) x, y = xy[0], xy[1:] labels_x, labels_y = labels_xy[0], tf.concat(labels_xy[1:], 0) del xy, labels_xy batches = layers.interleave([x] + y, batch) skip_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) logits = [classifier(batches[0], training=True)] post_ops = [ v for v in tf.get_collection(tf.GraphKeys.UPDATE_OPS) if v not in skip_ops ] for batchi in batches[1:]: logits.append(classifier(batchi, training=True)) logits = layers.interleave(logits, batch) logits_x = logits[0] logits_y = tf.concat(logits[1:], 0) loss_xe = tf.nn.softmax_cross_entropy_with_logits_v2(labels=labels_x, logits=logits_x) loss_xe = tf.reduce_mean(loss_xe) loss_l2u = tf.square(labels_y - tf.nn.softmax(logits_y)) loss_l2u = tf.reduce_mean(loss_l2u) tf.summary.scalar('losses/xe', loss_xe) tf.summary.scalar('losses/l2u', loss_l2u) self.distribution_summary(p_data(), p_model(), p_target()) ema = tf.train.ExponentialMovingAverage(decay=ema) ema_op = ema.apply(utils.model_vars()) ema_getter = functools.partial(utils.getter_ema, ema) post_ops.extend([ ema_op, p_model.update(guess.p_model), p_target.update(guess.p_target) ]) if p_data.has_update: post_ops.append(p_data.update(lx)) post_ops.extend([ tf.assign(v, v * (1 - wd)) for v in utils.model_vars('classify') if 'kernel' in v.name ]) train_op = tf.train.AdamOptimizer(lr).minimize( loss_xe + w_match * loss_l2u, colocate_gradients_with_ops=True) with tf.control_dependencies([train_op]): train_op = tf.group(*post_ops) # Tuning op: only retrain batch norm. skip_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) classifier(batches[0], training=True) train_bn = tf.group(*[ v for v in tf.get_collection(tf.GraphKeys.UPDATE_OPS) if v not in skip_ops ]) return utils.EasyDict( x=x_in, y=y_in, label=l_in, train_op=train_op, tune_op=train_bn, classify_raw=tf.nn.softmax( classifier(x_in, logit_norm=False, training=False)), # No EMA, for debugging. classify_op=tf.nn.softmax( classifier(x_in, logit_norm=False, getter=ema_getter, training=False)), embedding_op=embedding(x_in, training=False), label_index=label_index, update_label_index=update_label_index, label_index_input=label_index_input, )
def model(self, batch, lr, wd, ema, beta, w_match, warmup_kimg=1024, nu=2, mixmode='xxy.yxy', **kwargs): hwc = [self.dataset.height, self.dataset.width, self.dataset.colors] # Create placeholders for the labeled images, unlabeled images, # and the ground truth supervised labels respectively. x_in = tf.placeholder(tf.float32, [None] + hwc, 'x') y_in = tf.placeholder(tf.float32, [None, nu] + hwc, 'y') l_in = tf.placeholder(tf.int32, [None], 'labels') wd *= lr w_match *= tf.clip_by_value( tf.cast(self.step, tf.float32) / (warmup_kimg << 10), 0, 1) augment = MixMode(mixmode) classifier = functools.partial(self.classifier, **kwargs) y = tf.reshape(tf.transpose(y_in, [1, 0, 2, 3, 4]), [-1] + hwc) guess = self.guess_label(tf.split(y, nu), classifier, T=0.5, **kwargs) ly = tf.stop_gradient(guess.p_target) lx = tf.one_hot(l_in, self.nclass) # Create MixUp examples. xy, labels_xy = augment([x_in] + tf.split(y, nu), [lx] + [ly] * nu, [beta, beta]) x, y = xy[0], xy[1:] labels_x, labels_y = labels_xy[0], tf.concat(labels_xy[1:], 0) del xy, labels_xy # Create batches that represent both labeled and unlabeled batches. # For more, see google-research/mixmatch/issues/5. batches = layers.interleave([x] + y, batch) skip_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) logits = [classifier(batches[0], training=True)] post_ops = [ v for v in tf.get_collection(tf.GraphKeys.UPDATE_OPS) if v not in skip_ops ] for batchi in batches[1:]: logits.append(classifier(batchi, training=True)) logits = layers.interleave(logits, batch) logits_x = logits[0] logits_y = tf.concat(logits[1:], 0) # Calculate supervised and unsupervised losses. loss_xe = tf.nn.softmax_cross_entropy_with_logits_v2(labels=labels_x, logits=logits_x) if FLAGS.tsa != "none": print("Using training signal annealing...") loss_xe = self.anneal_sup_loss(logits_x, labels_x, loss_xe, self.step) else: loss_xe = tf.reduce_mean(loss_xe) loss_l2u = tf.square(labels_y - tf.nn.softmax(logits_y)) if FLAGS.percent_mask > 0: print("Using percent-based confidence masking...") loss_l2u = self.percent_confidence_mask_unsup( logits_y, labels_y, loss_l2u) else: loss_l2u = tf.reduce_mean(loss_l2u) # Calculate largest predicted probability for each image. unsup_prob = tf.nn.softmax(logits_y, axis=-1) tf.summary.scalar('losses/min_unsup_prob', tf.reduce_min(tf.reduce_max(unsup_prob, axis=-1))) tf.summary.scalar('losses/mean_unsup_prob', tf.reduce_mean(tf.reduce_max(unsup_prob, axis=-1))) tf.summary.scalar('losses/max_unsup_prob', tf.reduce_max(tf.reduce_max(unsup_prob, axis=-1))) # Print losses to tensorboard. tf.summary.scalar('losses/xe', loss_xe) tf.summary.scalar('losses/l2u', loss_l2u) tf.summary.scalar('losses/overall', loss_xe + w_match * loss_l2u) # Applying EMA weights to model. Conceptualized by Tarvainen & Valpola, 2017 # See https://arxiv.org/abs/1703.01780 for more. ema = tf.train.ExponentialMovingAverage(decay=ema) ema_op = ema.apply(utils.model_vars()) ema_getter = functools.partial(utils.getter_ema, ema) post_ops.append(ema_op) post_ops.extend([ tf.assign(v, v * (1 - wd)) for v in utils.model_vars('classify') if 'kernel' in v.name ]) train_op = tf.train.AdamOptimizer(lr).minimize( loss_xe + w_match * loss_l2u, colocate_gradients_with_ops=True) with tf.control_dependencies([train_op]): train_op = tf.group(*post_ops) # Tuning op: only retrain batch norm. skip_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) classifier(batches[0], training=True) train_bn = tf.group(*[ v for v in tf.get_collection(tf.GraphKeys.UPDATE_OPS) if v not in skip_ops ]) return EasyDict( x=x_in, y=y_in, label=l_in, train_op=train_op, tune_op=train_bn, classify_raw=tf.nn.softmax(classifier( x_in, training=False)), # No EMA, for debugging. classify_op=tf.nn.softmax( classifier(x_in, getter=ema_getter, training=False)), eval_loss_op=tf.reduce_mean( tf.nn.softmax_cross_entropy_with_logits_v2( logits=classifier(x_in, getter=ema_getter, training=False), labels=tf.one_hot(l_in, self.nclass))))
def model(self, lr, wd, ema, gamma, **kwargs): hwc = [self.dataset.height, self.dataset.width, self.dataset.colors] x_in = tf.placeholder(tf.float32, [None] + hwc, 'x') l_in = tf.placeholder(tf.int32, [None], 'labels') wd *= lr classifier = functools.partial(self.classifier, **kwargs) def get_logits(x): logits = classifier(x, training=True) return logits x, labels_x = self.augment(x_in, tf.one_hot(l_in, self.nclass), **kwargs) logits_x = get_logits(x) loss_xe = tf.nn.softmax_cross_entropy_with_logits_v2( labels=labels_x, logits=logits_x) #shape = (batchsize,) gradient = tf.gradients( loss_xe, x)[0] #output is list (batchsize, height, width, colors) loss_main = tf.reduce_mean(loss_xe) if gamma == None: loss_grad = tf.constant(0.0) elif gamma > 0: loss_grad = gamma * tf.reduce_sum( tf.square(gradient)) / tf.constant(FLAGS.batch, dtype=tf.float32) else: assert False, 'Check the penalty parameter gamma' tf.summary.scalar('losses/main', loss_main) tf.summary.scalar('losses/gradient', loss_grad) tf.summary.scalar('gradient/max_gradient', tf.reduce_max(tf.abs(gradient))) loss_xe = loss_main + loss_grad #sup_norm of gradients sup_gradients = tf.reduce_max(tf.abs(gradient), axis=[1, 2, 3]) #(batchsize, ) #EMA part ema = tf.train.ExponentialMovingAverage(decay=ema) ema_op = ema.apply(utils.model_vars()) ema_getter = functools.partial(utils.getter_ema, ema) post_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) + [ema_op] post_ops.extend([ tf.assign(v, v * (1 - wd)) for v in utils.model_vars('classify') if 'kernel' in v.name ]) train_op = tf.train.AdamOptimizer(lr).minimize( loss_xe, colocate_gradients_with_ops=True) with tf.control_dependencies([train_op]): train_op = tf.group(*post_ops) return EasyDict(x=x_in, label=l_in, train_op=train_op, classify_op=tf.nn.softmax( classifier(x_in, getter=ema_getter, training=False)), sup_gradients=sup_gradients)
def model(self, batch, lr, wd, ema, beta, w_match, warmup_kimg=1024, nu=2, mixmode='xxy.yxy', dbuf=128, **kwargs): hwc = [self.dataset.height, self.dataset.width, self.dataset.colors] xt_in = tf.placeholder(tf.float32, [batch] + hwc, 'xt') # For training x_in = tf.placeholder(tf.float32, [None] + hwc, 'x') y_in = tf.placeholder(tf.float32, [batch, nu] + hwc, 'y') l_in = tf.placeholder(tf.int32, [batch], 'labels') wd *= lr w_match *= tf.clip_by_value( tf.cast(self.step, tf.float32) / (warmup_kimg << 10), 0, 1) augment = MixMode(mixmode) classifier = lambda x, **kw: self.classifier(x, **kw, **kwargs).logits # Moving average of the current estimated label distribution p_model = layers.PMovingAverage('p_model', self.nclass, dbuf) p_target = layers.PMovingAverage( 'p_target', self.nclass, dbuf) # Rectified distribution (only for plotting) # Known (or inferred) true unlabeled distribution p_data = layers.PData(self.dataset) y = tf.reshape(tf.transpose(y_in, [1, 0, 2, 3, 4]), [-1] + hwc) guess = self.guess_label(tf.split(y, nu), classifier, T=0.5, **kwargs) ly = tf.stop_gradient(guess.p_target) lx = tf.one_hot(l_in, self.nclass) xy, labels_xy = augment([xt_in] + tf.split(y, nu), [lx] + [ly] * nu, [beta, beta]) x, y = xy[0], xy[1:] labels_x, labels_y = labels_xy[0], tf.concat(labels_xy[1:], 0) del xy, labels_xy batches = layers.interleave([x] + y, batch) skip_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) logits = [classifier(batches[0], training=True)] post_ops = [ v for v in tf.get_collection(tf.GraphKeys.UPDATE_OPS) if v not in skip_ops ] for batchi in batches[1:]: logits.append(classifier(batchi, training=True)) logits = layers.interleave(logits, batch) logits_x = logits[0] logits_y = tf.concat(logits[1:], 0) loss_xe = tf.nn.softmax_cross_entropy_with_logits_v2(labels=labels_x, logits=logits_x) loss_xe = tf.reduce_mean(loss_xe) loss_l2u = tf.square(labels_y - tf.nn.softmax(logits_y)) loss_l2u = tf.reduce_mean(loss_l2u) tf.summary.scalar('losses/xe', loss_xe) tf.summary.scalar('losses/l2u', loss_l2u) self.distribution_summary(p_data(), p_model(), p_target()) ema = tf.train.ExponentialMovingAverage(decay=ema) ema_op = ema.apply(utils.model_vars()) ema_getter = functools.partial(utils.getter_ema, ema) post_ops.extend([ ema_op, p_model.update(guess.p_model), p_target.update(guess.p_target) ]) if p_data.has_update: post_ops.append(p_data.update(lx)) post_ops.extend([ tf.assign(v, v * (1 - wd)) for v in utils.model_vars('classify') if 'kernel' in v.name ]) train_op = tf.train.AdamOptimizer(lr).minimize( loss_xe + w_match * loss_l2u, colocate_gradients_with_ops=True) with tf.control_dependencies([train_op]): train_op = tf.group(*post_ops) return EasyDict( xt=xt_in, x=x_in, y=y_in, label=l_in, train_op=train_op, classify_raw=tf.nn.softmax(classifier( x_in, training=False)), # No EMA, for debugging. classify_op=tf.nn.softmax( classifier(x_in, getter=ema_getter, training=False)))
def model(self, batch, lr, wd, wu, mom, delT, confidence, balance, uratio, ema=0.999, **kwargs): hwc = [self.dataset.height, self.dataset.width, self.dataset.colors] xt_in = tf.placeholder(tf.float32, [batch] + hwc, 'xt') # Training labeled x_in = tf.placeholder(tf.float32, [None] + hwc, 'x') # Eval images y_in = tf.placeholder(tf.float32, [batch * uratio, 2] + hwc, 'y') # Training unlabeled (weak, strong) l_in = tf.placeholder(tf.int32, [batch], 'labels') # Labels lrate = tf.clip_by_value( tf.to_float(self.step) / (FLAGS.train_kimg << 10), 0, 1) lr *= tf.cos(lrate * (7 * np.pi) / (2 * 8)) tf.summary.scalar('monitors/lr', lr) # Compute logits for xt_in and y_in classifier = lambda x, **kw: self.classifier(x, **kw, **kwargs).logits skip_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) x = utils.interleave(tf.concat([xt_in, y_in[:, 0], y_in[:, 1]], 0), 2 * uratio + 1) logits = utils.para_cat(lambda x: classifier(x, training=True), x) logits = utils.de_interleave(logits, 2 * uratio + 1) post_ops = [ v for v in tf.get_collection(tf.GraphKeys.UPDATE_OPS) if v not in skip_ops ] logits_x = logits[:batch] logits_weak, logits_strong = tf.split(logits[batch:], 2) del logits, skip_ops # Labeled cross-entropy loss_xe = tf.nn.sparse_softmax_cross_entropy_with_logits( labels=l_in, logits=logits_x) loss_xe = tf.reduce_mean(loss_xe) tf.summary.scalar('losses/xe', loss_xe) # Pseudo-label cross entropy for unlabeled data pseudo_labels = tf.stop_gradient(tf.nn.softmax(logits_weak)) pLabels = tf.math.argmax(pseudo_labels, axis=1) loss_xeu = tf.nn.sparse_softmax_cross_entropy_with_logits( labels=pLabels, logits=logits_strong) ####################### Modification pLabels = tf.cast(pLabels, dtype=tf.float32) classes, idx, counts = tf.unique_with_counts(pLabels) shape = tf.constant([self.dataset.nclass]) classes = tf.cast(classes, dtype=tf.int32) class_count = tf.scatter_nd(tf.reshape(classes, [tf.size(classes), 1]), counts, shape) print_cc = tf.print("class_count ", class_count, output_stream=sys.stdout) class_count = tf.cast(class_count, dtype=tf.float32) mxCount = tf.reduce_max(class_count, axis=0) if balance > 0: pLabels = tf.cast(pLabels, dtype=tf.int32) if balance == 1 or balance == 4: confidences = tf.zeros_like(pLabels, dtype=tf.float32) ratios = 1.0 - tf.math.divide_no_nan(class_count, mxCount) ratios = confidence - delT * ratios confidences = tf.gather_nd( ratios, tf.reshape(pLabels, [tf.size(pLabels), 1])) pseudo_mask = tf.reduce_max(pseudo_labels, axis=1) >= confidences else: pseudo_mask = tf.reduce_max(pseudo_labels, axis=1) >= confidence if balance == 3 or balance == 4: classes, idx, counts = tf.unique_with_counts( tf.boolean_mask(pLabels, pseudo_mask)) shape = tf.constant([self.dataset.nclass]) classes = tf.cast(classes, dtype=tf.int32) class_count = tf.scatter_nd( tf.reshape(classes, [tf.size(classes), 1]), counts, shape) class_count = tf.cast(class_count, dtype=tf.float32) pseudo_mask = tf.cast(pseudo_mask, dtype=tf.float32) if balance > 1: ratios = tf.math.divide_no_nan( tf.ones_like(class_count, dtype=tf.float32), class_count) ratio = tf.gather_nd( ratios, tf.reshape(pLabels, [tf.size(pLabels), 1])) # ratio = sum(pseudo_mask) * ratio / sum(ratio) Z = tf.reduce_sum(pseudo_mask) pseudo_mask = tf.math.multiply( pseudo_mask, tf.cast(ratio, dtype=tf.float32)) pseudo_mask = tf.math.divide_no_nan( tf.scalar_mul(Z, pseudo_mask), tf.reduce_sum(pseudo_mask)) else: pseudo_mask = tf.cast( tf.reduce_max(pseudo_labels, axis=1) >= confidence, dtype=tf.float32) ###################### End # tf.print(" class_count= ",class_count) tf.summary.scalar('monitors/mask', tf.reduce_mean(pseudo_mask)) loss_xeu = tf.reduce_mean(loss_xeu * pseudo_mask) tf.summary.scalar('losses/xeu', loss_xeu) # L2 regularization loss_wd = sum( tf.nn.l2_loss(v) for v in utils.model_vars('classify') if 'kernel' in v.name) tf.summary.scalar('losses/wd', loss_wd) ema = tf.train.ExponentialMovingAverage(decay=ema) ema_op = ema.apply(utils.model_vars()) ema_getter = functools.partial(utils.getter_ema, ema) post_ops.append(ema_op) # train_op = tf.train.MomentumOptimizer(lr, 0.9, use_nesterov=True).minimize( train_op = tf.train.MomentumOptimizer( lr, mom, use_nesterov=True).minimize(loss_xe + wu * loss_xeu + wd * loss_wd, colocate_gradients_with_ops=True) with tf.control_dependencies([train_op]): train_op = tf.group(*post_ops) return utils.EasyDict( xt=xt_in, x=x_in, y=y_in, label=l_in, train_op=train_op, classify_raw=tf.nn.softmax(classifier( x_in, training=False)), # No EMA, for debugging. classify_op=tf.nn.softmax( classifier(x_in, getter=ema_getter, training=False)))
def model(self, batch, lr, wd, beta, w_kl, w_match, w_rot, K, use_xe, warmup_kimg=1024, T=0.5, mixmode='xxy.yxy', dbuf=128, ema=0.999, **kwargs): hwc = [self.dataset.height, self.dataset.width, self.dataset.colors] xt_in = tf.placeholder(tf.float32, [batch] + hwc, 'xt') # For training x_in = tf.placeholder(tf.float32, [None] + hwc, 'x') y_in = tf.placeholder(tf.float32, [batch, K + 1] + hwc, 'y') l_in = tf.placeholder(tf.int32, [batch], 'labels') wd *= lr w_match *= tf.clip_by_value( tf.cast(self.step, tf.float32) / (warmup_kimg << 10), 0, 1) augment = layers.MixMode(mixmode) gpu = utils.get_gpu() def classifier_to_gpu(x, **kw): with tf.device(next(gpu)): return self.classifier(x, **kw, **kwargs).logits def random_rotate(x): b4 = batch // 4 x, xt = x[:2 * b4], tf.transpose(x[2 * b4:], [0, 2, 1, 3]) l = np.zeros(b4, np.int32) l = tf.constant(np.concatenate([l, l + 1, l + 2, l + 3], axis=0)) return tf.concat( [x[:b4], x[b4:, ::-1, ::-1], xt[:b4, ::-1], xt[b4:, :, ::-1]], axis=0), l # Moving average of the current estimated label distribution p_model = layers.PMovingAverage('p_model', self.nclass, dbuf) p_target = layers.PMovingAverage( 'p_target', self.nclass, dbuf) # Rectified distribution (only for plotting) # Known (or inferred) true unlabeled distribution p_data = layers.PData(self.dataset) if w_rot > 0: rot_y, rot_l = random_rotate(y_in[:, 1]) with tf.device(next(gpu)): rot_logits = self.classifier_rot( self.classifier(rot_y, training=True, **kwargs).embeds) loss_rot = tf.nn.softmax_cross_entropy_with_logits_v2( labels=tf.one_hot(rot_l, 4), logits=rot_logits) loss_rot = tf.reduce_mean(loss_rot) tf.summary.scalar('losses/rot', loss_rot) else: loss_rot = 0 if kwargs['redux'] == '1st' and w_kl <= 0: logits_y = [classifier_to_gpu(y_in[:, 0], training=True)] * (K + 1) elif kwargs['redux'] == '1st': logits_y = [ classifier_to_gpu(y_in[:, i], training=True) for i in range(2) ] logits_y += logits_y[:1] * (K - 1) else: logits_y = [ classifier_to_gpu(y_in[:, i], training=True) for i in range(K + 1) ] guess = self.guess_label(logits_y, p_data(), p_model(), T=T, **kwargs) ly = tf.stop_gradient(guess.p_target) if w_kl > 0: w_kl *= tf.clip_by_value( tf.cast(self.step, tf.float32) / (warmup_kimg << 10), 0, 1) loss_kl = tf.nn.softmax_cross_entropy_with_logits_v2( labels=ly[:batch], logits=logits_y[1]) loss_kl = tf.reduce_mean(loss_kl) tf.summary.scalar('losses/kl', loss_kl) else: loss_kl = 0 del logits_y lx = tf.one_hot(l_in, self.nclass) xy, labels_xy = augment([xt_in] + [y_in[:, i] for i in range(K + 1)], [lx] + tf.split(ly, K + 1), [beta, beta]) x, y = xy[0], xy[1:] labels_x, labels_y = labels_xy[0], tf.concat(labels_xy[1:], 0) del xy, labels_xy batches = layers.interleave([x] + y, batch) logits = [classifier_to_gpu(yi, training=True) for yi in batches[:-1]] skip_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) logits.append(classifier_to_gpu(batches[-1], training=True)) post_ops = [ v for v in tf.get_collection(tf.GraphKeys.UPDATE_OPS) if v not in skip_ops ] logits = layers.interleave(logits, batch) logits_x = logits[0] logits_y = tf.concat(logits[1:], 0) del batches, logits loss_xe = tf.nn.softmax_cross_entropy_with_logits_v2(labels=labels_x, logits=logits_x) loss_xe = tf.reduce_mean(loss_xe) if use_xe: loss_xeu = tf.nn.softmax_cross_entropy_with_logits_v2( labels=labels_y, logits=logits_y) else: loss_xeu = tf.square(labels_y - tf.nn.softmax(logits_y)) loss_xeu = tf.reduce_mean(loss_xeu) tf.summary.scalar('losses/xe', loss_xe) tf.summary.scalar('losses/%s' % ('xeu' if use_xe else 'l2u'), loss_xeu) self.distribution_summary(p_data(), p_model(), p_target()) ema = tf.train.ExponentialMovingAverage(decay=ema) ema_op = ema.apply(utils.model_vars()) ema_getter = functools.partial(utils.getter_ema, ema) post_ops.extend([ ema_op, p_model.update(guess.p_model), p_target.update(guess.p_target) ]) if p_data.has_update: post_ops.append(p_data.update(lx)) post_ops.extend([ tf.assign(v, v * (1 - wd)) for v in utils.model_vars('classify') if 'kernel' in v.name ]) train_op = tf.train.AdamOptimizer(lr).minimize( loss_xe + w_kl * loss_kl + w_match * loss_xeu + w_rot * loss_rot, colocate_gradients_with_ops=True) with tf.control_dependencies([train_op]): train_op = tf.group(*post_ops) return EasyDict( xt=xt_in, x=x_in, y=y_in, label=l_in, train_op=train_op, classify_op=tf.nn.softmax( classifier_to_gpu(x_in, getter=ema_getter, training=False)), classify_raw=tf.nn.softmax(classifier_to_gpu( x_in, training=False))) # No EMA, for debugging.
def model(self, dataset, scale, blocks, filters, adv_weight, pcp_weight, layer_name, decay_start, decay_stop, lr_decay, **kwargs): del kwargs x = tf.placeholder( tf.float32, [None, dataset.colors, dataset.height, dataset.width], 'x') y = tf.placeholder(tf.float32, [None, dataset.colors, None, None], 'y') cur_lr = tf.cond( tf.train.get_global_step() < decay_start, lambda: FLAGS.lr, lambda: tf.train.exponential_decay( FLAGS.lr, tf.train.get_global_step() - decay_start, decay_stop - decay_start, lr_decay)) def tower(real): real = layers.to_nhwc(real) lores = self.downscale(real, order=layers.NHWC) fake = self.sres(lores, dataset.colors, filters, blocks, train=True) disc_real = self.disc(real, dataset.width, filters) disc_fake = self.disc(fake, dataset.width, filters) with tf.variable_scope('VGG', reuse=tf.AUTO_REUSE): vgg19 = vgg.Vgg19() real_embed = vgg19.build(layer_name, real, channels_last=True) / 1000 fake_embed = vgg19.build(layer_name, fake, channels_last=True) / 1000 loss_gmse = tf.losses.mean_squared_error(fake, real) loss_gpcp = tf.losses.mean_squared_error(real_embed, fake_embed) loss_ggan = tf.nn.sigmoid_cross_entropy_with_logits( logits=disc_fake, labels=tf.ones_like(disc_fake)) loss_dreal = tf.nn.sigmoid_cross_entropy_with_logits( logits=disc_real, labels=tf.ones_like(disc_real)) loss_dfake = tf.nn.sigmoid_cross_entropy_with_logits( logits=disc_fake, labels=tf.zeros_like(disc_fake)) return (loss_gmse, loss_gpcp, tf.reduce_mean(loss_ggan), tf.reduce_mean(loss_dreal), tf.reduce_mean(loss_dfake)) loss_gmse, loss_gpcp, loss_ggan, loss_dreal, loss_dfake = utils.para_mean( tower, x) loss_disc = loss_dreal + loss_dfake loss_gen = (loss_gmse + pcp_weight * loss_gpcp + adv_weight * loss_ggan) utils.HookReport.log_tensor(cur_lr, 'lr') utils.HookReport.log_tensor(loss_dreal, 'dreal') utils.HookReport.log_tensor(loss_dfake, 'dfake') utils.HookReport.log_tensor(tf.sqrt(loss_gpcp / loss_gmse), 'grat') utils.HookReport.log_tensor(loss_gmse, 'gmse') utils.HookReport.log_tensor(pcp_weight * loss_gpcp, 'gpcp') utils.HookReport.log_tensor(adv_weight * loss_ggan, 'ggan') utils.HookReport.log_tensor(loss_gen, 'gen') utils.HookReport.log_tensor(tf.sqrt(loss_gmse) * 127.5, 'rmse') update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) with tf.control_dependencies(update_ops): train_d = tf.train.AdamOptimizer(cur_lr, 0.9).minimize( loss_disc, var_list=utils.model_vars('disc'), colocate_gradients_with_ops=True) train_g = tf.train.AdamOptimizer(cur_lr, 0.9).minimize( loss_gen, var_list=utils.model_vars('sres'), colocate_gradients_with_ops=True, global_step=tf.train.get_global_step()) def sres_op(y): return self.sres(layers.to_nhwc(y), dataset.colors, filters, blocks, train=False) return EasyDict(x=x, y=y, train_op=tf.group(train_d, train_g), sres_op=layers.to_nchw(sres_op(y)), eval_op=layers.to_nchw(sres_op(self.downscale(x))))
def model(self, dataset, lod_min, lod_max, lod_start, lod_stop, scale, blocks, filters, filters_min, wass_target, weight_avg, mse_weight, noise_dim, ttur, total_steps, **kwargs): assert lod_min == 1 del kwargs x = tf.placeholder(tf.float32, [None, dataset.colors, dataset.height, dataset.width], 'x') y = tf.placeholder(tf.float32, [None, dataset.colors, None, None], 'y') noise = tf.placeholder(tf.float32, [], 'noise') lod = tf.placeholder(tf.float32, [], 'lod') lfilters = [max(filters_min, filters >> stage) for stage in range(lod_max + 1)] disc = functools.partial(self.disc, lod=lod, lod_min=lod_min, lod_start=lod_start, lod_stop=lod_stop, blocks=blocks, lfilters=lfilters) sres = functools.partial(self.sres, lod=lod, lod_min=lod_min, lod_start=lod_start, lod_stop=lod_stop, blocks=blocks, lfilters=lfilters, colors=dataset.colors) ema = tf.train.ExponentialMovingAverage(decay=weight_avg) if weight_avg > 0 else None def pad_shape(x): return [tf.shape(x)[0], noise_dim, tf.shape(x)[2], tf.shape(x)[3]] def straight_through_round(x, r=127.5 / 4): xr = tf.round(x * r) / r return tf.stop_gradient(xr - x) + x def sres_op(y, noise): eps = tf.random_normal(pad_shape(y), stddev=noise) sres_op = sres(tf.concat([y, eps], axis=1), ema=ema) sres_op = layers.upscale2d(sres_op, 1 << (lod_max - lod_stop)) return sres_op def tower(x): lores = self.downscale(x) real = layers.downscale2d(x, 1 << (lod_max - lod_stop)) if lod_start != lod_stop: real = layers.blend_resolution(layers.remove_details2d(real), real, lod - lod_start) eps = tf.random_normal(pad_shape(lores)) fake = sres(tf.concat([lores, tf.zeros_like(eps)], axis=1)) fake_eps = sres(tf.concat([lores, eps], axis=1)) lores_fake = self.downscale(layers.upscale2d(fake, 1 << (lod_max - lod_stop))) lores_fake_eps = self.downscale(layers.upscale2d(fake_eps, 1 << (lod_max - lod_stop))) latent_real = disc(real, straight_through_round(tf.abs(lores - lores))) latent_fake = disc(fake, straight_through_round(tf.abs(lores - lores_fake))) latent_fake_eps = disc(fake_eps, straight_through_round(tf.abs(lores - lores_fake_eps))) # Gradient penalty. mix = tf.random_uniform([tf.shape(real)[0], 1, 1, 1], 0., 1.) mixed = real + mix * (fake_eps - real) mixed = layers.upscale2d(mixed, 1 << (lod_max - lod_stop)) mixed_round = straight_through_round(tf.abs(lores - self.downscale(mixed))) mixdown = layers.downscale2d(mixed, 1 << (lod_max - lod_stop)) grad = tf.gradients(tf.reduce_sum(tf.reduce_mean(disc(mixdown, mixed_round), 1)), [mixed])[0] grad_norm = tf.sqrt(tf.reduce_mean(tf.square(grad), axis=[1, 2, 3]) + 1e-8) loss_dreal = -tf.reduce_mean(latent_real) loss_dfake = tf.reduce_mean(latent_fake_eps) loss_gfake = -tf.reduce_mean(latent_fake_eps) loss_gmse = tf.losses.mean_squared_error(latent_real, latent_fake) loss_gp = 10 * tf.reduce_mean(tf.square(grad_norm - wass_target)) * wass_target ** -2 mse_ema = tf.losses.mean_squared_error(sres(tf.concat([lores, tf.zeros_like(eps)], axis=1), ema=ema), real) return loss_gmse, loss_gfake, loss_dreal, loss_dfake, loss_gp, mse_ema loss_gmse, loss_gfake, loss_dreal, loss_dfake, loss_gp, mse_ema = utils.para_mean(tower, x) loss_disc = loss_dreal + loss_dfake + loss_gp loss_gen = loss_gfake + mse_weight * loss_gmse utils.HookReport.log_tensor(loss_dreal, 'dreal') utils.HookReport.log_tensor(loss_dfake, 'dfake') utils.HookReport.log_tensor(loss_gp, 'gp') utils.HookReport.log_tensor(loss_gfake, 'gfake') utils.HookReport.log_tensor(loss_gmse, 'gmse') utils.HookReport.log_tensor(tf.sqrt(mse_ema) * 127.5, 'rmse_ema') utils.HookReport.log_tensor(lod, 'lod') update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) with tf.control_dependencies(update_ops): train_d, train_g = [], [] global_arg = dict(global_step=tf.train.get_global_step()) for stage in range(lod_stop + 1): g_arg = global_arg if stage == 0 else {} with tf.variable_scope('opt_%d' % stage): train_d.append(tf.train.AdamOptimizer(FLAGS.lr, 0, 0.99).minimize( loss_disc * ttur, var_list=utils.model_vars('disc/stage_%d' % stage), colocate_gradients_with_ops=True)) train_g.append(tf.train.AdamOptimizer(FLAGS.lr, 0, 0.99).minimize( loss_gen, var_list=utils.model_vars('sres/stage_%d' % stage), colocate_gradients_with_ops=True, **g_arg)) if ema is not None: ema_op = ema.apply(tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, 'sres')) train_op = tf.group(*train_d, *train_g, ema_op) else: train_op = tf.group(*train_d, *train_g) return EasyDict(x=x, y=y, noise=noise, lod=lod, train_op=train_op, downscale_op=self.downscale(x), upscale_op=layers.upscale2d(y, self.scale, order=layers.NCHW), sres_op=sres_op(y, noise), eval_op=sres_op(self.downscale(x), 0))
def model(self, dataset, scale, blocks, filters, noise, decay_start, decay_stop, lr_decay, **kwargs): del kwargs x = tf.placeholder( tf.float32, [None, dataset.colors, dataset.height, dataset.width], 'x') y = tf.placeholder(tf.float32, [None, dataset.colors, None, None], 'y') cur_lr = tf.cond( tf.train.get_global_step() < decay_start, lambda: FLAGS.lr, lambda: tf.train.exponential_decay( FLAGS.lr, tf.train.get_global_step() - decay_start, decay_stop - decay_start, lr_decay)) def tower(real): real = layers.to_nhwc(real) lores = self.downscale(real, order=layers.NHWC) fake = self.sres(lores, noise, filters, blocks, train=True) disc_real = self.disc(real, lores, dataset.width, filters) disc_fake = self.disc(fake, lores, dataset.width, filters) loss_dreal = tf.reduce_mean(tf.nn.relu(1 - disc_real)) loss_dfake = tf.reduce_mean(tf.nn.relu(1 + disc_fake)) loss_gfake = -tf.reduce_mean(disc_fake) mse_ema = tf.losses.mean_squared_error(fake, real) return loss_gfake, loss_dreal, loss_dfake, mse_ema loss_gfake, loss_dreal, loss_dfake, mse_ema = utils.para_mean(tower, x) loss_disc = loss_dreal + loss_dfake loss_gen = loss_gfake utils.HookReport.log_tensor(cur_lr, 'lr') utils.HookReport.log_tensor(loss_dreal, 'dreal') utils.HookReport.log_tensor(loss_dfake, 'dfake') utils.HookReport.log_tensor(loss_gfake, 'gfake') utils.HookReport.log_tensor(tf.sqrt(mse_ema) * 127.5, 'rmse_ema') update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) with tf.control_dependencies(update_ops): train_d = tf.train.AdamOptimizer(cur_lr, 0, 0.9).minimize( loss_disc, var_list=utils.model_vars('disc'), colocate_gradients_with_ops=True) train_g = tf.train.AdamOptimizer(cur_lr, 0, 0.9).minimize( loss_gen, var_list=utils.model_vars('sres'), colocate_gradients_with_ops=True, global_step=tf.train.get_global_step()) def sres_op(y): return self.sres(layers.to_nhwc(y), noise, filters, blocks, train=False) return EasyDict(x=x, y=y, train_d=train_d, train_g=train_g, sres_op=layers.to_nchw(sres_op(y)), eval_op=layers.to_nchw(sres_op(self.downscale(x))))
def model(self, dataset, scale, blocks, filters, decay_start, decay_stop, lr_decay, adv_weight, pcp_weight, layer_name, **kwargs): del kwargs x = tf.placeholder(tf.float32, [None, dataset.colors, dataset.height, dataset.width], 'x') y = tf.placeholder(tf.float32, [None, dataset.colors, None, None], 'y') log_scale = utils.ilog2(scale) cur_lr = tf.train.exponential_decay(FLAGS.lr, tf.train.get_global_step() - decay_start, decay_stop - decay_start, lr_decay) utils.HookReport.log_tensor(cur_lr, 'lr') def sres(x0, train): conv_args = dict(padding='same', data_format='channels_first', kernel_initializer=tf.random_normal_initializer(stddev=0.02)) with tf.variable_scope("sres", reuse=tf.AUTO_REUSE) as vs: x1 = x = tf.layers.conv2d(x0, filters, 3, activation=tf.nn.relu, **conv_args) # Residuals for i in range(blocks): xb = tf.layers.conv2d(x, filters, 3, **conv_args) xb = tf.layers.batch_normalization(xb, axis=1, training=train) xb = tf.nn.relu(xb) xb = tf.layers.conv2d(xb, filters, 3, **conv_args) xb = tf.layers.batch_normalization(xb, axis=1, training=train) x += xb x = tf.layers.conv2d(x, filters, 3, **conv_args) x = tf.layers.batch_normalization(x, axis=1, training=train) x += x1 # Upsampling for _ in range(log_scale): x = tf.layers.conv2d(x, filters * 4, 3, activation=tf.nn.relu, **conv_args) x = layers.channels_to_space(x) x = tf.layers.conv2d(x, x0.shape[1], 1, activation=tf.nn.tanh, **conv_args) return x def disc(x): conv_args = dict(padding='same', data_format='channels_first', kernel_initializer=tf.random_normal_initializer(stddev=0.02)) with tf.variable_scope('disc', reuse=tf.AUTO_REUSE): y = tf.layers.conv2d(x, filters, 4, strides=2, activation=tf.nn.leaky_relu, **conv_args) y = tf.layers.conv2d(y, filters * 2, 4, strides=2, **conv_args) y = tf.nn.leaky_relu(tf.layers.batch_normalization(y, axis=1, training=True)) y = tf.layers.conv2d(y, filters * 4, 4, strides=2, **conv_args) y = tf.nn.leaky_relu(tf.layers.batch_normalization(y, axis=1, training=True)) y = tf.layers.conv2d(y, filters * 8, 4, strides=2, **conv_args) y = tf.nn.leaky_relu(tf.layers.batch_normalization(y, axis=1, training=True)) if dataset.width > 32: y = tf.layers.conv2d(y, filters * 16, 4, strides=2, **conv_args) y = tf.nn.leaky_relu(tf.layers.batch_normalization(y, axis=1, training=True)) if dataset.width > 64: y = tf.layers.conv2d(y, filters * 32, 4, strides=2, **conv_args) y = tf.nn.leaky_relu(tf.layers.batch_normalization(y, axis=1, training=True)) y = tf.layers.conv2d(y, filters * 16, 1, **conv_args) y = tf.nn.leaky_relu(tf.layers.batch_normalization(y, axis=1, training=True)) if dataset.width > 32: y = tf.layers.conv2d(y, filters * 8, 1, **conv_args) y = tf.nn.leaky_relu(tf.layers.batch_normalization(y, axis=1, training=True)) y7 = y y = tf.layers.conv2d(y, filters * 2, 1, **conv_args) y = tf.nn.leaky_relu(tf.layers.batch_normalization(y, axis=1, training=True)) y = tf.layers.conv2d(y, filters * 2, 3, **conv_args) y = tf.nn.leaky_relu(tf.layers.batch_normalization(y, axis=1, training=True)) y = tf.layers.conv2d(y, filters * 8, 3, **conv_args) y8 = tf.nn.leaky_relu(y7 + tf.layers.batch_normalization(y, axis=1, training=True)) logits = tf.layers.conv2d(y8, 1, 3, **conv_args) return tf.reshape(logits, [-1, 1]) def tower(real): lores = self.downscale(real) fake = sres(lores, True) disc_real = disc(real) disc_fake = disc(fake) with tf.variable_scope('VGG', reuse=tf.AUTO_REUSE): vgg19 = vgg.Vgg19() real_embed = vgg19.build(layer_name, real, channels_last=False) fake_embed = vgg19.build(layer_name, fake, channels_last=False) loss_gmse = tf.losses.mean_squared_error(fake, real) loss_gpcp = tf.losses.mean_squared_error(real_embed, fake_embed) loss_ggan = tf.nn.sigmoid_cross_entropy_with_logits(logits=disc_fake, labels=tf.ones_like(disc_fake)) loss_dreal = tf.nn.sigmoid_cross_entropy_with_logits(logits=disc_real, labels=tf.ones_like(disc_real)) loss_dfake = tf.nn.sigmoid_cross_entropy_with_logits(logits=disc_fake, labels=tf.zeros_like(disc_fake)) return (loss_gmse, loss_gpcp, tf.reduce_mean(loss_ggan), tf.reduce_mean(loss_dreal), tf.reduce_mean(loss_dfake)) loss_gmse, loss_gpcp, loss_ggan, loss_dreal, loss_dfake = utils.para_mean(tower, x) loss_disc = loss_dreal + loss_dfake loss_gen = (loss_gmse + pcp_weight * loss_gpcp + adv_weight * loss_ggan) utils.HookReport.log_tensor(loss_dreal, 'dreal') utils.HookReport.log_tensor(loss_dfake, 'dfake') utils.HookReport.log_tensor(loss_gmse, 'gmse') utils.HookReport.log_tensor(pcp_weight * loss_gpcp, 'gpcp') utils.HookReport.log_tensor(adv_weight * loss_ggan, 'ggan') utils.HookReport.log_tensor(loss_gen, 'gen') utils.HookReport.log_tensor(tf.sqrt(loss_gmse) * 127.5, 'rmse') update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) with tf.control_dependencies(update_ops): train_d = tf.train.AdamOptimizer(cur_lr, 0.9).minimize( loss_disc, var_list=utils.model_vars('disc'), colocate_gradients_with_ops=True) train_g = tf.train.AdamOptimizer(cur_lr, 0.9).minimize( loss_gen, var_list=utils.model_vars('sres'), colocate_gradients_with_ops=True, global_step=tf.train.get_global_step()) return EasyDict(x=x, y=y, sres_op=sres(y, False), eval_op=sres(self.downscale(x), False), train_op=tf.group(train_d, train_g))
def model(self, batch, lr, wd, ema, warmup_pos, consistency_weight, threshold, **kwargs): hwc = [self.dataset.height, self.dataset.width, self.dataset.colors] xt_in = tf.placeholder(tf.float32, [batch] + hwc, 'xt') # For training x_in = tf.placeholder(tf.float32, [None] + hwc, 'x') y_in = tf.placeholder(tf.float32, [batch] + hwc, 'y') l_in = tf.placeholder(tf.int32, [batch], 'labels') l = tf.one_hot(l_in, self.nclass) warmup = tf.clip_by_value( tf.to_float(self.step) / (warmup_pos * (FLAGS.train_kimg << 10)), 0, 1) lrate = tf.clip_by_value( tf.to_float(self.step) / (FLAGS.train_kimg << 10), 0, 1) lr *= tf.cos(lrate * (7 * np.pi) / (2 * 8)) tf.summary.scalar('monitors/lr', lr) classifier = lambda x, **kw: self.classifier(x, **kw, **kwargs).logits logits_x = classifier(xt_in, training=True) post_ops = tf.get_collection( tf.GraphKeys.UPDATE_OPS ) # Take only first call to update batch norm. logits_y = classifier(y_in, training=True) # Get the pseudo-label loss loss_pl = tf.nn.sparse_softmax_cross_entropy_with_logits( labels=tf.argmax(logits_y, axis=-1), logits=logits_y) # Masks denoting which data points have high-confidence predictions greater_than_thresh = tf.reduce_any( tf.greater(tf.nn.softmax(logits_y), threshold), axis=-1, keepdims=True, ) greater_than_thresh = tf.cast(greater_than_thresh, loss_pl.dtype) # Only enforce the loss when the model is confident loss_pl *= greater_than_thresh # Note that we also average over examples without confident outputs; # this is consistent with the realistic evaluation codebase loss_pl = tf.reduce_mean(loss_pl) loss = tf.nn.softmax_cross_entropy_with_logits_v2(labels=l, logits=logits_x) loss = tf.reduce_mean(loss) tf.summary.scalar('losses/xe', loss) tf.summary.scalar('losses/pl', loss_pl) # L2 regularization loss_wd = sum( tf.nn.l2_loss(v) for v in utils.model_vars('classify') if 'kernel' in v.name) tf.summary.scalar('losses/wd', loss_wd) ema = tf.train.ExponentialMovingAverage(decay=ema) ema_op = ema.apply(utils.model_vars()) ema_getter = functools.partial(utils.getter_ema, ema) post_ops.append(ema_op) train_op = tf.train.MomentumOptimizer( lr, 0.9, use_nesterov=True).minimize( loss + loss_pl * warmup * consistency_weight + wd * loss_wd, colocate_gradients_with_ops=True) with tf.control_dependencies([train_op]): train_op = tf.group(*post_ops) return EasyDict( xt=xt_in, x=x_in, y=y_in, label=l_in, train_op=train_op, classify_raw=tf.nn.softmax(classifier( x_in, training=False)), # No EMA, for debugging. classify_op=tf.nn.softmax( classifier(x_in, getter=ema_getter, training=False)))
def model(self, batch, lr, wd, wu, we, confidence, uratio, temperature=1.0, tsa='no', tsa_pos=10, ema=0.999, **kwargs): hwc = [self.dataset.height, self.dataset.width, self.dataset.colors] xt_in = tf.placeholder(tf.float32, [batch] + hwc, 'xt') # For training x_in = tf.placeholder(tf.float32, [None] + hwc, 'x') y_in = tf.placeholder(tf.float32, [batch * uratio, 2] + hwc, 'y') l_in = tf.placeholder(tf.int32, [batch], 'labels') l = tf.one_hot(l_in, self.nclass) lrate = tf.clip_by_value( tf.to_float(self.step) / (FLAGS.train_kimg << 10), 0, 1) lr *= tf.cos(lrate * (7 * np.pi) / (2 * 8)) tf.summary.scalar('monitors/lr', lr) # Compute logits for xt_in and y_in classifier = lambda x, **kw: self.classifier(x, **kw, **kwargs).logits skip_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) x = utils.interleave(tf.concat([xt_in, y_in[:, 0], y_in[:, 1]], 0), 2 * uratio + 1) logits = utils.para_cat(lambda x: classifier(x, training=True), x) logits = utils.de_interleave(logits, 2 * uratio + 1) post_ops = [ v for v in tf.get_collection(tf.GraphKeys.UPDATE_OPS) if v not in skip_ops ] logits_x = logits[:batch] logits_weak, logits_strong = tf.split(logits[batch:], 2) del logits, skip_ops # softmax temperature control logits_weak_tgt = self.softmax_temperature_controlling(logits_weak, T=temperature) # generate confidence mask based on sharpened distribution pseudo_labels = tf.stop_gradient(tf.nn.softmax(logits_weak)) pseudo_mask = self.confidence_based_masking(logits=None, p_class=pseudo_labels, thresh=confidence) tf.summary.scalar('monitors/mask', tf.reduce_mean(pseudo_mask)) tf.summary.scalar( 'monitors/conf_weak', tf.reduce_mean(tf.reduce_max(tf.nn.softmax(logits_weak), axis=1))) tf.summary.scalar( 'monitors/conf_strong', tf.reduce_mean(tf.reduce_max(tf.nn.softmax(logits_strong), axis=1))) kld = self.kl_divergence_from_logits(logits_weak_tgt, logits_strong) entropy = self.entropy_from_logits(logits_weak) loss_xeu = tf.reduce_mean(kld * pseudo_mask) tf.summary.scalar('losses/xeu', loss_xeu) loss_ent = tf.reduce_mean(entropy) tf.summary.scalar('losses/entropy', loss_ent) # supervised loss with TSA loss_mask = self.tsa_loss_mask(tsa=tsa, logits=logits_x, labels=l, tsa_pos=tsa_pos) loss_xe = tf.nn.softmax_cross_entropy_with_logits_v2(labels=l, logits=logits_x) loss_xe = tf.reduce_sum(loss_xe * loss_mask) / tf.math.maximum( tf.reduce_sum(loss_mask), 1.0) tf.summary.scalar('losses/xe', loss_xe) tf.summary.scalar('losses/mask_sup', tf.reduce_mean(loss_mask)) # L2 regularization loss_wd = sum( tf.nn.l2_loss(v) for v in utils.model_vars('classify') if 'kernel' in v.name) tf.summary.scalar('losses/wd', loss_wd) ema = tf.train.ExponentialMovingAverage(decay=ema) ema_op = ema.apply(utils.model_vars()) ema_getter = functools.partial(utils.getter_ema, ema) post_ops.append(ema_op) train_op = tf.train.MomentumOptimizer( lr, 0.9, use_nesterov=True).minimize( loss_xe + loss_xeu * wu + loss_ent * we + loss_wd * wd, colocate_gradients_with_ops=True) with tf.control_dependencies([train_op]): train_op = tf.group(*post_ops) return utils.EasyDict( xt=xt_in, x=x_in, y=y_in, label=l_in, train_op=train_op, classify_raw=tf.nn.softmax(classifier( x_in, training=False)), # No EMA, for debugging. classify_op=tf.nn.softmax( classifier(x_in, getter=ema_getter, training=False)))