Exemple #1
0
def test_get_tensorize_loop_mapping_dense_vnni():
    s = Schedule(DenseVNNIModule)
    block = s.get_block("compute")

    info = get_tensorize_loop_mapping(s, block, dot_product_16x4_u8i8i32_desc)

    assert isinstance(info, TensorizeInfo)

    desc_loop_to_sref = dict((v, k) for k, v in info.loop_map.items())

    desc_loops = collect_loops(dot_product_16x4_u8i8i32_desc)
    _, loop_j, loop_k = s.get_loops(block)

    assert desc_loops[0] in desc_loop_to_sref and desc_loops[
        1] in desc_loop_to_sref
    assert s.get(desc_loop_to_sref[desc_loops[0]]) == s.get(loop_j)
    assert s.get(desc_loop_to_sref[desc_loops[1]]) == s.get(loop_k)
def test_get_tensorize_loop_mapping_conv2d_nchwc_vnni():
    s = Schedule(Conv2dNCHWcVNNIModule)
    block = s.get_block("conv2d_NCHWc_int8")

    info = get_tensorize_loop_mapping(s, block, dot_product_16x4_u8i8i32_desc)

    desc_loop_to_sref = dict((v, k) for k, v in info.loop_map.items())

    desc_loops = collect_loops(dot_product_16x4_u8i8i32_desc)

    # i4 corresonds to the inner output channel axis of the NCHWc output tensor
    # for i0, i1, i2, i3, i4, i5, i6, i7, i8, i9 in T.grid(1, 16, 56, 56, 16, 1, 1, 4, 4, 4):
    _, _, _, _, i4, _, _, _, _, i9 = s.get_loops(block)

    assert desc_loops[0] in desc_loop_to_sref and desc_loops[
        1] in desc_loop_to_sref
    assert s.get(desc_loop_to_sref[desc_loops[0]]) == s.get(i4)
    assert s.get(desc_loop_to_sref[desc_loops[1]]) == s.get(i9)
def test_get_tensorize_loop_mapping_matmul_mma():
    @T.prim_func
    def matmul_16x16x16xf16f16f16_desc(
        A: T.Buffer((16, 16), "float16", align=128, offset_factor=1),
        B: T.Buffer((16, 16), "float16", align=128, offset_factor=1),
        C: T.Buffer((16, 16), "float16", align=128, offset_factor=1),
    ) -> None:
        with T.block("root"):
            T.reads(C[0:16, 0:16], A[0:16, 0:16], B[0:16, 0:16])
            T.writes(C[0:16, 0:16])
            for i, j, k in T.grid(16, 16, 16):
                with T.block("update"):
                    vii, vjj, vkk = T.axis.remap("SSR", [i, j, k])
                    C[vii, vjj] = C[vii, vjj] + A[vii, vkk] * B[vjj, vkk]

    matmul = create_prim_func(te_workload.matmul_relu(
        n=512,
        m=512,
        k=512,
    ))

    s = Schedule(matmul)
    block = s.get_block("C")
    i0, i1, i2 = s.get_loops(block)
    desc_loops = collect_loops(matmul_16x16x16xf16f16f16_desc)

    for do_reorder in [False, True]:
        # Mapping should be invariant to the loop permutation
        if do_reorder:
            s.reorder(i2, i0, i1)

        info = get_tensorize_loop_mapping(s, block,
                                          matmul_16x16x16xf16f16f16_desc)
        assert info is not None
        desc_loop_to_sref = dict((v, k) for k, v in info.loop_map.items())

        for i in range(3):
            assert desc_loops[i] in desc_loop_to_sref

        assert s.get(desc_loop_to_sref[desc_loops[0]]) == s.get(i0)
        assert s.get(desc_loop_to_sref[desc_loops[1]]) == s.get(i1)
        assert s.get(desc_loop_to_sref[desc_loops[2]]) == s.get(i2)