def __init__(self, sess, args): self.model_name = 'AnimeGANv2' self.sess = sess self.checkpoint_dir = args.checkpoint_dir self.result_dir = args.result_dir self.log_dir = args.log_dir self.dataset_name = args.dataset self.data_mean = args.data_mean self.light = args.light self.epoch = args.epoch self.init_epoch = args.init_epoch # args.epoch // 20 self.gan_type = args.gan_type self.batch_size = args.batch_size self.save_freq = args.save_freq self.init_lr = args.init_lr self.d_lr = args.d_lr self.g_lr = args.g_lr """ Weight """ self.g_adv_weight = args.g_adv_weight self.d_adv_weight = args.d_adv_weight self.con_weight = args.con_weight self.sty_weight = args.sty_weight self.color_weight = args.color_weight self.tv_weight = args.tv_weight self.training_rate = args.training_rate self.ld = args.ld self.img_size = args.img_size self.img_ch = args.img_ch """ Discriminator """ self.n_dis = args.n_dis self.ch = args.ch self.sn = args.sn self.sample_dir = os.path.join(args.sample_dir, self.model_dir) check_folder(self.sample_dir) self.real = tf.placeholder( tf.float32, [self.batch_size, self.img_size[0], self.img_size[1], self.img_ch], name='real_A') self.anime = tf.placeholder( tf.float32, [self.batch_size, self.img_size[0], self.img_size[1], self.img_ch], name='anime_A') self.anime_smooth = tf.placeholder( tf.float32, [self.batch_size, self.img_size[0], self.img_size[1], self.img_ch], name='anime_smooth_A') self.test_real = tf.placeholder(tf.float32, [1, None, None, self.img_ch], name='test_input') self.anime_gray = tf.placeholder( tf.float32, [self.batch_size, self.img_size[0], self.img_size[1], self.img_ch], name='anime_B') self.real_image_generator = ImageGenerator('./dataset/train_photo', self.img_size, self.batch_size, self.data_mean) self.anime_image_generator = ImageGenerator( './dataset/{}'.format(self.dataset_name + '/style'), self.img_size, self.batch_size, self.data_mean) self.anime_smooth_generator = ImageGenerator( './dataset/{}'.format(self.dataset_name + '/smooth'), self.img_size, self.batch_size, self.data_mean) self.dataset_num = max(self.real_image_generator.num_images, self.anime_image_generator.num_images) self.vgg = Vgg19() print() print("##### Information #####") print("# gan type : ", self.gan_type) print("# light : ", self.light) print("# dataset : ", self.dataset_name) print("# max dataset number : ", self.dataset_num) print("# batch_size : ", self.batch_size) print("# epoch : ", self.epoch) print("# init_epoch : ", self.init_epoch) print("# training image size [H, W] : ", self.img_size) print( "# g_adv_weight,d_adv_weight,con_weight,sty_weight,color_weight,tv_weight : ", self.g_adv_weight, self.d_adv_weight, self.con_weight, self.sty_weight, self.color_weight, self.tv_weight) print("# init_lr,g_lr,d_lr : ", self.init_lr, self.g_lr, self.d_lr) print(f"# training_rate G -- D: {self.training_rate} : 1") print()
def build_model(self): # --------------------------------------------------------------------- # Define placeholders # --------------------------------------------------------------------- self.real = tf.placeholder(tf.float32, [self.batch_size, self.img_size[0], self.img_size[1], self.img_ch], name='real_A') self.anime = tf.placeholder(tf.float32, [self.batch_size, self.img_size[0], self.img_size[1], self.img_ch], name='anime_A') self.anime_smooth = tf.placeholder(tf.float32, [self.batch_size, self.img_size[0], self.img_size[1], self.img_ch], name='anime_smooth_A') self.anime_gray = tf.placeholder(tf.float32, [self.batch_size, self.img_size[0], self.img_size[1], self.img_ch], name='anime_B') self.test_real = tf.placeholder(tf.float32, [1, None, None, self.img_ch], name='test_input') self.vgg = Vgg19() # Create graph self.generated = self.generator(self.real) self.test_generated = self.generator(self.test_real, reuse=True) anime_logit = self.discriminator(self.anime) anime_gray_logit = self.discriminator(self.anime_gray, reuse=True) generated_logit = self.discriminator(self.generated, reuse=True) smooth_logit = self.discriminator(self.anime_smooth, reuse=True) # Define loss if self.gan_type.__contains__('gp') \ or self.gan_type.__contains__('lp') \ or self.gan_type.__contains__('dragan'): GP = self.gradient_penalty(real=self.anime, fake=self.generated) else: GP = 0.0 # Pretraining loss init_c_loss = con_loss(self.vgg, self.real, self.generated) init_loss = self.con_weight * init_c_loss self.init_loss = init_loss # GAN training loss c_loss, s_loss = con_sty_loss(self.vgg, self.real, self.anime_gray, self.generated) tv_loss = self.tv_weight * total_variation_loss(self.generated) # All generator but adversarial loss t_loss = self.con_weight * c_loss \ + self.sty_weight * s_loss \ + self.color_weight * color_loss(self.real, self.generated) \ + tv_loss g_loss = self.g_adv_weight * generator_loss(self.gan_type, generated_logit) d_loss = self.d_adv_weight * discriminator_loss(self.gan_type, anime_logit, anime_gray_logit, generated_logit, smooth_logit, dataset=self.dataset_name) + GP self.Generator_loss = t_loss + g_loss self.Discriminator_loss = d_loss # Optimizer ops t_vars = tf.trainable_variables() G_vars = [var for var in t_vars if 'generator' in var.name] D_vars = [var for var in t_vars if 'discriminator' in var.name] self.init_optim = tf.train.AdamOptimizer(self.init_lr, beta1=0.5, beta2=0.999) \ .minimize(self.init_loss, var_list=G_vars) self.G_optim = tf.train.AdamOptimizer(self.g_lr, beta1=0.5, beta2=0.999) \ .minimize(self.Generator_loss, var_list=G_vars) self.D_optim = tf.train.AdamOptimizer(self.d_lr, beta1=0.5, beta2=0.999) \ .minimize(self.Discriminator_loss, var_list=D_vars) # Summary ops self.G_init_loss = tf.summary.scalar("G_init", init_loss) self.G_loss = tf.summary.scalar("Generator_loss", self.Generator_loss) self.D_loss = tf.summary.scalar("Discriminator_loss", self.Discriminator_loss) self.G_gan = tf.summary.scalar("G_gan", g_loss) self.G_vgg = tf.summary.scalar("G_vgg", t_loss) self.V_loss_merge = tf.summary.merge([self.G_init_loss]) self.G_loss_merge = tf.summary.merge([self.G_loss, self.G_gan, self.G_vgg, self.G_init_loss]) self.D_loss_merge = tf.summary.merge([self.D_loss])