def _update_known_state( d: DefaultElement, group_to_choice: DictConfig, delete_groups: Dict[DeleteKey, int], ) -> None: fqgn = d.fully_qualified_group_name() if fqgn is None: return is_overridden = fqgn in group_to_choice if d.config_group is not None: if ( fqgn not in group_to_choice and not d.is_delete and d.config_name not in ("_self_", "_keep_") and not is_matching_deletion(delete_groups=delete_groups, d=d) ): group_to_choice[fqgn] = d.config_name if d.is_delete: if is_overridden: d.is_delete = False d.config_name = group_to_choice[fqgn] else: delete_key = DeleteKey( fqgn, d.config_name if d.config_name != "_delete_" else None, must_delete=d.from_override, ) if delete_key not in delete_groups: delete_groups[delete_key] = 0
def _compute_element_defaults_list_impl( element: DefaultElement, group_to_choice: DictConfig, delete_groups: Dict[DeleteKey, int], skip_missing: bool, repo: IConfigRepository, ) -> List[DefaultElement]: deleted = delete_if_matching(delete_groups, element) if deleted: return [] if element.config_name == "???": if skip_missing: element.set_skip_load("missing_skipped") return [element] else: if element.config_group is not None: options = repo.get_group_options( element.config_group, results_filter=ObjectType.CONFIG ) opt_list = "\n".join(["\t" + x for x in options]) msg = ( f"You must specify '{element.config_group}', e.g, {element.config_group}=<OPTION>" f"\nAvailable options:" f"\n{opt_list}" ) else: msg = f"You must specify '{element.config_group}', e.g, {element.config_group}=<OPTION>" raise ConfigCompositionException(msg) loaded = repo.load_config( config_path=element.config_path(), is_primary_config=element.primary, ) if loaded is None: if element.optional: element.set_skip_load("missing_optional_config") return [element] else: missing_config_error(repo=repo, element=element) else: original = copy.deepcopy(loaded.defaults_list) effective = copy.deepcopy(loaded.defaults_list) defaults = DefaultsList(original=original, effective=effective) _validate_self(element, defaults) return _expand_defaults_list_impl( self_element=element, defaults_list=defaults, group_to_choice=group_to_choice, delete_groups=delete_groups, skip_missing=skip_missing, repo=repo, )
def missing_config_error(repo: IConfigRepository, element: DefaultElement) -> None: options = None if element.config_group is not None: options = repo.get_group_options(element.config_group, ObjectType.CONFIG) opt_list = "\n".join(["\t" + x for x in options]) msg = ( f"Could not find '{element.config_name}' in the config group '{element.config_group}'" f"\nAvailable options:\n{opt_list}\n" ) else: msg = dedent( f"""\ Could not load {element.config_path()}. """ ) descs = [] for src in repo.get_sources(): descs.append(f"\t{repr(src)}") lines = "\n".join(descs) msg += "\nConfig search path:" + f"\n{lines}" raise MissingConfigException( missing_cfg_file=element.config_path(), message=msg, options=options, )
def _find_match_before( defaults: List[DefaultElement], like: DefaultElement ) -> Optional[DefaultElement]: fqgn = like.fully_qualified_group_name() for d2 in defaults: if d2 == like: break if d2.fully_qualified_group_name() == fqgn: return d2 return None
def test_load_defaults_list(self, type_: Type[ConfigSource], path: str) -> None: src = type_(provider="foo", path=path) ret = src.load_config(config_path="config_with_defaults_list", is_primary_config=True) assert ret.defaults_list == [ DefaultElement( config_group="dataset", config_name="imagenet", parent="config_with_defaults_list", ) ]
def convert_overrides_to_defaults( parsed_overrides: List[Override], ) -> List[DefaultElement]: ret = [] for override in parsed_overrides: if override.is_add() and override.is_package_rename(): raise ConfigCompositionException( "Add syntax does not support package rename, remove + prefix" ) value = override.value() if override.is_delete() and value is None: value = "_delete_" if not isinstance(value, str): raise ConfigCompositionException( "Defaults list supported delete syntax is in the form" " ~group and ~group=value, where value is a group name (string)" ) if override.is_package_rename(): default = DefaultElement( config_group=override.key_or_group, config_name=value, package=override.pkg1, rename_package_to=override.pkg2, from_override=True, parent="overrides", ) else: default = DefaultElement( config_group=override.key_or_group, config_name=value, package=override.get_subject_package(), from_override=True, parent="overrides", ) if override.is_delete(): default.is_delete = True if override.is_add(): default.is_add = True ret.append(default) return ret
def is_matching_deletion( delete_groups: Dict[DeleteKey, int], d: DefaultElement, mark_item_as_deleted: bool = False, ) -> bool: matched = False for delete in delete_groups: if delete.fqgn == d.fully_qualified_group_name(): if delete.config_name is None: # fqdn only matched = True if mark_item_as_deleted: delete_groups[delete] += 1 d.is_deleted = True d.set_skip_load("deleted_from_list") else: if delete.config_name == d.config_name: matched = True if mark_item_as_deleted: delete_groups[delete] += 1 d.is_deleted = True d.set_skip_load("deleted_from_list") return matched
def _load_single_config( self, default: DefaultElement, repo: IConfigRepository, ) -> Tuple[ConfigResult, LoadTrace]: config_path = default.config_path() package_override = default.package ret = repo.load_config( config_path=config_path, is_primary_config=default.primary, package_override=package_override, ) assert ret is not None if not isinstance(ret.config, DictConfig): raise ValueError( f"Config {config_path} must be a Dictionary, got {type(ret).__name__}" ) default.search_path = ret.path schema_provider = None if not ret.is_schema_source: schema = None try: schema_source = repo.get_schema_source() schema = schema_source.load_config( ConfigSource._normalize_file_name(filename=config_path), is_primary_config=default.primary, package_override=package_override, ) except ConfigLoadError: # schema not found, ignore pass if schema is not None: try: # if config has a hydra node, remove it during validation and add it back. # This allows overriding Hydra's configuration without declaring this node # in every program hydra = None hydra_config_group = ( default.config_group is not None and default.config_group.startswith("hydra/")) if "hydra" in ret.config and not hydra_config_group: hydra = ret.config.pop("hydra") schema_provider = schema.provider merged = OmegaConf.merge(schema.config, ret.config) assert isinstance(merged, DictConfig) if hydra is not None: with open_dict(merged): merged.hydra = hydra ret.config = merged except OmegaConfBaseException as e: raise ConfigCompositionException( f"Error merging '{config_path}' with schema") from e assert isinstance(merged, DictConfig) trace = LoadTrace( config_group=default.config_group, config_name=default.config_name, package=default.get_subject_package(), search_path=ret.path, parent=default.parent, provider=ret.provider, schema_provider=schema_provider, ) return ret, trace
def _load_configuration_impl( self, config_name: Optional[str], overrides: List[str], run_mode: RunMode, strict: Optional[bool] = None, from_shell: bool = True, ) -> DictConfig: self.ensure_main_config_source_available() caching_repo = CachingConfigRepository(self.repository) if config_name is not None and not caching_repo.config_exists( config_name): self._missing_config_error( config_name=config_name, msg= f"Cannot find primary config : {config_name}, check that it's in your config search path", with_search_path=True, ) if strict is None: strict = self.default_strict parser = OverridesParser.create() parsed_overrides = parser.parse_overrides(overrides=overrides) config_overrides = ConfigLoaderImpl.parse_overrides( overrides=overrides, run_mode=run_mode, from_shell=from_shell) split_res = self.split_by_override_type(config_overrides) config_group_overrides = split_res.config_group_overrides config_overrides = split_res.config_overrides input_defaults = [DefaultElement(config_name="hydra_config")] if config_name is not None: input_defaults.append( DefaultElement(config_name=config_name, primary=True)) for default in convert_overrides_to_defaults(config_group_overrides): input_defaults.append(default) skip_missing = run_mode == RunMode.MULTIRUN defaults = expand_defaults_list( defaults=input_defaults, skip_missing=skip_missing, repo=caching_repo, ) cfg, composition_trace = self._compose_config_from_defaults_list( defaults=defaults, repo=caching_repo) OmegaConf.set_struct(cfg, strict) OmegaConf.set_readonly(cfg.hydra, False) # Apply command line overrides after enabling strict flag ConfigLoaderImpl._apply_overrides_to_config(config_overrides, cfg) app_overrides = [] for override in parsed_overrides: if override.is_hydra_override(): cfg.hydra.overrides.hydra.append(override.input_line) else: cfg.hydra.overrides.task.append(override.input_line) app_overrides.append(override) with open_dict(cfg.hydra): from hydra import __version__ cfg.hydra.runtime.version = __version__ cfg.hydra.runtime.cwd = os.getcwd() cfg.hydra.composition_trace = composition_trace if "name" not in cfg.hydra.job: cfg.hydra.job.name = JobRuntime().get("name") cfg.hydra.job.override_dirname = get_overrides_dirname( overrides=app_overrides, kv_sep=cfg.hydra.job.config.override_dirname.kv_sep, item_sep=cfg.hydra.job.config.override_dirname.item_sep, exclude_keys=cfg.hydra.job.config.override_dirname. exclude_keys, ) cfg.hydra.job.config_name = config_name for key in cfg.hydra.job.env_copy: cfg.hydra.job.env_set[key] = os.environ[key] return cfg
class ConfigSourceTestSuite: def skip_overlap_config_path_name(self) -> bool: """ Some config source plugins do not support config name and path overlap. For example the following may not be allowed: (dataset exists both as a config object and a config group) /dateset.yaml /dataset/cifar.yaml Overriding and returning True here will disable testing of this scenario by assuming the dataset config (dataset.yaml) is not present. """ return False def test_not_available(self, type_: Type[ConfigSource], path: str) -> None: scheme = type_(provider="foo", path=path).scheme() # Test is meaningless for StructuredConfigSource if scheme == "structured": return src = type_(provider="foo", path=f"{scheme}://___NOT_FOUND___") assert not src.available() @mark.parametrize( # type: ignore "config_path, expected", [ pytest.param("", True, id="empty"), pytest.param("dataset", True, id="dataset"), pytest.param("optimizer", True, id="optimizer"), pytest.param( "configs_with_defaults_list", True, id="configs_with_defaults_list", ), pytest.param("dataset/imagenet", False, id="dataset/imagenet"), pytest.param("level1", True, id="level1"), pytest.param("level1/level2", True, id="level1/level2"), pytest.param( "level1/level2/nested1", False, id="level1/level2/nested1"), pytest.param("not_found", False, id="not_found"), ], ) def test_is_group(self, type_: Type[ConfigSource], path: str, config_path: str, expected: bool) -> None: src = type_(provider="foo", path=path) ret = src.is_group(config_path=config_path) assert ret == expected @mark.parametrize( # type: ignore "config_path, expected", [ ("", False), ("optimizer", False), ("dataset/imagenet", True), ("dataset/imagenet.yaml", True), ("dataset/imagenet.foobar", False), ("configs_with_defaults_list/global_package", True), ("configs_with_defaults_list/group_package", True), ("level1", False), ("level1/level2", False), ("level1/level2/nested1", True), ("not_found", False), ], ) def test_is_config(self, type_: Type[ConfigSource], path: str, config_path: str, expected: bool) -> None: src = type_(provider="foo", path=path) ret = src.is_config(config_path=config_path) assert ret == expected @mark.parametrize( # type: ignore "config_path, expected", [ ("dataset", True), ], ) def test_is_config_with_overlap_name(self, type_: Type[ConfigSource], path: str, config_path: str, expected: bool) -> None: if self.skip_overlap_config_path_name(): pytest.skip( f"ConfigSourcePlugin {type_.__name__} does not support config objects and config groups " f"with overlapping names.") src = type_(provider="foo", path=path) ret = src.is_config(config_path=config_path) assert ret == expected @mark.parametrize( # type: ignore "config_path,results_filter,expected", [ # groups ("", ObjectType.GROUP, ["dataset", "level1", "optimizer"]), ("dataset", ObjectType.GROUP, []), ("optimizer", ObjectType.GROUP, []), ("level1", ObjectType.GROUP, ["level2"]), ("level1/level2", ObjectType.GROUP, []), # Configs ("", ObjectType.CONFIG, ["config_without_group"]), ("dataset", ObjectType.CONFIG, ["cifar10", "imagenet"]), ("optimizer", ObjectType.CONFIG, ["adam", "nesterov"]), ("level1", ObjectType.CONFIG, []), ("level1/level2", ObjectType.CONFIG, ["nested1", "nested2"]), # both ("", None, ["config_without_group", "dataset", "level1", "optimizer"]), ("dataset", None, ["cifar10", "imagenet"]), ("optimizer", None, ["adam", "nesterov"]), ("level1", None, ["level2"]), ("level1/level2", None, ["nested1", "nested2"]), ("", None, ["config_without_group", "dataset", "level1", "optimizer"]), ], ) def test_list( self, type_: Type[ConfigSource], path: str, config_path: str, results_filter: Optional[ObjectType], expected: List[str], ) -> None: src = type_(provider="foo", path=path) ret = src.list(config_path=config_path, results_filter=results_filter) for x in expected: assert x in ret assert ret == sorted(ret) @mark.parametrize( # type: ignore "config_path,results_filter,expected", [ # Configs ("", ObjectType.CONFIG, ["dataset"]), ], ) def test_list_with_overlap_name( self, type_: Type[ConfigSource], path: str, config_path: str, results_filter: Optional[ObjectType], expected: List[str], ) -> None: if self.skip_overlap_config_path_name(): pytest.skip( f"ConfigSourcePlugin {type_.__name__} does not support config objects and config groups " f"with overlapping names.") src = type_(provider="foo", path=path) ret = src.list(config_path=config_path, results_filter=results_filter) for x in expected: assert x in ret assert ret == sorted(ret) @mark.parametrize( # type: ignore "config_path,expected_config,expected_defaults_list", [ param( "config_without_group", {"group": False}, [], id="config_without_group", ), param( "config_with_unicode", {"group": "数据库"}, [], id="config_with_unicode", ), param( "dataset/imagenet", { "dataset": { "name": "imagenet", "path": "/datasets/imagenet" } }, [], id="dataset/imagenet", ), param( "dataset/cifar10", {"dataset": { "name": "cifar10", "path": "/datasets/cifar10" }}, [], id="dataset/cifar10", ), param( "dataset/not_found", raises(ConfigLoadError), [], id="dataset/not_found", ), param( "level1/level2/nested1", {"l1_l2_n1": True}, [], id="level1/level2/nested1", ), param( "level1/level2/nested2", {"l1_l2_n2": True}, [], id="level1/level2/nested2", ), param( "config_with_defaults_list", {"key": "value"}, [ DefaultElement( config_group="dataset", config_name="imagenet", parent="config_with_defaults_list", ) ], id="config_with_defaults_list", ), param( "configs_with_defaults_list/global_package", {"configs_with_defaults_list": { "x": 10 }}, [ DefaultElement( config_group="foo", config_name="bar", parent="configs_with_defaults_list/global_package", ) ], id="configs_with_defaults_list/global_package", ), param( "configs_with_defaults_list/group_package", {"configs_with_defaults_list": { "x": 10 }}, [ DefaultElement( config_group="foo", config_name="bar", parent="configs_with_defaults_list/group_package", ) ], id="configs_with_defaults_list/group_package", ), ], ) def test_source_load_config( self, type_: Type[ConfigSource], path: str, config_path: str, expected_defaults_list: List[DefaultElement], expected_config: Any, recwarn: Any, ) -> None: assert issubclass(type_, ConfigSource) src = type_(provider="foo", path=path) if isinstance(expected_config, dict): ret = src.load_config(config_path=config_path, is_primary_config=False) assert ret.config == expected_config assert ret.defaults_list == expected_defaults_list else: with expected_config: src.load_config(config_path=config_path, is_primary_config=False) @mark.parametrize( # type: ignore "config_path, expected_result, expected_package", [ param("package_test/none", {"foo": "bar"}, "", id="none"), param( "package_test/explicit", {"a": { "b": { "foo": "bar" } }}, "a.b", id="explicit", ), param("package_test/global", {"foo": "bar"}, "", id="global"), param( "package_test/group", {"package_test": { "foo": "bar" }}, "package_test", id="group", ), param( "package_test/group_name", {"foo": { "package_test": { "group_name": { "foo": "bar" } } }}, "foo.package_test.group_name", id="group_name", ), param("package_test/name", {"name": { "foo": "bar" }}, "name", id="name"), ], ) def test_package_behavior( self, type_: Type[ConfigSource], path: str, config_path: str, expected_result: Any, expected_package: str, recwarn: Any, ) -> None: src = type_(provider="foo", path=path) cfg = src.load_config(config_path=config_path, is_primary_config=False) assert cfg.header["package"] == expected_package assert cfg.config == expected_result def test_default_package_for_primary_config(self, type_: Type[ConfigSource], path: str) -> None: src = type_(provider="foo", path=path) cfg = src.load_config(config_path="primary_config", is_primary_config=True) assert cfg.header["package"] == "" def test_primary_config_with_non_global_package_errors( self, type_: Type[ConfigSource], path: str) -> None: src = type_(provider="foo", path=path) with raises( HydraException, match=re.escape( "Primary config 'primary_config_with_non_global_package' must be in the _global_ package; " "effective package : 'foo'"), ): src.load_config( config_path="primary_config_with_non_global_package", is_primary_config=True, ) def test_load_defaults_list(self, type_: Type[ConfigSource], path: str) -> None: src = type_(provider="foo", path=path) ret = src.load_config(config_path="config_with_defaults_list", is_primary_config=True) assert ret.defaults_list == [ DefaultElement( config_group="dataset", config_name="imagenet", parent="config_with_defaults_list", ) ]
def _create_defaults_list( config_path: Optional[str], defaults: ListConfig, ) -> List[DefaultElement]: def _split_group( group_with_package: str, ) -> Tuple[str, Optional[str], Optional[str]]: idx = group_with_package.find("@") if idx == -1: # group group = group_with_package package = None else: # group@package group = group_with_package[0:idx] package = group_with_package[idx + 1 :] package2 = None if package is not None: # if we have a package, break it down if it's a rename idx = package.find(":") if idx != -1: package2 = package[idx + 1 :] package = package[0:idx] if package == "": package = None if package2 == "": package2 = None return group, package, package2 if not isinstance(defaults, MutableSequence): raise ValueError( dedent( f"""\ Invalid defaults list in '{config_path}', defaults must be a list. Example of a valid defaults: defaults: - dataset: imagenet - model: alexnet optional: true - optimizer: nesterov """ ) ) res: List[DefaultElement] = [] for item in defaults: if isinstance(item, DictConfig): optional = False if "optional" in item: optional = item.pop("optional") keys = list(item.keys()) if len(keys) > 1: raise ValueError(f"Too many keys in default item {item}") if len(keys) == 0: raise ValueError(f"Missing group name in {item}") key = keys[0] config_group, package, package2 = _split_group(key) node = item._get_node(key) assert node is not None config_name = node._value() is_delete = False if config_name is None: warnings.warn( category=UserWarning, message=dedent( f""" Deprecated form of deletion used in the defaults list of '{config_path}'. 'group: null' is deprecated, use '~group' instead. You can also delete group with a specific value with '~group: value'. Support for the 'group: null' form will be removed in Hydra 1.2. """ ), ) is_delete = True elif config_group.startswith("~"): is_delete = True config_group = config_group[1:] default = DefaultElement( config_group=config_group, config_name=config_name, package=package, rename_package_to=package2, optional=optional, is_delete=is_delete, parent=config_path, ) elif isinstance(item, str): if item.startswith("~"): item = item[1:] default = DefaultElement( config_group=item, config_name="_delete_", is_delete=True, parent=config_path, ) else: default = DefaultElement( config_group=None, config_name=item, parent=config_path, ) else: raise ValueError( f"Unsupported type in defaults : {type(item).__name__}" ) res.append(default) return res