def compute_losses(self): cycle_consistency_loss_a = \ self._lambda_a * losses.cycle_consistency_loss( real_images=self.input_a, generated_images=self.cycle_images_a, ) cycle_consistency_loss_b = \ self._lambda_b * losses.cycle_consistency_loss( real_images=self.input_b, generated_images=self.cycle_images_b, ) lsgan_loss_a = losses.lsgan_loss_generator(self.prob_fake_a_is_real) lsgan_loss_b = losses.lsgan_loss_generator(self.prob_fake_b_is_real) lsgan_loss_p = losses.lsgan_loss_generator( self.prob_pred_mask_b_is_real) lsgan_loss_p_ll = losses.lsgan_loss_generator( self.prob_pred_mask_b_ll_is_real) lsgan_loss_a_aux = losses.lsgan_loss_generator( self.prob_fake_a_aux_is_real) ce_loss_b, dice_loss_b = losses.task_loss(self.pred_mask_fake_b, self.gt_a) ce_loss_b_ll, dice_loss_b_ll = losses.task_loss( self.pred_mask_fake_b_ll, self.gt_a) l2_loss_b = tf.add_n([ 0.0001 * tf.nn.l2_loss(v) for v in tf.trainable_variables() if '/s_B/' in v.name or '/s_B_ll/' in v.name or '/e_B/' in v.name ]) g_loss_A = cycle_consistency_loss_a + cycle_consistency_loss_b + lsgan_loss_b g_loss_B = cycle_consistency_loss_b + cycle_consistency_loss_a + lsgan_loss_a seg_loss_B = ce_loss_b + dice_loss_b + l2_loss_b + 0.1 * ( ce_loss_b_ll + dice_loss_b_ll ) + 0.1 * g_loss_B + 0.1 * lsgan_loss_p + 0.01 * lsgan_loss_p_ll + 0.1 * lsgan_loss_a_aux d_loss_A = losses.lsgan_loss_discriminator( prob_real_is_real=self.prob_real_a_is_real, prob_fake_is_real=self.prob_fake_pool_a_is_real, ) d_loss_A_aux = losses.lsgan_loss_discriminator( prob_real_is_real=self.prob_cycle_a_aux_is_real, prob_fake_is_real=self.prob_fake_pool_a_aux_is_real, ) d_loss_A = d_loss_A + d_loss_A_aux d_loss_B = losses.lsgan_loss_discriminator( prob_real_is_real=self.prob_real_b_is_real, prob_fake_is_real=self.prob_fake_pool_b_is_real, ) d_loss_P = losses.lsgan_loss_discriminator( prob_real_is_real=self.prob_pred_mask_fake_b_is_real, prob_fake_is_real=self.prob_pred_mask_b_is_real, ) d_loss_P_ll = losses.lsgan_loss_discriminator( prob_real_is_real=self.prob_pred_mask_fake_b_ll_is_real, prob_fake_is_real=self.prob_pred_mask_b_ll_is_real, ) optimizer_gan = tf.train.AdamOptimizer(self.learning_rate_gan, beta1=0.5) optimizer_seg = tf.train.AdamOptimizer(self.learning_rate_seg) self.model_vars = tf.trainable_variables() d_A_vars = [var for var in self.model_vars if '/d_A/' in var.name] d_B_vars = [var for var in self.model_vars if '/d_B/' in var.name] g_A_vars = [var for var in self.model_vars if '/g_A/' in var.name] e_B_vars = [var for var in self.model_vars if '/e_B/' in var.name] de_B_vars = [var for var in self.model_vars if '/de_B/' in var.name] s_B_vars = [var for var in self.model_vars if '/s_B/' in var.name] s_B_ll_vars = [ var for var in self.model_vars if '/s_B_ll/' in var.name ] d_P_vars = [var for var in self.model_vars if '/d_P/' in var.name] d_P_ll_vars = [ var for var in self.model_vars if '/d_P_ll/' in var.name ] self.d_A_trainer = optimizer_gan.minimize(d_loss_A, var_list=d_A_vars) self.d_B_trainer = optimizer_gan.minimize(d_loss_B, var_list=d_B_vars) self.g_A_trainer = optimizer_gan.minimize(g_loss_A, var_list=g_A_vars) self.g_B_trainer = optimizer_gan.minimize(g_loss_B, var_list=de_B_vars) self.d_P_trainer = optimizer_gan.minimize(d_loss_P, var_list=d_P_vars) self.d_P_ll_trainer = optimizer_gan.minimize(d_loss_P_ll, var_list=d_P_ll_vars) self.s_B_trainer = optimizer_seg.minimize(seg_loss_B, var_list=e_B_vars + s_B_vars + s_B_ll_vars) for var in self.model_vars: print(var.name) # Summary variables for tensorboard self.g_A_loss_summ = tf.summary.scalar("g_A_loss", g_loss_A) self.g_B_loss_summ = tf.summary.scalar("g_B_loss", g_loss_B) self.d_A_loss_summ = tf.summary.scalar("d_A_loss", d_loss_A) self.d_B_loss_summ = tf.summary.scalar("d_B_loss", d_loss_B) self.ce_B_loss_summ = tf.summary.scalar("ce_B_loss", ce_loss_b) self.dice_B_loss_summ = tf.summary.scalar("dice_B_loss", dice_loss_b) self.l2_B_loss_summ = tf.summary.scalar("l2_B_loss", l2_loss_b) self.s_B_loss_summ = tf.summary.scalar("s_B_loss", seg_loss_B) self.s_B_loss_merge_summ = tf.summary.merge([ self.ce_B_loss_summ, self.dice_B_loss_summ, self.l2_B_loss_summ, self.s_B_loss_summ ]) self.d_P_loss_summ = tf.summary.scalar("d_P_loss", d_loss_P) self.d_P_ll_loss_summ = tf.summary.scalar("d_P_loss_ll", d_loss_P_ll) self.d_P_loss_merge_summ = tf.summary.merge( [self.d_P_loss_summ, self.d_P_ll_loss_summ])
def compute_losses(self): cycle_consistency_loss_a = \ self._lambda_a * losses.cycle_consistency_loss( real_images=tf.expand_dims(self.input_a[:,:,:,1], axis=3), generated_images=self.cycle_images_a, ) cycle_consistency_loss_b = \ self._lambda_b * losses.cycle_consistency_loss( real_images=tf.expand_dims(self.input_b[:,:,:,1], axis=3), generated_images=self.cycle_images_b, ) lsgan_loss_a = losses.lsgan_loss_generator(self.prob_fake_a_is_real) lsgan_loss_b = losses.lsgan_loss_generator(self.prob_fake_b_is_real) lsgan_loss_f = losses.lsgan_loss_generator(self.prob_fea_b_is_real) lsgan_loss_a_aux = losses.lsgan_loss_generator( self.prob_fake_a_aux_is_real) ce_loss_b, dice_loss_b = losses.task_loss(self.pred_mask_fake_b, self.gt_a) l2_loss_b = tf.add_n([ 0.0001 * tf.nn.l2_loss(v) for v in tf.trainable_variables() if '/s_B/' in v.name or '/e_B/' in v.name ]) g_loss_A = cycle_consistency_loss_a + cycle_consistency_loss_b + lsgan_loss_b g_loss_B = cycle_consistency_loss_b + cycle_consistency_loss_a + lsgan_loss_a self.loss_f_weight = tf.placeholder(tf.float32, shape=[], name="loss_f_weight") self.loss_f_weight_summ = tf.summary.scalar("loss_f_weight", self.loss_f_weight) seg_loss_B = ce_loss_b + dice_loss_b + l2_loss_b + 0.1 * g_loss_B + self.loss_f_weight * lsgan_loss_f + 0.1 * lsgan_loss_a_aux d_loss_A = losses.lsgan_loss_discriminator( prob_real_is_real=self.prob_real_a_is_real, prob_fake_is_real=self.prob_fake_pool_a_is_real, ) d_loss_A_aux = losses.lsgan_loss_discriminator( prob_real_is_real=self.prob_cycle_a_aux_is_real, prob_fake_is_real=self.prob_fake_pool_a_aux_is_real, ) d_loss_A = d_loss_A + d_loss_A_aux d_loss_B = losses.lsgan_loss_discriminator( prob_real_is_real=self.prob_real_b_is_real, prob_fake_is_real=self.prob_fake_pool_b_is_real, ) d_loss_F = losses.lsgan_loss_discriminator( prob_real_is_real=self.prob_fea_fake_b_is_real, prob_fake_is_real=self.prob_fea_b_is_real, ) optimizer = tf.train.AdamOptimizer(self.learning_rate, beta1=0.5) optimizer_seg = tf.train.AdamOptimizer(self.learning_rate_seg) self.model_vars = tf.trainable_variables() d_A_vars = [var for var in self.model_vars if '/d_A/' in var.name] d_B_vars = [var for var in self.model_vars if '/d_B/' in var.name] g_A_vars = [var for var in self.model_vars if '/g_A/' in var.name] e_B_vars = [var for var in self.model_vars if '/e_B/' in var.name] de_B_vars = [var for var in self.model_vars if '/de_B/' in var.name] s_B_vars = [var for var in self.model_vars if '/s_B/' in var.name] d_F_vars = [var for var in self.model_vars if '/d_F/' in var.name] self.d_A_trainer = optimizer.minimize(d_loss_A, var_list=d_A_vars) self.d_B_trainer = optimizer.minimize(d_loss_B, var_list=d_B_vars) self.g_A_trainer = optimizer.minimize(g_loss_A, var_list=g_A_vars) self.g_B_trainer = optimizer.minimize(g_loss_B, var_list=de_B_vars) self.s_B_trainer = optimizer_seg.minimize(seg_loss_B, var_list=e_B_vars + s_B_vars) self.d_F_trainer = optimizer.minimize(d_loss_F, var_list=d_F_vars) for var in self.model_vars: print(var.name) # Summary variables for tensorboard self.g_A_loss_summ = tf.summary.scalar("g_A_loss", g_loss_A) self.g_B_loss_summ = tf.summary.scalar("g_B_loss", g_loss_B) self.d_A_loss_summ = tf.summary.scalar("d_A_loss", d_loss_A) self.d_B_loss_summ = tf.summary.scalar("d_B_loss", d_loss_B) self.ce_B_loss_summ = tf.summary.scalar("ce_B_loss", ce_loss_b) self.dice_B_loss_summ = tf.summary.scalar("dice_B_loss", dice_loss_b) self.l2_B_loss_summ = tf.summary.scalar("l2_B_loss", l2_loss_b) self.s_B_loss_summ = tf.summary.scalar("s_B_loss", seg_loss_B) self.s_B_loss_merge_summ = tf.summary.merge([ self.ce_B_loss_summ, self.dice_B_loss_summ, self.l2_B_loss_summ, self.s_B_loss_summ ]) self.d_F_loss_summ = tf.summary.scalar("d_F_loss", d_loss_F)
def build_model(self): self.input_shape = (self.batchsize, self.args['input']['size'], self.args['input']['size'], self.args['input']['channel']) self.output_shape = (self.batchsize, self.args['output']['size'], self.args['output']['size'], self.args['output']['channel']) # set placeholder ## s for source # if self.flag_L1 or self.flag_d_intra or self.flag_task: self.s_gm = tf.placeholder(dtype=tf.float32, shape=self.input_shape) # if self.flag_L1 or self.flag_d_intra or self.flag_d_inter: self.s_color = tf.placeholder(dtype=tf.float32, shape=self.output_shape) if self.flag_task: temp = (self.args['batchsize'], self.args['model']['tasknet']['num_classes']) self.s_label = tf.placeholder(dtype=tf.float32, shape=temp) ## t for target if self.flag_d_inter: self.t_color = tf.placeholder(dtype=tf.float32, shape=self.output_shape) self.t_gm = tf.placeholder(dtype=tf.float32, shape=self.input_shape) # generate images if self.flag_L1 or self.flag_d_intra or self.flag_task: self.fake_s_color = generator(self.s_gm, gf_dim=self.gf_dim, o_c=self.output_shape[-1]) if self.flag_d_inter: self.fake_t_color = generator(self.t_gm, gf_dim=self.gf_dim, o_c=self.output_shape[-1]) # compute loss ## intra-domain self.loss_dict = {} if self.flag_d_intra: d_intra_logits_real = discriminator(self.s_color, df_dim=self.df_dim, name='intra_discriminator') d_intra_logits_fake = discriminator(self.fake_s_color, df_dim=self.df_dim, name='intra_discriminator') d_intra_loss_real = losses.nsgan_loss(d_intra_logits_real, is_real=True) d_intra_loss_fake = losses.nsgan_loss(d_intra_logits_fake, is_real=False) self.d_intra_loss = d_intra_loss_real + d_intra_loss_fake tf.summary.scalar("d_intra_loss", self.d_intra_loss) self.loss_dict.update({'d_intra_loss': self.d_intra_loss}) ## inter-domain if self.flag_d_inter: d_inter_logits_real = discriminator(self.s_color, df_dim=self.df_dim, name='inter_discriminator') d_inter_logits_fake = discriminator(self.fake_t_color, df_dim=self.df_dim, name='inter_discriminator') d_inter_loss_real = losses.nsgan_loss(d_inter_logits_real, is_real=True) d_inter_loss_fake = losses.nsgan_loss(d_inter_logits_fake, is_real=False) self.d_inter_loss = d_inter_loss_real + d_inter_loss_fake tf.summary.scalar("d_inter_loss", self.d_inter_loss) self.loss_dict.update({'d_inter_loss': self.d_inter_loss}) ## Generator loss flag = False self.g_loss_dict = {} ### l1 loss if self.flag_L1: self.l1_loss = losses.l1_loss(self.fake_s_color, self.s_color) self.g_loss_dict.update({'g_l1': self.l1_loss}) _lambda = self.args['model']['lambda_L1'] self.g_loss = _lambda * self.l1_loss flag = True tf.summary.scalar("l1_loss", self.l1_loss) ### task loss if self.flag_task: num_classes = self.args['model']['tasknet']['num_classes'] task_net = ResNet50(self.fake_s_color, num_classes, phase=False) task_pred_logits = task_net.outputs self.task_loss = losses.task_loss(task_pred_logits, self.s_label) self.g_loss_dict.update({'g_loss_task': self.task_loss}) _lambda = self.args['model']['tasknet']['lambda_L_task'] if flag: self.g_loss += _lambda * self.task_loss else: self.g_loss = _lambda * self.task_loss flag = True tf.summary.scalar("g_task_loss", self.task_loss) ### d-intra loss if self.flag_d_intra: self.g_loss_intra = losses.nsgan_loss(d_intra_logits_fake, True) self.g_loss_dict.update({'g_d_intra': self.g_loss_intra}) _lambda = self.args['model']['discriminator_intra'][ 'lambda_L_d_intra'] if flag: self.g_loss += _lambda * self.g_loss_intra else: self.g_loss = _lambda * self.g_loss_intra flag = True tf.summary.scalar("g_loss_intra", self.g_loss_intra) ### d-inter loss if self.flag_d_inter: self.g_loss_inter = losses.nsgan_loss(d_inter_logits_fake, True) self.g_loss_dict.update({'g_d_inter': self.g_loss_inter}) _lambda = self.args['model']['discriminator_inter'][ 'lambda_L_d_inter'] if flag: self.g_loss += _lambda * self.g_loss_inter else: self.g_loss = _lambda * self.g_loss_inter flag = True tf.summary.scalar("g_loss_inter", self.g_loss_inter) tf.summary.scalar("g_loss", self.g_loss) self.loss_dict.update(self.g_loss_dict) #log self.sample = tf.concat( [self.fake_s_color, self.s_gm[:, :, :, 1:], self.s_color], 2) if self.flag_d_inter: sample_t = tf.concat( [self.fake_t_color, self.t_gm[:, :, :, 1:], self.t_color], 2) self.sample = tf.concat([self.sample, sample_t], 1) self.sample = (self.sample + 1) * 127.5 #divide variable group t_vars = tf.trainable_variables() global_vars = tf.global_variables() self.normnet_vars_global = [] if self.flag_d_intra: self.d_intra_vars = [ var for var in t_vars if 'intra_discriminator' in var.name ] self.d_intra_vars_global = [ var for var in global_vars if 'intra_discriminator' in var.name ] self.normnet_vars_global += self.d_intra_vars_global if self.flag_d_inter: self.d_inter_vars = [ var for var in t_vars if 'inter_discriminator' in var.name ] self.d_inter_vars_global = [ var for var in global_vars if 'inter_discriminator' in var.name ] self.normnet_vars_global += self.d_inter_vars_global self.g_vars = [var for var in t_vars if 'generator' in var.name] self.g_vars_global = [ var for var in global_vars if 'generator' in var.name ] self.normnet_vars_global += self.g_vars_global if self.flag_task: self.tasknet_vars = [ var for var in t_vars if var not in self.normnet_vars_global ] # self.tasknet_vars_tainable = self.tasknet_vars[44:] self.tasknet_vars_global = [ var for var in global_vars if var not in self.normnet_vars_global ] #saver vars_save = self.normnet_vars_global self.saver = tf.train.Saver(var_list=vars_save, max_to_keep=20)