def test_register(): class Custom(optim.Optimizer): def __init__(self): super().__init__() optimizers.register_optimizer(Custom) cls = optimizers.get("Custom") assert cls == Custom with pytest.raises(ValueError): optimizers.register_optimizer(optimizers.Adam)
def test_get_none(): assert optimizers.get(None) is None
def test_get_errors(wrong): with pytest.raises(ValueError): # Should raise for anything not a Optimizer instance + unknown string optimizers.get(wrong)
def test_get_instance_returns_instance(opt): torch_optim = opt(global_model.parameters(), lr=1e-3) asteroid_optim = optimizers.get(torch_optim) assert torch_optim == asteroid_optim
def test_get_str_returns_instance(opt_tuple): torch_optim = opt_tuple[0](global_model.parameters(), lr=1e-3) asteroid_optim = optimizers.get(opt_tuple[1])(global_model.parameters(), lr=1e-3) assert type(torch_optim) == type(asteroid_optim) assert torch_optim.param_groups == asteroid_optim.param_groups
def test_all_get(opt): optimizers.get(opt)(global_model.parameters(), lr=1e-3)