示例#1
0
def test_shape_mapping():
    index_map = IndexMap.from_func(lambda i: [i // 4, i % 4])

    assert_structural_equal(index_map.map_shape([4]), [1, 4])
    assert_structural_equal(index_map.map_shape([16]), [4, 4])

    assert_structural_equal(index_map.map_shape([14]), [4, 4])
示例#2
0
def test_index_mapping():
    index_map = IndexMap.from_func(lambda i: [i // 4, i % 4])

    assert_structural_equal(index_map.map_indices([0]), [0, 0])
    assert_structural_equal(index_map.map_indices([3]), [0, 3])
    assert_structural_equal(index_map.map_indices([4]), [1, 0])
    assert_structural_equal(index_map.map_indices([42]), [10, 2])
示例#3
0
def test_index_map_inverse_no_iter():
    def input_example(i0, i1, i2, i3):
        j0 = floordiv(i3, 32)
        j1 = floordiv(i2, 2)
        j2 = floormod(i2, 2)
        j3 = floormod(i3, 32)
        return j0, j1, j2, j3

    def expected_inverse(i0, i1, i2, i3):
        return IntImm("int32", 0), IntImm("int32",
                                          0), i2 + i1 * 2, i3 + i0 * 32

    index_map = IndexMap.from_func(input_example)
    inverse_map = index_map.inverse([1, 1, 64, 64])
    expected_map = IndexMap.from_func(expected_inverse)
    assert expected_map.is_equivalent_to(inverse_map)
def check_index_map(workload, block_name, intrin_name, expected_index_map):
    s = Schedule(workload)
    block = s.get_block(block_name)
    desc_func = TensorIntrin.get(intrin_name).desc
    info = get_auto_tensorize_mapping_info(s, block, desc_func)
    assert len(info.mappings) == 1
    assert IndexMap.from_func(expected_index_map).is_equivalent_to(
        info.mappings[0])
示例#5
0
def test_nonsurjective_inverse(padding_test_case):
    index_map = IndexMap.from_func(padding_test_case["forward"])

    inverse, padding_predicate = index_map.non_surjective_inverse(
        padding_test_case["pre_shape"])
    expected_inverse = IndexMap.from_func(padding_test_case["inverse"])
    assert inverse.is_equivalent_to(expected_inverse)

    post_shape = index_map.map_shape(padding_test_case["pre_shape"])
    tvm.ir.assert_structural_equal(post_shape, padding_test_case["post_shape"])

    expected_predicate = padding_test_case["padding"](*inverse.initial_indices)

    # Can't use analyzer.can_prove_equal, because it can't simplify
    # expressions like `(4*i+j >= 14) - (4*i+j >= 14)`.
    analyzer = tvm.arith.Analyzer()
    expected_predicate = analyzer.simplify(expected_predicate)
    padding_predicate = analyzer.simplify(padding_predicate)
    tvm.ir.assert_structural_equal(padding_predicate, expected_predicate)
示例#6
0
def test_suggest_index_map_bijective():
    i, j = _make_vars("i", "j")
    index_map = suggest_index_map(
        buffer=decl_buffer(shape=[8]),
        indices=[floormod(j, 4) * 2 + i],
        loops=_make_loops(
            loop_vars=[i, j],
            extents=[2, 32],
        ),
        predicate=True,
    )
    expected_index_map = IndexMap.from_func(
        lambda x: [
            floormod(x, 2),
            floordiv(x, 2),
        ], )
    assert index_map.is_equivalent_to(expected_index_map)
示例#7
0
def test_suggest_index_map_simple():
    i, j = _make_vars("i", "j")
    index_map = suggest_index_map(
        buffer=decl_buffer(shape=[8, 256]),
        indices=[
            floordiv(i, 16) * 4 + floordiv(j, 16),
            floormod(i, 16) * 16 + floormod(j, 16),
        ],
        loops=_make_loops(
            loop_vars=[i, j],
            extents=[32, 64],
        ),
        predicate=True,
    )
    expected_index_map = IndexMap.from_func(
        lambda x, y: [
            floordiv(x, 4),
            floordiv(y, 16),
            floormod(x, 4),
            floormod(y, 16),
        ],
    )
    assert index_map.is_equivalent_to(expected_index_map)
示例#8
0
def test_nonbijective_inverse_gives_error():
    index_map = IndexMap.from_func(lambda i: [i // 4, i % 4])

    with pytest.raises(tvm.TVMError):
        index_map.inverse([14])
示例#9
0
def test_inverse():
    index_map = IndexMap.from_func(lambda i: [i // 4, i % 4])
    expected_inverse = IndexMap.from_func(lambda i, j: [4 * i + j])

    assert index_map.inverse([16]).is_equivalent_to(expected_inverse)