Ejemplo n.º 1
0
 def domain() -> "CartesianIterationSpace":
     return CartesianIterationSpace(
         i_interval=oir.Interval(start=oir.AxisBound.start(),
                                 end=oir.AxisBound.end()),
         j_interval=oir.Interval(start=oir.AxisBound.start(),
                                 end=oir.AxisBound.end()),
     )
Ejemplo n.º 2
0
 def __or__(self, other: "CartesianIterationSpace") -> "CartesianIterationSpace":
     i_interval = oir.Interval(
         start=oir.AxisBound.from_start(
             min(
                 self.i_interval.start.offset,
                 other.i_interval.start.offset,
             )
         ),
         end=oir.AxisBound.from_end(
             max(
                 self.i_interval.end.offset,
                 other.i_interval.end.offset,
             )
         ),
     )
     j_interval = oir.Interval(
         start=oir.AxisBound.from_start(
             min(
                 self.j_interval.start.offset,
                 other.j_interval.start.offset,
             )
         ),
         end=oir.AxisBound.from_end(
             max(
                 self.j_interval.end.offset,
                 other.j_interval.end.offset,
             )
         ),
     )
     return CartesianIterationSpace(i_interval=i_interval, j_interval=j_interval)
Ejemplo n.º 3
0
    def _split_entry_level(
        loop_order: common.LoopOrder,
        section: oir.VerticalLoopSection,
        new_symbol_name: Callable[[str], str],
    ) -> Tuple[oir.VerticalLoopSection, oir.VerticalLoopSection]:
        """Split the entry level of a loop section.

        Args:
            loop_order: forward or backward order.
            section: loop section to split.

        Returns:
            Two loop sections.
        """
        assert loop_order in (common.LoopOrder.FORWARD,
                              common.LoopOrder.BACKWARD)
        if loop_order == common.LoopOrder.FORWARD:
            bound = common.AxisBound(level=section.interval.start.level,
                                     offset=section.interval.start.offset + 1)
            entry_interval = oir.Interval(start=section.interval.start,
                                          end=bound)
            rest_interval = oir.Interval(start=bound, end=section.interval.end)
        else:
            bound = common.AxisBound(level=section.interval.end.level,
                                     offset=section.interval.end.offset - 1)
            entry_interval = oir.Interval(start=bound,
                                          end=section.interval.end)
            rest_interval = oir.Interval(start=section.interval.start,
                                         end=bound)
        decls = list(section.iter_tree().if_isinstance(oir.Decl))
        decls_map = {decl.name: new_symbol_name(decl.name) for decl in decls}

        class FixSymbolNameClashes(NodeTranslator):
            def visit_ScalarAccess(self,
                                   node: oir.ScalarAccess) -> oir.ScalarAccess:
                if node.name not in decls_map:
                    return node
                return oir.ScalarAccess(name=decls_map[node.name],
                                        dtype=node.dtype)

            def visit_LocalScalar(self,
                                  node: oir.LocalScalar) -> oir.LocalScalar:
                return oir.LocalScalar(name=decls_map[node.name],
                                       dtype=node.dtype)

        return (
            oir.VerticalLoopSection(
                interval=entry_interval,
                horizontal_executions=FixSymbolNameClashes().visit(
                    section.horizontal_executions),
                loc=section.loc,
            ),
            oir.VerticalLoopSection(
                interval=rest_interval,
                horizontal_executions=section.horizontal_executions,
                loc=section.loc,
            ),
        )
Ejemplo n.º 4
0
    def from_offset(offset: CartesianOffset) -> "CartesianIterationSpace":

        return CartesianIterationSpace(
            i_interval=oir.Interval(
                start=oir.AxisBound.from_start(min(0, offset.i)),
                end=oir.AxisBound.from_end(max(0, offset.i)),
            ),
            j_interval=oir.Interval(
                start=oir.AxisBound.from_start(min(0, offset.j)),
                end=oir.AxisBound.from_end(max(0, offset.j)),
            ),
        )
Ejemplo n.º 5
0
    def from_offset(
        offset: Union[CartesianOffset, oir.VariableKOffset]
    ) -> "CartesianIterationSpace":

        dict_offsets = offset.to_dict()
        return CartesianIterationSpace(
            i_interval=oir.Interval(
                start=oir.AxisBound.from_start(min(0, dict_offsets["i"])),
                end=oir.AxisBound.from_end(max(0, dict_offsets["i"])),
            ),
            j_interval=oir.Interval(
                start=oir.AxisBound.from_start(min(0, dict_offsets["j"])),
                end=oir.AxisBound.from_end(max(0, dict_offsets["j"])),
            ),
        )
Ejemplo n.º 6
0
    def __setitem__(self, key: oir.Interval, value: Any) -> None:
        if not isinstance(key, oir.Interval):
            raise TypeError(
                "Only OIR intervals supported for method add of IntervalSet.")

        delete = list()
        for i, (start,
                end) in enumerate(zip(self.interval_starts,
                                      self.interval_ends)):
            if key.covers(oir.Interval(start=start, end=end)):
                delete.append(i)

        for i in reversed(delete):  # so indices keep validity while deleting
            del self.interval_starts[i]
            del self.interval_ends[i]
            del self.values[i]

        if len(self.interval_starts) == 0:
            self.interval_starts.append(key.start)
            self.interval_ends.append(key.end)
            self.values.append(value)
            return

        for i, (start,
                end) in enumerate(zip(self.interval_starts,
                                      self.interval_ends)):
            if oir.Interval(start=start, end=end).covers(key):
                self._setitem_subset_of_existing(i, key, value)
                return

        for i, (start,
                end) in enumerate(zip(self.interval_starts,
                                      self.interval_ends)):
            if (key.intersects(oir.Interval(start=start, end=end))
                    or start == key.end or end == key.start):
                self._setitem_partial_overlap(i, key, value)
                return

        for i, start in enumerate(self.interval_starts):
            if start > key.start:
                self.interval_starts.insert(i, key.start)
                self.interval_ends.insert(i, key.end)
                self.values.insert(i, value)
                return
        self.interval_starts.append(key.start)
        self.interval_ends.append(key.end)
        self.values.append(value)
        return
Ejemplo n.º 7
0
 def _setitem_partial_overlap(self, i: int, key: oir.Interval,
                              value: Any) -> None:
     start = self.interval_starts[i]
     if key.start < start:
         if self.values[i] is value:
             self.interval_starts[i] = key.start
         else:
             self.interval_starts[i] = key.end
             self.interval_starts.insert(i, key.start)
             self.interval_ends.insert(i, key.end)
             self.values.insert(i, value)
     else:  # key.end > end
         if self.values[i] is value:
             self.interval_ends[i] = key.end
             nextidx = i + 1
         else:
             self.interval_ends[i] = key.start
             self.interval_starts.insert(i + 1, key.start)
             self.interval_ends.insert(i + 1, key.end)
             self.values.insert(i + 1, value)
             nextidx = i + 2
         if nextidx < len(self.interval_starts) and (
                 key.intersects(
                     oir.Interval(start=self.interval_starts[nextidx],
                                  end=self.interval_ends[nextidx]))
                 or self.interval_starts[nextidx] == key.end):
             if self.values[nextidx] is value:
                 self.interval_ends[nextidx -
                                    1] = self.interval_ends[nextidx]
                 del self.interval_starts[nextidx]
                 del self.interval_ends[nextidx]
                 del self.values[nextidx]
             else:
                 self.interval_starts[nextidx] = key.end
Ejemplo n.º 8
0
    def __getitem__(self, key: oir.Interval) -> List[Any]:
        if not isinstance(key, oir.Interval):
            raise TypeError("Only OIR intervals supported for keys of IntervalMapping.")

        res = []
        for start, end, value in zip(self.interval_starts, self.interval_ends, self.values):
            if key.intersects(oir.Interval(start=start, end=end)):
                res.append(value)
        return res
Ejemplo n.º 9
0
 def compose(self, other: "CartesianIterationSpace") -> "CartesianIterationSpace":
     i_interval = oir.Interval(
         start=oir.AxisBound.from_start(
             self.i_interval.start.offset + other.i_interval.start.offset,
         ),
         end=oir.AxisBound.from_end(
             self.i_interval.end.offset + other.i_interval.end.offset,
         ),
     )
     j_interval = oir.Interval(
         start=oir.AxisBound.from_start(
             self.j_interval.start.offset + other.j_interval.start.offset,
         ),
         end=oir.AxisBound.from_end(
             self.j_interval.end.offset + other.j_interval.end.offset,
         ),
     )
     return CartesianIterationSpace(i_interval=i_interval, j_interval=j_interval)
Ejemplo n.º 10
0
 def build(self):
     return oir.VerticalLoop(
         interval=oir.Interval(
             start=self._start,
             end=self._end,
         ),
         horizontal_executions=self._horizontal_executions,
         loop_order=self._loop_order,
         declarations=self._declarations,
     )
Ejemplo n.º 11
0
 def visit_Interval(self, node: gtir.Interval) -> oir.Interval:
     return oir.Interval(start=self.visit(node.start), end=self.visit(node.end), loc=node.loc)
Ejemplo n.º 12
0
 def visit_Interval(self, node: gtir.Interval, **kwargs: Any) -> oir.Interval:
     return oir.Interval(
         start=self.visit(node.start),
         end=self.visit(node.end),
     )