def build_model(self): """ Graph """ if self.phase == 'train': self.d_loss_per_res = {} self.g_loss_per_res = {} self.generator_optim = {} self.discriminator_optim = {} self.alpha_summary_per_res = {} self.d_summary_per_res = {} self.g_summary_per_res = {} self.train_fake_images = {} for res in self.resolutions[self.resolutions.index(self.start_res ):]: g_loss_per_gpu = [] d_loss_per_gpu = [] train_fake_images_per_gpu = [] batch_size = self.batch_sizes.get(res, self.batch_size_base) global_step = tf.get_variable( 'global_step_{}'.format(res), shape=[], dtype=tf.float32, initializer=tf.initializers.zeros(), trainable=False, aggregation=tf.VariableAggregation.ONLY_FIRST_TOWER) alpha_const, zero_constant = get_alpha_const( self.iteration // 2, batch_size * self.gpu_num, global_step) # smooth transition variable do_train_trans = self.train_with_trans[res] alpha = tf.get_variable( 'alpha_{}'.format(res), shape=[], dtype=tf.float32, initializer=tf.initializers.ones() if do_train_trans else tf.initializers.zeros(), trainable=False, aggregation=tf.VariableAggregation.ONLY_FIRST_TOWER) if do_train_trans: alpha_assign_op = tf.assign(alpha, alpha_const) else: alpha_assign_op = tf.assign(alpha, zero_constant) image_class = ImageData(res) main_ds = tf.data.Dataset.from_tensor_slices(self.dataset) main_ds = main_ds. \ apply(shuffle_and_repeat(self.dataset_num)). \ apply(map_and_batch(image_class.image_processing, batch_size, num_parallel_batches=32, drop_remainder=True)). \ prefetch(self.gpu_num * 256) with tf.control_dependencies([alpha_assign_op]): for gpu_id in range(self.gpu_num): with tf.device( tf.DeviceSpec(device_type="GPU", device_index=gpu_id)): with tf.variable_scope(tf.get_variable_scope(), reuse=(gpu_id > 0)): # images gpu_device = '/gpu:{}'.format(gpu_id) inputs = main_ds. \ shard(self.gpu_num, gpu_id). \ apply(prefetch_to_device(gpu_device, None)) # When using dataset.prefetch, use buffer_size=None to let it detect optimal buffer size inputs_iterator = inputs.make_one_shot_iterator( ) real_img = inputs_iterator.get_next() z = tf.random_normal( shape=[batch_size, self.z_dim]) fake_img = self.generator(z, alpha, res) real_img = smooth_crossfade(real_img, alpha) real_logit = self.discriminator( real_img, alpha, res) fake_logit = self.discriminator( fake_img, alpha, res) # compute loss d_loss, g_loss = compute_loss( real_img, real_logit, fake_logit) d_loss_per_gpu.append(d_loss) g_loss_per_gpu.append(g_loss) train_fake_images_per_gpu.append(fake_img) print("Create graph for {} resolution".format(res)) # prepare appropriate training vars d_vars, g_vars = filter_trainable_variables(res) d_loss = tf.reduce_mean(d_loss_per_gpu) g_loss = tf.reduce_mean(g_loss_per_gpu) d_lr = self.d_learning_rates.get(res, self.learning_rate_base) g_lr = self.g_learning_rates.get(res, self.learning_rate_base) if self.gpu_num == 1: colocate_grad = False else: colocate_grad = True d_optim = tf.train.AdamOptimizer( d_lr, beta1=0.0, beta2=0.99, epsilon=1e-8).minimize( d_loss, var_list=d_vars, colocate_gradients_with_ops=colocate_grad) g_optim = tf.train.AdamOptimizer( g_lr, beta1=0.0, beta2=0.99, epsilon=1e-8).minimize( g_loss, var_list=g_vars, global_step=global_step, colocate_gradients_with_ops=colocate_grad) self.discriminator_optim[res] = d_optim self.generator_optim[res] = g_optim self.d_loss_per_res[res] = d_loss self.g_loss_per_res[res] = g_loss self.train_fake_images[res] = tf.concat( train_fake_images_per_gpu, axis=0) """ Summary """ self.alpha_summary_per_res[res] = tf.summary.scalar( "alpha_{}".format(res), alpha) self.d_summary_per_res[res] = tf.summary.scalar( "d_loss_{}".format(res), self.d_loss_per_res[res]) self.g_summary_per_res[res] = tf.summary.scalar( "g_loss_{}".format(res), self.g_loss_per_res[res]) else: """" Testing """ test_z = tf.random_normal(shape=[self.batch_size, self.z_dim]) alpha = tf.constant(0.0, dtype=tf.float32, shape=[]) self.fake_images = self.generator(test_z, alpha=alpha, target_img_size=self.img_size, is_training=False)
def build_model(self): if self.phase == 'train': self.lr = tf.placeholder(tf.float32, name='learning_rate') """ Input Image""" Image_Data_Class = ImageData(self.img_size, self.img_ch, self.augment_flag) trainA = tf.data.Dataset.from_tensor_slices(self.trainA_dataset) trainB = tf.data.Dataset.from_tensor_slices(self.trainB_dataset) gpu_device = '/gpu:0' trainA = trainA.apply(shuffle_and_repeat(self.dataset_num)).apply( map_and_batch(Image_Data_Class.image_processing, self.batch_size, num_parallel_batches=16, drop_remainder=True)).apply( prefetch_to_device(gpu_device, None)) trainB = trainB.apply(shuffle_and_repeat(self.dataset_num)).apply( map_and_batch(Image_Data_Class.image_processing, self.batch_size, num_parallel_batches=16, drop_remainder=True)).apply( prefetch_to_device(gpu_device, None)) trainA_iterator = trainA.make_one_shot_iterator() trainB_iterator = trainB.make_one_shot_iterator() self.domain_A = trainA_iterator.get_next() self.domain_B = trainB_iterator.get_next() """ Define Generator, Discriminator """ x_ab, cam_ab = self.generate_a2b(self.domain_A) # real a x_ba, cam_ba = self.generate_b2a(self.domain_B) # real b x_aba, _ = self.generate_b2a(x_ab, reuse=True) # real b x_bab, _ = self.generate_a2b(x_ba, reuse=True) # real a x_aa, cam_aa = self.generate_b2a(self.domain_A, reuse=True) # fake b x_bb, cam_bb = self.generate_a2b(self.domain_B, reuse=True) # fake a real_A_logit, real_A_cam_logit, real_B_logit, real_B_cam_logit = self.discriminate_real( self.domain_A, self.domain_B) fake_A_logit, fake_A_cam_logit, fake_B_logit, fake_B_cam_logit = self.discriminate_fake( x_ba, x_ab) """ Define Loss """ if self.gan_type.__contains__('wgan') or self.gan_type == 'dragan': GP_A, GP_CAM_A = self.gradient_panalty(real=self.domain_A, fake=x_ba, scope="discriminator_A") GP_B, GP_CAM_B = self.gradient_panalty(real=self.domain_B, fake=x_ab, scope="discriminator_B") else: GP_A, GP_CAM_A = 0, 0 GP_B, GP_CAM_B = 0, 0 G_ad_loss_A = (generator_loss(self.gan_type, fake_A_logit) + generator_loss(self.gan_type, fake_A_cam_logit)) G_ad_loss_B = (generator_loss(self.gan_type, fake_B_logit) + generator_loss(self.gan_type, fake_B_cam_logit)) D_ad_loss_A = ( discriminator_loss(self.gan_type, real_A_logit, fake_A_logit) + discriminator_loss(self.gan_type, real_A_cam_logit, fake_A_cam_logit) + GP_A + GP_CAM_A) D_ad_loss_B = ( discriminator_loss(self.gan_type, real_B_logit, fake_B_logit) + discriminator_loss(self.gan_type, real_B_cam_logit, fake_B_cam_logit) + GP_B + GP_CAM_B) reconstruction_A = L1_loss(x_aba, self.domain_A) # reconstruction reconstruction_B = L1_loss(x_bab, self.domain_B) # reconstruction identity_A = L1_loss(x_aa, self.domain_A) identity_B = L1_loss(x_bb, self.domain_B) cam_A = cam_loss(source=cam_ba, non_source=cam_aa) cam_B = cam_loss(source=cam_ab, non_source=cam_bb) Generator_A_gan = self.adv_weight * G_ad_loss_A Generator_A_cycle = self.cycle_weight * reconstruction_B Generator_A_identity = self.identity_weight * identity_A Generator_A_cam = self.cam_weight * cam_A Generator_B_gan = self.adv_weight * G_ad_loss_B Generator_B_cycle = self.cycle_weight * reconstruction_A Generator_B_identity = self.identity_weight * identity_B Generator_B_cam = self.cam_weight * cam_B Generator_A_loss = Generator_A_gan + Generator_A_cycle + Generator_A_identity + Generator_A_cam Generator_B_loss = Generator_B_gan + Generator_B_cycle + Generator_B_identity + Generator_B_cam Discriminator_A_loss = self.adv_weight * D_ad_loss_A Discriminator_B_loss = self.adv_weight * D_ad_loss_B self.Generator_loss = Generator_A_loss + Generator_B_loss + regularization_loss( 'generator') self.Discriminator_loss = Discriminator_A_loss + Discriminator_B_loss + regularization_loss( 'discriminator') """ Result Image """ self.fake_A = x_ba self.fake_B = x_ab self.real_A = self.domain_A self.real_B = self.domain_B """ Training """ t_vars = tf.trainable_variables() G_vars = [var for var in t_vars if 'generator' in var.name] D_vars = [var for var in t_vars if 'discriminator' in var.name] self.G_optim = tf.train.AdamOptimizer(self.lr, beta1=0.5, beta2=0.999).minimize( self.Generator_loss, var_list=G_vars) self.D_optim = tf.train.AdamOptimizer(self.lr, beta1=0.5, beta2=0.999).minimize( self.Discriminator_loss, var_list=D_vars) """" Summary """ self.all_G_loss = tf.summary.scalar("Generator_loss", self.Generator_loss) self.all_D_loss = tf.summary.scalar("Discriminator_loss", self.Discriminator_loss) self.G_A_loss = tf.summary.scalar("G_A_loss", Generator_A_loss) self.G_A_gan = tf.summary.scalar("G_A_gan", Generator_A_gan) self.G_A_cycle = tf.summary.scalar("G_A_cycle", Generator_A_cycle) self.G_A_identity = tf.summary.scalar("G_A_identity", Generator_A_identity) self.G_A_cam = tf.summary.scalar("G_A_cam", Generator_A_cam) self.G_B_loss = tf.summary.scalar("G_B_loss", Generator_B_loss) self.G_B_gan = tf.summary.scalar("G_B_gan", Generator_B_gan) self.G_B_cycle = tf.summary.scalar("G_B_cycle", Generator_B_cycle) self.G_B_identity = tf.summary.scalar("G_B_identity", Generator_B_identity) self.G_B_cam = tf.summary.scalar("G_B_cam", Generator_B_cam) self.D_A_loss = tf.summary.scalar("D_A_loss", Discriminator_A_loss) self.D_B_loss = tf.summary.scalar("D_B_loss", Discriminator_B_loss) self.rho_var = [] for var in tf.trainable_variables(): if 'rho' in var.name: self.rho_var.append(tf.summary.histogram(var.name, var)) self.rho_var.append( tf.summary.scalar(var.name + "_min", tf.reduce_min(var))) self.rho_var.append( tf.summary.scalar(var.name + "_max", tf.reduce_max(var))) self.rho_var.append( tf.summary.scalar(var.name + "_mean", tf.reduce_mean(var))) g_summary_list = [ self.G_A_loss, self.G_A_gan, self.G_A_cycle, self.G_A_identity, self.G_A_cam, self.G_B_loss, self.G_B_gan, self.G_B_cycle, self.G_B_identity, self.G_B_cam, self.all_G_loss ] g_summary_list.extend(self.rho_var) d_summary_list = [self.D_A_loss, self.D_B_loss, self.all_D_loss] self.G_loss = tf.summary.merge(g_summary_list) self.D_loss = tf.summary.merge(d_summary_list) else: """ Test """ self.test_domain_A = tf.placeholder( tf.float32, [1, self.img_size, self.img_size, self.img_ch], name='test_domain_A') self.test_domain_B = tf.placeholder( tf.float32, [1, self.img_size, self.img_size, self.img_ch], name='test_domain_B') self.test_fake_B, _ = self.generate_a2b(self.test_domain_A) self.test_fake_A, _ = self.generate_b2a(self.test_domain_B)
def build_model(self): """ Graph Input """ # images if self.custom_dataset: Image_Data_Class = ImageData(self.img_size, self.c_dim, crop_pos=self.crop_pos, zoom_range=self.zoom_range) inputs = tf.data.Dataset.from_tensor_slices(self.data) gpu_device = '/gpu:0' inputs = inputs.apply(shuffle_and_repeat(self.dataset_num)).apply( map_and_batch(Image_Data_Class.image_processing, self.batch_size, num_parallel_batches=16, drop_remainder=True)).apply( prefetch_to_device(gpu_device, self.batch_size)) inputs_iterator = tf.compat.v1.data.make_one_shot_iterator(inputs) self.inputs = inputs_iterator.get_next() else: self.inputs = tf.compat.v1.placeholder( tf.float32, [self.batch_size, self.img_size, self.img_size, self.c_dim], name='real_images') # noises self.z = tf.compat.v1.placeholder(tf.float32, [self.batch_size, 1, 1, self.z_dim], name='z') """ Loss Function """ # output of D for real images real_logits = self.discriminator(self.inputs) # output of D for fake images fake_images = self.generator(self.z) fake_logits = self.discriminator(fake_images, reuse=True) if self.gan_type.__contains__('wgan') or self.gan_type == 'dragan': GP = self.gradient_penalty(real=self.inputs, fake=fake_images) else: GP = 0 # get loss for discriminator self.d_loss = discriminator_loss( self.gan_type, real=real_logits, fake=fake_logits) + GP # get loss for generator self.g_loss = generator_loss(self.gan_type, fake=fake_logits) """ Training """ # divide trainable variables into a group for D and a group for G t_vars = tf.compat.v1.trainable_variables() d_vars = [var for var in t_vars if 'discriminator' in var.name] g_vars = [var for var in t_vars if 'generator' in var.name] # optimizers self.d_optim = tf.compat.v1.train.AdamOptimizer( self.d_learning_rate, beta1=self.beta1, beta2=self.beta2).minimize(self.d_loss, var_list=d_vars) self.g_optim = tf.compat.v1.train.AdamOptimizer( self.g_learning_rate, beta1=self.beta1, beta2=self.beta2).minimize(self.g_loss, var_list=g_vars) """" Testing """ # for test self.fake_images = self.generator(self.z, is_training=False, reuse=True) """ Summary """ self.d_sum = tf.compat.v1.summary.scalar("d_loss", self.d_loss) self.g_sum = tf.compat.v1.summary.scalar("g_loss", self.g_loss)
def build_model(self): if self.custom_dataset: Image_Data_Class = ImageData(self.img_size, self.c_dim) inputs = tf.data.Dataset.from_tensor_slices(self.data) gpu_device = '/gpu:0' inputs = inputs.apply(shuffle_and_repeat(self.dataset_num)).apply( map_and_batch(Image_Data_Class.image_processing, self.batch_size, num_parallel_batches=16, drop_remainder=True)).apply( prefetch_to_device(gpu_device, self.batch_size)) inputs_iterator = inputs.make_one_shot_iterator() self.inputs = inputs_iterator.get_next() else: self.inputs = tf.placeholder( tf.float32, [self.batch_size, self.img_size, self.img_size, self.c_dim], name='real_images') self.z = tf.placeholder(tf.float32, [self.batch_size, 1, 1, self.z_dim], name='z') real_logits = self.discriminator(self.inputs) fake_images = self.generator(self.z) fake_logits = self.discriminator(fake_images, reuse=True) if self.gan_type.__contains__('wgan') or self.gan_type == 'dragan': GP = self.gradient_penalty(real=self.inputs, fake=fake_images) else: GP = 0 self.d_loss = discriminator_loss( self.gan_type, real=real_logits, fake=fake_logits) + GP self.g_loss = generator_loss(self.gan_type, fake=fake_logits) t_vars = tf.trainable_variables() d_vars = [var for var in t_vars if 'discriminator' in var.name] g_vars = [var for var in t_vars if 'generator' in var.name] self.d_optim = tf.train.AdamOptimizer(self.d_learning_rate, beta1=self.beta1, beta2=self.beta2).minimize( self.d_loss, var_list=d_vars) self.g_optim = tf.train.AdamOptimizer(self.g_learning_rate, beta1=self.beta1, beta2=self.beta2).minimize( self.g_loss, var_list=g_vars) self.fake_images = self.generator(self.z, is_training=False, reuse=True) self.d_sum = tf.summary.scalar("d_loss", self.d_loss) self.g_sum = tf.summary.scalar("g_loss", self.g_loss)
def build_model(self): self.lr = tf.placeholder(tf.float32, name='learning_rate') """ Input Image""" Image_Data_Class = ImageData(self.img_size, self.img_ch, self.augment_flag) trainA = tf.data.Dataset.from_tensor_slices(self.trainA_dataset) trainB = tf.data.Dataset.from_tensor_slices(self.trainB_dataset) gpu_device = '/gpu:0' trainA = trainA.apply(shuffle_and_repeat(self.dataset_num)).apply(map_and_batch(Image_Data_Class.image_processing, self.batch_size, num_parallel_batches=16, drop_remainder=True)).apply(prefetch_to_device(gpu_device, self.batch_size)) trainB = trainB.apply(shuffle_and_repeat(self.dataset_num)).apply(map_and_batch(Image_Data_Class.image_processing, self.batch_size, num_parallel_batches=16, drop_remainder=True)).apply(prefetch_to_device(gpu_device, self.batch_size)) trainA_iterator = trainA.make_one_shot_iterator() trainB_iterator = trainB.make_one_shot_iterator() self.identity_A = trainA_iterator.get_next() self.shape_A = trainA_iterator.get_next() self.other_A = trainA_iterator.get_next() self.shape_B = trainB_iterator.get_next() self.test_identity_A = tf.placeholder(tf.float32, [1, self.img_size, self.img_size, self.img_ch], name='test_identity_A') self.test_shape_B = tf.placeholder(tf.float32, [1, self.img_size, self.img_size, self.img_ch], name='test_shape_B') """ Define Generator, Discriminator """ self.fake_same = self.generator(x_identity=self.identity_A, x_shape=self.shape_A) self.fake_diff = self.generator(x_identity=self.identity_A, x_shape=self.shape_B, reuse=True) fake_diff_shape = self.generator(x_identity=self.shape_B, x_shape=self.fake_diff, reuse=True) fake_diff_identity = self.generator(x_identity=self.fake_diff, x_shape=self.shape_B, reuse=True) real_logit = self.discriminator(x_identity=self.identity_A, x=self.other_A) fake_logit = self.discriminator(x_identity=self.identity_A, x=self.fake_diff, reuse=True) """ Define Loss """ g_identity_loss = self.adv_weight * generator_loss(self.gan_type, fake=minpool(fake_logit)) * 64 g_shape_loss_same = self.L1_weight * L1_loss(self.fake_same, self.shape_A) g_shape_loss_diff_shape = self.L1_weight * L1_loss(fake_diff_shape, self.shape_B) g_shape_loss_diff_identity = self.L1_weight * L1_loss(fake_diff_identity, self.fake_diff) self.Generator_loss = g_identity_loss + g_shape_loss_same + g_shape_loss_diff_shape + g_shape_loss_diff_identity self.Discriminator_loss = self.adv_weight * discriminator_loss(self.gan_type, real=real_logit, fake=fake_logit) """ Result Image """ self.test_fake = self.generator(x_identity=self.test_identity_A, x_shape=self.test_shape_B, reuse=True) """ Training """ t_vars = tf.trainable_variables() G_vars = [var for var in t_vars if 'generator' in var.name] D_vars = [var for var in t_vars if 'discriminator' in var.name] self.G_optim = tf.train.AdamOptimizer(self.lr, beta1=0.5, beta2=0.999).minimize(self.Generator_loss, var_list=G_vars) self.D_optim = tf.train.AdamOptimizer(self.lr, beta1=0.5, beta2=0.999).minimize(self.Discriminator_loss, var_list=D_vars) """" Summary """ self.G_loss = tf.summary.scalar("Generator_loss", self.Generator_loss) self.D_loss = tf.summary.scalar("Discriminator_loss", self.Discriminator_loss) self.G_identity = tf.summary.scalar("G_identity", g_identity_loss) self.G_shape_loss_same = tf.summary.scalar("G_shape_loss_same", g_shape_loss_same) self.G_shape_loss_diff_shape = tf.summary.scalar("G_shape_loss_diff_shape", g_shape_loss_diff_shape) self.G_shape_loss_diff_identity = tf.summary.scalar("G_shape_loss_diff_identity", g_shape_loss_diff_identity) self.G_loss_merge = tf.summary.merge([self.G_loss, self.G_identity, self.G_shape_loss_same, self.G_shape_loss_diff_shape, self.G_shape_loss_diff_identity]) self.D_loss_merge = tf.summary.merge([self.D_loss])
def build_model(self): self.lr = tf.placeholder(tf.float32, name='learning_rate') """ Input Image""" img_class = Image_data(self.img_height, self.img_width, self.img_ch, self.segmap_ch, self.dataset_path, self.augment_flag) img_class.preprocess() self.dataset_num = len(img_class.image) self.test_dataset_num = len(img_class.segmap_test) img_and_segmap = tf.data.Dataset.from_tensor_slices((img_class.image, img_class.segmap)) segmap_test = tf.data.Dataset.from_tensor_slices(img_class.segmap_test) gpu_device = '/gpu:0' img_and_segmap = img_and_segmap.apply(shuffle_and_repeat(self.dataset_num)).apply( map_and_batch(img_class.image_processing, self.batch_size, num_parallel_batches=16, drop_remainder=True)).apply(prefetch_to_device(gpu_device, self.batch_size)) segmap_test = segmap_test.apply(shuffle_and_repeat(self.dataset_num)).apply( map_and_batch(img_class.test_image_processing, batch_size=self.batch_size, num_parallel_batches=16, drop_remainder=True)).apply(prefetch_to_device(gpu_device, self.batch_size)) img_and_segmap_iterator = img_and_segmap.make_one_shot_iterator() segmap_test_iterator = segmap_test.make_one_shot_iterator() self.real_x, self.real_x_segmap, self.real_x_segmap_onehot = img_and_segmap_iterator.get_next() self.real_x_segmap_test, self.real_x_segmap_test_onehot = segmap_test_iterator.get_next() """ Define Generator, Discriminator """ fake_x, x_mean, x_var = self.image_translate(segmap_img=self.real_x_segmap_onehot, x_img=self.real_x) real_logit, fake_logit = self.image_discriminate(segmap_img=self.real_x_segmap_onehot, real_img=self.real_x, fake_img=fake_x) if self.gan_type.__contains__('wgan') or self.gan_type == 'dragan': GP = self.gradient_penalty(real=self.real_x, segmap=self.real_x_segmap_onehot, fake=fake_x) else: GP = 0 """ Define Loss """ g_adv_loss = self.adv_weight * generator_loss(self.gan_type, fake_logit) g_kl_loss = self.kl_weight * kl_loss(x_mean, x_var) g_vgg_loss = self.vgg_weight * VGGLoss()(self.real_x, fake_x) g_feature_loss = self.feature_weight * feature_loss(real_logit, fake_logit) g_reg_loss = regularization_loss('generator') + regularization_loss('encoder') d_adv_loss = self.adv_weight * (discriminator_loss(self.gan_type, real_logit, fake_logit) + GP) d_reg_loss = regularization_loss('discriminator') self.g_loss = g_adv_loss + g_kl_loss + g_vgg_loss + g_feature_loss + g_reg_loss self.d_loss = d_adv_loss + d_reg_loss """ Result Image """ self.fake_x = fake_x self.random_fake_x, _, _ = self.image_translate(segmap_img=self.real_x_segmap_onehot, random_style=True, reuse=True) """ Test """ self.test_segmap_image = tf.placeholder(tf.float32, [1, self.img_height, self.img_width, len(img_class.color_value_dict)]) self.random_test_fake_x, _, _ = self.image_translate(segmap_img=self.test_segmap_image, random_style=True, reuse=True) self.test_guide_image = tf.placeholder(tf.float32, [1, self.img_height, self.img_width, self.img_ch]) self.guide_test_fake_x, _, _ = self.image_translate(segmap_img=self.test_segmap_image, x_img=self.test_guide_image, reuse=True) """ Training """ t_vars = tf.trainable_variables() G_vars = [var for var in t_vars if 'encoder' in var.name or 'generator' in var.name] D_vars = [var for var in t_vars if 'discriminator' in var.name] if self.TTUR : beta1 = 0.0 beta2 = 0.9 g_lr = self.lr / 2 d_lr = self.lr * 2 else : beta1 = self.beta1 beta2 = self.beta2 g_lr = self.lr d_lr = self.lr self.G_optim = tf.train.AdamOptimizer(g_lr, beta1=beta1, beta2=beta2).minimize(self.g_loss, var_list=G_vars) self.D_optim = tf.train.AdamOptimizer(d_lr, beta1=beta1, beta2=beta2).minimize(self.d_loss, var_list=D_vars) """" Summary """ self.summary_g_loss = tf.summary.scalar("g_loss", self.g_loss) self.summary_d_loss = tf.summary.scalar("d_loss", self.d_loss) self.summary_g_adv_loss = tf.summary.scalar("g_adv_loss", g_adv_loss) self.summary_g_kl_loss = tf.summary.scalar("g_kl_loss", g_kl_loss) self.summary_g_vgg_loss = tf.summary.scalar("g_vgg_loss", g_vgg_loss) self.summary_g_feature_loss = tf.summary.scalar("g_feature_loss", g_feature_loss) g_summary_list = [self.summary_g_loss, self.summary_g_adv_loss, self.summary_g_kl_loss, self.summary_g_vgg_loss, self.summary_g_feature_loss] d_summary_list = [self.summary_d_loss] self.G_loss = tf.summary.merge(g_summary_list) self.D_loss = tf.summary.merge(d_summary_list)
def build_model(self): self.lr = tf.placeholder(tf.float32, name='learning_rate') """ Input Image""" Image_data_class = ImageData(load_size=self.img_size, channels=self.img_ch, data_path=self.dataset_path, selected_attrs=self.selected_attrs, augment_flag=self.augment_flag) Image_data_class.preprocess() train_dataset_num = len(Image_data_class.train_dataset) test_dataset_num = len(Image_data_class.test_dataset) train_dataset = tf.data.Dataset.from_tensor_slices( (Image_data_class.train_dataset, Image_data_class.train_dataset_label, Image_data_class.train_dataset_fix_label)) test_dataset = tf.data.Dataset.from_tensor_slices( (Image_data_class.test_dataset, Image_data_class.test_dataset_label, Image_data_class.test_dataset_fix_label)) gpu_device = '/gpu:0' train_dataset = train_dataset.\ apply(shuffle_and_repeat(train_dataset_num)).\ apply(map_and_batch(Image_data_class.image_processing, self.batch_size, num_parallel_batches=8, drop_remainder=True)).\ apply(prefetch_to_device(gpu_device, self.batch_size)) test_dataset = test_dataset.\ apply(shuffle_and_repeat(test_dataset_num)).\ apply(map_and_batch(Image_data_class.image_processing, self.batch_size, num_parallel_batches=8, drop_remainder=True)).\ apply(prefetch_to_device(gpu_device, self.batch_size)) train_dataset_iterator = train_dataset.make_one_shot_iterator() test_dataset_iterator = test_dataset.make_one_shot_iterator() self.x_real, label_org, label_fix_list = train_dataset_iterator.get_next( ) # Input image / Original domain labels label_trg = tf.random_shuffle(label_org) # Target domain labels label_fix_list = tf.transpose(label_fix_list, perm=[1, 0, 2]) self.x_test, test_label_org, test_label_fix_list = test_dataset_iterator.get_next( ) # Input image / Original domain labels test_label_fix_list = tf.transpose(test_label_fix_list, perm=[1, 0, 2]) self.custom_image = tf.placeholder( tf.float32, [1, self.img_size, self.img_size, self.img_ch], name='custom_image') # Custom Image custom_label_fix_list = tf.transpose(create_labels( self.custom_label, self.selected_attrs), perm=[1, 0, 2]) """ Define Generator, Discriminator """ x_fake = self.generator(self.x_real, label_trg) # real a x_recon = self.generator(x_fake, label_org, reuse=True) # real b real_logit, real_cls = self.discriminator(self.x_real) fake_logit, fake_cls = self.discriminator(x_fake, reuse=True) """ Define Loss """ if self.gan_type.__contains__('wgan') or self.gan_type == 'dragan': GP = self.gradient_panalty(real=self.x_real, fake=x_fake) else: GP = 0 g_adv_loss = generator_loss(loss_func=self.gan_type, fake=fake_logit) g_cls_loss = classification_loss(logit=fake_cls, label=label_trg) g_rec_loss = L1_loss(self.x_real, x_recon) d_adv_loss = discriminator_loss( loss_func=self.gan_type, real=real_logit, fake=fake_logit) + GP d_cls_loss = classification_loss(logit=real_cls, label=label_org) self.d_loss = self.adv_weight * d_adv_loss + self.cls_weight * d_cls_loss self.g_loss = self.adv_weight * g_adv_loss + self.cls_weight * g_cls_loss + self.rec_weight * g_rec_loss """ Result Image """ self.x_fake_list = tf.map_fn( lambda x: self.generator(self.x_real, x, reuse=True), label_fix_list, dtype=tf.float32) """ Test Image """ self.x_test_fake_list = tf.map_fn( lambda x: self.generator(self.x_test, x, reuse=True), test_label_fix_list, dtype=tf.float32) self.custom_fake_image = tf.map_fn( lambda x: self.generator(self.custom_image, x, reuse=True), custom_label_fix_list, dtype=tf.float32) """ Training """ t_vars = tf.trainable_variables() G_vars = [var for var in t_vars if 'generator' in var.name] D_vars = [var for var in t_vars if 'discriminator' in var.name] self.g_optimizer = tf.train.AdamOptimizer(self.lr, beta1=0.5, beta2=0.999).minimize( self.g_loss, var_list=G_vars) self.d_optimizer = tf.train.AdamOptimizer(self.lr, beta1=0.5, beta2=0.999).minimize( self.d_loss, var_list=D_vars) """" Summary """ self.Generator_loss = tf.summary.scalar("Generator_loss", self.g_loss) self.Discriminator_loss = tf.summary.scalar("Discriminator_loss", self.d_loss) self.g_adv_loss = tf.summary.scalar("g_adv_loss", g_adv_loss) self.g_cls_loss = tf.summary.scalar("g_cls_loss", g_cls_loss) self.g_rec_loss = tf.summary.scalar("g_rec_loss", g_rec_loss) self.d_adv_loss = tf.summary.scalar("d_adv_loss", d_adv_loss) self.d_cls_loss = tf.summary.scalar("d_cls_loss", d_cls_loss) self.g_summary_loss = tf.summary.merge([ self.Generator_loss, self.g_adv_loss, self.g_cls_loss, self.g_rec_loss ]) self.d_summary_loss = tf.summary.merge( [self.Discriminator_loss, self.d_adv_loss, self.d_cls_loss])
def build_model(self): label_fix_onehot_list = [] """ Input Image""" if self.dataset_name == 'celebA-HQ' or self.dataset_name == 'celebA': img_class = ImageData_celebA(self.img_height, self.img_width, self.img_ch, self.dataset_path, self.label_list, self.augment_flag) img_class.preprocess(self.phase) else: img_class = Image_data(self.img_height, self.img_width, self.img_ch, self.dataset_path, self.label_list, self.augment_flag) img_class.preprocess() label_fix_onehot_list = img_class.label_onehot_list label_fix_onehot_list = tf.tile( tf.expand_dims(label_fix_onehot_list, axis=1), [1, self.batch_size, 1]) dataset_num = len(img_class.image) print("Dataset number : ", dataset_num) if self.phase == 'train': self.lr = tf.placeholder(tf.float32, name='learning_rate') if self.dataset_name == 'celebA-HQ' or self.dataset_name == 'celebA': img_and_label = tf.data.Dataset.from_tensor_slices( (img_class.image, img_class.label, img_class.train_label_onehot_list)) else: img_and_label = tf.data.Dataset.from_tensor_slices( (img_class.image, img_class.label)) gpu_device = '/gpu:0' img_and_label = img_and_label.apply( shuffle_and_repeat(dataset_num)).apply( map_and_batch(img_class.image_processing, self.batch_size, num_parallel_batches=16, drop_remainder=True)).apply( prefetch_to_device(gpu_device, None)) img_and_label_iterator = img_and_label.make_one_shot_iterator() if self.dataset_name == 'celebA-HQ' or self.dataset_name == 'celebA': self.x_real, label_org, label_fix_onehot_list = img_and_label_iterator.get_next( ) label_trg = tf.random_shuffle( label_org) # Target domain labels label_fix_onehot_list = tf.transpose(label_fix_onehot_list, perm=[1, 0, 2]) else: self.x_real, label_org = img_and_label_iterator.get_next() label_trg = tf.random_shuffle( label_org) # Target domain labels """ Define Generator, Discriminator """ fake_style_code = tf.random_normal( shape=[self.batch_size, self.style_dim]) x_fake = self.generator(self.x_real, label_trg, fake_style_code) # real a recon_style_code = tf.random_normal( shape=[self.batch_size, self.style_dim]) x_recon = self.generator(x_fake, label_org, recon_style_code, reuse=True) # real b real_logit, real_cls, _ = self.discriminator(self.x_real) fake_logit, fake_cls, fake_noise = self.discriminator(x_fake, reuse=True) """ Define Loss """ if self.gan_type.__contains__('wgan') or self.gan_type == 'dragan': GP = self.gradient_panalty(real=self.x_real, fake=x_fake) else: GP = 0 g_adv_loss = self.adv_weight * generator_loss( self.gan_type, fake_logit) g_cls_loss = self.cls_weight * classification_loss(logit=fake_cls, label=label_trg) g_rec_loss = self.rec_weight * L1_loss(self.x_real, x_recon) g_noise_loss = self.noise_weight * L1_loss(fake_style_code, fake_noise) d_adv_loss = self.adv_weight * discriminator_loss( self.gan_type, real_logit, fake_logit) + GP d_cls_loss = self.cls_weight * classification_loss(logit=real_cls, label=label_org) d_noise_loss = self.noise_weight * L1_loss(fake_style_code, fake_noise) self.d_loss = d_adv_loss + d_cls_loss + d_noise_loss self.g_loss = g_adv_loss + g_cls_loss + g_rec_loss + g_noise_loss """ Result Image """ if self.dataset_name == 'celebA-HQ' or self.dataset_name == 'celebA': self.x_fake_list = [] for _ in range(self.num_style): random_style_code = tf.random_normal( shape=[self.batch_size, self.style_dim]) self.x_fake_list.append( tf.map_fn(lambda c: self.generator( self.x_real, c, random_style_code, reuse=True), label_fix_onehot_list, dtype=tf.float32)) else: self.x_fake_list = [] for _ in range(self.num_style): random_style_code = tf.random_normal( shape=[self.batch_size, self.style_dim]) self.x_fake_list.append( tf.map_fn(lambda c: self.generator( self.x_real, c, random_style_code, reuse=True), label_fix_onehot_list, dtype=tf.float32)) """ Training """ t_vars = tf.trainable_variables() G_vars = [var for var in t_vars if 'generator' in var.name] D_vars = [var for var in t_vars if 'discriminator' in var.name] self.g_optimizer = tf.train.AdamOptimizer(self.lr, beta1=0.5, beta2=0.999).minimize( self.g_loss, var_list=G_vars) self.d_optimizer = tf.train.AdamOptimizer(self.lr, beta1=0.5, beta2=0.999).minimize( self.d_loss, var_list=D_vars) """" Summary """ self.Generator_loss = tf.summary.scalar("g_loss", self.g_loss) self.Discriminator_loss = tf.summary.scalar("d_loss", self.d_loss) self.g_adv_loss = tf.summary.scalar("g_adv_loss", g_adv_loss) self.g_cls_loss = tf.summary.scalar("g_cls_loss", g_cls_loss) self.g_rec_loss = tf.summary.scalar("g_rec_loss", g_rec_loss) self.g_noise_loss = tf.summary.scalar("g_noise_loss", g_noise_loss) self.d_adv_loss = tf.summary.scalar("d_adv_loss", d_adv_loss) self.d_cls_loss = tf.summary.scalar("d_cls_loss", d_cls_loss) self.d_noise_loss = tf.summary.scalar("d_noise_loss", d_noise_loss) self.g_summary_loss = tf.summary.merge([ self.Generator_loss, self.g_adv_loss, self.g_cls_loss, self.g_rec_loss, self.g_noise_loss ]) self.d_summary_loss = tf.summary.merge([ self.Discriminator_loss, self.d_adv_loss, self.d_cls_loss, self.d_noise_loss ]) else: """ Test """ if self.dataset_name == 'celebA-HQ' or self.dataset_name == 'celebA': img_and_label = tf.data.Dataset.from_tensor_slices( (img_class.test_image, img_class.test_label, img_class.test_label_onehot_list)) dataset_num = len(img_class.test_image) gpu_device = '/gpu:0' img_and_label = img_and_label.apply( shuffle_and_repeat(dataset_num)).apply( map_and_batch(img_class.image_processing, batch_size=self.batch_size, num_parallel_batches=16, drop_remainder=True)).apply( prefetch_to_device(gpu_device, None)) img_and_label_iterator = img_and_label.make_one_shot_iterator() self.x_test, _, self.test_label_fix_onehot_list = img_and_label_iterator.get_next( ) self.test_img_placeholder = tf.placeholder( tf.float32, [1, self.img_height, self.img_width, self.img_ch]) self.test_label_fix_placeholder = tf.placeholder( tf.float32, [self.c_dim, 1, self.c_dim]) self.custom_image = tf.placeholder( tf.float32, [1, self.img_height, self.img_width, self.img_ch], name='custom_image') # Custom Image custom_label_fix_onehot_list = tf.transpose( np.expand_dims(label2onehot(self.label_list), axis=0), perm=[1, 0, 2]) # [c_dim, bs, c_dim] """ Test Image """ test_random_style_code = tf.random_normal( shape=[1, self.style_dim]) self.x_test_fake_list = tf.map_fn( lambda c: self.generator(self.test_img_placeholder, c, test_random_style_code), self.test_label_fix_placeholder, dtype=tf.float32) self.custom_fake_image = tf.map_fn(lambda c: self.generator( self.custom_image, c, test_random_style_code, reuse=True), custom_label_fix_onehot_list, dtype=tf.float32) else: self.custom_image = tf.placeholder( tf.float32, [1, self.img_height, self.img_width, self.img_ch], name='custom_image') # Custom Image custom_label_fix_onehot_list = tf.transpose( np.expand_dims(label2onehot(self.label_list), axis=0), perm=[1, 0, 2]) # [c_dim, bs, c_dim] test_random_style_code = tf.random_normal( shape=[1, self.style_dim]) self.custom_fake_image = tf.map_fn( lambda c: self.generator(self.custom_image, c, test_random_style_code), custom_label_fix_onehot_list, dtype=tf.float32)
def build_model(self): """ Graph Input """ # images Image_Data_Class = ImageData(self.img_size, self.c_dim, self.custom_dataset) inputs = tf.data.Dataset.from_tensor_slices(self.data) gpu_device = '/gpu:0' inputs = inputs.\ apply(shuffle_and_repeat(self.dataset_num)).\ apply(map_and_batch(Image_Data_Class.image_processing, self.batch_size, num_parallel_batches=16, drop_remainder=True)).\ apply(prefetch_to_device(gpu_device, self.batch_size)) inputs_iterator = inputs.make_one_shot_iterator() self.inputs = inputs_iterator.get_next() # noises self.z = tf.truncated_normal(shape=[self.batch_size, 1, 1, self.z_dim], name='random_z') """ Loss Function """ # output of D for real images real_logits = self.discriminator(self.inputs) # output of D for fake images fake_images = self.generator(self.z) fake_logits = self.discriminator(fake_images, reuse=True) if self.gan_type.__contains__('wgan') or self.gan_type == 'dragan': GP = self.gradient_penalty(real=self.inputs, fake=fake_images) else: GP = 0 # get loss for discriminator self.d_loss = discriminator_loss( self.gan_type, real=real_logits, fake=fake_logits) + GP # get loss for generator self.g_loss = generator_loss(self.gan_type, fake=fake_logits) """ Training """ # divide trainable variables into a group for D and a group for G t_vars = tf.trainable_variables() d_vars = [var for var in t_vars if 'discriminator' in var.name] g_vars = [var for var in t_vars if 'generator' in var.name] # optimizers with tf.control_dependencies(tf.get_collection( tf.GraphKeys.UPDATE_OPS)): self.d_optim = tf.train.AdamOptimizer(self.d_learning_rate, beta1=self.beta1, beta2=self.beta2).minimize( self.d_loss, var_list=d_vars) self.opt = MovingAverageOptimizer(tf.train.AdamOptimizer( self.g_learning_rate, beta1=self.beta1, beta2=self.beta2), average_decay=self.moving_decay) self.g_optim = self.opt.minimize(self.g_loss, var_list=g_vars) """" Testing """ # for test self.fake_images = self.generator(self.z, is_training=False, reuse=True) """ Summary """ self.d_sum = tf.summary.scalar("d_loss", self.d_loss) self.g_sum = tf.summary.scalar("g_loss", self.g_loss)
def build_model(self): self.lr = tf.placeholder(tf.float32, name='lr') """ Input Image""" Image_Data_Class = ImageData(self.img_size, self.img_ch, self.augment_flag) trainA = tf.data.Dataset.from_tensor_slices(self.trainA_dataset) trainB = tf.data.Dataset.from_tensor_slices(self.trainB_dataset) gpu_device = '/gpu:0' trainA = trainA.apply(shuffle_and_repeat(self.dataset_num)).apply(map_and_batch(Image_Data_Class.image_processing, self.batch_size, num_parallel_batches=16, drop_remainder=True)).apply(prefetch_to_device(gpu_device, self.batch_size)) trainB = trainB.apply(shuffle_and_repeat(self.dataset_num)).apply(map_and_batch(Image_Data_Class.image_processing, self.batch_size, num_parallel_batches=16, drop_remainder=True)).apply(prefetch_to_device(gpu_device, self.batch_size)) trainA_iterator = trainA.make_one_shot_iterator() trainB_iterator = trainB.make_one_shot_iterator() self.domain_A = trainA_iterator.get_next() self.domain_B = trainB_iterator.get_next() """ Define Encoder, Generator, Discriminator """ random_z = tf.random_normal(shape=[self.batch_size, self.n_z], mean=0.0, stddev=1.0, dtype=tf.float32) # encode content_a, attribute_a, mean_a, logvar_a = self.Encoder_A(self.domain_A) content_b, attribute_b, mean_b, logvar_b = self.Encoder_B(self.domain_B) # decode (fake, identity, random) fake_a = self.Decoder_A(content_B=content_b, attribute_A=attribute_a) fake_b = self.Decoder_B(content_A=content_a, attribute_B=attribute_b) recon_a = self.Decoder_A(content_B=content_a, attribute_A=attribute_a, reuse=True) recon_b = self.Decoder_B(content_A=content_b, attribute_B=attribute_b, reuse=True) random_fake_a = self.Decoder_A(content_B=content_b, attribute_A=random_z, reuse=True) random_fake_b = self.Decoder_B(content_A=content_a, attribute_B=random_z, reuse=True) # encode & decode again for cycle-consistency content_fake_a, attribute_fake_a, _, _ = self.Encoder_A(fake_a, reuse=True) content_fake_b, attribute_fake_b, _, _ = self.Encoder_B(fake_b, reuse=True) cycle_a = self.Decoder_A(content_B=content_fake_b, attribute_A=attribute_fake_a, reuse=True) cycle_b = self.Decoder_B(content_A=content_fake_a, attribute_B=attribute_fake_b, reuse=True) # for latent regression _, attribute_fake_random_a, _, _ = self.Encoder_A(random_fake_a, random_fake=True, reuse=True) _, attribute_fake_random_b, _, _ = self.Encoder_B(random_fake_b, random_fake=True, reuse=True) # discriminate real_A_logit, real_B_logit = self.discriminate_real(self.domain_A, self.domain_B) fake_A_logit, fake_B_logit = self.discriminate_fake(fake_a, fake_b) random_fake_A_logit, random_fake_B_logit = self.discriminate_fake(random_fake_a, random_fake_b) content_A_logit, content_B_logit = self.discriminate_content(content_a, content_b) """ Define Loss """ g_adv_loss_a = generator_loss(self.gan_type, fake_A_logit) + generator_loss(self.gan_type, random_fake_A_logit) g_adv_loss_b = generator_loss(self.gan_type, fake_B_logit) + generator_loss(self.gan_type, random_fake_B_logit) g_con_loss_a = generator_loss(self.gan_type, content_A_logit, content=True) g_con_loss_b = generator_loss(self.gan_type, content_B_logit, content=True) g_cyc_loss_a = L1_loss(cycle_a, self.domain_A) g_cyc_loss_b = L1_loss(cycle_b, self.domain_B) g_rec_loss_a = L1_loss(recon_a, self.domain_A) g_rec_loss_b = L1_loss(recon_b, self.domain_B) g_latent_loss_a = L1_loss(attribute_fake_random_a, random_z) g_latent_loss_b = L1_loss(attribute_fake_random_b, random_z) if self.concat : g_kl_loss_a = kl_loss(mean_a, logvar_a) + l2_regularize(content_a) g_kl_loss_b = kl_loss(mean_b, logvar_b) + l2_regularize(content_b) else : g_kl_loss_a = l2_regularize(attribute_a) + l2_regularize(content_a) g_kl_loss_b = l2_regularize(attribute_b) + l2_regularize(content_b) d_adv_loss_a = discriminator_loss(self.gan_type, real_A_logit, fake_A_logit, random_fake_A_logit) d_adv_loss_b = discriminator_loss(self.gan_type, real_B_logit, fake_B_logit, random_fake_B_logit) d_con_loss = discriminator_loss(self.gan_type, content_A_logit, content_B_logit, content=True) Generator_A_domain_loss = self.domain_adv_w * g_adv_loss_a Generator_A_content_loss = self.content_adv_w * g_con_loss_a Generator_A_cycle_loss = self.cycle_w * g_cyc_loss_b Generator_A_recon_loss = self.recon_w * g_rec_loss_a Generator_A_latent_loss = self.latent_w * g_latent_loss_a Generator_A_kl_loss = self.kl_w * g_kl_loss_a Generator_A_loss = Generator_A_domain_loss + \ Generator_A_content_loss + \ Generator_A_cycle_loss + \ Generator_A_recon_loss + \ Generator_A_latent_loss + \ Generator_A_kl_loss Generator_B_domain_loss = self.domain_adv_w * g_adv_loss_b Generator_B_content_loss = self.content_adv_w * g_con_loss_b Generator_B_cycle_loss = self.cycle_w * g_cyc_loss_a Generator_B_recon_loss = self.recon_w * g_rec_loss_b Generator_B_latent_loss = self.latent_w * g_latent_loss_b Generator_B_kl_loss = self.kl_w * g_kl_loss_b Generator_B_loss = Generator_B_domain_loss + \ Generator_B_content_loss + \ Generator_B_cycle_loss + \ Generator_B_recon_loss + \ Generator_B_latent_loss + \ Generator_B_kl_loss Discriminator_A_loss = self.domain_adv_w * d_adv_loss_a Discriminator_B_loss = self.domain_adv_w * d_adv_loss_b Discriminator_content_loss = self.content_adv_w * d_con_loss self.Generator_loss = Generator_A_loss + Generator_B_loss self.Discriminator_loss = Discriminator_A_loss + Discriminator_B_loss self.Discriminator_content_loss = Discriminator_content_loss """ Training """ t_vars = tf.trainable_variables() G_vars = [var for var in t_vars if 'encoder' in var.name or 'generator' in var.name] D_vars = [var for var in t_vars if 'discriminator' in var.name and 'content' not in var.name] D_content_vars = [var for var in t_vars if 'content_discriminator' in var.name] grads, _ = tf.clip_by_global_norm(tf.gradients(self.Discriminator_content_loss, D_content_vars), clip_norm=5) self.G_optim = tf.train.AdamOptimizer(self.lr, beta1=0.5, beta2=0.999).minimize(self.Generator_loss, var_list=G_vars) self.D_optim = tf.train.AdamOptimizer(self.lr, beta1=0.5, beta2=0.999).minimize(self.Discriminator_loss, var_list=D_vars) self.D_content_optim = tf.train.AdamOptimizer(self.d_content_init_lr, beta1=0.5, beta2=0.999).apply_gradients(zip(grads, D_content_vars)) """" Summary """ self.lr_write = tf.summary.scalar("learning_rate", self.lr) self.all_G_loss = tf.summary.scalar("Generator_loss", self.Generator_loss) self.all_D_loss = tf.summary.scalar("Discriminator_loss", self.Discriminator_loss) self.G_A_loss = tf.summary.scalar("G_A_loss", Generator_A_loss) self.G_A_domain_loss = tf.summary.scalar("G_A_domain_loss", Generator_A_domain_loss) self.G_A_content_loss = tf.summary.scalar("G_A_content_loss", Generator_A_content_loss) self.G_A_cycle_loss = tf.summary.scalar("G_A_cycle_loss", Generator_A_cycle_loss) self.G_A_recon_loss = tf.summary.scalar("G_A_recon_loss", Generator_A_recon_loss) self.G_A_latent_loss = tf.summary.scalar("G_A_latent_loss", Generator_A_latent_loss) self.G_A_kl_loss = tf.summary.scalar("G_A_kl_loss", Generator_A_kl_loss) self.G_B_loss = tf.summary.scalar("G_B_loss", Generator_B_loss) self.G_B_domain_loss = tf.summary.scalar("G_B_domain_loss", Generator_B_domain_loss) self.G_B_content_loss = tf.summary.scalar("G_B_content_loss", Generator_B_content_loss) self.G_B_cycle_loss = tf.summary.scalar("G_B_cycle_loss", Generator_B_cycle_loss) self.G_B_recon_loss = tf.summary.scalar("G_B_recon_loss", Generator_B_recon_loss) self.G_B_latent_loss = tf.summary.scalar("G_B_latent_loss", Generator_B_latent_loss) self.G_B_kl_loss = tf.summary.scalar("G_B_kl_loss", Generator_B_kl_loss) self.D_A_loss = tf.summary.scalar("D_A_loss", Discriminator_A_loss) self.D_B_loss = tf.summary.scalar("D_B_loss", Discriminator_B_loss) self.G_loss = tf.summary.merge([self.G_A_loss, self.G_A_domain_loss, self.G_A_content_loss, self.G_A_cycle_loss, self.G_A_recon_loss, self.G_A_latent_loss, self.G_A_kl_loss, self.G_B_loss, self.G_B_domain_loss, self.G_B_content_loss, self.G_B_cycle_loss, self.G_B_recon_loss, self.G_B_latent_loss, self.G_B_kl_loss, self.all_G_loss]) self.D_loss = tf.summary.merge([self.D_A_loss, self.D_B_loss, self.all_D_loss]) self.D_content_loss = tf.summary.scalar("Discriminator_content_loss", self.Discriminator_content_loss) """ Image """ self.fake_A = random_fake_a self.fake_B = random_fake_b self.real_A = self.domain_A self.real_B = self.domain_B """ Test """ self.test_image = tf.placeholder(tf.float32, [1, self.img_size, self.img_size, self.img_ch], name='test_image') self.test_random_z = tf.random_normal(shape=[1, self.n_z], mean=0.0, stddev=1.0, dtype=tf.float32) test_content_a, _, _, _ = self.Encoder_A(self.test_image, is_training=False, reuse=True) test_content_b, _, _, _ = self.Encoder_B(self.test_image, is_training=False, reuse=True) self.test_fake_A = self.Decoder_A(content_B=test_content_b, attribute_A=self.test_random_z, reuse=True) self.test_fake_B = self.Decoder_B(content_A=test_content_a, attribute_B=self.test_random_z, reuse=True) """ Guided Image Translation """ self.content_image = tf.placeholder(tf.float32, [1, self.img_size, self.img_size, self.img_ch], name='content_image') self.attribute_image = tf.placeholder(tf.float32, [1, self.img_size, self.img_size, self.img_ch], name='guide_attribute_image') if self.direction == 'a2b' : guide_content_A, _, _, _ = self.Encoder_A(self.content_image, is_training=False, reuse=True) _, guide_attribute_B, _, _ = self.Encoder_B(self.attribute_image, is_training=False, reuse=True) self.guide_fake_B = self.Decoder_B(content_A=guide_content_A, attribute_B=guide_attribute_B, reuse=True) else : guide_content_B, _, _, _ = self.Encoder_B(self.content_image, is_training=False, reuse=True) _, guide_attribute_A, _, _ = self.Encoder_A(self.attribute_image, is_training=False, reuse=True) self.guide_fake_A = self.Decoder_A(content_B=guide_content_B, attribute_A=guide_attribute_A, reuse=True)
def build_model(self): if self.phase == 'train': #初始化步长 self.lr = tf.placeholder(tf.float32, name='learning_rate') """ Input Image""" Image_Data_Class = ImageData(self.img_size, self.img_ch, self.augment_flag) trainA = tf.data.Dataset.from_tensor_slices(self.trainA_dataset) trainB = tf.data.Dataset.from_tensor_slices(self.trainB_dataset) #将图片地址切片 #print('ok trainA:',trainA) gpu_device = '/gpu:0' trainA = trainA.apply(shuffle_and_repeat(100)).apply( map_and_batch(Image_Data_Class.image_processing, self.batch_size, num_parallel_batches=16, drop_remainder=True)).apply( prefetch_to_device(gpu_device, None)) trainB = trainB.apply(shuffle_and_repeat(100)).apply( map_and_batch(Image_Data_Class.image_processing, self.batch_size, num_parallel_batches=16, drop_remainder=True)).apply( prefetch_to_device(gpu_device, None)) #打乱原图片排列 #map_and_batch 1:将tensor的嵌套结构映射到另一个tensor嵌套结构的函数 # 2:要在此数据集合并的单个batch中的连续元素数(一个batch 32 个元素,即输出是32维的元素) # 3:要并行创建的batch数。一方面,较高的值可以帮助减轻落后者的影响。另一方面,如果CPU空闲,较高的值可能会增加竞争。 # 4:表示是否应丢弃最后一个batch,以防其大小小于所需值 #返回 32*64*64*3 的随机增强的数组 #应用gpu加速 #print('ok trainA:',trainA.shape) trainA_iterator = trainA.make_one_shot_iterator() trainB_iterator = trainB.make_one_shot_iterator() self.domain_A = trainA_iterator.get_next() self.domain_B = trainB_iterator.get_next() """ Define Generator, Discriminator """ x_ab, cam_ab = self.generate_a2b(self.domain_A) # real a #self.domain_A 是 卡通图片 #x_ab是由self.domain_A经过下采样 再 上采样得到的图片 x_ba, cam_ba = self.generate_b2a(self.domain_B) # real b #generate_a2b和generate_b2a是两套不同的参数 x_aba, _ = self.generate_b2a(x_ab, reuse=True) # real b x_bab, _ = self.generate_a2b(x_ba, reuse=True) # real a #固定参数不变,再将generate_a2b生成的图片,用generate_b2a生成一遍 #generate_b2a 尝试 将真人图生成卡通图 #generate_a2b 尝试 将卡通图生成真人图 #可以看做将generate_a2b与generate_b2a作为逆变换 x_aa, cam_aa = self.generate_b2a(self.domain_A, reuse=True) # fake b x_bb, cam_bb = self.generate_a2b(self.domain_B, reuse=True) # fake a #固定参数不变 #***将卡通图生成卡通图 #确保在generate_b2a和generate_a2b过程中,颜色区域是不变的 real_A_logit, real_A_cam_logit, real_B_logit, real_B_cam_logit = self.discriminate_real( self.domain_A, self.domain_B) #鉴别 真卡通图 与 真真人图 fake_A_logit, fake_A_cam_logit, fake_B_logit, fake_B_cam_logit = self.discriminate_fake( x_ba, x_ab) #鉴别 假卡通图 与 假真人图 #输入的是生成器生成的图片 #输出的是图片经过卷积的32*8*8*1的张量和池化结果的连接(32, 2) """ Define Loss """ if self.gan_type.__contains__('wgan') or self.gan_type == 'dragan': GP_A, GP_CAM_A = self.gradient_panalty(real=self.domain_A, fake=x_ba, scope="discriminator_A") GP_B, GP_CAM_B = self.gradient_panalty(real=self.domain_B, fake=x_ab, scope="discriminator_B") else: GP_A, GP_CAM_A = 0, 0 GP_B, GP_CAM_B = 0, 0 #接下来是对于假图片判别器discriminate_fake 真图片判别器discriminate_real的损失计算 #对判别器和生成器的损失计算 ''' 一、对抗损失(T) 用的最小二乘损失 ''' #fake_A_logit[0],fake_A_logit[1]与1的平方差的和+fake_A_cam_logit[0],fake_A_cam_logit[1]与1的平方差的和 #从生成器的角度看,真图片会被discriminate鉴别器判为0,生成的图片经过判别器(1-fake_*),也要尽量被判别为0 G_ad_loss_A = (generator_loss(self.gan_type, fake_A_logit) + generator_loss(self.gan_type, fake_A_cam_logit)) #生成假卡通图的损失 G_ad_loss_B = (generator_loss(self.gan_type, fake_B_logit) + generator_loss(self.gan_type, fake_B_cam_logit)) #生成假真人图的损失 #生成器的损失 D_ad_loss_A = ( discriminator_loss(self.gan_type, real_A_logit, fake_A_logit) + discriminator_loss(self.gan_type, real_A_cam_logit, fake_A_cam_logit) + GP_A + GP_CAM_A) #对卡通图的鉴别损失 #discriminator_loss力求将真图片判断为1,假图片判断为0 D_ad_loss_B = ( discriminator_loss(self.gan_type, real_B_logit, fake_B_logit) + discriminator_loss(self.gan_type, real_B_cam_logit, fake_B_cam_logit) + GP_B + GP_CAM_B) #鉴别器的损失 ''' 二、Cycle 损失(T) the image should be successfully translated back to the original domain ''' reconstruction_A = L1_loss(x_aba, self.domain_A) # reconstruction #由卡通图生成真人图再生成卡通图的损失 reconstruction_B = L1_loss(x_bab, self.domain_B) # reconstruction #由真人图生成卡通图再生成真人图的损失 ''' 三、Identity 损失 确保在A B相互变化时,身份信息不丢失 ''' identity_A = L1_loss(x_aa, self.domain_A) identity_B = L1_loss(x_bb, self.domain_B) #将卡通图生成卡通图的损失 ''' 四、CAM 损失 利用辅助分类器,使得G和D知道在哪里进行强化变换 在A变B时,热力图应该有明显显示 在A变A时,热力图应该没有显示 ''' cam_A = cam_loss(source=cam_ba, non_source=cam_aa) #cam_ba是从真人图到卡通图的全连接(两次,用不同方法池化) #cam_aa是从卡通图到卡通图的全连接 cam_B = cam_loss(source=cam_ab, non_source=cam_bb) #开始时的比重是如何决定的??? # Generator_A_gan = self.adv_weight * G_ad_loss_A #1 #网络由真人图生成卡通图的损失*相对的比重 Generator_A_cycle = self.cycle_weight * reconstruction_B #10 #self.generate_a2b(self.generate_b2a(self.domain_B), reuse=True) #由真人图生成卡通图再生成真人图的损失*相对的比重 Generator_A_identity = self.identity_weight * identity_A #10 #由卡通图生成卡通图的损失*相对的比重 Generator_A_cam = self.cam_weight * cam_A #1000 #从真人图到卡通图的全连接*相对的比重 Generator_B_gan = self.adv_weight * G_ad_loss_B Generator_B_cycle = self.cycle_weight * reconstruction_A Generator_B_identity = self.identity_weight * identity_B Generator_B_cam = self.cam_weight * cam_B print('ok 5') Generator_A_loss = Generator_A_gan + Generator_A_cycle + Generator_A_identity + Generator_A_cam #所有生成卡通图的损失 Generator_B_loss = Generator_B_gan + Generator_B_cycle + Generator_B_identity + Generator_B_cam Discriminator_A_loss = self.adv_weight * D_ad_loss_A #对生成的卡通图 和 真的卡通图的鉴别损失+生成的卡通图的全连接 和 真的卡通图的全连接的鉴别损失*权重 Discriminator_B_loss = self.adv_weight * D_ad_loss_B self.Generator_loss = Generator_A_loss + Generator_B_loss + regularization_loss( 'generator') #生成器的总损失 self.Discriminator_loss = Discriminator_A_loss + Discriminator_B_loss + regularization_loss( 'discriminator') #鉴别器的总损失 print('55') """ Result Image """ #生成的假图片(用于储存) self.fake_A = x_ba self.fake_B = x_ab #输入的真图片 self.real_A = self.domain_A self.real_B = self.domain_B self.imgba = imgba self.imgab = imgab """ Training """ t_vars = tf.trainable_variables() G_vars = [var for var in t_vars if 'generator' in var.name] D_vars = [var for var in t_vars if 'discriminator' in var.name] self.G_optim = tf.train.AdamOptimizer(self.lr, beta1=0.5, beta2=0.999).minimize( self.Generator_loss, var_list=G_vars) self.D_optim = tf.train.AdamOptimizer(self.lr, beta1=0.5, beta2=0.999).minimize( self.Discriminator_loss, var_list=D_vars) #var_list:在优化时每次要迭代更新的参数集合 """" Summary """ # self.all_G_loss = tf.summary.scalar("Generator_loss", self.Generator_loss) # self.all_D_loss = tf.summary.scalar("Discriminator_loss", self.Discriminator_loss) # self.G_A_loss = tf.summary.scalar("G_A_loss", Generator_A_loss) # self.G_A_gan = tf.summary.scalar("G_A_gan", Generator_A_gan) # self.G_A_cycle = tf.summary.scalar("G_A_cycle", Generator_A_cycle) # self.G_A_identity = tf.summary.scalar("G_A_identity", Generator_A_identity) # self.G_A_cam = tf.summary.scalar("G_A_cam", Generator_A_cam) # self.G_B_loss = tf.summary.scalar("G_B_loss", Generator_B_loss) # self.G_B_gan = tf.summary.scalar("G_B_gan", Generator_B_gan) # self.G_B_cycle = tf.summary.scalar("G_B_cycle", Generator_B_cycle) # self.G_B_identity = tf.summary.scalar("G_B_identity", Generator_B_identity) # self.G_B_cam = tf.summary.scalar("G_B_cam", Generator_B_cam) # self.D_A_loss = tf.summary.scalar("D_A_loss", Discriminator_A_loss) # self.D_B_loss = tf.summary.scalar("D_B_loss", Discriminator_B_loss) ''' 画图 ''' # self.rho_var = [] # for var in tf.trainable_variables(): # if 'rho' in var.name: # self.rho_var.append(tf.summary.histogram(var.name, var)) # self.rho_var.append(tf.summary.scalar(var.name + "_min", tf.reduce_min(var))) # self.rho_var.append(tf.summary.scalar(var.name + "_max", tf.reduce_max(var))) # self.rho_var.append(tf.summary.scalar(var.name + "_mean", tf.reduce_mean(var))) # print('ok 7') # g_summary_list = [self.G_A_loss, self.G_A_gan, self.G_A_cycle, self.G_A_identity, self.G_A_cam, # self.G_B_loss, self.G_B_gan, self.G_B_cycle, self.G_B_identity, self.G_B_cam, # self.all_G_loss] # g_summary_list.extend(self.rho_var) # d_summary_list = [self.D_A_loss, self.D_B_loss, self.all_D_loss] # self.G_loss = tf.summary.merge(g_summary_list) # self.D_loss = tf.summary.merge(d_summary_list) else: """ Test """ self.test_domain_A = tf.placeholder( tf.float32, [1, self.img_size, self.img_size, self.img_ch], name='test_domain_A') self.test_domain_B = tf.placeholder( tf.float32, [1, self.img_size, self.img_size, self.img_ch], name='test_domain_B') self.test_fake_B, _ = self.generate_a2b(self.test_domain_A) self.test_fake_A, _ = self.generate_b2a(self.test_domain_B)
def build_model(self): """ Graph Input """ # images if self.custom_dataset : Image_Data_Class = ImageData(self.img_size, self.c_dim) inputs = tf.data.Dataset.from_tensor_slices(self.data) gpu_device = '/gpu:0' inputs = inputs.apply(shuffle_and_repeat(self.dataset_num)).apply(map_and_batch(Image_Data_Class.image_processing, self.batch_size, num_parallel_batches=8, drop_remainder=True)).apply(prefetch_to_device(gpu_device, self.batch_size)) inputs_iterator = inputs.make_one_shot_iterator() self.inputs = inputs_iterator.get_next() else : self.inputs = tf.placeholder(tf.float32, [self.batch_size, self.img_size, self.img_size, self.c_dim], name='real_images') # noises self.z = tf.placeholder(tf.float32, [self.batch_size, self.z_dim], name='z') """ Loss Function """ # output of D for real images real_logits = self.discriminator(self.inputs) # output of D for fake images fake_images = self.generator(self.z) fake_logits = self.discriminator(fake_images, reuse=True) if self.gan_type.__contains__('wgan') or self.gan_type == 'dragan' : GP = self.gradient_penalty(real=self.inputs, fake=fake_images) else : GP = 0 # get loss for discriminator self.d_loss = discriminator_loss(self.gan_type, real=real_logits, fake=fake_logits) + GP # get loss for generator self.g_loss = generator_loss(self.gan_type, fake=fake_logits) """ Training """ # divide trainable variables into a group for D and a group for G t_vars = tf.trainable_variables() d_vars = [var for var in t_vars if 'discriminator' in var.name] g_vars = [var for var in t_vars if 'generator' in var.name] # optimizers self.d_optim = tf.train.AdamOptimizer(self.d_learning_rate, beta1=self.beta1, beta2=self.beta2).minimize(self.d_loss, var_list=d_vars) self.g_optim = tf.train.AdamOptimizer(self.g_learning_rate, beta1=self.beta1, beta2=self.beta2).minimize(self.g_loss, var_list=g_vars) """" Testing """ # for test self.fake_images = self.generator(self.z, is_training=False, reuse=True) """ Summary """ self.d_sum = tf.summary.scalar("d_loss", self.d_loss) self.g_sum = tf.summary.scalar("g_loss", self.g_loss)
def build_model(self): self.lr = tf.placeholder(tf.float32, name='learning_rate') """ Input Image""" Image_Data_Class = ImageData(self.img_size, self.img_ch, self.augment_flag) trainA = tf.data.Dataset.from_tensor_slices(self.trainA_dataset) trainB = tf.data.Dataset.from_tensor_slices(self.trainB_dataset) trainB_smooth = tf.data.Dataset.from_tensor_slices(self.trainB_smooth_dataset) gpu_device = '/gpu:0' trainA = trainA.apply(shuffle_and_repeat(self.dataset_num)).apply(map_and_batch(Image_Data_Class.image_processing, self.batch_size, num_parallel_batches=16, drop_remainder=True)).apply(prefetch_to_device(gpu_device, self.batch_size)) trainB = trainB.apply(shuffle_and_repeat(self.dataset_num)).apply(map_and_batch(Image_Data_Class.image_processing, self.batch_size, num_parallel_batches=16, drop_remainder=True)).apply(prefetch_to_device(gpu_device, self.batch_size)) trainB_smooth = trainB_smooth.apply(shuffle_and_repeat(self.dataset_num)).apply(map_and_batch(Image_Data_Class.image_processing, self.batch_size, num_parallel_batches=16, drop_remainder=True)).apply(prefetch_to_device(gpu_device, self.batch_size)) trainA_iterator = trainA.make_one_shot_iterator() trainB_iterator = trainB.make_one_shot_iterator() trainB_smooth_iterator = trainB_smooth.make_one_shot_iterator() self.real_A = trainA_iterator.get_next() self.real_B = trainB_iterator.get_next() self.real_B_smooth = trainB_smooth_iterator.get_next() self.test_real_A = tf.placeholder(tf.float32, [1, self.img_size, self.img_size, self.img_ch], name='test_real_A') """ Define Generator, Discriminator """ self.fake_B = self.generator(self.real_A) real_B_logit = self.discriminator(self.real_B) fake_B_logit = self.discriminator(self.fake_B, reuse=True) real_B_smooth_logit = self.discriminator(self.real_B_smooth, reuse=True) """ Define Loss """ if self.gan_type.__contains__('gp') or self.gan_type.__contains__('lp') or self.gan_type.__contains__('dragan') : GP = self.gradient_panalty(real=self.real_B, fake=self.fake_B) + self.gradient_panalty(self.real_B, fake=self.real_B_smooth) else : GP = 0.0 v_loss = self.vgg_weight * vgg_loss(self.real_A, self.fake_B) g_loss = self.adv_weight * generator_loss(self.gan_type, fake_B_logit) d_loss = self.adv_weight * discriminator_loss(self.gan_type, real_B_logit, fake_B_logit, real_B_smooth_logit) + GP self.Vgg_loss = v_loss self.Generator_loss = g_loss + v_loss self.Discriminator_loss = d_loss """ Result Image """ self.test_fake_B = self.generator(self.test_real_A, reuse=True) """ Training """ t_vars = tf.trainable_variables() G_vars = [var for var in t_vars if 'generator' in var.name] D_vars = [var for var in t_vars if 'discriminator' in var.name] self.init_optim = tf.train.AdamOptimizer(self.lr, beta1=0.5, beta2=0.999).minimize(self.Vgg_loss, var_list=G_vars) self.G_optim = tf.train.AdamOptimizer(self.lr, beta1=0.5, beta2=0.999).minimize(self.Generator_loss, var_list=G_vars) self.D_optim = tf.train.AdamOptimizer(self.lr, beta1=0.5, beta2=0.999).minimize(self.Discriminator_loss, var_list=D_vars) """" Summary """ self.G_loss = tf.summary.scalar("Generator_loss", self.Generator_loss) self.D_loss = tf.summary.scalar("Discriminator_loss", self.Discriminator_loss) self.G_gan = tf.summary.scalar("G_gan", g_loss) self.G_vgg = tf.summary.scalar("G_vgg", v_loss) self.V_loss_merge = tf.summary.merge([self.G_vgg]) self.G_loss_merge = tf.summary.merge([self.G_loss, self.G_gan, self.G_vgg]) self.D_loss_merge = tf.summary.merge([self.D_loss])
def build_model(self): # some parameters bs = self.batch_size """ Graph Input """ # images if self.custom_dataset: Image_Data_Class = ImageData(self.output_height, self.c_dim) inputs = tf.data.Dataset.from_tensor_slices(self.data) gpu_device = '/gpu:0' inputs = inputs.apply(shuffle_and_repeat(self.dataset_num)).apply( map_and_batch(Image_Data_Class.image_processing, self.batch_size, num_parallel_batches=8, drop_remainder=True)).apply(prefetch_to_device(gpu_device, self.batch_size)) inputs_iterator = inputs.make_one_shot_iterator() self.inputs = inputs_iterator.get_next() else: self.inputs = tf.placeholder(tf.float32, [self.batch_size, self.output_height, self.output_height, self.c_dim], name='real_images') # noises self.z = tf.placeholder(tf.float32, [bs, self.z_dim], name='z') """ Loss Function """ x_fake = self.generator(self.z, is_training=True, reuse=False) x_real_encoder = self.encoder(self.inputs, is_training=True, reuse=False, sn=True) x_fake_encoder = self.encoder(x_fake, is_training=True, reuse=True, sn=True) x_real_fake = tf.subtract(x_real_encoder, x_fake_encoder) x_fake_real = tf.subtract(x_fake_encoder, x_real_encoder) x_real_fake_score = self.discriminator(x_real_fake, reuse=False, sn=True) x_fake_real_score = self.discriminator(x_fake_real, reuse=True, sn=True) # get loss for discriminator self.d_loss = discriminator_loss(self.loss_type, real=x_real_fake_score, fake=x_fake_real_score) # get loss for generator self.g_loss = generator_loss(self.loss_type, real=x_real_fake_score, fake=x_fake_real_score) """ Training """ # divide trainable variables into a group for D and a group for G t_vars = tf.trainable_variables() d_vars = [var for var in t_vars if 'discriminator' in var.name or 'encoder' in var.name] g_vars = [var for var in t_vars if 'generator' in var.name] # optimizers with tf.control_dependencies(tf.get_collection(tf.GraphKeys.UPDATE_OPS)): self.d_optim = tf.train.AdamOptimizer(self.learning_rate, beta1=self.beta1) \ .minimize(self.d_loss, var_list=d_vars) self.g_optim = tf.train.AdamOptimizer(self.learning_rate * 5, beta1=self.beta1) \ .minimize(self.g_loss, var_list=g_vars) """" Testing """ # for test self.fake_images = self.generator(self.z, is_training=False, reuse=True) """ Summary """ self.d_sum = tf.summary.scalar("d_loss", self.d_loss) self.g_sum = tf.summary.scalar("g_loss", self.g_loss)
def build_model(self): self.lr = tf.placeholder(tf.float32, name='learning_rate') """ Input Image""" Image_Data_Class = ImageData(self.img_h, self.img_w, self.img_ch, self.augment_flag) trainA = tf.data.Dataset.from_tensor_slices(self.trainA_dataset) trainB = tf.data.Dataset.from_tensor_slices(self.trainB_dataset) gpu_device = '/gpu:0' trainA = trainA.\ apply(shuffle_and_repeat(self.dataset_num)). \ apply(map_and_batch(Image_Data_Class.image_processing, self.batch_size, num_parallel_batches=16, drop_remainder=True)). \ apply(prefetch_to_device(gpu_device, None)) trainB = trainB. \ apply(shuffle_and_repeat(self.dataset_num)). \ apply(map_and_batch(Image_Data_Class.image_processing, self.batch_size, num_parallel_batches=16, drop_remainder=True)). \ apply(prefetch_to_device(gpu_device, None)) # When using dataset.prefetch, use buffer_size=None to let it detect optimal buffer size trainA_iterator = trainA.make_one_shot_iterator() trainB_iterator = trainB.make_one_shot_iterator() self.domain_A = trainA_iterator.get_next() self.domain_B = trainB_iterator.get_next() """ Define Encoder, Generator, Discriminator """ # encode content_a, style_a = self.encoder_A(self.domain_A) content_b, style_b = self.encoder_B(self.domain_B) # decode (cross domain) x_ba, U_A = self.decoder_A(content_B=content_b, style_A=style_a) x_ab, U_B = self.decoder_B(content_A=content_a, style_B=style_b) # decode (within domain) x_aa, _ = self.decoder_A(content_B=content_a, style_A=style_a, reuse=True) x_bb, _ = self.decoder_B(content_A=content_b, style_B=style_b, reuse=True) # encode again content_ba, style_ba = self.encoder_A(x_ba, reuse=True) content_ab, style_ab = self.encoder_B(x_ab, reuse=True) # decode again (if needed) x_aba, _ = self.decoder_A(content_B=content_ab, style_A=style_ba, reuse=True) x_bab, _ = self.decoder_B(content_A=content_ba, style_B=style_ab, reuse=True) real_A_logit, real_B_logit = self.discriminate_real( self.domain_A, self.domain_B) fake_A_logit, fake_B_logit = self.discriminate_fake(x_ba, x_ab) """ Define Loss """ G_adv_A = self.gan_w * generator_loss(self.gan_type, fake_A_logit) G_adv_B = self.gan_w * generator_loss(self.gan_type, fake_B_logit) D_adv_A = self.gan_w * discriminator_loss(self.gan_type, real_A_logit, fake_A_logit) D_adv_B = self.gan_w * discriminator_loss(self.gan_type, real_B_logit, fake_B_logit) recon_style_A = self.recon_s_w * L1_loss(style_ba, style_a) recon_style_B = self.recon_s_w * L1_loss(style_ab, style_b) recon_content_A = self.recon_c_w * L1_loss(content_ab, content_a) recon_content_B = self.recon_c_w * L1_loss(content_ba, content_b) cyc_recon_A = self.recon_x_cyc_w * L1_loss(x_aba, self.domain_A) cyc_recon_B = self.recon_x_cyc_w * L1_loss(x_bab, self.domain_B) recon_A = self.recon_x_w * L1_loss(x_aa, self.domain_A) # reconstruction recon_B = self.recon_x_w * L1_loss(x_bb, self.domain_B) # reconstruction whitening_A, coloring_A = group_wise_regularization( deep_whitening_transform(content_a), U_A, self.group_num) whitening_B, coloring_B = group_wise_regularization( deep_whitening_transform(content_b), U_B, self.group_num) whitening_A = self.lambda_w * whitening_A whitening_B = self.lambda_w * whitening_B coloring_A = self.lambda_c * coloring_A coloring_B = self.lambda_c * coloring_B G_reg_A = regularization_loss('decoder_A') + regularization_loss( 'encoder_A') G_reg_B = regularization_loss('decoder_B') + regularization_loss( 'encoder_B') D_reg_A = regularization_loss('discriminator_A') D_reg_B = regularization_loss('discriminator_B') Generator_A_loss = G_adv_A + \ recon_A + \ recon_style_A + \ recon_content_A + \ cyc_recon_B + \ whitening_A + \ coloring_A + \ G_reg_A Generator_B_loss = G_adv_B + \ recon_B + \ recon_style_B + \ recon_content_B + \ cyc_recon_A + \ whitening_B + \ coloring_B + \ G_reg_B Discriminator_A_loss = D_adv_A + D_reg_A Discriminator_B_loss = D_adv_B + D_reg_B self.Generator_loss = Generator_A_loss + Generator_B_loss self.Discriminator_loss = Discriminator_A_loss + Discriminator_B_loss """ Training """ t_vars = tf.trainable_variables() G_vars = [ var for var in t_vars if 'decoder' in var.name or 'encoder' in var.name ] D_vars = [var for var in t_vars if 'discriminator' in var.name] self.G_optim = tf.train.AdamOptimizer( self.lr, beta1=0.5, beta2=0.999).minimize(self.Generator_loss, var_list=G_vars) self.D_optim = tf.train.AdamOptimizer( self.lr, beta1=0.5, beta2=0.999).minimize(self.Discriminator_loss, var_list=D_vars) """" Summary """ self.all_G_loss = tf.summary.scalar("Generator_loss", self.Generator_loss) self.all_D_loss = tf.summary.scalar("Discriminator_loss", self.Discriminator_loss) self.G_A_loss = tf.summary.scalar("G_A_loss", Generator_A_loss) self.G_B_loss = tf.summary.scalar("G_B_loss", Generator_B_loss) self.D_A_loss = tf.summary.scalar("D_A_loss", Discriminator_A_loss) self.D_B_loss = tf.summary.scalar("D_B_loss", Discriminator_B_loss) self.G_A_adv_loss = tf.summary.scalar("G_A_adv_loss", G_adv_A) self.G_A_style_loss = tf.summary.scalar("G_A_style_loss", recon_style_A) self.G_A_content_loss = tf.summary.scalar("G_A_content_loss", recon_content_A) self.G_A_cyc_loss = tf.summary.scalar("G_A_cyc_loss", cyc_recon_A) self.G_A_identity_loss = tf.summary.scalar("G_A_identity_loss", recon_A) self.G_A_whitening_loss = tf.summary.scalar("G_A_whitening_loss", whitening_A) self.G_A_coloring_loss = tf.summary.scalar("G_A_coloring_loss", coloring_A) self.G_B_adv_loss = tf.summary.scalar("G_B_adv_loss", G_adv_B) self.G_B_style_loss = tf.summary.scalar("G_B_style_loss", recon_style_B) self.G_B_content_loss = tf.summary.scalar("G_B_content_loss", recon_content_B) self.G_B_cyc_loss = tf.summary.scalar("G_B_cyc_loss", cyc_recon_B) self.G_B_identity_loss = tf.summary.scalar("G_B_identity_loss", recon_B) self.G_B_whitening_loss = tf.summary.scalar("G_B_whitening_loss", whitening_B) self.G_B_coloring_loss = tf.summary.scalar("G_B_coloring_loss", coloring_B) self.alpha_var = [] for var in tf.trainable_variables(): if 'alpha' in var.name: self.alpha_var.append(tf.summary.histogram(var.name, var)) self.alpha_var.append( tf.summary.scalar(var.name, tf.reduce_max(var))) G_summary_list = [ self.G_A_adv_loss, self.G_A_style_loss, self.G_A_content_loss, self.G_A_cyc_loss, self.G_A_identity_loss, self.G_A_whitening_loss, self.G_A_coloring_loss, self.G_A_loss, self.G_B_adv_loss, self.G_B_style_loss, self.G_B_content_loss, self.G_B_cyc_loss, self.G_B_identity_loss, self.G_B_whitening_loss, self.G_B_coloring_loss, self.G_B_loss, self.all_G_loss ] G_summary_list.extend(self.alpha_var) self.G_loss = tf.summary.merge(G_summary_list) self.D_loss = tf.summary.merge( [self.D_A_loss, self.D_B_loss, self.all_D_loss]) """ Image """ self.fake_A = x_ba self.fake_B = x_ab self.real_A = self.domain_A self.real_B = self.domain_B """ Test """ """ Guided Image Translation """ self.content_image = tf.placeholder( tf.float32, [1, self.img_h, self.img_w, self.img_ch], name='content_image') self.style_image = tf.placeholder( tf.float32, [1, self.img_h, self.img_w, self.img_ch], name='guide_style_image') if self.direction == 'a2b': guide_content_A, _ = self.encoder_A(self.content_image, reuse=True) _, guide_style_B = self.encoder_B(self.style_image, reuse=True) self.guide_fake_B, _ = self.decoder_B(content_A=guide_content_A, style_B=guide_style_B, reuse=True) else: guide_content_B, _ = self.encoder_B(self.content_image, reuse=True) _, guide_style_A = self.encoder_A(self.style_image, reuse=True) self.guide_fake_A, _ = self.decoder_A(content_B=guide_content_B, style_A=guide_style_A, reuse=True)
def build_model(self): """ Input Image""" img_data_class = Image_data(self.img_height, self.img_width, self.img_ch, self.dataset_path, self.augment_flag) train_captions, train_images, test_captions, test_images, idx_to_word, word_to_idx = img_data_class.preprocess() """ train_captions: (8855, 10, 18), test_captions: (2933, 10, 18) train_images: (8855,), test_images: (2933,) idx_to_word : 5450 5450 """ if self.phase == 'train' : self.lr = tf.placeholder(tf.float32, name='learning_rate') self.dataset_num = len(train_images) img_and_caption = tf.data.Dataset.from_tensor_slices((train_images, train_captions)) gpu_device = '/gpu:0' img_and_caption = img_and_caption.apply(shuffle_and_repeat(self.dataset_num)).apply( map_and_batch(img_data_class.image_processing, batch_size=self.batch_size, num_parallel_batches=16, drop_remainder=True)).apply(prefetch_to_device(gpu_device, None)) img_and_caption_iterator = img_and_caption.make_one_shot_iterator() real_img_256, caption = img_and_caption_iterator.get_next() target_sentence_index = tf.random_uniform(shape=[], minval=0, maxval=10, dtype=tf.int32) caption = tf.gather(caption, target_sentence_index, axis=1) word_emb, sent_emb, mask = self.rnn_encoder(caption, n_words=len(idx_to_word), embed_dim=self.embed_dim, drop_rate=0.5, n_hidden=128, n_layers=1, bidirectional=True, rnn_type='lstm') noise = tf.random_normal(shape=[self.batch_size, self.z_dim], mean=0.0, stddev=1.0) fake_imgs, _, mu, logvar = self.generator(noise, sent_emb, word_emb, mask) real_img_64, real_img_128 = resize(real_img_256, target_size=[64, 64]), resize(real_img_256, target_size=[128, 128]) fake_img_64, fake_img_128, fake_img_256 = fake_imgs[0], fake_imgs[1], fake_imgs[2] real_imgs = [real_img_64, real_img_128, real_img_256] uncond_real_logits, cond_real_logits, real_emb_features = self.discriminator([real_img_64, real_img_128, real_img_256], sent_emb) _, cond_wrong_logits, _ = self.discriminator([real_img_64[:(self.batch_size - 1)], real_img_128[:(self.batch_size - 1)], real_img_256[:(self.batch_size - 1)]], sent_emb[1:self.batch_size]) uncond_fake_logits, cond_fake_logits, _ = self.discriminator([fake_img_64, fake_img_128, fake_img_256], sent_emb) self.g_adv_loss, self.d_adv_loss = 0, 0 self.g_vgg_loss = 0 self.d_word_loss = 0 for i in range(3): self.g_adv_loss += self.adv_weight * (generator_loss(self.gan_type, uncond_fake_logits[i]) + generator_loss(self.gan_type, cond_fake_logits[i])) self.g_vgg_loss += self.vgg_weight * (vgg16_perceptual_loss(real_imgs[i], fake_imgs[i], class_vgg_16)) uncond_real_loss, uncond_fake_loss = discriminator_loss(self.gan_type, uncond_real_logits[i], uncond_fake_logits[i]) cond_real_loss, cond_fake_loss = discriminator_loss(self.gan_type, cond_real_logits[i], cond_fake_logits[i]) _, cond_wrong_loss = discriminator_loss(self.gan_type, None, cond_wrong_logits[i]) self.d_adv_loss += self.adv_weight * (((uncond_real_loss + cond_real_loss) / 2) + (uncond_fake_loss + cond_fake_loss + cond_wrong_loss) / 3) self.d_word_loss += self.word_weight * word_level_correlation_loss(real_emb_features[i], word_emb, gamma1=4.0, gamma2=5.0) self.g_kl_loss = self.kl_weight * kl_loss(mu, logvar) self.g_vgg_loss = self.g_vgg_loss / 3.0 self.g_loss = self.g_adv_loss + self.g_kl_loss + self.g_vgg_loss self.d_loss = self.d_adv_loss + self.d_word_loss self.real_img = real_img_256 self.fake_img = fake_img_256 """ Training """ t_vars = tf.trainable_variables() G_vars = [var for var in t_vars if 'generator' in var.name] D_vars = [var for var in t_vars if 'discriminator' in var.name] self.g_optim = tf.train.AdamOptimizer(self.lr, beta1=0.5, beta2=0.999).minimize(self.g_loss, var_list=G_vars) self.d_optim = tf.train.AdamOptimizer(self.lr, beta1=0.5, beta2=0.999).minimize(self.d_loss, var_list=D_vars) """" Summary """ self.summary_g_loss = tf.summary.scalar("g_loss", self.g_loss) self.summary_d_loss = tf.summary.scalar("d_loss", self.d_loss) self.summary_g_adv_loss = tf.summary.scalar("g_adv_loss", self.g_adv_loss) self.summary_g_kl_loss = tf.summary.scalar("g_kl_loss", self.g_kl_loss) self.summary_g_vgg_loss = tf.summary.scalar("g_vgg_loss", self.g_vgg_loss) self.summary_d_adv_loss = tf.summary.scalar("d_adv_loss", self.d_adv_loss) self.summary_d_word_loss = tf.summary.scalar("d_word_loss", self.d_word_loss) g_summary_list = [self.summary_g_loss, self.summary_g_adv_loss, self.summary_g_kl_loss, self.summary_g_vgg_loss] d_summary_list = [self.summary_d_loss, self.summary_d_adv_loss, self.summary_d_word_loss] self.summary_merge_g_loss = tf.summary.merge(g_summary_list) self.summary_merge_d_loss = tf.summary.merge(d_summary_list) else : """ Test """ self.dataset_num = len(test_captions) gpu_device = '/gpu:0' img_and_caption = tf.data.Dataset.from_tensor_slices((test_images, test_captions)) img_and_caption = img_and_caption.apply( shuffle_and_repeat(self.dataset_num)).apply( map_and_batch(img_data_class.image_processing, batch_size=self.batch_size, num_parallel_batches=16, drop_remainder=True)).apply( prefetch_to_device(gpu_device, None)) img_and_caption_iterator = img_and_caption.make_one_shot_iterator() real_img_256, caption = img_and_caption_iterator.get_next() target_sentence_index = tf.random_uniform(shape=[], minval=0, maxval=10, dtype=tf.int32) caption = tf.gather(caption, target_sentence_index, axis=1) word_emb, sent_emb, mask = self.rnn_encoder(caption, n_words=len(idx_to_word), embed_dim=self.embed_dim, drop_rate=0.5, n_hidden=128, n_layers=1, bidirectional=True, rnn_type='lstm', is_training=False) noise = tf.random_normal(shape=[self.batch_size, self.z_dim], mean=0.0, stddev=1.0) fake_imgs, _, _, _ = self.generator(noise, sent_emb, word_emb, mask, is_training=False) self.test_real_img = real_img_256 self.test_fake_img = fake_imgs[2]