def test_instantiate_adam(instantiate_func: Any, config: Any) -> None: with raises(TypeError): # can't instantiate without passing params instantiate_func(config) adam_params = Parameters([1, 2, 3]) res = instantiate_func(config, params=adam_params) assert res == Adam(params=adam_params)
def test_instantiate_adam_conf_with_convert(instantiate_func: Any) -> None: adam_params = Parameters([1, 2, 3]) res = instantiate_func(AdamConf(lr=0.123), params=adam_params, _convert_="all") expected = Adam(lr=0.123, params=adam_params) assert res.params == expected.params assert res.lr == expected.lr assert isinstance(res.betas, list) assert list(res.betas) == list( expected.betas) # OmegaConf converts tuples to lists assert res.eps == expected.eps assert res.weight_decay == expected.weight_decay assert res.amsgrad == expected.amsgrad
def test_instantiate_adam_conf(instantiate_func: Any) -> None: with raises(TypeError): # can't instantiate without passing params instantiate_func(AdamConf()) adam_params = Parameters([1, 2, 3]) res = instantiate_func(AdamConf(lr=0.123), params=adam_params) expected = Adam(lr=0.123, params=adam_params) assert res.params == expected.params assert res.lr == expected.lr assert list(res.betas) == list( expected.betas) # OmegaConf converts tuples to lists assert res.eps == expected.eps assert res.weight_decay == expected.weight_decay assert res.amsgrad == expected.amsgrad