def test_reduction_decompose1(): s = tir.Schedule(rowsum_blockized, debug_mask="all") blockized_B = s.get_block("blockized_B") io, ko = s.get_loops(blockized_B) s.decompose_reduction(blockized_B, io) tvm.ir.assert_structural_equal(matmul_decompose1, s.mod["main"]) verify_trace_roundtrip(s, mod=rowsum_blockized)
def test_reduction_decompose2(): s = tir.Schedule(matmul, debug_mask="all") C = s.get_block("update") i, j, k = s.get_loops(C) s.decompose_reduction(C, k) tvm.ir.assert_structural_equal(matmul_decompose2, s.mod["main"]) verify_trace_roundtrip(s, mod=matmul)
def test_blockize_init_loops(): @T.prim_func def rowsum(A: T.Buffer[(128, 128), "float32"], B: T.Buffer[(128, ), "float32"]) -> None: for k, i in T.grid(128, 128): with T.block("B"): vk, vi = T.axis.remap("RS", [k, i]) with T.init(): B[vi] = 0.0 B[vi] = B[vi] + A[vi, vk] @T.prim_func def after_rowsum_blockize( A: T.Buffer[(128, 128), "float32"], B: T.Buffer[(128, ), "float32"], ) -> None: with T.block("blockized_B"): vko = T.axis.R(1, 0) vio = T.axis.S(1, 0) with T.init(): for i1 in T.serial(0, 128): with T.block("B_init"): vi_init = T.axis.S(128, i1) B[vi_init] = T.float32(0) for i0, i1_1 in T.grid(128, 128): with T.block("B"): vk, vi = T.axis.remap("RS", [i0, i1_1]) B[vi] = B[vi] + A[vi, vk] s = tir.Schedule(rowsum, debug_mask="all") k, _ = s.get_loops(s.get_block("B")) s.blockize(k) tvm.ir.assert_structural_equal(s.mod["main"], after_rowsum_blockize) verify_trace_roundtrip(sch=s, mod=rowsum)
def test_continuous_cache_write(): sch = tir.Schedule(elementwise, debug_mask="all") block_b = sch.get_block("B") sch.cache_write(block_b, 0, "shared") sch.cache_write(block_b, 0, "local") tvm.ir.assert_structural_equal(continuous_cache_write, sch.mod["main"]) verify_trace_roundtrip(sch=sch, mod=elementwise)
def test_compute_inline_multi_loads(): sch = tir.Schedule(elementwise_multi_loads, debug_mask="all") block_b = sch.get_block("B") sch.compute_inline(block_b) tvm.ir.assert_structural_equal(elementwise_multi_loads_inlined, sch.mod["main"]) verify_trace_roundtrip(sch=sch, mod=elementwise_multi_loads)
def test_tir_schedule_copy_2(): sch = tir.Schedule(mod=matmul, debug_mask="all") i, j, k = sch.get_loops(sch.get_block("update")) sch_copy = sch.copy() assert not sch.get_sref(i).same_as(sch_copy.get_sref(i)) assert not sch.get_sref(j).same_as(sch_copy.get_sref(j)) assert not sch.get_sref(k).same_as(sch_copy.get_sref(k)) assert sch.get_sref(i).stmt.same_as(sch_copy.get_sref(i).stmt) assert sch.get_sref(j).stmt.same_as(sch_copy.get_sref(j).stmt) assert sch.get_sref(k).stmt.same_as(sch_copy.get_sref(k).stmt) i_0, i_1 = sch.split(i, factors=[None, 64]) j_0, j_1 = sch_copy.split(j, factors=[None, 32]) assert sch.get_sref(i_0).stmt.extent == 2 assert sch.get_sref(i_1).stmt.extent == 64 with pytest.raises(IndexError): sch_copy.get_sref(i_0) with pytest.raises(IndexError): sch_copy.get_sref(i_1) with pytest.raises(IndexError): sch.get_sref(j_0) with pytest.raises(IndexError): sch.get_sref(j_1) assert sch_copy.get_sref(j_0).stmt.extent == 4 assert sch_copy.get_sref(j_1).stmt.extent == 32 verify_trace_roundtrip(sch, mod=matmul) verify_trace_roundtrip(sch_copy, mod=matmul)
def test_set_scope_subregion(): func = element_wise_subregion_match s = tir.Schedule(func, debug_mask='all') s.set_scope(s.get_block("B"), 0, "shared") tvm.ir.assert_structural_equal(element_wise_subregion_match_set_scope, s.mod["main"]) verify_trace_roundtrip(sch=s, mod=func)
def test_reduction_rfactor_spatial_only(): s = tir.Schedule(rfactor_spatial_only, debug_mask="all") block = s.get_block(name="acc", func_name="main") _, _, _, _, loop, _ = s.get_loops(block) s.rfactor(loop=loop, factor_axis=4) tvm.ir.assert_structural_equal(s.mod["main"], rfactor_spatial_only_after) verify_trace_roundtrip(s, mod=rfactor_spatial_only)
def test_compute_at_cuda_matmul_4(): sch = tir.Schedule(cuda_matmul_4, debug_mask="all") block = sch.get_block("B_shared") _, _, _, _, _, _, loop, _, _, _, _ = sch.get_loops(sch.get_block("C")) sch.compute_at(block, loop, preserve_unit_loops=True) tvm.ir.assert_structural_equal(cuda_matmul_5, sch.mod["main"]) verify_trace_roundtrip(sch=sch, mod=cuda_matmul_4)
def test_two_elementwise_transform_output_buffer(): sch = tir.Schedule(two_elementwise, debug_mask="all") block = sch.get_block("C") sch.transform_layout(block, 0, "write", packed_index_map_func) tvm.ir.assert_structural_equal(two_elementwise_transformed_output_buffer, sch.mod["main"]) verify_trace_roundtrip(sch=sch, mod=two_elementwise)
def test_compute_at_blockized_1(): sch = tir.Schedule(blockized_1, debug_mask="all") block = sch.get_block("B") _, loop = sch.get_loops(sch.get_block("C_outer")) sch.compute_at(block, loop, preserve_unit_loops=True) tvm.ir.assert_structural_equal(blockized_after_compute_at, sch.mod["main"]) verify_trace_roundtrip(sch=sch, mod=blockized_1)
def test_storage_align(use_block_name): func = element_wise s = tir.Schedule(func, debug_mask='all') B = 'B' if use_block_name else s.get_block("B") s.storage_align(B, 0, axis=0, factor=128, offset=127) tvm.ir.assert_structural_equal(element_wise_storage_align, s.mod["main"]) verify_trace_roundtrip(sch=s, mod=func)
def _find_match_sketch_id( mod: IRModule, sketches: List[Schedule], expected_mod: IRModule, expected_decision: List[Tuple[str, List[int]]], *, debug_mask="all", ) -> Optional[int]: for sketch_id, sketch in enumerate(sketches): i = 0 new_decisions = {} for inst in sketch.trace.insts: if not inst.kind.name.startswith("Sample"): continue assert i < len(expected_decision) if inst.kind.name == expected_decision[i][0]: new_decisions[inst] = expected_decision[i][1] i += 1 if len(new_decisions) != len(expected_decision): continue sch = Schedule(mod, debug_mask=debug_mask) Trace( insts=sketch.trace.insts, decisions=new_decisions, ).apply_to_schedule(sch, remove_postproc=True) if structural_equal(sch.mod, expected_mod): verify_trace_roundtrip(sch=sch, mod=mod, debug_mask=debug_mask) return sketch_id return None
def test_split_with_opaque_block(): sch = tir.Schedule(elementwise_with_opaque_block, debug_mask="all") block_opaque = sch.get_block("opaque") i, _, _ = sch.get_loops(block_opaque) sch.split(i, factors=[None, 16]) tvm.ir.assert_structural_equal(elementwise_split_with_opaque_block, sch.mod["main"]) verify_trace_roundtrip(sch=sch, mod=elementwise_with_opaque_block)
def test_reorder2(): sch = tir.Schedule(elementwise, debug_mask="all") block_b = sch.get_block("B") i, j, k, l = sch.get_loops(block_b) sch.reorder(k, i, l) tvm.ir.assert_structural_equal(elementwise_reordered2, sch.mod["main"]) verify_trace_roundtrip(sch=sch, mod=elementwise)
def test_split_symbolic(): sch = tir.Schedule(elementwise_symbolic, debug_mask="all") block_b = sch.get_block("B") _, _, k = sch.get_loops(block_b) sch.split(k, factors=[10, None]) tvm.ir.assert_structural_equal(elementwise_symbolic_split, sch.mod["main"]) verify_trace_roundtrip(sch=sch, mod=elementwise_symbolic)
def test_fuse_symbolic(): sch = tir.Schedule(elementwise_symbolic, debug_mask="all") block_b = sch.get_block("B") i, j, k = sch.get_loops(block_b) sch.fuse(i, j, k) tvm.ir.assert_structural_equal(elementwise_symbolic_fused, sch.mod["main"]) verify_trace_roundtrip(sch=sch, mod=elementwise_symbolic)
def test_fuse_with_opaque_block(): sch = tir.Schedule(elementwise_with_opaque_block, debug_mask="all") block_opaque = sch.get_block("opaque") i, j, k = sch.get_loops(block_opaque) sch.fuse(i, j, k) tvm.ir.assert_structural_equal(elementwise_fuse_with_opaque_block, sch.mod["main"]) verify_trace_roundtrip(sch=sch, mod=elementwise_with_opaque_block)
def test_cache_read_under_scope(): sch = tir.Schedule(access_under_scope, debug_mask="all") block_b = sch.get_block("B") block_c = sch.get_block("C") sch.cache_read(block_b, 0, "local") sch.cache_read(block_c, 0, "global") tvm.ir.assert_structural_equal(cache_read_under_scope, sch.mod["main"]) verify_trace_roundtrip(sch=sch, mod=access_under_scope)
def test_reindex_read_basic(use_block_name, use_buffer_name): sch = tir.Schedule(transpose_elementwise) block = "B" if use_block_name else sch.get_block("B") buf = "A" if use_buffer_name else ("read", 0) sch.reindex(block, buf) tvm.ir.assert_structural_equal(transpose_elementwise_reindex_read, sch.mod["main"]) verify_trace_roundtrip(sch=sch, mod=transpose_elementwise)
def test_sample_perfect_tile_composite(): sch = tir.Schedule(elementwise, debug_mask="all") _, _, i = sch.get_loops(sch.get_block("B")) factors = sch.sample_perfect_tile(i, n=4) factors = [sch.get(i) for i in factors] prod = factors[0] * factors[1] * factors[2] * factors[3] assert prod == 1470 verify_trace_roundtrip(sch, mod=elementwise)
def test_reverse_compute_inline_under_loop(): sch = tir.Schedule(elementwise_under_loop, debug_mask="all") block_b = sch.get_block("B") block_c = sch.get_block("C") sch.reverse_compute_inline(block_c) tvm.ir.assert_structural_equal(elementwise_inlined, sch.mod["main"]) assert sch.get(block_b).name_hint == "B" verify_trace_roundtrip(sch=sch, mod=elementwise_under_loop)
def test_bind2(): s = tir.Schedule(element_wise_compute_at_split, debug_mask="all") _, j0 = s.get_loops(s.get_block("B")) _, j1o, _ = s.get_loops(s.get_block("C")) s.bind(j0, "threadIdx.x") s.bind(j1o, "threadIdx.x") tvm.ir.assert_structural_equal(s.mod["main"], element_wise_compute_at_split_j0_j1o_bound) verify_trace_roundtrip(s, mod=element_wise_compute_at_split)
def test_read_out_of_bound(): sch = tir.Schedule(read_out_of_bound, debug_mask="all") block = sch.get_block("B") (loop, ) = sch.get_loops(sch.get_block("C")) sch.compute_at(block, loop) tvm.ir.assert_structural_equal(read_out_of_bound_after_compute_at, sch.mod["main"]) verify_trace_roundtrip(sch=sch, mod=read_out_of_bound)
def test_compute_inline_as_dce(): sch = tir.Schedule(elementwise_standalone, debug_mask="all") block_b = sch.get_block("B") block_c = sch.get_block("C") sch.compute_inline(block_b) tvm.ir.assert_structural_equal(elementwise_standalone_dce, sch.mod["main"]) assert sch.get(block_c).name_hint == "C" verify_trace_roundtrip(sch=sch, mod=elementwise_standalone)
def test_reverse_compute_at_factorized(): sch = tir.Schedule(factorized, debug_mask="all") block = sch.get_block("B") _, loop, _, _ = sch.get_loops(sch.get_block("B_rf")) sch.reverse_compute_at(block, loop, preserve_unit_loops=False) tvm.ir.assert_structural_equal(factorized_after_reverse_compute_at, sch.mod["main"]) verify_trace_roundtrip(sch=sch, mod=factorized)
def test_compute_at_two_elementwise(): sch = tir.Schedule(two_elementwise, debug_mask="all") block = sch.get_block("B") loop, _ = sch.get_loops(sch.get_block("C")) sch.compute_at(block, loop, preserve_unit_loops=True) tvm.ir.assert_structural_equal(two_elementwise_after_compute_at, sch.mod["main"]) verify_trace_roundtrip(sch=sch, mod=two_elementwise)
def test_reverse_compute_inline_affine_load_unit_iter_simplified(use_block_name): sch = tir.Schedule(elementwise_reverse_affine_load_unit_iter_simplified, debug_mask="all") block_c = "C" if use_block_name else sch.get_block("C") sch.reverse_compute_inline(block_c) tvm.ir.assert_structural_equal( elementwise_reverse_affine_load_unit_iter_simplified_inlined, sch.mod["main"] ) verify_trace_roundtrip(sch=sch, mod=elementwise_reverse_affine_load_unit_iter_simplified)
def test_compute_inline_elementwise(use_block_name): sch = tir.Schedule(elementwise, debug_mask="all") block_b = "B" if use_block_name else sch.get_block("B") block_c = sch.get_block("C") sch.compute_inline(block_b) tvm.ir.assert_structural_equal(elementwise_inlined, sch.mod["main"]) assert sch.get(block_c).name_hint == "C" verify_trace_roundtrip(sch=sch, mod=elementwise)
def test_reduction_rfactor_square_sum_square_root(): s = tir.Schedule(transformed_square_sum_square_root, debug_mask="all") _, _, f_i = s.get_loops(s.get_block("C")) rf_block = s.rfactor(f_i, 0) tvm.ir.assert_structural_equal(s.mod["main"], square_sum_square_root_rfactor) assert s.get(rf_block).same_as(s.get(s.get_block("C_rf"))) verify_trace_roundtrip(s, mod=transformed_square_sum_square_root)