示例#1
0
def convert(infile, outfile, convert_to, **kwargs):
  """Convert pb.

  Args:
    infile: Input path.
    outfile: Output path.
    convert_to: Format converted to.
    **kwargs: Other args for converting.

  Returns:
    None.
  """
  if convert_to == "tf":
    logger.info("Start converting onnx pb to tf pb:")
    onnx_model = onnx.load(infile)
    tf_rep = backend.prepare(onnx_model, **kwargs)
    tf_rep.export_graph(outfile)
  elif convert_to == "onnx":
    ext = os.path.splitext(infile)[1]
    logger.info("Start converting tf pb to onnx pb:")
    if ext == ".pb":
      with open(infile, "rb") as f:
        graph_def = graph_pb2.GraphDef()
        graph_def.ParseFromString(f.read())
    elif ext == ".ckpt":
      latest_ckpt = tf.train.latest_checkpoint(os.path.dirname(infile))
      saver = tf.train.import_meta_graph(latest_ckpt + ".meta")
      with tf.Session() as sess:
        sess.run([
            tf.global_variables_initializer(),
            tf.local_variables_initializer()
        ])
        saver.restore(sess, latest_ckpt)
        output_node_names = get_output_node_names(sess.graph.as_graph_def())
        graph_def = tf.graph_util.convert_variables_to_constants(
            sess, sess.graph.as_graph_def(add_shapes=True), output_node_names)
    else:
      raise ValueError(
          "Input file is not supported. Should be .pb or .ckpt, but get {}".
          format(ext))
    onnx_model = frontend.tensorflow_graph_to_onnx_model(
        graph_def, get_output_node_names(graph_def), **kwargs)
    onnx.save(onnx_model, outfile)
  logger.info("Converting completes successfully.")
示例#2
0
def load_graph_from_ckpt(ckpt_file):
  """ Load graph from a checkpoint

  :param ckpt_file: the checkpoint file
  :return: graph of operations extracted from the checkpoint 
  """

  latest_ckpt = tf.train.latest_checkpoint(os.path.dirname(ckpt_file))
  saver = tf.train.import_meta_graph(latest_ckpt + ".meta")
  with tf.Session() as sess:
    sess.run(
        [tf.global_variables_initializer(),
         tf.local_variables_initializer()])
    saver.restore(sess, latest_ckpt)
    output_node_names = get_output_node_names(sess.graph.as_graph_def())
    graph_def = tf.graph_util.convert_variables_to_constants(
        sess, sess.graph.as_graph_def(add_shapes=True), output_node_names)
  return graph_def
示例#3
0
from tensorflow.core.framework import graph_pb2

from onnx_tf.common import get_output_node_names
from onnx_tf.frontend import tensorflow_graph_to_onnx_model

graph_def = graph_pb2.GraphDef()
with open("input_path", "rb") as f:  # load tf graph def
    graph_def.ParseFromString(f.read())
output = get_output_node_names(graph_def)  # get output node names

model = tensorflow_graph_to_onnx_model(
    graph_def, output)  # convert tf graph to onnx model
with open("output_path", 'wb') as f:
    f.write(model.SerializeToString())
示例#4
0
def convert(infile, outfile, convert_to, graph, **kwargs):
  """Convert pb.

  Args:
    infile: Input path.
    outfile: Output path.
    convert_to: Format converted to.
    graph: Inference graph.
    **kwargs: Other args for converting.

  Returns:
    None.
  """
  if convert_to == "tf":
    logger.info("Start converting onnx pb to tf pb:")
    onnx_model = onnx.load(infile)
    tf_rep = backend.prepare(onnx_model, **kwargs)
    tf_rep.export_graph(outfile)
  elif convert_to == "onnx":
    ext = os.path.splitext(infile)[1]
    logger.info("Start converting tf pb to onnx pb:")
    if ext == ".pb":
      with open(infile, "rb") as f:
        graph_def = graph_pb2.GraphDef()
        graph_def.ParseFromString(f.read())
    elif ext == ".ckpt":
      latest_ckpt = tf.train.latest_checkpoint(os.path.dirname(infile))
      saver = tf.train.import_meta_graph(latest_ckpt + ".meta")
      output_node_names = []
      temp_file_suffix = get_unique_suffix()
      workdir = 'onnx-tf_workdir_{}'.format(temp_file_suffix)
      with tf.Session() as sess:
        sess.run([
            tf.global_variables_initializer(),
            tf.local_variables_initializer()
        ])
        saver.restore(sess, latest_ckpt)
        # Take users' hint or deduce output node automatically.
        kwargs["output"] = kwargs.get("output", None) or get_output_node_names(
            sess.graph.as_graph_def())

        # Save the graph to disk for freezing.
        tf.train.write_graph(
            sess.graph.as_graph_def(add_shapes=True),
            workdir,
            "input_model.pb",
            as_text=False)

      # Freeze graph:
      freeze_graph.freeze_graph(
          input_graph=graph or workdir + "/input_model.pb",
          input_saver="",
          input_binary=True,
          input_checkpoint=latest_ckpt,
          output_node_names=",".join(kwargs["output"]),
          restore_op_name="",
          filename_tensor_name="",
          output_graph=workdir + "/frozen_model.pb",
          clear_devices=True,
          initializer_nodes="")

      # Load back the frozen graph.
      with open(workdir + "/frozen_model.pb", "rb") as f:
        graph_def = graph_pb2.GraphDef()
        graph_def.ParseFromString(f.read())

      # Remove work directory.
      shutil.rmtree(workdir)
    else:
      raise ValueError(
          "Input file is not supported. Should be .pb or .ckpt, but get {}".
          format(ext))
    onnx_model = frontend.tensorflow_graph_to_onnx_model(graph_def, **kwargs)
    onnx.save(onnx_model, outfile)
  logger.info("Converting completes successfully.")