Exemple #1
0
def get_default_args_for_class(cls):
    """
    Get default arguments for all methods in a class (except for static methods).

    Args:
        cls: type - The class type to inspect for default arguments.
    Returns:
        A Dict[str, Dict[str, Any]] which maps each method name to a Dict[str, Any]
        that maps each argument name to its default value.
    """
    # Get methods (except static methods because those are compiled separately as
    # if they were independent script functions).
    methods = inspect.getmembers(
        cls,
        predicate=lambda m:
        (inspect.ismethod(m) or inspect.isfunction(m)) and not is_static_fn(
            cls, m.__name__) and m.__name__ in cls.__dict__)

    # Get method defaults. Property defaults do not need to be considered
    # because setters cannot be invoked without a value.
    defaults = {
        method_name: get_default_args(method_impl)
        for method_name, method_impl in methods
    }

    return defaults
Exemple #2
0
def get_jit_class_def(cls, self_name):
    # Get defs for each method within the current class independently
    # TODO: proper overriding analysis when implementing class inheritance
    methods = inspect.getmembers(
        cls,
        predicate=lambda m:
        (inspect.ismethod(m) or inspect.isfunction(m)) and not is_static_fn(
            cls, m.__name__) and m.__name__ in cls.__dict__)
    methods = [
        get_jit_def(method[1], method[0], self_name=self_name)
        for method in methods
    ]

    properties = get_class_properties(cls, self_name)

    sourcelines, file_lineno, filename = get_source_lines_and_file(
        cls, torch._C.ErrorReport.call_stack())
    source = ''.join(sourcelines)
    dedent_src = dedent(source)
    py_ast = ast.parse(dedent_src)
    leading_whitespace_len = len(source.split('\n', 1)[0]) - len(
        dedent_src.split('\n', 1)[0])
    ctx = SourceContext(source, filename, file_lineno, leading_whitespace_len,
                        False)
    return build_class_def(ctx, py_ast.body[0], methods, properties, self_name)
Exemple #3
0
def get_jit_class_def(cls, self_name):
    # Get defs for each method within the current class independently
    # TODO: proper overriding analysis when implementing class inheritance
    methods = inspect.getmembers(
        cls,
        predicate=lambda m: (inspect.ismethod(m) or inspect.isfunction(m))
        and not is_static_fn(cls, m.__name__)
        and m.__name__ in cls.__dict__
    )

    def is_classmethod(fn):
        return inspect.ismethod(fn) and getattr(fn, "__self__", None) == cls

    methods = [get_jit_def(obj,
                           name,
                           self_name=self_name,
                           is_classmethod=is_classmethod(obj)) for (name, obj) in methods]

    properties = get_class_properties(cls, self_name)

    sourcelines, file_lineno, filename = get_source_lines_and_file(cls, torch._C.ErrorReport.call_stack())
    source = ''.join(sourcelines)
    dedent_src = dedent(source)
    py_ast = ast.parse(dedent_src)
    leading_whitespace_len = len(source.split('\n', 1)[0]) - len(dedent_src.split('\n', 1)[0])
    ctx = make_source_context(source, filename, file_lineno, leading_whitespace_len, False)
    class_ast = py_ast.body[0]
    assert isinstance(class_ast, ast.ClassDef)
    assigns = get_class_assigns(ctx, class_ast)

    return build_class_def(ctx, class_ast, methods, properties, self_name, assigns)
Exemple #4
0
def get_jit_class_def(cls, self_name):
    # Get defs for each method within the current class independently
    # TODO: proper overriding analysis when implementing class inheritance
    methods = inspect.getmembers(
        cls,
        predicate=lambda m:
        (inspect.ismethod(m) or inspect.isfunction(m)) and not is_static_fn(
            cls, m.__name__) and m.__name__ in cls.__dict__)

    def is_classmethod(fn):
        return inspect.ismethod(fn) and getattr(fn, "__self__", None) == cls

    # Get and parse the source code for this class
    sourcelines, file_lineno, filename = get_source_lines_and_file(
        cls, torch._C.ErrorReport.call_stack())
    source = ''.join(sourcelines)

    dedent_src = dedent(source)
    py_ast = ast.parse(dedent_src)

    class_ast = py_ast.body[0]
    assert isinstance(class_ast, ast.ClassDef)

    # Special case for dataclasses. In general we need access to the source code for
    # an object in order to JIT compile it. But the dataclasses module dynamically synthesizes
    # magic methods for classes, and we can't get the source code for these methods. As a
    # workaround, we synthesize TorchScript-friendly implementations ourselves.
    if dataclasses.is_dataclass(cls):
        # Detect whether the user manually implemented any of the magic methods. If they did,
        # we don't want to synthesize/override them.
        overrides = {
            method.name
            for method in class_ast.body if isinstance(method, ast.FunctionDef)
            and method.name in DATACLASS_MAGIC_METHODS
        }

        for i, (name, _) in enumerate(methods):
            # Is this a magic method we can synthesize?
            synthesizer_fn = DATACLASS_MAGIC_METHODS.get(name)
            if synthesizer_fn and name not in overrides:
                methods[i] = name, synthesizer_fn(cls)

    method_defs = [
        get_jit_def(obj,
                    name,
                    self_name=self_name,
                    is_classmethod=is_classmethod(obj))
        for (name, obj) in methods
    ]
    properties = get_class_properties(cls, self_name)

    leading_whitespace_len = len(source.split('\n', 1)[0]) - len(
        dedent_src.split('\n', 1)[0])
    ctx = make_source_context(source, filename, file_lineno,
                              leading_whitespace_len, False)
    assigns = get_class_assigns(ctx, class_ast)

    return build_class_def(ctx, class_ast, method_defs, properties, self_name,
                           assigns)