コード例 #1
0
ファイル: other.py プロジェクト: thomaswang525/ONNX_Convertor
def add_nop_bn_after(g, value_names):
    """Add do-nothing BatchNormalization nodes after the given value info. It will\\
    take the given names as the inputs of the new node and replace the inputs\\
    of the following nodes.

    :param g: the graph\\
    :param value_names: a list of string which are the names of value_info.
    """
    for value_name in value_names:
        # Find the value first
        value = helper.find_value_by_name(g, value_name)
        if value is None:
            value = helper.find_input_by_name(g, value_name)
        if value is None:
            value = helper.find_output_by_name(g, value_name)
        if value is None:
            print("Cannot find an value_info named {}".format(value_name))
            continue
        # Get the channel number from value info
        shape = helper.get_shape_from_value_info(value)
        channel = shape[1]
        # Construct 4 weights
        node_name = value_name + "_nop_bn"
        ones = [1.0] * channel
        zeros = [0.0] * channel
        scale_node = helper.list_to_constant(node_name + "_scale", [channel],
                                             ones)
        bias_node = helper.list_to_constant(node_name + "_bias", [channel],
                                            zeros)
        mean_node = helper.list_to_constant(node_name + "_mean", [channel],
                                            zeros)
        var_node = helper.list_to_constant(node_name + "_var", [channel], ones)
        # Construct BN node
        bn_node = onnx.helper.make_node("BatchNormalization", [
            value_name, scale_node.output[0], bias_node.output[0],
            mean_node.output[0], var_node.output[0]
        ], [node_name],
                                        name=node_name)
        # Reconnect the graph
        following_nodes = helper.find_following_nodes_by_input_value_name(
            g, value_name)
        if len(following_nodes) > 0:
            for following_node in following_nodes:
                replace_node_input(following_node, value_name, node_name)
        else:
            # If the node is the output, replace the output with the previous input.
            new_value = onnx.helper.make_tensor_value_info(
                node_name, value.type.tensor_type.elem_type, shape)
            output_values = []
            while len(g.output):
                output_values.append(g.output.pop())
            while output_values:
                output_value = output_values.pop()
                if output_value.name == value_name:
                    g.output.extend([new_value])
                else:
                    g.output.extend([output_value])
        # Add node to the graph
        g.node.extend([bn_node, scale_node, bias_node, mean_node, var_node])
    topological_sort(g)
コード例 #2
0
ファイル: other.py プロジェクト: raefYoussef/ONNX_Convertor
 def add_bn_after(prev_node):
     # Get the channel number from value info
     value_name = prev_node.output[0]
     value = helper.find_value_by_name(g, value_name)
     shape = helper.get_shape_from_value_info(value)
     channel = shape[1]
     # Construct 4 weights
     node_name = value_name + "_nop_bn"
     ones = [1.0] * channel
     zeros = [0.0] * channel
     scale_node = helper.list_to_constant(node_name + "_scale", [channel], ones)
     bias_node = helper.list_to_constant(node_name + "_bias", [channel], zeros)
     mean_node = helper.list_to_constant(node_name + "_mean", [channel], zeros)
     var_node = helper.list_to_constant(node_name + "_var", [channel], ones)
     # Construct BN node
     bn_node = onnx.helper.make_node(
         "BatchNormalization",
         [value_name,
         scale_node.output[0],
         bias_node.output[0],
         mean_node.output[0],
         var_node.output[0]],
         [node_name],
         name = node_name
     )
     # Reconnect the graph
     replace_node_input(n, value_name, node_name)
     # Add node to the graph
     g.node.extend([bn_node, scale_node, bias_node, mean_node, var_node])
コード例 #3
0
def replace_depthwise_1x1_with_bn(g):
    """Replace 1x1 DepthwiseConv node into BN node if applicable.

    :param g: the onnx graph
    """
    node_to_remove = []
    for node in g.node:
        # Check op_type
        if node.op_type != 'Conv':
            continue
        # Check attributes
        attr_map = {attr.name: attr for attr in node.attribute}
        if "group" not in attr_map or attr_map["group"].i == 1:
            continue
        if attr_map["kernel_shape"].ints[0] != 1 or attr_map["kernel_shape"].ints[1] != 1:
            continue
        if "pads" in attr_map and sum(attr_map["pads"].ints) != 0:
            continue
        # Check scale
        scale_node = helper.find_node_by_output_name(g, node.input[1])
        if scale_node is None or scale_node.attribute[0].t.dims[1] != 1:
            continue
        scale_node.attribute[0].t.dims.pop()
        scale_node.attribute[0].t.dims.pop()
        scale_node.attribute[0].t.dims.pop()
        scale_info = helper.find_value_by_name(g, node.input[1])
        if scale_info is not None:
            scale_info.type.tensor_type.shape.dim.pop()
            scale_info.type.tensor_type.shape.dim.pop()
            scale_info.type.tensor_type.shape.dim.pop()
        # Check bias
        if len(node.input) == 3:
            bias_name = node.input[2]
        else:
            bias_name = node.name + "_bias"
            bias_node = helper.list_to_constant(bias_name, [attr_map["group"].i], [0.0] * attr_map["group"].i)
            g.node.extend([bias_node])
        # Construct mean and vars
        mean_name = node.name + "_mean"
        mean_node = helper.list_to_constant(mean_name, [attr_map["group"].i], [0.0] * attr_map["group"].i)
        var_name = node.name + "_var"
        var_node = helper.list_to_constant(var_name, [attr_map["group"].i], [1.0] * attr_map["group"].i)
        g.node.extend([mean_node, var_node])
        # Convert
        bn_node = onnx.helper.make_node(
            op_type='BatchNormalization',
            inputs=[node.input[0], node.input[1], bias_name, mean_name, var_name],
            outputs=node.output,
            name=node.name,
            epsilon=0.00001,
            momentum=0.9
            )
        g.node.extend([bn_node])
        node_to_remove.append(node)
    for node in node_to_remove:
        g.node.remove(node)
    topological_sort(g)
コード例 #4
0
ファイル: other.py プロジェクト: thomaswang525/ONNX_Convertor
def add_bn_on_skip_branch(g):
    for n in g.node:
        # Find merge node (Add)
        if n.op_type != 'Add':
            continue
        if len(n.input) != 2:
            continue
        # TODO: Still need to consider more cases
        # Check if skip branch exist
        input_node_a = helper.find_node_by_output_name(g, n.input[0])
        output_of_input_node_a = helper.find_nodes_by_input_name(
            g, input_node_a.output[0])
        input_node_b = helper.find_node_by_output_name(g, n.input[1])
        output_of_input_node_b = helper.find_nodes_by_input_name(
            g, input_node_b.output[0])
        if len(output_of_input_node_a) == 1 and len(
                output_of_input_node_b) == 1:
            continue
        if len(output_of_input_node_a) == 2:
            split_node = input_node_a
        elif len(output_of_input_node_b) == 2:
            split_node = input_node_b
        else:
            continue
        # Get the channel number from value info
        value_name = split_node.output[0]
        value = helper.find_value_by_name(g, value_name)
        shape = helper.get_shape_from_value_info(value)
        channel = shape[1]
        # Construct 4 weights
        node_name = value_name + "_nop_bn"
        ones = [1.0] * channel
        zeros = [0.0] * channel
        scale_node = helper.list_to_constant(node_name + "_scale", [channel],
                                             ones)
        bias_node = helper.list_to_constant(node_name + "_bias", [channel],
                                            zeros)
        mean_node = helper.list_to_constant(node_name + "_mean", [channel],
                                            zeros)
        var_node = helper.list_to_constant(node_name + "_var", [channel], ones)
        # Construct BN node
        bn_node = onnx.helper.make_node("BatchNormalization", [
            value_name, scale_node.output[0], bias_node.output[0],
            mean_node.output[0], var_node.output[0]
        ], [node_name],
                                        name=node_name)
        # Reconnect the graph
        replace_node_input(n, value_name, node_name)
        # Add node to the graph
        g.node.extend([bn_node, scale_node, bias_node, mean_node, var_node])
    topological_sort(g)
コード例 #5
0
ファイル: replacing.py プロジェクト: henye/ONNX_Convertor
def replace_Squeeze_with_Reshape(g):
    """
    Replace Squeeze nodes with Reshape node.

    :param g: the input graph
    """
    node_to_remove = []
    for node in g.node:
        # Find Squeeze node
        if node.op_type != 'Squeeze':
            continue
        # Get the shape and Construct the shape
        output_value = helper.find_value_by_name(g, node.output[0])
        if output_value is None:
            output_value = helper.find_output_by_name(g, node.output[0])
        if output_value is None:
            raise RuntimeError("Cannot get shape for Squeeze")
        shape = [
            dim.dim_value for dim in output_value.type.tensor_type.shape.dim
        ]
        const_node = helper.list_to_constant(node.name + "_shape",
                                             [len(shape)], shape)
        # Construct the Reshape layer with same input, output and name.
        new_node = onnx.helper.make_node("Reshape",
                                         [node.input[0], node.name + "_shape"],
                                         node.output,
                                         name=node.name)
        # Append constructed nodes and append old node to remove_list
        g.node.extend([const_node, new_node])
        node_to_remove.append(node)
    # Remove old nodes
    for node in node_to_remove:
        g.node.remove(node)
    # Topological sort
    topological_sort(g)
コード例 #6
0
def polish_RESIZE_input_param_node(g, resize_node_name):
    resize_node = helper.find_node_by_output_name(g, resize_node_name)

    shape_data_node = helper.find_node_by_output_name(g, resize_node.input[3])
    shape_data = helper.constant_to_numpy(shape_data_node).astype(int)

    # handle 0 batch size which is invalid
    if shape_data[0] == 0:
        shape_data[0] = 1

    pre_node_output_value_info = helper.find_value_by_name(
        g, resize_node.input[0])
    ori_shape = np.array([
        pre_node_output_value_info.type.tensor_type.shape.dim[0].dim_value,
        pre_node_output_value_info.type.tensor_type.shape.dim[1].dim_value,
        pre_node_output_value_info.type.tensor_type.shape.dim[2].dim_value,
        pre_node_output_value_info.type.tensor_type.shape.dim[3].dim_value
    ])

    resize_node.input.remove(resize_node.input[3])

    resize_scales = np.array(shape_data / ori_shape).astype(float)
    resize_scale_node = helper.list_to_constant(
        'resize_scales_node_' + resize_node.name,
        resize_scales.shape,
        resize_scales,
        data_type=onnx.helper.TensorProto.FLOAT)

    resize_node.input[2] = resize_scale_node.name
    g.node.extend([resize_scale_node])

    other.topological_sort(g)
コード例 #7
0
ファイル: other.py プロジェクト: thomaswang525/ONNX_Convertor
def add_nop_conv_after(g, value_names):
    """Add do-nothing depthwise Conv nodes after the given value info. It will\\
    take the given names as the inputs of the new node and replace the inputs\\
    of the following nodes.

    :param g: the graph\\
    :param value_names: a list of string which are the names of value_info.
    """
    for value_name in value_names:
        # Find the value first
        value = helper.find_value_by_name(g, value_name)
        if value is None:
            value = helper.find_input_by_name(g, value_name)
        if value is None:
            value = helper.find_output_by_name(g, value_name)
        if value is None:
            print("Cannot find an value_info named {}".format(value_name))
            continue
        # Get the channel number from value info
        shape = helper.get_shape_from_value_info(value)
        channel = shape[1]
        # Construct 4 weights
        node_name = value_name + "_nop_conv"
        ones = [1.0] * channel
        weight_node = helper.list_to_constant(node_name + "_weight",
                                              [channel, 1, 1, 1], ones)
        # Construct BN node
        conv_node = onnx.helper.make_node("Conv",
                                          [value_name, weight_node.output[0]],
                                          [node_name],
                                          name=node_name,
                                          dilations=[1, 1],
                                          group=channel,
                                          kernel_shape=[1, 1],
                                          pads=[0, 0, 0, 0],
                                          strides=[1, 1])
        # Reconnect the graph
        following_nodes = helper.find_following_nodes_by_input_value_name(
            g, value_name)
        if len(following_nodes) > 0:
            for following_node in following_nodes:
                replace_node_input(following_node, value_name, node_name)
        else:
            # If the node is the output, replace the output with the previous input.
            new_value = onnx.helper.make_tensor_value_info(
                node_name, value.type.tensor_type.elem_type, shape)
            output_values = []
            while len(g.output):
                output_values.append(g.output.pop())
            while output_values:
                output_value = output_values.pop()
                if output_value.name == value_name:
                    g.output.extend([new_value])
                else:
                    g.output.extend([output_value])
        # Add node to the graph
        g.node.extend([conv_node, weight_node])
    topological_sort(g)
コード例 #8
0
ファイル: replacing.py プロジェクト: kneron/ONNX_Convertor
def replace_shape_with_constant(g):
    """Replace Shape with Constant.\\
    This is the first step of reshape constant folding.

    :param g: the input graph\\
    :return: if anything modified, return true.
    """
    node_to_remove = []
    for node in g.node:
        # Find a Shape
        if node.op_type != 'Shape':
            continue
        # Check its input
        input_value = helper.find_value_by_name(g, node.input[0])
        if input_value is None:
            input_value = helper.find_input_by_name(g, node.input[0])
        if input_value is None or len(
                input_value.type.tensor_type.shape.dim) == 0:
            continue
        # Check for case where dimension could be 0 or -1
        tmp = True
        for d in input_value.type.tensor_type.shape.dim:
            tmp = tmp and (d.dim_value > 0)
        if not tmp:
            continue
        # Repalce it
        input_shape = [
            d.dim_value for d in input_value.type.tensor_type.shape.dim
        ]
        node_name = node.output[0]
        new_node = helper.list_to_constant(node_name, [len(input_shape)],
                                           input_shape)
        g.node.extend([new_node])
        node_to_remove.append(node)

        # if the input value_info is not used by other node
        # delete this input value_info
        val_info_used = sum(
            [input_value.name in node.input for node in g.node])
        if val_info_used == 1:
            g.value_info.remove(input_value)

    replaced = True if len(node_to_remove) > 0 else False

    for node in node_to_remove:
        g.node.remove(node)

    topological_sort(g)

    return replaced
コード例 #9
0
ファイル: replacing.py プロジェクト: kneron/ONNX_Convertor
def replace_ConstantOfShape_with_constant(g):
    """Replace Shape with Constant.\\
    This is the first step of reshape constant folding.

    :param g: the input graph\\
    :return: if anything modified, return true.
    """
    node_to_remove = []
    for node in g.node:
        # Find a Shape
        if node.op_type != 'ConstantOfShape':
            continue
        # Check  input
        input_value = helper.find_value_by_name(g, node.input[0])
        if input_value is None:
            input_value = helper.find_input_by_name(g, node.input[0])
        if input_value is None or len(
                input_value.type.tensor_type.shape.dim) == 0:
            continue

        # Replace to constant node
        pre_node = helper.find_node_by_output_name(g, node.input[0])
        _, target_shape = helper.constant_to_list(pre_node)

        value = helper.get_attribute_by_name(node, 'value').i

        node_name = node.output[0]
        new_node = helper.list_to_constant(node_name, [target_shape[0]],
                                           [value] * target_shape[0])

        g.node.extend([new_node])

        # remove old node
        node_to_remove.append(node)

        # delete value_info
        val_info_used = sum(
            [input_value.name in node.input for node in g.node])
        if val_info_used == 1:
            g.value_info.remove(input_value)

    replaced = True if len(node_to_remove) > 0 else False

    for node in node_to_remove:
        g.node.remove(node)

    topological_sort(g)

    return replaced
コード例 #10
0
def replace_mul_to_bn(g):
    """Replace single Mul node with Batchnorm node.
    :param g: input graph.
    :return:
    """
    node_to_del = []
    for node in g.node:
        if node.op_type != 'Mul':
            continue

        mul_op_node = node

        # only support one input node
        if len(mul_op_node.input) != 2:  # OP node and value node
            continue

        input_op_node_name = mul_op_node.input[0]
        mul_value_node = helper.find_node_by_output_name(
            g, mul_op_node.input[1])
        if not mul_value_node or mul_value_node.op_type != 'Constant':
            continue

        _, previous_node_output_shape = helper.find_size_shape_from_value(
            helper.find_value_by_name(g, input_op_node_name))
        scale_shape, scale_data = helper.constant_to_list(mul_value_node)

        # only allow 4 dim data input due to the hardware limitation
        if len(previous_node_output_shape) != 4:
            continue

        # channel dimension
        c_dim = previous_node_output_shape[1]

        # only allow channelwise mul or const mul
        if scale_shape != [1, c_dim, 1, 1] and scale_shape != 1:
            continue

        ones = [1.0] * c_dim
        zeros = [0.0] * c_dim
        muls = scale_data * c_dim
        bn_name = mul_op_node.output[0]
        mean_value_node = helper.list_to_constant(bn_name + '_mean',
                                                  np.array(zeros).shape, zeros)
        variance_value_node = helper.list_to_constant(bn_name + '_var',
                                                      np.array(ones).shape,
                                                      ones)
        bias_value_node = helper.list_to_constant(bn_name + '_add',
                                                  np.array(zeros).shape, zeros)
        new_mul_value_node = helper.list_to_constant(bn_name + '_mul',
                                                     np.array(muls).shape,
                                                     muls)

        bn_node = onnx.helper.make_node('BatchNormalization', [
            input_op_node_name, new_mul_value_node.output[0],
            bias_value_node.output[0], mean_value_node.output[0],
            variance_value_node.output[0]
        ], [mul_op_node.output[0]],
                                        name=bn_name,
                                        epsilon=0.00000001)

        mid_val_info = helper.find_value_by_name(g, mul_op_node.output[0])
        scale_val_info = helper.find_value_by_name(g, mul_value_node.output[0])
        g.value_info.remove(mid_val_info)
        g.value_info.remove(scale_val_info)

        g.node.extend([bn_node])
        g.node.extend([mean_value_node])
        g.node.extend([variance_value_node])
        g.node.extend([bias_value_node])
        g.node.extend([new_mul_value_node])

        node_to_del.extend([mul_op_node])
        node_to_del.extend([mul_value_node])

    while node_to_del:
        g.node.remove(node_to_del.pop())

    topological_sort(g)
コード例 #11
0
ファイル: fusing.py プロジェクト: yiweichen04/ONNX_Convertor
def fuse_mul_and_add_into_bn(g):
    node_to_del = []
    for node in g.node:
        if node.op_type != 'Add':
            continue
        add_node = node
        input_nodes_add = [
            helper.find_node_by_output_name(g, input_name)
            for input_name in add_node.input
        ]
        if any([n == None for n in input_nodes_add]):
            continue
        mul_node, const_add = None, None
        for input_node_add in input_nodes_add:
            if input_node_add.op_type == 'Mul':
                mul_node = input_node_add
            elif input_node_add.op_type == 'Constant':
                const_add = input_node_add
            else:
                pass
        if not mul_node or not const_add:
            continue
        data_input_name, const_mul = None, None
        for input_name in mul_node.input:
            input_node = helper.find_node_by_output_name(g, input_name)
            if not input_node:
                data_input_name = input_name
            elif input_node.op_type == 'Constant':
                if not const_mul:
                    const_mul = input_node
                else:
                    data_input_name = input_name
            else:
                data_input_name = input_name

        if not const_mul:
            continue

        scale_shape, scale_data = helper.constant_to_list(const_mul)
        bais_shape, __ = helper.constant_to_list(const_add)
        c_dim = len(scale_data)
        if scale_shape != bais_shape:
            continue

        _, previous_node_output_shape = helper.find_size_shape_from_value(
            helper.find_value_by_name(g, data_input_name))
        # only allow 4 dim data input due to the hardware limitation
        if len(previous_node_output_shape) != 4:
            continue

        # check if mul's dim and input channel dimension are matched
        if previous_node_output_shape[1] != c_dim:
            continue

        if scale_shape == [1, c_dim, 1, 1]:

            # remove all '1'
            for _ in range(3):
                const_add.attribute[0].t.dims.remove(1)
                const_mul.attribute[0].t.dims.remove(1)

        elif scale_shape == [1, c_dim]:

            # remove all '1'
            const_add.attribute[0].t.dims.remove(1)
            const_mul.attribute[0].t.dims.remove(1)

        else:
            continue

        bn_name = add_node.output[0]
        const_mean = helper.list_to_constant(bn_name + '_mean', [c_dim],
                                             [0.0 for _ in range(c_dim)])
        const_var = helper.list_to_constant(bn_name + '_var', [c_dim],
                                            [1.0 for _ in range(c_dim)])

        bn_node = onnx.helper.make_node(
            'BatchNormalization',
            [data_input_name, const_mul.output[0], const_add.output[0],\
                const_mean.output[0], const_var.output[0]],
            [add_node.output[0]],
            name=bn_name,
            epsilon=0.00000001
        )

        mid_val_info = helper.find_value_by_name(g, mul_node.output[0])
        scale_val_info = helper.find_value_by_name(g, const_mul.output[0])
        bais_val_info = helper.find_value_by_name(g, const_add.output[0])
        g.value_info.remove(mid_val_info)
        g.value_info.remove(scale_val_info)
        g.value_info.remove(bais_val_info)

        new_scale_val_info = onnx.helper.make_tensor_value_info(
            const_mul.output[0], const_mul.attribute[0].t.data_type, [c_dim])
        new_bais_val_info = onnx.helper.make_tensor_value_info(
            const_add.output[0], const_add.attribute[0].t.data_type, [c_dim])
        mean_val_info = onnx.helper.make_tensor_value_info(
            const_mean.output[0], const_mean.attribute[0].t.data_type, [c_dim])
        var_val_info = onnx.helper.make_tensor_value_info(
            const_var.output[0], const_var.attribute[0].t.data_type, [c_dim])

        g.value_info.extend([new_scale_val_info])
        g.value_info.extend([new_bais_val_info])
        g.value_info.extend([mean_val_info])
        g.value_info.extend([var_val_info])
        g.node.extend([bn_node])
        g.node.extend([const_mean])
        g.node.extend([const_var])
        node_to_del.extend([mul_node, add_node])

    while node_to_del:
        g.node.remove(node_to_del.pop())

    topological_sort(g)
コード例 #12
0
ファイル: replacing.py プロジェクト: kneron/ONNX_Convertor
def replace_add_to_bn(g):
    """Replace single Add node with Batchnorm node.
    :param g: input graph.
    :return:
    """
    node_to_del = []
    for node in g.node:
        if node.op_type != 'Add':
            continue

        add_op_node = node

        # only support one input node
        if len(add_op_node.input) != 2:  # OP node and value node
            continue

        input_op_node_name = add_op_node.input[0]
        add_value_node = helper.find_node_by_output_name(
            g, add_op_node.input[1])
        if not add_value_node or add_value_node.op_type != 'Constant':
            continue

        prev_shape_value_info = helper.find_value_by_name(
            g, input_op_node_name)
        prev_shape_value_info = helper.find_input_by_name(
            g, input_op_node_name
        ) if prev_shape_value_info is None else prev_shape_value_info
        if prev_shape_value_info is None:
            continue

        _, previous_node_output_shape = helper.find_size_shape_from_value(
            prev_shape_value_info)
        bias_shape, bias_data = helper.constant_to_list(add_value_node)

        # channel dimension
        c_dim = previous_node_output_shape[1] if len(
            previous_node_output_shape) > 1 else 1

        # only allow channelwise mul or const mul
        if bias_shape != [1, c_dim, 1, 1] and bias_shape != 1:
            continue

        ones = [1.0] * c_dim
        zeros = [0.0] * c_dim
        # If bias is a scaler, expand it.
        if len(bias_data) == 1:
            bias = bias_data * c_dim
        else:
            bias = bias_data
        bn_name = add_op_node.output[0]
        mean_value_node = helper.list_to_constant(bn_name + '_mean',
                                                  np.array(zeros).shape, zeros)
        variance_value_node = helper.list_to_constant(bn_name + '_var',
                                                      np.array(ones).shape,
                                                      ones)
        scale_value_node = helper.list_to_constant(bn_name + '_mul',
                                                   np.array(ones).shape, ones)
        new_add_value_node = helper.list_to_constant(bn_name + '_add',
                                                     np.array(bias).shape,
                                                     bias)

        bn_node = onnx.helper.make_node('BatchNormalization', [
            input_op_node_name, scale_value_node.output[0],
            new_add_value_node.output[0], mean_value_node.output[0],
            variance_value_node.output[0]
        ], [add_op_node.output[0]],
                                        name=bn_name,
                                        epsilon=0.00000001)

        add_val_info = helper.find_value_by_name(g, add_value_node.output[0])
        g.value_info.remove(add_val_info)

        g.node.extend([bn_node])
        g.node.extend([mean_value_node])
        g.node.extend([variance_value_node])
        g.node.extend([scale_value_node])
        g.node.extend([new_add_value_node])

        node_to_del.extend([add_op_node])
        node_to_del.extend([add_value_node])

    while node_to_del:
        g.node.remove(node_to_del.pop())

    topological_sort(g)
コード例 #13
0
ファイル: replacing.py プロジェクト: kneron/ONNX_Convertor
def replace_split_with_slices(g):
    """Replace split node with slice nodes.
    :param g: input graph.
    :return:
    """
    node_to_remove = []
    for node in g.node:
        # Find a Split
        if node.op_type != 'Split':
            continue

        input_value = helper.find_value_by_name(g, node.input[0])
        if not input_value:
            input_value = helper.find_input_by_name(g, node.input[0])
        _, shape = helper.find_size_shape_from_value(input_value)
        if len(shape) == 0:
            continue

        output_val_names = list(node.output)

        axis = 0
        split = []
        for item in node.attribute:
            if item.name == 'axis':
                axis = item.i
            if item.name == 'split':
                split = item.ints

        length = input_value.type.tensor_type.shape.dim[axis].dim_value
        if split is not []:
            n_out = len(node.attribute[1].ints)
            pos = 0
            for i in range(n_out):
                pos += node.attribute[1].ints[i]
                new_node_name = output_val_names[i]
                # Construct starts, ends, axes
                starts_name = new_node_name + '_starts_' + str(i)
                ends_name = new_node_name + '_ends_' + str(i)
                axes_name = new_node_name + '_axes_' + str(i)
                starts_node = helper.list_to_constant(
                    starts_name, (1, ), [int(pos - node.attribute[1].ints[i])])
                ends_node = helper.list_to_constant(ends_name, (1, ),
                                                    [int(pos)])
                axes_node = helper.list_to_constant(axes_name, (1, ),
                                                    [int(axis)])
                # Construtc node
                new_node = onnx.helper.make_node(
                    op_type='Slice',
                    inputs=[node.input[0], starts_name, ends_name, axes_name],
                    outputs=[new_node_name],
                    name=new_node_name)
                g.node.extend([starts_node, ends_node, axes_node, new_node])
            node_to_remove.append(node)
        else:
            n_out = len(output_val_names)
            width = length // n_out
            for i in range(n_out):
                new_node_name = output_val_names[i]
                # Construct starts, ends, axes
                starts_name = new_node_name + '_starts_' + str(i)
                ends_name = new_node_name + '_ends_' + str(i)
                axes_name = new_node_name + '_axes_' + str(i)
                starts_node = helper.list_to_constant(starts_name, (1, ),
                                                      [int(i * width)])
                ends_node = helper.list_to_constant(ends_name, (1, ),
                                                    [int((1 + i) * width)])
                axes_node = helper.list_to_constant(axes_name, (1, ),
                                                    [int(axis)])
                # Construtc node
                new_node = onnx.helper.make_node(
                    op_type='Slice',
                    inputs=[node.input[0], starts_name, ends_name, axes_name],
                    outputs=[new_node_name],
                    name=new_node_name)
                g.node.extend([starts_node, ends_node, axes_node, new_node])
            node_to_remove.append(node)

    for old_node in node_to_remove:
        g.node.remove(old_node)
    topological_sort(g)