Esempio n. 1
0
    def test(self, *args, **kwargs):
        self.generator = self.make_generator(**kwargs)

        ckpt = tf.train.Checkpoint(g_model=self.generator)
        fname, _ = load_checkpoint(**kwargs)
        print("\nCheckpoint File : {}\n".format(fname))

        # model만 불러옴
        ckpt.mapped = {"g_model": self.generator}
        ckpt.restore(fname).expect_partial()

        _, model_images_path, _, _ = make_folders_for_model(kwargs['folder'])
        fname = os.path.join(model_images_path, "Test.png")

        self.plot_images(fname)
Esempio n. 2
0
    def train(self, **kwargs):
        interval = kwargs["interval"]
        model_ckpt_path, model_images_path, model_logs_path, model_result_file = make_folders_for_model(
            kwargs['folder'])

        self.generator = self.make_generator()
        self.discriminator = self.make_discriminator()

        train_dataset = self.get_dataset()
        num_batches = ceil(self.num_train / self.batch_size)

        d_epoch_loss = []
        d_epoch_aux_loss = []
        g_epoch_loss = []

        training_progbar = tf.keras.utils.Progbar(target=self.num_train)

        save_initial_model_info(
            {
                'generator': self.generator,
                'discriminator': self.discriminator
            }, model_logs_path, model_ckpt_path, **kwargs)

        count = 0

        self.g_opt = tf.keras.optimizers.Adam(lr=self.g_lr, beta_1=0.5)
        self.d_opt = tf.keras.optimizers.Adam(lr=self.d_lr, beta_1=0.5)

        self.BC_function = tf.keras.losses.BinaryCrossentropy(from_logits=True)
        self.SCC_function = tf.keras.losses.SparseCategoricalCrossentropy(
            from_logits=True)

        ckpt = tf.train.Checkpoint(g_opt=self.g_opt,
                                   d_opt=self.d_opt,
                                   g_model=self.generator,
                                   d_model=self.discriminator)

        if kwargs["ckpt_path"] is not None:
            fname, self.initial_epoch = load_checkpoint(**kwargs)
            print("\nCheckpoint File : {}\n".format(fname))
            ckpt.mapped = {
                "g_opt": self.g_opt,
                "d_opt": self.d_opt,
                "g_model": self.generator,
                "d_model": self.discriminator
            }
            ckpt.restore(fname)

            self.g_lr = self.g_opt.get_config()["learning_rate"]
            self.d_lr = self.d_opt.get_config()["learning_rate"]

        for epoch in range(self.initial_epoch, self.initial_epoch + 50000):
            count += 1

            start_time = korea_time()

            for real_images, real_labels in train_dataset:
                num_images = K.int_shape(real_labels)[0]
                g_loss = (self.train_G(num_images)).numpy()
                d_BC_loss, d_SCC_loss = self.train_D(real_images, real_labels)
                d_BC_loss = d_BC_loss.numpy()
                d_SCC_loss = d_SCC_loss.numpy()

                d_epoch_loss.append(d_BC_loss)
                d_epoch_aux_loss.append(d_SCC_loss)
                g_epoch_loss.append(g_loss)

                training_progbar.add(num_images)

            end_time = korea_time()
            training_progbar.update(0)  # Progress bar 초기화

            d_mean_loss = np.mean(d_epoch_loss, axis=0)
            d_mean_aux_loss = np.mean(d_epoch_aux_loss, axis=0)
            g_mean_loss = np.mean(g_epoch_loss, axis=0)

            ckpt_prefix = os.path.join(
                model_ckpt_path, "Epoch-{}_G-Loss-{:.6f}_D-Loss-{:.6f}".format(
                    epoch, g_mean_loss, d_mean_loss + d_mean_aux_loss))
            ckpt.save(file_prefix=ckpt_prefix)

            str_ = ("Epoch = [{:5d}]\tG Loss = [{:8.6f}]\t".format(
                epoch, g_mean_loss) +
                    "D Loss = [{:8.6f}]\tD AUX Loss = [{:8.6f}]\n".format(
                        d_mean_loss, d_mean_aux_loss))
            print(str_)

            # model result 저장
            str_ = "Epoch = [{:5d}] - End Time [ {} ]\n".format(
                epoch, str(end_time.strftime("%Y / %m / %d   %H:%M:%S")))
            str_ += "Elapsed Time = {}\n".format(end_time - start_time)
            str_ += "G Learning Rate = [{:.6f}] - D Learning Rate = [{:.6f}]\n".format(
                self.g_lr, self.d_lr)
            str_ += "G Loss : [{:8.6f}] - D Loss : [{:8.6f}] - D AUX Loss : [{:8.6f}] - Sum : [{:8.6f}]\n".format(
                g_mean_loss, d_mean_loss, d_mean_aux_loss,
                g_mean_loss + d_mean_loss + d_mean_aux_loss)
            str_ += " - " * 15 + "\n\n"

            with open(model_result_file, "a+", encoding='utf-8') as fp:
                fp.write(str_)

            if count == interval:
                fname = os.path.join(model_images_path, "{}.png".format(epoch))
                self.plot_images(fname)

                count = 0

            d_epoch_loss = []
            d_epoch_aux_loss = []
            g_epoch_loss = []
Esempio n. 3
0
File: WGAN.py Progetto: leesc912/GAN
    def train(self, **kwargs):
        interval = kwargs["interval"]
        model_ckpt_path, model_images_path, model_logs_path, model_result_file = make_folders_for_model(
            kwargs['folder'])

        self.generator = self.make_generator()
        self.critic = self.make_critic()

        train_dataset = self.get_dataset()
        num_batches = ceil(self.num_train / self.batch_size)

        c_epoch_loss = []
        g_epoch_loss = []

        training_progbar = tf.keras.utils.Progbar(target=self.num_train)

        save_initial_model_info(
            {
                'generator': self.generator,
                'critic': self.critic
            }, model_logs_path, model_ckpt_path, **kwargs)

        count = 0

        self.g_opt = tf.keras.optimizers.Adam(lr=self.g_lr,
                                              beta_1=0,
                                              beta_2=0.9)
        self.c_opt = tf.keras.optimizers.Adam(lr=self.c_lr,
                                              beta_1=0,
                                              beta_2=0.9)

        ckpt = tf.train.Checkpoint(g_opt=self.g_opt,
                                   c_opt=self.c_opt,
                                   g_model=self.generator,
                                   c_model=self.critic)

        if kwargs["ckpt_path"] is not None:
            fname, self.initial_epoch = load_checkpoint(**kwargs)
            print("\nCheckpoint File : {}\n".format(fname))
            ckpt.mapped = {
                "g_opt": self.g_opt,
                "c_opt": self.c_opt,
                "g_model": self.generator,
                "c_model": self.critic
            }
            ckpt.restore(fname)

            self.g_lr = self.g_opt.get_config()["learning_rate"]
            self.c_lr = self.c_opt.get_config()["learning_rate"]

        for epoch in range(self.initial_epoch, self.initial_epoch + 50000):
            count += 1

            start_time = korea_time()
            num_batch = 0  # 64 * 5 = 320
            mult = self.n_critic * self.batch_size
            num_dataset = 0  # 60000
            real_images_list = []

            for real_images in train_dataset:
                # self.n_critic개 만큼의 image dataset을 불러옴
                real_images_list.append(real_images)
                num_images = K.int_shape(real_images)[0]
                num_batch += num_images
                num_dataset += num_images

                if (num_batch == mult) or (num_dataset == self.num_train):
                    critic_loss_list = [(self.train_D(real_images)).numpy()
                                        for real_images in real_images_list]
                    g_loss = (self.train_G()).numpy()

                    c_epoch_loss.extend(critic_loss_list)
                    g_epoch_loss.append(g_loss)

                    training_progbar.add(num_batch)

                    if num_dataset == self.num_train:
                        break

                    num_batch = 0
                    real_images_list = []

            end_time = korea_time()
            training_progbar.update(0)  # Progress bar 초기화

            c_mean_loss = np.mean(c_epoch_loss, axis=0)
            g_mean_loss = np.mean(g_epoch_loss, axis=0)

            ckpt_prefix = os.path.join(
                model_ckpt_path, "Epoch-{}_G-Loss-{:.6f}_C-Loss-{:.6f}".format(
                    epoch, g_mean_loss, c_mean_loss))
            ckpt.save(file_prefix=ckpt_prefix)

            print(
                "Epoch = [{:5d}]\tGenerator Loss = [{:8.6f}]\tCritic Loss = [{:8.6f}]\n"
                .format(epoch, g_mean_loss, c_mean_loss))

            # model result 저장
            str_ = "Epoch = [{:5d}] - End Time [ {} ]\n".format(
                epoch, str(end_time.strftime("%Y / %m / %d   %H:%M:%S")))
            str_ += "Elapsed Time = {}\n".format(end_time - start_time)
            str_ += "Generator Learning Rate = [{:.6f}] - Critic Learning Rate = [{:.6f}]\n".format(
                self.g_lr, self.c_lr)
            str_ += "Generator Loss : [{:8.6f}] - Critic Loss : [{:8.6f}] - Sum : [{:8.6f}]\n".format(
                g_mean_loss, c_mean_loss, g_mean_loss + c_mean_loss)
            str_ += " - " * 15 + "\n\n"

            with open(model_result_file, "a+", encoding='utf-8') as fp:
                fp.write(str_)

            if count == interval:
                fname = os.path.join(model_images_path, "{}.png".format(epoch))
                self.plot_images(fname)

                count = 0

            c_epoch_loss = []
            g_epoch_loss = []