def test_simple_forward_pass_nll_top(self):
        loss = ExplodedLogitLoss(loss_type='nll', reduction='sum', top_n=3)

        scores = torch.tensor([1.2, 4.8, 0.2, 5.6, 7.4, 0.],
                              dtype=torch.float64)
        order = torch.tensor([6, 5, 3, 4, 2, 1], dtype=torch.long)

        loss_expected = torch.tensor(13.6171, dtype=torch.float64)
        loss_actual = loss.forward(scores, order)
        self.assertTrue(
            torch.isclose(loss_actual, loss_expected, atol=1e-4),
            "Forward pass not valid: {0} != {1}".format(
                loss_actual, loss_expected))
    def test_batch_forward_pass_bce(self):
        loss = ExplodedLogitLoss(loss_type='bce', reduction='sum')

        scores = torch.tensor(
            [[1.2, 4.8, 0.2, 5.6, 7.4, 0.], [1.2, 4.8, 0.2, 5.6, 7.4, 0.]],
            dtype=torch.float64)
        order = torch.tensor([[6, 5, 3, 4, 2, 1], [6, 5, 3, 4, 2, 1]],
                             dtype=torch.long)

        loss_expected = torch.tensor(17.9922 * 2, dtype=torch.float64)
        loss_actual = loss.forward(scores, order)
        self.assertTrue(
            torch.isclose(loss_actual, loss_expected, atol=1e-4),
            "Forward pass not valid: {0} != {1}".format(
                loss_actual, loss_expected))
    def test_simple_backward_pass(self):
        loss = ExplodedLogitLoss(loss_type='bce', reduction='sum')

        scores = torch.tensor([1.2, 4.8, 0.2, 5.6, 7.4, 0.],
                              dtype=torch.float64,
                              requires_grad=True)
        order = torch.tensor([6, 5, 3, 4, 2, 1], dtype=torch.long)

        loss = loss.forward(scores, order)
        loss.backward()

        grad_expected = torch.tensor(
            [0.0604, 0.4864, -1.0052, 0.3989, 1.0611, -1.0016],
            dtype=torch.float64)
        grad_actual = scores.grad

        self.assertTrue(
            torch.allclose(grad_actual, grad_expected, atol=1e-4),
            "Gradient is not valid:\n{0}\n{1}".format(grad_actual,
                                                      grad_expected))