Esempio n. 1
0
def test_block_in_opaque_block():
    s = tir.ScheduleState(block_in_opaque_block, debug_mask="all")
    # pylint: disable=protected-access
    assert s._get_cached_flags(_get_block(s, "B")) == CachedFlags(
        affine_binding=True,
        region_cover=True,
        stage_pipeline=True,
    )
    assert s._get_cached_flags(_get_block(s, "C")) == CachedFlags(
        affine_binding=True,
        region_cover=True,
        stage_pipeline=True,
    )
    assert s._get_cached_flags(_get_block(s, "E")) == CachedFlags(
        affine_binding=True,
        region_cover=True,
        stage_pipeline=True,
    )
    assert s._get_cached_flags(_get_block(s, "F")) == CachedFlags(
        affine_binding=True,
        region_cover=True,
        stage_pipeline=True,
    )
    assert s._get_cached_flags(_get_block(s, "root")) == CachedFlags(
        affine_binding=True,
        region_cover=True,
        stage_pipeline=True,
    )
Esempio n. 2
0
def replace_ir_builder_module(deep_copy=False, realize=False):
    new_func = tvm.script.from_source(elementwise.script())
    other_func = tvm.script.from_source(elementwise.script())
    mod = IRModule(functions={"main": new_func, "other": other_func})
    s = tir.ScheduleState(mod, debug_mask="all")
    target = tvm.tir.Block(
        iter_vars=[],
        reads=[],
        writes=[],
        name_hint="target",
        body=s.mod["main"].body.block.body[1],
        init=None,
        alloc_buffers=None,
        match_buffers=None,
        annotations=None,
    )
    if realize:
        target = tvm.tir.BlockRealize(
            iter_values=[],
            predicate=True,
            block=target,
        )
    if deep_copy:
        target.__setstate__(target.__getstate__())
    gc.collect()
    return s, target
def test_uncovered_producer_region():
    s = tir.ScheduleState(uncovered_producer_region, debug_mask="all")
    # pylint: disable=protected-access
    assert s._get_cached_flags(_get_block(s, "consumer")) == CachedFlags(
        affine_binding=True,
        region_cover=False,
        stage_pipeline=True,
    )
Esempio n. 4
0
def test_replace_block_remap():
    func = elementwise
    s = tir.ScheduleState(func, debug_mode=True)
    # The target stmt
    target = matmul.body.block.body.body.body[0].block
    sref = s.get_sref(s.mod["main"].body.block.body[0].body.body.block)
    s.replace(sref, target, {sref.stmt: target})
    sref_new = s.get_sref(s.mod["main"].body.block.body[0].body.body.block)
    # Check the original sref has been remapped
    assert sref.__hash__() == sref_new.__hash__()
    tvm.ir.assert_structural_equal(sref.stmt, target)
def test_non_perfect_tiling_cache():
    s = tir.ScheduleState(non_perfect_tiling_cache, debug_mask="all")
    # pylint: disable=protected-access
    assert s._get_cached_flags(_get_block(s, "cache")) == CachedFlags(
        affine_binding=False,
        region_cover=True,
        stage_pipeline=True,
    )
    assert s._get_cached_flags(_get_block(s, "compute")) == CachedFlags(
        affine_binding=True,
        region_cover=True,
        stage_pipeline=True,
    )
def test_war_dependency():
    s = tir.ScheduleState(war_dependency, debug_mode=True)
    root = _get_block(s, "root")
    block_c = _get_block(s, "C")
    block_b = _get_block(s, "B")
    # Check get_deps_by_src
    (dep, ) = s.get_block_scope(root).get_deps_by_src(block_c)
    assert dep.src.same_as(block_c)
    assert dep.dst.same_as(block_b)
    assert dep.kind == DepKind.WAR
    # Check get_deps_by_dst
    (dep, ) = s.get_block_scope(root).get_deps_by_dst(block_b)
    assert dep.src.same_as(block_c)
    assert dep.dst.same_as(block_b)
    assert dep.kind == DepKind.WAR
Esempio n. 7
0
def test_elementwise_dependency():
    s = tir.ScheduleState(elementwise, debug_mask="all")
    root = _get_block(s, "root")
    block_b = _get_block(s, "B")
    block_c = _get_block(s, "C")
    # Check get_deps_by_src
    (dep,) = s.get_block_scope(root).get_deps_by_src(block_b)
    assert dep.src.same_as(block_b)
    assert dep.dst.same_as(block_c)
    assert dep.kind == DepKind.RAW
    # Check get_deps_by_dst
    (dep,) = s.get_block_scope(root).get_deps_by_dst(block_c)
    assert dep.src.same_as(block_b)
    assert dep.dst.same_as(block_c)
    assert dep.kind == DepKind.RAW
def test_warp_memory_negative():
    s = tir.ScheduleState(warp_memory_negative, debug_mask="all")
    # pylint: disable=protected-access
    assert s._get_cached_flags(_get_block(s, "root")) == CachedFlags(
        affine_binding=True,
        region_cover=True,
        stage_pipeline=False,
    )
    assert s._get_cached_flags(_get_block(s, "B")) == CachedFlags(
        affine_binding=True,
        region_cover=True,
        stage_pipeline=True,
    )
    assert s._get_cached_flags(_get_block(s, "C")) == CachedFlags(
        affine_binding=True,
        region_cover=False,
        stage_pipeline=True,
    )
def test_equal_ranked_threads():
    s = tir.ScheduleState(equal_ranked_threads, debug_mask="all")
    # pylint: disable=protected-access
    assert s._get_cached_flags(_get_block(s, "root")) == CachedFlags(
        affine_binding=True,
        region_cover=True,
        stage_pipeline=True,
    )
    assert s._get_cached_flags(_get_block(s, "B")) == CachedFlags(
        affine_binding=True,
        region_cover=True,
        stage_pipeline=True,
    )
    assert s._get_cached_flags(_get_block(s, "C")) == CachedFlags(
        affine_binding=True,
        region_cover=True,
        stage_pipeline=True,
    )
def test_elementwise_affine_producer():
    s = tir.ScheduleState(elementwise_affine_producer, debug_mask="all")
    # pylint: disable=protected-access
    assert s._get_cached_flags(_get_block(s, "root")) == CachedFlags(
        affine_binding=True,
        region_cover=True,
        stage_pipeline=True,
    )
    assert s._get_cached_flags(_get_block(s, "B")) == CachedFlags(
        affine_binding=True,
        region_cover=True,
        stage_pipeline=True,
    )
    assert s._get_cached_flags(_get_block(s, "C")) == CachedFlags(
        affine_binding=True,
        region_cover=True,
        stage_pipeline=True,
    )
def test_loop_carried_dependency():
    s = tir.ScheduleState(loop_carried_dependency, debug_mask="all")
    # pylint: disable=protected-access
    assert s._get_cached_flags(_get_block(s, "B")) == CachedFlags(
        affine_binding=True,
        region_cover=True,
        stage_pipeline=True,
    )
    assert s._get_cached_flags(_get_block(s, "C")) == CachedFlags(
        affine_binding=True,
        region_cover=False,
        stage_pipeline=True,
    )
    assert s._get_cached_flags(_get_block(s, "root")) == CachedFlags(
        affine_binding=True,
        region_cover=True,
        stage_pipeline=False,
    )
def test_matmul():
    s = tir.ScheduleState(matmul, debug_mask="all")
    # pylint: disable=protected-access
    assert s._get_cached_flags(_get_block(s, "init")) == CachedFlags(
        affine_binding=True,
        region_cover=True,
        stage_pipeline=True,
    )
    assert s._get_cached_flags(_get_block(s, "update")) == CachedFlags(
        affine_binding=True,
        region_cover=True,
        stage_pipeline=True,
    )
    assert s._get_cached_flags(_get_block(s, "root")) == CachedFlags(
        affine_binding=True,
        region_cover=True,
        stage_pipeline=True,
    )
def test_thread_binding():
    s = tir.ScheduleState(bound_to_thread, debug_mode=True)
    # pylint: disable=protected-access
    assert s._get_cached_flags(_get_block(s, "root")) == CachedFlags(
        affine_binding=True,
        region_cover=True,
        stage_pipeline=True,
    )
    assert s._get_cached_flags(_get_block(s, "B")) == CachedFlags(
        affine_binding=True,
        region_cover=True,
        stage_pipeline=True,
    )
    assert s._get_cached_flags(_get_block(s, "C")) == CachedFlags(
        affine_binding=True,
        region_cover=True,
        stage_pipeline=True,
    )
def test_multi_producer_consumer():
    s = tir.ScheduleState(multi_producer_consumer, debug_mask="all")
    # pylint: disable=protected-access
    assert s._get_cached_flags(_get_block(s, "A_0")) == CachedFlags(
        affine_binding=True,
        region_cover=True,
        stage_pipeline=True,
    )
    assert s._get_cached_flags(_get_block(s, "A_1")) == CachedFlags(
        affine_binding=True,
        region_cover=True,
        stage_pipeline=True,
    )
    assert s._get_cached_flags(_get_block(s, "B_0")) == CachedFlags(
        affine_binding=True,
        region_cover=True,
        stage_pipeline=True,
    )
    assert s._get_cached_flags(_get_block(s, "B_1")) == CachedFlags(
        affine_binding=True,
        region_cover=True,
        stage_pipeline=True,
    )
def test_concatenate_multi_producer_covered():  # pylint: disable=invalid-name
    s = tir.ScheduleState(concatenate_multi_producer, debug_mask="all")
    # pylint: disable=protected-access
    assert s._get_cached_flags(_get_block(s, "A_0")) == CachedFlags(
        affine_binding=True,
        region_cover=True,
        stage_pipeline=True,
    )
    assert s._get_cached_flags(_get_block(s, "A_1")) == CachedFlags(
        affine_binding=True,
        region_cover=True,
        stage_pipeline=True,
    )
    assert s._get_cached_flags(_get_block(s, "B")) == CachedFlags(
        affine_binding=True,
        region_cover=True,
        stage_pipeline=True,
    )
    assert s._get_cached_flags(_get_block(s, "root")) == CachedFlags(
        affine_binding=True,
        region_cover=True,
        stage_pipeline=True,
    )
Esempio n. 16
0
def test_matmul_dependency():
    s = tir.ScheduleState(matmul, debug_mask="all")
    root = _get_block(s, "root")
    init = _get_block(s, "init")
    update = _get_block(s, "update")
    # Check get_deps_by_src
    p0, p1 = s.get_block_scope(root).get_deps_by_src(init)
    assert p0.src.same_as(init)
    assert p0.dst.same_as(update)
    assert p1.src.same_as(init)
    assert p1.dst.same_as(update)
    assert (p0.kind == DepKind.RAW and p1.kind == DepKind.WAW) or (
        p0.kind == DepKind.WAW and p1.kind == DepKind.RAW
    )
    # Check get_deps_by_dst
    p0, p1 = s.get_block_scope(root).get_deps_by_dst(update)
    assert p0.src.same_as(init)
    assert p0.dst.same_as(update)
    assert p1.src.same_as(init)
    assert p1.dst.same_as(update)
    assert (p0.kind == DepKind.RAW and p1.kind == DepKind.WAW) or (
        p0.kind == DepKind.WAW and p1.kind == DepKind.RAW
    )
def test_subblock_uncovered():
    s = tir.ScheduleState(elementwise_subblock_uncovered, debug_mode=True)
    # pylint: disable=protected-access
    assert s._get_cached_flags(_get_block(s, "root")) == CachedFlags(
        affine_binding=True,
        region_cover=True,
        stage_pipeline=False,
    )
    assert s._get_cached_flags(_get_block(s, "B")) == CachedFlags(
        affine_binding=True,
        region_cover=True,
        stage_pipeline=True,
    )
    assert s._get_cached_flags(_get_block(s, "B_sub")) == CachedFlags(
        affine_binding=True,
        region_cover=True,
        stage_pipeline=True,
    )
    assert s._get_cached_flags(_get_block(s, "C")) == CachedFlags(
        affine_binding=True,
        region_cover=False,
        stage_pipeline=True,
    )
Esempio n. 18
0
def replace_ir_builder_with_opaque():
    func = tvm.script.from_source(block_in_opaque_block.script())
    s = tir.ScheduleState(func, debug_mask="all")
    gc.collect()
    return s
Esempio n. 19
0
def replace_ir_builder_with_opaque():
    func = tvm.script.from_source(tvm.script.asscript(block_in_opaque_block))
    s = tir.ScheduleState(func, debug_mode=True)
    gc.collect()
    return s