def convert(infile, outfile, convert_to, **kwargs): """Convert pb. Args: infile: Input path. outfile: Output path. convert_to: Format converted to. **kwargs: Other args for converting. Returns: None. """ if convert_to == "tf": logger.info("Start converting onnx pb to tf pb:") onnx_model = onnx.load(infile) tf_rep = backend.prepare(onnx_model, **kwargs) tf_rep.export_graph(outfile) elif convert_to == "onnx": ext = os.path.splitext(infile)[1] logger.info("Start converting tf pb to onnx pb:") if ext == ".pb": with open(infile, "rb") as f: graph_def = graph_pb2.GraphDef() graph_def.ParseFromString(f.read()) elif ext == ".ckpt": latest_ckpt = tf.train.latest_checkpoint(os.path.dirname(infile)) saver = tf.train.import_meta_graph(latest_ckpt + ".meta") with tf.Session() as sess: sess.run([ tf.global_variables_initializer(), tf.local_variables_initializer() ]) saver.restore(sess, latest_ckpt) output_node_names = get_output_node_names(sess.graph.as_graph_def()) graph_def = tf.graph_util.convert_variables_to_constants( sess, sess.graph.as_graph_def(add_shapes=True), output_node_names) else: raise ValueError( "Input file is not supported. Should be .pb or .ckpt, but get {}". format(ext)) onnx_model = frontend.tensorflow_graph_to_onnx_model( graph_def, get_output_node_names(graph_def), **kwargs) onnx.save(onnx_model, outfile) logger.info("Converting completes successfully.")
def load_graph_from_ckpt(ckpt_file): """ Load graph from a checkpoint :param ckpt_file: the checkpoint file :return: graph of operations extracted from the checkpoint """ latest_ckpt = tf.train.latest_checkpoint(os.path.dirname(ckpt_file)) saver = tf.train.import_meta_graph(latest_ckpt + ".meta") with tf.Session() as sess: sess.run( [tf.global_variables_initializer(), tf.local_variables_initializer()]) saver.restore(sess, latest_ckpt) output_node_names = get_output_node_names(sess.graph.as_graph_def()) graph_def = tf.graph_util.convert_variables_to_constants( sess, sess.graph.as_graph_def(add_shapes=True), output_node_names) return graph_def
from tensorflow.core.framework import graph_pb2 from onnx_tf.common import get_output_node_names from onnx_tf.frontend import tensorflow_graph_to_onnx_model graph_def = graph_pb2.GraphDef() with open("input_path", "rb") as f: # load tf graph def graph_def.ParseFromString(f.read()) output = get_output_node_names(graph_def) # get output node names model = tensorflow_graph_to_onnx_model( graph_def, output) # convert tf graph to onnx model with open("output_path", 'wb') as f: f.write(model.SerializeToString())
def convert(infile, outfile, convert_to, graph, **kwargs): """Convert pb. Args: infile: Input path. outfile: Output path. convert_to: Format converted to. graph: Inference graph. **kwargs: Other args for converting. Returns: None. """ if convert_to == "tf": logger.info("Start converting onnx pb to tf pb:") onnx_model = onnx.load(infile) tf_rep = backend.prepare(onnx_model, **kwargs) tf_rep.export_graph(outfile) elif convert_to == "onnx": ext = os.path.splitext(infile)[1] logger.info("Start converting tf pb to onnx pb:") if ext == ".pb": with open(infile, "rb") as f: graph_def = graph_pb2.GraphDef() graph_def.ParseFromString(f.read()) elif ext == ".ckpt": latest_ckpt = tf.train.latest_checkpoint(os.path.dirname(infile)) saver = tf.train.import_meta_graph(latest_ckpt + ".meta") output_node_names = [] temp_file_suffix = get_unique_suffix() workdir = 'onnx-tf_workdir_{}'.format(temp_file_suffix) with tf.Session() as sess: sess.run([ tf.global_variables_initializer(), tf.local_variables_initializer() ]) saver.restore(sess, latest_ckpt) # Take users' hint or deduce output node automatically. kwargs["output"] = kwargs.get("output", None) or get_output_node_names( sess.graph.as_graph_def()) # Save the graph to disk for freezing. tf.train.write_graph( sess.graph.as_graph_def(add_shapes=True), workdir, "input_model.pb", as_text=False) # Freeze graph: freeze_graph.freeze_graph( input_graph=graph or workdir + "/input_model.pb", input_saver="", input_binary=True, input_checkpoint=latest_ckpt, output_node_names=",".join(kwargs["output"]), restore_op_name="", filename_tensor_name="", output_graph=workdir + "/frozen_model.pb", clear_devices=True, initializer_nodes="") # Load back the frozen graph. with open(workdir + "/frozen_model.pb", "rb") as f: graph_def = graph_pb2.GraphDef() graph_def.ParseFromString(f.read()) # Remove work directory. shutil.rmtree(workdir) else: raise ValueError( "Input file is not supported. Should be .pb or .ckpt, but get {}". format(ext)) onnx_model = frontend.tensorflow_graph_to_onnx_model(graph_def, **kwargs) onnx.save(onnx_model, outfile) logger.info("Converting completes successfully.")