def test_meta_schedule_post_order_apply_double(): mod = Matmul context = TuneContext( mod=mod, target=Target("llvm"), task_name="Double Rules Task", sch_rules=[DoubleScheduleRule()], ) post_order_apply = PostOrderApply() post_order_apply.initialize_with_tune_context(context) schs = post_order_apply.generate_design_space(mod) assert len(schs) == 2 for sch in schs: assert not tvm.ir.structural_equal(sch.mod, mod) _check_correct(sch)
def test_conv2d_winograd_cuda(): mod = conv2d_winograd_cuda mod = IRModule({"main": mod}) context = TuneContext( mod=mod, target=Target("nvidia/geforce-rtx-3090", host="llvm"), task_name="Custom Search Space Task", sch_rules=DefaultCUDA._sch_rules(), # pylint: disable=protected-access ) for sch_rule in context.sch_rules: sch_rule.initialize_with_tune_context(context) post_order_apply = PostOrderApply() post_order_apply.initialize_with_tune_context(context) (sch,) = post_order_apply.generate_design_space(mod) decisions = dict( zip( [i for i in sch.trace.insts if i.kind.name.startswith("Sample")], [ # data_pack [3, 3], [64, 2], 2, # inverse [3, 3], [2, 64], 2, # bgemm [1, 1, 1, 1, 6], [1, 1, 1, 3, 2], [3, 1, 1, 1, 3], [4, 2, 1, 4, 4], [32, 1, 4], 1, 1, # root anno 2, # conv2d 2, ], ) ) trace = Trace(sch.trace.insts, decisions=decisions) sch = Schedule(mod=mod) trace.apply_to_schedule(sch, remove_postproc=False) answer = sch.mod expected = _get_mod() tvm.ir.assert_structural_equal(answer, expected)
def test_conv2d_winograd_cpu(): mod = conv2d_winograd_cpu mod = IRModule({"main": mod}) context = TuneContext( mod=mod, target=Target("llvm"), task_name="Custom Search Space Task", sch_rules=DefaultLLVM._sch_rules(), # pylint: disable=protected-access ) post_order_apply = PostOrderApply() post_order_apply.initialize_with_tune_context(context) (sch, ) = post_order_apply.generate_design_space(mod) decisions = dict( zip( [ i for i in sch.trace.insts[:-4] if i.kind.name.startswith("Sample") ], [ # data_pack [9, 1], [32, 4], # input_tile 4, # data_pad -2, # inverse [1, 9], [2, 64], # bgemm [1, 2, 3, 1], [1, 1, 1, 6], [1, 1, 1, 9], [2, 1, 16, 4], [16, 8], ], )) trace = Trace(sch.trace.insts[:-4], decisions=decisions) sch = Schedule(mod=mod) trace.apply_to_schedule(sch, remove_postproc=False) answer = sch.mod expected = _get_mod() tvm.ir.assert_structural_equal(answer, expected)
def test_meta_schedule_post_order_apply_remove_block(): @derived_object class TrinityDouble(PyScheduleRule): def initialize_with_tune_context(self, context: "TuneContext") -> None: pass def apply(self, sch: Schedule, block: BlockRV) -> List[Schedule]: if _is_root(sch, block): return [sch] new_sch = sch.copy() i, j = new_sch.get_loops(block=block) i_0, i_1 = new_sch.split(loop=i, factors=[16, 64]) j_0, j_1 = new_sch.split(loop=j, factors=[64, 16]) new_sch.reorder(i_0, j_0, i_1, j_1) result = [new_sch] new_sch = sch.copy() i, j = new_sch.get_loops(block=block) i_0, i_1 = new_sch.split(loop=i, factors=[2, 512]) j_0, j_1 = new_sch.split(loop=j, factors=[2, 512]) new_sch.reorder(i_0, j_0, i_1, j_1) result.append(new_sch) return result @derived_object class RemoveBlock(PyScheduleRule): def initialize_with_tune_context(self, context: "TuneContext") -> None: pass def apply(self, sch: Schedule, block: BlockRV) -> List[Schedule]: if _is_root(sch, block): return [sch] sch = sch.copy() if sch.get(block).name_hint == "B": sch.compute_inline(block) return [sch] def correct_trace(a, b, c, d): return "\n".join([ 'b0 = sch.get_block(name="A", func_name="main")', 'b1 = sch.get_block(name="B", func_name="main")', 'b2 = sch.get_block(name="C", func_name="main")', "sch.compute_inline(block=b1)", "l3, l4 = sch.get_loops(block=b2)", "l5, l6 = sch.split(loop=l3, factors=" + str(a) + ")", "l7, l8 = sch.split(loop=l4, factors=" + str(b) + ")", "sch.reorder(l5, l7, l6, l8)", "l9, l10 = sch.get_loops(block=b0)", "l11, l12 = sch.split(loop=l9, factors=" + str(c) + ")", "l13, l14 = sch.split(loop=l10, factors=" + str(d) + ")", "sch.reorder(l11, l13, l12, l14)", ]) mod = TrinityMatmul context = TuneContext( mod=mod, target=Target("llvm"), task_name="Remove Block Task", sch_rules=[RemoveBlock(), TrinityDouble()], ) post_order_apply = PostOrderApply() post_order_apply.initialize_with_tune_context(context) schs = post_order_apply.generate_design_space(mod) assert len(schs) == 4 for sch in schs: with pytest.raises( tvm.tir.schedule.schedule.ScheduleError, match= "ScheduleError: An error occurred in the schedule primitive 'get-block'.", ): sch.get_block("B", "main") sch_trace = sch.trace.simplified(True) assert (str(sch_trace) == correct_trace([16, 64], [64, 16], [2, 512], [2, 512]) or str(sch_trace) == correct_trace([2, 512], [2, 512], [2, 512], [2, 512]) or str(sch_trace) == correct_trace( [16, 64], [64, 16], [16, 64], [64, 16]) or str(sch_trace) == correct_trace([2, 512], [2, 512], [16, 64], [64, 16]))