Ejemplo n.º 1
0
def classifier():
    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    T = tb.utils.TensorDict(dict(
        sess = tf.Session(config=config),
        src_x = placeholder((None, 32, 32, 3),  name='source_x'),
        src_y = placeholder((None, 10),         name='source_y'),
        test_x = placeholder((None, 32, 32, 3), name='test_x'),
        test_y = placeholder((None, 10),        name='test_y'),
        phase = placeholder((), tf.bool,        name='phase')
    ))

    # Classification
    src_y = des.classifier(T.src_x, T.phase, internal_update=True)
    loss_class = tf.reduce_mean(softmax_xent(labels=T.src_y, logits=src_y))
    src_acc = basic_accuracy(T.src_y, src_y)

    # Evaluation (non-EMA)
    test_y = des.classifier(T.test_x, phase=False, reuse=True)
    test_acc = basic_accuracy(T.test_y, test_y)
    fn_test_acc = tb.function(T.sess, [T.test_x, T.test_y], test_acc)

    # Evaluation (EMA)
    ema = tf.train.ExponentialMovingAverage(decay=0.998)
    var_class = tf.get_collection('trainable_variables', 'class')
    ema_op = ema.apply(var_class)
    ema_y = des.classifier(T.test_x, phase=False, reuse=True, getter=get_getter(ema))
    ema_acc = basic_accuracy(T.test_y, ema_y)
    fn_ema_acc = tb.function(T.sess, [T.test_x, T.test_y], ema_acc)

    # Optimizer
    loss_main = loss_class
    var_main = var_class
    train_main = tf.train.AdamOptimizer(args.lr, 0.5).minimize(loss_main, var_list=var_main)
    train_main = tf.group(train_main, ema_op)

    # Summarizations
    summary_main = [
        tf.summary.scalar('class/loss_class', loss_class),
        tf.summary.scalar('acc/src_acc', src_acc),
    ]
    summary_main = tf.summary.merge(summary_main)

    # Saved ops
    c = tf.constant
    T.ops_print = [
        c('class'), loss_class,
        c('src'), src_acc,
    ]

    T.ops_main = [summary_main, train_main]
    T.fn_test_acc = fn_test_acc
    T.fn_ema_acc = fn_ema_acc

    return T
Ejemplo n.º 2
0
def classifier():
    T = tb.utils.TensorDict(
        dict(sess=tf.Session(config=tb.growth_config()),
             src_x=placeholder((None, 32, 32, 3)),
             src_y=placeholder((None, 10)),
             trg_x=placeholder((None, 32, 32, 3)),
             trg_y=placeholder((None, 10)),
             test_x=placeholder((None, 32, 32, 3)),
             test_y=placeholder((None, 10)),
             phase=placeholder((), tf.bool)))

    # Supervised and conditional entropy minimization
    src_y = net.classifier(T.src_x, phase=True, internal_update=False)
    trg_y = net.classifier(T.trg_x,
                           phase=True,
                           internal_update=True,
                           reuse=True)

    loss_class = tf.reduce_mean(softmax_xent(labels=T.src_y, logits=src_y))

    # Evaluation (non-EMA)
    test_y = net.classifier(T.test_x, phase=False, scope='class', reuse=True)

    # Evaluation (EMA)
    ema = tf.train.ExponentialMovingAverage(decay=0.998)
    ema_op = ema.apply(tf.get_collection('trainable_variables', 'class/'))
    T.ema_y = net.classifier(T.test_x,
                             phase=False,
                             reuse=True,
                             getter=get_getter(ema))

    src_acc = basic_accuracy(T.src_y, src_y)
    trg_acc = basic_accuracy(T.trg_y, trg_y)
    ema_acc = basic_accuracy(T.test_y, T.ema_y)
    fn_ema_acc = tb.function(T.sess, [T.test_x, T.test_y], ema_acc)

    # Optimizer
    loss_main = loss_class
    var_main = tf.get_collection('trainable_variables', 'class')
    train_main = tf.train.AdamOptimizer(args.lr,
                                        0.5).minimize(loss_main,
                                                      var_list=var_main)
    train_main = tf.group(train_main, ema_op)

    # Summarizations
    summary_main = [
        tf.summary.scalar('class/loss_class', loss_class),
        tf.summary.scalar('acc/src_acc', src_acc),
        tf.summary.scalar('acc/trg_acc', trg_acc)
    ]
    summary_main = tf.summary.merge(summary_main)

    # Saved ops
    c = tf.constant
    T.ops_print = [c('class'), loss_class]
    T.ops_main = [summary_main, train_main]
    T.fn_ema_acc = fn_ema_acc

    return T
    def __init__(self, domain, datasets, M):
        print "Constructing pseudodata"
        sys.stdout.flush()

        cast = domain not in {'mnist28', 'mnist32', 'mnistm28', 'mnistm32'}
        print "Casting:", cast
        labeler = tb.function(M.sess, [M.test_x], M.back_y)

        self.train = Data(datasets.train.images, labeler=labeler, cast=cast)
        self.test = Data(datasets.test.images, labeler=labeler, cast=cast)
def dann_embed():
    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    # config.gpu_options.per_process_gpu_memory_fraction = 0.45
    T = tb.utils.TensorDict(dict(
        sess = tf.Session(config=config),
        src_x = placeholder((None, H, W, 3)),
        src_y = placeholder((None, Y)),
        trg_x = placeholder((None, H, W, 3)),
        trg_y = placeholder((None, Y)),
        fake_z = placeholder((None, 100)),
        fake_y = placeholder((None, Y)),
        test_x = placeholder((None, H, W, 3)),
        test_y = placeholder((None, Y)),
        phase = placeholder((), tf.bool)
    ))

    # Schedules
    start, end = args.pivot - 1, args.pivot
    global_step = tf.Variable(0., trainable=False)
    # Ramp down dw
    ramp_dw = conditional_ramp_weight(args.dwdn, global_step, args.dw, 0, start, end)
    # Ramp down src
    ramp_class = conditional_ramp_weight(args.dn, global_step, 1, args.dcval, start, end)
    ramp_sbw = conditional_ramp_weight(args.dn, global_step, args.sbw, 0, start, end)
    # Ramp up trg (never more than src)
    ramp_cw = conditional_ramp_weight(args.up, global_step, args.cw, args.uval, start, end)
    ramp_gw = conditional_ramp_weight(args.up, global_step, args.gw, args.uval, start, end)
    ramp_gbw = conditional_ramp_weight(args.up, global_step, args.gbw, args.uval, start, end)
    ramp_tbw = conditional_ramp_weight(args.up, global_step, args.tbw, args.uval, start, end)

    # Supervised and conditional entropy minimization
    src_e = des.classifier(T.src_x, T.phase, enc_phase=1, trim=args.trim, scope='class', internal_update=False)
    trg_e = des.classifier(T.trg_x, T.phase, enc_phase=1, trim=args.trim, scope='class', reuse=True, internal_update=True)
    src_y = des.classifier(src_e, T.phase, enc_phase=0, trim=args.trim, scope='class', internal_update=False)
    trg_y = des.classifier(trg_e, T.phase, enc_phase=0, trim=args.trim, scope='class', reuse=True, internal_update=True)

    loss_class = tf.reduce_mean(softmax_xent(labels=T.src_y, logits=src_y))
    loss_cent = tf.reduce_mean(softmax_xent_two(labels=trg_y, logits=trg_y))

    # Image generation
    if args.gw > 0:
        fake_x = des.generator(T.fake_z, T.fake_y, T.phase)
        fake_logit = des.discriminator(fake_x, T.phase)
        real_logit = des.discriminator(T.trg_x, T.phase, reuse=True)
        fake_e = des.classifier(fake_x, T.phase, enc_phase=1, trim=args.trim, scope='class', reuse=True)
        fake_y = des.classifier(fake_e, T.phase, enc_phase=0, trim=args.trim, scope='class', reuse=True)

        loss_gdisc = 0.5 * tf.reduce_mean(
            sigmoid_xent(labels=tf.ones_like(real_logit), logits=real_logit) +
            sigmoid_xent(labels=tf.zeros_like(fake_logit), logits=fake_logit))
        loss_gen = tf.reduce_mean(sigmoid_xent(labels=tf.ones_like(fake_logit), logits=fake_logit))
        loss_info = tf.reduce_mean(softmax_xent(labels=T.fake_y, logits=fake_y))

    else:
        loss_gdisc = constant(0)
        loss_gen = constant(0)
        loss_info = constant(0)

    # Domain confusion
    if args.dw > 0 and args.phase == 0:
        real_logit = des.feature_discriminator(src_e, T.phase)
        fake_logit = des.feature_discriminator(trg_e, T.phase, reuse=True)

        loss_ddisc = 0.5 * tf.reduce_mean(
            sigmoid_xent(labels=tf.ones_like(real_logit), logits=real_logit) +
            sigmoid_xent(labels=tf.zeros_like(fake_logit), logits=fake_logit))
        loss_domain = 0.5 * tf.reduce_mean(
            sigmoid_xent(labels=tf.zeros_like(real_logit), logits=real_logit) +
            sigmoid_xent(labels=tf.ones_like(fake_logit), logits=fake_logit))

    else:
        loss_ddisc = constant(0)
        loss_domain = constant(0)

    # Smoothing
    loss_t_ball = constant(0) if args.tbw == 0 else smoothing_loss(T.trg_x, trg_y, T.phase)
    loss_s_ball = constant(0) if args.sbw == 0 or args.phase == 1 else smoothing_loss(T.src_x, src_y, T.phase)
    loss_g_ball = constant(0) if args.gbw == 0 else smoothing_loss(fake_x, fake_y, T.phase)

    loss_t_emb = constant(0) if args.te == 0 else smoothing_loss(T.trg_x, trg_e, T.phase, is_embedding=True)
    loss_s_emb = constant(0) if args.se == 0 else smoothing_loss(T.src_x, src_e, T.phase, is_embedding=True)

    # Evaluation (non-EMA)
    test_y = des.classifier(T.test_x, False, enc_phase=1, trim=0, scope='class', reuse=True)

    # Evaluation (EMA)
    ema = tf.train.ExponentialMovingAverage(decay=0.998)
    var_class = tf.get_collection('trainable_variables', 'class/')
    ema_op = ema.apply(var_class)
    T.ema_e = des.classifier(T.test_x, False, enc_phase=1, trim=args.trim, scope='class', reuse=True, getter=get_getter(ema))
    ema_y = des.classifier(T.ema_e, False, enc_phase=0, trim=args.trim, scope='class', reuse=True, getter=get_getter(ema))

    # Back-up (teacher) model
    back_y = des.classifier(T.test_x, False, enc_phase=1, trim=0, scope='back')
    var_main = tf.get_collection('variables', 'class/(?!.*ExponentialMovingAverage:0)')
    var_back = tf.get_collection('variables', 'back/(?!.*ExponentialMovingAverage:0)')
    back_assigns = []
    init_assigns = []
    for b, m in zip(var_back, var_main):
        ave = ema.average(m)
        target = ave if ave else m
        back_assigns += [tf.assign(b, target)]
        init_assigns += [tf.assign(m, target)]
        # print "Assign {} -> {}, {}".format(target.name, b.name, m.name)
    back_update = tf.group(*back_assigns)
    init_update = tf.group(*init_assigns)

    src_acc = basic_accuracy(T.src_y, src_y)
    trg_acc = basic_accuracy(T.trg_y, trg_y)
    test_acc = basic_accuracy(T.test_y, test_y)
    ema_acc = basic_accuracy(T.test_y, ema_y)
    fn_test_acc = tb.function(T.sess, [T.test_x, T.test_y], test_acc)
    fn_ema_acc = tb.function(T.sess, [T.test_x, T.test_y], ema_acc)

    # Optimizer
    loss_main = (ramp_class * loss_class +
                 ramp_dw * loss_domain +
                 ramp_cw * loss_cent +
                 ramp_tbw * loss_t_ball +
                 ramp_gbw * loss_g_ball +
                 ramp_sbw * loss_s_ball +
                 args.te * loss_t_emb +
                 args.se * loss_s_emb +
                 ramp_gw * loss_gen +
                 ramp_gw * loss_info)
    var_main = tf.get_collection('trainable_variables', 'class')
    var_main += tf.get_collection('trainable_variables', 'gen')
    train_main = tf.train.AdamOptimizer(args.lr, 0.5).minimize(loss_main,
                                                               var_list=var_main,
                                                               global_step=global_step)
    train_main = tf.group(train_main, ema_op)

    if (args.dw > 0 and args.phase == 0) or args.gw > 0:
        loss_disc = loss_ddisc + loss_gdisc
        var_disc = tf.get_collection('trainable_variables', 'disc')
        train_disc = tf.train.AdamOptimizer(args.lr, 0.5).minimize(loss_disc,
                                                                   var_list=var_disc)
    else:
        train_disc = constant(0)

    # Summarizations
#    embedding = tf.Variable(tf.zeros([1000,12800]),name='embedding')
#    embedding = tf.reshape(trg_e[:1000], [-1,12800])

    summary_disc = [tf.summary.scalar('domain/loss_ddisc', loss_ddisc),
                    tf.summary.scalar('gen/loss_gdisc', loss_gdisc)]

    summary_main = [tf.summary.scalar('class/loss_class', loss_class),
                    tf.summary.scalar('class/loss_cent', loss_cent),
                    tf.summary.scalar('domain/loss_domain', loss_domain),
                    tf.summary.scalar('lipschitz/loss_t_ball', loss_t_ball),
                    tf.summary.scalar('lipschitz/loss_g_ball', loss_g_ball),
                    tf.summary.scalar('lipschitz/loss_s_ball', loss_s_ball),
                    tf.summary.scalar('embedding/loss_t_emb', loss_t_emb),
                    tf.summary.scalar('embedding/loss_s_emb', loss_s_emb),
                    tf.summary.scalar('gen/loss_gen', loss_gen),
                    tf.summary.scalar('gen/loss_info', loss_info),
                    tf.summary.scalar('ramp/ramp_class', ramp_class),
                    tf.summary.scalar('ramp/ramp_dw', ramp_dw),
                    tf.summary.scalar('ramp/ramp_cw', ramp_cw),
                    tf.summary.scalar('ramp/ramp_gw', ramp_gw),
                    tf.summary.scalar('ramp/ramp_tbw', ramp_tbw),
                    tf.summary.scalar('ramp/ramp_sbw', ramp_sbw),
                    tf.summary.scalar('ramp/ramp_gbw', ramp_gbw),
                    tf.summary.scalar('acc/src_acc', src_acc),
                    tf.summary.scalar('acc/trg_acc', trg_acc)]

    summary_disc = tf.summary.merge(summary_disc)
    summary_main = tf.summary.merge(summary_main)

    # Saved ops
    c = tf.constant
    T.ops_print = [c('ddisc'), loss_ddisc,
                   c('domain'), loss_domain,
                   c('gdisc'), loss_gdisc,
                   c('gen'), loss_gen,
                   c('info'), loss_info,
                   c('class'), loss_class,
                   c('cent'), loss_cent,
                   c('t_ball'), loss_t_ball,
                   c('g_ball'), loss_g_ball,
                   c('s_ball'), loss_s_ball,
                   c('t_emb'), loss_t_emb,
                   c('s_emb'), loss_s_emb,
                   c('src'), src_acc,
                   c('trg'), trg_acc]

    T.ops_disc = [summary_disc, train_disc]
    T.ops_main = [summary_main, train_main]
    T.fn_test_acc = fn_test_acc
    T.fn_ema_acc = fn_ema_acc
    T.back_y = tf.nn.softmax(back_y)  # Access to backed-up eval model softmax
    T.back_update = back_update       # Update op eval -> backed-up eval model
    T.init_update = init_update       # Update op eval -> student eval model
    T.global_step = global_step
    T.ramp_class = ramp_class
    if args.gw > 0:
        summary_image = tf.summary.image('image/gen', generate_img())
        T.ops_image = summary_image

    return T
Ejemplo n.º 5
0
def vae():
    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    T = tb.utils.TensorDict(dict(
        sess = tf.Session(config=config),
        src_x = placeholder((None, 32, 32, 3),  name='source_x'),
        src_y = placeholder((None, 10),         name='source_y'),
        trg_x = placeholder((None, 32, 32, 3),  name='target_x'),
        trg_y = placeholder((None, 10),         name='target_y'),
        test_x = placeholder((None, 32, 32, 3), name='test_x'),
        test_y = placeholder((None, 10),        name='test_y'),
        fake_z = placeholder((None, 100),       name='fake_z'),
        fake_y = placeholder((None, 10),        name='fake_y'),
        tau = placeholder((),                   name='tau'),
        phase = placeholder((), tf.bool,        name='phase'),
    ))

    if args.gw > 0:
        # Variational inference
        y_logit = des.classifier(T.trg_x, T.phase, internal_update=True)
        y = gumbel_softmax(y_logit, T.tau)
        z, z_post = des.encoder(T.trg_x, y, T.phase, internal_update=True)

        # Generation
        x = des.generator(z, y, T.phase, internal_update=True)

        # Loss
        z_prior = (0., 1.)
        kl_z = tf.reduce_mean(log_normal(z, *z_post) - log_normal(z, *z_prior))

        y_q = tf.nn.softmax(y_logit)
        log_y_q = tf.nn.log_softmax(y_logit)
        kl_y = tf.reduce_mean(tf.reduce_sum(y_q * (log_y_q - tf.log(0.1)), axis=1))

        loss_kl = kl_z + kl_y
        loss_rec = args.rw * tf.reduce_mean(tf.reduce_sum(tf.square(T.trg_x - x), axis=[1,2,3]))
        loss_gen = loss_rec + loss_kl
        trg_acc = basic_accuracy(T.trg_y, y_logit)

    else:
        loss_kl = constant(0)
        loss_rec = constant(0)
        loss_gen = constant(0)
        trg_acc = constant(0)

    # Posterior regularization (labeled classification)
    src_y = des.classifier(T.src_x, T.phase, reuse=True)
    loss_class = tf.reduce_mean(softmax_xent(labels=T.src_y, logits=src_y))
    src_acc = basic_accuracy(T.src_y, src_y)

    # Evaluation (classification)
    test_y = des.classifier(T.test_x, phase=False, reuse=True)
    test_acc = basic_accuracy(T.test_y, test_y)
    fn_test_acc = tb.function(T.sess, [T.test_x, T.test_y], test_acc)

    # Evaluation (generation)
    if args.gw > 0:
        fake_x = des.generator(T.fake_z, T.fake_y, phase=False, reuse=True)
        fn_fake_x = tb.function(T.sess, [T.fake_z, T.fake_y], fake_x)

    # Optimizer
    var_main = tf.get_collection('trainable_variables', 'gen/')
    var_main += tf.get_collection('trainable_variables', 'enc/')
    var_main += tf.get_collection('trainable_variables', 'class/')
    loss_main = args.gw * loss_gen + loss_class
    train_main = tf.train.AdamOptimizer(args.lr, 0.5).minimize(loss_main, var_list=var_main)

    # Summarizations
    summary_main = [
        tf.summary.scalar('gen/loss_gen', loss_gen),
        tf.summary.scalar('gen/loss_rec', loss_rec),
        tf.summary.scalar('gen/loss_kl', loss_kl),
        tf.summary.scalar('class/loss_class', loss_class),
        tf.summary.scalar('acc/src_acc', src_acc),
        tf.summary.scalar('acc/trg_acc', trg_acc),
    ]
    summary_main = tf.summary.merge(summary_main)

    if args.gw > 0:
        summary_image = tf.summary.image('image/gen', generate_img())

    # Saved ops
    c = tf.constant
    T.ops_print = [
        c('tau'), tf.identity(T.tau),
        c('gen'), loss_gen,
        c('rec'), loss_rec,
        c('kl'), loss_kl,
        c('class'), loss_class,
    ]

    T.ops_main = [summary_main, train_main]
    T.fn_test_acc = fn_test_acc

    if args.gw > 0:
        T.fn_fake_x = fn_fake_x
        T.ops_image = summary_image

    return T
Ejemplo n.º 6
0
def dirtt():
    T = tb.utils.TensorDict(dict(
        sess = tf.Session(config=tb.growth_config()),
        src_x = placeholder((None, 32, 32, 3)),
        src_y = placeholder((None, args.Y)),
        trg_x = placeholder((None, 32, 32, 3)),
        trg_y = placeholder((None, args.Y)),
        test_x = placeholder((None, 32, 32, 3)),
        test_y = placeholder((None, args.Y)),
    ))
    # Supervised and conditional entropy minimization
    src_e = nn.classifier(T.src_x, phase=True, enc_phase=1, trim=args.trim)
    trg_e = nn.classifier(T.trg_x, phase=True, enc_phase=1, trim=args.trim, reuse=True, internal_update=True)
    src_p = nn.classifier(src_e, phase=True, enc_phase=0, trim=args.trim)
    trg_p = nn.classifier(trg_e, phase=True, enc_phase=0, trim=args.trim, reuse=True, internal_update=True)

    loss_src_class = tf.reduce_mean(softmax_xent(labels=T.src_y, logits=src_p))
    loss_trg_cent = tf.reduce_mean(softmax_xent_two(labels=trg_p, logits=trg_p))

    # Domain confusion
    if args.dw > 0 and args.dirt == 0:
        real_logit = nn.feature_discriminator(src_e, phase=True)
        fake_logit = nn.feature_discriminator(trg_e, phase=True, reuse=True)

        loss_disc = 0.5 * tf.reduce_mean(
            sigmoid_xent(labels=tf.ones_like(real_logit), logits=real_logit) +
            sigmoid_xent(labels=tf.zeros_like(fake_logit), logits=fake_logit))
        loss_domain = 0.5 * tf.reduce_mean(
            sigmoid_xent(labels=tf.zeros_like(real_logit), logits=real_logit) +
            sigmoid_xent(labels=tf.ones_like(fake_logit), logits=fake_logit))

    else:
        loss_disc = constant(0)
        loss_domain = constant(0)

    # Virtual adversarial training (turn off src in non-VADA phase)
    loss_src_vat = vat_loss(T.src_x, src_p, nn.classifier) if args.sw > 0 and args.dirt == 0 else constant(0)
    loss_trg_vat = vat_loss(T.trg_x, trg_p, nn.classifier) if args.tw > 0 else constant(0)

    # Evaluation (EMA)
    ema = tf.train.ExponentialMovingAverage(decay=0.998)
    var_class = tf.get_collection('trainable_variables', 'class/')
    ema_op = ema.apply(var_class)
    ema_p = nn.classifier(T.test_x, phase=False, reuse=True, getter=tb.tfutils.get_getter(ema))

    # Teacher model (a back-up of EMA model)
    teacher_p = nn.classifier(T.test_x, phase=False, scope='teacher')
    var_main = tf.get_collection('variables', 'class/(?!.*ExponentialMovingAverage:0)')
    var_teacher = tf.get_collection('variables', 'teacher/(?!.*ExponentialMovingAverage:0)')
    teacher_assign_ops = []
    for t, m in zip(var_teacher, var_main):
        ave = ema.average(m)
        ave = ave if ave else m
        teacher_assign_ops += [tf.assign(t, ave)]
    update_teacher = tf.group(*teacher_assign_ops)
    teacher = tb.function(T.sess, [T.test_x], tf.nn.softmax(teacher_p))

    # Accuracies
    src_acc = basic_accuracy(T.src_y, src_p)
    trg_acc = basic_accuracy(T.trg_y, trg_p)
    ema_acc = basic_accuracy(T.test_y, ema_p)
    fn_ema_acc = tb.function(T.sess, [T.test_x, T.test_y], ema_acc)

    # Optimizer
    dw = constant(args.dw) if args.dirt == 0 else constant(0)
    cw = constant(1)       if args.dirt == 0 else constant(args.bw)
    sw = constant(args.sw) if args.dirt == 0 else constant(0)
    tw = constant(args.tw)
    loss_main = (dw * loss_domain +
                 cw * loss_src_class +
                 sw * loss_src_vat +
                 tw * loss_trg_cent +
                 tw * loss_trg_vat)
    var_main = tf.get_collection('trainable_variables', 'class')
    train_main = tf.train.AdamOptimizer(args.lr, 0.5).minimize(loss_main, var_list=var_main)
    train_main = tf.group(train_main, ema_op)

    if args.dw > 0 and args.dirt == 0:
        var_disc = tf.get_collection('trainable_variables', 'disc')
        train_disc = tf.train.AdamOptimizer(args.lr, 0.5).minimize(loss_disc, var_list=var_disc)
    else:
        train_disc = constant(0)

    # Summarizations
    summary_disc = [tf.summary.scalar('domain/loss_disc', loss_disc),]
    summary_main = [tf.summary.scalar('domain/loss_domain', loss_domain),
                    tf.summary.scalar('class/loss_src_class', loss_src_class),
                    tf.summary.scalar('class/loss_trg_cent', loss_trg_cent),
                    tf.summary.scalar('lipschitz/loss_trg_vat', loss_trg_vat),
                    tf.summary.scalar('lipschitz/loss_src_vat', loss_src_vat),
                    tf.summary.scalar('hyper/dw', dw),
                    tf.summary.scalar('hyper/cw', cw),
                    tf.summary.scalar('hyper/sw', sw),
                    tf.summary.scalar('hyper/tw', tw),
                    tf.summary.scalar('acc/src_acc', src_acc),
                    tf.summary.scalar('acc/trg_acc', trg_acc)]

    # Merge summaries
    summary_disc = tf.summary.merge(summary_disc)
    summary_main = tf.summary.merge(summary_main)

    # Saved ops
    c = tf.constant
    T.ops_print = [c('disc'), loss_disc,
                   c('domain'), loss_domain,
                   c('class'), loss_src_class,
                   c('cent'), loss_trg_cent,
                   c('trg_vat'), loss_trg_vat,
                   c('src_vat'), loss_src_vat,
                   c('src'), src_acc,
                   c('trg'), trg_acc]
    T.ops_disc = [summary_disc, train_disc]
    T.ops_main = [summary_main, train_main]
    T.fn_ema_acc = fn_ema_acc
    T.teacher = teacher
    T.update_teacher = update_teacher

    return T
Ejemplo n.º 7
0
def vae():
    T = tb.utils.TensorDict(
        dict(
            sess=tf.Session(config=tb.growth_config()),
            trg_x=placeholder((None, 32, 32, 3), name='target_x'),
            fake_z=placeholder((None, args.Z), name='fake_z'),
        ))

    # Inference
    z, z_post = nn.encoder(T.trg_x, phase=True, internal_update=True)

    # Generation
    x = nn.generator(z, phase=True, internal_update=True)

    # Loss
    loss_rec, loss_kl, loss_gen = vae_loss(x, T.trg_x, z, z_post)

    # Evaluation (embedding, reconstruction, loss)
    test_z, test_z_post = nn.encoder(T.trg_x, phase=False, reuse=True)
    test_x = nn.generator(test_z, phase=False, reuse=True)
    _, _, test_loss = vae_loss(test_x, T.trg_x, test_z, test_z_post)
    fn_embed = tb.function(T.sess, [T.trg_x], test_z_post)
    fn_recon = tb.function(T.sess, [T.trg_x], test_x)
    fn_loss = tb.function(T.sess, [T.trg_x], test_loss)

    # Evaluation (generation)
    fake_x = nn.generator(T.fake_z, phase=False, reuse=True)
    fn_generate = tb.function(T.sess, [T.fake_z], fake_x)

    # Optimizer
    var_main = tf.get_collection('trainable_variables', 'gen/')
    var_main += tf.get_collection('trainable_variables', 'enc/')
    loss_main = loss_gen
    train_main = tf.train.AdamOptimizer(args.lr,
                                        0.5).minimize(loss_main,
                                                      var_list=var_main)

    # Summarizations
    summary_main = [
        tf.summary.scalar('gen/loss_gen', loss_gen),
        tf.summary.scalar('gen/loss_kl', loss_kl),
        tf.summary.scalar('gen/loss_rec', loss_rec),
    ]
    summary_image = [tf.summary.image('gen/gen', generate_image(nn.generator))]

    # Merge summaries
    summary_main = tf.summary.merge(summary_main)
    summary_image = tf.summary.merge(summary_image)

    # Saved ops
    c = tf.constant
    T.ops_print = [
        c('gen'),
        loss_gen,
        c('kl'),
        loss_kl,
        c('rec'),
        loss_rec,
    ]

    T.ops_main = [summary_main, train_main]
    T.ops_image = summary_image
    T.fn_embed = fn_embed
    T.fn_recon = fn_recon
    T.fn_loss = fn_loss
    T.fn_generate = fn_generate

    return T
Ejemplo n.º 8
0
    def _build_model(self):
        self.x_src_lst = []
        self.y_src_lst = []
        for i in range(self.data_loader.num_src_domain):
            x_src = tf.placeholder(dtype=tf.float32, shape=tuple([None]) + self.dim_src,
                                   name='x_src_{}_input'.format(i))
            y_src = tf.placeholder(dtype=tf.float32, shape=(None, self.num_classes),
                                   name='y_src_{}_input'.format(i))
            self.x_src_lst.append(x_src)
            self.y_src_lst.append(y_src)

        self.x_trg = tf.placeholder(dtype=tf.float32, shape=tuple([None]) + self.dim_trg, name='x_trg_input')
        self.y_trg = tf.placeholder(dtype=tf.float32, shape=(None, self.num_classes),
                                    name='y_trg_input')
        self.y_src_domain = tf.placeholder(dtype=tf.float32, shape=(None, self.data_loader.num_src_domain),
                                           name='y_src_domain_input')

        T = tb.utils.TensorDict(dict(
            x_tmp=tf.placeholder(dtype=tf.float32, shape=tuple([None]) + self.dim_src),
            y_tmp=tf.placeholder(dtype=tf.float32, shape=(None, self.num_classes))
        ))

        self.is_training = tf.placeholder(tf.bool, shape=(), name='is_training')

        self.x_src_mid_lst = []
        for i in range(self.data_loader.num_src_domain):
            x_src_mid = self._build_source_middle(self.x_src_lst[i], is_reused=i)
            self.x_src_mid_lst.append(x_src_mid)
        self.x_trg_mid = self._build_target_middle(self.x_trg, reuse=True)

        # <editor-fold desc="Classifier-logits">
        self.y_src_logit_lst = []
        for i in range(self.data_loader.num_src_domain):
            y_src_logit = self._build_class_src_discriminator(self.x_src_mid_lst[i], self.num_classes, i)
            self.y_src_logit_lst.append(y_src_logit)

        self.y_trg_logit = self._build_class_trg_discriminator(self.x_trg_mid,
                                                               self.num_classes)
        # </editor-fold>

        # <editor-fold desc="Classification">
        self.src_loss_class_lst = []
        self.src_loss_class_sum = tf.constant(0.0)
        for i in range(self.data_loader.num_src_domain):
            src_loss_class_detail = tf.nn.softmax_cross_entropy_with_logits_v2(
                logits=self.y_src_logit_lst[i], labels=self.y_src_lst[i])
            src_loss_class = tf.reduce_mean(src_loss_class_detail)
            self.src_loss_class_lst.append(self.src_domain_trade_off[i] * src_loss_class)
            self.src_loss_class_sum += self.src_domain_trade_off[i] * src_loss_class
        self.trg_loss_class_detail = tf.nn.softmax_cross_entropy_with_logits_v2(
            logits=self.y_trg_logit, labels=self.y_trg)
        self.trg_loss_class = tf.reduce_mean(self.trg_loss_class_detail)
        # </editor-fold>

        # <editor-fold desc="Source domain discriminator">
        self.x_src_mid_all = tf.concat(self.x_src_mid_lst, axis=0)
        self.y_src_discriminator_logit = self._build_domain_discriminator(self.x_src_mid_all)

        self.src_loss_discriminator_detail = tf.nn.softmax_cross_entropy_with_logits_v2(
            logits=self.y_src_discriminator_logit, labels=self.y_src_domain)
        self.src_loss_discriminator = tf.reduce_mean(self.src_loss_discriminator_detail)
        # </editor-fold>

        # <editor-fold desc="Compute teacher hS(xS)">
        self.y_src_teacher_all = []
        for i, bs in zip(range(self.data_loader.num_src_domain),
                         range(0, self.batch_size_src * self.data_loader.num_src_domain, self.batch_size_src)):
            y_src_logit_each_h_lst = []
            for j in range(self.data_loader.num_src_domain):
                y_src_logit_each_h = self._build_class_src_discriminator(self.x_src_mid_lst[i], self.num_classes,
                                                                         j, reuse=True)
                y_src_logit_each_h_lst.append(y_src_logit_each_h)
            y_src_logit_each_h_lst = tf.nn.softmax(tf.convert_to_tensor(y_src_logit_each_h_lst))
            y_src_discriminator_prob = tf.nn.softmax(tf.gather(self.y_src_discriminator_logit,
                                                               tf.range(bs, bs + self.batch_size_src,
                                                                        dtype=tf.int32), axis=0))
            y_src_teacher = self._compute_teacher_hs(y_src_logit_each_h_lst, y_src_discriminator_prob)
            self.y_src_teacher_all.append(y_src_teacher)
        self.y_src_teacher_all = tf.concat(self.y_src_teacher_all, axis=0)
        # </editor-fold>

        # <editor-fold desc="Compute teacher hS(xt)">
        y_trg_logit_each_h_lst = []
        for j in range(self.data_loader.num_src_domain):
            y_trg_logit_each_h = self._build_class_src_discriminator(self.x_trg_mid, self.num_classes,
                                                                     j, reuse=True)
            y_trg_logit_each_h_lst.append(y_trg_logit_each_h)
        y_trg_logit_each_h_lst = tf.nn.softmax(tf.convert_to_tensor(y_trg_logit_each_h_lst))
        self.y_trg_src_domains_logit = self._build_domain_discriminator(self.x_trg_mid, reuse=True)
        y_trg_discriminator_prob = tf.nn.softmax(self.y_trg_src_domains_logit)
        self.y_trg_teacher = self._compute_teacher_hs(y_trg_logit_each_h_lst, y_trg_discriminator_prob)
        # </editor-fold>

        # <editor-fold desc="Compute pseudo-label loss">
        self.ht_g_xs = build_class_discriminator_template(
            self.x_src_mid_all, training_phase=self.is_training, scope='c-trg', num_classes=self.num_classes,
            reuse=True, internal_update=True, class_discriminator_layout=self.classify_layout, cnn_size=self.cnn_size
        )
        self.mimic_loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits_v2(
            logits=self.ht_g_xs, labels=self.y_src_teacher_all)) + \
                          tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits_v2(
                              logits=self.y_trg_logit, labels=self.y_trg_teacher))
        # </editor-fold>

        # <editor-fold desc="Compute WS loss">
        self.data_shift_loss = self._compute_cosine_similarity(self.x_trg_mid, self.x_src_mid_all)
        self.label_shift_loss = self._compute_label_shift_loss(self.y_trg_logit, self.ht_g_xs)
        self.data_label_shift_loss = self.data_shift_troff * self.data_shift_loss + self.lbl_shift_troff * self.label_shift_loss
        self.g_network = tf.reshape(self._build_phi_network(self.x_trg_mid), [-1])
        self.exp_term = (- self.data_label_shift_loss + self.g_network) / self.theta
        self.g_network_loss = tf.reduce_mean(self.g_network)
        self.OT_loss = tf.reduce_mean(
            - self.theta * \
            (
                    tf.log(1.0 / self.batch_size) +
                    tf.reduce_logsumexp(self.exp_term, axis=1)
            )
        ) + self.g_network_trade_off * self.g_network_loss
        # </editor-fold>

        # <editor-fold desc="Compute VAT loss">
        self.trg_loss_vat = self._build_vat_loss(
            self.x_trg, self.y_trg_logit, self.num_classes,
            scope_encode=self._get_scope('generator', 'trg'), scope_classify='c-trg'
        )
        # </editor-fold>

        # <editor-fold desc="Compute conditional entropy loss w.r.t target distribution">
        self.trg_loss_cond_entropy = tf.reduce_mean(softmax_x_entropy_two(labels=self.y_trg_logit,
                                                                          logits=self.y_trg_logit))
        # </editor-fold>

        # <editor-fold desc="Accuracy">
        self.src_accuracy_lst = []
        for i in range(self.data_loader.num_src_domain):
            y_src_pred = tf.argmax(self.y_src_logit_lst[i], 1, output_type=tf.int32)
            y_src_sparse = tf.argmax(self.y_src_lst[i], 1, output_type=tf.int32)
            src_accuracy = tf.reduce_mean(tf.cast(tf.equal(y_src_sparse, y_src_pred), 'float32'))
            self.src_accuracy_lst.append(src_accuracy)
        # compute acc for target domain
        self.y_trg_pred = tf.argmax(self.y_trg_logit, 1, output_type=tf.int32)
        self.y_trg_sparse = tf.argmax(self.y_trg, 1, output_type=tf.int32)
        self.trg_accuracy = tf.reduce_mean(tf.cast(tf.equal(self.y_trg_sparse, self.y_trg_pred), 'float32'))
        # compute acc for src domain disc
        self.y_src_domain_pred = tf.argmax(self.y_src_discriminator_logit, 1, output_type=tf.int32)
        self.y_src_domain_sparse = tf.argmax(self.y_src_domain, 1, output_type=tf.int32)
        self.src_domain_acc = tf.reduce_mean(
            tf.cast(tf.equal(self.y_src_domain_sparse, self.y_src_domain_pred), 'float32'))
        # </editor-fold>

        # <editor-fold desc="Put it all together">
        lst_phase1_losses = [
            (self.src_class_trade_off, self.src_loss_class_sum),
            (self.domain_trade_off, self.src_loss_discriminator),
        ]
        self.phase1_loss = tf.constant(0.0)
        for trade_off, loss in lst_phase1_losses:
            # if trade_off != 0:
            self.phase1_loss += trade_off * loss

        lst_phase2_losses = [
            (self.src_class_trade_off, self.src_loss_class_sum),
            (self.ot_trade_off, self.OT_loss),
            (self.domain_trade_off, self.src_loss_discriminator),
            (self.trg_vat_troff, self.trg_loss_vat),
            (self.trg_ent_troff, self.trg_loss_cond_entropy),
            (self.mimic_trade_off, self.mimic_loss)
        ]
        self.phase2_loss = tf.constant(0.0)
        for trade_off, loss in lst_phase2_losses:
            # if trade_off != 0:
            self.phase2_loss += trade_off * loss
        # </editor-fold>

        # <editor-fold desc="Evaluation">
        primary_student_variables = self._get_variables(self._get_student_primary_scopes())
        ema = tf.train.ExponentialMovingAverage(decay=0.998)
        var_list_for_ema = primary_student_variables[0] + primary_student_variables[1]
        ema_op = ema.apply(var_list=var_list_for_ema)
        self.ema_p = self._build_classifier(T.x_tmp, self.num_classes, ema)
        self.batch_ema_acc = batch_ema_acc(T.y_tmp, self.ema_p)
        self.fn_batch_ema_acc = tb.function(self.tf_session, [T.x_tmp, T.y_tmp], self.batch_ema_acc)
        # </editor-fold>

        teacher_variables = self._get_variables(self._get_teacher_scopes())
        self.train_teacher = \
            tf.train.AdamOptimizer(self.learning_rate, 0.5).minimize(self.phase1_loss,
                                                                     var_list=teacher_variables)
        self.train_student_main = \
            tf.train.AdamOptimizer(self.learning_rate, 0.5).minimize(self.phase2_loss,
                                                                     var_list=teacher_variables + [
                                                                         primary_student_variables[1]])
        self.primary_train_student_op = tf.group(self.train_student_main, ema_op)

        # <editor-fold desc="Construct secondary loss">
        secondary_variables = self._get_variables(self._get_student_secondary_scopes())
        self.secondary_train_student_op = \
            tf.train.AdamOptimizer(self.learning_rate, 0.5).minimize(-self.OT_loss,
                                                                     var_list=secondary_variables)
        # </editor-fold>

        # <editor-fold desc="Summaries">
        tf.summary.scalar('loss/phase1_loss', self.phase1_loss)
        tf.summary.scalar('loss/phase2_loss', self.phase2_loss)
        tf.summary.scalar('loss/W_distance', self.OT_loss)
        tf.summary.scalar('loss/src_loss_discriminator', self.src_loss_discriminator)

        for i in range(self.data_loader.num_src_domain):
            tf.summary.scalar('loss/src_loss_class_{}'.format(i), self.src_loss_class_lst[i])
            tf.summary.scalar('acc/src_acc_{}'.format(i), self.src_accuracy_lst[i])
        tf.summary.scalar('acc/src_domain_acc', self.src_domain_acc)
        tf.summary.scalar('acc/trg_acc', self.trg_accuracy)

        tf.summary.scalar('trg_loss_class', self.trg_loss_class)
        tf.summary.scalar('hyperparameters/learning_rate', self.learning_rate)
        tf.summary.scalar('hyperparameters/src_class_trade_off', self.src_class_trade_off)
        tf.summary.scalar('hyperparameters/g_network_trade_off', self.g_network_trade_off)
        tf.summary.scalar('hyperparameters/domain_trade_off', self.domain_trade_off)
        tf.summary.scalar('hyperparameters/trg_vat_troff', self.trg_vat_troff)
        tf.summary.scalar('hyperparameters/trg_ent_troff', self.trg_ent_troff)
        self.tf_merged_summaries = tf.summary.merge_all()
Ejemplo n.º 9
0
g_train = tf.train.AdamOptimizer(lr, 0.5).minimize(g_loss, var_list=g_var)

# Logger
base_dir = 'results/gamma={:.1f}_run={:d}'.format(args.gamma, args.run)
writer = tb.FileWriter(os.path.join(base_dir, 'log.csv'), args=args, overwrite=args.run >= 999)
writer.add_var('d_real', '{:8.4f}', d_real_loss)
writer.add_var('d_fake', '{:8.4f}', d_fake_loss)
writer.add_var('k', '{:8.4f}', k * 1)
writer.add_var('M', '{:8.4f}', m_global)
writer.add_var('lr', '{:8.6f}', lr * 1)
writer.add_var('iter', '{:>8d}')
writer.initialize()

sess = tf.Session()
load_model(sess)
f_gen = tb.function(sess, [z], x_fake)
f_rec = tb.function(sess, [x_real], d_real)
celeba = CelebA(args.data)

# Alternatively try grouping d_train/g_train together
all_tensors = [d_train, g_train, d_real_loss, d_fake_loss]
# d_tensors = [d_train, d_real_loss]
# g_tensors = [g_train, d_fake_loss]
for i in xrange(args.max_iter):
    x = celeba.next_batch(args.bs)
    z = np.random.uniform(-1, 1, (args.bs, args.e_size))

    feed_dict = {'x:0': x, 'z:0': z, 'k:0': args.k, 'lr:0': args.lr, 'g:0': args.gamma}
    _, _, d_real_loss, d_fake_loss = sess.run(all_tensors, feed_dict)
    # _, d_real_loss = sess.run(d_tensors, feed_dict)
    # _, d_fake_loss = sess.run(g_tensors, feed_dict)
Ejemplo n.º 10
0
def dirtt():
    T = tb.utils.TensorDict(dict(
        sess = tf.Session(config=tb.growth_config()),
        src_x = placeholder((None, 500, 60, 1)),
        src_y = placeholder((None, args.Y)),
        trg_x = placeholder((None, 500, 60, 1)),
        trg_y = placeholder((None, args.Y)),
        test_x = placeholder((None, 500, 60, 1)),
        test_y = placeholder((None, args.Y)),
    ))
    # Supervised and conditional entropy minimization
    src_e = nn.classifier(T.src_x, phase=True, enc_phase=1, trim=args.trim)
    trg_e = nn.classifier(T.trg_x, phase=True, enc_phase=1, trim=args.trim, reuse=True, internal_update=True)
    src_p = nn.classifier(src_e, phase=True, enc_phase=0, trim=args.trim)
    trg_p = nn.classifier(trg_e, phase=True, enc_phase=0, trim=args.trim, reuse=True, internal_update=True)

    loss_src_class = tf.reduce_mean(softmax_xent(labels=T.src_y, logits=src_p))
    loss_trg_cent = tf.reduce_mean(softmax_xent_two(labels=trg_p, logits=trg_p))

    # Domain confusion
    if args.dw > 0 and args.dirt == 0:
        real_logit = nn.feature_discriminator(src_e, phase=True)
        fake_logit = nn.feature_discriminator(trg_e, phase=True, reuse=True)

        loss_disc = 0.5 * tf.reduce_mean(
            sigmoid_xent(labels=tf.ones_like(real_logit), logits=real_logit) +
            sigmoid_xent(labels=tf.zeros_like(fake_logit), logits=fake_logit))
        loss_domain = 0.5 * tf.reduce_mean(
            sigmoid_xent(labels=tf.zeros_like(real_logit), logits=real_logit) +
            sigmoid_xent(labels=tf.ones_like(fake_logit), logits=fake_logit))

    else:
        loss_disc = constant(0)
        loss_domain = constant(0)

    # Virtual adversarial training (turn off src in non-VADA phase)
    loss_src_vat = vat_loss(T.src_x, src_p, nn.classifier) if args.sw > 0 and args.dirt == 0 else constant(0)
    loss_trg_vat = vat_loss(T.trg_x, trg_p, nn.classifier) if args.tw > 0 else constant(0)

    # Evaluation (EMA)
    ema = tf.train.ExponentialMovingAverage(decay=0.998)
    var_class = tf.get_collection('trainable_variables', 'class/')
    ema_op = ema.apply(var_class)
    ema_p = nn.classifier(T.test_x, phase=False, reuse=True, getter=tb.tfutils.get_getter(ema))

    # Teacher model (a back-up of EMA model)
    teacher_p = nn.classifier(T.test_x, phase=False, scope='teacher')
    var_main = tf.get_collection('variables', 'class/(?!.*ExponentialMovingAverage:0)')
    var_teacher = tf.get_collection('variables', 'teacher/(?!.*ExponentialMovingAverage:0)')
    teacher_assign_ops = []
    for t, m in zip(var_teacher, var_main):
        ave = ema.average(m)
        ave = ave if ave else m
        teacher_assign_ops += [tf.assign(t, ave)]
    update_teacher = tf.group(*teacher_assign_ops)
    teacher = tb.function(T.sess, [T.test_x], tf.nn.softmax(teacher_p))

    # Accuracies
    src_acc = basic_accuracy(T.src_y, src_p)
    trg_acc = basic_accuracy(T.trg_y, trg_p)
    ema_acc = basic_accuracy(T.test_y, ema_p)
    fn_ema_acc = tb.function(T.sess, [T.test_x, T.test_y], ema_acc)

    # Optimizer
    dw = constant(args.dw) if args.dirt == 0 else constant(0)
    cw = constant(1)       if args.dirt == 0 else constant(args.bw)
    sw = constant(args.sw) if args.dirt == 0 else constant(0)
    tw = constant(args.tw)
    loss_main = (dw * loss_domain +
                 cw * loss_src_class +
                 sw * loss_src_vat +
                 tw * loss_trg_cent +
                 tw * loss_trg_vat)
    var_main = tf.get_collection('trainable_variables', 'class')
    train_main = tf.train.AdamOptimizer(args.lr, 0.5).minimize(loss_main, var_list=var_main)
    train_main = tf.group(train_main, ema_op)

    if args.dw > 0 and args.dirt == 0:
        var_disc = tf.get_collection('trainable_variables', 'disc')
        train_disc = tf.train.AdamOptimizer(args.lr, 0.5).minimize(loss_disc, var_list=var_disc)
    else:
        train_disc = constant(0)

    # Summarizations
    summary_disc = [tf.summary.scalar('domain/loss_disc', loss_disc),]
    summary_main = [tf.summary.scalar('domain/loss_domain', loss_domain),
                    tf.summary.scalar('class/loss_src_class', loss_src_class),
                    tf.summary.scalar('class/loss_trg_cent', loss_trg_cent),
                    tf.summary.scalar('lipschitz/loss_trg_vat', loss_trg_vat),
                    tf.summary.scalar('lipschitz/loss_src_vat', loss_src_vat),
                    tf.summary.scalar('hyper/dw', dw),
                    tf.summary.scalar('hyper/cw', cw),
                    tf.summary.scalar('hyper/sw', sw),
                    tf.summary.scalar('hyper/tw', tw),
                    tf.summary.scalar('acc/src_acc', src_acc),
                    tf.summary.scalar('acc/trg_acc', trg_acc)]

    # Merge summaries
    summary_disc = tf.summary.merge(summary_disc)
    summary_main = tf.summary.merge(summary_main)

    # Saved ops
    c = tf.constant
    T.ops_print = [c('disc'), loss_disc,
                   c('domain'), loss_domain,
                   c('class'), loss_src_class,
                   c('cent'), loss_trg_cent,
                   c('trg_vat'), loss_trg_vat,
                   c('src_vat'), loss_src_vat,
                   c('src'), src_acc,
                   c('trg'), trg_acc]
    T.ops_disc = [summary_disc, train_disc]
    T.ops_main = [summary_main, train_main]
    T.fn_ema_acc = fn_ema_acc
    T.teacher = teacher
    T.update_teacher = update_teacher

    return T
Ejemplo n.º 11
0
def gada():
    T = tb.utils.TensorDict(dict(
        sess = tf.Session(config=tb.growth_config()),
        src_x = placeholder((None, 32, 32, 3)),
        src_y = placeholder((None, args.Y)),
        trg_x = placeholder((None, 32, 32, 3)),
        trg_y = placeholder((None, args.Y)),
        trg_z = placeholder((None, 100)),
        test_x = placeholder((None, 32, 32, 3)),
        test_y = placeholder((None, args.Y)),
    ))

    # Supervised and conditional entropy minimization
    src_e = nn.classifier(T.src_x, phase=True, enc_phase=1, enc_trim=args.etrim)
    src_g = nn.classifier(src_e, phase=True, gen_trim=args.gtrim, gen_phase=1, enc_trim=args.etrim)
    src_p = nn.classifier(src_g, phase=True, gen_trim=args.gtrim)
    trg_e = nn.classifier(T.trg_x, phase=True, enc_phase=1, enc_trim=args.etrim, reuse=True, internal_update=True)
    trg_g = nn.classifier(trg_e, phase=True, gen_trim=args.gtrim, gen_phase=1, enc_trim=args.etrim, reuse=True, internal_update=True)
    trg_p = nn.classifier(trg_g, phase=True, gen_trim=args.gtrim, reuse=True, internal_update=True)

    loss_src_class = tf.reduce_mean(softmax_xent(labels=T.src_y, logits=src_p))
    loss_trg_cent = tf.reduce_mean(softmax_xent_two(labels=trg_p, logits=trg_p)) if args.tw > 0 else constant(0)

    # Domain confusion
    if args.dw > 0 and args.dirt == 0:
        real_logit = nn.real_feature_discriminator(src_e, phase=True)
        fake_logit = nn.real_feature_discriminator(trg_e, phase=True, reuse=True)

        loss_disc = 0.5 * tf.reduce_mean(
            sigmoid_xent(labels=tf.ones_like(real_logit), logits=real_logit) +
            sigmoid_xent(labels=tf.zeros_like(fake_logit), logits=fake_logit))
        loss_domain = 0.5 * tf.reduce_mean(
            sigmoid_xent(labels=tf.zeros_like(real_logit), logits=real_logit) +
            sigmoid_xent(labels=tf.ones_like(fake_logit), logits=fake_logit))

    else:
        loss_disc = constant(0)
        loss_domain = constant(0)

    # Virtual adversarial training (turn off src in non-VADA phase)
    loss_src_vat = vat_loss(T.src_x, src_p, nn.classifier) if args.sw > 0 and args.dirt == 0 else constant(0)
    loss_trg_vat = vat_loss(T.trg_x, trg_p, nn.classifier) if args.tw > 0 else constant(0)

    # Generate images and process generated images
    trg_gen_x = nn.trg_generator(T.trg_z)
    trg_gen_e = nn.classifier(trg_gen_x, phase=True, enc_phase=1, enc_trim=args.etrim, reuse=True, internal_update=True)
    trg_gen_g = nn.classifier(trg_gen_e, phase=True, gen_trim=args.gtrim, gen_phase=1, enc_trim=args.etrim, reuse=True, internal_update=True)
    trg_gen_p = nn.classifier(trg_gen_g, phase=True, gen_trim=args.gtrim, reuse=True, internal_update=True)

    # Feature matching loss function for generator
    loss_trg_gen_fm = tf.reduce_mean(tf.square(tf.reduce_mean(trg_g, axis=0) - tf.reduce_mean(trg_gen_g, axis=0))) if args.dirt == 0 else constant(0)

    # Unsupervised loss function
    if args.dirt == 0:
        logit_real = tf.reduce_logsumexp(trg_p, axis=1)
        logit_fake = tf.reduce_logsumexp(trg_gen_p, axis=1)
        dis_loss_real = -0.5*tf.reduce_mean(logit_real) + 0.5*tf.reduce_mean(tf.nn.softplus(logit_real))
        dis_loss_fake = 0.5*tf.reduce_mean(tf.nn.softplus(logit_fake))
        loss_trg_usv = dis_loss_real + dis_loss_fake    # UnSuperVised loss function
    else:
        loss_trg_usv = constant(0)

    # Evaluation (EMA)
    ema = tf.train.ExponentialMovingAverage(decay=0.998)
    var_class = tf.get_collection('trainable_variables', 'class/')
    ema_op = ema.apply(var_class)
    ema_p = nn.classifier(T.test_x, enc_phase=1, enc_trim=0, phase=False, reuse=True, getter=tb.tfutils.get_getter(ema))

    # Teacher model (a back-up of EMA model)
    teacher_p = nn.classifier(T.test_x, enc_phase=1, enc_trim=0, phase=False, scope='teacher')
    var_main = tf.get_collection('variables', 'class/(?!.*ExponentialMovingAverage:0)')
    var_teacher = tf.get_collection('variables', 'teacher/(?!.*ExponentialMovingAverage:0)')
    teacher_assign_ops = []
    for t, m in zip(var_teacher, var_main):
        ave = ema.average(m)
        ave = ave if ave else m
        teacher_assign_ops += [tf.assign(t, ave)]
    update_teacher = tf.group(*teacher_assign_ops)
    teacher = tb.function(T.sess, [T.test_x], tf.nn.softmax(teacher_p))

    # Accuracies
    src_acc = basic_accuracy(T.src_y, src_p)
    trg_acc = basic_accuracy(T.trg_y, trg_p)
    ema_acc = basic_accuracy(T.test_y, ema_p)
    fn_ema_acc = tb.function(T.sess, [T.test_x, T.test_y], ema_acc)

    # Optimizer
    dw = constant(args.dw) if args.dirt == 0 else constant(0)
    cw = constant(1)       if args.dirt == 0 else constant(args.bw)
    sw = constant(args.sw) if args.dirt == 0 else constant(0)
    tw = constant(args.tw)
    uw = constant(args.uw) if args.dirt == 0 else constant(0)
    loss_main = (dw * loss_domain +
                 cw * loss_src_class +
                 sw * loss_src_vat +
                 tw * loss_trg_cent +
                 tw * loss_trg_vat +
                 uw * loss_trg_usv)
    var_main = tf.get_collection('trainable_variables', 'class')
    train_main = tf.train.AdamOptimizer(args.lr, 0.5).minimize(loss_main, var_list=var_main)
    train_main = tf.group(train_main, ema_op)

    # Optimizer for feature discriminator
    if args.dw > 0 and args.dirt == 0:
        var_disc = tf.get_collection('trainable_variables', 'disc_real')
        train_disc = tf.train.AdamOptimizer(args.lr, 0.5).minimize(loss_disc, var_list=var_disc)
    else:
        train_disc = constant(0)

    # Optimizer for generators
    if args.dirt == 0:
        fmw = constant(1)
        loss_trg_gen = (fmw * loss_trg_gen_fm)
        var_trg_gen = tf.get_collection('trainable_variables', 'trg_gen')
        trg_gen_update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS, scope='trg_gen')
        with tf.control_dependencies(trg_gen_update_ops):
            train_trg_gen = tf.train.AdamOptimizer(args.lr, 0.5).minimize(loss_trg_gen, var_list=var_trg_gen)
        train_gen = train_trg_gen
    else:
        fmw = constant(0)
        train_gen = constant(0)

    # Summarizations
    summary_disc = [tf.summary.scalar('domain/loss_disc', loss_disc),]
    summary_main = [tf.summary.scalar('domain/loss_domain', loss_domain),
                    tf.summary.scalar('class/loss_src_class', loss_src_class),
                    tf.summary.scalar('class/loss_trg_cent', loss_trg_cent),
                    tf.summary.scalar('class/loss_trg_usv', loss_trg_usv),
                    tf.summary.scalar('lipschitz/loss_trg_vat', loss_trg_vat),
                    tf.summary.scalar('lipschitz/loss_src_vat', loss_src_vat),
                    tf.summary.scalar('hyper/dw', dw),
                    tf.summary.scalar('hyper/cw', cw),
                    tf.summary.scalar('hyper/sw', sw),
                    tf.summary.scalar('hyper/tw', tw),
                    tf.summary.scalar('hyper/uw', uw),
                    tf.summary.scalar('hyper/fmw', fmw),
                    tf.summary.scalar('acc/src_acc', src_acc),
                    tf.summary.scalar('acc/trg_acc', trg_acc)]
    summary_gen  = [tf.summary.scalar('gen/loss_trg_gen_fm', loss_trg_gen_fm),
                    tf.summary.image('gen/trg_gen_img', trg_gen_x),]

    # Merge summaries
    summary_disc = tf.summary.merge(summary_disc)
    summary_main = tf.summary.merge(summary_main)
    summary_gen  = tf.summary.merge(summary_gen)

    # Saved ops
    c = tf.constant
    T.ops_print = [c('disc'), loss_disc,
                   c('domain'), loss_domain,
                   c('class'), loss_src_class,
                   c('cent'), loss_trg_cent,
                   c('trg_vat'), loss_trg_vat,
                   c('src_vat'), loss_src_vat,
                   c('src'), src_acc,
                   c('trg'), trg_acc]
    T.ops_disc = [summary_disc, train_disc]
    T.ops_main = [summary_main, train_main]
    T.ops_gen  = [summary_gen , train_gen]
    T.fn_ema_acc = fn_ema_acc
    T.teacher = teacher
    T.update_teacher = update_teacher
    T.trg_gen_x = trg_gen_x
    T.trg_gen_p = trg_gen_p
    T.src_p = src_p
    T.trg_p = trg_p
    T.ema_p = ema_p

    return T