コード例 #1
0
ファイル: other.py プロジェクト: thomaswang525/ONNX_Convertor
def add_bn_on_skip_branch(g):
    for n in g.node:
        # Find merge node (Add)
        if n.op_type != 'Add':
            continue
        if len(n.input) != 2:
            continue
        # TODO: Still need to consider more cases
        # Check if skip branch exist
        input_node_a = helper.find_node_by_output_name(g, n.input[0])
        output_of_input_node_a = helper.find_nodes_by_input_name(
            g, input_node_a.output[0])
        input_node_b = helper.find_node_by_output_name(g, n.input[1])
        output_of_input_node_b = helper.find_nodes_by_input_name(
            g, input_node_b.output[0])
        if len(output_of_input_node_a) == 1 and len(
                output_of_input_node_b) == 1:
            continue
        if len(output_of_input_node_a) == 2:
            split_node = input_node_a
        elif len(output_of_input_node_b) == 2:
            split_node = input_node_b
        else:
            continue
        # Get the channel number from value info
        value_name = split_node.output[0]
        value = helper.find_value_by_name(g, value_name)
        shape = helper.get_shape_from_value_info(value)
        channel = shape[1]
        # Construct 4 weights
        node_name = value_name + "_nop_bn"
        ones = [1.0] * channel
        zeros = [0.0] * channel
        scale_node = helper.list_to_constant(node_name + "_scale", [channel],
                                             ones)
        bias_node = helper.list_to_constant(node_name + "_bias", [channel],
                                            zeros)
        mean_node = helper.list_to_constant(node_name + "_mean", [channel],
                                            zeros)
        var_node = helper.list_to_constant(node_name + "_var", [channel], ones)
        # Construct BN node
        bn_node = onnx.helper.make_node("BatchNormalization", [
            value_name, scale_node.output[0], bias_node.output[0],
            mean_node.output[0], var_node.output[0]
        ], [node_name],
                                        name=node_name)
        # Reconnect the graph
        replace_node_input(n, value_name, node_name)
        # Add node to the graph
        g.node.extend([bn_node, scale_node, bias_node, mean_node, var_node])
    topological_sort(g)
コード例 #2
0
ファイル: replacing.py プロジェクト: kneron/ONNX_Convertor
def replace_initializer_with_Constant(g):
    """
    Replace initializers with Constant and a corresponding value_info
    If the initializer has related input, remove it.

    :param g: the onnx graph
    """

    input_map = {i.name: i for i in g.input}
    for tensor in g.initializer:
        # Check for the initializer related input and remove it
        if tensor.name in input_map:
            value_info = input_map[tensor.name]
            g.input.remove(value_info)
        following_nodes = helper.find_nodes_by_input_name(g, tensor.name)
        for i, node in enumerate(following_nodes):
            new_name = tensor.name + "_duplicated_No" + str(
                i) if i > 0 else tensor.name
            modhelper.replace_node_input(node, tensor.name, new_name)
            new_node = onnx.helper.make_node("Constant", [], [new_name],
                                             name=new_name,
                                             value=tensor)
            # Add node to lists
            g.node.extend([new_node])

        # if value info already exists, remove it as well.
        value_info = helper.find_value_by_name(g, tensor.name)
        if value_info is not None:
            g.value_info.remove(value_info)

    # Remove original initializer
    while len(g.initializer) != 0:
        g.initializer.pop()

    topological_sort(g)
コード例 #3
0
ファイル: other.py プロジェクト: thomaswang525/ONNX_Convertor
def find_first_sequential_outputs(g, node):
    for value_name in node.output:
        value = helper.find_output_by_name(g, value_name)
        if value is not None:
            return value
    return find_first_sequential_outputs(
        g,
        helper.find_nodes_by_input_name(g, node.output[0])[0])
コード例 #4
0
def replace_initializer_with_Constant(g):
    """
    Replace initializers with Constant and a corresponding value_info

    :param g: the onnx graph
    """
    # Creat a set of existed node names
    node_names = set()
    for node in g.node:
        node_names.add(node.name)
    # Unused initializers should be removed
    unused_initializer = set()
    for tensor in g.initializer:
        unused_initializer.add(tensor.name)
    for node in g.node:
        for in_value in node.input:
            if in_value in unused_initializer:
                unused_initializer.remove(in_value)

    input_map = {i.name: i for i in g.input}
    for tensor in g.initializer:
        if tensor.name in unused_initializer:
            value_info = input_map[tensor.name]
            g.input.remove(value_info)
            continue
        # Convert init to a constant node
        if tensor.name not in node_names:
            new_name = tensor.name
        else:
            new_name = tensor.name + '_2'
            following_nodes = helper.find_nodes_by_input_name(g, tensor.name)
            for node in following_nodes:
                modhelper.replace_node_input(node, tensor.name, new_name)
        new_node = onnx.helper.make_node(
            "Constant",
            [],
            [new_name],
            name=new_name,
            value=tensor
        )
        # Add node to lists
        g.node.extend([new_node])
        # Add value info to lists
        value_info = input_map[tensor.name]
        g.value_info.extend([value_info])
        # Remove original input value info
        g.input.remove(value_info)
        # if value info exists, remove it as well.
        value_info = helper.find_value_by_name(g, tensor.name)
        if value_info is not None:
            g.value_info.remove(value_info)
    # Remove original initializer
    while len(g.initializer) != 0:
        g.initializer.pop()
コード例 #5
0
ファイル: other.py プロジェクト: raefYoussef/ONNX_Convertor
def rename_output_name(g, original_name, new_name):
    # Output
    output_value = helper.find_output_by_name(g, original_name)
    if output_value is None:
        logging.error("Cannot find output value named " + original_name)
        return
    output_value.name = new_name
    # Value Info
    value_info = helper.find_value_by_name(g, original_name)
    if value_info is not None:
        value_info.name = new_name
    # Node output
    node = helper.find_node_by_output_name(g, original_name)
    node.output[0] = new_name
    # Node input
    nodes = helper.find_nodes_by_input_name(g, original_name)
    for node in nodes:
        replace_node_input(node, original_name, new_name)
コード例 #6
0
def pattern_matmul_mul_add(g, matmul_node):
    # Check node match - Mul node
    next_nodes = helper.find_nodes_by_input_name(g, matmul_node.output[0])
    if len(next_nodes) != 1:
        return
    if next_nodes[0].op_type != 'Mul':
        return
    mul_node = next_nodes[0]
    # Check node match - Add node
    next_nodes = helper.find_nodes_by_input_name(g, mul_node.output[0])
    if len(next_nodes) != 1:
        return
    if next_nodes[0].op_type != 'Add':
        return
    add_node = next_nodes[0]
    # Check Mul weight
    mul_weight_node = helper.find_node_by_output_name(g, mul_node.input[1])
    if mul_weight_node.op_type != 'Constant':
        return
    weight_size, mul_weight = helper.constant_to_list(mul_weight_node)
    for i in mul_weight:
        if i != 1:
            return
    channel = weight_size[0]
    # Check Add weight
    add_weight_node = helper.find_node_by_output_name(g, add_node.input[1])
    if add_weight_node.op_type != 'Constant':
        return
    # Check MatMul weight to see if it need weight broadcast
    matmul_weight_node = helper.find_node_by_output_name(g, matmul_node.input[1])
    matmul_weight = helper.constant_to_numpy(matmul_weight_node)
    if matmul_weight.shape[1] == 1:
        # Weight broadcast
        new_matmul_weight = np.tile(matmul_weight, channel)
        new_matmul_weight_node = helper.numpy_to_constant(matmul_weight_node.name, new_matmul_weight)
        g.node.remove(matmul_weight_node)
        g.node.extend([new_matmul_weight_node])
    value = helper.find_value_by_name(g, matmul_weight_node.output[0])
    if value is not None:
        g.value_info.remove(value)
    # Remove Mul node
    g.node.remove(mul_weight_node)
    value = helper.find_value_by_name(g, mul_weight_node.output[0])
    if value is not None:
        g.value_info.remove(value)
    g.node.remove(mul_node)
    value = helper.find_value_by_name(g, mul_node.output[0])
    if value is not None:
        g.value_info.remove(value)
    # Fuse Matmul and Add
    gemm_node = onnx.helper.make_node(
        'Gemm',
        [matmul_node.input[0], matmul_node.input[1], add_node.input[1]],
        [add_node.output[0]],
        name = matmul_node.name,
        alpha = 1.0,
        beta = 1.0,
        transA = 0,
        transB = 0
    )
    g.node.extend([gemm_node])
    # Clean up
    g.node.remove(matmul_node)
    g.node.remove(add_node)
    value = helper.find_value_by_name(g, matmul_node.output[0])
    if value is not None:
        g.value_info.remove(value)
    other.topological_sort(g)