def variable_device(device, name): """Fix the variable device to colocate its ops.""" if callable(device): var_name = tf.get_variable_scope().name + '/' + name var_def = graph_pb2.NodeDef(name=var_name, op='Variable') device = device(var_def) if device is None: device = '' return device
def convert_variables_to_constants(sess, input_graph_def, output_node_names): """Replaces all the variables in a graph with constants of the same values. If you have a trained graph containing Variable ops, it can be convenient to convert them all to Const ops holding the same values. This makes it possible to describe the network fully with a single GraphDef file, and allows the removal of a lot of ops related to loading and saving the variables. Args: sess: Active TensorFlow session containing the variables. input_graph_def: GraphDef object holding the network. output_node_names: List of name strings for the result nodes of the graph. Returns: GraphDef containing a simplified version of the original. """ found_variables = {} variable_names = [] variable_dict_names = [] for node in input_graph_def.node: if node.op == "Assign": variable_name = node.input[0] variable_dict_names.append(variable_name) variable_names.append(variable_name + ":0") if variable_names: returned_variables = sess.run(variable_names) else: returned_variables = [] found_variables = dict(zip(variable_dict_names, returned_variables)) logging.info("Frozen %d variables." % len(returned_variables)) # This graph only includes the nodes needed to evaluate the output nodes, and # removes unneeded nodes like those involved in saving and assignment. inference_graph = extract_sub_graph(input_graph_def, output_node_names) output_graph_def = graph_pb2.GraphDef() how_many_converted = 0 for input_node in inference_graph.node: output_node = graph_pb2.NodeDef() if input_node.name in found_variables: output_node.op = "Const" output_node.name = input_node.name dtype = input_node.attr["dtype"] data = found_variables[input_node.name] output_node.attr["dtype"].CopyFrom(dtype) output_node.attr["value"].CopyFrom( attr_value_pb2.AttrValue(tensor=tensor_util.make_tensor_proto( data, dtype=dtype.type, shape=data.shape))) how_many_converted += 1 else: output_node.CopyFrom(input_node) output_graph_def.node.extend([output_node]) print("Converted %d variables to const ops." % how_many_converted) return output_graph_def
def _operator_to_node(shapes, op): assert op.name, op # Check for existance of __version__ for backwards compatibility n = tf.NodeDef() if hasattr(tf, '__version__') else graph_pb2.NodeDef() n.name = op.name n.input.extend(op.input) n.op = op.type n.device = _tf_device(op.device_option) if shapes: # Add shapes in order. for output in op.output: if output not in shapes: break _add_tf_shape(n.attr, shapes[output]) for arg in op.arg: _set_tf_attr(n.attr, arg) return n
def _blob_to_node(producing_ops, shapes, name): assert name # Check for existance of __version__ for backwards compatibility n = tf.NodeDef() if hasattr(tf, '__version__') else graph_pb2.NodeDef() n.name = name inputs = producing_ops.get(name, []) if inputs: n.op = 'Blob' else: n.op = 'Placeholder' n.input.extend('%s:%d' % (op.name, i) for op, i in inputs) if inputs: device = inputs[0][0].device_option if (all(input[0].device_option == device for input in inputs)): n.device = _tf_device(device) if shapes and name in shapes: _add_tf_shape(n.attr, shapes[name]) return n
def _StripNode(self, nd): snode = graph_pb2.NodeDef(name=nd.name, op=nd.op, input=nd.input) if nd.device: snode.device = nd.device return snode
def convert_variables_to_constants(sess, input_graph_def, output_node_names): """Replaces all the variables in a graph with constants of the same values. If you have a trained graph containing Variable ops, it can be convenient to convert them all to Const ops holding the same values. This makes it possible to describe the network fully with a single GraphDef file, and allows the removal of a lot of ops related to loading and saving the variables. Args: sess: Active TensorFlow session containing the variables. input_graph_def: GraphDef object holding the network. output_node_names: List of name strings for the result nodes of the graph. Returns: GraphDef containing a simplified version of the original. """ print('call convert_variables') found_variables = {} variable_name_list = [] found_variables_list = [] print('search nodes...') for i, node in enumerate(input_graph_def.node): # print('node %s' % node) if node.op == "Assign": variable_name_list.append(node.input[0]) sys.stdout.write( "\r%s" % "node: {0}/{1}".format(i + 1, len(input_graph_def.node))) sys.stdout.flush() print('') print('{0} nodes founded'.format(len(variable_name_list))) print('evaluate nodes..') found_variables_list = sess.run([v + ":0" for v in variable_name_list]) print('insert values..') for i, v in enumerate(variable_name_list): found_variables[v] = found_variables_list[i] sys.stdout.write( "\r%s" % "node: {0}/{1}".format(i + 1, len(variable_name_list))) sys.stdout.flush() print('') # This graph only includes the nodes needed to evaluate the output nodes, and # removes unneeded nodes like those involved in saving and assignment. inference_graph = graph_util.extract_sub_graph(input_graph_def, output_node_names) output_graph_def = graph_pb2.GraphDef() how_many_converted = 0 for input_node in inference_graph.node: output_node = graph_pb2.NodeDef() if input_node.name in found_variables: output_node.op = "Const" output_node.name = input_node.name dtype = input_node.attr["dtype"] data = found_variables[input_node.name] output_node.attr["dtype"].CopyFrom(dtype) output_node.attr["value"].CopyFrom( attr_value_pb2.AttrValue(tensor=tensor_util.make_tensor_proto( data, dtype=dtype.type, shape=data.shape))) how_many_converted += 1 else: output_node.CopyFrom(input_node) output_graph_def.node.extend([output_node]) print("Converted %d variables to const ops." % how_many_converted) return output_graph_def
def remove_training_nodes(input_graph): """Prunes out nodes that aren't needed for inference. There are nodes like Identity and CheckNumerics that are only useful during training, and can be removed in graphs that will be used for nothing but inference. Here we identify and remove them, returning an equivalent graph. To be specific, CheckNumerics nodes are always removed, and Identity nodes that aren't involved in control edges are spliced out so that their input and outputs are directly connected. Args: input_graph: Model to analyze and prune. Returns: A list of nodes with the unnecessary ones removed. """ types_to_remove = {"CheckNumerics": True} input_nodes = input_graph.node names_to_remove = {} for node in input_nodes: if node.op in types_to_remove: names_to_remove[node.name] = True nodes_after_removal = [] for node in input_nodes: if node.name in names_to_remove: continue new_node = graph_pb2.NodeDef() new_node.CopyFrom(node) input_before_removal = node.input del new_node.input[:] for full_input_name in input_before_removal: input_name = re.sub(r"^\^", "", full_input_name) if input_name in names_to_remove: continue new_node.input.append(full_input_name) nodes_after_removal.append(new_node) types_to_splice = {"Identity": True} names_to_splice = {} for node in nodes_after_removal: if node.op in types_to_splice: # We don't want to remove nodes that have control edge inputs, because # they might be involved in subtle dependency issues that removing them # will jeopardize. has_control_edge = False for input_name in node.input: if re.match(r"^\^", input_name): has_control_edge = True if not has_control_edge: names_to_splice[node.name] = node.input[0] nodes_after_splicing = [] for node in nodes_after_removal: if node.name in names_to_splice: continue new_node = graph_pb2.NodeDef() new_node.CopyFrom(node) input_before_removal = node.input del new_node.input[:] for full_input_name in input_before_removal: input_name = re.sub(r"^\^", "", full_input_name) if input_name in names_to_splice: new_node.input.append(names_to_splice[input_name]) else: new_node.input.append(full_input_name) nodes_after_splicing.append(new_node) output_graph = graph_pb2.GraphDef() output_graph.node.extend(nodes_after_splicing) return output_graph