Ejemplo n.º 1
0
 def apply_final_node_renames(self):
     """Applies node renames in self.final_node_renames to self.output_graph."""
     old_graph = self.output_graph
     self.output_graph = graph_pb2.GraphDef()
     for node in old_graph.node:
         node.name = self.final_node_renames.get(node.name, node.name)
         for index, input_name in enumerate(node.input):
             node_name = helper.node_name_from_input(input_name)
             input_full_name = helper.ensure_tensor_name_has_port(
                 input_name)
             if node_name in self.final_node_renames:
                 node.input[index] = "%s%s" % (
                     self.final_node_renames[node_name],
                     input_full_name[len(node_name):])
         self.add_output_graph_node(node)
     return self.output_graph
Ejemplo n.º 2
0
    def remove_redundant_quantization(self, old_graph):
        old_nodes_map = self.create_nodes_map(old_graph)
        self.output_graph = graph_pb2.GraphDef()
        inputs_to_rename = {}
        # We go through all the nodes, looking for any that match the patterns we
        # know how to optimize away.
        for node in old_graph.node:
            # We always start with a Quantize node, and examine its inputs to see if
            # they are in a form that can be removed.
            if node.op not in ["Quantize", "QuantizeV2"]:
                continue

            dequantize_node_name = helper.node_name_from_input(node.input[0])
            if dequantize_node_name not in old_nodes_map:
                raise ValueError("Input node name '" + dequantize_node_name +
                                 "' not found in node '" + node.name + "'")
            dequantize_node = old_nodes_map[dequantize_node_name]
            # Do we have a Dequantize feeding in, with the same type as the Quantize?
            if dequantize_node.op != "Dequantize":
                continue

            if node.attr["T"] != dequantize_node.attr["T"]:
                continue

            # Now look at the other inputs, and ensure they're Min/Max nodes.
            min_node_name = helper.node_name_from_input(node.input[1])
            max_node_name = helper.node_name_from_input(node.input[2])
            min_node = old_nodes_map[min_node_name]
            max_node = old_nodes_map[max_node_name]
            is_min_right_type = (min_node.op in ["Min", "Dequantize"])
            is_max_right_type = (max_node.op in ["Max", "Dequantize"])
            if not is_min_right_type or not is_max_right_type:
                print("Didn't find expected types on inputs : %s, %s." %
                      (min_node.op, max_node.op))
                continue
            min_node_input_name = helper.node_name_from_input(
                min_node.input[0])
            max_node_input_name = helper.node_name_from_input(
                max_node.input[0])
            # There are two different patterns for Min nodes we can recognize, one
            # where the input comes directly from the same one as the Max, and
            # another where we run it through another Min first, so check for both.
            is_same_input = False
            if min_node_input_name == max_node_input_name:
                is_same_input = True
            else:
                first_min_node_input = old_nodes_map[min_node_input_name]
                if first_min_node_input.op == "Concat":
                    second_min_node_name = helper.node_name_from_input(
                        first_min_node_input.input[1])
                    second_min_node = old_nodes_map[second_min_node_name]
                    if second_min_node.op == "Min":
                        second_min_node_input_name = helper.node_name_from_input(
                            second_min_node.input[0])
                        is_same_input = (
                            second_min_node_input_name == max_node_input_name)
            if not is_same_input:
                print("Different min/max inputs: " + min_node_input_name)
                continue
            # We recognize this pattern, so mark the graph edges to be rewired to
            # route around it entirely, since we know it's a no-op.
            dequantize_source_name = helper.node_name_from_input(
                dequantize_node.input[0])
            node_tensor_name = helper.ensure_tensor_name_has_port(node.name)
            min_tensor_name = node.name + ":1"
            max_tensor_name = node.name + ":2"

            inputs_to_rename[node_tensor_name] = dequantize_source_name
            inputs_to_rename[min_tensor_name] = dequantize_node.input[1]
            inputs_to_rename[max_tensor_name] = dequantize_node.input[2]
        # Finally we apply all the rewiring we've marked to the graph.
        for node in old_graph.node:
            for index, input_full_name in enumerate(node.input):
                input_name = helper.ensure_tensor_name_has_port(
                    input_full_name)
                if input_name in inputs_to_rename:
                    node.input[index] = inputs_to_rename[input_name]
            self.add_output_graph_node(node)
        return self.output_graph