示例#1
0
    def __init__(self, train_dir, dataset, **kwargs):
        self.train_dir = os.path.join(train_dir,
                                      self.experiment_name(**kwargs))
        self.params = utils.EasyDict(kwargs)
        self.dataset = dataset
        self.session = None
        self.tmp = utils.EasyDict(print_queue=[], cache=utils.EasyDict())
        self.step = tf.train.get_or_create_global_step()
        self.ops = self.model(**kwargs)
        self.ops.update_step = tf.assign_add(self.step, FLAGS.batch)
        self.add_summaries(**kwargs)

        print(' Config '.center(80, '-'))
        print('train_dir', self.train_dir)
        print('%-32s %s' % ('Model', self.__class__.__name__))
        print('%-32s %s' % ('Dataset', dataset.name))
        for k, v in sorted(kwargs.items()):
            print('%-32s %s' % (k, v))
        print(' Model '.center(80, '-'))
        to_print = [
            tuple(['%s' % x for x in (v.name, np.prod(v.shape), v.shape)])
            for v in utils.model_vars(None)
        ]
        to_print.append(('Total', str(sum(int(x[1]) for x in to_print)), ''))
        sizes = [max([len(x[i]) for x in to_print]) for i in range(3)]
        fmt = '%%-%ds  %%%ds  %%%ds' % tuple(sizes)
        for x in to_print[:-1]:
            print(fmt % x)
        print()
        print(fmt % to_print[-1])
        print('-' * 80)
        self._create_initial_files()
        self.work_unit = None
        self.measurement = {}
示例#2
0
    def model(self, batch, lr, wd, wu, confidence, uratio, ema=0.999, **kwargs):
        hwc = [self.dataset.height, self.dataset.width, self.dataset.colors]
        xt_in = tf.placeholder(tf.float32, [batch] + hwc, 'xt')  # Training labeled
        x_in = tf.placeholder(tf.float32, [None] + hwc, 'x')  # Eval images
        y_in = tf.placeholder(tf.float32, [batch * uratio, 2] + hwc, 'y')  # Training unlabeled (weak, strong)
        l_in = tf.placeholder(tf.int32, [batch], 'labels')  # Labels

        lrate = tf.clip_by_value(tf.to_float(self.step) / (FLAGS.train_kimg << 10), 0, 1)
        lr *= tf.cos(lrate * (7 * np.pi) / (2 * 8))
        tf.summary.scalar('monitors/lr', lr)

        # Compute logits for xt_in and y_in
        classifier = lambda x, **kw: self.classifier(x, **kw, **kwargs).logits
        skip_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
        x = utils.interleave(tf.concat([xt_in, y_in[:, 0], y_in[:, 1]], 0), 2 * uratio + 1)
        logits = utils.para_cat(lambda x: classifier(x, training=True), x)
        logits = utils.de_interleave(logits, 2 * uratio+1)
        post_ops = [v for v in tf.get_collection(tf.GraphKeys.UPDATE_OPS) if v not in skip_ops]
        logits_x = logits[:batch]
        logits_weak, logits_strong = tf.split(logits[batch:], 2)
        del logits, skip_ops

        # Labeled cross-entropy
        loss_xe = tf.nn.sparse_softmax_cross_entropy_with_logits(labels=l_in, logits=logits_x)
        loss_xe = tf.reduce_mean(loss_xe)
        tf.summary.scalar('losses/xe', loss_xe)

        # Pseudo-label cross entropy for unlabeled data
        pseudo_labels = tf.stop_gradient(tf.nn.softmax(logits_weak))
        loss_xeu = tf.nn.sparse_softmax_cross_entropy_with_logits(labels=tf.argmax(pseudo_labels, axis=1),
                                                                  logits=logits_strong)
        pseudo_mask = tf.to_float(tf.reduce_max(pseudo_labels, axis=1) >= confidence)
        tf.summary.scalar('monitors/mask', tf.reduce_mean(pseudo_mask))
        loss_xeu = tf.reduce_mean(loss_xeu * pseudo_mask)
        tf.summary.scalar('losses/xeu', loss_xeu)

        # L2 regularization
        loss_wd = sum(tf.nn.l2_loss(v) for v in utils.model_vars('classify') if 'kernel' in v.name)
        tf.summary.scalar('losses/wd', loss_wd)

        ema = tf.train.ExponentialMovingAverage(decay=ema)
        ema_op = ema.apply(utils.model_vars())
        ema_getter = functools.partial(utils.getter_ema, ema)
        post_ops.append(ema_op)

        train_op = tf.train.MomentumOptimizer(lr, 0.9, use_nesterov=True).minimize(
            loss_xe + wu * loss_xeu + wd * loss_wd, colocate_gradients_with_ops=True)
        with tf.control_dependencies([train_op]):
            train_op = tf.group(*post_ops)

        return utils.EasyDict(
            xt=xt_in, x=x_in, y=y_in, label=l_in, train_op=train_op,
            classify_raw=tf.nn.softmax(classifier(x_in, training=False)),  # No EMA, for debugging.
            classify_op=tf.nn.softmax(classifier(x_in, getter=ema_getter, training=False)))
示例#3
0
 def model_val(self, batch, uratio, **kwargs):
     hwc = [self.dataset.height, self.dataset.width, self.dataset.colors]
     # validation labeled
     x = tf.placeholder(tf.float32, [batch * uratio] + hwc, 'v')
     l_in = tf.placeholder(tf.int32, [batch * uratio], 'labels')  # Labels
     classifier = lambda x, **kw: self.classifier(x, **kw, **kwargs).logits
     logits = utils.para_cat(lambda x: classifier(x, training=True), x)
     loss_xe = tf.reduce_mean(
         tf.nn.sparse_softmax_cross_entropy_with_logits(
             labels=l_in, logits=logits))  # loss
     grad_val_loss_op = tf.gradients(loss_xe, self.model_params)
     return utils.EasyDict(v=x, l=l_in, grad_val_loss_op=grad_val_loss_op)
示例#4
0
 def guess_label(self, y, classifier, p_data, p_model, T, **kwargs):
     del kwargs
     logits_y = [classifier(yi, training=True) for yi in y]
     logits_y = tf.concat(logits_y, 0)
     # Compute predicted probability distribution py.
     p_model_y = tf.reshape(tf.nn.softmax(logits_y),
                            [len(y), -1, self.nclass])
     p_model_y = tf.reduce_mean(p_model_y, axis=0)
     # Compute the target distribution.
     p_target = tf.pow(p_model_y, 1. / T)
     p_target /= tf.reduce_sum(p_target, axis=1, keep_dims=True)
     return utils.EasyDict(p_target=p_target, p_model=p_model_y)
示例#5
0
 def model_per_ex(self, nclass, batch, confidence, uratio, **kwargs):
     hwc = [self.dataset.height, self.dataset.width, self.dataset.colors]
     y_in = tf.placeholder(tf.float32, [batch * uratio, 2] + hwc, 'y')
     # weights for unlabeled data
     w_match = tf.placeholder(tf.float32, [batch * uratio], 'w_match')
     # forward
     classifier = lambda x, **kw: self.classifier(x, **kw, **kwargs).logits
     x = tf.concat([y_in[:, 0], y_in[:, 1]], 0)
     logits = classifier(x, training=True)
     logits_weak, logits_strong = tf.split(logits, 2)
     # Pseudo-label cross entropy for unlabeled data
     pseudo_labels = tf.stop_gradient(tf.nn.softmax(logits_weak))
     pseudo_labels_hard = tf.one_hot(tf.argmax(pseudo_labels, axis=1),
                                     nclass)
     loss_xeu_all = tf.nn.softmax_cross_entropy_with_logits_v2(
         labels=pseudo_labels_hard, logits=logits_strong)
     pseudo_mask = tf.to_float(
         tf.reduce_max(pseudo_labels, axis=1) >= confidence)
     loss_xeu = tf.reduce_mean(loss_xeu_all * pseudo_mask * w_match)
     # per-ex-grad_wrt_unlabeled_loss
     grads_train_per_ex = custom_gradient(loss_xeu_all, self.model_params)
     return utils.EasyDict(y=y_in,
                           w_match=w_match,
                           grads_train_per_ex=grads_train_per_ex)
示例#6
0
    def model(self,
              batch,
              lr,
              wd,
              wu,
              mom,
              delT,
              confidence,
              balance,
              uratio,
              ema=0.999,
              **kwargs):
        hwc = [self.dataset.height, self.dataset.width, self.dataset.colors]

        xt_in = tf.placeholder(tf.float32, [batch] + hwc,
                               'xt')  # Training labeled
        x_in = tf.placeholder(tf.float32, [None] + hwc, 'x')  # Eval images
        y_in = tf.placeholder(tf.float32, [batch * uratio, 2] + hwc,
                              'y')  # Training unlabeled (weak, strong)
        l_in = tf.placeholder(tf.int32, [batch], 'labels')  # Labels

        lrate = tf.clip_by_value(
            tf.to_float(self.step) / (FLAGS.train_kimg << 10), 0, 1)
        lr *= tf.cos(lrate * (7 * np.pi) / (2 * 8))
        tf.summary.scalar('monitors/lr', lr)

        # Compute logits for xt_in and y_in
        classifier = lambda x, **kw: self.classifier(x, **kw, **kwargs).logits
        skip_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
        x = utils.interleave(tf.concat([xt_in, y_in[:, 0], y_in[:, 1]], 0),
                             2 * uratio + 1)
        logits = utils.para_cat(lambda x: classifier(x, training=True), x)
        logits = utils.de_interleave(logits, 2 * uratio + 1)
        post_ops = [
            v for v in tf.get_collection(tf.GraphKeys.UPDATE_OPS)
            if v not in skip_ops
        ]
        logits_x = logits[:batch]
        logits_weak, logits_strong = tf.split(logits[batch:], 2)
        del logits, skip_ops

        # Labeled cross-entropy
        loss_xe = tf.nn.sparse_softmax_cross_entropy_with_logits(
            labels=l_in, logits=logits_x)
        loss_xe = tf.reduce_mean(loss_xe)
        tf.summary.scalar('losses/xe', loss_xe)

        # Pseudo-label cross entropy for unlabeled data
        pseudo_labels = tf.stop_gradient(tf.nn.softmax(logits_weak))
        pLabels = tf.math.argmax(pseudo_labels, axis=1)
        loss_xeu = tf.nn.sparse_softmax_cross_entropy_with_logits(
            labels=pLabels, logits=logits_strong)
        ####################### Modification
        pLabels = tf.cast(pLabels, dtype=tf.float32)
        classes, idx, counts = tf.unique_with_counts(pLabels)
        shape = tf.constant([self.dataset.nclass])
        classes = tf.cast(classes, dtype=tf.int32)
        class_count = tf.scatter_nd(tf.reshape(classes, [tf.size(classes), 1]),
                                    counts, shape)
        print_cc = tf.print("class_count ",
                            class_count,
                            output_stream=sys.stdout)
        class_count = tf.cast(class_count, dtype=tf.float32)
        mxCount = tf.reduce_max(class_count, axis=0)

        if balance > 0:
            pLabels = tf.cast(pLabels, dtype=tf.int32)
            if balance == 1 or balance == 4:
                confidences = tf.zeros_like(pLabels, dtype=tf.float32)
                ratios = 1.0 - tf.math.divide_no_nan(class_count, mxCount)
                ratios = confidence - delT * ratios
                confidences = tf.gather_nd(
                    ratios, tf.reshape(pLabels, [tf.size(pLabels), 1]))
                pseudo_mask = tf.reduce_max(pseudo_labels,
                                            axis=1) >= confidences
            else:
                pseudo_mask = tf.reduce_max(pseudo_labels,
                                            axis=1) >= confidence

            if balance == 3 or balance == 4:
                classes, idx, counts = tf.unique_with_counts(
                    tf.boolean_mask(pLabels, pseudo_mask))
                shape = tf.constant([self.dataset.nclass])
                classes = tf.cast(classes, dtype=tf.int32)
                class_count = tf.scatter_nd(
                    tf.reshape(classes, [tf.size(classes), 1]), counts, shape)
                class_count = tf.cast(class_count, dtype=tf.float32)
            pseudo_mask = tf.cast(pseudo_mask, dtype=tf.float32)

            if balance > 1:
                ratios = tf.math.divide_no_nan(
                    tf.ones_like(class_count, dtype=tf.float32), class_count)
                ratio = tf.gather_nd(
                    ratios, tf.reshape(pLabels, [tf.size(pLabels), 1]))
                # ratio = sum(pseudo_mask) * ratio / sum(ratio)
                Z = tf.reduce_sum(pseudo_mask)
                pseudo_mask = tf.math.multiply(
                    pseudo_mask, tf.cast(ratio, dtype=tf.float32))
                pseudo_mask = tf.math.divide_no_nan(
                    tf.scalar_mul(Z, pseudo_mask), tf.reduce_sum(pseudo_mask))
        else:
            pseudo_mask = tf.cast(
                tf.reduce_max(pseudo_labels, axis=1) >= confidence,
                dtype=tf.float32)


###################### End

#        tf.print(" class_count= ",class_count)
        tf.summary.scalar('monitors/mask', tf.reduce_mean(pseudo_mask))
        loss_xeu = tf.reduce_mean(loss_xeu * pseudo_mask)
        tf.summary.scalar('losses/xeu', loss_xeu)

        # L2 regularization
        loss_wd = sum(
            tf.nn.l2_loss(v) for v in utils.model_vars('classify')
            if 'kernel' in v.name)
        tf.summary.scalar('losses/wd', loss_wd)

        ema = tf.train.ExponentialMovingAverage(decay=ema)
        ema_op = ema.apply(utils.model_vars())
        ema_getter = functools.partial(utils.getter_ema, ema)
        post_ops.append(ema_op)

        #        train_op = tf.train.MomentumOptimizer(lr, 0.9, use_nesterov=True).minimize(
        train_op = tf.train.MomentumOptimizer(
            lr, mom,
            use_nesterov=True).minimize(loss_xe + wu * loss_xeu + wd * loss_wd,
                                        colocate_gradients_with_ops=True)
        with tf.control_dependencies([train_op]):
            train_op = tf.group(*post_ops)

        return utils.EasyDict(
            xt=xt_in,
            x=x_in,
            y=y_in,
            label=l_in,
            train_op=train_op,
            classify_raw=tf.nn.softmax(classifier(
                x_in, training=False)),  # No EMA, for debugging.
            classify_op=tf.nn.softmax(
                classifier(x_in, getter=ema_getter, training=False)))
示例#7
0
 def __init__(self, train_dir: str, **kwargs):
     self.train_dir = os.path.join(
         train_dir)  # , self.experiment_name(**kwargs))
     self.params = utils.EasyDict(kwargs)
     self.session = None
     self.tmp = utils.EasyDict(print_queue=[], cache=utils.EasyDict())
示例#8
0
    def model(self,
              batch,
              lr,
              wd,
              wu,
              wclr,
              mom,
              confidence,
              balance,
              delT,
              uratio,
              clrratio,
              temperature,
              ema=0.999,
              **kwargs):
        hwc = [self.dataset.height, self.dataset.width, self.dataset.colors]
        xt_in = tf.placeholder(tf.float32, [batch] + hwc,
                               'xt')  # Training labeled
        x_in = tf.placeholder(tf.float32, [None] + hwc, 'x')  # Eval images
        y_in = tf.placeholder(tf.float32, [batch * uratio, 2] + hwc,
                              'y')  # Training unlabeled (weak, strong)
        l_in = tf.placeholder(tf.int32, [batch], 'labels')  # Labels
        wclr_in = tf.placeholder(tf.int32, [1], 'wclr')  # wclr

        lrate = tf.clip_by_value(
            tf.to_float(self.step) / (FLAGS.train_kimg << 10), 0, 1)
        lr *= tf.cos(lrate * (7 * np.pi) / (2 * 8))
        tf.summary.scalar('monitors/lr', lr)

        # Compute logits for xt_in and y_in
        classifier = lambda x, **kw: self.classifier(x, **kw, **kwargs).logits
        skip_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
        x = utils.interleave(tf.concat([xt_in, y_in[:, 0], y_in[:, 1]], 0),
                             2 * uratio + 1)
        logits = utils.para_cat(lambda x: classifier(x, training=True), x)
        logits = utils.de_interleave(logits, 2 * uratio + 1)
        post_ops = [
            v for v in tf.get_collection(tf.GraphKeys.UPDATE_OPS)
            if v not in skip_ops
        ]
        logits_x = logits[:batch]
        logits_weak, logits_strong = tf.split(logits[batch:], 2)
        del logits, skip_ops

        # Labeled cross-entropy
        loss_xe = tf.nn.sparse_softmax_cross_entropy_with_logits(
            labels=l_in, logits=logits_x)
        loss_xe = tf.reduce_mean(loss_xe)
        tf.summary.scalar('losses/xe', loss_xe)

        # Pseudo-label cross entropy for unlabeled data
        pseudo_labels = tf.stop_gradient(tf.nn.softmax(logits_weak))
        loss_xeu = tf.nn.sparse_softmax_cross_entropy_with_logits(
            labels=tf.argmax(pseudo_labels, axis=1), logits=logits_strong)
        #        pseudo_mask = tf.to_float(tf.reduce_max(pseudo_labels, axis=1) >= confidence)
        pseudo_mask = self.class_balancing(pseudo_labels, balance, confidence,
                                           delT)
        tf.summary.scalar('monitors/mask', tf.reduce_mean(pseudo_mask))
        loss_xeu = tf.reduce_mean(loss_xeu * pseudo_mask)
        tf.summary.scalar('losses/xeu', loss_xeu)

        ####################### Modification
        # Contrastive loss term
        contrast_loss = 0
        if wclr > 0 and wclr_in == 0:
            ratio = min(uratio, clrratio)
            if FLAGS.clrDataAug == 1:
                preprocess_fn = functools.partial(
                    data_util.preprocess_for_train,
                    height=self.dataset.height,
                    width=self.dataset.width)
                x = tf.concat(
                    [lambda y: preprocess_fn(y), lambda y: preprocess_fn(y)],
                    0)
                embeds = lambda x, **kw: self.classifier(x, **kw, **kwargs
                                                         ).embeds
                hidden = utils.para_cat(lambda x: embeds(x, training=True), x)
            else:
                embeds = lambda x, **kw: self.classifier(x, **kw, **kwargs
                                                         ).embeds
                hiddens = utils.para_cat(lambda x: embeds(x, training=True), x)
                hiddens = utils.de_interleave(hiddens, 2 * uratio + 1)
                hiddens_weak, hiddens_strong = tf.split(hiddens[batch:], 2, 0)
                hidden = tf.concat([
                    hiddens_weak[:batch * ratio],
                    hiddens_strong[:batch * ratio]
                ],
                                   axis=0)
                del hiddens, hiddens_weak, hiddens_strong

            contrast_loss, _, _ = obj_lib.add_contrastive_loss(
                hidden,
                hidden_norm=True,  # FLAGS.hidden_norm,
                temperature=temperature,
                tpu_context=None)

            tf.summary.scalar('losses/contrast', contrast_loss)
            del embeds, hidden
###################### End

# L2 regularization
        loss_wd = sum(
            tf.nn.l2_loss(v) for v in utils.model_vars('classify')
            if 'kernel' in v.name)
        tf.summary.scalar('losses/wd', loss_wd)

        ema = tf.train.ExponentialMovingAverage(decay=ema)
        ema_op = ema.apply(utils.model_vars())
        ema_getter = functools.partial(utils.getter_ema, ema)
        post_ops.append(ema_op)

        #        train_op = tf.train.MomentumOptimizer(lr, 0.9, use_nesterov=True).minimize(
        train_op = tf.train.MomentumOptimizer(
            lr, mom, use_nesterov=True).minimize(
                loss_xe + wu * loss_xeu + wclr * contrast_loss + wd * loss_wd,
                colocate_gradients_with_ops=True)
        with tf.control_dependencies([train_op]):
            train_op = tf.group(*post_ops)

        return utils.EasyDict(
            xt=xt_in,
            x=x_in,
            y=y_in,
            label=l_in,
            wclr=wclr_in,
            train_op=train_op,
            classify_raw=tf.nn.softmax(classifier(
                x_in, training=False)),  # No EMA, for debugging.
            classify_op=tf.nn.softmax(
                classifier(x_in, getter=ema_getter, training=False)))
示例#9
0
    def model(self,
              batch,
              lr,
              wd,
              wu,
              we,
              confidence,
              uratio,
              temperature=1.0,
              tsa='no',
              tsa_pos=10,
              ema=0.999,
              **kwargs):
        hwc = [self.dataset.height, self.dataset.width, self.dataset.colors]
        xt_in = tf.placeholder(tf.float32, [batch] + hwc, 'xt')  # For training
        x_in = tf.placeholder(tf.float32, [None] + hwc, 'x')
        y_in = tf.placeholder(tf.float32, [batch * uratio, 2] + hwc, 'y')
        l_in = tf.placeholder(tf.int32, [batch], 'labels')
        l = tf.one_hot(l_in, self.nclass)

        lrate = tf.clip_by_value(
            tf.to_float(self.step) / (FLAGS.train_kimg << 10), 0, 1)
        lr *= tf.cos(lrate * (7 * np.pi) / (2 * 8))
        tf.summary.scalar('monitors/lr', lr)

        # Compute logits for xt_in and y_in
        classifier = lambda x, **kw: self.classifier(x, **kw, **kwargs).logits
        skip_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
        x = utils.interleave(tf.concat([xt_in, y_in[:, 0], y_in[:, 1]], 0),
                             2 * uratio + 1)
        logits = utils.para_cat(lambda x: classifier(x, training=True), x)
        logits = utils.de_interleave(logits, 2 * uratio + 1)
        post_ops = [
            v for v in tf.get_collection(tf.GraphKeys.UPDATE_OPS)
            if v not in skip_ops
        ]
        logits_x = logits[:batch]
        logits_weak, logits_strong = tf.split(logits[batch:], 2)
        del logits, skip_ops

        # softmax temperature control
        logits_weak_tgt = self.softmax_temperature_controlling(logits_weak,
                                                               T=temperature)
        # generate confidence mask based on sharpened distribution
        pseudo_labels = tf.stop_gradient(tf.nn.softmax(logits_weak))
        pseudo_mask = self.confidence_based_masking(logits=None,
                                                    p_class=pseudo_labels,
                                                    thresh=confidence)
        tf.summary.scalar('monitors/mask', tf.reduce_mean(pseudo_mask))
        tf.summary.scalar(
            'monitors/conf_weak',
            tf.reduce_mean(tf.reduce_max(tf.nn.softmax(logits_weak), axis=1)))
        tf.summary.scalar(
            'monitors/conf_strong',
            tf.reduce_mean(tf.reduce_max(tf.nn.softmax(logits_strong),
                                         axis=1)))

        kld = self.kl_divergence_from_logits(logits_weak_tgt, logits_strong)
        entropy = self.entropy_from_logits(logits_weak)
        loss_xeu = tf.reduce_mean(kld * pseudo_mask)
        tf.summary.scalar('losses/xeu', loss_xeu)
        loss_ent = tf.reduce_mean(entropy)
        tf.summary.scalar('losses/entropy', loss_ent)

        # supervised loss with TSA
        loss_mask = self.tsa_loss_mask(tsa=tsa,
                                       logits=logits_x,
                                       labels=l,
                                       tsa_pos=tsa_pos)
        loss_xe = tf.nn.softmax_cross_entropy_with_logits_v2(labels=l,
                                                             logits=logits_x)
        loss_xe = tf.reduce_sum(loss_xe * loss_mask) / tf.math.maximum(
            tf.reduce_sum(loss_mask), 1.0)
        tf.summary.scalar('losses/xe', loss_xe)
        tf.summary.scalar('losses/mask_sup', tf.reduce_mean(loss_mask))

        # L2 regularization
        loss_wd = sum(
            tf.nn.l2_loss(v) for v in utils.model_vars('classify')
            if 'kernel' in v.name)
        tf.summary.scalar('losses/wd', loss_wd)

        ema = tf.train.ExponentialMovingAverage(decay=ema)
        ema_op = ema.apply(utils.model_vars())
        ema_getter = functools.partial(utils.getter_ema, ema)
        post_ops.append(ema_op)

        train_op = tf.train.MomentumOptimizer(
            lr, 0.9, use_nesterov=True).minimize(
                loss_xe + loss_xeu * wu + loss_ent * we + loss_wd * wd,
                colocate_gradients_with_ops=True)
        with tf.control_dependencies([train_op]):
            train_op = tf.group(*post_ops)

        return utils.EasyDict(
            xt=xt_in,
            x=x_in,
            y=y_in,
            label=l_in,
            train_op=train_op,
            classify_raw=tf.nn.softmax(classifier(
                x_in, training=False)),  # No EMA, for debugging.
            classify_op=tf.nn.softmax(
                classifier(x_in, getter=ema_getter, training=False)))
示例#10
0
    def model(self, nu, w_match, warmup_kimg, batch, lr, wd, ema, dbuf, beta,
              mixmode, logit_norm, **kwargs):
        def classifier(x, logit_norm=logit_norm, **kw):
            v = self.classifier(x, **kw, **kwargs)[0]
            if not logit_norm:
                return v
            return v * tf.rsqrt(tf.reduce_mean(tf.square(v)) + 1e-8)

        def embedding(x, **kw):
            return self.classifier(x, **kw, **kwargs)[1]

        label_index = tf.Variable(utils.idx_to_fixlen(
            self.dataset.labeled_indices, self.dataset.ntrain),
                                  trainable=False,
                                  name='label_index',
                                  dtype=tf.int32)
        label_index_input = tf.placeholder(tf.int32, self.dataset.ntrain,
                                           'label_index_input')
        update_label_index = tf.assign(label_index, label_index_input)

        hwc = [self.dataset.height, self.dataset.width, self.dataset.colors]
        x_in = tf.placeholder(tf.float32, [None] + hwc, 'x')
        y_in = tf.placeholder(tf.float32, [None, nu] + hwc, 'y')
        l_in = tf.placeholder(tf.int32, [None], 'labels')
        wd *= lr
        w_match *= tf.clip_by_value(
            tf.cast(self.step, tf.float32) / (warmup_kimg << 10), 0, 1)
        augment = MixMode(mixmode)

        # Moving average of the current estimated label distribution
        p_model = layers.PMovingAverage('p_model', self.nclass, dbuf)
        p_target = layers.PMovingAverage(
            'p_target', self.nclass,
            dbuf)  # Rectified distribution (only for plotting)

        # Known (or inferred) true unlabeled distribution
        p_data = layers.PData(self.dataset)

        y = tf.reshape(tf.transpose(y_in, [1, 0, 2, 3, 4]), [-1] + hwc)
        guess = self.guess_label(tf.split(y, nu), classifier, p_data(),
                                 p_model(), **kwargs)
        ly = tf.stop_gradient(guess.p_target)
        lx = tf.one_hot(l_in, self.nclass)
        xy, labels_xy = augment([x_in] + tf.split(y, nu), [lx] + [ly] * nu,
                                [beta, beta])
        x, y = xy[0], xy[1:]
        labels_x, labels_y = labels_xy[0], tf.concat(labels_xy[1:], 0)
        del xy, labels_xy

        batches = layers.interleave([x] + y, batch)
        skip_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
        logits = [classifier(batches[0], training=True)]
        post_ops = [
            v for v in tf.get_collection(tf.GraphKeys.UPDATE_OPS)
            if v not in skip_ops
        ]
        for batchi in batches[1:]:
            logits.append(classifier(batchi, training=True))
        logits = layers.interleave(logits, batch)
        logits_x = logits[0]
        logits_y = tf.concat(logits[1:], 0)

        loss_xe = tf.nn.softmax_cross_entropy_with_logits_v2(labels=labels_x,
                                                             logits=logits_x)
        loss_xe = tf.reduce_mean(loss_xe)
        loss_l2u = tf.square(labels_y - tf.nn.softmax(logits_y))
        loss_l2u = tf.reduce_mean(loss_l2u)
        tf.summary.scalar('losses/xe', loss_xe)
        tf.summary.scalar('losses/l2u', loss_l2u)
        self.distribution_summary(p_data(), p_model(), p_target())

        ema = tf.train.ExponentialMovingAverage(decay=ema)
        ema_op = ema.apply(utils.model_vars())
        ema_getter = functools.partial(utils.getter_ema, ema)
        post_ops.extend([
            ema_op,
            p_model.update(guess.p_model),
            p_target.update(guess.p_target)
        ])
        if p_data.has_update:
            post_ops.append(p_data.update(lx))
        post_ops.extend([
            tf.assign(v, v * (1 - wd)) for v in utils.model_vars('classify')
            if 'kernel' in v.name
        ])

        train_op = tf.train.AdamOptimizer(lr).minimize(
            loss_xe + w_match * loss_l2u, colocate_gradients_with_ops=True)
        with tf.control_dependencies([train_op]):
            train_op = tf.group(*post_ops)

        # Tuning op: only retrain batch norm.
        skip_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
        classifier(batches[0], training=True)
        train_bn = tf.group(*[
            v for v in tf.get_collection(tf.GraphKeys.UPDATE_OPS)
            if v not in skip_ops
        ])

        return utils.EasyDict(
            x=x_in,
            y=y_in,
            label=l_in,
            train_op=train_op,
            tune_op=train_bn,
            classify_raw=tf.nn.softmax(
                classifier(x_in, logit_norm=False,
                           training=False)),  # No EMA, for debugging.
            classify_op=tf.nn.softmax(
                classifier(x_in,
                           logit_norm=False,
                           getter=ema_getter,
                           training=False)),
            embedding_op=embedding(x_in, training=False),
            label_index=label_index,
            update_label_index=update_label_index,
            label_index_input=label_index_input,
        )