def recursive_script(mod): """ Makes a ScriptModule from an nn.Module. If `_methods` is provided, these methods are treated as @script_methods. If not, it defaults to `('forward',)`. Methods accessed in forward are scripted on demand. """ if isinstance(mod, torch.jit.ScriptModule): return mod if isinstance(mod, (torch.nn.ModuleList, torch.nn.Sequential)): # Create constant versions for the iterable modules return create_constant_iterable_module(mod) methods = () if hasattr(mod, 'forward'): if mod.forward.__func__ == torch.nn.Module.forward: raise RuntimeError("No forward method was defined on {}".format(mod)) if not _jit_internal.is_ignored_fn(mod.forward): methods = ('forward',) exported = [] for name in dir(mod): item = getattr(mod, name) if callable(item): if _jit_internal.get_torchscript_modifier(item) is _jit_internal.FunctionModifiers.EXPORT: exported.append(name) methods = methods + tuple(exported) def make_stub(method): func = get_function_from_type(type(mod), method) return torch.jit.script_method(func, _jit_internal.createResolutionCallbackFromClosure(func)) stubs = list(map(make_stub, methods)) return copy_to_script_module(mod, stubs)
def infer_methods_to_compile(nn_module): """ Implements the default rules for which methods should act as starting points for compilation (TODO add a link when the rules are published). """ check_module_initialized(nn_module) user_annotated_ignored_attributes = getattr(nn_module, "__jit_ignored_attributes__", list()) ignored_properties = jit_ignored_properties(nn_module) methods: List[str] = [] if hasattr( nn_module, 'forward') and not _jit_internal.is_ignored_fn(nn_module.forward): forward_func = getattr(nn_module.forward, "__func__", None) module_forward = getattr(torch.nn.Module, "forward", None) if forward_func != module_forward: methods = ['forward'] exported = [] for name in dir(nn_module): if name in ignored_properties: continue item = getattr(nn_module, name, None) if _jit_internal.get_torchscript_modifier( item) is _jit_internal.FunctionModifiers.EXPORT: exported.append(name) methods = methods + exported overload_name_mappings = dict(getattr(nn_module, "__overloads__", {})) overload_info = get_overload_annotations(nn_module, ignored_properties) overload_name_mappings.update(get_overload_name_mapping(overload_info)) overload_stubs = make_stubs_for_overloads(overload_info) nn_module.__overloads__ = overload_name_mappings # we shouldn't directly compile overloaded methods, just its overloads def ignore_overloaded(method_name): return method_name not in overload_name_mappings filtered_methods = filter(ignore_overloaded, methods) # Unique the methods. We don't want to use a set to store the methods because it # introduces non-determinism to compile order. uniquer: Set[str] = set() uniqued_methods = [] for name in filtered_methods: if name in uniquer: continue uniqued_methods.append(name) uniquer.add(name) stubs = [] for method in uniqued_methods: stubs.append(make_stub_from_method(nn_module, method)) return overload_stubs + stubs
def compile_unbound_method(concrete_type, fn): if _jit_internal.is_ignored_fn(fn): return None stub = make_stub(fn, fn.__name__) with torch._jit_internal._disable_emit_hooks(): # We don't want to call the hooks here since the graph that is calling # this function is not yet complete create_methods_and_properties_from_stubs(concrete_type, (stub, ), ()) return stub
def create_method_from_fn(module, fn): if _jit_internal.is_ignored_fn(fn): return None if not inspect.ismethod(fn): return None stub = torch.jit.script_method(fn, _jit_internal.createResolutionCallbackFromClosure(fn)) with torch.jit._disable_emit_hooks(): # We don't want to call the hooks here since the graph that is calling # this function is not yet complete torch.jit._create_methods_from_stubs(module, (stub,)) return stub
def init_fn(script_module): orig_class = concrete_type.py_class # Copy @ignored/@unused methods from the original module to the new one. # This ensures they are available during execution. for name in dir(orig_class): item = getattr(orig_class, name, None) if _jit_internal.is_ignored_fn(item): setattr(script_module, name, item) # Copy constants over so they are available during execution. for name, value in concrete_type.get_constants().items(): setattr(script_module, name, value)
def init_fn(script_module): # Initialize the ScriptModule: # 1. Copy the attributes/parameters/buffers from the original `nn_module` to the new ScriptModule. for name, (attr_type, is_param) in concrete_type.get_attributes().items(): orig_value = getattr(nn_module, name) orig_value = orig_value.value if isinstance( orig_value, torch.jit.Attribute) else orig_value cpp_module.setattr(name, orig_value) # 2. Copy the submodules from the original `nn_module` to the new ScriptModule, # recursively scripting them. for name, sub_concrete_type in concrete_type.get_modules(): orig_value = getattr(nn_module, name) assert isinstance(orig_value, Module), "Expected Module but got {}".format( type(orig_value)) module_type = sub_concrete_type.jit_type if isinstance(module_type, torch._C.InterfaceType): # use the interface inference rule to compile the module scripted = interface_script(module_type, orig_value) elif isinstance(orig_value, torch.jit.ScriptModule): scripted = orig_value else: # always reuse the provided stubs_fn to infer the methods to compile scripted = create_script_module_impl(orig_value, sub_concrete_type, stubs_fn) cpp_module.setattr(name, scripted) script_module._modules[name] = scripted # 3. Copy @ignored/@unused methods and attrs from the original `nn_module` to the new ScriptModule. # This ensures we can access these Python methods on the ScriptModule. for name in dir(nn_module): if name in ignored_properties: continue item = getattr(nn_module, name, None) if inspect.ismethod(item) and _jit_internal.is_ignored_fn(item): unbound_function = getattr(type(nn_module), name) bound_method = unbound_function.__get__(script_module) setattr(script_module, name, bound_method) elif concrete_type.is_ignored_attribute(name): setattr(script_module, name, item) # For convenience, attach the concrete type to the new ScriptModule script_module._concrete_type = concrete_type
def try_compile_fn(fn, loc): if _jit_internal.is_ignored_fn(fn): # Don't do anything for @ignore'd functions return None if isinstance(fn, torch.nn.Module): # Since modules are callable pybind recognizes them as functions, but # don't do anything for them return None if not inspect.isfunction(fn) and not inspect.ismethod(fn): raise RuntimeError("`{}` is not a function. Recursive scripting only supports " "Python functions or methods currently.\n" "Consider manually annotating `{}` with @torch.jit.script.".format(fn, fn)) # We don't have the actual scope where the function was defined, but we can # extract the necessary info from the closed over variables on the function # object rcb = _jit_internal.createResolutionCallbackFromClosure(fn) return torch.jit.script(fn, _rcb=rcb)
def init_fn(script_module): # Initialize the ScriptModule: # 1. Copy the attributes/parameters/buffers from the original `nn_module` to the new ScriptModule. for name, (attr_type, is_param) in concrete_type.get_attributes().items(): orig_value = getattr(nn_module, name) if is_param: cpp_module._register_parameter(name, orig_value, False) else: orig_value = orig_value.value if isinstance( orig_value, torch.jit.Attribute) else orig_value cpp_module._register_attribute(name, attr_type, orig_value) # 2. Copy the submodules from the original `nn_module` to the new ScriptModule, # recursively scripting them. for name, module_type in concrete_type.get_modules(): orig_value = getattr(nn_module, name) assert isinstance(orig_value, Module) if isinstance(module_type, torch._C.InterfaceType): # use the interface inference rule to compile the module scripted = interface_script(module_type, orig_value) else: # use the default recursive rule to compile the module scripted = recursive_script(orig_value) cpp_module._register_module(name, scripted._c) script_module._modules[name] = scripted # 3. Copy @ignored/@unused methods from the original `nn_module` to the new ScriptModule. # This ensures we can access these Python methods on the ScriptModule. for name in dir(nn_module): item = getattr(nn_module, name, None) if not inspect.ismethod(item): continue if _jit_internal.is_ignored_fn(item): setattr(script_module, name, item) # For convenience, attach the concrete type to the new ScriptModule script_module._concrete_type = concrete_type
def recursive_script(mod, exclude_methods=()): """ Makes a ScriptModule from an nn.Module. If `_methods` is provided, these methods are treated as @script_methods. If not, it defaults to `('forward',)`. Methods accessed in forward are scripted on demand. """ if isinstance(mod, torch.jit.ScriptModule): return mod if isinstance(mod, (torch.nn.ModuleList, torch.nn.Sequential, torch.nn.ModuleDict)): # Create constant versions for the iterable modules return create_constant_iterable_module(mod) if not hasattr(mod, '_parameters'): raise RuntimeError("'{}' has not been initialized, did you forget to call 'super()'?" .format(type(mod).__name__)) methods = () if hasattr(mod, 'forward'): if mod.forward.__func__ == torch.nn.Module.forward: raise RuntimeError("No forward method was defined on {}".format(mod)) if not _jit_internal.is_ignored_fn(mod.forward): methods = ('forward',) exported = [] overloads = [] for name in dir(mod): item = getattr(mod, name) if callable(item): if _jit_internal.get_torchscript_modifier(item) is _jit_internal.FunctionModifiers.EXPORT: exported.append(name) # builtin functions like repr() in python 2 do not have __module__ defined if hasattr(item, "__module__") and item.__module__ is not None: method_overloads = _jit_internal._get_overloaded_methods(item, mod.__class__) if method_overloads is not None: overloads.append((item, method_overloads)) methods = methods + tuple(exported) methods = tuple(name for name in methods if name not in exclude_methods) overload_name_mappings = dict(getattr(mod, "__overloads__", {})) overload_stubs = [] for orig_fn, overload_fns in overloads: orig_ast = torch.jit.get_jit_def(orig_fn, self_name="ScriptModule") names = list(map(lambda i: orig_ast.name().name + "__" + str(i), range(len(overload_fns)))) overload_name_mappings[orig_ast.name().name] = names for overload_fn, name in zip(overload_fns, names): torch.jit._check_no_signature(overload_fn) over_ast = torch.jit.get_jit_def(overload_fn, self_name="ScriptModule") new_ast = torch._C._replace_overloaded_method_decl(over_ast.decl(), orig_ast, name) _rcb = _jit_internal.createResolutionCallbackFromClosure(orig_fn) overload_stubs.append(torch.jit.ScriptMethodStub(_rcb, new_ast, overload_fn)) mod.__overloads__ = overload_name_mappings # we shouldn't directly compile overloaded methods, just its overloads def ignore_overloaded(method_name): return method_name not in overload_name_mappings def make_stub(method): func = get_function_from_type(type(mod), method) return torch.jit.script_method(func, _jit_internal.createResolutionCallbackFromClosure(func)) filtered_methods = filter(ignore_overloaded, methods) stubs = list(map(make_stub, filtered_methods)) return copy_to_script_module(mod, overload_stubs + stubs)