def check_index_map(workload, block_name, intrin_name, expected_index_map):
    s = Schedule(workload)
    block = s.get_block(block_name)
    desc_func = TensorIntrin.get(intrin_name).desc
    info = get_auto_tensorize_mapping_info(s, block, desc_func)
    assert len(info.mappings) == 1
    assert IndexMap.from_func(expected_index_map).is_equivalent_to(
        info.mappings[0])
예제 #2
0
def _find_match_sketch_id(
    mod: IRModule,
    sketches: List[Schedule],
    expected_mod: IRModule,
    expected_decision: List[Tuple[str, List[int]]],
    *,
    debug_mask="all",
) -> Optional[int]:
    for sketch_id, sketch in enumerate(sketches):
        i = 0
        new_decisions = {}
        for inst in sketch.trace.insts:
            if not inst.kind.name.startswith("Sample"):
                continue
            assert i < len(expected_decision)
            if inst.kind.name == expected_decision[i][0]:
                new_decisions[inst] = expected_decision[i][1]
                i += 1
        if len(new_decisions) != len(expected_decision):
            continue
        sch = Schedule(mod, debug_mask=debug_mask)
        Trace(
            insts=sketch.trace.insts,
            decisions=new_decisions,
        ).apply_to_schedule(sch, remove_postproc=True)
        if structural_equal(sch.mod, expected_mod):
            verify_trace_roundtrip(sch=sch, mod=mod, debug_mask=debug_mask)
            return sketch_id
    return None
def _schedule_batch_matmul(sch: Schedule):
    block = sch.get_block("matmul")
    i, j, k, t = sch.get_loops(block=block)
    i_0, i_1, i_2, i_3 = sch.split(loop=i, factors=[2, 2, 2, 2])
    j_0, j_1, j_2, j_3 = sch.split(loop=j, factors=[2, 4, 64, 2])
    k_0, k_1 = sch.split(loop=k, factors=[32, 32])
    t_0, t_1 = sch.split(loop=t, factors=[2, 512])
    sch.reorder(i_0, j_0, i_1, j_1, k_0, i_2, j_2, k_1, i_3, j_3, t_0, t_1)
예제 #4
0
def test_meta_schedule_integration_apply_history_best():
    class DummyDatabase(PyDatabase):
        def __init__(self):
            super().__init__()
            self.records = []
            self.workload_reg = []

        def has_workload(self, mod: IRModule) -> Workload:
            for workload in self.workload_reg:
                if tvm.ir.structural_equal(workload.mod, mod):
                    return True
            return False

        def commit_tuning_record(self, record: TuningRecord) -> None:
            self.records.append(record)

        def commit_workload(self, mod: IRModule) -> Workload:
            for workload in self.workload_reg:
                if tvm.ir.structural_equal(workload.mod, mod):
                    return workload
            workload = Workload(mod)
            self.workload_reg.append(workload)
            return workload

        def get_top_k(self, workload: Workload,
                      top_k: int) -> List[TuningRecord]:
            return list(
                filter(
                    lambda x: x.workload == workload,
                    sorted(self.records,
                           key=lambda x: sum(x.run_secs) / len(x.run_secs)),
                ))[:int(top_k)]

        def __len__(self) -> int:
            return len(self.records)

        def print_results(self) -> None:
            print("\n".join([str(r) for r in self.records]))

    mod, _, _, _ = get_network(
        name="resnet-18",
        batch_size=1,
        layout="NHWC",
        dtype="float32",
    )
    database = DummyDatabase()
    env = ApplyHistoryBest(database)
    target = Target("llvm")
    workload = database.commit_workload(MockModule)
    database.commit_tuning_record(
        TuningRecord(Schedule(MockModule).trace, [1.0], workload, target, []))
    mod = env.query(task_name="mock-task",
                    mod=mod,
                    target=target,
                    dispatched=[MockModule])
    assert tvm.ir.structural_equal(mod, workload.mod)
예제 #5
0
def _schedule_batch_matmul(sch: Schedule):
    block = sch.get_block("matmul")
    i, j, k, t = sch.get_loops(block=block)
    # TODO(@zxybazh): Change to `sample_perfect_tile` after upstreaming
    i_0, i_1, i_2, i_3 = sch.split(loop=i, factors=[2, 2, 2, 2])
    j_0, j_1, j_2, j_3 = sch.split(loop=j, factors=[2, 4, 64, 2])
    k_0, k_1 = sch.split(loop=k, factors=[32, 32])
    t_0, t_1 = sch.split(loop=t, factors=[2, 512])
    sch.reorder(i_0, j_0, i_1, j_1, k_0, i_2, j_2, k_1, i_3, j_3, t_0, t_1)
예제 #6
0
def test_get_tensorize_loop_mapping_matmul_mma():
    @T.prim_func
    def matmul_16x16x16xf16f16f16_desc(
        A: T.Buffer((16, 16), "float16", align=128, offset_factor=1),
        B: T.Buffer((16, 16), "float16", align=128, offset_factor=1),
        C: T.Buffer((16, 16), "float16", align=128, offset_factor=1),
    ) -> None:
        with T.block("root"):
            T.reads(C[0:16, 0:16], A[0:16, 0:16], B[0:16, 0:16])
            T.writes(C[0:16, 0:16])
            for i, j, k in T.grid(16, 16, 16):
                with T.block("update"):
                    vii, vjj, vkk = T.axis.remap("SSR", [i, j, k])
                    C[vii, vjj] = C[vii, vjj] + A[vii, vkk] * B[vjj, vkk]

    matmul = create_prim_func(te_workload.matmul_relu(
        n=512,
        m=512,
        k=512,
    ))

    s = Schedule(matmul)
    block = s.get_block("C")
    i0, i1, i2 = s.get_loops(block)
    desc_loops = collect_loops(matmul_16x16x16xf16f16f16_desc)

    for do_reorder in [False, True]:
        # Mapping should be invariant to the loop permutation
        if do_reorder:
            s.reorder(i2, i0, i1)

        info = get_tensorize_loop_mapping(s, block,
                                          matmul_16x16x16xf16f16f16_desc)
        assert info is not None
        desc_loop_to_sref = dict((v, k) for k, v in info.loop_map.items())

        for i in range(3):
            assert desc_loops[i] in desc_loop_to_sref

        assert s.get(desc_loop_to_sref[desc_loops[0]]) == s.get(i0)
        assert s.get(desc_loop_to_sref[desc_loops[1]]) == s.get(i1)
        assert s.get(desc_loop_to_sref[desc_loops[2]]) == s.get(i2)
예제 #7
0
def test_meta_schedule_integration_apply_history_best():
    mod, _, _ = get_network(name="resnet_18", input_shape=[1, 3, 224, 224])
    database = DummyDatabase()
    env = ApplyHistoryBest(database)
    target = Target("llvm")
    workload = database.commit_workload(MockModule)
    database.commit_tuning_record(
        TuningRecord(Schedule(MockModule).trace, [1.0], workload, target, [])
    )
    mod = env.query(task_name="mock-task", mod=mod, target=target, dispatched=[MockModule])
    assert tvm.ir.structural_equal(mod, workload.mod)
예제 #8
0
파일: utils.py 프로젝트: junrushao1994/tvm
def apply_fixed_schedules(
    relay_mod: Union[RelayFunc, IRModule],
    target: Union[str, Target],
    params: Optional[Dict[str, NDArray]],
    schedule_fn: Callable[[ms.ExtractedTask, Schedule], bool],
    te_filter_func=None,
):
    """Apply fixed schedules (manually written, without any tunable knobs) as specified by
    schedule_fn to extracted tasks, and return a database that can be passed to ApplyHistoryBest.

    Parameters
    ----------
    mod : Union[RelayFunc, IRModule]
        The Relay module to apply fixed schedules.
    target : Union[str, Target]
        The target used to extract tasks.
    params : Optional[Dict[str, tvm.runtime.NDArray]]
        The associated parameters of the module.
    schedule_fn : Callable[[ExtractedTask, Schedule], bool]
        A callable that is applied for each extracted task and the corresponding default schedule.
        Returns True if the given schedule should be committed to the database, False otherwise.
    te_filter_func : Union[str, None, Callable[[List[Tensor]], PrimFunc]] = None
        The filtering function for TE computation
        If it's a string, it's the name of the filtering function. Built in functions are
          - "meta_schedule.DefaultTaskFilter"
          - "meta_schedule.DefaultTaskFilterAllowExtern"
        If it's None, it's the default filtering function
        If it's a callable, it's the filtering function

    Returns
    -------
    database : Database
        The database containing dummy tuning records for manually scheduled traces.
    """
    target = Target(target) if isinstance(target, str) else target
    extracted_tasks = ms.extract_task_from_relay(
        relay_mod,
        target,
        params,
        te_filter_func=te_filter_func,
    )
    database = ms.database.MemoryDatabase()
    for task in extracted_tasks:
        mod = ms.default_config.mod(task.dispatched[0])
        sch = Schedule(mod)

        if schedule_fn(task, sch):
            workload = database.commit_workload(mod)
            tune_rec = ms.database.TuningRecord(sch.trace, workload, [0.0],
                                                target, [])
            database.commit_tuning_record(tune_rec)

    return database
def test_tile_with_tensor_intrin_conv2d_nchwc_vnni():
    s = Schedule(Conv2dNCHWcVNNIModule)
    block = s.get_block("conv2d_NCHWc_int8")
    tiled_loop = tile_with_tensor_intrin(s, block, VNNI_DOT_16x4_INTRIN)
    tiled_loops = s.get_loops(block)
    assert len(tiled_loops) == 12
    assert s.get(tiled_loop) == s.get(tiled_loops[-2])
    tvm.ir.assert_structural_equal(s.mod, Conv2dNCHWcVNNIModuleTiled)
def _sch(decision: int) -> Schedule:
    sch = Schedule(add, debug_mask="all")
    # pylint: disable=invalid-name
    b0 = sch.get_block(name="move", func_name="main")
    l1 = sch.sample_compute_location(block=b0, decision=decision)
    sch.compute_at(block=b0, loop=l1, preserve_unit_loops=True)
    # pylint: enable=invalid-name
    return sch
def test_tile_with_tensor_intrin_dense_vnni():
    s = Schedule(DenseVNNIModule)
    block = s.get_block("compute")

    tiled_loop = tile_with_tensor_intrin(s, block, VNNI_DOT_16x4_INTRIN)

    _, _, _, i1_1, _ = s.get_loops(block)

    assert s.get(tiled_loop) == s.get(i1_1)
    tvm.ir.assert_structural_equal(s.mod, DenseVNNIModuleTiled)
예제 #12
0
def _schedule_matmul(sch: Schedule):
    block = sch.get_block("matmul")
    i, j, k = sch.get_loops(block=block)
    i_tiles = [1, 1, 2, 512]
    j_tiles = [1, 512, 1, 2]
    k_tiles = [256, 4]
    i_0, i_1, i_2, i_3 = sch.split(loop=i, factors=i_tiles)
    j_0, j_1, j_2, j_3 = sch.split(loop=j, factors=j_tiles)
    k_0, k_1 = sch.split(loop=k, factors=k_tiles)
    sch.reorder(i_0, j_0, i_1, j_1, k_0, i_2, j_2, k_1, i_3, j_3)
예제 #13
0
def apply_fixed_schedules(
    relay_mod: Union[RelayFunc, IRModule],
    target: Union[str, Target],
    params: Optional[Dict[str, NDArray]],
    schedule_fn: Callable[[ExtractedTask, Schedule], bool],
):
    """Apply fixed schedules (manually written, without any tunable knobs) as specified by
    schedule_fn to extracted tasks, and return a database that can be passed to ApplyHistoryBest.

    Parameters
    ----------
    mod : Union[RelayFunc, IRModule]
        The Relay module to apply fixed schedules.
    target : Union[str, Target]
        The target used to extract tasks.
    params : Optional[Dict[str, tvm.runtime.NDArray]]
        The associated parameters of the module.
    schedule_fn : Callable[[ExtractedTask, Schedule], bool]
        A callable that is applied for each extracted task and the corresponding default schedule.
        Returns True if the given schedule should be committed to the database, False otherwise.

    Returns
    -------
    database : Database
        The database containing dummy tuning records for manually scheduled traces.
    """
    target = Target(target) if isinstance(target, str) else target
    extracted_tasks = extract_task_from_relay(relay_mod, target, params)

    database = DummyDatabase()

    for task in extracted_tasks:
        mod = Parse._mod(task.dispatched[0])
        sch = Schedule(mod)

        if schedule_fn(task, sch):
            workload = database.commit_workload(mod)
            tune_rec = TuningRecord(sch.trace, [0.0], workload, target, [])
            database.commit_tuning_record(tune_rec)

    return database
예제 #14
0
def test_get_tensorize_loop_mapping_dense_vnni():
    s = Schedule(DenseVNNIModule)
    block = s.get_block("compute")

    info = get_tensorize_loop_mapping(s, block, dot_product_16x4_u8i8i32_desc)

    assert isinstance(info, TensorizeInfo)

    desc_loop_to_sref = dict((v, k) for k, v in info.loop_map.items())

    desc_loops = collect_loops(dot_product_16x4_u8i8i32_desc)
    _, loop_j, loop_k = s.get_loops(block)

    assert desc_loops[0] in desc_loop_to_sref and desc_loops[
        1] in desc_loop_to_sref
    assert s.get(desc_loop_to_sref[desc_loops[0]]) == s.get(loop_j)
    assert s.get(desc_loop_to_sref[desc_loops[1]]) == s.get(loop_k)
예제 #15
0
def test_get_tensorize_loop_mapping_conv2d_nchwc_vnni():
    s = Schedule(Conv2dNCHWcVNNIModule)
    block = s.get_block("conv2d_NCHWc_int8")

    info = get_tensorize_loop_mapping(s, block, dot_product_16x4_u8i8i32_desc)

    desc_loop_to_sref = dict((v, k) for k, v in info.loop_map.items())

    desc_loops = collect_loops(dot_product_16x4_u8i8i32_desc)

    # i4 corresonds to the inner output channel axis of the NCHWc output tensor
    # for i0, i1, i2, i3, i4, i5, i6, i7, i8, i9 in T.grid(1, 16, 56, 56, 16, 1, 1, 4, 4, 4):
    _, _, _, _, i4, _, _, _, _, i9 = s.get_loops(block)

    assert desc_loops[0] in desc_loop_to_sref and desc_loops[
        1] in desc_loop_to_sref
    assert s.get(desc_loop_to_sref[desc_loops[0]]) == s.get(i4)
    assert s.get(desc_loop_to_sref[desc_loops[1]]) == s.get(i9)
예제 #16
0
def tune_tir(
    mod: Union[IRModule, PrimFunc],
    target: Union[str, Target],
    config: TuneConfig,
    work_dir: str,
    *,
    builder: Optional[Builder] = None,
    runner: Optional[Runner] = None,
    database: Optional[Database] = None,
    cost_model: Optional[CostModel] = None,
    measure_callbacks: Optional[List[MeasureCallback]] = None,
    space: Optional[FnSpaceGenerator] = None,
    sch_rules: Optional[FnScheduleRule] = None,
    postprocs: Optional[FnPostproc] = None,
    mutator_probs: Optional[FnMutatorProb] = None,
    task_name: str = "main",
    num_threads: Optional[int] = None,
) -> Optional[Schedule]:
    """Tune a TIR IRModule with a given target.

    Parameters
    ----------
    mod : Union[IRModule, PrimFunc]
        The module to tune.
    target : Union[str, Target]
        The target to tune for.
    config : TuneConfig
        The search strategy config.
    work_dir : Optional[str]
        The working directory to save intermediate results.
    builder : Optional[Builder]
        The builder to use.
    runner : Optional[Runner]
        The runner to use.
    database : Optional[Database]
        The database to use.
    cost_model : Optional[CostModel]
        The cost model to use.
    measure_callbacks : Optional[List[MeasureCallback]]
        The callbacks used during tuning.

    Returns
    -------
    sch : Optional[Schedule]
        The tuned schedule.
    """
    # logging directory is set to `work_dir/logs` by default
    log_dir = osp.join(work_dir, "logs")
    os.makedirs(log_dir, exist_ok=True)

    config.create_loggers(
        log_dir=log_dir,
        params=[{
            "log_dir": log_dir,
            "logger_name": __name__ + f".task_{task_name}"
        }],
    )

    # pylint: disable=protected-access
    mod = Parse._mod(mod)
    target = Parse._target(target)
    # pylint: enable=protected-access
    database = tune_extracted_tasks(
        extracted_tasks=[
            ExtractedTask(
                task_name=task_name,
                mod=mod,
                dispatched=[mod],
                target=target,
                weight=1,
            ),
        ],
        config=config,
        work_dir=work_dir,
        builder=builder,
        runner=runner,
        database=database,
        cost_model=cost_model,
        measure_callbacks=measure_callbacks,
        space=space,
        sch_rules=sch_rules,
        postprocs=postprocs,
        mutator_probs=mutator_probs,
        num_threads=num_threads,
    )
    bests: List[TuningRecord] = database.get_top_k(
        database.commit_workload(mod),
        top_k=1,
    )
    if not bests:
        return None
    assert len(bests) == 1
    sch = Schedule(mod)
    bests[0].trace.apply_to_schedule(sch, remove_postproc=False)
    return sch
예제 #17
0
def _sch(decisions: List[List[int]]) -> Schedule:
    sch = Schedule(matmul, debug_mask="all")
    # pylint: disable=invalid-name
    d0, d1, d2 = decisions
    b0 = sch.get_block(name="C", func_name="main")
    root = sch.get_block(name="root", func_name="main")
    sch.get_consumers(block=b0)
    b1 = sch.cache_write(block=b0,
                         write_buffer_index=0,
                         storage_scope="global")
    l2, l3, l4 = sch.get_loops(block=b0)
    v5, v6, v7, v8 = sch.sample_perfect_tile(
        loop=l2,
        n=4,
        max_innermost_factor=64,
        decision=d0,
    )
    l9, l10, l11, l12 = sch.split(loop=l2, factors=[v5, v6, v7, v8])
    v13, v14, v15, v16 = sch.sample_perfect_tile(
        loop=l3,
        n=4,
        max_innermost_factor=64,
        decision=d1,
    )
    l17, l18, l19, l20 = sch.split(loop=l3, factors=[v13, v14, v15, v16])
    v21, v22 = sch.sample_perfect_tile(
        loop=l4,
        n=2,
        max_innermost_factor=64,
        decision=d2,
    )
    l23, l24 = sch.split(loop=l4, factors=[v21, v22])
    sch.reorder(l9, l17, l10, l18, l23, l11, l19, l24, l12, l20)
    sch.reverse_compute_at(block=b1, loop=l18, preserve_unit_loops=True)
    v57 = sch.sample_categorical(
        candidates=[0, 16, 64, 512],
        probs=[0.25, 0.25, 0.25, 0.25],
        decision=0,
    )
    sch.annotate(block_or_loop=root,
                 ann_key="meta_schedule.unroll_explicit",
                 ann_val=v57)
    # pylint: enable=invalid-name
    return sch
예제 #18
0
파일: tune.py 프로젝트: wenxcs/tvm
def tune_tir(
    mod: Union[IRModule, PrimFunc],
    target: Union[str, Target],
    config: SearchStrategyConfig,
    work_dir: str,
    *,
    task_name: str = "main",
    builder: Optional[Builder] = None,
    runner: Optional[Runner] = None,
    database: Optional[Database] = None,
    cost_model: Optional[CostModel] = None,
    measure_callbacks: Optional[List[MeasureCallback]] = None,
    task_scheduler: Optional[TaskScheduler] = None,
    space: Optional[FnSpaceGenerator] = None,
    sch_rules: Optional[FnScheduleRule] = None,
    postprocs: Optional[FnPostproc] = None,
    mutator_probs: Optional[FnMutatorProb] = None,
    num_threads: Optional[int] = None,
) -> Optional[Schedule]:
    """Tune a TIR IRModule with a given target.

    Parameters
    ----------
    mod : Union[IRModule, PrimFunc]
        The module to tune.
    target : Union[str, Target]
        The target to tune for.
    config : SearchStrategyConfig
        The search strategy config.
    work_dir : Optional[str]
        The working directory to save intermediate results.
    builder : Optional[Builder]
        The builder to use.
    runner : Optional[Runner]
        The runner to use.
    database : Optional[Database]
        The database to use.
    cost_model : Optional[CostModel]
        The cost model to use.
    measure_callbacks : Optional[List[MeasureCallback]]
        The callbacks used during tuning.
    f_tune_context : Optional[TYPE_F_TUNE_CONTEXT]
        The function to create TuneContext.
    f_task_scheduler : Optional[TYPE_F_TASK_SCHEDULER]
        The function to create TaskScheduler.

    Returns
    -------
    sch : Optional[Schedule]
        The tuned schedule.
    """

    logger.info("Working directory: %s", work_dir)
    # pylint: disable=protected-access
    mod = Parse._mod(mod)
    database = Parse._database(database, task_name, work_dir)
    tune_context = Parse._tune_context(
        tune_context=None,
        mod=mod,
        target=Parse._target(target),
        config=config,
        task_name=task_name,
        space_generator=space,
        sch_rules=sch_rules,
        postprocs=postprocs,
        mutator_probs=mutator_probs,
        num_threads=num_threads,
    )
    task_scheduler = Parse._task_scheduler(
        task_scheduler,
        [tune_context],
        builder=Parse._builder(builder),
        runner=Parse._runner(runner),
        database=database,
        cost_model=Parse._cost_model(cost_model),
        measure_callbacks=Parse._callbacks(measure_callbacks),
    )
    # pylint: enable=protected-access
    task_scheduler.tune()
    bests: List[TuningRecord] = database.get_top_k(
        database.commit_workload(mod),
        top_k=1,
    )
    if not bests:
        return None
    assert len(bests) == 1
    sch = Schedule(mod)
    bests[0].trace.apply_to_schedule(sch, remove_postproc=False)
    task_scheduler.cost_model.save(os.path.join(work_dir, f"{task_name}.xgb"))
    return sch
예제 #19
0
def _sch(decisions: List[List[int]]) -> Schedule:
    sch = Schedule(matmul, debug_mask="all")
    # pylint: disable=invalid-name
    (d0,) = decisions
    b0 = sch.get_block(name="C", func_name="main")
    sch.get_consumers(block=b0)
    b1 = sch.cache_write(block=b0, write_buffer_index=0, storage_scope="global")
    l2, l3, l4 = sch.get_loops(block=b0)
    v5, v6, v7, v8 = sch.sample_perfect_tile(
        loop=l2,
        n=4,
        max_innermost_factor=64,
        decision=d0,
    )
    l9, l10, l11, l12 = sch.split(loop=l2, factors=[v5, v6, v7, v8])
    l17, l18, l19, l20 = sch.split(loop=l3, factors=[8, 4, 8, 2])
    l23, l24 = sch.split(loop=l4, factors=[512, 1])
    sch.reorder(l9, l17, l10, l18, l23, l11, l19, l24, l12, l20)
    sch.reverse_compute_at(block=b1, loop=l18, preserve_unit_loops=True)
    # pylint: enable=invalid-name
    return sch
def _sch() -> Schedule:
    sch = Schedule(element_wise, debug_mask="all")
    # pylint: disable=invalid-name
    b0 = sch.get_block(name="C", func_name="main")
    l1, l2 = sch.get_loops(block=b0)
    l3 = sch.fuse(l1, l2)
    v4 = sch.sample_categorical(
        candidates=[32, 64, 128, 256, 512, 1024],
        probs=[
            0.16666666666666666,
            0.16666666666666666,
            0.16666666666666666,
            0.16666666666666666,
            0.16666666666666666,
            0.16666666666666666,
        ],
        decision=3,
    )
    l5, l6 = sch.split(loop=l3, factors=[None, v4])
    sch.bind(loop=l5, thread_axis="blockIdx.x")
    sch.bind(loop=l6, thread_axis="threadIdx.x")
    # pylint: enable=invalid-name
    return sch
예제 #21
0
파일: tune.py 프로젝트: chenghanpeng/tvm
def tune_tir(
    mod: Union[IRModule, PrimFunc],
    target: Union[str, Target],
    config: TuneConfig,
    work_dir: str,
    *,
    builder: Optional[Builder] = None,
    runner: Optional[Runner] = None,
    database: Optional[Database] = None,
    cost_model: Optional[CostModel] = None,
    measure_callbacks: Optional[List[MeasureCallback]] = None,
    space: Optional[FnSpaceGenerator] = None,
    blocks: Optional[List[str]] = None,
    sch_rules: Optional[FnScheduleRule] = None,
    postprocs: Optional[FnPostproc] = None,
    mutator_probs: Optional[FnMutatorProb] = None,
    task_name: str = "main",
    num_threads: Optional[int] = None,
) -> Optional[Schedule]:
    """Tune a TIR IRModule with a given target.

    Parameters
    ----------
    mod : Union[IRModule, PrimFunc]
        The module to tune.
    target : Union[str, Target]
        The target to tune for.
    config : TuneConfig
        The search strategy config.
    work_dir : Optional[str]
        The working directory to save intermediate results.
    builder : Optional[Builder]
        The builder to use.
    runner : Optional[Runner]
        The runner to use.
    database : Optional[Database]
        The database to use.
    cost_model : Optional[CostModel]
        The cost model to use.
    measure_callbacks : Optional[List[MeasureCallback]]
        The callbacks used during tuning.
    space : Optional[FnSpaceGenerator]
        The space generator to use.
    blocks : Optional[List[str]]
        A list of block names specifying blocks to be tuned. Note that if
        the list is not None, blocks outside this list will not be tuned.
        Only one of this argument and space may be provided.
    sch_rules : Optional[FnScheduleRule]
        The search rules to use.
    postprocs : Optional[FnPostproc]
        The postprocessors to use.
    mutator_probs : Optional[FnMutatorProb]
        The probability distribution to use different mutators.
    task_name : str
        The name of the function to extract schedules from.
    num_threads : Optional[int]
        The number of threads to use

    Returns
    -------
    sch : Optional[Schedule]
        The tuned schedule.
    """
    # logging directory is set to `work_dir/logs` by default
    log_dir = osp.join(work_dir, "logs")
    os.makedirs(log_dir, exist_ok=True)

    config.create_loggers(
        log_dir=log_dir,
        params=[{
            "log_dir": log_dir,
            "logger_name": __name__ + f".task_{task_name}"
        }],
    )

    if blocks is not None:
        assert space is None, "Can not specify blocks to tune when a search space is given."

        # Create a filter function to identify named blocks.
        def _f_block_filter(block, target_names) -> bool:
            return block.name_hint in target_names

        # Create a space generator that targets specific blocks.
        space = PostOrderApply(
            f_block_filter=lambda block: _f_block_filter(block, blocks))

    # pylint: disable=protected-access
    mod = default_config.mod(mod)
    target = default_config.target(target)
    # pylint: enable=protected-access
    database = tune_extracted_tasks(
        extracted_tasks=[
            ExtractedTask(
                task_name=task_name,
                mod=mod,
                dispatched=[mod],
                target=target,
                weight=1,
            ),
        ],
        config=config,
        work_dir=work_dir,
        builder=builder,
        runner=runner,
        database=database,
        cost_model=cost_model,
        measure_callbacks=measure_callbacks,
        space=space,
        sch_rules=sch_rules,
        postprocs=postprocs,
        mutator_probs=mutator_probs,
        num_threads=num_threads,
    )
    with Profiler.timeit("ApplyHistoryBest"):
        bests: List[TuningRecord] = database.get_top_k(
            database.commit_workload(mod), top_k=1)
        if not bests:
            return None
        assert len(bests) == 1
        sch = Schedule(mod)
        bests[0].trace.apply_to_schedule(sch, remove_postproc=False)
    return sch