Ejemplo n.º 1
0
def fuse_Add_into_Conv(g):
    """
    Fuse Transpose layers into the Constant layers before

    :param g: the onnx graph
    """
    node_to_remove = []
    for node in g.node:
        if node.op_type != 'Add':
            continue
        conv_node = helper.find_node_by_output_name(g, node.input[0])
        cons_node = helper.find_node_by_output_name(g, node.input[1])
        if conv_node is None or cons_node is None:
            continue
        if conv_node.op_type != 'Conv' or cons_node.op_type != 'Constant':
            continue
        if len(conv_node.input) > 2:
            continue
        # This layer should be fused. Connect constant node into convolution node.
        add_node = node
        conv_node.input.extend([cons_node.output[0]])
        old_value = helper.find_value_by_name(g, conv_node.output[0])
        conv_node.output[0] = add_node.output[0]
        # Remove origin conv_node_output
        g.value_info.remove(old_value)
        # Remove current node
        node_to_remove.append(add_node)
    # Apply changes to the model
    for node in node_to_remove:
        g.node.remove(node)
Ejemplo n.º 2
0
def polish_RESIZE_input_param_node(g, resize_node_name):
    resize_node = helper.find_node_by_output_name(g, resize_node_name)

    shape_data_node = helper.find_node_by_output_name(g, resize_node.input[3])
    shape_data = helper.constant_to_numpy(shape_data_node).astype(int)

    # handle 0 batch size which is invalid
    if shape_data[0] == 0:
        shape_data[0] = 1

    pre_node_output_value_info = helper.find_value_by_name(
        g, resize_node.input[0])
    ori_shape = np.array([
        pre_node_output_value_info.type.tensor_type.shape.dim[0].dim_value,
        pre_node_output_value_info.type.tensor_type.shape.dim[1].dim_value,
        pre_node_output_value_info.type.tensor_type.shape.dim[2].dim_value,
        pre_node_output_value_info.type.tensor_type.shape.dim[3].dim_value
    ])

    resize_node.input.remove(resize_node.input[3])

    resize_scales = np.array(shape_data / ori_shape).astype(float)
    resize_scale_node = helper.list_to_constant(
        'resize_scales_node_' + resize_node.name,
        resize_scales.shape,
        resize_scales,
        data_type=onnx.helper.TensorProto.FLOAT)

    resize_node.input[2] = resize_scale_node.name
    g.node.extend([resize_scale_node])

    other.topological_sort(g)
Ejemplo n.º 3
0
def add_bn_before_add(g):
    for n in g.node:
        # Find merge node (Add)
        if n.op_type != 'Add':
            continue
        if len(n.input) != 2:
            continue
        # Get two inputs
        input_node_a = helper.find_node_by_output_name(g, n.input[0])
        input_node_b = helper.find_node_by_output_name(g, n.input[1])
        # Skip constant input add
        if input_node_a is None or input_node_a.op_type == 'Constant':
            continue
        if input_node_b is None or input_node_b.op_type == 'Constant':
            continue

        def add_bn_after(prev_node):
            # Get the channel number from value info
            value_name = prev_node.output[0]
            value = helper.find_value_by_name(g, value_name)
            shape = helper.get_shape_from_value_info(value)
            channel = shape[1]
            # Construct 4 weights
            node_name = value_name + "_nop_bn"
            ones = [1.0] * channel
            zeros = [0.0] * channel
            scale_node = helper.list_to_constant(node_name + "_scale",
                                                 [channel], ones)
            bias_node = helper.list_to_constant(node_name + "_bias", [channel],
                                                zeros)
            mean_node = helper.list_to_constant(node_name + "_mean", [channel],
                                                zeros)
            var_node = helper.list_to_constant(node_name + "_var", [channel],
                                               ones)
            # Construct BN node
            bn_node = onnx.helper.make_node("BatchNormalization", [
                value_name, scale_node.output[0], bias_node.output[0],
                mean_node.output[0], var_node.output[0]
            ], [node_name],
                                            name=node_name,
                                            epsilon=0.00000001)
            # Reconnect the graph
            replace_node_input(n, value_name, node_name)
            # Add node to the graph
            g.node.extend(
                [bn_node, scale_node, bias_node, mean_node, var_node])

        if not input_node_a.op_type == 'BatchNormalization' or len(
                helper.find_following_nodes_by_input_value_name(
                    g, input_node_a.output[0])) > 1:
            add_bn_after(input_node_a)
        if not input_node_b.op_type == 'BatchNormalization' or len(
                helper.find_following_nodes_by_input_value_name(
                    g, input_node_b.output[0])) > 1:
            add_bn_after(input_node_b)
    topological_sort(g)
Ejemplo n.º 4
0
def add_bn_on_skip_branch(g):
    for n in g.node:
        # Find merge node (Add)
        if n.op_type != 'Add':
            continue
        if len(n.input) != 2:
            continue
        # TODO: Still need to consider more cases
        # Check if skip branch exist
        input_node_a = helper.find_node_by_output_name(g, n.input[0])
        output_of_input_node_a = helper.find_nodes_by_input_name(
            g, input_node_a.output[0])
        input_node_b = helper.find_node_by_output_name(g, n.input[1])
        output_of_input_node_b = helper.find_nodes_by_input_name(
            g, input_node_b.output[0])
        if len(output_of_input_node_a) == 1 and len(
                output_of_input_node_b) == 1:
            continue
        if len(output_of_input_node_a) == 2:
            split_node = input_node_a
        elif len(output_of_input_node_b) == 2:
            split_node = input_node_b
        else:
            continue
        # Get the channel number from value info
        value_name = split_node.output[0]
        value = helper.find_value_by_name(g, value_name)
        shape = helper.get_shape_from_value_info(value)
        channel = shape[1]
        # Construct 4 weights
        node_name = value_name + "_nop_bn"
        ones = [1.0] * channel
        zeros = [0.0] * channel
        scale_node = helper.list_to_constant(node_name + "_scale", [channel],
                                             ones)
        bias_node = helper.list_to_constant(node_name + "_bias", [channel],
                                            zeros)
        mean_node = helper.list_to_constant(node_name + "_mean", [channel],
                                            zeros)
        var_node = helper.list_to_constant(node_name + "_var", [channel], ones)
        # Construct BN node
        bn_node = onnx.helper.make_node("BatchNormalization", [
            value_name, scale_node.output[0], bias_node.output[0],
            mean_node.output[0], var_node.output[0]
        ], [node_name],
                                        name=node_name)
        # Reconnect the graph
        replace_node_input(n, value_name, node_name)
        # Add node to the graph
        g.node.extend([bn_node, scale_node, bias_node, mean_node, var_node])
    topological_sort(g)
Ejemplo n.º 5
0
def make_UpsamplingBilinear2d_value_info(g, resize_node_name):
    resize_node = helper.find_node_by_output_name(g, resize_node_name)

    shape_data_node = helper.find_node_by_output_name(g, resize_node.input[3])
    shape_data = helper.constant_to_numpy(shape_data_node).astype(int)
    l_shape_data = list(shape_data)
    if l_shape_data[0] == 0:
        l_shape_data[0] = 1 + l_shape_data[0]
    shape_data = np.array(l_shape_data)

    new_output_value_info = onnx.helper.make_tensor_value_info(
        resize_node.output[0], onnx.helper.TensorProto.FLOAT,
        shape_data.tolist())

    g.value_info.extend([new_output_value_info])
Ejemplo n.º 6
0
def replace_Reshape_with_Flatten(g):
    """
    Replace Reshape node into Flatten node if applicable.

    :param g: the onnx graph
    """
    node_to_remove = []
    for node in g.node:
        if node.op_type != 'Reshape':
            continue
        found = False
        # Flatten must be followed by Gemm
        for i in g.node:
            if len(i.input) == 0 or i.input[0] != node.output[0]:
                continue
            if i.op_type == 'Gemm':
                found = True
                break
        if not found:
            continue
        shape_node = helper.find_node_by_output_name(g, node.input[1])
        if shape_node.op_type != 'Constant':
            continue
        # Replace it
        node.op_type = "Flatten"
        for _ in range(len(node.attribute)):
            node.attribute.pop()
        shape_value = helper.find_value_by_name(g, shape_node.output[0])
        node.input.pop()
        node_to_remove.append(shape_node)
        g.value_info.remove(shape_value)
    for node in node_to_remove:
        g.node.remove(node)
Ejemplo n.º 7
0
def inference_resize_shape(g):
    for node in g.node:
        if node.op_type != 'Resize':
            continue

        output_value = helper.find_value_by_name(g, node.output[0])
        output_value = helper.find_output_by_name(
            g, node.output[0]) if output_value is None else output_value
        if output_value is not None:
            continue

        # currently, only support 4 input
        if len(node.input) == 4:  # input: X, roi, scales, sizes
            shape_node = helper.find_node_by_output_name(g, node.input[3])
            if shape_node.op_type != 'Constant':
                continue

            _, shape_value = helper.constant_to_list(shape_node)
            output_value = onnx.helper.make_tensor_value_info(
                node.output[0], onnx.TensorProto.FLOAT,
                [int(v) for v in shape_value])
            g.value_info.extend([output_value])

            return True
    return False
Ejemplo n.º 8
0
def change_first_conv_from_bgr_to_rgb(m):
    """For input channel format BGR model, use this function to change the first
    conv weight to adapt the input into RGB.

    :param m: the model proto
    """
    # Check for first node.
    g = m.graph
    input_name = g.input[0].name
    first_nodes = helper.find_following_nodes_by_input_value_name(
        g, input_name)
    if len(first_nodes) > 1:
        return False
    first_node = first_nodes[0]
    # Now we have the first node. Check this first node.
    if first_node.op_type != 'Conv':
        return False
    weight_value = helper.find_value_by_name(g, first_node.input[1])
    weight_shape = helper.get_shape_from_value_info(weight_value)
    if weight_shape[1] != 3:
        return False
    # Do weight shuffle
    weight_node = helper.find_node_by_output_name(g, weight_value.name)
    weight_np = helper.constant_to_numpy(weight_node)
    b_channel = np.expand_dims(weight_np[:, 0, :, :], axis=1)
    g_channel = np.expand_dims(weight_np[:, 1, :, :], axis=1)
    r_channel = np.expand_dims(weight_np[:, 2, :, :], axis=1)
    new_np = np.concatenate((r_channel, g_channel, b_channel), axis=1)
    new_node = helper.numpy_to_constant(weight_value.name, new_np)
    # Replace the weight and topological sort
    g.node.remove(weight_node)
    g.node.extend([new_node])
    other.topological_sort(g)
    return True
Ejemplo n.º 9
0
def fuse_conv_and_add_into_conv(g):
    node_to_del = []
    for node in g.node:
        if node.op_type != 'Add':
            continue
        add_node = node
        add_const = helper.find_node_by_output_name(g, add_node.input[1])
        if not add_const or add_const.op_type != 'Constant':
            continue

        conv_node = helper.find_node_by_output_name(g, add_node.input[0])
        if not conv_node or conv_node.op_type != 'Conv':
            continue
        weight_node = helper.find_node_by_output_name(g, conv_node.input[1])
        if not weight_node or weight_node.op_type != 'Constant':
            continue

        m_dim = weight_node.attribute[0].t.dims[0]
        if add_const.attribute[0].t.dims != [1, m_dim, 1, 1]:
            continue
        for _ in range(3):
            add_const.attribute[0].t.dims.remove(1)

        conv_node.input.extend([add_const.output[0]])
        conv_node.output.pop()
        conv_node.output.extend([add_node.output[0]])

        node_to_del.append(add_node)

        old_add_const_val_info = helper.find_value_by_name(
            g, add_node.input[1])
        old_conv_output_val_info = helper.find_value_by_name(
            g, conv_node.output[0])
        if old_add_const_val_info:
            g.value_info.remove(old_add_const_val_info)
        if old_conv_output_val_info:
            g.value_info.remove(old_conv_output_val_info)

        new_add_const_val_info = onnx.helper.make_tensor_value_info(
            add_const.output[0], add_const.attribute[0].t.data_type,
            add_const.attribute[0].t.dims)
        g.value_info.extend([new_add_const_val_info])

    while node_to_del:
        g.node.remove(node_to_del.pop())

    topological_sort(g)
Ejemplo n.º 10
0
def replace_depthwise_1x1_with_bn(g):
    """Replace 1x1 DepthwiseConv node into BN node if applicable.

    :param g: the onnx graph
    """
    node_to_remove = []
    for node in g.node:
        # Check op_type
        if node.op_type != 'Conv':
            continue
        # Check attributes
        attr_map = {attr.name: attr for attr in node.attribute}
        if "group" not in attr_map or attr_map["group"].i == 1:
            continue
        if attr_map["kernel_shape"].ints[0] != 1 or attr_map["kernel_shape"].ints[1] != 1:
            continue
        if "pads" in attr_map and sum(attr_map["pads"].ints) != 0:
            continue
        # Check scale
        scale_node = helper.find_node_by_output_name(g, node.input[1])
        if scale_node is None or scale_node.attribute[0].t.dims[1] != 1:
            continue
        scale_node.attribute[0].t.dims.pop()
        scale_node.attribute[0].t.dims.pop()
        scale_node.attribute[0].t.dims.pop()
        scale_info = helper.find_value_by_name(g, node.input[1])
        if scale_info is not None:
            scale_info.type.tensor_type.shape.dim.pop()
            scale_info.type.tensor_type.shape.dim.pop()
            scale_info.type.tensor_type.shape.dim.pop()
        # Check bias
        if len(node.input) == 3:
            bias_name = node.input[2]
        else:
            bias_name = node.name + "_bias"
            bias_node = helper.list_to_constant(bias_name, [attr_map["group"].i], [0.0] * attr_map["group"].i)
            g.node.extend([bias_node])
        # Construct mean and vars
        mean_name = node.name + "_mean"
        mean_node = helper.list_to_constant(mean_name, [attr_map["group"].i], [0.0] * attr_map["group"].i)
        var_name = node.name + "_var"
        var_node = helper.list_to_constant(var_name, [attr_map["group"].i], [1.0] * attr_map["group"].i)
        g.node.extend([mean_node, var_node])
        # Convert
        bn_node = onnx.helper.make_node(
            op_type='BatchNormalization',
            inputs=[node.input[0], node.input[1], bias_name, mean_name, var_name],
            outputs=node.output,
            name=node.name,
            epsilon=0.00001,
            momentum=0.9
            )
        g.node.extend([bn_node])
        node_to_remove.append(node)
    for node in node_to_remove:
        g.node.remove(node)
    topological_sort(g)
Ejemplo n.º 11
0
def fuse_Transpose_into_Constant(g):
    """
    Fuse Transpose layers into the Constant layers before

    :param g: the onnx graph
    """
    node_to_remove = []
    for node in g.node:
        if node.op_type != 'Transpose':
            continue
        prev_node = helper.find_node_by_output_name(g, node.input[0])
        if prev_node is None or prev_node.op_type != 'Constant':
            continue
        
        pre_shape, data_list = helper.constant_to_list(prev_node)
        w = np.reshape(data_list, pre_shape)
        w = w.transpose(node.attribute[0].ints)
        new_shape = w.shape
        w = w.flatten()
        
        new_tensor = onnx.helper.make_tensor(
            name=prev_node.name+'_data',
            data_type=prev_node.attribute[0].t.data_type,
            dims=new_shape,
            vals=w.tolist()
        )
        new_node = onnx.helper.make_node(
            'Constant',
            [],
            [node.output[0]],
            name=node.output[0],
            value=new_tensor
        )
        
        value_between = helper.find_value_by_name(g, prev_node.output[0])
        value_type = value_between.type.tensor_type.elem_type
        g.value_info.remove(value_between)

        g.node.extend([new_node])
        node_to_remove.append(node)
        node_to_remove.append(prev_node)
        
        if new_node.output[0] not in [i.name for i in g.value_info]:
            new_value = onnx.helper.make_tensor_value_info(
                name=new_node.output[0],
                elem_type=value_type,
                shape=new_shape 
                )
            g.value_info.extend([new_value])
            if new_node.output[0]:
                val_info_to_del = helper.find_value_by_name(g, new_node.output[0])
                g.value_info.remove(val_info_to_del)
    
    for node in node_to_remove:
        g.node.remove(node)
    
    topological_sort(g)
Ejemplo n.º 12
0
def inference_resize_shape(g):
    for node in g.node:
        if node.op_type != 'Resize':
            continue

        output_value = helper.find_value_by_name(g, node.output[0])
        output_value = helper.find_output_by_name(
            g, node.output[0]) if output_value is None else output_value
        if output_value is not None:
            continue

        if len(node.input) == 4:  # input: X, roi, scales, sizes
            shape_node = helper.find_node_by_output_name(g, node.input[3])
            if shape_node.op_type != 'Constant':
                continue

            _, shape_value = helper.constant_to_list(shape_node)
            output_value = onnx.helper.make_tensor_value_info(
                node.output[0], onnx.TensorProto.FLOAT,
                [int(v) for v in shape_value])
            g.value_info.extend([output_value])
            return True
        else:
            # If output shape is not given, inference from scales
            # Get the input shape
            input_value = helper.find_value_by_name(g, node.input[0])
            if input_value is None:
                continue
            shape_value = helper.get_shape_from_value_info(input_value)
            scales_node = helper.find_node_by_output_name(g, node.input[2])
            if scales_node.op_type != 'Constant':
                continue
            _, scales_value = helper.constant_to_list(scales_node)
            for i in range(len(shape_value)):
                shape_value[i] *= scales_value[i]
            output_value = onnx.helper.make_tensor_value_info(
                node.output[0], onnx.TensorProto.FLOAT,
                [int(v) for v in shape_value])
            g.value_info.extend([output_value])
            return True
    return False
Ejemplo n.º 13
0
def replace_ConstantOfShape_with_constant(g):
    """Replace Shape with Constant.\\
    This is the first step of reshape constant folding.

    :param g: the input graph\\
    :return: if anything modified, return true.
    """
    node_to_remove = []
    for node in g.node:
        # Find a Shape
        if node.op_type != 'ConstantOfShape':
            continue
        # Check  input
        input_value = helper.find_value_by_name(g, node.input[0])
        if input_value is None:
            input_value = helper.find_input_by_name(g, node.input[0])
        if input_value is None or len(
                input_value.type.tensor_type.shape.dim) == 0:
            continue

        # Replace to constant node
        pre_node = helper.find_node_by_output_name(g, node.input[0])
        _, target_shape = helper.constant_to_list(pre_node)

        value = helper.get_attribute_by_name(node, 'value').i

        node_name = node.output[0]
        new_node = helper.list_to_constant(node_name, [target_shape[0]],
                                           [value] * target_shape[0])

        g.node.extend([new_node])

        # remove old node
        node_to_remove.append(node)

        # delete value_info
        val_info_used = sum(
            [input_value.name in node.input for node in g.node])
        if val_info_used == 1:
            g.value_info.remove(input_value)

    replaced = True if len(node_to_remove) > 0 else False

    for node in node_to_remove:
        g.node.remove(node)

    topological_sort(g)

    return replaced
Ejemplo n.º 14
0
def add_bn_before_activation(g):
    activation_nodes = set(['Relu', 'Clip', 'PRelu', 'LeakyRelu'])
    previous_nodes = set(['Conv', 'BatchNormalization'])
    for n in g.node:
        # Find activation node
        if n.op_type not in activation_nodes:
            continue
        # Get input
        input_node = helper.find_node_by_output_name(g, n.input[0])
        if input_node is None or input_node.op_type in previous_nodes:
            continue

        def add_bn_after(prev_node):
            # Get the channel number from value info
            value_name = prev_node.output[0]
            value = helper.find_value_by_name(g, value_name)
            shape = helper.get_shape_from_value_info(value)
            channel = shape[1]
            # Construct 4 weights
            node_name = value_name + "_nop_bn"
            ones = [1.0] * channel
            zeros = [0.0] * channel
            scale_node = helper.list_to_constant(node_name + "_scale",
                                                 [channel], ones)
            bias_node = helper.list_to_constant(node_name + "_bias", [channel],
                                                zeros)
            mean_node = helper.list_to_constant(node_name + "_mean", [channel],
                                                zeros)
            var_node = helper.list_to_constant(node_name + "_var", [channel],
                                               ones)
            # Construct BN node
            bn_node = onnx.helper.make_node("BatchNormalization", [
                value_name, scale_node.output[0], bias_node.output[0],
                mean_node.output[0], var_node.output[0]
            ], [node_name],
                                            name=node_name,
                                            epsilon=0.00000001)
            # Reconnect the graph
            replace_node_input(n, value_name, node_name)
            # Add node to the graph
            g.node.extend(
                [bn_node, scale_node, bias_node, mean_node, var_node])

        add_bn_after(input_node)
    topological_sort(g)
Ejemplo n.º 15
0
def rename_output_name(g, original_name, new_name):
    # Output
    output_value = helper.find_output_by_name(g, original_name)
    if output_value is None:
        logging.error("Cannot find output value named " + original_name)
        return
    output_value.name = new_name
    # Value Info
    value_info = helper.find_value_by_name(g, original_name)
    if value_info is not None:
        value_info.name = new_name
    # Node output
    node = helper.find_node_by_output_name(g, original_name)
    node.output[0] = new_name
    # Node input
    nodes = helper.find_nodes_by_input_name(g, original_name)
    for node in nodes:
        replace_node_input(node, original_name, new_name)
Ejemplo n.º 16
0
def duplicate_param_shared_constant(g):
    for node in g.node:
        input_names = set()
        for n, input_node_name in enumerate(node.input):
            param_data_node = helper.find_node_by_output_name(g, input_node_name)
            if param_data_node is None or param_data_node.op_type != 'Constant':
                continue
            if param_data_node.name not in input_names:
                input_names.add(input_node_name)
                continue
            
            duplicated_node = copy.deepcopy(param_data_node)
            new_node_name = param_data_node.name + '_' + str(n)
            
            duplicated_node.name = new_node_name
            duplicated_node.output[0] = new_node_name
            
            node.input[n] = new_node_name
            g.node.extend([duplicated_node])
Ejemplo n.º 17
0
def fuse_consecutive_reducemean(g):
    node_to_del = []
    for node in g.node:
        # Find consecutive ReduceMean
        if node.op_type != 'ReduceMean':
            continue
        pre_node = helper.find_node_by_output_name(g, node.input[0])
        if pre_node is None or pre_node.op_type != 'ReduceMean':
            continue
        # Check attributes
        pre_keepdims = helper.get_var_attribute_by_name(
            pre_node, 'keepdims', 'int')
        pre_axes = helper.get_list_attribute_by_name(pre_node, 'axes', 'int')
        cur_keepdims = helper.get_var_attribute_by_name(
            node, 'keepdims', 'int')
        cur_axes = helper.get_list_attribute_by_name(node, 'axes', 'int')
        if pre_keepdims != 0 or cur_keepdims != 0:
            continue
        axes = sorted(pre_axes + cur_axes)
        if axes != [2, 3]:
            continue
        # Merge two ReduceMean into GlobalAveragePool.
        new_gap_node = onnx.helper.make_node('GlobalAveragePool',
                                             [pre_node.input[0]],
                                             [node.output[0] + '_intermedia'],
                                             name=node.name + '_gap')
        new_flatten_node = onnx.helper.make_node(
            'Flatten', [node.output[0] + '_intermedia'], [node.output[0]],
            name=node.name + '_flatten',
            axis=1)

        # Clean up
        g.node.extend([new_gap_node, new_flatten_node])
        node_to_del.extend([pre_node, node])
        mid_val_info = helper.find_value_by_name(g, node.input[0])
        if mid_val_info:
            g.value_info.remove(mid_val_info)

    while node_to_del:
        node = node_to_del.pop()
        g.node.remove(node)

    topological_sort(g)
Ejemplo n.º 18
0
def inference_upsample_shape(g):
    """For onnx v1.4.1+, onnx cannot inference upsample output shape. Let's\\
    do it ourselves. This function only inference the next upsample without\\
    output shape each time.

    :param g: the graph\\
    :return: True if any Upsample shape is generated. Otherwise, False.
    """
    for node in g.node:
        if node.op_type != 'Upsample':
            continue
        output_value = helper.find_value_by_name(g, node.output[0])
        if output_value is None:
            output_value = helper.find_output_by_name(g, node.output[0])
        if output_value and helper.get_shape_from_value_info(output_value):
            continue
        # Get input shape
        input_value = helper.find_value_by_name(g, node.input[0])
        if input_value is None:
            continue
            #raise RuntimeError("Shape for {} has not been generated.".format(node.input[0]))
        if not helper.get_shape_from_value_info(input_value):
            continue
            #raise RuntimeError("Shape for {} is empty.".format(node.input[0]))
        input_shape = helper.get_shape_from_value_info(input_value)
        # Get upsample weight
        weight_node = helper.find_node_by_output_name(g, node.input[1])
        weight_shape, weight = helper.constant_to_list(weight_node)
        if len(input_shape) != weight_shape[0]:
            raise RuntimeError(
                "Unmatch input shape and weight shape: {} vs {}".format(
                    input_shape, weight_shape))
        # Calculate shape
        output_shape = list(input_shape)
        for i in range(len(output_shape)):
            output_shape[i] = int(input_shape[i] * weight[i])
        output_value = onnx.helper.make_tensor_value_info(
            node.output[0], input_value.type.tensor_type.elem_type,
            output_shape)
        g.value_info.extend([output_value])
        return True
    return False
Ejemplo n.º 19
0
def transpose_B_in_Gemm(g):
    """
    If transB is set in Gemm, transpose it

    :param g: the onnx graph
    """
    for node in g.node:
        if node.op_type != 'Gemm':
            continue
        do_it = False
        for attr in node.attribute:
            if attr.name == "transB":
                if attr.i == 1:
                    attr.i = 0
                    do_it = True
                    break
        if not do_it:
            continue
        # Transpose the weight and its output value
        w_node = helper.find_node_by_output_name(g, node.input[1])
        w_output = helper.find_value_by_name(g, node.input[1])
        dim_0 = w_output.type.tensor_type.shape.dim[0].dim_value
        dim_1 = w_output.type.tensor_type.shape.dim[1].dim_value
        w_output.type.tensor_type.shape.dim[0].dim_value = dim_1
        w_output.type.tensor_type.shape.dim[1].dim_value = dim_0
        w_node.attribute[0].t.dims[0] = dim_1
        w_node.attribute[0].t.dims[1] = dim_0
        if w_node.attribute[0].t.raw_data:
            raw_data = w_node.attribute[0].t.raw_data
            fl_data = [i[0] for i in struct.iter_unpack('f', raw_data)]
        else:
            fl_data = w_node.attribute[0].t.float_data
        w = np.reshape(fl_data, (dim_0, dim_1))
        w = w.transpose((1, 0)).flatten()
        if w_node.attribute[0].t.raw_data:
            buf = struct.pack('%sf' % len(w), *w)
            w_node.attribute[0].t.raw_data = buf
        else:
            for i in range(len(fl_data)):
                w_node.attribute[0].t.float_data[i] = w[i]
Ejemplo n.º 20
0
def fuse_consecutive_transposes(g):
    node_to_del = []
    for node in g.node:
        if node.op_type != 'Transpose':
            continue
        pre_node = helper.find_node_by_output_name(g, node.input[0])
        if pre_node.op_type != 'Transpose':
            continue

        pre_permutation = list(pre_node.attribute[0].ints)
        cur_permutation = list(node.attribute[0].ints)
        if len(pre_permutation) != len(cur_permutation):
            continue

        new_permutation = []
        for ind in cur_permutation:
            new_permutation.append(pre_permutation[ind])
        
        new_trans_node = onnx.helper.make_node(
            'Transpose',
            [pre_node.input[0]],
            [node.output[0]],
            name=node.name,
            perm=new_permutation
        )
        
        g.node.extend([new_trans_node])
        node_to_del.extend([pre_node, node])
        
        mid_val_info = helper.find_value_by_name(g, node.input[0])
        if mid_val_info:
            g.value_info.remove(mid_val_info)
    
    while node_to_del:
        node = node_to_del.pop()
        g.node.remove(node)
    
    topological_sort(g)
Ejemplo n.º 21
0
def replace_mul_to_bn(g):
    """Replace single Mul node with Batchnorm node.
    :param g: input graph.
    :return:
    """
    node_to_del = []
    for node in g.node:
        if node.op_type != 'Mul':
            continue

        mul_op_node = node

        # only support one input node
        if len(mul_op_node.input) != 2:  # OP node and value node
            continue

        input_op_node_name = mul_op_node.input[0]
        mul_value_node = helper.find_node_by_output_name(
            g, mul_op_node.input[1])
        if not mul_value_node or mul_value_node.op_type != 'Constant':
            continue

        _, previous_node_output_shape = helper.find_size_shape_from_value(
            helper.find_value_by_name(g, input_op_node_name))
        scale_shape, scale_data = helper.constant_to_list(mul_value_node)

        # only allow 4 dim data input due to the hardware limitation
        if len(previous_node_output_shape) != 4:
            continue

        # channel dimension
        c_dim = previous_node_output_shape[1]

        # only allow channelwise mul or const mul
        if scale_shape != [1, c_dim, 1, 1] and scale_shape != 1:
            continue

        ones = [1.0] * c_dim
        zeros = [0.0] * c_dim
        muls = scale_data * c_dim
        bn_name = mul_op_node.output[0]
        mean_value_node = helper.list_to_constant(bn_name + '_mean',
                                                  np.array(zeros).shape, zeros)
        variance_value_node = helper.list_to_constant(bn_name + '_var',
                                                      np.array(ones).shape,
                                                      ones)
        bias_value_node = helper.list_to_constant(bn_name + '_add',
                                                  np.array(zeros).shape, zeros)
        new_mul_value_node = helper.list_to_constant(bn_name + '_mul',
                                                     np.array(muls).shape,
                                                     muls)

        bn_node = onnx.helper.make_node('BatchNormalization', [
            input_op_node_name, new_mul_value_node.output[0],
            bias_value_node.output[0], mean_value_node.output[0],
            variance_value_node.output[0]
        ], [mul_op_node.output[0]],
                                        name=bn_name,
                                        epsilon=0.00000001)

        mid_val_info = helper.find_value_by_name(g, mul_op_node.output[0])
        scale_val_info = helper.find_value_by_name(g, mul_value_node.output[0])
        g.value_info.remove(mid_val_info)
        g.value_info.remove(scale_val_info)

        g.node.extend([bn_node])
        g.node.extend([mean_value_node])
        g.node.extend([variance_value_node])
        g.node.extend([bias_value_node])
        g.node.extend([new_mul_value_node])

        node_to_del.extend([mul_op_node])
        node_to_del.extend([mul_value_node])

    while node_to_del:
        g.node.remove(node_to_del.pop())

    topological_sort(g)
Ejemplo n.º 22
0
def replace_ReduceMean_with_GlobalAveragePool(g):
    """
    Replace ReduceMean with GlobalAveragePool node when available.

    If there is preceeded Transpose, check the Transpose and the ReduceMean
    together. If the keep_dims is set to 0, add a Flatten.

    :param g: the input graph
    """
    node_to_remove = []
    for node in g.node:
        # Find a ReduceMean layer
        if node.op_type != 'ReduceMean':
            continue
        # Find if it have previous Transpose and its attribute meet the need.
        prev_node = helper.find_node_by_output_name(g, node.input[0])
        if prev_node is not None and prev_node.op_type != 'Transpose':
            prev_node = None
        if prev_node is not None:
            perm = helper.get_list_attribute_by_name(prev_node, 'perm', 'int')
            if perm != [0, 2, 3, 1]:
                prev_node = None
        # Check attributes
        axes = helper.get_list_attribute_by_name(node, 'axes', 'int')
        keepdims = helper.get_var_attribute_by_name(node, 'keepdims', 'int')
        if axes is None:
            continue
        if prev_node is None and axes != [2, 3]:
            continue
        if prev_node is not None and axes != [1, 2]:
            continue
        if keepdims is None:
            keepdims = 1
        # Replace it with GlobalAveragePool
        if prev_node:
            input_list = prev_node.input
        else:
            input_list = node.input
        if keepdims == 1:
            output_list = node.output
        else:
            output_list = [node.output[0] + '_before_flatten']
            flatten_node = onnx.helper.make_node("Flatten",
                                                 output_list,
                                                 node.output,
                                                 name=node.name + "_flatten",
                                                 axis=1)
            g.node.extend([flatten_node])
        new_node = onnx.helper.make_node("GlobalAveragePool",
                                         input_list,
                                         output_list,
                                         name=node.name)
        g.node.extend([new_node])
        node_to_remove.append(node)
        if prev_node:
            value = helper.find_value_by_name(g, prev_node.output[0])
            if value:
                g.value_info.remove(value)
            node_to_remove.append(prev_node)
    for node in node_to_remove:
        g.node.remove(node)
    topological_sort(g)
Ejemplo n.º 23
0
def split_ConvTranspose(model):
    """To feed our compiler, split ConvTranspose into Upsample and Conv.

    :param model: the model
    """
    node_to_delete = []
    # Change model properties for upsample.
    if model.ir_version < 3:
        print("Warning: Current model IR version is not fully supported.")
    model.ir_version = 4
    model.opset_import[0].version = 9
    g = model.graph
    # Get a Convtranspose layer
    for node in g.node:
        # Find a Flatten node
        if node.op_type != 'ConvTranspose':
            continue
        # Check auto_pad
        auto_pad_proto = helper.get_attribute_by_name(node, "auto_pad")
        if auto_pad_proto is not None:
            print("Currently not split auto_pad ConvTranspose")
            continue
        # Check output_shape
        output_shape_proto = helper.get_attribute_by_name(node, "output_shape")
        if output_shape_proto is not None:
            print("Currently not split output_shape ConvTranspose")
            continue
        # Get input shape
        input_value = helper.find_value_by_name(g, node.input[0])
        if input_value is None:
            input_value = helper.find_input_by_name(g, node.input[0])
        if input_value is None:
            print("Cannot get value info named {}.".format(node.input[0]))
            exit(1)
        input_shape = helper.get_shape_from_value_info(input_value)
        # Get attrbutes
        attr = deconv_to_conv_info_extraction(input_shape, node)
        # Generate Upsample scales
        upsample_output_shape = list(input_shape)
        upsample_output_shape[2] = (input_shape[2] - 1) * attr["strides"][0] + 1
        upsample_output_shape[3] = (input_shape[3] - 1) * attr["strides"][1] + 1
        upsample_node_name = node.name + "_inner_upsample"
        upsample_scale_name = upsample_node_name + "_scales"
        scales_np = np.ones([4]).astype('float32')
        scales_np[2] = float(upsample_output_shape[2]) / input_shape[2]
        scales_np[3] = float(upsample_output_shape[3]) / input_shape[3]
        scales_node = helper.numpy_to_constant(upsample_scale_name, scales_np)
        # Generate a Upsample layer and an internal value info
        upsample_node = onnx.helper.make_node(
            "Upsample",
            [node.input[0], upsample_scale_name],
            [upsample_node_name],
            name=upsample_node_name,
            mode="zeros"
        )
        upsample_value_info = onnx.helper.make_tensor_value_info(
            upsample_node_name,
            input_value.type.tensor_type.elem_type,
            upsample_output_shape
        )
        # Check the weight layer, it may need a transpose
        if attr["group"] != input_shape[1]:
            weight_node = helper.find_node_by_output_name(g, node.input[1])
            weight_np = helper.constant_to_numpy(weight_node)
            new_weight_np = np.transpose(weight_np, [1, 0, 2, 3])
            new_weight_node = helper.numpy_to_constant(node.input[1], new_weight_np)
            node_to_delete.append(weight_node)
            g.node.extend([new_weight_node])
            value = helper.find_value_by_name(g, node.input[1])
            g.value_info.remove(value)
        # Generate a Conv layer
        conv_node_name = node.name + "_inner_conv"
        conv_node_input = [upsample_node_name]
        conv_node_input.extend(node.input[1:])
        conv_node = onnx.helper.make_node(
            "Conv",
            conv_node_input,
            [node.output[0]],
            name=conv_node_name,
            pads=[int(i) for i in attr["conv_pads"]],
            dilations=[int(i) for i in attr["dilations"]],
            group=int(attr["group"]),
            kernel_shape=[int(i) for i in attr["kernel_shape"]],
            strides=[int(1), int(1)]
        )
        # Reconnect the graph
        g.node.extend([scales_node, upsample_node, conv_node])
        g.value_info.extend([upsample_value_info])
        node_to_delete.append(node)
    # Delete useless nodes
    for node in node_to_delete:
        g.node.remove(node)
    topological_sort(g)
Ejemplo n.º 24
0
def fuse_BN_into_Gemm(g):
    """Fuse the following BN into the previous Gemm.

    :param g: the graph
    """
    node_to_remove = []
    for node in g.node:
        # Check for BN and Gemm
        if node.op_type != 'BatchNormalization':
            continue
        gemm_node = helper.find_node_by_output_name(g, node.input[0])
        if gemm_node is None:
            continue
        if gemm_node.op_type != 'Gemm':
            continue
        if len(
                helper.find_following_nodes_by_input_value_name(
                    g, gemm_node.output[0])) > 1:
            continue
        bn_node = node
        # Get original weights
        gemm_b_node = helper.find_node_by_output_name(g, gemm_node.input[1])
        gemm_b = helper.constant_to_numpy(gemm_b_node)
        gemm_c_node = helper.find_node_by_output_name(g, gemm_node.input[2])
        gemm_c = helper.constant_to_numpy(gemm_c_node)
        bn_scale_node = helper.find_node_by_output_name(g, bn_node.input[1])
        bn_scale = helper.constant_to_numpy(bn_scale_node)
        bn_bias_node = helper.find_node_by_output_name(g, bn_node.input[2])
        bn_bias = helper.constant_to_numpy(bn_bias_node)
        bn_mean_node = helper.find_node_by_output_name(g, bn_node.input[3])
        bn_mean = helper.constant_to_numpy(bn_mean_node)
        bn_var_node = helper.find_node_by_output_name(g, bn_node.input[4])
        bn_var = helper.constant_to_numpy(bn_var_node)
        # Apply attributes
        # epsilon
        epsilon = helper.get_attribute_by_name(bn_node, 'epsilon')
        if epsilon is None:
            epsilon = 0.00001
        else:
            epsilon = epsilon.f
        bn_var = bn_var + epsilon
        # alpha
        alpha = helper.get_attribute_by_name(gemm_node, 'alpha')
        if alpha is None:
            alpha = 1
        else:
            alpha = alpha.f
        gemm_b = gemm_b * alpha
        # beta
        beta = helper.get_attribute_by_name(gemm_node, 'beta')
        if beta is None:
            beta = 1
        else:
            beta = beta.f
        gemm_c = gemm_c * beta
        # transA
        transA = helper.get_attribute_by_name(gemm_node, 'transA')
        if transA is not None and transA.i == 1:
            raise RuntimeError("Do not support transA")
        # transB
        transB = helper.get_attribute_by_name(gemm_node, 'transB')
        if transB is not None and transB.i == 1:
            gemm_b = gemm_b.transpose()
        # Calculate new weights
        new_gemm_b = gemm_b * bn_scale / np.sqrt(bn_var)
        new_gemm_c = (gemm_c - bn_mean) * bn_scale / np.sqrt(bn_var) + bn_bias
        # Replace original weights
        new_gemm_b_node = helper.numpy_to_constant(gemm_b_node.name + '_fused',
                                                   new_gemm_b)
        new_gemm_c_node = helper.numpy_to_constant(gemm_c_node.name + '_fused',
                                                   new_gemm_c)
        g.node.extend([new_gemm_b_node, new_gemm_c_node])
        node_to_remove.extend([
            gemm_b_node, gemm_c_node, bn_node, bn_scale_node, bn_bias_node,
            bn_mean_node, bn_var_node
        ])
        # Modify attributes
        # alpha
        alpha = helper.get_attribute_by_name(gemm_node, 'alpha')
        if alpha is not None:
            alpha.f = 1.0
        # beta
        beta = helper.get_attribute_by_name(gemm_node, 'beta')
        if beta is not None:
            beta.f = 1.0
        # transB
        transB = helper.get_attribute_by_name(gemm_node, 'transB')
        if transB is not None:
            transB.i = 0
        # Connect the new graph
        gemm_node.input[1] = new_gemm_b_node.output[0]
        gemm_node.input[2] = new_gemm_c_node.output[0]
        gemm_b_value = helper.find_value_by_name(g, gemm_b_node.output[0])
        gemm_c_value = helper.find_value_by_name(g, gemm_c_node.output[0])
        gemm_b_value.name = new_gemm_b_node.output[0]
        gemm_c_value.name = new_gemm_c_node.output[0]
        gemm_value = helper.find_value_by_name(g, gemm_node.output[0])
        g.value_info.remove(gemm_value)
        gemm_node.output[0] = bn_node.output[0]
        for i in range(1, 5):
            value = helper.find_value_by_name(g, bn_node.input[i])
            g.value_info.remove(value)
    # Remove useless nodes
    for node in node_to_remove:
        g.node.remove(node)
    topological_sort(g)
Ejemplo n.º 25
0
def replace_dilated_conv(g):
    """
    If the dilation of a convolution is not (1, 1), replace it with a regular
    convolution with an expanded kernel.

    :param g: the input graph
    """
    node_to_remove = []
    for node in g.node:
        # Check if this is a conv layer
        if node.op_type != 'Conv':
            continue
        # Check if this has dilation
        has_dilations = False
        has_strides = False
        for attr in node.attribute:
            if attr.name == "dilations":
                dilations = list(attr.ints)
                if dilations != [1, 1]:
                    has_dilations = True
            if attr.name == "strides":
                strides = list(attr.ints)
                if strides != [1, 1]:
                    has_strides = True
        if has_dilations and has_strides:
            print("Warning: Both strides and dilations are set in ", node.name)
            continue
        if not has_dilations:
            continue
        # Construct new kernel
        w_node = helper.find_node_by_output_name(g, node.input[1])
        w_output = helper.find_value_by_name(g, node.input[1])
        shape = list(w_node.attribute[0].t.dims)
        # get original weight from float_data or raw data
        weight = list(w_node.attribute[0].t.float_data)
        if len(weight) == 0:
            # Unpack from raw data
            raw_data = w_node.attribute[0].t.raw_data
            weight = [i[0] for i in struct.iter_unpack('f', raw_data)]
        weight = np.array(weight)
        weight = np.reshape(weight, shape)
        new_shape = copy.copy(shape)
        new_shape[2] = 1 + (shape[2] - 1) * dilations[0]
        new_shape[3] = 1 + (shape[3] - 1) * dilations[1]
        new_weight = np.zeros(new_shape)
        for batch in range(shape[0]):
            for ch in range(shape[1]):
                for h in range(shape[2]):
                    nh = h * dilations[0]
                    for w in range(shape[3]):
                        nw = w * dilations[1]
                        new_weight[batch, ch, nh, nw] = weight[batch, ch, h, w]
        tensor = onnx.helper.make_tensor(w_node.attribute[0].t.name,
                                         w_node.attribute[0].t.data_type,
                                         new_shape, new_weight.ravel())
        new_w_node = onnx.helper.make_node("Constant", [],
                                           list(w_node.output),
                                           name=w_node.name,
                                           value=tensor)
        g.node.extend([new_w_node])
        node_to_remove.append(w_node)
        # Modify attributes and value info shapes
        w_output.type.tensor_type.shape.dim[2].dim_value = new_shape[2]
        w_output.type.tensor_type.shape.dim[3].dim_value = new_shape[3]
        for attr in node.attribute:
            if attr.name == "kernel_shape":
                attr.ints[0] = new_shape[2]
                attr.ints[1] = new_shape[3]
            if attr.name == "dilations":
                attr.ints[0] = 1
                attr.ints[1] = 1
    # Remove old weight nodes
    for node in node_to_remove:
        g.node.remove(node)
Ejemplo n.º 26
0
def fuse_mul_and_add_into_bn(g):
    node_to_del = []
    for node in g.node:
        if node.op_type != 'Add':
            continue
        add_node = node
        input_nodes_add = [
            helper.find_node_by_output_name(g, input_name)
            for input_name in add_node.input
        ]
        if any([n == None for n in input_nodes_add]):
            continue
        mul_node, const_add = None, None
        for input_node_add in input_nodes_add:
            if input_node_add.op_type == 'Mul':
                mul_node = input_node_add
            elif input_node_add.op_type == 'Constant':
                const_add = input_node_add
            else:
                pass
        if not mul_node or not const_add:
            continue
        data_input_name, const_mul = None, None
        for input_name in mul_node.input:
            input_node = helper.find_node_by_output_name(g, input_name)
            if not input_node:
                data_input_name = input_name
            elif input_node.op_type == 'Constant':
                if not const_mul:
                    const_mul = input_node
                else:
                    data_input_name = input_name
            else:
                data_input_name = input_name

        if not const_mul:
            continue

        scale_shape, scale_data = helper.constant_to_list(const_mul)
        bais_shape, __ = helper.constant_to_list(const_add)
        c_dim = len(scale_data)
        if scale_shape != bais_shape:
            continue

        _, previous_node_output_shape = helper.find_size_shape_from_value(
            helper.find_value_by_name(g, data_input_name))
        # only allow 4 dim data input due to the hardware limitation
        if len(previous_node_output_shape) != 4:
            continue

        # check if mul's dim and input channel dimension are matched
        if previous_node_output_shape[1] != c_dim:
            continue

        if scale_shape == [1, c_dim, 1, 1]:

            # remove all '1'
            for _ in range(3):
                const_add.attribute[0].t.dims.remove(1)
                const_mul.attribute[0].t.dims.remove(1)

        elif scale_shape == [1, c_dim]:

            # remove all '1'
            const_add.attribute[0].t.dims.remove(1)
            const_mul.attribute[0].t.dims.remove(1)

        else:
            continue

        bn_name = add_node.output[0]
        const_mean = helper.list_to_constant(bn_name + '_mean', [c_dim],
                                             [0.0 for _ in range(c_dim)])
        const_var = helper.list_to_constant(bn_name + '_var', [c_dim],
                                            [1.0 for _ in range(c_dim)])

        bn_node = onnx.helper.make_node(
            'BatchNormalization',
            [data_input_name, const_mul.output[0], const_add.output[0],\
                const_mean.output[0], const_var.output[0]],
            [add_node.output[0]],
            name=bn_name,
            epsilon=0.00000001
        )

        mid_val_info = helper.find_value_by_name(g, mul_node.output[0])
        scale_val_info = helper.find_value_by_name(g, const_mul.output[0])
        bais_val_info = helper.find_value_by_name(g, const_add.output[0])
        g.value_info.remove(mid_val_info)
        g.value_info.remove(scale_val_info)
        g.value_info.remove(bais_val_info)

        new_scale_val_info = onnx.helper.make_tensor_value_info(
            const_mul.output[0], const_mul.attribute[0].t.data_type, [c_dim])
        new_bais_val_info = onnx.helper.make_tensor_value_info(
            const_add.output[0], const_add.attribute[0].t.data_type, [c_dim])
        mean_val_info = onnx.helper.make_tensor_value_info(
            const_mean.output[0], const_mean.attribute[0].t.data_type, [c_dim])
        var_val_info = onnx.helper.make_tensor_value_info(
            const_var.output[0], const_var.attribute[0].t.data_type, [c_dim])

        g.value_info.extend([new_scale_val_info])
        g.value_info.extend([new_bais_val_info])
        g.value_info.extend([mean_val_info])
        g.value_info.extend([var_val_info])
        g.node.extend([bn_node])
        g.node.extend([const_mean])
        g.node.extend([const_var])
        node_to_del.extend([mul_node, add_node])

    while node_to_del:
        g.node.remove(node_to_del.pop())

    topological_sort(g)
Ejemplo n.º 27
0
def fuse_BN_with_Reshape_into_Gemm(g):
    """Fuse the following BN into the previous Gemm, even with Reshape or \\
        Squeeze and Unsqueeze surrounding.

    :param g: the graph
    """
    node_to_remove = []
    for node in g.node:
        # Check for BN and Gemm pattern: Gemm A BN B
        # Find BatchNorm Node
        if node.op_type != 'BatchNormalization':
            continue
        bn_node = node
        # Find A Node
        a_node = helper.find_node_by_output_name(g, node.input[0])
        if a_node is None or len(a_node.input) == 0:
            continue
        # Find Gemm Node
        gemm_node = helper.find_node_by_output_name(g, a_node.input[0])
        if gemm_node is None or gemm_node.op_type != 'Gemm':
            continue
        # Find B Node
        b_node_list = helper.find_following_nodes_by_input_value_name(
            g, bn_node.output[0])
        if len(b_node_list) == 0:
            the_output = helper.find_output_by_name(g, bn_node.output[0])
            if the_output is None:
                continue
            b_node = None
        elif len(b_node_list) > 1:
            continue
        else:
            b_node = b_node_list[0]
        # Check for branches
        if len(
                helper.find_following_nodes_by_input_value_name(
                    g, gemm_node.output[0])) > 1:
            continue
        if len(
                helper.find_following_nodes_by_input_value_name(
                    g, a_node.output[0])) > 1:
            continue
        # Check type of A
        if a_node.op_type == 'Unsqueeze':
            axes = helper.get_attribute_by_name(a_node, 'axes')
            if axes.ints != [2]:
                continue
        elif a_node.op_type == 'Reshape':
            a = helper.constant_to_list(
                helper.find_node_by_output_name(g, a_node.input[1]))[1]
            if len(a) != 3 or a[2] != 1:
                continue
        else:
            continue
        # Check type of B
        if b_node is None:
            pass
        elif b_node.op_type == 'Flatten':
            pass
        elif b_node.op_type == 'Squeeze':
            axes = helper.get_attribute_by_name(a_node, 'axes')
            if axes.ints != [2]:
                continue
        elif b_node.op_type == 'Reshape':
            a = helper.constant_to_list(
                helper.find_node_by_output_name(g, b_node.input[1]))[1]
            if len(a) != 2:
                continue
        else:
            continue
        # Construct new Nodes
        # Get original weights
        gemm_b_node = helper.find_node_by_output_name(g, gemm_node.input[1])
        gemm_b = helper.constant_to_numpy(gemm_b_node)
        gemm_c_node = helper.find_node_by_output_name(g, gemm_node.input[2])
        gemm_c = helper.constant_to_numpy(gemm_c_node)
        bn_scale_node = helper.find_node_by_output_name(g, bn_node.input[1])
        bn_scale = helper.constant_to_numpy(bn_scale_node)
        bn_bias_node = helper.find_node_by_output_name(g, bn_node.input[2])
        bn_bias = helper.constant_to_numpy(bn_bias_node)
        bn_mean_node = helper.find_node_by_output_name(g, bn_node.input[3])
        bn_mean = helper.constant_to_numpy(bn_mean_node)
        bn_var_node = helper.find_node_by_output_name(g, bn_node.input[4])
        bn_var = helper.constant_to_numpy(bn_var_node)
        # Apply attributes
        # epsilon
        epsilon = helper.get_attribute_by_name(bn_node, 'epsilon')
        if epsilon is None:
            epsilon = 0.00001
        else:
            epsilon = epsilon.f
        bn_var = bn_var + epsilon
        # alpha
        alpha = helper.get_attribute_by_name(gemm_node, 'alpha')
        if alpha is None:
            alpha = 1
        else:
            alpha = alpha.f
        gemm_b = gemm_b * alpha
        # beta
        beta = helper.get_attribute_by_name(gemm_node, 'beta')
        if beta is None:
            beta = 1
        else:
            beta = beta.f
        gemm_c = gemm_c * beta
        # transA
        transA = helper.get_attribute_by_name(gemm_node, 'transA')
        if transA is not None and transA.i == 1:
            raise RuntimeError("Do not support transA")
        # transB
        transB = helper.get_attribute_by_name(gemm_node, 'transB')
        if transB is not None and transB.i == 1:
            gemm_b = gemm_b.transpose()
        # Calculate new weights
        new_gemm_b = gemm_b * bn_scale / np.sqrt(bn_var)
        new_gemm_c = (gemm_c - bn_mean) * bn_scale / np.sqrt(bn_var) + bn_bias
        # Replace original weights
        new_gemm_b_node = helper.numpy_to_constant(gemm_b_node.name + '_fused',
                                                   new_gemm_b)
        new_gemm_c_node = helper.numpy_to_constant(gemm_c_node.name + '_fused',
                                                   new_gemm_c)
        g.node.extend([new_gemm_b_node, new_gemm_c_node])
        # Modify attributes
        # alpha
        alpha = helper.get_attribute_by_name(gemm_node, 'alpha')
        if alpha is not None:
            alpha.f = 1.0
        # beta
        beta = helper.get_attribute_by_name(gemm_node, 'beta')
        if beta is not None:
            beta.f = 1.0
        # transB
        transB = helper.get_attribute_by_name(gemm_node, 'transB')
        if transB is not None:
            transB.i = 0
        # Remove useless nodes
        node_to_remove.extend([
            gemm_b_node, gemm_c_node, bn_node, bn_scale_node, bn_bias_node,
            bn_mean_node, bn_var_node, a_node
        ])
        if a_node.op_type == 'Reshape':
            node_to_remove.append(
                helper.find_node_by_output_name(g, a_node.input[1]))
        if b_node is not None:
            node_to_remove.append(b_node)
            if b_node.op_type == 'Reshape':
                node_to_remove.append(
                    helper.find_node_by_output_name(g, b_node.input[1]))
        # Delete useless value infos
        value = helper.find_value_by_name(g, a_node.output[0])
        g.value_info.remove(value)
        if a_node.op_type == 'Reshape':
            value = helper.find_value_by_name(g, a_node.input[1])
            g.value_info.remove(value)
        for i in range(1, 5):
            value = helper.find_value_by_name(g, bn_node.input[i])
            g.value_info.remove(value)
        value = helper.find_value_by_name(g, bn_node.output[0])
        if value is not None:
            g.value_info.remove(value)
        if b_node is not None:
            value = helper.find_value_by_name(g, gemm_node.output[0])
            g.value_info.remove(value)
            if b_node.op_type == 'Reshape':
                value = helper.find_value_by_name(g, b_node.input[1])
                g.value_info.remove(value)
        # Connect the new graph
        # Connect Gemm new weights
        gemm_node.input[1] = new_gemm_b_node.output[0]
        gemm_node.input[2] = new_gemm_c_node.output[0]
        gemm_b_value = helper.find_value_by_name(g, gemm_b_node.output[0])
        gemm_c_value = helper.find_value_by_name(g, gemm_c_node.output[0])
        gemm_b_value.name = new_gemm_b_node.output[0]
        gemm_b_value.type.tensor_type.shape.dim[
            0].dim_value = new_gemm_b.shape[0]
        gemm_b_value.type.tensor_type.shape.dim[
            1].dim_value = new_gemm_b.shape[1]
        gemm_c_value.name = new_gemm_c_node.output[0]
        if b_node is None:
            # If b node is None, set the Gemm output as the graph output
            output_value = helper.find_output_by_name(g, bn_node.output[0])
            g.output.remove(output_value)
            g.output.extend(
                [helper.find_value_by_name(g, gemm_node.output[0])])
        else:
            # Else, set node B output as gemm output
            gemm_node.output[0] = b_node.output[0]
    # Remove useless nodes
    for node in node_to_remove:
        g.node.remove(node)
    topological_sort(g)
Ejemplo n.º 28
0
def pattern_matmul_mul_add(g, matmul_node):
    # Check node match - Mul node
    next_nodes = helper.find_nodes_by_input_name(g, matmul_node.output[0])
    if len(next_nodes) != 1:
        return
    if next_nodes[0].op_type != 'Mul':
        return
    mul_node = next_nodes[0]
    # Check node match - Add node
    next_nodes = helper.find_nodes_by_input_name(g, mul_node.output[0])
    if len(next_nodes) != 1:
        return
    if next_nodes[0].op_type != 'Add':
        return
    add_node = next_nodes[0]
    # Check Mul weight
    mul_weight_node = helper.find_node_by_output_name(g, mul_node.input[1])
    if mul_weight_node.op_type != 'Constant':
        return
    weight_size, mul_weight = helper.constant_to_list(mul_weight_node)
    for i in mul_weight:
        if i != 1:
            return
    channel = weight_size[0]
    # Check Add weight
    add_weight_node = helper.find_node_by_output_name(g, add_node.input[1])
    if add_weight_node.op_type != 'Constant':
        return
    # Check MatMul weight to see if it need weight broadcast
    matmul_weight_node = helper.find_node_by_output_name(g, matmul_node.input[1])
    matmul_weight = helper.constant_to_numpy(matmul_weight_node)
    if matmul_weight.shape[1] == 1:
        # Weight broadcast
        new_matmul_weight = np.tile(matmul_weight, channel)
        new_matmul_weight_node = helper.numpy_to_constant(matmul_weight_node.name, new_matmul_weight)
        g.node.remove(matmul_weight_node)
        g.node.extend([new_matmul_weight_node])
    value = helper.find_value_by_name(g, matmul_weight_node.output[0])
    if value is not None:
        g.value_info.remove(value)
    # Remove Mul node
    g.node.remove(mul_weight_node)
    value = helper.find_value_by_name(g, mul_weight_node.output[0])
    if value is not None:
        g.value_info.remove(value)
    g.node.remove(mul_node)
    value = helper.find_value_by_name(g, mul_node.output[0])
    if value is not None:
        g.value_info.remove(value)
    # Fuse Matmul and Add
    gemm_node = onnx.helper.make_node(
        'Gemm',
        [matmul_node.input[0], matmul_node.input[1], add_node.input[1]],
        [add_node.output[0]],
        name = matmul_node.name,
        alpha = 1.0,
        beta = 1.0,
        transA = 0,
        transB = 0
    )
    g.node.extend([gemm_node])
    # Clean up
    g.node.remove(matmul_node)
    g.node.remove(add_node)
    value = helper.find_value_by_name(g, matmul_node.output[0])
    if value is not None:
        g.value_info.remove(value)
    other.topological_sort(g)
Ejemplo n.º 29
0
def fuse_Gemm_into_Gemm(g):
    """Fuse the previous Gemm into the following Gemm.

    :param g: the graph
    """
    node_to_remove = []
    for node in g.node:
        # Check for Gemm and Gemm
        if node.op_type != 'Gemm':
            continue
        prev_node = helper.find_node_by_output_name(g, node.input[0])
        if prev_node is None:
            continue
        if prev_node.op_type != 'Gemm':
            continue
        # Get original weights
        prev_b_node = helper.find_node_by_output_name(g, prev_node.input[1])
        prev_b = helper.constant_to_numpy(prev_b_node)
        prev_c_node = helper.find_node_by_output_name(g, prev_node.input[2])
        prev_c = helper.constant_to_numpy(prev_c_node)
        b_node = helper.find_node_by_output_name(g, node.input[1])
        b = helper.constant_to_numpy(b_node)
        c_node = helper.find_node_by_output_name(g, node.input[2])
        c = helper.constant_to_numpy(c_node)
        # Apply attributes
        # alpha
        alpha = helper.get_attribute_by_name(node, 'alpha')
        if alpha is None:
            alpha = 1
        else:
            alpha = alpha.f
        b = b * alpha
        alpha = helper.get_attribute_by_name(prev_node, 'alpha')
        if alpha is None:
            alpha = 1
        else:
            alpha = alpha.f
        prev_b = prev_b * alpha
        # beta
        beta = helper.get_attribute_by_name(node, 'beta')
        if beta is None:
            beta = 1
        else:
            beta = beta.f
        c = c * beta
        beta = helper.get_attribute_by_name(prev_node, 'beta')
        if beta is None:
            beta = 1
        else:
            beta = beta.f
        prev_c = prev_c * beta
        # transA
        transA = helper.get_attribute_by_name(node, 'transA')
        if transA is not None and transA.i == 1:
            raise RuntimeError("Do not support transA")
        transA = helper.get_attribute_by_name(prev_node, 'transA')
        if transA is not None and transA.i == 1:
            raise RuntimeError("Do not support transA")
        # transB
        transB = helper.get_attribute_by_name(node, 'transB')
        if transB is not None and transB.i == 1:
            b = b.transpose()
        transB = helper.get_attribute_by_name(prev_node, 'transB')
        if transB is not None and transB.i == 1:
            prev_b = prev_b.transpose()
        # Calculate new weights
        new_b = prev_b.dot(b)
        new_c = prev_c.dot(b) + c
        # Replace original weights
        new_b_node = helper.numpy_to_constant(b_node.name + '_fused', new_b)
        new_c_node = helper.numpy_to_constant(c_node.name + '_fused', new_c)
        g.node.extend([new_b_node, new_c_node])
        node_to_remove.extend(
            [b_node, c_node, prev_b_node, prev_c_node, prev_node])
        # Modify attributes
        # alpha
        alpha = helper.get_attribute_by_name(node, 'alpha')
        if alpha is not None:
            alpha.f = 1.0
        # beta
        beta = helper.get_attribute_by_name(node, 'beta')
        if beta is not None:
            beta.f = 1.0
        # transB
        transB = helper.get_attribute_by_name(node, 'transB')
        if transB is not None:
            transB.i = 0
        # Connect the new graph
        node.input[0] = prev_node.input[0]
        prev_value = helper.find_value_by_name(g, prev_node.output[0])
        g.value_info.remove(prev_value)
        for i in range(1, 3):
            value = helper.find_value_by_name(g, prev_node.input[i])
            g.value_info.remove(value)
            value = helper.find_value_by_name(g, node.input[i])
            g.value_info.remove(value)
        node.input[1] = new_b_node.output[0]
        node.input[2] = new_c_node.output[0]
    # Remove useless nodes
    for node in node_to_remove:
        g.node.remove(node)
    topological_sort(g)
Ejemplo n.º 30
0
def fuse_mul_and_add_into_gemm(g):
    node_to_del = []
    for node in g.node:
        if node.op_type != 'Add':
            continue
        add_node = node
        mul_node = helper.find_node_by_output_name(g, add_node.input[0])
        if not mul_node or mul_node.op_type != 'Mul':
            continue
        mul_const = helper.find_node_by_output_name(g, mul_node.input[1])
        if not mul_const or mul_const.op_type != 'Constant':
            continue
        add_const = helper.find_node_by_output_name(g, add_node.input[1])
        if not add_const or add_const.op_type != 'Constant':
            continue

        input_val = helper.find_value_by_name(g, mul_node.input[0])
        if not input_val:
            input_val = helper.find_input_by_name(g, mul_node.input[0])
        if not input_val:
            continue

        _, input_shape = helper.find_size_shape_from_value(input_val)
        if not input_shape:
            continue

        dim = int(np.prod(input_shape))
        if input_shape != [1, dim]:
            continue

        mul_const_shape, mul_const_data = helper.constant_to_list(mul_const)
        add_const_shape, __ = helper.constant_to_list(add_const)

        if len(mul_const_shape) != 1 or mul_const_shape[0] != dim:
            continue
        if len(add_const_shape) != 1 or add_const_shape[0] != dim:
            continue

        b_data = np.zeros([dim, dim])
        for i in range(dim):
            b_data[i][i] = mul_const_data[i]
        b_data = b_data.flatten().tolist()
        b_tensor = onnx.helper.make_tensor(
            name=mul_const.name + '_tensor',
            data_type=mul_const.attribute[0].t.data_type,
            dims=[dim, dim],
            vals=b_data)
        b_const_node = onnx.helper.make_node('Constant', [],
                                             [mul_const.output[0]],
                                             value=b_tensor,
                                             name=mul_const.output[0])

        add_const.attribute[0].t.dims.insert(0, 1)

        gemm_node = onnx.helper.make_node(
            'Gemm',
            [mul_node.input[0], b_const_node.output[0], add_const.output[0]],
            [add_node.output[0]],
            name=add_node.output[0])

        g.node.extend([gemm_node, b_const_node])
        node_to_del.extend([mul_const, mul_node, add_node])

        val_info_mid = helper.find_value_by_name(g, mul_node.output[0])
        val_info_mul_const = helper.find_value_by_name(g, mul_const.output[0])
        val_info_add_const = helper.find_value_by_name(g, add_const.output[0])
        if val_info_mid:
            g.value_info.remove(val_info_mid)
        if val_info_mul_const:
            g.value_info.remove(val_info_mul_const)
        if val_info_add_const:
            g.value_info.remove(val_info_add_const)

    while node_to_del:
        g.node.remove(node_to_del.pop())

    topological_sort(g)