Example #1
0
def test_shape_mapping():
    index_map = IndexMap.from_func(lambda i: [i // 4, i % 4])

    assert_structural_equal(index_map.map_shape([4]), [1, 4])
    assert_structural_equal(index_map.map_shape([16]), [4, 4])

    assert_structural_equal(index_map.map_shape([14]), [4, 4])
Example #2
0
def test_index_map_inverse_no_iter():
    def input_example(i0, i1, i2, i3):
        j0 = floordiv(i3, 32)
        j1 = floordiv(i2, 2)
        j2 = floormod(i2, 2)
        j3 = floormod(i3, 32)
        return j0, j1, j2, j3

    def expected_inverse(i0, i1, i2, i3):
        return IntImm("int32", 0), IntImm("int32",
                                          0), i2 + i1 * 2, i3 + i0 * 32

    index_map = IndexMap.from_func(input_example)
    inverse_map = index_map.inverse([1, 1, 64, 64])
    expected_map = IndexMap.from_func(expected_inverse)
    assert expected_map.is_equivalent_to(inverse_map)
Example #3
0
def test_index_mapping():
    index_map = IndexMap.from_func(lambda i: [i // 4, i % 4])

    assert_structural_equal(index_map.map_indices([0]), [0, 0])
    assert_structural_equal(index_map.map_indices([3]), [0, 3])
    assert_structural_equal(index_map.map_indices([4]), [1, 0])
    assert_structural_equal(index_map.map_indices([42]), [10, 2])
def check_index_map(workload, block_name, intrin_name, expected_index_map):
    s = Schedule(workload)
    block = s.get_block(block_name)
    desc_func = TensorIntrin.get(intrin_name).desc
    info = get_auto_tensorize_mapping_info(s, block, desc_func)
    assert len(info.mappings) == 1
    assert IndexMap.from_func(expected_index_map).is_equivalent_to(
        info.mappings[0])
Example #5
0
def assert_equal_index_map(map1: IndexMap, map2: IndexMap) -> None:

    iters_1 = map1.map_indices(map2.initial_indices)
    iters_2 = map2.final_indices
    assert len(iters_1) == len(iters_2)

    analyzer = tvm.arith.Analyzer()
    for iter1, iter2 in zip(iters_1, iters_2):
        assert analyzer.can_prove_equal(iter1, iter2)
Example #6
0
def test_nonsurjective_inverse(padding_test_case):
    index_map = IndexMap.from_func(padding_test_case["forward"])

    inverse, padding_predicate = index_map.non_surjective_inverse(
        padding_test_case["pre_shape"])
    expected_inverse = IndexMap.from_func(padding_test_case["inverse"])
    assert inverse.is_equivalent_to(expected_inverse)

    post_shape = index_map.map_shape(padding_test_case["pre_shape"])
    tvm.ir.assert_structural_equal(post_shape, padding_test_case["post_shape"])

    expected_predicate = padding_test_case["padding"](*inverse.initial_indices)

    # Can't use analyzer.can_prove_equal, because it can't simplify
    # expressions like `(4*i+j >= 14) - (4*i+j >= 14)`.
    analyzer = tvm.arith.Analyzer()
    expected_predicate = analyzer.simplify(expected_predicate)
    padding_predicate = analyzer.simplify(padding_predicate)
    tvm.ir.assert_structural_equal(padding_predicate, expected_predicate)
Example #7
0
def test_suggest_index_map_bijective():
    i, j = _make_vars("i", "j")
    index_map = suggest_index_map(
        buffer=decl_buffer(shape=[8]),
        indices=[floormod(j, 4) * 2 + i],
        loops=_make_loops(
            loop_vars=[i, j],
            extents=[2, 32],
        ),
        predicate=True,
    )
    expected_index_map = IndexMap.from_func(
        lambda x: [
            floormod(x, 2),
            floordiv(x, 2),
        ], )
    assert index_map.is_equivalent_to(expected_index_map)
Example #8
0
def test_suggest_index_map_simple():
    i, j = _make_vars("i", "j")
    index_map = suggest_index_map(
        buffer=decl_buffer(shape=[8, 256]),
        indices=[
            floordiv(i, 16) * 4 + floordiv(j, 16),
            floormod(i, 16) * 16 + floormod(j, 16),
        ],
        loops=_make_loops(
            loop_vars=[i, j],
            extents=[32, 64],
        ),
        predicate=True,
    )
    expected_index_map = IndexMap.from_func(
        lambda x, y: [
            floordiv(x, 4),
            floordiv(y, 16),
            floormod(x, 4),
            floormod(y, 16),
        ],
    )
    assert index_map.is_equivalent_to(expected_index_map)
Example #9
0
def test_nonbijective_inverse_gives_error():
    index_map = IndexMap.from_func(lambda i: [i // 4, i % 4])

    with pytest.raises(tvm.TVMError):
        index_map.inverse([14])
Example #10
0
def test_inverse():
    index_map = IndexMap.from_func(lambda i: [i // 4, i % 4])
    expected_inverse = IndexMap.from_func(lambda i, j: [4 * i + j])

    assert index_map.inverse([16]).is_equivalent_to(expected_inverse)
Example #11
0
    def transform_layout(self,
                         mapping_function: Callable[...,
                                                    List[tvm.tir.PrimExpr]]):
        """Defines the layout transformation for the current stage's tensor.

        The map from initial_indices to final_indices must be an
        invertible affine transformation.  This method may be called
        more than once for a given tensor, in which case each
        transformation is applied sequentially.

        If the stage is a ComputeOp, then the iteration order of the
        compute stage is rewritten to be a row-major traversal of the
        tensor, and the new loop iteration variables are returned.
        For all other stages, the loop iteration order is unmodified,
        and the return value is None.

        Parameters
        ----------
        mapping_function : Callable[..., List[tvm.tir.PrimExpr]]

            A callable that accepts N arguments of type tvm.tir.Var,
            and outputs a list of PrimExpr.  The input arguments
            represent the location of a value in the current stage's
            tensor, using the pre-transformation layout.  The return
            value of the function gives the location of that value in
            the current stage's tensor, using the post-transformation
            layout.

        Returns
        -------
        new_iter_vars : Optional[List[tvm.tir.IterVar]]

            If the stage is a ComputeOp, then the return will be the
            updated loop iteration variables over the data array, in
            the same order as the output values from the
            `mapping_function`.

            Otherwise, the return value is None.

        Examples
        --------
        .. code-block:: python

            # ``A`` is a tensor whose compute definition is in NHWC
            # format, and should be transformed into NCHWc format.

            s[A].transform_layout(
                lambda n,h,w,c: [n, c//4, h, w, c%4]
            )


        .. code-block:: python

            # ``A`` is a tensor whose compute definition is in an
            # arbitrary format, and should be transformed such that
            # the last index is split, with the slower-changing index
            # of the split placed at the slowest changing dimension.

            s[A].transform_layout(
                lambda *indices, i: [i//4, *indices, i%4]
            )

        .. code-block:: python

            # ``B`` is a tensor defined by te.compute to be a copy of
            # ``A`, and should be transformed such that ``B``'s layout
            # is a transpose of ``A``'s layout.  The loop iteration
            # that computes ``B`` will correspond to ``B``'s memory
            # layout.

            A = te.placeholder([n,m])
            B = te.compute(A.shape, lambda i,j: A[i,j])
            s = te.create_schedule(B.op)

            s[B].transform_layout(lambda i,j: [j,i])

        """

        ndim = len(self.op.output(0).shape)
        index_map, axis_separators = IndexMap.from_func_with_separators(
            mapping_function, ndim=ndim)

        new_iter_vars = _ffi_api.StageTransformLayout(
            self, index_map.initial_indices, index_map.final_indices)
        _ffi_api.StageSetAxisSeparators(self, axis_separators)

        return new_iter_vars or None
def _assert_equal_index_map(map1: IndexMap, map2: IndexMap) -> None:
    iters_1 = map1.map_indices(map2.initial_indices)
    iters_2 = map2.final_indices
    assert len(iters_1) == len(iters_2)
    for iter1, iter2 in zip(iters_1, iters_2):
        assert expr_deep_equal(iter1, iter2)