Ejemplo n.º 1
0
def _delete_repeated_qat_blocks(model: ModelProto):
    # removes repeated qat quant/dequant blocks with the same parameters
    # (Quant -> Dequant -> Quant -> Dequant) -> (Quant -> Dequant)
    graph = ONNXGraph(model)
    nodes_to_delete = []
    quant_nodes = [
        n for n in model.graph.node if n.op_type == "QuantizeLinear"
    ]
    for quant_node_1 in quant_nodes:
        dequant_node_1 = graph.get_node_single_child(quant_node_1)
        if not dequant_node_1 or dequant_node_1.op_type != "DequantizeLinear":
            continue
        quant_node_2 = graph.get_node_single_child(dequant_node_1)
        if not quant_node_2 or quant_node_2.op_type != "QuantizeLinear":
            continue
        dequant_node_2 = graph.get_node_single_child(quant_node_2)
        if not dequant_node_2 or dequant_node_2.op_type != "DequantizeLinear":
            continue

        # forward first qat block input to that of the second
        quant_node_2.input[0] = quant_node_1.input[0]

        # remove repeated quant/dequant block
        nodes_to_delete.append(quant_node_1)
        nodes_to_delete.append(dequant_node_1)

    for n in nodes_to_delete:
        delete_quant_node(model, n)
Ejemplo n.º 2
0
    def get_named_prunable_params(self, model: Any) -> Dict[str, numpy.ndarray]:
        """
        loads the prunable parameters in a standardized way so that weight magnitude
        analysis may be run on each

        :param model: model to load the prunable parameters from
        :return: dictionary of prunable parameter name as listed in the ModelInfo to
            a numpy array of the values of the parameter
        """
        graph = ONNXGraph(model)
        return {
            layer_name: numpy_helper.to_array(graph.get_init_by_name(layer_name, False))
            for layer_name, layer_info in self._model_info.layer_info.items()
            if layer_info.prunable
        }
Ejemplo n.º 3
0
def quantize_torch_qat_export(
    model: Union[ModelProto, str],
    output_file_path: Union[str, None] = None,
    inplace: bool = True,
) -> ModelProto:
    """
    :param model: The model to convert, or a file path to it
    :param output_file_path: File path to save the converted model to
    :param inplace: If true, does conversion of model in place. Default is true
    :return: Converts a model exported from a torch QAT session from a QAT graph with
        fake quantize ops surrounding operations to a quantized graph with quantized
        operations. All quantized Convs and FC inputs and outputs be surrounded by
        fake quantize ops
    """
    if isinstance(model, str):
        model = onnx.load(model)

    if not inplace:
        model = deepcopy(model)

    _fold_qat_conv_bns(model)
    _fold_relu_quants(model)
    _convert_single_constants_to_initializers(model)
    _delete_repeated_qat_blocks(model)
    _convert_quantizable_ops(model)
    quantize_resnet_identity_add_inputs(model)
    quantized_residual_add_optim(model)
    _remove_duplicate_quantize__ops(model)
    ONNXGraph(model).sort_nodes_topologically()

    if output_file_path:
        onnx.save(model, output_file_path)

    return model
Ejemplo n.º 4
0
def _skip_input_quantize(model: ModelProto) -> Optional[str]:
    if (
        len(model.graph.input) != 1
        or model.graph.input[0].type.tensor_type.elem_type != 1
    ):
        # more than 1 input or input is not FP32
        return (
            "Not modifying ONNX graph inputs - either graph has more than one "
            "input or input type is not FP32"
        )

    input_node = model.graph.input[0]
    input_children = [
        node for node in model.graph.node if input_node.name in node.input
    ]
    if not all(node.op_type == "QuantizeLinear" for node in input_children):
        return (
            "Not modifying ONNX graph inputs - only QuantizeLinear nodes may follow"
            "the FP32 input tensor in original graph, prior to converting to uint8"
        )

    graph = ONNXGraph(model)
    for quantize_node in input_children:
        quantize_children = graph.get_node_children(quantize_node)
        quantize_node_id = quantize_node.output[0]
        for child_node in quantize_children:
            input_idx = [
                idx
                for idx, inp in enumerate(child_node.input)
                if inp == quantize_node_id
            ]
            if not input_idx:
                continue
            input_idx = input_idx[0]
            graph.update_node_input(child_node, input_node.name, input_idx)
            _LOGGER.debug(
                f"set node with output id {child_node.output[0]} as initial node in "
                "graph"
            )

    _LOGGER.debug(
        f"deleting QuantizeLinear node(s) with output id(s): "
        f"{[n.output for n in input_children]}"
    )
    graph.delete_nodes(input_children)  # only contains references to the Quantize nodes
    graph.delete_unused_initializers()  # cleanup
    input_node.type.tensor_type.elem_type = 2  # fp32 -> uint8
    _LOGGER.info("Model initial QuantizeLinear node(s) deleted and inputs set to uint8")

    return None
Ejemplo n.º 5
0
def _get_model_last_prunable_nodes(model: ModelProto) -> List[NodeProto]:
    graph = ONNXGraph(model)
    output_names = {tens.name for tens in model.graph.output}
    stack = [
        node for node in model.graph.node
        if any(out in output_names for out in node.output)
    ]
    seen_node_ids = {output_id for node in stack for output_id in node.output}
    last_prunable_nodes = []
    while stack:
        node = stack.pop()
        if node.op_type in ["Gemm", "MatMul", "Conv"]:
            last_prunable_nodes.append(node)
            continue
        for parent in graph.get_node_parents(node):
            if any(output_id in seen_node_ids for output_id in parent.output):
                continue
            stack.append(parent)
            seen_node_ids.update(set(parent.output))
    return last_prunable_nodes
Ejemplo n.º 6
0
def _get_model_first_prunable_nodes(model: ModelProto) -> List[NodeProto]:
    graph = ONNXGraph(model)
    input_names = {tens.name for tens in model.graph.input}
    stack = [
        node for node in model.graph.node
        if any(inp in input_names for inp in node.input)
    ]
    seen_node_ids = {output_id for node in stack for output_id in node.output}
    first_prunable_nodes = []
    while stack:
        node = stack.pop()
        if node.op_type in ["Gemm", "MatMul", "Conv"]:
            first_prunable_nodes.append(node)
            continue
        for child in graph.get_node_children(node):
            if any(output_id in seen_node_ids for output_id in child.output):
                continue
            stack.append(child)
            seen_node_ids.update(set(child.output))
    return first_prunable_nodes
Ejemplo n.º 7
0
def _cleanup_unused_quants(model: ModelProto):
    """
    A pass for removing unused Quantize->Dequantize blocks.
    This should be called at the end of conversion, once all of the conversions
    to quantized operators has been tried.
    Example:
    op -> QuantizeLinear -> DequantizeLinear -> non-quantized op
    => op -> non-quantized operator
    """
    graph = ONNXGraph(model)
    nodes_to_delete = []
    quant_nodes = [
        n for n in model.graph.node if n.op_type == "QuantizeLinear"
    ]
    for quant_node in quant_nodes:
        dequant_node = graph.get_node_single_child(quant_node)
        if not dequant_node or dequant_node.op_type != "DequantizeLinear":
            continue

        removable = True
        dequant_children = graph.get_node_children(dequant_node)
        for child in dequant_children:
            if isinstance(
                    child,
                    onnx.NodeProto) and child.op_type in _QLINEAR_OP_NAMES:
                removable = False
        if not removable:
            continue

        # Forward QuantizeLinear input to DequantizeLinear output
        for child in dequant_children:
            _replace_input_id_model(model, dequant_node.output[0],
                                    quant_node.input[0])

        # Remove QuantizeLinear->DequantizeLinear block
        nodes_to_delete.append(quant_node)
        nodes_to_delete.append(dequant_node)

    for n in nodes_to_delete:
        delete_quant_node(model, n)
Ejemplo n.º 8
0
 def _get_node_param_array(
     node: NodeProto,
     graph: ONNXGraph,
     param_idx: int = 1,
 ) -> Optional[numpy.ndarray]:
     if len(node.input) <= param_idx:
         # no such param exists
         return None
     param = graph.get_init_by_name(node.input[1])
     if param is None:
         # input is not a param stored as an initializer in the graph
         return None
     return numpy_helper.to_array(param)
Ejemplo n.º 9
0
    def extract_layer_info(self,
                           model: ModelProto) -> "OrderedDict[str, LayerInfo]":
        """
        :param model: ONNX model to extract LayerInfo of
        :return: ordered dictionary of layer name to LayerInfo object for the prunable
            model layers
        """
        layers = OrderedDict()
        graph = ONNXGraph(model)
        graph.sort_nodes_topologically()  # for execution order

        first_prunable_nodes = _get_model_first_prunable_nodes(model)
        last_prunable_nodes = _get_model_last_prunable_nodes(model)

        for node in graph.nodes:
            layer_info = None
            if node.op_type == "Conv":
                layer_info = self._make_conv_layer_info(
                    node, graph, len(layers))
            elif node.op_type == "Gemm":
                layer_info = self._make_gemm_layer_info(
                    node, graph, len(layers))
            elif node.op_type == "MatMul":
                layer_info = self._make_matmul_layer_info(
                    node, graph, len(layers))

            if layer_info is not None:
                if node.name:
                    layer_info.attributes["node_name"] = node.name
                if node.output:
                    layer_info.attributes["node_output_id"] = node.output[0]
                if node in first_prunable_nodes:
                    layer_info.attributes["first_prunable_layer"] = True
                if node in last_prunable_nodes:
                    layer_info.attributes["last_prunable_layer"] = True
                layers[layer_info.name] = layer_info

        return layers
Ejemplo n.º 10
0
def _get_node_dependency_names(graph: ONNXGraph, node: onnx.NodeProto,
                               structure_type: str) -> Set[str]:
    # returns a list of parameters whose should be pruned to match
    # the target dimensions of this node
    unchecked_nodes = _get_next_layer_deps(graph, node, structure_type)
    seen_output_ids = _get_node_output_ids(unchecked_nodes)
    dependent_params = set()

    if structure_type == "filter" and len(node.input) > 2:
        # node bias depends on num filters
        dependent_params.add(node.input[2])

    while unchecked_nodes:
        current_node = unchecked_nodes.pop(0)
        if not isinstance(current_node, onnx.NodeProto):
            continue

        if current_node.op_type in _OUTPUT_CHANNEL_OP_TYPES:
            prunable = current_node.op_type in _PRUNABLE_OP_TYPES
            params = (
                list(current_node.input[1:])  # skip layer input tensor
                if not (prunable and structure_type != "filter") else
                [current_node.input[1]]  # bias not dependent on prev filter
            )

            for param in params:
                if graph.get_init_by_name(param) is not None:
                    dependent_params.add(param)
            if prunable and not _is_group_conv(current_node):
                # continue on other branches, do not go past prunable nodes
                continue
        dep_nodes = _get_next_layer_deps(graph, current_node, structure_type)
        for dep_node in dep_nodes:
            dep_node_ids = _get_node_output_ids(dep_node)
            if dep_node_ids.isdisjoint(seen_output_ids):
                unchecked_nodes.append(dep_node)
                seen_output_ids.update(dep_node_ids)

    return dependent_params
Ejemplo n.º 11
0
def _convert_quantizable_ops(model: ModelProto):
    quantizable_nodes = [
        n for n in model.graph.node if n.op_type in ["Conv", "Gemm"]
    ]
    for quantizable_node in quantizable_nodes:
        graph = ONNXGraph(model)

        weight_dequant = graph.get_node_single_parent(quantizable_node, 1)
        if not weight_dequant or weight_dequant.op_type != "DequantizeLinear":
            continue
        weight_quant = graph.get_node_single_parent(weight_dequant, 0)
        if not weight_quant or weight_quant.op_type != "QuantizeLinear":
            continue

        input_quant = graph.get_node_single_parent(quantizable_node, 0)
        if not input_quant or input_quant.op_type not in _QUANTIZE_OP_NAMES:
            continue

        output_quant = graph.get_node_single_child(quantizable_node)
        if not output_quant or output_quant.op_type not in _QUANTIZE_OP_NAMES:
            continue

        if quantizable_node.op_type == "Conv":
            _convert_quantizable_conv(
                model,
                quantizable_node,
                input_quant,
                weight_dequant,
                weight_quant,
                output_quant,
            )

        if quantizable_node.op_type == "Gemm":
            _convert_quantizable_gemm(
                model,
                quantizable_node,
                input_quant,
                weight_dequant,
                weight_quant,
                output_quant,
            )
Ejemplo n.º 12
0
def _get_next_layer_deps(graph: ONNXGraph, node: onnx.NodeProto,
                         structure_type: str) -> List[onnx.NodeProto]:
    return ([
        parent_node for parent_node in graph.get_node_parents(node)
        if isinstance(parent_node, onnx.NodeProto)
    ] if structure_type == "channel" else graph.get_node_children(node))
Ejemplo n.º 13
0
def get_param_structured_pruning_group_dependencies(
    model: Union[onnx.ModelProto, str],
    structure_type: str = "filter",
) -> Dict[str, List[str]]:
    """
    :param model: model to generate pruning groups and dependencies for
    :param structure_type: valid options are 'filter' and 'channel'. Generates
        dependency map for corresponding pruning scheme. Default is 'filter'
    :return: dictionary of parameter names that should be grouped during
        structured pruning to a list of parameter names whose parameters should
        be updated accordingly to the param group pruning results. prunable parameter
        names will be represented as a comma separated string
    """
    if structure_type not in ["filter", "channel"]:
        raise ValueError(
            f"invalid structure_type {structure_type}. not in ['filter', 'channel']"
        )

    if isinstance(model, str):
        model = onnx.load(model)

    graph = ONNXGraph(model)
    param_name_to_dependents = {}  # Dict[str, Set[str]]
    for node in model.graph.node:
        if node.op_type not in _PRUNABLE_OP_TYPES or (graph.get_init_by_name(
                node.input[1]) is None):
            # main param not found or not prunable
            continue

        param_name_to_dependents[node.input[1]] = _get_node_dependency_names(
            graph, node, structure_type)

    # merge disjoint sets of dependencies (could improve with union-find)
    prunable_param_group_to_dep_params = []  # List[Tuple[List, Set]]
    for prunable_param_name, dep_params in param_name_to_dependents.items():
        intersected_group_idxs = {
            idx
            for idx, (_, group_dep_params
                      ) in enumerate(prunable_param_group_to_dep_params)
            if not dep_params.isdisjoint(group_dep_params)
        }
        new_group_val = ([prunable_param_name], dep_params)
        if not intersected_group_idxs:
            prunable_param_group_to_dep_params.append(new_group_val)
        else:
            non_intersected_vals = []
            for idx, (prunable_param_group, group_dep_params
                      ) in enumerate(prunable_param_group_to_dep_params):
                if idx not in intersected_group_idxs:
                    non_intersected_vals.append(
                        (prunable_param_group, group_dep_params))
                else:
                    new_group_val = (
                        new_group_val[0] + prunable_param_group,
                        new_group_val[1].union(group_dep_params),
                    )
            prunable_param_group_to_dep_params = non_intersected_vals + [
                new_group_val
            ]

    return {
        ",".join(prunable_param_group): list(dependent_params)
        for prunable_param_group, dependent_params in
        prunable_param_group_to_dep_params
    }
Ejemplo n.º 14
0
def _convert_quantizable_matmul_and_add(model: ModelProto):
    """
    A pass for converting a MatMul with kernel and bias into a quantized representation

    | Starting with:
    |          INPUT         QuantizeLinear (with constant kernel)
    |            |               |
    |     QuantizeLinear     DequantizeLinear
    |            |               |
    |     DequantizeLinear   Transpose
    |                  |      |
    |                   MatMul
    |                     |
    |                    Add (with constant bias)
    |                     |
    |               QuantizeLinear
    |                     |
    |              DequantizeLinear
    |                     |
    |                  OUTPUT
    | We end up converting to:
    |       INPUT
    |         |
    |     QuantizeLinear
    |         |
    |     QLinearMatMul (with constant kernel)
    |         |
    |     QLinearAdd (with constant bias)
    |         |
    |     DequantizeLinear
    |         |
    |       OUTPUT
    """
    conversion_count = 0
    matmul_nodes = [n for n in model.graph.node if n.op_type in ["MatMul"]]
    for matmul_node in matmul_nodes:
        graph = ONNXGraph(model)
        #############
        # Matching
        #############
        weight_transpose_node = graph.get_node_single_parent(matmul_node, 1)
        if not weight_transpose_node or weight_transpose_node.op_type != "Transpose":
            continue

        weight_dequantize_node = graph.get_node_single_parent(
            weight_transpose_node, 0)
        if (not weight_dequantize_node
                or weight_dequantize_node.op_type != "DequantizeLinear"):
            continue
        weight_quantize_node = graph.get_node_single_parent(
            weight_dequantize_node, 0)
        if not weight_quantize_node or weight_quantize_node.op_type != "QuantizeLinear":
            continue

        input_quantize_node = graph.get_node_single_parent(matmul_node, 0)
        if (not input_quantize_node
                or input_quantize_node.op_type not in _QUANTIZE_OP_NAMES):
            continue

        bias_add_node = graph.get_node_single_child(matmul_node)
        if not bias_add_node or bias_add_node.op_type != "Add":
            continue
        output_quantize_node = graph.get_node_single_child(bias_add_node)
        if (not output_quantize_node
                or output_quantize_node.op_type not in _QUANTIZE_OP_NAMES):
            continue

        input_quantize_params = get_quantization_params(model,
                                                        input_quantize_node,
                                                        include_target=False)
        weight_quantize_params = get_quantization_params(model,
                                                         weight_quantize_node,
                                                         include_target=True)
        if weight_quantize_params.target is None:
            # weight initializer not included
            continue
        if input_quantize_node.op_type != "DequantizeLinear":
            continue
        if output_quantize_node.op_type != "QuantizeLinear":
            continue
        bias_initializer = get_init_by_name(model, bias_add_node.input[1])
        if bias_initializer is None:
            continue

        _LOGGER.debug(
            f"Matched quantizable MatMul weight and bias: {matmul_node.name}")

        #############
        # Conversion
        #############
        # quantize weight
        quantized_weight = _quantize_array(
            weight_quantize_params.target,
            weight_quantize_params.scale,
            weight_quantize_params.zero_point,
        )
        quantized_weight = quantized_weight.transpose(
        )  # Gemm has implicit transpose
        quantized_weight_name = "{}.weight_quantized".format(matmul_node.name)
        quantized_weight_initializer = numpy_helper.from_array(
            quantized_weight, name=quantized_weight_name)
        model.graph.initializer.append(quantized_weight_initializer)

        # QLinearMatMul
        # get qmatmul inputs and outputs
        qmatmul_input = input_quantize_node.input[0]
        qmatmul_inputs = [
            qmatmul_input,  # x
            input_quantize_node.input[1],  # x_scale
            input_quantize_node.input[2],  # x_zero_point
            quantized_weight_name,  # w
            weight_quantize_node.input[1],  # w_scale
            weight_quantize_node.input[2],  # w_zero_point
            output_quantize_node.input[1],  # y_scale
            output_quantize_node.input[2],  # y_zero_point
        ]
        qmatmul_output = matmul_node.output[0]
        qmatmul_name = "{}_quant".format(matmul_node.name)

        # create qmatmul node and add it to graph
        qmatmul_node = onnx.helper.make_node(
            "QLinearMatMul",
            qmatmul_inputs,
            [qmatmul_output],
            qmatmul_name,
        )
        model.graph.node.append(qmatmul_node)

        # QLinearAdd
        # quantize bias
        bias_initializer = numpy_helper.to_array(bias_initializer)
        bias_scale = input_quantize_params.scale * weight_quantize_params.scale
        bias_zero_point = 0
        quantized_bias = _quantize_array(bias_initializer, bias_scale,
                                         bias_zero_point)
        quantized_bias_name = "{}.bias_quantized".format(bias_add_node.name)
        quantized_bias_initializer = numpy_helper.from_array(
            quantized_bias, name=quantized_bias_name)
        model.graph.initializer.append(quantized_bias_initializer)
        quantized_bias_scale_name = "{}.scale".format(quantized_bias_name)
        model.graph.initializer.append(
            numpy_helper.from_array(numpy.asarray(bias_scale),
                                    name=quantized_bias_scale_name))
        quantized_bias_zero_point_name = "{}.zero_point".format(
            quantized_bias_name)
        model.graph.initializer.append(
            numpy_helper.from_array(
                numpy.asarray(bias_zero_point, dtype=numpy.uint8),
                name=quantized_bias_zero_point_name,
            ))

        # get qadd inputs and outputs
        qadd_input = qmatmul_output
        qadd_inputs = [
            qadd_input,  # x
            output_quantize_node.input[1],  # x_scale
            output_quantize_node.input[2],  # x_zero_point
            quantized_bias_name,  # b
            quantized_bias_scale_name,  # b_scale
            quantized_bias_zero_point_name,  # b_zero_point
            output_quantize_node.input[1],  # y_scale
            output_quantize_node.input[2],  # y_zero_point
        ]
        qadd_output = output_quantize_node.output[0]
        qadd_name = "{}_quant".format(bias_add_node.name)
        kwargs = {"domain": "com.microsoft"}
        # create qlinearadd node and add it to graph
        qadd_node = onnx.helper.make_node(
            "QLinearAdd",
            qadd_inputs,
            [qadd_output],
            qadd_name,
            **kwargs,
        )
        model.graph.node.append(qadd_node)

        # Cleanup
        # delete folded quantization ops
        delete_quant_node(model, weight_dequantize_node, keep_params=False)
        delete_quant_node(model, weight_quantize_node, keep_params=True)
        remove_node_and_params_from_graph(model, weight_transpose_node)
        delete_quant_node(model, input_quantize_node, keep_params=True)
        delete_quant_node(model, output_quantize_node, keep_params=True)

        # delete original Gemm node
        remove_node_and_params_from_graph(model, matmul_node, keep_params=None)
        # delete original Add node
        remove_node_and_params_from_graph(model,
                                          bias_add_node,
                                          keep_params=None)

        conversion_count += 1

    if matmul_nodes:
        _LOGGER.info(
            f"Converted {conversion_count} quantizable MatMul ops with weight and bias "
            "to QLinearMatMul and QLinearAdd")