def test_instantiate_adam_conf() -> None: with pytest.raises(Exception): # can't instantiate without passing params utils.instantiate({"_target_": "tests.Adam"}) adam_params = Parameters([1, 2, 3]) res = utils.instantiate(AdamConf(lr=0.123), params=adam_params) assert res == Adam(lr=0.123, params=adam_params)
def test_instantiate_adam() -> None: with pytest.raises(Exception): # can't instantiate without passing params utils.instantiate(ObjectConf(target="tests.Adam")) adam_params = Parameters([1, 2, 3]) res = utils.instantiate(ObjectConf(target="tests.Adam"), params=adam_params) assert res == Adam(params=adam_params)
def test_instantiate_adam(instantiate_func: Any, config: Any) -> None: with raises(Exception): # can't instantiate without passing params instantiate_func(config) adam_params = Parameters([1, 2, 3]) res = instantiate_func(config, params=adam_params) assert res == Adam(params=adam_params)
def test_instantiate_adam_objectconf() -> None: with pytest.warns(expected_warning=UserWarning, match=objectconf_depreacted): with pytest.raises(Exception): # can't instantiate without passing params utils.instantiate(ObjectConf(target="tests.Adam")) adam_params = Parameters([1, 2, 3]) res = utils.instantiate(ObjectConf(target="tests.Adam"), params=adam_params) assert res == Adam(params=adam_params)
def test_instantiate_adam_conf() -> None: with pytest.raises(Exception): # can't instantiate without passing params utils.instantiate({"_target_": "tests.Adam"}) adam_params = Parameters([1, 2, 3]) res = utils.instantiate(AdamConf(lr=0.123), params=adam_params) expected = Adam(lr=0.123, params=adam_params) assert res.params == expected.params assert res.lr == expected.lr assert list(res.betas) == list(expected.betas) # OmegaConf converts tuples to lists assert res.eps == expected.eps assert res.weight_decay == expected.weight_decay assert res.amsgrad == expected.amsgrad
def test_instantiate_adam_conf_with_convert(instantiate_func: Any) -> None: adam_params = Parameters([1, 2, 3]) res = instantiate_func(AdamConf(lr=0.123), params=adam_params, _convert_="all") expected = Adam(lr=0.123, params=adam_params) assert res.params == expected.params assert res.lr == expected.lr assert isinstance(res.betas, list) assert list(res.betas) == list( expected.betas) # OmegaConf converts tuples to lists assert res.eps == expected.eps assert res.weight_decay == expected.weight_decay assert res.amsgrad == expected.amsgrad