def test_on_criterion_validation(self): divergence = DivergenceBase({'test': key}).with_sum_sum_reduction() divergence.compute = Mock( return_value=torch.ones((2, 2), requires_grad=True)) state = {torchbearer.LOSS: torch.zeros(1, requires_grad=True), key: 1} divergence.on_criterion_validation(state) self.assertTrue(state[torchbearer.LOSS].item() == 4) divergence.compute.assert_called_once_with(test=1)
def test_with_linear_capacity(self): divergence = DivergenceBase({ 'test': key }).with_sum_sum_reduction().with_linear_capacity(min_c=0, max_c=6, steps=6, gamma=2) divergence.compute = Mock( return_value=torch.ones((2, 2), requires_grad=True)) state = {torchbearer.LOSS: torch.zeros(1, requires_grad=True), key: 1} divergence.on_criterion(state) self.assertTrue(state[torchbearer.LOSS].item() == 8) state[torchbearer.LOSS] = torch.zeros(1) divergence.on_step_training(state) divergence.on_criterion(state) self.assertTrue(state[torchbearer.LOSS].item() == 6) state[torchbearer.LOSS] = torch.zeros(1) divergence.on_step_training(state) divergence.on_criterion(state) self.assertTrue(state[torchbearer.LOSS].item() == 4) state[torchbearer.LOSS] = torch.zeros(1) divergence.on_step_training(state) divergence.on_criterion(state) self.assertTrue(state[torchbearer.LOSS].item() == 2) state[torchbearer.LOSS] = torch.zeros(1) divergence.on_step_training(state) divergence.on_criterion(state) self.assertTrue(state[torchbearer.LOSS].item() == 0) state[torchbearer.LOSS] = torch.zeros(1) divergence.on_step_training(state) divergence.on_criterion(state) self.assertTrue(state[torchbearer.LOSS].item() == 2) state[torchbearer.LOSS] = torch.zeros(1) divergence.on_step_training(state) divergence.on_criterion(state) self.assertTrue(state[torchbearer.LOSS].item() == 4)
def test_empty_compute(self): divergence = DivergenceBase({'test': key}) self.assertRaises(NotImplementedError, divergence.compute)