Пример #1
0
def register_quantized_ops(domain: str, version: int):
    # Register all the non-quantized ops
    symbolic_registry.register_version("", version)
    # Register all quantized ops
    module = importlib.import_module("torch.onnx.symbolic_caffe2")
    symbolic_registry._symbolic_versions["caffe2"] = module
    quant_version_ops = inspect.getmembers(
        symbolic_registry._symbolic_versions["caffe2"])
    for op in quant_version_ops:
        if inspect.isfunction(
                op[1]) and not symbolic_registry.is_registered_op(
                    op[0], domain, version):
            aten_q_ops = [
                "relu",
                "_empty_affine_quantized",
                "dequantize",
                "quantize_per_tensor",
                "upsample_nearest2d",
                "avg_pool2d",
                "reshape",
                "slice",
                "cat",
                "max_pool2d",
                "sigmoid",
            ]
            if op[0] in aten_q_ops:
                symbolic_registry.register_op(op[0], op[1], "", version)
            symbolic_registry.register_op(op[0], op[1], domain, version)
Пример #2
0
def register_quantized_ops(domain, version):
    # Register all the non-quantized ops
    sym_registry.register_version('', version)
    # Register all quantized ops
    module = importlib.import_module('torch.onnx.symbolic_caffe2')
    sym_registry._symbolic_versions['caffe2'] = module
    quant_version_ops = getmembers(sym_registry._symbolic_versions['caffe2'])
    for op in quant_version_ops:
        if isfunction(op[1]) and not sym_registry.is_registered_op(op[0], domain, version):
            aten_q_ops = ['relu', '_empty_affine_quantized', 'dequantize', 'quantize_per_tensor', 'upsample_nearest2d']
            if op[0] in aten_q_ops:
                sym_registry.register_op(op[0], op[1], '', version)
            sym_registry.register_op(op[0], op[1], domain, version)
Пример #3
0
    def __init__(self, model: Callable, inputs: Any, **opts: Any):
        super().__init__(**opts)

        if self.dynamic_axes is None:
            self.dynamic_axes = {}

        # Load symbolic opset
        assert self.opset_version is not None
        sym_reg.register_version("", self.opset_version)  # type: ignore[no-untyped-call]

        self.original_model = model
        self.inputs = _to_tuple_if_not_sequence(inputs)

        self.attrs: Dict[TorchValueID, ONNXValueID] = {}
        self.node_doc_string: Dict[torch._C.Node, str] = {}
        self.node_scope: Dict[torch._C.Node, str] = {}

        self._convert()
Пример #4
0
from inspect import signature, _empty  # type: ignore[attr-defined]
from torch._C import _jit_get_all_schemas, FunctionSchema
from torch.onnx.symbolic_registry import _registry, register_version
from torch.onnx.symbolic_helper import _onnx_main_opset, _onnx_stable_opsets
from typing import Dict, List, Union

for v in _onnx_stable_opsets + [_onnx_main_opset]:
    register_version("", v)


class _TorchSchema:
    def __init__(self, schema: Union[FunctionSchema, str]) -> None:
        if isinstance(schema, FunctionSchema):
            self.name: str = schema.name
            self.overload_name: str = schema.overload_name
            self.arguments: List[str] = [arg.name for arg in schema.arguments]
            self.optional_arguments: List[str] = []
            self.returns: List[str] = [ret.name for ret in schema.returns]
            self.opsets: List[int] = []
        else:
            self.name = schema
            self.overload_name = ""
            self.arguments = []
            self.optional_arguments = []
            self.returns = []
            self.opsets = []

    def __str__(self) -> str:
        s = f"{self.name}.{self.overload_name}("
        s += ", ".join(self.arguments)
        s += ") -> ("
Пример #5
0
def _run_symbolic_function(g,
                           n,
                           inputs,
                           env,
                           operator_export_type=OperatorExportTypes.ONNX):
    # NB: Returning None means the node gets cloned as is into
    # the new graph
    try:
        from torch.onnx.symbolic_helper import _export_onnx_opset_version as opset_version
        import torch.onnx.symbolic_registry as sym_registry

        sym_registry.register_version('', opset_version)

        # See Note [Export inplace]
        # TODO: I think this is not necessary anymore
        if n.kind().endswith('_'):
            ns_op_name = n.kind()[:-1]
        else:
            ns_op_name = n.kind()
        ns, op_name = ns_op_name.split("::")

        if ns == "onnx":
            # Use the original node directly
            return None

        elif ns == "aten":
            is_exportable_aten_op = sym_registry.is_registered_op(
                op_name, '', opset_version)
            is_onnx_aten_export = operator_export_type == OperatorExportTypes.ONNX_ATEN
            is_aten_fallback_export = operator_export_type == OperatorExportTypes.ONNX_ATEN_FALLBACK
            if is_onnx_aten_export or (not is_exportable_aten_op
                                       and is_aten_fallback_export):
                # Direct ATen export requested
                attrs = {
                    k + "_" + n.kindOf(k)[0]: n[k]
                    for k in n.attributeNames()
                }
                outputs = n.outputsSize()
                attrs["outputs"] = outputs
                return _graph_at(g, op_name, *inputs, aten=True, **attrs)

            else:
                # Export it regularly
                attrs = {k: n[k] for k in n.attributeNames()}
                if not is_exportable_aten_op:
                    warnings.warn(
                        "ONNX export failed on ATen operator {} because "
                        "torch.onnx.symbolic_opset{}.{} does not exist".format(
                            op_name, opset_version, op_name))
                op_fn = sym_registry.get_registered_op(op_name, '',
                                                       opset_version)
                return op_fn(g, *inputs, **attrs)

        elif ns == "prim":
            if op_name == "Constant" and not n.mustBeNone():
                if n.kindOf("value") == "t":
                    return g.op("Constant", value_t=n["value"])
                elif n.kindOf("value") == "is":
                    value = torch.stack([torch.tensor(v) for v in n["value"]
                                         ]) if n["value"] else []
                    return g.op("Constant", value_t=value)
                elif n.output().type().kind() == "DeviceObjType":
                    return None
                else:
                    raise RuntimeError(
                        "Unsupported prim::Constant kind: `{}`. Send a bug report."
                        .format(n.kindOf("value")))
            elif n.mustBeNone(
            ) or op_name == "ListConstruct" or op_name == "ListUnpack":
                # None is not an ONNX operator; keep it as None
                # let the exporter handle finally eliminating these

                # For ListConstruct/ListUnpack, it will be erased in the ONNX peephole pass
                return None
            elif op_name == 'Loop' or op_name == 'If':
                new_op_outputs = g.op(op_name,
                                      *inputs,
                                      outputs=n.outputsSize())
                new_node = new_op_outputs[0].node(
                ) if n.outputsSize() > 1 else new_op_outputs.node()
                for b in n.blocks():
                    new_block = new_node.addBlock()
                    torch._C._jit_pass_onnx_block(b, new_block,
                                                  operator_export_type, env)
                return new_op_outputs
            else:
                # TODO: we sould lift prim's symbolic out
                symbolic_name = 'prim_' + op_name
                is_exportable = sym_registry.is_registered_op(
                    symbolic_name, '', opset_version)
                if not is_exportable:
                    warnings.warn(
                        "ONNX export failed on primitive operator {}; please report a bug"
                        .format(op_name))
                symbolic_fn = sym_registry.get_registered_op(
                    symbolic_name, '', opset_version)
                attrs = {k: n[k] for k in n.attributeNames()}
                return symbolic_fn(g, *inputs, **attrs)

        # custom ops
        elif sym_registry.is_registered_version(ns, opset_version):
            if not sym_registry.is_registered_op(op_name, ns, opset_version):
                warnings.warn(
                    "ONNX export failed on custom operator {}::{} because "
                    "torch.onnx.symbolic_opset{}.{} does not exist. "
                    "Have you registered your symbolic function with "
                    "torch.onnx.register_custom_op_symbolic(symbolic_name, symbolic_fn)?"
                    .format(ns, op_name, opset_version, op_name))
            symbolic_fn = sym_registry.get_registered_op(
                op_name, ns, opset_version)
            attrs = {k: n[k] for k in n.attributeNames()}
            return symbolic_fn(g, *inputs, **attrs)

        else:
            warnings.warn(
                "ONNX export failed on an operator with unrecognized namespace {}::{}; "
                "If you are trying to export a custom operator, make sure you registered "
                "it with the right domain and version."
                "Otherwise please report a bug".format(ns, op_name))
            return None

    except TypeError as e:
        # Handle the specific case where we didn't successfully dispatch.
        # Otherwise, the backtrace will have the clues you need.
        e.args = ("{} (occurred when translating {})".format(
            e.args[0], op_name), )
        raise
Пример #6
0
 def setup(self, model, opset_ver):
     self.model = model
     # dryrun to register every aten ops
     sym_reg.register_version('', opset_ver)
     self.opset_ver = opset_ver
Пример #7
0
import inspect
from typing import Dict, List, Union

import torch._C
from torch.onnx import _constants, symbolic_registry

for v in _constants.onnx_stable_opsets:
    symbolic_registry.register_version("", v)
symbolic_registry.register_version("", _constants.onnx_main_opset)


class _TorchSchema:
    def __init__(self, schema: Union[torch._C.FunctionSchema, str]) -> None:
        if isinstance(schema, torch._C.FunctionSchema):
            self.name: str = schema.name
            self.overload_name: str = schema.overload_name
            self.arguments: List[str] = [arg.name for arg in schema.arguments]
            self.optional_arguments: List[str] = []
            self.returns: List[str] = [ret.name for ret in schema.returns]
            self.opsets: List[int] = []
        else:
            self.name = schema
            self.overload_name = ""
            self.arguments = []
            self.optional_arguments = []
            self.returns = []
            self.opsets = []

    def __str__(self) -> str:
        s = f"{self.name}.{self.overload_name}("
        s += ", ".join(self.arguments)
Пример #8
0
import inspect
from typing import Dict, List, Union

import torch._C
from torch.onnx import symbolic_helper, symbolic_registry

for v in symbolic_helper._onnx_stable_opsets + [
        symbolic_helper._onnx_main_opset
]:
    symbolic_registry.register_version("", v)


class _TorchSchema:
    def __init__(self, schema: Union[torch._C.FunctionSchema, str]) -> None:
        if isinstance(schema, torch._C.FunctionSchema):
            self.name: str = schema.name
            self.overload_name: str = schema.overload_name
            self.arguments: List[str] = [arg.name for arg in schema.arguments]
            self.optional_arguments: List[str] = []
            self.returns: List[str] = [ret.name for ret in schema.returns]
            self.opsets: List[int] = []
        else:
            self.name = schema
            self.overload_name = ""
            self.arguments = []
            self.optional_arguments = []
            self.returns = []
            self.opsets = []

    def __str__(self) -> str:
        s = f"{self.name}.{self.overload_name}("
Пример #9
0
def _run_symbolic_function(g, n, inputs, env, operator_export_type=OperatorExportTypes.ONNX):
    # NB: Returning None means the node gets cloned as is into
    # the new graph
    try:
        import torch
        from torch.onnx.symbolic_helper import _export_onnx_opset_version as opset_version
        import torch.onnx.symbolic_registry as sym_registry

        sym_registry.register_version('', opset_version)

        # Quantized op symbolics are registered for opset 9 only.
        if operator_export_type == OperatorExportTypes.ONNX_ATEN_FALLBACK and opset_version == 9:
            import torch.onnx.symbolic_caffe2
            torch.onnx.symbolic_caffe2.register_quantized_ops('caffe2', opset_version)

        # See Note [Export inplace]
        # TODO: I think this is not necessary anymore
        if n.kind().endswith('_'):
            ns_op_name = n.kind()[:-1]
        else:
            ns_op_name = n.kind()
        ns, op_name = ns_op_name.split("::")
        if ns == "onnx":
            # Use the original node directly
            return None

        elif ns == "aten":
            is_exportable_aten_op = sym_registry.is_registered_op(op_name, '', opset_version)
            is_onnx_aten_export = operator_export_type == OperatorExportTypes.ONNX_ATEN
            is_aten_fallback_export = operator_export_type == OperatorExportTypes.ONNX_ATEN_FALLBACK
            if is_onnx_aten_export or (not is_exportable_aten_op and is_aten_fallback_export):
                # Direct ATen export requested
                attrs = {k + "_" + n.kindOf(k)[0]: n[k] for k in n.attributeNames()}
                outputs = n.outputsSize()
                attrs["outputs"] = outputs
                return _graph_at(g, op_name, *inputs, aten=True, **attrs)
            else:
                # Export it regularly
                domain = ''
                symbolic_fn = _find_symbolic_in_registry(domain, op_name, opset_version, operator_export_type)
                if symbolic_fn is None:
                    return None
                attrs = {k: n[k] for k in n.attributeNames()}
                return symbolic_fn(g, *inputs, **attrs)

        elif ns == "prim":
            if op_name == "Constant" and not n.mustBeNone():
                if n.kindOf("value") == "t":
                    return g.op("Constant", value_t=n["value"])
                if n.kindOf("value") == "s":
                    return g.op("Constant", value_s=n["value"])
                elif n.output().type().isSubtypeOf(ListType.ofInts()) or n.output().type().isSubtypeOf(ListType.ofFloats()):
                    vals = n.output().toIValue()
                    value = torch.stack([torch.tensor(v) for v in vals]) if len(vals) else []
                    return g.op("Constant", value_t=value)
                elif n.output().type().kind() == "DeviceObjType":
                    return None
                else:
                    raise RuntimeError("Unsupported prim::Constant kind: `{}`. Send a bug report.".format(
                        n.kindOf("value")))
            elif n.mustBeNone() or op_name == "ListConstruct" or op_name == "ListUnpack":
                # None is not an ONNX operator; keep it as None
                # Let the exporter handle and finally eliminate these ops
                # ListConstruct and ListUnpack will be erased in the ONNX peephole pass
                return None
            elif op_name == 'Loop' or op_name == 'If':
                new_op_outputs = g.op(op_name, *inputs, outputs=n.outputsSize())
                new_node = new_op_outputs[0].node() if n.outputsSize() > 1 else new_op_outputs.node()
                for b in n.blocks():
                    new_block = new_node.addBlock()
                    torch._C._jit_pass_onnx_block(b, new_block, operator_export_type, env)
                new_op_outputs = torch._C._jit_pass_fixup_onnx_controlflow_node(new_node, opset_version)
                return new_op_outputs
            else:
                symbolic_name = 'prim_' + op_name
                domain = ''
                symbolic_fn = _find_symbolic_in_registry(domain, symbolic_name, opset_version,
                                                         operator_export_type)
                if symbolic_fn is None:
                    return None
                attrs = {k: n[k] for k in n.attributeNames()}
                return symbolic_fn(g, *inputs, **attrs)

        elif ns == "quantized":
            domain = ''
            if operator_export_type == OperatorExportTypes.ONNX_ATEN_FALLBACK:
                domain = 'caffe2'
            symbolic_fn = _find_symbolic_in_registry(domain, op_name, opset_version, operator_export_type)
            if symbolic_fn is None:
                return None
            attrs = {k: n[k] for k in n.attributeNames()}
            return symbolic_fn(g, *inputs, **attrs)

        # custom ops
        elif sym_registry.is_registered_version(ns, opset_version):
            domain = ns
            symbolic_fn = _find_symbolic_in_registry(domain, op_name, opset_version, operator_export_type)
            if symbolic_fn is None:
                return None
            attrs = {k: n[k] for k in n.attributeNames()}
            return symbolic_fn(g, *inputs, **attrs)
        else:
            raise RuntimeError("ONNX export failed on an operator with unrecognized namespace {}::{}. "
                               "If you are trying to export a custom operator, make sure you registered "
                               "it with the right domain and version. "
                               "Otherwise, please report a bug.".format(ns, op_name))
    except RuntimeError:
        if operator_export_type == OperatorExportTypes.ONNX_FALLTHROUGH:
            return None
        raise
    except TypeError as e:
        # Handle the specific case where we didn't successfully dispatch.
        # Otherwise, the backtrace will have the clues you need.
        e.args = ("{} \n(Occurred when translating {}).".format(e.args[0], op_name),)
        raise
Пример #10
0
 def setup(self, model: nn.Module, opset_ver: int) -> None:
     self._model: Optional[nn.Module] = model
     # dryrun to register every aten ops
     sym_reg.register_version('',
                              opset_ver)  # type: ignore[no-untyped-call]
     self.opset_ver = opset_ver
Пример #11
0
def _run_symbolic_function_with_in_place(
        g, n, inputs, env, operator_export_type=OperatorExportTypes.ONNX):
    """
    Monkey-patched version of `_run_symbolic_function` function in `torch.onnx.utils`.
    The only change is that trailing '_' is no longer removed from `ns_op_name` for
    the dropout function.
    """
    try:
        import torch
        from torch.onnx.symbolic_helper import (
            _export_onnx_opset_version as opset_version, )
        import torch.onnx.symbolic_registry as sym_registry

        sym_registry.register_version("", opset_version)
        if operator_export_type == OperatorExportTypes.ONNX_ATEN_FALLBACK:
            import torch.onnx.symbolic_caffe2

            torch.onnx.symbolic_caffe2.register_quantized_ops(
                "caffe2", opset_version)

        ns_op_name = n.kind()
        ns, op_name = ns_op_name.split("::")
        if n.kind().endswith("_"):
            if op_name not in ["dropout_", "feature_dropout_"]:
                op_name = op_name[:-1]

        if ns == "onnx":
            # Use the original node directly
            return None

        elif ns == "aten":
            is_exportable_aten_op = sym_registry.is_registered_op(
                op_name, "", opset_version)
            is_onnx_aten_export = operator_export_type == OperatorExportTypes.ONNX_ATEN
            is_aten_fallback_export = (
                operator_export_type == OperatorExportTypes.ONNX_ATEN_FALLBACK)
            if is_onnx_aten_export or (not is_exportable_aten_op
                                       and is_aten_fallback_export):
                # Direct ATen export requested
                attrs = {
                    k + "_" + n.kindOf(k)[0]: n[k]
                    for k in n.attributeNames()
                }
                outputs = n.outputsSize()
                attrs["outputs"] = outputs
                return torch.onnx.utils._graph_at(g,
                                                  op_name,
                                                  *inputs,
                                                  aten=True,
                                                  **attrs)

            else:
                # Export it regularly
                attrs = {k: n[k] for k in n.attributeNames()}
                if not is_exportable_aten_op:
                    warnings.warn(
                        "ONNX export failed on ATen operator {} because "
                        "torch.onnx.symbolic_opset{}.{} does not exist".format(
                            op_name, opset_version, op_name))
                op_fn = sym_registry.get_registered_op(op_name, "",
                                                       opset_version)
                return op_fn(g, *inputs, **attrs)

        elif ns == "prim":
            if op_name == "Constant" and not n.mustBeNone():
                if n.kindOf("value") == "t":
                    return g.op("Constant", value_t=n["value"])
                if n.kindOf("value") == "s":
                    return g.op("Constant", value_s=n["value"])
                elif n.kindOf("value") == "is":
                    value = (torch.stack([torch.tensor(v) for v in n["value"]])
                             if n["value"] else [])
                    return g.op("Constant", value_t=value)
                elif n.output().type().kind() == "DeviceObjType":
                    return None
                else:
                    raise RuntimeError(
                        "Unsupported prim::Constant kind: `{}`".format(
                            n.kindOf("value")))
            elif (n.mustBeNone() or op_name == "ListConstruct"
                  or op_name == "ListUnpack"):
                # None is not an ONNX operator; keep it as None
                # let the exporter handle finally eliminating these

                # For ListConstruct/ListUnpack, it will be erased in the
                # ONNX peephole pass
                return None
            elif op_name == "Loop" or op_name == "If":
                new_op_outputs = g.op(op_name,
                                      *inputs,
                                      outputs=n.outputsSize())
                new_node = (new_op_outputs[0].node()
                            if n.outputsSize() > 1 else new_op_outputs.node())
                for b in n.blocks():
                    new_block = new_node.addBlock()
                    torch._C._jit_pass_onnx_block(b, new_block,
                                                  operator_export_type, env)
                return new_op_outputs
            else:
                symbolic_name = "prim_" + op_name
                is_exportable = sym_registry.is_registered_op(
                    symbolic_name, "", opset_version)
                if not is_exportable:
                    warnings.warn(
                        "ONNX export failed on primitive operator {}".format(
                            op_name))
                symbolic_fn = sym_registry.get_registered_op(
                    symbolic_name, "", opset_version)
                attrs = {k: n[k] for k in n.attributeNames()}
                return symbolic_fn(g, *inputs, **attrs)

        elif ns == "quantized":
            domain = ""
            if operator_export_type == OperatorExportTypes.ONNX_ATEN_FALLBACK:
                domain = "caffe2"
            attrs = {k: n[k] for k in n.attributeNames()}

            if not sym_registry.is_registered_op(op_name, domain,
                                                 opset_version):
                warnings.warn(
                    "ONNX export failed on quantized operator {}::{} because "
                    "torch.onnx.symbolic_opset{}.{} does not exist. ".format(
                        ns, op_name, opset_version, op_name))
            op_fn = sym_registry.get_registered_op(op_name, domain,
                                                   opset_version)
            return op_fn(g, *inputs, **attrs)

        # custom ops
        elif sym_registry.is_registered_version(ns, opset_version):
            if not sym_registry.is_registered_op(op_name, ns, opset_version):
                warnings.warn(
                    "ONNX export failed on custom operator {}::{} because "
                    "torch.onnx.symbolic_opset{}.{} does not exist.".format(
                        ns, op_name, opset_version, op_name))
            symbolic_fn = sym_registry.get_registered_op(
                op_name, ns, opset_version)
            attrs = {k: n[k] for k in n.attributeNames()}
            return symbolic_fn(g, *inputs, **attrs)

        else:
            warnings.warn(
                "ONNX export failed on an operator with unrecognized namespace "
                "{}::{}; If you are trying to export a custom operator, "
                "make sure you registered it with the right domain and version."
                "Otherwise please report a bug".format(ns, op_name))
            return None

    except TypeError as e:
        # Handle the specific case where we didn't successfully dispatch.
        # Otherwise, the backtrace will have the clues you need.
        e.args = ("{} (occurred when translating {})".format(
            e.args[0], op_name), )
        raise