def _onnx_node_to_tensorflow_op(cls, node, tensor_dict, handlers=None, opset=None, strict=True): """ Convert onnx node to tensorflow op. Args: node: Onnx node object. tensor_dict: Tensor dict of graph. opset: Opset version of the operator set. Default 0 means using latest version. strict: whether to enforce semantic equivalence between the original model and the converted tensorflow model, defaults to True (yes, enforce semantic equivalence). Changing to False is strongly discouraged. Returns: Tensorflow op """ handlers = handlers or cls._get_handlers(opset) handler = handlers[node.domain].get(node.op_type, None) if handler: return handler.handle(node, tensor_dict=tensor_dict, strict=strict) else: exception.OP_UNIMPLEMENTED_EXCEPT(node.op_type)
def handle(cls, node, **kwargs): """ Main method in handler. It will find corresponding versioned handle method, whose name format is `version_%d`. So prefix `version_` is reserved in onnx-tensorflow. DON'T use it for other purpose. :param node: NodeProto for backend or TensorflowNode for frontend. :param kwargs: Other args. :return: NodeProto for frontend or TensorflowNode for backend. """ ver_handle = getattr(cls, "version_{}".format(cls.SINCE_VERSION), None) if ver_handle: cls.args_check(node, **kwargs) return ver_handle(node, **kwargs) exception.OP_UNIMPLEMENTED_EXCEPT(node.op_type, cls.SINCE_VERSION) return None
def tensorflow_graph_to_onnx_graph(cls, graph_def, output, opset=((defs.ONNX_DOMAIN, defs.onnx_opset_version()), ), name="graph", ignore_unimplemented=False): """Converts a Tensorflow Graph Proto to an ONNX graph This function converts a Tensorflow Graph proto to an equivalent representation of ONNX graph. :param graph_def: Tensorflow Graph Proto object. :param output: List of Tensorflow NodeDef object specifying which nodes to be taken as outputs of the ONNX graph. :param opset: Opset, which should be ((str domain: int version number),). :param name: The name of the output ONNX Graph. :param ignore_unimplemented: Convert to ONNX model and ignore all the operators that are not currently supported by onnx-tensorflow. This is an experimental feature. By enabling this feature, the graph would not be guaranteed to match the ONNX specifications. :returns: The equivalent ONNX Graph Proto object. """ onnx_graph = OnnxGraph(name) exception.IGNORE_UNIMPLEMENTED = ignore_unimplemented opset_dict = {} for domain, version in opset: if domain == "ai.onnx": domain = defs.ONNX_DOMAIN opset_dict[domain] = version handlers = get_all_frontend_handlers(opset_dict) node_tup = [(node.name, TensorflowNode(node)) for node in graph_def.node] for name, node in node_tup: if node.op_type == "Placeholder": onnx_graph.add_input_proto(node) elif node.op_type == "Const": onnx_graph.add_const(node) onnx_graph.add_const_proto(node) onnx_graph.add_input_proto(node) else: onnx_graph.add_value_info_proto(node) handler = handlers.get(node.domain, {}).get(node.op_type, None) node_proto = None if handler: node_proto = handler.handle( node, consts=onnx_graph.consts, node_dict=dict(node_tup), data_type_cast_map=onnx_graph.data_type_cast_map) else: exception.OP_UNIMPLEMENTED_EXCEPT( node.op_type, domain=None if node.domain in handlers else node.domain) if node_proto is None: node_proto = FrontendHandler.make_node_from_tf_node( node, op_type=node.op_type, should_check=False) onnx_graph.add_node_proto(node_proto) for o in output: output_node = TensorflowNode(o) onnx_graph.add_output_proto(output_node) return onnx_graph.make_graph_proto()
def tensorflow_graph_to_onnx_graph(cls, tf_graph, opset=((defs.ONNX_DOMAIN, defs.onnx_opset_version()), ), ignore_unimplemented=False): """Converts a TensorflowGraph to an ONNX graph This function converts a TensorflowGraph to an equivalent representation of ONNX graph. :param tf_graph: TensorflowGraph object. :param opset: Opset, which should be ((str domain: int version number),). :param ignore_unimplemented: Convert to ONNX model and ignore all the operators that are not currently supported by onnx-tensorflow. This is an experimental feature. By enabling this feature, the graph would not be guaranteed to match the ONNX specifications. :returns: The equivalent ONNX Graph Proto object. """ onnx_graph = OnnxGraph(tf_graph.graph_name) exception.IGNORE_UNIMPLEMENTED = ignore_unimplemented training_ops_to_remove = ["RandomShuffleQueueV2"] opset_dict = {} for domain, version in opset: if domain == "ai.onnx": domain = defs.ONNX_DOMAIN opset_dict[domain] = version handlers = get_all_frontend_handlers(opset_dict) node_tup = [(n.name, n) for n in tf_graph.nodes] for name, node in node_tup: if node.op_type == "Placeholder": onnx_graph.add_input_proto(node) elif node.op_type == "Const": onnx_graph.add_const(node) onnx_graph.add_const_proto(node) onnx_graph.add_input_proto(node) elif node.op_type in training_ops_to_remove: logger.info( "A training op with name {} type {} has been removed.". format(node.name, node.op_type)) elif node.op_type == "QueueDequeueManyV2": num_output = len(node.attr["_output_shapes"]) for index, shape, onnx_type in zip( range(num_output), node.attr["_output_shapes"], node.attr["component_types"]): onnx_graph.add_input_proto_explicit(node.name + ":" + str(index), shape, onnx_dtype=onnx_type) else: onnx_graph.add_value_info_proto(node) handler = handlers.get(node.domain, {}).get(node.op_type, None) node_proto = None if handler: node_proto = handler.handle( node, consts=onnx_graph.consts, node_dict=dict(node_tup), data_type_cast_map=onnx_graph.data_type_cast_map) else: exception.OP_UNIMPLEMENTED_EXCEPT( node.op_type, domain=None if node.domain in handlers else node.domain) if node_proto is None: node_proto = FrontendHandler.make_node_from_tf_node( node, op_type=node.op_type, should_check=False) onnx_graph.add_node_proto(node_proto) for o in tf_graph.outputs: output_node = tf_graph.get_node_by_name(o) onnx_graph.add_output_proto(output_node) return onnx_graph.make_graph_proto()