示例#1
0
 def _find_relu_node(self, node):
     if node.op in ("Relu", "Relu6") or node.op.find("AndRelu") != -1:
         return True
     elif (node.op.find("QuantizedConv") != -1
           or node.op.find("QuantizedDepthwiseConv") != -1
           ) and node.op.find("Relu") == -1:
         return False
     elif self._need_to_check(node.op):
         input_node = self.node_name_mapping[helper.node_name_from_input(
             node.input[0])]
         return self._find_relu_node(input_node.node)
     else:
         return False
示例#2
0
 def apply_final_node_renames(self):
     """Applies node renames in self.final_node_renames to self.output_graph."""
     old_graph = self.output_graph
     self.output_graph = graph_pb2.GraphDef()
     for node in old_graph.node:
         node.name = self.final_node_renames.get(node.name, node.name)
         for index, input_name in enumerate(node.input):
             node_name = helper.node_name_from_input(input_name)
             input_full_name = helper.ensure_tensor_name_has_port(
                 input_name)
             if node_name in self.final_node_renames:
                 node.input[index] = "%s%s" % (
                     self.final_node_renames[node_name],
                     input_full_name[len(node_name):])
         self.add_output_graph_node(node)
     return self.output_graph
示例#3
0
    def _add_eightbit_prologue_nodes(self, original_node):
        namespace_prefix = original_node + "_eightbit"
        reshape_dims_name, reduction_dims_name = self._add_common_quantization_nodes(
            namespace_prefix,
            self.node_name_mapping[original_node].node.input[0])
        input_names = []
        min_max_names = []
        for each_input_name in self.node_name_mapping[
                original_node].node.input:
            if each_input_name[0] == '^':
                continue
            input_node_name = helper.node_name_from_input(each_input_name)
            if self.intel_cpu_eightbitize and input_node_name in self.output_node_maps:
                dtype = dtypes.DType(
                    self.output_node_maps[input_node_name].attr["T"].type
                ) if self.output_node_maps[
                    input_node_name].op == "Dequantize" else dtypes.quint8
            else:
                dtype = dtypes.quint8 if self._find_relu_node(
                    self.node_name_mapping[original_node].node
                ) else dtypes.qint8

            quantize_input_name, min_input_name, max_input_name = (
                self._eightbitize_input_to_node(namespace_prefix,
                                                each_input_name,
                                                reshape_dims_name,
                                                reduction_dims_name,
                                                dtype=dtype))
            input_names.append(quantize_input_name)
            min_max_names.append(min_input_name)
            min_max_names.append(max_input_name)
        all_input_names = []
        all_input_names.extend(input_names)
        all_input_names.extend(min_max_names)

        for original_input_name in self.node_name_mapping[
                original_node].node.input:
            if original_input_name[0] == '^':
                all_input_names.append(original_input_name)
        return all_input_names
示例#4
0
    def _parse_graph(self, input_graph=None):
        """
        Parse the graph and get the input node and output node name details.
        """
        logging.debug("start parsing graph")
        self.node_name_mapping = OrderedDict()

        graph = self.input_graph if input_graph is None else input_graph
        for node in graph.node:
            each_node = self.node_details(node=node, input_node=[], output=[])

            if node.name in self.node_name_mapping:
                raise ValueError(
                    "Duplicate Node Found when _parse_graph, the node name is {}"
                    .format(node.name))

            self.node_name_mapping[node.name] = each_node

        for node in graph.node:
            for input in node.input:
                self.node_name_mapping[helper.node_name_from_input(
                    input)].output.append(node.name)
 def _apply_concatv2_transform(self, original_node):
     namespace_prefix = original_node.name + "_eightbit"
     quantized_concat_name = namespace_prefix + "_quantized_concatv2"
     reshape_dims_name, reduction_dims_name = self._add_common_quantization_nodes(
         namespace_prefix,
         helper.node_name_from_input(original_node.input[-1]))
     num_input = len(original_node.input)
     shape_input_name = original_node.input[num_input - 1]
     original_inputs = original_node.input[0:num_input - 1]
     input_names = []
     min_names = []
     max_names = []
     for original_input_name in original_inputs:
         quantize_input_name, min_input_name, max_input_name = (
             self._eightbitize_input_to_node(namespace_prefix,
                                             original_input_name,
                                             reshape_dims_name,
                                             reduction_dims_name,
                                             dtype=dtypes.quint8))
         input_names.append(quantize_input_name)
         min_names.append(min_input_name)
         max_names.append(max_input_name)
     all_input_names = input_names
     all_input_names.append(shape_input_name)
     all_input_names.extend(min_names)
     all_input_names.extend(max_names)
     quantized_concat_node = helper.create_node("QuantizedConcatV2",
                                                quantized_concat_name,
                                                all_input_names)
     helper.set_attr_int(quantized_concat_node, "N", len(original_inputs))
     helper.set_attr_dtype(quantized_concat_node, "T", dtypes.quint8)
     self.add_output_graph_node(quantized_concat_node)
     if self.intel_cpu_eightbitize:
         self._intel_cpu_add_dequantize_result_node(quantized_concat_name,
                                                    original_node.name)
     else:
         self._add_dequantize_result_node(quantized_concat_name,
                                          original_node.name)
示例#6
0
    def remove_redundant_quantization(self, old_graph):
        old_nodes_map = self.create_nodes_map(old_graph)
        self.output_graph = graph_pb2.GraphDef()
        inputs_to_rename = {}
        # We go through all the nodes, looking for any that match the patterns we
        # know how to optimize away.
        for node in old_graph.node:
            # We always start with a Quantize node, and examine its inputs to see if
            # they are in a form that can be removed.
            if node.op not in ["Quantize", "QuantizeV2"]:
                continue

            dequantize_node_name = helper.node_name_from_input(node.input[0])
            if dequantize_node_name not in old_nodes_map:
                raise ValueError("Input node name '" + dequantize_node_name +
                                 "' not found in node '" + node.name + "'")
            dequantize_node = old_nodes_map[dequantize_node_name]
            # Do we have a Dequantize feeding in, with the same type as the Quantize?
            if dequantize_node.op != "Dequantize":
                continue

            if node.attr["T"] != dequantize_node.attr["T"]:
                continue

            # Now look at the other inputs, and ensure they're Min/Max nodes.
            min_node_name = helper.node_name_from_input(node.input[1])
            max_node_name = helper.node_name_from_input(node.input[2])
            min_node = old_nodes_map[min_node_name]
            max_node = old_nodes_map[max_node_name]
            is_min_right_type = (min_node.op in ["Min", "Dequantize"])
            is_max_right_type = (max_node.op in ["Max", "Dequantize"])
            if not is_min_right_type or not is_max_right_type:
                print("Didn't find expected types on inputs : %s, %s." %
                      (min_node.op, max_node.op))
                continue
            min_node_input_name = helper.node_name_from_input(
                min_node.input[0])
            max_node_input_name = helper.node_name_from_input(
                max_node.input[0])
            # There are two different patterns for Min nodes we can recognize, one
            # where the input comes directly from the same one as the Max, and
            # another where we run it through another Min first, so check for both.
            is_same_input = False
            if min_node_input_name == max_node_input_name:
                is_same_input = True
            else:
                first_min_node_input = old_nodes_map[min_node_input_name]
                if first_min_node_input.op == "Concat":
                    second_min_node_name = helper.node_name_from_input(
                        first_min_node_input.input[1])
                    second_min_node = old_nodes_map[second_min_node_name]
                    if second_min_node.op == "Min":
                        second_min_node_input_name = helper.node_name_from_input(
                            second_min_node.input[0])
                        is_same_input = (
                            second_min_node_input_name == max_node_input_name)
            if not is_same_input:
                print("Different min/max inputs: " + min_node_input_name)
                continue
            # We recognize this pattern, so mark the graph edges to be rewired to
            # route around it entirely, since we know it's a no-op.
            dequantize_source_name = helper.node_name_from_input(
                dequantize_node.input[0])
            node_tensor_name = helper.ensure_tensor_name_has_port(node.name)
            min_tensor_name = node.name + ":1"
            max_tensor_name = node.name + ":2"

            inputs_to_rename[node_tensor_name] = dequantize_source_name
            inputs_to_rename[min_tensor_name] = dequantize_node.input[1]
            inputs_to_rename[max_tensor_name] = dequantize_node.input[2]
        # Finally we apply all the rewiring we've marked to the graph.
        for node in old_graph.node:
            for index, input_full_name in enumerate(node.input):
                input_name = helper.ensure_tensor_name_has_port(
                    input_full_name)
                if input_name in inputs_to_rename:
                    node.input[index] = inputs_to_rename[input_name]
            self.add_output_graph_node(node)
        return self.output_graph