def _recursive_compile_class(obj, loc): _qual_name = _qualified_name(obj) # We're starting a new compilation, so update the error call stack in # case it fails error_stack = torch._C.CallStack(_qual_name, loc) rcb = _jit_internal.createResolutionCallbackForClassMethods(obj) _compile_and_register_class(obj, rcb, _qual_name)
def interface(obj): if not inspect.isclass(obj): raise RuntimeError("interface must be applied to a class") if not _is_new_style_class(obj): raise RuntimeError("TorchScript interfaces must inherit from 'object'") # Expected MRO is: # User module # torch.nn.modules.module.Module # object is_module_interface = issubclass(obj, torch.nn.Module) and len( obj.mro()) == 3 if not is_module_interface and len(obj.mro()) > 2: raise RuntimeError( "TorchScript interface does not support inheritance yet. " "Please directly inherit from 'object' or 'nn.Module'.") qualified_name = _qualified_name(obj) rcb = _jit_internal.createResolutionCallbackFromFrame(1) # if this type is a `nn.Module` subclass, generate an module interface type # instead of a class interface type, an module interface type only compile # the user provided methods as part of the interface ast = get_jit_class_def(obj, obj.__name__) torch._C._jit_script_interface_compile(qualified_name, ast, rcb, is_module_interface) obj.__torch_script_interface__ = True return obj
def _check_directly_compile_overloaded(obj): qual_name = _qualified_name(obj) if _jit_internal._get_fn_overloads(qual_name) or _try_get_jit_cached_overloads(obj): raise RuntimeError( "Function {} cannot be directly compiled because it" " is overloaded. It must be used in a context of a function" " where its inputs can determine which overload to call.".format(qual_name) )
def _check_no_signature(func): signature = torch.jit.annotations.get_signature(func, None, fake_range(), inspect.ismethod(func)) if signature is None: qual_name = _jit_internal._qualified_name(func) raise RuntimeError( "Must explicitly add type annotations to overloaded functions: {}". format(qual_name))
def test_sc(obj, optimize=None, _frames_up=0, _rcb=None): qualified_name = _qualified_name(obj) if inspect.isclass(obj): # If this type is a `nn.Module` subclass, they probably meant to pass # an instance instead of a Module if issubclass(obj, torch.nn.Module): raise RuntimeError("Type '{}' cannot be compiled since it inherits" " from nn.Module," " pass an instance instead".format(obj)) if not _is_new_style_class(obj): raise RuntimeError("TorchScript classes must be new-style classes. " "Please inherit from 'object'.") if len(obj.mro()) > 2: raise RuntimeError("TorchScript classes does not support inheritance yet. " "Please directly inherit from 'object'.") if _rcb is None: _rcb = _jit_internal.createResolutionCallbackFromFrame(_frames_up + 1) _compile_and_register_class(obj, _rcb, qualified_name) return obj else: #_check_directly_compile_overloaded(obj) #maybe_already_compiled_fn = _try_get_jit_cached_function(obj) #if maybe_already_compiled_fn: # return maybe_already_compiled_fn ast = get_jit_def(obj, obj.__name__) print("---ast---") print(ast) if _rcb is None: _rcb = _jit_internal.createResolutionCallbackFromClosure(obj) print("---rcb---") print(_rcb) fn = torch._C._jit_script_compile(qualified_name, ast, _rcb, get_default_args(obj)) # Forward docstrings fn.__doc__ = obj.__doc__ #_set_jit_function_cache(obj, fn) print("---scripted_fn---") print(fn) print("---scripted_fn.code---") print(fn.code) print("---scripted_fn.schema---") print(fn.schema) print("---scripted_fn.graph---") print(fn.graph) print("---scripted_fn.name---") print(fn.name) return fn
def _get_overloads(obj): # check for cached compiled fns existing_compiled_fns = _try_get_jit_cached_overloads(obj) qual_name = _qualified_name(obj) uncompiled_overloads = _jit_internal._get_fn_overloads(qual_name) if uncompiled_overloads is None: return existing_compiled_fns compiled_fns = [] for overload_fn in uncompiled_overloads: compiled_fns.append( _compile_function_with_overload(overload_fn, qual_name, obj)) if existing_compiled_fns: compiled_fns = existing_compiled_fns + compiled_fns # cache compilation, remove information stored to do compilation _set_jit_overload_cache(obj, compiled_fns) _jit_internal._clear_fn_overloads(qual_name) return compiled_fns
def create_script_class(obj): """ Create and return a RecursiveScriptClass instance from a Python object. Arguments: obj: A Python object. """ qualified_class_name = _jit_internal._qualified_name(type(obj)) rcb = _jit_internal.createResolutionCallbackForClassMethods(type(obj)) # Script the type of obj if it hasn't already been scripted. _compile_and_register_class(type(obj), rcb, qualified_class_name) class_ty = _python_cu.get_class(qualified_class_name) # Create an empty torch._C.ScriptObject with the scripted type. cpp_object = torch._C._create_object_with_type(class_ty) # Copy all of the attributes over to the torch._C.ScriptObject. for name, value in obj.__dict__.items(): cpp_object.setattr(name, value) # Wrap the torch._C.ScriptObject in a RecursiveScriptClass instance. return wrap_cpp_class(cpp_object)
def _script_pdt(obj, optimize=None, _frames_up=0, _rcb=None, example_inputs: Optional[List[Tuple]] = None): # This is a private API, intended for internal use only. Usage of this API is only for experimental # purposes only and is highly discouraged. global type_trace_db if not _enabled: return obj if optimize is not None: warnings.warn( "`optimize` is deprecated and has no effect. Use `with torch.jit.optimized_execution() instead" ) # No-op for modules and functions that are already scripted if isinstance(obj, ScriptModule): return obj if isinstance(obj, ScriptFunction): return obj qualified_name = _qualified_name(obj) # If MonkeyType is installed, enable profile directed type annotation # Check if example_inputs are defined and generate call traces # for the method by running eager mode version of the method with # the provide example inputs. This logs all the traces in type_trace_db type_trace_db = JitTypeTraceStore() if monkeytype_trace: monkeytype_config = JitTypeTraceConfig(type_trace_db) with monkeytype_trace(monkeytype_config): for example_input in example_inputs: # type: ignore[union-attr] obj(*example_input) else: warnings.warn( "Warning: monkeytype is not installed. Please install https://github.com/Instagram/MonkeyType " "to enable Profile-Directed Typing in TorchScript. Refer to " "https://github.com/Instagram/MonkeyType/blob/master/README.rst to install MonkeyType. " ) return script(obj, optimize, _frames_up, _rcb)
def script(obj, optimize=None, _frames_up=0, _rcb=None): r""" Scripting a function or ``nn.Module`` will inspect the source code, compile it as TorchScript code using the TorchScript compiler, and return a :class:`ScriptModule` or :class:`ScriptFunction`. TorchScript itself is a subset of the Python language, so not all features in Python work, but we provide enough functionality to compute on tensors and do control-dependent operations. For a complete guide, see the :ref:`language-reference`. ``torch.jit.script`` can be used as a function for modules and functions, and as a decorator ``@torch.jit.script`` for :ref:`torchscript-classes` and functions. Args: obj (callable, class, or ``nn.Module``): The ``nn.Module``, function, or class type to compile. Returns: If ``obj`` is ``nn.Module``, ``script`` returns a :class:`ScriptModule` object. The returned :class:`ScriptModule` will have the same set of sub-modules and parameters as the original ``nn.Module``. If ``obj`` is a standalone function, a :class:`ScriptFunction` will be returned. **Scripting a function** The ``@torch.jit.script`` decorator will construct a :class:`ScriptFunction` by compiling the body of the function. Example (scripting a function): .. testcode:: import torch @torch.jit.script def foo(x, y): if x.max() > y.max(): r = x else: r = y return r print(type(foo)) # torch.jit.ScriptFuncion # See the compiled graph as Python code print(foo.code) # Call the function using the TorchScript interpreter foo(torch.ones(2, 2), torch.ones(2, 2)) .. testoutput:: :hide: ... **Scripting an nn.Module** Scripting an ``nn.Module`` by default will compile the ``forward`` method and recursively compile any methods, submodules, and functions called by ``forward``. If a ``nn.Module`` only uses features supported in TorchScript, no changes to the original module code should be necessary. ``script`` will construct :class:`ScriptModule` that has copies of the attributes, parameters, and methods of the original module. Example (scripting a simple module with a Parameter): .. testcode:: import torch class MyModule(torch.nn.Module): def __init__(self, N, M): super(MyModule, self).__init__() # This parameter will be copied to the new ScriptModule self.weight = torch.nn.Parameter(torch.rand(N, M)) # When this submodule is used, it will be compiled self.linear = torch.nn.Linear(N, M) def forward(self, input): output = self.weight.mv(input) # This calls the `forward` method of the `nn.Linear` module, which will # cause the `self.linear` submodule to be compiled to a `ScriptModule` here output = self.linear(output) return output scripted_module = torch.jit.script(MyModule(2, 3)) Example (scripting a module with traced submodules): .. testcode:: import torch import torch.nn as nn import torch.nn.functional as F class MyModule(nn.Module): def __init__(self): super(MyModule, self).__init__() # torch.jit.trace produces a ScriptModule's conv1 and conv2 self.conv1 = torch.jit.trace(nn.Conv2d(1, 20, 5), torch.rand(1, 1, 16, 16)) self.conv2 = torch.jit.trace(nn.Conv2d(20, 20, 5), torch.rand(1, 20, 16, 16)) def forward(self, input): input = F.relu(self.conv1(input)) input = F.relu(self.conv2(input)) return input scripted_module = torch.jit.script(MyModule()) To compile a method other than ``forward`` (and recursively compile anything it calls), add the :func:`@torch.jit.export <torch.jit.export>` decorator to the method. To opt out of compilation use :func:`@torch.jit.ignore <torch.jit.ignore>` or :func:`@torch.jit.unused <torch.jit.unused>`. Example (an exported and ignored method in a module):: import torch import torch.nn as nn class MyModule(nn.Module): def __init__(self): super(MyModule, self).__init__() @torch.jit.export def some_entry_point(self, input): return input + 10 @torch.jit.ignore def python_only_fn(self, input): # This function won't be compiled, so any # Python APIs can be used import pdb pdb.set_trace() def forward(self, input): if self.training: self.python_only_fn(input) return input * 99 scripted_module = torch.jit.script(MyModule()) print(scripted_module.some_entry_point(torch.randn(2, 2))) print(scripted_module(torch.randn(2, 2))) """ if not _enabled: return obj if optimize is not None: warnings.warn( "`optimize` is deprecated and has no effect. Use `with torch.jit.optimized_execution() instead" ) if isinstance(obj, ScriptModule): return obj if isinstance(obj, torch.nn.Module): obj = call_prepare_scriptable_func(obj) return torch.jit._recursive.create_script_module( obj, torch.jit._recursive.infer_methods_to_compile) qualified_name = _qualified_name(obj) if inspect.isclass(obj): # If this type is a `nn.Module` subclass, they probably meant to pass # an instance instead of a Module if issubclass(obj, torch.nn.Module): raise RuntimeError("Type '{}' cannot be compiled since it inherits" " from nn.Module," " pass an instance instead".format(obj)) if not _is_new_style_class(obj): raise RuntimeError( "TorchScript classes must be new-style classes. " "Please inherit from 'object'.") if len(obj.mro()) > 2: raise RuntimeError( "TorchScript classes does not support inheritance yet. " "Please directly inherit from 'object'.") if _rcb is None: _rcb = _jit_internal.createResolutionCallbackFromFrame(_frames_up + 1) _compile_and_register_class(obj, _rcb, qualified_name) return obj else: # this is a decorated fn, and we need to the underlying fn and its rcb if hasattr(obj, "__script_if_tracing_wrapper"): obj = obj.__original_fn _rcb = _jit_internal.createResolutionCallbackFromClosure(obj) _check_directly_compile_overloaded(obj) maybe_already_compiled_fn = _try_get_jit_cached_function(obj) if maybe_already_compiled_fn: return maybe_already_compiled_fn ast = get_jit_def(obj, obj.__name__) if _rcb is None: _rcb = _jit_internal.createResolutionCallbackFromClosure(obj) fn = torch._C._jit_script_compile(qualified_name, ast, _rcb, get_default_args(obj)) # Forward docstrings fn.__doc__ = obj.__doc__ _set_jit_function_cache(obj, fn) return fn
def rpc_async(to, func, args=None, kwargs=None): r""" Make a non-blocking RPC call to run function ``func`` on worker ``to``. RPC messages are sent and received in parallel to execution of Python code. This method is thread-safe. This method will immediately return a Future that can be awaited on. Arguments: to (str or WorkerInfo): id or name of the destination worker. func (callable): any callable function. python callable, builtin or annotated TorchScript functions (like meth:`torch.add`) can be sent over RPC more efficiently. args (tuple): the argument tuple for the ``func`` invocation. kwargs (dict): is a dictionary of keyword arguments for the ``func`` invocation. Returns: Returns a Future object that can be waited on. When completed, the return value of ``func`` on ``args`` and ``kwargs`` can be retrieved from the Future object. Example:: Make sure that ``MASTER_ADDRESS`` and ``MASTER_PORT`` are set properly on both workers. Refer to :meth:`~torch.distributed.init_process_group` API for more details. For example, >>> export MASTER_ADDRESS=localhost >>> export MASTER_port=5678 Then run the following code in two different processes: >>> # On worker 0: >>> import torch >>> import torch.distributed.rpc as rpc >>> rpc.init_rpc("worker0", rank=0, world_size=2) >>> fut1 = rpc.rpc_async("worker1", torch.add, args=(torch.ones(2), 3)) >>> fut2 = rpc.rpc_async("worker1", min, args=(1, 2)) >>> result = fut1.wait() + fut2.wait() >>> rpc.shutdown() >>> # On worker 1: >>> import torch.distributed.rpc as rpc >>> rpc.init_rpc("worker1", rank=1, world_size=2) >>> rpc.shutdown() If invoking an annotated TorchScript function, then run the following code in two different processes: >>> # On worker 0: >>> @torch.jit.script >>> def my_script_add(t1, t2): >>> return torch.add(t1, t2) >>> import torch.distributed.rpc as rpc >>> rpc.init_rpc("worker0", rank=0, world_size=2) >>> fut = rpc.rpc_async("worker1", my_script_add, args=(torch.ones(2), 3)) >>> ret = fut.wait() >>> rpc.shutdown() >>> # On worker 1: >>> import torch.distributed.rpc as rpc >>> rpc.init_rpc("worker1", rank=1, world_size=2) >>> rpc.shutdown() """ # If invoking an annotated TorchScript function, # call the internal API _rpc_async_torchscript() if isinstance(func, torch.jit.ScriptFunction): fut = _rpc_async_torchscript(to, _qualified_name(func), args, kwargs) else: fut = _invoke_rpc(to, func, RPCExecMode.ASYNC, args, kwargs) return fut
def trace( func, example_inputs, optimize=None, check_trace=True, check_inputs=None, check_tolerance=1e-5, strict=True, _force_outplace=False, _module_class=None, _compilation_unit=_python_cu, ): """ Trace a function and return an executable or :class:`ScriptFunction` that will be optimized using just-in-time compilation. Tracing is ideal for code that operates only on ``Tensor``\\s and lists, dictionaries, and tuples of ``Tensor``\\s. Using `torch.jit.trace` and `torch.jit.trace_module`, you can turn an existing module or Python function into a TorchScript :class:`ScriptFunction` or :class:`ScriptModule`. You must provide example inputs, and we run the function, recording the operations performed on all the tensors. * The resulting recording of a standalone function produces `ScriptFunction`. * The resulting recording of `nn.Module.forward` or `nn.Module` produces `ScriptModule`. This module also contains any parameters that the original module had as well. Warning: Tracing only correctly records functions and modules which are not data dependent (e.g., do not have conditionals on data in tensors) and do not have any untracked external dependencies (e.g., perform input/output or access global variables). Tracing only records operations done when the given function is run on the given tensors. Therefore, the returned `ScriptModule` will always run the same traced graph on any input. This has some important implications when your module is expected to run different sets of operations, depending on the input and/or the module state. For example, * Tracing will not record any control-flow like if-statements or loops. When this control-flow is constant across your module, this is fine and it often inlines the control-flow decisions. But sometimes the control-flow is actually part of the model itself. For instance, a recurrent network is a loop over the (possibly dynamic) length of an input sequence. * In the returned :class:`ScriptModule`, operations that have different behaviors in ``training`` and ``eval`` modes will always behave as if it is in the mode it was in during tracing, no matter which mode the `ScriptModule` is in. In cases like these, tracing would not be appropriate and :func:`scripting <torch.jit.script>` is a better choice. If you trace such models, you may silently get incorrect results on subsequent invocations of the model. The tracer will try to emit warnings when doing something that may cause an incorrect trace to be produced. Arguments: func (callable or torch.nn.Module): A Python function or `torch.nn.Module` that will be run with `example_inputs`. `func` arguments and return values must be tensors or (possibly nested) tuples that contain tensors. When a module is passed `torch.jit.trace`, only the ``forward`` method is run and traced (see :func:`torch.jit.trace <torch.jit.trace_module>` for details). example_inputs (tuple or torch.Tensor): A tuple of example inputs that will be passed to the function while tracing. The resulting trace can be run with inputs of different types and shapes assuming the traced operations support those types and shapes. `example_inputs` may also be a single Tensor in which case it is automatically wrapped in a tuple. Keyword arguments: check_trace (``bool``, optional): Check if the same inputs run through traced code produce the same outputs. Default: ``True``. You might want to disable this if, for example, your network contains non- deterministic ops or if you are sure that the network is correct despite a checker failure. check_inputs (list of tuples, optional): A list of tuples of input arguments that should be used to check the trace against what is expected. Each tuple is equivalent to a set of input arguments that would be specified in ``example_inputs``. For best results, pass in a set of checking inputs representative of the space of shapes and types of inputs you expect the network to see. If not specified, the original ``example_inputs`` are used for checking check_tolerance (float, optional): Floating-point comparison tolerance to use in the checker procedure. This can be used to relax the checker strictness in the event that results diverge numerically for a known reason, such as operator fusion. strict (``bool``, optional): run the tracer in a strict mode or not (default: ``True``). Only turn this off when you want the tracer to record your mutable container types (currently ``list``/``dict``) and you are sure that the container you are using in your problem is a ``constant`` structure and does not get used as control flow (if, for) conditions. Returns: If `func` is `nn.Module` or ``forward`` of `nn.Module`, `trace` returns a :class:`ScriptModule` object with a single ``forward`` method containing the traced code. The returned `ScriptModule` will have the same set of sub-modules and parameters as the original ``nn.Module``. If ``func`` is a standalone function, ``trace`` returns `ScriptFunction`. Example (tracing a function): .. testcode:: import torch def foo(x, y): return 2 * x + y # Run `foo` with the provided inputs and record the tensor operations traced_foo = torch.jit.trace(foo, (torch.rand(3), torch.rand(3))) # `traced_foo` can now be run with the TorchScript interpreter or saved # and loaded in a Python-free environment Example (tracing an existing module):: import torch import torch.nn as nn class Net(nn.Module): def __init__(self): super(Net, self).__init__() self.conv = nn.Conv2d(1, 1, 3) def forward(self, x): return self.conv(x) n = Net() example_weight = torch.rand(1, 1, 3, 3) example_forward_input = torch.rand(1, 1, 3, 3) # Trace a specific method and construct `ScriptModule` with # a single `forward` method module = torch.jit.trace(n.forward, example_forward_input) # Trace a module (implicitly traces `forward`) and construct a # `ScriptModule` with a single `forward` method module = torch.jit.trace(n, example_forward_input) """ if not _enabled: return func if optimize is not None: warnings.warn( "`optimize` is deprecated and has no effect. Use `with torch.jit.optimized_execution() instead" ) if isinstance(func, torch.jit.ScriptModule): # it is hard to trace it because the forward method on ScriptModule is already defined, so it # would result in an error. warnings.warn( "The input to trace is already a ScriptModule, tracing it is a no-op. Returning the object as is." ) return func if isinstance(func, torch.nn.Module): return trace_module( func, {"forward": example_inputs}, None, check_trace, wrap_check_inputs(check_inputs), check_tolerance, strict, _force_outplace, _module_class, ) if ( hasattr(func, "__self__") and isinstance(func.__self__, torch.nn.Module) and func.__name__ == "forward" ): return trace_module( func.__self__, {"forward": example_inputs}, None, check_trace, wrap_check_inputs(check_inputs), check_tolerance, strict, _force_outplace, _module_class, ) # Special case for common case of passing a single Tensor if isinstance(example_inputs, (torch.Tensor, dict)): example_inputs = (example_inputs,) # done primarily so that weird iterables fail here and not pybind11 code elif not isinstance(example_inputs, tuple): example_inputs = tuple(example_inputs) var_lookup_fn = _create_interpreter_name_lookup_fn(0) if hasattr(func, "__self__") and isinstance(func.__self__, torch.nn.Module): raise AttributeError( "trace doesn't support compiling individual module's functions.\n" "Please use trace_module" ) name = _qualified_name(func) traced = torch._C._create_function_from_trace( name, func, example_inputs, var_lookup_fn, strict, _force_outplace ) # Check the trace against new traces created from user-specified inputs if check_trace: if check_inputs is not None: _check_trace( check_inputs, func, traced, check_tolerance, strict, _force_outplace, False, _module_class, ) else: _check_trace( [example_inputs], func, traced, check_tolerance, strict, _force_outplace, False, _module_class, ) return traced