def get_default_args_for_class(cls): """ Get default arguments for all methods in a class (except for static methods). Args: cls: type - The class type to inspect for default arguments. Returns: A Dict[str, Dict[str, Any]] which maps each method name to a Dict[str, Any] that maps each argument name to its default value. """ # Get methods (except static methods because those are compiled separately as # if they were independent script functions). methods = inspect.getmembers( cls, predicate=lambda m: (inspect.ismethod(m) or inspect.isfunction(m)) and not is_static_fn( cls, m.__name__) and m.__name__ in cls.__dict__) # Get method defaults. Property defaults do not need to be considered # because setters cannot be invoked without a value. defaults = { method_name: get_default_args(method_impl) for method_name, method_impl in methods } return defaults
def get_jit_class_def(cls, self_name): # Get defs for each method within the current class independently # TODO: proper overriding analysis when implementing class inheritance methods = inspect.getmembers( cls, predicate=lambda m: (inspect.ismethod(m) or inspect.isfunction(m)) and not is_static_fn( cls, m.__name__) and m.__name__ in cls.__dict__) methods = [ get_jit_def(method[1], method[0], self_name=self_name) for method in methods ] properties = get_class_properties(cls, self_name) sourcelines, file_lineno, filename = get_source_lines_and_file( cls, torch._C.ErrorReport.call_stack()) source = ''.join(sourcelines) dedent_src = dedent(source) py_ast = ast.parse(dedent_src) leading_whitespace_len = len(source.split('\n', 1)[0]) - len( dedent_src.split('\n', 1)[0]) ctx = SourceContext(source, filename, file_lineno, leading_whitespace_len, False) return build_class_def(ctx, py_ast.body[0], methods, properties, self_name)
def get_jit_class_def(cls, self_name): # Get defs for each method within the current class independently # TODO: proper overriding analysis when implementing class inheritance methods = inspect.getmembers( cls, predicate=lambda m: (inspect.ismethod(m) or inspect.isfunction(m)) and not is_static_fn(cls, m.__name__) and m.__name__ in cls.__dict__ ) def is_classmethod(fn): return inspect.ismethod(fn) and getattr(fn, "__self__", None) == cls methods = [get_jit_def(obj, name, self_name=self_name, is_classmethod=is_classmethod(obj)) for (name, obj) in methods] properties = get_class_properties(cls, self_name) sourcelines, file_lineno, filename = get_source_lines_and_file(cls, torch._C.ErrorReport.call_stack()) source = ''.join(sourcelines) dedent_src = dedent(source) py_ast = ast.parse(dedent_src) leading_whitespace_len = len(source.split('\n', 1)[0]) - len(dedent_src.split('\n', 1)[0]) ctx = make_source_context(source, filename, file_lineno, leading_whitespace_len, False) class_ast = py_ast.body[0] assert isinstance(class_ast, ast.ClassDef) assigns = get_class_assigns(ctx, class_ast) return build_class_def(ctx, class_ast, methods, properties, self_name, assigns)
def get_jit_class_def(cls, self_name): # Get defs for each method within the current class independently # TODO: proper overriding analysis when implementing class inheritance methods = inspect.getmembers( cls, predicate=lambda m: (inspect.ismethod(m) or inspect.isfunction(m)) and not is_static_fn( cls, m.__name__) and m.__name__ in cls.__dict__) def is_classmethod(fn): return inspect.ismethod(fn) and getattr(fn, "__self__", None) == cls # Get and parse the source code for this class sourcelines, file_lineno, filename = get_source_lines_and_file( cls, torch._C.ErrorReport.call_stack()) source = ''.join(sourcelines) dedent_src = dedent(source) py_ast = ast.parse(dedent_src) class_ast = py_ast.body[0] assert isinstance(class_ast, ast.ClassDef) # Special case for dataclasses. In general we need access to the source code for # an object in order to JIT compile it. But the dataclasses module dynamically synthesizes # magic methods for classes, and we can't get the source code for these methods. As a # workaround, we synthesize TorchScript-friendly implementations ourselves. if dataclasses.is_dataclass(cls): # Detect whether the user manually implemented any of the magic methods. If they did, # we don't want to synthesize/override them. overrides = { method.name for method in class_ast.body if isinstance(method, ast.FunctionDef) and method.name in DATACLASS_MAGIC_METHODS } for i, (name, _) in enumerate(methods): # Is this a magic method we can synthesize? synthesizer_fn = DATACLASS_MAGIC_METHODS.get(name) if synthesizer_fn and name not in overrides: methods[i] = name, synthesizer_fn(cls) method_defs = [ get_jit_def(obj, name, self_name=self_name, is_classmethod=is_classmethod(obj)) for (name, obj) in methods ] properties = get_class_properties(cls, self_name) leading_whitespace_len = len(source.split('\n', 1)[0]) - len( dedent_src.split('\n', 1)[0]) ctx = make_source_context(source, filename, file_lineno, leading_whitespace_len, False) assigns = get_class_assigns(ctx, class_ast) return build_class_def(ctx, class_ast, method_defs, properties, self_name, assigns)