def expected_func(): B = T.allocate([16], "int32", "shared") # The indices for B should each be a single Ramp node, and # should not be the sum of a Ramp and Broadcast node. B[0 * 4:0 * 4 + 4] = T.broadcast(0, 4) B[1 * 4:1 * 4 + 4] = T.broadcast(1, 4) B[2 * 4:2 * 4 + 4] = T.broadcast(2, 4) B[3 * 4:3 * 4 + 4] = T.broadcast(3, 4)
def dot_product_4x4_i8i8i32_neon( A: T.Buffer((4, ), "int8", offset_factor=1), B: T.Buffer((4, 4), "int8", offset_factor=1), C: T.Buffer((4, ), "int32", offset_factor=1), ) -> None: with T.block("root"): T.reads(C[0:4], A[0:4], B[0:4, 0:4]) T.writes(C[0:4]) A_int8 = A.vload([0], "int8x4") re_int32 = T.reinterpret(A_int8, dtype="int32") vec_ai32 = T.broadcast(re_int32, 2) vec_a = T.reinterpret(vec_ai32, dtype="int8x8") vec_b = B.vload([0, 0], dtype="int8x16") # TODO(masahi): Remove duplication when inlined function call is supported vec_b_low = T.vectorlow(vec_b, dtype="int8x8") multiply_low = T.call_llvm_pure_intrin( T.llvm_lookup_intrinsic_id("llvm.aarch64.neon.smull.v8i16"), T.uint32(2), vec_a, vec_b_low, dtype="int16x8", ) pairwise_reduction_low = T.call_llvm_pure_intrin( T.llvm_lookup_intrinsic_id("llvm.aarch64.neon.saddlp.v4i32.v8i16"), T.uint32(1), multiply_low, dtype="int32x4", ) vec_b_high = T.vectorhigh(vec_b, dtype="int8x8") multiply_high = T.call_llvm_pure_intrin( T.llvm_lookup_intrinsic_id("llvm.aarch64.neon.smull.v8i16"), T.uint32(2), vec_a, vec_b_high, dtype="int16x8", ) pairwise_reduction_high = T.call_llvm_pure_intrin( T.llvm_lookup_intrinsic_id("llvm.aarch64.neon.saddlp.v4i32.v8i16"), T.uint32(1), multiply_high, dtype="int32x4", ) C[T.ramp(T.int32(0), 1, 4)] += T.call_llvm_pure_intrin( T.llvm_lookup_intrinsic_id("llvm.aarch64.neon.addp.v4i32"), T.uint32(2), pairwise_reduction_low, pairwise_reduction_high, dtype="int32x4", )
def main(a: T.handle, b: T.handle) -> None: # function attr dict T.func_attr({"global_symbol": "main", "T.noalias": True}) # var definition threadIdx_x = T.env_thread("threadIdx.x") threadIdx_y = T.env_thread("threadIdx.y") blockIdx_x = T.env_thread("blockIdx.x") blockIdx_y = T.env_thread("blockIdx.y") blockIdx_z = T.env_thread("blockIdx.z") A = T.match_buffer(a, [14 * 14 * 256 * 256], dtype="float32") B = T.match_buffer(b, [14 * 14 * 512 * 256], dtype="float32") # body T.launch_thread(blockIdx_z, 196) B_local = T.allocate([64], "float32", "local") Apad_shared = T.allocate([512000], "float32", "shared") Apad_shared_local = T.allocate([8], "float32", "local") T.launch_thread(blockIdx_y, 8) T.launch_thread(blockIdx_x, 4) T.launch_thread(threadIdx_y, 8) T.launch_thread(threadIdx_x, 8) for ff_c_init, nn_c_init in T.grid(8, 8): B_local[ff_c_init * 8 + nn_c_init] = T.float32(0) for rc_outer, ry, rx in T.grid(32, 3, 3): for ax3_inner_outer in T.serial(0, 2): Apad_shared[T.ramp( threadIdx_y * 64 + threadIdx_x * 8 + ax3_inner_outer * 4, 1, 4)] = T.if_then_else( 1 <= blockIdx_z // 14 + ry and blockIdx_z // 14 + ry < 15 and 1 <= rx + blockIdx_z % 14 and rx + blockIdx_z % 14 < 15, A[T.ramp( ry * 917504 + blockIdx_z * 65536 + rx * 65536 + rc_outer * 2048 + threadIdx_y * 256 + blockIdx_x * 64 + threadIdx_x * 8 + ax3_inner_outer * 4 - 983040, 1, 4)], T.broadcast(T.float32(0), 4), dtype="float32x4", ) # Access of the last element of Apad_shared prevents # buffer compacting from reducing the amount of shared # memory used. Apad_shared[512000 - 1] = 0.0 for rc_inner in T.serial(0, 8): for ax3 in T.serial(0, 8): Apad_shared_local[ax3] = Apad_shared[rc_inner * 64 + threadIdx_x * 8 + ax3] for ff_c, nn_c in T.grid(8, 8): B_local[ff_c * 8 + nn_c] = B_local[ff_c * 8 + nn_c] + Apad_shared_local[nn_c] for ff_inner_inner_inner, nn_inner_inner_inner in T.grid(8, 8): B[blockIdx_z * 131072 + blockIdx_y * 16384 + threadIdx_y * 2048 + ff_inner_inner_inner * 256 + blockIdx_x * 64 + threadIdx_x * 8 + nn_inner_inner_inner] = B_local[ff_inner_inner_inner * 8 + nn_inner_inner_inner] # fmt: on
def dot_product_16x4_u8i8i32_vnni( A: T.Buffer((4,), "uint8", offset_factor=1), B: T.Buffer((16, 4), "int8", offset_factor=1), C: T.Buffer((16,), "int32", offset_factor=1), ) -> None: with T.block("root"): T.reads(C[0:16], A[0:4], B[0:16, 0:4]) T.writes(C[0:16]) A_u8x4 = A.vload([0], "uint8x4") A_i32 = T.reinterpret(A_u8x4, dtype="int32") B_i8x64 = B.vload([0, 0], dtype="int8x64") B_i32x16 = T.reinterpret(B_i8x64, dtype="int32x16") C[T.ramp(T.int32(0), 1, 16)] += T.call_llvm_pure_intrin( # Note: this is an update += T.llvm_lookup_intrinsic_id("llvm.x86.avx512.vpdpbusd.512"), T.uint32(0), T.int32x16(0), T.broadcast(A_i32, 16), B_i32x16, dtype="int32x16", )
def dot_product_4x4_i8i8i32_sdot( A: T.Buffer((4, ), "int8", offset_factor=1), B: T.Buffer((4, 4), "int8", offset_factor=1), C: T.Buffer((4, ), "int32", offset_factor=1), ) -> None: with T.block("root"): T.reads(C[0:4], A[0:4], B[0:4, 0:4]) T.writes(C[0:4]) A_i8x4 = A.vload([0], "int8x4") A_i32 = T.reinterpret(A_i8x4, dtype="int32") vec_ai32 = T.broadcast(A_i32, 4) vec_a = T.reinterpret(vec_ai32, dtype="int8x16") vec_b = B.vload([0, 0], dtype="int8x16") C[T.ramp(T.int32(0), 1, 4)] += T.call_llvm_pure_intrin( T.llvm_lookup_intrinsic_id("llvm.aarch64.neon.sdot.v4i32.v16i8"), T.uint32(3), T.int32x4(0), vec_a, vec_b, dtype="int32x4", )
def main( inputs: T.Buffer[(1, 4, 4, 512), "float32"], weight: T.Buffer[(4, 4, 512, 256), "float32"], conv2d_transpose_nhwc: T.Buffer[(1, 8, 8, 256), "float32"]) -> None: # function attr dict T.func_attr({"global_symbol": "main", "tir.noalias": True}) # var definition threadIdx_x = T.env_thread("threadIdx.x") blockIdx_x = T.env_thread("blockIdx.x") # body T.launch_thread(blockIdx_x, 64) conv2d_transpose_nhwc_local = T.allocate([8], "float32", "local") PadInput_shared = T.allocate([768], "float32", "shared") weight_shared = T.allocate([4096], "float32", "shared") T.launch_thread(threadIdx_x, 32) for i2_3_init, i1_4_init, i2_4_init in T.grid(2, 2, 2): T.store(conv2d_transpose_nhwc_local, i1_4_init * 4 + i2_3_init * 2 + i2_4_init, T.float32(0), True) for i6_0 in T.serial(16): for ax0_ax1_ax2_ax3_fused_0 in T.serial(24): T.store( PadInput_shared, ax0_ax1_ax2_ax3_fused_0 * 32 + threadIdx_x, T.if_then_else( 128 <= ax0_ax1_ax2_ax3_fused_0 * 32 + threadIdx_x and ax0_ax1_ax2_ax3_fused_0 * 32 + threadIdx_x < 640 and 1 <= blockIdx_x // 32 * 2 + (ax0_ax1_ax2_ax3_fused_0 * 32 + threadIdx_x) % 128 // 32 and blockIdx_x // 32 * 2 + (ax0_ax1_ax2_ax3_fused_0 * 32 + threadIdx_x) % 128 // 32 < 5, T.load( "float32", inputs.data, blockIdx_x // 32 * 1024 + ax0_ax1_ax2_ax3_fused_0 * 512 + i6_0 * 32 + threadIdx_x - 2560), T.float32(0), dtype="float32"), True) for ax0_ax1_ax2_ax3_fused_0 in T.serial(32): T.store( weight_shared, T.ramp(ax0_ax1_ax2_ax3_fused_0 * 128 + threadIdx_x * 4, 1, 4), T.load( "float32x4", weight.data, T.ramp( (ax0_ax1_ax2_ax3_fused_0 * 128 + threadIdx_x * 4) // 256 * 131072 + i6_0 * 8192 + (ax0_ax1_ax2_ax3_fused_0 * 128 + threadIdx_x * 4) % 256 // 8 * 256 + blockIdx_x % 32 * 8 + threadIdx_x % 2 * 4, 1, 4), T.broadcast(True, 4)), T.broadcast(True, 4)) for i6_1, i2_3, i4_2, i5_2, i6_2, i1_4, i2_4 in T.grid( 4, 2, 4, 4, 8, 2, 2): T.store( conv2d_transpose_nhwc_local, i1_4 * 4 + i2_3 * 2 + i2_4, T.load("float32", conv2d_transpose_nhwc_local, i1_4 * 4 + i2_3 * 2 + i2_4) + T.if_then_else( (i1_4 + i4_2) % 2 == 0 and (i2_4 + i5_2) % 2 == 0, T.load( "float32", PadInput_shared, threadIdx_x // 8 * 128 + (i1_4 + i4_2) // 2 * 128 + (i2_4 + i5_2) // 2 * 32 + i2_3 * 32 + i6_1 * 8 + i6_2), T.float32(0), dtype="float32") * T.load( "float32", weight_shared, i6_1 * 64 + i6_2 * 8 + threadIdx_x % 8 + 3840 - i5_2 * 256 - i4_2 * 1024), True) for ax1, ax2 in T.grid(2, 4): T.store( conv2d_transpose_nhwc.data, threadIdx_x // 8 * 4096 + ax1 * 2048 + blockIdx_x // 32 * 1024 + ax2 * 256 + blockIdx_x % 32 * 8 + threadIdx_x % 8, T.load("float32", conv2d_transpose_nhwc_local, ax1 * 4 + ax2), True)
def main(a: T.handle, b: T.handle) -> None: # function attr dict T.func_attr({"global_symbol": "main", "T.noalias": True}) # var definition threadIdx_x = T.env_thread("threadIdx.x") threadIdx_y = T.env_thread("threadIdx.y") blockIdx_x = T.env_thread("blockIdx.x") blockIdx_y = T.env_thread("blockIdx.y") blockIdx_z = T.env_thread("blockIdx.z") A = T.match_buffer(a, [14, 14, 256, 256], dtype="float32") B = T.match_buffer(b, [14, 14, 512, 256], dtype="float32") # body T.launch_thread(blockIdx_z, 196) B_local = T.allocate([6400000], "float32", "local") Apad_shared = T.allocate([512], "float32", "shared") Apad_shared_local = T.allocate([8], "float32", "local") T.launch_thread(blockIdx_y, 8) T.launch_thread(blockIdx_x, 4) T.launch_thread(threadIdx_y, 8) T.launch_thread(threadIdx_x, 8) for ff_c_init, nn_c_init in T.grid(8, 8): T.store(B_local, ff_c_init * 8 + nn_c_init, T.float32(0), True) for rc_outer, ry, rx in T.grid(32, 3, 3): for ax3_inner_outer in T.serial(0, 2): T.store(Apad_shared, T.ramp(threadIdx_y * 64 + threadIdx_x * 8 + ax3_inner_outer * 4, 1, 4), T.if_then_else(1 <= blockIdx_z // 14 + ry and blockIdx_z // 14 + ry < 15 and 1 <= rx + blockIdx_z % 14 and rx + blockIdx_z % 14 < 15, T.load("float32x4", A.data, T.ramp(ry * 917504 + blockIdx_z * 65536 + rx * 65536 + rc_outer * 2048 + threadIdx_y * 256 + blockIdx_x * 64 + threadIdx_x * 8 + ax3_inner_outer * 4 - 983040, 1, 4), T.broadcast(True, 4)), T.broadcast(T.float32(0), 4), dtype="float32x4"), T.broadcast(True, 4)) for rc_inner in T.serial(0, 8): for ax3 in T.serial(0, 8): T.store(Apad_shared_local, ax3, T.load("float32", Apad_shared, rc_inner * 64 + threadIdx_x * 8 + ax3), True) for ff_c, nn_c in T.grid(8, 8): T.store(B_local, ff_c * 8 + nn_c, T.load("float32", B_local, ff_c * 8 + nn_c) + T.load("float32", Apad_shared_local, nn_c), True) for ff_inner_inner_inner, nn_inner_inner_inner in T.grid(8, 8): T.store(B.data, blockIdx_z * 131072 + blockIdx_y * 16384 + threadIdx_y * 2048 + ff_inner_inner_inner * 256 + blockIdx_x * 64 + threadIdx_x * 8 + nn_inner_inner_inner, T.load("float32", B_local, ff_inner_inner_inner * 8 + nn_inner_inner_inner), True)# fmt: on
def expected() -> None: A_data: T.Ptr[T.int32x8] = T.call_extern("dummy_func", dtype="handle") A = T.buffer_decl([1], "int32x8", data=A_data) A[0] = T.broadcast(42, 8)
def before() -> None: A_data: T.Ptr[T.int32] = T.call_extern("dummy_func", dtype="handle") A = T.buffer_decl([8], "int32", data=A_data) A[0:8] = T.broadcast(42, 8)
def main( placeholder: T.Buffer[(1, 4, 56, 56, 16), "uint8"], placeholder_1: T.Buffer[(16, 4, 1, 1, 4, 16, 4), "int8"], conv2d_NCHWc_int8: T.Buffer[(1, 16, 56, 56, 16), "int32"], ) -> None: # function attr dict T.func_attr({"global_symbol": "main", "tir.noalias": True}) # body # with T.block("root") for i0_0, i1_0, i2_0, i3_0, i4_0_0, i0_1, i1_1, i2_1, i3_1, i4_0_1, i5_0, i6_0 in T.grid( 1, 1, 2, 1, 1, 1, 4, 1, 14, 1, 1, 1): for i1_2_init, i2_2_init, i2_3_init, i3_3_init in T.grid( 4, 7, 4, 4): with T.block("conv2d_NCHWc_int8_o_init"): n = T.axis.spatial(1, 0) oc_chunk = T.axis.spatial(16, i1_1 * 4 + i1_2_init) oh = T.axis.spatial(56, i2_0 * 28 + i2_2_init * 4 + i2_3_init) ow = T.axis.spatial(56, i3_1 * 4 + i3_3_init) oc_block_o = T.axis.spatial(1, 0) T.reads() T.writes(conv2d_NCHWc_int8[n, oc_chunk, oh, ow, 0:16]) for i4_1 in T.vectorized(16): with T.block("conv2d_NCHWc_int8_init"): oc_block_init = T.axis.spatial(16, i4_1) T.reads() T.writes(conv2d_NCHWc_int8[n, oc_chunk, oh, ow, oc_block_init]) conv2d_NCHWc_int8[n, oc_chunk, oh, ow, oc_block_init] = 0 for ( i7_0, i8_0, i9_0_0, i0_2, i1_2, i2_2, i3_2, i4_0_2, i5_1, i6_1, i7_1, i8_1, i9_0_1, i0_3, i1_3, i2_3, i3_3, i4_0_3, ) in T.grid(4, 1, 1, 1, 4, 7, 1, 1, 1, 1, 1, 4, 1, 1, 1, 4, 4, 1): with T.block("conv2d_NCHWc_int8_o_update"): n = T.axis.spatial(1, 0) oc_chunk = T.axis.spatial(16, i1_1 * 4 + i1_2) oh = T.axis.spatial(56, i2_0 * 28 + i2_2 * 4 + i2_3) ow = T.axis.spatial(56, i3_1 * 4 + i3_3) oc_block_o = T.axis.spatial(1, 0) kh = T.axis.reduce(1, 0) kw = T.axis.reduce(1, 0) ic_outer, ic_f_inner = T.axis.remap("RR", [i7_0, i8_1]) ic_s_inner_o = T.axis.reduce(1, 0) T.reads( conv2d_NCHWc_int8[n, oc_chunk, oh, ow, 0:16], placeholder[n, ic_outer, oh + kh, ow + kw, ic_f_inner * 4:ic_f_inner * 4 + 4], placeholder_1[oc_chunk, ic_outer, kh, kw, ic_f_inner, 0:16, 0:4], ) T.writes(conv2d_NCHWc_int8[n, oc_chunk, oh, ow, 0:16]) A = T.match_buffer( placeholder[n, ic_outer, oh + kh, ow + kw, ic_f_inner * 4:ic_f_inner * 4 + 4], [4], dtype="uint8", offset_factor=1, ) B = T.match_buffer( placeholder_1[oc_chunk, ic_outer, kh, kw, ic_f_inner, 0:16, 0:4], [16, 4], dtype="int8", offset_factor=1, ) C = T.match_buffer( conv2d_NCHWc_int8[n, oc_chunk, oh, ow, 0:16], [16], dtype="int32", offset_factor=1, ) A_u8x4 = A.vload([0], "uint8x4") A_i32 = T.reinterpret(A_u8x4, dtype="int32") B_i8x64 = B.vload([0, 0], dtype="int8x64") B_i32x16 = T.reinterpret(B_i8x64, dtype="int32x16") C[T.ramp( 0, 1, 16)] = C[T.ramp(0, 1, 16)] + T.call_llvm_pure_intrin( T.llvm_lookup_intrinsic_id( "llvm.x86.avx512.vpdpbusd.512"), T.uint32(0), T.broadcast(0, 16), T.broadcast(A_i32, 16), B_i32x16, dtype="int32x16", )
def expected_func(): B = T.allocate([4], "int32x4", "shared") B[0 * 4 / 4] = T.broadcast(0, 4) B[1 * 4 / 4] = T.broadcast(1, 4) B[2 * 4 / 4] = T.broadcast(2, 4) B[3 * 4 / 4] = T.broadcast(3, 4)
def before_func(): vthread = T.env_thread("vthread") T.launch_thread(vthread, 4) B = T.allocate([4], "int32", "shared") B[0:4] = T.broadcast(vthread, 4)
def access_reversed_slice(A: T.handle): # do not allow reversed slice step A = T.match_buffer((128,), "int32") A[0:128:-1] = T.broadcast(1, 128) # error