示例#1
0
def lowered_two_bound_loops(a: T.handle, b: T.handle) -> None:
    A = T.match_buffer(a, [128, 128], dtype="float32")
    B = T.match_buffer(b, [128], dtype="float32")
    reduce_temp0 = T.alloc_buffer([1],
                                  dtype="float32",
                                  strides=[1],
                                  scope="local")
    for i in T.serial(0, 128):
        for ko in T.thread_binding(0, 4, thread="threadIdx.x"):
            for ki in T.thread_binding(0, 32, thread="threadIdx.y"):
                with T.block("B_cross_thread_reduction"):
                    vi = T.axis.spatial(128, i)
                    vk = T.axis.reduce(128, ko * 32 + ki)
                    T.reads([A[vi, vk]])
                    T.writes([reduce_temp0[0]])
                    T.attr(
                        T.comm_reducer(lambda x, y: x + y, [T.float32(0)]),
                        "reduce_scope",
                        T.reinterpret(T.uint64(0), dtype="handle"),
                    )
                    T.evaluate(
                        T.tvm_thread_allreduce(T.uint32(1),
                                               A[vi, vk],
                                               True,
                                               reduce_temp0.data,
                                               ko,
                                               ki,
                                               dtype="handle"))
                with T.block("B_write_back"):
                    vi = T.axis.spatial(128, i)
                    T.reads([reduce_temp0[0]])
                    T.writes([B[vi]])
                    B[vi] = reduce_temp0[0]
 def main(a: T.handle, b: T.handle) -> None:
     T.func_attr({"global_symbol": "main", "tir.noalias": True})
     A = T.match_buffer(a, (128, 128, 4),
                        dtype="float32",
                        scope="global.texture")
     B = T.alloc_buffer((128, 128, 4),
                        dtype="float32",
                        scope="global.texture")
     C = T.match_buffer(b, (128, 128, 4),
                        dtype="float32",
                        scope="global.texture")
     for block_idx in T.thread_binding(0, 128, thread="blockIdx.x"):
         for thread_idx in T.thread_binding(0,
                                            128,
                                            thread="threadIdx.x"):
             for k in T.serial(4):
                 with T.block("B"):
                     vb, vt, vk = T.axis.remap(
                         "SSS", [block_idx, thread_idx, k])
                     B[vb, vt, vk] = A[vb, vt, vk] + T.float32(1)
     for block_idx in T.thread_binding(0, 128, thread="blockIdx.x"):
         for thread_idx in T.thread_binding(0,
                                            128,
                                            thread="threadIdx.x"):
             for k in T.serial(4):
                 with T.block("C"):
                     vb, vt, vk = T.axis.remap(
                         "SSS", [block_idx, thread_idx, k])
                     C[vb, vt, vk] = B[vb, vt, vk] * T.float32(2)
示例#3
0
 def main(placeholder: T.Buffer[(12, 64, 64), "float32"],
          T_reshape: T.Buffer[(64, 768), "float32"]) -> None:
     for i0_i1_fused_0_i0_i1_fused_1_fused_0 in T.thread_binding(
             48, thread="blockIdx.x"):
         for i0_i1_fused_0_i0_i1_fused_1_fused_1 in T.thread_binding(
                 1024, thread="threadIdx.x"):
             with T.block("T_reshape_1"):
                 ax0 = T.axis.spatial(
                     64,
                     ((i0_i1_fused_0_i0_i1_fused_1_fused_0 * 1024 +
                       i0_i1_fused_0_i0_i1_fused_1_fused_1) // 32 * 32 +
                      (i0_i1_fused_0_i0_i1_fused_1_fused_0 * 1024 +
                       i0_i1_fused_0_i0_i1_fused_1_fused_1) % 32) // 768,
                 )
                 ax1 = T.axis.spatial(
                     768,
                     ((i0_i1_fused_0_i0_i1_fused_1_fused_0 * 1024 +
                       i0_i1_fused_0_i0_i1_fused_1_fused_1) // 32 * 32 +
                      (i0_i1_fused_0_i0_i1_fused_1_fused_0 * 1024 +
                       i0_i1_fused_0_i0_i1_fused_1_fused_1) % 32) % 768,
                 )
                 T.reads(placeholder[ax1 % 768 // 64,
                                     (ax1 // 768 + ax0) % 64, ax1 % 64])
                 T.writes(T_reshape[ax0, ax1])
                 T_reshape[ax0, ax1] = placeholder[
                     ((ax1 % 64 // 64 +
                       (ax1 // 768 + ax0) % 64) // 64 + ax1 % 768 // 64) %
                     12, (ax1 % 64 // 64 + (ax1 // 768 + ax0) % 64) % 64,
                     ax1 % 64 % 64, ]
示例#4
0
def lowered_multiple_blocks_under_reduction_loop(a: T.handle,
                                                 b: T.handle) -> None:
    A = T.match_buffer(a, [16, 16, 16], dtype="float32")
    B = T.match_buffer(b, [16], dtype="float32")
    B_rf_local = T.alloc_buffer([16, 16], dtype="float32", scope="local")
    reduce_temp0 = T.alloc_buffer([1],
                                  dtype="float32",
                                  strides=[1],
                                  scope="local")
    normal_reduce_temp0 = T.alloc_buffer([1],
                                         dtype="float32",
                                         strides=[1],
                                         scope="local")
    for i in T.thread_binding(0, 16, thread="blockIdx.x"):
        for k0o in T.thread_binding(0, 4, thread="threadIdx.x"):
            with T.block("B_in_thread_init"):
                T.reads([])
                T.writes([normal_reduce_temp0[0]])
                normal_reduce_temp0[0] = T.float32(0)
            for k0i0, k1 in T.grid(4, 16):
                with T.block("B_rf"):
                    vk0 = T.axis.spatial(16, k0o * 4 + k0i0)
                    vi, vk1 = T.axis.remap("SR", [i, k1])
                    T.reads([B_rf_local[vk0, vi], A[vi, vk0, vk1]])
                    T.writes([B_rf_local[vk0, vi]])
                    with T.init():
                        B_rf_local[vk0, vi] = T.float32(0)
                    B_rf_local[vk0, vi] = B_rf_local[vk0, vi] + A[vi, vk0, vk1]
            for k0i1 in T.serial(0, 4):
                with T.block("B_normal_reduction"):
                    vk0 = T.axis.reduce(16, k0o * 4 + k0i1)
                    vi = T.axis.spatial(16, i)
                    T.reads([B_rf_local[vk0, vi], normal_reduce_temp0[0]])
                    T.writes([normal_reduce_temp0[0]])
                    normal_reduce_temp0[
                        0] = normal_reduce_temp0[0] + B_rf_local[vk0, vi]
            with T.block("B_cross_thread_reduction"):
                T.reads([normal_reduce_temp0[0]])
                T.writes([reduce_temp0[0]])
                T.attr(
                    T.comm_reducer(lambda x, y: x + y, [T.float32(0)]),
                    "reduce_scope",
                    T.reinterpret(T.uint64(0), dtype="handle"),
                )
                T.evaluate(
                    T.tvm_thread_allreduce(
                        T.uint32(1),
                        normal_reduce_temp0[0],
                        True,
                        reduce_temp0.data,
                        k0o,
                        dtype="handle",
                    ))
            with T.block("B_write_back"):
                vi = T.axis.spatial(16, i)
                T.reads([reduce_temp0[0]])
                T.writes([B[vi]])
                B[vi] = reduce_temp0[0]
示例#5
0
def softmax(var_A: T.handle, var_T_softmax_norm: T.handle) -> None:
    A = T.match_buffer(var_A, [256, 256], dtype="float32")
    T_softmax_norm = T.match_buffer(var_T_softmax_norm, [256, 256], dtype="float32")
    T_softmax_maxelem_shared = T.alloc_buffer([256], dtype="float32", scope="shared")
    T_softmax_expsum_shared = T.alloc_buffer([256], dtype="float32", scope="shared")
    for i0 in T.thread_binding(0, 256, thread="blockIdx.x"):
        for ax0_0 in T.serial(0, 8):
            for ax0_1 in T.thread_binding(0, 32, thread="threadIdx.x"):
                with T.block("T_softmax_maxelem"):
                    i0_1 = T.axis.spatial(256, i0)
                    k = T.axis.reduce(256, ax0_0 * 32 + ax0_1)
                    T.reads([T_softmax_maxelem_shared[i0_1], A[i0_1, k]])
                    T.writes([T_softmax_maxelem_shared[i0_1]])
                    with T.init():
                        T_softmax_maxelem_shared[i0_1] = T.min_value("float32")
                    T_softmax_maxelem_shared[i0_1] = T.max(
                        T_softmax_maxelem_shared[i0_1], A[i0_1, k]
                    )
        for ax0_0 in T.serial(0, 8):
            for ax0_1 in T.thread_binding(0, 32, thread="threadIdx.x"):
                with T.block("T_softmax_expsum"):
                    i0_2 = T.axis.spatial(256, i0)
                    k = T.axis.reduce(256, ax0_0 * 32 + ax0_1)
                    T.reads(
                        [
                            T_softmax_expsum_shared[i0_2],
                            A[i0_2, k],
                            T_softmax_maxelem_shared[i0_2],
                        ]
                    )
                    T.writes([T_softmax_expsum_shared[i0_2]])
                    with T.init():
                        T_softmax_expsum_shared[i0_2] = T.float32(0)
                    T_softmax_expsum_shared[i0_2] = T_softmax_expsum_shared[i0_2] + T.exp(
                        A[i0_2, k] - T_softmax_maxelem_shared[i0_2], dtype="float32"
                    )
        for i1_0 in T.serial(0, 8):
            for i1_1 in T.thread_binding(0, 32, thread="threadIdx.x"):
                with T.block("T_softmax_norm"):
                    i0_3 = T.axis.spatial(256, i0)
                    i1 = T.axis.spatial(256, i1_0 * 32 + i1_1)
                    T.reads(
                        [
                            A[i0_3, i1],
                            T_softmax_maxelem_shared[i0_3],
                            T_softmax_expsum_shared[i0_3],
                        ]
                    )
                    T.writes([T_softmax_norm[i0_3, i1]])
                    T.block_attr({"axis": 1})
                    T_softmax_norm[i0_3, i1] = (
                        T.exp(
                            A[i0_3, i1] - T_softmax_maxelem_shared[i0_3],
                            dtype="float32",
                        )
                        / T_softmax_expsum_shared[i0_3]
                    )
示例#6
0
 def main(X: T.Buffer[(1, 512, 56, 56), "float32"],
          W: T.Buffer[(512, 512, 3, 3), "float32"], B: T.Buffer[(512, 1, 1),
                                                                "float32"],
          bn_scale: T.Buffer[(512, 1, 1),
                             "float32"], bn_offset: T.Buffer[(512, 1, 1),
                                                             "float32"],
          compute: T.Buffer[(1, 512, 56, 56), "float32"]) -> None:
     compute_local = T.alloc_buffer([1, 512, 56, 56],
                                    dtype="float32",
                                    scope="local")
     for i0_0_i1_0_i2_0_i3_0_fused in T.thread_binding(224,
                                                       thread="blockIdx.x"):
         for i0_1_i1_1_i2_1_i3_1_fused in T.thread_binding(
                 2, thread="vthread.x"):
             for i0_2_i1_2_i2_2_i3_2_fused in T.thread_binding(
                     8, thread="threadIdx.x"):
                 for i4_0, i5_0, i6_0, i4_1, i5_1, i6_1, i0_3, i1_3, i2_3, i3_3, i4_2, i5_2, i6_2, i0_4, i1_4, i2_4, i3_4 in T.grid(
                         1, 3, 1, 32, 1, 1, 1, 1, 1, 1, 16, 1, 3, 1, 8, 2,
                         28):
                     with T.block("compute"):
                         nn = T.axis.spatial(1, 0)
                         ff = T.axis.spatial(
                             512, i0_0_i1_0_i2_0_i3_0_fused // 14 * 32 +
                             i0_2_i1_2_i2_2_i3_2_fused // 2 * 8 + i1_4)
                         yy = T.axis.spatial(
                             56, i0_0_i1_0_i2_0_i3_0_fused // 2 % 7 * 8 +
                             i0_1_i1_1_i2_1_i3_1_fused * 4 +
                             i0_2_i1_2_i2_2_i3_2_fused % 2 * 2 + i2_4)
                         xx = T.axis.spatial(
                             56, i0_0_i1_0_i2_0_i3_0_fused % 2 * 28 + i3_4)
                         rc = T.axis.reduce(512, i4_1 * 16 + i4_2)
                         ry, rx = T.axis.remap("RR", [i5_0, i6_2])
                         with T.init():
                             compute_local[nn, ff, yy, xx] = T.float32(0)
                         compute_local[nn, ff, yy, xx] = compute_local[
                             nn, ff, yy, xx] + T.if_then_else(
                                 yy + ry >= 1 and yy + ry < 57
                                 and xx + rx >= 1 and xx + rx < 57,
                                 X[nn, rc, yy + ry - 1, xx + rx - 1],
                                 T.float32(0),
                                 dtype="float32") * W[ff, rc, ry, rx]
                 for ax0, ax1, ax2, ax3 in T.grid(1, 8, 2, 28):
                     with T.block("compute_local"):
                         v0 = T.axis.spatial(1, ax0)
                         v1 = T.axis.spatial(
                             512, i0_0_i1_0_i2_0_i3_0_fused // 14 * 32 +
                             i0_2_i1_2_i2_2_i3_2_fused // 2 * 8 + ax1)
                         v2 = T.axis.spatial(
                             56, i0_0_i1_0_i2_0_i3_0_fused % 14 // 2 * 8 +
                             i0_1_i1_1_i2_1_i3_1_fused * 4 +
                             i0_2_i1_2_i2_2_i3_2_fused % 2 * 2 + ax2)
                         v3 = T.axis.spatial(
                             56, i0_0_i1_0_i2_0_i3_0_fused % 2 * 28 + ax3)
                         compute[v0, v1, v2, v3] = T.max(
                             (compute_local[v0, v1, v2, v3] + B[v1, 0, 0]) *
                             bn_scale[v1, 0, 0] + bn_offset[v1, 0, 0],
                             T.float32(0))
示例#7
0
def loop_syntax_sugar(a: T.handle) -> None:
    A = T.match_buffer(a, (128, 128, 128, 128))
    for i in T.serial(128):
        for j in T.parallel(128):
            for k in T.vectorized(128):
                for x in T.unroll(128):
                    for y in T.thread_binding(128, "threadIdx.x"):
                        for z in T.thread_binding(128, thread="threadIdx.x"):
                            A[i, j, k, x] = A[i, j, k, x] * 2.0
def GmmCuda2(X: T.Buffer[(1, 128, 128), "float32"], Y: T.Buffer[(1, 128, 128), "float32"], Z: T.Buffer[(1, 128, 128), "float32"]) -> None:
    Z_local = T.alloc_buffer([1, 128, 128], dtype="float32", scope="local")
    X_shared = T.alloc_buffer([1, 128, 128], dtype="float32", scope="shared")
    Y_shared = T.alloc_buffer([1, 128, 128], dtype="float32", scope="shared")
    for i0_0_i1_0_i2_0_fused in T.thread_binding(16, thread="blockIdx.x"):
        for i0_1_i1_1_i2_1_fused in T.thread_binding(1, thread="vthread.x"):
            for i0_2_i1_2_i2_2_fused in T.thread_binding(128, thread="threadIdx.x"):
                for i1_3_init, i2_4_init in T.grid(4, 2):
                    with T.block("Z_init"):
                        b = T.axis.spatial(1, 0)
                        i = T.axis.spatial(128, i0_0_i1_0_i2_0_fused // 4 * 32 + i0_2_i1_2_i2_2_fused // 16 * 4 + i1_3_init)
                        j = T.axis.spatial(128, i0_0_i1_0_i2_0_fused % 4 * 32 + i0_2_i1_2_i2_2_fused % 16 * 2 + i2_4_init)
                        T.reads()
                        T.writes(Z_local[b, i, j])
                        Z_local[b, i, j] = T.float32(0)
                for i3_0 in T.serial(4):
                    for ax0_ax1_ax2_fused_0 in T.serial(4):
                        for ax0_ax1_ax2_fused_1 in T.thread_binding(128, thread="threadIdx.x"):
                            for ax0_ax1_ax2_fused_2 in T.vectorized(2):
                                with T.block("X_shared"):
                                    v0 = T.axis.spatial(1, 0)
                                    v1 = T.axis.spatial(128, i0_0_i1_0_i2_0_fused // 4 * 32 + (ax0_ax1_ax2_fused_0 * 256 + ax0_ax1_ax2_fused_1 * 2 + ax0_ax1_ax2_fused_2) // 32)
                                    v2 = T.axis.spatial(128, i3_0 * 32 + (ax0_ax1_ax2_fused_0 * 256 + ax0_ax1_ax2_fused_1 * 2 + ax0_ax1_ax2_fused_2) % 32)
                                    T.reads(X[v0, v1, v2])
                                    T.writes(X_shared[v0, v1, v2])
                                    X_shared[v0, v1, v2] = X[v0, v1, v2]
                    for ax0_ax1_ax2_fused_0 in T.serial(8):
                        for ax0_ax1_ax2_fused_1 in T.thread_binding(128, thread="threadIdx.x"):
                            with T.block("Y_shared"):
                                v0 = T.axis.spatial(1, 0)
                                v1 = T.axis.spatial(128, i3_0 * 32 + (ax0_ax1_ax2_fused_0 * 128 + ax0_ax1_ax2_fused_1) // 32)
                                v2 = T.axis.spatial(128, i0_0_i1_0_i2_0_fused % 4 * 32 + (ax0_ax1_ax2_fused_0 * 128 + ax0_ax1_ax2_fused_1) % 32)
                                T.reads(Y[v0, v1, v2])
                                T.writes(Y_shared[v0, v1, v2])
                                Y_shared[v0, v1, v2] = Y[v0, v1, v2]
                    for i3_1, i0_3, i1_3, i2_3, i3_2, i0_4, i1_4, i2_4 in T.grid(1, 1, 4, 1, 32, 1, 1, 2):
                        with T.block("Z_update"):
                            b = T.axis.spatial(1, 0)
                            i = T.axis.spatial(128, i0_0_i1_0_i2_0_fused // 4 * 32 + i0_2_i1_2_i2_2_fused // 16 * 4 + i1_3)
                            j = T.axis.spatial(128, i0_0_i1_0_i2_0_fused % 4 * 32 + i0_2_i1_2_i2_2_fused % 16 * 2 + i2_4)
                            k = T.axis.reduce(128, i3_0 * 32 + i3_2)
                            T.block_attr({
                                "meta_schedule.thread_extent_low_inclusive": 1024,
                                "meta_schedule.thread_extent_high_inclusive": 1024,
                            })
                            T.reads(Z_local[b, i, j], X_shared[b, i, k], Y_shared[b, k, j])
                            T.writes(Z_local[b, i, j])
                            Z_local[b, i, j] = Z_local[b, i, j] + X_shared[b, i, k] * Y_shared[b, k, j]
                for ax0, ax1, ax2 in T.grid(1, 4, 2):
                    with T.block("Z_local"):
                        v0 = T.axis.spatial(1, ax0)
                        v1 = T.axis.spatial(128, i0_0_i1_0_i2_0_fused // 4 * 32 + i0_2_i1_2_i2_2_fused // 16 * 4 + ax1)
                        v2 = T.axis.spatial(128, i0_0_i1_0_i2_0_fused % 4 * 32 + i0_2_i1_2_i2_2_fused % 16 * 2 + ax2)
                        T.reads(Z_local[v0, v1, v2])
                        T.writes(Z[v0, v1, v2])
                        Z[v0, v1, v2] = Z_local[v0, v1, v2]
def unified_element_wise_vthread_x(a: T.handle, b: T.handle) -> None:
    A = T.match_buffer(a, [128, 128])
    B = T.match_buffer(b, [128, 128])
    for vthread_x in T.thread_binding(0, 2, "vthread.x"):
        for threadIdx_x in T.thread_binding(0, 64, "threadIdx.x"):
            for j_1 in T.serial(0, 64):
                with T.block(""):
                    B[vthread_x * 64 + threadIdx_x, vthread_x * 64 + j_1] = (
                        A[vthread_x * 64 + threadIdx_x, vthread_x * 64 + j_1] *
                        2.0)
示例#10
0
 def main(A: T.Buffer[8, "float32"], B: T.Buffer[8, "float32"]) -> None:
     T.func_attr({"global_symbol": "main", "tir.noalias": True})
     for i_0 in T.thread_binding(2, thread="blockIdx.x"):
         for i_2 in T.thread_binding(2, thread="threadIdx.x"):
             for i_1 in T.serial(2):
                 with T.block("B"):
                     vi = T.axis.spatial(8, i_0 * 4 + i_1 * 2 + i_2)
                     T.reads(A[vi])
                     T.writes(B[vi])
                     B[vi] = A[vi] + T.float32(1)
def element_wise_vthread_x(a: T.handle, b: T.handle) -> None:
    A = T.match_buffer(a, [128, 128])
    B = T.match_buffer(b, [128, 128])
    for i_0 in T.thread_binding(0, 2, "vthread.x"):
        for i_1 in T.thread_binding(0, 64, "threadIdx.x"):
            for j_0 in T.thread_binding(0, 2, "vthread.x"):
                for j_1 in T.serial(0, 64):
                    with T.block(""):
                        B[i_0 * 64 + i_1, j_0 * 64 +
                          j_1] = A[i_0 * 64 + i_1, j_0 * 64 + j_1] * 2.0
示例#12
0
 def main(var_A: T.handle, var_B: T.handle) -> None:
     A = T.match_buffer(var_A, [512, 512], dtype="float32")
     B = T.match_buffer(var_B, [512, 512], dtype="float32")
     for i_j_fused_0 in T.thread_binding(0, 8192, thread="blockIdx.x"):
         for i_j_fused_1 in T.thread_binding(0, 32, thread="threadIdx.x"):
             with T.block("C"):
                 vi = T.axis.spatial(
                     512, (i_j_fused_0 * 32 + i_j_fused_1) // 512)
                 vj = T.axis.spatial(512,
                                     (i_j_fused_0 * 32 + i_j_fused_1) % 512)
                 B[vi, vj] = A[vi, vj] + 1.0
def element_wise_two_thread_x_in_same_kernel_not_equal(a: T.handle,
                                                       b: T.handle,
                                                       c: T.handle) -> None:
    A = T.match_buffer(a, [128, 128])
    B = T.match_buffer(b, [128, 128])
    C = T.match_buffer(c, [128, 64])
    for i in T.thread_binding(0, 128, "blockIdx.x"):
        for j0 in T.thread_binding(0, 128, "threadIdx.x"):
            B[i, j0] = A[i, j0] * 2.0
        for j1 in T.thread_binding(0, 64, "threadIdx.x"):
            C[i, j1] = A[i, j1] + 1.0
def element_wise_kernels_with_different_size(a: T.handle, b: T.handle,
                                             c: T.handle, d: T.handle) -> None:
    A = T.match_buffer(a, [128, 128])
    B = T.match_buffer(b, [128, 128])
    C = T.match_buffer(c, [256, 256])
    D = T.match_buffer(d, [256, 256])
    for i0 in T.thread_binding(0, 128, "blockIdx.x"):
        for j0 in T.thread_binding(0, 128, "threadIdx.x"):
            B[i0, j0] = A[i0, j0] * 2.0
    for i1 in T.thread_binding(0, 256, "blockIdx.x"):
        for j1 in T.thread_binding(0, 256, "threadIdx.x"):
            D[i1, j1] = C[i1, j1] + 1.0
def unified_element_wise_kernels_with_different_size(a: T.handle, b: T.handle,
                                                     c: T.handle,
                                                     d: T.handle) -> None:
    A = T.match_buffer(a, [128, 128])
    B = T.match_buffer(b, [128, 128])
    C = T.match_buffer(c, [256, 256])
    D = T.match_buffer(d, [256, 256])
    for blockIdx_x in T.thread_binding(0, 128, "blockIdx.x"):
        for threadIdx_x in T.thread_binding(0, 128, "threadIdx.x"):
            B[blockIdx_x, threadIdx_x] = A[blockIdx_x, threadIdx_x] * 2.0
    for blockIdx_x in T.thread_binding(0, 256, "blockIdx.x"):
        for threadIdx_x in T.thread_binding(0, 256, "threadIdx.x"):
            D[blockIdx_x, threadIdx_x] = C[blockIdx_x, threadIdx_x] + 1.0
def element_wise_thread_x(a: T.handle, b: T.handle, c: T.handle) -> None:
    A = T.match_buffer(a, [128, 128])
    B = T.match_buffer(b, [128, 128])
    C = T.match_buffer(c, [128, 128])
    for i in T.thread_binding(0, 128, "blockIdx.x"):
        for j0_0 in T.thread_binding(0, 4, "threadIdx.x"):
            for j0_1 in T.serial(0, 32):
                with T.block(""):
                    B[i, j0_0 * 32 + j0_1] = A[i, j0_0 * 32 + j0_1] * 2.0
        for j1_0 in T.thread_binding(0, 4, "threadIdx.x"):
            for j1_1 in T.serial(0, 32):
                with T.block(""):
                    C[i, j1_0 * 32 + j1_1] = B[i, j1_0 * 32 + j1_1] + 1.0
示例#17
0
def single_reduction_loop_with_block_predicate(
        A: T.Buffer[(256, 256), "float32"],
        T_softmax_norm: T.Buffer[(256, 256), "float32"]) -> None:
    T_softmax_maxelem_shared = T.alloc_buffer([256],
                                              dtype="float32",
                                              scope="shared")
    T_softmax_expsum_shared = T.alloc_buffer([256],
                                             dtype="float32",
                                             scope="shared")
    for i0 in T.serial(256):
        for ax0, ax1_0 in T.grid(1, 1):
            for ax1_1 in T.thread_binding(512, thread="threadIdx.x"):
                with T.block("T_softmax_maxelem"):
                    i0_1 = T.axis.spatial(256, i0)
                    k = T.axis.reduce(256, ax1_1)
                    T.where(ax1_0 * 512 + ax1_1 < 256)
                    T.reads(T_softmax_maxelem_shared[i0_1], A[i0_1, k])
                    T.writes(T_softmax_maxelem_shared[i0_1])
                    with T.init():
                        T_softmax_maxelem_shared[i0_1] = T.float32(
                            -3.4028234663852886e38)
                    T_softmax_maxelem_shared[i0_1] = T.max(
                        T_softmax_maxelem_shared[i0_1], A[i0_1, k])
        for ax0, ax1_0 in T.grid(1, 1):
            for ax1_1 in T.thread_binding(512, thread="threadIdx.x"):
                with T.block("T_softmax_expsum"):
                    i0_2 = T.axis.spatial(256, i0)
                    k = T.axis.reduce(256, ax1_1)
                    T.where(ax1_0 * 512 + ax1_1 < 256)
                    T.reads(T_softmax_expsum_shared[i0_2], A[i0_2, k],
                            T_softmax_maxelem_shared[i0_2])
                    T.writes(T_softmax_expsum_shared[i0_2])
                    with T.init():
                        T_softmax_expsum_shared[i0_2] = T.float32(0)
                    T_softmax_expsum_shared[
                        i0_2] = T_softmax_expsum_shared[i0_2] + T.exp(
                            A[i0_2, k] - T_softmax_maxelem_shared[i0_2],
                            dtype="float32")
        for i1_0 in T.serial(1):
            for i1_1 in T.thread_binding(512, thread="threadIdx.x"):
                with T.block("T_softmax_norm"):
                    i0_3 = T.axis.spatial(256, i0)
                    i1 = T.axis.spatial(256, i1_1)
                    T.where(i1_0 * 512 + i1_1 < 256)
                    T.reads(A[i0_3, i1], T_softmax_maxelem_shared[i0_3],
                            T_softmax_expsum_shared[i0_3])
                    T.writes(T_softmax_norm[i0_3, i1])
                    T.block_attr({"axis": 1})
                    T_softmax_norm[i0_3, i1] = (
                        T.exp(A[i0_3, i1] - T_softmax_maxelem_shared[i0_3],
                              dtype="float32") / T_softmax_expsum_shared[i0_3])
示例#18
0
def thread_bound_nested_block(
    A: T.Buffer[(16, 16, 16, 16), "float32"], B: T.Buffer[(16, 16, 16), "float32"]
) -> None:
    for i in T.serial(16):
        for j in T.thread_binding(16, thread="blockIdx.x"):
            with T.block("outer"):
                vi, vj = T.axis.remap("SS", [i, j])
                for k in T.serial(16):
                    for l in T.thread_binding(16, thread="threadIdx.x"):
                        with T.block("inner"):
                            vk, vl = T.axis.remap("SR", [k, l])
                            with T.init():
                                B[vi, vj, vk] = T.float32(0)
                            B[vi, vj, vk] = B[vi, vj, vk] + A[vi, vj, vk, vl]
示例#19
0
def warp_memory(a: T.handle, c: T.handle) -> None:
    A = T.match_buffer(a, [128, 128])
    C = T.match_buffer(c, [128, 128])
    B = T.alloc_buffer([128, 4, 32], scope="warp")
    for i_o in T.thread_binding(0, 4, thread="threadIdx.y"):
        for i_i in T.thread_binding(0, 32, thread="threadIdx.x"):
            for j in T.serial(0, 128):
                with T.block("B"):
                    warp_id, lane_id, vj = T.axis.remap("SSS", [i_o, i_i, j])
                    B[vj, warp_id, lane_id] = A[warp_id * 32 + lane_id, vj] * 2.0
            for j in T.serial(0, 128):
                with T.block("C"):
                    warp_id, lane_id, vj = T.axis.remap("SSS", [i_o, i_i, j])
                    C[warp_id * 32 + lane_id, vj] = B[vj, warp_id, lane_id] + 1.0
示例#20
0
def two_bound_loops(a: T.handle, b: T.handle) -> None:
    A = T.match_buffer(a, [128, 128], dtype="float32")
    B = T.match_buffer(b, [128], dtype="float32")
    for i in T.serial(0, 128):
        for ko in T.thread_binding(0, 4, thread="threadIdx.x"):
            for ki in T.thread_binding(0, 32, thread="threadIdx.y"):
                with T.block("B"):
                    vi = T.axis.spatial(128, i)
                    vk = T.axis.reduce(128, ko * 32 + ki)
                    T.reads([B[vi], A[vi, vk]])
                    T.writes([B[vi]])
                    with T.init():
                        B[vi] = T.float32(0)
                    B[vi] = B[vi] + A[vi, vk]
 def main(var_A: T.handle, var_B: T.handle, var_C: T.handle) -> None:
     # function attr dict
     T.func_attr({"global_symbol": "main", "tir.noalias": True})
     A = T.match_buffer(var_A, [512, 512], dtype="float32")
     B = T.match_buffer(var_B, [512, 512], dtype="float32")
     C = T.match_buffer(var_C, [512, 512], dtype="float32")
     # body
     # with T.block("root")
     C_local = T.alloc_buffer([512, 512], dtype="float32", scope="local")
     A_shared = T.alloc_buffer([512, 512], dtype="float32", scope="shared")
     B_shared = T.alloc_buffer([512, 512], dtype="float32", scope="shared")
     for i0_0_i1_0_fused in T.thread_binding(0, 16, thread="blockIdx.x"):
         for i0_1_i1_1_fused in T.thread_binding(0, 16, thread="vthread.x"):
             for i0_2_i1_2_fused in T.thread_binding(0, 8, thread="threadIdx.x"):
                 for i2_0 in T.serial(0, 1):
                     for ax0_ax1_fused_0 in T.serial(0, 32768):
                         for ax0_ax1_fused_1 in T.thread_binding(0, 8, thread="threadIdx.x"):
                             with T.block("A_shared"):
                                 v0 = T.axis.spatial(512, (ax0_ax1_fused_0 * 8 + ax0_ax1_fused_1) // 512)
                                 v1 = T.axis.spatial(512, (ax0_ax1_fused_0 * 8 + ax0_ax1_fused_1) % 512)
                                 T.reads([A[v0, v1]])
                                 T.writes([A_shared[v0, v1]])
                                 A_shared[v0, v1] = A[v0, v1]
                     for ax0_ax1_fused_0 in T.serial(0, 1024):
                         for ax0_ax1_fused_1 in T.thread_binding(0, 8, thread="threadIdx.x"):
                             for ax0_ax1_fused_2 in T.vectorized(0, 2):
                                 with T.block("B_shared"):
                                     v0 = T.axis.spatial(512, (ax0_ax1_fused_0 * 16 + ax0_ax1_fused_1 * 2 + ax0_ax1_fused_2) // 32)
                                     v1 = T.axis.spatial(512, i0_0_i1_0_fused * 32 + (ax0_ax1_fused_0 * 16 + ax0_ax1_fused_1 * 2 + ax0_ax1_fused_2) % 32)
                                     T.reads([B[v0, v1]])
                                     T.writes([B_shared[v0, v1]])
                                     B_shared[v0, v1] = B[v0, v1]
                     for i2_1, i0_3, i1_3, i2_2, i0_4, i1_4 in T.grid(16, 2, 2, 32, 16, 2):
                         with T.block("C"):
                             i = T.axis.spatial(512, i0_1_i1_1_fused * 32 + i0_3 * 16 + i0_4)
                             j = T.axis.spatial(512, i0_0_i1_0_fused * 32 + i0_2_i1_2_fused * 4 + i1_3 * 2 + i1_4)
                             k = T.axis.reduce(512, i2_1 * 32 + i2_2)
                             T.reads([A_shared[i, k], B_shared[k, j]])
                             T.writes([C_local[i, j]])
                             with T.init():
                                 C_local[i, j] = T.float32(0)
                             C_local[i, j] = C_local[i, j] + A_shared[i, k] * B_shared[k, j]
                 for ax0, ax1 in T.grid(32, 4):
                     with T.block("C_local"):
                         v0 = T.axis.spatial(512, i0_1_i1_1_fused * 32 + ax0)
                         v1 = T.axis.spatial(512, i0_0_i1_0_fused * 32 + i0_2_i1_2_fused * 4 + ax1)
                         T.reads([C_local[v0, v1]])
                         T.writes([C[v0, v1]])
                         C[v0, v1] = C_local[v0, v1]
示例#22
0
def element_wise_compute_at_split_j0_j1o_bound(a: T.handle, c: T.handle) -> None:
    A = T.match_buffer(a, (128, 128))
    C = T.match_buffer(c, (128, 128))
    B = T.alloc_buffer((128, 128))
    for i in T.serial(0, 128):
        for j0 in T.thread_binding(0, 128, thread="threadIdx.x"):
            with T.block("B"):
                vi, vj = T.axis.remap("SS", [i, j0])
                B[vi, vj] = A[vi, vj] * 2.0
        for j1o in T.thread_binding(0, 32, thread="threadIdx.x"):
            for j1i in T.serial(0, 4):
                with T.block("C"):
                    vi = T.axis.S(128, i)
                    vj = T.axis.S(128, j1o * 4 + j1i)
                    C[vi, vj] = B[vi, vj] + 1.0
def element_wise_thread_x_different_dtype(
    A: T.Buffer[(128, 128), "float32"],
    B: T.Buffer[(128, 128), "float32"],
    C: T.Buffer[(128, 128), "float32"],
) -> None:
    for i in T.thread_binding(128, "blockIdx.x"):
        for j0_0 in T.thread_binding(4, "threadIdx.x"):
            for j0_1 in T.serial(0, 32):
                with T.block(""):
                    B[i, j0_0 * 32 + j0_1] = A[i, j0_0 * 32 + j0_1] * 2.0
        for j1_0 in T.thread_binding(T.int64(4), "threadIdx.x"):
            for j1_1 in T.serial(T.int64(32)):
                with T.block(""):
                    C[i, j1_0 * T.int64(32) +
                      j1_1] = B[i, j1_0 * T.int64(32) + j1_1] + 1.0
示例#24
0
def transformed_simple_compute(A: T.Buffer[(16, 16), "float32"],
                               C: T.Buffer[(16, 16), "float32"]) -> None:
    for tx in T.thread_binding(0, 16, thread="threadIdx.x"):
        with T.block():
            T.reads([A[tx, 0:16]])
            T.writes([C[tx, 0:16]])
            B = T.alloc_buffer([2, 16, 1], dtype="float32", scope="shared")
            with T.block():
                T.reads([A[tx, 0]])
                T.writes([B[0, tx, 0]])
                B[0, tx, 0] = A[tx, 0] * T.float32(2)
            with T.block():
                T.reads([A[tx, 1:16], B[0:2, tx, 0]])
                T.writes([B[0:2, tx, 0], C[tx, 0:15]])
                for i in T.serial(0, 15):
                    with T.block():
                        T.reads([A[tx, i + 1]])
                        T.writes([B[(i + 1) % 2, tx, 0]])
                        B[(i + 1) % 2, tx, 0] = A[tx, i + 1] * T.float32(2)
                    with T.block():
                        T.reads([B[i % 2, tx, 0]])
                        T.writes([C[tx, i]])
                        C[tx, i] = B[i % 2, tx, 0] + T.float32(1)
            with T.block():
                T.reads([B[1, tx, 0]])
                T.writes([C[tx, 15]])
                C[tx, 15] = B[1, tx, 0] + T.float32(1)
def equal_ranked_threads(a: T.handle, c: T.handle) -> None:
    A = T.match_buffer(a, [128, 128])
    C = T.match_buffer(c, [128, 128])
    B = T.alloc_buffer([128, 128], scope="shared")
    for i_o in T.thread_binding(0, 16, thread="threadIdx.x"):
        for i_i in T.thread_binding(0, 8, thread="threadIdx.y"):
            for j in T.serial(0, 128):
                with T.block("B"):
                    vi = T.axis.S(128, i_o * 8 + i_i)
                    vj = T.axis.S(128, j)
                    B[vi, vj] = A[vi, vj] * 2.0
            for j in T.serial(0, 128):
                with T.block("C"):
                    vi = T.axis.S(128, i_o * 8 + i_i)
                    vj = T.axis.S(128, j)
                    C[vj, vi] = B[vj, vi] + 1.0
示例#26
0
def lowered_reducer_max(a: T.handle, b: T.handle) -> None:
    A = T.match_buffer(a, [128, 128], dtype="float32")
    B = T.match_buffer(b, [128], dtype="float32")
    reduce_temp0 = T.alloc_buffer([1],
                                  dtype="float32",
                                  strides=[1],
                                  scope="local")
    for i in T.serial(0, 128):
        for k in T.thread_binding(0, 128, thread="threadIdx.x"):
            with T.block("B_cross_thread_reduction"):
                vi, vk = T.axis.remap("SR", [i, k])
                T.reads([A[vi, vk]])
                T.writes([reduce_temp0[0]])
                T.attr(
                    T.comm_reducer(lambda x, y: T.max(x, y),
                                   [T.min_value("float32")]),
                    "reduce_scope",
                    T.reinterpret(T.uint64(0), dtype="handle"),
                )
                T.evaluate(
                    T.tvm_thread_allreduce(T.uint32(1),
                                           A[vi, vk],
                                           True,
                                           reduce_temp0.data,
                                           k,
                                           dtype="handle"))
            with T.block("B_write_back"):
                vi = T.axis.spatial(128, i)
                T.reads([reduce_temp0[0]])
                T.writes([B[vi]])
                B[vi] = reduce_temp0[0]
def unified_element_wise_thread_x(a: T.handle, b: T.handle,
                                  c: T.handle) -> None:
    A = T.match_buffer(a, [128, 128])
    B = T.match_buffer(b, [128, 128])
    C = T.match_buffer(c, [128, 128])

    for blockIdx_x in T.thread_binding(0, 128, "blockIdx.x"):
        for threadIdx_x in T.thread_binding(0, 4, "threadIdx.x"):
            for j0_1 in T.serial(0, 32):
                with T.block(""):
                    B[blockIdx_x, threadIdx_x * 32 +
                      j0_1] = (A[blockIdx_x, threadIdx_x * 32 + j0_1] * 2.0)
            for j1_1 in T.serial(0, 32):
                with T.block(""):
                    C[blockIdx_x, threadIdx_x * 32 +
                      j1_1] = (B[blockIdx_x, threadIdx_x * 32 + j1_1] + 1.0)
示例#28
0
def lowered_zero_rank_buffer(a: T.handle, b: T.handle) -> None:
    A = T.match_buffer(a, [128], dtype="float32")
    B = T.match_buffer(b, [], dtype="float32")
    reduce_temp0 = T.alloc_buffer([1],
                                  dtype="float32",
                                  strides=[1],
                                  scope="local")
    for k in T.thread_binding(0, 128, thread="threadIdx.x"):
        with T.block("B_cross_thread_reduction"):
            vk = T.axis.reduce(128, k)
            T.reads([A[vk]])
            T.writes([reduce_temp0[0]])
            T.attr(
                T.comm_reducer(lambda x, y: x + y, [T.float32(0)]),
                "reduce_scope",
                T.reinterpret(T.uint64(0), dtype="handle"),
            )
            T.evaluate(
                T.tvm_thread_allreduce(T.uint32(1),
                                       A[vk],
                                       True,
                                       reduce_temp0.data,
                                       k,
                                       dtype="handle"))
        with T.block("B_write_back"):
            T.reads([reduce_temp0[0]])
            T.writes([B[()]])
            B[()] = reduce_temp0[0]
示例#29
0
def simple_compute_conflicting_order(A: T.Buffer[(16, 16), "float32"],
                                     D: T.Buffer[(16, 16), "float32"]):
    for tx in T.thread_binding(0, 16, thread="threadIdx.x"):
        for i in T.serial(
                0,
                16,
                annotations={
                    "software_pipeline_stage": [0, 1, 1],
                    "software_pipeline_order": [0, 1, 1],
                },
        ):
            with T.block():
                T.reads(A[tx, i])
                T.writes(D[tx, i])
                B = T.alloc_buffer((16, 1), dtype="float32", scope="shared")
                C = T.alloc_buffer((16, 1), dtype="float32", scope="shared")
                with T.block():
                    T.reads(A[tx, i])
                    T.writes(B[tx, 0])
                    B[tx, 0] = A[tx, i] * T.float32(2)
                with T.block():
                    T.reads(B[tx, 0])
                    T.writes(C[tx, 0])
                    C[tx, 0] = B[tx, 0] + T.float32(2)
                with T.block():
                    T.reads(C[tx, 0])
                    T.writes(D[tx, i])
                    D[tx, i] = C[tx, 0] + T.float32(1)
示例#30
0
def element_wise_compute_at_split_j0_j1o_bound(a: T.handle, c: T.handle) -> None:
    A = T.match_buffer(a, (128, 128))
    C = T.match_buffer(c, (128, 128))
    B = T.alloc_buffer((128, 128))
    for i in T.serial(0, 128):
        for j0 in T.thread_binding(0, 128, thread="threadIdx.x"):
            with T.block([128, 128], "B") as [vi, vj]:
                T.bind(vi, i)
                T.bind(vj, j0)
                B[vi, vj] = A[vi, vj] * 2.0
        for j1o in T.thread_binding(0, 32, thread="threadIdx.x"):
            for j1i in T.serial(0, 4):
                with T.block([128, 128], "C") as [vi, vj]:
                    T.bind(vi, i)
                    T.bind(vj, j1o * 4 + j1i)
                    C[vi, vj] = B[vi, vj] + 1.0