def main(a: T.handle, b: T.handle) -> None: # function attr dict T.func_attr({"global_symbol": "main", "T.noalias": True}) # var definition threadIdx_x = T.env_thread("threadIdx.x") threadIdx_y = T.env_thread("threadIdx.y") blockIdx_x = T.env_thread("blockIdx.x") blockIdx_y = T.env_thread("blockIdx.y") blockIdx_z = T.env_thread("blockIdx.z") A = T.match_buffer(a, [14, 14, 256, 256], dtype="float32") B = T.match_buffer(b, [14, 14, 512, 256], dtype="float32") # body T.launch_thread(blockIdx_z, 196) B_local = T.allocate([6400000], "float32", "local") Apad_shared = T.allocate([512], "float32", "shared") Apad_shared_local = T.allocate([8], "float32", "local") T.launch_thread(blockIdx_y, 8) T.launch_thread(blockIdx_x, 4) T.launch_thread(threadIdx_y, 8) T.launch_thread(threadIdx_x, 8) for ff_c_init, nn_c_init in T.grid(8, 8): T.store(B_local, ff_c_init * 8 + nn_c_init, T.float32(0), True) for rc_outer, ry, rx in T.grid(32, 3, 3): for ax3_inner_outer in T.serial(0, 2): T.store(Apad_shared, T.ramp(threadIdx_y * 64 + threadIdx_x * 8 + ax3_inner_outer * 4, 1, 4), T.if_then_else(1 <= blockIdx_z // 14 + ry and blockIdx_z // 14 + ry < 15 and 1 <= rx + blockIdx_z % 14 and rx + blockIdx_z % 14 < 15, T.load("float32x4", A.data, T.ramp(ry * 917504 + blockIdx_z * 65536 + rx * 65536 + rc_outer * 2048 + threadIdx_y * 256 + blockIdx_x * 64 + threadIdx_x * 8 + ax3_inner_outer * 4 - 983040, 1, 4), T.broadcast(True, 4)), T.broadcast(T.float32(0), 4), dtype="float32x4"), T.broadcast(True, 4)) for rc_inner in T.serial(0, 8): for ax3 in T.serial(0, 8): T.store(Apad_shared_local, ax3, T.load("float32", Apad_shared, rc_inner * 64 + threadIdx_x * 8 + ax3), True) for ff_c, nn_c in T.grid(8, 8): T.store(B_local, ff_c * 8 + nn_c, T.load("float32", B_local, ff_c * 8 + nn_c) + T.load("float32", Apad_shared_local, nn_c), True) for ff_inner_inner_inner, nn_inner_inner_inner in T.grid(8, 8): T.store(B.data, blockIdx_z * 131072 + blockIdx_y * 16384 + threadIdx_y * 2048 + ff_inner_inner_inner * 256 + blockIdx_x * 64 + threadIdx_x * 8 + nn_inner_inner_inner, T.load("float32", B_local, ff_inner_inner_inner * 8 + nn_inner_inner_inner), True)# fmt: on
def ptx_cp_async(A: T.Buffer[(32, 128), "float16"], B: T.Buffer[(32, 128), "float16"]) -> None: T.func_attr({"global_symbol": "default_function", "tir.noalias": True}) bx = T.env_thread("blockIdx.x") tx = T.env_thread("threadIdx.x") T.launch_thread(bx, 1) T.launch_thread(tx, 32) with T.block(): A_shared = T.alloc_buffer([32, 128], "float16", scope="shared") T.reads(A[0:32, 0:128]) T.writes(B[0:32, 0:128]) for i in range(16): T.evaluate( T.ptx_cp_async(A_shared.data, tx * 128 + 8 * i, A.data, tx * 128 + 8 * i, 16, dtype="float16")) # TODO(masahi): Remove dtype requirement from TVMScript parser T.evaluate(T.ptx_commit_group(dtype="float16")) T.evaluate(T.ptx_wait_group(0, dtype="float16")) for i in range(128): B[tx, i] = A_shared[tx, i]
def main(inputs: T.Buffer[(8192,), "float32"], weight: T.Buffer[(2097152,), "float32"], conv2d_transpose_nhwc: T.Buffer[(16384,), "float32"]) -> None: # function attr dict T.func_attr({"global_symbol": "main", "tir.noalias": True}) # var definition threadIdx_x = T.env_thread("threadIdx.x") blockIdx_x = T.env_thread("blockIdx.x") T.preflattened_buffer(inputs, [1, 4, 4, 512], dtype="float32", data=inputs.data) T.preflattened_buffer(weight, [4, 4, 512, 256], dtype="float32", data=weight.data) T.preflattened_buffer(conv2d_transpose_nhwc, [1, 8, 8, 256], dtype="float32", data=conv2d_transpose_nhwc.data) # body T.launch_thread(blockIdx_x, 64) conv2d_transpose_nhwc_local = T.allocate([8], "float32", "local") PadInput_shared = T.allocate([768], "float32", "shared") weight_shared = T.allocate([4096], "float32", "shared") T.launch_thread(threadIdx_x, 32) for i2_3_init, i1_4_init, i2_4_init in T.grid(2, 2, 2): conv2d_transpose_nhwc_local[i1_4_init * 4 + i2_3_init * 2 + i2_4_init] = T.float32(0) for i6_0 in T.serial(16): for ax0_ax1_ax2_ax3_fused_0 in T.serial(24): PadInput_shared[ax0_ax1_ax2_ax3_fused_0 * 32 + threadIdx_x] = T.if_then_else(4 <= ax0_ax1_ax2_ax3_fused_0 and ax0_ax1_ax2_ax3_fused_0 < 20 and 1 <= blockIdx_x // 32 * 2 + ax0_ax1_ax2_ax3_fused_0 % 4 and blockIdx_x // 32 * 2 + ax0_ax1_ax2_ax3_fused_0 % 4 < 5, inputs[blockIdx_x // 32 * 1024 + ax0_ax1_ax2_ax3_fused_0 * 512 + i6_0 * 32 + threadIdx_x - 2560], T.float32(0), dtype="float32") for ax0_ax1_ax2_ax3_fused_0 in T.serial(32): weight_shared[T.ramp(ax0_ax1_ax2_ax3_fused_0 * 128 + threadIdx_x * 4, 1, 4)] = weight[T.ramp(ax0_ax1_ax2_ax3_fused_0 // 2 * 131072 + i6_0 * 8192 + ax0_ax1_ax2_ax3_fused_0 % 2 * 4096 + threadIdx_x // 2 * 256 + blockIdx_x % 32 * 8 + threadIdx_x % 2 * 4, 1, 4)] for i6_1, i2_3, i4_2, i5_2, i6_2, i1_4, i2_4 in T.grid(4, 2, 4, 4, 8, 2, 2): conv2d_transpose_nhwc_local[i1_4 * 4 + i2_3 * 2 + i2_4] = conv2d_transpose_nhwc_local[i1_4 * 4 + i2_3 * 2 + i2_4] + T.if_then_else((i1_4 + i4_2) % 2 == 0 and (i2_4 + i5_2) % 2 == 0, PadInput_shared[threadIdx_x // 8 * 128 + (i1_4 + i4_2) // 2 * 128 + (i2_4 + i5_2) // 2 * 32 + i2_3 * 32 + i6_1 * 8 + i6_2], T.float32(0), dtype="float32") * weight_shared[i6_1 * 64 + i6_2 * 8 + threadIdx_x % 8 + 3840 - i5_2 * 256 - i4_2 * 1024] for ax1, ax2 in T.grid(2, 4): conv2d_transpose_nhwc[threadIdx_x // 8 * 4096 + ax1 * 2048 + blockIdx_x // 32 * 1024 + ax2 * 256 + blockIdx_x % 32 * 8 + threadIdx_x % 8] = conv2d_transpose_nhwc_local[ax1 * 4 + ax2]
def unified_element_wise_kernels_with_different_size(a: T.handle, b: T.handle, c: T.handle, d: T.handle) -> None: block_x = T.env_thread("blockIdx.x") thread_x = T.env_thread("threadIdx.x") block_x_1 = T.env_thread("blockIdx.x") thread_x_1 = T.env_thread("threadIdx.x") A = T.match_buffer(a, [128, 128]) B = T.match_buffer(b, [128, 128]) C = T.match_buffer(c, [256, 256]) D = T.match_buffer(d, [256, 256]) with T.launch_thread(block_x, 128): T.launch_thread(thread_x, 128) T.store( B.data, block_x * 128 + thread_x, T.load("float32", A.data, block_x * 128 + thread_x) * 2.0, True, ) T.launch_thread(block_x_1, 256) T.launch_thread(thread_x_1, 256) T.store( D.data, block_x_1 * 256 + thread_x_1, T.load("float32", C.data, block_x_1 * 256 + thread_x_1) + 1.0, True, )
def unified_element_wise_thread_x(a: T.handle, b: T.handle, c: T.handle) -> None: thread_x = T.env_thread("threadIdx.x") block_x = T.env_thread("blockIdx.x") A = T.match_buffer(a, [128, 128]) B = T.match_buffer(b, [128, 128]) C = T.match_buffer(c, [128, 128]) T.launch_thread(block_x, 128) with T.launch_thread(thread_x, 4): for j0_1 in T.serial(0, 32): T.store( B.data, block_x * 128 + thread_x * 32 + j0_1, T.load("float32", A.data, block_x * 128 + thread_x * 32 + j0_1) * 2.0, True, ) T.launch_thread(thread_x, 4) for j1_1 in T.serial(0, 32): T.store( C.data, block_x * 128 + thread_x * 32 + j1_1, T.load("float32", A.data, block_x * 128 + thread_x * 32 + j1_1) + 1.0, True, )
def ptx_ldmatrix(A: T.Buffer[(16, 16), "float16"], B: T.Buffer[(16, 16), "float16"], num: T.int32, trans: T.uint8) -> None: T.func_attr({"global_symbol": "default_function", "tir.noalias": True}) bx = T.env_thread("blockIdx.x") tx = T.env_thread("threadIdx.x") T.launch_thread(bx, 1) T.launch_thread(tx, 32) with T.block(): A_shared = T.alloc_buffer([16, 16], "float16", scope="shared") A_local = T.alloc_buffer([8], "float16", scope="local") for i in range(8): A_shared[i * 2 + tx // 16, tx % 16] = A[i * 2 + tx // 16, tx % 16] T.evaluate( T.ptx_ldmatrix( trans, num, ".b16", A_local.data, 0, A_shared.data, 16 * (tx % 16) + 8 * (tx // 16), dtype="float16", )) for k in range(2): for j in range(2): for i in range(2): B[8 * j + tx // 4, 8 * k + (tx % 4) * 2 + i] = A_local[4 * k + 2 * j + i]
def ptx_global_to_shared_dyn_copy_fp16x8( A: T.Buffer[(32, 128), "float16"], B: T.Buffer[(32, 128), "float16"], C: T.Buffer[(32, 128), "float16"], ) -> None: T.func_attr({"global_symbol": "main", "tir.noalias": True}) bx = T.env_thread("blockIdx.x") tx = T.env_thread("threadIdx.x") T.launch_thread(bx, 1) T.launch_thread(tx, 32) with T.block(): A_shared = T.alloc_buffer([32, 128], "float16", scope="shared.dyn") B_shared = T.alloc_buffer([32, 128], "float16", scope="shared.dyn") T.reads(A[0:32, 0:128], B[0:32, 0:128]) T.writes(C[0:32, 0:128]) T.attr("default", "async_scope", 1) for i in T.serial(16): for j in T.vectorized(8): A_shared[tx, i * 8 + j] = A[tx, i * 8 + j] B_shared[tx, i * 8 + j] = B[tx, i * 8 + j] T.evaluate(T.ptx_commit_group(dtype="")) T.evaluate(T.ptx_wait_group(0, dtype="")) for i in range(128): C[tx, i] = A_shared[tx, i] + B_shared[tx, i]
def main(a: T.handle, b: T.handle) -> None: # function attr dict T.func_attr({"global_symbol": "main", "T.noalias": True}) # var definition threadIdx_x = T.env_thread("threadIdx.x") threadIdx_y = T.env_thread("threadIdx.y") blockIdx_x = T.env_thread("blockIdx.x") blockIdx_y = T.env_thread("blockIdx.y") blockIdx_z = T.env_thread("blockIdx.z") A = T.match_buffer(a, [14 * 14 * 256 * 256], dtype="float32") B = T.match_buffer(b, [14 * 14 * 512 * 256], dtype="float32") # body T.launch_thread(blockIdx_z, 196) B_local = T.allocate([64], "float32", "local") Apad_shared = T.allocate([512000], "float32", "shared") Apad_shared_local = T.allocate([8], "float32", "local") T.launch_thread(blockIdx_y, 8) T.launch_thread(blockIdx_x, 4) T.launch_thread(threadIdx_y, 8) T.launch_thread(threadIdx_x, 8) for ff_c_init, nn_c_init in T.grid(8, 8): B_local[ff_c_init * 8 + nn_c_init] = T.float32(0) for rc_outer, ry, rx in T.grid(32, 3, 3): for ax3_inner_outer in T.serial(0, 2): Apad_shared[T.ramp( threadIdx_y * 64 + threadIdx_x * 8 + ax3_inner_outer * 4, 1, 4)] = T.if_then_else( 1 <= blockIdx_z // 14 + ry and blockIdx_z // 14 + ry < 15 and 1 <= rx + blockIdx_z % 14 and rx + blockIdx_z % 14 < 15, A[T.ramp( ry * 917504 + blockIdx_z * 65536 + rx * 65536 + rc_outer * 2048 + threadIdx_y * 256 + blockIdx_x * 64 + threadIdx_x * 8 + ax3_inner_outer * 4 - 983040, 1, 4)], T.broadcast(T.float32(0), 4), dtype="float32x4", ) # Access of the last element of Apad_shared prevents # buffer compacting from reducing the amount of shared # memory used. Apad_shared[512000 - 1] = 0.0 for rc_inner in T.serial(0, 8): for ax3 in T.serial(0, 8): Apad_shared_local[ax3] = Apad_shared[rc_inner * 64 + threadIdx_x * 8 + ax3] for ff_c, nn_c in T.grid(8, 8): B_local[ff_c * 8 + nn_c] = B_local[ff_c * 8 + nn_c] + Apad_shared_local[nn_c] for ff_inner_inner_inner, nn_inner_inner_inner in T.grid(8, 8): B[blockIdx_z * 131072 + blockIdx_y * 16384 + threadIdx_y * 2048 + ff_inner_inner_inner * 256 + blockIdx_x * 64 + threadIdx_x * 8 + nn_inner_inner_inner] = B_local[ff_inner_inner_inner * 8 + nn_inner_inner_inner] # fmt: on
def gemm_mma_m8n8k4_row_row_fp16fp16fp32(a: T.handle, b: T.handle, c: T.handle): T.func_attr({"global_symbol": "default_function", "tir.noalias": True}) A = T.match_buffer(a, [16, 4], dtype="float16") B = T.match_buffer(b, [4, 16], dtype="float16") C = T.match_buffer(c, [16, 16], dtype="float32") brow = T.env_thread("blockIdx.y") bcol = T.env_thread("blockIdx.x") tx = T.env_thread("threadIdx.x") T.launch_thread(brow, 1) T.launch_thread(bcol, 1) T.launch_thread(tx, 32) MultiA = T.allocate([4], "float16", scope="local") MultiB = T.allocate([4], "float16", scope="local") Accum = T.allocate([8], "float32", scope="local") for i in range(8): Accum[i] = T.float32(0) for mma_multi_a_col in T.vectorized(4): MultiA[mma_multi_a_col] = A[ ((tx % 32) % 4) + (4 * ((((tx % 32) // 16 + (tx % 32) % 16 // 4 * 2)) % 4)), mma_multi_a_col, ] for mma_multi_b_col in T.vectorized(4): MultiB[mma_multi_b_col] = B[ (tx % 32) % 4, mma_multi_b_col + (4 * ((tx % 32) // 8)), ] T.evaluate( T.ptx_mma( "m8n8k4", "row", "row", "fp16", "fp16", "fp32", MultiA, 0, MultiB, 0, Accum, 0, False, dtype="float32", ) ) for mma_accum_c_id in range(8): C[ ((tx % 32) % 2) + ((mma_accum_c_id // 2 % 2) * 2) + 4 * ((tx % 32) // 16) + ((tx % 32) % 16 // 4) % 2 * 8, (tx % 32) % 4 // 2 * 2 + (tx % 32) % 16 // 8 * 4 + mma_accum_c_id % 2 + mma_accum_c_id // 4 * 8, ] = T.load("float32", Accum, mma_accum_c_id)
def mma_sp_m16n8k16_f16f16f32(a: T.handle, b: T.handle, c: T.handle, _metadata: T.handle): T.func_attr({"global_symbol": "default_function", "tir.noalias": True}) A = T.match_buffer(a, [16, 8], dtype="float16") B = T.match_buffer(b, [16, 8], dtype="float16") C = T.match_buffer(c, [16, 8], dtype="float32") metadata = T.match_buffer(_metadata, [8], dtype="uint32") brow = T.env_thread("blockIdx.y") bcol = T.env_thread("blockIdx.x") tx = T.env_thread("threadIdx.x") T.launch_thread(brow, 1) T.launch_thread(bcol, 1) T.launch_thread(tx, 32) multi_a = T.allocate([4], "float16", scope="local") multi_b = T.allocate([4], "float16", scope="local") accum = T.allocate([4], "float32", scope="local") meta_local = T.allocate([1], "uint32", scope="local") for i in range(4): accum[i] = T.float16(0) for i in range(4): multi_a[i] = A[tx // 4 + i // 2 * 8, tx % 4 * 2 + i % 2] for i in range(4): multi_b[i] = B[tx % 4 * 2 + i % 2 + i // 2 * 8, tx // 4] meta_local[0] = metadata[tx // 4] T.evaluate( T.ptx_mma_sp( "m16n8k16", "row", "col", "fp16", "fp16", "fp32", multi_a.data, 0, multi_b.data, 0, accum.data, 0, meta_local.data, 0, 0, False, dtype="float32", ) ) for i in range(4): C[i // 2 * 8 + tx // 4, tx % 4 * 2 + i % 2] = accum[i]
def gemm_mma_m16n8k256_row_col_b1b1s32(a: T.handle, b: T.handle, c: T.handle): T.func_attr({"global_symbol": "default_function", "tir.noalias": True}) A = T.match_buffer(a, [16, 256], dtype="int1") B = T.match_buffer(b, [8, 256], dtype="int1") C = T.match_buffer(c, [16, 8], dtype="int32") brow = T.env_thread("blockIdx.y") bcol = T.env_thread("blockIdx.x") tx = T.env_thread("threadIdx.x") T.launch_thread(brow, 1) T.launch_thread(bcol, 1) T.launch_thread(tx, 32) MultiA = T.allocate([128], "int1", scope="local") MultiB = T.allocate([64], "int1", scope="local") Accum = T.allocate([4], "int32", scope="local") for i in range(4): Accum[i] = T.int32(0) for mma_multi_a_col in range(128): MultiA[mma_multi_a_col] = A[ (tx % 32) // 4 + mma_multi_a_col % 64 // 32 * 8, (tx % 32) % 4 * 32 + mma_multi_a_col % 32 + mma_multi_a_col // 64 * 128, ] for mma_multi_b_col in range(16): MultiB[mma_multi_b_col] = B[ (tx % 32) // 4, (tx % 32) % 4 * 32 + mma_multi_b_col % 32 + mma_multi_b_col // 32 * 128, ] T.evaluate( T.ptx_mma( "m16n8k256", "row", "col", "int1", "int1", "int32", MultiA.data, 0, MultiB.data, 0, Accum.data, 0, False, dtype="int32", ) ) for mma_accum_c_id in range(4): C[ (tx % 32) // 4 + mma_accum_c_id // 2 * 8, (tx % 32) % 4 * 2 + mma_accum_c_id % 2, ] = Accum[mma_accum_c_id]
def gemm_mma_m16n8k32_row_col_s8s8s32(a: T.handle, b: T.handle, c: T.handle): T.func_attr({"global_symbol": "default_function", "tir.noalias": True}) A = T.match_buffer(a, [16, 32], dtype="int8") B = T.match_buffer(b, [8, 32], dtype="int8") C = T.match_buffer(c, [16, 8], dtype="int32") brow = T.env_thread("blockIdx.y") bcol = T.env_thread("blockIdx.x") tx = T.env_thread("threadIdx.x") T.launch_thread(brow, 1) T.launch_thread(bcol, 1) T.launch_thread(tx, 32) MultiA = T.allocate([16], "int8", scope="local") MultiB = T.allocate([8], "int8", scope="local") Accum = T.allocate([4], "int32", scope="local") for i in range(4): Accum[i] = T.int32(0) for mma_multi_a_col in range(16): MultiA[mma_multi_a_col] = A[ (tx % 32) // 4 + mma_multi_a_col % 8 // 4 * 8, (tx % 32) % 4 * 4 + mma_multi_a_col % 4 + mma_multi_a_col // 8 * 16, ] for mma_multi_b_col in range(8): MultiB[mma_multi_b_col] = B[ (tx % 32) // 4, (tx % 32) % 4 * 4 + mma_multi_b_col % 4 + mma_multi_b_col // 4 * 16, ] T.evaluate( T.ptx_mma( "m16n8k32", "row", "col", "int8", "int8", "int32", MultiA, 0, MultiB, 0, Accum, 0, False, dtype="int32", ) ) for mma_accum_c_id in range(4): C[ (tx % 32) // 4 + mma_accum_c_id // 2 * 8, (tx % 32) % 4 * 2 + mma_accum_c_id % 2, ] = T.load("int32", Accum, mma_accum_c_id)
def gemm_mma_m16n8k16_row_col_fp16fp16fp32(a: T.handle, b: T.handle, c: T.handle): T.func_attr({"global_symbol": "default_function", "tir.noalias": True}) A = T.match_buffer(a, [16, 16], dtype="float16") B = T.match_buffer(b, [8, 16], dtype="float16") C = T.match_buffer(c, [16, 8], dtype="float32") brow = T.env_thread("blockIdx.y") bcol = T.env_thread("blockIdx.x") tx = T.env_thread("threadIdx.x") T.launch_thread(brow, 1) T.launch_thread(bcol, 1) T.launch_thread(tx, 32) MultiA = T.allocate([8], "float16", scope="local") MultiB = T.allocate([4], "float16", scope="local") Accum = T.allocate([4], "float32", scope="local") for i in range(4): Accum[i] = T.float32(0) for mma_multi_a_col in range(8): MultiA[mma_multi_a_col] = A[ (tx % 32) // 4 + mma_multi_a_col % 4 // 2 * 8, (tx % 32) % 4 * 2 + mma_multi_a_col % 2 + mma_multi_a_col // 4 * 8, ] for mma_multi_b_col in T.vectorized(4): MultiB[mma_multi_b_col] = B[ (tx % 32) // 4, (tx % 32) % 4 * 2 + mma_multi_b_col % 2 + mma_multi_b_col // 2 * 8, ] T.evaluate( T.ptx_mma( "m16n8k16", "row", "col", "fp16", "fp16", "fp32", MultiA.data, 0, MultiB.data, 0, Accum.data, 0, False, dtype="float32", ) ) for mma_accum_c_id in range(4): C[ (tx % 32) // 4 + mma_accum_c_id // 2 * 8, (tx % 32) % 4 * 2 + mma_accum_c_id % 2, ] = Accum[mma_accum_c_id]
def flattened_gpu_func(a: T.handle, c: T.handle) -> None: A = T.match_buffer(a, (16, 16), "float32") C = T.match_buffer(c, (16, 16), "float32") i0 = T.env_thread("blockIdx.x") i1 = T.env_thread("threadIdx.x") i2 = T.env_thread("vthread") T.launch_thread(i0, 4) T.launch_thread(i1, 2) T.launch_thread(i2, 2) B = T.allocate([16], "float32", "local") for j in range(0, 16): B[j] = T.load("float32", A.data, i0 * 64 + i1 * 32 + i2 * 16 + j) + 1.0 for j in range(0, 16): C.data[i0 * 64 + i1 * 32 + i2 * 16 + j] = T.load("float32", B, j) * 2.0
def gpu_func(a: T.handle, c: T.handle) -> None: A = T.match_buffer(a, (16, 16), "float32") C = T.match_buffer(c, (16, 16), "float32") i0 = T.env_thread("blockIdx.x") i1 = T.env_thread("threadIdx.x") i2 = T.env_thread("vthread") T.launch_thread(i0, 4) T.launch_thread(i1, 2) T.launch_thread(i2, 2) B = T.allocate([1, 16], "float32", "local") for j in range(0, 16): B[0, j] = A[i0 * 4 + i1 * 2 + i2, j] + 1.0 for j in range(0, 16): C[i0 * 4 + i1 * 2 + i2, j] = B[0, j] * 2.0
def mma_store_impl(a: T.handle, c: T.handle) -> None: s0 = T.var("int32") s1 = T.var("int32") C_warp = T.match_buffer(a, [WARP_SIZE, local_size], dtype=dtype, scope="warp", offset_factor=1) C = T.match_buffer(c, [M_DIM, N_DIM], dtype=dtype, scope="global", offset_factor=1, strides=[s0, s1]) with T.block("root"): T.reads(C_warp[0:WARP_SIZE, 0:local_size]) T.writes(C[0:M_DIM, 0:N_DIM]) tx = T.env_thread("threadIdx.x") T.launch_thread(tx, WARP_SIZE) T.evaluate( T.mma_store( M_DIM, N_DIM, C.access_ptr("w"), C_warp.data, C_warp.elem_offset, s0, dtype=dtype, ))
def element_wise_two_thread_x_in_same_kernel_not_equal(a: T.handle, b: T.handle, c: T.handle) -> None: i = T.env_thread("blockIdx.x") j0 = T.env_thread("threadIdx.x") j1 = T.env_thread("threadIdx.x") A = T.match_buffer(a, [128, 128]) B = T.match_buffer(b, [128, 128]) C = T.match_buffer(c, [128, 64]) T.launch_thread(i, 128) with T.launch_thread(j0, 128): T.store(B.data, i * 64 + j0, T.load("float32", A.data, i * 128 + j0) * 2.0, True) T.launch_thread(j1, 64) T.store(C.data, i * 64 + j1, T.load("float32", A.data, i * 128 + j1) + 1.0, True)
def unified_element_wise_vthread_x(a: T.handle, b: T.handle) -> None: vthread_x = T.env_thread("vthread.x") thread_x = T.env_thread("threadIdx.x") A = T.match_buffer(a, [128, 128]) B = T.match_buffer(b, [128, 128]) T.launch_thread(vthread_x, 2) T.launch_thread(thread_x, 64) T.launch_thread(vthread_x, 2) for j_1 in T.serial(0, 64): T.store( B.data, vthread_x * 8256 + thread_x * 128 + j_1, T.load("float32", A.data, vthread_x * 8256 + thread_x * 128 + j_1) * 2.0, True, )
def ldmatrix_impl(warp_handle: T.handle, shared_handle: T.handle) -> None: s0 = T.var("int32") s1 = T.var("int32") shared = T.match_buffer( shared_handle, shmem_shape, dtype, align=128, offset_factor=16, scope=shared_scope, strides=[s0, s1], ) warp = T.match_buffer(warp_handle, (WARP_SIZE, local_size), dtype, align=128, offset_factor=16, scope="warp") with T.block("root"): T.reads(shared[0:row_dim, 0:col_dim]) T.writes(warp[0:WARP_SIZE, 0:local_size]) tx = T.env_thread("threadIdx.x") T.launch_thread(tx, WARP_SIZE) T.evaluate( T.ptx_ldmatrix( ldmatrix_col_major, 4, # Always load 4 matrices ".b16", warp.data, warp.elem_offset + lift(local_size) * tx, shared.access_ptr("r"), shared_offset(tx, s0), dtype=dtype, ))
def before(A: T.Buffer[(128, 16), "float32"], n: T.int32): i = T.env_thread("threadIdx.x") T.launch_thread(i, 128) for j in T.serial(16): if i < 32: A[i, j] = 0.0
def element_wise_vthread_x(a: T.handle, b: T.handle) -> None: i_0 = T.env_thread("vthread.x") i_1 = T.env_thread("threadIdx.x") j_0 = T.env_thread("vthread.x") A = T.match_buffer(a, [128, 128]) B = T.match_buffer(b, [128, 128]) T.launch_thread(i_0, 2) T.launch_thread(i_1, 64) T.launch_thread(j_0, 2) for j_1 in T.serial(0, 64): T.store( B.data, i_0 * 8192 + i_1 * 128 + j_0 * 64 + j_1, T.load("float32", A.data, i_0 * 8192 + i_1 * 128 + j_0 * 64 + j_1) * 2.0, True, )
def element_wise_env_thread_x(a: T.handle, b: T.handle, c: T.handle) -> None: j1_0 = T.env_thread("threadIdx.x") j0_0 = T.env_thread("threadIdx.x") i = T.env_thread("blockIdx.x") A = T.match_buffer(a, [128, 128]) B = T.match_buffer(b, [128, 128]) C = T.match_buffer(c, [128, 128]) T.launch_thread(i, 128) T.launch_thread(j0_0, 4) T.launch_thread(j1_0, 4) for j0_1 in T.serial(0, 32): with T.block(""): B[i, j0_0 * 32 + j0_1] = A[i, j0_0 * 32 + j0_1] * 2.0 for j1_1 in T.serial(0, 32): with T.block(""): C[i, j1_0 * 32 + j1_1] = B[i, j1_0 * 32 + j1_1] + 1.0
def expected(A: T.Buffer[(128, 16), "float32"], n: T.int32): thread_x = T.env_thread("threadIdx.x") T.launch_thread(thread_x, 128) if n == 0: for i in T.thread_binding(0, 128, thread="threadIdx.x"): for j in T.serial(16): A[i, j] = 0.0
def before(A: T.Buffer[(128, 16), "float32"], n: T.int32): thread_x = T.env_thread("threadIdx.x") T.launch_thread(thread_x, 128) for i in T.thread_binding(0, 128, thread="threadIdx.x"): if i < 32: for j in T.serial(16): A[i, j] = 0.0
def mma_sync_impl(a: T.handle, b: T.handle, c: T.handle) -> None: A = T.match_buffer( a, (WARP_SIZE, local_size), in_dtype, align=128, offset_factor=16, scope="warp" ) B = T.match_buffer( b, (WARP_SIZE, local_size), in_dtype, align=128, offset_factor=16, scope="warp" ) C = T.match_buffer( c, (WARP_SIZE, local_size_out), out_dtype, align=128, offset_factor=16, scope="warp" ) with T.block("root"): T.reads( C[0:WARP_SIZE, 0:local_size_out], A[0:WARP_SIZE, 0:local_size], B[0:WARP_SIZE, 0:local_size], ) T.writes(C[0:WARP_SIZE, 0:local_size_out]) tx = T.env_thread("threadIdx.x") T.launch_thread(tx, WARP_SIZE) T.evaluate( T.ptx_mma( mma_prefix, "row", "col", in_dtype_abbrv, in_dtype_abbrv, out_dtype_abbrv, A.data, A.elem_offset + tx * lift(local_size), B.data, B.elem_offset + tx * lift(local_size), C.data, C.elem_offset + tx * lift(local_size_out), False, dtype=out_dtype, ) ) T.evaluate( T.ptx_mma( mma_prefix, "row", "col", in_dtype_abbrv, in_dtype_abbrv, out_dtype_abbrv, A.data, A.elem_offset + tx * lift(local_size), B.data, B.elem_offset + tx * lift(local_size) + lift(local_size) // 2, C.data, C.elem_offset + tx * lift(local_size_out) + lift(local_size_out) // 2, False, dtype=out_dtype, ) )
def flattened_gpu_func(a: T.handle, c: T.handle) -> None: A = T.match_buffer(a, 256, "float32") C = T.match_buffer(c, 256, "float32") T.preflattened_buffer(A, (16, 16), dtype="float32", data=A.data) T.preflattened_buffer(C, (16, 16), dtype="float32", data=C.data) i0 = T.env_thread("blockIdx.x") i1 = T.env_thread("threadIdx.x") i2 = T.env_thread("vthread") T.launch_thread(i0, 4) T.launch_thread(i1, 2) T.launch_thread(i2, 2) B = T.allocate([16], "float32", "local") for j in range(0, 16): B[j] = A[i0 * 64 + i1 * 32 + i2 * 16 + j] + 1.0 for j in range(0, 16): C[i0 * 64 + i1 * 32 + i2 * 16 + j] = B[j] * 2.0
def element_wise_kernels_with_different_size(a: T.handle, b: T.handle, c: T.handle, d: T.handle) -> None: i0 = T.env_thread("blockIdx.x") j0 = T.env_thread("threadIdx.x") i1 = T.env_thread("blockIdx.x") j1 = T.env_thread("threadIdx.x") A = T.match_buffer(a, [128, 128]) B = T.match_buffer(b, [128, 128]) C = T.match_buffer(c, [256, 256]) D = T.match_buffer(d, [256, 256]) with T.launch_thread(i0, 128): T.launch_thread(j0, 128) T.store(B.data, i0 * 128 + j0, T.load("float32", A.data, i0 * 128 + j0) * 2.0, True) T.launch_thread(i1, 256) T.launch_thread(j1, 256) T.store(D.data, i1 * 256 + j1, T.load("float32", C.data, i1 * 256 + j1) + 1.0, True)
def gemm_mma_m8n8k32_row_col_s4u4s32(a: T.handle, b: T.handle, c: T.handle): T.func_attr({"global_symbol": "default_function", "tir.noalias": True}) A = T.match_buffer(a, [8, 32], dtype="int4") B = T.match_buffer(b, [8, 32], dtype="uint4") C = T.match_buffer(c, [8, 8], dtype="int32") brow = T.env_thread("blockIdx.y") bcol = T.env_thread("blockIdx.x") tx = T.env_thread("threadIdx.x") T.launch_thread(brow, 1) T.launch_thread(bcol, 1) T.launch_thread(tx, 32) MultiA = T.allocate([8], "int4", scope="local") MultiB = T.allocate([8], "uint4", scope="local") Accum = T.allocate([2], "int32", scope="local") for i in range(2): Accum[i] = T.int32(0) for mma_multi_a_col in T.vectorized(8): MultiA[mma_multi_a_col] = A[(tx % 32) // 4, mma_multi_a_col + (tx % 32) % 4 * 8] for mma_multi_b_col in T.vectorized(8): MultiB[mma_multi_b_col] = B[(tx % 32) // 4, mma_multi_b_col + (tx % 32) % 4 * 8] T.evaluate( T.ptx_mma( "m8n8k32", "row", "col", "int4", "uint4", "int32", MultiA, 0, MultiB, 0, Accum, 0, False, dtype="int32", ) ) for mma_accum_c_id in range(2): C[(tx % 32) // 4, (tx % 32) % 4 * 2 + mma_accum_c_id] = T.load( "int32", Accum, mma_accum_c_id )
def gemm_mma_m8n8k4_row_col_fp64pf64fp64(a: T.handle, b: T.handle, c: T.handle): T.func_attr({"global_symbol": "default_function", "tir.noalias": True}) A = T.match_buffer(a, [8, 4], dtype="float64") B = T.match_buffer(b, [8, 4], dtype="float64") C = T.match_buffer(c, [8, 8], dtype="float64") brow = T.env_thread("blockIdx.y") bcol = T.env_thread("blockIdx.x") tx = T.env_thread("threadIdx.x") T.launch_thread(brow, 1) T.launch_thread(bcol, 1) T.launch_thread(tx, 32) MultiA = T.allocate([1], "float64", scope="local") MultiB = T.allocate([1], "float64", scope="local") Accum = T.allocate([2], "float64", scope="local") for i in range(2): Accum[i] = T.float64(0) MultiA[0] = A[(tx % 32) // 4, (tx % 32) % 4] MultiB[0] = B[(tx % 32) // 4, (tx % 32) % 4] T.evaluate( T.ptx_mma( "m8n8k4", "row", "col", "fp64", "fp64", "fp64", MultiA, 0, MultiB, 0, Accum, 0, False, dtype="float64", ) ) for mma_accum_c_id in range(2): C[(tx % 32) // 4, (tx % 32) % 4 * 2 + mma_accum_c_id] = T.load( "float64", Accum, mma_accum_c_id )
def element_wise_env_thread_x(a: T.handle, b: T.handle, c: T.handle) -> None: j1_0 = T.env_thread("threadIdx.x") j0_0 = T.env_thread("threadIdx.x") i = T.env_thread("blockIdx.x") A = T.match_buffer(a, [128, 128]) B = T.match_buffer(b, [128, 128]) C = T.match_buffer(c, [128, 128]) T.launch_thread(i, 128) T.launch_thread(j0_0, 4) T.launch_thread(j1_0, 4) for blockIdx_x in T.thread_binding(0, 128, "blockIdx.x"): for threadIdx_x in T.thread_binding(0, 4, "threadIdx.x"): for j0_1 in T.serial(0, 32): with T.block(""): B[blockIdx_x, threadIdx_x * 32 + j0_1] = (A[blockIdx_x, threadIdx_x * 32 + j0_1] * 2.0) for j1_1 in T.serial(0, 32): with T.block(""): C[blockIdx_x, threadIdx_x * 32 + j1_1] = (B[blockIdx_x, threadIdx_x * 32 + j1_1] + 1.0)