def optimize_fp16_onnx_with_cast(input_onnx_path, optimized_onnx_path,
                                 epsilon):
    m = onnx.load(input_onnx_path)
    onnx_model = OnnxModel(m)

    nodes_to_remove = onnx_model.nodes()
    nodes_to_add = [
        onnx.helper.make_node("Cast", ["input"], ["fp32_input"],
                              "cast_input",
                              to=1),
        onnx.helper.make_node("Cast", ["layer_norm.weight"],
                              ["fp32_layer_norm.weight"],
                              "cast_weight",
                              to=1),
        onnx.helper.make_node("Cast", ["layer_norm.bias"],
                              ["fp32_layer_norm.bias"],
                              "cast_bias",
                              to=1),
        onnx.helper.make_node(
            "LayerNormalization",
            ["fp32_input", "fp32_layer_norm.weight", "fp32_layer_norm.bias"],
            ["fp32_output"],
            "layer_norm",
            epsilon=epsilon),  # use fp32 epsilon
        onnx.helper.make_node("Cast", ["fp32_output"], ["output"],
                              "cast_output",
                              to=10)
    ]

    onnx_model.remove_nodes(nodes_to_remove)
    onnx_model.add_nodes(nodes_to_add)
    onnx_model.prune_graph()
    onnx_model.save_model_to_file(optimized_onnx_path)
def optimize_fp16_onnx_no_cast(input_onnx_path, optimized_onnx_path, epsilon):
    m = onnx.load(input_onnx_path)
    onnx_model = OnnxModel(m)

    weight_name = get_weight(onnx_model)
    bias_name = get_bias(onnx_model)
    nodes_to_remove = [n for n in onnx_model.nodes() if n.output[0] != weight_name and n.output[0] != bias_name]

    nodes_to_remove = onnx_model.nodes()
    node_to_add = onnx.helper.make_node("LayerNormalization", ["input", weight_name, bias_name], ["output"],
                                        "layer_norm",
                                        epsilon=epsilon)

    onnx_model.remove_nodes(nodes_to_remove)
    onnx_model.add_node(node_to_add)
    onnx_model.prune_graph()
    onnx_model.save_model_to_file(optimized_onnx_path)
def optimize_fp16_onnx_no_cast(input_onnx_path, optimized_onnx_path, epsilon):
    m = onnx.load(input_onnx_path)
    onnx_model = OnnxModel(m)

    nodes_to_remove = onnx_model.nodes()
    node_to_add = onnx.helper.make_node(
        "LayerNormalization",
        ["input", "layer_norm.weight", "layer_norm.bias"], ["output"],
        "layer_norm",
        epsilon=epsilon)

    onnx_model.remove_nodes(nodes_to_remove)
    onnx_model.add_node(node_to_add)
    onnx_model.prune_graph()
    onnx_model.save_model_to_file(optimized_onnx_path)
Beispiel #4
0
    def remove_useless_reshape_nodes(model: OnnxModel):
        """Remove reshape node that is not needed based on symbolic shape inference: input and output has same shape
        """
        shape_infer = model.infer_runtime_shape(update=True)
        if shape_infer is None:
            return

        nodes_to_remove = []
        for node in model.nodes():
            if node.op_type == 'Reshape':
                input_shape = shape_infer.get_edge_shape(node.input[0])
                output_shape = shape_infer.get_edge_shape(node.output[0])
                if input_shape and output_shape and input_shape == output_shape:
                    logger.info(
                        f"Remove reshape node {node.name} since its input shape is same as output: {input_shape}"
                    )
                    nodes_to_remove.append(node)

        if nodes_to_remove:
            for node in nodes_to_remove:
                model.replace_input_of_all_nodes(node.output[0], node.input[0])
                model.remove_node(node)
            model.prune_graph()
Beispiel #5
0
def optimize_fp16_onnx_with_cast(input_onnx_path, optimized_onnx_path,
                                 epsilon):
    m = onnx.load(input_onnx_path)
    onnx_model = OnnxModel(m)
    weight_name = get_weight(onnx_model)
    bias_name = get_bias(onnx_model)
    nodes_to_remove = [
        n for n in onnx_model.nodes()
        if n.output[0] != weight_name and n.output[0] != bias_name
    ]
    nodes_to_add = [
        onnx.helper.make_node("Cast", ["input"], ["fp32_input"],
                              "cast_input",
                              to=1),
        onnx.helper.make_node("Cast", [weight_name],
                              ["fp32_layer_norm.weight"],
                              "cast_weight",
                              to=1),
        onnx.helper.make_node("Cast", [bias_name], ["fp32_layer_norm.bias"],
                              "cast_bias",
                              to=1),
        onnx.helper.make_node(
            "LayerNormalization",
            ["fp32_input", "fp32_layer_norm.weight", "fp32_layer_norm.bias"],
            ["fp32_output"],
            "layer_norm",
            epsilon=epsilon,
        ),  # use fp32 epsilon
        onnx.helper.make_node("Cast", ["fp32_output"], ["output"],
                              "cast_output",
                              to=10),
    ]

    onnx_model.remove_nodes(nodes_to_remove)
    onnx_model.add_nodes(nodes_to_add)
    onnx_model.prune_graph()
    onnx_model.save_model_to_file(optimized_onnx_path)
Beispiel #6
0
    def auto_mixed_precision(onnx_model: OnnxModel,
                             op_block_list: List[str] = [
                                 'Add', 'LayerNormalization', 'FastGelu'
                             ]):
        """Convert GPT-2 model to mixed precision.
           It detects whether original model has fp16 precision weights, and set parameters for float16 conversion automatically.
        Args:
            onnx_model (OnnxModel): optimized ONNX model
            op_block_list (List[str], optional): . Defaults to ['Add', 'LayerNormalization', 'FastGelu']
        Returns:
            parameters(dict): a dictionary of parameters used in float16 conversion
        """
        op_full_set = set([node.op_type for node in onnx_model.nodes()])
        fp32_op_set = set(op_block_list)
        fp16_op_set = op_full_set.difference(fp32_op_set)
        logger.info(f"fp32 op: {fp32_op_set} fp16 op: {fp16_op_set}")

        # logits is the first output
        logits_output_name = onnx_model.graph().output[0].name

        # We use the weight in last MatMul node to detect whether the model is stored with float16 weights from training.
        is_weight_fp16_precision = False
        output_name_to_node = onnx_model.output_name_to_node()
        assert logits_output_name in output_name_to_node
        node = output_name_to_node[logits_output_name]
        last_matmul_node = None
        if node.op_type == "MatMul":
            last_matmul_node = node
            logger.info(f"Found last MatMul node for logits: {node.name}")
            initializer = None
            for input in node.input:
                initializer = onnx_model.get_initializer(input)
                if initializer is not None:
                    break

            # when the max difference of value after converting float to float16 is lower than a threshold (1e-6),
            # we can deduce that the weights are stored in float16 precision.
            max_diff = float_to_float16_max_diff(initializer)
            logger.debug(
                f"max diff of converting weights in last MatMul node {node.name}: {max_diff}"
            )
            is_weight_fp16_precision = (max_diff < 1E-6)
        else:
            logger.warning(
                f"Failed to find MatMul node for logits. Found {node.op_type} of node {node.name}"
            )

        if is_weight_fp16_precision:
            keep_io_types = []
            node_block_list = []
        else:
            # When original weight is float32 precision, keep logits and last MatMul in float32 could get better precision.
            keep_io_types = [logits_output_name]
            node_block_list = [last_matmul_node.name]

        parameters = {
            "keep_io_types": keep_io_types,
            "op_block_list": op_block_list,
            "node_block_list": node_block_list,
            "force_fp16_initializers": is_weight_fp16_precision
        }

        logger.info(f"auto_mixed_precision parameters: {parameters}")
        onnx_model.convert_float_to_float16(use_symbolic_shape_infer=True,
                                            **parameters)

        fusion_utils = FusionUtils(onnx_model)
        fusion_utils.remove_cascaded_cast_nodes()
        fusion_utils.remove_useless_cast_nodes()

        return parameters