def symbolic_match(a: ty.handle, b: ty.handle, n: ty.int32,
                   m: ty.int32) -> None:
    A = tir.match_buffer(a, (n * m, m))
    B = tir.match_buffer(b, (n * 2, m * 4))
    for i in range(0, n):
        with tir.block([]):
            tir.reads([])
            tir.writes([A[i * m:i * m + n, 0:m], B[i * n:i * n + 2, 0:m * 4]])
            Bs_0 = tir.var("int32")
            Bs_1 = tir.var("int32")
            sub_A = tir.match_buffer(A[i * m:i * m + m, 0:m], (m, m),
                                     offset_factor=1)
            sub_B = tir.match_buffer(B[i * n:i * n + 2, 0:m * 4], (2, m * 4),
                                     strides=[Bs_0, Bs_1],
                                     offset_factor=1)
            for ii, jj in tir.grid(m, m):
                sub_A[ii, jj] = 1
            for j in range(0, 4):
                tir.evaluate(
                    tir.intrin_test(
                        sub_B.data,
                        sub_B.elem_offset,
                        sub_B.strides[0],
                        sub_B.strides[1],
                        sub_B.shape[0],
                        sub_B.shape[1],
                        dtype="handle",
                    ))
def recursive_match(a: ty.handle, b: ty.handle) -> None:
    A = tir.match_buffer(a, (64, 64, 64))
    B = tir.match_buffer(b, (64, 64, 64))
    for i, j, k in tir.grid(64, 4, 4):
        with tir.block([]):
            tir.reads([])
            tir.writes([
                A[i, j * 16:j * 16 + 16, k * 16:k * 16 + 16],
                B[i, j * 16:j * 16 + 16, k * 16:k * 16 + 16],
            ])
            As_0 = tir.var("int32")
            As_1 = tir.var("int32")
            sub_A = tir.match_buffer(
                A[i, j * 16:j * 16 + 16, k * 16:k * 16 + 16],
                (16, 16),
                strides=[As_0, As_1],
                offset_factor=1,
            )
            sub_B = tir.match_buffer(
                B[i, j * 16:j * 16 + 16, k * 16:k * 16 + 16],
                (16, 16),
                offset_factor=1,
            )
            for jj, kk in tir.grid(4, 4):
                with tir.block([]):
                    tir.reads([])
                    tir.writes([
                        sub_A[jj * 4:jj * 4 + 4, kk * 4:kk * 4 + 4],
                        sub_B[jj * 4:jj * 4 + 4, kk * 4:kk * 4 + 4],
                    ])
                    Ass_0 = tir.var("int32")
                    Ass_1 = tir.var("int32")
                    sub_sub_A = tir.match_buffer(
                        sub_A[jj * 4:jj * 4 + 4, kk * 4:kk * 4 + 4],
                        (4, 4),
                        strides=[Ass_0, Ass_1],
                        offset_factor=1,
                    )
                    sub_sub_B = tir.match_buffer(
                        sub_B[jj * 4:jj * 4 + 4, kk * 4:kk * 4 + 4],
                        (4, 4),
                        offset_factor=1,
                    )
                    tir.evaluate(
                        tir.intrin_test(
                            sub_sub_A.data,
                            sub_sub_A.elem_offset,
                            sub_sub_A.strides[0],
                            sub_sub_A.strides[1],
                            sub_sub_A.shape[0],
                            sub_sub_A.shape[1],
                            dtype="handle",
                        ))
                    for jjj, kkk in tir.grid(4, 4):
                        sub_sub_B[jjj, kkk] = 1
예제 #3
0
def matmul_m_8x(a: ty.handle, b: ty.handle, c: ty.handle) -> None:
    x = tir.var("int32")
    m = tir.var("int32")
    A = tir.match_buffer(a, [m, x * 8])
    B = tir.match_buffer(b, [m, x * 8])
    C = tir.match_buffer(c, [m, m])

    with tir.block([m, m, tir.reduce_axis(0, x * 8)],
                   "update") as [vi, vj, vk]:
        with tir.init():
            C[vi, vj] = 0.0
        C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vj, vk]
예제 #4
0
 def test_tir_fma(A: ty.handle, B: ty.handle, C: ty.handle,
                  d: ty.handle) -> None:
     # function attr dict
     tir.func_attr({"global_symbol": "test_fma", "tir.noalias": True})
     n = tir.var("int32")
     stride = tir.var("int32")
     stride_1 = tir.var("int32")
     stride_2 = tir.var("int32")
     stride_3 = tir.var("int32")
     A_1 = tir.match_buffer(
         A,
         [n],
         strides=[stride],
         elem_offset=0,
         align=128,
         offset_factor=1,
         type="auto",
     )
     B_1 = tir.match_buffer(
         B,
         [n],
         strides=[stride_1],
         elem_offset=0,
         align=128,
         offset_factor=1,
         type="auto",
     )
     C_1 = tir.match_buffer(
         C,
         [n],
         strides=[stride_2],
         elem_offset=0,
         align=128,
         offset_factor=1,
         type="auto",
     )
     d_1 = tir.match_buffer(
         d,
         [n],
         strides=[stride_3],
         elem_offset=0,
         align=128,
         offset_factor=1,
         type="auto",
     )
     # body
     for i in tir.serial(0, n):
         d_1.data[(i * stride_3)] = (tir.load("float32", A_1.data,
                                              (i * stride)) *
                                     tir.load("float32", B_1.data,
                                              (i * stride_1))) + tir.load(
                                                  "float32", C_1.data,
                                                  (i * stride_2))
예제 #5
0
 def tir_packed_call() -> None:
     A = tir.var("handle")
     B = tir.var("handle")
     C = tir.var("handle")
     # body
     tir.evaluate(
         tir.tvm_call_cpacked(
             "tvm_test_cpacked",
             A,
             B,
             C,
             dtype="int32",
         ))
예제 #6
0
def tir_multi_output(a0: ty.handle, a1: ty.handle, b0: ty.handle, b1: ty.handle) -> None:
    m = tir.var("int32")
    n = tir.var("int32")
    A0 = tir.match_buffer(a0, (m, n))
    A1 = tir.match_buffer(a1, (m, n))
    B0 = tir.match_buffer(b0, (m, n))
    B1 = tir.match_buffer(b1, (m, n))

    for i0, i1 in tir.grid(m, n):
        with tir.block([m, n], "B.v0") as [i, j]:
            B0[i, j] = A0[i, j] + 2.0
        with tir.block([m, n], "B.v1") as [i, j]:
            B1[i, j] = A1[i, j] * 3.0
예제 #7
0
def element_wise(a: ty.handle, c: ty.handle) -> None:
    m = tir.var("int32")
    n = tir.var("int32")
    A = tir.match_buffer(a, (m, n), "float32")
    C = tir.match_buffer(c, (m, n), "float32")

    B = tir.alloc_buffer((m, n), "float32")

    with tir.block([m, n], "B") as [vi, vj]:
        B[vi, vj] = A[vi, vj] * 2.0

    with tir.block([m, n], "C") as [vi, vj]:
        C[vi, vj] = B[vi, vj] + 1.0
def opaque_access(a: ty.handle, b: ty.handle) -> None:
    A = tir.match_buffer(a, (32, 64, 128))
    B = tir.match_buffer(b, (64, 64, 64))
    for i, j, k in tir.grid(2, 64, 8):
        with tir.block([]):
            tir.reads([])
            tir.writes(A[i * 16:i * 16 + 16, j, k * 16:k * 16 + 16])
            sub_A = tir.match_buffer(
                A[i * 16:i * 16 + 16, j, k * 16:k * 16 + 16],
                (16, 1, 16),
                strides=[8192, 128, 1],
                offset_factor=1,
            )
            tir.evaluate(
                tir.intrin_test(
                    sub_A.data,
                    sub_A.elem_offset,
                    sub_A.strides[0],
                    sub_A.strides[1],
                    sub_A.shape[0],
                    sub_A.shape[1],
                    dtype="handle",
                ))
    for i, j, k in tir.grid(64, 2, 8):
        with tir.block([]):
            Bs_0 = tir.var("int32")
            Bs_1 = tir.var("int32")
            tir.reads([])
            tir.writes(B[i, j * 32:j * 32 + 32, k * 8:k * 8 + 8])
            sub_B = tir.match_buffer(
                B[i, j * 32:j * 32 + 32, k * 8:k * 8 + 8],
                (32, 8),
                strides=[Bs_0, Bs_1],
                offset_factor=1,
            )
            tir.evaluate(
                tir.intrin_test(
                    sub_B.data,
                    sub_B.elem_offset,
                    sub_B.strides[0],
                    sub_B.strides[1],
                    sub_B.shape[0],
                    sub_B.shape[1],
                    dtype="handle",
                ))
def fail_buffer_bind(a: ty.handle) -> None:
    A = tir.match_buffer(a, (8, 8))
    for i, j in tir.grid(8, 2):
        with tir.block([]):
            stride = tir.var("int32")
            sub_A = tir.match_buffer(A[i, j * 4:j * 4 + 4], (1, 4),
                                     strides=[stride, stride],
                                     offset_factor=1)
            for jj in range(0, 4):
                sub_A[i, j * 4 + jj] = 1
예제 #10
0
    def tir_packed_call() -> None:
        A = tir.var("handle")
        B = tir.var("handle")
        C = tir.var("handle")

        # body
        tvm_value_2 = tir.var("handle")
        tvm_value_1 = tir.var("handle")
        tvm_value_0 = tir.var("handle")
        with tir.let(tvm_value_2,
                     tir.tvm_stack_alloca("array", 1, dtype="handle")):
            with tir.let(tvm_value_1,
                         tir.tvm_stack_alloca("array", 1, dtype="handle")):
                with tir.let(tvm_value_0,
                             tir.tvm_stack_alloca("array", 1, dtype="handle")):
                    tir.evaluate(
                        tir.tvm_struct_set(tvm_value_0,
                                           0,
                                           1,
                                           A,
                                           dtype="handle"))
                    tir.evaluate(
                        tir.tvm_struct_set(tvm_value_1,
                                           0,
                                           1,
                                           B,
                                           dtype="handle"))
                    tir.evaluate(
                        tir.tvm_struct_set(tvm_value_2,
                                           0,
                                           1,
                                           C,
                                           dtype="handle"))
                    tir.evaluate(
                        tir.tvm_call_cpacked(
                            "tvm_test_cpacked",
                            tvm_value_0,
                            tvm_value_1,
                            tvm_value_2,
                            dtype="int32",
                        ))
예제 #11
0
def param_in_arith_exprs_n_16(a: ty.handle, b: ty.handle) -> None:
    n = tir.var("int32")
    A = tir.match_buffer(a, [2, 8], "int32")
    B = tir.match_buffer(b, [16], "int32")
    with tir.block([15], "") as [vi]:
        B[vi] = A[vi // 8, vi % 8] + 714
예제 #12
0
def param_in_arith_exprs(a: ty.handle, b: ty.handle) -> None:
    n = tir.var("int32")
    A = tir.match_buffer(a, [n // 8, 8], "int32")
    B = tir.match_buffer(b, [n], "int32")
    with tir.block([n - 1], "") as [vi]:
        B[vi] = A[vi // 8, vi % 8] + (n + 1) * 42
예제 #13
0
def tensorcore_gemm(a: ty.handle, b: ty.handle, c: ty.handle) -> None:
    # match buffer
    A = tir.match_buffer(a, [1024, 1024], "float16")
    B = tir.match_buffer(b, [1024, 1024], "float16")
    C = tir.match_buffer(c, [1024, 1024], "float32")

    # body
    for blockIdx_x in tir.thread_binding(0, 16, "blockIdx.x"):
        for blockIdx_y in tir.thread_binding(0, 8, "blockIdx.y"):
            with tir.block([16, 8]) as [bx, by]:
                tir.bind(bx, blockIdx_x)
                tir.bind(by, blockIdx_y)
                shared_A = tir.alloc_buffer([1024, 1024], "float16", scope="shared")
                shared_B = tir.alloc_buffer([1024, 1024], "float16", scope="shared")
                wmma_A = tir.alloc_buffer([1024, 1024], "float16", scope="wmma.matrix_a")
                wmma_B = tir.alloc_buffer([1024, 1024], "float16", scope="wmma.matrix_b")
                wmma_C = tir.alloc_buffer([1024, 1024], "float32", scope="wmma.accumulator")
                for ty in tir.thread_binding(0, 2, "threadIdx.y"):
                    for tz in tir.thread_binding(0, 2, "threadIdx.z"):
                        for i, j in tir.grid(2, 4):
                            with tir.block([64, 64]) as [vi, vj]:
                                tir.bind(vi, bx * 4 + ty * 2 + i)
                                tir.bind(vj, by * 8 + tz * 4 + j)
                                tir.reads([])
                                tir.writes(wmma_C[vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16])
                                C0 = tir.match_buffer(
                                    wmma_C[vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16],
                                    (16, 16),
                                    "float32",
                                    strides=[16 * 4, 1],
                                    scope="wmma.accumulator",
                                    offset_factor=1,
                                )
                                tir.evaluate(
                                    tir.tvm_fill_fragment(
                                        C0.data,
                                        16,
                                        16,
                                        16,
                                        i * 4 + j,
                                        tir.float32(0),
                                        dtype="handle",
                                    )
                                )

                        for ko in range(0, 32):
                            # copy data from global to shared
                            for tx in tir.thread_binding(0, 32, "threadIdx.x"):
                                for i0, j0 in tir.grid(1, 4):
                                    for j1 in tir.vectorized(0, 4):
                                        with tir.block([1024, 1024]) as [vi, vj]:
                                            tir.bind(vi, bx * 64 + ty * 32 + tx + i0)
                                            tir.bind(vj, ko * 32 + tz * 16 + j0 * 4 + j1)
                                            shared_A[vi, vj + 8] = A[vi, vj]

                                for i0, j0 in tir.grid(2, 4):
                                    for j1 in tir.vectorized(0, 4):
                                        with tir.block([1024, 1024]) as [vi, vj]:
                                            tir.bind(vi, by * 128 + ty * 64 + tx * 2 + i0)
                                            tir.bind(vj, ko * 32 + tz * 16 + j0 * 4 + j1)
                                            shared_B[vi, vj + 8] = B[vi, vj]

                            for ki in range(0, 2):
                                for i in range(0, 2):
                                    with tir.block([64, 64]) as [vi, vk]:
                                        tir.bind(vi, bx * 4 + ty * 2 + i)
                                        tir.bind(vk, ko * 2 + ki)
                                        tir.reads(
                                            shared_A[
                                                vi * 16 : vi * 16 + 16,
                                                vk * 16 : vk * 16 + 16 + 8,
                                            ]
                                        )
                                        tir.writes(
                                            wmma_A[vi * 16 : vi * 16 + 16, vk * 16 : vk * 16 + 16]
                                        )
                                        s0 = tir.var("int32")
                                        s1 = tir.var("int32")
                                        A0 = tir.match_buffer(
                                            shared_A[
                                                vi * 16 : vi * 16 + 16,
                                                vk * 16 : vk * 16 + 16 + 8,
                                            ],
                                            (16, 16 + 8),
                                            "float16",
                                            strides=[s0, s1],
                                            scope="shared",
                                            offset_factor=1,
                                        )
                                        wmma_A0 = tir.match_buffer(
                                            wmma_A[vi * 16 : vi * 16 + 16, vk * 16 : vk * 16 + 16],
                                            (16, 16),
                                            "float16",
                                            strides=[16, 1],
                                            scope="wmma.matrix_a",
                                            offset_factor=1,
                                        )
                                        tir.evaluate(
                                            tir.tvm_load_matrix_sync(
                                                wmma_A0.data,
                                                16,
                                                16,
                                                16,
                                                i,
                                                tir.tvm_access_ptr(
                                                    tir.type_annotation(dtype="float16"),
                                                    A0.data,
                                                    A0.elem_offset + 8,
                                                    A0.strides[0],
                                                    1,
                                                    dtype="handle",
                                                ),
                                                A0.strides[0],
                                                "row_major",
                                                dtype="handle",
                                            )
                                        )
                                for j in range(0, 4):
                                    with tir.block([64, 64]) as [vj, vk]:
                                        tir.bind(vj, by * 8 + tz * 4 + j)
                                        tir.bind(vk, ko * 2 + ki)
                                        tir.reads(
                                            shared_B[
                                                vj * 16 : vj * 16 + 16,
                                                vk * 16 : vk * 16 + 16 + 8,
                                            ]
                                        )
                                        tir.writes(
                                            wmma_B[vj * 16 : vj * 16 + 16, vk * 16 : vk * 16 + 16]
                                        )
                                        s0 = tir.var("int32")
                                        s1 = tir.var("int32")
                                        B0 = tir.match_buffer(
                                            shared_B[
                                                vj * 16 : vj * 16 + 16,
                                                vk * 16 : vk * 16 + 16 + 8,
                                            ],
                                            (16, 16 + 8),
                                            "float16",
                                            strides=[s0, s1],
                                            scope="shared",
                                            offset_factor=1,
                                        )
                                        wmma_B0 = tir.match_buffer(
                                            wmma_B[vj * 16 : vj * 16 + 16, vk * 16 : vk * 16 + 16],
                                            (16, 16),
                                            "float16",
                                            strides=[16, 1],
                                            scope="wmma.matrix_b",
                                            offset_factor=1,
                                        )
                                        tir.evaluate(
                                            tir.tvm_load_matrix_sync(
                                                wmma_B0.data,
                                                16,
                                                16,
                                                16,
                                                j,
                                                tir.tvm_access_ptr(
                                                    tir.type_annotation(dtype="float16"),
                                                    B0.data,
                                                    B0.elem_offset + 8,
                                                    B0.strides[0],
                                                    1,
                                                    dtype="handle",
                                                ),
                                                B0.strides[0],
                                                "col_major",
                                                dtype="handle",
                                            )
                                        )
                                for i, j in tir.grid(2, 4):
                                    with tir.block([64, 64, tir.reduce_axis(0, 64)]) as [
                                        vi,
                                        vj,
                                        vk,
                                    ]:
                                        tir.bind(vi, bx * 4 + ty * 2 + i)
                                        tir.bind(vj, by * 8 + tz * 4 + j)
                                        tir.bind(vk, ko * 2 + ki)
                                        tir.reads(
                                            [
                                                wmma_A[
                                                    vi * 16 : vi * 16 + 16, vk * 16 : vk * 16 + 16
                                                ],
                                                wmma_B[
                                                    vj * 16 : vj * 16 + 16, vk * 16 : vk * 16 + 16
                                                ],
                                                wmma_C[
                                                    vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16
                                                ],
                                            ]
                                        )
                                        tir.writes(
                                            wmma_C[vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16]
                                        )
                                        wmma_A1 = tir.match_buffer(
                                            wmma_A[vi * 16 : vi * 16 + 16, vk * 16 : vk * 16 + 16],
                                            (16, 16),
                                            "float16",
                                            strides=[16, 1],
                                            scope="wmma.matrix_a",
                                            offset_factor=1,
                                        )
                                        wmma_B1 = tir.match_buffer(
                                            wmma_B[vj * 16 : vj * 16 + 16, vk * 16 : vk * 16 + 16],
                                            (16, 16),
                                            "float16",
                                            strides=[16, 1],
                                            scope="wmma.matrix_b",
                                            offset_factor=1,
                                        )
                                        wmma_C1 = tir.match_buffer(
                                            wmma_C[vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16],
                                            (16, 16),
                                            "float32",
                                            strides=[16 * 4, 1],
                                            scope="wmma.accumulator",
                                            offset_factor=1,
                                        )
                                        tir.evaluate(
                                            tir.tvm_mma_sync(
                                                wmma_C1.data,
                                                i * 4 + j,
                                                wmma_A1.data,
                                                i,
                                                wmma_B1.data,
                                                j,
                                                wmma_C1.data,
                                                i * 4 + j,
                                                dtype="handle",
                                            )
                                        )
                        for i, j in tir.grid(2, 4):
                            with tir.block([64, 64]) as [vi, vj]:
                                tir.bind(vi, bx * 4 + ty * 2 + i)
                                tir.bind(vj, by * 8 + tz * 4 + j)
                                tir.reads(wmma_C[vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16])
                                tir.writes(C[vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16])
                                s0 = tir.var("int32")
                                s1 = tir.var("int32")
                                wmma_C2 = tir.match_buffer(
                                    wmma_C[vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16],
                                    (16, 16),
                                    "float32",
                                    strides=[16 * 4, 1],
                                    scope="wmma.accumulator",
                                    offset_factor=1,
                                )
                                C1 = tir.match_buffer(
                                    C[vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16],
                                    (16, 16),
                                    "float32",
                                    strides=[s0, s1],
                                    offset_factor=1,
                                )
                                tir.evaluate(
                                    tir.tvm_store_matrix_sync(
                                        wmma_C2.data,
                                        16,
                                        16,
                                        16,
                                        i * 4 + j,
                                        tir.tvm_access_ptr(
                                            tir.type_annotation(dtype="float32"),
                                            C1.data,
                                            C1.elem_offset,
                                            C1.strides[0],
                                            1,
                                            dtype="handle",
                                        ),
                                        C1.strides[0],
                                        "row_major",
                                        dtype="handle",
                                    )
                                )