Exemple #1
0
def _compute_relative_interval(
    extent: Tuple[int, int], interval: common.HorizontalInterval
) -> Optional[common.HorizontalInterval]:
    def _offset(extent: Tuple[int, int],
                bound: Optional[common.AxisBound],
                start: bool = True) -> Tuple[common.LevelMarker, int]:
        if bound:
            if start:
                if bound.level == common.LevelMarker.START:
                    offset = max(0, bound.offset - extent[0])
                else:
                    offset = min(0, bound.offset - extent[1])
            else:
                if bound.level == common.LevelMarker.END:
                    offset = min(0, bound.offset - extent[1])
                else:
                    offset = max(0, bound.offset - extent[0])
        else:
            offset = 0
        return offset

    return (common.HorizontalInterval(
        start=common.AxisBound(
            level=interval.start.level
            if interval.start else common.LevelMarker.START,
            offset=_offset(extent, interval.start, start=True),
        ),
        end=common.AxisBound(
            level=interval.end.level
            if interval.end else common.LevelMarker.END,
            offset=_offset(extent, interval.end, start=False),
        ),
    ) if _overlap_along_axis(extent, interval) else None)
Exemple #2
0
    def _split_entry_level(
        loop_order: common.LoopOrder,
        section: oir.VerticalLoopSection,
        new_symbol_name: Callable[[str], str],
    ) -> Tuple[oir.VerticalLoopSection, oir.VerticalLoopSection]:
        """Split the entry level of a loop section.

        Args:
            loop_order: forward or backward order.
            section: loop section to split.

        Returns:
            Two loop sections.
        """
        assert loop_order in (common.LoopOrder.FORWARD,
                              common.LoopOrder.BACKWARD)
        if loop_order == common.LoopOrder.FORWARD:
            bound = common.AxisBound(level=section.interval.start.level,
                                     offset=section.interval.start.offset + 1)
            entry_interval = oir.Interval(start=section.interval.start,
                                          end=bound)
            rest_interval = oir.Interval(start=bound, end=section.interval.end)
        else:
            bound = common.AxisBound(level=section.interval.end.level,
                                     offset=section.interval.end.offset - 1)
            entry_interval = oir.Interval(start=bound,
                                          end=section.interval.end)
            rest_interval = oir.Interval(start=section.interval.start,
                                         end=bound)
        decls = list(section.iter_tree().if_isinstance(oir.Decl))
        decls_map = {decl.name: new_symbol_name(decl.name) for decl in decls}

        class FixSymbolNameClashes(NodeTranslator):
            def visit_ScalarAccess(self,
                                   node: oir.ScalarAccess) -> oir.ScalarAccess:
                if node.name not in decls_map:
                    return node
                return oir.ScalarAccess(name=decls_map[node.name],
                                        dtype=node.dtype)

            def visit_LocalScalar(self,
                                  node: oir.LocalScalar) -> oir.LocalScalar:
                return oir.LocalScalar(name=decls_map[node.name],
                                       dtype=node.dtype)

        return (
            oir.VerticalLoopSection(
                interval=entry_interval,
                horizontal_executions=FixSymbolNameClashes().visit(
                    section.horizontal_executions),
                loc=section.loc,
            ),
            oir.VerticalLoopSection(
                interval=rest_interval,
                horizontal_executions=section.horizontal_executions,
                loc=section.loc,
            ),
        )
Exemple #3
0
 def make_bound_or_level(bound: AxisBound,
                         level) -> Optional[common.AxisBound]:
     if (level == LevelMarker.START
             and bound.offset <= -10000) or (level == LevelMarker.END
                                             and bound.offset >= 10000):
         return None
     else:
         return common.AxisBound(
             level=self.GT4PY_LEVELMARKER_TO_GTIR_LEVELMARKER[
                 bound.level],
             offset=bound.offset,
         )
Exemple #4
0
def test_HorizontalInterval():
    common.HorizontalInterval(
        start=common.AxisBound(level=common.LevelMarker.START, offset=-1),
        end=common.AxisBound(level=common.LevelMarker.START, offset=0),
    )
    with pytest.raises(ValidationError):
        common.HorizontalInterval(
            start=common.AxisBound(level=common.LevelMarker.END, offset=0),
            end=common.AxisBound(level=common.LevelMarker.START, offset=-1),
        )
        common.HorizontalInterval(
            start=common.AxisBound(level=common.LevelMarker.START, offset=0),
            end=common.AxisBound(level=common.LevelMarker.START, offset=-1),
        )