Пример #1
0
    def optimize(self, graph: Graph) -> Tuple[Graph, bool]:
        flag_changed = False
        for op in traverse.listup_operators(graph):
            if isinstance(op, (Reshape, ReinterpretAxis)):
                flag_changed |= _replace_input(graph, op, "x",
                                               op.parameters["in_order"])
                flag_changed |= _replace_output(graph, op, "y",
                                                op.parameters["out_order"])
                continue

            elif isinstance(op, LSTM):
                flag_changed |= _replace_input(graph, op, "x", OrderNTC)
                flag_changed |= _replace_input(graph, op, "w_input", OrderCN)
                flag_changed |= _replace_input(graph, op, "w_hidden", OrderCN)
                flag_changed |= _replace_output(
                    graph, op, "y",
                    OrderNTC if op.parameters["return_sequences"] else OrderNC)
                flag_changed |= _replace_output(graph, op, "final_c", OrderNC)
                continue

            elif isinstance(op, Embedding):
                flag_changed |= _replace_input(graph, op, "x", OrderNT)
                flag_changed |= _replace_input(graph, op, "w", OrderCN)
                flag_changed |= _replace_output(graph, op, "y", OrderNTC)
                continue

            elif isinstance(op, Im2Col):
                flag_changed |= _replace_input(graph, op, "im", OrderNHWC)
                flag_changed |= _replace_output(graph, op, "col", [
                    Order([Axis.N, Axis.H, Axis.W, Axis.KH, Axis.KW, Axis.C]),
                    Order([Axis.KH, Axis.KW, Axis.C, Axis.N, Axis.H, Axis.W])
                ])
                continue

            elif isinstance(op, Col2Im):
                flag_changed |= _replace_input(graph, op, "col", [
                    Order([Axis.N, Axis.H, Axis.W, Axis.KH, Axis.KW, Axis.C])
                ])
                flag_changed |= _replace_output(graph, op, "im", OrderNHWC)
                continue

            elif isinstance(op, (Tensordot, )):
                op = op  # type: Tensordot
                A = op.inputs["A"]
                B = op.inputs["B"]
                C = op.outputs["C"]

                # Reduced axes must be located in inner side.
                a_axes = list(A.order.axes)
                for axis in op.axes[0]:
                    a_axes.remove(axis)
                    a_axes.append(axis)

                b_axes = list(B.order.axes)
                for axis in op.axes[1]:
                    b_axes.remove(axis)
                    b_axes.append(axis)

                # Remained axes must be located in same order as A and B's axes order.
                if all(axis in op.axes[0]
                       for axis in C.order.axes[:A.ndim - len(op.axes[0])]):
                    # C's order is as [*a_remained_axes, *b_remained_axes], so it's not need to transpose C.
                    for i, axis in enumerate(C.order.axes[:A.ndim -
                                                          len(op.axes[0])]):
                        a_axes.remove(axis)
                        a_axes.insert(i, axis)

                    for i, axis in enumerate(C.order.axes[A.ndim -
                                                          len(op.axes[0]):]):
                        b_axes.remove(axis)
                        b_axes.insert(i, axis)

                else:
                    c_axes = a_axes[:(A.ndim - len(op.axes[0]))] + b_axes[:(
                        B.ndim - len(op.axes[1]))]
                    flag_changed |= _replace_output(graph, op, "C",
                                                    Order(c_axes))

                flag_changed |= _replace_input(graph, op, "A", Order(a_axes))
                flag_changed |= _replace_input(graph, op, "B", Order(b_axes))
                continue

            elif isinstance(op, (Convolution2D, Deconvolution2D, MaxPooling2D,
                                 AveragePooling2D, Space2Depth, Depth2Space,
                                 LocalResponseNormalization, Unpooling2D)):
                flag_changed |= _replace_input(graph, op, "x", OrderNHWC)
                flag_changed |= _replace_output(graph, op, "y", OrderNHWC)
                continue

            elif isinstance(op, Softmax):
                x = op.inputs["x"]
                y = op.outputs["y"]
                target_axis = op.parameters["axis"]

                if not (x.ndim == 2
                        and x.order.axes_dict[target_axis] == x.ndim - 1):
                    """
                    Before)
                    | x   |              | y   |
                    |-----| -{softmax}-> |-----|
                    | XYZ |   axis=Y     | XYZ |

                    After)
                    | x   |                | hx1 |              | hx2 |              | hy1 |              | hy2 |                | y   |
                    |-----| -{transpose}-> |-----| -{reshape}-> |-----| -{softmax}-> |-----| -{reshape}-> |-----| -{transpose}-> |-----|
                    | XYZ |                | XZY |              | NC  |   axis=C     | NC  |              | XZY |                | XYZ |
                                              :                    :
                                        order_nd = XZY       order_2d = NC
                    """
                    op.remove_all()

                    axes_nd = list(x.order.axes)
                    axes_nd.remove(target_axis)
                    axes_nd.append(target_axis)
                    order_nd = Order(axes_nd)
                    shape_nd = tuple([x.shape_dict[axis] for axis in axes_nd])

                    order_2d = OrderNC
                    shape_2d = tuple([
                        x.size // x.shape_dict[target_axis],
                        x.shape_dict[target_axis]
                    ])

                    if x.order == order_nd:
                        hx1 = x

                    else:
                        hx1 = x.transpose(order_nd)
                        flag_changed = True

                    if hx1.order == order_2d and hx1.shape == shape_2d:
                        hx2 = hx1

                    else:
                        hx2 = hx1.reshape(shape_2d, order_2d)
                        flag_changed = True

                    hy1, = Softmax(None, axis=Axis.C)(hx2)

                    if hy1.order == order_nd and hy1.shape == shape_nd:
                        hy2 = hy1

                    else:
                        hy2 = hy1.reshape(shape_nd, order_nd)
                        flag_changed = True

                    if hy2.order == y.order:
                        y_dummy = hy2

                    else:
                        y_dummy = hy2.transpose(y.order)
                        flag_changed = True

                    OptimizeRule.replace_variable(graph, y_dummy, y)

                    continue

            else:
                # "op" accepts any order. Remove redundant transpose operations if exist.
                for key in op.inputs:
                    flag_changed |= _optimize_redundant_transposed_input(
                        graph, op, key, None)
                for key in op.outputs:
                    flag_changed |= _optimize_redundant_transposed_output(
                        graph, op, key, None)
                continue

        return graph, flag_changed
Пример #2
0
    def optimize(self, graph: Graph) -> Tuple[Graph, bool]:
        flag_changed = False
        for op in traverse.listup_operators(graph):
            if isinstance(op, Transpose):
                x = op.inputs["x0"]
                y = op.outputs["y"]

                if x.order == y.order:
                    op.remove_all()
                    OptimizeRule.replace_variable(graph, x, y)

                    if x in graph.inputs:
                        index = graph.inputs.index(x)
                        graph.inputs.remove(x)
                        graph.inputs.insert(index, y)

                    flag_changed = True
                    continue

                if y not in graph.outputs and all(
                        isinstance(op2, (Elementwise, SplitAxis))
                        for op2 in y.input_to):
                    op.remove_all()
                    for op2 in list(y.input_to):
                        name = op2.get_input_name(y)
                        op2.remove_input(y)
                        op2.append_input(name, x)

                    flag_changed = True
                    continue

            elif isinstance(op, Reshape):
                flag_changed |= _replace_input(op, "x",
                                               op.parameters["in_order"])
                flag_changed |= _replace_output(op, "y",
                                                op.parameters["out_order"])
                continue

            elif isinstance(op, (Tensordot, )):
                op = op  # type: Tensordot
                A = op.inputs["A"]
                B = op.inputs["B"]
                C = op.outputs["C"]

                # Reduced axes must be located in inner side.
                a_axes = list(A.order.axes)
                for axis in op.axes[0]:
                    a_axes.remove(axis)
                    a_axes.append(axis)

                b_axes = list(B.order.axes)
                for axis in op.axes[1]:
                    b_axes.remove(axis)
                    b_axes.append(axis)

                # Remained axes must be located in same order as A and B's axes order.
                if all(axis in a_axes
                       for axis in C.order.axes[:A.ndim - len(op.axes[0])]):
                    # C's order is as [*a_remained_axes, *b_remained_axes], so it's not need to transpose C.
                    for i, axis in enumerate(C.order.axes[:A.ndim -
                                                          len(op.axes[0])]):
                        a_axes.remove(axis)
                        a_axes.insert(i, axis)

                    for i, axis in enumerate(C.order.axes[A.ndim -
                                                          len(op.axes[0]):]):
                        b_axes.remove(axis)
                        b_axes.insert(i, axis)

                else:
                    c_axes = a_axes[:len(op.axes[0])] + b_axes[:len(op.axes[1]
                                                                    )]
                    flag_changed |= _replace_output(op, "C", Order(c_axes))

                flag_changed |= _replace_input(op, "A", Order(a_axes))
                flag_changed |= _replace_input(op, "B", Order(b_axes))
                continue

            elif isinstance(op, (Convolution2D, Deconvolution2D, MaxPooling2D,
                                 AveragePooling2D, Space2Depth, Depth2Space,
                                 LocalResponseNormalization, Unpooling2D)):
                flag_changed |= _replace_input(op, "x", OrderNHWC)
                flag_changed |= _replace_output(op, "y", OrderNHWC)
                continue

            elif isinstance(op, Softmax):
                x = op.inputs["x"]
                y = op.outputs["y"]
                target_axis = op.parameters["axis"]

                if not (x.ndim == 2
                        and x.order.axes_dict[target_axis] == x.ndim - 1):
                    """
                    Before)
                    | x   |              | y   |
                    |-----| -{softmax}-> |-----|
                    | XYZ |   axis=Y     | XYZ |

                    After)
                    | x   |                | hx1 |              | hx2 |              | hy1 |              | hy2 |                | y   |
                    |-----| -{transpose}-> |-----| -{reshape}-> |-----| -{softmax}-> |-----| -{reshape}-> |-----| -{transpose}-> |-----|
                    | XYZ |                | XZY |              | NC  |   axis=C     | NC  |              | XZY |                | XYZ |
                                              :                    :
                                        order_nd = XZY       order_2d = NC
                    """
                    op.remove_all()

                    axes_nd = list(x.order.axes)
                    axes_nd.remove(target_axis)
                    axes_nd.append(target_axis)
                    order_nd = Order(axes_nd)
                    shape_nd = tuple([x.shape_dict[axis] for axis in axes_nd])

                    order_2d = OrderNC
                    shape_2d = tuple([
                        x.size // x.shape_dict[target_axis],
                        x.shape_dict[target_axis]
                    ])

                    if x.order == order_nd:
                        hx1 = x

                    else:
                        hx1 = x.transpose(order_nd)
                        flag_changed = True

                    if hx1.order == order_2d and hx1.shape == shape_2d:
                        hx2 = hx1

                    else:
                        hx2 = hx1.reshape(shape_2d, order_2d)
                        flag_changed = True

                    hy1, = Softmax(None, axis=Axis.C)(hx2)

                    if hy1.order == order_nd and hy1.shape == shape_nd:
                        hy2 = hy1

                    else:
                        hy2 = hy1.reshape(shape_nd, order_nd)
                        flag_changed = True

                    if hy2.order == y.order:
                        y_dummy = hy2

                    else:
                        y_dummy = hy2.transpose(y.order)
                        flag_changed = True

                    OptimizeRule.replace_variable(graph, y_dummy, y)

                    continue

        return graph, flag_changed