예제 #1
0
 def _get_value_element_as_str(value: ParsedElementType,
                               space_after_sep: bool = False) -> str:
     # str, QuotedString, int, bool, float, List[Any], Dict[str, Any]
     comma = ", " if space_after_sep else ","
     colon = ": " if space_after_sep else ":"
     if value is None:
         return "null"
     elif isinstance(value, QuotedString):
         return value.with_quotes()
     elif isinstance(value, list):
         s = comma.join([
             Override._get_value_element_as_str(
                 x, space_after_sep=space_after_sep) for x in value
         ])
         return "[" + s + "]"
     elif isinstance(value, dict):
         s = comma.join([
             f"{k}{colon}{Override._get_value_element_as_str(v, space_after_sep=space_after_sep)}"
             for k, v in value.items()
         ])
         return "{" + s + "}"
     elif isinstance(value, (str, int, bool, float)):
         return str(value)
     elif is_structured_config(value):
         print(value)
         return Override._get_value_element_as_str(
             OmegaConf.to_container(OmegaConf.structured(value)))
     else:
         assert False
예제 #2
0
 def _get_value_element_as_str(value: ParsedElementType,
                               space_after_sep: bool = False) -> str:
     # str, QuotedString, int, bool, float, List[Any], Dict[str, Any]
     comma = ", " if space_after_sep else ","
     colon = ": " if space_after_sep else ":"
     if value is None:
         return "null"
     elif isinstance(value, QuotedString):
         return value.with_quotes()
     elif isinstance(value, list):
         s = comma.join([
             Override._get_value_element_as_str(
                 x, space_after_sep=space_after_sep) for x in value
         ])
         return "[" + s + "]"
     elif isinstance(value, dict):
         str_items = []
         for k, v in value.items():
             str_key = Override._get_value_element_as_str(k)
             str_value = Override._get_value_element_as_str(
                 v, space_after_sep=space_after_sep)
             str_items.append(f"{str_key}{colon}{str_value}")
         return "{" + comma.join(str_items) + "}"
     elif isinstance(value, str):
         return escape_special_characters(value)
     elif isinstance(value, (int, bool, float)):
         return str(value)
     elif is_structured_config(value):
         return Override._get_value_element_as_str(
             OmegaConf.to_container(OmegaConf.structured(value)))
     else:
         assert False
예제 #3
0
def _is_passthrough(type_: Type[Any]) -> bool:
    if type_ is Any or issubclass(type_, (int, float, str, bool, Enum)):
        return False
    if is_structured_config(type_):
        try:
            OmegaConf.structured(type_)  # verify it's actually legal
        except ValidationError as e:
            log.debug(
                f"Failed to create DictConfig from ({type_.__name__}) : {e}, flagging as passthrough"
            )
            return True
        return False
    return True
예제 #4
0
def is_incompatible(type_: Type[Any]) -> bool:

    opt = _resolve_optional(type_)
    # Unions are not supported (Except Optional)
    if not opt[0] and _is_union(type_):
        return True

    type_ = opt[1]
    if type_ in (type(None), tuple, list, dict):
        return False

    try:
        if is_list_annotation(type_):
            lt = get_list_element_type(type_)
            return is_incompatible(lt)
        if is_dict_annotation(type_):
            kvt = get_dict_key_value_types(type_)
            if not issubclass(kvt[0], (str, Enum)):
                return True
            return is_incompatible(kvt[1])
        if is_tuple_annotation(type_):
            for arg in type_.__args__:
                if arg is not ... and is_incompatible(arg):
                    return True
            return False
        if get_origin(type_) is Callable:
            args = get_args(type_)
            for arg in args[0]:
                if arg is not ... and is_incompatible(arg):
                    return True
            if is_incompatible(args[1]):
                return True
            return False

    except ValidationError:
        return True

    if type_ is Any or issubclass(type_, (int, float, str, bool, Enum)):
        return False
    if is_structured_config(type_):
        try:
            OmegaConf.structured(type_)  # verify it's actually legal
        except ValidationError as e:
            log.debug(
                f"Failed to create DictConfig from ({type_.__name__}) : {e}, flagging as incompatible"
            )
            return True
        return False
    return True
예제 #5
0
def test_merge(inputs: Any, expected: Any) -> None:
    configs = [OmegaConf.create(c) for c in inputs]

    if isinstance(expected, (dict, list)) or is_structured_config(expected):
        merged = OmegaConf.merge(*configs)
        assert merged == expected
        # test input configs are not changed.
        # Note that converting to container without resolving to avoid resolution errors while comparing
        for i in range(len(inputs)):
            input_i = OmegaConf.create(inputs[i])
            orig = OmegaConf.to_container(input_i, resolve=False)
            merged2 = OmegaConf.to_container(configs[i], resolve=False)
            assert orig == merged2
    else:
        with expected:
            OmegaConf.merge(*configs)
예제 #6
0
파일: utils.py 프로젝트: wpc/hydra
def call(config: Any, *args: Any, **kwargs: Any) -> Any:
    """
    :param config: An object describing what to call and what params to use.
                   Must have a _target_ field.
    :param args: optional positional parameters pass-through
    :param kwargs: optional named parameters pass-through
    :return: the return value from the specified class or method
    """

    if OmegaConf.is_none(config):
        return None

    if isinstance(config, TargetConf) and config._target_ == "???":
        # Specific check to give a good warning about failure to annotate _target_ as a string.
        raise InstantiationException(
            f"Missing value for {type(config).__name__}._target_. Check that it's properly annotated and overridden."
            f"\nA common problem is forgetting to annotate _target_ as a string : '_target_: str = ...'"
        )

    if not (isinstance(config, dict) or OmegaConf.is_config(config)
            or is_structured_config(config)):
        raise HydraException(
            f"Unsupported config type : {type(config).__name__}")

    # make a copy to ensure we do not change the provided object
    config_copy = OmegaConf.structured(config)
    if OmegaConf.is_config(config):
        config_copy._set_parent(config._get_parent())
    config = config_copy

    cls = "<unknown>"
    try:
        assert isinstance(config, DictConfig)
        OmegaConf.set_readonly(config, False)
        OmegaConf.set_struct(config, False)
        cls = _get_cls_name(config)
        type_or_callable = _locate(cls)
        if isinstance(type_or_callable, type):
            return _instantiate_class(type_or_callable, config, *args,
                                      **kwargs)
        else:
            assert callable(type_or_callable)
            return _call_callable(type_or_callable, config, *args, **kwargs)
    except InstantiationException as e:
        raise e
    except Exception as e:
        raise HydraException(f"Error calling '{cls}' : {e}") from e
예제 #7
0
def check_node_metadata(
    node: Container,
    type_hint: Any,
    key_type: Any,
    elt_type: Any,
    obj_type: Any,
) -> None:
    value_optional, value_ref_type = _resolve_optional(type_hint)
    assert node._metadata.optional == value_optional
    assert node._metadata.ref_type == value_ref_type
    assert node._metadata.key_type == key_type
    assert node._metadata.element_type == elt_type
    assert node._metadata.object_type == obj_type

    if is_dict_annotation(value_ref_type) or is_structured_config(
            value_ref_type):
        assert isinstance(node, DictConfig)
    elif is_list_annotation(value_ref_type):
        assert isinstance(node, ListConfig)
예제 #8
0
def test_merge(
    inputs: Any,
    expected: Any,
    merge_function: Any,
    input_unchanged: bool,
) -> None:
    configs = [OmegaConf.create(c) for c in inputs]

    if isinstance(
            expected,
        (MutableMapping, MutableSequence)) or is_structured_config(expected):
        merged = merge_function(*configs)
        assert merged == expected
        if input_unchanged:
            # test input configs are not changed.
            # Note that converting to container without resolving to avoid resolution errors while comparing
            for i in range(len(inputs)):
                input_i = OmegaConf.create(inputs[i])
                orig = OmegaConf.to_container(input_i, resolve=False)
                merged2 = OmegaConf.to_container(configs[i], resolve=False)
                assert orig == merged2
    else:
        with expected:
            merge_function(*configs)
예제 #9
0
def instantiate(config: Any, *args: Any, **kwargs: Any) -> Any:
    """
    :param config: An config object describing what to call and what params to use.
                   In addition to the parameters, the config must contain:
                   _target_ : target class or callable name (str)
                   And may contain:
                   _args_: List-like of positional arguments to pass to the target
                   _recursive_: Construct nested objects as well (bool).
                                True by default.
                                may be overridden via a _recursive_ key in
                                the kwargs
                   _convert_: Conversion strategy
                        none    : Passed objects are DictConfig and ListConfig, default
                        partial : Passed objects are converted to dict and list, with
                                  the exception of Structured Configs (and their fields).
                        all     : Passed objects are dicts, lists and primitives without
                                  a trace of OmegaConf containers
                   _args_: List-like of positional arguments
    :param args: Optional positional parameters pass-through
    :param kwargs: Optional named parameters to override
                   parameters in the config object. Parameters not present
                   in the config objects are being passed as is to the target.
                   IMPORTANT: dataclasses instances in kwargs are interpreted as config
                              and cannot be used as passthrough
    :return: if _target_ is a class name: the instantiated object
             if _target_ is a callable: the return value of the call
    """

    # Return None if config is None
    if config is None or OmegaConf.is_none(config):
        return None

    # TargetConf edge case
    if isinstance(config, TargetConf) and config._target_ == "???":
        # Specific check to give a good warning about failure to annotate _target_ as a string.
        raise InstantiationException(
            f"Missing value for {type(config).__name__}._target_. Check that it's properly annotated and overridden."
            f"\nA common problem is forgetting to annotate _target_ as a string : '_target_: str = ...'"
        )

    if isinstance(config, dict):
        config = _prepare_input_dict(config)

    kwargs = _prepare_input_dict(kwargs)

    # Structured Config always converted first to OmegaConf
    if is_structured_config(config) or isinstance(config, dict):
        config = OmegaConf.structured(config, flags={"allow_objects": True})

    if OmegaConf.is_dict(config):
        # Finalize config (convert targets to strings, merge with kwargs)
        config_copy = copy.deepcopy(config)
        config_copy._set_flag(flags=["allow_objects", "struct", "readonly"],
                              values=[True, False, False])
        config_copy._set_parent(config._get_parent())
        config = config_copy

        if kwargs:
            config = OmegaConf.merge(config, kwargs)

        _recursive_ = config.pop(_Keys.RECURSIVE, True)
        _convert_ = config.pop(_Keys.CONVERT, ConvertMode.NONE)

        return instantiate_node(config,
                                *args,
                                recursive=_recursive_,
                                convert=_convert_)
    else:
        raise InstantiationException(
            "Top level config has to be OmegaConf DictConfig, plain dict, or a Structured Config class or instance"
        )
예제 #10
0
파일: utils.py 프로젝트: zhaodan2000/hydra
def instantiate(config: Any, *args: Any, **kwargs: Any) -> Any:
    """
    :param config: An config object describing what to call and what params to use.
                   In addition to the parameters, the config must contain:
                   _target_ : target class or callable name (str)
                   And may contain:
                   _recursive_: Construct nested objects as well (bool).
                                True by default.
                                may be overridden via a _recursive_ key in
                                the kwargs
                   _convert_: Conversion strategy
                        none    : Passed objects are DictConfig and ListConfig, default
                        partial : Passed objects are converted to dict and list, with
                                  the exception of Structured Configs (and their fields).
                        all     : Passed objects are dicts, lists and primitives without
                                  a trace of OmegaConf containers
    :param args: Optional positional parameters pass-through
    :param kwargs: Optional named parameters to override
                   parameters in the config object. Parameters not present
                   in the config objects are being passed as is to the target.
    :return: if _target_ is a class name: the instantiated object
             if _target_ is a callable: the return value of the call
    """

    if OmegaConf.is_none(config):
        return None

    if isinstance(config, TargetConf) and config._target_ == "???":
        # Specific check to give a good warning about failure to annotate _target_ as a string.
        raise InstantiationException(
            f"Missing value for {type(config).__name__}._target_. Check that it's properly annotated and overridden."
            f"\nA common problem is forgetting to annotate _target_ as a string : '_target_: str = ...'"
        )

    if not (
        isinstance(config, dict)
        or OmegaConf.is_config(config)
        or is_structured_config(config)
    ):
        raise HydraException(f"Unsupported config type : {type(config).__name__}")

    if isinstance(config, dict):
        configc = config.copy()
        _convert_container_targets_to_strings(configc)
        config = configc

    kwargsc = kwargs.copy()
    _convert_container_targets_to_strings(kwargsc)
    kwargs = kwargsc

    # make a copy to ensure we do not change the provided object
    config_copy = OmegaConf.structured(config, flags={"allow_objects": True})
    if OmegaConf.is_config(config):
        config_copy._set_parent(config._get_parent())
    config = config_copy
    assert OmegaConf.is_config(config)
    OmegaConf.set_readonly(config, False)
    OmegaConf.set_struct(config, False)

    target = _get_target_type(config, kwargs)
    try:
        config._set_flag("allow_objects", True)
        final_kwargs = _get_kwargs(config, **kwargs)
        convert = _pop_convert_mode(final_kwargs)
        if convert == ConvertMode.PARTIAL:
            final_kwargs = OmegaConf.to_container(
                final_kwargs, resolve=True, exclude_structured_configs=True
            )
            return target(*args, **final_kwargs)
        elif convert == ConvertMode.ALL:
            final_kwargs = OmegaConf.to_container(final_kwargs, resolve=True)
            return target(*args, **final_kwargs)
        elif convert == ConvertMode.NONE:
            return target(*args, **final_kwargs)
        else:
            assert False
    except Exception as e:
        # preserve the original exception backtrace
        raise type(e)(
            f"Error instantiating/calling '{_convert_target_to_string(target)}' : {e}"
        ).with_traceback(sys.exc_info()[2])
예제 #11
0
def instantiate(config: Any, *args: Any, **kwargs: Any) -> Any:
    """
    :param config: An config object describing what to call and what params to use.
                   In addition to the parameters, the config must contain:
                   _target_ : target class or callable name (str)
                   And may contain:
                   _args_: List-like of positional arguments to pass to the target
                   _recursive_: Construct nested objects as well (bool).
                                True by default.
                                may be overridden via a _recursive_ key in
                                the kwargs
                   _convert_: Conversion strategy
                        none    : Passed objects are DictConfig and ListConfig, default
                        partial : Passed objects are converted to dict and list, with
                                  the exception of Structured Configs (and their fields).
                        all     : Passed objects are dicts, lists and primitives without
                                  a trace of OmegaConf containers
                   _partial_: If True, return functools.partial wrapped method or object
                              False by default. Configure per target.
    :param args: Optional positional parameters pass-through
    :param kwargs: Optional named parameters to override
                   parameters in the config object. Parameters not present
                   in the config objects are being passed as is to the target.
                   IMPORTANT: dataclasses instances in kwargs are interpreted as config
                              and cannot be used as passthrough
    :return: if _target_ is a class name: the instantiated object
             if _target_ is a callable: the return value of the call
    """

    # Return None if config is None
    if config is None:
        return None

    # TargetConf edge case
    if isinstance(config, TargetConf) and config._target_ == "???":
        # Specific check to give a good warning about failure to annotate _target_ as a string.
        raise InstantiationException(
            dedent(f"""\
                Config has missing value for key `_target_`, cannot instantiate.
                Config type: {type(config).__name__}
                Check that the `_target_` key in your dataclass is properly annotated and overridden.
                A common problem is forgetting to annotate _target_ as a string : '_target_: str = ...'"""
                   ))
        # TODO: print full key

    if isinstance(config, (dict, list)):
        config = _prepare_input_dict_or_list(config)

    kwargs = _prepare_input_dict_or_list(kwargs)

    # Structured Config always converted first to OmegaConf
    if is_structured_config(config) or isinstance(config, (dict, list)):
        config = OmegaConf.structured(config, flags={"allow_objects": True})

    if OmegaConf.is_dict(config):
        # Finalize config (convert targets to strings, merge with kwargs)
        config_copy = copy.deepcopy(config)
        config_copy._set_flag(flags=["allow_objects", "struct", "readonly"],
                              values=[True, False, False])
        config_copy._set_parent(config._get_parent())
        config = config_copy

        if kwargs:
            config = OmegaConf.merge(config, kwargs)

        OmegaConf.resolve(config)

        _recursive_ = config.pop(_Keys.RECURSIVE, True)
        _convert_ = config.pop(_Keys.CONVERT, ConvertMode.NONE)
        _partial_ = config.pop(_Keys.PARTIAL, False)

        return instantiate_node(config,
                                *args,
                                recursive=_recursive_,
                                convert=_convert_,
                                partial=_partial_)
    elif OmegaConf.is_list(config):
        # Finalize config (convert targets to strings, merge with kwargs)
        config_copy = copy.deepcopy(config)
        config_copy._set_flag(flags=["allow_objects", "struct", "readonly"],
                              values=[True, False, False])
        config_copy._set_parent(config._get_parent())
        config = config_copy

        OmegaConf.resolve(config)

        _recursive_ = kwargs.pop(_Keys.RECURSIVE, True)
        _convert_ = kwargs.pop(_Keys.CONVERT, ConvertMode.NONE)
        _partial_ = kwargs.pop(_Keys.PARTIAL, False)

        if _partial_:
            raise InstantiationException(
                "The _partial_ keyword is not compatible with top-level list instantiation"
            )

        return instantiate_node(config,
                                *args,
                                recursive=_recursive_,
                                convert=_convert_,
                                partial=_partial_)
    else:
        raise InstantiationException(
            dedent(f"""\
                Cannot instantiate config of type {type(config).__name__}.
                Top level config must be an OmegaConf DictConfig/ListConfig object,
                a plain dict/list, or a Structured Config class or instance."""
                   ))
예제 #12
0
파일: utils.py 프로젝트: EspenHa/hydra
def instantiate(config: Any, *args: Any, **kwargs: Any) -> Any:
    """
    :param config: An config object describing what to call and what params to use.
                   In addition to the parameters, the config must contain:
                   _target_ : target class or callable name (str)
                   _recursive_: Construct nested objects as well (bool).
                                True by default.
                                may be overridden via a _recursive_ key in
                                the kwargs
    :param args: Optional positional parameters pass-through
    :param kwargs: Optional named parameters to override
                   parameters in the config object. Parameters not present
                   in the config objects are being passed as is to the target.
    :return: if _target_ is a class name: the instantiated object
             if _target_ is a callable: the return value of the call
    """

    if OmegaConf.is_none(config):
        return None

    if isinstance(config, TargetConf) and config._target_ == "???":
        # Specific check to give a good warning about failure to annotate _target_ as a string.
        raise InstantiationException(
            f"Missing value for {type(config).__name__}._target_. Check that it's properly annotated and overridden."
            f"\nA common problem is forgetting to annotate _target_ as a string : '_target_: str = ...'"
        )

    if not (
        isinstance(config, dict)
        or OmegaConf.is_config(config)
        or is_structured_config(config)
    ):
        raise HydraException(f"Unsupported config type : {type(config).__name__}")

    if isinstance(config, dict):
        configc = config.copy()
        _convert_container_targets_to_strings(configc)
        config = configc

    kwargsc = kwargs.copy()
    _convert_container_targets_to_strings(kwargsc)
    kwargs = kwargsc

    # make a copy to ensure we do not change the provided object
    config_copy = OmegaConf.structured(config, flags={"allow_objects": True})
    if OmegaConf.is_config(config):
        config_copy._set_parent(config._get_parent())
    config = config_copy
    assert OmegaConf.is_config(config)
    OmegaConf.set_readonly(config, False)
    OmegaConf.set_struct(config, False)

    target = _get_target_type(config, kwargs)
    try:
        config._set_flag("allow_objects", True)
        final_kwargs = _get_kwargs(config, **kwargs)
        return target(*args, **final_kwargs)
    except Exception as e:
        raise type(e)(
            f"Error instantiating/calling '{_convert_target_to_string(target)}' : {e}"
        )