def test_block_in_opaque_block(): s = tir.ScheduleState(block_in_opaque_block, debug_mask="all") # pylint: disable=protected-access assert s._get_cached_flags(_get_block(s, "B")) == CachedFlags( affine_binding=True, region_cover=True, stage_pipeline=True, ) assert s._get_cached_flags(_get_block(s, "C")) == CachedFlags( affine_binding=True, region_cover=True, stage_pipeline=True, ) assert s._get_cached_flags(_get_block(s, "E")) == CachedFlags( affine_binding=True, region_cover=True, stage_pipeline=True, ) assert s._get_cached_flags(_get_block(s, "F")) == CachedFlags( affine_binding=True, region_cover=True, stage_pipeline=True, ) assert s._get_cached_flags(_get_block(s, "root")) == CachedFlags( affine_binding=True, region_cover=True, stage_pipeline=True, )
def test_non_perfect_tiling_cache(): s = tir.ScheduleState(non_perfect_tiling_cache, debug_mask="all") # pylint: disable=protected-access assert s._get_cached_flags(_get_block(s, "cache")) == CachedFlags( affine_binding=False, region_cover=True, stage_pipeline=True, ) assert s._get_cached_flags(_get_block(s, "compute")) == CachedFlags( affine_binding=True, region_cover=True, stage_pipeline=True, )
def test_uncovered_producer_region(): s = tir.ScheduleState(uncovered_producer_region, debug_mask="all") # pylint: disable=protected-access assert s._get_cached_flags(_get_block(s, "consumer")) == CachedFlags( affine_binding=True, region_cover=False, stage_pipeline=True, )
def test_warp_memory_negative(): s = tir.ScheduleState(warp_memory_negative, debug_mask="all") # pylint: disable=protected-access assert s._get_cached_flags(_get_block(s, "root")) == CachedFlags( affine_binding=True, region_cover=True, stage_pipeline=False, ) assert s._get_cached_flags(_get_block(s, "B")) == CachedFlags( affine_binding=True, region_cover=True, stage_pipeline=True, ) assert s._get_cached_flags(_get_block(s, "C")) == CachedFlags( affine_binding=True, region_cover=False, stage_pipeline=True, )
def test_equal_ranked_threads(): s = tir.ScheduleState(equal_ranked_threads, debug_mask="all") # pylint: disable=protected-access assert s._get_cached_flags(_get_block(s, "root")) == CachedFlags( affine_binding=True, region_cover=True, stage_pipeline=True, ) assert s._get_cached_flags(_get_block(s, "B")) == CachedFlags( affine_binding=True, region_cover=True, stage_pipeline=True, ) assert s._get_cached_flags(_get_block(s, "C")) == CachedFlags( affine_binding=True, region_cover=True, stage_pipeline=True, )
def test_elementwise_affine_producer(): s = tir.ScheduleState(elementwise_affine_producer, debug_mask="all") # pylint: disable=protected-access assert s._get_cached_flags(_get_block(s, "root")) == CachedFlags( affine_binding=True, region_cover=True, stage_pipeline=True, ) assert s._get_cached_flags(_get_block(s, "B")) == CachedFlags( affine_binding=True, region_cover=True, stage_pipeline=True, ) assert s._get_cached_flags(_get_block(s, "C")) == CachedFlags( affine_binding=True, region_cover=True, stage_pipeline=True, )
def test_loop_carried_dependency(): s = tir.ScheduleState(loop_carried_dependency, debug_mask="all") # pylint: disable=protected-access assert s._get_cached_flags(_get_block(s, "B")) == CachedFlags( affine_binding=True, region_cover=True, stage_pipeline=True, ) assert s._get_cached_flags(_get_block(s, "C")) == CachedFlags( affine_binding=True, region_cover=False, stage_pipeline=True, ) assert s._get_cached_flags(_get_block(s, "root")) == CachedFlags( affine_binding=True, region_cover=True, stage_pipeline=False, )
def test_matmul(): s = tir.ScheduleState(matmul, debug_mask="all") # pylint: disable=protected-access assert s._get_cached_flags(_get_block(s, "init")) == CachedFlags( affine_binding=True, region_cover=True, stage_pipeline=True, ) assert s._get_cached_flags(_get_block(s, "update")) == CachedFlags( affine_binding=True, region_cover=True, stage_pipeline=True, ) assert s._get_cached_flags(_get_block(s, "root")) == CachedFlags( affine_binding=True, region_cover=True, stage_pipeline=True, )
def test_thread_binding(): s = tir.ScheduleState(bound_to_thread, debug_mode=True) # pylint: disable=protected-access assert s._get_cached_flags(_get_block(s, "root")) == CachedFlags( affine_binding=True, region_cover=True, stage_pipeline=True, ) assert s._get_cached_flags(_get_block(s, "B")) == CachedFlags( affine_binding=True, region_cover=True, stage_pipeline=True, ) assert s._get_cached_flags(_get_block(s, "C")) == CachedFlags( affine_binding=True, region_cover=True, stage_pipeline=True, )
def test_multi_producer_consumer(): s = tir.ScheduleState(multi_producer_consumer, debug_mask="all") # pylint: disable=protected-access assert s._get_cached_flags(_get_block(s, "A_0")) == CachedFlags( affine_binding=True, region_cover=True, stage_pipeline=True, ) assert s._get_cached_flags(_get_block(s, "A_1")) == CachedFlags( affine_binding=True, region_cover=True, stage_pipeline=True, ) assert s._get_cached_flags(_get_block(s, "B_0")) == CachedFlags( affine_binding=True, region_cover=True, stage_pipeline=True, ) assert s._get_cached_flags(_get_block(s, "B_1")) == CachedFlags( affine_binding=True, region_cover=True, stage_pipeline=True, )
def test_concatenate_multi_producer_covered(): # pylint: disable=invalid-name s = tir.ScheduleState(concatenate_multi_producer, debug_mask="all") # pylint: disable=protected-access assert s._get_cached_flags(_get_block(s, "A_0")) == CachedFlags( affine_binding=True, region_cover=True, stage_pipeline=True, ) assert s._get_cached_flags(_get_block(s, "A_1")) == CachedFlags( affine_binding=True, region_cover=True, stage_pipeline=True, ) assert s._get_cached_flags(_get_block(s, "B")) == CachedFlags( affine_binding=True, region_cover=True, stage_pipeline=True, ) assert s._get_cached_flags(_get_block(s, "root")) == CachedFlags( affine_binding=True, region_cover=True, stage_pipeline=True, )
def test_subblock_uncovered(): s = tir.ScheduleState(elementwise_subblock_uncovered, debug_mode=True) # pylint: disable=protected-access assert s._get_cached_flags(_get_block(s, "root")) == CachedFlags( affine_binding=True, region_cover=True, stage_pipeline=False, ) assert s._get_cached_flags(_get_block(s, "B")) == CachedFlags( affine_binding=True, region_cover=True, stage_pipeline=True, ) assert s._get_cached_flags(_get_block(s, "B_sub")) == CachedFlags( affine_binding=True, region_cover=True, stage_pipeline=True, ) assert s._get_cached_flags(_get_block(s, "C")) == CachedFlags( affine_binding=True, region_cover=False, stage_pipeline=True, )