Ejemplo n.º 1
0
def get_property_stubs(nn_module):
    """
    Create property stubs for the properties of the module by creating method
    stubs for the getter and setter.
    """
    module_ty = type(nn_module)
    properties_asts = get_class_properties(module_ty,
                                           self_name="RecursiveScriptModule")
    rcbs = {}

    for name in dir(module_ty):
        item = getattr(module_ty, name, None)
        if isinstance(item, property):
            if not item.fget:
                raise RuntimeError(
                    f'Property {name} of {nn_module.__name__} must have a getter'
                )

            rcbs[name] = _jit_internal.createResolutionCallbackFromClosure(
                item.fget)

    stubs = [
        PropertyStub(rcbs[ast.name().name], ast) for ast in properties_asts
    ]
    return stubs
Ejemplo n.º 2
0
def create_method_from_fn(module, fn):
    if _jit_internal.is_ignored_fn(fn):
        return None
    if not inspect.ismethod(fn):
        return None
    stub = torch.jit.script_method(fn, _jit_internal.createResolutionCallbackFromClosure(fn))
    with torch.jit._disable_emit_hooks():
        # We don't want to call the hooks here since the graph that is calling
        # this function is not yet complete
        torch.jit._create_methods_from_stubs(module, (stub,))
    return stub
Ejemplo n.º 3
0
def make_stubs_for_overloads(overload_info):
    overload_stubs = []
    for orig_fn, overloads in overload_info.items():
        orig_ast = get_jit_def(orig_fn, orig_fn.__name__, self_name="RecursiveScriptModule")
        for overload_name, overload_fn in overloads:
            _check_no_signature(overload_fn)
            over_ast = get_jit_def(overload_fn, overload_fn.__name__, self_name="RecursiveScriptModule")
            new_ast = torch._C._replace_overloaded_method_decl(over_ast.decl(), orig_ast, overload_name)
            _rcb = _jit_internal.createResolutionCallbackFromClosure(orig_fn)
            overload_stubs.append(ScriptMethodStub(_rcb, new_ast, overload_fn))
    return overload_stubs
Ejemplo n.º 4
0
    def test_sc(obj, optimize=None, _frames_up=0, _rcb=None):
        qualified_name = _qualified_name(obj)
        if inspect.isclass(obj):
            # If this type is a `nn.Module` subclass, they probably meant to pass
            # an instance instead of a Module
            if issubclass(obj, torch.nn.Module):
                raise RuntimeError("Type '{}' cannot be compiled since it inherits"
                                   " from nn.Module,"
                                   " pass an instance instead".format(obj))

            if not _is_new_style_class(obj):
                raise RuntimeError("TorchScript classes must be new-style classes. "
                                   "Please inherit from 'object'.")
            if len(obj.mro()) > 2:
                raise RuntimeError("TorchScript classes does not support inheritance yet. "
                                   "Please directly inherit from 'object'.")
            if _rcb is None:
                _rcb = _jit_internal.createResolutionCallbackFromFrame(_frames_up + 1)
            _compile_and_register_class(obj, _rcb, qualified_name)
            return obj
        else:
            #_check_directly_compile_overloaded(obj)
            #maybe_already_compiled_fn = _try_get_jit_cached_function(obj)
            #if maybe_already_compiled_fn:
            #    return maybe_already_compiled_fn
            ast = get_jit_def(obj, obj.__name__)
            print("---ast---")
            print(ast)
            if _rcb is None:
                _rcb = _jit_internal.createResolutionCallbackFromClosure(obj)
            print("---rcb---")
            print(_rcb)
            fn = torch._C._jit_script_compile(qualified_name, ast, _rcb, get_default_args(obj))
            # Forward docstrings
            fn.__doc__ = obj.__doc__
            #_set_jit_function_cache(obj, fn)
            print("---scripted_fn---")
            print(fn)
            print("---scripted_fn.code---")
            print(fn.code)
            print("---scripted_fn.schema---")
            print(fn.schema)
            print("---scripted_fn.graph---")
            print(fn.graph)
            print("---scripted_fn.name---")
            print(fn.name)
            return fn
Ejemplo n.º 5
0
def _compile_function_with_overload(overload_fn, qual_name, impl_fn):
    overload_decl = get_jit_def(overload_fn, overload_fn.__name__).decl()
    overload_signature = torch.jit.annotations.get_signature(
        overload_fn, None, None, inspect.ismethod(overload_fn))
    impl_ast = get_jit_def(impl_fn, impl_fn.__name__)
    overload_defaults = get_default_args(overload_fn)
    implementation_defaults = get_default_args(impl_fn)
    _rcb = _jit_internal.createResolutionCallbackFromClosure(impl_fn)
    _check_overload_defaults(implementation_defaults, overload_defaults,
                             overload_decl.range())
    fn = torch._C._jit_script_compile_overload(
        qual_name,
        overload_decl,
        impl_ast,
        _rcb,
        implementation_defaults,
        overload_signature,
    )
    return fn
Ejemplo n.º 6
0
def try_compile_fn(fn, loc):
    if _jit_internal.is_ignored_fn(fn):
        # Don't do anything for @ignore'd functions
        return None

    if isinstance(fn, torch.nn.Module):
        # Since modules are callable pybind recognizes them as functions, but
        # don't do anything for them
        return None

    if not inspect.isfunction(fn) and not inspect.ismethod(fn):
        raise RuntimeError("`{}` is not a function. Recursive scripting only supports "
                           "Python functions or methods currently.\n"
                           "Consider manually annotating `{}` with @torch.jit.script.".format(fn, fn))

    # We don't have the actual scope where the function was defined, but we can
    # extract the necessary info from the closed over variables on the function
    # object
    rcb = _jit_internal.createResolutionCallbackFromClosure(fn)
    return torch.jit.script(fn, _rcb=rcb)
Ejemplo n.º 7
0
def script(obj, optimize=None, _frames_up=0, _rcb=None):
    r"""
    Scripting a function or ``nn.Module`` will inspect the source code, compile
    it as TorchScript code using the TorchScript compiler, and return a :class:`ScriptModule` or
    :class:`ScriptFunction`. TorchScript itself is a subset of the Python language, so not all
    features in Python work, but we provide enough functionality to compute on
    tensors and do control-dependent operations. For a complete guide, see the
    :ref:`language-reference`.

    ``torch.jit.script`` can be used as a function for modules and functions, and as a decorator
    ``@torch.jit.script`` for :ref:`torchscript-classes` and functions.

    Args:
        obj (callable, class, or ``nn.Module``):  The ``nn.Module``, function, or class type to
                                                  compile.

    Returns:
        If ``obj`` is ``nn.Module``, ``script`` returns
        a :class:`ScriptModule` object. The returned :class:`ScriptModule` will
        have the same set of sub-modules and parameters as the
        original ``nn.Module``. If ``obj`` is a standalone function,
        a :class:`ScriptFunction` will be returned.

    **Scripting a function**
        The ``@torch.jit.script`` decorator will construct a :class:`ScriptFunction`
        by compiling the body of the function.

        Example (scripting a function):

        .. testcode::

            import torch

            @torch.jit.script
            def foo(x, y):
                if x.max() > y.max():
                    r = x
                else:
                    r = y
                return r

            print(type(foo))  # torch.jit.ScriptFuncion

            # See the compiled graph as Python code
            print(foo.code)

            # Call the function using the TorchScript interpreter
            foo(torch.ones(2, 2), torch.ones(2, 2))

        .. testoutput::
            :hide:

            ...

    **Scripting an nn.Module**
        Scripting an ``nn.Module`` by default will compile the ``forward`` method and recursively
        compile any methods, submodules, and functions called by ``forward``. If a ``nn.Module`` only uses
        features supported in TorchScript, no changes to the original module code should be necessary. ``script``
        will construct :class:`ScriptModule` that has copies of the attributes, parameters, and methods of
        the original module.

        Example (scripting a simple module with a Parameter):

        .. testcode::

            import torch

            class MyModule(torch.nn.Module):
                def __init__(self, N, M):
                    super(MyModule, self).__init__()
                    # This parameter will be copied to the new ScriptModule
                    self.weight = torch.nn.Parameter(torch.rand(N, M))

                    # When this submodule is used, it will be compiled
                    self.linear = torch.nn.Linear(N, M)

                def forward(self, input):
                    output = self.weight.mv(input)

                    # This calls the `forward` method of the `nn.Linear` module, which will
                    # cause the `self.linear` submodule to be compiled to a `ScriptModule` here
                    output = self.linear(output)
                    return output

            scripted_module = torch.jit.script(MyModule(2, 3))

        Example (scripting a module with traced submodules):

        .. testcode::

            import torch
            import torch.nn as nn
            import torch.nn.functional as F

            class MyModule(nn.Module):
                def __init__(self):
                    super(MyModule, self).__init__()
                    # torch.jit.trace produces a ScriptModule's conv1 and conv2
                    self.conv1 = torch.jit.trace(nn.Conv2d(1, 20, 5), torch.rand(1, 1, 16, 16))
                    self.conv2 = torch.jit.trace(nn.Conv2d(20, 20, 5), torch.rand(1, 20, 16, 16))

                def forward(self, input):
                    input = F.relu(self.conv1(input))
                    input = F.relu(self.conv2(input))
                    return input

            scripted_module = torch.jit.script(MyModule())

        To compile a method other than ``forward`` (and recursively compile anything it calls), add
        the :func:`@torch.jit.export <torch.jit.export>` decorator to the method. To opt out of compilation
        use :func:`@torch.jit.ignore <torch.jit.ignore>` or :func:`@torch.jit.unused <torch.jit.unused>`.

        Example (an exported and ignored method in a module)::

            import torch
            import torch.nn as nn

            class MyModule(nn.Module):
                def __init__(self):
                    super(MyModule, self).__init__()

                @torch.jit.export
                def some_entry_point(self, input):
                    return input + 10

                @torch.jit.ignore
                def python_only_fn(self, input):
                    # This function won't be compiled, so any
                    # Python APIs can be used
                    import pdb
                    pdb.set_trace()

                def forward(self, input):
                    if self.training:
                        self.python_only_fn(input)
                    return input * 99

            scripted_module = torch.jit.script(MyModule())
            print(scripted_module.some_entry_point(torch.randn(2, 2)))
            print(scripted_module(torch.randn(2, 2)))
    """
    if not _enabled:
        return obj

    if optimize is not None:
        warnings.warn(
            "`optimize` is deprecated and has no effect. Use `with torch.jit.optimized_execution() instead"
        )
    if isinstance(obj, ScriptModule):
        return obj

    if isinstance(obj, torch.nn.Module):
        obj = call_prepare_scriptable_func(obj)
        return torch.jit._recursive.create_script_module(
            obj, torch.jit._recursive.infer_methods_to_compile)

    qualified_name = _qualified_name(obj)
    if inspect.isclass(obj):
        # If this type is a `nn.Module` subclass, they probably meant to pass
        # an instance instead of a Module
        if issubclass(obj, torch.nn.Module):
            raise RuntimeError("Type '{}' cannot be compiled since it inherits"
                               " from nn.Module,"
                               " pass an instance instead".format(obj))

        if not _is_new_style_class(obj):
            raise RuntimeError(
                "TorchScript classes must be new-style classes. "
                "Please inherit from 'object'.")
        if len(obj.mro()) > 2:
            raise RuntimeError(
                "TorchScript classes does not support inheritance yet. "
                "Please directly inherit from 'object'.")
        if _rcb is None:
            _rcb = _jit_internal.createResolutionCallbackFromFrame(_frames_up +
                                                                   1)
        _compile_and_register_class(obj, _rcb, qualified_name)
        return obj
    else:
        # this is a decorated fn, and we need to the underlying fn and its rcb
        if hasattr(obj, "__script_if_tracing_wrapper"):
            obj = obj.__original_fn
            _rcb = _jit_internal.createResolutionCallbackFromClosure(obj)

        _check_directly_compile_overloaded(obj)
        maybe_already_compiled_fn = _try_get_jit_cached_function(obj)
        if maybe_already_compiled_fn:
            return maybe_already_compiled_fn
        ast = get_jit_def(obj, obj.__name__)
        if _rcb is None:
            _rcb = _jit_internal.createResolutionCallbackFromClosure(obj)
        fn = torch._C._jit_script_compile(qualified_name, ast, _rcb,
                                          get_default_args(obj))
        # Forward docstrings
        fn.__doc__ = obj.__doc__
        _set_jit_function_cache(obj, fn)
        return fn
Ejemplo n.º 8
0
def make_stub(func, name):
    rcb = _jit_internal.createResolutionCallbackFromClosure(func)
    ast = get_jit_def(func, name, self_name="RecursiveScriptModule")
    return ScriptMethodStub(rcb, ast, func)
Ejemplo n.º 9
0
def recursive_script(mod, exclude_methods=()):
    """
    Makes a ScriptModule from an nn.Module. If `_methods` is provided,
    these methods are treated as @script_methods. If not, it defaults to
    `('forward',)`. Methods accessed in forward are scripted on demand.
    """
    if isinstance(mod, torch.jit.ScriptModule):
        return mod

    if isinstance(mod, (torch.nn.ModuleList, torch.nn.Sequential, torch.nn.ModuleDict)):
        # Create constant versions for the iterable modules
        return create_constant_iterable_module(mod)

    if not hasattr(mod, '_parameters'):
        raise RuntimeError("'{}' has not been initialized, did you forget to call 'super()'?"
                           .format(type(mod).__name__))

    methods = ()
    if hasattr(mod, 'forward'):
        if mod.forward.__func__ == torch.nn.Module.forward:
            raise RuntimeError("No forward method was defined on {}".format(mod))
        if not _jit_internal.is_ignored_fn(mod.forward):
            methods = ('forward',)
    exported = []
    overloads = []
    for name in dir(mod):
        item = getattr(mod, name)
        if callable(item):
            if _jit_internal.get_torchscript_modifier(item) is _jit_internal.FunctionModifiers.EXPORT:
                exported.append(name)

            # builtin functions like repr() in python 2 do not have __module__ defined
            if hasattr(item, "__module__") and item.__module__ is not None:
                method_overloads = _jit_internal._get_overloaded_methods(item, mod.__class__)

                if method_overloads is not None:
                    overloads.append((item, method_overloads))

    methods = methods + tuple(exported)

    methods = tuple(name for name in methods if name not in exclude_methods)

    overload_name_mappings = dict(getattr(mod, "__overloads__", {}))
    overload_stubs = []

    for orig_fn, overload_fns in overloads:
        orig_ast = torch.jit.get_jit_def(orig_fn, self_name="ScriptModule")
        names = list(map(lambda i: orig_ast.name().name + "__" + str(i), range(len(overload_fns))))
        overload_name_mappings[orig_ast.name().name] = names
        for overload_fn, name in zip(overload_fns, names):
            torch.jit._check_no_signature(overload_fn)
            over_ast = torch.jit.get_jit_def(overload_fn, self_name="ScriptModule")
            new_ast = torch._C._replace_overloaded_method_decl(over_ast.decl(), orig_ast, name)
            _rcb = _jit_internal.createResolutionCallbackFromClosure(orig_fn)
            overload_stubs.append(torch.jit.ScriptMethodStub(_rcb, new_ast, overload_fn))

    mod.__overloads__ = overload_name_mappings

    # we shouldn't directly compile overloaded methods, just its overloads
    def ignore_overloaded(method_name):
        return method_name not in overload_name_mappings

    def make_stub(method):
        func = get_function_from_type(type(mod), method)
        return torch.jit.script_method(func, _jit_internal.createResolutionCallbackFromClosure(func))

    filtered_methods = filter(ignore_overloaded, methods)
    stubs = list(map(make_stub, filtered_methods))
    return copy_to_script_module(mod, overload_stubs + stubs)
Ejemplo n.º 10
0
 def make_stub(method):
     func = get_function_from_type(type(mod), method)
     return torch.jit.script_method(func, _jit_internal.createResolutionCallbackFromClosure(func))