def warp_memory(a: ty.handle, c: ty.handle) -> None: A = tir.match_buffer(a, [128, 128]) C = tir.match_buffer(c, [128, 128]) B = tir.alloc_buffer([128, 4, 32], scope="warp") for i_o in tir.thread_binding(0, 4, thread="threadIdx.y"): for i_i in tir.thread_binding(0, 32, thread="threadIdx.x"): for j in tir.serial(0, 128): with tir.block([4, 32, 128], "B") as [warp_id, lane_id, vj]: B[vj, warp_id, lane_id] = A[warp_id * 32 + lane_id, vj] * 2.0 for j in tir.serial(0, 128): with tir.block([4, 32, 128], "C") as [warp_id, lane_id, vj]: C[warp_id * 32 + lane_id, vj] = B[vj, warp_id, lane_id] + 1.0
def equal_ranked_threads(a: ty.handle, c: ty.handle) -> None: A = tir.match_buffer(a, [128, 128]) C = tir.match_buffer(c, [128, 128]) B = tir.alloc_buffer([128, 128], scope="shared") for i_o in tir.thread_binding(0, 16, thread="threadIdx.x"): for i_i in tir.thread_binding(0, 8, thread="threadIdx.y"): for j in tir.serial(0, 128): with tir.block([128, 128], "B") as [vi, vj]: tir.bind(vi, i_o * 8 + i_i) tir.bind(vj, j) B[vi, vj] = A[vi, vj] * 2.0 for j in tir.serial(0, 128): with tir.block([128, 128], "C") as [vi, vj]: tir.bind(vi, i_o * 8 + i_i) tir.bind(vj, j) C[vj, vi] = B[vj, vi] + 1.0
def compacted_gpu_func(a: ty.handle, c: ty.handle) -> None: A = tir.match_buffer(a, (16, 16), "float32") C = tir.match_buffer(c, (16, 16), "float32") for i0 in tir.thread_binding(0, 4, thread="blockIdx.x"): for i1 in tir.thread_binding(0, 2, thread="threadIdx.x"): for i2 in tir.thread_binding(0, 2, thread="vthread"): with tir.block([]): tir.reads(A[i0 * 4 + i1 * 2 + i2, 0:16]) tir.writes(C[i0 * 4 + i1 * 2 + i2, 0:16]) B = tir.alloc_buffer([1, 16], "float32", scope="local") for j in range(0, 16): with tir.block() as []: tir.reads(A[i0 * 4 + i1 * 2 + i2, j]) tir.writes(B[0, j]) B[0, j] = A[i0 * 4 + i1 * 2 + i2, j] + 1.0 for j in range(0, 16): with tir.block() as []: tir.reads(B[0, j]) tir.writes(C[i0 * 4 + i1 * 2 + i2, j]) C[i0 * 4 + i1 * 2 + i2, j] = B[0, j] * 2.0
def bound_to_thread(a: ty.handle, c: ty.handle) -> None: A = tir.match_buffer(a, [128, 128]) C = tir.match_buffer(c, [128, 128]) B = tir.alloc_buffer([128, 128], scope="shared") for i in tir.thread_binding(0, 128, thread="threadIdx.x"): for j in tir.serial(0, 128): with tir.block([128, 128], "B") as [vi, vj]: B[vi, vj] = A[vi, vj] * 2.0 for j in tir.serial(0, 128): with tir.block([128, 128], "C") as [vi, vj]: C[vj, vi] = B[vj, vi] + 1.0
def elementwise_with_thread_binding(a: ty.handle, b: ty.handle) -> None: A = tir.match_buffer(a, (128, 128, 128)) B = tir.match_buffer(b, (128, 128, 128)) for i, j in tir.grid(128, 128): for k in tir.thread_binding(0, 128, thread="threadIdx.x"): with tir.block([128, 128, 128], "B") as [vi, vj, vk]: tir.bind(vi, i) tir.bind(vj, j) tir.bind(vk, k) tir.reads([A[vi, vj, vk]]) tir.writes([B[vi, vj, vk]]) B[vi, vj, vk] = A[vi, vj, vk] * 2.0
def factorized_after_reverse_compute_at(a: ty.handle, b: ty.handle) -> None: A = tir.match_buffer(a, [16, 16, 16], "float32") B = tir.match_buffer(b, [16], "float32") B_rf_local = tir.alloc_buffer([16, 16], "float32", scope="local") for j in tir.thread_binding(0, 16, thread="blockIdx.x"): for i_o in tir.thread_binding(0, 4, thread="threadIdx.x"): for i_i, k in tir.grid(4, 16): with tir.block([16, 16, tir.reduce_axis(0, 16)], "B_rf") as [vi, vj, vk]: tir.bind(vi, i_o * 4 + i_i) tir.bind(vj, j) tir.bind(vk, k) with tir.init(): B_rf_local[vi, vj] = 0.0 B_rf_local[vi, vj] = B_rf_local[vi, vj] + A[vj, vi, vk] for k in tir.serial(0, 4): with tir.block([16, tir.reduce_axis(0, 16)], "B") as [vi, vk]: tir.bind(vi, j) tir.bind(vk, i_o * 4 + k) with tir.init(): B[vi] = 0.0 B[vi] = B[vi] + B_rf_local[vk, vi]
def cuda_matmul_1(a: ty.handle, b: ty.handle, c: ty.handle) -> None: # pylint: disable=undefined-loop-variable A = tir.match_buffer(a, [2048, 2048], "float32") B = tir.match_buffer(b, [2048, 2048], "float32") C = tir.match_buffer(c, [2048, 2048], "float32") A_shared = tir.alloc_buffer([2048, 2048], "float32", scope="shared") B_shared = tir.alloc_buffer([2048, 2048], "float32", scope="shared") A_shared_local = tir.alloc_buffer([2048, 2048], "float32", scope="local") B_shared_local = tir.alloc_buffer([2048, 2048], "float32", scope="local") C_local = tir.alloc_buffer([2048, 2048], "float32", scope="local") with tir.block([2048, 2048], "A_shared") as [v0, v1]: A_shared[v0, v1] = A[v0, v1] with tir.block([2048, 2048], "B_shared") as [v0, v1]: B_shared[v0, v1] = B[v0, v1] with tir.block([2048, 2048], "A_shared_local") as [v0, v1]: A_shared_local[v0, v1] = A_shared[v0, v1] with tir.block([2048, 2048], "B_shared_local") as [v0, v1]: B_shared_local[v0, v1] = B_shared[v0, v1] for by in tir.thread_binding(0, 32, thread="blockIdx.y"): for bx in tir.thread_binding(0, 32, thread="blockIdx.x"): for vy in tir.thread_binding(0, 2, thread="vthread.y"): for vx in tir.thread_binding(0, 2, thread="vthread.x"): for ty in tir.thread_binding(0, 8, thread="threadIdx.y"): for tx in tir.thread_binding(0, 8, thread="threadIdx.x"): for k_0 in tir.serial(0, 256): for k_1 in tir.unroll(0, 8): for _, i, j in tir.grid(1, 4, 4): with tir.block([ 2048, 2048, tir.reduce_axis(0, 2048) ], "C") as [vi, vj, vk]: tir.bind( vi, by * 64 + vy * 32 + ty * 4 + i) tir.bind( vj, bx * 64 + vx * 32 + tx * 4 + j) tir.bind(vk, k_0 * 8 + k_1) with tir.init(): C_local[vi, vj] = 0.0 C_local[vi, vj] = C_local[ vi, vj] + A_shared_local[ vk, vi] * B_shared_local[ vk, vj] for i, j in tir.grid(4, 4): with tir.block([2048, 2048], "C_local") as [vi, vj]: tir.bind(vi, by * 64 + vy * 32 + ty * 4 + i) tir.bind(vj, bx * 64 + vx * 32 + tx * 4 + j) C[vi, vj] = C_local[vi, vj]
def tensorcore_gemm(a: ty.handle, b: ty.handle, c: ty.handle) -> None: # match buffer A = tir.match_buffer(a, [1024, 1024], "float16") B = tir.match_buffer(b, [1024, 1024], "float16") C = tir.match_buffer(c, [1024, 1024], "float32") # body for blockIdx_x in tir.thread_binding(0, 16, "blockIdx.x"): for blockIdx_y in tir.thread_binding(0, 8, "blockIdx.y"): with tir.block([16, 8]) as [bx, by]: tir.bind(bx, blockIdx_x) tir.bind(by, blockIdx_y) shared_A = tir.alloc_buffer([1024, 1024], "float16", scope="shared") shared_B = tir.alloc_buffer([1024, 1024], "float16", scope="shared") wmma_A = tir.alloc_buffer([1024, 1024], "float16", scope="wmma.matrix_a") wmma_B = tir.alloc_buffer([1024, 1024], "float16", scope="wmma.matrix_b") wmma_C = tir.alloc_buffer([1024, 1024], "float32", scope="wmma.accumulator") for ty in tir.thread_binding(0, 2, "threadIdx.y"): for tz in tir.thread_binding(0, 2, "threadIdx.z"): for i, j in tir.grid(2, 4): with tir.block([64, 64]) as [vi, vj]: tir.bind(vi, bx * 4 + ty * 2 + i) tir.bind(vj, by * 8 + tz * 4 + j) tir.reads([]) tir.writes(wmma_C[vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16]) C0 = tir.match_buffer( wmma_C[vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16], (16, 16), "float32", strides=[16 * 4, 1], scope="wmma.accumulator", offset_factor=1, ) tir.evaluate( tir.tvm_fill_fragment( C0.data, 16, 16, 16, i * 4 + j, tir.float32(0), dtype="handle", ) ) for ko in range(0, 32): # copy data from global to shared for tx in tir.thread_binding(0, 32, "threadIdx.x"): for i0, j0 in tir.grid(1, 4): for j1 in tir.vectorized(0, 4): with tir.block([1024, 1024]) as [vi, vj]: tir.bind(vi, bx * 64 + ty * 32 + tx + i0) tir.bind(vj, ko * 32 + tz * 16 + j0 * 4 + j1) shared_A[vi, vj + 8] = A[vi, vj] for i0, j0 in tir.grid(2, 4): for j1 in tir.vectorized(0, 4): with tir.block([1024, 1024]) as [vi, vj]: tir.bind(vi, by * 128 + ty * 64 + tx * 2 + i0) tir.bind(vj, ko * 32 + tz * 16 + j0 * 4 + j1) shared_B[vi, vj + 8] = B[vi, vj] for ki in range(0, 2): for i in range(0, 2): with tir.block([64, 64]) as [vi, vk]: tir.bind(vi, bx * 4 + ty * 2 + i) tir.bind(vk, ko * 2 + ki) tir.reads( shared_A[ vi * 16 : vi * 16 + 16, vk * 16 : vk * 16 + 16 + 8, ] ) tir.writes( wmma_A[vi * 16 : vi * 16 + 16, vk * 16 : vk * 16 + 16] ) s0 = tir.var("int32") s1 = tir.var("int32") A0 = tir.match_buffer( shared_A[ vi * 16 : vi * 16 + 16, vk * 16 : vk * 16 + 16 + 8, ], (16, 16 + 8), "float16", strides=[s0, s1], scope="shared", offset_factor=1, ) wmma_A0 = tir.match_buffer( wmma_A[vi * 16 : vi * 16 + 16, vk * 16 : vk * 16 + 16], (16, 16), "float16", strides=[16, 1], scope="wmma.matrix_a", offset_factor=1, ) tir.evaluate( tir.tvm_load_matrix_sync( wmma_A0.data, 16, 16, 16, i, tir.tvm_access_ptr( tir.type_annotation(dtype="float16"), A0.data, A0.elem_offset + 8, A0.strides[0], 1, dtype="handle", ), A0.strides[0], "row_major", dtype="handle", ) ) for j in range(0, 4): with tir.block([64, 64]) as [vj, vk]: tir.bind(vj, by * 8 + tz * 4 + j) tir.bind(vk, ko * 2 + ki) tir.reads( shared_B[ vj * 16 : vj * 16 + 16, vk * 16 : vk * 16 + 16 + 8, ] ) tir.writes( wmma_B[vj * 16 : vj * 16 + 16, vk * 16 : vk * 16 + 16] ) s0 = tir.var("int32") s1 = tir.var("int32") B0 = tir.match_buffer( shared_B[ vj * 16 : vj * 16 + 16, vk * 16 : vk * 16 + 16 + 8, ], (16, 16 + 8), "float16", strides=[s0, s1], scope="shared", offset_factor=1, ) wmma_B0 = tir.match_buffer( wmma_B[vj * 16 : vj * 16 + 16, vk * 16 : vk * 16 + 16], (16, 16), "float16", strides=[16, 1], scope="wmma.matrix_b", offset_factor=1, ) tir.evaluate( tir.tvm_load_matrix_sync( wmma_B0.data, 16, 16, 16, j, tir.tvm_access_ptr( tir.type_annotation(dtype="float16"), B0.data, B0.elem_offset + 8, B0.strides[0], 1, dtype="handle", ), B0.strides[0], "col_major", dtype="handle", ) ) for i, j in tir.grid(2, 4): with tir.block([64, 64, tir.reduce_axis(0, 64)]) as [ vi, vj, vk, ]: tir.bind(vi, bx * 4 + ty * 2 + i) tir.bind(vj, by * 8 + tz * 4 + j) tir.bind(vk, ko * 2 + ki) tir.reads( [ wmma_A[ vi * 16 : vi * 16 + 16, vk * 16 : vk * 16 + 16 ], wmma_B[ vj * 16 : vj * 16 + 16, vk * 16 : vk * 16 + 16 ], wmma_C[ vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16 ], ] ) tir.writes( wmma_C[vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16] ) wmma_A1 = tir.match_buffer( wmma_A[vi * 16 : vi * 16 + 16, vk * 16 : vk * 16 + 16], (16, 16), "float16", strides=[16, 1], scope="wmma.matrix_a", offset_factor=1, ) wmma_B1 = tir.match_buffer( wmma_B[vj * 16 : vj * 16 + 16, vk * 16 : vk * 16 + 16], (16, 16), "float16", strides=[16, 1], scope="wmma.matrix_b", offset_factor=1, ) wmma_C1 = tir.match_buffer( wmma_C[vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16], (16, 16), "float32", strides=[16 * 4, 1], scope="wmma.accumulator", offset_factor=1, ) tir.evaluate( tir.tvm_mma_sync( wmma_C1.data, i * 4 + j, wmma_A1.data, i, wmma_B1.data, j, wmma_C1.data, i * 4 + j, dtype="handle", ) ) for i, j in tir.grid(2, 4): with tir.block([64, 64]) as [vi, vj]: tir.bind(vi, bx * 4 + ty * 2 + i) tir.bind(vj, by * 8 + tz * 4 + j) tir.reads(wmma_C[vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16]) tir.writes(C[vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16]) s0 = tir.var("int32") s1 = tir.var("int32") wmma_C2 = tir.match_buffer( wmma_C[vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16], (16, 16), "float32", strides=[16 * 4, 1], scope="wmma.accumulator", offset_factor=1, ) C1 = tir.match_buffer( C[vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16], (16, 16), "float32", strides=[s0, s1], offset_factor=1, ) tir.evaluate( tir.tvm_store_matrix_sync( wmma_C2.data, 16, 16, 16, i * 4 + j, tir.tvm_access_ptr( tir.type_annotation(dtype="float32"), C1.data, C1.elem_offset, C1.strides[0], 1, dtype="handle", ), C1.strides[0], "row_major", dtype="handle", ) )