def classifier(): config = tf.ConfigProto() config.gpu_options.allow_growth = True T = tb.utils.TensorDict(dict( sess = tf.Session(config=config), src_x = placeholder((None, 32, 32, 3), name='source_x'), src_y = placeholder((None, 10), name='source_y'), test_x = placeholder((None, 32, 32, 3), name='test_x'), test_y = placeholder((None, 10), name='test_y'), phase = placeholder((), tf.bool, name='phase') )) # Classification src_y = des.classifier(T.src_x, T.phase, internal_update=True) loss_class = tf.reduce_mean(softmax_xent(labels=T.src_y, logits=src_y)) src_acc = basic_accuracy(T.src_y, src_y) # Evaluation (non-EMA) test_y = des.classifier(T.test_x, phase=False, reuse=True) test_acc = basic_accuracy(T.test_y, test_y) fn_test_acc = tb.function(T.sess, [T.test_x, T.test_y], test_acc) # Evaluation (EMA) ema = tf.train.ExponentialMovingAverage(decay=0.998) var_class = tf.get_collection('trainable_variables', 'class') ema_op = ema.apply(var_class) ema_y = des.classifier(T.test_x, phase=False, reuse=True, getter=get_getter(ema)) ema_acc = basic_accuracy(T.test_y, ema_y) fn_ema_acc = tb.function(T.sess, [T.test_x, T.test_y], ema_acc) # Optimizer loss_main = loss_class var_main = var_class train_main = tf.train.AdamOptimizer(args.lr, 0.5).minimize(loss_main, var_list=var_main) train_main = tf.group(train_main, ema_op) # Summarizations summary_main = [ tf.summary.scalar('class/loss_class', loss_class), tf.summary.scalar('acc/src_acc', src_acc), ] summary_main = tf.summary.merge(summary_main) # Saved ops c = tf.constant T.ops_print = [ c('class'), loss_class, c('src'), src_acc, ] T.ops_main = [summary_main, train_main] T.fn_test_acc = fn_test_acc T.fn_ema_acc = fn_ema_acc return T
def classifier(): T = tb.utils.TensorDict( dict(sess=tf.Session(config=tb.growth_config()), src_x=placeholder((None, 32, 32, 3)), src_y=placeholder((None, 10)), trg_x=placeholder((None, 32, 32, 3)), trg_y=placeholder((None, 10)), test_x=placeholder((None, 32, 32, 3)), test_y=placeholder((None, 10)), phase=placeholder((), tf.bool))) # Supervised and conditional entropy minimization src_y = net.classifier(T.src_x, phase=True, internal_update=False) trg_y = net.classifier(T.trg_x, phase=True, internal_update=True, reuse=True) loss_class = tf.reduce_mean(softmax_xent(labels=T.src_y, logits=src_y)) # Evaluation (non-EMA) test_y = net.classifier(T.test_x, phase=False, scope='class', reuse=True) # Evaluation (EMA) ema = tf.train.ExponentialMovingAverage(decay=0.998) ema_op = ema.apply(tf.get_collection('trainable_variables', 'class/')) T.ema_y = net.classifier(T.test_x, phase=False, reuse=True, getter=get_getter(ema)) src_acc = basic_accuracy(T.src_y, src_y) trg_acc = basic_accuracy(T.trg_y, trg_y) ema_acc = basic_accuracy(T.test_y, T.ema_y) fn_ema_acc = tb.function(T.sess, [T.test_x, T.test_y], ema_acc) # Optimizer loss_main = loss_class var_main = tf.get_collection('trainable_variables', 'class') train_main = tf.train.AdamOptimizer(args.lr, 0.5).minimize(loss_main, var_list=var_main) train_main = tf.group(train_main, ema_op) # Summarizations summary_main = [ tf.summary.scalar('class/loss_class', loss_class), tf.summary.scalar('acc/src_acc', src_acc), tf.summary.scalar('acc/trg_acc', trg_acc) ] summary_main = tf.summary.merge(summary_main) # Saved ops c = tf.constant T.ops_print = [c('class'), loss_class] T.ops_main = [summary_main, train_main] T.fn_ema_acc = fn_ema_acc return T
def __init__(self, domain, datasets, M): print "Constructing pseudodata" sys.stdout.flush() cast = domain not in {'mnist28', 'mnist32', 'mnistm28', 'mnistm32'} print "Casting:", cast labeler = tb.function(M.sess, [M.test_x], M.back_y) self.train = Data(datasets.train.images, labeler=labeler, cast=cast) self.test = Data(datasets.test.images, labeler=labeler, cast=cast)
def dann_embed(): config = tf.ConfigProto() config.gpu_options.allow_growth = True # config.gpu_options.per_process_gpu_memory_fraction = 0.45 T = tb.utils.TensorDict(dict( sess = tf.Session(config=config), src_x = placeholder((None, H, W, 3)), src_y = placeholder((None, Y)), trg_x = placeholder((None, H, W, 3)), trg_y = placeholder((None, Y)), fake_z = placeholder((None, 100)), fake_y = placeholder((None, Y)), test_x = placeholder((None, H, W, 3)), test_y = placeholder((None, Y)), phase = placeholder((), tf.bool) )) # Schedules start, end = args.pivot - 1, args.pivot global_step = tf.Variable(0., trainable=False) # Ramp down dw ramp_dw = conditional_ramp_weight(args.dwdn, global_step, args.dw, 0, start, end) # Ramp down src ramp_class = conditional_ramp_weight(args.dn, global_step, 1, args.dcval, start, end) ramp_sbw = conditional_ramp_weight(args.dn, global_step, args.sbw, 0, start, end) # Ramp up trg (never more than src) ramp_cw = conditional_ramp_weight(args.up, global_step, args.cw, args.uval, start, end) ramp_gw = conditional_ramp_weight(args.up, global_step, args.gw, args.uval, start, end) ramp_gbw = conditional_ramp_weight(args.up, global_step, args.gbw, args.uval, start, end) ramp_tbw = conditional_ramp_weight(args.up, global_step, args.tbw, args.uval, start, end) # Supervised and conditional entropy minimization src_e = des.classifier(T.src_x, T.phase, enc_phase=1, trim=args.trim, scope='class', internal_update=False) trg_e = des.classifier(T.trg_x, T.phase, enc_phase=1, trim=args.trim, scope='class', reuse=True, internal_update=True) src_y = des.classifier(src_e, T.phase, enc_phase=0, trim=args.trim, scope='class', internal_update=False) trg_y = des.classifier(trg_e, T.phase, enc_phase=0, trim=args.trim, scope='class', reuse=True, internal_update=True) loss_class = tf.reduce_mean(softmax_xent(labels=T.src_y, logits=src_y)) loss_cent = tf.reduce_mean(softmax_xent_two(labels=trg_y, logits=trg_y)) # Image generation if args.gw > 0: fake_x = des.generator(T.fake_z, T.fake_y, T.phase) fake_logit = des.discriminator(fake_x, T.phase) real_logit = des.discriminator(T.trg_x, T.phase, reuse=True) fake_e = des.classifier(fake_x, T.phase, enc_phase=1, trim=args.trim, scope='class', reuse=True) fake_y = des.classifier(fake_e, T.phase, enc_phase=0, trim=args.trim, scope='class', reuse=True) loss_gdisc = 0.5 * tf.reduce_mean( sigmoid_xent(labels=tf.ones_like(real_logit), logits=real_logit) + sigmoid_xent(labels=tf.zeros_like(fake_logit), logits=fake_logit)) loss_gen = tf.reduce_mean(sigmoid_xent(labels=tf.ones_like(fake_logit), logits=fake_logit)) loss_info = tf.reduce_mean(softmax_xent(labels=T.fake_y, logits=fake_y)) else: loss_gdisc = constant(0) loss_gen = constant(0) loss_info = constant(0) # Domain confusion if args.dw > 0 and args.phase == 0: real_logit = des.feature_discriminator(src_e, T.phase) fake_logit = des.feature_discriminator(trg_e, T.phase, reuse=True) loss_ddisc = 0.5 * tf.reduce_mean( sigmoid_xent(labels=tf.ones_like(real_logit), logits=real_logit) + sigmoid_xent(labels=tf.zeros_like(fake_logit), logits=fake_logit)) loss_domain = 0.5 * tf.reduce_mean( sigmoid_xent(labels=tf.zeros_like(real_logit), logits=real_logit) + sigmoid_xent(labels=tf.ones_like(fake_logit), logits=fake_logit)) else: loss_ddisc = constant(0) loss_domain = constant(0) # Smoothing loss_t_ball = constant(0) if args.tbw == 0 else smoothing_loss(T.trg_x, trg_y, T.phase) loss_s_ball = constant(0) if args.sbw == 0 or args.phase == 1 else smoothing_loss(T.src_x, src_y, T.phase) loss_g_ball = constant(0) if args.gbw == 0 else smoothing_loss(fake_x, fake_y, T.phase) loss_t_emb = constant(0) if args.te == 0 else smoothing_loss(T.trg_x, trg_e, T.phase, is_embedding=True) loss_s_emb = constant(0) if args.se == 0 else smoothing_loss(T.src_x, src_e, T.phase, is_embedding=True) # Evaluation (non-EMA) test_y = des.classifier(T.test_x, False, enc_phase=1, trim=0, scope='class', reuse=True) # Evaluation (EMA) ema = tf.train.ExponentialMovingAverage(decay=0.998) var_class = tf.get_collection('trainable_variables', 'class/') ema_op = ema.apply(var_class) T.ema_e = des.classifier(T.test_x, False, enc_phase=1, trim=args.trim, scope='class', reuse=True, getter=get_getter(ema)) ema_y = des.classifier(T.ema_e, False, enc_phase=0, trim=args.trim, scope='class', reuse=True, getter=get_getter(ema)) # Back-up (teacher) model back_y = des.classifier(T.test_x, False, enc_phase=1, trim=0, scope='back') var_main = tf.get_collection('variables', 'class/(?!.*ExponentialMovingAverage:0)') var_back = tf.get_collection('variables', 'back/(?!.*ExponentialMovingAverage:0)') back_assigns = [] init_assigns = [] for b, m in zip(var_back, var_main): ave = ema.average(m) target = ave if ave else m back_assigns += [tf.assign(b, target)] init_assigns += [tf.assign(m, target)] # print "Assign {} -> {}, {}".format(target.name, b.name, m.name) back_update = tf.group(*back_assigns) init_update = tf.group(*init_assigns) src_acc = basic_accuracy(T.src_y, src_y) trg_acc = basic_accuracy(T.trg_y, trg_y) test_acc = basic_accuracy(T.test_y, test_y) ema_acc = basic_accuracy(T.test_y, ema_y) fn_test_acc = tb.function(T.sess, [T.test_x, T.test_y], test_acc) fn_ema_acc = tb.function(T.sess, [T.test_x, T.test_y], ema_acc) # Optimizer loss_main = (ramp_class * loss_class + ramp_dw * loss_domain + ramp_cw * loss_cent + ramp_tbw * loss_t_ball + ramp_gbw * loss_g_ball + ramp_sbw * loss_s_ball + args.te * loss_t_emb + args.se * loss_s_emb + ramp_gw * loss_gen + ramp_gw * loss_info) var_main = tf.get_collection('trainable_variables', 'class') var_main += tf.get_collection('trainable_variables', 'gen') train_main = tf.train.AdamOptimizer(args.lr, 0.5).minimize(loss_main, var_list=var_main, global_step=global_step) train_main = tf.group(train_main, ema_op) if (args.dw > 0 and args.phase == 0) or args.gw > 0: loss_disc = loss_ddisc + loss_gdisc var_disc = tf.get_collection('trainable_variables', 'disc') train_disc = tf.train.AdamOptimizer(args.lr, 0.5).minimize(loss_disc, var_list=var_disc) else: train_disc = constant(0) # Summarizations # embedding = tf.Variable(tf.zeros([1000,12800]),name='embedding') # embedding = tf.reshape(trg_e[:1000], [-1,12800]) summary_disc = [tf.summary.scalar('domain/loss_ddisc', loss_ddisc), tf.summary.scalar('gen/loss_gdisc', loss_gdisc)] summary_main = [tf.summary.scalar('class/loss_class', loss_class), tf.summary.scalar('class/loss_cent', loss_cent), tf.summary.scalar('domain/loss_domain', loss_domain), tf.summary.scalar('lipschitz/loss_t_ball', loss_t_ball), tf.summary.scalar('lipschitz/loss_g_ball', loss_g_ball), tf.summary.scalar('lipschitz/loss_s_ball', loss_s_ball), tf.summary.scalar('embedding/loss_t_emb', loss_t_emb), tf.summary.scalar('embedding/loss_s_emb', loss_s_emb), tf.summary.scalar('gen/loss_gen', loss_gen), tf.summary.scalar('gen/loss_info', loss_info), tf.summary.scalar('ramp/ramp_class', ramp_class), tf.summary.scalar('ramp/ramp_dw', ramp_dw), tf.summary.scalar('ramp/ramp_cw', ramp_cw), tf.summary.scalar('ramp/ramp_gw', ramp_gw), tf.summary.scalar('ramp/ramp_tbw', ramp_tbw), tf.summary.scalar('ramp/ramp_sbw', ramp_sbw), tf.summary.scalar('ramp/ramp_gbw', ramp_gbw), tf.summary.scalar('acc/src_acc', src_acc), tf.summary.scalar('acc/trg_acc', trg_acc)] summary_disc = tf.summary.merge(summary_disc) summary_main = tf.summary.merge(summary_main) # Saved ops c = tf.constant T.ops_print = [c('ddisc'), loss_ddisc, c('domain'), loss_domain, c('gdisc'), loss_gdisc, c('gen'), loss_gen, c('info'), loss_info, c('class'), loss_class, c('cent'), loss_cent, c('t_ball'), loss_t_ball, c('g_ball'), loss_g_ball, c('s_ball'), loss_s_ball, c('t_emb'), loss_t_emb, c('s_emb'), loss_s_emb, c('src'), src_acc, c('trg'), trg_acc] T.ops_disc = [summary_disc, train_disc] T.ops_main = [summary_main, train_main] T.fn_test_acc = fn_test_acc T.fn_ema_acc = fn_ema_acc T.back_y = tf.nn.softmax(back_y) # Access to backed-up eval model softmax T.back_update = back_update # Update op eval -> backed-up eval model T.init_update = init_update # Update op eval -> student eval model T.global_step = global_step T.ramp_class = ramp_class if args.gw > 0: summary_image = tf.summary.image('image/gen', generate_img()) T.ops_image = summary_image return T
def vae(): config = tf.ConfigProto() config.gpu_options.allow_growth = True T = tb.utils.TensorDict(dict( sess = tf.Session(config=config), src_x = placeholder((None, 32, 32, 3), name='source_x'), src_y = placeholder((None, 10), name='source_y'), trg_x = placeholder((None, 32, 32, 3), name='target_x'), trg_y = placeholder((None, 10), name='target_y'), test_x = placeholder((None, 32, 32, 3), name='test_x'), test_y = placeholder((None, 10), name='test_y'), fake_z = placeholder((None, 100), name='fake_z'), fake_y = placeholder((None, 10), name='fake_y'), tau = placeholder((), name='tau'), phase = placeholder((), tf.bool, name='phase'), )) if args.gw > 0: # Variational inference y_logit = des.classifier(T.trg_x, T.phase, internal_update=True) y = gumbel_softmax(y_logit, T.tau) z, z_post = des.encoder(T.trg_x, y, T.phase, internal_update=True) # Generation x = des.generator(z, y, T.phase, internal_update=True) # Loss z_prior = (0., 1.) kl_z = tf.reduce_mean(log_normal(z, *z_post) - log_normal(z, *z_prior)) y_q = tf.nn.softmax(y_logit) log_y_q = tf.nn.log_softmax(y_logit) kl_y = tf.reduce_mean(tf.reduce_sum(y_q * (log_y_q - tf.log(0.1)), axis=1)) loss_kl = kl_z + kl_y loss_rec = args.rw * tf.reduce_mean(tf.reduce_sum(tf.square(T.trg_x - x), axis=[1,2,3])) loss_gen = loss_rec + loss_kl trg_acc = basic_accuracy(T.trg_y, y_logit) else: loss_kl = constant(0) loss_rec = constant(0) loss_gen = constant(0) trg_acc = constant(0) # Posterior regularization (labeled classification) src_y = des.classifier(T.src_x, T.phase, reuse=True) loss_class = tf.reduce_mean(softmax_xent(labels=T.src_y, logits=src_y)) src_acc = basic_accuracy(T.src_y, src_y) # Evaluation (classification) test_y = des.classifier(T.test_x, phase=False, reuse=True) test_acc = basic_accuracy(T.test_y, test_y) fn_test_acc = tb.function(T.sess, [T.test_x, T.test_y], test_acc) # Evaluation (generation) if args.gw > 0: fake_x = des.generator(T.fake_z, T.fake_y, phase=False, reuse=True) fn_fake_x = tb.function(T.sess, [T.fake_z, T.fake_y], fake_x) # Optimizer var_main = tf.get_collection('trainable_variables', 'gen/') var_main += tf.get_collection('trainable_variables', 'enc/') var_main += tf.get_collection('trainable_variables', 'class/') loss_main = args.gw * loss_gen + loss_class train_main = tf.train.AdamOptimizer(args.lr, 0.5).minimize(loss_main, var_list=var_main) # Summarizations summary_main = [ tf.summary.scalar('gen/loss_gen', loss_gen), tf.summary.scalar('gen/loss_rec', loss_rec), tf.summary.scalar('gen/loss_kl', loss_kl), tf.summary.scalar('class/loss_class', loss_class), tf.summary.scalar('acc/src_acc', src_acc), tf.summary.scalar('acc/trg_acc', trg_acc), ] summary_main = tf.summary.merge(summary_main) if args.gw > 0: summary_image = tf.summary.image('image/gen', generate_img()) # Saved ops c = tf.constant T.ops_print = [ c('tau'), tf.identity(T.tau), c('gen'), loss_gen, c('rec'), loss_rec, c('kl'), loss_kl, c('class'), loss_class, ] T.ops_main = [summary_main, train_main] T.fn_test_acc = fn_test_acc if args.gw > 0: T.fn_fake_x = fn_fake_x T.ops_image = summary_image return T
def dirtt(): T = tb.utils.TensorDict(dict( sess = tf.Session(config=tb.growth_config()), src_x = placeholder((None, 32, 32, 3)), src_y = placeholder((None, args.Y)), trg_x = placeholder((None, 32, 32, 3)), trg_y = placeholder((None, args.Y)), test_x = placeholder((None, 32, 32, 3)), test_y = placeholder((None, args.Y)), )) # Supervised and conditional entropy minimization src_e = nn.classifier(T.src_x, phase=True, enc_phase=1, trim=args.trim) trg_e = nn.classifier(T.trg_x, phase=True, enc_phase=1, trim=args.trim, reuse=True, internal_update=True) src_p = nn.classifier(src_e, phase=True, enc_phase=0, trim=args.trim) trg_p = nn.classifier(trg_e, phase=True, enc_phase=0, trim=args.trim, reuse=True, internal_update=True) loss_src_class = tf.reduce_mean(softmax_xent(labels=T.src_y, logits=src_p)) loss_trg_cent = tf.reduce_mean(softmax_xent_two(labels=trg_p, logits=trg_p)) # Domain confusion if args.dw > 0 and args.dirt == 0: real_logit = nn.feature_discriminator(src_e, phase=True) fake_logit = nn.feature_discriminator(trg_e, phase=True, reuse=True) loss_disc = 0.5 * tf.reduce_mean( sigmoid_xent(labels=tf.ones_like(real_logit), logits=real_logit) + sigmoid_xent(labels=tf.zeros_like(fake_logit), logits=fake_logit)) loss_domain = 0.5 * tf.reduce_mean( sigmoid_xent(labels=tf.zeros_like(real_logit), logits=real_logit) + sigmoid_xent(labels=tf.ones_like(fake_logit), logits=fake_logit)) else: loss_disc = constant(0) loss_domain = constant(0) # Virtual adversarial training (turn off src in non-VADA phase) loss_src_vat = vat_loss(T.src_x, src_p, nn.classifier) if args.sw > 0 and args.dirt == 0 else constant(0) loss_trg_vat = vat_loss(T.trg_x, trg_p, nn.classifier) if args.tw > 0 else constant(0) # Evaluation (EMA) ema = tf.train.ExponentialMovingAverage(decay=0.998) var_class = tf.get_collection('trainable_variables', 'class/') ema_op = ema.apply(var_class) ema_p = nn.classifier(T.test_x, phase=False, reuse=True, getter=tb.tfutils.get_getter(ema)) # Teacher model (a back-up of EMA model) teacher_p = nn.classifier(T.test_x, phase=False, scope='teacher') var_main = tf.get_collection('variables', 'class/(?!.*ExponentialMovingAverage:0)') var_teacher = tf.get_collection('variables', 'teacher/(?!.*ExponentialMovingAverage:0)') teacher_assign_ops = [] for t, m in zip(var_teacher, var_main): ave = ema.average(m) ave = ave if ave else m teacher_assign_ops += [tf.assign(t, ave)] update_teacher = tf.group(*teacher_assign_ops) teacher = tb.function(T.sess, [T.test_x], tf.nn.softmax(teacher_p)) # Accuracies src_acc = basic_accuracy(T.src_y, src_p) trg_acc = basic_accuracy(T.trg_y, trg_p) ema_acc = basic_accuracy(T.test_y, ema_p) fn_ema_acc = tb.function(T.sess, [T.test_x, T.test_y], ema_acc) # Optimizer dw = constant(args.dw) if args.dirt == 0 else constant(0) cw = constant(1) if args.dirt == 0 else constant(args.bw) sw = constant(args.sw) if args.dirt == 0 else constant(0) tw = constant(args.tw) loss_main = (dw * loss_domain + cw * loss_src_class + sw * loss_src_vat + tw * loss_trg_cent + tw * loss_trg_vat) var_main = tf.get_collection('trainable_variables', 'class') train_main = tf.train.AdamOptimizer(args.lr, 0.5).minimize(loss_main, var_list=var_main) train_main = tf.group(train_main, ema_op) if args.dw > 0 and args.dirt == 0: var_disc = tf.get_collection('trainable_variables', 'disc') train_disc = tf.train.AdamOptimizer(args.lr, 0.5).minimize(loss_disc, var_list=var_disc) else: train_disc = constant(0) # Summarizations summary_disc = [tf.summary.scalar('domain/loss_disc', loss_disc),] summary_main = [tf.summary.scalar('domain/loss_domain', loss_domain), tf.summary.scalar('class/loss_src_class', loss_src_class), tf.summary.scalar('class/loss_trg_cent', loss_trg_cent), tf.summary.scalar('lipschitz/loss_trg_vat', loss_trg_vat), tf.summary.scalar('lipschitz/loss_src_vat', loss_src_vat), tf.summary.scalar('hyper/dw', dw), tf.summary.scalar('hyper/cw', cw), tf.summary.scalar('hyper/sw', sw), tf.summary.scalar('hyper/tw', tw), tf.summary.scalar('acc/src_acc', src_acc), tf.summary.scalar('acc/trg_acc', trg_acc)] # Merge summaries summary_disc = tf.summary.merge(summary_disc) summary_main = tf.summary.merge(summary_main) # Saved ops c = tf.constant T.ops_print = [c('disc'), loss_disc, c('domain'), loss_domain, c('class'), loss_src_class, c('cent'), loss_trg_cent, c('trg_vat'), loss_trg_vat, c('src_vat'), loss_src_vat, c('src'), src_acc, c('trg'), trg_acc] T.ops_disc = [summary_disc, train_disc] T.ops_main = [summary_main, train_main] T.fn_ema_acc = fn_ema_acc T.teacher = teacher T.update_teacher = update_teacher return T
def vae(): T = tb.utils.TensorDict( dict( sess=tf.Session(config=tb.growth_config()), trg_x=placeholder((None, 32, 32, 3), name='target_x'), fake_z=placeholder((None, args.Z), name='fake_z'), )) # Inference z, z_post = nn.encoder(T.trg_x, phase=True, internal_update=True) # Generation x = nn.generator(z, phase=True, internal_update=True) # Loss loss_rec, loss_kl, loss_gen = vae_loss(x, T.trg_x, z, z_post) # Evaluation (embedding, reconstruction, loss) test_z, test_z_post = nn.encoder(T.trg_x, phase=False, reuse=True) test_x = nn.generator(test_z, phase=False, reuse=True) _, _, test_loss = vae_loss(test_x, T.trg_x, test_z, test_z_post) fn_embed = tb.function(T.sess, [T.trg_x], test_z_post) fn_recon = tb.function(T.sess, [T.trg_x], test_x) fn_loss = tb.function(T.sess, [T.trg_x], test_loss) # Evaluation (generation) fake_x = nn.generator(T.fake_z, phase=False, reuse=True) fn_generate = tb.function(T.sess, [T.fake_z], fake_x) # Optimizer var_main = tf.get_collection('trainable_variables', 'gen/') var_main += tf.get_collection('trainable_variables', 'enc/') loss_main = loss_gen train_main = tf.train.AdamOptimizer(args.lr, 0.5).minimize(loss_main, var_list=var_main) # Summarizations summary_main = [ tf.summary.scalar('gen/loss_gen', loss_gen), tf.summary.scalar('gen/loss_kl', loss_kl), tf.summary.scalar('gen/loss_rec', loss_rec), ] summary_image = [tf.summary.image('gen/gen', generate_image(nn.generator))] # Merge summaries summary_main = tf.summary.merge(summary_main) summary_image = tf.summary.merge(summary_image) # Saved ops c = tf.constant T.ops_print = [ c('gen'), loss_gen, c('kl'), loss_kl, c('rec'), loss_rec, ] T.ops_main = [summary_main, train_main] T.ops_image = summary_image T.fn_embed = fn_embed T.fn_recon = fn_recon T.fn_loss = fn_loss T.fn_generate = fn_generate return T
def _build_model(self): self.x_src_lst = [] self.y_src_lst = [] for i in range(self.data_loader.num_src_domain): x_src = tf.placeholder(dtype=tf.float32, shape=tuple([None]) + self.dim_src, name='x_src_{}_input'.format(i)) y_src = tf.placeholder(dtype=tf.float32, shape=(None, self.num_classes), name='y_src_{}_input'.format(i)) self.x_src_lst.append(x_src) self.y_src_lst.append(y_src) self.x_trg = tf.placeholder(dtype=tf.float32, shape=tuple([None]) + self.dim_trg, name='x_trg_input') self.y_trg = tf.placeholder(dtype=tf.float32, shape=(None, self.num_classes), name='y_trg_input') self.y_src_domain = tf.placeholder(dtype=tf.float32, shape=(None, self.data_loader.num_src_domain), name='y_src_domain_input') T = tb.utils.TensorDict(dict( x_tmp=tf.placeholder(dtype=tf.float32, shape=tuple([None]) + self.dim_src), y_tmp=tf.placeholder(dtype=tf.float32, shape=(None, self.num_classes)) )) self.is_training = tf.placeholder(tf.bool, shape=(), name='is_training') self.x_src_mid_lst = [] for i in range(self.data_loader.num_src_domain): x_src_mid = self._build_source_middle(self.x_src_lst[i], is_reused=i) self.x_src_mid_lst.append(x_src_mid) self.x_trg_mid = self._build_target_middle(self.x_trg, reuse=True) # <editor-fold desc="Classifier-logits"> self.y_src_logit_lst = [] for i in range(self.data_loader.num_src_domain): y_src_logit = self._build_class_src_discriminator(self.x_src_mid_lst[i], self.num_classes, i) self.y_src_logit_lst.append(y_src_logit) self.y_trg_logit = self._build_class_trg_discriminator(self.x_trg_mid, self.num_classes) # </editor-fold> # <editor-fold desc="Classification"> self.src_loss_class_lst = [] self.src_loss_class_sum = tf.constant(0.0) for i in range(self.data_loader.num_src_domain): src_loss_class_detail = tf.nn.softmax_cross_entropy_with_logits_v2( logits=self.y_src_logit_lst[i], labels=self.y_src_lst[i]) src_loss_class = tf.reduce_mean(src_loss_class_detail) self.src_loss_class_lst.append(self.src_domain_trade_off[i] * src_loss_class) self.src_loss_class_sum += self.src_domain_trade_off[i] * src_loss_class self.trg_loss_class_detail = tf.nn.softmax_cross_entropy_with_logits_v2( logits=self.y_trg_logit, labels=self.y_trg) self.trg_loss_class = tf.reduce_mean(self.trg_loss_class_detail) # </editor-fold> # <editor-fold desc="Source domain discriminator"> self.x_src_mid_all = tf.concat(self.x_src_mid_lst, axis=0) self.y_src_discriminator_logit = self._build_domain_discriminator(self.x_src_mid_all) self.src_loss_discriminator_detail = tf.nn.softmax_cross_entropy_with_logits_v2( logits=self.y_src_discriminator_logit, labels=self.y_src_domain) self.src_loss_discriminator = tf.reduce_mean(self.src_loss_discriminator_detail) # </editor-fold> # <editor-fold desc="Compute teacher hS(xS)"> self.y_src_teacher_all = [] for i, bs in zip(range(self.data_loader.num_src_domain), range(0, self.batch_size_src * self.data_loader.num_src_domain, self.batch_size_src)): y_src_logit_each_h_lst = [] for j in range(self.data_loader.num_src_domain): y_src_logit_each_h = self._build_class_src_discriminator(self.x_src_mid_lst[i], self.num_classes, j, reuse=True) y_src_logit_each_h_lst.append(y_src_logit_each_h) y_src_logit_each_h_lst = tf.nn.softmax(tf.convert_to_tensor(y_src_logit_each_h_lst)) y_src_discriminator_prob = tf.nn.softmax(tf.gather(self.y_src_discriminator_logit, tf.range(bs, bs + self.batch_size_src, dtype=tf.int32), axis=0)) y_src_teacher = self._compute_teacher_hs(y_src_logit_each_h_lst, y_src_discriminator_prob) self.y_src_teacher_all.append(y_src_teacher) self.y_src_teacher_all = tf.concat(self.y_src_teacher_all, axis=0) # </editor-fold> # <editor-fold desc="Compute teacher hS(xt)"> y_trg_logit_each_h_lst = [] for j in range(self.data_loader.num_src_domain): y_trg_logit_each_h = self._build_class_src_discriminator(self.x_trg_mid, self.num_classes, j, reuse=True) y_trg_logit_each_h_lst.append(y_trg_logit_each_h) y_trg_logit_each_h_lst = tf.nn.softmax(tf.convert_to_tensor(y_trg_logit_each_h_lst)) self.y_trg_src_domains_logit = self._build_domain_discriminator(self.x_trg_mid, reuse=True) y_trg_discriminator_prob = tf.nn.softmax(self.y_trg_src_domains_logit) self.y_trg_teacher = self._compute_teacher_hs(y_trg_logit_each_h_lst, y_trg_discriminator_prob) # </editor-fold> # <editor-fold desc="Compute pseudo-label loss"> self.ht_g_xs = build_class_discriminator_template( self.x_src_mid_all, training_phase=self.is_training, scope='c-trg', num_classes=self.num_classes, reuse=True, internal_update=True, class_discriminator_layout=self.classify_layout, cnn_size=self.cnn_size ) self.mimic_loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits_v2( logits=self.ht_g_xs, labels=self.y_src_teacher_all)) + \ tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits_v2( logits=self.y_trg_logit, labels=self.y_trg_teacher)) # </editor-fold> # <editor-fold desc="Compute WS loss"> self.data_shift_loss = self._compute_cosine_similarity(self.x_trg_mid, self.x_src_mid_all) self.label_shift_loss = self._compute_label_shift_loss(self.y_trg_logit, self.ht_g_xs) self.data_label_shift_loss = self.data_shift_troff * self.data_shift_loss + self.lbl_shift_troff * self.label_shift_loss self.g_network = tf.reshape(self._build_phi_network(self.x_trg_mid), [-1]) self.exp_term = (- self.data_label_shift_loss + self.g_network) / self.theta self.g_network_loss = tf.reduce_mean(self.g_network) self.OT_loss = tf.reduce_mean( - self.theta * \ ( tf.log(1.0 / self.batch_size) + tf.reduce_logsumexp(self.exp_term, axis=1) ) ) + self.g_network_trade_off * self.g_network_loss # </editor-fold> # <editor-fold desc="Compute VAT loss"> self.trg_loss_vat = self._build_vat_loss( self.x_trg, self.y_trg_logit, self.num_classes, scope_encode=self._get_scope('generator', 'trg'), scope_classify='c-trg' ) # </editor-fold> # <editor-fold desc="Compute conditional entropy loss w.r.t target distribution"> self.trg_loss_cond_entropy = tf.reduce_mean(softmax_x_entropy_two(labels=self.y_trg_logit, logits=self.y_trg_logit)) # </editor-fold> # <editor-fold desc="Accuracy"> self.src_accuracy_lst = [] for i in range(self.data_loader.num_src_domain): y_src_pred = tf.argmax(self.y_src_logit_lst[i], 1, output_type=tf.int32) y_src_sparse = tf.argmax(self.y_src_lst[i], 1, output_type=tf.int32) src_accuracy = tf.reduce_mean(tf.cast(tf.equal(y_src_sparse, y_src_pred), 'float32')) self.src_accuracy_lst.append(src_accuracy) # compute acc for target domain self.y_trg_pred = tf.argmax(self.y_trg_logit, 1, output_type=tf.int32) self.y_trg_sparse = tf.argmax(self.y_trg, 1, output_type=tf.int32) self.trg_accuracy = tf.reduce_mean(tf.cast(tf.equal(self.y_trg_sparse, self.y_trg_pred), 'float32')) # compute acc for src domain disc self.y_src_domain_pred = tf.argmax(self.y_src_discriminator_logit, 1, output_type=tf.int32) self.y_src_domain_sparse = tf.argmax(self.y_src_domain, 1, output_type=tf.int32) self.src_domain_acc = tf.reduce_mean( tf.cast(tf.equal(self.y_src_domain_sparse, self.y_src_domain_pred), 'float32')) # </editor-fold> # <editor-fold desc="Put it all together"> lst_phase1_losses = [ (self.src_class_trade_off, self.src_loss_class_sum), (self.domain_trade_off, self.src_loss_discriminator), ] self.phase1_loss = tf.constant(0.0) for trade_off, loss in lst_phase1_losses: # if trade_off != 0: self.phase1_loss += trade_off * loss lst_phase2_losses = [ (self.src_class_trade_off, self.src_loss_class_sum), (self.ot_trade_off, self.OT_loss), (self.domain_trade_off, self.src_loss_discriminator), (self.trg_vat_troff, self.trg_loss_vat), (self.trg_ent_troff, self.trg_loss_cond_entropy), (self.mimic_trade_off, self.mimic_loss) ] self.phase2_loss = tf.constant(0.0) for trade_off, loss in lst_phase2_losses: # if trade_off != 0: self.phase2_loss += trade_off * loss # </editor-fold> # <editor-fold desc="Evaluation"> primary_student_variables = self._get_variables(self._get_student_primary_scopes()) ema = tf.train.ExponentialMovingAverage(decay=0.998) var_list_for_ema = primary_student_variables[0] + primary_student_variables[1] ema_op = ema.apply(var_list=var_list_for_ema) self.ema_p = self._build_classifier(T.x_tmp, self.num_classes, ema) self.batch_ema_acc = batch_ema_acc(T.y_tmp, self.ema_p) self.fn_batch_ema_acc = tb.function(self.tf_session, [T.x_tmp, T.y_tmp], self.batch_ema_acc) # </editor-fold> teacher_variables = self._get_variables(self._get_teacher_scopes()) self.train_teacher = \ tf.train.AdamOptimizer(self.learning_rate, 0.5).minimize(self.phase1_loss, var_list=teacher_variables) self.train_student_main = \ tf.train.AdamOptimizer(self.learning_rate, 0.5).minimize(self.phase2_loss, var_list=teacher_variables + [ primary_student_variables[1]]) self.primary_train_student_op = tf.group(self.train_student_main, ema_op) # <editor-fold desc="Construct secondary loss"> secondary_variables = self._get_variables(self._get_student_secondary_scopes()) self.secondary_train_student_op = \ tf.train.AdamOptimizer(self.learning_rate, 0.5).minimize(-self.OT_loss, var_list=secondary_variables) # </editor-fold> # <editor-fold desc="Summaries"> tf.summary.scalar('loss/phase1_loss', self.phase1_loss) tf.summary.scalar('loss/phase2_loss', self.phase2_loss) tf.summary.scalar('loss/W_distance', self.OT_loss) tf.summary.scalar('loss/src_loss_discriminator', self.src_loss_discriminator) for i in range(self.data_loader.num_src_domain): tf.summary.scalar('loss/src_loss_class_{}'.format(i), self.src_loss_class_lst[i]) tf.summary.scalar('acc/src_acc_{}'.format(i), self.src_accuracy_lst[i]) tf.summary.scalar('acc/src_domain_acc', self.src_domain_acc) tf.summary.scalar('acc/trg_acc', self.trg_accuracy) tf.summary.scalar('trg_loss_class', self.trg_loss_class) tf.summary.scalar('hyperparameters/learning_rate', self.learning_rate) tf.summary.scalar('hyperparameters/src_class_trade_off', self.src_class_trade_off) tf.summary.scalar('hyperparameters/g_network_trade_off', self.g_network_trade_off) tf.summary.scalar('hyperparameters/domain_trade_off', self.domain_trade_off) tf.summary.scalar('hyperparameters/trg_vat_troff', self.trg_vat_troff) tf.summary.scalar('hyperparameters/trg_ent_troff', self.trg_ent_troff) self.tf_merged_summaries = tf.summary.merge_all()
g_train = tf.train.AdamOptimizer(lr, 0.5).minimize(g_loss, var_list=g_var) # Logger base_dir = 'results/gamma={:.1f}_run={:d}'.format(args.gamma, args.run) writer = tb.FileWriter(os.path.join(base_dir, 'log.csv'), args=args, overwrite=args.run >= 999) writer.add_var('d_real', '{:8.4f}', d_real_loss) writer.add_var('d_fake', '{:8.4f}', d_fake_loss) writer.add_var('k', '{:8.4f}', k * 1) writer.add_var('M', '{:8.4f}', m_global) writer.add_var('lr', '{:8.6f}', lr * 1) writer.add_var('iter', '{:>8d}') writer.initialize() sess = tf.Session() load_model(sess) f_gen = tb.function(sess, [z], x_fake) f_rec = tb.function(sess, [x_real], d_real) celeba = CelebA(args.data) # Alternatively try grouping d_train/g_train together all_tensors = [d_train, g_train, d_real_loss, d_fake_loss] # d_tensors = [d_train, d_real_loss] # g_tensors = [g_train, d_fake_loss] for i in xrange(args.max_iter): x = celeba.next_batch(args.bs) z = np.random.uniform(-1, 1, (args.bs, args.e_size)) feed_dict = {'x:0': x, 'z:0': z, 'k:0': args.k, 'lr:0': args.lr, 'g:0': args.gamma} _, _, d_real_loss, d_fake_loss = sess.run(all_tensors, feed_dict) # _, d_real_loss = sess.run(d_tensors, feed_dict) # _, d_fake_loss = sess.run(g_tensors, feed_dict)
def dirtt(): T = tb.utils.TensorDict(dict( sess = tf.Session(config=tb.growth_config()), src_x = placeholder((None, 500, 60, 1)), src_y = placeholder((None, args.Y)), trg_x = placeholder((None, 500, 60, 1)), trg_y = placeholder((None, args.Y)), test_x = placeholder((None, 500, 60, 1)), test_y = placeholder((None, args.Y)), )) # Supervised and conditional entropy minimization src_e = nn.classifier(T.src_x, phase=True, enc_phase=1, trim=args.trim) trg_e = nn.classifier(T.trg_x, phase=True, enc_phase=1, trim=args.trim, reuse=True, internal_update=True) src_p = nn.classifier(src_e, phase=True, enc_phase=0, trim=args.trim) trg_p = nn.classifier(trg_e, phase=True, enc_phase=0, trim=args.trim, reuse=True, internal_update=True) loss_src_class = tf.reduce_mean(softmax_xent(labels=T.src_y, logits=src_p)) loss_trg_cent = tf.reduce_mean(softmax_xent_two(labels=trg_p, logits=trg_p)) # Domain confusion if args.dw > 0 and args.dirt == 0: real_logit = nn.feature_discriminator(src_e, phase=True) fake_logit = nn.feature_discriminator(trg_e, phase=True, reuse=True) loss_disc = 0.5 * tf.reduce_mean( sigmoid_xent(labels=tf.ones_like(real_logit), logits=real_logit) + sigmoid_xent(labels=tf.zeros_like(fake_logit), logits=fake_logit)) loss_domain = 0.5 * tf.reduce_mean( sigmoid_xent(labels=tf.zeros_like(real_logit), logits=real_logit) + sigmoid_xent(labels=tf.ones_like(fake_logit), logits=fake_logit)) else: loss_disc = constant(0) loss_domain = constant(0) # Virtual adversarial training (turn off src in non-VADA phase) loss_src_vat = vat_loss(T.src_x, src_p, nn.classifier) if args.sw > 0 and args.dirt == 0 else constant(0) loss_trg_vat = vat_loss(T.trg_x, trg_p, nn.classifier) if args.tw > 0 else constant(0) # Evaluation (EMA) ema = tf.train.ExponentialMovingAverage(decay=0.998) var_class = tf.get_collection('trainable_variables', 'class/') ema_op = ema.apply(var_class) ema_p = nn.classifier(T.test_x, phase=False, reuse=True, getter=tb.tfutils.get_getter(ema)) # Teacher model (a back-up of EMA model) teacher_p = nn.classifier(T.test_x, phase=False, scope='teacher') var_main = tf.get_collection('variables', 'class/(?!.*ExponentialMovingAverage:0)') var_teacher = tf.get_collection('variables', 'teacher/(?!.*ExponentialMovingAverage:0)') teacher_assign_ops = [] for t, m in zip(var_teacher, var_main): ave = ema.average(m) ave = ave if ave else m teacher_assign_ops += [tf.assign(t, ave)] update_teacher = tf.group(*teacher_assign_ops) teacher = tb.function(T.sess, [T.test_x], tf.nn.softmax(teacher_p)) # Accuracies src_acc = basic_accuracy(T.src_y, src_p) trg_acc = basic_accuracy(T.trg_y, trg_p) ema_acc = basic_accuracy(T.test_y, ema_p) fn_ema_acc = tb.function(T.sess, [T.test_x, T.test_y], ema_acc) # Optimizer dw = constant(args.dw) if args.dirt == 0 else constant(0) cw = constant(1) if args.dirt == 0 else constant(args.bw) sw = constant(args.sw) if args.dirt == 0 else constant(0) tw = constant(args.tw) loss_main = (dw * loss_domain + cw * loss_src_class + sw * loss_src_vat + tw * loss_trg_cent + tw * loss_trg_vat) var_main = tf.get_collection('trainable_variables', 'class') train_main = tf.train.AdamOptimizer(args.lr, 0.5).minimize(loss_main, var_list=var_main) train_main = tf.group(train_main, ema_op) if args.dw > 0 and args.dirt == 0: var_disc = tf.get_collection('trainable_variables', 'disc') train_disc = tf.train.AdamOptimizer(args.lr, 0.5).minimize(loss_disc, var_list=var_disc) else: train_disc = constant(0) # Summarizations summary_disc = [tf.summary.scalar('domain/loss_disc', loss_disc),] summary_main = [tf.summary.scalar('domain/loss_domain', loss_domain), tf.summary.scalar('class/loss_src_class', loss_src_class), tf.summary.scalar('class/loss_trg_cent', loss_trg_cent), tf.summary.scalar('lipschitz/loss_trg_vat', loss_trg_vat), tf.summary.scalar('lipschitz/loss_src_vat', loss_src_vat), tf.summary.scalar('hyper/dw', dw), tf.summary.scalar('hyper/cw', cw), tf.summary.scalar('hyper/sw', sw), tf.summary.scalar('hyper/tw', tw), tf.summary.scalar('acc/src_acc', src_acc), tf.summary.scalar('acc/trg_acc', trg_acc)] # Merge summaries summary_disc = tf.summary.merge(summary_disc) summary_main = tf.summary.merge(summary_main) # Saved ops c = tf.constant T.ops_print = [c('disc'), loss_disc, c('domain'), loss_domain, c('class'), loss_src_class, c('cent'), loss_trg_cent, c('trg_vat'), loss_trg_vat, c('src_vat'), loss_src_vat, c('src'), src_acc, c('trg'), trg_acc] T.ops_disc = [summary_disc, train_disc] T.ops_main = [summary_main, train_main] T.fn_ema_acc = fn_ema_acc T.teacher = teacher T.update_teacher = update_teacher return T
def gada(): T = tb.utils.TensorDict(dict( sess = tf.Session(config=tb.growth_config()), src_x = placeholder((None, 32, 32, 3)), src_y = placeholder((None, args.Y)), trg_x = placeholder((None, 32, 32, 3)), trg_y = placeholder((None, args.Y)), trg_z = placeholder((None, 100)), test_x = placeholder((None, 32, 32, 3)), test_y = placeholder((None, args.Y)), )) # Supervised and conditional entropy minimization src_e = nn.classifier(T.src_x, phase=True, enc_phase=1, enc_trim=args.etrim) src_g = nn.classifier(src_e, phase=True, gen_trim=args.gtrim, gen_phase=1, enc_trim=args.etrim) src_p = nn.classifier(src_g, phase=True, gen_trim=args.gtrim) trg_e = nn.classifier(T.trg_x, phase=True, enc_phase=1, enc_trim=args.etrim, reuse=True, internal_update=True) trg_g = nn.classifier(trg_e, phase=True, gen_trim=args.gtrim, gen_phase=1, enc_trim=args.etrim, reuse=True, internal_update=True) trg_p = nn.classifier(trg_g, phase=True, gen_trim=args.gtrim, reuse=True, internal_update=True) loss_src_class = tf.reduce_mean(softmax_xent(labels=T.src_y, logits=src_p)) loss_trg_cent = tf.reduce_mean(softmax_xent_two(labels=trg_p, logits=trg_p)) if args.tw > 0 else constant(0) # Domain confusion if args.dw > 0 and args.dirt == 0: real_logit = nn.real_feature_discriminator(src_e, phase=True) fake_logit = nn.real_feature_discriminator(trg_e, phase=True, reuse=True) loss_disc = 0.5 * tf.reduce_mean( sigmoid_xent(labels=tf.ones_like(real_logit), logits=real_logit) + sigmoid_xent(labels=tf.zeros_like(fake_logit), logits=fake_logit)) loss_domain = 0.5 * tf.reduce_mean( sigmoid_xent(labels=tf.zeros_like(real_logit), logits=real_logit) + sigmoid_xent(labels=tf.ones_like(fake_logit), logits=fake_logit)) else: loss_disc = constant(0) loss_domain = constant(0) # Virtual adversarial training (turn off src in non-VADA phase) loss_src_vat = vat_loss(T.src_x, src_p, nn.classifier) if args.sw > 0 and args.dirt == 0 else constant(0) loss_trg_vat = vat_loss(T.trg_x, trg_p, nn.classifier) if args.tw > 0 else constant(0) # Generate images and process generated images trg_gen_x = nn.trg_generator(T.trg_z) trg_gen_e = nn.classifier(trg_gen_x, phase=True, enc_phase=1, enc_trim=args.etrim, reuse=True, internal_update=True) trg_gen_g = nn.classifier(trg_gen_e, phase=True, gen_trim=args.gtrim, gen_phase=1, enc_trim=args.etrim, reuse=True, internal_update=True) trg_gen_p = nn.classifier(trg_gen_g, phase=True, gen_trim=args.gtrim, reuse=True, internal_update=True) # Feature matching loss function for generator loss_trg_gen_fm = tf.reduce_mean(tf.square(tf.reduce_mean(trg_g, axis=0) - tf.reduce_mean(trg_gen_g, axis=0))) if args.dirt == 0 else constant(0) # Unsupervised loss function if args.dirt == 0: logit_real = tf.reduce_logsumexp(trg_p, axis=1) logit_fake = tf.reduce_logsumexp(trg_gen_p, axis=1) dis_loss_real = -0.5*tf.reduce_mean(logit_real) + 0.5*tf.reduce_mean(tf.nn.softplus(logit_real)) dis_loss_fake = 0.5*tf.reduce_mean(tf.nn.softplus(logit_fake)) loss_trg_usv = dis_loss_real + dis_loss_fake # UnSuperVised loss function else: loss_trg_usv = constant(0) # Evaluation (EMA) ema = tf.train.ExponentialMovingAverage(decay=0.998) var_class = tf.get_collection('trainable_variables', 'class/') ema_op = ema.apply(var_class) ema_p = nn.classifier(T.test_x, enc_phase=1, enc_trim=0, phase=False, reuse=True, getter=tb.tfutils.get_getter(ema)) # Teacher model (a back-up of EMA model) teacher_p = nn.classifier(T.test_x, enc_phase=1, enc_trim=0, phase=False, scope='teacher') var_main = tf.get_collection('variables', 'class/(?!.*ExponentialMovingAverage:0)') var_teacher = tf.get_collection('variables', 'teacher/(?!.*ExponentialMovingAverage:0)') teacher_assign_ops = [] for t, m in zip(var_teacher, var_main): ave = ema.average(m) ave = ave if ave else m teacher_assign_ops += [tf.assign(t, ave)] update_teacher = tf.group(*teacher_assign_ops) teacher = tb.function(T.sess, [T.test_x], tf.nn.softmax(teacher_p)) # Accuracies src_acc = basic_accuracy(T.src_y, src_p) trg_acc = basic_accuracy(T.trg_y, trg_p) ema_acc = basic_accuracy(T.test_y, ema_p) fn_ema_acc = tb.function(T.sess, [T.test_x, T.test_y], ema_acc) # Optimizer dw = constant(args.dw) if args.dirt == 0 else constant(0) cw = constant(1) if args.dirt == 0 else constant(args.bw) sw = constant(args.sw) if args.dirt == 0 else constant(0) tw = constant(args.tw) uw = constant(args.uw) if args.dirt == 0 else constant(0) loss_main = (dw * loss_domain + cw * loss_src_class + sw * loss_src_vat + tw * loss_trg_cent + tw * loss_trg_vat + uw * loss_trg_usv) var_main = tf.get_collection('trainable_variables', 'class') train_main = tf.train.AdamOptimizer(args.lr, 0.5).minimize(loss_main, var_list=var_main) train_main = tf.group(train_main, ema_op) # Optimizer for feature discriminator if args.dw > 0 and args.dirt == 0: var_disc = tf.get_collection('trainable_variables', 'disc_real') train_disc = tf.train.AdamOptimizer(args.lr, 0.5).minimize(loss_disc, var_list=var_disc) else: train_disc = constant(0) # Optimizer for generators if args.dirt == 0: fmw = constant(1) loss_trg_gen = (fmw * loss_trg_gen_fm) var_trg_gen = tf.get_collection('trainable_variables', 'trg_gen') trg_gen_update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS, scope='trg_gen') with tf.control_dependencies(trg_gen_update_ops): train_trg_gen = tf.train.AdamOptimizer(args.lr, 0.5).minimize(loss_trg_gen, var_list=var_trg_gen) train_gen = train_trg_gen else: fmw = constant(0) train_gen = constant(0) # Summarizations summary_disc = [tf.summary.scalar('domain/loss_disc', loss_disc),] summary_main = [tf.summary.scalar('domain/loss_domain', loss_domain), tf.summary.scalar('class/loss_src_class', loss_src_class), tf.summary.scalar('class/loss_trg_cent', loss_trg_cent), tf.summary.scalar('class/loss_trg_usv', loss_trg_usv), tf.summary.scalar('lipschitz/loss_trg_vat', loss_trg_vat), tf.summary.scalar('lipschitz/loss_src_vat', loss_src_vat), tf.summary.scalar('hyper/dw', dw), tf.summary.scalar('hyper/cw', cw), tf.summary.scalar('hyper/sw', sw), tf.summary.scalar('hyper/tw', tw), tf.summary.scalar('hyper/uw', uw), tf.summary.scalar('hyper/fmw', fmw), tf.summary.scalar('acc/src_acc', src_acc), tf.summary.scalar('acc/trg_acc', trg_acc)] summary_gen = [tf.summary.scalar('gen/loss_trg_gen_fm', loss_trg_gen_fm), tf.summary.image('gen/trg_gen_img', trg_gen_x),] # Merge summaries summary_disc = tf.summary.merge(summary_disc) summary_main = tf.summary.merge(summary_main) summary_gen = tf.summary.merge(summary_gen) # Saved ops c = tf.constant T.ops_print = [c('disc'), loss_disc, c('domain'), loss_domain, c('class'), loss_src_class, c('cent'), loss_trg_cent, c('trg_vat'), loss_trg_vat, c('src_vat'), loss_src_vat, c('src'), src_acc, c('trg'), trg_acc] T.ops_disc = [summary_disc, train_disc] T.ops_main = [summary_main, train_main] T.ops_gen = [summary_gen , train_gen] T.fn_ema_acc = fn_ema_acc T.teacher = teacher T.update_teacher = update_teacher T.trg_gen_x = trg_gen_x T.trg_gen_p = trg_gen_p T.src_p = src_p T.trg_p = trg_p T.ema_p = ema_p return T