def recursive_script(mod): """ Makes a ScriptModule from an nn.Module. If `_methods` is provided, these methods are treated as @script_methods. If not, it defaults to `('forward',)`. Methods accessed in forward are scripted on demand. """ if isinstance(mod, torch.jit.ScriptModule): return mod if isinstance(mod, (torch.nn.ModuleList, torch.nn.Sequential)): # Create constant versions for the iterable modules return create_constant_iterable_module(mod) methods = () if hasattr(mod, 'forward'): if mod.forward.__func__ == torch.nn.Module.forward: raise RuntimeError("No forward method was defined on {}".format(mod)) if not _jit_internal.is_ignored_fn(mod.forward): methods = ('forward',) exported = [] for name in dir(mod): item = getattr(mod, name) if callable(item): if _jit_internal.get_torchscript_modifier(item) is _jit_internal.FunctionModifiers.EXPORT: exported.append(name) methods = methods + tuple(exported) def make_stub(method): func = get_function_from_type(type(mod), method) return torch.jit.script_method(func, _jit_internal.createResolutionCallbackFromClosure(func)) stubs = list(map(make_stub, methods)) return copy_to_script_module(mod, stubs)
def infer_methods_to_compile(nn_module): """ Implements the default rules for which methods should act as starting points for compilation (TODO add a link when the rules are published). """ check_module_initialized(nn_module) user_annotated_ignored_attributes = getattr(nn_module, "__jit_ignored_attributes__", list()) ignored_properties = jit_ignored_properties(nn_module) methods: List[str] = [] if hasattr( nn_module, 'forward') and not _jit_internal.is_ignored_fn(nn_module.forward): forward_func = getattr(nn_module.forward, "__func__", None) module_forward = getattr(torch.nn.Module, "forward", None) if forward_func != module_forward: methods = ['forward'] exported = [] for name in dir(nn_module): if name in ignored_properties: continue item = getattr(nn_module, name, None) if _jit_internal.get_torchscript_modifier( item) is _jit_internal.FunctionModifiers.EXPORT: exported.append(name) methods = methods + exported overload_name_mappings = dict(getattr(nn_module, "__overloads__", {})) overload_info = get_overload_annotations(nn_module, ignored_properties) overload_name_mappings.update(get_overload_name_mapping(overload_info)) overload_stubs = make_stubs_for_overloads(overload_info) nn_module.__overloads__ = overload_name_mappings # we shouldn't directly compile overloaded methods, just its overloads def ignore_overloaded(method_name): return method_name not in overload_name_mappings filtered_methods = filter(ignore_overloaded, methods) # Unique the methods. We don't want to use a set to store the methods because it # introduces non-determinism to compile order. uniquer: Set[str] = set() uniqued_methods = [] for name in filtered_methods: if name in uniquer: continue uniqued_methods.append(name) uniquer.add(name) stubs = [] for method in uniqued_methods: stubs.append(make_stub_from_method(nn_module, method)) return overload_stubs + stubs
def make_stubs_from_exported_methods(mod): stubs = [] for name in dir(mod): item = getattr(mod, name, None) if (_jit_internal.get_torchscript_modifier(item) is _jit_internal.FunctionModifiers.EXPORT): stubs.append(make_stub_from_method(mod, name)) return stubs
def create_script_module_impl(nn_module, concrete_type, stubs_fn): """ Convert an nn.Module to a RecursiveScriptModule. Args: nn_module: The original Python nn.Module that we are creating a ScriptModule for. concrete_type: The fully initialized ConcreteType of the module. stubs_fn: Lambda that takes an nn.Module and generates a list of ScriptMethodStubs to compile. """ cpp_module = torch._C._create_module_with_type(concrete_type.jit_type) method_stubs = stubs_fn(nn_module) property_stubs = get_property_stubs(nn_module) hook_stubs, pre_hook_stubs = get_hook_stubs(nn_module) user_annotated_ignored_attributes = getattr(nn_module, "__jit_ignored_attributes__", list()) ignored_properties = jit_ignored_properties(nn_module) def init_fn(script_module): # Initialize the ScriptModule: # 1. Copy the attributes/parameters/buffers from the original `nn_module` to the new ScriptModule. for name, (attr_type, is_param) in concrete_type.get_attributes().items(): orig_value = getattr(nn_module, name) orig_value = orig_value.value if isinstance( orig_value, torch.jit.Attribute) else orig_value cpp_module.setattr(name, orig_value) # 2. Copy the submodules from the original `nn_module` to the new ScriptModule, # recursively scripting them. for name, sub_concrete_type in concrete_type.get_modules(): orig_value = getattr(nn_module, name) assert isinstance(orig_value, Module), "Expected Module but got {}".format( type(orig_value)) module_type = sub_concrete_type.jit_type if isinstance(module_type, torch._C.InterfaceType): # use the interface inference rule to compile the module scripted = interface_script(module_type, orig_value) elif isinstance(orig_value, torch.jit.ScriptModule): scripted = orig_value else: # always reuse the provided stubs_fn to infer the methods to compile scripted = create_script_module_impl(orig_value, sub_concrete_type, stubs_fn) cpp_module.setattr(name, scripted) script_module._modules[name] = scripted # 3. Copy @ignored/@unused methods and attrs from the original `nn_module` to the new ScriptModule. # This ensures we can access these Python methods on the ScriptModule. for name in dir(nn_module): if name in ignored_properties: continue item = getattr(nn_module, name, None) if inspect.ismethod(item) and _jit_internal.is_ignored_fn(item): unbound_function = getattr(type(nn_module), name) bound_method = unbound_function.__get__(script_module) setattr(script_module, name, bound_method) elif concrete_type.is_ignored_attribute(name): setattr(script_module, name, item) # For convenience, attach the concrete type to the new ScriptModule script_module._concrete_type = concrete_type # Actually create the ScriptModule, initializing it with the function we just defined script_module = torch.jit.RecursiveScriptModule._construct( cpp_module, init_fn) # Compile methods if necessary if concrete_type not in concrete_type_store.methods_compiled: create_methods_and_properties_from_stubs(concrete_type, method_stubs, property_stubs) # Create hooks after methods to ensure no name collisions between hooks and methods. # If done before, hooks can overshadow methods that aren't exported. create_hooks_from_stubs(concrete_type, hook_stubs, pre_hook_stubs) torch._C._run_emit_module_hook(cpp_module) concrete_type_store.methods_compiled.add(concrete_type) # Copy the forward hooks and pre-hooks to the new ScriptModule # to allow the hooks to be run from eager as ScriptFunctions for idx, fn in enumerate(script_module._c._get_forward_pre_hooks()): script_module._forward_pre_hooks[idx] = fn for idx, fn in enumerate(script_module._c._get_forward_hooks()): script_module._forward_hooks[idx] = fn # Special handling so methods like __len__ work in script methods on classes derived from containers if isinstance(nn_module, (torch.nn.ModuleList, torch.nn.Sequential, torch.nn.ModuleDict)) and \ '__len__' not in cpp_module._method_names(): script_module.define("def __len__(self):\n return {}\n".format( len(nn_module))) if isinstance(nn_module, torch.nn.ModuleDict) and \ '__contains__' not in cpp_module._method_names(): if len(nn_module.keys()): keys = repr(list(nn_module.keys())) script_module.define( "def __contains__(self, key: str):\n return key in {}\n". format(keys)) else: script_module.define( "def __contains__(self, key: str):\n return False\n") # Make the compiled methods available to the Python ScriptModule class. for method_stub in method_stubs: if method_stub.original_method is None: # define()'d methods don't have an Python original_method, so we # don't need to do any Python re-wrapping stuff continue name = method_stub.original_method.__name__ if name != method_stub.def_.name().name: # TODO: Why skip this? Because @torch.jit._overload_method will # mangle the name of the function. continue script_method = cpp_module._get_method(name) # Wrap the original to propagate docstrings and such. # TODO: we don't currently do this functions that are recursively # compiled, we should. wrapped_script_method = functools.wraps(method_stub.original_method)( script_method) # type: ignore # Add the methods to the script_module directly. This ensures they will # be found first when `name` is looked up (as opposed to the stubs or # nn.Module.forward) script_module.__dict__[name] = wrapped_script_method # Make module properties available on the Python ScriptModule class. for property_stub in property_stubs: property_name = property_stub.def_.name().name fget = cpp_module._get_method(property_stub.def_.getter_name().name) # Setter is optional, so it may not exist. setter_name = property_stub.def_.setter_name() fset = cpp_module._get_method( setter_name.name) if setter_name else None script_module.__dict__[property_name] = property( property_name, fget, fset) # type: ignore # copy over python methods to script module if they aren't defined on the script module # this is currently an internal api used only on module containers for name in dir(nn_module): if name in ignored_properties: continue item = getattr(nn_module, name, None) if _jit_internal.get_torchscript_modifier( item ) is _jit_internal.FunctionModifiers.COPY_TO_SCRIPT_WRAPPER: add_python_attr_to_scripted_model(script_module, nn_module, name) return script_module
def create_script_module_impl(nn_module, concrete_type, stubs_fn): """ Convert an nn.Module to a RecursiveScriptModule. Arguments: nn_module: The original Python nn.Module that we are creating a ScriptModule for. concrete_type: The fully initialized ConcreteType of the module. stubs_fn: Lambda that takes an nn.Module and generates a list of ScriptMethodStubs to compile. """ cpp_module = torch._C._create_module_with_type(concrete_type.jit_type) stubs = stubs_fn(nn_module) def init_fn(script_module): # Initialize the ScriptModule: # 1. Copy the attributes/parameters/buffers from the original `nn_module` to the new ScriptModule. for name, (attr_type, is_param) in concrete_type.get_attributes().items(): orig_value = getattr(nn_module, name) orig_value = orig_value.value if isinstance( orig_value, torch.jit.Attribute) else orig_value cpp_module.setattr(name, orig_value) # 2. Copy the submodules from the original `nn_module` to the new ScriptModule, # recursively scripting them. for name, sub_concrete_type in concrete_type.get_modules(): orig_value = getattr(nn_module, name) assert isinstance(orig_value, Module), "Expected Module but got {}".format( type(orig_value)) module_type = sub_concrete_type.jit_type if isinstance(module_type, torch._C.InterfaceType): # use the interface inference rule to compile the module scripted = interface_script(module_type, orig_value) elif isinstance(orig_value, torch.jit.ScriptModule): scripted = orig_value else: # use the default recursive rule to compile the module scripted = create_script_module_impl(orig_value, sub_concrete_type, infer_methods_to_compile) cpp_module.setattr(name, scripted) script_module._modules[name] = scripted # 3. Copy @ignored/@unused methods from the original `nn_module` to the new ScriptModule. # This ensures we can access these Python methods on the ScriptModule. for name in dir(nn_module): item = getattr(nn_module, name, None) if not inspect.ismethod(item): continue if _jit_internal.is_ignored_fn(item): unbound_function = getattr(type(nn_module), name) bound_method = unbound_function.__get__(script_module) setattr(script_module, name, bound_method) # For convenience, attach the concrete type to the new ScriptModule script_module._concrete_type = concrete_type # Actually create the ScriptModule, initializing it with the function we just defined script_module = torch.jit.RecursiveScriptModule._construct( cpp_module, init_fn) # Compile methods if necessary if concrete_type not in concrete_type_store.methods_compiled: create_methods_from_stubs(concrete_type, stubs) torch._C._run_emit_module_hook(cpp_module) concrete_type_store.methods_compiled.add(concrete_type) # Make the compiled methods available to the Python ScriptModule class. for stub in stubs: if stub.original_method is None: # define()'d methods don't have an Python original_method, so we # don't need to do any Python re-wrapping stuff continue name = stub.original_method.__name__ if name != stub.def_.name().name: # TODO: Why skip this? Because @torch.jit._overload_method will # mangle the name of the function. continue script_method = cpp_module._get_method(name) # Wrap the original to propagate docstrings and such. # TODO: we don't currently do this functions that are recursively # compiled, we should. script_method = functools.wraps(stub.original_method)(script_method) # Add the methods to the script_module directly. This ensures they will # be found first when `name` is looked up (as opposed to the stubs or # nn.Module.forward) script_module.__dict__[name] = script_method # copy over python methods to script module if they aren't defined on the script module # this is currently an internal api used only on module containers for name in dir(nn_module): item = getattr(nn_module, name, None) if _jit_internal.get_torchscript_modifier( item ) is _jit_internal.FunctionModifiers.COPY_TO_SCRIPT_WRAPPER: add_python_attr_to_scripted_model(script_module, nn_module, name) return script_module
def recursive_script(mod, exclude_methods=()): """ Makes a ScriptModule from an nn.Module. If `_methods` is provided, these methods are treated as @script_methods. If not, it defaults to `('forward',)`. Methods accessed in forward are scripted on demand. """ if isinstance(mod, torch.jit.ScriptModule): return mod if isinstance(mod, (torch.nn.ModuleList, torch.nn.Sequential, torch.nn.ModuleDict)): # Create constant versions for the iterable modules return create_constant_iterable_module(mod) if not hasattr(mod, '_parameters'): raise RuntimeError("'{}' has not been initialized, did you forget to call 'super()'?" .format(type(mod).__name__)) methods = () if hasattr(mod, 'forward'): if mod.forward.__func__ == torch.nn.Module.forward: raise RuntimeError("No forward method was defined on {}".format(mod)) if not _jit_internal.is_ignored_fn(mod.forward): methods = ('forward',) exported = [] overloads = [] for name in dir(mod): item = getattr(mod, name) if callable(item): if _jit_internal.get_torchscript_modifier(item) is _jit_internal.FunctionModifiers.EXPORT: exported.append(name) # builtin functions like repr() in python 2 do not have __module__ defined if hasattr(item, "__module__") and item.__module__ is not None: method_overloads = _jit_internal._get_overloaded_methods(item, mod.__class__) if method_overloads is not None: overloads.append((item, method_overloads)) methods = methods + tuple(exported) methods = tuple(name for name in methods if name not in exclude_methods) overload_name_mappings = dict(getattr(mod, "__overloads__", {})) overload_stubs = [] for orig_fn, overload_fns in overloads: orig_ast = torch.jit.get_jit_def(orig_fn, self_name="ScriptModule") names = list(map(lambda i: orig_ast.name().name + "__" + str(i), range(len(overload_fns)))) overload_name_mappings[orig_ast.name().name] = names for overload_fn, name in zip(overload_fns, names): torch.jit._check_no_signature(overload_fn) over_ast = torch.jit.get_jit_def(overload_fn, self_name="ScriptModule") new_ast = torch._C._replace_overloaded_method_decl(over_ast.decl(), orig_ast, name) _rcb = _jit_internal.createResolutionCallbackFromClosure(orig_fn) overload_stubs.append(torch.jit.ScriptMethodStub(_rcb, new_ast, overload_fn)) mod.__overloads__ = overload_name_mappings # we shouldn't directly compile overloaded methods, just its overloads def ignore_overloaded(method_name): return method_name not in overload_name_mappings def make_stub(method): func = get_function_from_type(type(mod), method) return torch.jit.script_method(func, _jit_internal.createResolutionCallbackFromClosure(func)) filtered_methods = filter(ignore_overloaded, methods) stubs = list(map(make_stub, filtered_methods)) return copy_to_script_module(mod, overload_stubs + stubs)
def _get_modifier_wrapper(fn): if inspect.isclass(fn) and issubclass(fn, Injector): return None else: return get_torchscript_modifier(fn)