def get_sch_rules_for_dp4a(intrin): return [ schedule_rule.MultiLevelTilingWithIntrin( intrin, structure="SSSRRSRS", tile_binds=["blockIdx.x", "vthread.x", "threadIdx.x"], max_innermost_factor=64, vector_load_lens=[1, 2, 3, 4], reuse_read=schedule_rule.ReuseType( req="must", levels=[4], scope="shared", ), reuse_write=schedule_rule.ReuseType( req="must", levels=[3], scope="local", ), ), schedule_rule.AutoInline( into_producer=True, into_consumer=True, inline_const_tensor=True, disallow_if_then_else=False, require_injective=False, require_ordered=False, disallow_op=None, ), schedule_rule.CrossThreadReduction( thread_extents=[4, 8, 16, 32, 64, 128, 256, 512]), schedule_rule.ParallelizeVectorizeUnroll( max_jobs_per_core=-1, # disable parallelize max_vectorize_extent=-1, # disable vectorize unroll_max_steps=[0, 16, 64, 512, 1024], unroll_explicit=True, ), ]
def test_multi_level_tiling_dense_dpa4(): m, n, k = 128, 128, 128 X = te.placeholder((m, k), name="X", dtype="int8") W = te.placeholder((n, k), name="W", dtype="int8") ak = te.reduce_axis((0, k), name="k") matmul = te.compute( (m, n), lambda i, j: te.sum( X[i, ak].astype("int32") * W[j, ak].astype("int32"), axis=ak, ), name="compute", ) func = te.create_prim_func([X, W, matmul]) ctx = _create_context( func, target=tvm.target.Target("cuda"), rule=schedule_rule.MultiLevelTilingWithIntrin( DP4A_INTRIN, structure="SSSRRSRS", tile_binds=["blockIdx.x", "vthread.x", "threadIdx.x"], max_innermost_factor=64, vector_load_lens=[1, 2, 3, 4], reuse_read=schedule_rule.ReuseType( req="must", levels=[4], scope="shared", ), reuse_write=schedule_rule.ReuseType( req="must", levels=[3], scope="local", ), ), ) spaces = ctx.space_generator.generate_design_space(mod=ctx.mod) expected = [ """b0 = sch.get_block(name="compute", func_name="main") sch.annotate(block_or_loop=b0, ann_key="meta_schedule.tiling_structure", ann_val="SSSRRSRS") l1, l2, l3 = sch.get_loops(block=b0) l4, l5 = sch.split(loop=l3, factors=[32, 4]) sch.reorder(l5) b6 = sch.blockize(loop=l5) sch.annotate(block_or_loop=b6, ann_key="meta_schedule.auto_tensorize", ann_val="dp4a") l7, l8, l9 = sch.get_loops(block=b6) v10, v11, v12, v13, v14 = sch.sample_perfect_tile(loop=l7, n=5, max_innermost_factor=64) l15, l16, l17, l18, l19 = sch.split(loop=l7, factors=[v10, v11, v12, v13, v14]) v20, v21, v22, v23, v24 = sch.sample_perfect_tile(loop=l8, n=5, max_innermost_factor=64) l25, l26, l27, l28, l29 = sch.split(loop=l8, factors=[v20, v21, v22, v23, v24]) v30, v31, v32 = sch.sample_perfect_tile(loop=l9, n=3, max_innermost_factor=64) l33, l34, l35 = sch.split(loop=l9, factors=[v30, v31, v32]) sch.reorder(l15, l25, l16, l26, l17, l27, l33, l34, l18, l28, l35, l19, l29) l36 = sch.fuse(l15, l25) sch.bind(loop=l36, thread_axis="blockIdx.x") l37 = sch.fuse(l16, l26) sch.bind(loop=l37, thread_axis="vthread.x") l38 = sch.fuse(l17, l27) sch.bind(loop=l38, thread_axis="threadIdx.x") b39 = sch.cache_write(block=b6, write_buffer_index=0, storage_scope="local") sch.reverse_compute_at(block=b39, loop=l38, preserve_unit_loops=True) b40 = sch.cache_read(block=b6, read_buffer_index=0, storage_scope="shared") sch.compute_at(block=b40, loop=l33, preserve_unit_loops=True) l41, l42, l43, l44, l45, l46 = sch.get_loops(block=b40) l47 = sch.fuse(l45, l46) v48 = sch.sample_categorical(candidates=[1, 2, 3, 4], probs=[0.25, 0.25, 0.25, 0.25]) sch.annotate(block_or_loop=b40, ann_key="meta_schedule.cooperative_fetch", ann_val=v48) b49 = sch.cache_read(block=b6, read_buffer_index=1, storage_scope="shared") sch.compute_at(block=b49, loop=l33, preserve_unit_loops=True) l50, l51, l52, l53, l54, l55 = sch.get_loops(block=b49) l56 = sch.fuse(l54, l55) v57 = sch.sample_categorical(candidates=[1, 2, 3, 4], probs=[0.25, 0.25, 0.25, 0.25]) sch.annotate(block_or_loop=b49, ann_key="meta_schedule.cooperative_fetch", ann_val=v57)""".split( "\n" ) ] check_trace(spaces, expected)
def test_multi_level_tiling_conv2d_nchwc_vnni(): target = "llvm -mcpu=cascadelake -num-cores 4" ctx = _create_context( Conv2dNCHWcVNNIModule, target=tvm.target.Target(target), rule=schedule_rule.MultiLevelTilingWithIntrin( VNNI_INTRIN, structure="SSRSRS", tile_binds=None, max_innermost_factor=64, vector_load_lens=None, reuse_read=None, reuse_write=schedule_rule.ReuseType( req="may", levels=[1, 2], scope="global", ), ), ) spaces = ctx.space_generator.generate_design_space(mod=ctx.mod) expected = [ """b0 = sch.get_block(name="conv2d_NCHWc_int8", func_name="main") sch.annotate(block_or_loop=b0, ann_key="meta_schedule.tiling_structure", ann_val="SSRSRS") l1, l2, l3, l4, l5, l6, l7, l8, l9, l10 = sch.get_loops(block=b0) l11, l12 = sch.split(loop=l10, factors=[1, 4]) l13, l14 = sch.split(loop=l5, factors=[1, 16]) l15, l16, l17, l18, l19, l20, l21, l22, l23, l24, l25, l26 = sch.get_loops(block=b0) sch.reorder(l21, l22, l23, l24, l25, l14, l12) b27 = sch.blockize(loop=l14) sch.annotate(block_or_loop=b27, ann_key="meta_schedule.auto_tensorize", ann_val="dot_16x4_vnni") l28, l29, l30, l31, l32, l33, l34, l35, l36, l37 = sch.get_loops(block=b27) v38, v39, v40, v41 = sch.sample_perfect_tile(loop=l28, n=4, max_innermost_factor=64) l42, l43, l44, l45 = sch.split(loop=l28, factors=[v38, v39, v40, v41]) v46, v47, v48, v49 = sch.sample_perfect_tile(loop=l29, n=4, max_innermost_factor=64) l50, l51, l52, l53 = sch.split(loop=l29, factors=[v46, v47, v48, v49]) v54, v55, v56, v57 = sch.sample_perfect_tile(loop=l30, n=4, max_innermost_factor=64) l58, l59, l60, l61 = sch.split(loop=l30, factors=[v54, v55, v56, v57]) v62, v63, v64, v65 = sch.sample_perfect_tile(loop=l31, n=4, max_innermost_factor=64) l66, l67, l68, l69 = sch.split(loop=l31, factors=[v62, v63, v64, v65]) v70, v71, v72, v73 = sch.sample_perfect_tile(loop=l32, n=4, max_innermost_factor=64) l74, l75, l76, l77 = sch.split(loop=l32, factors=[v70, v71, v72, v73]) v78, v79 = sch.sample_perfect_tile(loop=l33, n=2, max_innermost_factor=64) l80, l81 = sch.split(loop=l33, factors=[v78, v79]) v82, v83 = sch.sample_perfect_tile(loop=l34, n=2, max_innermost_factor=64) l84, l85 = sch.split(loop=l34, factors=[v82, v83]) v86, v87 = sch.sample_perfect_tile(loop=l35, n=2, max_innermost_factor=64) l88, l89 = sch.split(loop=l35, factors=[v86, v87]) v90, v91 = sch.sample_perfect_tile(loop=l36, n=2, max_innermost_factor=64) l92, l93 = sch.split(loop=l36, factors=[v90, v91]) v94, v95 = sch.sample_perfect_tile(loop=l37, n=2, max_innermost_factor=64) l96, l97 = sch.split(loop=l37, factors=[v94, v95]) sch.reorder(l42, l50, l58, l66, l74, l43, l51, l59, l67, l75, l80, l84, l88, l92, l96, l44, l52, l60, l68, l76, l81, l85, l89, l93, l97, l45, l53, l61, l69, l77) b98 = sch.cache_write(block=b27, write_buffer_index=0, storage_scope="global") sch.reverse_compute_at(block=b98, loop=l75, preserve_unit_loops=True)""".split( "\n" ), """b0 = sch.get_block(name="conv2d_NCHWc_int8", func_name="main") sch.annotate(block_or_loop=b0, ann_key="meta_schedule.tiling_structure", ann_val="SSRSRS") l1, l2, l3, l4, l5, l6, l7, l8, l9, l10 = sch.get_loops(block=b0) l11, l12 = sch.split(loop=l10, factors=[1, 4]) l13, l14 = sch.split(loop=l5, factors=[1, 16]) l15, l16, l17, l18, l19, l20, l21, l22, l23, l24, l25, l26 = sch.get_loops(block=b0) sch.reorder(l21, l22, l23, l24, l25, l14, l12) b27 = sch.blockize(loop=l14) sch.annotate(block_or_loop=b27, ann_key="meta_schedule.auto_tensorize", ann_val="dot_16x4_vnni") l28, l29, l30, l31, l32, l33, l34, l35, l36, l37 = sch.get_loops(block=b27) v38, v39, v40, v41 = sch.sample_perfect_tile(loop=l28, n=4, max_innermost_factor=64) l42, l43, l44, l45 = sch.split(loop=l28, factors=[v38, v39, v40, v41]) v46, v47, v48, v49 = sch.sample_perfect_tile(loop=l29, n=4, max_innermost_factor=64) l50, l51, l52, l53 = sch.split(loop=l29, factors=[v46, v47, v48, v49]) v54, v55, v56, v57 = sch.sample_perfect_tile(loop=l30, n=4, max_innermost_factor=64) l58, l59, l60, l61 = sch.split(loop=l30, factors=[v54, v55, v56, v57]) v62, v63, v64, v65 = sch.sample_perfect_tile(loop=l31, n=4, max_innermost_factor=64) l66, l67, l68, l69 = sch.split(loop=l31, factors=[v62, v63, v64, v65]) v70, v71, v72, v73 = sch.sample_perfect_tile(loop=l32, n=4, max_innermost_factor=64) l74, l75, l76, l77 = sch.split(loop=l32, factors=[v70, v71, v72, v73]) v78, v79 = sch.sample_perfect_tile(loop=l33, n=2, max_innermost_factor=64) l80, l81 = sch.split(loop=l33, factors=[v78, v79]) v82, v83 = sch.sample_perfect_tile(loop=l34, n=2, max_innermost_factor=64) l84, l85 = sch.split(loop=l34, factors=[v82, v83]) v86, v87 = sch.sample_perfect_tile(loop=l35, n=2, max_innermost_factor=64) l88, l89 = sch.split(loop=l35, factors=[v86, v87]) v90, v91 = sch.sample_perfect_tile(loop=l36, n=2, max_innermost_factor=64) l92, l93 = sch.split(loop=l36, factors=[v90, v91]) v94, v95 = sch.sample_perfect_tile(loop=l37, n=2, max_innermost_factor=64) l96, l97 = sch.split(loop=l37, factors=[v94, v95]) sch.reorder(l42, l50, l58, l66, l74, l43, l51, l59, l67, l75, l80, l84, l88, l92, l96, l44, l52, l60, l68, l76, l81, l85, l89, l93, l97, l45, l53, l61, l69, l77) b98 = sch.cache_write(block=b27, write_buffer_index=0, storage_scope="global") sch.reverse_compute_at(block=b98, loop=l74, preserve_unit_loops=True)""".split( "\n" ), """b0 = sch.get_block(name="conv2d_NCHWc_int8", func_name="main") sch.annotate(block_or_loop=b0, ann_key="meta_schedule.tiling_structure", ann_val="SSRSRS") l1, l2, l3, l4, l5, l6, l7, l8, l9, l10 = sch.get_loops(block=b0) l11, l12 = sch.split(loop=l10, factors=[1, 4]) l13, l14 = sch.split(loop=l5, factors=[1, 16]) l15, l16, l17, l18, l19, l20, l21, l22, l23, l24, l25, l26 = sch.get_loops(block=b0) sch.reorder(l21, l22, l23, l24, l25, l14, l12) b27 = sch.blockize(loop=l14) sch.annotate(block_or_loop=b27, ann_key="meta_schedule.auto_tensorize", ann_val="dot_16x4_vnni") l28, l29, l30, l31, l32, l33, l34, l35, l36, l37 = sch.get_loops(block=b27) v38, v39, v40, v41 = sch.sample_perfect_tile(loop=l28, n=4, max_innermost_factor=64) l42, l43, l44, l45 = sch.split(loop=l28, factors=[v38, v39, v40, v41]) v46, v47, v48, v49 = sch.sample_perfect_tile(loop=l29, n=4, max_innermost_factor=64) l50, l51, l52, l53 = sch.split(loop=l29, factors=[v46, v47, v48, v49]) v54, v55, v56, v57 = sch.sample_perfect_tile(loop=l30, n=4, max_innermost_factor=64) l58, l59, l60, l61 = sch.split(loop=l30, factors=[v54, v55, v56, v57]) v62, v63, v64, v65 = sch.sample_perfect_tile(loop=l31, n=4, max_innermost_factor=64) l66, l67, l68, l69 = sch.split(loop=l31, factors=[v62, v63, v64, v65]) v70, v71, v72, v73 = sch.sample_perfect_tile(loop=l32, n=4, max_innermost_factor=64) l74, l75, l76, l77 = sch.split(loop=l32, factors=[v70, v71, v72, v73]) v78, v79 = sch.sample_perfect_tile(loop=l33, n=2, max_innermost_factor=64) l80, l81 = sch.split(loop=l33, factors=[v78, v79]) v82, v83 = sch.sample_perfect_tile(loop=l34, n=2, max_innermost_factor=64) l84, l85 = sch.split(loop=l34, factors=[v82, v83]) v86, v87 = sch.sample_perfect_tile(loop=l35, n=2, max_innermost_factor=64) l88, l89 = sch.split(loop=l35, factors=[v86, v87]) v90, v91 = sch.sample_perfect_tile(loop=l36, n=2, max_innermost_factor=64) l92, l93 = sch.split(loop=l36, factors=[v90, v91]) v94, v95 = sch.sample_perfect_tile(loop=l37, n=2, max_innermost_factor=64) l96, l97 = sch.split(loop=l37, factors=[v94, v95]) sch.reorder(l42, l50, l58, l66, l74, l43, l51, l59, l67, l75, l80, l84, l88, l92, l96, l44, l52, l60, l68, l76, l81, l85, l89, l93, l97, l45, l53, l61, l69, l77)""".split( "\n" ), ] check_trace(spaces, expected)
into_producer=False, into_consumer=True, inline_const_tensor=True, disallow_if_then_else=True, require_injective=True, require_ordered=True, disallow_op=["tir.exp"], ), schedule_rule.AddRFactor(max_jobs_per_core=16, max_innermost_factor=64), schedule_rule.MultiLevelTilingWithIntrin( VNNI_INTRIN, structure="SSRSRS", tile_binds=None, max_innermost_factor=64, vector_load_lens=None, reuse_read=None, reuse_write=schedule_rule.ReuseType( req="may", levels=[1, 2], scope="global", ), ), schedule_rule.ParallelizeVectorizeUnroll( max_jobs_per_core=16, max_vectorize_extent=64, unroll_max_steps=[0, 16, 64, 512], unroll_explicit=True, ), schedule_rule.RandomComputeLocation(), ]