Ejemplo n.º 1
0
def test_meta_schedule_integration_apply_history_best():
    mod, _, _ = get_network(name="resnet_18", input_shape=[1, 3, 224, 224])
    database = DummyDatabase()
    env = ApplyHistoryBest(database)
    target = Target("llvm")
    workload = database.commit_workload(MockModule)
    database.commit_tuning_record(
        TuningRecord(Schedule(MockModule).trace, [1.0], workload, target, [])
    )
    mod = env.query(task_name="mock-task", mod=mod, target=target, dispatched=[MockModule])
    assert tvm.ir.structural_equal(mod, workload.mod)
Ejemplo n.º 2
0
def test_meta_schedule_task_scheduler_override_next_task_id_only():  # pylint: disable=invalid-name

    num_trials_per_iter = 6
    max_trials_per_task = 101
    tasks = [
        TuneContext(
            MatmulModule,
            target=tvm.target.Target("llvm"),
            space_generator=ScheduleFn(sch_fn=_schedule_matmul),
            search_strategy=ReplayTrace(num_trials_per_iter,
                                        max_trials_per_task),
            task_name="Matmul",
            rand_state=42,
        ),
        TuneContext(
            MatmulReluModule,
            target=tvm.target.Target("llvm"),
            space_generator=ScheduleFn(sch_fn=_schedule_matmul),
            search_strategy=ReplayTrace(num_trials_per_iter,
                                        max_trials_per_task),
            task_name="MatmulRelu",
            rand_state=0xDEADBEEF,
        ),
        TuneContext(
            BatchMatmulModule,
            target=tvm.target.Target("llvm"),
            space_generator=ScheduleFn(sch_fn=_schedule_batch_matmul),
            search_strategy=ReplayTrace(num_trials_per_iter,
                                        max_trials_per_task),
            task_name="BatchMatmul",
            rand_state=0x114514,
        ),
    ]
    database = DummyDatabase()
    scheduler = MyTaskScheduler(
        tasks,
        DummyBuilder(),
        DummyRunner(),
        database,
        measure_callbacks=[
            measure_callback.AddToDatabase(),
        ],
        max_trials=max_trials_per_task * len(tasks),
    )
    scheduler.tune()
    assert len(database) == max_trials_per_task * len(tasks)
    for task in tasks:
        assert (len(
            database.get_top_k(
                database.commit_workload(task.mod),
                100000,
            )) == max_trials_per_task)
Ejemplo n.º 3
0
def test_meta_schedule_task_scheduler_multiple():
    num_trials_per_iter = 6
    max_trials_per_task = 101
    tasks = [
        TuneContext(
            MatmulModule,
            target=tvm.target.Target("llvm"),
            space_generator=ScheduleFn(sch_fn=_schedule_matmul),
            search_strategy=ReplayTrace(num_trials_per_iter,
                                        max_trials_per_task),
            task_name="Matmul",
            rand_state=42,
        ),
        TuneContext(
            MatmulReluModule,
            target=tvm.target.Target("llvm"),
            space_generator=ScheduleFn(sch_fn=_schedule_matmul),
            search_strategy=ReplayTrace(num_trials_per_iter,
                                        max_trials_per_task),
            task_name="MatmulRelu",
            rand_state=0xDEADBEEF,
        ),
        TuneContext(
            BatchMatmulModule,
            target=tvm.target.Target("llvm"),
            space_generator=ScheduleFn(sch_fn=_schedule_batch_matmul),
            search_strategy=ReplayTrace(num_trials_per_iter,
                                        max_trials_per_task),
            task_name="BatchMatmul",
            rand_state=0x114514,
        ),
    ]
    database = DummyDatabase()
    round_robin = RoundRobin(
        tasks,
        [1.0],
        DummyBuilder(),
        DummyRunner(),
        database,
        measure_callbacks=[measure_callback.AddToDatabase()],
        max_trials=max_trials_per_task * len(tasks),
    )
    round_robin.tune()
    assert len(database) == max_trials_per_task * len(tasks)
    for task in tasks:
        assert (len(
            database.get_top_k(
                database.commit_workload(task.mod),
                100000,
            )) == max_trials_per_task)
Ejemplo n.º 4
0
def test_meta_schedule_task_scheduler_single():
    num_trials_per_iter = 3
    max_trials_per_task = 10
    sch_fn = ScheduleFn(sch_fn=_schedule_matmul)
    replay = ReplayTrace(num_trials_per_iter, max_trials_per_task)
    task = TuneContext(
        MatmulModule,
        target=tvm.target.Target("llvm"),
        space_generator=sch_fn,
        search_strategy=replay,
        task_name="Test",
        rand_state=42,
    )
    database = DummyDatabase()
    round_robin = RoundRobin(
        [task],
        [1.0],
        DummyBuilder(),
        DummyRunner(),
        database,
        measure_callbacks=[measure_callback.AddToDatabase()],
        max_trials=max_trials_per_task,
    )
    round_robin.tune()
    assert len(database) == max_trials_per_task
def test_meta_schedule_measure_callback_fail():
    @derived_object
    class FailingMeasureCallback(PyMeasureCallback):
        def apply(
            self,
            task_scheduler: TaskScheduler,
            task_id: int,
            measure_candidates: List[MeasureCandidate],
            builds: List[BuilderResult],
            results: List[RunnerResult],
        ) -> None:
            raise ValueError("test")

    measure_callback = FailingMeasureCallback()
    with pytest.raises(ValueError, match="test"):
        measure_callback.apply(
            RoundRobin([], [],
                       DummyBuilder(),
                       DummyRunner(),
                       DummyDatabase(),
                       max_trials=1),
            0,
            [MeasureCandidate(Schedule(Matmul), None)],
            [BuilderResult("test_build", None)],
            [RunnerResult([1.0, 2.1], None)],
        )
def test_meta_schedule_measure_callback():
    @derived_object
    class FancyMeasureCallback(PyMeasureCallback):
        def apply(
            self,
            task_scheduler: TaskScheduler,
            task_id: int,
            measure_candidates: List[MeasureCandidate],
            builds: List[BuilderResult],
            results: List[RunnerResult],
        ) -> None:
            assert len(measure_candidates) == 1
            assert_structural_equal(measure_candidates[0].sch.mod, Matmul)
            assert (len(builds) == 1 and builds[0].error_msg is None
                    and builds[0].artifact_path == "test_build")
            assert (len(results) == 1 and results[0].error_msg is None
                    and len(results[0].run_secs) == 2)

    measure_callback = FancyMeasureCallback()
    measure_callback.apply(
        RoundRobin([], [],
                   DummyBuilder(),
                   DummyRunner(),
                   DummyDatabase(),
                   max_trials=1),
        0,
        [MeasureCandidate(Schedule(Matmul), None)],
        [BuilderResult("test_build", None)],
        [RunnerResult([1.0, 2.1], None)],
    )
Ejemplo n.º 7
0
def test_meta_schedule_task_scheduler_NIE():  # pylint: disable=invalid-name
    @derived_object
    class NIETaskScheduler(PyTaskScheduler):
        pass

    with pytest.raises(
            TVMError,
            match="PyTaskScheduler's NextTaskId method not implemented!"):
        scheduler = NIETaskScheduler([], DummyBuilder(), DummyRunner(),
                                     DummyDatabase(), 1)
        scheduler.next_task_id()
def test_meta_schedule_task_scheduler_avoid_cyclic():  # pylint: disable=invalid-name

    database = DummyDatabase()
    scheduler = MyTaskScheduler(
        [],
        DummyBuilder(),
        DummyRunner(),
        database,
        measure_callbacks=[
            measure_callback.AddToDatabase(),
        ],
    )
    test = weakref.ref(scheduler)  # test if it can be destructed successfully
    del scheduler
    assert test() is None
Ejemplo n.º 9
0
def test_meta_schedule_evolutionary_search(
):  # pylint: disable = invalid-name]
    num_trials_per_iter = 10
    num_trials_total = 100

    strategy = EvolutionarySearch(
        num_trials_per_iter=num_trials_per_iter,
        num_trials_total=num_trials_total,
        population_size=5,
        init_measured_ratio=0.1,
        init_min_unmeasured=50,
        genetic_num_iters=3,
        genetic_mutate_prob=0.5,
        genetic_max_fail_count=10,
        eps_greedy=0.9,
    )
    context = TuneContext(
        mod=Matmul,
        space_generator=ScheduleFn(sch_fn=_schedule_matmul),
        mutator_probs={
            DummyMutator(): 1.0,
        },
        target=tvm.target.Target("llvm"),
        num_threads=1,  # because we are using a mutator from the python side
    )
    _scheduler = RoundRobin(
        tasks=[context],
        builder=LocalBuilder(),
        runner=LocalRunner(),
        database=DummyDatabase(),
        cost_model=RandomModel(),
        measure_callbacks=[],
    )
    context.space_generator.initialize_with_tune_context(context)
    spaces = context.space_generator.generate_design_space(context.mod)

    strategy.initialize_with_tune_context(context)
    strategy.pre_tuning(spaces)
    (correct_sch, ) = ScheduleFn(
        sch_fn=_schedule_matmul).generate_design_space(Matmul)
    num_trials_each_iter: List[int] = []
    candidates = strategy.generate_measure_candidates()
    while candidates is not None:
        num_trials_each_iter.append(len(candidates))
        runner_results: List[RunnerResult] = []
        for candidate in candidates:
            _is_trace_equal(
                candidate.sch,
                correct_sch,
                remove_decisions=(isinstance(strategy, ReplayTrace)),
            )
            runner_results.append(
                RunnerResult(run_secs=[0.11, 0.41, 0.54], error_msg=None))
        strategy.notify_runner_results(context, candidates, runner_results)
        candidates = strategy.generate_measure_candidates()
    strategy.post_tuning()
    print(num_trials_each_iter)
    correct_count = 10  # For each iteration except the last one
    assert num_trials_each_iter == [correct_count] * (
        num_trials_total // correct_count) + (
            [num_trials_total %
             correct_count] if num_trials_total % correct_count != 0 else [])
    del _scheduler