def test_update(self): params = [torch.nn.Parameter(torch.randn(2, 3, 4))] optimizer = Optimizer(torch.optim.Adam(params, lr=1), max_grad_norm=5) scheduler = StepLR(optimizer.optimizer, 1, gamma=0.1) optimizer.set_scheduler(scheduler) optimizer.step() optimizer.update(10, 1) self.assertEqual(0.1, optimizer.optimizer.param_groups[0]['lr'])
def test_step(self, mock_clip_grad_norm): params = [torch.nn.Parameter(torch.randn(2,3,4))] optim = Optimizer(torch.optim.Adam(params), max_grad_norm=5) optim.step() mock_clip_grad_norm.assert_called_once()