def _compute_relative_interval( extent: Tuple[int, int], interval: common.HorizontalInterval ) -> Optional[common.HorizontalInterval]: def _offset(extent: Tuple[int, int], bound: Optional[common.AxisBound], start: bool = True) -> Tuple[common.LevelMarker, int]: if bound: if start: if bound.level == common.LevelMarker.START: offset = max(0, bound.offset - extent[0]) else: offset = min(0, bound.offset - extent[1]) else: if bound.level == common.LevelMarker.END: offset = min(0, bound.offset - extent[1]) else: offset = max(0, bound.offset - extent[0]) else: offset = 0 return offset return (common.HorizontalInterval( start=common.AxisBound( level=interval.start.level if interval.start else common.LevelMarker.START, offset=_offset(extent, interval.start, start=True), ), end=common.AxisBound( level=interval.end.level if interval.end else common.LevelMarker.END, offset=_offset(extent, interval.end, start=False), ), ) if _overlap_along_axis(extent, interval) else None)
def _split_entry_level( loop_order: common.LoopOrder, section: oir.VerticalLoopSection, new_symbol_name: Callable[[str], str], ) -> Tuple[oir.VerticalLoopSection, oir.VerticalLoopSection]: """Split the entry level of a loop section. Args: loop_order: forward or backward order. section: loop section to split. Returns: Two loop sections. """ assert loop_order in (common.LoopOrder.FORWARD, common.LoopOrder.BACKWARD) if loop_order == common.LoopOrder.FORWARD: bound = common.AxisBound(level=section.interval.start.level, offset=section.interval.start.offset + 1) entry_interval = oir.Interval(start=section.interval.start, end=bound) rest_interval = oir.Interval(start=bound, end=section.interval.end) else: bound = common.AxisBound(level=section.interval.end.level, offset=section.interval.end.offset - 1) entry_interval = oir.Interval(start=bound, end=section.interval.end) rest_interval = oir.Interval(start=section.interval.start, end=bound) decls = list(section.iter_tree().if_isinstance(oir.Decl)) decls_map = {decl.name: new_symbol_name(decl.name) for decl in decls} class FixSymbolNameClashes(NodeTranslator): def visit_ScalarAccess(self, node: oir.ScalarAccess) -> oir.ScalarAccess: if node.name not in decls_map: return node return oir.ScalarAccess(name=decls_map[node.name], dtype=node.dtype) def visit_LocalScalar(self, node: oir.LocalScalar) -> oir.LocalScalar: return oir.LocalScalar(name=decls_map[node.name], dtype=node.dtype) return ( oir.VerticalLoopSection( interval=entry_interval, horizontal_executions=FixSymbolNameClashes().visit( section.horizontal_executions), loc=section.loc, ), oir.VerticalLoopSection( interval=rest_interval, horizontal_executions=section.horizontal_executions, loc=section.loc, ), )
def make_bound_or_level(bound: AxisBound, level) -> Optional[common.AxisBound]: if (level == LevelMarker.START and bound.offset <= -10000) or (level == LevelMarker.END and bound.offset >= 10000): return None else: return common.AxisBound( level=self.GT4PY_LEVELMARKER_TO_GTIR_LEVELMARKER[ bound.level], offset=bound.offset, )
def test_HorizontalInterval(): common.HorizontalInterval( start=common.AxisBound(level=common.LevelMarker.START, offset=-1), end=common.AxisBound(level=common.LevelMarker.START, offset=0), ) with pytest.raises(ValidationError): common.HorizontalInterval( start=common.AxisBound(level=common.LevelMarker.END, offset=0), end=common.AxisBound(level=common.LevelMarker.START, offset=-1), ) common.HorizontalInterval( start=common.AxisBound(level=common.LevelMarker.START, offset=0), end=common.AxisBound(level=common.LevelMarker.START, offset=-1), )