def get_jit_def(fn, def_name, self_name=None): """ Build a JIT AST (TreeView) from the given function. Arguments: fn: A function object to compile def_name: The name to give to the resulting AST object. This is not always the same as `fn.__name__`, for example: def _forward(self): ... forward = _forward In this case, the `__name__` attribute of the function object is "_forward", but we want the result AST to have the name "forward". self_name: If this function is a method, what the type name of `self` is. """ sourcelines, file_lineno, filename = get_source_lines_and_file( fn, torch._C.ErrorReport.call_stack()) source = ''.join(sourcelines) dedent_src = dedent(source) py_ast = ast.parse(dedent_src) if len(py_ast.body) != 1 or not isinstance(py_ast.body[0], ast.FunctionDef): raise RuntimeError("Expected a single top-level function") leading_whitespace_len = len(source.split('\n', 1)[0]) - len( dedent_src.split('\n', 1)[0]) type_line = torch.jit.annotations.get_type_line(source) ctx = SourceContext(source, filename, file_lineno, leading_whitespace_len, _uses_true_division(fn)) return build_def(ctx, py_ast.body[0], type_line, def_name, self_name=self_name)
def get_jit_class_def(cls, self_name): # Get defs for each method within the current class independently # TODO: proper overriding analysis when implementing class inheritance methods = inspect.getmembers( cls, predicate=lambda m: (inspect.ismethod(m) or inspect.isfunction(m)) and not is_static_fn( cls, m.__name__) and m.__name__ in cls.__dict__) methods = [ get_jit_def(method[1], method[0], self_name=self_name) for method in methods ] properties = get_class_properties(cls, self_name) sourcelines, file_lineno, filename = get_source_lines_and_file( cls, torch._C.ErrorReport.call_stack()) source = ''.join(sourcelines) dedent_src = dedent(source) py_ast = ast.parse(dedent_src) leading_whitespace_len = len(source.split('\n', 1)[0]) - len( dedent_src.split('\n', 1)[0]) ctx = SourceContext(source, filename, file_lineno, leading_whitespace_len, False) return build_class_def(ctx, py_ast.body[0], methods, properties, self_name)
def get_num_params(fn, loc): try: source = dedent(''.join(get_source_lines_and_file(fn)[0])) except (TypeError, IOError): return None if source is None: return None py_ast = ast.parse(source) if len(py_ast.body) == 1 and isinstance(py_ast.body[0], ast.ClassDef): raise torch.jit.frontend.FrontendError( loc, "Cannot instantiate class '{}' in a script function".format( py_ast.body[0].name)) if len(py_ast.body) != 1 or not isinstance(py_ast.body[0], ast.FunctionDef): raise torch.jit.frontend.FrontendError( loc, "Expected a single top-level function") py_def = py_ast.body[0] if py_def.args.vararg is not None: return None elif hasattr(py_def.args, 'kwonlyargs') and len(py_def.args.kwonlyargs) > 0: return None else: num_params = len(py_def.args.args) if inspect.ismethod(fn): num_params = num_params - 1 return num_params
def get_signature(fn, rcb, loc, is_method): # Python 3.5 adds support for the nice annotation syntax, so try that first. signature = None if PY35: signature = try_real_annotations(fn, loc) if signature is not None and is_method: # If this is a method, then the signaure will include a type for # `self`, but type comments do not contain a `self`. So strip it # away here so everything is consistent (`inspect.ismethod` does # not work here since `fn` is unbound at this point) param_types, return_type = signature param_types = param_types[1:] signature = (param_types, return_type) if signature is None: type_line, source = None, None try: source = dedent(''.join(get_source_lines_and_file(fn)[0])) type_line = get_type_line(source) except TypeError: pass # This might happen both because we failed to get the source of fn, or # because it didn't have any annotations. if type_line is not None: signature = parse_type_line(type_line, rcb, loc) return signature
def persistent_id(obj): if isinstance(obj, type) and issubclass(obj, nn.Module): if obj in serialized_container_types: return None serialized_container_types[obj] = True source_file = source = None try: source_lines, _, source_file = get_source_lines_and_file(obj) source = ''.join(source_lines) except Exception: # saving the source is optional, so we can ignore any errors warnings.warn("Couldn't retrieve source code for container of " "type " + obj.__name__ + ". It won't be checked " "for correctness upon loading.") return ('module', obj, source_file, source) elif torch.is_storage(obj): storage_type = normalize_storage_type(type(obj)) # Offset is always 0, but we keep it for backwards compatibility # with the old serialization format (which supported storage views) offset = 0 obj_key = str(obj._cdata) location = location_tag(obj) serialized_storages[obj_key] = obj is_view = obj._cdata != obj._cdata if is_view: view_metadata = (str(obj._cdata), offset, obj.size()) else: view_metadata = None return ('storage', storage_type, obj_key, location, obj.size(), view_metadata) return None
def _get_overloaded_methods(method, mod_class): # TODO: __name__ not set for submodules in recursive script if not hasattr(method, "__name__"): return None qual_name = _qualified_name(method) class_name_map = _overloaded_methods.get(qual_name, None) if class_name_map is None: return None overloads = class_name_map.get(mod_class.__name__, None) if overloads is None: return None method_line_no = get_source_lines_and_file(method)[1] mod_class_fileno = get_source_lines_and_file(mod_class)[1] mod_end_fileno = mod_class_fileno + len(get_source_lines_and_file(mod_class)[0]) if not (method_line_no >= mod_class_fileno and method_line_no <= mod_end_fileno): raise Exception("Overloads are not useable when a module is redeclared within the same file: " + str(method)) return overloads
def get_jit_def(fn, def_name, self_name=None, is_classmethod=False): """ Build a JIT AST (TreeView) from the given function. Args: fn: A function object to compile def_name: The name to give to the resulting AST object. This is not always the same as `fn.__name__`, for example: def _forward(self): ... forward = _forward In this case, the `__name__` attribute of the function object is "_forward", but we want the result AST to have the name "forward". self_name: If this function is a method, what the type name of `self` is. """ sourcelines, file_lineno, filename = get_source_lines_and_file( fn, torch._C.ErrorReport.call_stack()) sourcelines = normalize_source_lines(sourcelines) source = ''.join(sourcelines) dedent_src = dedent(source) py_ast = ast.parse(dedent_src) if len(py_ast.body) != 1 or not isinstance(py_ast.body[0], ast.FunctionDef): raise RuntimeError( f"Expected a single top-level function: {filename}:{file_lineno}") leading_whitespace_len = len(source.split('\n', 1)[0]) - len( dedent_src.split('\n', 1)[0]) type_line = torch.jit.annotations.get_type_line(source) ctx = SourceContext(source, filename, file_lineno, leading_whitespace_len, True) fn_def = py_ast.body[0] if is_classmethod: arg_name = fn_def.args.args[0].arg # Insert a statement that assigns the first argument to the class assign_stmt = ast.parse(f"{arg_name} = {self_name}").body[0] fn_def.body.insert(0, assign_stmt) # Swap out the function signature and body if it is unused if should_drop(fn): unused_fn_def = ast.parse( "def unused_fn(self: Any):\n\traise RuntimeError(\"Cannot call @unused methods\")" ) if len(unused_fn_def.body) != 1 or not isinstance( unused_fn_def.body[0], ast.FunctionDef): raise RuntimeError( f"Expected a single top-level function: {filename}:{file_lineno}" ) unused_def = unused_fn_def.body[0] fn_def.body = unused_def.body # kwarg/vararg not supported by `build_def` fn_def.args.kwarg = fn_def.args.vararg = None for arg in fn_def.args.args + fn_def.args.kwonlyargs: # Replace potentially unsupported type annotations by "Any" arg.annotation = unused_def.args.args[0].annotation return build_def(ctx, fn_def, type_line, def_name, self_name=self_name)
def get_jit_def(fn, self_name=None): sourcelines, file_lineno, filename = get_source_lines_and_file(fn) source = ''.join(sourcelines) dedent_src = dedent(source) py_ast = ast.parse(dedent_src) if len(py_ast.body) != 1 or not isinstance(py_ast.body[0], ast.FunctionDef): raise RuntimeError("Expected a single top-level function") leading_whitespace_len = len(source.split('\n', 1)[0]) - len(dedent_src.split('\n', 1)[0]) type_line = torch.jit.annotations.get_type_line(source) ctx = SourceContext(source, filename, file_lineno, leading_whitespace_len, _uses_true_division(fn)) return build_def(ctx, py_ast.body[0], type_line, self_name)
def get_jit_class_def(cls, self_name): # Get defs for each method independently methods = inspect.getmembers( cls, predicate=lambda m: inspect.ismethod(m) or inspect.isfunction(m)) method_defs = [get_jit_def(method[1], self_name=self_name) for method in methods] sourcelines, file_lineno, filename = get_source_lines_and_file(cls) source = ''.join(sourcelines) dedent_src = dedent(source) py_ast = ast.parse(dedent_src) leading_whitespace_len = len(source.split('\n', 1)[0]) - len(dedent_src.split('\n', 1)[0]) ctx = SourceContext(source, filename, file_lineno, leading_whitespace_len, False) return build_class_def(ctx, py_ast.body[0], method_defs, self_name)
def check_fn(fn, loc): # Make sure the function definition is not a class instantiation try: source = dedent(''.join(get_source_lines_and_file(fn)[0])) except (TypeError, IOError): return if source is None: return py_ast = ast.parse(source) if len(py_ast.body) == 1 and isinstance(py_ast.body[0], ast.ClassDef): raise torch.jit.frontend.FrontendError( loc, f"Cannot instantiate class '{py_ast.body[0].name}' in a script function") if len(py_ast.body) != 1 or not isinstance(py_ast.body[0], ast.FunctionDef): raise torch.jit.frontend.FrontendError(loc, "Expected a single top-level function")
def _check_container_source(container_type, source_file, original_source): try: current_source = "".join( get_source_lines_and_file(container_type)[0]) except Exception: # saving the source is optional, so we can ignore any errors warnings.warn("Couldn't retrieve source code for container of " "type " + container_type.__name__ + ". It won't be checked " "for correctness upon loading.") return if original_source != current_source: if container_type.dump_patches: file_name = container_type.__name__ + ".patch" diff = difflib.unified_diff( current_source.split("\n"), original_source.split("\n"), source_file, source_file, lineterm="", ) lines = "\n".join(diff) try: with open(file_name, "a+") as f: file_size = f.seek(0, 2) f.seek(0) if file_size == 0: f.write(lines) elif file_size != len(lines) or f.read() != lines: raise IOError msg = ("Saved a reverse patch to " + file_name + ". " "Run `patch -p0 < " + file_name + "` to revert your " "changes.") except IOError: msg = ("Tried to save a patch, but couldn't create a " "writable file " + file_name + ". Make sure it " "doesn't exist and your working directory is " "writable.") else: msg = ("you can retrieve the original source code by " "accessing the object's source attribute or set " "`torch.nn.Module.dump_patches = True` and use the " "patch tool to revert the changes.") msg = "source code of class '{container_type}' has changed. {msg}".format( container_type=torch.typename(container_type), msg=msg) warnings.warn(msg, SourceChangeWarning)
def get_signature(fn, rcb, loc): # Python 3.5 adds support for the nice annotation syntax, so try that first. if PY35: sig = try_real_annotations(fn) if sig is not None: return sig type_line, source = None, None try: source = dedent(''.join(get_source_lines_and_file(fn)[0])) type_line = get_type_line(source) except TypeError: pass # This might happen both because we failed to get the source of fn, or # because it didn't have any annotations. if type_line is None: return None return parse_type_line(type_line, rcb, loc)
def persistent_id(obj: Any) -> Optional[Tuple]: # FIXME: the docs say that persistent_id should only return a string # but torch store returns tuples. This works only in the binary protocol # see # https://docs.python.org/2/library/pickle.html#pickling-and-unpickling-external-objects # https://github.com/python/cpython/blob/master/Lib/pickle.py#L527-L537 if isinstance(obj, type) and issubclass(obj, nn.Module): if obj in serialized_container_types: return None serialized_container_types[obj] = True source_file = source = None try: source_lines, _, source_file = get_source_lines_and_file(obj) source = ''.join(source_lines) except Exception: # saving the source is optional, so we can ignore any errors warnings.warn("Couldn't retrieve source code for container of " "type " + obj.__name__ + ". It won't be checked " "for correctness upon loading.") return ('module', obj, source_file, source) elif torch.is_storage(obj): view_metadata: Optional[Tuple[str, int, int]] obj = cast(Storage, obj) storage_type = normalize_storage_type(type(obj)) # Offset is always 0, but we keep it for backwards compatibility # with the old serialization format (which supported storage views) offset = 0 obj_key = str(obj._cdata) location = location_tag(obj) serialized_storages[obj_key] = obj is_view = obj._cdata != obj._cdata if is_view: view_metadata = (str(obj._cdata), offset, obj.size()) else: view_metadata = None return ('storage', storage_type, obj_key, location, obj.size(), view_metadata) return None
def get_jit_class_def(cls, self_name): # Get defs for each method within the current class independently # TODO: proper overriding analysis when implementing class inheritance methods = inspect.getmembers( cls, predicate=lambda m: (inspect.ismethod(m) or inspect.isfunction(m)) and not is_static_fn( cls, m.__name__) and m.__name__ in cls.__dict__) def is_classmethod(fn): return inspect.ismethod(fn) and getattr(fn, "__self__", None) == cls methods = [ get_jit_def(obj, name, self_name=self_name, is_classmethod=is_classmethod(obj)) for (name, obj) in methods ] properties = get_class_properties(cls, self_name) sourcelines, file_lineno, filename = get_source_lines_and_file( cls, torch._C.ErrorReport.call_stack()) source = ''.join(sourcelines) dedent_src = dedent(source) py_ast = ast.parse(dedent_src) leading_whitespace_len = len(source.split('\n', 1)[0]) - len( dedent_src.split('\n', 1)[0]) ctx = make_source_context(source, filename, file_lineno, leading_whitespace_len, False) class_ast = py_ast.body[0] assert isinstance(class_ast, ast.ClassDef) assigns = get_class_assigns(ctx, class_ast) return build_class_def(ctx, class_ast, methods, properties, self_name, assigns)