def _get_kwargs( config: Union[DictConfig, ListConfig], root: bool = True, **kwargs: Any, ) -> Any: from hydra.utils import instantiate assert OmegaConf.is_config(config) if OmegaConf.is_list(config): assert isinstance(config, ListConfig) return [ _get_kwargs(x, root=False) if OmegaConf.is_config(x) else x for x in config ] assert OmegaConf.is_dict(config), "Input config is not an OmegaConf DictConfig" recursive = _is_recursive(config, kwargs) overrides = OmegaConf.create(kwargs, flags={"allow_objects": True}) config.merge_with(overrides) final_kwargs = OmegaConf.create(flags={"allow_objects": True}) final_kwargs._set_parent(config._get_parent()) final_kwargs._set_flag("readonly", False) final_kwargs._set_flag("struct", False) if recursive: for k, v in config.items_ex(resolve=False): if OmegaConf.is_none(v): final_kwargs[k] = v elif _is_target(v): final_kwargs[k] = instantiate(v) elif OmegaConf.is_dict(v): d = OmegaConf.create({}, flags={"allow_objects": True}) for key, value in v.items_ex(resolve=False): if _is_target(value): d[key] = instantiate(value) elif OmegaConf.is_config(value): d[key] = _get_kwargs(value, root=False) else: d[key] = value d._metadata.object_type = v._metadata.object_type final_kwargs[k] = d elif OmegaConf.is_list(v): lst = OmegaConf.create([], flags={"allow_objects": True}) for x in v: if _is_target(x): lst.append(instantiate(x)) elif OmegaConf.is_config(x): lst.append(_get_kwargs(x, root=False)) lst[-1]._metadata.object_type = x._metadata.object_type else: lst.append(x) final_kwargs[k] = lst else: final_kwargs[k] = v else: for k, v in config.items_ex(resolve=False): final_kwargs[k] = v final_kwargs._set_flag("readonly", None) final_kwargs._set_flag("struct", None) final_kwargs._set_flag("allow_objects", None) if not root: # This is tricky, since the root kwargs is exploded anyway we can treat is as an untyped dict # the motivation is that the object type is used as an indicator to treat the object differently during # conversion to a primitive container in some cases final_kwargs._metadata.object_type = config._metadata.object_type return final_kwargs
def test_is_config(cfg: Any, expected: bool) -> None: assert OmegaConf.is_config(cfg) == expected
def instantiate_node( node: Any, *args: Any, convert: Union[str, ConvertMode] = ConvertMode.NONE, recursive: bool = True, partial: bool = False, ) -> Any: # Return None if config is None if node is None or (OmegaConf.is_config(node) and node._is_none()): return None if not OmegaConf.is_config(node): return node # Override parent modes from config if specified if OmegaConf.is_dict(node): # using getitem instead of get(key, default) because OmegaConf will raise an exception # if the key type is incompatible on get. convert = node[_Keys.CONVERT] if _Keys.CONVERT in node else convert recursive = node[_Keys.RECURSIVE] if _Keys.RECURSIVE in node else recursive partial = node[_Keys.PARTIAL] if _Keys.PARTIAL in node else partial if not isinstance(recursive, bool): raise TypeError(f"_recursive_ flag must be a bool, got {type(recursive)}") if not isinstance(partial, bool): raise TypeError(f"_partial_ flag must be a bool, got {type( partial )}") # If OmegaConf list, create new list of instances if recursive if OmegaConf.is_list(node): items = [ instantiate_node(item, convert=convert, recursive=recursive) for item in node._iter_ex(resolve=True) ] if convert in (ConvertMode.ALL, ConvertMode.PARTIAL): # If ALL or PARTIAL, use plain list as container return items else: # Otherwise, use ListConfig as container lst = OmegaConf.create(items, flags={"allow_objects": True}) lst._set_parent(node) return lst elif OmegaConf.is_dict(node): exclude_keys = set({"_target_", "_convert_", "_recursive_", "_partial_"}) if _is_target(node): _target_ = _resolve_target(node.get(_Keys.TARGET)) kwargs = {} for key, value in node.items(): if key not in exclude_keys: if recursive: value = instantiate_node( value, convert=convert, recursive=recursive ) kwargs[key] = _convert_node(value, convert) return _call_target(_target_, partial, *args, **kwargs) else: # If ALL or PARTIAL non structured, instantiate in dict and resolve interpolations eagerly. if convert == ConvertMode.ALL or ( convert == ConvertMode.PARTIAL and node._metadata.object_type is None ): dict_items = {} for key, value in node.items(): # list items inherits recursive flag from the containing dict. dict_items[key] = instantiate_node( value, convert=convert, recursive=recursive ) return dict_items else: # Otherwise use DictConfig and resolve interpolations lazily. cfg = OmegaConf.create({}, flags={"allow_objects": True}) for key, value in node.items(): cfg[key] = instantiate_node( value, convert=convert, recursive=recursive ) cfg._set_parent(node) cfg._metadata.object_type = node._metadata.object_type return cfg else: assert False, f"Unexpected config type : {type(node).__name__}"
def str_rep(in_key: Union[str, int], in_value: Any) -> str: if OmegaConf.is_config(in_value): return "{}.".format(in_key) else: return "{}=".format(in_key)
def str_rep(in_key: Any, in_value: Any) -> str: if OmegaConf.is_config(in_value): return f"{in_key}." else: return f"{in_key}="
def recursive_is_struct(node: Any) -> None: if OmegaConf.is_config(node): OmegaConf.is_struct(node) for val in node.values(): recursive_is_struct(val)
def validate_and_convert(self, value: Any) -> Optional[str]: from omegaconf import OmegaConf if OmegaConf.is_config(value) or is_primitive_container(value): raise ValidationError("Cannot convert '$VALUE_TYPE' to string : '$VALUE'") return str(value) if value is not None else None
def test_is_config(cfg: Any, is_conf: bool, is_list: bool, is_dict: bool, type_: Type[Any]) -> None: assert OmegaConf.is_config(cfg) == is_conf assert OmegaConf.is_list(cfg) == is_list assert OmegaConf.is_dict(cfg) == is_dict assert OmegaConf.get_type(cfg) == type_
def _get_matches(config: Container, word: str) -> List[str]: def str_rep(in_key: Union[str, int], in_value: Any) -> str: if OmegaConf.is_config(in_value): return f"{in_key}." else: return f"{in_key}=" if config is None: return [] elif OmegaConf.is_config(config): matches = [] if word.endswith(".") or word.endswith("="): exact_key = word[0:-1] try: conf_node = OmegaConf.select(config, exact_key, throw_on_missing=True) except MissingMandatoryValue: conf_node = "" if conf_node is not None: if OmegaConf.is_config(conf_node): key_matches = CompletionPlugin._get_matches( conf_node, "") else: # primitive if isinstance(conf_node, bool): conf_node = str(conf_node).lower() key_matches = [conf_node] else: key_matches = [] matches.extend([f"{word}{match}" for match in key_matches]) else: last_dot = word.rfind(".") if last_dot != -1: base_key = word[0:last_dot] partial_key = word[last_dot + 1:] conf_node = OmegaConf.select(config, base_key) key_matches = CompletionPlugin._get_matches( conf_node, partial_key) matches.extend( [f"{base_key}.{match}" for match in key_matches]) else: if isinstance(config, DictConfig): for key, value in config.items_ex(resolve=False): if key.startswith(word): matches.append(str_rep(key, value)) elif OmegaConf.is_list(config): assert isinstance(config, ListConfig) for idx in range(len(config)): try: value = config[idx] if str(idx).startswith(word): matches.append(str_rep(idx, value)) except MissingMandatoryValue: matches.append(str_rep(idx, "")) else: assert False, f"Object is not an instance of config : {type(config)}" return matches
def instantiate(config: Any, *args: Any, **kwargs: Any) -> Any: """ :param config: An config object describing what to call and what params to use. In addition to the parameters, the config must contain: _target_ : target class or callable name (str) _recursive_: Construct nested objects as well (bool). True by default. may be overridden via a _recursive_ key in the kwargs :param args: Optional positional parameters pass-through :param kwargs: Optional named parameters to override parameters in the config object. Parameters not present in the config objects are being passed as is to the target. :return: if _target_ is a class name: the instantiated object if _target_ is a callable: the return value of the call """ if OmegaConf.is_none(config): return None if isinstance(config, TargetConf) and config._target_ == "???": # Specific check to give a good warning about failure to annotate _target_ as a string. raise InstantiationException( f"Missing value for {type(config).__name__}._target_. Check that it's properly annotated and overridden." f"\nA common problem is forgetting to annotate _target_ as a string : '_target_: str = ...'" ) if not ( isinstance(config, dict) or OmegaConf.is_config(config) or is_structured_config(config) ): raise HydraException(f"Unsupported config type : {type(config).__name__}") if isinstance(config, dict): configc = config.copy() _convert_container_targets_to_strings(configc) config = configc kwargsc = kwargs.copy() _convert_container_targets_to_strings(kwargsc) kwargs = kwargsc # make a copy to ensure we do not change the provided object config_copy = OmegaConf.structured(config, flags={"allow_objects": True}) if OmegaConf.is_config(config): config_copy._set_parent(config._get_parent()) config = config_copy assert OmegaConf.is_config(config) OmegaConf.set_readonly(config, False) OmegaConf.set_struct(config, False) target = _get_target_type(config, kwargs) try: config._set_flag("allow_objects", True) final_kwargs = _get_kwargs(config, **kwargs) return target(*args, **final_kwargs) except Exception as e: raise type(e)( f"Error instantiating/calling '{_convert_target_to_string(target)}' : {e}" )
def _load_single_config( self, default: ResultDefault, repo: IConfigRepository ) -> ConfigResult: config_path = default.config_path assert config_path is not None ret = repo.load_config(config_path=config_path) assert ret is not None if not OmegaConf.is_config(ret.config): raise ValueError( f"Config {config_path} must be an OmegaConf config, got {type(ret.config).__name__}" ) if not ret.is_schema_source: schema = None try: schema_source = repo.get_schema_source() cname = ConfigSource._normalize_file_name(filename=config_path) schema = schema_source.load_config(cname) except ConfigLoadError: # schema not found, ignore pass if schema is not None: try: url = "https://hydra.cc/docs/next/upgrades/1.0_to_1.1/automatic_schema_matching" if "defaults" in schema.config: raise ConfigCompositionException( dedent( f"""\ '{config_path}' is validated against ConfigStore schema with the same name. This behavior is deprecated in Hydra 1.1 and will be removed in Hydra 1.2. In addition, the automatically matched schema contains a defaults list. This combination is no longer supported. See {url} for migration instructions.""" ) ) else: deprecation_warning( dedent( f"""\ '{config_path}' is validated against ConfigStore schema with the same name. This behavior is deprecated in Hydra 1.1 and will be removed in Hydra 1.2. See {url} for migration instructions.""" ), stacklevel=11, ) # if primary config has a hydra node, remove it during validation and add it back. # This allows overriding Hydra's configuration without declaring it's node # in the schema of every primary config hydra = None hydra_config_group = ( default.config_path is not None and default.config_path.startswith("hydra/") ) config = ret.config if ( default.primary and isinstance(config, DictConfig) and "hydra" in config and not hydra_config_group ): hydra = config.pop("hydra") merged = OmegaConf.merge(schema.config, config) assert isinstance(merged, DictConfig) if hydra is not None: with open_dict(merged): merged.hydra = hydra ret.config = merged except OmegaConfBaseException as e: raise ConfigCompositionException( f"Error merging '{config_path}' with schema" ) from e assert isinstance(merged, DictConfig) res = self._embed_result_config(ret, default.package) if ( not default.primary and config_path != "hydra/config" and isinstance(res.config, DictConfig) and OmegaConf.select(res.config, "hydra.searchpath") is not None ): raise ConfigCompositionException( f"In '{config_path}': Overriding hydra.searchpath is only supported from the primary config" ) return res