def apply(self, sch: Schedule, block: BlockRV) -> List[Schedule]: if _is_root(sch, block): return [sch] sch = sch.copy() if sch.get(block).name_hint == "B": sch.compute_inline(block) return [sch]
def input_tile_data_pad(sch: Schedule): b115 = sch.get_block(name="input_tile") (b116, ) = sch.get_consumers(block=b115) _, _, _, l120, _, _, _, _ = sch.get_loops(block=b116) sch.compute_at(block=b115, loop=l120, preserve_unit_loops=True) sch.set_scope(block=b115, buffer_index=0, storage_scope="local") b127 = sch.get_block(name="data_pad") sch.compute_inline(block=b127)
def apply(self, sch: Schedule, block: BlockRV) -> List[Schedule]: if _is_root(sch, block): return [sch] new_sch = sch.copy() i_0, j_0, i_1, j_1, k_0, i_2, j_2, k_1, i_3, j_3 = new_sch.get_loops(block=block) new_sch.reorder(i_1, j_1, k_0, i_2, j_2, k_1, i_3, j_3, i_0, j_0) result = [new_sch] new_sch = sch.copy() i_0, j_0, i_1, j_1, k_0, i_2, j_2, k_1, i_3, j_3 = new_sch.get_loops(block=block) new_sch.reorder(i_1, j_3, i_0, j_0, j_1, k_0, i_2, j_2, k_1, i_3) result.append(new_sch) return result
def _schedule_matmul(sch: Schedule): block = sch.get_block("matmul") i, j, k = sch.get_loops(block=block) i_0, i_1, i_2, i_3 = sch.split(i, sch.sample_perfect_tile(i, n=4)) j_0, j_1, j_2, j_3 = sch.split(j, sch.sample_perfect_tile(j, n=4)) k_0, k_1 = sch.split(k, sch.sample_perfect_tile(k, n=2)) sch.reorder(i_0, j_0, i_1, j_1, k_0, i_2, j_2, k_1, i_3, j_3)
def root_anno(sch: Schedule): b8 = sch.get_block(name="root", func_name="main") v140 = sch.sample_categorical( candidates=[0, 16, 64, 512, 1024], probs=[ 0.20000000000000001, 0.20000000000000001, 0.20000000000000001, 0.20000000000000001, 0.20000000000000001, ], decision=2, ) sch.annotate(block_or_loop=b8, ann_key="meta_schedule.unroll_explicit", ann_val=v140)
def input_tile_data_pad(sch: Schedule): b78 = sch.get_block(name="input_tile") l80 = sch.sample_compute_location(block=b78, decision=4) sch.compute_at(block=b78, loop=l80, preserve_unit_loops=True) b81 = sch.get_block(name="data_pad") l83 = sch.sample_compute_location(block=b81, decision=-2) sch.compute_at(block=b81, loop=l83, preserve_unit_loops=True)
def test_meta_schedule_measure_callback(): @derived_object class FancyMeasureCallback(PyMeasureCallback): def apply( self, task_scheduler: TaskScheduler, task_id: int, measure_candidates: List[MeasureCandidate], builds: List[BuilderResult], results: List[RunnerResult], ) -> None: assert len(measure_candidates) == 1 assert_structural_equal(measure_candidates[0].sch.mod, Matmul) assert (len(builds) == 1 and builds[0].error_msg is None and builds[0].artifact_path == "test_build") assert (len(results) == 1 and results[0].error_msg is None and len(results[0].run_secs) == 2) measure_callback = FancyMeasureCallback() measure_callback.apply( RoundRobin([], [], DummyBuilder(), DummyRunner(), DummyDatabase(), max_trials=1), 0, [MeasureCandidate(Schedule(Matmul), None)], [BuilderResult("test_build", None)], [RunnerResult([1.0, 2.1], None)], )
def generate_design_space(self, mod: IRModule) -> List[Schedule]: """Generate design spaces given a module. Parameters ---------- mod : IRModule The module used for design space generation. Returns ------- design_spaces : List[Schedule] The generated design spaces, i.e., schedules. """ sch = Schedule(mod) # Make sure the schedule is traced result = self.sch_fn(sch) # Call the schedule function if result is None: # Case 1. No output return [sch] if isinstance(result, Schedule): # Case 2. Single output return [result] if isinstance(result, (list, tuple, Array)): # Case 3. Multiple outputs for ret in result: # enumerate the outputs if not isinstance(ret, Schedule): raise TypeError( "Wrong type of element in the list, expected Schedule got " + f"'{type(ret)}': {ret}") return result raise TypeError(f"Unexpected return type {type(result)}: {result}")
def apply(self, sch: Schedule, block: BlockRV) -> List[Schedule]: if _is_root(sch, block): return [sch] new_sch = sch.copy() i, j = new_sch.get_loops(block=block) i_0, i_1 = new_sch.split(loop=i, factors=[16, 64]) j_0, j_1 = new_sch.split(loop=j, factors=[64, 16]) new_sch.reorder(i_0, j_0, i_1, j_1) result = [new_sch] new_sch = sch.copy() i, j = new_sch.get_loops(block=block) i_0, i_1 = new_sch.split(loop=i, factors=[2, 512]) j_0, j_1 = new_sch.split(loop=j, factors=[2, 512]) new_sch.reorder(i_0, j_0, i_1, j_1) result.append(new_sch) return result
def test_meta_schedule_postproc_rewrite_parallel_unroll_vectorize(): postproc = RewriteParallelVectorizeUnroll() sch = Schedule(Move_PUV) assert postproc.apply(sch) print(sch.mod["main"].script()) mod = tvm.tir.transform.Simplify()(sch.mod) tvm.ir.assert_structural_equal(mod["main"], Move_PUV0)
def verify_trace_roundtrip( sch: Schedule, mod: Union[PrimFunc, IRModule], *, debug_mask: Union[str, int] = "all", text_format: Union[str, Sequence[str]] = ["python", "json"], ) -> Schedule: """Serialize a traced schedule to JSON, then replay the JSON trace by applying to a fresh new schedule, verifying the reproducibility of scheduling. Parameters ---------- sch : tir.Schedule The traced TensorIR schedule to be verified mod : Union[PrimFunc, IRModule] The IRModule or PrimFunc to construct the fresh new schedule debug_mask : Union[str, int] Do extra correctness checking after the class creation and each time after calling the Replace method. Possible choices of `debug_mask`: 1) "all" - Turn on all the checks 2) "none" - Turn off all the checks 3) An integer - Turn on checks according to the bitmasks provided in ScheduleDebugMask text_format: Union[str, Sequence[str]] The text format or formats whose round-trip behavior should be validated. If a single string, validate round-trips through """ if not isinstance(text_format, str): for opt in text_format: new_sch = verify_trace_roundtrip(sch, mod, debug_mask=debug_mask, text_format=opt) return new_sch trace = sch.trace assert trace is not None # Step 1. Perform a round-trip through the text-format new_sch = Schedule(mod=mod, debug_mask=debug_mask) if text_format == "json": json_obj = trace.as_json() Trace.apply_json_to_schedule(json_obj=json_obj, sch=new_sch) elif text_format == "python": py_trace = "\n".join(trace.as_python()) exec(py_trace, tvm.tir.__dict__, {"sch": new_sch}) # pylint: disable=exec-used else: assert text_format in ("json", "python"), f"Unknown text format: {text_format}" # Step 2. Verify that the round-trip produced the same scheduling assert structural_equal(new_sch.mod, sch.mod) # Step 3. Check the consistency of the text format between the old and new traces py_repr = "\n".join(trace.as_python()) new_py_repr = "\n".join(new_sch.trace.as_python()) assert py_repr == new_py_repr # Step 4. Return the new schedule in case it could be useful return new_sch
def test_meta_schedule_measure_callback_fail(): @ms.derived_object class FailingMeasureCallback(ms.measure_callback.PyMeasureCallback): def apply( self, task_scheduler: ms.task_scheduler.TaskScheduler, task_id: int, measure_candidates: List[ms.MeasureCandidate], builder_results: List[ms.builder.BuilderResult], runner_results: List[ms.runner.RunnerResult], ) -> None: raise ValueError("test") measure_callback = FailingMeasureCallback() with pytest.raises(ValueError, match="test"): measure_callback.apply( ms.task_scheduler.RoundRobin( tasks=[], task_weights=[], builder=DummyBuilder(), runner=DummyRunner(), database=ms.database.MemoryDatabase(), max_trials=1, ), 0, [ms.MeasureCandidate(Schedule(Matmul), None)], [ms.builder.BuilderResult("test_build", None)], [ms.runner.RunnerResult([1.0, 2.1], None)], )
def test_meta_schedule_measure_callback_fail(): @derived_object class FailingMeasureCallback(PyMeasureCallback): def apply( self, task_scheduler: TaskScheduler, task_id: int, measure_candidates: List[MeasureCandidate], builds: List[BuilderResult], results: List[RunnerResult], ) -> None: raise ValueError("test") measure_callback = FailingMeasureCallback() with pytest.raises(ValueError, match="test"): measure_callback.apply( RoundRobin([], [], DummyBuilder(), DummyRunner(), DummyDatabase(), max_trials=1), 0, [MeasureCandidate(Schedule(Matmul), None)], [BuilderResult("test_build", None)], [RunnerResult([1.0, 2.1], None)], )
def apply(self, sch: Schedule, block: BlockRV) -> List[Schedule]: if _is_root(sch, block): return [sch] new_sch = sch.copy() i, j, k = new_sch.get_loops(block=block) i_0, i_1, i_2, i_3 = new_sch.split(loop=i, factors=[4, 64, 2, 2]) j_0, j_1, j_2, j_3 = new_sch.split(loop=j, factors=[2, 4, 64, 2]) k_0, k_1 = new_sch.split(loop=k, factors=[32, 32]) new_sch.reorder(i_0, j_0, i_1, j_1, k_0, i_2, j_2, k_1, i_3, j_3) result = [new_sch] new_sch = sch.copy() i, j, k = new_sch.get_loops(block=block) i_0, i_1, i_2, i_3 = new_sch.split(loop=i, factors=[4, 64, 2, 2]) j_0, j_1, j_2, j_3 = new_sch.split(loop=j, factors=[2, 4, 64, 2]) k_0, k_1 = new_sch.split(loop=k, factors=[32, 32]) new_sch.reorder(i_0, j_0, i_1, j_1, k_0, i_2, j_2, k_1, i_3, j_3) result.append(new_sch) return result
def schedule_matmul(sch: Schedule): block = sch.get_block("matmul") i, j, k = sch.get_loops(block=block) # TODO(@zxybazh): Change to `sample_perfect_tile` after upstreaming i_0, i_1, i_2, i_3 = sch.split(loop=i, factors=[2, 4, 64, 2]) j_0, j_1, j_2, j_3 = sch.split(loop=j, factors=[4, 64, 2, 2]) k_0, k_1 = sch.split(loop=k, factors=[32, 32]) sch.reorder(i_0, j_0, i_1, j_1, k_0, i_2, j_2, k_1, i_3, j_3)
def conv2d(sch: Schedule): b7 = sch.get_block(name="conv2d_winograd") l141, l142, l143, l144 = sch.get_loops(block=b7) l145 = sch.fuse(l141, l142, l143, l144) v146 = sch.sample_categorical( candidates=[32, 64, 128, 256, 512, 1024], probs=[ 0.16666666666666666, 0.16666666666666666, 0.16666666666666666, 0.16666666666666666, 0.16666666666666666, 0.16666666666666666, ], decision=2, ) l147, l148 = sch.split(loop=l145, factors=[None, v146]) sch.bind(loop=l147, thread_axis="blockIdx.x") sch.bind(loop=l148, thread_axis="threadIdx.x")
def test_conv2d_winograd_cuda(): mod = conv2d_winograd_cuda mod = IRModule({"main": mod}) context = TuneContext( mod=mod, target=Target("nvidia/geforce-rtx-3090", host="llvm"), task_name="Custom Search Space Task", sch_rules=DefaultCUDA._sch_rules(), # pylint: disable=protected-access ) for sch_rule in context.sch_rules: sch_rule.initialize_with_tune_context(context) post_order_apply = PostOrderApply() post_order_apply.initialize_with_tune_context(context) (sch,) = post_order_apply.generate_design_space(mod) decisions = dict( zip( [i for i in sch.trace.insts if i.kind.name.startswith("Sample")], [ # data_pack [3, 3], [64, 2], 2, # inverse [3, 3], [2, 64], 2, # bgemm [1, 1, 1, 1, 6], [1, 1, 1, 3, 2], [3, 1, 1, 1, 3], [4, 2, 1, 4, 4], [32, 1, 4], 1, 1, # root anno 2, # conv2d 2, ], ) ) trace = Trace(sch.trace.insts, decisions=decisions) sch = Schedule(mod=mod) trace.apply_to_schedule(sch, remove_postproc=False) answer = sch.mod expected = _get_mod() tvm.ir.assert_structural_equal(answer, expected)
def test_conv2d_winograd_cpu(): mod = conv2d_winograd_cpu mod = IRModule({"main": mod}) context = TuneContext( mod=mod, target=Target("llvm"), task_name="Custom Search Space Task", sch_rules=DefaultLLVM._sch_rules(), # pylint: disable=protected-access ) post_order_apply = PostOrderApply() post_order_apply.initialize_with_tune_context(context) (sch, ) = post_order_apply.generate_design_space(mod) decisions = dict( zip( [ i for i in sch.trace.insts[:-4] if i.kind.name.startswith("Sample") ], [ # data_pack [9, 1], [32, 4], # input_tile 4, # data_pad -2, # inverse [1, 9], [2, 64], # bgemm [1, 2, 3, 1], [1, 1, 1, 6], [1, 1, 1, 9], [2, 1, 16, 4], [16, 8], ], )) trace = Trace(sch.trace.insts[:-4], decisions=decisions) sch = Schedule(mod=mod) trace.apply_to_schedule(sch, remove_postproc=False) answer = sch.mod expected = _get_mod() tvm.ir.assert_structural_equal(answer, expected)
def test_meta_schedule_measure_callback_fail(): class FailingMeasureCallback(PyMeasureCallback): def apply( self, task_scheduler: TaskScheduler, task_id: int, measure_candidates: List[MeasureCandidate], builds: List[BuilderResult], results: List[RunnerResult], ) -> None: raise ValueError("test") measure_callback = FailingMeasureCallback() with pytest.raises(ValueError, match="test"): measure_callback.apply( TaskScheduler(), 0, [MeasureCandidate(Schedule(Matmul), None)], [BuilderResult("test_build", None)], [RunnerResult([1.0, 2.1], None)], )
def verify_trace_roundtrip( sch: Schedule, mod: Union[PrimFunc, IRModule], *, debug_mask: Union[str, int] = "all", ) -> Schedule: """Serialize a traced schedule to JSON, then replay the JSON trace by applying to a fresh new schedule, verifying the reproducibility of scheduling. Parameters ---------- sch : tir.Schedule The traced TensorIR schedule to be verified mod : Union[PrimFunc, IRModule] The IRModule or PrimFunc to construct the fresh new schedule debug_mask : Union[str, int] Do extra correctness checking after the class creation and each time after calling the Replace method. Possible choices of `debug_mask`: 1) "all" - Turn on all the checks 2) "none" - Turn off all the checks 3) An integer - Turn on checks according to the bitmasks provided in ScheduleDebugMask """ # Step 1. Serialize the trace to JSON trace = sch.trace assert trace is not None json_obj = trace.as_json() # Step 2. Apply the JSON trace to a new schedule, then check if it reproduces the scheduling new_sch = Schedule(mod=mod, debug_mask=debug_mask) Trace.apply_json_to_schedule(json_obj=json_obj, sch=new_sch) assert structural_equal(new_sch.mod, sch.mod) # Step 3. Check the consistency of the text format between the old and new traces py_repr = "\n".join(trace.as_python()) new_py_repr = "\n".join(new_sch.trace.as_python()) assert py_repr == new_py_repr # Step 4. Return the new schedule in case it could be useful return new_sch
def test_meta_schedule_measure_callback(): @ms.derived_object class FancyMeasureCallback(ms.measure_callback.PyMeasureCallback): def apply( self, task_scheduler: ms.task_scheduler.TaskScheduler, task_id: int, measure_candidates: List[ms.MeasureCandidate], builder_results: List[ms.builder.BuilderResult], runner_results: List[ms.runner.RunnerResult], ) -> None: assert len(measure_candidates) == 1 tvm.ir.assert_structural_equal(measure_candidates[0].sch.mod, Matmul) assert (len(builder_results) == 1 and builder_results[0].error_msg is None and builder_results[0].artifact_path == "test_build") assert (len(runner_results) == 1 and runner_results[0].error_msg is None and len(runner_results[0].run_secs) == 2) measure_callback = FancyMeasureCallback() measure_callback.apply( ms.task_scheduler.RoundRobin( tasks=[], task_weights=[], builder=DummyBuilder(), runner=DummyRunner(), database=ms.database.MemoryDatabase(), max_trials=1, ), 0, [ms.MeasureCandidate(Schedule(Matmul), None)], [ms.builder.BuilderResult("test_build", None)], [ms.runner.RunnerResult([1.0, 2.1], None)], )
def _schedule_matmul_small(sch: Schedule): block = sch.get_block("matmul") _, j, k = sch.get_loops(block=block) _, _ = sch.split(j, sch.sample_perfect_tile(j, n=2)) _, _ = sch.split(k, sch.sample_perfect_tile(k, n=2))
def apply(self, sch: Schedule, block: BlockRV): if sch.get(block).name_hint == "root": return [sch] sch = sch.copy() sch.compute_inline(block) return [sch]
def _is_root(sch: Schedule, block: BlockRV) -> bool: return sch.get_sref(block).parent is None
def bgemm(sch: Schedule): b31 = sch.get_block(name="bgemm") sch.annotate( block_or_loop=b31, ann_key="meta_schedule.tiling_structure", ann_val="SSSRRSRS", ) b32 = sch.cache_write(block=b31, write_buffer_index=0, storage_scope="local") b31, b32 = b32, b31 l33, l34, l35, l36, l37 = sch.get_loops(block=b32) v38, v39, v40, v41, v42 = sch.sample_perfect_tile( n=5, loop=l33, max_innermost_factor=64, decision=[1, 1, 1, 1, 6], ) l43, l44, l45, l46, l47 = sch.split(loop=l33, factors=[v38, v39, v40, v41, v42]) v48, v49, v50, v51, v52 = sch.sample_perfect_tile( n=5, loop=l34, max_innermost_factor=64, decision=[1, 1, 1, 3, 2], ) l53, l54, l55, l56, l57 = sch.split(loop=l34, factors=[v48, v49, v50, v51, v52]) v58, v59, v60, v61, v62 = sch.sample_perfect_tile( n=5, loop=l35, max_innermost_factor=64, decision=[3, 1, 1, 1, 3], ) l63, l64, l65, l66, l67 = sch.split(loop=l35, factors=[v58, v59, v60, v61, v62]) v68, v69, v70, v71, v72 = sch.sample_perfect_tile( n=5, loop=l36, max_innermost_factor=64, decision=[4, 2, 1, 4, 4], ) l73, l74, l75, l76, l77 = sch.split(loop=l36, factors=[v68, v69, v70, v71, v72]) v78, v79, v80 = sch.sample_perfect_tile( n=3, loop=l37, max_innermost_factor=64, decision=[32, 1, 4], ) l81, l82, l83 = sch.split(loop=l37, factors=[v78, v79, v80]) sch.reorder( # fmt: off l43, l53, l63, l73, l44, l54, l64, l74, l45, l55, l65, l75, l81, l82, l46, l56, l66, l76, l83, l47, l57, l67, l77, # fmt: on ) l84 = sch.fuse(l43, l53, l63, l73) sch.bind(loop=l84, thread_axis="blockIdx.x") l85 = sch.fuse(l44, l54, l64, l74) sch.bind(loop=l85, thread_axis="vthread.x") l86 = sch.fuse(l45, l55, l65, l75) sch.bind(loop=l86, thread_axis="threadIdx.x") b87 = sch.cache_read(block=b32, read_buffer_index=1, storage_scope="shared") sch.compute_at(block=b87, loop=l81, preserve_unit_loops=True) _, _, _, _, l92, l93, l94, l95 = sch.get_loops(block=b87) sch.fuse(l92, l93, l94, l95) v97 = sch.sample_categorical( candidates=[1, 2, 3, 4], probs=[0.25, 0.25, 0.25, 0.25], decision=1, ) sch.annotate( block_or_loop=b87, ann_key="meta_schedule.cooperative_fetch", ann_val=v97, ) b101 = sch.cache_read(block=b32, read_buffer_index=2, storage_scope="shared") sch.compute_at(block=b101, loop=l81, preserve_unit_loops=True) _, _, _, _, l106, l107, l108, l109 = sch.get_loops(block=b101) sch.fuse(l106, l107, l108, l109) v110 = sch.sample_categorical( candidates=[1, 2, 3, 4], probs=[0.25, 0.25, 0.25, 0.25], decision=1, ) sch.annotate( block_or_loop=b101, ann_key="meta_schedule.cooperative_fetch", ann_val=v110, ) sch.reverse_compute_at(block=b31, loop=l86, preserve_unit_loops=True)
def data_pack(sch: Schedule): b16 = sch.get_block(name="data_pack") l17, l18, l19, l20, l21, l22 = sch.get_loops(block=b16) sch.unroll(loop=l17) sch.unroll(loop=l18) v23, v24 = sch.sample_perfect_tile( n=2, loop=l19, max_innermost_factor=64, decision=[3, 3], ) l25, l26 = sch.split(loop=l19, factors=[v23, v24]) v27, v28 = sch.sample_perfect_tile( n=2, loop=l20, max_innermost_factor=64, decision=[64, 2], ) l29, l30 = sch.split(loop=l20, factors=[v27, v28]) sch.unroll(loop=l21) sch.unroll(loop=l22) sch.reorder(l25, l29, l26, l30, l17, l18, l21, l22)
def inline(sch: Schedule): b125 = sch.get_block(name="A") sch.compute_inline(block=b125) b126 = sch.get_block(name="B") sch.compute_inline(block=b126)
def _get_mod(): # pylint: disable=invalid-name def inline(sch: Schedule): b125 = sch.get_block(name="A") sch.compute_inline(block=b125) b126 = sch.get_block(name="B") sch.compute_inline(block=b126) def input_tile_data_pad(sch: Schedule): b115 = sch.get_block(name="input_tile") (b116, ) = sch.get_consumers(block=b115) _, _, _, l120, _, _, _, _ = sch.get_loops(block=b116) sch.compute_at(block=b115, loop=l120, preserve_unit_loops=True) sch.set_scope(block=b115, buffer_index=0, storage_scope="local") b127 = sch.get_block(name="data_pad") sch.compute_inline(block=b127) def data_pack(sch: Schedule): b16 = sch.get_block(name="data_pack") l17, l18, l19, l20, l21, l22 = sch.get_loops(block=b16) sch.unroll(loop=l17) sch.unroll(loop=l18) v23, v24 = sch.sample_perfect_tile( n=2, loop=l19, max_innermost_factor=64, decision=[3, 3], ) l25, l26 = sch.split(loop=l19, factors=[v23, v24]) v27, v28 = sch.sample_perfect_tile( n=2, loop=l20, max_innermost_factor=64, decision=[64, 2], ) l29, l30 = sch.split(loop=l20, factors=[v27, v28]) sch.unroll(loop=l21) sch.unroll(loop=l22) sch.reorder(l25, l29, l26, l30, l17, l18, l21, l22) def bgemm(sch: Schedule): b31 = sch.get_block(name="bgemm") sch.annotate( block_or_loop=b31, ann_key="meta_schedule.tiling_structure", ann_val="SSSRRSRS", ) b32 = sch.cache_write(block=b31, write_buffer_index=0, storage_scope="local") b31, b32 = b32, b31 l33, l34, l35, l36, l37 = sch.get_loops(block=b32) v38, v39, v40, v41, v42 = sch.sample_perfect_tile( n=5, loop=l33, max_innermost_factor=64, decision=[1, 1, 1, 1, 6], ) l43, l44, l45, l46, l47 = sch.split(loop=l33, factors=[v38, v39, v40, v41, v42]) v48, v49, v50, v51, v52 = sch.sample_perfect_tile( n=5, loop=l34, max_innermost_factor=64, decision=[1, 1, 1, 3, 2], ) l53, l54, l55, l56, l57 = sch.split(loop=l34, factors=[v48, v49, v50, v51, v52]) v58, v59, v60, v61, v62 = sch.sample_perfect_tile( n=5, loop=l35, max_innermost_factor=64, decision=[3, 1, 1, 1, 3], ) l63, l64, l65, l66, l67 = sch.split(loop=l35, factors=[v58, v59, v60, v61, v62]) v68, v69, v70, v71, v72 = sch.sample_perfect_tile( n=5, loop=l36, max_innermost_factor=64, decision=[4, 2, 1, 4, 4], ) l73, l74, l75, l76, l77 = sch.split(loop=l36, factors=[v68, v69, v70, v71, v72]) v78, v79, v80 = sch.sample_perfect_tile( n=3, loop=l37, max_innermost_factor=64, decision=[32, 1, 4], ) l81, l82, l83 = sch.split(loop=l37, factors=[v78, v79, v80]) sch.reorder( # fmt: off l43, l53, l63, l73, l44, l54, l64, l74, l45, l55, l65, l75, l81, l82, l46, l56, l66, l76, l83, l47, l57, l67, l77, # fmt: on ) l84 = sch.fuse(l43, l53, l63, l73) sch.bind(loop=l84, thread_axis="blockIdx.x") l85 = sch.fuse(l44, l54, l64, l74) sch.bind(loop=l85, thread_axis="vthread.x") l86 = sch.fuse(l45, l55, l65, l75) sch.bind(loop=l86, thread_axis="threadIdx.x") b87 = sch.cache_read(block=b32, read_buffer_index=1, storage_scope="shared") sch.compute_at(block=b87, loop=l81, preserve_unit_loops=True) _, _, _, _, l92, l93, l94, l95 = sch.get_loops(block=b87) sch.fuse(l92, l93, l94, l95) v97 = sch.sample_categorical( candidates=[1, 2, 3, 4], probs=[0.25, 0.25, 0.25, 0.25], decision=1, ) sch.annotate( block_or_loop=b87, ann_key="meta_schedule.cooperative_fetch", ann_val=v97, ) b101 = sch.cache_read(block=b32, read_buffer_index=2, storage_scope="shared") sch.compute_at(block=b101, loop=l81, preserve_unit_loops=True) _, _, _, _, l106, l107, l108, l109 = sch.get_loops(block=b101) sch.fuse(l106, l107, l108, l109) v110 = sch.sample_categorical( candidates=[1, 2, 3, 4], probs=[0.25, 0.25, 0.25, 0.25], decision=1, ) sch.annotate( block_or_loop=b101, ann_key="meta_schedule.cooperative_fetch", ann_val=v110, ) sch.reverse_compute_at(block=b31, loop=l86, preserve_unit_loops=True) def inverse(sch: Schedule): b1 = sch.get_block(name="inverse") l2, l3, l4, l5, l6, l7 = sch.get_loops(block=b1) sch.unroll(loop=l2) sch.unroll(loop=l3) v8, v9 = sch.sample_perfect_tile( n=2, loop=l4, max_innermost_factor=64, decision=[3, 3], ) l10, l11 = sch.split(loop=l4, factors=[v8, v9]) v12, v13 = sch.sample_perfect_tile( n=2, loop=l5, max_innermost_factor=64, decision=[2, 64], ) l14, l15 = sch.split(loop=l5, factors=[v12, v13]) sch.unroll(loop=l6) sch.unroll(loop=l7) sch.reorder(l10, l14, l11, l15, l2, l3, l6, l7) # pylint: enable=invalid-name sch = Schedule(mod=conv2d_winograd_cuda) inline(sch) data_pack(sch) input_tile_data_pad(sch) bgemm(sch) inverse(sch) return sch.mod
def inverse(sch: Schedule): b1 = sch.get_block(name="inverse") l2, l3, l4, l5, l6, l7 = sch.get_loops(block=b1) sch.unroll(loop=l2) sch.unroll(loop=l3) v8, v9 = sch.sample_perfect_tile( n=2, loop=l4, max_innermost_factor=64, decision=[3, 3], ) l10, l11 = sch.split(loop=l4, factors=[v8, v9]) v12, v13 = sch.sample_perfect_tile( n=2, loop=l5, max_innermost_factor=64, decision=[2, 64], ) l14, l15 = sch.split(loop=l5, factors=[v12, v13]) sch.unroll(loop=l6) sch.unroll(loop=l7) sch.reorder(l10, l14, l11, l15, l2, l3, l6, l7)
def test_vectorize_inner_loop(): sch = Schedule(before_matmul_vectorize) rule = RewriteParallelVectorizeUnroll() assert rule.apply(sch) tvm.ir.assert_structural_equal(sch.mod["main"], after_matmul_vectorize)