示例#1
0
    def __init__(self, sess, args):
        self.model_name = 'AnimeGANv2'
        self.sess = sess
        self.checkpoint_dir = args.checkpoint_dir
        self.result_dir = args.result_dir
        self.log_dir = args.log_dir
        self.dataset_name = args.dataset
        self.data_mean = args.data_mean

        self.light = args.light
        self.epoch = args.epoch
        self.init_epoch = args.init_epoch  # args.epoch // 20

        self.gan_type = args.gan_type
        self.batch_size = args.batch_size
        self.save_freq = args.save_freq

        self.init_lr = args.init_lr
        self.d_lr = args.d_lr
        self.g_lr = args.g_lr
        """ Weight """
        self.g_adv_weight = args.g_adv_weight
        self.d_adv_weight = args.d_adv_weight
        self.con_weight = args.con_weight
        self.sty_weight = args.sty_weight
        self.color_weight = args.color_weight
        self.tv_weight = args.tv_weight

        self.training_rate = args.training_rate
        self.ld = args.ld

        self.img_size = args.img_size
        self.img_ch = args.img_ch
        """ Discriminator """
        self.n_dis = args.n_dis
        self.ch = args.ch
        self.sn = args.sn

        self.sample_dir = os.path.join(args.sample_dir, self.model_dir)
        check_folder(self.sample_dir)

        self.real = tf.placeholder(
            tf.float32,
            [self.batch_size, self.img_size[0], self.img_size[1], self.img_ch],
            name='real_A')
        self.anime = tf.placeholder(
            tf.float32,
            [self.batch_size, self.img_size[0], self.img_size[1], self.img_ch],
            name='anime_A')
        self.anime_smooth = tf.placeholder(
            tf.float32,
            [self.batch_size, self.img_size[0], self.img_size[1], self.img_ch],
            name='anime_smooth_A')
        self.test_real = tf.placeholder(tf.float32,
                                        [1, None, None, self.img_ch],
                                        name='test_input')

        self.anime_gray = tf.placeholder(
            tf.float32,
            [self.batch_size, self.img_size[0], self.img_size[1], self.img_ch],
            name='anime_B')

        self.real_image_generator = ImageGenerator('./dataset/train_photo',
                                                   self.img_size,
                                                   self.batch_size,
                                                   self.data_mean)
        self.anime_image_generator = ImageGenerator(
            './dataset/{}'.format(self.dataset_name + '/style'), self.img_size,
            self.batch_size, self.data_mean)
        self.anime_smooth_generator = ImageGenerator(
            './dataset/{}'.format(self.dataset_name + '/smooth'),
            self.img_size, self.batch_size, self.data_mean)
        self.dataset_num = max(self.real_image_generator.num_images,
                               self.anime_image_generator.num_images)

        self.vgg = Vgg19()

        print()
        print("##### Information #####")
        print("# gan type : ", self.gan_type)
        print("# light : ", self.light)
        print("# dataset : ", self.dataset_name)
        print("# max dataset number : ", self.dataset_num)
        print("# batch_size : ", self.batch_size)
        print("# epoch : ", self.epoch)
        print("# init_epoch : ", self.init_epoch)
        print("# training image size [H, W] : ", self.img_size)
        print(
            "# g_adv_weight,d_adv_weight,con_weight,sty_weight,color_weight,tv_weight : ",
            self.g_adv_weight, self.d_adv_weight, self.con_weight,
            self.sty_weight, self.color_weight, self.tv_weight)
        print("# init_lr,g_lr,d_lr : ", self.init_lr, self.g_lr, self.d_lr)
        print(f"# training_rate G -- D: {self.training_rate} : 1")
        print()
示例#2
0
    def build_model(self):
        # ---------------------------------------------------------------------
        # Define placeholders
        # ---------------------------------------------------------------------

        self.real = tf.placeholder(tf.float32,
                                   [self.batch_size, self.img_size[0],
                                       self.img_size[1], self.img_ch],
                                   name='real_A')

        self.anime = tf.placeholder(tf.float32,
                                    [self.batch_size, self.img_size[0], self.img_size[1],
                                     self.img_ch],
                                    name='anime_A')
        self.anime_smooth = tf.placeholder(tf.float32,
                                           [self.batch_size, self.img_size[0],
                                               self.img_size[1], self.img_ch],
                                           name='anime_smooth_A')
        self.anime_gray = tf.placeholder(tf.float32,
                                         [self.batch_size, self.img_size[0],
                                             self.img_size[1], self.img_ch],
                                         name='anime_B')

        self.test_real = tf.placeholder(tf.float32,
                                        [1, None, None, self.img_ch],
                                        name='test_input')

        self.vgg = Vgg19()

        # Create graph

        self.generated = self.generator(self.real)
        self.test_generated = self.generator(self.test_real, reuse=True)

        anime_logit = self.discriminator(self.anime)
        anime_gray_logit = self.discriminator(self.anime_gray, reuse=True)

        generated_logit = self.discriminator(self.generated, reuse=True)
        smooth_logit = self.discriminator(self.anime_smooth, reuse=True)

        # Define loss

        if self.gan_type.__contains__('gp') \
                or self.gan_type.__contains__('lp') \
                or self.gan_type.__contains__('dragan'):
            GP = self.gradient_penalty(real=self.anime, fake=self.generated)
        else:
            GP = 0.0

        # Pretraining loss
        init_c_loss = con_loss(self.vgg, self.real, self.generated)
        init_loss = self.con_weight * init_c_loss
        self.init_loss = init_loss

        # GAN training loss
        c_loss, s_loss = con_sty_loss(self.vgg,
                                      self.real,
                                      self.anime_gray,
                                      self.generated)
        tv_loss = self.tv_weight * total_variation_loss(self.generated)
        # All generator but adversarial loss
        t_loss = self.con_weight * c_loss \
            + self.sty_weight * s_loss \
            + self.color_weight * color_loss(self.real, self.generated) \
            + tv_loss

        g_loss = self.g_adv_weight * generator_loss(self.gan_type,
                                                    generated_logit)
        d_loss = self.d_adv_weight * discriminator_loss(self.gan_type,
                                                        anime_logit,
                                                        anime_gray_logit,
                                                        generated_logit,
                                                        smooth_logit,
                                                        dataset=self.dataset_name) + GP

        self.Generator_loss = t_loss + g_loss
        self.Discriminator_loss = d_loss

        # Optimizer ops
        t_vars = tf.trainable_variables()
        G_vars = [var for var in t_vars if 'generator' in var.name]
        D_vars = [var for var in t_vars if 'discriminator' in var.name]
        self.init_optim = tf.train.AdamOptimizer(self.init_lr, beta1=0.5, beta2=0.999) \
            .minimize(self.init_loss, var_list=G_vars)
        self.G_optim = tf.train.AdamOptimizer(self.g_lr, beta1=0.5, beta2=0.999) \
            .minimize(self.Generator_loss, var_list=G_vars)
        self.D_optim = tf.train.AdamOptimizer(self.d_lr, beta1=0.5, beta2=0.999) \
            .minimize(self.Discriminator_loss, var_list=D_vars)

        # Summary ops
        self.G_init_loss = tf.summary.scalar("G_init", init_loss)
        self.G_loss = tf.summary.scalar("Generator_loss",
                                        self.Generator_loss)
        self.D_loss = tf.summary.scalar("Discriminator_loss",
                                        self.Discriminator_loss)
        self.G_gan = tf.summary.scalar("G_gan", g_loss)
        self.G_vgg = tf.summary.scalar("G_vgg", t_loss)
        self.V_loss_merge = tf.summary.merge([self.G_init_loss])
        self.G_loss_merge = tf.summary.merge([self.G_loss,
                                              self.G_gan,
                                              self.G_vgg,
                                              self.G_init_loss])
        self.D_loss_merge = tf.summary.merge([self.D_loss])