Exemplo n.º 1
0
    def backward_G(self):
        loss_app_gen = self.L1loss(self.img_gen, self.input_P2)
        loss_correctness_gen = self.Correctness(self.input_P2, self.input_P1,
                                                self.flow_fields,
                                                self.opt.attn_layer)
        self.loss_correctness_gen = loss_correctness_gen * self.opt.lambda_correct
        self.loss_app_gen = loss_app_gen * self.opt.lambda_rec

        base_function._freeze(self.net_D)
        D_fake = self.net_D(self.img_gen)
        self.loss_ad_gen = self.GANloss(D_fake, True,
                                        False) * self.opt.lambda_g

        loss_regularization = self.Regularization(self.flow_fields)
        self.loss_regularization = loss_regularization * self.opt.lambda_regularization

        loss_content_gen, loss_style_gen = self.Vggloss(
            self.img_gen, self.input_P2)
        self.loss_style_gen = loss_style_gen * self.opt.lambda_style
        self.loss_content_gen = loss_content_gen * self.opt.lambda_content

        total_loss = 0

        for name in self.loss_names:
            if name != 'dis_img_rec' and name != 'dis_img_gen':
                total_loss += getattr(self, "loss_" + name)
        total_loss.backward()
Exemplo n.º 2
0
    def backward_G(self):
        """Calculate training loss for the generator"""
        # Calculate regularzation loss to make transformed feature and target image feature in the same latent space
        self.loss_reg_gen = self.loss_reg * self.opt.lambda_regularization

        # Calculate l1 loss
        loss_app_gen = self.L1loss(self.img_gen, self.input_P2)
        self.loss_app_gen = loss_app_gen * self.opt.lambda_rec

        # parsing loss
        label_P2 = self.label_P2.squeeze(1).long()
        #print(self.input_SPL2.min(), self.input_SPL2.max(), self.parsav.min(), self.parsav.max())
        self.loss_par = self.parLoss(self.parsav, label_P2)  # * 20.
        self.loss_par1 = self.L1loss(self.parsav, self.input_SPL2) * 100

        # Calculate GAN loss
        base_function._freeze(self.net_D)
        D_fake = self.net_D(self.img_gen)
        self.loss_ad_gen = self.GANloss(D_fake, True,
                                        False) * self.opt.lambda_g

        # Calculate perceptual loss
        loss_content_gen, loss_style_gen = self.Vggloss(
            self.img_gen, self.input_P2)
        self.loss_style_gen = loss_style_gen * self.opt.lambda_style
        self.loss_content_gen = loss_content_gen * self.opt.lambda_content

        total_loss = 0

        for name in self.loss_names:
            if name != 'dis_img_gen':
                #print(getattr(self, "loss_" + name))
                total_loss += getattr(self, "loss_" + name)
        total_loss.backward()
    def backward_G(self):
        """Calculate training loss for the generator"""
        # Calculate l1 loss
        loss_app_gen = self.L1loss(self.img_gen, self.input_P2)
        self.loss_app_gen = loss_app_gen * self.opt.lambda_rec

        # Calculate Sampling Correctness Loss
        loss_correctness_gen = self.Correctness(self.input_P2, self.input_P1,
                                                self.flow_fields,
                                                self.opt.attn_layer)
        self.loss_correctness_gen = loss_correctness_gen * self.opt.lambda_correct

        # Calculate GAN loss
        base_function._freeze(self.net_D)
        D_fake = self.net_D(self.img_gen)
        self.loss_ad_gen = self.GANloss(D_fake, True,
                                        False) * self.opt.lambda_g

        # Calculate regularization term
        loss_regularization = self.Regularization(self.flow_fields)
        self.loss_regularization = loss_regularization * self.opt.lambda_regularization

        # Calculate perceptual loss
        loss_content_gen, loss_style_gen = self.Vggloss(
            self.img_gen, self.input_P2)
        self.loss_style_gen = loss_style_gen * self.opt.lambda_style
        self.loss_content_gen = loss_content_gen * self.opt.lambda_content

        total_loss = 0

        for name in self.loss_names:
            if name != 'dis_img_gen':
                total_loss += getattr(self, "loss_" + name)
        total_loss.backward()
    def backward_G(self):
        """Calculate training loss for the generator"""
        # gen_tensor = torch.cat([v.unsqueeze(1) for v in self.img_gen], 1)
        loss_style_gen, loss_content_gen, loss_app_gen = 0, 0, 0

        for i in range(len(self.img_gen)):
            gen = self.img_gen[i]
            gt = self.P_frame_step[:, i, ...]
            loss_app_gen += self.L1loss(gen, gt)

            if self.opt.use_vgg_loss:
                content_gen, style_gen = self.Vggloss(gen, gt)
                loss_style_gen += style_gen
                loss_content_gen += content_gen

        self.loss_style_gen = loss_style_gen * self.opt.lambda_style
        self.loss_content_gen = loss_content_gen * self.opt.lambda_content
        self.loss_app_gen = loss_app_gen * self.opt.lambda_rec

        loss_correctness_p, loss_regularization_p = 0, 0
        loss_correctness_r, loss_regularization_r = 0, 0

        for i in range(len(self.flow_fields)):
            flow_field_i = self.flow_fields[i]
            flow_p, flow_r = [], []
            for j in range(0, len(flow_field_i), 2):
                flow_p.append(flow_field_i[j])
                flow_r.append(flow_field_i[j + 1])

            correctness_r = self.Correctness(self.P_frame_step[:, i, ...],
                                             self.P_reference, flow_r,
                                             self.opt.attn_layer)
            correctness_p = self.Correctness(
                self.P_frame_step[:, i,
                                  ...], self.P_previous_recoder[i].detach(),
                flow_p, self.opt.attn_layer)
            loss_correctness_p += correctness_p
            loss_correctness_r += correctness_r
            if self.opt.use_affine_regularization:
                loss_regularization_p += self.Regularization(flow_p)
                loss_regularization_r += self.Regularization(flow_r)

        self.loss_correctness_p = loss_correctness_p * self.opt.lambda_correct
        self.loss_correctness_r = loss_correctness_r * self.opt.lambda_correct
        self.loss_regularization_p = loss_regularization_p * self.opt.lambda_regularization
        self.loss_regularization_r = loss_regularization_r * self.opt.lambda_regularization

        # rec loss fake
        if self.opt.use_gan:
            base_function._freeze(self.net_D)
            i = np.random.randint(len(self.img_gen))
            fake = self.img_gen[i]
            D_fake = self.net_D(fake)
            self.loss_ad_gen = self.GANloss(D_fake, True,
                                            False) * self.opt.lambda_g

            ##########################################################################
            base_function._freeze(self.net_D_V)
            i = np.random.randint(len(self.img_gen) - self.opt.frames_D_V + 1)
            # fake = [self.img_gen[i]]
            fake = []
            for frame in range(self.opt.frames_D_V - 1):
                fake.append(self.img_gen[i + frame] -
                            self.img_gen[i + frame + 1])
            fake = torch.cat(fake, dim=1)
            D_fake = self.net_D_V(fake)
            self.loss_ad_gen_v = self.GANloss(D_fake, True,
                                              False) * self.opt.lambda_g
            ##########################################################################
        total_loss = 0
        for name in self.loss_names:
            if name != 'dis_img_gen_v' and name != 'dis_img_gen':
                total_loss += getattr(self, "loss_" + name)
        total_loss.backward()
Exemplo n.º 5
0
    def backward_G(self):
        """Calculate training loss for the generator"""
        loss_style_gen, loss_content_gen, loss_app_gen, loss_pose=0,0,0,0
        loss_mask_app_gen = 0

        # Calculate the Reconstruction Loss
        for i in range(len(self.img_gen)):
            gen = self.img_gen[i]
            gt = self.P_step[:,i,...]
            loss_app_gen += self.L1loss(gen, gt)
            # Add cloth mask  L1 loss
            # TODO: populate self.cloth_mask with the binary mask corresponding
            # to each frame, and uncomment the following line.
            
            # loss_mask_app_gen += torch.nn.L1loss(gen * self.cloth_mask, gt * self.cloth_mask)            

            content_gen, style_gen = self.Vggloss(gen, gt) 
            loss_style_gen += style_gen
            loss_content_gen += content_gen

        self.loss_style_gen = loss_style_gen * self.opt.lambda_style
        self.loss_content_gen = loss_content_gen * self.opt.lambda_content            
        self.loss_app_gen = loss_app_gen * self.opt.lambda_rec
        self.loss_app_gen += loss_mask_app_gen * self.opt.lambda_cloth_mask


        loss_correctness_p, loss_regularization_p=0, 0
        loss_correctness_r, loss_regularization_r=0, 0

        # Calculate the Sampling Correctness Loss and Regularization Loss
        for i in range(len(self.flow_fields)):
            flow_field_i = self.flow_fields[i]
            flow_p, flow_r=[],[]
            for j in range(0, len(flow_field_i), 2):
                flow_p.append(flow_field_i[j])
                flow_r.append(flow_field_i[j+1])

            mask = self.mask_step[:,i,...] if self.opt.use_mask else None
            correctness_r = self.Correctness(self.P_step[:,i,...], self.ref_image, 
                                            flow_r, self.opt.attn_layer, mask)
            correctness_p = self.Correctness(self.P_step[:,i,...], self.P_gt_previous_recoder[:,i,...], 
                                            flow_p, self.opt.attn_layer, mask)
            loss_correctness_p += correctness_p
            loss_correctness_r += correctness_r

            loss_regularization_p += self.Regularization(flow_p)
            loss_regularization_r += self.Regularization(flow_r)


        self.loss_correctness_p = loss_correctness_p * self.opt.lambda_correct     
        self.loss_correctness_r = loss_correctness_r * self.opt.lambda_correct   
        self.loss_regularization_p = loss_regularization_p * self.opt.lambda_regularization
        self.loss_regularization_r = loss_regularization_r * self.opt.lambda_regularization


        # Spatial GAN Loss
        base_function._freeze(self.net_D)
        i = np.random.randint(len(self.img_gen))
        fake = self.img_gen[i]
        D_fake = self.net_D(fake)
        self.loss_ad_gen = self.GANloss(D_fake, True, False) * self.opt.lambda_g

        # Temporal GAN Loss        
        base_function._freeze(self.net_D_V)
        i = np.random.randint(len(self.img_gen)-self.opt.frames_D_V+1)
        fake = []
        for frame in range(self.opt.frames_D_V):
            fake.append(self.img_gen[i+frame].unsqueeze(2))
        fake = torch.cat(fake, dim=2)
        D_fake = self.net_D_V(fake)
        self.loss_ad_gen_v = self.GANloss(D_fake, True, False) * self.opt.lambda_g

        total_loss = 0
        for name in self.loss_names:
            if name != 'dis_img_gen_v' and name != 'dis_img_gen':
                total_loss += getattr(self, "loss_" + name)
        total_loss.backward()