コード例 #1
0
ファイル: other.py プロジェクト: thomaswang525/ONNX_Convertor
def add_nop_bn_after(g, value_names):
    """Add do-nothing BatchNormalization nodes after the given value info. It will\\
    take the given names as the inputs of the new node and replace the inputs\\
    of the following nodes.

    :param g: the graph\\
    :param value_names: a list of string which are the names of value_info.
    """
    for value_name in value_names:
        # Find the value first
        value = helper.find_value_by_name(g, value_name)
        if value is None:
            value = helper.find_input_by_name(g, value_name)
        if value is None:
            value = helper.find_output_by_name(g, value_name)
        if value is None:
            print("Cannot find an value_info named {}".format(value_name))
            continue
        # Get the channel number from value info
        shape = helper.get_shape_from_value_info(value)
        channel = shape[1]
        # Construct 4 weights
        node_name = value_name + "_nop_bn"
        ones = [1.0] * channel
        zeros = [0.0] * channel
        scale_node = helper.list_to_constant(node_name + "_scale", [channel],
                                             ones)
        bias_node = helper.list_to_constant(node_name + "_bias", [channel],
                                            zeros)
        mean_node = helper.list_to_constant(node_name + "_mean", [channel],
                                            zeros)
        var_node = helper.list_to_constant(node_name + "_var", [channel], ones)
        # Construct BN node
        bn_node = onnx.helper.make_node("BatchNormalization", [
            value_name, scale_node.output[0], bias_node.output[0],
            mean_node.output[0], var_node.output[0]
        ], [node_name],
                                        name=node_name)
        # Reconnect the graph
        following_nodes = helper.find_following_nodes_by_input_value_name(
            g, value_name)
        if len(following_nodes) > 0:
            for following_node in following_nodes:
                replace_node_input(following_node, value_name, node_name)
        else:
            # If the node is the output, replace the output with the previous input.
            new_value = onnx.helper.make_tensor_value_info(
                node_name, value.type.tensor_type.elem_type, shape)
            output_values = []
            while len(g.output):
                output_values.append(g.output.pop())
            while output_values:
                output_value = output_values.pop()
                if output_value.name == value_name:
                    g.output.extend([new_value])
                else:
                    g.output.extend([output_value])
        # Add node to the graph
        g.node.extend([bn_node, scale_node, bias_node, mean_node, var_node])
    topological_sort(g)
コード例 #2
0
ファイル: other.py プロジェクト: thomaswang525/ONNX_Convertor
def add_nop_conv_after(g, value_names):
    """Add do-nothing depthwise Conv nodes after the given value info. It will\\
    take the given names as the inputs of the new node and replace the inputs\\
    of the following nodes.

    :param g: the graph\\
    :param value_names: a list of string which are the names of value_info.
    """
    for value_name in value_names:
        # Find the value first
        value = helper.find_value_by_name(g, value_name)
        if value is None:
            value = helper.find_input_by_name(g, value_name)
        if value is None:
            value = helper.find_output_by_name(g, value_name)
        if value is None:
            print("Cannot find an value_info named {}".format(value_name))
            continue
        # Get the channel number from value info
        shape = helper.get_shape_from_value_info(value)
        channel = shape[1]
        # Construct 4 weights
        node_name = value_name + "_nop_conv"
        ones = [1.0] * channel
        weight_node = helper.list_to_constant(node_name + "_weight",
                                              [channel, 1, 1, 1], ones)
        # Construct BN node
        conv_node = onnx.helper.make_node("Conv",
                                          [value_name, weight_node.output[0]],
                                          [node_name],
                                          name=node_name,
                                          dilations=[1, 1],
                                          group=channel,
                                          kernel_shape=[1, 1],
                                          pads=[0, 0, 0, 0],
                                          strides=[1, 1])
        # Reconnect the graph
        following_nodes = helper.find_following_nodes_by_input_value_name(
            g, value_name)
        if len(following_nodes) > 0:
            for following_node in following_nodes:
                replace_node_input(following_node, value_name, node_name)
        else:
            # If the node is the output, replace the output with the previous input.
            new_value = onnx.helper.make_tensor_value_info(
                node_name, value.type.tensor_type.elem_type, shape)
            output_values = []
            while len(g.output):
                output_values.append(g.output.pop())
            while output_values:
                output_value = output_values.pop()
                if output_value.name == value_name:
                    g.output.extend([new_value])
                else:
                    g.output.extend([output_value])
        # Add node to the graph
        g.node.extend([conv_node, weight_node])
    topological_sort(g)
コード例 #3
0
ファイル: other.py プロジェクト: thomaswang525/ONNX_Convertor
def inference_split_shape(g):
    processed = False
    for node in g.node:
        if node.op_type != 'Split':
            continue

        input_val_info = helper.find_value_by_name(g, node.input[0])
        if not input_val_info:
            input_val_info = helper.find_input_by_name(g, node.input[0])
        if not input_val_info:
            continue

        _, input_shape = helper.find_size_shape_from_value(input_val_info)
        if not input_shape:
            continue

        output_val_names = list(node.output)
        output_vals = [
            helper.find_value_by_name(g, val_name)
            for val_name in output_val_names
        ]

        output_shapes = [
            helper.find_size_shape_from_value(output_val)[1]
            for output_val in output_vals
        ]
        if not any([len(s) == 0 for s in output_shapes]):
            continue

        for att in node.attribute:
            if att.name == 'axis':
                axis = att.i
            else:
                split = list(att.ints)

        new_output_vals = []
        for i in range(len(output_val_names)):
            new_shape = list(input_shape)
            new_shape[axis] = split[i]
            new_output_val = onnx.helper.make_tensor_value_info(
                output_val_names[i], input_val_info.type.tensor_type.elem_type,
                new_shape)
            new_output_vals.append(new_output_val)

        for val in output_vals:
            if val is not None:
                g.value_info.remove(val)
        g.value_info.extend(new_output_vals)

        processed = True

    return processed
コード例 #4
0
ファイル: other.py プロジェクト: henye/ONNX_Convertor
def inference_cov_shape(g):
    processed = False
    for node in g.node:
        if node.op_type != 'Conv':
            continue
        input_value_info = helper.find_value_by_name(g, node.input[0])
        if not input_value_info:
            input_value_info = helper.find_input_by_name(g, node.input[0])
        if not input_value_info:
            continue

        kernel_value_info = helper.find_value_by_name(g, node.input[1])
        output_value_info = helper.find_value_by_name(g, node.output[0])
        if not output_value_info:
            output_value_info = helper.find_output_by_name(g, node.output[0])

        if output_value_info and \
            helper.get_shape_from_value_info(output_value_info):
            continue

        _, kernel_shape = helper.find_size_shape_from_value(kernel_value_info)
        _, input_shape = helper.find_size_shape_from_value(input_value_info)
        if not input_shape or not kernel_shape:
            continue
        strides = helper.get_attribute_by_name(node, 'strides').ints
        pads = helper.get_attribute_by_name(node, 'pads').ints
        dilation = helper.get_attribute_by_name(node, 'dilations').ints

        # Pytorch model has the case where strides only have one number
        if len(strides) == 1:
            return strides.append(strides[0])
        if len(dilation) == 1:
            return dilation.append(dilation[0])

        H = math.floor((input_shape[2]+pads[0]+pads[2]-\
            dilation[0]*(kernel_shape[2]-1)-1)/strides[0]+1)
        W = math.floor((input_shape[3]+pads[1]+pads[3]-\
            dilation[1]*(kernel_shape[3]-1)-1)/strides[1]+1)
        output_shape = [input_shape[0], kernel_shape[0], H, W]

        new_output_value_info = onnx.helper.make_tensor_value_info(
            node.output[0], input_value_info.type.tensor_type.elem_type,
            output_shape)

        processed = True

        if output_value_info:
            g.value_info.remove(output_value_info)
        g.value_info.extend([new_output_value_info])

    return processed
コード例 #5
0
ファイル: replacing.py プロジェクト: kneron/ONNX_Convertor
def replace_shape_with_constant(g):
    """Replace Shape with Constant.\\
    This is the first step of reshape constant folding.

    :param g: the input graph\\
    :return: if anything modified, return true.
    """
    node_to_remove = []
    for node in g.node:
        # Find a Shape
        if node.op_type != 'Shape':
            continue
        # Check its input
        input_value = helper.find_value_by_name(g, node.input[0])
        if input_value is None:
            input_value = helper.find_input_by_name(g, node.input[0])
        if input_value is None or len(
                input_value.type.tensor_type.shape.dim) == 0:
            continue
        # Check for case where dimension could be 0 or -1
        tmp = True
        for d in input_value.type.tensor_type.shape.dim:
            tmp = tmp and (d.dim_value > 0)
        if not tmp:
            continue
        # Repalce it
        input_shape = [
            d.dim_value for d in input_value.type.tensor_type.shape.dim
        ]
        node_name = node.output[0]
        new_node = helper.list_to_constant(node_name, [len(input_shape)],
                                           input_shape)
        g.node.extend([new_node])
        node_to_remove.append(node)

        # if the input value_info is not used by other node
        # delete this input value_info
        val_info_used = sum(
            [input_value.name in node.input for node in g.node])
        if val_info_used == 1:
            g.value_info.remove(input_value)

    replaced = True if len(node_to_remove) > 0 else False

    for node in node_to_remove:
        g.node.remove(node)

    topological_sort(g)

    return replaced
コード例 #6
0
ファイル: replacing.py プロジェクト: kneron/ONNX_Convertor
def replace_ConstantOfShape_with_constant(g):
    """Replace Shape with Constant.\\
    This is the first step of reshape constant folding.

    :param g: the input graph\\
    :return: if anything modified, return true.
    """
    node_to_remove = []
    for node in g.node:
        # Find a Shape
        if node.op_type != 'ConstantOfShape':
            continue
        # Check  input
        input_value = helper.find_value_by_name(g, node.input[0])
        if input_value is None:
            input_value = helper.find_input_by_name(g, node.input[0])
        if input_value is None or len(
                input_value.type.tensor_type.shape.dim) == 0:
            continue

        # Replace to constant node
        pre_node = helper.find_node_by_output_name(g, node.input[0])
        _, target_shape = helper.constant_to_list(pre_node)

        value = helper.get_attribute_by_name(node, 'value').i

        node_name = node.output[0]
        new_node = helper.list_to_constant(node_name, [target_shape[0]],
                                           [value] * target_shape[0])

        g.node.extend([new_node])

        # remove old node
        node_to_remove.append(node)

        # delete value_info
        val_info_used = sum(
            [input_value.name in node.input for node in g.node])
        if val_info_used == 1:
            g.value_info.remove(input_value)

    replaced = True if len(node_to_remove) > 0 else False

    for node in node_to_remove:
        g.node.remove(node)

    topological_sort(g)

    return replaced
コード例 #7
0
ファイル: other.py プロジェクト: raefYoussef/ONNX_Convertor
def change_input_shape(g, target_list):
    for target in target_list:
        try:
            name, shape = parse_shape_change_input(target)
            input_value = helper.find_input_by_name(g, name)
            if input_value is None:
                print("Cannot find input {}".format(name))
                continue
            if len(shape) != len(input_value.type.tensor_type.shape.dim):
                print("The dimension doesn't match for input {}".format(name))
                continue
            for i in range(len(shape)):
                input_value.type.tensor_type.shape.dim[i].dim_value = shape[i]
        except TypeError:
            # This happens when the parser function returns None.
            continue
        except ValueError:
            # This happens when the input cannot be converter into int
            print("Cannot parse {} into name and int".format(target))
            continue
コード例 #8
0
ファイル: replacing.py プロジェクト: henye/ONNX_Convertor
def replace_split_with_slices(g):
    """Replace split node with slice nodes.
    :param g: input graph.
    :return:
    """
    node_to_remove = []
    for node in g.node:
        # Find a Split
        if node.op_type != 'Split':
            continue

        input_value = helper.find_value_by_name(g, node.input[0])
        if not input_value:
            input_value = helper.find_input_by_name(g, node.input[0])
        _, shape = helper.find_size_shape_from_value(input_value)
        if len(shape) == 0:
            continue

        output_val_names = list(node.output)

        axis = 0
        split = []
        for item in node.attribute:
            if item.name == 'axis':
                axis = item.i
            if item.name == 'split':
                split = item.ints

        length = input_value.type.tensor_type.shape.dim[axis].dim_value

        outputs = node.output
        if split is not []:
            n_out = len(node.attribute[1].ints)
            pos = 0
            for i in range(n_out):
                pos += node.attribute[1].ints[i]
                new_node_name = output_val_names[i]
                new_node = onnx.helper.make_node(
                    op_type='Slice',
                    inputs=[node.input[0]],
                    outputs=[new_node_name],
                    name=new_node_name,
                    axes=[axis],
                    ends=[pos],
                    starts=[pos - node.attribute[1].ints[i]])
                g.node.extend([new_node])
            node_to_remove.append(node)
        else:
            n_out = len(outputs)
            width = length // n_out
            for i in range(n_out):
                new_node = onnx.helper.make_node(op_type='Slice',
                                                 inputs=[node.input[0]],
                                                 outputs=[outputs[i]],
                                                 name=outputs[i],
                                                 axes=[axis],
                                                 ends=[(1 + i) * width],
                                                 starts=[i * width])
                g.node.extend([new_node])
            node_to_remove.append(node)

    for old_node in node_to_remove:
        g.node.remove(old_node)
    topological_sort(g)
コード例 #9
0
ファイル: other.py プロジェクト: raefYoussef/ONNX_Convertor
def split_ConvTranspose(model):
    """To feed our compiler, split ConvTranspose into Upsample and Conv.

    :param model: the model
    """
    node_to_delete = []
    # Change model properties for upsample.
    if model.ir_version < 3:
        print("Warning: Current model IR version is not fully supported.")
    model.ir_version = 4
    model.opset_import[0].version = 9
    g = model.graph
    # Get a Convtranspose layer
    for node in g.node:
        # Find a Flatten node
        if node.op_type != 'ConvTranspose':
            continue
        # Check auto_pad
        auto_pad_proto = helper.get_attribute_by_name(node, "auto_pad")
        if auto_pad_proto is not None:
            print("Currently not split auto_pad ConvTranspose")
            continue
        # Check output_shape
        output_shape_proto = helper.get_attribute_by_name(node, "output_shape")
        if output_shape_proto is not None:
            print("Currently not split output_shape ConvTranspose")
            continue
        # Get input shape
        input_value = helper.find_value_by_name(g, node.input[0])
        if input_value is None:
            input_value = helper.find_input_by_name(g, node.input[0])
        if input_value is None:
            print("Cannot get value info named {}.".format(node.input[0]))
            exit(1)
        input_shape = helper.get_shape_from_value_info(input_value)
        # Get attrbutes
        attr = deconv_to_conv_info_extraction(input_shape, node)
        # Generate Upsample scales
        upsample_output_shape = list(input_shape)
        upsample_output_shape[2] = (input_shape[2] - 1) * attr["strides"][0] + 1
        upsample_output_shape[3] = (input_shape[3] - 1) * attr["strides"][1] + 1
        upsample_node_name = node.name + "_inner_upsample"
        upsample_scale_name = upsample_node_name + "_scales"
        scales_np = np.ones([4]).astype('float32')
        scales_np[2] = float(upsample_output_shape[2]) / input_shape[2]
        scales_np[3] = float(upsample_output_shape[3]) / input_shape[3]
        scales_node = helper.numpy_to_constant(upsample_scale_name, scales_np)
        # Generate a Upsample layer and an internal value info
        upsample_node = onnx.helper.make_node(
            "Upsample",
            [node.input[0], upsample_scale_name],
            [upsample_node_name],
            name=upsample_node_name,
            mode="zeros"
        )
        upsample_value_info = onnx.helper.make_tensor_value_info(
            upsample_node_name,
            input_value.type.tensor_type.elem_type,
            upsample_output_shape
        )
        # Check the weight layer, it may need a transpose
        if attr["group"] != input_shape[1]:
            weight_node = helper.find_node_by_output_name(g, node.input[1])
            weight_np = helper.constant_to_numpy(weight_node)
            new_weight_np = np.transpose(weight_np, [1, 0, 2, 3])
            new_weight_node = helper.numpy_to_constant(node.input[1], new_weight_np)
            node_to_delete.append(weight_node)
            g.node.extend([new_weight_node])
            value = helper.find_value_by_name(g, node.input[1])
            g.value_info.remove(value)
        # Generate a Conv layer
        conv_node_name = node.name + "_inner_conv"
        conv_node_input = [upsample_node_name]
        conv_node_input.extend(node.input[1:])
        conv_node = onnx.helper.make_node(
            "Conv",
            conv_node_input,
            [node.output[0]],
            name=conv_node_name,
            pads=[int(i) for i in attr["conv_pads"]],
            dilations=[int(i) for i in attr["dilations"]],
            group=int(attr["group"]),
            kernel_shape=[int(i) for i in attr["kernel_shape"]],
            strides=[int(1), int(1)]
        )
        # Reconnect the graph
        g.node.extend([scales_node, upsample_node, conv_node])
        g.value_info.extend([upsample_value_info])
        node_to_delete.append(node)
    # Delete useless nodes
    for node in node_to_delete:
        g.node.remove(node)
    topological_sort(g)
コード例 #10
0
ファイル: other.py プロジェクト: raefYoussef/ONNX_Convertor
def inference_cov_shape(g):
    processed = False
    for node in g.node:
        # Check for Conv output shape need to be inferrenced.
        if node.op_type != 'Conv':
            continue
        # Input shape is not ready yet. Skip.
        input_value_info = helper.find_value_by_name(g, node.input[0])
        if not input_value_info:
            input_value_info = helper.find_input_by_name(g, node.input[0])
        if not input_value_info:
            continue
        _, input_shape = helper.find_size_shape_from_value(input_value_info)
        if not input_shape:
            continue
        # Output shape is already there. Skip.
        output_value_info = helper.find_value_by_name(g, node.output[0])
        if not output_value_info:
            output_value_info = helper.find_output_by_name(g, node.output[0])
        if output_value_info and \
            helper.get_shape_from_value_info(output_value_info):
            continue

        # Now start the inference.
        # If auto_pad is set, use the auto_pad.
        auto_pad = helper.get_var_attribute_by_name(node, 'auto_pad', 'string')
        pads = None
        if auto_pad is not None and auto_pad != 'NOTSET':
            if auto_pad == 'SAME_LOWER' or auto_pad == 'SAME_UPPER':
                new_output_value_info = onnx.helper.make_tensor_value_info(
                    node.output[0],
                    input_value_info.type.tensor_type.elem_type,
                    input_shape
                )
                if output_value_info:
                    g.value_info.remove(output_value_info)
                g.value_info.extend([new_output_value_info])
                processed = True
                continue
            elif auto_pad == 'VALID':
                pads = [0, 0, 0, 0]
            else:
                print("Unrecognized auto_pad value: " + str(auto_pad))
                exit(1)
        kernel_value_info = helper.find_value_by_name(g, node.input[1])
        _, kernel_shape = helper.find_size_shape_from_value(kernel_value_info)
        if not input_shape or not kernel_shape:
            continue
        strides = helper.get_attribute_by_name(node, 'strides').ints
        if not pads:
            pads = helper.get_attribute_by_name(node, 'pads').ints
        dilation = helper.get_attribute_by_name(node, 'dilations').ints

        # Pytorch model has the case where strides only have one number
        if len(strides) == 1:
            return strides.append(strides[0])
        if len(dilation) == 1:
            return dilation.append(dilation[0])

        H = math.floor((input_shape[2]+pads[0]+pads[2]-\
            dilation[0]*(kernel_shape[2]-1)-1)/strides[0]+1)
        W = math.floor((input_shape[3]+pads[1]+pads[3]-\
            dilation[1]*(kernel_shape[3]-1)-1)/strides[1]+1)
        output_shape = [input_shape[0], kernel_shape[0], H, W]

        new_output_value_info = onnx.helper.make_tensor_value_info(
            node.output[0],
            input_value_info.type.tensor_type.elem_type,
            output_shape
        )

        processed = True

        if output_value_info:
            g.value_info.remove(output_value_info)
        g.value_info.extend([new_output_value_info])

    return processed
コード例 #11
0
ファイル: fusing.py プロジェクト: yiweichen04/ONNX_Convertor
def fuse_mul_and_add_into_gemm(g):
    node_to_del = []
    for node in g.node:
        if node.op_type != 'Add':
            continue
        add_node = node
        mul_node = helper.find_node_by_output_name(g, add_node.input[0])
        if not mul_node or mul_node.op_type != 'Mul':
            continue
        mul_const = helper.find_node_by_output_name(g, mul_node.input[1])
        if not mul_const or mul_const.op_type != 'Constant':
            continue
        add_const = helper.find_node_by_output_name(g, add_node.input[1])
        if not add_const or add_const.op_type != 'Constant':
            continue

        input_val = helper.find_value_by_name(g, mul_node.input[0])
        if not input_val:
            input_val = helper.find_input_by_name(g, mul_node.input[0])
        if not input_val:
            continue

        _, input_shape = helper.find_size_shape_from_value(input_val)
        if not input_shape:
            continue

        dim = int(np.prod(input_shape))
        if input_shape != [1, dim]:
            continue

        mul_const_shape, mul_const_data = helper.constant_to_list(mul_const)
        add_const_shape, __ = helper.constant_to_list(add_const)

        if len(mul_const_shape) != 1 or mul_const_shape[0] != dim:
            continue
        if len(add_const_shape) != 1 or add_const_shape[0] != dim:
            continue

        b_data = np.zeros([dim, dim])
        for i in range(dim):
            b_data[i][i] = mul_const_data[i]
        b_data = b_data.flatten().tolist()
        b_tensor = onnx.helper.make_tensor(
            name=mul_const.name + '_tensor',
            data_type=mul_const.attribute[0].t.data_type,
            dims=[dim, dim],
            vals=b_data)
        b_const_node = onnx.helper.make_node('Constant', [],
                                             [mul_const.output[0]],
                                             value=b_tensor,
                                             name=mul_const.output[0])

        add_const.attribute[0].t.dims.insert(0, 1)

        gemm_node = onnx.helper.make_node(
            'Gemm',
            [mul_node.input[0], b_const_node.output[0], add_const.output[0]],
            [add_node.output[0]],
            name=add_node.output[0])

        g.node.extend([gemm_node, b_const_node])
        node_to_del.extend([mul_const, mul_node, add_node])

        val_info_mid = helper.find_value_by_name(g, mul_node.output[0])
        val_info_mul_const = helper.find_value_by_name(g, mul_const.output[0])
        val_info_add_const = helper.find_value_by_name(g, add_const.output[0])
        if val_info_mid:
            g.value_info.remove(val_info_mid)
        if val_info_mul_const:
            g.value_info.remove(val_info_mul_const)
        if val_info_add_const:
            g.value_info.remove(val_info_add_const)

    while node_to_del:
        g.node.remove(node_to_del.pop())

    topological_sort(g)
コード例 #12
0
ファイル: replacing.py プロジェクト: kneron/ONNX_Convertor
def replace_add_to_bn(g):
    """Replace single Add node with Batchnorm node.
    :param g: input graph.
    :return:
    """
    node_to_del = []
    for node in g.node:
        if node.op_type != 'Add':
            continue

        add_op_node = node

        # only support one input node
        if len(add_op_node.input) != 2:  # OP node and value node
            continue

        input_op_node_name = add_op_node.input[0]
        add_value_node = helper.find_node_by_output_name(
            g, add_op_node.input[1])
        if not add_value_node or add_value_node.op_type != 'Constant':
            continue

        prev_shape_value_info = helper.find_value_by_name(
            g, input_op_node_name)
        prev_shape_value_info = helper.find_input_by_name(
            g, input_op_node_name
        ) if prev_shape_value_info is None else prev_shape_value_info
        if prev_shape_value_info is None:
            continue

        _, previous_node_output_shape = helper.find_size_shape_from_value(
            prev_shape_value_info)
        bias_shape, bias_data = helper.constant_to_list(add_value_node)

        # channel dimension
        c_dim = previous_node_output_shape[1] if len(
            previous_node_output_shape) > 1 else 1

        # only allow channelwise mul or const mul
        if bias_shape != [1, c_dim, 1, 1] and bias_shape != 1:
            continue

        ones = [1.0] * c_dim
        zeros = [0.0] * c_dim
        # If bias is a scaler, expand it.
        if len(bias_data) == 1:
            bias = bias_data * c_dim
        else:
            bias = bias_data
        bn_name = add_op_node.output[0]
        mean_value_node = helper.list_to_constant(bn_name + '_mean',
                                                  np.array(zeros).shape, zeros)
        variance_value_node = helper.list_to_constant(bn_name + '_var',
                                                      np.array(ones).shape,
                                                      ones)
        scale_value_node = helper.list_to_constant(bn_name + '_mul',
                                                   np.array(ones).shape, ones)
        new_add_value_node = helper.list_to_constant(bn_name + '_add',
                                                     np.array(bias).shape,
                                                     bias)

        bn_node = onnx.helper.make_node('BatchNormalization', [
            input_op_node_name, scale_value_node.output[0],
            new_add_value_node.output[0], mean_value_node.output[0],
            variance_value_node.output[0]
        ], [add_op_node.output[0]],
                                        name=bn_name,
                                        epsilon=0.00000001)

        add_val_info = helper.find_value_by_name(g, add_value_node.output[0])
        g.value_info.remove(add_val_info)

        g.node.extend([bn_node])
        g.node.extend([mean_value_node])
        g.node.extend([variance_value_node])
        g.node.extend([scale_value_node])
        g.node.extend([new_add_value_node])

        node_to_del.extend([add_op_node])
        node_to_del.extend([add_value_node])

    while node_to_del:
        g.node.remove(node_to_del.pop())

    topological_sort(g)
コード例 #13
0
def replace_mul_to_bn(g):
    """Replace single Mul node with Batchnorm node.
    :param g: input graph.
    :return:
    """
    node_to_del = []
    for node in g.node:
        if node.op_type != 'Mul':
            continue

        mul_op_node = node

        # only support one input node
        if len(mul_op_node.input) != 2:  # OP node and value node
            continue

        input_op_node_name = mul_op_node.input[0]
        mul_value_node = helper.find_node_by_output_name(
            g, mul_op_node.input[1])
        if not mul_value_node or mul_value_node.op_type != 'Constant':
            continue

        prev_shape_value_info = helper.find_value_by_name(
            g, input_op_node_name)
        prev_shape_value_info = helper.find_input_by_name(
            g, input_op_node_name
        ) if prev_shape_value_info is None else prev_shape_value_info
        if prev_shape_value_info is None:
            continue

        _, previous_node_output_shape = helper.find_size_shape_from_value(
            prev_shape_value_info)
        scale_shape, scale_data = helper.constant_to_list(mul_value_node)

        # only allow 4 dim data input due to the hardware limitation
        if len(previous_node_output_shape) != 4:
            continue

        # channel dimension
        c_dim = previous_node_output_shape[1]

        # only allow channelwise mul or const mul
        if scale_shape != [1, c_dim, 1, 1] and scale_shape != 1:
            continue

        ones = [1.0] * c_dim
        zeros = [0.0] * c_dim
        muls = scale_data * c_dim
        bn_name = mul_op_node.output[0]
        mean_value_node = helper.list_to_constant(bn_name + '_mean',
                                                  np.array(zeros).shape, zeros)
        variance_value_node = helper.list_to_constant(bn_name + '_var',
                                                      np.array(ones).shape,
                                                      ones)
        bias_value_node = helper.list_to_constant(bn_name + '_add',
                                                  np.array(zeros).shape, zeros)
        new_mul_value_node = helper.list_to_constant(bn_name + '_mul',
                                                     np.array(muls).shape,
                                                     muls)

        bn_node = onnx.helper.make_node('BatchNormalization', [
            input_op_node_name, new_mul_value_node.output[0],
            bias_value_node.output[0], mean_value_node.output[0],
            variance_value_node.output[0]
        ], [mul_op_node.output[0]],
                                        name=bn_name,
                                        epsilon=0.00000001)

        mid_val_info = helper.find_value_by_name(g, mul_op_node.output[0])
        scale_val_info = helper.find_value_by_name(g, mul_value_node.output[0])
        g.value_info.remove(mid_val_info)
        g.value_info.remove(scale_val_info)

        g.node.extend([bn_node])
        g.node.extend([mean_value_node])
        g.node.extend([variance_value_node])
        g.node.extend([bias_value_node])
        g.node.extend([new_mul_value_node])

        node_to_del.extend([mul_op_node])
        node_to_del.extend([mul_value_node])

    while node_to_del:
        g.node.remove(node_to_del.pop())

    topological_sort(g)