예제 #1
0
    _try_get_jit_cached_function,
    _try_get_jit_cached_overloads,
    _set_jit_function_cache,
    _set_jit_overload_cache,
)
from torch.overrides import (has_torch_function, has_torch_function_unary,
                             has_torch_function_variadic)

torch._C.ScriptMethod.graph_for = _graph_for  # type: ignore
torch._C.ScriptFunction.graph_for = _graph_for  # type: ignore
ScriptFunction = torch._C.ScriptFunction
ScriptFunction.__doc__ = """
Functionally equivalent to a :class:`ScriptModule`, but represents a single
function and does not have any attributes or Parameters.
"""
set_module(ScriptFunction, "torch.jit")

if _enabled:
    Attribute = collections.namedtuple("Attribute", ["value", "type"])
else:

    def Attribute(value, type):  # type: ignore
        return value


# ScriptClasses must be new-style classes because we construct them using their
# __new__ method.
def _is_new_style_class(cls):
    if hasattr(cls, "__class__"):
        return "__dict__" in dir(cls) or hasattr(cls, "__slots__")
예제 #2
0
"""Async API
This module contains the API for parallelism in TorchScript, notably:
    * torch.jit.fork
    * torch.jit.wait

This is not intended to be imported directly; please use the exposed
functionalities in `torch.jit`.
"""

import torch

from torch.utils import set_module
from torch.jit._builtins import _register_builtin
from torch._jit_internal import Future

set_module(Future, "torch.jit")


def fork(func, *args, **kwargs):
    r"""
    Creates an asynchronous task executing `func` and a reference to the value
    of the result of this execution. `fork` will return immediately,
    so the return value of `func` may not have been computed yet. To force completion
    of the task and access the return value invoke `torch.jit.wait` on the Future. `fork` invoked
    with a `func` which returns `T` is typed as `torch.jit.Future[T]`. `fork` calls can be arbitrarily
    nested, and may be invoked with positional and keyword arguments.
    Asynchronous execution will only occur when run in TorchScript. If run in pure python,
    `fork` will not execute in parallel. `fork` will also not execute in parallel when invoked
    while tracing, however the `fork` and `wait` calls will be captured in the exported IR Graph.
    Warning:
        `fork` tasks will execute non-deterministically. We recommend only spawning
예제 #3
0
    _try_get_jit_cached_function,
    _try_get_jit_cached_overloads,
    _set_jit_function_cache,
    _set_jit_overload_cache,
)
from torch.overrides import (has_torch_function, has_torch_function_unary,
                             has_torch_function_variadic)

torch._C.ScriptMethod.graph_for = _graph_for  # type: ignore
torch._C.ScriptFunction.graph_for = _graph_for  # type: ignore
ScriptFunction = torch._C.ScriptFunction
ScriptFunction.__doc__ = """
Functionally equivalent to a :class:`ScriptModule`, but represents a single
function and does not have any attributes or Parameters.
"""
set_module(ScriptFunction, "torch.jit")

if _enabled:
    Attribute = collections.namedtuple("Attribute", ["value", "type"])
else:

    def Attribute(value, type):  # type: ignore
        return value


Attribute.__doc__ = """
    This method is a pass-through function that returns `value`, mostly
    used to indicate to the TorchScript compiler that the left-hand side
    expression is a class instance attribute with type of `type`. Note that
    `torch.jit.Attribute` should only be used in `__init__` method of `nn.Module`
    subclasses.
예제 #4
0
_wait = wait


def export_opnames(m):
    r"""
        Generates new bytecode for a Script module and returns what the op list
        would be for a Script Module based off the current code base. If you
        have a LiteScriptModule and want to get the currently present
        list of ops call _export_operator_list instead.
    """
    return torch._C._export_opnames(m._c)


# torch.jit.Error
Error = torch._C.JITException
set_module(Error, "torch.jit")
# This is not perfect but works in common cases
Error.__name__ = "Error"
Error.__qualname__ = "Error"


# for use in python if using annotate
def annotate(the_type, the_value):
    """
    This method is a pass-through function that returns `the_value`, used to hint TorchScript
    compiler the type of `the_value`. It is a no-op when running outside of TorchScript.

    Though TorchScript can infer correct type for most Python expressions, there are some cases where
    type infernece can be wrong, including:
    - Empty containers like `[]` and `{}`, which TorchScript assumes to be container of `Tensor`s
    - Optional types like `Optional[T]` but assigned a valid value of type `T`, TorchScript would assume
예제 #5
0
import contextlib
import functools
import os
import pathlib

# These are imported so users can access them from the `torch.jit` module
from torch._jit_internal import Final, _overload, _overload_method
from torch._jit_internal import ignore, export, unused
from torch.jit._script import script, Attribute, ScriptModule, is_scripting, script_method, \
    RecursiveScriptModule, ScriptWarning, interface
from torch.jit._trace import trace, trace_module, TracedModule, TracerWarning, TracingCheckError, \
    is_tracing, ONNXTracedModule, _unique_state_dict, _flatten, TopLevelTracedModule
from torch.jit._async import fork, wait
from torch.jit._serialization import save, load

set_module(Future, "torch.jit")

# For backwards compatibility
_fork = fork
_wait = wait

@contextlib.contextmanager
def optimized_execution(should_optimize):
    """
    A context manager that controls whether the JIT's executor will run
    optimizations before executing a function.
    """
    stored_flag = torch._C._get_graph_executor_optimize()
    torch._C._set_graph_executor_optimize(should_optimize)
    try:
        yield