def test_meta_schedule_task_scheduler_single():
    num_trials_per_iter = 3
    max_trials_per_task = 10
    database = ms.database.MemoryDatabase()
    round_robin = ms.task_scheduler.RoundRobin(
        [
            ms.TuneContext(
                MatmulModule,
                target=tvm.target.Target("llvm"),
                space_generator=ms.space_generator.ScheduleFn(
                    sch_fn=_schedule_matmul),
                search_strategy=ms.search_strategy.ReplayTrace(
                    num_trials_per_iter,
                    max_trials_per_task,
                ),
                task_name="Test",
                rand_state=42,
            )
        ],
        [1.0],
        builder=DummyBuilder(),
        runner=DummyRunner(),
        database=database,
        measure_callbacks=[ms.measure_callback.AddToDatabase()],
        max_trials=max_trials_per_task,
    )
    round_robin.tune()
    assert len(database) == max_trials_per_task
Beispiel #2
0
def test_meta_schedule_measure_callback_fail():
    @ms.derived_object
    class FailingMeasureCallback(ms.measure_callback.PyMeasureCallback):
        def apply(
            self,
            task_scheduler: ms.task_scheduler.TaskScheduler,
            task_id: int,
            measure_candidates: List[ms.MeasureCandidate],
            builder_results: List[ms.builder.BuilderResult],
            runner_results: List[ms.runner.RunnerResult],
        ) -> None:
            raise ValueError("test")

    measure_callback = FailingMeasureCallback()
    with pytest.raises(ValueError, match="test"):
        measure_callback.apply(
            ms.task_scheduler.RoundRobin(
                tasks=[],
                task_weights=[],
                builder=DummyBuilder(),
                runner=DummyRunner(),
                database=ms.database.MemoryDatabase(),
                max_trials=1,
            ),
            0,
            [ms.MeasureCandidate(Schedule(Matmul), None)],
            [ms.builder.BuilderResult("test_build", None)],
            [ms.runner.RunnerResult([1.0, 2.1], None)],
        )
def test_meta_schedule_task_scheduler_multiple():
    num_trials_per_iter = 6
    max_trials_per_task = 101
    tasks = [
        ms.TuneContext(
            MatmulModule,
            target=tvm.target.Target("llvm"),
            space_generator=ms.space_generator.ScheduleFn(
                sch_fn=_schedule_matmul),
            search_strategy=ms.search_strategy.ReplayTrace(
                num_trials_per_iter,
                max_trials_per_task,
            ),
            task_name="Matmul",
            rand_state=42,
        ),
        ms.TuneContext(
            MatmulReluModule,
            target=tvm.target.Target("llvm"),
            space_generator=ms.space_generator.ScheduleFn(
                sch_fn=_schedule_matmul),
            search_strategy=ms.search_strategy.ReplayTrace(
                num_trials_per_iter,
                max_trials_per_task,
            ),
            task_name="MatmulRelu",
            rand_state=0xDEADBEEF,
        ),
        ms.TuneContext(
            BatchMatmulModule,
            target=tvm.target.Target("llvm"),
            space_generator=ms.space_generator.ScheduleFn(
                sch_fn=_schedule_batch_matmul),
            search_strategy=ms.search_strategy.ReplayTrace(
                num_trials_per_iter,
                max_trials_per_task,
            ),
            task_name="BatchMatmul",
            rand_state=0x114514,
        ),
    ]
    database = ms.database.MemoryDatabase()
    round_robin = ms.task_scheduler.RoundRobin(
        tasks,
        [1.0, 1.0, 1.0],
        builder=DummyBuilder(),
        runner=DummyRunner(),
        database=database,
        measure_callbacks=[ms.measure_callback.AddToDatabase()],
        max_trials=max_trials_per_task * len(tasks),
    )
    round_robin.tune()
    assert len(database) == max_trials_per_task * len(tasks)
    for task in tasks:
        assert (len(
            database.get_top_k(
                database.commit_workload(task.mod),
                100000,
            )) == max_trials_per_task)
def test_meta_schedule_task_scheduler_avoid_cyclic():  # pylint: disable=invalid-name
    database = ms.database.MemoryDatabase()
    scheduler = MyTaskScheduler(
        [],
        builder=DummyBuilder(),
        runner=DummyRunner(),
        database=database,
        measure_callbacks=[
            ms.measure_callback.AddToDatabase(),
        ],
        max_trials=10,
    )
    test = weakref.ref(scheduler)  # test if it can be destructed successfully
    del scheduler
    assert test() is None
def test_meta_schedule_task_scheduler_NIE():  # pylint: disable=invalid-name
    @ms.derived_object
    class NIETaskScheduler(ms.task_scheduler.PyTaskScheduler):
        pass

    with pytest.raises(
            TVMError,
            match="PyTaskScheduler's NextTaskId method not implemented!"):
        scheduler = NIETaskScheduler(
            tasks=[],
            builder=DummyBuilder(),
            runner=DummyRunner(),
            database=ms.database.MemoryDatabase(),
            max_trials=1,
        )
        scheduler.next_task_id()
Beispiel #6
0
def test_meta_schedule_measure_callback():
    @ms.derived_object
    class FancyMeasureCallback(ms.measure_callback.PyMeasureCallback):
        def apply(
            self,
            task_scheduler: ms.task_scheduler.TaskScheduler,
            task_id: int,
            measure_candidates: List[ms.MeasureCandidate],
            builder_results: List[ms.builder.BuilderResult],
            runner_results: List[ms.runner.RunnerResult],
        ) -> None:
            assert len(measure_candidates) == 1
            tvm.ir.assert_structural_equal(measure_candidates[0].sch.mod,
                                           Matmul)
            assert (len(builder_results) == 1
                    and builder_results[0].error_msg is None
                    and builder_results[0].artifact_path == "test_build")
            assert (len(runner_results) == 1
                    and runner_results[0].error_msg is None
                    and len(runner_results[0].run_secs) == 2)

    measure_callback = FancyMeasureCallback()
    measure_callback.apply(
        ms.task_scheduler.RoundRobin(
            tasks=[],
            task_weights=[],
            builder=DummyBuilder(),
            runner=DummyRunner(),
            database=ms.database.MemoryDatabase(),
            max_trials=1,
        ),
        0,
        [ms.MeasureCandidate(Schedule(Matmul), None)],
        [ms.builder.BuilderResult("test_build", None)],
        [ms.runner.RunnerResult([1.0, 2.1], None)],
    )