def _compile_function_with_overload(overload_fn, qual_name, impl_fn): overload_decl = get_jit_def(overload_fn, overload_fn.__name__).decl() overload_signature = torch.jit.annotations.get_signature( overload_fn, None, None, inspect.ismethod(overload_fn)) impl_ast = get_jit_def(impl_fn, impl_fn.__name__) overload_defaults = get_default_args(overload_fn) implementation_defaults = get_default_args(impl_fn) _rcb = _jit_internal.createResolutionCallbackFromClosure(impl_fn) _check_overload_defaults(implementation_defaults, overload_defaults, overload_decl.range()) fn = torch._C._jit_script_compile_overload( qual_name, overload_decl, impl_ast, _rcb, implementation_defaults, overload_signature, ) return fn
def test_sc(obj, optimize=None, _frames_up=0, _rcb=None): qualified_name = _qualified_name(obj) if inspect.isclass(obj): # If this type is a `nn.Module` subclass, they probably meant to pass # an instance instead of a Module if issubclass(obj, torch.nn.Module): raise RuntimeError("Type '{}' cannot be compiled since it inherits" " from nn.Module," " pass an instance instead".format(obj)) if not _is_new_style_class(obj): raise RuntimeError("TorchScript classes must be new-style classes. " "Please inherit from 'object'.") if len(obj.mro()) > 2: raise RuntimeError("TorchScript classes does not support inheritance yet. " "Please directly inherit from 'object'.") if _rcb is None: _rcb = _jit_internal.createResolutionCallbackFromFrame(_frames_up + 1) _compile_and_register_class(obj, _rcb, qualified_name) return obj else: #_check_directly_compile_overloaded(obj) #maybe_already_compiled_fn = _try_get_jit_cached_function(obj) #if maybe_already_compiled_fn: # return maybe_already_compiled_fn ast = get_jit_def(obj, obj.__name__) print("---ast---") print(ast) if _rcb is None: _rcb = _jit_internal.createResolutionCallbackFromClosure(obj) print("---rcb---") print(_rcb) fn = torch._C._jit_script_compile(qualified_name, ast, _rcb, get_default_args(obj)) # Forward docstrings fn.__doc__ = obj.__doc__ #_set_jit_function_cache(obj, fn) print("---scripted_fn---") print(fn) print("---scripted_fn.code---") print(fn.code) print("---scripted_fn.schema---") print(fn.schema) print("---scripted_fn.graph---") print(fn.graph) print("---scripted_fn.name---") print(fn.name) return fn
def script_method(fn): if not _enabled: return fn # NOTE: we need to traverse two frames here because the meta-class frame # for ScriptModule will be present, as opposed to invoking @script on a # a function or invoking define() on a CompilationUnit. # The stack will look like: # # 0. createResolutionCallback() # 1. script_method() # 2. ScriptModule metaclass frame # 3. Surrounding scope # # createResolutionCallback internally adds 1 to get us to the scope of this # function (the calling function). Adding 2 gets us to the proper surrounding scope. _rcb = _jit_internal.createResolutionCallbackFromFrame(frames_up=2) ast = get_jit_def(fn, fn.__name__, self_name="ScriptModule") return ScriptMethodStub(_rcb, ast, fn)
def script(obj, optimize=None, _frames_up=0, _rcb=None): r""" Scripting a function or ``nn.Module`` will inspect the source code, compile it as TorchScript code using the TorchScript compiler, and return a :class:`ScriptModule` or :class:`ScriptFunction`. TorchScript itself is a subset of the Python language, so not all features in Python work, but we provide enough functionality to compute on tensors and do control-dependent operations. For a complete guide, see the :ref:`language-reference`. ``torch.jit.script`` can be used as a function for modules and functions, and as a decorator ``@torch.jit.script`` for :ref:`torchscript-classes` and functions. Args: obj (callable, class, or ``nn.Module``): The ``nn.Module``, function, or class type to compile. Returns: If ``obj`` is ``nn.Module``, ``script`` returns a :class:`ScriptModule` object. The returned :class:`ScriptModule` will have the same set of sub-modules and parameters as the original ``nn.Module``. If ``obj`` is a standalone function, a :class:`ScriptFunction` will be returned. **Scripting a function** The ``@torch.jit.script`` decorator will construct a :class:`ScriptFunction` by compiling the body of the function. Example (scripting a function): .. testcode:: import torch @torch.jit.script def foo(x, y): if x.max() > y.max(): r = x else: r = y return r print(type(foo)) # torch.jit.ScriptFuncion # See the compiled graph as Python code print(foo.code) # Call the function using the TorchScript interpreter foo(torch.ones(2, 2), torch.ones(2, 2)) .. testoutput:: :hide: ... **Scripting an nn.Module** Scripting an ``nn.Module`` by default will compile the ``forward`` method and recursively compile any methods, submodules, and functions called by ``forward``. If a ``nn.Module`` only uses features supported in TorchScript, no changes to the original module code should be necessary. ``script`` will construct :class:`ScriptModule` that has copies of the attributes, parameters, and methods of the original module. Example (scripting a simple module with a Parameter): .. testcode:: import torch class MyModule(torch.nn.Module): def __init__(self, N, M): super(MyModule, self).__init__() # This parameter will be copied to the new ScriptModule self.weight = torch.nn.Parameter(torch.rand(N, M)) # When this submodule is used, it will be compiled self.linear = torch.nn.Linear(N, M) def forward(self, input): output = self.weight.mv(input) # This calls the `forward` method of the `nn.Linear` module, which will # cause the `self.linear` submodule to be compiled to a `ScriptModule` here output = self.linear(output) return output scripted_module = torch.jit.script(MyModule(2, 3)) Example (scripting a module with traced submodules): .. testcode:: import torch import torch.nn as nn import torch.nn.functional as F class MyModule(nn.Module): def __init__(self): super(MyModule, self).__init__() # torch.jit.trace produces a ScriptModule's conv1 and conv2 self.conv1 = torch.jit.trace(nn.Conv2d(1, 20, 5), torch.rand(1, 1, 16, 16)) self.conv2 = torch.jit.trace(nn.Conv2d(20, 20, 5), torch.rand(1, 20, 16, 16)) def forward(self, input): input = F.relu(self.conv1(input)) input = F.relu(self.conv2(input)) return input scripted_module = torch.jit.script(MyModule()) To compile a method other than ``forward`` (and recursively compile anything it calls), add the :func:`@torch.jit.export <torch.jit.export>` decorator to the method. To opt out of compilation use :func:`@torch.jit.ignore <torch.jit.ignore>` or :func:`@torch.jit.unused <torch.jit.unused>`. Example (an exported and ignored method in a module):: import torch import torch.nn as nn class MyModule(nn.Module): def __init__(self): super(MyModule, self).__init__() @torch.jit.export def some_entry_point(self, input): return input + 10 @torch.jit.ignore def python_only_fn(self, input): # This function won't be compiled, so any # Python APIs can be used import pdb pdb.set_trace() def forward(self, input): if self.training: self.python_only_fn(input) return input * 99 scripted_module = torch.jit.script(MyModule()) print(scripted_module.some_entry_point(torch.randn(2, 2))) print(scripted_module(torch.randn(2, 2))) """ if not _enabled: return obj if optimize is not None: warnings.warn( "`optimize` is deprecated and has no effect. Use `with torch.jit.optimized_execution() instead" ) if isinstance(obj, ScriptModule): return obj if isinstance(obj, torch.nn.Module): obj = call_prepare_scriptable_func(obj) return torch.jit._recursive.create_script_module( obj, torch.jit._recursive.infer_methods_to_compile) qualified_name = _qualified_name(obj) if inspect.isclass(obj): # If this type is a `nn.Module` subclass, they probably meant to pass # an instance instead of a Module if issubclass(obj, torch.nn.Module): raise RuntimeError("Type '{}' cannot be compiled since it inherits" " from nn.Module," " pass an instance instead".format(obj)) if not _is_new_style_class(obj): raise RuntimeError( "TorchScript classes must be new-style classes. " "Please inherit from 'object'.") if len(obj.mro()) > 2: raise RuntimeError( "TorchScript classes does not support inheritance yet. " "Please directly inherit from 'object'.") if _rcb is None: _rcb = _jit_internal.createResolutionCallbackFromFrame(_frames_up + 1) _compile_and_register_class(obj, _rcb, qualified_name) return obj else: # this is a decorated fn, and we need to the underlying fn and its rcb if hasattr(obj, "__script_if_tracing_wrapper"): obj = obj.__original_fn _rcb = _jit_internal.createResolutionCallbackFromClosure(obj) _check_directly_compile_overloaded(obj) maybe_already_compiled_fn = _try_get_jit_cached_function(obj) if maybe_already_compiled_fn: return maybe_already_compiled_fn ast = get_jit_def(obj, obj.__name__) if _rcb is None: _rcb = _jit_internal.createResolutionCallbackFromClosure(obj) fn = torch._C._jit_script_compile(qualified_name, ast, _rcb, get_default_args(obj)) # Forward docstrings fn.__doc__ = obj.__doc__ _set_jit_function_cache(obj, fn) return fn
def make_stub(func, name): rcb = _jit_internal.createResolutionCallbackFromClosure(func) ast = get_jit_def(func, name, self_name="RecursiveScriptModule") return ScriptMethodStub(rcb, ast, func)