コード例 #1
0
    def compute_losses(self):

        cycle_consistency_loss_a = \
            self._lambda_a * losses.cycle_consistency_loss(
                real_images=self.input_a, generated_images=self.cycle_images_a,
            )
        cycle_consistency_loss_b = \
            self._lambda_b * losses.cycle_consistency_loss(
                real_images=self.input_b, generated_images=self.cycle_images_b,
            )

        lsgan_loss_a = losses.lsgan_loss_generator(self.prob_fake_a_is_real)
        lsgan_loss_b = losses.lsgan_loss_generator(self.prob_fake_b_is_real)
        lsgan_loss_p = losses.lsgan_loss_generator(
            self.prob_pred_mask_b_is_real)
        lsgan_loss_p_ll = losses.lsgan_loss_generator(
            self.prob_pred_mask_b_ll_is_real)
        lsgan_loss_a_aux = losses.lsgan_loss_generator(
            self.prob_fake_a_aux_is_real)

        ce_loss_b, dice_loss_b = losses.task_loss(self.pred_mask_fake_b,
                                                  self.gt_a)
        ce_loss_b_ll, dice_loss_b_ll = losses.task_loss(
            self.pred_mask_fake_b_ll, self.gt_a)
        l2_loss_b = tf.add_n([
            0.0001 * tf.nn.l2_loss(v) for v in tf.trainable_variables()
            if '/s_B/' in v.name or '/s_B_ll/' in v.name or '/e_B/' in v.name
        ])

        g_loss_A = cycle_consistency_loss_a + cycle_consistency_loss_b + lsgan_loss_b
        g_loss_B = cycle_consistency_loss_b + cycle_consistency_loss_a + lsgan_loss_a

        seg_loss_B = ce_loss_b + dice_loss_b + l2_loss_b + 0.1 * (
            ce_loss_b_ll + dice_loss_b_ll
        ) + 0.1 * g_loss_B + 0.1 * lsgan_loss_p + 0.01 * lsgan_loss_p_ll + 0.1 * lsgan_loss_a_aux

        d_loss_A = losses.lsgan_loss_discriminator(
            prob_real_is_real=self.prob_real_a_is_real,
            prob_fake_is_real=self.prob_fake_pool_a_is_real,
        )
        d_loss_A_aux = losses.lsgan_loss_discriminator(
            prob_real_is_real=self.prob_cycle_a_aux_is_real,
            prob_fake_is_real=self.prob_fake_pool_a_aux_is_real,
        )
        d_loss_A = d_loss_A + d_loss_A_aux
        d_loss_B = losses.lsgan_loss_discriminator(
            prob_real_is_real=self.prob_real_b_is_real,
            prob_fake_is_real=self.prob_fake_pool_b_is_real,
        )
        d_loss_P = losses.lsgan_loss_discriminator(
            prob_real_is_real=self.prob_pred_mask_fake_b_is_real,
            prob_fake_is_real=self.prob_pred_mask_b_is_real,
        )
        d_loss_P_ll = losses.lsgan_loss_discriminator(
            prob_real_is_real=self.prob_pred_mask_fake_b_ll_is_real,
            prob_fake_is_real=self.prob_pred_mask_b_ll_is_real,
        )

        optimizer_gan = tf.train.AdamOptimizer(self.learning_rate_gan,
                                               beta1=0.5)
        optimizer_seg = tf.train.AdamOptimizer(self.learning_rate_seg)

        self.model_vars = tf.trainable_variables()

        d_A_vars = [var for var in self.model_vars if '/d_A/' in var.name]
        d_B_vars = [var for var in self.model_vars if '/d_B/' in var.name]
        g_A_vars = [var for var in self.model_vars if '/g_A/' in var.name]
        e_B_vars = [var for var in self.model_vars if '/e_B/' in var.name]
        de_B_vars = [var for var in self.model_vars if '/de_B/' in var.name]
        s_B_vars = [var for var in self.model_vars if '/s_B/' in var.name]
        s_B_ll_vars = [
            var for var in self.model_vars if '/s_B_ll/' in var.name
        ]
        d_P_vars = [var for var in self.model_vars if '/d_P/' in var.name]
        d_P_ll_vars = [
            var for var in self.model_vars if '/d_P_ll/' in var.name
        ]

        self.d_A_trainer = optimizer_gan.minimize(d_loss_A, var_list=d_A_vars)
        self.d_B_trainer = optimizer_gan.minimize(d_loss_B, var_list=d_B_vars)
        self.g_A_trainer = optimizer_gan.minimize(g_loss_A, var_list=g_A_vars)
        self.g_B_trainer = optimizer_gan.minimize(g_loss_B, var_list=de_B_vars)
        self.d_P_trainer = optimizer_gan.minimize(d_loss_P, var_list=d_P_vars)
        self.d_P_ll_trainer = optimizer_gan.minimize(d_loss_P_ll,
                                                     var_list=d_P_ll_vars)
        self.s_B_trainer = optimizer_seg.minimize(seg_loss_B,
                                                  var_list=e_B_vars +
                                                  s_B_vars + s_B_ll_vars)

        for var in self.model_vars:
            print(var.name)

        # Summary variables for tensorboard
        self.g_A_loss_summ = tf.summary.scalar("g_A_loss", g_loss_A)
        self.g_B_loss_summ = tf.summary.scalar("g_B_loss", g_loss_B)
        self.d_A_loss_summ = tf.summary.scalar("d_A_loss", d_loss_A)
        self.d_B_loss_summ = tf.summary.scalar("d_B_loss", d_loss_B)
        self.ce_B_loss_summ = tf.summary.scalar("ce_B_loss", ce_loss_b)
        self.dice_B_loss_summ = tf.summary.scalar("dice_B_loss", dice_loss_b)
        self.l2_B_loss_summ = tf.summary.scalar("l2_B_loss", l2_loss_b)
        self.s_B_loss_summ = tf.summary.scalar("s_B_loss", seg_loss_B)
        self.s_B_loss_merge_summ = tf.summary.merge([
            self.ce_B_loss_summ, self.dice_B_loss_summ, self.l2_B_loss_summ,
            self.s_B_loss_summ
        ])
        self.d_P_loss_summ = tf.summary.scalar("d_P_loss", d_loss_P)
        self.d_P_ll_loss_summ = tf.summary.scalar("d_P_loss_ll", d_loss_P_ll)
        self.d_P_loss_merge_summ = tf.summary.merge(
            [self.d_P_loss_summ, self.d_P_ll_loss_summ])
コード例 #2
0
ファイル: main.py プロジェクト: ziyangwang007/SIFA
    def compute_losses(self):

        cycle_consistency_loss_a = \
            self._lambda_a * losses.cycle_consistency_loss(
                real_images=tf.expand_dims(self.input_a[:,:,:,1], axis=3), generated_images=self.cycle_images_a,
            )
        cycle_consistency_loss_b = \
            self._lambda_b * losses.cycle_consistency_loss(
                real_images=tf.expand_dims(self.input_b[:,:,:,1], axis=3), generated_images=self.cycle_images_b,
            )

        lsgan_loss_a = losses.lsgan_loss_generator(self.prob_fake_a_is_real)
        lsgan_loss_b = losses.lsgan_loss_generator(self.prob_fake_b_is_real)
        lsgan_loss_f = losses.lsgan_loss_generator(self.prob_fea_b_is_real)
        lsgan_loss_a_aux = losses.lsgan_loss_generator(
            self.prob_fake_a_aux_is_real)

        ce_loss_b, dice_loss_b = losses.task_loss(self.pred_mask_fake_b,
                                                  self.gt_a)

        l2_loss_b = tf.add_n([
            0.0001 * tf.nn.l2_loss(v) for v in tf.trainable_variables()
            if '/s_B/' in v.name or '/e_B/' in v.name
        ])

        g_loss_A = cycle_consistency_loss_a + cycle_consistency_loss_b + lsgan_loss_b
        g_loss_B = cycle_consistency_loss_b + cycle_consistency_loss_a + lsgan_loss_a

        self.loss_f_weight = tf.placeholder(tf.float32,
                                            shape=[],
                                            name="loss_f_weight")
        self.loss_f_weight_summ = tf.summary.scalar("loss_f_weight",
                                                    self.loss_f_weight)
        seg_loss_B = ce_loss_b + dice_loss_b + l2_loss_b + 0.1 * g_loss_B + self.loss_f_weight * lsgan_loss_f + 0.1 * lsgan_loss_a_aux

        d_loss_A = losses.lsgan_loss_discriminator(
            prob_real_is_real=self.prob_real_a_is_real,
            prob_fake_is_real=self.prob_fake_pool_a_is_real,
        )
        d_loss_A_aux = losses.lsgan_loss_discriminator(
            prob_real_is_real=self.prob_cycle_a_aux_is_real,
            prob_fake_is_real=self.prob_fake_pool_a_aux_is_real,
        )
        d_loss_A = d_loss_A + d_loss_A_aux
        d_loss_B = losses.lsgan_loss_discriminator(
            prob_real_is_real=self.prob_real_b_is_real,
            prob_fake_is_real=self.prob_fake_pool_b_is_real,
        )
        d_loss_F = losses.lsgan_loss_discriminator(
            prob_real_is_real=self.prob_fea_fake_b_is_real,
            prob_fake_is_real=self.prob_fea_b_is_real,
        )

        optimizer = tf.train.AdamOptimizer(self.learning_rate, beta1=0.5)
        optimizer_seg = tf.train.AdamOptimizer(self.learning_rate_seg)

        self.model_vars = tf.trainable_variables()

        d_A_vars = [var for var in self.model_vars if '/d_A/' in var.name]
        d_B_vars = [var for var in self.model_vars if '/d_B/' in var.name]
        g_A_vars = [var for var in self.model_vars if '/g_A/' in var.name]
        e_B_vars = [var for var in self.model_vars if '/e_B/' in var.name]
        de_B_vars = [var for var in self.model_vars if '/de_B/' in var.name]
        s_B_vars = [var for var in self.model_vars if '/s_B/' in var.name]
        d_F_vars = [var for var in self.model_vars if '/d_F/' in var.name]

        self.d_A_trainer = optimizer.minimize(d_loss_A, var_list=d_A_vars)
        self.d_B_trainer = optimizer.minimize(d_loss_B, var_list=d_B_vars)
        self.g_A_trainer = optimizer.minimize(g_loss_A, var_list=g_A_vars)
        self.g_B_trainer = optimizer.minimize(g_loss_B, var_list=de_B_vars)
        self.s_B_trainer = optimizer_seg.minimize(seg_loss_B,
                                                  var_list=e_B_vars + s_B_vars)
        self.d_F_trainer = optimizer.minimize(d_loss_F, var_list=d_F_vars)

        for var in self.model_vars:
            print(var.name)

        # Summary variables for tensorboard
        self.g_A_loss_summ = tf.summary.scalar("g_A_loss", g_loss_A)
        self.g_B_loss_summ = tf.summary.scalar("g_B_loss", g_loss_B)
        self.d_A_loss_summ = tf.summary.scalar("d_A_loss", d_loss_A)
        self.d_B_loss_summ = tf.summary.scalar("d_B_loss", d_loss_B)
        self.ce_B_loss_summ = tf.summary.scalar("ce_B_loss", ce_loss_b)
        self.dice_B_loss_summ = tf.summary.scalar("dice_B_loss", dice_loss_b)
        self.l2_B_loss_summ = tf.summary.scalar("l2_B_loss", l2_loss_b)
        self.s_B_loss_summ = tf.summary.scalar("s_B_loss", seg_loss_B)
        self.s_B_loss_merge_summ = tf.summary.merge([
            self.ce_B_loss_summ, self.dice_B_loss_summ, self.l2_B_loss_summ,
            self.s_B_loss_summ
        ])
        self.d_F_loss_summ = tf.summary.scalar("d_F_loss", d_loss_F)
コード例 #3
0
    def build_model(self):
        self.input_shape = (self.batchsize, self.args['input']['size'],
                            self.args['input']['size'],
                            self.args['input']['channel'])
        self.output_shape = (self.batchsize, self.args['output']['size'],
                             self.args['output']['size'],
                             self.args['output']['channel'])
        # set placeholder
        ## s for source
        #        if self.flag_L1 or self.flag_d_intra or self.flag_task:
        self.s_gm = tf.placeholder(dtype=tf.float32, shape=self.input_shape)
        #        if self.flag_L1 or self.flag_d_intra or self.flag_d_inter:
        self.s_color = tf.placeholder(dtype=tf.float32,
                                      shape=self.output_shape)
        if self.flag_task:
            temp = (self.args['batchsize'],
                    self.args['model']['tasknet']['num_classes'])
            self.s_label = tf.placeholder(dtype=tf.float32, shape=temp)
        ## t for target
        if self.flag_d_inter:
            self.t_color = tf.placeholder(dtype=tf.float32,
                                          shape=self.output_shape)
            self.t_gm = tf.placeholder(dtype=tf.float32,
                                       shape=self.input_shape)

        # generate images
        if self.flag_L1 or self.flag_d_intra or self.flag_task:
            self.fake_s_color = generator(self.s_gm,
                                          gf_dim=self.gf_dim,
                                          o_c=self.output_shape[-1])
        if self.flag_d_inter:
            self.fake_t_color = generator(self.t_gm,
                                          gf_dim=self.gf_dim,
                                          o_c=self.output_shape[-1])

        # compute loss
        ## intra-domain
        self.loss_dict = {}
        if self.flag_d_intra:
            d_intra_logits_real = discriminator(self.s_color,
                                                df_dim=self.df_dim,
                                                name='intra_discriminator')
            d_intra_logits_fake = discriminator(self.fake_s_color,
                                                df_dim=self.df_dim,
                                                name='intra_discriminator')
            d_intra_loss_real = losses.nsgan_loss(d_intra_logits_real,
                                                  is_real=True)
            d_intra_loss_fake = losses.nsgan_loss(d_intra_logits_fake,
                                                  is_real=False)
            self.d_intra_loss = d_intra_loss_real + d_intra_loss_fake
            tf.summary.scalar("d_intra_loss", self.d_intra_loss)
            self.loss_dict.update({'d_intra_loss': self.d_intra_loss})
        ## inter-domain
        if self.flag_d_inter:
            d_inter_logits_real = discriminator(self.s_color,
                                                df_dim=self.df_dim,
                                                name='inter_discriminator')
            d_inter_logits_fake = discriminator(self.fake_t_color,
                                                df_dim=self.df_dim,
                                                name='inter_discriminator')
            d_inter_loss_real = losses.nsgan_loss(d_inter_logits_real,
                                                  is_real=True)
            d_inter_loss_fake = losses.nsgan_loss(d_inter_logits_fake,
                                                  is_real=False)
            self.d_inter_loss = d_inter_loss_real + d_inter_loss_fake
            tf.summary.scalar("d_inter_loss", self.d_inter_loss)
            self.loss_dict.update({'d_inter_loss': self.d_inter_loss})
        ## Generator loss
        flag = False
        self.g_loss_dict = {}
        ### l1 loss
        if self.flag_L1:
            self.l1_loss = losses.l1_loss(self.fake_s_color, self.s_color)
            self.g_loss_dict.update({'g_l1': self.l1_loss})
            _lambda = self.args['model']['lambda_L1']
            self.g_loss = _lambda * self.l1_loss
            flag = True
            tf.summary.scalar("l1_loss", self.l1_loss)
        ### task loss
        if self.flag_task:
            num_classes = self.args['model']['tasknet']['num_classes']
            task_net = ResNet50(self.fake_s_color, num_classes, phase=False)
            task_pred_logits = task_net.outputs
            self.task_loss = losses.task_loss(task_pred_logits, self.s_label)
            self.g_loss_dict.update({'g_loss_task': self.task_loss})
            _lambda = self.args['model']['tasknet']['lambda_L_task']
            if flag:
                self.g_loss += _lambda * self.task_loss
            else:
                self.g_loss = _lambda * self.task_loss
                flag = True
            tf.summary.scalar("g_task_loss", self.task_loss)
        ### d-intra loss
        if self.flag_d_intra:
            self.g_loss_intra = losses.nsgan_loss(d_intra_logits_fake, True)
            self.g_loss_dict.update({'g_d_intra': self.g_loss_intra})
            _lambda = self.args['model']['discriminator_intra'][
                'lambda_L_d_intra']
            if flag:
                self.g_loss += _lambda * self.g_loss_intra
            else:
                self.g_loss = _lambda * self.g_loss_intra
                flag = True
            tf.summary.scalar("g_loss_intra", self.g_loss_intra)
        ### d-inter loss
        if self.flag_d_inter:
            self.g_loss_inter = losses.nsgan_loss(d_inter_logits_fake, True)
            self.g_loss_dict.update({'g_d_inter': self.g_loss_inter})
            _lambda = self.args['model']['discriminator_inter'][
                'lambda_L_d_inter']
            if flag:
                self.g_loss += _lambda * self.g_loss_inter
            else:
                self.g_loss = _lambda * self.g_loss_inter
                flag = True
            tf.summary.scalar("g_loss_inter", self.g_loss_inter)
        tf.summary.scalar("g_loss", self.g_loss)
        self.loss_dict.update(self.g_loss_dict)

        #log
        self.sample = tf.concat(
            [self.fake_s_color, self.s_gm[:, :, :, 1:], self.s_color], 2)
        if self.flag_d_inter:
            sample_t = tf.concat(
                [self.fake_t_color, self.t_gm[:, :, :, 1:], self.t_color], 2)
            self.sample = tf.concat([self.sample, sample_t], 1)
        self.sample = (self.sample + 1) * 127.5

        #divide variable group
        t_vars = tf.trainable_variables()
        global_vars = tf.global_variables()
        self.normnet_vars_global = []
        if self.flag_d_intra:
            self.d_intra_vars = [
                var for var in t_vars if 'intra_discriminator' in var.name
            ]
            self.d_intra_vars_global = [
                var for var in global_vars if 'intra_discriminator' in var.name
            ]
            self.normnet_vars_global += self.d_intra_vars_global
        if self.flag_d_inter:
            self.d_inter_vars = [
                var for var in t_vars if 'inter_discriminator' in var.name
            ]
            self.d_inter_vars_global = [
                var for var in global_vars if 'inter_discriminator' in var.name
            ]
            self.normnet_vars_global += self.d_inter_vars_global
        self.g_vars = [var for var in t_vars if 'generator' in var.name]
        self.g_vars_global = [
            var for var in global_vars if 'generator' in var.name
        ]
        self.normnet_vars_global += self.g_vars_global
        if self.flag_task:
            self.tasknet_vars = [
                var for var in t_vars if var not in self.normnet_vars_global
            ]
            #            self.tasknet_vars_tainable = self.tasknet_vars[44:]
            self.tasknet_vars_global = [
                var for var in global_vars
                if var not in self.normnet_vars_global
            ]

        #saver
        vars_save = self.normnet_vars_global
        self.saver = tf.train.Saver(var_list=vars_save, max_to_keep=20)