def input_tile_data_pad(sch: Schedule): b115 = sch.get_block(name="input_tile") (b116,) = sch.get_consumers(block=b115) _, _, _, l120, _, _, _, _ = sch.get_loops(block=b116) sch.compute_at(block=b115, loop=l120, preserve_unit_loops=True) sch.set_scope(block=b115, buffer_index=0, storage_scope="local") b127 = sch.get_block(name="data_pad") sch.compute_inline(block=b127) b3 = sch.get_block(name="data_pack") l25, l26, l27, l28, _, _, _, _ = sch.get_loops(block=b3) l33 = sch.fuse(l25, l26, l27, l28) v34 = sch.sample_categorical( candidates=[32, 64, 128, 256, 512, 1024], probs=[ 0.16666666666666666, 0.16666666666666666, 0.16666666666666666, 0.16666666666666666, 0.16666666666666666, 0.16666666666666666, ], decision=2, ) l35, l36 = sch.split(loop=l33, factors=[None, v34]) sch.bind(loop=l35, thread_axis="blockIdx.x") sch.bind(loop=l36, thread_axis="threadIdx.x")
def input_tile_data_pad(sch: Schedule): b78 = sch.get_block(name="input_tile") l80 = sch.sample_compute_location(block=b78, decision=4) sch.compute_at(block=b78, loop=l80, preserve_unit_loops=True) b81 = sch.get_block(name="data_pad") l83 = sch.sample_compute_location(block=b81, decision=-2) sch.compute_at(block=b81, loop=l83, preserve_unit_loops=True)
def input_tile_data_pad(sch: Schedule): b115 = sch.get_block(name="input_tile") (b116, ) = sch.get_consumers(block=b115) _, _, _, l120, _, _, _, _ = sch.get_loops(block=b116) sch.compute_at(block=b115, loop=l120, preserve_unit_loops=True) sch.set_scope(block=b115, buffer_index=0, storage_scope="local") b127 = sch.get_block(name="data_pad") sch.compute_inline(block=b127)
def _schedule_matmul(sch: Schedule): block = sch.get_block("matmul") i, j, k = sch.get_loops(block=block) i_0, i_1, i_2, i_3 = sch.split(i, sch.sample_perfect_tile(i, n=4)) j_0, j_1, j_2, j_3 = sch.split(j, sch.sample_perfect_tile(j, n=4)) k_0, k_1 = sch.split(k, sch.sample_perfect_tile(k, n=2)) sch.reorder(i_0, j_0, i_1, j_1, k_0, i_2, j_2, k_1, i_3, j_3)
def schedule_matmul(sch: Schedule): block = sch.get_block("matmul") i, j, k = sch.get_loops(block=block) # TODO(@zxybazh): Change to `sample_perfect_tile` after upstreaming i_0, i_1, i_2, i_3 = sch.split(loop=i, factors=[2, 4, 64, 2]) j_0, j_1, j_2, j_3 = sch.split(loop=j, factors=[4, 64, 2, 2]) k_0, k_1 = sch.split(loop=k, factors=[32, 32]) sch.reorder(i_0, j_0, i_1, j_1, k_0, i_2, j_2, k_1, i_3, j_3)
def root_anno(sch: Schedule): b8 = sch.get_block(name="root", func_name="main") v140 = sch.sample_categorical( candidates=[0, 16, 64, 512, 1024], probs=[ 0.20000000000000001, 0.20000000000000001, 0.20000000000000001, 0.20000000000000001, 0.20000000000000001, ], decision=2, ) sch.annotate(block_or_loop=b8, ann_key="meta_schedule.unroll_explicit", ann_val=v140)
def conv2d(sch: Schedule): b7 = sch.get_block(name="conv2d_winograd") l141, l142, l143, l144 = sch.get_loops(block=b7) l145 = sch.fuse(l141, l142, l143, l144) v146 = sch.sample_categorical( candidates=[32, 64, 128, 256, 512, 1024], probs=[ 0.16666666666666666, 0.16666666666666666, 0.16666666666666666, 0.16666666666666666, 0.16666666666666666, 0.16666666666666666, ], decision=2, ) l147, l148 = sch.split(loop=l145, factors=[None, v146]) sch.bind(loop=l147, thread_axis="blockIdx.x") sch.bind(loop=l148, thread_axis="threadIdx.x")
def inverse(sch: Schedule): b3 = sch.get_block(name="inverse") l4, l5, l6, l7, l8, l9 = sch.get_loops(block=b3) sch.unroll(loop=l4) sch.unroll(loop=l5) v10, v11 = sch.sample_perfect_tile( n=2, loop=l6, max_innermost_factor=64, decision=[1, 9], ) l12, l13 = sch.split(loop=l6, factors=[v10, v11]) v14, v15 = sch.sample_perfect_tile( n=2, loop=l7, max_innermost_factor=64, decision=[2, 64], ) l16, l17 = sch.split(loop=l7, factors=[v14, v15]) sch.unroll(loop=l8) sch.unroll(loop=l9) sch.reorder(l12, l16, l13, l17, l4, l5, l8, l9)
def data_pack(sch: Schedule): b16 = sch.get_block(name="data_pack") l17, l18, l19, l20, l21, l22 = sch.get_loops(block=b16) sch.unroll(loop=l17) sch.unroll(loop=l18) v23, v24 = sch.sample_perfect_tile( n=2, loop=l19, max_innermost_factor=64, decision=[3, 3], ) l25, l26 = sch.split(loop=l19, factors=[v23, v24]) v27, v28 = sch.sample_perfect_tile( n=2, loop=l20, max_innermost_factor=64, decision=[64, 2], ) l29, l30 = sch.split(loop=l20, factors=[v27, v28]) sch.unroll(loop=l21) sch.unroll(loop=l22) sch.reorder(l25, l29, l26, l30, l17, l18, l21, l22)
def data_pack(sch: Schedule): b18 = sch.get_block(name="data_pack") l19, l20, l21, l22, l23, l24 = sch.get_loops(block=b18) sch.unroll(loop=l19) sch.unroll(loop=l20) v25, v26 = sch.sample_perfect_tile( n=2, loop=l21, max_innermost_factor=64, decision=[9, 1], ) l27, l28 = sch.split(loop=l21, factors=[v25, v26]) v29, v30 = sch.sample_perfect_tile( n=2, loop=l22, max_innermost_factor=64, decision=[32, 4], ) l31, l32 = sch.split(loop=l22, factors=[v29, v30]) sch.unroll(loop=l23) sch.unroll(loop=l24) sch.reorder(l27, l31, l28, l32, l19, l20, l23, l24)
def inverse(sch: Schedule): b1 = sch.get_block(name="inverse") l2, l3, l4, l5, l6, l7 = sch.get_loops(block=b1) sch.unroll(loop=l2) sch.unroll(loop=l3) v8, v9 = sch.sample_perfect_tile( n=2, loop=l4, max_innermost_factor=64, decision=[3, 3], ) l10, l11 = sch.split(loop=l4, factors=[v8, v9]) v12, v13 = sch.sample_perfect_tile( n=2, loop=l5, max_innermost_factor=64, decision=[2, 64], ) l14, l15 = sch.split(loop=l5, factors=[v12, v13]) sch.unroll(loop=l6) sch.unroll(loop=l7) sch.reorder(l10, l14, l11, l15, l2, l3, l6, l7)
def inverse(sch: Schedule): b1 = sch.get_block(name="inverse") l2, l3, l4, l5, l6, l7 = sch.get_loops(block=b1) sch.unroll(loop=l2) sch.unroll(loop=l3) v8, v9 = sch.sample_perfect_tile( n=2, loop=l4, max_innermost_factor=64, decision=[3, 3], ) l10, l11 = sch.split(loop=l4, factors=[v8, v9]) v12, v13 = sch.sample_perfect_tile( n=2, loop=l5, max_innermost_factor=64, decision=[2, 64], ) l14, l15 = sch.split(loop=l5, factors=[v12, v13]) sch.unroll(loop=l6) sch.unroll(loop=l7) sch.reorder(l10, l14, l11, l15, l2, l3, l6, l7) l59 = sch.fuse(l10, l14, l11, l15) v60 = sch.sample_categorical( candidates=[32, 64, 128, 256, 512, 1024], probs=[ 0.16666666666666666, 0.16666666666666666, 0.16666666666666666, 0.16666666666666666, 0.16666666666666666, 0.16666666666666666, ], decision=2, ) l61, l62 = sch.split(loop=l59, factors=[None, v60]) sch.bind(loop=l61, thread_axis="blockIdx.x") sch.bind(loop=l62, thread_axis="threadIdx.x")
def bgemm(sch: Schedule): b31 = sch.get_block(name="bgemm") sch.annotate( block_or_loop=b31, ann_key="meta_schedule.tiling_structure", ann_val="SSSRRSRS", ) b32 = sch.cache_write(block=b31, write_buffer_index=0, storage_scope="local") b31, b32 = b32, b31 l33, l34, l35, l36, l37 = sch.get_loops(block=b32) v38, v39, v40, v41, v42 = sch.sample_perfect_tile( n=5, loop=l33, max_innermost_factor=64, decision=[1, 1, 1, 1, 6], ) l43, l44, l45, l46, l47 = sch.split(loop=l33, factors=[v38, v39, v40, v41, v42]) v48, v49, v50, v51, v52 = sch.sample_perfect_tile( n=5, loop=l34, max_innermost_factor=64, decision=[1, 1, 1, 3, 2], ) l53, l54, l55, l56, l57 = sch.split(loop=l34, factors=[v48, v49, v50, v51, v52]) v58, v59, v60, v61, v62 = sch.sample_perfect_tile( n=5, loop=l35, max_innermost_factor=64, decision=[3, 1, 1, 1, 3], ) l63, l64, l65, l66, l67 = sch.split(loop=l35, factors=[v58, v59, v60, v61, v62]) v68, v69, v70, v71, v72 = sch.sample_perfect_tile( n=5, loop=l36, max_innermost_factor=64, decision=[4, 2, 1, 4, 4], ) l73, l74, l75, l76, l77 = sch.split(loop=l36, factors=[v68, v69, v70, v71, v72]) v78, v79, v80 = sch.sample_perfect_tile( n=3, loop=l37, max_innermost_factor=64, decision=[32, 1, 4], ) l81, l82, l83 = sch.split(loop=l37, factors=[v78, v79, v80]) sch.reorder( # fmt: off l43, l53, l63, l73, l44, l54, l64, l74, l45, l55, l65, l75, l81, l82, l46, l56, l66, l76, l83, l47, l57, l67, l77, # fmt: on ) l84 = sch.fuse(l43, l53, l63, l73) sch.bind(loop=l84, thread_axis="blockIdx.x") l85 = sch.fuse(l44, l54, l64, l74) sch.bind(loop=l85, thread_axis="vthread.x") l86 = sch.fuse(l45, l55, l65, l75) sch.bind(loop=l86, thread_axis="threadIdx.x") b87 = sch.cache_read(block=b32, read_buffer_index=1, storage_scope="shared") sch.compute_at(block=b87, loop=l81, preserve_unit_loops=True) _, _, _, _, l92, l93, l94, l95 = sch.get_loops(block=b87) sch.fuse(l92, l93, l94, l95) v97 = sch.sample_categorical( candidates=[1, 2, 3, 4], probs=[0.25, 0.25, 0.25, 0.25], decision=1, ) sch.annotate( block_or_loop=b87, ann_key="meta_schedule.cooperative_fetch", ann_val=v97, ) b101 = sch.cache_read(block=b32, read_buffer_index=2, storage_scope="shared") sch.compute_at(block=b101, loop=l81, preserve_unit_loops=True) _, _, _, _, l106, l107, l108, l109 = sch.get_loops(block=b101) sch.fuse(l106, l107, l108, l109) v110 = sch.sample_categorical( candidates=[1, 2, 3, 4], probs=[0.25, 0.25, 0.25, 0.25], decision=1, ) sch.annotate( block_or_loop=b101, ann_key="meta_schedule.cooperative_fetch", ann_val=v110, ) sch.reverse_compute_at(block=b31, loop=l86, preserve_unit_loops=True)
def inline(sch: Schedule): b125 = sch.get_block(name="A") sch.compute_inline(block=b125) b126 = sch.get_block(name="B") sch.compute_inline(block=b126)
def bgemm(sch: Schedule): bgemm = sch.get_block(name="bgemm") write_cache = sch.cache_write( block=bgemm, write_buffer_index=0, storage_scope="global", ) sch.annotate( block_or_loop=bgemm, ann_key="meta_schedule.tiling_structure", ann_val="SSRSRS", ) # b33, b34 = b34, b33 l35, l36, l37, l38, l39 = sch.get_loops(block=bgemm) v40, v41, v42, v43 = sch.sample_perfect_tile( n=4, loop=l35, max_innermost_factor=64, decision=[1, 2, 3, 1], ) l44, l45, l46, l47 = sch.split(loop=l35, factors=[v40, v41, v42, v43]) v48, v49, v50, v51 = sch.sample_perfect_tile( n=4, loop=l36, max_innermost_factor=64, decision=[1, 1, 1, 6], ) l52, l53, l54, l55 = sch.split(loop=l36, factors=[v48, v49, v50, v51]) v56, v57, v58, v59 = sch.sample_perfect_tile( n=4, loop=l37, max_innermost_factor=64, decision=[1, 1, 1, 9], ) l60, l61, l62, l63 = sch.split(loop=l37, factors=[v56, v57, v58, v59]) v64, v65, v66, v67 = sch.sample_perfect_tile( n=4, loop=l38, max_innermost_factor=64, decision=[2, 1, 16, 4], ) l68, l69, l70, l71 = sch.split(loop=l38, factors=[v64, v65, v66, v67]) v72, v73 = sch.sample_perfect_tile( n=2, loop=l39, max_innermost_factor=64, decision=[16, 8], ) l74, l75 = sch.split(loop=l39, factors=[v72, v73]) sch.reorder( # fmt: off l44, l52, l60, l68, l45, l53, l61, l69, l74, l46, l54, l62, l70, l75, l47, l55, l63, l71, # fmt: on ) sch.reverse_compute_at(block=write_cache, loop=l69, preserve_unit_loops=True)
def _schedule_matmul_small(sch: Schedule): block = sch.get_block("matmul") _, j, k = sch.get_loops(block=block) _, _ = sch.split(j, sch.sample_perfect_tile(j, n=2)) _, _ = sch.split(k, sch.sample_perfect_tile(k, n=2))
def inline(sch: Schedule): b1 = sch.get_block(name="A") b2 = sch.get_block(name="B") sch.compute_inline(block=b1) sch.compute_inline(block=b2)