Пример #1
0
def elementwise_predicate(a: T.handle, b: T.handle) -> None:
    A = T.match_buffer(a, (128, 128, 128, 128))
    B = T.match_buffer(b, (128, 128, 128, 128))
    for i, j, k, l in T.grid(128, 128, 128, 128):
        with T.block([128, 128, 128, 128], "B") as [vi, vj, vk, vl]:
            T.where(i * 2097152 + j * 16384 + k * 128 + l < 100)
            B[vi, vj, vk, vl] = A[vi, vj, vk, vl] * 2.0
Пример #2
0
def elementwise_predicate_inlined(a: T.handle, c: T.handle) -> None:
    A = T.match_buffer(a, (128, 128))
    C = T.match_buffer(c, (128, 128))
    for i, j in T.grid(128, 128):
        with T.block([128, 128], "C") as [vi, vj]:
            T.where(A[i, j] * 2.0 < 10.0)
            C[vi, vj] = A[vi, vj] * 2.0 + 1.0
def matmul_decompose4(a: T.handle, b: T.handle, c: T.handle) -> None:
    C = T.match_buffer(c, [128, 128],
                       elem_offset=0,
                       align=128,
                       offset_factor=1)
    B = T.match_buffer(b, [128, 128],
                       elem_offset=0,
                       align=128,
                       offset_factor=1)
    A = T.match_buffer(a, [128, 128],
                       elem_offset=0,
                       align=128,
                       offset_factor=1)
    # body
    with T.block("root"):
        T.reads([])
        T.writes([])
        for i0_0 in T.serial(0, 16):
            for i0_1_init, i1_init in T.grid(8, 128):
                with T.block("update_init"):
                    vi_init = T.axis.S(128, i0_0 * 8 + i0_1_init)
                    vj_init = T.axis.S(128, i1_init)
                    C[vi_init, vj_init] = T.float32(0)
            for i0_1, i1, i2_0, i2_1 in T.grid(8, 128, 19, 7):
                with T.block("update_update"):
                    T.where((((i2_0 * 7) + i2_1) < 128))
                    vi = T.axis.S(128, i0_0 * 8 + i0_1)
                    vj = T.axis.S(128, i1)
                    vk = T.axis.R(128, i2_0 * 7 + i2_1)
                    C[vi, vj] = C[vi, vj] + (A[vi, vk] * B[vj, vk])
Пример #4
0
def matmul_decompose4(a: T.handle, b: T.handle, c: T.handle) -> None:
    C = T.match_buffer(c, [128, 128],
                       elem_offset=0,
                       align=128,
                       offset_factor=1)
    B = T.match_buffer(b, [128, 128],
                       elem_offset=0,
                       align=128,
                       offset_factor=1)
    A = T.match_buffer(a, [128, 128],
                       elem_offset=0,
                       align=128,
                       offset_factor=1)
    # body
    with T.block([], "root"):
        T.reads([])
        T.writes([])
        for i0_0 in T.serial(0, 16):
            for i0_1_init, i1_init in T.grid(8, 128):
                with T.block([128, 128], "update_init") as [vi_init, vj_init]:
                    T.bind(vi_init, ((i0_0 * 8) + i0_1_init))
                    T.bind(vj_init, i1_init)
                    C[vi_init, vj_init] = T.float32(0)
            for i0_1, i1, i2_0, i2_1 in T.grid(8, 128, 19, 7):
                with T.block([128, 128, T.reduce_axis(0, 128)],
                             "update_update") as [
                                 vi,
                                 vj,
                                 vk,
                             ]:
                    T.where((((i2_0 * 7) + i2_1) < 128))
                    T.bind(vi, ((i0_0 * 8) + i0_1))
                    T.bind(vj, i1)
                    T.bind(vk, ((i2_0 * 7) + i2_1))
                    C[vi, vj] = C[vi, vj] + (A[vi, vk] * B[vj, vk])
def non_perfect_tiling_cache(a: T.handle, b: T.handle) -> None:
    X = T.match_buffer(a, [224, 224], dtype="float32")
    Y = T.match_buffer(b, [224, 224], dtype="float32")
    cache = T.alloc_buffer([224, 224], dtype="float32")
    for hh_0, ww_0 in T.grid(28, 28):
        for ax0 in T.serial(0, 10):
            for ax1 in T.serial(0, 10):
                with T.block("cache"):
                    h = T.axis.spatial(224, hh_0 * 8 - 1 + ax0)
                    w = T.axis.spatial(224, ww_0 * 8 - 1 + ax1)
                    T.where(1 <= hh_0 * 8 + ax0 and hh_0 * 8 + ax0 < 225
                            and 1 <= ww_0 * 8 + ax1 and ww_0 * 8 + ax1 < 225)
                    cache[h, w] = X[h, w]
        for hh_1, ww_1, khh, kww in T.grid(8, 8, 3, 3):
            with T.block("compute"):
                h = T.axis.spatial(224, hh_0 * 8 + hh_1)
                w = T.axis.spatial(224, ww_0 * 8 + ww_1)
                kh, kw = T.axis.remap("RR", [khh, kww])
                with T.init():
                    Y[h, w] = 0.0
                Y[h, w] = T.max(
                    Y[h, w],
                    T.if_then_else(
                        T.likely(1 <= h + kh, dtype="bool")
                        and T.likely(h + kh < 225, dtype="bool")
                        and T.likely(1 <= w + kw, dtype="bool")
                        and T.likely(w + kw < 225, dtype="bool"),
                        cache[h + kh - 1, w + kw - 1],
                        0.0,
                        dtype="float32",
                    ),
                )
Пример #6
0
def elementwise_predicate(a: T.handle, b: T.handle) -> None:
    A = T.match_buffer(a, (128, 128, 128, 128))
    B = T.match_buffer(b, (128, 128, 128, 128))
    for i, j, k, l in T.grid(128, 128, 128, 128):
        with T.block("B"):
            T.where(i * 2097152 + j * 16384 + k * 128 + l < 100)
            vi, vj, vk, vl = T.axis.remap("SSSS", [i, j, k, l])
            B[vi, vj, vk, vl] = A[vi, vj, vk, vl] * 2.0
Пример #7
0
def elementwise_predicate_inlined(a: T.handle, c: T.handle) -> None:
    A = T.match_buffer(a, (128, 128))
    C = T.match_buffer(c, (128, 128))
    for i, j in T.grid(128, 128):
        with T.block("C"):
            vi, vj = T.axis.remap("SS", [i, j])
            T.where(A[i, j] * 2.0 < 10.0)
            C[vi, vj] = A[vi, vj] * 2.0 + 1.0
def narrow_shape(A: T.Buffer[(10,), "float32"], B: T.Buffer[(10,), "float32"]) -> None:
    B_cache = T.alloc_buffer(10, "float32")
    for j in T.serial(3):
        for k in T.serial(4):
            with T.block("B_cache"):
                T.where(j * 4 + k < 10)
                B_cache[j * 4 + k] = B[j]
    for i in T.serial(10):
        A[i] = B_cache[i] + T.float32(1)
Пример #9
0
def element_wise_split_predicate(a: T.handle, b: T.handle) -> None:
    A = T.match_buffer(a, [128, 128])
    B = T.match_buffer(b, [128, 128])
    for i, j_0, j_1 in T.grid(128, 13, 10):
        with T.block([128, 128], "B") as [vi, vj]:
            T.where(j_0 * 10 + j_1 < 128)
            T.bind(vi, i)
            T.bind(vj, j_0 * 10 + j_1)
            B[vi, vj] = A[vi, vj] * 2.0
Пример #10
0
def element_wise_split_predicate(a: T.handle, b: T.handle) -> None:
    A = T.match_buffer(a, [128, 128])
    B = T.match_buffer(b, [128, 128])
    for i, j_0, j_1 in T.grid(128, 13, 10):
        with T.block("B"):
            T.where(j_0 * 10 + j_1 < 128)
            vi = T.axis.S(128, i)
            vj = T.axis.S(128, j_0 * 10 + j_1)
            B[vi, vj] = A[vi, vj] * 2.0
Пример #11
0
def elementwise_predicate(a: T.handle, c: T.handle) -> None:
    A = T.match_buffer(a, (128, 128))
    B = T.alloc_buffer((128, 128))
    C = T.match_buffer(c, (128, 128))
    with T.block([128, 128], "B") as [vi, vj]:
        B[vi, vj] = A[vi, vj] * 2.0
    for i, j in T.grid(128, 128):
        with T.block([128, 128], "C") as [vi, vj]:
            T.where(B[i, j] < 10.0)
            C[vi, vj] = B[vi, vj] + 1.0
def compacted_predicate_func(a: T.handle, c: T.handle) -> None:
    A = T.match_buffer(a, (32), "float32")
    C = T.match_buffer(c, (32), "float32")

    for i, j in T.grid(5, 7):
        with T.block() as []:
            T.reads(A[i * 7 + j])
            T.writes(C[i * 7 + j])
            T.where(i * 7 + j < 32)
            C[i * 7 + j] = A[i * 7 + j] + 1.0
Пример #13
0
def elementwise_reordered_with_predicate(a: T.handle, b: T.handle) -> None:
    A = T.match_buffer(a, (128, 128, 128, 128))
    B = T.match_buffer(b, (128, 128, 128, 128))
    for l, j, k, i in T.grid(128, 128, 128, 128):
        with T.block([128, 128, 128, 128], "B") as [vi, vj, vk, vl]:
            T.where(i * 2097152 + j * 16384 + k * 128 + l < 100)
            T.bind(vi, i)
            T.bind(vj, j)
            T.bind(vk, k)
            T.bind(vl, l)
            B[vi, vj, vk, vl] = A[vi, vj, vk, vl] * 2.0
Пример #14
0
def rowsum_predicate(a: T.handle, b: T.handle) -> None:
    A = T.match_buffer(a, [128, 128], dtype="float32")
    B = T.match_buffer(b, [128], dtype="float32")
    for i, k_0, k_1 in T.grid(128, 13, 10):
        with T.block([128, T.reduce_axis(0, 128)], "B") as [vi, vk]:
            T.where(k_0 * 10 + k_1 < 128)
            T.bind(vi, i)
            T.bind(vk, k_0 * 10 + k_1)
            with T.init():
                B[vi] = 0.0
            B[vi] = B[vi] + A[vi, vk]
Пример #15
0
def rowsum_predicate(a: T.handle, b: T.handle) -> None:
    A = T.match_buffer(a, [128, 128], dtype="float32")
    B = T.match_buffer(b, [128], dtype="float32")
    for i, k_0, k_1 in T.grid(128, 13, 10):
        with T.block("B"):
            T.where(k_0 * 10 + k_1 < 128)
            vi = T.axis.S(128, i)
            vk = T.axis.R(128, k_0 * 10 + k_1)
            with T.init():
                B[vi] = 0.0
            B[vi] = B[vi] + A[vi, vk]
Пример #16
0
def element_wise_split_predicate_parallelized(a: T.handle, b: T.handle) -> None:
    A = T.match_buffer(a, [128, 128])
    B = T.match_buffer(b, [128, 128])
    for i in T.serial(0, 128):
        for j_0 in T.parallel(0, 13):
            for j_1 in T.serial(0, 10):
                with T.block("B"):
                    T.where(j_0 * 10 + j_1 < 128)
                    vi = T.axis.S(128, i)
                    vj = T.axis.S(128, j_0 * 10 + j_1)
                    B[vi, vj] = A[vi, vj] * 2.0
Пример #17
0
def elementwise_symbolic_split(a: T.handle, b: T.handle, n: T.int32) -> None:
    A = T.match_buffer(a, (128, 128, n))
    B = T.match_buffer(b, (128, 128, n))
    for i, j, k0, k1 in T.grid(128, 128, 10, T.floordiv((n + 9), 10)):
        with T.block("B"):
            T.where((((k0 * T.floordiv((n + 9), 10)) + k1) < n))
            vi, vj = T.axis.remap("SS", [i, j])
            vk = T.axis.S(n, k0 * T.floordiv(n + 9, 10) + k1)
            T.reads([A[vi, vj, vk]])
            T.writes([B[vi, vj, vk]])
            B[vi, vj, vk] = A[vi, vj, vk] * 2.0
def transformed_three_stage_compute(A: T.Buffer[(16, 16), "float32"],
                                    D: T.Buffer[(16, 16), "float32"]) -> None:
    for tx in T.thread_binding(16, thread="threadIdx.x"):
        with T.block():
            T.reads(A[tx, 0:16])
            T.writes(D[tx, 0:16])
            B = T.alloc_buffer([2, 16, 1], dtype="float32", scope="shared")
            C = T.alloc_buffer([2, 16, 1], dtype="float32", scope="shared")
            with T.block():
                T.reads(A[tx, 0:2], B[0:2, tx, 0])
                T.writes(B[0:2, tx, 0], C[0:2, tx, 0])
                for i in T.unroll(2):
                    with T.block():
                        T.reads(A[tx, i])
                        T.writes(B[0:2, tx, 0])
                        B[i, tx, 0] = A[tx, i] * T.float32(2)
                    with T.block():
                        T.where(1 <= i)
                        T.reads(B[0:2, tx, 0])
                        T.writes(C[0:2, tx, 0])
                        C[(i + 1) % 2, tx,
                          0] = B[(i + 1) % 2, tx, 0] + T.float32(2)
            with T.block():
                T.reads(A[tx, 2:16], B[0:2, tx, 0], C[0:2, tx, 0])
                T.writes(B[0:2, tx, 0], C[0:2, tx, 0], D[tx, 0:14])
                for i in T.serial(14):
                    with T.block():
                        T.reads(A[tx, i + 2])
                        T.writes(B[0:2, tx, 0])
                        B[i % 2, tx, 0] = A[tx, i + 2] * T.float32(2)
                    with T.block():
                        T.reads(B[0:2, tx, 0])
                        T.writes(C[0:2, tx, 0])
                        C[(i + 1) % 2, tx,
                          0] = B[(i + 1) % 2, tx, 0] + T.float32(2)
                    with T.block():
                        T.reads(C[0:2, tx, 0])
                        T.writes(D[tx, i])
                        D[tx, i] = C[i % 2, tx, 0] + T.float32(1)
            with T.block():
                T.reads(B[0:2, tx, 0], C[0:2, tx, 0])
                T.writes(C[0:2, tx, 0], D[tx, 14:16])
                for i in T.unroll(2):
                    with T.block():
                        T.where(i < 1)
                        T.reads(B[0:2, tx, 0])
                        T.writes(C[0:2, tx, 0])
                        C[(i + 1) % 2, tx,
                          0] = B[(i + 1) % 2, tx, 0] + T.float32(2)
                    with T.block():
                        T.reads(C[0:2, tx, 0])
                        T.writes(D[tx, i + 14])
                        D[tx, i + 14] = C[i, tx, 0] + T.float32(1)
Пример #19
0
def elementwise_split_with_predicate(a: T.handle, b: T.handle) -> None:
    B = T.match_buffer(b, [128, 128, 128])
    A = T.match_buffer(a, [128, 128, 128])
    for i0, i1, i2, j0, j1, k0, k1 in T.grid(1000, 2, 3, 1, 129, 3, 43):
        with T.block("B"):
            vi = T.axis.S(128, i0 * 6 + i1 * 3 + i2)
            vj = T.axis.S(128, j0 * 129 + j1)
            vk = T.axis.S(128, k0 * 43 + k1)
            T.where((i0 * 2 + i1) * 3 + i2 < 128 and j0 * 129 + j1 < 128 and k0 * 43 + k1 < 128)
            T.reads([A[vi, vj, vk]])
            T.writes([B[vi, vj, vk]])
            B[vi, vj, vk] = A[vi, vj, vk] * 2.0
def compacted_narrow_shape(A: T.Buffer[(10,), "float32"], B: T.Buffer[(10,), "float32"]) -> None:
    # body
    # with T.block("root")
    B_cache = T.alloc_buffer([10], dtype="float32")
    for j, k in T.grid(3, 4):
        with T.block("B_cache"):
            T.where(j * 4 + k < 10)
            T.reads(B[j])
            T.writes(B_cache[j * 4 + k])
            B_cache[j * 4 + k] = B[j]
    for i in T.serial(10):
        A[i] = B_cache[i] + T.float32(1)
Пример #21
0
def elementwise_predicate(a: T.handle, c: T.handle) -> None:
    A = T.match_buffer(a, (128, 128))
    B = T.alloc_buffer((128, 128))
    C = T.match_buffer(c, (128, 128))
    for i, j in T.grid(128, 128):
        with T.block("B"):
            vi, vj = T.axis.remap("SS", [i, j])
            B[vi, vj] = A[vi, vj] * 2.0
    for i, j in T.grid(128, 128):
        with T.block("C"):
            vi, vj = T.axis.remap("SS", [i, j])
            T.where(B[i, j] < 10.0)
            C[vi, vj] = B[vi, vj] + 1.0
Пример #22
0
def func_with_block_predicate() -> None:
    A = T.alloc_buffer((120))
    B = T.alloc_buffer((120))
    for i, j in T.grid(16, 8):
        with T.block("producer"):
            T.where(i * 8 + j < 120)
            ax = T.axis.S(120, i * 8 + j)
            A[ax] = 0.0
    for i, j in T.grid(16, 8):
        with T.block("consumer"):
            T.where(i * 8 + j < 120)
            ax = T.axis.S(120, i * 8 + j)
            B[ax] = A[ax] + 1.0
Пример #23
0
def with_block_predicate(a: T.handle, b: T.handle) -> None:
    A = T.match_buffer(a, [128, 120], dtype="float32")
    B = T.match_buffer(b, [128], dtype="float32")
    for i, ko in T.grid(128, 4):
        for ki in T.thread_binding(0, 32, thread="threadIdx.x"):
            with T.block("B"):
                vi = T.axis.spatial(128, i)
                vk = T.axis.reduce(120, ko * 32 + ki)
                T.where(ko * 32 + ki < 120)
                T.reads([B[vi], A[vi, vk]])
                T.writes([B[vi]])
                with T.init():
                    B[vi] = T.float32(0)
                B[vi] = B[vi] + A[vi, vk]
 def compacted_spatial_tiled_pad_and_pooling(
     X: T.Buffer[(64, 112, 112), "int32"], Y: T.Buffer[(64, 56, 56), "int32"]
 ) -> None:
     for h_o, w_o in T.grid(14, 14):
         with T.block():
             T.reads(X[0:64, h_o * 8 - 1 : h_o * 8 + 8, w_o * 8 - 1 : w_o * 8 + 8])
             T.writes(Y[h_o * 4 : h_o * 4 + 4, w_o * 4 : w_o * 4 + 4, 0:64])
             X_cache = T.alloc_buffer([9, 9, 64], dtype="int32")
             for ax0, ax1, ax2 in T.grid(64, 9, 9):
                 with T.block("cache"):
                     T.where(1 <= h_o * 8 + ax1 and 1 <= w_o * 8 + ax2)
                     T.reads(X[ax0, h_o * 8 + ax1 - 1, w_o * 8 + ax2 - 1])
                     T.writes(
                         X_cache[
                             h_o * 8 + ax1 - T.max(0, h_o * 8 - 1) - 1,
                             w_o * 8 + ax2 - T.max(0, w_o * 8 - 1) - 1,
                             ax0,
                         ]
                     )
                     X_cache[
                         h_o * 8 + ax1 - T.max(0, h_o * 8 - 1) - 1,
                         w_o * 8 + ax2 - T.max(0, w_o * 8 - 1) - 1,
                         ax0,
                     ] = X[ax0, h_o * 8 + ax1 - 1, w_o * 8 + ax2 - 1]
             for h_i, w_i, kh, kw, c in T.grid(4, 4, 3, 3, 64):
                 with T.block("compute"):
                     T.reads(
                         X_cache[
                             h_o * 8 + h_i * 2 + kh - T.max(0, h_o * 8 - 1) - 1,
                             w_o * 8 + w_i * 2 + kw - T.max(0, w_o * 8 - 1) - 1,
                             c,
                         ]
                     )
                     T.writes(Y[h_o * 4 + h_i, w_o * 4 + w_i, c])
                     if kh == 0 and kw == 0:
                         Y[h_o * 4 + h_i, w_o * 4 + w_i, c] = 0
                     Y[h_o * 4 + h_i, w_o * 4 + w_i, c] = T.max(
                         Y[h_o * 4 + h_i, w_o * 4 + w_i, c],
                         T.if_then_else(
                             T.likely(1 <= h_o * 8 + h_i * 2 + kh, dtype="bool")
                             and T.likely(1 <= w_o * 8 + w_i * 2 + kw, dtype="bool"),
                             X_cache[
                                 h_o * 8 + h_i * 2 + kh - T.max(0, h_o * 8 - 1) - 1,
                                 w_o * 8 + w_i * 2 + kw - T.max(0, w_o * 8 - 1) - 1,
                                 c,
                             ],
                             0,
                             dtype="int32",
                         ),
                     )
 def main(
     T_reshape: T.Buffer[(1, 12, 384, 384), "float32"],
     placeholder_1: T.Buffer[(T.int64(1), T.int64(12), T.int64(384),
                              384), "bool"],
     T_where: T.Buffer[(T.int64(1), T.int64(12), T.int64(384), 384),
                       "float32"]
 ) -> None:
     # function attr dict
     T.func_attr({"global_symbol": "main", "tir.noalias": True})
     # body
     # with T.block("root")
     for i0_i1_i2_i3_fused_1 in T.thread_binding(T.int64(256),
                                                 thread="blockIdx.x"):
         for i0_i1_i2_i3_fused_2 in T.thread_binding(
                 T.int64(1024), thread="threadIdx.x"):
             for i0_i1_i2_i3_fused_0 in T.serial(T.int64(7)):
                 with T.block("T_where"):
                     ax0 = T.axis.spatial(T.int64(1), T.int64(0))
                     ax1 = T.axis.spatial(
                         T.int64(12),
                         ((i0_i1_i2_i3_fused_0 * T.int64(256) +
                           i0_i1_i2_i3_fused_1) * T.int64(1024) +
                          i0_i1_i2_i3_fused_2) % T.int64(1769472) //
                         T.int64(147456))
                     ax2 = T.axis.spatial(
                         T.int64(384),
                         ((i0_i1_i2_i3_fused_0 * T.int64(256) +
                           i0_i1_i2_i3_fused_1) * T.int64(1024) +
                          i0_i1_i2_i3_fused_2) % T.int64(147456) //
                         T.int64(384))
                     ax3 = T.axis.spatial(
                         384,
                         T.cast(((i0_i1_i2_i3_fused_0 * T.int64(256) +
                                  i0_i1_i2_i3_fused_1) * T.int64(1024) +
                                 i0_i1_i2_i3_fused_2) % T.int64(384),
                                "int32"))
                     T.where((i0_i1_i2_i3_fused_0 * T.int64(256) +
                              i0_i1_i2_i3_fused_1) * T.int64(1024) +
                             i0_i1_i2_i3_fused_2 < T.int64(1769472))
                     T.reads(placeholder_1[ax0, ax1, ax2, ax3],
                             T_reshape[ax0, ax1, ax2, ax3])
                     T.writes(T_where[ax0, ax1, ax2, ax3])
                     T_where[ax0, ax1, ax2, ax3] = T.Select(
                         T.cast(placeholder_1[ax0, ax1, ax2,
                                              ax3], "int32") != 0,
                         T.float32(-1000000000), T_reshape[ax0, ax1,
                                                           ax2, ax3])
Пример #26
0
def lowered_with_block_predicate(a: T.handle, b: T.handle) -> None:
    A = T.match_buffer(a, [128, 120], dtype="float32")
    B = T.match_buffer(b, [128], dtype="float32")
    reduce_temp0 = T.alloc_buffer([1],
                                  dtype="float32",
                                  strides=[1],
                                  scope="local")
    normal_reduce_temp0 = T.alloc_buffer([1],
                                         dtype="float32",
                                         strides=[1],
                                         scope="local")
    for i in T.serial(0, 128):
        for ki in T.thread_binding(0, 32, thread="threadIdx.x"):
            with T.block("B_in_thread_init"):
                T.reads([])
                T.writes([normal_reduce_temp0[0]])
                normal_reduce_temp0[0] = T.float32(0)
            for ko in T.serial(0, 4):
                with T.block("B_normal_reduction"):
                    vi = T.axis.spatial(128, i)
                    vk = T.axis.reduce(120, ko * 32 + ki)
                    T.where(ko * 32 + ki < 120)
                    T.reads([A[vi, vk], normal_reduce_temp0[0]])
                    T.writes([normal_reduce_temp0[0]])
                    normal_reduce_temp0[0] = normal_reduce_temp0[0] + A[vi, vk]
            with T.block("B_cross_thread_reduction"):
                T.reads([normal_reduce_temp0[0]])
                T.writes([reduce_temp0[0]])
                T.attr(
                    T.comm_reducer(lambda x, y: x + y, [T.float32(0)]),
                    "reduce_scope",
                    T.reinterpret(T.uint64(0), dtype="handle"),
                )
                T.evaluate(
                    T.tvm_thread_allreduce(
                        T.uint32(1),
                        normal_reduce_temp0[0],
                        True,
                        reduce_temp0.data,
                        ki,
                        dtype="handle",
                    ))
            with T.block("B_write_back"):
                vi = T.axis.spatial(128, i)
                T.reads([reduce_temp0[0]])
                T.writes([B[vi]])
                B[vi] = reduce_temp0[0]
Пример #27
0
def rowsum_predicate_rfactor(a: T.handle, b: T.handle) -> None:
    A = T.match_buffer(a, [128, 128], dtype="float32")
    B = T.match_buffer(b, [128], dtype="float32")
    B_rf = T.alloc_buffer([128, 13], dtype="float32")
    for i, k_0, k_1 in T.grid(128, 13, 10):
        with T.block("B_rf"):
            vk_0, vi, vk_1 = T.axis.remap("SSR", [k_0, i, k_1])
            T.where(k_0 * 10 + k_1 < 128)
            with T.init():
                B_rf[vi, vk_0] = T.float32(0)
            B_rf[vi, vk_0] = B_rf[vi, vk_0] + A[vi, vk_0 * 10 + vk_1]
    for i, k_0 in T.grid(128, 13):
        with T.block("B"):
            vk_0, vi = T.axis.remap("RS", [k_0, i])
            with T.init():
                B[vi] = T.float32(0)
            B[vi] = B[vi] + B_rf[vi, vk_0]
 def main(a: T.handle, b: T.handle) -> None:
     # function attr dict
     T.func_attr({"global_symbol": "main"})
     A = T.match_buffer(a, [1024, 1024, 1024], dtype="float32")
     B = T.match_buffer(b, [1024, 1024, 1024], dtype="float32")
     # body
     with T.block("root"):
         T.block_attr({"meta_schedule.parallel":128, "meta_schedule.vectorize":32})
         for i0, j0, i1, j1, k0, i2, j2, k1 in T.grid(128, 64, 4, 4, 64, 4, 8, 32):
             with T.block("move"):
                 vi = T.axis.spatial(1024, i0 * 16 + i1 * 4 + i2)
                 vj = T.axis.spatial(1024, j0 * 32 + j1 * 8 + j2)
                 vk = T.axis.spatial(1024, k0 * 32 + k1)
                 T.where((i0 * 4 + i1) * 4 + i2 < 1024 and (j0 * 4 + j1) * 8 + j2 < 1024 and k0 * 32 + k1 < 1024)
                 T.reads([A[vi, vj, vk]])
                 T.writes([B[vi, vj, vk]])
                 B[vi, vj, vk] = A[vi, vj, vk]
Пример #29
0
def block_predicate_cache_write_output_buf() -> None:
    A = T.alloc_buffer([120], dtype="float32")
    B = T.alloc_buffer([120], dtype="float32")
    B_shared = T.alloc_buffer([120], dtype="float32", scope="shared")
    for i, j in T.grid(16, 8):
        with T.block("producer"):
            ax = T.axis.spatial(120, i * 8 + j)
            T.where(i * 8 + j < 120)
            A[ax] = T.float32(0)
    for i, j in T.grid(16, 8):
        with T.block("consumer"):
            ax = T.axis.spatial(120, i * 8 + j)
            T.where(i * 8 + j < 120)
            B_shared[ax] = A[ax] + T.float32(1)
    for ax0 in T.serial(120):
        with T.block("B_shared"):
            v0 = T.axis.spatial(120, ax0)
            B[v0] = B_shared[v0]
Пример #30
0
 def main(A: T.Buffer[(1, 256, 256), "float32"],
          D: T.Buffer[(1, ), "float32"]) -> None:
     C = T.alloc_buffer([1], dtype="float32")
     for i0_fused_0 in T.thread_binding(1, thread="blockIdx.x"):
         for i0_fused_1 in T.thread_binding(32, thread="threadIdx.x"):
             for i1, i2 in T.grid(256, 256):
                 with T.block("C"):
                     b = T.axis.S(1, 0)
                     i, j = T.axis.remap("RR", [i1, i2])
                     T.where(i0_fused_1 < 1)
                     with T.init():
                         C[b] = T.float32(0)
                     C[b] = C[b] + A[b, i, j] * A[b, i, j]
     for i0_fused_0 in T.thread_binding(1, thread="blockIdx.x"):
         for i0_fused_1 in T.thread_binding(32, thread="threadIdx.x"):
             with T.block("D"):
                 b = T.axis.S(1, 0)
                 T.where(i0_fused_1 < 1)
                 D[b] = T.sqrt(C[b], dtype="float32")