def mma_sync_impl(a: T.handle, b: T.handle, c: T.handle) -> None: A = T.match_buffer( a, (WARP_SIZE, local_size), in_dtype, align=128, offset_factor=16, scope="warp" ) B = T.match_buffer( b, (WARP_SIZE, local_size), in_dtype, align=128, offset_factor=16, scope="warp" ) C = T.match_buffer( c, (WARP_SIZE, local_size_out), out_dtype, align=128, offset_factor=16, scope="warp" ) with T.block("root"): T.reads( C[0:WARP_SIZE, 0:local_size_out], A[0:WARP_SIZE, 0:local_size], B[0:WARP_SIZE, 0:local_size], ) T.writes(C[0:WARP_SIZE, 0:local_size_out]) tx = T.env_thread("threadIdx.x") T.launch_thread(tx, WARP_SIZE) T.evaluate( T.ptx_mma( mma_prefix, "row", "col", in_dtype_abbrv, in_dtype_abbrv, out_dtype_abbrv, A.data, A.elem_offset + tx * lift(local_size), B.data, B.elem_offset + tx * lift(local_size), C.data, C.elem_offset + tx * lift(local_size_out), False, dtype=out_dtype, ) ) T.evaluate( T.ptx_mma( mma_prefix, "row", "col", in_dtype_abbrv, in_dtype_abbrv, out_dtype_abbrv, A.data, A.elem_offset + tx * lift(local_size), B.data, B.elem_offset + tx * lift(local_size) + lift(local_size) // 2, C.data, C.elem_offset + tx * lift(local_size_out) + lift(local_size_out) // 2, False, dtype=out_dtype, ) )
def gemm_mma_m8n8k4_row_row_fp16fp16fp32(a: T.handle, b: T.handle, c: T.handle): T.func_attr({"global_symbol": "default_function", "tir.noalias": True}) A = T.match_buffer(a, [16, 4], dtype="float16") B = T.match_buffer(b, [4, 16], dtype="float16") C = T.match_buffer(c, [16, 16], dtype="float32") brow = T.env_thread("blockIdx.y") bcol = T.env_thread("blockIdx.x") tx = T.env_thread("threadIdx.x") T.launch_thread(brow, 1) T.launch_thread(bcol, 1) T.launch_thread(tx, 32) MultiA = T.allocate([4], "float16", scope="local") MultiB = T.allocate([4], "float16", scope="local") Accum = T.allocate([8], "float32", scope="local") for i in range(8): Accum[i] = T.float32(0) for mma_multi_a_col in T.vectorized(4): MultiA[mma_multi_a_col] = A[ ((tx % 32) % 4) + (4 * ((((tx % 32) // 16 + (tx % 32) % 16 // 4 * 2)) % 4)), mma_multi_a_col, ] for mma_multi_b_col in T.vectorized(4): MultiB[mma_multi_b_col] = B[ (tx % 32) % 4, mma_multi_b_col + (4 * ((tx % 32) // 8)), ] T.evaluate( T.ptx_mma( "m8n8k4", "row", "row", "fp16", "fp16", "fp32", MultiA, 0, MultiB, 0, Accum, 0, False, dtype="float32", ) ) for mma_accum_c_id in range(8): C[ ((tx % 32) % 2) + ((mma_accum_c_id // 2 % 2) * 2) + 4 * ((tx % 32) // 16) + ((tx % 32) % 16 // 4) % 2 * 8, (tx % 32) % 4 // 2 * 2 + (tx % 32) % 16 // 8 * 4 + mma_accum_c_id % 2 + mma_accum_c_id // 4 * 8, ] = T.load("float32", Accum, mma_accum_c_id)
def gemm_mma_m16n8k32_row_col_s8s8s32(a: T.handle, b: T.handle, c: T.handle): T.func_attr({"global_symbol": "default_function", "tir.noalias": True}) A = T.match_buffer(a, [16, 32], dtype="int8") B = T.match_buffer(b, [8, 32], dtype="int8") C = T.match_buffer(c, [16, 8], dtype="int32") brow = T.env_thread("blockIdx.y") bcol = T.env_thread("blockIdx.x") tx = T.env_thread("threadIdx.x") T.launch_thread(brow, 1) T.launch_thread(bcol, 1) T.launch_thread(tx, 32) MultiA = T.allocate([16], "int8", scope="local") MultiB = T.allocate([8], "int8", scope="local") Accum = T.allocate([4], "int32", scope="local") for i in range(4): Accum[i] = T.int32(0) for mma_multi_a_col in range(16): MultiA[mma_multi_a_col] = A[ (tx % 32) // 4 + mma_multi_a_col % 8 // 4 * 8, (tx % 32) % 4 * 4 + mma_multi_a_col % 4 + mma_multi_a_col // 8 * 16, ] for mma_multi_b_col in range(8): MultiB[mma_multi_b_col] = B[ (tx % 32) // 4, (tx % 32) % 4 * 4 + mma_multi_b_col % 4 + mma_multi_b_col // 4 * 16, ] T.evaluate( T.ptx_mma( "m16n8k32", "row", "col", "int8", "int8", "int32", MultiA, 0, MultiB, 0, Accum, 0, False, dtype="int32", ) ) for mma_accum_c_id in range(4): C[ (tx % 32) // 4 + mma_accum_c_id // 2 * 8, (tx % 32) % 4 * 2 + mma_accum_c_id % 2, ] = T.load("int32", Accum, mma_accum_c_id)
def gemm_mma_m16n8k16_row_col_fp16fp16fp32(a: T.handle, b: T.handle, c: T.handle): T.func_attr({"global_symbol": "default_function", "tir.noalias": True}) A = T.match_buffer(a, [16, 16], dtype="float16") B = T.match_buffer(b, [8, 16], dtype="float16") C = T.match_buffer(c, [16, 8], dtype="float32") brow = T.env_thread("blockIdx.y") bcol = T.env_thread("blockIdx.x") tx = T.env_thread("threadIdx.x") T.launch_thread(brow, 1) T.launch_thread(bcol, 1) T.launch_thread(tx, 32) MultiA = T.allocate([8], "float16", scope="local") MultiB = T.allocate([4], "float16", scope="local") Accum = T.allocate([4], "float32", scope="local") for i in range(4): Accum[i] = T.float32(0) for mma_multi_a_col in range(8): MultiA[mma_multi_a_col] = A[ (tx % 32) // 4 + mma_multi_a_col % 4 // 2 * 8, (tx % 32) % 4 * 2 + mma_multi_a_col % 2 + mma_multi_a_col // 4 * 8, ] for mma_multi_b_col in T.vectorized(4): MultiB[mma_multi_b_col] = B[ (tx % 32) // 4, (tx % 32) % 4 * 2 + mma_multi_b_col % 2 + mma_multi_b_col // 2 * 8, ] T.evaluate( T.ptx_mma( "m16n8k16", "row", "col", "fp16", "fp16", "fp32", MultiA.data, 0, MultiB.data, 0, Accum.data, 0, False, dtype="float32", ) ) for mma_accum_c_id in range(4): C[ (tx % 32) // 4 + mma_accum_c_id // 2 * 8, (tx % 32) % 4 * 2 + mma_accum_c_id % 2, ] = Accum[mma_accum_c_id]
def gemm_mma_m16n8k256_row_col_b1b1s32(a: T.handle, b: T.handle, c: T.handle): T.func_attr({"global_symbol": "default_function", "tir.noalias": True}) A = T.match_buffer(a, [16, 256], dtype="int1") B = T.match_buffer(b, [8, 256], dtype="int1") C = T.match_buffer(c, [16, 8], dtype="int32") brow = T.env_thread("blockIdx.y") bcol = T.env_thread("blockIdx.x") tx = T.env_thread("threadIdx.x") T.launch_thread(brow, 1) T.launch_thread(bcol, 1) T.launch_thread(tx, 32) MultiA = T.allocate([128], "int1", scope="local") MultiB = T.allocate([64], "int1", scope="local") Accum = T.allocate([4], "int32", scope="local") for i in range(4): Accum[i] = T.int32(0) for mma_multi_a_col in range(128): MultiA[mma_multi_a_col] = A[ (tx % 32) // 4 + mma_multi_a_col % 64 // 32 * 8, (tx % 32) % 4 * 32 + mma_multi_a_col % 32 + mma_multi_a_col // 64 * 128, ] for mma_multi_b_col in range(16): MultiB[mma_multi_b_col] = B[ (tx % 32) // 4, (tx % 32) % 4 * 32 + mma_multi_b_col % 32 + mma_multi_b_col // 32 * 128, ] T.evaluate( T.ptx_mma( "m16n8k256", "row", "col", "int1", "int1", "int32", MultiA.data, 0, MultiB.data, 0, Accum.data, 0, False, dtype="int32", ) ) for mma_accum_c_id in range(4): C[ (tx % 32) // 4 + mma_accum_c_id // 2 * 8, (tx % 32) % 4 * 2 + mma_accum_c_id % 2, ] = Accum[mma_accum_c_id]
def gemm_mma_m8n8k32_row_col_s4u4s32(a: T.handle, b: T.handle, c: T.handle): T.func_attr({"global_symbol": "default_function", "tir.noalias": True}) A = T.match_buffer(a, [8, 32], dtype="int4") B = T.match_buffer(b, [8, 32], dtype="uint4") C = T.match_buffer(c, [8, 8], dtype="int32") brow = T.env_thread("blockIdx.y") bcol = T.env_thread("blockIdx.x") tx = T.env_thread("threadIdx.x") T.launch_thread(brow, 1) T.launch_thread(bcol, 1) T.launch_thread(tx, 32) MultiA = T.allocate([8], "int4", scope="local") MultiB = T.allocate([8], "uint4", scope="local") Accum = T.allocate([2], "int32", scope="local") for i in range(2): Accum[i] = T.int32(0) for mma_multi_a_col in T.vectorized(8): MultiA[mma_multi_a_col] = A[(tx % 32) // 4, mma_multi_a_col + (tx % 32) % 4 * 8] for mma_multi_b_col in T.vectorized(8): MultiB[mma_multi_b_col] = B[(tx % 32) // 4, mma_multi_b_col + (tx % 32) % 4 * 8] T.evaluate( T.ptx_mma( "m8n8k32", "row", "col", "int4", "uint4", "int32", MultiA, 0, MultiB, 0, Accum, 0, False, dtype="int32", ) ) for mma_accum_c_id in range(2): C[(tx % 32) // 4, (tx % 32) % 4 * 2 + mma_accum_c_id] = T.load( "int32", Accum, mma_accum_c_id )
def gemm_mma_m8n8k4_row_col_fp64pf64fp64(a: T.handle, b: T.handle, c: T.handle): T.func_attr({"global_symbol": "default_function", "tir.noalias": True}) A = T.match_buffer(a, [8, 4], dtype="float64") B = T.match_buffer(b, [8, 4], dtype="float64") C = T.match_buffer(c, [8, 8], dtype="float64") brow = T.env_thread("blockIdx.y") bcol = T.env_thread("blockIdx.x") tx = T.env_thread("threadIdx.x") T.launch_thread(brow, 1) T.launch_thread(bcol, 1) T.launch_thread(tx, 32) MultiA = T.allocate([1], "float64", scope="local") MultiB = T.allocate([1], "float64", scope="local") Accum = T.allocate([2], "float64", scope="local") for i in range(2): Accum[i] = T.float64(0) MultiA[0] = A[(tx % 32) // 4, (tx % 32) % 4] MultiB[0] = B[(tx % 32) // 4, (tx % 32) % 4] T.evaluate( T.ptx_mma( "m8n8k4", "row", "col", "fp64", "fp64", "fp64", MultiA, 0, MultiB, 0, Accum, 0, False, dtype="float64", ) ) for mma_accum_c_id in range(2): C[(tx % 32) // 4, (tx % 32) % 4 * 2 + mma_accum_c_id] = T.load( "float64", Accum, mma_accum_c_id )