def unregister(name, opset_version):
        ns, kind = name.split("::")
        from torch.onnx.symbolic_helper import _onnx_stable_opsets

        for version in _onnx_stable_opsets:
            if version >= opset_version and sym_registry.is_registered_op(kind, ns, version):
                del sym_registry._registry[(ns, version)][kind]
示例#2
0
def register_quantized_ops(domain: str, version: int):
    # Register all the non-quantized ops
    symbolic_registry.register_version("", version)
    # Register all quantized ops
    module = importlib.import_module("torch.onnx.symbolic_caffe2")
    symbolic_registry._symbolic_versions["caffe2"] = module
    quant_version_ops = inspect.getmembers(
        symbolic_registry._symbolic_versions["caffe2"])
    for op in quant_version_ops:
        if inspect.isfunction(
                op[1]) and not symbolic_registry.is_registered_op(
                    op[0], domain, version):
            aten_q_ops = [
                "relu",
                "_empty_affine_quantized",
                "dequantize",
                "quantize_per_tensor",
                "upsample_nearest2d",
                "avg_pool2d",
                "reshape",
                "slice",
                "cat",
                "max_pool2d",
                "sigmoid",
            ]
            if op[0] in aten_q_ops:
                symbolic_registry.register_op(op[0], op[1], "", version)
            symbolic_registry.register_op(op[0], op[1], domain, version)
示例#3
0
def _find_symbolic_in_registry(domain, op_name, opset_version, operator_export_type):
    import torch.onnx.symbolic_registry as sym_registry
    if not sym_registry.is_registered_op(op_name, domain, opset_version):
        if operator_export_type == OperatorExportTypes.ONNX_FALLTHROUGH:
            # Use the original node directly
            return None
    return sym_registry.get_registered_op(op_name, domain, opset_version)
示例#4
0
        def wrapped_function(*args, **kwargs):

            name = op_name if op_name is not None else func.__name__
            opset = sym_help._export_onnx_opset_version

            if is_registered_op(name, namespace, opset):

                class XFunction(torch.autograd.Function):
                    @staticmethod
                    def forward(ctx, *xargs):
                        return func(*args, **kwargs)

                    @staticmethod
                    def symbolic(g, *xargs):
                        symb = get_registered_op(name, namespace, opset)
                        if adapter is not None:
                            return symb(g, *xargs, **adapter_kwargs)
                        return symb(g, *xargs)

                if adapter is not None:
                    adapter_args, adapter_kwargs = adapter(*args, **kwargs)
                    return XFunction.apply(*adapter_args)
                return XFunction.apply(*args)
            else:
                return func(*args, **kwargs)
def unregister():
    """Unregister ONNX Runtime's built-in contrib ops."""
    # TODO: replace this once PyTorch supports unregister natively.
    # https://msdata.visualstudio.com/Vienna/_workitems/edit/1342343
    for name in _registered_ops:
        ns, kind = name.split("::")
        for version in sym_help._onnx_stable_opsets:
            if version >= _OPSET_VERSION and sym_registry.is_registered_op(kind, ns, version):
                del sym_registry._registry[(ns, version)][kind]
def unregister():
    """Unregister ONNX Runtime's built-in contrib ops."""
    for name in _registered_ops:
        try:
            torch.onnx.unregister_custom_op_symbolic(name, _OPSET_VERSION)
        except AttributeError:
            # unregister_custom_op_symbolic is not available before PyTorch 1.12
            namespace, kind = name.split("::")
            for version in sym_help._onnx_stable_opsets:
                if version >= _OPSET_VERSION and sym_registry.is_registered_op(
                        kind, namespace, version):
                    del sym_registry._registry[(namespace, version)][kind]
示例#7
0
def register_quantized_ops(domain, version):
    # Register all the non-quantized ops
    sym_registry.register_version('', version)
    # Register all quantized ops
    module = importlib.import_module('torch.onnx.symbolic_caffe2')
    sym_registry._symbolic_versions['caffe2'] = module
    quant_version_ops = getmembers(sym_registry._symbolic_versions['caffe2'])
    for op in quant_version_ops:
        if isfunction(op[1]) and not sym_registry.is_registered_op(op[0], domain, version):
            aten_q_ops = ['relu', '_empty_affine_quantized', 'dequantize', 'quantize_per_tensor', 'upsample_nearest2d']
            if op[0] in aten_q_ops:
                sym_registry.register_op(op[0], op[1], '', version)
            sym_registry.register_op(op[0], op[1], domain, version)
示例#8
0
 def symbolic_function(self, n: torch._C.Node) -> Optional[Callable]:
     ns, op = n.kind().split("::")
     if op.endswith("_"):  # For inplace op
         op = op[:-1]
     if ns == "prim" and op == "PythonOp":
         pyobj = n.pyobj()
         if issubclass(pyobj.__self__, torch.autograd.Function):
             pyobj = pyobj.__self__
         assert issubclass(pyobj, torch.autograd.Function)
         assert hasattr(pyobj, "symbolic"), f"symbolic method not supported in {pyobj}"
         # TODO(twata): Use repr(pyobj) in scope name or doc_string
         return cast(Callable, pyobj.symbolic)
     else:
         if ns == "prim":
             op = f"prim_{op}"
         if sym_reg.is_registered_op(op, "", self.opset_version):  # type: ignore[no-untyped-call]
             return cast(
                 Callable, sym_reg.get_registered_op(op, "", self.opset_version)  # type: ignore[no-untyped-call]
             )
         else:
             return None
示例#9
0
    def __init__(self,
                 inner,
                 namespace,
                 name,
                 verbose=False,
                 force_rebuild=True,
                 unique_name=True):
        super().__init__()
        self.inner = inner
        self.namespace = namespace
        self.name = name
        if unique_name:
            self.name = name + '_' + next(
                tempfile._get_candidate_names()) + 'x'
        self.qualified_name = '{}::{}'.format(self.namespace, self.name)
        self.num_outputs = 1
        self.params = {}
        self.verbose = verbose
        self.force_rebuild = force_rebuild

        # Register symbolic function for ONNX export.
        while not is_registered_op(self.name, self.namespace, 10):
            register_op(self.name, self.symbolic, self.namespace, 10)
示例#10
0
def _run_symbolic_function(g,
                           n,
                           inputs,
                           env,
                           operator_export_type=OperatorExportTypes.ONNX):
    # NB: Returning None means the node gets cloned as is into
    # the new graph
    try:
        from torch.onnx.symbolic_helper import _export_onnx_opset_version as opset_version
        import torch.onnx.symbolic_registry as sym_registry

        sym_registry.register_version('', opset_version)

        # See Note [Export inplace]
        # TODO: I think this is not necessary anymore
        if n.kind().endswith('_'):
            ns_op_name = n.kind()[:-1]
        else:
            ns_op_name = n.kind()
        ns, op_name = ns_op_name.split("::")

        if ns == "onnx":
            # Use the original node directly
            return None

        elif ns == "aten":
            is_exportable_aten_op = sym_registry.is_registered_op(
                op_name, '', opset_version)
            is_onnx_aten_export = operator_export_type == OperatorExportTypes.ONNX_ATEN
            is_aten_fallback_export = operator_export_type == OperatorExportTypes.ONNX_ATEN_FALLBACK
            if is_onnx_aten_export or (not is_exportable_aten_op
                                       and is_aten_fallback_export):
                # Direct ATen export requested
                attrs = {
                    k + "_" + n.kindOf(k)[0]: n[k]
                    for k in n.attributeNames()
                }
                outputs = n.outputsSize()
                attrs["outputs"] = outputs
                return _graph_at(g, op_name, *inputs, aten=True, **attrs)

            else:
                # Export it regularly
                attrs = {k: n[k] for k in n.attributeNames()}
                if not is_exportable_aten_op:
                    warnings.warn(
                        "ONNX export failed on ATen operator {} because "
                        "torch.onnx.symbolic_opset{}.{} does not exist".format(
                            op_name, opset_version, op_name))
                op_fn = sym_registry.get_registered_op(op_name, '',
                                                       opset_version)
                return op_fn(g, *inputs, **attrs)

        elif ns == "prim":
            if op_name == "Constant" and not n.mustBeNone():
                if n.kindOf("value") == "t":
                    return g.op("Constant", value_t=n["value"])
                elif n.kindOf("value") == "is":
                    value = torch.stack([torch.tensor(v) for v in n["value"]
                                         ]) if n["value"] else []
                    return g.op("Constant", value_t=value)
                elif n.output().type().kind() == "DeviceObjType":
                    return None
                else:
                    raise RuntimeError(
                        "Unsupported prim::Constant kind: `{}`. Send a bug report."
                        .format(n.kindOf("value")))
            elif n.mustBeNone(
            ) or op_name == "ListConstruct" or op_name == "ListUnpack":
                # None is not an ONNX operator; keep it as None
                # let the exporter handle finally eliminating these

                # For ListConstruct/ListUnpack, it will be erased in the ONNX peephole pass
                return None
            elif op_name == 'Loop' or op_name == 'If':
                new_op_outputs = g.op(op_name,
                                      *inputs,
                                      outputs=n.outputsSize())
                new_node = new_op_outputs[0].node(
                ) if n.outputsSize() > 1 else new_op_outputs.node()
                for b in n.blocks():
                    new_block = new_node.addBlock()
                    torch._C._jit_pass_onnx_block(b, new_block,
                                                  operator_export_type, env)
                return new_op_outputs
            else:
                # TODO: we sould lift prim's symbolic out
                symbolic_name = 'prim_' + op_name
                is_exportable = sym_registry.is_registered_op(
                    symbolic_name, '', opset_version)
                if not is_exportable:
                    warnings.warn(
                        "ONNX export failed on primitive operator {}; please report a bug"
                        .format(op_name))
                symbolic_fn = sym_registry.get_registered_op(
                    symbolic_name, '', opset_version)
                attrs = {k: n[k] for k in n.attributeNames()}
                return symbolic_fn(g, *inputs, **attrs)

        # custom ops
        elif sym_registry.is_registered_version(ns, opset_version):
            if not sym_registry.is_registered_op(op_name, ns, opset_version):
                warnings.warn(
                    "ONNX export failed on custom operator {}::{} because "
                    "torch.onnx.symbolic_opset{}.{} does not exist. "
                    "Have you registered your symbolic function with "
                    "torch.onnx.register_custom_op_symbolic(symbolic_name, symbolic_fn)?"
                    .format(ns, op_name, opset_version, op_name))
            symbolic_fn = sym_registry.get_registered_op(
                op_name, ns, opset_version)
            attrs = {k: n[k] for k in n.attributeNames()}
            return symbolic_fn(g, *inputs, **attrs)

        else:
            warnings.warn(
                "ONNX export failed on an operator with unrecognized namespace {}::{}; "
                "If you are trying to export a custom operator, make sure you registered "
                "it with the right domain and version."
                "Otherwise please report a bug".format(ns, op_name))
            return None

    except TypeError as e:
        # Handle the specific case where we didn't successfully dispatch.
        # Otherwise, the backtrace will have the clues you need.
        e.args = ("{} (occurred when translating {})".format(
            e.args[0], op_name), )
        raise
示例#11
0
def _run_symbolic_function(g, n, inputs, env, operator_export_type=OperatorExportTypes.ONNX):
    # NB: Returning None means the node gets cloned as is into
    # the new graph
    try:
        import torch
        from torch.onnx.symbolic_helper import _export_onnx_opset_version as opset_version
        import torch.onnx.symbolic_registry as sym_registry

        sym_registry.register_version('', opset_version)

        # Quantized op symbolics are registered for opset 9 only.
        if operator_export_type == OperatorExportTypes.ONNX_ATEN_FALLBACK and opset_version == 9:
            import torch.onnx.symbolic_caffe2
            torch.onnx.symbolic_caffe2.register_quantized_ops('caffe2', opset_version)

        # See Note [Export inplace]
        # TODO: I think this is not necessary anymore
        if n.kind().endswith('_'):
            ns_op_name = n.kind()[:-1]
        else:
            ns_op_name = n.kind()
        ns, op_name = ns_op_name.split("::")
        if ns == "onnx":
            # Use the original node directly
            return None

        elif ns == "aten":
            is_exportable_aten_op = sym_registry.is_registered_op(op_name, '', opset_version)
            is_onnx_aten_export = operator_export_type == OperatorExportTypes.ONNX_ATEN
            is_aten_fallback_export = operator_export_type == OperatorExportTypes.ONNX_ATEN_FALLBACK
            if is_onnx_aten_export or (not is_exportable_aten_op and is_aten_fallback_export):
                # Direct ATen export requested
                attrs = {k + "_" + n.kindOf(k)[0]: n[k] for k in n.attributeNames()}
                outputs = n.outputsSize()
                attrs["outputs"] = outputs
                return _graph_at(g, op_name, *inputs, aten=True, **attrs)
            else:
                # Export it regularly
                domain = ''
                symbolic_fn = _find_symbolic_in_registry(domain, op_name, opset_version, operator_export_type)
                if symbolic_fn is None:
                    return None
                attrs = {k: n[k] for k in n.attributeNames()}
                return symbolic_fn(g, *inputs, **attrs)

        elif ns == "prim":
            if op_name == "Constant" and not n.mustBeNone():
                if n.kindOf("value") == "t":
                    return g.op("Constant", value_t=n["value"])
                if n.kindOf("value") == "s":
                    return g.op("Constant", value_s=n["value"])
                elif n.output().type().isSubtypeOf(ListType.ofInts()) or n.output().type().isSubtypeOf(ListType.ofFloats()):
                    vals = n.output().toIValue()
                    value = torch.stack([torch.tensor(v) for v in vals]) if len(vals) else []
                    return g.op("Constant", value_t=value)
                elif n.output().type().kind() == "DeviceObjType":
                    return None
                else:
                    raise RuntimeError("Unsupported prim::Constant kind: `{}`. Send a bug report.".format(
                        n.kindOf("value")))
            elif n.mustBeNone() or op_name == "ListConstruct" or op_name == "ListUnpack":
                # None is not an ONNX operator; keep it as None
                # Let the exporter handle and finally eliminate these ops
                # ListConstruct and ListUnpack will be erased in the ONNX peephole pass
                return None
            elif op_name == 'Loop' or op_name == 'If':
                new_op_outputs = g.op(op_name, *inputs, outputs=n.outputsSize())
                new_node = new_op_outputs[0].node() if n.outputsSize() > 1 else new_op_outputs.node()
                for b in n.blocks():
                    new_block = new_node.addBlock()
                    torch._C._jit_pass_onnx_block(b, new_block, operator_export_type, env)
                new_op_outputs = torch._C._jit_pass_fixup_onnx_controlflow_node(new_node, opset_version)
                return new_op_outputs
            else:
                symbolic_name = 'prim_' + op_name
                domain = ''
                symbolic_fn = _find_symbolic_in_registry(domain, symbolic_name, opset_version,
                                                         operator_export_type)
                if symbolic_fn is None:
                    return None
                attrs = {k: n[k] for k in n.attributeNames()}
                return symbolic_fn(g, *inputs, **attrs)

        elif ns == "quantized":
            domain = ''
            if operator_export_type == OperatorExportTypes.ONNX_ATEN_FALLBACK:
                domain = 'caffe2'
            symbolic_fn = _find_symbolic_in_registry(domain, op_name, opset_version, operator_export_type)
            if symbolic_fn is None:
                return None
            attrs = {k: n[k] for k in n.attributeNames()}
            return symbolic_fn(g, *inputs, **attrs)

        # custom ops
        elif sym_registry.is_registered_version(ns, opset_version):
            domain = ns
            symbolic_fn = _find_symbolic_in_registry(domain, op_name, opset_version, operator_export_type)
            if symbolic_fn is None:
                return None
            attrs = {k: n[k] for k in n.attributeNames()}
            return symbolic_fn(g, *inputs, **attrs)
        else:
            raise RuntimeError("ONNX export failed on an operator with unrecognized namespace {}::{}. "
                               "If you are trying to export a custom operator, make sure you registered "
                               "it with the right domain and version. "
                               "Otherwise, please report a bug.".format(ns, op_name))
    except RuntimeError:
        if operator_export_type == OperatorExportTypes.ONNX_FALLTHROUGH:
            return None
        raise
    except TypeError as e:
        # Handle the specific case where we didn't successfully dispatch.
        # Otherwise, the backtrace will have the clues you need.
        e.args = ("{} \n(Occurred when translating {}).".format(e.args[0], op_name),)
        raise
示例#12
0
def _run_symbolic_function_with_in_place(
        g, n, inputs, env, operator_export_type=OperatorExportTypes.ONNX):
    """
    Monkey-patched version of `_run_symbolic_function` function in `torch.onnx.utils`.
    The only change is that trailing '_' is no longer removed from `ns_op_name` for
    the dropout function.
    """
    try:
        import torch
        from torch.onnx.symbolic_helper import (
            _export_onnx_opset_version as opset_version, )
        import torch.onnx.symbolic_registry as sym_registry

        sym_registry.register_version("", opset_version)
        if operator_export_type == OperatorExportTypes.ONNX_ATEN_FALLBACK:
            import torch.onnx.symbolic_caffe2

            torch.onnx.symbolic_caffe2.register_quantized_ops(
                "caffe2", opset_version)

        ns_op_name = n.kind()
        ns, op_name = ns_op_name.split("::")
        if n.kind().endswith("_"):
            if op_name not in ["dropout_", "feature_dropout_"]:
                op_name = op_name[:-1]

        if ns == "onnx":
            # Use the original node directly
            return None

        elif ns == "aten":
            is_exportable_aten_op = sym_registry.is_registered_op(
                op_name, "", opset_version)
            is_onnx_aten_export = operator_export_type == OperatorExportTypes.ONNX_ATEN
            is_aten_fallback_export = (
                operator_export_type == OperatorExportTypes.ONNX_ATEN_FALLBACK)
            if is_onnx_aten_export or (not is_exportable_aten_op
                                       and is_aten_fallback_export):
                # Direct ATen export requested
                attrs = {
                    k + "_" + n.kindOf(k)[0]: n[k]
                    for k in n.attributeNames()
                }
                outputs = n.outputsSize()
                attrs["outputs"] = outputs
                return torch.onnx.utils._graph_at(g,
                                                  op_name,
                                                  *inputs,
                                                  aten=True,
                                                  **attrs)

            else:
                # Export it regularly
                attrs = {k: n[k] for k in n.attributeNames()}
                if not is_exportable_aten_op:
                    warnings.warn(
                        "ONNX export failed on ATen operator {} because "
                        "torch.onnx.symbolic_opset{}.{} does not exist".format(
                            op_name, opset_version, op_name))
                op_fn = sym_registry.get_registered_op(op_name, "",
                                                       opset_version)
                return op_fn(g, *inputs, **attrs)

        elif ns == "prim":
            if op_name == "Constant" and not n.mustBeNone():
                if n.kindOf("value") == "t":
                    return g.op("Constant", value_t=n["value"])
                if n.kindOf("value") == "s":
                    return g.op("Constant", value_s=n["value"])
                elif n.kindOf("value") == "is":
                    value = (torch.stack([torch.tensor(v) for v in n["value"]])
                             if n["value"] else [])
                    return g.op("Constant", value_t=value)
                elif n.output().type().kind() == "DeviceObjType":
                    return None
                else:
                    raise RuntimeError(
                        "Unsupported prim::Constant kind: `{}`".format(
                            n.kindOf("value")))
            elif (n.mustBeNone() or op_name == "ListConstruct"
                  or op_name == "ListUnpack"):
                # None is not an ONNX operator; keep it as None
                # let the exporter handle finally eliminating these

                # For ListConstruct/ListUnpack, it will be erased in the
                # ONNX peephole pass
                return None
            elif op_name == "Loop" or op_name == "If":
                new_op_outputs = g.op(op_name,
                                      *inputs,
                                      outputs=n.outputsSize())
                new_node = (new_op_outputs[0].node()
                            if n.outputsSize() > 1 else new_op_outputs.node())
                for b in n.blocks():
                    new_block = new_node.addBlock()
                    torch._C._jit_pass_onnx_block(b, new_block,
                                                  operator_export_type, env)
                return new_op_outputs
            else:
                symbolic_name = "prim_" + op_name
                is_exportable = sym_registry.is_registered_op(
                    symbolic_name, "", opset_version)
                if not is_exportable:
                    warnings.warn(
                        "ONNX export failed on primitive operator {}".format(
                            op_name))
                symbolic_fn = sym_registry.get_registered_op(
                    symbolic_name, "", opset_version)
                attrs = {k: n[k] for k in n.attributeNames()}
                return symbolic_fn(g, *inputs, **attrs)

        elif ns == "quantized":
            domain = ""
            if operator_export_type == OperatorExportTypes.ONNX_ATEN_FALLBACK:
                domain = "caffe2"
            attrs = {k: n[k] for k in n.attributeNames()}

            if not sym_registry.is_registered_op(op_name, domain,
                                                 opset_version):
                warnings.warn(
                    "ONNX export failed on quantized operator {}::{} because "
                    "torch.onnx.symbolic_opset{}.{} does not exist. ".format(
                        ns, op_name, opset_version, op_name))
            op_fn = sym_registry.get_registered_op(op_name, domain,
                                                   opset_version)
            return op_fn(g, *inputs, **attrs)

        # custom ops
        elif sym_registry.is_registered_version(ns, opset_version):
            if not sym_registry.is_registered_op(op_name, ns, opset_version):
                warnings.warn(
                    "ONNX export failed on custom operator {}::{} because "
                    "torch.onnx.symbolic_opset{}.{} does not exist.".format(
                        ns, op_name, opset_version, op_name))
            symbolic_fn = sym_registry.get_registered_op(
                op_name, ns, opset_version)
            attrs = {k: n[k] for k in n.attributeNames()}
            return symbolic_fn(g, *inputs, **attrs)

        else:
            warnings.warn(
                "ONNX export failed on an operator with unrecognized namespace "
                "{}::{}; If you are trying to export a custom operator, "
                "make sure you registered it with the right domain and version."
                "Otherwise please report a bug".format(ns, op_name))
            return None

    except TypeError as e:
        # Handle the specific case where we didn't successfully dispatch.
        # Otherwise, the backtrace will have the clues you need.
        e.args = ("{} (occurred when translating {})".format(
            e.args[0], op_name), )
        raise