def get_sch_rules_for_dp4a(intrin):
    return [
        schedule_rule.MultiLevelTilingWithIntrin(
            intrin,
            structure="SSSRRSRS",
            tile_binds=["blockIdx.x", "vthread.x", "threadIdx.x"],
            max_innermost_factor=64,
            vector_load_lens=[1, 2, 3, 4],
            reuse_read=schedule_rule.ReuseType(
                req="must",
                levels=[4],
                scope="shared",
            ),
            reuse_write=schedule_rule.ReuseType(
                req="must",
                levels=[3],
                scope="local",
            ),
        ),
        schedule_rule.AutoInline(
            into_producer=True,
            into_consumer=True,
            inline_const_tensor=True,
            disallow_if_then_else=False,
            require_injective=False,
            require_ordered=False,
            disallow_op=None,
        ),
        schedule_rule.CrossThreadReduction(
            thread_extents=[4, 8, 16, 32, 64, 128, 256, 512]),
        schedule_rule.ParallelizeVectorizeUnroll(
            max_jobs_per_core=-1,  # disable parallelize
            max_vectorize_extent=-1,  # disable vectorize
            unroll_max_steps=[0, 16, 64, 512, 1024],
            unroll_explicit=True,
        ),
    ]
def test_multi_level_tiling_dense_dpa4():
    m, n, k = 128, 128, 128

    X = te.placeholder((m, k), name="X", dtype="int8")
    W = te.placeholder((n, k), name="W", dtype="int8")
    ak = te.reduce_axis((0, k), name="k")

    matmul = te.compute(
        (m, n),
        lambda i, j: te.sum(
            X[i, ak].astype("int32") * W[j, ak].astype("int32"),
            axis=ak,
        ),
        name="compute",
    )

    func = te.create_prim_func([X, W, matmul])

    ctx = _create_context(
        func,
        target=tvm.target.Target("cuda"),
        rule=schedule_rule.MultiLevelTilingWithIntrin(
            DP4A_INTRIN,
            structure="SSSRRSRS",
            tile_binds=["blockIdx.x", "vthread.x", "threadIdx.x"],
            max_innermost_factor=64,
            vector_load_lens=[1, 2, 3, 4],
            reuse_read=schedule_rule.ReuseType(
                req="must",
                levels=[4],
                scope="shared",
            ),
            reuse_write=schedule_rule.ReuseType(
                req="must",
                levels=[3],
                scope="local",
            ),
        ),
    )

    spaces = ctx.space_generator.generate_design_space(mod=ctx.mod)

    expected = [
        """b0 = sch.get_block(name="compute", func_name="main")
sch.annotate(block_or_loop=b0, ann_key="meta_schedule.tiling_structure", ann_val="SSSRRSRS")
l1, l2, l3 = sch.get_loops(block=b0)
l4, l5 = sch.split(loop=l3, factors=[32, 4])
sch.reorder(l5)
b6 = sch.blockize(loop=l5)
sch.annotate(block_or_loop=b6, ann_key="meta_schedule.auto_tensorize", ann_val="dp4a")
l7, l8, l9 = sch.get_loops(block=b6)
v10, v11, v12, v13, v14 = sch.sample_perfect_tile(loop=l7, n=5, max_innermost_factor=64)
l15, l16, l17, l18, l19 = sch.split(loop=l7, factors=[v10, v11, v12, v13, v14])
v20, v21, v22, v23, v24 = sch.sample_perfect_tile(loop=l8, n=5, max_innermost_factor=64)
l25, l26, l27, l28, l29 = sch.split(loop=l8, factors=[v20, v21, v22, v23, v24])
v30, v31, v32 = sch.sample_perfect_tile(loop=l9, n=3, max_innermost_factor=64)
l33, l34, l35 = sch.split(loop=l9, factors=[v30, v31, v32])
sch.reorder(l15, l25, l16, l26, l17, l27, l33, l34, l18, l28, l35, l19, l29)
l36 = sch.fuse(l15, l25)
sch.bind(loop=l36, thread_axis="blockIdx.x")
l37 = sch.fuse(l16, l26)
sch.bind(loop=l37, thread_axis="vthread.x")
l38 = sch.fuse(l17, l27)
sch.bind(loop=l38, thread_axis="threadIdx.x")
b39 = sch.cache_write(block=b6, write_buffer_index=0, storage_scope="local")
sch.reverse_compute_at(block=b39, loop=l38, preserve_unit_loops=True)
b40 = sch.cache_read(block=b6, read_buffer_index=0, storage_scope="shared")
sch.compute_at(block=b40, loop=l33, preserve_unit_loops=True)
l41, l42, l43, l44, l45, l46 = sch.get_loops(block=b40)
l47 = sch.fuse(l45, l46)
v48 = sch.sample_categorical(candidates=[1, 2, 3, 4], probs=[0.25, 0.25, 0.25, 0.25])
sch.annotate(block_or_loop=b40, ann_key="meta_schedule.cooperative_fetch", ann_val=v48)
b49 = sch.cache_read(block=b6, read_buffer_index=1, storage_scope="shared")
sch.compute_at(block=b49, loop=l33, preserve_unit_loops=True)
l50, l51, l52, l53, l54, l55 = sch.get_loops(block=b49)
l56 = sch.fuse(l54, l55)
v57 = sch.sample_categorical(candidates=[1, 2, 3, 4], probs=[0.25, 0.25, 0.25, 0.25])
sch.annotate(block_or_loop=b49, ann_key="meta_schedule.cooperative_fetch", ann_val=v57)""".split(
            "\n"
        )
    ]

    check_trace(spaces, expected)
def test_multi_level_tiling_conv2d_nchwc_vnni():
    target = "llvm -mcpu=cascadelake -num-cores 4"
    ctx = _create_context(
        Conv2dNCHWcVNNIModule,
        target=tvm.target.Target(target),
        rule=schedule_rule.MultiLevelTilingWithIntrin(
            VNNI_INTRIN,
            structure="SSRSRS",
            tile_binds=None,
            max_innermost_factor=64,
            vector_load_lens=None,
            reuse_read=None,
            reuse_write=schedule_rule.ReuseType(
                req="may",
                levels=[1, 2],
                scope="global",
            ),
        ),
    )

    spaces = ctx.space_generator.generate_design_space(mod=ctx.mod)

    expected = [
        """b0 = sch.get_block(name="conv2d_NCHWc_int8", func_name="main")
sch.annotate(block_or_loop=b0, ann_key="meta_schedule.tiling_structure", ann_val="SSRSRS")
l1, l2, l3, l4, l5, l6, l7, l8, l9, l10 = sch.get_loops(block=b0)
l11, l12 = sch.split(loop=l10, factors=[1, 4])
l13, l14 = sch.split(loop=l5, factors=[1, 16])
l15, l16, l17, l18, l19, l20, l21, l22, l23, l24, l25, l26 = sch.get_loops(block=b0)
sch.reorder(l21, l22, l23, l24, l25, l14, l12)
b27 = sch.blockize(loop=l14)
sch.annotate(block_or_loop=b27, ann_key="meta_schedule.auto_tensorize", ann_val="dot_16x4_vnni")
l28, l29, l30, l31, l32, l33, l34, l35, l36, l37 = sch.get_loops(block=b27)
v38, v39, v40, v41 = sch.sample_perfect_tile(loop=l28, n=4, max_innermost_factor=64)
l42, l43, l44, l45 = sch.split(loop=l28, factors=[v38, v39, v40, v41])
v46, v47, v48, v49 = sch.sample_perfect_tile(loop=l29, n=4, max_innermost_factor=64)
l50, l51, l52, l53 = sch.split(loop=l29, factors=[v46, v47, v48, v49])
v54, v55, v56, v57 = sch.sample_perfect_tile(loop=l30, n=4, max_innermost_factor=64)
l58, l59, l60, l61 = sch.split(loop=l30, factors=[v54, v55, v56, v57])
v62, v63, v64, v65 = sch.sample_perfect_tile(loop=l31, n=4, max_innermost_factor=64)
l66, l67, l68, l69 = sch.split(loop=l31, factors=[v62, v63, v64, v65])
v70, v71, v72, v73 = sch.sample_perfect_tile(loop=l32, n=4, max_innermost_factor=64)
l74, l75, l76, l77 = sch.split(loop=l32, factors=[v70, v71, v72, v73])
v78, v79 = sch.sample_perfect_tile(loop=l33, n=2, max_innermost_factor=64)
l80, l81 = sch.split(loop=l33, factors=[v78, v79])
v82, v83 = sch.sample_perfect_tile(loop=l34, n=2, max_innermost_factor=64)
l84, l85 = sch.split(loop=l34, factors=[v82, v83])
v86, v87 = sch.sample_perfect_tile(loop=l35, n=2, max_innermost_factor=64)
l88, l89 = sch.split(loop=l35, factors=[v86, v87])
v90, v91 = sch.sample_perfect_tile(loop=l36, n=2, max_innermost_factor=64)
l92, l93 = sch.split(loop=l36, factors=[v90, v91])
v94, v95 = sch.sample_perfect_tile(loop=l37, n=2, max_innermost_factor=64)
l96, l97 = sch.split(loop=l37, factors=[v94, v95])
sch.reorder(l42, l50, l58, l66, l74, l43, l51, l59, l67, l75, l80, l84, l88, l92, l96, l44, l52, l60, l68, l76, l81, l85, l89, l93, l97, l45, l53, l61, l69, l77)
b98 = sch.cache_write(block=b27, write_buffer_index=0, storage_scope="global")
sch.reverse_compute_at(block=b98, loop=l75, preserve_unit_loops=True)""".split(
            "\n"
        ),
        """b0 = sch.get_block(name="conv2d_NCHWc_int8", func_name="main")
sch.annotate(block_or_loop=b0, ann_key="meta_schedule.tiling_structure", ann_val="SSRSRS")
l1, l2, l3, l4, l5, l6, l7, l8, l9, l10 = sch.get_loops(block=b0)
l11, l12 = sch.split(loop=l10, factors=[1, 4])
l13, l14 = sch.split(loop=l5, factors=[1, 16])
l15, l16, l17, l18, l19, l20, l21, l22, l23, l24, l25, l26 = sch.get_loops(block=b0)
sch.reorder(l21, l22, l23, l24, l25, l14, l12)
b27 = sch.blockize(loop=l14)
sch.annotate(block_or_loop=b27, ann_key="meta_schedule.auto_tensorize", ann_val="dot_16x4_vnni")
l28, l29, l30, l31, l32, l33, l34, l35, l36, l37 = sch.get_loops(block=b27)
v38, v39, v40, v41 = sch.sample_perfect_tile(loop=l28, n=4, max_innermost_factor=64)
l42, l43, l44, l45 = sch.split(loop=l28, factors=[v38, v39, v40, v41])
v46, v47, v48, v49 = sch.sample_perfect_tile(loop=l29, n=4, max_innermost_factor=64)
l50, l51, l52, l53 = sch.split(loop=l29, factors=[v46, v47, v48, v49])
v54, v55, v56, v57 = sch.sample_perfect_tile(loop=l30, n=4, max_innermost_factor=64)
l58, l59, l60, l61 = sch.split(loop=l30, factors=[v54, v55, v56, v57])
v62, v63, v64, v65 = sch.sample_perfect_tile(loop=l31, n=4, max_innermost_factor=64)
l66, l67, l68, l69 = sch.split(loop=l31, factors=[v62, v63, v64, v65])
v70, v71, v72, v73 = sch.sample_perfect_tile(loop=l32, n=4, max_innermost_factor=64)
l74, l75, l76, l77 = sch.split(loop=l32, factors=[v70, v71, v72, v73])
v78, v79 = sch.sample_perfect_tile(loop=l33, n=2, max_innermost_factor=64)
l80, l81 = sch.split(loop=l33, factors=[v78, v79])
v82, v83 = sch.sample_perfect_tile(loop=l34, n=2, max_innermost_factor=64)
l84, l85 = sch.split(loop=l34, factors=[v82, v83])
v86, v87 = sch.sample_perfect_tile(loop=l35, n=2, max_innermost_factor=64)
l88, l89 = sch.split(loop=l35, factors=[v86, v87])
v90, v91 = sch.sample_perfect_tile(loop=l36, n=2, max_innermost_factor=64)
l92, l93 = sch.split(loop=l36, factors=[v90, v91])
v94, v95 = sch.sample_perfect_tile(loop=l37, n=2, max_innermost_factor=64)
l96, l97 = sch.split(loop=l37, factors=[v94, v95])
sch.reorder(l42, l50, l58, l66, l74, l43, l51, l59, l67, l75, l80, l84, l88, l92, l96, l44, l52, l60, l68, l76, l81, l85, l89, l93, l97, l45, l53, l61, l69, l77)
b98 = sch.cache_write(block=b27, write_buffer_index=0, storage_scope="global")
sch.reverse_compute_at(block=b98, loop=l74, preserve_unit_loops=True)""".split(
            "\n"
        ),
        """b0 = sch.get_block(name="conv2d_NCHWc_int8", func_name="main")
sch.annotate(block_or_loop=b0, ann_key="meta_schedule.tiling_structure", ann_val="SSRSRS")
l1, l2, l3, l4, l5, l6, l7, l8, l9, l10 = sch.get_loops(block=b0)
l11, l12 = sch.split(loop=l10, factors=[1, 4])
l13, l14 = sch.split(loop=l5, factors=[1, 16])
l15, l16, l17, l18, l19, l20, l21, l22, l23, l24, l25, l26 = sch.get_loops(block=b0)
sch.reorder(l21, l22, l23, l24, l25, l14, l12)
b27 = sch.blockize(loop=l14)
sch.annotate(block_or_loop=b27, ann_key="meta_schedule.auto_tensorize", ann_val="dot_16x4_vnni")
l28, l29, l30, l31, l32, l33, l34, l35, l36, l37 = sch.get_loops(block=b27)
v38, v39, v40, v41 = sch.sample_perfect_tile(loop=l28, n=4, max_innermost_factor=64)
l42, l43, l44, l45 = sch.split(loop=l28, factors=[v38, v39, v40, v41])
v46, v47, v48, v49 = sch.sample_perfect_tile(loop=l29, n=4, max_innermost_factor=64)
l50, l51, l52, l53 = sch.split(loop=l29, factors=[v46, v47, v48, v49])
v54, v55, v56, v57 = sch.sample_perfect_tile(loop=l30, n=4, max_innermost_factor=64)
l58, l59, l60, l61 = sch.split(loop=l30, factors=[v54, v55, v56, v57])
v62, v63, v64, v65 = sch.sample_perfect_tile(loop=l31, n=4, max_innermost_factor=64)
l66, l67, l68, l69 = sch.split(loop=l31, factors=[v62, v63, v64, v65])
v70, v71, v72, v73 = sch.sample_perfect_tile(loop=l32, n=4, max_innermost_factor=64)
l74, l75, l76, l77 = sch.split(loop=l32, factors=[v70, v71, v72, v73])
v78, v79 = sch.sample_perfect_tile(loop=l33, n=2, max_innermost_factor=64)
l80, l81 = sch.split(loop=l33, factors=[v78, v79])
v82, v83 = sch.sample_perfect_tile(loop=l34, n=2, max_innermost_factor=64)
l84, l85 = sch.split(loop=l34, factors=[v82, v83])
v86, v87 = sch.sample_perfect_tile(loop=l35, n=2, max_innermost_factor=64)
l88, l89 = sch.split(loop=l35, factors=[v86, v87])
v90, v91 = sch.sample_perfect_tile(loop=l36, n=2, max_innermost_factor=64)
l92, l93 = sch.split(loop=l36, factors=[v90, v91])
v94, v95 = sch.sample_perfect_tile(loop=l37, n=2, max_innermost_factor=64)
l96, l97 = sch.split(loop=l37, factors=[v94, v95])
sch.reorder(l42, l50, l58, l66, l74, l43, l51, l59, l67, l75, l80, l84, l88, l92, l96, l44, l52, l60, l68, l76, l81, l85, l89, l93, l97, l45, l53, l61, l69, l77)""".split(
            "\n"
        ),
    ]

    check_trace(spaces, expected)
        into_producer=False,
        into_consumer=True,
        inline_const_tensor=True,
        disallow_if_then_else=True,
        require_injective=True,
        require_ordered=True,
        disallow_op=["tir.exp"],
    ),
    schedule_rule.AddRFactor(max_jobs_per_core=16, max_innermost_factor=64),
    schedule_rule.MultiLevelTilingWithIntrin(
        VNNI_INTRIN,
        structure="SSRSRS",
        tile_binds=None,
        max_innermost_factor=64,
        vector_load_lens=None,
        reuse_read=None,
        reuse_write=schedule_rule.ReuseType(
            req="may",
            levels=[1, 2],
            scope="global",
        ),
    ),
    schedule_rule.ParallelizeVectorizeUnroll(
        max_jobs_per_core=16,
        max_vectorize_extent=64,
        unroll_max_steps=[0, 16, 64, 512],
        unroll_explicit=True,
    ),
    schedule_rule.RandomComputeLocation(),
]