def opaque_access_store(a: T.handle, c: T.handle) -> None: A = T.match_buffer(a, (128, 128)) B = T.alloc_buffer((128, 128)) C = T.match_buffer(c, (128, 128)) for i, j in T.grid(128, 128): with T.block("B"): vi, vj = T.axis.remap("SS", [i, j]) B[vi, vj] = A[vi, vj] * 2.0 for i, j in T.grid(128, 128): with T.block("C"): vi, vj = T.axis.remap("SS", [i, j]) T.reads(B[0:128, 0:128]) T.writes(C[0:128, 0:128]) T.evaluate(B.access_ptr("r", extent=128)) T.evaluate(C.access_ptr("w", extent=128)) C[vi, vj] = B[vi, vj] + 1.0
def access_opaque_ptr_then_elemwise_inline(a: T.handle, b: T.handle) -> None: A = T.match_buffer(a, [1024], dtype="float32") B = T.match_buffer(b, [1024], dtype="float32") A_cache = T.alloc_buffer([1024], dtype="float32") with T.block("opaque"): # annotated opaque partial access should be kept T.reads(A[0:512]) T.writes([A_cache[0:512]]) T.evaluate(A.access_ptr("r", extent=512)) T.evaluate(A_cache.access_ptr("w", extent=512)) for i in T.serial(0, 512): with T.block("B"): vi = T.axis.spatial(512, i) T.reads([A_cache[vi]]) T.writes([B[vi]]) B[vi] = A_cache[vi] * 2.0 + 1.0
def gemm_mma_m16n8k256_row_col_b1b1s32(a: T.handle, b: T.handle, c: T.handle): T.func_attr({"global_symbol": "default_function", "tir.noalias": True}) A = T.match_buffer(a, [16, 256], dtype="int1") B = T.match_buffer(b, [8, 256], dtype="int1") C = T.match_buffer(c, [16, 8], dtype="int32") brow = T.env_thread("blockIdx.y") bcol = T.env_thread("blockIdx.x") tx = T.env_thread("threadIdx.x") T.launch_thread(brow, 1) T.launch_thread(bcol, 1) T.launch_thread(tx, 32) MultiA = T.allocate([128], "int1", scope="local") MultiB = T.allocate([64], "int1", scope="local") Accum = T.allocate([4], "int32", scope="local") for i in range(4): Accum[i] = T.int32(0) for mma_multi_a_col in range(128): MultiA[mma_multi_a_col] = A[(tx % 32) // 4 + mma_multi_a_col % 64 // 32 * 8, (tx % 32) % 4 * 32 + mma_multi_a_col % 32 + mma_multi_a_col // 64 * 128, ] for mma_multi_b_col in range(16): MultiB[mma_multi_b_col] = B[(tx % 32) // 4, (tx % 32) % 4 * 32 + mma_multi_b_col % 32 + mma_multi_b_col // 32 * 128, ] T.evaluate( T.ptx_mma( "m16n8k256", "row", "col", "int1", "int1", "int32", MultiA.data, 0, MultiB.data, 0, Accum.data, 0, False, "xor", dtype="int32", )) for mma_accum_c_id in range(4): C[(tx % 32) // 4 + mma_accum_c_id // 2 * 8, (tx % 32) % 4 * 2 + mma_accum_c_id % 2, ] = Accum[mma_accum_c_id]
def gemm_mma_m16n8k8_row_col_fp16fp16fp32(a: T.handle, b: T.handle, c: T.handle): T.func_attr({"global_symbol": "default_function", "tir.noalias": True}) A = T.match_buffer(a, [16, 8], dtype="float16") B = T.match_buffer(b, [8, 8], dtype="float16") C = T.match_buffer(c, [16, 8], dtype="float32") brow = T.env_thread("blockIdx.y") bcol = T.env_thread("blockIdx.x") tx = T.env_thread("threadIdx.x") T.launch_thread(brow, 1) T.launch_thread(bcol, 1) T.launch_thread(tx, 32) MultiA = T.allocate([4], "float16", scope="local") MultiB = T.allocate([2], "float16", scope="local") Accum = T.allocate([4], "float32", scope="local") for i in range(4): Accum[i] = T.float32(0) for mma_multi_a_col in T.vectorized(4): MultiA[mma_multi_a_col] = A[ (tx % 32) // 4 + mma_multi_a_col // 2 * 8, (tx % 32) % 4 * 2 + mma_multi_a_col % 2 ] for mma_multi_b_col in T.vectorized(4): MultiB[mma_multi_b_col] = B[ (tx % 32) // 4 + mma_multi_b_col // 2 * 8, (tx % 32) % 4 * 2 + mma_multi_b_col % 2 ] T.evaluate( T.ptx_mma( "m16n8k8", "row", "col", "fp16", "fp16", "fp32", MultiA, 0, MultiB, 0, Accum, 0, False, dtype="float32", ) ) for mma_accum_c_id in range(4): C[ (tx % 32) // 4 + mma_accum_c_id // 2 * 8, (tx % 32) % 4 * 2 + mma_accum_c_id % 2 ] = T.load("float32", Accum, mma_accum_c_id)
def opaque_access(a: T.handle, b: T.handle) -> None: A = T.match_buffer(a, (32, 64, 128)) B = T.match_buffer(b, (64, 64, 64)) for i, j, k in T.grid(2, 64, 8): with T.block([]): T.reads([]) T.writes(A[i * 16 : i * 16 + 16, j, k * 16 : k * 16 + 16]) sub_A = T.match_buffer( A[i * 16 : i * 16 + 16, j, k * 16 : k * 16 + 16], (16, 1, 16), strides=[8192, 128, 1], offset_factor=1, ) T.evaluate( T.intrin_test( sub_A.data, sub_A.elem_offset, sub_A.strides[0], sub_A.strides[1], sub_A.shape[0], sub_A.shape[1], dtype="handle", ) ) for i, j, k in T.grid(64, 2, 8): with T.block([]): Bs_0 = T.var("int32") Bs_1 = T.var("int32") T.reads([]) T.writes(B[i, j * 32 : j * 32 + 32, k * 8 : k * 8 + 8]) sub_B = T.match_buffer( B[i, j * 32 : j * 32 + 32, k * 8 : k * 8 + 8], (32, 8), strides=[Bs_0, Bs_1], offset_factor=1, ) T.evaluate( T.intrin_test( sub_B.data, sub_B.elem_offset, sub_B.strides[0], sub_B.strides[1], sub_B.shape[0], sub_B.shape[1], dtype="handle", ) )
def lowered_with_block_predicate(a: T.handle, b: T.handle) -> None: A = T.match_buffer(a, [128, 120], dtype="float32") B = T.match_buffer(b, [128], dtype="float32") reduce_temp0 = T.alloc_buffer([1], dtype="float32", strides=[1], scope="local") normal_reduce_temp0 = T.alloc_buffer([1], dtype="float32", strides=[1], scope="local") for i in T.serial(0, 128): for ki in T.thread_binding(0, 32, thread="threadIdx.x"): with T.block("B_in_thread_init"): T.reads([]) T.writes([normal_reduce_temp0[0]]) normal_reduce_temp0[0] = T.float32(0) for ko in T.serial(0, 4): with T.block("B_normal_reduction"): vi = T.axis.spatial(128, i) vk = T.axis.reduce(120, ko * 32 + ki) T.where(ko * 32 + ki < 120) T.reads([A[vi, vk], normal_reduce_temp0[0]]) T.writes([normal_reduce_temp0[0]]) normal_reduce_temp0[0] = normal_reduce_temp0[0] + A[vi, vk] with T.block("B_cross_thread_reduction"): T.reads([normal_reduce_temp0[0]]) T.writes([reduce_temp0[0]]) T.attr( T.comm_reducer(lambda x, y: x + y, [T.float32(0)]), "reduce_scope", T.reinterpret(T.uint64(0), dtype="handle"), ) T.evaluate( T.tvm_thread_allreduce( T.uint32(1), normal_reduce_temp0[0], True, reduce_temp0.data, ki, dtype="handle", )) with T.block("B_write_back"): vi = T.axis.spatial(128, i) T.reads([reduce_temp0[0]]) T.writes([B[vi]]) B[vi] = reduce_temp0[0]
def opaque_access_fused(a: T.handle, b: T.handle) -> None: A = T.match_buffer(a, [16, 16]) B = T.match_buffer(b, [16, 16]) for i_j_fused in T.serial(0, 256): with T.block("A"): vi = T.axis.S(16, T.floordiv(i_j_fused, 16)) vj = T.axis.S(16, T.floormod(i_j_fused, 16)) T.reads([]) T.writes([A[0:16, 0:16]]) A[vi, vj] = 1 for i_j_fused in T.serial(0, 256): with T.block("B"): vi = T.axis.S(16, T.floordiv(i_j_fused, 16)) vj = T.axis.S(16, T.floormod(i_j_fused, 16)) T.reads([]) T.writes([B[0:16, 0:16]]) T.evaluate(T.tvm_fill_fragment(B.data, 16, 16, 16, 0, ((vi * 16) + vj), dtype="handle"))
def exp_exp_opaque_access_with_tvm_access_ptr_inlined( lookup_table: T.Buffer[(1024,), "int8"], x: T.Buffer[(16,), "float16"], compute: T.Buffer[(16,), "float16"], ) -> None: for i0 in T.serial(16): with T.block("compute_1"): i0_1 = T.axis.spatial(16, i0) # Do not put the opaque access to new write region when opaque access # wrapped with a tvm_access_ptr and the access mask set to "read only" T.reads(lookup_table[0:1024], x[i0_1]) T.writes(compute[i0_1]) T.evaluate(lookup_table.access_ptr("r")) compute[i0_1] = T.exp( T.exp(x[i0_1], dtype="float16"), dtype="float16", )
def mma_fill_impl(a: T.handle) -> None: C_warp = T.match_buffer(a, [WARP_SIZE, local_size], dtype=dtype, scope="warp", offset_factor=1) with T.block("root"): T.reads() T.writes(C_warp[0:WARP_SIZE, 0:local_size]) tx = T.env_thread("threadIdx.x") T.launch_thread(tx, WARP_SIZE) T.evaluate( T.mma_fill(local_size, C_warp.data, C_warp.elem_offset, dtype=dtype))
def transformed_high_dim_opaque_access(a: T.handle) -> None: A = T.match_buffer(a, (16, 32, 64)) for i, j, k in T.grid(16, 2, 4): with T.block([]): T.reads([]) T.writes(A[i, j * 16 : j * 16 + 16, k * 16 : k * 16 + 16]) T.evaluate( T.intrin_test( A.data, i * 2048 + j * 1024 + k * 16, 64, 1, 16, 16, dtype="handle", ) )
def transformed_high_dim_opaque_access_with_source_strides( a: T.handle) -> None: A = T.match_buffer(a, (16, 32, 64), strides=[2576, 80, 1]) for i, j, k in T.grid(16, 2, 4): with T.block(): T.reads([]) T.writes(A[i, j * 16:j * 16 + 16, k * 16:k * 16 + 16]) T.evaluate( T.intrin_test( A.data, i * 2576 + j * 1280 + k * 16, 80, 1, 16, 16, dtype="handle", ))
def opaque_access_split(a: T.handle, b: T.handle) -> None: A = T.match_buffer(a, (16, 16)) B = T.match_buffer(b, (16, 16)) for i, j0, j1 in T.grid(16, 4, 4): with T.block("A"): vi = T.axis.S(16, i) vj = T.axis.S(16, j0 * 4 + j1) T.reads([]) T.writes([A[0:16, 0:16]]) A[vi, vj] = 1 for i, j0, j1 in T.grid(16, 4, 4): with T.block("B"): vi = T.axis.S(16, i) vj = T.axis.S(16, j0 * 4 + j1) T.reads([]) T.writes([B[0:16, 0:16]]) T.evaluate(T.tvm_fill_fragment(B.data, 16, 16, 16, 0, ((vi * 16) + vj), dtype="handle"))
def opaque_access_func() -> None: A = T.alloc_buffer([1024]) B = T.alloc_buffer([1024]) for i in T.serial(0, 8): with T.block(): v = T.axis.S(8, i) T.reads([A[v * 128:v * 128 + 128]]) T.writes([B[v * 128:v * 128 + 128]]) T.evaluate( T.call_extern("test", B.data, v * 128, 128, A.data, v * 128, 128, dtype="float32"))
def opaque_access(a: T.handle, b: T.handle) -> None: A = T.match_buffer(a, [16, 16], "float32") B = T.match_buffer(b, [16, 16], "float32") with T.block([16, 16], "A") as [vi, vj]: T.reads([]) T.writes([A[0:16, 0:16]]) T.store(A.data, vi * 16 + vj, 1) with T.block([16, 16], "B") as [vi, vj]: T.reads([]) T.writes([B[0:16, 0:16]]) T.evaluate( T.tvm_fill_fragment(B.data, 16, 16, 16, 0, vi * 16 + vj, dtype="handle"))
def transformed_rank0_buffer(a: T.handle, b: T.handle) -> None: A = T.match_buffer(a, (8, 8)) B = T.match_buffer(b, (8, 8)) for i, j in T.grid(8, 8): with T.block(): T.reads([]) T.writes([A[i, j], B[i, j]]) A[i, j] = 1 T.evaluate( T.intrin_test( B.data, i * 8 + j, 0, 0, 0, 0, dtype="handle", ))
def main(buffer2: T.Buffer[(160,), "uint8"]) -> None: # function attr dict T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}) buffer1 = T.buffer_decl([8192], "int8") buffer10 = T.buffer_decl([2048], "int8") # body p5 = T.allocate([160], "uint8", "global") T.evaluate(T.call_extern("ethosu_copy", buffer2[0], 160, p5[0], dtype="handle")) T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 32, 16, 0, 16, buffer1[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 16, 2, 16, 0, 16, buffer10[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, p5[0], 128, 12, p5[128], 32, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) T.evaluate(T.call_extern("ethosu_copy", buffer2[0], 160, p5[0], dtype="handle")) T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 32, 16, 0, 16, buffer1[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 16, 2, 16, 0, 16, buffer10[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, p5[0], 128, 12, p5[128], 32, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle"))
def tir_extern(a: T.handle, b: T.handle, c: T.handle) -> None: T.func_attr({"global_symbol": "main", "tir.noalias": True}) A = T.match_buffer(a, (128, 128)) B = T.match_buffer(b, (128, 128)) C = T.match_buffer(c, (128, 128)) # body with T.block("C"): T.reads([A[0:128, 0:128], B[0:128, 0:128]]) T.writes([C[0:128, 0:128]]) T.evaluate( T.tvm_call_packed( "tvm.contrib.cblas.matmul", T.tvm_stack_make_array( A.data, T.tvm_stack_make_shape(128, 128, dtype="handle"), 0, 2, 0.0, 0, dtype="handle", ), T.tvm_stack_make_array( B.data, T.tvm_stack_make_shape(128, 128, dtype="handle"), 0, 2, 0.0, 0, dtype="handle", ), T.tvm_stack_make_array( C.data, T.tvm_stack_make_shape(128, 128, dtype="handle"), 0, 2, 0.0, 0, dtype="handle", ), 0, 0, dtype="int32", ) )
def main(buffer2: T.Buffer[(80,), "uint8"], buffer3: T.Buffer[(64,), "uint8"]) -> None: T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}) buffer0 = T.buffer_decl([390336], "int8") buffer1 = T.buffer_decl([97156], "int8") buffer6 = T.buffer_decl([390336], "int8") # body p2 = T.allocate([80], "uint8", "global") p3 = T.allocate([64], "uint8", "global") T.evaluate(T.call_extern("ethosu_pooling", "int8", 214, 227, 2, 214, 0, 227, buffer1[0], 0, 0, 0, T.float32(1), 0, "NHWC", 454, 2, 1, "int8", 214, 114, 2, 214, 0, 114, buffer0[0], 0, 0, 0, T.float32(1), 0, "NHCWB16", 1824, 16, 1, "MAX", 2, 1, 2, 1, 1, 1, 0, 0, 0, 1, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) T.evaluate(T.call_extern("ethosu_copy", buffer2[0], 80, p2[0], dtype="handle")) T.evaluate(T.call_extern("ethosu_copy", buffer3[0], 64, p3[0], dtype="handle")) T.evaluate(T.call_extern("ethosu_conv2d", "int8", 214, 114, 2, 214, 0, 114, buffer0[0], 0, 0, 0, T.float32(0.00392157), -128, "NHCWB16", 1824, 16, 1, "int8", 214, 114, 5, 214, 0, 114, buffer6[0], 0, 0, 0, T.float32(0.0174839), -128, "NHCWB16", 1824, 16, 1, 3, 1, 1, 1, 1, 2, p2[0], 80, 0, p3[0], 64, 0, 1, 0, 1, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle"))
def match_buffer_func(a: T.handle, b: T.handle) -> None: A = T.match_buffer(a, (128, 128), "float32") B = T.match_buffer(b, (128, 128), "float32") for i, j in T.grid(8, 8): with T.block("block"): vi, vj = T.axis.remap("SS", [i, j]) T.reads(B[vi * 16 + 2:vi * 16 + 12, vj * 16 + 2:vj * 16 + 16]) T.writes(A[vi * 16:vi * 16 + 16, vj * 16:vj * 16 + 16]) B0 = T.match_buffer( B[vi * 16 + 2:vi * 16 + 6, vj * 16 + 2:vj * 16 + 6], (4, 4)) B1 = T.match_buffer( B[vi * 16 + 8:vi * 16 + 12, vj * 16 + 8:vj * 16 + 16], (4, 8)) for ii, jj in T.grid(16, 16): with T.block("AAA"): vii, vjj = T.axis.remap("SS", [ii, jj]) AA = T.match_buffer(A[vii, vjj], ()) AA[()] = 1.0 T.evaluate(B0.data) T.evaluate(B1.data)
def dot_product_intrin(a: T.handle, b: T.handle, c: T.handle) -> None: A = T.match_buffer(a, (4, ), offset_factor=1) B = T.match_buffer(b, (4, ), offset_factor=1) C = T.match_buffer(c, (), offset_factor=1) with T.block("root"): T.reads(C[()], A[0:4], B[0:4]) T.writes(C[()]) T.evaluate( T.call_extern( "vec4add", C.data, C.elem_offset, A.data, A.elem_offset, B.data, B.elem_offset, dtype="int32", ))
def gemm_mma_m8n8k32_row_col_s4u4s32(a: T.handle, b: T.handle, c: T.handle): T.func_attr({"global_symbol": "default_function", "tir.noalias": True}) A = T.match_buffer(a, [8, 32], dtype="int4") B = T.match_buffer(b, [8, 32], dtype="uint4") C = T.match_buffer(c, [8, 8], dtype="int32") brow = T.env_thread("blockIdx.y") bcol = T.env_thread("blockIdx.x") tx = T.env_thread("threadIdx.x") T.launch_thread(brow, 1) T.launch_thread(bcol, 1) T.launch_thread(tx, 32) MultiA = T.allocate([8], "int4", scope="local") MultiB = T.allocate([8], "uint4", scope="local") Accum = T.allocate([2], "int32", scope="local") for i in range(2): Accum[i] = T.int32(0) for mma_multi_a_col in T.vectorized(8): MultiA[mma_multi_a_col] = A[(tx % 32) // 4, mma_multi_a_col + (tx % 32) % 4 * 8] for mma_multi_b_col in T.vectorized(8): MultiB[mma_multi_b_col] = B[(tx % 32) // 4, mma_multi_b_col + (tx % 32) % 4 * 8] T.evaluate( T.ptx_mma( "m8n8k32", "row", "col", "int4", "uint4", "int32", MultiA, 0, MultiB, 0, Accum, 0, False, dtype="int32", ) ) for mma_accum_c_id in range(2): C[(tx % 32) // 4, (tx % 32) % 4 * 2 + mma_accum_c_id] = T.load( "int32", Accum, mma_accum_c_id )
def access_opaque_ptr_then_elemwise(a: T.handle, b: T.handle) -> None: A = T.match_buffer(a, [1024]) B = T.match_buffer(b, [1024]) A_cache = T.alloc_buffer([1024]) BB = T.alloc_buffer([1024]) with T.block("opaque"): # annotated opaque partial access T.reads(A[0:512]) T.writes(A_cache[0:512]) T.evaluate(A.access_ptr("r", extent=512)) T.evaluate(A_cache.access_ptr("w", extent=512)) for i in range(512): with T.block("BB"): vi = T.axis.remap("S", [i]) BB[vi] = A_cache[vi] * 2.0 for i in range(512): with T.block("B"): vi = T.axis.remap("S", [i]) B[vi] = BB[vi] + 1.0
def compacted_opaque_access_annotated_func(a: T.handle) -> None: A = T.match_buffer(a, (1024,), "float32") with T.block(): B = T.alloc_buffer((1024,), dtypes="float32") C = T.alloc_buffer((520,), dtypes="float32") for i in range(0, 512): with T.block(): # no annotation, opaque access will cover full region T.reads([]) T.writes([]) T.evaluate(T.call_extern("opaque_extern_function", A.data, B.data, dtype="int32")) B[i] = A[i] with T.block(): # treat opaque access only access annotated regions, even if # they are not compatible with actual buffer accesses. T.reads([B[i]]) T.writes([C[i : i + 9]]) T.evaluate(T.call_extern("opaque_extern_function", B.data, C.data, dtype="int32")) C[i] = B[i]
def func() -> None: A = T.alloc_buffer((128, 128), "float32") B = T.alloc_buffer((128, 128), "float32") C = T.alloc_buffer((128, 128), "float32") D = T.alloc_buffer((128, 128), "float32") with T.block(): # Need add read/write region manually to avoid triggering block access region detector T.reads([B[0, 0], C[0:16, 0:16], A[4:12, 4:12]]) T.writes([A[0:12, 0:12]]) for i, j in T.grid(8, 8): A[i, j] = B[0, 0] + C[0, 0] for i, j in T.grid(2, 2): with T.block(): vi, vj = T.axis.remap("SS", [i, j]) T.reads([A[vi * 4 + 4 : vi * 4 + 8, vj * 4 + 4 : vj * 4 + 8], C[12:16, 12:16]]) T.writes([A[vi * 4 + 4 : vi * 4 + 8, vj * 4 + 4 : vj * 4 + 8]]) for i, j in T.grid(4, 4): A[vi * 4 + 4 + i, vj * 4 + 4 + j] += C[i + 12, j + 12] T.evaluate(D.data)
def transformed_trivial_pipeline(A: T.Buffer[(16, 1), "float32"], C: T.Buffer[(16, 1), "float32"]) -> None: for tx in T.thread_binding(16, thread="threadIdx.x"): with T.block(): T.reads(A[tx, 0]) T.writes(C[tx, 0]) B = T.alloc_buffer([2, 16, 1], dtype="float32", scope="shared") with T.block(): T.reads(A[tx, 0]) T.writes(B[0, tx, 0]) B[0, tx, 0] = A[tx, 0] * T.float32(2) with T.block(): T.reads() T.writes() T.evaluate(0) with T.block(): T.reads(B[0, tx, 0]) T.writes(C[tx, 0]) C[tx, 0] = B[0, tx, 0] + T.float32(1)
def wmma_fill_impl(c: T.handle) -> None: C = T.match_buffer(c, (m_dim, n_dim), dtype, align=128, offset_factor=16, scope="wmma.accumulator") with T.block("root"): T.reads() T.writes(C[0:m_dim, 0:n_dim]) T.evaluate( T.tvm_fill_fragment( C.data, m_dim, n_dim, k_dim, get_wmma_fragment_index(C, m_dim, n_dim), T.float32(0), dtype="handle", ))
def tir_extern(a: T.handle, b: T.handle, c: T.handle) -> None: A = T.match_buffer(a, (128, 128)) B = T.match_buffer(b, (128, 128)) C = T.match_buffer(c, (128, 128)) # body with T.block([], "C"): T.reads([A[0:128, 0:128], B[0:128, 0:128]]) T.writes([C[0:128, 0:128]]) T.evaluate( T.tvm_call_packed( "tvm.contrib.cblas.matmul", T.tvm_stack_make_array( A.data, T.tvm_stack_make_shape(128, 128, dtype="handle"), 0, 2, 0.0, 0, dtype="handle", ), T.tvm_stack_make_array( B.data, T.tvm_stack_make_shape(128, 128, dtype="handle"), 0, 2, 0.0, 0, dtype="handle", ), T.tvm_stack_make_array( C.data, T.tvm_stack_make_shape(128, 128, dtype="handle"), 0, 2, 0.0, 0, dtype="handle", ), 0, 0, dtype="int32", ))
def transformed_recursive_match(a: T.handle, b: T.handle) -> None: A = T.match_buffer(a, (64, 64, 64)) B = T.match_buffer(b, (64, 64, 64)) for i, j, k in T.grid(64, 4, 4): with T.block([]): T.reads([]) T.writes( [ A[i, j * 16 : j * 16 + 16, k * 16 : k * 16 + 16], B[i, j * 16 : j * 16 + 16, k * 16 : k * 16 + 16], ] ) for jj, kk in T.grid(4, 4): with T.block([]): T.reads([]) T.writes( [ A[ i, j * 16 + jj * 4 : j * 16 + jj * 4 + 4, k * 16 + kk * 4 : k * 16 + kk * 4 + 4, ], B[ i, j * 16 + jj * 4 : j * 16 + jj * 4 + 4, k * 16 + kk * 4 : k * 16 + kk * 4 + 4, ], ] ) T.evaluate( T.intrin_test( A.data, i * 4096 + j * 1024 + jj * 256 + k * 16 + kk * 4, 64, 1, 4, 4, dtype="handle", ) ) for jjj, kkk in T.grid(4, 4): B[i, j * 16 + jj * 4 + jjj, k * 16 + kk * 4 + kkk] = 1
def mma_intrin(a: T.handle, b: T.handle, c: T.handle) -> None: A = T.match_buffer(a, (16, 16), align=128, offset_factor=1) B = T.match_buffer(b, (16, 16), align=128, offset_factor=1) C = T.match_buffer(c, (16, 16), align=128, offset_factor=1) with T.block("root"): T.reads(C[0:16, 0:16], A[0:16, 0:16], B[0:16, 0:16]) T.writes(C[0:16, 0:16]) T.evaluate( T.tvm_mma_sync( C.data, C.elem_offset // 256, A.data, A.elem_offset // 256, B.data, B.elem_offset // 256, C.data, C.elem_offset // 256, dtype="handle", ))
def rank0_buffer(a: T.handle, b: T.handle) -> None: A = T.match_buffer(a, (8, 8)) B = T.match_buffer(b, (8, 8)) for i, j in T.grid(8, 8): with T.block(): T.reads([]) T.writes([A[i, j], B[i, j]]) sub_A = T.match_buffer(A[i, j], (), offset_factor=1) sub_B = T.match_buffer(B[i, j], (), offset_factor=1) sub_A[()] = 1 T.evaluate( T.intrin_test( sub_B.data, sub_B.elem_offset, 0, 0, 0, 0, dtype="handle", ))