def test_meta_schedule_task_scheduler_multiple():
    num_trials_per_iter = 6
    max_trials_per_task = 101
    tasks = [
        ms.TuneContext(
            MatmulModule,
            target=tvm.target.Target("llvm"),
            space_generator=ms.space_generator.ScheduleFn(
                sch_fn=_schedule_matmul),
            search_strategy=ms.search_strategy.ReplayTrace(
                num_trials_per_iter,
                max_trials_per_task,
            ),
            task_name="Matmul",
            rand_state=42,
        ),
        ms.TuneContext(
            MatmulReluModule,
            target=tvm.target.Target("llvm"),
            space_generator=ms.space_generator.ScheduleFn(
                sch_fn=_schedule_matmul),
            search_strategy=ms.search_strategy.ReplayTrace(
                num_trials_per_iter,
                max_trials_per_task,
            ),
            task_name="MatmulRelu",
            rand_state=0xDEADBEEF,
        ),
        ms.TuneContext(
            BatchMatmulModule,
            target=tvm.target.Target("llvm"),
            space_generator=ms.space_generator.ScheduleFn(
                sch_fn=_schedule_batch_matmul),
            search_strategy=ms.search_strategy.ReplayTrace(
                num_trials_per_iter,
                max_trials_per_task,
            ),
            task_name="BatchMatmul",
            rand_state=0x114514,
        ),
    ]
    database = ms.database.MemoryDatabase()
    round_robin = ms.task_scheduler.RoundRobin(
        tasks,
        [1.0, 1.0, 1.0],
        builder=DummyBuilder(),
        runner=DummyRunner(),
        database=database,
        measure_callbacks=[ms.measure_callback.AddToDatabase()],
        max_trials=max_trials_per_task * len(tasks),
    )
    round_robin.tune()
    assert len(database) == max_trials_per_task * len(tasks)
    for task in tasks:
        assert (len(
            database.get_top_k(
                database.commit_workload(task.mod),
                100000,
            )) == max_trials_per_task)
def test_meta_schedule_task_scheduler_single():
    num_trials_per_iter = 3
    max_trials_per_task = 10
    database = ms.database.MemoryDatabase()
    round_robin = ms.task_scheduler.RoundRobin(
        [
            ms.TuneContext(
                MatmulModule,
                target=tvm.target.Target("llvm"),
                space_generator=ms.space_generator.ScheduleFn(
                    sch_fn=_schedule_matmul),
                search_strategy=ms.search_strategy.ReplayTrace(
                    num_trials_per_iter,
                    max_trials_per_task,
                ),
                task_name="Test",
                rand_state=42,
            )
        ],
        [1.0],
        builder=DummyBuilder(),
        runner=DummyRunner(),
        database=database,
        measure_callbacks=[ms.measure_callback.AddToDatabase()],
        max_trials=max_trials_per_task,
    )
    round_robin.tune()
    assert len(database) == max_trials_per_task
Example #3
0
def test_conv2d_winograd_cuda():
    mod = conv2d_winograd_cuda
    mod = IRModule({"main": mod})
    context = ms.TuneContext(
        mod=mod,
        target=Target("nvidia/geforce-rtx-3090", host="llvm"),
        task_name="Custom Search Space Task",
        space_generator=ms.space_generator.PostOrderApply(),
        sch_rules=ms.default_config.schedule_rules(  # pylint: disable=protected-access
            None, Target("cuda")),
    )
    post_order_apply = context.space_generator
    (sch, ) = post_order_apply.generate_design_space(mod)
    decisions = dict(
        zip(
            [i for i in sch.trace.insts if i.kind.name.startswith("Sample")],
            [
                # data_pack
                [3, 3],
                [64, 2],
                2,
                # inverse
                [3, 3],
                [2, 64],
                2,
                # bgemm
                [1, 1, 1, 1, 6],
                [1, 1, 1, 3, 2],
                [3, 1, 1, 1, 3],
                [4, 2, 1, 4, 4],
                [32, 1, 4],
                1,
                1,
                # root anno
                2,
                # conv2d
                2,
            ],
        ))
    trace = Trace(sch.trace.insts, decisions=decisions)
    sch = Schedule(mod=mod)
    trace.apply_to_schedule(sch, remove_postproc=False)
    answer = sch.mod
    expected = _get_mod()
    tvm.ir.assert_structural_equal(answer, expected)
def _make_context(target) -> ms.TuneContext:
    return ms.TuneContext(
        target=target,
        num_threads=1,
    )
def measure_candidates(database, builder, runner):
    """Send the candidates to builder and runner for distributed measurement,
    and save the results in a new json database.

    Parameters
    ----------
    database : JSONDatabase
        The database for candidates to be measured.
    builder : Builder
        The builder for building the candidates.
    runner : Runner
        The runner for measuring the candidates.

    Returns
    -------
    None
    """
    candidates, runner_results, build_fail_indices, run_fail_indices = [], [], [], []
    context = ms.TuneContext(target=Target(args.target))
    tuning_records = database.get_all_tuning_records()
    for record in tuning_records:
        candidates.append(record.as_measure_candidate())
    with ms.Profiler() as profiler:
        for idx in range(0, len(candidates), args.batch_size):
            batch_candidates = candidates[idx:idx + args.batch_size]
            context._set_measure_candidates(batch_candidates)  # pylint: disable=protected-access
            with ms.Profiler.timeit("build"):
                context._send_to_builder(builder)  # pylint: disable=protected-access
            with ms.Profiler.timeit("run"):
                context._send_to_runner(runner)  # pylint: disable=protected-access
                batch_runner_results = context._join()  # pylint: disable=protected-access
            runner_results.extend(batch_runner_results)
            for i, result in enumerate(context.builder_results):
                if result.error_msg is None:
                    ms.utils.remove_build_dir(result.artifact_path)
                else:
                    build_fail_indices.append(i + idx)
            context._clear_measure_state()  # pylint: disable=protected-access

    model_name, workload_name = database.path_workload.split("/")[-2:]
    record_name = database.path_tuning_record.split("/")[-1]
    new_database = ms.database.JSONDatabase(
        path_workload=os.path.join(args.result_cache_dir, model_name,
                                   workload_name),
        path_tuning_record=os.path.join(args.result_cache_dir, model_name,
                                        record_name),
    )
    workload = tuning_records[0].workload
    new_database.commit_workload(workload.mod)
    for i, (record, result) in enumerate(zip(tuning_records, runner_results)):
        if result.error_msg is None:
            new_database.commit_tuning_record(
                ms.database.TuningRecord(
                    trace=record.trace,
                    workload=workload,
                    run_secs=[v.value for v in result.run_secs],
                    target=Target(args.target),
                ))
        else:
            run_fail_indices.append(i)
    fail_indices_name = workload_name.replace("_workload.json",
                                              "_failed_indices.txt")
    with open(os.path.join(args.result_cache_dir, model_name,
                           fail_indices_name),
              "w",
              encoding="utf8") as file:
        file.write(" ".join([str(n) for n in run_fail_indices]))
    print(
        f"Builder time: {profiler.get()['build']}, Runner time: {profiler.get()['run']}\n\
            Failed number of builds: {len(build_fail_indices)},\
            Failed number of runs: {len(run_fail_indices)}")