def convert_variables_to_constants(sess, input_graph_def, output_node_names, variable_names_whitelist=None, variable_names_blacklist=None): """Replaces all the variables in a graph with constants of the same values. If you have a trained graph containing Variable ops, it can be convenient to convert them all to Const ops holding the same values. This makes it possible to describe the network fully with a single GraphDef file, and allows the removal of a lot of ops related to loading and saving the variables. Args: sess: Active TensorFlow session containing the variables. input_graph_def: GraphDef object holding the network. output_node_names: List of name strings for the result nodes of the graph. variable_names_whitelist: The set of variable names to convert (by default, all variables are converted). variable_names_blacklist: The set of variable names to omit converting to constants. Returns: GraphDef containing a simplified version of the original. """ def has_variable_as_input(node): """Checks if the input node has a variable in `variables_data_map`.""" for name in node.input: if name in variables_data_map or\ (name in identity_ops_input_map and identity_ops_input_map[name] in variables_data_map): return True return False def dfs_find_variable(origin_name, name_to_nodes): if origin_name in variables_data_map: return origin_name, set() nodes_in_path = set() found_variables = set() def dfs(name): node = name_to_nodes[name] if node.op == "Switch": inputs = [node.input[0]] else: inputs = node.input for name in inputs: name = _node_name(name) if name in nodes_in_path: continue elif name in variables_data_map: found_variables.add(name) continue else: nodes_in_path.add(name) dfs(name) nodes_in_path.add(origin_name) dfs(origin_name) if len(found_variables) > 1: raise ValueError("found variables %s" % found_variables) variable = None for v in found_variables: variable = v return variable, nodes_in_path def create_const_op(node_name, dtype, data, data_shape=None): """Creates a Const op.""" output_node = node_def_pb2.NodeDef() output_node.op = "Const" output_node.name = node_name output_node.attr["dtype"].CopyFrom(dtype) output_node.attr["value"].CopyFrom( attr_value_pb2.AttrValue(tensor=tensor_util.make_tensor_proto( data, dtype=dtype.type, shape=data_shape))) return output_node # This graph only includes the nodes needed to evaluate the output nodes, and # removes unneeded nodes like those involved in saving and assignment. inference_graph = extract_sub_graph(input_graph_def, output_node_names) # Get list of variables. variable_names = [] variable_dict_names = [] identity_ops_input_map = {} name_to_node = {} for node in inference_graph.node: name_to_node[node.name] = node if node.op in ["Variable", "VariableV2", "VarHandleOp"]: variable_name = node.name if ((variable_names_whitelist is not None and variable_name not in variable_names_whitelist) or (variable_names_blacklist is not None and variable_name in variable_names_blacklist)): continue variable_dict_names.append(variable_name) if node.op == "VarHandleOp": variable_names.append(variable_name + "/Read/ReadVariableOp:0") else: variable_names.append(variable_name + ":0") elif node.op == "Identity": # TODO(nupurgarg): Move and reuse get_name from lite/convert.py. # Creates a map of Identity node names to the input names. input_info = node.input[0].split(":") if (len(input_info) == 1 or (len(input_info) == 2 and int(input_info[1]) == 0)): identity_ops_input_map[node.name] = input_info[0] # Gets map of variables and the associated data. if variable_names: returned_variables = sess.run(variable_names) else: returned_variables = [] variables_data_map = dict(zip(variable_dict_names, returned_variables)) logging.info("Froze %d variables.", len(returned_variables)) # Reconstruct the graph with constants in place of variables. path_node_to_variables = {} output_graph_def = graph_pb2.GraphDef() how_many_converted = 0 for input_node in inference_graph.node: output_node = node_def_pb2.NodeDef() if input_node.name in variables_data_map: data = variables_data_map[input_node.name] output_node = create_const_op(input_node.name, input_node.attr["dtype"], data, data.shape) how_many_converted += 1 elif input_node.op == "ReadVariableOp": variable, nodes_in_path = dfs_find_variable( input_node.input[0], name_to_node) if variable is not None: # The first branch converts all VarHandleOps of ResourceVariables to # constants, so we need to convert the associated ReadVariableOps to # Identity ops. # # Handles the following cases: # Variable --> ReadVariableOp # Variable --> Identity --> ReadVariableOp output_node.op = "Identity" output_node.name = input_node.name output_node.input.extend([input_node.input[0]]) output_node.attr["T"].CopyFrom(input_node.attr["dtype"]) if "_class" in input_node.attr: output_node.attr["_class"].CopyFrom( input_node.attr["_class"]) for name in nodes_in_path: path_node_to_variables[name] = variable else: raise ValueError("Cannot find variable for %s" % input_node.name) elif input_node.op == "ResourceGather": variable, nodes_in_path = dfs_find_variable( input_node.input[0], name_to_node) if variable is not None: # The first branch converts all VarHandleOps of ResourceGather to # constants, so we need to convert the associated ResourceGather to Gather # ops with a Const axis feeding into it. if input_node.attr["batch_dims"].i != 0: raise ValueError( "batch_dims != 0 is not supported by freeze_graph.") axis_data = input_node.attr["batch_dims"].i axis_node_name = input_node.name + "/axis" axis_dtype = input_node.attr["Tindices"] output_axis_node = create_const_op(axis_node_name, axis_dtype, axis_data) output_graph_def.node.extend([output_axis_node]) output_node.op = "GatherV2" output_node.name = input_node.name output_node.input.extend( [input_node.input[0], input_node.input[1], axis_node_name]) output_node.attr["Tparams"].CopyFrom(input_node.attr["dtype"]) output_node.attr["Tindices"].CopyFrom( input_node.attr["Tindices"]) output_node.attr["Taxis"].CopyFrom(axis_dtype) if "_class" in input_node.attr: output_node.attr["_class"].CopyFrom( input_node.attr["_class"]) for name in nodes_in_path: path_node_to_variables[name] = variable else: raise ValueError("Cannot find variable for %s" % input_node.name) elif input_node.op == "VariableShape": variable, nodes_in_path = dfs_find_variable( input_node.input[0], name_to_node) if variable is not None: input_variable = name_to_node[variable] output_node.op = "Shape" output_node.name = input_node.name output_node.input.extend([input_node.input[0]]) output_node.attr["T"].CopyFrom(input_variable.attr["dtype"]) output_node.attr["out_type"].CopyFrom( input_node.attr["out_type"]) for name in nodes_in_path: path_node_to_variables[name] = variable else: raise ValueError("Cannot find variable for %s" % input_node.name) else: output_node.CopyFrom(input_node) output_graph_def.node.extend([output_node]) output_graph_def.library.CopyFrom(inference_graph.library) inference_graph = output_graph_def output_graph_def = graph_pb2.GraphDef() for input_node in inference_graph.node: output_node = node_def_pb2.NodeDef() if input_node.name in path_node_to_variables: input_variable = path_node_to_variables[input_node.name] input_variable = name_to_node[input_variable] output_node.op = input_node.op output_node.name = input_node.name if input_node.op == "Enter": output_node.input.extend([input_node.input[0]]) output_node.attr["T"].CopyFrom(input_variable.attr["dtype"]) output_node.attr["frame_name"].CopyFrom( input_node.attr["frame_name"]) output_node.attr["is_constant"].CopyFrom( input_node.attr["is_constant"]) output_node.attr["parallel_iterations"]\ .CopyFrom(input_node.attr["parallel_iterations"]) elif input_node.op == "Switch": output_node.input.extend(input_node.input) output_node.attr["T"].CopyFrom(input_variable.attr["dtype"]) else: raise ValueError("cannot do type: %s" % input_node.op) else: output_node.CopyFrom(input_node) output_graph_def.node.extend([output_node]) output_graph_def.library.CopyFrom(inference_graph.library) logging.info("Converted %d variables to const ops.", how_many_converted) return output_graph_def
def fuse_resize_and_conv(input_graph_def, output_node_names): """Merges preceding resize and mirror pad ops into a specialized convolution. There's a common pattern of enlarging the input to a convolution using a resize operation, and also using MirrorPad to extend the boundaries to that zero edge pixels don't bleed inwards when convolving. This routine looks for that pattern of operations, and fuses them together into a Conv2DWithResizeOp. Args: input_graph_def: A GraphDef containing a model. output_node_names: A list of names of the nodes that produce the final results. Returns: Modified graph with resize and pad ops merged. Raises: ValueError: If the graph is badly formed with duplicate node names. """ input_node_map = {} for node in input_graph_def.node: if node.name not in input_node_map: input_node_map[node.name] = node else: raise ValueError("Duplicate node names detected for ", node.name) node_reference_count = collections.defaultdict(int) for node in input_graph_def.node: for input_name in node.input: stripped_name = node_name_from_input(input_name) node_reference_count[stripped_name] += 1 for output_name in output_node_names: node_reference_count[output_name] += 1 new_ops = [] for node in input_graph_def.node: if node.op != "Conv2D": continue conv_op = node input_op = node_from_map(input_node_map, conv_op.input[0]) if input_op.op == "MirrorPad": mirror_pad_op = input_op resize_op = node_from_map(input_node_map, mirror_pad_op.input[0]) if resize_op.op != "ResizeBilinear": resize_op = None else: mirror_pad_op = None if input_op.op == "ResizeBilinear": resize_op = input_op else: resize_op = None # There are no ops to be fused into the conv, so skip replacing this one. if not mirror_pad_op and not resize_op: continue # We're replacing this node, so make sure the old one is removed. node_reference_count[conv_op.name] = 0 if mirror_pad_op: node_reference_count[mirror_pad_op.name] -= 1 if resize_op: node_reference_count[resize_op.name] -= 1 fused_conv_op = node_def_pb2.NodeDef() if resize_op: fused_conv_op.op = "FusedResizeAndPadConv2D" else: fused_conv_op.op = "FusedPadConv2D" fused_conv_op.name = conv_op.name if mirror_pad_op: mirror_paddings_name = mirror_pad_op.input[1] mirror_paddings_mode = mirror_pad_op.attr["mode"] else: # If there was no MirrorPad op, then create settings that make the padding # stage of the fused operation a no-op. paddings_op = node_def_pb2.NodeDef() paddings_op.op = "Const" paddings_op.name = conv_op.name + "_dummy_paddings" paddings_op.attr["dtype"].CopyFrom( attr_value_pb2.AttrValue(type=dtypes.int32.as_datatype_enum)) paddings_op.attr["value"].CopyFrom( attr_value_pb2.AttrValue(tensor=tensor_util.make_tensor_proto( [0, 0, 0, 0, 0, 0, 0, 0], dtypes.int32, [4, 2]))) new_ops.extend([paddings_op]) mirror_paddings_name = paddings_op.name mirror_paddings_mode = attr_value_pb2.AttrValue(s=b"REFLECT") if resize_op: fused_conv_op.input.extend([ resize_op.input[0], resize_op.input[1], mirror_paddings_name, conv_op.input[1] ]) fused_conv_op.attr["resize_align_corners"].CopyFrom( resize_op.attr["align_corners"]) else: fused_conv_op.input.extend( [mirror_pad_op.input[0], mirror_paddings_name, conv_op.input[1]]) fused_conv_op.attr["T"].CopyFrom(conv_op.attr["T"]) fused_conv_op.attr["mode"].CopyFrom(mirror_paddings_mode) fused_conv_op.attr["strides"].CopyFrom(conv_op.attr["strides"]) fused_conv_op.attr["padding"].CopyFrom(conv_op.attr["padding"]) new_ops.extend([fused_conv_op]) result_graph_def = graph_pb2.GraphDef() for node in input_graph_def.node: if node_reference_count[node.name] < 1: continue new_node = node_def_pb2.NodeDef() new_node.CopyFrom(node) result_graph_def.node.extend([new_node]) result_graph_def.node.extend(new_ops) return result_graph_def
def create_subgraph(tf_graph, node_list, sess, dst_scope=None): """ Create a tf subgraph from the node list. :param tf_graph: :param node_list: :param sess: :param dst_scope: :return: """ variable_dict_names = [] variable_names = [] tensor_op_names = [] for n_ in node_list: # type: tf.Operation tensor_op_names.extend([ts_.op.name for ts_ in n_.inputs]) if n_.type in ["Variable", "VariableV2", "VarHandleOp"]: variable_name = n_.name variable_dict_names.append(variable_name) if n_.type == "VarHandleOp": variable_names.append(variable_name + "/Read/ReadVariableOp:0") else: variable_names.append(variable_name + ":0") if variable_names: returned_variables = sess.run(variable_names) else: returned_variables = [] found_variables = dict(zip(variable_dict_names, returned_variables)) all_op_names = set([n_.name for n_ in node_list]) missing_ops = set(tensor_op_names) - all_op_names replacement = {} tf_graph_def = tf_graph.as_graph_def() subgraph_def = _extract_sub_graph(tf_graph_def, [n_.name for n_ in node_list], missing_ops) output_graph_def = graph_pb2.GraphDef() how_many_converted = 0 for input_node in subgraph_def.node: output_node = node_def_pb2.NodeDef() if input_node.name in found_variables: output_node.op = "Const" output_node.name = input_node.name dtype = input_node.attr["dtype"] data = found_variables[input_node.name] output_node.attr["dtype"].CopyFrom(dtype) output_node.attr["value"].CopyFrom( attr_value_pb2.AttrValue( tensor=tensor_util.make_tensor_proto( data, dtype=dtype.type, shape=data.shape))) how_many_converted += 1 elif input_node.op == "ReadVariableOp" and ( input_node.input[0] in found_variables): # The preceding branch converts all VarHandleOps of ResourceVariables to # constants, so we need to convert the associated ReadVariableOps to # Identity ops. output_node.op = "Identity" output_node.name = input_node.name output_node.input.extend([input_node.input[0]]) output_node.attr["T"].CopyFrom(input_node.attr["dtype"]) if "_class" in input_node.attr: output_node.attr["_class"].CopyFrom(input_node.attr["_class"]) elif input_node.name not in missing_ops: output_node.CopyFrom(input_node) else: output_node = None if output_node is not None: output_graph_def.node.extend([output_node]) for input_node in tf_graph_def.node: if input_node.name in missing_ops: output_node = node_def_pb2.NodeDef() output_node.op = "Placeholder" output_node.name = input_node.name replacement[input_node.name] = input_node.name if str(input_node.attr["dtype"]): output_node.attr["dtype"].CopyFrom(input_node.attr["dtype"]) elif str(input_node.attr["T"]): output_node.attr["dtype"].CopyFrom(input_node.attr["T"]) else: if input_node.op == 'All': output_node.attr["dtype"].CopyFrom(attr_value_pb2.AttrValue(type="DT_BOOL")) elif input_node.op == 'Cast': output_node.attr["dtype"].CopyFrom(input_node.attr["DstT"]) else: raise RuntimeError("Can't get the node data type for %s" % input_node.name) ts_shape = tf.graph_util.tensor_shape_from_node_def_name(tf_graph, input_node.name) output_node.attr["shape"].CopyFrom( attr_value_pb2.AttrValue(shape=ts_shape.as_proto())) output_graph_def.node.extend([output_node]) output_graph_def.library.CopyFrom(subgraph_def.library) with tf.Graph().as_default() as sub_graph: im_scope = "" if dst_scope is None else dst_scope tf.import_graph_def(output_graph_def, name=im_scope) if im_scope: replacement = {k_: im_scope + '/' + k_ for k_ in replacement} return sub_graph, replacement
def create_node_def(self, op, name, inputs): new_node = node_def_pb2.NodeDef() new_node.op = op new_node.name = name new_node.input.extend(inputs) return new_node
def fuse_op(graph_def, input_nodes, output_nodes, output_dtypes, output_quantized, op_name, op_type): """Fuse subgraph between input_nodes and output_nodes into a single custom op. Args: graph_def: A graph_pb2.GraphDef proto. input_nodes: input nodes to the subgraph to be fused. output_nodes: output nodes to the subgraph to be fused. output_dtypes: A list of output datatypes for the custom op output_quantized: A boolean flag that indicates if output is quantized op_name: fused op name. op_type: fused op type. Returns: The GraphDef of the new graph. Raises: TypeError: If 'graph_def' is not a graph_pb2.GraphDef proto. """ if not isinstance(graph_def, graph_pb2.GraphDef): raise TypeError("graph_def must be a graph_pb2.GraphDef proto.") if isinstance(input_nodes, six.string_types): raise TypeError("input_nodes must be a list.") if isinstance(output_nodes, six.string_types): raise TypeError("output_nodes must be a list.") name_to_input_name, name_to_node, name_to_seq_num = _extract_graph_summary( graph_def) _assert_nodes_are_present(name_to_node, input_nodes + output_nodes) # Nodes upto and including input_nodes reachable_by_input = _bfs_for_reachable_nodes(input_nodes, name_to_input_name) # Nodes upto and including output_nodes reachable_by_output = _bfs_for_reachable_nodes(output_nodes, name_to_input_name) # Set of nodes in the list input_nodes input_nodes_set = set(input_nodes) # Set of nodes in the list output_nodes output_nodes_set = set(output_nodes) nodes_post_output = [] for node in graph_def.node: n = _node_name(node.name) if n in reachable_by_output: if n not in reachable_by_input and n not in output_nodes_set: # n is between input and output, i.e., part of the fused op next_to_visit = [n] while next_to_visit: cur_node = next_to_visit[0] del next_to_visit[0] if cur_node in reachable_by_input and cur_node not in input_nodes_set: raise TypeError( "Node %s uses input %s not in input_nodes." % (n, cur_node)) if cur_node not in input_nodes_set: next_to_visit += name_to_input_name[cur_node] elif n not in reachable_by_input: nodes_post_output.append(n) # Add all nodes upto the input nodes out = graph_pb2.GraphDef() reachable_by_input_sorted = sorted(list(reachable_by_input), key=lambda n: name_to_seq_num[n]) for node in reachable_by_input_sorted: out.node.extend([copy.deepcopy(name_to_node[node])]) # Add the custom op new_node = node_def_pb2.NodeDef() for node in input_nodes: new_node.input.append(node) new_node.attr["_output_types"].list.type[:] = output_dtypes new_node.attr["_output_quantized"].b = output_quantized new_node.op = op_type new_node.name = op_name out.node.extend([new_node]) # Add the nodes in the output of the custom op for index, n in enumerate(output_nodes): assert len(name_to_node[n].input) == 1 new_node = copy.deepcopy(name_to_node[n]) del new_node.input[:] new_node.input.append(op_name + (":" + str(index) if index != 0 else "")) out.node.extend([new_node]) # Add the nodes post output_nodes for n in nodes_post_output: out.node.extend([copy.deepcopy(name_to_node[n])]) out.library.CopyFrom(graph_def.library) out.versions.CopyFrom(graph_def.versions) return out
def convert_variables_to_constants(sess, input_graph_def, output_node_names, variable_names_whitelist=None, variable_names_blacklist=None): """Replaces all the variables in a graph with constants of the same values. If you have a trained graph containing Variable ops, it can be convenient to convert them all to Const ops holding the same values. This makes it possible to describe the network fully with a single GraphDef file, and allows the removal of a lot of ops related to loading and saving the variables. Args: sess: Active TensorFlow session containing the variables. input_graph_def: GraphDef object holding the network. output_node_names: List of name strings for the result nodes of the graph. variable_names_whitelist: The set of variable names to convert (by default, all variables are converted). variable_names_blacklist: The set of variable names to omit converting to constants. Returns: GraphDef containing a simplified version of the original. """ # This graph only includes the nodes needed to evaluate the output nodes, and # removes unneeded nodes like those involved in saving and assignment. inference_graph = extract_sub_graph(input_graph_def, output_node_names) found_variables = {} variable_names = [] variable_dict_names = [] for node in inference_graph.node: if node.op in ["Variable", "VariableV2", "VarHandleOp"]: variable_name = node.name if ((variable_names_whitelist is not None and variable_name not in variable_names_whitelist) or (variable_names_blacklist is not None and variable_name in variable_names_blacklist)): continue variable_dict_names.append(variable_name) if node.op == "VarHandleOp": variable_names.append(variable_name + "/Read/ReadVariableOp:0") else: variable_names.append(variable_name + ":0") if variable_names: returned_variables = sess.run(variable_names) else: returned_variables = [] found_variables = dict(zip(variable_dict_names, returned_variables)) logging.info("Froze %d variables.", len(returned_variables)) output_graph_def = graph_pb2.GraphDef() how_many_converted = 0 for input_node in inference_graph.node: output_node = node_def_pb2.NodeDef() if input_node.name in found_variables: output_node.op = "Const" output_node.name = input_node.name dtype = input_node.attr["dtype"] data = found_variables[input_node.name] output_node.attr["dtype"].CopyFrom(dtype) output_node.attr["value"].CopyFrom( attr_value_pb2.AttrValue(tensor=tensor_util.make_tensor_proto( data, dtype=dtype.type, shape=data.shape))) how_many_converted += 1 elif input_node.op == "ReadVariableOp" and (input_node.input[0] in found_variables): # The preceding branch converts all VarHandleOps of ResourceVariables to # constants, so we need to convert the associated ReadVariableOps to # Identity ops. output_node.op = "Identity" output_node.name = input_node.name output_node.input.extend([input_node.input[0]]) output_node.attr["T"].CopyFrom(input_node.attr["dtype"]) if "_class" in input_node.attr: output_node.attr["_class"].CopyFrom(input_node.attr["_class"]) else: output_node.CopyFrom(input_node) output_graph_def.node.extend([output_node]) output_graph_def.library.CopyFrom(inference_graph.library) logging.info("Converted %d variables to const ops.", how_many_converted) return output_graph_def
def apply_matmul_biasadd_relu_fusion(self, match_node_name): skip_node_name = match_node_name[1:] matched_node = self.node_name_mapping[match_node_name[0]] control_inputs, normal_inputs = self._get_node_input( matched_node.node.name) weight_name = normal_inputs[1] weight_node = self.node_name_mapping[helper.node_name_from_input( weight_name)].node # FIXME We only quantize the MatMul op which second input node type is const. This is a # workaround for RNN model like LTSM. if weight_node.op != 'Const': self.output_graph = self.input_graph return for i in self.node_name_mapping: if weight_node.name in self.node_name_mapping[i].output: self.output_graph = self.input_graph return q_weights_name, q_weights_min_name, q_weights_max_name = \ self._intel_cpu_quantize_weight_eightbit( matched_node.node.op, self.node_name_mapping[weight_name].node, self.per_channel) skip_node_name.append(weight_name) for _, node in enumerate(self.input_graph.node): if node.name in skip_node_name: pass elif node.name == match_node_name[0]: self.logger.debug("matched node {} with input {}".format( node.name, node.input)) self.logger.debug("apply_matmul_biasadd_relu_fusion") quantized_node_name = node.name + "_eightbit_quantized_mat_mul" bias_node_name = self.node_name_mapping[ match_node_name[1]].node.input[1] relu_node_name = match_node_name[2] all_input_names = self._add_eightbit_prologue_nodes( matched_node.node.name) all_input_names = all_input_names[:1] + [ q_weights_name ] + all_input_names[1:] all_input_names.append(q_weights_min_name) all_input_names.append(q_weights_max_name) quantized_node_input_names = all_input_names[:2] + [ bias_node_name ] + all_input_names[2:] + control_inputs quantized_matmul_node = helper.create_node( "QuantizedMatMulWithBiasAndRelu", quantized_node_name, quantized_node_input_names) helper.copy_attr(quantized_matmul_node, "transpose_a", node.attr["transpose_a"]) helper.copy_attr(quantized_matmul_node, "transpose_b", node.attr["transpose_b"]) helper.set_attr_dtype(quantized_matmul_node, "T1", dtypes.quint8) helper.set_attr_dtype(quantized_matmul_node, "T2", dtypes.qint8) helper.set_attr_dtype(quantized_matmul_node, "Toutput", dtypes.qint32) self.add_output_graph_node(quantized_matmul_node) quantize_down_name = self._add_quantize_down_nodes( node, quantized_node_name, dtypes.quint8, False) self._intel_cpu_add_dequantize_result_node( quantize_down_name, relu_node_name) else: new_node = node_def_pb2.NodeDef() new_node.CopyFrom(node) self.add_output_graph_node(new_node)
def do_transformation(self): cur_graph = GraphAnalyzer() # according to https://github.com/onnx/tensorflow-onnx/issues/77 for node in self.model.node: if node.op == 'RefSwitch': node.op = 'Switch' for index in range(len(node.input)): if 'moving_' in node.input[index]: node.input[index] = node.input[index] + '/read' elif node.op == 'AssignSub': node.op = 'Sub' if 'use_locking' in node.attr: del node.attr['use_locking'] elif node.op == 'AssignAdd': node.op = 'Add' if 'use_locking' in node.attr: del node.attr['use_locking'] elif node.op == 'Assign': node.op = 'Identity' if 'use_locking' in node.attr: del node.attr['use_locking'] if 'validate_shape' in node.attr: del node.attr['validate_shape'] if len(node.input) == 2: # input0: ref: Should be from a Variable node. May be uninitialized. # input1: value: The value to be assigned to the variable. node.input[0] = node.input[1] del node.input[1] cur_graph.graph = self.model graph_info = cur_graph.parse_graph() for name in self.input_node_names: if ':' in name: self.logger.debug("Name {} appears to refer to a Tensor, " "not a Operation.".format(name)) return False type_attr = {"Sub": "T"} not_found = {name for name in self.input_node_names} for node_name, _ in graph_info.items(): if node_name in not_found: not_found.remove(node_name) node = graph_info[node_name].node # skip the convertion to Placeholder that with type list if 'component_types' in node.attr: continue original_output = graph_info[node_name].outputs placeholder_node = node_def_pb2.NodeDef() placeholder_node.op = "Placeholder" placeholder_node.name = node.name if "dtype" in node.attr: placeholder_node.attr["dtype"].CopyFrom( attr_value_pb2.AttrValue(type=node.attr["dtype"].type)) elif node.op in type_attr.keys(): placeholder_node.attr["dtype"].CopyFrom( attr_value_pb2.AttrValue( type=node.attr[type_attr[node.op]].type)) else: raise KeyError("%s op's type attribute is not found," "you should add it to type_attr dict" % node.op) if "_output_shapes" in node.attr: placeholder_node.attr["_output_shapes"].CopyFrom( node.attr["_output_shapes"]) if "shape" in node.attr: placeholder_node.attr["shape"].CopyFrom(node.attr["shape"]) cur_graph.remove_node(node_name) cur_graph.replace_const_node(placeholder_node, [node_name], original_output) import tensorflow as tf return tf.compat.v1.graph_util.extract_sub_graph( cur_graph.dump_graph(), self.output_node_names)
def test_freeze_then_sparsify(self, freeze_mock, graph_transform_mock): tag_name = 'tag' input_nodes = 'input_nodes' output_nodes = 'output_nodes' freeze_transform = 'freeze_graph' sparsify_transform = 'sparsify_gather' base_meta_graph_def = meta_graph_pb2.MetaGraphDef() # Add a table initializer. table_init_name = 'table_init' node_def = node_def_pb2.NodeDef(name=table_init_name, op='InitializeTableV2') base_meta_graph_def.graph_def.node.extend([node_def]) # Add a group_deps node. group_deps_name = 'group_deps' node_def = node_def_pb2.NodeDef(name=group_deps_name, op='NoOp') node_def.input.extend(['^table_init']) base_meta_graph_def.graph_def.node.extend([node_def]) base_meta_graph_def.collection_def[ ops.GraphKeys.TABLE_INITIALIZERS].node_list.value.extend( [table_init_name]) base_meta_graph_def.collection_def[ saved_model_constants.LEGACY_INIT_OP_KEY].node_list.value.extend( [group_deps_name]) # Expected metagraphdef. expected_meta_graph_def = meta_graph_pb2.MetaGraphDef() expected_meta_graph_def.CopyFrom(base_meta_graph_def) expected_meta_graph_def.meta_info_def.tags.append(tag_name) transformed_graph_def = graph_pb2.GraphDef() transformed_graph_def.CopyFrom(expected_meta_graph_def.graph_def) freeze_mock.return_value = transformed_graph_def graph_transform_mock.return_value = transformed_graph_def # Add unsaved init node. unsaved_init_name = 'unsaved_node' node_def = node_def_pb2.NodeDef(name=unsaved_init_name, op='NoOp') base_meta_graph_def.graph_def.node.extend([node_def]) # Add a saver. base_meta_graph_def.saver_def.filename_tensor_name = 'node1' base_meta_graph_def.saver_def.save_tensor_name = 'node3' base_meta_graph_def.saver_def.restore_op_name = 'node6' transformed_meta_graph_def = meta_graph_transform.meta_graph_transform( base_meta_graph_def, [input_nodes], [output_nodes], [freeze_transform, sparsify_transform], [tag_name]) self.assertEqual(expected_meta_graph_def, transformed_meta_graph_def) freeze_mock.assert_called_once_with(base_meta_graph_def.graph_def, [output_nodes], [table_init_name], group_deps_name, base_meta_graph_def.saver_def, None) graph_transform_mock.assert_called_once_with(transformed_graph_def, [ input_nodes ], [output_nodes, group_deps_name, table_init_name], [ sparsify_transform + '(group_init_node="sparify_gather_init_op")' ])
def do_transformation(self): float32_type = dtypes.float32.as_datatype_enum qint32_type = dtypes.qint32.as_datatype_enum target_nodes = self.graph_analyzer.query_fusion_pattern_nodes( self.fuse_patterns[self.version]) for i in target_nodes: # TODO Remove below checker once the TF's limitation removed. if len(i) == 5: continue quantized_node_name = i[0] quantized_node = self.graph_info[quantized_node_name].node requantize_node_name = i[1] requantize_node = self.graph_info[requantize_node_name].node requested_output_min_name = requantize_node.input[3] requested_output_max_name = requantize_node.input[4] deq_node_name = i[2] quantized_node_op = i[-1][0] new_node = node_def_pb2.NodeDef() new_node.op = quantized_node_op + "AndDequantize" new_node.name = requantize_node_name for _, value in enumerate(quantized_node.input): new_node.input.append(value) new_node.input.append(requested_output_min_name) new_node.input.append(requested_output_max_name) if 'T1' in quantized_node.attr: new_node.attr["T1"].CopyFrom(quantized_node.attr['T1']) if 'T2' in quantized_node.attr: new_node.attr["T2"].CopyFrom(quantized_node.attr['T2']) top_node_name = Helper.node_name_from_input(quantized_node.input[0]) max_filter_node = self.graph_info[new_node.input[6]].node min_filter_node = self.graph_info[new_node.input[5]].node last_node = self.graph_info[new_node.input[0]].node bias_node = self.graph_info[new_node.input[2]].node max_input_node = self.graph_info[last_node.input[-1]].node min_input_node = self.graph_info[last_node.input[-2]].node min_input_value = (min_input_node.attr['value'].tensor.float_val)[0] max_input_value = (max_input_node.attr['value'].tensor.float_val)[0] max_filter_value = (max_filter_node.attr['value'].tensor.float_val)[0] min_filter_value = (min_filter_node.attr['value'].tensor.float_val)[0] weights_tensor = tensor_util.MakeNdarray( self.graph_info[new_node.input[1]].node.attr['value'].tensor) bias_tensor = tensor_util.MakeNdarray( self.graph_info[new_node.input[2]].node.attr['value'].tensor) bias_scale = 255.0 * 127.0 / ( (max_input_value -min_input_value) * max(abs(max_filter_value), abs(min_filter_value))) relative_scale = 255 * min_input_value / (max_input_value - min_input_value) int32_bias = [] for bias_index, value in enumerate( np.sum(np.array(weights_tensor, dtype=np.int32), axis=0, dtype=np.int32)): int32_bias.append(int(bias_tensor[bias_index] * bias_scale + value * relative_scale)) bias_node.attr['dtype'].CopyFrom( attr_value_pb2.AttrValue( type=float32_type if self.device == 'gpu' else qint32_type)) bias_node.attr['value'].CopyFrom( attr_value_pb2.AttrValue(tensor=tensor_util.make_tensor_proto( bias_tensor if self.device == 'gpu' else int32_bias, dtypes. float32 if self.device == 'gpu' else dtypes.int32, bias_tensor.shape))) bias_node.attr['value'].tensor.dtype = float32_type \ if self.device == 'gpu' else qint32_type new_node.attr["Tbias"].CopyFrom(attr_value_pb2.AttrValue(type=float32_type \ if self.device == 'gpu' else qint32_type)) new_node.attr["Toutput"].CopyFrom(attr_value_pb2.AttrValue(type=float32_type)) self.graph_analyzer.remove_node(requantize_node_name) if self.graph_info[deq_node_name].outputs: self.graph_analyzer.replace_single_node( new_node, [top_node_name], quantized_node_name, self.graph_info[deq_node_name].outputs, deq_node_name) self.graph_analyzer.remove_node(deq_node_name) else: self.graph_analyzer.remove_node(deq_node_name) new_node.name = deq_node_name self.graph_analyzer.replace_single_node( new_node, [top_node_name], quantized_node_name, [], deq_node_name) self.graph_analyzer.remove_node(quantized_node_name) return self.graph_analyzer.dump_graph()
def _convert_single_op_hint_to_stub(call, graph_def): """Given a graph_def, converts `call` into a stub and returns a new graph_def. Args: call: A single function call to be converted. graph_def: A graph_def to use as input (that hass call obviously). Returns: A new transformed graph-def that has call as a stub (single op). Note: after this process, the graph_def can no longer be loaded into the tensorflow runtime, so all future manipulations are done in graph_def level. """ name_to_input_name, name_to_node, name_to_seq_num = _extract_graph_summary( graph_def) input_names, output_names = call.flattened_inputs_and_outputs() reachable_by_input = _bfs_for_reachable_nodes(input_names, name_to_input_name) reachable_by_output = _bfs_for_reachable_nodes(output_names, name_to_input_name) input_nodes_set = set(input_names) output_nodes_set = set(output_names) nodes_after_fuse = [] nodes_deleted_by_fuse = set() # Classify each node. We want to keep everything reachable by input, but # we don't know if things that are not reachable by output or input (things # after fusing). for node in graph_def.node: n = _tensor_name_base(node.name) if n in reachable_by_output: if n not in reachable_by_input and n not in output_nodes_set: # n is an internal node. Check to make sure it is really internal. # TODO(aselle): this could be done more efficiently by flooding # the graph first. _check_subgraph_closed(n, reachable_by_input, input_nodes_set, name_to_input_name) nodes_deleted_by_fuse.add(n) elif n not in reachable_by_input: # n is a node that after all the fusings, so keep it. nodes_after_fuse.append(n) else: # n is a node that is randomly in the graph but not connected to # the chain of dependencies. pass # Make a new graphdef with all the pre-input and input nodes out = _graph_pb2.GraphDef() reachable_by_input_sorted = sorted(list(reachable_by_input), key=lambda n: name_to_seq_num[n]) for node in reachable_by_input_sorted: out.node.extend([_copy.deepcopy(name_to_node[node])]) # Create any stacks to aggregate arguments into to a single input # i.e. for static_rnn's. # TODO(aselle): Check that the inputs are complete i.e. 0 to n-1 sorted_input_indices = list(call.inputs.keys()) sorted_input_indices.sort() sorted_output_indices = list(call.outputs.keys()) sorted_output_indices.sort() new_node = _node_def_pb2.NodeDef() # Delegate to each operand to produce the proper new input for this stub node. # In particular, an aggregate input will now be a Pack of some previously # non-fused things. for input_index in sorted_input_indices: inputs = call.inputs[input_index] new_node.input.append(inputs.aggregate_and_return_name_for_input(out)) new_node.attr[OpHint.TFLITE_INPUT_INDICES].list.i.extend( sorted_input_indices) # Ceate the function new_node.op = call.function_name new_node.name = call.uuid out.node.extend([new_node]) # Now call each output argument to give them a chance to make the proper # output type and add it to our new_node. output_dtypes = [] for output_index in sorted_output_indices: output = call.outputs[output_index] output_dtype = (output.aggregate_and_return_name_for_output( new_node.name, output_index, out)) output_dtypes.append(output_dtype) new_node.attr["_output_types"].list.type[:] = output_dtypes # TODO(aselle): what is right here? new_node.attr["_output_quantized"].b = False # Add post output nodes that do not depend on the outputs for n in nodes_after_fuse: should_keep = True for input_name in name_to_input_name[n]: if input_name in nodes_deleted_by_fuse: should_keep = False if should_keep: out.node.extend([_copy.deepcopy(name_to_node[n])]) # Misc. graph_def data that needs copying. out.library.CopyFrom(graph_def.library) out.versions.CopyFrom(graph_def.versions) return out
def do_transformation(self): """Fuse the quantized op with the following requantize op. Returns: [graphdef]: the optimized graphdef object """ uint8_type = dtypes.quint8.as_datatype_enum float32_type = dtypes.float32.as_datatype_enum qint32_type = dtypes.qint32.as_datatype_enum while True: target_nodes = self.graph_analyzer.query_fusion_pattern_nodes( self.fuse_patterns['default']) if len(target_nodes) == 0: break i = target_nodes[0] quantized_node_name = i[0] quantized_node = self.graph_info[quantized_node_name].node requantize_node_name = i[1] requantize_node = self.graph_info[requantize_node_name].node requested_output_min_name = requantize_node.input[3] requested_output_max_name = requantize_node.input[4] quantized_node_op = i[-1][0] new_node = node_def_pb2.NodeDef() new_node.op = quantized_node_op + "AndRequantize" new_node.name = requantize_node_name for _, value in enumerate(quantized_node.input): new_node.input.append(value) new_node.input.append(requested_output_min_name) new_node.input.append(requested_output_max_name) if 'T1' in quantized_node.attr: new_node.attr["T1"].CopyFrom(quantized_node.attr['T1']) if 'T2' in quantized_node.attr: new_node.attr["T2"].CopyFrom(quantized_node.attr['T2']) parent_node_name = Helper.node_name_from_input(quantized_node.input[0]) max_filter_node = self.graph_info[new_node.input[6]].node min_filter_node = self.graph_info[new_node.input[5]].node last_node = self.graph_info[new_node.input[0]].node is_min_first = bool(quantized_node.attr['input_quant_mode'] == 'MIN_FIRST') if last_node.op.find('Requantize') != -1 or last_node.op.find('QuantizeV2') != -1: bias_node = self.graph_info[new_node.input[2]].node max_input_node = self.graph_info[last_node.input[-1]].node min_input_node = self.graph_info[last_node.input[-2]].node min_input_value = (min_input_node.attr['value'].tensor.float_val)[0] max_input_value = (max_input_node.attr['value'].tensor.float_val)[0] max_filter_value = (max_filter_node.attr['value'].tensor.float_val)[0] min_filter_value = (min_filter_node.attr['value'].tensor.float_val)[0] weights_tensor = tensor_util.MakeNdarray( self.graph_info[new_node.input[1]].node.attr['value'].tensor) bias_tensor = tensor_util.MakeNdarray( self.graph_info[new_node.input[2]].node.attr['value'].tensor) input_range = max_input_value - \ min_input_value if is_min_first else max( abs(max_input_value), abs(min_input_value)) bias_scale = 255.0 * 127.0 / ( input_range * max(abs(max_filter_value), abs(min_filter_value))) relative_scale = 255 * min_input_value / (max_input_value - min_input_value) int32_bias = [] for bias_index, value in enumerate( np.sum(np.array(weights_tensor, dtype=np.int32), axis=0, dtype=np.int32)): int32_bias.append(int(bias_tensor[bias_index] * bias_scale + value * relative_scale)) bias_node.attr['dtype'].CopyFrom( attr_value_pb2.AttrValue( type=float32_type if self.device == 'gpu' else qint32_type)) bias_node.attr['value'].CopyFrom( attr_value_pb2.AttrValue(tensor=tensor_util.make_tensor_proto( bias_tensor if self.device == 'gpu' else int32_bias, dtypes. float32 if self.device == 'gpu' else dtypes.int32, bias_tensor.shape))) bias_node.attr['value'].tensor.dtype = float32_type \ if self.device == 'gpu' else qint32_type new_node.attr["Tbias"].CopyFrom(attr_value_pb2.AttrValue(type=float32_type \ if self.device == 'gpu' else qint32_type)) new_node.attr["Toutput"].CopyFrom( attr_value_pb2.AttrValue(type=uint8_type)) #TODO enabled below commit once the graph refactor pre_optimize commmitted. if quantized_node_op.find('Relu') == -1: deq_node_name = self.graph_info[requantize_node_name].outputs[0] deq_node = self.graph_info[deq_node_name].node deq_node.attr['T'].CopyFrom(attr_value_pb2.AttrValue(type=uint8_type)) else: new_node.attr["Tbias"].CopyFrom(attr_value_pb2.AttrValue(type=float32_type)) self.graph_analyzer.replace_single_node( new_node, [parent_node_name], quantized_node_name, [self.graph_info[requantize_node_name].outputs[0]], requantize_node_name) self.graph_analyzer.remove_node(quantized_node_name) return self.graph_analyzer.dump_graph()
def bn_fold(input_graph_def, conv_name, weight_name, mean_name, var_name, beta_name, gamma_name, epsilon_name, add_name): input_node_map = get_input_node_map(input_graph_def) skip_ops = [conv_name, weight_name, mean_name, var_name, beta_name, gamma_name, epsilon_name, add_name] skip_ops.extend([]) try: conv_op = input_node_map[conv_name] weights_op = input_node_map[weight_name] mean_op = input_node_map[mean_name] var_op = input_node_map[var_name] beta_op = input_node_map[beta_name] gamma_op = input_node_map[gamma_name] epsilon_op = input_node_map[epsilon_name] add_op = input_node_map[add_name] except KeyError as e: print("node %s not in graph"%e) return [],[] weights = values_from_const(weights_op) mean_value = values_from_const(mean_op) var_value = values_from_const(var_op) beta_value = values_from_const(beta_op) gamma_value = values_from_const(gamma_op) variance_epsilon_value = values_from_const(epsilon_op) new_ops = [] scale_value = ( (1.0 / np.vectorize(math.sqrt)(var_value + variance_epsilon_value)) * gamma_value) offset_value = (-mean_value * scale_value) + beta_value scaled_weights = np.copy(weights) it = np.nditer( scaled_weights, flags=["multi_index"], op_flags=["readwrite"]) while not it.finished: if conv_op.op == "DepthwiseConv2dNative": current_scale = scale_value[it.multi_index[2]] else: current_scale = scale_value[it.multi_index[3]] it[0] *= current_scale it.iternext() scaled_weights_op = node_def_pb2.NodeDef() scaled_weights_op.op = "Const" scaled_weights_op.name = weights_op.name scaled_weights_op.attr["dtype"].CopyFrom(weights_op.attr["dtype"]) scaled_weights_op.attr["value"].CopyFrom( attr_value_pb2.AttrValue(tensor=tensor_util.make_tensor_proto( scaled_weights, weights.dtype.type, weights.shape))) new_conv_op = node_def_pb2.NodeDef() new_conv_op.CopyFrom(conv_op) offset_op = node_def_pb2.NodeDef() offset_op.op = "Const" offset_op.name = conv_op.name + "_bn_offset" offset_op.attr["dtype"].CopyFrom(mean_op.attr["dtype"]) offset_op.attr["value"].CopyFrom( attr_value_pb2.AttrValue(tensor=tensor_util.make_tensor_proto( offset_value, mean_value.dtype.type, offset_value.shape))) new_add_op = node_def_pb2.NodeDef() new_add_op.CopyFrom(add_op) del new_add_op.input[:] new_add_op.input.extend([new_conv_op.name, offset_op.name]) new_ops.extend([scaled_weights_op, new_conv_op, offset_op, new_add_op]) return skip_ops,new_ops
] for conv_name, weight_name, mean_name,var_name, beta_name, gamma_name, epsilon_name, add_name in zip(conv_names, weight_names, mean_names,var_names, beta_names, gamma_names, epsilon_names, add_names): skip_op, new_op = bn_fold(output_graph_def, conv_name, weight_name, mean_name, var_name, beta_name, gamma_name, epsilon_name, add_name) skip_ops.extend(skip_op) new_ops.extend(new_op) result_graph_def = graph_pb2.GraphDef() for node in output_graph_def.node: if node.name in skip_ops: continue new_node = node_def_pb2.NodeDef() new_node.CopyFrom(node) result_graph_def.node.extend([new_node]) result_graph_def.node.extend(new_ops) output_graph_def = result_graph_def output_graph_def = strip_unused_lib.strip_unused( output_graph_def, input_node_names=input_node_names, output_node_names=output_node_names, placeholder_type_enum=dtypes.uint8.as_datatype_enum) with open(output_pb_file,'wb') as f: f.write(output_graph_def.SerializeToString())
def convert_variables_to_constants(sess, input_graph_def, output_node_names, variable_names_whitelist=None, variable_names_blacklist=None): """Replaces all the variables in a graph with constants of the same values. If you have a trained graph containing Variable ops, it can be convenient to convert them all to Const ops holding the same values. This makes it possible to describe the network fully with a single GraphDef file, and allows the removal of a lot of ops related to loading and saving the variables. Args: sess: Active TensorFlow session containing the variables. input_graph_def: GraphDef object holding the network. output_node_names: List of name strings for the result nodes of the graph. variable_names_whitelist: The set of variable names to convert (by default, all variables are converted). variable_names_blacklist: The set of variable names to omit converting to constants. Returns: GraphDef containing a simplified version of the original. """ def get_input_name(node): """Gets the name of the first input. Errors if suffix is not :0.""" details = node.input[0].split(":") if len(details) == 1 or int(details[1]) == 0: return details[0] # While it is valid for input tensors to have a suffix that is not :0, this # method is used to find the associated ops, not tensors, and therefore it # is not valid. raise ValueError("Tensor name '{0}' is invalid.".format(node.input[0])) def create_const_op(node_name, dtype, data, data_shape=None): """Creates a Const op.""" output_node = node_def_pb2.NodeDef() output_node.op = "Const" output_node.name = node_name output_node.attr["dtype"].CopyFrom(dtype) output_node.attr["value"].CopyFrom( attr_value_pb2.AttrValue(tensor=tensor_util.make_tensor_proto( data, dtype=dtype.type, shape=data_shape))) return output_node # This graph only includes the nodes needed to evaluate the output nodes, and # removes unneeded nodes like those involved in saving and assignment. inference_graph = extract_sub_graph(input_graph_def, output_node_names) # Identify the ops in the graph. map_name_to_node = {node.name: node for node in inference_graph.node} # Get list of variables. variable_names = [] variable_dict_names = [] resource_identity_types = {} for node in inference_graph.node: if node.op in ["Variable", "VariableV2", "VarHandleOp"]: variable_name = node.name if ((variable_names_whitelist is not None and variable_name not in variable_names_whitelist) or (variable_names_blacklist is not None and variable_name in variable_names_blacklist)): continue variable_dict_names.append(variable_name) if node.op == "VarHandleOp": variable_names.append(variable_name + "/Read/ReadVariableOp:0") else: variable_names.append(variable_name + ":0") elif node.op in ["ReadVariableOp", "ResourceGather"]: # There can be one or more Identity ops in between the ReadVariableOp and # VarHandleOp. Store the Identity ops with the associated dtypes. source_op_name = get_input_name(node) while map_name_to_node[source_op_name].op == "Identity": resource_identity_types[source_op_name] = node.attr["dtype"] source_op_name = get_input_name( map_name_to_node[source_op_name]) if map_name_to_node[source_op_name].op != "VarHandleOp": raise ValueError("Cannot find the variable that is an input " "to the ReadVariableOp.") # Gets map of variables and the associated data. if variable_names: returned_variables = sess.run(variable_names) else: returned_variables = [] variables_data_map = dict(zip(variable_dict_names, returned_variables)) logging.info("Froze %d variables.", len(returned_variables)) # Reconstruct the graph with constants in place of variables. output_graph_def = graph_pb2.GraphDef() how_many_converted = 0 for input_node in inference_graph.node: output_node = node_def_pb2.NodeDef() if input_node.name in variables_data_map: data = variables_data_map[input_node.name] output_node = create_const_op(input_node.name, input_node.attr["dtype"], data, data.shape) how_many_converted += 1 elif input_node.name in resource_identity_types: # Converts the Identities of type RESOURCE_DT to the appropriate type # based on the input they are referencing. output_node.CopyFrom(input_node) output_node.attr["T"].CopyFrom( resource_identity_types[input_node.name]) elif input_node.op == "ReadVariableOp": # The first branch converts all VarHandleOps of ResourceVariables to # constants, so we need to convert the associated ReadVariableOps to # Identity ops. output_node.op = "Identity" output_node.name = input_node.name output_node.input.extend([input_node.input[0]]) output_node.attr["T"].CopyFrom(input_node.attr["dtype"]) if "_class" in input_node.attr: output_node.attr["_class"].CopyFrom(input_node.attr["_class"]) elif input_node.op == "ResourceGather": # The first branch converts all VarHandleOps of ResourceGather to # constants, so we need to convert the associated ResourceGather to Gather # ops with a Const axis feeding into it. if input_node.attr["batch_dims"].i != 0: raise ValueError( "batch_dims != 0 is not supported by freeze_graph.") axis_data = input_node.attr["batch_dims"].i axis_node_name = input_node.name + "/axis" axis_dtype = input_node.attr["Tindices"] output_axis_node = create_const_op(axis_node_name, axis_dtype, axis_data) output_graph_def.node.extend([output_axis_node]) output_node.op = "GatherV2" output_node.name = input_node.name output_node.input.extend( [input_node.input[0], input_node.input[1], axis_node_name]) output_node.attr["Tparams"].CopyFrom(input_node.attr["dtype"]) output_node.attr["Tindices"].CopyFrom(input_node.attr["Tindices"]) output_node.attr["Taxis"].CopyFrom(axis_dtype) if "_class" in input_node.attr: output_node.attr["_class"].CopyFrom(input_node.attr["_class"]) else: output_node.CopyFrom(input_node) output_graph_def.node.extend([output_node]) output_graph_def.library.CopyFrom(inference_graph.library) logging.info("Converted %d variables to const ops.", how_many_converted) return output_graph_def
class TestFoldConstant(unittest.TestCase): x_node = node_def_pb2.NodeDef() x_node.name = "placeholder" x_node.op = "Placeholder" input0_node = node_def_pb2.NodeDef() input0_node.name = "input0" input0_node.op = "Const" input0_value = np.float32(np.abs(np.random.randn(4, 3, 2))) input0_node.attr["value"].CopyFrom( attr_value_pb2.AttrValue(tensor=tensor_util.make_tensor_proto( input0_value, input0_value.dtype.type, input0_value.shape))) input1_node = node_def_pb2.NodeDef() input1_node.name = "input1" input1_node.op = "Const" input1_value = np.float32(np.abs(np.random.randn(4, 1, 1))) input1_node.attr["value"].CopyFrom( attr_value_pb2.AttrValue(tensor=tensor_util.make_tensor_proto( input1_value, input1_value.dtype.type, input1_value.shape))) add_node = node_def_pb2.NodeDef() add_node.op = "Add" add_node.name = "add" add_node.input.extend([input0_node.name, input1_node.name]) input2_node = node_def_pb2.NodeDef() input2_node.name = "input2" input2_node.op = "Const" input2_value = np.float32(np.abs(np.random.randn(1))) input2_node.attr["value"].CopyFrom( attr_value_pb2.AttrValue(tensor=tensor_util.make_tensor_proto( input2_value, input2_value.dtype.type, input2_value.shape))) input3_node = node_def_pb2.NodeDef() input3_node.name = "input3" input3_node.op = "Const" input3_value = np.float32(np.abs(np.random.randn(1))) input3_node.attr["value"].CopyFrom( attr_value_pb2.AttrValue(tensor=tensor_util.make_tensor_proto( input3_value, input3_value.dtype.type, input3_value.shape))) switch_node = node_def_pb2.NodeDef() switch_node.name = "switch" switch_node.op = "Switch" input4_node = node_def_pb2.NodeDef() input4_node.name = "input4" input4_node.op = "Const" input4_value = np.float32(np.abs(np.random.randn(1))) input4_node.attr["value"].CopyFrom( attr_value_pb2.AttrValue(tensor=tensor_util.make_tensor_proto( input4_value, input4_value.dtype.type, input4_value.shape))) input4_node.input.extend([switch_node.name]) input5_node = node_def_pb2.NodeDef() input5_node.name = "input5" input5_node.op = "Const" input5_value = np.float32(np.abs(np.random.randn(1))) input5_node.attr["value"].CopyFrom( attr_value_pb2.AttrValue(tensor=tensor_util.make_tensor_proto( input5_value, input5_value.dtype.type, input5_value.shape))) input5_node.input.extend([switch_node.name]) cond_end = node_def_pb2.NodeDef() cond_end.name = "cond" cond_end.op = "Add" cond_end.input.extend([input4_node.name, input5_node.name]) mul_node = node_def_pb2.NodeDef() mul_node.op = "Mul" mul_node.name = "mul" mul_node.input.extend([add_node.name, input3_node.name]) sqrt_node = node_def_pb2.NodeDef() sqrt_node.name = "rsqrt" sqrt_node.op = "Rsqrt" sqrt_node.input.extend([mul_node.name]) relu_node = node_def_pb2.NodeDef() relu_node.op = "Relu" relu_node.name = "relu" relu_node.input.extend([sqrt_node.name]) block_node = node_def_pb2.NodeDef() block_node.name = "block_output" block_node.op = "Add" block_node.input.extend([x_node.name, relu_node.name]) res_node = node_def_pb2.NodeDef() res_node.name = "res_add" res_node.op = "Add" res_node.input.extend([sqrt_node.name, input2_node.name]) end_node = node_def_pb2.NodeDef() end_node.name = "end" end_node.op = "Add" end_node.input.extend([block_node.name, res_node.name]) graph_def = graph_pb2.GraphDef() graph_def.node.extend([ x_node, input0_node, input1_node, input2_node, input3_node, add_node, mul_node, sqrt_node, relu_node, block_node, res_node, end_node ]) def test_fold_constant(self): graph = self.graph_def rewriter = GraphFoldConstantOptimizer(graph) new_graph = rewriter.do_transformation() for node in new_graph.node: assert node.name in [ "placeholder", "block_output", "rsqrt_const", "relu", "res_add_const", "end" ] def test_condition_fold_constant(self): graph_def = graph_pb2.GraphDef() graph_def.node.extend([ self.cond_end, self.input4_node, self.input5_node, self.switch_node ]) rewriter = GraphFoldConstantOptimizer(graph_def) new_graph = rewriter.do_transformation() for node in new_graph.node: assert node.name in ["switch", "cond", "input4", "input5"] def test_slice_int_input(self): graph_def = graph_pb2.GraphDef() index0_node = node_def_pb2.NodeDef() index0_node.name = "index0" index0_node.op = "Const" index0_value = np.array(3).astype(np.int32).reshape(()) index0_node.attr["value"].CopyFrom( attr_value_pb2.AttrValue(tensor=tensor_util.make_tensor_proto( index0_value, index0_value.dtype.type, index0_value.shape))) index1_node = node_def_pb2.NodeDef() index1_node.name = "index1" index1_node.op = "Const" index1_value = np.array(1).astype(np.int32).reshape(()) index1_node.attr["value"].CopyFrom( attr_value_pb2.AttrValue(tensor=tensor_util.make_tensor_proto( index1_value, index1_value.dtype.type, index1_value.shape))) minus_node = node_def_pb2.NodeDef() minus_node.name = "sub" minus_node.op = "Sub" minus_node.input.extend([index0_node.name, index1_node.name]) graph_def.node.extend([index0_node, index1_node, minus_node]) rewriter = GraphFoldConstantOptimizer(graph_def) new_graph = rewriter.do_transformation() with tf.compat.v1.Session() as sess: tf.compat.v1.import_graph_def(new_graph)
def generate_output_graph(input_graph_def, input_node_map, output_node_map, fuse_op_list, fuse_op_deq_list): output_graph_def = graph_pb2.GraphDef() skip_list = [] skip_node_name = [] int8_type = dtypes.qint8.as_datatype_enum uint8_type = dtypes.quint8.as_datatype_enum float32_type = dtypes.float32.as_datatype_enum qint32_type = dtypes.qint32.as_datatype_enum for index, node in enumerate(input_graph_def.node): if index in fuse_op_list: const_node_1 = input_graph_def.node[index + 1] const_node_2 = input_graph_def.node[index + 2] requantize_node = input_graph_def.node[index + 3] new_node = node_def_pb2.NodeDef() new_node.op = node.op + "AndRequantize" new_node.name = requantize_node.name for _, value in enumerate(node.input): new_node.input.append(value) new_node.input.append(const_node_1.name) new_node.input.append(const_node_2.name) new_node.attr["Tinput"].CopyFrom(node.attr['Tinput']) new_node.attr["Tfilter"].CopyFrom(node.attr['Tfilter']) new_node.attr["strides"].CopyFrom(node.attr['strides']) new_node.attr["padding"].CopyFrom(node.attr['padding']) if input_node_map[new_node.input[0]].op.find("Requantize") != -1: bias_node = input_node_map[new_node.input[2]] last_node = input_node_map[new_node.input[0]] max_input_node = (input_node_map[last_node.input[4][:-2]]) min_input_node = (input_node_map[last_node.input[3][:-2]]) max_filter = input_node_map[new_node.input[6]] min_filter = input_node_map[new_node.input[5]] min_input = (min_input_node.attr['value'].tensor.float_val)[0] max_input = (max_input_node.attr['value'].tensor.float_val)[0] if 'Depthwise' in node.op or "RequantizePerChannel" in [ node.op for node in output_node_map[node.name] ]: channel_size = max_filter.attr[ 'value'].tensor.tensor_shape.dim[0].size max_filter_tensor = tensor_util.MakeNdarray( max_filter.attr['value'].tensor) min_filter_tensor = tensor_util.MakeNdarray( min_filter.attr['value'].tensor) else: channel_size = 1 max_filter_tensor = [] min_filter_tensor = [] max_filter_tensor.append( (max_filter.attr['value'].tensor.float_val)[0]) min_filter_tensor.append( (min_filter.attr['value'].tensor.float_val)[0]) bias_tensor = tensor_util.MakeNdarray( input_node_map[new_node.input[2]].attr['value'].tensor) bias_length = bias_tensor.shape[0] scales = [] for i in range(channel_size): scales.append(255.0 * 127.0 / (max(abs(max_input), abs(min_input)) * max(abs(max_filter_tensor[i]), abs(min_filter_tensor[i])))) int32_bias = [] if channel_size > 1: for i in range(bias_length): int32_bias.append((int)(bias_tensor[i] * scales[i])) else: for i in range(bias_length): int32_bias.append((int)(bias_tensor[i] * scales[0])) bias_node.attr['dtype'].CopyFrom( attr_value_pb2.AttrValue(type=qint32_type)) bias_node.attr['value'].CopyFrom( attr_value_pb2.AttrValue( tensor=tensor_util.make_tensor_proto( int32_bias, dtypes.int32, bias_tensor.shape))) bias_node.attr['value'].tensor.dtype = qint32_type skip_node_name.append(bias_node.name) output_graph_def.node.extend([bias_node]) new_node.attr["Tbias"].CopyFrom( attr_value_pb2.AttrValue(type=qint32_type)) else: new_node.attr["Tbias"].CopyFrom( attr_value_pb2.AttrValue(type=float32_type)) if "padding_list" in node.attr: new_node.attr["padding_list"].CopyFrom( node.attr['padding_list']) if "dilations" in node.attr: new_node.attr["dilations"].CopyFrom(node.attr['dilations']) if node.op == "QuantizedConv2D" or node.op == "QuantizedConv2DWithBias": new_node.attr["out_type"].CopyFrom( attr_value_pb2.AttrValue(type=int8_type)) else: new_node.attr["out_type"].CopyFrom( attr_value_pb2.AttrValue(type=uint8_type)) skip_list.append(index + 1) skip_list.append(index + 2) skip_list.append(index + 3) output_graph_def.node.extend( [new_node, const_node_1, const_node_2]) elif index in skip_list or node.name in skip_node_name: continue elif node.op == "Dequantize": new_node = node_def_pb2.NodeDef() new_node.CopyFrom(node) new_node.attr["mode"].s = b"SCALED" p_node = input_node_map[new_node.input[0]] pp_node = input_node_map[p_node.name].input[0] if input_node_map[pp_node].op.find("Relu") != -1 or p_node.op in ( "QuantizedAvgPool", "QuantizedMaxPool", "QuantizedConcatV2"): new_node.attr["T"].CopyFrom( attr_value_pb2.AttrValue(type=uint8_type)) elif input_node_map[pp_node].op.find( "QuantizedMatMulWithBias") != -1 and p_node.op.find( "Requantize") != -1: new_node.attr["mode"].s = node.attr["mode"].s new_node.attr["T"].CopyFrom( attr_value_pb2.AttrValue(type=node.attr["T"].type)) else: new_node.attr["T"].CopyFrom( attr_value_pb2.AttrValue(type=int8_type)) output_graph_def.node.extend([new_node]) elif index in fuse_op_deq_list: original_summand_node = input_node_map[ input_graph_def.node[index].input[-1]] sum_const_node_1 = input_graph_def.node[index + 1] sum_const_node_2 = input_graph_def.node[index + 2] sum_requantize_node = input_graph_def.node[index + 3] new_node = node_def_pb2.NodeDef() new_node.op = node.op + "AndRequantize" new_node.name = sum_requantize_node.name for _, value in enumerate(node.input[:-1]): new_node.input.append(value) new_node.input.append(sum_const_node_1.name) new_node.input.append(sum_const_node_2.name) new_node.input.append( input_node_map[original_summand_node.name].input[0]) new_node.input.append( input_node_map[original_summand_node.name].input[0] + ":1") new_node.input.append( input_node_map[original_summand_node.name].input[0] + ":2") # skip_list.append(index + 1) # skip_list.append(index + 2) skip_list.append(index + 3) new_node.attr["Tinput"].CopyFrom(node.attr['Tinput']) new_node.attr["Tfilter"].CopyFrom(node.attr['Tfilter']) new_node.attr["strides"].CopyFrom(node.attr['strides']) new_node.attr["padding"].CopyFrom(node.attr['padding']) if input_node_map[new_node.input[0]].op.find("Requantize") != -1: bias_node = input_node_map[new_node.input[2]] last_node = input_node_map[new_node.input[0]] max_input_node = (input_node_map[last_node.input[4][:-2]]) min_input_node = (input_node_map[last_node.input[3][:-2]]) max_filter = input_node_map[new_node.input[6]] min_filter = input_node_map[new_node.input[5]] min_input = (min_input_node.attr['value'].tensor.float_val)[0] max_input = (max_input_node.attr['value'].tensor.float_val)[0] if "RequantizePerChannel" in [ node.op for node in output_node_map[node.name] ]: channel_size = max_filter.attr[ 'value'].tensor.tensor_shape.dim[0].size max_filter_tensor = tensor_util.MakeNdarray( max_filter.attr['value'].tensor) min_filter_tensor = tensor_util.MakeNdarray( min_filter.attr['value'].tensor) else: channel_size = 1 max_filter_tensor = [] min_filter_tensor = [] max_filter_tensor.append( (max_filter.attr['value'].tensor.float_val)[0]) min_filter_tensor.append( (min_filter.attr['value'].tensor.float_val)[0]) bias_tensor = (tensor_util.MakeNdarray( input_node_map[new_node.input[2]].attr['value'].tensor)) bias_length = bias_tensor.shape[0] scales = [] for i in range(channel_size): scales.append(255.0 * 127.0 / (max(abs(max_input), abs(min_input)) * max(abs(max_filter_tensor[i]), abs(min_filter_tensor[i])))) int32_bias = [] if channel_size > 1: for i in range(bias_length): int32_bias.append(int(bias_tensor[i] * scales[i])) else: for i in range(bias_length): int32_bias.append(int(bias_tensor[i] * scales[0])) bias_node.attr['dtype'].CopyFrom( attr_value_pb2.AttrValue(type=qint32_type)) bias_node.attr['value'].CopyFrom( attr_value_pb2.AttrValue( tensor=tensor_util.make_tensor_proto( int32_bias, dtypes.int32, bias_tensor.shape))) bias_node.attr['value'].tensor.dtype = qint32_type new_node.attr["Tbias"].CopyFrom( attr_value_pb2.AttrValue(type=qint32_type)) skip_node_name.append(bias_node.name) output_graph_def.node.extend([bias_node]) else: new_node.attr["Tbias"].CopyFrom( attr_value_pb2.AttrValue(type=float32_type)) if "padding_list" in node.attr: new_node.attr["padding_list"].CopyFrom( node.attr['padding_list']) if "dilations" in node.attr: new_node.attr["dilations"].CopyFrom(node.attr['dilations']) new_node.attr["out_type"].CopyFrom( attr_value_pb2.AttrValue(type=uint8_type)) summand_op_type = uint8_type if dtypes.as_dtype( original_summand_node.attr["T"].type ) == uint8_type else int8_type if summand_op_type == int8_type: new_node.op = "QuantizedConv2DWithBiasSignedSumAndReluAndRequantize" new_node.attr["Tsummand"].CopyFrom( attr_value_pb2.AttrValue(type=summand_op_type)) output_graph_def.node.extend([new_node]) else: new_node = node_def_pb2.NodeDef() new_node.CopyFrom(node) output_graph_def.node.extend([new_node]) return output_graph_def
def do_transformation(self): """Removes batch normalization ops by folding them into convolutions. Batch normalization during training has multiple dynamic parameters that are updated, but once the graph is finalized these become constants. That means there's an opportunity to reduce the computations down to a scale and addition, rather than the more expensive multiple ops, and even bake the scaling into the convolution weights. This function identifies the typical pattern of batch normalization subgraphs, and performs the transformation to fold the computations down into a simpler form. It currently only spots batch normalization that's performed by the BatchNormWithGlobalNormalization and FusedBatchNorm ops, and will need to be extended in the future to handle the newer style. Returns: Modified graph with BN ops removed, and modified weights. Raises: ValueError: If the graph is badly formed with duplicate node names. """ cur_graph = GraphAnalyzer() cur_graph.graph = self.model graph_info = cur_graph.parse_graph() target_nodes = cur_graph.query_fusion_pattern_nodes( [["Conv2D", "DepthwiseConv2dNative"], ("BiasAdd", "Add", "AddV2"), ["BatchNormWithGlobalNormalization", "FusedBatchNorm", "FusedBatchNormV3"]]) for node_combination in target_nodes: matched_node = node_combination[:-1] has_add_op = True if len(node_combination[-1]) == 3 else False conv_node = graph_info[Helper.node_name_from_input(matched_node[0])].node weights_node_name = graph_info[Helper.node_name_from_input( matched_node[0])].node.input[1] weights_node = graph_info[Helper.node_name_from_input(weights_node_name)].node bn_node = graph_info[Helper.node_name_from_input(matched_node[-1])].node if weights_node.op != "Const": self.logger.warning("Didn't find expected conv Constant input to '%s'," " found %s instead. Maybe because freeze_graph wasn't" " run first?" % (bn_node.name, weights_node_name)) continue weights = Helper.values_from_const(weights_node) if conv_node.op == "Conv2D": channel_count = weights.shape[3] elif conv_node.op == "DepthwiseConv2dNative": channel_count = weights.shape[2] * weights.shape[3] mean_node_name = Helper.node_name_from_input( bn_node.input[self.INPUT_ORDER[bn_node.op].index("mean_op")]) mean_node = graph_info[mean_node_name].node if mean_node.op != "Const": continue mean_value = Helper.values_from_const(mean_node) if has_add_op: bias_node_name = graph_info[Helper.node_name_from_input( matched_node[1])].node.input[1] bias_node = graph_info[Helper.node_name_from_input(bias_node_name)].node if bias_node.op != "Const": continue if mean_value.shape != (channel_count, ): continue mean_value = mean_value - Helper.values_from_const(bias_node) cur_graph.remove_node(bias_node.name) cur_graph.remove_node(matched_node[1]) if mean_value.shape != (channel_count, ): self.logger.warning("Incorrect shape for mean, found %s, expected %s," " for node %s" % (str(mean_value.shape), str( (channel_count, )), conv_node.name)) continue var_node_name = Helper.node_name_from_input( bn_node.input[self.INPUT_ORDER[bn_node.op].index("var_op")]) var_node = graph_info[var_node_name].node if var_node.op != "Const": continue var_value = Helper.values_from_const(var_node) if var_value.shape != (channel_count, ): continue beta_node_name = Helper.node_name_from_input( bn_node.input[self.INPUT_ORDER[bn_node.op].index("beta_op")]) beta_node = graph_info[beta_node_name].node if beta_node.op != "Const": continue beta_value = Helper.values_from_const(beta_node) if beta_value.shape != (channel_count, ): continue gamma_node_name = Helper.node_name_from_input( bn_node.input[self.INPUT_ORDER[bn_node.op].index("gamma_op")]) gamma_node = graph_info[gamma_node_name].node if gamma_node.op != "Const": continue gamma_value = Helper.values_from_const(gamma_node) if gamma_value.shape != (channel_count, ): continue variance_epsilon_value = bn_node.attr[self.EPSILON_ATTR[bn_node.op]].f if self.scale_after_normalization(bn_node): scale_value = ( (1.0 / np.vectorize(math.sqrt)(var_value + variance_epsilon_value)) * gamma_value) else: scale_value = (1.0 / np.vectorize(math.sqrt)(var_value + variance_epsilon_value)) offset_value = (-mean_value * scale_value) + beta_value if conv_node.op == "Conv2D": original_shape =weights.shape tmp_shape = (original_shape[-1], int(weights.size/original_shape[-1])) tmp_order = [weights.ndim - 1] + [i for i in range(weights.ndim - 1)] scaled_weights = np.copy(weights).transpose(tmp_order).ravel().reshape(tmp_shape) reshape_scale = np.array(scale_value).reshape(len(scale_value), 1) scaled_weights = np.multiply( scaled_weights, reshape_scale).transpose().reshape(original_shape) elif conv_node.op == "DepthwiseConv2dNative": scaled_weights = np.copy(weights) it = np.nditer(scaled_weights, flags=["multi_index"], op_flags=["readwrite"]) channel_multiplier = weights.shape[3] while not it.finished: current_scale = scale_value[it.multi_index[2] * channel_multiplier + it.multi_index[3]] it[0] *= current_scale it.iternext() scaled_weights_node = node_def_pb2.NodeDef() scaled_weights_node.op = "Const" scaled_weights_node.name = weights_node_name + "_bn_offset" scaled_weights_node.attr["dtype"].CopyFrom(weights_node.attr["dtype"]) scaled_weights_node.attr["value"].CopyFrom( attr_value_pb2.AttrValue(tensor=tensor_util.make_tensor_proto( scaled_weights, weights.dtype.type, weights.shape))) cur_graph.replace_const_node(scaled_weights_node, [conv_node.name], weights_node_name) offset_node = node_def_pb2.NodeDef() offset_node.op = "Const" offset_node.name = conv_node.name + "_bn_offset" offset_node.attr["dtype"].CopyFrom(mean_node.attr["dtype"]) offset_node.attr["value"].CopyFrom( attr_value_pb2.AttrValue(tensor=tensor_util.make_tensor_proto( offset_value, mean_value.dtype.type, offset_value.shape))) bias_add_node = node_def_pb2.NodeDef() bias_add_node.op = "BiasAdd" bias_add_node.name = bn_node.name bias_add_node.attr["T"].CopyFrom(conv_node.attr["T"]) bias_add_node.attr["data_format"].CopyFrom(conv_node.attr["data_format"]) bias_add_node.input.extend([conv_node.name, offset_node.name]) cur_graph.add_node(offset_node, [], [bias_add_node.name]) cur_graph.add_node(bias_add_node, conv_node.name, graph_info[Helper.node_name_from_input(matched_node[-1])].outputs) cur_graph.replace_const_node(scaled_weights_node, [conv_node.name], weights_node_name) cur_graph.remove_node(weights_node_name) cur_graph.remove_node(mean_node_name) cur_graph.remove_node(var_node_name) cur_graph.remove_node(beta_node_name) cur_graph.remove_node(gamma_node_name) return cur_graph.dump_graph()
def remove_training_nodes(input_graph, protected_nodes=None): """Prunes out nodes that aren't needed for inference. There are nodes like Identity and CheckNumerics that are only useful during training, and can be removed in graphs that will be used for nothing but inference. Here we identify and remove them, returning an equivalent graph. To be specific, CheckNumerics nodes are always removed, and Identity nodes that aren't involved in control edges are spliced out so that their input and outputs are directly connected. Args: input_graph: Model to analyze and prune. protected_nodes: An optional list of names of nodes to be kept unconditionally. This is for example useful to preserve Identity output nodes. Returns: A list of nodes with the unnecessary ones removed. """ if not protected_nodes: protected_nodes = [] types_to_remove = {"CheckNumerics": True} input_nodes = input_graph.node names_to_remove = {} for node in input_nodes: if node.op in types_to_remove and node.name not in protected_nodes: names_to_remove[node.name] = True nodes_after_removal = [] for node in input_nodes: if node.name in names_to_remove: continue new_node = node_def_pb2.NodeDef() new_node.CopyFrom(node) input_before_removal = node.input del new_node.input[:] for full_input_name in input_before_removal: input_name = re.sub(r"^\^", "", full_input_name) if input_name in names_to_remove: continue new_node.input.append(full_input_name) nodes_after_removal.append(new_node) types_to_splice = {"Identity": True} control_input_names = set() node_names_with_control_input = set() for node in nodes_after_removal: for node_input in node.input: if "^" in node_input: control_input_names.add(node_input.replace("^", "")) node_names_with_control_input.add(node.name) names_to_splice = {} for node in nodes_after_removal: if node.op in types_to_splice and node.name not in protected_nodes: # We don't want to remove nodes that have control edge inputs, because # they might be involved in subtle dependency issues that removing them # will jeopardize. if node.name not in node_names_with_control_input: names_to_splice[node.name] = node.input[0] # We also don't want to remove nodes which are used as control edge inputs. names_to_splice = { name: value for name, value in names_to_splice.items() if name not in control_input_names } nodes_after_splicing = [] for node in nodes_after_removal: if node.name in names_to_splice: continue new_node = node_def_pb2.NodeDef() new_node.CopyFrom(node) input_before_removal = node.input del new_node.input[:] for full_input_name in input_before_removal: input_name = re.sub(r"^\^", "", full_input_name) while input_name in names_to_splice: full_input_name = names_to_splice[input_name] input_name = re.sub(r"^\^", "", full_input_name) new_node.input.append(full_input_name) nodes_after_splicing.append(new_node) output_graph = graph_pb2.GraphDef() output_graph.node.extend(nodes_after_splicing) return output_graph
def strip_unused(input_graph_def, input_tensor_names, output_tensor_names, placeholder_type_enum): """Removes unused nodes from a GraphDef. Args: input_graph_def: A graph with nodes we want to prune. input_tensor_names: A list of the nodes we use as inputs. output_tensor_names: A list of the output nodes. placeholder_type_enum: The AttrValue enum for the placeholder data type, or a list that specifies one value per input node name. Returns: A `GraphDef` with all unnecessary ops removed. and a map containing the old input names to the new input names Raises: ValueError: If any element in `input_node_names` refers to a tensor instead of an operation. KeyError: If any element in `input_node_names` is not found in the graph. """ for name in input_tensor_names: if ":" not in name: raise ValueError("Input '%s' appears to refer to a Operation, " "not a Tensor." % name) old2new = {} # Here we replace the nodes we're going to override as inputs with # placeholders so that any unused nodes that are inputs to them are # automatically stripped out by extract_sub_graph(). not_found = {name for name in input_tensor_names} input_node_names = {name.split(":")[0] for name in input_tensor_names} output_node_names = list( {name.split(":")[0] for name in output_tensor_names}) inputs_replaced_graph_def = graph_pb2.GraphDef() for node in input_graph_def.node: if node.name not in input_node_names: for i in range(len(node.input)): if _append_port(node.input[i]) in input_tensor_names: old_name = _append_port(node.input[i]) not_found.remove(old_name) new_input_name = node.input[i].replace(":", "_") placeholder_node = node_def_pb2.NodeDef() placeholder_node.op = "Placeholder" placeholder_node.name = new_input_name if isinstance(placeholder_type_enum, list): input_node_index = input_tensor_names.index(old_name) placeholder_node.attr["dtype"].CopyFrom( attr_value_pb2.AttrValue( type=placeholder_type_enum[input_node_index])) else: placeholder_node.attr["dtype"].CopyFrom( attr_value_pb2.AttrValue( type=placeholder_type_enum)) if "_output_shapes" in node.attr: placeholder_node.attr["_output_shapes"].CopyFrom( node.attr["_output_shapes"]) node.input[i] = new_input_name old2new[old_name] = new_input_name + ":0" inputs_replaced_graph_def.node.extend([placeholder_node]) inputs_replaced_graph_def.node.extend([copy.deepcopy(node)]) if not_found: raise KeyError("The following input nodes were not found: %s\n" % not_found) output_graph_def = graph_util.extract_sub_graph(inputs_replaced_graph_def, output_node_names) return output_graph_def, old2new
def generate_output_graph(self, input_graph_def, input_node_map, fuse_op_name): output_graph_def = graph_pb2.GraphDef() skip_list = [] skip_node_name = [] for index, node in enumerate(input_graph_def.node): if node.name in fuse_op_name: skip_list.append(index + 1) original_node = input_node_map[node.name] mul_node = input_node_map[fuse_op_name[node.name]] weights_node_name = original_node.input[1] weights_node = input_node_map[weights_node_name] mul_value_node_name = mul_node.input[1] mul_value_node = input_node_map[mul_value_node_name] new_node = node_def_pb2.NodeDef() new_node.op = original_node.op new_node.name = mul_node.name for _, value in enumerate(node.input): new_node.input.append(value) if original_node.op == "DepthwiseConv2dNative": weights_col = weights_node.attr[ 'value'].tensor.tensor_shape.dim[ 2].size * weights_node.attr[ 'value'].tensor.tensor_shape.dim[3].size elif original_node.op == "Conv2D": weights_col = weights_node.attr[ 'value'].tensor.tensor_shape.dim[3].size else: weights_col = weights_node.attr[ 'value'].tensor.tensor_shape.dim[1].size mul_value_node_tensor = mul_value_node.attr['value'].tensor weights_node_tensor = weights_node.attr['value'].tensor if len(mul_value_node_tensor.tensor_shape.dim ) != 1 or mul_value_node_tensor.tensor_shape.dim[ 0].size != weights_col: print("Invalid Mul OP fusion.") mul_value_node_list = [ i for i in tensor_util.MakeNdarray( mul_value_node_tensor).flat ] new_weights = [] for index, i in enumerate( tensor_util.MakeNdarray(weights_node_tensor).flat): new_weights_value = i * mul_value_node_list[ index % len(mul_value_node_list)] new_weights.append(new_weights_value) weights_node.attr['value'].CopyFrom( attr_value_pb2. AttrValue(tensor=tensor_util.make_tensor_proto( new_weights, dtypes.float32, tensor_util.MakeNdarray(weights_node_tensor).shape))) skip_node_name.append(weights_node.name) output_graph_def.node.extend([weights_node]) for key in original_node.attr: new_node.attr[key].CopyFrom(original_node.attr[key]) output_graph_def.node.extend([new_node]) elif index in skip_list or node.name in skip_node_name: continue else: new_node = node_def_pb2.NodeDef() new_node.CopyFrom(node) output_graph_def.node.extend([new_node]) return output_graph_def
def eightbitize_nodes_recursively(self, current_node): if current_node.name in self.state.already_visited: if (self.should_merge_with_fake_quant_node() or current_node.name in self.state.merged_with_fake_quant): raise ValueError( "Unsupported graph structure: output of node %s " "is processed by a FakeQuant* node and should have " "no other outputs.", current_node.name) return self.state.already_visited[current_node.name] = True for i, input_node_name in enumerate(current_node.input): quantize_input = False if current_node.op in ("MatMul", "Conv2D", "BiasAdd", "MaxPool", "AvgPool", "Relu", "Relu6", "BatchNormWithGlobalNormalization"): quantize_input = True elif current_node.op == "Concat" and i > 0: quantize_input = (dtypes.as_dtype( current_node.attr["T"].type) == dtypes.float32) elif current_node.op == "Reshape" and i == 0: quantize_input = (dtypes.as_dtype( current_node.attr["T"].type) == dtypes.float32) self.state.output_node_stack.append( (current_node, i, quantize_input)) input_node_name = node_name_from_input(input_node_name) input_node = self.nodes_map[input_node_name] self.eightbitize_nodes_recursively(input_node) self.state.output_node_stack.pop() if current_node.op == "MatMul": self.eightbitize_mat_mul_node(current_node) elif current_node.op == "Conv2D": self.eightbitize_conv_node(current_node) elif current_node.op == "BiasAdd": self.eightbitize_bias_add_node(current_node) elif current_node.op == "MaxPool" or current_node.op == "AvgPool": self.eightbitize_single_input_tensor_node(current_node, self.add_pool_function) elif current_node.op == "Relu" or current_node.op == "Relu6": self.eightbitize_single_input_tensor_node(current_node, self.add_relu_function) elif (current_node.op == "Concat" and dtypes.as_dtype( current_node.attr["T"].type) == dtypes.float32): self.eightbitize_concat_node(current_node) elif current_node.op == "BatchNormWithGlobalNormalization": self.eightbitize_batch_norm_node(current_node) elif (current_node.op == "Reshape" and dtypes.as_dtype( current_node.attr["T"].type) == dtypes.float32): self.eightbitize_reshape_node(current_node) elif (self.input_range and current_node.op in ("Placeholder", "PlaceholderV2")): self.eightbitize_placeholder_node(current_node) elif current_node.op == "FakeQuantWithMinMaxVars": pass elif current_node.op == "Const": if self.should_quantize_const(current_node): for n in quantize_weight_eightbit(current_node, b"MIN_FIRST"): self.add_output_graph_node(n) else: new_node = node_def_pb2.NodeDef() new_node.CopyFrom(current_node) self.add_output_graph_node(new_node) else: new_node = node_def_pb2.NodeDef() new_node.CopyFrom(current_node) self.add_output_graph_node(new_node) if (self.should_merge_with_fake_quant_node() and current_node.name not in self.state.merged_with_fake_quant): raise ValueError( "FakeQuant* node %s failed to merge with node %s of type %s" % (self.state.output_node_stack[-1][0], current_node.name, current_node.op))
class TestGraph_util(unittest.TestCase): x_node = node_def_pb2.NodeDef() x_node.name = "placeholder" x_node.op = "Placeholder" input0_node = node_def_pb2.NodeDef() input0_node.name = "input0" input0_node.op = "Const" input0_value = np.float32(np.abs(np.random.randn(4, 3, 2))) input0_node.attr["value"].CopyFrom( attr_value_pb2.AttrValue(tensor=tensor_util.make_tensor_proto( input0_value, input0_value.dtype.type, input0_value.shape))) input1_node = node_def_pb2.NodeDef() input1_node.name = "input1" input1_node.op = "Const" input1_value = np.float32(np.abs(np.random.randn(4, 1, 1))) input1_node.attr["value"].CopyFrom( attr_value_pb2.AttrValue(tensor=tensor_util.make_tensor_proto( input1_value, input1_value.dtype.type, input1_value.shape))) add_node = node_def_pb2.NodeDef() add_node.op = "Add" add_node.name = "add" add_node.input.extend([input0_node.name, input1_node.name]) input2_node = node_def_pb2.NodeDef() input2_node.name = "input2" input2_node.op = "Const" input2_value = np.float32(np.abs(np.random.randn(1))) input2_node.attr["value"].CopyFrom( attr_value_pb2.AttrValue(tensor=tensor_util.make_tensor_proto( input2_value, input2_value.dtype.type, input2_value.shape))) input3_node = node_def_pb2.NodeDef() input3_node.name = "input3" input3_node.op = "Const" input3_value = np.float32(np.abs(np.random.randn(1))) input3_node.attr["value"].CopyFrom( attr_value_pb2.AttrValue(tensor=tensor_util.make_tensor_proto( input3_value, input3_value.dtype.type, input3_value.shape))) mul_node = node_def_pb2.NodeDef() mul_node.op = "Mul" mul_node.name = "mul" mul_node.input.extend([add_node.name, input3_node.name]) sqrt_node = node_def_pb2.NodeDef() sqrt_node.name = "rsqrt" sqrt_node.op = "Rsqrt" sqrt_node.input.extend([mul_node.name]) sqrt1_node = node_def_pb2.NodeDef() sqrt1_node.op = "Relu" sqrt1_node.name = "sqrt1" sqrt1_node.input.extend([sqrt_node.name]) block_node = node_def_pb2.NodeDef() block_node.name = "block_output" block_node.op = "Add" block_node.input.extend([x_node.name, sqrt1_node.name]) res_node = node_def_pb2.NodeDef() res_node.name = "res_add" res_node.op = "Add" res_node.input.extend([sqrt_node.name, input2_node.name]) end_node = node_def_pb2.NodeDef() end_node.name = "end" end_node.op = "Add" end_node.input.extend([block_node.name, res_node.name]) graph_def = graph_pb2.GraphDef() graph_def.node.extend([ x_node, input0_node, input1_node, input2_node, input3_node, add_node, mul_node, sqrt_node, sqrt1_node, block_node, res_node, end_node ]) def test_replace_constant_graph_with_constant_node(self): graph_analyzer = GraphAnalyzer() graph_analyzer.graph = copy.deepcopy(self.graph_def) graph_analyzer.parse_graph() new_constant_value = np.random.random([4, 1]) new_constant_type = tf.as_dtype(np.float32(new_constant_value).dtype) new_constant_node = GraphRewriterHelper.create_constant_node( self.add_node.name + "_const", new_constant_value, new_constant_type) assert graph_analyzer.replace_constant_graph_with_constant_node( new_constant_node, self.add_node.name) result_graph = graph_analyzer.dump_graph() assert len(list(result_graph.node)) == 10 new_constant_value = np.random.random([4, 1]) new_constant_type = tf.as_dtype(np.float32(new_constant_value).dtype) new_constant_node = GraphRewriterHelper.create_constant_node( self.mul_node.name + "_const", new_constant_value, new_constant_type) assert graph_analyzer.replace_constant_graph_with_constant_node( new_constant_node, self.mul_node.name) result_graph = graph_analyzer.dump_graph() assert len(list(result_graph.node)) == 8 new_constant_value = np.random.random([4, 1]) new_constant_type = tf.as_dtype(np.float32(new_constant_value).dtype) new_constant_node = GraphRewriterHelper.create_constant_node( self.sqrt_node.name + "_const", new_constant_value, new_constant_type) assert graph_analyzer.replace_constant_graph_with_constant_node( new_constant_node, self.sqrt_node.name) result_graph = graph_analyzer.dump_graph() assert len(list(result_graph.node)) == 7 new_constant_value = np.random.random([4, 1]) new_constant_type = tf.as_dtype(np.float32(new_constant_value).dtype) new_constant_node = GraphRewriterHelper.create_constant_node( self.block_node.name + "_const", new_constant_value, new_constant_type) assert not graph_analyzer.replace_constant_graph_with_constant_node( new_constant_node, self.block_node.name) def test_replace_node(self): graph_analyzer = GraphAnalyzer() graph_analyzer.graph = copy.deepcopy(self.graph_def) graph_analyzer.parse_graph() new_add_node = node_def_pb2.NodeDef() new_add_node.op = "Add" new_add_node.name = "add1" new_add_node.input.extend( [self.input0_node.name, self.input1_node.name]) graph_analyzer.replace_node(new_add_node, self.add_node.name, [self.mul_node.name]) result_graph = graph_analyzer.dump_graph() assert self.add_node not in list(result_graph.node) assert new_add_node in list(result_graph.node) def test_freeze_value_regrex(self): sample_str_1 = ';efficientnet-b3/model/blocks_14/se/conv2d/Conv2D_eightbit_requant_range__print__;__requant_min_max:[-2.35420851e+09][2.59383834e+09]' sample_str_2 = ';efficientnet-b3/model/blocks_15/se/conv2d/Conv2D_eightbit_requant_range__print__;__requant_min_max:[-1.254][2.59383834]' print_suffix = '__print__' postfix = '__requant_min_max' res_1 = re.search( r"{};{}:\[\-?\d+\.?\d*e?\+?\d*\]".format(print_suffix, postfix), sample_str_1) res_2 = re.search( r"{};{}:\[\-?\d+\.?\d*e?\+?\d*\]".format(print_suffix, postfix), sample_str_2) self.assertNotEqual(res_1, None) self.assertNotEqual(res_2, None)
def remove_training_nodes(input_graph): """Prunes out nodes that aren't needed for inference. There are nodes like Identity and CheckNumerics that are only useful during training, and can be removed in graphs that will be used for nothing but inference. Here we identify and remove them, returning an equivalent graph. To be specific, CheckNumerics nodes are always removed, and Identity nodes that aren't involved in control edges are spliced out so that their input and outputs are directly connected. Args: input_graph: Model to analyze and prune. Returns: A list of nodes with the unnecessary ones removed. """ types_to_remove = {"CheckNumerics": True} input_nodes = input_graph.node names_to_remove = {} for node in input_nodes: if node.op in types_to_remove: names_to_remove[node.name] = True nodes_after_removal = [] for node in input_nodes: if node.name in names_to_remove: continue new_node = node_def_pb2.NodeDef() new_node.CopyFrom(node) input_before_removal = node.input del new_node.input[:] for full_input_name in input_before_removal: input_name = re.sub(r"^\^", "", full_input_name) if input_name in names_to_remove: continue new_node.input.append(full_input_name) nodes_after_removal.append(new_node) types_to_splice = {"Identity": True} names_to_splice = {} for node in nodes_after_removal: if node.op in types_to_splice: # We don't want to remove nodes that have control edge inputs, because # they might be involved in subtle dependency issues that removing them # will jeopardize. has_control_edge = False for input_name in node.input: if re.match(r"^\^", input_name): has_control_edge = True if not has_control_edge: names_to_splice[node.name] = node.input[0] nodes_after_splicing = [] for node in nodes_after_removal: if node.name in names_to_splice: continue new_node = node_def_pb2.NodeDef() new_node.CopyFrom(node) input_before_removal = node.input del new_node.input[:] for full_input_name in input_before_removal: input_name = re.sub(r"^\^", "", full_input_name) if input_name in names_to_splice: new_node.input.append(names_to_splice[input_name]) else: new_node.input.append(full_input_name) nodes_after_splicing.append(new_node) output_graph = graph_pb2.GraphDef() output_graph.node.extend(nodes_after_splicing) return output_graph
def fold_batch_norms(input_graph_def): """Removes batch normalization ops by folding them into convolutions. Batch normalization during training has multiple dynamic parameters that are updated, but once the graph is finalized these become constants. That means there's an opportunity to reduce the computations down to a scale and addition, rather than the more expensive multiple ops, and even bake the scaling into the convolution weights. This function identifies the typical pattern of batch normalization subgraphs, and performs the transformation to fold the computations down into a simpler form. It currently only spots batch normalization that's performed by the BatchNormWithGlobalNormalization and FusedBatchNorm ops, and will need to be extended in the future to handle the newer style. Args: input_graph_def: A GraphDef containing a model. Returns: Modified graph with BN ops removed, and modified weights. Raises: ValueError: If the graph is badly formed with duplicate node names. """ input_node_map = {} for node in input_graph_def.node: if node.name not in input_node_map: input_node_map[node.name] = node else: raise ValueError("Duplicate node names detected for ", node.name) nodes_to_skip = {} new_ops = [] for node in input_graph_def.node: if node.op not in ("BatchNormWithGlobalNormalization", "FusedBatchNorm"): continue conv_op = node_from_map(input_node_map, node.input[INPUT_ORDER[node.op].index("conv_op")]) if conv_op.op != "Conv2D": tf_logging.warning( "Didn't find expected Conv2D input to '%s'" % node.name) continue weights_op = node_from_map(input_node_map, conv_op.input[1]) if weights_op.op != "Const": tf_logging.warning("Didn't find expected conv Constant input to '%s'," " found %s instead. Maybe because freeze_graph wasn't" " run first?" % (conv_op.name, weights_op)) continue weights = values_from_const(weights_op) channel_count = weights.shape[3] mean_op = node_from_map(input_node_map, node.input[INPUT_ORDER[node.op].index("mean_op")]) if mean_op.op != "Const": tf_logging.warning("Didn't find expected mean Constant input to '%s'," " found %s instead. Maybe because freeze_graph wasn't" " run first?" % (node.name, mean_op)) continue mean_value = values_from_const(mean_op) if mean_value.shape != (channel_count,): tf_logging.warning("Incorrect shape for mean, found %s, expected %s," " for node %s" % (str(mean_value.shape), str( (channel_count,)), node.name)) continue var_op = node_from_map(input_node_map, node.input[INPUT_ORDER[node.op].index("var_op")]) if var_op.op != "Const": tf_logging.warning("Didn't find expected var Constant input to '%s'," " found %s instead. Maybe because freeze_graph wasn't" " run first?" % (node.name, var_op)) continue var_value = values_from_const(var_op) if var_value.shape != (channel_count,): tf_logging.warning("Incorrect shape for var, found %s, expected %s," " for node %s" % (str(var_value.shape), str( (channel_count,)), node.name)) continue beta_op = node_from_map(input_node_map, node.input[INPUT_ORDER[node.op].index("beta_op")]) if beta_op.op != "Const": tf_logging.warning("Didn't find expected beta Constant input to '%s'," " found %s instead. Maybe because freeze_graph wasn't" " run first?" % (node.name, beta_op)) continue beta_value = values_from_const(beta_op) if beta_value.shape != (channel_count,): tf_logging.warning("Incorrect shape for beta, found %s, expected %s," " for node %s" % (str(beta_value.shape), str( (channel_count,)), node.name)) continue gamma_op = node_from_map(input_node_map, node.input[INPUT_ORDER[node.op].index("gamma_op")]) if gamma_op.op != "Const": tf_logging.warning("Didn't find expected gamma Constant input to '%s'," " found %s instead. Maybe because freeze_graph wasn't" " run first?" % (node.name, gamma_op)) continue gamma_value = values_from_const(gamma_op) if gamma_value.shape != (channel_count,): tf_logging.warning("Incorrect shape for gamma, found %s, expected %s," " for node %s" % (str(gamma_value.shape), str( (channel_count,)), node.name)) continue variance_epsilon_value = node.attr[EPSILON_ATTR[node.op]].f nodes_to_skip[node.name] = True nodes_to_skip[weights_op.name] = True nodes_to_skip[mean_op.name] = True nodes_to_skip[var_op.name] = True nodes_to_skip[beta_op.name] = True nodes_to_skip[gamma_op.name] = True nodes_to_skip[conv_op.name] = True if scale_after_normalization(node): scale_value = ( (1.0 / np.vectorize(math.sqrt)(var_value + variance_epsilon_value)) * gamma_value) else: scale_value = ( 1.0 / np.vectorize(math.sqrt)(var_value + variance_epsilon_value)) offset_value = (-mean_value * scale_value) + beta_value scaled_weights = np.copy(weights) it = np.nditer( scaled_weights, flags=["multi_index"], op_flags=["readwrite"]) while not it.finished: current_scale = scale_value[it.multi_index[3]] it[0] *= current_scale it.iternext() scaled_weights_op = node_def_pb2.NodeDef() scaled_weights_op.op = "Const" scaled_weights_op.name = weights_op.name scaled_weights_op.attr["dtype"].CopyFrom(weights_op.attr["dtype"]) scaled_weights_op.attr["value"].CopyFrom( attr_value_pb2.AttrValue(tensor=tensor_util.make_tensor_proto( scaled_weights, weights.dtype.type, weights.shape))) new_conv_op = node_def_pb2.NodeDef() new_conv_op.CopyFrom(conv_op) offset_op = node_def_pb2.NodeDef() offset_op.op = "Const" offset_op.name = conv_op.name + "_bn_offset" offset_op.attr["dtype"].CopyFrom(mean_op.attr["dtype"]) offset_op.attr["value"].CopyFrom( attr_value_pb2.AttrValue(tensor=tensor_util.make_tensor_proto( offset_value, mean_value.dtype.type, offset_value.shape))) bias_add_op = node_def_pb2.NodeDef() bias_add_op.op = "BiasAdd" bias_add_op.name = node.name bias_add_op.attr["T"].CopyFrom(conv_op.attr["T"]) bias_add_op.attr["data_format"].CopyFrom(conv_op.attr["data_format"]) bias_add_op.input.extend([new_conv_op.name, offset_op.name]) new_ops.extend([scaled_weights_op, new_conv_op, offset_op, bias_add_op]) result_graph_def = graph_pb2.GraphDef() for node in input_graph_def.node: if node.name in nodes_to_skip: continue new_node = node_def_pb2.NodeDef() new_node.CopyFrom(node) result_graph_def.node.extend([new_node]) result_graph_def.node.extend(new_ops) return result_graph_def
def generate_output_graph(self, input_graph_def, input_node_map, fuse_op_list): output_graph_def = graph_pb2.GraphDef() skip_list = [] skip_node_name = [] float32_type = dtypes.float32.as_datatype_enum for index, node in enumerate(input_graph_def.node): if index in fuse_op_list: input_node = input_node_map[node.input[0]] if input_node.op == 'QuantizeV2': new_node = node_def_pb2.NodeDef() new_node.op = node.op + "AndDequantize" for _, value in enumerate(node.input): new_node.input.append(value) dequantize_node = input_graph_def.node[index + 4] frozen_max_node = input_graph_def.node[index + 2] frozen_min_node = input_graph_def.node[index + 1] new_node.name = dequantize_node.name new_node.input.append(frozen_min_node.name) new_node.input.append(frozen_max_node.name) new_node.attr["T1"].CopyFrom(node.attr['T1']) new_node.attr["T2"].CopyFrom(node.attr['T2']) new_node.attr["Tbias"].CopyFrom( attr_value_pb2.AttrValue(type=float32_type)) new_node.attr["Toutput"].CopyFrom( attr_value_pb2.AttrValue(type=float32_type)) skip_list.append(index + 1) skip_list.append(index + 2) skip_list.append(index + 3) skip_list.append(index + 4) output_graph_def.node.extend( [new_node, frozen_max_node, frozen_min_node]) elif input_node.op == "Requantize": new_node = node_def_pb2.NodeDef() new_node.op = node.op + "AndDequantize" for _, value in enumerate(node.input): new_node.input.append(value) dequantize_node = input_graph_def.node[index + 4] frozen_max_node = input_graph_def.node[index + 2] frozen_min_node = input_graph_def.node[index + 1] new_node.name = dequantize_node.name skip_list.append(index + 1) skip_list.append(index + 2) skip_list.append(index + 3) skip_list.append(index + 4) new_node.input.append(frozen_min_node.name) new_node.input.append(frozen_max_node.name) new_node.attr["T1"].CopyFrom(node.attr['T1']) new_node.attr["T2"].CopyFrom(node.attr['T2']) new_node.attr["Tbias"].CopyFrom( attr_value_pb2.AttrValue(type=float32_type)) new_node.attr["Toutput"].CopyFrom( attr_value_pb2.AttrValue(type=float32_type)) output_graph_def.node.extend( [new_node, frozen_max_node, frozen_min_node]) else: new_node = node_def_pb2.NodeDef() new_node.CopyFrom(node) output_graph_def.node.extend([new_node]) elif index in skip_list or node.name in skip_node_name: continue else: new_node = node_def_pb2.NodeDef() new_node.CopyFrom(node) output_graph_def.node.extend([new_node]) return output_graph_def
graph_filename = 'mrt_graph_1.pb' graph_filename_converted = 'mrt_graph_2.pb' f = gfile.FastGFile(graph_filename, 'rb') # define graph def object graph_def = tf.GraphDef() # store frozen graph from pb file graph_def.ParseFromString(f.read()) # define new empty graph modified_graph_def = graph_pb2.GraphDef() # pre-define empty image placeholder node image_placeholder_node = node_def_pb2.NodeDef() # iterate through all nodes in graph for node in graph_def.node: # set dtype attibute of imagePlaceholder node to int32 if node.name == 'vars/Cast': print node # iterate through all nodes in graph for node in graph_def.node: # set dtype attibute of imagePlaceholder node to int32 if node.name == 'imagePlaceholder': # print("found image placeholder") node.attr["dtype"].CopyFrom(
def convert_variables_to_constants(sess, input_graph_def, output_node_names, variable_names_whitelist=None, variable_names_blacklist=None, use_fp16=False): from tensorflow.python.framework.graph_util_impl import extract_sub_graph from tensorflow.core.framework import graph_pb2 from tensorflow.core.framework import node_def_pb2 from tensorflow.core.framework import attr_value_pb2 from tensorflow.core.framework import types_pb2 from tensorflow.python.framework import tensor_util def patch_dtype(input_node, field_name, output_node): if use_fp16 and (field_name in input_node.attr) and (input_node.attr[field_name].type == types_pb2.DT_FLOAT): output_node.attr[field_name].CopyFrom(attr_value_pb2.AttrValue(type=types_pb2.DT_HALF)) inference_graph = extract_sub_graph(input_graph_def, output_node_names) variable_names = [] variable_dict_names = [] for node in inference_graph.node: if node.op in ["Variable", "VariableV2", "VarHandleOp"]: variable_name = node.name if ((variable_names_whitelist is not None and variable_name not in variable_names_whitelist) or (variable_names_blacklist is not None and variable_name in variable_names_blacklist)): continue variable_dict_names.append(variable_name) if node.op == "VarHandleOp": variable_names.append(variable_name + "/Read/ReadVariableOp:0") else: variable_names.append(variable_name + ":0") if variable_names: returned_variables = sess.run(variable_names) else: returned_variables = [] found_variables = dict(zip(variable_dict_names, returned_variables)) output_graph_def = graph_pb2.GraphDef() how_many_converted = 0 for input_node in inference_graph.node: output_node = node_def_pb2.NodeDef() if input_node.name in found_variables: output_node.op = "Const" output_node.name = input_node.name dtype = input_node.attr["dtype"] data = found_variables[input_node.name] if use_fp16 and dtype.type == types_pb2.DT_FLOAT: output_node.attr["value"].CopyFrom( attr_value_pb2.AttrValue( tensor=tensor_util.make_tensor_proto(data.astype('float16'), dtype=types_pb2.DT_HALF, shape=data.shape))) else: output_node.attr["dtype"].CopyFrom(dtype) output_node.attr["value"].CopyFrom(attr_value_pb2.AttrValue( tensor=tensor_util.make_tensor_proto(data, dtype=dtype.type, shape=data.shape))) how_many_converted += 1 elif input_node.op == "ReadVariableOp" and (input_node.input[0] in found_variables): # placeholder nodes # print('- %s | %s ' % (input_node.name, input_node.attr["dtype"])) output_node.op = "Identity" output_node.name = input_node.name output_node.input.extend([input_node.input[0]]) output_node.attr["T"].CopyFrom(input_node.attr["dtype"]) if "_class" in input_node.attr: output_node.attr["_class"].CopyFrom(input_node.attr["_class"]) else: # mostly op nodes output_node.CopyFrom(input_node) patch_dtype(input_node, 'dtype', output_node) patch_dtype(input_node, 'T', output_node) patch_dtype(input_node, 'DstT', output_node) patch_dtype(input_node, 'SrcT', output_node) patch_dtype(input_node, 'Tparams', output_node) if use_fp16 and ('value' in output_node.attr) and ( output_node.attr['value'].tensor.dtype == types_pb2.DT_FLOAT): # hard-coded value need to be converted as well output_node.attr['value'].CopyFrom(attr_value_pb2.AttrValue( tensor=tensor_util.make_tensor_proto( output_node.attr['value'].tensor.float_val[0], dtype=types_pb2.DT_HALF))) output_graph_def.node.extend([output_node]) output_graph_def.library.CopyFrom(inference_graph.library) return output_graph_def
def _StripNode(self, nd): snode = node_def_pb2.NodeDef(name=nd.name, op=nd.op, input=nd.input) if nd.device: snode.device = nd.device return snode
def apply_conv_single_fusion(self, match_node_name): skip_node_name = match_node_name[1:] matched_node = self.node_name_mapping[match_node_name[0]] _, normal_inputs = self._get_node_input(matched_node.node.name) weight_name = normal_inputs[1] # TODO this is workaround as the tf 2.1 doesn't support depthwise/conv s8 # feature. if self.enable_s8 and not self._find_relu_node(matched_node.node): self.output_graph = self.input_graph return self._intel_cpu_quantize_weight_eightbit( matched_node.node.op, self.node_name_mapping[weight_name].node, self.per_channel) all_input_names = self._add_eightbit_prologue_nodes( matched_node.node.name) skip_node_name.append(weight_name) for _, node in enumerate(self.input_graph.node): if node.name in skip_node_name: self.logger.debug("skip node {}".format(node.name)) elif node.name == match_node_name[0]: postfix = "_eightbit_quantized_conv" if node.op == "Conv2D" else "_eightbit_quantized_depthwise_conv" quantized_node_name = node.name + postfix if node.op == "Conv2D": quantized_conv_node = helper.create_node( "QuantizedConv2DPerChannel" if self.per_channel else "QuantizedConv2D", quantized_node_name, all_input_names) elif node.op == "DepthwiseConv2dNative": quantized_conv_node = helper.create_node( "QuantizedDepthwiseConv2D", quantized_node_name, all_input_names) helper.copy_attr(quantized_conv_node, "strides", node.attr["strides"]) helper.copy_attr(quantized_conv_node, "padding", node.attr["padding"]) if node.op != 'DepthwiseConv2dNative' and "padding_list" in node.attr: helper.copy_attr(quantized_conv_node, "padding_list", node.attr["padding_list"]) helper.copy_attr(quantized_conv_node, "dilations", node.attr["dilations"]) input_data_type = dtypes.quint8 if self._find_relu_node( node) else dtypes.qint8 helper.set_attr_dtype(quantized_conv_node, "Tinput", input_data_type) helper.set_attr_dtype(quantized_conv_node, "Tfilter", dtypes.qint8) helper.set_attr_dtype(quantized_conv_node, "out_type", dtypes.qint32) self.add_output_graph_node(quantized_conv_node) quantize_down_name = self._add_quantize_down_nodes( node, quantized_node_name, dtypes.qint8) self._intel_cpu_add_dequantize_result_node( quantize_down_name, node.name, dtypes.qint8) else: new_node = node_def_pb2.NodeDef() new_node.CopyFrom(node) self.add_output_graph_node(new_node)