Esempio n. 1
0
    def model(self, batch, lr, wd, ema, beta, w_match, warmup_kimg=1024, nu=2, mixmode='xxy.yxy', **kwargs):
        hwc = [self.dataset.height, self.dataset.width, self.dataset.colors]
        x_in = tf.placeholder(tf.float32, [None] + hwc, 'x')
        y_in = tf.placeholder(tf.float32, [None, nu] + hwc, 'y')
        l_in = tf.placeholder(tf.int32, [None], 'labels')
        wd *= lr
        w_match *= tf.clip_by_value(tf.cast(self.step, tf.float32) / (warmup_kimg << 10), 0, 1)
        augment = MixMode(mixmode)
        classifier = functools.partial(self.classifier, **kwargs)

        y = tf.reshape(tf.transpose(y_in, [1, 0, 2, 3, 4]), [-1] + hwc)
        # generate guessed label
        guess = self.guess_label(tf.split(y, nu), classifier, T=0.5, **kwargs)
        ly = tf.stop_gradient(guess.p_target)
        lx = tf.one_hot(l_in, self.nclass)
        # apply mixup
        xy, labels_xy = augment([x_in] + tf.split(y, nu), [lx] + [ly] * nu, [beta, beta])
        x, y = xy[0], xy[1:]
        labels_x, labels_y = labels_xy[0], tf.concat(labels_xy[1:], 0)
        del xy, labels_xy

        batches = layers.interleave([x] + y, batch)
        skip_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
        logits = [classifier(batches[0], training=True)]
        post_ops = [v for v in tf.get_collection(tf.GraphKeys.UPDATE_OPS) if v not in skip_ops]
        for batchi in batches[1:]:
            logits.append(classifier(batchi, training=True))
        logits = layers.interleave(logits, batch)
        logits_x = logits[0]
        logits_y = tf.concat(logits[1:], 0)

        loss_xe = tf.nn.softmax_cross_entropy_with_logits_v2(labels=labels_x, logits=logits_x)
        loss_xe = tf.reduce_mean(loss_xe)
        loss_l2u = tf.square(labels_y - tf.nn.softmax(logits_y))
        loss_l2u = tf.reduce_mean(loss_l2u)
        tf.summary.scalar('losses/xe', loss_xe)
        tf.summary.scalar('losses/l2u', loss_l2u)

        ema = tf.train.ExponentialMovingAverage(decay=ema)
        ema_op = ema.apply(utils.model_vars())
        ema_getter = functools.partial(utils.getter_ema, ema)
        post_ops.append(ema_op)
        post_ops.extend([tf.assign(v, v * (1 - wd)) for v in utils.model_vars('classify') if 'kernel' in v.name])

        train_op = tf.train.AdamOptimizer(lr).minimize(loss_xe + w_match * loss_l2u, colocate_gradients_with_ops=True)
        with tf.control_dependencies([train_op]):
            train_op = tf.group(*post_ops)

        # Tuning op: only retrain batch norm.
        skip_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
        classifier(batches[0], training=True)
        train_bn = tf.group(*[v for v in tf.get_collection(tf.GraphKeys.UPDATE_OPS)
                              if v not in skip_ops])

        return EasyDict(
            x=x_in, y=y_in, label=l_in, train_op=train_op, tune_op=train_bn,
            classify_raw=tf.nn.softmax(classifier(x_in, training=False)),  # No EMA, for debugging.
            classify_op=tf.nn.softmax(classifier(x_in, getter=ema_getter, training=False)))

        def cutmix():
Esempio n. 2
0
    def model(self,
              batch,
              lr,
              wd,
              ema,
              beta,
              w_match,
              warmup_kimg=1024,
              nu=2,
              mixmode='xxy.yxy',
              dbuf=128,
              **kwargs):
        hwc = [self.dataset.height, self.dataset.width, self.dataset.colors]
        xt_in = tf.placeholder(tf.float32, [batch] + hwc, 'xt')  # For training
        x_in = tf.placeholder(tf.float32, [None] + hwc, 'x')
        y_in = tf.placeholder(tf.float32, [batch, nu] + hwc, 'y')
        l_in = tf.placeholder(tf.int32, [batch], 'labels')
        wd *= lr
        w_match *= tf.clip_by_value(
            tf.cast(self.step, tf.float32) / (warmup_kimg << 10), 0, 1)
        augment = MixMode(mixmode)
        classifier = lambda x, **kw: self.classifier(x, **kw, **kwargs).logits

        # Moving average of the current estimated label distribution
        p_model = layers.PMovingAverage('p_model', self.nclass, dbuf)
        p_target = layers.PMovingAverage(
            'p_target', self.nclass,
            dbuf)  # Rectified distribution (only for plotting)

        # Known (or inferred) true unlabeled distribution
        p_data = layers.PData(self.dataset)

        y = tf.reshape(tf.transpose(y_in, [1, 0, 2, 3, 4]), [-1] + hwc)
        guess = self.guess_label(tf.split(y, nu), classifier, T=0.5, **kwargs)
        ly = tf.stop_gradient(guess.p_target)
        lx = tf.one_hot(l_in, self.nclass)
        xy, labels_xy = augment([xt_in] + tf.split(y, nu), [lx] + [ly] * nu,
                                [beta, beta])
        x, y = xy[0], xy[1:]
        labels_x, labels_y = labels_xy[0], tf.concat(labels_xy[1:], 0)
        del xy, labels_xy

        batches = layers.interleave([x] + y, batch)
        skip_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
        logits = [classifier(batches[0], training=True)]
        post_ops = [
            v for v in tf.get_collection(tf.GraphKeys.UPDATE_OPS)
            if v not in skip_ops
        ]
        for batchi in batches[1:]:
            logits.append(classifier(batchi, training=True))
        logits = layers.interleave(logits, batch)
        logits_x = logits[0]
        logits_y = tf.concat(logits[1:], 0)

        loss_xe = tf.nn.softmax_cross_entropy_with_logits_v2(labels=labels_x,
                                                             logits=logits_x)
        loss_xe = tf.reduce_mean(loss_xe)
        loss_l2u = tf.square(labels_y - tf.nn.softmax(logits_y))
        loss_l2u = tf.reduce_mean(loss_l2u)
        tf.summary.scalar('losses/xe', loss_xe)
        tf.summary.scalar('losses/l2u', loss_l2u)
        self.distribution_summary(p_data(), p_model(), p_target())

        ema = tf.train.ExponentialMovingAverage(decay=ema)
        ema_op = ema.apply(utils.model_vars())
        ema_getter = functools.partial(utils.getter_ema, ema)
        post_ops.extend([
            ema_op,
            p_model.update(guess.p_model),
            p_target.update(guess.p_target)
        ])
        if p_data.has_update:
            post_ops.append(p_data.update(lx))
        post_ops.extend([
            tf.assign(v, v * (1 - wd)) for v in utils.model_vars('classify')
            if 'kernel' in v.name
        ])

        train_op = tf.train.AdamOptimizer(lr).minimize(
            loss_xe + w_match * loss_l2u, colocate_gradients_with_ops=True)
        with tf.control_dependencies([train_op]):
            train_op = tf.group(*post_ops)

        return EasyDict(
            xt=xt_in,
            x=x_in,
            y=y_in,
            label=l_in,
            train_op=train_op,
            classify_raw=tf.nn.softmax(classifier(
                x_in, training=False)),  # No EMA, for debugging.
            classify_op=tf.nn.softmax(
                classifier(x_in, getter=ema_getter, training=False)))
Esempio n. 3
0
    def model(self,
              batch,
              lr,
              wd,
              ema,
              beta,
              w_match,
              warmup_kimg=1024,
              nu=2,
              mixmode='xxy.yxy',
              **kwargs):
        hwc = [self.dataset.height, self.dataset.width, self.dataset.colors]

        # Create placeholders for the labeled images, unlabeled images,
        # and the ground truth supervised labels respectively.
        x_in = tf.placeholder(tf.float32, [None] + hwc, 'x')
        y_in = tf.placeholder(tf.float32, [None, nu] + hwc, 'y')
        l_in = tf.placeholder(tf.int32, [None], 'labels')
        wd *= lr
        w_match *= tf.clip_by_value(
            tf.cast(self.step, tf.float32) / (warmup_kimg << 10), 0, 1)
        augment = MixMode(mixmode)
        classifier = functools.partial(self.classifier, **kwargs)

        y = tf.reshape(tf.transpose(y_in, [1, 0, 2, 3, 4]), [-1] + hwc)
        guess = self.guess_label(tf.split(y, nu), classifier, T=0.5, **kwargs)
        ly = tf.stop_gradient(guess.p_target)
        lx = tf.one_hot(l_in, self.nclass)

        # Create MixUp examples.
        xy, labels_xy = augment([x_in] + tf.split(y, nu), [lx] + [ly] * nu,
                                [beta, beta])
        x, y = xy[0], xy[1:]
        labels_x, labels_y = labels_xy[0], tf.concat(labels_xy[1:], 0)
        del xy, labels_xy

        # Create batches that represent both labeled and unlabeled batches.
        # For more, see google-research/mixmatch/issues/5.
        batches = layers.interleave([x] + y, batch)
        skip_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
        logits = [classifier(batches[0], training=True)]
        post_ops = [
            v for v in tf.get_collection(tf.GraphKeys.UPDATE_OPS)
            if v not in skip_ops
        ]
        for batchi in batches[1:]:
            logits.append(classifier(batchi, training=True))
        logits = layers.interleave(logits, batch)
        logits_x = logits[0]
        logits_y = tf.concat(logits[1:], 0)

        # Calculate supervised and unsupervised losses.
        loss_xe = tf.nn.softmax_cross_entropy_with_logits_v2(labels=labels_x,
                                                             logits=logits_x)
        if FLAGS.tsa != "none":
            print("Using training signal annealing...")
            loss_xe = self.anneal_sup_loss(logits_x, labels_x, loss_xe,
                                           self.step)
        else:
            loss_xe = tf.reduce_mean(loss_xe)

        loss_l2u = tf.square(labels_y - tf.nn.softmax(logits_y))

        if FLAGS.percent_mask > 0:
            print("Using percent-based confidence masking...")
            loss_l2u = self.percent_confidence_mask_unsup(
                logits_y, labels_y, loss_l2u)
        else:
            loss_l2u = tf.reduce_mean(loss_l2u)

        # Calculate largest predicted probability for each image.
        unsup_prob = tf.nn.softmax(logits_y, axis=-1)
        tf.summary.scalar('losses/min_unsup_prob',
                          tf.reduce_min(tf.reduce_max(unsup_prob, axis=-1)))
        tf.summary.scalar('losses/mean_unsup_prob',
                          tf.reduce_mean(tf.reduce_max(unsup_prob, axis=-1)))
        tf.summary.scalar('losses/max_unsup_prob',
                          tf.reduce_max(tf.reduce_max(unsup_prob, axis=-1)))

        # Print losses to tensorboard.
        tf.summary.scalar('losses/xe', loss_xe)
        tf.summary.scalar('losses/l2u', loss_l2u)
        tf.summary.scalar('losses/overall', loss_xe + w_match * loss_l2u)

        # Applying EMA weights to model. Conceptualized by Tarvainen & Valpola, 2017
        # See https://arxiv.org/abs/1703.01780 for more.
        ema = tf.train.ExponentialMovingAverage(decay=ema)
        ema_op = ema.apply(utils.model_vars())
        ema_getter = functools.partial(utils.getter_ema, ema)
        post_ops.append(ema_op)
        post_ops.extend([
            tf.assign(v, v * (1 - wd)) for v in utils.model_vars('classify')
            if 'kernel' in v.name
        ])

        train_op = tf.train.AdamOptimizer(lr).minimize(
            loss_xe + w_match * loss_l2u, colocate_gradients_with_ops=True)
        with tf.control_dependencies([train_op]):
            train_op = tf.group(*post_ops)

        # Tuning op: only retrain batch norm.
        skip_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
        classifier(batches[0], training=True)
        train_bn = tf.group(*[
            v for v in tf.get_collection(tf.GraphKeys.UPDATE_OPS)
            if v not in skip_ops
        ])

        return EasyDict(
            x=x_in,
            y=y_in,
            label=l_in,
            train_op=train_op,
            tune_op=train_bn,
            classify_raw=tf.nn.softmax(classifier(
                x_in, training=False)),  # No EMA, for debugging.
            classify_op=tf.nn.softmax(
                classifier(x_in, getter=ema_getter, training=False)),
            eval_loss_op=tf.reduce_mean(
                tf.nn.softmax_cross_entropy_with_logits_v2(
                    logits=classifier(x_in, getter=ema_getter, training=False),
                    labels=tf.one_hot(l_in, self.nclass))))