Beispiel #1
0
def _sch(decisions: List[List[int]]) -> Schedule:
    sch = Schedule(matmul, debug_mask="all")
    # pylint: disable=invalid-name
    d0, d1, d2 = decisions
    b0 = sch.get_block(name="C", func_name="main")
    root = sch.get_block(name="root", func_name="main")
    sch.get_consumers(block=b0)
    b1 = sch.cache_write(block=b0,
                         write_buffer_index=0,
                         storage_scope="global")
    l2, l3, l4 = sch.get_loops(block=b0)
    v5, v6, v7, v8 = sch.sample_perfect_tile(
        loop=l2,
        n=4,
        max_innermost_factor=64,
        decision=d0,
    )
    l9, l10, l11, l12 = sch.split(loop=l2, factors=[v5, v6, v7, v8])
    v13, v14, v15, v16 = sch.sample_perfect_tile(
        loop=l3,
        n=4,
        max_innermost_factor=64,
        decision=d1,
    )
    l17, l18, l19, l20 = sch.split(loop=l3, factors=[v13, v14, v15, v16])
    v21, v22 = sch.sample_perfect_tile(
        loop=l4,
        n=2,
        max_innermost_factor=64,
        decision=d2,
    )
    l23, l24 = sch.split(loop=l4, factors=[v21, v22])
    sch.reorder(l9, l17, l10, l18, l23, l11, l19, l24, l12, l20)
    sch.reverse_compute_at(block=b1, loop=l18, preserve_unit_loops=True)
    v57 = sch.sample_categorical(
        candidates=[0, 16, 64, 512],
        probs=[0.25, 0.25, 0.25, 0.25],
        decision=0,
    )
    sch.annotate(block_or_loop=root,
                 ann_key="meta_schedule.unroll_explicit",
                 ann_val=v57)
    # pylint: enable=invalid-name
    return sch
def _sch() -> Schedule:
    sch = Schedule(element_wise, debug_mask="all")
    # pylint: disable=invalid-name
    b0 = sch.get_block(name="C", func_name="main")
    l1, l2 = sch.get_loops(block=b0)
    l3 = sch.fuse(l1, l2)
    v4 = sch.sample_categorical(
        candidates=[32, 64, 128, 256, 512, 1024],
        probs=[
            0.16666666666666666,
            0.16666666666666666,
            0.16666666666666666,
            0.16666666666666666,
            0.16666666666666666,
            0.16666666666666666,
        ],
        decision=3,
    )
    l5, l6 = sch.split(loop=l3, factors=[None, v4])
    sch.bind(loop=l5, thread_axis="blockIdx.x")
    sch.bind(loop=l6, thread_axis="threadIdx.x")
    # pylint: enable=invalid-name
    return sch