def collect_imports(imports: Set[Type], type_: Type) -> None: for arg in get_args(type_): if arg is not ...: collect_imports(imports, arg) if _resolve_optional(type_)[0] and type_ is not Any: type_ = Optional imports.add(type_)
def collect_imports(imports: Set[Any], type_: Any) -> None: if is_list_annotation(type_): collect_imports(imports, get_list_element_type(type_)) type_ = List elif is_dict_annotation(type_): kvt = get_dict_key_value_types(type_) collect_imports(imports, kvt[0]) collect_imports(imports, kvt[1]) type_ = Dict else: is_optional = _resolve_optional(type_)[0] if is_optional and type_ is not Any: type_ = Optional imports.add(type_)
def is_incompatible(type_: Type[Any]) -> bool: opt = _resolve_optional(type_) # Unions are not supported (Except Optional) if not opt[0] and _is_union(type_): return True type_ = opt[1] if type_ in (type(None), tuple, list, dict): return False try: if is_list_annotation(type_): lt = get_list_element_type(type_) return is_incompatible(lt) if is_dict_annotation(type_): kvt = get_dict_key_value_types(type_) if not issubclass(kvt[0], (str, Enum)): return True return is_incompatible(kvt[1]) if is_tuple_annotation(type_): for arg in type_.__args__: if arg is not ... and is_incompatible(arg): return True return False if get_origin(type_) is Callable: args = get_args(type_) for arg in args[0]: if arg is not ... and is_incompatible(arg): return True if is_incompatible(args[1]): return True return False except ValidationError: return True if type_ is Any or issubclass(type_, (int, float, str, bool, Enum)): return False if is_structured_config(type_): try: OmegaConf.structured(type_) # verify it's actually legal except ValidationError as e: log.debug( f"Failed to create DictConfig from ({type_.__name__}) : {e}, flagging as incompatible" ) return True return False return True
def check_node_metadata( node: Container, type_hint: Any, key_type: Any, elt_type: Any, obj_type: Any, ) -> None: value_optional, value_ref_type = _resolve_optional(type_hint) assert node._metadata.optional == value_optional assert node._metadata.ref_type == value_ref_type assert node._metadata.key_type == key_type assert node._metadata.element_type == elt_type assert node._metadata.object_type == obj_type if is_dict_annotation(value_ref_type) or is_structured_config( value_ref_type): assert isinstance(node, DictConfig) elif is_list_annotation(value_ref_type): assert isinstance(node, ListConfig)
def test_resolve_optional_support_pep_604() -> None: if sys.version_info >= (3, 10): # this if-statement is for mypy's benefit assert _resolve_optional(int | str) == (False, Union[int, str]) assert _resolve_optional(Optional[int | str]) == (True, Union[int, str]) assert _resolve_optional(int | Optional[str]) == (True, Union[int, str]) assert _resolve_optional(int | Union[str, float]) == ( False, Union[int, str, float], ) assert _resolve_optional(int | Union[str, Optional[float]]) == ( True, Union[int, str, float], ) assert _resolve_optional(int | str | None) == (True, Union[int, str]) assert _resolve_optional(int | str | NoneType) == (True, Union[int, str])
def type_str(t: Any) -> str: is_optional, t = _resolve_optional(t) if t is None: return type(t).__name__ if t is Any: return "Any" if t is ...: return "..." if sys.version_info < (3, 7, 0): # pragma: no cover # Python 3.6 if hasattr(t, "__name__"): name = str(t.__name__) else: if t.__origin__ is not None: name = type_str(t.__origin__) else: name = str(t) if name.startswith("typing."): name = name[len("typing.") :] else: # pragma: no cover # Python >= 3.7 if hasattr(t, "__name__"): name = str(t.__name__) else: if t._name is None: if t.__origin__ is not None: name = type_str(t.__origin__) else: name = str(t._name) args = getattr(t, "__args__", None) if args is not None: args = ", ".join([type_str(t) for t in t.__args__]) ret = f"{name}[{args}]" else: ret = name if is_optional: return f"Optional[{ret}]" else: return ret
def test_resolve_optional(type_: Any, expected_optional: bool, expected_type: Any) -> None: resolved_optional, resolved_type = _resolve_optional(type_) assert resolved_optional == expected_optional assert resolved_type == expected_type