コード例 #1
0
def from_tensorflow_frozen_model(frozen_file,
                                 output_nodes=[],
                                 preprocessor=None,
                                 **kwargs):
    """
    Converts a TensorFlow frozen graph to a UFF model.

    Args:
        frozen_file (str): The path to the frozen TensorFlow graph to convert.
        output_nodes (list(str)): The names of the outputs of the graph. If not provided, graphsurgeon is used to automatically deduce output nodes.
        output_filename (str): The UFF file to write.
        preprocessor (str): The path to a preprocessing script that will be executed before the converter. This script should define a ``preprocess`` function which accepts a graphsurgeon DynamicGraph and modifies it in place.
        write_preprocessed (bool): If set to True, the converter will write out the preprocessed graph as well as a TensorBoard visualization. Must be used in conjunction with output_filename.
        text (bool): If set to True, the converter will also write out a human readable UFF file. Must be used in conjunction with output_filename.
        quiet (bool): If set to True, suppresses informational messages. Errors may still be printed.
        list_nodes (bool): If set to True, the converter displays a list of all nodes present in the graph.
        debug_mode (bool): If set to True, the converter prints verbose debug messages.
        return_graph_info (bool): If set to True, this function returns the graph input and output nodes in addition to the serialized UFF graph.

    Returns:
        serialized UFF MetaGraph (str)

        OR, if return_graph_info is set to True,

        serialized UFF MetaGraph (str), graph inputs (list(tensorflow.NodeDef)), graph outputs (list(tensorflow.NodeDef))
    """
    graphdef = GraphDef()
    with tf.io.gfile.GFile(frozen_file, "rb") as frozen_pb:
        graphdef.ParseFromString(frozen_pb.read())
    return from_tensorflow(graphdef, output_nodes, preprocessor, **kwargs)
コード例 #2
0
def main(unused_args):
  # params
  in_path = FLAGS.input  # type: str
  in_is_text = FLAGS.text_proto  # type: bool
  out_path = FLAGS.output  # type: str
  skip = FLAGS.skip  # type: list
  output_nodes = FLAGS.output_node  # type: list

  # validate param
  if in_path is None or len(in_path) == 0:
    raise RuntimeError("in_path must be provided")

  if out_path is None or len(out_path) == 0:
    raise RuntimeError("output must be provided")

  # read graph
  in_graph = GraphDef()
  if in_is_text:
    with open(in_path, "r") as fp:
      Parse(fp.read(), in_graph)
  else:
    with open(in_path, "rb") as fp:
      in_graph.ParseFromString(fp.read())

  # quantize
  quantized = quantize_graph_def(in_graph, set(skip), output_nodes)

  # write
  with open(out_path, "wb") as fp:
    fp.write(quantized.SerializeToString())
def load_graphdef_from_pb(pb_file):
    graph = GraphDef()
    with open(pb_file, 'rb') as f:
        content = f.read()
        try:
            graph.ParseFromString(content)
        except Exception as e:
            raise IOError("Can't parse file {}: {}".format(pb_file, str(e)))
    return graph
コード例 #4
0
def do_quantize_training_on_graphdef(input_graph, num_bits):
    from tensorflow.core.framework.graph_pb2 import GraphDef
    from tensorflow.python.framework import errors
    with errors.raise_exception_on_not_ok_status() as status:
        graph = GraphDef()
        graph.ParseFromString(
            DoQuantizeTrainingOnGraphDefHelper(input_graph.SerializeToString(),
                                               num_bits, status))
    return graph
def merge_partitioned_graphs_from_pb(pb_files):
    graphs = []
    for pb_file in pb_files:
        graph = GraphDef()
        with open(pb_file, 'rb') as f:
            content = f.read()
        try:
            graph.ParseFromString(content)
            graphs.append(graph)
        except Exception as e:
            raise IOError("Can't parse file {}: {}.".format(pb_file, str(e)))

    return merge_partitioned_graphs(graphs)
コード例 #6
0
    def __init__(self, pb_path):
        """
        Creates tf function for neural network.
        """
        with open(pb_path, "rb") as pb:
            graph_def = GraphDef()
            graph_def.ParseFromString(pb.read())

        @tf.function
        def network_function(I0, I1, I2, I3, I4):
            inputs = {
                "Placeholder:0": I0,
                "Placeholder_1:0": I1,
                "Placeholder_2:0": I2,
                "Placeholder_3:0": I3,
                "Placeholder_4:0": I4
            }
            alpha, background = tf.graph_util.import_graph_def(
                graph_def, input_map=inputs, return_elements=OUTPUT_NAMES)
            return alpha, background

        self._network = network_function