def expected_recursive_bufferslice_indices(data: T.handle, index: T.handle) -> None: index_buf = T.match_buffer(index, [1], dtype="int32", elem_offset=0, align=128, offset_factor=1) data_buf = T.match_buffer(data, [16, 16], elem_offset=0, align=128, offset_factor=1) with T.block([], "root"): T.reads([]) T.writes([]) out_buf = T.alloc_buffer([16, 16], elem_offset=0, align=128, offset_factor=1) for i0, i1 in T.grid(16, 16): with T.block([16, 16], "") as [vi, vj]: T.bind(vi, i0) T.bind(vj, i1) T.reads([data_buf[0:16, 0:16], index_buf[0]]) T.writes([out_buf[vi, vj]]) out_buf[vi, vj] = data_buf[index_buf[index_buf[0]], index_buf[0]]
def transformed_opaque_access(a: T.handle, b: T.handle) -> None: A = T.match_buffer(a, [1024]) B = T.match_buffer(b, [1024]) for i in T.serial(0, 8): with T.block([8]) as [vi]: T.reads(A[vi * 128:vi * 128 + 128]) T.writes(B[vi * 128:vi * 128 + 128]) A_cache = T.alloc_buffer([1024]) with T.block([8]) as [v]: T.bind(v, vi) T.reads([A[v * 128:v * 128 + 128]]) T.writes([A_cache[v * 128:v * 128 + 128]]) T.evaluate( T.call_extern("test", A_cache.data, v * 128, 128, A.data, v * 128, 128, dtype="float32")) for j in T.serial(0, 128): with T.block([1024]) as [v]: T.bind(v, ((vi * 128) + j)) T.reads([A_cache[v]]) T.writes([B[v]]) B[v] = A_cache[v]
def buffer_opaque_access(b: T.handle, c: T.handle) -> None: B = T.match_buffer(b, [16, 16], "float32") C = T.match_buffer(c, [16, 16], "float32") with T.block([]): T.reads([]) T.writes(B[0:16, 0:16]) A = T.allocate([256], "float32", "global") for i, j in T.grid(16, 16): T.store(A, i * 16 + j, 1) for i in range(0, 16): for j in range(0, 16): T.evaluate(T.load("float32", A, i * 16 + j)) for j in range(0, 16): T.evaluate( T.tvm_fill_fragment(B.data, 16, 16, 16, 0, T.float32(0), dtype="handle")) for i, j in T.grid(16, 16): with T.block([16, 16]) as [vi, vj]: T.bind(vi, i) T.bind(vj, j) C[vi, vj] = B[vi, vj]
def matmul_decompose4(a: T.handle, b: T.handle, c: T.handle) -> None: C = T.match_buffer(c, [128, 128], elem_offset=0, align=128, offset_factor=1) B = T.match_buffer(b, [128, 128], elem_offset=0, align=128, offset_factor=1) A = T.match_buffer(a, [128, 128], elem_offset=0, align=128, offset_factor=1) # body with T.block([], "root"): T.reads([]) T.writes([]) for i0_0 in T.serial(0, 16): for i0_1_init, i1_init in T.grid(8, 128): with T.block([128, 128], "update_init") as [vi_init, vj_init]: T.bind(vi_init, ((i0_0 * 8) + i0_1_init)) T.bind(vj_init, i1_init) C[vi_init, vj_init] = T.float32(0) for i0_1, i1, i2_0, i2_1 in T.grid(8, 128, 19, 7): with T.block([128, 128, T.reduce_axis(0, 128)], "update_update") as [ vi, vj, vk, ]: T.where((((i2_0 * 7) + i2_1) < 128)) T.bind(vi, ((i0_0 * 8) + i0_1)) T.bind(vj, i1) T.bind(vk, ((i2_0 * 7) + i2_1)) C[vi, vj] = C[vi, vj] + (A[vi, vk] * B[vj, vk])
def element_wise_parallelized(a: T.handle, b: T.handle) -> None: A = T.match_buffer(a, (128, 128)) B = T.match_buffer(b, (128, 128)) for i0 in T.parallel(0, 128): for i1 in T.serial(0, 128): with T.block([128, 128], "B") as [vi, vj]: T.bind(vi, i0) T.bind(vj, i1) B[vi, vj] = A[vi, vj] * 2.0
def element_wise_split_predicate(a: T.handle, b: T.handle) -> None: A = T.match_buffer(a, [128, 128]) B = T.match_buffer(b, [128, 128]) for i, j_0, j_1 in T.grid(128, 13, 10): with T.block([128, 128], "B") as [vi, vj]: T.where(j_0 * 10 + j_1 < 128) T.bind(vi, i) T.bind(vj, j_0 * 10 + j_1) B[vi, vj] = A[vi, vj] * 2.0
def element_wise_i_bound(a: T.handle, b: T.handle) -> None: A = T.match_buffer(a, (128, 128)) B = T.match_buffer(b, (128, 128)) for i0 in T.thread_binding(0, 128, thread="threadIdx.x"): for i1 in T.serial(0, 128): with T.block([128, 128], "B") as [vi, vj]: T.bind(vi, i0) T.bind(vj, i1) B[vi, vj] = A[vi, vj] * 2.0
def transformed_match_buffer_func() -> None: for i in range(0, 128): with T.block([128]) as [vi]: T.bind(vi, i) C = T.alloc_buffer((128, 128)) C0 = T.match_buffer(C[vi, 0:128], (128)) with T.block([128]) as [jj]: C1 = T.match_buffer(C0[jj], ()) C1[()] = 0
def rowsum_not_quasi_affine(a: T.handle, b: T.handle) -> None: A = T.match_buffer(a, (128, 128)) B = T.match_buffer(b, (128,)) for i, k in T.grid(128, 16): with T.block([128, T.reduce_axis(0, 128)], "B") as [vi, vk]: T.bind(vi, i) T.bind(vk, T.floordiv(k * k, 2)) with T.init(): B[vi] = 0.0 B[vi] = B[vi] + A[vi, vk]
def tiled(a: T.handle, c: T.handle) -> None: A = T.match_buffer(a, [128, 128], "float32") B = T.alloc_buffer([128, 128], "float32") C = T.match_buffer(c, [128, 128], "float32") for i_0, j_0, i_1, j_1 in T.grid(8, 8, 16, 16): with T.block([128, 128], "B") as [vi, vj]: T.bind(vi, i_0 * 16 + i_1) T.bind(vj, j_0 * 16 + j_1) B[vi, vj] = A[vi, vj] * 2.0 with T.block([128, 128], "C") as [vi, vj]: C[vi, vj] = B[vi, vj] + 1.0
def rowsum_predicate(a: T.handle, b: T.handle) -> None: A = T.match_buffer(a, [128, 128], dtype="float32") B = T.match_buffer(b, [128], dtype="float32") for i, k_0, k_1 in T.grid(128, 13, 10): with T.block([128, T.reduce_axis(0, 128)], "B") as [vi, vk]: T.where(k_0 * 10 + k_1 < 128) T.bind(vi, i) T.bind(vk, k_0 * 10 + k_1) with T.init(): B[vi] = 0.0 B[vi] = B[vi] + A[vi, vk]
def rowsum_transformed(a: T.handle, b: T.handle) -> None: A = T.match_buffer(a, (128, 128)) B = T.match_buffer(b, (128, )) for io, ii_ko_fused, ki in T.grid(32, 128, 4): with T.block([128, T.reduce_axis(0, 128)], "B") as [vi, vk]: T.bind(vi, io * 4 + T.floordiv(ii_ko_fused, 32)) T.bind(vk, T.floormod(ii_ko_fused, 32) * 4 + ki) with T.init(): B[vi] = 0.0 B[vi] = B[vi] + A[vi, vk]
def rowsum_unrolled(a: T.handle, b: T.handle) -> None: A = T.match_buffer(a, (128, 128)) B = T.match_buffer(b, (128,)) for i0 in T.unroll(0, 128): for i1 in T.serial(0, 128): with T.block([128, T.reduce_axis(0, 128)], "B") as [vi, vk]: T.bind(vi, i0) T.bind(vk, i1) with T.init(): B[vi] = 0.0 B[vi] = B[vi] + A[vi, vk]
def rowsum_cross_thread_reduction(a: T.handle, b: T.handle) -> None: A = T.match_buffer(a, (128, 128)) B = T.match_buffer(b, (128,)) for i0 in T.serial(0, 128): for i1 in T.thread_binding(0, 128, thread="threadIdx.x"): with T.block([128, T.reduce_axis(0, 128)], "B") as [vi, vk]: T.bind(vi, i0) T.bind(vk, i1) with T.init(): B[vi] = 0.0 B[vi] = B[vi] + A[vi, vk]
def element_wise_split_predicate_parallelized(a: T.handle, b: T.handle) -> None: A = T.match_buffer(a, [128, 128]) B = T.match_buffer(b, [128, 128]) for i in T.serial(0, 128): for j_0 in T.parallel(0, 13): for j_1 in T.serial(0, 10): with T.block([128, 128], "B") as [vi, vj]: T.where(j_0 * 10 + j_1 < 128) T.bind(vi, i) T.bind(vj, j_0 * 10 + j_1) B[vi, vj] = A[vi, vj] * 2.0
def rowsum_not_serial(a: T.handle, b: T.handle) -> None: A = T.match_buffer(a, (128, 128)) B = T.match_buffer(b, (128, )) for i in T.serial(0, 128): for k in T.parallel(0, 128): with T.block([128, T.reduce_axis(0, 128)], "B") as [vi, vk]: T.bind(vi, i) T.bind(vk, k) with T.init(): B[vi] = 0.0 B[vi] = B[vi] + A[vi, vk]
def two_elementwise_after_compute_at(a: T.handle, c: T.handle) -> None: A = T.match_buffer(a, (128, 128), "float32") B = T.alloc_buffer((128, 128), "float32") C = T.match_buffer(c, (128, 128), "float32") for i in range(0, 128): for ax0, ax1 in T.grid(1, 128): with T.block([128, 128], "B") as [vi, vj]: T.bind(vi, i + ax0) T.bind(vj, ax1) B[vi, vj] = A[vi, vj] * 2.0 for j in range(0, 128): with T.block([128, 128], "B") as [vi, vj]: C[vi, vj] = B[vi, vj] + 1.0
def blockized_1(a: T.handle, c: T.handle) -> None: A = T.match_buffer(a, [128, 128], "float32") B = T.alloc_buffer([128, 128], "float32") C = T.match_buffer(c, [128, 128], "float32") with T.block([128, 128], "B") as [vi, vj]: B[vi, vj] = A[vi, vj] * 2.0 with T.block([8, 8], "C_outer") as [vi_o, vj_o]: T.reads([B[vi_o * 16:vi_o * 16 + 16, vj_o * 16:vj_o * 16 + 16, ]]) T.writes([C[vi_o * 16:vi_o * 16 + 16, vj_o * 16:vj_o * 16 + 16]]) for i_i, j_i in T.grid(16, 16): with T.block([128, 128], "C_inner") as [vi, vj]: T.bind(vi, vi_o * 16 + i_i) T.bind(vj, vj_o * 16 + j_i) C[vi, vj] = B[vi, vj] + 1.0
def cuda_matmul_1(a: T.handle, b: T.handle, c: T.handle) -> None: # pylint: disable=undefined-loop-variable A = T.match_buffer(a, [2048, 2048], "float32") B = T.match_buffer(b, [2048, 2048], "float32") C = T.match_buffer(c, [2048, 2048], "float32") A_shared = T.alloc_buffer([2048, 2048], "float32", scope="shared") B_shared = T.alloc_buffer([2048, 2048], "float32", scope="shared") A_shared_local = T.alloc_buffer([2048, 2048], "float32", scope="local") B_shared_local = T.alloc_buffer([2048, 2048], "float32", scope="local") C_local = T.alloc_buffer([2048, 2048], "float32", scope="local") with T.block([2048, 2048], "A_shared") as [v0, v1]: A_shared[v0, v1] = A[v0, v1] with T.block([2048, 2048], "B_shared") as [v0, v1]: B_shared[v0, v1] = B[v0, v1] with T.block([2048, 2048], "A_shared_local") as [v0, v1]: A_shared_local[v0, v1] = A_shared[v0, v1] with T.block([2048, 2048], "B_shared_local") as [v0, v1]: B_shared_local[v0, v1] = B_shared[v0, v1] for by in T.thread_binding(0, 32, thread="blockIdx.y"): for bx in T.thread_binding(0, 32, thread="blockIdx.x"): for vy in T.thread_binding(0, 2, thread="vthread.y"): for vx in T.thread_binding(0, 2, thread="vthread.x"): for ty in T.thread_binding(0, 8, thread="threadIdx.y"): for tx in T.thread_binding(0, 8, thread="threadIdx.x"): for k_0 in T.serial(0, 256): for k_1 in T.unroll(0, 8): for _, i, j in T.grid(1, 4, 4): with T.block([ 2048, 2048, T.reduce_axis(0, 2048) ], "C") as [vi, vj, vk]: T.bind( vi, by * 64 + vy * 32 + ty * 4 + i) T.bind( vj, bx * 64 + vx * 32 + tx * 4 + j) T.bind(vk, k_0 * 8 + k_1) with T.init(): C_local[vi, vj] = 0.0 C_local[vi, vj] = C_local[ vi, vj] + A_shared_local[ vk, vi] * B_shared_local[ vk, vj] for i, j in T.grid(4, 4): with T.block([2048, 2048], "C_local") as [vi, vj]: T.bind(vi, by * 64 + vy * 32 + ty * 4 + i) T.bind(vj, bx * 64 + vx * 32 + tx * 4 + j) C[vi, vj] = C_local[vi, vj]
def read_out_of_bound_after_compute_at(a: T.handle, c: T.handle) -> None: A = T.match_buffer(a, [16], "float32") B = T.alloc_buffer([16], "float32") C = T.match_buffer(c, [16], "float32") for j in T.serial(0, 16): for i in T.serial(0, T.min(1, 15 - j) + 1): with T.block([16], "B") as [v]: T.bind(v, j + i) B[v] = A[v] with T.block([16], "C") as [v]: T.bind(v, j) T.reads([B[v:v + 2]]) C[v] = T.if_then_else(v < 15, T.max(B[v], B[v + 1]), B[v], dtype="float32")
def func_multi_consumer() -> None: A = T.alloc_buffer((128)) B = T.alloc_buffer((128)) C = T.alloc_buffer((128)) for i in T.grid(8): for j in T.grid(16): with T.block([128], "A") as [vi]: T.bind(vi, i * 16 + j) A[vi] = 1.0 for j in T.grid(16): with T.block([128], "B") as [vi]: T.bind(vi, i * 16 + j) B[vi] = A[vi] + 1.0 for i in T.grid(128): with T.block([128], "C") as [vi]: C[vi] = A[vi]
def rowsum_blockized(a: T.handle, b: T.handle) -> None: B = T.match_buffer(b, [32, 4]) A = T.match_buffer(a, [32, 4, 128]) for i0, i2_0 in T.grid(32, 16): with T.block([32, T.reduce_axis(0, 16)], "blockized_B") as [io, ko]: T.bind(io, i0) T.bind(ko, i2_0) with T.init(): for i1 in T.serial(0, 4): with T.block([4], "B_init") as [ii_init]: T.bind(ii_init, i1) B[io, ii_init] = 0.0 for i1_1, i2_1 in T.grid(4, 8): with T.block([4, T.reduce_axis(0, 128)], "B") as [ii, k]: T.bind(ii, i1_1) T.bind(k, ko * 8 + i2_1) B[io, ii] = B[io, ii] + A[io, ii, k]
def elementwise_reordered2(a: T.handle, b: T.handle) -> None: A = T.match_buffer(a, (128, 128, 128, 128)) B = T.match_buffer(b, (128, 128, 128, 128)) for k, j, i, l in T.grid(128, 128, 128, 128): with T.block([128, 128, 128, 128], "B") as [vi, vj, vk, vl]: T.bind(vi, i) T.bind(vj, j) T.bind(vk, k) T.bind(vl, l) B[vi, vj, vk, vl] = A[vi, vj, vk, vl] * 2.0
def elementwise_not_affine(a: T.handle, b: T.handle) -> None: A = T.match_buffer(a, (128, 128, 128, 128)) B = T.match_buffer(b, (128, 128, 128, 128)) for i, j, k, l in T.grid(128, 128, 128, 8): with T.block([128, 128, 128, 128], "B") as [vi, vj, vk, vl]: T.bind(vi, i) T.bind(vj, j) T.bind(vk, k) T.bind(vl, l * 16) B[vi, vj, vk, vl] = A[vi, vj, vk, vl] * 2.0
def elementwise_reordered_with_predicate(a: T.handle, b: T.handle) -> None: A = T.match_buffer(a, (128, 128, 128, 128)) B = T.match_buffer(b, (128, 128, 128, 128)) for l, j, k, i in T.grid(128, 128, 128, 128): with T.block([128, 128, 128, 128], "B") as [vi, vj, vk, vl]: T.where(i * 2097152 + j * 16384 + k * 128 + l < 100) T.bind(vi, i) T.bind(vj, j) T.bind(vk, k) T.bind(vl, l) B[vi, vj, vk, vl] = A[vi, vj, vk, vl] * 2.0
def matmul_decompose1(a: T.handle, b: T.handle) -> None: A = T.match_buffer(a, [32, 4, 128], elem_offset=0, align=128, offset_factor=1) B = T.match_buffer(b, [32, 4], elem_offset=0, align=128, offset_factor=1) for i0 in T.serial(0, 32): with T.block([32], "blockized_B_init") as [io]: for i1 in T.serial(0, 4): with T.block([4], "B_init") as [ii]: B[io, ii] = T.float32(0) for i0, i2_o in T.grid(32, 16): with T.block([32, T.reduce_axis(0, 16)], "blockized_B_update") as [io, ko]: for i1, i2_i in T.grid(4, 8): with T.block([4, T.reduce_axis(0, 128)], "B") as [ii, k]: T.bind(ii, i1) T.bind(k, ((ko * 8) + i2_i)) B[io, ii] = B[io, ii] + A[io, ii, k]
def square_sum_rfactor(a: T.handle, c: T.handle) -> None: A = T.match_buffer(a, [16, 256, 256]) C = T.match_buffer(c, [16]) C_rf = T.alloc_buffer([16, 256]) for i0, i1, i2 in T.grid(16, 256, 256): with T.block([256, 16, T.reduce_axis(0, 256)], "C_rf") as [vi2, b, i]: T.bind(vi2, i2) T.bind(b, i0) T.bind(i, i1) with T.init(): C_rf[b, vi2] = 0.0 C_rf[b, vi2] = C_rf[b, vi2] + (A[b, i, vi2] * A[b, i, vi2]) for i0_1, i2_1 in T.grid(16, 256): with T.block([T.reduce_axis(0, 256), 16], "C") as [vi2_1, b_1]: T.bind(vi2_1, i2_1) T.bind(b_1, i0_1) with T.init(): C[b_1] = 0.0 C[b_1] = C[b_1] + C_rf[b_1, vi2_1]
def rowsum_predicate_rfactor(a: T.handle, b: T.handle) -> None: A = T.match_buffer(a, [128, 128], dtype="float32") B = T.match_buffer(b, [128], dtype="float32") B_rf = T.alloc_buffer([128, 13], dtype="float32") for i, k_0, k_1 in T.grid(128, 13, 10): with T.block([13, 128, T.reduce_axis(0, 10)], "B_rf") as [vk_0, vi, vk_1]: T.where(k_0 * 10 + k_1 < 128) T.bind(vk_0, k_0) T.bind(vi, i) T.bind(vk_1, k_1) with T.init(): B_rf[vi, vk_0] = T.float32(0) B_rf[vi, vk_0] = B_rf[vi, vk_0] + A[vi, vk_0 * 10 + vk_1] for i, k_0 in T.grid(128, 13): with T.block([T.reduce_axis(0, 13), 128], "B") as [vk_0, vi]: T.bind(vk_0, k_0) T.bind(vi, i) with T.init(): B[vi] = T.float32(0) B[vi] = B[vi] + B_rf[vi, vk_0]
def factorized_after_reverse_compute_at(a: T.handle, b: T.handle) -> None: A = T.match_buffer(a, [16, 16, 16], "float32") B = T.match_buffer(b, [16], "float32") B_rf_local = T.alloc_buffer([16, 16], "float32", scope="local") for j in T.thread_binding(0, 16, thread="blockIdx.x"): for i_o in T.thread_binding(0, 4, thread="threadIdx.x"): for i_i, k in T.grid(4, 16): with T.block([16, 16, T.reduce_axis(0, 16)], "B_rf") as [vi, vj, vk]: T.bind(vi, i_o * 4 + i_i) T.bind(vj, j) T.bind(vk, k) with T.init(): B_rf_local[vi, vj] = 0.0 B_rf_local[vi, vj] = B_rf_local[vi, vj] + A[vj, vi, vk] for k in T.serial(0, 4): with T.block([16, T.reduce_axis(0, 16)], "B") as [vi, vk]: T.bind(vi, j) T.bind(vk, i_o * 4 + k) with T.init(): B[vi] = 0.0 B[vi] = B[vi] + B_rf_local[vk, vi]
def element_wise_compute_at_split(a: T.handle, c: T.handle) -> None: A = T.match_buffer(a, (128, 128)) C = T.match_buffer(c, (128, 128)) B = T.alloc_buffer((128, 128)) for i in T.serial(0, 128): for j0 in T.serial(0, 128): with T.block([128, 128], "B") as [vi, vj]: T.bind(vi, i) T.bind(vj, j0) B[vi, vj] = A[vi, vj] * 2.0 for j1o, j1i in T.grid(32, 4): with T.block([128, 128], "C") as [vi, vj]: T.bind(vi, i) T.bind(vj, j1o * 4 + j1i) C[vi, vj] = B[vi, vj] + 1.0