def infer_type(name, item): # The forward function from Module is special; never use this annotations; we # need to infer type directly using JIT. I originally wanted to write # this test as isinstance(class_annotations[name], Callable) but # isinstance on typing things doesn't seem to work: isinstance(list, Callable) # is also true! inferred = False try: if name in class_annotations and class_annotations[ name] != torch.nn.Module.__annotations__["forward"]: ann_to_type = torch.jit.annotations.ann_to_type( class_annotations[name], fake_range()) attr_type = torch._C.InferredType(ann_to_type) elif isinstance(item, torch.jit.Attribute): ann_to_type = torch.jit.annotations.ann_to_type( item.type, fake_range()) attr_type = torch._C.InferredType(ann_to_type) else: attr_type = torch._C._jit_try_infer_type(item) inferred = True except RuntimeError as re: raise RuntimeError( "Error inferring type for {name}: {item}: {re}".format( name=name, item=item, re=re)) return attr_type, inferred
def _get_named_tuple_properties(obj): assert issubclass(obj, tuple) and hasattr(obj, "_fields") if hasattr(obj, "_field_defaults"): defaults = [ obj._field_defaults[field] for field in obj._fields if field in obj._field_defaults ] else: defaults = [] # In 3.10 recommended way to get annotations is to call `inspect.get_annotations` function # Also, annotations from base class are not inherited so they need to be queried explicitly if sys.version_info[:2] < (3, 10): obj_annotations = getattr(obj, "__annotations__", {}) else: obj_annotations = inspect.get_annotations(obj) if len(obj_annotations) == 0 and hasattr(obj, "__base__"): obj_annotations = inspect.get_annotations(obj.__base__) annotations = [] for field in obj._fields: if field in obj_annotations: the_type = torch.jit.annotations.ann_to_type( obj_annotations[field], fake_range() ) annotations.append(the_type) else: annotations.append(torch._C.TensorType.getInferred()) return type(obj).__name__, obj._fields, annotations, defaults
def _check_no_signature(func): signature = torch.jit.annotations.get_signature(func, None, fake_range(), inspect.ismethod(func)) if signature is None: qual_name = _jit_internal._qualified_name(func) raise RuntimeError( "Must explicitly add type annotations to overloaded functions: {}". format(qual_name))
def _get_named_tuple_properties(obj): assert issubclass(obj, tuple) and hasattr(obj, '_fields') if hasattr(obj, "_field_defaults"): defaults = [obj._field_defaults[field] for field in obj._fields if field in obj._field_defaults] else: defaults = [] annotations = [] has_annotations = hasattr(obj, '__annotations__') for field in obj._fields: if has_annotations and field in obj.__annotations__: the_type = torch.jit.annotations.ann_to_type(obj.__annotations__[field], fake_range()) annotations.append(the_type) else: annotations.append(torch._C.TensorType.getInferred()) return type(obj).__name__, obj._fields, annotations, defaults