def test_meta_schedule_measure_callback_fail():
    @derived_object
    class FailingMeasureCallback(PyMeasureCallback):
        def apply(
            self,
            task_scheduler: TaskScheduler,
            task_id: int,
            measure_candidates: List[MeasureCandidate],
            builds: List[BuilderResult],
            results: List[RunnerResult],
        ) -> None:
            raise ValueError("test")

    measure_callback = FailingMeasureCallback()
    with pytest.raises(ValueError, match="test"):
        measure_callback.apply(
            RoundRobin([], [],
                       DummyBuilder(),
                       DummyRunner(),
                       DummyDatabase(),
                       max_trials=1),
            0,
            [MeasureCandidate(Schedule(Matmul), None)],
            [BuilderResult("test_build", None)],
            [RunnerResult([1.0, 2.1], None)],
        )
def test_meta_schedule_measure_callback():
    @derived_object
    class FancyMeasureCallback(PyMeasureCallback):
        def apply(
            self,
            task_scheduler: TaskScheduler,
            task_id: int,
            measure_candidates: List[MeasureCandidate],
            builds: List[BuilderResult],
            results: List[RunnerResult],
        ) -> None:
            assert len(measure_candidates) == 1
            assert_structural_equal(measure_candidates[0].sch.mod, Matmul)
            assert (len(builds) == 1 and builds[0].error_msg is None
                    and builds[0].artifact_path == "test_build")
            assert (len(results) == 1 and results[0].error_msg is None
                    and len(results[0].run_secs) == 2)

    measure_callback = FancyMeasureCallback()
    measure_callback.apply(
        RoundRobin([], [],
                   DummyBuilder(),
                   DummyRunner(),
                   DummyDatabase(),
                   max_trials=1),
        0,
        [MeasureCandidate(Schedule(Matmul), None)],
        [BuilderResult("test_build", None)],
        [RunnerResult([1.0, 2.1], None)],
    )
示例#3
0
def test_meta_schedule_measure_callback_fail():
    class FailingMeasureCallback(PyMeasureCallback):
        def apply(
            self,
            task_scheduler: TaskScheduler,
            task_id: int,
            measure_candidates: List[MeasureCandidate],
            builds: List[BuilderResult],
            results: List[RunnerResult],
        ) -> None:
            raise ValueError("test")

    measure_callback = FailingMeasureCallback()
    with pytest.raises(ValueError, match="test"):
        measure_callback.apply(
            TaskScheduler(),
            0,
            [MeasureCandidate(Schedule(Matmul), None)],
            [BuilderResult("test_build", None)],
            [RunnerResult([1.0, 2.1], None)],
        )
示例#4
0
 def build(self, build_inputs: List[BuilderInput]) -> List[BuilderResult]:
     return [BuilderResult("test_path", None) for _ in build_inputs]
示例#5
0
 def build(  # pylint: disable=no-self-use
     self,
     build_inputs: List[BuilderInput],
 ) -> List[BuilderResult]:
     return [BuilderResult(None, "error") for w in build_inputs]