Example #1
0
def inference_resize_shape(g):
    for node in g.node:
        if node.op_type != 'Resize':
            continue

        output_value = helper.find_value_by_name(g, node.output[0])
        output_value = helper.find_output_by_name(
            g, node.output[0]) if output_value is None else output_value
        if output_value is not None:
            continue

        # currently, only support 4 input
        if len(node.input) == 4:  # input: X, roi, scales, sizes
            shape_node = helper.find_node_by_output_name(g, node.input[3])
            if shape_node.op_type != 'Constant':
                continue

            _, shape_value = helper.constant_to_list(shape_node)
            output_value = onnx.helper.make_tensor_value_info(
                node.output[0], onnx.TensorProto.FLOAT,
                [int(v) for v in shape_value])
            g.value_info.extend([output_value])

            return True
    return False
Example #2
0
def replace_Squeeze_with_Reshape(g):
    """
    Replace Squeeze nodes with Reshape node.

    :param g: the input graph
    """
    node_to_remove = []
    for node in g.node:
        # Find Squeeze node
        if node.op_type != 'Squeeze':
            continue
        # Get the shape and Construct the shape
        output_value = helper.find_value_by_name(g, node.output[0])
        if output_value is None:
            output_value = helper.find_output_by_name(g, node.output[0])
        if output_value is None:
            raise RuntimeError("Cannot get shape for Squeeze")
        shape = [
            dim.dim_value for dim in output_value.type.tensor_type.shape.dim
        ]
        const_node = helper.list_to_constant(node.name + "_shape",
                                             [len(shape)], shape)
        # Construct the Reshape layer with same input, output and name.
        new_node = onnx.helper.make_node("Reshape",
                                         [node.input[0], node.name + "_shape"],
                                         node.output,
                                         name=node.name)
        # Append constructed nodes and append old node to remove_list
        g.node.extend([const_node, new_node])
        node_to_remove.append(node)
    # Remove old nodes
    for node in node_to_remove:
        g.node.remove(node)
    # Topological sort
    topological_sort(g)
Example #3
0
def add_nop_bn_after(g, value_names):
    """Add do-nothing BatchNormalization nodes after the given value info. It will\\
    take the given names as the inputs of the new node and replace the inputs\\
    of the following nodes.

    :param g: the graph\\
    :param value_names: a list of string which are the names of value_info.
    """
    for value_name in value_names:
        # Find the value first
        value = helper.find_value_by_name(g, value_name)
        if value is None:
            value = helper.find_input_by_name(g, value_name)
        if value is None:
            value = helper.find_output_by_name(g, value_name)
        if value is None:
            print("Cannot find an value_info named {}".format(value_name))
            continue
        # Get the channel number from value info
        shape = helper.get_shape_from_value_info(value)
        channel = shape[1]
        # Construct 4 weights
        node_name = value_name + "_nop_bn"
        ones = [1.0] * channel
        zeros = [0.0] * channel
        scale_node = helper.list_to_constant(node_name + "_scale", [channel],
                                             ones)
        bias_node = helper.list_to_constant(node_name + "_bias", [channel],
                                            zeros)
        mean_node = helper.list_to_constant(node_name + "_mean", [channel],
                                            zeros)
        var_node = helper.list_to_constant(node_name + "_var", [channel], ones)
        # Construct BN node
        bn_node = onnx.helper.make_node("BatchNormalization", [
            value_name, scale_node.output[0], bias_node.output[0],
            mean_node.output[0], var_node.output[0]
        ], [node_name],
                                        name=node_name)
        # Reconnect the graph
        following_nodes = helper.find_following_nodes_by_input_value_name(
            g, value_name)
        if len(following_nodes) > 0:
            for following_node in following_nodes:
                replace_node_input(following_node, value_name, node_name)
        else:
            # If the node is the output, replace the output with the previous input.
            new_value = onnx.helper.make_tensor_value_info(
                node_name, value.type.tensor_type.elem_type, shape)
            output_values = []
            while len(g.output):
                output_values.append(g.output.pop())
            while output_values:
                output_value = output_values.pop()
                if output_value.name == value_name:
                    g.output.extend([new_value])
                else:
                    g.output.extend([output_value])
        # Add node to the graph
        g.node.extend([bn_node, scale_node, bias_node, mean_node, var_node])
    topological_sort(g)
Example #4
0
def rename_all_node_name(g):
    """
    rename all nodes:

        new_name = old_name + "_kn"

    :param g: the onnx graph
    """

    for node in g.node:
        new_node_name = node.name + "_kn"
        new_node_output0_name = node.output[0] + "_kn"

        # in order to keep same output node name, skip if it is output node.
        output_value_info = helper.find_output_by_name(g, node.output[0])
        if output_value_info != None:
            continue

        # rename  the input of all the following nodes
        following_nodes = helper.find_following_nodes_by_input_value_name(
            g, node.output[0])
        for following_node in following_nodes:
            replace_node_input(following_node, node.output[0],
                               new_node_output0_name)

        # rename value info
        value_info = helper.find_value_by_name(g, node.output[0])
        if value_info != None:
            value_info.name = new_node_output0_name

        # rename node
        node.output[0] = new_node_output0_name
        node.name = new_node_name
Example #5
0
def find_first_sequential_outputs(g, node):
    for value_name in node.output:
        value = helper.find_output_by_name(g, value_name)
        if value is not None:
            return value
    return find_first_sequential_outputs(
        g,
        helper.find_nodes_by_input_name(g, node.output[0])[0])
Example #6
0
def add_nop_conv_after(g, value_names):
    """Add do-nothing depthwise Conv nodes after the given value info. It will\\
    take the given names as the inputs of the new node and replace the inputs\\
    of the following nodes.

    :param g: the graph\\
    :param value_names: a list of string which are the names of value_info.
    """
    for value_name in value_names:
        # Find the value first
        value = helper.find_value_by_name(g, value_name)
        if value is None:
            value = helper.find_input_by_name(g, value_name)
        if value is None:
            value = helper.find_output_by_name(g, value_name)
        if value is None:
            print("Cannot find an value_info named {}".format(value_name))
            continue
        # Get the channel number from value info
        shape = helper.get_shape_from_value_info(value)
        channel = shape[1]
        # Construct 4 weights
        node_name = value_name + "_nop_conv"
        ones = [1.0] * channel
        weight_node = helper.list_to_constant(node_name + "_weight",
                                              [channel, 1, 1, 1], ones)
        # Construct BN node
        conv_node = onnx.helper.make_node("Conv",
                                          [value_name, weight_node.output[0]],
                                          [node_name],
                                          name=node_name,
                                          dilations=[1, 1],
                                          group=channel,
                                          kernel_shape=[1, 1],
                                          pads=[0, 0, 0, 0],
                                          strides=[1, 1])
        # Reconnect the graph
        following_nodes = helper.find_following_nodes_by_input_value_name(
            g, value_name)
        if len(following_nodes) > 0:
            for following_node in following_nodes:
                replace_node_input(following_node, value_name, node_name)
        else:
            # If the node is the output, replace the output with the previous input.
            new_value = onnx.helper.make_tensor_value_info(
                node_name, value.type.tensor_type.elem_type, shape)
            output_values = []
            while len(g.output):
                output_values.append(g.output.pop())
            while output_values:
                output_value = output_values.pop()
                if output_value.name == value_name:
                    g.output.extend([new_value])
                else:
                    g.output.extend([output_value])
        # Add node to the graph
        g.node.extend([conv_node, weight_node])
    topological_sort(g)
Example #7
0
def inference_cov_shape(g):
    processed = False
    for node in g.node:
        if node.op_type != 'Conv':
            continue
        input_value_info = helper.find_value_by_name(g, node.input[0])
        if not input_value_info:
            input_value_info = helper.find_input_by_name(g, node.input[0])
        if not input_value_info:
            continue

        kernel_value_info = helper.find_value_by_name(g, node.input[1])
        output_value_info = helper.find_value_by_name(g, node.output[0])
        if not output_value_info:
            output_value_info = helper.find_output_by_name(g, node.output[0])

        if output_value_info and \
            helper.get_shape_from_value_info(output_value_info):
            continue

        _, kernel_shape = helper.find_size_shape_from_value(kernel_value_info)
        _, input_shape = helper.find_size_shape_from_value(input_value_info)
        if not input_shape or not kernel_shape:
            continue
        strides = helper.get_attribute_by_name(node, 'strides').ints
        pads = helper.get_attribute_by_name(node, 'pads').ints
        dilation = helper.get_attribute_by_name(node, 'dilations').ints

        # Pytorch model has the case where strides only have one number
        if len(strides) == 1:
            return strides.append(strides[0])
        if len(dilation) == 1:
            return dilation.append(dilation[0])

        H = math.floor((input_shape[2]+pads[0]+pads[2]-\
            dilation[0]*(kernel_shape[2]-1)-1)/strides[0]+1)
        W = math.floor((input_shape[3]+pads[1]+pads[3]-\
            dilation[1]*(kernel_shape[3]-1)-1)/strides[1]+1)
        output_shape = [input_shape[0], kernel_shape[0], H, W]

        new_output_value_info = onnx.helper.make_tensor_value_info(
            node.output[0], input_value_info.type.tensor_type.elem_type,
            output_shape)

        processed = True

        if output_value_info:
            g.value_info.remove(output_value_info)
        g.value_info.extend([new_output_value_info])

    return processed
Example #8
0
def rename_output_name(g, original_name, new_name):
    # Output
    output_value = helper.find_output_by_name(g, original_name)
    if output_value is None:
        logging.error("Cannot find output value named " + original_name)
        return
    output_value.name = new_name
    # Value Info
    value_info = helper.find_value_by_name(g, original_name)
    if value_info is not None:
        value_info.name = new_name
    # Node output
    node = helper.find_node_by_output_name(g, original_name)
    node.output[0] = new_name
    # Node input
    nodes = helper.find_nodes_by_input_name(g, original_name)
    for node in nodes:
        replace_node_input(node, original_name, new_name)
Example #9
0
def inference_upsample_shape(g):
    """For onnx v1.4.1+, onnx cannot inference upsample output shape. Let's\\
    do it ourselves. This function only inference the next upsample without\\
    output shape each time.

    :param g: the graph\\
    :return: True if any Upsample shape is generated. Otherwise, False.
    """
    for node in g.node:
        if node.op_type != 'Upsample':
            continue
        output_value = helper.find_value_by_name(g, node.output[0])
        if output_value is None:
            output_value = helper.find_output_by_name(g, node.output[0])
        if output_value and helper.get_shape_from_value_info(output_value):
            continue
        # Get input shape
        input_value = helper.find_value_by_name(g, node.input[0])
        if input_value is None:
            continue
            #raise RuntimeError("Shape for {} has not been generated.".format(node.input[0]))
        if not helper.get_shape_from_value_info(input_value):
            continue
            #raise RuntimeError("Shape for {} is empty.".format(node.input[0]))
        input_shape = helper.get_shape_from_value_info(input_value)
        # Get upsample weight
        weight_node = helper.find_node_by_output_name(g, node.input[1])
        weight_shape, weight = helper.constant_to_list(weight_node)
        if len(input_shape) != weight_shape[0]:
            raise RuntimeError(
                "Unmatch input shape and weight shape: {} vs {}".format(
                    input_shape, weight_shape))
        # Calculate shape
        output_shape = list(input_shape)
        for i in range(len(output_shape)):
            output_shape[i] = int(input_shape[i] * weight[i])
        output_value = onnx.helper.make_tensor_value_info(
            node.output[0], input_value.type.tensor_type.elem_type,
            output_shape)
        g.value_info.extend([output_value])
        return True
    return False
Example #10
0
def change_output_shape(g, target_list):
    for target in target_list:
        try:
            name, shape = parse_shape_change_input(target)
            output_value = helper.find_output_by_name(g, name)
            if output_value is None:
                print("Cannot find output {}".format(name))
                continue
            if len(shape) != len(output_value.type.tensor_type.shape.dim):
                print("The dimension doesn't match for output {}".format(name))
                continue
            for i in range(len(shape)):
                output_value.type.tensor_type.shape.dim[i].dim_value = shape[i]
        except TypeError:
            # This happens when the parser function returns None.
            continue
        except ValueError:
            # This happens when the input cannot be converter into int
            print("Cannot parse {} into name and int".format(target))
            continue
Example #11
0
def inference_resize_shape(g):
    for node in g.node:
        if node.op_type != 'Resize':
            continue

        output_value = helper.find_value_by_name(g, node.output[0])
        output_value = helper.find_output_by_name(
            g, node.output[0]) if output_value is None else output_value
        if output_value is not None:
            continue

        if len(node.input) == 4:  # input: X, roi, scales, sizes
            shape_node = helper.find_node_by_output_name(g, node.input[3])
            if shape_node.op_type != 'Constant':
                continue

            _, shape_value = helper.constant_to_list(shape_node)
            output_value = onnx.helper.make_tensor_value_info(
                node.output[0], onnx.TensorProto.FLOAT,
                [int(v) for v in shape_value])
            g.value_info.extend([output_value])
            return True
        else:
            # If output shape is not given, inference from scales
            # Get the input shape
            input_value = helper.find_value_by_name(g, node.input[0])
            if input_value is None:
                continue
            shape_value = helper.get_shape_from_value_info(input_value)
            scales_node = helper.find_node_by_output_name(g, node.input[2])
            if scales_node.op_type != 'Constant':
                continue
            _, scales_value = helper.constant_to_list(scales_node)
            for i in range(len(shape_value)):
                shape_value[i] *= scales_value[i]
            output_value = onnx.helper.make_tensor_value_info(
                node.output[0], onnx.TensorProto.FLOAT,
                [int(v) for v in shape_value])
            g.value_info.extend([output_value])
            return True
    return False
Example #12
0
def remove_nodes(g, cut_nodes=[], cut_types=[]):
    node_to_delete = []
    #Find target nodes
    for node in g.node:
        if node.name not in cut_nodes and node.op_type not in cut_types:
            continue
        else:
            node_to_delete.append(node)
    # Mapping outputs
    output_mapping = {}
    new_output = set()
    for node in node_to_delete:
        original_output = find_first_sequential_outputs(g, node)
        if original_output.name not in output_mapping:
            output_mapping[original_output.name] = []
        for input_name in node.input:
            value = helper.find_value_by_name(g, input_name)
            if value is not None and helper.find_output_by_name(g, input_name) is None and value.name not in new_output:
                output_mapping[original_output.name].append(value)
                new_output.add(value.name)
    # Remove them
    while node_to_delete:
        g.node.remove(node_to_delete.pop())
    # Remove unreachable nodes
    visited_values = set()
    unused_constant_map = {}
    for input_value in g.input:
        visited_values.add(input_value.name)
    for node in g.node:
        if node.op_type == 'Constant':
            visited_values.add(node.output[0])
            unused_constant_map[node.output[0]] = node
            continue
        can_reach = True
        for input_name in node.input:
            if input_name not in visited_values:
                can_reach = False
                break
        if can_reach:
            for output_name in node.output:
                visited_values.add(output_name)
        else:
            node_to_delete.append(node)
    # Mapping outputs again
    for node in node_to_delete:
        original_output = find_first_sequential_outputs(g, node)
        if original_output.name not in output_mapping:
            output_mapping[original_output.name] = []
        for input_name in node.input:
            value = helper.find_value_by_name(g, input_name)
            if value is not None and helper.find_output_by_name(g, input_name) is None and value.name not in new_output:
                output_mapping[original_output.name].append(value)
                new_output.add(value.name)
    # Remove them
    while node_to_delete:
        g.node.remove(node_to_delete.pop())
    #Remove unused constants
    for node in g.node:
        for input_name in node.input:
            if input_name in unused_constant_map:
                del unused_constant_map[input_name]
    for node in unused_constant_map.values():
        g.node.remove(node)
    #Remove unreachable value infos
    reachable_values = set()
    for input_value in g.input:
        reachable_values.add(input_value.name)
    for node in g.node:
        for input_name in node.input:
            reachable_values.add(input_name)
        for output_name in node.output:
            reachable_values.add(output_name)
    value_to_remove = []
    for value_info in g.value_info:
        if value_info.name not in reachable_values:
            value_to_remove.append(value_info)
    while value_to_remove:
        value_info = value_to_remove.pop()
        g.value_info.remove(value_info)
    # Reorder output
    output_values = []
    while len(g.output):
        output_values.append(g.output.pop())
    while output_values:
        output_value = output_values.pop()
        if output_value.name in reachable_values:
            logging.info("Keep output {}".format(output_value.name))
            g.output.extend([output_value])
        elif output_value.name in output_mapping:
            real_outputs = [i for i in output_mapping[output_value.name] if i.name in reachable_values]
            logging.info("Replace output {} with {}".format(output_value.name, [i.name for i in real_outputs]))
            g.output.extend(real_outputs)
        else:
            logging.info("Abandon output {}".format(output_value.name))
            continue
Example #13
0
def inference_cov_shape(g):
    processed = False
    for node in g.node:
        # Check for Conv output shape need to be inferrenced.
        if node.op_type != 'Conv':
            continue
        # Input shape is not ready yet. Skip.
        input_value_info = helper.find_value_by_name(g, node.input[0])
        if not input_value_info:
            input_value_info = helper.find_input_by_name(g, node.input[0])
        if not input_value_info:
            continue
        _, input_shape = helper.find_size_shape_from_value(input_value_info)
        if not input_shape:
            continue
        # Output shape is already there. Skip.
        output_value_info = helper.find_value_by_name(g, node.output[0])
        if not output_value_info:
            output_value_info = helper.find_output_by_name(g, node.output[0])
        if output_value_info and \
            helper.get_shape_from_value_info(output_value_info):
            continue

        # Now start the inference.
        # If auto_pad is set, use the auto_pad.
        auto_pad = helper.get_var_attribute_by_name(node, 'auto_pad', 'string')
        pads = None
        if auto_pad is not None and auto_pad != 'NOTSET':
            if auto_pad == 'SAME_LOWER' or auto_pad == 'SAME_UPPER':
                new_output_value_info = onnx.helper.make_tensor_value_info(
                    node.output[0],
                    input_value_info.type.tensor_type.elem_type,
                    input_shape
                )
                if output_value_info:
                    g.value_info.remove(output_value_info)
                g.value_info.extend([new_output_value_info])
                processed = True
                continue
            elif auto_pad == 'VALID':
                pads = [0, 0, 0, 0]
            else:
                print("Unrecognized auto_pad value: " + str(auto_pad))
                exit(1)
        kernel_value_info = helper.find_value_by_name(g, node.input[1])
        _, kernel_shape = helper.find_size_shape_from_value(kernel_value_info)
        if not input_shape or not kernel_shape:
            continue
        strides = helper.get_attribute_by_name(node, 'strides').ints
        if not pads:
            pads = helper.get_attribute_by_name(node, 'pads').ints
        dilation = helper.get_attribute_by_name(node, 'dilations').ints

        # Pytorch model has the case where strides only have one number
        if len(strides) == 1:
            return strides.append(strides[0])
        if len(dilation) == 1:
            return dilation.append(dilation[0])

        H = math.floor((input_shape[2]+pads[0]+pads[2]-\
            dilation[0]*(kernel_shape[2]-1)-1)/strides[0]+1)
        W = math.floor((input_shape[3]+pads[1]+pads[3]-\
            dilation[1]*(kernel_shape[3]-1)-1)/strides[1]+1)
        output_shape = [input_shape[0], kernel_shape[0], H, W]

        new_output_value_info = onnx.helper.make_tensor_value_info(
            node.output[0],
            input_value_info.type.tensor_type.elem_type,
            output_shape
        )

        processed = True

        if output_value_info:
            g.value_info.remove(output_value_info)
        g.value_info.extend([new_output_value_info])

    return processed
Example #14
0
def fuse_BN_with_Reshape_into_Gemm(g):
    """Fuse the following BN into the previous Gemm, even with Reshape or \\
        Squeeze and Unsqueeze surrounding.

    :param g: the graph
    """
    node_to_remove = []
    for node in g.node:
        # Check for BN and Gemm pattern: Gemm A BN B
        # Find BatchNorm Node
        if node.op_type != 'BatchNormalization':
            continue
        bn_node = node
        # Find A Node
        a_node = helper.find_node_by_output_name(g, node.input[0])
        if a_node is None or len(a_node.input) == 0:
            continue
        # Find Gemm Node
        gemm_node = helper.find_node_by_output_name(g, a_node.input[0])
        if gemm_node is None or gemm_node.op_type != 'Gemm':
            continue
        # Find B Node
        b_node_list = helper.find_following_nodes_by_input_value_name(
            g, bn_node.output[0])
        if len(b_node_list) == 0:
            the_output = helper.find_output_by_name(g, bn_node.output[0])
            if the_output is None:
                continue
            b_node = None
        elif len(b_node_list) > 1:
            continue
        else:
            b_node = b_node_list[0]
        # Check for branches
        if len(
                helper.find_following_nodes_by_input_value_name(
                    g, gemm_node.output[0])) > 1:
            continue
        if len(
                helper.find_following_nodes_by_input_value_name(
                    g, a_node.output[0])) > 1:
            continue
        # Check type of A
        if a_node.op_type == 'Unsqueeze':
            axes = helper.get_attribute_by_name(a_node, 'axes')
            if axes.ints != [2]:
                continue
        elif a_node.op_type == 'Reshape':
            a = helper.constant_to_list(
                helper.find_node_by_output_name(g, a_node.input[1]))[1]
            if len(a) != 3 or a[2] != 1:
                continue
        else:
            continue
        # Check type of B
        if b_node is None:
            pass
        elif b_node.op_type == 'Flatten':
            pass
        elif b_node.op_type == 'Squeeze':
            axes = helper.get_attribute_by_name(a_node, 'axes')
            if axes.ints != [2]:
                continue
        elif b_node.op_type == 'Reshape':
            a = helper.constant_to_list(
                helper.find_node_by_output_name(g, b_node.input[1]))[1]
            if len(a) != 2:
                continue
        else:
            continue
        # Construct new Nodes
        # Get original weights
        gemm_b_node = helper.find_node_by_output_name(g, gemm_node.input[1])
        gemm_b = helper.constant_to_numpy(gemm_b_node)
        gemm_c_node = helper.find_node_by_output_name(g, gemm_node.input[2])
        gemm_c = helper.constant_to_numpy(gemm_c_node)
        bn_scale_node = helper.find_node_by_output_name(g, bn_node.input[1])
        bn_scale = helper.constant_to_numpy(bn_scale_node)
        bn_bias_node = helper.find_node_by_output_name(g, bn_node.input[2])
        bn_bias = helper.constant_to_numpy(bn_bias_node)
        bn_mean_node = helper.find_node_by_output_name(g, bn_node.input[3])
        bn_mean = helper.constant_to_numpy(bn_mean_node)
        bn_var_node = helper.find_node_by_output_name(g, bn_node.input[4])
        bn_var = helper.constant_to_numpy(bn_var_node)
        # Apply attributes
        # epsilon
        epsilon = helper.get_attribute_by_name(bn_node, 'epsilon')
        if epsilon is None:
            epsilon = 0.00001
        else:
            epsilon = epsilon.f
        bn_var = bn_var + epsilon
        # alpha
        alpha = helper.get_attribute_by_name(gemm_node, 'alpha')
        if alpha is None:
            alpha = 1
        else:
            alpha = alpha.f
        gemm_b = gemm_b * alpha
        # beta
        beta = helper.get_attribute_by_name(gemm_node, 'beta')
        if beta is None:
            beta = 1
        else:
            beta = beta.f
        gemm_c = gemm_c * beta
        # transA
        transA = helper.get_attribute_by_name(gemm_node, 'transA')
        if transA is not None and transA.i == 1:
            raise RuntimeError("Do not support transA")
        # transB
        transB = helper.get_attribute_by_name(gemm_node, 'transB')
        if transB is not None and transB.i == 1:
            gemm_b = gemm_b.transpose()
        # Calculate new weights
        new_gemm_b = gemm_b * bn_scale / np.sqrt(bn_var)
        new_gemm_c = (gemm_c - bn_mean) * bn_scale / np.sqrt(bn_var) + bn_bias
        # Replace original weights
        new_gemm_b_node = helper.numpy_to_constant(gemm_b_node.name + '_fused',
                                                   new_gemm_b)
        new_gemm_c_node = helper.numpy_to_constant(gemm_c_node.name + '_fused',
                                                   new_gemm_c)
        g.node.extend([new_gemm_b_node, new_gemm_c_node])
        # Modify attributes
        # alpha
        alpha = helper.get_attribute_by_name(gemm_node, 'alpha')
        if alpha is not None:
            alpha.f = 1.0
        # beta
        beta = helper.get_attribute_by_name(gemm_node, 'beta')
        if beta is not None:
            beta.f = 1.0
        # transB
        transB = helper.get_attribute_by_name(gemm_node, 'transB')
        if transB is not None:
            transB.i = 0
        # Remove useless nodes
        node_to_remove.extend([
            gemm_b_node, gemm_c_node, bn_node, bn_scale_node, bn_bias_node,
            bn_mean_node, bn_var_node, a_node
        ])
        if a_node.op_type == 'Reshape':
            node_to_remove.append(
                helper.find_node_by_output_name(g, a_node.input[1]))
        if b_node is not None:
            node_to_remove.append(b_node)
            if b_node.op_type == 'Reshape':
                node_to_remove.append(
                    helper.find_node_by_output_name(g, b_node.input[1]))
        # Delete useless value infos
        value = helper.find_value_by_name(g, a_node.output[0])
        g.value_info.remove(value)
        if a_node.op_type == 'Reshape':
            value = helper.find_value_by_name(g, a_node.input[1])
            g.value_info.remove(value)
        for i in range(1, 5):
            value = helper.find_value_by_name(g, bn_node.input[i])
            g.value_info.remove(value)
        value = helper.find_value_by_name(g, bn_node.output[0])
        if value is not None:
            g.value_info.remove(value)
        if b_node is not None:
            value = helper.find_value_by_name(g, gemm_node.output[0])
            g.value_info.remove(value)
            if b_node.op_type == 'Reshape':
                value = helper.find_value_by_name(g, b_node.input[1])
                g.value_info.remove(value)
        # Connect the new graph
        # Connect Gemm new weights
        gemm_node.input[1] = new_gemm_b_node.output[0]
        gemm_node.input[2] = new_gemm_c_node.output[0]
        gemm_b_value = helper.find_value_by_name(g, gemm_b_node.output[0])
        gemm_c_value = helper.find_value_by_name(g, gemm_c_node.output[0])
        gemm_b_value.name = new_gemm_b_node.output[0]
        gemm_b_value.type.tensor_type.shape.dim[
            0].dim_value = new_gemm_b.shape[0]
        gemm_b_value.type.tensor_type.shape.dim[
            1].dim_value = new_gemm_b.shape[1]
        gemm_c_value.name = new_gemm_c_node.output[0]
        if b_node is None:
            # If b node is None, set the Gemm output as the graph output
            output_value = helper.find_output_by_name(g, bn_node.output[0])
            g.output.remove(output_value)
            g.output.extend(
                [helper.find_value_by_name(g, gemm_node.output[0])])
        else:
            # Else, set node B output as gemm output
            gemm_node.output[0] = b_node.output[0]
    # Remove useless nodes
    for node in node_to_remove:
        g.node.remove(node)
    topological_sort(g)
Example #15
0
def remove_nodes(g, cut_nodes=[], cut_types=[]):
    node_to_delete = []
    #Find target nodes
    for node in g.node:
        if node.name not in cut_nodes and node.op_type not in cut_types:
            continue
        else:
            node_to_delete.append(node)
    #Remove them and add new outputs
    new_output = []
    while node_to_delete:
        node = node_to_delete.pop()
        for input_name in node.input:
            value = helper.find_value_by_name(g, input_name)
            if value is not None and helper.find_output_by_name(
                    g, input_name) is None:
                new_output.append(value)
        g.node.remove(node)
    g.output.extend(new_output)
    #Remove unreachable nodes
    visited_values = set()
    unused_constant_map = {}
    for input_value in g.input:
        visited_values.add(input_value.name)
    for node in g.node:
        if node.op_type == 'Constant':
            visited_values.add(node.output[0])
            unused_constant_map[node.output[0]] = node
            continue
        can_reach = True
        for input_name in node.input:
            if input_name not in visited_values:
                can_reach = False
                break
        if can_reach:
            for output_name in node.output:
                visited_values.add(output_name)
        else:
            node_to_delete.append(node)
    new_output = []
    while node_to_delete:
        node = node_to_delete.pop()
        for input_name in node.input:
            value = helper.find_value_by_name(g, input_name)
            if value is not None and helper.find_output_by_name(
                    g, input_name) is None:
                new_output.append(value)
        g.node.remove(node)
    g.output.extend(new_output)
    #Remove unused constants
    for node in g.node:
        for input_name in node.input:
            if input_name in unused_constant_map:
                del unused_constant_map[input_name]
    for node in unused_constant_map.values():
        g.node.remove(node)
    #Remove unreachable value infos and outputs
    reachable_values = set()
    for input_value in g.input:
        reachable_values.add(input_value.name)
    for node in g.node:
        for input_name in node.input:
            reachable_values.add(input_name)
        for output_name in node.output:
            reachable_values.add(output_name)
    value_to_remove = []
    for value_info in g.value_info:
        if value_info.name not in reachable_values:
            value_to_remove.append(value_info)
    while value_to_remove:
        value_info = value_to_remove.pop()
        g.value_info.remove(value_info)
    for value_info in g.output:
        if value_info.name not in reachable_values:
            value_to_remove.append(value_info)
    while value_to_remove:
        value_info = value_to_remove.pop()
        g.output.remove(value_info)