def _sch_rules() -> List[ScheduleRule]: from tvm.meta_schedule import schedule_rule as M return [ M.MultiLevelTiling( structure="SSSRRSRS", tile_binds=["blockIdx.x", "vthread.x", "threadIdx.x"], max_innermost_factor=64, vector_load_lens=[1, 2, 3, 4], reuse_read=M.ReuseType( req="must", levels=[4], scope="shared", ), reuse_write=M.ReuseType( req="must", levels=[3], scope="local", ), ), M.AutoInline( into_producer=True, into_consumer=True, inline_const_tensor=True, disallow_if_then_else=False, require_injective=False, require_ordered=False, disallow_op=None, ), M.CrossThreadReduction( thread_extents=[4, 8, 16, 32, 64, 128, 256, 512]), M.ParallelizeVectorizeUnroll( max_jobs_per_core=-1, # disable parallelize max_vectorize_extent=-1, # disable vectorize unroll_max_steps=[0, 16, 64, 512, 1024], unroll_explicit=True, ), M.AutoBind( max_threadblocks=256, thread_extents=[32, 64, 128, 256, 512, 1024], ), ]
def get_sch_rules_for_dp4a(intrin): return [ schedule_rule.MultiLevelTilingWithIntrin( intrin, structure="SSSRRSRS", tile_binds=["blockIdx.x", "vthread.x", "threadIdx.x"], max_innermost_factor=64, vector_load_lens=[1, 2, 3, 4], reuse_read=schedule_rule.ReuseType( req="must", levels=[4], scope="shared", ), reuse_write=schedule_rule.ReuseType( req="must", levels=[3], scope="local", ), ), schedule_rule.AutoInline( into_producer=True, into_consumer=True, inline_const_tensor=True, disallow_if_then_else=False, require_injective=False, require_ordered=False, disallow_op=None, ), schedule_rule.CrossThreadReduction( thread_extents=[4, 8, 16, 32, 64, 128, 256, 512]), schedule_rule.ParallelizeVectorizeUnroll( max_jobs_per_core=-1, # disable parallelize max_vectorize_extent=-1, # disable vectorize unroll_max_steps=[0, 16, 64, 512, 1024], unroll_explicit=True, ), ]