def test_getitem_different_value(self, interval, values): imap = IntervalMapping() imap[Interval(start=AxisBound.start(), end=AxisBound.from_start(2))] = get_instance(0) imap[Interval(start=AxisBound.from_end(-2), end=AxisBound.end())] = get_instance(1) res = imap[interval] assert isinstance(res, list) assert len(res) == len(values) for observed, expected in zip(res, values): assert observed is get_instance(expected)
def get_k_subsets(self, node): assert isinstance(node, VerticalLoopLibraryNode) write_intervals = dict() read_intervals = dict() k_origins = dict() for interval, sdfg in node.sections: collection = self._get_access_collection(sdfg) for name, offsets in collection.read_offsets().items(): if self._axes[name][2]: for offset in offsets: read_interval = interval.shifted(offset[2]) read_intervals.setdefault(name, read_interval) read_intervals[name] = Interval( start=min(read_intervals[name].start, read_interval.start), end=max(read_intervals[name].end, read_interval.end), ) k_origins[name] = min( k_origins.get(name, read_interval.start), read_interval.start) for name in collection.write_fields(): if self._axes[name][2]: write_intervals.setdefault(name, interval) write_intervals[name] = Interval( start=min(write_intervals[name].start, interval.start), end=max(write_intervals[name].end, interval.end), ) k_origins[name] = min(k_origins.get(name, interval.start), interval.start) write_subsets = dict() for name, interval in write_intervals.items(): res_interval = _offset_origin(interval, k_origins[name]) write_subsets[name] = "{}{:+d}:{}{:+d}".format( "__K" if res_interval.start.level == LevelMarker.END else "", res_interval.start.offset, "__K" if res_interval.end.level == LevelMarker.END else "", res_interval.end.offset, ) read_subsets = dict() for name, interval in read_intervals.items(): res_interval = _offset_origin(interval, k_origins[name]) read_subsets[name] = "{}{:+d}:{}{:+d}".format( "__K" if res_interval.start.level == LevelMarker.END else "", res_interval.start.offset, "__K" if res_interval.end.level == LevelMarker.END else "", res_interval.end.offset, ) return read_subsets, write_subsets
def __setitem__(self, key: oir.Interval, value: Any) -> None: if not isinstance(key, oir.Interval): raise TypeError( "Only OIR intervals supported for method add of IntervalSet.") key = oir.UnboundedInterval(start=key.start, end=key.end) delete = list() for i, (start, end) in enumerate(zip(self.interval_starts, self.interval_ends)): if key.covers(oir.UnboundedInterval(start=start, end=end)): delete.append(i) for i in reversed(delete): # so indices keep validity while deleting del self.interval_starts[i] del self.interval_ends[i] del self.values[i] if len(self.interval_starts) == 0: self.interval_starts.append(key.start) self.interval_ends.append(key.end) self.values.append(value) return for i, (start, end) in enumerate(zip(self.interval_starts, self.interval_ends)): if oir.UnboundedInterval(start=start, end=end).covers(key): self._setitem_subset_of_existing(i, key, value) return for i, (start, end) in enumerate(zip(self.interval_starts, self.interval_ends)): if (key.intersects(oir.UnboundedInterval(start=start, end=end)) or start == key.end or end == key.start): self._setitem_partial_overlap(i, key, value) return for i, start in enumerate(self.interval_starts): if start > key.start: self.interval_starts.insert(i, key.start) self.interval_ends.insert(i, key.end) self.values.insert(i, value) return self.interval_starts.append(key.start) self.interval_ends.append(key.end) self.values.append(value) return
def _setitem_partial_overlap(self, i: int, key: oir.Interval, value: Any) -> None: start = self.interval_starts[i] if key.start < start: if self.values[i] is value: self.interval_starts[i] = key.start else: self.interval_starts[i] = key.end self.interval_starts.insert(i, key.start) self.interval_ends.insert(i, key.end) self.values.insert(i, value) else: # key.end > end if self.values[i] is value: self.interval_ends[i] = key.end nextidx = i + 1 else: self.interval_ends[i] = key.start self.interval_starts.insert(i + 1, key.start) self.interval_ends.insert(i + 1, key.end) self.values.insert(i + 1, value) nextidx = i + 2 if nextidx < len(self.interval_starts) and ( key.intersects( oir.Interval(start=self.interval_starts[nextidx], end=self.interval_ends[nextidx])) or self.interval_starts[nextidx] == key.end): if self.values[nextidx] is value: self.interval_ends[nextidx - 1] = self.interval_ends[nextidx] del self.interval_starts[nextidx] del self.interval_ends[nextidx] del self.values[nextidx] else: self.interval_starts[nextidx] = key.end
def __getitem__(self, key: oir.Interval) -> List[Any]: if not isinstance(key, oir.Interval): raise TypeError("Only OIR intervals supported for keys of IntervalMapping.") res = [] for start, end, value in zip(self.interval_starts, self.interval_ends, self.values): if key.intersects(oir.Interval(start=start, end=end)): res.append(value) return res
def test_same_node_read_write_not_overlap(): oir = StencilFactory(vertical_loops=[ VerticalLoopFactory(sections__0=VerticalLoopSectionFactory( interval=Interval(start=AxisBound.start(), end=AxisBound.from_start(1)), horizontal_executions__0__body__0=AssignStmtFactory( left__name="field", right__name="other"), )), VerticalLoopFactory(sections__0=VerticalLoopSectionFactory( interval=Interval(start=AxisBound.from_start(1), end=AxisBound.from_start(2)), horizontal_executions__0__body__0=AssignStmtFactory( left__name="field", right__name="field", right__offset__k=-1), )), ]) sdfg = OirSDFGBuilder().visit(oir) convert(sdfg, oir.loc)
def intervals_strategy(draw): length = draw(hyp_st.integers(0, 5)) intervals = [] for _ in range(length): level1 = draw(hyp_st.sampled_from([LevelMarker.START, LevelMarker.END])) offset1 = draw(hyp_st.integers(-5, 5)) bound1 = AxisBound(level=level1, offset=offset1) level2 = draw(hyp_st.sampled_from([LevelMarker.START, LevelMarker.END])) offset2 = draw( hyp_st.integers(-5, 5).filter(lambda x: x != offset1 if level1 == level2 else True) ) bound2 = AxisBound(level=level2, offset=offset2) intervals.append(Interval(start=min(bound1, bound2), end=max(bound1, bound2))) return intervals
class TestIntervalMapping: @staticmethod def assert_consistency(imap: IntervalMapping): assert len(imap.interval_starts) == len(imap.interval_ends) assert len(imap.interval_starts) == len(imap.values) for i in range(len(imap.interval_starts) - 1): assert imap.interval_starts[i] < imap.interval_starts[i + 1] assert imap.interval_ends[i] < imap.interval_ends[i + 1] assert imap.interval_ends[i] <= imap.interval_starts[i + 1] if imap.interval_ends[i] == imap.interval_starts[i + 1]: assert imap.values[i] is not imap.values[i + 1] for start, end in zip(imap.interval_starts, imap.interval_ends): assert start < end @pytest.mark.parametrize( ["intervals", "starts", "ends"], [ ([], [], []), ( [Interval(start=AxisBound.start(), end=AxisBound.end())], [AxisBound.start()], [AxisBound.end()], ), ( [ Interval(start=AxisBound.start(), end=AxisBound.end()), Interval(start=AxisBound.from_start(1), end=AxisBound.end()), ], [AxisBound.start()], [AxisBound.end()], ), ( [ Interval(start=AxisBound.start(), end=AxisBound.end()), Interval(start=AxisBound.from_end(-1), end=AxisBound.end()), ], [AxisBound.start()], [AxisBound.end()], ), ( [ Interval(start=AxisBound.start(), end=AxisBound.from_end(-1)), Interval(start=AxisBound.from_end(-1), end=AxisBound.end()), ], [AxisBound.start()], [AxisBound.end()], ), ( [ Interval(start=AxisBound.start(), end=AxisBound.from_start(1)), Interval(start=AxisBound.from_start(1), end=AxisBound.end()), ], [AxisBound.start()], [AxisBound.end()], ), ( [ Interval(start=AxisBound.start(), end=AxisBound.from_start(1)), Interval(start=AxisBound.from_start(2), end=AxisBound.end()), ], [AxisBound.start(), AxisBound.from_start(2)], [AxisBound.from_start(1), AxisBound.end()], ), ( [ Interval(start=AxisBound.start(), end=AxisBound.from_start(2)), Interval(start=AxisBound.from_start(1), end=AxisBound.end()), ], [AxisBound.start()], [AxisBound.end()], ), ( [ Interval(start=AxisBound.start(), end=AxisBound.from_start(1)), Interval(start=AxisBound.from_end(-1), end=AxisBound.end()), Interval(start=AxisBound.from_start(1), end=AxisBound.from_end(-1)), ], [AxisBound.start()], [AxisBound.end()], ), ( [ Interval(start=AxisBound.start(), end=AxisBound.from_start(1)), Interval(start=AxisBound.from_end(-1), end=AxisBound.end()), Interval(start=AxisBound.from_start(2), end=AxisBound.from_end(-2)), ], [AxisBound.start(), AxisBound.from_start(2), AxisBound.from_end(-1)], [AxisBound.from_start(1), AxisBound.from_end(-2), AxisBound.end()], ), ( [ Interval(start=AxisBound.start(), end=AxisBound.from_start(1)), Interval(start=AxisBound.from_end(-1), end=AxisBound.end()), Interval(start=AxisBound.from_start(1), end=AxisBound.from_end(-2)), ], [AxisBound.start(), AxisBound.from_end(-1)], [AxisBound.from_end(-2), AxisBound.end()], ), ( [ Interval(start=AxisBound.start(), end=AxisBound.from_start(1)), Interval(start=AxisBound.from_end(-1), end=AxisBound.end()), Interval(start=AxisBound.from_start(2), end=AxisBound.from_end(-1)), ], [AxisBound.start(), AxisBound.from_start(2)], [AxisBound.from_start(1), AxisBound.end()], ), ], ) def test_setitem_same_value(self, intervals, starts, ends): # if all values are the same (same instance), the behavior is the same as for a IntervalSet. imap = IntervalMapping() for interval in intervals: imap[interval] = 0 self.assert_consistency(imap) assert len(starts) == len(imap.interval_starts) for expected, observed in zip(starts, imap.interval_starts): assert observed == expected assert len(ends) == len(imap.interval_ends) for expected, observed in zip(ends, imap.interval_ends): assert observed == expected @hyp.given(intervals_strategy()) def test_setitem_same_value_hypothesis(self, intervals): imap = IntervalMapping() for interval in intervals: imap[interval] = 0 self.assert_consistency(imap) for permutation in itertools.permutations(intervals): other_imap = IntervalMapping() for interval in permutation: other_imap[interval] = 0 self.assert_consistency(other_imap) assert imap.interval_starts == other_imap.interval_starts assert imap.interval_ends == other_imap.interval_ends @pytest.mark.parametrize( ["intervals", "starts", "ends", "values"], [ ([], [], [], []), ( [Interval(start=AxisBound.start(), end=AxisBound.end())], [AxisBound.start()], [AxisBound.end()], [0], ), ( [ Interval(start=AxisBound.start(), end=AxisBound.end()), Interval(start=AxisBound.from_start(1), end=AxisBound.end()), ], [AxisBound.start(), AxisBound.from_start(1)], [AxisBound.from_start(1), AxisBound.end()], [0, 1], ), ( [ Interval(start=AxisBound.start(), end=AxisBound.end()), Interval(start=AxisBound.from_end(-1), end=AxisBound.end()), ], [AxisBound.start(), AxisBound.from_end(-1)], [AxisBound.from_end(-1), AxisBound.end()], [0, 1], ), ( [ Interval(start=AxisBound.start(), end=AxisBound.from_end(-1)), Interval(start=AxisBound.from_end(-1), end=AxisBound.end()), ], [AxisBound.start(), AxisBound.from_end(-1)], [AxisBound.from_end(-1), AxisBound.end()], [0, 1], ), ( [ Interval(start=AxisBound.start(), end=AxisBound.from_start(1)), Interval(start=AxisBound.from_start(1), end=AxisBound.end()), ], [AxisBound.start(), AxisBound.from_start(1)], [AxisBound.from_start(1), AxisBound.end()], [0, 1], ), ( [ Interval(start=AxisBound.start(), end=AxisBound.from_start(1)), Interval(start=AxisBound.from_start(2), end=AxisBound.end()), ], [AxisBound.start(), AxisBound.from_start(2)], [AxisBound.from_start(1), AxisBound.end()], [0, 1], ), ( [ Interval(start=AxisBound.start(), end=AxisBound.from_start(2)), Interval(start=AxisBound.from_start(1), end=AxisBound.end()), ], [AxisBound.start(), AxisBound.from_start(1)], [AxisBound.from_start(1), AxisBound.end()], [0, 1], ), ( [ Interval(start=AxisBound.start(), end=AxisBound.from_start(1)), Interval(start=AxisBound.from_end(-1), end=AxisBound.end()), Interval(start=AxisBound.from_start(1), end=AxisBound.from_end(-1)), ], [AxisBound.start(), AxisBound.from_start(1), AxisBound.from_end(-1)], [AxisBound.from_start(1), AxisBound.from_end(-1), AxisBound.end()], [0, 2, 1], ), ( [ Interval(start=AxisBound.start(), end=AxisBound.from_start(1)), Interval(start=AxisBound.from_end(-1), end=AxisBound.end()), Interval(start=AxisBound.from_start(2), end=AxisBound.from_end(-2)), ], [AxisBound.start(), AxisBound.from_start(2), AxisBound.from_end(-1)], [AxisBound.from_start(1), AxisBound.from_end(-2), AxisBound.end()], [0, 2, 1], ), ( [ Interval(start=AxisBound.start(), end=AxisBound.from_start(1)), Interval(start=AxisBound.from_end(-1), end=AxisBound.end()), Interval(start=AxisBound.from_start(1), end=AxisBound.from_end(-2)), ], [AxisBound.start(), AxisBound.from_start(1), AxisBound.from_end(-1)], [AxisBound.from_start(1), AxisBound.from_end(-2), AxisBound.end()], [0, 2, 1], ), ( [ Interval(start=AxisBound.start(), end=AxisBound.from_start(1)), Interval(start=AxisBound.from_end(-1), end=AxisBound.end()), Interval(start=AxisBound.from_start(2), end=AxisBound.from_end(-1)), ], [AxisBound.start(), AxisBound.from_start(2), AxisBound.from_end(-1)], [AxisBound.from_start(1), AxisBound.from_end(-1), AxisBound.end()], [0, 2, 1], ), ( [ Interval(start=AxisBound.start(), end=AxisBound.from_start(3)), Interval(start=AxisBound.from_start(1), end=AxisBound.from_start(2)), ], [AxisBound.start(), AxisBound.from_start(1), AxisBound.from_start(2)], [AxisBound.from_start(1), AxisBound.from_start(2), AxisBound.from_start(3)], [0, 1, 0], ), ( [ Interval(start=AxisBound.from_start(1), end=AxisBound.from_start(2)), Interval(start=AxisBound.start(), end=AxisBound.from_start(3)), ], [AxisBound.start()], [AxisBound.from_start(3)], [1], ), ( [ Interval(start=AxisBound.start(), end=AxisBound.from_start(2)), Interval(start=AxisBound.start(), end=AxisBound.from_start(1)), ], [AxisBound.start(), AxisBound.from_start(1)], [AxisBound.from_start(1), AxisBound.from_start(2)], [1, 0], ), ], ) def test_setitem_different_value(self, intervals, starts, ends, values): imap = IntervalMapping() ctr = 0 for interval in intervals: imap[interval] = get_instance(ctr) self.assert_consistency(imap) ctr = ctr + 1 assert len(imap.interval_starts) == len(starts) assert len(imap.interval_ends) == len(ends) assert len(imap.values) == len(values) for i, (start, end, value) in enumerate( zip(imap.interval_starts, imap.interval_ends, imap.values) ): assert start == starts[i] assert end == ends[i] assert value is get_instance(values[i]) @hyp.given(intervals_strategy()) def test_setitem_different_value_hypothesis(self, intervals): ctr = 0 imap = IntervalMapping() for interval in intervals: imap[interval] = get_instance(ctr) self.assert_consistency(imap) ctr += 1 for permutation in itertools.permutations(intervals): other_imap = IntervalMapping() for interval in permutation: other_imap[interval] = get_instance(ctr) self.assert_consistency(other_imap) ctr += 1 for start, end, value in zip( other_imap.interval_starts, other_imap.interval_ends, other_imap.values ): if start == permutation[-1].start: assert end == permutation[-1].end assert value is get_instance(ctr - 1) break @pytest.mark.parametrize( ["interval", "values"], [ (Interval(start=AxisBound.from_start(-1), end=AxisBound.from_end(1)), [0, 1]), (Interval(start=AxisBound.from_start(-1), end=AxisBound.from_start(3)), [0]), (Interval(start=AxisBound.from_start(1), end=AxisBound.from_end(-1)), [0, 1]), (Interval(start=AxisBound.from_start(2), end=AxisBound.from_end(-2)), []), ], ) def test_getitem_different_value(self, interval, values): imap = IntervalMapping() imap[Interval(start=AxisBound.start(), end=AxisBound.from_start(2))] = get_instance(0) imap[Interval(start=AxisBound.from_end(-2), end=AxisBound.end())] = get_instance(1) res = imap[interval] assert isinstance(res, list) assert len(res) == len(values) for observed, expected in zip(res, values): assert observed is get_instance(expected)
def add_write_after_read_edges(self, node): interval = Interval.full() return self._add_write_after_read_edges( node, [(interval, self._get_access_collection(node))])
def _offset_origin(interval: oir.Interval, origin: oir.AxisBound) -> oir.Interval: if origin >= oir.AxisBound.start(): return interval return interval.shifted(-origin.offset)
def _set_write(self, name, interval, node): if name not in self._recent_write_acc: self._recent_write_acc[name] = IntervalMapping() if not self._axes[name][2]: interval = Interval.full() self._recent_write_acc[name][interval] = node
def _get_recent_writes(self, name, interval): if name not in self._recent_write_acc: self._recent_write_acc[name] = IntervalMapping() if not self._axes[name][2]: interval = Interval.full() return self._recent_write_acc[name][interval]
def __init__(self) -> None: self._interval = Interval(start=AxisBound.start(), end=AxisBound.end()) self._horizontal_executions: List[HorizontalExecution] = []
def _offset_origin(interval: oir.Interval, origin: Optional[oir.AxisBound]) -> oir.Interval: if origin is None: return interval if origin.level != LevelMarker.START: return interval return interval.shifted(-origin.offset)
@pytest.mark.parametrize(["lhs", "rhs"], GREATER_AXISBOUNDS + EQUAL_AXISBOUNDS) def test_ge_true(self, lhs, rhs): res = lhs >= rhs assert isinstance(res, bool) assert res @pytest.mark.parametrize(["lhs", "rhs"], LESS_AXISBOUNDS) def test_ge_false(self, lhs, rhs): res = lhs >= rhs assert isinstance(res, bool) assert not res COVER_INTERVALS = [ ( Interval(start=AxisBound.start(), end=AxisBound.end()), Interval(start=AxisBound.start(), end=AxisBound.end()), ), ( Interval(start=AxisBound.start(), end=AxisBound.end()), Interval(start=AxisBound.start(), end=AxisBound.from_end(-1)), ), ( Interval(start=AxisBound.start(), end=AxisBound.end()), Interval(start=AxisBound.from_start(1), end=AxisBound.end()), ), ( Interval(start=AxisBound.start(), end=AxisBound.end()), Interval(start=AxisBound.from_start(1), end=AxisBound.from_end(-1)), ), (
def _offset_origin(interval: oir.Interval, origin: oir.AxisBound) -> oir.Interval: if origin.level != LevelMarker.START: return interval return interval.shifted(-origin.offset)
def get_k_subsets(self, node): assert isinstance(node, VerticalLoopLibraryNode) write_intervals = dict() read_intervals = dict() k_origins = dict() dynamic_read_intervals = dict() for interval, sdfg in node.sections: collection = self._get_access_collection(sdfg) for name, offsets in collection.read_offsets().items(): if self._axes[name][2]: for offset in offsets: k_offset = 0 if offset[2] is None else offset[2] read_interval = interval.shifted(k_offset) if offset[2] is None: dynamic_read_intervals.setdefault(name, interval) dynamic_read_intervals[name] = Interval( start=min(dynamic_read_intervals[name].start, interval.start), end=max(dynamic_read_intervals[name].end, interval.end), ) read_intervals.setdefault(name, read_interval) start = ( min(read_intervals[name].start, read_interval.start) if read_intervals[name].start is not None and read_interval.start is not None else None ) end = ( max(read_intervals[name].end, read_interval.end) if read_intervals[name].end is not None and read_interval.end is not None else None ) read_intervals[name] = oir.UnboundedInterval( start=start, end=end, ) k_origins.setdefault(name, read_interval.start) if read_interval.start is not None: if k_origins[name] is None: k_origins[name] = read_interval.start else: k_origins[name] = min(k_origins[name], read_interval.start) for name in collection.write_fields(): if self._axes[name][2]: write_intervals.setdefault(name, interval) write_intervals[name] = Interval( start=min(write_intervals[name].start, interval.start), end=max(write_intervals[name].end, interval.end), ) k_origins[name] = min(k_origins.get(name, interval.start), interval.start) write_subsets = dict() for name, interval in write_intervals.items(): res_interval = _offset_origin(interval, k_origins[name]) write_subsets[name] = "{}{:+d}:{}{:+d}".format( "__K" if res_interval.start.level == LevelMarker.END else "", res_interval.start.offset, "__K" if res_interval.end.level == LevelMarker.END else "", res_interval.end.offset, ) read_subsets = dict() for name, interval in read_intervals.items(): res_interval = _offset_origin(interval, k_origins[name]) if name in dynamic_read_intervals: dyn_interval = dynamic_read_intervals[name] res_interval = Interval( start=min(res_interval.start, dyn_interval.start), end=max(res_interval.end, dyn_interval.end), ) if res_interval.start is None: interval_start_str = "0" else: interval_start_str = "{}{:+d}".format( "__K" if res_interval.start.level == LevelMarker.END else "", res_interval.start.offset, ) if res_interval.end is None: interval_end_str = "__K" else: interval_end_str = "{}{:+d}".format( "__K" if res_interval.end.level == LevelMarker.END else "", res_interval.end.offset, ) read_subsets[name] = f"{interval_start_str}:{interval_end_str}" return read_subsets, write_subsets