def _create_context(mod, target, rule): ctx = TuneContext( mod=mod, target=target, space_generator=PostOrderApply(), sch_rules=[rule], task_name="test", ) return ctx
def _create_context(mod, target, rule): ctx = TuneContext( mod=mod, target=target, space_generator=PostOrderApply(), sch_rules=[rule], task_name="test", ) ctx.space_generator.initialize_with_tune_context(ctx) for sch_rule in ctx.sch_rules: sch_rule.initialize_with_tune_context(ctx) return ctx
def tune_tir( mod: Union[IRModule, PrimFunc], target: Union[str, Target], config: TuneConfig, work_dir: str, *, builder: Optional[Builder] = None, runner: Optional[Runner] = None, database: Optional[Database] = None, cost_model: Optional[CostModel] = None, measure_callbacks: Optional[List[MeasureCallback]] = None, space: Optional[FnSpaceGenerator] = None, blocks: Optional[List[str]] = None, sch_rules: Optional[FnScheduleRule] = None, postprocs: Optional[FnPostproc] = None, mutator_probs: Optional[FnMutatorProb] = None, task_name: str = "main", num_threads: Optional[int] = None, ) -> Optional[Schedule]: """Tune a TIR IRModule with a given target. Parameters ---------- mod : Union[IRModule, PrimFunc] The module to tune. target : Union[str, Target] The target to tune for. config : TuneConfig The search strategy config. work_dir : Optional[str] The working directory to save intermediate results. builder : Optional[Builder] The builder to use. runner : Optional[Runner] The runner to use. database : Optional[Database] The database to use. cost_model : Optional[CostModel] The cost model to use. measure_callbacks : Optional[List[MeasureCallback]] The callbacks used during tuning. space : Optional[FnSpaceGenerator] The space generator to use. blocks : Optional[List[str]] A list of block names specifying blocks to be tuned. Note that if the list is not None, blocks outside this list will not be tuned. Only one of this argument and space may be provided. sch_rules : Optional[FnScheduleRule] The search rules to use. postprocs : Optional[FnPostproc] The postprocessors to use. mutator_probs : Optional[FnMutatorProb] The probability distribution to use different mutators. task_name : str The name of the function to extract schedules from. num_threads : Optional[int] The number of threads to use Returns ------- sch : Optional[Schedule] The tuned schedule. """ # logging directory is set to `work_dir/logs` by default log_dir = osp.join(work_dir, "logs") os.makedirs(log_dir, exist_ok=True) config.create_loggers( log_dir=log_dir, params=[{ "log_dir": log_dir, "logger_name": __name__ + f".task_{task_name}" }], ) if blocks is not None: assert space is None, "Can not specify blocks to tune when a search space is given." # Create a filter function to identify named blocks. def _f_block_filter(block, target_names) -> bool: return block.name_hint in target_names # Create a space generator that targets specific blocks. space = PostOrderApply( f_block_filter=lambda block: _f_block_filter(block, blocks)) # pylint: disable=protected-access mod = default_config.mod(mod) target = default_config.target(target) # pylint: enable=protected-access database = tune_extracted_tasks( extracted_tasks=[ ExtractedTask( task_name=task_name, mod=mod, dispatched=[mod], target=target, weight=1, ), ], config=config, work_dir=work_dir, builder=builder, runner=runner, database=database, cost_model=cost_model, measure_callbacks=measure_callbacks, space=space, sch_rules=sch_rules, postprocs=postprocs, mutator_probs=mutator_probs, num_threads=num_threads, ) with Profiler.timeit("ApplyHistoryBest"): bests: List[TuningRecord] = database.get_top_k( database.commit_workload(mod), top_k=1) if not bests: return None assert len(bests) == 1 sch = Schedule(mod) bests[0].trace.apply_to_schedule(sch, remove_postproc=False) return sch