def infer_type(name, item): # The forward function from Module is special; never use this annotations; we # need to infer type directly using JIT. I originally wanted to write # this test as isinstance(class_annotations[name], Callable) but # isinstance on typing things doesn't seem to work: isinstance(list, Callable) # is also true! inferred = False try: if name in class_annotations and class_annotations[ name] != torch.nn.Module.__annotations__["forward"]: ann_to_type = torch.jit.annotations.ann_to_type( class_annotations[name], _jit_internal.fake_range()) attr_type = torch._C.InferredType(ann_to_type) elif isinstance(item, torch.jit.Attribute): ann_to_type = torch.jit.annotations.ann_to_type( item.type, _jit_internal.fake_range()) attr_type = torch._C.InferredType(ann_to_type) else: attr_type = torch._C._jit_try_infer_type(item) inferred = True except RuntimeError as re: raise RuntimeError( "Error inferring type for {name}: {item}: {re}".format( name=name, item=item, re=re)) return attr_type, inferred
def infer_type(name, item): if name in class_annotations: attr_type = torch.jit.annotations.ann_to_type(class_annotations[name], _jit_internal.fake_range()) elif isinstance(item, torch.jit.Attribute): attr_type = torch.jit.annotations.ann_to_type(item.type, _jit_internal.fake_range()) else: attr_type = torch._C._jit_try_infer_type(item) return attr_type
def infer_type(name, item): # The forward function from Module is special; never use this annotations; we # need to infer type directly using JIT. I originally wanted to write # this test as isinstance(class_annotations[name], Callable) but # isinstance on typing things doesn't seem to work: isinstance(list, Callable) # is also true! if name in class_annotations and class_annotations[name] != torch.nn.Module.__annotations__["forward"]: attr_type = torch.jit.annotations.ann_to_type(class_annotations[name], _jit_internal.fake_range()) elif isinstance(item, torch.jit.Attribute): attr_type = torch.jit.annotations.ann_to_type(item.type, _jit_internal.fake_range()) else: attr_type = torch._C._jit_try_infer_type(item) return attr_type
def _check_no_signature(func): signature = torch.jit.annotations.get_signature(func, None, _jit_internal.fake_range(), inspect.ismethod(func)) if signature is None: qual_name = _jit_internal._qualified_name(func) raise RuntimeError( "Must explicitly add type annotations to overloaded functions: {}". format(qual_name))
def _get_named_tuple_properties(obj): assert issubclass(obj, tuple) and hasattr(obj, '_fields') fields = list(obj._fields) annotations = [] has_annotations = hasattr(obj, '__annotations__') for field in fields: if has_annotations and field in obj.__annotations__: the_type = torch.jit.annotations.ann_to_type(obj.__annotations__[field], _jit_internal.fake_range()) annotations.append(the_type) else: annotations.append(torch._C.TensorType.get()) return type(obj).__name__, fields, annotations