def _sch(decisions: List[List[int]], ann_val: int) -> Schedule: sch = Schedule(matmul, debug_mask="all") # pylint: disable=invalid-name d0, d1, d2 = decisions b0 = sch.get_block(name="C", func_name="main") root = sch.get_block(name="root", func_name="main") sch.get_consumers(block=b0) b1 = sch.cache_write(block=b0, write_buffer_index=0, storage_scope="global") l2, l3, l4 = sch.get_loops(block=b0) v5, v6, v7, v8 = sch.sample_perfect_tile( loop=l2, n=4, max_innermost_factor=64, decision=d0, ) l9, l10, l11, l12 = sch.split(loop=l2, factors=[v5, v6, v7, v8]) v13, v14, v15, v16 = sch.sample_perfect_tile( loop=l3, n=4, max_innermost_factor=64, decision=d1, ) l17, l18, l19, l20 = sch.split(loop=l3, factors=[v13, v14, v15, v16]) v21, v22 = sch.sample_perfect_tile( loop=l4, n=2, max_innermost_factor=64, decision=d2, ) l23, l24 = sch.split(loop=l4, factors=[v21, v22]) sch.reorder(l9, l17, l10, l18, l23, l11, l19, l24, l12, l20) sch.reverse_compute_at(block=b1, loop=l18, preserve_unit_loops=True) sch.annotate(block_or_loop=root, ann_key="meta_schedule.parallel", ann_val=ann_val) # pylint: enable=invalid-name return sch
def _schedule_matmul(sch: Schedule): block = sch.get_block("matmul") i, j, k = sch.get_loops(block=block) i_0, i_1, i_2, i_3 = sch.split(loop=i, factors=[2, 4, 64, 2]) j_0, j_1, j_2, j_3 = sch.split(loop=j, factors=[4, 64, 2, 2]) k_0, k_1 = sch.split(loop=k, factors=[32, 32]) sch.reorder(i_0, j_0, i_1, j_1, k_0, i_2, j_2, k_1, i_3, j_3)
def _schedule_matmul(sch: Schedule): block = sch.get_block("matmul") i, j, k = sch.get_loops(block=block) # TODO(@zxybazh): Change to `sample_perfect_tile` after upstreaming i_0, i_1, i_2, i_3 = sch.split(loop=i, factors=[2, 4, 64, 2]) j_0, j_1, j_2, j_3 = sch.split(loop=j, factors=[4, 64, 2, 2]) k_0, k_1 = sch.split(loop=k, factors=[32, 32]) sch.reorder(i_0, j_0, i_1, j_1, k_0, i_2, j_2, k_1, i_3, j_3)
def _schedule_matmul(sch: Schedule): block = sch.get_block("matmul") i, j, k = sch.get_loops(block=block) i_tiles = [1, 1, 2, 512] j_tiles = [1, 512, 1, 2] k_tiles = [256, 4] i_0, i_1, i_2, i_3 = sch.split(loop=i, factors=i_tiles) j_0, j_1, j_2, j_3 = sch.split(loop=j, factors=j_tiles) k_0, k_1 = sch.split(loop=k, factors=k_tiles) sch.reorder(i_0, j_0, i_1, j_1, k_0, i_2, j_2, k_1, i_3, j_3)
def _sch(decisions: List[List[int]]) -> Schedule: sch = Schedule(matmul, debug_mask="all") # pylint: disable=invalid-name (d0,) = decisions b0 = sch.get_block(name="C", func_name="main") sch.get_consumers(block=b0) b1 = sch.cache_write(block=b0, write_buffer_index=0, storage_scope="global") l2, l3, l4 = sch.get_loops(block=b0) v5, v6, v7, v8 = sch.sample_perfect_tile( loop=l2, n=4, max_innermost_factor=64, decision=d0, ) l9, l10, l11, l12 = sch.split(loop=l2, factors=[v5, v6, v7, v8]) l17, l18, l19, l20 = sch.split(loop=l3, factors=[8, 4, 8, 2]) l23, l24 = sch.split(loop=l4, factors=[512, 1]) sch.reorder(l9, l17, l10, l18, l23, l11, l19, l24, l12, l20) sch.reverse_compute_at(block=b1, loop=l18, preserve_unit_loops=True) # pylint: enable=invalid-name return sch
def _sch() -> Schedule: sch = Schedule(element_wise, debug_mask="all") # pylint: disable=invalid-name b0 = sch.get_block(name="C", func_name="main") l1, l2 = sch.get_loops(block=b0) l3 = sch.fuse(l1, l2) v4 = sch.sample_categorical( candidates=[32, 64, 128, 256, 512, 1024], probs=[ 0.16666666666666666, 0.16666666666666666, 0.16666666666666666, 0.16666666666666666, 0.16666666666666666, 0.16666666666666666, ], decision=3, ) l5, l6 = sch.split(loop=l3, factors=[None, v4]) sch.bind(loop=l5, thread_axis="blockIdx.x") sch.bind(loop=l6, thread_axis="threadIdx.x") # pylint: enable=invalid-name return sch