def disentangle_test(): print('Start disentangle_test!') ds, _ = get_dataset_train() E = get_E([None, flags.img_size_h, flags.img_size_w, flags.c_dim]) E.load_weights('{}/{}/E.npz'.format(flags.checkpoint_dir, flags.param_dir), format='npz') E.eval() G = get_G([None, flags.z_dim]) G.load_weights('{}/{}/G.npz'.format(flags.checkpoint_dir, flags.param_dir), format='npz') G.eval() for step, batch_img in enumerate(ds): if step > flags.disentangle_step_num: break z_real = E(batch_img) hash_real = ((tf.sign(z_real * 2 - 1, name=None) + 1) / 2).numpy() epsilon = flags.scale * np.random.normal( loc=0.0, scale=flags.sigma * math.sqrt(flags.z_dim) * 0.0625, size=[flags.batch_size_train, flags.z_dim]).astype(np.float32) # z_fake = hash_real + epsilon z_fake = z_real + epsilon fake_imgs = G(z_fake) tl.visualize.save_images( batch_img.numpy(), [8, 8], '{}/{}/disentangle/real_{:02d}.png'.format(flags.test_dir, flags.param_dir, step)) tl.visualize.save_images( fake_imgs.numpy(), [8, 8], '{}/{}/disentangle/fake_{:02d}.png'.format(flags.test_dir, flags.param_dir, step))
def ktest(): print('Start ktest!') ds, _ = get_dataset_train() E = get_E([None, flags.img_size_h, flags.img_size_w, flags.c_dim]) E.load_weights('{}/{}/E.npz'.format(flags.checkpoint_dir, flags.param_dir), format='npz') E.eval()
def train_F_D(): dataset, len_dataset = get_dataset_train() len_dataset = flags.len_dataset G.load_weights('./checkpoint/G.npz') E.load_weights('./checkpoint/E.npz') f_ab.train() f_ba.train() D_zA.train() D_zB.train() G.eval() E.eval() f_optimizer = tf.optimizers.Adam(flags.lr_F, beta_1=flags.beta1, beta_2=flags.beta2) d_optimizer = tf.optimizers.Adam(flags.lr_D, beta_1=flags.beta1, beta_2=flags.beta2) n_step_epoch = int(len_dataset // flags.batch_size_train) n_epoch = int(flags.step_num // n_step_epoch) for step, batch_imgs in enumerate(dataset): ''' log = " ** new learning rate: %f (for GAN)" % (lr_v.tolist()[0]) print(log) ''' image_a = batch_imgs[0] image_b = batch_imgs[1] epoch_num = step // n_step_epoch with tf.GradientTape(persistent=True) as tape: za_real = E() # Updating Encoder grad = tape.gradient(loss_c, C.trainable_weights) c_optimizer.apply_gradients(zip(grad, C.trainable_weights)) del tape # basic if np.mod(step, flags.show_freq) == 0 and step != 0: print("Epoch: [{}/{}] [{}/{}] batch_acc is {}, loss_c: {:.5f}". format(epoch_num, n_epoch, step, n_step_epoch, batch_acc, loss_c)) if np.mod(step, n_step_epoch) == 0 and step != 0: C.save_weights('{}/G.npz'.format(flags.checkpoint_dir), format='npz') if np.mod(step, flags.acc_step) == 0 and step != 0: acc = acc_sum / acc_step print("The avg step_acc is {:.3f} in {} step".format( acc, acc_step)) acc_sum = 0 acc_step = 0
def opt_z(): dataset, len_dataset = get_dataset_train() len_dataset = flags.len_dataset G.eval() E.eval() C.eval() z_optimizer = tf.optimizers.Adam(lr_C, beta_1=flags.beta1, beta_2=flags.beta2) for step, image_labels in enumerate(dataset): ''' log = " ** new learning rate: %f (for GAN)" % (lr_v.tolist()[0]) print(log) ''' if step >= flags.sample_cnt: break print('Now start {} img'.format(str(step))) batch_imgs = image_labels[0] batch_labels = image_labels[1] batch_labels = (batch_labels + 1) / 2 tran_label = tf.ones_like(batch_labels) - batch_labels for step_z in range(flags.step_z): with tf.GradientTape(persistent=True) as tape: real_z = E(batch_imgs) z = tf.Variable(real_z) z_logits = C(z) loss_z = tl.cost.sigmoid_cross_entropy(z_logits, tran_label) # Updating Encoder grad = tape.gradient(loss_z, z.trainable_weights) z_optimizer.apply_gradients(zip(grad, z.trainable_weights)) opt_img = G(z) del tape if np.mod(step_z, 10) == 0: tl.visualize.save_images( opt_img.numpy(), [8, 8], '{}/opt_img{:02d}_step{:02d}.png'.format( flags.opt_sample_dir, step, step_z)) tl.visualize.save_images( batch_imgs.numpy(), [8, 8], '{}/raw_img{:02d}.png'.format(flags.opt_sample_dir, step)) tl.visualize.save_images( opt_img.numpy(), [8, 8], '{}/opt_img{:02d}.png'.format(flags.opt_sample_dir, step))
def train_GE(con=False): dataset, len_dataset = get_dataset_train() len_dataset = flags.len_dataset print(con) if con: G.load_weights('./checkpoint/G.npz') E.load_weights('./checkpoint/E.npz') G.train() E.train() n_step_epoch = int(len_dataset // flags.batch_size_train) n_epoch = int(flags.step_num // n_step_epoch) lr_G = flags.lr_G lr_E = flags.lr_E g_optimizer = tf.optimizers.Adam(lr_G, beta_1=flags.beta1, beta_2=flags.beta2) e_optimizer = tf.optimizers.Adam(lr_E, beta_1=flags.beta1, beta_2=flags.beta2) eval_batch = None for step, image_labels in enumerate(dataset): ''' log = " ** new learning rate: %f (for GAN)" % (lr_v.tolist()[0]) print(log) ''' if step == 0: eval_batch = image_labels[0] tl.visualize.save_images( eval_batch.numpy(), [8, 8], '{}/eval_samples.png'.format(flags.sample_dir, step // n_step_epoch, step)) batch_imgs = image_labels[0] # ipdb.set_trace() epoch_num = step // n_step_epoch with tf.GradientTape(persistent=True) as tape: tl.visualize.save_images( batch_imgs.numpy(), [8, 8], '{}/raw_samples.png'.format(flags.sample_dir)) fake_z = E(batch_imgs) noise = np.random.normal(loc=0.0, scale=flags.noise_scale, size=fake_z.shape).astype(np.float32) recon_x = G(fake_z + noise) recon_loss = tl.cost.absolute_difference_error(batch_imgs, recon_x, is_mean=True) # reg_loss = tf.math.maximum(tl.cost.mean_squared_error(fake_z, tf.zeros_like(fake_z)), 0.5) len = tl.cost.mean_squared_error(fake_z, tf.zeros_like(fake_z)) # print(len) reg_loss = tf.math.maximum((len - 1) * (len - 1), flags.margin) e_loss = flags.lamba_recon * recon_loss + reg_loss g_loss = flags.lamba_recon * recon_loss # Updating Encoder grad = tape.gradient(e_loss, E.trainable_weights) e_optimizer.apply_gradients(zip(grad, E.trainable_weights)) # Updating Generator grad = tape.gradient(g_loss, G.trainable_weights) g_optimizer.apply_gradients(zip(grad, G.trainable_weights)) # basic if np.mod(step, flags.show_freq) == 0 and step != 0: print( "Epoch: [{}/{}] [{}/{}] e_loss: {:.5f}, g_loss: {:.5f}, recon_loss: {:.5f}, reg_loss: {:.5f}, len: {:.5f}" .format(epoch_num, n_epoch, step, n_step_epoch, e_loss, g_loss, recon_loss, reg_loss, len)) if np.mod(step, n_step_epoch) == 0 and step != 0: G.save_weights('{}/G.npz'.format(flags.checkpoint_dir), format='npz') E.save_weights('{}/E.npz'.format(flags.checkpoint_dir), format='npz') # G.train() if np.mod(step, flags.eval_step) == 0 and step != 0: # z = np.random.normal(loc=0.0, scale=1, size=[flags.batch_size_train, flags.z_dim]).astype(np.float32) G.eval() E.eval() recon_imgs = G(E(eval_batch)) G.train() E.train() tl.visualize.save_images( recon_imgs.numpy(), [8, 8], '{}/recon_{:02d}_{:04d}.png'.format(flags.sample_dir, step // n_step_epoch, step)) del tape
def train_Gz(): dataset, len_dataset = get_dataset_train() len_dataset = flags.len_dataset G.load_weights('./checkpoint/G.npz') E.load_weights('./checkpoint/E.npz') G.eval() E.eval() G_z.train() D_z.train() n_step_epoch = int(len_dataset // flags.batch_size_train) n_epoch = int(flags.step_num // n_step_epoch) lr_Dz = flags.lr_Dz lr_Gz = flags.lr_Gz dt_optimizer = tf.optimizers.Adam(lr_Dz, beta_1=flags.beta1, beta_2=flags.beta2) gt_optimizer = tf.optimizers.Adam(lr_Gz, beta_1=flags.beta1, beta_2=flags.beta2) for step, image_labels in enumerate(dataset): ''' log = " ** new learning rate: %f (for GAN)" % (lr_v.tolist()[0]) print(log) ''' batch_imgs = image_labels[0] epoch_num = step // n_step_epoch with tf.GradientTape(persistent=True) as tape: real_tensor = E(batch_imgs) z = np.random.normal(loc=0.0, scale=1, size=[flags.batch_size_train, flags.z_dim]).astype(np.float32) fake_tensor = G_z(z) fake_tensor_logits = D_z(fake_tensor) real_tensor_logits = D_z(real_tensor) gt_loss = tl.cost.sigmoid_cross_entropy( fake_tensor_logits, tf.ones_like(fake_tensor_logits)) dt_loss = tl.cost.sigmoid_cross_entropy(real_tensor_logits, tf.ones_like(real_tensor_logits)) + \ tl.cost.sigmoid_cross_entropy(fake_tensor_logits, tf.zeros_like(fake_tensor_logits)) # Updating Generator grad = tape.gradient(gt_loss, G_z.trainable_weights) gt_optimizer.apply_gradients(zip(grad, G_z.trainable_weights)) # # Updating D_z & D_h grad = tape.gradient(dt_loss, D_z.trainable_weights) dt_optimizer.apply_gradients(zip(grad, D_z.trainable_weights)) # basic if np.mod(step, flags.show_freq) == 0 and step != 0: print("Epoch: [{}/{}] [{}/{}] dt_loss: {:.5f}, gt_loss: {:.5f}". format(epoch_num, n_epoch, step, n_step_epoch, dt_loss, gt_loss)) if np.mod(step, n_step_epoch) == 0 and step != 0: G_z.save_weights('{}/G_z.npz'.format(flags.checkpoint_dir), format='npz') D_z.save_weights('{}/D_z.npz'.format(flags.checkpoint_dir), format='npz') # G.train() if np.mod(step, flags.eval_step) == 0 and step != 0: z = np.random.normal(loc=0.0, scale=1, size=[flags.batch_size_train, flags.z_dim]).astype(np.float32) G.eval() sample_tensor = G_z(z) sample_img = G(sample_tensor) G.train() tl.visualize.save_images( sample_img.numpy(), [8, 8], '{}/sample_{:02d}_{:04d}.png'.format(flags.sample_dir, step // n_step_epoch, step)) del tape
def train_GE(con=False): dataset, len_dataset = get_dataset_train() len_dataset = flags.len_dataset print(con) if con: G.load_weights('./checkpoint/G.npz') D.load_weights('./checkpoint/D.npz') E.load_weights('./checkpoint/E.npz') D_z.load_weights('./checkpoint/Dz.npz') G.train() D.train() E.train() D_z.train() n_step_epoch = int(len_dataset // flags.batch_size_train) n_epoch = int(flags.step_num // n_step_epoch) lr_G = flags.lr_G lr_E = flags.lr_E lr_D = flags.lr_D lr_Dz = flags.lr_Dz d_optimizer = tf.optimizers.Adam(lr_D, beta_1=flags.beta1, beta_2=flags.beta2) g_optimizer = tf.optimizers.Adam(lr_G, beta_1=flags.beta1, beta_2=flags.beta2) e_optimizer = tf.optimizers.Adam(lr_E, beta_1=flags.beta1, beta_2=flags.beta2) dz_optimizer = tf.optimizers.Adam(lr_Dz, beta_1=flags.beta1, beta_2=flags.beta2) # tfd = tfp.distributions # dist_normal = tfd.Normal(loc=0., scale=1.) # dist_Bernoulli = tfd.Bernoulli(probs=0.5) # dist_beta = tfd.Beta(0.5, 0.5) eval_batch = None for step, image_labels in enumerate(dataset): ''' log = " ** new learning rate: %f (for GAN)" % (lr_v.tolist()[0]) print(log) ''' if step == 0: eval_batch = image_labels[0] batch_imgs = image_labels[0] # ipdb.set_trace() epoch_num = step // n_step_epoch with tf.GradientTape(persistent=True) as tape: z = np.random.normal(loc=0.0, scale=1, size=[flags.batch_size_train, flags.z_dim]).astype(np.float32) tl.visualize.save_images( batch_imgs.numpy(), [8, 8], '{}/raw_samples.png'.format(flags.sample_dir)) fake_z = E(batch_imgs) fake_imgs = G(fake_z) fake_logits = D(fake_imgs) real_logits = D(batch_imgs) fake_logits_z = D(G(z)) real_z_logits = D_z(z) fake_z_logits = D_z(fake_z) e_loss_z = - tl.cost.sigmoid_cross_entropy(fake_z_logits, tf.zeros_like(fake_z_logits)) + \ tl.cost.sigmoid_cross_entropy(fake_z_logits, tf.ones_like(fake_z_logits)) recon_loss = flags.lamba_recon * tl.cost.absolute_difference_error( batch_imgs, fake_imgs, is_mean=True) g_loss_x = - tl.cost.sigmoid_cross_entropy(fake_logits, tf.zeros_like(fake_logits)) + \ tl.cost.sigmoid_cross_entropy(fake_logits, tf.ones_like(fake_logits)) g_loss_z = - tl.cost.sigmoid_cross_entropy(fake_logits_z, tf.zeros_like(fake_logits_z)) + \ tl.cost.sigmoid_cross_entropy(fake_logits_z, tf.ones_like(fake_logits_z)) e_loss = recon_loss + e_loss_z g_loss = recon_loss + g_loss_x + g_loss_z d_loss = tl.cost.sigmoid_cross_entropy(real_logits, tf.ones_like(real_logits)) + \ tl.cost.sigmoid_cross_entropy(fake_logits, tf.zeros_like(fake_logits)) + \ tl.cost.sigmoid_cross_entropy(fake_logits_z, tf.zeros_like(fake_logits_z)) dz_loss = tl.cost.sigmoid_cross_entropy(fake_z_logits, tf.zeros_like(fake_z_logits)) + \ tl.cost.sigmoid_cross_entropy(real_z_logits, tf.ones_like(real_z_logits)) # Updating Encoder grad = tape.gradient(e_loss, E.trainable_weights) e_optimizer.apply_gradients(zip(grad, E.trainable_weights)) # Updating Generator grad = tape.gradient(g_loss, G.trainable_weights) g_optimizer.apply_gradients(zip(grad, G.trainable_weights)) # Updating Discriminator grad = tape.gradient(d_loss, D.trainable_weights) d_optimizer.apply_gradients(zip(grad, D.trainable_weights)) # Updating D_z & D_h grad = tape.gradient(dz_loss, D_z.trainable_weights) dz_optimizer.apply_gradients(zip(grad, D_z.trainable_weights)) # basic if np.mod(step, flags.show_freq) == 0 and step != 0: print( "Epoch: [{}/{}] [{}/{}] e_loss: {:.5f}, g_loss: {:.5f}, d_loss: {:.5f}, " "dz_loss: {:.5f}, recon: {:.5f}".format( epoch_num, n_epoch, step, n_step_epoch, e_loss, g_loss, d_loss, dz_loss, recon_loss)) # Kstest p_min, p_avg = KStest(z, fake_z) print("kstest: min:{}, avg:{}", p_min, p_avg) if np.mod(step, n_step_epoch) == 0 and step != 0: G.save_weights('{}/G.npz'.format(flags.checkpoint_dir), format='npz') D.save_weights('{}/D.npz'.format(flags.checkpoint_dir), format='npz') E.save_weights('{}/E.npz'.format(flags.checkpoint_dir), format='npz') D_z.save_weights('{}/Dz.npz'.format(flags.checkpoint_dir), format='npz') # G.train() if np.mod(step, flags.eval_step) == 0 and step != 0: z = np.random.normal(loc=0.0, scale=1, size=[flags.batch_size_train, flags.z_dim]).astype(np.float32) G.eval() recon_imgs = G(E(eval_batch)) result = G(z) G.train() tl.visualize.save_images( result.numpy(), [8, 8], '{}/sample_{:02d}_{:04d}.png'.format(flags.sample_dir, step // n_step_epoch, step)) tl.visualize.save_images( recon_imgs.numpy(), [8, 8], '{}/recon_{:02d}_{:04d}.png'.format(flags.sample_dir, step // n_step_epoch, step)) del tape
def train_C(): dataset, len_dataset = get_dataset_train() len_dataset = flags.len_dataset # G.load_weights('./checkpoint/{}/G.npz'.format(flags.param_dir)) # E.load_weights('./checkpoint/{}/E.npz'.format(flags.param_dir)) G.eval() E.eval() C.train() c_optimizer = tf.optimizers.Adam(flags.lr_C, beta_1=flags.beta1, beta_2=flags.beta2) n_step_epoch = int(len_dataset // flags.batch_size_train) n_epoch = int(flags.step_num // n_step_epoch) acc_sum = 0 acc_step = 0 for step, image_labels in enumerate(dataset): ''' log = " ** new learning rate: %f (for GAN)" % (lr_v.tolist()[0]) print(log) ''' batch_imgs = image_labels[0] batch_labels = image_labels[1] batch_labels = (batch_labels + 1) / 2 batch_labels = tf.reshape(batch_labels, [flags.batch_size_train, 1]) batch_labels = tf.cast(batch_labels, tf.float32) epoch_num = step // n_step_epoch with tf.GradientTape(persistent=True) as tape: real_z = E(batch_imgs) z_logits = C(real_z) z_logits = tf.cast(z_logits, tf.float32) loss_c = tl.cost.sigmoid_cross_entropy(z_logits, batch_labels) logits = ((tf.sign(z_logits * 2 - 1) + 1) / 2) labels = batch_labels acc_num = flags.batch_size_train + tf.reduce_sum(logits * labels) - tf.reduce_sum(logits) \ - tf.reduce_sum(labels) batch_acc = acc_num / flags.batch_size_train acc_sum += batch_acc acc_step += 1 # Updating Encoder grad = tape.gradient(loss_c, C.trainable_weights) c_optimizer.apply_gradients(zip(grad, C.trainable_weights)) del tape # basic if np.mod(step, flags.show_freq) == 0 and step != 0: print("Epoch: [{}/{}] [{}/{}] batch_acc is {}, loss_c: {:.5f}". format(epoch_num, n_epoch, step, n_step_epoch, batch_acc, loss_c)) if np.mod(step, n_step_epoch) == 0 and step != 0: C.save_weights('{}/G.npz'.format(flags.checkpoint_dir), format='npz') if np.mod(step, flags.acc_step) == 0 and step != 0: acc = acc_sum / acc_step print("The avg step_acc is {:.3f} in {} step".format( acc, acc_step)) acc_sum = 0 acc_step = 0
def train(con=False): dataset, len_dataset = get_dataset_train() len_dataset = flags.len_dataset G = get_G([None, flags.z_dim]) D = get_img_D([None, flags.img_size_h, flags.img_size_w, flags.c_dim]) E = get_E([None, flags.img_size_h, flags.img_size_w, flags.c_dim]) D_z = get_z_D([None, flags.z_dim]) if con: G.load_weights('./checkpoint/G.npz') D.load_weights('./checkpoint/D.npz') E.load_weights('./checkpoint/E.npz') D_z.load_weights('./checkpoint/D_z.npz') G.train() D.train() E.train() D_z.train() n_step_epoch = int(len_dataset // flags.batch_size_train) n_epoch = flags.n_epoch # lr_G = flags.lr_G * flags.initial_scale # lr_E = flags.lr_E * flags.initial_scale # lr_D = flags.lr_D * flags.initial_scale # lr_Dz = flags.lr_Dz * flags.initial_scale lr_G = flags.lr_G lr_E = flags.lr_E lr_D = flags.lr_D lr_Dz = flags.lr_Dz # total_step = n_epoch * n_step_epoch # lr_decay_G = flags.lr_G * (flags.ending_scale - flags.initial_scale) / total_step # lr_decay_E = flags.lr_G * (flags.ending_scale - flags.initial_scale) / total_step # lr_decay_D = flags.lr_G * (flags.ending_scale - flags.initial_scale) / total_step # lr_decay_Dz = flags.lr_G * (flags.ending_scale - flags.initial_scale) / total_step d_optimizer = tf.optimizers.Adam(lr_D, beta_1=flags.beta1, beta_2=flags.beta2) g_optimizer = tf.optimizers.Adam(lr_G, beta_1=flags.beta1, beta_2=flags.beta2) e_optimizer = tf.optimizers.Adam(lr_E, beta_1=flags.beta1, beta_2=flags.beta2) dz_optimizer = tf.optimizers.Adam(lr_Dz, beta_1=flags.beta1, beta_2=flags.beta2) curr_lambda = flags.lambda_recon for step, batch_imgs_labels in enumerate(dataset): ''' log = " ** new learning rate: %f (for GAN)" % (lr_v.tolist()[0]) print(log) ''' batch_imgs = batch_imgs_labels[0] # print("batch_imgs shape:") # print(batch_imgs.shape) # (64, 64, 64, 3) batch_labels = batch_imgs_labels[1] # print("batch_labels shape:") # print(batch_labels.shape) # (64,) epoch_num = step // n_step_epoch # for i in range(flags.batch_size_train): # tl.visualize.save_image(batch_imgs[i].numpy(), 'train_{:02d}.png'.format(i)) # # Updating recon lambda # if epoch_num <= 5: # 50 --> 25 # curr_lambda -= 5 # elif epoch_num <= 40: # stay at 25 # curr_lambda = 25 # else: # 25 --> 10 # curr_lambda -= 0.25 with tf.GradientTape(persistent=True) as tape: z = flags.scale * np.random.normal( loc=0.0, scale=flags.sigma * math.sqrt(flags.z_dim), size=[flags.batch_size_train, flags.z_dim]).astype(np.float32) z += flags.scale * np.random.binomial( n=1, p=0.5, size=[flags.batch_size_train, flags.z_dim]).astype( np.float32) fake_z = E(batch_imgs) fake_imgs = G(fake_z) fake_logits = D(fake_imgs) real_logits = D(batch_imgs) fake_logits_z = D(G(z)) real_z_logits = D_z(z) fake_z_logits = D_z(fake_z) e_loss_z = - tl.cost.sigmoid_cross_entropy(fake_z_logits, tf.zeros_like(fake_z_logits)) + \ tl.cost.sigmoid_cross_entropy(fake_z_logits, tf.ones_like(fake_z_logits)) recon_loss = curr_lambda * tl.cost.absolute_difference_error( batch_imgs, fake_imgs) g_loss_x = - tl.cost.sigmoid_cross_entropy(fake_logits, tf.zeros_like(fake_logits)) + \ tl.cost.sigmoid_cross_entropy(fake_logits, tf.ones_like(fake_logits)) g_loss_z = - tl.cost.sigmoid_cross_entropy(fake_logits_z, tf.zeros_like(fake_logits_z)) + \ tl.cost.sigmoid_cross_entropy(fake_logits_z, tf.ones_like(fake_logits_z)) e_loss = recon_loss + e_loss_z g_loss = recon_loss + g_loss_x + g_loss_z d_loss = tl.cost.sigmoid_cross_entropy(real_logits, tf.ones_like(real_logits)) + \ tl.cost.sigmoid_cross_entropy(fake_logits, tf.zeros_like(fake_logits)) + \ tl.cost.sigmoid_cross_entropy(fake_logits_z, tf.zeros_like(fake_logits_z)) dz_loss = tl.cost.sigmoid_cross_entropy(fake_z_logits, tf.zeros_like(fake_z_logits)) + \ tl.cost.sigmoid_cross_entropy(real_z_logits, tf.ones_like(real_z_logits)) # Updating Encoder grad = tape.gradient(e_loss, E.trainable_weights) e_optimizer.apply_gradients(zip(grad, E.trainable_weights)) # Updating Generator grad = tape.gradient(g_loss, G.trainable_weights) g_optimizer.apply_gradients(zip(grad, G.trainable_weights)) # Updating Discriminator grad = tape.gradient(d_loss, D.trainable_weights) d_optimizer.apply_gradients(zip(grad, D.trainable_weights)) # Updating D_z & D_h grad = tape.gradient(dz_loss, D_z.trainable_weights) dz_optimizer.apply_gradients(zip(grad, D_z.trainable_weights)) # # Updating lr # lr_G -= lr_decay_G # lr_E -= lr_decay_E # lr_D -= lr_decay_D # lr_Dz -= lr_decay_Dz # show current state if np.mod(step, flags.show_every_step) == 0: with open("log.txt", "a+") as f: p_min, p_avg = KStest(z, fake_z) sys.stdout = f # 输出指向txt文件 print( "Epoch: [{}/{}] [{}/{}] curr_lambda: {:.5f}, recon_loss: {:.5f}, g_loss: {:.5f}, d_loss: {:.5f}, " "e_loss: {:.5f}, dz_loss: {:.5f}, g_loss_x: {:.5f}, g_loss_z: {:.5f}, e_loss_z: {:.5f}" .format(epoch_num, flags.n_epoch, step - (epoch_num * n_step_epoch), n_step_epoch, curr_lambda, recon_loss, g_loss, d_loss, e_loss, dz_loss, g_loss_x, g_loss_z, e_loss_z)) print("kstest: min:{}, avg:{}".format(p_min, p_avg)) sys.stdout = temp_out # 输出重定向回console print( "Epoch: [{}/{}] [{}/{}] curr_lambda: {:.5f}, recon_loss: {:.5f}, g_loss: {:.5f}, d_loss: {:.5f}, " "e_loss: {:.5f}, dz_loss: {:.5f}, g_loss_x: {:.5f}, g_loss_z: {:.5f}, e_loss_z: {:.5f}" .format(epoch_num, flags.n_epoch, step - (epoch_num * n_step_epoch), n_step_epoch, curr_lambda, recon_loss, g_loss, d_loss, e_loss, dz_loss, g_loss_x, g_loss_z, e_loss_z)) print("kstest: min:{}, avg:{}".format(p_min, p_avg)) if np.mod(step, n_step_epoch) == 0 and step != 0: G.save_weights('{}/{}/G.npz'.format(flags.checkpoint_dir, flags.param_dir), format='npz') D.save_weights('{}/{}/D.npz'.format(flags.checkpoint_dir, flags.param_dir), format='npz') E.save_weights('{}/{}/E.npz'.format(flags.checkpoint_dir, flags.param_dir), format='npz') D_z.save_weights('{}/{}/Dz.npz'.format(flags.checkpoint_dir, flags.param_dir), format='npz') # G.train() if np.mod(step, flags.eval_step) == 0: z = np.random.normal(loc=0.0, scale=1, size=[flags.batch_size_train, flags.z_dim]).astype(np.float32) G.eval() result = G(z) G.train() tl.visualize.save_images( result.numpy(), [8, 8], '{}/{}/train_{:02d}_{:04d}.png'.format(flags.sample_dir, flags.param_dir, step // n_step_epoch, step)) del tape
def train(): dataset, len_dataset = get_dataset_train() G = get_dwG([None, flags.z_dim], [None, flags.h_dim]) D = get_dwD([None, flags.img_size_h, flags.img_size_w, flags.c_dim]) G.train() D.train() n_step_epoch = int(len_dataset // flags.batch_size_train) lr_v = tf.Variable(flags.max_learning_rate) lr_decay = (flags.init_learning_rate - flags.max_learning_rate) / (n_step_epoch * flags.n_epoch) d_optimizer = tf.optimizers.Adam(lr_v, beta_1=flags.beta1) g_optimizer = tf.optimizers.Adam(lr_v, beta_1=flags.beta1) t_hash_loss_total = 10000 for step, batch_imgs in enumerate(dataset): #print(batch_imgs.shape) lambda_minEntrpBit = flags.lambda_minEntrpBit lambda_Hash = flags.lambda_HashBit lambda_L2 = flags.lambda_L2 # in first tenth epoch, no L2 & Hash loss if step//n_step_epoch == 0: lambda_L2 = 0 lambda_Hash = 0 # update learning rate lr_v.assign(lr_v + lr_decay) ''' log = " ** new learning rate: %f (for GAN)" % (lr_v.tolist()[0]) print(log) ''' with tf.GradientTape(persistent=True) as tape: z = np.random.normal(loc=0.0, scale=1, size=[flags.batch_size_train, flags.z_dim]).astype(np.float32) # z = np.random.uniform(low=0.0, high=1.0, size=[flags.batch_size_train, flags.z_dim]).astype(np.float32) hash = np.random.randint(low=0.0, high=1.0, size=[flags.batch_size_train,flags.h_dim]).astype(np.float32) d_last_hidden_real, d_real_logits, hash_real = D(batch_imgs) fake_images = G([z, hash]) d_last_hidden_fake, d_fake_logits, hash_fake = D(fake_images) weights_final = D.get_layer('hash_layer').all_weights[0] _, _, hash_aug= D(data_aug(batch_imgs)) # adv loss d_loss_real = tl.cost.sigmoid_cross_entropy(d_real_logits, tf.ones_like(d_real_logits)) d_loss_fake = tl.cost.sigmoid_cross_entropy(d_fake_logits, tf.zeros_like(d_fake_logits)) g_loss_fake = tl.cost.sigmoid_cross_entropy(d_fake_logits, tf.ones_like(d_fake_logits)) hash_l2_loss = tl.cost.mean_squared_error(hash, hash_fake, is_mean=True) # hash loss if t_hash_loss_total < flags.hash_loss_threshold: hash_loss_total = hash_loss(hash_real, hash_aug, weights_final, lambda_minEntrpBit[0]) else: hash_loss_total = hash_loss(hash_real, hash_aug, weights_final, lambda_minEntrpBit[1]) t_hash_loss_total = hash_loss_total # Save the new hash loss # feature matching loss (for generator) feature_matching_loss = tl.cost.mean_squared_error(d_last_hidden_real, d_last_hidden_fake, is_mean=True) # loss for discriminator d_loss = d_loss_real + d_loss_fake + \ lambda_L2 * hash_l2_loss + lambda_Hash * hash_loss_total g_loss = g_loss_fake # g_loss = g_loss_fake + feature_matching_loss grad = tape.gradient(d_loss, D.trainable_weights) d_optimizer.apply_gradients(zip(grad, D.trainable_weights)) # grad = tape.gradient(feature_matching_loss, G.trainable_weights) grad = tape.gradient(g_loss, G.trainable_weights) g_optimizer.apply_gradients(zip(grad, G.trainable_weights)) print("Epoch: [{}/{}] [{}/{}] L_D: {:.6f}, L_G: {:.6f}, L_Hash: {:.3f}, " "L_adv: {:.6f}, L_2: {:.3f}".format (step//n_step_epoch, flags.n_epoch, step, n_step_epoch, d_loss, g_loss, lambda_Hash * hash_loss_total, d_loss_real + d_loss_fake, lambda_L2 * hash_l2_loss)) if np.mod(step, flags.save_step) == 0: G.save_weights('{}/G.npz'.format(flags.checkpoint_dir), format='npz') D.save_weights('{}/D.npz'.format(flags.checkpoint_dir), format='npz') G.train() if np.mod(step, flags.eval_step) == 0: z = np.random.normal(loc=0.0, scale=1, size=[flags.batch_size_train, flags.z_dim]).astype(np.float32) # z = np.random.uniform(low=0.0, high=1.0, size=[flags.batch_size_train, flags.z_dim]).astype(np.float32) hash = np.random.randint(low=0.0, high=1.0, size=[flags.batch_size_train, flags.h_dim]).astype(np.float32) G.eval() result = G([z, hash]) G.train() tl.visualize.save_images(result.numpy(), [8, 8], '{}/train_{:02d}_{:04d}.png'.format(flags.sample_dir, step // n_step_epoch, step)) del tape
def Evaluate_mAP(): ####################### Functions ################ class Retrival_Obj(): def __init__(self, hash, label): self.label = label self.dist = 0 list1 = [True if hash[i] == 1 else False for i in range(len(hash))] # convert bool list to bool array self.hash = np.array(list1) def __repr__(self): return repr((self.hash, self.label, self.dist)) # to calculate the hamming dist between obj1 & obj2 def hamming(obj1, obj2): res = obj1.hash ^ obj2.hash ans = 0 for k in range(len(res)): if res[k] == True: ans += 1 obj2.dist = ans def take_ele(obj): return obj.dist # to get 'nearest_num' nearest objs from 'image' in 'Gallery' def get_nearest(image, Gallery, nearest_num): for obj in Gallery: hamming(image, obj) Gallery.sort(key=take_ele) ans = [] cnt = 0 for obj in Gallery: cnt += 1 if cnt <= nearest_num: ans.append(obj) else: break return ans # given retrivial_set, calc AP w.r.t. given label def calc_ap(retrivial_set, label): total_num = 0 ac_num = 0 ans = 0 result = [] for obj in retrivial_set: total_num += 1 if obj.label == label: ac_num += 1 ans += ac_num / total_num result.append(ac_num / total_num) result = np.array(result) ans = np.mean(result) return ans ################ Start eval ########################## print('Start Eval!') # load images & labels ds, _ = get_dataset_train() E = get_E([None, flags.img_size_h, flags.img_size_w, flags.c_dim]) E.load_weights('{}/{}/E.npz'.format(flags.checkpoint_dir, flags.param_dir), format='npz') E.eval() # create (hash,label) gallery Gallery = [] cnt = 0 step_time1 = time.time() for batch, label in ds: cnt += 1 if cnt % flags.eval_print_freq == 0: step_time2 = time.time() print("Now {} Imgs done, takes {:.3f} sec".format( cnt, step_time2 - step_time1)) step_time1 = time.time() hash_fake, _ = E(batch) hash_fake = hash_fake.numpy()[0] hash_fake = ((tf.sign(hash_fake * 2 - 1, name=None) + 1) / 2).numpy() label = label.numpy()[0] Gallery.append(Retrival_Obj(hash_fake, label)) print('Hash calc done, start split dataset') #sample 1000 from Gallery and bulid the Query set random.shuffle(Gallery) cnt = 0 Queryset = [] G = [] for obj in Gallery: cnt += 1 if cnt > flags.eval_sample: G.append(obj) else: Queryset.append(obj) Gallery = G print('split done, start eval') # Calculate mAP Final_mAP = 0 step_time1 = time.time() for eval_epoch in range(flags.eval_epoch_num): result_list = [] cnt = 0 for obj in Queryset: cnt += 1 if cnt % flags.retrieval_print_freq == 0: step_time2 = time.time() print("Now Steps {} done, takes {:.3f} sec".format( eval_epoch, cnt, step_time2 - step_time1)) step_time1 = time.time() retrivial_set = get_nearest(obj, Gallery, flags.nearest_num) result = calc_ap(retrivial_set, obj.label) result_list.append(result) result_list = np.array(result_list) temp_res = np.mean(result_list) print("Query_num:{}, Eval_step:{}, Top_k_num:{}, AP:{:.3f}".format( flags.eval_sample, eval_epoch, flags.nearest_num, temp_res)) Final_mAP += temp_res / flags.eval_epoch_num print('') print("Query_num:{}, Eval_num:{}, Top_k_num:{}, mAP:{:.3f}".format( flags.eval_sample, flags.eval_epoch_num, flags.nearest_num, Final_mAP)) print('')