Beispiel #1
0
    def forward(self, x, bboxes, masks, ground_truth, compute_loss_g=False):
        self.train()
        l1_loss = nn.L1Loss()
        losses = {}

        x1, x2, offset_flow = self.netG(x, masks)
        local_patch_gt = local_patch(ground_truth, bboxes)
        x1_inpaint = x1 * masks + x * (1. - masks)
        x2_inpaint = x2 * masks + x * (1. - masks)
        local_patch_x1_inpaint = local_patch(x1_inpaint, bboxes)
        local_patch_x2_inpaint = local_patch(x2_inpaint, bboxes)

        # D part
        # wgan d loss
        local_patch_real_pred, local_patch_fake_pred = self.dis_forward(
            self.localD, local_patch_gt, local_patch_x2_inpaint.detach())
        global_real_pred, global_fake_pred = self.dis_forward(
            self.globalD, ground_truth, x2_inpaint.detach())
        losses['wgan_d'] = torch.mean(local_patch_fake_pred - local_patch_real_pred) + \
            torch.mean(global_fake_pred - global_real_pred) * \
            self.config['global_wgan_loss_alpha']
        # gradients penalty loss
        local_penalty = self.calc_gradient_penalty(
            self.localD, local_patch_gt, local_patch_x2_inpaint.detach())
        global_penalty = self.calc_gradient_penalty(self.globalD, ground_truth,
                                                    x2_inpaint.detach())
        losses['wgan_gp'] = local_penalty + global_penalty

        # G part
        if compute_loss_g:
            sd_mask = spatial_discounting_mask(self.config)
            losses['l1'] = l1_loss(local_patch_x1_inpaint * sd_mask, local_patch_gt * sd_mask) * \
                self.config['coarse_l1_alpha'] + \
                l1_loss(local_patch_x2_inpaint * sd_mask,
                        local_patch_gt * sd_mask)
            losses['ae'] = l1_loss(x1 * (1. - masks), ground_truth * (1. - masks)) * \
                self.config['coarse_l1_alpha'] + \
                l1_loss(x2 * (1. - masks), ground_truth * (1. - masks))

            # wgan g loss
            local_patch_real_pred, local_patch_fake_pred = self.dis_forward(
                self.localD, local_patch_gt, local_patch_x2_inpaint)
            global_real_pred, global_fake_pred = self.dis_forward(
                self.globalD, ground_truth, x2_inpaint)
            losses['wgan_g'] = - torch.mean(local_patch_fake_pred) - \
                torch.mean(global_fake_pred) * \
                self.config['global_wgan_loss_alpha']

        return losses, x2_inpaint, offset_flow
def forward(config, x, bboxes, masks, ground_truth,
            netG, localD, globalD,
            local_rank, compute_loss_g=False):

        l1_loss = nn.L1Loss().cuda(local_rank)
        losses = {}


        x1, x2, offset_flow = netG(x, masks)

        
        local_patch_gt = local_patch(ground_truth, bboxes)
        x1_inpaint = x1 * masks + x * (1. - masks)
        x2_inpaint = x2 * masks + x * (1. - masks)
        local_patch_x1_inpaint = local_patch(x1_inpaint, bboxes)
        local_patch_x2_inpaint = local_patch(x2_inpaint, bboxes)
        
        
#         # D part
#         # wgan d loss
#         local_patch_real_pred, local_patch_fake_pred = self.dis_forward(
#             self.localD, local_patch_gt, local_patch_x2_inpaint.detach()
#         )
        batch_size = local_patch_gt.size(0)
        batch_data = torch.cat([local_patch_gt, local_patch_x2_inpaint.detach()], dim=0)
        batch_output = localD(batch_data)
        local_patch_real_pred, local_patch_fake_pred = torch.split(batch_output, batch_size, dim=0)
        
        
#         global_real_pred, global_fake_pred = self.dis_forward(
#             self.globalD, ground_truth, x2_inpaint.detach()
#         )
        batch_size = ground_truth.size(0)
        batch_data = torch.cat([ground_truth, x2_inpaint.detach()], dim=0)
        batch_output = globalD(batch_data)
        global_real_pred, global_fake_pred = torch.split(batch_output, batch_size, dim=0)
        
        
        losses['wgan_d'] = torch.mean(local_patch_fake_pred - local_patch_real_pred) + \
            torch.mean(global_fake_pred - global_real_pred) * config['global_wgan_loss_alpha']
        
        
        # gradient penalty loss
        #
        local_penalty = calc_gradient_penalty(
            localD, local_patch_gt, local_patch_x2_inpaint.detach(), local_rank
        )
        
        global_penalty = calc_gradient_penalty(
            globalD, ground_truth, x2_inpaint.detach(), local_rank
        )
        
        losses['wgan_gp'] = local_penalty + global_penalty
        
        
        # G part
        if compute_loss_g:
            sd_mask = spatial_discounting_mask(config)
            losses['l1'] = l1_loss(local_patch_x1_inpaint * sd_mask, local_patch_gt * sd_mask) * \
                config['coarse_l1_alpha'] + \
                l1_loss(local_patch_x2_inpaint * sd_mask, local_patch_gt * sd_mask)
            
            losses['ae'] = l1_loss(x1 * (1. - masks), ground_truth * (1. - masks)) * \
                config['coarse_l1_alpha'] + \
                l1_loss(x2 * (1. - masks), ground_truth * (1. - masks))

            # wgan g loss
            local_patch_real_pred, local_patch_fake_pred = discriminator_pred(
                localD, local_patch_gt, local_patch_x2_inpaint
            )
            
            global_real_pred, global_fake_pred = discriminator_pred(
                globalD, ground_truth, x2_inpaint
            )
            
            losses['wgan_g'] = - torch.mean(local_patch_fake_pred) - \
                torch.mean(global_fake_pred) * config['global_wgan_loss_alpha']

        
        return losses, x2_inpaint, offset_flow