def _sch_rules() -> List[ScheduleRule]: from tvm.meta_schedule import schedule_rule as M return [ M.AutoInline( into_producer=False, into_consumer=True, inline_const_tensor=True, disallow_if_then_else=True, require_injective=True, require_ordered=True, disallow_op=["tir.exp"], ), M.AddRFactor(max_jobs_per_core=16, max_innermost_factor=64), M.MultiLevelTiling( structure="SSRSRS", tile_binds=None, max_innermost_factor=64, vector_load_lens=None, reuse_read=None, reuse_write=M.ReuseType( req="may", levels=[1, 2], scope="global", ), ), M.ParallelizeVectorizeUnroll( max_jobs_per_core=16, max_vectorize_extent=64, unroll_max_steps=[0, 16, 64, 512], unroll_explicit=True, ), M.RandomComputeLocation(), ]
max_innermost_factor=64, vector_load_lens=None, reuse_read=None, reuse_write=schedule_rule.ReuseType( req="may", levels=[1, 2], scope="global", ), ), schedule_rule.ParallelizeVectorizeUnroll( max_jobs_per_core=16, max_vectorize_extent=64, unroll_max_steps=[0, 16, 64, 512], unroll_explicit=True, ), schedule_rule.RandomComputeLocation(), ] def get_sch_rules_for_dp4a(intrin): return [ schedule_rule.MultiLevelTilingWithIntrin( intrin, structure="SSSRRSRS", tile_binds=["blockIdx.x", "vthread.x", "threadIdx.x"], max_innermost_factor=64, vector_load_lens=[1, 2, 3, 4], reuse_read=schedule_rule.ReuseType( req="must", levels=[4], scope="shared",