Beispiel #1
0
    def model(self, batch, lr, wd, wu, confidence, uratio, ema=0.999, **kwargs):
        hwc = [self.dataset.height, self.dataset.width, self.dataset.colors]
        xt_in = tf.placeholder(tf.float32, [batch] + hwc, 'xt')  # For training
        x_in = tf.placeholder(tf.float32, [None] + hwc, 'x')
        y_in = tf.placeholder(tf.float32, [batch * uratio, 2] + hwc, 'y')
        l_in = tf.placeholder(tf.int32, [batch], 'labels')

        lrate = tf.clip_by_value(tf.to_float(self.step) / (FLAGS.train_kimg << 10), 0, 1)
        lr *= tf.cos(lrate * (7 * np.pi) / (2 * 8))
        tf.summary.scalar('monitors/lr', lr)

        classifier = lambda x, **kw: self.classifier(x, **kw, **kwargs).logits
        skip_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
        x = utils.interleave(tf.concat([xt_in, y_in[:, 0], y_in[:, 1]], 0), 2 * uratio + 1)
        logits = utils.para_cat(lambda x: classifier(x, training=True), x)
        logits = utils.de_interleave(logits, 2 * uratio+1)
        post_ops = [v for v in tf.get_collection(tf.GraphKeys.UPDATE_OPS) if v not in skip_ops]
        logits_x = logits[:batch]
        logits_strong = logits[batch:]
        del logits, skip_ops

        ema = tf.train.ExponentialMovingAverage(decay=ema)
        ema_op = ema.apply(utils.model_vars())
        ema_getter = functools.partial(utils.getter_ema, ema)
        post_ops.append(ema_op)

        loss_xe = tf.nn.softmax_cross_entropy_with_logits_v2(labels=tf.one_hot(l_in, self.nclass), logits=logits_x)
        loss_xe = tf.reduce_mean(loss_xe)
        tf.summary.scalar('losses/xe', loss_xe)

        logits_weak_mt = utils.para_cat(lambda x: classifier(x, getter=ema_getter, training=True), y_in[:, 0])
        pseudo_labels = tf.stop_gradient(tf.nn.softmax(logits_weak_mt))
        loss_xeu = tf.nn.sparse_softmax_cross_entropy_with_logits(labels=tf.argmax(pseudo_labels, axis=1),
                                                                  logits=logits_strong)
        pseudo_mask = tf.to_float(tf.reduce_max(pseudo_labels, axis=1) >= confidence)
        tf.summary.scalar('monitors/mask', tf.reduce_mean(pseudo_mask))
        loss_xeu = tf.reduce_mean(loss_xeu * pseudo_mask)
        tf.summary.scalar('losses/xeu', loss_xeu)

        loss_wd = sum(tf.nn.l2_loss(v) for v in utils.model_vars('classify') if 'kernel' in v.name)
        tf.summary.scalar('losses/wd', loss_wd)

        train_op = tf.train.MomentumOptimizer(lr, 0.9, use_nesterov=True).minimize(
            loss_xe + wu * loss_xeu + wd * loss_wd, colocate_gradients_with_ops=True)
        with tf.control_dependencies([train_op]):
            train_op = tf.group(*post_ops)

        return utils.EasyDict(
            xt=xt_in, x=x_in, y=y_in, label=l_in, train_op=train_op,
            classify_raw=tf.nn.softmax(classifier(x_in, training=False)),  # No EMA, for debugging.
            classify_op=tf.nn.softmax(classifier(x_in, getter=ema_getter, training=False)))
Beispiel #2
0
 def model_val(self, batch, uratio, **kwargs):
     hwc = [self.dataset.height, self.dataset.width, self.dataset.colors]
     # validation labeled
     x = tf.placeholder(tf.float32, [batch * uratio] + hwc, 'v')
     l_in = tf.placeholder(tf.int32, [batch * uratio], 'labels')  # Labels
     classifier = lambda x, **kw: self.classifier(x, **kw, **kwargs).logits
     logits = utils.para_cat(lambda x: classifier(x, training=True), x)
     loss_xe = tf.reduce_mean(
         tf.nn.sparse_softmax_cross_entropy_with_logits(
             labels=l_in, logits=logits))  # loss
     grad_val_loss_op = tf.gradients(loss_xe, self.model_params)
     return utils.EasyDict(v=x, l=l_in, grad_val_loss_op=grad_val_loss_op)
Beispiel #3
0
    def model(self,
              batch,
              lr,
              wd,
              wu,
              confidence,
              uratio,
              ema=0.999,
              **kwargs):
        hwc = [self.dataset.height, self.dataset.width, self.dataset.colors]
        xt_in = tf.placeholder(tf.float32, [batch] + hwc,
                               'xt')  # Training labeled
        x_in = tf.placeholder(tf.float32, [None] + hwc, 'x')  # Eval images
        # Training unlabeled (weak, strong)
        y_in = tf.placeholder(tf.float32, [batch * uratio, 2] + hwc, 'y')
        l_in = tf.placeholder(tf.int32, [batch], 'labels')  # Labels
        # weights for unlabeled data
        w_match = tf.placeholder(tf.float32, [batch * uratio], 'w_match')

        lrate = tf.clip_by_value(
            tf.to_float(self.step) / (FLAGS.train_kimg << 10), 0, 1)
        lr *= tf.cos(lrate * (7 * np.pi) / (2 * 8))
        tf.summary.scalar('monitors/lr', lr)

        # Compute logits for xt_in and y_in
        classifier = lambda x, **kw: self.classifier(x, **kw, **kwargs).logits
        skip_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
        x = utils.interleave(tf.concat([xt_in, y_in[:, 0], y_in[:, 1]], 0),
                             2 * uratio + 1)
        logits = utils.para_cat(lambda x: classifier(x, training=True), x)
        logits = utils.de_interleave(logits, 2 * uratio + 1)
        post_ops = [
            v for v in tf.get_collection(tf.GraphKeys.UPDATE_OPS)
            if v not in skip_ops
        ]
        logits_x = logits[:batch]
        logits_weak, logits_strong = tf.split(logits[batch:], 2)
        del logits, skip_ops

        # Labeled cross-entropy
        l_in_1hot = tf.one_hot(l_in, kwargs['nclass'])
        loss_xe = tf.nn.softmax_cross_entropy_with_logits_v2(labels=l_in_1hot,
                                                             logits=logits_x)
        loss_xe = tf.reduce_mean(loss_xe)
        tf.summary.scalar('losses/xe', loss_xe)

        # Pseudo-label cross entropy for unlabeled data
        pseudo_labels = tf.stop_gradient(tf.nn.softmax(logits_weak))
        pseudo_labels_hard = tf.one_hot(tf.argmax(pseudo_labels, axis=1),
                                        kwargs['nclass'])
        loss_xeu_all = tf.nn.softmax_cross_entropy_with_logits_v2(
            labels=pseudo_labels_hard, logits=logits_strong)
        pseudo_mask = tf.to_float(
            tf.reduce_max(pseudo_labels, axis=1) >= confidence)
        tf.summary.scalar('monitors/mask', tf.reduce_mean(pseudo_mask))
        tf.summary.scalar('monitors/lambdas', tf.reduce_mean(w_match))
        loss_xeu = tf.reduce_mean(loss_xeu_all * pseudo_mask * w_match)
        tf.summary.scalar('losses/xeu', loss_xeu)

        # L2 regularization
        loss_wd = sum(
            tf.nn.l2_loss(v) for v in utils.model_vars('classify')
            if 'kernel' in v.name)
        tf.summary.scalar('losses/wd', loss_wd)

        # inverse hessian
        self.model_params = self.get_all_params()[-2]
        total_loss = loss_wd + loss_xeu + loss_xe
        hessian = tf.hessians(total_loss, self.model_params)
        # TODO: remove the hard-coded 128
        _dim = 128 * kwargs['nclass']
        hessian = tf.reshape(hessian, [_dim, _dim])
        inv_H_op = tf.linalg.inv(hessian)

        ema = tf.train.ExponentialMovingAverage(decay=ema)
        ema_op = ema.apply(utils.model_vars())
        ema_getter = functools.partial(utils.getter_ema, ema)
        post_ops.append(ema_op)

        train_op = tf.train.MomentumOptimizer(
            lr, 0.9,
            use_nesterov=True).minimize(loss_xe + wu * loss_xeu + wd * loss_wd,
                                        colocate_gradients_with_ops=True)
        with tf.control_dependencies([train_op]):
            train_op = tf.group(*post_ops)

        return utils.EasyDict(xt=xt_in,
                              x=x_in,
                              y=y_in,
                              label=l_in,
                              train_op=train_op,
                              w_match=w_match,
                              inv_H_op=inv_H_op,
                              classify_raw=tf.nn.softmax(
                                  classifier(x_in, training=False)),
                              classify_op=tf.nn.softmax(
                                  classifier(x_in,
                                             getter=ema_getter,
                                             training=False)))
Beispiel #4
0
    def model(self,
              batch,
              lr,
              wd,
              wu,
              mom,
              delT,
              confidence,
              balance,
              uratio,
              ema=0.999,
              **kwargs):
        hwc = [self.dataset.height, self.dataset.width, self.dataset.colors]

        xt_in = tf.placeholder(tf.float32, [batch] + hwc,
                               'xt')  # Training labeled
        x_in = tf.placeholder(tf.float32, [None] + hwc, 'x')  # Eval images
        y_in = tf.placeholder(tf.float32, [batch * uratio, 2] + hwc,
                              'y')  # Training unlabeled (weak, strong)
        l_in = tf.placeholder(tf.int32, [batch], 'labels')  # Labels

        lrate = tf.clip_by_value(
            tf.to_float(self.step) / (FLAGS.train_kimg << 10), 0, 1)
        lr *= tf.cos(lrate * (7 * np.pi) / (2 * 8))
        tf.summary.scalar('monitors/lr', lr)

        # Compute logits for xt_in and y_in
        classifier = lambda x, **kw: self.classifier(x, **kw, **kwargs).logits
        skip_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
        x = utils.interleave(tf.concat([xt_in, y_in[:, 0], y_in[:, 1]], 0),
                             2 * uratio + 1)
        logits = utils.para_cat(lambda x: classifier(x, training=True), x)
        logits = utils.de_interleave(logits, 2 * uratio + 1)
        post_ops = [
            v for v in tf.get_collection(tf.GraphKeys.UPDATE_OPS)
            if v not in skip_ops
        ]
        logits_x = logits[:batch]
        logits_weak, logits_strong = tf.split(logits[batch:], 2)
        del logits, skip_ops

        # Labeled cross-entropy
        loss_xe = tf.nn.sparse_softmax_cross_entropy_with_logits(
            labels=l_in, logits=logits_x)
        loss_xe = tf.reduce_mean(loss_xe)
        tf.summary.scalar('losses/xe', loss_xe)

        # Pseudo-label cross entropy for unlabeled data
        pseudo_labels = tf.stop_gradient(tf.nn.softmax(logits_weak))
        pLabels = tf.math.argmax(pseudo_labels, axis=1)
        loss_xeu = tf.nn.sparse_softmax_cross_entropy_with_logits(
            labels=pLabels, logits=logits_strong)
        ####################### Modification
        pLabels = tf.cast(pLabels, dtype=tf.float32)
        classes, idx, counts = tf.unique_with_counts(pLabels)
        shape = tf.constant([self.dataset.nclass])
        classes = tf.cast(classes, dtype=tf.int32)
        class_count = tf.scatter_nd(tf.reshape(classes, [tf.size(classes), 1]),
                                    counts, shape)
        print_cc = tf.print("class_count ",
                            class_count,
                            output_stream=sys.stdout)
        class_count = tf.cast(class_count, dtype=tf.float32)
        mxCount = tf.reduce_max(class_count, axis=0)

        if balance > 0:
            pLabels = tf.cast(pLabels, dtype=tf.int32)
            if balance == 1 or balance == 4:
                confidences = tf.zeros_like(pLabels, dtype=tf.float32)
                ratios = 1.0 - tf.math.divide_no_nan(class_count, mxCount)
                ratios = confidence - delT * ratios
                confidences = tf.gather_nd(
                    ratios, tf.reshape(pLabels, [tf.size(pLabels), 1]))
                pseudo_mask = tf.reduce_max(pseudo_labels,
                                            axis=1) >= confidences
            else:
                pseudo_mask = tf.reduce_max(pseudo_labels,
                                            axis=1) >= confidence

            if balance == 3 or balance == 4:
                classes, idx, counts = tf.unique_with_counts(
                    tf.boolean_mask(pLabels, pseudo_mask))
                shape = tf.constant([self.dataset.nclass])
                classes = tf.cast(classes, dtype=tf.int32)
                class_count = tf.scatter_nd(
                    tf.reshape(classes, [tf.size(classes), 1]), counts, shape)
                class_count = tf.cast(class_count, dtype=tf.float32)
            pseudo_mask = tf.cast(pseudo_mask, dtype=tf.float32)

            if balance > 1:
                ratios = tf.math.divide_no_nan(
                    tf.ones_like(class_count, dtype=tf.float32), class_count)
                ratio = tf.gather_nd(
                    ratios, tf.reshape(pLabels, [tf.size(pLabels), 1]))
                # ratio = sum(pseudo_mask) * ratio / sum(ratio)
                Z = tf.reduce_sum(pseudo_mask)
                pseudo_mask = tf.math.multiply(
                    pseudo_mask, tf.cast(ratio, dtype=tf.float32))
                pseudo_mask = tf.math.divide_no_nan(
                    tf.scalar_mul(Z, pseudo_mask), tf.reduce_sum(pseudo_mask))
        else:
            pseudo_mask = tf.cast(
                tf.reduce_max(pseudo_labels, axis=1) >= confidence,
                dtype=tf.float32)


###################### End

#        tf.print(" class_count= ",class_count)
        tf.summary.scalar('monitors/mask', tf.reduce_mean(pseudo_mask))
        loss_xeu = tf.reduce_mean(loss_xeu * pseudo_mask)
        tf.summary.scalar('losses/xeu', loss_xeu)

        # L2 regularization
        loss_wd = sum(
            tf.nn.l2_loss(v) for v in utils.model_vars('classify')
            if 'kernel' in v.name)
        tf.summary.scalar('losses/wd', loss_wd)

        ema = tf.train.ExponentialMovingAverage(decay=ema)
        ema_op = ema.apply(utils.model_vars())
        ema_getter = functools.partial(utils.getter_ema, ema)
        post_ops.append(ema_op)

        #        train_op = tf.train.MomentumOptimizer(lr, 0.9, use_nesterov=True).minimize(
        train_op = tf.train.MomentumOptimizer(
            lr, mom,
            use_nesterov=True).minimize(loss_xe + wu * loss_xeu + wd * loss_wd,
                                        colocate_gradients_with_ops=True)
        with tf.control_dependencies([train_op]):
            train_op = tf.group(*post_ops)

        return utils.EasyDict(
            xt=xt_in,
            x=x_in,
            y=y_in,
            label=l_in,
            train_op=train_op,
            classify_raw=tf.nn.softmax(classifier(
                x_in, training=False)),  # No EMA, for debugging.
            classify_op=tf.nn.softmax(
                classifier(x_in, getter=ema_getter, training=False)))
Beispiel #5
0
    def model(self,
              batch,
              lr,
              wd,
              wu,
              wclr,
              mom,
              confidence,
              balance,
              delT,
              uratio,
              clrratio,
              temperature,
              ema=0.999,
              **kwargs):
        hwc = [self.dataset.height, self.dataset.width, self.dataset.colors]
        xt_in = tf.placeholder(tf.float32, [batch] + hwc,
                               'xt')  # Training labeled
        x_in = tf.placeholder(tf.float32, [None] + hwc, 'x')  # Eval images
        y_in = tf.placeholder(tf.float32, [batch * uratio, 2] + hwc,
                              'y')  # Training unlabeled (weak, strong)
        l_in = tf.placeholder(tf.int32, [batch], 'labels')  # Labels
        wclr_in = tf.placeholder(tf.int32, [1], 'wclr')  # wclr

        lrate = tf.clip_by_value(
            tf.to_float(self.step) / (FLAGS.train_kimg << 10), 0, 1)
        lr *= tf.cos(lrate * (7 * np.pi) / (2 * 8))
        tf.summary.scalar('monitors/lr', lr)

        # Compute logits for xt_in and y_in
        classifier = lambda x, **kw: self.classifier(x, **kw, **kwargs).logits
        skip_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
        x = utils.interleave(tf.concat([xt_in, y_in[:, 0], y_in[:, 1]], 0),
                             2 * uratio + 1)
        logits = utils.para_cat(lambda x: classifier(x, training=True), x)
        logits = utils.de_interleave(logits, 2 * uratio + 1)
        post_ops = [
            v for v in tf.get_collection(tf.GraphKeys.UPDATE_OPS)
            if v not in skip_ops
        ]
        logits_x = logits[:batch]
        logits_weak, logits_strong = tf.split(logits[batch:], 2)
        del logits, skip_ops

        # Labeled cross-entropy
        loss_xe = tf.nn.sparse_softmax_cross_entropy_with_logits(
            labels=l_in, logits=logits_x)
        loss_xe = tf.reduce_mean(loss_xe)
        tf.summary.scalar('losses/xe', loss_xe)

        # Pseudo-label cross entropy for unlabeled data
        pseudo_labels = tf.stop_gradient(tf.nn.softmax(logits_weak))
        loss_xeu = tf.nn.sparse_softmax_cross_entropy_with_logits(
            labels=tf.argmax(pseudo_labels, axis=1), logits=logits_strong)
        #        pseudo_mask = tf.to_float(tf.reduce_max(pseudo_labels, axis=1) >= confidence)
        pseudo_mask = self.class_balancing(pseudo_labels, balance, confidence,
                                           delT)
        tf.summary.scalar('monitors/mask', tf.reduce_mean(pseudo_mask))
        loss_xeu = tf.reduce_mean(loss_xeu * pseudo_mask)
        tf.summary.scalar('losses/xeu', loss_xeu)

        ####################### Modification
        # Contrastive loss term
        contrast_loss = 0
        if wclr > 0 and wclr_in == 0:
            ratio = min(uratio, clrratio)
            if FLAGS.clrDataAug == 1:
                preprocess_fn = functools.partial(
                    data_util.preprocess_for_train,
                    height=self.dataset.height,
                    width=self.dataset.width)
                x = tf.concat(
                    [lambda y: preprocess_fn(y), lambda y: preprocess_fn(y)],
                    0)
                embeds = lambda x, **kw: self.classifier(x, **kw, **kwargs
                                                         ).embeds
                hidden = utils.para_cat(lambda x: embeds(x, training=True), x)
            else:
                embeds = lambda x, **kw: self.classifier(x, **kw, **kwargs
                                                         ).embeds
                hiddens = utils.para_cat(lambda x: embeds(x, training=True), x)
                hiddens = utils.de_interleave(hiddens, 2 * uratio + 1)
                hiddens_weak, hiddens_strong = tf.split(hiddens[batch:], 2, 0)
                hidden = tf.concat([
                    hiddens_weak[:batch * ratio],
                    hiddens_strong[:batch * ratio]
                ],
                                   axis=0)
                del hiddens, hiddens_weak, hiddens_strong

            contrast_loss, _, _ = obj_lib.add_contrastive_loss(
                hidden,
                hidden_norm=True,  # FLAGS.hidden_norm,
                temperature=temperature,
                tpu_context=None)

            tf.summary.scalar('losses/contrast', contrast_loss)
            del embeds, hidden
###################### End

# L2 regularization
        loss_wd = sum(
            tf.nn.l2_loss(v) for v in utils.model_vars('classify')
            if 'kernel' in v.name)
        tf.summary.scalar('losses/wd', loss_wd)

        ema = tf.train.ExponentialMovingAverage(decay=ema)
        ema_op = ema.apply(utils.model_vars())
        ema_getter = functools.partial(utils.getter_ema, ema)
        post_ops.append(ema_op)

        #        train_op = tf.train.MomentumOptimizer(lr, 0.9, use_nesterov=True).minimize(
        train_op = tf.train.MomentumOptimizer(
            lr, mom, use_nesterov=True).minimize(
                loss_xe + wu * loss_xeu + wclr * contrast_loss + wd * loss_wd,
                colocate_gradients_with_ops=True)
        with tf.control_dependencies([train_op]):
            train_op = tf.group(*post_ops)

        return utils.EasyDict(
            xt=xt_in,
            x=x_in,
            y=y_in,
            label=l_in,
            wclr=wclr_in,
            train_op=train_op,
            classify_raw=tf.nn.softmax(classifier(
                x_in, training=False)),  # No EMA, for debugging.
            classify_op=tf.nn.softmax(
                classifier(x_in, getter=ema_getter, training=False)))
Beispiel #6
0
    def model(self,
              batch,
              lr,
              wd,
              wu,
              we,
              confidence,
              uratio,
              temperature=1.0,
              tsa='no',
              tsa_pos=10,
              ema=0.999,
              **kwargs):
        hwc = [self.dataset.height, self.dataset.width, self.dataset.colors]
        xt_in = tf.placeholder(tf.float32, [batch] + hwc, 'xt')  # For training
        x_in = tf.placeholder(tf.float32, [None] + hwc, 'x')
        y_in = tf.placeholder(tf.float32, [batch * uratio, 2] + hwc, 'y')
        l_in = tf.placeholder(tf.int32, [batch], 'labels')
        l = tf.one_hot(l_in, self.nclass)

        lrate = tf.clip_by_value(
            tf.to_float(self.step) / (FLAGS.train_kimg << 10), 0, 1)
        lr *= tf.cos(lrate * (7 * np.pi) / (2 * 8))
        tf.summary.scalar('monitors/lr', lr)

        # Compute logits for xt_in and y_in
        classifier = lambda x, **kw: self.classifier(x, **kw, **kwargs).logits
        skip_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
        x = utils.interleave(tf.concat([xt_in, y_in[:, 0], y_in[:, 1]], 0),
                             2 * uratio + 1)
        logits = utils.para_cat(lambda x: classifier(x, training=True), x)
        logits = utils.de_interleave(logits, 2 * uratio + 1)
        post_ops = [
            v for v in tf.get_collection(tf.GraphKeys.UPDATE_OPS)
            if v not in skip_ops
        ]
        logits_x = logits[:batch]
        logits_weak, logits_strong = tf.split(logits[batch:], 2)
        del logits, skip_ops

        # softmax temperature control
        logits_weak_tgt = self.softmax_temperature_controlling(logits_weak,
                                                               T=temperature)
        # generate confidence mask based on sharpened distribution
        pseudo_labels = tf.stop_gradient(tf.nn.softmax(logits_weak))
        pseudo_mask = self.confidence_based_masking(logits=None,
                                                    p_class=pseudo_labels,
                                                    thresh=confidence)
        tf.summary.scalar('monitors/mask', tf.reduce_mean(pseudo_mask))
        tf.summary.scalar(
            'monitors/conf_weak',
            tf.reduce_mean(tf.reduce_max(tf.nn.softmax(logits_weak), axis=1)))
        tf.summary.scalar(
            'monitors/conf_strong',
            tf.reduce_mean(tf.reduce_max(tf.nn.softmax(logits_strong),
                                         axis=1)))

        kld = self.kl_divergence_from_logits(logits_weak_tgt, logits_strong)
        entropy = self.entropy_from_logits(logits_weak)
        loss_xeu = tf.reduce_mean(kld * pseudo_mask)
        tf.summary.scalar('losses/xeu', loss_xeu)
        loss_ent = tf.reduce_mean(entropy)
        tf.summary.scalar('losses/entropy', loss_ent)

        # supervised loss with TSA
        loss_mask = self.tsa_loss_mask(tsa=tsa,
                                       logits=logits_x,
                                       labels=l,
                                       tsa_pos=tsa_pos)
        loss_xe = tf.nn.softmax_cross_entropy_with_logits_v2(labels=l,
                                                             logits=logits_x)
        loss_xe = tf.reduce_sum(loss_xe * loss_mask) / tf.math.maximum(
            tf.reduce_sum(loss_mask), 1.0)
        tf.summary.scalar('losses/xe', loss_xe)
        tf.summary.scalar('losses/mask_sup', tf.reduce_mean(loss_mask))

        # L2 regularization
        loss_wd = sum(
            tf.nn.l2_loss(v) for v in utils.model_vars('classify')
            if 'kernel' in v.name)
        tf.summary.scalar('losses/wd', loss_wd)

        ema = tf.train.ExponentialMovingAverage(decay=ema)
        ema_op = ema.apply(utils.model_vars())
        ema_getter = functools.partial(utils.getter_ema, ema)
        post_ops.append(ema_op)

        train_op = tf.train.MomentumOptimizer(
            lr, 0.9, use_nesterov=True).minimize(
                loss_xe + loss_xeu * wu + loss_ent * we + loss_wd * wd,
                colocate_gradients_with_ops=True)
        with tf.control_dependencies([train_op]):
            train_op = tf.group(*post_ops)

        return utils.EasyDict(
            xt=xt_in,
            x=x_in,
            y=y_in,
            label=l_in,
            train_op=train_op,
            classify_raw=tf.nn.softmax(classifier(
                x_in, training=False)),  # No EMA, for debugging.
            classify_op=tf.nn.softmax(
                classifier(x_in, getter=ema_getter, training=False)))