コード例 #1
0
ファイル: overrides_visitor.py プロジェクト: wpc/hydra
    def visitFunction(self, ctx: OverrideParser.FunctionContext) -> Any:
        args = []
        kwargs = {}
        children = ctx.getChildren()
        func_name = next(children).getText()
        assert self.is_matching_terminal(next(children), OverrideLexer.POPEN)
        in_kwargs = False
        while True:
            cur = next(children)
            if self.is_matching_terminal(cur, OverrideLexer.PCLOSE):
                break

            if isinstance(cur, OverrideParser.ArgNameContext):
                in_kwargs = True
                name = cur.getChild(0).getText()
                cur = next(children)
                value = self.visitElement(cur)
                kwargs[name] = value
            else:
                if self.is_matching_terminal(cur, OverrideLexer.COMMA):
                    continue
                if in_kwargs:
                    raise HydraException("positional argument follows keyword argument")
                value = self.visitElement(cur)
                args.append(value)

        function = FunctionCall(name=func_name, args=args, kwargs=kwargs)
        try:
            return self.functions.eval(function)
        except Exception as e:
            raise HydraException(
                f"{type(e).__name__} while evaluating '{ctx.getText()}': {e}"
            ) from e
コード例 #2
0
    def _apply_overrides_to_config(overrides: List[str],
                                   cfg: DictConfig) -> None:
        loader = _utils.get_yaml_loader()

        def get_value(val_: Optional[str]) -> Any:
            return yaml.load(val_, Loader=loader) if val_ is not None else None

        for line in overrides:
            override = ConfigLoaderImpl._parse_config_override(line)
            try:
                value = get_value(override.value)
                if override.is_delete():
                    val = OmegaConf.select(cfg,
                                           override.key,
                                           throw_on_missing=False)
                    if val is None:
                        raise HydraException(
                            f"Could not delete from config. '{override.key}' does not exist."
                        )
                    elif value is not None and value != val:
                        raise HydraException(
                            f"Could not delete from config."
                            f" The value of '{override.key}' is {val} and not {override.value}."
                        )

                    key = override.key
                    last_dot = key.rfind(".")
                    with open_dict(cfg):
                        if last_dot == -1:
                            del cfg[key]
                        else:
                            node = OmegaConf.select(cfg, key[0:last_dot])
                            del node[key[last_dot + 1:]]

                elif override.is_add():
                    if (OmegaConf.select(cfg,
                                         override.key,
                                         throw_on_missing=False) is None):
                        with open_dict(cfg):
                            OmegaConf.update(cfg, override.key, value)
                    else:
                        raise HydraException(
                            f"Could not append to config. An item is already at '{override.key}'."
                        )
                else:
                    try:
                        OmegaConf.update(cfg, override.key, value)
                    except (ConfigAttributeError, ConfigKeyError) as ex:
                        raise HydraException(
                            f"Could not override '{override.key}'. No match in config."
                            f"\nTo append to your config use +{line}") from ex
            except OmegaConfBaseException as ex:
                raise HydraException(f"Error merging override {line}") from ex
コード例 #3
0
ファイル: overrides_visitor.py プロジェクト: wpc/hydra
 def syntaxError(
     self,
     recognizer: Any,
     offending_symbol: Any,
     line: Any,
     column: Any,
     msg: Any,
     e: Any,
 ) -> None:
     if msg is not None:
         raise HydraException(msg) from e
     else:
         raise HydraException(str(e)) from e
コード例 #4
0
ファイル: initialize.py プロジェクト: zhaodan2000/hydra
    def __init__(
        self,
        config_path: Optional[str] = None,
        job_name: Optional[str] = None,
        strict: Optional[bool] = None,
        caller_stack_depth: int = 1,
    ) -> None:
        self._gh_backup = get_gh_backup()

        if config_path is not None and os.path.isabs(config_path):
            raise HydraException(
                "config_path in initialize() must be relative")
        calling_file, calling_module = detect_calling_file_or_module_from_stack_frame(
            caller_stack_depth + 1)
        if job_name is None:
            job_name = detect_task_name(calling_file=calling_file,
                                        calling_module=calling_module)

        Hydra.create_main_hydra_file_or_module(
            calling_file=calling_file,
            calling_module=calling_module,
            config_path=config_path,
            job_name=job_name,
            strict=strict,
        )
コード例 #5
0
    def split_arguments(
            overrides: List[Override],
            max_batch_size: Optional[int]) -> List[List[List[str]]]:

        lists = []
        for override in overrides:
            if override.is_sweep_override():
                if override.is_discrete_sweep():
                    key = override.get_key_element()
                    sweep = [
                        f"{key}={val}"
                        for val in override.sweep_string_iterator()
                    ]
                    lists.append(sweep)
                else:
                    assert override.value_type is not None
                    raise HydraException(
                        f"{BasicSweeper.__name__} does not support sweep type : {override.value_type.name}"
                    )
            else:
                key = override.get_key_element()
                value = override.get_value_element_as_str()
                lists.append([f"{key}={value}"])

        all_batches = [list(x) for x in itertools.product(*lists)]
        assert max_batch_size is None or max_batch_size > 0
        if max_batch_size is None:
            return [all_batches]
        else:
            chunks_iter = BasicSweeper.split_overrides_to_chunks(
                all_batches, max_batch_size)
            return [x for x in chunks_iter]
コード例 #6
0
    def register(self, name: str, func: Callable[..., Any]) -> None:
        if name in self.definitions:
            raise HydraException(
                f"Function named '{name}' is already registered")

        self.definitions[name] = inspect.signature(func)
        self.functions[name] = func
コード例 #7
0
    def __init__(
        self,
        config_path: Optional[str] = _UNSPECIFIED_,
        job_name: Optional[str] = None,
        caller_stack_depth: int = 1,
    ) -> None:
        self._gh_backup = get_gh_backup()

        # DEPRECATED: remove in 1.2
        # in 1.2, the default config_path should be changed to None
        if config_path is _UNSPECIFIED_:
            url = "https://hydra.cc/docs/next/upgrades/1.0_to_1.1/changes_to_hydra_main_config_path"
            deprecation_warning(
                message=dedent(f"""\
                config_path is not specified in hydra.initialize().
                See {url} for more information."""),
                stacklevel=2,
            )
            config_path = "."

        if config_path is not None and os.path.isabs(config_path):
            raise HydraException(
                "config_path in initialize() must be relative")
        calling_file, calling_module = detect_calling_file_or_module_from_stack_frame(
            caller_stack_depth + 1)
        if job_name is None:
            job_name = detect_task_name(calling_file=calling_file,
                                        calling_module=calling_module)

        Hydra.create_main_hydra_file_or_module(
            calling_file=calling_file,
            calling_module=calling_module,
            config_path=config_path,
            job_name=job_name,
        )
コード例 #8
0
    def _parse_config_override(override: str) -> ParsedConfigOverride:
        # forms:
        # update: key=value
        # append: +key=value
        # delete: ~key=value | ~key
        # regex code and tests: https://regex101.com/r/JAPVdx/9

        regex = r"^(?P<prefix>[+~])?(?P<key>[\w\.@]+)(?:=(?P<value>.*))?$"
        matches = re.search(regex, override)

        valid = True
        prefix = None
        key = None
        value = None
        msg = (
            f"Error parsing config override : '{override}'"
            f"\nAccepted forms:"
            f"\n\tOverride:  key=value"
            f"\n\tAppend:   +key=value"
            f"\n\tDelete:   ~key=value, ~key"
        )
        if matches:
            prefix = matches.group("prefix")
            key = matches.group("key")
            value = matches.group("value")
            if prefix in (None, "+"):
                valid = key is not None and value is not None
            elif prefix == "~":
                valid = key is not None

        if matches and valid:
            assert key is not None
            return ParsedConfigOverride(prefix=prefix, key=key, value=value)
        else:
            raise HydraException(msg)
コード例 #9
0
    def sweep_iterator(
        self,
        transformer: TransformerType = Transformer.identity
    ) -> Iterator[ElementType]:
        """
        Converts CHOICE_SWEEP, SIMPLE_CHOICE_SWEEP, GLOB_CHOICE_SWEEP and
        RANGE_SWEEP to a List[Elements] that can be used in the value component
        of overrides (the part after the =). A transformer may be provided for
        converting each element to support the needs of different sweepers
        """
        if self.value_type not in (
                ValueType.CHOICE_SWEEP,
                ValueType.SIMPLE_CHOICE_SWEEP,
                ValueType.GLOB_CHOICE_SWEEP,
                ValueType.RANGE_SWEEP,
        ):
            raise HydraException(
                f"Can only enumerate CHOICE and RANGE sweeps, type is {self.value_type}"
            )

        lst: Any
        if isinstance(self._value, list):
            lst = self._value
        elif isinstance(self._value, ChoiceSweep):
            if self._value.shuffle:
                lst = copy(self._value.list)
                shuffle(lst)
            else:
                lst = self._value.list
        elif isinstance(self._value, RangeSweep):
            if self._value.shuffle:
                lst = list(self._value.range())
                shuffle(lst)
                lst = iter(lst)
            else:
                lst = self._value.range()
        elif isinstance(self._value, Glob):
            if self.config_loader is None:
                raise HydraException("ConfigLoader is not set")

            ret = self.config_loader.get_group_options(
                self.key_or_group, results_filter=ObjectType.CONFIG)
            return iter(self._value.filter(ret))
        else:
            assert False

        return map(transformer, lst)
コード例 #10
0
ファイル: utils.py プロジェクト: wpc/hydra
def call(config: Any, *args: Any, **kwargs: Any) -> Any:
    """
    :param config: An object describing what to call and what params to use.
                   Must have a _target_ field.
    :param args: optional positional parameters pass-through
    :param kwargs: optional named parameters pass-through
    :return: the return value from the specified class or method
    """

    if OmegaConf.is_none(config):
        return None

    if isinstance(config, TargetConf) and config._target_ == "???":
        # Specific check to give a good warning about failure to annotate _target_ as a string.
        raise InstantiationException(
            f"Missing value for {type(config).__name__}._target_. Check that it's properly annotated and overridden."
            f"\nA common problem is forgetting to annotate _target_ as a string : '_target_: str = ...'"
        )

    if not (isinstance(config, dict) or OmegaConf.is_config(config)
            or is_structured_config(config)):
        raise HydraException(
            f"Unsupported config type : {type(config).__name__}")

    # make a copy to ensure we do not change the provided object
    config_copy = OmegaConf.structured(config)
    if OmegaConf.is_config(config):
        config_copy._set_parent(config._get_parent())
    config = config_copy

    cls = "<unknown>"
    try:
        assert isinstance(config, DictConfig)
        OmegaConf.set_readonly(config, False)
        OmegaConf.set_struct(config, False)
        cls = _get_cls_name(config)
        type_or_callable = _locate(cls)
        if isinstance(type_or_callable, type):
            return _instantiate_class(type_or_callable, config, *args,
                                      **kwargs)
        else:
            assert callable(type_or_callable)
            return _call_callable(type_or_callable, config, *args, **kwargs)
    except InstantiationException as e:
        raise e
    except Exception as e:
        raise HydraException(f"Error calling '{cls}' : {e}") from e
コード例 #11
0
 def _raise_parse_override_error(override: str) -> str:
     msg = (
         f"Error parsing config group override : '{override}'"
         f"\nAccepted forms:"
         f"\n\tOverride: key=value, key@package=value, key@src_pkg:dest_pkg=value, key@src_pkg:dest_pkg"
         f"\n\tAppend:  +key=value, +key@package=value"
         f"\n\tDelete:  ~key, ~key@pkg, ~key=value, ~key@pkg=value")
     raise HydraException(msg)
コード例 #12
0
 def syntaxError(
     self,
     recognizer: Any,
     offending_symbol: Any,
     line: Any,
     column: Any,
     msg: Any,
     e: Any,
 ) -> None:
     raise HydraException(msg)
コード例 #13
0
    def sweep_string_iterator(self) -> Iterator[str]:
        """
        Converts the sweep_choices from a List[ParsedElements] to a List[str] that can be used in the
        value component of overrides (the part after the =)
        """
        if self.value_type not in (
                ValueType.CHOICE_SWEEP,
                ValueType.SIMPLE_CHOICE_SWEEP,
                ValueType.GLOB_CHOICE_SWEEP,
                ValueType.RANGE_SWEEP,
        ):
            raise HydraException(
                f"Can only enumerate CHOICE and RANGE sweeps, type is {self.value_type}"
            )

        lst: Any
        if isinstance(self._value, list):
            lst = self._value
        elif isinstance(self._value, ChoiceSweep):
            if self._value.shuffle:
                lst = copy(self._value.list)
                shuffle(lst)
            else:
                lst = self._value.list
        elif isinstance(self._value, RangeSweep):
            if self._value.shuffle:
                lst = list(self._value.range())
                shuffle(lst)
                lst = iter(lst)
            else:
                lst = self._value.range()
        elif isinstance(self._value, Glob):
            if self.config_loader is None:
                raise HydraException("ConfigLoader is not set")

            ret = self.config_loader.get_group_options(
                self.key_or_group, results_filter=ObjectType.CONFIG)
            return iter(self._value.filter(ret))
        else:
            assert False

        return map(Override._get_value_element_as_str, lst)
コード例 #14
0
ファイル: config_loader_impl.py プロジェクト: tchaton/hydra
 def _raise_parse_override_error(override: Optional[str]) -> None:
     msg = (
         f"Error parsing config group override : '{override}'"
         f"\nAccepted forms:"
         f"\n\tOverride: key=value, key@package=value, key@src_pkg:dest_pkg=value, key@src_pkg:dest_pkg"
         f"\n\tAppend:  +key=value, +key@package=value"
         f"\n\tDelete:  ~key, ~key@pkg, ~key=value, ~key@pkg=value"
         f"\n"
         f"\nSee https://hydra.cc/docs/next/advanced/command_line_syntax for details"
     )
     raise HydraException(msg)
コード例 #15
0
 def __init__(self, config_dir: str, job_name: str = "app") -> None:
     self._gh_backup = get_gh_backup()
     # Relative here would be interpreted as relative to cwd, which - depending on when it run
     # may have unexpected meaning. best to force an absolute path to avoid confusion.
     # Can consider using hydra.utils.to_absolute_path() to convert it at a future point if there is demand.
     if not os.path.isabs(config_dir):
         raise HydraException(
             "initialize_config_dir() requires an absolute config_dir as input"
         )
     csp = create_config_search_path(search_path_dir=config_dir)
     Hydra.create_main_hydra2(task_name=job_name, config_search_path=csp)
コード例 #16
0
ファイル: overrides_visitor.py プロジェクト: wpc/hydra
 def visitPrimitive(
     self, ctx: OverrideParser.PrimitiveContext
 ) -> Optional[Union[QuotedString, int, bool, float, str]]:
     ret: Optional[Union[int, bool, float, str]]
     first_idx = 0
     last_idx = ctx.getChildCount()
     # skip first if whitespace
     if self.is_ws(ctx.getChild(0)):
         if last_idx == 1:
             # Only whitespaces => this is not allowed.
             raise HydraException(
                 "Trying to parse a primitive that is all whitespaces"
             )
         first_idx = 1
     if self.is_ws(ctx.getChild(-1)):
         last_idx = last_idx - 1
     num = last_idx - first_idx
     if num > 1:
         ret = ctx.getText().strip()
     else:
         node = ctx.getChild(first_idx)
         if node.symbol.type == OverrideLexer.QUOTED_VALUE:
             text = node.getText()
             qc = text[0]
             text = text[1:-1]
             if qc == "'":
                 quote = Quote.single
                 text = text.replace("\\'", "'")
             elif qc == '"':
                 quote = Quote.double
                 text = text.replace('\\"', '"')
             else:
                 assert False
             return QuotedString(text=text, quote=quote)
         elif node.symbol.type in (OverrideLexer.ID, OverrideLexer.INTERPOLATION):
             ret = node.symbol.text
         elif node.symbol.type == OverrideLexer.INT:
             ret = int(node.symbol.text)
         elif node.symbol.type == OverrideLexer.FLOAT:
             ret = float(node.symbol.text)
         elif node.symbol.type == OverrideLexer.NULL:
             ret = None
         elif node.symbol.type == OverrideLexer.BOOL:
             text = node.getText().lower()
             if text == "true":
                 ret = True
             elif text == "false":
                 ret = False
             else:
                 assert False
         else:
             return node.getText()  # type: ignore
     return ret
コード例 #17
0
 def get_value_element_as_str(self, space_after_sep: bool = False) -> str:
     """
     Returns a string representation of the value in this override
     (similar to the part after the = in the input string)
     :param space_after_sep: True to append space after commas and colons
     :return:
     """
     if isinstance(self._value, Sweep):
         # This should not be called for sweeps
         raise HydraException("Cannot convert sweep to str")
     return Override._get_value_element_as_str(
         self._value, space_after_sep=space_after_sep)
コード例 #18
0
 def parse_overrides(self, overrides: List[str]) -> List[Override]:
     ret: List[Override] = []
     for override in overrides:
         try:
             parsed = self.parse_rule(override, "override")
         except HydraException as e:
             raise HydraException(
                 f"Error parsing override '{override}' : {e}"
                 f"\nSee https://hydra.cc/docs/next/advanced/command_line_syntax for details"
             )
         assert isinstance(parsed, Override)
         ret.append(parsed)
     return ret
コード例 #19
0
ファイル: utils.py プロジェクト: iamdarshshah/hydra
def _get_cls_name(config: Union[ObjectConf, DictConfig]) -> str:
    def _warn(field: str) -> None:
        if isinstance(config, DictConfig):
            warnings.warn(
                f"""Config key '{config._get_full_key(field)}' is deprecated since Hydra 1.0 and will be removed in Hydra 1.1.
Use 'target' instead of '{field}'.""",
                category=UserWarning,
            )
        else:
            warnings.warn(
                f"""
ObjectConf field '{field}' is deprecated since Hydra 1.0 and will be removed in Hydra 1.1.
Use 'target' instead of '{field}'.""",
                category=UserWarning,
            )

    def _getcls(field: str) -> str:
        classname = getattr(config, field)
        assert isinstance(classname, str)
        return classname

    def _has_field(field: str) -> bool:
        if isinstance(config, DictConfig):
            if field in config:
                ret = config[field] != "???"
                assert isinstance(ret, bool)
                return ret
            else:
                return False
        else:
            if hasattr(config, field):
                ret = getattr(config, field) != "???"
                assert isinstance(ret, bool)
                return ret
            else:
                return False

    if _has_field(field="class"):
        _warn(field="class")
    elif _has_field(field="cls"):
        _warn(field="cls")

    if _has_field(field="target"):
        return _getcls(field="target")
    elif _has_field(field="cls"):
        return _getcls(field="cls")
    elif _has_field(field="class"):
        return _getcls(field="class")
    else:
        raise HydraException("Input config does not have a `target` field")
コード例 #20
0
ファイル: config_loader_impl.py プロジェクト: tchaton/hydra
    def _apply_overrides_to_config(overrides: List[Override],
                                   cfg: DictConfig) -> None:
        for override in overrides:
            if override.get_subject_package() is not None:
                raise HydraException(
                    f"Override {override.input_line} looks like a config group override, "
                    f"but config group '{override.key_or_group}' does not exist."
                )

            key = override.key_or_group
            value = override.value()
            try:
                if override.is_delete():
                    config_val = OmegaConf.select(cfg,
                                                  key,
                                                  throw_on_missing=False)
                    if config_val is None:
                        raise HydraException(
                            f"Could not delete from config. '{override.key_or_group}' does not exist."
                        )
                    elif value is not None and value != config_val:
                        raise HydraException(
                            f"Could not delete from config."
                            f" The value of '{override.key_or_group}' is {config_val} and not {value}."
                        )

                    last_dot = key.rfind(".")
                    with open_dict(cfg):
                        if last_dot == -1:
                            del cfg[key]
                        else:
                            node = OmegaConf.select(cfg, key[0:last_dot])
                            del node[key[last_dot + 1:]]

                elif override.is_add():
                    if OmegaConf.select(cfg, key,
                                        throw_on_missing=False) is None:
                        with open_dict(cfg):
                            OmegaConf.update(cfg, key, value)
                    else:
                        raise HydraException(
                            f"Could not append to config. An item is already at '{override.key_or_group}'."
                        )
                else:
                    try:
                        OmegaConf.update(cfg, key, value)
                    except (ConfigAttributeError, ConfigKeyError) as ex:
                        raise HydraException(
                            f"Could not override '{override.key_or_group}'. No match in config."
                            f"\nTo append to your config use +{override.input_line}"
                        ) from ex
            except OmegaConfBaseException as ex:
                raise HydraException(
                    f"Error merging override {override.input_line}") from ex
コード例 #21
0
ファイル: config_loader_impl.py プロジェクト: tchaton/hydra
    def _merge_config(
        self,
        cfg: DictConfig,
        config_group: str,
        name: str,
        required: bool,
        is_primary_config: bool,
        package_override: Optional[str],
    ) -> DictConfig:
        try:
            if config_group != "":
                new_cfg = f"{config_group}/{name}"
            else:
                new_cfg = name

            loaded_cfg, _ = self._load_config_impl(
                new_cfg,
                is_primary_config=is_primary_config,
                package_override=package_override,
            )
            if loaded_cfg is None:
                if required:
                    if config_group == "":
                        msg = f"Could not load {new_cfg}"
                        raise MissingConfigException(msg, new_cfg)
                    else:
                        options = self.get_group_options(config_group)
                        if options:
                            opt_list = "\n".join(["\t" + x for x in options])
                            msg = (
                                f"Could not load {new_cfg}.\nAvailable options:"
                                f"\n{opt_list}")
                        else:
                            msg = f"Could not load {new_cfg}"
                        raise MissingConfigException(msg, new_cfg, options)
                else:
                    return cfg

            else:
                ret = OmegaConf.merge(cfg, loaded_cfg)
                assert isinstance(ret, DictConfig)
                return ret
        except OmegaConfBaseException as ex:
            raise HydraException(
                f"Error merging {config_group}={name}") from ex
コード例 #22
0
    def eval(self, func: FunctionCall) -> Any:
        if func.name not in self.definitions:
            raise HydraException(
                f"Unknown function '{func.name}'"
                f"\nAvailable: {','.join(sorted(self.definitions.keys()))}\n")
        sig = self.definitions[func.name]

        # unquote strings in args
        args = []
        for arg in func.args:
            if isinstance(arg, QuotedString):
                arg = arg.text
            args.append(arg)

        # Unquote strings in kwargs values
        kwargs = {}
        for key, val in func.kwargs.items():
            if isinstance(val, QuotedString):
                val = val.text
            kwargs[key] = val

        bound = sig.bind(*args, **kwargs)

        for idx, arg in enumerate(bound.arguments.items()):
            name = arg[0]
            value = arg[1]
            expected_type = sig.parameters[name].annotation
            if sig.parameters[name].kind == inspect.Parameter.VAR_POSITIONAL:
                for iidx, v in enumerate(value):
                    if not is_type_matching(v, expected_type):
                        raise TypeError(
                            f"mismatch type argument {name}[{iidx}]:"
                            f" {type_str(type(v))} is incompatible with {type_str(expected_type)}"
                        )

            else:
                if not is_type_matching(value, expected_type):
                    raise TypeError(
                        f"mismatch type argument {name}:"
                        f" {type_str(type(value))} is incompatible with {type_str(expected_type)}"
                    )

        return self.functions[func.name](*bound.args, **bound.kwargs)
コード例 #23
0
ファイル: utils.py プロジェクト: sawravchy/hydra
def call(config: Union[ObjectConf, DictConfig], *args: Any, **kwargs: Any) -> Any:
    """
    :param config: An ObjectConf or DictConfig describing what to call and what params to use
    :param args: optional positional parameters pass-through
    :param kwargs: optional named parameters pass-through
    :return: the return value from the specified class or method
    """
    if OmegaConf.is_none(config):
        return None
    try:
        cls = _get_cls_name(config)
        type_or_callable = _locate(cls)
        if isinstance(type_or_callable, type):
            return _instantiate_class(type_or_callable, config, *args, **kwargs)
        else:
            assert callable(type_or_callable)
            return _call_callable(type_or_callable, config, *args, **kwargs)
    except Exception as e:
        raise HydraException(f"Error calling '{cls}' : {e}") from e
コード例 #24
0
def call(
    config: Union[ObjectConf, TargetConf, DictConfig, Dict[Any, Any]],
    *args: Any,
    **kwargs: Any,
) -> Any:
    """
    :param config: An object describing what to call and what params to use
    :param args: optional positional parameters pass-through
    :param kwargs: optional named parameters pass-through
    :return: the return value from the specified class or method
    """
    if isinstance(config, TargetConf) and config._target_ == "???":
        raise InstantiationException(
            f"Missing _target_ value. Check that you specified it in '{type(config).__name__}'"
            f" and that the type annotation is correct: '_target_: str = ...'")
    if isinstance(config, (dict, ObjectConf, TargetConf)):
        config = OmegaConf.structured(config)

    if OmegaConf.is_none(config):
        return None
    cls = "<unknown>"
    try:
        assert isinstance(config, DictConfig)
        # make a copy to ensure we do not change the provided object
        config = copy.deepcopy(config)
        OmegaConf.set_readonly(config, False)
        OmegaConf.set_struct(config, False)
        cls = _get_cls_name(config)
        type_or_callable = _locate(cls)
        if isinstance(type_or_callable, type):
            return _instantiate_class(type_or_callable, config, *args,
                                      **kwargs)
        else:
            assert callable(type_or_callable)
            return _call_callable(type_or_callable, config, *args, **kwargs)
    except InstantiationException as e:
        raise e
    except Exception as e:
        raise HydraException(f"Error calling '{cls}' : {e}") from e
コード例 #25
0
 def __next__(self) -> float:
     assert isinstance(self.start, decimal.Decimal)
     assert isinstance(self.stop, decimal.Decimal)
     assert isinstance(self.step, decimal.Decimal)
     if self.step > 0:
         if self.start < self.stop:
             ret = float(self.start)
             self.start += self.step
             return ret
         else:
             raise StopIteration
     elif self.step < 0:
         if self.start > self.stop:
             ret = float(self.start)
             self.start += self.step
             return ret
         else:
             raise StopIteration
     else:
         raise HydraException(
             f"Invalid range values (start:{self.start}, stop:{self.stop}, step:{self.step})"
         )
コード例 #26
0
ファイル: config_source.py プロジェクト: sumitsethtest/hydra
    def _update_package_in_header(
        self,
        header: Dict[str, str],
        normalized_config_path: str,
        is_primary_config: bool,
        package_override: Optional[str],
    ) -> None:
        config_without_ext = normalized_config_path[0:-len(".yaml")]

        package = ConfigSource._resolve_package(
            config_without_ext=config_without_ext,
            header=header,
            package_override=package_override,
        )

        if is_primary_config:
            if "package" not in header:
                header["package"] = "_global_"
            else:
                if package != "":
                    raise HydraException(
                        f"Primary config '{config_without_ext}' must be "
                        f"in the _global_ package; effective package : '{package}'"
                    )
        else:
            if "package" not in header:
                # Loading a config group option.
                # Hydra 1.0: default to _global_ and warn.
                # Hydra 1.1: default will change to _package_ and the warning will be removed.
                header["package"] = "_global_"
                msg = (
                    f"\nMissing @package directive {normalized_config_path} in {self.full_path()}.\n"
                    f"See https://hydra.cc/docs/next/upgrades/0.11_to_1.0/adding_a_package_directive"
                )
                warnings.warn(message=msg, category=UserWarning)

        header["package"] = package
コード例 #27
0
ファイル: utils.py プロジェクト: zhaodan2000/hydra
def instantiate(config: Any, *args: Any, **kwargs: Any) -> Any:
    """
    :param config: An config object describing what to call and what params to use.
                   In addition to the parameters, the config must contain:
                   _target_ : target class or callable name (str)
                   And may contain:
                   _recursive_: Construct nested objects as well (bool).
                                True by default.
                                may be overridden via a _recursive_ key in
                                the kwargs
                   _convert_: Conversion strategy
                        none    : Passed objects are DictConfig and ListConfig, default
                        partial : Passed objects are converted to dict and list, with
                                  the exception of Structured Configs (and their fields).
                        all     : Passed objects are dicts, lists and primitives without
                                  a trace of OmegaConf containers
    :param args: Optional positional parameters pass-through
    :param kwargs: Optional named parameters to override
                   parameters in the config object. Parameters not present
                   in the config objects are being passed as is to the target.
    :return: if _target_ is a class name: the instantiated object
             if _target_ is a callable: the return value of the call
    """

    if OmegaConf.is_none(config):
        return None

    if isinstance(config, TargetConf) and config._target_ == "???":
        # Specific check to give a good warning about failure to annotate _target_ as a string.
        raise InstantiationException(
            f"Missing value for {type(config).__name__}._target_. Check that it's properly annotated and overridden."
            f"\nA common problem is forgetting to annotate _target_ as a string : '_target_: str = ...'"
        )

    if not (
        isinstance(config, dict)
        or OmegaConf.is_config(config)
        or is_structured_config(config)
    ):
        raise HydraException(f"Unsupported config type : {type(config).__name__}")

    if isinstance(config, dict):
        configc = config.copy()
        _convert_container_targets_to_strings(configc)
        config = configc

    kwargsc = kwargs.copy()
    _convert_container_targets_to_strings(kwargsc)
    kwargs = kwargsc

    # make a copy to ensure we do not change the provided object
    config_copy = OmegaConf.structured(config, flags={"allow_objects": True})
    if OmegaConf.is_config(config):
        config_copy._set_parent(config._get_parent())
    config = config_copy
    assert OmegaConf.is_config(config)
    OmegaConf.set_readonly(config, False)
    OmegaConf.set_struct(config, False)

    target = _get_target_type(config, kwargs)
    try:
        config._set_flag("allow_objects", True)
        final_kwargs = _get_kwargs(config, **kwargs)
        convert = _pop_convert_mode(final_kwargs)
        if convert == ConvertMode.PARTIAL:
            final_kwargs = OmegaConf.to_container(
                final_kwargs, resolve=True, exclude_structured_configs=True
            )
            return target(*args, **final_kwargs)
        elif convert == ConvertMode.ALL:
            final_kwargs = OmegaConf.to_container(final_kwargs, resolve=True)
            return target(*args, **final_kwargs)
        elif convert == ConvertMode.NONE:
            return target(*args, **final_kwargs)
        else:
            assert False
    except Exception as e:
        # preserve the original exception backtrace
        raise type(e)(
            f"Error instantiating/calling '{_convert_target_to_string(target)}' : {e}"
        ).with_traceback(sys.exc_info()[2])
コード例 #28
0
ファイル: config_loader_impl.py プロジェクト: tchaton/hydra
    def _apply_overrides_to_defaults(
        overrides: List[Override],
        defaults: List[DefaultElement],
    ) -> None:

        key_to_defaults: Dict[str,
                              List[IndexedDefaultElement]] = defaultdict(list)

        for idx, default in enumerate(defaults):
            if default.config_group is not None:
                key_to_defaults[default.config_group].append(
                    IndexedDefaultElement(idx=idx, default=default))
        for override in overrides:
            value = override.value()
            if value is None:
                if override.is_add():
                    ConfigLoaderImpl._raise_parse_override_error(
                        override.input_line)

                if not override.is_delete():
                    override.type = OverrideType.DEL
                    msg = (
                        "\nRemoving from the defaults list by assigning 'null' "
                        "is deprecated and will be removed in Hydra 1.1."
                        f"\nUse ~{override.key_or_group}")
                    warnings.warn(category=UserWarning, message=msg)

            if (not (override.is_delete() or override.is_package_rename())
                    and value is None):
                ConfigLoaderImpl._raise_parse_override_error(
                    override.input_line)

            if override.is_add() and override.is_package_rename():
                raise HydraException(
                    "Add syntax does not support package rename, remove + prefix"
                )

            matches = ConfigLoaderImpl.find_matches(key_to_defaults, override)

            if isinstance(value, (list, dict)):
                raise HydraException(
                    f"Config group override value type cannot be a {type(value).__name__}"
                )

            if override.is_delete():
                src = override.get_source_item()
                if len(matches) == 0:
                    raise HydraException(
                        f"Could not delete. No match for '{src}' in the defaults list."
                    )
                for pair in matches:
                    if value is not None and value != defaults[
                            pair.idx].config_name:
                        raise HydraException(
                            f"Could not delete. No match for '{src}={value}' in the defaults list."
                        )

                    del defaults[pair.idx]
            elif override.is_add():
                if len(matches) > 0:
                    src = override.get_source_item()
                    raise HydraException(
                        f"Could not add. An item matching '{src}' is already in the defaults list."
                    )
                assert value is not None
                defaults.append(
                    DefaultElement(
                        config_group=override.key_or_group,
                        config_name=str(value),
                        package=override.get_subject_package(),
                    ))
            else:
                assert value is not None
                # override
                for match in matches:
                    default = match.default
                    default.config_name = str(value)
                    if override.is_package_rename():
                        default.package = override.get_subject_package()

                if len(matches) == 0:
                    src = override.get_source_item()
                    if override.is_package_rename():
                        msg = f"Could not rename package. No match for '{src}' in the defaults list."
                    else:
                        msg = (
                            f"Could not override '{src}'. No match in the defaults list."
                            f"\nTo append to your default list use +{override.input_line}"
                        )

                    raise HydraException(msg)
コード例 #29
0
 def _createPrimitive(
     self, ctx: ParserRuleContext
 ) -> Optional[Union[QuotedString, int, bool, float, str]]:
     ret: Optional[Union[int, bool, float, str]]
     first_idx = 0
     last_idx = ctx.getChildCount()
     # skip first if whitespace
     if self.is_ws(ctx.getChild(0)):
         if last_idx == 1:
             # Only whitespaces => this is not allowed.
             raise HydraException(
                 "Trying to parse a primitive that is all whitespaces"
             )
         first_idx = 1
     if self.is_ws(ctx.getChild(-1)):
         last_idx = last_idx - 1
     num = last_idx - first_idx
     if num > 1:
         # Concatenate, while un-escaping as needed.
         tokens = []
         for i, n in enumerate(ctx.getChildren()):
             if n.symbol.type == OverrideLexer.WS and (
                 i < first_idx or i >= last_idx
             ):
                 # Skip leading / trailing whitespaces.
                 continue
             tokens.append(
                 n.symbol.text[1::2]  # un-escape by skipping every other char
                 if n.symbol.type == OverrideLexer.ESC
                 else n.symbol.text
             )
         ret = "".join(tokens)
     else:
         node = ctx.getChild(first_idx)
         if node.symbol.type == OverrideLexer.QUOTED_VALUE:
             text = node.getText()
             qc = text[0]
             text = text[1:-1]
             if qc == "'":
                 quote = Quote.single
                 text = text.replace("\\'", "'")
             elif qc == '"':
                 quote = Quote.double
                 text = text.replace('\\"', '"')
             else:
                 assert False
             return QuotedString(text=text, quote=quote)
         elif node.symbol.type in (OverrideLexer.ID, OverrideLexer.INTERPOLATION):
             ret = node.symbol.text
         elif node.symbol.type == OverrideLexer.INT:
             ret = int(node.symbol.text)
         elif node.symbol.type == OverrideLexer.FLOAT:
             ret = float(node.symbol.text)
         elif node.symbol.type == OverrideLexer.NULL:
             ret = None
         elif node.symbol.type == OverrideLexer.BOOL:
             text = node.getText().lower()
             if text == "true":
                 ret = True
             elif text == "false":
                 ret = False
             else:
                 assert False
         elif node.symbol.type == OverrideLexer.ESC:
             ret = node.symbol.text[1::2]
         else:
             return node.getText()  # type: ignore
     return ret
コード例 #30
0
    def _apply_overrides_to_defaults(
        overrides: List[ParsedOverrideWithLine], defaults: List[DefaultElement],
    ) -> None:

        key_to_defaults: Dict[str, List[IndexedDefaultElement]] = defaultdict(list)

        for idx, default in enumerate(defaults):
            if default.config_group is not None:
                key_to_defaults[default.config_group].append(
                    IndexedDefaultElement(idx=idx, default=default)
                )
        for owl in overrides:
            override = owl.override
            if override.value == "null":
                if override.prefix not in (None, "~"):
                    ConfigLoaderImpl._raise_parse_override_error(owl.input_line)
                override.prefix = "~"
                override.value = None

                msg = (
                    "\nRemoving from the defaults list by assigning 'null' "
                    "is deprecated and will be removed in Hydra 1.1."
                    f"\nUse ~{override.key}"
                )
                warnings.warn(category=UserWarning, message=msg)

            if (
                not (override.is_delete() or override.is_package_rename())
                and override.value is None
            ):
                ConfigLoaderImpl._raise_parse_override_error(owl.input_line)

            if override.is_add() and override.is_package_rename():
                raise HydraException(
                    "Add syntax does not support package rename, remove + prefix"
                )

            if override.value is not None and "," in override.value:
                # If this is a multirun config (comma separated list), flag the default to prevent it from being
                # loaded until we are constructing the config for individual jobs.
                override.value = "_SKIP_"

            matches = ConfigLoaderImpl.find_matches(key_to_defaults, override)

            if override.is_delete():
                src = override.get_source_item()
                if len(matches) == 0:
                    raise HydraException(
                        f"Could not delete. No match for '{src}' in the defaults list."
                    )
                for pair in matches:
                    if (
                        override.value is not None
                        and override.value != defaults[pair.idx].config_name
                    ):
                        raise HydraException(
                            f"Could not delete. No match for '{src}={override.value}' in the defaults list."
                        )

                    del defaults[pair.idx]
            elif override.is_add():
                if len(matches) > 0:
                    src = override.get_source_item()
                    raise HydraException(
                        f"Could not add. An item matching '{src}' is already in the defaults list."
                    )
                assert override.value is not None
                defaults.append(
                    DefaultElement(
                        config_group=override.key,
                        config_name=override.value,
                        package=override.get_subject_package(),
                    )
                )
            else:
                # override
                for match in matches:
                    default = match.default
                    if override.value is not None:
                        default.config_name = override.value
                    if override.pkg1 is not None:
                        default.package = override.get_subject_package()

                if len(matches) == 0:
                    src = override.get_source_item()
                    if override.is_package_rename():
                        msg = f"Could not rename package. No match for '{src}' in the defaults list."
                    else:
                        msg = (
                            f"Could not override '{src}'. No match in the defaults list."
                            f"\nTo append to your default list use +{owl.input_line}"
                        )

                    raise HydraException(msg)