def tensorflow_graph_to_onnx_graph(cls, graph_def, output, opset=((defs.ONNX_DOMAIN, defs.onnx_opset_version()), ), name="graph", ignore_unimplemented=False): """Converts a Tensorflow Graph Proto to an ONNX graph This function converts a Tensorflow Graph proto to an equivalent representation of ONNX graph. :param graph_def: Tensorflow Graph Proto object. :param output: List of Tensorflow NodeDef object specifying which nodes to be taken as outputs of the ONNX graph. :param opset: Opset, which should be ((str domain: int version number),). :param name: The name of the output ONNX Graph. :param ignore_unimplemented: Convert to ONNX model and ignore all the operators that are not currently supported by onnx-tensorflow. This is an experimental feature. By enabling this feature, the graph would not be guaranteed to match the ONNX specifications. :returns: The equivalent ONNX Graph Proto object. """ onnx_graph = OnnxGraph(name) exception.IGNORE_UNIMPLEMENTED = ignore_unimplemented opset_dict = {} for domain, version in opset: if domain == "ai.onnx": domain = defs.ONNX_DOMAIN opset_dict[domain] = version handlers = get_all_frontend_handlers(opset_dict) node_tup = [(node.name, TensorflowNode(node)) for node in graph_def.node] for name, node in node_tup: if node.op_type == "Placeholder": onnx_graph.add_input_proto(node) elif node.op_type == "Const": onnx_graph.add_const(node) onnx_graph.add_const_proto(node) onnx_graph.add_input_proto(node) else: onnx_graph.add_value_info_proto(node) handler = handlers.get(node.domain, {}).get(node.op_type, None) node_proto = None if handler: node_proto = handler.handle( node, consts=onnx_graph.consts, node_dict=dict(node_tup), data_type_cast_map=onnx_graph.data_type_cast_map) else: exception.OP_UNIMPLEMENTED_EXCEPT( node.op_type, domain=None if node.domain in handlers else node.domain) if node_proto is None: node_proto = FrontendHandler.make_node_from_tf_node( node, op_type=node.op_type, should_check=False) onnx_graph.add_node_proto(node_proto) for o in output: output_node = TensorflowNode(o) onnx_graph.add_output_proto(output_node) return onnx_graph.make_graph_proto()
def tensorflow_graph_to_onnx_graph(cls, tf_graph, opset=((defs.ONNX_DOMAIN, defs.onnx_opset_version()), ), ignore_unimplemented=False): """Converts a TensorflowGraph to an ONNX graph This function converts a TensorflowGraph to an equivalent representation of ONNX graph. :param tf_graph: TensorflowGraph object. :param opset: Opset, which should be ((str domain: int version number),). :param ignore_unimplemented: Convert to ONNX model and ignore all the operators that are not currently supported by onnx-tensorflow. This is an experimental feature. By enabling this feature, the graph would not be guaranteed to match the ONNX specifications. :returns: The equivalent ONNX Graph Proto object. """ onnx_graph = OnnxGraph(tf_graph.graph_name) exception.IGNORE_UNIMPLEMENTED = ignore_unimplemented training_ops_to_remove = ["RandomShuffleQueueV2"] opset_dict = {} for domain, version in opset: if domain == "ai.onnx": domain = defs.ONNX_DOMAIN opset_dict[domain] = version handlers = get_all_frontend_handlers(opset_dict) node_tup = [(n.name, n) for n in tf_graph.nodes] for name, node in node_tup: if node.op_type == "Placeholder": onnx_graph.add_input_proto(node) elif node.op_type == "Const": onnx_graph.add_const(node) onnx_graph.add_const_proto(node) onnx_graph.add_input_proto(node) elif node.op_type in training_ops_to_remove: logger.info( "A training op with name {} type {} has been removed.". format(node.name, node.op_type)) elif node.op_type == "QueueDequeueManyV2": num_output = len(node.attr["_output_shapes"]) for index, shape, onnx_type in zip( range(num_output), node.attr["_output_shapes"], node.attr["component_types"]): onnx_graph.add_input_proto_explicit(node.name + ":" + str(index), shape, onnx_dtype=onnx_type) else: onnx_graph.add_value_info_proto(node) handler = handlers.get(node.domain, {}).get(node.op_type, None) node_proto = None if handler: node_proto = handler.handle( node, consts=onnx_graph.consts, node_dict=dict(node_tup), data_type_cast_map=onnx_graph.data_type_cast_map) else: exception.OP_UNIMPLEMENTED_EXCEPT( node.op_type, domain=None if node.domain in handlers else node.domain) if node_proto is None: node_proto = FrontendHandler.make_node_from_tf_node( node, op_type=node.op_type, should_check=False) onnx_graph.add_node_proto(node_proto) for o in tf_graph.outputs: output_node = tf_graph.get_node_by_name(o) onnx_graph.add_output_proto(output_node) return onnx_graph.make_graph_proto()