示例#1
0
def inference_resize_shape(g):
    for node in g.node:
        if node.op_type != 'Resize':
            continue

        output_value = helper.find_value_by_name(g, node.output[0])
        output_value = helper.find_output_by_name(
            g, node.output[0]) if output_value is None else output_value
        if output_value is not None:
            continue

        # currently, only support 4 input
        if len(node.input) == 4:  # input: X, roi, scales, sizes
            shape_node = helper.find_node_by_output_name(g, node.input[3])
            if shape_node.op_type != 'Constant':
                continue

            _, shape_value = helper.constant_to_list(shape_node)
            output_value = onnx.helper.make_tensor_value_info(
                node.output[0], onnx.TensorProto.FLOAT,
                [int(v) for v in shape_value])
            g.value_info.extend([output_value])

            return True
    return False
示例#2
0
def fuse_Transpose_into_Constant(g):
    """
    Fuse Transpose layers into the Constant layers before

    :param g: the onnx graph
    """
    node_to_remove = []
    for node in g.node:
        if node.op_type != 'Transpose':
            continue
        prev_node = helper.find_node_by_output_name(g, node.input[0])
        if prev_node is None or prev_node.op_type != 'Constant':
            continue
        
        pre_shape, data_list = helper.constant_to_list(prev_node)
        w = np.reshape(data_list, pre_shape)
        w = w.transpose(node.attribute[0].ints)
        new_shape = w.shape
        w = w.flatten()
        
        new_tensor = onnx.helper.make_tensor(
            name=prev_node.name+'_data',
            data_type=prev_node.attribute[0].t.data_type,
            dims=new_shape,
            vals=w.tolist()
        )
        new_node = onnx.helper.make_node(
            'Constant',
            [],
            [node.output[0]],
            name=node.output[0],
            value=new_tensor
        )
        
        value_between = helper.find_value_by_name(g, prev_node.output[0])
        value_type = value_between.type.tensor_type.elem_type
        g.value_info.remove(value_between)

        g.node.extend([new_node])
        node_to_remove.append(node)
        node_to_remove.append(prev_node)
        
        if new_node.output[0] not in [i.name for i in g.value_info]:
            new_value = onnx.helper.make_tensor_value_info(
                name=new_node.output[0],
                elem_type=value_type,
                shape=new_shape 
                )
            g.value_info.extend([new_value])
            if new_node.output[0]:
                val_info_to_del = helper.find_value_by_name(g, new_node.output[0])
                g.value_info.remove(val_info_to_del)
    
    for node in node_to_remove:
        g.node.remove(node)
    
    topological_sort(g)
示例#3
0
def inference_resize_shape(g):
    for node in g.node:
        if node.op_type != 'Resize':
            continue

        output_value = helper.find_value_by_name(g, node.output[0])
        output_value = helper.find_output_by_name(
            g, node.output[0]) if output_value is None else output_value
        if output_value is not None:
            continue

        if len(node.input) == 4:  # input: X, roi, scales, sizes
            shape_node = helper.find_node_by_output_name(g, node.input[3])
            if shape_node.op_type != 'Constant':
                continue

            _, shape_value = helper.constant_to_list(shape_node)
            output_value = onnx.helper.make_tensor_value_info(
                node.output[0], onnx.TensorProto.FLOAT,
                [int(v) for v in shape_value])
            g.value_info.extend([output_value])
            return True
        else:
            # If output shape is not given, inference from scales
            # Get the input shape
            input_value = helper.find_value_by_name(g, node.input[0])
            if input_value is None:
                continue
            shape_value = helper.get_shape_from_value_info(input_value)
            scales_node = helper.find_node_by_output_name(g, node.input[2])
            if scales_node.op_type != 'Constant':
                continue
            _, scales_value = helper.constant_to_list(scales_node)
            for i in range(len(shape_value)):
                shape_value[i] *= scales_value[i]
            output_value = onnx.helper.make_tensor_value_info(
                node.output[0], onnx.TensorProto.FLOAT,
                [int(v) for v in shape_value])
            g.value_info.extend([output_value])
            return True
    return False
示例#4
0
def replace_ConstantOfShape_with_constant(g):
    """Replace Shape with Constant.\\
    This is the first step of reshape constant folding.

    :param g: the input graph\\
    :return: if anything modified, return true.
    """
    node_to_remove = []
    for node in g.node:
        # Find a Shape
        if node.op_type != 'ConstantOfShape':
            continue
        # Check  input
        input_value = helper.find_value_by_name(g, node.input[0])
        if input_value is None:
            input_value = helper.find_input_by_name(g, node.input[0])
        if input_value is None or len(
                input_value.type.tensor_type.shape.dim) == 0:
            continue

        # Replace to constant node
        pre_node = helper.find_node_by_output_name(g, node.input[0])
        _, target_shape = helper.constant_to_list(pre_node)

        value = helper.get_attribute_by_name(node, 'value').i

        node_name = node.output[0]
        new_node = helper.list_to_constant(node_name, [target_shape[0]],
                                           [value] * target_shape[0])

        g.node.extend([new_node])

        # remove old node
        node_to_remove.append(node)

        # delete value_info
        val_info_used = sum(
            [input_value.name in node.input for node in g.node])
        if val_info_used == 1:
            g.value_info.remove(input_value)

    replaced = True if len(node_to_remove) > 0 else False

    for node in node_to_remove:
        g.node.remove(node)

    topological_sort(g)

    return replaced
示例#5
0
def inference_upsample_shape(g):
    """For onnx v1.4.1+, onnx cannot inference upsample output shape. Let's\\
    do it ourselves. This function only inference the next upsample without\\
    output shape each time.

    :param g: the graph\\
    :return: True if any Upsample shape is generated. Otherwise, False.
    """
    for node in g.node:
        if node.op_type != 'Upsample':
            continue
        output_value = helper.find_value_by_name(g, node.output[0])
        if output_value is None:
            output_value = helper.find_output_by_name(g, node.output[0])
        if output_value and helper.get_shape_from_value_info(output_value):
            continue
        # Get input shape
        input_value = helper.find_value_by_name(g, node.input[0])
        if input_value is None:
            continue
            #raise RuntimeError("Shape for {} has not been generated.".format(node.input[0]))
        if not helper.get_shape_from_value_info(input_value):
            continue
            #raise RuntimeError("Shape for {} is empty.".format(node.input[0]))
        input_shape = helper.get_shape_from_value_info(input_value)
        # Get upsample weight
        weight_node = helper.find_node_by_output_name(g, node.input[1])
        weight_shape, weight = helper.constant_to_list(weight_node)
        if len(input_shape) != weight_shape[0]:
            raise RuntimeError(
                "Unmatch input shape and weight shape: {} vs {}".format(
                    input_shape, weight_shape))
        # Calculate shape
        output_shape = list(input_shape)
        for i in range(len(output_shape)):
            output_shape[i] = int(input_shape[i] * weight[i])
        output_value = onnx.helper.make_tensor_value_info(
            node.output[0], input_value.type.tensor_type.elem_type,
            output_shape)
        g.value_info.extend([output_value])
        return True
    return False
示例#6
0
def replace_mul_to_bn(g):
    """Replace single Mul node with Batchnorm node.
    :param g: input graph.
    :return:
    """
    node_to_del = []
    for node in g.node:
        if node.op_type != 'Mul':
            continue

        mul_op_node = node

        # only support one input node
        if len(mul_op_node.input) != 2:  # OP node and value node
            continue

        input_op_node_name = mul_op_node.input[0]
        mul_value_node = helper.find_node_by_output_name(
            g, mul_op_node.input[1])
        if not mul_value_node or mul_value_node.op_type != 'Constant':
            continue

        _, previous_node_output_shape = helper.find_size_shape_from_value(
            helper.find_value_by_name(g, input_op_node_name))
        scale_shape, scale_data = helper.constant_to_list(mul_value_node)

        # only allow 4 dim data input due to the hardware limitation
        if len(previous_node_output_shape) != 4:
            continue

        # channel dimension
        c_dim = previous_node_output_shape[1]

        # only allow channelwise mul or const mul
        if scale_shape != [1, c_dim, 1, 1] and scale_shape != 1:
            continue

        ones = [1.0] * c_dim
        zeros = [0.0] * c_dim
        muls = scale_data * c_dim
        bn_name = mul_op_node.output[0]
        mean_value_node = helper.list_to_constant(bn_name + '_mean',
                                                  np.array(zeros).shape, zeros)
        variance_value_node = helper.list_to_constant(bn_name + '_var',
                                                      np.array(ones).shape,
                                                      ones)
        bias_value_node = helper.list_to_constant(bn_name + '_add',
                                                  np.array(zeros).shape, zeros)
        new_mul_value_node = helper.list_to_constant(bn_name + '_mul',
                                                     np.array(muls).shape,
                                                     muls)

        bn_node = onnx.helper.make_node('BatchNormalization', [
            input_op_node_name, new_mul_value_node.output[0],
            bias_value_node.output[0], mean_value_node.output[0],
            variance_value_node.output[0]
        ], [mul_op_node.output[0]],
                                        name=bn_name,
                                        epsilon=0.00000001)

        mid_val_info = helper.find_value_by_name(g, mul_op_node.output[0])
        scale_val_info = helper.find_value_by_name(g, mul_value_node.output[0])
        g.value_info.remove(mid_val_info)
        g.value_info.remove(scale_val_info)

        g.node.extend([bn_node])
        g.node.extend([mean_value_node])
        g.node.extend([variance_value_node])
        g.node.extend([bias_value_node])
        g.node.extend([new_mul_value_node])

        node_to_del.extend([mul_op_node])
        node_to_del.extend([mul_value_node])

    while node_to_del:
        g.node.remove(node_to_del.pop())

    topological_sort(g)
示例#7
0
def fuse_mul_and_add_into_gemm(g):
    node_to_del = []
    for node in g.node:
        if node.op_type != 'Add':
            continue
        add_node = node
        mul_node = helper.find_node_by_output_name(g, add_node.input[0])
        if not mul_node or mul_node.op_type != 'Mul':
            continue
        mul_const = helper.find_node_by_output_name(g, mul_node.input[1])
        if not mul_const or mul_const.op_type != 'Constant':
            continue
        add_const = helper.find_node_by_output_name(g, add_node.input[1])
        if not add_const or add_const.op_type != 'Constant':
            continue

        input_val = helper.find_value_by_name(g, mul_node.input[0])
        if not input_val:
            input_val = helper.find_input_by_name(g, mul_node.input[0])
        if not input_val:
            continue

        _, input_shape = helper.find_size_shape_from_value(input_val)
        if not input_shape:
            continue

        dim = int(np.prod(input_shape))
        if input_shape != [1, dim]:
            continue

        mul_const_shape, mul_const_data = helper.constant_to_list(mul_const)
        add_const_shape, __ = helper.constant_to_list(add_const)

        if len(mul_const_shape) != 1 or mul_const_shape[0] != dim:
            continue
        if len(add_const_shape) != 1 or add_const_shape[0] != dim:
            continue

        b_data = np.zeros([dim, dim])
        for i in range(dim):
            b_data[i][i] = mul_const_data[i]
        b_data = b_data.flatten().tolist()
        b_tensor = onnx.helper.make_tensor(
            name=mul_const.name + '_tensor',
            data_type=mul_const.attribute[0].t.data_type,
            dims=[dim, dim],
            vals=b_data)
        b_const_node = onnx.helper.make_node('Constant', [],
                                             [mul_const.output[0]],
                                             value=b_tensor,
                                             name=mul_const.output[0])

        add_const.attribute[0].t.dims.insert(0, 1)

        gemm_node = onnx.helper.make_node(
            'Gemm',
            [mul_node.input[0], b_const_node.output[0], add_const.output[0]],
            [add_node.output[0]],
            name=add_node.output[0])

        g.node.extend([gemm_node, b_const_node])
        node_to_del.extend([mul_const, mul_node, add_node])

        val_info_mid = helper.find_value_by_name(g, mul_node.output[0])
        val_info_mul_const = helper.find_value_by_name(g, mul_const.output[0])
        val_info_add_const = helper.find_value_by_name(g, add_const.output[0])
        if val_info_mid:
            g.value_info.remove(val_info_mid)
        if val_info_mul_const:
            g.value_info.remove(val_info_mul_const)
        if val_info_add_const:
            g.value_info.remove(val_info_add_const)

    while node_to_del:
        g.node.remove(node_to_del.pop())

    topological_sort(g)
示例#8
0
def fuse_mul_and_add_into_bn(g):
    node_to_del = []
    for node in g.node:
        if node.op_type != 'Add':
            continue
        add_node = node
        input_nodes_add = [
            helper.find_node_by_output_name(g, input_name)
            for input_name in add_node.input
        ]
        if any([n == None for n in input_nodes_add]):
            continue
        mul_node, const_add = None, None
        for input_node_add in input_nodes_add:
            if input_node_add.op_type == 'Mul':
                mul_node = input_node_add
            elif input_node_add.op_type == 'Constant':
                const_add = input_node_add
            else:
                pass
        if not mul_node or not const_add:
            continue
        data_input_name, const_mul = None, None
        for input_name in mul_node.input:
            input_node = helper.find_node_by_output_name(g, input_name)
            if not input_node:
                data_input_name = input_name
            elif input_node.op_type == 'Constant':
                if not const_mul:
                    const_mul = input_node
                else:
                    data_input_name = input_name
            else:
                data_input_name = input_name

        if not const_mul:
            continue

        scale_shape, scale_data = helper.constant_to_list(const_mul)
        bais_shape, __ = helper.constant_to_list(const_add)
        c_dim = len(scale_data)
        if scale_shape != bais_shape:
            continue

        _, previous_node_output_shape = helper.find_size_shape_from_value(
            helper.find_value_by_name(g, data_input_name))
        # only allow 4 dim data input due to the hardware limitation
        if len(previous_node_output_shape) != 4:
            continue

        # check if mul's dim and input channel dimension are matched
        if previous_node_output_shape[1] != c_dim:
            continue

        if scale_shape == [1, c_dim, 1, 1]:

            # remove all '1'
            for _ in range(3):
                const_add.attribute[0].t.dims.remove(1)
                const_mul.attribute[0].t.dims.remove(1)

        elif scale_shape == [1, c_dim]:

            # remove all '1'
            const_add.attribute[0].t.dims.remove(1)
            const_mul.attribute[0].t.dims.remove(1)

        else:
            continue

        bn_name = add_node.output[0]
        const_mean = helper.list_to_constant(bn_name + '_mean', [c_dim],
                                             [0.0 for _ in range(c_dim)])
        const_var = helper.list_to_constant(bn_name + '_var', [c_dim],
                                            [1.0 for _ in range(c_dim)])

        bn_node = onnx.helper.make_node(
            'BatchNormalization',
            [data_input_name, const_mul.output[0], const_add.output[0],\
                const_mean.output[0], const_var.output[0]],
            [add_node.output[0]],
            name=bn_name,
            epsilon=0.00000001
        )

        mid_val_info = helper.find_value_by_name(g, mul_node.output[0])
        scale_val_info = helper.find_value_by_name(g, const_mul.output[0])
        bais_val_info = helper.find_value_by_name(g, const_add.output[0])
        g.value_info.remove(mid_val_info)
        g.value_info.remove(scale_val_info)
        g.value_info.remove(bais_val_info)

        new_scale_val_info = onnx.helper.make_tensor_value_info(
            const_mul.output[0], const_mul.attribute[0].t.data_type, [c_dim])
        new_bais_val_info = onnx.helper.make_tensor_value_info(
            const_add.output[0], const_add.attribute[0].t.data_type, [c_dim])
        mean_val_info = onnx.helper.make_tensor_value_info(
            const_mean.output[0], const_mean.attribute[0].t.data_type, [c_dim])
        var_val_info = onnx.helper.make_tensor_value_info(
            const_var.output[0], const_var.attribute[0].t.data_type, [c_dim])

        g.value_info.extend([new_scale_val_info])
        g.value_info.extend([new_bais_val_info])
        g.value_info.extend([mean_val_info])
        g.value_info.extend([var_val_info])
        g.node.extend([bn_node])
        g.node.extend([const_mean])
        g.node.extend([const_var])
        node_to_del.extend([mul_node, add_node])

    while node_to_del:
        g.node.remove(node_to_del.pop())

    topological_sort(g)
示例#9
0
def fuse_BN_with_Reshape_into_Gemm(g):
    """Fuse the following BN into the previous Gemm, even with Reshape or \\
        Squeeze and Unsqueeze surrounding.

    :param g: the graph
    """
    node_to_remove = []
    for node in g.node:
        # Check for BN and Gemm pattern: Gemm A BN B
        # Find BatchNorm Node
        if node.op_type != 'BatchNormalization':
            continue
        bn_node = node
        # Find A Node
        a_node = helper.find_node_by_output_name(g, node.input[0])
        if a_node is None or len(a_node.input) == 0:
            continue
        # Find Gemm Node
        gemm_node = helper.find_node_by_output_name(g, a_node.input[0])
        if gemm_node is None or gemm_node.op_type != 'Gemm':
            continue
        # Find B Node
        b_node_list = helper.find_following_nodes_by_input_value_name(
            g, bn_node.output[0])
        if len(b_node_list) == 0:
            the_output = helper.find_output_by_name(g, bn_node.output[0])
            if the_output is None:
                continue
            b_node = None
        elif len(b_node_list) > 1:
            continue
        else:
            b_node = b_node_list[0]
        # Check for branches
        if len(
                helper.find_following_nodes_by_input_value_name(
                    g, gemm_node.output[0])) > 1:
            continue
        if len(
                helper.find_following_nodes_by_input_value_name(
                    g, a_node.output[0])) > 1:
            continue
        # Check type of A
        if a_node.op_type == 'Unsqueeze':
            axes = helper.get_attribute_by_name(a_node, 'axes')
            if axes.ints != [2]:
                continue
        elif a_node.op_type == 'Reshape':
            a = helper.constant_to_list(
                helper.find_node_by_output_name(g, a_node.input[1]))[1]
            if len(a) != 3 or a[2] != 1:
                continue
        else:
            continue
        # Check type of B
        if b_node is None:
            pass
        elif b_node.op_type == 'Flatten':
            pass
        elif b_node.op_type == 'Squeeze':
            axes = helper.get_attribute_by_name(a_node, 'axes')
            if axes.ints != [2]:
                continue
        elif b_node.op_type == 'Reshape':
            a = helper.constant_to_list(
                helper.find_node_by_output_name(g, b_node.input[1]))[1]
            if len(a) != 2:
                continue
        else:
            continue
        # Construct new Nodes
        # Get original weights
        gemm_b_node = helper.find_node_by_output_name(g, gemm_node.input[1])
        gemm_b = helper.constant_to_numpy(gemm_b_node)
        gemm_c_node = helper.find_node_by_output_name(g, gemm_node.input[2])
        gemm_c = helper.constant_to_numpy(gemm_c_node)
        bn_scale_node = helper.find_node_by_output_name(g, bn_node.input[1])
        bn_scale = helper.constant_to_numpy(bn_scale_node)
        bn_bias_node = helper.find_node_by_output_name(g, bn_node.input[2])
        bn_bias = helper.constant_to_numpy(bn_bias_node)
        bn_mean_node = helper.find_node_by_output_name(g, bn_node.input[3])
        bn_mean = helper.constant_to_numpy(bn_mean_node)
        bn_var_node = helper.find_node_by_output_name(g, bn_node.input[4])
        bn_var = helper.constant_to_numpy(bn_var_node)
        # Apply attributes
        # epsilon
        epsilon = helper.get_attribute_by_name(bn_node, 'epsilon')
        if epsilon is None:
            epsilon = 0.00001
        else:
            epsilon = epsilon.f
        bn_var = bn_var + epsilon
        # alpha
        alpha = helper.get_attribute_by_name(gemm_node, 'alpha')
        if alpha is None:
            alpha = 1
        else:
            alpha = alpha.f
        gemm_b = gemm_b * alpha
        # beta
        beta = helper.get_attribute_by_name(gemm_node, 'beta')
        if beta is None:
            beta = 1
        else:
            beta = beta.f
        gemm_c = gemm_c * beta
        # transA
        transA = helper.get_attribute_by_name(gemm_node, 'transA')
        if transA is not None and transA.i == 1:
            raise RuntimeError("Do not support transA")
        # transB
        transB = helper.get_attribute_by_name(gemm_node, 'transB')
        if transB is not None and transB.i == 1:
            gemm_b = gemm_b.transpose()
        # Calculate new weights
        new_gemm_b = gemm_b * bn_scale / np.sqrt(bn_var)
        new_gemm_c = (gemm_c - bn_mean) * bn_scale / np.sqrt(bn_var) + bn_bias
        # Replace original weights
        new_gemm_b_node = helper.numpy_to_constant(gemm_b_node.name + '_fused',
                                                   new_gemm_b)
        new_gemm_c_node = helper.numpy_to_constant(gemm_c_node.name + '_fused',
                                                   new_gemm_c)
        g.node.extend([new_gemm_b_node, new_gemm_c_node])
        # Modify attributes
        # alpha
        alpha = helper.get_attribute_by_name(gemm_node, 'alpha')
        if alpha is not None:
            alpha.f = 1.0
        # beta
        beta = helper.get_attribute_by_name(gemm_node, 'beta')
        if beta is not None:
            beta.f = 1.0
        # transB
        transB = helper.get_attribute_by_name(gemm_node, 'transB')
        if transB is not None:
            transB.i = 0
        # Remove useless nodes
        node_to_remove.extend([
            gemm_b_node, gemm_c_node, bn_node, bn_scale_node, bn_bias_node,
            bn_mean_node, bn_var_node, a_node
        ])
        if a_node.op_type == 'Reshape':
            node_to_remove.append(
                helper.find_node_by_output_name(g, a_node.input[1]))
        if b_node is not None:
            node_to_remove.append(b_node)
            if b_node.op_type == 'Reshape':
                node_to_remove.append(
                    helper.find_node_by_output_name(g, b_node.input[1]))
        # Delete useless value infos
        value = helper.find_value_by_name(g, a_node.output[0])
        g.value_info.remove(value)
        if a_node.op_type == 'Reshape':
            value = helper.find_value_by_name(g, a_node.input[1])
            g.value_info.remove(value)
        for i in range(1, 5):
            value = helper.find_value_by_name(g, bn_node.input[i])
            g.value_info.remove(value)
        value = helper.find_value_by_name(g, bn_node.output[0])
        if value is not None:
            g.value_info.remove(value)
        if b_node is not None:
            value = helper.find_value_by_name(g, gemm_node.output[0])
            g.value_info.remove(value)
            if b_node.op_type == 'Reshape':
                value = helper.find_value_by_name(g, b_node.input[1])
                g.value_info.remove(value)
        # Connect the new graph
        # Connect Gemm new weights
        gemm_node.input[1] = new_gemm_b_node.output[0]
        gemm_node.input[2] = new_gemm_c_node.output[0]
        gemm_b_value = helper.find_value_by_name(g, gemm_b_node.output[0])
        gemm_c_value = helper.find_value_by_name(g, gemm_c_node.output[0])
        gemm_b_value.name = new_gemm_b_node.output[0]
        gemm_b_value.type.tensor_type.shape.dim[
            0].dim_value = new_gemm_b.shape[0]
        gemm_b_value.type.tensor_type.shape.dim[
            1].dim_value = new_gemm_b.shape[1]
        gemm_c_value.name = new_gemm_c_node.output[0]
        if b_node is None:
            # If b node is None, set the Gemm output as the graph output
            output_value = helper.find_output_by_name(g, bn_node.output[0])
            g.output.remove(output_value)
            g.output.extend(
                [helper.find_value_by_name(g, gemm_node.output[0])])
        else:
            # Else, set node B output as gemm output
            gemm_node.output[0] = b_node.output[0]
    # Remove useless nodes
    for node in node_to_remove:
        g.node.remove(node)
    topological_sort(g)
示例#10
0
def pattern_matmul_mul_add(g, matmul_node):
    # Check node match - Mul node
    next_nodes = helper.find_nodes_by_input_name(g, matmul_node.output[0])
    if len(next_nodes) != 1:
        return
    if next_nodes[0].op_type != 'Mul':
        return
    mul_node = next_nodes[0]
    # Check node match - Add node
    next_nodes = helper.find_nodes_by_input_name(g, mul_node.output[0])
    if len(next_nodes) != 1:
        return
    if next_nodes[0].op_type != 'Add':
        return
    add_node = next_nodes[0]
    # Check Mul weight
    mul_weight_node = helper.find_node_by_output_name(g, mul_node.input[1])
    if mul_weight_node.op_type != 'Constant':
        return
    weight_size, mul_weight = helper.constant_to_list(mul_weight_node)
    for i in mul_weight:
        if i != 1:
            return
    channel = weight_size[0]
    # Check Add weight
    add_weight_node = helper.find_node_by_output_name(g, add_node.input[1])
    if add_weight_node.op_type != 'Constant':
        return
    # Check MatMul weight to see if it need weight broadcast
    matmul_weight_node = helper.find_node_by_output_name(g, matmul_node.input[1])
    matmul_weight = helper.constant_to_numpy(matmul_weight_node)
    if matmul_weight.shape[1] == 1:
        # Weight broadcast
        new_matmul_weight = np.tile(matmul_weight, channel)
        new_matmul_weight_node = helper.numpy_to_constant(matmul_weight_node.name, new_matmul_weight)
        g.node.remove(matmul_weight_node)
        g.node.extend([new_matmul_weight_node])
    value = helper.find_value_by_name(g, matmul_weight_node.output[0])
    if value is not None:
        g.value_info.remove(value)
    # Remove Mul node
    g.node.remove(mul_weight_node)
    value = helper.find_value_by_name(g, mul_weight_node.output[0])
    if value is not None:
        g.value_info.remove(value)
    g.node.remove(mul_node)
    value = helper.find_value_by_name(g, mul_node.output[0])
    if value is not None:
        g.value_info.remove(value)
    # Fuse Matmul and Add
    gemm_node = onnx.helper.make_node(
        'Gemm',
        [matmul_node.input[0], matmul_node.input[1], add_node.input[1]],
        [add_node.output[0]],
        name = matmul_node.name,
        alpha = 1.0,
        beta = 1.0,
        transA = 0,
        transB = 0
    )
    g.node.extend([gemm_node])
    # Clean up
    g.node.remove(matmul_node)
    g.node.remove(add_node)
    value = helper.find_value_by_name(g, matmul_node.output[0])
    if value is not None:
        g.value_info.remove(value)
    other.topological_sort(g)
示例#11
0
def replace_add_to_bn(g):
    """Replace single Add node with Batchnorm node.
    :param g: input graph.
    :return:
    """
    node_to_del = []
    for node in g.node:
        if node.op_type != 'Add':
            continue

        add_op_node = node

        # only support one input node
        if len(add_op_node.input) != 2:  # OP node and value node
            continue

        input_op_node_name = add_op_node.input[0]
        add_value_node = helper.find_node_by_output_name(
            g, add_op_node.input[1])
        if not add_value_node or add_value_node.op_type != 'Constant':
            continue

        prev_shape_value_info = helper.find_value_by_name(
            g, input_op_node_name)
        prev_shape_value_info = helper.find_input_by_name(
            g, input_op_node_name
        ) if prev_shape_value_info is None else prev_shape_value_info
        if prev_shape_value_info is None:
            continue

        _, previous_node_output_shape = helper.find_size_shape_from_value(
            prev_shape_value_info)
        bias_shape, bias_data = helper.constant_to_list(add_value_node)

        # channel dimension
        c_dim = previous_node_output_shape[1] if len(
            previous_node_output_shape) > 1 else 1

        # only allow channelwise mul or const mul
        if bias_shape != [1, c_dim, 1, 1] and bias_shape != 1:
            continue

        ones = [1.0] * c_dim
        zeros = [0.0] * c_dim
        # If bias is a scaler, expand it.
        if len(bias_data) == 1:
            bias = bias_data * c_dim
        else:
            bias = bias_data
        bn_name = add_op_node.output[0]
        mean_value_node = helper.list_to_constant(bn_name + '_mean',
                                                  np.array(zeros).shape, zeros)
        variance_value_node = helper.list_to_constant(bn_name + '_var',
                                                      np.array(ones).shape,
                                                      ones)
        scale_value_node = helper.list_to_constant(bn_name + '_mul',
                                                   np.array(ones).shape, ones)
        new_add_value_node = helper.list_to_constant(bn_name + '_add',
                                                     np.array(bias).shape,
                                                     bias)

        bn_node = onnx.helper.make_node('BatchNormalization', [
            input_op_node_name, scale_value_node.output[0],
            new_add_value_node.output[0], mean_value_node.output[0],
            variance_value_node.output[0]
        ], [add_op_node.output[0]],
                                        name=bn_name,
                                        epsilon=0.00000001)

        add_val_info = helper.find_value_by_name(g, add_value_node.output[0])
        g.value_info.remove(add_val_info)

        g.node.extend([bn_node])
        g.node.extend([mean_value_node])
        g.node.extend([variance_value_node])
        g.node.extend([scale_value_node])
        g.node.extend([new_add_value_node])

        node_to_del.extend([add_op_node])
        node_to_del.extend([add_value_node])

    while node_to_del:
        g.node.remove(node_to_del.pop())

    topological_sort(g)