Пример #1
0
def classifier():
    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    T = tb.utils.TensorDict(dict(
        sess = tf.Session(config=config),
        src_x = placeholder((None, 32, 32, 3),  name='source_x'),
        src_y = placeholder((None, 10),         name='source_y'),
        test_x = placeholder((None, 32, 32, 3), name='test_x'),
        test_y = placeholder((None, 10),        name='test_y'),
        phase = placeholder((), tf.bool,        name='phase')
    ))

    # Classification
    src_y = des.classifier(T.src_x, T.phase, internal_update=True)
    loss_class = tf.reduce_mean(softmax_xent(labels=T.src_y, logits=src_y))
    src_acc = basic_accuracy(T.src_y, src_y)

    # Evaluation (non-EMA)
    test_y = des.classifier(T.test_x, phase=False, reuse=True)
    test_acc = basic_accuracy(T.test_y, test_y)
    fn_test_acc = tb.function(T.sess, [T.test_x, T.test_y], test_acc)

    # Evaluation (EMA)
    ema = tf.train.ExponentialMovingAverage(decay=0.998)
    var_class = tf.get_collection('trainable_variables', 'class')
    ema_op = ema.apply(var_class)
    ema_y = des.classifier(T.test_x, phase=False, reuse=True, getter=get_getter(ema))
    ema_acc = basic_accuracy(T.test_y, ema_y)
    fn_ema_acc = tb.function(T.sess, [T.test_x, T.test_y], ema_acc)

    # Optimizer
    loss_main = loss_class
    var_main = var_class
    train_main = tf.train.AdamOptimizer(args.lr, 0.5).minimize(loss_main, var_list=var_main)
    train_main = tf.group(train_main, ema_op)

    # Summarizations
    summary_main = [
        tf.summary.scalar('class/loss_class', loss_class),
        tf.summary.scalar('acc/src_acc', src_acc),
    ]
    summary_main = tf.summary.merge(summary_main)

    # Saved ops
    c = tf.constant
    T.ops_print = [
        c('class'), loss_class,
        c('src'), src_acc,
    ]

    T.ops_main = [summary_main, train_main]
    T.fn_test_acc = fn_test_acc
    T.fn_ema_acc = fn_ema_acc

    return T
Пример #2
0
def model(FLAGS):
    print(colored("Model is called.", "blue"))

    T = tb.utils.TensorDict(
        dict(sess=tf.Session(config=tb.growth_config()),
             sv=placeholder((FLAGS.bs, 369539)),
             ts=placeholder((FLAGS.bs, )),
             test_sv=placeholder((FLAGS.bs, 369539)),
             test_ts=placeholder((FLAGS.bs, ))))

    # h = dense(T.sv, FLAGS.d, scope='hidden', bn=False, phase=True, reuse=tf.AUTO_REUSE)
    # o = dense(h, 1, scope='out', bn=False, phase=True, reuse=tf.AUTO_REUSE)
    # test_h = dense(T.test_sv, FLAGS.d, scope='hidden', bn=False, phase=False, reuse=tf.AUTO_REUSE)
    # test_o = dense(test_h, 1, scope='out', bn=False, phase=False, reuse=tf.AUTO_REUSE)

    hidden1 = tf.get_variable("hidden1", [369539, FLAGS.d])
    hidden2 = tf.get_variable("hidden2", [FLAGS.d, 1])

    h = tf.matmul(T.sv, hidden1)
    o = tf.matmul(h, hidden2)

    test_h = tf.matmul(T.test_sv, hidden1)
    test_o = tf.matmul(test_h, hidden2)

    loss = tf.reduce_mean(tf.squared_difference(o, T.ts))

    test_o_mean = tf.reduce_mean(test_o)
    test_ts_mean = tf.reduce_mean(T.test_ts)
    test_error = tf.reduce_mean((test_o - T.test_ts) / T.test_ts) * 100.0
    error = tf.reduce_mean((o - T.ts) / T.ts) * 100.0
    # optimizer = tf.train.AdagradOptimizer(FLAGS.lr).minimize(loss)
    optimizer = tf.train.AdamOptimizer(FLAGS.lr).minimize(loss)

    summary = [
        tf.summary.scalar('loss', loss),
        tf.summary.scalar('error', error)
    ]
    summary = tf.summary.merge(summary)

    c = tf.constant
    T.ops_print = [
        c('loss'), loss,
        c('error'), error,
        c('test_error'), test_error,
        c('test_o_mean'), test_o_mean,
        c('test_ts_mean'), test_ts_mean
    ]

    T.ops = [summary, optimizer]

    print(colored("Model is initialized.", "blue"))

    return T
Пример #3
0
    def fit(self, x_input, epochs = 1000, learning_rate = 0.001, batch_size = 100, print_size = 50, train=True):
        # training setting
        self.DO_SHARE = False
        self.epochs = epochs
        self.learning_rate = learning_rate
        self.batch_size = batch_size
        self.print_size = print_size

        self.g = tf.Graph()
        # inference process
        x_ = placeholder((None, self.input_dim))
        x = x_
        depth_inf = len(self.encoding_dims)
        for i in range(depth_inf):
            x = dense(x, self.encoding_dims[i], scope="enc_layer"+"%s" %i, activation=tf.nn.sigmoid)
        h_encode = x
        z_mu = dense(h_encode, self.z_dim, scope="mu_layer")
        z_log_sigma_sq = dense(h_encode, self.z_dim, scope = "sigma_layer")
        e = tf.random_normal(tf.shape(z_mu))
        z = z_mu + tf.sqrt(tf.maximum(tf.exp(z_log_sigma_sq), self.eps)) * e

        # generative process
        if self.useTranse == False:
            depth_gen = len(self.decoding_dims)

            for i in range(depth_gen):
                y = dense(z, self.decoding_dims[i], scope="dec_layer"+"%s" %i, activation=tf.nn.sigmoid)
                # if last_layer_nonelinear: depth_gen -1

        else:
            depth_gen = depth_inf
            ## haven't finnished yet...

        x_recons = y

        if self.loss == "cross_entropy":
            loss_recons = tf.reduce_mean(tf.reduce_sum(binary_crossentropy(x_recons, x_, self.eps), axis=1))
            loss_kl = 0.5 * tf.reduce_mean(tf.reduce_sum(tf.square(z_mu) + tf.exp(z_log_sigma_sq) - z_log_sigma_sq - 1, 1))
            # loss_kl = 0.5 * tf.reduce_mean(tf.reduce_sum(tf.square(z_mu) + tf.exp(z_log_sigma_sq) - z_log_sigma_sq - 1, 1))
            loss = loss_recons + loss_kl
        # other cases not finished yet
        train_op = tf.train.AdamOptimizer(self.learning_rate).minimize(loss)

        sess = tf.Session()
        sess.run(tf.global_variables_initializer())
        saver = tf.train.Saver()
        ckpt_dir = "pre_model/" + "vae.ckpt"
        if train == True:
            # num_turn = x_input.shape[0] / self.batch_size
            for i in range(epochs):
                idx = np.random.choice(x_input.shape[0], batch_size, replace=False)
                x_batch = x_input[idx]
                _, l = sess.run((train_op, loss), feed_dict={x_:x_batch})
                if i % self.print_size == 0:
                    print "{:>10s}{:>10s}".format("epoces","loss")
                    print "{:10.2e}{:10.2e}".format(i, l)
            saver.save(sess, ckpt_dir)
        else:
            saver.restore(sess, ckpt_dir)
Пример #4
0
def classifier():
    T = tb.utils.TensorDict(
        dict(sess=tf.Session(config=tb.growth_config()),
             src_x=placeholder((None, 32, 32, 3)),
             src_y=placeholder((None, 10)),
             trg_x=placeholder((None, 32, 32, 3)),
             trg_y=placeholder((None, 10)),
             test_x=placeholder((None, 32, 32, 3)),
             test_y=placeholder((None, 10)),
             phase=placeholder((), tf.bool)))

    # Supervised and conditional entropy minimization
    src_y = net.classifier(T.src_x, phase=True, internal_update=False)
    trg_y = net.classifier(T.trg_x,
                           phase=True,
                           internal_update=True,
                           reuse=True)

    loss_class = tf.reduce_mean(softmax_xent(labels=T.src_y, logits=src_y))

    # Evaluation (non-EMA)
    test_y = net.classifier(T.test_x, phase=False, scope='class', reuse=True)

    # Evaluation (EMA)
    ema = tf.train.ExponentialMovingAverage(decay=0.998)
    ema_op = ema.apply(tf.get_collection('trainable_variables', 'class/'))
    T.ema_y = net.classifier(T.test_x,
                             phase=False,
                             reuse=True,
                             getter=get_getter(ema))

    src_acc = basic_accuracy(T.src_y, src_y)
    trg_acc = basic_accuracy(T.trg_y, trg_y)
    ema_acc = basic_accuracy(T.test_y, T.ema_y)
    fn_ema_acc = tb.function(T.sess, [T.test_x, T.test_y], ema_acc)

    # Optimizer
    loss_main = loss_class
    var_main = tf.get_collection('trainable_variables', 'class')
    train_main = tf.train.AdamOptimizer(args.lr,
                                        0.5).minimize(loss_main,
                                                      var_list=var_main)
    train_main = tf.group(train_main, ema_op)

    # Summarizations
    summary_main = [
        tf.summary.scalar('class/loss_class', loss_class),
        tf.summary.scalar('acc/src_acc', src_acc),
        tf.summary.scalar('acc/trg_acc', trg_acc)
    ]
    summary_main = tf.summary.merge(summary_main)

    # Saved ops
    c = tf.constant
    T.ops_print = [c('class'), loss_class]
    T.ops_main = [summary_main, train_main]
    T.fn_ema_acc = fn_ema_acc

    return T
Пример #5
0
    def __init__(self, k=10, n_x=784, n_z=64):
        self.k = k
        self.n_x = n_x
        self.n_z = n_z
        tf.reset_default_graph()
        x = placeholder((None, n_x), name='x')
        phase = tf.placeholder(tf.bool, name='phase')

        # create a y "placeholder"
        with tf.name_scope('y_'):
            y_ = tf.fill(tf.stack([tf.shape(x)[0], k]), 0.0)

        # propose distribution over y
        self.qy_logit, self.qy = qy_graph(x, k, phase)

        # for each proposed y, infer z and reconstruct x
        self.z, \
        self.zm, \
        self.zv, \
        self.zm_prior, \
        self.zv_prior, \
        self.xm, \
        self.xv, \
        self.y = [[None] * k for i in range(8)]
        for i in range(k):
            with tf.name_scope('graphs/hot_at{:d}'.format(i)):
                y = tf.add(
                    y_, constant(np.eye(k)[i], name='hot_at_{:d}'.format(i)))
                self.z[i], self.zm[i], self.zv[i] = qz_graph(x, y, n_z, phase)
                self.y[i], \
                self.zm_prior[i], \
                self.zv_prior[i] = pz_graph(y, n_z, phase)
                self.xm[i], self.xv[i] = px_graph(self.z[i], n_x, phase)

        # Aggressive name scoping for pretty graph visualization :P
        with tf.name_scope('loss'):
            with tf.name_scope('neg_entropy'):
                self.nent = -tf.nn.softmax_cross_entropy_with_logits_v2(
                    labels=self.qy, logits=self.qy_logit)
            losses = [None] * k
            for i in range(k):
                with tf.name_scope('loss_at{:d}'.format(i)):
                    losses[i] = labeled_loss(x, self.xm[i], self.xv[i],
                                             self.z[i], self.zm[i], self.zv[i],
                                             self.zm_prior[i],
                                             self.zv_prior[i])
            with tf.name_scope('final_loss'):
                self.loss = tf.add_n(
                    #[self.nent] +
                    [self.qy[:, i] * losses[i] for i in range(k)])

        self.train_step = tf.train.AdamOptimizer(0.00001).minimize(self.loss)

        show_default_graph()
def dann_embed():
    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    # config.gpu_options.per_process_gpu_memory_fraction = 0.45
    T = tb.utils.TensorDict(dict(
        sess = tf.Session(config=config),
        src_x = placeholder((None, H, W, 3)),
        src_y = placeholder((None, Y)),
        trg_x = placeholder((None, H, W, 3)),
        trg_y = placeholder((None, Y)),
        fake_z = placeholder((None, 100)),
        fake_y = placeholder((None, Y)),
        test_x = placeholder((None, H, W, 3)),
        test_y = placeholder((None, Y)),
        phase = placeholder((), tf.bool)
    ))

    # Schedules
    start, end = args.pivot - 1, args.pivot
    global_step = tf.Variable(0., trainable=False)
    # Ramp down dw
    ramp_dw = conditional_ramp_weight(args.dwdn, global_step, args.dw, 0, start, end)
    # Ramp down src
    ramp_class = conditional_ramp_weight(args.dn, global_step, 1, args.dcval, start, end)
    ramp_sbw = conditional_ramp_weight(args.dn, global_step, args.sbw, 0, start, end)
    # Ramp up trg (never more than src)
    ramp_cw = conditional_ramp_weight(args.up, global_step, args.cw, args.uval, start, end)
    ramp_gw = conditional_ramp_weight(args.up, global_step, args.gw, args.uval, start, end)
    ramp_gbw = conditional_ramp_weight(args.up, global_step, args.gbw, args.uval, start, end)
    ramp_tbw = conditional_ramp_weight(args.up, global_step, args.tbw, args.uval, start, end)

    # Supervised and conditional entropy minimization
    src_e = des.classifier(T.src_x, T.phase, enc_phase=1, trim=args.trim, scope='class', internal_update=False)
    trg_e = des.classifier(T.trg_x, T.phase, enc_phase=1, trim=args.trim, scope='class', reuse=True, internal_update=True)
    src_y = des.classifier(src_e, T.phase, enc_phase=0, trim=args.trim, scope='class', internal_update=False)
    trg_y = des.classifier(trg_e, T.phase, enc_phase=0, trim=args.trim, scope='class', reuse=True, internal_update=True)

    loss_class = tf.reduce_mean(softmax_xent(labels=T.src_y, logits=src_y))
    loss_cent = tf.reduce_mean(softmax_xent_two(labels=trg_y, logits=trg_y))

    # Image generation
    if args.gw > 0:
        fake_x = des.generator(T.fake_z, T.fake_y, T.phase)
        fake_logit = des.discriminator(fake_x, T.phase)
        real_logit = des.discriminator(T.trg_x, T.phase, reuse=True)
        fake_e = des.classifier(fake_x, T.phase, enc_phase=1, trim=args.trim, scope='class', reuse=True)
        fake_y = des.classifier(fake_e, T.phase, enc_phase=0, trim=args.trim, scope='class', reuse=True)

        loss_gdisc = 0.5 * tf.reduce_mean(
            sigmoid_xent(labels=tf.ones_like(real_logit), logits=real_logit) +
            sigmoid_xent(labels=tf.zeros_like(fake_logit), logits=fake_logit))
        loss_gen = tf.reduce_mean(sigmoid_xent(labels=tf.ones_like(fake_logit), logits=fake_logit))
        loss_info = tf.reduce_mean(softmax_xent(labels=T.fake_y, logits=fake_y))

    else:
        loss_gdisc = constant(0)
        loss_gen = constant(0)
        loss_info = constant(0)

    # Domain confusion
    if args.dw > 0 and args.phase == 0:
        real_logit = des.feature_discriminator(src_e, T.phase)
        fake_logit = des.feature_discriminator(trg_e, T.phase, reuse=True)

        loss_ddisc = 0.5 * tf.reduce_mean(
            sigmoid_xent(labels=tf.ones_like(real_logit), logits=real_logit) +
            sigmoid_xent(labels=tf.zeros_like(fake_logit), logits=fake_logit))
        loss_domain = 0.5 * tf.reduce_mean(
            sigmoid_xent(labels=tf.zeros_like(real_logit), logits=real_logit) +
            sigmoid_xent(labels=tf.ones_like(fake_logit), logits=fake_logit))

    else:
        loss_ddisc = constant(0)
        loss_domain = constant(0)

    # Smoothing
    loss_t_ball = constant(0) if args.tbw == 0 else smoothing_loss(T.trg_x, trg_y, T.phase)
    loss_s_ball = constant(0) if args.sbw == 0 or args.phase == 1 else smoothing_loss(T.src_x, src_y, T.phase)
    loss_g_ball = constant(0) if args.gbw == 0 else smoothing_loss(fake_x, fake_y, T.phase)

    loss_t_emb = constant(0) if args.te == 0 else smoothing_loss(T.trg_x, trg_e, T.phase, is_embedding=True)
    loss_s_emb = constant(0) if args.se == 0 else smoothing_loss(T.src_x, src_e, T.phase, is_embedding=True)

    # Evaluation (non-EMA)
    test_y = des.classifier(T.test_x, False, enc_phase=1, trim=0, scope='class', reuse=True)

    # Evaluation (EMA)
    ema = tf.train.ExponentialMovingAverage(decay=0.998)
    var_class = tf.get_collection('trainable_variables', 'class/')
    ema_op = ema.apply(var_class)
    T.ema_e = des.classifier(T.test_x, False, enc_phase=1, trim=args.trim, scope='class', reuse=True, getter=get_getter(ema))
    ema_y = des.classifier(T.ema_e, False, enc_phase=0, trim=args.trim, scope='class', reuse=True, getter=get_getter(ema))

    # Back-up (teacher) model
    back_y = des.classifier(T.test_x, False, enc_phase=1, trim=0, scope='back')
    var_main = tf.get_collection('variables', 'class/(?!.*ExponentialMovingAverage:0)')
    var_back = tf.get_collection('variables', 'back/(?!.*ExponentialMovingAverage:0)')
    back_assigns = []
    init_assigns = []
    for b, m in zip(var_back, var_main):
        ave = ema.average(m)
        target = ave if ave else m
        back_assigns += [tf.assign(b, target)]
        init_assigns += [tf.assign(m, target)]
        # print "Assign {} -> {}, {}".format(target.name, b.name, m.name)
    back_update = tf.group(*back_assigns)
    init_update = tf.group(*init_assigns)

    src_acc = basic_accuracy(T.src_y, src_y)
    trg_acc = basic_accuracy(T.trg_y, trg_y)
    test_acc = basic_accuracy(T.test_y, test_y)
    ema_acc = basic_accuracy(T.test_y, ema_y)
    fn_test_acc = tb.function(T.sess, [T.test_x, T.test_y], test_acc)
    fn_ema_acc = tb.function(T.sess, [T.test_x, T.test_y], ema_acc)

    # Optimizer
    loss_main = (ramp_class * loss_class +
                 ramp_dw * loss_domain +
                 ramp_cw * loss_cent +
                 ramp_tbw * loss_t_ball +
                 ramp_gbw * loss_g_ball +
                 ramp_sbw * loss_s_ball +
                 args.te * loss_t_emb +
                 args.se * loss_s_emb +
                 ramp_gw * loss_gen +
                 ramp_gw * loss_info)
    var_main = tf.get_collection('trainable_variables', 'class')
    var_main += tf.get_collection('trainable_variables', 'gen')
    train_main = tf.train.AdamOptimizer(args.lr, 0.5).minimize(loss_main,
                                                               var_list=var_main,
                                                               global_step=global_step)
    train_main = tf.group(train_main, ema_op)

    if (args.dw > 0 and args.phase == 0) or args.gw > 0:
        loss_disc = loss_ddisc + loss_gdisc
        var_disc = tf.get_collection('trainable_variables', 'disc')
        train_disc = tf.train.AdamOptimizer(args.lr, 0.5).minimize(loss_disc,
                                                                   var_list=var_disc)
    else:
        train_disc = constant(0)

    # Summarizations
#    embedding = tf.Variable(tf.zeros([1000,12800]),name='embedding')
#    embedding = tf.reshape(trg_e[:1000], [-1,12800])

    summary_disc = [tf.summary.scalar('domain/loss_ddisc', loss_ddisc),
                    tf.summary.scalar('gen/loss_gdisc', loss_gdisc)]

    summary_main = [tf.summary.scalar('class/loss_class', loss_class),
                    tf.summary.scalar('class/loss_cent', loss_cent),
                    tf.summary.scalar('domain/loss_domain', loss_domain),
                    tf.summary.scalar('lipschitz/loss_t_ball', loss_t_ball),
                    tf.summary.scalar('lipschitz/loss_g_ball', loss_g_ball),
                    tf.summary.scalar('lipschitz/loss_s_ball', loss_s_ball),
                    tf.summary.scalar('embedding/loss_t_emb', loss_t_emb),
                    tf.summary.scalar('embedding/loss_s_emb', loss_s_emb),
                    tf.summary.scalar('gen/loss_gen', loss_gen),
                    tf.summary.scalar('gen/loss_info', loss_info),
                    tf.summary.scalar('ramp/ramp_class', ramp_class),
                    tf.summary.scalar('ramp/ramp_dw', ramp_dw),
                    tf.summary.scalar('ramp/ramp_cw', ramp_cw),
                    tf.summary.scalar('ramp/ramp_gw', ramp_gw),
                    tf.summary.scalar('ramp/ramp_tbw', ramp_tbw),
                    tf.summary.scalar('ramp/ramp_sbw', ramp_sbw),
                    tf.summary.scalar('ramp/ramp_gbw', ramp_gbw),
                    tf.summary.scalar('acc/src_acc', src_acc),
                    tf.summary.scalar('acc/trg_acc', trg_acc)]

    summary_disc = tf.summary.merge(summary_disc)
    summary_main = tf.summary.merge(summary_main)

    # Saved ops
    c = tf.constant
    T.ops_print = [c('ddisc'), loss_ddisc,
                   c('domain'), loss_domain,
                   c('gdisc'), loss_gdisc,
                   c('gen'), loss_gen,
                   c('info'), loss_info,
                   c('class'), loss_class,
                   c('cent'), loss_cent,
                   c('t_ball'), loss_t_ball,
                   c('g_ball'), loss_g_ball,
                   c('s_ball'), loss_s_ball,
                   c('t_emb'), loss_t_emb,
                   c('s_emb'), loss_s_emb,
                   c('src'), src_acc,
                   c('trg'), trg_acc]

    T.ops_disc = [summary_disc, train_disc]
    T.ops_main = [summary_main, train_main]
    T.fn_test_acc = fn_test_acc
    T.fn_ema_acc = fn_ema_acc
    T.back_y = tf.nn.softmax(back_y)  # Access to backed-up eval model softmax
    T.back_update = back_update       # Update op eval -> backed-up eval model
    T.init_update = init_update       # Update op eval -> student eval model
    T.global_step = global_step
    T.ramp_class = ramp_class
    if args.gw > 0:
        summary_image = tf.summary.image('image/gen', generate_img())
        T.ops_image = summary_image

    return T
Пример #7
0
def vae():
    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    T = tb.utils.TensorDict(dict(
        sess = tf.Session(config=config),
        src_x = placeholder((None, 32, 32, 3),  name='source_x'),
        src_y = placeholder((None, 10),         name='source_y'),
        trg_x = placeholder((None, 32, 32, 3),  name='target_x'),
        trg_y = placeholder((None, 10),         name='target_y'),
        test_x = placeholder((None, 32, 32, 3), name='test_x'),
        test_y = placeholder((None, 10),        name='test_y'),
        fake_z = placeholder((None, 100),       name='fake_z'),
        fake_y = placeholder((None, 10),        name='fake_y'),
        tau = placeholder((),                   name='tau'),
        phase = placeholder((), tf.bool,        name='phase'),
    ))

    if args.gw > 0:
        # Variational inference
        y_logit = des.classifier(T.trg_x, T.phase, internal_update=True)
        y = gumbel_softmax(y_logit, T.tau)
        z, z_post = des.encoder(T.trg_x, y, T.phase, internal_update=True)

        # Generation
        x = des.generator(z, y, T.phase, internal_update=True)

        # Loss
        z_prior = (0., 1.)
        kl_z = tf.reduce_mean(log_normal(z, *z_post) - log_normal(z, *z_prior))

        y_q = tf.nn.softmax(y_logit)
        log_y_q = tf.nn.log_softmax(y_logit)
        kl_y = tf.reduce_mean(tf.reduce_sum(y_q * (log_y_q - tf.log(0.1)), axis=1))

        loss_kl = kl_z + kl_y
        loss_rec = args.rw * tf.reduce_mean(tf.reduce_sum(tf.square(T.trg_x - x), axis=[1,2,3]))
        loss_gen = loss_rec + loss_kl
        trg_acc = basic_accuracy(T.trg_y, y_logit)

    else:
        loss_kl = constant(0)
        loss_rec = constant(0)
        loss_gen = constant(0)
        trg_acc = constant(0)

    # Posterior regularization (labeled classification)
    src_y = des.classifier(T.src_x, T.phase, reuse=True)
    loss_class = tf.reduce_mean(softmax_xent(labels=T.src_y, logits=src_y))
    src_acc = basic_accuracy(T.src_y, src_y)

    # Evaluation (classification)
    test_y = des.classifier(T.test_x, phase=False, reuse=True)
    test_acc = basic_accuracy(T.test_y, test_y)
    fn_test_acc = tb.function(T.sess, [T.test_x, T.test_y], test_acc)

    # Evaluation (generation)
    if args.gw > 0:
        fake_x = des.generator(T.fake_z, T.fake_y, phase=False, reuse=True)
        fn_fake_x = tb.function(T.sess, [T.fake_z, T.fake_y], fake_x)

    # Optimizer
    var_main = tf.get_collection('trainable_variables', 'gen/')
    var_main += tf.get_collection('trainable_variables', 'enc/')
    var_main += tf.get_collection('trainable_variables', 'class/')
    loss_main = args.gw * loss_gen + loss_class
    train_main = tf.train.AdamOptimizer(args.lr, 0.5).minimize(loss_main, var_list=var_main)

    # Summarizations
    summary_main = [
        tf.summary.scalar('gen/loss_gen', loss_gen),
        tf.summary.scalar('gen/loss_rec', loss_rec),
        tf.summary.scalar('gen/loss_kl', loss_kl),
        tf.summary.scalar('class/loss_class', loss_class),
        tf.summary.scalar('acc/src_acc', src_acc),
        tf.summary.scalar('acc/trg_acc', trg_acc),
    ]
    summary_main = tf.summary.merge(summary_main)

    if args.gw > 0:
        summary_image = tf.summary.image('image/gen', generate_img())

    # Saved ops
    c = tf.constant
    T.ops_print = [
        c('tau'), tf.identity(T.tau),
        c('gen'), loss_gen,
        c('rec'), loss_rec,
        c('kl'), loss_kl,
        c('class'), loss_class,
    ]

    T.ops_main = [summary_main, train_main]
    T.fn_test_acc = fn_test_acc

    if args.gw > 0:
        T.fn_fake_x = fn_fake_x
        T.ops_image = summary_image

    return T
Пример #8
0
def dirtt():
    T = tb.utils.TensorDict(dict(
        sess = tf.Session(config=tb.growth_config()),
        src_x = placeholder((None, 32, 32, 3)),
        src_y = placeholder((None, args.Y)),
        trg_x = placeholder((None, 32, 32, 3)),
        trg_y = placeholder((None, args.Y)),
        test_x = placeholder((None, 32, 32, 3)),
        test_y = placeholder((None, args.Y)),
    ))
    # Supervised and conditional entropy minimization
    src_e = nn.classifier(T.src_x, phase=True, enc_phase=1, trim=args.trim)
    trg_e = nn.classifier(T.trg_x, phase=True, enc_phase=1, trim=args.trim, reuse=True, internal_update=True)
    src_p = nn.classifier(src_e, phase=True, enc_phase=0, trim=args.trim)
    trg_p = nn.classifier(trg_e, phase=True, enc_phase=0, trim=args.trim, reuse=True, internal_update=True)

    loss_src_class = tf.reduce_mean(softmax_xent(labels=T.src_y, logits=src_p))
    loss_trg_cent = tf.reduce_mean(softmax_xent_two(labels=trg_p, logits=trg_p))

    # Domain confusion
    if args.dw > 0 and args.dirt == 0:
        real_logit = nn.feature_discriminator(src_e, phase=True)
        fake_logit = nn.feature_discriminator(trg_e, phase=True, reuse=True)

        loss_disc = 0.5 * tf.reduce_mean(
            sigmoid_xent(labels=tf.ones_like(real_logit), logits=real_logit) +
            sigmoid_xent(labels=tf.zeros_like(fake_logit), logits=fake_logit))
        loss_domain = 0.5 * tf.reduce_mean(
            sigmoid_xent(labels=tf.zeros_like(real_logit), logits=real_logit) +
            sigmoid_xent(labels=tf.ones_like(fake_logit), logits=fake_logit))

    else:
        loss_disc = constant(0)
        loss_domain = constant(0)

    # Virtual adversarial training (turn off src in non-VADA phase)
    loss_src_vat = vat_loss(T.src_x, src_p, nn.classifier) if args.sw > 0 and args.dirt == 0 else constant(0)
    loss_trg_vat = vat_loss(T.trg_x, trg_p, nn.classifier) if args.tw > 0 else constant(0)

    # Evaluation (EMA)
    ema = tf.train.ExponentialMovingAverage(decay=0.998)
    var_class = tf.get_collection('trainable_variables', 'class/')
    ema_op = ema.apply(var_class)
    ema_p = nn.classifier(T.test_x, phase=False, reuse=True, getter=tb.tfutils.get_getter(ema))

    # Teacher model (a back-up of EMA model)
    teacher_p = nn.classifier(T.test_x, phase=False, scope='teacher')
    var_main = tf.get_collection('variables', 'class/(?!.*ExponentialMovingAverage:0)')
    var_teacher = tf.get_collection('variables', 'teacher/(?!.*ExponentialMovingAverage:0)')
    teacher_assign_ops = []
    for t, m in zip(var_teacher, var_main):
        ave = ema.average(m)
        ave = ave if ave else m
        teacher_assign_ops += [tf.assign(t, ave)]
    update_teacher = tf.group(*teacher_assign_ops)
    teacher = tb.function(T.sess, [T.test_x], tf.nn.softmax(teacher_p))

    # Accuracies
    src_acc = basic_accuracy(T.src_y, src_p)
    trg_acc = basic_accuracy(T.trg_y, trg_p)
    ema_acc = basic_accuracy(T.test_y, ema_p)
    fn_ema_acc = tb.function(T.sess, [T.test_x, T.test_y], ema_acc)

    # Optimizer
    dw = constant(args.dw) if args.dirt == 0 else constant(0)
    cw = constant(1)       if args.dirt == 0 else constant(args.bw)
    sw = constant(args.sw) if args.dirt == 0 else constant(0)
    tw = constant(args.tw)
    loss_main = (dw * loss_domain +
                 cw * loss_src_class +
                 sw * loss_src_vat +
                 tw * loss_trg_cent +
                 tw * loss_trg_vat)
    var_main = tf.get_collection('trainable_variables', 'class')
    train_main = tf.train.AdamOptimizer(args.lr, 0.5).minimize(loss_main, var_list=var_main)
    train_main = tf.group(train_main, ema_op)

    if args.dw > 0 and args.dirt == 0:
        var_disc = tf.get_collection('trainable_variables', 'disc')
        train_disc = tf.train.AdamOptimizer(args.lr, 0.5).minimize(loss_disc, var_list=var_disc)
    else:
        train_disc = constant(0)

    # Summarizations
    summary_disc = [tf.summary.scalar('domain/loss_disc', loss_disc),]
    summary_main = [tf.summary.scalar('domain/loss_domain', loss_domain),
                    tf.summary.scalar('class/loss_src_class', loss_src_class),
                    tf.summary.scalar('class/loss_trg_cent', loss_trg_cent),
                    tf.summary.scalar('lipschitz/loss_trg_vat', loss_trg_vat),
                    tf.summary.scalar('lipschitz/loss_src_vat', loss_src_vat),
                    tf.summary.scalar('hyper/dw', dw),
                    tf.summary.scalar('hyper/cw', cw),
                    tf.summary.scalar('hyper/sw', sw),
                    tf.summary.scalar('hyper/tw', tw),
                    tf.summary.scalar('acc/src_acc', src_acc),
                    tf.summary.scalar('acc/trg_acc', trg_acc)]

    # Merge summaries
    summary_disc = tf.summary.merge(summary_disc)
    summary_main = tf.summary.merge(summary_main)

    # Saved ops
    c = tf.constant
    T.ops_print = [c('disc'), loss_disc,
                   c('domain'), loss_domain,
                   c('class'), loss_src_class,
                   c('cent'), loss_trg_cent,
                   c('trg_vat'), loss_trg_vat,
                   c('src_vat'), loss_src_vat,
                   c('src'), src_acc,
                   c('trg'), trg_acc]
    T.ops_disc = [summary_disc, train_disc]
    T.ops_main = [summary_main, train_main]
    T.fn_ema_acc = fn_ema_acc
    T.teacher = teacher
    T.update_teacher = update_teacher

    return T
Пример #9
0
def model(FLAGS, gpu_config):
    """
    :param FLAGS: Contains the experiment info
    :return: (TensorDict) the model
    """

    print(colored("Model initialization started", "blue"))

    nn = network(FLAGS)
    sz = FLAGS.sz
    ch = FLAGS.ch
    bs = FLAGS.bs
    sbs = FLAGS.sbs

    alpha = constant(FLAGS.alpha)
    beta = constant(FLAGS.beta)
    theta = constant(FLAGS.theta)
    delta = constant(FLAGS.delta)

    T = tb.utils.TensorDict(
        dict(sess=tf.Session(config=tb.growth_config()),
             x=placeholder((bs, sz, sz, ch)),
             z=placeholder((bs, FLAGS.nz)),
             pos=placeholder((bs * FLAGS.jcb, FLAGS.nz)),
             iorth=placeholder((bs, FLAGS.jcb, FLAGS.jcb)),
             lrD=placeholder(None),
             lrG=placeholder(None),
             seq_in=placeholder((10, sbs, sz, sz, ch)),
             seq_out=placeholder((10, sbs, sz, sz, ch)),
             val_seq_in=placeholder((10, 10, sz, sz, ch)),
             val_seq_out=placeholder((10, 10, sz, sz, ch)),
             test_seq_in=placeholder((10, 10, sz, sz, ch)),
             lr=placeholder(None)))

    # Compute G(x, z) and G(x, 0)
    fake_x = nn.generator(T.x, T.z, phase=True)
    # T.fake_x0 = fake_x0 = nn.generator(T.x, tf.zeros_like(T.z), phase=True)
    fake_x0 = nn.generator(T.x, tf.zeros_like(T.z), phase=True)

    # Compute discriminator logits
    real_logit = nn.discriminator(T.x, phase=True)
    fake_logit = nn.discriminator(fake_x, phase=True)
    fake0_logit = nn.discriminator(fake_x0, phase=True)

    # Adversarial generator
    loss_disc = tf.reduce_mean(
        sigmoid_xent(labels=tf.ones_like(real_logit), logits=real_logit) +
        sigmoid_xent(labels=tf.zeros_like(fake_logit), logits=fake_logit) +
        theta *
        sigmoid_xent(labels=tf.zeros_like(fake0_logit), logits=fake0_logit))

    loss_fake = tf.reduce_mean(
        sigmoid_xent(labels=tf.ones_like(fake_logit), logits=fake_logit) +
        theta *
        sigmoid_xent(labels=tf.ones_like(fake0_logit), logits=fake0_logit))

    # Locality
    loss_local = tf.reduce_mean(abs_diff(labels=T.x, predictions=fake_x0))

    # Orthogonality
    pos = T.pos * delta
    tiled_real_x = tf.tile(T.x, [FLAGS.jcb, 1, 1, 1])
    pos_fake_x = nn.generator(tiled_real_x, pos, phase=True)
    neg_fake_x = nn.generator(tiled_real_x, -pos, phase=True)

    jx = (pos_fake_x - neg_fake_x) / (2 * delta)
    jx = tf.reshape(jx, [bs, FLAGS.jcb, -1])
    jx_t = tf.transpose(jx, [0, 2, 1])
    loss_orth = tf.reduce_mean(abs_diff(tf.matmul(jx, jx_t), T.iorth))

    loss_gen = loss_fake + alpha * loss_local + beta * loss_orth

    # Optimizer
    var_disc = tf.get_collection('trainable_variables', 'lgan/dsc')
    train_disc = tf.train.AdamOptimizer(T.lrD, 0.5).minimize(loss_disc,
                                                             var_list=var_disc)

    if FLAGS.clip:
        clip_disc = [
            p.assign(tf.clip_by_value(p, -0.01, 0.01)) for p in var_disc
        ]

    var_gen = tf.get_collection('trainable_variables', 'lgan/gen')
    train_gen = tf.train.AdamOptimizer(T.lrG, 0.5).minimize(loss_gen,
                                                            var_list=var_gen)

    # Summarizations
    summary_disc = [
        tf.summary.scalar('disc/loss_disc', loss_disc),
    ]
    summary_gen = [
        tf.summary.scalar('gen/loss_gen', loss_gen),
        tf.summary.scalar('gen/loss_fake', loss_fake),
        tf.summary.scalar('gen/loss_local', loss_local),
        tf.summary.scalar('gen/loss_orth', loss_orth),
        tf.summary.scalar('hyper/alpha', alpha),
        tf.summary.scalar('hyper/beta', beta),
        tf.summary.scalar('hyper/theta', theta),
        tf.summary.scalar('hyper/delta', delta),
        tf.summary.scalar('hyper/lrD', T.lrD),
        tf.summary.scalar('hyper/lrG', T.lrG),
        tf.summary.scalar('hyper/var', FLAGS.var)
    ]
    summary_image = [
        tf.summary.image('image/x', T.x),
        tf.summary.image('image/fake_x', fake_x),
        tf.summary.image('image/fake_x0', fake_x0)
    ]
    # Merge summaries
    summary_disc = tf.summary.merge(summary_disc)
    summary_gen = tf.summary.merge(summary_gen)
    summary_image = tf.summary.merge(summary_image)

    # Saved ops
    c = tf.constant
    T.ops_print = [
        c('disc'), loss_disc,
        c('gen'), loss_gen,
        c('fake'), loss_fake,
        c('local'), loss_local,
        c('orth'), loss_orth
    ]
    # T.ops_disc = [summary_disc, train_disc]

    if FLAGS.clip:
        T.ops_disc = [summary_disc, train_disc, clip_disc]
    else:
        T.ops_disc = [summary_disc, train_disc]

    T.ops_gen = [summary_gen, train_gen]
    T.ops_image = summary_image

    if FLAGS.phase:
        # LSTM initialization
        seq_in = tf.reshape(T.seq_in, [-1, sz, sz, ch])
        seq_out = tf.reshape(T.seq_out, [-1, sz, sz, ch])
        val_seq_in = tf.reshape(T.val_seq_in, [-1, sz, sz, ch])
        test_seq_in = tf.reshape(T.test_seq_in, [-1, sz, sz, ch])
        enc_in = nn.generator(seq_in,
                              tf.zeros((10 * sbs, FLAGS.nz)),
                              phase=True,
                              enc=True)
        enc_out = nn.generator(seq_out,
                               tf.zeros((10 * sbs, FLAGS.nz)),
                               phase=True,
                               enc=True)
        val_enc_in = nn.generator(val_seq_in,
                                  tf.zeros((10 * 10, FLAGS.nz)),
                                  phase=True,
                                  enc=True)
        test_enc_in = nn.generator(test_seq_in,
                                   tf.zeros((10 * 10, FLAGS.nz)),
                                   phase=True,
                                   enc=True)
        enc_in = tf.stop_gradient(enc_in)
        enc_out = tf.stop_gradient(enc_out)
        val_enc_in = tf.stop_gradient(val_enc_in)
        test_enc_in = tf.stop_gradient(test_enc_in)
        enc_in = tf.squeeze(enc_in)
        enc_out = tf.squeeze(enc_out)
        val_enc_in = tf.squeeze(val_enc_in)
        test_enc_in = tf.squeeze(test_enc_in)
        enc_in = tf.reshape(enc_in, [-1, sbs, 3 * FLAGS.nz])
        enc_out = tf.reshape(enc_out, [-1, sbs, 3 * FLAGS.nz])
        val_enc_in = tf.reshape(val_enc_in, [-1, 10, 3 * FLAGS.nz])
        test_enc_in = tf.reshape(test_enc_in, [-1, 10, 3 * FLAGS.nz])

        with tf.variable_scope('lstm/in'):
            in_cell = tf.contrib.cudnn_rnn.CudnnLSTM(FLAGS.nhl,
                                                     FLAGS.nhw,
                                                     dropout=0.5)

            _, in_states = in_cell(enc_in, initial_state=None, training=True)
            _, val_in_states = in_cell(val_enc_in,
                                       initial_state=None,
                                       training=False)
            _, test_in_states = in_cell(test_enc_in,
                                        initial_state=None,
                                        training=False)

        with tf.variable_scope('lstm/out'):
            out_cell = tf.contrib.cudnn_rnn.CudnnLSTM(FLAGS.nhl,
                                                      FLAGS.nhw,
                                                      dropout=0.5)

            outputs, _ = out_cell(tf.zeros_like(enc_out),
                                  initial_state=in_states,
                                  training=True)
            val_outputs, _ = out_cell(tf.zeros_like(val_enc_in),
                                      initial_state=val_in_states,
                                      training=False)
            test_outputs, _ = out_cell(tf.zeros_like(test_enc_in),
                                       initial_state=test_in_states,
                                       training=False)

            enc_out_pred = tf.layers.dense(outputs,
                                           3 * FLAGS.nz,
                                           activation=None,
                                           name='lstm_dense',
                                           reuse=tf.AUTO_REUSE)
            val_enc_out_pred = tf.layers.dense(val_outputs,
                                               3 * FLAGS.nz,
                                               activation=None,
                                               name='lstm_dense',
                                               reuse=tf.AUTO_REUSE)
            test_enc_out_pred = tf.layers.dense(test_outputs,
                                                3 * FLAGS.nz,
                                                activation=None,
                                                name='lstm_dense',
                                                reuse=tf.AUTO_REUSE)

        enc_out_pred_reshape = tf.reshape(enc_out_pred, [-1, 3 * FLAGS.nz])
        enc_out_pred_reshape = tf.expand_dims(
            tf.expand_dims(enc_out_pred_reshape, 1), 1)
        val_enc_out_pred_reshape = tf.reshape(val_enc_out_pred,
                                              [-1, 3 * FLAGS.nz])
        val_enc_out_pred_reshape = tf.expand_dims(
            tf.expand_dims(val_enc_out_pred_reshape, 1), 1)
        test_enc_out_pred_reshape = tf.reshape(test_enc_out_pred,
                                               [-1, 3 * FLAGS.nz])
        test_enc_out_pred_reshape = tf.expand_dims(
            tf.expand_dims(test_enc_out_pred_reshape, 1), 1)

        seq_out_pred = nn.generator(enc_out_pred_reshape,
                                    tf.zeros((10 * sbs, FLAGS.nz)),
                                    phase=True,
                                    dec=True)
        seq_out_pred = tf.reshape(seq_out_pred, [10, sbs, sz, sz, ch])
        val_seq_out_pred = nn.generator(val_enc_out_pred_reshape,
                                        tf.zeros((10 * 10, FLAGS.nz)),
                                        phase=True,
                                        dec=True)
        val_seq_out_pred = tf.reshape(val_seq_out_pred, [10, 10, sz, sz, ch])
        test_seq_out_pred = nn.generator(test_enc_out_pred_reshape,
                                         tf.zeros((10 * 10, FLAGS.nz)),
                                         phase=True,
                                         dec=True)
        T.test_seq_out_pred = tf.reshape(test_seq_out_pred,
                                         [10, 10, sz, sz, ch])

        T.val_mae = tf.reduce_mean(
            abs_diff(labels=T.val_seq_out, predictions=val_seq_out_pred))
        loss_lstm = tf.reduce_mean(
            abs_diff(labels=enc_out, predictions=enc_out_pred))
        var_lstm = tf.get_collection('trainable_variables', 'lstm')
        # train_lstm = tf.train.AdamOptimizer(FLAGS.lr, 0.5).minimize(loss_lstm, var_list=var_lstm)
        train_lstm = tf.train.AdamOptimizer(T.lr,
                                            0.5).minimize(loss_lstm,
                                                          var_list=var_lstm)

        summary_lstm = [tf.summary.scalar('lstm/loss_lstm', loss_lstm)]
        summary_lstm_image = [
            tf.summary.image('lstm/seq_out', T.seq_out[:, 0, :, :, :]),
            tf.summary.image('lstm/seq_out_pred', seq_out_pred[:, 0, :, :, :])
        ]
        summary_lstm = tf.summary.merge(summary_lstm)
        summary_lstm_image = tf.summary.merge(summary_lstm_image)

        T.ops_lstm_print = [c('loss_lstm'), loss_lstm]
        T.ops_lstm = [summary_lstm, train_lstm]
        T.ops_lstm_image = summary_lstm_image

        # T.test1 = seq_out_pred

    print(colored("Model initialization ended", "blue"))

    return T
Пример #10
0
def vae():
    T = tb.utils.TensorDict(
        dict(
            sess=tf.Session(config=tb.growth_config()),
            trg_x=placeholder((None, 32, 32, 3), name='target_x'),
            fake_z=placeholder((None, args.Z), name='fake_z'),
        ))

    # Inference
    z, z_post = nn.encoder(T.trg_x, phase=True, internal_update=True)

    # Generation
    x = nn.generator(z, phase=True, internal_update=True)

    # Loss
    loss_rec, loss_kl, loss_gen = vae_loss(x, T.trg_x, z, z_post)

    # Evaluation (embedding, reconstruction, loss)
    test_z, test_z_post = nn.encoder(T.trg_x, phase=False, reuse=True)
    test_x = nn.generator(test_z, phase=False, reuse=True)
    _, _, test_loss = vae_loss(test_x, T.trg_x, test_z, test_z_post)
    fn_embed = tb.function(T.sess, [T.trg_x], test_z_post)
    fn_recon = tb.function(T.sess, [T.trg_x], test_x)
    fn_loss = tb.function(T.sess, [T.trg_x], test_loss)

    # Evaluation (generation)
    fake_x = nn.generator(T.fake_z, phase=False, reuse=True)
    fn_generate = tb.function(T.sess, [T.fake_z], fake_x)

    # Optimizer
    var_main = tf.get_collection('trainable_variables', 'gen/')
    var_main += tf.get_collection('trainable_variables', 'enc/')
    loss_main = loss_gen
    train_main = tf.train.AdamOptimizer(args.lr,
                                        0.5).minimize(loss_main,
                                                      var_list=var_main)

    # Summarizations
    summary_main = [
        tf.summary.scalar('gen/loss_gen', loss_gen),
        tf.summary.scalar('gen/loss_kl', loss_kl),
        tf.summary.scalar('gen/loss_rec', loss_rec),
    ]
    summary_image = [tf.summary.image('gen/gen', generate_image(nn.generator))]

    # Merge summaries
    summary_main = tf.summary.merge(summary_main)
    summary_image = tf.summary.merge(summary_image)

    # Saved ops
    c = tf.constant
    T.ops_print = [
        c('gen'),
        loss_gen,
        c('kl'),
        loss_kl,
        c('rec'),
        loss_rec,
    ]

    T.ops_main = [summary_main, train_main]
    T.ops_image = summary_image
    T.fn_embed = fn_embed
    T.fn_recon = fn_recon
    T.fn_loss = fn_loss
    T.fn_generate = fn_generate

    return T
Пример #11
0
from config import args, get_log_dir
import numpy as np
from data import Mnist, Svhn
import tensorflow as tf
import tensorbayes as tb
from tensorbayes.layers import placeholder

################
# Set up model #
################
T = tb.utils.TensorDict(dict(
    sess = tf.Session(),
    trg_x = placeholder((None, 32, 32, 1), name='trg_x'),
    trg_d = placeholder((None, 32, 32, 1), name='trg_d'),
    src_x = placeholder((None, 32, 32, 3), name='src_x'),
    src_y = placeholder((None, 10), name='src_y'),
    test_x = placeholder((None, 32, 32, 3), name='test_x'),
    test_y = placeholder((None, 10), name='test_y'),
    phase = placeholder((), tf.bool, name='phase')
))

exec "from {0:s} import {0:s}".format(args.model)
exec "T = {:s}(T)".format(args.model)
T.sess.run(tf.global_variables_initializer())

if args.model != 'classifier':
    path = tf.train.latest_checkpoint('save')
    restorer = tf.train.Saver(tf.get_collection('trainable_variables', 'enc'))
    restorer.restore(T.sess, path)

#############
Пример #12
0
def model(FLAGS, gpu_config):
    """
    :param FLAGS: Contains the experiment info
    :return: (TensorDict) the model
    """

    print(colored("Model initialization started", "blue"))

    nn = network(FLAGS)
    sz = FLAGS.sz
    ch = FLAGS.ch
    bs = FLAGS.bs
    sbs = FLAGS.sbs

    T = tb.utils.TensorDict(dict(
        sess=tf.Session(config=tb.growth_config()),
        x=placeholder((bs, sz, sz, ch)),
        lrD=placeholder(None),
        lrG=placeholder(None),
        seq_in=placeholder((10, sbs, sz, sz, ch)),
        seq_out=placeholder((10, sbs, sz, sz, ch)),
        val_seq_in=placeholder((10, 10, sz, sz, ch)),
        val_seq_out=placeholder((10, 10, sz, sz, ch)),
        test_seq_in=placeholder((10, 10, sz, sz, ch)),
        lr=placeholder(None)
    ))

    recon_x = nn.generator(T.x, phase=True)

    # Compute discriminator logits
    real_logit = nn.discriminator(T.x, phase=True)
    fake_logit = nn.discriminator(recon_x, phase=True)

    # Adversarial generator
    loss_disc = tf.reduce_mean(
        sigmoid_xent(labels=tf.ones_like(real_logit), logits=real_logit) +
        sigmoid_xent(labels=tf.zeros_like(fake_logit), logits=fake_logit))
    loss_fake = tf.reduce_mean(
        sigmoid_xent(labels=tf.ones_like(fake_logit), logits=fake_logit))

    loss_local = tf.reduce_mean(abs_diff(labels=T.x, predictions=recon_x))

    loss_gen = loss_fake + FLAGS.alpha * loss_local

    var_gen = tf.get_collection('trainable_variables', 'lgan/gen')
    train_gen = tf.train.AdamOptimizer(T.lrG, 0.5).minimize(loss_gen, var_list=var_gen)

    var_disc = tf.get_collection('trainable_variables', 'lgan/dsc')
    train_disc = tf.train.AdamOptimizer(T.lrD, 0.5).minimize(loss_disc, var_list=var_disc)

    # Summarizations
    summary_disc = [tf.summary.scalar('disc/loss_disc', loss_disc)]
    summary_gen = [tf.summary.scalar('gen/loss_gen', loss_gen),
                   tf.summary.scalar('gen/loss_local', loss_local),
                   tf.summary.scalar('gen/loss_fake', loss_fake),
                   tf.summary.scalar('hyper/lrD', T.lrD),
                   tf.summary.scalar('hyper/lrG', T.lrG)]
    summary_image = [tf.summary.image('image/x', T.x),
                     tf.summary.image('image/recon_x', recon_x)]

    # Merge summaries
    summary_disc = tf.summary.merge(summary_disc)
    summary_gen = tf.summary.merge(summary_gen)
    summary_image = tf.summary.merge(summary_image)

    # Saved ops
    c = tf.constant
    T.ops_print = [c('disc'), loss_disc,
                   c('gen'), loss_gen,
                   c('local'), loss_local,
                   c('fake'), loss_fake]
    T.ops_disc = [summary_disc, train_disc]
    T.ops_gen = [summary_gen, train_gen]
    T.ops_image = summary_image

    if FLAGS.phase:
        # LSTM initialization
        seq_in = tf.reshape(T.seq_in, [-1, sz, sz, ch])
        seq_out = tf.reshape(T.seq_out, [-1, sz, sz, ch])
        val_seq_in = tf.reshape(T.val_seq_in, [-1, sz, sz, ch])
        test_seq_in = tf.reshape(T.test_seq_in, [-1, sz, sz, ch])
        enc_in = nn.generator(seq_in, phase=True, enc=True)
        enc_out = nn.generator(seq_out, phase=True, enc=True)
        val_enc_in = nn.generator(val_seq_in, phase=True, enc=True)
        test_enc_in = nn.generator(test_seq_in, phase=True, enc=True)
        enc_in = tf.stop_gradient(enc_in)
        enc_out = tf.stop_gradient(enc_out)
        val_enc_in = tf.stop_gradient(val_enc_in)
        test_enc_in = tf.stop_gradient(test_enc_in)
        enc_in = tf.squeeze(enc_in)
        enc_out = tf.squeeze(enc_out)
        val_enc_in = tf.squeeze(val_enc_in)
        test_enc_in = tf.squeeze(test_enc_in)
        enc_in = tf.reshape(enc_in, [-1, sbs, FLAGS.nz])
        enc_out = tf.reshape(enc_out, [-1, sbs, FLAGS.nz])
        val_enc_in = tf.reshape(val_enc_in, [-1, 10, FLAGS.nz])
        test_enc_in = tf.reshape(test_enc_in, [-1, 10, FLAGS.nz])

        with tf.variable_scope('lstm/in'):
            in_cell = tf.contrib.cudnn_rnn.CudnnLSTM(FLAGS.nhl, FLAGS.nhw, dropout=0.5)

            _, in_states = in_cell(enc_in, initial_state=None, training=True)
            _, val_in_states = in_cell(val_enc_in, initial_state=None, training=False)
            _, test_in_states = in_cell(test_enc_in, initial_state=None, training=False)

        with tf.variable_scope('lstm/out'):
            out_cell = tf.contrib.cudnn_rnn.CudnnLSTM(FLAGS.nhl, FLAGS.nhw, dropout=0.5)

            outputs, _ = out_cell(tf.zeros_like(enc_out), initial_state=in_states, training=True)
            val_outputs, _ = out_cell(tf.zeros_like(val_enc_in), initial_state=val_in_states, training=False)
            test_outputs, _ = out_cell(tf.zeros_like(test_enc_in), initial_state=test_in_states, training=False)

            enc_out_pred = tf.layers.dense(outputs, FLAGS.nz, activation=None, name='lstm_dense', reuse=tf.AUTO_REUSE)
            val_enc_out_pred = tf.layers.dense(val_outputs, FLAGS.nz, activation=None, name='lstm_dense',
                                               reuse=tf.AUTO_REUSE)
            test_enc_out_pred = tf.layers.dense(test_outputs, FLAGS.nz, activation=None, name='lstm_dense',
                                               reuse=tf.AUTO_REUSE)

        enc_out_pred_reshape = tf.reshape(enc_out_pred, [-1, FLAGS.nz])
        enc_out_pred_reshape = tf.expand_dims(tf.expand_dims(enc_out_pred_reshape, 1), 1)
        val_enc_out_pred_reshape = tf.reshape(val_enc_out_pred, [-1, FLAGS.nz])
        val_enc_out_pred_reshape = tf.expand_dims(tf.expand_dims(val_enc_out_pred_reshape, 1), 1)
        test_enc_out_pred_reshape = tf.reshape(test_enc_out_pred, [-1, FLAGS.nz])
        test_enc_out_pred_reshape = tf.expand_dims(tf.expand_dims(test_enc_out_pred_reshape, 1), 1)

        seq_out_pred = nn.generator(enc_out_pred_reshape, phase=True, dec=True)
        seq_out_pred = tf.reshape(seq_out_pred, [10, sbs, sz, sz, ch])
        val_seq_out_pred = nn.generator(val_enc_out_pred_reshape, phase=True, dec=True)
        val_seq_out_pred = tf.reshape(val_seq_out_pred, [10, 10, sz, sz, ch])
        test_seq_out_pred = nn.generator(test_enc_out_pred_reshape, phase=True, dec=True)
        T.test_seq_out_pred = tf.reshape(test_seq_out_pred, [10, 10, sz, sz, ch])

        T.val_mae = tf.reduce_mean(abs_diff(labels=T.val_seq_out, predictions=val_seq_out_pred))
        loss_lstm = tf.reduce_mean(abs_diff(labels=enc_out, predictions=enc_out_pred))
        var_lstm = tf.get_collection('trainable_variables', 'lstm')
        # train_lstm = tf.train.AdamOptimizer(FLAGS.lr, 0.5).minimize(loss_lstm, var_list=var_lstm)
        train_lstm = tf.train.AdamOptimizer(T.lr, 0.5).minimize(loss_lstm, var_list=var_lstm)

        summary_lstm = [tf.summary.scalar('lstm/loss_lstm', loss_lstm)]
        summary_lstm_image = [tf.summary.image('lstm/seq_out', T.seq_out[:, 0, :, :, :]),
                              tf.summary.image('lstm/seq_out_pred', seq_out_pred[:, 0, :, :, :])]
        summary_lstm = tf.summary.merge(summary_lstm)
        summary_lstm_image = tf.summary.merge(summary_lstm_image)

        T.ops_lstm_print = [c('loss_lstm'), loss_lstm]
        T.ops_lstm = [summary_lstm, train_lstm]
        T.ops_lstm_image = summary_lstm_image

    print(colored("Model initialization ended", "blue"))

    return T
Пример #13
0
def dirtt():
    T = tb.utils.TensorDict(dict(
        sess = tf.Session(config=tb.growth_config()),
        src_x = placeholder((None, 500, 60, 1)),
        src_y = placeholder((None, args.Y)),
        trg_x = placeholder((None, 500, 60, 1)),
        trg_y = placeholder((None, args.Y)),
        test_x = placeholder((None, 500, 60, 1)),
        test_y = placeholder((None, args.Y)),
    ))
    # Supervised and conditional entropy minimization
    src_e = nn.classifier(T.src_x, phase=True, enc_phase=1, trim=args.trim)
    trg_e = nn.classifier(T.trg_x, phase=True, enc_phase=1, trim=args.trim, reuse=True, internal_update=True)
    src_p = nn.classifier(src_e, phase=True, enc_phase=0, trim=args.trim)
    trg_p = nn.classifier(trg_e, phase=True, enc_phase=0, trim=args.trim, reuse=True, internal_update=True)

    loss_src_class = tf.reduce_mean(softmax_xent(labels=T.src_y, logits=src_p))
    loss_trg_cent = tf.reduce_mean(softmax_xent_two(labels=trg_p, logits=trg_p))

    # Domain confusion
    if args.dw > 0 and args.dirt == 0:
        real_logit = nn.feature_discriminator(src_e, phase=True)
        fake_logit = nn.feature_discriminator(trg_e, phase=True, reuse=True)

        loss_disc = 0.5 * tf.reduce_mean(
            sigmoid_xent(labels=tf.ones_like(real_logit), logits=real_logit) +
            sigmoid_xent(labels=tf.zeros_like(fake_logit), logits=fake_logit))
        loss_domain = 0.5 * tf.reduce_mean(
            sigmoid_xent(labels=tf.zeros_like(real_logit), logits=real_logit) +
            sigmoid_xent(labels=tf.ones_like(fake_logit), logits=fake_logit))

    else:
        loss_disc = constant(0)
        loss_domain = constant(0)

    # Virtual adversarial training (turn off src in non-VADA phase)
    loss_src_vat = vat_loss(T.src_x, src_p, nn.classifier) if args.sw > 0 and args.dirt == 0 else constant(0)
    loss_trg_vat = vat_loss(T.trg_x, trg_p, nn.classifier) if args.tw > 0 else constant(0)

    # Evaluation (EMA)
    ema = tf.train.ExponentialMovingAverage(decay=0.998)
    var_class = tf.get_collection('trainable_variables', 'class/')
    ema_op = ema.apply(var_class)
    ema_p = nn.classifier(T.test_x, phase=False, reuse=True, getter=tb.tfutils.get_getter(ema))

    # Teacher model (a back-up of EMA model)
    teacher_p = nn.classifier(T.test_x, phase=False, scope='teacher')
    var_main = tf.get_collection('variables', 'class/(?!.*ExponentialMovingAverage:0)')
    var_teacher = tf.get_collection('variables', 'teacher/(?!.*ExponentialMovingAverage:0)')
    teacher_assign_ops = []
    for t, m in zip(var_teacher, var_main):
        ave = ema.average(m)
        ave = ave if ave else m
        teacher_assign_ops += [tf.assign(t, ave)]
    update_teacher = tf.group(*teacher_assign_ops)
    teacher = tb.function(T.sess, [T.test_x], tf.nn.softmax(teacher_p))

    # Accuracies
    src_acc = basic_accuracy(T.src_y, src_p)
    trg_acc = basic_accuracy(T.trg_y, trg_p)
    ema_acc = basic_accuracy(T.test_y, ema_p)
    fn_ema_acc = tb.function(T.sess, [T.test_x, T.test_y], ema_acc)

    # Optimizer
    dw = constant(args.dw) if args.dirt == 0 else constant(0)
    cw = constant(1)       if args.dirt == 0 else constant(args.bw)
    sw = constant(args.sw) if args.dirt == 0 else constant(0)
    tw = constant(args.tw)
    loss_main = (dw * loss_domain +
                 cw * loss_src_class +
                 sw * loss_src_vat +
                 tw * loss_trg_cent +
                 tw * loss_trg_vat)
    var_main = tf.get_collection('trainable_variables', 'class')
    train_main = tf.train.AdamOptimizer(args.lr, 0.5).minimize(loss_main, var_list=var_main)
    train_main = tf.group(train_main, ema_op)

    if args.dw > 0 and args.dirt == 0:
        var_disc = tf.get_collection('trainable_variables', 'disc')
        train_disc = tf.train.AdamOptimizer(args.lr, 0.5).minimize(loss_disc, var_list=var_disc)
    else:
        train_disc = constant(0)

    # Summarizations
    summary_disc = [tf.summary.scalar('domain/loss_disc', loss_disc),]
    summary_main = [tf.summary.scalar('domain/loss_domain', loss_domain),
                    tf.summary.scalar('class/loss_src_class', loss_src_class),
                    tf.summary.scalar('class/loss_trg_cent', loss_trg_cent),
                    tf.summary.scalar('lipschitz/loss_trg_vat', loss_trg_vat),
                    tf.summary.scalar('lipschitz/loss_src_vat', loss_src_vat),
                    tf.summary.scalar('hyper/dw', dw),
                    tf.summary.scalar('hyper/cw', cw),
                    tf.summary.scalar('hyper/sw', sw),
                    tf.summary.scalar('hyper/tw', tw),
                    tf.summary.scalar('acc/src_acc', src_acc),
                    tf.summary.scalar('acc/trg_acc', trg_acc)]

    # Merge summaries
    summary_disc = tf.summary.merge(summary_disc)
    summary_main = tf.summary.merge(summary_main)

    # Saved ops
    c = tf.constant
    T.ops_print = [c('disc'), loss_disc,
                   c('domain'), loss_domain,
                   c('class'), loss_src_class,
                   c('cent'), loss_trg_cent,
                   c('trg_vat'), loss_trg_vat,
                   c('src_vat'), loss_src_vat,
                   c('src'), src_acc,
                   c('trg'), trg_acc]
    T.ops_disc = [summary_disc, train_disc]
    T.ops_main = [summary_main, train_main]
    T.fn_ema_acc = fn_ema_acc
    T.teacher = teacher
    T.update_teacher = update_teacher

    return T
Пример #14
0
    def e_step(self, x_data):
        print "e_step finetuning"
        tf.reset_default_graph()
        self.x_ = placeholder(
            (None, self.input_dim))  # we need these global nodes
        self.v_ = placeholder((None, self.num_factors))

        self.sess = tf.Session()
        self.sess.run(tf.global_variables_initializer())

        # inference process

        x = self.x_
        depth_inf = len(self.encoding_dims)
        for i in range(depth_inf):
            x = dense(x,
                      self.encoding_dims[i],
                      scope="enc_layer" + "%s" % i,
                      activation=tf.nn.sigmoid)
            # print("enc_layer0/weights:0".graph)
        h_encode = x
        z_mu = dense(h_encode, self.z_dim, scope="mu_layer")
        z_log_sigma_sq = dense(h_encode, self.z_dim, scope="sigma_layer")
        e = tf.random_normal(tf.shape(z_mu))
        z = z_mu + tf.sqrt(tf.maximum(tf.exp(z_log_sigma_sq), self.eps)) * e

        # generative process
        depth_gen = len(self.decoding_dims)
        for i in range(depth_gen):
            y = dense(z,
                      self.decoding_dims[i],
                      scope="dec_layer" + "%s" % i,
                      activation=tf.nn.sigmoid)
            # if last_layer_nonelinear: depth_gen -1

        x_recons = y

        if self.loss_type == "cross_entropy":
            loss_recons = tf.reduce_mean(
                tf.reduce_sum(binary_crossentropy(x_recons, self.x_, self.eps),
                              axis=1))
            loss_kl = 0.5 * tf.reduce_mean(
                tf.reduce_sum(
                    tf.square(z_mu) + tf.exp(z_log_sigma_sq) - z_log_sigma_sq -
                    1, 1))
            loss_v = 1.0 * self.params.lambda_v / self.params.lambda_r * tf.reduce_mean(
                tf.reduce_sum(tf.square(self.v_ - z), 1))
            # reg_loss we don't use reg_loss temporailly
        self.loss_e_step = loss_recons + loss_kl + loss_v
        train_op = tf.train.AdamOptimizer(self.params.learning_rate).minimize(
            self.loss_e_step)

        ckpt_file = "pre_model/" + "vae.ckpt"
        self.saver = tf.train.Saver()
        # if init == True:
        self.saver.restore(self.sess, ckpt_file)
        for i in range(self.params.num_iter):
            idx = np.random.choice(self.num_items,
                                   self.params.batch_size,
                                   replace=False)
            x_batch = x_data[idx]
            v_batch = self.V[idx]
            _, l = self.sess.run((train_op, self.loss_e_step),
                                 feed_dict={
                                     self.x_: x_batch,
                                     self.v_: v_batch
                                 })
            if i % 50 == 0:
                print "{:>10s}{:>10s}".format("epochs", "loss_e_step")
                print "{:>10d}{:>10.2e}".format(i, l)

        self.z_mu = z_mu
        self.x_recons = x_recons
        self.saver.save(self.sess, ckpt_file)
        return None
Пример #15
0
def gada():
    T = tb.utils.TensorDict(dict(
        sess = tf.Session(config=tb.growth_config()),
        src_x = placeholder((None, 32, 32, 3)),
        src_y = placeholder((None, args.Y)),
        trg_x = placeholder((None, 32, 32, 3)),
        trg_y = placeholder((None, args.Y)),
        trg_z = placeholder((None, 100)),
        test_x = placeholder((None, 32, 32, 3)),
        test_y = placeholder((None, args.Y)),
    ))

    # Supervised and conditional entropy minimization
    src_e = nn.classifier(T.src_x, phase=True, enc_phase=1, enc_trim=args.etrim)
    src_g = nn.classifier(src_e, phase=True, gen_trim=args.gtrim, gen_phase=1, enc_trim=args.etrim)
    src_p = nn.classifier(src_g, phase=True, gen_trim=args.gtrim)
    trg_e = nn.classifier(T.trg_x, phase=True, enc_phase=1, enc_trim=args.etrim, reuse=True, internal_update=True)
    trg_g = nn.classifier(trg_e, phase=True, gen_trim=args.gtrim, gen_phase=1, enc_trim=args.etrim, reuse=True, internal_update=True)
    trg_p = nn.classifier(trg_g, phase=True, gen_trim=args.gtrim, reuse=True, internal_update=True)

    loss_src_class = tf.reduce_mean(softmax_xent(labels=T.src_y, logits=src_p))
    loss_trg_cent = tf.reduce_mean(softmax_xent_two(labels=trg_p, logits=trg_p)) if args.tw > 0 else constant(0)

    # Domain confusion
    if args.dw > 0 and args.dirt == 0:
        real_logit = nn.real_feature_discriminator(src_e, phase=True)
        fake_logit = nn.real_feature_discriminator(trg_e, phase=True, reuse=True)

        loss_disc = 0.5 * tf.reduce_mean(
            sigmoid_xent(labels=tf.ones_like(real_logit), logits=real_logit) +
            sigmoid_xent(labels=tf.zeros_like(fake_logit), logits=fake_logit))
        loss_domain = 0.5 * tf.reduce_mean(
            sigmoid_xent(labels=tf.zeros_like(real_logit), logits=real_logit) +
            sigmoid_xent(labels=tf.ones_like(fake_logit), logits=fake_logit))

    else:
        loss_disc = constant(0)
        loss_domain = constant(0)

    # Virtual adversarial training (turn off src in non-VADA phase)
    loss_src_vat = vat_loss(T.src_x, src_p, nn.classifier) if args.sw > 0 and args.dirt == 0 else constant(0)
    loss_trg_vat = vat_loss(T.trg_x, trg_p, nn.classifier) if args.tw > 0 else constant(0)

    # Generate images and process generated images
    trg_gen_x = nn.trg_generator(T.trg_z)
    trg_gen_e = nn.classifier(trg_gen_x, phase=True, enc_phase=1, enc_trim=args.etrim, reuse=True, internal_update=True)
    trg_gen_g = nn.classifier(trg_gen_e, phase=True, gen_trim=args.gtrim, gen_phase=1, enc_trim=args.etrim, reuse=True, internal_update=True)
    trg_gen_p = nn.classifier(trg_gen_g, phase=True, gen_trim=args.gtrim, reuse=True, internal_update=True)

    # Feature matching loss function for generator
    loss_trg_gen_fm = tf.reduce_mean(tf.square(tf.reduce_mean(trg_g, axis=0) - tf.reduce_mean(trg_gen_g, axis=0))) if args.dirt == 0 else constant(0)

    # Unsupervised loss function
    if args.dirt == 0:
        logit_real = tf.reduce_logsumexp(trg_p, axis=1)
        logit_fake = tf.reduce_logsumexp(trg_gen_p, axis=1)
        dis_loss_real = -0.5*tf.reduce_mean(logit_real) + 0.5*tf.reduce_mean(tf.nn.softplus(logit_real))
        dis_loss_fake = 0.5*tf.reduce_mean(tf.nn.softplus(logit_fake))
        loss_trg_usv = dis_loss_real + dis_loss_fake    # UnSuperVised loss function
    else:
        loss_trg_usv = constant(0)

    # Evaluation (EMA)
    ema = tf.train.ExponentialMovingAverage(decay=0.998)
    var_class = tf.get_collection('trainable_variables', 'class/')
    ema_op = ema.apply(var_class)
    ema_p = nn.classifier(T.test_x, enc_phase=1, enc_trim=0, phase=False, reuse=True, getter=tb.tfutils.get_getter(ema))

    # Teacher model (a back-up of EMA model)
    teacher_p = nn.classifier(T.test_x, enc_phase=1, enc_trim=0, phase=False, scope='teacher')
    var_main = tf.get_collection('variables', 'class/(?!.*ExponentialMovingAverage:0)')
    var_teacher = tf.get_collection('variables', 'teacher/(?!.*ExponentialMovingAverage:0)')
    teacher_assign_ops = []
    for t, m in zip(var_teacher, var_main):
        ave = ema.average(m)
        ave = ave if ave else m
        teacher_assign_ops += [tf.assign(t, ave)]
    update_teacher = tf.group(*teacher_assign_ops)
    teacher = tb.function(T.sess, [T.test_x], tf.nn.softmax(teacher_p))

    # Accuracies
    src_acc = basic_accuracy(T.src_y, src_p)
    trg_acc = basic_accuracy(T.trg_y, trg_p)
    ema_acc = basic_accuracy(T.test_y, ema_p)
    fn_ema_acc = tb.function(T.sess, [T.test_x, T.test_y], ema_acc)

    # Optimizer
    dw = constant(args.dw) if args.dirt == 0 else constant(0)
    cw = constant(1)       if args.dirt == 0 else constant(args.bw)
    sw = constant(args.sw) if args.dirt == 0 else constant(0)
    tw = constant(args.tw)
    uw = constant(args.uw) if args.dirt == 0 else constant(0)
    loss_main = (dw * loss_domain +
                 cw * loss_src_class +
                 sw * loss_src_vat +
                 tw * loss_trg_cent +
                 tw * loss_trg_vat +
                 uw * loss_trg_usv)
    var_main = tf.get_collection('trainable_variables', 'class')
    train_main = tf.train.AdamOptimizer(args.lr, 0.5).minimize(loss_main, var_list=var_main)
    train_main = tf.group(train_main, ema_op)

    # Optimizer for feature discriminator
    if args.dw > 0 and args.dirt == 0:
        var_disc = tf.get_collection('trainable_variables', 'disc_real')
        train_disc = tf.train.AdamOptimizer(args.lr, 0.5).minimize(loss_disc, var_list=var_disc)
    else:
        train_disc = constant(0)

    # Optimizer for generators
    if args.dirt == 0:
        fmw = constant(1)
        loss_trg_gen = (fmw * loss_trg_gen_fm)
        var_trg_gen = tf.get_collection('trainable_variables', 'trg_gen')
        trg_gen_update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS, scope='trg_gen')
        with tf.control_dependencies(trg_gen_update_ops):
            train_trg_gen = tf.train.AdamOptimizer(args.lr, 0.5).minimize(loss_trg_gen, var_list=var_trg_gen)
        train_gen = train_trg_gen
    else:
        fmw = constant(0)
        train_gen = constant(0)

    # Summarizations
    summary_disc = [tf.summary.scalar('domain/loss_disc', loss_disc),]
    summary_main = [tf.summary.scalar('domain/loss_domain', loss_domain),
                    tf.summary.scalar('class/loss_src_class', loss_src_class),
                    tf.summary.scalar('class/loss_trg_cent', loss_trg_cent),
                    tf.summary.scalar('class/loss_trg_usv', loss_trg_usv),
                    tf.summary.scalar('lipschitz/loss_trg_vat', loss_trg_vat),
                    tf.summary.scalar('lipschitz/loss_src_vat', loss_src_vat),
                    tf.summary.scalar('hyper/dw', dw),
                    tf.summary.scalar('hyper/cw', cw),
                    tf.summary.scalar('hyper/sw', sw),
                    tf.summary.scalar('hyper/tw', tw),
                    tf.summary.scalar('hyper/uw', uw),
                    tf.summary.scalar('hyper/fmw', fmw),
                    tf.summary.scalar('acc/src_acc', src_acc),
                    tf.summary.scalar('acc/trg_acc', trg_acc)]
    summary_gen  = [tf.summary.scalar('gen/loss_trg_gen_fm', loss_trg_gen_fm),
                    tf.summary.image('gen/trg_gen_img', trg_gen_x),]

    # Merge summaries
    summary_disc = tf.summary.merge(summary_disc)
    summary_main = tf.summary.merge(summary_main)
    summary_gen  = tf.summary.merge(summary_gen)

    # Saved ops
    c = tf.constant
    T.ops_print = [c('disc'), loss_disc,
                   c('domain'), loss_domain,
                   c('class'), loss_src_class,
                   c('cent'), loss_trg_cent,
                   c('trg_vat'), loss_trg_vat,
                   c('src_vat'), loss_src_vat,
                   c('src'), src_acc,
                   c('trg'), trg_acc]
    T.ops_disc = [summary_disc, train_disc]
    T.ops_main = [summary_main, train_main]
    T.ops_gen  = [summary_gen , train_gen]
    T.fn_ema_acc = fn_ema_acc
    T.teacher = teacher
    T.update_teacher = update_teacher
    T.trg_gen_x = trg_gen_x
    T.trg_gen_p = trg_gen_p
    T.src_p = src_p
    T.trg_p = trg_p
    T.ema_p = ema_p

    return T