def test_meta_schedule_task_scheduler_single(): num_trials_per_iter = 3 max_trials_per_task = 10 database = ms.database.MemoryDatabase() round_robin = ms.task_scheduler.RoundRobin( [ ms.TuneContext( MatmulModule, target=tvm.target.Target("llvm"), space_generator=ms.space_generator.ScheduleFn( sch_fn=_schedule_matmul), search_strategy=ms.search_strategy.ReplayTrace( num_trials_per_iter, max_trials_per_task, ), task_name="Test", rand_state=42, ) ], [1.0], builder=DummyBuilder(), runner=DummyRunner(), database=database, measure_callbacks=[ms.measure_callback.AddToDatabase()], max_trials=max_trials_per_task, ) round_robin.tune() assert len(database) == max_trials_per_task
def test_meta_schedule_measure_callback_fail(): @ms.derived_object class FailingMeasureCallback(ms.measure_callback.PyMeasureCallback): def apply( self, task_scheduler: ms.task_scheduler.TaskScheduler, task_id: int, measure_candidates: List[ms.MeasureCandidate], builder_results: List[ms.builder.BuilderResult], runner_results: List[ms.runner.RunnerResult], ) -> None: raise ValueError("test") measure_callback = FailingMeasureCallback() with pytest.raises(ValueError, match="test"): measure_callback.apply( ms.task_scheduler.RoundRobin( tasks=[], task_weights=[], builder=DummyBuilder(), runner=DummyRunner(), database=ms.database.MemoryDatabase(), max_trials=1, ), 0, [ms.MeasureCandidate(Schedule(Matmul), None)], [ms.builder.BuilderResult("test_build", None)], [ms.runner.RunnerResult([1.0, 2.1], None)], )
def test_meta_schedule_task_scheduler_multiple(): num_trials_per_iter = 6 max_trials_per_task = 101 tasks = [ ms.TuneContext( MatmulModule, target=tvm.target.Target("llvm"), space_generator=ms.space_generator.ScheduleFn( sch_fn=_schedule_matmul), search_strategy=ms.search_strategy.ReplayTrace( num_trials_per_iter, max_trials_per_task, ), task_name="Matmul", rand_state=42, ), ms.TuneContext( MatmulReluModule, target=tvm.target.Target("llvm"), space_generator=ms.space_generator.ScheduleFn( sch_fn=_schedule_matmul), search_strategy=ms.search_strategy.ReplayTrace( num_trials_per_iter, max_trials_per_task, ), task_name="MatmulRelu", rand_state=0xDEADBEEF, ), ms.TuneContext( BatchMatmulModule, target=tvm.target.Target("llvm"), space_generator=ms.space_generator.ScheduleFn( sch_fn=_schedule_batch_matmul), search_strategy=ms.search_strategy.ReplayTrace( num_trials_per_iter, max_trials_per_task, ), task_name="BatchMatmul", rand_state=0x114514, ), ] database = ms.database.MemoryDatabase() round_robin = ms.task_scheduler.RoundRobin( tasks, [1.0, 1.0, 1.0], builder=DummyBuilder(), runner=DummyRunner(), database=database, measure_callbacks=[ms.measure_callback.AddToDatabase()], max_trials=max_trials_per_task * len(tasks), ) round_robin.tune() assert len(database) == max_trials_per_task * len(tasks) for task in tasks: assert (len( database.get_top_k( database.commit_workload(task.mod), 100000, )) == max_trials_per_task)
def test_meta_schedule_task_scheduler_avoid_cyclic(): # pylint: disable=invalid-name database = ms.database.MemoryDatabase() scheduler = MyTaskScheduler( [], builder=DummyBuilder(), runner=DummyRunner(), database=database, measure_callbacks=[ ms.measure_callback.AddToDatabase(), ], max_trials=10, ) test = weakref.ref(scheduler) # test if it can be destructed successfully del scheduler assert test() is None
def test_meta_schedule_task_scheduler_NIE(): # pylint: disable=invalid-name @ms.derived_object class NIETaskScheduler(ms.task_scheduler.PyTaskScheduler): pass with pytest.raises( TVMError, match="PyTaskScheduler's NextTaskId method not implemented!"): scheduler = NIETaskScheduler( tasks=[], builder=DummyBuilder(), runner=DummyRunner(), database=ms.database.MemoryDatabase(), max_trials=1, ) scheduler.next_task_id()
def test_meta_schedule_measure_callback(): @ms.derived_object class FancyMeasureCallback(ms.measure_callback.PyMeasureCallback): def apply( self, task_scheduler: ms.task_scheduler.TaskScheduler, task_id: int, measure_candidates: List[ms.MeasureCandidate], builder_results: List[ms.builder.BuilderResult], runner_results: List[ms.runner.RunnerResult], ) -> None: assert len(measure_candidates) == 1 tvm.ir.assert_structural_equal(measure_candidates[0].sch.mod, Matmul) assert (len(builder_results) == 1 and builder_results[0].error_msg is None and builder_results[0].artifact_path == "test_build") assert (len(runner_results) == 1 and runner_results[0].error_msg is None and len(runner_results[0].run_secs) == 2) measure_callback = FancyMeasureCallback() measure_callback.apply( ms.task_scheduler.RoundRobin( tasks=[], task_weights=[], builder=DummyBuilder(), runner=DummyRunner(), database=ms.database.MemoryDatabase(), max_trials=1, ), 0, [ms.MeasureCandidate(Schedule(Matmul), None)], [ms.builder.BuilderResult("test_build", None)], [ms.runner.RunnerResult([1.0, 2.1], None)], )