def test_get_tensorize_loop_mapping_dense_vnni(): s = Schedule(DenseVNNIModule) block = s.get_block("compute") info = get_tensorize_loop_mapping(s, block, dot_product_16x4_u8i8i32_desc) assert isinstance(info, TensorizeInfo) desc_loop_to_sref = dict((v, k) for k, v in info.loop_map.items()) desc_loops = collect_loops(dot_product_16x4_u8i8i32_desc) _, loop_j, loop_k = s.get_loops(block) assert desc_loops[0] in desc_loop_to_sref and desc_loops[ 1] in desc_loop_to_sref assert s.get(desc_loop_to_sref[desc_loops[0]]) == s.get(loop_j) assert s.get(desc_loop_to_sref[desc_loops[1]]) == s.get(loop_k)
def test_get_tensorize_loop_mapping_conv2d_nchwc_vnni(): s = Schedule(Conv2dNCHWcVNNIModule) block = s.get_block("conv2d_NCHWc_int8") info = get_tensorize_loop_mapping(s, block, dot_product_16x4_u8i8i32_desc) desc_loop_to_sref = dict((v, k) for k, v in info.loop_map.items()) desc_loops = collect_loops(dot_product_16x4_u8i8i32_desc) # i4 corresonds to the inner output channel axis of the NCHWc output tensor # for i0, i1, i2, i3, i4, i5, i6, i7, i8, i9 in T.grid(1, 16, 56, 56, 16, 1, 1, 4, 4, 4): _, _, _, _, i4, _, _, _, _, i9 = s.get_loops(block) assert desc_loops[0] in desc_loop_to_sref and desc_loops[ 1] in desc_loop_to_sref assert s.get(desc_loop_to_sref[desc_loops[0]]) == s.get(i4) assert s.get(desc_loop_to_sref[desc_loops[1]]) == s.get(i9)
def test_get_tensorize_loop_mapping_matmul_mma(): @T.prim_func def matmul_16x16x16xf16f16f16_desc( A: T.Buffer((16, 16), "float16", align=128, offset_factor=1), B: T.Buffer((16, 16), "float16", align=128, offset_factor=1), C: T.Buffer((16, 16), "float16", align=128, offset_factor=1), ) -> None: with T.block("root"): T.reads(C[0:16, 0:16], A[0:16, 0:16], B[0:16, 0:16]) T.writes(C[0:16, 0:16]) for i, j, k in T.grid(16, 16, 16): with T.block("update"): vii, vjj, vkk = T.axis.remap("SSR", [i, j, k]) C[vii, vjj] = C[vii, vjj] + A[vii, vkk] * B[vjj, vkk] matmul = create_prim_func(te_workload.matmul_relu( n=512, m=512, k=512, )) s = Schedule(matmul) block = s.get_block("C") i0, i1, i2 = s.get_loops(block) desc_loops = collect_loops(matmul_16x16x16xf16f16f16_desc) for do_reorder in [False, True]: # Mapping should be invariant to the loop permutation if do_reorder: s.reorder(i2, i0, i1) info = get_tensorize_loop_mapping(s, block, matmul_16x16x16xf16f16f16_desc) assert info is not None desc_loop_to_sref = dict((v, k) for k, v in info.loop_map.items()) for i in range(3): assert desc_loops[i] in desc_loop_to_sref assert s.get(desc_loop_to_sref[desc_loops[0]]) == s.get(i0) assert s.get(desc_loop_to_sref[desc_loops[1]]) == s.get(i1) assert s.get(desc_loop_to_sref[desc_loops[2]]) == s.get(i2)