def test_on_criterion_validation(self):
        divergence = DivergenceBase({'test': key}).with_sum_sum_reduction()
        divergence.compute = Mock(
            return_value=torch.ones((2, 2), requires_grad=True))

        state = {torchbearer.LOSS: torch.zeros(1, requires_grad=True), key: 1}

        divergence.on_criterion_validation(state)
        self.assertTrue(state[torchbearer.LOSS].item() == 4)
        divergence.compute.assert_called_once_with(test=1)
    def test_with_linear_capacity(self):
        divergence = DivergenceBase({
            'test': key
        }).with_sum_sum_reduction().with_linear_capacity(min_c=0,
                                                         max_c=6,
                                                         steps=6,
                                                         gamma=2)
        divergence.compute = Mock(
            return_value=torch.ones((2, 2), requires_grad=True))

        state = {torchbearer.LOSS: torch.zeros(1, requires_grad=True), key: 1}

        divergence.on_criterion(state)
        self.assertTrue(state[torchbearer.LOSS].item() == 8)
        state[torchbearer.LOSS] = torch.zeros(1)
        divergence.on_step_training(state)

        divergence.on_criterion(state)
        self.assertTrue(state[torchbearer.LOSS].item() == 6)
        state[torchbearer.LOSS] = torch.zeros(1)
        divergence.on_step_training(state)

        divergence.on_criterion(state)
        self.assertTrue(state[torchbearer.LOSS].item() == 4)
        state[torchbearer.LOSS] = torch.zeros(1)
        divergence.on_step_training(state)

        divergence.on_criterion(state)
        self.assertTrue(state[torchbearer.LOSS].item() == 2)
        state[torchbearer.LOSS] = torch.zeros(1)
        divergence.on_step_training(state)

        divergence.on_criterion(state)
        self.assertTrue(state[torchbearer.LOSS].item() == 0)
        state[torchbearer.LOSS] = torch.zeros(1)
        divergence.on_step_training(state)

        divergence.on_criterion(state)
        self.assertTrue(state[torchbearer.LOSS].item() == 2)
        state[torchbearer.LOSS] = torch.zeros(1)
        divergence.on_step_training(state)

        divergence.on_criterion(state)
        self.assertTrue(state[torchbearer.LOSS].item() == 4)
 def test_empty_compute(self):
     divergence = DivergenceBase({'test': key})
     self.assertRaises(NotImplementedError, divergence.compute)