def unregister(name, opset_version): ns, kind = name.split("::") from torch.onnx.symbolic_helper import _onnx_stable_opsets for version in _onnx_stable_opsets: if version >= opset_version and sym_registry.is_registered_op(kind, ns, version): del sym_registry._registry[(ns, version)][kind]
def register_quantized_ops(domain: str, version: int): # Register all the non-quantized ops symbolic_registry.register_version("", version) # Register all quantized ops module = importlib.import_module("torch.onnx.symbolic_caffe2") symbolic_registry._symbolic_versions["caffe2"] = module quant_version_ops = inspect.getmembers( symbolic_registry._symbolic_versions["caffe2"]) for op in quant_version_ops: if inspect.isfunction( op[1]) and not symbolic_registry.is_registered_op( op[0], domain, version): aten_q_ops = [ "relu", "_empty_affine_quantized", "dequantize", "quantize_per_tensor", "upsample_nearest2d", "avg_pool2d", "reshape", "slice", "cat", "max_pool2d", "sigmoid", ] if op[0] in aten_q_ops: symbolic_registry.register_op(op[0], op[1], "", version) symbolic_registry.register_op(op[0], op[1], domain, version)
def _find_symbolic_in_registry(domain, op_name, opset_version, operator_export_type): import torch.onnx.symbolic_registry as sym_registry if not sym_registry.is_registered_op(op_name, domain, opset_version): if operator_export_type == OperatorExportTypes.ONNX_FALLTHROUGH: # Use the original node directly return None return sym_registry.get_registered_op(op_name, domain, opset_version)
def wrapped_function(*args, **kwargs): name = op_name if op_name is not None else func.__name__ opset = sym_help._export_onnx_opset_version if is_registered_op(name, namespace, opset): class XFunction(torch.autograd.Function): @staticmethod def forward(ctx, *xargs): return func(*args, **kwargs) @staticmethod def symbolic(g, *xargs): symb = get_registered_op(name, namespace, opset) if adapter is not None: return symb(g, *xargs, **adapter_kwargs) return symb(g, *xargs) if adapter is not None: adapter_args, adapter_kwargs = adapter(*args, **kwargs) return XFunction.apply(*adapter_args) return XFunction.apply(*args) else: return func(*args, **kwargs)
def unregister(): """Unregister ONNX Runtime's built-in contrib ops.""" # TODO: replace this once PyTorch supports unregister natively. # https://msdata.visualstudio.com/Vienna/_workitems/edit/1342343 for name in _registered_ops: ns, kind = name.split("::") for version in sym_help._onnx_stable_opsets: if version >= _OPSET_VERSION and sym_registry.is_registered_op(kind, ns, version): del sym_registry._registry[(ns, version)][kind]
def unregister(): """Unregister ONNX Runtime's built-in contrib ops.""" for name in _registered_ops: try: torch.onnx.unregister_custom_op_symbolic(name, _OPSET_VERSION) except AttributeError: # unregister_custom_op_symbolic is not available before PyTorch 1.12 namespace, kind = name.split("::") for version in sym_help._onnx_stable_opsets: if version >= _OPSET_VERSION and sym_registry.is_registered_op( kind, namespace, version): del sym_registry._registry[(namespace, version)][kind]
def register_quantized_ops(domain, version): # Register all the non-quantized ops sym_registry.register_version('', version) # Register all quantized ops module = importlib.import_module('torch.onnx.symbolic_caffe2') sym_registry._symbolic_versions['caffe2'] = module quant_version_ops = getmembers(sym_registry._symbolic_versions['caffe2']) for op in quant_version_ops: if isfunction(op[1]) and not sym_registry.is_registered_op(op[0], domain, version): aten_q_ops = ['relu', '_empty_affine_quantized', 'dequantize', 'quantize_per_tensor', 'upsample_nearest2d'] if op[0] in aten_q_ops: sym_registry.register_op(op[0], op[1], '', version) sym_registry.register_op(op[0], op[1], domain, version)
def symbolic_function(self, n: torch._C.Node) -> Optional[Callable]: ns, op = n.kind().split("::") if op.endswith("_"): # For inplace op op = op[:-1] if ns == "prim" and op == "PythonOp": pyobj = n.pyobj() if issubclass(pyobj.__self__, torch.autograd.Function): pyobj = pyobj.__self__ assert issubclass(pyobj, torch.autograd.Function) assert hasattr(pyobj, "symbolic"), f"symbolic method not supported in {pyobj}" # TODO(twata): Use repr(pyobj) in scope name or doc_string return cast(Callable, pyobj.symbolic) else: if ns == "prim": op = f"prim_{op}" if sym_reg.is_registered_op(op, "", self.opset_version): # type: ignore[no-untyped-call] return cast( Callable, sym_reg.get_registered_op(op, "", self.opset_version) # type: ignore[no-untyped-call] ) else: return None
def __init__(self, inner, namespace, name, verbose=False, force_rebuild=True, unique_name=True): super().__init__() self.inner = inner self.namespace = namespace self.name = name if unique_name: self.name = name + '_' + next( tempfile._get_candidate_names()) + 'x' self.qualified_name = '{}::{}'.format(self.namespace, self.name) self.num_outputs = 1 self.params = {} self.verbose = verbose self.force_rebuild = force_rebuild # Register symbolic function for ONNX export. while not is_registered_op(self.name, self.namespace, 10): register_op(self.name, self.symbolic, self.namespace, 10)
def _run_symbolic_function(g, n, inputs, env, operator_export_type=OperatorExportTypes.ONNX): # NB: Returning None means the node gets cloned as is into # the new graph try: from torch.onnx.symbolic_helper import _export_onnx_opset_version as opset_version import torch.onnx.symbolic_registry as sym_registry sym_registry.register_version('', opset_version) # See Note [Export inplace] # TODO: I think this is not necessary anymore if n.kind().endswith('_'): ns_op_name = n.kind()[:-1] else: ns_op_name = n.kind() ns, op_name = ns_op_name.split("::") if ns == "onnx": # Use the original node directly return None elif ns == "aten": is_exportable_aten_op = sym_registry.is_registered_op( op_name, '', opset_version) is_onnx_aten_export = operator_export_type == OperatorExportTypes.ONNX_ATEN is_aten_fallback_export = operator_export_type == OperatorExportTypes.ONNX_ATEN_FALLBACK if is_onnx_aten_export or (not is_exportable_aten_op and is_aten_fallback_export): # Direct ATen export requested attrs = { k + "_" + n.kindOf(k)[0]: n[k] for k in n.attributeNames() } outputs = n.outputsSize() attrs["outputs"] = outputs return _graph_at(g, op_name, *inputs, aten=True, **attrs) else: # Export it regularly attrs = {k: n[k] for k in n.attributeNames()} if not is_exportable_aten_op: warnings.warn( "ONNX export failed on ATen operator {} because " "torch.onnx.symbolic_opset{}.{} does not exist".format( op_name, opset_version, op_name)) op_fn = sym_registry.get_registered_op(op_name, '', opset_version) return op_fn(g, *inputs, **attrs) elif ns == "prim": if op_name == "Constant" and not n.mustBeNone(): if n.kindOf("value") == "t": return g.op("Constant", value_t=n["value"]) elif n.kindOf("value") == "is": value = torch.stack([torch.tensor(v) for v in n["value"] ]) if n["value"] else [] return g.op("Constant", value_t=value) elif n.output().type().kind() == "DeviceObjType": return None else: raise RuntimeError( "Unsupported prim::Constant kind: `{}`. Send a bug report." .format(n.kindOf("value"))) elif n.mustBeNone( ) or op_name == "ListConstruct" or op_name == "ListUnpack": # None is not an ONNX operator; keep it as None # let the exporter handle finally eliminating these # For ListConstruct/ListUnpack, it will be erased in the ONNX peephole pass return None elif op_name == 'Loop' or op_name == 'If': new_op_outputs = g.op(op_name, *inputs, outputs=n.outputsSize()) new_node = new_op_outputs[0].node( ) if n.outputsSize() > 1 else new_op_outputs.node() for b in n.blocks(): new_block = new_node.addBlock() torch._C._jit_pass_onnx_block(b, new_block, operator_export_type, env) return new_op_outputs else: # TODO: we sould lift prim's symbolic out symbolic_name = 'prim_' + op_name is_exportable = sym_registry.is_registered_op( symbolic_name, '', opset_version) if not is_exportable: warnings.warn( "ONNX export failed on primitive operator {}; please report a bug" .format(op_name)) symbolic_fn = sym_registry.get_registered_op( symbolic_name, '', opset_version) attrs = {k: n[k] for k in n.attributeNames()} return symbolic_fn(g, *inputs, **attrs) # custom ops elif sym_registry.is_registered_version(ns, opset_version): if not sym_registry.is_registered_op(op_name, ns, opset_version): warnings.warn( "ONNX export failed on custom operator {}::{} because " "torch.onnx.symbolic_opset{}.{} does not exist. " "Have you registered your symbolic function with " "torch.onnx.register_custom_op_symbolic(symbolic_name, symbolic_fn)?" .format(ns, op_name, opset_version, op_name)) symbolic_fn = sym_registry.get_registered_op( op_name, ns, opset_version) attrs = {k: n[k] for k in n.attributeNames()} return symbolic_fn(g, *inputs, **attrs) else: warnings.warn( "ONNX export failed on an operator with unrecognized namespace {}::{}; " "If you are trying to export a custom operator, make sure you registered " "it with the right domain and version." "Otherwise please report a bug".format(ns, op_name)) return None except TypeError as e: # Handle the specific case where we didn't successfully dispatch. # Otherwise, the backtrace will have the clues you need. e.args = ("{} (occurred when translating {})".format( e.args[0], op_name), ) raise
def _run_symbolic_function(g, n, inputs, env, operator_export_type=OperatorExportTypes.ONNX): # NB: Returning None means the node gets cloned as is into # the new graph try: import torch from torch.onnx.symbolic_helper import _export_onnx_opset_version as opset_version import torch.onnx.symbolic_registry as sym_registry sym_registry.register_version('', opset_version) # Quantized op symbolics are registered for opset 9 only. if operator_export_type == OperatorExportTypes.ONNX_ATEN_FALLBACK and opset_version == 9: import torch.onnx.symbolic_caffe2 torch.onnx.symbolic_caffe2.register_quantized_ops('caffe2', opset_version) # See Note [Export inplace] # TODO: I think this is not necessary anymore if n.kind().endswith('_'): ns_op_name = n.kind()[:-1] else: ns_op_name = n.kind() ns, op_name = ns_op_name.split("::") if ns == "onnx": # Use the original node directly return None elif ns == "aten": is_exportable_aten_op = sym_registry.is_registered_op(op_name, '', opset_version) is_onnx_aten_export = operator_export_type == OperatorExportTypes.ONNX_ATEN is_aten_fallback_export = operator_export_type == OperatorExportTypes.ONNX_ATEN_FALLBACK if is_onnx_aten_export or (not is_exportable_aten_op and is_aten_fallback_export): # Direct ATen export requested attrs = {k + "_" + n.kindOf(k)[0]: n[k] for k in n.attributeNames()} outputs = n.outputsSize() attrs["outputs"] = outputs return _graph_at(g, op_name, *inputs, aten=True, **attrs) else: # Export it regularly domain = '' symbolic_fn = _find_symbolic_in_registry(domain, op_name, opset_version, operator_export_type) if symbolic_fn is None: return None attrs = {k: n[k] for k in n.attributeNames()} return symbolic_fn(g, *inputs, **attrs) elif ns == "prim": if op_name == "Constant" and not n.mustBeNone(): if n.kindOf("value") == "t": return g.op("Constant", value_t=n["value"]) if n.kindOf("value") == "s": return g.op("Constant", value_s=n["value"]) elif n.output().type().isSubtypeOf(ListType.ofInts()) or n.output().type().isSubtypeOf(ListType.ofFloats()): vals = n.output().toIValue() value = torch.stack([torch.tensor(v) for v in vals]) if len(vals) else [] return g.op("Constant", value_t=value) elif n.output().type().kind() == "DeviceObjType": return None else: raise RuntimeError("Unsupported prim::Constant kind: `{}`. Send a bug report.".format( n.kindOf("value"))) elif n.mustBeNone() or op_name == "ListConstruct" or op_name == "ListUnpack": # None is not an ONNX operator; keep it as None # Let the exporter handle and finally eliminate these ops # ListConstruct and ListUnpack will be erased in the ONNX peephole pass return None elif op_name == 'Loop' or op_name == 'If': new_op_outputs = g.op(op_name, *inputs, outputs=n.outputsSize()) new_node = new_op_outputs[0].node() if n.outputsSize() > 1 else new_op_outputs.node() for b in n.blocks(): new_block = new_node.addBlock() torch._C._jit_pass_onnx_block(b, new_block, operator_export_type, env) new_op_outputs = torch._C._jit_pass_fixup_onnx_controlflow_node(new_node, opset_version) return new_op_outputs else: symbolic_name = 'prim_' + op_name domain = '' symbolic_fn = _find_symbolic_in_registry(domain, symbolic_name, opset_version, operator_export_type) if symbolic_fn is None: return None attrs = {k: n[k] for k in n.attributeNames()} return symbolic_fn(g, *inputs, **attrs) elif ns == "quantized": domain = '' if operator_export_type == OperatorExportTypes.ONNX_ATEN_FALLBACK: domain = 'caffe2' symbolic_fn = _find_symbolic_in_registry(domain, op_name, opset_version, operator_export_type) if symbolic_fn is None: return None attrs = {k: n[k] for k in n.attributeNames()} return symbolic_fn(g, *inputs, **attrs) # custom ops elif sym_registry.is_registered_version(ns, opset_version): domain = ns symbolic_fn = _find_symbolic_in_registry(domain, op_name, opset_version, operator_export_type) if symbolic_fn is None: return None attrs = {k: n[k] for k in n.attributeNames()} return symbolic_fn(g, *inputs, **attrs) else: raise RuntimeError("ONNX export failed on an operator with unrecognized namespace {}::{}. " "If you are trying to export a custom operator, make sure you registered " "it with the right domain and version. " "Otherwise, please report a bug.".format(ns, op_name)) except RuntimeError: if operator_export_type == OperatorExportTypes.ONNX_FALLTHROUGH: return None raise except TypeError as e: # Handle the specific case where we didn't successfully dispatch. # Otherwise, the backtrace will have the clues you need. e.args = ("{} \n(Occurred when translating {}).".format(e.args[0], op_name),) raise
def _run_symbolic_function_with_in_place( g, n, inputs, env, operator_export_type=OperatorExportTypes.ONNX): """ Monkey-patched version of `_run_symbolic_function` function in `torch.onnx.utils`. The only change is that trailing '_' is no longer removed from `ns_op_name` for the dropout function. """ try: import torch from torch.onnx.symbolic_helper import ( _export_onnx_opset_version as opset_version, ) import torch.onnx.symbolic_registry as sym_registry sym_registry.register_version("", opset_version) if operator_export_type == OperatorExportTypes.ONNX_ATEN_FALLBACK: import torch.onnx.symbolic_caffe2 torch.onnx.symbolic_caffe2.register_quantized_ops( "caffe2", opset_version) ns_op_name = n.kind() ns, op_name = ns_op_name.split("::") if n.kind().endswith("_"): if op_name not in ["dropout_", "feature_dropout_"]: op_name = op_name[:-1] if ns == "onnx": # Use the original node directly return None elif ns == "aten": is_exportable_aten_op = sym_registry.is_registered_op( op_name, "", opset_version) is_onnx_aten_export = operator_export_type == OperatorExportTypes.ONNX_ATEN is_aten_fallback_export = ( operator_export_type == OperatorExportTypes.ONNX_ATEN_FALLBACK) if is_onnx_aten_export or (not is_exportable_aten_op and is_aten_fallback_export): # Direct ATen export requested attrs = { k + "_" + n.kindOf(k)[0]: n[k] for k in n.attributeNames() } outputs = n.outputsSize() attrs["outputs"] = outputs return torch.onnx.utils._graph_at(g, op_name, *inputs, aten=True, **attrs) else: # Export it regularly attrs = {k: n[k] for k in n.attributeNames()} if not is_exportable_aten_op: warnings.warn( "ONNX export failed on ATen operator {} because " "torch.onnx.symbolic_opset{}.{} does not exist".format( op_name, opset_version, op_name)) op_fn = sym_registry.get_registered_op(op_name, "", opset_version) return op_fn(g, *inputs, **attrs) elif ns == "prim": if op_name == "Constant" and not n.mustBeNone(): if n.kindOf("value") == "t": return g.op("Constant", value_t=n["value"]) if n.kindOf("value") == "s": return g.op("Constant", value_s=n["value"]) elif n.kindOf("value") == "is": value = (torch.stack([torch.tensor(v) for v in n["value"]]) if n["value"] else []) return g.op("Constant", value_t=value) elif n.output().type().kind() == "DeviceObjType": return None else: raise RuntimeError( "Unsupported prim::Constant kind: `{}`".format( n.kindOf("value"))) elif (n.mustBeNone() or op_name == "ListConstruct" or op_name == "ListUnpack"): # None is not an ONNX operator; keep it as None # let the exporter handle finally eliminating these # For ListConstruct/ListUnpack, it will be erased in the # ONNX peephole pass return None elif op_name == "Loop" or op_name == "If": new_op_outputs = g.op(op_name, *inputs, outputs=n.outputsSize()) new_node = (new_op_outputs[0].node() if n.outputsSize() > 1 else new_op_outputs.node()) for b in n.blocks(): new_block = new_node.addBlock() torch._C._jit_pass_onnx_block(b, new_block, operator_export_type, env) return new_op_outputs else: symbolic_name = "prim_" + op_name is_exportable = sym_registry.is_registered_op( symbolic_name, "", opset_version) if not is_exportable: warnings.warn( "ONNX export failed on primitive operator {}".format( op_name)) symbolic_fn = sym_registry.get_registered_op( symbolic_name, "", opset_version) attrs = {k: n[k] for k in n.attributeNames()} return symbolic_fn(g, *inputs, **attrs) elif ns == "quantized": domain = "" if operator_export_type == OperatorExportTypes.ONNX_ATEN_FALLBACK: domain = "caffe2" attrs = {k: n[k] for k in n.attributeNames()} if not sym_registry.is_registered_op(op_name, domain, opset_version): warnings.warn( "ONNX export failed on quantized operator {}::{} because " "torch.onnx.symbolic_opset{}.{} does not exist. ".format( ns, op_name, opset_version, op_name)) op_fn = sym_registry.get_registered_op(op_name, domain, opset_version) return op_fn(g, *inputs, **attrs) # custom ops elif sym_registry.is_registered_version(ns, opset_version): if not sym_registry.is_registered_op(op_name, ns, opset_version): warnings.warn( "ONNX export failed on custom operator {}::{} because " "torch.onnx.symbolic_opset{}.{} does not exist.".format( ns, op_name, opset_version, op_name)) symbolic_fn = sym_registry.get_registered_op( op_name, ns, opset_version) attrs = {k: n[k] for k in n.attributeNames()} return symbolic_fn(g, *inputs, **attrs) else: warnings.warn( "ONNX export failed on an operator with unrecognized namespace " "{}::{}; If you are trying to export a custom operator, " "make sure you registered it with the right domain and version." "Otherwise please report a bug".format(ns, op_name)) return None except TypeError as e: # Handle the specific case where we didn't successfully dispatch. # Otherwise, the backtrace will have the clues you need. e.args = ("{} (occurred when translating {})".format( e.args[0], op_name), ) raise