Exemple #1
0
def element_wise_invalid_annotation(a: ty.handle, c: ty.handle) -> None:
    C = tir.match_buffer(c, [128, 128],
                         elem_offset=0,
                         align=128,
                         offset_factor=1)
    A = tir.match_buffer(a, [128, 128],
                         elem_offset=0,
                         align=128,
                         offset_factor=1)
    # body
    with tir.block([], "root"):
        tir.reads([])
        tir.writes([])
        B = tir.alloc_buffer([128, 128],
                             elem_offset=0,
                             align=128,
                             offset_factor=1)
        for i0 in tir.serial(0, 128):
            for ax1 in tir.serial(0, 128):
                with tir.block([128, 128], "B") as [vi, vj]:
                    tir.block_attr({"buffer_dim_align": [0]})
                    tir.bind(vi, i0)
                    tir.bind(vj, ax1)
                    tir.reads([A[vi, vj]])
                    tir.writes([B[vi, vj]])
                    B[vi, vj] = (A[vi, vj] * tir.float32(2))
            for i1 in tir.serial(0, 128):
                with tir.block([128, 128], "C") as [vi_1, vj_1]:
                    tir.bind(vi_1, i0)
                    tir.bind(vj_1, i1)
                    tir.reads([B[vi_1, vj_1]])
                    tir.writes([C[vi_1, vj_1]])
                    C[vi_1, vj_1] = (B[vi_1, vj_1] + tir.float32(1))
Exemple #2
0
def get_valid_counts(
    data: ty.handle,
    valid_count: ty.handle,
    out: ty.handle,
    out_indices: ty.handle,
    score_threshold: ty.float32,
    id_index: ty.int32,
    score_index: ty.int32,
) -> None:

    data_buf = tir.match_buffer(data, (1, 2500, 6), "float32")
    valid_count_buf = tir.match_buffer(valid_count, (1, ), "int32")
    out_buf = tir.match_buffer(out, (1, 2500, 6), "float32")
    out_indices_buf = tir.match_buffer(out_indices, (1, 2500), "int32")

    with tir.block([1], "init") as [vi]:
        valid_count_buf[vi] = tir.int32(0)
        with tir.block([2500], "update") as [vj]:
            tir.reads([data_buf[vi, vj, 6]])
            tir.writes([
                valid_count_buf[vi], out_indices_buf[vi, vj], out_buf[vi, vj,
                                                                      6]
            ])
            if (data_buf[vi, vj, score_index] > score_threshold) and (
                (id_index < 0) or
                (data_buf[vi, vj, id_index] >= tir.float32(0))):
                for k in tir.serial(0, 6):
                    out_buf[vi, valid_count_buf[vi], k] = data_buf[vi, vj, k]
                out_indices_buf[vi, valid_count_buf[vi]] = vj
                valid_count_buf[vi] = valid_count_buf[vi] + 1
            if vj >= valid_count_buf[vi]:
                for k in tir.serial(0, 6):
                    out_buf[vi, vj, k] = tir.float32(-1)
                out_indices_buf[vi, vj] = tir.int32(-1)
Exemple #3
0
def func_with_opaque_block(a: ty.handle, b: ty.handle, c: ty.handle) -> None:
    A = tir.match_buffer(a, [128, 128])
    B = tir.match_buffer(b, [128, 128])
    C = tir.match_buffer(c, [128, 128])

    with tir.block([]) as []:
        with tir.block([]) as []:
            B[0, 0] = A[0, 0] + tir.float32(1)

        with tir.block([128, 128]) as [vi, vj]:
            C[vi, vj] = B[vi, vj] + tir.float32(1)
Exemple #4
0
def elementwise_with_root(a: ty.handle, b: ty.handle, c: ty.handle) -> None:
    A = tir.match_buffer(a, [128, 128])
    B = tir.match_buffer(b, [128, 128])
    C = tir.match_buffer(c, [128, 128])

    with tir.block([]) as []:
        with tir.block([128, 128]) as [vi, vj]:
            B[vi, vj] = A[vi, vj] + tir.float32(1)

        with tir.block([128, 128]) as [vi, vj]:
            C[vi, vj] = B[vi, vj] + tir.float32(1)
def func_with_part_access_region(a: ty.handle, b: ty.handle, c: ty.handle) -> None:
    A = tir.match_buffer(a, [128, 128])
    B = tir.match_buffer(b, [128, 128])
    C = tir.match_buffer(c, [128, 128])

    with tir.block([]) as []:
        with tir.block([128, 128]) as [vi, vj]:
            tir.reads(A[vi, vj])
            B[vi, vj] = A[vi, vj] + tir.float32(1)

        with tir.block([128, 128]) as [vi, vj]:
            tir.writes(C[vi, vj])
            C[vi, vj] = B[vi, vj] + tir.float32(1)
Exemple #6
0
def buffer_opaque_access(b: ty.handle, c: ty.handle) -> None:
    B = tir.match_buffer(b, [16, 16], "float32")
    C = tir.match_buffer(c, [16, 16], "float32")

    with tir.block([]):
        tir.reads([])
        tir.writes(B[0:16, 0:16])
        A = tir.allocate([256], "float32", "global")
        for i, j in tir.grid(16, 16):
            tir.store(A, i * 16 + j, 1)
        for i in range(0, 16):
            for j in range(0, 16):
                tir.evaluate(tir.load("float32", A, i * 16 + j))
            for j in range(0, 16):
                tir.evaluate(
                    tir.tvm_fill_fragment(B.data,
                                          16,
                                          16,
                                          16,
                                          0,
                                          tir.float32(0),
                                          dtype="handle"))

    for i, j in tir.grid(16, 16):
        with tir.block([16, 16]) as [vi, vj]:
            tir.bind(vi, i)
            tir.bind(vj, j)
            C[vi, vj] = B[vi, vj]
def transformed_func() -> None:
    A = tir.alloc_buffer([128, 128])
    with tir.block([128, 128], "") as [i, j]:
        A[i, j] = tir.float32(0)
    with tir.block([32, 32, tir.reduce_axis(0, 32)], "") as [i, j, k]:
        B = tir.alloc_buffer([128, 128])
        if k == 0:
            for ii, jj in tir.grid(4, 4):
                B[i * 4 + ii, j * 4 + jj] = A[i * 4 + ii, j * 4 + jj]
        for ii, jj in tir.grid(4, 4):
            with tir.block([], ""):
                tir.reads([B[((i * 4) + ii), ((j * 4) + jj)]])
                tir.writes([B[((i * 4) + ii), ((j * 4) + jj)]])
                C = tir.alloc_buffer([128, 128])
                for kk in tir.serial(0, 4):
                    B[((i * 4) + ii),
                      ((j * 4) + jj)] = (B[((i * 4) + ii),
                                           ((j * 4) + jj)] + C[((i * 4) + ii),
                                                               ((k * 4) + kk)])
                for kk in tir.serial(0, 4):
                    with tir.block([], ""):
                        tir.reads([
                            B[((i * 4) + ii), ((j * 4) + jj)],
                            C[((i * 4) + ii), ((k * 4) + kk)],
                        ])
                        tir.writes([B[((i * 4) + ii), ((j * 4) + jj)]])
                        D = tir.alloc_buffer([128, 128])
                        B[((i * 4) + ii),
                          ((j * 4) +
                           jj)] = B[((i * 4) + ii),
                                    ((j * 4) + jj)] + (D[((j * 4) + jj), (
                                        (k * 4) + kk)] * C[((i * 4) + ii),
                                                           ((k * 4) + kk)])
Exemple #8
0
def matmul(a: ty.handle, b: ty.handle, c: ty.handle) -> None:
    A = tir.match_buffer(a, [128, 128])
    B = tir.match_buffer(b, [128, 128])
    C = tir.match_buffer(c, [128, 128])

    with tir.block([128, 128, tir.reduce_axis(0, 128)], "update") as [vi, vj, vk]:
        with tir.init():
            C[vi, vj] = tir.float32(0)
        C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vj, vk]
Exemple #9
0
    def main(a: ty.handle, b: ty.handle) -> None:
        A = tir.match_buffer(a, [64, 64, 64])
        B = tir.match_buffer(b, [64])

        with tir.block([64,
                        tir.reduce_axis(0, 64),
                        tir.reduce_axis(32, 64)]) as [i, j, k]:
            if (j == 0) and (k == 32):
                B[i] = tir.float32(0)
            B[i] += A[i, j, k]
Exemple #10
0
def matmul(a: ty.handle, b: ty.handle, c: ty.handle) -> None:
    A = tir.match_buffer(a, [128, 128])
    B = tir.match_buffer(b, [128, 128])
    C = tir.match_buffer(c, [128, 128])
    for i, j in tir.grid(128, 128):
        with tir.block([128, 128], "init") as [vi, vj]:
            C[vi, vj] = tir.float32(0)
        for k in range(0, 128):
            with tir.block([128, 128, tir.reduce_axis(0, 128)], "update") as [vi, vj, vk]:
                C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vj, vk]
Exemple #11
0
    def main(a: ty.handle, b: ty.handle) -> None:
        A = tir.match_buffer(a, [64, 64, 64])
        B = tir.match_buffer(b, [64])

        with tir.block([64,
                        tir.reduce_axis(0, 64),
                        tir.reduce_axis(32, 64)]) as [i, j, k]:
            BB = tir.match_buffer(B[i], ())
            AA = tir.match_buffer(A[i, 0:64, 0:64], (64, 64))
            if (j == 0) and (k == 32):
                BB[()] = tir.float32(0)
            BB[()] += AA[j, k]
Exemple #12
0
def buffer_load_store_func(a: ty.handle, b: ty.handle) -> None:
    A = tir.match_buffer(a, (128, 128), "float32")
    B = tir.match_buffer(b, (128, 128), "float32")
    C = tir.alloc_buffer((128, 128), "float32")
    D = tir.alloc_buffer((128, 128), "float32")
    with tir.block([128, 128]) as [i, j]:
        A[i, j] = tir.float32(0)
    with tir.block([32, 32, tir.reduce_axis(0, 32)]) as [i, j, k]:
        with tir.init():
            for ii, jj in tir.grid(4, 4):
                B[i * 4 + ii, j * 4 + jj] = A[i * 4 + ii, j * 4 + jj]
        for ii, jj in tir.grid(4, 4):
            for kk in range(0, 4):
                B[i * 4 + ii, j * 4 + jj] += C[i * 4 + ii, k * 4 + kk]
            for kk in range(0, 4):
                B[i * 4 + ii, j * 4 + jj] += D[j * 4 + jj, k * 4 + kk] * C[i * 4 + ii, k * 4 + kk]
def original_func() -> None:
    A = tir.alloc_buffer((128, 128), "float32")
    with tir.block([128, 128]) as [i, j]:
        A[i, j] = tir.float32(0)
    with tir.block([32, 32, tir.reduce_axis(0, 32)]) as [i, j, k]:
        B = tir.alloc_buffer((128, 128), "float32")
        C = tir.alloc_buffer((128, 128), "float32")
        D = tir.alloc_buffer((128, 128), "float32")
        if k == 0:
            for ii, jj in tir.grid(4, 4):
                B[i * 4 + ii, j * 4 + jj] = A[i * 4 + ii, j * 4 + jj]
        for ii, jj in tir.grid(4, 4):
            for kk in range(0, 4):
                B[i * 4 + ii, j * 4 + jj] += C[i * 4 + ii, k * 4 + kk]
            for kk in range(0, 4):
                B[i * 4 + ii, j * 4 +
                  jj] += D[j * 4 + jj, k * 4 + kk] * C[i * 4 + ii, k * 4 + kk]
Exemple #14
0
def matmul_original(a: ty.handle, b: ty.handle, c: ty.handle) -> None:
    A = tir.match_buffer(a, [128, 128])
    B = tir.match_buffer(b, [128, 128])
    C = tir.match_buffer(c, [128, 128])

    for i, j in tir.grid(32, 32):
        with tir.block([32, 32], "init") as [vi, vj]:
            for ii, jj in tir.grid(4, 4):
                C[vi * 4 + ii, vj * 4 + jj] = tir.float32(0)

        for k in range(0, 32):
            with tir.block([128, 128, tir.reduce_axis(0, 128)], "update") as [vi, vj, vk]:
                for ii, jj, kk in tir.grid(4, 4, 4):
                    C[vi * 4 + ii, vj * 4 + jj] = (
                        C[vi * 4 + ii, vj * 4 + jj]
                        + A[vi * 4 + ii, vk * 4 + kk] * B[vj * 4 + jj, vk * 4 + kk]
                    )
Exemple #15
0
def cuda_matmul_3(a: ty.handle, b: ty.handle, c: ty.handle) -> None:  # pylint: disable=undefined-loop-variable
    A = tir.match_buffer(a, [2048, 2048], "float32")
    B = tir.match_buffer(b, [2048, 2048], "float32")
    C = tir.match_buffer(c, [2048, 2048], "float32")
    A_shared = tir.alloc_buffer([2048, 2048], "float32", scope="shared")
    B_shared = tir.alloc_buffer([2048, 2048], "float32", scope="shared")
    A_shared_local = tir.alloc_buffer([2048, 2048], "float32", scope="local")
    B_shared_local = tir.alloc_buffer([2048, 2048], "float32", scope="local")
    C_local = tir.alloc_buffer([2048, 2048], "float32", scope="local")
    with tir.block([2048, 2048], "A_shared") as [v0, v1]:
        A_shared[v0, v1] = A[v0, v1]
    with tir.block([2048, 2048], "B_shared") as [v0, v1]:
        B_shared[v0, v1] = B[v0, v1]
    for by in tir.thread_binding(0, 32, thread="blockIdx.y"):
        for bx in tir.thread_binding(0, 32, thread="blockIdx.x"):
            for vy in tir.thread_binding(0, 2, thread="vthread.y"):
                for vx in tir.thread_binding(0, 2, thread="vthread.x"):
                    for ty in tir.thread_binding(0, 8, thread="threadIdx.y"):
                        for tx in tir.thread_binding(0,
                                                     8,
                                                     thread="threadIdx.x"):
                            for k0 in tir.serial(0, 256):
                                for k1 in tir.unroll(0, 8):
                                    for i, j in tir.grid(1, 4):
                                        with tir.block(
                                            [2048, 2048],
                                                "A_shared_local") as [v0, v1]:
                                            tir.bind(v0, k0 * 8 + k1 + i)
                                            tir.bind(
                                                v1,
                                                by * 64 + vy * 32 + ty * 4 + j)
                                            A_shared_local[v0,
                                                           v1] = A_shared[v0,
                                                                          v1]
                                    for i, j in tir.grid(1, 4):
                                        with tir.block(
                                            [2048, 2048],
                                                "B_shared_local") as [v0, v1]:
                                            tir.bind(v0, k0 * 8 + k1 + i)
                                            tir.bind(
                                                v1,
                                                bx * 64 + vx * 32 + tx * 4 + j)
                                            B_shared_local[v0,
                                                           v1] = B_shared[v0,
                                                                          v1]
                                    for _, i, j in tir.grid(1, 4, 4):
                                        with tir.block([
                                                2048, 2048,
                                                tir.reduce_axis(0, 2048)
                                        ], "C") as [vi, vj, vk]:
                                            tir.bind(
                                                vi,
                                                by * 64 + vy * 32 + ty * 4 + i)
                                            tir.bind(
                                                vj,
                                                bx * 64 + vx * 32 + tx * 4 + j)
                                            tir.bind(vk, k0 * 8 + k1)
                                            with tir.init():
                                                C_local[vi,
                                                        vj] = tir.float32(0)
                                            C_local[vi, vj] = C_local[
                                                vi, vj] + A_shared_local[
                                                    vk, vi] * B_shared_local[
                                                        vk, vj]
                            for i, j in tir.grid(4, 4):
                                with tir.block([2048, 2048],
                                               "C_local") as [v0, v1]:
                                    tir.bind(v0,
                                             by * 64 + vy * 32 + ty * 4 + i)
                                    tir.bind(v1,
                                             bx * 64 + vx * 32 + tx * 4 + j)
                                    C[v0, v1] = C_local[v0, v1]
Exemple #16
0
 def main(
     placeholder: ty.handle,
     placeholder_1: ty.handle,
     placeholder_2: ty.handle,
     ethosu_conv2d: ty.handle,
 ) -> None:
     # function attr dict
     tir.func_attr({"global_symbol": "main", "tir.noalias": True})
     placeholder_3 = tir.match_buffer(placeholder, [1, 8, 8, 3],
                                      dtype="uint8",
                                      elem_offset=0,
                                      align=128,
                                      offset_factor=1)
     placeholder_4 = tir.match_buffer(placeholder_1, [48],
                                      dtype="uint8",
                                      elem_offset=0,
                                      align=128,
                                      offset_factor=1)
     placeholder_5 = tir.match_buffer(placeholder_2, [16],
                                      dtype="int32",
                                      elem_offset=0,
                                      align=128,
                                      offset_factor=1)
     ethosu_conv2d_1 = tir.match_buffer(ethosu_conv2d, [1, 8, 8, 16],
                                        dtype="uint8",
                                        elem_offset=0,
                                        align=128,
                                        offset_factor=1)
     # body
     tir.evaluate(
         tir.call_extern(
             "ethosu_conv2d",
             "uint8",
             8,
             8,
             3,
             8,
             0,
             8,
             tir.load("uint8", placeholder_3.data, 0),
             0,
             0,
             0,
             tir.float32(0.5),
             10,
             "NHWC",
             24,
             3,
             1,
             "uint8",
             8,
             8,
             16,
             8,
             0,
             8,
             tir.load("uint8", ethosu_conv2d_1.data, 0),
             0,
             0,
             0,
             tir.float32(0.25),
             14,
             "NHWC",
             128,
             16,
             1,
             1,
             1,
             1,
             1,
             1,
             1,
             tir.load("uint8", placeholder_4.data, 0),
             0,
             12,
             tir.load("uint8", placeholder_5.data, 0),
             0,
             0,
             0,
             0,
             0,
             "CLIP",
             0,
             0,
             "NONE",
             dtype="uint8",
         ))
Exemple #17
0
def tensorcore_gemm(a: ty.handle, b: ty.handle, c: ty.handle) -> None:
    # match buffer
    A = tir.match_buffer(a, [1024, 1024], "float16")
    B = tir.match_buffer(b, [1024, 1024], "float16")
    C = tir.match_buffer(c, [1024, 1024], "float32")

    # body
    for blockIdx_x in tir.thread_binding(0, 16, "blockIdx.x"):
        for blockIdx_y in tir.thread_binding(0, 8, "blockIdx.y"):
            with tir.block([16, 8]) as [bx, by]:
                tir.bind(bx, blockIdx_x)
                tir.bind(by, blockIdx_y)
                shared_A = tir.alloc_buffer([1024, 1024], "float16", scope="shared")
                shared_B = tir.alloc_buffer([1024, 1024], "float16", scope="shared")
                wmma_A = tir.alloc_buffer([1024, 1024], "float16", scope="wmma.matrix_a")
                wmma_B = tir.alloc_buffer([1024, 1024], "float16", scope="wmma.matrix_b")
                wmma_C = tir.alloc_buffer([1024, 1024], "float32", scope="wmma.accumulator")
                for ty in tir.thread_binding(0, 2, "threadIdx.y"):
                    for tz in tir.thread_binding(0, 2, "threadIdx.z"):
                        for i, j in tir.grid(2, 4):
                            with tir.block([64, 64]) as [vi, vj]:
                                tir.bind(vi, bx * 4 + ty * 2 + i)
                                tir.bind(vj, by * 8 + tz * 4 + j)
                                tir.reads([])
                                tir.writes(wmma_C[vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16])
                                C0 = tir.match_buffer(
                                    wmma_C[vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16],
                                    (16, 16),
                                    "float32",
                                    strides=[16 * 4, 1],
                                    scope="wmma.accumulator",
                                    offset_factor=1,
                                )
                                tir.evaluate(
                                    tir.tvm_fill_fragment(
                                        C0.data,
                                        16,
                                        16,
                                        16,
                                        i * 4 + j,
                                        tir.float32(0),
                                        dtype="handle",
                                    )
                                )

                        for ko in range(0, 32):
                            # copy data from global to shared
                            for tx in tir.thread_binding(0, 32, "threadIdx.x"):
                                for i0, j0 in tir.grid(1, 4):
                                    for j1 in tir.vectorized(0, 4):
                                        with tir.block([1024, 1024]) as [vi, vj]:
                                            tir.bind(vi, bx * 64 + ty * 32 + tx + i0)
                                            tir.bind(vj, ko * 32 + tz * 16 + j0 * 4 + j1)
                                            shared_A[vi, vj + 8] = A[vi, vj]

                                for i0, j0 in tir.grid(2, 4):
                                    for j1 in tir.vectorized(0, 4):
                                        with tir.block([1024, 1024]) as [vi, vj]:
                                            tir.bind(vi, by * 128 + ty * 64 + tx * 2 + i0)
                                            tir.bind(vj, ko * 32 + tz * 16 + j0 * 4 + j1)
                                            shared_B[vi, vj + 8] = B[vi, vj]

                            for ki in range(0, 2):
                                for i in range(0, 2):
                                    with tir.block([64, 64]) as [vi, vk]:
                                        tir.bind(vi, bx * 4 + ty * 2 + i)
                                        tir.bind(vk, ko * 2 + ki)
                                        tir.reads(
                                            shared_A[
                                                vi * 16 : vi * 16 + 16,
                                                vk * 16 : vk * 16 + 16 + 8,
                                            ]
                                        )
                                        tir.writes(
                                            wmma_A[vi * 16 : vi * 16 + 16, vk * 16 : vk * 16 + 16]
                                        )
                                        s0 = tir.var("int32")
                                        s1 = tir.var("int32")
                                        A0 = tir.match_buffer(
                                            shared_A[
                                                vi * 16 : vi * 16 + 16,
                                                vk * 16 : vk * 16 + 16 + 8,
                                            ],
                                            (16, 16 + 8),
                                            "float16",
                                            strides=[s0, s1],
                                            scope="shared",
                                            offset_factor=1,
                                        )
                                        wmma_A0 = tir.match_buffer(
                                            wmma_A[vi * 16 : vi * 16 + 16, vk * 16 : vk * 16 + 16],
                                            (16, 16),
                                            "float16",
                                            strides=[16, 1],
                                            scope="wmma.matrix_a",
                                            offset_factor=1,
                                        )
                                        tir.evaluate(
                                            tir.tvm_load_matrix_sync(
                                                wmma_A0.data,
                                                16,
                                                16,
                                                16,
                                                i,
                                                tir.tvm_access_ptr(
                                                    tir.type_annotation(dtype="float16"),
                                                    A0.data,
                                                    A0.elem_offset + 8,
                                                    A0.strides[0],
                                                    1,
                                                    dtype="handle",
                                                ),
                                                A0.strides[0],
                                                "row_major",
                                                dtype="handle",
                                            )
                                        )
                                for j in range(0, 4):
                                    with tir.block([64, 64]) as [vj, vk]:
                                        tir.bind(vj, by * 8 + tz * 4 + j)
                                        tir.bind(vk, ko * 2 + ki)
                                        tir.reads(
                                            shared_B[
                                                vj * 16 : vj * 16 + 16,
                                                vk * 16 : vk * 16 + 16 + 8,
                                            ]
                                        )
                                        tir.writes(
                                            wmma_B[vj * 16 : vj * 16 + 16, vk * 16 : vk * 16 + 16]
                                        )
                                        s0 = tir.var("int32")
                                        s1 = tir.var("int32")
                                        B0 = tir.match_buffer(
                                            shared_B[
                                                vj * 16 : vj * 16 + 16,
                                                vk * 16 : vk * 16 + 16 + 8,
                                            ],
                                            (16, 16 + 8),
                                            "float16",
                                            strides=[s0, s1],
                                            scope="shared",
                                            offset_factor=1,
                                        )
                                        wmma_B0 = tir.match_buffer(
                                            wmma_B[vj * 16 : vj * 16 + 16, vk * 16 : vk * 16 + 16],
                                            (16, 16),
                                            "float16",
                                            strides=[16, 1],
                                            scope="wmma.matrix_b",
                                            offset_factor=1,
                                        )
                                        tir.evaluate(
                                            tir.tvm_load_matrix_sync(
                                                wmma_B0.data,
                                                16,
                                                16,
                                                16,
                                                j,
                                                tir.tvm_access_ptr(
                                                    tir.type_annotation(dtype="float16"),
                                                    B0.data,
                                                    B0.elem_offset + 8,
                                                    B0.strides[0],
                                                    1,
                                                    dtype="handle",
                                                ),
                                                B0.strides[0],
                                                "col_major",
                                                dtype="handle",
                                            )
                                        )
                                for i, j in tir.grid(2, 4):
                                    with tir.block([64, 64, tir.reduce_axis(0, 64)]) as [
                                        vi,
                                        vj,
                                        vk,
                                    ]:
                                        tir.bind(vi, bx * 4 + ty * 2 + i)
                                        tir.bind(vj, by * 8 + tz * 4 + j)
                                        tir.bind(vk, ko * 2 + ki)
                                        tir.reads(
                                            [
                                                wmma_A[
                                                    vi * 16 : vi * 16 + 16, vk * 16 : vk * 16 + 16
                                                ],
                                                wmma_B[
                                                    vj * 16 : vj * 16 + 16, vk * 16 : vk * 16 + 16
                                                ],
                                                wmma_C[
                                                    vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16
                                                ],
                                            ]
                                        )
                                        tir.writes(
                                            wmma_C[vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16]
                                        )
                                        wmma_A1 = tir.match_buffer(
                                            wmma_A[vi * 16 : vi * 16 + 16, vk * 16 : vk * 16 + 16],
                                            (16, 16),
                                            "float16",
                                            strides=[16, 1],
                                            scope="wmma.matrix_a",
                                            offset_factor=1,
                                        )
                                        wmma_B1 = tir.match_buffer(
                                            wmma_B[vj * 16 : vj * 16 + 16, vk * 16 : vk * 16 + 16],
                                            (16, 16),
                                            "float16",
                                            strides=[16, 1],
                                            scope="wmma.matrix_b",
                                            offset_factor=1,
                                        )
                                        wmma_C1 = tir.match_buffer(
                                            wmma_C[vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16],
                                            (16, 16),
                                            "float32",
                                            strides=[16 * 4, 1],
                                            scope="wmma.accumulator",
                                            offset_factor=1,
                                        )
                                        tir.evaluate(
                                            tir.tvm_mma_sync(
                                                wmma_C1.data,
                                                i * 4 + j,
                                                wmma_A1.data,
                                                i,
                                                wmma_B1.data,
                                                j,
                                                wmma_C1.data,
                                                i * 4 + j,
                                                dtype="handle",
                                            )
                                        )
                        for i, j in tir.grid(2, 4):
                            with tir.block([64, 64]) as [vi, vj]:
                                tir.bind(vi, bx * 4 + ty * 2 + i)
                                tir.bind(vj, by * 8 + tz * 4 + j)
                                tir.reads(wmma_C[vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16])
                                tir.writes(C[vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16])
                                s0 = tir.var("int32")
                                s1 = tir.var("int32")
                                wmma_C2 = tir.match_buffer(
                                    wmma_C[vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16],
                                    (16, 16),
                                    "float32",
                                    strides=[16 * 4, 1],
                                    scope="wmma.accumulator",
                                    offset_factor=1,
                                )
                                C1 = tir.match_buffer(
                                    C[vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16],
                                    (16, 16),
                                    "float32",
                                    strides=[s0, s1],
                                    offset_factor=1,
                                )
                                tir.evaluate(
                                    tir.tvm_store_matrix_sync(
                                        wmma_C2.data,
                                        16,
                                        16,
                                        16,
                                        i * 4 + j,
                                        tir.tvm_access_ptr(
                                            tir.type_annotation(dtype="float32"),
                                            C1.data,
                                            C1.elem_offset,
                                            C1.strides[0],
                                            1,
                                            dtype="handle",
                                        ),
                                        C1.strides[0],
                                        "row_major",
                                        dtype="handle",
                                    )
                                )
 def main(placeholder: ty.handle, placeholder_1: ty.handle,
          placeholder_2: ty.handle, ethosu_write: ty.handle,
          placeholder_3: ty.handle, placeholder_4: ty.handle,
          placeholder_5: ty.handle, placeholder_6: ty.handle,
          placeholder_7: ty.handle, placeholder_8: ty.handle,
          placeholder_9: ty.handle, placeholder_10: ty.handle) -> None:
     # function attr dict
     tir.func_attr({
         "from_legacy_te_schedule": True,
         "global_symbol": "main",
         "tir.noalias": True
     })
     buffer = tir.match_buffer(placeholder_7, [80],
                               dtype="uint8",
                               elem_offset=0,
                               align=128,
                               offset_factor=1)
     buffer_1 = tir.match_buffer(placeholder_5, [80],
                                 dtype="uint8",
                                 elem_offset=0,
                                 align=128,
                                 offset_factor=1)
     buffer_2 = tir.match_buffer(placeholder_3, [80],
                                 dtype="uint8",
                                 elem_offset=0,
                                 align=128,
                                 offset_factor=1)
     buffer_3 = tir.match_buffer(placeholder_4, [32],
                                 dtype="uint8",
                                 elem_offset=0,
                                 align=128,
                                 offset_factor=1)
     buffer_4 = tir.match_buffer(placeholder_9, [80],
                                 dtype="uint8",
                                 elem_offset=0,
                                 align=128,
                                 offset_factor=1)
     buffer_5 = tir.match_buffer(placeholder_6, [32],
                                 dtype="uint8",
                                 elem_offset=0,
                                 align=128,
                                 offset_factor=1)
     placeholder_11 = tir.match_buffer(placeholder, [1, 16, 16, 32],
                                       dtype="int8",
                                       elem_offset=0,
                                       align=128,
                                       offset_factor=1)
     buffer_6 = tir.match_buffer(placeholder_1, [592],
                                 dtype="uint8",
                                 elem_offset=0,
                                 align=128,
                                 offset_factor=1)
     ethosu_write_1 = tir.match_buffer(ethosu_write, [1, 16, 16, 8],
                                       dtype="int8",
                                       elem_offset=0,
                                       align=128,
                                       offset_factor=1)
     buffer_7 = tir.match_buffer(placeholder_2, [160],
                                 dtype="uint8",
                                 elem_offset=0,
                                 align=128,
                                 offset_factor=1)
     buffer_8 = tir.match_buffer(placeholder_8, [32],
                                 dtype="uint8",
                                 elem_offset=0,
                                 align=128,
                                 offset_factor=1)
     buffer_9 = tir.match_buffer(placeholder_10, [32],
                                 dtype="uint8",
                                 elem_offset=0,
                                 align=128,
                                 offset_factor=1)
     # body
     ethosu_write_2 = tir.allocate([4096], "int8", "global")
     placeholder_global = tir.allocate([80], "uint8", "global")
     placeholder_d_global = tir.allocate([32], "uint8", "global")
     tir.evaluate(
         tir.call_extern("ethosu_conv2d",
                         "int8",
                         16,
                         16,
                         32,
                         16,
                         0,
                         16,
                         tir.load("int8", placeholder_11.data, 0),
                         0,
                         0,
                         0,
                         tir.float32(0.5),
                         10,
                         "NHWC",
                         512,
                         32,
                         1,
                         "int8",
                         16,
                         16,
                         16,
                         16,
                         0,
                         16,
                         tir.load("int8", ethosu_write_2, 0),
                         0,
                         0,
                         0,
                         tir.float32(0.25),
                         14,
                         "NHWC",
                         256,
                         16,
                         1,
                         1,
                         1,
                         1,
                         1,
                         1,
                         1,
                         tir.load("uint8", buffer_6.data, 0),
                         592,
                         12,
                         tir.load("uint8", buffer_7.data, 0),
                         160,
                         0,
                         0,
                         0,
                         0,
                         "NONE",
                         0,
                         0,
                         "NONE",
                         dtype="handle"))
     tir.evaluate(
         tir.call_extern("ethosu_copy",
                         tir.load("uint8", buffer_2.data, 0),
                         80,
                         tir.load("uint8", placeholder_global, 0),
                         dtype="handle"))
     tir.evaluate(
         tir.call_extern("ethosu_copy",
                         tir.load("uint8", buffer_3.data, 0),
                         32,
                         tir.load("uint8", placeholder_d_global, 0),
                         dtype="handle"))
     tir.evaluate(
         tir.call_extern("ethosu_conv2d",
                         "int8",
                         16,
                         16,
                         16,
                         16,
                         0,
                         16,
                         tir.load("int8", ethosu_write_2, 0),
                         0,
                         0,
                         0,
                         tir.float32(0.5),
                         10,
                         "NHWC",
                         256,
                         16,
                         1,
                         "int8",
                         16,
                         16,
                         2,
                         16,
                         0,
                         16,
                         tir.load("int8", ethosu_write_1.data, 0),
                         0,
                         0,
                         0,
                         tir.float32(0.25),
                         14,
                         "NHWC",
                         128,
                         8,
                         1,
                         1,
                         1,
                         1,
                         1,
                         1,
                         1,
                         tir.load("uint8", placeholder_global, 0),
                         80,
                         12,
                         tir.load("uint8", placeholder_d_global, 0),
                         32,
                         0,
                         0,
                         0,
                         0,
                         "NONE",
                         0,
                         0,
                         "NONE",
                         dtype="handle"))
     tir.evaluate(
         tir.call_extern("ethosu_copy",
                         tir.load("uint8", buffer_1.data, 0),
                         80,
                         tir.load("uint8", placeholder_global, 0),
                         dtype="handle"))
     tir.evaluate(
         tir.call_extern("ethosu_copy",
                         tir.load("uint8", buffer_5.data, 0),
                         32,
                         tir.load("uint8", placeholder_d_global, 0),
                         dtype="handle"))
     tir.evaluate(
         tir.call_extern("ethosu_conv2d",
                         "int8",
                         16,
                         16,
                         16,
                         16,
                         0,
                         16,
                         tir.load("int8", ethosu_write_2, 0),
                         0,
                         0,
                         0,
                         tir.float32(0.5),
                         10,
                         "NHWC",
                         256,
                         16,
                         1,
                         "int8",
                         16,
                         16,
                         2,
                         16,
                         0,
                         16,
                         tir.load("int8", ethosu_write_1.data, 2),
                         0,
                         0,
                         0,
                         tir.float32(0.25),
                         14,
                         "NHWC",
                         128,
                         8,
                         1,
                         1,
                         1,
                         1,
                         1,
                         1,
                         1,
                         tir.load("uint8", placeholder_global, 0),
                         80,
                         12,
                         tir.load("uint8", placeholder_d_global, 0),
                         32,
                         0,
                         0,
                         0,
                         0,
                         "NONE",
                         0,
                         0,
                         "NONE",
                         dtype="handle"))
     tir.evaluate(
         tir.call_extern("ethosu_copy",
                         tir.load("uint8", buffer.data, 0),
                         80,
                         tir.load("uint8", placeholder_global, 0),
                         dtype="handle"))
     tir.evaluate(
         tir.call_extern("ethosu_copy",
                         tir.load("uint8", buffer_8.data, 0),
                         32,
                         tir.load("uint8", placeholder_d_global, 0),
                         dtype="handle"))
     tir.evaluate(
         tir.call_extern("ethosu_conv2d",
                         "int8",
                         16,
                         16,
                         16,
                         16,
                         0,
                         16,
                         tir.load("int8", ethosu_write_2, 0),
                         0,
                         0,
                         0,
                         tir.float32(0.5),
                         10,
                         "NHWC",
                         256,
                         16,
                         1,
                         "int8",
                         16,
                         16,
                         2,
                         16,
                         0,
                         16,
                         tir.load("int8", ethosu_write_1.data, 4),
                         0,
                         0,
                         0,
                         tir.float32(0.25),
                         14,
                         "NHWC",
                         128,
                         8,
                         1,
                         1,
                         1,
                         1,
                         1,
                         1,
                         1,
                         tir.load("uint8", placeholder_global, 0),
                         80,
                         12,
                         tir.load("uint8", placeholder_d_global, 0),
                         32,
                         0,
                         0,
                         0,
                         0,
                         "NONE",
                         0,
                         0,
                         "NONE",
                         dtype="handle"))
     tir.evaluate(
         tir.call_extern("ethosu_copy",
                         tir.load("uint8", buffer_4.data, 0),
                         80,
                         tir.load("uint8", placeholder_global, 0),
                         dtype="handle"))
     tir.evaluate(
         tir.call_extern("ethosu_copy",
                         tir.load("uint8", buffer_9.data, 0),
                         32,
                         tir.load("uint8", placeholder_d_global, 0),
                         dtype="handle"))
     tir.evaluate(
         tir.call_extern("ethosu_conv2d",
                         "int8",
                         16,
                         16,
                         16,
                         16,
                         0,
                         16,
                         tir.load("int8", ethosu_write_2, 0),
                         0,
                         0,
                         0,
                         tir.float32(0.5),
                         10,
                         "NHWC",
                         256,
                         16,
                         1,
                         "int8",
                         16,
                         16,
                         2,
                         16,
                         0,
                         16,
                         tir.load("int8", ethosu_write_1.data, 6),
                         0,
                         0,
                         0,
                         tir.float32(0.25),
                         14,
                         "NHWC",
                         128,
                         8,
                         1,
                         1,
                         1,
                         1,
                         1,
                         1,
                         1,
                         tir.load("uint8", placeholder_global, 0),
                         80,
                         12,
                         tir.load("uint8", placeholder_d_global, 0),
                         32,
                         0,
                         0,
                         0,
                         0,
                         "NONE",
                         0,
                         0,
                         "NONE",
                         dtype="handle"))