def annotated_tensorized_matmul(a: T.handle, b: T.handle, c: T.handle) -> None: C = T.match_buffer(c, [128, 128], elem_offset=0, align=128, offset_factor=1) B = T.match_buffer(b, [128, 128], elem_offset=0, align=128, offset_factor=1) A = T.match_buffer(a, [128, 128], elem_offset=0, align=128, offset_factor=1) for i_outer, j_outer in T.grid(8, 8): for i_inner_init, j_inner_init in T.grid(16, 16): with T.block("init"): vi_init = T.axis.S(128, ((i_outer * 16) + i_inner_init)) vj_init = T.axis.S(128, ((j_outer * 16) + j_inner_init)) T.block_attr({"test_annotation": True}) C[vi_init, vj_init] = T.float32(0) for k_outer in T.grid(8): with T.block("update"): vi, vj, vk = T.axis.remap("SSR", [i_outer, j_outer, k_outer]) T.reads([ C[vi * 16:vi * 16 + 16, vj * 16:vj * 16 + 16], A[vi * 16:vi * 16 + 16, vk * 16:vk * 16 + 16], B[vj * 16:vj * 16 + 16, vk * 16:vk * 16 + 16], ]) T.writes(C[vi * 16:vi * 16 + 16, vj * 16:vj * 16 + 16]) A_elem_offset = T.var("int32") B_elem_offset = T.var("int32") C_elem_offset = T.var("int32") A_sub = T.match_buffer( A[vi * 16:vi * 16 + 16, vk * 16:vk * 16 + 16], [16, 16], elem_offset=A_elem_offset, ) B_sub = T.match_buffer( B[vj * 16:vj * 16 + 16, vk * 16:vk * 16 + 16], [16, 16], elem_offset=B_elem_offset, ) C_sub = T.match_buffer( C[vi * 16:vi * 16 + 16, vj * 16:vj * 16 + 16], [16, 16], elem_offset=C_elem_offset, ) T.evaluate( T.tvm_mma_sync( C_sub.data, T.floordiv(C_sub.elem_offset, 256), A_sub.data, T.floordiv(A_sub.elem_offset, 256), B_sub.data, T.floordiv(B_sub.elem_offset, 256), C_sub.data, T.floordiv(C_sub.elem_offset, 256), dtype="handle", ))
def tensorized_batch_matmul_mma( A: T.Buffer[(16, 128, 128), "float32"], B: T.Buffer[(16, 128, 128), "float32"], C: T.Buffer[(16, 128, 128), "float32"], ) -> None: for n, i, j in T.grid(16, 128, 128): with T.block("init"): vn, vi, vj = T.axis.remap("SSS", [n, i, j]) T.reads() T.writes(C[vn, vi, vj]) C[vn, vi, vj] = T.float32(0) for n in range(0, 16): for i, j, k in T.grid(8, 8, 8): with T.block("update"): vn, vi, vj, vk = T.axis.remap("SSSR", [n, i, j, k]) T.reads( C[vn:vn + 1, vi * 16:vi * 16 + 16, vj * 16:vj * 16 + 16], A[vn:vn + 1, vi * 16:vi * 16 + 16, vk * 16:vk * 16 + 16], B[vn:vn + 1, vj * 16:vj * 16 + 16, vk * 16:vk * 16 + 16], ) T.writes(C[vn:vn + 1, vi * 16:vi * 16 + 16, vj * 16:vj * 16 + 16]) A_elem_offset = T.var("int32") B_elem_offset = T.var("int32") C_elem_offset = T.var("int32") A_sub = T.match_buffer( A[vn:vn + 1, vi * 16:vi * 16 + 16, vk * 16:vk * 16 + 16], (16, 16), elem_offset=A_elem_offset, ) B_sub = T.match_buffer( B[vn:vn + 1, vj * 16:vj * 16 + 16, vk * 16:vk * 16 + 16], (16, 16), elem_offset=B_elem_offset, ) C_sub = T.match_buffer( C[vn:vn + 1, vi * 16:vi * 16 + 16, vj * 16:vj * 16 + 16], (16, 16), elem_offset=C_elem_offset, ) T.evaluate( T.tvm_mma_sync( C_sub.data, T.floordiv(C_sub.elem_offset, 256), A_sub.data, T.floordiv(A_sub.elem_offset, 256), B_sub.data, T.floordiv(B_sub.elem_offset, 256), C_sub.data, T.floordiv(C_sub.elem_offset, 256), dtype="handle", ))
def mma_intrin(a: T.handle, b: T.handle, c: T.handle) -> None: A = T.match_buffer(a, (16, 16), align=128, offset_factor=1) B = T.match_buffer(b, (16, 16), align=128, offset_factor=1) C = T.match_buffer(c, (16, 16), align=128, offset_factor=1) with T.block("root"): T.reads(C[0:16, 0:16], A[0:16, 0:16], B[0:16, 0:16]) T.writes(C[0:16, 0:16]) T.evaluate( T.tvm_mma_sync( C.data, C.elem_offset // 256, A.data, A.elem_offset // 256, B.data, B.elem_offset // 256, C.data, C.elem_offset // 256, dtype="handle", ))
def wmma_sync_impl(a: T.handle, b: T.handle, c: T.handle) -> None: A = T.match_buffer(a, (m_dim, k_dim), in_dtype, align=128, offset_factor=16, scope="wmma.matrix_a") B = T.match_buffer( b, maybe_swap(k_dim, n_dim), in_dtype, align=128, offset_factor=16, scope="wmma.matrix_b", ) C = T.match_buffer(c, (m_dim, n_dim), out_dtype, align=128, offset_factor=16, scope="wmma.accumulator") with T.block("root"): T.reads(C[0:m_dim, 0:n_dim], A[0:m_dim, 0:k_dim], B[0:b_shape_0, 0:b_shape_1]) T.writes(C[0:m_dim, 0:n_dim]) T.evaluate( T.tvm_mma_sync( C.data, get_wmma_fragment_index(C, m_dim, n_dim), A.data, get_wmma_fragment_index(A, m_dim, k_dim), B.data, get_wmma_fragment_index(B, b_shape_0, b_shape_1), C.data, get_wmma_fragment_index(C, m_dim, n_dim), dtype="handle", ))
def tensorcore_gemm(handle_a: T.handle, handle_b: T.handle, handle_c: T.handle) -> None: # pylint: disable=missing-function-docstring # match buffer match_buffer_a = T.match_buffer(handle_a, [1024, 1024], "float16") match_buffer_b = T.match_buffer(handle_b, [1024, 1024], "float16") match_buffer_c = T.match_buffer(handle_c, [1024, 1024], "float32") # body for block_idx_x in T.thread_binding(0, 16, "blockIdx.x"): for block_idx_y in T.thread_binding(0, 8, "blockIdx.y"): with T.block(): axis_bx, axis_by = T.axis.remap("SS", [block_idx_x, block_idx_y]) shared_a = T.alloc_buffer([1024, 1024], "float16", scope="shared") shared_b = T.alloc_buffer([1024, 1024], "float16", scope="shared") wmma_a = T.alloc_buffer([1024, 1024], "float16", scope="wmma.matrix_a") wmma_b = T.alloc_buffer([1024, 1024], "float16", scope="wmma.matrix_b") wmma_c = T.alloc_buffer([1024, 1024], "float32", scope="wmma.accumulator") # pylint: disable=too-many-nested-blocks for thread_ty in T.thread_binding(0, 2, "threadIdx.y"): for thread_tz in T.thread_binding(0, 2, "threadIdx.z"): for index_i, index_jj in T.grid(2, 4): with T.block(): new_axis_vi = T.axis.S( 64, axis_bx * 4 + thread_ty * 2 + index_i) new_axis_vj = T.axis.S( 64, axis_by * 8 + thread_tz * 4 + index_jj) T.reads([]) T.writes(wmma_c[new_axis_vi * 16:new_axis_vi * 16 + 16, new_axis_vj * 16:new_axis_vj * 16 + 16, ]) match_buffer_c0 = T.match_buffer( wmma_c[new_axis_vi * 16:new_axis_vi * 16 + 16, new_axis_vj * 16:new_axis_vj * 16 + 16, ], (16, 16), "float32", strides=[16 * 4, 1], scope="wmma.accumulator", offset_factor=1, ) T.evaluate( T.tvm_fill_fragment( match_buffer_c0.data, 16, 16, 16, index_i * 4 + index_jj, T.float32(0), # pylint: disable=not-callable dtype="handle", )) for k_o in range(0, 32): # copy data from global to shared for thread_tx in T.thread_binding( 0, 32, "threadIdx.x"): for index_i0, index_j0 in T.grid(1, 4): for index_j1 in T.vectorized(0, 4): with T.block(): new_axis_vi = T.axis.S( 1024, axis_bx * 64 + thread_ty * 32 + thread_tx + index_i0, ) new_axis_vj = T.axis.S( 1024, k_o * 32 + thread_tz * 16 + index_j0 * 4 + index_j1, ) shared_a[new_axis_vi, new_axis_vj + 8] = match_buffer_a[ new_axis_vi, new_axis_vj] for index_i0, index_j0 in T.grid(2, 4): for index_j1 in T.vectorized(0, 4): with T.block(): new_axis_vi = T.axis.S( 1024, axis_by * 128 + thread_ty * 64 + thread_tx * 2 + index_i0, ) new_axis_vj = T.axis.S( 1024, k_o * 32 + thread_tz * 16 + index_j0 * 4 + index_j1, ) shared_b[new_axis_vi, new_axis_vj + 8] = match_buffer_b[ new_axis_vi, new_axis_vj] for k_i in range(0, 2): for index_i in range(0, 2): with T.block(): new_axis_vi = T.axis.S( 64, axis_bx * 4 + thread_ty * 2 + index_i) axis_vk = T.axis.S(64, k_o * 2 + k_i) T.reads(shared_a[new_axis_vi * 16:new_axis_vi * 16 + 16, axis_vk * 16:axis_vk * 16 + 16 + 8, ]) T.writes( wmma_a[new_axis_vi * 16:new_axis_vi * 16 + 16, axis_vk * 16:axis_vk * 16 + 16, ]) stride0 = T.var("int32") stride1 = T.var("int32") match_buffer_a0 = T.match_buffer( shared_a[new_axis_vi * 16:new_axis_vi * 16 + 16, axis_vk * 16:axis_vk * 16 + 16 + 8, ], (16, 16 + 8), "float16", strides=[stride0, stride1], scope="shared", offset_factor=1, ) wmma_a0 = T.match_buffer( wmma_a[new_axis_vi * 16:new_axis_vi * 16 + 16, axis_vk * 16:axis_vk * 16 + 16, ], (16, 16), "float16", strides=[16, 1], scope="wmma.matrix_a", offset_factor=1, ) T.evaluate( T.tvm_load_matrix_sync( wmma_a0.data, 16, 16, 16, index_i, T.tvm_access_ptr( T.type_annotation( dtype="float16"), match_buffer_a0.data, match_buffer_a0.elem_offset + 8, match_buffer_a0.strides[0], 1, dtype="handle", ), match_buffer_a0.strides[0], "row_major", dtype="handle", )) for index_jj in range(0, 4): with T.block(): new_axis_vj = T.axis.S( 64, axis_by * 8 + thread_tz * 4 + index_jj) axis_vk = T.axis.S(64, k_o * 2 + k_i) T.reads(shared_b[new_axis_vj * 16:new_axis_vj * 16 + 16, axis_vk * 16:axis_vk * 16 + 16 + 8, ]) T.writes( wmma_b[new_axis_vj * 16:new_axis_vj * 16 + 16, axis_vk * 16:axis_vk * 16 + 16, ]) stride0 = T.var("int32") stride1 = T.var("int32") match_buffer_b0 = T.match_buffer( shared_b[new_axis_vj * 16:new_axis_vj * 16 + 16, axis_vk * 16:axis_vk * 16 + 16 + 8, ], (16, 16 + 8), "float16", strides=[stride0, stride1], scope="shared", offset_factor=1, ) wmma_b0 = T.match_buffer( wmma_b[new_axis_vj * 16:new_axis_vj * 16 + 16, axis_vk * 16:axis_vk * 16 + 16, ], (16, 16), "float16", strides=[16, 1], scope="wmma.matrix_b", offset_factor=1, ) T.evaluate( T.tvm_load_matrix_sync( wmma_b0.data, 16, 16, 16, index_jj, T.tvm_access_ptr( T.type_annotation( dtype="float16"), match_buffer_b0.data, match_buffer_b0.elem_offset + 8, match_buffer_b0.strides[0], 1, dtype="handle", ), match_buffer_b0.strides[0], "col_major", dtype="handle", )) for index_i, index_jj in T.grid(2, 4): with T.block(): new_axis_vi = T.axis.S( 64, axis_bx * 4 + thread_ty * 2 + index_i) new_axis_vj = T.axis.S( 64, axis_by * 8 + thread_tz * 4 + index_jj) axis_vk = T.axis.R(64, k_o * 2 + k_i) T.reads([ wmma_a[new_axis_vi * 16:new_axis_vi * 16 + 16, axis_vk * 16:axis_vk * 16 + 16, ], wmma_b[new_axis_vj * 16:new_axis_vj * 16 + 16, axis_vk * 16:axis_vk * 16 + 16, ], wmma_c[new_axis_vi * 16:new_axis_vi * 16 + 16, new_axis_vj * 16:new_axis_vj * 16 + 16, ], ]) T.writes( wmma_c[new_axis_vi * 16:new_axis_vi * 16 + 16, new_axis_vj * 16:new_axis_vj * 16 + 16, ]) wmma_a1 = T.match_buffer( wmma_a[new_axis_vi * 16:new_axis_vi * 16 + 16, axis_vk * 16:axis_vk * 16 + 16, ], (16, 16), "float16", strides=[16, 1], scope="wmma.matrix_a", offset_factor=1, ) wmma_b1 = T.match_buffer( wmma_b[new_axis_vj * 16:new_axis_vj * 16 + 16, axis_vk * 16:axis_vk * 16 + 16, ], (16, 16), "float16", strides=[16, 1], scope="wmma.matrix_b", offset_factor=1, ) wmma_c1 = T.match_buffer( wmma_c[new_axis_vi * 16:new_axis_vi * 16 + 16, new_axis_vj * 16:new_axis_vj * 16 + 16, ], (16, 16), "float32", strides=[16 * 4, 1], scope="wmma.accumulator", offset_factor=1, ) T.evaluate( T.tvm_mma_sync( wmma_c1.data, index_i * 4 + index_jj, wmma_a1.data, index_i, wmma_b1.data, index_jj, wmma_c1.data, index_i * 4 + index_jj, dtype="handle", )) for index_i, index_jj in T.grid(2, 4): with T.block(): new_axis_vi = T.axis.S( 64, axis_bx * 4 + thread_ty * 2 + index_i) new_axis_vj = T.axis.S( 64, axis_by * 8 + thread_tz * 4 + index_jj) T.reads(wmma_c[new_axis_vi * 16:new_axis_vi * 16 + 16, new_axis_vj * 16:new_axis_vj * 16 + 16, ]) T.writes( match_buffer_c[new_axis_vi * 16:new_axis_vi * 16 + 16, new_axis_vj * 16:new_axis_vj * 16 + 16, ]) stride0 = T.var("int32") stride1 = T.var("int32") wmma_c2 = T.match_buffer( wmma_c[new_axis_vi * 16:new_axis_vi * 16 + 16, new_axis_vj * 16:new_axis_vj * 16 + 16, ], (16, 16), "float32", strides=[16 * 4, 1], scope="wmma.accumulator", offset_factor=1, ) match_buffer_c1 = T.match_buffer( match_buffer_c[new_axis_vi * 16:new_axis_vi * 16 + 16, new_axis_vj * 16:new_axis_vj * 16 + 16, ], (16, 16), "float32", strides=[stride0, stride1], offset_factor=1, ) T.evaluate( T.tvm_store_matrix_sync( wmma_c2.data, 16, 16, 16, index_i * 4 + index_jj, T.tvm_access_ptr( T.type_annotation(dtype="float32"), match_buffer_c1.data, match_buffer_c1.elem_offset, match_buffer_c1.strides[0], 1, dtype="handle", ), match_buffer_c1.strides[0], "row_major", dtype="handle", ))
def tensorcore_gemm(a: T.handle, b: T.handle, c: T.handle) -> None: # match buffer A = T.match_buffer(a, [1024, 1024], "float16") B = T.match_buffer(b, [1024, 1024], "float16") C = T.match_buffer(c, [1024, 1024], "float32") # body for blockIdx_x in T.thread_binding(0, 16, "blockIdx.x"): for blockIdx_y in T.thread_binding(0, 8, "blockIdx.y"): with T.block([16, 8]) as [bx, by]: T.bind(bx, blockIdx_x) T.bind(by, blockIdx_y) shared_A = T.alloc_buffer([1024, 1024], "float16", scope="shared") shared_B = T.alloc_buffer([1024, 1024], "float16", scope="shared") wmma_A = T.alloc_buffer([1024, 1024], "float16", scope="wmma.matrix_a") wmma_B = T.alloc_buffer([1024, 1024], "float16", scope="wmma.matrix_b") wmma_C = T.alloc_buffer([1024, 1024], "float32", scope="wmma.accumulator") for ty in T.thread_binding(0, 2, "threadIdx.y"): for tz in T.thread_binding(0, 2, "threadIdx.z"): for i, j in T.grid(2, 4): with T.block([64, 64]) as [vi, vj]: T.bind(vi, bx * 4 + ty * 2 + i) T.bind(vj, by * 8 + tz * 4 + j) T.reads([]) T.writes(wmma_C[vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16]) C0 = T.match_buffer( wmma_C[vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16], (16, 16), "float32", strides=[16 * 4, 1], scope="wmma.accumulator", offset_factor=1, ) T.evaluate( T.tvm_fill_fragment( C0.data, 16, 16, 16, i * 4 + j, T.float32(0), dtype="handle", ) ) for ko in range(0, 32): # copy data from global to shared for tx in T.thread_binding(0, 32, "threadIdx.x"): for i0, j0 in T.grid(1, 4): for j1 in T.vectorized(0, 4): with T.block([1024, 1024]) as [vi, vj]: T.bind(vi, bx * 64 + ty * 32 + tx + i0) T.bind(vj, ko * 32 + tz * 16 + j0 * 4 + j1) shared_A[vi, vj + 8] = A[vi, vj] for i0, j0 in T.grid(2, 4): for j1 in T.vectorized(0, 4): with T.block([1024, 1024]) as [vi, vj]: T.bind(vi, by * 128 + ty * 64 + tx * 2 + i0) T.bind(vj, ko * 32 + tz * 16 + j0 * 4 + j1) shared_B[vi, vj + 8] = B[vi, vj] for ki in range(0, 2): for i in range(0, 2): with T.block([64, 64]) as [vi, vk]: T.bind(vi, bx * 4 + ty * 2 + i) T.bind(vk, ko * 2 + ki) T.reads( shared_A[ vi * 16 : vi * 16 + 16, vk * 16 : vk * 16 + 16 + 8, ] ) T.writes( wmma_A[vi * 16 : vi * 16 + 16, vk * 16 : vk * 16 + 16] ) s0 = T.var("int32") s1 = T.var("int32") A0 = T.match_buffer( shared_A[ vi * 16 : vi * 16 + 16, vk * 16 : vk * 16 + 16 + 8, ], (16, 16 + 8), "float16", strides=[s0, s1], scope="shared", offset_factor=1, ) wmma_A0 = T.match_buffer( wmma_A[vi * 16 : vi * 16 + 16, vk * 16 : vk * 16 + 16], (16, 16), "float16", strides=[16, 1], scope="wmma.matrix_a", offset_factor=1, ) T.evaluate( T.tvm_load_matrix_sync( wmma_A0.data, 16, 16, 16, i, T.tvm_access_ptr( T.type_annotation(dtype="float16"), A0.data, A0.elem_offset + 8, A0.strides[0], 1, dtype="handle", ), A0.strides[0], "row_major", dtype="handle", ) ) for j in range(0, 4): with T.block([64, 64]) as [vj, vk]: T.bind(vj, by * 8 + tz * 4 + j) T.bind(vk, ko * 2 + ki) T.reads( shared_B[ vj * 16 : vj * 16 + 16, vk * 16 : vk * 16 + 16 + 8, ] ) T.writes( wmma_B[vj * 16 : vj * 16 + 16, vk * 16 : vk * 16 + 16] ) s0 = T.var("int32") s1 = T.var("int32") B0 = T.match_buffer( shared_B[ vj * 16 : vj * 16 + 16, vk * 16 : vk * 16 + 16 + 8, ], (16, 16 + 8), "float16", strides=[s0, s1], scope="shared", offset_factor=1, ) wmma_B0 = T.match_buffer( wmma_B[vj * 16 : vj * 16 + 16, vk * 16 : vk * 16 + 16], (16, 16), "float16", strides=[16, 1], scope="wmma.matrix_b", offset_factor=1, ) T.evaluate( T.tvm_load_matrix_sync( wmma_B0.data, 16, 16, 16, j, T.tvm_access_ptr( T.type_annotation(dtype="float16"), B0.data, B0.elem_offset + 8, B0.strides[0], 1, dtype="handle", ), B0.strides[0], "col_major", dtype="handle", ) ) for i, j in T.grid(2, 4): with T.block([64, 64, T.reduce_axis(0, 64)]) as [ vi, vj, vk, ]: T.bind(vi, bx * 4 + ty * 2 + i) T.bind(vj, by * 8 + tz * 4 + j) T.bind(vk, ko * 2 + ki) T.reads( [ wmma_A[ vi * 16 : vi * 16 + 16, vk * 16 : vk * 16 + 16 ], wmma_B[ vj * 16 : vj * 16 + 16, vk * 16 : vk * 16 + 16 ], wmma_C[ vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16 ], ] ) T.writes( wmma_C[vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16] ) wmma_A1 = T.match_buffer( wmma_A[vi * 16 : vi * 16 + 16, vk * 16 : vk * 16 + 16], (16, 16), "float16", strides=[16, 1], scope="wmma.matrix_a", offset_factor=1, ) wmma_B1 = T.match_buffer( wmma_B[vj * 16 : vj * 16 + 16, vk * 16 : vk * 16 + 16], (16, 16), "float16", strides=[16, 1], scope="wmma.matrix_b", offset_factor=1, ) wmma_C1 = T.match_buffer( wmma_C[vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16], (16, 16), "float32", strides=[16 * 4, 1], scope="wmma.accumulator", offset_factor=1, ) T.evaluate( T.tvm_mma_sync( wmma_C1.data, i * 4 + j, wmma_A1.data, i, wmma_B1.data, j, wmma_C1.data, i * 4 + j, dtype="handle", ) ) for i, j in T.grid(2, 4): with T.block([64, 64]) as [vi, vj]: T.bind(vi, bx * 4 + ty * 2 + i) T.bind(vj, by * 8 + tz * 4 + j) T.reads(wmma_C[vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16]) T.writes(C[vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16]) s0 = T.var("int32") s1 = T.var("int32") wmma_C2 = T.match_buffer( wmma_C[vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16], (16, 16), "float32", strides=[16 * 4, 1], scope="wmma.accumulator", offset_factor=1, ) C1 = T.match_buffer( C[vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16], (16, 16), "float32", strides=[s0, s1], offset_factor=1, ) T.evaluate( T.tvm_store_matrix_sync( wmma_C2.data, 16, 16, 16, i * 4 + j, T.tvm_access_ptr( T.type_annotation(dtype="float32"), C1.data, C1.elem_offset, C1.strides[0], 1, dtype="handle", ), C1.strides[0], "row_major", dtype="handle", ) )