def testExtractSubGraph(self): graph_def = tf.GraphDef() n1 = graph_def.node.add() n1.name = "n1" n1.input.extend(["n5"]) n2 = graph_def.node.add() n2.name = "n2" # Take the first output of the n1 node as the input. n2.input.extend(["n1:0"]) n3 = graph_def.node.add() n3.name = "n3" # Add a control input (which isn't really needed by the kernel, but # rather to enforce execution order between nodes). n3.input.extend(["^n2"]) n4 = graph_def.node.add() n4.name = "n4" # It is fine to have a loops in the graph as well. n5 = graph_def.node.add() n5.name = "n5" n5.input.extend(["n1"]) sub_graph = graph_util.extract_sub_graph(graph_def, ["n3"]) self.assertEqual("n1", sub_graph.node[0].name) self.assertEqual("n2", sub_graph.node[1].name) self.assertEqual("n3", sub_graph.node[2].name) self.assertEqual("n5", sub_graph.node[3].name)
def convert_variables_to_constants(sess, input_graph_def, output_node_names): variable_names = [] variable_dict_names = [] for node in input_graph_def.node: if node.op == "Assign": variable_name = node.input[0] variable_dict_names.append(variable_name) variable_names.append(variable_name + ":0") returned_variables = sess.run(variable_names) found_variables = dict(zip(variable_dict_names, returned_variables)) print("Frozen %d variables." % len(returned_variables)) inference_graph = extract_sub_graph(input_graph_def, output_node_names) output_graph_def = graph_pb2.GraphDef() how_many_converted = 0 for input_node in inference_graph.node: output_node = graph_pb2.NodeDef() if input_node.name in found_variables: output_node.op = "Const" output_node.name = input_node.name dtype = input_node.attr["dtype"] data = found_variables[input_node.name] output_node.attr["dtype"].CopyFrom(dtype) output_node.attr["value"].CopyFrom(attr_value_pb2.AttrValue( tensor=tensor_util.make_tensor_proto(data, dtype=dtype.type, shape=data.shape))) how_many_converted += 1 else: output_node.CopyFrom(input_node) output_graph_def.node.extend([output_node]) print("Converted %d variables to const ops." % how_many_converted) return output_graph_def
def test_remove_unneeded_nodes(self): a_constant_name = "a_constant" b_constant_name = "b_constant" a_check_name = "a_check" b_check_name = "b_check" a_identity_name = "a_identity" b_identity_name = "b_identity" add_name = "add" graph_def = tf.GraphDef() a_constant = quantize_graph.create_constant_node(a_constant_name, value=1, dtype=tf.float32, shape=[]) graph_def.node.extend([a_constant]) a_check_node = quantize_graph.create_node("CheckNumerics", a_check_name, [a_constant_name]) graph_def.node.extend([a_check_node]) a_identity_node = quantize_graph.create_node("Identity", a_identity_name, [a_constant_name, "^" + a_check_name]) graph_def.node.extend([a_identity_node]) b_constant = quantize_graph.create_constant_node(b_constant_name, value=1, dtype=tf.float32, shape=[]) graph_def.node.extend([b_constant]) b_check_node = quantize_graph.create_node("CheckNumerics", b_check_name, [b_constant_name]) graph_def.node.extend([b_check_node]) b_identity_node = quantize_graph.create_node("Identity", b_identity_name, [b_constant_name, "^" + b_check_name]) graph_def.node.extend([b_identity_node]) add_node = quantize_graph.create_node("Add", add_name, [a_identity_name, b_identity_name]) quantize_graph.set_attr_dtype(add_node, "T", tf.float32) graph_def.node.extend([add_node]) expected_output = tf.GraphDef() a_constant = quantize_graph.create_constant_node(a_constant_name, value=1, dtype=tf.float32, shape=[]) expected_output.node.extend([a_constant]) b_constant = quantize_graph.create_constant_node(b_constant_name, value=1, dtype=tf.float32, shape=[]) expected_output.node.extend([b_constant]) add_node = quantize_graph.create_node("Add", add_name, [a_constant_name, b_constant_name]) quantize_graph.set_attr_dtype(add_node, "T", tf.float32) expected_output.node.extend([add_node]) rewriter = quantize_graph.GraphRewriter(graph_def, [add_name]) output = rewriter.remove_unneeded_nodes(graph_def) stripped_output = graph_util.extract_sub_graph(output, [add_name]) self.assertProtoEquals(expected_output, stripped_output)
def strip_unused(input_graph, input_binary, output_graph, input_node_names, output_node_names, placeholder_type_enum): """Removes unused nodes from a graph.""" if not tf.gfile.Exists(input_graph): print("Input graph file '" + input_graph + "' does not exist!") return -1 if not output_node_names: print("You need to supply the name of a node to --output_node_names.") return -1 input_graph_def = tf.GraphDef() mode = "rb" if input_binary else "r" with tf.gfile.FastGFile(input_graph, mode) as f: if input_binary: input_graph_def.ParseFromString(f.read()) else: text_format.Merge(f.read(), input_graph_def) # Here we replace the nodes we're going to override as inputs with # placeholders so that any unused nodes that are inputs to them are # automatically stripped out by extract_sub_graph(). input_node_names_list = input_node_names.split(",") inputs_replaced_graph_def = tf.GraphDef() for node in input_graph_def.node: if node.name in input_node_names_list: placeholder_node = tf.NodeDef() placeholder_node.op = "Placeholder" placeholder_node.name = node.name placeholder_node.attr["dtype"].CopyFrom(tf.AttrValue( type=placeholder_type_enum)) inputs_replaced_graph_def.node.extend([placeholder_node]) else: inputs_replaced_graph_def.node.extend([copy.deepcopy(node)]) output_graph_def = graph_util.extract_sub_graph(inputs_replaced_graph_def, output_node_names.split(",")) with tf.gfile.GFile(output_graph, "wb") as f: f.write(output_graph_def.SerializeToString()) print("%d ops in the final graph." % len(output_graph_def.node))
def strip_unused(input_graph, input_binary, output_graph, input_node_names, output_node_names, placeholder_type_enum): """Removes unused nodes from a graph.""" if not tf.gfile.Exists(input_graph): print("Input graph file '" + input_graph + "' does not exist!") return -1 if not output_node_names: print("You need to supply the name of a node to --output_node_names.") return -1 input_graph_def = tf.GraphDef() mode = "rb" if input_binary else "r" with tf.gfile.FastGFile(input_graph, mode) as f: if input_binary: input_graph_def.ParseFromString(f.read()) else: text_format.Merge(f.read(), input_graph_def) # Here we replace the nodes we're going to override as inputs with # placeholders so that any unused nodes that are inputs to them are # automatically stripped out by extract_sub_graph(). input_node_names_list = input_node_names.split(",") inputs_replaced_graph_def = tf.GraphDef() for node in input_graph_def.node: if node.name in input_node_names_list: placeholder_node = tf.NodeDef() placeholder_node.op = "Placeholder" placeholder_node.name = node.name placeholder_node.attr["dtype"].CopyFrom( tf.AttrValue(type=placeholder_type_enum)) inputs_replaced_graph_def.node.extend([placeholder_node]) else: inputs_replaced_graph_def.node.extend([copy.deepcopy(node)]) output_graph_def = graph_util.extract_sub_graph( inputs_replaced_graph_def, output_node_names.split(",")) with tf.gfile.GFile(output_graph, "wb") as f: f.write(output_graph_def.SerializeToString()) print("%d ops in the final graph." % len(output_graph_def.node))
def freeze_graph(input_graph, input_saver, input_binary, input_checkpoint, output_node_names, restore_op_name, filename_tensor_name, output_graph, clear_devices): """Converts all variables in a graph and checkpoint into constants.""" if not gfile.Exists(input_graph): print("Input graph file '" + input_graph + "' does not exist!") return -1 if input_saver and not gfile.Exists(input_saver): print("Input saver file '" + input_saver + "' does not exist!") return -1 if not gfile.Exists(input_checkpoint): print("Input checkpoint '" + input_checkpoint + "' doesn't exist!") return -1 if not output_node_names: print("You need to supply the name of a node to --output_node_names.") return -1 input_graph_def = tf.GraphDef() with open(input_graph, "rb") as f: if input_binary: input_graph_def.ParseFromString(f.read()) else: text_format.Merge(bytes(f.read()), input_graph_def) # Remove all the explicit device specifications for this node. This helps to # make the graph more portable. if clear_devices: for node in input_graph_def.node: node.device = "" _ = tf.import_graph_def(input_graph_def, name="") with tf.Session() as sess: if input_saver: with open(input_saver, "rb") as f: saver_def = tf.train.SaverDef() if input_binary: saver_def.ParseFromString(f.read()) else: text_format.Merge(f.read(), saver_def) saver = tf.train.Saver(saver_def=saver_def) saver.restore(sess, input_checkpoint) else: sess.run([restore_op_name], {filename_tensor_name: input_checkpoint}) found_variables = {} for node in input_graph_def.node: if node.op == "Assign": variable_name = node.input[0] found_variables[variable_name] = sess.run(variable_name + ":0") # This graph only includes the nodes needed to evaluate the output nodes, and # removes unneeded nodes like those involved in saving and assignment. inference_graph = graph_util.extract_sub_graph( input_graph_def, output_node_names.split(",")) output_graph_def = tf.GraphDef() how_many_converted = 0 for input_node in inference_graph.node: output_node = tf.NodeDef() if input_node.name in found_variables: output_node.op = "Const" output_node.name = input_node.name dtype = input_node.attr["dtype"] data = found_variables[input_node.name] set_attr_dtype(output_node, "dtype", dtype) set_attr_tensor(output_node, "value", data, dtype.type, data.shape) how_many_converted += 1 else: output_node.CopyFrom(input_node) output_graph_def.node.extend([output_node]) with gfile.FastGFile(output_graph, "w") as f: f.write(output_graph_def.SerializeToString()) print("Converted %d variables to const ops." % how_many_converted) print("%d ops in the final graph." % len(output_graph_def.node))
def freeze_graph( input_graph, input_saver, input_binary, input_checkpoint, output_node_names, restore_op_name, filename_tensor_name, output_graph, clear_devices, ): """Converts all variables in a graph and checkpoint into constants.""" if not tf.gfile.Exists(input_graph): print("Input graph file '" + input_graph + "' does not exist!") return -1 if input_saver and not tf.gfile.Exists(input_saver): print("Input saver file '" + input_saver + "' does not exist!") return -1 if not tf.gfile.Exists(input_checkpoint): print("Input checkpoint '" + input_checkpoint + "' doesn't exist!") return -1 if not output_node_names: print("You need to supply the name of a node to --output_node_names.") return -1 input_graph_def = tf.GraphDef() mode = "rb" if input_binary else "r" with open(input_graph, mode) as f: if input_binary: input_graph_def.ParseFromString(f.read()) else: text_format.Merge(f.read(), input_graph_def) # Remove all the explicit device specifications for this node. This helps to # make the graph more portable. if clear_devices: for node in input_graph_def.node: node.device = "" _ = tf.import_graph_def(input_graph_def, name="") with tf.Session() as sess: if input_saver: with open(input_saver, mode) as f: saver_def = tf.train.SaverDef() if input_binary: saver_def.ParseFromString(f.read()) else: text_format.Merge(f.read(), saver_def) saver = tf.train.Saver(saver_def=saver_def) saver.restore(sess, input_checkpoint) else: sess.run([restore_op_name], {filename_tensor_name: input_checkpoint}) found_variables = {} for node in input_graph_def.node: if node.op == "Assign": variable_name = node.input[0] found_variables[variable_name] = sess.run(variable_name + ":0") # This graph only includes the nodes needed to evaluate the output nodes, and # removes unneeded nodes like those involved in saving and assignment. inference_graph = graph_util.extract_sub_graph(input_graph_def, output_node_names.split(",")) output_graph_def = tf.GraphDef() how_many_converted = 0 for input_node in inference_graph.node: output_node = tf.NodeDef() if input_node.name in found_variables: output_node.op = "Const" output_node.name = input_node.name dtype = input_node.attr["dtype"] data = found_variables[input_node.name] set_attr_dtype(output_node, "dtype", dtype) set_attr_tensor(output_node, "value", data, dtype.type, data.shape) how_many_converted += 1 else: output_node.CopyFrom(input_node) output_graph_def.node.extend([output_node]) with tf.gfile.FastGFile(output_graph, "wb") as f: f.write(output_graph_def.SerializeToString()) print("Converted %d variables to const ops." % how_many_converted) print("%d ops in the final graph." % len(output_graph_def.node))
def remove_dead_nodes(self, output_names): """Removes nodes that are no longer needed for inference from the graph.""" old_output_graph = self.output_graph self.output_graph = graph_util.extract_sub_graph(old_output_graph, output_names)
def remove_dead_nodes(self, output_names): """Removes nodes that are no longer needed for inference from the graph.""" old_output_graph = self.output_graph self.output_graph = graph_util.extract_sub_graph( old_output_graph, output_names)
def convert_variables_to_constants(sess, input_graph_def, output_node_names): """Replaces all the variables in a graph with constants of the same values. If you have a trained graph containing Variable ops, it can be convenient to convert them all to Const ops holding the same values. This makes it possible to describe the network fully with a single GraphDef file, and allows the removal of a lot of ops related to loading and saving the variables. Args: sess: Active TensorFlow session containing the variables. input_graph_def: GraphDef object holding the network. output_node_names: List of name strings for the result nodes of the graph. Returns: GraphDef containing a simplified version of the original. """ print('call convert_variables') found_variables = {} variable_name_list = [] found_variables_list = [] print('search nodes...') for i, node in enumerate(input_graph_def.node): # print('node %s' % node) if node.op == "Assign": variable_name_list.append(node.input[0]) sys.stdout.write( "\r%s" % "node: {0}/{1}".format(i + 1, len(input_graph_def.node))) sys.stdout.flush() print('') print('{0} nodes founded'.format(len(variable_name_list))) print('evaluate nodes..') found_variables_list = sess.run([v + ":0" for v in variable_name_list]) print('insert values..') for i, v in enumerate(variable_name_list): found_variables[v] = found_variables_list[i] sys.stdout.write( "\r%s" % "node: {0}/{1}".format(i + 1, len(variable_name_list))) sys.stdout.flush() print('') # This graph only includes the nodes needed to evaluate the output nodes, and # removes unneeded nodes like those involved in saving and assignment. inference_graph = graph_util.extract_sub_graph(input_graph_def, output_node_names) output_graph_def = graph_pb2.GraphDef() how_many_converted = 0 for input_node in inference_graph.node: output_node = graph_pb2.NodeDef() if input_node.name in found_variables: output_node.op = "Const" output_node.name = input_node.name dtype = input_node.attr["dtype"] data = found_variables[input_node.name] output_node.attr["dtype"].CopyFrom(dtype) output_node.attr["value"].CopyFrom( attr_value_pb2.AttrValue(tensor=tensor_util.make_tensor_proto( data, dtype=dtype.type, shape=data.shape))) how_many_converted += 1 else: output_node.CopyFrom(input_node) output_graph_def.node.extend([output_node]) print("Converted %d variables to const ops." % how_many_converted) return output_graph_def
def test_remove_redundant_quantization(self): a_constant_name = "a_constant" a_constant_min_name = "a_constant_min" a_constant_max_name = "a_constant_max" a_dequantize_name = "a_dequantize" a_quantize_name = "a_quantize" b_constant_name = "b_constant" b_constant_min_name = "b_constant_min" b_constant_max_name = "b_constant_max" b_dequantize_name = "b_dequantize" b_quantize_name = "b_quantize" mat_mul_name = "mat_mul" graph_def = tf.GraphDef() a_constant = quantize_graph.create_constant_node(a_constant_name, value=(0,), dtype=tf.quint8, shape=[]) graph_def.node.extend([a_constant]) a_constant_min = quantize_graph.create_constant_node(a_constant_min_name, value=2, dtype=tf.float32, shape=[]) graph_def.node.extend([a_constant_min]) a_constant_max = quantize_graph.create_constant_node(a_constant_max_name, value=2, dtype=tf.float32, shape=[]) graph_def.node.extend([a_constant_max]) a_dequantize_node = quantize_graph.create_node("Dequantize", a_dequantize_name, [a_constant_name, a_constant_min_name, a_constant_max_name]) quantize_graph.set_attr_dtype(a_dequantize_node, "T", tf.uint8) graph_def.node.extend([a_dequantize_node]) a_quantize_node = quantize_graph.create_node("QuantizeV2", a_quantize_name, [a_dequantize_name, a_dequantize_name + ":1", a_dequantize_name + ":2"]) quantize_graph.set_attr_dtype(a_quantize_node, "T", tf.uint8) graph_def.node.extend([a_quantize_node]) b_constant = quantize_graph.create_constant_node(b_constant_name, value=(0,), dtype=tf.quint8, shape=[]) graph_def.node.extend([b_constant]) b_constant_min = quantize_graph.create_constant_node(b_constant_min_name, value=3, dtype=tf.float32, shape=[]) graph_def.node.extend([b_constant_min]) b_constant_max = quantize_graph.create_constant_node(b_constant_max_name, value=3, dtype=tf.float32, shape=[]) graph_def.node.extend([b_constant_max]) b_dequantize_node = quantize_graph.create_node("Dequantize", b_dequantize_name, [b_constant_name, b_constant_min_name, b_constant_max_name]) quantize_graph.set_attr_dtype(b_dequantize_node, "T", tf.uint8) graph_def.node.extend([b_dequantize_node]) b_quantize_node = quantize_graph.create_node("QuantizeV2", b_quantize_name, [b_dequantize_name, b_dequantize_name + ":1", b_dequantize_name + ":2"]) quantize_graph.set_attr_dtype(b_quantize_node, "T", tf.uint8) graph_def.node.extend([b_quantize_node]) mat_mul_node = quantize_graph.create_node("QuantizedMatMul", mat_mul_name, [a_quantize_name, b_quantize_name, a_quantize_name + ":1", a_quantize_name + ":2", b_quantize_name + ":1", b_quantize_name + ":2"]) quantize_graph.set_attr_dtype(mat_mul_node, "T1", tf.uint8) quantize_graph.set_attr_dtype(mat_mul_node, "T2", tf.int32) graph_def.node.extend([mat_mul_node]) expected_output = tf.GraphDef() a_constant = quantize_graph.create_constant_node(a_constant_name, value=(0,), dtype=tf.quint8, shape=[]) expected_output.node.extend([a_constant]) a_constant_min = quantize_graph.create_constant_node(a_constant_min_name, value=2, dtype=tf.float32, shape=[]) expected_output.node.extend([a_constant_min]) a_constant_max = quantize_graph.create_constant_node(a_constant_max_name, value=2, dtype=tf.float32, shape=[]) expected_output.node.extend([a_constant_max]) b_constant = quantize_graph.create_constant_node(b_constant_name, value=(0,), dtype=tf.quint8, shape=[]) expected_output.node.extend([b_constant]) b_constant_min = quantize_graph.create_constant_node(b_constant_min_name, value=3, dtype=tf.float32, shape=[]) expected_output.node.extend([b_constant_min]) b_constant_max = quantize_graph.create_constant_node(b_constant_max_name, value=3, dtype=tf.float32, shape=[]) expected_output.node.extend([b_constant_max]) mat_mul_node = quantize_graph.create_node("QuantizedMatMul", mat_mul_name, [a_constant_name, b_constant_name, a_constant_min_name, a_constant_max_name, b_constant_min_name, b_constant_max_name]) quantize_graph.set_attr_dtype(mat_mul_node, "T1", tf.uint8) quantize_graph.set_attr_dtype(mat_mul_node, "T2", tf.int32) expected_output.node.extend([mat_mul_node]) rewriter = quantize_graph.GraphRewriter(graph_def, [mat_mul_name]) output = rewriter.remove_redundant_quantization(graph_def) stripped_output = graph_util.extract_sub_graph(output, [mat_mul_name]) self.assertProtoEquals(expected_output, stripped_output)