Beispiel #1
0
    def test_basic(self, sess, test_list, e, step):

        ne_imgs, emo_imgs, ne_lm, emo_lm, emo_label, \
        emo_lm_ref, emo_label_ref = CelebA.getTrainImages_lm_embed(test_list)

        # input image
        save_images(
            ne_imgs[0:64], [8, 8],
            '{}/train_{:02d}_{:04d}_in.png'.format(self.sample_path, e, step))

        # ground truth
        save_images(
            emo_imgs[0:64], [8, 8],
            '{}/train_{:02d}_{:04d}_r.png'.format(self.sample_path, e, step))

        # generate image
        sample_images = sess.run(self.x_tilde,
                                 feed_dict={
                                     self.images: ne_imgs,
                                     self.images_lm: ne_lm,
                                     self.isTrain: False,
                                     self.emotion_images_lm: emo_lm
                                 })
        save_images(
            sample_images[0:64], [8, 8],
            '{}/train_{:02d}_{:04d}.png'.format(self.sample_path, e, step))
Beispiel #2
0
    def test_landmark_interpolation(self, sess, test_list, e, step):

        ne_imgs, emo_imgs, ne_lm, emo_lm, emo_label, \
        emo_lm_ref, emo_label_ref = CelebA.getTrainImages_lm_embed(test_list)

        #可以先测试landmark空间的插值,然后再试试隐空间的插值,以每个batch的第一张图像为基准
        # print('realbatch_array', realbatch_array_test.shape)
        batch_size = self.batch_size
        img_width = int(np.sqrt(batch_size))
        test_img = ne_imgs[0]
        test_img_lm = ne_lm[0]
        test_img_emotion_lm = emo_lm[0]
        test_lm_interpolation = []
        for i in range(64):
            factor = i / 63
            test_lm_interpolation.append(factor * test_img_lm +
                                         (1 - factor) * test_img_emotion_lm)

        test_lm_interpolation = np.asarray(test_lm_interpolation,
                                           dtype=np.float32)
        test_imgs = np.repeat(np.expand_dims(test_img, axis=0), 64, axis=0)

        # input img
        save_images(
            test_imgs[0:batch_size], [img_width, img_width],
            '{}/train_{:02d}_{:04d}_in.png'.format(self.lm_interpolation_path,
                                                   0, step))

        # generate img
        sample_images = sess.run(self.x_tilde,
                                 feed_dict={
                                     self.images: test_imgs,
                                     self.isTrain: False,
                                     self.emotion_images_lm:
                                     test_lm_interpolation
                                 })

        save_images(
            sample_images[0:batch_size], [img_width, img_width],
            '{}/train_{:02d}_{:04d}.png'.format(self.lm_interpolation_path, 0,
                                                step))
Beispiel #3
0
    def train(self):
        global_step = tf.Variable(0, trainable=False)
        add_global = global_step.assign_add(1)
        new_learning_rate = tf.train.exponential_decay(self.learn_rate_init,
                                                       global_step=global_step,
                                                       decay_steps=20000,
                                                       decay_rate=0.98)
        #for D
        trainer_D = tf.train.AdamOptimizer(learning_rate=new_learning_rate,
                                           beta1=0.5)
        gradients_D = trainer_D.compute_gradients(self.D_loss,
                                                  var_list=self.d_vars)
        # clipped_gradients_D = [(tf.clip_by_value(grad, -1.0, 1.0), var) for grad, var in gradients_D]
        opti_D = trainer_D.apply_gradients(gradients_D)

        #for G
        trainer_G = tf.train.AdamOptimizer(learning_rate=new_learning_rate,
                                           beta1=0.5)
        gradients_G = trainer_G.compute_gradients(self.G_loss,
                                                  var_list=self.g_vars)
        # clipped_gradients_G = [(tf.clip_by_value(_[0], -1, 1.), _[1]) for _ in gradients_G]
        opti_G = trainer_G.apply_gradients(gradients_G)

        #for E
        trainer_E = tf.train.AdamOptimizer(learning_rate=new_learning_rate,
                                           beta1=0.5)
        gradients_E = trainer_E.compute_gradients(self.encode_loss,
                                                  var_list=self.e_vars)
        # clipped_gradients_E = [(tf.clip_by_value(_[0], -1, 1.), _[1]) for _ in gradients_E]
        opti_E = trainer_E.apply_gradients(gradients_E)

        #for Embed
        trainer_Embed = tf.train.AdamOptimizer(learning_rate=new_learning_rate,
                                               beta1=0.5)
        gradients_Embed = trainer_Embed.compute_gradients(
            self.Embed_loss, var_list=self.embed_vars)
        # clipped_gradients_E = [(tf.clip_by_value(_[0], -1, 1.), _[1]) for _ in gradients_E]
        opti_Embed = trainer_Embed.apply_gradients(gradients_Embed)

        init = tf.global_variables_initializer()
        config = tf.ConfigProto()
        config.gpu_options.allow_growth = True

        with tf.Session(config=config) as sess:

            embed_vector = []
            embed_label = []

            sess.run(init)
            # 从断点处继续训练
            self.saver.restore(sess, self.saved_model_path)
            summary_op = tf.summary.merge_all()
            summary_writer = tf.summary.FileWriter(self.log_dir, sess.graph)
            batch_num = 0
            e = 0
            step = 0
            counter = 0
            # 训练之前就应该shuffle一次
            self.ds_train = self.shuffle_train(self.ds_train)
            while e <= self.max_epoch:

                max_iter = len(self.ds_train) / self.batch_size - 1
                while batch_num < max_iter:

                    step = step + 1

                    if batch_num >= max_iter - 1:
                        self.ds_train = self.shuffle_train(self.ds_train)

                    train_list = CelebA.getNextBatch(self.ds_train, batch_num,
                                                     self.batch_size)

                    ne_imgs,  emo_imgs, ne_lm, emo_lm, emo_label, emo_lm_ref, emo_label_ref\
                        = CelebA.getTrainImages_lm_embed(train_list)

                    sample_z = np.random.normal(
                        size=[self.batch_size, self.latent_dim])

                    # WGAN-GP
                    loops = 5

                    # # 先获取到当前G网络生成的图像
                    # fake_emo_imgs = sess.run(self.x_tilde, feed_dict={self.images: ne_imgs,
                    #                                                   self.images_lm: ne_lm,
                    #                                                   self.emotion_images_lm: emo_lm,
                    #                                                   self.isTrain: True})
                    # # optimization D
                    # local_x_batch, local_completion_batch = self.crop_local_imgs(emo_imgs, fake_emo_imgs)

                    # T-SNE
                    # embed_v = sess.run(self.lm_embed,  feed_dict={self.emotion_images_lm: emo_lm, self.isTrain:False} )
                    # embed_vector.append(embed_v)
                    # embed_label.append(emo_label)
                    #
                    # if step == 20:
                    #     with open('embed_vector.pickle', 'wb') as f:
                    #         pickle.dump(embed_vector, f, protocol=-1)
                    #     with open('embed_label.pickle', 'wb') as f:
                    #         pickle.dump(embed_label, f, protocol=-1)

                    for _ in range(loops):
                        sess.run(
                            opti_D,
                            feed_dict={
                                self.images:
                                ne_imgs,
                                self.z_p:
                                sample_z,
                                self.images_emotion:
                                emo_imgs,
                                # self.real_local_imgs: local_x_batch, self.fake_local_imgs: local_completion_batch,\
                                self.images_lm:
                                ne_lm,
                                self.emotion_label:
                                emo_label,
                                self.emotion_images_lm:
                                emo_lm,
                                self.isTrain:
                                True
                            })

                    # 后面再改 self.images_lm
                    for _ in range(1):

                        #optimization Embed
                        sess.run(opti_Embed,
                                 feed_dict={
                                     self.emotion_label: emo_label,
                                     self.emotion_images_lm: emo_lm,
                                     self.isTrain: True,
                                     self.emotion_label_reference:
                                     emo_label_ref,
                                     self.emotion_images_lm_reference:
                                     emo_lm_ref
                                 })

                        #optimization E
                        sess.run(
                            opti_E,
                            feed_dict={
                                self.images:
                                ne_imgs,
                                self.images_emotion:
                                emo_imgs,
                                self.images_lm:
                                ne_lm,
                                self.isTrain:
                                True,
                                # self.real_local_imgs: local_x_batch, self.fake_local_imgs: local_completion_batch,
                                self.emotion_label:
                                emo_label,
                                self.emotion_images_lm:
                                emo_lm
                            })
                        #optimizaiton G
                        sess.run(
                            opti_G,
                            feed_dict={
                                self.images:
                                ne_imgs,
                                self.z_p:
                                sample_z,
                                self.images_emotion:
                                emo_imgs,
                                self.images_lm:
                                ne_lm,
                                self.isTrain:
                                True,
                                # self.real_local_imgs: local_x_batch, self.fake_local_imgs: local_completion_batch,
                                self.emotion_label:
                                emo_label,
                                self.emotion_images_lm:
                                emo_lm
                            })

                    summary_str = sess.run(
                        summary_op,
                        feed_dict={
                            self.images:
                            ne_imgs,
                            self.z_p:
                            sample_z,
                            self.images_emotion:
                            emo_imgs,
                            self.images_lm:
                            ne_lm,
                            self.emotion_label:
                            emo_label,
                            self.emotion_images_lm:
                            emo_lm,
                            self.emotion_label_reference:
                            emo_label_ref,
                            self.isTrain:
                            False,
                            # self.real_local_imgs: local_x_batch, self.fake_local_imgs: local_completion_batch,
                            self.emotion_images_lm_reference:
                            emo_lm_ref
                        })
                    summary_writer.add_summary(summary_str, step)

                    batch_num += 1

                    new_learn_rate = sess.run(new_learning_rate)
                    if new_learn_rate > 0.00005:
                        sess.run(add_global)

                    if step % 20 == 0:
                        D_loss, fake_loss, encode_loss, LL_loss, kl_loss, recon_loss, positive_loss, negtive_loss, lm_recon_loss, Embed_loss, real_cls, fake_cls = sess.run(
                            [
                                self.D_loss, self.G_loss, self.encode_loss,
                                self.D_loss, self.LL_loss, self.recon_loss,
                                self.positive_loss, self.negative_loss,
                                self.lm_recon_loss, self.Embed_loss,
                                self.real_emotion_cls_loss,
                                self.fake_emotion_cls_loss
                            ],
                            feed_dict={
                                self.images:
                                ne_imgs,
                                self.z_p:
                                sample_z,
                                self.images_emotion:
                                emo_imgs,
                                self.images_lm:
                                ne_lm,
                                self.emotion_label:
                                emo_label,
                                self.emotion_images_lm:
                                emo_lm,
                                self.isTrain:
                                False,
                                # self.real_local_imgs: local_x_batch, self.fake_local_imgs: local_completion_batch,
                                self.emotion_label_reference:
                                emo_label_ref,
                                self.emotion_images_lm_reference:
                                emo_lm_ref
                            })
                        print(
                            "EPOCH %d step %d: D: loss = %.7f G: loss=%.7f Encode: loss=%.7f identity loss=%.7f KL=%.7f recon_loss=%.7f "
                            "positive_loss=%.7f negtive_loss=%.7f lm_recon_loss=%.7f Embed_loss==%.7f real_cls=%.7f fake_cls=%.7f"
                            %
                            (e, step, D_loss, fake_loss, encode_loss, LL_loss,
                             kl_loss, recon_loss, positive_loss, negtive_loss,
                             lm_recon_loss, Embed_loss, real_cls, fake_cls))
                    # previous
                    if np.mod(step, 20) == 1:
                        self.ds_test = self.shuffle_train(self.ds_test)
                        test_list = CelebA.getNextBatch(
                            self.ds_test, 0, self.batch_size)

                        self.test_basic(sess, test_list, 0, step)
                        self.test_landmark_interpolation(
                            sess, test_list, 0, step)
                        self.test_expression_transfer(sess, 0, step)
                        self.saver.save(sess, self.saved_model_path)

                    # for tsne interpolation
                    # if step > 0:
                    #     print('step', step)
                    #     self.ds_test = self.shuffle_train(self.ds_test)
                    #     test_list = CelebA.getNextBatch(self.ds_test, 0, self.batch_size)
                    #
                    #     self.test_basic(sess, test_list, 0, step)
                    #     embed_inter_v, embed_img_n = self.test_landmark_interpolation(sess, test_list, 0, step)
                    #     embed_inter_vector.append(embed_inter_v)
                    #     embed_img_names.append(embed_img_n)
                    #
                    #     if step == 20:
                    #         with open('embed_inter_vector.pickle', 'wb') as f:
                    #             pickle.dump(embed_inter_vector, f, protocol=-1)
                    #         with open('embed_img_names.pickle', 'wb') as f:
                    #             pickle.dump(embed_img_names, f, protocol=-1)
                    #
                    #     # self.test_one_eye_close(sess, test_list, 0, step)
                    #     # self.test_expression_transfer(sess, test_list, 0, step)
                    #     self.saver.save(sess, self.saved_model_path)

                e += 1
                batch_num = 0
            save_path = self.saver.save(sess, self.saved_model_path)
            print("Model saved in file: %s" % save_path)