def _add_pool_function(self, original_node, quantized_op_node): helper.set_attr_dtype(quantized_op_node, "T", dtypes.quint8) helper.copy_attr(quantized_op_node, "ksize", original_node.attr["ksize"]) helper.copy_attr(quantized_op_node, "strides", original_node.attr["strides"]) helper.copy_attr(quantized_op_node, "padding", original_node.attr["padding"])
def apply_final_node_renames(self): """Applies node renames in self.final_node_renames to self.output_graph.""" old_graph = self.output_graph self.output_graph = graph_pb2.GraphDef() for node in old_graph.node: node.name = self.final_node_renames.get(node.name, node.name) for index, input_name in enumerate(node.input): node_name = helper.node_name_from_input(input_name) input_full_name = helper.ensure_tensor_name_has_port( input_name) if node_name in self.final_node_renames: node.input[index] = "%s%s" % ( self.final_node_renames[node_name], input_full_name[len(node_name):]) self.add_output_graph_node(node) return self.output_graph
def add_dequantize_result_node(self, quantized_output_name, original_node_name, min_tensor_index=1): min_max_inputs = [ "%s:%s" % (quantized_output_name, min_tensor_index), "%s:%s" % (quantized_output_name, (min_tensor_index + 1)) ] dequantize_name = original_node_name dequantize_node = helper.create_node( "Dequantize", dequantize_name, [quantized_output_name, min_max_inputs[0], min_max_inputs[1]]) helper.set_attr_dtype(dequantize_node, "T", dtypes.quint8) helper.set_attr_string( dequantize_node, "mode", b"SCALED" if self.intel_cpu_eightbitize else b"MIN_FIRST") self.add_output_graph_node(dequantize_node)
def _add_common_quantization_nodes(self, namespace_prefix, control_input_names=None): """Builds constant nodes needed for quantization of inputs.""" reshape_dims_name = namespace_prefix + "_reshape_dims" reduction_dims_name = namespace_prefix + "_reduction_dims" reshape_dims_node = helper.create_constant_node( reshape_dims_name, -1, dtypes.int32, [1]) if control_input_names: reshape_dims_node.input.append("^" + control_input_names) self.add_output_graph_node(reshape_dims_node) reduction_dims_node = helper.create_constant_node( reduction_dims_name, 0, dtypes.int32, [1]) if control_input_names: reduction_dims_node.input.append("^" + control_input_names) self.add_output_graph_node(reduction_dims_node) return reshape_dims_name, reduction_dims_name
def __init__(self, input_graph, output_node_names, perchannel=False, excluded_ops=[], excluded_nodes=[]): """Quantize Graph For Intel Cpu Arguments: input_graph {[type]} -- [description] rules {[type]} -- [description] output_node_names {[type]} -- [description] Keyword Arguments: debug {bool} -- [description] (default: {False}) """ super(QuantizeGraphForIntel, self).__init__(output_node_names) self.perchannel = perchannel if isinstance(input_graph, graph_pb2.GraphDef): self.input_graph = input_graph else: self.input_graph = graph_pb2.GraphDef() with gfile.Open(input_graph, 'rb') as f: self.input_graph.ParseFromString(f.read()) self.input_graph = graph_util.remove_training_nodes( self.input_graph, protected_nodes=self.output_node_names) self.input_graph = QuantizeGraphHelper().get_sorted_graph( self.input_graph, output_node_names) self.excluded_ops = excluded_ops self.excluded_nodes = excluded_nodes self.register_transformer("MaxPool", FuseNodeStartWithPooling) self.register_transformer("Conv2D", FuseNodeStartWithConv2d) self.register_transformer("DepthwiseConv2dNative", FuseNodeStartWithConv2d) self.register_transformer("AvgPool", FuseNodeStartWithPooling) self.register_transformer("ConcatV2", FuseNodeStartWithConcatV2) self.register_transformer("Pad", FuseNodeStartWithPad) # self.register_transformer("MatMul", FuseNodeStartWithMatmul) self.input_graph = QuantizeGraphHelper.split_shared_inputs( self.input_graph, self.transformers.keys())
def eightbitize_single_input_tensor_node(self, original_node, add_op_function): quantized_op_name = original_node.name + "_eightbit_quantized" quantized_op_type = "Quantized" + original_node.op all_input_names = self._add_eightbit_prologue_nodes(original_node.name) quantized_op_node = helper.create_node(quantized_op_type, quantized_op_name, all_input_names) add_op_function(original_node, quantized_op_node) self.add_output_graph_node(quantized_op_node) self._intel_cpu_add_dequantize_result_node(quantized_op_name, original_node.name)
def _find_relu_node(self, node): if node.op in ("Relu", "Relu6") or node.op.find("AndRelu") != -1: return True elif (node.op.find("QuantizedConv") != -1 or node.op.find("QuantizedDepthwiseConv") != -1 ) and node.op.find("Relu") == -1: return False elif self._need_to_check(node.op): input_node = self.node_name_mapping[helper.node_name_from_input( node.input[0])] return self._find_relu_node(input_node.node) else: return False
def _apply_pad_conv_fusion(self): for _, value in self.node_name_mapping.items(): if value.node.op in ("Pad") and self.node_name_mapping[ value. output[0]].node.op == "Conv2D" and self._find_relu_node( value.node): paddings_tensor = tensor_util.MakeNdarray( self.node_name_mapping[value.node.input[1]].node. attr["value"].tensor).flatten() if any(paddings_tensor): new_node = node_def_pb2.NodeDef() new_node.CopyFrom(value.node) self.add_output_graph_node(new_node) else: self.node_name_mapping[ value.output[0]].node.input[0] = value.node.input[0] helper.set_attr_int_list( self.node_name_mapping[value.output[0]].node, "padding_list", paddings_tensor) else: new_node = node_def_pb2.NodeDef() new_node.CopyFrom(value.node) self.add_output_graph_node(new_node)
def _add_eightbit_prologue_nodes(self, original_node): namespace_prefix = original_node + "_eightbit" reshape_dims_name, reduction_dims_name = self._add_common_quantization_nodes( namespace_prefix, self.node_name_mapping[original_node].node.input[0]) input_names = [] min_max_names = [] for each_input_name in self.node_name_mapping[ original_node].node.input: if each_input_name[0] == '^': continue input_node_name = helper.node_name_from_input(each_input_name) if self.intel_cpu_eightbitize and input_node_name in self.output_node_maps: dtype = dtypes.DType( self.output_node_maps[input_node_name].attr["T"].type ) if self.output_node_maps[ input_node_name].op == "Dequantize" else dtypes.quint8 else: dtype = dtypes.quint8 if self._find_relu_node( self.node_name_mapping[original_node].node ) else dtypes.qint8 quantize_input_name, min_input_name, max_input_name = ( self._eightbitize_input_to_node(namespace_prefix, each_input_name, reshape_dims_name, reduction_dims_name, dtype=dtype)) input_names.append(quantize_input_name) min_max_names.append(min_input_name) min_max_names.append(max_input_name) all_input_names = [] all_input_names.extend(input_names) all_input_names.extend(min_max_names) for original_input_name in self.node_name_mapping[ original_node].node.input: if original_input_name[0] == '^': all_input_names.append(original_input_name) return all_input_names
def _parse_graph(self, input_graph=None): """ Parse the graph and get the input node and output node name details. """ logging.debug("start parsing graph") self.node_name_mapping = OrderedDict() graph = self.input_graph if input_graph is None else input_graph for node in graph.node: each_node = self.node_details(node=node, input_node=[], output=[]) if node.name in self.node_name_mapping: raise ValueError( "Duplicate Node Found when _parse_graph, the node name is {}" .format(node.name)) self.node_name_mapping[node.name] = each_node for node in graph.node: for input in node.input: self.node_name_mapping[helper.node_name_from_input( input)].output.append(node.name)
def _apply_concatv2_transform(self, original_node): namespace_prefix = original_node.name + "_eightbit" quantized_concat_name = namespace_prefix + "_quantized_concatv2" reshape_dims_name, reduction_dims_name = self._add_common_quantization_nodes( namespace_prefix, helper.node_name_from_input(original_node.input[-1])) num_input = len(original_node.input) shape_input_name = original_node.input[num_input - 1] original_inputs = original_node.input[0:num_input - 1] input_names = [] min_names = [] max_names = [] for original_input_name in original_inputs: quantize_input_name, min_input_name, max_input_name = ( self._eightbitize_input_to_node(namespace_prefix, original_input_name, reshape_dims_name, reduction_dims_name, dtype=dtypes.quint8)) input_names.append(quantize_input_name) min_names.append(min_input_name) max_names.append(max_input_name) all_input_names = input_names all_input_names.append(shape_input_name) all_input_names.extend(min_names) all_input_names.extend(max_names) quantized_concat_node = helper.create_node("QuantizedConcatV2", quantized_concat_name, all_input_names) helper.set_attr_int(quantized_concat_node, "N", len(original_inputs)) helper.set_attr_dtype(quantized_concat_node, "T", dtypes.quint8) self.add_output_graph_node(quantized_concat_node) if self.intel_cpu_eightbitize: self._intel_cpu_add_dequantize_result_node(quantized_concat_name, original_node.name) else: self._add_dequantize_result_node(quantized_concat_name, original_node.name)
def apply_matmul_biasadd_fusion(self, match_node_name): skip_node_name = match_node_name[1:] matched_node = self.node_name_mapping[match_node_name[0]] control_inputs, normal_inputs = self._get_node_input( matched_node.node.name) weight_name = normal_inputs[1] self._intel_cpu_quantize_weight_eightbit( matched_node.node.op, self.node_name_mapping[weight_name].node, self.per_channel) skip_node_name.append(weight_name) for _, node in enumerate(self.input_graph.node): if node.name in skip_node_name: pass elif node.name == match_node_name[0]: logging.debug("matched node {} with input {}".format( node.name, node.input)) logging.debug("apply_conv_biasadd_fusion") quantized_node_name = node.name + "_eightbit_quantized_mat_mul" bias_node_name = self.node_name_mapping[ match_node_name[1]].node.input[1] all_input_names = self._add_eightbit_prologue_nodes( matched_node.node.name) quantized_node_input_names = all_input_names[:2] + [ bias_node_name ] + all_input_names[2:] + control_inputs quantized_matmul_node = helper.create_node( "QuantizedMatMulWithBias", quantized_node_name, quantized_node_input_names) helper.copy_attr(quantized_matmul_node, "transpose_a", node.attr["transpose_a"]) helper.copy_attr(quantized_matmul_node, "transpose_b", node.attr["transpose_b"]) helper.set_attr_dtype(quantized_matmul_node, "T1", dtypes.quint8) helper.set_attr_dtype(quantized_matmul_node, "T2", dtypes.qint8) helper.set_attr_dtype(quantized_matmul_node, "Toutput", dtypes.qint32) helper.set_attr_dtype(quantized_matmul_node, "Tbias", dtypes.float32) self.add_output_graph_node(quantized_matmul_node) requantize_type = dtypes.qint8 quantize_down_name = self._add_quantize_down_nodes( node, quantized_node_name, requantize_type, False) self._intel_cpu_add_dequantize_result_node( quantize_down_name, match_node_name[1], requantize_type) else: new_node = node_def_pb2.NodeDef() new_node.CopyFrom(node) self.add_output_graph_node(new_node)
def apply_conv_single_fusion(self, match_node_name): skip_node_name = match_node_name[1:] matched_node = self.node_name_mapping[match_node_name[0]] _, normal_inputs = self._get_node_input(matched_node.node.name) weight_name = normal_inputs[1] # TODO this is workaround as the tf 2.1 doesn't support depthwise s8 feature. if self.enable_s8 and matched_node.node.op == "DepthwiseConv2dNative" and not self._find_relu_node( matched_node.node): self.output_graph = self.input_graph return self._intel_cpu_quantize_weight_eightbit( matched_node.node.op, self.node_name_mapping[weight_name].node, self.per_channel) all_input_names = self._add_eightbit_prologue_nodes( matched_node.node.name) skip_node_name.append(weight_name) for _, node in enumerate(self.input_graph.node): if node.name in skip_node_name: logging.debug("skip node {}".format(node.name)) elif node.name == match_node_name[0]: postfix = "_eightbit_quantized_conv" if node.op == "Conv2D" else "_eightbit_quantized_depthwise_conv" quantized_node_name = node.name + postfix if node.op == "Conv2D": quantized_conv_node = helper.create_node( "QuantizedConv2DPerChannel" if self.per_channel else "QuantizedConv2D", quantized_node_name, all_input_names) elif node.op == "DepthwiseConv2dNative": quantized_conv_node = helper.create_node( "QuantizedDepthwiseConv2D", quantized_node_name, all_input_names) helper.copy_attr(quantized_conv_node, "strides", node.attr["strides"]) helper.copy_attr(quantized_conv_node, "padding", node.attr["padding"]) if node.op != 'DepthwiseConv2dNative' and "padding_list" in node.attr: helper.copy_attr(quantized_conv_node, "padding_list", node.attr["padding_list"]) helper.copy_attr(quantized_conv_node, "dilations", node.attr["dilations"]) input_data_type = dtypes.quint8 if self._find_relu_node( node) else dtypes.qint8 helper.set_attr_dtype(quantized_conv_node, "Tinput", input_data_type) helper.set_attr_dtype(quantized_conv_node, "Tfilter", dtypes.qint8) helper.set_attr_dtype(quantized_conv_node, "out_type", dtypes.qint32) self.add_output_graph_node(quantized_conv_node) quantize_down_name = self._add_quantize_down_nodes( node, quantized_node_name, dtypes.qint8) self._intel_cpu_add_dequantize_result_node( quantize_down_name, node.name, dtypes.qint8) else: new_node = node_def_pb2.NodeDef() new_node.CopyFrom(node) self.add_output_graph_node(new_node)
def apply_conv_biasadd_addn_relu_fusion(self, match_node_name): skip_node_name = match_node_name[1:] matched_node = self.node_name_mapping[match_node_name[0]] control_inputs, normal_inputs = self._get_node_input( matched_node.node.name) weight_name = normal_inputs[1] self._intel_cpu_quantize_weight_eightbit( matched_node.node.op, self.node_name_mapping[weight_name].node, self.per_channel) all_input_names = self._add_eightbit_prologue_nodes( matched_node.node.name) skip_node_name.append(weight_name) for _, node in enumerate(self.input_graph.node): if node.name in skip_node_name: logging.debug("skip node {}".format(node.name)) elif node.name == match_node_name[0]: logging.debug("matched node {} with input {}".format( node.name, node.input)) logging.debug("apply_conv_biasadd_addn_relu_fusion") quantized_node_name = node.name + "_eightbit_quantized_conv" bias_node_name = self.node_name_mapping[ match_node_name[1]].node.input[1] relu_node_name = match_node_name[3] is_relu6 = self.node_name_mapping[ relu_node_name].node.op == "Relu6" sum_index = 1 if match_node_name[1] == self.node_name_mapping[ match_node_name[2]].node.input[0] else 0 quantized_node_input_names = all_input_names[:2] + [ bias_node_name ] + all_input_names[2:] + [ self.node_name_mapping[ match_node_name[2]].node.input[sum_index] ] + control_inputs quantized_conv_node = helper.create_node( "QuantizedConv2DWithBiasSumAndRelu", quantized_node_name, quantized_node_input_names) helper.copy_attr(quantized_conv_node, "strides", node.attr["strides"]) helper.copy_attr(quantized_conv_node, "padding", node.attr["padding"]) if "padding_list" in node.attr: helper.copy_attr(quantized_conv_node, "padding_list", node.attr["padding_list"]) helper.copy_attr(quantized_conv_node, "dilations", node.attr["dilations"]) input_data_type = dtypes.quint8 if self._find_relu_node( node) else dtypes.qint8 helper.set_attr_dtype(quantized_conv_node, "Tinput", input_data_type) helper.set_attr_dtype(quantized_conv_node, "Tfilter", dtypes.qint8) helper.set_attr_dtype(quantized_conv_node, "out_type", dtypes.qint32) self.add_output_graph_node(quantized_conv_node) quantize_down_name = self._add_quantize_down_nodes( node, quantized_node_name, dtypes.quint8, is_relu6) self._intel_cpu_add_dequantize_result_node( quantize_down_name, relu_node_name) else: new_node = node_def_pb2.NodeDef() new_node.CopyFrom(node) self.add_output_graph_node(new_node)
def apply_conv_biasadd_relu_fusion(self, match_node_name): """Fuse the conv/biasadd/relu pattern. Arguments: match_node_name {[type]} -- [description] """ skip_node_name = match_node_name[1:] matched_node = self.node_name_mapping[match_node_name[0]] control_inputs, normal_inputs = self._get_node_input( matched_node.node.name) weight_name = normal_inputs[1] self._intel_cpu_quantize_weight_eightbit( matched_node.node.op, self.node_name_mapping[weight_name].node, self.per_channel) all_input_names = self._add_eightbit_prologue_nodes( matched_node.node.name) skip_node_name.append(weight_name) for _, node in enumerate(self.input_graph.node): if node.name in skip_node_name: logging.debug("skip node {}".format(node.name)) elif node.name == match_node_name[0]: logging.debug("apply_conv_biasadd_relu_fusion") postfix = "_eightbit_quantized_conv" if node.op == "Conv2D" else "_eightbit_quantized_depthwise_conv" quantized_node_name = node.name + postfix bias_node_name = self.node_name_mapping[ match_node_name[1]].node.input[1] relu_node_name = match_node_name[2] is_relu6 = self.node_name_mapping[ relu_node_name].node.op == "Relu6" quantized_node_input_names = all_input_names[:2] + [ bias_node_name ] + all_input_names[2:] + control_inputs quantized_conv_node = helper.create_node( "QuantizedConv2DWithBiasAndRelu" if node.op == "Conv2D" else "QuantizedDepthwiseConv2DWithBiasAndRelu", quantized_node_name, quantized_node_input_names) helper.copy_attr(quantized_conv_node, "strides", node.attr["strides"]) helper.copy_attr(quantized_conv_node, "padding", node.attr["padding"]) if node.op != 'DepthwiseConv2dNative' and "padding_list" in node.attr: helper.copy_attr(quantized_conv_node, "padding_list", node.attr["padding_list"]) helper.copy_attr(quantized_conv_node, "dilations", node.attr["dilations"]) input_data_type = dtypes.quint8 if self._find_relu_node( node) else dtypes.qint8 helper.set_attr_dtype(quantized_conv_node, "Tinput", input_data_type) helper.set_attr_dtype(quantized_conv_node, "Tfilter", dtypes.qint8) helper.set_attr_dtype(quantized_conv_node, "out_type", dtypes.qint32) self.add_output_graph_node(quantized_conv_node) quantize_down_name = self._add_quantize_down_nodes( node, quantized_node_name, dtypes.quint8, is_relu6) self._intel_cpu_add_dequantize_result_node( quantize_down_name, relu_node_name) else: new_node = node_def_pb2.NodeDef() new_node.CopyFrom(node) self.add_output_graph_node(new_node)
def _intel_cpu_quantize_weight_eightbit(self, parent, input_node, per_channel, quantization_mode=b"SCALED"): base_name = input_node.name + "_" qint8_const_name = base_name + "qint8_const" min_name = base_name + "min" max_name = base_name + "max" float_tensor = tensor_util.MakeNdarray(input_node.attr["value"].tensor) epsilon = 1e-4 # Needs to be set empirically if accuracy is not satisfactory if parent in ("Conv2D", "MatMul"): if per_channel: ranges = np.abs(float_tensor).max(axis=(0, 1, 2)) min_value = -ranges max_value = ranges # nudging min-max values outside epsilon radius around zero ranges[ranges < epsilon] = epsilon min_value[np.abs(min_value) < epsilon] = -epsilon max_value[np.abs(max_value) < epsilon] = epsilon qint8_tensor = (float_tensor * 127.0 / ranges).astype(np.int8) else: min_value = np.min(float_tensor.flatten()) max_value = np.max(float_tensor.flatten()) # Same processing of min-max as in quantize_weight_eightbit # function. if min_value > 0.0: min_value = 0.0 if min_value == max_value: if abs(min_value) < 0.000001: max_value = min_value + 1.0 elif min_value > 0: max_value = 2 * min_value else: max_value = min_value / 2.0 sess = session.Session() with sess.as_default(): quantize_op = array_ops.quantize_v2( float_tensor, min_value, max_value, dtypes.qint8, mode=quantization_mode, round_mode="HALF_TO_EVEN") qint8_tensor = quantize_op[0].eval() # Updated min-max values should be passed to the next feeding node. min_value = quantize_op[1].eval() max_value = quantize_op[2].eval() elif parent == "DepthwiseConv2dNative": # get the max values based on dim 0 and 1 for depthwise conv # since, the output channel will be dim 2 * dim 3 ranges = np.abs(float_tensor).max(axis=(0, 1)) ranges = ranges.flatten() min_value = -ranges max_value = ranges # nudging min-max values outside epsilon radius around zero ranges[ranges < epsilon] = epsilon min_value[np.abs(min_value) < epsilon] = -epsilon max_value[np.abs(max_value) < epsilon] = epsilon # Since output channel will be 1 dim which is dim 2 * dim 3 # When divide by range, qint8_tensor needs to be 3 dim # where, 3rd dim should be same dim of ranges a, b, c, d = float_tensor.shape qint8_tensor = (float_tensor.reshape(a, b, c * d) * 127.0 / ranges).astype(np.int8) # get the shape back to 4 dim qint8_tensor = qint8_tensor.reshape(a, b, c, d) shape = tensor_util.TensorShapeProtoToList( input_node.attr["value"].tensor.tensor_shape) qint8_const_node = helper.create_constant_node(qint8_const_name, qint8_tensor, dtypes.qint8, shape=shape) min_node = helper.create_constant_node(min_name, min_value, dtypes.float32) max_node = helper.create_constant_node(max_name, max_value, dtypes.float32) dequantize_node = helper.create_node( "Dequantize", input_node.name, [qint8_const_name, min_name, max_name]) helper.set_attr_dtype(dequantize_node, "T", dtypes.qint8) helper.set_attr_string(dequantize_node, "mode", b"SCALED") self.add_output_graph_node(qint8_const_node) self.add_output_graph_node(min_node) self.add_output_graph_node(max_node) self.add_output_graph_node(dequantize_node)
def _eightbitize_input_to_node(self, namespace_prefix, original_input_name, reshape_dims_name, reduction_dims_name, dtype=dtypes.quint8): """Takes one float input to an op, and converts it to quantized form.""" unique_input_name = helper.unique_node_name_from_input( original_input_name) if unique_input_name in self.quantized_node_dict: quantized_tuple = self.quantized_node_dict[unique_input_name] return quantized_tuple[0], quantized_tuple[1], quantized_tuple[2] reshape_input_name = namespace_prefix + "_reshape_" + unique_input_name min_input_name = namespace_prefix + "_min_" + unique_input_name max_input_name = namespace_prefix + "_max_" + unique_input_name quantize_input_name = namespace_prefix + "_quantize_" + unique_input_name reshape_input_node = helper.create_node( "Reshape", reshape_input_name, [original_input_name, reshape_dims_name]) helper.set_attr_dtype(reshape_input_node, "T", dtypes.float32) self.add_output_graph_node(reshape_input_node) min_input_node = helper.create_node( "Min", min_input_name, [reshape_input_name, reduction_dims_name]) helper.set_attr_dtype(min_input_node, "T", dtypes.float32) helper.set_attr_dtype(min_input_node, "Tidx", dtypes.int32) helper.set_attr_bool(min_input_node, "keep_dims", False) self.add_output_graph_node(min_input_node) max_input_node = helper.create_node( "Max", max_input_name, [reshape_input_name, reduction_dims_name]) helper.set_attr_dtype(max_input_node, "T", dtypes.float32) helper.set_attr_dtype(max_input_node, "Tidx", dtypes.int32) helper.set_attr_bool(max_input_node, "keep_dims", False) self.add_output_graph_node(max_input_node) quantize_input_node = helper.create_node( "QuantizeV2", quantize_input_name, [original_input_name, min_input_name, max_input_name]) helper.set_attr_dtype(quantize_input_node, "T", dtype) helper.set_attr_string(quantize_input_node, "mode", b"SCALED") helper.set_attr_string(quantize_input_node, "round_mode", b"HALF_TO_EVEN") # if FLAGS.model_name in ["wide_deep_large_ds"]: # set_attr_string(quantize_input_node, "mode", b"MIN_FIRST") # else: # set_attr_string(quantize_input_node, "mode", # b"SCALED" if self.intel_cpu_eightbitize else b"MIN_FIRST") # set_attr_string(quantize_input_node, "round_mode", # b"HALF_TO_EVEN" if self.intel_cpu_eightbitize # else b"HALF_AWAY_FROM_ZERO") self.add_output_graph_node(quantize_input_node) min_output_name = quantize_input_name + ":1" max_output_name = quantize_input_name + ":2" self.quantized_node_dict[unique_input_name] = (quantize_input_name, min_output_name, max_output_name) return quantize_input_name, min_output_name, max_output_name
def _add_quantize_down_nodes(self, original_node, quantized_output_name, requantize_type=dtypes.quint8, is_relu6=False): quantized_outputs = [ quantized_output_name, quantized_output_name + ":1", quantized_output_name + ":2" ] # Add a RequantizationRange node for finding the min and max values. requant_range_node = helper.create_node( "RequantizationRangePerChannel" if self.per_channel else "RequantizationRange", original_node.name + "_eightbit_requant_range", quantized_outputs) if self.per_channel: helper.set_attr_dtype(requant_range_node, "T", dtypes.qint32) if is_relu6: helper.set_attr_float(requant_range_node, "clip_value_max", 6.0) else: helper.set_attr_float(requant_range_node, "clip_value_max", 1e30) else: helper.set_attr_dtype(requant_range_node, "Tinput", dtypes.qint32) self.add_output_graph_node(requant_range_node) min_max_inputs = [ requant_range_node.name + ":0", requant_range_node.name + ":1" ] requantize_node = helper.create_node( "RequantizePerChannel" if self.per_channel else "Requantize", original_node.name + "_eightbit_requantize", quantized_outputs + min_max_inputs) if self.per_channel: helper.set_attr_dtype(requantize_node, "T", dtypes.qint32) else: helper.set_attr_dtype(requantize_node, "Tinput", dtypes.qint32) helper.set_attr_dtype(requantize_node, "out_type", requantize_type) self.add_output_graph_node(requantize_node) return requantize_node.name
def remove_redundant_quantization(self, old_graph): old_nodes_map = self.create_nodes_map(old_graph) self.output_graph = graph_pb2.GraphDef() inputs_to_rename = {} # We go through all the nodes, looking for any that match the patterns we # know how to optimize away. for node in old_graph.node: # We always start with a Quantize node, and examine its inputs to see if # they are in a form that can be removed. if node.op not in ["Quantize", "QuantizeV2"]: continue dequantize_node_name = helper.node_name_from_input(node.input[0]) if dequantize_node_name not in old_nodes_map: raise ValueError("Input node name '" + dequantize_node_name + "' not found in node '" + node.name + "'") dequantize_node = old_nodes_map[dequantize_node_name] # Do we have a Dequantize feeding in, with the same type as the Quantize? if dequantize_node.op != "Dequantize": continue if node.attr["T"] != dequantize_node.attr["T"]: continue # Now look at the other inputs, and ensure they're Min/Max nodes. min_node_name = helper.node_name_from_input(node.input[1]) max_node_name = helper.node_name_from_input(node.input[2]) min_node = old_nodes_map[min_node_name] max_node = old_nodes_map[max_node_name] is_min_right_type = (min_node.op in ["Min", "Dequantize"]) is_max_right_type = (max_node.op in ["Max", "Dequantize"]) if not is_min_right_type or not is_max_right_type: print("Didn't find expected types on inputs : %s, %s." % (min_node.op, max_node.op)) continue min_node_input_name = helper.node_name_from_input( min_node.input[0]) max_node_input_name = helper.node_name_from_input( max_node.input[0]) # There are two different patterns for Min nodes we can recognize, one # where the input comes directly from the same one as the Max, and # another where we run it through another Min first, so check for both. is_same_input = False if min_node_input_name == max_node_input_name: is_same_input = True else: first_min_node_input = old_nodes_map[min_node_input_name] if first_min_node_input.op == "Concat": second_min_node_name = helper.node_name_from_input( first_min_node_input.input[1]) second_min_node = old_nodes_map[second_min_node_name] if second_min_node.op == "Min": second_min_node_input_name = helper.node_name_from_input( second_min_node.input[0]) is_same_input = ( second_min_node_input_name == max_node_input_name) if not is_same_input: print("Different min/max inputs: " + min_node_input_name) continue # We recognize this pattern, so mark the graph edges to be rewired to # route around it entirely, since we know it's a no-op. dequantize_source_name = helper.node_name_from_input( dequantize_node.input[0]) node_tensor_name = helper.ensure_tensor_name_has_port(node.name) min_tensor_name = node.name + ":1" max_tensor_name = node.name + ":2" inputs_to_rename[node_tensor_name] = dequantize_source_name inputs_to_rename[min_tensor_name] = dequantize_node.input[1] inputs_to_rename[max_tensor_name] = dequantize_node.input[2] # Finally we apply all the rewiring we've marked to the graph. for node in old_graph.node: for index, input_full_name in enumerate(node.input): input_name = helper.ensure_tensor_name_has_port( input_full_name) if input_name in inputs_to_rename: node.input[index] = inputs_to_rename[input_name] self.add_output_graph_node(node) return self.output_graph