def testmake_optimizer_with_tuple(self): """Test make_optimizer function with tuple as first argument.""" optimizer_type = (torch.optim.Adam, {'lr': 0.1}) module = torch.nn.Linear(2, 1) optimizer = torch_algo_utils.make_optimizer(optimizer_type, module) assert isinstance(optimizer, optimizer_type) assert optimizer.defaults['lr'] == optimizer_type[1]['lr']
def testmake_optimizer_raise_value_error(self): """Test make_optimizer raises value error.""" optimizer_type = (torch.optim.Adam, {'lr': 0.1}) module = torch.nn.Linear(2, 1) with pytest.raises(ValueError): _ = torch_algo_utils.make_optimizer(optimizer_type, module, lr=0.123)
def testmake_optimizer_with_type(self): """Test make_optimizer function with type as first argument.""" optimizer_type = torch.optim.Adam module = torch.nn.Linear(2, 1) lr = 0.123 optimizer = torch_algo_utils.make_optimizer(optimizer_type, module, lr=lr) assert isinstance(optimizer, optimizer_type) assert optimizer.defaults['lr'] == lr