Exemplo n.º 1
0
    def infer_type(name, item):
        # The forward function from Module is special; never use this annotations; we
        # need to infer type directly using JIT.  I originally wanted to write
        # this test as isinstance(class_annotations[name], Callable) but
        # isinstance on typing things doesn't seem to work: isinstance(list, Callable)
        # is also true!
        inferred = False
        try:
            if name in class_annotations and class_annotations[
                    name] != torch.nn.Module.__annotations__["forward"]:
                ann_to_type = torch.jit.annotations.ann_to_type(
                    class_annotations[name], fake_range())
                attr_type = torch._C.InferredType(ann_to_type)
            elif isinstance(item, torch.jit.Attribute):
                ann_to_type = torch.jit.annotations.ann_to_type(
                    item.type, fake_range())
                attr_type = torch._C.InferredType(ann_to_type)
            else:
                attr_type = torch._C._jit_try_infer_type(item)
                inferred = True
        except RuntimeError as re:
            raise RuntimeError(
                "Error inferring type for {name}: {item}: {re}".format(
                    name=name, item=item, re=re))

        return attr_type, inferred
Exemplo n.º 2
0
def _get_named_tuple_properties(obj):
    assert issubclass(obj, tuple) and hasattr(obj, "_fields")
    if hasattr(obj, "_field_defaults"):
        defaults = [
            obj._field_defaults[field]
            for field in obj._fields
            if field in obj._field_defaults
        ]
    else:
        defaults = []
    # In 3.10 recommended way to get annotations is to call `inspect.get_annotations` function
    # Also, annotations from base class are not inherited so they need to be queried explicitly
    if sys.version_info[:2] < (3, 10):
        obj_annotations = getattr(obj, "__annotations__", {})
    else:
        obj_annotations = inspect.get_annotations(obj)
        if len(obj_annotations) == 0 and hasattr(obj, "__base__"):
            obj_annotations = inspect.get_annotations(obj.__base__)

    annotations = []
    for field in obj._fields:
        if field in obj_annotations:
            the_type = torch.jit.annotations.ann_to_type(
                obj_annotations[field], fake_range()
            )
            annotations.append(the_type)
        else:
            annotations.append(torch._C.TensorType.getInferred())
    return type(obj).__name__, obj._fields, annotations, defaults
Exemplo n.º 3
0
def _check_no_signature(func):
    signature = torch.jit.annotations.get_signature(func, None, fake_range(),
                                                    inspect.ismethod(func))
    if signature is None:
        qual_name = _jit_internal._qualified_name(func)
        raise RuntimeError(
            "Must explicitly add type annotations to overloaded functions: {}".
            format(qual_name))
Exemplo n.º 4
0
def _get_named_tuple_properties(obj):
    assert issubclass(obj, tuple) and hasattr(obj, '_fields')
    if hasattr(obj, "_field_defaults"):
        defaults = [obj._field_defaults[field]
                    for field in obj._fields
                    if field in obj._field_defaults]
    else:
        defaults = []
    annotations = []
    has_annotations = hasattr(obj, '__annotations__')
    for field in obj._fields:
        if has_annotations and field in obj.__annotations__:
            the_type = torch.jit.annotations.ann_to_type(obj.__annotations__[field], fake_range())
            annotations.append(the_type)
        else:
            annotations.append(torch._C.TensorType.getInferred())
    return type(obj).__name__, obj._fields, annotations, defaults