def input_tile_data_pad(sch: Schedule):
        b115 = sch.get_block(name="input_tile")
        (b116,) = sch.get_consumers(block=b115)
        _, _, _, l120, _, _, _, _ = sch.get_loops(block=b116)
        sch.compute_at(block=b115, loop=l120, preserve_unit_loops=True)
        sch.set_scope(block=b115, buffer_index=0, storage_scope="local")

        b127 = sch.get_block(name="data_pad")
        sch.compute_inline(block=b127)

        b3 = sch.get_block(name="data_pack")
        l25, l26, l27, l28, _, _, _, _ = sch.get_loops(block=b3)
        l33 = sch.fuse(l25, l26, l27, l28)
        v34 = sch.sample_categorical(
            candidates=[32, 64, 128, 256, 512, 1024],
            probs=[
                0.16666666666666666,
                0.16666666666666666,
                0.16666666666666666,
                0.16666666666666666,
                0.16666666666666666,
                0.16666666666666666,
            ],
            decision=2,
        )
        l35, l36 = sch.split(loop=l33, factors=[None, v34])
        sch.bind(loop=l35, thread_axis="blockIdx.x")
        sch.bind(loop=l36, thread_axis="threadIdx.x")
示例#2
0
    def input_tile_data_pad(sch: Schedule):
        b78 = sch.get_block(name="input_tile")
        l80 = sch.sample_compute_location(block=b78, decision=4)
        sch.compute_at(block=b78, loop=l80, preserve_unit_loops=True)

        b81 = sch.get_block(name="data_pad")
        l83 = sch.sample_compute_location(block=b81, decision=-2)
        sch.compute_at(block=b81, loop=l83, preserve_unit_loops=True)
示例#3
0
    def input_tile_data_pad(sch: Schedule):
        b115 = sch.get_block(name="input_tile")
        (b116, ) = sch.get_consumers(block=b115)
        _, _, _, l120, _, _, _, _ = sch.get_loops(block=b116)
        sch.compute_at(block=b115, loop=l120, preserve_unit_loops=True)
        sch.set_scope(block=b115, buffer_index=0, storage_scope="local")

        b127 = sch.get_block(name="data_pad")
        sch.compute_inline(block=b127)
def _schedule_matmul(sch: Schedule):
    block = sch.get_block("matmul")
    i, j, k = sch.get_loops(block=block)
    i_0, i_1, i_2, i_3 = sch.split(i, sch.sample_perfect_tile(i, n=4))
    j_0, j_1, j_2, j_3 = sch.split(j, sch.sample_perfect_tile(j, n=4))
    k_0, k_1 = sch.split(k, sch.sample_perfect_tile(k, n=2))
    sch.reorder(i_0, j_0, i_1, j_1, k_0, i_2, j_2, k_1, i_3, j_3)
示例#5
0
def schedule_matmul(sch: Schedule):
    block = sch.get_block("matmul")
    i, j, k = sch.get_loops(block=block)
    # TODO(@zxybazh): Change to `sample_perfect_tile` after upstreaming
    i_0, i_1, i_2, i_3 = sch.split(loop=i, factors=[2, 4, 64, 2])
    j_0, j_1, j_2, j_3 = sch.split(loop=j, factors=[4, 64, 2, 2])
    k_0, k_1 = sch.split(loop=k, factors=[32, 32])
    sch.reorder(i_0, j_0, i_1, j_1, k_0, i_2, j_2, k_1, i_3, j_3)
 def root_anno(sch: Schedule):
     b8 = sch.get_block(name="root", func_name="main")
     v140 = sch.sample_categorical(
         candidates=[0, 16, 64, 512, 1024],
         probs=[
             0.20000000000000001,
             0.20000000000000001,
             0.20000000000000001,
             0.20000000000000001,
             0.20000000000000001,
         ],
         decision=2,
     )
     sch.annotate(block_or_loop=b8, ann_key="meta_schedule.unroll_explicit", ann_val=v140)
 def conv2d(sch: Schedule):
     b7 = sch.get_block(name="conv2d_winograd")
     l141, l142, l143, l144 = sch.get_loops(block=b7)
     l145 = sch.fuse(l141, l142, l143, l144)
     v146 = sch.sample_categorical(
         candidates=[32, 64, 128, 256, 512, 1024],
         probs=[
             0.16666666666666666,
             0.16666666666666666,
             0.16666666666666666,
             0.16666666666666666,
             0.16666666666666666,
             0.16666666666666666,
         ],
         decision=2,
     )
     l147, l148 = sch.split(loop=l145, factors=[None, v146])
     sch.bind(loop=l147, thread_axis="blockIdx.x")
     sch.bind(loop=l148, thread_axis="threadIdx.x")
示例#8
0
 def inverse(sch: Schedule):
     b3 = sch.get_block(name="inverse")
     l4, l5, l6, l7, l8, l9 = sch.get_loops(block=b3)
     sch.unroll(loop=l4)
     sch.unroll(loop=l5)
     v10, v11 = sch.sample_perfect_tile(
         n=2,
         loop=l6,
         max_innermost_factor=64,
         decision=[1, 9],
     )
     l12, l13 = sch.split(loop=l6, factors=[v10, v11])
     v14, v15 = sch.sample_perfect_tile(
         n=2,
         loop=l7,
         max_innermost_factor=64,
         decision=[2, 64],
     )
     l16, l17 = sch.split(loop=l7, factors=[v14, v15])
     sch.unroll(loop=l8)
     sch.unroll(loop=l9)
     sch.reorder(l12, l16, l13, l17, l4, l5, l8, l9)
示例#9
0
 def data_pack(sch: Schedule):
     b16 = sch.get_block(name="data_pack")
     l17, l18, l19, l20, l21, l22 = sch.get_loops(block=b16)
     sch.unroll(loop=l17)
     sch.unroll(loop=l18)
     v23, v24 = sch.sample_perfect_tile(
         n=2,
         loop=l19,
         max_innermost_factor=64,
         decision=[3, 3],
     )
     l25, l26 = sch.split(loop=l19, factors=[v23, v24])
     v27, v28 = sch.sample_perfect_tile(
         n=2,
         loop=l20,
         max_innermost_factor=64,
         decision=[64, 2],
     )
     l29, l30 = sch.split(loop=l20, factors=[v27, v28])
     sch.unroll(loop=l21)
     sch.unroll(loop=l22)
     sch.reorder(l25, l29, l26, l30, l17, l18, l21, l22)
示例#10
0
 def data_pack(sch: Schedule):
     b18 = sch.get_block(name="data_pack")
     l19, l20, l21, l22, l23, l24 = sch.get_loops(block=b18)
     sch.unroll(loop=l19)
     sch.unroll(loop=l20)
     v25, v26 = sch.sample_perfect_tile(
         n=2,
         loop=l21,
         max_innermost_factor=64,
         decision=[9, 1],
     )
     l27, l28 = sch.split(loop=l21, factors=[v25, v26])
     v29, v30 = sch.sample_perfect_tile(
         n=2,
         loop=l22,
         max_innermost_factor=64,
         decision=[32, 4],
     )
     l31, l32 = sch.split(loop=l22, factors=[v29, v30])
     sch.unroll(loop=l23)
     sch.unroll(loop=l24)
     sch.reorder(l27, l31, l28, l32, l19, l20, l23, l24)
示例#11
0
 def inverse(sch: Schedule):
     b1 = sch.get_block(name="inverse")
     l2, l3, l4, l5, l6, l7 = sch.get_loops(block=b1)
     sch.unroll(loop=l2)
     sch.unroll(loop=l3)
     v8, v9 = sch.sample_perfect_tile(
         n=2,
         loop=l4,
         max_innermost_factor=64,
         decision=[3, 3],
     )
     l10, l11 = sch.split(loop=l4, factors=[v8, v9])
     v12, v13 = sch.sample_perfect_tile(
         n=2,
         loop=l5,
         max_innermost_factor=64,
         decision=[2, 64],
     )
     l14, l15 = sch.split(loop=l5, factors=[v12, v13])
     sch.unroll(loop=l6)
     sch.unroll(loop=l7)
     sch.reorder(l10, l14, l11, l15, l2, l3, l6, l7)
 def inverse(sch: Schedule):
     b1 = sch.get_block(name="inverse")
     l2, l3, l4, l5, l6, l7 = sch.get_loops(block=b1)
     sch.unroll(loop=l2)
     sch.unroll(loop=l3)
     v8, v9 = sch.sample_perfect_tile(
         n=2,
         loop=l4,
         max_innermost_factor=64,
         decision=[3, 3],
     )
     l10, l11 = sch.split(loop=l4, factors=[v8, v9])
     v12, v13 = sch.sample_perfect_tile(
         n=2,
         loop=l5,
         max_innermost_factor=64,
         decision=[2, 64],
     )
     l14, l15 = sch.split(loop=l5, factors=[v12, v13])
     sch.unroll(loop=l6)
     sch.unroll(loop=l7)
     sch.reorder(l10, l14, l11, l15, l2, l3, l6, l7)
     l59 = sch.fuse(l10, l14, l11, l15)
     v60 = sch.sample_categorical(
         candidates=[32, 64, 128, 256, 512, 1024],
         probs=[
             0.16666666666666666,
             0.16666666666666666,
             0.16666666666666666,
             0.16666666666666666,
             0.16666666666666666,
             0.16666666666666666,
         ],
         decision=2,
     )
     l61, l62 = sch.split(loop=l59, factors=[None, v60])
     sch.bind(loop=l61, thread_axis="blockIdx.x")
     sch.bind(loop=l62, thread_axis="threadIdx.x")
示例#13
0
    def bgemm(sch: Schedule):
        b31 = sch.get_block(name="bgemm")
        sch.annotate(
            block_or_loop=b31,
            ann_key="meta_schedule.tiling_structure",
            ann_val="SSSRRSRS",
        )
        b32 = sch.cache_write(block=b31,
                              write_buffer_index=0,
                              storage_scope="local")
        b31, b32 = b32, b31
        l33, l34, l35, l36, l37 = sch.get_loops(block=b32)
        v38, v39, v40, v41, v42 = sch.sample_perfect_tile(
            n=5,
            loop=l33,
            max_innermost_factor=64,
            decision=[1, 1, 1, 1, 6],
        )
        l43, l44, l45, l46, l47 = sch.split(loop=l33,
                                            factors=[v38, v39, v40, v41, v42])
        v48, v49, v50, v51, v52 = sch.sample_perfect_tile(
            n=5,
            loop=l34,
            max_innermost_factor=64,
            decision=[1, 1, 1, 3, 2],
        )
        l53, l54, l55, l56, l57 = sch.split(loop=l34,
                                            factors=[v48, v49, v50, v51, v52])
        v58, v59, v60, v61, v62 = sch.sample_perfect_tile(
            n=5,
            loop=l35,
            max_innermost_factor=64,
            decision=[3, 1, 1, 1, 3],
        )
        l63, l64, l65, l66, l67 = sch.split(loop=l35,
                                            factors=[v58, v59, v60, v61, v62])
        v68, v69, v70, v71, v72 = sch.sample_perfect_tile(
            n=5,
            loop=l36,
            max_innermost_factor=64,
            decision=[4, 2, 1, 4, 4],
        )
        l73, l74, l75, l76, l77 = sch.split(loop=l36,
                                            factors=[v68, v69, v70, v71, v72])
        v78, v79, v80 = sch.sample_perfect_tile(
            n=3,
            loop=l37,
            max_innermost_factor=64,
            decision=[32, 1, 4],
        )
        l81, l82, l83 = sch.split(loop=l37, factors=[v78, v79, v80])
        sch.reorder(
            # fmt: off
            l43,
            l53,
            l63,
            l73,
            l44,
            l54,
            l64,
            l74,
            l45,
            l55,
            l65,
            l75,
            l81,
            l82,
            l46,
            l56,
            l66,
            l76,
            l83,
            l47,
            l57,
            l67,
            l77,
            # fmt: on
        )
        l84 = sch.fuse(l43, l53, l63, l73)
        sch.bind(loop=l84, thread_axis="blockIdx.x")
        l85 = sch.fuse(l44, l54, l64, l74)
        sch.bind(loop=l85, thread_axis="vthread.x")
        l86 = sch.fuse(l45, l55, l65, l75)
        sch.bind(loop=l86, thread_axis="threadIdx.x")

        b87 = sch.cache_read(block=b32,
                             read_buffer_index=1,
                             storage_scope="shared")
        sch.compute_at(block=b87, loop=l81, preserve_unit_loops=True)
        _, _, _, _, l92, l93, l94, l95 = sch.get_loops(block=b87)
        sch.fuse(l92, l93, l94, l95)
        v97 = sch.sample_categorical(
            candidates=[1, 2, 3, 4],
            probs=[0.25, 0.25, 0.25, 0.25],
            decision=1,
        )
        sch.annotate(
            block_or_loop=b87,
            ann_key="meta_schedule.cooperative_fetch",
            ann_val=v97,
        )

        b101 = sch.cache_read(block=b32,
                              read_buffer_index=2,
                              storage_scope="shared")
        sch.compute_at(block=b101, loop=l81, preserve_unit_loops=True)
        _, _, _, _, l106, l107, l108, l109 = sch.get_loops(block=b101)
        sch.fuse(l106, l107, l108, l109)
        v110 = sch.sample_categorical(
            candidates=[1, 2, 3, 4],
            probs=[0.25, 0.25, 0.25, 0.25],
            decision=1,
        )
        sch.annotate(
            block_or_loop=b101,
            ann_key="meta_schedule.cooperative_fetch",
            ann_val=v110,
        )

        sch.reverse_compute_at(block=b31, loop=l86, preserve_unit_loops=True)
示例#14
0
 def inline(sch: Schedule):
     b125 = sch.get_block(name="A")
     sch.compute_inline(block=b125)
     b126 = sch.get_block(name="B")
     sch.compute_inline(block=b126)
示例#15
0
 def bgemm(sch: Schedule):
     bgemm = sch.get_block(name="bgemm")
     write_cache = sch.cache_write(
         block=bgemm,
         write_buffer_index=0,
         storage_scope="global",
     )
     sch.annotate(
         block_or_loop=bgemm,
         ann_key="meta_schedule.tiling_structure",
         ann_val="SSRSRS",
     )
     # b33, b34 = b34, b33
     l35, l36, l37, l38, l39 = sch.get_loops(block=bgemm)
     v40, v41, v42, v43 = sch.sample_perfect_tile(
         n=4,
         loop=l35,
         max_innermost_factor=64,
         decision=[1, 2, 3, 1],
     )
     l44, l45, l46, l47 = sch.split(loop=l35, factors=[v40, v41, v42, v43])
     v48, v49, v50, v51 = sch.sample_perfect_tile(
         n=4,
         loop=l36,
         max_innermost_factor=64,
         decision=[1, 1, 1, 6],
     )
     l52, l53, l54, l55 = sch.split(loop=l36, factors=[v48, v49, v50, v51])
     v56, v57, v58, v59 = sch.sample_perfect_tile(
         n=4,
         loop=l37,
         max_innermost_factor=64,
         decision=[1, 1, 1, 9],
     )
     l60, l61, l62, l63 = sch.split(loop=l37, factors=[v56, v57, v58, v59])
     v64, v65, v66, v67 = sch.sample_perfect_tile(
         n=4,
         loop=l38,
         max_innermost_factor=64,
         decision=[2, 1, 16, 4],
     )
     l68, l69, l70, l71 = sch.split(loop=l38, factors=[v64, v65, v66, v67])
     v72, v73 = sch.sample_perfect_tile(
         n=2,
         loop=l39,
         max_innermost_factor=64,
         decision=[16, 8],
     )
     l74, l75 = sch.split(loop=l39, factors=[v72, v73])
     sch.reorder(
         # fmt: off
         l44,
         l52,
         l60,
         l68,
         l45,
         l53,
         l61,
         l69,
         l74,
         l46,
         l54,
         l62,
         l70,
         l75,
         l47,
         l55,
         l63,
         l71,
         # fmt: on
     )
     sch.reverse_compute_at(block=write_cache,
                            loop=l69,
                            preserve_unit_loops=True)
 def _schedule_matmul_small(sch: Schedule):
     block = sch.get_block("matmul")
     _, j, k = sch.get_loops(block=block)
     _, _ = sch.split(j, sch.sample_perfect_tile(j, n=2))
     _, _ = sch.split(k, sch.sample_perfect_tile(k, n=2))
示例#17
0
 def inline(sch: Schedule):
     b1 = sch.get_block(name="A")
     b2 = sch.get_block(name="B")
     sch.compute_inline(block=b1)
     sch.compute_inline(block=b2)