コード例 #1
0
    def infer_type(name, item):
        # The forward function from Module is special; never use this annotations; we
        # need to infer type directly using JIT.  I originally wanted to write
        # this test as isinstance(class_annotations[name], Callable) but
        # isinstance on typing things doesn't seem to work: isinstance(list, Callable)
        # is also true!
        inferred = False
        try:
            if name in class_annotations and class_annotations[
                    name] != torch.nn.Module.__annotations__["forward"]:
                ann_to_type = torch.jit.annotations.ann_to_type(
                    class_annotations[name], _jit_internal.fake_range())
                attr_type = torch._C.InferredType(ann_to_type)
            elif isinstance(item, torch.jit.Attribute):
                ann_to_type = torch.jit.annotations.ann_to_type(
                    item.type, _jit_internal.fake_range())
                attr_type = torch._C.InferredType(ann_to_type)
            else:
                attr_type = torch._C._jit_try_infer_type(item)
                inferred = True
        except RuntimeError as re:
            raise RuntimeError(
                "Error inferring type for {name}: {item}: {re}".format(
                    name=name, item=item, re=re))

        return attr_type, inferred
コード例 #2
0
 def infer_type(name, item):
     if name in class_annotations:
         attr_type = torch.jit.annotations.ann_to_type(class_annotations[name], _jit_internal.fake_range())
     elif isinstance(item, torch.jit.Attribute):
         attr_type = torch.jit.annotations.ann_to_type(item.type, _jit_internal.fake_range())
     else:
         attr_type = torch._C._jit_try_infer_type(item)
     return attr_type
コード例 #3
0
 def infer_type(name, item):
     # The forward function from Module is special; never use this annotations; we
     # need to infer type directly using JIT.  I originally wanted to write
     # this test as isinstance(class_annotations[name], Callable) but
     # isinstance on typing things doesn't seem to work: isinstance(list, Callable)
     # is also true!
     if name in class_annotations and class_annotations[name] != torch.nn.Module.__annotations__["forward"]:
         attr_type = torch.jit.annotations.ann_to_type(class_annotations[name], _jit_internal.fake_range())
     elif isinstance(item, torch.jit.Attribute):
         attr_type = torch.jit.annotations.ann_to_type(item.type, _jit_internal.fake_range())
     else:
         attr_type = torch._C._jit_try_infer_type(item)
     return attr_type
コード例 #4
0
def _check_no_signature(func):
    signature = torch.jit.annotations.get_signature(func, None,
                                                    _jit_internal.fake_range(),
                                                    inspect.ismethod(func))
    if signature is None:
        qual_name = _jit_internal._qualified_name(func)
        raise RuntimeError(
            "Must explicitly add type annotations to overloaded functions: {}".
            format(qual_name))
コード例 #5
0
def _get_named_tuple_properties(obj):
    assert issubclass(obj, tuple) and hasattr(obj, '_fields')
    fields = list(obj._fields)
    annotations = []
    has_annotations = hasattr(obj, '__annotations__')
    for field in fields:
        if has_annotations and field in obj.__annotations__:
            the_type = torch.jit.annotations.ann_to_type(obj.__annotations__[field], _jit_internal.fake_range())
            annotations.append(the_type)
        else:
            annotations.append(torch._C.TensorType.get())
    return type(obj).__name__, fields, annotations