def _find_relu_node(self, node): if node.op in ("Relu", "Relu6") or node.op.find("AndRelu") != -1: return True elif (node.op.find("QuantizedConv") != -1 or node.op.find("QuantizedDepthwiseConv") != -1 ) and node.op.find("Relu") == -1: return False elif self._need_to_check(node.op): input_node = self.node_name_mapping[helper.node_name_from_input( node.input[0])] return self._find_relu_node(input_node.node) else: return False
def apply_final_node_renames(self): """Applies node renames in self.final_node_renames to self.output_graph.""" old_graph = self.output_graph self.output_graph = graph_pb2.GraphDef() for node in old_graph.node: node.name = self.final_node_renames.get(node.name, node.name) for index, input_name in enumerate(node.input): node_name = helper.node_name_from_input(input_name) input_full_name = helper.ensure_tensor_name_has_port( input_name) if node_name in self.final_node_renames: node.input[index] = "%s%s" % ( self.final_node_renames[node_name], input_full_name[len(node_name):]) self.add_output_graph_node(node) return self.output_graph
def _add_eightbit_prologue_nodes(self, original_node): namespace_prefix = original_node + "_eightbit" reshape_dims_name, reduction_dims_name = self._add_common_quantization_nodes( namespace_prefix, self.node_name_mapping[original_node].node.input[0]) input_names = [] min_max_names = [] for each_input_name in self.node_name_mapping[ original_node].node.input: if each_input_name[0] == '^': continue input_node_name = helper.node_name_from_input(each_input_name) if self.intel_cpu_eightbitize and input_node_name in self.output_node_maps: dtype = dtypes.DType( self.output_node_maps[input_node_name].attr["T"].type ) if self.output_node_maps[ input_node_name].op == "Dequantize" else dtypes.quint8 else: dtype = dtypes.quint8 if self._find_relu_node( self.node_name_mapping[original_node].node ) else dtypes.qint8 quantize_input_name, min_input_name, max_input_name = ( self._eightbitize_input_to_node(namespace_prefix, each_input_name, reshape_dims_name, reduction_dims_name, dtype=dtype)) input_names.append(quantize_input_name) min_max_names.append(min_input_name) min_max_names.append(max_input_name) all_input_names = [] all_input_names.extend(input_names) all_input_names.extend(min_max_names) for original_input_name in self.node_name_mapping[ original_node].node.input: if original_input_name[0] == '^': all_input_names.append(original_input_name) return all_input_names
def _parse_graph(self, input_graph=None): """ Parse the graph and get the input node and output node name details. """ logging.debug("start parsing graph") self.node_name_mapping = OrderedDict() graph = self.input_graph if input_graph is None else input_graph for node in graph.node: each_node = self.node_details(node=node, input_node=[], output=[]) if node.name in self.node_name_mapping: raise ValueError( "Duplicate Node Found when _parse_graph, the node name is {}" .format(node.name)) self.node_name_mapping[node.name] = each_node for node in graph.node: for input in node.input: self.node_name_mapping[helper.node_name_from_input( input)].output.append(node.name)
def _apply_concatv2_transform(self, original_node): namespace_prefix = original_node.name + "_eightbit" quantized_concat_name = namespace_prefix + "_quantized_concatv2" reshape_dims_name, reduction_dims_name = self._add_common_quantization_nodes( namespace_prefix, helper.node_name_from_input(original_node.input[-1])) num_input = len(original_node.input) shape_input_name = original_node.input[num_input - 1] original_inputs = original_node.input[0:num_input - 1] input_names = [] min_names = [] max_names = [] for original_input_name in original_inputs: quantize_input_name, min_input_name, max_input_name = ( self._eightbitize_input_to_node(namespace_prefix, original_input_name, reshape_dims_name, reduction_dims_name, dtype=dtypes.quint8)) input_names.append(quantize_input_name) min_names.append(min_input_name) max_names.append(max_input_name) all_input_names = input_names all_input_names.append(shape_input_name) all_input_names.extend(min_names) all_input_names.extend(max_names) quantized_concat_node = helper.create_node("QuantizedConcatV2", quantized_concat_name, all_input_names) helper.set_attr_int(quantized_concat_node, "N", len(original_inputs)) helper.set_attr_dtype(quantized_concat_node, "T", dtypes.quint8) self.add_output_graph_node(quantized_concat_node) if self.intel_cpu_eightbitize: self._intel_cpu_add_dequantize_result_node(quantized_concat_name, original_node.name) else: self._add_dequantize_result_node(quantized_concat_name, original_node.name)
def remove_redundant_quantization(self, old_graph): old_nodes_map = self.create_nodes_map(old_graph) self.output_graph = graph_pb2.GraphDef() inputs_to_rename = {} # We go through all the nodes, looking for any that match the patterns we # know how to optimize away. for node in old_graph.node: # We always start with a Quantize node, and examine its inputs to see if # they are in a form that can be removed. if node.op not in ["Quantize", "QuantizeV2"]: continue dequantize_node_name = helper.node_name_from_input(node.input[0]) if dequantize_node_name not in old_nodes_map: raise ValueError("Input node name '" + dequantize_node_name + "' not found in node '" + node.name + "'") dequantize_node = old_nodes_map[dequantize_node_name] # Do we have a Dequantize feeding in, with the same type as the Quantize? if dequantize_node.op != "Dequantize": continue if node.attr["T"] != dequantize_node.attr["T"]: continue # Now look at the other inputs, and ensure they're Min/Max nodes. min_node_name = helper.node_name_from_input(node.input[1]) max_node_name = helper.node_name_from_input(node.input[2]) min_node = old_nodes_map[min_node_name] max_node = old_nodes_map[max_node_name] is_min_right_type = (min_node.op in ["Min", "Dequantize"]) is_max_right_type = (max_node.op in ["Max", "Dequantize"]) if not is_min_right_type or not is_max_right_type: print("Didn't find expected types on inputs : %s, %s." % (min_node.op, max_node.op)) continue min_node_input_name = helper.node_name_from_input( min_node.input[0]) max_node_input_name = helper.node_name_from_input( max_node.input[0]) # There are two different patterns for Min nodes we can recognize, one # where the input comes directly from the same one as the Max, and # another where we run it through another Min first, so check for both. is_same_input = False if min_node_input_name == max_node_input_name: is_same_input = True else: first_min_node_input = old_nodes_map[min_node_input_name] if first_min_node_input.op == "Concat": second_min_node_name = helper.node_name_from_input( first_min_node_input.input[1]) second_min_node = old_nodes_map[second_min_node_name] if second_min_node.op == "Min": second_min_node_input_name = helper.node_name_from_input( second_min_node.input[0]) is_same_input = ( second_min_node_input_name == max_node_input_name) if not is_same_input: print("Different min/max inputs: " + min_node_input_name) continue # We recognize this pattern, so mark the graph edges to be rewired to # route around it entirely, since we know it's a no-op. dequantize_source_name = helper.node_name_from_input( dequantize_node.input[0]) node_tensor_name = helper.ensure_tensor_name_has_port(node.name) min_tensor_name = node.name + ":1" max_tensor_name = node.name + ":2" inputs_to_rename[node_tensor_name] = dequantize_source_name inputs_to_rename[min_tensor_name] = dequantize_node.input[1] inputs_to_rename[max_tensor_name] = dequantize_node.input[2] # Finally we apply all the rewiring we've marked to the graph. for node in old_graph.node: for index, input_full_name in enumerate(node.input): input_name = helper.ensure_tensor_name_has_port( input_full_name) if input_name in inputs_to_rename: node.input[index] = inputs_to_rename[input_name] self.add_output_graph_node(node) return self.output_graph