Exemplo n.º 1
0
def test_meta_schedule_task_scheduler_override_next_task_id_only():  # pylint: disable=invalid-name

    num_trials_per_iter = 6
    max_trials_per_task = 101
    tasks = [
        TuneContext(
            MatmulModule,
            target=tvm.target.Target("llvm"),
            space_generator=ScheduleFn(sch_fn=_schedule_matmul),
            search_strategy=ReplayTrace(num_trials_per_iter,
                                        max_trials_per_task),
            task_name="Matmul",
            rand_state=42,
        ),
        TuneContext(
            MatmulReluModule,
            target=tvm.target.Target("llvm"),
            space_generator=ScheduleFn(sch_fn=_schedule_matmul),
            search_strategy=ReplayTrace(num_trials_per_iter,
                                        max_trials_per_task),
            task_name="MatmulRelu",
            rand_state=0xDEADBEEF,
        ),
        TuneContext(
            BatchMatmulModule,
            target=tvm.target.Target("llvm"),
            space_generator=ScheduleFn(sch_fn=_schedule_batch_matmul),
            search_strategy=ReplayTrace(num_trials_per_iter,
                                        max_trials_per_task),
            task_name="BatchMatmul",
            rand_state=0x114514,
        ),
    ]
    database = DummyDatabase()
    scheduler = MyTaskScheduler(
        tasks,
        DummyBuilder(),
        DummyRunner(),
        database,
        measure_callbacks=[
            measure_callback.AddToDatabase(),
        ],
        max_trials=max_trials_per_task * len(tasks),
    )
    scheduler.tune()
    assert len(database) == max_trials_per_task * len(tasks)
    for task in tasks:
        assert (len(
            database.get_top_k(
                database.commit_workload(task.mod),
                100000,
            )) == max_trials_per_task)
Exemplo n.º 2
0
def test_meta_schedule_task_scheduler_multiple():
    num_trials_per_iter = 6
    max_trials_per_task = 101
    tasks = [
        TuneContext(
            MatmulModule,
            target=tvm.target.Target("llvm"),
            space_generator=ScheduleFn(sch_fn=_schedule_matmul),
            search_strategy=ReplayTrace(num_trials_per_iter,
                                        max_trials_per_task),
            task_name="Matmul",
            rand_state=42,
        ),
        TuneContext(
            MatmulReluModule,
            target=tvm.target.Target("llvm"),
            space_generator=ScheduleFn(sch_fn=_schedule_matmul),
            search_strategy=ReplayTrace(num_trials_per_iter,
                                        max_trials_per_task),
            task_name="MatmulRelu",
            rand_state=0xDEADBEEF,
        ),
        TuneContext(
            BatchMatmulModule,
            target=tvm.target.Target("llvm"),
            space_generator=ScheduleFn(sch_fn=_schedule_batch_matmul),
            search_strategy=ReplayTrace(num_trials_per_iter,
                                        max_trials_per_task),
            task_name="BatchMatmul",
            rand_state=0x114514,
        ),
    ]
    database = DummyDatabase()
    round_robin = RoundRobin(
        tasks,
        [1.0],
        DummyBuilder(),
        DummyRunner(),
        database,
        measure_callbacks=[measure_callback.AddToDatabase()],
        max_trials=max_trials_per_task * len(tasks),
    )
    round_robin.tune()
    assert len(database) == max_trials_per_task * len(tasks)
    for task in tasks:
        assert (len(
            database.get_top_k(
                database.commit_workload(task.mod),
                100000,
            )) == max_trials_per_task)
def test_meta_schedule_replay_func(
        TestClass: SearchStrategy):  # pylint: disable = invalid-name
    num_trials_per_iter = 7
    max_trials_per_task = 20

    strategy = TestClass(num_trials_per_iter=num_trials_per_iter,
                         max_trials_per_task=max_trials_per_task)
    context = TuneContext(mod=Matmul,
                          space_generator=ScheduleFn(sch_fn=_schedule_matmul))
    context.space_generator.initialize_with_tune_context(context)
    spaces = context.space_generator.generate_design_space(context.mod)

    strategy.initialize_with_tune_context(context)
    strategy.pre_tuning(spaces)
    (correct_sch, ) = ScheduleFn(
        sch_fn=_schedule_matmul).generate_design_space(Matmul)
    num_trials_each_iter: List[int] = []
    candidates = strategy.generate_measure_candidates()
    while candidates is not None:
        num_trials_each_iter.append(len(candidates))
        runner_results: List[RunnerResult] = []
        for candidate in candidates:
            _is_trace_equal(
                candidate.sch,
                correct_sch,
                remove_decisions=(isinstance(strategy, ReplayTrace)),
            )
            runner_results.append(
                RunnerResult(run_secs=[0.11, 0.41, 0.54], error_msg=None))
        strategy.notify_runner_results(context, candidates, runner_results)
        candidates = strategy.generate_measure_candidates()
    strategy.post_tuning()
    assert num_trials_each_iter == [7, 7, 6]
Exemplo n.º 4
0
def test_meta_schedule_custom_search_space():
    mod = MatmulCustomized
    context = TuneContext(
        mod=mod,
        target=Target("llvm"),
        task_name="Custom Search Space Task",
        sch_rules=[],
    )
    post_order_apply = PostOrderApply()
    post_order_apply.initialize_with_tune_context(context)

    post_order_apply.generate_design_space(mod)

    called = False

    def custom_search_space_func(sch: Schedule, _: BlockRV) -> List[Schedule]:
        nonlocal called
        called = True
        return [sch]

    register_func("tvm.meta_schedule.test.custom_search_space",
                  custom_search_space_func)

    post_order_apply.generate_design_space(mod)
    assert called
def test_meta_schedule_replay_trace():
    num_trials_per_iter = 7
    num_trials_total = 20

    (example_sch, ) = ScheduleFn(
        sch_fn=_schedule_matmul).generate_design_space(Matmul)
    replay = ReplayTrace(num_trials_per_iter=num_trials_per_iter,
                         num_trials_total=num_trials_total)
    tune_context = TuneContext(mod=Matmul)
    replay.initialize_with_tune_context(tune_context)

    num_trials_each_round: List[int] = []
    replay.pre_tuning([example_sch])
    while True:
        candidates = replay.generate_measure_candidates()
        if candidates is None:
            break
        num_trials_each_round.append(len(candidates))
        runner_results: List[RunnerResult] = []
        for candidate in candidates:
            assert _is_trace_equal(candidate.sch, example_sch)
            runner_results.append(
                RunnerResult(run_secs=[0.5, 0.4, 0.3], error_msg=None))
        replay.notify_runner_results(runner_results)
    replay.post_tuning()
    assert num_trials_each_round == [7, 7, 6]
Exemplo n.º 6
0
def _make_mutator(target: Target) -> Mutator:
    ctx = TuneContext(
        mod=matmul,
        target=target,
        mutator_probs={MutateTileSize(): 1.0},
    )
    return list(ctx.mutator_probs.keys())[0]
Exemplo n.º 7
0
def test_meta_schedule_task_scheduler_single():
    num_trials_per_iter = 3
    max_trials_per_task = 10
    sch_fn = ScheduleFn(sch_fn=_schedule_matmul)
    replay = ReplayTrace(num_trials_per_iter, max_trials_per_task)
    task = TuneContext(
        MatmulModule,
        target=tvm.target.Target("llvm"),
        space_generator=sch_fn,
        search_strategy=replay,
        task_name="Test",
        rand_state=42,
    )
    database = DummyDatabase()
    round_robin = RoundRobin(
        [task],
        [1.0],
        DummyBuilder(),
        DummyRunner(),
        database,
        measure_callbacks=[measure_callback.AddToDatabase()],
        max_trials=max_trials_per_task,
    )
    round_robin.tune()
    assert len(database) == max_trials_per_task
Exemplo n.º 8
0
def _create_context(mod, target, postprocs):
    ctx = TuneContext(
        mod=mod,
        target=target,
        postprocs=postprocs,
        task_name="test",
    )
    return ctx
def test_tune_context_create():
    mod = Matmul
    context = TuneContext(mod=mod,
                          target=Target("llvm"),
                          task_name="Test Task")
    assert context.num_threads > 0
    assert context.rand_state != -1
    assert context.task_name == "Test Task"
    assert context.mod == mod or tvm.ir.structural_equal(context.mod, mod)
def _create_context(mod, target) -> TuneContext:
    return TuneContext(
        mod=mod,
        target=target,
        postprocs=[
            RewriteLayout(),
        ],
        task_name="test",
    )
def _make_mutator(target: Target) -> Mutator:
    ctx = TuneContext(
        mod=add,
        target=target,
        mutator_probs={
            MutateComputeLocation(): 1.0,
        },
    )
    return list(ctx.mutator_probs.keys())[0]
Exemplo n.º 12
0
def _make_mutator(target: Target) -> Mutator:
    ctx = TuneContext(
        mod=element_wise,
        target=target,
        mutator_probs={
            MutateThreadBinding(): 1.0,
        },
    )
    return list(ctx.mutator_probs.keys())[0]
def _make_mutator(target: Target, max_jobs_per_core: int) -> Mutator:
    ctx = TuneContext(
        mod=matmul,
        target=target,
        mutator_probs={
            MutateParallel(max_jobs_per_core): 1.0,
        },
    )
    return list(ctx.mutator_probs.keys())[0]
Exemplo n.º 14
0
def _create_context(mod, target) -> TuneContext:
    ctx = TuneContext(
        mod=mod,
        target=target,
        postprocs=[
            DisallowDynamicLoop(),
        ],
        task_name="test",
    )
    return ctx
def _create_context(mod, target) -> TuneContext:
    ctx = TuneContext(
        mod=mod,
        target=target,
        postprocs=[
            RewriteCooperativeFetch(),
        ],
        task_name="test",
    )
    return ctx
def _create_context(mod, target) -> TuneContext:
    ctx = TuneContext(
        mod=mod,
        target=target,
        postprocs=[
            VerifyGPUCode(),
        ],
        task_name="test",
    )
    return ctx
Exemplo n.º 17
0
def _create_context(mod, target) -> TuneContext:
    ctx = TuneContext(
        mod=mod,
        target=target,
        postprocs=[
            RewriteUnboundBlock(),
        ],
        task_name="test",
    )
    return ctx
Exemplo n.º 18
0
def _create_context(mod, target, postprocs):
    ctx = TuneContext(
        mod=mod,
        target=target,
        postprocs=postprocs,
        task_name="test",
    )
    for rule in ctx.postprocs:
        rule.initialize_with_tune_context(ctx)
    return ctx
def _create_context(mod, target) -> TuneContext:
    ctx = TuneContext(
        mod=mod,
        target=target,
        postprocs=[
            VerifyGPUCode(),
        ],
        task_name="test",
    )
    for rule in ctx.postprocs:
        rule.initialize_with_tune_context(ctx)
    return ctx
def _create_context(mod, target) -> TuneContext:
    ctx = TuneContext(
        mod=mod,
        target=target,
        postprocs=[
            DisallowDynamicLoop(),
        ],
        task_name="test",
    )
    for rule in ctx.postprocs:
        rule.initialize_with_tune_context(ctx)
    return ctx
Exemplo n.º 21
0
def _create_context(mod, target) -> TuneContext:
    ctx = TuneContext(
        mod=mod,
        target=target,
        postprocs=[
            RewriteCooperativeFetch(),
        ],
        task_name="test",
    )
    for rule in ctx.postprocs:
        rule.initialize_with_tune_context(ctx)
    return ctx
def test_meta_schedule_feature_extractor():
    class FancyFeatureExtractor(PyFeatureExtractor):
        def extract_from(
            self,
            context: TuneContext,  # pylint: disable = unused-argument
            candidates: List[MeasureCandidate],  # pylint: disable = unused-argument
        ) -> List[np.ndarray]:
            return [np.random.rand(4, 5)]

    extractor = FancyFeatureExtractor()
    features = extractor.extract_from(TuneContext(), [])
    assert len(features) == 1
    assert features[0].shape == (4, 5)
Exemplo n.º 23
0
def test_meta_schedule_post_order_apply():
    mod = Matmul
    context = TuneContext(
        mod=mod,
        target=Target("llvm"),
        task_name="Test Task",
        sch_rules=[WowSoFancyScheduleRule()],
    )
    post_order_apply = PostOrderApply()
    post_order_apply.initialize_with_tune_context(context)
    schs = post_order_apply.generate_design_space(mod)
    assert len(schs) == 1
    assert not tvm.ir.structural_equal(schs[0].mod, mod)
    _check_correct(schs[0])
Exemplo n.º 24
0
def test_meta_schedule_task_scheduler_multiple():
    num_trials_per_iter = 6
    num_trials_total = 101
    tasks = [
        TuneContext(
            MatmulModule,
            target=tvm.target.Target("llvm"),
            space_generator=ScheduleFn(sch_fn=_schedule_matmul),
            search_strategy=ReplayTrace(num_trials_per_iter, num_trials_total),
            task_name="Matmul",
            rand_state=42,
        ),
        TuneContext(
            MatmulReluModule,
            target=tvm.target.Target("llvm"),
            space_generator=ScheduleFn(sch_fn=_schedule_matmul),
            search_strategy=ReplayTrace(num_trials_per_iter, num_trials_total),
            task_name="MatmulRelu",
            rand_state=0xDEADBEEF,
        ),
        TuneContext(
            BatchMatmulModule,
            target=tvm.target.Target("llvm"),
            space_generator=ScheduleFn(sch_fn=_schedule_batch_matmul),
            search_strategy=ReplayTrace(num_trials_per_iter, num_trials_total),
            task_name="BatchMatmul",
            rand_state=0x114514,
        ),
    ]
    database = DummyDatabase()
    round_robin = RoundRobin(tasks, DummyBuilder(), DummyRunner(), database)
    round_robin.tune()
    assert len(database) == num_trials_total * len(tasks)
    print(database.workload_reg)
    for task in tasks:
        assert len(database.get_top_k(database.commit_workload(task.mod),
                                      1e9)) == num_trials_total
Exemplo n.º 25
0
def test_meta_schedule_post_order_apply_multiple():
    mod = Matmul
    context = TuneContext(
        mod=mod,
        target=Target("llvm"),
        task_name="Double Rules Task",
        sch_rules=[DoubleScheduleRule(), ReorderScheduleRule()],
    )
    post_order_apply = PostOrderApply()
    post_order_apply.initialize_with_tune_context(context)
    schs = post_order_apply.generate_design_space(mod)
    assert len(schs) == 4
    for sch in schs:
        assert not tvm.ir.structural_equal(sch.mod, mod)
        _check_correct(sch)
Exemplo n.º 26
0
def test_meta_schedule_post_order_apply_double():
    mod = Matmul
    context = TuneContext(
        mod=mod,
        target=Target("llvm"),
        task_name="Double Rules Task",
        space_generator=PostOrderApply(),
        sch_rules=[DoubleScheduleRule()],
    )
    post_order_apply = context.space_generator
    schs = post_order_apply.generate_design_space(mod)
    assert len(schs) == 2
    for sch in schs:
        assert not tvm.ir.structural_equal(sch.mod, mod)
        _check_correct(sch)
Exemplo n.º 27
0
def test_meta_schedule_post_order_apply_duplicate_matmul():
    mod = DuplicateMatmul
    context = TuneContext(
        mod=mod,
        target=Target("llvm"),
        task_name="Duplicate Matmul Task",
        sch_rules=[WowSoFancyScheduleRule()],
    )
    post_order_apply = PostOrderApply()
    post_order_apply.initialize_with_tune_context(context)
    with pytest.raises(
        TVMError,
        match=r".*TVMError: Check failed: \(block_names_.count\(block->name_hint\) == 0\)"
        r" is false: Duplicated block name matmul in function main not supported!",
    ):
        post_order_apply.generate_design_space(mod)
def test_conv2d_winograd_cuda():
    mod = conv2d_winograd_cuda
    mod = IRModule({"main": mod})
    context = TuneContext(
        mod=mod,
        target=Target("nvidia/geforce-rtx-3090", host="llvm"),
        task_name="Custom Search Space Task",
        sch_rules=DefaultCUDA._sch_rules(),  # pylint: disable=protected-access
    )
    for sch_rule in context.sch_rules:
        sch_rule.initialize_with_tune_context(context)
    post_order_apply = PostOrderApply()
    post_order_apply.initialize_with_tune_context(context)
    (sch,) = post_order_apply.generate_design_space(mod)
    decisions = dict(
        zip(
            [i for i in sch.trace.insts if i.kind.name.startswith("Sample")],
            [
                # data_pack
                [3, 3],
                [64, 2],
                2,
                # inverse
                [3, 3],
                [2, 64],
                2,
                # bgemm
                [1, 1, 1, 1, 6],
                [1, 1, 1, 3, 2],
                [3, 1, 1, 1, 3],
                [4, 2, 1, 4, 4],
                [32, 1, 4],
                1,
                1,
                # root anno
                2,
                # conv2d
                2,
            ],
        )
    )
    trace = Trace(sch.trace.insts, decisions=decisions)
    sch = Schedule(mod=mod)
    trace.apply_to_schedule(sch, remove_postproc=False)
    answer = sch.mod
    expected = _get_mod()
    tvm.ir.assert_structural_equal(answer, expected)
Exemplo n.º 29
0
def test_meta_schedule_task_scheduler_single():
    num_trials_per_iter = 3
    num_trials_total = 10
    sch_fn = ScheduleFn(sch_fn=_schedule_matmul)
    replay = ReplayTrace(num_trials_per_iter, num_trials_total)
    task = TuneContext(
        MatmulModule,
        target=tvm.target.Target("llvm"),
        space_generator=sch_fn,
        search_strategy=replay,
        task_name="Test",
        rand_state=42,
    )
    database = DummyDatabase()
    round_robin = RoundRobin([task], DummyBuilder(), DummyRunner(), database)
    round_robin.tune()
    assert len(database) == num_trials_total
Exemplo n.º 30
0
def test_conv2d_winograd_cpu():
    mod = conv2d_winograd_cpu
    mod = IRModule({"main": mod})
    context = TuneContext(
        mod=mod,
        target=Target("llvm"),
        task_name="Custom Search Space Task",
        sch_rules=DefaultLLVM._sch_rules(),  # pylint: disable=protected-access
    )
    post_order_apply = PostOrderApply()
    post_order_apply.initialize_with_tune_context(context)
    (sch, ) = post_order_apply.generate_design_space(mod)

    decisions = dict(
        zip(
            [
                i for i in sch.trace.insts[:-4]
                if i.kind.name.startswith("Sample")
            ],
            [
                # data_pack
                [9, 1],
                [32, 4],
                # input_tile
                4,
                # data_pad
                -2,
                # inverse
                [1, 9],
                [2, 64],
                # bgemm
                [1, 2, 3, 1],
                [1, 1, 1, 6],
                [1, 1, 1, 9],
                [2, 1, 16, 4],
                [16, 8],
            ],
        ))
    trace = Trace(sch.trace.insts[:-4], decisions=decisions)
    sch = Schedule(mod=mod)
    trace.apply_to_schedule(sch, remove_postproc=False)
    answer = sch.mod
    expected = _get_mod()
    tvm.ir.assert_structural_equal(answer, expected)