コード例 #1
0
def _skip_input_quantize(model: ModelProto) -> Optional[str]:
    if (
        len(model.graph.input) != 1
        or model.graph.input[0].type.tensor_type.elem_type != 1
    ):
        # more than 1 input or input is not FP32
        return (
            "Not modifying ONNX graph inputs - either graph has more than one "
            "input or input type is not FP32"
        )

    input_node = model.graph.input[0]
    input_children = [
        node for node in model.graph.node if input_node.name in node.input
    ]
    if not all(node.op_type == "QuantizeLinear" for node in input_children):
        return (
            "Not modifying ONNX graph inputs - only QuantizeLinear nodes may follow"
            "the FP32 input tensor in original graph, prior to converting to uint8"
        )

    graph = ONNXGraph(model)
    for quantize_node in input_children:
        quantize_children = graph.get_node_children(quantize_node)
        quantize_node_id = quantize_node.output[0]
        for child_node in quantize_children:
            input_idx = [
                idx
                for idx, inp in enumerate(child_node.input)
                if inp == quantize_node_id
            ]
            if not input_idx:
                continue
            input_idx = input_idx[0]
            graph.update_node_input(child_node, input_node.name, input_idx)
            _LOGGER.debug(
                f"set node with output id {child_node.output[0]} as initial node in "
                "graph"
            )

    _LOGGER.debug(
        f"deleting QuantizeLinear node(s) with output id(s): "
        f"{[n.output for n in input_children]}"
    )
    graph.delete_nodes(input_children)  # only contains references to the Quantize nodes
    graph.delete_unused_initializers()  # cleanup
    input_node.type.tensor_type.elem_type = 2  # fp32 -> uint8
    _LOGGER.info("Model initial QuantizeLinear node(s) deleted and inputs set to uint8")

    return None
コード例 #2
0
def _get_model_first_prunable_nodes(model: ModelProto) -> List[NodeProto]:
    graph = ONNXGraph(model)
    input_names = {tens.name for tens in model.graph.input}
    stack = [
        node for node in model.graph.node
        if any(inp in input_names for inp in node.input)
    ]
    seen_node_ids = {output_id for node in stack for output_id in node.output}
    first_prunable_nodes = []
    while stack:
        node = stack.pop()
        if node.op_type in ["Gemm", "MatMul", "Conv"]:
            first_prunable_nodes.append(node)
            continue
        for child in graph.get_node_children(node):
            if any(output_id in seen_node_ids for output_id in child.output):
                continue
            stack.append(child)
            seen_node_ids.update(set(child.output))
    return first_prunable_nodes
コード例 #3
0
def _cleanup_unused_quants(model: ModelProto):
    """
    A pass for removing unused Quantize->Dequantize blocks.
    This should be called at the end of conversion, once all of the conversions
    to quantized operators has been tried.
    Example:
    op -> QuantizeLinear -> DequantizeLinear -> non-quantized op
    => op -> non-quantized operator
    """
    graph = ONNXGraph(model)
    nodes_to_delete = []
    quant_nodes = [
        n for n in model.graph.node if n.op_type == "QuantizeLinear"
    ]
    for quant_node in quant_nodes:
        dequant_node = graph.get_node_single_child(quant_node)
        if not dequant_node or dequant_node.op_type != "DequantizeLinear":
            continue

        removable = True
        dequant_children = graph.get_node_children(dequant_node)
        for child in dequant_children:
            if isinstance(
                    child,
                    onnx.NodeProto) and child.op_type in _QLINEAR_OP_NAMES:
                removable = False
        if not removable:
            continue

        # Forward QuantizeLinear input to DequantizeLinear output
        for child in dequant_children:
            _replace_input_id_model(model, dequant_node.output[0],
                                    quant_node.input[0])

        # Remove QuantizeLinear->DequantizeLinear block
        nodes_to_delete.append(quant_node)
        nodes_to_delete.append(dequant_node)

    for n in nodes_to_delete:
        delete_quant_node(model, n)
コード例 #4
0
def _get_next_layer_deps(graph: ONNXGraph, node: onnx.NodeProto,
                         structure_type: str) -> List[onnx.NodeProto]:
    return ([
        parent_node for parent_node in graph.get_node_parents(node)
        if isinstance(parent_node, onnx.NodeProto)
    ] if structure_type == "channel" else graph.get_node_children(node))