def G_train_graph(): # ====================================== # = graph = # ====================================== # placeholders & inputs lr = tf.placeholder(dtype=tf.float32, shape=[]) xa, a = train_iter.get_next() b = tf.random_shuffle(a) a_ = a * 2 - 1 b_ = b * 2 - 1 # generate z = Genc(xa) xa_ = Gdec(z, a_) xb_ = Gdec(z, b_) # discriminate xb__logit_gan, xb__logit_att = D(xb_) # generator losses xb__loss_gan = g_loss_fn(xb__logit_gan) xb__loss_att = tf.losses.sigmoid_cross_entropy(b, xb__logit_att) xa__loss_rec = tf.losses.absolute_difference(xa, xa_) reg_loss = tf.reduce_sum(Genc.func.reg_losses + Gdec.func.reg_losses) loss = (xb__loss_gan + xb__loss_att * args.g_attribute_loss_weight + xa__loss_rec * args.g_reconstruction_loss_weight + reg_loss) # optim step_cnt, _ = tl.counter() step = tf.train.AdamOptimizer(lr, beta1=args.beta_1).minimize(loss, global_step=step_cnt, var_list=Genc.func.trainable_variables + Gdec.func.trainable_variables) # summary with tf.contrib.summary.create_file_writer('./output/%s/summaries/G' % args.experiment_name).as_default(),\ tf.contrib.summary.record_summaries_every_n_global_steps(10, global_step=step_cnt): summary = tl.summary_v2({ 'xb__loss_gan': xb__loss_gan, 'xb__loss_att': xb__loss_att, 'xa__loss_rec': xa__loss_rec, 'reg_loss': reg_loss }, step=step_cnt, name='G') # ====================================== # = generator size = # ====================================== n_params, n_bytes = tl.count_parameters(Genc.func.variables + Gdec.func.variables) print('Generator Size: n_parameters = %d = %.2fMB' % (n_params, n_bytes / 1024 / 1024)) # ====================================== # = run function = # ====================================== def run(**pl_ipts): sess.run([step, summary], feed_dict={lr: pl_ipts['lr']}) return run
def G_train_graph(): # ====================================== # = graph = # ====================================== # placeholders & inputs lr = tf.placeholder(dtype=tf.float32, shape=[]) xa, a = train_iter.get_next() b = tf.random_shuffle(a) a_ = a * 2 - 1 b_ = b * 2 - 1 # generate xb, _, ms, ms_multi = G(xa, b_ - a_) # discriminate xb_logit_gan, xb_logit_att = D(xb) # generator losses xb_loss_gan = g_loss_fn(xb_logit_gan) xb_loss_att = tf.losses.sigmoid_cross_entropy(b, xb_logit_att) spasity_loss = tf.reduce_sum([ tf.reduce_mean(m) * w for m, w in zip(ms, [1.0, 1.0, 1.0, 1.0, 1.0, 1.0]) ]) full_overlap_mask_pair_loss, non_overlap_mask_pair_loss = module.overlap_loss_fn( ms_multi, args.att_names) reg_loss = tf.reduce_sum(G.func.reg_losses) loss = ( xb_loss_gan + xb_loss_att * args.g_attribute_loss_weight + spasity_loss * args.g_spasity_loss_weight + full_overlap_mask_pair_loss * args.g_full_overlap_mask_pair_loss_weight + non_overlap_mask_pair_loss * args.g_non_overlap_mask_pair_loss_weight + reg_loss) # optim step_cnt, _ = tl.counter() step = tf.train.AdamOptimizer(lr, beta1=args.beta_1).minimize( loss, global_step=step_cnt, var_list=G.func.trainable_variables) # summary with tf.contrib.summary.create_file_writer('./output/%s/summaries/G' % args.experiment_name).as_default(),\ tf.contrib.summary.record_summaries_every_n_global_steps(10, global_step=step_cnt): summary = tl.summary_v2( { 'xb_loss_gan': xb_loss_gan, 'xb_loss_att': xb_loss_att, 'spasity_loss': spasity_loss, 'full_overlap_mask_pair_loss': full_overlap_mask_pair_loss, 'non_overlap_mask_pair_loss': non_overlap_mask_pair_loss, 'reg_loss': reg_loss }, step=step_cnt, name='G') # ====================================== # = generator size = # ====================================== n_params, n_bytes = tl.count_parameters(G.func.variables) print('Generator Size: n_parameters = %d = %.2fMB' % (n_params, n_bytes / 1024 / 1024)) # ====================================== # = run function = # ====================================== def run(**pl_ipts): sess.run([step, summary], feed_dict={lr: pl_ipts['lr']}) return run
def G_train_graph(): # ====================================== # = graph = # ====================================== # placeholders & inputs lr = tf.placeholder(dtype=tf.float32, shape=[]) zs = [tf.random.normal([args.batch_size, z_dim]) for z_dim in args.z_dims] eps = tf.random.normal([args.batch_size, args.eps_dim]) # counter step_cnt, _ = tl.counter() # optimizer optimizer = tf.train.AdamOptimizer(lr, beta1=args.beta_1) def graph_per_gpu(zs, eps): # generate x_f = G(zs, eps) # discriminate x_f_logit = D(x_f) # loss x_f_loss = g_loss_fn(x_f_logit) orth_loss = tf.reduce_sum( tl.tensors_filter(G.func.reg_losses, 'orthogonal_regularizer')) reg_loss = tf.reduce_sum( tl.tensors_filter(G.func.reg_losses, 'l2_regularizer')) loss = (x_f_loss * args.g_loss_weight_x_gan + orth_loss * args.g_loss_weight_orth_loss + reg_loss * args.weight_decay) # optim grads = optimizer.compute_gradients( loss, var_list=G.func.trainable_variables) return grads, x_f_loss, orth_loss, reg_loss split_grads, split_x_f_loss, split_orth_loss, split_reg_loss = zip( *tl.parellel_run(tl.gpus(), graph_per_gpu, tl.split_nest((zs, eps), len(tl.gpus())))) # split_grads, split_x_f_loss, split_orth_loss, split_reg_loss = zip(*tl.parellel_run(['cpu:0'], graph_per_gpu, tl.split_nest((zs, eps), 1))) grads = tl.average_gradients(split_grads) x_f_loss, orth_loss, reg_loss = [ tf.reduce_mean(t) for t in [split_x_f_loss, split_orth_loss, split_reg_loss] ] step = optimizer.apply_gradients(grads, global_step=step_cnt) # moving average with tf.control_dependencies([step]): step = G_ema.apply(G.func.trainable_variables) # summary summary_dict = { 'x_f_loss': x_f_loss, 'orth_loss': orth_loss, 'reg_loss': reg_loss } summary_dict.update({ 'L_%d' % i: t for i, t in enumerate(tl.tensors_filter(G.func.variables, 'L')) }) summary_loss = tl.create_summary_statistic_v2(summary_dict, './output/%s/summaries/G' % args.experiment_name, step=step_cnt, n_steps_per_record=10, name='G_loss') summary_image = tl.create_summary_image_v2( { 'orth_U_%d' % i: t[None, :, :, None] for i, t in enumerate(tf.get_collection('orth', G.func.scope + '/')) }, './output/%s/summaries/G' % args.experiment_name, step=step_cnt, n_steps_per_record=10, name='G_image') # ====================================== # = model size = # ====================================== n_params, n_bytes = tl.count_parameters(G.func.variables) print('Model Size: n_parameters = %d = %.2fMB' % (n_params, n_bytes / 1024 / 1024)) # ====================================== # = run function = # ====================================== def run(**pl_ipts): sess.run([step, summary_loss, summary_image], feed_dict={lr: pl_ipts['lr']}) return run