def dot_product_4x4_i8i8i32_neon( A: T.Buffer((4, ), "int8", offset_factor=1), B: T.Buffer((4, 4), "int8", offset_factor=1), C: T.Buffer((4, ), "int32", offset_factor=1), ) -> None: with T.block("root"): T.reads(C[0:4], A[0:4], B[0:4, 0:4]) T.writes(C[0:4]) A_int8 = A.vload([0], "int8x4") re_int32 = T.reinterpret(A_int8, dtype="int32") vec_ai32 = T.broadcast(re_int32, 2) vec_a = T.reinterpret(vec_ai32, dtype="int8x8") vec_b = B.vload([0, 0], dtype="int8x16") # TODO(masahi): Remove duplication when inlined function call is supported vec_b_low = T.vectorlow(vec_b, dtype="int8x8") multiply_low = T.call_llvm_pure_intrin( T.llvm_lookup_intrinsic_id("llvm.aarch64.neon.smull.v8i16"), T.uint32(2), vec_a, vec_b_low, dtype="int16x8", ) pairwise_reduction_low = T.call_llvm_pure_intrin( T.llvm_lookup_intrinsic_id("llvm.aarch64.neon.saddlp.v4i32.v8i16"), T.uint32(1), multiply_low, dtype="int32x4", ) vec_b_high = T.vectorhigh(vec_b, dtype="int8x8") multiply_high = T.call_llvm_pure_intrin( T.llvm_lookup_intrinsic_id("llvm.aarch64.neon.smull.v8i16"), T.uint32(2), vec_a, vec_b_high, dtype="int16x8", ) pairwise_reduction_high = T.call_llvm_pure_intrin( T.llvm_lookup_intrinsic_id("llvm.aarch64.neon.saddlp.v4i32.v8i16"), T.uint32(1), multiply_high, dtype="int32x4", ) C[T.ramp(T.int32(0), 1, 4)] += T.call_llvm_pure_intrin( T.llvm_lookup_intrinsic_id("llvm.aarch64.neon.addp.v4i32"), T.uint32(2), pairwise_reduction_low, pairwise_reduction_high, dtype="int32x4", )
def lowered_reducer_max(a: T.handle, b: T.handle) -> None: A = T.match_buffer(a, [128, 128], dtype="float32") B = T.match_buffer(b, [128], dtype="float32") reduce_temp0 = T.alloc_buffer([1], dtype="float32", strides=[1], scope="local") for i in T.serial(0, 128): for k in T.thread_binding(0, 128, thread="threadIdx.x"): with T.block("B_cross_thread_reduction"): vi, vk = T.axis.remap("SR", [i, k]) T.reads([A[vi, vk]]) T.writes([reduce_temp0[0]]) T.attr( T.comm_reducer(lambda x, y: T.max(x, y), [T.min_value("float32")]), "reduce_scope", T.reinterpret(T.uint64(0), dtype="handle"), ) T.evaluate( T.tvm_thread_allreduce(T.uint32(1), A[vi, vk], True, reduce_temp0.data, k, dtype="handle")) with T.block("B_write_back"): vi = T.axis.spatial(128, i) T.reads([reduce_temp0[0]]) T.writes([B[vi]]) B[vi] = reduce_temp0[0]
def lowered_zero_rank_buffer(a: T.handle, b: T.handle) -> None: A = T.match_buffer(a, [128], dtype="float32") B = T.match_buffer(b, [], dtype="float32") reduce_temp0 = T.alloc_buffer([1], dtype="float32", strides=[1], scope="local") for k in T.thread_binding(0, 128, thread="threadIdx.x"): with T.block("B_cross_thread_reduction"): vk = T.axis.reduce(128, k) T.reads([A[vk]]) T.writes([reduce_temp0[0]]) T.attr( T.comm_reducer(lambda x, y: x + y, [T.float32(0)]), "reduce_scope", T.reinterpret(T.uint64(0), dtype="handle"), ) T.evaluate( T.tvm_thread_allreduce(T.uint32(1), A[vk], True, reduce_temp0.data, k, dtype="handle")) with T.block("B_write_back"): T.reads([reduce_temp0[0]]) T.writes([B[()]]) B[()] = reduce_temp0[0]
def lowered_two_bound_loops(a: T.handle, b: T.handle) -> None: A = T.match_buffer(a, [128, 128], dtype="float32") B = T.match_buffer(b, [128], dtype="float32") reduce_temp0 = T.alloc_buffer([1], dtype="float32", strides=[1], scope="local") for i in T.serial(0, 128): for ko in T.thread_binding(0, 4, thread="threadIdx.x"): for ki in T.thread_binding(0, 32, thread="threadIdx.y"): with T.block("B_cross_thread_reduction"): vi = T.axis.spatial(128, i) vk = T.axis.reduce(128, ko * 32 + ki) T.reads([A[vi, vk]]) T.writes([reduce_temp0[0]]) T.attr( T.comm_reducer(lambda x, y: x + y, [T.float32(0)]), "reduce_scope", T.reinterpret(T.uint64(0), dtype="handle"), ) T.evaluate( T.tvm_thread_allreduce(T.uint32(1), A[vi, vk], True, reduce_temp0.data, ko, ki, dtype="handle")) with T.block("B_write_back"): vi = T.axis.spatial(128, i) T.reads([reduce_temp0[0]]) T.writes([B[vi]]) B[vi] = reduce_temp0[0]
def lowered_multiple_blocks_under_reduction_loop(a: T.handle, b: T.handle) -> None: A = T.match_buffer(a, [16, 16, 16], dtype="float32") B = T.match_buffer(b, [16], dtype="float32") B_rf_local = T.alloc_buffer([16, 16], dtype="float32", scope="local") reduce_temp0 = T.alloc_buffer([1], dtype="float32", strides=[1], scope="local") normal_reduce_temp0 = T.alloc_buffer([1], dtype="float32", strides=[1], scope="local") for i in T.thread_binding(0, 16, thread="blockIdx.x"): for k0o in T.thread_binding(0, 4, thread="threadIdx.x"): with T.block("B_in_thread_init"): T.reads([]) T.writes([normal_reduce_temp0[0]]) normal_reduce_temp0[0] = T.float32(0) for k0i0, k1 in T.grid(4, 16): with T.block("B_rf"): vk0 = T.axis.spatial(16, k0o * 4 + k0i0) vi, vk1 = T.axis.remap("SR", [i, k1]) T.reads([B_rf_local[vk0, vi], A[vi, vk0, vk1]]) T.writes([B_rf_local[vk0, vi]]) with T.init(): B_rf_local[vk0, vi] = T.float32(0) B_rf_local[vk0, vi] = B_rf_local[vk0, vi] + A[vi, vk0, vk1] for k0i1 in T.serial(0, 4): with T.block("B_normal_reduction"): vk0 = T.axis.reduce(16, k0o * 4 + k0i1) vi = T.axis.spatial(16, i) T.reads([B_rf_local[vk0, vi], normal_reduce_temp0[0]]) T.writes([normal_reduce_temp0[0]]) normal_reduce_temp0[ 0] = normal_reduce_temp0[0] + B_rf_local[vk0, vi] with T.block("B_cross_thread_reduction"): T.reads([normal_reduce_temp0[0]]) T.writes([reduce_temp0[0]]) T.attr( T.comm_reducer(lambda x, y: x + y, [T.float32(0)]), "reduce_scope", T.reinterpret(T.uint64(0), dtype="handle"), ) T.evaluate( T.tvm_thread_allreduce( T.uint32(1), normal_reduce_temp0[0], True, reduce_temp0.data, k0o, dtype="handle", )) with T.block("B_write_back"): vi = T.axis.spatial(16, i) T.reads([reduce_temp0[0]]) T.writes([B[vi]]) B[vi] = reduce_temp0[0]
def tir_packed_call() -> None: A = T.var("handle") B = T.var("handle") C = T.var("handle") device_context = T.var("handle") # body T.evaluate( T.tvm_call_cpacked( "tvm_test_cpacked", T.tvm_stack_make_array( A, T.tvm_stack_make_shape(1, dtype="handle"), T.reinterpret(T.uint64(0), dtype="handle"), T.uint32(1), T.cast(0, dtype="float32"), 0, dtype="handle", ), T.tvm_stack_make_array( B, T.tvm_stack_make_shape(1, dtype="handle"), T.reinterpret(T.uint64(0), dtype="handle"), T.uint32(1), T.cast(0, dtype="float32"), 0, dtype="handle", ), T.tvm_stack_make_array( C, T.tvm_stack_make_shape(1, dtype="handle"), T.reinterpret(T.uint64(0), dtype="handle"), T.uint32(1), T.cast(0, dtype="float32"), 0, dtype="handle", ), device_context, dtype="int32", ))
def lowered_with_block_predicate(a: T.handle, b: T.handle) -> None: A = T.match_buffer(a, [128, 120], dtype="float32") B = T.match_buffer(b, [128], dtype="float32") reduce_temp0 = T.alloc_buffer([1], dtype="float32", strides=[1], scope="local") normal_reduce_temp0 = T.alloc_buffer([1], dtype="float32", strides=[1], scope="local") for i in T.serial(0, 128): for ki in T.thread_binding(0, 32, thread="threadIdx.x"): with T.block("B_in_thread_init"): T.reads([]) T.writes([normal_reduce_temp0[0]]) normal_reduce_temp0[0] = T.float32(0) for ko in T.serial(0, 4): with T.block("B_normal_reduction"): vi = T.axis.spatial(128, i) vk = T.axis.reduce(120, ko * 32 + ki) T.where(ko * 32 + ki < 120) T.reads([A[vi, vk], normal_reduce_temp0[0]]) T.writes([normal_reduce_temp0[0]]) normal_reduce_temp0[0] = normal_reduce_temp0[0] + A[vi, vk] with T.block("B_cross_thread_reduction"): T.reads([normal_reduce_temp0[0]]) T.writes([reduce_temp0[0]]) T.attr( T.comm_reducer(lambda x, y: x + y, [T.float32(0)]), "reduce_scope", T.reinterpret(T.uint64(0), dtype="handle"), ) T.evaluate( T.tvm_thread_allreduce( T.uint32(1), normal_reduce_temp0[0], True, reduce_temp0[0], ki, dtype="handle", )) with T.block("B_write_back"): vi = T.axis.spatial(128, i) T.reads([reduce_temp0[0]]) T.writes([B[vi]]) B[vi] = reduce_temp0[0]
def sdot4( A: T.Buffer((4,), "int8", offset_factor=1, align=4, scope="shared"), B: T.Buffer((4,), "int8", offset_factor=1, align=4, scope="shared"), C: T.Buffer((1,), "int32", offset_factor=1, align=4, scope="local"), ) -> None: with T.block("root"): T.reads(C[0], A[0:4], B[0:4]) T.writes(C[0]) C[0] += T.call_llvm_pure_intrin( T.llvm_lookup_intrinsic_id("llvm.amdgcn.sdot4"), T.uint32(4), T.reinterpret(A.vload([0], "int8x4"), dtype="int32"), T.reinterpret(B.vload([0], "int8x4"), dtype="int32"), T.int32(0), T.bool(1), dtype="int32", )
def dot_product_16x4_u8i8i32_vnni( A: T.Buffer((4,), "uint8", offset_factor=1), B: T.Buffer((16, 4), "int8", offset_factor=1), C: T.Buffer((16,), "int32", offset_factor=1), ) -> None: with T.block("root"): T.reads(C[0:16], A[0:4], B[0:16, 0:4]) T.writes(C[0:16]) A_u8x4 = A.vload([0], "uint8x4") A_i32 = T.reinterpret(A_u8x4, dtype="int32") B_i8x64 = B.vload([0, 0], dtype="int8x64") B_i32x16 = T.reinterpret(B_i8x64, dtype="int32x16") C[T.ramp(T.int32(0), 1, 16)] += T.call_llvm_pure_intrin( # Note: this is an update += T.llvm_lookup_intrinsic_id("llvm.x86.avx512.vpdpbusd.512"), T.uint32(0), T.int32x16(0), T.broadcast(A_i32, 16), B_i32x16, dtype="int32x16", )
def dot_product_4x4_i8i8i32_sdot( A: T.Buffer((4, ), "int8", offset_factor=1), B: T.Buffer((4, 4), "int8", offset_factor=1), C: T.Buffer((4, ), "int32", offset_factor=1), ) -> None: with T.block("root"): T.reads(C[0:4], A[0:4], B[0:4, 0:4]) T.writes(C[0:4]) A_i8x4 = A.vload([0], "int8x4") A_i32 = T.reinterpret(A_i8x4, dtype="int32") vec_ai32 = T.broadcast(A_i32, 4) vec_a = T.reinterpret(vec_ai32, dtype="int8x16") vec_b = B.vload([0, 0], dtype="int8x16") C[T.ramp(T.int32(0), 1, 4)] += T.call_llvm_pure_intrin( T.llvm_lookup_intrinsic_id("llvm.aarch64.neon.sdot.v4i32.v16i8"), T.uint32(3), T.int32x4(0), vec_a, vec_b, dtype="int32x4", )
def lowered_softmax(var_A: T.handle, var_T_softmax_norm: T.handle) -> None: A = T.match_buffer(var_A, [256, 256], dtype="float32") T_softmax_norm = T.match_buffer(var_T_softmax_norm, [256, 256], dtype="float32") T_softmax_maxelem_shared = T.alloc_buffer([256], dtype="float32", scope="shared") T_softmax_expsum_shared = T.alloc_buffer([256], dtype="float32", scope="shared") reduce_temp0 = T.alloc_buffer([1], dtype="float32", strides=[1], scope="local") normal_reduce_temp0 = T.alloc_buffer([1], dtype="float32", strides=[1], scope="local") reduce_temp1 = T.alloc_buffer([1], dtype="float32", strides=[1], scope="local") normal_reduce_temp1 = T.alloc_buffer([1], dtype="float32", strides=[1], scope="local") for i0 in T.thread_binding(0, 256, thread="blockIdx.x"): for ax0_1 in T.thread_binding(0, 32, thread="threadIdx.x"): with T.block("T_softmax_maxelem_normal_reduction_init"): T.reads([]) T.writes([normal_reduce_temp0[0]]) normal_reduce_temp0[0] = T.min_value("float32") for ax0_0 in T.serial(0, 8): with T.block("T_softmax_maxelem_normal_reduction"): i0_1 = T.axis.spatial(256, i0) k = T.axis.reduce(256, ax0_0 * 32 + ax0_1) T.reads([A[i0_1, k], normal_reduce_temp0[0]]) T.writes([normal_reduce_temp0[0]]) normal_reduce_temp0[0] = T.max(normal_reduce_temp0[0], A[i0_1, k]) with T.block("T_softmax_maxelem_cross_thread_reduction"): T.reads([normal_reduce_temp0[0]]) T.writes([reduce_temp0[0]]) T.attr( T.comm_reducer(lambda x, y: T.max(x, y), [T.min_value("float32")]), "reduce_scope", T.reinterpret(T.uint64(0), dtype="handle"), ) T.evaluate( T.tvm_thread_allreduce( T.uint32(1), normal_reduce_temp0[0], True, reduce_temp0.data, ax0_1, dtype="handle", )) with T.block("T_softmax_maxelem_write_back"): i0_2 = T.axis.spatial(256, i0) T.reads([reduce_temp0[0]]) T.writes([T_softmax_maxelem_shared[i0_2]]) T_softmax_maxelem_shared[i0_2] = reduce_temp0[0] for ax0_1 in T.thread_binding(0, 32, thread="threadIdx.x"): with T.block("T_softmax_expsum_normal_reduction_init"): T.reads([]) T.writes([normal_reduce_temp1[0]]) normal_reduce_temp1[0] = T.float32(0) for ax0_0 in T.serial(0, 8): with T.block("T_softmax_expsum_normal_reduction"): i0_3 = T.axis.spatial(256, i0) k = T.axis.reduce(256, ax0_0 * 32 + ax0_1) T.reads([ A[i0_3, k], T_softmax_maxelem_shared[i0_3], normal_reduce_temp1[0], ]) T.writes([normal_reduce_temp1[0]]) normal_reduce_temp1[0] = normal_reduce_temp1[0] + T.exp( A[i0_3, k] - T_softmax_maxelem_shared[i0_3], dtype="float32") with T.block("T_softmax_expsum_cross_thread_reduction"): T.reads([normal_reduce_temp1[0]]) T.writes([reduce_temp1[0]]) T.attr( T.comm_reducer(lambda x_1, y_1: x_1 + y_1, [T.float32(0)]), "reduce_scope", T.reinterpret(T.uint64(0), dtype="handle"), ) T.evaluate( T.tvm_thread_allreduce( T.uint32(1), normal_reduce_temp1[0], True, reduce_temp1.data, ax0_1, dtype="handle", )) with T.block("T_softmax_expsum_write_back"): i0_4 = T.axis.spatial(256, i0) T.reads([reduce_temp1[0]]) T.writes([T_softmax_expsum_shared[i0_4]]) T_softmax_expsum_shared[i0_4] = reduce_temp1[0] for i1_0 in T.serial(0, 8): for i1_1 in T.thread_binding(0, 32, thread="threadIdx.x"): with T.block("T_softmax_norm"): i0_5 = T.axis.spatial(256, i0) i1 = T.axis.spatial(256, i1_0 * 32 + i1_1) T.reads([ A[i0_5, i1], T_softmax_maxelem_shared[i0_5], T_softmax_expsum_shared[i0_5], ]) T.writes([T_softmax_norm[i0_5, i1]]) T.block_attr({"axis": 1}) T_softmax_norm[i0_5, i1] = (T.exp( A[i0_5, i1] - T_softmax_maxelem_shared[i0_5], dtype="float32", ) / T_softmax_expsum_shared[i0_5])
def lowered_single_reduction_loop_with_block_predicate( A: T.Buffer[(256, 256), "float32"], T_softmax_norm: T.Buffer[(256, 256), "float32"]) -> None: T_softmax_maxelem_shared = T.alloc_buffer([256], dtype="float32", scope="shared") T_softmax_expsum_shared = T.alloc_buffer([256], dtype="float32", scope="shared") cross_thread_0 = T.alloc_buffer([1], dtype="float32", strides=[1], scope="local") in_thread_0 = T.alloc_buffer([1], dtype="float32", strides=[1], scope="local") cross_thread_1 = T.alloc_buffer([1], dtype="float32", strides=[1], scope="local") in_thread_1 = T.alloc_buffer([1], dtype="float32", strides=[1], scope="local") for i0 in T.serial(256): for ax0, ax1_0 in T.grid(1, 1): for ax1_1 in T.thread_binding(512, thread="threadIdx.x"): with T.block("T_softmax_maxelem_in_thread_init"): T.reads() T.writes(in_thread_0[0]) in_thread_0[0] = T.float32(-3.4028234663852886e38) with T.block("T_softmax_maxelem_in_thread"): i0_1 = T.axis.spatial(256, i0) k = T.axis.reduce(256, ax1_1) T.where(ax1_0 * 512 + ax1_1 < 256) T.reads(A[i0_1, k], in_thread_0[0]) T.writes(in_thread_0[0]) in_thread_0[0] = T.max(in_thread_0[0], A[i0_1, k]) with T.block("T_softmax_maxelem_cross_thread"): T.reads(in_thread_0[0]) T.writes(cross_thread_0[0]) T.attr( T.comm_reducer(lambda x, y: T.max(x, y), [T.float32(-3.4028234663852886e38)]), "reduce_scope", T.reinterpret(T.uint64(0), dtype="handle"), ) T.evaluate( T.tvm_thread_allreduce( T.uint32(1), in_thread_0[0], True, cross_thread_0.data, ax1_1, dtype="handle", )) with T.block("T_softmax_maxelem_write_back"): i0_2 = T.axis.spatial(256, i0) T.reads(cross_thread_0[0]) T.writes(T_softmax_maxelem_shared[i0_2]) T_softmax_maxelem_shared[i0_2] = cross_thread_0[0] for ax0, ax1_0 in T.grid(1, 1): for ax1_1 in T.thread_binding(512, thread="threadIdx.x"): with T.block("T_softmax_expsum_in_thread_init"): T.reads() T.writes(in_thread_1[0]) in_thread_1[0] = T.float32(0) with T.block("T_softmax_expsum_in_thread"): i0_3 = T.axis.spatial(256, i0) k = T.axis.reduce(256, ax1_1) T.where(ax1_0 * 512 + ax1_1 < 256) T.reads(A[i0_3, k], T_softmax_maxelem_shared[i0_3], in_thread_1[0]) T.writes(in_thread_1[0]) in_thread_1[0] = in_thread_1[0] + T.exp( A[i0_3, k] - T_softmax_maxelem_shared[i0_3], dtype="float32") with T.block("T_softmax_expsum_cross_thread"): T.reads(in_thread_1[0]) T.writes(cross_thread_1[0]) T.attr( T.comm_reducer(lambda x_1, y_1: x_1 + y_1, [T.float32(0)]), "reduce_scope", T.reinterpret(T.uint64(0), dtype="handle"), ) T.evaluate( T.tvm_thread_allreduce( T.uint32(1), in_thread_1[0], True, cross_thread_1.data, ax1_1, dtype="handle", )) with T.block("T_softmax_expsum_write_back"): i0_4 = T.axis.spatial(256, i0) T.reads(cross_thread_1[0]) T.writes(T_softmax_expsum_shared[i0_4]) T_softmax_expsum_shared[i0_4] = cross_thread_1[0] for i1_0 in T.serial(1): for i1_1 in T.thread_binding(512, thread="threadIdx.x"): with T.block("T_softmax_norm"): i0_5 = T.axis.spatial(256, i0) i1 = T.axis.spatial(256, i1_1) T.where(i1_0 * 512 + i1_1 < 256) T.reads(A[i0_5, i1], T_softmax_maxelem_shared[i0_5], T_softmax_expsum_shared[i0_5]) T.writes(T_softmax_norm[i0_5, i1]) T.block_attr({"axis": 1}) T_softmax_norm[i0_5, i1] = ( T.exp(A[i0_5, i1] - T_softmax_maxelem_shared[i0_5], dtype="float32") / T_softmax_expsum_shared[i0_5])
def main( placeholder: T.Buffer[(1, 4, 56, 56, 16), "uint8"], placeholder_1: T.Buffer[(16, 4, 1, 1, 4, 16, 4), "int8"], conv2d_NCHWc_int8: T.Buffer[(1, 16, 56, 56, 16), "int32"], ) -> None: # function attr dict T.func_attr({"global_symbol": "main", "tir.noalias": True}) # body # with T.block("root") for i0_0, i1_0, i2_0, i3_0, i4_0_0, i0_1, i1_1, i2_1, i3_1, i4_0_1, i5_0, i6_0 in T.grid( 1, 1, 2, 1, 1, 1, 4, 1, 14, 1, 1, 1): for i1_2_init, i2_2_init, i2_3_init, i3_3_init in T.grid( 4, 7, 4, 4): with T.block("conv2d_NCHWc_int8_o_init"): n = T.axis.spatial(1, 0) oc_chunk = T.axis.spatial(16, i1_1 * 4 + i1_2_init) oh = T.axis.spatial(56, i2_0 * 28 + i2_2_init * 4 + i2_3_init) ow = T.axis.spatial(56, i3_1 * 4 + i3_3_init) oc_block_o = T.axis.spatial(1, 0) T.reads() T.writes(conv2d_NCHWc_int8[n, oc_chunk, oh, ow, 0:16]) for i4_1 in T.vectorized(16): with T.block("conv2d_NCHWc_int8_init"): oc_block_init = T.axis.spatial(16, i4_1) T.reads() T.writes(conv2d_NCHWc_int8[n, oc_chunk, oh, ow, oc_block_init]) conv2d_NCHWc_int8[n, oc_chunk, oh, ow, oc_block_init] = 0 for ( i7_0, i8_0, i9_0_0, i0_2, i1_2, i2_2, i3_2, i4_0_2, i5_1, i6_1, i7_1, i8_1, i9_0_1, i0_3, i1_3, i2_3, i3_3, i4_0_3, ) in T.grid(4, 1, 1, 1, 4, 7, 1, 1, 1, 1, 1, 4, 1, 1, 1, 4, 4, 1): with T.block("conv2d_NCHWc_int8_o_update"): n = T.axis.spatial(1, 0) oc_chunk = T.axis.spatial(16, i1_1 * 4 + i1_2) oh = T.axis.spatial(56, i2_0 * 28 + i2_2 * 4 + i2_3) ow = T.axis.spatial(56, i3_1 * 4 + i3_3) oc_block_o = T.axis.spatial(1, 0) kh = T.axis.reduce(1, 0) kw = T.axis.reduce(1, 0) ic_outer, ic_f_inner = T.axis.remap("RR", [i7_0, i8_1]) ic_s_inner_o = T.axis.reduce(1, 0) T.reads( conv2d_NCHWc_int8[n, oc_chunk, oh, ow, 0:16], placeholder[n, ic_outer, oh + kh, ow + kw, ic_f_inner * 4:ic_f_inner * 4 + 4], placeholder_1[oc_chunk, ic_outer, kh, kw, ic_f_inner, 0:16, 0:4], ) T.writes(conv2d_NCHWc_int8[n, oc_chunk, oh, ow, 0:16]) A = T.match_buffer( placeholder[n, ic_outer, oh + kh, ow + kw, ic_f_inner * 4:ic_f_inner * 4 + 4], [4], dtype="uint8", offset_factor=1, ) B = T.match_buffer( placeholder_1[oc_chunk, ic_outer, kh, kw, ic_f_inner, 0:16, 0:4], [16, 4], dtype="int8", offset_factor=1, ) C = T.match_buffer( conv2d_NCHWc_int8[n, oc_chunk, oh, ow, 0:16], [16], dtype="int32", offset_factor=1, ) A_u8x4 = A.vload([0], "uint8x4") A_i32 = T.reinterpret(A_u8x4, dtype="int32") B_i8x64 = B.vload([0, 0], dtype="int8x64") B_i32x16 = T.reinterpret(B_i8x64, dtype="int32x16") C[T.ramp( 0, 1, 16)] = C[T.ramp(0, 1, 16)] + T.call_llvm_pure_intrin( T.llvm_lookup_intrinsic_id( "llvm.x86.avx512.vpdpbusd.512"), T.uint32(0), T.broadcast(0, 16), T.broadcast(A_i32, 16), B_i32x16, dtype="int32x16", )