def test_meta_schedule_random_model_reseed():
    model = RandomModel(seed=100)
    res = model.predict(TuneContext(), [MeasureCandidate(Schedule(Matmul), []) for i in range(20)])
    new_model = RandomModel(seed=100)
    new_res = new_model.predict(
        TuneContext(), [MeasureCandidate(Schedule(Matmul), []) for i in range(20)]
    )
    assert (res == new_res).all()
def test_meta_schedule_random_model_reload():
    model = RandomModel(seed=25973)
    model.predict(
        TuneContext(), [MeasureCandidate(Schedule(Matmul), []) for i in range(30)]
    )  # change state
    path = os.path.join(tempfile.mkdtemp(), "test_output_meta_schedule_random_model.npy")
    model.save(path)
    res1 = model.predict(TuneContext(), [MeasureCandidate(Schedule(Matmul), []) for i in range(70)])
    model.load(path)
    res2 = model.predict(TuneContext(), [MeasureCandidate(Schedule(Matmul), []) for i in range(70)])
    shutil.rmtree(os.path.dirname(path))
    assert (res1 == res2).all()
def test_meta_schedule_cost_model():
    class FancyCostModel(PyCostModel):
        def load(self, path: str) -> None:
            pass

        def save(self, path: str) -> None:
            pass

        def update(
            self,
            tune_context: TuneContext,
            candidates: List[MeasureCandidate],
            results: List[RunnerResult],
        ) -> None:
            pass

        def predict(self, tune_context: TuneContext,
                    candidates: List[MeasureCandidate]) -> np.ndarray:
            return np.random.rand(10)

    model = FancyCostModel()
    model.save("fancy_test_location")
    model.load("fancy_test_location")
    model.update(TuneContext(), [], [])
    results = model.predict(TuneContext,
                            [MeasureCandidate(Schedule(mod=Matmul), [])])
    assert results.shape == (10, )
def test_meta_schedule_measure_callback_fail():
    @derived_object
    class FailingMeasureCallback(PyMeasureCallback):
        def apply(
            self,
            task_scheduler: TaskScheduler,
            task_id: int,
            measure_candidates: List[MeasureCandidate],
            builds: List[BuilderResult],
            results: List[RunnerResult],
        ) -> None:
            raise ValueError("test")

    measure_callback = FailingMeasureCallback()
    with pytest.raises(ValueError, match="test"):
        measure_callback.apply(
            RoundRobin([], [],
                       DummyBuilder(),
                       DummyRunner(),
                       DummyDatabase(),
                       max_trials=1),
            0,
            [MeasureCandidate(Schedule(Matmul), None)],
            [BuilderResult("test_build", None)],
            [RunnerResult([1.0, 2.1], None)],
        )
def test_meta_schedule_measure_callback():
    @derived_object
    class FancyMeasureCallback(PyMeasureCallback):
        def apply(
            self,
            task_scheduler: TaskScheduler,
            task_id: int,
            measure_candidates: List[MeasureCandidate],
            builds: List[BuilderResult],
            results: List[RunnerResult],
        ) -> None:
            assert len(measure_candidates) == 1
            assert_structural_equal(measure_candidates[0].sch.mod, Matmul)
            assert (len(builds) == 1 and builds[0].error_msg is None
                    and builds[0].artifact_path == "test_build")
            assert (len(results) == 1 and results[0].error_msg is None
                    and len(results[0].run_secs) == 2)

    measure_callback = FancyMeasureCallback()
    measure_callback.apply(
        RoundRobin([], [],
                   DummyBuilder(),
                   DummyRunner(),
                   DummyDatabase(),
                   max_trials=1),
        0,
        [MeasureCandidate(Schedule(Matmul), None)],
        [BuilderResult("test_build", None)],
        [RunnerResult([1.0, 2.1], None)],
    )
def test_meta_schedule_random_model():
    model = RandomModel()
    model.update(TuneContext(), [], [])
    res = model.predict(
        TuneContext(),
        [MeasureCandidate(Schedule(Matmul), []) for i in range(10)])
    assert len(res) == 10
    assert min(res) >= 0 and max(res) <= model.max_range
Ejemplo n.º 7
0
def test_meta_schedule_measure_callback_fail():
    class FailingMeasureCallback(PyMeasureCallback):
        def apply(
            self,
            task_scheduler: TaskScheduler,
            task_id: int,
            measure_candidates: List[MeasureCandidate],
            builds: List[BuilderResult],
            results: List[RunnerResult],
        ) -> None:
            raise ValueError("test")

    measure_callback = FailingMeasureCallback()
    with pytest.raises(ValueError, match="test"):
        measure_callback.apply(
            TaskScheduler(),
            0,
            [MeasureCandidate(Schedule(Matmul), None)],
            [BuilderResult("test_build", None)],
            [RunnerResult([1.0, 2.1], None)],
        )
def _dummy_candidate():
    return MeasureCandidate(Schedule(Matmul), [])