Esempio n. 1
0
def _get_kwargs(
    config: Union[DictConfig, ListConfig],
    root: bool = True,
    **kwargs: Any,
) -> Any:
    from hydra.utils import instantiate

    assert OmegaConf.is_config(config)

    if OmegaConf.is_list(config):
        assert isinstance(config, ListConfig)
        return [
            _get_kwargs(x, root=False) if OmegaConf.is_config(x) else x for x in config
        ]

    assert OmegaConf.is_dict(config), "Input config is not an OmegaConf DictConfig"

    recursive = _is_recursive(config, kwargs)
    overrides = OmegaConf.create(kwargs, flags={"allow_objects": True})
    config.merge_with(overrides)

    final_kwargs = OmegaConf.create(flags={"allow_objects": True})
    final_kwargs._set_parent(config._get_parent())
    final_kwargs._set_flag("readonly", False)
    final_kwargs._set_flag("struct", False)
    if recursive:
        for k, v in config.items_ex(resolve=False):
            if OmegaConf.is_none(v):
                final_kwargs[k] = v
            elif _is_target(v):
                final_kwargs[k] = instantiate(v)
            elif OmegaConf.is_dict(v):
                d = OmegaConf.create({}, flags={"allow_objects": True})
                for key, value in v.items_ex(resolve=False):
                    if _is_target(value):
                        d[key] = instantiate(value)
                    elif OmegaConf.is_config(value):
                        d[key] = _get_kwargs(value, root=False)
                    else:
                        d[key] = value
                d._metadata.object_type = v._metadata.object_type
                final_kwargs[k] = d
            elif OmegaConf.is_list(v):
                lst = OmegaConf.create([], flags={"allow_objects": True})
                for x in v:
                    if _is_target(x):
                        lst.append(instantiate(x))
                    elif OmegaConf.is_config(x):
                        lst.append(_get_kwargs(x, root=False))
                        lst[-1]._metadata.object_type = x._metadata.object_type
                    else:
                        lst.append(x)
                final_kwargs[k] = lst
            else:
                final_kwargs[k] = v
    else:
        for k, v in config.items_ex(resolve=False):
            final_kwargs[k] = v

    final_kwargs._set_flag("readonly", None)
    final_kwargs._set_flag("struct", None)
    final_kwargs._set_flag("allow_objects", None)
    if not root:
        # This is tricky, since the root kwargs is exploded anyway we can treat is as an untyped dict
        # the motivation is that the object type is used as an indicator to treat the object differently during
        # conversion to a primitive container in some cases
        final_kwargs._metadata.object_type = config._metadata.object_type
    return final_kwargs
Esempio n. 2
0
def test_is_config(cfg: Any, expected: bool) -> None:
    assert OmegaConf.is_config(cfg) == expected
Esempio n. 3
0
def instantiate_node(
    node: Any,
    *args: Any,
    convert: Union[str, ConvertMode] = ConvertMode.NONE,
    recursive: bool = True,
    partial: bool = False,
) -> Any:
    # Return None if config is None
    if node is None or (OmegaConf.is_config(node) and node._is_none()):
        return None

    if not OmegaConf.is_config(node):
        return node

    # Override parent modes from config if specified
    if OmegaConf.is_dict(node):
        # using getitem instead of get(key, default) because OmegaConf will raise an exception
        # if the key type is incompatible on get.
        convert = node[_Keys.CONVERT] if _Keys.CONVERT in node else convert
        recursive = node[_Keys.RECURSIVE] if _Keys.RECURSIVE in node else recursive
        partial = node[_Keys.PARTIAL] if _Keys.PARTIAL in node else partial

    if not isinstance(recursive, bool):
        raise TypeError(f"_recursive_ flag must be a bool, got {type(recursive)}")

    if not isinstance(partial, bool):
        raise TypeError(f"_partial_ flag must be a bool, got {type( partial )}")

    # If OmegaConf list, create new list of instances if recursive
    if OmegaConf.is_list(node):
        items = [
            instantiate_node(item, convert=convert, recursive=recursive)
            for item in node._iter_ex(resolve=True)
        ]

        if convert in (ConvertMode.ALL, ConvertMode.PARTIAL):
            # If ALL or PARTIAL, use plain list as container
            return items
        else:
            # Otherwise, use ListConfig as container
            lst = OmegaConf.create(items, flags={"allow_objects": True})
            lst._set_parent(node)
            return lst

    elif OmegaConf.is_dict(node):
        exclude_keys = set({"_target_", "_convert_", "_recursive_", "_partial_"})
        if _is_target(node):
            _target_ = _resolve_target(node.get(_Keys.TARGET))
            kwargs = {}
            for key, value in node.items():
                if key not in exclude_keys:
                    if recursive:
                        value = instantiate_node(
                            value, convert=convert, recursive=recursive
                        )
                    kwargs[key] = _convert_node(value, convert)

            return _call_target(_target_, partial, *args, **kwargs)
        else:
            # If ALL or PARTIAL non structured, instantiate in dict and resolve interpolations eagerly.
            if convert == ConvertMode.ALL or (
                convert == ConvertMode.PARTIAL and node._metadata.object_type is None
            ):
                dict_items = {}
                for key, value in node.items():
                    # list items inherits recursive flag from the containing dict.
                    dict_items[key] = instantiate_node(
                        value, convert=convert, recursive=recursive
                    )
                return dict_items
            else:
                # Otherwise use DictConfig and resolve interpolations lazily.
                cfg = OmegaConf.create({}, flags={"allow_objects": True})
                for key, value in node.items():
                    cfg[key] = instantiate_node(
                        value, convert=convert, recursive=recursive
                    )
                cfg._set_parent(node)
                cfg._metadata.object_type = node._metadata.object_type
                return cfg

    else:
        assert False, f"Unexpected config type : {type(node).__name__}"
Esempio n. 4
0
 def str_rep(in_key: Union[str, int], in_value: Any) -> str:
     if OmegaConf.is_config(in_value):
         return "{}.".format(in_key)
     else:
         return "{}=".format(in_key)
Esempio n. 5
0
 def str_rep(in_key: Any, in_value: Any) -> str:
     if OmegaConf.is_config(in_value):
         return f"{in_key}."
     else:
         return f"{in_key}="
Esempio n. 6
0
 def recursive_is_struct(node: Any) -> None:
     if OmegaConf.is_config(node):
         OmegaConf.is_struct(node)
         for val in node.values():
             recursive_is_struct(val)
Esempio n. 7
0
    def validate_and_convert(self, value: Any) -> Optional[str]:
        from omegaconf import OmegaConf

        if OmegaConf.is_config(value) or is_primitive_container(value):
            raise ValidationError("Cannot convert '$VALUE_TYPE' to string : '$VALUE'")
        return str(value) if value is not None else None
Esempio n. 8
0
def test_is_config(cfg: Any, is_conf: bool, is_list: bool, is_dict: bool,
                   type_: Type[Any]) -> None:
    assert OmegaConf.is_config(cfg) == is_conf
    assert OmegaConf.is_list(cfg) == is_list
    assert OmegaConf.is_dict(cfg) == is_dict
    assert OmegaConf.get_type(cfg) == type_
Esempio n. 9
0
    def _get_matches(config: Container, word: str) -> List[str]:
        def str_rep(in_key: Union[str, int], in_value: Any) -> str:
            if OmegaConf.is_config(in_value):
                return f"{in_key}."
            else:
                return f"{in_key}="

        if config is None:
            return []
        elif OmegaConf.is_config(config):
            matches = []
            if word.endswith(".") or word.endswith("="):
                exact_key = word[0:-1]
                try:
                    conf_node = OmegaConf.select(config,
                                                 exact_key,
                                                 throw_on_missing=True)
                except MissingMandatoryValue:
                    conf_node = ""
                if conf_node is not None:
                    if OmegaConf.is_config(conf_node):
                        key_matches = CompletionPlugin._get_matches(
                            conf_node, "")
                    else:
                        # primitive
                        if isinstance(conf_node, bool):
                            conf_node = str(conf_node).lower()
                        key_matches = [conf_node]
                else:
                    key_matches = []

                matches.extend([f"{word}{match}" for match in key_matches])
            else:
                last_dot = word.rfind(".")
                if last_dot != -1:
                    base_key = word[0:last_dot]
                    partial_key = word[last_dot + 1:]
                    conf_node = OmegaConf.select(config, base_key)
                    key_matches = CompletionPlugin._get_matches(
                        conf_node, partial_key)
                    matches.extend(
                        [f"{base_key}.{match}" for match in key_matches])
                else:
                    if isinstance(config, DictConfig):
                        for key, value in config.items_ex(resolve=False):
                            if key.startswith(word):
                                matches.append(str_rep(key, value))
                    elif OmegaConf.is_list(config):
                        assert isinstance(config, ListConfig)
                        for idx in range(len(config)):
                            try:
                                value = config[idx]
                                if str(idx).startswith(word):
                                    matches.append(str_rep(idx, value))
                            except MissingMandatoryValue:
                                matches.append(str_rep(idx, ""))

        else:
            assert False, f"Object is not an instance of config : {type(config)}"

        return matches
Esempio n. 10
0
def instantiate(config: Any, *args: Any, **kwargs: Any) -> Any:
    """
    :param config: An config object describing what to call and what params to use.
                   In addition to the parameters, the config must contain:
                   _target_ : target class or callable name (str)
                   _recursive_: Construct nested objects as well (bool).
                                True by default.
                                may be overridden via a _recursive_ key in
                                the kwargs
    :param args: Optional positional parameters pass-through
    :param kwargs: Optional named parameters to override
                   parameters in the config object. Parameters not present
                   in the config objects are being passed as is to the target.
    :return: if _target_ is a class name: the instantiated object
             if _target_ is a callable: the return value of the call
    """

    if OmegaConf.is_none(config):
        return None

    if isinstance(config, TargetConf) and config._target_ == "???":
        # Specific check to give a good warning about failure to annotate _target_ as a string.
        raise InstantiationException(
            f"Missing value for {type(config).__name__}._target_. Check that it's properly annotated and overridden."
            f"\nA common problem is forgetting to annotate _target_ as a string : '_target_: str = ...'"
        )

    if not (
        isinstance(config, dict)
        or OmegaConf.is_config(config)
        or is_structured_config(config)
    ):
        raise HydraException(f"Unsupported config type : {type(config).__name__}")

    if isinstance(config, dict):
        configc = config.copy()
        _convert_container_targets_to_strings(configc)
        config = configc

    kwargsc = kwargs.copy()
    _convert_container_targets_to_strings(kwargsc)
    kwargs = kwargsc

    # make a copy to ensure we do not change the provided object
    config_copy = OmegaConf.structured(config, flags={"allow_objects": True})
    if OmegaConf.is_config(config):
        config_copy._set_parent(config._get_parent())
    config = config_copy
    assert OmegaConf.is_config(config)
    OmegaConf.set_readonly(config, False)
    OmegaConf.set_struct(config, False)

    target = _get_target_type(config, kwargs)
    try:
        config._set_flag("allow_objects", True)
        final_kwargs = _get_kwargs(config, **kwargs)
        return target(*args, **final_kwargs)
    except Exception as e:
        raise type(e)(
            f"Error instantiating/calling '{_convert_target_to_string(target)}' : {e}"
        )
Esempio n. 11
0
    def _load_single_config(
        self, default: ResultDefault, repo: IConfigRepository
    ) -> ConfigResult:
        config_path = default.config_path

        assert config_path is not None
        ret = repo.load_config(config_path=config_path)
        assert ret is not None

        if not OmegaConf.is_config(ret.config):
            raise ValueError(
                f"Config {config_path} must be an OmegaConf config, got {type(ret.config).__name__}"
            )

        if not ret.is_schema_source:
            schema = None
            try:
                schema_source = repo.get_schema_source()
                cname = ConfigSource._normalize_file_name(filename=config_path)
                schema = schema_source.load_config(cname)
            except ConfigLoadError:
                # schema not found, ignore
                pass

            if schema is not None:
                try:
                    url = "https://hydra.cc/docs/next/upgrades/1.0_to_1.1/automatic_schema_matching"
                    if "defaults" in schema.config:
                        raise ConfigCompositionException(
                            dedent(
                                f"""\
                            '{config_path}' is validated against ConfigStore schema with the same name.
                            This behavior is deprecated in Hydra 1.1 and will be removed in Hydra 1.2.
                            In addition, the automatically matched schema contains a defaults list.
                            This combination is no longer supported.
                            See {url} for migration instructions."""
                            )
                        )
                    else:
                        deprecation_warning(
                            dedent(
                                f"""\

                                '{config_path}' is validated against ConfigStore schema with the same name.
                                This behavior is deprecated in Hydra 1.1 and will be removed in Hydra 1.2.
                                See {url} for migration instructions."""
                            ),
                            stacklevel=11,
                        )

                    # if primary config has a hydra node, remove it during validation and add it back.
                    # This allows overriding Hydra's configuration without declaring it's node
                    # in the schema of every primary config
                    hydra = None
                    hydra_config_group = (
                        default.config_path is not None
                        and default.config_path.startswith("hydra/")
                    )
                    config = ret.config
                    if (
                        default.primary
                        and isinstance(config, DictConfig)
                        and "hydra" in config
                        and not hydra_config_group
                    ):
                        hydra = config.pop("hydra")

                    merged = OmegaConf.merge(schema.config, config)
                    assert isinstance(merged, DictConfig)

                    if hydra is not None:
                        with open_dict(merged):
                            merged.hydra = hydra
                    ret.config = merged
                except OmegaConfBaseException as e:
                    raise ConfigCompositionException(
                        f"Error merging '{config_path}' with schema"
                    ) from e

                assert isinstance(merged, DictConfig)

        res = self._embed_result_config(ret, default.package)
        if (
            not default.primary
            and config_path != "hydra/config"
            and isinstance(res.config, DictConfig)
            and OmegaConf.select(res.config, "hydra.searchpath") is not None
        ):
            raise ConfigCompositionException(
                f"In '{config_path}': Overriding hydra.searchpath is only supported from the primary config"
            )

        return res