def test_simple_forward_pass_nll_top(self): loss = ExplodedLogitLoss(loss_type='nll', reduction='sum', top_n=3) scores = torch.tensor([1.2, 4.8, 0.2, 5.6, 7.4, 0.], dtype=torch.float64) order = torch.tensor([6, 5, 3, 4, 2, 1], dtype=torch.long) loss_expected = torch.tensor(13.6171, dtype=torch.float64) loss_actual = loss.forward(scores, order) self.assertTrue( torch.isclose(loss_actual, loss_expected, atol=1e-4), "Forward pass not valid: {0} != {1}".format( loss_actual, loss_expected))
def test_batch_forward_pass_bce(self): loss = ExplodedLogitLoss(loss_type='bce', reduction='sum') scores = torch.tensor( [[1.2, 4.8, 0.2, 5.6, 7.4, 0.], [1.2, 4.8, 0.2, 5.6, 7.4, 0.]], dtype=torch.float64) order = torch.tensor([[6, 5, 3, 4, 2, 1], [6, 5, 3, 4, 2, 1]], dtype=torch.long) loss_expected = torch.tensor(17.9922 * 2, dtype=torch.float64) loss_actual = loss.forward(scores, order) self.assertTrue( torch.isclose(loss_actual, loss_expected, atol=1e-4), "Forward pass not valid: {0} != {1}".format( loss_actual, loss_expected))
def test_simple_backward_pass(self): loss = ExplodedLogitLoss(loss_type='bce', reduction='sum') scores = torch.tensor([1.2, 4.8, 0.2, 5.6, 7.4, 0.], dtype=torch.float64, requires_grad=True) order = torch.tensor([6, 5, 3, 4, 2, 1], dtype=torch.long) loss = loss.forward(scores, order) loss.backward() grad_expected = torch.tensor( [0.0604, 0.4864, -1.0052, 0.3989, 1.0611, -1.0016], dtype=torch.float64) grad_actual = scores.grad self.assertTrue( torch.allclose(grad_actual, grad_expected, atol=1e-4), "Gradient is not valid:\n{0}\n{1}".format(grad_actual, grad_expected))