def test_meta_schedule_task_scheduler_single(): num_trials_per_iter = 3 max_trials_per_task = 10 sch_fn = ScheduleFn(sch_fn=_schedule_matmul) replay = ReplayTrace(num_trials_per_iter, max_trials_per_task) task = TuneContext( MatmulModule, target=tvm.target.Target("llvm"), space_generator=sch_fn, search_strategy=replay, task_name="Test", rand_state=42, ) database = DummyDatabase() round_robin = RoundRobin( [task], [1.0], DummyBuilder(), DummyRunner(), database, measure_callbacks=[measure_callback.AddToDatabase()], max_trials=max_trials_per_task, ) round_robin.tune() assert len(database) == max_trials_per_task
def test_meta_schedule_task_scheduler_override_next_task_id_only(): # pylint: disable=invalid-name num_trials_per_iter = 6 max_trials_per_task = 101 tasks = [ TuneContext( MatmulModule, target=tvm.target.Target("llvm"), space_generator=ScheduleFn(sch_fn=_schedule_matmul), search_strategy=ReplayTrace(num_trials_per_iter, max_trials_per_task), task_name="Matmul", rand_state=42, ), TuneContext( MatmulReluModule, target=tvm.target.Target("llvm"), space_generator=ScheduleFn(sch_fn=_schedule_matmul), search_strategy=ReplayTrace(num_trials_per_iter, max_trials_per_task), task_name="MatmulRelu", rand_state=0xDEADBEEF, ), TuneContext( BatchMatmulModule, target=tvm.target.Target("llvm"), space_generator=ScheduleFn(sch_fn=_schedule_batch_matmul), search_strategy=ReplayTrace(num_trials_per_iter, max_trials_per_task), task_name="BatchMatmul", rand_state=0x114514, ), ] database = DummyDatabase() scheduler = MyTaskScheduler( tasks, DummyBuilder(), DummyRunner(), database, measure_callbacks=[ measure_callback.AddToDatabase(), ], max_trials=max_trials_per_task * len(tasks), ) scheduler.tune() assert len(database) == max_trials_per_task * len(tasks) for task in tasks: assert (len( database.get_top_k( database.commit_workload(task.mod), 100000, )) == max_trials_per_task)
def test_meta_schedule_task_scheduler_multiple(): num_trials_per_iter = 6 max_trials_per_task = 101 tasks = [ TuneContext( MatmulModule, target=tvm.target.Target("llvm"), space_generator=ScheduleFn(sch_fn=_schedule_matmul), search_strategy=ReplayTrace(num_trials_per_iter, max_trials_per_task), task_name="Matmul", rand_state=42, ), TuneContext( MatmulReluModule, target=tvm.target.Target("llvm"), space_generator=ScheduleFn(sch_fn=_schedule_matmul), search_strategy=ReplayTrace(num_trials_per_iter, max_trials_per_task), task_name="MatmulRelu", rand_state=0xDEADBEEF, ), TuneContext( BatchMatmulModule, target=tvm.target.Target("llvm"), space_generator=ScheduleFn(sch_fn=_schedule_batch_matmul), search_strategy=ReplayTrace(num_trials_per_iter, max_trials_per_task), task_name="BatchMatmul", rand_state=0x114514, ), ] database = DummyDatabase() round_robin = RoundRobin( tasks, [1.0], DummyBuilder(), DummyRunner(), database, measure_callbacks=[measure_callback.AddToDatabase()], max_trials=max_trials_per_task * len(tasks), ) round_robin.tune() assert len(database) == max_trials_per_task * len(tasks) for task in tasks: assert (len( database.get_top_k( database.commit_workload(task.mod), 100000, )) == max_trials_per_task)
def test_meta_schedule_task_scheduler_avoid_cyclic(): # pylint: disable=invalid-name database = DummyDatabase() scheduler = MyTaskScheduler( [], DummyBuilder(), DummyRunner(), database, measure_callbacks=[ measure_callback.AddToDatabase(), ], ) test = weakref.ref(scheduler) # test if it can be destructed successfully del scheduler assert test() is None
def callbacks( # pylint: disable=redefined-outer-name measure_callbacks: Optional[List[MeasureCallback]], ) -> List[MeasureCallback]: """Normalize the input to List[tvm.meta_schedule.MeasureCallback]""" if measure_callbacks is None: from tvm.meta_schedule import measure_callback as M return [ M.AddToDatabase(), M.RemoveBuildArtifact(), M.EchoStatistics(), M.UpdateCostModel(), ] if not isinstance(measure_callbacks, (list, tuple)): raise TypeError( f"Expected `measure_callbacks` to be List[MeasureCallback], " f"but gets: {measure_callbacks}") measure_callbacks = list(measure_callbacks) for i, callback in enumerate(measure_callbacks): if not isinstance(callback, MeasureCallback): raise TypeError( f"Expected `measure_callbacks` to be List[MeasureCallback], " f"but measure_callbacks[{i}] is: {callback}") return measure_callbacks
def test_meta_schedule_task_scheduler_override_next_task_id_only(): # pylint: disable=invalid-name class MyTaskScheduler(PyTaskScheduler): done = set() def next_task_id(self) -> int: while len(self.done) != len(tasks): x = random.randint(0, len(tasks) - 1) task = tasks[x] if not task.is_stopped: """Calling base func via following route: Python side: PyTaskScheduler does not have `_is_task_running` Call TaskScheduler's `is_task_running`, which calls ffi C++ side: The ffi calls TaskScheduler's `is_task_running` But it is overridden in PyTaskScheduler PyTaskScheduler checks if the function is overridden in python If not, it returns the TaskScheduler's vtable, calling TaskScheduler::IsTaskRunning """ if self._is_task_running(x): # Same Here self._join_running_task(x) return x else: self.done.add(x) return -1 num_trials_per_iter = 6 num_trials_total = 101 tasks = [ TuneContext( MatmulModule, target=tvm.target.Target("llvm"), space_generator=ScheduleFn(sch_fn=_schedule_matmul), search_strategy=ReplayTrace(num_trials_per_iter, num_trials_total), task_name="Matmul", rand_state=42, ), TuneContext( MatmulReluModule, target=tvm.target.Target("llvm"), space_generator=ScheduleFn(sch_fn=_schedule_matmul), search_strategy=ReplayTrace(num_trials_per_iter, num_trials_total), task_name="MatmulRelu", rand_state=0xDEADBEEF, ), TuneContext( BatchMatmulModule, target=tvm.target.Target("llvm"), space_generator=ScheduleFn(sch_fn=_schedule_batch_matmul), search_strategy=ReplayTrace(num_trials_per_iter, num_trials_total), task_name="BatchMatmul", rand_state=0x114514, ), ] database = DummyDatabase() scheduler = MyTaskScheduler( tasks, DummyBuilder(), DummyRunner(), database, measure_callbacks=[ measure_callback.AddToDatabase(), ], ) scheduler.tune() assert len(database) == num_trials_total * len(tasks) for task in tasks: assert ( len( database.get_top_k( database.commit_workload(task.mod), 100000, ) ) == num_trials_total )