Ejemplo n.º 1
0
def get_jit_def(fn, def_name, self_name=None):
    """
    Build a JIT AST (TreeView) from the given function.

    Arguments:
        fn: A function object to compile
        def_name: The name to give to the resulting AST object. This is not
            always the same as `fn.__name__`, for example:
                def _forward(self):
                    ...
                forward = _forward
            In this case, the `__name__` attribute of the function object is "_forward",
            but we want the result AST to have the name "forward".
        self_name: If this function is a method, what the type name of `self` is.
    """
    sourcelines, file_lineno, filename = get_source_lines_and_file(
        fn, torch._C.ErrorReport.call_stack())
    source = ''.join(sourcelines)
    dedent_src = dedent(source)
    py_ast = ast.parse(dedent_src)
    if len(py_ast.body) != 1 or not isinstance(py_ast.body[0],
                                               ast.FunctionDef):
        raise RuntimeError("Expected a single top-level function")
    leading_whitespace_len = len(source.split('\n', 1)[0]) - len(
        dedent_src.split('\n', 1)[0])
    type_line = torch.jit.annotations.get_type_line(source)
    ctx = SourceContext(source, filename, file_lineno, leading_whitespace_len,
                        _uses_true_division(fn))
    return build_def(ctx,
                     py_ast.body[0],
                     type_line,
                     def_name,
                     self_name=self_name)
Ejemplo n.º 2
0
def get_jit_class_def(cls, self_name):
    # Get defs for each method within the current class independently
    # TODO: proper overriding analysis when implementing class inheritance
    methods = inspect.getmembers(
        cls,
        predicate=lambda m:
        (inspect.ismethod(m) or inspect.isfunction(m)) and not is_static_fn(
            cls, m.__name__) and m.__name__ in cls.__dict__)
    methods = [
        get_jit_def(method[1], method[0], self_name=self_name)
        for method in methods
    ]

    properties = get_class_properties(cls, self_name)

    sourcelines, file_lineno, filename = get_source_lines_and_file(
        cls, torch._C.ErrorReport.call_stack())
    source = ''.join(sourcelines)
    dedent_src = dedent(source)
    py_ast = ast.parse(dedent_src)
    leading_whitespace_len = len(source.split('\n', 1)[0]) - len(
        dedent_src.split('\n', 1)[0])
    ctx = SourceContext(source, filename, file_lineno, leading_whitespace_len,
                        False)
    return build_class_def(ctx, py_ast.body[0], methods, properties, self_name)
Ejemplo n.º 3
0
def get_jit_def(fn, def_name, self_name=None, is_classmethod=False):
    """
    Build a JIT AST (TreeView) from the given function.

    Args:
        fn: A function object to compile
        def_name: The name to give to the resulting AST object. This is not
            always the same as `fn.__name__`, for example:
                def _forward(self):
                    ...
                forward = _forward
            In this case, the `__name__` attribute of the function object is "_forward",
            but we want the result AST to have the name "forward".
        self_name: If this function is a method, what the type name of `self` is.
    """
    sourcelines, file_lineno, filename = get_source_lines_and_file(
        fn, torch._C.ErrorReport.call_stack())
    sourcelines = normalize_source_lines(sourcelines)
    source = ''.join(sourcelines)
    dedent_src = dedent(source)
    py_ast = ast.parse(dedent_src)
    if len(py_ast.body) != 1 or not isinstance(py_ast.body[0],
                                               ast.FunctionDef):
        raise RuntimeError(
            f"Expected a single top-level function: {filename}:{file_lineno}")
    leading_whitespace_len = len(source.split('\n', 1)[0]) - len(
        dedent_src.split('\n', 1)[0])
    type_line = torch.jit.annotations.get_type_line(source)
    ctx = SourceContext(source, filename, file_lineno, leading_whitespace_len,
                        True)
    fn_def = py_ast.body[0]

    if is_classmethod:
        arg_name = fn_def.args.args[0].arg
        # Insert a statement that assigns the first argument to the class
        assign_stmt = ast.parse(f"{arg_name} = {self_name}").body[0]
        fn_def.body.insert(0, assign_stmt)

    # Swap out the function signature and body if it is unused
    if should_drop(fn):
        unused_fn_def = ast.parse(
            "def unused_fn(self: Any):\n\traise RuntimeError(\"Cannot call @unused methods\")"
        )
        if len(unused_fn_def.body) != 1 or not isinstance(
                unused_fn_def.body[0], ast.FunctionDef):
            raise RuntimeError(
                f"Expected a single top-level function: {filename}:{file_lineno}"
            )
        unused_def = unused_fn_def.body[0]
        fn_def.body = unused_def.body
        # kwarg/vararg not supported by `build_def`
        fn_def.args.kwarg = fn_def.args.vararg = None
        for arg in fn_def.args.args + fn_def.args.kwonlyargs:
            # Replace potentially unsupported type annotations by "Any"
            arg.annotation = unused_def.args.args[0].annotation

    return build_def(ctx, fn_def, type_line, def_name, self_name=self_name)
Ejemplo n.º 4
0
def get_jit_def(fn, self_name=None):
    sourcelines, file_lineno, filename = get_source_lines_and_file(fn)
    source = ''.join(sourcelines)
    dedent_src = dedent(source)
    py_ast = ast.parse(dedent_src)
    if len(py_ast.body) != 1 or not isinstance(py_ast.body[0], ast.FunctionDef):
        raise RuntimeError("Expected a single top-level function")
    leading_whitespace_len = len(source.split('\n', 1)[0]) - len(dedent_src.split('\n', 1)[0])
    type_line = torch.jit.annotations.get_type_line(source)
    ctx = SourceContext(source, filename, file_lineno, leading_whitespace_len, _uses_true_division(fn))
    return build_def(ctx, py_ast.body[0], type_line, self_name)