Пример #1
0
def graph_defs_equal(graph_def_1: graph_pb2.GraphDef,
                     graph_def_2: graph_pb2.GraphDef,
                     treat_nan_as_equal: bool = False) -> bool:
    """Returns True iff the graph def arguments are structurally equivalent.

  The notion of equivalence encoded here checks that the set of NodeDefs in
  the GraphDef's function library and main graph body are identical.
  Additionally, it checks that the functions in the function library are equal
  as sets.

  Args:
    graph_def_1: Instance of `graph_pb2.GraphDef` to compare.
    graph_def_2: Instance of `graph_pb2.GraphDef` to compare.
    treat_nan_as_equal: Boolean indicating whether or not to treat nan
      floating-point values as equal. This is crucial for any equivalence
      relation defined over GraphDefs, to ensure symmetry.

  Returns:
    Boolean indicating structural equivalence as described above.

  Raises:
    TypeError: If either of the GraphDefs are not instances of
      `graph_pb2.GraphDef`.
  """
    if not isinstance(graph_def_1, graph_pb2.GraphDef):
        raise TypeError("graph_def_1 must be a graph_pb2.GraphDef proto.")
    if not isinstance(graph_def_2, graph_pb2.GraphDef):
        raise TypeError("graph_def_2 must be a graph_pb2.GraphDef proto.")
    options = _proto_comparators.ProtoComparisonOptions(treat_nan_as_equal)
    return _proto_comparators.EqualsGraphDef(graph_def_1.SerializeToString(),
                                             graph_def_2.SerializeToString(),
                                             options)
Пример #2
0
def main(unused_args):
  # params
  in_path = FLAGS.input  # type: str
  in_is_text = FLAGS.text_proto  # type: bool
  out_path = FLAGS.output  # type: str
  skip = FLAGS.skip  # type: list
  output_nodes = FLAGS.output_node  # type: list

  # validate param
  if in_path is None or len(in_path) == 0:
    raise RuntimeError("in_path must be provided")

  if out_path is None or len(out_path) == 0:
    raise RuntimeError("output must be provided")

  # read graph
  in_graph = GraphDef()
  if in_is_text:
    with open(in_path, "r") as fp:
      Parse(fp.read(), in_graph)
  else:
    with open(in_path, "rb") as fp:
      in_graph.ParseFromString(fp.read())

  # quantize
  quantized = quantize_graph_def(in_graph, set(skip), output_nodes)

  # write
  with open(out_path, "wb") as fp:
    fp.write(quantized.SerializeToString())
def from_tensorflow_frozen_model(frozen_file,
                                 output_nodes=[],
                                 preprocessor=None,
                                 **kwargs):
    """
    Converts a TensorFlow frozen graph to a UFF model.

    Args:
        frozen_file (str): The path to the frozen TensorFlow graph to convert.
        output_nodes (list(str)): The names of the outputs of the graph. If not provided, graphsurgeon is used to automatically deduce output nodes.
        output_filename (str): The UFF file to write.
        preprocessor (str): The path to a preprocessing script that will be executed before the converter. This script should define a ``preprocess`` function which accepts a graphsurgeon DynamicGraph and modifies it in place.
        write_preprocessed (bool): If set to True, the converter will write out the preprocessed graph as well as a TensorBoard visualization. Must be used in conjunction with output_filename.
        text (bool): If set to True, the converter will also write out a human readable UFF file. Must be used in conjunction with output_filename.
        quiet (bool): If set to True, suppresses informational messages. Errors may still be printed.
        list_nodes (bool): If set to True, the converter displays a list of all nodes present in the graph.
        debug_mode (bool): If set to True, the converter prints verbose debug messages.
        return_graph_info (bool): If set to True, this function returns the graph input and output nodes in addition to the serialized UFF graph.

    Returns:
        serialized UFF MetaGraph (str)

        OR, if return_graph_info is set to True,

        serialized UFF MetaGraph (str), graph inputs (list(tensorflow.NodeDef)), graph outputs (list(tensorflow.NodeDef))
    """
    graphdef = GraphDef()
    with tf.io.gfile.GFile(frozen_file, "rb") as frozen_pb:
        graphdef.ParseFromString(frozen_pb.read())
    return from_tensorflow(graphdef, output_nodes, preprocessor, **kwargs)
def load_graphdef_from_pb(pb_file):
    graph = GraphDef()
    with open(pb_file, 'rb') as f:
        content = f.read()
        try:
            graph.ParseFromString(content)
        except Exception as e:
            raise IOError("Can't parse file {}: {}".format(pb_file, str(e)))
    return graph
def do_quantize_training_on_graphdef(input_graph, num_bits):
    from tensorflow.core.framework.graph_pb2 import GraphDef
    from tensorflow.python.framework import errors
    with errors.raise_exception_on_not_ok_status() as status:
        graph = GraphDef()
        graph.ParseFromString(
            DoQuantizeTrainingOnGraphDefHelper(input_graph.SerializeToString(),
                                               num_bits, status))
    return graph
def merge_partitioned_graphs_from_pb(pb_files):
    graphs = []
    for pb_file in pb_files:
        graph = GraphDef()
        with open(pb_file, 'rb') as f:
            content = f.read()
        try:
            graph.ParseFromString(content)
            graphs.append(graph)
        except Exception as e:
            raise IOError("Can't parse file {}: {}.".format(pb_file, str(e)))

    return merge_partitioned_graphs(graphs)
Пример #7
0
 def deserialize(cls, path):
     """Deserialize a graph_item serialized proto message from a file path."""
     item_def = graphitem_pb2.GraphItem()
     with open(path, "rb") as f:
         item_def.ParseFromString(f.read())
     # GraphDef
     gdef = GraphDef()
     item_def.graph_def.Unpack(gdef)
     g = cls(graph_def=gdef)
     # Grad Target Pairs
     for k, v in item_def.grad_target_pairs.items():
         k = k.split(';')
         k = k[0] if len(k) == 1 else tuple(k)
         g._grad_target_pairs[k] = v
     # Info
     for a in item_def.info.variables:
         v = VariableDef()
         a.Unpack(v)
         g.info.update_variables([v], replace=False)
     for a in item_def.info.savers:
         v = SaverDef()
         a.Unpack(v)
         g.info.update_savers([v], replace=False)
     g.info.update_table_initializers(item_def.info.table_initializers)
     return g
Пример #8
0
def _replace_graph_node_names(graph, mapping):
    # regex, match all mapped name
    all_nodes_regex = re.compile(
        _node_name_regex_tpl.format('|'.join(mapping.keys())))

    # old graph text
    graph_text = MessageToString(graph)

    # replace all node name
    obfuscated_graph_text = io.StringIO()
    last_match_end = 0
    while True:
        match = all_nodes_regex.search(graph_text, last_match_end)
        if match is None:
            break

        # prefix
        match_beg, match_end = match.span('name')
        obfuscated_graph_text.write(graph_text[last_match_end:match_beg])
        last_match_end = match_end

        # node name
        node_name = graph_text[match_beg:match_end]
        obfuscated_graph_text.write(mapping.get(node_name, node_name))

    obfuscated_graph_text.write(graph_text[last_match_end:])

    obfuscated_graph = GraphDef()
    Parse(obfuscated_graph_text.getvalue(), obfuscated_graph)
    obfuscated_graph_text.close()
    return obfuscated_graph
def load_graphdef_from_pbtxt(pbtxt_file):
    graph = GraphDef()
    with open(pbtxt_file, 'rb') as f:
        content = f.read()
        from google.protobuf import text_format
        try:
            text_format.Parse(content.decode('UTF-8'),
                              graph,
                              allow_unknown_extension=True)
        except Exception as e:
            raise IOError("Can't parse file {}: {}".format(pbtxt_file, str(e)))
    return graph
Пример #10
0
    def __init__(self, pb_path):
        """
        Creates tf function for neural network.
        """
        with open(pb_path, "rb") as pb:
            graph_def = GraphDef()
            graph_def.ParseFromString(pb.read())

        @tf.function
        def network_function(I0, I1, I2, I3, I4):
            inputs = {
                "Placeholder:0": I0,
                "Placeholder_1:0": I1,
                "Placeholder_2:0": I2,
                "Placeholder_3:0": I3,
                "Placeholder_4:0": I4
            }
            alpha, background = tf.graph_util.import_graph_def(
                graph_def, input_map=inputs, return_elements=OUTPUT_NAMES)
            return alpha, background

        self._network = network_function
Пример #11
0
def main(unused_args):
    # params
    in_path = FLAGS.input  # type: str
    in_is_text = FLAGS.text_proto  # type: bool
    quantized = FLAGS.quantized  # type: bool
    out_path = FLAGS.output  # type: str
    out_mapping_path = FLAGS.output_mapping  # type: str
    keeps = [s if ':' not in s else tuple(s.split(':')) for s in FLAGS.keep]

    # validate param
    if in_path is None or len(in_path) == 0:
        raise RuntimeError("in_path must be provided")

    if out_path is None or len(out_path) == 0:
        raise RuntimeError("output must be provided")

    if out_mapping_path is None or len(out_mapping_path) == 0:
        raise RuntimeError("output_mapping must be provided")

    # read graph
    if quantized:
        in_graph = QuantizedGraph()
    else:
        in_graph = GraphDef()
    if in_is_text:
        with open(in_path, "r") as fp:
            Parse(fp.read(), in_graph)
    else:
        with open(in_path, "rb") as fp:
            in_graph.ParseFromString(fp.read())

    # obfuscate
    if quantized:
        obfuscated, mapping = obfuscate_quantized_graph(in_graph, keeps)
    else:
        obfuscated, mapping = obfuscate_graph_def(in_graph, keeps)

    # write graph
    with open(out_path, "wb") as fp:
        fp.write(obfuscated.SerializeToString())

    # write mapping
    with open(out_mapping_path, "w") as fp:
        for k, v in mapping.items():
            fp.write("{}:{}\n".format(k, v))
def merge_partitioned_graphs_from_pbtxt(pbtxt_files):
    graphs = []
    for pbtxt_file in pbtxt_files:
        graph = GraphDef()
        with open(pbtxt_file, 'rb') as f:
            content = f.read()

        from google.protobuf import text_format
        try:
            text_format.Parse(content.decode('UTF-8'),
                              graph,
                              allow_unknown_extension=True)
            graphs.append(graph)
        except Exception as e:
            raise IOError("Can't parse file {}: {}.".format(
                pbtxt_file, str(e)))

    return merge_partitioned_graphs(graphs)
    def __init__(self, tarball_path):
        # """Creates and loads pretrained deeplab model."""
        self.graph = tf.Graph()
        graph_def = None
        # Extract frozen graph from tar archive.
        tar_file = tarfile.open(tarball_path)
        for tar_info in tar_file.getmembers():
            if self.FROZEN_GRAPH_NAME in os.path.basename(tar_info.name):
                file_handle = tar_file.extractfile(tar_info)
                graph_def = GraphDef.FromString(file_handle.read())
                break

        tar_file.close()

        if graph_def is None:
            raise RuntimeError('Cannot find inference graph in tar archive.')

        with self.graph.as_default():
            tf.import_graph_def(graph_def, name = '')

        self.sess = Session(graph = self.graph)
Пример #14
0
def update_graph_def(input_graph_def: GraphDef,
                     nodes_to_remap: Dict[Text, List[NodeDef]],
                     inputs_to_replace: Dict[Text, Text]) -> GraphDef:
    """
    Update a TF graph_def by replacing nodes and node inputs.
    There will be no consistency check in this function.
    Callers have to make sure the given remappings and input replacements
    result in a valid graph.

    Args:
        input_graph_def: TF graph_def with nodes or node inputs to replace
        nodes_to_remap: `dict` that maps node names to a list of replacement
            nodes. Nodes whose name map to an empty list, will be
            removed from the returned graph.
            Nodes that are not in the input graph_def but have an
            entry in the remap dict, will be ignored.
        inputs_to_replace: `dict` that maps node names to replacement names.
            Nodes that have been removed need to be replaced in all referenced
            graph nodes. This mapping can be used to make sure this happens.

    Returns:
        An updated copy of the input graph_def. The original inputs remains
        unchanged.
    """
    result_graph_def = GraphDef()
    for node in input_graph_def.node:
        if node.name in nodes_to_remap:
            nodes_to_insert = nodes_to_remap[node.name]
            if nodes_to_insert and len(nodes_to_insert) > 0:
                result_graph_def.node.extend(nodes_to_insert)
            continue
        new_node = NodeDef()
        new_node.CopyFrom(node)
        for i, input_node in enumerate(new_node.input):
            if input_node in inputs_to_replace:
                new_node.input[i] = inputs_to_replace[input_node]
        result_graph_def.node.extend([new_node])
    result_graph_def.versions.CopyFrom(input_graph_def.versions)
    return result_graph_def
Пример #15
0
    executor.benchmark_layout_transform(min_exec_num=exec_num)
    executor.run()
    executor.write_opt_sch2record_file(opt_sch_file)


if __name__ == '__main__':
    # logging.getLogger('autotvm').setLevel(logging.DEBUG)
    # 1.load img
    img_path = '/opt/dataset/tr2_cropped/data/1.png'
    image = Image.open(img_path).resize((1024, 1024))
    x = np.array(image)
    # 2.load graph
    GRAPH_PB_PATH = './frozen'
    # graph_def = tf.compat.v1.get_default_graph().as_graph_def(add_shapes=True)
    with tf.io.gfile.GFile('./frozen/frozen_model_fixed.pb', 'rb') as f:
        graph_def = GraphDef.FromString(f.read())
        #call the utility to import the graph definition into default graph.
        graph_def = tf_testing.ProcessGraphDefParam(graph_def)

    # 3. tvm frontend
    shape_dict = {"input_1": (1, 1024, 1024, 3)}  # change shape
    mod, params = relay.frontend.from_tensorflow(graph_def, shape_dict)
    # meta-data部份
    data_shape = (1, 1024, 1024, 3)  # maybe
    output_shape = (1, 1024, 1024, 4)
    batch_size = 1
    dtype = "float32"
    model_name = "unet_cpu_12_thread"
    log_file = "%s.log" % model_name
    graph_opt_sch_file = "%s_graph_opt_1000.log" % model_name
    input_name = "x"  #这是和后面的输入名字一样的
def merge_partitioned_graphs(partitioned_graphs):
    merged_graph = GraphDef()
    # TODO: for now we use first partitioned graph for version
    merged_graph.versions.CopyFrom(partitioned_graphs[0].versions)
    merged_graph.library.CopyFrom(partitioned_graphs[0].library)

    send_nodes = []
    recv_nodes = []
    for pg in partitioned_graphs:
        for node in pg.node:
            if node.op == "_Send" or node.op == "_HostSend":
                send_nodes.append(node)
            elif node.op == "_Recv" or node.op == "_HostRecv":
                recv_nodes.append(node)
            else:
                merged_graph.node.extend([node])

    # build _Send/_Recv pairs
    send_recv_pairs = []
    for snode in send_nodes:
        for rnode in recv_nodes:
            if not "tensor_name" in snode.attr:
                raise RuntimeError(
                    "_Send node {} must have tensor_name".format(snode.name))
            if not "tensor_name" in rnode.attr:
                raise RuntimeError(
                    "_Recv node {} must have tensor_name".format(rnode.name))
            if snode.attr["tensor_name"] == rnode.attr["tensor_name"]:
                send_recv_pairs.append([snode, rnode])
                break
        else:
            raise RuntimeError(
                "_Send node '{}' does not match any _Recv node (tensor_name={})"
                .format(snode.name, snode.attr["tensor_name"]))

    # build source/destination node pairs
    rewrite_node_pairs = []
    for pair in send_recv_pairs:
        src_node_and_port = None
        dst_node_and_port = None
        dst_node_and_port_list = []
        for node in merged_graph.node:
            for i, input_full_name in enumerate(pair[0].input):
                str_list = input_full_name.split(":")
                if len(str_list) == 2:
                    input_node_name = str_list[0]
                    input_port = str_list[1]
                elif len(str_list) == 1:
                    input_node_name = str_list[0]
                    input_port = None
                else:
                    raise RuntimeError(
                        "Node '{}' input '{}' does not match the proper format."
                        .format(pair[0].name, input_full_name))
                if input_node_name == node.name:
                    src_node_and_port = {
                        "node": node,
                        "port": input_port,
                        "index": i
                    }
            for i, input_full_name in enumerate(node.input):
                str_list = input_full_name.split(":")
                if len(str_list) == 2:
                    input_node_name = str_list[0]
                    input_port = str_list[1]
                elif len(str_list) == 1:
                    input_node_name = str_list[0]
                    input_port = None
                else:
                    raise RuntimeError(
                        "Node '{}' input '{}' does not match the proper format."
                        .format(node.name, input_full_name))
                if input_node_name == pair[1].name:
                    dst_node_and_port = {
                        "node": node,
                        "port": input_port,
                        "index": i
                    }
                    dst_node_and_port_list.append(dst_node_and_port)

        if src_node_and_port is None:
            raise RuntimeError(
                "_Send input node '{}' is not found. (Node name: {})".format(
                    pair[0].input, pair[0].name))
        if not dst_node_and_port_list:
            raise RuntimeError(
                "_Recv output is not found. (Node name: {})".format(
                    pair[1].name))

        for dst in dst_node_and_port_list:
            rewrite_node_pairs.append({"src": src_node_and_port, "dst": dst})

    # rewrite destination node's input
    for pair in rewrite_node_pairs:
        src = pair["src"]
        dst = pair["dst"]
        if src["port"] is not None:
            dst["node"].input[dst["index"]] = "{}:{}".format(
                src["node"].name, src["port"])
        else:
            dst["node"].input[dst["index"]] = "{}".format(src["node"].name)

    return merged_graph