def optimize(self, graph: Graph) -> Tuple[Graph, bool]:
        flag_changed = False
        matches = traverse.search_sub_structure(
            graph, [Sgemm, Variable, ElementwiseMul])
        while len(matches) > 0:
            match = matches.pop()
            sgemm = match[0]  # type: Sgemm
            elementwise_mul = match[2]  # type:  ElementwiseMul

            out_order = sgemm.parameters["out_order"]
            out_shape = sgemm.parameters["out_shape"]

            axis_k = Axis('AxisK')

            if not isinstance(sgemm.inputs["A"],
                              ConstantVariable) and not isinstance(
                                  sgemm.inputs["B"], ConstantVariable):
                # neither x nor w1 is constant
                continue

            elif isinstance(sgemm.inputs["A"], ConstantVariable):
                w1 = sgemm.inputs["A"]  # type: ConstantVariable

                if sgemm.transpose_A:
                    # w1.shape = (M, K)

                    shape = []
                    axes = []
                    for axis, size in zip(out_order.axes, out_shape):
                        shape.append(size)
                        axes.append(axis)

                        if mul(shape) >= sgemm.M:
                            break

                    if mul(shape) != sgemm.M:
                        # output axes are derived from both w1 and x
                        continue

                    w1_virtual_order = Order(axes + [axis_k])
                    w1_virtual_shape = shape + [sgemm.K]

                else:
                    # w1.shape = (K, M)

                    shape = [sgemm.K]
                    axes = [axis_k]
                    for axis, size in zip(out_order.axes, out_shape):
                        shape.append(size)
                        axes.append(axis)

                        if mul(shape) >= w1.size:
                            break

                    if mul(shape) != w1.size:
                        # output axes are derived from both w1 and x
                        continue

                    w1_virtual_order = Order(axes)
                    w1_virtual_shape = shape

            else:
                w1 = sgemm.inputs["B"]  # type: ConstantVariable

                if sgemm.transpose_B:
                    # w1.shape = (K, N)

                    shape = []
                    axes = []
                    for axis, size in reversed(
                            list(zip(out_order.axes, out_shape))):
                        shape.insert(0, size)
                        axes.insert(0, axis)

                        if mul(shape) >= sgemm.N:
                            break

                    if mul(shape) != sgemm.N:
                        # output axes are derived from both w1 and x
                        continue

                    w1_virtual_order = Order([axis_k] + axes)
                    w1_virtual_shape = [sgemm.K] + shape

                else:
                    # w1.shape = (N, K)
                    shape = [sgemm.K]
                    axes = [axis_k]
                    for axis, size in reversed(
                            list(zip(out_order.axes, out_shape))):
                        shape.insert(0, size)
                        axes.insert(0, axis)

                        if mul(shape) >= w1.size:
                            break

                    if mul(shape) != w1.size:
                        # output axes are derived from both w1 and x
                        continue

                    w1_virtual_order = Order(axes)
                    w1_virtual_shape = shape

            h = sgemm.outputs["C"]  # type: Variable

            x0 = elementwise_mul.inputs["x0"]
            x1 = elementwise_mul.inputs["x1"]
            if h == x1:
                if not isinstance(x0, ConstantVariable):
                    # w2 is not constant
                    continue

                w2 = x0  # type: ConstantVariable

            else:
                if not isinstance(x1, ConstantVariable):
                    # w2 is not constant
                    continue

                w2 = x1  # type: ConstantVariable

            y = elementwise_mul.outputs["y"]  # type: Variable

            if not all(axis in w1_virtual_order.axes
                       for axis in w2.order.axes):
                # w2's axes are derived from both w1 and x
                continue

            elementwise_mul.remove_all()
            y_dummy, = Transpose(None)(h)
            y_dummy.change_order(y.order)
            y_dummy.replace(y)

            w2.change_order(w1_virtual_order)
            w_new = ConstantVariable(
                w1.data.reshape(w1_virtual_shape),
                w1_virtual_order) * w2  # type: ConstantVariable
            w1.data = w_new.data.reshape(w1.shape)

            flag_changed = True
            matches = traverse.search_sub_structure(
                graph, [Sgemm, Variable, ElementwiseMul])

        return graph, flag_changed
Beispiel #2
0
    def optimize(self, graph: Graph) -> Tuple[Graph, bool]:
        flag_changed = False
        for op in traverse.listup_operators(graph):
            if isinstance(op, Reshape):
                flag_changed |= _replace_input(op, "x",
                                               op.parameters["in_order"])
                flag_changed |= _replace_output(op, "y",
                                                op.parameters["out_order"])
                continue

            elif isinstance(op, (Convolution2D, MaxPooling2D, AveragePooling2D,
                                 Deconvolution2D, Space2Depth, Depth2Space)):
                flag_changed |= _replace_input(op, "x", OrderNHWC)
                flag_changed |= _replace_output(op, "y", OrderNHWC)
                continue

            elif isinstance(op, Softmax):
                x = op.inputs["x"]
                y = op.outputs["y"]
                target_axis = op.parameters["axis"]

                if not (x.ndim == 2
                        and x.order.axes_dict[target_axis] == x.ndim - 1):
                    """
                    Before)
                    | x   |              | y   |
                    |-----| -{softmax}-> |-----|
                    | XYZ |   axis=Y     | XYZ |
                    
                    After)
                    | x   |                | hx1 |              | hx2 |              | hy1 |              | hy2 |                | y   |
                    |-----| -{transpose}-> |-----| -{reshape}-> |-----| -{softmax}-> |-----| -{reshape}-> |-----| -{transpose}-> |-----|
                    | XYZ |                | XZY |              | NC  |   axis=C     | NC  |              | XZY |                | XYZ |
                                              :                    :
                                        order_nd = XZY       order_2d = NC
                    """
                    op.remove_all()

                    axes_nd = list(x.order.axes)
                    axes_nd.remove(target_axis)
                    axes_nd.append(target_axis)
                    order_nd = Order(axes_nd)
                    shape_nd = tuple([x.shape_dict[axis] for axis in axes_nd])

                    order_2d = OrderNC
                    shape_2d = tuple([
                        x.size // x.shape_dict[target_axis],
                        x.shape_dict[target_axis]
                    ])

                    if x.order == order_nd:
                        hx1 = x

                    else:
                        hx1, = Transpose(None)(x)
                        hx1.change_order(order_nd)
                        flag_changed = True

                    if hx1.order == order_2d and hx1.shape == shape_2d:
                        hx2 = hx1

                    else:
                        hx2, = Reshape(None,
                                       in_order=hx1.order,
                                       out_order=order_2d,
                                       out_shape=shape_2d)(hx1)
                        flag_changed = True

                    hy1, = Softmax(None, axis=Axis.C)(hx2)

                    if hy1.order == order_nd and hy1.shape == shape_nd:
                        hy2 = hy1

                    else:
                        hy2, = Reshape(None,
                                       in_order=hy1.order,
                                       out_order=order_nd,
                                       out_shape=shape_nd)(hy1)
                        flag_changed = True

                    if hy2.order == y.order:
                        y_dummy = hy2

                    else:
                        y_dummy, = Transpose(None)(hy2)
                        y_dummy.change_order(y.order)
                        flag_changed = True

                    y_dummy.replace(y)

                    continue

        return graph, flag_changed
Beispiel #3
0
    def optimize(self, graph: Graph) -> Tuple[Graph, bool]:
        flag_changed = False
        for op in traverse.listup_operators(graph):
            if isinstance(op, Transpose):
                x = op.inputs["x0"]
                y = op.outputs["y"]
                if x.order == y.order:
                    op.remove_all()
                    x.replace(y)
                    flag_changed = True

                if all(isinstance(op2, (Elementwise, SplitAxis)) for op2 in y.input_to):
                    op.remove_all()
                    for op2 in list(y.input_to):
                        name = op2._get_input_name(y)
                        op2.remove_input(y)
                        op2.append_input(name, x)

            elif isinstance(op, Reshape):
                flag_changed |= _replace_input(op, "x", op.parameters["in_order"])
                flag_changed |= _replace_output(op, "y", op.parameters["out_order"])

            elif isinstance(op, (Convolution2D, MaxPooling2D, AveragePooling2D, Deconvolution2D)):
                flag_changed |= _replace_input(op, "x", OrderNHWC)
                flag_changed |= _replace_output(op, "y", OrderNHWC)

            elif isinstance(op, Softmax):
                x = op.inputs["x"]
                y = op.outputs["y"]

                if x.ndim > 2:
                    """
                    Before)
                    | x    |              | y    |
                    |------| -{softmax}-> |------|
                    | NCHW |              | NCHW |

                    After)
                    | x    |                | hx1  |              | hx2 |              | hy1 |              | hy2  |                | y    |
                    |------| -{transpose}-> |------| -{reshape}-> |-----| -{softmax}-> |-----| -{reshape}-> |------| -{transpose}-> |------|
                    | NCHW |                | NHWC |              | NC  |              | NC  |              | NHWC |                | NCHW |
                    """
                    op.remove_all()

                    target_axis = op.parameters["axis"]
                    axes_nd = list(x.order.axes)
                    axes_nd.remove(target_axis)
                    axes_nd.append(target_axis)
                    order_nd = Order(axes_nd)
                    shape_nd = [x.shape_dict[axis] for axis in axes_nd]

                    order_2d = OrderNC
                    shape_2d = [x.size // x.shape_dict[target_axis], x.shape_dict[target_axis]]

                    hx1, = Transpose(None)(x)
                    hx1.change_order(order_nd)

                    hx2, = Reshape(None, in_order=hx1.order, out_order=order_2d, out_shape=shape_2d)(hx1)

                    hy1, = Softmax(None, axis=Axis.C)(hx2)

                    hy2, = Reshape(None, in_order=hy1.order, out_order=order_nd, out_shape=shape_nd)(hy1)

                    y_dummy, = Transpose(None)(hy2)
                    y_dummy.change_order(y.order)

                    y_dummy.replace(y)
                    flag_changed = True

                else:
                    flag_changed |= _replace_input(op, "x", OrderNC)
                    flag_changed |= _replace_output(op, "y", OrderNC)

        return graph, flag_changed