コード例 #1
0
ファイル: test_dace_utils.py プロジェクト: havogt/gt4py
 def test_getitem_different_value(self, interval, values):
     imap = IntervalMapping()
     imap[Interval(start=AxisBound.start(), end=AxisBound.from_start(2))] = get_instance(0)
     imap[Interval(start=AxisBound.from_end(-2), end=AxisBound.end())] = get_instance(1)
     res = imap[interval]
     assert isinstance(res, list)
     assert len(res) == len(values)
     for observed, expected in zip(res, values):
         assert observed is get_instance(expected)
コード例 #2
0
ファイル: test_oir.py プロジェクト: stubbiali/gt4py
    VerticalLoopSectionFactory,
)


def test_no_horizontal_offset_allowed():
    with pytest.raises(ValidationError, match=r"must not have .*horizontal offset"):
        AssignStmtFactory(left__offset__i=1)


def test_mask_must_be_bool():
    with pytest.raises(ValidationError, match=r".*must be.* bool.*"):
        MaskStmtFactory(mask=FieldAccessFactory(dtype=DataType.INT32))


EQUAL_AXISBOUNDS = [
    (AxisBound.start(), AxisBound.start()),
    (AxisBound.end(), AxisBound.end()),
    (AxisBound.from_end(-1), AxisBound.from_end(-1)),
]
LESS_AXISBOUNDS = [
    (AxisBound.start(), AxisBound.end()),
    (AxisBound.start(), AxisBound.from_start(1)),
    (AxisBound.from_end(-1), AxisBound.end()),
    (AxisBound.from_start(1), AxisBound.from_end(-1)),
]
GREATER_AXISBOUNDS = [
    (AxisBound.end(), AxisBound.start()),
    (AxisBound.from_start(1), AxisBound.start()),
    (AxisBound.end(), AxisBound.from_end(-1)),
    (AxisBound.from_end(-1), AxisBound.from_start(1)),
]
コード例 #3
0
ファイル: test_dace_utils.py プロジェクト: havogt/gt4py
class TestIntervalMapping:
    @staticmethod
    def assert_consistency(imap: IntervalMapping):
        assert len(imap.interval_starts) == len(imap.interval_ends)
        assert len(imap.interval_starts) == len(imap.values)
        for i in range(len(imap.interval_starts) - 1):
            assert imap.interval_starts[i] < imap.interval_starts[i + 1]
            assert imap.interval_ends[i] < imap.interval_ends[i + 1]
            assert imap.interval_ends[i] <= imap.interval_starts[i + 1]
            if imap.interval_ends[i] == imap.interval_starts[i + 1]:
                assert imap.values[i] is not imap.values[i + 1]

        for start, end in zip(imap.interval_starts, imap.interval_ends):
            assert start < end

    @pytest.mark.parametrize(
        ["intervals", "starts", "ends"],
        [
            ([], [], []),
            (
                [Interval(start=AxisBound.start(), end=AxisBound.end())],
                [AxisBound.start()],
                [AxisBound.end()],
            ),
            (
                [
                    Interval(start=AxisBound.start(), end=AxisBound.end()),
                    Interval(start=AxisBound.from_start(1), end=AxisBound.end()),
                ],
                [AxisBound.start()],
                [AxisBound.end()],
            ),
            (
                [
                    Interval(start=AxisBound.start(), end=AxisBound.end()),
                    Interval(start=AxisBound.from_end(-1), end=AxisBound.end()),
                ],
                [AxisBound.start()],
                [AxisBound.end()],
            ),
            (
                [
                    Interval(start=AxisBound.start(), end=AxisBound.from_end(-1)),
                    Interval(start=AxisBound.from_end(-1), end=AxisBound.end()),
                ],
                [AxisBound.start()],
                [AxisBound.end()],
            ),
            (
                [
                    Interval(start=AxisBound.start(), end=AxisBound.from_start(1)),
                    Interval(start=AxisBound.from_start(1), end=AxisBound.end()),
                ],
                [AxisBound.start()],
                [AxisBound.end()],
            ),
            (
                [
                    Interval(start=AxisBound.start(), end=AxisBound.from_start(1)),
                    Interval(start=AxisBound.from_start(2), end=AxisBound.end()),
                ],
                [AxisBound.start(), AxisBound.from_start(2)],
                [AxisBound.from_start(1), AxisBound.end()],
            ),
            (
                [
                    Interval(start=AxisBound.start(), end=AxisBound.from_start(2)),
                    Interval(start=AxisBound.from_start(1), end=AxisBound.end()),
                ],
                [AxisBound.start()],
                [AxisBound.end()],
            ),
            (
                [
                    Interval(start=AxisBound.start(), end=AxisBound.from_start(1)),
                    Interval(start=AxisBound.from_end(-1), end=AxisBound.end()),
                    Interval(start=AxisBound.from_start(1), end=AxisBound.from_end(-1)),
                ],
                [AxisBound.start()],
                [AxisBound.end()],
            ),
            (
                [
                    Interval(start=AxisBound.start(), end=AxisBound.from_start(1)),
                    Interval(start=AxisBound.from_end(-1), end=AxisBound.end()),
                    Interval(start=AxisBound.from_start(2), end=AxisBound.from_end(-2)),
                ],
                [AxisBound.start(), AxisBound.from_start(2), AxisBound.from_end(-1)],
                [AxisBound.from_start(1), AxisBound.from_end(-2), AxisBound.end()],
            ),
            (
                [
                    Interval(start=AxisBound.start(), end=AxisBound.from_start(1)),
                    Interval(start=AxisBound.from_end(-1), end=AxisBound.end()),
                    Interval(start=AxisBound.from_start(1), end=AxisBound.from_end(-2)),
                ],
                [AxisBound.start(), AxisBound.from_end(-1)],
                [AxisBound.from_end(-2), AxisBound.end()],
            ),
            (
                [
                    Interval(start=AxisBound.start(), end=AxisBound.from_start(1)),
                    Interval(start=AxisBound.from_end(-1), end=AxisBound.end()),
                    Interval(start=AxisBound.from_start(2), end=AxisBound.from_end(-1)),
                ],
                [AxisBound.start(), AxisBound.from_start(2)],
                [AxisBound.from_start(1), AxisBound.end()],
            ),
        ],
    )
    def test_setitem_same_value(self, intervals, starts, ends):
        # if all values are the same (same instance), the behavior is the same as for a IntervalSet.
        imap = IntervalMapping()
        for interval in intervals:
            imap[interval] = 0
            self.assert_consistency(imap)

        assert len(starts) == len(imap.interval_starts)
        for expected, observed in zip(starts, imap.interval_starts):
            assert observed == expected

        assert len(ends) == len(imap.interval_ends)
        for expected, observed in zip(ends, imap.interval_ends):
            assert observed == expected

    @hyp.given(intervals_strategy())
    def test_setitem_same_value_hypothesis(self, intervals):
        imap = IntervalMapping()
        for interval in intervals:
            imap[interval] = 0
            self.assert_consistency(imap)

        for permutation in itertools.permutations(intervals):
            other_imap = IntervalMapping()
            for interval in permutation:
                other_imap[interval] = 0
                self.assert_consistency(other_imap)
            assert imap.interval_starts == other_imap.interval_starts
            assert imap.interval_ends == other_imap.interval_ends

    @pytest.mark.parametrize(
        ["intervals", "starts", "ends", "values"],
        [
            ([], [], [], []),
            (
                [Interval(start=AxisBound.start(), end=AxisBound.end())],
                [AxisBound.start()],
                [AxisBound.end()],
                [0],
            ),
            (
                [
                    Interval(start=AxisBound.start(), end=AxisBound.end()),
                    Interval(start=AxisBound.from_start(1), end=AxisBound.end()),
                ],
                [AxisBound.start(), AxisBound.from_start(1)],
                [AxisBound.from_start(1), AxisBound.end()],
                [0, 1],
            ),
            (
                [
                    Interval(start=AxisBound.start(), end=AxisBound.end()),
                    Interval(start=AxisBound.from_end(-1), end=AxisBound.end()),
                ],
                [AxisBound.start(), AxisBound.from_end(-1)],
                [AxisBound.from_end(-1), AxisBound.end()],
                [0, 1],
            ),
            (
                [
                    Interval(start=AxisBound.start(), end=AxisBound.from_end(-1)),
                    Interval(start=AxisBound.from_end(-1), end=AxisBound.end()),
                ],
                [AxisBound.start(), AxisBound.from_end(-1)],
                [AxisBound.from_end(-1), AxisBound.end()],
                [0, 1],
            ),
            (
                [
                    Interval(start=AxisBound.start(), end=AxisBound.from_start(1)),
                    Interval(start=AxisBound.from_start(1), end=AxisBound.end()),
                ],
                [AxisBound.start(), AxisBound.from_start(1)],
                [AxisBound.from_start(1), AxisBound.end()],
                [0, 1],
            ),
            (
                [
                    Interval(start=AxisBound.start(), end=AxisBound.from_start(1)),
                    Interval(start=AxisBound.from_start(2), end=AxisBound.end()),
                ],
                [AxisBound.start(), AxisBound.from_start(2)],
                [AxisBound.from_start(1), AxisBound.end()],
                [0, 1],
            ),
            (
                [
                    Interval(start=AxisBound.start(), end=AxisBound.from_start(2)),
                    Interval(start=AxisBound.from_start(1), end=AxisBound.end()),
                ],
                [AxisBound.start(), AxisBound.from_start(1)],
                [AxisBound.from_start(1), AxisBound.end()],
                [0, 1],
            ),
            (
                [
                    Interval(start=AxisBound.start(), end=AxisBound.from_start(1)),
                    Interval(start=AxisBound.from_end(-1), end=AxisBound.end()),
                    Interval(start=AxisBound.from_start(1), end=AxisBound.from_end(-1)),
                ],
                [AxisBound.start(), AxisBound.from_start(1), AxisBound.from_end(-1)],
                [AxisBound.from_start(1), AxisBound.from_end(-1), AxisBound.end()],
                [0, 2, 1],
            ),
            (
                [
                    Interval(start=AxisBound.start(), end=AxisBound.from_start(1)),
                    Interval(start=AxisBound.from_end(-1), end=AxisBound.end()),
                    Interval(start=AxisBound.from_start(2), end=AxisBound.from_end(-2)),
                ],
                [AxisBound.start(), AxisBound.from_start(2), AxisBound.from_end(-1)],
                [AxisBound.from_start(1), AxisBound.from_end(-2), AxisBound.end()],
                [0, 2, 1],
            ),
            (
                [
                    Interval(start=AxisBound.start(), end=AxisBound.from_start(1)),
                    Interval(start=AxisBound.from_end(-1), end=AxisBound.end()),
                    Interval(start=AxisBound.from_start(1), end=AxisBound.from_end(-2)),
                ],
                [AxisBound.start(), AxisBound.from_start(1), AxisBound.from_end(-1)],
                [AxisBound.from_start(1), AxisBound.from_end(-2), AxisBound.end()],
                [0, 2, 1],
            ),
            (
                [
                    Interval(start=AxisBound.start(), end=AxisBound.from_start(1)),
                    Interval(start=AxisBound.from_end(-1), end=AxisBound.end()),
                    Interval(start=AxisBound.from_start(2), end=AxisBound.from_end(-1)),
                ],
                [AxisBound.start(), AxisBound.from_start(2), AxisBound.from_end(-1)],
                [AxisBound.from_start(1), AxisBound.from_end(-1), AxisBound.end()],
                [0, 2, 1],
            ),
            (
                [
                    Interval(start=AxisBound.start(), end=AxisBound.from_start(3)),
                    Interval(start=AxisBound.from_start(1), end=AxisBound.from_start(2)),
                ],
                [AxisBound.start(), AxisBound.from_start(1), AxisBound.from_start(2)],
                [AxisBound.from_start(1), AxisBound.from_start(2), AxisBound.from_start(3)],
                [0, 1, 0],
            ),
            (
                [
                    Interval(start=AxisBound.from_start(1), end=AxisBound.from_start(2)),
                    Interval(start=AxisBound.start(), end=AxisBound.from_start(3)),
                ],
                [AxisBound.start()],
                [AxisBound.from_start(3)],
                [1],
            ),
            (
                [
                    Interval(start=AxisBound.start(), end=AxisBound.from_start(2)),
                    Interval(start=AxisBound.start(), end=AxisBound.from_start(1)),
                ],
                [AxisBound.start(), AxisBound.from_start(1)],
                [AxisBound.from_start(1), AxisBound.from_start(2)],
                [1, 0],
            ),
        ],
    )
    def test_setitem_different_value(self, intervals, starts, ends, values):
        imap = IntervalMapping()
        ctr = 0
        for interval in intervals:
            imap[interval] = get_instance(ctr)
            self.assert_consistency(imap)
            ctr = ctr + 1

        assert len(imap.interval_starts) == len(starts)
        assert len(imap.interval_ends) == len(ends)
        assert len(imap.values) == len(values)
        for i, (start, end, value) in enumerate(
            zip(imap.interval_starts, imap.interval_ends, imap.values)
        ):
            assert start == starts[i]
            assert end == ends[i]
            assert value is get_instance(values[i])

    @hyp.given(intervals_strategy())
    def test_setitem_different_value_hypothesis(self, intervals):
        ctr = 0
        imap = IntervalMapping()
        for interval in intervals:
            imap[interval] = get_instance(ctr)
            self.assert_consistency(imap)
            ctr += 1
        for permutation in itertools.permutations(intervals):
            other_imap = IntervalMapping()
            for interval in permutation:
                other_imap[interval] = get_instance(ctr)
                self.assert_consistency(other_imap)
                ctr += 1

            for start, end, value in zip(
                other_imap.interval_starts, other_imap.interval_ends, other_imap.values
            ):
                if start == permutation[-1].start:
                    assert end == permutation[-1].end
                    assert value is get_instance(ctr - 1)
                    break

    @pytest.mark.parametrize(
        ["interval", "values"],
        [
            (Interval(start=AxisBound.from_start(-1), end=AxisBound.from_end(1)), [0, 1]),
            (Interval(start=AxisBound.from_start(-1), end=AxisBound.from_start(3)), [0]),
            (Interval(start=AxisBound.from_start(1), end=AxisBound.from_end(-1)), [0, 1]),
            (Interval(start=AxisBound.from_start(2), end=AxisBound.from_end(-2)), []),
        ],
    )
    def test_getitem_different_value(self, interval, values):
        imap = IntervalMapping()
        imap[Interval(start=AxisBound.start(), end=AxisBound.from_start(2))] = get_instance(0)
        imap[Interval(start=AxisBound.from_end(-2), end=AxisBound.end())] = get_instance(1)
        res = imap[interval]
        assert isinstance(res, list)
        assert len(res) == len(values)
        for observed, expected in zip(res, values):
            assert observed is get_instance(expected)
コード例 #4
0
 def __init__(self) -> None:
     self._interval = Interval(start=AxisBound.start(), end=AxisBound.end())
     self._horizontal_executions: List[HorizontalExecution] = []