def test_meta_schedule_task_scheduler_multiple(): num_trials_per_iter = 6 max_trials_per_task = 101 tasks = [ ms.TuneContext( MatmulModule, target=tvm.target.Target("llvm"), space_generator=ms.space_generator.ScheduleFn( sch_fn=_schedule_matmul), search_strategy=ms.search_strategy.ReplayTrace( num_trials_per_iter, max_trials_per_task, ), task_name="Matmul", rand_state=42, ), ms.TuneContext( MatmulReluModule, target=tvm.target.Target("llvm"), space_generator=ms.space_generator.ScheduleFn( sch_fn=_schedule_matmul), search_strategy=ms.search_strategy.ReplayTrace( num_trials_per_iter, max_trials_per_task, ), task_name="MatmulRelu", rand_state=0xDEADBEEF, ), ms.TuneContext( BatchMatmulModule, target=tvm.target.Target("llvm"), space_generator=ms.space_generator.ScheduleFn( sch_fn=_schedule_batch_matmul), search_strategy=ms.search_strategy.ReplayTrace( num_trials_per_iter, max_trials_per_task, ), task_name="BatchMatmul", rand_state=0x114514, ), ] database = ms.database.MemoryDatabase() round_robin = ms.task_scheduler.RoundRobin( tasks, [1.0, 1.0, 1.0], builder=DummyBuilder(), runner=DummyRunner(), database=database, measure_callbacks=[ms.measure_callback.AddToDatabase()], max_trials=max_trials_per_task * len(tasks), ) round_robin.tune() assert len(database) == max_trials_per_task * len(tasks) for task in tasks: assert (len( database.get_top_k( database.commit_workload(task.mod), 100000, )) == max_trials_per_task)
def test_meta_schedule_task_scheduler_single(): num_trials_per_iter = 3 max_trials_per_task = 10 database = ms.database.MemoryDatabase() round_robin = ms.task_scheduler.RoundRobin( [ ms.TuneContext( MatmulModule, target=tvm.target.Target("llvm"), space_generator=ms.space_generator.ScheduleFn( sch_fn=_schedule_matmul), search_strategy=ms.search_strategy.ReplayTrace( num_trials_per_iter, max_trials_per_task, ), task_name="Test", rand_state=42, ) ], [1.0], builder=DummyBuilder(), runner=DummyRunner(), database=database, measure_callbacks=[ms.measure_callback.AddToDatabase()], max_trials=max_trials_per_task, ) round_robin.tune() assert len(database) == max_trials_per_task
def test_conv2d_winograd_cuda(): mod = conv2d_winograd_cuda mod = IRModule({"main": mod}) context = ms.TuneContext( mod=mod, target=Target("nvidia/geforce-rtx-3090", host="llvm"), task_name="Custom Search Space Task", space_generator=ms.space_generator.PostOrderApply(), sch_rules=ms.default_config.schedule_rules( # pylint: disable=protected-access None, Target("cuda")), ) post_order_apply = context.space_generator (sch, ) = post_order_apply.generate_design_space(mod) decisions = dict( zip( [i for i in sch.trace.insts if i.kind.name.startswith("Sample")], [ # data_pack [3, 3], [64, 2], 2, # inverse [3, 3], [2, 64], 2, # bgemm [1, 1, 1, 1, 6], [1, 1, 1, 3, 2], [3, 1, 1, 1, 3], [4, 2, 1, 4, 4], [32, 1, 4], 1, 1, # root anno 2, # conv2d 2, ], )) trace = Trace(sch.trace.insts, decisions=decisions) sch = Schedule(mod=mod) trace.apply_to_schedule(sch, remove_postproc=False) answer = sch.mod expected = _get_mod() tvm.ir.assert_structural_equal(answer, expected)
def _make_context(target) -> ms.TuneContext: return ms.TuneContext( target=target, num_threads=1, )
def measure_candidates(database, builder, runner): """Send the candidates to builder and runner for distributed measurement, and save the results in a new json database. Parameters ---------- database : JSONDatabase The database for candidates to be measured. builder : Builder The builder for building the candidates. runner : Runner The runner for measuring the candidates. Returns ------- None """ candidates, runner_results, build_fail_indices, run_fail_indices = [], [], [], [] context = ms.TuneContext(target=Target(args.target)) tuning_records = database.get_all_tuning_records() for record in tuning_records: candidates.append(record.as_measure_candidate()) with ms.Profiler() as profiler: for idx in range(0, len(candidates), args.batch_size): batch_candidates = candidates[idx:idx + args.batch_size] context._set_measure_candidates(batch_candidates) # pylint: disable=protected-access with ms.Profiler.timeit("build"): context._send_to_builder(builder) # pylint: disable=protected-access with ms.Profiler.timeit("run"): context._send_to_runner(runner) # pylint: disable=protected-access batch_runner_results = context._join() # pylint: disable=protected-access runner_results.extend(batch_runner_results) for i, result in enumerate(context.builder_results): if result.error_msg is None: ms.utils.remove_build_dir(result.artifact_path) else: build_fail_indices.append(i + idx) context._clear_measure_state() # pylint: disable=protected-access model_name, workload_name = database.path_workload.split("/")[-2:] record_name = database.path_tuning_record.split("/")[-1] new_database = ms.database.JSONDatabase( path_workload=os.path.join(args.result_cache_dir, model_name, workload_name), path_tuning_record=os.path.join(args.result_cache_dir, model_name, record_name), ) workload = tuning_records[0].workload new_database.commit_workload(workload.mod) for i, (record, result) in enumerate(zip(tuning_records, runner_results)): if result.error_msg is None: new_database.commit_tuning_record( ms.database.TuningRecord( trace=record.trace, workload=workload, run_secs=[v.value for v in result.run_secs], target=Target(args.target), )) else: run_fail_indices.append(i) fail_indices_name = workload_name.replace("_workload.json", "_failed_indices.txt") with open(os.path.join(args.result_cache_dir, model_name, fail_indices_name), "w", encoding="utf8") as file: file.write(" ".join([str(n) for n in run_fail_indices])) print( f"Builder time: {profiler.get()['build']}, Runner time: {profiler.get()['run']}\n\ Failed number of builds: {len(build_fail_indices)},\ Failed number of runs: {len(run_fail_indices)}")