コード例 #1
0
import numpy as np
import tensorflow as tf
from absl import flags

from libml import utils, ctaugment
from libml.utils import EasyDict
from third_party.auto_augment import augmentations, policies

FLAGS = flags.FLAGS
POOL = None
POLICIES = EasyDict(
    cifar10=policies.cifar10_policies(),
    cifar10p=policies.cifar10_policies(),
    #                    color=policies.color_policies(),
    cifar10imb=policies.cifar10_policies(),
    cifar100=policies.cifar10_policies(),
    svhn=policies.svhn_policies(),
    svhnp=policies.svhn_policies(),
    svhnp_noextra=policies.svhn_policies(),
    svhn_noextra=policies.svhn_policies())

RANDOM_POLICY_OPS = ('Identity', 'AutoContrast', 'Equalize', 'Rotate',
                     'Solarize', 'Color', 'Contrast', 'Brightness',
                     'Sharpness', 'ShearX', 'TranslateX', 'TranslateY',
                     'Posterize', 'ShearY')
AUGMENT_ENUM = 'd x m aa aac aacc ra rac'.split() + [
    'r%d_%d_%d' % (nops, mag, cutout)
    for nops, mag, cutout in itertools.product(range(1, 5), range(1, 16),
                                               range(0, 100, 25))
] + ['rac%d' % (mag) for mag in range(1, 10)]
コード例 #2
0
    def classifier(self,
                   x,
                   scales,
                   filters,
                   repeat,
                   training,
                   getter=None,
                   dropout=0,
                   **kwargs):
        del kwargs
        leaky_relu = functools.partial(tf.nn.leaky_relu, alpha=0.1)
        bn_args = dict(training=training, momentum=0.999)

        def conv_args(k, f):
            return dict(padding='same',
                        kernel_initializer=tf.random_normal_initializer(
                            stddev=tf.rsqrt(0.5 * k * k * f)))

        def residual(x0, filters, stride=1, activate_before_residual=False):
            x = leaky_relu(tf.layers.batch_normalization(x0, **bn_args))
            if activate_before_residual:
                x0 = x

            x = tf.layers.conv2d(x,
                                 filters,
                                 3,
                                 strides=stride,
                                 **conv_args(3, filters))
            x = leaky_relu(tf.layers.batch_normalization(x, **bn_args))
            x = tf.layers.conv2d(x, filters, 3, **conv_args(3, filters))

            if x0.get_shape()[3] != filters:
                x0 = tf.layers.conv2d(x0,
                                      filters,
                                      1,
                                      strides=stride,
                                      **conv_args(1, filters))

            return x0 + x

        with tf.variable_scope('classify',
                               reuse=tf.AUTO_REUSE,
                               custom_getter=getter):
            y = tf.layers.conv2d((x - self.dataset.mean) / self.dataset.std,
                                 16, 3, **conv_args(3, 16))
            for scale in range(scales):
                y = residual(y,
                             filters << scale,
                             stride=2 if scale else 1,
                             activate_before_residual=scale == 0)
                for i in range(repeat - 1):
                    y = residual(y, filters << scale)

            y = leaky_relu(tf.layers.batch_normalization(y, **bn_args))
            y = embeds = tf.reduce_mean(y, [1, 2])
            if dropout and training:
                y = tf.nn.dropout(y, 1 - dropout)
            logits = tf.layers.dense(
                y,
                self.nclass,
                kernel_initializer=tf.glorot_normal_initializer())
        return EasyDict(logits=logits, embeds=embeds)
コード例 #3
0
ファイル: augment.py プロジェクト: biwana/fixmatch
import itertools
import multiprocessing
import random

import numpy as np
import tensorflow as tf
from absl import flags

from libml import utils, ctaugment
from libml.utils import EasyDict
from third_party.auto_augment import augmentations, policies

FLAGS = flags.FLAGS
POOL = None
POLICIES = EasyDict(cifar10=policies.cifar10_policies(),
                    cifar100=policies.cifar10_policies(),
                    svhn=policies.svhn_policies(),
                    svhn_noextra=policies.svhn_policies())

RANDOM_POLICY_OPS = ('Identity', 'AutoContrast', 'Equalize', 'Rotate',
                     'Solarize', 'Color', 'Contrast', 'Brightness',
                     'Sharpness', 'ShearX', 'TranslateX', 'TranslateY',
                     'Posterize', 'ShearY')
AUGMENT_ENUM = 'd x m aa aac ra rac'.split() + [
    'r%d_%d_%d' % (nops, mag, cutout)
    for nops, mag, cutout in itertools.product(range(1, 5), range(1, 16),
                                               range(0, 100, 25))
] + ['rac%d' % (mag) for mag in range(1, 10)]

flags.DEFINE_integer('K', 1,
                     'Number of strong augmentation for unlabeled data.')
flags.DEFINE_enum(
コード例 #4
0
    def model(self,
              batch,
              lr,
              wd,
              beta,
              w_kl,
              w_match,
              w_rot,
              K,
              use_xe,
              warmup_kimg=1024,
              T=0.5,
              mixmode='xxy.yxy',
              dbuf=128,
              ema=0.999,
              **kwargs):
        hwc = [self.dataset.height, self.dataset.width, self.dataset.colors]
        xt_in = tf.placeholder(tf.float32, [batch] + hwc, 'xt')  # For training
        x_in = tf.placeholder(tf.float32, [None] + hwc, 'x')
        y_in = tf.placeholder(tf.float32, [batch, K + 1] + hwc, 'y')
        l_in = tf.placeholder(tf.int32, [batch], 'labels')

        w_match *= tf.clip_by_value(
            tf.cast(self.step, tf.float32) / (warmup_kimg << 10), 0, 1)
        lrate = tf.clip_by_value(
            tf.to_float(self.step) / (FLAGS.train_kimg << 10), 0, 1)
        lr *= tf.cos(lrate * (7 * np.pi) / (2 * 8))
        tf.summary.scalar('monitors/lr', lr)
        augment = layers.MixMode(mixmode)

        gpu = utils.get_gpu()

        def classifier_to_gpu(x, **kw):
            with tf.device(next(gpu)):
                return self.classifier(x, **kw, **kwargs).logits

        def random_rotate(x):
            b4 = batch // 4
            x, xt = x[:2 * b4], tf.transpose(x[2 * b4:], [0, 2, 1, 3])
            l = np.zeros(b4, np.int32)
            l = tf.constant(np.concatenate([l, l + 1, l + 2, l + 3], axis=0))
            return tf.concat(
                [x[:b4], x[b4:, ::-1, ::-1], xt[:b4, ::-1], xt[b4:, :, ::-1]],
                axis=0), l

        # Moving average of the current estimated label distribution
        p_model = layers.PMovingAverage('p_model', self.nclass, dbuf)
        p_target = layers.PMovingAverage(
            'p_target', self.nclass,
            dbuf)  # Rectified distribution (only for plotting)

        # Known (or inferred) true unlabeled distribution
        p_data = layers.PData(self.dataset)

        if w_rot > 0:
            rot_y, rot_l = random_rotate(y_in[:, 1])
            with tf.device(next(gpu)):
                rot_logits = self.classifier_rot(
                    self.classifier(rot_y, training=True, **kwargs).embeds)
            loss_rot = tf.nn.softmax_cross_entropy_with_logits_v2(
                labels=tf.one_hot(rot_l, 4), logits=rot_logits)
            loss_rot = tf.reduce_mean(loss_rot)
            tf.summary.scalar('losses/rot', loss_rot)
        else:
            loss_rot = 0

        if kwargs['redux'] == '1st' and w_kl <= 0:
            logits_y = [classifier_to_gpu(y_in[:, 0], training=True)] * (K + 1)
        elif kwargs['redux'] == '1st':
            logits_y = [
                classifier_to_gpu(y_in[:, i], training=True) for i in range(2)
            ]
            logits_y += logits_y[:1] * (K - 1)
        else:
            logits_y = [
                classifier_to_gpu(y_in[:, i], training=True)
                for i in range(K + 1)
            ]

        guess = self.guess_label(logits_y, p_data(), p_model(), T=T, **kwargs)
        ly = tf.stop_gradient(guess.p_target)
        if w_kl > 0:
            w_kl *= tf.clip_by_value(
                tf.cast(self.step, tf.float32) / (warmup_kimg << 10), 0, 1)
            loss_kl = tf.nn.softmax_cross_entropy_with_logits_v2(
                labels=ly[:batch], logits=logits_y[1])
            loss_kl = tf.reduce_mean(loss_kl)
            tf.summary.scalar('losses/kl', loss_kl)
        else:
            loss_kl = 0
        del logits_y

        lx = tf.one_hot(l_in, self.nclass)
        xy, labels_xy = augment([xt_in] + [y_in[:, i] for i in range(K + 1)],
                                [lx] + tf.split(ly, K + 1), [beta, beta])
        x, y = xy[0], xy[1:]
        labels_x, labels_y = labels_xy[0], tf.concat(labels_xy[1:], 0)
        del xy, labels_xy

        batches = layers.interleave([x] + y, batch)
        logits = [classifier_to_gpu(yi, training=True) for yi in batches[:-1]]
        skip_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
        logits.append(classifier_to_gpu(batches[-1], training=True))
        post_ops = [
            v for v in tf.get_collection(tf.GraphKeys.UPDATE_OPS)
            if v not in skip_ops
        ]
        logits = layers.interleave(logits, batch)
        logits_x = logits[0]
        logits_y = tf.concat(logits[1:], 0)
        del batches, logits

        loss_xe = tf.nn.softmax_cross_entropy_with_logits_v2(labels=labels_x,
                                                             logits=logits_x)
        loss_xe = tf.reduce_mean(loss_xe)
        if use_xe:
            loss_xeu = tf.nn.softmax_cross_entropy_with_logits_v2(
                labels=labels_y, logits=logits_y)
        else:
            loss_xeu = tf.square(labels_y - tf.nn.softmax(logits_y))
        loss_xeu = tf.reduce_mean(loss_xeu)
        tf.summary.scalar('losses/xe', loss_xe)
        tf.summary.scalar('losses/%s' % ('xeu' if use_xe else 'l2u'), loss_xeu)
        self.distribution_summary(p_data(), p_model(), p_target())

        # L2 regularization
        loss_wd = sum(
            tf.nn.l2_loss(v) for v in utils.model_vars('classify')
            if 'kernel' in v.name)
        tf.summary.scalar('losses/wd', loss_wd)

        ema = tf.train.ExponentialMovingAverage(decay=ema)
        ema_op = ema.apply(utils.model_vars())
        ema_getter = functools.partial(utils.getter_ema, ema)
        post_ops.extend([
            ema_op,
            p_model.update(guess.p_model),
            p_target.update(guess.p_target)
        ])
        if p_data.has_update:
            post_ops.append(p_data.update(lx))

        train_op = tf.train.MomentumOptimizer(
            lr, 0.9, use_nesterov=True).minimize(
                loss_xe + w_kl * loss_kl + w_match * loss_xeu +
                w_rot * loss_rot + wd * loss_wd,
                colocate_gradients_with_ops=True)
        with tf.control_dependencies([train_op]):
            train_op = tf.group(*post_ops)

        return EasyDict(
            xt=xt_in,
            x=x_in,
            y=y_in,
            label=l_in,
            train_op=train_op,
            classify_op=tf.nn.softmax(
                classifier_to_gpu(x_in, getter=ema_getter, training=False)),
            classify_raw=tf.nn.softmax(classifier_to_gpu(
                x_in, training=False)))  # No EMA, for debugging.
コード例 #5
0
    def classifier(self,
                   x,
                   scales,
                   filters,
                   repeat,
                   training,
                   getter=None,
                   dropout=0,
                   **kwargs):
        del kwargs
        bn_args = dict(training=training, momentum=0.999)

        def conv_args(k, f):
            return dict(padding='same',
                        use_bias=False,
                        kernel_initializer=tf.random_normal_initializer(
                            stddev=tf.rsqrt(0.5 * k * k * f)))

        def residual(x0, filters, stride=1):
            def branch():
                x = tf.nn.relu(x0)
                x = tf.layers.conv2d(x,
                                     filters,
                                     3,
                                     strides=stride,
                                     **conv_args(3, filters))
                x = tf.nn.relu(tf.layers.batch_normalization(x, **bn_args))
                x = tf.layers.conv2d(x, filters, 3, **conv_args(3, filters))
                x = tf.layers.batch_normalization(x, **bn_args)
                return x

            x = layers.shakeshake(branch(), branch(), training)

            if stride == 2:
                x1 = tf.layers.conv2d(tf.nn.relu(x0[:, ::2, ::2]),
                                      filters >> 1, 1,
                                      **conv_args(1, filters >> 1))
                x2 = tf.layers.conv2d(tf.nn.relu(x0[:, 1::2, 1::2]),
                                      filters >> 1, 1,
                                      **conv_args(1, filters >> 1))
                x0 = tf.concat([x1, x2], axis=3)
                x0 = tf.layers.batch_normalization(x0, **bn_args)
            elif x0.get_shape()[3] != filters:
                x0 = tf.layers.conv2d(x0, filters, 1, **conv_args(1, filters))
                x0 = tf.layers.batch_normalization(x0, **bn_args)

            return x0 + x

        with tf.variable_scope('classify',
                               reuse=tf.AUTO_REUSE,
                               custom_getter=getter):
            y = tf.layers.conv2d((x - self.dataset.mean) / self.dataset.std,
                                 16, 3, **conv_args(3, 16))
            for scale, i in itertools.product(range(scales), range(repeat)):
                with tf.variable_scope('layer%d.%d' % (scale + 1, i)):
                    if i == 0:
                        y = residual(y,
                                     filters << scale,
                                     stride=2 if scale else 1)
                    else:
                        y = residual(y, filters << scale)

            y = embeds = tf.reduce_mean(y, [1, 2])
            if dropout and training:
                y = tf.nn.dropout(y, 1 - dropout)
            logits = tf.layers.dense(
                y,
                self.nclass,
                kernel_initializer=tf.glorot_normal_initializer())
        return EasyDict(logits=logits, embeds=embeds)
コード例 #6
0
    def model(self, batch, lr, wd, ema, warmup_pos, consistency_weight, beta,
              **kwargs):
        hwc = [self.dataset.height, self.dataset.width, self.dataset.colors]
        xt_in = tf.placeholder(tf.float32, [batch] + hwc, 'xt')  # For training
        x_in = tf.placeholder(tf.float32, [None] + hwc, 'x')
        y_in = tf.placeholder(tf.float32, [batch, 2] + hwc, 'y')
        l_in = tf.placeholder(tf.int32, [batch], 'labels')
        l = tf.one_hot(l_in, self.nclass)
        wd *= lr
        warmup = tf.clip_by_value(
            tf.to_float(self.step) / (warmup_pos * (FLAGS.train_kimg << 10)),
            0, 1)

        y = tf.reshape(tf.transpose(y_in, [1, 0, 2, 3, 4]), [-1] + hwc)
        y_1, y_2 = tf.split(y, 2)

        mix = tf.distributions.Beta(beta,
                                    beta).sample([tf.shape(x_in)[0], 1, 1, 1])
        mix = tf.maximum(mix, 1 - mix)

        classifier = lambda x, **kw: self.classifier(x, **kw, **kwargs).logits
        logits_x = classifier(xt_in, training=True)
        post_ops = tf.get_collection(
            tf.GraphKeys.UPDATE_OPS
        )  # Take only first call to update batch norm.

        ema = tf.train.ExponentialMovingAverage(decay=ema)
        ema_op = ema.apply(utils.model_vars())
        ema_getter = functools.partial(utils.getter_ema, ema)
        logits_teacher = classifier(y_1, training=True, getter=ema_getter)
        labels_teacher = tf.stop_gradient(tf.nn.softmax(logits_teacher))
        labels_teacher = labels_teacher * mix[:, :, 0,
                                              0] + labels_teacher[::-1] * (
                                                  1 - mix[:, :, 0, 0])
        logits_student = classifier(y_1 * mix + y_1[::-1] * (1 - mix),
                                    training=True)
        loss_mt = tf.reduce_mean(
            (labels_teacher - tf.nn.softmax(logits_student))**2, -1)
        loss_mt = tf.reduce_mean(loss_mt)

        loss = tf.nn.softmax_cross_entropy_with_logits_v2(labels=l,
                                                          logits=logits_x)
        loss = tf.reduce_mean(loss)
        tf.summary.scalar('losses/xe', loss)
        tf.summary.scalar('losses/mt', loss_mt)

        post_ops.append(ema_op)
        post_ops.extend([
            tf.assign(v, v * (1 - wd)) for v in utils.model_vars('classify')
            if 'kernel' in v.name
        ])

        train_op = tf.train.AdamOptimizer(lr).minimize(
            loss + loss_mt * warmup * consistency_weight,
            colocate_gradients_with_ops=True)
        with tf.control_dependencies([train_op]):
            train_op = tf.group(*post_ops)

        return EasyDict(
            xt=xt_in,
            x=x_in,
            y=y_in,
            label=l_in,
            train_op=train_op,
            classify_raw=tf.nn.softmax(classifier(
                x_in, training=False)),  # No EMA, for debugging.
            classify_op=tf.nn.softmax(
                classifier(x_in, getter=ema_getter, training=False)))
コード例 #7
0
    def model(self,
              batch,
              lr,
              wd,
              ema,
              beta,
              w_match,
              warmup_kimg=1024,
              nu=2,
              mixmode='xxy.yxy',
              dbuf=128,
              **kwargs):
        hwc = [self.dataset.height, self.dataset.width, self.dataset.colors]
        xt_in = tf.placeholder(tf.float32, [batch] + hwc, 'xt')  # For training
        x_in = tf.placeholder(tf.float32, [None] + hwc, 'x')
        y_in = tf.placeholder(tf.float32, [batch, nu] + hwc, 'y')
        l_in = tf.placeholder(tf.int32, [batch], 'labels')

        w_match *= tf.clip_by_value(
            tf.cast(self.step, tf.float32) / (warmup_kimg << 10), 0, 1)
        lrate = tf.clip_by_value(
            tf.to_float(self.step) / (FLAGS.train_kimg << 10), 0, 1)
        lr *= tf.cos(lrate * (7 * np.pi) / (2 * 8))
        tf.summary.scalar('monitors/lr', lr)
        augment = MixMode(mixmode)
        classifier = lambda x, **kw: self.classifier(x, **kw, **kwargs).logits

        # Moving average of the current estimated label distribution
        p_model = layers.PMovingAverage('p_model', self.nclass, dbuf)
        p_target = layers.PMovingAverage(
            'p_target', self.nclass,
            dbuf)  # Rectified distribution (only for plotting)

        # Known (or inferred) true unlabeled distribution
        p_data = layers.PData(self.dataset)

        y = tf.reshape(tf.transpose(y_in, [1, 0, 2, 3, 4]), [-1] + hwc)
        guess = self.guess_label(tf.split(y, nu), classifier, T=0.5, **kwargs)
        ly = tf.stop_gradient(guess.p_target)
        lx = tf.one_hot(l_in, self.nclass)
        xy, labels_xy = augment([xt_in] + tf.split(y, nu), [lx] + [ly] * nu,
                                [beta, beta])
        x, y = xy[0], xy[1:]
        labels_x, labels_y = labels_xy[0], tf.concat(labels_xy[1:], 0)
        del xy, labels_xy

        batches = layers.interleave([x] + y, batch)
        skip_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
        logits = [classifier(batches[0], training=True)]
        post_ops = [
            v for v in tf.get_collection(tf.GraphKeys.UPDATE_OPS)
            if v not in skip_ops
        ]
        for batchi in batches[1:]:
            logits.append(classifier(batchi, training=True))
        logits = layers.interleave(logits, batch)
        logits_x = logits[0]
        logits_y = tf.concat(logits[1:], 0)

        loss_xe = tf.nn.softmax_cross_entropy_with_logits_v2(labels=labels_x,
                                                             logits=logits_x)
        loss_xe = tf.reduce_mean(loss_xe)
        loss_l2u = tf.square(labels_y - tf.nn.softmax(logits_y))
        loss_l2u = tf.reduce_mean(loss_l2u)
        tf.summary.scalar('losses/xe', loss_xe)
        tf.summary.scalar('losses/l2u', loss_l2u)
        self.distribution_summary(p_data(), p_model(), p_target())

        # L2 regularization
        loss_wd = sum(
            tf.nn.l2_loss(v) for v in utils.model_vars('classify')
            if 'kernel' in v.name)
        tf.summary.scalar('losses/wd', loss_wd)

        ema = tf.train.ExponentialMovingAverage(decay=ema)
        ema_op = ema.apply(utils.model_vars())
        ema_getter = functools.partial(utils.getter_ema, ema)
        post_ops.append(ema_op)

        train_op = tf.train.MomentumOptimizer(
            lr, 0.9, use_nesterov=True).minimize(
                loss_xe + w_match * loss_l2u + wd * loss_wd,
                colocate_gradients_with_ops=True)
        with tf.control_dependencies([train_op]):
            train_op = tf.group(*post_ops)

        return EasyDict(
            xt=xt_in,
            x=x_in,
            y=y_in,
            label=l_in,
            train_op=train_op,
            classify_raw=tf.nn.softmax(classifier(
                x_in, training=False)),  # No EMA, for debugging.
            classify_op=tf.nn.softmax(
                classifier(x_in, getter=ema_getter, training=False)))
コード例 #8
0
    def model(self, batch, lr, wd, ema, warmup_pos, consistency_weight,
              **kwargs):
        hwc = [self.dataset.height, self.dataset.width, self.dataset.colors]
        xt_in = tf.placeholder(tf.float32, [batch] + hwc, 'xt')  # For training
        x_in = tf.placeholder(tf.float32, [None] + hwc, 'x')
        y_in = tf.placeholder(tf.float32, [batch, 2] + hwc, 'y')
        l_in = tf.placeholder(tf.int32, [batch], 'labels')
        l = tf.one_hot(l_in, self.nclass)

        warmup = tf.clip_by_value(
            tf.to_float(self.step) / (warmup_pos * (FLAGS.train_kimg << 10)),
            0, 1)
        lrate = tf.clip_by_value(
            tf.to_float(self.step) / (FLAGS.train_kimg << 10), 0, 1)
        lr *= tf.cos(lrate * (7 * np.pi) / (2 * 8))
        tf.summary.scalar('monitors/lr', lr)

        classifier = lambda x, **kw: self.classifier(x, **kw, **kwargs).logits
        logits_x = classifier(xt_in, training=True)
        post_ops = tf.get_collection(
            tf.GraphKeys.UPDATE_OPS
        )  # Take only first call to update batch norm.
        y = tf.reshape(tf.transpose(y_in, [1, 0, 2, 3, 4]), [-1] + hwc)
        y_1, y_2 = tf.split(y, 2)
        ema = tf.train.ExponentialMovingAverage(decay=ema)
        ema_op = ema.apply(utils.model_vars())
        ema_getter = functools.partial(utils.getter_ema, ema)
        logits_y = classifier(y_1, training=True, getter=ema_getter)
        logits_teacher = tf.stop_gradient(logits_y)
        logits_student = classifier(y_2, training=True)
        loss_mt = tf.reduce_mean(
            (tf.nn.softmax(logits_teacher) - tf.nn.softmax(logits_student))**2,
            -1)
        loss_mt = tf.reduce_mean(loss_mt)

        loss = tf.nn.softmax_cross_entropy_with_logits_v2(labels=l,
                                                          logits=logits_x)
        loss = tf.reduce_mean(loss)
        tf.summary.scalar('losses/xe', loss)
        tf.summary.scalar('losses/mt', loss_mt)

        # L2 regularization
        loss_wd = sum(
            tf.nn.l2_loss(v) for v in utils.model_vars('classify')
            if 'kernel' in v.name)
        tf.summary.scalar('losses/wd', loss_wd)

        post_ops.append(ema_op)
        train_op = tf.train.MomentumOptimizer(
            lr, 0.9, use_nesterov=True).minimize(
                loss + loss_mt * warmup * consistency_weight + wd * loss_wd,
                colocate_gradients_with_ops=True)
        with tf.control_dependencies([train_op]):
            train_op = tf.group(*post_ops)

        return EasyDict(
            xt=xt_in,
            x=x_in,
            y=y_in,
            label=l_in,
            train_op=train_op,
            classify_raw=tf.nn.softmax(classifier(
                x_in, training=False)),  # No EMA, for debugging.
            classify_op=tf.nn.softmax(
                classifier(x_in, getter=ema_getter, training=False)))
コード例 #9
0
    def model(self,
              batch,
              lr,
              wd,
              ema,
              beta,
              w_match,
              warmup_kimg=1024,
              nu=2,
              mixmode='xxy.yxy',
              dbuf=128,
              **kwargs):
        # height, width, colors
        hwc = [self.dataset.height, self.dataset.width, self.dataset.colors]
        # labeled data [batch,32,32,3]
        xt_in = tf.placeholder(tf.float32, [batch] + hwc, 'xt')  # For training
        # labeled data [?,32,32,3]
        x_in = tf.placeholder(tf.float32, [None] + hwc, 'x')
        # unlabeled data [?,2,32,32,3], 每个未标记样本生成两个数据增强样本(nu=2)
        y_in = tf.placeholder(tf.float32, [batch, nu] + hwc, 'y')
        l_in = tf.placeholder(tf.int32, [batch], 'labels')
        # 使用weight decay调整权重,防止过拟合
        wd *= lr
        # 在训练的前期逐步让w_match增长到最大值
        w_match *= tf.clip_by_value(
            tf.cast(self.step, tf.float32) / (warmup_kimg << 10), 0, 1)
        augment = MixMode(mixmode)
        classifier = lambda x, **kw: self.classifier(x, **kw, **kwargs).logits

        # Moving average of the current estimated label distribution
        p_model = layers.PMovingAverage('p_model', self.nclass, dbuf)
        p_target = layers.PMovingAverage(
            'p_target', self.nclass,
            dbuf)  # Rectified distribution (only for plotting)

        # Known (or inferred) true unlabeled distribution
        p_data = layers.PData(self.dataset)

        # 将y_in([batch, nu, hwc]) 转化为[nu * batch, hwc]
        y = tf.reshape(tf.transpose(y_in, [1, 0, 2, 3, 4]), [-1] + hwc)
        guess = self.guess_label(tf.split(y, nu), classifier, T=0.5, **kwargs)
        # ly:unlabeled data的目标标签,当做真实标签使用
        ly = tf.stop_gradient(guess.p_target)
        lx = tf.one_hot(l_in, self.nclass)
        xy, labels_xy = augment([xt_in] + tf.split(y, nu), [lx] + [ly] * nu,
                                [beta, beta])
        x, y = xy[0], xy[1:]
        labels_x, labels_y = labels_xy[0], tf.concat(labels_xy[1:], 0)
        del xy, labels_xy

        batches = layers.interleave([x] + y, batch)
        skip_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
        logits = [classifier(batches[0], training=True)]
        post_ops = [
            v for v in tf.get_collection(tf.GraphKeys.UPDATE_OPS)
            if v not in skip_ops
        ]
        for batchi in batches[1:]:
            logits.append(classifier(batchi, training=True))
        logits = layers.interleave(logits, batch)
        logits_x = logits[0]
        logits_y = tf.concat(logits[1:], 0)

        loss_xe = tf.nn.softmax_cross_entropy_with_logits_v2(labels=labels_x,
                                                             logits=logits_x)
        loss_xe = tf.reduce_mean(loss_xe)
        loss_l2u = tf.square(labels_y - tf.nn.softmax(logits_y))
        loss_l2u = tf.reduce_mean(loss_l2u)
        tf.summary.scalar('losses/xe', loss_xe)
        tf.summary.scalar('losses/l2u', loss_l2u)
        self.distribution_summary(p_data(), p_model(), p_target())

        ema = tf.train.ExponentialMovingAverage(decay=ema)
        ema_op = ema.apply(utils.model_vars())
        ema_getter = functools.partial(utils.getter_ema, ema)
        post_ops.extend([
            ema_op,
            p_model.update(guess.p_model),
            p_target.update(guess.p_target)
        ])
        if p_data.has_update:
            post_ops.append(p_data.update(lx))
        post_ops.extend([
            tf.assign(v, v * (1 - wd)) for v in utils.model_vars('classify')
            if 'kernel' in v.name
        ])

        train_op = tf.train.AdamOptimizer(lr).minimize(
            loss_xe + w_match * loss_l2u, colocate_gradients_with_ops=True)
        with tf.control_dependencies([train_op]):
            train_op = tf.group(*post_ops)

        return EasyDict(
            xt=xt_in,
            x=x_in,
            y=y_in,
            label=l_in,
            train_op=train_op,
            classify_raw=tf.nn.softmax(classifier(
                x_in, training=False)),  # No EMA, for debugging.
            classify_op=tf.nn.softmax(
                classifier(x_in, getter=ema_getter, training=False)))
コード例 #10
0
    def model(self, batch, lr, wd, ema, warmup_pos, consistency_weight,
              threshold, **kwargs):
        hwc = [self.dataset.height, self.dataset.width, self.dataset.colors]
        xt_in = tf.placeholder(tf.float32, [batch] + hwc, 'xt')  # For training
        x_in = tf.placeholder(tf.float32, [None] + hwc, 'x')
        y_in = tf.placeholder(tf.float32, [batch] + hwc, 'y')
        l_in = tf.placeholder(tf.int32, [batch], 'labels')
        l = tf.one_hot(l_in, self.nclass)
        wd *= lr
        warmup = tf.clip_by_value(
            tf.to_float(self.step) / (warmup_pos * (FLAGS.train_kimg << 10)),
            0, 1)

        classifier = lambda x, **kw: self.classifier(x, **kw, **kwargs).logits
        logits_x = classifier(xt_in, training=True)
        post_ops = tf.get_collection(
            tf.GraphKeys.UPDATE_OPS
        )  # Take only first call to update batch norm.
        logits_y = classifier(y_in, training=True)
        # Get the pseudo-label loss
        loss_pl = tf.nn.sparse_softmax_cross_entropy_with_logits(
            labels=tf.argmax(logits_y, axis=-1), logits=logits_y)
        # Masks denoting which data points have high-confidence predictions
        greater_than_thresh = tf.reduce_any(
            tf.greater(tf.nn.softmax(logits_y), threshold),
            axis=-1,
            keepdims=True,
        )
        greater_than_thresh = tf.cast(greater_than_thresh, loss_pl.dtype)
        # Only enforce the loss when the model is confident
        loss_pl *= greater_than_thresh
        # Note that we also average over examples without confident outputs;
        # this is consistent with the realistic evaluation codebase
        loss_pl = tf.reduce_mean(loss_pl)

        loss = tf.nn.softmax_cross_entropy_with_logits_v2(labels=l,
                                                          logits=logits_x)
        loss = tf.reduce_mean(loss)
        tf.summary.scalar('losses/xe', loss)
        tf.summary.scalar('losses/pl', loss_pl)

        ema = tf.train.ExponentialMovingAverage(decay=ema)
        ema_op = ema.apply(utils.model_vars())
        ema_getter = functools.partial(utils.getter_ema, ema)
        post_ops.append(ema_op)
        post_ops.extend([
            tf.assign(v, v * (1 - wd)) for v in utils.model_vars('classify')
            if 'kernel' in v.name
        ])

        train_op = tf.train.AdamOptimizer(lr).minimize(
            loss + loss_pl * warmup * consistency_weight,
            colocate_gradients_with_ops=True)
        with tf.control_dependencies([train_op]):
            train_op = tf.group(*post_ops)

        return EasyDict(
            xt=xt_in,
            x=x_in,
            y=y_in,
            label=l_in,
            train_op=train_op,
            classify_raw=tf.nn.softmax(classifier(
                x_in, training=False)),  # No EMA, for debugging.
            classify_op=tf.nn.softmax(
                classifier(x_in, getter=ema_getter, training=False)))