def get_class_properties(cls, self_name): """ Get a list of Property objects representing the properties of a class. Args: cls: The class to get properties of. self_name: The name of the class that the properties should belong to. Returns: A list of Property objects corresponding to the properties of cls. Property here refers to the subclass of TreeView. """ props = inspect.getmembers(cls, predicate=lambda m: isinstance(m, property)) # Any property that should not compiled must be in this list on the Module. unused_properties = getattr(cls, "__jit_unused_properties__", []) # Create Property TreeView objects from inspected property objects. properties = [] for prop in props: if prop[0] not in unused_properties and not should_drop(prop[1].fget): getter = get_jit_def(prop[1].fget, f"__{prop[0]}_getter", self_name=self_name) setter = get_jit_def(prop[1].fset, f"__{prop[0]}_setter", self_name=self_name) if prop[1].fset else None properties.append( Property(getter.range(), Ident(getter.range(), prop[0]), getter, setter)) return properties
def get_jit_def(fn, def_name, self_name=None, is_classmethod=False): """ Build a JIT AST (TreeView) from the given function. Args: fn: A function object to compile or a pre-parsed ParsedDef object def_name: The name to give to the resulting AST object. This is not always the same as `fn.__name__`, for example: def _forward(self): ... forward = _forward In this case, the `__name__` attribute of the function object is "_forward", but we want the result AST to have the name "forward". self_name: If this function is a method, what the type name of `self` is. """ parsed_def = parse_def(fn) if not isinstance(fn, ParsedDef) else fn type_line = torch.jit.annotations.get_type_line(parsed_def.source) fn_def = parsed_def.ast.body[0] if is_classmethod: arg_name = fn_def.args.args[0].arg # Insert a statement that assigns the first argument to the class assign_stmt = ast.parse(f"{arg_name} = {self_name}").body[0] fn_def.body.insert(0, assign_stmt) # Swap out the function signature and body if it is unused if should_drop(fn): unused_fn_def = ast.parse( "def unused_fn(self: Any):\n\traise RuntimeError(\"Cannot call @unused methods\")" ) if len(unused_fn_def.body) != 1 or not isinstance( unused_fn_def.body[0], ast.FunctionDef): raise RuntimeError( f"Expected a single top-level function: {parsed_def.filename}:{parsed_def.file_lineno}" ) unused_def = unused_fn_def.body[0] fn_def.body = unused_def.body # kwarg/vararg not supported by `build_def` fn_def.args.kwarg = fn_def.args.vararg = None for arg in fn_def.args.args + fn_def.args.kwonlyargs: # Replace potentially unsupported type annotations by "Any" arg.annotation = unused_def.args.args[0].annotation # If MonkeyType is installed, get all the consolidated type traces # for the arguments from type_trace_db type_trace_db = torch.jit._script._get_type_trace_db() pdt_arg_types = None if monkeytype_trace and not isinstance(fn, ParsedDef): qualname = get_qualified_name(fn) pdt_arg_types = type_trace_db.get_args_types(qualname) return build_def(parsed_def.ctx, fn_def, type_line, def_name, self_name=self_name, pdt_arg_types=pdt_arg_types)
def get_jit_def(fn, def_name, self_name=None, is_classmethod=False): """ Build a JIT AST (TreeView) from the given function. Args: fn: A function object to compile def_name: The name to give to the resulting AST object. This is not always the same as `fn.__name__`, for example: def _forward(self): ... forward = _forward In this case, the `__name__` attribute of the function object is "_forward", but we want the result AST to have the name "forward". self_name: If this function is a method, what the type name of `self` is. """ sourcelines, file_lineno, filename = get_source_lines_and_file( fn, torch._C.ErrorReport.call_stack()) sourcelines = normalize_source_lines(sourcelines) source = ''.join(sourcelines) dedent_src = dedent(source) py_ast = ast.parse(dedent_src) if len(py_ast.body) != 1 or not isinstance(py_ast.body[0], ast.FunctionDef): raise RuntimeError( f"Expected a single top-level function: {filename}:{file_lineno}") leading_whitespace_len = len(source.split('\n', 1)[0]) - len( dedent_src.split('\n', 1)[0]) type_line = torch.jit.annotations.get_type_line(source) ctx = SourceContext(source, filename, file_lineno, leading_whitespace_len, True) fn_def = py_ast.body[0] if is_classmethod: arg_name = fn_def.args.args[0].arg # Insert a statement that assigns the first argument to the class assign_stmt = ast.parse(f"{arg_name} = {self_name}").body[0] fn_def.body.insert(0, assign_stmt) # Swap out the function signature and body if it is unused if should_drop(fn): unused_fn_def = ast.parse( "def unused_fn(self: Any):\n\traise RuntimeError(\"Cannot call @unused methods\")" ) if len(unused_fn_def.body) != 1 or not isinstance( unused_fn_def.body[0], ast.FunctionDef): raise RuntimeError( f"Expected a single top-level function: {filename}:{file_lineno}" ) unused_def = unused_fn_def.body[0] fn_def.body = unused_def.body # kwarg/vararg not supported by `build_def` fn_def.args.kwarg = fn_def.args.vararg = None for arg in fn_def.args.args + fn_def.args.kwonlyargs: # Replace potentially unsupported type annotations by "Any" arg.annotation = unused_def.args.args[0].annotation return build_def(ctx, fn_def, type_line, def_name, self_name=self_name)
def get_jit_def(fn, def_name, self_name=None): """ Build a JIT AST (TreeView) from the given function. Arguments: fn: A function object to compile def_name: The name to give to the resulting AST object. This is not always the same as `fn.__name__`, for example: def _forward(self): ... forward = _forward In this case, the `__name__` attribute of the function object is "_forward", but we want the result AST to have the name "forward". self_name: If this function is a method, what the type name of `self` is. """ sourcelines, file_lineno, filename = get_source_lines_and_file( fn, torch._C.ErrorReport.call_stack()) source = ''.join(sourcelines) dedent_src = dedent(source) py_ast = ast.parse(dedent_src) if len(py_ast.body) != 1 or not isinstance(py_ast.body[0], ast.FunctionDef): raise RuntimeError("Expected a single top-level function") leading_whitespace_len = len(source.split('\n', 1)[0]) - len( dedent_src.split('\n', 1)[0]) type_line = torch.jit.annotations.get_type_line(source) ctx = SourceContext(source, filename, file_lineno, leading_whitespace_len, True) fn_def = py_ast.body[0] # Swap out the function body if it is unused if should_drop(fn): unused_fn_def = ast.parse( "def unused_fn(self):\n\traise RuntimeError(\"Cannot call @unused methods\")" ).body[0] fn_def.body = unused_fn_def.body return build_def(ctx, fn_def, type_line, def_name, self_name=self_name)