def domain() -> "CartesianIterationSpace": return CartesianIterationSpace( i_interval=oir.Interval(start=oir.AxisBound.start(), end=oir.AxisBound.end()), j_interval=oir.Interval(start=oir.AxisBound.start(), end=oir.AxisBound.end()), )
def __or__(self, other: "CartesianIterationSpace") -> "CartesianIterationSpace": i_interval = oir.Interval( start=oir.AxisBound.from_start( min( self.i_interval.start.offset, other.i_interval.start.offset, ) ), end=oir.AxisBound.from_end( max( self.i_interval.end.offset, other.i_interval.end.offset, ) ), ) j_interval = oir.Interval( start=oir.AxisBound.from_start( min( self.j_interval.start.offset, other.j_interval.start.offset, ) ), end=oir.AxisBound.from_end( max( self.j_interval.end.offset, other.j_interval.end.offset, ) ), ) return CartesianIterationSpace(i_interval=i_interval, j_interval=j_interval)
def _split_entry_level( loop_order: common.LoopOrder, section: oir.VerticalLoopSection, new_symbol_name: Callable[[str], str], ) -> Tuple[oir.VerticalLoopSection, oir.VerticalLoopSection]: """Split the entry level of a loop section. Args: loop_order: forward or backward order. section: loop section to split. Returns: Two loop sections. """ assert loop_order in (common.LoopOrder.FORWARD, common.LoopOrder.BACKWARD) if loop_order == common.LoopOrder.FORWARD: bound = common.AxisBound(level=section.interval.start.level, offset=section.interval.start.offset + 1) entry_interval = oir.Interval(start=section.interval.start, end=bound) rest_interval = oir.Interval(start=bound, end=section.interval.end) else: bound = common.AxisBound(level=section.interval.end.level, offset=section.interval.end.offset - 1) entry_interval = oir.Interval(start=bound, end=section.interval.end) rest_interval = oir.Interval(start=section.interval.start, end=bound) decls = list(section.iter_tree().if_isinstance(oir.Decl)) decls_map = {decl.name: new_symbol_name(decl.name) for decl in decls} class FixSymbolNameClashes(NodeTranslator): def visit_ScalarAccess(self, node: oir.ScalarAccess) -> oir.ScalarAccess: if node.name not in decls_map: return node return oir.ScalarAccess(name=decls_map[node.name], dtype=node.dtype) def visit_LocalScalar(self, node: oir.LocalScalar) -> oir.LocalScalar: return oir.LocalScalar(name=decls_map[node.name], dtype=node.dtype) return ( oir.VerticalLoopSection( interval=entry_interval, horizontal_executions=FixSymbolNameClashes().visit( section.horizontal_executions), loc=section.loc, ), oir.VerticalLoopSection( interval=rest_interval, horizontal_executions=section.horizontal_executions, loc=section.loc, ), )
def from_offset(offset: CartesianOffset) -> "CartesianIterationSpace": return CartesianIterationSpace( i_interval=oir.Interval( start=oir.AxisBound.from_start(min(0, offset.i)), end=oir.AxisBound.from_end(max(0, offset.i)), ), j_interval=oir.Interval( start=oir.AxisBound.from_start(min(0, offset.j)), end=oir.AxisBound.from_end(max(0, offset.j)), ), )
def from_offset( offset: Union[CartesianOffset, oir.VariableKOffset] ) -> "CartesianIterationSpace": dict_offsets = offset.to_dict() return CartesianIterationSpace( i_interval=oir.Interval( start=oir.AxisBound.from_start(min(0, dict_offsets["i"])), end=oir.AxisBound.from_end(max(0, dict_offsets["i"])), ), j_interval=oir.Interval( start=oir.AxisBound.from_start(min(0, dict_offsets["j"])), end=oir.AxisBound.from_end(max(0, dict_offsets["j"])), ), )
def __setitem__(self, key: oir.Interval, value: Any) -> None: if not isinstance(key, oir.Interval): raise TypeError( "Only OIR intervals supported for method add of IntervalSet.") delete = list() for i, (start, end) in enumerate(zip(self.interval_starts, self.interval_ends)): if key.covers(oir.Interval(start=start, end=end)): delete.append(i) for i in reversed(delete): # so indices keep validity while deleting del self.interval_starts[i] del self.interval_ends[i] del self.values[i] if len(self.interval_starts) == 0: self.interval_starts.append(key.start) self.interval_ends.append(key.end) self.values.append(value) return for i, (start, end) in enumerate(zip(self.interval_starts, self.interval_ends)): if oir.Interval(start=start, end=end).covers(key): self._setitem_subset_of_existing(i, key, value) return for i, (start, end) in enumerate(zip(self.interval_starts, self.interval_ends)): if (key.intersects(oir.Interval(start=start, end=end)) or start == key.end or end == key.start): self._setitem_partial_overlap(i, key, value) return for i, start in enumerate(self.interval_starts): if start > key.start: self.interval_starts.insert(i, key.start) self.interval_ends.insert(i, key.end) self.values.insert(i, value) return self.interval_starts.append(key.start) self.interval_ends.append(key.end) self.values.append(value) return
def _setitem_partial_overlap(self, i: int, key: oir.Interval, value: Any) -> None: start = self.interval_starts[i] if key.start < start: if self.values[i] is value: self.interval_starts[i] = key.start else: self.interval_starts[i] = key.end self.interval_starts.insert(i, key.start) self.interval_ends.insert(i, key.end) self.values.insert(i, value) else: # key.end > end if self.values[i] is value: self.interval_ends[i] = key.end nextidx = i + 1 else: self.interval_ends[i] = key.start self.interval_starts.insert(i + 1, key.start) self.interval_ends.insert(i + 1, key.end) self.values.insert(i + 1, value) nextidx = i + 2 if nextidx < len(self.interval_starts) and ( key.intersects( oir.Interval(start=self.interval_starts[nextidx], end=self.interval_ends[nextidx])) or self.interval_starts[nextidx] == key.end): if self.values[nextidx] is value: self.interval_ends[nextidx - 1] = self.interval_ends[nextidx] del self.interval_starts[nextidx] del self.interval_ends[nextidx] del self.values[nextidx] else: self.interval_starts[nextidx] = key.end
def __getitem__(self, key: oir.Interval) -> List[Any]: if not isinstance(key, oir.Interval): raise TypeError("Only OIR intervals supported for keys of IntervalMapping.") res = [] for start, end, value in zip(self.interval_starts, self.interval_ends, self.values): if key.intersects(oir.Interval(start=start, end=end)): res.append(value) return res
def compose(self, other: "CartesianIterationSpace") -> "CartesianIterationSpace": i_interval = oir.Interval( start=oir.AxisBound.from_start( self.i_interval.start.offset + other.i_interval.start.offset, ), end=oir.AxisBound.from_end( self.i_interval.end.offset + other.i_interval.end.offset, ), ) j_interval = oir.Interval( start=oir.AxisBound.from_start( self.j_interval.start.offset + other.j_interval.start.offset, ), end=oir.AxisBound.from_end( self.j_interval.end.offset + other.j_interval.end.offset, ), ) return CartesianIterationSpace(i_interval=i_interval, j_interval=j_interval)
def build(self): return oir.VerticalLoop( interval=oir.Interval( start=self._start, end=self._end, ), horizontal_executions=self._horizontal_executions, loop_order=self._loop_order, declarations=self._declarations, )
def visit_Interval(self, node: gtir.Interval) -> oir.Interval: return oir.Interval(start=self.visit(node.start), end=self.visit(node.end), loc=node.loc)
def visit_Interval(self, node: gtir.Interval, **kwargs: Any) -> oir.Interval: return oir.Interval( start=self.visit(node.start), end=self.visit(node.end), )