Exemple #1
0
def get_jit_def(fn, def_name, self_name=None):
    """
    Build a JIT AST (TreeView) from the given function.

    Arguments:
        fn: A function object to compile
        def_name: The name to give to the resulting AST object. This is not
            always the same as `fn.__name__`, for example:
                def _forward(self):
                    ...
                forward = _forward
            In this case, the `__name__` attribute of the function object is "_forward",
            but we want the result AST to have the name "forward".
        self_name: If this function is a method, what the type name of `self` is.
    """
    sourcelines, file_lineno, filename = get_source_lines_and_file(
        fn, torch._C.ErrorReport.call_stack())
    source = ''.join(sourcelines)
    dedent_src = dedent(source)
    py_ast = ast.parse(dedent_src)
    if len(py_ast.body) != 1 or not isinstance(py_ast.body[0],
                                               ast.FunctionDef):
        raise RuntimeError("Expected a single top-level function")
    leading_whitespace_len = len(source.split('\n', 1)[0]) - len(
        dedent_src.split('\n', 1)[0])
    type_line = torch.jit.annotations.get_type_line(source)
    ctx = SourceContext(source, filename, file_lineno, leading_whitespace_len,
                        _uses_true_division(fn))
    return build_def(ctx,
                     py_ast.body[0],
                     type_line,
                     def_name,
                     self_name=self_name)
Exemple #2
0
def get_jit_class_def(cls, self_name):
    # Get defs for each method within the current class independently
    # TODO: proper overriding analysis when implementing class inheritance
    methods = inspect.getmembers(
        cls,
        predicate=lambda m:
        (inspect.ismethod(m) or inspect.isfunction(m)) and not is_static_fn(
            cls, m.__name__) and m.__name__ in cls.__dict__)
    methods = [
        get_jit_def(method[1], method[0], self_name=self_name)
        for method in methods
    ]

    properties = get_class_properties(cls, self_name)

    sourcelines, file_lineno, filename = get_source_lines_and_file(
        cls, torch._C.ErrorReport.call_stack())
    source = ''.join(sourcelines)
    dedent_src = dedent(source)
    py_ast = ast.parse(dedent_src)
    leading_whitespace_len = len(source.split('\n', 1)[0]) - len(
        dedent_src.split('\n', 1)[0])
    ctx = SourceContext(source, filename, file_lineno, leading_whitespace_len,
                        False)
    return build_class_def(ctx, py_ast.body[0], methods, properties, self_name)
Exemple #3
0
def get_num_params(fn, loc):
    try:
        source = dedent(''.join(get_source_lines_and_file(fn)[0]))
    except (TypeError, IOError):
        return None
    if source is None:
        return None
    py_ast = ast.parse(source)
    if len(py_ast.body) == 1 and isinstance(py_ast.body[0], ast.ClassDef):
        raise torch.jit.frontend.FrontendError(
            loc, "Cannot instantiate class '{}' in a script function".format(
                py_ast.body[0].name))
    if len(py_ast.body) != 1 or not isinstance(py_ast.body[0],
                                               ast.FunctionDef):
        raise torch.jit.frontend.FrontendError(
            loc, "Expected a single top-level function")
    py_def = py_ast.body[0]
    if py_def.args.vararg is not None:
        return None
    elif hasattr(py_def.args,
                 'kwonlyargs') and len(py_def.args.kwonlyargs) > 0:
        return None
    else:
        num_params = len(py_def.args.args)
        if inspect.ismethod(fn):
            num_params = num_params - 1
        return num_params
Exemple #4
0
def get_signature(fn, rcb, loc, is_method):
    # Python 3.5 adds support for the nice annotation syntax, so try that first.
    signature = None
    if PY35:
        signature = try_real_annotations(fn, loc)
        if signature is not None and is_method:
            # If this is a method, then the signaure will include a type for
            # `self`, but type comments do not contain a `self`. So strip it
            # away here so everything is consistent (`inspect.ismethod` does
            # not work here since `fn` is unbound at this point)
            param_types, return_type = signature
            param_types = param_types[1:]
            signature = (param_types, return_type)

    if signature is None:
        type_line, source = None, None
        try:
            source = dedent(''.join(get_source_lines_and_file(fn)[0]))
            type_line = get_type_line(source)
        except TypeError:
            pass
        # This might happen both because we failed to get the source of fn, or
        # because it didn't have any annotations.
        if type_line is not None:
            signature = parse_type_line(type_line, rcb, loc)

    return signature
    def persistent_id(obj):
        if isinstance(obj, type) and issubclass(obj, nn.Module):
            if obj in serialized_container_types:
                return None
            serialized_container_types[obj] = True
            source_file = source = None
            try:
                source_lines, _, source_file = get_source_lines_and_file(obj)
                source = ''.join(source_lines)
            except Exception:  # saving the source is optional, so we can ignore any errors
                warnings.warn("Couldn't retrieve source code for container of "
                              "type " + obj.__name__ + ". It won't be checked "
                              "for correctness upon loading.")
            return ('module', obj, source_file, source)

        elif torch.is_storage(obj):
            storage_type = normalize_storage_type(type(obj))
            # Offset is always 0, but we keep it for backwards compatibility
            # with the old serialization format (which supported storage views)
            offset = 0
            obj_key = str(obj._cdata)
            location = location_tag(obj)
            serialized_storages[obj_key] = obj
            is_view = obj._cdata != obj._cdata
            if is_view:
                view_metadata = (str(obj._cdata), offset, obj.size())
            else:
                view_metadata = None

            return ('storage', storage_type, obj_key, location, obj.size(),
                    view_metadata)
        return None
Exemple #6
0
def _get_overloaded_methods(method, mod_class):
    # TODO: __name__ not set for submodules in recursive script
    if not hasattr(method, "__name__"):
        return None
    qual_name = _qualified_name(method)
    class_name_map = _overloaded_methods.get(qual_name, None)
    if class_name_map is None:
        return None
    overloads = class_name_map.get(mod_class.__name__, None)
    if overloads is None:
        return None

    method_line_no = get_source_lines_and_file(method)[1]
    mod_class_fileno = get_source_lines_and_file(mod_class)[1]
    mod_end_fileno = mod_class_fileno + len(get_source_lines_and_file(mod_class)[0])
    if not (method_line_no >= mod_class_fileno and method_line_no <= mod_end_fileno):
        raise Exception("Overloads are not useable when a module is redeclared within the same file: " + str(method))
    return overloads
Exemple #7
0
def get_jit_def(fn, def_name, self_name=None, is_classmethod=False):
    """
    Build a JIT AST (TreeView) from the given function.

    Args:
        fn: A function object to compile
        def_name: The name to give to the resulting AST object. This is not
            always the same as `fn.__name__`, for example:
                def _forward(self):
                    ...
                forward = _forward
            In this case, the `__name__` attribute of the function object is "_forward",
            but we want the result AST to have the name "forward".
        self_name: If this function is a method, what the type name of `self` is.
    """
    sourcelines, file_lineno, filename = get_source_lines_and_file(
        fn, torch._C.ErrorReport.call_stack())
    sourcelines = normalize_source_lines(sourcelines)
    source = ''.join(sourcelines)
    dedent_src = dedent(source)
    py_ast = ast.parse(dedent_src)
    if len(py_ast.body) != 1 or not isinstance(py_ast.body[0],
                                               ast.FunctionDef):
        raise RuntimeError(
            f"Expected a single top-level function: {filename}:{file_lineno}")
    leading_whitespace_len = len(source.split('\n', 1)[0]) - len(
        dedent_src.split('\n', 1)[0])
    type_line = torch.jit.annotations.get_type_line(source)
    ctx = SourceContext(source, filename, file_lineno, leading_whitespace_len,
                        True)
    fn_def = py_ast.body[0]

    if is_classmethod:
        arg_name = fn_def.args.args[0].arg
        # Insert a statement that assigns the first argument to the class
        assign_stmt = ast.parse(f"{arg_name} = {self_name}").body[0]
        fn_def.body.insert(0, assign_stmt)

    # Swap out the function signature and body if it is unused
    if should_drop(fn):
        unused_fn_def = ast.parse(
            "def unused_fn(self: Any):\n\traise RuntimeError(\"Cannot call @unused methods\")"
        )
        if len(unused_fn_def.body) != 1 or not isinstance(
                unused_fn_def.body[0], ast.FunctionDef):
            raise RuntimeError(
                f"Expected a single top-level function: {filename}:{file_lineno}"
            )
        unused_def = unused_fn_def.body[0]
        fn_def.body = unused_def.body
        # kwarg/vararg not supported by `build_def`
        fn_def.args.kwarg = fn_def.args.vararg = None
        for arg in fn_def.args.args + fn_def.args.kwonlyargs:
            # Replace potentially unsupported type annotations by "Any"
            arg.annotation = unused_def.args.args[0].annotation

    return build_def(ctx, fn_def, type_line, def_name, self_name=self_name)
Exemple #8
0
def get_jit_def(fn, self_name=None):
    sourcelines, file_lineno, filename = get_source_lines_and_file(fn)
    source = ''.join(sourcelines)
    dedent_src = dedent(source)
    py_ast = ast.parse(dedent_src)
    if len(py_ast.body) != 1 or not isinstance(py_ast.body[0], ast.FunctionDef):
        raise RuntimeError("Expected a single top-level function")
    leading_whitespace_len = len(source.split('\n', 1)[0]) - len(dedent_src.split('\n', 1)[0])
    type_line = torch.jit.annotations.get_type_line(source)
    ctx = SourceContext(source, filename, file_lineno, leading_whitespace_len, _uses_true_division(fn))
    return build_def(ctx, py_ast.body[0], type_line, self_name)
Exemple #9
0
def get_jit_class_def(cls, self_name):
    # Get defs for each method independently
    methods = inspect.getmembers(
        cls, predicate=lambda m: inspect.ismethod(m) or inspect.isfunction(m))
    method_defs = [get_jit_def(method[1],
                   self_name=self_name) for method in methods]

    sourcelines, file_lineno, filename = get_source_lines_and_file(cls)
    source = ''.join(sourcelines)
    dedent_src = dedent(source)
    py_ast = ast.parse(dedent_src)
    leading_whitespace_len = len(source.split('\n', 1)[0]) - len(dedent_src.split('\n', 1)[0])
    ctx = SourceContext(source, filename, file_lineno, leading_whitespace_len, False)
    return build_class_def(ctx, py_ast.body[0], method_defs, self_name)
Exemple #10
0
def check_fn(fn, loc):
    # Make sure the function definition is not a class instantiation
    try:
        source = dedent(''.join(get_source_lines_and_file(fn)[0]))
    except (TypeError, IOError):
        return
    if source is None:
        return

    py_ast = ast.parse(source)
    if len(py_ast.body) == 1 and isinstance(py_ast.body[0], ast.ClassDef):
        raise torch.jit.frontend.FrontendError(
            loc, f"Cannot instantiate class '{py_ast.body[0].name}' in a script function")
    if len(py_ast.body) != 1 or not isinstance(py_ast.body[0], ast.FunctionDef):
        raise torch.jit.frontend.FrontendError(loc, "Expected a single top-level function")
Exemple #11
0
 def _check_container_source(container_type, source_file, original_source):
     try:
         current_source = "".join(
             get_source_lines_and_file(container_type)[0])
     except Exception:  # saving the source is optional, so we can ignore any errors
         warnings.warn("Couldn't retrieve source code for container of "
                       "type " + container_type.__name__ +
                       ". It won't be checked "
                       "for correctness upon loading.")
         return
     if original_source != current_source:
         if container_type.dump_patches:
             file_name = container_type.__name__ + ".patch"
             diff = difflib.unified_diff(
                 current_source.split("\n"),
                 original_source.split("\n"),
                 source_file,
                 source_file,
                 lineterm="",
             )
             lines = "\n".join(diff)
             try:
                 with open(file_name, "a+") as f:
                     file_size = f.seek(0, 2)
                     f.seek(0)
                     if file_size == 0:
                         f.write(lines)
                     elif file_size != len(lines) or f.read() != lines:
                         raise IOError
                 msg = ("Saved a reverse patch to " + file_name + ". "
                        "Run `patch -p0 < " + file_name +
                        "` to revert your "
                        "changes.")
             except IOError:
                 msg = ("Tried to save a patch, but couldn't create a "
                        "writable file " + file_name + ". Make sure it "
                        "doesn't exist and your working directory is "
                        "writable.")
         else:
             msg = ("you can retrieve the original source code by "
                    "accessing the object's source attribute or set "
                    "`torch.nn.Module.dump_patches = True` and use the "
                    "patch tool to revert the changes.")
         msg = "source code of class '{container_type}' has changed. {msg}".format(
             container_type=torch.typename(container_type), msg=msg)
         warnings.warn(msg, SourceChangeWarning)
Exemple #12
0
def get_signature(fn, rcb, loc):
    # Python 3.5 adds support for the nice annotation syntax, so try that first.
    if PY35:
        sig = try_real_annotations(fn)
        if sig is not None:
            return sig

    type_line, source = None, None
    try:
        source = dedent(''.join(get_source_lines_and_file(fn)[0]))
        type_line = get_type_line(source)
    except TypeError:
        pass
    # This might happen both because we failed to get the source of fn, or
    # because it didn't have any annotations.
    if type_line is None:
        return None

    return parse_type_line(type_line, rcb, loc)
Exemple #13
0
    def persistent_id(obj: Any) -> Optional[Tuple]:
        # FIXME: the docs say that persistent_id should only return a string
        # but torch store returns tuples. This works only in the binary protocol
        # see
        # https://docs.python.org/2/library/pickle.html#pickling-and-unpickling-external-objects
        # https://github.com/python/cpython/blob/master/Lib/pickle.py#L527-L537
        if isinstance(obj, type) and issubclass(obj, nn.Module):
            if obj in serialized_container_types:
                return None
            serialized_container_types[obj] = True
            source_file = source = None
            try:
                source_lines, _, source_file = get_source_lines_and_file(obj)
                source = ''.join(source_lines)
            except Exception:  # saving the source is optional, so we can ignore any errors
                warnings.warn("Couldn't retrieve source code for container of "
                              "type " + obj.__name__ + ". It won't be checked "
                              "for correctness upon loading.")
            return ('module', obj, source_file, source)

        elif torch.is_storage(obj):
            view_metadata: Optional[Tuple[str, int, int]]
            obj = cast(Storage, obj)
            storage_type = normalize_storage_type(type(obj))
            # Offset is always 0, but we keep it for backwards compatibility
            # with the old serialization format (which supported storage views)
            offset = 0
            obj_key = str(obj._cdata)
            location = location_tag(obj)
            serialized_storages[obj_key] = obj
            is_view = obj._cdata != obj._cdata
            if is_view:
                view_metadata = (str(obj._cdata), offset, obj.size())
            else:
                view_metadata = None

            return ('storage',
                    storage_type,
                    obj_key,
                    location,
                    obj.size(),
                    view_metadata)
        return None
Exemple #14
0
def get_jit_class_def(cls, self_name):
    # Get defs for each method within the current class independently
    # TODO: proper overriding analysis when implementing class inheritance
    methods = inspect.getmembers(
        cls,
        predicate=lambda m:
        (inspect.ismethod(m) or inspect.isfunction(m)) and not is_static_fn(
            cls, m.__name__) and m.__name__ in cls.__dict__)

    def is_classmethod(fn):
        return inspect.ismethod(fn) and getattr(fn, "__self__", None) == cls

    methods = [
        get_jit_def(obj,
                    name,
                    self_name=self_name,
                    is_classmethod=is_classmethod(obj))
        for (name, obj) in methods
    ]

    properties = get_class_properties(cls, self_name)

    sourcelines, file_lineno, filename = get_source_lines_and_file(
        cls, torch._C.ErrorReport.call_stack())
    source = ''.join(sourcelines)
    dedent_src = dedent(source)
    py_ast = ast.parse(dedent_src)
    leading_whitespace_len = len(source.split('\n', 1)[0]) - len(
        dedent_src.split('\n', 1)[0])
    ctx = make_source_context(source, filename, file_lineno,
                              leading_whitespace_len, False)
    class_ast = py_ast.body[0]
    assert isinstance(class_ast, ast.ClassDef)
    assigns = get_class_assigns(ctx, class_ast)

    return build_class_def(ctx, class_ast, methods, properties, self_name,
                           assigns)