def access_opaque_ptr_then_elemwise_inline(a: T.handle, b: T.handle) -> None: A = T.match_buffer(a, [1024], dtype="float32") B = T.match_buffer(b, [1024], dtype="float32") A_cache = T.alloc_buffer([1024], dtype="float32") with T.block("opaque"): # annotated opaque partial access should be kept T.reads(A[0:512]) T.writes([A_cache[0:512]]) T.evaluate( T.tvm_access_ptr(T.type_annotation(dtype="float32"), A.data, 0, 512, "r", dtype="handle")) T.evaluate( T.tvm_access_ptr(T.type_annotation(dtype="float32"), A_cache.data, 0, 512, "w", dtype="handle")) for i in T.serial(0, 512): with T.block("B"): vi = T.axis.spatial(512, i) T.reads([A_cache[vi]]) T.writes([B[vi]]) B[vi] = A_cache[vi] * 2.0 + 1.0
def access_opaque_ptr_then_elemwise(a: T.handle, b: T.handle) -> None: A = T.match_buffer(a, [1024]) B = T.match_buffer(b, [1024]) A_cache = T.alloc_buffer([1024]) BB = T.alloc_buffer([1024]) with T.block("opaque"): # annotated opaque partial access T.reads(A[0:512]) T.writes(A_cache[0:512]) T.evaluate( T.tvm_access_ptr(T.type_annotation(dtype="float32"), A.data, 0, 512, "r", dtype="handle")) T.evaluate( T.tvm_access_ptr(T.type_annotation(dtype="float32"), A_cache.data, 0, 512, "w", dtype="handle")) for i in range(512): with T.block("BB"): vi = T.axis.remap("S", [i]) BB[vi] = A_cache[vi] * 2.0 for i in range(512): with T.block("B"): vi = T.axis.remap("S", [i]) B[vi] = BB[vi] + 1.0
def opaque_access_store(a: T.handle, c: T.handle) -> None: A = T.match_buffer(a, (128, 128)) B = T.alloc_buffer((128, 128)) C = T.match_buffer(c, (128, 128)) for i, j in T.grid(128, 128): with T.block("B"): vi, vj = T.axis.remap("SS", [i, j]) B[vi, vj] = A[vi, vj] * 2.0 for i, j in T.grid(128, 128): with T.block("C"): vi, vj = T.axis.remap("SS", [i, j]) T.reads(B[0:128, 0:128]) T.writes(C[0:128, 0:128]) T.evaluate( T.tvm_access_ptr(T.type_annotation(dtype="float32"), B.data, 0, 128, "r", dtype="handle")) T.evaluate( T.tvm_access_ptr(T.type_annotation(dtype="float32"), C.data, 0, 128, "w", dtype="handle")) C[vi, vj] = B[vi, vj] + 1.0
def tensorcore_gemm(handle_a: T.handle, handle_b: T.handle, handle_c: T.handle) -> None: # pylint: disable=missing-function-docstring # match buffer match_buffer_a = T.match_buffer(handle_a, [1024, 1024], "float16") match_buffer_b = T.match_buffer(handle_b, [1024, 1024], "float16") match_buffer_c = T.match_buffer(handle_c, [1024, 1024], "float32") # body for block_idx_x in T.thread_binding(0, 16, "blockIdx.x"): for block_idx_y in T.thread_binding(0, 8, "blockIdx.y"): with T.block(): axis_bx, axis_by = T.axis.remap("SS", [block_idx_x, block_idx_y]) shared_a = T.alloc_buffer([1024, 1024], "float16", scope="shared") shared_b = T.alloc_buffer([1024, 1024], "float16", scope="shared") wmma_a = T.alloc_buffer([1024, 1024], "float16", scope="wmma.matrix_a") wmma_b = T.alloc_buffer([1024, 1024], "float16", scope="wmma.matrix_b") wmma_c = T.alloc_buffer([1024, 1024], "float32", scope="wmma.accumulator") # pylint: disable=too-many-nested-blocks for thread_ty in T.thread_binding(0, 2, "threadIdx.y"): for thread_tz in T.thread_binding(0, 2, "threadIdx.z"): for index_i, index_jj in T.grid(2, 4): with T.block(): new_axis_vi = T.axis.S( 64, axis_bx * 4 + thread_ty * 2 + index_i) new_axis_vj = T.axis.S( 64, axis_by * 8 + thread_tz * 4 + index_jj) T.reads([]) T.writes(wmma_c[new_axis_vi * 16:new_axis_vi * 16 + 16, new_axis_vj * 16:new_axis_vj * 16 + 16, ]) match_buffer_c0 = T.match_buffer( wmma_c[new_axis_vi * 16:new_axis_vi * 16 + 16, new_axis_vj * 16:new_axis_vj * 16 + 16, ], (16, 16), "float32", strides=[16 * 4, 1], scope="wmma.accumulator", offset_factor=1, ) T.evaluate( T.tvm_fill_fragment( match_buffer_c0.data, 16, 16, 16, index_i * 4 + index_jj, T.float32(0), # pylint: disable=not-callable dtype="handle", )) for k_o in range(0, 32): # copy data from global to shared for thread_tx in T.thread_binding( 0, 32, "threadIdx.x"): for index_i0, index_j0 in T.grid(1, 4): for index_j1 in T.vectorized(0, 4): with T.block(): new_axis_vi = T.axis.S( 1024, axis_bx * 64 + thread_ty * 32 + thread_tx + index_i0, ) new_axis_vj = T.axis.S( 1024, k_o * 32 + thread_tz * 16 + index_j0 * 4 + index_j1, ) shared_a[new_axis_vi, new_axis_vj + 8] = match_buffer_a[ new_axis_vi, new_axis_vj] for index_i0, index_j0 in T.grid(2, 4): for index_j1 in T.vectorized(0, 4): with T.block(): new_axis_vi = T.axis.S( 1024, axis_by * 128 + thread_ty * 64 + thread_tx * 2 + index_i0, ) new_axis_vj = T.axis.S( 1024, k_o * 32 + thread_tz * 16 + index_j0 * 4 + index_j1, ) shared_b[new_axis_vi, new_axis_vj + 8] = match_buffer_b[ new_axis_vi, new_axis_vj] for k_i in range(0, 2): for index_i in range(0, 2): with T.block(): new_axis_vi = T.axis.S( 64, axis_bx * 4 + thread_ty * 2 + index_i) axis_vk = T.axis.S(64, k_o * 2 + k_i) T.reads(shared_a[new_axis_vi * 16:new_axis_vi * 16 + 16, axis_vk * 16:axis_vk * 16 + 16 + 8, ]) T.writes( wmma_a[new_axis_vi * 16:new_axis_vi * 16 + 16, axis_vk * 16:axis_vk * 16 + 16, ]) stride0 = T.var("int32") stride1 = T.var("int32") match_buffer_a0 = T.match_buffer( shared_a[new_axis_vi * 16:new_axis_vi * 16 + 16, axis_vk * 16:axis_vk * 16 + 16 + 8, ], (16, 16 + 8), "float16", strides=[stride0, stride1], scope="shared", offset_factor=1, ) wmma_a0 = T.match_buffer( wmma_a[new_axis_vi * 16:new_axis_vi * 16 + 16, axis_vk * 16:axis_vk * 16 + 16, ], (16, 16), "float16", strides=[16, 1], scope="wmma.matrix_a", offset_factor=1, ) T.evaluate( T.tvm_load_matrix_sync( wmma_a0.data, 16, 16, 16, index_i, T.tvm_access_ptr( T.type_annotation( dtype="float16"), match_buffer_a0.data, match_buffer_a0.elem_offset + 8, match_buffer_a0.strides[0], 1, dtype="handle", ), match_buffer_a0.strides[0], "row_major", dtype="handle", )) for index_jj in range(0, 4): with T.block(): new_axis_vj = T.axis.S( 64, axis_by * 8 + thread_tz * 4 + index_jj) axis_vk = T.axis.S(64, k_o * 2 + k_i) T.reads(shared_b[new_axis_vj * 16:new_axis_vj * 16 + 16, axis_vk * 16:axis_vk * 16 + 16 + 8, ]) T.writes( wmma_b[new_axis_vj * 16:new_axis_vj * 16 + 16, axis_vk * 16:axis_vk * 16 + 16, ]) stride0 = T.var("int32") stride1 = T.var("int32") match_buffer_b0 = T.match_buffer( shared_b[new_axis_vj * 16:new_axis_vj * 16 + 16, axis_vk * 16:axis_vk * 16 + 16 + 8, ], (16, 16 + 8), "float16", strides=[stride0, stride1], scope="shared", offset_factor=1, ) wmma_b0 = T.match_buffer( wmma_b[new_axis_vj * 16:new_axis_vj * 16 + 16, axis_vk * 16:axis_vk * 16 + 16, ], (16, 16), "float16", strides=[16, 1], scope="wmma.matrix_b", offset_factor=1, ) T.evaluate( T.tvm_load_matrix_sync( wmma_b0.data, 16, 16, 16, index_jj, T.tvm_access_ptr( T.type_annotation( dtype="float16"), match_buffer_b0.data, match_buffer_b0.elem_offset + 8, match_buffer_b0.strides[0], 1, dtype="handle", ), match_buffer_b0.strides[0], "col_major", dtype="handle", )) for index_i, index_jj in T.grid(2, 4): with T.block(): new_axis_vi = T.axis.S( 64, axis_bx * 4 + thread_ty * 2 + index_i) new_axis_vj = T.axis.S( 64, axis_by * 8 + thread_tz * 4 + index_jj) axis_vk = T.axis.R(64, k_o * 2 + k_i) T.reads([ wmma_a[new_axis_vi * 16:new_axis_vi * 16 + 16, axis_vk * 16:axis_vk * 16 + 16, ], wmma_b[new_axis_vj * 16:new_axis_vj * 16 + 16, axis_vk * 16:axis_vk * 16 + 16, ], wmma_c[new_axis_vi * 16:new_axis_vi * 16 + 16, new_axis_vj * 16:new_axis_vj * 16 + 16, ], ]) T.writes( wmma_c[new_axis_vi * 16:new_axis_vi * 16 + 16, new_axis_vj * 16:new_axis_vj * 16 + 16, ]) wmma_a1 = T.match_buffer( wmma_a[new_axis_vi * 16:new_axis_vi * 16 + 16, axis_vk * 16:axis_vk * 16 + 16, ], (16, 16), "float16", strides=[16, 1], scope="wmma.matrix_a", offset_factor=1, ) wmma_b1 = T.match_buffer( wmma_b[new_axis_vj * 16:new_axis_vj * 16 + 16, axis_vk * 16:axis_vk * 16 + 16, ], (16, 16), "float16", strides=[16, 1], scope="wmma.matrix_b", offset_factor=1, ) wmma_c1 = T.match_buffer( wmma_c[new_axis_vi * 16:new_axis_vi * 16 + 16, new_axis_vj * 16:new_axis_vj * 16 + 16, ], (16, 16), "float32", strides=[16 * 4, 1], scope="wmma.accumulator", offset_factor=1, ) T.evaluate( T.tvm_mma_sync( wmma_c1.data, index_i * 4 + index_jj, wmma_a1.data, index_i, wmma_b1.data, index_jj, wmma_c1.data, index_i * 4 + index_jj, dtype="handle", )) for index_i, index_jj in T.grid(2, 4): with T.block(): new_axis_vi = T.axis.S( 64, axis_bx * 4 + thread_ty * 2 + index_i) new_axis_vj = T.axis.S( 64, axis_by * 8 + thread_tz * 4 + index_jj) T.reads(wmma_c[new_axis_vi * 16:new_axis_vi * 16 + 16, new_axis_vj * 16:new_axis_vj * 16 + 16, ]) T.writes( match_buffer_c[new_axis_vi * 16:new_axis_vi * 16 + 16, new_axis_vj * 16:new_axis_vj * 16 + 16, ]) stride0 = T.var("int32") stride1 = T.var("int32") wmma_c2 = T.match_buffer( wmma_c[new_axis_vi * 16:new_axis_vi * 16 + 16, new_axis_vj * 16:new_axis_vj * 16 + 16, ], (16, 16), "float32", strides=[16 * 4, 1], scope="wmma.accumulator", offset_factor=1, ) match_buffer_c1 = T.match_buffer( match_buffer_c[new_axis_vi * 16:new_axis_vi * 16 + 16, new_axis_vj * 16:new_axis_vj * 16 + 16, ], (16, 16), "float32", strides=[stride0, stride1], offset_factor=1, ) T.evaluate( T.tvm_store_matrix_sync( wmma_c2.data, 16, 16, 16, index_i * 4 + index_jj, T.tvm_access_ptr( T.type_annotation(dtype="float32"), match_buffer_c1.data, match_buffer_c1.elem_offset, match_buffer_c1.strides[0], 1, dtype="handle", ), match_buffer_c1.strides[0], "row_major", dtype="handle", ))
def tensorcore_gemm(a: T.handle, b: T.handle, c: T.handle) -> None: # match buffer A = T.match_buffer(a, [1024, 1024], "float16") B = T.match_buffer(b, [1024, 1024], "float16") C = T.match_buffer(c, [1024, 1024], "float32") # body for blockIdx_x in T.thread_binding(0, 16, "blockIdx.x"): for blockIdx_y in T.thread_binding(0, 8, "blockIdx.y"): with T.block([16, 8]) as [bx, by]: T.bind(bx, blockIdx_x) T.bind(by, blockIdx_y) shared_A = T.alloc_buffer([1024, 1024], "float16", scope="shared") shared_B = T.alloc_buffer([1024, 1024], "float16", scope="shared") wmma_A = T.alloc_buffer([1024, 1024], "float16", scope="wmma.matrix_a") wmma_B = T.alloc_buffer([1024, 1024], "float16", scope="wmma.matrix_b") wmma_C = T.alloc_buffer([1024, 1024], "float32", scope="wmma.accumulator") for ty in T.thread_binding(0, 2, "threadIdx.y"): for tz in T.thread_binding(0, 2, "threadIdx.z"): for i, j in T.grid(2, 4): with T.block([64, 64]) as [vi, vj]: T.bind(vi, bx * 4 + ty * 2 + i) T.bind(vj, by * 8 + tz * 4 + j) T.reads([]) T.writes(wmma_C[vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16]) C0 = T.match_buffer( wmma_C[vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16], (16, 16), "float32", strides=[16 * 4, 1], scope="wmma.accumulator", offset_factor=1, ) T.evaluate( T.tvm_fill_fragment( C0.data, 16, 16, 16, i * 4 + j, T.float32(0), dtype="handle", ) ) for ko in range(0, 32): # copy data from global to shared for tx in T.thread_binding(0, 32, "threadIdx.x"): for i0, j0 in T.grid(1, 4): for j1 in T.vectorized(0, 4): with T.block([1024, 1024]) as [vi, vj]: T.bind(vi, bx * 64 + ty * 32 + tx + i0) T.bind(vj, ko * 32 + tz * 16 + j0 * 4 + j1) shared_A[vi, vj + 8] = A[vi, vj] for i0, j0 in T.grid(2, 4): for j1 in T.vectorized(0, 4): with T.block([1024, 1024]) as [vi, vj]: T.bind(vi, by * 128 + ty * 64 + tx * 2 + i0) T.bind(vj, ko * 32 + tz * 16 + j0 * 4 + j1) shared_B[vi, vj + 8] = B[vi, vj] for ki in range(0, 2): for i in range(0, 2): with T.block([64, 64]) as [vi, vk]: T.bind(vi, bx * 4 + ty * 2 + i) T.bind(vk, ko * 2 + ki) T.reads( shared_A[ vi * 16 : vi * 16 + 16, vk * 16 : vk * 16 + 16 + 8, ] ) T.writes( wmma_A[vi * 16 : vi * 16 + 16, vk * 16 : vk * 16 + 16] ) s0 = T.var("int32") s1 = T.var("int32") A0 = T.match_buffer( shared_A[ vi * 16 : vi * 16 + 16, vk * 16 : vk * 16 + 16 + 8, ], (16, 16 + 8), "float16", strides=[s0, s1], scope="shared", offset_factor=1, ) wmma_A0 = T.match_buffer( wmma_A[vi * 16 : vi * 16 + 16, vk * 16 : vk * 16 + 16], (16, 16), "float16", strides=[16, 1], scope="wmma.matrix_a", offset_factor=1, ) T.evaluate( T.tvm_load_matrix_sync( wmma_A0.data, 16, 16, 16, i, T.tvm_access_ptr( T.type_annotation(dtype="float16"), A0.data, A0.elem_offset + 8, A0.strides[0], 1, dtype="handle", ), A0.strides[0], "row_major", dtype="handle", ) ) for j in range(0, 4): with T.block([64, 64]) as [vj, vk]: T.bind(vj, by * 8 + tz * 4 + j) T.bind(vk, ko * 2 + ki) T.reads( shared_B[ vj * 16 : vj * 16 + 16, vk * 16 : vk * 16 + 16 + 8, ] ) T.writes( wmma_B[vj * 16 : vj * 16 + 16, vk * 16 : vk * 16 + 16] ) s0 = T.var("int32") s1 = T.var("int32") B0 = T.match_buffer( shared_B[ vj * 16 : vj * 16 + 16, vk * 16 : vk * 16 + 16 + 8, ], (16, 16 + 8), "float16", strides=[s0, s1], scope="shared", offset_factor=1, ) wmma_B0 = T.match_buffer( wmma_B[vj * 16 : vj * 16 + 16, vk * 16 : vk * 16 + 16], (16, 16), "float16", strides=[16, 1], scope="wmma.matrix_b", offset_factor=1, ) T.evaluate( T.tvm_load_matrix_sync( wmma_B0.data, 16, 16, 16, j, T.tvm_access_ptr( T.type_annotation(dtype="float16"), B0.data, B0.elem_offset + 8, B0.strides[0], 1, dtype="handle", ), B0.strides[0], "col_major", dtype="handle", ) ) for i, j in T.grid(2, 4): with T.block([64, 64, T.reduce_axis(0, 64)]) as [ vi, vj, vk, ]: T.bind(vi, bx * 4 + ty * 2 + i) T.bind(vj, by * 8 + tz * 4 + j) T.bind(vk, ko * 2 + ki) T.reads( [ wmma_A[ vi * 16 : vi * 16 + 16, vk * 16 : vk * 16 + 16 ], wmma_B[ vj * 16 : vj * 16 + 16, vk * 16 : vk * 16 + 16 ], wmma_C[ vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16 ], ] ) T.writes( wmma_C[vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16] ) wmma_A1 = T.match_buffer( wmma_A[vi * 16 : vi * 16 + 16, vk * 16 : vk * 16 + 16], (16, 16), "float16", strides=[16, 1], scope="wmma.matrix_a", offset_factor=1, ) wmma_B1 = T.match_buffer( wmma_B[vj * 16 : vj * 16 + 16, vk * 16 : vk * 16 + 16], (16, 16), "float16", strides=[16, 1], scope="wmma.matrix_b", offset_factor=1, ) wmma_C1 = T.match_buffer( wmma_C[vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16], (16, 16), "float32", strides=[16 * 4, 1], scope="wmma.accumulator", offset_factor=1, ) T.evaluate( T.tvm_mma_sync( wmma_C1.data, i * 4 + j, wmma_A1.data, i, wmma_B1.data, j, wmma_C1.data, i * 4 + j, dtype="handle", ) ) for i, j in T.grid(2, 4): with T.block([64, 64]) as [vi, vj]: T.bind(vi, bx * 4 + ty * 2 + i) T.bind(vj, by * 8 + tz * 4 + j) T.reads(wmma_C[vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16]) T.writes(C[vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16]) s0 = T.var("int32") s1 = T.var("int32") wmma_C2 = T.match_buffer( wmma_C[vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16], (16, 16), "float32", strides=[16 * 4, 1], scope="wmma.accumulator", offset_factor=1, ) C1 = T.match_buffer( C[vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16], (16, 16), "float32", strides=[s0, s1], offset_factor=1, ) T.evaluate( T.tvm_store_matrix_sync( wmma_C2.data, 16, 16, 16, i * 4 + j, T.tvm_access_ptr( T.type_annotation(dtype="float32"), C1.data, C1.elem_offset, C1.strides[0], 1, dtype="handle", ), C1.strides[0], "row_major", dtype="handle", ) )
def opaque_access(a: T.handle, b: T.handle, c: T.handle, d: T.handle) -> None: A = T.match_buffer(a, (128, 128), dtype="float16") B = T.match_buffer(b, (128, 128), dtype="float16") C = T.match_buffer(c, (128, 128), dtype="float16") D = T.match_buffer(d, (128, 128), dtype="float16") for i, j in T.grid(128, 128): with T.block("load_store"): vi, vj = T.axis.remap("SS", [i, j]) T.reads(A[vi, vj]) T.writes(D[vi, vj]) D[vi, vj] = A[vi, vj] for i, j in T.grid(8, 8): with T.block("opaque"): vi, vj = T.axis.remap("SS", [i, j]) T.reads(A[vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16]) T.writes(B[vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16]) T.evaluate( T.tvm_load_matrix_sync( B.data, 16, 16, 16, vi * 8 + vj, T.tvm_access_ptr( T.type_annotation(dtype="float16"), A.data, vi * 2048 + vj * 16, 128, 1, dtype="handle", ), 128, "row_major", dtype="handle", ) ) for i, j in T.grid(8, 8): with T.block("match_buffer"): vi, vj = T.axis.remap("SS", [i, j]) T.reads(A[vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16]) T.writes(C[vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16]) A0 = T.match_buffer( A[ vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16, ], (16, 16), "float16", strides=[128, 1], offset_factor=1, ) C0 = T.match_buffer( C[ vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16, ], (16, 16), "float16", strides=[128, 1], offset_factor=1, ) T.evaluate( T.tvm_load_matrix_sync( C0.data, 16, 16, 16, vi * 8 + vj, T.tvm_access_ptr( T.type_annotation(dtype="float16"), A0.data, A0.elem_offset, A0.strides[0], 1, dtype="handle", ), 128, "row_major", dtype="handle", ) )
def opaque_access(a: T.handle, b: T.handle, c: T.handle, d: T.handle) -> None: A = T.match_buffer(a, (128, 128), dtype="float16") B = T.match_buffer(b, (128, 128), dtype="float16") C = T.match_buffer(c, (128, 128), dtype="float16") D = T.match_buffer(d, (128, 128), dtype="float16") with T.block([128, 128], "load_store") as [vi, vj]: T.reads(A[vi, vj]) T.writes(D[vi, vj]) D.data[vi * 128 + vj] = T.load("float16", A.data, vi * 128 + vj) with T.block([8, 8], "opaque") as [vi, vj]: T.reads(A[vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16]) T.writes(B[vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16]) T.evaluate( T.tvm_load_matrix_sync( B.data, 16, 16, 16, vi * 8 + vj, T.tvm_access_ptr( T.type_annotation(dtype="float16"), A.data, vi * 2048 + vj * 16, 128, 1, dtype="handle", ), 128, "row_major", dtype="handle", ) ) with T.block([8, 8], "match_buffer") as [vi, vj]: T.reads(A[vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16]) T.writes(C[vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16]) A0 = T.match_buffer( A[ vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16, ], (16, 16), "float16", strides=[128, 1], offset_factor=1, ) C0 = T.match_buffer( C[ vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16, ], (16, 16), "float16", strides=[128, 1], offset_factor=1, ) T.evaluate( T.tvm_load_matrix_sync( C0.data, 16, 16, 16, vi * 8 + vj, T.tvm_access_ptr( T.type_annotation(dtype="float16"), A0.data, A0.elem_offset, A0.strides[0], 1, dtype="handle", ), 128, "row_major", dtype="handle", ) )