Beispiel #1
0
def tune_each_task(
    mod,
    target,
    config,
    runner,
    work_dir,
    params,
):
    extracted_tasks = ms.extract_task_from_relay(mod, target, params)
    database = ms.database.JSONDatabase(
        path_workload=os.path.join(work_dir, "default_database_workload.json"),
        path_tuning_record=os.path.join(work_dir, "default_database_tuning_record.json"),
    )
    for task in extracted_tasks:
        # pylint: disable=protected-access
        tune_context = ms.tune.Parse._tune_context(
            tune_context=None,
            mod=ms.tune.Parse._mod(task.dispatched[0]),
            target=target,
            config=config,
            task_name=task.task_name,
            space_generator=None,
            sch_rules=None,
            postprocs=None,
            mutator_probs=None,
            num_threads=os.cpu_count(),
        )
        task_scheduler = ms.tune.Parse._task_scheduler(
            None,
            [tune_context],
            task_weights=[1.0],
            builder=ms.tune.Parse._builder(None),
            runner=ms.tune.Parse._runner(runner),
            database=database,
            max_trials=config.max_trials_per_task,
            cost_model=ms.tune.Parse._cost_model(None),
            measure_callbacks=ms.tune.Parse._callbacks(None),
        )
        # pylint: enable=protected-access
        task_scheduler.tune()
    with target, ms.ApplyHistoryBest(database):
        with PassContext(
            opt_level=3,
            config={"relay.backend.use_meta_schedule": True},
        ):
            return relay_build(mod, target=target, params=params)
Beispiel #2
0
def tune_relay(
    mod: IRModule,
    target: Union[str, Target],
    config: TuneConfig,
    work_dir: str,
    *,
    params: Optional[Dict[str, NDArray]] = None,
    builder: Optional[Builder] = None,
    runner: Optional[Runner] = None,
    database: Optional[Database] = None,
    cost_model: Optional[CostModel] = None,
    measure_callbacks: Optional[List[MeasureCallback]] = None,
    space: Optional[FnSpaceGenerator] = None,
    sch_rules: Optional[FnScheduleRule] = None,
    postprocs: Optional[FnPostproc] = None,
    mutator_probs: Optional[FnMutatorProb] = None,
    num_threads: Optional[int] = None,
) -> Module:
    """Tune a TIR IRModule with a given target.

    Parameters
    ----------
    mod : IRModule
        The module to tune.
    target : Union[str, Target]
        The target to tune for.
    config : TuneConfig
        The search strategy config.
    params : Optional[Dict[str, tvm.runtime.NDArray]]
        The associated parameters of the program
    task_name : str
        The name of the task.
    work_dir : Optional[str]
        The working directory to save intermediate results.
    builder : Optional[Builder]
        The builder to use.
    runner : Optional[Runner]
        The runner to use.
    database : Optional[Database]
        The database to use.
    measure_callbacks : Optional[List[MeasureCallback]]
        The callbacks used during tuning.

    Returns
    -------
    lib : Module
        The built runtime module for the given relay workload.
    """
    # pylint: disable=import-outside-toplevel
    from tvm.relay import build as relay_build

    from .relay_integration import extract_task_from_relay

    # pylint: disable=protected-access, enable=import-outside-toplevel
    target = Parse._target(target)
    # pylint: enable=protected-access,
    # parse the tuning contexts
    extracted_tasks = extract_task_from_relay(mod, target, params)
    database = tune_extracted_tasks(
        extracted_tasks,
        config,
        work_dir,
        builder=builder,
        runner=runner,
        database=database,
        cost_model=cost_model,
        measure_callbacks=measure_callbacks,
        space=space,
        sch_rules=sch_rules,
        postprocs=postprocs,
        mutator_probs=mutator_probs,
        num_threads=num_threads,
    )
    with target, autotvm_silencer(), ApplyHistoryBest(database):
        with PassContext(
                opt_level=3,
                config={"relay.backend.use_meta_schedule": True},
        ):
            return relay_build(mod, target=target, params=params)
Beispiel #3
0
def tune_relay(
    mod: Union[RelayFunc, IRModule],
    target: Union[str, Target],
    config: SearchStrategyConfig,
    work_dir: str,
    *,
    params: Optional[Dict[str, NDArray]] = None,
    builder: Optional[Builder] = None,
    runner: Optional[Runner] = None,
    database: Optional[Database] = None,
    cost_model: Optional[CostModel] = None,
    measure_callbacks: Optional[List[MeasureCallback]] = None,
    task_scheduler: Optional[TaskScheduler] = None,
    space: Optional[FnSpaceGenerator] = None,
    sch_rules: Optional[FnScheduleRule] = None,
    postprocs: Optional[FnPostproc] = None,
    mutator_probs: Optional[FnMutatorProb] = None,
    num_threads: Optional[int] = None,
) -> Module:
    """Tune a TIR IRModule with a given target.

    Parameters
    ----------
    mod : Union[RelayFunc, IRModule]
        The module to tune.
    target : Union[str, Target]
        The target to tune for.
    config : SearchStrategyConfig
        The search strategy config.
    params : Optional[Dict[str, tvm.runtime.NDArray]]
        The associated parameters of the program
    task_name : str
        The name of the task.
    work_dir : Optional[str]
        The working directory to save intermediate results.
    builder : Optional[Builder]
        The builder to use.
    runner : Optional[Runner]
        The runner to use.
    database : Optional[Database]
        The database to use.
    measure_callbacks : Optional[List[MeasureCallback]]
        The callbacks used during tuning.
    f_tune_context : Optional[TYPE_F_TUNE_CONTEXT]
        The function to create TuneContext.
    f_task_scheduler : Optional[TYPE_F_TASK_SCHEDULER]
        The function to create TaskScheduler.

    Returns
    -------
    lib : Module
        The built runtime module for the given relay workload.
    """

    logger.info("Working directory: %s", work_dir)
    extracted_tasks = extract_task_from_relay(mod, target, params)
    database = tune_extracted_tasks(
        extracted_tasks,
        target,
        config,
        work_dir,
        builder=builder,
        runner=runner,
        database=database,
        cost_model=cost_model,
        measure_callbacks=measure_callbacks,
        task_scheduler=task_scheduler,
        space=space,
        sch_rules=sch_rules,
        postprocs=postprocs,
        mutator_probs=mutator_probs,
        num_threads=num_threads,
    )
    with ApplyHistoryBest(database):
        with tvm.transform.PassContext(
                opt_level=3,
                config={"relay.backend.use_meta_schedule": True},
        ):
            return relay_build(mod, target=target, params=params)