def cuda_matmul_1(a: T.handle, b: T.handle, c: T.handle) -> None: # pylint: disable=undefined-loop-variable A = T.match_buffer(a, [2048, 2048], "float32") B = T.match_buffer(b, [2048, 2048], "float32") C = T.match_buffer(c, [2048, 2048], "float32") A_shared = T.alloc_buffer([2048, 2048], "float32", scope="shared") B_shared = T.alloc_buffer([2048, 2048], "float32", scope="shared") A_shared_local = T.alloc_buffer([2048, 2048], "float32", scope="local") B_shared_local = T.alloc_buffer([2048, 2048], "float32", scope="local") C_local = T.alloc_buffer([2048, 2048], "float32", scope="local") with T.block([2048, 2048], "A_shared") as [v0, v1]: A_shared[v0, v1] = A[v0, v1] with T.block([2048, 2048], "B_shared") as [v0, v1]: B_shared[v0, v1] = B[v0, v1] with T.block([2048, 2048], "A_shared_local") as [v0, v1]: A_shared_local[v0, v1] = A_shared[v0, v1] with T.block([2048, 2048], "B_shared_local") as [v0, v1]: B_shared_local[v0, v1] = B_shared[v0, v1] for by in T.thread_binding(0, 32, thread="blockIdx.y"): for bx in T.thread_binding(0, 32, thread="blockIdx.x"): for vy in T.thread_binding(0, 2, thread="vthread.y"): for vx in T.thread_binding(0, 2, thread="vthread.x"): for ty in T.thread_binding(0, 8, thread="threadIdx.y"): for tx in T.thread_binding(0, 8, thread="threadIdx.x"): for k_0 in T.serial(0, 256): for k_1 in T.unroll(0, 8): for _, i, j in T.grid(1, 4, 4): with T.block([ 2048, 2048, T.reduce_axis(0, 2048) ], "C") as [vi, vj, vk]: T.bind( vi, by * 64 + vy * 32 + ty * 4 + i) T.bind( vj, bx * 64 + vx * 32 + tx * 4 + j) T.bind(vk, k_0 * 8 + k_1) with T.init(): C_local[vi, vj] = 0.0 C_local[vi, vj] = C_local[ vi, vj] + A_shared_local[ vk, vi] * B_shared_local[ vk, vj] for i, j in T.grid(4, 4): with T.block([2048, 2048], "C_local") as [vi, vj]: T.bind(vi, by * 64 + vy * 32 + ty * 4 + i) T.bind(vj, bx * 64 + vx * 32 + tx * 4 + j) C[vi, vj] = C_local[vi, vj]
def cuda_matmul_4(a: T.handle, b: T.handle, c: T.handle) -> None: # pylint: disable=undefined-loop-variable A = T.match_buffer(a, [2048, 2048], "float32") B = T.match_buffer(b, [2048, 2048], "float32") C = T.match_buffer(c, [2048, 2048], "float32") A_shared = T.alloc_buffer([2048, 2048], "float32", scope="shared") B_shared = T.alloc_buffer([2048, 2048], "float32", scope="shared") A_shared_local = T.alloc_buffer([2048, 2048], "float32", scope="local") B_shared_local = T.alloc_buffer([2048, 2048], "float32", scope="local") C_local = T.alloc_buffer([2048, 2048], "float32", scope="local") for i, j in T.grid(2048, 2048): with T.block("B_shared"): v0, v1 = T.axis.remap("SS", [i, j]) B_shared[v0, v1] = B[v0, v1] for by in T.thread_binding(0, 32, thread = "blockIdx.y"): for bx in T.thread_binding(0, 32, thread = "blockIdx.x"): for vy in T.thread_binding(0, 2, thread = "vthread.y"): for vx in T.thread_binding(0, 2, thread = "vthread.x"): for ty in T.thread_binding(0, 8, thread = "threadIdx.y"): for tx in T.thread_binding(0, 8, thread = "threadIdx.x"): for k0 in T.serial(0, 256): for i, j in T.grid(8, 64): with T.block("A_shared"): v0 = T.axis.S(2048, k0 * 8 + i) v1 = T.axis.S(2048, by * 64 + j) A_shared[v0, v1] = A[v0, v1] for k1 in T.unroll(0, 8): for i, j in T.grid(1, 4): with T.block("A_shared_local"): v0 = T.axis.S(2048, k0 * 8 + k1 + i) v1 = T.axis.S(2048, by * 64 + vy * 32 + ty * 4 + j) A_shared_local[v0, v1] = A_shared[v0, v1] for i, j in T.grid(1, 4): with T.block("B_shared_local"): v0 = T.axis.S(2048, k0 * 8 + k1 + i) v1 = T.axis.S(2048, bx * 64 + vx * 32 + tx * 4 + j) B_shared_local[v0, v1] = B_shared[v0, v1] for _, i, j in T.grid(1, 4, 4): with T.block("C"): vi = T.axis.S(2048, by * 64 + vy * 32 + ty * 4 + i) vj = T.axis.S(2048, bx * 64 + vx * 32 + tx * 4 + j) vk = T.axis.R(2048, k0 * 8 + k1) with T.init(): C_local[vi, vj] = 0.0 C_local[vi, vj] = C_local[vi, vj] + A_shared_local[vk, vi] * B_shared_local[vk, vj] for i, j in T.grid(4, 4): with T.block("C_local"): v0 = T.axis.S(2048, by * 64 + vy * 32 + ty * 4 + i) v1 = T.axis.S(2048, bx * 64 + vx * 32 + tx * 4 + j) C[v0, v1] = C_local[v0, v1]
def test_tir_fma(A: T.handle, B: T.handle, C: T.handle, d: T.handle) -> None: # function attr dict T.func_attr({"global_symbol": "test_fma", "tir.noalias": True}) n = T.var("int32") stride = T.var("int32") stride_1 = T.var("int32") stride_2 = T.var("int32") stride_3 = T.var("int32") A_1 = T.match_buffer( A, [n], strides=[stride], elem_offset=0, align=128, offset_factor=1, buffer_type="auto", ) B_1 = T.match_buffer( B, [n], strides=[stride_1], elem_offset=0, align=128, offset_factor=1, buffer_type="auto", ) C_1 = T.match_buffer( C, [n], strides=[stride_2], elem_offset=0, align=128, offset_factor=1, buffer_type="auto", ) d_1 = T.match_buffer( d, [n], strides=[stride_3], elem_offset=0, align=128, offset_factor=1, buffer_type="auto", ) # body for i in T.serial(0, n): d_1[(i * stride_3)] = (A_1[(i * stride)] * B_1[(i * stride_1)]) + C_1[(i * stride_2)]
def compacted_func(A: T.Buffer[(960, 770), "float32"], B: T.Buffer[(770, 2304), "float32"], C: T.Buffer[(960, 2304), "float32"]) -> None: for bx in T.thread_binding(144, thread="blockIdx.x"): for vx in T.thread_binding(2, thread="vthread.x"): for tx_p in T.thread_binding(256, thread="threadIdx.x"): with T.block(): for k_0 in T.serial(193): with T.block(): A_shared = T.alloc_buffer([128, 4], dtype="float32", scope="shared") B_shared = T.alloc_buffer([4, 128], dtype="float32", scope="shared") for v_u in T.serial(1): for tx in T.thread_binding(256, thread="threadIdx.x"): for vec in T.vectorized(3): with T.block("A_shared"): T.where(bx // 18 * 128 + (tx * 3 + vec) // 4 < 960 and k_0 * 4 + (tx * 3 + vec) % 4 < 770 and tx * 3 + vec < 512) A_shared[(tx * 3 + vec) // 4, (tx * 3 + vec) % 4] = A[bx // 18 * 128 + (tx * 3 + vec) // 4, k_0 * 4 + (tx * 3 + vec) % 4] for v_u in T.serial(1): for tx in T.thread_binding(256, thread="threadIdx.x"): for vec in T.vectorized(4): with T.block("B_shared"): T.where(k_0 * 4 + tx // 32 < 770 and tx * 4 + vec < 512) B_shared[tx // 32, tx % 32 * 4 + vec] = B[k_0 * 4 + tx // 32, bx % 18 * 128 + tx % 32 * 4 + vec] for k_1, i_3, j_3, k_2, i_4, j_4 in T.grid(1, 8, 1, 4, 2, 2): with T.block("update_update"): C[bx // 18 * 128 + tx_p // 32 * 16 + i_3 * 2 + i_4, bx % 18 * 128 + vx * 64 + tx_p % 32 * 2 + j_4] = C[bx // 18 * 128 + tx_p // 32 * 16 + i_3 * 2 + i_4, bx % 18 * 128 + vx * 64 + tx_p % 32 * 2 + j_4] + A_shared[tx_p // 32 * 16 + i_3 * 2 + i_4, k_2] * B_shared[k_2, vx * 64 + tx_p % 32 * 2 + j_4]
def func(A: T.Buffer[(960, 770), "float32"], B: T.Buffer[(770, 2304), "float32"], C: T.Buffer[(960, 2304), "float32"]) -> None: for bx in T.thread_binding(144, thread="blockIdx.x"): for vx in T.thread_binding(2, thread="vthread.x"): for tx_p in T.thread_binding(256, thread="threadIdx.x"): with T.block(): for k_0 in T.serial(193): with T.block(): A_shared = T.alloc_buffer([960, 770], dtype="float32", scope="shared") B_shared = T.alloc_buffer([770, 2304], dtype="float32", scope="shared") for _u in T.serial(1): for tx in T.thread_binding(256, thread="threadIdx.x"): for vec in T.vectorized(3): with T.block("A_shared"): T.where(bx // 18 * 128 + ((_u * 256 + tx) * 3 + vec) // 4 < 960 and k_0 * 4 + ((_u * 256 + tx) * 3 + vec) % 4 < 770 and (_u * 256 + tx) * 3 + vec < 512) A_shared[bx // 18 * 128 + (_u * 768 + tx * 3 + vec) // 4, k_0 * 4 + (_u * 768 + tx * 3 + vec) % 4] = A[bx // 18 * 128 + (_u * 768 + tx * 3 + vec) // 4, k_0 * 4 + (_u * 768 + tx * 3 + vec) % 4] for _u in T.serial(1): for tx in T.thread_binding(256, thread="threadIdx.x"): for vec in T.vectorized(4): with T.block("B_shared"): T.where(k_0 * 4 + ((_u * 256 + tx) * 4 + vec) // 128 < 770 and (_u * 256 + tx) * 4 + vec < 512) B_shared[k_0 * 4 + (_u * 1024 + tx * 4 + vec) // 128, bx % 18 * 128 + (_u * 1024 + tx * 4 + vec) % 128] = B[k_0 * 4 + (_u * 1024 + tx * 4 + vec) // 128, bx % 18 * 128 + (_u * 1024 + tx * 4 + vec) % 128] for k_1, i_3, j_3, k_2, i_4, j_4 in T.grid(1, 8, 1, 4, 2, 2): with T.block("update_update"): C[(((bx // 18 + 0) * 8 + tx_p // 32) * 8 + i_3) * 2 + i_4, ((bx % 18 * 2 + vx % 2) * 32 + tx_p % 32 + j_3) * 2 + j_4] = C[(((bx // 18 + 0) * 8 + tx_p // 32) * 8 + i_3) * 2 + i_4, ((bx % 18 * 2 + vx % 2) * 32 + tx_p % 32 + j_3) * 2 + j_4] + A_shared[(((bx // 18 + 0) * 8 + tx_p // 32) * 8 + i_3) * 2 + i_4, (k_0 + k_1) * 4 + k_2] * B_shared[(k_0 + k_1) * 4 + k_2, ((bx % 18 * 2 + vx % 2) * 32 + tx_p % 32 + j_3) * 2 + j_4]
def element_wise_thread_x(a: T.handle, b: T.handle, c: T.handle) -> None: j1_0 = T.env_thread("threadIdx.x") j0_0 = T.env_thread("threadIdx.x") i = T.env_thread("blockIdx.x") A = T.match_buffer(a, [128, 128]) B = T.match_buffer(b, [128, 128]) C = T.match_buffer(c, [128, 128]) T.launch_thread(i, 128) with T.launch_thread(j0_0, 4): for j0_1 in T.serial(0, 32): T.store( B.data, i * 128 + j0_0 * 32 + j0_1, T.load("float32", A.data, i * 128 + j0_0 * 32 + j0_1) * 2.0, True, ) T.launch_thread(j1_0, 4) for j1_1 in T.serial(0, 32): T.store( C.data, i * 128 + j1_0 * 32 + j1_1, T.load("float32", A.data, i * 128 + j1_0 * 32 + j1_1) + 1.0, True, )
def non_perfect_tiling_cache(a: T.handle, b: T.handle) -> None: X = T.match_buffer(a, [224, 224], dtype="float32") Y = T.match_buffer(b, [224, 224], dtype="float32") cache = T.alloc_buffer([224, 224], dtype="float32") for hh_0, ww_0 in T.grid(28, 28): for ax0 in T.serial(0, 10): for ax1 in T.serial(0, 10): with T.block("cache"): h = T.axis.spatial(224, hh_0 * 8 - 1 + ax0) w = T.axis.spatial(224, ww_0 * 8 - 1 + ax1) T.where( 1 <= hh_0 * 8 + ax0 and hh_0 * 8 + ax0 < 225 and 1 <= ww_0 * 8 + ax1 and ww_0 * 8 + ax1 < 225 ) cache[h, w] = X[h, w] for hh_1, ww_1, khh, kww in T.grid(8, 8, 3, 3): with T.block("compute"): h = T.axis.spatial(224, hh_0 * 8 + hh_1) w = T.axis.spatial(224, ww_0 * 8 + ww_1) kh, kw = T.axis.remap("RR", [khh, kww]) with T.init(): Y[h, w] = 0.0 Y[h, w] = T.max( Y[h, w], T.if_then_else( T.likely(1 <= h + kh, dtype="bool") and T.likely(h + kh < 225, dtype="bool") and T.likely(1 <= w + kw, dtype="bool") and T.likely(w + kw < 225, dtype="bool"), cache[h + kh - 1, w + kw - 1], 0.0, dtype="float32", ), )
def elementwise_not_affine_fused(a: T.handle, b: T.handle) -> None: A = T.match_buffer(a, [127, 128]) B = T.match_buffer(b, [127, 128]) for i in T.grid(4): for j_k_fused in T.serial(0, T.min(31, 126 - i * 32) * 128 + 128): with T.block("B"): vi = T.axis.S( 127, i * 32 + T.floormod(T.floordiv(j_k_fused, 128), T.min(31, 126 - i * 32) + 1), ) vj = T.axis.S(128, T.floormod(j_k_fused, 128)) T.reads([A[vi, vj]]) T.writes([B[vi, vj]]) B[vi, vj] = A[vi, vj]
def tir_matmul( A: T.Buffer[(16384,), "float32"], B: T.Buffer[(16384,), "float32"], C: T.Buffer[(16384,), "float32"], ) -> None: # function attr dict T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}) T.preflattened_buffer(A, [128, 128], dtype="float32", data=A.data) T.preflattened_buffer(B, [128, 128], dtype="float32", data=B.data) T.preflattened_buffer(C, [128, 128], dtype="float32", data=C.data) # body for x, y in T.grid(128, 128): C[x * 128 + y] = T.float32(0) for k in T.serial(128): C[x * 128 + y] = C[x * 128 + y] + A[x * 128 + k] * B[y * 128 + k]
def main(a: T.handle, b: T.handle) -> None: A = T.match_buffer(a, [64, 64, 64]) B = T.match_buffer(b, [64]) for i0, j0 in T.grid(64, 64): for k0 in T.serial(32, 64): with T.block(): i, j, k = T.axis.remap("SRR", [i0, j0, k0]) T.reads(A[i, j, k]) T.writes(B[i]) BB = T.match_buffer(B[i], ()) AA = T.match_buffer(A[i, 0:64, 0:64], (64, 64)) if (j == 0) and (k == 32): BB[()] = T.float32(0) BB[()] += AA[j, k]
def element_wise_invalid_annotation(a: T.handle, c: T.handle) -> None: C = T.match_buffer(c, [128, 128], elem_offset=0, align=128, offset_factor=1) A = T.match_buffer(a, [128, 128], elem_offset=0, align=128, offset_factor=1) # body with T.block([], "root"): T.reads([]) T.writes([]) B = T.alloc_buffer([128, 128], elem_offset=0, align=128, offset_factor=1) for i0 in T.serial(0, 128): for ax1 in T.serial(0, 128): with T.block([128, 128], "B") as [vi, vj]: T.block_attr({"buffer_dim_align": [0]}) T.bind(vi, i0) T.bind(vj, ax1) T.reads([A[vi, vj]]) T.writes([B[vi, vj]]) B[vi, vj] = (A[vi, vj]*T.float32(2)) for i1 in T.serial(0, 128): with T.block([128, 128], "C") as [vi_1, vj_1]: T.bind(vi_1, i0) T.bind(vj_1, i1) T.reads([B[vi_1, vj_1]]) T.writes([C[vi_1, vj_1]]) C[vi_1, vj_1] = (B[vi_1, vj_1] + T.float32(1))
def unified_element_wise_vthread_x(a: T.handle, b: T.handle) -> None: vthread_x = T.env_thread("vthread.x") thread_x = T.env_thread("threadIdx.x") A = T.match_buffer(a, [128, 128]) B = T.match_buffer(b, [128, 128]) T.launch_thread(vthread_x, 2) T.launch_thread(thread_x, 64) T.launch_thread(vthread_x, 2) for j_1 in T.serial(0, 64): T.store( B.data, vthread_x * 8256 + thread_x * 128 + j_1, T.load("float32", A.data, vthread_x * 8256 + thread_x * 128 + j_1) * 2.0, True, )
def access_opaque_ptr_then_elemwise_inline(a: T.handle, b: T.handle) -> None: A = T.match_buffer(a, [1024], dtype="float32") B = T.match_buffer(b, [1024], dtype="float32") A_cache = T.alloc_buffer([1024], dtype="float32") with T.block("opaque"): # annotated opaque partial access should be kept T.reads(A[0:512]) T.writes([A_cache[0:512]]) T.evaluate(A.access_ptr("r", extent=512)) T.evaluate(A_cache.access_ptr("w", extent=512)) for i in T.serial(0, 512): with T.block("B"): vi = T.axis.spatial(512, i) T.reads([A_cache[vi]]) T.writes([B[vi]]) B[vi] = A_cache[vi] * 2.0 + 1.0
def rowsum_blockized(a: T.handle, b: T.handle) -> None: B = T.match_buffer(b, [32, 4]) A = T.match_buffer(a, [32, 4, 128]) for i0, i2_0 in T.grid(32, 16): with T.block("blockized_B"): io, ko = T.axis.remap("SR", [i0, i2_0]) with T.init(): for i1 in T.serial(0, 4): with T.block("B_init"): ii_init = T.axis.S(4, i1) B[io, ii_init] = 0.0 for i1_1, i2_1 in T.grid(4, 8): with T.block("B"): ii = T.axis.S(4, i1_1) k = T.axis.R(128, ko * 8 + i2_1) B[io, ii] = B[io, ii] + A[io, ii, k]
def main( A: T.Buffer[(16384, ), "float32"], B: T.Buffer[(16384, ), "float32"], C: T.Buffer[(16384, ), "float32"], ) -> None: # function attr dict T.func_attr({"global_symbol": "main", "tir.noalias": True}) T.preflattened_buffer(A, [128, 128], data=A.data) T.preflattened_buffer(B, [128, 128], data=B.data) T.preflattened_buffer(C, [128, 128], data=C.data) # body for x, y in T.grid(128, 128): C[x * 128 + y] = 0.0 for k in T.serial(0, 128): C[x * 128 + y] = C[x * 128 + y] + A[x * 128 + k] * B[y * 128 + k]
def after_rowsum_blockize( A: T.Buffer[(128, 128), "float32"], B: T.Buffer[(128, ), "float32"], ) -> None: with T.block("blockized_B"): vko = T.axis.R(1, 0) vio = T.axis.S(1, 0) with T.init(): for i1 in T.serial(0, 128): with T.block("B_init"): vi_init = T.axis.S(128, i1) B[vi_init] = T.float32(0) for i0, i1_1 in T.grid(128, 128): with T.block("B"): vk, vi = T.axis.remap("RS", [i0, i1_1]) B[vi] = B[vi] + A[vi, vk]
def main( T_reshape: T.Buffer[(1, 12, 384, 384), "float32"], placeholder_1: T.Buffer[(T.int64(1), T.int64(12), T.int64(384), 384), "bool"], T_where: T.Buffer[(T.int64(1), T.int64(12), T.int64(384), 384), "float32"] ) -> None: # function attr dict T.func_attr({"global_symbol": "main", "tir.noalias": True}) # body # with T.block("root") for i0_i1_i2_i3_fused_1 in T.thread_binding(T.int64(256), thread="blockIdx.x"): for i0_i1_i2_i3_fused_2 in T.thread_binding( T.int64(1024), thread="threadIdx.x"): for i0_i1_i2_i3_fused_0 in T.serial(T.int64(7)): with T.block("T_where"): ax0 = T.axis.spatial(T.int64(1), T.int64(0)) ax1 = T.axis.spatial( T.int64(12), ((i0_i1_i2_i3_fused_0 * T.int64(256) + i0_i1_i2_i3_fused_1) * T.int64(1024) + i0_i1_i2_i3_fused_2) % T.int64(1769472) // T.int64(147456)) ax2 = T.axis.spatial( T.int64(384), ((i0_i1_i2_i3_fused_0 * T.int64(256) + i0_i1_i2_i3_fused_1) * T.int64(1024) + i0_i1_i2_i3_fused_2) % T.int64(147456) // T.int64(384)) ax3 = T.axis.spatial( 384, T.cast(((i0_i1_i2_i3_fused_0 * T.int64(256) + i0_i1_i2_i3_fused_1) * T.int64(1024) + i0_i1_i2_i3_fused_2) % T.int64(384), "int32")) T.where((i0_i1_i2_i3_fused_0 * T.int64(256) + i0_i1_i2_i3_fused_1) * T.int64(1024) + i0_i1_i2_i3_fused_2 < T.int64(1769472)) T.reads(placeholder_1[ax0, ax1, ax2, ax3], T_reshape[ax0, ax1, ax2, ax3]) T.writes(T_where[ax0, ax1, ax2, ax3]) T_where[ax0, ax1, ax2, ax3] = T.Select( T.cast(placeholder_1[ax0, ax1, ax2, ax3], "int32") != 0, T.float32(-1000000000), T_reshape[ax0, ax1, ax2, ax3])
def rowsum_blockized(a: T.handle, b: T.handle) -> None: B = T.match_buffer(b, [32, 4]) A = T.match_buffer(a, [32, 4, 128]) for i0, i2_0 in T.grid(32, 16): with T.block([32, T.reduce_axis(0, 16)], "blockized_B") as [io, ko]: T.bind(io, i0) T.bind(ko, i2_0) with T.init(): for i1 in T.serial(0, 4): with T.block([4], "B_init") as [ii_init]: T.bind(ii_init, i1) B[io, ii_init] = 0.0 for i1_1, i2_1 in T.grid(4, 8): with T.block([4, T.reduce_axis(0, 128)], "B") as [ii, k]: T.bind(ii, i1_1) T.bind(k, ko * 8 + i2_1) B[io, ii] = B[io, ii] + A[io, ii, k]
def opaque_access_func() -> None: A = T.alloc_buffer([1024]) B = T.alloc_buffer([1024]) for i in T.serial(0, 8): with T.block(): v = T.axis.S(8, i) T.reads([A[v * 128:v * 128 + 128]]) T.writes([B[v * 128:v * 128 + 128]]) T.evaluate( T.call_extern("test", B.data, v * 128, 128, A.data, v * 128, 128, dtype="float32"))
def exp_exp_opaque_access_with_tvm_access_ptr_inlined( lookup_table: T.Buffer[(1024,), "int8"], x: T.Buffer[(16,), "float16"], compute: T.Buffer[(16,), "float16"], ) -> None: for i0 in T.serial(16): with T.block("compute_1"): i0_1 = T.axis.spatial(16, i0) # Do not put the opaque access to new write region when opaque access # wrapped with a tvm_access_ptr and the access mask set to "read only" T.reads(x[i0_1], lookup_table[0:1024]) T.writes(compute[i0_1]) compute[i0_1] = T.exp( T.exp(x[i0_1], dtype="float16"), lookup_table.access_ptr("r"), dtype="float16", )
def element_wise_vthread_x(a: T.handle, b: T.handle) -> None: i_0 = T.env_thread("vthread.x") i_1 = T.env_thread("threadIdx.x") j_0 = T.env_thread("vthread.x") A = T.match_buffer(a, [128, 128]) B = T.match_buffer(b, [128, 128]) T.launch_thread(i_0, 2) T.launch_thread(i_1, 64) T.launch_thread(j_0, 2) for j_1 in T.serial(0, 64): T.store( B.data, i_0 * 8192 + i_1 * 128 + j_0 * 64 + j_1, T.load("float32", A.data, i_0 * 8192 + i_1 * 128 + j_0 * 64 + j_1) * 2.0, True, )
def transformed_square_sum_square_root_factor_one_2(a: T.handle, d: T.handle) -> None: A = T.match_buffer(a, [16, 256, 256]) D = T.match_buffer(d, [16]) C = T.alloc_buffer([16]) for i0, i1_i2_fused_outer, i1_i2_fused_inner in T.grid(16, 1, 65536): with T.block("C"): b = T.axis.S(16, i0) i = T.axis.R(256, T.floordiv(i1_i2_fused_inner, 256)) j = T.axis.R(256, T.floormod(i1_i2_fused_inner, 256)) with T.init(): C[b] = 0.0 C[b] = C[b] + (A[b, i, j] * A[b, i, j]) for i0_1 in T.serial(0, 16): with T.block("D"): b_1 = T.axis.S(16, i0_1) D[b_1] = T.sqrt(C[b_1], dtype="float32")
def simple_compute_missing_annotation(A: T.Buffer[(16, 16), "float32"], C: T.Buffer[(16, 16), "float32"]): for tx in T.thread_binding(0, 16, thread="threadIdx.x"): for i in T.serial(0, 16, annotations={"software_pipeline_stage": [0, 1]}): with T.block(): T.reads(A[tx, i]) T.writes(C[tx, i]) B = T.alloc_buffer((16, 1), dtype="float32", scope="shared") with T.block(): T.reads(A[tx, i]) T.writes(B[tx, 0]) B[tx, 0] = A[tx, i] * T.float32(2) with T.block(): T.reads(B[tx, 0]) T.writes(C[tx, i]) C[tx, i] = B[tx, 0] + T.float32(1)
def main(a: T.handle, b: T.handle, c: T.handle) -> None: # function attr dict T.func_attr({ "global_symbol": "main", "from_legacy_te_schedule": True, "tir.noalias": True }) A = T.match_buffer(a, [128, 128]) B = T.match_buffer(b, [128, 128]) C = T.match_buffer(c, [128, 128]) # body for x, y in T.grid(128, 128): C.data[x * 128 + y] = 0.0 for k in T.serial(0, 128): C.data[x * 128 + y] = T.load( "float32", C.data, x * 128 + y) + T.load("float32", A.data, x * 128 + k) * T.load( "float32", B.data, y * 128 + k)
def concat_func_3( placeholder: T.Buffer[(50176,), "int8"], placeholder_1: T.Buffer[(25088,), "int8"], placeholder_2: T.Buffer[(25088,), "int8"], T_concat: T.Buffer[(100352,), "int8"], ) -> None: T.preflattened_buffer(placeholder, (1, 64, 28, 28), "int8", data=placeholder.data) T.preflattened_buffer(placeholder_1, (1, 32, 28, 28), "int8", data=placeholder_1.data) T.preflattened_buffer(placeholder_2, (1, 32, 28, 28), "int8", data=placeholder_2.data) T.preflattened_buffer(T_concat, (1, 128, 28, 28), "int8", data=T_concat.data) for i1 in T.serial(128, annotations={"pragma_loop_partition_hint": 1}): for i2, i3 in T.grid(28, 28): if 96 <= i1: T_concat[i1 * 784 + i2 * 28 + i3] = placeholder_2[i1 * 784 + i2 * 28 + i3 - 75264] if 64 <= i1 and i1 < 96: T_concat[i1 * 784 + i2 * 28 + i3] = placeholder_1[i1 * 784 + i2 * 28 + i3 - 50176] if i1 < 64: T_concat[i1 * 784 + i2 * 28 + i3] = placeholder[i1 * 784 + i2 * 28 + i3]
def block_predicate_cache_write_output_buf() -> None: A = T.alloc_buffer([120], dtype="float32") B = T.alloc_buffer([120], dtype="float32") B_shared = T.alloc_buffer([120], dtype="float32", scope="shared") for i, j in T.grid(16, 8): with T.block("producer"): ax = T.axis.spatial(120, i * 8 + j) T.where(i * 8 + j < 120) A[ax] = T.float32(0) for i, j in T.grid(16, 8): with T.block("consumer"): ax = T.axis.spatial(120, i * 8 + j) T.where(i * 8 + j < 120) B_shared[ax] = A[ax] + T.float32(1) for ax0 in T.serial(120): with T.block("B_shared"): v0 = T.axis.spatial(120, ax0) B[v0] = B_shared[v0]
def expected(A: T.Buffer[(4, 4), "float32"]): for i in T.serial(4): if i < 2: for j in T.serial(4): if j < 3: for k in T.serial(4): A[i, j] = 0.0 else: for k in T.serial(4): A[i, j] = 2.0 else: for j in T.serial(4): if j < 3: for k in T.serial(4): A[i, j] = 1.0 else: for k in T.serial(4): A[i, j] = 3.0
def func_3( C: T.Buffer[(1,), "float32"], A: T.Buffer[(16,), "float32"], D: T.Buffer[(2,), "float32"], E: T.Buffer[(16,), "float32"], F: T.Buffer[(16,), "float32"], ): for i in T.serial( 0, 16, ): with T.block(): B = T.alloc_buffer((1,), dtype="float32") with T.block(): B[0] = A[i] * T.float32(2) with T.block(): E[i] = A[i] F[i] = E[i] + 1.0 C[0] = C[0] + A[i] + B[0] + T.float32(1) + D[0] A[i] = B[0] + T.float32(1) + D[1]
def dag_interleaving( A: T.Buffer[(16, 16), "float32"], B: T.Buffer[(16, 16), "float32"], C: T.Buffer[(16, 16), "float32"], ) -> None: for tx in T.thread_binding(0, 16, thread="threadIdx.x"): for i in T.serial( 0, 16, annotations={ "software_pipeline_stage": [0, 0, 0, 0, 1], "software_pipeline_order": [0, 2, 1, 3, 4], }, ): with T.block(): T.reads(A[tx, i]) T.writes(C[tx, i]) AS = T.alloc_buffer((16, 1), dtype="float32", scope="shared") BS = T.alloc_buffer((16, 1), dtype="float32", scope="shared") AL = T.alloc_buffer((1, 1), dtype="float32", scope="local") BL = T.alloc_buffer((1, 1), dtype="float32", scope="local") with T.block(): T.reads(A[tx, i]) T.writes(AS[tx, 0]) AS[tx, 0] = A[tx, i] * T.float32(2) with T.block(): T.reads(AS[tx, 0]) T.writes(AL[0, 0]) AL[0, 0] = AS[tx, 0] with T.block(): T.reads(B[tx, i]) T.writes(BS[tx, 0]) BS[tx, 0] = B[tx, i] + T.float32(2) with T.block(): T.reads(BS[tx, 0]) T.writes(BL[0, 0]) BL[0, 0] = BS[tx, 0] with T.block(): T.reads(AL[0, 0], BL[0, 0]) T.writes(C[tx, i]) C[tx, i] = AL[0, 0] * BL[0, 0]
def matmul_decompose4(a: T.handle, b: T.handle, c: T.handle) -> None: C = T.match_buffer(c, [128, 128], elem_offset=0, align=128, offset_factor=1) B = T.match_buffer(b, [128, 128], elem_offset=0, align=128, offset_factor=1) A = T.match_buffer(a, [128, 128], elem_offset=0, align=128, offset_factor=1) # body with T.block("root"): T.reads([]) T.writes([]) for i0_0 in T.serial(0, 16): for i0_1_init, i1_init in T.grid(8, 128): with T.block("update_init"): vi_init = T.axis.S(128, i0_0 * 8 + i0_1_init) vj_init = T.axis.S(128, i1_init) C[vi_init, vj_init] = T.float32(0) for i0_1, i1, i2_0, i2_1 in T.grid(8, 128, 19, 7): with T.block("update_update"): T.where((((i2_0 * 7) + i2_1) < 128)) vi = T.axis.S(128, i0_0 * 8 + i0_1) vj = T.axis.S(128, i1) vk = T.axis.R(128, i2_0 * 7 + i2_1) C[vi, vj] = C[vi, vj] + (A[vi, vk] * B[vj, vk])