def _get_kwargs( config: Union[DictConfig, ListConfig], **kwargs: Any, ) -> Any: from hydra.utils import instantiate assert OmegaConf.is_config(config) if OmegaConf.is_list(config): assert isinstance(config, ListConfig) return [ _get_kwargs(x) if OmegaConf.is_config(x) else x for x in config ] assert OmegaConf.is_dict( config), "Input config is not an OmegaConf DictConfig" final_kwargs = {} recursive = _is_recursive(config, kwargs) overrides = OmegaConf.create(kwargs, flags={"allow_objects": True}) config.merge_with(overrides) for k, v in config.items(): final_kwargs[k] = v if recursive: for k, v in final_kwargs.items(): if _is_target(v): final_kwargs[k] = instantiate(v) elif OmegaConf.is_dict(v) and not OmegaConf.is_none(v): d = OmegaConf.create({}, flags={"allow_objects": True}) for key, value in v.items(): if _is_target(value): d[key] = instantiate(value) elif OmegaConf.is_config(value): d[key] = _get_kwargs(value) else: d[key] = value final_kwargs[k] = d elif OmegaConf.is_list(v): lst = OmegaConf.create([], flags={"allow_objects": True}) for x in v: if _is_target(x): lst.append(instantiate(x)) elif OmegaConf.is_config(x): lst.append(_get_kwargs(x)) else: lst.append(x) final_kwargs[k] = lst else: if OmegaConf.is_none(v): v = None final_kwargs[k] = v return final_kwargs
def test_merge_none_is_none(self, class_type: str) -> None: # Test that the merged type is that of the last merged config module: Any = import_module(class_type) c1 = OmegaConf.structured(module.StructuredOptional) assert c1.with_default == module.Nested() c2 = OmegaConf.merge(c1, {"with_default": None}) assert OmegaConf.is_none(c2, "with_default")
def _set_value(self, value: Any) -> None: from omegaconf import OmegaConf self._metadata.object_type = self._metadata.annotated_type type_ = (self._metadata.object_type if self._metadata.object_type is not None else DictConfig) if OmegaConf.is_none(value): if not self._is_optional(): assert isinstance(type_, type) raise ValidationError( f"Cannot assign {type_.__name__}=None (field is not Optional)" ) self.__dict__["_content"] = None elif _is_interpolation(value): self.__dict__["_content"] = value elif value == "???": # missing self.__dict__["_content"] = "???" else: is_structured = is_structured_config(value) if is_structured: _type = get_type_of(value) value = get_structured_config_data(value) self._metadata.object_type = None self.__dict__["_content"] = {} for k, v in value.items(): self.__setitem__(k, v) if is_structured: self._metadata.object_type = _type
def _validate_merge(self, value: Any) -> None: from omegaconf import OmegaConf dest = self src = value self._validate_non_optional(None, src) dest_obj_type = OmegaConf.get_type(dest) src_obj_type = OmegaConf.get_type(src) if dest._is_missing() and src._metadata.object_type is not None: self._validate_set(key=None, value=_get_value(src)) if src._is_missing(): return validation_error = (dest_obj_type is not None and src_obj_type is not None and is_structured_config(dest_obj_type) and not OmegaConf.is_none(src) and not is_dict(src_obj_type) and not issubclass(src_obj_type, dest_obj_type)) if validation_error: msg = (f"Merge error : {type_str(src_obj_type)} is not a " f"subclass of {type_str(dest_obj_type)}. value: {src}") raise ValidationError(msg)
def _set_value(self, value: Any) -> None: from omegaconf import OmegaConf assert not isinstance(value, ValueNode) self._validate_set(key=None, value=value) if OmegaConf.is_none(value): self.__dict__["_content"] = None self._metadata.object_type = None elif _is_interpolation(value): self.__dict__["_content"] = value self._metadata.object_type = None elif value == "???": self.__dict__["_content"] = "???" self._metadata.object_type = None else: self.__dict__["_content"] = {} if is_structured_config(value): self._metadata.object_type = None data = get_structured_config_data(value) for k, v in data.items(): self.__setitem__(k, v) self._metadata.object_type = get_type_of(value) elif isinstance(value, DictConfig): self._metadata.object_type = dict for k, v in value.items_ex(resolve=False): self.__setitem__(k, v) self.__dict__["_metadata"] = copy.deepcopy(value._metadata) elif isinstance(value, dict): for k, v in value.items(): self.__setitem__(k, v) else: assert False, f"Unsupported value type : {value}" # pragma: no cover
def verify( cfg: Any, key: Any, none: bool, opt: bool, missing: bool, inter: bool, none_public: Optional[bool] = None, exp: Any = SKIP, ) -> None: if none_public is None: none_public = none target_node = cfg._get_node(key) assert target_node._key() == key assert target_node._is_none() == none assert target_node._is_optional() == opt assert target_node._is_missing() == missing assert target_node._is_interpolation() == inter if exp is not SKIP: assert cfg.get(key) == exp assert OmegaConf.is_missing(cfg, key) == missing with warns(UserWarning): assert OmegaConf.is_none(cfg, key) == none_public assert _is_optional(cfg, key) == opt assert OmegaConf.is_interpolation(cfg, key) == inter
def test_is_none_interpolation(cfg: Any, key: str, is_none: bool) -> None: cfg = OmegaConf.create(cfg) with warns(UserWarning): assert OmegaConf.is_none(cfg, key) == is_none check = _is_none(cfg._get_node(key), resolve=True, throw_on_resolution_failure=False) assert check == is_none
def test_none_construction(self, node_type: Any, values: Any) -> None: values = copy.deepcopy(values) node = node_type(value=None, is_optional=True) if isinstance(node, ValueNode): assert node._value() is None assert node._is_optional() assert node.__eq__(None) assert OmegaConf.is_none(node) for value in values: node._set_value(value) assert node.__eq__(value) assert not node.__eq__(None) assert not OmegaConf.is_none(node) with pytest.raises(ValidationError): node_type(value=None, is_optional=False)
def _set_value_impl( self, value: Any, flags: Optional[Dict[str, bool]] = None ) -> None: from omegaconf import OmegaConf, flag_override if id(self) == id(value): return if flags is None: flags = {} assert not isinstance(value, ValueNode) self._validate_set(key=None, value=value) if OmegaConf.is_none(value): self.__dict__["_content"] = None self._metadata.object_type = None elif _is_interpolation(value): self.__dict__["_content"] = value self._metadata.object_type = None elif value == "???": self.__dict__["_content"] = "???" self._metadata.object_type = None else: self.__dict__["_content"] = {} if is_structured_config(value): self._metadata.object_type = None data = get_structured_config_data( value, allow_objects=self._get_flag("allow_objects"), ) for k, v in data.items(): self.__setitem__(k, v) self._metadata.object_type = get_type_of(value) elif isinstance(value, DictConfig): self.__dict__["_metadata"] = copy.deepcopy(value._metadata) self._metadata.flags = copy.deepcopy(flags) # disable struct and readonly for the construction phase # retaining other flags like allow_objects. The real flags are restored at the end of this function with flag_override(self, "struct", False): with flag_override(self, "readonly", False): for k, v in value.__dict__["_content"].items(): self.__setitem__(k, v) elif isinstance(value, dict): for k, v in value.items(): self.__setitem__(k, v) else: # pragma: no cover msg = f"Unsupported value type : {value}" raise ValidationError(msg)
def call(config: Any, *args: Any, **kwargs: Any) -> Any: """ :param config: An object describing what to call and what params to use. Must have a _target_ field. :param args: optional positional parameters pass-through :param kwargs: optional named parameters pass-through :return: the return value from the specified class or method """ if OmegaConf.is_none(config): return None if isinstance(config, TargetConf) and config._target_ == "???": # Specific check to give a good warning about failure to annotate _target_ as a string. raise InstantiationException( f"Missing value for {type(config).__name__}._target_. Check that it's properly annotated and overridden." f"\nA common problem is forgetting to annotate _target_ as a string : '_target_: str = ...'" ) if not (isinstance(config, dict) or OmegaConf.is_config(config) or is_structured_config(config)): raise HydraException( f"Unsupported config type : {type(config).__name__}") # make a copy to ensure we do not change the provided object config_copy = OmegaConf.structured(config) if OmegaConf.is_config(config): config_copy._set_parent(config._get_parent()) config = config_copy cls = "<unknown>" try: assert isinstance(config, DictConfig) OmegaConf.set_readonly(config, False) OmegaConf.set_struct(config, False) cls = _get_cls_name(config) type_or_callable = _locate(cls) if isinstance(type_or_callable, type): return _instantiate_class(type_or_callable, config, *args, **kwargs) else: assert callable(type_or_callable) return _call_callable(type_or_callable, config, *args, **kwargs) except InstantiationException as e: raise e except Exception as e: raise HydraException(f"Error calling '{cls}' : {e}") from e
def call(config: Union[ObjectConf, DictConfig], *args: Any, **kwargs: Any) -> Any: """ :param config: An ObjectConf or DictConfig describing what to call and what params to use :param args: optional positional parameters pass-through :param kwargs: optional named parameters pass-through :return: the return value from the specified class or method """ if OmegaConf.is_none(config): return None try: cls = _get_cls_name(config) type_or_callable = _locate(cls) if isinstance(type_or_callable, type): return _instantiate_class(type_or_callable, config, *args, **kwargs) else: assert callable(type_or_callable) return _call_callable(type_or_callable, config, *args, **kwargs) except Exception as e: raise HydraException(f"Error calling '{cls}' : {e}") from e
def _validate_non_optional(self, key: Any, value: Any) -> None: from omegaconf import OmegaConf if OmegaConf.is_none(value): if key is not None: child = self._get_node(key) if child is not None and not child._is_optional(): self._format_and_raise( key=key, value=value, cause=ValidationError("child '$FULL_KEY' is not Optional"), ) else: if not self._is_optional(): self._format_and_raise( key=None, value=value, cause=ValidationError("field '$FULL_KEY' is not Optional"), )
def _set_value_impl(self, value: Any, flags: Optional[Dict[str, bool]] = None) -> None: from omegaconf import OmegaConf, flag_override if id(self) == id(value): return if flags is None: flags = {} if OmegaConf.is_none(value): if not self._is_optional(): raise ValidationError( "Non optional ListConfig cannot be constructed from None") self.__dict__["_content"] = None elif get_value_kind(value) == ValueKind.MANDATORY_MISSING: self.__dict__["_content"] = "???" elif get_value_kind(value) in ( ValueKind.INTERPOLATION, ValueKind.STR_INTERPOLATION, ): self.__dict__["_content"] = value else: if not (is_primitive_list(value) or isinstance(value, ListConfig)): type_ = type(value) msg = f"Invalid value assigned : {type_.__name__} is not a ListConfig, list or tuple." raise ValidationError(msg) self.__dict__["_content"] = [] if isinstance(value, ListConfig): self.__dict__["_metadata"] = copy.deepcopy(value._metadata) self._metadata.flags = copy.deepcopy(flags) # disable struct and readonly for the construction phase # retaining other flags like allow_objects. The real flags are restored at the end of this function with flag_override(self, "struct", False): with flag_override(self, "readonly", False): for item in value._iter_ex(resolve=False): self.append(item) elif is_primitive_list(value): for item in value: self.append(item)
def call( config: Union[ObjectConf, TargetConf, DictConfig, Dict[Any, Any]], *args: Any, **kwargs: Any, ) -> Any: """ :param config: An object describing what to call and what params to use :param args: optional positional parameters pass-through :param kwargs: optional named parameters pass-through :return: the return value from the specified class or method """ if isinstance(config, TargetConf) and config._target_ == "???": raise InstantiationException( f"Missing _target_ value. Check that you specified it in '{type(config).__name__}'" f" and that the type annotation is correct: '_target_: str = ...'") if isinstance(config, (dict, ObjectConf, TargetConf)): config = OmegaConf.structured(config) if OmegaConf.is_none(config): return None cls = "<unknown>" try: assert isinstance(config, DictConfig) # make a copy to ensure we do not change the provided object config = copy.deepcopy(config) OmegaConf.set_readonly(config, False) OmegaConf.set_struct(config, False) cls = _get_cls_name(config) type_or_callable = _locate(cls) if isinstance(type_or_callable, type): return _instantiate_class(type_or_callable, config, *args, **kwargs) else: assert callable(type_or_callable) return _call_callable(type_or_callable, config, *args, **kwargs) except InstantiationException as e: raise e except Exception as e: raise HydraException(f"Error calling '{cls}' : {e}") from e
def verify( cfg: Any, key: Any, none: bool, opt: bool, missing: bool, inter: bool, exp: Any = SKIP, ) -> None: target_node = cfg._get_node(key) assert target_node._key() == key assert target_node._is_none() == none assert target_node._is_optional() == opt assert target_node._is_missing() == missing assert target_node._is_interpolation() == inter if exp is not SKIP: assert cfg.get(key) == exp assert OmegaConf.is_missing(cfg, key) == missing assert OmegaConf.is_none(cfg, key) == none assert OmegaConf.is_optional(cfg, key) == opt assert OmegaConf.is_interpolation(cfg, key) == inter
def test_is_none(fac: Any, is_none: bool) -> None: obj = fac(is_none) assert OmegaConf.is_none(obj) == is_none cfg = OmegaConf.create({"node": obj}) assert OmegaConf.is_none(cfg, "node") == is_none
def _get_kwargs( config: Union[DictConfig, ListConfig], root: bool = True, **kwargs: Any, ) -> Any: from hydra.utils import instantiate assert OmegaConf.is_config(config) if OmegaConf.is_list(config): assert isinstance(config, ListConfig) return [ _get_kwargs(x, root=False) if OmegaConf.is_config(x) else x for x in config ] assert OmegaConf.is_dict( config), "Input config is not an OmegaConf DictConfig" recursive = _is_recursive(config, kwargs) overrides = OmegaConf.create(kwargs, flags={"allow_objects": True}) config.merge_with(overrides) final_kwargs = OmegaConf.create(flags={"allow_objects": True}) final_kwargs._set_parent(config._get_parent()) final_kwargs._set_flag("readonly", False) final_kwargs._set_flag("struct", False) if recursive: for k, v in config.items_ex(resolve=False): if OmegaConf.is_none(v): final_kwargs[k] = v elif _is_target(v): final_kwargs[k] = instantiate(v) elif OmegaConf.is_dict(v): d = OmegaConf.create({}, flags={"allow_objects": True}) for key, value in v.items_ex(resolve=False): if _is_target(value): d[key] = instantiate(value) elif OmegaConf.is_config(value): d[key] = _get_kwargs(value, root=False) else: d[key] = value d._metadata.object_type = v._metadata.object_type final_kwargs[k] = d elif OmegaConf.is_list(v): lst = OmegaConf.create([], flags={"allow_objects": True}) for x in v: if _is_target(x): lst.append(instantiate(x)) elif OmegaConf.is_config(x): lst.append(_get_kwargs(x, root=False)) lst[-1]._metadata.object_type = x._metadata.object_type else: lst.append(x) final_kwargs[k] = lst else: final_kwargs[k] = v else: for k, v in config.items_ex(resolve=False): final_kwargs[k] = v final_kwargs._set_flag("readonly", None) final_kwargs._set_flag("struct", None) final_kwargs._set_flag("allow_objects", None) if not root: # This is tricky, since the root kwargs is exploded anyway we can treat is as an untyped dict # the motivation is that the object type is used as an indicator to treat the object differently during # conversion to a primitive container in some cases final_kwargs._metadata.object_type = config._metadata.object_type return final_kwargs
def _validate_set_merge_impl(self, key: Any, value: Any, is_assign: bool) -> None: from omegaconf import OmegaConf vk = get_value_kind(value) if vk in (ValueKind.INTERPOLATION, ValueKind.STR_INTERPOLATION): return if OmegaConf.is_none(value): if key is not None: child = self._get_node(key) if child is not None and not child._is_optional(): self._format_and_raise( key=key, value=value, cause=ValidationError( "child '$FULL_KEY' is not Optional"), ) else: if not self._is_optional(): self._format_and_raise( key=None, value=value, cause=ValidationError( "field '$FULL_KEY' is not Optional"), ) if value == "???": return if is_assign and isinstance(value, ValueNode) and self._has_element_type(): self._check_assign_value_node(key, value) return target: Optional[Node] if key is None: target = self else: target = self._get_node(key) if target is None: return def is_typed(c: Any) -> bool: return isinstance( c, DictConfig) and c._metadata.ref_type not in (Any, dict) if not is_typed(target): return # target must be optional by now. no need to check the type of value if None. if value is None: return target_type = target._metadata.ref_type value_type = OmegaConf.get_type(value) if is_dict(value_type) and is_dict(target_type): return if is_generic_list(target_type) or is_generic_dict(target_type): return # TODO: simplify this flow. (if assign validate assign else validate merge) # is assignment illegal? validation_error = (target_type is not None and value_type is not None and not issubclass(value_type, target_type)) if not is_assign: # merge # Merging of a dictionary is allowed even if assignment is illegal (merge would do deeper checks) validation_error = not is_dict(value_type) and validation_error if validation_error: assert value_type is not None assert target_type is not None msg = (f"Invalid type assigned : {type_str(value_type)} is not a " f"subclass of {type_str(target_type)}. value: {value}") raise ValidationError(msg)
def instantiate(config: Any, *args: Any, **kwargs: Any) -> Any: """ :param config: An config object describing what to call and what params to use. In addition to the parameters, the config must contain: _target_ : target class or callable name (str) And may contain: _recursive_: Construct nested objects as well (bool). True by default. may be overridden via a _recursive_ key in the kwargs _convert_: Conversion strategy none : Passed objects are DictConfig and ListConfig, default partial : Passed objects are converted to dict and list, with the exception of Structured Configs (and their fields). all : Passed objects are dicts, lists and primitives without a trace of OmegaConf containers :param args: Optional positional parameters pass-through :param kwargs: Optional named parameters to override parameters in the config object. Parameters not present in the config objects are being passed as is to the target. :return: if _target_ is a class name: the instantiated object if _target_ is a callable: the return value of the call """ if OmegaConf.is_none(config): return None if isinstance(config, TargetConf) and config._target_ == "???": # Specific check to give a good warning about failure to annotate _target_ as a string. raise InstantiationException( f"Missing value for {type(config).__name__}._target_. Check that it's properly annotated and overridden." f"\nA common problem is forgetting to annotate _target_ as a string : '_target_: str = ...'" ) if not ( isinstance(config, dict) or OmegaConf.is_config(config) or is_structured_config(config) ): raise HydraException(f"Unsupported config type : {type(config).__name__}") if isinstance(config, dict): configc = config.copy() _convert_container_targets_to_strings(configc) config = configc kwargsc = kwargs.copy() _convert_container_targets_to_strings(kwargsc) kwargs = kwargsc # make a copy to ensure we do not change the provided object config_copy = OmegaConf.structured(config, flags={"allow_objects": True}) if OmegaConf.is_config(config): config_copy._set_parent(config._get_parent()) config = config_copy assert OmegaConf.is_config(config) OmegaConf.set_readonly(config, False) OmegaConf.set_struct(config, False) target = _get_target_type(config, kwargs) try: config._set_flag("allow_objects", True) final_kwargs = _get_kwargs(config, **kwargs) convert = _pop_convert_mode(final_kwargs) if convert == ConvertMode.PARTIAL: final_kwargs = OmegaConf.to_container( final_kwargs, resolve=True, exclude_structured_configs=True ) return target(*args, **final_kwargs) elif convert == ConvertMode.ALL: final_kwargs = OmegaConf.to_container(final_kwargs, resolve=True) return target(*args, **final_kwargs) elif convert == ConvertMode.NONE: return target(*args, **final_kwargs) else: assert False except Exception as e: # preserve the original exception backtrace raise type(e)( f"Error instantiating/calling '{_convert_target_to_string(target)}' : {e}" ).with_traceback(sys.exc_info()[2])
def _validate_set_merge_impl(self, key: Any, value: Any, is_assign: bool) -> None: from omegaconf import OmegaConf vk = get_value_kind(value) if vk in (ValueKind.INTERPOLATION, ValueKind.STR_INTERPOLATION): return if OmegaConf.is_none(value): if key is not None: node = self._get_node(key) if node is not None and not node._is_optional(): self._format_and_raise( key=key, value=value, cause=ValidationError("field '$FULL_KEY' is not Optional"), ) else: if not self._is_optional(): raise ValidationError("field '$FULL_KEY' is not Optional") if value == "???": return target: Optional[Node] if key is None: target = self else: target = self._get_node(key) if (target is not None and target._get_flag("readonly")) or self._get_flag( "readonly" ): if is_assign: msg = f"Cannot assign to read-only node : {value}" else: msg = f"Cannot merge into read-only node : {value}" raise ReadonlyConfigError(msg) if target is None: return def is_typed(c: Any) -> bool: return isinstance(c, DictConfig) and c._metadata.ref_type not in (Any, dict) if not is_typed(target): return # target must be optional by now. no need to check the type of value if None. if value is None: return target_type = target._metadata.ref_type value_type = OmegaConf.get_type(value) if is_dict(value_type) and is_dict(target_type): return # is assignment illegal? validation_error = ( target_type is not None and value_type is not None and not issubclass(value_type, target_type) ) if not is_assign: # merge # Merging of a dictionary is allowed even if assignment is illegal (merge would do deeper checks) validation_error = not is_dict(value_type) and validation_error if validation_error: assert value_type is not None assert target_type is not None msg = ( f"Invalid type assigned : {type_str(value_type)} is not a " f"subclass of {type_str(target_type)}. value: {value}" ) raise ValidationError(msg)
def format_and_raise( node: Any, key: Any, value: Any, msg: str, cause: Exception, type_override: Any = None, ) -> None: from omegaconf import OmegaConf from omegaconf.base import Node if isinstance(cause, OmegaConfBaseException) and cause._initialized: ex = cause if type_override is not None: ex = type_override(str(cause)) ex.__dict__ = copy.deepcopy(cause.__dict__) _raise(ex, cause) object_type: Optional[Type[Any]] object_type_str: Optional[str] = None ref_type: Optional[Type[Any]] ref_type_str: Optional[str] child_node: Optional[Node] = None if node is None: full_key = "" object_type = None ref_type = None ref_type_str = None else: if key is not None and not OmegaConf.is_none(node): child_node = node._get_node(key, validate_access=False, disable_warning=True) full_key = node._get_full_key(key=key, disable_warning=True) object_type = OmegaConf.get_type(node) object_type_str = type_str(object_type) ref_type = get_ref_type(node) ref_type_str = type_str(ref_type) msg = string.Template(msg).substitute( REF_TYPE=ref_type_str, OBJECT_TYPE=object_type_str, KEY=key, FULL_KEY=full_key, VALUE=value, VALUE_TYPE=f"{type(value).__name__}", KEY_TYPE=f"{type(key).__name__}", ) template = """$MSG \tfull_key: $FULL_KEY \treference_type=$REF_TYPE \tobject_type=$OBJECT_TYPE""" s = string.Template(template=template) message = s.substitute( REF_TYPE=ref_type_str, OBJECT_TYPE=object_type_str, MSG=msg, FULL_KEY=full_key, ) exception_type = type(cause) if type_override is None else type_override if exception_type == TypeError: exception_type = ConfigTypeError elif exception_type == IndexError: exception_type = ConfigIndexError ex = exception_type(f"{message}") if issubclass(exception_type, OmegaConfBaseException): ex._initialized = True ex.msg = message ex.parent_node = node ex.child_node = child_node ex.key = key ex.full_key = full_key ex.value = value ex.object_type = object_type ex.object_type_str = object_type_str ex.ref_type = ref_type ex.ref_type_str = ref_type_str _raise(ex, cause)
def _map_merge(dest: "BaseContainer", src: "BaseContainer") -> None: """merge src into dest and return a new copy, does not modified input""" from omegaconf import AnyNode, DictConfig, OmegaConf, ValueNode assert isinstance(dest, DictConfig) assert isinstance(src, DictConfig) src_type = src._metadata.object_type src_ref_type = get_ref_type(src) assert src_ref_type is not None # If source DictConfig is: # - an interpolation => set the destination DictConfig to be the same interpolation # - None => set the destination DictConfig to None if src._is_interpolation() or src._is_none(): dest._set_value(src._value()) _update_types(node=dest, ref_type=src_ref_type, object_type=src_type) return dest._validate_merge(value=src) def expand(node: Container) -> None: rt = node._metadata.ref_type val: Any if rt is not Any: if is_dict_annotation(rt): val = {} elif is_list_annotation(rt): val = [] else: val = rt elif isinstance(node, DictConfig): val = {} else: assert False node._set_value(val) if (src._is_missing() and not dest._is_missing() and is_structured_config(src_ref_type)): # Replace `src` with a prototype of its corresponding structured config # whose fields are all missing (to avoid overwriting fields in `dest`). src = _create_structured_with_missing_fields(ref_type=src_ref_type, object_type=src_type) if (dest._is_interpolation() or dest._is_missing()) and not src._is_missing(): expand(dest) for key, src_value in src.items_ex(resolve=False): src_node = src._get_node(key, validate_access=False) dest_node = dest._get_node(key, validate_access=False) assert src_node is None or isinstance(src_node, Node) assert dest_node is None or isinstance(dest_node, Node) if isinstance(dest_node, DictConfig): dest_node._validate_merge(value=src_node) missing_src_value = _is_missing_value(src_value) if (isinstance(dest_node, Container) and OmegaConf.is_none(dest, key) and not missing_src_value and not OmegaConf.is_none(src_value)): expand(dest_node) if dest_node is not None and dest_node._is_interpolation(): target_node = dest_node._dereference_node( throw_on_resolution_failure=False) if isinstance(target_node, Container): dest[key] = target_node dest_node = dest._get_node(key) if (dest_node is None and is_structured_config(dest._metadata.element_type) and not missing_src_value): # merging into a new node. Use element_type as a base dest[key] = DictConfig(content=dest._metadata.element_type, parent=dest) dest_node = dest._get_node(key) if dest_node is not None: if isinstance(dest_node, BaseContainer): if isinstance(src_value, BaseContainer): dest_node._merge_with(src_value) elif not missing_src_value: dest.__setitem__(key, src_value) else: if isinstance(src_value, BaseContainer): dest.__setitem__(key, src_value) else: assert isinstance(dest_node, ValueNode) assert isinstance(src_node, ValueNode) # Compare to literal missing, ignoring interpolation src_node_missing = _is_missing_literal(src_value) try: if isinstance(dest_node, AnyNode): if src_node_missing: node = copy.copy(src_node) # if src node is missing, use the value from the dest_node, # but validate it against the type of the src node before assigment node._set_value(dest_node._value()) else: node = src_node dest.__setitem__(key, node) else: if not src_node_missing: dest_node._set_value(src_value) except (ValidationError, ReadonlyConfigError) as e: dest._format_and_raise(key=key, value=src_value, cause=e) else: from omegaconf import open_dict if is_structured_config(src_type): # verified to be compatible above in _validate_merge with open_dict(dest): dest[key] = src._get_node(key) else: dest[key] = src._get_node(key) _update_types(node=dest, ref_type=src_ref_type, object_type=src_type) # explicit flags on the source config are replacing the flag values in the destination flags = src._metadata.flags assert flags is not None for flag, value in flags.items(): if value is not None: dest._set_flag(flag, value)
def instantiate_node( node: Any, *args: Any, convert: Union[str, ConvertMode] = ConvertMode.NONE, recursive: bool = True, ) -> Any: # Return None if config is None if node is None or OmegaConf.is_none(node): return None if not OmegaConf.is_config(node): return node # Override parent modes from config if specified if OmegaConf.is_dict(node): # using getitem instead of get(key, default) because OmegaConf will raise an exception # if the key type is incompatible on get. convert = node[_Keys.CONVERT] if _Keys.CONVERT in node else convert recursive = node[ _Keys.RECURSIVE] if _Keys.RECURSIVE in node else recursive if not isinstance(recursive, bool): raise TypeError( f"_recursive_ flag must be a bool, got {type(recursive)}") # If OmegaConf list, create new list of instances if recursive if OmegaConf.is_list(node): items = [ instantiate_node(item, convert=convert, recursive=recursive) for item in node._iter_ex(resolve=True) ] if convert in (ConvertMode.ALL, ConvertMode.PARTIAL): # If ALL or PARTIAL, use plain list as container return items else: # Otherwise, use ListConfig as container lst = OmegaConf.create(items, flags={"allow_objects": True}) lst._set_parent(node) return lst elif OmegaConf.is_dict(node): exclude_keys = set({"_target_", "_convert_", "_recursive_"}) if _is_target(node): target = _resolve_target(node.get(_Keys.TARGET)) kwargs = {} for key, value in node.items_ex(resolve=True): if key not in exclude_keys: if recursive: value = instantiate_node(value, convert=convert, recursive=recursive) kwargs[key] = _convert_node(value, convert) return _call_target(target, *args, **kwargs) else: # If ALL or PARTIAL non structured, instantiate in dict and resolve interpolations eagerly. if convert == ConvertMode.ALL or ( convert == ConvertMode.PARTIAL and node._metadata.object_type is None): dict_items = {} for key, value in node.items_ex(resolve=True): # list items inherits recursive flag from the containing dict. dict_items[key] = instantiate_node(value, convert=convert, recursive=recursive) return dict_items else: # Otherwise use DictConfig and resolve interpolations lazily. cfg = OmegaConf.create({}, flags={"allow_objects": True}) for key, value in node.items_ex(resolve=False): cfg[key] = instantiate_node(value, convert=convert, recursive=recursive) cfg._set_parent(node) cfg._metadata.object_type = node._metadata.object_type return cfg else: assert False, f"Unexpected config type : {type(node).__name__}"
def test_is_none_invalid_node() -> None: cfg = OmegaConf.create({}) with warns(UserWarning): assert OmegaConf.is_none(cfg, "invalid")
def _map_merge(dest: "BaseContainer", src: "BaseContainer") -> None: """merge src into dest and return a new copy, does not modified input""" from omegaconf import DictConfig, OmegaConf, ValueNode assert isinstance(dest, DictConfig) assert isinstance(src, DictConfig) src_type = src._metadata.object_type # if source DictConfig is missing set the DictConfig one to be missing too. if src._is_missing(): dest._set_value("???") return dest._validate_set_merge_impl(key=None, value=src, is_assign=False) def expand(node: Container) -> None: type_ = get_ref_type(node) if type_ is not None: _is_optional, type_ = _resolve_optional(type_) if is_dict_annotation(type_): node._set_value({}) elif is_list_annotation(type_): node._set_value([]) else: node._set_value(type_) if dest._is_missing(): expand(dest) for key, src_value in src.items_ex(resolve=False): dest_node = dest._get_node(key, validate_access=False) if isinstance(dest_node, Container) and OmegaConf.is_none( dest, key): if not OmegaConf.is_none(src_value): expand(dest_node) if dest_node is not None: if dest_node._is_interpolation(): target_node = dest_node._dereference_node( throw_on_resolution_failure=False) if isinstance(target_node, Container): dest[key] = target_node dest_node = dest._get_node(key) if is_structured_config(dest._metadata.element_type): dest[key] = DictConfig(content=dest._metadata.element_type, parent=dest) dest_node = dest._get_node(key) if dest_node is not None: if isinstance(dest_node, BaseContainer): if isinstance(src_value, BaseContainer): dest._validate_merge(key=key, value=src_value) dest_node._merge_with(src_value) else: dest.__setitem__(key, src_value) else: if isinstance(src_value, BaseContainer): dest.__setitem__(key, src_value) else: assert isinstance(dest_node, ValueNode) try: dest_node._set_value(src_value) except (ValidationError, ReadonlyConfigError) as e: dest._format_and_raise(key=key, value=src_value, cause=e) else: from omegaconf import open_dict if is_structured_config(src_type): # verified to be compatible above in _validate_set_merge_impl with open_dict(dest): dest[key] = src._get_node(key) else: dest[key] = src._get_node(key) if src_type is not None and not is_primitive_dict(src_type): dest._metadata.object_type = src_type # explicit flags on the source config are replacing the flag values in the destination flags = src._metadata.flags assert flags is not None for flag, value in flags.items(): if value is not None: dest._set_flag(flag, value)
def instantiate(config: Any, *args: Any, **kwargs: Any) -> Any: """ :param config: An config object describing what to call and what params to use. In addition to the parameters, the config must contain: _target_ : target class or callable name (str) And may contain: _args_: List-like of positional arguments to pass to the target _recursive_: Construct nested objects as well (bool). True by default. may be overridden via a _recursive_ key in the kwargs _convert_: Conversion strategy none : Passed objects are DictConfig and ListConfig, default partial : Passed objects are converted to dict and list, with the exception of Structured Configs (and their fields). all : Passed objects are dicts, lists and primitives without a trace of OmegaConf containers _args_: List-like of positional arguments :param args: Optional positional parameters pass-through :param kwargs: Optional named parameters to override parameters in the config object. Parameters not present in the config objects are being passed as is to the target. IMPORTANT: dataclasses instances in kwargs are interpreted as config and cannot be used as passthrough :return: if _target_ is a class name: the instantiated object if _target_ is a callable: the return value of the call """ # Return None if config is None if config is None or OmegaConf.is_none(config): return None # TargetConf edge case if isinstance(config, TargetConf) and config._target_ == "???": # Specific check to give a good warning about failure to annotate _target_ as a string. raise InstantiationException( f"Missing value for {type(config).__name__}._target_. Check that it's properly annotated and overridden." f"\nA common problem is forgetting to annotate _target_ as a string : '_target_: str = ...'" ) if isinstance(config, dict): config = _prepare_input_dict(config) kwargs = _prepare_input_dict(kwargs) # Structured Config always converted first to OmegaConf if is_structured_config(config) or isinstance(config, dict): config = OmegaConf.structured(config, flags={"allow_objects": True}) if OmegaConf.is_dict(config): # Finalize config (convert targets to strings, merge with kwargs) config_copy = copy.deepcopy(config) config_copy._set_flag(flags=["allow_objects", "struct", "readonly"], values=[True, False, False]) config_copy._set_parent(config._get_parent()) config = config_copy if kwargs: config = OmegaConf.merge(config, kwargs) _recursive_ = config.pop(_Keys.RECURSIVE, True) _convert_ = config.pop(_Keys.CONVERT, ConvertMode.NONE) return instantiate_node(config, *args, recursive=_recursive_, convert=_convert_) else: raise InstantiationException( "Top level config has to be OmegaConf DictConfig, plain dict, or a Structured Config class or instance" )
def _param_in_cfg(self, param): with open_dict(self.cfg): return not OmegaConf.is_none(self.cfg, param)
def _is_target(x: Any) -> bool: if isinstance(x, dict): return "_target_" in x if OmegaConf.is_dict(x) and not OmegaConf.is_none(x): return "_target_" in x return False
def instantiate(config: Any, *args: Any, **kwargs: Any) -> Any: """ :param config: An config object describing what to call and what params to use. In addition to the parameters, the config must contain: _target_ : target class or callable name (str) _recursive_: Construct nested objects as well (bool). True by default. may be overridden via a _recursive_ key in the kwargs :param args: Optional positional parameters pass-through :param kwargs: Optional named parameters to override parameters in the config object. Parameters not present in the config objects are being passed as is to the target. :return: if _target_ is a class name: the instantiated object if _target_ is a callable: the return value of the call """ if OmegaConf.is_none(config): return None if isinstance(config, TargetConf) and config._target_ == "???": # Specific check to give a good warning about failure to annotate _target_ as a string. raise InstantiationException( f"Missing value for {type(config).__name__}._target_. Check that it's properly annotated and overridden." f"\nA common problem is forgetting to annotate _target_ as a string : '_target_: str = ...'" ) if not ( isinstance(config, dict) or OmegaConf.is_config(config) or is_structured_config(config) ): raise HydraException(f"Unsupported config type : {type(config).__name__}") if isinstance(config, dict): configc = config.copy() _convert_container_targets_to_strings(configc) config = configc kwargsc = kwargs.copy() _convert_container_targets_to_strings(kwargsc) kwargs = kwargsc # make a copy to ensure we do not change the provided object config_copy = OmegaConf.structured(config, flags={"allow_objects": True}) if OmegaConf.is_config(config): config_copy._set_parent(config._get_parent()) config = config_copy assert OmegaConf.is_config(config) OmegaConf.set_readonly(config, False) OmegaConf.set_struct(config, False) target = _get_target_type(config, kwargs) try: config._set_flag("allow_objects", True) final_kwargs = _get_kwargs(config, **kwargs) return target(*args, **final_kwargs) except Exception as e: raise type(e)( f"Error instantiating/calling '{_convert_target_to_string(target)}' : {e}" )