Exemple #1
0
def inference_cov_shape(g):
    processed = False
    for node in g.node:
        if node.op_type != 'Conv':
            continue
        input_value_info = helper.find_value_by_name(g, node.input[0])
        if not input_value_info:
            input_value_info = helper.find_input_by_name(g, node.input[0])
        if not input_value_info:
            continue

        kernel_value_info = helper.find_value_by_name(g, node.input[1])
        output_value_info = helper.find_value_by_name(g, node.output[0])
        if not output_value_info:
            output_value_info = helper.find_output_by_name(g, node.output[0])

        if output_value_info and \
            helper.get_shape_from_value_info(output_value_info):
            continue

        _, kernel_shape = helper.find_size_shape_from_value(kernel_value_info)
        _, input_shape = helper.find_size_shape_from_value(input_value_info)
        if not input_shape or not kernel_shape:
            continue
        strides = helper.get_attribute_by_name(node, 'strides').ints
        pads = helper.get_attribute_by_name(node, 'pads').ints
        dilation = helper.get_attribute_by_name(node, 'dilations').ints

        # Pytorch model has the case where strides only have one number
        if len(strides) == 1:
            return strides.append(strides[0])
        if len(dilation) == 1:
            return dilation.append(dilation[0])

        H = math.floor((input_shape[2]+pads[0]+pads[2]-\
            dilation[0]*(kernel_shape[2]-1)-1)/strides[0]+1)
        W = math.floor((input_shape[3]+pads[1]+pads[3]-\
            dilation[1]*(kernel_shape[3]-1)-1)/strides[1]+1)
        output_shape = [input_shape[0], kernel_shape[0], H, W]

        new_output_value_info = onnx.helper.make_tensor_value_info(
            node.output[0], input_value_info.type.tensor_type.elem_type,
            output_shape)

        processed = True

        if output_value_info:
            g.value_info.remove(output_value_info)
        g.value_info.extend([new_output_value_info])

    return processed
Exemple #2
0
def set_upsample_mode_to_align_corner(g):
    """Set all the upsample nodes mode to align_corner.
    """
    for node in g.node:
        if node.op_type != 'Upsample':
            continue
        # Find a upsample node
        attribute = helper.get_attribute_by_name(node, "mode")
        if type(attribute.s) == type(b'abc'):
            attribute.s = "align_corner".encode('utf-8')
        else:
            attribute.s = "align_corner"
Exemple #3
0
def replace_ConstantOfShape_with_constant(g):
    """Replace Shape with Constant.\\
    This is the first step of reshape constant folding.

    :param g: the input graph\\
    :return: if anything modified, return true.
    """
    node_to_remove = []
    for node in g.node:
        # Find a Shape
        if node.op_type != 'ConstantOfShape':
            continue
        # Check  input
        input_value = helper.find_value_by_name(g, node.input[0])
        if input_value is None:
            input_value = helper.find_input_by_name(g, node.input[0])
        if input_value is None or len(
                input_value.type.tensor_type.shape.dim) == 0:
            continue

        # Replace to constant node
        pre_node = helper.find_node_by_output_name(g, node.input[0])
        _, target_shape = helper.constant_to_list(pre_node)

        value = helper.get_attribute_by_name(node, 'value').i

        node_name = node.output[0]
        new_node = helper.list_to_constant(node_name, [target_shape[0]],
                                           [value] * target_shape[0])

        g.node.extend([new_node])

        # remove old node
        node_to_remove.append(node)

        # delete value_info
        val_info_used = sum(
            [input_value.name in node.input for node in g.node])
        if val_info_used == 1:
            g.value_info.remove(input_value)

    replaced = True if len(node_to_remove) > 0 else False

    for node in node_to_remove:
        g.node.remove(node)

    topological_sort(g)

    return replaced
Exemple #4
0
def split_ConvTranspose(model):
    """To feed our compiler, split ConvTranspose into Upsample and Conv.

    :param model: the model
    """
    node_to_delete = []
    # Change model properties for upsample.
    if model.ir_version < 3:
        print("Warning: Current model IR version is not fully supported.")
    model.ir_version = 4
    model.opset_import[0].version = 9
    g = model.graph
    # Get a Convtranspose layer
    for node in g.node:
        # Find a Flatten node
        if node.op_type != 'ConvTranspose':
            continue
        # Check auto_pad
        auto_pad_proto = helper.get_attribute_by_name(node, "auto_pad")
        if auto_pad_proto is not None:
            print("Currently not split auto_pad ConvTranspose")
            continue
        # Check output_shape
        output_shape_proto = helper.get_attribute_by_name(node, "output_shape")
        if output_shape_proto is not None:
            print("Currently not split output_shape ConvTranspose")
            continue
        # Get input shape
        input_value = helper.find_value_by_name(g, node.input[0])
        if input_value is None:
            input_value = helper.find_input_by_name(g, node.input[0])
        if input_value is None:
            print("Cannot get value info named {}.".format(node.input[0]))
            exit(1)
        input_shape = helper.get_shape_from_value_info(input_value)
        # Get attrbutes
        attr = deconv_to_conv_info_extraction(input_shape, node)
        # Generate Upsample scales
        upsample_output_shape = list(input_shape)
        upsample_output_shape[2] = (input_shape[2] - 1) * attr["strides"][0] + 1
        upsample_output_shape[3] = (input_shape[3] - 1) * attr["strides"][1] + 1
        upsample_node_name = node.name + "_inner_upsample"
        upsample_scale_name = upsample_node_name + "_scales"
        scales_np = np.ones([4]).astype('float32')
        scales_np[2] = float(upsample_output_shape[2]) / input_shape[2]
        scales_np[3] = float(upsample_output_shape[3]) / input_shape[3]
        scales_node = helper.numpy_to_constant(upsample_scale_name, scales_np)
        # Generate a Upsample layer and an internal value info
        upsample_node = onnx.helper.make_node(
            "Upsample",
            [node.input[0], upsample_scale_name],
            [upsample_node_name],
            name=upsample_node_name,
            mode="zeros"
        )
        upsample_value_info = onnx.helper.make_tensor_value_info(
            upsample_node_name,
            input_value.type.tensor_type.elem_type,
            upsample_output_shape
        )
        # Check the weight layer, it may need a transpose
        if attr["group"] != input_shape[1]:
            weight_node = helper.find_node_by_output_name(g, node.input[1])
            weight_np = helper.constant_to_numpy(weight_node)
            new_weight_np = np.transpose(weight_np, [1, 0, 2, 3])
            new_weight_node = helper.numpy_to_constant(node.input[1], new_weight_np)
            node_to_delete.append(weight_node)
            g.node.extend([new_weight_node])
            value = helper.find_value_by_name(g, node.input[1])
            g.value_info.remove(value)
        # Generate a Conv layer
        conv_node_name = node.name + "_inner_conv"
        conv_node_input = [upsample_node_name]
        conv_node_input.extend(node.input[1:])
        conv_node = onnx.helper.make_node(
            "Conv",
            conv_node_input,
            [node.output[0]],
            name=conv_node_name,
            pads=[int(i) for i in attr["conv_pads"]],
            dilations=[int(i) for i in attr["dilations"]],
            group=int(attr["group"]),
            kernel_shape=[int(i) for i in attr["kernel_shape"]],
            strides=[int(1), int(1)]
        )
        # Reconnect the graph
        g.node.extend([scales_node, upsample_node, conv_node])
        g.value_info.extend([upsample_value_info])
        node_to_delete.append(node)
    # Delete useless nodes
    for node in node_to_delete:
        g.node.remove(node)
    topological_sort(g)
Exemple #5
0
def inference_cov_shape(g):
    processed = False
    for node in g.node:
        # Check for Conv output shape need to be inferrenced.
        if node.op_type != 'Conv':
            continue
        # Input shape is not ready yet. Skip.
        input_value_info = helper.find_value_by_name(g, node.input[0])
        if not input_value_info:
            input_value_info = helper.find_input_by_name(g, node.input[0])
        if not input_value_info:
            continue
        _, input_shape = helper.find_size_shape_from_value(input_value_info)
        if not input_shape:
            continue
        # Output shape is already there. Skip.
        output_value_info = helper.find_value_by_name(g, node.output[0])
        if not output_value_info:
            output_value_info = helper.find_output_by_name(g, node.output[0])
        if output_value_info and \
            helper.get_shape_from_value_info(output_value_info):
            continue

        # Now start the inference.
        # If auto_pad is set, use the auto_pad.
        auto_pad = helper.get_var_attribute_by_name(node, 'auto_pad', 'string')
        pads = None
        if auto_pad is not None and auto_pad != 'NOTSET':
            if auto_pad == 'SAME_LOWER' or auto_pad == 'SAME_UPPER':
                new_output_value_info = onnx.helper.make_tensor_value_info(
                    node.output[0],
                    input_value_info.type.tensor_type.elem_type,
                    input_shape
                )
                if output_value_info:
                    g.value_info.remove(output_value_info)
                g.value_info.extend([new_output_value_info])
                processed = True
                continue
            elif auto_pad == 'VALID':
                pads = [0, 0, 0, 0]
            else:
                print("Unrecognized auto_pad value: " + str(auto_pad))
                exit(1)
        kernel_value_info = helper.find_value_by_name(g, node.input[1])
        _, kernel_shape = helper.find_size_shape_from_value(kernel_value_info)
        if not input_shape or not kernel_shape:
            continue
        strides = helper.get_attribute_by_name(node, 'strides').ints
        if not pads:
            pads = helper.get_attribute_by_name(node, 'pads').ints
        dilation = helper.get_attribute_by_name(node, 'dilations').ints

        # Pytorch model has the case where strides only have one number
        if len(strides) == 1:
            return strides.append(strides[0])
        if len(dilation) == 1:
            return dilation.append(dilation[0])

        H = math.floor((input_shape[2]+pads[0]+pads[2]-\
            dilation[0]*(kernel_shape[2]-1)-1)/strides[0]+1)
        W = math.floor((input_shape[3]+pads[1]+pads[3]-\
            dilation[1]*(kernel_shape[3]-1)-1)/strides[1]+1)
        output_shape = [input_shape[0], kernel_shape[0], H, W]

        new_output_value_info = onnx.helper.make_tensor_value_info(
            node.output[0],
            input_value_info.type.tensor_type.elem_type,
            output_shape
        )

        processed = True

        if output_value_info:
            g.value_info.remove(output_value_info)
        g.value_info.extend([new_output_value_info])

    return processed
Exemple #6
0
def fuse_BN_into_Gemm(g):
    """Fuse the following BN into the previous Gemm.

    :param g: the graph
    """
    node_to_remove = []
    for node in g.node:
        # Check for BN and Gemm
        if node.op_type != 'BatchNormalization':
            continue
        gemm_node = helper.find_node_by_output_name(g, node.input[0])
        if gemm_node is None:
            continue
        if gemm_node.op_type != 'Gemm':
            continue
        if len(
                helper.find_following_nodes_by_input_value_name(
                    g, gemm_node.output[0])) > 1:
            continue
        bn_node = node
        # Get original weights
        gemm_b_node = helper.find_node_by_output_name(g, gemm_node.input[1])
        gemm_b = helper.constant_to_numpy(gemm_b_node)
        gemm_c_node = helper.find_node_by_output_name(g, gemm_node.input[2])
        gemm_c = helper.constant_to_numpy(gemm_c_node)
        bn_scale_node = helper.find_node_by_output_name(g, bn_node.input[1])
        bn_scale = helper.constant_to_numpy(bn_scale_node)
        bn_bias_node = helper.find_node_by_output_name(g, bn_node.input[2])
        bn_bias = helper.constant_to_numpy(bn_bias_node)
        bn_mean_node = helper.find_node_by_output_name(g, bn_node.input[3])
        bn_mean = helper.constant_to_numpy(bn_mean_node)
        bn_var_node = helper.find_node_by_output_name(g, bn_node.input[4])
        bn_var = helper.constant_to_numpy(bn_var_node)
        # Apply attributes
        # epsilon
        epsilon = helper.get_attribute_by_name(bn_node, 'epsilon')
        if epsilon is None:
            epsilon = 0.00001
        else:
            epsilon = epsilon.f
        bn_var = bn_var + epsilon
        # alpha
        alpha = helper.get_attribute_by_name(gemm_node, 'alpha')
        if alpha is None:
            alpha = 1
        else:
            alpha = alpha.f
        gemm_b = gemm_b * alpha
        # beta
        beta = helper.get_attribute_by_name(gemm_node, 'beta')
        if beta is None:
            beta = 1
        else:
            beta = beta.f
        gemm_c = gemm_c * beta
        # transA
        transA = helper.get_attribute_by_name(gemm_node, 'transA')
        if transA is not None and transA.i == 1:
            raise RuntimeError("Do not support transA")
        # transB
        transB = helper.get_attribute_by_name(gemm_node, 'transB')
        if transB is not None and transB.i == 1:
            gemm_b = gemm_b.transpose()
        # Calculate new weights
        new_gemm_b = gemm_b * bn_scale / np.sqrt(bn_var)
        new_gemm_c = (gemm_c - bn_mean) * bn_scale / np.sqrt(bn_var) + bn_bias
        # Replace original weights
        new_gemm_b_node = helper.numpy_to_constant(gemm_b_node.name + '_fused',
                                                   new_gemm_b)
        new_gemm_c_node = helper.numpy_to_constant(gemm_c_node.name + '_fused',
                                                   new_gemm_c)
        g.node.extend([new_gemm_b_node, new_gemm_c_node])
        node_to_remove.extend([
            gemm_b_node, gemm_c_node, bn_node, bn_scale_node, bn_bias_node,
            bn_mean_node, bn_var_node
        ])
        # Modify attributes
        # alpha
        alpha = helper.get_attribute_by_name(gemm_node, 'alpha')
        if alpha is not None:
            alpha.f = 1.0
        # beta
        beta = helper.get_attribute_by_name(gemm_node, 'beta')
        if beta is not None:
            beta.f = 1.0
        # transB
        transB = helper.get_attribute_by_name(gemm_node, 'transB')
        if transB is not None:
            transB.i = 0
        # Connect the new graph
        gemm_node.input[1] = new_gemm_b_node.output[0]
        gemm_node.input[2] = new_gemm_c_node.output[0]
        gemm_b_value = helper.find_value_by_name(g, gemm_b_node.output[0])
        gemm_c_value = helper.find_value_by_name(g, gemm_c_node.output[0])
        gemm_b_value.name = new_gemm_b_node.output[0]
        gemm_c_value.name = new_gemm_c_node.output[0]
        gemm_value = helper.find_value_by_name(g, gemm_node.output[0])
        g.value_info.remove(gemm_value)
        gemm_node.output[0] = bn_node.output[0]
        for i in range(1, 5):
            value = helper.find_value_by_name(g, bn_node.input[i])
            g.value_info.remove(value)
    # Remove useless nodes
    for node in node_to_remove:
        g.node.remove(node)
    topological_sort(g)
Exemple #7
0
def fuse_Gemm_into_Gemm(g):
    """Fuse the previous Gemm into the following Gemm.

    :param g: the graph
    """
    node_to_remove = []
    for node in g.node:
        # Check for Gemm and Gemm
        if node.op_type != 'Gemm':
            continue
        prev_node = helper.find_node_by_output_name(g, node.input[0])
        if prev_node is None:
            continue
        if prev_node.op_type != 'Gemm':
            continue
        # Get original weights
        prev_b_node = helper.find_node_by_output_name(g, prev_node.input[1])
        prev_b = helper.constant_to_numpy(prev_b_node)
        prev_c_node = helper.find_node_by_output_name(g, prev_node.input[2])
        prev_c = helper.constant_to_numpy(prev_c_node)
        b_node = helper.find_node_by_output_name(g, node.input[1])
        b = helper.constant_to_numpy(b_node)
        c_node = helper.find_node_by_output_name(g, node.input[2])
        c = helper.constant_to_numpy(c_node)
        # Apply attributes
        # alpha
        alpha = helper.get_attribute_by_name(node, 'alpha')
        if alpha is None:
            alpha = 1
        else:
            alpha = alpha.f
        b = b * alpha
        alpha = helper.get_attribute_by_name(prev_node, 'alpha')
        if alpha is None:
            alpha = 1
        else:
            alpha = alpha.f
        prev_b = prev_b * alpha
        # beta
        beta = helper.get_attribute_by_name(node, 'beta')
        if beta is None:
            beta = 1
        else:
            beta = beta.f
        c = c * beta
        beta = helper.get_attribute_by_name(prev_node, 'beta')
        if beta is None:
            beta = 1
        else:
            beta = beta.f
        prev_c = prev_c * beta
        # transA
        transA = helper.get_attribute_by_name(node, 'transA')
        if transA is not None and transA.i == 1:
            raise RuntimeError("Do not support transA")
        transA = helper.get_attribute_by_name(prev_node, 'transA')
        if transA is not None and transA.i == 1:
            raise RuntimeError("Do not support transA")
        # transB
        transB = helper.get_attribute_by_name(node, 'transB')
        if transB is not None and transB.i == 1:
            b = b.transpose()
        transB = helper.get_attribute_by_name(prev_node, 'transB')
        if transB is not None and transB.i == 1:
            prev_b = prev_b.transpose()
        # Calculate new weights
        new_b = prev_b.dot(b)
        new_c = prev_c.dot(b) + c
        # Replace original weights
        new_b_node = helper.numpy_to_constant(b_node.name + '_fused', new_b)
        new_c_node = helper.numpy_to_constant(c_node.name + '_fused', new_c)
        g.node.extend([new_b_node, new_c_node])
        node_to_remove.extend(
            [b_node, c_node, prev_b_node, prev_c_node, prev_node])
        # Modify attributes
        # alpha
        alpha = helper.get_attribute_by_name(node, 'alpha')
        if alpha is not None:
            alpha.f = 1.0
        # beta
        beta = helper.get_attribute_by_name(node, 'beta')
        if beta is not None:
            beta.f = 1.0
        # transB
        transB = helper.get_attribute_by_name(node, 'transB')
        if transB is not None:
            transB.i = 0
        # Connect the new graph
        node.input[0] = prev_node.input[0]
        prev_value = helper.find_value_by_name(g, prev_node.output[0])
        g.value_info.remove(prev_value)
        for i in range(1, 3):
            value = helper.find_value_by_name(g, prev_node.input[i])
            g.value_info.remove(value)
            value = helper.find_value_by_name(g, node.input[i])
            g.value_info.remove(value)
        node.input[1] = new_b_node.output[0]
        node.input[2] = new_c_node.output[0]
    # Remove useless nodes
    for node in node_to_remove:
        g.node.remove(node)
    topological_sort(g)
Exemple #8
0
def fuse_BN_with_Reshape_into_Gemm(g):
    """Fuse the following BN into the previous Gemm, even with Reshape or \\
        Squeeze and Unsqueeze surrounding.

    :param g: the graph
    """
    node_to_remove = []
    for node in g.node:
        # Check for BN and Gemm pattern: Gemm A BN B
        # Find BatchNorm Node
        if node.op_type != 'BatchNormalization':
            continue
        bn_node = node
        # Find A Node
        a_node = helper.find_node_by_output_name(g, node.input[0])
        if a_node is None or len(a_node.input) == 0:
            continue
        # Find Gemm Node
        gemm_node = helper.find_node_by_output_name(g, a_node.input[0])
        if gemm_node is None or gemm_node.op_type != 'Gemm':
            continue
        # Find B Node
        b_node_list = helper.find_following_nodes_by_input_value_name(
            g, bn_node.output[0])
        if len(b_node_list) == 0:
            the_output = helper.find_output_by_name(g, bn_node.output[0])
            if the_output is None:
                continue
            b_node = None
        elif len(b_node_list) > 1:
            continue
        else:
            b_node = b_node_list[0]
        # Check for branches
        if len(
                helper.find_following_nodes_by_input_value_name(
                    g, gemm_node.output[0])) > 1:
            continue
        if len(
                helper.find_following_nodes_by_input_value_name(
                    g, a_node.output[0])) > 1:
            continue
        # Check type of A
        if a_node.op_type == 'Unsqueeze':
            axes = helper.get_attribute_by_name(a_node, 'axes')
            if axes.ints != [2]:
                continue
        elif a_node.op_type == 'Reshape':
            a = helper.constant_to_list(
                helper.find_node_by_output_name(g, a_node.input[1]))[1]
            if len(a) != 3 or a[2] != 1:
                continue
        else:
            continue
        # Check type of B
        if b_node is None:
            pass
        elif b_node.op_type == 'Flatten':
            pass
        elif b_node.op_type == 'Squeeze':
            axes = helper.get_attribute_by_name(a_node, 'axes')
            if axes.ints != [2]:
                continue
        elif b_node.op_type == 'Reshape':
            a = helper.constant_to_list(
                helper.find_node_by_output_name(g, b_node.input[1]))[1]
            if len(a) != 2:
                continue
        else:
            continue
        # Construct new Nodes
        # Get original weights
        gemm_b_node = helper.find_node_by_output_name(g, gemm_node.input[1])
        gemm_b = helper.constant_to_numpy(gemm_b_node)
        gemm_c_node = helper.find_node_by_output_name(g, gemm_node.input[2])
        gemm_c = helper.constant_to_numpy(gemm_c_node)
        bn_scale_node = helper.find_node_by_output_name(g, bn_node.input[1])
        bn_scale = helper.constant_to_numpy(bn_scale_node)
        bn_bias_node = helper.find_node_by_output_name(g, bn_node.input[2])
        bn_bias = helper.constant_to_numpy(bn_bias_node)
        bn_mean_node = helper.find_node_by_output_name(g, bn_node.input[3])
        bn_mean = helper.constant_to_numpy(bn_mean_node)
        bn_var_node = helper.find_node_by_output_name(g, bn_node.input[4])
        bn_var = helper.constant_to_numpy(bn_var_node)
        # Apply attributes
        # epsilon
        epsilon = helper.get_attribute_by_name(bn_node, 'epsilon')
        if epsilon is None:
            epsilon = 0.00001
        else:
            epsilon = epsilon.f
        bn_var = bn_var + epsilon
        # alpha
        alpha = helper.get_attribute_by_name(gemm_node, 'alpha')
        if alpha is None:
            alpha = 1
        else:
            alpha = alpha.f
        gemm_b = gemm_b * alpha
        # beta
        beta = helper.get_attribute_by_name(gemm_node, 'beta')
        if beta is None:
            beta = 1
        else:
            beta = beta.f
        gemm_c = gemm_c * beta
        # transA
        transA = helper.get_attribute_by_name(gemm_node, 'transA')
        if transA is not None and transA.i == 1:
            raise RuntimeError("Do not support transA")
        # transB
        transB = helper.get_attribute_by_name(gemm_node, 'transB')
        if transB is not None and transB.i == 1:
            gemm_b = gemm_b.transpose()
        # Calculate new weights
        new_gemm_b = gemm_b * bn_scale / np.sqrt(bn_var)
        new_gemm_c = (gemm_c - bn_mean) * bn_scale / np.sqrt(bn_var) + bn_bias
        # Replace original weights
        new_gemm_b_node = helper.numpy_to_constant(gemm_b_node.name + '_fused',
                                                   new_gemm_b)
        new_gemm_c_node = helper.numpy_to_constant(gemm_c_node.name + '_fused',
                                                   new_gemm_c)
        g.node.extend([new_gemm_b_node, new_gemm_c_node])
        # Modify attributes
        # alpha
        alpha = helper.get_attribute_by_name(gemm_node, 'alpha')
        if alpha is not None:
            alpha.f = 1.0
        # beta
        beta = helper.get_attribute_by_name(gemm_node, 'beta')
        if beta is not None:
            beta.f = 1.0
        # transB
        transB = helper.get_attribute_by_name(gemm_node, 'transB')
        if transB is not None:
            transB.i = 0
        # Remove useless nodes
        node_to_remove.extend([
            gemm_b_node, gemm_c_node, bn_node, bn_scale_node, bn_bias_node,
            bn_mean_node, bn_var_node, a_node
        ])
        if a_node.op_type == 'Reshape':
            node_to_remove.append(
                helper.find_node_by_output_name(g, a_node.input[1]))
        if b_node is not None:
            node_to_remove.append(b_node)
            if b_node.op_type == 'Reshape':
                node_to_remove.append(
                    helper.find_node_by_output_name(g, b_node.input[1]))
        # Delete useless value infos
        value = helper.find_value_by_name(g, a_node.output[0])
        g.value_info.remove(value)
        if a_node.op_type == 'Reshape':
            value = helper.find_value_by_name(g, a_node.input[1])
            g.value_info.remove(value)
        for i in range(1, 5):
            value = helper.find_value_by_name(g, bn_node.input[i])
            g.value_info.remove(value)
        value = helper.find_value_by_name(g, bn_node.output[0])
        if value is not None:
            g.value_info.remove(value)
        if b_node is not None:
            value = helper.find_value_by_name(g, gemm_node.output[0])
            g.value_info.remove(value)
            if b_node.op_type == 'Reshape':
                value = helper.find_value_by_name(g, b_node.input[1])
                g.value_info.remove(value)
        # Connect the new graph
        # Connect Gemm new weights
        gemm_node.input[1] = new_gemm_b_node.output[0]
        gemm_node.input[2] = new_gemm_c_node.output[0]
        gemm_b_value = helper.find_value_by_name(g, gemm_b_node.output[0])
        gemm_c_value = helper.find_value_by_name(g, gemm_c_node.output[0])
        gemm_b_value.name = new_gemm_b_node.output[0]
        gemm_b_value.type.tensor_type.shape.dim[
            0].dim_value = new_gemm_b.shape[0]
        gemm_b_value.type.tensor_type.shape.dim[
            1].dim_value = new_gemm_b.shape[1]
        gemm_c_value.name = new_gemm_c_node.output[0]
        if b_node is None:
            # If b node is None, set the Gemm output as the graph output
            output_value = helper.find_output_by_name(g, bn_node.output[0])
            g.output.remove(output_value)
            g.output.extend(
                [helper.find_value_by_name(g, gemm_node.output[0])])
        else:
            # Else, set node B output as gemm output
            gemm_node.output[0] = b_node.output[0]
    # Remove useless nodes
    for node in node_to_remove:
        g.node.remove(node)
    topological_sort(g)