def test_reduction_decompose1():
    s = tir.Schedule(rowsum_blockized, debug_mask="all")
    blockized_B = s.get_block("blockized_B")
    io, ko = s.get_loops(blockized_B)
    s.decompose_reduction(blockized_B, io)
    tvm.ir.assert_structural_equal(matmul_decompose1, s.mod["main"])
    verify_trace_roundtrip(s, mod=rowsum_blockized)
def test_reduction_decompose2():
    s = tir.Schedule(matmul, debug_mask="all")
    C = s.get_block("update")
    i, j, k = s.get_loops(C)
    s.decompose_reduction(C, k)
    tvm.ir.assert_structural_equal(matmul_decompose2, s.mod["main"])
    verify_trace_roundtrip(s, mod=matmul)
def test_blockize_init_loops():
    @T.prim_func
    def rowsum(A: T.Buffer[(128, 128), "float32"],
               B: T.Buffer[(128, ), "float32"]) -> None:
        for k, i in T.grid(128, 128):
            with T.block("B"):
                vk, vi = T.axis.remap("RS", [k, i])
                with T.init():
                    B[vi] = 0.0
                B[vi] = B[vi] + A[vi, vk]

    @T.prim_func
    def after_rowsum_blockize(
        A: T.Buffer[(128, 128), "float32"],
        B: T.Buffer[(128, ), "float32"],
    ) -> None:
        with T.block("blockized_B"):
            vko = T.axis.R(1, 0)
            vio = T.axis.S(1, 0)
            with T.init():
                for i1 in T.serial(0, 128):
                    with T.block("B_init"):
                        vi_init = T.axis.S(128, i1)
                        B[vi_init] = T.float32(0)
            for i0, i1_1 in T.grid(128, 128):
                with T.block("B"):
                    vk, vi = T.axis.remap("RS", [i0, i1_1])
                    B[vi] = B[vi] + A[vi, vk]

    s = tir.Schedule(rowsum, debug_mask="all")
    k, _ = s.get_loops(s.get_block("B"))
    s.blockize(k)
    tvm.ir.assert_structural_equal(s.mod["main"], after_rowsum_blockize)
    verify_trace_roundtrip(sch=s, mod=rowsum)
def test_continuous_cache_write():
    sch = tir.Schedule(elementwise, debug_mask="all")
    block_b = sch.get_block("B")
    sch.cache_write(block_b, 0, "shared")
    sch.cache_write(block_b, 0, "local")
    tvm.ir.assert_structural_equal(continuous_cache_write, sch.mod["main"])
    verify_trace_roundtrip(sch=sch, mod=elementwise)
def test_compute_inline_multi_loads():
    sch = tir.Schedule(elementwise_multi_loads, debug_mask="all")
    block_b = sch.get_block("B")
    sch.compute_inline(block_b)
    tvm.ir.assert_structural_equal(elementwise_multi_loads_inlined,
                                   sch.mod["main"])
    verify_trace_roundtrip(sch=sch, mod=elementwise_multi_loads)
def test_tir_schedule_copy_2():
    sch = tir.Schedule(mod=matmul, debug_mask="all")
    i, j, k = sch.get_loops(sch.get_block("update"))
    sch_copy = sch.copy()
    assert not sch.get_sref(i).same_as(sch_copy.get_sref(i))
    assert not sch.get_sref(j).same_as(sch_copy.get_sref(j))
    assert not sch.get_sref(k).same_as(sch_copy.get_sref(k))
    assert sch.get_sref(i).stmt.same_as(sch_copy.get_sref(i).stmt)
    assert sch.get_sref(j).stmt.same_as(sch_copy.get_sref(j).stmt)
    assert sch.get_sref(k).stmt.same_as(sch_copy.get_sref(k).stmt)
    i_0, i_1 = sch.split(i, factors=[None, 64])
    j_0, j_1 = sch_copy.split(j, factors=[None, 32])

    assert sch.get_sref(i_0).stmt.extent == 2
    assert sch.get_sref(i_1).stmt.extent == 64
    with pytest.raises(IndexError):
        sch_copy.get_sref(i_0)
    with pytest.raises(IndexError):
        sch_copy.get_sref(i_1)

    with pytest.raises(IndexError):
        sch.get_sref(j_0)
    with pytest.raises(IndexError):
        sch.get_sref(j_1)
    assert sch_copy.get_sref(j_0).stmt.extent == 4
    assert sch_copy.get_sref(j_1).stmt.extent == 32
    verify_trace_roundtrip(sch, mod=matmul)
    verify_trace_roundtrip(sch_copy, mod=matmul)
Example #7
0
def test_set_scope_subregion():
    func = element_wise_subregion_match
    s = tir.Schedule(func, debug_mask='all')
    s.set_scope(s.get_block("B"), 0, "shared")
    tvm.ir.assert_structural_equal(element_wise_subregion_match_set_scope,
                                   s.mod["main"])
    verify_trace_roundtrip(sch=s, mod=func)
def test_reduction_rfactor_spatial_only():
    s = tir.Schedule(rfactor_spatial_only, debug_mask="all")
    block = s.get_block(name="acc", func_name="main")
    _, _, _, _, loop, _ = s.get_loops(block)
    s.rfactor(loop=loop, factor_axis=4)
    tvm.ir.assert_structural_equal(s.mod["main"], rfactor_spatial_only_after)
    verify_trace_roundtrip(s, mod=rfactor_spatial_only)
def test_compute_at_cuda_matmul_4():
    sch = tir.Schedule(cuda_matmul_4, debug_mask="all")
    block = sch.get_block("B_shared")
    _, _, _, _, _, _, loop, _, _, _, _ = sch.get_loops(sch.get_block("C"))
    sch.compute_at(block, loop, preserve_unit_loops=True)
    tvm.ir.assert_structural_equal(cuda_matmul_5, sch.mod["main"])
    verify_trace_roundtrip(sch=sch, mod=cuda_matmul_4)
Example #10
0
def test_two_elementwise_transform_output_buffer():
    sch = tir.Schedule(two_elementwise, debug_mask="all")
    block = sch.get_block("C")
    sch.transform_layout(block, 0, "write", packed_index_map_func)
    tvm.ir.assert_structural_equal(two_elementwise_transformed_output_buffer,
                                   sch.mod["main"])
    verify_trace_roundtrip(sch=sch, mod=two_elementwise)
def test_compute_at_blockized_1():
    sch = tir.Schedule(blockized_1, debug_mask="all")
    block = sch.get_block("B")
    _, loop = sch.get_loops(sch.get_block("C_outer"))
    sch.compute_at(block, loop, preserve_unit_loops=True)
    tvm.ir.assert_structural_equal(blockized_after_compute_at, sch.mod["main"])
    verify_trace_roundtrip(sch=sch, mod=blockized_1)
def test_storage_align(use_block_name):
    func = element_wise
    s = tir.Schedule(func, debug_mask='all')
    B = 'B' if use_block_name else s.get_block("B")
    s.storage_align(B, 0, axis=0, factor=128, offset=127)
    tvm.ir.assert_structural_equal(element_wise_storage_align, s.mod["main"])
    verify_trace_roundtrip(sch=s, mod=func)
Example #13
0
def _find_match_sketch_id(
    mod: IRModule,
    sketches: List[Schedule],
    expected_mod: IRModule,
    expected_decision: List[Tuple[str, List[int]]],
    *,
    debug_mask="all",
) -> Optional[int]:
    for sketch_id, sketch in enumerate(sketches):
        i = 0
        new_decisions = {}
        for inst in sketch.trace.insts:
            if not inst.kind.name.startswith("Sample"):
                continue
            assert i < len(expected_decision)
            if inst.kind.name == expected_decision[i][0]:
                new_decisions[inst] = expected_decision[i][1]
                i += 1
        if len(new_decisions) != len(expected_decision):
            continue
        sch = Schedule(mod, debug_mask=debug_mask)
        Trace(
            insts=sketch.trace.insts,
            decisions=new_decisions,
        ).apply_to_schedule(sch, remove_postproc=True)
        if structural_equal(sch.mod, expected_mod):
            verify_trace_roundtrip(sch=sch, mod=mod, debug_mask=debug_mask)
            return sketch_id
    return None
def test_split_with_opaque_block():
    sch = tir.Schedule(elementwise_with_opaque_block, debug_mask="all")
    block_opaque = sch.get_block("opaque")
    i, _, _ = sch.get_loops(block_opaque)
    sch.split(i, factors=[None, 16])
    tvm.ir.assert_structural_equal(elementwise_split_with_opaque_block, sch.mod["main"])
    verify_trace_roundtrip(sch=sch, mod=elementwise_with_opaque_block)
def test_reorder2():
    sch = tir.Schedule(elementwise, debug_mask="all")
    block_b = sch.get_block("B")
    i, j, k, l = sch.get_loops(block_b)
    sch.reorder(k, i, l)
    tvm.ir.assert_structural_equal(elementwise_reordered2, sch.mod["main"])
    verify_trace_roundtrip(sch=sch, mod=elementwise)
def test_split_symbolic():
    sch = tir.Schedule(elementwise_symbolic, debug_mask="all")
    block_b = sch.get_block("B")
    _, _, k = sch.get_loops(block_b)
    sch.split(k, factors=[10, None])
    tvm.ir.assert_structural_equal(elementwise_symbolic_split, sch.mod["main"])
    verify_trace_roundtrip(sch=sch, mod=elementwise_symbolic)
def test_fuse_symbolic():
    sch = tir.Schedule(elementwise_symbolic, debug_mask="all")
    block_b = sch.get_block("B")
    i, j, k = sch.get_loops(block_b)
    sch.fuse(i, j, k)
    tvm.ir.assert_structural_equal(elementwise_symbolic_fused, sch.mod["main"])
    verify_trace_roundtrip(sch=sch, mod=elementwise_symbolic)
def test_fuse_with_opaque_block():
    sch = tir.Schedule(elementwise_with_opaque_block, debug_mask="all")
    block_opaque = sch.get_block("opaque")
    i, j, k = sch.get_loops(block_opaque)
    sch.fuse(i, j, k)
    tvm.ir.assert_structural_equal(elementwise_fuse_with_opaque_block, sch.mod["main"])
    verify_trace_roundtrip(sch=sch, mod=elementwise_with_opaque_block)
def test_cache_read_under_scope():
    sch = tir.Schedule(access_under_scope, debug_mask="all")
    block_b = sch.get_block("B")
    block_c = sch.get_block("C")
    sch.cache_read(block_b, 0, "local")
    sch.cache_read(block_c, 0, "global")
    tvm.ir.assert_structural_equal(cache_read_under_scope, sch.mod["main"])
    verify_trace_roundtrip(sch=sch, mod=access_under_scope)
Example #20
0
def test_reindex_read_basic(use_block_name, use_buffer_name):
    sch = tir.Schedule(transpose_elementwise)
    block = "B" if use_block_name else sch.get_block("B")
    buf = "A" if use_buffer_name else ("read", 0)
    sch.reindex(block, buf)
    tvm.ir.assert_structural_equal(transpose_elementwise_reindex_read,
                                   sch.mod["main"])
    verify_trace_roundtrip(sch=sch, mod=transpose_elementwise)
def test_sample_perfect_tile_composite():
    sch = tir.Schedule(elementwise, debug_mask="all")
    _, _, i = sch.get_loops(sch.get_block("B"))
    factors = sch.sample_perfect_tile(i, n=4)
    factors = [sch.get(i) for i in factors]
    prod = factors[0] * factors[1] * factors[2] * factors[3]
    assert prod == 1470
    verify_trace_roundtrip(sch, mod=elementwise)
Example #22
0
def test_reverse_compute_inline_under_loop():
    sch = tir.Schedule(elementwise_under_loop, debug_mask="all")
    block_b = sch.get_block("B")
    block_c = sch.get_block("C")
    sch.reverse_compute_inline(block_c)
    tvm.ir.assert_structural_equal(elementwise_inlined, sch.mod["main"])
    assert sch.get(block_b).name_hint == "B"
    verify_trace_roundtrip(sch=sch, mod=elementwise_under_loop)
Example #23
0
def test_bind2():
    s = tir.Schedule(element_wise_compute_at_split, debug_mask="all")
    _, j0 = s.get_loops(s.get_block("B"))
    _, j1o, _ = s.get_loops(s.get_block("C"))
    s.bind(j0, "threadIdx.x")
    s.bind(j1o, "threadIdx.x")
    tvm.ir.assert_structural_equal(s.mod["main"], element_wise_compute_at_split_j0_j1o_bound)
    verify_trace_roundtrip(s, mod=element_wise_compute_at_split)
Example #24
0
def test_read_out_of_bound():
    sch = tir.Schedule(read_out_of_bound, debug_mask="all")
    block = sch.get_block("B")
    (loop, ) = sch.get_loops(sch.get_block("C"))
    sch.compute_at(block, loop)
    tvm.ir.assert_structural_equal(read_out_of_bound_after_compute_at,
                                   sch.mod["main"])
    verify_trace_roundtrip(sch=sch, mod=read_out_of_bound)
Example #25
0
def test_compute_inline_as_dce():
    sch = tir.Schedule(elementwise_standalone, debug_mask="all")
    block_b = sch.get_block("B")
    block_c = sch.get_block("C")
    sch.compute_inline(block_b)
    tvm.ir.assert_structural_equal(elementwise_standalone_dce, sch.mod["main"])
    assert sch.get(block_c).name_hint == "C"
    verify_trace_roundtrip(sch=sch, mod=elementwise_standalone)
Example #26
0
def test_reverse_compute_at_factorized():
    sch = tir.Schedule(factorized, debug_mask="all")
    block = sch.get_block("B")
    _, loop, _, _ = sch.get_loops(sch.get_block("B_rf"))
    sch.reverse_compute_at(block, loop, preserve_unit_loops=False)
    tvm.ir.assert_structural_equal(factorized_after_reverse_compute_at,
                                   sch.mod["main"])
    verify_trace_roundtrip(sch=sch, mod=factorized)
Example #27
0
def test_compute_at_two_elementwise():
    sch = tir.Schedule(two_elementwise, debug_mask="all")
    block = sch.get_block("B")
    loop, _ = sch.get_loops(sch.get_block("C"))
    sch.compute_at(block, loop, preserve_unit_loops=True)
    tvm.ir.assert_structural_equal(two_elementwise_after_compute_at,
                                   sch.mod["main"])
    verify_trace_roundtrip(sch=sch, mod=two_elementwise)
def test_reverse_compute_inline_affine_load_unit_iter_simplified(use_block_name):
    sch = tir.Schedule(elementwise_reverse_affine_load_unit_iter_simplified, debug_mask="all")
    block_c = "C" if use_block_name else sch.get_block("C")
    sch.reverse_compute_inline(block_c)
    tvm.ir.assert_structural_equal(
        elementwise_reverse_affine_load_unit_iter_simplified_inlined, sch.mod["main"]
    )
    verify_trace_roundtrip(sch=sch, mod=elementwise_reverse_affine_load_unit_iter_simplified)
def test_compute_inline_elementwise(use_block_name):
    sch = tir.Schedule(elementwise, debug_mask="all")
    block_b = "B" if use_block_name else sch.get_block("B")
    block_c = sch.get_block("C")
    sch.compute_inline(block_b)
    tvm.ir.assert_structural_equal(elementwise_inlined, sch.mod["main"])
    assert sch.get(block_c).name_hint == "C"
    verify_trace_roundtrip(sch=sch, mod=elementwise)
Example #30
0
def test_reduction_rfactor_square_sum_square_root():
    s = tir.Schedule(transformed_square_sum_square_root, debug_mask="all")
    _, _, f_i = s.get_loops(s.get_block("C"))
    rf_block = s.rfactor(f_i, 0)
    tvm.ir.assert_structural_equal(s.mod["main"],
                                   square_sum_square_root_rfactor)
    assert s.get(rf_block).same_as(s.get(s.get_block("C_rf")))
    verify_trace_roundtrip(s, mod=transformed_square_sum_square_root)