def test_ls_celoss(self): num_batches = 2 num_classes = 4 # Common verification self._test_loss_function('ls_cross_entropy', 0.1 / num_classes * (num_classes - 1) * 9) x = torch.rand(num_batches, num_classes, 20, 20) target = (num_classes * torch.rand(num_batches, 20, 20)).to(torch.long) # Value check self.assertAlmostEqual(F.ls_cross_entropy(x, target, eps=0).item(), nn.functional.cross_entropy(x, target).item(), places=5) self.assertAlmostEqual(F.ls_cross_entropy(x, target, eps=1).item(), -1 / num_classes * nn.functional.log_softmax(x, dim=1).sum(dim=1).mean().item(), places=5)
def test_ls_celoss(): num_batches = 2 num_classes = 4 # Common verification _test_loss_function(F.ls_cross_entropy, 0.1 / num_classes * (num_classes - 1) * 9) x = torch.rand(num_batches, num_classes, 20, 20) target = (num_classes * torch.rand(num_batches, 20, 20)).to(torch.long) # Value check assert torch.allclose(F.ls_cross_entropy(x, target, eps=0), cross_entropy(x, target), atol=1e-5) assert torch.allclose(F.ls_cross_entropy(x, target, eps=1), -1 / num_classes * log_softmax(x, dim=1).sum(dim=1).mean(), atol=1e-5) assert repr(nn.LabelSmoothingCrossEntropy() ) == "LabelSmoothingCrossEntropy(eps=0.1, reduction='mean')"