def apply_final_node_renames(self): """Applies node renames in self.final_node_renames to self.output_graph.""" old_graph = self.output_graph self.output_graph = graph_pb2.GraphDef() for node in old_graph.node: node.name = self.final_node_renames.get(node.name, node.name) for index, input_name in enumerate(node.input): node_name = helper.node_name_from_input(input_name) input_full_name = helper.ensure_tensor_name_has_port( input_name) if node_name in self.final_node_renames: node.input[index] = "%s%s" % ( self.final_node_renames[node_name], input_full_name[len(node_name):]) self.add_output_graph_node(node) return self.output_graph
def remove_redundant_quantization(self, old_graph): old_nodes_map = self.create_nodes_map(old_graph) self.output_graph = graph_pb2.GraphDef() inputs_to_rename = {} # We go through all the nodes, looking for any that match the patterns we # know how to optimize away. for node in old_graph.node: # We always start with a Quantize node, and examine its inputs to see if # they are in a form that can be removed. if node.op not in ["Quantize", "QuantizeV2"]: continue dequantize_node_name = helper.node_name_from_input(node.input[0]) if dequantize_node_name not in old_nodes_map: raise ValueError("Input node name '" + dequantize_node_name + "' not found in node '" + node.name + "'") dequantize_node = old_nodes_map[dequantize_node_name] # Do we have a Dequantize feeding in, with the same type as the Quantize? if dequantize_node.op != "Dequantize": continue if node.attr["T"] != dequantize_node.attr["T"]: continue # Now look at the other inputs, and ensure they're Min/Max nodes. min_node_name = helper.node_name_from_input(node.input[1]) max_node_name = helper.node_name_from_input(node.input[2]) min_node = old_nodes_map[min_node_name] max_node = old_nodes_map[max_node_name] is_min_right_type = (min_node.op in ["Min", "Dequantize"]) is_max_right_type = (max_node.op in ["Max", "Dequantize"]) if not is_min_right_type or not is_max_right_type: print("Didn't find expected types on inputs : %s, %s." % (min_node.op, max_node.op)) continue min_node_input_name = helper.node_name_from_input( min_node.input[0]) max_node_input_name = helper.node_name_from_input( max_node.input[0]) # There are two different patterns for Min nodes we can recognize, one # where the input comes directly from the same one as the Max, and # another where we run it through another Min first, so check for both. is_same_input = False if min_node_input_name == max_node_input_name: is_same_input = True else: first_min_node_input = old_nodes_map[min_node_input_name] if first_min_node_input.op == "Concat": second_min_node_name = helper.node_name_from_input( first_min_node_input.input[1]) second_min_node = old_nodes_map[second_min_node_name] if second_min_node.op == "Min": second_min_node_input_name = helper.node_name_from_input( second_min_node.input[0]) is_same_input = ( second_min_node_input_name == max_node_input_name) if not is_same_input: print("Different min/max inputs: " + min_node_input_name) continue # We recognize this pattern, so mark the graph edges to be rewired to # route around it entirely, since we know it's a no-op. dequantize_source_name = helper.node_name_from_input( dequantize_node.input[0]) node_tensor_name = helper.ensure_tensor_name_has_port(node.name) min_tensor_name = node.name + ":1" max_tensor_name = node.name + ":2" inputs_to_rename[node_tensor_name] = dequantize_source_name inputs_to_rename[min_tensor_name] = dequantize_node.input[1] inputs_to_rename[max_tensor_name] = dequantize_node.input[2] # Finally we apply all the rewiring we've marked to the graph. for node in old_graph.node: for index, input_full_name in enumerate(node.input): input_name = helper.ensure_tensor_name_has_port( input_full_name) if input_name in inputs_to_rename: node.input[index] = inputs_to_rename[input_name] self.add_output_graph_node(node) return self.output_graph