def _delete_repeated_qat_blocks(model: ModelProto): # removes repeated qat quant/dequant blocks with the same parameters # (Quant -> Dequant -> Quant -> Dequant) -> (Quant -> Dequant) graph = ONNXGraph(model) nodes_to_delete = [] quant_nodes = [ n for n in model.graph.node if n.op_type == "QuantizeLinear" ] for quant_node_1 in quant_nodes: dequant_node_1 = graph.get_node_single_child(quant_node_1) if not dequant_node_1 or dequant_node_1.op_type != "DequantizeLinear": continue quant_node_2 = graph.get_node_single_child(dequant_node_1) if not quant_node_2 or quant_node_2.op_type != "QuantizeLinear": continue dequant_node_2 = graph.get_node_single_child(quant_node_2) if not dequant_node_2 or dequant_node_2.op_type != "DequantizeLinear": continue # forward first qat block input to that of the second quant_node_2.input[0] = quant_node_1.input[0] # remove repeated quant/dequant block nodes_to_delete.append(quant_node_1) nodes_to_delete.append(dequant_node_1) for n in nodes_to_delete: delete_quant_node(model, n)
def get_named_prunable_params(self, model: Any) -> Dict[str, numpy.ndarray]: """ loads the prunable parameters in a standardized way so that weight magnitude analysis may be run on each :param model: model to load the prunable parameters from :return: dictionary of prunable parameter name as listed in the ModelInfo to a numpy array of the values of the parameter """ graph = ONNXGraph(model) return { layer_name: numpy_helper.to_array(graph.get_init_by_name(layer_name, False)) for layer_name, layer_info in self._model_info.layer_info.items() if layer_info.prunable }
def quantize_torch_qat_export( model: Union[ModelProto, str], output_file_path: Union[str, None] = None, inplace: bool = True, ) -> ModelProto: """ :param model: The model to convert, or a file path to it :param output_file_path: File path to save the converted model to :param inplace: If true, does conversion of model in place. Default is true :return: Converts a model exported from a torch QAT session from a QAT graph with fake quantize ops surrounding operations to a quantized graph with quantized operations. All quantized Convs and FC inputs and outputs be surrounded by fake quantize ops """ if isinstance(model, str): model = onnx.load(model) if not inplace: model = deepcopy(model) _fold_qat_conv_bns(model) _fold_relu_quants(model) _convert_single_constants_to_initializers(model) _delete_repeated_qat_blocks(model) _convert_quantizable_ops(model) quantize_resnet_identity_add_inputs(model) quantized_residual_add_optim(model) _remove_duplicate_quantize__ops(model) ONNXGraph(model).sort_nodes_topologically() if output_file_path: onnx.save(model, output_file_path) return model
def _skip_input_quantize(model: ModelProto) -> Optional[str]: if ( len(model.graph.input) != 1 or model.graph.input[0].type.tensor_type.elem_type != 1 ): # more than 1 input or input is not FP32 return ( "Not modifying ONNX graph inputs - either graph has more than one " "input or input type is not FP32" ) input_node = model.graph.input[0] input_children = [ node for node in model.graph.node if input_node.name in node.input ] if not all(node.op_type == "QuantizeLinear" for node in input_children): return ( "Not modifying ONNX graph inputs - only QuantizeLinear nodes may follow" "the FP32 input tensor in original graph, prior to converting to uint8" ) graph = ONNXGraph(model) for quantize_node in input_children: quantize_children = graph.get_node_children(quantize_node) quantize_node_id = quantize_node.output[0] for child_node in quantize_children: input_idx = [ idx for idx, inp in enumerate(child_node.input) if inp == quantize_node_id ] if not input_idx: continue input_idx = input_idx[0] graph.update_node_input(child_node, input_node.name, input_idx) _LOGGER.debug( f"set node with output id {child_node.output[0]} as initial node in " "graph" ) _LOGGER.debug( f"deleting QuantizeLinear node(s) with output id(s): " f"{[n.output for n in input_children]}" ) graph.delete_nodes(input_children) # only contains references to the Quantize nodes graph.delete_unused_initializers() # cleanup input_node.type.tensor_type.elem_type = 2 # fp32 -> uint8 _LOGGER.info("Model initial QuantizeLinear node(s) deleted and inputs set to uint8") return None
def _get_model_last_prunable_nodes(model: ModelProto) -> List[NodeProto]: graph = ONNXGraph(model) output_names = {tens.name for tens in model.graph.output} stack = [ node for node in model.graph.node if any(out in output_names for out in node.output) ] seen_node_ids = {output_id for node in stack for output_id in node.output} last_prunable_nodes = [] while stack: node = stack.pop() if node.op_type in ["Gemm", "MatMul", "Conv"]: last_prunable_nodes.append(node) continue for parent in graph.get_node_parents(node): if any(output_id in seen_node_ids for output_id in parent.output): continue stack.append(parent) seen_node_ids.update(set(parent.output)) return last_prunable_nodes
def _get_model_first_prunable_nodes(model: ModelProto) -> List[NodeProto]: graph = ONNXGraph(model) input_names = {tens.name for tens in model.graph.input} stack = [ node for node in model.graph.node if any(inp in input_names for inp in node.input) ] seen_node_ids = {output_id for node in stack for output_id in node.output} first_prunable_nodes = [] while stack: node = stack.pop() if node.op_type in ["Gemm", "MatMul", "Conv"]: first_prunable_nodes.append(node) continue for child in graph.get_node_children(node): if any(output_id in seen_node_ids for output_id in child.output): continue stack.append(child) seen_node_ids.update(set(child.output)) return first_prunable_nodes
def _cleanup_unused_quants(model: ModelProto): """ A pass for removing unused Quantize->Dequantize blocks. This should be called at the end of conversion, once all of the conversions to quantized operators has been tried. Example: op -> QuantizeLinear -> DequantizeLinear -> non-quantized op => op -> non-quantized operator """ graph = ONNXGraph(model) nodes_to_delete = [] quant_nodes = [ n for n in model.graph.node if n.op_type == "QuantizeLinear" ] for quant_node in quant_nodes: dequant_node = graph.get_node_single_child(quant_node) if not dequant_node or dequant_node.op_type != "DequantizeLinear": continue removable = True dequant_children = graph.get_node_children(dequant_node) for child in dequant_children: if isinstance( child, onnx.NodeProto) and child.op_type in _QLINEAR_OP_NAMES: removable = False if not removable: continue # Forward QuantizeLinear input to DequantizeLinear output for child in dequant_children: _replace_input_id_model(model, dequant_node.output[0], quant_node.input[0]) # Remove QuantizeLinear->DequantizeLinear block nodes_to_delete.append(quant_node) nodes_to_delete.append(dequant_node) for n in nodes_to_delete: delete_quant_node(model, n)
def _get_node_param_array( node: NodeProto, graph: ONNXGraph, param_idx: int = 1, ) -> Optional[numpy.ndarray]: if len(node.input) <= param_idx: # no such param exists return None param = graph.get_init_by_name(node.input[1]) if param is None: # input is not a param stored as an initializer in the graph return None return numpy_helper.to_array(param)
def extract_layer_info(self, model: ModelProto) -> "OrderedDict[str, LayerInfo]": """ :param model: ONNX model to extract LayerInfo of :return: ordered dictionary of layer name to LayerInfo object for the prunable model layers """ layers = OrderedDict() graph = ONNXGraph(model) graph.sort_nodes_topologically() # for execution order first_prunable_nodes = _get_model_first_prunable_nodes(model) last_prunable_nodes = _get_model_last_prunable_nodes(model) for node in graph.nodes: layer_info = None if node.op_type == "Conv": layer_info = self._make_conv_layer_info( node, graph, len(layers)) elif node.op_type == "Gemm": layer_info = self._make_gemm_layer_info( node, graph, len(layers)) elif node.op_type == "MatMul": layer_info = self._make_matmul_layer_info( node, graph, len(layers)) if layer_info is not None: if node.name: layer_info.attributes["node_name"] = node.name if node.output: layer_info.attributes["node_output_id"] = node.output[0] if node in first_prunable_nodes: layer_info.attributes["first_prunable_layer"] = True if node in last_prunable_nodes: layer_info.attributes["last_prunable_layer"] = True layers[layer_info.name] = layer_info return layers
def _get_node_dependency_names(graph: ONNXGraph, node: onnx.NodeProto, structure_type: str) -> Set[str]: # returns a list of parameters whose should be pruned to match # the target dimensions of this node unchecked_nodes = _get_next_layer_deps(graph, node, structure_type) seen_output_ids = _get_node_output_ids(unchecked_nodes) dependent_params = set() if structure_type == "filter" and len(node.input) > 2: # node bias depends on num filters dependent_params.add(node.input[2]) while unchecked_nodes: current_node = unchecked_nodes.pop(0) if not isinstance(current_node, onnx.NodeProto): continue if current_node.op_type in _OUTPUT_CHANNEL_OP_TYPES: prunable = current_node.op_type in _PRUNABLE_OP_TYPES params = ( list(current_node.input[1:]) # skip layer input tensor if not (prunable and structure_type != "filter") else [current_node.input[1]] # bias not dependent on prev filter ) for param in params: if graph.get_init_by_name(param) is not None: dependent_params.add(param) if prunable and not _is_group_conv(current_node): # continue on other branches, do not go past prunable nodes continue dep_nodes = _get_next_layer_deps(graph, current_node, structure_type) for dep_node in dep_nodes: dep_node_ids = _get_node_output_ids(dep_node) if dep_node_ids.isdisjoint(seen_output_ids): unchecked_nodes.append(dep_node) seen_output_ids.update(dep_node_ids) return dependent_params
def _convert_quantizable_ops(model: ModelProto): quantizable_nodes = [ n for n in model.graph.node if n.op_type in ["Conv", "Gemm"] ] for quantizable_node in quantizable_nodes: graph = ONNXGraph(model) weight_dequant = graph.get_node_single_parent(quantizable_node, 1) if not weight_dequant or weight_dequant.op_type != "DequantizeLinear": continue weight_quant = graph.get_node_single_parent(weight_dequant, 0) if not weight_quant or weight_quant.op_type != "QuantizeLinear": continue input_quant = graph.get_node_single_parent(quantizable_node, 0) if not input_quant or input_quant.op_type not in _QUANTIZE_OP_NAMES: continue output_quant = graph.get_node_single_child(quantizable_node) if not output_quant or output_quant.op_type not in _QUANTIZE_OP_NAMES: continue if quantizable_node.op_type == "Conv": _convert_quantizable_conv( model, quantizable_node, input_quant, weight_dequant, weight_quant, output_quant, ) if quantizable_node.op_type == "Gemm": _convert_quantizable_gemm( model, quantizable_node, input_quant, weight_dequant, weight_quant, output_quant, )
def _get_next_layer_deps(graph: ONNXGraph, node: onnx.NodeProto, structure_type: str) -> List[onnx.NodeProto]: return ([ parent_node for parent_node in graph.get_node_parents(node) if isinstance(parent_node, onnx.NodeProto) ] if structure_type == "channel" else graph.get_node_children(node))
def get_param_structured_pruning_group_dependencies( model: Union[onnx.ModelProto, str], structure_type: str = "filter", ) -> Dict[str, List[str]]: """ :param model: model to generate pruning groups and dependencies for :param structure_type: valid options are 'filter' and 'channel'. Generates dependency map for corresponding pruning scheme. Default is 'filter' :return: dictionary of parameter names that should be grouped during structured pruning to a list of parameter names whose parameters should be updated accordingly to the param group pruning results. prunable parameter names will be represented as a comma separated string """ if structure_type not in ["filter", "channel"]: raise ValueError( f"invalid structure_type {structure_type}. not in ['filter', 'channel']" ) if isinstance(model, str): model = onnx.load(model) graph = ONNXGraph(model) param_name_to_dependents = {} # Dict[str, Set[str]] for node in model.graph.node: if node.op_type not in _PRUNABLE_OP_TYPES or (graph.get_init_by_name( node.input[1]) is None): # main param not found or not prunable continue param_name_to_dependents[node.input[1]] = _get_node_dependency_names( graph, node, structure_type) # merge disjoint sets of dependencies (could improve with union-find) prunable_param_group_to_dep_params = [] # List[Tuple[List, Set]] for prunable_param_name, dep_params in param_name_to_dependents.items(): intersected_group_idxs = { idx for idx, (_, group_dep_params ) in enumerate(prunable_param_group_to_dep_params) if not dep_params.isdisjoint(group_dep_params) } new_group_val = ([prunable_param_name], dep_params) if not intersected_group_idxs: prunable_param_group_to_dep_params.append(new_group_val) else: non_intersected_vals = [] for idx, (prunable_param_group, group_dep_params ) in enumerate(prunable_param_group_to_dep_params): if idx not in intersected_group_idxs: non_intersected_vals.append( (prunable_param_group, group_dep_params)) else: new_group_val = ( new_group_val[0] + prunable_param_group, new_group_val[1].union(group_dep_params), ) prunable_param_group_to_dep_params = non_intersected_vals + [ new_group_val ] return { ",".join(prunable_param_group): list(dependent_params) for prunable_param_group, dependent_params in prunable_param_group_to_dep_params }
def _convert_quantizable_matmul_and_add(model: ModelProto): """ A pass for converting a MatMul with kernel and bias into a quantized representation | Starting with: | INPUT QuantizeLinear (with constant kernel) | | | | QuantizeLinear DequantizeLinear | | | | DequantizeLinear Transpose | | | | MatMul | | | Add (with constant bias) | | | QuantizeLinear | | | DequantizeLinear | | | OUTPUT | We end up converting to: | INPUT | | | QuantizeLinear | | | QLinearMatMul (with constant kernel) | | | QLinearAdd (with constant bias) | | | DequantizeLinear | | | OUTPUT """ conversion_count = 0 matmul_nodes = [n for n in model.graph.node if n.op_type in ["MatMul"]] for matmul_node in matmul_nodes: graph = ONNXGraph(model) ############# # Matching ############# weight_transpose_node = graph.get_node_single_parent(matmul_node, 1) if not weight_transpose_node or weight_transpose_node.op_type != "Transpose": continue weight_dequantize_node = graph.get_node_single_parent( weight_transpose_node, 0) if (not weight_dequantize_node or weight_dequantize_node.op_type != "DequantizeLinear"): continue weight_quantize_node = graph.get_node_single_parent( weight_dequantize_node, 0) if not weight_quantize_node or weight_quantize_node.op_type != "QuantizeLinear": continue input_quantize_node = graph.get_node_single_parent(matmul_node, 0) if (not input_quantize_node or input_quantize_node.op_type not in _QUANTIZE_OP_NAMES): continue bias_add_node = graph.get_node_single_child(matmul_node) if not bias_add_node or bias_add_node.op_type != "Add": continue output_quantize_node = graph.get_node_single_child(bias_add_node) if (not output_quantize_node or output_quantize_node.op_type not in _QUANTIZE_OP_NAMES): continue input_quantize_params = get_quantization_params(model, input_quantize_node, include_target=False) weight_quantize_params = get_quantization_params(model, weight_quantize_node, include_target=True) if weight_quantize_params.target is None: # weight initializer not included continue if input_quantize_node.op_type != "DequantizeLinear": continue if output_quantize_node.op_type != "QuantizeLinear": continue bias_initializer = get_init_by_name(model, bias_add_node.input[1]) if bias_initializer is None: continue _LOGGER.debug( f"Matched quantizable MatMul weight and bias: {matmul_node.name}") ############# # Conversion ############# # quantize weight quantized_weight = _quantize_array( weight_quantize_params.target, weight_quantize_params.scale, weight_quantize_params.zero_point, ) quantized_weight = quantized_weight.transpose( ) # Gemm has implicit transpose quantized_weight_name = "{}.weight_quantized".format(matmul_node.name) quantized_weight_initializer = numpy_helper.from_array( quantized_weight, name=quantized_weight_name) model.graph.initializer.append(quantized_weight_initializer) # QLinearMatMul # get qmatmul inputs and outputs qmatmul_input = input_quantize_node.input[0] qmatmul_inputs = [ qmatmul_input, # x input_quantize_node.input[1], # x_scale input_quantize_node.input[2], # x_zero_point quantized_weight_name, # w weight_quantize_node.input[1], # w_scale weight_quantize_node.input[2], # w_zero_point output_quantize_node.input[1], # y_scale output_quantize_node.input[2], # y_zero_point ] qmatmul_output = matmul_node.output[0] qmatmul_name = "{}_quant".format(matmul_node.name) # create qmatmul node and add it to graph qmatmul_node = onnx.helper.make_node( "QLinearMatMul", qmatmul_inputs, [qmatmul_output], qmatmul_name, ) model.graph.node.append(qmatmul_node) # QLinearAdd # quantize bias bias_initializer = numpy_helper.to_array(bias_initializer) bias_scale = input_quantize_params.scale * weight_quantize_params.scale bias_zero_point = 0 quantized_bias = _quantize_array(bias_initializer, bias_scale, bias_zero_point) quantized_bias_name = "{}.bias_quantized".format(bias_add_node.name) quantized_bias_initializer = numpy_helper.from_array( quantized_bias, name=quantized_bias_name) model.graph.initializer.append(quantized_bias_initializer) quantized_bias_scale_name = "{}.scale".format(quantized_bias_name) model.graph.initializer.append( numpy_helper.from_array(numpy.asarray(bias_scale), name=quantized_bias_scale_name)) quantized_bias_zero_point_name = "{}.zero_point".format( quantized_bias_name) model.graph.initializer.append( numpy_helper.from_array( numpy.asarray(bias_zero_point, dtype=numpy.uint8), name=quantized_bias_zero_point_name, )) # get qadd inputs and outputs qadd_input = qmatmul_output qadd_inputs = [ qadd_input, # x output_quantize_node.input[1], # x_scale output_quantize_node.input[2], # x_zero_point quantized_bias_name, # b quantized_bias_scale_name, # b_scale quantized_bias_zero_point_name, # b_zero_point output_quantize_node.input[1], # y_scale output_quantize_node.input[2], # y_zero_point ] qadd_output = output_quantize_node.output[0] qadd_name = "{}_quant".format(bias_add_node.name) kwargs = {"domain": "com.microsoft"} # create qlinearadd node and add it to graph qadd_node = onnx.helper.make_node( "QLinearAdd", qadd_inputs, [qadd_output], qadd_name, **kwargs, ) model.graph.node.append(qadd_node) # Cleanup # delete folded quantization ops delete_quant_node(model, weight_dequantize_node, keep_params=False) delete_quant_node(model, weight_quantize_node, keep_params=True) remove_node_and_params_from_graph(model, weight_transpose_node) delete_quant_node(model, input_quantize_node, keep_params=True) delete_quant_node(model, output_quantize_node, keep_params=True) # delete original Gemm node remove_node_and_params_from_graph(model, matmul_node, keep_params=None) # delete original Add node remove_node_and_params_from_graph(model, bias_add_node, keep_params=None) conversion_count += 1 if matmul_nodes: _LOGGER.info( f"Converted {conversion_count} quantizable MatMul ops with weight and bias " "to QLinearMatMul and QLinearAdd")