Beispiel #1
0
    def test_add_to_structured_config(self,
                                      hydra_restore_singletons: Any) -> None:
        @dataclass
        class Config:
            a: int = 10

        ConfigStore.instance().store(name="config",
                                     node=Config,
                                     package="nested")

        assert compose("config", overrides=["+nested.b=20"]) == {
            "nested": {
                "a": 10,
                "b": 20
            }
        }

        assert compose("config", overrides=["++nested.a=30",
                                            "++nested.b=20"]) == {
                                                "nested": {
                                                    "a": 30,
                                                    "b": 20
                                                }
                                            }

        assert compose("config", overrides=["+nested.b.c=20"]) == {
            "nested": {
                "a": 10,
                "b": {
                    "c": 20
                }
            }
        }
Beispiel #2
0
    def test_force_add(self) -> None:
        ConfigStore.instance().store(name="config", node={"key": 0})
        cfg = compose(config_name="config", overrides=["++key=1"])
        assert cfg == {"key": 1}

        cfg = compose(config_name="config", overrides=["++key2=1"])
        assert cfg == {"key": 0, "key2": 1}
Beispiel #3
0
def test_load_schema_as_config(restore_singletons: Any) -> None:
    """
    Load structured config as a configuration
    """
    ConfigStore.instance().store(
        group="db",
        name="mysql",
        node=MySQLConfig,
        provider="test_provider",
    )

    config_loader = ConfigLoaderImpl(
        config_search_path=create_config_search_path(None))
    cfg = config_loader.load_configuration(config_name="db/mysql",
                                           overrides=[])
    with open_dict(cfg):
        del cfg["hydra"]
    assert cfg == {
        "db": {
            "driver": MISSING,
            "host": MISSING,
            "port": MISSING,
            "user": MISSING,
            "password": MISSING,
        }
    }

    expected = hydra_load_list.copy()
    expected.extend(
        [LoadTrace("db/mysql", "structured://", "test_provider", None)])
    assert config_loader.get_load_history() == expected
def test_load_schema_as_config(hydra_restore_singletons: Any) -> None:
    """
    Load structured config as a configuration
    """
    ConfigStore.instance().store(name="config",
                                 node=TopLevelConfig,
                                 provider="this_test")

    ConfigStore.instance().store(
        name="db/mysql",
        node=MySQLConfig,
        provider="this_test",
    )

    config_loader = ConfigLoaderImpl(
        config_search_path=create_config_search_path(None))
    cfg = config_loader.load_configuration(config_name="config",
                                           overrides=[],
                                           run_mode=RunMode.RUN)
    with open_dict(cfg):
        del cfg["hydra"]
    assert cfg == {
        "normal_yaml_config": "???",
        "db": {
            "driver": MISSING,
            "host": MISSING,
            "port": MISSING,
            "user": MISSING,
            "password": MISSING,
        },
    }

    expected = hydra_load_list.copy()
    expected.extend([LoadTrace("config", "structured://", "this_test", None)])
    assert config_loader.get_load_history() == expected
Beispiel #5
0
    def test_load_config_with_schema(self, hydra_restore_singletons: Any,
                                     path: str) -> None:

        ConfigStore.instance().store(name="config",
                                     node=TopLevelConfig,
                                     provider="this_test")
        ConfigStore.instance().store(group="db",
                                     name="mysql",
                                     node=MySQLConfig,
                                     provider="this_test")

        config_loader = ConfigLoaderImpl(
            config_search_path=create_config_search_path(path))

        cfg = config_loader.load_configuration(config_name="config",
                                               overrides=["+db=mysql"],
                                               run_mode=RunMode.RUN)

        expected = deepcopy(hydra_load_list)
        expected.append(
            LoadTrace(
                config_name="config",
                search_path=path,
                provider="main",
                schema_provider="this_test",
            ))
        expected.append(
            LoadTrace(
                config_group="db",
                config_name="mysql",
                search_path=path,
                provider="main",
                schema_provider="this_test",
                parent="overrides",
            ))
        assert_same_composition_trace(cfg.hydra.composition_trace, expected)

        with open_dict(cfg):
            del cfg["hydra"]
        assert cfg == {
            "normal_yaml_config": True,
            "db": {
                "driver": "mysql",
                "host": "???",
                "port": "???",
                "user": "******",
                "password": "******",
            },
        }

        # verify illegal modification is rejected at runtime
        with pytest.raises(ValidationError):
            cfg.db.port = "fail"

        # verify illegal override is rejected during load
        with pytest.raises(HydraException):
            config_loader.load_configuration(config_name="db/mysql",
                                             overrides=["db.port=fail"],
                                             run_mode=RunMode.RUN)
Beispiel #6
0
    def test_force_add(self, hydra_restore_singletons: Any) -> None:
        ConfigStore.instance().store(name="config", node={"key": 0})
        with initialize():
            cfg = compose(config_name="config", overrides=["++key=1"])
            assert cfg == {"key": 1}

            cfg = compose(config_name="config", overrides=["++key2=1"])
            assert cfg == {"key": 0, "key2": 1}
Beispiel #7
0
    def test_add(self) -> None:
        ConfigStore.instance().store(name="config", node={"key": 0})
        with raises(
                ConfigCompositionException,
                match="Could not append to config. An item is already at 'key'",
        ):
            compose(config_name="config", overrides=["+key=value"])

        cfg = compose(config_name="config", overrides=["key=1"])
        assert cfg == {"key": 1}
Beispiel #8
0
def test_adding_to_sc_dict(hydra_restore_singletons: Any, overrides: List[str],
                           expected: Any) -> None:
    @dataclass
    class Config:
        map: Dict[str, str] = field(default_factory=dict)

    ConfigStore.instance().store(name="config", node=Config)

    cfg = compose(config_name="config", overrides=overrides)
    assert cfg == expected
Beispiel #9
0
def register_configs() -> None:
    cs = ConfigStore.instance()
    cs.store(
        group="dynamics",
        name="dynamics",
        node=DynamicsConfig,
    )
Beispiel #10
0
    def register_model_cls(cls):
        if name in MODEL_REGISTRY:
            raise ValueError(
                "Cannot register duplicate model ({})".format(name))
        if not issubclass(cls, BaseFairseqModel):
            raise ValueError(
                "Model ({}: {}) must extend BaseFairseqModel".format(
                    name, cls.__name__))
        MODEL_REGISTRY[name] = cls
        if dataclass is not None and not issubclass(dataclass,
                                                    FairseqDataclass):
            raise ValueError(
                "Dataclass {} must extend FairseqDataclass".format(dataclass))

        cls.__dataclass = dataclass
        if dataclass is not None:
            MODEL_DATACLASS_REGISTRY[name] = dataclass

            cs = ConfigStore.instance()
            node = dataclass()
            node._name = name
            cs.store(name=name, group="model", node=node, provider="fairseq")

            @register_model_architecture(name, name)
            def noop(_):
                pass

        return cls
Beispiel #11
0
def test_overlapping_schemas(hydra_restore_singletons: Any) -> None:

    cs = ConfigStore.instance()
    cs.store(name="config", node=Config)
    cs.store(group="plugin", name="concrete", node=ConcretePlugin)

    config_loader = ConfigLoaderImpl(
        config_search_path=create_config_search_path(None))
    cfg = config_loader.load_configuration(config_name="config",
                                           overrides=[],
                                           run_mode=RunMode.RUN)
    with open_dict(cfg):
        del cfg["hydra"]

    assert cfg == {"plugin": {"name": "???", "params": "???"}}
    assert OmegaConf.get_type(cfg.plugin) == Plugin

    cfg = config_loader.load_configuration(config_name="config",
                                           overrides=["+plugin=concrete"],
                                           run_mode=RunMode.RUN)
    with open_dict(cfg):
        del cfg["hydra"]

    assert cfg == {"plugin": {"name": "foobar_plugin", "params": {"foo": 10}}}
    assert OmegaConf.get_type(cfg.plugin) == ConcretePlugin
    assert OmegaConf.get_type(cfg.plugin.params) == ConcretePlugin.FoobarParams
    with pytest.raises(ValidationError):
        cfg.plugin = 10
Beispiel #12
0
        def register_x_cls(cls):
            if name in REGISTRY:
                raise ValueError("Cannot register duplicate {} ({})".format(
                    registry_name, name))
            if cls.__name__ in REGISTRY_CLASS_NAMES:
                raise ValueError(
                    "Cannot register {} with duplicate class name ({})".format(
                        registry_name, cls.__name__))
            if base_class is not None and not issubclass(cls, base_class):
                raise ValueError("{} must extend {}".format(
                    cls.__name__, base_class.__name__))

            if dataclass is not None and not issubclass(
                    dataclass, FairseqDataclass):
                raise ValueError(
                    "Dataclass {} must extend FairseqDataclass".format(
                        dataclass))

            cls.__dataclass = dataclass
            if cls.__dataclass is not None:
                DATACLASS_REGISTRY[name] = cls.__dataclass

                cs = ConfigStore.instance()
                node = dataclass()
                node._name = name
                cs.store(name=name,
                         group=registry_name,
                         node=node,
                         provider="fairseq")

            REGISTRY[name] = cls

            return cls
Beispiel #13
0
    def register_task_cls(cls):
        if name in TASK_REGISTRY:
            raise ValueError(
                "Cannot register duplicate task ({})".format(name))
        if not issubclass(cls, FairseqTask):
            raise ValueError("Task ({}: {}) must extend FairseqTask".format(
                name, cls.__name__))
        if cls.__name__ in TASK_CLASS_NAMES:
            raise ValueError(
                "Cannot register task with duplicate class name ({})".format(
                    cls.__name__))
        TASK_REGISTRY[name] = cls
        TASK_CLASS_NAMES.add(cls.__name__)

        if dataclass is not None and not issubclass(dataclass,
                                                    FairseqDataclass):
            raise ValueError(
                "Dataclass {} must extend FairseqDataclass".format(dataclass))

        cls.__dataclass = dataclass
        if dataclass is not None:
            TASK_DATACLASS_REGISTRY[name] = dataclass

            cs = ConfigStore.instance()
            node = dataclass()
            node._name = name
            cs.store(name=name, group="task", node=node, provider="fairseq")

        return cls
Beispiel #14
0
def register_train_config() -> None:
    cs = ConfigStore.instance()
    cs.store(
        group="train/model",
        name="rnn",
        node=ModelArgs(model_type=EModelType.RNN, model_args=ModelRNNArgs()),
    )
Beispiel #15
0
def test_invalid_plugin_merge(restore_singletons: Any) -> Any:  # noqa: F811
    cs = ConfigStore.instance()
    cs.store(name="config", node=Config)
    cs.store(group="plugin", name="invalid", node=InvalidPlugin, path="plugin")

    cl = ConfigLoaderImpl(config_search_path=create_config_search_path(None))
    with pytest.raises(ValidationError):
        cl.load_configuration(config_name="config", overrides=["plugin=invalid"])
Beispiel #16
0
def init_config_store():
    cs = ConfigStore.instance()
    cs.store(name="config", node=Config)
    cs.store(group="aquarium", name="default", node=AquariumConnection)
    cs.store(group="neo", name="default", node=NeoConnetion)
    # cs.store(group='job', name='default', node=Job)
    for name, task in Task.registered_tasks.items():
        logger.info("Registering task {} ({})".format(name, task.__name__))
        cs.store(group="task", name=name, node=task)
def test_invalid_plugin_merge(hydra_restore_singletons: Any) -> Any:
    cs = ConfigStore.instance()
    cs.store(name="config", node=Config)
    cs.store(group="plugin", name="invalid", node=InvalidPlugin)

    cl = ConfigLoaderImpl(config_search_path=create_config_search_path(None))
    with pytest.raises(HydraException):
        cl.load_configuration(config_name="config",
                              overrides=["plugin=invalid"])
Beispiel #18
0
def get_config_store(name="config", node=Config):
    cs = ConfigStore.instance()
    #     for name, node in get_datasets().items():
    #         cs.store(group="data", name=name, node=node)
    #     for name, node in get_optimizers().items():
    #         cs.store(group="train.optimizer", name=name, node=node)
    #     for name, node in get_lr_schedulers().items():
    #         cs.store(group="train.lr_scheduler", name=name, node=node)
    cs.store(name=name, node=node)
    return cs
Beispiel #19
0
    def test_add_config_group(self) -> None:
        ConfigStore.instance().store(group="group", name="a0", node={"key": 0})
        ConfigStore.instance().store(group="group", name="a1", node={"key": 1})
        # overriding non existing group throws
        with raises(ConfigCompositionException):
            compose(overrides=["group=a0"])

        # appending a new group
        cfg = compose(overrides=["+group=a0"])
        assert cfg == {"group": {"key": 0}}

        # force adding is not supported for config groups.
        with raises(
                ConfigCompositionException,
                match=re.escape(
                    "force-add of config groups is not supported: '++group=a1'"
                ),
        ):
            compose(overrides=["++group=a1"])
    def test_load_config_with_schema(
        self, hydra_restore_singletons: Any, path: str
    ) -> None:

        ConfigStore.instance().store(
            name="config_with_schema", node=TopLevelConfig, provider="this_test"
        )
        ConfigStore.instance().store(
            group="db", name="base_mysql", node=MySQLConfig, provider="this_test"
        )

        config_loader = ConfigLoaderImpl(
            config_search_path=create_config_search_path(path)
        )

        cfg = config_loader.load_configuration(
            config_name="config",
            overrides=["+db=validated_mysql"],
            run_mode=RunMode.RUN,
        )

        with open_dict(cfg):
            del cfg["hydra"]
        assert cfg == {
            "normal_yaml_config": True,
            "db": {
                "driver": "mysql",
                "host": "???",
                "port": "???",
                "user": "******",
                "password": "******",
            },
        }

        # verify illegal modification is rejected at runtime
        with raises(ValidationError):
            cfg.db.port = "fail"

        # verify illegal override is rejected during load
        with raises(HydraException):
            config_loader.load_configuration(
                config_name="db/mysql", overrides=["db.port=fail"], run_mode=RunMode.RUN
            )
Beispiel #21
0
def before_hydra(config_class):
    cs = ConfigStore.instance()
    cs.store(name='config', node=config_class)
    [
        cs.store(group='train.model', name=model_name, node=model_cfg)
        for model_name, model_cfg in model_list
    ]
    cs.store(group='train.model.optim', name='sgd', node=SGDConfig)
    cs.store(group='train.model.optim', name='adam', node=AdamConfig)
    return cs
Beispiel #22
0
    def test_load_config_with_schema(
            self,
            restore_singletons: Any,
            path: str  # noqa: F811
    ) -> None:

        ConfigStore.instance().store(
            group="db",
            name="mysql",
            node=MySQLConfig,
            path="db",
            provider="test_provider",
        )

        config_loader = ConfigLoaderImpl(
            config_search_path=create_config_search_path(path))

        cfg = config_loader.load_configuration(config_name="db/mysql",
                                               overrides=[])
        del cfg["hydra"]
        assert cfg == {
            "db": {
                "driver": "mysql",
                "host": "???",
                "port": "???",
                "user": "******",
                "password": "******",
            }
        }

        expected = hydra_load_list.copy()
        expected.append(("db/mysql", path, "main", "test_provider"))
        assert config_loader.get_load_history() == expected

        # verify illegal modification is rejected at runtime
        with pytest.raises(ValidationError):
            cfg.db.port = "fail"

        # verify illegal override is rejected during load
        with pytest.raises(ValidationError):
            config_loader.load_configuration(config_name="db/mysql",
                                             overrides=["db.port=fail"])
Beispiel #23
0
    def test_load_config_with_key_error(self, hydra_restore_singletons: Any,
                                        path: str) -> None:

        ConfigStore.instance().store(name="base_mysql",
                                     node=MySQLConfig,
                                     provider="this_test")
        config_loader = ConfigLoaderImpl(
            config_search_path=create_config_search_path(path))

        msg = dedent("""\
            In 'schema_key_error': ConfigKeyError raised while composing config:
            Key 'foo' not in 'MySQLConfig'
                full_key: foo
                object_type=MySQLConfig""")
        with raises(ConfigCompositionException, match=re.escape(msg)):
            config_loader.load_configuration(
                config_name="schema_key_error",
                overrides=[],
                run_mode=RunMode.RUN,
            )
Beispiel #24
0
def test_load_schema_as_config(hydra_restore_singletons: Any) -> None:
    """
    Load structured config as a configuration
    """
    ConfigStore.instance().store(
        name="config", node=TopLevelConfig, provider="this_test"
    )

    ConfigStore.instance().store(
        name="db/mysql", node=MySQLConfig, provider="this_test"
    )

    config_loader = ConfigLoaderImpl(config_search_path=create_config_search_path(None))
    cfg = config_loader.load_configuration(
        config_name="config", overrides=[], run_mode=RunMode.RUN
    )

    expected = deepcopy(hydra_load_list)
    expected.append(
        LoadTrace(
            config_path="config",
            package="",
            parent="<root>",
            search_path="structured://",
            provider="this_test",
        )
    )
    assert_same_composition_trace(cfg.hydra.composition_trace, expected)

    with open_dict(cfg):
        del cfg["hydra"]
    assert cfg == {
        "normal_yaml_config": "???",
        "db": {
            "driver": MISSING,
            "host": MISSING,
            "port": MISSING,
            "user": MISSING,
            "password": MISSING,
        },
    }
Beispiel #25
0
    def test_load_config_with_validation_error(self,
                                               hydra_restore_singletons: Any,
                                               path: str) -> None:

        ConfigStore.instance().store(name="base_mysql",
                                     node=MySQLConfig,
                                     provider="this_test")
        config_loader = ConfigLoaderImpl(
            config_search_path=create_config_search_path(path))

        msg = dedent("""\
            In 'schema_validation_error': ValidationError raised while composing config:
            Value 'not_an_int' could not be converted to Integer
                full_key: port
                object_type=MySQLConfig""")
        with raises(ConfigCompositionException, match=re.escape(msg)):
            config_loader.load_configuration(
                config_name="schema_validation_error",
                overrides=[],
                run_mode=RunMode.RUN,
            )
Beispiel #26
0
def cli_main():
    try:
        from hydra._internal.utils import get_args

        cfg_name = get_args().config_name or "config"
    except:
        logger.warning("Failed to get config name from hydra args")
        cfg_name = "config"

    cs = ConfigStore.instance()
    cs.store(name=cfg_name, node=UnsupGenerateConfig)
    hydra_main()
Beispiel #27
0
def register_configs() -> None:
    cs = ConfigStore.instance()
    cs.store(
        group="database_lib/db",
        name="mysql",
        node=MySQLConfig,
    )
    cs.store(
        group="database_lib/db",
        name="postgresql",
        node=PostGreSQLConfig,
    )
 def __init__(self, provider: str, path: str) -> None:
     super().__init__(provider=provider, path=path)
     # Import the module, the __init__ there is expected to register the configs.
     self.store = ConfigStore.instance()
     if self.path != "":
         try:
             importlib.import_module(self.path)
         except Exception as e:
             warnings.warn(
                 f"Error importing {self.path} : some configs may not be available\n\n\tRoot cause: {e}\n"
             )
             raise e
Beispiel #29
0
def hydra_init(cfg_name="config") -> None:

    cs = ConfigStore.instance()
    cs.store(name=cfg_name, node=FairseqConfig)

    for k in FairseqConfig.__dataclass_fields__:
        v = FairseqConfig.__dataclass_fields__[k].default
        try:
            cs.store(name=k, node=v)
        except BaseException:
            logger.error(f"{k} - {v}")
            raise
Beispiel #30
0
    def setUpClass(cls):
        cs = ConfigStore.instance()
        cs.store(name="config", node=datasets.PSQLConfig)

        data = torch.load(os.path.join(fixture_dir, 'soc_seq_3_raw_df.pt'))

        def _get_states_from_db_se_f(self, idx: int) -> pd.DataFrame:
            return data[idx][0]

        def _get_actions_from_db_se_f(self, idx: int) -> pd.DataFrame:
            return data[idx][1]

        cls._get_states_from_db_se_f = _get_states_from_db_se_f
        cls._get_actions_from_db_se_f = _get_actions_from_db_se_f