示例#1
0
def test_meta_schedule_task_scheduler_single():
    num_trials_per_iter = 3
    max_trials_per_task = 10
    sch_fn = ScheduleFn(sch_fn=_schedule_matmul)
    replay = ReplayTrace(num_trials_per_iter, max_trials_per_task)
    task = TuneContext(
        MatmulModule,
        target=tvm.target.Target("llvm"),
        space_generator=sch_fn,
        search_strategy=replay,
        task_name="Test",
        rand_state=42,
    )
    database = DummyDatabase()
    round_robin = RoundRobin(
        [task],
        [1.0],
        DummyBuilder(),
        DummyRunner(),
        database,
        measure_callbacks=[measure_callback.AddToDatabase()],
        max_trials=max_trials_per_task,
    )
    round_robin.tune()
    assert len(database) == max_trials_per_task
示例#2
0
def test_meta_schedule_task_scheduler_override_next_task_id_only():  # pylint: disable=invalid-name

    num_trials_per_iter = 6
    max_trials_per_task = 101
    tasks = [
        TuneContext(
            MatmulModule,
            target=tvm.target.Target("llvm"),
            space_generator=ScheduleFn(sch_fn=_schedule_matmul),
            search_strategy=ReplayTrace(num_trials_per_iter,
                                        max_trials_per_task),
            task_name="Matmul",
            rand_state=42,
        ),
        TuneContext(
            MatmulReluModule,
            target=tvm.target.Target("llvm"),
            space_generator=ScheduleFn(sch_fn=_schedule_matmul),
            search_strategy=ReplayTrace(num_trials_per_iter,
                                        max_trials_per_task),
            task_name="MatmulRelu",
            rand_state=0xDEADBEEF,
        ),
        TuneContext(
            BatchMatmulModule,
            target=tvm.target.Target("llvm"),
            space_generator=ScheduleFn(sch_fn=_schedule_batch_matmul),
            search_strategy=ReplayTrace(num_trials_per_iter,
                                        max_trials_per_task),
            task_name="BatchMatmul",
            rand_state=0x114514,
        ),
    ]
    database = DummyDatabase()
    scheduler = MyTaskScheduler(
        tasks,
        DummyBuilder(),
        DummyRunner(),
        database,
        measure_callbacks=[
            measure_callback.AddToDatabase(),
        ],
        max_trials=max_trials_per_task * len(tasks),
    )
    scheduler.tune()
    assert len(database) == max_trials_per_task * len(tasks)
    for task in tasks:
        assert (len(
            database.get_top_k(
                database.commit_workload(task.mod),
                100000,
            )) == max_trials_per_task)
示例#3
0
def test_meta_schedule_task_scheduler_multiple():
    num_trials_per_iter = 6
    max_trials_per_task = 101
    tasks = [
        TuneContext(
            MatmulModule,
            target=tvm.target.Target("llvm"),
            space_generator=ScheduleFn(sch_fn=_schedule_matmul),
            search_strategy=ReplayTrace(num_trials_per_iter,
                                        max_trials_per_task),
            task_name="Matmul",
            rand_state=42,
        ),
        TuneContext(
            MatmulReluModule,
            target=tvm.target.Target("llvm"),
            space_generator=ScheduleFn(sch_fn=_schedule_matmul),
            search_strategy=ReplayTrace(num_trials_per_iter,
                                        max_trials_per_task),
            task_name="MatmulRelu",
            rand_state=0xDEADBEEF,
        ),
        TuneContext(
            BatchMatmulModule,
            target=tvm.target.Target("llvm"),
            space_generator=ScheduleFn(sch_fn=_schedule_batch_matmul),
            search_strategy=ReplayTrace(num_trials_per_iter,
                                        max_trials_per_task),
            task_name="BatchMatmul",
            rand_state=0x114514,
        ),
    ]
    database = DummyDatabase()
    round_robin = RoundRobin(
        tasks,
        [1.0],
        DummyBuilder(),
        DummyRunner(),
        database,
        measure_callbacks=[measure_callback.AddToDatabase()],
        max_trials=max_trials_per_task * len(tasks),
    )
    round_robin.tune()
    assert len(database) == max_trials_per_task * len(tasks)
    for task in tasks:
        assert (len(
            database.get_top_k(
                database.commit_workload(task.mod),
                100000,
            )) == max_trials_per_task)
def test_meta_schedule_task_scheduler_avoid_cyclic():  # pylint: disable=invalid-name

    database = DummyDatabase()
    scheduler = MyTaskScheduler(
        [],
        DummyBuilder(),
        DummyRunner(),
        database,
        measure_callbacks=[
            measure_callback.AddToDatabase(),
        ],
    )
    test = weakref.ref(scheduler)  # test if it can be destructed successfully
    del scheduler
    assert test() is None
示例#5
0
def callbacks(  # pylint: disable=redefined-outer-name
    measure_callbacks: Optional[List[MeasureCallback]],
) -> List[MeasureCallback]:
    """Normalize the input to List[tvm.meta_schedule.MeasureCallback]"""
    if measure_callbacks is None:
        from tvm.meta_schedule import measure_callback as M

        return [
            M.AddToDatabase(),
            M.RemoveBuildArtifact(),
            M.EchoStatistics(),
            M.UpdateCostModel(),
        ]
    if not isinstance(measure_callbacks, (list, tuple)):
        raise TypeError(
            f"Expected `measure_callbacks` to be List[MeasureCallback], "
            f"but gets: {measure_callbacks}")
    measure_callbacks = list(measure_callbacks)
    for i, callback in enumerate(measure_callbacks):
        if not isinstance(callback, MeasureCallback):
            raise TypeError(
                f"Expected `measure_callbacks` to be List[MeasureCallback], "
                f"but measure_callbacks[{i}] is: {callback}")
    return measure_callbacks
示例#6
0
def test_meta_schedule_task_scheduler_override_next_task_id_only():  # pylint: disable=invalid-name
    class MyTaskScheduler(PyTaskScheduler):
        done = set()

        def next_task_id(self) -> int:
            while len(self.done) != len(tasks):
                x = random.randint(0, len(tasks) - 1)
                task = tasks[x]
                if not task.is_stopped:
                    """Calling base func via following route:
                    Python side:
                        PyTaskScheduler does not have `_is_task_running`
                        Call TaskScheduler's `is_task_running`, which calls ffi
                    C++ side:
                        The ffi calls TaskScheduler's `is_task_running`
                        But it is overridden in PyTaskScheduler
                        PyTaskScheduler checks if the function is overridden in python
                        If not, it returns the TaskScheduler's vtable, calling
                            TaskScheduler::IsTaskRunning
                    """
                    if self._is_task_running(x):
                        # Same Here
                        self._join_running_task(x)
                    return x
                else:
                    self.done.add(x)
            return -1

    num_trials_per_iter = 6
    num_trials_total = 101
    tasks = [
        TuneContext(
            MatmulModule,
            target=tvm.target.Target("llvm"),
            space_generator=ScheduleFn(sch_fn=_schedule_matmul),
            search_strategy=ReplayTrace(num_trials_per_iter, num_trials_total),
            task_name="Matmul",
            rand_state=42,
        ),
        TuneContext(
            MatmulReluModule,
            target=tvm.target.Target("llvm"),
            space_generator=ScheduleFn(sch_fn=_schedule_matmul),
            search_strategy=ReplayTrace(num_trials_per_iter, num_trials_total),
            task_name="MatmulRelu",
            rand_state=0xDEADBEEF,
        ),
        TuneContext(
            BatchMatmulModule,
            target=tvm.target.Target("llvm"),
            space_generator=ScheduleFn(sch_fn=_schedule_batch_matmul),
            search_strategy=ReplayTrace(num_trials_per_iter, num_trials_total),
            task_name="BatchMatmul",
            rand_state=0x114514,
        ),
    ]
    database = DummyDatabase()
    scheduler = MyTaskScheduler(
        tasks,
        DummyBuilder(),
        DummyRunner(),
        database,
        measure_callbacks=[
            measure_callback.AddToDatabase(),
        ],
    )
    scheduler.tune()
    assert len(database) == num_trials_total * len(tasks)
    for task in tasks:
        assert (
            len(
                database.get_top_k(
                    database.commit_workload(task.mod),
                    100000,
                )
            )
            == num_trials_total
        )