示例#1
0
    def test_bin_seg_2d(self):
        # define 2d examples
        target = torch.tensor([[0, 0, 0, 0], [0, 1, 1, 0], [0, 1, 1, 0],
                               [0, 0, 0, 0]])
        # add another dimension corresponding to the batch (batch size = 1 here)
        target = target.unsqueeze(0)  # shape (1, H, W)
        pred_very_good = 1000 * F.one_hot(target, num_classes=2).permute(
            0, 3, 1, 2).float()

        # initialize the mean dice loss
        loss = FocalLoss()

        # focal loss for pred_very_good should be close to 0
        focal_loss_good = float(loss.forward(pred_very_good, target).cpu())
        self.assertAlmostEqual(focal_loss_good, 0.0, places=3)

        # Same test, but for target with a class dimension
        target = target.unsqueeze(1)  # shape (1, 1, H, W)
        focal_loss_good = float(loss.forward(pred_very_good, target).cpu())
        self.assertAlmostEqual(focal_loss_good, 0.0, places=3)
示例#2
0
 def test_consistency_with_cross_entropy_classification(self):
     # for gamma=0 the focal loss reduces to the cross entropy loss
     focal_loss = FocalLoss(gamma=0.0, reduction="mean")
     ce = nn.CrossEntropyLoss(reduction="mean")
     max_error = 0
     class_num = 10
     batch_size = 128
     for _ in range(100):
         # Create a random scores tensor of shape (batch_size, class_num)
         x = torch.rand(batch_size, class_num, requires_grad=True)
         # Create a random batch of classes
         l = torch.randint(low=0, high=class_num, size=(batch_size, 1))
         l = l.long()
         if torch.cuda.is_available():
             x = x.cuda()
             l = l.cuda()
         output0 = focal_loss.forward(x, l)
         output1 = ce.forward(x, l[:, 0])
         a = float(output0.cpu().detach())
         b = float(output1.cpu().detach())
         if abs(a - b) > max_error:
             max_error = abs(a - b)
     self.assertAlmostEqual(max_error, 0.0, places=3)