def register_quantized_ops(domain: str, version: int): # Register all the non-quantized ops symbolic_registry.register_version("", version) # Register all quantized ops module = importlib.import_module("torch.onnx.symbolic_caffe2") symbolic_registry._symbolic_versions["caffe2"] = module quant_version_ops = inspect.getmembers( symbolic_registry._symbolic_versions["caffe2"]) for op in quant_version_ops: if inspect.isfunction( op[1]) and not symbolic_registry.is_registered_op( op[0], domain, version): aten_q_ops = [ "relu", "_empty_affine_quantized", "dequantize", "quantize_per_tensor", "upsample_nearest2d", "avg_pool2d", "reshape", "slice", "cat", "max_pool2d", "sigmoid", ] if op[0] in aten_q_ops: symbolic_registry.register_op(op[0], op[1], "", version) symbolic_registry.register_op(op[0], op[1], domain, version)
def register_quantized_ops(domain, version): # Register all the non-quantized ops sym_registry.register_version('', version) # Register all quantized ops module = importlib.import_module('torch.onnx.symbolic_caffe2') sym_registry._symbolic_versions['caffe2'] = module quant_version_ops = getmembers(sym_registry._symbolic_versions['caffe2']) for op in quant_version_ops: if isfunction(op[1]) and not sym_registry.is_registered_op(op[0], domain, version): aten_q_ops = ['relu', '_empty_affine_quantized', 'dequantize', 'quantize_per_tensor', 'upsample_nearest2d'] if op[0] in aten_q_ops: sym_registry.register_op(op[0], op[1], '', version) sym_registry.register_op(op[0], op[1], domain, version)
def __init__(self, model: Callable, inputs: Any, **opts: Any): super().__init__(**opts) if self.dynamic_axes is None: self.dynamic_axes = {} # Load symbolic opset assert self.opset_version is not None sym_reg.register_version("", self.opset_version) # type: ignore[no-untyped-call] self.original_model = model self.inputs = _to_tuple_if_not_sequence(inputs) self.attrs: Dict[TorchValueID, ONNXValueID] = {} self.node_doc_string: Dict[torch._C.Node, str] = {} self.node_scope: Dict[torch._C.Node, str] = {} self._convert()
from inspect import signature, _empty # type: ignore[attr-defined] from torch._C import _jit_get_all_schemas, FunctionSchema from torch.onnx.symbolic_registry import _registry, register_version from torch.onnx.symbolic_helper import _onnx_main_opset, _onnx_stable_opsets from typing import Dict, List, Union for v in _onnx_stable_opsets + [_onnx_main_opset]: register_version("", v) class _TorchSchema: def __init__(self, schema: Union[FunctionSchema, str]) -> None: if isinstance(schema, FunctionSchema): self.name: str = schema.name self.overload_name: str = schema.overload_name self.arguments: List[str] = [arg.name for arg in schema.arguments] self.optional_arguments: List[str] = [] self.returns: List[str] = [ret.name for ret in schema.returns] self.opsets: List[int] = [] else: self.name = schema self.overload_name = "" self.arguments = [] self.optional_arguments = [] self.returns = [] self.opsets = [] def __str__(self) -> str: s = f"{self.name}.{self.overload_name}(" s += ", ".join(self.arguments) s += ") -> ("
def _run_symbolic_function(g, n, inputs, env, operator_export_type=OperatorExportTypes.ONNX): # NB: Returning None means the node gets cloned as is into # the new graph try: from torch.onnx.symbolic_helper import _export_onnx_opset_version as opset_version import torch.onnx.symbolic_registry as sym_registry sym_registry.register_version('', opset_version) # See Note [Export inplace] # TODO: I think this is not necessary anymore if n.kind().endswith('_'): ns_op_name = n.kind()[:-1] else: ns_op_name = n.kind() ns, op_name = ns_op_name.split("::") if ns == "onnx": # Use the original node directly return None elif ns == "aten": is_exportable_aten_op = sym_registry.is_registered_op( op_name, '', opset_version) is_onnx_aten_export = operator_export_type == OperatorExportTypes.ONNX_ATEN is_aten_fallback_export = operator_export_type == OperatorExportTypes.ONNX_ATEN_FALLBACK if is_onnx_aten_export or (not is_exportable_aten_op and is_aten_fallback_export): # Direct ATen export requested attrs = { k + "_" + n.kindOf(k)[0]: n[k] for k in n.attributeNames() } outputs = n.outputsSize() attrs["outputs"] = outputs return _graph_at(g, op_name, *inputs, aten=True, **attrs) else: # Export it regularly attrs = {k: n[k] for k in n.attributeNames()} if not is_exportable_aten_op: warnings.warn( "ONNX export failed on ATen operator {} because " "torch.onnx.symbolic_opset{}.{} does not exist".format( op_name, opset_version, op_name)) op_fn = sym_registry.get_registered_op(op_name, '', opset_version) return op_fn(g, *inputs, **attrs) elif ns == "prim": if op_name == "Constant" and not n.mustBeNone(): if n.kindOf("value") == "t": return g.op("Constant", value_t=n["value"]) elif n.kindOf("value") == "is": value = torch.stack([torch.tensor(v) for v in n["value"] ]) if n["value"] else [] return g.op("Constant", value_t=value) elif n.output().type().kind() == "DeviceObjType": return None else: raise RuntimeError( "Unsupported prim::Constant kind: `{}`. Send a bug report." .format(n.kindOf("value"))) elif n.mustBeNone( ) or op_name == "ListConstruct" or op_name == "ListUnpack": # None is not an ONNX operator; keep it as None # let the exporter handle finally eliminating these # For ListConstruct/ListUnpack, it will be erased in the ONNX peephole pass return None elif op_name == 'Loop' or op_name == 'If': new_op_outputs = g.op(op_name, *inputs, outputs=n.outputsSize()) new_node = new_op_outputs[0].node( ) if n.outputsSize() > 1 else new_op_outputs.node() for b in n.blocks(): new_block = new_node.addBlock() torch._C._jit_pass_onnx_block(b, new_block, operator_export_type, env) return new_op_outputs else: # TODO: we sould lift prim's symbolic out symbolic_name = 'prim_' + op_name is_exportable = sym_registry.is_registered_op( symbolic_name, '', opset_version) if not is_exportable: warnings.warn( "ONNX export failed on primitive operator {}; please report a bug" .format(op_name)) symbolic_fn = sym_registry.get_registered_op( symbolic_name, '', opset_version) attrs = {k: n[k] for k in n.attributeNames()} return symbolic_fn(g, *inputs, **attrs) # custom ops elif sym_registry.is_registered_version(ns, opset_version): if not sym_registry.is_registered_op(op_name, ns, opset_version): warnings.warn( "ONNX export failed on custom operator {}::{} because " "torch.onnx.symbolic_opset{}.{} does not exist. " "Have you registered your symbolic function with " "torch.onnx.register_custom_op_symbolic(symbolic_name, symbolic_fn)?" .format(ns, op_name, opset_version, op_name)) symbolic_fn = sym_registry.get_registered_op( op_name, ns, opset_version) attrs = {k: n[k] for k in n.attributeNames()} return symbolic_fn(g, *inputs, **attrs) else: warnings.warn( "ONNX export failed on an operator with unrecognized namespace {}::{}; " "If you are trying to export a custom operator, make sure you registered " "it with the right domain and version." "Otherwise please report a bug".format(ns, op_name)) return None except TypeError as e: # Handle the specific case where we didn't successfully dispatch. # Otherwise, the backtrace will have the clues you need. e.args = ("{} (occurred when translating {})".format( e.args[0], op_name), ) raise
def setup(self, model, opset_ver): self.model = model # dryrun to register every aten ops sym_reg.register_version('', opset_ver) self.opset_ver = opset_ver
import inspect from typing import Dict, List, Union import torch._C from torch.onnx import _constants, symbolic_registry for v in _constants.onnx_stable_opsets: symbolic_registry.register_version("", v) symbolic_registry.register_version("", _constants.onnx_main_opset) class _TorchSchema: def __init__(self, schema: Union[torch._C.FunctionSchema, str]) -> None: if isinstance(schema, torch._C.FunctionSchema): self.name: str = schema.name self.overload_name: str = schema.overload_name self.arguments: List[str] = [arg.name for arg in schema.arguments] self.optional_arguments: List[str] = [] self.returns: List[str] = [ret.name for ret in schema.returns] self.opsets: List[int] = [] else: self.name = schema self.overload_name = "" self.arguments = [] self.optional_arguments = [] self.returns = [] self.opsets = [] def __str__(self) -> str: s = f"{self.name}.{self.overload_name}(" s += ", ".join(self.arguments)
import inspect from typing import Dict, List, Union import torch._C from torch.onnx import symbolic_helper, symbolic_registry for v in symbolic_helper._onnx_stable_opsets + [ symbolic_helper._onnx_main_opset ]: symbolic_registry.register_version("", v) class _TorchSchema: def __init__(self, schema: Union[torch._C.FunctionSchema, str]) -> None: if isinstance(schema, torch._C.FunctionSchema): self.name: str = schema.name self.overload_name: str = schema.overload_name self.arguments: List[str] = [arg.name for arg in schema.arguments] self.optional_arguments: List[str] = [] self.returns: List[str] = [ret.name for ret in schema.returns] self.opsets: List[int] = [] else: self.name = schema self.overload_name = "" self.arguments = [] self.optional_arguments = [] self.returns = [] self.opsets = [] def __str__(self) -> str: s = f"{self.name}.{self.overload_name}("
def _run_symbolic_function(g, n, inputs, env, operator_export_type=OperatorExportTypes.ONNX): # NB: Returning None means the node gets cloned as is into # the new graph try: import torch from torch.onnx.symbolic_helper import _export_onnx_opset_version as opset_version import torch.onnx.symbolic_registry as sym_registry sym_registry.register_version('', opset_version) # Quantized op symbolics are registered for opset 9 only. if operator_export_type == OperatorExportTypes.ONNX_ATEN_FALLBACK and opset_version == 9: import torch.onnx.symbolic_caffe2 torch.onnx.symbolic_caffe2.register_quantized_ops('caffe2', opset_version) # See Note [Export inplace] # TODO: I think this is not necessary anymore if n.kind().endswith('_'): ns_op_name = n.kind()[:-1] else: ns_op_name = n.kind() ns, op_name = ns_op_name.split("::") if ns == "onnx": # Use the original node directly return None elif ns == "aten": is_exportable_aten_op = sym_registry.is_registered_op(op_name, '', opset_version) is_onnx_aten_export = operator_export_type == OperatorExportTypes.ONNX_ATEN is_aten_fallback_export = operator_export_type == OperatorExportTypes.ONNX_ATEN_FALLBACK if is_onnx_aten_export or (not is_exportable_aten_op and is_aten_fallback_export): # Direct ATen export requested attrs = {k + "_" + n.kindOf(k)[0]: n[k] for k in n.attributeNames()} outputs = n.outputsSize() attrs["outputs"] = outputs return _graph_at(g, op_name, *inputs, aten=True, **attrs) else: # Export it regularly domain = '' symbolic_fn = _find_symbolic_in_registry(domain, op_name, opset_version, operator_export_type) if symbolic_fn is None: return None attrs = {k: n[k] for k in n.attributeNames()} return symbolic_fn(g, *inputs, **attrs) elif ns == "prim": if op_name == "Constant" and not n.mustBeNone(): if n.kindOf("value") == "t": return g.op("Constant", value_t=n["value"]) if n.kindOf("value") == "s": return g.op("Constant", value_s=n["value"]) elif n.output().type().isSubtypeOf(ListType.ofInts()) or n.output().type().isSubtypeOf(ListType.ofFloats()): vals = n.output().toIValue() value = torch.stack([torch.tensor(v) for v in vals]) if len(vals) else [] return g.op("Constant", value_t=value) elif n.output().type().kind() == "DeviceObjType": return None else: raise RuntimeError("Unsupported prim::Constant kind: `{}`. Send a bug report.".format( n.kindOf("value"))) elif n.mustBeNone() or op_name == "ListConstruct" or op_name == "ListUnpack": # None is not an ONNX operator; keep it as None # Let the exporter handle and finally eliminate these ops # ListConstruct and ListUnpack will be erased in the ONNX peephole pass return None elif op_name == 'Loop' or op_name == 'If': new_op_outputs = g.op(op_name, *inputs, outputs=n.outputsSize()) new_node = new_op_outputs[0].node() if n.outputsSize() > 1 else new_op_outputs.node() for b in n.blocks(): new_block = new_node.addBlock() torch._C._jit_pass_onnx_block(b, new_block, operator_export_type, env) new_op_outputs = torch._C._jit_pass_fixup_onnx_controlflow_node(new_node, opset_version) return new_op_outputs else: symbolic_name = 'prim_' + op_name domain = '' symbolic_fn = _find_symbolic_in_registry(domain, symbolic_name, opset_version, operator_export_type) if symbolic_fn is None: return None attrs = {k: n[k] for k in n.attributeNames()} return symbolic_fn(g, *inputs, **attrs) elif ns == "quantized": domain = '' if operator_export_type == OperatorExportTypes.ONNX_ATEN_FALLBACK: domain = 'caffe2' symbolic_fn = _find_symbolic_in_registry(domain, op_name, opset_version, operator_export_type) if symbolic_fn is None: return None attrs = {k: n[k] for k in n.attributeNames()} return symbolic_fn(g, *inputs, **attrs) # custom ops elif sym_registry.is_registered_version(ns, opset_version): domain = ns symbolic_fn = _find_symbolic_in_registry(domain, op_name, opset_version, operator_export_type) if symbolic_fn is None: return None attrs = {k: n[k] for k in n.attributeNames()} return symbolic_fn(g, *inputs, **attrs) else: raise RuntimeError("ONNX export failed on an operator with unrecognized namespace {}::{}. " "If you are trying to export a custom operator, make sure you registered " "it with the right domain and version. " "Otherwise, please report a bug.".format(ns, op_name)) except RuntimeError: if operator_export_type == OperatorExportTypes.ONNX_FALLTHROUGH: return None raise except TypeError as e: # Handle the specific case where we didn't successfully dispatch. # Otherwise, the backtrace will have the clues you need. e.args = ("{} \n(Occurred when translating {}).".format(e.args[0], op_name),) raise
def setup(self, model: nn.Module, opset_ver: int) -> None: self._model: Optional[nn.Module] = model # dryrun to register every aten ops sym_reg.register_version('', opset_ver) # type: ignore[no-untyped-call] self.opset_ver = opset_ver
def _run_symbolic_function_with_in_place( g, n, inputs, env, operator_export_type=OperatorExportTypes.ONNX): """ Monkey-patched version of `_run_symbolic_function` function in `torch.onnx.utils`. The only change is that trailing '_' is no longer removed from `ns_op_name` for the dropout function. """ try: import torch from torch.onnx.symbolic_helper import ( _export_onnx_opset_version as opset_version, ) import torch.onnx.symbolic_registry as sym_registry sym_registry.register_version("", opset_version) if operator_export_type == OperatorExportTypes.ONNX_ATEN_FALLBACK: import torch.onnx.symbolic_caffe2 torch.onnx.symbolic_caffe2.register_quantized_ops( "caffe2", opset_version) ns_op_name = n.kind() ns, op_name = ns_op_name.split("::") if n.kind().endswith("_"): if op_name not in ["dropout_", "feature_dropout_"]: op_name = op_name[:-1] if ns == "onnx": # Use the original node directly return None elif ns == "aten": is_exportable_aten_op = sym_registry.is_registered_op( op_name, "", opset_version) is_onnx_aten_export = operator_export_type == OperatorExportTypes.ONNX_ATEN is_aten_fallback_export = ( operator_export_type == OperatorExportTypes.ONNX_ATEN_FALLBACK) if is_onnx_aten_export or (not is_exportable_aten_op and is_aten_fallback_export): # Direct ATen export requested attrs = { k + "_" + n.kindOf(k)[0]: n[k] for k in n.attributeNames() } outputs = n.outputsSize() attrs["outputs"] = outputs return torch.onnx.utils._graph_at(g, op_name, *inputs, aten=True, **attrs) else: # Export it regularly attrs = {k: n[k] for k in n.attributeNames()} if not is_exportable_aten_op: warnings.warn( "ONNX export failed on ATen operator {} because " "torch.onnx.symbolic_opset{}.{} does not exist".format( op_name, opset_version, op_name)) op_fn = sym_registry.get_registered_op(op_name, "", opset_version) return op_fn(g, *inputs, **attrs) elif ns == "prim": if op_name == "Constant" and not n.mustBeNone(): if n.kindOf("value") == "t": return g.op("Constant", value_t=n["value"]) if n.kindOf("value") == "s": return g.op("Constant", value_s=n["value"]) elif n.kindOf("value") == "is": value = (torch.stack([torch.tensor(v) for v in n["value"]]) if n["value"] else []) return g.op("Constant", value_t=value) elif n.output().type().kind() == "DeviceObjType": return None else: raise RuntimeError( "Unsupported prim::Constant kind: `{}`".format( n.kindOf("value"))) elif (n.mustBeNone() or op_name == "ListConstruct" or op_name == "ListUnpack"): # None is not an ONNX operator; keep it as None # let the exporter handle finally eliminating these # For ListConstruct/ListUnpack, it will be erased in the # ONNX peephole pass return None elif op_name == "Loop" or op_name == "If": new_op_outputs = g.op(op_name, *inputs, outputs=n.outputsSize()) new_node = (new_op_outputs[0].node() if n.outputsSize() > 1 else new_op_outputs.node()) for b in n.blocks(): new_block = new_node.addBlock() torch._C._jit_pass_onnx_block(b, new_block, operator_export_type, env) return new_op_outputs else: symbolic_name = "prim_" + op_name is_exportable = sym_registry.is_registered_op( symbolic_name, "", opset_version) if not is_exportable: warnings.warn( "ONNX export failed on primitive operator {}".format( op_name)) symbolic_fn = sym_registry.get_registered_op( symbolic_name, "", opset_version) attrs = {k: n[k] for k in n.attributeNames()} return symbolic_fn(g, *inputs, **attrs) elif ns == "quantized": domain = "" if operator_export_type == OperatorExportTypes.ONNX_ATEN_FALLBACK: domain = "caffe2" attrs = {k: n[k] for k in n.attributeNames()} if not sym_registry.is_registered_op(op_name, domain, opset_version): warnings.warn( "ONNX export failed on quantized operator {}::{} because " "torch.onnx.symbolic_opset{}.{} does not exist. ".format( ns, op_name, opset_version, op_name)) op_fn = sym_registry.get_registered_op(op_name, domain, opset_version) return op_fn(g, *inputs, **attrs) # custom ops elif sym_registry.is_registered_version(ns, opset_version): if not sym_registry.is_registered_op(op_name, ns, opset_version): warnings.warn( "ONNX export failed on custom operator {}::{} because " "torch.onnx.symbolic_opset{}.{} does not exist.".format( ns, op_name, opset_version, op_name)) symbolic_fn = sym_registry.get_registered_op( op_name, ns, opset_version) attrs = {k: n[k] for k in n.attributeNames()} return symbolic_fn(g, *inputs, **attrs) else: warnings.warn( "ONNX export failed on an operator with unrecognized namespace " "{}::{}; If you are trying to export a custom operator, " "make sure you registered it with the right domain and version." "Otherwise please report a bug".format(ns, op_name)) return None except TypeError as e: # Handle the specific case where we didn't successfully dispatch. # Otherwise, the backtrace will have the clues you need. e.args = ("{} (occurred when translating {})".format( e.args[0], op_name), ) raise