예제 #1
0
def test_meta_schedule_space_generator_schedule_fn():
    mod = Matmul
    space_generator = ScheduleFn(sch_fn=schedule_matmul)
    design_spaces = space_generator.generate_design_space(mod)
    assert len(design_spaces) == 1
    (schedule, ) = design_spaces
    _check_correct(schedule)
def test_meta_schedule_replay_func(
        TestClass: SearchStrategy):  # pylint: disable = invalid-name
    num_trials_per_iter = 7
    max_trials_per_task = 20

    strategy = TestClass(num_trials_per_iter=num_trials_per_iter,
                         max_trials_per_task=max_trials_per_task)
    context = TuneContext(mod=Matmul,
                          space_generator=ScheduleFn(sch_fn=_schedule_matmul))
    context.space_generator.initialize_with_tune_context(context)
    spaces = context.space_generator.generate_design_space(context.mod)

    strategy.initialize_with_tune_context(context)
    strategy.pre_tuning(spaces)
    (correct_sch, ) = ScheduleFn(
        sch_fn=_schedule_matmul).generate_design_space(Matmul)
    num_trials_each_iter: List[int] = []
    candidates = strategy.generate_measure_candidates()
    while candidates is not None:
        num_trials_each_iter.append(len(candidates))
        runner_results: List[RunnerResult] = []
        for candidate in candidates:
            _is_trace_equal(
                candidate.sch,
                correct_sch,
                remove_decisions=(isinstance(strategy, ReplayTrace)),
            )
            runner_results.append(
                RunnerResult(run_secs=[0.11, 0.41, 0.54], error_msg=None))
        strategy.notify_runner_results(context, candidates, runner_results)
        candidates = strategy.generate_measure_candidates()
    strategy.post_tuning()
    assert num_trials_each_iter == [7, 7, 6]
예제 #3
0
def test_meta_schedule_task_scheduler_override_next_task_id_only():  # pylint: disable=invalid-name

    num_trials_per_iter = 6
    max_trials_per_task = 101
    tasks = [
        TuneContext(
            MatmulModule,
            target=tvm.target.Target("llvm"),
            space_generator=ScheduleFn(sch_fn=_schedule_matmul),
            search_strategy=ReplayTrace(num_trials_per_iter,
                                        max_trials_per_task),
            task_name="Matmul",
            rand_state=42,
        ),
        TuneContext(
            MatmulReluModule,
            target=tvm.target.Target("llvm"),
            space_generator=ScheduleFn(sch_fn=_schedule_matmul),
            search_strategy=ReplayTrace(num_trials_per_iter,
                                        max_trials_per_task),
            task_name="MatmulRelu",
            rand_state=0xDEADBEEF,
        ),
        TuneContext(
            BatchMatmulModule,
            target=tvm.target.Target("llvm"),
            space_generator=ScheduleFn(sch_fn=_schedule_batch_matmul),
            search_strategy=ReplayTrace(num_trials_per_iter,
                                        max_trials_per_task),
            task_name="BatchMatmul",
            rand_state=0x114514,
        ),
    ]
    database = DummyDatabase()
    scheduler = MyTaskScheduler(
        tasks,
        DummyBuilder(),
        DummyRunner(),
        database,
        measure_callbacks=[
            measure_callback.AddToDatabase(),
        ],
        max_trials=max_trials_per_task * len(tasks),
    )
    scheduler.tune()
    assert len(database) == max_trials_per_task * len(tasks)
    for task in tasks:
        assert (len(
            database.get_top_k(
                database.commit_workload(task.mod),
                100000,
            )) == max_trials_per_task)
예제 #4
0
def test_meta_schedule_task_scheduler_multiple():
    num_trials_per_iter = 6
    max_trials_per_task = 101
    tasks = [
        TuneContext(
            MatmulModule,
            target=tvm.target.Target("llvm"),
            space_generator=ScheduleFn(sch_fn=_schedule_matmul),
            search_strategy=ReplayTrace(num_trials_per_iter,
                                        max_trials_per_task),
            task_name="Matmul",
            rand_state=42,
        ),
        TuneContext(
            MatmulReluModule,
            target=tvm.target.Target("llvm"),
            space_generator=ScheduleFn(sch_fn=_schedule_matmul),
            search_strategy=ReplayTrace(num_trials_per_iter,
                                        max_trials_per_task),
            task_name="MatmulRelu",
            rand_state=0xDEADBEEF,
        ),
        TuneContext(
            BatchMatmulModule,
            target=tvm.target.Target("llvm"),
            space_generator=ScheduleFn(sch_fn=_schedule_batch_matmul),
            search_strategy=ReplayTrace(num_trials_per_iter,
                                        max_trials_per_task),
            task_name="BatchMatmul",
            rand_state=0x114514,
        ),
    ]
    database = DummyDatabase()
    round_robin = RoundRobin(
        tasks,
        [1.0],
        DummyBuilder(),
        DummyRunner(),
        database,
        measure_callbacks=[measure_callback.AddToDatabase()],
        max_trials=max_trials_per_task * len(tasks),
    )
    round_robin.tune()
    assert len(database) == max_trials_per_task * len(tasks)
    for task in tasks:
        assert (len(
            database.get_top_k(
                database.commit_workload(task.mod),
                100000,
            )) == max_trials_per_task)
예제 #5
0
def test_meta_schedule_task_scheduler_single():
    num_trials_per_iter = 3
    max_trials_per_task = 10
    sch_fn = ScheduleFn(sch_fn=_schedule_matmul)
    replay = ReplayTrace(num_trials_per_iter, max_trials_per_task)
    task = TuneContext(
        MatmulModule,
        target=tvm.target.Target("llvm"),
        space_generator=sch_fn,
        search_strategy=replay,
        task_name="Test",
        rand_state=42,
    )
    database = DummyDatabase()
    round_robin = RoundRobin(
        [task],
        [1.0],
        DummyBuilder(),
        DummyRunner(),
        database,
        measure_callbacks=[measure_callback.AddToDatabase()],
        max_trials=max_trials_per_task,
    )
    round_robin.tune()
    assert len(database) == max_trials_per_task
def test_meta_schedule_replay_trace():
    num_trials_per_iter = 7
    num_trials_total = 20

    (example_sch, ) = ScheduleFn(
        sch_fn=_schedule_matmul).generate_design_space(Matmul)
    replay = ReplayTrace(num_trials_per_iter=num_trials_per_iter,
                         num_trials_total=num_trials_total)
    tune_context = TuneContext(mod=Matmul)
    replay.initialize_with_tune_context(tune_context)

    num_trials_each_round: List[int] = []
    replay.pre_tuning([example_sch])
    while True:
        candidates = replay.generate_measure_candidates()
        if candidates is None:
            break
        num_trials_each_round.append(len(candidates))
        runner_results: List[RunnerResult] = []
        for candidate in candidates:
            assert _is_trace_equal(candidate.sch, example_sch)
            runner_results.append(
                RunnerResult(run_secs=[0.5, 0.4, 0.3], error_msg=None))
        replay.notify_runner_results(runner_results)
    replay.post_tuning()
    assert num_trials_each_round == [7, 7, 6]
예제 #7
0
def test_meta_schedule_design_space_generator_union():
    mod = Matmul
    space_generator = ScheduleFn(sch_fn=schedule_matmul)
    space_generator_union = SpaceGeneratorUnion([space_generator, space_generator])
    design_spaces = space_generator_union.generate_design_space(mod)
    assert len(design_spaces) == 2
    for design_space in design_spaces:
        _check_correct(design_space)
예제 #8
0
def test_meta_schedule_task_scheduler_multiple():
    num_trials_per_iter = 6
    num_trials_total = 101
    tasks = [
        TuneContext(
            MatmulModule,
            target=tvm.target.Target("llvm"),
            space_generator=ScheduleFn(sch_fn=_schedule_matmul),
            search_strategy=ReplayTrace(num_trials_per_iter, num_trials_total),
            task_name="Matmul",
            rand_state=42,
        ),
        TuneContext(
            MatmulReluModule,
            target=tvm.target.Target("llvm"),
            space_generator=ScheduleFn(sch_fn=_schedule_matmul),
            search_strategy=ReplayTrace(num_trials_per_iter, num_trials_total),
            task_name="MatmulRelu",
            rand_state=0xDEADBEEF,
        ),
        TuneContext(
            BatchMatmulModule,
            target=tvm.target.Target("llvm"),
            space_generator=ScheduleFn(sch_fn=_schedule_batch_matmul),
            search_strategy=ReplayTrace(num_trials_per_iter, num_trials_total),
            task_name="BatchMatmul",
            rand_state=0x114514,
        ),
    ]
    database = DummyDatabase()
    round_robin = RoundRobin(tasks, DummyBuilder(), DummyRunner(), database)
    round_robin.tune()
    assert len(database) == num_trials_total * len(tasks)
    print(database.workload_reg)
    for task in tasks:
        assert len(database.get_top_k(database.commit_workload(task.mod),
                                      1e9)) == num_trials_total
예제 #9
0
def test_meta_schedule_task_scheduler_single():
    num_trials_per_iter = 3
    num_trials_total = 10
    sch_fn = ScheduleFn(sch_fn=_schedule_matmul)
    replay = ReplayTrace(num_trials_per_iter, num_trials_total)
    task = TuneContext(
        MatmulModule,
        target=tvm.target.Target("llvm"),
        space_generator=sch_fn,
        search_strategy=replay,
        task_name="Test",
        rand_state=42,
    )
    database = DummyDatabase()
    round_robin = RoundRobin([task], DummyBuilder(), DummyRunner(), database)
    round_robin.tune()
    assert len(database) == num_trials_total
예제 #10
0
def test_meta_schedule_task_scheduler_override_next_task_id_only():
    class MyTaskScheduler(PyTaskScheduler):
        done = set()

        def next_task_id(self) -> int:
            while len(self.done) != len(tasks):
                x = random.randint(0, len(tasks) - 1)
                task = tasks[x]
                if not task.is_stopped:
                    """Calling base func via following route:
                    Python side:
                        PyTaskScheduler does not have `_is_task_running`
                        Call TaskScheduler's `is_task_running`, which calls ffi
                    C++ side:
                        The ffi calls TaskScheduler's `is_task_running`
                        But it is overridden in PyTaskScheduler
                        PyTaskScheduler checks if the function is overridden in python
                        If not, it returns the TaskScheduler's vtable, calling
                            TaskScheduler::IsTaskRunning
                    """
                    if self._is_task_running(x):
                        # Same Here
                        self._join_running_task(x)
                    return x
                else:
                    self.done.add(x)
            return -1

    num_trials_per_iter = 6
    num_trials_total = 101
    tasks = [
        TuneContext(
            MatmulModule,
            target=tvm.target.Target("llvm"),
            space_generator=ScheduleFn(sch_fn=_schedule_matmul),
            search_strategy=ReplayTrace(num_trials_per_iter, num_trials_total),
            task_name="Matmul",
            rand_state=42,
        ),
        TuneContext(
            MatmulReluModule,
            target=tvm.target.Target("llvm"),
            space_generator=ScheduleFn(sch_fn=_schedule_matmul),
            search_strategy=ReplayTrace(num_trials_per_iter, num_trials_total),
            task_name="MatmulRelu",
            rand_state=0xDEADBEEF,
        ),
        TuneContext(
            BatchMatmulModule,
            target=tvm.target.Target("llvm"),
            space_generator=ScheduleFn(sch_fn=_schedule_batch_matmul),
            search_strategy=ReplayTrace(num_trials_per_iter, num_trials_total),
            task_name="BatchMatmul",
            rand_state=0x114514,
        ),
    ]
    database = DummyDatabase()
    scheduler = MyTaskScheduler(tasks, DummyBuilder(), DummyRunner(), database)
    scheduler.tune()
    assert len(database) == num_trials_total * len(tasks)
    for task in tasks:
        assert len(database.get_top_k(database.commit_workload(task.mod),
                                      1e9)) == num_trials_total
예제 #11
0
def test_meta_schedule_evolutionary_search_early_stop(
):  # pylint: disable = invalid-name]
    def _schedule_matmul_empty(sch: Schedule):
        return sch

    num_trials_per_iter = 10
    max_trials_per_task = 100

    strategy = EvolutionarySearch(
        num_trials_per_iter=num_trials_per_iter,
        max_trials_per_task=max_trials_per_task,
        population_size=5,
        init_measured_ratio=0.1,
        init_min_unmeasured=50,
        genetic_num_iters=3,
        genetic_mutate_prob=0.5,
        genetic_max_fail_count=10,
        eps_greedy=0.9,
    )
    context = TuneContext(
        mod=Matmul,
        space_generator=ScheduleFn(sch_fn=_schedule_matmul_empty),
        mutator_probs={
            DummyMutator(): 1.0,
        },
        target=tvm.target.Target("llvm"),
        num_threads=1,  # because we are using a mutator from the python side
    )
    _scheduler = RoundRobin(
        tasks=[context],
        task_weights=[1.0],
        builder=ms.builder.LocalBuilder(),
        runner=ms.runner.LocalRunner(),
        database=DummyDatabase(),
        cost_model=ms.cost_model.RandomModel(),
        measure_callbacks=[],
        max_trials=1,
    )
    context.space_generator.initialize_with_tune_context(context)
    spaces = context.space_generator.generate_design_space(context.mod)

    strategy.initialize_with_tune_context(context)
    strategy.pre_tuning(spaces)
    (correct_sch, ) = ScheduleFn(
        sch_fn=_schedule_matmul).generate_design_space(Matmul)
    num_trials_each_iter: List[int] = []
    candidates = strategy.generate_measure_candidates()
    while candidates is not None:
        num_trials_each_iter.append(len(candidates))
        runner_results: List[RunnerResult] = []
        for candidate in candidates:
            _is_trace_equal(
                candidate.sch,
                correct_sch,
                remove_decisions=(isinstance(strategy, ReplayTrace)),
            )
            runner_results.append(
                RunnerResult(run_secs=[0.11, 0.41, 0.54], error_msg=None))
        strategy.notify_runner_results(context, candidates, runner_results)
        candidates = strategy.generate_measure_candidates()
    strategy.post_tuning()
    assert num_trials_each_iter == [1, 0, 0, 0, 0]
    del _scheduler
예제 #12
0
def test_meta_schedule_evolutionary_search(
):  # pylint: disable = invalid-name]
    @derived_object
    class DummyMutator(PyMutator):
        """Dummy Mutator for testing"""
        def initialize_with_tune_context(self, context: "TuneContext") -> None:
            pass

        def apply(self, trace: Trace, _) -> Optional[Trace]:
            return Trace(trace.insts, {})

    @derived_object
    class DummyDatabase(PyDatabase):
        """Dummy Database for testing"""
        def __init__(self):
            super().__init__()
            self.records = []
            self.workload_reg = []

        def has_workload(self, mod: IRModule) -> bool:
            for workload in self.workload_reg:
                if tvm.ir.structural_equal(workload.mod, mod):
                    return True
            return False

        def commit_tuning_record(self, record: TuningRecord) -> None:
            self.records.append(record)

        def commit_workload(self, mod: IRModule) -> Workload:
            for workload in self.workload_reg:
                if tvm.ir.structural_equal(workload.mod, mod):
                    return workload
            workload = Workload(mod)
            self.workload_reg.append(workload)
            return workload

        def get_top_k(self, workload: Workload,
                      top_k: int) -> List[TuningRecord]:
            return list(
                filter(
                    lambda x: x.workload == workload,
                    sorted(self.records,
                           key=lambda x: sum(x.run_secs) / len(x.run_secs)),
                ))[:int(top_k)]

        def __len__(self) -> int:
            return len(self.records)

        def print_results(self) -> None:
            print("\n".join([str(r) for r in self.records]))

    num_trials_per_iter = 10
    num_trials_total = 100

    strategy = EvolutionarySearch(
        num_trials_per_iter=num_trials_per_iter,
        num_trials_total=num_trials_total,
        population_size=5,
        init_measured_ratio=0.1,
        init_min_unmeasured=50,
        genetic_num_iters=3,
        genetic_mutate_prob=0.5,
        genetic_max_fail_count=10,
        eps_greedy=0.9,
    )
    context = TuneContext(
        mod=Matmul,
        space_generator=ScheduleFn(sch_fn=_schedule_matmul),
        mutator_probs={
            DummyMutator(): 1.0,
        },
        target=tvm.target.Target("llvm"),
        num_threads=1,  # because we are using a mutator from the python side
    )
    _scheduler = RoundRobin(
        tasks=[context],
        builder=LocalBuilder(),
        runner=LocalRunner(),
        database=DummyDatabase(),
        cost_model=RandomModel(),
        measure_callbacks=[],
    )
    context.space_generator.initialize_with_tune_context(context)
    spaces = context.space_generator.generate_design_space(context.mod)

    strategy.initialize_with_tune_context(context)
    strategy.pre_tuning(spaces)
    (correct_sch, ) = ScheduleFn(
        sch_fn=_schedule_matmul).generate_design_space(Matmul)
    num_trials_each_iter: List[int] = []
    candidates = strategy.generate_measure_candidates()
    while candidates is not None:
        num_trials_each_iter.append(len(candidates))
        runner_results: List[RunnerResult] = []
        for candidate in candidates:
            _is_trace_equal(
                candidate.sch,
                correct_sch,
                remove_decisions=(isinstance(strategy, ReplayTrace)),
            )
            runner_results.append(
                RunnerResult(run_secs=[0.11, 0.41, 0.54], error_msg=None))
        strategy.notify_runner_results(context, candidates, runner_results)
        candidates = strategy.generate_measure_candidates()
    strategy.post_tuning()
    print(num_trials_each_iter)
    correct_count = 10  # For each iteration except the last one
    assert num_trials_each_iter == [correct_count] * (
        num_trials_total // correct_count) + (
            [num_trials_total %
             correct_count] if num_trials_total % correct_count != 0 else [])
    del _scheduler