Пример #1
0
def add_0_5_to_normalized_input(m):
    """For normalized input between -0.5 ~ 0.5, add 0.5 to the input to keep it
    between 0 ~ 1.

    :param m: the model proto
    """
    g = m.graph
    if len(g.input) > 1:
        print("This model has multiple inputs. Cannot normalize input.")
        return
    input_shape = helper.get_shape_from_value_info(g.input[0])
    if len(input_shape) != 4:
        print("The input shape is not BCHW. Cannot normalize input.")
        return
    # Construct weight
    ch = input_shape[1]
    weight_np = np.zeros((ch, ch, 3, 3)).astype('float32')
    for i in range(ch):
        weight_np[i, i, 1, 1] = 1.0
    new_weight = helper.numpy_to_constant("input_norm_weight", weight_np)
    # Construct bias
    bias_np = np.array([0.5] * ch).astype('float32')
    new_bias = helper.numpy_to_constant("input_norm_bias", bias_np)
    # Construct Conv
    new_conv = onnx.helper.make_node(
        'Conv',
        ['origin_input', "input_norm_weight", "input_norm_bias"],
        [g.input[0].name],
        name='input_norm',
        dilations=[1, 1],
        kernel_shape=[3, 3],
        pads=[1, 1, 1, 1],
        strides=[1, 1]
    )
    # Construct value_infos
    old_input_value = g.input.pop()
    weight_value = onnx.helper.make_tensor_value_info(
        'input_norm_weight',
        old_input_value.type.tensor_type.elem_type,
        [3, 3, 3, 3]
    )
    bias_value = onnx.helper.make_tensor_value_info(
        'input_norm_bias',
        old_input_value.type.tensor_type.elem_type,
        [3]
    )
    # Connect the graph
    new_input_value = onnx.helper.make_tensor_value_info(
        'origin_input',
        old_input_value.type.tensor_type.elem_type,
        input_shape
    )
    g.input.extend([new_input_value])
    g.node.extend([new_weight, new_bias, new_conv])
    g.value_info.extend([weight_value, bias_value, old_input_value])
    # topological sort
    other.topological_sort(g)
Пример #2
0
def change_first_conv_from_bgr_to_rgb(m):
    """For input channel format BGR model, use this function to change the first
    conv weight to adapt the input into RGB.

    :param m: the model proto
    """
    # Check for first node.
    g = m.graph
    input_name = g.input[0].name
    first_nodes = helper.find_following_nodes_by_input_value_name(
        g, input_name)
    if len(first_nodes) > 1:
        return False
    first_node = first_nodes[0]
    # Now we have the first node. Check this first node.
    if first_node.op_type != 'Conv':
        return False
    weight_value = helper.find_value_by_name(g, first_node.input[1])
    weight_shape = helper.get_shape_from_value_info(weight_value)
    if weight_shape[1] != 3:
        return False
    # Do weight shuffle
    weight_node = helper.find_node_by_output_name(g, weight_value.name)
    weight_np = helper.constant_to_numpy(weight_node)
    b_channel = np.expand_dims(weight_np[:, 0, :, :], axis=1)
    g_channel = np.expand_dims(weight_np[:, 1, :, :], axis=1)
    r_channel = np.expand_dims(weight_np[:, 2, :, :], axis=1)
    new_np = np.concatenate((r_channel, g_channel, b_channel), axis=1)
    new_node = helper.numpy_to_constant(weight_value.name, new_np)
    # Replace the weight and topological sort
    g.node.remove(weight_node)
    g.node.extend([new_node])
    other.topological_sort(g)
    return True
Пример #3
0
def add_rgb2yynn_node(m):
    """Add a conv layer which can convert rgb to yynn input.
    """
    g = m.graph
    if len(g.input) > 1:
        print("This model has multiple inputs. Cannot change to rgb input.")
        return
    input_shape = helper.get_shape_from_value_info(g.input[0])
    if len(input_shape) != 4:
        print("The input shape is not BCHW. Cannot normalize input.")
        return
    # Construct weight
    ch = input_shape[1]
    weight_np = np.zeros((3, 3, 4, 4)).astype('float32')
    weight_np[1, 1, :3, :2] = np.array([[[[0.299],
                                          [0.587],
                                          [0.114]]]])
    weight_np[1, 1, 3, 2:] = 1.
    weight_np = np.transpose(weight_np, (3, 2, 0, 1))
    new_weight = helper.numpy_to_constant("input_rgb2yynn_weight", weight_np)
    # Construct conv node
    new_conv = onnx.helper.make_node(
        'Conv',
        ['new_input', "input_rgb2yynn_weight"],
        [g.input[0].name],
        name='input_rgba2yynn',
        dilations=[1, 1],
        kernel_shape=[3, 3],
        pads=[1, 1, 1, 1],
        strides=[1, 1]
    )
    # Construct value_infos
    old_input_value = g.input.pop()
    weight_value = onnx.helper.make_tensor_value_info(
        'input_rgb2yynn_weight',
        old_input_value.type.tensor_type.elem_type,
        [4, 4, 3, 3]
    )
    # Connect the graph
    new_input_value = onnx.helper.make_tensor_value_info(
        'new_input',
        old_input_value.type.tensor_type.elem_type,
        input_shape
    )
    g.input.extend([new_input_value])
    g.node.extend([new_weight, new_conv])
    g.value_info.extend([weight_value, old_input_value])
    # topological sort
    other.topological_sort(g)
Пример #4
0
def change_input_from_bgr_to_rgb(m):
    """For input channel format BGR model, use this function to modify the model
    to accepct RGB image.If the first node is a non-group Conv. Modify weight to
    adapt the input into RGB. Otherwise create a new node.

    :param m: the model proto
    """
    g = m.graph
    if len(g.input) > 1:
        print("This model has multiple inputs. Cannot change to RGB input.")
        return
    input_shape = helper.get_shape_from_value_info(g.input[0])
    if len(input_shape) != 4 or input_shape[1] != 3:
        print("The input shape is invalid for bgr conversion.")
        return
    # Try change conv weight first
    if change_first_conv_from_bgr_to_rgb(m):
        return
    # Otherwise, create a special conv node and replace the input
    # Construct weight
    weight_np = np.zeros((3, 3, 3, 3)).astype('float32')
    weight_np[0, 2, 1, 1] = 1.0
    weight_np[1, 1, 1, 1] = 1.0
    weight_np[2, 0, 1, 1] = 1.0
    new_weight = helper.numpy_to_constant("bgr_shuffle_weight", weight_np)
    # Construct Conv
    new_conv = onnx.helper.make_node(
        'Conv',
        ['rgb_input', "bgr_shuffle_weight"],
        [g.input[0].name],
        name='bgr_shuffle',
        dilations=[1, 1],
        kernel_shape=[3, 3],
        pads=[1, 1, 1, 1],
        strides=[1, 1]
    )
    # Connect the graph
    old_input_value = g.input.pop()
    new_input_value = onnx.helper.make_tensor_value_info(
        'rgb_input',
        old_input_value.type.tensor_type.elem_type,
        input_shape
    )
    g.input.extend([new_input_value])
    g.node.extend([new_weight, new_conv])
    # topological sort
    other.topological_sort(g)
Пример #5
0
def split_ConvTranspose(model):
    """To feed our compiler, split ConvTranspose into Upsample and Conv.

    :param model: the model
    """
    node_to_delete = []
    # Change model properties for upsample.
    if model.ir_version < 3:
        print("Warning: Current model IR version is not fully supported.")
    model.ir_version = 4
    model.opset_import[0].version = 9
    g = model.graph
    # Get a Convtranspose layer
    for node in g.node:
        # Find a Flatten node
        if node.op_type != 'ConvTranspose':
            continue
        # Check auto_pad
        auto_pad_proto = helper.get_attribute_by_name(node, "auto_pad")
        if auto_pad_proto is not None:
            print("Currently not split auto_pad ConvTranspose")
            continue
        # Check output_shape
        output_shape_proto = helper.get_attribute_by_name(node, "output_shape")
        if output_shape_proto is not None:
            print("Currently not split output_shape ConvTranspose")
            continue
        # Get input shape
        input_value = helper.find_value_by_name(g, node.input[0])
        if input_value is None:
            input_value = helper.find_input_by_name(g, node.input[0])
        if input_value is None:
            print("Cannot get value info named {}.".format(node.input[0]))
            exit(1)
        input_shape = helper.get_shape_from_value_info(input_value)
        # Get attrbutes
        attr = deconv_to_conv_info_extraction(input_shape, node)
        # Generate Upsample scales
        upsample_output_shape = list(input_shape)
        upsample_output_shape[2] = (input_shape[2] - 1) * attr["strides"][0] + 1
        upsample_output_shape[3] = (input_shape[3] - 1) * attr["strides"][1] + 1
        upsample_node_name = node.name + "_inner_upsample"
        upsample_scale_name = upsample_node_name + "_scales"
        scales_np = np.ones([4]).astype('float32')
        scales_np[2] = float(upsample_output_shape[2]) / input_shape[2]
        scales_np[3] = float(upsample_output_shape[3]) / input_shape[3]
        scales_node = helper.numpy_to_constant(upsample_scale_name, scales_np)
        # Generate a Upsample layer and an internal value info
        upsample_node = onnx.helper.make_node(
            "Upsample",
            [node.input[0], upsample_scale_name],
            [upsample_node_name],
            name=upsample_node_name,
            mode="zeros"
        )
        upsample_value_info = onnx.helper.make_tensor_value_info(
            upsample_node_name,
            input_value.type.tensor_type.elem_type,
            upsample_output_shape
        )
        # Check the weight layer, it may need a transpose
        if attr["group"] != input_shape[1]:
            weight_node = helper.find_node_by_output_name(g, node.input[1])
            weight_np = helper.constant_to_numpy(weight_node)
            new_weight_np = np.transpose(weight_np, [1, 0, 2, 3])
            new_weight_node = helper.numpy_to_constant(node.input[1], new_weight_np)
            node_to_delete.append(weight_node)
            g.node.extend([new_weight_node])
            value = helper.find_value_by_name(g, node.input[1])
            g.value_info.remove(value)
        # Generate a Conv layer
        conv_node_name = node.name + "_inner_conv"
        conv_node_input = [upsample_node_name]
        conv_node_input.extend(node.input[1:])
        conv_node = onnx.helper.make_node(
            "Conv",
            conv_node_input,
            [node.output[0]],
            name=conv_node_name,
            pads=[int(i) for i in attr["conv_pads"]],
            dilations=[int(i) for i in attr["dilations"]],
            group=int(attr["group"]),
            kernel_shape=[int(i) for i in attr["kernel_shape"]],
            strides=[int(1), int(1)]
        )
        # Reconnect the graph
        g.node.extend([scales_node, upsample_node, conv_node])
        g.value_info.extend([upsample_value_info])
        node_to_delete.append(node)
    # Delete useless nodes
    for node in node_to_delete:
        g.node.remove(node)
    topological_sort(g)
Пример #6
0
def fuse_BN_into_Gemm(g):
    """Fuse the following BN into the previous Gemm.

    :param g: the graph
    """
    node_to_remove = []
    for node in g.node:
        # Check for BN and Gemm
        if node.op_type != 'BatchNormalization':
            continue
        gemm_node = helper.find_node_by_output_name(g, node.input[0])
        if gemm_node is None:
            continue
        if gemm_node.op_type != 'Gemm':
            continue
        if len(
                helper.find_following_nodes_by_input_value_name(
                    g, gemm_node.output[0])) > 1:
            continue
        bn_node = node
        # Get original weights
        gemm_b_node = helper.find_node_by_output_name(g, gemm_node.input[1])
        gemm_b = helper.constant_to_numpy(gemm_b_node)
        gemm_c_node = helper.find_node_by_output_name(g, gemm_node.input[2])
        gemm_c = helper.constant_to_numpy(gemm_c_node)
        bn_scale_node = helper.find_node_by_output_name(g, bn_node.input[1])
        bn_scale = helper.constant_to_numpy(bn_scale_node)
        bn_bias_node = helper.find_node_by_output_name(g, bn_node.input[2])
        bn_bias = helper.constant_to_numpy(bn_bias_node)
        bn_mean_node = helper.find_node_by_output_name(g, bn_node.input[3])
        bn_mean = helper.constant_to_numpy(bn_mean_node)
        bn_var_node = helper.find_node_by_output_name(g, bn_node.input[4])
        bn_var = helper.constant_to_numpy(bn_var_node)
        # Apply attributes
        # epsilon
        epsilon = helper.get_attribute_by_name(bn_node, 'epsilon')
        if epsilon is None:
            epsilon = 0.00001
        else:
            epsilon = epsilon.f
        bn_var = bn_var + epsilon
        # alpha
        alpha = helper.get_attribute_by_name(gemm_node, 'alpha')
        if alpha is None:
            alpha = 1
        else:
            alpha = alpha.f
        gemm_b = gemm_b * alpha
        # beta
        beta = helper.get_attribute_by_name(gemm_node, 'beta')
        if beta is None:
            beta = 1
        else:
            beta = beta.f
        gemm_c = gemm_c * beta
        # transA
        transA = helper.get_attribute_by_name(gemm_node, 'transA')
        if transA is not None and transA.i == 1:
            raise RuntimeError("Do not support transA")
        # transB
        transB = helper.get_attribute_by_name(gemm_node, 'transB')
        if transB is not None and transB.i == 1:
            gemm_b = gemm_b.transpose()
        # Calculate new weights
        new_gemm_b = gemm_b * bn_scale / np.sqrt(bn_var)
        new_gemm_c = (gemm_c - bn_mean) * bn_scale / np.sqrt(bn_var) + bn_bias
        # Replace original weights
        new_gemm_b_node = helper.numpy_to_constant(gemm_b_node.name + '_fused',
                                                   new_gemm_b)
        new_gemm_c_node = helper.numpy_to_constant(gemm_c_node.name + '_fused',
                                                   new_gemm_c)
        g.node.extend([new_gemm_b_node, new_gemm_c_node])
        node_to_remove.extend([
            gemm_b_node, gemm_c_node, bn_node, bn_scale_node, bn_bias_node,
            bn_mean_node, bn_var_node
        ])
        # Modify attributes
        # alpha
        alpha = helper.get_attribute_by_name(gemm_node, 'alpha')
        if alpha is not None:
            alpha.f = 1.0
        # beta
        beta = helper.get_attribute_by_name(gemm_node, 'beta')
        if beta is not None:
            beta.f = 1.0
        # transB
        transB = helper.get_attribute_by_name(gemm_node, 'transB')
        if transB is not None:
            transB.i = 0
        # Connect the new graph
        gemm_node.input[1] = new_gemm_b_node.output[0]
        gemm_node.input[2] = new_gemm_c_node.output[0]
        gemm_b_value = helper.find_value_by_name(g, gemm_b_node.output[0])
        gemm_c_value = helper.find_value_by_name(g, gemm_c_node.output[0])
        gemm_b_value.name = new_gemm_b_node.output[0]
        gemm_c_value.name = new_gemm_c_node.output[0]
        gemm_value = helper.find_value_by_name(g, gemm_node.output[0])
        g.value_info.remove(gemm_value)
        gemm_node.output[0] = bn_node.output[0]
        for i in range(1, 5):
            value = helper.find_value_by_name(g, bn_node.input[i])
            g.value_info.remove(value)
    # Remove useless nodes
    for node in node_to_remove:
        g.node.remove(node)
    topological_sort(g)
Пример #7
0
def fuse_Gemm_into_Gemm(g):
    """Fuse the previous Gemm into the following Gemm.

    :param g: the graph
    """
    node_to_remove = []
    for node in g.node:
        # Check for Gemm and Gemm
        if node.op_type != 'Gemm':
            continue
        prev_node = helper.find_node_by_output_name(g, node.input[0])
        if prev_node is None:
            continue
        if prev_node.op_type != 'Gemm':
            continue
        # Get original weights
        prev_b_node = helper.find_node_by_output_name(g, prev_node.input[1])
        prev_b = helper.constant_to_numpy(prev_b_node)
        prev_c_node = helper.find_node_by_output_name(g, prev_node.input[2])
        prev_c = helper.constant_to_numpy(prev_c_node)
        b_node = helper.find_node_by_output_name(g, node.input[1])
        b = helper.constant_to_numpy(b_node)
        c_node = helper.find_node_by_output_name(g, node.input[2])
        c = helper.constant_to_numpy(c_node)
        # Apply attributes
        # alpha
        alpha = helper.get_attribute_by_name(node, 'alpha')
        if alpha is None:
            alpha = 1
        else:
            alpha = alpha.f
        b = b * alpha
        alpha = helper.get_attribute_by_name(prev_node, 'alpha')
        if alpha is None:
            alpha = 1
        else:
            alpha = alpha.f
        prev_b = prev_b * alpha
        # beta
        beta = helper.get_attribute_by_name(node, 'beta')
        if beta is None:
            beta = 1
        else:
            beta = beta.f
        c = c * beta
        beta = helper.get_attribute_by_name(prev_node, 'beta')
        if beta is None:
            beta = 1
        else:
            beta = beta.f
        prev_c = prev_c * beta
        # transA
        transA = helper.get_attribute_by_name(node, 'transA')
        if transA is not None and transA.i == 1:
            raise RuntimeError("Do not support transA")
        transA = helper.get_attribute_by_name(prev_node, 'transA')
        if transA is not None and transA.i == 1:
            raise RuntimeError("Do not support transA")
        # transB
        transB = helper.get_attribute_by_name(node, 'transB')
        if transB is not None and transB.i == 1:
            b = b.transpose()
        transB = helper.get_attribute_by_name(prev_node, 'transB')
        if transB is not None and transB.i == 1:
            prev_b = prev_b.transpose()
        # Calculate new weights
        new_b = prev_b.dot(b)
        new_c = prev_c.dot(b) + c
        # Replace original weights
        new_b_node = helper.numpy_to_constant(b_node.name + '_fused', new_b)
        new_c_node = helper.numpy_to_constant(c_node.name + '_fused', new_c)
        g.node.extend([new_b_node, new_c_node])
        node_to_remove.extend(
            [b_node, c_node, prev_b_node, prev_c_node, prev_node])
        # Modify attributes
        # alpha
        alpha = helper.get_attribute_by_name(node, 'alpha')
        if alpha is not None:
            alpha.f = 1.0
        # beta
        beta = helper.get_attribute_by_name(node, 'beta')
        if beta is not None:
            beta.f = 1.0
        # transB
        transB = helper.get_attribute_by_name(node, 'transB')
        if transB is not None:
            transB.i = 0
        # Connect the new graph
        node.input[0] = prev_node.input[0]
        prev_value = helper.find_value_by_name(g, prev_node.output[0])
        g.value_info.remove(prev_value)
        for i in range(1, 3):
            value = helper.find_value_by_name(g, prev_node.input[i])
            g.value_info.remove(value)
            value = helper.find_value_by_name(g, node.input[i])
            g.value_info.remove(value)
        node.input[1] = new_b_node.output[0]
        node.input[2] = new_c_node.output[0]
    # Remove useless nodes
    for node in node_to_remove:
        g.node.remove(node)
    topological_sort(g)
Пример #8
0
def fuse_BN_with_Reshape_into_Gemm(g):
    """Fuse the following BN into the previous Gemm, even with Reshape or \\
        Squeeze and Unsqueeze surrounding.

    :param g: the graph
    """
    node_to_remove = []
    for node in g.node:
        # Check for BN and Gemm pattern: Gemm A BN B
        # Find BatchNorm Node
        if node.op_type != 'BatchNormalization':
            continue
        bn_node = node
        # Find A Node
        a_node = helper.find_node_by_output_name(g, node.input[0])
        if a_node is None or len(a_node.input) == 0:
            continue
        # Find Gemm Node
        gemm_node = helper.find_node_by_output_name(g, a_node.input[0])
        if gemm_node is None or gemm_node.op_type != 'Gemm':
            continue
        # Find B Node
        b_node_list = helper.find_following_nodes_by_input_value_name(
            g, bn_node.output[0])
        if len(b_node_list) == 0:
            the_output = helper.find_output_by_name(g, bn_node.output[0])
            if the_output is None:
                continue
            b_node = None
        elif len(b_node_list) > 1:
            continue
        else:
            b_node = b_node_list[0]
        # Check for branches
        if len(
                helper.find_following_nodes_by_input_value_name(
                    g, gemm_node.output[0])) > 1:
            continue
        if len(
                helper.find_following_nodes_by_input_value_name(
                    g, a_node.output[0])) > 1:
            continue
        # Check type of A
        if a_node.op_type == 'Unsqueeze':
            axes = helper.get_attribute_by_name(a_node, 'axes')
            if axes.ints != [2]:
                continue
        elif a_node.op_type == 'Reshape':
            a = helper.constant_to_list(
                helper.find_node_by_output_name(g, a_node.input[1]))[1]
            if len(a) != 3 or a[2] != 1:
                continue
        else:
            continue
        # Check type of B
        if b_node is None:
            pass
        elif b_node.op_type == 'Flatten':
            pass
        elif b_node.op_type == 'Squeeze':
            axes = helper.get_attribute_by_name(a_node, 'axes')
            if axes.ints != [2]:
                continue
        elif b_node.op_type == 'Reshape':
            a = helper.constant_to_list(
                helper.find_node_by_output_name(g, b_node.input[1]))[1]
            if len(a) != 2:
                continue
        else:
            continue
        # Construct new Nodes
        # Get original weights
        gemm_b_node = helper.find_node_by_output_name(g, gemm_node.input[1])
        gemm_b = helper.constant_to_numpy(gemm_b_node)
        gemm_c_node = helper.find_node_by_output_name(g, gemm_node.input[2])
        gemm_c = helper.constant_to_numpy(gemm_c_node)
        bn_scale_node = helper.find_node_by_output_name(g, bn_node.input[1])
        bn_scale = helper.constant_to_numpy(bn_scale_node)
        bn_bias_node = helper.find_node_by_output_name(g, bn_node.input[2])
        bn_bias = helper.constant_to_numpy(bn_bias_node)
        bn_mean_node = helper.find_node_by_output_name(g, bn_node.input[3])
        bn_mean = helper.constant_to_numpy(bn_mean_node)
        bn_var_node = helper.find_node_by_output_name(g, bn_node.input[4])
        bn_var = helper.constant_to_numpy(bn_var_node)
        # Apply attributes
        # epsilon
        epsilon = helper.get_attribute_by_name(bn_node, 'epsilon')
        if epsilon is None:
            epsilon = 0.00001
        else:
            epsilon = epsilon.f
        bn_var = bn_var + epsilon
        # alpha
        alpha = helper.get_attribute_by_name(gemm_node, 'alpha')
        if alpha is None:
            alpha = 1
        else:
            alpha = alpha.f
        gemm_b = gemm_b * alpha
        # beta
        beta = helper.get_attribute_by_name(gemm_node, 'beta')
        if beta is None:
            beta = 1
        else:
            beta = beta.f
        gemm_c = gemm_c * beta
        # transA
        transA = helper.get_attribute_by_name(gemm_node, 'transA')
        if transA is not None and transA.i == 1:
            raise RuntimeError("Do not support transA")
        # transB
        transB = helper.get_attribute_by_name(gemm_node, 'transB')
        if transB is not None and transB.i == 1:
            gemm_b = gemm_b.transpose()
        # Calculate new weights
        new_gemm_b = gemm_b * bn_scale / np.sqrt(bn_var)
        new_gemm_c = (gemm_c - bn_mean) * bn_scale / np.sqrt(bn_var) + bn_bias
        # Replace original weights
        new_gemm_b_node = helper.numpy_to_constant(gemm_b_node.name + '_fused',
                                                   new_gemm_b)
        new_gemm_c_node = helper.numpy_to_constant(gemm_c_node.name + '_fused',
                                                   new_gemm_c)
        g.node.extend([new_gemm_b_node, new_gemm_c_node])
        # Modify attributes
        # alpha
        alpha = helper.get_attribute_by_name(gemm_node, 'alpha')
        if alpha is not None:
            alpha.f = 1.0
        # beta
        beta = helper.get_attribute_by_name(gemm_node, 'beta')
        if beta is not None:
            beta.f = 1.0
        # transB
        transB = helper.get_attribute_by_name(gemm_node, 'transB')
        if transB is not None:
            transB.i = 0
        # Remove useless nodes
        node_to_remove.extend([
            gemm_b_node, gemm_c_node, bn_node, bn_scale_node, bn_bias_node,
            bn_mean_node, bn_var_node, a_node
        ])
        if a_node.op_type == 'Reshape':
            node_to_remove.append(
                helper.find_node_by_output_name(g, a_node.input[1]))
        if b_node is not None:
            node_to_remove.append(b_node)
            if b_node.op_type == 'Reshape':
                node_to_remove.append(
                    helper.find_node_by_output_name(g, b_node.input[1]))
        # Delete useless value infos
        value = helper.find_value_by_name(g, a_node.output[0])
        g.value_info.remove(value)
        if a_node.op_type == 'Reshape':
            value = helper.find_value_by_name(g, a_node.input[1])
            g.value_info.remove(value)
        for i in range(1, 5):
            value = helper.find_value_by_name(g, bn_node.input[i])
            g.value_info.remove(value)
        value = helper.find_value_by_name(g, bn_node.output[0])
        if value is not None:
            g.value_info.remove(value)
        if b_node is not None:
            value = helper.find_value_by_name(g, gemm_node.output[0])
            g.value_info.remove(value)
            if b_node.op_type == 'Reshape':
                value = helper.find_value_by_name(g, b_node.input[1])
                g.value_info.remove(value)
        # Connect the new graph
        # Connect Gemm new weights
        gemm_node.input[1] = new_gemm_b_node.output[0]
        gemm_node.input[2] = new_gemm_c_node.output[0]
        gemm_b_value = helper.find_value_by_name(g, gemm_b_node.output[0])
        gemm_c_value = helper.find_value_by_name(g, gemm_c_node.output[0])
        gemm_b_value.name = new_gemm_b_node.output[0]
        gemm_b_value.type.tensor_type.shape.dim[
            0].dim_value = new_gemm_b.shape[0]
        gemm_b_value.type.tensor_type.shape.dim[
            1].dim_value = new_gemm_b.shape[1]
        gemm_c_value.name = new_gemm_c_node.output[0]
        if b_node is None:
            # If b node is None, set the Gemm output as the graph output
            output_value = helper.find_output_by_name(g, bn_node.output[0])
            g.output.remove(output_value)
            g.output.extend(
                [helper.find_value_by_name(g, gemm_node.output[0])])
        else:
            # Else, set node B output as gemm output
            gemm_node.output[0] = b_node.output[0]
    # Remove useless nodes
    for node in node_to_remove:
        g.node.remove(node)
    topological_sort(g)
Пример #9
0
def pattern_matmul_mul_add(g, matmul_node):
    # Check node match - Mul node
    next_nodes = helper.find_nodes_by_input_name(g, matmul_node.output[0])
    if len(next_nodes) != 1:
        return
    if next_nodes[0].op_type != 'Mul':
        return
    mul_node = next_nodes[0]
    # Check node match - Add node
    next_nodes = helper.find_nodes_by_input_name(g, mul_node.output[0])
    if len(next_nodes) != 1:
        return
    if next_nodes[0].op_type != 'Add':
        return
    add_node = next_nodes[0]
    # Check Mul weight
    mul_weight_node = helper.find_node_by_output_name(g, mul_node.input[1])
    if mul_weight_node.op_type != 'Constant':
        return
    weight_size, mul_weight = helper.constant_to_list(mul_weight_node)
    for i in mul_weight:
        if i != 1:
            return
    channel = weight_size[0]
    # Check Add weight
    add_weight_node = helper.find_node_by_output_name(g, add_node.input[1])
    if add_weight_node.op_type != 'Constant':
        return
    # Check MatMul weight to see if it need weight broadcast
    matmul_weight_node = helper.find_node_by_output_name(g, matmul_node.input[1])
    matmul_weight = helper.constant_to_numpy(matmul_weight_node)
    if matmul_weight.shape[1] == 1:
        # Weight broadcast
        new_matmul_weight = np.tile(matmul_weight, channel)
        new_matmul_weight_node = helper.numpy_to_constant(matmul_weight_node.name, new_matmul_weight)
        g.node.remove(matmul_weight_node)
        g.node.extend([new_matmul_weight_node])
    value = helper.find_value_by_name(g, matmul_weight_node.output[0])
    if value is not None:
        g.value_info.remove(value)
    # Remove Mul node
    g.node.remove(mul_weight_node)
    value = helper.find_value_by_name(g, mul_weight_node.output[0])
    if value is not None:
        g.value_info.remove(value)
    g.node.remove(mul_node)
    value = helper.find_value_by_name(g, mul_node.output[0])
    if value is not None:
        g.value_info.remove(value)
    # Fuse Matmul and Add
    gemm_node = onnx.helper.make_node(
        'Gemm',
        [matmul_node.input[0], matmul_node.input[1], add_node.input[1]],
        [add_node.output[0]],
        name = matmul_node.name,
        alpha = 1.0,
        beta = 1.0,
        transA = 0,
        transB = 0
    )
    g.node.extend([gemm_node])
    # Clean up
    g.node.remove(matmul_node)
    g.node.remove(add_node)
    value = helper.find_value_by_name(g, matmul_node.output[0])
    if value is not None:
        g.value_info.remove(value)
    other.topological_sort(g)