def gemm() -> None:
    A = T.alloc_buffer([16, 16], "float32")
    B = T.alloc_buffer([16, 16], "float32")
    C = T.alloc_buffer([16, 16], "float32")
    for i, j, k, ii, jj in T.grid(4, 4, 16, 4, 4):
        with T.block("update"):
            vi = T.axis.S(16, i * 4 + ii)
            vj = T.axis.S(16, j * 4 + jj)
            vk = T.axis.R(16, k)
            T.reads(A[vi, vk], B[vj, vk])
            T.writes(C[vi, vj])
            with T.init():
                T.reads([])
                T.writes(C[vi, vj])
                C[vi, vj] = 0
            C[vi, vj] += A[vi, vk] * B[vj, vk]
Ejemplo n.º 2
0
 def main(a: T.handle, b: T.handle, c: T.handle) -> None:
     T.func_attr({"global_symbol": "main"})
     A = T.match_buffer(a, (1024, 1024), "float32")
     B = T.match_buffer(b, (1024, 1024), "float32")
     C = T.match_buffer(c, (1024, 1024), "float32")
     with T.block("root"):
         for i, j, k in T.grid(1024, 1024, 1024):
             with T.block("matmul"):
                 T.block_attr({
                     "schedule_rule":
                     "tvm.meta_schedule.test.custom_search_space"
                 })
                 vi, vj, vk = T.axis.remap("SSR", [i, j, k])
                 with T.init():
                     C[vi, vj] = 0.0
                 C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj]
Ejemplo n.º 3
0
def outer_product_desc(a: T.handle, b: T.handle, c: T.handle) -> None:
    A = T.match_buffer(a, (16, 1), offset_factor=1)
    B = T.match_buffer(b, (16, 1), offset_factor=1)
    C = T.match_buffer(c, (16, 16), offset_factor=1)

    with T.block("root"):
        T.reads(
            C[0 : 16, 0 : 16],
            A[0 : 16, 0 : 1],
            B[0 : 16, 0 : 1],
        )
        T.writes(C[0 : 16, 0 : 16])
        for i, j in T.grid(16, 16):
            with T.block("update"):
                vii, vjj = T.axis.remap("SS", [i, j])
                C[vii, vjj] = C[vii, vjj] + A[vii, 0] * B[vjj, 0]
Ejemplo n.º 4
0
 def func(a: T.handle, b: T.handle, c: T.handle) -> None:
     A = T.match_buffer(a, [64, 32], dtype="float32")
     B = T.match_buffer(b, [64, 32], dtype="float32")
     C = T.match_buffer(c, [64, 32], dtype="float32")
     for i, j in T.grid(64, 32):  # type: ignore
         with T.block():
             T.reads([A[i, j], B[i, j]])  # type: ignore
             T.writes([B[i, j], C[i, j]])  # type: ignore
             with T.block("B"):
                 T.reads([A[i, j]])  # type: ignore
                 T.writes([B[i, j]])  # type: ignore
                 B[i, j] = A[i, j]  # type: ignore
             with T.block("C"):
                 T.reads([B[i, j]])  # type: ignore
                 T.writes([C[i, j]])  # type: ignore
                 C[i, j] = B[i, j]  # type: ignore
Ejemplo n.º 5
0
def transformed_matmul(a: T.handle, b: T.handle, c: T.handle) -> None:
    A = T.match_buffer(a, [128, 128])
    B = T.match_buffer(b, [128, 128])
    C = T.match_buffer(c, [128, 128])

    for i0, i1, i2_outer, i2_inner_outer, i2_inner_inner in T.grid(
            128, 128, 4, 8, 4):
        with T.block("update"):
            vi, vj = T.axis.remap("SS", [i0, i1])
            vk = T.axis.R(128,
                          i2_outer * 32 + i2_inner_outer * 4 + i2_inner_inner)
            T.reads([C[vi, vj], A[vi, vk], B[vj, vk]])
            T.writes([C[vi, vj]])
            with T.init():
                C[vi, vj] = 0.0
            C[vi, vj] = C[vi, vj] + (A[vi, vk] * B[vj, vk])
def tir_multi_output(a0: T.handle, a1: T.handle, b0: T.handle,
                     b1: T.handle) -> None:
    m = T.var("int32")
    n = T.var("int32")
    A0 = T.match_buffer(a0, (m, n))
    A1 = T.match_buffer(a1, (m, n))
    B0 = T.match_buffer(b0, (m, n))
    B1 = T.match_buffer(b1, (m, n))

    for i0, i1 in T.grid(m, n):
        with T.block("B.v0"):
            i, j = T.axis.remap("SS", [i0, i1])
            B0[i, j] = A0[i, j] + 2.0
        with T.block("B.v1"):
            i, j = T.axis.remap("SS", [i0, i1])
            B1[i, j] = A1[i, j] * 3.0
Ejemplo n.º 7
0
def tir_multi_output(a0: T.handle, a1: T.handle, b0: T.handle, b1: T.handle) -> None:
    T.func_attr({"global_symbol": "main", "tir.noalias": True})
    m = T.var("int32")
    n = T.var("int32")
    A0 = T.match_buffer(a0, (m, n))
    A1 = T.match_buffer(a1, (m, n))
    B0 = T.match_buffer(b0, (m, n))
    B1 = T.match_buffer(b1, (m, n))

    for i0, i1 in T.grid(m, n):
        with T.block("B.v0"):
            i, j = T.axis.remap("SS", [i0, i1])
            B0[i, j] = A0[i, j] + 2.0
        with T.block("B.v1"):
            i, j = T.axis.remap("SS", [i0, i1])
            B1[i, j] = A1[i, j] * 3.0
 def main(a: T.handle, b: T.handle) -> None:
     # function attr dict
     T.func_attr({"global_symbol": "main"})
     A = T.match_buffer(a, [1024, 1024, 1024], dtype="float32")
     B = T.match_buffer(b, [1024, 1024, 1024], dtype="float32")
     # body
     with T.block("root"):
         T.block_attr({"meta_schedule.parallel":128, "meta_schedule.vectorize":32})
         for i0, j0, i1, j1, k0, i2, j2, k1 in T.grid(128, 64, 4, 4, 64, 4, 8, 32):
             with T.block("move"):
                 vi = T.axis.spatial(1024, i0 * 16 + i1 * 4 + i2)
                 vj = T.axis.spatial(1024, j0 * 32 + j1 * 8 + j2)
                 vk = T.axis.spatial(1024, k0 * 32 + k1)
                 T.where((i0 * 4 + i1) * 4 + i2 < 1024 and (j0 * 4 + j1) * 8 + j2 < 1024 and k0 * 32 + k1 < 1024)
                 T.reads([A[vi, vj, vk]])
                 T.writes([B[vi, vj, vk]])
                 B[vi, vj, vk] = A[vi, vj, vk]
def matmul_relu_ann1(a: T.handle, b: T.handle, d: T.handle) -> None:
    A = T.match_buffer(a, (1024, 1024))
    B = T.match_buffer(b, (1024, 1024))
    C = T.alloc_buffer((1024, 1024))
    D = T.match_buffer(d, (1024, 1024))
    for i in T.serial(0, 1024, annotations={"test1": "aaa"}):
        for j in T.serial(0, 1024, annotations={"test2": 612}):
            for k in T.serial(0, 1024):
                with T.block("matmul"):
                    vi, vj, vk = T.axis.remap("SSR", [i, j, k])
                    with T.init():
                        C[vi, vj] = 0.0
                    C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj]
    for i, j in T.grid(1024, 1024):
        with T.block("relu"):
            vi, vj = T.axis.remap("SS", [i, j])
            D[vi, vj] = T.max(C[vi, vj], 0.0)
Ejemplo n.º 10
0
def transformed_high_dim_opaque_access(a: T.handle) -> None:
    A = T.match_buffer(a, (16, 32, 64))
    for i, j, k in T.grid(16, 2, 4):
        with T.block([]):
            T.reads([])
            T.writes(A[i, j * 16 : j * 16 + 16, k * 16 : k * 16 + 16])
            T.evaluate(
                T.intrin_test(
                    A.data,
                    i * 2048 + j * 1024 + k * 16,
                    64,
                    1,
                    16,
                    16,
                    dtype="handle",
                )
            )
Ejemplo n.º 11
0
 def main(A: T.Buffer[(1, 256, 256), "float32"],
          D: T.Buffer[(1, ), "float32"]) -> None:
     C = T.alloc_buffer([1], dtype="float32")
     for i0_fused_0 in T.thread_binding(1, thread="blockIdx.x"):
         for i0_fused_1 in T.thread_binding(1, thread="threadIdx.x"):
             for i1, i2 in T.grid(256, 256):
                 with T.block("C"):
                     b = T.axis.S(1, 0)
                     i, j = T.axis.remap("RR", [i1, i2])
                     with T.init():
                         C[b] = T.float32(0)
                     C[b] = C[b] + A[b, i, j] * A[b, i, j]
     for i0_fused_0 in T.thread_binding(1, thread="blockIdx.x"):
         for i0_fused_1 in T.thread_binding(1, thread="threadIdx.x"):
             with T.block("D"):
                 b = T.axis.S(1, 0)
                 D[b] = T.sqrt(C[b], dtype="float32")
 def main(  # type: ignore
         placeholder: T.Buffer[(1, 3, 16, 16), "float32"],  # type: ignore
         T_layout_trans: T.Buffer[(1, 1, 16, 16, 3),
                                  "float32"],  # type: ignore
 ) -> None:  # type: ignore
     # function attr dict
     T.func_attr({"global_symbol": "main", "tir.noalias": True})
     # body
     # with T.block("root")
     for i0, i1, i2, i3, i4 in T.grid(1, 1, 16, 16, 3):
         with T.block("T_layout_trans"):
             ax0, ax1, ax2, ax3, ax4 = T.axis.remap("SSSSS",
                                                    [i0, i1, i2, i3, i4])
             T.reads(placeholder[0, ax4, ax2, ax3])
             T.writes(T_layout_trans[ax0, ax1, ax2, ax3, ax4])
             T_layout_trans[ax0, ax1, ax2, ax3, ax4] = placeholder[0, ax4,
                                                                   ax2, ax3]
Ejemplo n.º 13
0
def transformed_square_sum_square_root_factor_one_2(a: T.handle, d: T.handle) -> None:
    A = T.match_buffer(a, [16, 256, 256])
    D = T.match_buffer(d, [16])
    C = T.alloc_buffer([16])

    for i0, i1_i2_fused_outer, i1_i2_fused_inner in T.grid(16, 1, 65536):
        with T.block("C"):
            b = T.axis.S(16, i0)
            i = T.axis.R(256, T.floordiv(i1_i2_fused_inner, 256))
            j = T.axis.R(256, T.floormod(i1_i2_fused_inner, 256))
            with T.init():
                C[b] = 0.0
            C[b] = C[b] + (A[b, i, j] * A[b, i, j])
    for i0_1 in T.serial(0, 16):
        with T.block("D"):
            b_1 = T.axis.S(16, i0_1)
            D[b_1] = T.sqrt(C[b_1], dtype="float32")
Ejemplo n.º 14
0
def transformed_high_dim_opaque_access_with_source_strides(
        a: T.handle) -> None:
    A = T.match_buffer(a, (16, 32, 64), strides=[2576, 80, 1])
    for i, j, k in T.grid(16, 2, 4):
        with T.block():
            T.reads([])
            T.writes(A[i, j * 16:j * 16 + 16, k * 16:k * 16 + 16])
            T.evaluate(
                T.intrin_test(
                    A.data,
                    i * 2576 + j * 1280 + k * 16,
                    80,
                    1,
                    16,
                    16,
                    dtype="handle",
                ))
Ejemplo n.º 15
0
def matmul_loop_multiple_children(a: T.handle, b: T.handle, c: T.handle, d: T.handle) -> None:
    A = T.match_buffer(a, [128, 128])
    B = T.match_buffer(b, [128, 128])
    C = T.match_buffer(c, [128, 128])
    D = T.match_buffer(d, [128, 128])

    for k, i, j in T.grid(128, 128, 128):
        with T.block("C"):
            ck, ci, cj = T.axis.remap("RSS", [k, i, j])
            with T.init():
                C[ci, cj] = 0.0
            C[ci, cj] = C[ci, cj] + A[ci, ck] * B[ck, cj]
        with T.block("D"):
            dk, di, dj = T.axis.remap("RSS", [k, i, j])
            with T.init():
                D[di, dj] = 0.0
            D[di, dj] = D[di, dj] + B[di, dk] * A[dk, dj]
Ejemplo n.º 16
0
def different_access_indices(a: T.handle, b: T.handle) -> None:
    A = T.match_buffer(a, [128, 128, 128], dtype="float32")
    B = T.match_buffer(b, [128, 128], dtype="float32")
    for i, j in T.grid(128, 128):
        for k in T.thread_binding(0, 128, thread="threadIdx.x"):
            with T.block("B"):
                vi, vj, vk = T.axis.remap("SSR", [i, j, k])
                T.reads([B[vi, vj], A[vi, vj, vk]])
                T.writes([
                    B[T.min(vj, vi):T.min(vj, vi) +
                      (T.max(vj, vi) + 1 - T.min(vj, vi)),
                      T.min(vi, vj):T.min(vi, vj) +
                      (T.max(vi, vj) + 1 - T.min(vi, vj)), ]
                ])
                with T.init():
                    B[vj, vi] = T.float32(0)
                B[vi, vj] = B[vi, vj] + A[vi, vj, vk]
Ejemplo n.º 17
0
def elementwise_non_single_branch(a: T.handle, b: T.handle) -> None:
    A = T.match_buffer(a, (128, 128, 128))
    C = T.alloc_buffer((128, 128, 128))
    B = T.match_buffer(b, (128, 128, 128))
    for i, j in T.grid(128, 128):
        for k in T.serial(0, 128):
            with T.block([128, 128, 128], "C") as [vi, vj, vk]:
                T.bind(vi, i)
                T.bind(vj, j)
                T.bind(vk, k)
                C[vi, vj, vk] = A[vi, vj, vk] * 2.0
        for k in T.serial(0, 128):
            with T.block([128, 128, 128], "B") as [vi, vj, vk]:
                T.bind(vi, i)
                T.bind(vj, j)
                T.bind(vk, k)
                B[vi, vj, vk] = C[vi, vj, vk] * 2.0
Ejemplo n.º 18
0
 def main(placeholder: T.Buffer[(12, 64, 64), "float32"],
          T_reshape: T.Buffer[(64, 768), "float32"]) -> None:
     for i0_i1_fused_0, i0_i1_fused_1 in T.grid(1536000, 32):
         with T.block("T_reshape_1"):
             ax0 = T.axis.spatial(
                 64, (i0_i1_fused_0 * 32 + i0_i1_fused_1) // 768)
             ax1 = T.axis.spatial(
                 768, (i0_i1_fused_0 * 32 + i0_i1_fused_1) % 768)
             T.reads(placeholder[ax1 % 768 // 64, (ax1 // 768 + ax0) % 64,
                                 ax1 % 64])
             T.writes(T_reshape[ax0, ax1])
             T_reshape[ax0, ax1] = placeholder[(
                 (ax1 % 64 // 64 +
                  (ax1 // 768 + ax0) % 64) // 64 + ax1 % 768 // 64) % 12,
                                               (ax1 % 64 // 64 +
                                                (ax1 // 768 + ax0) % 64) %
                                               64, ax1 % 64 % 64, ]
def matmul(
    A: T.Buffer[(512, 512), "float32"],
    B: T.Buffer[(512, 512), "float32"],
    C: T.Buffer[(512, 512), "float32"],
) -> None:
    # function attr dict
    T.func_attr({"global_symbol": "main", "tir.noalias": True})
    # body
    # with T.block("root")
    for i0, i1, i2 in T.grid(512, 512, 512):
        with T.block("C"):
            i, j, k = T.axis.remap("SSR", [i0, i1, i2])
            T.reads(C[i, j], A[i, k], B[k, j])
            T.writes(C[i, j])
            with T.init():
                C[i, j] = T.float32(0)
            C[i, j] = C[i, j] + A[i, k] * B[k, j]
def access_of_padding_pattern() -> None:
    X = T.alloc_buffer([28, 28])
    X_pad = T.alloc_buffer([32, 32])
    Y = T.alloc_buffer([28, 28])
    for i, j in T.grid(32, 32):
        with T.block("padding"):
            vi, vj = T.axis.remap("SS", [i, j])
            T.reads([X[vi - 2, vj - 2]])
            T.writes([X_pad[vi, vj]])
            X_pad[vi, vj] = T.if_then_else(
                2 <= vi and vi < 30 and 2 <= vj and vj < 30, X[vi - 2, vj - 2], 0.0, dtype="float32"
            )
        with T.block("padding_reverse"):
            vi, vj = T.axis.remap("SS", [i, j])
            T.reads([X_pad[vi, vj]])
            T.writes([Y[vi - 2, vj - 2]])
            if 2 <= vi and vi < 30 and 2 <= vj and vj < 30:
                Y[vi - 2, vj - 2] = X_pad[vi, vj]
Ejemplo n.º 21
0
def thread_bound_nested_block_after_cache_read(
    A: T.Buffer[(16, 16), "float32"], B: T.Buffer[(16,), "float32"]
) -> None:
    for i in T.thread_binding(16, thread="blockIdx.x"):
        with T.block("outer"):
            vi = T.axis.spatial(16, i)
            A_shared = T.alloc_buffer([1, 16], dtype="float32", scope="shared")
            for ax0, ax1 in T.grid(1, 16):
                with T.block("A_shared"):
                    v0 = T.axis.spatial(16, vi + ax0)
                    v1 = T.axis.spatial(16, ax1)
                    A_shared[v0, v1] = A[v0, v1]
            for j in T.thread_binding(16, thread="threadIdx.x"):
                with T.block("inner"):
                    vj = T.axis.reduce(16, j)
                    with T.init():
                        B[vi] = T.float32(0)
                    B[vi] = B[vi] + A_shared[vi, vj]
Ejemplo n.º 22
0
 def wmma_store_desc(a: T.handle, c: T.handle) -> None:
     A = T.match_buffer(a, (m_dim, n_dim),
                        dtype,
                        align=128,
                        offset_factor=16,
                        scope="wmma.accumulator")
     C = T.match_buffer(c, (m_dim, n_dim),
                        dtype,
                        align=128,
                        offset_factor=16,
                        scope=scope)
     with T.block("root"):
         T.reads(A[0:m_dim, 0:n_dim])
         T.writes(C[0:m_dim, 0:n_dim])
         for i, j in T.grid(m_dim, n_dim):
             with T.block("store"):
                 vii, vjj = T.axis.remap("SS", [i, j])
                 C[vii, vjj] = A[vii, vjj]
Ejemplo n.º 23
0
 def wmma_load_desc(a: T.handle, c: T.handle) -> None:
     A = T.match_buffer(a, (m_dim, n_dim),
                        dtype,
                        align=128,
                        offset_factor=16,
                        scope=shared_scope)
     C = T.match_buffer(c, (m_dim, n_dim),
                        dtype,
                        align=128,
                        offset_factor=16,
                        scope=wmma_fragment_scope)
     with T.block("root"):
         T.reads(A[0:m_dim, 0:n_dim])
         T.writes(C[0:m_dim, 0:n_dim])
         for i, j in T.grid(m_dim, n_dim):
             with T.block("load"):
                 vii, vjj = T.axis.remap("SS", [i, j])
                 C[vii, vjj] = A[vii, vjj]
Ejemplo n.º 24
0
 def main(a: T.handle, b: T.handle, c: T.handle) -> None:
     # function attr dict
     T.func_attr({
         "global_symbol": "main",
         "from_legacy_te_schedule": True,
         "tir.noalias": True
     })
     A = T.match_buffer(a, [128, 128])
     B = T.match_buffer(b, [128, 128])
     C = T.match_buffer(c, [128, 128])
     # body
     for x, y in T.grid(128, 128):
         C.data[x * 128 + y] = 0.0
         for k in T.serial(0, 128):
             C.data[x * 128 + y] = T.load(
                 "float32", C.data, x * 128 +
                 y) + T.load("float32", A.data, x * 128 + k) * T.load(
                     "float32", B.data, y * 128 + k)
Ejemplo n.º 25
0
def transformed_rank0_buffer(a: T.handle, b: T.handle) -> None:
    A = T.match_buffer(a, (8, 8))
    B = T.match_buffer(b, (8, 8))
    for i, j in T.grid(8, 8):
        with T.block():
            T.reads([])
            T.writes([A[i, j], B[i, j]])
            A[i, j] = 1
            T.evaluate(
                T.intrin_test(
                    B.data,
                    i * 8 + j,
                    0,
                    0,
                    0,
                    0,
                    dtype="handle",
                ))
Ejemplo n.º 26
0
 def main(var_X: T.handle, var_W: T.handle, var_B: T.handle,
          var_bn_scale: T.handle, var_bn_offset: T.handle,
          var_compute: T.handle) -> None:
     X = T.match_buffer(var_X, [1, 512, 56, 56], dtype="float32")
     W = T.match_buffer(var_W, [512, 512, 3, 3], dtype="float32")
     B = T.match_buffer(var_B, [512, 1, 1], dtype="float32")
     bn_scale = T.match_buffer(var_bn_scale, [512, 1, 1], dtype="float32")
     bn_offset = T.match_buffer(var_bn_offset, [512, 1, 1], dtype="float32")
     compute = T.match_buffer(var_compute, [1, 512, 56, 56],
                              dtype="float32")
     pad_temp = T.alloc_buffer([1, 512, 58, 58], dtype="float32")
     compute_1 = T.alloc_buffer([1, 512, 56, 56], dtype="float32")
     bias_add = T.alloc_buffer([1, 512, 56, 56], dtype="float32")
     bn_mul = T.alloc_buffer([1, 512, 56, 56], dtype="float32")
     bn_add = T.alloc_buffer([1, 512, 56, 56], dtype="float32")
     for i0, i1, i2, i3 in T.grid(1, 512, 58, 58):
         with T.block("pad_temp"):
             i0_1, i1_1, i2_1, i3_1 = T.axis.remap("SSSS", [i0, i1, i2, i3])
             pad_temp[i0_1, i1_1, i2_1,
                      i3_1] = T.if_then_else(i2_1 >= 1 and i2_1 < 57
                                             and i3_1 >= 1 and i3_1 < 57,
                                             X[i0_1, i1_1, i2_1 - 1,
                                               i3_1 - 1],
                                             T.float32(0),
                                             dtype="float32")
     for i0, i1, i2, i3, i4, i5, i6 in T.grid(1, 512, 56, 56, 512, 3, 3):
         with T.block("compute"):
             nn, ff, yy, xx, rc, ry, rx = T.axis.remap(
                 "SSSSRRR", [i0, i1, i2, i3, i4, i5, i6])
             with T.init():
                 compute_1[nn, ff, yy, xx] = T.float32(0)
             compute_1[nn, ff, yy,
                       xx] = compute_1[nn, ff, yy, xx] + pad_temp[
                           nn, rc, yy + ry, xx + rx] * W[ff, rc, ry, rx]
     for i0, i1, i2, i3 in T.grid(1, 512, 56, 56):
         with T.block("bias_add"):
             i, j, k, l = T.axis.remap("SSSS", [i0, i1, i2, i3])
             bias_add[i, j, k, l] = compute_1[i, j, k, l] + B[j, 0, 0]
     for i0, i1, i2, i3 in T.grid(1, 512, 56, 56):
         with T.block("bn_mul"):
             i, j, k, l = T.axis.remap("SSSS", [i0, i1, i2, i3])
             bn_mul[i, j, k, l] = bias_add[i, j, k, l] * bn_scale[j, 0, 0]
     for i0, i1, i2, i3 in T.grid(1, 512, 56, 56):
         with T.block("bn_add"):
             i, j, k, l = T.axis.remap("SSSS", [i0, i1, i2, i3])
             bn_add[i, j, k, l] = bn_mul[i, j, k, l] + bn_offset[j, 0, 0]
     for i0, i1, i2, i3 in T.grid(1, 512, 56, 56):
         with T.block("compute_1"):
             i0_2, i1_2, i2_2, i3_2 = T.axis.remap("SSSS", [i0, i1, i2, i3])
             compute[i0_2, i1_2, i2_2,
                     i3_2] = T.max(bn_add[i0_2, i1_2, i2_2, i3_2],
                                   T.float32(0))
Ejemplo n.º 27
0
def transformed_matmul(a: T.handle, b: T.handle, c: T.handle) -> None:
    A = T.match_buffer(a, [128, 128])
    B = T.match_buffer(b, [128, 128])
    C = T.match_buffer(c, [128, 128])

    for i0, i1, i2_outer, i2_inner_outer, i2_inner_inner in T.grid(
            128, 128, 4, 8, 4):
        with T.block([128, 128, T.reduce_axis(0, 128)],
                     "update") as [vi, vj, vk]:
            T.bind(vi, i0)
            T.bind(vj, i1)
            T.bind(vk,
                   (((i2_outer * 32) + (i2_inner_outer * 4)) + i2_inner_inner))
            T.reads([C[vi, vj], A[vi, vk], B[vj, vk]])
            T.writes([C[vi, vj]])
            with T.init():
                C[vi, vj] = 0.0
            C[vi, vj] = C[vi, vj] + (A[vi, vk] * B[vj, vk])
Ejemplo n.º 28
0
 def func_match_buffer(A: T.Buffer[(128, 128), "float32"],
                       B: T.Buffer[(128, 128), "float32"]):
     with T.block("root"):
         s = T.var("int32")
         e = T.var("int32")
         # A0 should be remapped
         A0 = T.match_buffer(
             A[0:128, 0:128],
             shape=(128, 128),
             dtype="float32",
             # s and e should be remapped
             strides=[s, s],
             elem_offset=e,
         )
         for i, j in T.grid(128, 128):
             with T.block("B"):
                 vi, vj = T.axis.remap("SS", [i, j])
                 B[vi, vj] = A0[vi, vj] * 2.0
Ejemplo n.º 29
0
def concat_func_3(
    placeholder: T.Buffer[(50176,), "int8"],
    placeholder_1: T.Buffer[(25088,), "int8"],
    placeholder_2: T.Buffer[(25088,), "int8"],
    T_concat: T.Buffer[(100352,), "int8"],
) -> None:
    T.preflattened_buffer(placeholder, (1, 64, 28, 28), "int8", data=placeholder.data)
    T.preflattened_buffer(placeholder_1, (1, 32, 28, 28), "int8", data=placeholder_1.data)
    T.preflattened_buffer(placeholder_2, (1, 32, 28, 28), "int8", data=placeholder_2.data)
    T.preflattened_buffer(T_concat, (1, 128, 28, 28), "int8", data=T_concat.data)
    for i1 in T.serial(128, annotations={"pragma_loop_partition_hint": 1}):
        for i2, i3 in T.grid(28, 28):
            if 96 <= i1:
                T_concat[i1 * 784 + i2 * 28 + i3] = placeholder_2[i1 * 784 + i2 * 28 + i3 - 75264]
            if 64 <= i1 and i1 < 96:
                T_concat[i1 * 784 + i2 * 28 + i3] = placeholder_1[i1 * 784 + i2 * 28 + i3 - 50176]
            if i1 < 64:
                T_concat[i1 * 784 + i2 * 28 + i3] = placeholder[i1 * 784 + i2 * 28 + i3]
Ejemplo n.º 30
0
    def mma_sync_desc(a: T.handle, b: T.handle, c: T.handle) -> None:
        A = T.match_buffer(a, (WARP_SIZE, local_size),
                           in_dtype,
                           align=128,
                           offset_factor=16,
                           scope="warp")
        B = T.match_buffer(b, (WARP_SIZE, local_size),
                           in_dtype,
                           align=128,
                           offset_factor=16,
                           scope="warp")
        C = T.match_buffer(c, (WARP_SIZE, local_size_out),
                           out_dtype,
                           align=128,
                           offset_factor=16,
                           scope="warp")

        with T.block("root"):
            T.reads(
                C[0:WARP_SIZE, 0:local_size_out],
                A[0:WARP_SIZE, 0:local_size],
                B[0:WARP_SIZE, 0:local_size],
            )
            T.writes(C[0:WARP_SIZE, 0:local_size_out])

            for i, j, k in T.grid(M_DIM, N_DIM, k_dim):
                with T.block("C"):
                    i, j, k = T.axis.remap("SSR", [i, j, k])
                    b_row_ind, b_col_ind = maybe_swap(k, j)

                    thread_id_C, local_id_C = index_map_C(i, j)
                    thread_id_A, local_id_A = index_map_A(i, k)
                    thread_id_B, local_id_B = index_map_B(b_row_ind, b_col_ind)

                    T.reads(
                        C[thread_id_C, local_id_C],
                        A[thread_id_A, local_id_A],
                        B[thread_id_B, local_id_B],
                    )
                    T.writes(C[thread_id_C, local_id_C])

                    C[thread_id_C, local_id_C] += maybe_cast(
                        A[thread_id_A, local_id_A]) * maybe_cast(B[thread_id_B,
                                                                   local_id_B])