def test_block_in_opaque_block(): s = tir.ScheduleState(block_in_opaque_block, debug_mask="all") # pylint: disable=protected-access assert s._get_cached_flags(_get_block(s, "B")) == CachedFlags( affine_binding=True, region_cover=True, stage_pipeline=True, ) assert s._get_cached_flags(_get_block(s, "C")) == CachedFlags( affine_binding=True, region_cover=True, stage_pipeline=True, ) assert s._get_cached_flags(_get_block(s, "E")) == CachedFlags( affine_binding=True, region_cover=True, stage_pipeline=True, ) assert s._get_cached_flags(_get_block(s, "F")) == CachedFlags( affine_binding=True, region_cover=True, stage_pipeline=True, ) assert s._get_cached_flags(_get_block(s, "root")) == CachedFlags( affine_binding=True, region_cover=True, stage_pipeline=True, )
def replace_ir_builder_module(deep_copy=False, realize=False): new_func = tvm.script.from_source(elementwise.script()) other_func = tvm.script.from_source(elementwise.script()) mod = IRModule(functions={"main": new_func, "other": other_func}) s = tir.ScheduleState(mod, debug_mask="all") target = tvm.tir.Block( iter_vars=[], reads=[], writes=[], name_hint="target", body=s.mod["main"].body.block.body[1], init=None, alloc_buffers=None, match_buffers=None, annotations=None, ) if realize: target = tvm.tir.BlockRealize( iter_values=[], predicate=True, block=target, ) if deep_copy: target.__setstate__(target.__getstate__()) gc.collect() return s, target
def test_uncovered_producer_region(): s = tir.ScheduleState(uncovered_producer_region, debug_mask="all") # pylint: disable=protected-access assert s._get_cached_flags(_get_block(s, "consumer")) == CachedFlags( affine_binding=True, region_cover=False, stage_pipeline=True, )
def test_replace_block_remap(): func = elementwise s = tir.ScheduleState(func, debug_mode=True) # The target stmt target = matmul.body.block.body.body.body[0].block sref = s.get_sref(s.mod["main"].body.block.body[0].body.body.block) s.replace(sref, target, {sref.stmt: target}) sref_new = s.get_sref(s.mod["main"].body.block.body[0].body.body.block) # Check the original sref has been remapped assert sref.__hash__() == sref_new.__hash__() tvm.ir.assert_structural_equal(sref.stmt, target)
def test_non_perfect_tiling_cache(): s = tir.ScheduleState(non_perfect_tiling_cache, debug_mask="all") # pylint: disable=protected-access assert s._get_cached_flags(_get_block(s, "cache")) == CachedFlags( affine_binding=False, region_cover=True, stage_pipeline=True, ) assert s._get_cached_flags(_get_block(s, "compute")) == CachedFlags( affine_binding=True, region_cover=True, stage_pipeline=True, )
def test_war_dependency(): s = tir.ScheduleState(war_dependency, debug_mode=True) root = _get_block(s, "root") block_c = _get_block(s, "C") block_b = _get_block(s, "B") # Check get_deps_by_src (dep, ) = s.get_block_scope(root).get_deps_by_src(block_c) assert dep.src.same_as(block_c) assert dep.dst.same_as(block_b) assert dep.kind == DepKind.WAR # Check get_deps_by_dst (dep, ) = s.get_block_scope(root).get_deps_by_dst(block_b) assert dep.src.same_as(block_c) assert dep.dst.same_as(block_b) assert dep.kind == DepKind.WAR
def test_elementwise_dependency(): s = tir.ScheduleState(elementwise, debug_mask="all") root = _get_block(s, "root") block_b = _get_block(s, "B") block_c = _get_block(s, "C") # Check get_deps_by_src (dep,) = s.get_block_scope(root).get_deps_by_src(block_b) assert dep.src.same_as(block_b) assert dep.dst.same_as(block_c) assert dep.kind == DepKind.RAW # Check get_deps_by_dst (dep,) = s.get_block_scope(root).get_deps_by_dst(block_c) assert dep.src.same_as(block_b) assert dep.dst.same_as(block_c) assert dep.kind == DepKind.RAW
def test_warp_memory_negative(): s = tir.ScheduleState(warp_memory_negative, debug_mask="all") # pylint: disable=protected-access assert s._get_cached_flags(_get_block(s, "root")) == CachedFlags( affine_binding=True, region_cover=True, stage_pipeline=False, ) assert s._get_cached_flags(_get_block(s, "B")) == CachedFlags( affine_binding=True, region_cover=True, stage_pipeline=True, ) assert s._get_cached_flags(_get_block(s, "C")) == CachedFlags( affine_binding=True, region_cover=False, stage_pipeline=True, )
def test_equal_ranked_threads(): s = tir.ScheduleState(equal_ranked_threads, debug_mask="all") # pylint: disable=protected-access assert s._get_cached_flags(_get_block(s, "root")) == CachedFlags( affine_binding=True, region_cover=True, stage_pipeline=True, ) assert s._get_cached_flags(_get_block(s, "B")) == CachedFlags( affine_binding=True, region_cover=True, stage_pipeline=True, ) assert s._get_cached_flags(_get_block(s, "C")) == CachedFlags( affine_binding=True, region_cover=True, stage_pipeline=True, )
def test_elementwise_affine_producer(): s = tir.ScheduleState(elementwise_affine_producer, debug_mask="all") # pylint: disable=protected-access assert s._get_cached_flags(_get_block(s, "root")) == CachedFlags( affine_binding=True, region_cover=True, stage_pipeline=True, ) assert s._get_cached_flags(_get_block(s, "B")) == CachedFlags( affine_binding=True, region_cover=True, stage_pipeline=True, ) assert s._get_cached_flags(_get_block(s, "C")) == CachedFlags( affine_binding=True, region_cover=True, stage_pipeline=True, )
def test_loop_carried_dependency(): s = tir.ScheduleState(loop_carried_dependency, debug_mask="all") # pylint: disable=protected-access assert s._get_cached_flags(_get_block(s, "B")) == CachedFlags( affine_binding=True, region_cover=True, stage_pipeline=True, ) assert s._get_cached_flags(_get_block(s, "C")) == CachedFlags( affine_binding=True, region_cover=False, stage_pipeline=True, ) assert s._get_cached_flags(_get_block(s, "root")) == CachedFlags( affine_binding=True, region_cover=True, stage_pipeline=False, )
def test_matmul(): s = tir.ScheduleState(matmul, debug_mask="all") # pylint: disable=protected-access assert s._get_cached_flags(_get_block(s, "init")) == CachedFlags( affine_binding=True, region_cover=True, stage_pipeline=True, ) assert s._get_cached_flags(_get_block(s, "update")) == CachedFlags( affine_binding=True, region_cover=True, stage_pipeline=True, ) assert s._get_cached_flags(_get_block(s, "root")) == CachedFlags( affine_binding=True, region_cover=True, stage_pipeline=True, )
def test_thread_binding(): s = tir.ScheduleState(bound_to_thread, debug_mode=True) # pylint: disable=protected-access assert s._get_cached_flags(_get_block(s, "root")) == CachedFlags( affine_binding=True, region_cover=True, stage_pipeline=True, ) assert s._get_cached_flags(_get_block(s, "B")) == CachedFlags( affine_binding=True, region_cover=True, stage_pipeline=True, ) assert s._get_cached_flags(_get_block(s, "C")) == CachedFlags( affine_binding=True, region_cover=True, stage_pipeline=True, )
def test_multi_producer_consumer(): s = tir.ScheduleState(multi_producer_consumer, debug_mask="all") # pylint: disable=protected-access assert s._get_cached_flags(_get_block(s, "A_0")) == CachedFlags( affine_binding=True, region_cover=True, stage_pipeline=True, ) assert s._get_cached_flags(_get_block(s, "A_1")) == CachedFlags( affine_binding=True, region_cover=True, stage_pipeline=True, ) assert s._get_cached_flags(_get_block(s, "B_0")) == CachedFlags( affine_binding=True, region_cover=True, stage_pipeline=True, ) assert s._get_cached_flags(_get_block(s, "B_1")) == CachedFlags( affine_binding=True, region_cover=True, stage_pipeline=True, )
def test_concatenate_multi_producer_covered(): # pylint: disable=invalid-name s = tir.ScheduleState(concatenate_multi_producer, debug_mask="all") # pylint: disable=protected-access assert s._get_cached_flags(_get_block(s, "A_0")) == CachedFlags( affine_binding=True, region_cover=True, stage_pipeline=True, ) assert s._get_cached_flags(_get_block(s, "A_1")) == CachedFlags( affine_binding=True, region_cover=True, stage_pipeline=True, ) assert s._get_cached_flags(_get_block(s, "B")) == CachedFlags( affine_binding=True, region_cover=True, stage_pipeline=True, ) assert s._get_cached_flags(_get_block(s, "root")) == CachedFlags( affine_binding=True, region_cover=True, stage_pipeline=True, )
def test_matmul_dependency(): s = tir.ScheduleState(matmul, debug_mask="all") root = _get_block(s, "root") init = _get_block(s, "init") update = _get_block(s, "update") # Check get_deps_by_src p0, p1 = s.get_block_scope(root).get_deps_by_src(init) assert p0.src.same_as(init) assert p0.dst.same_as(update) assert p1.src.same_as(init) assert p1.dst.same_as(update) assert (p0.kind == DepKind.RAW and p1.kind == DepKind.WAW) or ( p0.kind == DepKind.WAW and p1.kind == DepKind.RAW ) # Check get_deps_by_dst p0, p1 = s.get_block_scope(root).get_deps_by_dst(update) assert p0.src.same_as(init) assert p0.dst.same_as(update) assert p1.src.same_as(init) assert p1.dst.same_as(update) assert (p0.kind == DepKind.RAW and p1.kind == DepKind.WAW) or ( p0.kind == DepKind.WAW and p1.kind == DepKind.RAW )
def test_subblock_uncovered(): s = tir.ScheduleState(elementwise_subblock_uncovered, debug_mode=True) # pylint: disable=protected-access assert s._get_cached_flags(_get_block(s, "root")) == CachedFlags( affine_binding=True, region_cover=True, stage_pipeline=False, ) assert s._get_cached_flags(_get_block(s, "B")) == CachedFlags( affine_binding=True, region_cover=True, stage_pipeline=True, ) assert s._get_cached_flags(_get_block(s, "B_sub")) == CachedFlags( affine_binding=True, region_cover=True, stage_pipeline=True, ) assert s._get_cached_flags(_get_block(s, "C")) == CachedFlags( affine_binding=True, region_cover=False, stage_pipeline=True, )
def replace_ir_builder_with_opaque(): func = tvm.script.from_source(block_in_opaque_block.script()) s = tir.ScheduleState(func, debug_mask="all") gc.collect() return s
def replace_ir_builder_with_opaque(): func = tvm.script.from_source(tvm.script.asscript(block_in_opaque_block)) s = tir.ScheduleState(func, debug_mode=True) gc.collect() return s