示例#1
0
 def test_update(self):
     params = [torch.nn.Parameter(torch.randn(2, 3, 4))]
     optimizer = Optimizer(torch.optim.Adam(params, lr=1), max_grad_norm=5)
     scheduler = StepLR(optimizer.optimizer, 1, gamma=0.1)
     optimizer.set_scheduler(scheduler)
     optimizer.step()
     optimizer.update(10, 1)
     self.assertEqual(0.1, optimizer.optimizer.param_groups[0]['lr'])
示例#2
0
 def test_step(self, mock_clip_grad_norm):
     params = [torch.nn.Parameter(torch.randn(2,3,4))]
     optim = Optimizer(torch.optim.Adam(params),
                       max_grad_norm=5)
     optim.step()
     mock_clip_grad_norm.assert_called_once()