예제 #1
0
def test_instantiate_adam() -> None:
    with pytest.raises(Exception):
        # can't instantiate without passing params
        utils.instantiate(ObjectConf(target="tests.Adam"))

    adam_params = Parameters([1, 2, 3])
    res = utils.instantiate(ObjectConf(target="tests.Adam"), params=adam_params)
    assert res == Adam(params=adam_params)
예제 #2
0
파일: test_utils.py 프로젝트: zivzone/hydra
def test_class_instantiate_pass_omegaconf_node() -> Any:
    pc = ObjectConf()
    # This is a bit clunky because it exposes a problem with the backport of dataclass on Python 3.6
    # see: https://github.com/ericvsmith/dataclasses/issues/155
    pc.cls = "tests.test_utils.Bar"
    pc.params = {"b": 200, "c": {"x": 10, "y": "${params.b}"}}
    conf = OmegaConf.structured(pc)
    obj = utils.instantiate(conf, **{"a": 10, "d": Foo(99)})
    assert obj == Bar(10, 200, {"x": 10, "y": 200}, Foo(99))
    assert OmegaConf.is_config(obj.c)
예제 #3
0
파일: test_utils.py 프로젝트: vporta/hydra
def test_instantiate_adam_objectconf() -> None:
    with pytest.warns(expected_warning=UserWarning,
                      match=objectconf_depreacted):
        with pytest.raises(Exception):
            # can't instantiate without passing params
            utils.instantiate(ObjectConf(target="tests.Adam"))

        adam_params = Parameters([1, 2, 3])
        res = utils.instantiate(ObjectConf(target="tests.Adam"),
                                params=adam_params)
        assert res == Adam(params=adam_params)
예제 #4
0
def test_instantiate_with_missing_module() -> None:

    with pytest.raises(
            HydraException,
            match=re.escape(" No module named 'some_missing_module'")):
        # can't instantiate when importing a missing module
        utils.instantiate(ObjectConf(target="tests.ClassWithMissingModule"))
예제 #5
0
파일: test_utils.py 프로젝트: vporta/hydra
def test_object_conf_deprecated() -> None:
    with pytest.warns(UserWarning) as recwarn:
        cfg = ObjectConf(target="tests.AClass", params={"a": 10, "b": 20})
        ret = utils.instantiate(cfg, c=30)
    assert ret == AClass(a=10, b=20, c=30)
    assert recwarn[0].message.args[0] == objectconf_depreacted
    assert recwarn[1].message.args[0] == target_field_deprecated.format(
        field="target")
예제 #6
0
def test_object_conf_deprecated() -> None:
    msg = (
        "\nObjectConf is deprecated in favor of TargetConf since Hydra 1.0.0rc3 and will be removed in Hydra 1.1."
        "\nSee https://hydra.cc/docs/next/upgrades/0.11_to_1.0/object_instantiation_changes"
    )

    with pytest.warns(expected_warning=UserWarning, match=msg):
        ObjectConf(target="foo")
예제 #7
0
def test_class_instantiate_pass_omegaconf_node() -> Any:
    conf = OmegaConf.structured(
        ObjectConf(
            target="tests.test_utils.Bar",
            params={"b": 200, "c": {"x": 10, "y": "${params.b}"}},
        )
    )
    obj = utils.instantiate(conf, **{"a": 10, "d": Foo(99)})
    assert obj == Bar(10, 200, {"x": 10, "y": 200}, Foo(99))
    assert OmegaConf.is_config(obj.c)
예제 #8
0
파일: test_utils.py 프로젝트: vporta/hydra
def test_class_instantiate_objectconf_pass_omegaconf_node() -> Any:
    with pytest.warns(expected_warning=UserWarning) as recwarn:
        conf = OmegaConf.structured(
            ObjectConf(
                target="tests.AClass",
                params={
                    "b": 200,
                    "c": {
                        "x": 10,
                        "y": "${params.b}"
                    }
                },
            ))
        obj = utils.instantiate(conf, **{"a": 10, "d": AnotherClass(99)})
    assert obj == AClass(10, 200, {"x": 10, "y": 200}, AnotherClass(99))
    assert OmegaConf.is_config(obj.c)

    assert recwarn[0].message.args[0] == objectconf_depreacted
    assert recwarn[1].message.args[0] == target_field_deprecated.format(
        field="target")
예제 #9
0
파일: my_app.py 프로젝트: austinserif/hydra
    {"db": "mysql"}
]


@dataclass
class Config(DictConfig):
    defaults: List[Any] = field(default_factory=lambda: defaults)
    db: ObjectConf = MISSING


cs = ConfigStore.instance()
cs.store(name="config", node=Config)
cs.store(
    group="db",
    name="mysql",
    node=ObjectConf(target="my_app.MySQLConnection", params=MySQLConfig),
)
cs.store(
    group="db",
    name="postgresql",
    node=ObjectConf(target="my_app.PostgreSQLConnection", params=PostGreSQLConfig),
)


@hydra.main(config_name="config")
def my_app(cfg: DictConfig) -> None:
    connection = instantiate(cfg.db)
    connection.connect()


if __name__ == "__main__":
예제 #10
0
def test_pass_extra_variables() -> None:
    cfg = ObjectConf(target="tests.AClass", params={"a": 10, "b": 20})
    assert utils.instantiate(cfg, c=30) == AClass(a=10, b=20, c=30)
예제 #11
0
def test_cls() -> None:
    with pytest.warns(expected_warning=UserWarning):
        assert utils._get_cls_name(ObjectConf(cls="foo")) == "foo"
    with pytest.warns(expected_warning=UserWarning):
        assert utils._get_cls_name(ObjectConf(cls="foo",
                                              target="bar")) == "bar"
예제 #12
0
        ([["a"], ["aa", "b"]], [2, 1]),
        ([["a", "aa"], ["bb"]], [2, 2]),
        ([["a"]], [1]),
        ([["a"]], [1]),
        ([["a"]], [1]),
    ],
)
def test_get_column_widths(matrix: Any, expected: Any) -> None:
    assert utils.get_column_widths(matrix) == expected


@pytest.mark.parametrize(  # type: ignore
    "config, expected, warning",
    [
        pytest.param(
            ObjectConf(target="foo"), "foo", False, id="ObjectConf:target"),
        pytest.param(OmegaConf.create({"cls": "foo"}),
                     "foo",
                     "cls",
                     id="DictConfig:cls"),
        pytest.param(OmegaConf.create({"class": "foo"}),
                     "foo",
                     "class",
                     id="DictConfig:class"),
        pytest.param(
            OmegaConf.create({"target": "foo"}),
            "foo",
            False,
            id="DictConfig:target",
        ),
        pytest.param(
예제 #13
0
파일: config.py 프로젝트: sawravchy/hydra
    # check the following for more info on slurm_max_num_timeout
    # https://github.com/facebookincubator/submitit/blob/master/docs/checkpointing.md
    max_num_timeout: int = 0


@dataclass
class LocalParams(BaseParams):
    pass


# finally, register two different choices:
ConfigStore.instance().store(
    group="hydra/launcher",
    name="submitit_local",
    node=ObjectConf(
        target="hydra_plugins.hydra_submitit_launcher.submitit_launcher.LocalLauncher",
        params=LocalParams(),
    ),
    provider="submitit_launcher",
)


ConfigStore.instance().store(
    group="hydra/launcher",
    name="submitit_slurm",
    node=ObjectConf(
        target="hydra_plugins.hydra_submitit_launcher.submitit_launcher.SlurmLauncher",
        params=SlurmParams(),
    ),
    provider="submitit_launcher",
)
예제 #14
0
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
from dataclasses import dataclass
from typing import Optional

from hydra.core.config_store import ConfigStore
from hydra.types import ObjectConf

# finally, register two different choices:
ConfigStore.instance().store(
    group="hydra/launcher",
    name="tsp",
    node=ObjectConf(
        cls="hydra_plugins.hydra_tsp_launcher.tsp_launcher.TaskSpoolerLauncher",
        params={},
    ),
    provider="tsp_launcher",
)
예제 #15
0
        ([["a", "bb"], ["aa", "b"]], [2, 2]),
        ([["a"], ["aa", "b"]], [2, 1]),
        ([["a", "aa"], ["bb"]], [2, 2]),
        ([["a"]], [1]),
        ([["a"]], [1]),
        ([["a"]], [1]),
    ],
)
def test_get_column_widths(matrix: Any, expected: Any) -> None:
    assert utils.get_column_widths(matrix) == expected


@pytest.mark.parametrize(  # type: ignore
    "config, expected, warning",
    [
        pytest.param(ObjectConf(target="foo"), "foo", False, id="ObjectConf:target"),
        pytest.param(
            OmegaConf.create({"cls": "foo"}), "foo", "cls", id="DictConfig:cls"
        ),
        pytest.param(
            OmegaConf.create({"class": "foo"}), "foo", "class", id="DictConfig:class"
        ),
        pytest.param(
            OmegaConf.create({"target": "foo"}), "foo", False, id="DictConfig:target",
        ),
        pytest.param(
            OmegaConf.create({"cls": "foo", "target": "bar"}),
            "bar",
            False,
            id="DictConfig:cls_target",
        ),