Esempio n. 1
0
def _run_symbolic_function(g, n, inputs, env, operator_export_type=OperatorExportTypes.ONNX):
    # NB: Returning None means the node gets cloned as is into
    # the new graph
    try:
        import torch
        from torch.onnx.symbolic_helper import _export_onnx_opset_version as opset_version
        import torch.onnx.symbolic_registry as sym_registry

        sym_registry.register_version('', opset_version)

        # Quantized op symbolics are registered for opset 9 only.
        if operator_export_type == OperatorExportTypes.ONNX_ATEN_FALLBACK and opset_version == 9:
            import torch.onnx.symbolic_caffe2
            torch.onnx.symbolic_caffe2.register_quantized_ops('caffe2', opset_version)

        # See Note [Export inplace]
        # TODO: I think this is not necessary anymore
        if n.kind().endswith('_'):
            ns_op_name = n.kind()[:-1]
        else:
            ns_op_name = n.kind()
        ns, op_name = ns_op_name.split("::")
        if ns == "onnx":
            # Use the original node directly
            return None

        elif ns == "aten":
            is_exportable_aten_op = sym_registry.is_registered_op(op_name, '', opset_version)
            is_onnx_aten_export = operator_export_type == OperatorExportTypes.ONNX_ATEN
            is_aten_fallback_export = operator_export_type == OperatorExportTypes.ONNX_ATEN_FALLBACK
            if is_onnx_aten_export or (not is_exportable_aten_op and is_aten_fallback_export):
                # Direct ATen export requested
                attrs = {k + "_" + n.kindOf(k)[0]: n[k] for k in n.attributeNames()}
                outputs = n.outputsSize()
                attrs["outputs"] = outputs
                return _graph_at(g, op_name, *inputs, aten=True, **attrs)
            else:
                # Export it regularly
                domain = ''
                symbolic_fn = _find_symbolic_in_registry(domain, op_name, opset_version, operator_export_type)
                if symbolic_fn is None:
                    return None
                attrs = {k: n[k] for k in n.attributeNames()}
                return symbolic_fn(g, *inputs, **attrs)

        elif ns == "prim":
            if op_name == "Constant" and not n.mustBeNone():
                if n.kindOf("value") == "t":
                    return g.op("Constant", value_t=n["value"])
                if n.kindOf("value") == "s":
                    return g.op("Constant", value_s=n["value"])
                elif n.output().type().isSubtypeOf(ListType.ofInts()) or n.output().type().isSubtypeOf(ListType.ofFloats()):
                    vals = n.output().toIValue()
                    value = torch.stack([torch.tensor(v) for v in vals]) if len(vals) else []
                    return g.op("Constant", value_t=value)
                elif n.output().type().kind() == "DeviceObjType":
                    return None
                else:
                    raise RuntimeError("Unsupported prim::Constant kind: `{}`. Send a bug report.".format(
                        n.kindOf("value")))
            elif n.mustBeNone() or op_name == "ListConstruct" or op_name == "ListUnpack":
                # None is not an ONNX operator; keep it as None
                # Let the exporter handle and finally eliminate these ops
                # ListConstruct and ListUnpack will be erased in the ONNX peephole pass
                return None
            elif op_name == 'Loop' or op_name == 'If':
                new_op_outputs = g.op(op_name, *inputs, outputs=n.outputsSize())
                new_node = new_op_outputs[0].node() if n.outputsSize() > 1 else new_op_outputs.node()
                for b in n.blocks():
                    new_block = new_node.addBlock()
                    torch._C._jit_pass_onnx_block(b, new_block, operator_export_type, env)
                new_op_outputs = torch._C._jit_pass_fixup_onnx_controlflow_node(new_node, opset_version)
                return new_op_outputs
            else:
                symbolic_name = 'prim_' + op_name
                domain = ''
                symbolic_fn = _find_symbolic_in_registry(domain, symbolic_name, opset_version,
                                                         operator_export_type)
                if symbolic_fn is None:
                    return None
                attrs = {k: n[k] for k in n.attributeNames()}
                return symbolic_fn(g, *inputs, **attrs)

        elif ns == "quantized":
            domain = ''
            if operator_export_type == OperatorExportTypes.ONNX_ATEN_FALLBACK:
                domain = 'caffe2'
            symbolic_fn = _find_symbolic_in_registry(domain, op_name, opset_version, operator_export_type)
            if symbolic_fn is None:
                return None
            attrs = {k: n[k] for k in n.attributeNames()}
            return symbolic_fn(g, *inputs, **attrs)

        # custom ops
        elif sym_registry.is_registered_version(ns, opset_version):
            domain = ns
            symbolic_fn = _find_symbolic_in_registry(domain, op_name, opset_version, operator_export_type)
            if symbolic_fn is None:
                return None
            attrs = {k: n[k] for k in n.attributeNames()}
            return symbolic_fn(g, *inputs, **attrs)
        else:
            raise RuntimeError("ONNX export failed on an operator with unrecognized namespace {}::{}. "
                               "If you are trying to export a custom operator, make sure you registered "
                               "it with the right domain and version. "
                               "Otherwise, please report a bug.".format(ns, op_name))
    except RuntimeError:
        if operator_export_type == OperatorExportTypes.ONNX_FALLTHROUGH:
            return None
        raise
    except TypeError as e:
        # Handle the specific case where we didn't successfully dispatch.
        # Otherwise, the backtrace will have the clues you need.
        e.args = ("{} \n(Occurred when translating {}).".format(e.args[0], op_name),)
        raise
Esempio n. 2
0
def _run_symbolic_function(g,
                           n,
                           inputs,
                           env,
                           operator_export_type=OperatorExportTypes.ONNX):
    # NB: Returning None means the node gets cloned as is into
    # the new graph
    try:
        from torch.onnx.symbolic_helper import _export_onnx_opset_version as opset_version
        import torch.onnx.symbolic_registry as sym_registry

        sym_registry.register_version('', opset_version)

        # See Note [Export inplace]
        # TODO: I think this is not necessary anymore
        if n.kind().endswith('_'):
            ns_op_name = n.kind()[:-1]
        else:
            ns_op_name = n.kind()
        ns, op_name = ns_op_name.split("::")

        if ns == "onnx":
            # Use the original node directly
            return None

        elif ns == "aten":
            is_exportable_aten_op = sym_registry.is_registered_op(
                op_name, '', opset_version)
            is_onnx_aten_export = operator_export_type == OperatorExportTypes.ONNX_ATEN
            is_aten_fallback_export = operator_export_type == OperatorExportTypes.ONNX_ATEN_FALLBACK
            if is_onnx_aten_export or (not is_exportable_aten_op
                                       and is_aten_fallback_export):
                # Direct ATen export requested
                attrs = {
                    k + "_" + n.kindOf(k)[0]: n[k]
                    for k in n.attributeNames()
                }
                outputs = n.outputsSize()
                attrs["outputs"] = outputs
                return _graph_at(g, op_name, *inputs, aten=True, **attrs)

            else:
                # Export it regularly
                attrs = {k: n[k] for k in n.attributeNames()}
                if not is_exportable_aten_op:
                    warnings.warn(
                        "ONNX export failed on ATen operator {} because "
                        "torch.onnx.symbolic_opset{}.{} does not exist".format(
                            op_name, opset_version, op_name))
                op_fn = sym_registry.get_registered_op(op_name, '',
                                                       opset_version)
                return op_fn(g, *inputs, **attrs)

        elif ns == "prim":
            if op_name == "Constant" and not n.mustBeNone():
                if n.kindOf("value") == "t":
                    return g.op("Constant", value_t=n["value"])
                elif n.kindOf("value") == "is":
                    value = torch.stack([torch.tensor(v) for v in n["value"]
                                         ]) if n["value"] else []
                    return g.op("Constant", value_t=value)
                elif n.output().type().kind() == "DeviceObjType":
                    return None
                else:
                    raise RuntimeError(
                        "Unsupported prim::Constant kind: `{}`. Send a bug report."
                        .format(n.kindOf("value")))
            elif n.mustBeNone(
            ) or op_name == "ListConstruct" or op_name == "ListUnpack":
                # None is not an ONNX operator; keep it as None
                # let the exporter handle finally eliminating these

                # For ListConstruct/ListUnpack, it will be erased in the ONNX peephole pass
                return None
            elif op_name == 'Loop' or op_name == 'If':
                new_op_outputs = g.op(op_name,
                                      *inputs,
                                      outputs=n.outputsSize())
                new_node = new_op_outputs[0].node(
                ) if n.outputsSize() > 1 else new_op_outputs.node()
                for b in n.blocks():
                    new_block = new_node.addBlock()
                    torch._C._jit_pass_onnx_block(b, new_block,
                                                  operator_export_type, env)
                return new_op_outputs
            else:
                # TODO: we sould lift prim's symbolic out
                symbolic_name = 'prim_' + op_name
                is_exportable = sym_registry.is_registered_op(
                    symbolic_name, '', opset_version)
                if not is_exportable:
                    warnings.warn(
                        "ONNX export failed on primitive operator {}; please report a bug"
                        .format(op_name))
                symbolic_fn = sym_registry.get_registered_op(
                    symbolic_name, '', opset_version)
                attrs = {k: n[k] for k in n.attributeNames()}
                return symbolic_fn(g, *inputs, **attrs)

        # custom ops
        elif sym_registry.is_registered_version(ns, opset_version):
            if not sym_registry.is_registered_op(op_name, ns, opset_version):
                warnings.warn(
                    "ONNX export failed on custom operator {}::{} because "
                    "torch.onnx.symbolic_opset{}.{} does not exist. "
                    "Have you registered your symbolic function with "
                    "torch.onnx.register_custom_op_symbolic(symbolic_name, symbolic_fn)?"
                    .format(ns, op_name, opset_version, op_name))
            symbolic_fn = sym_registry.get_registered_op(
                op_name, ns, opset_version)
            attrs = {k: n[k] for k in n.attributeNames()}
            return symbolic_fn(g, *inputs, **attrs)

        else:
            warnings.warn(
                "ONNX export failed on an operator with unrecognized namespace {}::{}; "
                "If you are trying to export a custom operator, make sure you registered "
                "it with the right domain and version."
                "Otherwise please report a bug".format(ns, op_name))
            return None

    except TypeError as e:
        # Handle the specific case where we didn't successfully dispatch.
        # Otherwise, the backtrace will have the clues you need.
        e.args = ("{} (occurred when translating {})".format(
            e.args[0], op_name), )
        raise
Esempio n. 3
0
def _run_symbolic_function_with_in_place(
        g, n, inputs, env, operator_export_type=OperatorExportTypes.ONNX):
    """
    Monkey-patched version of `_run_symbolic_function` function in `torch.onnx.utils`.
    The only change is that trailing '_' is no longer removed from `ns_op_name` for
    the dropout function.
    """
    try:
        import torch
        from torch.onnx.symbolic_helper import (
            _export_onnx_opset_version as opset_version, )
        import torch.onnx.symbolic_registry as sym_registry

        sym_registry.register_version("", opset_version)
        if operator_export_type == OperatorExportTypes.ONNX_ATEN_FALLBACK:
            import torch.onnx.symbolic_caffe2

            torch.onnx.symbolic_caffe2.register_quantized_ops(
                "caffe2", opset_version)

        ns_op_name = n.kind()
        ns, op_name = ns_op_name.split("::")
        if n.kind().endswith("_"):
            if op_name not in ["dropout_", "feature_dropout_"]:
                op_name = op_name[:-1]

        if ns == "onnx":
            # Use the original node directly
            return None

        elif ns == "aten":
            is_exportable_aten_op = sym_registry.is_registered_op(
                op_name, "", opset_version)
            is_onnx_aten_export = operator_export_type == OperatorExportTypes.ONNX_ATEN
            is_aten_fallback_export = (
                operator_export_type == OperatorExportTypes.ONNX_ATEN_FALLBACK)
            if is_onnx_aten_export or (not is_exportable_aten_op
                                       and is_aten_fallback_export):
                # Direct ATen export requested
                attrs = {
                    k + "_" + n.kindOf(k)[0]: n[k]
                    for k in n.attributeNames()
                }
                outputs = n.outputsSize()
                attrs["outputs"] = outputs
                return torch.onnx.utils._graph_at(g,
                                                  op_name,
                                                  *inputs,
                                                  aten=True,
                                                  **attrs)

            else:
                # Export it regularly
                attrs = {k: n[k] for k in n.attributeNames()}
                if not is_exportable_aten_op:
                    warnings.warn(
                        "ONNX export failed on ATen operator {} because "
                        "torch.onnx.symbolic_opset{}.{} does not exist".format(
                            op_name, opset_version, op_name))
                op_fn = sym_registry.get_registered_op(op_name, "",
                                                       opset_version)
                return op_fn(g, *inputs, **attrs)

        elif ns == "prim":
            if op_name == "Constant" and not n.mustBeNone():
                if n.kindOf("value") == "t":
                    return g.op("Constant", value_t=n["value"])
                if n.kindOf("value") == "s":
                    return g.op("Constant", value_s=n["value"])
                elif n.kindOf("value") == "is":
                    value = (torch.stack([torch.tensor(v) for v in n["value"]])
                             if n["value"] else [])
                    return g.op("Constant", value_t=value)
                elif n.output().type().kind() == "DeviceObjType":
                    return None
                else:
                    raise RuntimeError(
                        "Unsupported prim::Constant kind: `{}`".format(
                            n.kindOf("value")))
            elif (n.mustBeNone() or op_name == "ListConstruct"
                  or op_name == "ListUnpack"):
                # None is not an ONNX operator; keep it as None
                # let the exporter handle finally eliminating these

                # For ListConstruct/ListUnpack, it will be erased in the
                # ONNX peephole pass
                return None
            elif op_name == "Loop" or op_name == "If":
                new_op_outputs = g.op(op_name,
                                      *inputs,
                                      outputs=n.outputsSize())
                new_node = (new_op_outputs[0].node()
                            if n.outputsSize() > 1 else new_op_outputs.node())
                for b in n.blocks():
                    new_block = new_node.addBlock()
                    torch._C._jit_pass_onnx_block(b, new_block,
                                                  operator_export_type, env)
                return new_op_outputs
            else:
                symbolic_name = "prim_" + op_name
                is_exportable = sym_registry.is_registered_op(
                    symbolic_name, "", opset_version)
                if not is_exportable:
                    warnings.warn(
                        "ONNX export failed on primitive operator {}".format(
                            op_name))
                symbolic_fn = sym_registry.get_registered_op(
                    symbolic_name, "", opset_version)
                attrs = {k: n[k] for k in n.attributeNames()}
                return symbolic_fn(g, *inputs, **attrs)

        elif ns == "quantized":
            domain = ""
            if operator_export_type == OperatorExportTypes.ONNX_ATEN_FALLBACK:
                domain = "caffe2"
            attrs = {k: n[k] for k in n.attributeNames()}

            if not sym_registry.is_registered_op(op_name, domain,
                                                 opset_version):
                warnings.warn(
                    "ONNX export failed on quantized operator {}::{} because "
                    "torch.onnx.symbolic_opset{}.{} does not exist. ".format(
                        ns, op_name, opset_version, op_name))
            op_fn = sym_registry.get_registered_op(op_name, domain,
                                                   opset_version)
            return op_fn(g, *inputs, **attrs)

        # custom ops
        elif sym_registry.is_registered_version(ns, opset_version):
            if not sym_registry.is_registered_op(op_name, ns, opset_version):
                warnings.warn(
                    "ONNX export failed on custom operator {}::{} because "
                    "torch.onnx.symbolic_opset{}.{} does not exist.".format(
                        ns, op_name, opset_version, op_name))
            symbolic_fn = sym_registry.get_registered_op(
                op_name, ns, opset_version)
            attrs = {k: n[k] for k in n.attributeNames()}
            return symbolic_fn(g, *inputs, **attrs)

        else:
            warnings.warn(
                "ONNX export failed on an operator with unrecognized namespace "
                "{}::{}; If you are trying to export a custom operator, "
                "make sure you registered it with the right domain and version."
                "Otherwise please report a bug".format(ns, op_name))
            return None

    except TypeError as e:
        # Handle the specific case where we didn't successfully dispatch.
        # Otherwise, the backtrace will have the clues you need.
        e.args = ("{} (occurred when translating {})".format(
            e.args[0], op_name), )
        raise