def test_focal_loss(self): # Common verification self._test_loss_function('focal_loss') num_batches = 2 num_classes = 4 x = torch.rand(num_batches, num_classes, 20, 20) target = (num_classes * torch.rand(num_batches, 20, 20)).to(torch.long) # Value check self.assertAlmostEqual(F.focal_loss(x, target, gamma=0).item(), nn.functional.cross_entropy(x, target).item(), places=5) # Equal probabilities x = torch.ones(num_batches, num_classes, 20, 20) self.assertAlmostEqual((1 - 1 / num_classes) * F.focal_loss(x, target, gamma=0).item(), F.focal_loss(x, target, gamma=1).item(), places=5)
def test_focal_loss(): # Common verification _test_loss_function(F.focal_loss) num_batches = 2 num_classes = 4 x = torch.rand(num_batches, num_classes, 20, 20) target = (num_classes * torch.rand(num_batches, 20, 20)).to(torch.long) # Value check assert torch.allclose(F.focal_loss(x, target, gamma=0), cross_entropy(x, target), atol=1e-5) # Equal probabilities x = torch.ones(num_batches, num_classes, 20, 20) assert torch.allclose( (1 - 1 / num_classes) * F.focal_loss(x, target, gamma=0), F.focal_loss(x, target, gamma=1), atol=1e-5) assert repr(nn.FocalLoss()) == "FocalLoss(gamma=2.0, reduction='mean')"