Esempio n. 1
0
def _get_kwargs(
    config: Union[DictConfig, ListConfig],
    **kwargs: Any,
) -> Any:
    from hydra.utils import instantiate

    assert OmegaConf.is_config(config)

    if OmegaConf.is_list(config):
        assert isinstance(config, ListConfig)
        return [
            _get_kwargs(x) if OmegaConf.is_config(x) else x for x in config
        ]

    assert OmegaConf.is_dict(
        config), "Input config is not an OmegaConf DictConfig"

    final_kwargs = {}

    recursive = _is_recursive(config, kwargs)
    overrides = OmegaConf.create(kwargs, flags={"allow_objects": True})
    config.merge_with(overrides)

    for k, v in config.items():
        final_kwargs[k] = v

    if recursive:
        for k, v in final_kwargs.items():
            if _is_target(v):
                final_kwargs[k] = instantiate(v)
            elif OmegaConf.is_dict(v) and not OmegaConf.is_none(v):
                d = OmegaConf.create({}, flags={"allow_objects": True})
                for key, value in v.items():
                    if _is_target(value):
                        d[key] = instantiate(value)
                    elif OmegaConf.is_config(value):
                        d[key] = _get_kwargs(value)
                    else:
                        d[key] = value
                final_kwargs[k] = d
            elif OmegaConf.is_list(v):
                lst = OmegaConf.create([], flags={"allow_objects": True})
                for x in v:
                    if _is_target(x):
                        lst.append(instantiate(x))
                    elif OmegaConf.is_config(x):
                        lst.append(_get_kwargs(x))
                    else:
                        lst.append(x)
                final_kwargs[k] = lst
            else:
                if OmegaConf.is_none(v):
                    v = None
                final_kwargs[k] = v

    return final_kwargs
 def test_merge_none_is_none(self, class_type: str) -> None:
     # Test that the merged type is that of the last merged config
     module: Any = import_module(class_type)
     c1 = OmegaConf.structured(module.StructuredOptional)
     assert c1.with_default == module.Nested()
     c2 = OmegaConf.merge(c1, {"with_default": None})
     assert OmegaConf.is_none(c2, "with_default")
Esempio n. 3
0
    def _set_value(self, value: Any) -> None:
        from omegaconf import OmegaConf

        self._metadata.object_type = self._metadata.annotated_type
        type_ = (self._metadata.object_type
                 if self._metadata.object_type is not None else DictConfig)
        if OmegaConf.is_none(value):
            if not self._is_optional():
                assert isinstance(type_, type)
                raise ValidationError(
                    f"Cannot assign {type_.__name__}=None (field is not Optional)"
                )
            self.__dict__["_content"] = None
        elif _is_interpolation(value):
            self.__dict__["_content"] = value
        elif value == "???":  # missing
            self.__dict__["_content"] = "???"
        else:
            is_structured = is_structured_config(value)
            if is_structured:
                _type = get_type_of(value)
                value = get_structured_config_data(value)

            self._metadata.object_type = None
            self.__dict__["_content"] = {}

            for k, v in value.items():
                self.__setitem__(k, v)

            if is_structured:
                self._metadata.object_type = _type
Esempio n. 4
0
    def _validate_merge(self, value: Any) -> None:
        from omegaconf import OmegaConf

        dest = self
        src = value

        self._validate_non_optional(None, src)

        dest_obj_type = OmegaConf.get_type(dest)
        src_obj_type = OmegaConf.get_type(src)

        if dest._is_missing() and src._metadata.object_type is not None:
            self._validate_set(key=None, value=_get_value(src))

        if src._is_missing():
            return

        validation_error = (dest_obj_type is not None
                            and src_obj_type is not None
                            and is_structured_config(dest_obj_type)
                            and not OmegaConf.is_none(src)
                            and not is_dict(src_obj_type)
                            and not issubclass(src_obj_type, dest_obj_type))
        if validation_error:
            msg = (f"Merge error : {type_str(src_obj_type)} is not a "
                   f"subclass of {type_str(dest_obj_type)}. value: {src}")
            raise ValidationError(msg)
Esempio n. 5
0
    def _set_value(self, value: Any) -> None:
        from omegaconf import OmegaConf

        assert not isinstance(value, ValueNode)
        self._validate_set(key=None, value=value)

        if OmegaConf.is_none(value):
            self.__dict__["_content"] = None
            self._metadata.object_type = None
        elif _is_interpolation(value):
            self.__dict__["_content"] = value
            self._metadata.object_type = None
        elif value == "???":
            self.__dict__["_content"] = "???"
            self._metadata.object_type = None
        else:
            self.__dict__["_content"] = {}
            if is_structured_config(value):
                self._metadata.object_type = None
                data = get_structured_config_data(value)
                for k, v in data.items():
                    self.__setitem__(k, v)
                self._metadata.object_type = get_type_of(value)
            elif isinstance(value, DictConfig):
                self._metadata.object_type = dict
                for k, v in value.items_ex(resolve=False):
                    self.__setitem__(k, v)
                self.__dict__["_metadata"] = copy.deepcopy(value._metadata)

            elif isinstance(value, dict):
                for k, v in value.items():
                    self.__setitem__(k, v)
            else:
                assert False, f"Unsupported value type : {value}"  # pragma: no cover
Esempio n. 6
0
def verify(
    cfg: Any,
    key: Any,
    none: bool,
    opt: bool,
    missing: bool,
    inter: bool,
    none_public: Optional[bool] = None,
    exp: Any = SKIP,
) -> None:
    if none_public is None:
        none_public = none

    target_node = cfg._get_node(key)
    assert target_node._key() == key
    assert target_node._is_none() == none
    assert target_node._is_optional() == opt
    assert target_node._is_missing() == missing
    assert target_node._is_interpolation() == inter

    if exp is not SKIP:
        assert cfg.get(key) == exp

    assert OmegaConf.is_missing(cfg, key) == missing
    with warns(UserWarning):
        assert OmegaConf.is_none(cfg, key) == none_public
    assert _is_optional(cfg, key) == opt
    assert OmegaConf.is_interpolation(cfg, key) == inter
Esempio n. 7
0
def test_is_none_interpolation(cfg: Any, key: str, is_none: bool) -> None:
    cfg = OmegaConf.create(cfg)
    with warns(UserWarning):
        assert OmegaConf.is_none(cfg, key) == is_none
    check = _is_none(cfg._get_node(key),
                     resolve=True,
                     throw_on_resolution_failure=False)
    assert check == is_none
Esempio n. 8
0
    def test_none_construction(self, node_type: Any, values: Any) -> None:
        values = copy.deepcopy(values)
        node = node_type(value=None, is_optional=True)
        if isinstance(node, ValueNode):
            assert node._value() is None
            assert node._is_optional()
        assert node.__eq__(None)
        assert OmegaConf.is_none(node)

        for value in values:
            node._set_value(value)
            assert node.__eq__(value)
            assert not node.__eq__(None)
            assert not OmegaConf.is_none(node)

        with pytest.raises(ValidationError):
            node_type(value=None, is_optional=False)
Esempio n. 9
0
    def _set_value_impl(
        self, value: Any, flags: Optional[Dict[str, bool]] = None
    ) -> None:
        from omegaconf import OmegaConf, flag_override

        if id(self) == id(value):
            return

        if flags is None:
            flags = {}

        assert not isinstance(value, ValueNode)
        self._validate_set(key=None, value=value)

        if OmegaConf.is_none(value):
            self.__dict__["_content"] = None
            self._metadata.object_type = None
        elif _is_interpolation(value):
            self.__dict__["_content"] = value
            self._metadata.object_type = None
        elif value == "???":
            self.__dict__["_content"] = "???"
            self._metadata.object_type = None
        else:
            self.__dict__["_content"] = {}
            if is_structured_config(value):
                self._metadata.object_type = None
                data = get_structured_config_data(
                    value,
                    allow_objects=self._get_flag("allow_objects"),
                )
                for k, v in data.items():
                    self.__setitem__(k, v)
                self._metadata.object_type = get_type_of(value)
            elif isinstance(value, DictConfig):
                self.__dict__["_metadata"] = copy.deepcopy(value._metadata)
                self._metadata.flags = copy.deepcopy(flags)
                # disable struct and readonly for the construction phase
                # retaining other flags like allow_objects. The real flags are restored at the end of this function
                with flag_override(self, "struct", False):
                    with flag_override(self, "readonly", False):
                        for k, v in value.__dict__["_content"].items():
                            self.__setitem__(k, v)

            elif isinstance(value, dict):
                for k, v in value.items():
                    self.__setitem__(k, v)
            else:  # pragma: no cover
                msg = f"Unsupported value type : {value}"
                raise ValidationError(msg)
Esempio n. 10
0
File: utils.py Progetto: wpc/hydra
def call(config: Any, *args: Any, **kwargs: Any) -> Any:
    """
    :param config: An object describing what to call and what params to use.
                   Must have a _target_ field.
    :param args: optional positional parameters pass-through
    :param kwargs: optional named parameters pass-through
    :return: the return value from the specified class or method
    """

    if OmegaConf.is_none(config):
        return None

    if isinstance(config, TargetConf) and config._target_ == "???":
        # Specific check to give a good warning about failure to annotate _target_ as a string.
        raise InstantiationException(
            f"Missing value for {type(config).__name__}._target_. Check that it's properly annotated and overridden."
            f"\nA common problem is forgetting to annotate _target_ as a string : '_target_: str = ...'"
        )

    if not (isinstance(config, dict) or OmegaConf.is_config(config)
            or is_structured_config(config)):
        raise HydraException(
            f"Unsupported config type : {type(config).__name__}")

    # make a copy to ensure we do not change the provided object
    config_copy = OmegaConf.structured(config)
    if OmegaConf.is_config(config):
        config_copy._set_parent(config._get_parent())
    config = config_copy

    cls = "<unknown>"
    try:
        assert isinstance(config, DictConfig)
        OmegaConf.set_readonly(config, False)
        OmegaConf.set_struct(config, False)
        cls = _get_cls_name(config)
        type_or_callable = _locate(cls)
        if isinstance(type_or_callable, type):
            return _instantiate_class(type_or_callable, config, *args,
                                      **kwargs)
        else:
            assert callable(type_or_callable)
            return _call_callable(type_or_callable, config, *args, **kwargs)
    except InstantiationException as e:
        raise e
    except Exception as e:
        raise HydraException(f"Error calling '{cls}' : {e}") from e
Esempio n. 11
0
def call(config: Union[ObjectConf, DictConfig], *args: Any, **kwargs: Any) -> Any:
    """
    :param config: An ObjectConf or DictConfig describing what to call and what params to use
    :param args: optional positional parameters pass-through
    :param kwargs: optional named parameters pass-through
    :return: the return value from the specified class or method
    """
    if OmegaConf.is_none(config):
        return None
    try:
        cls = _get_cls_name(config)
        type_or_callable = _locate(cls)
        if isinstance(type_or_callable, type):
            return _instantiate_class(type_or_callable, config, *args, **kwargs)
        else:
            assert callable(type_or_callable)
            return _call_callable(type_or_callable, config, *args, **kwargs)
    except Exception as e:
        raise HydraException(f"Error calling '{cls}' : {e}") from e
Esempio n. 12
0
    def _validate_non_optional(self, key: Any, value: Any) -> None:
        from omegaconf import OmegaConf

        if OmegaConf.is_none(value):
            if key is not None:
                child = self._get_node(key)
                if child is not None and not child._is_optional():
                    self._format_and_raise(
                        key=key,
                        value=value,
                        cause=ValidationError("child '$FULL_KEY' is not Optional"),
                    )
            else:
                if not self._is_optional():
                    self._format_and_raise(
                        key=None,
                        value=value,
                        cause=ValidationError("field '$FULL_KEY' is not Optional"),
                    )
Esempio n. 13
0
    def _set_value_impl(self,
                        value: Any,
                        flags: Optional[Dict[str, bool]] = None) -> None:
        from omegaconf import OmegaConf, flag_override

        if id(self) == id(value):
            return

        if flags is None:
            flags = {}

        if OmegaConf.is_none(value):
            if not self._is_optional():
                raise ValidationError(
                    "Non optional ListConfig cannot be constructed from None")
            self.__dict__["_content"] = None
        elif get_value_kind(value) == ValueKind.MANDATORY_MISSING:
            self.__dict__["_content"] = "???"
        elif get_value_kind(value) in (
                ValueKind.INTERPOLATION,
                ValueKind.STR_INTERPOLATION,
        ):
            self.__dict__["_content"] = value
        else:
            if not (is_primitive_list(value) or isinstance(value, ListConfig)):
                type_ = type(value)
                msg = f"Invalid value assigned : {type_.__name__} is not a ListConfig, list or tuple."
                raise ValidationError(msg)

            self.__dict__["_content"] = []
            if isinstance(value, ListConfig):
                self.__dict__["_metadata"] = copy.deepcopy(value._metadata)
                self._metadata.flags = copy.deepcopy(flags)
                # disable struct and readonly for the construction phase
                # retaining other flags like allow_objects. The real flags are restored at the end of this function
                with flag_override(self, "struct", False):
                    with flag_override(self, "readonly", False):
                        for item in value._iter_ex(resolve=False):
                            self.append(item)
            elif is_primitive_list(value):
                for item in value:
                    self.append(item)
Esempio n. 14
0
def call(
    config: Union[ObjectConf, TargetConf, DictConfig, Dict[Any, Any]],
    *args: Any,
    **kwargs: Any,
) -> Any:
    """
    :param config: An object describing what to call and what params to use
    :param args: optional positional parameters pass-through
    :param kwargs: optional named parameters pass-through
    :return: the return value from the specified class or method
    """
    if isinstance(config, TargetConf) and config._target_ == "???":
        raise InstantiationException(
            f"Missing _target_ value. Check that you specified it in '{type(config).__name__}'"
            f" and that the type annotation is correct: '_target_: str = ...'")
    if isinstance(config, (dict, ObjectConf, TargetConf)):
        config = OmegaConf.structured(config)

    if OmegaConf.is_none(config):
        return None
    cls = "<unknown>"
    try:
        assert isinstance(config, DictConfig)
        # make a copy to ensure we do not change the provided object
        config = copy.deepcopy(config)
        OmegaConf.set_readonly(config, False)
        OmegaConf.set_struct(config, False)
        cls = _get_cls_name(config)
        type_or_callable = _locate(cls)
        if isinstance(type_or_callable, type):
            return _instantiate_class(type_or_callable, config, *args,
                                      **kwargs)
        else:
            assert callable(type_or_callable)
            return _call_callable(type_or_callable, config, *args, **kwargs)
    except InstantiationException as e:
        raise e
    except Exception as e:
        raise HydraException(f"Error calling '{cls}' : {e}") from e
Esempio n. 15
0
def verify(
    cfg: Any,
    key: Any,
    none: bool,
    opt: bool,
    missing: bool,
    inter: bool,
    exp: Any = SKIP,
) -> None:
    target_node = cfg._get_node(key)
    assert target_node._key() == key
    assert target_node._is_none() == none
    assert target_node._is_optional() == opt
    assert target_node._is_missing() == missing
    assert target_node._is_interpolation() == inter

    if exp is not SKIP:
        assert cfg.get(key) == exp

    assert OmegaConf.is_missing(cfg, key) == missing
    assert OmegaConf.is_none(cfg, key) == none
    assert OmegaConf.is_optional(cfg, key) == opt
    assert OmegaConf.is_interpolation(cfg, key) == inter
Esempio n. 16
0
def test_is_none(fac: Any, is_none: bool) -> None:
    obj = fac(is_none)
    assert OmegaConf.is_none(obj) == is_none

    cfg = OmegaConf.create({"node": obj})
    assert OmegaConf.is_none(cfg, "node") == is_none
Esempio n. 17
0
def _get_kwargs(
    config: Union[DictConfig, ListConfig],
    root: bool = True,
    **kwargs: Any,
) -> Any:
    from hydra.utils import instantiate

    assert OmegaConf.is_config(config)

    if OmegaConf.is_list(config):
        assert isinstance(config, ListConfig)
        return [
            _get_kwargs(x, root=False) if OmegaConf.is_config(x) else x
            for x in config
        ]

    assert OmegaConf.is_dict(
        config), "Input config is not an OmegaConf DictConfig"

    recursive = _is_recursive(config, kwargs)
    overrides = OmegaConf.create(kwargs, flags={"allow_objects": True})
    config.merge_with(overrides)

    final_kwargs = OmegaConf.create(flags={"allow_objects": True})
    final_kwargs._set_parent(config._get_parent())
    final_kwargs._set_flag("readonly", False)
    final_kwargs._set_flag("struct", False)
    if recursive:
        for k, v in config.items_ex(resolve=False):
            if OmegaConf.is_none(v):
                final_kwargs[k] = v
            elif _is_target(v):
                final_kwargs[k] = instantiate(v)
            elif OmegaConf.is_dict(v):
                d = OmegaConf.create({}, flags={"allow_objects": True})
                for key, value in v.items_ex(resolve=False):
                    if _is_target(value):
                        d[key] = instantiate(value)
                    elif OmegaConf.is_config(value):
                        d[key] = _get_kwargs(value, root=False)
                    else:
                        d[key] = value
                d._metadata.object_type = v._metadata.object_type
                final_kwargs[k] = d
            elif OmegaConf.is_list(v):
                lst = OmegaConf.create([], flags={"allow_objects": True})
                for x in v:
                    if _is_target(x):
                        lst.append(instantiate(x))
                    elif OmegaConf.is_config(x):
                        lst.append(_get_kwargs(x, root=False))
                        lst[-1]._metadata.object_type = x._metadata.object_type
                    else:
                        lst.append(x)
                final_kwargs[k] = lst
            else:
                final_kwargs[k] = v
    else:
        for k, v in config.items_ex(resolve=False):
            final_kwargs[k] = v

    final_kwargs._set_flag("readonly", None)
    final_kwargs._set_flag("struct", None)
    final_kwargs._set_flag("allow_objects", None)
    if not root:
        # This is tricky, since the root kwargs is exploded anyway we can treat is as an untyped dict
        # the motivation is that the object type is used as an indicator to treat the object differently during
        # conversion to a primitive container in some cases
        final_kwargs._metadata.object_type = config._metadata.object_type
    return final_kwargs
Esempio n. 18
0
    def _validate_set_merge_impl(self, key: Any, value: Any,
                                 is_assign: bool) -> None:
        from omegaconf import OmegaConf

        vk = get_value_kind(value)
        if vk in (ValueKind.INTERPOLATION, ValueKind.STR_INTERPOLATION):
            return

        if OmegaConf.is_none(value):
            if key is not None:
                child = self._get_node(key)
                if child is not None and not child._is_optional():
                    self._format_and_raise(
                        key=key,
                        value=value,
                        cause=ValidationError(
                            "child '$FULL_KEY' is not Optional"),
                    )
            else:
                if not self._is_optional():
                    self._format_and_raise(
                        key=None,
                        value=value,
                        cause=ValidationError(
                            "field '$FULL_KEY' is not Optional"),
                    )

        if value == "???":
            return

        if is_assign and isinstance(value,
                                    ValueNode) and self._has_element_type():
            self._check_assign_value_node(key, value)
            return

        target: Optional[Node]
        if key is None:
            target = self
        else:
            target = self._get_node(key)

        if target is None:
            return

        def is_typed(c: Any) -> bool:
            return isinstance(
                c, DictConfig) and c._metadata.ref_type not in (Any, dict)

        if not is_typed(target):
            return

        # target must be optional by now. no need to check the type of value if None.
        if value is None:
            return

        target_type = target._metadata.ref_type
        value_type = OmegaConf.get_type(value)

        if is_dict(value_type) and is_dict(target_type):
            return

        if is_generic_list(target_type) or is_generic_dict(target_type):
            return
        # TODO: simplify this flow. (if assign validate assign else validate merge)
        # is assignment illegal?
        validation_error = (target_type is not None and value_type is not None
                            and not issubclass(value_type, target_type))

        if not is_assign:
            # merge
            # Merging of a dictionary is allowed even if assignment is illegal (merge would do deeper checks)
            validation_error = not is_dict(value_type) and validation_error

        if validation_error:
            assert value_type is not None
            assert target_type is not None
            msg = (f"Invalid type assigned : {type_str(value_type)} is not a "
                   f"subclass of {type_str(target_type)}. value: {value}")
            raise ValidationError(msg)
Esempio n. 19
0
def instantiate(config: Any, *args: Any, **kwargs: Any) -> Any:
    """
    :param config: An config object describing what to call and what params to use.
                   In addition to the parameters, the config must contain:
                   _target_ : target class or callable name (str)
                   And may contain:
                   _recursive_: Construct nested objects as well (bool).
                                True by default.
                                may be overridden via a _recursive_ key in
                                the kwargs
                   _convert_: Conversion strategy
                        none    : Passed objects are DictConfig and ListConfig, default
                        partial : Passed objects are converted to dict and list, with
                                  the exception of Structured Configs (and their fields).
                        all     : Passed objects are dicts, lists and primitives without
                                  a trace of OmegaConf containers
    :param args: Optional positional parameters pass-through
    :param kwargs: Optional named parameters to override
                   parameters in the config object. Parameters not present
                   in the config objects are being passed as is to the target.
    :return: if _target_ is a class name: the instantiated object
             if _target_ is a callable: the return value of the call
    """

    if OmegaConf.is_none(config):
        return None

    if isinstance(config, TargetConf) and config._target_ == "???":
        # Specific check to give a good warning about failure to annotate _target_ as a string.
        raise InstantiationException(
            f"Missing value for {type(config).__name__}._target_. Check that it's properly annotated and overridden."
            f"\nA common problem is forgetting to annotate _target_ as a string : '_target_: str = ...'"
        )

    if not (
        isinstance(config, dict)
        or OmegaConf.is_config(config)
        or is_structured_config(config)
    ):
        raise HydraException(f"Unsupported config type : {type(config).__name__}")

    if isinstance(config, dict):
        configc = config.copy()
        _convert_container_targets_to_strings(configc)
        config = configc

    kwargsc = kwargs.copy()
    _convert_container_targets_to_strings(kwargsc)
    kwargs = kwargsc

    # make a copy to ensure we do not change the provided object
    config_copy = OmegaConf.structured(config, flags={"allow_objects": True})
    if OmegaConf.is_config(config):
        config_copy._set_parent(config._get_parent())
    config = config_copy
    assert OmegaConf.is_config(config)
    OmegaConf.set_readonly(config, False)
    OmegaConf.set_struct(config, False)

    target = _get_target_type(config, kwargs)
    try:
        config._set_flag("allow_objects", True)
        final_kwargs = _get_kwargs(config, **kwargs)
        convert = _pop_convert_mode(final_kwargs)
        if convert == ConvertMode.PARTIAL:
            final_kwargs = OmegaConf.to_container(
                final_kwargs, resolve=True, exclude_structured_configs=True
            )
            return target(*args, **final_kwargs)
        elif convert == ConvertMode.ALL:
            final_kwargs = OmegaConf.to_container(final_kwargs, resolve=True)
            return target(*args, **final_kwargs)
        elif convert == ConvertMode.NONE:
            return target(*args, **final_kwargs)
        else:
            assert False
    except Exception as e:
        # preserve the original exception backtrace
        raise type(e)(
            f"Error instantiating/calling '{_convert_target_to_string(target)}' : {e}"
        ).with_traceback(sys.exc_info()[2])
Esempio n. 20
0
    def _validate_set_merge_impl(self, key: Any, value: Any, is_assign: bool) -> None:
        from omegaconf import OmegaConf

        vk = get_value_kind(value)
        if vk in (ValueKind.INTERPOLATION, ValueKind.STR_INTERPOLATION):
            return

        if OmegaConf.is_none(value):
            if key is not None:
                node = self._get_node(key)
                if node is not None and not node._is_optional():
                    self._format_and_raise(
                        key=key,
                        value=value,
                        cause=ValidationError("field '$FULL_KEY' is not Optional"),
                    )
            else:
                if not self._is_optional():
                    raise ValidationError("field '$FULL_KEY' is not Optional")

        if value == "???":
            return

        target: Optional[Node]
        if key is None:
            target = self
        else:
            target = self._get_node(key)

        if (target is not None and target._get_flag("readonly")) or self._get_flag(
            "readonly"
        ):
            if is_assign:
                msg = f"Cannot assign to read-only node : {value}"
            else:
                msg = f"Cannot merge into read-only node : {value}"
            raise ReadonlyConfigError(msg)

        if target is None:
            return

        def is_typed(c: Any) -> bool:
            return isinstance(c, DictConfig) and c._metadata.ref_type not in (Any, dict)

        if not is_typed(target):
            return

        # target must be optional by now. no need to check the type of value if None.
        if value is None:
            return

        target_type = target._metadata.ref_type
        value_type = OmegaConf.get_type(value)

        if is_dict(value_type) and is_dict(target_type):
            return

        # is assignment illegal?
        validation_error = (
            target_type is not None
            and value_type is not None
            and not issubclass(value_type, target_type)
        )

        if not is_assign:
            # merge
            # Merging of a dictionary is allowed even if assignment is illegal (merge would do deeper checks)
            validation_error = not is_dict(value_type) and validation_error

        if validation_error:
            assert value_type is not None
            assert target_type is not None
            msg = (
                f"Invalid type assigned : {type_str(value_type)} is not a "
                f"subclass of {type_str(target_type)}. value: {value}"
            )
            raise ValidationError(msg)
Esempio n. 21
0
def format_and_raise(
    node: Any,
    key: Any,
    value: Any,
    msg: str,
    cause: Exception,
    type_override: Any = None,
) -> None:
    from omegaconf import OmegaConf
    from omegaconf.base import Node

    if isinstance(cause, OmegaConfBaseException) and cause._initialized:
        ex = cause
        if type_override is not None:
            ex = type_override(str(cause))
            ex.__dict__ = copy.deepcopy(cause.__dict__)
        _raise(ex, cause)

    object_type: Optional[Type[Any]]
    object_type_str: Optional[str] = None
    ref_type: Optional[Type[Any]]
    ref_type_str: Optional[str]

    child_node: Optional[Node] = None
    if node is None:
        full_key = ""
        object_type = None
        ref_type = None
        ref_type_str = None
    else:
        if key is not None and not OmegaConf.is_none(node):
            child_node = node._get_node(key,
                                        validate_access=False,
                                        disable_warning=True)

        full_key = node._get_full_key(key=key, disable_warning=True)

        object_type = OmegaConf.get_type(node)
        object_type_str = type_str(object_type)

        ref_type = get_ref_type(node)
        ref_type_str = type_str(ref_type)

    msg = string.Template(msg).substitute(
        REF_TYPE=ref_type_str,
        OBJECT_TYPE=object_type_str,
        KEY=key,
        FULL_KEY=full_key,
        VALUE=value,
        VALUE_TYPE=f"{type(value).__name__}",
        KEY_TYPE=f"{type(key).__name__}",
    )

    template = """$MSG
\tfull_key: $FULL_KEY
\treference_type=$REF_TYPE
\tobject_type=$OBJECT_TYPE"""

    s = string.Template(template=template)

    message = s.substitute(
        REF_TYPE=ref_type_str,
        OBJECT_TYPE=object_type_str,
        MSG=msg,
        FULL_KEY=full_key,
    )
    exception_type = type(cause) if type_override is None else type_override
    if exception_type == TypeError:
        exception_type = ConfigTypeError
    elif exception_type == IndexError:
        exception_type = ConfigIndexError

    ex = exception_type(f"{message}")
    if issubclass(exception_type, OmegaConfBaseException):
        ex._initialized = True
        ex.msg = message
        ex.parent_node = node
        ex.child_node = child_node
        ex.key = key
        ex.full_key = full_key
        ex.value = value
        ex.object_type = object_type
        ex.object_type_str = object_type_str
        ex.ref_type = ref_type
        ex.ref_type_str = ref_type_str

    _raise(ex, cause)
Esempio n. 22
0
    def _map_merge(dest: "BaseContainer", src: "BaseContainer") -> None:
        """merge src into dest and return a new copy, does not modified input"""
        from omegaconf import AnyNode, DictConfig, OmegaConf, ValueNode

        assert isinstance(dest, DictConfig)
        assert isinstance(src, DictConfig)
        src_type = src._metadata.object_type
        src_ref_type = get_ref_type(src)
        assert src_ref_type is not None

        # If source DictConfig is:
        #  - an interpolation => set the destination DictConfig to be the same interpolation
        #  - None => set the destination DictConfig to None
        if src._is_interpolation() or src._is_none():
            dest._set_value(src._value())
            _update_types(node=dest,
                          ref_type=src_ref_type,
                          object_type=src_type)
            return

        dest._validate_merge(value=src)

        def expand(node: Container) -> None:
            rt = node._metadata.ref_type
            val: Any
            if rt is not Any:
                if is_dict_annotation(rt):
                    val = {}
                elif is_list_annotation(rt):
                    val = []
                else:
                    val = rt
            elif isinstance(node, DictConfig):
                val = {}
            else:
                assert False

            node._set_value(val)

        if (src._is_missing() and not dest._is_missing()
                and is_structured_config(src_ref_type)):
            # Replace `src` with a prototype of its corresponding structured config
            # whose fields are all missing (to avoid overwriting fields in `dest`).
            src = _create_structured_with_missing_fields(ref_type=src_ref_type,
                                                         object_type=src_type)

        if (dest._is_interpolation()
                or dest._is_missing()) and not src._is_missing():
            expand(dest)

        for key, src_value in src.items_ex(resolve=False):
            src_node = src._get_node(key, validate_access=False)
            dest_node = dest._get_node(key, validate_access=False)
            assert src_node is None or isinstance(src_node, Node)
            assert dest_node is None or isinstance(dest_node, Node)

            if isinstance(dest_node, DictConfig):
                dest_node._validate_merge(value=src_node)

            missing_src_value = _is_missing_value(src_value)

            if (isinstance(dest_node, Container)
                    and OmegaConf.is_none(dest, key) and not missing_src_value
                    and not OmegaConf.is_none(src_value)):
                expand(dest_node)

            if dest_node is not None and dest_node._is_interpolation():
                target_node = dest_node._dereference_node(
                    throw_on_resolution_failure=False)
                if isinstance(target_node, Container):
                    dest[key] = target_node
                    dest_node = dest._get_node(key)

            if (dest_node is None
                    and is_structured_config(dest._metadata.element_type)
                    and not missing_src_value):
                # merging into a new node. Use element_type as a base
                dest[key] = DictConfig(content=dest._metadata.element_type,
                                       parent=dest)
                dest_node = dest._get_node(key)

            if dest_node is not None:
                if isinstance(dest_node, BaseContainer):
                    if isinstance(src_value, BaseContainer):
                        dest_node._merge_with(src_value)
                    elif not missing_src_value:
                        dest.__setitem__(key, src_value)
                else:
                    if isinstance(src_value, BaseContainer):
                        dest.__setitem__(key, src_value)
                    else:
                        assert isinstance(dest_node, ValueNode)
                        assert isinstance(src_node, ValueNode)
                        # Compare to literal missing, ignoring interpolation
                        src_node_missing = _is_missing_literal(src_value)
                        try:
                            if isinstance(dest_node, AnyNode):
                                if src_node_missing:
                                    node = copy.copy(src_node)
                                    # if src node is missing, use the value from the dest_node,
                                    # but validate it against the type of the src node before assigment
                                    node._set_value(dest_node._value())
                                else:
                                    node = src_node
                                dest.__setitem__(key, node)
                            else:
                                if not src_node_missing:
                                    dest_node._set_value(src_value)

                        except (ValidationError, ReadonlyConfigError) as e:
                            dest._format_and_raise(key=key,
                                                   value=src_value,
                                                   cause=e)
            else:
                from omegaconf import open_dict

                if is_structured_config(src_type):
                    # verified to be compatible above in _validate_merge
                    with open_dict(dest):
                        dest[key] = src._get_node(key)
                else:
                    dest[key] = src._get_node(key)

        _update_types(node=dest, ref_type=src_ref_type, object_type=src_type)

        # explicit flags on the source config are replacing the flag values in the destination
        flags = src._metadata.flags
        assert flags is not None
        for flag, value in flags.items():
            if value is not None:
                dest._set_flag(flag, value)
Esempio n. 23
0
def instantiate_node(
    node: Any,
    *args: Any,
    convert: Union[str, ConvertMode] = ConvertMode.NONE,
    recursive: bool = True,
) -> Any:
    # Return None if config is None
    if node is None or OmegaConf.is_none(node):
        return None

    if not OmegaConf.is_config(node):
        return node

    # Override parent modes from config if specified
    if OmegaConf.is_dict(node):
        # using getitem instead of get(key, default) because OmegaConf will raise an exception
        # if the key type is incompatible on get.
        convert = node[_Keys.CONVERT] if _Keys.CONVERT in node else convert
        recursive = node[
            _Keys.RECURSIVE] if _Keys.RECURSIVE in node else recursive

    if not isinstance(recursive, bool):
        raise TypeError(
            f"_recursive_ flag must be a bool, got {type(recursive)}")

    # If OmegaConf list, create new list of instances if recursive
    if OmegaConf.is_list(node):
        items = [
            instantiate_node(item, convert=convert, recursive=recursive)
            for item in node._iter_ex(resolve=True)
        ]

        if convert in (ConvertMode.ALL, ConvertMode.PARTIAL):
            # If ALL or PARTIAL, use plain list as container
            return items
        else:
            # Otherwise, use ListConfig as container
            lst = OmegaConf.create(items, flags={"allow_objects": True})
            lst._set_parent(node)
            return lst

    elif OmegaConf.is_dict(node):
        exclude_keys = set({"_target_", "_convert_", "_recursive_"})
        if _is_target(node):
            target = _resolve_target(node.get(_Keys.TARGET))
            kwargs = {}
            for key, value in node.items_ex(resolve=True):
                if key not in exclude_keys:
                    if recursive:
                        value = instantiate_node(value,
                                                 convert=convert,
                                                 recursive=recursive)
                    kwargs[key] = _convert_node(value, convert)
            return _call_target(target, *args, **kwargs)
        else:
            # If ALL or PARTIAL non structured, instantiate in dict and resolve interpolations eagerly.
            if convert == ConvertMode.ALL or (
                    convert == ConvertMode.PARTIAL
                    and node._metadata.object_type is None):
                dict_items = {}
                for key, value in node.items_ex(resolve=True):
                    # list items inherits recursive flag from the containing dict.
                    dict_items[key] = instantiate_node(value,
                                                       convert=convert,
                                                       recursive=recursive)
                return dict_items
            else:
                # Otherwise use DictConfig and resolve interpolations lazily.
                cfg = OmegaConf.create({}, flags={"allow_objects": True})
                for key, value in node.items_ex(resolve=False):
                    cfg[key] = instantiate_node(value,
                                                convert=convert,
                                                recursive=recursive)
                cfg._set_parent(node)
                cfg._metadata.object_type = node._metadata.object_type
                return cfg

    else:
        assert False, f"Unexpected config type : {type(node).__name__}"
Esempio n. 24
0
def test_is_none_invalid_node() -> None:
    cfg = OmegaConf.create({})
    with warns(UserWarning):
        assert OmegaConf.is_none(cfg, "invalid")
Esempio n. 25
0
    def _map_merge(dest: "BaseContainer", src: "BaseContainer") -> None:
        """merge src into dest and return a new copy, does not modified input"""
        from omegaconf import DictConfig, OmegaConf, ValueNode

        assert isinstance(dest, DictConfig)
        assert isinstance(src, DictConfig)
        src_type = src._metadata.object_type

        # if source DictConfig is missing set the DictConfig one to be missing too.
        if src._is_missing():
            dest._set_value("???")
            return
        dest._validate_set_merge_impl(key=None, value=src, is_assign=False)

        def expand(node: Container) -> None:
            type_ = get_ref_type(node)
            if type_ is not None:
                _is_optional, type_ = _resolve_optional(type_)
                if is_dict_annotation(type_):
                    node._set_value({})
                elif is_list_annotation(type_):
                    node._set_value([])
                else:
                    node._set_value(type_)

        if dest._is_missing():
            expand(dest)

        for key, src_value in src.items_ex(resolve=False):

            dest_node = dest._get_node(key, validate_access=False)
            if isinstance(dest_node, Container) and OmegaConf.is_none(
                    dest, key):
                if not OmegaConf.is_none(src_value):
                    expand(dest_node)

            if dest_node is not None:
                if dest_node._is_interpolation():
                    target_node = dest_node._dereference_node(
                        throw_on_resolution_failure=False)
                    if isinstance(target_node, Container):
                        dest[key] = target_node
                        dest_node = dest._get_node(key)

            if is_structured_config(dest._metadata.element_type):
                dest[key] = DictConfig(content=dest._metadata.element_type,
                                       parent=dest)
                dest_node = dest._get_node(key)

            if dest_node is not None:
                if isinstance(dest_node, BaseContainer):
                    if isinstance(src_value, BaseContainer):
                        dest._validate_merge(key=key, value=src_value)
                        dest_node._merge_with(src_value)
                    else:
                        dest.__setitem__(key, src_value)
                else:
                    if isinstance(src_value, BaseContainer):
                        dest.__setitem__(key, src_value)
                    else:
                        assert isinstance(dest_node, ValueNode)
                        try:
                            dest_node._set_value(src_value)
                        except (ValidationError, ReadonlyConfigError) as e:
                            dest._format_and_raise(key=key,
                                                   value=src_value,
                                                   cause=e)
            else:
                from omegaconf import open_dict

                if is_structured_config(src_type):
                    # verified to be compatible above in _validate_set_merge_impl
                    with open_dict(dest):
                        dest[key] = src._get_node(key)
                else:
                    dest[key] = src._get_node(key)

        if src_type is not None and not is_primitive_dict(src_type):
            dest._metadata.object_type = src_type

        # explicit flags on the source config are replacing the flag values in the destination
        flags = src._metadata.flags
        assert flags is not None
        for flag, value in flags.items():
            if value is not None:
                dest._set_flag(flag, value)
Esempio n. 26
0
def instantiate(config: Any, *args: Any, **kwargs: Any) -> Any:
    """
    :param config: An config object describing what to call and what params to use.
                   In addition to the parameters, the config must contain:
                   _target_ : target class or callable name (str)
                   And may contain:
                   _args_: List-like of positional arguments to pass to the target
                   _recursive_: Construct nested objects as well (bool).
                                True by default.
                                may be overridden via a _recursive_ key in
                                the kwargs
                   _convert_: Conversion strategy
                        none    : Passed objects are DictConfig and ListConfig, default
                        partial : Passed objects are converted to dict and list, with
                                  the exception of Structured Configs (and their fields).
                        all     : Passed objects are dicts, lists and primitives without
                                  a trace of OmegaConf containers
                   _args_: List-like of positional arguments
    :param args: Optional positional parameters pass-through
    :param kwargs: Optional named parameters to override
                   parameters in the config object. Parameters not present
                   in the config objects are being passed as is to the target.
                   IMPORTANT: dataclasses instances in kwargs are interpreted as config
                              and cannot be used as passthrough
    :return: if _target_ is a class name: the instantiated object
             if _target_ is a callable: the return value of the call
    """

    # Return None if config is None
    if config is None or OmegaConf.is_none(config):
        return None

    # TargetConf edge case
    if isinstance(config, TargetConf) and config._target_ == "???":
        # Specific check to give a good warning about failure to annotate _target_ as a string.
        raise InstantiationException(
            f"Missing value for {type(config).__name__}._target_. Check that it's properly annotated and overridden."
            f"\nA common problem is forgetting to annotate _target_ as a string : '_target_: str = ...'"
        )

    if isinstance(config, dict):
        config = _prepare_input_dict(config)

    kwargs = _prepare_input_dict(kwargs)

    # Structured Config always converted first to OmegaConf
    if is_structured_config(config) or isinstance(config, dict):
        config = OmegaConf.structured(config, flags={"allow_objects": True})

    if OmegaConf.is_dict(config):
        # Finalize config (convert targets to strings, merge with kwargs)
        config_copy = copy.deepcopy(config)
        config_copy._set_flag(flags=["allow_objects", "struct", "readonly"],
                              values=[True, False, False])
        config_copy._set_parent(config._get_parent())
        config = config_copy

        if kwargs:
            config = OmegaConf.merge(config, kwargs)

        _recursive_ = config.pop(_Keys.RECURSIVE, True)
        _convert_ = config.pop(_Keys.CONVERT, ConvertMode.NONE)

        return instantiate_node(config,
                                *args,
                                recursive=_recursive_,
                                convert=_convert_)
    else:
        raise InstantiationException(
            "Top level config has to be OmegaConf DictConfig, plain dict, or a Structured Config class or instance"
        )
Esempio n. 27
0
 def _param_in_cfg(self, param):
     with open_dict(self.cfg):
         return not OmegaConf.is_none(self.cfg, param)
Esempio n. 28
0
def _is_target(x: Any) -> bool:
    if isinstance(x, dict):
        return "_target_" in x
    if OmegaConf.is_dict(x) and not OmegaConf.is_none(x):
        return "_target_" in x
    return False
Esempio n. 29
0
def instantiate(config: Any, *args: Any, **kwargs: Any) -> Any:
    """
    :param config: An config object describing what to call and what params to use.
                   In addition to the parameters, the config must contain:
                   _target_ : target class or callable name (str)
                   _recursive_: Construct nested objects as well (bool).
                                True by default.
                                may be overridden via a _recursive_ key in
                                the kwargs
    :param args: Optional positional parameters pass-through
    :param kwargs: Optional named parameters to override
                   parameters in the config object. Parameters not present
                   in the config objects are being passed as is to the target.
    :return: if _target_ is a class name: the instantiated object
             if _target_ is a callable: the return value of the call
    """

    if OmegaConf.is_none(config):
        return None

    if isinstance(config, TargetConf) and config._target_ == "???":
        # Specific check to give a good warning about failure to annotate _target_ as a string.
        raise InstantiationException(
            f"Missing value for {type(config).__name__}._target_. Check that it's properly annotated and overridden."
            f"\nA common problem is forgetting to annotate _target_ as a string : '_target_: str = ...'"
        )

    if not (
        isinstance(config, dict)
        or OmegaConf.is_config(config)
        or is_structured_config(config)
    ):
        raise HydraException(f"Unsupported config type : {type(config).__name__}")

    if isinstance(config, dict):
        configc = config.copy()
        _convert_container_targets_to_strings(configc)
        config = configc

    kwargsc = kwargs.copy()
    _convert_container_targets_to_strings(kwargsc)
    kwargs = kwargsc

    # make a copy to ensure we do not change the provided object
    config_copy = OmegaConf.structured(config, flags={"allow_objects": True})
    if OmegaConf.is_config(config):
        config_copy._set_parent(config._get_parent())
    config = config_copy
    assert OmegaConf.is_config(config)
    OmegaConf.set_readonly(config, False)
    OmegaConf.set_struct(config, False)

    target = _get_target_type(config, kwargs)
    try:
        config._set_flag("allow_objects", True)
        final_kwargs = _get_kwargs(config, **kwargs)
        return target(*args, **final_kwargs)
    except Exception as e:
        raise type(e)(
            f"Error instantiating/calling '{_convert_target_to_string(target)}' : {e}"
        )