def try_ann_to_type(ann, loc): if ann is None: return TensorType.get() if inspect.isclass(ann) and issubclass(ann, torch.Tensor): return TensorType.get() if is_tuple(ann): return TupleType([try_ann_to_type(a, loc) for a in ann.__args__]) if is_list(ann): elem_type = try_ann_to_type(ann.__args__[0], loc) if elem_type: return ListType(elem_type) if is_dict(ann): key = try_ann_to_type(ann.__args__[0], loc) value = try_ann_to_type(ann.__args__[1], loc) return DictType(key, value) if is_optional(ann): if issubclass(ann.__args__[1], type(None)): contained = ann.__args__[0] else: contained = ann.__args__[1] valid_type = try_ann_to_type(contained, loc) msg = "Unsupported annotation {} could not be resolved because {} could not be resolved." assert valid_type, msg.format(repr(ann), repr(contained)) return OptionalType(valid_type) if torch.distributed.rpc.is_available() and is_rref(ann): return RRefType(try_ann_to_type(ann.__args__[0], loc)) if is_future(ann): return FutureType(try_ann_to_type(ann.__args__[0], loc)) if ann is float: return FloatType.get() if ann is int: return IntType.get() if ann is str: return StringType.get() if ann is bool: return BoolType.get() if ann is Any: return AnyType.get() if ann is type(None): return NoneType.get() if inspect.isclass(ann) and hasattr(ann, "__torch_script_interface__"): return InterfaceType(_qualified_name(ann)) if ann is torch.device: return DeviceObjType.get() if ann is torch.dtype: return IntType.get() # dtype not yet bound in as its own type if inspect.isclass(ann) and issubclass(ann, enum.Enum): if not hasattr(ann, "__torch_script_class__"): torch.jit._script._recursive_compile_class(ann, loc) return EnumType(_qualified_name(ann), get_enum_value_type(ann, loc), list(ann)) if inspect.isclass(ann): if hasattr(ann, "__torch_script_class__"): return ClassType(_qualified_name(ann)) ignored_builtin_classes = (torch.nn.Module, tuple, list, Exception) if torch._jit_internal.can_compile_class(ann) and not issubclass( ann, ignored_builtin_classes): torch.jit._script._recursive_compile_class(ann, loc) return ClassType(_qualified_name(ann)) # Maybe resolve a NamedTuple to a Tuple Type def fake_rcb(key): return None return torch._C._resolve_type_from_object(ann, loc, fake_rcb)
def try_ann_to_type(ann, loc): if ann is inspect.Signature.empty: return TensorType.getInferred() if ann is None: return NoneType.get() if inspect.isclass(ann) and is_tensor(ann): return TensorType.get() if is_tuple(ann): return TupleType([try_ann_to_type(a, loc) for a in ann.__args__]) if is_list(ann): elem_type = try_ann_to_type(ann.__args__[0], loc) if elem_type: return ListType(elem_type) if is_dict(ann): key = try_ann_to_type(ann.__args__[0], loc) value = try_ann_to_type(ann.__args__[1], loc) # Raise error if key or value is None if key is None: raise ValueError( f"Unknown type annotation: '{ann.__args__[0]}' at {loc.highlight()}" ) if value is None: raise ValueError( f"Unknown type annotation: '{ann.__args__[1]}' at {loc.highlight()}" ) return DictType(key, value) if is_optional(ann): if issubclass(ann.__args__[1], type(None)): contained = ann.__args__[0] else: contained = ann.__args__[1] valid_type = try_ann_to_type(contained, loc) msg = "Unsupported annotation {} could not be resolved because {} could not be resolved." assert valid_type, msg.format(repr(ann), repr(contained)) return OptionalType(valid_type) if torch.distributed.rpc.is_available() and is_rref(ann): return RRefType(try_ann_to_type(ann.__args__[0], loc)) if is_future(ann): return FutureType(try_ann_to_type(ann.__args__[0], loc)) if ann is float: return FloatType.get() if ann is complex: return ComplexType.get() if ann is int: return IntType.get() if ann is str: return StringType.get() if ann is bool: return BoolType.get() if ann is Any: return AnyType.get() if ann is type(None): return NoneType.get() if inspect.isclass(ann) and hasattr(ann, "__torch_script_interface__"): return InterfaceType(ann.__torch_script_interface__) if ann is torch.device: return DeviceObjType.get() if ann is torch.Stream: return StreamObjType.get() if ann is torch.dtype: return IntType.get() # dtype not yet bound in as its own type if inspect.isclass(ann) and issubclass(ann, enum.Enum): if _get_script_class(ann) is None: scripted_class = torch.jit._script._recursive_compile_class( ann, loc) name = scripted_class.qualified_name() else: name = _qualified_name(ann) return EnumType(name, get_enum_value_type(ann, loc), list(ann)) if inspect.isclass(ann): maybe_script_class = _get_script_class(ann) if maybe_script_class is not None: return maybe_script_class if torch._jit_internal.can_compile_class(ann): return torch.jit._script._recursive_compile_class(ann, loc) # Maybe resolve a NamedTuple to a Tuple Type def fake_rcb(key): return None return torch._C._resolve_type_from_object(ann, loc, fake_rcb)