Example #1
0
def add_bn_before_add(g):
    for n in g.node:
        # Find merge node (Add)
        if n.op_type != 'Add':
            continue
        if len(n.input) != 2:
            continue
        # Get two inputs
        input_node_a = helper.find_node_by_output_name(g, n.input[0])
        input_node_b = helper.find_node_by_output_name(g, n.input[1])
        # Skip constant input add
        if input_node_a is None or input_node_a.op_type == 'Constant':
            continue
        if input_node_b is None or input_node_b.op_type == 'Constant':
            continue

        def add_bn_after(prev_node):
            # Get the channel number from value info
            value_name = prev_node.output[0]
            value = helper.find_value_by_name(g, value_name)
            shape = helper.get_shape_from_value_info(value)
            channel = shape[1]
            # Construct 4 weights
            node_name = value_name + "_nop_bn"
            ones = [1.0] * channel
            zeros = [0.0] * channel
            scale_node = helper.list_to_constant(node_name + "_scale",
                                                 [channel], ones)
            bias_node = helper.list_to_constant(node_name + "_bias", [channel],
                                                zeros)
            mean_node = helper.list_to_constant(node_name + "_mean", [channel],
                                                zeros)
            var_node = helper.list_to_constant(node_name + "_var", [channel],
                                               ones)
            # Construct BN node
            bn_node = onnx.helper.make_node("BatchNormalization", [
                value_name, scale_node.output[0], bias_node.output[0],
                mean_node.output[0], var_node.output[0]
            ], [node_name],
                                            name=node_name,
                                            epsilon=0.00000001)
            # Reconnect the graph
            replace_node_input(n, value_name, node_name)
            # Add node to the graph
            g.node.extend(
                [bn_node, scale_node, bias_node, mean_node, var_node])

        if not input_node_a.op_type == 'BatchNormalization' or len(
                helper.find_following_nodes_by_input_value_name(
                    g, input_node_a.output[0])) > 1:
            add_bn_after(input_node_a)
        if not input_node_b.op_type == 'BatchNormalization' or len(
                helper.find_following_nodes_by_input_value_name(
                    g, input_node_b.output[0])) > 1:
            add_bn_after(input_node_b)
    topological_sort(g)
Example #2
0
def duplicate_shared_Flatten(g):
    """To feed our compiler, bind Flatten with Gemm. If the output of one\\
    Flatten goes to two Gemm nodes, duplicate the Flatten.

    :param g: the graph
    """
    for node in g.node:
        # Find a Flatten node
        if node.op_type != 'Flatten':
            continue
        # Check Flatten outputs. Get following Gemm
        output_nodes = helper.find_following_nodes_by_input_value_name(g, node.output[0])
        if len(output_nodes) < 2:
            continue
        gemm_nodes = []
        for output_node in output_nodes:
            if output_node.op_type == 'Gemm':
                gemm_nodes.append(output_node)
        if len(gemm_nodes) < 2:
            continue
        # Process all the Gemm nodes except for the first one.
        for i in range(1, len(gemm_nodes)):
            # Duplicate
            new_flatten_name = node.name + "_copy" + str(i)
            new_flatten_node = onnx.helper.make_node(
                "Flatten",
                node.input,
                [new_flatten_name],
                name=new_flatten_name,
                axis=1
            )
            # Connect new graph
            replace_node_input(gemm_nodes[i], node.output[0], new_flatten_name)
            g.node.extend([new_flatten_node])
    topological_sort(g)
Example #3
0
def add_nop_bn_after(g, value_names):
    """Add do-nothing BatchNormalization nodes after the given value info. It will\\
    take the given names as the inputs of the new node and replace the inputs\\
    of the following nodes.

    :param g: the graph\\
    :param value_names: a list of string which are the names of value_info.
    """
    for value_name in value_names:
        # Find the value first
        value = helper.find_value_by_name(g, value_name)
        if value is None:
            value = helper.find_input_by_name(g, value_name)
        if value is None:
            value = helper.find_output_by_name(g, value_name)
        if value is None:
            print("Cannot find an value_info named {}".format(value_name))
            continue
        # Get the channel number from value info
        shape = helper.get_shape_from_value_info(value)
        channel = shape[1]
        # Construct 4 weights
        node_name = value_name + "_nop_bn"
        ones = [1.0] * channel
        zeros = [0.0] * channel
        scale_node = helper.list_to_constant(node_name + "_scale", [channel],
                                             ones)
        bias_node = helper.list_to_constant(node_name + "_bias", [channel],
                                            zeros)
        mean_node = helper.list_to_constant(node_name + "_mean", [channel],
                                            zeros)
        var_node = helper.list_to_constant(node_name + "_var", [channel], ones)
        # Construct BN node
        bn_node = onnx.helper.make_node("BatchNormalization", [
            value_name, scale_node.output[0], bias_node.output[0],
            mean_node.output[0], var_node.output[0]
        ], [node_name],
                                        name=node_name)
        # Reconnect the graph
        following_nodes = helper.find_following_nodes_by_input_value_name(
            g, value_name)
        if len(following_nodes) > 0:
            for following_node in following_nodes:
                replace_node_input(following_node, value_name, node_name)
        else:
            # If the node is the output, replace the output with the previous input.
            new_value = onnx.helper.make_tensor_value_info(
                node_name, value.type.tensor_type.elem_type, shape)
            output_values = []
            while len(g.output):
                output_values.append(g.output.pop())
            while output_values:
                output_value = output_values.pop()
                if output_value.name == value_name:
                    g.output.extend([new_value])
                else:
                    g.output.extend([output_value])
        # Add node to the graph
        g.node.extend([bn_node, scale_node, bias_node, mean_node, var_node])
    topological_sort(g)
Example #4
0
def change_first_conv_from_bgr_to_rgb(m):
    """For input channel format BGR model, use this function to change the first
    conv weight to adapt the input into RGB.

    :param m: the model proto
    """
    # Check for first node.
    g = m.graph
    input_name = g.input[0].name
    first_nodes = helper.find_following_nodes_by_input_value_name(
        g, input_name)
    if len(first_nodes) > 1:
        return False
    first_node = first_nodes[0]
    # Now we have the first node. Check this first node.
    if first_node.op_type != 'Conv':
        return False
    weight_value = helper.find_value_by_name(g, first_node.input[1])
    weight_shape = helper.get_shape_from_value_info(weight_value)
    if weight_shape[1] != 3:
        return False
    # Do weight shuffle
    weight_node = helper.find_node_by_output_name(g, weight_value.name)
    weight_np = helper.constant_to_numpy(weight_node)
    b_channel = np.expand_dims(weight_np[:, 0, :, :], axis=1)
    g_channel = np.expand_dims(weight_np[:, 1, :, :], axis=1)
    r_channel = np.expand_dims(weight_np[:, 2, :, :], axis=1)
    new_np = np.concatenate((r_channel, g_channel, b_channel), axis=1)
    new_node = helper.numpy_to_constant(weight_value.name, new_np)
    # Replace the weight and topological sort
    g.node.remove(weight_node)
    g.node.extend([new_node])
    other.topological_sort(g)
    return True
Example #5
0
def rename_all_node_name(g):
    """
    rename all nodes:

        new_name = old_name + "_kn"

    :param g: the onnx graph
    """

    for node in g.node:
        new_node_name = node.name + "_kn"
        new_node_output0_name = node.output[0] + "_kn"

        # in order to keep same output node name, skip if it is output node.
        output_value_info = helper.find_output_by_name(g, node.output[0])
        if output_value_info != None:
            continue

        # rename  the input of all the following nodes
        following_nodes = helper.find_following_nodes_by_input_value_name(
            g, node.output[0])
        for following_node in following_nodes:
            replace_node_input(following_node, node.output[0],
                               new_node_output0_name)

        # rename value info
        value_info = helper.find_value_by_name(g, node.output[0])
        if value_info != None:
            value_info.name = new_node_output0_name

        # rename node
        node.output[0] = new_node_output0_name
        node.name = new_node_name
Example #6
0
def add_nop_conv_after(g, value_names):
    """Add do-nothing depthwise Conv nodes after the given value info. It will\\
    take the given names as the inputs of the new node and replace the inputs\\
    of the following nodes.

    :param g: the graph\\
    :param value_names: a list of string which are the names of value_info.
    """
    for value_name in value_names:
        # Find the value first
        value = helper.find_value_by_name(g, value_name)
        if value is None:
            value = helper.find_input_by_name(g, value_name)
        if value is None:
            value = helper.find_output_by_name(g, value_name)
        if value is None:
            print("Cannot find an value_info named {}".format(value_name))
            continue
        # Get the channel number from value info
        shape = helper.get_shape_from_value_info(value)
        channel = shape[1]
        # Construct 4 weights
        node_name = value_name + "_nop_conv"
        ones = [1.0] * channel
        weight_node = helper.list_to_constant(node_name + "_weight",
                                              [channel, 1, 1, 1], ones)
        # Construct BN node
        conv_node = onnx.helper.make_node("Conv",
                                          [value_name, weight_node.output[0]],
                                          [node_name],
                                          name=node_name,
                                          dilations=[1, 1],
                                          group=channel,
                                          kernel_shape=[1, 1],
                                          pads=[0, 0, 0, 0],
                                          strides=[1, 1])
        # Reconnect the graph
        following_nodes = helper.find_following_nodes_by_input_value_name(
            g, value_name)
        if len(following_nodes) > 0:
            for following_node in following_nodes:
                replace_node_input(following_node, value_name, node_name)
        else:
            # If the node is the output, replace the output with the previous input.
            new_value = onnx.helper.make_tensor_value_info(
                node_name, value.type.tensor_type.elem_type, shape)
            output_values = []
            while len(g.output):
                output_values.append(g.output.pop())
            while output_values:
                output_value = output_values.pop()
                if output_value.name == value_name:
                    g.output.extend([new_value])
                else:
                    g.output.extend([output_value])
        # Add node to the graph
        g.node.extend([conv_node, weight_node])
    topological_sort(g)
Example #7
0
def replace_Sum_with_Adds(g):
    node_to_del = []

    for node in g.node:
        # Check for sum
        if node.op_type != 'Sum':
            continue
        # Check for input number
        if len(node.input) == 1:
            # If input number is 1, delete the sum node.
            following_nodes = helper.find_following_nodes_by_input_value_name(g, node.output[0])
            for following_node in following_nodes:
                modhelper.replace_node_input(following_node, node.output[0], node.input[0])
            node_to_del.append(node)
            if helper.find_value_by_name(node.output[0]) is not None:
                g.value_info.remove(helper.find_value_by_name(node.output[0]))
        elif len(node.input) == 2:
            # If input number is 2, replace it with add.
            node.op_type = 'Add'
            continue
        elif len(node.input) > 2:
            # If input number is larger than 2, replace it with n-1 add.
            input_count = len(node.input)
            # First node has 2 inputs
            first_node = onnx.helper.make_node(
                "Add",
                [node.input[0], node.input[1]],
                [node.output[0] + '_replacement_1'],
                name=node.name + '_replacement_1'
            )
            # Last node has the same output as the original sum node
            last_node = onnx.helper.make_node(
                "Add",
                [node.output[0] + '_replacement_' + str(input_count - 2), node.input[input_count - 1]],
                [node.output[0]],
                name=node.name
            )
            g.node.extend([first_node, last_node])
            for i in range(2, input_count - 1):
                new_node = onnx.helper.make_node(
                    "Add",
                    [node.output[0] + '_replacement_' + str(i - 1), node.input[i]],
                    [node.output[0] + '_replacement_' + str(i)],
                    name=node.name + '_replacement_' + str(i)
                )
                g.node.extend([new_node])
            node_to_del.append(node)
        else:
            logging.error("Sum node must have at least 1 input.")
            quit(1)

    while node_to_del:
        g.node.remove(node_to_del.pop())

    topological_sort(g)
Example #8
0
def fuse_BN_into_Gemm(g):
    """Fuse the following BN into the previous Gemm.

    :param g: the graph
    """
    node_to_remove = []
    for node in g.node:
        # Check for BN and Gemm
        if node.op_type != 'BatchNormalization':
            continue
        gemm_node = helper.find_node_by_output_name(g, node.input[0])
        if gemm_node is None:
            continue
        if gemm_node.op_type != 'Gemm':
            continue
        if len(
                helper.find_following_nodes_by_input_value_name(
                    g, gemm_node.output[0])) > 1:
            continue
        bn_node = node
        # Get original weights
        gemm_b_node = helper.find_node_by_output_name(g, gemm_node.input[1])
        gemm_b = helper.constant_to_numpy(gemm_b_node)
        gemm_c_node = helper.find_node_by_output_name(g, gemm_node.input[2])
        gemm_c = helper.constant_to_numpy(gemm_c_node)
        bn_scale_node = helper.find_node_by_output_name(g, bn_node.input[1])
        bn_scale = helper.constant_to_numpy(bn_scale_node)
        bn_bias_node = helper.find_node_by_output_name(g, bn_node.input[2])
        bn_bias = helper.constant_to_numpy(bn_bias_node)
        bn_mean_node = helper.find_node_by_output_name(g, bn_node.input[3])
        bn_mean = helper.constant_to_numpy(bn_mean_node)
        bn_var_node = helper.find_node_by_output_name(g, bn_node.input[4])
        bn_var = helper.constant_to_numpy(bn_var_node)
        # Apply attributes
        # epsilon
        epsilon = helper.get_attribute_by_name(bn_node, 'epsilon')
        if epsilon is None:
            epsilon = 0.00001
        else:
            epsilon = epsilon.f
        bn_var = bn_var + epsilon
        # alpha
        alpha = helper.get_attribute_by_name(gemm_node, 'alpha')
        if alpha is None:
            alpha = 1
        else:
            alpha = alpha.f
        gemm_b = gemm_b * alpha
        # beta
        beta = helper.get_attribute_by_name(gemm_node, 'beta')
        if beta is None:
            beta = 1
        else:
            beta = beta.f
        gemm_c = gemm_c * beta
        # transA
        transA = helper.get_attribute_by_name(gemm_node, 'transA')
        if transA is not None and transA.i == 1:
            raise RuntimeError("Do not support transA")
        # transB
        transB = helper.get_attribute_by_name(gemm_node, 'transB')
        if transB is not None and transB.i == 1:
            gemm_b = gemm_b.transpose()
        # Calculate new weights
        new_gemm_b = gemm_b * bn_scale / np.sqrt(bn_var)
        new_gemm_c = (gemm_c - bn_mean) * bn_scale / np.sqrt(bn_var) + bn_bias
        # Replace original weights
        new_gemm_b_node = helper.numpy_to_constant(gemm_b_node.name + '_fused',
                                                   new_gemm_b)
        new_gemm_c_node = helper.numpy_to_constant(gemm_c_node.name + '_fused',
                                                   new_gemm_c)
        g.node.extend([new_gemm_b_node, new_gemm_c_node])
        node_to_remove.extend([
            gemm_b_node, gemm_c_node, bn_node, bn_scale_node, bn_bias_node,
            bn_mean_node, bn_var_node
        ])
        # Modify attributes
        # alpha
        alpha = helper.get_attribute_by_name(gemm_node, 'alpha')
        if alpha is not None:
            alpha.f = 1.0
        # beta
        beta = helper.get_attribute_by_name(gemm_node, 'beta')
        if beta is not None:
            beta.f = 1.0
        # transB
        transB = helper.get_attribute_by_name(gemm_node, 'transB')
        if transB is not None:
            transB.i = 0
        # Connect the new graph
        gemm_node.input[1] = new_gemm_b_node.output[0]
        gemm_node.input[2] = new_gemm_c_node.output[0]
        gemm_b_value = helper.find_value_by_name(g, gemm_b_node.output[0])
        gemm_c_value = helper.find_value_by_name(g, gemm_c_node.output[0])
        gemm_b_value.name = new_gemm_b_node.output[0]
        gemm_c_value.name = new_gemm_c_node.output[0]
        gemm_value = helper.find_value_by_name(g, gemm_node.output[0])
        g.value_info.remove(gemm_value)
        gemm_node.output[0] = bn_node.output[0]
        for i in range(1, 5):
            value = helper.find_value_by_name(g, bn_node.input[i])
            g.value_info.remove(value)
    # Remove useless nodes
    for node in node_to_remove:
        g.node.remove(node)
    topological_sort(g)
Example #9
0
def fuse_BN_with_Reshape_into_Gemm(g):
    """Fuse the following BN into the previous Gemm, even with Reshape or \\
        Squeeze and Unsqueeze surrounding.

    :param g: the graph
    """
    node_to_remove = []
    for node in g.node:
        # Check for BN and Gemm pattern: Gemm A BN B
        # Find BatchNorm Node
        if node.op_type != 'BatchNormalization':
            continue
        bn_node = node
        # Find A Node
        a_node = helper.find_node_by_output_name(g, node.input[0])
        if a_node is None or len(a_node.input) == 0:
            continue
        # Find Gemm Node
        gemm_node = helper.find_node_by_output_name(g, a_node.input[0])
        if gemm_node is None or gemm_node.op_type != 'Gemm':
            continue
        # Find B Node
        b_node_list = helper.find_following_nodes_by_input_value_name(
            g, bn_node.output[0])
        if len(b_node_list) == 0:
            the_output = helper.find_output_by_name(g, bn_node.output[0])
            if the_output is None:
                continue
            b_node = None
        elif len(b_node_list) > 1:
            continue
        else:
            b_node = b_node_list[0]
        # Check for branches
        if len(
                helper.find_following_nodes_by_input_value_name(
                    g, gemm_node.output[0])) > 1:
            continue
        if len(
                helper.find_following_nodes_by_input_value_name(
                    g, a_node.output[0])) > 1:
            continue
        # Check type of A
        if a_node.op_type == 'Unsqueeze':
            axes = helper.get_attribute_by_name(a_node, 'axes')
            if axes.ints != [2]:
                continue
        elif a_node.op_type == 'Reshape':
            a = helper.constant_to_list(
                helper.find_node_by_output_name(g, a_node.input[1]))[1]
            if len(a) != 3 or a[2] != 1:
                continue
        else:
            continue
        # Check type of B
        if b_node is None:
            pass
        elif b_node.op_type == 'Flatten':
            pass
        elif b_node.op_type == 'Squeeze':
            axes = helper.get_attribute_by_name(a_node, 'axes')
            if axes.ints != [2]:
                continue
        elif b_node.op_type == 'Reshape':
            a = helper.constant_to_list(
                helper.find_node_by_output_name(g, b_node.input[1]))[1]
            if len(a) != 2:
                continue
        else:
            continue
        # Construct new Nodes
        # Get original weights
        gemm_b_node = helper.find_node_by_output_name(g, gemm_node.input[1])
        gemm_b = helper.constant_to_numpy(gemm_b_node)
        gemm_c_node = helper.find_node_by_output_name(g, gemm_node.input[2])
        gemm_c = helper.constant_to_numpy(gemm_c_node)
        bn_scale_node = helper.find_node_by_output_name(g, bn_node.input[1])
        bn_scale = helper.constant_to_numpy(bn_scale_node)
        bn_bias_node = helper.find_node_by_output_name(g, bn_node.input[2])
        bn_bias = helper.constant_to_numpy(bn_bias_node)
        bn_mean_node = helper.find_node_by_output_name(g, bn_node.input[3])
        bn_mean = helper.constant_to_numpy(bn_mean_node)
        bn_var_node = helper.find_node_by_output_name(g, bn_node.input[4])
        bn_var = helper.constant_to_numpy(bn_var_node)
        # Apply attributes
        # epsilon
        epsilon = helper.get_attribute_by_name(bn_node, 'epsilon')
        if epsilon is None:
            epsilon = 0.00001
        else:
            epsilon = epsilon.f
        bn_var = bn_var + epsilon
        # alpha
        alpha = helper.get_attribute_by_name(gemm_node, 'alpha')
        if alpha is None:
            alpha = 1
        else:
            alpha = alpha.f
        gemm_b = gemm_b * alpha
        # beta
        beta = helper.get_attribute_by_name(gemm_node, 'beta')
        if beta is None:
            beta = 1
        else:
            beta = beta.f
        gemm_c = gemm_c * beta
        # transA
        transA = helper.get_attribute_by_name(gemm_node, 'transA')
        if transA is not None and transA.i == 1:
            raise RuntimeError("Do not support transA")
        # transB
        transB = helper.get_attribute_by_name(gemm_node, 'transB')
        if transB is not None and transB.i == 1:
            gemm_b = gemm_b.transpose()
        # Calculate new weights
        new_gemm_b = gemm_b * bn_scale / np.sqrt(bn_var)
        new_gemm_c = (gemm_c - bn_mean) * bn_scale / np.sqrt(bn_var) + bn_bias
        # Replace original weights
        new_gemm_b_node = helper.numpy_to_constant(gemm_b_node.name + '_fused',
                                                   new_gemm_b)
        new_gemm_c_node = helper.numpy_to_constant(gemm_c_node.name + '_fused',
                                                   new_gemm_c)
        g.node.extend([new_gemm_b_node, new_gemm_c_node])
        # Modify attributes
        # alpha
        alpha = helper.get_attribute_by_name(gemm_node, 'alpha')
        if alpha is not None:
            alpha.f = 1.0
        # beta
        beta = helper.get_attribute_by_name(gemm_node, 'beta')
        if beta is not None:
            beta.f = 1.0
        # transB
        transB = helper.get_attribute_by_name(gemm_node, 'transB')
        if transB is not None:
            transB.i = 0
        # Remove useless nodes
        node_to_remove.extend([
            gemm_b_node, gemm_c_node, bn_node, bn_scale_node, bn_bias_node,
            bn_mean_node, bn_var_node, a_node
        ])
        if a_node.op_type == 'Reshape':
            node_to_remove.append(
                helper.find_node_by_output_name(g, a_node.input[1]))
        if b_node is not None:
            node_to_remove.append(b_node)
            if b_node.op_type == 'Reshape':
                node_to_remove.append(
                    helper.find_node_by_output_name(g, b_node.input[1]))
        # Delete useless value infos
        value = helper.find_value_by_name(g, a_node.output[0])
        g.value_info.remove(value)
        if a_node.op_type == 'Reshape':
            value = helper.find_value_by_name(g, a_node.input[1])
            g.value_info.remove(value)
        for i in range(1, 5):
            value = helper.find_value_by_name(g, bn_node.input[i])
            g.value_info.remove(value)
        value = helper.find_value_by_name(g, bn_node.output[0])
        if value is not None:
            g.value_info.remove(value)
        if b_node is not None:
            value = helper.find_value_by_name(g, gemm_node.output[0])
            g.value_info.remove(value)
            if b_node.op_type == 'Reshape':
                value = helper.find_value_by_name(g, b_node.input[1])
                g.value_info.remove(value)
        # Connect the new graph
        # Connect Gemm new weights
        gemm_node.input[1] = new_gemm_b_node.output[0]
        gemm_node.input[2] = new_gemm_c_node.output[0]
        gemm_b_value = helper.find_value_by_name(g, gemm_b_node.output[0])
        gemm_c_value = helper.find_value_by_name(g, gemm_c_node.output[0])
        gemm_b_value.name = new_gemm_b_node.output[0]
        gemm_b_value.type.tensor_type.shape.dim[
            0].dim_value = new_gemm_b.shape[0]
        gemm_b_value.type.tensor_type.shape.dim[
            1].dim_value = new_gemm_b.shape[1]
        gemm_c_value.name = new_gemm_c_node.output[0]
        if b_node is None:
            # If b node is None, set the Gemm output as the graph output
            output_value = helper.find_output_by_name(g, bn_node.output[0])
            g.output.remove(output_value)
            g.output.extend(
                [helper.find_value_by_name(g, gemm_node.output[0])])
        else:
            # Else, set node B output as gemm output
            gemm_node.output[0] = b_node.output[0]
    # Remove useless nodes
    for node in node_to_remove:
        g.node.remove(node)
    topological_sort(g)