def test_meta_schedule_random_model_reseed(): model = RandomModel(seed=100) res = model.predict(TuneContext(), [MeasureCandidate(Schedule(Matmul), []) for i in range(20)]) new_model = RandomModel(seed=100) new_res = new_model.predict( TuneContext(), [MeasureCandidate(Schedule(Matmul), []) for i in range(20)] ) assert (res == new_res).all()
def test_meta_schedule_random_model_reload(): model = RandomModel(seed=25973) model.predict( TuneContext(), [MeasureCandidate(Schedule(Matmul), []) for i in range(30)] ) # change state path = os.path.join(tempfile.mkdtemp(), "test_output_meta_schedule_random_model.npy") model.save(path) res1 = model.predict(TuneContext(), [MeasureCandidate(Schedule(Matmul), []) for i in range(70)]) model.load(path) res2 = model.predict(TuneContext(), [MeasureCandidate(Schedule(Matmul), []) for i in range(70)]) shutil.rmtree(os.path.dirname(path)) assert (res1 == res2).all()
def test_meta_schedule_cost_model(): class FancyCostModel(PyCostModel): def load(self, path: str) -> None: pass def save(self, path: str) -> None: pass def update( self, tune_context: TuneContext, candidates: List[MeasureCandidate], results: List[RunnerResult], ) -> None: pass def predict(self, tune_context: TuneContext, candidates: List[MeasureCandidate]) -> np.ndarray: return np.random.rand(10) model = FancyCostModel() model.save("fancy_test_location") model.load("fancy_test_location") model.update(TuneContext(), [], []) results = model.predict(TuneContext, [MeasureCandidate(Schedule(mod=Matmul), [])]) assert results.shape == (10, )
def test_meta_schedule_measure_callback_fail(): @derived_object class FailingMeasureCallback(PyMeasureCallback): def apply( self, task_scheduler: TaskScheduler, task_id: int, measure_candidates: List[MeasureCandidate], builds: List[BuilderResult], results: List[RunnerResult], ) -> None: raise ValueError("test") measure_callback = FailingMeasureCallback() with pytest.raises(ValueError, match="test"): measure_callback.apply( RoundRobin([], [], DummyBuilder(), DummyRunner(), DummyDatabase(), max_trials=1), 0, [MeasureCandidate(Schedule(Matmul), None)], [BuilderResult("test_build", None)], [RunnerResult([1.0, 2.1], None)], )
def test_meta_schedule_measure_callback(): @derived_object class FancyMeasureCallback(PyMeasureCallback): def apply( self, task_scheduler: TaskScheduler, task_id: int, measure_candidates: List[MeasureCandidate], builds: List[BuilderResult], results: List[RunnerResult], ) -> None: assert len(measure_candidates) == 1 assert_structural_equal(measure_candidates[0].sch.mod, Matmul) assert (len(builds) == 1 and builds[0].error_msg is None and builds[0].artifact_path == "test_build") assert (len(results) == 1 and results[0].error_msg is None and len(results[0].run_secs) == 2) measure_callback = FancyMeasureCallback() measure_callback.apply( RoundRobin([], [], DummyBuilder(), DummyRunner(), DummyDatabase(), max_trials=1), 0, [MeasureCandidate(Schedule(Matmul), None)], [BuilderResult("test_build", None)], [RunnerResult([1.0, 2.1], None)], )
def test_meta_schedule_random_model(): model = RandomModel() model.update(TuneContext(), [], []) res = model.predict( TuneContext(), [MeasureCandidate(Schedule(Matmul), []) for i in range(10)]) assert len(res) == 10 assert min(res) >= 0 and max(res) <= model.max_range
def test_meta_schedule_measure_callback_fail(): class FailingMeasureCallback(PyMeasureCallback): def apply( self, task_scheduler: TaskScheduler, task_id: int, measure_candidates: List[MeasureCandidate], builds: List[BuilderResult], results: List[RunnerResult], ) -> None: raise ValueError("test") measure_callback = FailingMeasureCallback() with pytest.raises(ValueError, match="test"): measure_callback.apply( TaskScheduler(), 0, [MeasureCandidate(Schedule(Matmul), None)], [BuilderResult("test_build", None)], [RunnerResult([1.0, 2.1], None)], )
def _dummy_candidate(): return MeasureCandidate(Schedule(Matmul), [])