def test_meta_schedule_space_generator_schedule_fn(): mod = Matmul space_generator = ScheduleFn(sch_fn=schedule_matmul) design_spaces = space_generator.generate_design_space(mod) assert len(design_spaces) == 1 (schedule, ) = design_spaces _check_correct(schedule)
def test_meta_schedule_replay_func( TestClass: SearchStrategy): # pylint: disable = invalid-name num_trials_per_iter = 7 max_trials_per_task = 20 strategy = TestClass(num_trials_per_iter=num_trials_per_iter, max_trials_per_task=max_trials_per_task) context = TuneContext(mod=Matmul, space_generator=ScheduleFn(sch_fn=_schedule_matmul)) context.space_generator.initialize_with_tune_context(context) spaces = context.space_generator.generate_design_space(context.mod) strategy.initialize_with_tune_context(context) strategy.pre_tuning(spaces) (correct_sch, ) = ScheduleFn( sch_fn=_schedule_matmul).generate_design_space(Matmul) num_trials_each_iter: List[int] = [] candidates = strategy.generate_measure_candidates() while candidates is not None: num_trials_each_iter.append(len(candidates)) runner_results: List[RunnerResult] = [] for candidate in candidates: _is_trace_equal( candidate.sch, correct_sch, remove_decisions=(isinstance(strategy, ReplayTrace)), ) runner_results.append( RunnerResult(run_secs=[0.11, 0.41, 0.54], error_msg=None)) strategy.notify_runner_results(context, candidates, runner_results) candidates = strategy.generate_measure_candidates() strategy.post_tuning() assert num_trials_each_iter == [7, 7, 6]
def test_meta_schedule_task_scheduler_override_next_task_id_only(): # pylint: disable=invalid-name num_trials_per_iter = 6 max_trials_per_task = 101 tasks = [ TuneContext( MatmulModule, target=tvm.target.Target("llvm"), space_generator=ScheduleFn(sch_fn=_schedule_matmul), search_strategy=ReplayTrace(num_trials_per_iter, max_trials_per_task), task_name="Matmul", rand_state=42, ), TuneContext( MatmulReluModule, target=tvm.target.Target("llvm"), space_generator=ScheduleFn(sch_fn=_schedule_matmul), search_strategy=ReplayTrace(num_trials_per_iter, max_trials_per_task), task_name="MatmulRelu", rand_state=0xDEADBEEF, ), TuneContext( BatchMatmulModule, target=tvm.target.Target("llvm"), space_generator=ScheduleFn(sch_fn=_schedule_batch_matmul), search_strategy=ReplayTrace(num_trials_per_iter, max_trials_per_task), task_name="BatchMatmul", rand_state=0x114514, ), ] database = DummyDatabase() scheduler = MyTaskScheduler( tasks, DummyBuilder(), DummyRunner(), database, measure_callbacks=[ measure_callback.AddToDatabase(), ], max_trials=max_trials_per_task * len(tasks), ) scheduler.tune() assert len(database) == max_trials_per_task * len(tasks) for task in tasks: assert (len( database.get_top_k( database.commit_workload(task.mod), 100000, )) == max_trials_per_task)
def test_meta_schedule_task_scheduler_multiple(): num_trials_per_iter = 6 max_trials_per_task = 101 tasks = [ TuneContext( MatmulModule, target=tvm.target.Target("llvm"), space_generator=ScheduleFn(sch_fn=_schedule_matmul), search_strategy=ReplayTrace(num_trials_per_iter, max_trials_per_task), task_name="Matmul", rand_state=42, ), TuneContext( MatmulReluModule, target=tvm.target.Target("llvm"), space_generator=ScheduleFn(sch_fn=_schedule_matmul), search_strategy=ReplayTrace(num_trials_per_iter, max_trials_per_task), task_name="MatmulRelu", rand_state=0xDEADBEEF, ), TuneContext( BatchMatmulModule, target=tvm.target.Target("llvm"), space_generator=ScheduleFn(sch_fn=_schedule_batch_matmul), search_strategy=ReplayTrace(num_trials_per_iter, max_trials_per_task), task_name="BatchMatmul", rand_state=0x114514, ), ] database = DummyDatabase() round_robin = RoundRobin( tasks, [1.0], DummyBuilder(), DummyRunner(), database, measure_callbacks=[measure_callback.AddToDatabase()], max_trials=max_trials_per_task * len(tasks), ) round_robin.tune() assert len(database) == max_trials_per_task * len(tasks) for task in tasks: assert (len( database.get_top_k( database.commit_workload(task.mod), 100000, )) == max_trials_per_task)
def test_meta_schedule_task_scheduler_single(): num_trials_per_iter = 3 max_trials_per_task = 10 sch_fn = ScheduleFn(sch_fn=_schedule_matmul) replay = ReplayTrace(num_trials_per_iter, max_trials_per_task) task = TuneContext( MatmulModule, target=tvm.target.Target("llvm"), space_generator=sch_fn, search_strategy=replay, task_name="Test", rand_state=42, ) database = DummyDatabase() round_robin = RoundRobin( [task], [1.0], DummyBuilder(), DummyRunner(), database, measure_callbacks=[measure_callback.AddToDatabase()], max_trials=max_trials_per_task, ) round_robin.tune() assert len(database) == max_trials_per_task
def test_meta_schedule_replay_trace(): num_trials_per_iter = 7 num_trials_total = 20 (example_sch, ) = ScheduleFn( sch_fn=_schedule_matmul).generate_design_space(Matmul) replay = ReplayTrace(num_trials_per_iter=num_trials_per_iter, num_trials_total=num_trials_total) tune_context = TuneContext(mod=Matmul) replay.initialize_with_tune_context(tune_context) num_trials_each_round: List[int] = [] replay.pre_tuning([example_sch]) while True: candidates = replay.generate_measure_candidates() if candidates is None: break num_trials_each_round.append(len(candidates)) runner_results: List[RunnerResult] = [] for candidate in candidates: assert _is_trace_equal(candidate.sch, example_sch) runner_results.append( RunnerResult(run_secs=[0.5, 0.4, 0.3], error_msg=None)) replay.notify_runner_results(runner_results) replay.post_tuning() assert num_trials_each_round == [7, 7, 6]
def test_meta_schedule_design_space_generator_union(): mod = Matmul space_generator = ScheduleFn(sch_fn=schedule_matmul) space_generator_union = SpaceGeneratorUnion([space_generator, space_generator]) design_spaces = space_generator_union.generate_design_space(mod) assert len(design_spaces) == 2 for design_space in design_spaces: _check_correct(design_space)
def test_meta_schedule_task_scheduler_multiple(): num_trials_per_iter = 6 num_trials_total = 101 tasks = [ TuneContext( MatmulModule, target=tvm.target.Target("llvm"), space_generator=ScheduleFn(sch_fn=_schedule_matmul), search_strategy=ReplayTrace(num_trials_per_iter, num_trials_total), task_name="Matmul", rand_state=42, ), TuneContext( MatmulReluModule, target=tvm.target.Target("llvm"), space_generator=ScheduleFn(sch_fn=_schedule_matmul), search_strategy=ReplayTrace(num_trials_per_iter, num_trials_total), task_name="MatmulRelu", rand_state=0xDEADBEEF, ), TuneContext( BatchMatmulModule, target=tvm.target.Target("llvm"), space_generator=ScheduleFn(sch_fn=_schedule_batch_matmul), search_strategy=ReplayTrace(num_trials_per_iter, num_trials_total), task_name="BatchMatmul", rand_state=0x114514, ), ] database = DummyDatabase() round_robin = RoundRobin(tasks, DummyBuilder(), DummyRunner(), database) round_robin.tune() assert len(database) == num_trials_total * len(tasks) print(database.workload_reg) for task in tasks: assert len(database.get_top_k(database.commit_workload(task.mod), 1e9)) == num_trials_total
def test_meta_schedule_task_scheduler_single(): num_trials_per_iter = 3 num_trials_total = 10 sch_fn = ScheduleFn(sch_fn=_schedule_matmul) replay = ReplayTrace(num_trials_per_iter, num_trials_total) task = TuneContext( MatmulModule, target=tvm.target.Target("llvm"), space_generator=sch_fn, search_strategy=replay, task_name="Test", rand_state=42, ) database = DummyDatabase() round_robin = RoundRobin([task], DummyBuilder(), DummyRunner(), database) round_robin.tune() assert len(database) == num_trials_total
def test_meta_schedule_task_scheduler_override_next_task_id_only(): class MyTaskScheduler(PyTaskScheduler): done = set() def next_task_id(self) -> int: while len(self.done) != len(tasks): x = random.randint(0, len(tasks) - 1) task = tasks[x] if not task.is_stopped: """Calling base func via following route: Python side: PyTaskScheduler does not have `_is_task_running` Call TaskScheduler's `is_task_running`, which calls ffi C++ side: The ffi calls TaskScheduler's `is_task_running` But it is overridden in PyTaskScheduler PyTaskScheduler checks if the function is overridden in python If not, it returns the TaskScheduler's vtable, calling TaskScheduler::IsTaskRunning """ if self._is_task_running(x): # Same Here self._join_running_task(x) return x else: self.done.add(x) return -1 num_trials_per_iter = 6 num_trials_total = 101 tasks = [ TuneContext( MatmulModule, target=tvm.target.Target("llvm"), space_generator=ScheduleFn(sch_fn=_schedule_matmul), search_strategy=ReplayTrace(num_trials_per_iter, num_trials_total), task_name="Matmul", rand_state=42, ), TuneContext( MatmulReluModule, target=tvm.target.Target("llvm"), space_generator=ScheduleFn(sch_fn=_schedule_matmul), search_strategy=ReplayTrace(num_trials_per_iter, num_trials_total), task_name="MatmulRelu", rand_state=0xDEADBEEF, ), TuneContext( BatchMatmulModule, target=tvm.target.Target("llvm"), space_generator=ScheduleFn(sch_fn=_schedule_batch_matmul), search_strategy=ReplayTrace(num_trials_per_iter, num_trials_total), task_name="BatchMatmul", rand_state=0x114514, ), ] database = DummyDatabase() scheduler = MyTaskScheduler(tasks, DummyBuilder(), DummyRunner(), database) scheduler.tune() assert len(database) == num_trials_total * len(tasks) for task in tasks: assert len(database.get_top_k(database.commit_workload(task.mod), 1e9)) == num_trials_total
def test_meta_schedule_evolutionary_search_early_stop( ): # pylint: disable = invalid-name] def _schedule_matmul_empty(sch: Schedule): return sch num_trials_per_iter = 10 max_trials_per_task = 100 strategy = EvolutionarySearch( num_trials_per_iter=num_trials_per_iter, max_trials_per_task=max_trials_per_task, population_size=5, init_measured_ratio=0.1, init_min_unmeasured=50, genetic_num_iters=3, genetic_mutate_prob=0.5, genetic_max_fail_count=10, eps_greedy=0.9, ) context = TuneContext( mod=Matmul, space_generator=ScheduleFn(sch_fn=_schedule_matmul_empty), mutator_probs={ DummyMutator(): 1.0, }, target=tvm.target.Target("llvm"), num_threads=1, # because we are using a mutator from the python side ) _scheduler = RoundRobin( tasks=[context], task_weights=[1.0], builder=ms.builder.LocalBuilder(), runner=ms.runner.LocalRunner(), database=DummyDatabase(), cost_model=ms.cost_model.RandomModel(), measure_callbacks=[], max_trials=1, ) context.space_generator.initialize_with_tune_context(context) spaces = context.space_generator.generate_design_space(context.mod) strategy.initialize_with_tune_context(context) strategy.pre_tuning(spaces) (correct_sch, ) = ScheduleFn( sch_fn=_schedule_matmul).generate_design_space(Matmul) num_trials_each_iter: List[int] = [] candidates = strategy.generate_measure_candidates() while candidates is not None: num_trials_each_iter.append(len(candidates)) runner_results: List[RunnerResult] = [] for candidate in candidates: _is_trace_equal( candidate.sch, correct_sch, remove_decisions=(isinstance(strategy, ReplayTrace)), ) runner_results.append( RunnerResult(run_secs=[0.11, 0.41, 0.54], error_msg=None)) strategy.notify_runner_results(context, candidates, runner_results) candidates = strategy.generate_measure_candidates() strategy.post_tuning() assert num_trials_each_iter == [1, 0, 0, 0, 0] del _scheduler
def test_meta_schedule_evolutionary_search( ): # pylint: disable = invalid-name] @derived_object class DummyMutator(PyMutator): """Dummy Mutator for testing""" def initialize_with_tune_context(self, context: "TuneContext") -> None: pass def apply(self, trace: Trace, _) -> Optional[Trace]: return Trace(trace.insts, {}) @derived_object class DummyDatabase(PyDatabase): """Dummy Database for testing""" def __init__(self): super().__init__() self.records = [] self.workload_reg = [] def has_workload(self, mod: IRModule) -> bool: for workload in self.workload_reg: if tvm.ir.structural_equal(workload.mod, mod): return True return False def commit_tuning_record(self, record: TuningRecord) -> None: self.records.append(record) def commit_workload(self, mod: IRModule) -> Workload: for workload in self.workload_reg: if tvm.ir.structural_equal(workload.mod, mod): return workload workload = Workload(mod) self.workload_reg.append(workload) return workload def get_top_k(self, workload: Workload, top_k: int) -> List[TuningRecord]: return list( filter( lambda x: x.workload == workload, sorted(self.records, key=lambda x: sum(x.run_secs) / len(x.run_secs)), ))[:int(top_k)] def __len__(self) -> int: return len(self.records) def print_results(self) -> None: print("\n".join([str(r) for r in self.records])) num_trials_per_iter = 10 num_trials_total = 100 strategy = EvolutionarySearch( num_trials_per_iter=num_trials_per_iter, num_trials_total=num_trials_total, population_size=5, init_measured_ratio=0.1, init_min_unmeasured=50, genetic_num_iters=3, genetic_mutate_prob=0.5, genetic_max_fail_count=10, eps_greedy=0.9, ) context = TuneContext( mod=Matmul, space_generator=ScheduleFn(sch_fn=_schedule_matmul), mutator_probs={ DummyMutator(): 1.0, }, target=tvm.target.Target("llvm"), num_threads=1, # because we are using a mutator from the python side ) _scheduler = RoundRobin( tasks=[context], builder=LocalBuilder(), runner=LocalRunner(), database=DummyDatabase(), cost_model=RandomModel(), measure_callbacks=[], ) context.space_generator.initialize_with_tune_context(context) spaces = context.space_generator.generate_design_space(context.mod) strategy.initialize_with_tune_context(context) strategy.pre_tuning(spaces) (correct_sch, ) = ScheduleFn( sch_fn=_schedule_matmul).generate_design_space(Matmul) num_trials_each_iter: List[int] = [] candidates = strategy.generate_measure_candidates() while candidates is not None: num_trials_each_iter.append(len(candidates)) runner_results: List[RunnerResult] = [] for candidate in candidates: _is_trace_equal( candidate.sch, correct_sch, remove_decisions=(isinstance(strategy, ReplayTrace)), ) runner_results.append( RunnerResult(run_secs=[0.11, 0.41, 0.54], error_msg=None)) strategy.notify_runner_results(context, candidates, runner_results) candidates = strategy.generate_measure_candidates() strategy.post_tuning() print(num_trials_each_iter) correct_count = 10 # For each iteration except the last one assert num_trials_each_iter == [correct_count] * ( num_trials_total // correct_count) + ( [num_trials_total % correct_count] if num_trials_total % correct_count != 0 else []) del _scheduler