def criterion(self, logit, logit_pixel, logit_image, truth_pixel, is_average=True):
        
        ## image classification loss
        batch_size, c, h, w = truth_pixel.shape
        truth_image = torch.tensor(np.array([((truth_pixel[i, :, :, :].sum()) > 0.5)    for i in range(batch_size)]), dtype=torch.float32, device='cuda:0')
        loss_image = F.binary_cross_entropy_with_logits(logit_image, truth_image)

        ## segmentation loss
        k = len(logit_pixel)
        if self.loss_type == 'focal':
            loss_pixel = RobustFocalLoss2d()(logit_pixel[0], truth_pixel, type='sigmoid')
            for i in range(1, k):
                loss_pixel += RobustFocalLoss2d()(logit_pixel[i], truth_pixel, type='sigmoid')
            loss_pixel *= 1.0 / k
            loss_seg = RobustFocalLoss2d()(logit, truth_pixel, type='sigmoid')
        elif self.loss_type == 'lovasz':
            loss_pixel = L.lovasz_hinge(logit_pixel[0], truth_pixel)
            for i in range(1, k):
                loss_pixel += L.lovasz_hinge(logit_pixel[i], truth_pixel)
            loss_pixel *= 1.0 / k
            loss_seg = L.lovasz_hinge(logit, truth_pixel)

        ## non-empty image seg loss
        loss_pixel = loss_pixel * truth_image #loss for empty image is weighted 0

        if is_average:
            loss_pixel = loss_pixel.sum() / truth_image.sum()

        return loss_seg, loss_pixel, loss_image
예제 #2
0
    def criterion(self, logit, truth):

        # loss = PseudoBCELoss2d()(logit, truth)
        # loss = FocalLoss2d()(logit, truth, type='sigmoid')
        # loss = RobustFocalLoss2d()(logit, truth, type='sigmoid')
        # return loss
        loss = L.lovasz_hinge(logit, truth)
        return loss
예제 #3
0
 def forward(self, logit, truth):
     loss = lovasz_losses.lovasz_hinge(logit, truth)
     return loss