def __getitem__(self, index):
        ### 这里暂时采用gt parsing b label, 后面再用根据parsing a + pose b生成的label
        a_jpg_path, b_jpg_path, a_parsing_path, b_parsing_path, a_json_path, b_json_path, theta_pair_key = self.get_paths(
            index)
        b_label_tensor, b_label_show_tensor = get_label_tensor(
            b_json_path, b_jpg_path, self.opt)

        a_parsing_tensor = get_parsing_label_tensor(a_parsing_path, self.opt)
        b_parsing_tensor = get_parsing_label_tensor(b_parsing_path, self.opt)
        a_image_tensor = get_image_tensor(a_jpg_path, self.opt)
        b_image_tensor = get_image_tensor(b_jpg_path, self.opt)

        theta_aff_tensor, theta_tps_tensor, theta_aff_tps_tensor = get_thetas_affgrid_tensor(
            self.affTnf, self.tpsTnf, self.theta_json_data,
            theta_pair_key)  ### [1,18] --> 1*256*256

        input_dict = {
            'a_image_tensor': a_image_tensor, \
            'b_image_tensor': b_image_tensor, \
            'b_label_tensor': b_label_tensor, \
            'a_parsing_tensor': a_parsing_tensor, \
            'b_parsing_tensor': b_parsing_tensor, \
            'b_label_show_tensor': b_label_show_tensor, \
            'theta_aff_tensor': theta_aff_tensor, \
            'theta_tps_tensor': theta_tps_tensor, \
            'theta_aff_tps_tensor': theta_aff_tps_tensor, \
            'a_jpg_path': a_jpg_path, \
            'b_jpg_path': b_jpg_path}

        return input_dict
Beispiel #2
0
    def discriminator_loss(self, forged_with_input: torch.Tensor,
                           reference_with_input: torch.Tensor) -> torch.Tensor:
        output = self.discriminator(reference_with_input)
        label = get_label_tensor(output, 1.0, self.device)

        real_intermediate_outputs = self.discriminator.get_intermediate_output(
        )

        loss_D_real = self.hparams.gan_criterion(output, label)

        output = self.discriminator(forged_with_input)
        label = get_label_tensor(output, 0.0, self.device)

        forged_intermediate_outputs = self.discriminator.get_intermediate_output(
        )

        loss_D_forged = self.hparams.gan_criterion(output, label)

        loss_D_PAN = torch.zeros(1, device=self.device)

        for (forged_i, real_i, lam) in zip(forged_intermediate_outputs,
                                           real_intermediate_outputs,
                                           self.hparams.lambdas_pan):
            loss_D_PAN += self.hparams.pan_criterion(forged_i, real_i) * lam

        self.loss_PAN = loss_D_PAN.detach()

        loss_D_PAN = max(
            torch.zeros(1, device=self.device),
            torch.tensor(self.hparams.pan_margin, device=self.device) -
            loss_D_PAN)

        self.log('d_PAN_loss',
                 loss_D_PAN.detach(),
                 on_step=True,
                 on_epoch=False)
        self.log('d_real_loss',
                 loss_D_real.detach(),
                 on_step=True,
                 on_epoch=False)
        self.log('d_forged_loss',
                 loss_D_forged.detach(),
                 on_step=True,
                 on_epoch=False)

        return (loss_D_real + loss_D_forged) * 0.5 + loss_D_PAN
    def __getitem__(self, index):
        ### 这里暂时采用gt parsing b label, 后面再用根据parsing a + pose b生成的label
        a_jpg_path, b_jpg_path, a_parsing_path, b_parsing_path, a_json_path, b_json_path, theta_pair_key = self.get_paths(index)
        b_label_tensor, b_label_show_tensor = get_label_tensor(b_json_path, b_jpg_path, self.opt)
        a_image_tensor = get_image_tensor(a_jpg_path, self.opt)
        b_image_tensor = get_image_tensor(b_jpg_path, self.opt)

        input_dict = {
            'a_image_tensor': a_image_tensor, \
            'b_image_tensor': b_image_tensor, \
            'b_label_tensor': b_label_tensor, \
            'b_label_show_tensor': b_label_show_tensor, \
            'a_jpg_path': a_jpg_path, \
            'b_jpg_path': b_jpg_path}

        return input_dict
Beispiel #4
0
    def __getitem__(self, index):
        ### 这里暂时采用gt parsing b label, 后面再用根据parsing a + pose b生成的label
        a_jpg_path, b_jpg_path, a_parsing_path, b_parsing_path, a_json_path, b_json_path, theta_pair_key = self.get_paths(index)
        b_label_tensor, b_label_show_tensor = get_label_tensor(b_json_path, b_jpg_path, self.opt)

        a_parsing_tensor = get_parsing_label_tensor(a_parsing_path, self.opt)
        b_parsing_tensor = get_parsing_label_tensor(b_parsing_path, self.opt)
        a_image_tensor = get_image_tensor(a_jpg_path, self.opt)
        b_image_tensor = get_image_tensor(b_jpg_path, self.opt)

        theta_aff_tensor, theta_tps_tensor, theta_aff_tps_tensor = get_thetas_affgrid_tensor(self.affTnf, self.tpsTnf, self.theta_json_data, theta_pair_key)  ### [1,18] --> 1*256*256
        if not self.opt.no_dynamic_policy:
            # policy_binary = get_policy_tensor(self.policy_json_data, theta_pair_key)  ### [1,18] --> 1*256*256
            print (1)
        # 固定policy方式跑四个模型, 后面再跑
        else:
            policy_binary = torch.zeros(1, a_parsing_tensor.size(1), a_parsing_tensor.size(2))
            if self.opt.which_policy == 'policy1': ### 00 (这个是必须跑的。事先给定00, 当为00时就代表已经去掉两个warp-block)
                policy_binary[:, 0, 0] = 0
                policy_binary[:, 0, 1] = 0
            elif self.opt.which_policy == 'policy2': ### 01
                policy_binary[:, 0, 0] = 0
                policy_binary[:, 0, 1] = 1
                #print policy_binary
            elif self.opt.which_policy == 'policy3': ### 10
                policy_binary[:, 0, 0] = 1
                policy_binary[:, 0, 1] = 0
            elif self.opt.which_policy == 'policy4': ### 11
                policy_binary[:, 0, 0] = 1
                policy_binary[:, 0, 1] = 1

        input_dict = {
            'a_image_tensor': a_image_tensor, \
            'b_image_tensor': b_image_tensor, \
            'b_label_tensor': b_label_tensor, \
            'a_parsing_tensor': a_parsing_tensor, \
            'b_parsing_tensor': b_parsing_tensor, \
            'b_label_show_tensor': b_label_show_tensor, \
            'theta_aff_tensor': theta_aff_tensor, \
            'theta_tps_tensor': theta_tps_tensor, \
            'theta_aff_tps_tensor': theta_aff_tps_tensor, \
            'policy_binary': policy_binary, \
            'a_jpg_path': a_jpg_path, \
            'b_jpg_path': b_jpg_path}

        return input_dict
    def __getitem__(self, index):
        a_jpg_path, b_jpg_path, a_parsing_path, b_json_path = self.get_paths(index)
        
        b_label_tensor, b_label_show_tensor = get_label_tensor(b_json_path, b_jpg_path, self.opt)

        a_parsing_tensor = get_parsing_label_tensor(a_parsing_path, self.opt)
        a_image_tensor = get_image_tensor(a_jpg_path, self.opt)
        b_image_tensor = get_image_tensor(b_jpg_path, self.opt)


        input_dict = {
            'a_image_tensor': a_image_tensor, \
            'b_image_tensor': b_image_tensor, \
            'b_label_tensor': b_label_tensor, \
            'a_parsing_tensor': a_parsing_tensor, \
            'b_label_show_tensor': b_label_show_tensor, \
            'a_jpg_path': a_jpg_path, \
            'b_jpg_path': b_jpg_path}

        return input_dict
Beispiel #6
0
    def generator_loss(self, input_images: torch.Tensor,
                       reference_images: torch.Tensor) -> torch.Tensor:
        forged_images = self(input_images)

        forged_with_input = torch.cat((input_images[:, 0:1], forged_images), 1)
        output = self.discriminator(forged_with_input)

        label = get_label_tensor(output, 1.0, self.device)

        loss_G_GAN = self.hparams.gan_criterion(output, label)

        loss_G_L = self.hparams.l_criterion(
            forged_images, reference_images[:, 0:1]) * self.hparams.lambda_l

        self.log('g_GAN_loss',
                 loss_G_GAN.detach(),
                 on_step=True,
                 on_epoch=False)
        self.log('g_L_loss', loss_G_L.detach(), on_step=True, on_epoch=False)

        return loss_G_GAN + loss_G_L + self.loss_PAN