def test_meta_schedule_integration_apply_history_best(): mod, _, _ = get_network(name="resnet_18", input_shape=[1, 3, 224, 224]) database = DummyDatabase() env = ApplyHistoryBest(database) target = Target("llvm") workload = database.commit_workload(MockModule) database.commit_tuning_record( TuningRecord(Schedule(MockModule).trace, [1.0], workload, target, []) ) mod = env.query(task_name="mock-task", mod=mod, target=target, dispatched=[MockModule]) assert tvm.ir.structural_equal(mod, workload.mod)
def test_meta_schedule_task_scheduler_override_next_task_id_only(): # pylint: disable=invalid-name num_trials_per_iter = 6 max_trials_per_task = 101 tasks = [ TuneContext( MatmulModule, target=tvm.target.Target("llvm"), space_generator=ScheduleFn(sch_fn=_schedule_matmul), search_strategy=ReplayTrace(num_trials_per_iter, max_trials_per_task), task_name="Matmul", rand_state=42, ), TuneContext( MatmulReluModule, target=tvm.target.Target("llvm"), space_generator=ScheduleFn(sch_fn=_schedule_matmul), search_strategy=ReplayTrace(num_trials_per_iter, max_trials_per_task), task_name="MatmulRelu", rand_state=0xDEADBEEF, ), TuneContext( BatchMatmulModule, target=tvm.target.Target("llvm"), space_generator=ScheduleFn(sch_fn=_schedule_batch_matmul), search_strategy=ReplayTrace(num_trials_per_iter, max_trials_per_task), task_name="BatchMatmul", rand_state=0x114514, ), ] database = DummyDatabase() scheduler = MyTaskScheduler( tasks, DummyBuilder(), DummyRunner(), database, measure_callbacks=[ measure_callback.AddToDatabase(), ], max_trials=max_trials_per_task * len(tasks), ) scheduler.tune() assert len(database) == max_trials_per_task * len(tasks) for task in tasks: assert (len( database.get_top_k( database.commit_workload(task.mod), 100000, )) == max_trials_per_task)
def test_meta_schedule_task_scheduler_multiple(): num_trials_per_iter = 6 max_trials_per_task = 101 tasks = [ TuneContext( MatmulModule, target=tvm.target.Target("llvm"), space_generator=ScheduleFn(sch_fn=_schedule_matmul), search_strategy=ReplayTrace(num_trials_per_iter, max_trials_per_task), task_name="Matmul", rand_state=42, ), TuneContext( MatmulReluModule, target=tvm.target.Target("llvm"), space_generator=ScheduleFn(sch_fn=_schedule_matmul), search_strategy=ReplayTrace(num_trials_per_iter, max_trials_per_task), task_name="MatmulRelu", rand_state=0xDEADBEEF, ), TuneContext( BatchMatmulModule, target=tvm.target.Target("llvm"), space_generator=ScheduleFn(sch_fn=_schedule_batch_matmul), search_strategy=ReplayTrace(num_trials_per_iter, max_trials_per_task), task_name="BatchMatmul", rand_state=0x114514, ), ] database = DummyDatabase() round_robin = RoundRobin( tasks, [1.0], DummyBuilder(), DummyRunner(), database, measure_callbacks=[measure_callback.AddToDatabase()], max_trials=max_trials_per_task * len(tasks), ) round_robin.tune() assert len(database) == max_trials_per_task * len(tasks) for task in tasks: assert (len( database.get_top_k( database.commit_workload(task.mod), 100000, )) == max_trials_per_task)