def _split_entry_level( loop_order: common.LoopOrder, section: oir.VerticalLoopSection, new_symbol_name: Callable[[str], str], ) -> Tuple[oir.VerticalLoopSection, oir.VerticalLoopSection]: """Split the entry level of a loop section. Args: loop_order: forward or backward order. section: loop section to split. Returns: Two loop sections. """ assert loop_order in (common.LoopOrder.FORWARD, common.LoopOrder.BACKWARD) if loop_order == common.LoopOrder.FORWARD: bound = common.AxisBound(level=section.interval.start.level, offset=section.interval.start.offset + 1) entry_interval = oir.Interval(start=section.interval.start, end=bound) rest_interval = oir.Interval(start=bound, end=section.interval.end) else: bound = common.AxisBound(level=section.interval.end.level, offset=section.interval.end.offset - 1) entry_interval = oir.Interval(start=bound, end=section.interval.end) rest_interval = oir.Interval(start=section.interval.start, end=bound) decls = list(section.iter_tree().if_isinstance(oir.Decl)) decls_map = {decl.name: new_symbol_name(decl.name) for decl in decls} class FixSymbolNameClashes(NodeTranslator): def visit_ScalarAccess(self, node: oir.ScalarAccess) -> oir.ScalarAccess: if node.name not in decls_map: return node return oir.ScalarAccess(name=decls_map[node.name], dtype=node.dtype) def visit_LocalScalar(self, node: oir.LocalScalar) -> oir.LocalScalar: return oir.LocalScalar(name=decls_map[node.name], dtype=node.dtype) return ( oir.VerticalLoopSection( interval=entry_interval, horizontal_executions=FixSymbolNameClashes().visit( section.horizontal_executions), loc=section.loc, ), oir.VerticalLoopSection( interval=rest_interval, horizontal_executions=section.horizontal_executions, loc=section.loc, ), )
def visit_VerticalLoopSection(self, node: oir.VerticalLoopSection) -> Any: horizontal_executions = self.visit(node.horizontal_executions) if not horizontal_executions: return NOTHING return oir.VerticalLoopSection( interval=node.interval, horizontal_executions=horizontal_executions)
def visit_VerticalLoopSection(self, node: oir.VerticalLoopSection, **kwargs: Any) -> oir.VerticalLoopSection: return oir.VerticalLoopSection( interval=node.interval, horizontal_executions=self._merge(node.horizontal_executions, **kwargs), )
def visit_VerticalLoop(self, node: gtir.VerticalLoop, *, ctx: Context) -> oir.VerticalLoop: horiz_execs: List[oir.HorizontalExecution] = [] for stmt in node.body: ctx.reset_local_scalars() ret = self.visit(stmt, ctx=ctx) stmts = utils.flatten_list( [ret] if isinstance(ret, oir.Stmt) else ret) horiz_execs.append( oir.HorizontalExecution(body=stmts, declarations=ctx.local_scalars)) ctx.temp_fields += [ oir.Temporary(name=temp.name, dtype=temp.dtype, dimensions=temp.dimensions) for temp in node.temporaries ] return oir.VerticalLoop( loop_order=node.loop_order, sections=[ oir.VerticalLoopSection( interval=self.visit(node.interval), horizontal_executions=horiz_execs, loc=node.loc, ) ], )
def visit_VerticalLoopSection( self, node: oir.VerticalLoop, local_tmps: Set[str], **kwargs: Any, ) -> oir.VerticalLoopSection: return oir.VerticalLoopSection( interval=node.interval, horizontal_executions=self.visit(node.horizontal_executions, local_tmps=local_tmps, **kwargs), )
def visit_VerticalLoopSection(self, node: oir.VerticalLoopSection, **kwargs: Any) -> oir.VerticalLoopSection: last_vls = None next_vls = node applied = True while applied: last_vls = next_vls next_vls = oir.VerticalLoopSection( interval=last_vls.interval, horizontal_executions=self._merge( last_vls.horizontal_executions, **kwargs), ) applied = len(next_vls.horizontal_executions) < len( last_vls.horizontal_executions) return next_vls
def visit_VerticalLoop(self, node: gtir.VerticalLoop, *, ctx: Context, **kwargs: Any) -> oir.VerticalLoop: ctx.horizontal_executions.clear() self.visit(node.body, ctx=ctx) for temp in node.temporaries: ctx.add_decl(oir.Temporary(name=temp.name, dtype=temp.dtype)) return oir.VerticalLoop( loop_order=node.loop_order, sections=[ oir.VerticalLoopSection( interval=self.visit(node.interval, **kwargs), horizontal_executions=ctx.horizontal_executions, ) ], caches=[], )
def visit_VerticalLoopSection( self, node: oir.VerticalLoopSection, *, block_extents: Dict[int, Extent], new_symbol_name: Callable[[str], str], **kwargs: Any, ) -> oir.VerticalLoopSection: horizontal_executions = [node.horizontal_executions[0]] new_block_extents = [block_extents[id(horizontal_executions[-1])]] for this_hexec in node.horizontal_executions[1:]: last_extent = new_block_extents[-1] last_writes = AccessCollector.apply( horizontal_executions[-1]).write_fields() this_offset_reads = { name for name, offsets in AccessCollector.apply( this_hexec).read_offsets().items() if any(off[0] != 0 or off[1] != 0 for off in offsets) } reads_with_offset_after_write = last_writes & this_offset_reads this_extent = block_extents[id(this_hexec)] if reads_with_offset_after_write or last_extent != this_extent: # Cannot merge: simply append to list horizontal_executions.append(this_hexec) new_block_extents.append(this_extent) else: # Merge duplicated_locals = { decl.name for decl in horizontal_executions[-1].declarations } & {decl.name for decl in this_hexec.declarations} # Map from old to new scalar names applied to the second horizontal execution scalar_map = { name: new_symbol_name(name) for name in duplicated_locals } locals_symtable = { decl.name: decl for decl in this_hexec.declarations } new_body = self.visit(this_hexec.body, scalar_map=scalar_map, **kwargs) this_not_duplicated = [ decl for decl in this_hexec.declarations if decl.name not in duplicated_locals ] this_mapped = [ oir.ScalarDecl(name=scalar_map[name], dtype=locals_symtable[name].dtype) for name in duplicated_locals ] horizontal_executions[-1] = oir.HorizontalExecution( body=horizontal_executions[-1].body + new_body, declarations=(horizontal_executions[-1].declarations + this_not_duplicated + this_mapped), ) return oir.VerticalLoopSection( interval=node.interval, horizontal_executions=horizontal_executions)
def visit_VerticalLoopSection( self, node: oir.VerticalLoopSection, *, block_extents: Dict[int, Extent], new_symbol_name: Callable[[str], str], **kwargs: Any, ) -> oir.VerticalLoopSection: @dataclass class UncheckedHorizontalExecution: # local replacement without type checking for type-checked oir node # required to reach reasonable run times for large node counts body: List[oir.Stmt] declarations: List[oir.LocalScalar] loc: Optional[SourceLocation] assert set(oir.HorizontalExecution.__fields__) == { "loc", "symtable_", "body", "declarations", }, ("Unexpected field in oir.HorizontalExecution, " "probably UncheckedHorizontalExecution needs an update") @classmethod def from_oir(cls, hexec: oir.HorizontalExecution): return cls(body=hexec.body, declarations=hexec.declarations, loc=hexec.loc) def to_oir(self) -> oir.HorizontalExecution: return oir.HorizontalExecution(body=self.body, declarations=self.declarations, loc=self.loc) horizontal_executions = [ UncheckedHorizontalExecution.from_oir( node.horizontal_executions[0]) ] new_block_extents = [block_extents[id(node.horizontal_executions[0])]] last_writes = AccessCollector.apply( node.horizontal_executions[0]).write_fields() for this_hexec in node.horizontal_executions[1:]: last_extent = new_block_extents[-1] this_offset_reads = { name for name, offsets in AccessCollector.apply( this_hexec).read_offsets().items() if any(off[0] != 0 or off[1] != 0 for off in offsets) } reads_with_offset_after_write = last_writes & this_offset_reads this_extent = block_extents[id(this_hexec)] if reads_with_offset_after_write or last_extent != this_extent: # Cannot merge: simply append to list horizontal_executions.append( UncheckedHorizontalExecution.from_oir(this_hexec)) new_block_extents.append(this_extent) last_writes = AccessCollector.apply(this_hexec).write_fields() else: # Merge duplicated_locals = { decl.name for decl in horizontal_executions[-1].declarations } & {decl.name for decl in this_hexec.declarations} # Map from old to new scalar names applied to the second horizontal execution scalar_map = { name: new_symbol_name(name) for name in duplicated_locals } locals_symtable = { decl.name: decl for decl in this_hexec.declarations } new_body = self.visit(this_hexec.body, scalar_map=scalar_map, **kwargs) this_not_duplicated = [ decl for decl in this_hexec.declarations if decl.name not in duplicated_locals ] this_mapped = [ oir.ScalarDecl(name=scalar_map[name], dtype=locals_symtable[name].dtype) for name in duplicated_locals ] horizontal_executions[-1] = UncheckedHorizontalExecution( body=horizontal_executions[-1].body + new_body, declarations=(horizontal_executions[-1].declarations + this_not_duplicated + this_mapped), loc=horizontal_executions[-1].loc, ) last_writes |= AccessCollector.apply(new_body).write_fields() return oir.VerticalLoopSection( interval=node.interval, horizontal_executions=[ hexec.to_oir() for hexec in horizontal_executions ], loc=node.loc, )