Ejemplo n.º 1
0
    def train(self):
        # iteration이 지날때 마다 linearly alpha를 증가시킴
        step_pl = tf.placeholder(tf.float32, shape=None)
        alpha_tra_assign = self.alpha_tra.assign(step_pl / self.max_iters)

        # optimizer를 이용하여 loss를 줄이는 방향으로 variable을 업데이트해나간다.
        opti_D = tf.train.AdamOptimizer(learning_rate=self.learning_rate,
                                        beta1=0.0,
                                        beta2=0.99).minimize(
                                            self.D_loss, var_list=self.d_vars)
        opti_G = tf.train.AdamOptimizer(learning_rate=self.learning_rate,
                                        beta1=0.0,
                                        beta2=0.99).minimize(
                                            self.G_loss, var_list=self.g_vars)

        # 변수들을 초기화하는 그래프
        init = tf.global_variables_initializer()
        # GPU 메모리를 필요할때 마다 조금씩 늘리도록 한다.
        config = tf.ConfigProto()
        config.gpu_options.allow_growth = True

        with tf.Session(config=config) as sess:

            # 변수들을 초기화하는 그래프를 돌림
            sess.run(init)

            summary_op = tf.summary.merge_all()
            summary_writer = tf.summary.FileWriter(self.log_dir, sess.graph)

            # progressive 값이 1, 7을 제외한 값이라면
            if self.pg != 1 and self.pg != 7:
                # transition이 일어나는 때라면
                if self.trans:

                    # d_vars_n_read와 g_vars_n_read를 담당하는 r_saver가 이전 progressive에서
                    # training 하여 저장했던 weight들을 불러온다
                    self.r_saver.restore(sess, self.read_model_path)
                    # 이전에 학습했던 d_vars_n_2_rgb와 g_vars_n_2_rgb를 불러온다
                    self.rgb_saver.restore(sess, self.read_model_path)

                # trainsition이 일어나는 단계가 아니라면
                else:
                    # 전체 이전 단계의 전체 weight를 불러온다.
                    self.saver.restore(sess, self.read_model_path)

            step = 0
            batch_num = 0
            # 16 x 32000 = 512,000개의 실제 이미지를 discriminator에게 보여줄 때까지 돌림
            while step <= self.max_iters:

                # optimization D
                n_critic = 1
                if self.pg == 5 and self.trans:
                    n_critic = 1

                for i in range(n_critic):
                    # 512 짜리 latent vector를 만듬
                    sample_z = np.random.normal(
                        size=[self.batch_size, self.sample_size])
                    # 실제 이미지의 path를 batch단위씩 얻음 (16개씩)
                    train_list = self.data_In.getNextBatch(
                        batch_num, self.batch_size)
                    # 이 path를 현재 단계의 아웃풋 사이즈만큼 리사이즈 시킴
                    # 예를 들어 1단계에서 output_size가 4이므로 4x4로 resize한다.
                    realbatch_array = CelebA.getShapeForData(
                        train_list, resize_w=self.output_size)
                    # 만약 transition(fade)가 일어나는 차례라면
                    if self.trans and self.pg != 0:

                        alpha = np.float(step) / self.max_iters
                        # 이미지 해상도 변경
                        # https://stackoverflow.com/questions/37119071/scipy-rotate-and-zoom-an-image-without-changing-its-dimensions
                        low_realbatch_array = scipy.ndimage.zoom(
                            realbatch_array, zoom=[1, 0.5, 0.5, 1])
                        low_realbatch_array = scipy.ndimage.zoom(
                            low_realbatch_array, zoom=[1, 2, 2, 1])
                        # resolution transition중에 실제이미지의 resolution 사이의 보간을 한다.
                        realbatch_array = alpha * realbatch_array + (
                            1 - alpha) * low_realbatch_array

                    sess.run(opti_D,
                             feed_dict={
                                 self.images: realbatch_array,
                                 self.z: sample_z
                             })
                    # 다음 번 배치
                    batch_num += 1

                # optimization G
                sess.run(opti_G, feed_dict={self.z: sample_z})

                summary_str = sess.run(summary_op,
                                       feed_dict={
                                           self.images: realbatch_array,
                                           self.z: sample_z
                                       })
                summary_writer.add_summary(summary_str, step)
                # the alpha of fake_in process
                # 1씩 증가하는 step을 step_placeholder에 넣고 alpha_transition값을 계산
                sess.run(alpha_tra_assign, feed_dict={step_pl: step})
                # 400번째 step마다 loss 출력하고 이미지를 저장
                if step % 400 == 0:

                    D_loss, G_loss, D_origin_loss, alpha_tra = sess.run(
                        [
                            self.D_loss, self.G_loss, self.D_origin_loss,
                            self.alpha_tra
                        ],
                        feed_dict={
                            self.images: realbatch_array,
                            self.z: sample_z
                        })
                    print(
                        "PG %d, step %d: D loss=%.7f G loss=%.7f, D_or loss=%.7f, opt_alpha_tra=%.7f"
                        % (self.pg, step, D_loss, G_loss, D_origin_loss,
                           alpha_tra))

                    realbatch_array = np.clip(realbatch_array, -1, 1)
                    save_images(
                        realbatch_array[0:self.batch_size],
                        [2, self.batch_size / 2],
                        '{}/{:02d}_real.png'.format(self.sample_path, step))

                    # 만약 transition 단계인 경우에
                    if self.trans and self.pg != 0:

                        low_realbatch_array = np.clip(low_realbatch_array, -1,
                                                      1)

                        save_images(
                            low_realbatch_array[0:self.batch_size],
                            [2, self.batch_size / 2],
                            '{}/{:02d}_real_lower.png'.format(
                                self.sample_path, step))

                    fake_image = sess.run(self.fake_images,
                                          feed_dict={
                                              self.images: realbatch_array,
                                              self.z: sample_z
                                          })
                    fake_image = np.clip(fake_image, -1, 1)
                    save_images(
                        fake_image[0:self.batch_size],
                        [2, self.batch_size / 2],
                        '{}/{:02d}_train.png'.format(self.sample_path, step))

                # 4000번째 마다 g_vars와 d_vars (network의 weight)를 중간저장함
                if np.mod(step, 4000) == 0 and step != 0:
                    self.saver.save(sess, self.gan_model_path)
                step += 1
            # max_iter 끝나고 최종 모델을 저장함
            save_path = self.saver.save(sess, self.gan_model_path)
            print("Model saved in file: %s" % save_path)

        tf.reset_default_graph()
Ejemplo n.º 2
0
    def train(self):

        step_pl = tf.placeholder(tf.float32, shape=None)
        alpha_tra_assign = self.alpha_tra.assign(step_pl / self.max_iters)

        opti_D = tf.train.AdamOptimizer(learning_rate=self.learning_rate,
                                        beta1=0.0,
                                        beta2=0.99).minimize(
                                            self.D_loss, var_list=self.d_vars)
        opti_G = tf.train.AdamOptimizer(learning_rate=self.learning_rate,
                                        beta1=0.0,
                                        beta2=0.99).minimize(
                                            self.G_loss, var_list=self.g_vars)

        init = tf.global_variables_initializer()
        config = tf.ConfigProto()
        config.gpu_options.allow_growth = True

        with tf.Session(config=config) as sess:

            sess.run(init)
            summary_op = tf.summary.merge_all()
            summary_writer = tf.summary.FileWriter(self.log_dir, sess.graph)

            if self.pg != 1 and self.pg != 7:

                if self.trans:
                    self.r_saver.restore(sess, self.read_model_path)
                    self.rgb_saver.restore(sess, self.read_model_path)

                else:
                    self.saver.restore(sess, self.read_model_path)

            step = 0
            batch_num = 0
            while step <= self.max_iters:

                # optimization D
                n_critic = 1
                if self.pg == 5 and self.trans:
                    n_critic = 1

                for i in range(n_critic):

                    sample_z = np.random.normal(
                        size=[self.batch_size, self.sample_size])
                    train_list = self.data_In.getNextBatch(
                        batch_num, self.batch_size)
                    realbatch_array = CelebA.getShapeForData(
                        train_list, resize_w=self.output_size)

                    if self.trans and self.pg != 0:

                        alpha = np.float(step) / self.max_iters

                        low_realbatch_array = scipy.ndimage.zoom(
                            realbatch_array, zoom=[1, 0.5, 0.5, 1])
                        low_realbatch_array = scipy.ndimage.zoom(
                            low_realbatch_array, zoom=[1, 2, 2, 1])
                        realbatch_array = alpha * realbatch_array + (
                            1 - alpha) * low_realbatch_array

                    sess.run(opti_D,
                             feed_dict={
                                 self.images: realbatch_array,
                                 self.z: sample_z
                             })
                    batch_num += 1

                # optimization G
                sess.run(opti_G, feed_dict={self.z: sample_z})

                summary_str = sess.run(summary_op,
                                       feed_dict={
                                           self.images: realbatch_array,
                                           self.z: sample_z
                                       })
                summary_writer.add_summary(summary_str, step)
                # the alpha of fake_in process
                sess.run(alpha_tra_assign, feed_dict={step_pl: step})

                if step % 1000 == 0:

                    D_loss, G_loss, D_origin_loss, alpha_tra = sess.run(
                        [
                            self.D_loss, self.G_loss, self.D_origin_loss,
                            self.alpha_tra
                        ],
                        feed_dict={
                            self.images: realbatch_array,
                            self.z: sample_z
                        })
                    print(
                        "PG %d, step %d: D loss=%.7f G loss=%.7f, D_or loss=%.7f, opt_alpha_tra=%.7f"
                        % (self.pg, step, D_loss, G_loss, D_origin_loss,
                           alpha_tra))

                    realbatch_array = np.clip(realbatch_array, -1, 1)
                    save_images(
                        realbatch_array[0:self.batch_size],
                        [2, self.batch_size / 2],
                        '{}/{:02d}_real.png'.format(self.sample_path, step))

                    if self.trans and self.pg != 0:

                        low_realbatch_array = np.clip(low_realbatch_array, -1,
                                                      1)

                        save_images(
                            low_realbatch_array[0:self.batch_size],
                            [2, self.batch_size / 2],
                            '{}/{:02d}_real_lower.png'.format(
                                self.sample_path, step))

                    fake_image = sess.run(self.fake_images,
                                          feed_dict={
                                              self.images: realbatch_array,
                                              self.z: sample_z
                                          })
                    fake_image = np.clip(fake_image, -1, 1)
                    save_images(
                        fake_image[0:self.batch_size],
                        [2, self.batch_size / 2],
                        '{}/{:02d}_train.png'.format(self.sample_path, step))

                if np.mod(step, 4000) == 0 and step != 0:
                    self.saver.save(sess, self.gan_model_path)
                step += 1

            save_path = self.saver.save(sess, self.gan_model_path)
            print("Model saved in file: %s" % save_path)

        tf.reset_default_graph()