def annotated_tensorized_matmul(a: T.handle, b: T.handle, c: T.handle) -> None: C = T.match_buffer(c, [128, 128], elem_offset=0, align=128, offset_factor=1) B = T.match_buffer(b, [128, 128], elem_offset=0, align=128, offset_factor=1) A = T.match_buffer(a, [128, 128], elem_offset=0, align=128, offset_factor=1) for i_outer, j_outer in T.grid(8, 8): for i_inner_init, j_inner_init in T.grid(16, 16): with T.block("init"): vi_init = T.axis.S(128, ((i_outer * 16) + i_inner_init)) vj_init = T.axis.S(128, ((j_outer * 16) + j_inner_init)) T.block_attr({"test_annotation": True}) C[vi_init, vj_init] = T.float32(0) for k_outer in T.grid(8): with T.block("update"): vi, vj, vk = T.axis.remap("SSR", [i_outer, j_outer, k_outer]) T.reads([ C[vi * 16:vi * 16 + 16, vj * 16:vj * 16 + 16], A[vi * 16:vi * 16 + 16, vk * 16:vk * 16 + 16], B[vj * 16:vj * 16 + 16, vk * 16:vk * 16 + 16], ]) T.writes(C[vi * 16:vi * 16 + 16, vj * 16:vj * 16 + 16]) A_elem_offset = T.var("int32") B_elem_offset = T.var("int32") C_elem_offset = T.var("int32") A_sub = T.match_buffer( A[vi * 16:vi * 16 + 16, vk * 16:vk * 16 + 16], [16, 16], elem_offset=A_elem_offset, ) B_sub = T.match_buffer( B[vj * 16:vj * 16 + 16, vk * 16:vk * 16 + 16], [16, 16], elem_offset=B_elem_offset, ) C_sub = T.match_buffer( C[vi * 16:vi * 16 + 16, vj * 16:vj * 16 + 16], [16, 16], elem_offset=C_elem_offset, ) T.evaluate( T.tvm_mma_sync( C_sub.data, T.floordiv(C_sub.elem_offset, 256), A_sub.data, T.floordiv(A_sub.elem_offset, 256), B_sub.data, T.floordiv(B_sub.elem_offset, 256), C_sub.data, T.floordiv(C_sub.elem_offset, 256), dtype="handle", ))
def mma_store_impl(a: T.handle, c: T.handle) -> None: s0 = T.var("int32") s1 = T.var("int32") C_warp = T.match_buffer(a, [WARP_SIZE, local_size], dtype=dtype, scope="warp", offset_factor=1) C = T.match_buffer(c, [M_DIM, N_DIM], dtype=dtype, scope="global", offset_factor=1, strides=[s0, s1]) with T.block("root"): T.reads(C_warp[0:WARP_SIZE, 0:local_size]) T.writes(C[0:M_DIM, 0:N_DIM]) tx = T.env_thread("threadIdx.x") T.launch_thread(tx, WARP_SIZE) T.evaluate( T.mma_store( M_DIM, N_DIM, C.access_ptr("w"), C_warp.data, C_warp.elem_offset, s0, dtype=dtype, ))
def symbolic_match(a: T.handle, b: T.handle, n: T.int32, m: T.int32) -> None: A = T.match_buffer(a, (n * m, m)) B = T.match_buffer(b, (n * 2, m * 4)) for i in range(0, n): with T.block(): T.reads([]) T.writes([A[i * m:i * m + n, 0:m], B[i * n:i * n + 2, 0:m * 4]]) Bs_0 = T.var("int32") Bs_1 = T.var("int32") sub_A = T.match_buffer(A[i * m:i * m + m, 0:m], (m, m), offset_factor=1) sub_B = T.match_buffer(B[i * n:i * n + 2, 0:m * 4], (2, m * 4), strides=[Bs_0, Bs_1], offset_factor=1) for ii, jj in T.grid(m, m): sub_A[ii, jj] = 1 for j in range(0, 4): T.evaluate( T.intrin_test( sub_B.data, sub_B.elem_offset, sub_B.strides[0], sub_B.strides[1], sub_B.shape[0], sub_B.shape[1], dtype="handle", ))
def ldmatrix_impl(warp_handle: T.handle, shared_handle: T.handle) -> None: s0 = T.var("int32") s1 = T.var("int32") shared = T.match_buffer( shared_handle, shmem_shape, dtype, align=128, offset_factor=16, scope=shared_scope, strides=[s0, s1], ) warp = T.match_buffer(warp_handle, (WARP_SIZE, local_size), dtype, align=128, offset_factor=16, scope="warp") with T.block("root"): T.reads(shared[0:row_dim, 0:col_dim]) T.writes(warp[0:WARP_SIZE, 0:local_size]) tx = T.env_thread("threadIdx.x") T.launch_thread(tx, WARP_SIZE) T.evaluate( T.ptx_ldmatrix( ldmatrix_col_major, 4, # Always load 4 matrices ".b16", warp.data, warp.elem_offset + lift(local_size) * tx, shared.access_ptr("r"), shared_offset(tx, s0), dtype=dtype, ))
def recursive_match(a: T.handle, b: T.handle) -> None: A = T.match_buffer(a, (64, 64, 64)) B = T.match_buffer(b, (64, 64, 64)) for i, j, k in T.grid(64, 4, 4): with T.block([]): T.reads([]) T.writes( [ A[i, j * 16 : j * 16 + 16, k * 16 : k * 16 + 16], B[i, j * 16 : j * 16 + 16, k * 16 : k * 16 + 16], ] ) As_0 = T.var("int32") As_1 = T.var("int32") sub_A = T.match_buffer( A[i, j * 16 : j * 16 + 16, k * 16 : k * 16 + 16], (16, 16), strides=[As_0, As_1], offset_factor=1, ) sub_B = T.match_buffer( B[i, j * 16 : j * 16 + 16, k * 16 : k * 16 + 16], (16, 16), offset_factor=1, ) for jj, kk in T.grid(4, 4): with T.block([]): T.reads([]) T.writes( [ sub_A[jj * 4 : jj * 4 + 4, kk * 4 : kk * 4 + 4], sub_B[jj * 4 : jj * 4 + 4, kk * 4 : kk * 4 + 4], ] ) Ass_0 = T.var("int32") Ass_1 = T.var("int32") sub_sub_A = T.match_buffer( sub_A[jj * 4 : jj * 4 + 4, kk * 4 : kk * 4 + 4], (4, 4), strides=[Ass_0, Ass_1], offset_factor=1, ) sub_sub_B = T.match_buffer( sub_B[jj * 4 : jj * 4 + 4, kk * 4 : kk * 4 + 4], (4, 4), offset_factor=1, ) T.evaluate( T.intrin_test( sub_sub_A.data, sub_sub_A.elem_offset, sub_sub_A.strides[0], sub_sub_A.strides[1], sub_sub_A.shape[0], sub_sub_A.shape[1], dtype="handle", ) ) for jjj, kkk in T.grid(4, 4): sub_sub_B[jjj, kkk] = 1
def wmma_load_impl(a: T.handle, c: T.handle) -> None: s1 = T.var("int32") s0 = T.var("int32") A = T.match_buffer( a, (m_dim, n_dim), dtype, align=128, offset_factor=16, scope=shared_scope, strides=[s1, s0], ) C = T.match_buffer(c, (m_dim, n_dim), dtype, align=128, offset_factor=16, scope=wmma_fragment_scope) with T.block("root"): T.reads(A[0:m_dim, 0:n_dim]) T.writes(C[0:m_dim, 0:n_dim]) T.evaluate( T.tvm_load_matrix_sync( C.data, m_dim, n_dim, k_dim, get_wmma_fragment_index(C, m_dim, n_dim), A.access_ptr("r"), s1, layout, dtype="handle", ))
def high_dim_opaque_access_with_source_strides(a: T.handle) -> None: A = T.match_buffer(a, (16, 32, 64), strides=[2576, 80, 1]) for i, j, k in T.grid(16, 2, 4): with T.block([]): As_0 = T.var("int32") As_1 = T.var("int32") T.reads([]) T.writes(A[i, j * 16 : j * 16 + 16, k * 16 : k * 16 + 16]) sub_A = T.match_buffer( A[i, j * 16 : j * 16 + 16, k * 16 : k * 16 + 16], (16, 16), strides=[As_0, As_1], offset_factor=1, ) T.evaluate( T.intrin_test( sub_A.data, sub_A.elem_offset, sub_A.strides[0], sub_A.strides[1], sub_A.shape[0], sub_A.shape[1], dtype="handle", ) )
def wmma_store_impl(a: T.handle, c: T.handle) -> None: s1 = T.var("int32") s0 = T.var("int32") A = T.match_buffer(a, (m_dim, n_dim), dtype, align=128, offset_factor=16, scope="wmma.accumulator") C = T.match_buffer(c, (m_dim, n_dim), dtype, align=128, offset_factor=16, scope=scope, strides=[s1, s0]) with T.block("root"): T.reads(A[0:m_dim, 0:n_dim]) T.writes(C[0:m_dim, 0:n_dim]) T.evaluate( T.tvm_store_matrix_sync( A.data, m_dim, n_dim, k_dim, get_wmma_fragment_index(A, m_dim, n_dim), C.access_ptr("w"), s1, "row_major", dtype="handle", ))
def matmul_m_8x(a: T.handle, b: T.handle, c: T.handle) -> None: x = T.var("int32") m = T.var("int32") A = T.match_buffer(a, [m, x * 8]) B = T.match_buffer(b, [m, x * 8]) C = T.match_buffer(c, [m, m]) with T.block([m, m, T.reduce_axis(0, x * 8)], "update") as [vi, vj, vk]: with T.init(): C[vi, vj] = 0.0 C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vj, vk]
def test_tir_fma(A: T.handle, B: T.handle, C: T.handle, d: T.handle) -> None: # function attr dict T.func_attr({"global_symbol": "test_fma", "tir.noalias": True}) n = T.var("int32") stride = T.var("int32") stride_1 = T.var("int32") stride_2 = T.var("int32") stride_3 = T.var("int32") A_1 = T.match_buffer( A, [n], strides=[stride], elem_offset=0, align=128, offset_factor=1, type="auto", ) B_1 = T.match_buffer( B, [n], strides=[stride_1], elem_offset=0, align=128, offset_factor=1, type="auto", ) C_1 = T.match_buffer( C, [n], strides=[stride_2], elem_offset=0, align=128, offset_factor=1, type="auto", ) d_1 = T.match_buffer( d, [n], strides=[stride_3], elem_offset=0, align=128, offset_factor=1, type="auto", ) # body for i in T.serial(0, n): d_1.data[(i * stride_3)] = (T.load("float32", A_1.data, (i * stride)) * T.load("float32", B_1.data, (i * stride_1))) + T.load( "float32", C_1.data, (i * stride_2))
def tensorized_batch_matmul_mma( A: T.Buffer[(16, 128, 128), "float32"], B: T.Buffer[(16, 128, 128), "float32"], C: T.Buffer[(16, 128, 128), "float32"], ) -> None: for n, i, j in T.grid(16, 128, 128): with T.block("init"): vn, vi, vj = T.axis.remap("SSS", [n, i, j]) T.reads() T.writes(C[vn, vi, vj]) C[vn, vi, vj] = T.float32(0) for n in range(0, 16): for i, j, k in T.grid(8, 8, 8): with T.block("update"): vn, vi, vj, vk = T.axis.remap("SSSR", [n, i, j, k]) T.reads( C[vn:vn + 1, vi * 16:vi * 16 + 16, vj * 16:vj * 16 + 16], A[vn:vn + 1, vi * 16:vi * 16 + 16, vk * 16:vk * 16 + 16], B[vn:vn + 1, vj * 16:vj * 16 + 16, vk * 16:vk * 16 + 16], ) T.writes(C[vn:vn + 1, vi * 16:vi * 16 + 16, vj * 16:vj * 16 + 16]) A_elem_offset = T.var("int32") B_elem_offset = T.var("int32") C_elem_offset = T.var("int32") A_sub = T.match_buffer( A[vn:vn + 1, vi * 16:vi * 16 + 16, vk * 16:vk * 16 + 16], (16, 16), elem_offset=A_elem_offset, ) B_sub = T.match_buffer( B[vn:vn + 1, vj * 16:vj * 16 + 16, vk * 16:vk * 16 + 16], (16, 16), elem_offset=B_elem_offset, ) C_sub = T.match_buffer( C[vn:vn + 1, vi * 16:vi * 16 + 16, vj * 16:vj * 16 + 16], (16, 16), elem_offset=C_elem_offset, ) T.evaluate( T.tvm_mma_sync( C_sub.data, T.floordiv(C_sub.elem_offset, 256), A_sub.data, T.floordiv(A_sub.elem_offset, 256), B_sub.data, T.floordiv(B_sub.elem_offset, 256), C_sub.data, T.floordiv(C_sub.elem_offset, 256), dtype="handle", ))
def tir_packed_call() -> None: A = T.var("handle") B = T.var("handle") C = T.var("handle") # body T.evaluate( T.tvm_call_cpacked( "tvm_test_cpacked", A, B, C, dtype="int32", ))
def gemm_dyn_shape(a: T.handle, b: T.handle, c: T.handle): N = T.var("int32") M = T.var("int32") K = T.var("int32") A = T.match_buffer(a, (N, K), "float32") B = T.match_buffer(b, (K, M), "float32") C = T.match_buffer(c, (N, M), "float32") for i, j, k in T.grid(N, M, K): with T.block("gemm"): vi, vj, vk = T.axis.remap("SSR", [i, j, k]) with T.init(): C[vi, vj] = 0.0 C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj]
def element_wise(a: T.handle, c: T.handle) -> None: m = T.var("int32") n = T.var("int32") A = T.match_buffer(a, (m, n), "float32") C = T.match_buffer(c, (m, n), "float32") B = T.alloc_buffer((m, n), "float32") with T.block([m, n], "B") as [vi, vj]: B[vi, vj] = A[vi, vj] * 2.0 with T.block([m, n], "C") as [vi, vj]: C[vi, vj] = B[vi, vj] + 1.0
def matmul_m_8x(a: T.handle, b: T.handle, c: T.handle) -> None: x = T.var("int32") m = T.var("int32") A = T.match_buffer(a, [m, x * 8]) B = T.match_buffer(b, [m, x * 8]) C = T.match_buffer(c, [m, m]) for i, j, k in T.grid(m, m, x * 8): with T.block("update"): vi, vj, vk = T.axis.remap("SSR", [i, j, k]) with T.init(): C[vi, vj] = 0.0 C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vj, vk]
def main(buffer2: T.Buffer[(160,), "uint8"], buffer4: T.Buffer[(144,), "uint8"], buffer6: T.Buffer[(144,), "uint8"], buffer8: T.Buffer[(144,), "uint8"]) -> None: # function attr dict T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}) v1a = T.var("int32") v1c = T.var("int32") v2a = T.var("int32") v2c = T.var("int32") v3a = T.var("int32") v3c = T.var("int32") v4a = T.var("int32") v4c = T.var("int32") buffer1 = T.buffer_decl([8192], "int8") buffer10 = T.buffer_decl([2048], "int8") # body p4 = T.allocate([160], "uint8", "global") p7 = T.allocate([144], "uint8", "global") p10 = T.allocate([144], "uint8", "global") p11 = T.allocate([144], "uint8", "global") with T.attr(T.iter_var(v1a, None, "DataPar", ""), "pragma_compute_cycles_hint", 201): T.evaluate(T.call_extern("ethosu_copy", buffer2[0], 160, p4[0], dtype="handle")) with T.attr(T.iter_var(v2a, None, "DataPar", ""), "pragma_compute_cycles_hint", 205): T.evaluate(T.call_extern("ethosu_copy", buffer4[0], 144, p7[0], dtype="handle")) with T.attr(T.iter_var(v1c, None, "DataPar", ""), "pragma_compute_cycles_hint", 300): T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 32, 16, 0, 16, buffer1[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 16, 2, 16, 0, 16, buffer10[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, p4[0], 128, 12, p4[128], 32, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) with T.attr(T.iter_var(v3a, None, "DataPar", ""), "pragma_compute_cycles_hint", 209): T.evaluate(T.call_extern("ethosu_copy", buffer6[0], 144, p10[0], dtype="handle")) with T.attr(T.iter_var(v2c, None, "DataPar", ""), "pragma_compute_cycles_hint", 301): T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 32, 16, 0, 16, buffer1[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 16, 2, 16, 0, 16, buffer10[2], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, p7[0], 112, 12, p7[112], 32, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) with T.attr(T.iter_var(v4a, None, "DataPar", ""), "pragma_compute_cycles_hint", 213): T.evaluate(T.call_extern("ethosu_copy", buffer8[0], 144, p11[0], dtype="handle")) with T.attr(T.iter_var(v3c, None, "DataPar", ""), "pragma_compute_cycles_hint", 302): T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 32, 16, 0, 16, buffer1[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 16, 2, 16, 0, 16, buffer10[4], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, p10[0], 112, 12, p10[112], 32, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) with T.attr(T.iter_var(v4c, None, "DataPar", ""), "pragma_compute_cycles_hint", 303): T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 32, 16, 0, 16, buffer1[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 16, 2, 16, 0, 16, buffer10[6], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, p11[0], 112, 12, p11[112], 32, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle"))
def tir_multi_output(a0: T.handle, a1: T.handle, b0: T.handle, b1: T.handle) -> None: m = T.var("int32") n = T.var("int32") A0 = T.match_buffer(a0, (m, n)) A1 = T.match_buffer(a1, (m, n)) B0 = T.match_buffer(b0, (m, n)) B1 = T.match_buffer(b1, (m, n)) for i0, i1 in T.grid(m, n): with T.block([m, n], "B.v0") as [i, j]: B0[i, j] = A0[i, j] + 2.0 with T.block([m, n], "B.v1") as [i, j]: B1[i, j] = A1[i, j] * 3.0
def tir_multi_output(a0: T.handle, a1: T.handle, b0: T.handle, b1: T.handle) -> None: T.func_attr({"global_symbol": "main", "tir.noalias": True}) m = T.var("int32") n = T.var("int32") A0 = T.match_buffer(a0, (m, n)) A1 = T.match_buffer(a1, (m, n)) B0 = T.match_buffer(b0, (m, n)) B1 = T.match_buffer(b1, (m, n)) for i0, i1 in T.grid(m, n): with T.block("B.v0"): i, j = T.axis.remap("SS", [i0, i1]) B0[i, j] = A0[i, j] + 2.0 with T.block("B.v1"): i, j = T.axis.remap("SS", [i0, i1]) B1[i, j] = A1[i, j] * 3.0
def func_distributivity_expected(i1: T.int32, i2: T.int32, x: T.int32, y: T.int32, z: T.int32) -> None: B = T.buffer_decl((50, ), "int32") cse_var_1 = T.var("int32") with T.let(cse_var_1, x * y + x * z): B[i1] = cse_var_1 B[i2] = cse_var_1
def symbolic_func(a: T.handle, b: T.handle, n: T.int32): m = T.var("int32") A = T.match_buffer(a, (n, m)) B = T.match_buffer(b, (n, m * 2)) for i, j in T.grid(n, m): B[i, j * 2] = A[i, j] B[i, j * 2 + 1] = A[i, j]
def func_associativity_expected(i1: T.int32, i2: T.int32, x: T.int32, y: T.int32, z: T.int32) -> None: B = T.buffer_decl((50, ), "int32") cse_var_1 = T.var("int32") with T.let(cse_var_1, (x + y) + z): B[i1] = cse_var_1 B[i2] = cse_var_1
def tir_multi_output(a0: T.handle, a1: T.handle, b0: T.handle, b1: T.handle) -> None: m = T.var("int32") n = T.var("int32") A0 = T.match_buffer(a0, (m, n)) A1 = T.match_buffer(a1, (m, n)) B0 = T.match_buffer(b0, (m, n)) B1 = T.match_buffer(b1, (m, n)) for i0, i1 in T.grid(m, n): with T.block("B.v0"): i, j = T.axis.remap("SS", [i0, i1]) B0[i, j] = A0[i, j] + 2.0 with T.block("B.v1"): i, j = T.axis.remap("SS", [i0, i1]) B1[i, j] = A1[i, j] * 3.0
def opaque_access(a: T.handle, b: T.handle) -> None: A = T.match_buffer(a, (32, 64, 128)) B = T.match_buffer(b, (64, 64, 64)) for i, j, k in T.grid(2, 64, 8): with T.block([]): T.reads([]) T.writes(A[i * 16 : i * 16 + 16, j, k * 16 : k * 16 + 16]) sub_A = T.match_buffer( A[i * 16 : i * 16 + 16, j, k * 16 : k * 16 + 16], (16, 1, 16), strides=[8192, 128, 1], offset_factor=1, ) T.evaluate( T.intrin_test( sub_A.data, sub_A.elem_offset, sub_A.strides[0], sub_A.strides[1], sub_A.shape[0], sub_A.shape[1], dtype="handle", ) ) for i, j, k in T.grid(64, 2, 8): with T.block([]): Bs_0 = T.var("int32") Bs_1 = T.var("int32") T.reads([]) T.writes(B[i, j * 32 : j * 32 + 32, k * 8 : k * 8 + 8]) sub_B = T.match_buffer( B[i, j * 32 : j * 32 + 32, k * 8 : k * 8 + 8], (32, 8), strides=[Bs_0, Bs_1], offset_factor=1, ) T.evaluate( T.intrin_test( sub_B.data, sub_B.elem_offset, sub_B.strides[0], sub_B.strides[1], sub_B.shape[0], sub_B.shape[1], dtype="handle", ) )
def scalar_func(a: T.handle, b: T.handle): m = T.var("int32") n = 100 A = T.match_buffer(a, (n, m)) B = T.match_buffer(b, (n, m)) for i, j in T.grid(n, m): A[i, j] = B[i - 1, j + 1] + A[i - 1, j - 1]
def element_wise(a: T.handle, c: T.handle) -> None: m = T.var("int32") n = T.var("int32") A = T.match_buffer(a, (m, n), "float32") C = T.match_buffer(c, (m, n), "float32") B = T.alloc_buffer((m, n), "float32") for i, j in T.grid(m, n): with T.block("B"): vi, vj = T.axis.remap("SS", [i, j]) B[vi, vj] = A[vi, vj] * 2.0 for i, j in T.grid(m, n): with T.block("C"): vi, vj = T.axis.remap("SS", [i, j]) C[vi, vj] = B[vi, vj] + 1.0
def param_in_arith_exprs_n_16(a: T.handle, b: T.handle) -> None: n = T.var("int32") A = T.match_buffer(a, [2, 8], "int32") B = T.match_buffer(b, [16], "int32") for i in range(15): with T.block(): vi = T.axis.S(15, i) B[vi] = A[vi // 8, vi % 8] + 714
def param_in_arith_exprs(a: T.handle, b: T.handle) -> None: n = T.var("int32") A = T.match_buffer(a, [n // 8, 8], "int32") B = T.match_buffer(b, [n], "int32") for i in range(n - 1): with T.block(): vi = T.axis.S(n - 1, i) B[vi] = A[vi // 8, vi % 8] + (n + 1) * 42
def func_match_buffer(A: T.Buffer[(128, 128), "float32"], B: T.Buffer[(128, 128), "float32"]): with T.block("root"): s = T.var("int32") e = T.var("int32") # A0 should be remapped A0 = T.match_buffer( A[0:128, 0:128], shape=(128, 128), dtype="float32", # s and e should be remapped strides=[s, s], elem_offset=e, ) for i, j in T.grid(128, 128): with T.block("B"): vi, vj = T.axis.remap("SS", [i, j]) B[vi, vj] = A0[vi, vj] * 2.0
def vector_func(a: T.handle, b: T.handle): n = T.var("int32") m = 128 A = T.match_buffer(a, (n, m)) B = T.match_buffer(b, (n, m)) for i in T.serial(n): for j in T.vectorized(m): A[i, j] = A[i, j] + B[i, j]
def fail_buffer_bind(a: T.handle) -> None: A = T.match_buffer(a, (8, 8)) for i, j in T.grid(8, 2): with T.block(): stride = T.var("int32") sub_A = T.match_buffer(A[i, j * 4:j * 4 + 4], (1, 4), strides=[stride, stride], offset_factor=1) for jj in range(0, 4): sub_A[i, j * 4 + jj] = 1