예제 #1
0
def _update_known_state(
    d: DefaultElement,
    group_to_choice: DictConfig,
    delete_groups: Dict[DeleteKey, int],
) -> None:
    fqgn = d.fully_qualified_group_name()
    if fqgn is None:
        return

    is_overridden = fqgn in group_to_choice

    if d.config_group is not None:
        if (
            fqgn not in group_to_choice
            and not d.is_delete
            and d.config_name not in ("_self_", "_keep_")
            and not is_matching_deletion(delete_groups=delete_groups, d=d)
        ):
            group_to_choice[fqgn] = d.config_name

    if d.is_delete:
        if is_overridden:
            d.is_delete = False
            d.config_name = group_to_choice[fqgn]
        else:
            delete_key = DeleteKey(
                fqgn,
                d.config_name if d.config_name != "_delete_" else None,
                must_delete=d.from_override,
            )
            if delete_key not in delete_groups:
                delete_groups[delete_key] = 0
예제 #2
0
def _compute_element_defaults_list_impl(
    element: DefaultElement,
    group_to_choice: DictConfig,
    delete_groups: Dict[DeleteKey, int],
    skip_missing: bool,
    repo: IConfigRepository,
) -> List[DefaultElement]:
    deleted = delete_if_matching(delete_groups, element)
    if deleted:
        return []

    if element.config_name == "???":
        if skip_missing:
            element.set_skip_load("missing_skipped")
            return [element]
        else:
            if element.config_group is not None:
                options = repo.get_group_options(
                    element.config_group, results_filter=ObjectType.CONFIG
                )
                opt_list = "\n".join(["\t" + x for x in options])
                msg = (
                    f"You must specify '{element.config_group}', e.g, {element.config_group}=<OPTION>"
                    f"\nAvailable options:"
                    f"\n{opt_list}"
                )
            else:
                msg = f"You must specify '{element.config_group}', e.g, {element.config_group}=<OPTION>"

            raise ConfigCompositionException(msg)

    loaded = repo.load_config(
        config_path=element.config_path(),
        is_primary_config=element.primary,
    )

    if loaded is None:
        if element.optional:
            element.set_skip_load("missing_optional_config")
            return [element]
        else:
            missing_config_error(repo=repo, element=element)
    else:
        original = copy.deepcopy(loaded.defaults_list)
        effective = copy.deepcopy(loaded.defaults_list)

    defaults = DefaultsList(original=original, effective=effective)
    _validate_self(element, defaults)

    return _expand_defaults_list_impl(
        self_element=element,
        defaults_list=defaults,
        group_to_choice=group_to_choice,
        delete_groups=delete_groups,
        skip_missing=skip_missing,
        repo=repo,
    )
예제 #3
0
def missing_config_error(repo: IConfigRepository, element: DefaultElement) -> None:
    options = None
    if element.config_group is not None:
        options = repo.get_group_options(element.config_group, ObjectType.CONFIG)
        opt_list = "\n".join(["\t" + x for x in options])
        msg = (
            f"Could not find '{element.config_name}' in the config group '{element.config_group}'"
            f"\nAvailable options:\n{opt_list}\n"
        )
    else:
        msg = dedent(
            f"""\
        Could not load {element.config_path()}.
        """
        )

    descs = []
    for src in repo.get_sources():
        descs.append(f"\t{repr(src)}")
    lines = "\n".join(descs)
    msg += "\nConfig search path:" + f"\n{lines}"

    raise MissingConfigException(
        missing_cfg_file=element.config_path(),
        message=msg,
        options=options,
    )
예제 #4
0
def _find_match_before(
    defaults: List[DefaultElement], like: DefaultElement
) -> Optional[DefaultElement]:
    fqgn = like.fully_qualified_group_name()
    for d2 in defaults:
        if d2 == like:
            break
        if d2.fully_qualified_group_name() == fqgn:
            return d2
    return None
예제 #5
0
 def test_load_defaults_list(self, type_: Type[ConfigSource],
                             path: str) -> None:
     src = type_(provider="foo", path=path)
     ret = src.load_config(config_path="config_with_defaults_list",
                           is_primary_config=True)
     assert ret.defaults_list == [
         DefaultElement(
             config_group="dataset",
             config_name="imagenet",
             parent="config_with_defaults_list",
         )
     ]
예제 #6
0
def convert_overrides_to_defaults(
    parsed_overrides: List[Override],
) -> List[DefaultElement]:
    ret = []
    for override in parsed_overrides:
        if override.is_add() and override.is_package_rename():
            raise ConfigCompositionException(
                "Add syntax does not support package rename, remove + prefix"
            )

        value = override.value()
        if override.is_delete() and value is None:
            value = "_delete_"

        if not isinstance(value, str):
            raise ConfigCompositionException(
                "Defaults list supported delete syntax is in the form"
                " ~group and ~group=value, where value is a group name (string)"
            )

        if override.is_package_rename():
            default = DefaultElement(
                config_group=override.key_or_group,
                config_name=value,
                package=override.pkg1,
                rename_package_to=override.pkg2,
                from_override=True,
                parent="overrides",
            )
        else:
            default = DefaultElement(
                config_group=override.key_or_group,
                config_name=value,
                package=override.get_subject_package(),
                from_override=True,
                parent="overrides",
            )

        if override.is_delete():
            default.is_delete = True

        if override.is_add():
            default.is_add = True
        ret.append(default)
    return ret
예제 #7
0
def is_matching_deletion(
    delete_groups: Dict[DeleteKey, int],
    d: DefaultElement,
    mark_item_as_deleted: bool = False,
) -> bool:
    matched = False
    for delete in delete_groups:
        if delete.fqgn == d.fully_qualified_group_name():
            if delete.config_name is None:
                # fqdn only
                matched = True
                if mark_item_as_deleted:
                    delete_groups[delete] += 1
                    d.is_deleted = True
                    d.set_skip_load("deleted_from_list")
            else:
                if delete.config_name == d.config_name:
                    matched = True
                    if mark_item_as_deleted:
                        delete_groups[delete] += 1
                        d.is_deleted = True
                        d.set_skip_load("deleted_from_list")
    return matched
예제 #8
0
    def _load_single_config(
        self,
        default: DefaultElement,
        repo: IConfigRepository,
    ) -> Tuple[ConfigResult, LoadTrace]:
        config_path = default.config_path()
        package_override = default.package
        ret = repo.load_config(
            config_path=config_path,
            is_primary_config=default.primary,
            package_override=package_override,
        )
        assert ret is not None

        if not isinstance(ret.config, DictConfig):
            raise ValueError(
                f"Config {config_path} must be a Dictionary, got {type(ret).__name__}"
            )

        default.search_path = ret.path

        schema_provider = None
        if not ret.is_schema_source:
            schema = None
            try:
                schema_source = repo.get_schema_source()
                schema = schema_source.load_config(
                    ConfigSource._normalize_file_name(filename=config_path),
                    is_primary_config=default.primary,
                    package_override=package_override,
                )

            except ConfigLoadError:
                # schema not found, ignore
                pass

            if schema is not None:
                try:
                    # if config has a hydra node, remove it during validation and add it back.
                    # This allows overriding Hydra's configuration without declaring this node
                    # in every program
                    hydra = None
                    hydra_config_group = (
                        default.config_group is not None
                        and default.config_group.startswith("hydra/"))
                    if "hydra" in ret.config and not hydra_config_group:
                        hydra = ret.config.pop("hydra")
                    schema_provider = schema.provider
                    merged = OmegaConf.merge(schema.config, ret.config)
                    assert isinstance(merged, DictConfig)

                    if hydra is not None:
                        with open_dict(merged):
                            merged.hydra = hydra
                    ret.config = merged
                except OmegaConfBaseException as e:
                    raise ConfigCompositionException(
                        f"Error merging '{config_path}' with schema") from e

                assert isinstance(merged, DictConfig)

        trace = LoadTrace(
            config_group=default.config_group,
            config_name=default.config_name,
            package=default.get_subject_package(),
            search_path=ret.path,
            parent=default.parent,
            provider=ret.provider,
            schema_provider=schema_provider,
        )
        return ret, trace
예제 #9
0
    def _load_configuration_impl(
        self,
        config_name: Optional[str],
        overrides: List[str],
        run_mode: RunMode,
        strict: Optional[bool] = None,
        from_shell: bool = True,
    ) -> DictConfig:
        self.ensure_main_config_source_available()
        caching_repo = CachingConfigRepository(self.repository)
        if config_name is not None and not caching_repo.config_exists(
                config_name):
            self._missing_config_error(
                config_name=config_name,
                msg=
                f"Cannot find primary config : {config_name}, check that it's in your config search path",
                with_search_path=True,
            )

        if strict is None:
            strict = self.default_strict

        parser = OverridesParser.create()
        parsed_overrides = parser.parse_overrides(overrides=overrides)
        config_overrides = ConfigLoaderImpl.parse_overrides(
            overrides=overrides, run_mode=run_mode, from_shell=from_shell)

        split_res = self.split_by_override_type(config_overrides)
        config_group_overrides = split_res.config_group_overrides
        config_overrides = split_res.config_overrides

        input_defaults = [DefaultElement(config_name="hydra_config")]

        if config_name is not None:
            input_defaults.append(
                DefaultElement(config_name=config_name, primary=True))

        for default in convert_overrides_to_defaults(config_group_overrides):
            input_defaults.append(default)

        skip_missing = run_mode == RunMode.MULTIRUN
        defaults = expand_defaults_list(
            defaults=input_defaults,
            skip_missing=skip_missing,
            repo=caching_repo,
        )

        cfg, composition_trace = self._compose_config_from_defaults_list(
            defaults=defaults, repo=caching_repo)

        OmegaConf.set_struct(cfg, strict)
        OmegaConf.set_readonly(cfg.hydra, False)

        # Apply command line overrides after enabling strict flag
        ConfigLoaderImpl._apply_overrides_to_config(config_overrides, cfg)
        app_overrides = []
        for override in parsed_overrides:
            if override.is_hydra_override():
                cfg.hydra.overrides.hydra.append(override.input_line)
            else:
                cfg.hydra.overrides.task.append(override.input_line)
                app_overrides.append(override)

        with open_dict(cfg.hydra):
            from hydra import __version__

            cfg.hydra.runtime.version = __version__
            cfg.hydra.runtime.cwd = os.getcwd()

            cfg.hydra.composition_trace = composition_trace
            if "name" not in cfg.hydra.job:
                cfg.hydra.job.name = JobRuntime().get("name")
            cfg.hydra.job.override_dirname = get_overrides_dirname(
                overrides=app_overrides,
                kv_sep=cfg.hydra.job.config.override_dirname.kv_sep,
                item_sep=cfg.hydra.job.config.override_dirname.item_sep,
                exclude_keys=cfg.hydra.job.config.override_dirname.
                exclude_keys,
            )
            cfg.hydra.job.config_name = config_name

            for key in cfg.hydra.job.env_copy:
                cfg.hydra.job.env_set[key] = os.environ[key]

        return cfg
예제 #10
0
class ConfigSourceTestSuite:
    def skip_overlap_config_path_name(self) -> bool:
        """
        Some config source plugins do not support config name and path overlap.
        For example the following may not be allowed:
        (dataset exists both as a config object and a config group)
        /dateset.yaml
        /dataset/cifar.yaml

        Overriding and returning True here will disable testing of this scenario
        by assuming the dataset config (dataset.yaml) is not present.
        """
        return False

    def test_not_available(self, type_: Type[ConfigSource], path: str) -> None:
        scheme = type_(provider="foo", path=path).scheme()
        # Test is meaningless for StructuredConfigSource
        if scheme == "structured":
            return
        src = type_(provider="foo", path=f"{scheme}://___NOT_FOUND___")
        assert not src.available()

    @mark.parametrize(  # type: ignore
        "config_path, expected",
        [
            pytest.param("", True, id="empty"),
            pytest.param("dataset", True, id="dataset"),
            pytest.param("optimizer", True, id="optimizer"),
            pytest.param(
                "configs_with_defaults_list",
                True,
                id="configs_with_defaults_list",
            ),
            pytest.param("dataset/imagenet", False, id="dataset/imagenet"),
            pytest.param("level1", True, id="level1"),
            pytest.param("level1/level2", True, id="level1/level2"),
            pytest.param(
                "level1/level2/nested1", False, id="level1/level2/nested1"),
            pytest.param("not_found", False, id="not_found"),
        ],
    )
    def test_is_group(self, type_: Type[ConfigSource], path: str,
                      config_path: str, expected: bool) -> None:
        src = type_(provider="foo", path=path)
        ret = src.is_group(config_path=config_path)
        assert ret == expected

    @mark.parametrize(  # type: ignore
        "config_path, expected",
        [
            ("", False),
            ("optimizer", False),
            ("dataset/imagenet", True),
            ("dataset/imagenet.yaml", True),
            ("dataset/imagenet.foobar", False),
            ("configs_with_defaults_list/global_package", True),
            ("configs_with_defaults_list/group_package", True),
            ("level1", False),
            ("level1/level2", False),
            ("level1/level2/nested1", True),
            ("not_found", False),
        ],
    )
    def test_is_config(self, type_: Type[ConfigSource], path: str,
                       config_path: str, expected: bool) -> None:
        src = type_(provider="foo", path=path)
        ret = src.is_config(config_path=config_path)
        assert ret == expected

    @mark.parametrize(  # type: ignore
        "config_path, expected",
        [
            ("dataset", True),
        ],
    )
    def test_is_config_with_overlap_name(self, type_: Type[ConfigSource],
                                         path: str, config_path: str,
                                         expected: bool) -> None:
        if self.skip_overlap_config_path_name():
            pytest.skip(
                f"ConfigSourcePlugin {type_.__name__} does not support config objects and config groups "
                f"with overlapping names.")
        src = type_(provider="foo", path=path)
        ret = src.is_config(config_path=config_path)
        assert ret == expected

    @mark.parametrize(  # type: ignore
        "config_path,results_filter,expected",
        [
            # groups
            ("", ObjectType.GROUP, ["dataset", "level1", "optimizer"]),
            ("dataset", ObjectType.GROUP, []),
            ("optimizer", ObjectType.GROUP, []),
            ("level1", ObjectType.GROUP, ["level2"]),
            ("level1/level2", ObjectType.GROUP, []),
            # Configs
            ("", ObjectType.CONFIG, ["config_without_group"]),
            ("dataset", ObjectType.CONFIG, ["cifar10", "imagenet"]),
            ("optimizer", ObjectType.CONFIG, ["adam", "nesterov"]),
            ("level1", ObjectType.CONFIG, []),
            ("level1/level2", ObjectType.CONFIG, ["nested1", "nested2"]),
            # both
            ("", None,
             ["config_without_group", "dataset", "level1", "optimizer"]),
            ("dataset", None, ["cifar10", "imagenet"]),
            ("optimizer", None, ["adam", "nesterov"]),
            ("level1", None, ["level2"]),
            ("level1/level2", None, ["nested1", "nested2"]),
            ("", None,
             ["config_without_group", "dataset", "level1", "optimizer"]),
        ],
    )
    def test_list(
        self,
        type_: Type[ConfigSource],
        path: str,
        config_path: str,
        results_filter: Optional[ObjectType],
        expected: List[str],
    ) -> None:
        src = type_(provider="foo", path=path)
        ret = src.list(config_path=config_path, results_filter=results_filter)
        for x in expected:
            assert x in ret
        assert ret == sorted(ret)

    @mark.parametrize(  # type: ignore
        "config_path,results_filter,expected",
        [
            # Configs
            ("", ObjectType.CONFIG, ["dataset"]),
        ],
    )
    def test_list_with_overlap_name(
        self,
        type_: Type[ConfigSource],
        path: str,
        config_path: str,
        results_filter: Optional[ObjectType],
        expected: List[str],
    ) -> None:
        if self.skip_overlap_config_path_name():
            pytest.skip(
                f"ConfigSourcePlugin {type_.__name__} does not support config objects and config groups "
                f"with overlapping names.")
        src = type_(provider="foo", path=path)
        ret = src.list(config_path=config_path, results_filter=results_filter)
        for x in expected:
            assert x in ret
        assert ret == sorted(ret)

    @mark.parametrize(  # type: ignore
        "config_path,expected_config,expected_defaults_list",
        [
            param(
                "config_without_group",
                {"group": False},
                [],
                id="config_without_group",
            ),
            param(
                "config_with_unicode",
                {"group": "数据库"},
                [],
                id="config_with_unicode",
            ),
            param(
                "dataset/imagenet",
                {
                    "dataset": {
                        "name": "imagenet",
                        "path": "/datasets/imagenet"
                    }
                },
                [],
                id="dataset/imagenet",
            ),
            param(
                "dataset/cifar10",
                {"dataset": {
                    "name": "cifar10",
                    "path": "/datasets/cifar10"
                }},
                [],
                id="dataset/cifar10",
            ),
            param(
                "dataset/not_found",
                raises(ConfigLoadError),
                [],
                id="dataset/not_found",
            ),
            param(
                "level1/level2/nested1",
                {"l1_l2_n1": True},
                [],
                id="level1/level2/nested1",
            ),
            param(
                "level1/level2/nested2",
                {"l1_l2_n2": True},
                [],
                id="level1/level2/nested2",
            ),
            param(
                "config_with_defaults_list",
                {"key": "value"},
                [
                    DefaultElement(
                        config_group="dataset",
                        config_name="imagenet",
                        parent="config_with_defaults_list",
                    )
                ],
                id="config_with_defaults_list",
            ),
            param(
                "configs_with_defaults_list/global_package",
                {"configs_with_defaults_list": {
                    "x": 10
                }},
                [
                    DefaultElement(
                        config_group="foo",
                        config_name="bar",
                        parent="configs_with_defaults_list/global_package",
                    )
                ],
                id="configs_with_defaults_list/global_package",
            ),
            param(
                "configs_with_defaults_list/group_package",
                {"configs_with_defaults_list": {
                    "x": 10
                }},
                [
                    DefaultElement(
                        config_group="foo",
                        config_name="bar",
                        parent="configs_with_defaults_list/group_package",
                    )
                ],
                id="configs_with_defaults_list/group_package",
            ),
        ],
    )
    def test_source_load_config(
        self,
        type_: Type[ConfigSource],
        path: str,
        config_path: str,
        expected_defaults_list: List[DefaultElement],
        expected_config: Any,
        recwarn: Any,
    ) -> None:
        assert issubclass(type_, ConfigSource)
        src = type_(provider="foo", path=path)
        if isinstance(expected_config, dict):
            ret = src.load_config(config_path=config_path,
                                  is_primary_config=False)
            assert ret.config == expected_config
            assert ret.defaults_list == expected_defaults_list
        else:
            with expected_config:
                src.load_config(config_path=config_path,
                                is_primary_config=False)

    @mark.parametrize(  # type: ignore
        "config_path, expected_result, expected_package",
        [
            param("package_test/none", {"foo": "bar"}, "", id="none"),
            param(
                "package_test/explicit",
                {"a": {
                    "b": {
                        "foo": "bar"
                    }
                }},
                "a.b",
                id="explicit",
            ),
            param("package_test/global", {"foo": "bar"}, "", id="global"),
            param(
                "package_test/group",
                {"package_test": {
                    "foo": "bar"
                }},
                "package_test",
                id="group",
            ),
            param(
                "package_test/group_name",
                {"foo": {
                    "package_test": {
                        "group_name": {
                            "foo": "bar"
                        }
                    }
                }},
                "foo.package_test.group_name",
                id="group_name",
            ),
            param("package_test/name", {"name": {
                "foo": "bar"
            }},
                  "name",
                  id="name"),
        ],
    )
    def test_package_behavior(
        self,
        type_: Type[ConfigSource],
        path: str,
        config_path: str,
        expected_result: Any,
        expected_package: str,
        recwarn: Any,
    ) -> None:
        src = type_(provider="foo", path=path)
        cfg = src.load_config(config_path=config_path, is_primary_config=False)
        assert cfg.header["package"] == expected_package
        assert cfg.config == expected_result

    def test_default_package_for_primary_config(self,
                                                type_: Type[ConfigSource],
                                                path: str) -> None:
        src = type_(provider="foo", path=path)
        cfg = src.load_config(config_path="primary_config",
                              is_primary_config=True)
        assert cfg.header["package"] == ""

    def test_primary_config_with_non_global_package_errors(
            self, type_: Type[ConfigSource], path: str) -> None:
        src = type_(provider="foo", path=path)
        with raises(
                HydraException,
                match=re.escape(
                    "Primary config 'primary_config_with_non_global_package' must be in the _global_ package; "
                    "effective package : 'foo'"),
        ):
            src.load_config(
                config_path="primary_config_with_non_global_package",
                is_primary_config=True,
            )

    def test_load_defaults_list(self, type_: Type[ConfigSource],
                                path: str) -> None:
        src = type_(provider="foo", path=path)
        ret = src.load_config(config_path="config_with_defaults_list",
                              is_primary_config=True)
        assert ret.defaults_list == [
            DefaultElement(
                config_group="dataset",
                config_name="imagenet",
                parent="config_with_defaults_list",
            )
        ]
예제 #11
0
    def _create_defaults_list(
        config_path: Optional[str],
        defaults: ListConfig,
    ) -> List[DefaultElement]:
        def _split_group(
            group_with_package: str,
        ) -> Tuple[str, Optional[str], Optional[str]]:
            idx = group_with_package.find("@")
            if idx == -1:
                # group
                group = group_with_package
                package = None
            else:
                # group@package
                group = group_with_package[0:idx]
                package = group_with_package[idx + 1 :]

            package2 = None
            if package is not None:
                # if we have a package, break it down if it's a rename
                idx = package.find(":")
                if idx != -1:
                    package2 = package[idx + 1 :]
                    package = package[0:idx]

            if package == "":
                package = None

            if package2 == "":
                package2 = None

            return group, package, package2

        if not isinstance(defaults, MutableSequence):
            raise ValueError(
                dedent(
                    f"""\
                    Invalid defaults list in '{config_path}', defaults must be a list.
                    Example of a valid defaults:
                    defaults:
                      - dataset: imagenet
                      - model: alexnet
                        optional: true
                      - optimizer: nesterov
                    """
                )
            )

        res: List[DefaultElement] = []
        for item in defaults:
            if isinstance(item, DictConfig):
                optional = False
                if "optional" in item:
                    optional = item.pop("optional")
                keys = list(item.keys())
                if len(keys) > 1:
                    raise ValueError(f"Too many keys in default item {item}")
                if len(keys) == 0:
                    raise ValueError(f"Missing group name in {item}")
                key = keys[0]
                config_group, package, package2 = _split_group(key)
                node = item._get_node(key)
                assert node is not None
                config_name = node._value()

                is_delete = False
                if config_name is None:
                    warnings.warn(
                        category=UserWarning,
                        message=dedent(
                            f"""
                    Deprecated form of deletion used in the defaults list of '{config_path}'.
                    'group: null' is deprecated, use '~group' instead.
                    You can also delete group with a specific value with '~group: value'.
                    Support for the 'group: null' form will be removed in Hydra 1.2.
                    """
                        ),
                    )
                    is_delete = True
                elif config_group.startswith("~"):
                    is_delete = True
                    config_group = config_group[1:]

                default = DefaultElement(
                    config_group=config_group,
                    config_name=config_name,
                    package=package,
                    rename_package_to=package2,
                    optional=optional,
                    is_delete=is_delete,
                    parent=config_path,
                )
            elif isinstance(item, str):
                if item.startswith("~"):
                    item = item[1:]
                    default = DefaultElement(
                        config_group=item,
                        config_name="_delete_",
                        is_delete=True,
                        parent=config_path,
                    )
                else:
                    default = DefaultElement(
                        config_group=None,
                        config_name=item,
                        parent=config_path,
                    )
            else:
                raise ValueError(
                    f"Unsupported type in defaults : {type(item).__name__}"
                )
            res.append(default)
        return res