def graph_defs_equal(graph_def_1: graph_pb2.GraphDef, graph_def_2: graph_pb2.GraphDef, treat_nan_as_equal: bool = False) -> bool: """Returns True iff the graph def arguments are structurally equivalent. The notion of equivalence encoded here checks that the set of NodeDefs in the GraphDef's function library and main graph body are identical. Additionally, it checks that the functions in the function library are equal as sets. Args: graph_def_1: Instance of `graph_pb2.GraphDef` to compare. graph_def_2: Instance of `graph_pb2.GraphDef` to compare. treat_nan_as_equal: Boolean indicating whether or not to treat nan floating-point values as equal. This is crucial for any equivalence relation defined over GraphDefs, to ensure symmetry. Returns: Boolean indicating structural equivalence as described above. Raises: TypeError: If either of the GraphDefs are not instances of `graph_pb2.GraphDef`. """ if not isinstance(graph_def_1, graph_pb2.GraphDef): raise TypeError("graph_def_1 must be a graph_pb2.GraphDef proto.") if not isinstance(graph_def_2, graph_pb2.GraphDef): raise TypeError("graph_def_2 must be a graph_pb2.GraphDef proto.") options = _proto_comparators.ProtoComparisonOptions(treat_nan_as_equal) return _proto_comparators.EqualsGraphDef(graph_def_1.SerializeToString(), graph_def_2.SerializeToString(), options)
def main(unused_args): # params in_path = FLAGS.input # type: str in_is_text = FLAGS.text_proto # type: bool out_path = FLAGS.output # type: str skip = FLAGS.skip # type: list output_nodes = FLAGS.output_node # type: list # validate param if in_path is None or len(in_path) == 0: raise RuntimeError("in_path must be provided") if out_path is None or len(out_path) == 0: raise RuntimeError("output must be provided") # read graph in_graph = GraphDef() if in_is_text: with open(in_path, "r") as fp: Parse(fp.read(), in_graph) else: with open(in_path, "rb") as fp: in_graph.ParseFromString(fp.read()) # quantize quantized = quantize_graph_def(in_graph, set(skip), output_nodes) # write with open(out_path, "wb") as fp: fp.write(quantized.SerializeToString())
def from_tensorflow_frozen_model(frozen_file, output_nodes=[], preprocessor=None, **kwargs): """ Converts a TensorFlow frozen graph to a UFF model. Args: frozen_file (str): The path to the frozen TensorFlow graph to convert. output_nodes (list(str)): The names of the outputs of the graph. If not provided, graphsurgeon is used to automatically deduce output nodes. output_filename (str): The UFF file to write. preprocessor (str): The path to a preprocessing script that will be executed before the converter. This script should define a ``preprocess`` function which accepts a graphsurgeon DynamicGraph and modifies it in place. write_preprocessed (bool): If set to True, the converter will write out the preprocessed graph as well as a TensorBoard visualization. Must be used in conjunction with output_filename. text (bool): If set to True, the converter will also write out a human readable UFF file. Must be used in conjunction with output_filename. quiet (bool): If set to True, suppresses informational messages. Errors may still be printed. list_nodes (bool): If set to True, the converter displays a list of all nodes present in the graph. debug_mode (bool): If set to True, the converter prints verbose debug messages. return_graph_info (bool): If set to True, this function returns the graph input and output nodes in addition to the serialized UFF graph. Returns: serialized UFF MetaGraph (str) OR, if return_graph_info is set to True, serialized UFF MetaGraph (str), graph inputs (list(tensorflow.NodeDef)), graph outputs (list(tensorflow.NodeDef)) """ graphdef = GraphDef() with tf.io.gfile.GFile(frozen_file, "rb") as frozen_pb: graphdef.ParseFromString(frozen_pb.read()) return from_tensorflow(graphdef, output_nodes, preprocessor, **kwargs)
def load_graphdef_from_pb(pb_file): graph = GraphDef() with open(pb_file, 'rb') as f: content = f.read() try: graph.ParseFromString(content) except Exception as e: raise IOError("Can't parse file {}: {}".format(pb_file, str(e))) return graph
def do_quantize_training_on_graphdef(input_graph, num_bits): from tensorflow.core.framework.graph_pb2 import GraphDef from tensorflow.python.framework import errors with errors.raise_exception_on_not_ok_status() as status: graph = GraphDef() graph.ParseFromString( DoQuantizeTrainingOnGraphDefHelper(input_graph.SerializeToString(), num_bits, status)) return graph
def merge_partitioned_graphs_from_pb(pb_files): graphs = [] for pb_file in pb_files: graph = GraphDef() with open(pb_file, 'rb') as f: content = f.read() try: graph.ParseFromString(content) graphs.append(graph) except Exception as e: raise IOError("Can't parse file {}: {}.".format(pb_file, str(e))) return merge_partitioned_graphs(graphs)
def deserialize(cls, path): """Deserialize a graph_item serialized proto message from a file path.""" item_def = graphitem_pb2.GraphItem() with open(path, "rb") as f: item_def.ParseFromString(f.read()) # GraphDef gdef = GraphDef() item_def.graph_def.Unpack(gdef) g = cls(graph_def=gdef) # Grad Target Pairs for k, v in item_def.grad_target_pairs.items(): k = k.split(';') k = k[0] if len(k) == 1 else tuple(k) g._grad_target_pairs[k] = v # Info for a in item_def.info.variables: v = VariableDef() a.Unpack(v) g.info.update_variables([v], replace=False) for a in item_def.info.savers: v = SaverDef() a.Unpack(v) g.info.update_savers([v], replace=False) g.info.update_table_initializers(item_def.info.table_initializers) return g
def _replace_graph_node_names(graph, mapping): # regex, match all mapped name all_nodes_regex = re.compile( _node_name_regex_tpl.format('|'.join(mapping.keys()))) # old graph text graph_text = MessageToString(graph) # replace all node name obfuscated_graph_text = io.StringIO() last_match_end = 0 while True: match = all_nodes_regex.search(graph_text, last_match_end) if match is None: break # prefix match_beg, match_end = match.span('name') obfuscated_graph_text.write(graph_text[last_match_end:match_beg]) last_match_end = match_end # node name node_name = graph_text[match_beg:match_end] obfuscated_graph_text.write(mapping.get(node_name, node_name)) obfuscated_graph_text.write(graph_text[last_match_end:]) obfuscated_graph = GraphDef() Parse(obfuscated_graph_text.getvalue(), obfuscated_graph) obfuscated_graph_text.close() return obfuscated_graph
def load_graphdef_from_pbtxt(pbtxt_file): graph = GraphDef() with open(pbtxt_file, 'rb') as f: content = f.read() from google.protobuf import text_format try: text_format.Parse(content.decode('UTF-8'), graph, allow_unknown_extension=True) except Exception as e: raise IOError("Can't parse file {}: {}".format(pbtxt_file, str(e))) return graph
def __init__(self, pb_path): """ Creates tf function for neural network. """ with open(pb_path, "rb") as pb: graph_def = GraphDef() graph_def.ParseFromString(pb.read()) @tf.function def network_function(I0, I1, I2, I3, I4): inputs = { "Placeholder:0": I0, "Placeholder_1:0": I1, "Placeholder_2:0": I2, "Placeholder_3:0": I3, "Placeholder_4:0": I4 } alpha, background = tf.graph_util.import_graph_def( graph_def, input_map=inputs, return_elements=OUTPUT_NAMES) return alpha, background self._network = network_function
def main(unused_args): # params in_path = FLAGS.input # type: str in_is_text = FLAGS.text_proto # type: bool quantized = FLAGS.quantized # type: bool out_path = FLAGS.output # type: str out_mapping_path = FLAGS.output_mapping # type: str keeps = [s if ':' not in s else tuple(s.split(':')) for s in FLAGS.keep] # validate param if in_path is None or len(in_path) == 0: raise RuntimeError("in_path must be provided") if out_path is None or len(out_path) == 0: raise RuntimeError("output must be provided") if out_mapping_path is None or len(out_mapping_path) == 0: raise RuntimeError("output_mapping must be provided") # read graph if quantized: in_graph = QuantizedGraph() else: in_graph = GraphDef() if in_is_text: with open(in_path, "r") as fp: Parse(fp.read(), in_graph) else: with open(in_path, "rb") as fp: in_graph.ParseFromString(fp.read()) # obfuscate if quantized: obfuscated, mapping = obfuscate_quantized_graph(in_graph, keeps) else: obfuscated, mapping = obfuscate_graph_def(in_graph, keeps) # write graph with open(out_path, "wb") as fp: fp.write(obfuscated.SerializeToString()) # write mapping with open(out_mapping_path, "w") as fp: for k, v in mapping.items(): fp.write("{}:{}\n".format(k, v))
def merge_partitioned_graphs_from_pbtxt(pbtxt_files): graphs = [] for pbtxt_file in pbtxt_files: graph = GraphDef() with open(pbtxt_file, 'rb') as f: content = f.read() from google.protobuf import text_format try: text_format.Parse(content.decode('UTF-8'), graph, allow_unknown_extension=True) graphs.append(graph) except Exception as e: raise IOError("Can't parse file {}: {}.".format( pbtxt_file, str(e))) return merge_partitioned_graphs(graphs)
def __init__(self, tarball_path): # """Creates and loads pretrained deeplab model.""" self.graph = tf.Graph() graph_def = None # Extract frozen graph from tar archive. tar_file = tarfile.open(tarball_path) for tar_info in tar_file.getmembers(): if self.FROZEN_GRAPH_NAME in os.path.basename(tar_info.name): file_handle = tar_file.extractfile(tar_info) graph_def = GraphDef.FromString(file_handle.read()) break tar_file.close() if graph_def is None: raise RuntimeError('Cannot find inference graph in tar archive.') with self.graph.as_default(): tf.import_graph_def(graph_def, name = '') self.sess = Session(graph = self.graph)
def update_graph_def(input_graph_def: GraphDef, nodes_to_remap: Dict[Text, List[NodeDef]], inputs_to_replace: Dict[Text, Text]) -> GraphDef: """ Update a TF graph_def by replacing nodes and node inputs. There will be no consistency check in this function. Callers have to make sure the given remappings and input replacements result in a valid graph. Args: input_graph_def: TF graph_def with nodes or node inputs to replace nodes_to_remap: `dict` that maps node names to a list of replacement nodes. Nodes whose name map to an empty list, will be removed from the returned graph. Nodes that are not in the input graph_def but have an entry in the remap dict, will be ignored. inputs_to_replace: `dict` that maps node names to replacement names. Nodes that have been removed need to be replaced in all referenced graph nodes. This mapping can be used to make sure this happens. Returns: An updated copy of the input graph_def. The original inputs remains unchanged. """ result_graph_def = GraphDef() for node in input_graph_def.node: if node.name in nodes_to_remap: nodes_to_insert = nodes_to_remap[node.name] if nodes_to_insert and len(nodes_to_insert) > 0: result_graph_def.node.extend(nodes_to_insert) continue new_node = NodeDef() new_node.CopyFrom(node) for i, input_node in enumerate(new_node.input): if input_node in inputs_to_replace: new_node.input[i] = inputs_to_replace[input_node] result_graph_def.node.extend([new_node]) result_graph_def.versions.CopyFrom(input_graph_def.versions) return result_graph_def
executor.benchmark_layout_transform(min_exec_num=exec_num) executor.run() executor.write_opt_sch2record_file(opt_sch_file) if __name__ == '__main__': # logging.getLogger('autotvm').setLevel(logging.DEBUG) # 1.load img img_path = '/opt/dataset/tr2_cropped/data/1.png' image = Image.open(img_path).resize((1024, 1024)) x = np.array(image) # 2.load graph GRAPH_PB_PATH = './frozen' # graph_def = tf.compat.v1.get_default_graph().as_graph_def(add_shapes=True) with tf.io.gfile.GFile('./frozen/frozen_model_fixed.pb', 'rb') as f: graph_def = GraphDef.FromString(f.read()) #call the utility to import the graph definition into default graph. graph_def = tf_testing.ProcessGraphDefParam(graph_def) # 3. tvm frontend shape_dict = {"input_1": (1, 1024, 1024, 3)} # change shape mod, params = relay.frontend.from_tensorflow(graph_def, shape_dict) # meta-data部份 data_shape = (1, 1024, 1024, 3) # maybe output_shape = (1, 1024, 1024, 4) batch_size = 1 dtype = "float32" model_name = "unet_cpu_12_thread" log_file = "%s.log" % model_name graph_opt_sch_file = "%s_graph_opt_1000.log" % model_name input_name = "x" #这是和后面的输入名字一样的
def merge_partitioned_graphs(partitioned_graphs): merged_graph = GraphDef() # TODO: for now we use first partitioned graph for version merged_graph.versions.CopyFrom(partitioned_graphs[0].versions) merged_graph.library.CopyFrom(partitioned_graphs[0].library) send_nodes = [] recv_nodes = [] for pg in partitioned_graphs: for node in pg.node: if node.op == "_Send" or node.op == "_HostSend": send_nodes.append(node) elif node.op == "_Recv" or node.op == "_HostRecv": recv_nodes.append(node) else: merged_graph.node.extend([node]) # build _Send/_Recv pairs send_recv_pairs = [] for snode in send_nodes: for rnode in recv_nodes: if not "tensor_name" in snode.attr: raise RuntimeError( "_Send node {} must have tensor_name".format(snode.name)) if not "tensor_name" in rnode.attr: raise RuntimeError( "_Recv node {} must have tensor_name".format(rnode.name)) if snode.attr["tensor_name"] == rnode.attr["tensor_name"]: send_recv_pairs.append([snode, rnode]) break else: raise RuntimeError( "_Send node '{}' does not match any _Recv node (tensor_name={})" .format(snode.name, snode.attr["tensor_name"])) # build source/destination node pairs rewrite_node_pairs = [] for pair in send_recv_pairs: src_node_and_port = None dst_node_and_port = None dst_node_and_port_list = [] for node in merged_graph.node: for i, input_full_name in enumerate(pair[0].input): str_list = input_full_name.split(":") if len(str_list) == 2: input_node_name = str_list[0] input_port = str_list[1] elif len(str_list) == 1: input_node_name = str_list[0] input_port = None else: raise RuntimeError( "Node '{}' input '{}' does not match the proper format." .format(pair[0].name, input_full_name)) if input_node_name == node.name: src_node_and_port = { "node": node, "port": input_port, "index": i } for i, input_full_name in enumerate(node.input): str_list = input_full_name.split(":") if len(str_list) == 2: input_node_name = str_list[0] input_port = str_list[1] elif len(str_list) == 1: input_node_name = str_list[0] input_port = None else: raise RuntimeError( "Node '{}' input '{}' does not match the proper format." .format(node.name, input_full_name)) if input_node_name == pair[1].name: dst_node_and_port = { "node": node, "port": input_port, "index": i } dst_node_and_port_list.append(dst_node_and_port) if src_node_and_port is None: raise RuntimeError( "_Send input node '{}' is not found. (Node name: {})".format( pair[0].input, pair[0].name)) if not dst_node_and_port_list: raise RuntimeError( "_Recv output is not found. (Node name: {})".format( pair[1].name)) for dst in dst_node_and_port_list: rewrite_node_pairs.append({"src": src_node_and_port, "dst": dst}) # rewrite destination node's input for pair in rewrite_node_pairs: src = pair["src"] dst = pair["dst"] if src["port"] is not None: dst["node"].input[dst["index"]] = "{}:{}".format( src["node"].name, src["port"]) else: dst["node"].input[dst["index"]] = "{}".format(src["node"].name) return merged_graph