def get_valid_counts( data: T.handle, valid_count: T.handle, out: T.handle, out_indices: T.handle, score_threshold: T.float32, id_index: T.int32, score_index: T.int32, ) -> None: data_buf = T.match_buffer(data, (1, 2500, 6), "float32") valid_count_buf = T.match_buffer(valid_count, (1, ), "int32") out_buf = T.match_buffer(out, (1, 2500, 6), "float32") out_indices_buf = T.match_buffer(out_indices, (1, 2500), "int32") with T.block([1], "init") as [vi]: valid_count_buf[vi] = T.int32(0) with T.block([2500], "update") as [vj]: T.reads([data_buf[vi, vj, 6]]) T.writes([ valid_count_buf[vi], out_indices_buf[vi, vj], out_buf[vi, vj, 6] ]) if (data_buf[vi, vj, score_index] > score_threshold) and ( (id_index < 0) or (data_buf[vi, vj, id_index] >= T.float32(0))): for k in T.serial(0, 6): out_buf[vi, valid_count_buf[vi], k] = data_buf[vi, vj, k] out_indices_buf[vi, valid_count_buf[vi]] = vj valid_count_buf[vi] = valid_count_buf[vi] + 1 if vj >= valid_count_buf[vi]: for k in T.serial(0, 6): out_buf[vi, vj, k] = T.float32(-1) out_indices_buf[vi, vj] = T.int32(-1)
def dot_product_4x4_i8i8i32_neon( A: T.Buffer((4, ), "int8", offset_factor=1), B: T.Buffer((4, 4), "int8", offset_factor=1), C: T.Buffer((4, ), "int32", offset_factor=1), ) -> None: with T.block("root"): T.reads(C[0:4], A[0:4], B[0:4, 0:4]) T.writes(C[0:4]) A_int8 = A.vload([0], "int8x4") re_int32 = T.reinterpret(A_int8, dtype="int32") vec_ai32 = T.broadcast(re_int32, 2) vec_a = T.reinterpret(vec_ai32, dtype="int8x8") vec_b = B.vload([0, 0], dtype="int8x16") # TODO(masahi): Remove duplication when inlined function call is supported vec_b_low = T.vectorlow(vec_b, dtype="int8x8") multiply_low = T.call_llvm_pure_intrin( T.llvm_lookup_intrinsic_id("llvm.aarch64.neon.smull.v8i16"), T.uint32(2), vec_a, vec_b_low, dtype="int16x8", ) pairwise_reduction_low = T.call_llvm_pure_intrin( T.llvm_lookup_intrinsic_id("llvm.aarch64.neon.saddlp.v4i32.v8i16"), T.uint32(1), multiply_low, dtype="int32x4", ) vec_b_high = T.vectorhigh(vec_b, dtype="int8x8") multiply_high = T.call_llvm_pure_intrin( T.llvm_lookup_intrinsic_id("llvm.aarch64.neon.smull.v8i16"), T.uint32(2), vec_a, vec_b_high, dtype="int16x8", ) pairwise_reduction_high = T.call_llvm_pure_intrin( T.llvm_lookup_intrinsic_id("llvm.aarch64.neon.saddlp.v4i32.v8i16"), T.uint32(1), multiply_high, dtype="int32x4", ) C[T.ramp(T.int32(0), 1, 4)] += T.call_llvm_pure_intrin( T.llvm_lookup_intrinsic_id("llvm.aarch64.neon.addp.v4i32"), T.uint32(2), pairwise_reduction_low, pairwise_reduction_high, dtype="int32x4", )
def dp4a_impl( A: T.Buffer((4,), "int8", offset_factor=1, align=4, scope="shared"), B: T.Buffer((4,), "int8", offset_factor=1, align=4, scope="shared"), C: T.Buffer((1,), "int32", offset_factor=1, align=4, scope="local"), ) -> None: with T.block("root"): T.reads(C[0], A[0:4], B[0:4]) T.writes(C[0]) C[0] += T.call_pure_extern( "__dp4a", A.vload([0], "int8x4"), B.vload([0], "int8x4"), T.int32(0), dtype="int32" )
def gemm_mma_m16n8k32_row_col_s8s8s32(a: T.handle, b: T.handle, c: T.handle): T.func_attr({"global_symbol": "default_function", "tir.noalias": True}) A = T.match_buffer(a, [16, 32], dtype="int8") B = T.match_buffer(b, [8, 32], dtype="int8") C = T.match_buffer(c, [16, 8], dtype="int32") brow = T.env_thread("blockIdx.y") bcol = T.env_thread("blockIdx.x") tx = T.env_thread("threadIdx.x") T.launch_thread(brow, 1) T.launch_thread(bcol, 1) T.launch_thread(tx, 32) MultiA = T.allocate([16], "int8", scope="local") MultiB = T.allocate([8], "int8", scope="local") Accum = T.allocate([4], "int32", scope="local") for i in range(4): Accum[i] = T.int32(0) for mma_multi_a_col in range(16): MultiA[mma_multi_a_col] = A[ (tx % 32) // 4 + mma_multi_a_col % 8 // 4 * 8, (tx % 32) % 4 * 4 + mma_multi_a_col % 4 + mma_multi_a_col // 8 * 16, ] for mma_multi_b_col in range(8): MultiB[mma_multi_b_col] = B[ (tx % 32) // 4, (tx % 32) % 4 * 4 + mma_multi_b_col % 4 + mma_multi_b_col // 4 * 16, ] T.evaluate( T.ptx_mma( "m16n8k32", "row", "col", "int8", "int8", "int32", MultiA, 0, MultiB, 0, Accum, 0, False, dtype="int32", ) ) for mma_accum_c_id in range(4): C[ (tx % 32) // 4 + mma_accum_c_id // 2 * 8, (tx % 32) % 4 * 2 + mma_accum_c_id % 2, ] = T.load("int32", Accum, mma_accum_c_id)
def gemm_mma_m16n8k256_row_col_b1b1s32(a: T.handle, b: T.handle, c: T.handle): T.func_attr({"global_symbol": "default_function", "tir.noalias": True}) A = T.match_buffer(a, [16, 256], dtype="int1") B = T.match_buffer(b, [8, 256], dtype="int1") C = T.match_buffer(c, [16, 8], dtype="int32") brow = T.env_thread("blockIdx.y") bcol = T.env_thread("blockIdx.x") tx = T.env_thread("threadIdx.x") T.launch_thread(brow, 1) T.launch_thread(bcol, 1) T.launch_thread(tx, 32) MultiA = T.allocate([128], "int1", scope="local") MultiB = T.allocate([64], "int1", scope="local") Accum = T.allocate([4], "int32", scope="local") for i in range(4): Accum[i] = T.int32(0) for mma_multi_a_col in range(128): MultiA[mma_multi_a_col] = A[ (tx % 32) // 4 + mma_multi_a_col % 64 // 32 * 8, (tx % 32) % 4 * 32 + mma_multi_a_col % 32 + mma_multi_a_col // 64 * 128, ] for mma_multi_b_col in range(16): MultiB[mma_multi_b_col] = B[ (tx % 32) // 4, (tx % 32) % 4 * 32 + mma_multi_b_col % 32 + mma_multi_b_col // 32 * 128, ] T.evaluate( T.ptx_mma( "m16n8k256", "row", "col", "int1", "int1", "int32", MultiA.data, 0, MultiB.data, 0, Accum.data, 0, False, dtype="int32", ) ) for mma_accum_c_id in range(4): C[ (tx % 32) // 4 + mma_accum_c_id // 2 * 8, (tx % 32) % 4 * 2 + mma_accum_c_id % 2, ] = Accum[mma_accum_c_id]
def dot_product_16x4_u8i8i32_desc( A: T.Buffer((4,), "uint8", offset_factor=1), B: T.Buffer((16, 4), "int8", offset_factor=1), C: T.Buffer((16,), "int32", offset_factor=1), ) -> None: with T.block("root"): T.reads(C[0:16], A[0:4], B[0:16, 0:4]) T.writes(C[0:16]) for i in T.serial(0, 16): with T.init(): C[i] = T.int32(0) for k in T.serial(0, 4): with T.block("update"): vi, vk = T.axis.remap("SR", [i, k]) C[vi] = C[vi] + T.cast(A[vk], "int32") * T.cast(B[vi, vk], "int32")
def sdot4( A: T.Buffer((4,), "int8", offset_factor=1, align=4, scope="shared"), B: T.Buffer((4,), "int8", offset_factor=1, align=4, scope="shared"), C: T.Buffer((1,), "int32", offset_factor=1, align=4, scope="local"), ) -> None: with T.block("root"): T.reads(C[0], A[0:4], B[0:4]) T.writes(C[0]) C[0] += T.call_llvm_pure_intrin( T.llvm_lookup_intrinsic_id("llvm.amdgcn.sdot4"), T.uint32(4), T.reinterpret(A.vload([0], "int8x4"), dtype="int32"), T.reinterpret(B.vload([0], "int8x4"), dtype="int32"), T.int32(0), T.bool(1), dtype="int32", )
def gemm_mma_m8n8k32_row_col_s4u4s32(a: T.handle, b: T.handle, c: T.handle): T.func_attr({"global_symbol": "default_function", "tir.noalias": True}) A = T.match_buffer(a, [8, 32], dtype="int4") B = T.match_buffer(b, [8, 32], dtype="uint4") C = T.match_buffer(c, [8, 8], dtype="int32") brow = T.env_thread("blockIdx.y") bcol = T.env_thread("blockIdx.x") tx = T.env_thread("threadIdx.x") T.launch_thread(brow, 1) T.launch_thread(bcol, 1) T.launch_thread(tx, 32) MultiA = T.allocate([8], "int4", scope="local") MultiB = T.allocate([8], "uint4", scope="local") Accum = T.allocate([2], "int32", scope="local") for i in range(2): Accum[i] = T.int32(0) for mma_multi_a_col in T.vectorized(8): MultiA[mma_multi_a_col] = A[(tx % 32) // 4, mma_multi_a_col + (tx % 32) % 4 * 8] for mma_multi_b_col in T.vectorized(8): MultiB[mma_multi_b_col] = B[(tx % 32) // 4, mma_multi_b_col + (tx % 32) % 4 * 8] T.evaluate( T.ptx_mma( "m8n8k32", "row", "col", "int4", "uint4", "int32", MultiA, 0, MultiB, 0, Accum, 0, False, dtype="int32", ) ) for mma_accum_c_id in range(2): C[(tx % 32) // 4, (tx % 32) % 4 * 2 + mma_accum_c_id] = T.load( "int32", Accum, mma_accum_c_id )
def tir_argmax_val_idx( var_val: T.handle, var_idx: T.handle, var_argmax_v0: T.handle, var_argmax_v1: T.handle ) -> None: T.func_attr({"global_symbol": "main", "tir.noalias": True}) m = T.var("int32") n = T.var("int32") val = T.match_buffer(var_val, [m, n], dtype="float32") idx = T.match_buffer(var_idx, [m, n], dtype="int32") argmax_v0 = T.match_buffer(var_argmax_v0, [m], dtype="float32") argmax_v1 = T.match_buffer(var_argmax_v1, [m], dtype="int32") for i0, i1 in T.grid(m, n): with T.block("argmax"): i, k = T.axis.remap("SR", [i0, i1]) T.reads(val[i, k], idx[i, k]) T.writes(argmax_v0[i], argmax_v1[i]) with T.init(): argmax_v0[i] = T.min_value("float32") argmax_v1[i] = T.int32(-1) v_argmax_v0: T.float32 = T.Select(argmax_v0[i] >= val[i, k], argmax_v0[i], val[i, k]) v_argmax_v1: T.int32 = T.Select(argmax_v0[i] >= val[i, k], argmax_v1[i], idx[i, k]) argmax_v0[i] = v_argmax_v0 argmax_v1[i] = v_argmax_v1
def dot_product_16x4_u8i8i32_vnni( A: T.Buffer((4,), "uint8", offset_factor=1), B: T.Buffer((16, 4), "int8", offset_factor=1), C: T.Buffer((16,), "int32", offset_factor=1), ) -> None: with T.block("root"): T.reads(C[0:16], A[0:4], B[0:16, 0:4]) T.writes(C[0:16]) A_u8x4 = A.vload([0], "uint8x4") A_i32 = T.reinterpret(A_u8x4, dtype="int32") B_i8x64 = B.vload([0, 0], dtype="int8x64") B_i32x16 = T.reinterpret(B_i8x64, dtype="int32x16") C[T.ramp(T.int32(0), 1, 16)] += T.call_llvm_pure_intrin( # Note: this is an update += T.llvm_lookup_intrinsic_id("llvm.x86.avx512.vpdpbusd.512"), T.uint32(0), T.int32x16(0), T.broadcast(A_i32, 16), B_i32x16, dtype="int32x16", )
def dot_product_4x4_i8i8i32_sdot( A: T.Buffer((4, ), "int8", offset_factor=1), B: T.Buffer((4, 4), "int8", offset_factor=1), C: T.Buffer((4, ), "int32", offset_factor=1), ) -> None: with T.block("root"): T.reads(C[0:4], A[0:4], B[0:4, 0:4]) T.writes(C[0:4]) A_i8x4 = A.vload([0], "int8x4") A_i32 = T.reinterpret(A_i8x4, dtype="int32") vec_ai32 = T.broadcast(A_i32, 4) vec_a = T.reinterpret(vec_ai32, dtype="int8x16") vec_b = B.vload([0, 0], dtype="int8x16") C[T.ramp(T.int32(0), 1, 4)] += T.call_llvm_pure_intrin( T.llvm_lookup_intrinsic_id("llvm.aarch64.neon.sdot.v4i32.v16i8"), T.uint32(3), T.int32x4(0), vec_a, vec_b, dtype="int32x4", )
def constant_binds_wrapped(): x = T.int32(1) y = T.float32(42.0) T.evaluate(T.cast(x, "float32") + y)
def preflattened_buffer_map(A: T.handle, B: T.handle): A_1 = T.match_buffer(A, [1]) T.preflattened_buffer(A_1, [1], align=T.int32(1), offset_factor=T.int64(2)) B_1 = T.match_buffer(B, [1]) T.preflattened_buffer(B_1, [1]) B_1[0] = A_1[0]