def _create_context(mod, target, rule):
    ctx = TuneContext(
        mod=mod,
        target=target,
        space_generator=PostOrderApply(),
        sch_rules=[rule],
        task_name="test",
    )
    return ctx
def _create_context(mod, target, rule):
    ctx = TuneContext(
        mod=mod,
        target=target,
        space_generator=PostOrderApply(),
        sch_rules=[rule],
        task_name="test",
    )
    ctx.space_generator.initialize_with_tune_context(ctx)
    for sch_rule in ctx.sch_rules:
        sch_rule.initialize_with_tune_context(ctx)
    return ctx
Ejemplo n.º 3
0
def tune_tir(
    mod: Union[IRModule, PrimFunc],
    target: Union[str, Target],
    config: TuneConfig,
    work_dir: str,
    *,
    builder: Optional[Builder] = None,
    runner: Optional[Runner] = None,
    database: Optional[Database] = None,
    cost_model: Optional[CostModel] = None,
    measure_callbacks: Optional[List[MeasureCallback]] = None,
    space: Optional[FnSpaceGenerator] = None,
    blocks: Optional[List[str]] = None,
    sch_rules: Optional[FnScheduleRule] = None,
    postprocs: Optional[FnPostproc] = None,
    mutator_probs: Optional[FnMutatorProb] = None,
    task_name: str = "main",
    num_threads: Optional[int] = None,
) -> Optional[Schedule]:
    """Tune a TIR IRModule with a given target.

    Parameters
    ----------
    mod : Union[IRModule, PrimFunc]
        The module to tune.
    target : Union[str, Target]
        The target to tune for.
    config : TuneConfig
        The search strategy config.
    work_dir : Optional[str]
        The working directory to save intermediate results.
    builder : Optional[Builder]
        The builder to use.
    runner : Optional[Runner]
        The runner to use.
    database : Optional[Database]
        The database to use.
    cost_model : Optional[CostModel]
        The cost model to use.
    measure_callbacks : Optional[List[MeasureCallback]]
        The callbacks used during tuning.
    space : Optional[FnSpaceGenerator]
        The space generator to use.
    blocks : Optional[List[str]]
        A list of block names specifying blocks to be tuned. Note that if
        the list is not None, blocks outside this list will not be tuned.
        Only one of this argument and space may be provided.
    sch_rules : Optional[FnScheduleRule]
        The search rules to use.
    postprocs : Optional[FnPostproc]
        The postprocessors to use.
    mutator_probs : Optional[FnMutatorProb]
        The probability distribution to use different mutators.
    task_name : str
        The name of the function to extract schedules from.
    num_threads : Optional[int]
        The number of threads to use

    Returns
    -------
    sch : Optional[Schedule]
        The tuned schedule.
    """
    # logging directory is set to `work_dir/logs` by default
    log_dir = osp.join(work_dir, "logs")
    os.makedirs(log_dir, exist_ok=True)

    config.create_loggers(
        log_dir=log_dir,
        params=[{
            "log_dir": log_dir,
            "logger_name": __name__ + f".task_{task_name}"
        }],
    )

    if blocks is not None:
        assert space is None, "Can not specify blocks to tune when a search space is given."

        # Create a filter function to identify named blocks.
        def _f_block_filter(block, target_names) -> bool:
            return block.name_hint in target_names

        # Create a space generator that targets specific blocks.
        space = PostOrderApply(
            f_block_filter=lambda block: _f_block_filter(block, blocks))

    # pylint: disable=protected-access
    mod = default_config.mod(mod)
    target = default_config.target(target)
    # pylint: enable=protected-access
    database = tune_extracted_tasks(
        extracted_tasks=[
            ExtractedTask(
                task_name=task_name,
                mod=mod,
                dispatched=[mod],
                target=target,
                weight=1,
            ),
        ],
        config=config,
        work_dir=work_dir,
        builder=builder,
        runner=runner,
        database=database,
        cost_model=cost_model,
        measure_callbacks=measure_callbacks,
        space=space,
        sch_rules=sch_rules,
        postprocs=postprocs,
        mutator_probs=mutator_probs,
        num_threads=num_threads,
    )
    with Profiler.timeit("ApplyHistoryBest"):
        bests: List[TuningRecord] = database.get_top_k(
            database.commit_workload(mod), top_k=1)
        if not bests:
            return None
        assert len(bests) == 1
        sch = Schedule(mod)
        bests[0].trace.apply_to_schedule(sch, remove_postproc=False)
    return sch