Example #1
0
    def optimize(self, graph: Graph) -> Tuple[Graph, bool]:
        flag_changed = False
        for op in traverse.filter_nodes(traverse.listup_operators(graph), Linear):  # type: Linear
            x = op.inputs["x"]
            w = op.inputs["w"]
            y = op.outputs["y"]
            assert x.order == OrderNC or x.order == OrderNHWC, f"(x.order) = {x.order}"
            assert w.order == OrderCN or w.order == OrderHWCN, f"(x.order) = {w.order}"
            assert y.order == OrderNC or y.order == OrderNHWC, f"(x.order) = {y.order}"
            assert w.ndim == x.ndim

            flag_changed = True
            op.remove_all()

            sgemm = Sgemm(None,
                          M=y.shape_dict[Axis.N],
                          N=y.size // y.shape_dict[Axis.N],
                          K=x.size // x.shape_dict[Axis.N],
                          out_shape=y.shape,
                          out_order=y.order,
                          transpose_A=True,
                          transpose_B=True)
            new_y, = sgemm(x, w)

            sgemm.replace_output(new_y, y)

        return graph, flag_changed
Example #2
0
def template(transpose_A=False,
             transpose_B=False,
             M=5,
             N=8,
             K=6,
             description: str = ""):
    va = np.random.rand(M, K).astype(np.float32)
    vb = np.random.rand(K, N).astype(np.float32)
    va[0, :] = 2
    vb[:, 0] = 2

    vc = np.dot(va, vb)

    a = Variable((va if transpose_A else va.transpose()).shape, order=OrderNC)
    b = ConstantVariable((vb if transpose_B else vb.transpose()),
                         order=OrderNC)
    c, = Sgemm(None,
               M=M,
               N=N,
               K=K,
               out_shape=[M, N],
               out_order=OrderNC,
               transpose_A=transpose_A,
               transpose_B=transpose_B)(a, b)
    generate_kernel_test_case(
        description=f"Sgemm {description}",
        backend=["webgpu", "webassembly", "webgl"],
        graph=Graph([a], [c]),
        inputs={a: (va if transpose_A else va.transpose())},
        expected={c: vc})
Example #3
0
def test_sgemm_invalid_C_shape():
    op = Sgemm(None,
               M=10,
               N=20,
               K=30,
               out_shape=[1, 2, 3, 4],
               out_order=OrderNHWC,
               transpose_A=True,
               transpose_B=True)

    x = Variable((10, 30), OrderNC)
    w = Variable((20, 30), OrderNC)
    op(x, w)
Example #4
0
    def optimize(self, graph: Graph) -> Tuple[Graph, bool]:
        flag_changed = False
        for op in traverse.filter_nodes(traverse.listup_operators(graph),
                                        Convolution2D):  # type: Convolution2D
            x = op.inputs["x"]
            w = op.inputs["w"]
            y = op.outputs["y"]

            assert x.order == OrderNHWC
            assert y.order == OrderNHWC
            assert isinstance(w, ConstantVariable)

            flag_changed = True
            op.remove_all()
            w.change_order(OrderHWCN)

            if op.WH != 1 or op.WW != 1 or op.stride != (
                    1, 1) or op.padding != (0, 0):
                im2col = Im2Col(None,
                                ksize=op.ksize,
                                stride=op.stride,
                                padding=op.padding,
                                dilation_rate=op.dilation_rate)
                col, = im2col(x)
                col.change_order(OrderNHWC)

            else:
                col = x

            sgemm = Sgemm(
                None,
                M=col.shape_dict[Axis.N] * col.shape_dict[Axis.H] *
                col.shape_dict[Axis.W],
                N=w.shape_dict[Axis.N],
                K=col.shape_dict[Axis.C],
                out_shape=[
                    col.shape_dict[Axis.N], col.shape_dict[Axis.H],
                    col.shape_dict[Axis.W], w.shape_dict[Axis.N]
                ],
                out_order=OrderNHWC,
                transpose_A=True if col.order == OrderNHWC else False,
                transpose_B=True)

            new_y, = sgemm(col, w)
            new_y.replace(y)

        return graph, flag_changed
Example #5
0
def sgemm(op: Sgemm, memory_layout: MemoryLayout) -> List[Kernel]:
    A = op.inputs["A"]
    B = op.inputs["B"]
    C = op.outputs["C"]

    buffer_injector = BufferInjector()
    buffer_injector.register({
        "sgemm_A": memory_layout[A],
        "sgemm_B": memory_layout[B],
        "sgemm_C": memory_layout[C],
        "sgemm_M": op.M,
        "sgemm_N": op.N,
        "sgemm_K": op.K
    })

    if op.has_attribute(SgemmWithEigen):
        source = generate_template_eigen(op.transpose_A, op.transpose_B)
        buffer_injector.register({
            "sgemm_A": memory_layout[A],
            "sgemm_B": memory_layout[B],
            "sgemm_C": memory_layout[C]
        })

    else:
        source = generate_template(op.transpose_A, op.transpose_B)
        buffer_injector.register({
            "sgemm_A": memory_layout[A],
            "sgemm_B": memory_layout[B],
            "sgemm_C": memory_layout[C],
            "sgemm_M": op.M,
            "sgemm_N": op.N,
            "sgemm_K": op.K
        })

    name_injector = KernelNameInjector(op)

    source = buffer_injector.inject(source)
    source = name_injector.inject(source)

    kernel = Kernel({name_injector.name: source}, name_injector.name,
                    buffer_injector.buffer,
                    buffer_injector.unresolved_value_list)

    return [kernel]
    def optimize(self, graph: Graph) -> Tuple[Graph, bool]:
        flag_changed = False
        for op in traverse.filter_nodes(
                traverse.listup_operators(graph),
                Deconvolution2D):  # type: Deconvolution2D
            x = op.inputs["x"]
            w = op.inputs["w"]
            y = op.outputs["y"]

            assert y.order == OrderNHWC
            assert y.order == OrderNHWC
            assert isinstance(w, ConstantVariable)

            flag_changed = True
            op.remove_all()
            w.change_order(OrderCHWN)

            sgemm = Sgemm(None,
                          M=x.shape_dict[Axis.N] * x.shape_dict[Axis.H] *
                          x.shape_dict[Axis.W],
                          N=w.shape_dict[Axis.H] * w.shape_dict[Axis.W] *
                          w.shape_dict[Axis.N],
                          K=x.shape_dict[Axis.C],
                          out_shape=[
                              x.shape_dict[Axis.N], x.shape_dict[Axis.H],
                              x.shape_dict[Axis.W], w.shape_dict[Axis.H] *
                              w.shape_dict[Axis.W] * w.shape_dict[Axis.N]
                          ],
                          out_order=OrderNHWC,
                          transpose_A=True if x.order == OrderNHWC else False,
                          transpose_B=True)
            col, = sgemm(x, w)

            col2im = Col2Im(None,
                            ksize=op.ksize,
                            stride=op.stride,
                            padding=op.padding)
            new_y, = col2im(col)

            col2im.replace_output(new_y, y)

        return graph, flag_changed
Example #7
0
def test_sgemm():
    op = Sgemm(None,
               M=10,
               N=20,
               K=30,
               out_shape=[1, 10, 4, 5],
               out_order=OrderNHWC,
               transpose_A=True,
               transpose_B=True)

    x = Variable((10, 30), OrderNC)
    w = Variable((20, 30), OrderNC)

    y, = op(x, w)

    assert y.order == OrderNHWC
    assert y.shape_dict[Axis.N] == 1
    assert y.shape_dict[Axis.H] == 10
    assert y.shape_dict[Axis.W] == 4
    assert y.shape_dict[Axis.C] == 5
Example #8
0
    def optimize(self, graph: Graph) -> Tuple[Graph, bool]:
        flag_changed = False
        for op in traverse.filter_nodes(traverse.listup_operators(graph),
                                        Linear):
            x = op.inputs["x"]
            w = op.inputs["w"]
            y = op.outputs["y"]

            flag_changed = True
            op.remove_all()

            a_k = Axis.C
            a_n = w.order.axes[0] if w.order.axes[1] == a_k else w.order.axes[1]
            axes_m = [a for a in x.order.axes if a != a_k]

            K = x.shape_dict[a_k]
            M = x.size // K
            N = w.shape_dict[a_n]

            x, = Transpose(None)(x)
            x.change_order(Order([a_k] + axes_m))

            w, = Transpose(None)(w)
            w.change_order(Order([a_k, a_n]))

            new_y, = Sgemm(None,
                           M=M,
                           N=N,
                           K=K,
                           out_shape=[x.shape_dict[a] for a in axes_m] + [N],
                           out_order=Order(axes_m + [a_n]),
                           transpose_A=False,
                           transpose_B=True)(x, w)
            new_y, = Transpose(None)(new_y)

            OptimizeRule.replace_variable(graph, new_y, y)

        return graph, flag_changed
Example #9
0
def _split_sgemm(graph: Graph, op: Sgemm, v: Variable, v_pair: Sequence[Variable], axis: Axis):
    s1 = v_pair[0].shape_dict[axis]
    s2 = v_pair[1].shape_dict[axis]
    A = op.inputs["A"]
    B = op.inputs["B"]
    C = op.outputs["C"]
    transpose_A, transpose_B = op.transpose_A, op.transpose_B
    M, K, N = op.M, op.K, op.N
    axis_M, axis_K, axis_N = Axis(None), Axis(None), Axis(None)

    op.remove_all()

    def decompose_logical_axes(logical_shape: Tuple[int, int], v: Variable):
        """
        Decompose logical axes into real axes

        Examples::

            A.order, A.shape
            >>> "NCHW", (1, 128, 8, 8)

            M = 128
            K = 64
            decompose_logical_axes([M, K], A)
            >>> ["<Axis N>", "<Axis C>"], ["<Axis H>", "<Axis W>"]
        """
        total_size = 1
        axes1 = []  # type: List[Axis]
        axes2 = list(v.order.axes)  # type: List[Axis]
        for size, a in zip(v.shape, v.order.axes):
            if total_size == logical_shape[0]:
                return axes1, axes2

            elif total_size > logical_shape[0]:
                raise ValueError

            axes1.append(a)
            axes2.remove(a)
            total_size *= size

    if v == A:
        A1, A2 = v_pair
        if transpose_A:  # A.shape = [M, K]
            axes_M, axes_K = decompose_logical_axes((M, K), A)

        else:  # A.shape = [K, M]
            axes_K, axes_M = decompose_logical_axes((K, M), A)

        if axis in axes_K:
            """
            before)

                A -{sgemm}- C

            after) In case `axis` is in `K`,

                A_0 -{sgemm}- C_0 -+
                                   +-{Add}- C
                A_1 -{sgemm}- C_1 -+
            """
            K1, K2 = K * s1 // (s1 + s2), K * s2 // (s1 + s2)

            # Factorize B's axes included in K into A's corresponding axes
            if transpose_B:  # B: [k_b1, k_b2, ..., N] -{reshape}-> [k_a1, k_a2, ..., N]
                B, = Reshape(None,
                             in_order=B.order,
                             out_order=Order(axes_K + [axis_N]),
                             out_shape=[A.shape_dict[a] for a in axes_K] + [N])(B)
            else:  # B: [N, k_b1, k_b2, ...] -{reshape}-> [N, k_a1, k_a2, ...]
                B, = Reshape(None,
                             in_order=B.order,
                             out_order=Order([axis_N] + axes_K),
                             out_shape=[N] + [A.shape_dict[a] for a in axes_K])(B)

            B1, B2 = SplitAxis(None, axis=axis, sections=[s1])(B)

            C1, = Sgemm(None, M=M, K=K1, N=N,
                        transpose_A=transpose_A,
                        transpose_B=transpose_B,
                        out_shape=op.parameters["out_shape"],
                        out_order=op.parameters["out_order"])(A1, B1)

            C2, = Sgemm(None, M=M, K=K2, N=N,
                        transpose_A=transpose_A,
                        transpose_B=transpose_B,
                        out_shape=op.parameters["out_shape"],
                        out_order=op.parameters["out_order"])(A2, B2)

            OptimizeRule.replace_variable(graph, C1 + C2, C)

        else:
            assert axis in axes_M
            """
            before)

                A -{sgemm}- C

            after) In case `axis` is in `M`,

                A_0 -{sgemm}- C_0 -+
                                   +-{Concat}- C
                A_1 -{sgemm}- C_1 -+
            """
            M1, M2 = M * s1 // (s1 + s2), M * s2 // (s1 + s2)

            c_tmp_order = Order(axes_M + [axis_N])
            c1_shape = [A1.shape_dict[a] for a in axes_M] + [N]
            c2_shape = [A2.shape_dict[a] for a in axes_M] + [N]

            C1, = Sgemm(None, M=M1, K=K, N=N,
                        transpose_A=transpose_A,
                        transpose_B=transpose_B,
                        out_shape=c1_shape,
                        out_order=c_tmp_order)(A1, B)

            C2, = Sgemm(None, M=M2, K=K, N=N,
                        transpose_A=transpose_A,
                        transpose_B=transpose_B,
                        out_shape=c2_shape,
                        out_order=c_tmp_order)(A2, B)

            C_new, = Concat(None, axis=axis)(C1, C2)
            C_new, = Reshape(None, in_order=c_tmp_order, out_order=C.order, out_shape=C.shape)(C_new)
            OptimizeRule.replace_variable(graph, C_new, C)

    elif v == B:
        B1, B2 = v_pair
        if transpose_B:  # B.shape = [K, N]
            axes_K, axes_N = decompose_logical_axes((K, N), B)

        else:  # B.shape = [N, K]
            axes_N, axes_K = decompose_logical_axes((N, K), B)

        if axis in axes_K:
            """
            before)
    
                B -{sgemm}- C
    
            after) In case `axis` is in `K`,
    
                B_0 -{sgemm}- C_0 -+
                                   +-{Add}- C
                B_1 -{sgemm}- C_1 -+
            """
            K1, K2 = K * s1 // (s1 + s2), K * s2 // (s1 + s2)

            # Factorize A's axes included in K into B's corresponding axes
            if transpose_A:  # A: [M, k_a1, k_a2, k_a3, ...] -{reshape}-> [M, k_b1, k_b2, ...]
                A, = Reshape(None,
                             in_order=A.order,
                             out_order=Order([axis_M] + axes_K),
                             out_shape=[M] + [B.shape_dict[a] for a in axes_K])(A)
            else:  # A: [k_a1, k_a2, k_a3, ..., M] -{reshape}-> [k_b1, k_b2, ..., M]
                A, = Reshape(None,
                             in_order=A.order,
                             out_order=Order(axes_K + [axis_M]),
                             out_shape=[B.shape_dict[a] for a in axes_K] + [M])(A)

            A1, A2 = SplitAxis(None, axis=axis, sections=[s1])(A)

            C1, = Sgemm(None, M=M, K=K1, N=N,
                        transpose_A=transpose_A,
                        transpose_B=transpose_B,
                        out_shape=op.parameters["out_shape"],
                        out_order=op.parameters["out_order"])(A1, B1)

            C2, = Sgemm(None, M=M, K=K2, N=N,
                        transpose_A=transpose_A,
                        transpose_B=transpose_B,
                        out_shape=op.parameters["out_shape"],
                        out_order=op.parameters["out_order"])(A2, B2)

            OptimizeRule.replace_variable(graph, C1 + C2, C)

        else:
            assert axis in axes_N
            """
            before)
    
                C[M, N] = A[M, K] @ B[K, N]
    
            after) In case `axis` is in `N`,
    
                C[M, N] = Concat(C1[M, N1], C2[M, N2])
                        = Concat(A[M, K] @ B1[K, N1], A[M, K] @ B2[K, N2]) 
            """
            N1, N2 = N * s1 // (s1 + s2), N * s2 // (s1 + s2)

            c_tmp_order = Order([axis_M] + axes_N)
            c1_shape = [M] + [B1.shape_dict[a] for a in axes_N]
            c2_shape = [M] + [B2.shape_dict[a] for a in axes_N]

            C1, = Sgemm(None, M=M, K=K, N=N1,
                        transpose_A=transpose_A,
                        transpose_B=transpose_B,
                        out_shape=c1_shape,
                        out_order=c_tmp_order)(A, B1)
            # C1.shape = [M, B.shape_dict[n1], B.shape_dict[n2], ..., B1.shape_dict[axis], ...]
            # C1.order = [axis_M, n1, n2, ..., axis, ...]

            C2, = Sgemm(None, M=M, K=K, N=N2,
                        transpose_A=transpose_A,
                        transpose_B=transpose_B,
                        out_shape=c2_shape,
                        out_order=c_tmp_order)(A, B2)

            C_new, = Concat(None, axis=axis)(C1, C2)
            # C_new.shape = [M, B.shape_dict[n1], B.shape_dict[n2], ..., B1.shape_dict[axis]+B2.shape_dict[axis], ...]
            # C_new.order = [axis_M, n1, n2, ..., axis, ...]

            C_new, = Reshape(None, in_order=c_tmp_order, out_order=C.order, out_shape=C.shape)(C_new)
            OptimizeRule.replace_variable(graph, C_new, C)

    elif v == C:
        """
        before)

            C[M, N] = A[M, K] @ B[K, N]

        after) In case `axis` is in `N`,

            C[M, N1] = Concat(A[M, K] @ B1[K, N1])
            C[M, N2] = Concat(A[M, K] @ B2[K, N2])
        """
        raise NotImplementedError(f"Variable is too large to handle in WebGL backend: {v}")

    else:
        raise UnexpectedAndPleaseReportError