Пример #1
0
def fuse_consecutive_reducemean(g):
    node_to_del = []
    for node in g.node:
        # Find consecutive ReduceMean
        if node.op_type != 'ReduceMean':
            continue
        pre_node = helper.find_node_by_output_name(g, node.input[0])
        if pre_node is None or pre_node.op_type != 'ReduceMean':
            continue
        # Check attributes
        pre_keepdims = helper.get_var_attribute_by_name(
            pre_node, 'keepdims', 'int')
        pre_axes = helper.get_list_attribute_by_name(pre_node, 'axes', 'int')
        cur_keepdims = helper.get_var_attribute_by_name(
            node, 'keepdims', 'int')
        cur_axes = helper.get_list_attribute_by_name(node, 'axes', 'int')
        if pre_keepdims != 0 or cur_keepdims != 0:
            continue
        axes = sorted(pre_axes + cur_axes)
        if axes != [2, 3]:
            continue
        # Merge two ReduceMean into GlobalAveragePool.
        new_gap_node = onnx.helper.make_node('GlobalAveragePool',
                                             [pre_node.input[0]],
                                             [node.output[0] + '_intermedia'],
                                             name=node.name + '_gap')
        new_flatten_node = onnx.helper.make_node(
            'Flatten', [node.output[0] + '_intermedia'], [node.output[0]],
            name=node.name + '_flatten',
            axis=1)

        # Clean up
        g.node.extend([new_gap_node, new_flatten_node])
        node_to_del.extend([pre_node, node])
        mid_val_info = helper.find_value_by_name(g, node.input[0])
        if mid_val_info:
            g.value_info.remove(mid_val_info)

    while node_to_del:
        node = node_to_del.pop()
        g.node.remove(node)

    topological_sort(g)
Пример #2
0
def replace_ReduceMean_with_GlobalAveragePool(g):
    """
    Replace ReduceMean with GlobalAveragePool node when available.

    If there is preceeded Transpose, check the Transpose and the ReduceMean
    together. If the keep_dims is set to 0, add a Flatten.

    :param g: the input graph
    """
    node_to_remove = []
    for node in g.node:
        # Find a ReduceMean layer
        if node.op_type != 'ReduceMean':
            continue
        # Find if it have previous Transpose and its attribute meet the need.
        prev_node = helper.find_node_by_output_name(g, node.input[0])
        if prev_node is not None and prev_node.op_type != 'Transpose':
            prev_node = None
        if prev_node is not None:
            perm = helper.get_list_attribute_by_name(prev_node, 'perm', 'int')
            if perm != [0, 2, 3, 1]:
                prev_node = None
        # Check attributes
        axes = helper.get_list_attribute_by_name(node, 'axes', 'int')
        keepdims = helper.get_var_attribute_by_name(node, 'keepdims', 'int')
        if axes is None:
            continue
        if prev_node is None and axes != [2, 3]:
            continue
        if prev_node is not None and axes != [1, 2]:
            continue
        if keepdims is None:
            keepdims = 1
        # Replace it with GlobalAveragePool
        if prev_node:
            input_list = prev_node.input
        else:
            input_list = node.input
        if keepdims == 1:
            output_list = node.output
        else:
            output_list = [node.output[0] + '_before_flatten']
            flatten_node = onnx.helper.make_node("Flatten",
                                                 output_list,
                                                 node.output,
                                                 name=node.name + "_flatten",
                                                 axis=1)
            g.node.extend([flatten_node])
        new_node = onnx.helper.make_node("GlobalAveragePool",
                                         input_list,
                                         output_list,
                                         name=node.name)
        g.node.extend([new_node])
        node_to_remove.append(node)
        if prev_node:
            value = helper.find_value_by_name(g, prev_node.output[0])
            if value:
                g.value_info.remove(value)
            node_to_remove.append(prev_node)
    for node in node_to_remove:
        g.node.remove(node)
    topological_sort(g)
Пример #3
0
def deconv_to_conv_info_extraction(input_size, node_proto):
    """Extract the information needed for deconv split.

    :param input_size: input shape of the deconv node.\\
    :param node_proto: the deconv node proto.\\
    :return: a dictionary of extracted params.
    """
    attr = dict()
    # Get attributes from Deconv node
    attr["auto_pad"] = helper.get_var_attribute_by_name(node_proto, "auto_pad", "string")
    attr["dilations"] = helper.get_list_attribute_by_name(node_proto, "dilations", "int")
    attr["group"] = helper.get_var_attribute_by_name(node_proto, "group", "int")
    attr["kernel_shape"] = helper.get_list_attribute_by_name(node_proto, "kernel_shape", "int")
    attr["output_padding"] = helper.get_list_attribute_by_name(node_proto, "output_padding", "int")
    attr["pads"] = helper.get_list_attribute_by_name(node_proto, "pads", "int")
    attr["strides"] = helper.get_list_attribute_by_name(node_proto, "strides", "int")
    # Get output_padding
    if attr["output_padding"] is None:
        if attr["auto_pad"] == "SAME_LOWER" or attr["auto_pad"] == "SAME_UPPER":
            attr["output_padding"] = [attr["strides"][0] - 1, attr["strides"][1]]
        else:
            attr["output_padding"] = [max(attr["strides"][0] - attr["kernel_shape"][0], 0),
                                      max(attr["strides"][1] - attr["kernel_shape"][1], 0)]
    # Calculate conv_padding
    if attr["auto_pad"] == "SAME_LOWER" or attr["auto_pad"] == "SAME_UPPER":
        pad1_h = attr["kernel_shape"][0] - (attr["kernel_shape"][0] - 1) // 2 - 1
        pad1_w = attr["kernel_shape"][1] - (attr["kernel_shape"][1] - 1) // 2 - 1
        head_h = min(attr["kernel_shape"][0] // 2, (attr["output_padding"][0] + 1) // 2)
        head_w = min(attr["kernel_shape"][1] // 2, (attr["output_padding"][1] + 1) // 2)
        tail_h = attr["output_padding"][0] - head_h
        tail_w = attr["output_padding"][1] - head_w
        attr["conv_pads"] = [pad1_h + head_h, pad1_w + head_w, pad1_h + tail_h, pad1_w + tail_w]
    elif attr["pads"] is not None:
        sum_of_pads = sum(attr["pads"])
        if sum_of_pads == 0:
            # Valid padding
            pad1_h = attr["kernel_shape"][0] - 0 - 1
            pad1_w = attr["kernel_shape"][1] - 0 - 1
            head_h = 0
            head_w = 0
            tail_h = attr["output_padding"][0] - head_h
            tail_w = attr["output_padding"][1] - head_w
            attr["conv_pads"] = [pad1_h + head_h, pad1_w + head_w, pad1_h + tail_h, pad1_w + tail_w]
        else:
            # Calculate output shape
            tmp_output_shape = [0, 0]
            tmp_output_shape[0] = attr["strides"][0] * (input_size[2] - 1) + attr["output_padding"][0] + attr["kernel_shape"][0] - attr["pads"][0] - attr["pads"][2]
            tmp_output_shape[1] = attr["strides"][1] * (input_size[3] - 1) + attr["output_padding"][1] + attr["kernel_shape"][1] - attr["pads"][1] - attr["pads"][3]
            # Calculate real conv output shape
            tmp_center_shape = [0, 0]
            tmp_center_shape[0] = (input_size[2] - 1) * attr["strides"][0] + 1
            tmp_center_shape[1] = (input_size[3] - 1) * attr["strides"][1] + 1
            # Calculate padding
            total_padding = [0, 0]
            total_padding[0] = tmp_output_shape[0] - tmp_center_shape[0] + attr["kernel_shape"][0] - 1
            total_padding[1] = tmp_output_shape[1] - tmp_center_shape[1] + attr["kernel_shape"][1] - 1
            if total_padding[0] < 0 or total_padding[1] < 0:
                raise RuntimeError(node_proto.name + " cannot infer conv padding.")
            conv_pads_ = [0] * 4
            conv_pads_[0] = total_padding[0] // 2
            conv_pads_[1] = total_padding[1] // 2
            conv_pads_[2] = total_padding[0] - total_padding[0] // 2
            conv_pads_[3] = total_padding[1] - total_padding[1] // 2
            attr["conv_pads"] = conv_pads_
    else:
        pad1_h = attr["kernel_shape"][0] - 0 - 1
        pad1_w = attr["kernel_shape"][1] - 0 - 1
        head_h = 0
        head_w = 0
        tail_h = attr["output_padding"][0] - head_h
        tail_w = attr["output_padding"][1] - head_w
        attr["conv_pads"] = [pad1_h + head_h, pad1_w + head_w, pad1_h + tail_h, pad1_w + tail_w]
    return attr