def test_set_parameters(self):
        learning_rate = 1
        optim = Optimizer(torch.optim.SGD, lr=learning_rate)
        params = [torch.nn.Parameter(torch.randn(2, 3, 4))]
        optim.set_parameters(params)

        self.assertTrue(type(optim.optimizer) is torch.optim.SGD)
        self.assertEquals(optim.optimizer.param_groups[0]['lr'], learning_rate)
 def test_update(self):
     optim = Optimizer(torch.optim.SGD,
                       lr=1,
                       decay_after_epoch=5,
                       lr_decay=0.5)
     params = [torch.nn.Parameter(torch.randn(2, 3, 4))]
     optim.set_parameters(params)
     optim.update(0, 10)
     self.assertEquals(optim.optimizer.param_groups[0]['lr'], 0.5)
 def test_step(self, mock_clip_grad_norm):
     optim = Optimizer(torch.optim.Adam, max_grad_norm=5)
     params = [torch.nn.Parameter(torch.randn(2, 3, 4))]
     optim.set_parameters(params)
     optim.step()
     mock_clip_grad_norm.assert_called_once()