Пример #1
0
 def apply(self, sch: Schedule, block: BlockRV) -> List[Schedule]:
     if _is_root(sch, block):
         return [sch]
     sch = sch.copy()
     if sch.get(block).name_hint == "B":
         sch.compute_inline(block)
     return [sch]
Пример #2
0
    def input_tile_data_pad(sch: Schedule):
        b115 = sch.get_block(name="input_tile")
        (b116, ) = sch.get_consumers(block=b115)
        _, _, _, l120, _, _, _, _ = sch.get_loops(block=b116)
        sch.compute_at(block=b115, loop=l120, preserve_unit_loops=True)
        sch.set_scope(block=b115, buffer_index=0, storage_scope="local")

        b127 = sch.get_block(name="data_pad")
        sch.compute_inline(block=b127)
Пример #3
0
 def apply(self, sch: Schedule, block: BlockRV) -> List[Schedule]:
     if _is_root(sch, block):
         return [sch]
     new_sch = sch.copy()
     i_0, j_0, i_1, j_1, k_0, i_2, j_2, k_1, i_3, j_3 = new_sch.get_loops(block=block)
     new_sch.reorder(i_1, j_1, k_0, i_2, j_2, k_1, i_3, j_3, i_0, j_0)
     result = [new_sch]
     new_sch = sch.copy()
     i_0, j_0, i_1, j_1, k_0, i_2, j_2, k_1, i_3, j_3 = new_sch.get_loops(block=block)
     new_sch.reorder(i_1, j_3, i_0, j_0, j_1, k_0, i_2, j_2, k_1, i_3)
     result.append(new_sch)
     return result
def _schedule_matmul(sch: Schedule):
    block = sch.get_block("matmul")
    i, j, k = sch.get_loops(block=block)
    i_0, i_1, i_2, i_3 = sch.split(i, sch.sample_perfect_tile(i, n=4))
    j_0, j_1, j_2, j_3 = sch.split(j, sch.sample_perfect_tile(j, n=4))
    k_0, k_1 = sch.split(k, sch.sample_perfect_tile(k, n=2))
    sch.reorder(i_0, j_0, i_1, j_1, k_0, i_2, j_2, k_1, i_3, j_3)
 def root_anno(sch: Schedule):
     b8 = sch.get_block(name="root", func_name="main")
     v140 = sch.sample_categorical(
         candidates=[0, 16, 64, 512, 1024],
         probs=[
             0.20000000000000001,
             0.20000000000000001,
             0.20000000000000001,
             0.20000000000000001,
             0.20000000000000001,
         ],
         decision=2,
     )
     sch.annotate(block_or_loop=b8, ann_key="meta_schedule.unroll_explicit", ann_val=v140)
Пример #6
0
    def input_tile_data_pad(sch: Schedule):
        b78 = sch.get_block(name="input_tile")
        l80 = sch.sample_compute_location(block=b78, decision=4)
        sch.compute_at(block=b78, loop=l80, preserve_unit_loops=True)

        b81 = sch.get_block(name="data_pad")
        l83 = sch.sample_compute_location(block=b81, decision=-2)
        sch.compute_at(block=b81, loop=l83, preserve_unit_loops=True)
def test_meta_schedule_measure_callback():
    @derived_object
    class FancyMeasureCallback(PyMeasureCallback):
        def apply(
            self,
            task_scheduler: TaskScheduler,
            task_id: int,
            measure_candidates: List[MeasureCandidate],
            builds: List[BuilderResult],
            results: List[RunnerResult],
        ) -> None:
            assert len(measure_candidates) == 1
            assert_structural_equal(measure_candidates[0].sch.mod, Matmul)
            assert (len(builds) == 1 and builds[0].error_msg is None
                    and builds[0].artifact_path == "test_build")
            assert (len(results) == 1 and results[0].error_msg is None
                    and len(results[0].run_secs) == 2)

    measure_callback = FancyMeasureCallback()
    measure_callback.apply(
        RoundRobin([], [],
                   DummyBuilder(),
                   DummyRunner(),
                   DummyDatabase(),
                   max_trials=1),
        0,
        [MeasureCandidate(Schedule(Matmul), None)],
        [BuilderResult("test_build", None)],
        [RunnerResult([1.0, 2.1], None)],
    )
Пример #8
0
    def generate_design_space(self, mod: IRModule) -> List[Schedule]:
        """Generate design spaces given a module.

        Parameters
        ----------
        mod : IRModule
            The module used for design space generation.

        Returns
        -------
        design_spaces : List[Schedule]
            The generated design spaces, i.e., schedules.
        """
        sch = Schedule(mod)  # Make sure the schedule is traced
        result = self.sch_fn(sch)  # Call the schedule function
        if result is None:  # Case 1. No output
            return [sch]
        if isinstance(result, Schedule):  # Case 2. Single output
            return [result]
        if isinstance(result,
                      (list, tuple, Array)):  # Case 3. Multiple outputs
            for ret in result:  # enumerate the outputs
                if not isinstance(ret, Schedule):
                    raise TypeError(
                        "Wrong type of element in the list, expected Schedule got "
                        + f"'{type(ret)}': {ret}")
            return result
        raise TypeError(f"Unexpected return type {type(result)}: {result}")
Пример #9
0
 def apply(self, sch: Schedule, block: BlockRV) -> List[Schedule]:
     if _is_root(sch, block):
         return [sch]
     new_sch = sch.copy()
     i, j = new_sch.get_loops(block=block)
     i_0, i_1 = new_sch.split(loop=i, factors=[16, 64])
     j_0, j_1 = new_sch.split(loop=j, factors=[64, 16])
     new_sch.reorder(i_0, j_0, i_1, j_1)
     result = [new_sch]
     new_sch = sch.copy()
     i, j = new_sch.get_loops(block=block)
     i_0, i_1 = new_sch.split(loop=i, factors=[2, 512])
     j_0, j_1 = new_sch.split(loop=j, factors=[2, 512])
     new_sch.reorder(i_0, j_0, i_1, j_1)
     result.append(new_sch)
     return result
def test_meta_schedule_postproc_rewrite_parallel_unroll_vectorize():
    postproc = RewriteParallelVectorizeUnroll()
    sch = Schedule(Move_PUV)
    assert postproc.apply(sch)
    print(sch.mod["main"].script())
    mod = tvm.tir.transform.Simplify()(sch.mod)
    tvm.ir.assert_structural_equal(mod["main"], Move_PUV0)
Пример #11
0
def verify_trace_roundtrip(
    sch: Schedule,
    mod: Union[PrimFunc, IRModule],
    *,
    debug_mask: Union[str, int] = "all",
    text_format: Union[str, Sequence[str]] = ["python", "json"],
) -> Schedule:
    """Serialize a traced schedule to JSON, then replay the JSON trace by applying to
    a fresh new schedule, verifying the reproducibility of scheduling.

    Parameters
    ----------
    sch : tir.Schedule
        The traced TensorIR schedule to be verified
    mod : Union[PrimFunc, IRModule]
        The IRModule or PrimFunc to construct the fresh new schedule
    debug_mask : Union[str, int]
        Do extra correctness checking after the class creation and each time
        after calling the Replace method.
        Possible choices of `debug_mask`:
        1) "all" - Turn on all the checks
        2) "none" - Turn off all the checks
        3) An integer - Turn on checks according to the bitmasks provided in ScheduleDebugMask
    text_format: Union[str, Sequence[str]]
        The text format or formats whose round-trip behavior should be
        validated.  If a single string, validate round-trips through
    """
    if not isinstance(text_format, str):
        for opt in text_format:
            new_sch = verify_trace_roundtrip(sch,
                                             mod,
                                             debug_mask=debug_mask,
                                             text_format=opt)
        return new_sch

    trace = sch.trace
    assert trace is not None

    # Step 1. Perform a round-trip through the text-format
    new_sch = Schedule(mod=mod, debug_mask=debug_mask)
    if text_format == "json":
        json_obj = trace.as_json()
        Trace.apply_json_to_schedule(json_obj=json_obj, sch=new_sch)
    elif text_format == "python":
        py_trace = "\n".join(trace.as_python())
        exec(py_trace, tvm.tir.__dict__, {"sch": new_sch})  # pylint: disable=exec-used
    else:
        assert text_format in ("json",
                               "python"), f"Unknown text format: {text_format}"

    # Step 2. Verify that the round-trip produced the same scheduling
    assert structural_equal(new_sch.mod, sch.mod)

    # Step 3. Check the consistency of the text format between the old and new traces
    py_repr = "\n".join(trace.as_python())
    new_py_repr = "\n".join(new_sch.trace.as_python())
    assert py_repr == new_py_repr

    # Step 4. Return the new schedule in case it could be useful
    return new_sch
Пример #12
0
def test_meta_schedule_measure_callback_fail():
    @ms.derived_object
    class FailingMeasureCallback(ms.measure_callback.PyMeasureCallback):
        def apply(
            self,
            task_scheduler: ms.task_scheduler.TaskScheduler,
            task_id: int,
            measure_candidates: List[ms.MeasureCandidate],
            builder_results: List[ms.builder.BuilderResult],
            runner_results: List[ms.runner.RunnerResult],
        ) -> None:
            raise ValueError("test")

    measure_callback = FailingMeasureCallback()
    with pytest.raises(ValueError, match="test"):
        measure_callback.apply(
            ms.task_scheduler.RoundRobin(
                tasks=[],
                task_weights=[],
                builder=DummyBuilder(),
                runner=DummyRunner(),
                database=ms.database.MemoryDatabase(),
                max_trials=1,
            ),
            0,
            [ms.MeasureCandidate(Schedule(Matmul), None)],
            [ms.builder.BuilderResult("test_build", None)],
            [ms.runner.RunnerResult([1.0, 2.1], None)],
        )
def test_meta_schedule_measure_callback_fail():
    @derived_object
    class FailingMeasureCallback(PyMeasureCallback):
        def apply(
            self,
            task_scheduler: TaskScheduler,
            task_id: int,
            measure_candidates: List[MeasureCandidate],
            builds: List[BuilderResult],
            results: List[RunnerResult],
        ) -> None:
            raise ValueError("test")

    measure_callback = FailingMeasureCallback()
    with pytest.raises(ValueError, match="test"):
        measure_callback.apply(
            RoundRobin([], [],
                       DummyBuilder(),
                       DummyRunner(),
                       DummyDatabase(),
                       max_trials=1),
            0,
            [MeasureCandidate(Schedule(Matmul), None)],
            [BuilderResult("test_build", None)],
            [RunnerResult([1.0, 2.1], None)],
        )
Пример #14
0
 def apply(self, sch: Schedule, block: BlockRV) -> List[Schedule]:
     if _is_root(sch, block):
         return [sch]
     new_sch = sch.copy()
     i, j, k = new_sch.get_loops(block=block)
     i_0, i_1, i_2, i_3 = new_sch.split(loop=i, factors=[4, 64, 2, 2])
     j_0, j_1, j_2, j_3 = new_sch.split(loop=j, factors=[2, 4, 64, 2])
     k_0, k_1 = new_sch.split(loop=k, factors=[32, 32])
     new_sch.reorder(i_0, j_0, i_1, j_1, k_0, i_2, j_2, k_1, i_3, j_3)
     result = [new_sch]
     new_sch = sch.copy()
     i, j, k = new_sch.get_loops(block=block)
     i_0, i_1, i_2, i_3 = new_sch.split(loop=i, factors=[4, 64, 2, 2])
     j_0, j_1, j_2, j_3 = new_sch.split(loop=j, factors=[2, 4, 64, 2])
     k_0, k_1 = new_sch.split(loop=k, factors=[32, 32])
     new_sch.reorder(i_0, j_0, i_1, j_1, k_0, i_2, j_2, k_1, i_3, j_3)
     result.append(new_sch)
     return result
Пример #15
0
def schedule_matmul(sch: Schedule):
    block = sch.get_block("matmul")
    i, j, k = sch.get_loops(block=block)
    # TODO(@zxybazh): Change to `sample_perfect_tile` after upstreaming
    i_0, i_1, i_2, i_3 = sch.split(loop=i, factors=[2, 4, 64, 2])
    j_0, j_1, j_2, j_3 = sch.split(loop=j, factors=[4, 64, 2, 2])
    k_0, k_1 = sch.split(loop=k, factors=[32, 32])
    sch.reorder(i_0, j_0, i_1, j_1, k_0, i_2, j_2, k_1, i_3, j_3)
 def conv2d(sch: Schedule):
     b7 = sch.get_block(name="conv2d_winograd")
     l141, l142, l143, l144 = sch.get_loops(block=b7)
     l145 = sch.fuse(l141, l142, l143, l144)
     v146 = sch.sample_categorical(
         candidates=[32, 64, 128, 256, 512, 1024],
         probs=[
             0.16666666666666666,
             0.16666666666666666,
             0.16666666666666666,
             0.16666666666666666,
             0.16666666666666666,
             0.16666666666666666,
         ],
         decision=2,
     )
     l147, l148 = sch.split(loop=l145, factors=[None, v146])
     sch.bind(loop=l147, thread_axis="blockIdx.x")
     sch.bind(loop=l148, thread_axis="threadIdx.x")
def test_conv2d_winograd_cuda():
    mod = conv2d_winograd_cuda
    mod = IRModule({"main": mod})
    context = TuneContext(
        mod=mod,
        target=Target("nvidia/geforce-rtx-3090", host="llvm"),
        task_name="Custom Search Space Task",
        sch_rules=DefaultCUDA._sch_rules(),  # pylint: disable=protected-access
    )
    for sch_rule in context.sch_rules:
        sch_rule.initialize_with_tune_context(context)
    post_order_apply = PostOrderApply()
    post_order_apply.initialize_with_tune_context(context)
    (sch,) = post_order_apply.generate_design_space(mod)
    decisions = dict(
        zip(
            [i for i in sch.trace.insts if i.kind.name.startswith("Sample")],
            [
                # data_pack
                [3, 3],
                [64, 2],
                2,
                # inverse
                [3, 3],
                [2, 64],
                2,
                # bgemm
                [1, 1, 1, 1, 6],
                [1, 1, 1, 3, 2],
                [3, 1, 1, 1, 3],
                [4, 2, 1, 4, 4],
                [32, 1, 4],
                1,
                1,
                # root anno
                2,
                # conv2d
                2,
            ],
        )
    )
    trace = Trace(sch.trace.insts, decisions=decisions)
    sch = Schedule(mod=mod)
    trace.apply_to_schedule(sch, remove_postproc=False)
    answer = sch.mod
    expected = _get_mod()
    tvm.ir.assert_structural_equal(answer, expected)
Пример #18
0
def test_conv2d_winograd_cpu():
    mod = conv2d_winograd_cpu
    mod = IRModule({"main": mod})
    context = TuneContext(
        mod=mod,
        target=Target("llvm"),
        task_name="Custom Search Space Task",
        sch_rules=DefaultLLVM._sch_rules(),  # pylint: disable=protected-access
    )
    post_order_apply = PostOrderApply()
    post_order_apply.initialize_with_tune_context(context)
    (sch, ) = post_order_apply.generate_design_space(mod)

    decisions = dict(
        zip(
            [
                i for i in sch.trace.insts[:-4]
                if i.kind.name.startswith("Sample")
            ],
            [
                # data_pack
                [9, 1],
                [32, 4],
                # input_tile
                4,
                # data_pad
                -2,
                # inverse
                [1, 9],
                [2, 64],
                # bgemm
                [1, 2, 3, 1],
                [1, 1, 1, 6],
                [1, 1, 1, 9],
                [2, 1, 16, 4],
                [16, 8],
            ],
        ))
    trace = Trace(sch.trace.insts[:-4], decisions=decisions)
    sch = Schedule(mod=mod)
    trace.apply_to_schedule(sch, remove_postproc=False)
    answer = sch.mod
    expected = _get_mod()
    tvm.ir.assert_structural_equal(answer, expected)
Пример #19
0
def test_meta_schedule_measure_callback_fail():
    class FailingMeasureCallback(PyMeasureCallback):
        def apply(
            self,
            task_scheduler: TaskScheduler,
            task_id: int,
            measure_candidates: List[MeasureCandidate],
            builds: List[BuilderResult],
            results: List[RunnerResult],
        ) -> None:
            raise ValueError("test")

    measure_callback = FailingMeasureCallback()
    with pytest.raises(ValueError, match="test"):
        measure_callback.apply(
            TaskScheduler(),
            0,
            [MeasureCandidate(Schedule(Matmul), None)],
            [BuilderResult("test_build", None)],
            [RunnerResult([1.0, 2.1], None)],
        )
Пример #20
0
def verify_trace_roundtrip(
    sch: Schedule,
    mod: Union[PrimFunc, IRModule],
    *,
    debug_mask: Union[str, int] = "all",
) -> Schedule:
    """Serialize a traced schedule to JSON, then replay the JSON trace by applying to
    a fresh new schedule, verifying the reproducibility of scheduling.

    Parameters
    ----------
    sch : tir.Schedule
        The traced TensorIR schedule to be verified
    mod : Union[PrimFunc, IRModule]
        The IRModule or PrimFunc to construct the fresh new schedule
    debug_mask : Union[str, int]
        Do extra correctness checking after the class creation and each time
        after calling the Replace method.
        Possible choices of `debug_mask`:
        1) "all" - Turn on all the checks
        2) "none" - Turn off all the checks
        3) An integer - Turn on checks according to the bitmasks provided in ScheduleDebugMask
    """
    # Step 1. Serialize the trace to JSON
    trace = sch.trace
    assert trace is not None
    json_obj = trace.as_json()
    # Step 2. Apply the JSON trace to a new schedule, then check if it reproduces the scheduling
    new_sch = Schedule(mod=mod, debug_mask=debug_mask)
    Trace.apply_json_to_schedule(json_obj=json_obj, sch=new_sch)
    assert structural_equal(new_sch.mod, sch.mod)
    # Step 3. Check the consistency of the text format between the old and new traces
    py_repr = "\n".join(trace.as_python())
    new_py_repr = "\n".join(new_sch.trace.as_python())
    assert py_repr == new_py_repr
    # Step 4. Return the new schedule in case it could be useful
    return new_sch
Пример #21
0
def test_meta_schedule_measure_callback():
    @ms.derived_object
    class FancyMeasureCallback(ms.measure_callback.PyMeasureCallback):
        def apply(
            self,
            task_scheduler: ms.task_scheduler.TaskScheduler,
            task_id: int,
            measure_candidates: List[ms.MeasureCandidate],
            builder_results: List[ms.builder.BuilderResult],
            runner_results: List[ms.runner.RunnerResult],
        ) -> None:
            assert len(measure_candidates) == 1
            tvm.ir.assert_structural_equal(measure_candidates[0].sch.mod,
                                           Matmul)
            assert (len(builder_results) == 1
                    and builder_results[0].error_msg is None
                    and builder_results[0].artifact_path == "test_build")
            assert (len(runner_results) == 1
                    and runner_results[0].error_msg is None
                    and len(runner_results[0].run_secs) == 2)

    measure_callback = FancyMeasureCallback()
    measure_callback.apply(
        ms.task_scheduler.RoundRobin(
            tasks=[],
            task_weights=[],
            builder=DummyBuilder(),
            runner=DummyRunner(),
            database=ms.database.MemoryDatabase(),
            max_trials=1,
        ),
        0,
        [ms.MeasureCandidate(Schedule(Matmul), None)],
        [ms.builder.BuilderResult("test_build", None)],
        [ms.runner.RunnerResult([1.0, 2.1], None)],
    )
Пример #22
0
 def _schedule_matmul_small(sch: Schedule):
     block = sch.get_block("matmul")
     _, j, k = sch.get_loops(block=block)
     _, _ = sch.split(j, sch.sample_perfect_tile(j, n=2))
     _, _ = sch.split(k, sch.sample_perfect_tile(k, n=2))
Пример #23
0
 def apply(self, sch: Schedule, block: BlockRV):
     if sch.get(block).name_hint == "root":
         return [sch]
     sch = sch.copy()
     sch.compute_inline(block)
     return [sch]
Пример #24
0
def _is_root(sch: Schedule, block: BlockRV) -> bool:
    return sch.get_sref(block).parent is None
Пример #25
0
    def bgemm(sch: Schedule):
        b31 = sch.get_block(name="bgemm")
        sch.annotate(
            block_or_loop=b31,
            ann_key="meta_schedule.tiling_structure",
            ann_val="SSSRRSRS",
        )
        b32 = sch.cache_write(block=b31,
                              write_buffer_index=0,
                              storage_scope="local")
        b31, b32 = b32, b31
        l33, l34, l35, l36, l37 = sch.get_loops(block=b32)
        v38, v39, v40, v41, v42 = sch.sample_perfect_tile(
            n=5,
            loop=l33,
            max_innermost_factor=64,
            decision=[1, 1, 1, 1, 6],
        )
        l43, l44, l45, l46, l47 = sch.split(loop=l33,
                                            factors=[v38, v39, v40, v41, v42])
        v48, v49, v50, v51, v52 = sch.sample_perfect_tile(
            n=5,
            loop=l34,
            max_innermost_factor=64,
            decision=[1, 1, 1, 3, 2],
        )
        l53, l54, l55, l56, l57 = sch.split(loop=l34,
                                            factors=[v48, v49, v50, v51, v52])
        v58, v59, v60, v61, v62 = sch.sample_perfect_tile(
            n=5,
            loop=l35,
            max_innermost_factor=64,
            decision=[3, 1, 1, 1, 3],
        )
        l63, l64, l65, l66, l67 = sch.split(loop=l35,
                                            factors=[v58, v59, v60, v61, v62])
        v68, v69, v70, v71, v72 = sch.sample_perfect_tile(
            n=5,
            loop=l36,
            max_innermost_factor=64,
            decision=[4, 2, 1, 4, 4],
        )
        l73, l74, l75, l76, l77 = sch.split(loop=l36,
                                            factors=[v68, v69, v70, v71, v72])
        v78, v79, v80 = sch.sample_perfect_tile(
            n=3,
            loop=l37,
            max_innermost_factor=64,
            decision=[32, 1, 4],
        )
        l81, l82, l83 = sch.split(loop=l37, factors=[v78, v79, v80])
        sch.reorder(
            # fmt: off
            l43,
            l53,
            l63,
            l73,
            l44,
            l54,
            l64,
            l74,
            l45,
            l55,
            l65,
            l75,
            l81,
            l82,
            l46,
            l56,
            l66,
            l76,
            l83,
            l47,
            l57,
            l67,
            l77,
            # fmt: on
        )
        l84 = sch.fuse(l43, l53, l63, l73)
        sch.bind(loop=l84, thread_axis="blockIdx.x")
        l85 = sch.fuse(l44, l54, l64, l74)
        sch.bind(loop=l85, thread_axis="vthread.x")
        l86 = sch.fuse(l45, l55, l65, l75)
        sch.bind(loop=l86, thread_axis="threadIdx.x")

        b87 = sch.cache_read(block=b32,
                             read_buffer_index=1,
                             storage_scope="shared")
        sch.compute_at(block=b87, loop=l81, preserve_unit_loops=True)
        _, _, _, _, l92, l93, l94, l95 = sch.get_loops(block=b87)
        sch.fuse(l92, l93, l94, l95)
        v97 = sch.sample_categorical(
            candidates=[1, 2, 3, 4],
            probs=[0.25, 0.25, 0.25, 0.25],
            decision=1,
        )
        sch.annotate(
            block_or_loop=b87,
            ann_key="meta_schedule.cooperative_fetch",
            ann_val=v97,
        )

        b101 = sch.cache_read(block=b32,
                              read_buffer_index=2,
                              storage_scope="shared")
        sch.compute_at(block=b101, loop=l81, preserve_unit_loops=True)
        _, _, _, _, l106, l107, l108, l109 = sch.get_loops(block=b101)
        sch.fuse(l106, l107, l108, l109)
        v110 = sch.sample_categorical(
            candidates=[1, 2, 3, 4],
            probs=[0.25, 0.25, 0.25, 0.25],
            decision=1,
        )
        sch.annotate(
            block_or_loop=b101,
            ann_key="meta_schedule.cooperative_fetch",
            ann_val=v110,
        )

        sch.reverse_compute_at(block=b31, loop=l86, preserve_unit_loops=True)
Пример #26
0
 def data_pack(sch: Schedule):
     b16 = sch.get_block(name="data_pack")
     l17, l18, l19, l20, l21, l22 = sch.get_loops(block=b16)
     sch.unroll(loop=l17)
     sch.unroll(loop=l18)
     v23, v24 = sch.sample_perfect_tile(
         n=2,
         loop=l19,
         max_innermost_factor=64,
         decision=[3, 3],
     )
     l25, l26 = sch.split(loop=l19, factors=[v23, v24])
     v27, v28 = sch.sample_perfect_tile(
         n=2,
         loop=l20,
         max_innermost_factor=64,
         decision=[64, 2],
     )
     l29, l30 = sch.split(loop=l20, factors=[v27, v28])
     sch.unroll(loop=l21)
     sch.unroll(loop=l22)
     sch.reorder(l25, l29, l26, l30, l17, l18, l21, l22)
Пример #27
0
 def inline(sch: Schedule):
     b125 = sch.get_block(name="A")
     sch.compute_inline(block=b125)
     b126 = sch.get_block(name="B")
     sch.compute_inline(block=b126)
Пример #28
0
def _get_mod():
    # pylint: disable=invalid-name
    def inline(sch: Schedule):
        b125 = sch.get_block(name="A")
        sch.compute_inline(block=b125)
        b126 = sch.get_block(name="B")
        sch.compute_inline(block=b126)

    def input_tile_data_pad(sch: Schedule):
        b115 = sch.get_block(name="input_tile")
        (b116, ) = sch.get_consumers(block=b115)
        _, _, _, l120, _, _, _, _ = sch.get_loops(block=b116)
        sch.compute_at(block=b115, loop=l120, preserve_unit_loops=True)
        sch.set_scope(block=b115, buffer_index=0, storage_scope="local")

        b127 = sch.get_block(name="data_pad")
        sch.compute_inline(block=b127)

    def data_pack(sch: Schedule):
        b16 = sch.get_block(name="data_pack")
        l17, l18, l19, l20, l21, l22 = sch.get_loops(block=b16)
        sch.unroll(loop=l17)
        sch.unroll(loop=l18)
        v23, v24 = sch.sample_perfect_tile(
            n=2,
            loop=l19,
            max_innermost_factor=64,
            decision=[3, 3],
        )
        l25, l26 = sch.split(loop=l19, factors=[v23, v24])
        v27, v28 = sch.sample_perfect_tile(
            n=2,
            loop=l20,
            max_innermost_factor=64,
            decision=[64, 2],
        )
        l29, l30 = sch.split(loop=l20, factors=[v27, v28])
        sch.unroll(loop=l21)
        sch.unroll(loop=l22)
        sch.reorder(l25, l29, l26, l30, l17, l18, l21, l22)

    def bgemm(sch: Schedule):
        b31 = sch.get_block(name="bgemm")
        sch.annotate(
            block_or_loop=b31,
            ann_key="meta_schedule.tiling_structure",
            ann_val="SSSRRSRS",
        )
        b32 = sch.cache_write(block=b31,
                              write_buffer_index=0,
                              storage_scope="local")
        b31, b32 = b32, b31
        l33, l34, l35, l36, l37 = sch.get_loops(block=b32)
        v38, v39, v40, v41, v42 = sch.sample_perfect_tile(
            n=5,
            loop=l33,
            max_innermost_factor=64,
            decision=[1, 1, 1, 1, 6],
        )
        l43, l44, l45, l46, l47 = sch.split(loop=l33,
                                            factors=[v38, v39, v40, v41, v42])
        v48, v49, v50, v51, v52 = sch.sample_perfect_tile(
            n=5,
            loop=l34,
            max_innermost_factor=64,
            decision=[1, 1, 1, 3, 2],
        )
        l53, l54, l55, l56, l57 = sch.split(loop=l34,
                                            factors=[v48, v49, v50, v51, v52])
        v58, v59, v60, v61, v62 = sch.sample_perfect_tile(
            n=5,
            loop=l35,
            max_innermost_factor=64,
            decision=[3, 1, 1, 1, 3],
        )
        l63, l64, l65, l66, l67 = sch.split(loop=l35,
                                            factors=[v58, v59, v60, v61, v62])
        v68, v69, v70, v71, v72 = sch.sample_perfect_tile(
            n=5,
            loop=l36,
            max_innermost_factor=64,
            decision=[4, 2, 1, 4, 4],
        )
        l73, l74, l75, l76, l77 = sch.split(loop=l36,
                                            factors=[v68, v69, v70, v71, v72])
        v78, v79, v80 = sch.sample_perfect_tile(
            n=3,
            loop=l37,
            max_innermost_factor=64,
            decision=[32, 1, 4],
        )
        l81, l82, l83 = sch.split(loop=l37, factors=[v78, v79, v80])
        sch.reorder(
            # fmt: off
            l43,
            l53,
            l63,
            l73,
            l44,
            l54,
            l64,
            l74,
            l45,
            l55,
            l65,
            l75,
            l81,
            l82,
            l46,
            l56,
            l66,
            l76,
            l83,
            l47,
            l57,
            l67,
            l77,
            # fmt: on
        )
        l84 = sch.fuse(l43, l53, l63, l73)
        sch.bind(loop=l84, thread_axis="blockIdx.x")
        l85 = sch.fuse(l44, l54, l64, l74)
        sch.bind(loop=l85, thread_axis="vthread.x")
        l86 = sch.fuse(l45, l55, l65, l75)
        sch.bind(loop=l86, thread_axis="threadIdx.x")

        b87 = sch.cache_read(block=b32,
                             read_buffer_index=1,
                             storage_scope="shared")
        sch.compute_at(block=b87, loop=l81, preserve_unit_loops=True)
        _, _, _, _, l92, l93, l94, l95 = sch.get_loops(block=b87)
        sch.fuse(l92, l93, l94, l95)
        v97 = sch.sample_categorical(
            candidates=[1, 2, 3, 4],
            probs=[0.25, 0.25, 0.25, 0.25],
            decision=1,
        )
        sch.annotate(
            block_or_loop=b87,
            ann_key="meta_schedule.cooperative_fetch",
            ann_val=v97,
        )

        b101 = sch.cache_read(block=b32,
                              read_buffer_index=2,
                              storage_scope="shared")
        sch.compute_at(block=b101, loop=l81, preserve_unit_loops=True)
        _, _, _, _, l106, l107, l108, l109 = sch.get_loops(block=b101)
        sch.fuse(l106, l107, l108, l109)
        v110 = sch.sample_categorical(
            candidates=[1, 2, 3, 4],
            probs=[0.25, 0.25, 0.25, 0.25],
            decision=1,
        )
        sch.annotate(
            block_or_loop=b101,
            ann_key="meta_schedule.cooperative_fetch",
            ann_val=v110,
        )

        sch.reverse_compute_at(block=b31, loop=l86, preserve_unit_loops=True)

    def inverse(sch: Schedule):
        b1 = sch.get_block(name="inverse")
        l2, l3, l4, l5, l6, l7 = sch.get_loops(block=b1)
        sch.unroll(loop=l2)
        sch.unroll(loop=l3)
        v8, v9 = sch.sample_perfect_tile(
            n=2,
            loop=l4,
            max_innermost_factor=64,
            decision=[3, 3],
        )
        l10, l11 = sch.split(loop=l4, factors=[v8, v9])
        v12, v13 = sch.sample_perfect_tile(
            n=2,
            loop=l5,
            max_innermost_factor=64,
            decision=[2, 64],
        )
        l14, l15 = sch.split(loop=l5, factors=[v12, v13])
        sch.unroll(loop=l6)
        sch.unroll(loop=l7)
        sch.reorder(l10, l14, l11, l15, l2, l3, l6, l7)

    # pylint: enable=invalid-name

    sch = Schedule(mod=conv2d_winograd_cuda)
    inline(sch)
    data_pack(sch)
    input_tile_data_pad(sch)
    bgemm(sch)
    inverse(sch)

    return sch.mod
Пример #29
0
 def inverse(sch: Schedule):
     b1 = sch.get_block(name="inverse")
     l2, l3, l4, l5, l6, l7 = sch.get_loops(block=b1)
     sch.unroll(loop=l2)
     sch.unroll(loop=l3)
     v8, v9 = sch.sample_perfect_tile(
         n=2,
         loop=l4,
         max_innermost_factor=64,
         decision=[3, 3],
     )
     l10, l11 = sch.split(loop=l4, factors=[v8, v9])
     v12, v13 = sch.sample_perfect_tile(
         n=2,
         loop=l5,
         max_innermost_factor=64,
         decision=[2, 64],
     )
     l14, l15 = sch.split(loop=l5, factors=[v12, v13])
     sch.unroll(loop=l6)
     sch.unroll(loop=l7)
     sch.reorder(l10, l14, l11, l15, l2, l3, l6, l7)
Пример #30
0
def test_vectorize_inner_loop():
    sch = Schedule(before_matmul_vectorize)
    rule = RewriteParallelVectorizeUnroll()
    assert rule.apply(sch)
    tvm.ir.assert_structural_equal(sch.mod["main"], after_matmul_vectorize)