def check_index_map(workload, block_name, intrin_name, expected_index_map): s = Schedule(workload) block = s.get_block(block_name) desc_func = TensorIntrin.get(intrin_name).desc info = get_auto_tensorize_mapping_info(s, block, desc_func) assert len(info.mappings) == 1 assert IndexMap.from_func(expected_index_map).is_equivalent_to( info.mappings[0])
def _find_match_sketch_id( mod: IRModule, sketches: List[Schedule], expected_mod: IRModule, expected_decision: List[Tuple[str, List[int]]], *, debug_mask="all", ) -> Optional[int]: for sketch_id, sketch in enumerate(sketches): i = 0 new_decisions = {} for inst in sketch.trace.insts: if not inst.kind.name.startswith("Sample"): continue assert i < len(expected_decision) if inst.kind.name == expected_decision[i][0]: new_decisions[inst] = expected_decision[i][1] i += 1 if len(new_decisions) != len(expected_decision): continue sch = Schedule(mod, debug_mask=debug_mask) Trace( insts=sketch.trace.insts, decisions=new_decisions, ).apply_to_schedule(sch, remove_postproc=True) if structural_equal(sch.mod, expected_mod): verify_trace_roundtrip(sch=sch, mod=mod, debug_mask=debug_mask) return sketch_id return None
def _schedule_batch_matmul(sch: Schedule): block = sch.get_block("matmul") i, j, k, t = sch.get_loops(block=block) i_0, i_1, i_2, i_3 = sch.split(loop=i, factors=[2, 2, 2, 2]) j_0, j_1, j_2, j_3 = sch.split(loop=j, factors=[2, 4, 64, 2]) k_0, k_1 = sch.split(loop=k, factors=[32, 32]) t_0, t_1 = sch.split(loop=t, factors=[2, 512]) sch.reorder(i_0, j_0, i_1, j_1, k_0, i_2, j_2, k_1, i_3, j_3, t_0, t_1)
def test_meta_schedule_integration_apply_history_best(): class DummyDatabase(PyDatabase): def __init__(self): super().__init__() self.records = [] self.workload_reg = [] def has_workload(self, mod: IRModule) -> Workload: for workload in self.workload_reg: if tvm.ir.structural_equal(workload.mod, mod): return True return False def commit_tuning_record(self, record: TuningRecord) -> None: self.records.append(record) def commit_workload(self, mod: IRModule) -> Workload: for workload in self.workload_reg: if tvm.ir.structural_equal(workload.mod, mod): return workload workload = Workload(mod) self.workload_reg.append(workload) return workload def get_top_k(self, workload: Workload, top_k: int) -> List[TuningRecord]: return list( filter( lambda x: x.workload == workload, sorted(self.records, key=lambda x: sum(x.run_secs) / len(x.run_secs)), ))[:int(top_k)] def __len__(self) -> int: return len(self.records) def print_results(self) -> None: print("\n".join([str(r) for r in self.records])) mod, _, _, _ = get_network( name="resnet-18", batch_size=1, layout="NHWC", dtype="float32", ) database = DummyDatabase() env = ApplyHistoryBest(database) target = Target("llvm") workload = database.commit_workload(MockModule) database.commit_tuning_record( TuningRecord(Schedule(MockModule).trace, [1.0], workload, target, [])) mod = env.query(task_name="mock-task", mod=mod, target=target, dispatched=[MockModule]) assert tvm.ir.structural_equal(mod, workload.mod)
def _schedule_batch_matmul(sch: Schedule): block = sch.get_block("matmul") i, j, k, t = sch.get_loops(block=block) # TODO(@zxybazh): Change to `sample_perfect_tile` after upstreaming i_0, i_1, i_2, i_3 = sch.split(loop=i, factors=[2, 2, 2, 2]) j_0, j_1, j_2, j_3 = sch.split(loop=j, factors=[2, 4, 64, 2]) k_0, k_1 = sch.split(loop=k, factors=[32, 32]) t_0, t_1 = sch.split(loop=t, factors=[2, 512]) sch.reorder(i_0, j_0, i_1, j_1, k_0, i_2, j_2, k_1, i_3, j_3, t_0, t_1)
def test_get_tensorize_loop_mapping_matmul_mma(): @T.prim_func def matmul_16x16x16xf16f16f16_desc( A: T.Buffer((16, 16), "float16", align=128, offset_factor=1), B: T.Buffer((16, 16), "float16", align=128, offset_factor=1), C: T.Buffer((16, 16), "float16", align=128, offset_factor=1), ) -> None: with T.block("root"): T.reads(C[0:16, 0:16], A[0:16, 0:16], B[0:16, 0:16]) T.writes(C[0:16, 0:16]) for i, j, k in T.grid(16, 16, 16): with T.block("update"): vii, vjj, vkk = T.axis.remap("SSR", [i, j, k]) C[vii, vjj] = C[vii, vjj] + A[vii, vkk] * B[vjj, vkk] matmul = create_prim_func(te_workload.matmul_relu( n=512, m=512, k=512, )) s = Schedule(matmul) block = s.get_block("C") i0, i1, i2 = s.get_loops(block) desc_loops = collect_loops(matmul_16x16x16xf16f16f16_desc) for do_reorder in [False, True]: # Mapping should be invariant to the loop permutation if do_reorder: s.reorder(i2, i0, i1) info = get_tensorize_loop_mapping(s, block, matmul_16x16x16xf16f16f16_desc) assert info is not None desc_loop_to_sref = dict((v, k) for k, v in info.loop_map.items()) for i in range(3): assert desc_loops[i] in desc_loop_to_sref assert s.get(desc_loop_to_sref[desc_loops[0]]) == s.get(i0) assert s.get(desc_loop_to_sref[desc_loops[1]]) == s.get(i1) assert s.get(desc_loop_to_sref[desc_loops[2]]) == s.get(i2)
def test_meta_schedule_integration_apply_history_best(): mod, _, _ = get_network(name="resnet_18", input_shape=[1, 3, 224, 224]) database = DummyDatabase() env = ApplyHistoryBest(database) target = Target("llvm") workload = database.commit_workload(MockModule) database.commit_tuning_record( TuningRecord(Schedule(MockModule).trace, [1.0], workload, target, []) ) mod = env.query(task_name="mock-task", mod=mod, target=target, dispatched=[MockModule]) assert tvm.ir.structural_equal(mod, workload.mod)
def apply_fixed_schedules( relay_mod: Union[RelayFunc, IRModule], target: Union[str, Target], params: Optional[Dict[str, NDArray]], schedule_fn: Callable[[ms.ExtractedTask, Schedule], bool], te_filter_func=None, ): """Apply fixed schedules (manually written, without any tunable knobs) as specified by schedule_fn to extracted tasks, and return a database that can be passed to ApplyHistoryBest. Parameters ---------- mod : Union[RelayFunc, IRModule] The Relay module to apply fixed schedules. target : Union[str, Target] The target used to extract tasks. params : Optional[Dict[str, tvm.runtime.NDArray]] The associated parameters of the module. schedule_fn : Callable[[ExtractedTask, Schedule], bool] A callable that is applied for each extracted task and the corresponding default schedule. Returns True if the given schedule should be committed to the database, False otherwise. te_filter_func : Union[str, None, Callable[[List[Tensor]], PrimFunc]] = None The filtering function for TE computation If it's a string, it's the name of the filtering function. Built in functions are - "meta_schedule.DefaultTaskFilter" - "meta_schedule.DefaultTaskFilterAllowExtern" If it's None, it's the default filtering function If it's a callable, it's the filtering function Returns ------- database : Database The database containing dummy tuning records for manually scheduled traces. """ target = Target(target) if isinstance(target, str) else target extracted_tasks = ms.extract_task_from_relay( relay_mod, target, params, te_filter_func=te_filter_func, ) database = ms.database.MemoryDatabase() for task in extracted_tasks: mod = ms.default_config.mod(task.dispatched[0]) sch = Schedule(mod) if schedule_fn(task, sch): workload = database.commit_workload(mod) tune_rec = ms.database.TuningRecord(sch.trace, workload, [0.0], target, []) database.commit_tuning_record(tune_rec) return database
def test_tile_with_tensor_intrin_conv2d_nchwc_vnni(): s = Schedule(Conv2dNCHWcVNNIModule) block = s.get_block("conv2d_NCHWc_int8") tiled_loop = tile_with_tensor_intrin(s, block, VNNI_DOT_16x4_INTRIN) tiled_loops = s.get_loops(block) assert len(tiled_loops) == 12 assert s.get(tiled_loop) == s.get(tiled_loops[-2]) tvm.ir.assert_structural_equal(s.mod, Conv2dNCHWcVNNIModuleTiled)
def _sch(decision: int) -> Schedule: sch = Schedule(add, debug_mask="all") # pylint: disable=invalid-name b0 = sch.get_block(name="move", func_name="main") l1 = sch.sample_compute_location(block=b0, decision=decision) sch.compute_at(block=b0, loop=l1, preserve_unit_loops=True) # pylint: enable=invalid-name return sch
def test_tile_with_tensor_intrin_dense_vnni(): s = Schedule(DenseVNNIModule) block = s.get_block("compute") tiled_loop = tile_with_tensor_intrin(s, block, VNNI_DOT_16x4_INTRIN) _, _, _, i1_1, _ = s.get_loops(block) assert s.get(tiled_loop) == s.get(i1_1) tvm.ir.assert_structural_equal(s.mod, DenseVNNIModuleTiled)
def _schedule_matmul(sch: Schedule): block = sch.get_block("matmul") i, j, k = sch.get_loops(block=block) i_tiles = [1, 1, 2, 512] j_tiles = [1, 512, 1, 2] k_tiles = [256, 4] i_0, i_1, i_2, i_3 = sch.split(loop=i, factors=i_tiles) j_0, j_1, j_2, j_3 = sch.split(loop=j, factors=j_tiles) k_0, k_1 = sch.split(loop=k, factors=k_tiles) sch.reorder(i_0, j_0, i_1, j_1, k_0, i_2, j_2, k_1, i_3, j_3)
def apply_fixed_schedules( relay_mod: Union[RelayFunc, IRModule], target: Union[str, Target], params: Optional[Dict[str, NDArray]], schedule_fn: Callable[[ExtractedTask, Schedule], bool], ): """Apply fixed schedules (manually written, without any tunable knobs) as specified by schedule_fn to extracted tasks, and return a database that can be passed to ApplyHistoryBest. Parameters ---------- mod : Union[RelayFunc, IRModule] The Relay module to apply fixed schedules. target : Union[str, Target] The target used to extract tasks. params : Optional[Dict[str, tvm.runtime.NDArray]] The associated parameters of the module. schedule_fn : Callable[[ExtractedTask, Schedule], bool] A callable that is applied for each extracted task and the corresponding default schedule. Returns True if the given schedule should be committed to the database, False otherwise. Returns ------- database : Database The database containing dummy tuning records for manually scheduled traces. """ target = Target(target) if isinstance(target, str) else target extracted_tasks = extract_task_from_relay(relay_mod, target, params) database = DummyDatabase() for task in extracted_tasks: mod = Parse._mod(task.dispatched[0]) sch = Schedule(mod) if schedule_fn(task, sch): workload = database.commit_workload(mod) tune_rec = TuningRecord(sch.trace, [0.0], workload, target, []) database.commit_tuning_record(tune_rec) return database
def test_get_tensorize_loop_mapping_dense_vnni(): s = Schedule(DenseVNNIModule) block = s.get_block("compute") info = get_tensorize_loop_mapping(s, block, dot_product_16x4_u8i8i32_desc) assert isinstance(info, TensorizeInfo) desc_loop_to_sref = dict((v, k) for k, v in info.loop_map.items()) desc_loops = collect_loops(dot_product_16x4_u8i8i32_desc) _, loop_j, loop_k = s.get_loops(block) assert desc_loops[0] in desc_loop_to_sref and desc_loops[ 1] in desc_loop_to_sref assert s.get(desc_loop_to_sref[desc_loops[0]]) == s.get(loop_j) assert s.get(desc_loop_to_sref[desc_loops[1]]) == s.get(loop_k)
def test_get_tensorize_loop_mapping_conv2d_nchwc_vnni(): s = Schedule(Conv2dNCHWcVNNIModule) block = s.get_block("conv2d_NCHWc_int8") info = get_tensorize_loop_mapping(s, block, dot_product_16x4_u8i8i32_desc) desc_loop_to_sref = dict((v, k) for k, v in info.loop_map.items()) desc_loops = collect_loops(dot_product_16x4_u8i8i32_desc) # i4 corresonds to the inner output channel axis of the NCHWc output tensor # for i0, i1, i2, i3, i4, i5, i6, i7, i8, i9 in T.grid(1, 16, 56, 56, 16, 1, 1, 4, 4, 4): _, _, _, _, i4, _, _, _, _, i9 = s.get_loops(block) assert desc_loops[0] in desc_loop_to_sref and desc_loops[ 1] in desc_loop_to_sref assert s.get(desc_loop_to_sref[desc_loops[0]]) == s.get(i4) assert s.get(desc_loop_to_sref[desc_loops[1]]) == s.get(i9)
def tune_tir( mod: Union[IRModule, PrimFunc], target: Union[str, Target], config: TuneConfig, work_dir: str, *, builder: Optional[Builder] = None, runner: Optional[Runner] = None, database: Optional[Database] = None, cost_model: Optional[CostModel] = None, measure_callbacks: Optional[List[MeasureCallback]] = None, space: Optional[FnSpaceGenerator] = None, sch_rules: Optional[FnScheduleRule] = None, postprocs: Optional[FnPostproc] = None, mutator_probs: Optional[FnMutatorProb] = None, task_name: str = "main", num_threads: Optional[int] = None, ) -> Optional[Schedule]: """Tune a TIR IRModule with a given target. Parameters ---------- mod : Union[IRModule, PrimFunc] The module to tune. target : Union[str, Target] The target to tune for. config : TuneConfig The search strategy config. work_dir : Optional[str] The working directory to save intermediate results. builder : Optional[Builder] The builder to use. runner : Optional[Runner] The runner to use. database : Optional[Database] The database to use. cost_model : Optional[CostModel] The cost model to use. measure_callbacks : Optional[List[MeasureCallback]] The callbacks used during tuning. Returns ------- sch : Optional[Schedule] The tuned schedule. """ # logging directory is set to `work_dir/logs` by default log_dir = osp.join(work_dir, "logs") os.makedirs(log_dir, exist_ok=True) config.create_loggers( log_dir=log_dir, params=[{ "log_dir": log_dir, "logger_name": __name__ + f".task_{task_name}" }], ) # pylint: disable=protected-access mod = Parse._mod(mod) target = Parse._target(target) # pylint: enable=protected-access database = tune_extracted_tasks( extracted_tasks=[ ExtractedTask( task_name=task_name, mod=mod, dispatched=[mod], target=target, weight=1, ), ], config=config, work_dir=work_dir, builder=builder, runner=runner, database=database, cost_model=cost_model, measure_callbacks=measure_callbacks, space=space, sch_rules=sch_rules, postprocs=postprocs, mutator_probs=mutator_probs, num_threads=num_threads, ) bests: List[TuningRecord] = database.get_top_k( database.commit_workload(mod), top_k=1, ) if not bests: return None assert len(bests) == 1 sch = Schedule(mod) bests[0].trace.apply_to_schedule(sch, remove_postproc=False) return sch
def _sch(decisions: List[List[int]]) -> Schedule: sch = Schedule(matmul, debug_mask="all") # pylint: disable=invalid-name d0, d1, d2 = decisions b0 = sch.get_block(name="C", func_name="main") root = sch.get_block(name="root", func_name="main") sch.get_consumers(block=b0) b1 = sch.cache_write(block=b0, write_buffer_index=0, storage_scope="global") l2, l3, l4 = sch.get_loops(block=b0) v5, v6, v7, v8 = sch.sample_perfect_tile( loop=l2, n=4, max_innermost_factor=64, decision=d0, ) l9, l10, l11, l12 = sch.split(loop=l2, factors=[v5, v6, v7, v8]) v13, v14, v15, v16 = sch.sample_perfect_tile( loop=l3, n=4, max_innermost_factor=64, decision=d1, ) l17, l18, l19, l20 = sch.split(loop=l3, factors=[v13, v14, v15, v16]) v21, v22 = sch.sample_perfect_tile( loop=l4, n=2, max_innermost_factor=64, decision=d2, ) l23, l24 = sch.split(loop=l4, factors=[v21, v22]) sch.reorder(l9, l17, l10, l18, l23, l11, l19, l24, l12, l20) sch.reverse_compute_at(block=b1, loop=l18, preserve_unit_loops=True) v57 = sch.sample_categorical( candidates=[0, 16, 64, 512], probs=[0.25, 0.25, 0.25, 0.25], decision=0, ) sch.annotate(block_or_loop=root, ann_key="meta_schedule.unroll_explicit", ann_val=v57) # pylint: enable=invalid-name return sch
def tune_tir( mod: Union[IRModule, PrimFunc], target: Union[str, Target], config: SearchStrategyConfig, work_dir: str, *, task_name: str = "main", builder: Optional[Builder] = None, runner: Optional[Runner] = None, database: Optional[Database] = None, cost_model: Optional[CostModel] = None, measure_callbacks: Optional[List[MeasureCallback]] = None, task_scheduler: Optional[TaskScheduler] = None, space: Optional[FnSpaceGenerator] = None, sch_rules: Optional[FnScheduleRule] = None, postprocs: Optional[FnPostproc] = None, mutator_probs: Optional[FnMutatorProb] = None, num_threads: Optional[int] = None, ) -> Optional[Schedule]: """Tune a TIR IRModule with a given target. Parameters ---------- mod : Union[IRModule, PrimFunc] The module to tune. target : Union[str, Target] The target to tune for. config : SearchStrategyConfig The search strategy config. work_dir : Optional[str] The working directory to save intermediate results. builder : Optional[Builder] The builder to use. runner : Optional[Runner] The runner to use. database : Optional[Database] The database to use. cost_model : Optional[CostModel] The cost model to use. measure_callbacks : Optional[List[MeasureCallback]] The callbacks used during tuning. f_tune_context : Optional[TYPE_F_TUNE_CONTEXT] The function to create TuneContext. f_task_scheduler : Optional[TYPE_F_TASK_SCHEDULER] The function to create TaskScheduler. Returns ------- sch : Optional[Schedule] The tuned schedule. """ logger.info("Working directory: %s", work_dir) # pylint: disable=protected-access mod = Parse._mod(mod) database = Parse._database(database, task_name, work_dir) tune_context = Parse._tune_context( tune_context=None, mod=mod, target=Parse._target(target), config=config, task_name=task_name, space_generator=space, sch_rules=sch_rules, postprocs=postprocs, mutator_probs=mutator_probs, num_threads=num_threads, ) task_scheduler = Parse._task_scheduler( task_scheduler, [tune_context], builder=Parse._builder(builder), runner=Parse._runner(runner), database=database, cost_model=Parse._cost_model(cost_model), measure_callbacks=Parse._callbacks(measure_callbacks), ) # pylint: enable=protected-access task_scheduler.tune() bests: List[TuningRecord] = database.get_top_k( database.commit_workload(mod), top_k=1, ) if not bests: return None assert len(bests) == 1 sch = Schedule(mod) bests[0].trace.apply_to_schedule(sch, remove_postproc=False) task_scheduler.cost_model.save(os.path.join(work_dir, f"{task_name}.xgb")) return sch
def _sch(decisions: List[List[int]]) -> Schedule: sch = Schedule(matmul, debug_mask="all") # pylint: disable=invalid-name (d0,) = decisions b0 = sch.get_block(name="C", func_name="main") sch.get_consumers(block=b0) b1 = sch.cache_write(block=b0, write_buffer_index=0, storage_scope="global") l2, l3, l4 = sch.get_loops(block=b0) v5, v6, v7, v8 = sch.sample_perfect_tile( loop=l2, n=4, max_innermost_factor=64, decision=d0, ) l9, l10, l11, l12 = sch.split(loop=l2, factors=[v5, v6, v7, v8]) l17, l18, l19, l20 = sch.split(loop=l3, factors=[8, 4, 8, 2]) l23, l24 = sch.split(loop=l4, factors=[512, 1]) sch.reorder(l9, l17, l10, l18, l23, l11, l19, l24, l12, l20) sch.reverse_compute_at(block=b1, loop=l18, preserve_unit_loops=True) # pylint: enable=invalid-name return sch
def _sch() -> Schedule: sch = Schedule(element_wise, debug_mask="all") # pylint: disable=invalid-name b0 = sch.get_block(name="C", func_name="main") l1, l2 = sch.get_loops(block=b0) l3 = sch.fuse(l1, l2) v4 = sch.sample_categorical( candidates=[32, 64, 128, 256, 512, 1024], probs=[ 0.16666666666666666, 0.16666666666666666, 0.16666666666666666, 0.16666666666666666, 0.16666666666666666, 0.16666666666666666, ], decision=3, ) l5, l6 = sch.split(loop=l3, factors=[None, v4]) sch.bind(loop=l5, thread_axis="blockIdx.x") sch.bind(loop=l6, thread_axis="threadIdx.x") # pylint: enable=invalid-name return sch
def tune_tir( mod: Union[IRModule, PrimFunc], target: Union[str, Target], config: TuneConfig, work_dir: str, *, builder: Optional[Builder] = None, runner: Optional[Runner] = None, database: Optional[Database] = None, cost_model: Optional[CostModel] = None, measure_callbacks: Optional[List[MeasureCallback]] = None, space: Optional[FnSpaceGenerator] = None, blocks: Optional[List[str]] = None, sch_rules: Optional[FnScheduleRule] = None, postprocs: Optional[FnPostproc] = None, mutator_probs: Optional[FnMutatorProb] = None, task_name: str = "main", num_threads: Optional[int] = None, ) -> Optional[Schedule]: """Tune a TIR IRModule with a given target. Parameters ---------- mod : Union[IRModule, PrimFunc] The module to tune. target : Union[str, Target] The target to tune for. config : TuneConfig The search strategy config. work_dir : Optional[str] The working directory to save intermediate results. builder : Optional[Builder] The builder to use. runner : Optional[Runner] The runner to use. database : Optional[Database] The database to use. cost_model : Optional[CostModel] The cost model to use. measure_callbacks : Optional[List[MeasureCallback]] The callbacks used during tuning. space : Optional[FnSpaceGenerator] The space generator to use. blocks : Optional[List[str]] A list of block names specifying blocks to be tuned. Note that if the list is not None, blocks outside this list will not be tuned. Only one of this argument and space may be provided. sch_rules : Optional[FnScheduleRule] The search rules to use. postprocs : Optional[FnPostproc] The postprocessors to use. mutator_probs : Optional[FnMutatorProb] The probability distribution to use different mutators. task_name : str The name of the function to extract schedules from. num_threads : Optional[int] The number of threads to use Returns ------- sch : Optional[Schedule] The tuned schedule. """ # logging directory is set to `work_dir/logs` by default log_dir = osp.join(work_dir, "logs") os.makedirs(log_dir, exist_ok=True) config.create_loggers( log_dir=log_dir, params=[{ "log_dir": log_dir, "logger_name": __name__ + f".task_{task_name}" }], ) if blocks is not None: assert space is None, "Can not specify blocks to tune when a search space is given." # Create a filter function to identify named blocks. def _f_block_filter(block, target_names) -> bool: return block.name_hint in target_names # Create a space generator that targets specific blocks. space = PostOrderApply( f_block_filter=lambda block: _f_block_filter(block, blocks)) # pylint: disable=protected-access mod = default_config.mod(mod) target = default_config.target(target) # pylint: enable=protected-access database = tune_extracted_tasks( extracted_tasks=[ ExtractedTask( task_name=task_name, mod=mod, dispatched=[mod], target=target, weight=1, ), ], config=config, work_dir=work_dir, builder=builder, runner=runner, database=database, cost_model=cost_model, measure_callbacks=measure_callbacks, space=space, sch_rules=sch_rules, postprocs=postprocs, mutator_probs=mutator_probs, num_threads=num_threads, ) with Profiler.timeit("ApplyHistoryBest"): bests: List[TuningRecord] = database.get_top_k( database.commit_workload(mod), top_k=1) if not bests: return None assert len(bests) == 1 sch = Schedule(mod) bests[0].trace.apply_to_schedule(sch, remove_postproc=False) return sch