def test_shape_mapping(): index_map = IndexMap.from_func(lambda i: [i // 4, i % 4]) assert_structural_equal(index_map.map_shape([4]), [1, 4]) assert_structural_equal(index_map.map_shape([16]), [4, 4]) assert_structural_equal(index_map.map_shape([14]), [4, 4])
def test_index_mapping(): index_map = IndexMap.from_func(lambda i: [i // 4, i % 4]) assert_structural_equal(index_map.map_indices([0]), [0, 0]) assert_structural_equal(index_map.map_indices([3]), [0, 3]) assert_structural_equal(index_map.map_indices([4]), [1, 0]) assert_structural_equal(index_map.map_indices([42]), [10, 2])
def test_index_map_inverse_no_iter(): def input_example(i0, i1, i2, i3): j0 = floordiv(i3, 32) j1 = floordiv(i2, 2) j2 = floormod(i2, 2) j3 = floormod(i3, 32) return j0, j1, j2, j3 def expected_inverse(i0, i1, i2, i3): return IntImm("int32", 0), IntImm("int32", 0), i2 + i1 * 2, i3 + i0 * 32 index_map = IndexMap.from_func(input_example) inverse_map = index_map.inverse([1, 1, 64, 64]) expected_map = IndexMap.from_func(expected_inverse) assert expected_map.is_equivalent_to(inverse_map)
def check_index_map(workload, block_name, intrin_name, expected_index_map): s = Schedule(workload) block = s.get_block(block_name) desc_func = TensorIntrin.get(intrin_name).desc info = get_auto_tensorize_mapping_info(s, block, desc_func) assert len(info.mappings) == 1 assert IndexMap.from_func(expected_index_map).is_equivalent_to( info.mappings[0])
def test_nonsurjective_inverse(padding_test_case): index_map = IndexMap.from_func(padding_test_case["forward"]) inverse, padding_predicate = index_map.non_surjective_inverse( padding_test_case["pre_shape"]) expected_inverse = IndexMap.from_func(padding_test_case["inverse"]) assert inverse.is_equivalent_to(expected_inverse) post_shape = index_map.map_shape(padding_test_case["pre_shape"]) tvm.ir.assert_structural_equal(post_shape, padding_test_case["post_shape"]) expected_predicate = padding_test_case["padding"](*inverse.initial_indices) # Can't use analyzer.can_prove_equal, because it can't simplify # expressions like `(4*i+j >= 14) - (4*i+j >= 14)`. analyzer = tvm.arith.Analyzer() expected_predicate = analyzer.simplify(expected_predicate) padding_predicate = analyzer.simplify(padding_predicate) tvm.ir.assert_structural_equal(padding_predicate, expected_predicate)
def test_suggest_index_map_bijective(): i, j = _make_vars("i", "j") index_map = suggest_index_map( buffer=decl_buffer(shape=[8]), indices=[floormod(j, 4) * 2 + i], loops=_make_loops( loop_vars=[i, j], extents=[2, 32], ), predicate=True, ) expected_index_map = IndexMap.from_func( lambda x: [ floormod(x, 2), floordiv(x, 2), ], ) assert index_map.is_equivalent_to(expected_index_map)
def test_suggest_index_map_simple(): i, j = _make_vars("i", "j") index_map = suggest_index_map( buffer=decl_buffer(shape=[8, 256]), indices=[ floordiv(i, 16) * 4 + floordiv(j, 16), floormod(i, 16) * 16 + floormod(j, 16), ], loops=_make_loops( loop_vars=[i, j], extents=[32, 64], ), predicate=True, ) expected_index_map = IndexMap.from_func( lambda x, y: [ floordiv(x, 4), floordiv(y, 16), floormod(x, 4), floormod(y, 16), ], ) assert index_map.is_equivalent_to(expected_index_map)
def test_nonbijective_inverse_gives_error(): index_map = IndexMap.from_func(lambda i: [i // 4, i % 4]) with pytest.raises(tvm.TVMError): index_map.inverse([14])
def test_inverse(): index_map = IndexMap.from_func(lambda i: [i // 4, i % 4]) expected_inverse = IndexMap.from_func(lambda i, j: [4 * i + j]) assert index_map.inverse([16]).is_equivalent_to(expected_inverse)