def test_add_to_structured_config(self, hydra_restore_singletons: Any) -> None: @dataclass class Config: a: int = 10 ConfigStore.instance().store(name="config", node=Config, package="nested") assert compose("config", overrides=["+nested.b=20"]) == { "nested": { "a": 10, "b": 20 } } assert compose("config", overrides=["++nested.a=30", "++nested.b=20"]) == { "nested": { "a": 30, "b": 20 } } assert compose("config", overrides=["+nested.b.c=20"]) == { "nested": { "a": 10, "b": { "c": 20 } } }
def test_force_add(self) -> None: ConfigStore.instance().store(name="config", node={"key": 0}) cfg = compose(config_name="config", overrides=["++key=1"]) assert cfg == {"key": 1} cfg = compose(config_name="config", overrides=["++key2=1"]) assert cfg == {"key": 0, "key2": 1}
def test_load_schema_as_config(restore_singletons: Any) -> None: """ Load structured config as a configuration """ ConfigStore.instance().store( group="db", name="mysql", node=MySQLConfig, provider="test_provider", ) config_loader = ConfigLoaderImpl( config_search_path=create_config_search_path(None)) cfg = config_loader.load_configuration(config_name="db/mysql", overrides=[]) with open_dict(cfg): del cfg["hydra"] assert cfg == { "db": { "driver": MISSING, "host": MISSING, "port": MISSING, "user": MISSING, "password": MISSING, } } expected = hydra_load_list.copy() expected.extend( [LoadTrace("db/mysql", "structured://", "test_provider", None)]) assert config_loader.get_load_history() == expected
def test_load_schema_as_config(hydra_restore_singletons: Any) -> None: """ Load structured config as a configuration """ ConfigStore.instance().store(name="config", node=TopLevelConfig, provider="this_test") ConfigStore.instance().store( name="db/mysql", node=MySQLConfig, provider="this_test", ) config_loader = ConfigLoaderImpl( config_search_path=create_config_search_path(None)) cfg = config_loader.load_configuration(config_name="config", overrides=[], run_mode=RunMode.RUN) with open_dict(cfg): del cfg["hydra"] assert cfg == { "normal_yaml_config": "???", "db": { "driver": MISSING, "host": MISSING, "port": MISSING, "user": MISSING, "password": MISSING, }, } expected = hydra_load_list.copy() expected.extend([LoadTrace("config", "structured://", "this_test", None)]) assert config_loader.get_load_history() == expected
def test_load_config_with_schema(self, hydra_restore_singletons: Any, path: str) -> None: ConfigStore.instance().store(name="config", node=TopLevelConfig, provider="this_test") ConfigStore.instance().store(group="db", name="mysql", node=MySQLConfig, provider="this_test") config_loader = ConfigLoaderImpl( config_search_path=create_config_search_path(path)) cfg = config_loader.load_configuration(config_name="config", overrides=["+db=mysql"], run_mode=RunMode.RUN) expected = deepcopy(hydra_load_list) expected.append( LoadTrace( config_name="config", search_path=path, provider="main", schema_provider="this_test", )) expected.append( LoadTrace( config_group="db", config_name="mysql", search_path=path, provider="main", schema_provider="this_test", parent="overrides", )) assert_same_composition_trace(cfg.hydra.composition_trace, expected) with open_dict(cfg): del cfg["hydra"] assert cfg == { "normal_yaml_config": True, "db": { "driver": "mysql", "host": "???", "port": "???", "user": "******", "password": "******", }, } # verify illegal modification is rejected at runtime with pytest.raises(ValidationError): cfg.db.port = "fail" # verify illegal override is rejected during load with pytest.raises(HydraException): config_loader.load_configuration(config_name="db/mysql", overrides=["db.port=fail"], run_mode=RunMode.RUN)
def test_force_add(self, hydra_restore_singletons: Any) -> None: ConfigStore.instance().store(name="config", node={"key": 0}) with initialize(): cfg = compose(config_name="config", overrides=["++key=1"]) assert cfg == {"key": 1} cfg = compose(config_name="config", overrides=["++key2=1"]) assert cfg == {"key": 0, "key2": 1}
def test_add(self) -> None: ConfigStore.instance().store(name="config", node={"key": 0}) with raises( ConfigCompositionException, match="Could not append to config. An item is already at 'key'", ): compose(config_name="config", overrides=["+key=value"]) cfg = compose(config_name="config", overrides=["key=1"]) assert cfg == {"key": 1}
def test_adding_to_sc_dict(hydra_restore_singletons: Any, overrides: List[str], expected: Any) -> None: @dataclass class Config: map: Dict[str, str] = field(default_factory=dict) ConfigStore.instance().store(name="config", node=Config) cfg = compose(config_name="config", overrides=overrides) assert cfg == expected
def register_configs() -> None: cs = ConfigStore.instance() cs.store( group="dynamics", name="dynamics", node=DynamicsConfig, )
def register_model_cls(cls): if name in MODEL_REGISTRY: raise ValueError( "Cannot register duplicate model ({})".format(name)) if not issubclass(cls, BaseFairseqModel): raise ValueError( "Model ({}: {}) must extend BaseFairseqModel".format( name, cls.__name__)) MODEL_REGISTRY[name] = cls if dataclass is not None and not issubclass(dataclass, FairseqDataclass): raise ValueError( "Dataclass {} must extend FairseqDataclass".format(dataclass)) cls.__dataclass = dataclass if dataclass is not None: MODEL_DATACLASS_REGISTRY[name] = dataclass cs = ConfigStore.instance() node = dataclass() node._name = name cs.store(name=name, group="model", node=node, provider="fairseq") @register_model_architecture(name, name) def noop(_): pass return cls
def test_overlapping_schemas(hydra_restore_singletons: Any) -> None: cs = ConfigStore.instance() cs.store(name="config", node=Config) cs.store(group="plugin", name="concrete", node=ConcretePlugin) config_loader = ConfigLoaderImpl( config_search_path=create_config_search_path(None)) cfg = config_loader.load_configuration(config_name="config", overrides=[], run_mode=RunMode.RUN) with open_dict(cfg): del cfg["hydra"] assert cfg == {"plugin": {"name": "???", "params": "???"}} assert OmegaConf.get_type(cfg.plugin) == Plugin cfg = config_loader.load_configuration(config_name="config", overrides=["+plugin=concrete"], run_mode=RunMode.RUN) with open_dict(cfg): del cfg["hydra"] assert cfg == {"plugin": {"name": "foobar_plugin", "params": {"foo": 10}}} assert OmegaConf.get_type(cfg.plugin) == ConcretePlugin assert OmegaConf.get_type(cfg.plugin.params) == ConcretePlugin.FoobarParams with pytest.raises(ValidationError): cfg.plugin = 10
def register_x_cls(cls): if name in REGISTRY: raise ValueError("Cannot register duplicate {} ({})".format( registry_name, name)) if cls.__name__ in REGISTRY_CLASS_NAMES: raise ValueError( "Cannot register {} with duplicate class name ({})".format( registry_name, cls.__name__)) if base_class is not None and not issubclass(cls, base_class): raise ValueError("{} must extend {}".format( cls.__name__, base_class.__name__)) if dataclass is not None and not issubclass( dataclass, FairseqDataclass): raise ValueError( "Dataclass {} must extend FairseqDataclass".format( dataclass)) cls.__dataclass = dataclass if cls.__dataclass is not None: DATACLASS_REGISTRY[name] = cls.__dataclass cs = ConfigStore.instance() node = dataclass() node._name = name cs.store(name=name, group=registry_name, node=node, provider="fairseq") REGISTRY[name] = cls return cls
def register_task_cls(cls): if name in TASK_REGISTRY: raise ValueError( "Cannot register duplicate task ({})".format(name)) if not issubclass(cls, FairseqTask): raise ValueError("Task ({}: {}) must extend FairseqTask".format( name, cls.__name__)) if cls.__name__ in TASK_CLASS_NAMES: raise ValueError( "Cannot register task with duplicate class name ({})".format( cls.__name__)) TASK_REGISTRY[name] = cls TASK_CLASS_NAMES.add(cls.__name__) if dataclass is not None and not issubclass(dataclass, FairseqDataclass): raise ValueError( "Dataclass {} must extend FairseqDataclass".format(dataclass)) cls.__dataclass = dataclass if dataclass is not None: TASK_DATACLASS_REGISTRY[name] = dataclass cs = ConfigStore.instance() node = dataclass() node._name = name cs.store(name=name, group="task", node=node, provider="fairseq") return cls
def register_train_config() -> None: cs = ConfigStore.instance() cs.store( group="train/model", name="rnn", node=ModelArgs(model_type=EModelType.RNN, model_args=ModelRNNArgs()), )
def test_invalid_plugin_merge(restore_singletons: Any) -> Any: # noqa: F811 cs = ConfigStore.instance() cs.store(name="config", node=Config) cs.store(group="plugin", name="invalid", node=InvalidPlugin, path="plugin") cl = ConfigLoaderImpl(config_search_path=create_config_search_path(None)) with pytest.raises(ValidationError): cl.load_configuration(config_name="config", overrides=["plugin=invalid"])
def init_config_store(): cs = ConfigStore.instance() cs.store(name="config", node=Config) cs.store(group="aquarium", name="default", node=AquariumConnection) cs.store(group="neo", name="default", node=NeoConnetion) # cs.store(group='job', name='default', node=Job) for name, task in Task.registered_tasks.items(): logger.info("Registering task {} ({})".format(name, task.__name__)) cs.store(group="task", name=name, node=task)
def test_invalid_plugin_merge(hydra_restore_singletons: Any) -> Any: cs = ConfigStore.instance() cs.store(name="config", node=Config) cs.store(group="plugin", name="invalid", node=InvalidPlugin) cl = ConfigLoaderImpl(config_search_path=create_config_search_path(None)) with pytest.raises(HydraException): cl.load_configuration(config_name="config", overrides=["plugin=invalid"])
def get_config_store(name="config", node=Config): cs = ConfigStore.instance() # for name, node in get_datasets().items(): # cs.store(group="data", name=name, node=node) # for name, node in get_optimizers().items(): # cs.store(group="train.optimizer", name=name, node=node) # for name, node in get_lr_schedulers().items(): # cs.store(group="train.lr_scheduler", name=name, node=node) cs.store(name=name, node=node) return cs
def test_add_config_group(self) -> None: ConfigStore.instance().store(group="group", name="a0", node={"key": 0}) ConfigStore.instance().store(group="group", name="a1", node={"key": 1}) # overriding non existing group throws with raises(ConfigCompositionException): compose(overrides=["group=a0"]) # appending a new group cfg = compose(overrides=["+group=a0"]) assert cfg == {"group": {"key": 0}} # force adding is not supported for config groups. with raises( ConfigCompositionException, match=re.escape( "force-add of config groups is not supported: '++group=a1'" ), ): compose(overrides=["++group=a1"])
def test_load_config_with_schema( self, hydra_restore_singletons: Any, path: str ) -> None: ConfigStore.instance().store( name="config_with_schema", node=TopLevelConfig, provider="this_test" ) ConfigStore.instance().store( group="db", name="base_mysql", node=MySQLConfig, provider="this_test" ) config_loader = ConfigLoaderImpl( config_search_path=create_config_search_path(path) ) cfg = config_loader.load_configuration( config_name="config", overrides=["+db=validated_mysql"], run_mode=RunMode.RUN, ) with open_dict(cfg): del cfg["hydra"] assert cfg == { "normal_yaml_config": True, "db": { "driver": "mysql", "host": "???", "port": "???", "user": "******", "password": "******", }, } # verify illegal modification is rejected at runtime with raises(ValidationError): cfg.db.port = "fail" # verify illegal override is rejected during load with raises(HydraException): config_loader.load_configuration( config_name="db/mysql", overrides=["db.port=fail"], run_mode=RunMode.RUN )
def before_hydra(config_class): cs = ConfigStore.instance() cs.store(name='config', node=config_class) [ cs.store(group='train.model', name=model_name, node=model_cfg) for model_name, model_cfg in model_list ] cs.store(group='train.model.optim', name='sgd', node=SGDConfig) cs.store(group='train.model.optim', name='adam', node=AdamConfig) return cs
def test_load_config_with_schema( self, restore_singletons: Any, path: str # noqa: F811 ) -> None: ConfigStore.instance().store( group="db", name="mysql", node=MySQLConfig, path="db", provider="test_provider", ) config_loader = ConfigLoaderImpl( config_search_path=create_config_search_path(path)) cfg = config_loader.load_configuration(config_name="db/mysql", overrides=[]) del cfg["hydra"] assert cfg == { "db": { "driver": "mysql", "host": "???", "port": "???", "user": "******", "password": "******", } } expected = hydra_load_list.copy() expected.append(("db/mysql", path, "main", "test_provider")) assert config_loader.get_load_history() == expected # verify illegal modification is rejected at runtime with pytest.raises(ValidationError): cfg.db.port = "fail" # verify illegal override is rejected during load with pytest.raises(ValidationError): config_loader.load_configuration(config_name="db/mysql", overrides=["db.port=fail"])
def test_load_config_with_key_error(self, hydra_restore_singletons: Any, path: str) -> None: ConfigStore.instance().store(name="base_mysql", node=MySQLConfig, provider="this_test") config_loader = ConfigLoaderImpl( config_search_path=create_config_search_path(path)) msg = dedent("""\ In 'schema_key_error': ConfigKeyError raised while composing config: Key 'foo' not in 'MySQLConfig' full_key: foo object_type=MySQLConfig""") with raises(ConfigCompositionException, match=re.escape(msg)): config_loader.load_configuration( config_name="schema_key_error", overrides=[], run_mode=RunMode.RUN, )
def test_load_schema_as_config(hydra_restore_singletons: Any) -> None: """ Load structured config as a configuration """ ConfigStore.instance().store( name="config", node=TopLevelConfig, provider="this_test" ) ConfigStore.instance().store( name="db/mysql", node=MySQLConfig, provider="this_test" ) config_loader = ConfigLoaderImpl(config_search_path=create_config_search_path(None)) cfg = config_loader.load_configuration( config_name="config", overrides=[], run_mode=RunMode.RUN ) expected = deepcopy(hydra_load_list) expected.append( LoadTrace( config_path="config", package="", parent="<root>", search_path="structured://", provider="this_test", ) ) assert_same_composition_trace(cfg.hydra.composition_trace, expected) with open_dict(cfg): del cfg["hydra"] assert cfg == { "normal_yaml_config": "???", "db": { "driver": MISSING, "host": MISSING, "port": MISSING, "user": MISSING, "password": MISSING, }, }
def test_load_config_with_validation_error(self, hydra_restore_singletons: Any, path: str) -> None: ConfigStore.instance().store(name="base_mysql", node=MySQLConfig, provider="this_test") config_loader = ConfigLoaderImpl( config_search_path=create_config_search_path(path)) msg = dedent("""\ In 'schema_validation_error': ValidationError raised while composing config: Value 'not_an_int' could not be converted to Integer full_key: port object_type=MySQLConfig""") with raises(ConfigCompositionException, match=re.escape(msg)): config_loader.load_configuration( config_name="schema_validation_error", overrides=[], run_mode=RunMode.RUN, )
def cli_main(): try: from hydra._internal.utils import get_args cfg_name = get_args().config_name or "config" except: logger.warning("Failed to get config name from hydra args") cfg_name = "config" cs = ConfigStore.instance() cs.store(name=cfg_name, node=UnsupGenerateConfig) hydra_main()
def register_configs() -> None: cs = ConfigStore.instance() cs.store( group="database_lib/db", name="mysql", node=MySQLConfig, ) cs.store( group="database_lib/db", name="postgresql", node=PostGreSQLConfig, )
def __init__(self, provider: str, path: str) -> None: super().__init__(provider=provider, path=path) # Import the module, the __init__ there is expected to register the configs. self.store = ConfigStore.instance() if self.path != "": try: importlib.import_module(self.path) except Exception as e: warnings.warn( f"Error importing {self.path} : some configs may not be available\n\n\tRoot cause: {e}\n" ) raise e
def hydra_init(cfg_name="config") -> None: cs = ConfigStore.instance() cs.store(name=cfg_name, node=FairseqConfig) for k in FairseqConfig.__dataclass_fields__: v = FairseqConfig.__dataclass_fields__[k].default try: cs.store(name=k, node=v) except BaseException: logger.error(f"{k} - {v}") raise
def setUpClass(cls): cs = ConfigStore.instance() cs.store(name="config", node=datasets.PSQLConfig) data = torch.load(os.path.join(fixture_dir, 'soc_seq_3_raw_df.pt')) def _get_states_from_db_se_f(self, idx: int) -> pd.DataFrame: return data[idx][0] def _get_actions_from_db_se_f(self, idx: int) -> pd.DataFrame: return data[idx][1] cls._get_states_from_db_se_f = _get_states_from_db_se_f cls._get_actions_from_db_se_f = _get_actions_from_db_se_f