def test_meta_schedule_random_model(): model = RandomModel() model.update(TuneContext(), [], []) res = model.predict( TuneContext(), [MeasureCandidate(Schedule(Matmul), []) for i in range(10)]) assert len(res) == 10 assert min(res) >= 0 and max(res) <= model.max_range
def test_meta_schedule_random_model_reload(): model = RandomModel(seed=25973) model.predict( TuneContext(), [MeasureCandidate(Schedule(Matmul), []) for i in range(30)] ) # change state path = os.path.join(tempfile.mkdtemp(), "test_output_meta_schedule_random_model.npy") model.save(path) res1 = model.predict(TuneContext(), [MeasureCandidate(Schedule(Matmul), []) for i in range(70)]) model.load(path) res2 = model.predict(TuneContext(), [MeasureCandidate(Schedule(Matmul), []) for i in range(70)]) shutil.rmtree(os.path.dirname(path)) assert (res1 == res2).all()
def test_meta_schedule_random_model_reseed(): model = RandomModel(seed=100) res = model.predict(TuneContext(), [MeasureCandidate(Schedule(Matmul), []) for i in range(20)]) new_model = RandomModel(seed=100) new_res = new_model.predict( TuneContext(), [MeasureCandidate(Schedule(Matmul), []) for i in range(20)] ) assert (res == new_res).all()
def test_meta_schedule_evolutionary_search( ): # pylint: disable = invalid-name] num_trials_per_iter = 10 num_trials_total = 100 strategy = EvolutionarySearch( num_trials_per_iter=num_trials_per_iter, num_trials_total=num_trials_total, population_size=5, init_measured_ratio=0.1, init_min_unmeasured=50, genetic_num_iters=3, genetic_mutate_prob=0.5, genetic_max_fail_count=10, eps_greedy=0.9, ) context = TuneContext( mod=Matmul, space_generator=ScheduleFn(sch_fn=_schedule_matmul), mutator_probs={ DummyMutator(): 1.0, }, target=tvm.target.Target("llvm"), num_threads=1, # because we are using a mutator from the python side ) _scheduler = RoundRobin( tasks=[context], builder=LocalBuilder(), runner=LocalRunner(), database=DummyDatabase(), cost_model=RandomModel(), measure_callbacks=[], ) context.space_generator.initialize_with_tune_context(context) spaces = context.space_generator.generate_design_space(context.mod) strategy.initialize_with_tune_context(context) strategy.pre_tuning(spaces) (correct_sch, ) = ScheduleFn( sch_fn=_schedule_matmul).generate_design_space(Matmul) num_trials_each_iter: List[int] = [] candidates = strategy.generate_measure_candidates() while candidates is not None: num_trials_each_iter.append(len(candidates)) runner_results: List[RunnerResult] = [] for candidate in candidates: _is_trace_equal( candidate.sch, correct_sch, remove_decisions=(isinstance(strategy, ReplayTrace)), ) runner_results.append( RunnerResult(run_secs=[0.11, 0.41, 0.54], error_msg=None)) strategy.notify_runner_results(context, candidates, runner_results) candidates = strategy.generate_measure_candidates() strategy.post_tuning() print(num_trials_each_iter) correct_count = 10 # For each iteration except the last one assert num_trials_each_iter == [correct_count] * ( num_trials_total // correct_count) + ( [num_trials_total % correct_count] if num_trials_total % correct_count != 0 else []) del _scheduler
def test_meta_schedule_evolutionary_search( ): # pylint: disable = invalid-name] @derived_object class DummyMutator(PyMutator): """Dummy Mutator for testing""" def initialize_with_tune_context(self, context: "TuneContext") -> None: pass def apply(self, trace: Trace, _) -> Optional[Trace]: return Trace(trace.insts, {}) @derived_object class DummyDatabase(PyDatabase): """Dummy Database for testing""" def __init__(self): super().__init__() self.records = [] self.workload_reg = [] def has_workload(self, mod: IRModule) -> bool: for workload in self.workload_reg: if tvm.ir.structural_equal(workload.mod, mod): return True return False def commit_tuning_record(self, record: TuningRecord) -> None: self.records.append(record) def commit_workload(self, mod: IRModule) -> Workload: for workload in self.workload_reg: if tvm.ir.structural_equal(workload.mod, mod): return workload workload = Workload(mod) self.workload_reg.append(workload) return workload def get_top_k(self, workload: Workload, top_k: int) -> List[TuningRecord]: return list( filter( lambda x: x.workload == workload, sorted(self.records, key=lambda x: sum(x.run_secs) / len(x.run_secs)), ))[:int(top_k)] def __len__(self) -> int: return len(self.records) def print_results(self) -> None: print("\n".join([str(r) for r in self.records])) num_trials_per_iter = 10 num_trials_total = 100 strategy = EvolutionarySearch( num_trials_per_iter=num_trials_per_iter, num_trials_total=num_trials_total, population_size=5, init_measured_ratio=0.1, init_min_unmeasured=50, genetic_num_iters=3, genetic_mutate_prob=0.5, genetic_max_fail_count=10, eps_greedy=0.9, ) context = TuneContext( mod=Matmul, space_generator=ScheduleFn(sch_fn=_schedule_matmul), mutator_probs={ DummyMutator(): 1.0, }, target=tvm.target.Target("llvm"), num_threads=1, # because we are using a mutator from the python side ) _scheduler = RoundRobin( tasks=[context], builder=LocalBuilder(), runner=LocalRunner(), database=DummyDatabase(), cost_model=RandomModel(), measure_callbacks=[], ) context.space_generator.initialize_with_tune_context(context) spaces = context.space_generator.generate_design_space(context.mod) strategy.initialize_with_tune_context(context) strategy.pre_tuning(spaces) (correct_sch, ) = ScheduleFn( sch_fn=_schedule_matmul).generate_design_space(Matmul) num_trials_each_iter: List[int] = [] candidates = strategy.generate_measure_candidates() while candidates is not None: num_trials_each_iter.append(len(candidates)) runner_results: List[RunnerResult] = [] for candidate in candidates: _is_trace_equal( candidate.sch, correct_sch, remove_decisions=(isinstance(strategy, ReplayTrace)), ) runner_results.append( RunnerResult(run_secs=[0.11, 0.41, 0.54], error_msg=None)) strategy.notify_runner_results(context, candidates, runner_results) candidates = strategy.generate_measure_candidates() strategy.post_tuning() print(num_trials_each_iter) correct_count = 10 # For each iteration except the last one assert num_trials_each_iter == [correct_count] * ( num_trials_total // correct_count) + ( [num_trials_total % correct_count] if num_trials_total % correct_count != 0 else []) del _scheduler