def _find_symbolic_in_registry(domain, op_name, opset_version, operator_export_type): import torch.onnx.symbolic_registry as sym_registry if not sym_registry.is_registered_op(op_name, domain, opset_version): if operator_export_type == OperatorExportTypes.ONNX_FALLTHROUGH: # Use the original node directly return None return sym_registry.get_registered_op(op_name, domain, opset_version)
def symbolic_function(self, n: torch._C.Node) -> Optional[Callable]: ns, op = n.kind().split("::") if op.endswith("_"): # For inplace op op = op[:-1] if ns == "prim" and op == "PythonOp": pyobj = n.pyobj() if issubclass(pyobj.__self__, torch.autograd.Function): pyobj = pyobj.__self__ assert issubclass(pyobj, torch.autograd.Function) assert hasattr(pyobj, "symbolic"), f"symbolic method not supported in {pyobj}" # TODO(twata): Use repr(pyobj) in scope name or doc_string return cast(Callable, pyobj.symbolic) else: if ns == "prim": op = f"prim_{op}" if sym_reg.is_registered_op(op, "", self.opset_version): # type: ignore[no-untyped-call] return cast( Callable, sym_reg.get_registered_op(op, "", self.opset_version) # type: ignore[no-untyped-call] ) else: return None
def _run_symbolic_function(g, n, inputs, env, operator_export_type=OperatorExportTypes.ONNX): # NB: Returning None means the node gets cloned as is into # the new graph try: from torch.onnx.symbolic_helper import _export_onnx_opset_version as opset_version import torch.onnx.symbolic_registry as sym_registry sym_registry.register_version('', opset_version) # See Note [Export inplace] # TODO: I think this is not necessary anymore if n.kind().endswith('_'): ns_op_name = n.kind()[:-1] else: ns_op_name = n.kind() ns, op_name = ns_op_name.split("::") if ns == "onnx": # Use the original node directly return None elif ns == "aten": is_exportable_aten_op = sym_registry.is_registered_op( op_name, '', opset_version) is_onnx_aten_export = operator_export_type == OperatorExportTypes.ONNX_ATEN is_aten_fallback_export = operator_export_type == OperatorExportTypes.ONNX_ATEN_FALLBACK if is_onnx_aten_export or (not is_exportable_aten_op and is_aten_fallback_export): # Direct ATen export requested attrs = { k + "_" + n.kindOf(k)[0]: n[k] for k in n.attributeNames() } outputs = n.outputsSize() attrs["outputs"] = outputs return _graph_at(g, op_name, *inputs, aten=True, **attrs) else: # Export it regularly attrs = {k: n[k] for k in n.attributeNames()} if not is_exportable_aten_op: warnings.warn( "ONNX export failed on ATen operator {} because " "torch.onnx.symbolic_opset{}.{} does not exist".format( op_name, opset_version, op_name)) op_fn = sym_registry.get_registered_op(op_name, '', opset_version) return op_fn(g, *inputs, **attrs) elif ns == "prim": if op_name == "Constant" and not n.mustBeNone(): if n.kindOf("value") == "t": return g.op("Constant", value_t=n["value"]) elif n.kindOf("value") == "is": value = torch.stack([torch.tensor(v) for v in n["value"] ]) if n["value"] else [] return g.op("Constant", value_t=value) elif n.output().type().kind() == "DeviceObjType": return None else: raise RuntimeError( "Unsupported prim::Constant kind: `{}`. Send a bug report." .format(n.kindOf("value"))) elif n.mustBeNone( ) or op_name == "ListConstruct" or op_name == "ListUnpack": # None is not an ONNX operator; keep it as None # let the exporter handle finally eliminating these # For ListConstruct/ListUnpack, it will be erased in the ONNX peephole pass return None elif op_name == 'Loop' or op_name == 'If': new_op_outputs = g.op(op_name, *inputs, outputs=n.outputsSize()) new_node = new_op_outputs[0].node( ) if n.outputsSize() > 1 else new_op_outputs.node() for b in n.blocks(): new_block = new_node.addBlock() torch._C._jit_pass_onnx_block(b, new_block, operator_export_type, env) return new_op_outputs else: # TODO: we sould lift prim's symbolic out symbolic_name = 'prim_' + op_name is_exportable = sym_registry.is_registered_op( symbolic_name, '', opset_version) if not is_exportable: warnings.warn( "ONNX export failed on primitive operator {}; please report a bug" .format(op_name)) symbolic_fn = sym_registry.get_registered_op( symbolic_name, '', opset_version) attrs = {k: n[k] for k in n.attributeNames()} return symbolic_fn(g, *inputs, **attrs) # custom ops elif sym_registry.is_registered_version(ns, opset_version): if not sym_registry.is_registered_op(op_name, ns, opset_version): warnings.warn( "ONNX export failed on custom operator {}::{} because " "torch.onnx.symbolic_opset{}.{} does not exist. " "Have you registered your symbolic function with " "torch.onnx.register_custom_op_symbolic(symbolic_name, symbolic_fn)?" .format(ns, op_name, opset_version, op_name)) symbolic_fn = sym_registry.get_registered_op( op_name, ns, opset_version) attrs = {k: n[k] for k in n.attributeNames()} return symbolic_fn(g, *inputs, **attrs) else: warnings.warn( "ONNX export failed on an operator with unrecognized namespace {}::{}; " "If you are trying to export a custom operator, make sure you registered " "it with the right domain and version." "Otherwise please report a bug".format(ns, op_name)) return None except TypeError as e: # Handle the specific case where we didn't successfully dispatch. # Otherwise, the backtrace will have the clues you need. e.args = ("{} (occurred when translating {})".format( e.args[0], op_name), ) raise
def symbolic(g, *xargs): symb = get_registered_op(name, namespace, opset) return symb(g, *xargs)
def _run_symbolic_function_with_in_place( g, n, inputs, env, operator_export_type=OperatorExportTypes.ONNX): """ Monkey-patched version of `_run_symbolic_function` function in `torch.onnx.utils`. The only change is that trailing '_' is no longer removed from `ns_op_name` for the dropout function. """ try: import torch from torch.onnx.symbolic_helper import ( _export_onnx_opset_version as opset_version, ) import torch.onnx.symbolic_registry as sym_registry sym_registry.register_version("", opset_version) if operator_export_type == OperatorExportTypes.ONNX_ATEN_FALLBACK: import torch.onnx.symbolic_caffe2 torch.onnx.symbolic_caffe2.register_quantized_ops( "caffe2", opset_version) ns_op_name = n.kind() ns, op_name = ns_op_name.split("::") if n.kind().endswith("_"): if op_name not in ["dropout_", "feature_dropout_"]: op_name = op_name[:-1] if ns == "onnx": # Use the original node directly return None elif ns == "aten": is_exportable_aten_op = sym_registry.is_registered_op( op_name, "", opset_version) is_onnx_aten_export = operator_export_type == OperatorExportTypes.ONNX_ATEN is_aten_fallback_export = ( operator_export_type == OperatorExportTypes.ONNX_ATEN_FALLBACK) if is_onnx_aten_export or (not is_exportable_aten_op and is_aten_fallback_export): # Direct ATen export requested attrs = { k + "_" + n.kindOf(k)[0]: n[k] for k in n.attributeNames() } outputs = n.outputsSize() attrs["outputs"] = outputs return torch.onnx.utils._graph_at(g, op_name, *inputs, aten=True, **attrs) else: # Export it regularly attrs = {k: n[k] for k in n.attributeNames()} if not is_exportable_aten_op: warnings.warn( "ONNX export failed on ATen operator {} because " "torch.onnx.symbolic_opset{}.{} does not exist".format( op_name, opset_version, op_name)) op_fn = sym_registry.get_registered_op(op_name, "", opset_version) return op_fn(g, *inputs, **attrs) elif ns == "prim": if op_name == "Constant" and not n.mustBeNone(): if n.kindOf("value") == "t": return g.op("Constant", value_t=n["value"]) if n.kindOf("value") == "s": return g.op("Constant", value_s=n["value"]) elif n.kindOf("value") == "is": value = (torch.stack([torch.tensor(v) for v in n["value"]]) if n["value"] else []) return g.op("Constant", value_t=value) elif n.output().type().kind() == "DeviceObjType": return None else: raise RuntimeError( "Unsupported prim::Constant kind: `{}`".format( n.kindOf("value"))) elif (n.mustBeNone() or op_name == "ListConstruct" or op_name == "ListUnpack"): # None is not an ONNX operator; keep it as None # let the exporter handle finally eliminating these # For ListConstruct/ListUnpack, it will be erased in the # ONNX peephole pass return None elif op_name == "Loop" or op_name == "If": new_op_outputs = g.op(op_name, *inputs, outputs=n.outputsSize()) new_node = (new_op_outputs[0].node() if n.outputsSize() > 1 else new_op_outputs.node()) for b in n.blocks(): new_block = new_node.addBlock() torch._C._jit_pass_onnx_block(b, new_block, operator_export_type, env) return new_op_outputs else: symbolic_name = "prim_" + op_name is_exportable = sym_registry.is_registered_op( symbolic_name, "", opset_version) if not is_exportable: warnings.warn( "ONNX export failed on primitive operator {}".format( op_name)) symbolic_fn = sym_registry.get_registered_op( symbolic_name, "", opset_version) attrs = {k: n[k] for k in n.attributeNames()} return symbolic_fn(g, *inputs, **attrs) elif ns == "quantized": domain = "" if operator_export_type == OperatorExportTypes.ONNX_ATEN_FALLBACK: domain = "caffe2" attrs = {k: n[k] for k in n.attributeNames()} if not sym_registry.is_registered_op(op_name, domain, opset_version): warnings.warn( "ONNX export failed on quantized operator {}::{} because " "torch.onnx.symbolic_opset{}.{} does not exist. ".format( ns, op_name, opset_version, op_name)) op_fn = sym_registry.get_registered_op(op_name, domain, opset_version) return op_fn(g, *inputs, **attrs) # custom ops elif sym_registry.is_registered_version(ns, opset_version): if not sym_registry.is_registered_op(op_name, ns, opset_version): warnings.warn( "ONNX export failed on custom operator {}::{} because " "torch.onnx.symbolic_opset{}.{} does not exist.".format( ns, op_name, opset_version, op_name)) symbolic_fn = sym_registry.get_registered_op( op_name, ns, opset_version) attrs = {k: n[k] for k in n.attributeNames()} return symbolic_fn(g, *inputs, **attrs) else: warnings.warn( "ONNX export failed on an operator with unrecognized namespace " "{}::{}; If you are trying to export a custom operator, " "make sure you registered it with the right domain and version." "Otherwise please report a bug".format(ns, op_name)) return None except TypeError as e: # Handle the specific case where we didn't successfully dispatch. # Otherwise, the backtrace will have the clues you need. e.args = ("{} (occurred when translating {})".format( e.args[0], op_name), ) raise
def symbolic(g, *xargs): symb = get_registered_op(name, namespace, opset) if adapter is not None: return symb(g, *xargs, **adapter_kwargs) return symb(g, *xargs)