コード例 #1
0
def _schedule_matmul(sch: Schedule):
    block = sch.get_block("matmul")
    i, j, k = sch.get_loops(block=block)
    i_0, i_1, i_2, i_3 = sch.split(loop=i, factors=[2, 4, 64, 2])
    j_0, j_1, j_2, j_3 = sch.split(loop=j, factors=[4, 64, 2, 2])
    k_0, k_1 = sch.split(loop=k, factors=[32, 32])
    sch.reorder(i_0, j_0, i_1, j_1, k_0, i_2, j_2, k_1, i_3, j_3)
コード例 #2
0
def _sch(decisions: List[List[int]], ann_val: int) -> Schedule:
    sch = Schedule(matmul, debug_mask="all")
    # pylint: disable=invalid-name
    d0, d1, d2 = decisions
    b0 = sch.get_block(name="C", func_name="main")
    root = sch.get_block(name="root", func_name="main")
    sch.get_consumers(block=b0)
    b1 = sch.cache_write(block=b0, write_buffer_index=0, storage_scope="global")
    l2, l3, l4 = sch.get_loops(block=b0)
    v5, v6, v7, v8 = sch.sample_perfect_tile(
        loop=l2,
        n=4,
        max_innermost_factor=64,
        decision=d0,
    )
    l9, l10, l11, l12 = sch.split(loop=l2, factors=[v5, v6, v7, v8])
    v13, v14, v15, v16 = sch.sample_perfect_tile(
        loop=l3,
        n=4,
        max_innermost_factor=64,
        decision=d1,
    )
    l17, l18, l19, l20 = sch.split(loop=l3, factors=[v13, v14, v15, v16])
    v21, v22 = sch.sample_perfect_tile(
        loop=l4,
        n=2,
        max_innermost_factor=64,
        decision=d2,
    )
    l23, l24 = sch.split(loop=l4, factors=[v21, v22])
    sch.reorder(l9, l17, l10, l18, l23, l11, l19, l24, l12, l20)
    sch.reverse_compute_at(block=b1, loop=l18, preserve_unit_loops=True)
    sch.annotate(block_or_loop=root, ann_key="meta_schedule.parallel", ann_val=ann_val)
    # pylint: enable=invalid-name
    return sch
コード例 #3
0
def _schedule_matmul(sch: Schedule):
    block = sch.get_block("matmul")
    i, j, k = sch.get_loops(block=block)
    # TODO(@zxybazh): Change to `sample_perfect_tile` after upstreaming
    i_0, i_1, i_2, i_3 = sch.split(loop=i, factors=[2, 4, 64, 2])
    j_0, j_1, j_2, j_3 = sch.split(loop=j, factors=[4, 64, 2, 2])
    k_0, k_1 = sch.split(loop=k, factors=[32, 32])
    sch.reorder(i_0, j_0, i_1, j_1, k_0, i_2, j_2, k_1, i_3, j_3)
コード例 #4
0
def test_tile_with_tensor_intrin_conv2d_nchwc_vnni():
    s = Schedule(Conv2dNCHWcVNNIModule)
    block = s.get_block("conv2d_NCHWc_int8")
    tiled_loop = tile_with_tensor_intrin(s, block, VNNI_DOT_16x4_INTRIN)
    tiled_loops = s.get_loops(block)
    assert len(tiled_loops) == 12
    assert s.get(tiled_loop) == s.get(tiled_loops[-2])
    tvm.ir.assert_structural_equal(s.mod, Conv2dNCHWcVNNIModuleTiled)
コード例 #5
0
def _schedule_matmul(sch: Schedule):
    block = sch.get_block("matmul")
    i, j, k = sch.get_loops(block=block)
    i_tiles = [1, 1, 2, 512]
    j_tiles = [1, 512, 1, 2]
    k_tiles = [256, 4]
    i_0, i_1, i_2, i_3 = sch.split(loop=i, factors=i_tiles)
    j_0, j_1, j_2, j_3 = sch.split(loop=j, factors=j_tiles)
    k_0, k_1 = sch.split(loop=k, factors=k_tiles)
    sch.reorder(i_0, j_0, i_1, j_1, k_0, i_2, j_2, k_1, i_3, j_3)
コード例 #6
0
def test_tile_with_tensor_intrin_dense_vnni():
    s = Schedule(DenseVNNIModule)
    block = s.get_block("compute")

    tiled_loop = tile_with_tensor_intrin(s, block, VNNI_DOT_16x4_INTRIN)

    _, _, _, i1_1, _ = s.get_loops(block)

    assert s.get(tiled_loop) == s.get(i1_1)
    tvm.ir.assert_structural_equal(s.mod, DenseVNNIModuleTiled)
コード例 #7
0
def test_get_tensorize_loop_mapping_dense_vnni():
    s = Schedule(DenseVNNIModule)
    block = s.get_block("compute")

    info = get_tensorize_loop_mapping(s, block, dot_product_16x4_u8i8i32_desc)

    assert isinstance(info, TensorizeInfo)

    desc_loop_to_sref = dict((v, k) for k, v in info.loop_map.items())

    desc_loops = collect_loops(dot_product_16x4_u8i8i32_desc)
    _, loop_j, loop_k = s.get_loops(block)

    assert desc_loops[0] in desc_loop_to_sref and desc_loops[
        1] in desc_loop_to_sref
    assert s.get(desc_loop_to_sref[desc_loops[0]]) == s.get(loop_j)
    assert s.get(desc_loop_to_sref[desc_loops[1]]) == s.get(loop_k)
コード例 #8
0
def test_get_tensorize_loop_mapping_conv2d_nchwc_vnni():
    s = Schedule(Conv2dNCHWcVNNIModule)
    block = s.get_block("conv2d_NCHWc_int8")

    info = get_tensorize_loop_mapping(s, block, dot_product_16x4_u8i8i32_desc)

    desc_loop_to_sref = dict((v, k) for k, v in info.loop_map.items())

    desc_loops = collect_loops(dot_product_16x4_u8i8i32_desc)

    # i4 corresonds to the inner output channel axis of the NCHWc output tensor
    # for i0, i1, i2, i3, i4, i5, i6, i7, i8, i9 in T.grid(1, 16, 56, 56, 16, 1, 1, 4, 4, 4):
    _, _, _, _, i4, _, _, _, _, i9 = s.get_loops(block)

    assert desc_loops[0] in desc_loop_to_sref and desc_loops[
        1] in desc_loop_to_sref
    assert s.get(desc_loop_to_sref[desc_loops[0]]) == s.get(i4)
    assert s.get(desc_loop_to_sref[desc_loops[1]]) == s.get(i9)
コード例 #9
0
def test_get_tensorize_loop_mapping_matmul_mma():
    @T.prim_func
    def matmul_16x16x16xf16f16f16_desc(
        A: T.Buffer((16, 16), "float16", align=128, offset_factor=1),
        B: T.Buffer((16, 16), "float16", align=128, offset_factor=1),
        C: T.Buffer((16, 16), "float16", align=128, offset_factor=1),
    ) -> None:
        with T.block("root"):
            T.reads(C[0:16, 0:16], A[0:16, 0:16], B[0:16, 0:16])
            T.writes(C[0:16, 0:16])
            for i, j, k in T.grid(16, 16, 16):
                with T.block("update"):
                    vii, vjj, vkk = T.axis.remap("SSR", [i, j, k])
                    C[vii, vjj] = C[vii, vjj] + A[vii, vkk] * B[vjj, vkk]

    matmul = create_prim_func(te_workload.matmul_relu(
        n=512,
        m=512,
        k=512,
    ))

    s = Schedule(matmul)
    block = s.get_block("C")
    i0, i1, i2 = s.get_loops(block)
    desc_loops = collect_loops(matmul_16x16x16xf16f16f16_desc)

    for do_reorder in [False, True]:
        # Mapping should be invariant to the loop permutation
        if do_reorder:
            s.reorder(i2, i0, i1)

        info = get_tensorize_loop_mapping(s, block,
                                          matmul_16x16x16xf16f16f16_desc)
        assert info is not None
        desc_loop_to_sref = dict((v, k) for k, v in info.loop_map.items())

        for i in range(3):
            assert desc_loops[i] in desc_loop_to_sref

        assert s.get(desc_loop_to_sref[desc_loops[0]]) == s.get(i0)
        assert s.get(desc_loop_to_sref[desc_loops[1]]) == s.get(i1)
        assert s.get(desc_loop_to_sref[desc_loops[2]]) == s.get(i2)
コード例 #10
0
def _sch(decisions: List[List[int]]) -> Schedule:
    sch = Schedule(matmul, debug_mask="all")
    # pylint: disable=invalid-name
    (d0,) = decisions
    b0 = sch.get_block(name="C", func_name="main")
    sch.get_consumers(block=b0)
    b1 = sch.cache_write(block=b0, write_buffer_index=0, storage_scope="global")
    l2, l3, l4 = sch.get_loops(block=b0)
    v5, v6, v7, v8 = sch.sample_perfect_tile(
        loop=l2,
        n=4,
        max_innermost_factor=64,
        decision=d0,
    )
    l9, l10, l11, l12 = sch.split(loop=l2, factors=[v5, v6, v7, v8])
    l17, l18, l19, l20 = sch.split(loop=l3, factors=[8, 4, 8, 2])
    l23, l24 = sch.split(loop=l4, factors=[512, 1])
    sch.reorder(l9, l17, l10, l18, l23, l11, l19, l24, l12, l20)
    sch.reverse_compute_at(block=b1, loop=l18, preserve_unit_loops=True)
    # pylint: enable=invalid-name
    return sch
コード例 #11
0
def _sch() -> Schedule:
    sch = Schedule(element_wise, debug_mask="all")
    # pylint: disable=invalid-name
    b0 = sch.get_block(name="C", func_name="main")
    l1, l2 = sch.get_loops(block=b0)
    l3 = sch.fuse(l1, l2)
    v4 = sch.sample_categorical(
        candidates=[32, 64, 128, 256, 512, 1024],
        probs=[
            0.16666666666666666,
            0.16666666666666666,
            0.16666666666666666,
            0.16666666666666666,
            0.16666666666666666,
            0.16666666666666666,
        ],
        decision=3,
    )
    l5, l6 = sch.split(loop=l3, factors=[None, v4])
    sch.bind(loop=l5, thread_axis="blockIdx.x")
    sch.bind(loop=l6, thread_axis="threadIdx.x")
    # pylint: enable=invalid-name
    return sch