def _skip_input_quantize(model: ModelProto) -> Optional[str]: if ( len(model.graph.input) != 1 or model.graph.input[0].type.tensor_type.elem_type != 1 ): # more than 1 input or input is not FP32 return ( "Not modifying ONNX graph inputs - either graph has more than one " "input or input type is not FP32" ) input_node = model.graph.input[0] input_children = [ node for node in model.graph.node if input_node.name in node.input ] if not all(node.op_type == "QuantizeLinear" for node in input_children): return ( "Not modifying ONNX graph inputs - only QuantizeLinear nodes may follow" "the FP32 input tensor in original graph, prior to converting to uint8" ) graph = ONNXGraph(model) for quantize_node in input_children: quantize_children = graph.get_node_children(quantize_node) quantize_node_id = quantize_node.output[0] for child_node in quantize_children: input_idx = [ idx for idx, inp in enumerate(child_node.input) if inp == quantize_node_id ] if not input_idx: continue input_idx = input_idx[0] graph.update_node_input(child_node, input_node.name, input_idx) _LOGGER.debug( f"set node with output id {child_node.output[0]} as initial node in " "graph" ) _LOGGER.debug( f"deleting QuantizeLinear node(s) with output id(s): " f"{[n.output for n in input_children]}" ) graph.delete_nodes(input_children) # only contains references to the Quantize nodes graph.delete_unused_initializers() # cleanup input_node.type.tensor_type.elem_type = 2 # fp32 -> uint8 _LOGGER.info("Model initial QuantizeLinear node(s) deleted and inputs set to uint8") return None
def _get_model_first_prunable_nodes(model: ModelProto) -> List[NodeProto]: graph = ONNXGraph(model) input_names = {tens.name for tens in model.graph.input} stack = [ node for node in model.graph.node if any(inp in input_names for inp in node.input) ] seen_node_ids = {output_id for node in stack for output_id in node.output} first_prunable_nodes = [] while stack: node = stack.pop() if node.op_type in ["Gemm", "MatMul", "Conv"]: first_prunable_nodes.append(node) continue for child in graph.get_node_children(node): if any(output_id in seen_node_ids for output_id in child.output): continue stack.append(child) seen_node_ids.update(set(child.output)) return first_prunable_nodes
def _cleanup_unused_quants(model: ModelProto): """ A pass for removing unused Quantize->Dequantize blocks. This should be called at the end of conversion, once all of the conversions to quantized operators has been tried. Example: op -> QuantizeLinear -> DequantizeLinear -> non-quantized op => op -> non-quantized operator """ graph = ONNXGraph(model) nodes_to_delete = [] quant_nodes = [ n for n in model.graph.node if n.op_type == "QuantizeLinear" ] for quant_node in quant_nodes: dequant_node = graph.get_node_single_child(quant_node) if not dequant_node or dequant_node.op_type != "DequantizeLinear": continue removable = True dequant_children = graph.get_node_children(dequant_node) for child in dequant_children: if isinstance( child, onnx.NodeProto) and child.op_type in _QLINEAR_OP_NAMES: removable = False if not removable: continue # Forward QuantizeLinear input to DequantizeLinear output for child in dequant_children: _replace_input_id_model(model, dequant_node.output[0], quant_node.input[0]) # Remove QuantizeLinear->DequantizeLinear block nodes_to_delete.append(quant_node) nodes_to_delete.append(dequant_node) for n in nodes_to_delete: delete_quant_node(model, n)
def _get_next_layer_deps(graph: ONNXGraph, node: onnx.NodeProto, structure_type: str) -> List[onnx.NodeProto]: return ([ parent_node for parent_node in graph.get_node_parents(node) if isinstance(parent_node, onnx.NodeProto) ] if structure_type == "channel" else graph.get_node_children(node))