def eval(): Gab = models.get_G() Gba = models.get_G() Gab.eval() Gba.eval() Gab.load_weights(flags.model_dir + '/Gab.h5') Gba.load_weights(flags.model_dir + '/Gba.h5') for i, (x, _) in enumerate( tl.iterate.minibatches(inputs=im_test_A, targets=im_test_A, batch_size=5, shuffle=False)): o = Gab(x) tl.vis.save_images(x, [1, 5], flags.sample_dir + '/eval_{}_a.png'.format(i)) tl.vis.save_images(o.numpy(), [1, 5], flags.sample_dir + '/eval_{}_a2b.png'.format(i)) for i, (x, _) in enumerate( tl.iterate.minibatches(inputs=im_test_B, targets=im_test_B, batch_size=5, shuffle=False)): o = Gba(x) tl.vis.save_images(x, [1, 5], flags.sample_dir + '/eval_{}_b.png'.format(i)) tl.vis.save_images(o.numpy(), [1, 5], flags.sample_dir + '/eval_{}_b2a.png'.format(i))
def disentangle_test(): print('Start disentangle_test!') ds, _ = get_dataset_train() E = get_E([None, flags.img_size_h, flags.img_size_w, flags.c_dim]) E.load_weights('{}/{}/E.npz'.format(flags.checkpoint_dir, flags.param_dir), format='npz') E.eval() G = get_G([None, flags.z_dim]) G.load_weights('{}/{}/G.npz'.format(flags.checkpoint_dir, flags.param_dir), format='npz') G.eval() for step, batch_img in enumerate(ds): if step > flags.disentangle_step_num: break z_real = E(batch_img) hash_real = ((tf.sign(z_real * 2 - 1, name=None) + 1) / 2).numpy() epsilon = flags.scale * np.random.normal( loc=0.0, scale=flags.sigma * math.sqrt(flags.z_dim) * 0.0625, size=[flags.batch_size_train, flags.z_dim]).astype(np.float32) # z_fake = hash_real + epsilon z_fake = z_real + epsilon fake_imgs = G(z_fake) tl.visualize.save_images( batch_img.numpy(), [8, 8], '{}/{}/disentangle/real_{:02d}.png'.format(flags.test_dir, flags.param_dir, step)) tl.visualize.save_images( fake_imgs.numpy(), [8, 8], '{}/{}/disentangle/fake_{:02d}.png'.format(flags.test_dir, flags.param_dir, step))
def train(parallel, kungfu_option): Gab = models.get_G(name='Gab') Gba = models.get_G(name='Gba') Da = models.get_D(name='Da') Db = models.get_D(name='Db') Gab.train() Gba.train() Da.train() Db.train() lr_v = tf.Variable(flags.lr_init) # optimizer_Gab_Db = tf.optimizers.Adam(lr_v, beta_1=flags.beta_1) # optimizer_Gba_Da = tf.optimizers.Adam(lr_v, beta_1=flags.beta_1) # optimizer_G = tf.optimizers.Adam(lr_v, beta_1=flags.beta_1) # optimizer_D = tf.optimizers.Adam(lr_v, beta_1=flags.beta_1) optimizer = tf.optimizers.Adam( lr_v, beta_1=flags.beta_1 ) # use only one optimier, if your GPU memory is large use_ident = False # KungFu: wrap the optimizers if parallel: from kungfu.tensorflow.optimizers import SynchronousSGDOptimizer, SynchronousAveragingOptimizer, PairAveragingOptimizer if kungfu_option == 'sync-sgd': opt_fn = SynchronousSGDOptimizer elif kungfu_option == 'async-sgd': opt_fn = PairAveragingOptimizer elif kungfu_option == 'sma': opt_fn = SynchronousAveragingOptimizer else: raise RuntimeError('Unknown distributed training optimizer.') optimizer_Gab_Db = opt_fn(optimizer_Gab_Db) optimizer_Gba_Da = opt_fn(optimizer_Gba_Da) # Gab.load_weights(flags.model_dir + '/Gab.h5') # restore params? # Gba.load_weights(flags.model_dir + '/Gba.h5') # Da.load_weights(flags.model_dir + '/Da.h5') # Db.load_weights(flags.model_dir + '/Db.h5') # KungFu: shard the data if parallel: from kungfu import current_cluster_size, current_rank data_A_shard = [] data_B_shard = [] for step, (image_A, image_B) in enumerate(zip(data_A, data_B)): if step % current_cluster_size() == current_rank(): data_A_shard.append(image_A) data_B_shard.append(image_B) else: data_A_shard = data_A data_B_shard = data_B @tf.function def train_step(image_A, image_B): fake_B = Gab(image_A) fake_A = Gba(image_B) cycle_A = Gba(fake_B) cycle_B = Gab(fake_A) if use_ident: iden_A = Gba(image_A) iden_B = Gab(image_B) logits_fake_B = Db(fake_B) # TODO: missing image buffer (pool) logits_real_B = Db(image_B) logits_fake_A = Da(fake_A) logits_real_A = Da(image_A) # loss_Da = (tl.cost.mean_squared_error(logits_real_A, tf.ones_like(logits_real_A), is_mean=True) + \ # LSGAN # tl.cost.mean_squared_error(logits_fake_A, tf.ones_like(logits_fake_A), is_mean=True)) / 2. loss_Da = tf.reduce_mean(tf.math.squared_difference(logits_fake_A, tf.zeros_like(logits_fake_A))) + \ tf.reduce_mean(tf.math.squared_difference(logits_real_A, tf.ones_like(logits_real_A))) # loss_Da = tl.cost.sigmoid_cross_entropy(logits_fake_A, tf.zeros_like(logits_fake_A)) + \ # tl.cost.sigmoid_cross_entropy(logits_real_A, tf.ones_like(logits_real_A)) # loss_Db = (tl.cost.mean_squared_error(logits_real_B, tf.ones_like(logits_real_B), is_mean=True) + \ # LSGAN # tl.cost.mean_squared_error(logits_fake_B, tf.ones_like(logits_fake_B), is_mean=True)) / 2. loss_Db = tf.reduce_mean(tf.math.squared_difference(logits_fake_B, tf.zeros_like(logits_fake_B))) + \ tf.reduce_mean(tf.math.squared_difference(logits_real_B, tf.ones_like(logits_real_B))) # loss_Db = tl.cost.sigmoid_cross_entropy(logits_fake_B, tf.zeros_like(logits_fake_B)) + \ # tl.cost.sigmoid_cross_entropy(logits_real_B, tf.ones_like(logits_real_B)) # loss_Gab = tl.cost.mean_squared_error(logits_fake_B, tf.ones_like(logits_fake_B), is_mean=True) # LSGAN loss_Gab = tf.reduce_mean( tf.math.squared_difference(logits_fake_B, tf.ones_like(logits_fake_B))) # loss_Gab = tl.cost.sigmoid_cross_entropy(logits_fake_B, tf.ones_like(logits_fake_B)) # loss_Gba = tl.cost.mean_squared_error(logits_fake_A, tf.ones_like(logits_fake_A), is_mean=True) # LSGAN loss_Gba = tf.reduce_mean( tf.math.squared_difference(logits_fake_A, tf.ones_like(logits_fake_A))) # loss_Gba = tl.cost.sigmoid_cross_entropy(logits_fake_A, tf.ones_like(logits_fake_A)) # loss_cyc = 10 * (tl.cost.absolute_difference_error(image_A, cycle_A, is_mean=True) + \ # tl.cost.absolute_difference_error(image_B, cycle_B, is_mean=True)) loss_cyc = 10. * (tf.reduce_mean(tf.abs(image_A - cycle_A)) + tf.reduce_mean(tf.abs(image_B - cycle_B))) if use_ident: loss_iden = 5. * (tf.reduce_mean(tf.abs(image_A - iden_A)) + tf.reduce_mean(tf.abs(image_B - iden_B))) else: loss_iden = 0. loss_G = loss_Gab + loss_Gba + loss_cyc + loss_iden loss_D = loss_Da + loss_Db return loss_G, loss_D, loss_Gab, loss_Gba, loss_cyc, loss_iden, loss_Da, loss_Db, loss_D + loss_G for epoch in range(0, flags.n_epoch): # reduce lr linearly after 100 epochs, from lr_init to 0 if epoch >= 100: new_lr = flags.lr_init - flags.lr_init * (epoch - 100) / 100 lr_v.assign(lr_v, new_lr) print("New learning rate %f" % new_lr) # train 1 epoch for step, (image_A, image_B) in enumerate(zip(data_A_shard, data_B_shard)): if image_A.shape[0] != flags.batch_size or image_B.shape[ 0] != flags.batch_size: # if the remaining data in this epoch < batch_size break step_time = time.time() with tf.GradientTape(persistent=True) as tape: # print(image_A.numpy().max()) loss_G, loss_D, loss_Gab, loss_Gba, loss_cyc, loss_iden, loss_Da, loss_Db, loss_DG = train_step( image_A, image_B) grad = tape.gradient( loss_DG, Gba.trainable_weights + Gab.trainable_weights + Da.trainable_weights + Db.trainable_weights) optimizer.apply_gradients( zip( grad, Gba.trainable_weights + Gab.trainable_weights + Da.trainable_weights + Db.trainable_weights)) # grad = tape.gradient(loss_G, Gba.trainable_weights+Gab.trainable_weights) # optimizer_G.apply_gradients(zip(grad, Gba.trainable_weights+Gab.trainable_weights)) # grad = tape.gradient(loss_D, Da.trainable_weights+Db.trainable_weights) # optimizer_D.apply_gradients(zip(grad, Da.trainable_weights+Db.trainable_weights)) # del tape print("Epoch[{}/{}] step[{}/{}] time:{:.3f} Gab:{:.3f} Gba:{:.3f} cyc:{:.3f} iden:{:.3f} Da:{:.3f} Db:{:.3f}".format(\ epoch, flags.n_epoch, step, n_step_per_epoch, time.time()-step_time, \ loss_Gab, loss_Gba, loss_cyc, loss_iden, loss_Da, loss_Db)) if parallel and step == 0: # KungFu: broadcast is done after the first gradient step to ensure optimizer initialization. from kungfu.tensorflow.initializer import broadcast_variables # Broadcast model variables broadcast_variables(Gab.trainable_weights) broadcast_variables(Gba.trainable_weights) broadcast_variables(Da.trainable_weights) broadcast_variables(Db.trainable_weights) # Broadcast optimizer variables broadcast_variables(optimizer_Gab.variables()) broadcast_variables(optimizer_Gba.variables()) broadcast_variables(optimizer_Da.variables()) broadcast_variables(optimizer_Db.variables()) if parallel: from kungfu import current_rank is_chief = current_rank() == 0 else: is_chief = True # Let the chief worker to do visuliazation and checkpoints. if is_chief: # visualization # outb = Gab(sample_A) # outa = Gba(sample_B) # tl.vis.save_images(outb.numpy(), [1, 5], flags.sample_dir+'/{}_a2b.png'.format(epoch)) # tl.vis.save_images(outa.numpy(), [1, 5], flags.sample_dir+'/{}_b2a.png'.format(epoch)) outb_list = [] # do it one by one in case your GPU memory is low for i in range(len(sample_A)): outb = Gab(sample_A[i][np.newaxis, :, :, :]) outb_list.append(outb.numpy()[0]) outa_list = [] for i in range(len(sample_B)): outa = Gba(sample_B[i][np.newaxis, :, :, :]) outa_list.append(outa.numpy()[0]) tl.vis.save_images(np.asarray(outb_list), [1, 5], flags.sample_dir + '/{}_a2b.png'.format(epoch)) tl.vis.save_images(np.asarray(outa_list), [1, 5], flags.sample_dir + '/{}_b2a.png'.format(epoch)) # save models if epoch % 5: Gab.save_weights(flags.model_dir + '/Gab.h5') Gba.save_weights(flags.model_dir + '/Gba.h5') Da.save_weights(flags.model_dir + '/Da.h5') Db.save_weights(flags.model_dir + '/Db.h5')
from models import get_G, get_E, get_z_D, get_z_G import random import argparse import math import scipy.stats as stats import tensorflow_probability as tfp import matplotlib.pyplot as plt from sklearn import manifold, datasets import ipdb # import sys # f = open('a.log', 'a') # sys.stdout = f # sys.stderr = f # redirect std err, if necessary G = get_G([None, 8, 8, 128]) E = get_E([None, flags.img_size_h, flags.img_size_w, flags.c_dim]) D_z = get_z_D([None, 8, 8, 128]) G_z = get_z_G([None, flags.z_dim]) def KStest(real_z, fake_z): p_list = [] for i in range(flags.batch_size_train): _, tmp_p = stats.ks_2samp(fake_z[i], real_z[i]) p_list.append(tmp_p) return np.min(p_list), np.mean(p_list) def TSNE(dataset, E, n_step_epoch): total_tensor = None
temp_out = sys.stdout parser = argparse.ArgumentParser() parser.add_argument('--is_continue', type=bool, default=False, help='load weights from checkpoints?') args = parser.parse_args() E_x_a = get_Ea([None, flags.img_size_h, flags.img_size_w, flags.c_dim]) E_y_a = get_Ea([None, flags.img_size_h, flags.img_size_w, flags.c_dim]) E_c = get_Ec([None, flags.img_size_h, flags.img_size_w, flags.c_dim], [None, flags.img_size_h, flags.img_size_w, flags.c_dim]) G = get_G([None, flags.za_dim], [None, flags.c_shape[0], flags.c_shape[1], flags.c_shape[2]], [None, flags.za_dim], [None, flags.c_shape[0], flags.c_shape[1], flags.c_shape[2]]) D_x = get_D([None, flags.img_size_h, flags.img_size_h, flags.c_dim]) D_y = get_D([None, flags.img_size_h, flags.img_size_h, flags.c_dim]) D_c = get_D_content( [None, flags.c_shape[0], flags.c_shape[1], flags.c_shape[2]]) E_x_a.train() E_y_a.train() E_c.train() G.train() D_x.train() D_y.train() D_c.train()
def train(con=False): dataset = get_cat2dog_train() len_dataset = flags.len_dataset E_x_a = get_Ea([None, flags.img_size_h, flags.img_size_w, flags.c_dim]) E_x_c = get_Ec([None, flags.img_size_h, flags.img_size_w, flags.c_dim]) E_y_a = get_Ea([None, flags.img_size_h, flags.img_size_w, flags.c_dim]) E_y_c = get_Ec([None, flags.img_size_h, flags.img_size_w, flags.c_dim]) G_x = get_G([None, flags.za_dim], [None, flags.c_shape[0], flags.c_shape[1], flags.c_shape[2]]) G_y = get_G([None, flags.za_dim], [None, flags.c_shape[0], flags.c_shape[1], flags.c_shape[2]]) G_z_c = get_G_zc([None, flags.zc_dim]) D_x = get_D([None, flags.img_size_h, flags.img_size_h, flags.c_dim]) D_y = get_D([None, flags.img_size_h, flags.img_size_h, flags.c_dim]) D_c = get_D_content( [None, flags.c_shape[0], flags.c_shape[1], flags.c_shape[2]]) E_y_zc = get_E_x2zc( [None, flags.img_size_h, flags.img_size_w, flags.c_dim]) E_y_za = get_E_x2za( [None, flags.img_size_h, flags.img_size_w, flags.c_dim]) if con: E_x_a.load_weights('./checkpoint/E_x_a.npz') E_x_c.load_weights('./checkpoint/E_x_c.npz') E_y_a.load_weights('./checkpoint/E_y_a.npz') E_y_c.load_weights('./checkpoint/E_y_c.npz') G_x.load_weights('./checkpoint/G_x.npz') G_y.load_weights('./checkpoint/G_y.npz') G_z_c.load_weights('./checkpoint/G_z_c.npz') D_x.load_weights('./checkpoint/D_x.npz') D_y.load_weights('./checkpoint/D_y.npz') D_c.load_weights('./checkpoint/D_c.npz') E_y_zc.load_weights('./checkpoint/E_y_zc.npz') E_y_za.load_weights('./checkpoint/E_y_za.npz') E_x_a.train() E_x_c.train() E_y_a.train() E_y_c.train() G_x.train() G_y.train() G_z_c.train() E_y_zc.train() E_y_za.train() D_x.train() D_y.train() D_c.train() n_step_epoch = int(len_dataset // flags.batch_size_train) n_epoch = flags.n_epoch lr_share = flags.lr E_x_a_optimizer = tf.optimizers.Adam(lr_share, beta_1=flags.beta1, beta_2=flags.beta2) E_x_c_optimizer = tf.optimizers.Adam(lr_share, beta_1=flags.beta1, beta_2=flags.beta2) E_y_a_optimizer = tf.optimizers.Adam(lr_share, beta_1=flags.beta1, beta_2=flags.beta2) E_y_c_optimizer = tf.optimizers.Adam(lr_share, beta_1=flags.beta1, beta_2=flags.beta2) G_x_optimizer = tf.optimizers.Adam(lr_share, beta_1=flags.beta1, beta_2=flags.beta2) G_y_optimizer = tf.optimizers.Adam(lr_share, beta_1=flags.beta1, beta_2=flags.beta2) G_z_c_optimizer = tf.optimizers.Adam(lr_share, beta_1=flags.beta1, beta_2=flags.beta2) E_y_zc_optimizer = tf.optimizers.Adam(lr_share, beta_1=flags.beta1, beta_2=flags.beta2) E_y_za_optimizer = tf.optimizers.Adam(lr_share, beta_1=flags.beta1, beta_2=flags.beta2) D_x_optimizer = tf.optimizers.Adam(lr_share, beta_1=flags.beta1, beta_2=flags.beta2) D_y_optimizer = tf.optimizers.Adam(lr_share, beta_1=flags.beta1, beta_2=flags.beta2) D_c_optimizer = tf.optimizers.Adam(lr_share, beta_1=flags.beta1, beta_2=flags.beta2) tfd = tfp.distributions dist = tfd.Normal(loc=0., scale=1.) for step, cat_and_dog in enumerate(dataset): ''' log = " ** new learning rate: %f (for GAN)" % (lr_v.tolist()[0]) print(log) ''' cat_img = cat_and_dog[0] # (1, 256, 256, 3) dog_img = cat_and_dog[1] # (1, 256, 256, 3) epoch_num = step // n_step_epoch with tf.GradientTape(persistent=True) as tape: z_a = dist.sample([flags.batch_size_train, flags.za_dim]) z_c = dist.sample([flags.batch_size_train, flags.zc_dim]) # dog_app_vec = E_x_a(dog_img) dog_app_vec, dog_app_mu, dog_app_logvar = E_x_a( dog_img) # instead of dog_app_vec dog_cont_vec = E_x_c(dog_img) z_cat_cont_vec = G_z_c(z_c) dog_cont_vec_logit = D_c(dog_cont_vec) z_cat_cont_vec_logit = D_c(z_cat_cont_vec) # print(dog_app_vec.shape) # (1, 8) # print(z_cat_cont_vec.shape) # (1, 64, 64, 256) fake_dog = G_x([dog_app_vec, z_cat_cont_vec]) fake_cat = G_y([z_a, dog_cont_vec]) real_dog_logit = D_x(dog_img) fake_dog_logit = D_x(fake_dog) real_cat_logit = D_y(cat_img) fake_cat_logit = D_y(fake_cat) # fake_dog_app_vec = E_x_a(fake_dog) fake_dog_app_vec, fake_dog_app_mu, fake_dog_app_logvar = E_x_a( fake_dog) # instead of dog_app_vec fake_dog_cont_vec = E_x_c(fake_dog) # fake_cat_app_vec = E_y_a(fake_cat) fake_cat_app_vec, fake_cat_app_mu, fake_cat_app_logvar = E_x_a( fake_cat) # instead of dog_app_vec fake_cat_cont_vec = E_y_c(fake_cat) recon_dog = G_x([fake_dog_app_vec, fake_cat_cont_vec]) recon_cat = G_y([fake_cat_app_vec, fake_dog_cont_vec]) recon_z_a = E_y_za(recon_cat) recon_z_c = E_y_zc(recon_cat) # content adv loss, to update D_c E_x_c, E_y_c cont_adv_loss = flags.lambda_content * ( 1 / 2 * tl.cost.sigmoid_cross_entropy(dog_cont_vec_logit, tf.ones_like(dog_cont_vec_logit)) + \ 1 / 2 * tl.cost.sigmoid_cross_entropy(dog_cont_vec_logit, tf.zeros_like(dog_cont_vec_logit)) + \ 1 / 2 * tl.cost.sigmoid_cross_entropy(z_cat_cont_vec_logit, tf.zeros_like(z_cat_cont_vec_logit)) + \ 1 / 2 * tl.cost.sigmoid_cross_entropy(z_cat_cont_vec_logit, tf.ones_like(z_cat_cont_vec_logit))) # cross_identity loss, to update all Es and Gs cross_identity_loss = flags.lambda_corss * ( tl.cost.absolute_difference_error(recon_dog, dog_img, is_mean=True) + \ tl.cost.absolute_difference_error(recon_z_a, z_a, is_mean=True) + \ tl.cost.absolute_difference_error(recon_z_c, z_c, is_mean=True)) # Domain adv loss dog_adv_loss = flags.lambda_domain * ( tl.cost.sigmoid_cross_entropy(real_dog_logit, tf.ones_like(real_dog_logit)) + \ tl.cost.sigmoid_cross_entropy(fake_dog_logit, tf.zeros_like(fake_dog_logit))) cat_adv_loss = flags.lambda_domain * ( tl.cost.sigmoid_cross_entropy(real_cat_logit, tf.ones_like(real_cat_logit)) + \ tl.cost.sigmoid_cross_entropy(fake_cat_logit, tf.zeros_like(fake_cat_logit))) # Self recon loss self_recon_dog = G_x([dog_app_vec, dog_cont_vec]) self_recon_cat_z = G_y([E_y_za(cat_img), G_z_c(E_y_zc(cat_img))]) self_recon_cat_t = G_y([E_y_a(cat_img)[0], E_y_c(cat_img)]) cat_self_recon_loss_z = flags.lambda_srecon * tl.cost.absolute_difference_error( self_recon_cat_z, cat_img, is_mean=True) cat_self_recon_loss_t = flags.lambda_srecon * tl.cost.absolute_difference_error( self_recon_cat_t, cat_img, is_mean=True) dog_self_recon_loss = flags.lambda_srecon * tl.cost.absolute_difference_error( self_recon_dog, dog_img, is_mean=True) # latent regression loss fake_cat_za = E_y_za(fake_cat) z_a_dog = dist.sample([flags.batch_size_train, flags.za_dim]) fake_dog_za = E_x_a(G_x([z_a_dog, E_x_c(dog_img)]))[0] latent_regre_loss_cat = tl.cost.absolute_difference_error( fake_cat_za, z_a, is_mean=True) latent_regre_loss_dog = tl.cost.absolute_difference_error( fake_dog_za, z_a_dog, is_mean=True) # latent generation loss gen_cat = G_y([z_a, G_z_c(z_c)]) gen_cat_logit = D_y(gen_cat) latent_gen_loss_cat = flags.lambda_latent * ( tl.cost.sigmoid_cross_entropy(real_cat_logit, tf.ones_like(real_cat_logit)) + \ tl.cost.sigmoid_cross_entropy(gen_cat_logit, tf.zeros_like(gen_cat_logit))) # KL loss z_kl = dist.sample([flags.KL_batch, flags.za_dim]) # kl_mean, kl_variance = tf.nn.moments(x=dog_app_vec, axes=[1]) # print(dog_app_vec) # print(str(kl_variance) + ' ' + str(kl_mean)) # [1,] [1,] # kl_sigma = tf.math.sqrt(kl_variance) # z_calc = z_kl * kl_sigma + tf.ones_like(z_kl) * kl_mean kl_loss_dog = flags.lambda_KL * KL_loss(dog_app_mu, dog_app_logvar) # Mode seeking regularization z_1 = dist.sample([flags.batch_size_train, flags.za_dim]) z_2 = dist.sample([flags.batch_size_train, flags.za_dim]) dog_1 = G_x([z_1, E_x_c(dog_img)]) dog_2 = G_x([z_2, E_x_c(dog_img)]) dog_norm = tl.cost.mean_squared_error(dog_1, dog_2) cat_norm = tl.cost.mean_squared_error(G_y([z_1, E_y_c(cat_img)]), G_y([z_2, E_y_c(cat_img)])) z_norm = tl.cost.mean_squared_error(z_1, z_2) dog_ms_loss = -dog_norm / z_norm cat_ms_loss = -cat_norm / z_norm # sum up total loss content_adv_loss = cont_adv_loss cross_identity_loss = cross_identity_loss domain_adv_loss = dog_adv_loss + cat_adv_loss self_recon_loss = cat_self_recon_loss_t + cat_self_recon_loss_z + dog_self_recon_loss latent_regre_loss = latent_gen_loss_cat + latent_regre_loss_dog latent_gen_loss = latent_gen_loss_cat kl_loss = kl_loss_dog ms_loss = dog_ms_loss + cat_ms_loss E_x_a_total_loss = cross_identity_loss + dog_self_recon_loss + latent_regre_loss_dog + kl_loss_dog E_x_c_total_loss = cont_adv_loss + cross_identity_loss + dog_self_recon_loss + latent_regre_loss_dog + \ dog_ms_loss E_y_a_total_loss = cross_identity_loss + cat_self_recon_loss_t + latent_regre_loss_cat E_y_c_total_loss = cont_adv_loss + cross_identity_loss + cat_self_recon_loss_t + latent_regre_loss_cat + \ cat_ms_loss E_y_zc_total_loss = cross_identity_loss + cat_self_recon_loss_z E_y_za_total_loss = cross_identity_loss + cat_self_recon_loss_z G_x_total_loss = cross_identity_loss + dog_adv_loss + dog_self_recon_loss + latent_regre_loss_dog + \ dog_ms_loss G_y_total_loss = cross_identity_loss + cat_adv_loss + cat_self_recon_loss_z + cat_self_recon_loss_t + \ latent_regre_loss_cat + latent_gen_loss_cat + cat_ms_loss G_z_c_total_loss = cross_identity_loss + cat_self_recon_loss_z D_x_total_loss = dog_adv_loss D_y_total_loss = cat_adv_loss + latent_gen_loss_cat D_c_total_loss = cont_adv_loss # Release Memory E_x_a.release_memory() E_x_c.release_memory() E_y_a.release_memory() E_y_c.release_memory() E_y_zc.release_memory() E_y_za.release_memory() G_x.release_memory() G_y.release_memory() G_z_c.release_memory() D_x.release_memory() D_y.release_memory() D_c.release_memory() # Updating Encoder grad = tape.gradient(E_x_a_total_loss, E_x_a.trainable_weights) E_x_a_optimizer.apply_gradients(zip(grad, E_x_a.trainable_weights)) grad = tape.gradient(E_x_c_total_loss, E_x_c.trainable_weights) E_x_c_optimizer.apply_gradients(zip(grad, E_x_c.trainable_weights)) grad = tape.gradient(E_y_a_total_loss, E_y_a.trainable_weights) E_y_a_optimizer.apply_gradients(zip(grad, E_y_a.trainable_weights)) grad = tape.gradient(E_y_c_total_loss, E_y_c.trainable_weights) E_y_c_optimizer.apply_gradients(zip(grad, E_y_c.trainable_weights)) grad = tape.gradient(E_y_zc_total_loss, E_y_zc.trainable_weights) E_y_zc_optimizer.apply_gradients(zip(grad, E_y_zc.trainable_weights)) grad = tape.gradient(E_y_za_total_loss, E_y_za.trainable_weights) E_y_za_optimizer.apply_gradients(zip(grad, E_y_za.trainable_weights)) grad = tape.gradient(G_x_total_loss, G_x.trainable_weights) G_x_optimizer.apply_gradients(zip(grad, G_x.trainable_weights)) grad = tape.gradient(G_y_total_loss, G_y.trainable_weights) G_y_optimizer.apply_gradients(zip(grad, G_y.trainable_weights)) grad = tape.gradient(G_z_c_total_loss, G_z_c.trainable_weights) G_z_c_optimizer.apply_gradients(zip(grad, G_z_c.trainable_weights)) grad = tape.gradient(D_x_total_loss, D_x.trainable_weights) D_x_optimizer.apply_gradients(zip(grad, D_x.trainable_weights)) grad = tape.gradient(D_y_total_loss, D_y.trainable_weights) D_y_optimizer.apply_gradients(zip(grad, D_y.trainable_weights)) grad = tape.gradient(D_c_total_loss, D_c.trainable_weights) D_c_optimizer.apply_gradients(zip(grad, D_c.trainable_weights)) del tape # show current state if np.mod(step, flags.show_every_step) == 0: with open("log.txt", "a+") as f: sys.stdout = f print( "Epoch: [{}/{}] [{}/{}] content_adv_loss:{:5f}, cross_identity_loss:{:5f}, " "domain_adv_loss:{:5f}, self_recon_loss:{:5f}, latent_regre_loss:{:5f}, latent_gen_loss:{:5f}, " "kl_loss:{:5f}, ms_loss:{:10f}".format( epoch_num, flags.n_epoch, step - (epoch_num * n_step_epoch), n_step_epoch, content_adv_loss, cross_identity_loss, domain_adv_loss, self_recon_loss, latent_regre_loss, latent_gen_loss, kl_loss, ms_loss)) sys.stdout = temp_out print( "Epoch: [{}/{}] [{}/{}] content_adv_loss:{:5f}, cross_identity_loss:{:5f}, " "domain_adv_loss:{:5f}, self_recon_loss:{:5f}, latent_regre_loss:{:5f}, latent_gen_loss:{:5f}, " "kl_loss:{:5f}, ms_loss:{:10f}".format( epoch_num, flags.n_epoch, step - (epoch_num * n_step_epoch), n_step_epoch, content_adv_loss, cross_identity_loss, domain_adv_loss, self_recon_loss, latent_regre_loss, latent_gen_loss, kl_loss, ms_loss)) if np.mod(step, flags.save_step) == 0 and step != 0: E_x_a.save_weights('{}/{}/E_x_a.npz'.format( flags.checkpoint_dir, flags.param_dir), format='npz') E_x_c.save_weights('{}/{}/E_x_c.npz'.format( flags.checkpoint_dir, flags.param_dir), format='npz') E_y_a.save_weights('{}/{}/E_y_a.npz'.format( flags.checkpoint_dir, flags.param_dir), format='npz') E_y_c.save_weights('{}/{}/E_y_c.npz'.format( flags.checkpoint_dir, flags.param_dir), format='npz') G_x.save_weights('{}/{}/G_x.npz'.format(flags.checkpoint_dir, flags.param_dir), format='npz') G_y.save_weights('{}/{}/G_y.npz'.format(flags.checkpoint_dir, flags.param_dir), format='npz') G_z_c.save_weights('{}/{}/G_z_c.npz'.format( flags.checkpoint_dir, flags.param_dir), format='npz') E_y_zc.save_weights('{}/{}/E_y_zc.npz'.format( flags.checkpoint_dir, flags.param_dir), format='npz') E_y_za.save_weights('{}/{}/E_y_za.npz'.format( flags.checkpoint_dir, flags.param_dir), format='npz') D_x.save_weights('{}/{}/D_x.npz'.format(flags.checkpoint_dir, flags.param_dir), format='npz') D_y.save_weights('{}/{}/D_y.npz'.format(flags.checkpoint_dir, flags.param_dir), format='npz') D_c.save_weights('{}/{}/D_c.npz'.format(flags.checkpoint_dir, flags.param_dir), format='npz') # G.train() if np.mod(step, flags.eval_step) == 0: z = dist.sample([flags.batch_size_train, flags.za_dim]) E_y_c.eval() G_y.eval() eval_cat_cont_vec = E_y_c(cat_img) sys_cat_img = G_y([z, eval_cat_cont_vec]) sys_cat_img = tf.concat([sys_cat_img, cat_img], 0) E_y_c.train() G_y.train() tl.visualize.save_images( sys_cat_img.numpy(), [1, 2], '{}/{}/train_{:02d}_{:04d}.png'.format(flags.sample_dir, flags.param_dir, step // n_step_epoch, step))
import os, time, multiprocessing import numpy as np import tensorflow as tf import tensorlayer as tl from config import flags from data import get_dataset_train from models import get_G, get_E, get_z_D, get_trans_func, get_img_D import random import argparse import math import scipy.stats as stats import tensorflow_probability as tfp import ipdb G = get_G([None, flags.z_dim]) D = get_img_D([None, flags.img_size_h, flags.img_size_w, flags.c_dim]) E = get_E([None, flags.img_size_h, flags.img_size_w, flags.c_dim]) D_z = get_z_D([None, flags.z_dim]) C = get_z_D([None, flags.z_dim]) f_ab = get_trans_func([None, flags.z_dim]) f_ba = get_trans_func([None, flags.z_dim]) D_zA = get_z_D([None, flags.z_dim]) D_zB = get_z_D([None, flags.z_dim]) def KStest(real_z, fake_z): p_list = [] for i in range(flags.batch_size_train): _, tmp_p = stats.ks_2samp(fake_z[i], real_z[i]) p_list.append(tmp_p)
def train(con=False): dataset, len_dataset = get_dataset_train() len_dataset = flags.len_dataset G = get_G([None, flags.z_dim]) D = get_img_D([None, flags.img_size_h, flags.img_size_w, flags.c_dim]) E = get_E([None, flags.img_size_h, flags.img_size_w, flags.c_dim]) D_z = get_z_D([None, flags.z_dim]) if con: G.load_weights('./checkpoint/G.npz') D.load_weights('./checkpoint/D.npz') E.load_weights('./checkpoint/E.npz') D_z.load_weights('./checkpoint/D_z.npz') G.train() D.train() E.train() D_z.train() n_step_epoch = int(len_dataset // flags.batch_size_train) n_epoch = flags.n_epoch # lr_G = flags.lr_G * flags.initial_scale # lr_E = flags.lr_E * flags.initial_scale # lr_D = flags.lr_D * flags.initial_scale # lr_Dz = flags.lr_Dz * flags.initial_scale lr_G = flags.lr_G lr_E = flags.lr_E lr_D = flags.lr_D lr_Dz = flags.lr_Dz # total_step = n_epoch * n_step_epoch # lr_decay_G = flags.lr_G * (flags.ending_scale - flags.initial_scale) / total_step # lr_decay_E = flags.lr_G * (flags.ending_scale - flags.initial_scale) / total_step # lr_decay_D = flags.lr_G * (flags.ending_scale - flags.initial_scale) / total_step # lr_decay_Dz = flags.lr_G * (flags.ending_scale - flags.initial_scale) / total_step d_optimizer = tf.optimizers.Adam(lr_D, beta_1=flags.beta1, beta_2=flags.beta2) g_optimizer = tf.optimizers.Adam(lr_G, beta_1=flags.beta1, beta_2=flags.beta2) e_optimizer = tf.optimizers.Adam(lr_E, beta_1=flags.beta1, beta_2=flags.beta2) dz_optimizer = tf.optimizers.Adam(lr_Dz, beta_1=flags.beta1, beta_2=flags.beta2) curr_lambda = flags.lambda_recon for step, batch_imgs_labels in enumerate(dataset): ''' log = " ** new learning rate: %f (for GAN)" % (lr_v.tolist()[0]) print(log) ''' batch_imgs = batch_imgs_labels[0] # print("batch_imgs shape:") # print(batch_imgs.shape) # (64, 64, 64, 3) batch_labels = batch_imgs_labels[1] # print("batch_labels shape:") # print(batch_labels.shape) # (64,) epoch_num = step // n_step_epoch # for i in range(flags.batch_size_train): # tl.visualize.save_image(batch_imgs[i].numpy(), 'train_{:02d}.png'.format(i)) # # Updating recon lambda # if epoch_num <= 5: # 50 --> 25 # curr_lambda -= 5 # elif epoch_num <= 40: # stay at 25 # curr_lambda = 25 # else: # 25 --> 10 # curr_lambda -= 0.25 with tf.GradientTape(persistent=True) as tape: z = flags.scale * np.random.normal( loc=0.0, scale=flags.sigma * math.sqrt(flags.z_dim), size=[flags.batch_size_train, flags.z_dim]).astype(np.float32) z += flags.scale * np.random.binomial( n=1, p=0.5, size=[flags.batch_size_train, flags.z_dim]).astype( np.float32) fake_z = E(batch_imgs) fake_imgs = G(fake_z) fake_logits = D(fake_imgs) real_logits = D(batch_imgs) fake_logits_z = D(G(z)) real_z_logits = D_z(z) fake_z_logits = D_z(fake_z) e_loss_z = - tl.cost.sigmoid_cross_entropy(fake_z_logits, tf.zeros_like(fake_z_logits)) + \ tl.cost.sigmoid_cross_entropy(fake_z_logits, tf.ones_like(fake_z_logits)) recon_loss = curr_lambda * tl.cost.absolute_difference_error( batch_imgs, fake_imgs) g_loss_x = - tl.cost.sigmoid_cross_entropy(fake_logits, tf.zeros_like(fake_logits)) + \ tl.cost.sigmoid_cross_entropy(fake_logits, tf.ones_like(fake_logits)) g_loss_z = - tl.cost.sigmoid_cross_entropy(fake_logits_z, tf.zeros_like(fake_logits_z)) + \ tl.cost.sigmoid_cross_entropy(fake_logits_z, tf.ones_like(fake_logits_z)) e_loss = recon_loss + e_loss_z g_loss = recon_loss + g_loss_x + g_loss_z d_loss = tl.cost.sigmoid_cross_entropy(real_logits, tf.ones_like(real_logits)) + \ tl.cost.sigmoid_cross_entropy(fake_logits, tf.zeros_like(fake_logits)) + \ tl.cost.sigmoid_cross_entropy(fake_logits_z, tf.zeros_like(fake_logits_z)) dz_loss = tl.cost.sigmoid_cross_entropy(fake_z_logits, tf.zeros_like(fake_z_logits)) + \ tl.cost.sigmoid_cross_entropy(real_z_logits, tf.ones_like(real_z_logits)) # Updating Encoder grad = tape.gradient(e_loss, E.trainable_weights) e_optimizer.apply_gradients(zip(grad, E.trainable_weights)) # Updating Generator grad = tape.gradient(g_loss, G.trainable_weights) g_optimizer.apply_gradients(zip(grad, G.trainable_weights)) # Updating Discriminator grad = tape.gradient(d_loss, D.trainable_weights) d_optimizer.apply_gradients(zip(grad, D.trainable_weights)) # Updating D_z & D_h grad = tape.gradient(dz_loss, D_z.trainable_weights) dz_optimizer.apply_gradients(zip(grad, D_z.trainable_weights)) # # Updating lr # lr_G -= lr_decay_G # lr_E -= lr_decay_E # lr_D -= lr_decay_D # lr_Dz -= lr_decay_Dz # show current state if np.mod(step, flags.show_every_step) == 0: with open("log.txt", "a+") as f: p_min, p_avg = KStest(z, fake_z) sys.stdout = f # 输出指向txt文件 print( "Epoch: [{}/{}] [{}/{}] curr_lambda: {:.5f}, recon_loss: {:.5f}, g_loss: {:.5f}, d_loss: {:.5f}, " "e_loss: {:.5f}, dz_loss: {:.5f}, g_loss_x: {:.5f}, g_loss_z: {:.5f}, e_loss_z: {:.5f}" .format(epoch_num, flags.n_epoch, step - (epoch_num * n_step_epoch), n_step_epoch, curr_lambda, recon_loss, g_loss, d_loss, e_loss, dz_loss, g_loss_x, g_loss_z, e_loss_z)) print("kstest: min:{}, avg:{}".format(p_min, p_avg)) sys.stdout = temp_out # 输出重定向回console print( "Epoch: [{}/{}] [{}/{}] curr_lambda: {:.5f}, recon_loss: {:.5f}, g_loss: {:.5f}, d_loss: {:.5f}, " "e_loss: {:.5f}, dz_loss: {:.5f}, g_loss_x: {:.5f}, g_loss_z: {:.5f}, e_loss_z: {:.5f}" .format(epoch_num, flags.n_epoch, step - (epoch_num * n_step_epoch), n_step_epoch, curr_lambda, recon_loss, g_loss, d_loss, e_loss, dz_loss, g_loss_x, g_loss_z, e_loss_z)) print("kstest: min:{}, avg:{}".format(p_min, p_avg)) if np.mod(step, n_step_epoch) == 0 and step != 0: G.save_weights('{}/{}/G.npz'.format(flags.checkpoint_dir, flags.param_dir), format='npz') D.save_weights('{}/{}/D.npz'.format(flags.checkpoint_dir, flags.param_dir), format='npz') E.save_weights('{}/{}/E.npz'.format(flags.checkpoint_dir, flags.param_dir), format='npz') D_z.save_weights('{}/{}/Dz.npz'.format(flags.checkpoint_dir, flags.param_dir), format='npz') # G.train() if np.mod(step, flags.eval_step) == 0: z = np.random.normal(loc=0.0, scale=1, size=[flags.batch_size_train, flags.z_dim]).astype(np.float32) G.eval() result = G(z) G.train() tl.visualize.save_images( result.numpy(), [8, 8], '{}/{}/train_{:02d}_{:04d}.png'.format(flags.sample_dir, flags.param_dir, step // n_step_epoch, step)) del tape