def test_basic(self, sess, test_list, e, step): ne_imgs, emo_imgs, ne_lm, emo_lm, emo_label, \ emo_lm_ref, emo_label_ref = CelebA.getTrainImages_lm_embed(test_list) # input image save_images( ne_imgs[0:64], [8, 8], '{}/train_{:02d}_{:04d}_in.png'.format(self.sample_path, e, step)) # ground truth save_images( emo_imgs[0:64], [8, 8], '{}/train_{:02d}_{:04d}_r.png'.format(self.sample_path, e, step)) # generate image sample_images = sess.run(self.x_tilde, feed_dict={ self.images: ne_imgs, self.images_lm: ne_lm, self.isTrain: False, self.emotion_images_lm: emo_lm }) save_images( sample_images[0:64], [8, 8], '{}/train_{:02d}_{:04d}.png'.format(self.sample_path, e, step))
def test_landmark_interpolation(self, sess, test_list, e, step): ne_imgs, emo_imgs, ne_lm, emo_lm, emo_label, \ emo_lm_ref, emo_label_ref = CelebA.getTrainImages_lm_embed(test_list) #可以先测试landmark空间的插值,然后再试试隐空间的插值,以每个batch的第一张图像为基准 # print('realbatch_array', realbatch_array_test.shape) batch_size = self.batch_size img_width = int(np.sqrt(batch_size)) test_img = ne_imgs[0] test_img_lm = ne_lm[0] test_img_emotion_lm = emo_lm[0] test_lm_interpolation = [] for i in range(64): factor = i / 63 test_lm_interpolation.append(factor * test_img_lm + (1 - factor) * test_img_emotion_lm) test_lm_interpolation = np.asarray(test_lm_interpolation, dtype=np.float32) test_imgs = np.repeat(np.expand_dims(test_img, axis=0), 64, axis=0) # input img save_images( test_imgs[0:batch_size], [img_width, img_width], '{}/train_{:02d}_{:04d}_in.png'.format(self.lm_interpolation_path, 0, step)) # generate img sample_images = sess.run(self.x_tilde, feed_dict={ self.images: test_imgs, self.isTrain: False, self.emotion_images_lm: test_lm_interpolation }) save_images( sample_images[0:batch_size], [img_width, img_width], '{}/train_{:02d}_{:04d}.png'.format(self.lm_interpolation_path, 0, step))
def train(self): global_step = tf.Variable(0, trainable=False) add_global = global_step.assign_add(1) new_learning_rate = tf.train.exponential_decay(self.learn_rate_init, global_step=global_step, decay_steps=20000, decay_rate=0.98) #for D trainer_D = tf.train.AdamOptimizer(learning_rate=new_learning_rate, beta1=0.5) gradients_D = trainer_D.compute_gradients(self.D_loss, var_list=self.d_vars) # clipped_gradients_D = [(tf.clip_by_value(grad, -1.0, 1.0), var) for grad, var in gradients_D] opti_D = trainer_D.apply_gradients(gradients_D) #for G trainer_G = tf.train.AdamOptimizer(learning_rate=new_learning_rate, beta1=0.5) gradients_G = trainer_G.compute_gradients(self.G_loss, var_list=self.g_vars) # clipped_gradients_G = [(tf.clip_by_value(_[0], -1, 1.), _[1]) for _ in gradients_G] opti_G = trainer_G.apply_gradients(gradients_G) #for E trainer_E = tf.train.AdamOptimizer(learning_rate=new_learning_rate, beta1=0.5) gradients_E = trainer_E.compute_gradients(self.encode_loss, var_list=self.e_vars) # clipped_gradients_E = [(tf.clip_by_value(_[0], -1, 1.), _[1]) for _ in gradients_E] opti_E = trainer_E.apply_gradients(gradients_E) #for Embed trainer_Embed = tf.train.AdamOptimizer(learning_rate=new_learning_rate, beta1=0.5) gradients_Embed = trainer_Embed.compute_gradients( self.Embed_loss, var_list=self.embed_vars) # clipped_gradients_E = [(tf.clip_by_value(_[0], -1, 1.), _[1]) for _ in gradients_E] opti_Embed = trainer_Embed.apply_gradients(gradients_Embed) init = tf.global_variables_initializer() config = tf.ConfigProto() config.gpu_options.allow_growth = True with tf.Session(config=config) as sess: embed_vector = [] embed_label = [] sess.run(init) # 从断点处继续训练 self.saver.restore(sess, self.saved_model_path) summary_op = tf.summary.merge_all() summary_writer = tf.summary.FileWriter(self.log_dir, sess.graph) batch_num = 0 e = 0 step = 0 counter = 0 # 训练之前就应该shuffle一次 self.ds_train = self.shuffle_train(self.ds_train) while e <= self.max_epoch: max_iter = len(self.ds_train) / self.batch_size - 1 while batch_num < max_iter: step = step + 1 if batch_num >= max_iter - 1: self.ds_train = self.shuffle_train(self.ds_train) train_list = CelebA.getNextBatch(self.ds_train, batch_num, self.batch_size) ne_imgs, emo_imgs, ne_lm, emo_lm, emo_label, emo_lm_ref, emo_label_ref\ = CelebA.getTrainImages_lm_embed(train_list) sample_z = np.random.normal( size=[self.batch_size, self.latent_dim]) # WGAN-GP loops = 5 # # 先获取到当前G网络生成的图像 # fake_emo_imgs = sess.run(self.x_tilde, feed_dict={self.images: ne_imgs, # self.images_lm: ne_lm, # self.emotion_images_lm: emo_lm, # self.isTrain: True}) # # optimization D # local_x_batch, local_completion_batch = self.crop_local_imgs(emo_imgs, fake_emo_imgs) # T-SNE # embed_v = sess.run(self.lm_embed, feed_dict={self.emotion_images_lm: emo_lm, self.isTrain:False} ) # embed_vector.append(embed_v) # embed_label.append(emo_label) # # if step == 20: # with open('embed_vector.pickle', 'wb') as f: # pickle.dump(embed_vector, f, protocol=-1) # with open('embed_label.pickle', 'wb') as f: # pickle.dump(embed_label, f, protocol=-1) for _ in range(loops): sess.run( opti_D, feed_dict={ self.images: ne_imgs, self.z_p: sample_z, self.images_emotion: emo_imgs, # self.real_local_imgs: local_x_batch, self.fake_local_imgs: local_completion_batch,\ self.images_lm: ne_lm, self.emotion_label: emo_label, self.emotion_images_lm: emo_lm, self.isTrain: True }) # 后面再改 self.images_lm for _ in range(1): #optimization Embed sess.run(opti_Embed, feed_dict={ self.emotion_label: emo_label, self.emotion_images_lm: emo_lm, self.isTrain: True, self.emotion_label_reference: emo_label_ref, self.emotion_images_lm_reference: emo_lm_ref }) #optimization E sess.run( opti_E, feed_dict={ self.images: ne_imgs, self.images_emotion: emo_imgs, self.images_lm: ne_lm, self.isTrain: True, # self.real_local_imgs: local_x_batch, self.fake_local_imgs: local_completion_batch, self.emotion_label: emo_label, self.emotion_images_lm: emo_lm }) #optimizaiton G sess.run( opti_G, feed_dict={ self.images: ne_imgs, self.z_p: sample_z, self.images_emotion: emo_imgs, self.images_lm: ne_lm, self.isTrain: True, # self.real_local_imgs: local_x_batch, self.fake_local_imgs: local_completion_batch, self.emotion_label: emo_label, self.emotion_images_lm: emo_lm }) summary_str = sess.run( summary_op, feed_dict={ self.images: ne_imgs, self.z_p: sample_z, self.images_emotion: emo_imgs, self.images_lm: ne_lm, self.emotion_label: emo_label, self.emotion_images_lm: emo_lm, self.emotion_label_reference: emo_label_ref, self.isTrain: False, # self.real_local_imgs: local_x_batch, self.fake_local_imgs: local_completion_batch, self.emotion_images_lm_reference: emo_lm_ref }) summary_writer.add_summary(summary_str, step) batch_num += 1 new_learn_rate = sess.run(new_learning_rate) if new_learn_rate > 0.00005: sess.run(add_global) if step % 20 == 0: D_loss, fake_loss, encode_loss, LL_loss, kl_loss, recon_loss, positive_loss, negtive_loss, lm_recon_loss, Embed_loss, real_cls, fake_cls = sess.run( [ self.D_loss, self.G_loss, self.encode_loss, self.D_loss, self.LL_loss, self.recon_loss, self.positive_loss, self.negative_loss, self.lm_recon_loss, self.Embed_loss, self.real_emotion_cls_loss, self.fake_emotion_cls_loss ], feed_dict={ self.images: ne_imgs, self.z_p: sample_z, self.images_emotion: emo_imgs, self.images_lm: ne_lm, self.emotion_label: emo_label, self.emotion_images_lm: emo_lm, self.isTrain: False, # self.real_local_imgs: local_x_batch, self.fake_local_imgs: local_completion_batch, self.emotion_label_reference: emo_label_ref, self.emotion_images_lm_reference: emo_lm_ref }) print( "EPOCH %d step %d: D: loss = %.7f G: loss=%.7f Encode: loss=%.7f identity loss=%.7f KL=%.7f recon_loss=%.7f " "positive_loss=%.7f negtive_loss=%.7f lm_recon_loss=%.7f Embed_loss==%.7f real_cls=%.7f fake_cls=%.7f" % (e, step, D_loss, fake_loss, encode_loss, LL_loss, kl_loss, recon_loss, positive_loss, negtive_loss, lm_recon_loss, Embed_loss, real_cls, fake_cls)) # previous if np.mod(step, 20) == 1: self.ds_test = self.shuffle_train(self.ds_test) test_list = CelebA.getNextBatch( self.ds_test, 0, self.batch_size) self.test_basic(sess, test_list, 0, step) self.test_landmark_interpolation( sess, test_list, 0, step) self.test_expression_transfer(sess, 0, step) self.saver.save(sess, self.saved_model_path) # for tsne interpolation # if step > 0: # print('step', step) # self.ds_test = self.shuffle_train(self.ds_test) # test_list = CelebA.getNextBatch(self.ds_test, 0, self.batch_size) # # self.test_basic(sess, test_list, 0, step) # embed_inter_v, embed_img_n = self.test_landmark_interpolation(sess, test_list, 0, step) # embed_inter_vector.append(embed_inter_v) # embed_img_names.append(embed_img_n) # # if step == 20: # with open('embed_inter_vector.pickle', 'wb') as f: # pickle.dump(embed_inter_vector, f, protocol=-1) # with open('embed_img_names.pickle', 'wb') as f: # pickle.dump(embed_img_names, f, protocol=-1) # # # self.test_one_eye_close(sess, test_list, 0, step) # # self.test_expression_transfer(sess, test_list, 0, step) # self.saver.save(sess, self.saved_model_path) e += 1 batch_num = 0 save_path = self.saver.save(sess, self.saved_model_path) print("Model saved in file: %s" % save_path)