Ejemplo n.º 1
def recursive_script(mod):
    Makes a ScriptModule from an nn.Module. If `_methods` is provided,
    these methods are treated as @script_methods. If not, it defaults to
    `('forward',)`. Methods accessed in forward are scripted on demand.
    if isinstance(mod, torch.jit.ScriptModule):
        return mod

    if isinstance(mod, (torch.nn.ModuleList, torch.nn.Sequential)):
        # Create constant versions for the iterable modules
        return create_constant_iterable_module(mod)

    methods = ()
    if hasattr(mod, 'forward'):
        if mod.forward.__func__ == torch.nn.Module.forward:
            raise RuntimeError("No forward method was defined on {}".format(mod))
        if not _jit_internal.is_ignored_fn(mod.forward):
            methods = ('forward',)
    exported = []
    for name in dir(mod):
        item = getattr(mod, name)
        if callable(item):
            if _jit_internal.get_torchscript_modifier(item) is _jit_internal.FunctionModifiers.EXPORT:
    methods = methods + tuple(exported)

    def make_stub(method):
        func = get_function_from_type(type(mod), method)
        return torch.jit.script_method(func, _jit_internal.createResolutionCallbackFromClosure(func))

    stubs = list(map(make_stub, methods))
    return copy_to_script_module(mod, stubs)
Ejemplo n.º 2
def infer_methods_to_compile(nn_module):
    Implements the default rules for which methods should act as starting
    points for compilation (TODO add a link when the rules are published).
    user_annotated_ignored_attributes = getattr(nn_module,
    ignored_properties = jit_ignored_properties(nn_module)

    methods: List[str] = []
    if hasattr(
            'forward') and not _jit_internal.is_ignored_fn(nn_module.forward):
        forward_func = getattr(nn_module.forward, "__func__", None)
        module_forward = getattr(torch.nn.Module, "forward", None)
        if forward_func != module_forward:
            methods = ['forward']

    exported = []
    for name in dir(nn_module):
        if name in ignored_properties:
        item = getattr(nn_module, name, None)
        if _jit_internal.get_torchscript_modifier(
                item) is _jit_internal.FunctionModifiers.EXPORT:

    methods = methods + exported

    overload_name_mappings = dict(getattr(nn_module, "__overloads__", {}))
    overload_info = get_overload_annotations(nn_module, ignored_properties)
    overload_stubs = make_stubs_for_overloads(overload_info)

    nn_module.__overloads__ = overload_name_mappings

    # we shouldn't directly compile overloaded methods, just its overloads
    def ignore_overloaded(method_name):
        return method_name not in overload_name_mappings

    filtered_methods = filter(ignore_overloaded, methods)

    # Unique the methods. We don't want to use a set to store the methods because it
    # introduces non-determinism to compile order.
    uniquer: Set[str] = set()
    uniqued_methods = []
    for name in filtered_methods:
        if name in uniquer:

    stubs = []
    for method in uniqued_methods:
        stubs.append(make_stub_from_method(nn_module, method))
    return overload_stubs + stubs
Ejemplo n.º 3
def make_stubs_from_exported_methods(mod):
    stubs = []
    for name in dir(mod):
        item = getattr(mod, name, None)
        if (_jit_internal.get_torchscript_modifier(item) is
            stubs.append(make_stub_from_method(mod, name))

    return stubs
Ejemplo n.º 4
def create_script_module_impl(nn_module, concrete_type, stubs_fn):
    Convert an nn.Module to a RecursiveScriptModule.

        nn_module:  The original Python nn.Module that we are creating a ScriptModule for.
        concrete_type:  The fully initialized ConcreteType of the module.
        stubs_fn:  Lambda that takes an nn.Module and generates a list of ScriptMethodStubs to compile.
    cpp_module = torch._C._create_module_with_type(concrete_type.jit_type)
    method_stubs = stubs_fn(nn_module)
    property_stubs = get_property_stubs(nn_module)
    hook_stubs, pre_hook_stubs = get_hook_stubs(nn_module)

    user_annotated_ignored_attributes = getattr(nn_module,
    ignored_properties = jit_ignored_properties(nn_module)

    def init_fn(script_module):
        # Initialize the ScriptModule:
        # 1. Copy the attributes/parameters/buffers from the original `nn_module` to the new ScriptModule.
        for name, (attr_type,
                   is_param) in concrete_type.get_attributes().items():
            orig_value = getattr(nn_module, name)
            orig_value = orig_value.value if isinstance(
                orig_value, torch.jit.Attribute) else orig_value
            cpp_module.setattr(name, orig_value)

        # 2. Copy the submodules from the original `nn_module` to the new ScriptModule,
        #    recursively scripting them.
        for name, sub_concrete_type in concrete_type.get_modules():
            orig_value = getattr(nn_module, name)
            assert isinstance(orig_value,
                              Module), "Expected Module but got {}".format(
            module_type = sub_concrete_type.jit_type
            if isinstance(module_type, torch._C.InterfaceType):
                # use the interface inference rule to compile the module
                scripted = interface_script(module_type, orig_value)
            elif isinstance(orig_value, torch.jit.ScriptModule):
                scripted = orig_value
                # always reuse the provided stubs_fn to infer the methods to compile
                scripted = create_script_module_impl(orig_value,

            cpp_module.setattr(name, scripted)
            script_module._modules[name] = scripted

        # 3. Copy @ignored/@unused methods and attrs from the original `nn_module` to the new ScriptModule.
        #    This ensures we can access these Python methods on the ScriptModule.
        for name in dir(nn_module):
            if name in ignored_properties:
            item = getattr(nn_module, name, None)
            if inspect.ismethod(item) and _jit_internal.is_ignored_fn(item):
                unbound_function = getattr(type(nn_module), name)
                bound_method = unbound_function.__get__(script_module)
                setattr(script_module, name, bound_method)
            elif concrete_type.is_ignored_attribute(name):
                setattr(script_module, name, item)

        # For convenience, attach the concrete type to the new ScriptModule
        script_module._concrete_type = concrete_type

    # Actually create the ScriptModule, initializing it with the function we just defined
    script_module = torch.jit.RecursiveScriptModule._construct(
        cpp_module, init_fn)

    # Compile methods if necessary
    if concrete_type not in concrete_type_store.methods_compiled:
        create_methods_and_properties_from_stubs(concrete_type, method_stubs,
        # Create hooks after methods to ensure no name collisions between hooks and methods.
        # If done before, hooks can overshadow methods that aren't exported.
        create_hooks_from_stubs(concrete_type, hook_stubs, pre_hook_stubs)

    # Copy the forward hooks and pre-hooks to the new ScriptModule
    # to allow the hooks to be run from eager as ScriptFunctions
    for idx, fn in enumerate(script_module._c._get_forward_pre_hooks()):
        script_module._forward_pre_hooks[idx] = fn
    for idx, fn in enumerate(script_module._c._get_forward_hooks()):
        script_module._forward_hooks[idx] = fn

    # Special handling so methods like __len__ work in script methods on classes derived from containers
    if isinstance(nn_module, (torch.nn.ModuleList, torch.nn.Sequential, torch.nn.ModuleDict)) and \
            '__len__' not in cpp_module._method_names():
        script_module.define("def __len__(self):\n   return {}\n".format(
    if isinstance(nn_module, torch.nn.ModuleDict) and \
            '__contains__' not in cpp_module._method_names():
        if len(nn_module.keys()):
            keys = repr(list(nn_module.keys()))
                "def __contains__(self, key: str):\n   return key in {}\n".
                "def __contains__(self, key: str):\n   return False\n")

    # Make the compiled methods available to the Python ScriptModule class.
    for method_stub in method_stubs:
        if method_stub.original_method is None:
            # define()'d methods don't have an Python original_method, so we
            # don't need to do any Python re-wrapping stuff

        name = method_stub.original_method.__name__
        if name != method_stub.def_.name().name:
            # TODO: Why skip this? Because @torch.jit._overload_method will
            # mangle the name of the function.
        script_method = cpp_module._get_method(name)

        # Wrap the original to propagate docstrings and such.
        # TODO: we don't currently do this functions that are recursively
        # compiled, we should.
        wrapped_script_method = functools.wraps(method_stub.original_method)(
            script_method)  # type: ignore

        # Add the methods to the script_module directly. This ensures they will
        # be found first when `name` is looked up (as opposed to the stubs or
        # nn.Module.forward)
        script_module.__dict__[name] = wrapped_script_method

    # Make module properties available on the Python ScriptModule class.
    for property_stub in property_stubs:
        property_name = property_stub.def_.name().name
        fget = cpp_module._get_method(property_stub.def_.getter_name().name)
        # Setter is optional, so it may not exist.
        setter_name = property_stub.def_.setter_name()
        fset = cpp_module._get_method(
            setter_name.name) if setter_name else None
        script_module.__dict__[property_name] = property(
            property_name, fget, fset)  # type: ignore

    # copy over python methods to script module if they aren't defined on the script module
    # this is currently an internal api used only on module containers
    for name in dir(nn_module):
        if name in ignored_properties:
        item = getattr(nn_module, name, None)
        if _jit_internal.get_torchscript_modifier(
        ) is _jit_internal.FunctionModifiers.COPY_TO_SCRIPT_WRAPPER:
            add_python_attr_to_scripted_model(script_module, nn_module, name)

    return script_module
Ejemplo n.º 5
def create_script_module_impl(nn_module, concrete_type, stubs_fn):
    Convert an nn.Module to a RecursiveScriptModule.

        nn_module:  The original Python nn.Module that we are creating a ScriptModule for.
        concrete_type:  The fully initialized ConcreteType of the module.
        stubs_fn:  Lambda that takes an nn.Module and generates a list of ScriptMethodStubs to compile.
    cpp_module = torch._C._create_module_with_type(concrete_type.jit_type)
    stubs = stubs_fn(nn_module)

    def init_fn(script_module):
        # Initialize the ScriptModule:
        # 1. Copy the attributes/parameters/buffers from the original `nn_module` to the new ScriptModule.
        for name, (attr_type,
                   is_param) in concrete_type.get_attributes().items():
            orig_value = getattr(nn_module, name)
            orig_value = orig_value.value if isinstance(
                orig_value, torch.jit.Attribute) else orig_value
            cpp_module.setattr(name, orig_value)

        # 2. Copy the submodules from the original `nn_module` to the new ScriptModule,
        #    recursively scripting them.
        for name, sub_concrete_type in concrete_type.get_modules():
            orig_value = getattr(nn_module, name)
            assert isinstance(orig_value,
                              Module), "Expected Module but got {}".format(
            module_type = sub_concrete_type.jit_type
            if isinstance(module_type, torch._C.InterfaceType):
                # use the interface inference rule to compile the module
                scripted = interface_script(module_type, orig_value)
            elif isinstance(orig_value, torch.jit.ScriptModule):
                scripted = orig_value
                # use the default recursive rule to compile the module
                scripted = create_script_module_impl(orig_value,
            cpp_module.setattr(name, scripted)
            script_module._modules[name] = scripted

        # 3. Copy @ignored/@unused methods from the original `nn_module` to the new ScriptModule.
        #    This ensures we can access these Python methods on the ScriptModule.
        for name in dir(nn_module):
            item = getattr(nn_module, name, None)
            if not inspect.ismethod(item):
            if _jit_internal.is_ignored_fn(item):
                unbound_function = getattr(type(nn_module), name)
                bound_method = unbound_function.__get__(script_module)
                setattr(script_module, name, bound_method)

        # For convenience, attach the concrete type to the new ScriptModule
        script_module._concrete_type = concrete_type

    # Actually create the ScriptModule, initializing it with the function we just defined
    script_module = torch.jit.RecursiveScriptModule._construct(
        cpp_module, init_fn)

    # Compile methods if necessary
    if concrete_type not in concrete_type_store.methods_compiled:
        create_methods_from_stubs(concrete_type, stubs)

    # Make the compiled methods available to the Python ScriptModule class.
    for stub in stubs:
        if stub.original_method is None:
            # define()'d methods don't have an Python original_method, so we
            # don't need to do any Python re-wrapping stuff

        name = stub.original_method.__name__
        if name != stub.def_.name().name:
            # TODO: Why skip this? Because @torch.jit._overload_method will
            # mangle the name of the function.
        script_method = cpp_module._get_method(name)

        # Wrap the original to propagate docstrings and such.
        # TODO: we don't currently do this functions that are recursively
        # compiled, we should.
        script_method = functools.wraps(stub.original_method)(script_method)

        # Add the methods to the script_module directly. This ensures they will
        # be found first when `name` is looked up (as opposed to the stubs or
        # nn.Module.forward)
        script_module.__dict__[name] = script_method

    # copy over python methods to script module if they aren't defined on the script module
    # this is currently an internal api used only on module containers
    for name in dir(nn_module):
        item = getattr(nn_module, name, None)
        if _jit_internal.get_torchscript_modifier(
        ) is _jit_internal.FunctionModifiers.COPY_TO_SCRIPT_WRAPPER:
            add_python_attr_to_scripted_model(script_module, nn_module, name)

    return script_module
Ejemplo n.º 6
def recursive_script(mod, exclude_methods=()):
    Makes a ScriptModule from an nn.Module. If `_methods` is provided,
    these methods are treated as @script_methods. If not, it defaults to
    `('forward',)`. Methods accessed in forward are scripted on demand.
    if isinstance(mod, torch.jit.ScriptModule):
        return mod

    if isinstance(mod, (torch.nn.ModuleList, torch.nn.Sequential, torch.nn.ModuleDict)):
        # Create constant versions for the iterable modules
        return create_constant_iterable_module(mod)

    if not hasattr(mod, '_parameters'):
        raise RuntimeError("'{}' has not been initialized, did you forget to call 'super()'?"

    methods = ()
    if hasattr(mod, 'forward'):
        if mod.forward.__func__ == torch.nn.Module.forward:
            raise RuntimeError("No forward method was defined on {}".format(mod))
        if not _jit_internal.is_ignored_fn(mod.forward):
            methods = ('forward',)
    exported = []
    overloads = []
    for name in dir(mod):
        item = getattr(mod, name)
        if callable(item):
            if _jit_internal.get_torchscript_modifier(item) is _jit_internal.FunctionModifiers.EXPORT:

            # builtin functions like repr() in python 2 do not have __module__ defined
            if hasattr(item, "__module__") and item.__module__ is not None:
                method_overloads = _jit_internal._get_overloaded_methods(item, mod.__class__)

                if method_overloads is not None:
                    overloads.append((item, method_overloads))

    methods = methods + tuple(exported)

    methods = tuple(name for name in methods if name not in exclude_methods)

    overload_name_mappings = dict(getattr(mod, "__overloads__", {}))
    overload_stubs = []

    for orig_fn, overload_fns in overloads:
        orig_ast = torch.jit.get_jit_def(orig_fn, self_name="ScriptModule")
        names = list(map(lambda i: orig_ast.name().name + "__" + str(i), range(len(overload_fns))))
        overload_name_mappings[orig_ast.name().name] = names
        for overload_fn, name in zip(overload_fns, names):
            over_ast = torch.jit.get_jit_def(overload_fn, self_name="ScriptModule")
            new_ast = torch._C._replace_overloaded_method_decl(over_ast.decl(), orig_ast, name)
            _rcb = _jit_internal.createResolutionCallbackFromClosure(orig_fn)
            overload_stubs.append(torch.jit.ScriptMethodStub(_rcb, new_ast, overload_fn))

    mod.__overloads__ = overload_name_mappings

    # we shouldn't directly compile overloaded methods, just its overloads
    def ignore_overloaded(method_name):
        return method_name not in overload_name_mappings

    def make_stub(method):
        func = get_function_from_type(type(mod), method)
        return torch.jit.script_method(func, _jit_internal.createResolutionCallbackFromClosure(func))

    filtered_methods = filter(ignore_overloaded, methods)
    stubs = list(map(make_stub, filtered_methods))
    return copy_to_script_module(mod, overload_stubs + stubs)
Ejemplo n.º 7
def _get_modifier_wrapper(fn):
    if inspect.isclass(fn) and issubclass(fn, Injector):
        return None
        return get_torchscript_modifier(fn)