def visit_VerticalLoop( self, node: oir.VerticalLoop, *, local_tmps: Set[str], **kwargs: Any ) -> oir.VerticalLoop: if node.loop_order != common.LoopOrder.PARALLEL or not local_tmps: return node def already_cached(field: str) -> bool: return any(c.name == field for c in node.caches) def has_vertical_offset(offsets: Set[Tuple[int, int, int]]) -> bool: return any(offset[2] != 0 for offset in offsets) accesses = AccessCollector.apply(node).cartesian_accesses().offsets() cacheable = { field for field, offsets in accesses.items() if field in local_tmps and not already_cached(field) and not has_vertical_offset(offsets) } caches = self.visit(node.caches, **kwargs) + [ oir.IJCache(name=field) for field in cacheable ] return oir.VerticalLoop( sections=node.sections, loop_order=node.loop_order, caches=caches, )
def visit_VerticalLoop(self, node: oir.VerticalLoop) -> Any: sections = self.visit(node.sections) if not sections: return NOTHING return oir.VerticalLoop(loop_order=node.loop_order, sections=sections, caches=node.caches)
def visit_VerticalLoop(self, node: gtir.VerticalLoop, *, ctx: Context) -> oir.VerticalLoop: horiz_execs: List[oir.HorizontalExecution] = [] for stmt in node.body: ctx.reset_local_scalars() ret = self.visit(stmt, ctx=ctx) stmts = utils.flatten_list( [ret] if isinstance(ret, oir.Stmt) else ret) horiz_execs.append( oir.HorizontalExecution(body=stmts, declarations=ctx.local_scalars)) ctx.temp_fields += [ oir.Temporary(name=temp.name, dtype=temp.dtype, dimensions=temp.dimensions) for temp in node.temporaries ] return oir.VerticalLoop( loop_order=node.loop_order, sections=[ oir.VerticalLoopSection( interval=self.visit(node.interval), horizontal_executions=horiz_execs, loc=node.loc, ) ], )
def _merge(a: oir.VerticalLoop, b: oir.VerticalLoop) -> oir.VerticalLoop: sections = a.sections + b.sections if a.caches or b.caches: warnings.warn("AdjacentLoopMerging pass removed previously declared caches") return oir.VerticalLoop( loop_order=a.loop_order, sections=sections, caches=[], )
def build(self): return oir.VerticalLoop( interval=oir.Interval( start=self._start, end=self._end, ), horizontal_executions=self._horizontal_executions, loop_order=self._loop_order, declarations=self._declarations, )
def visit_VerticalLoop(self, node: oir.VerticalLoop, **kwargs: Any) -> oir.VerticalLoop: if node.loop_order != common.LoopOrder.PARALLEL: return node sections = self.visit(node.sections, **kwargs) accessed = AccessCollector.apply(sections).fields() return oir.VerticalLoop( loop_order=node.loop_order, sections=sections, caches=[c for c in node.caches if c.name in accessed], )
def visit_VerticalLoop(self, node: oir.VerticalLoop, **kwargs: Any) -> oir.VerticalLoop: # k-caches are restricted to loops with a single horizontal region as all regions without # horizontal offsets should be merged before anyway and this restriction allows for easier # conversion of fill and flush caches to local caches later if node.loop_order == common.LoopOrder.PARALLEL or any( len(section.horizontal_executions) != 1 for section in node.sections): return self.generic_visit(node, **kwargs) all_accesses = AccessCollector.apply(node) fields_with_variable_reads = { field for field, offsets in all_accesses.offsets().items() if any(off[2] is None for off in offsets) } def accessed_more_than_once(offsets: Set[Any]) -> bool: return len(offsets) > 1 def already_cached(field: str) -> bool: return field in {c.name for c in node.caches} # TODO(fthaler): k-caches with non-zero ij offsets? def has_horizontal_offset(offsets: Set[Tuple[int, int, int]]) -> bool: return any(offset[:2] != (0, 0) for offset in offsets) def offsets_within_limits(offsets: Set[Tuple[int, int, int]]) -> bool: return all( abs(offset[2]) <= self.max_cacheable_offset for offset in offsets) def has_variable_offset_reads(field: str) -> bool: return field in fields_with_variable_reads accesses = all_accesses.cartesian_accesses().offsets() cacheable = { field for field, offsets in accesses.items() if not already_cached(field) and not has_variable_offset_reads(field) and accessed_more_than_once( offsets) and not has_horizontal_offset(offsets) and offsets_within_limits(offsets) } caches = self.visit(node.caches, **kwargs) + [ oir.KCache(name=field, fill=True, flush=True) for field in cacheable ] return oir.VerticalLoop( loop_order=node.loop_order, sections=node.sections, caches=caches, loc=node.loc, )
def visit_VerticalLoop( self, node: oir.VerticalLoop, tmps_to_replace: Set[str], **kwargs: Any, ) -> oir.VerticalLoop: return oir.VerticalLoop( loop_order=node.loop_order, sections=self.visit(node.sections, tmps_to_replace=tmps_to_replace, **kwargs), caches=[c for c in node.caches if c.name not in tmps_to_replace], )
def visit_VerticalLoop(self, node: gtir.VerticalLoop, *, ctx: Context, **kwargs: Any) -> oir.VerticalLoop: ctx.horizontal_executions.clear() self.visit(node.body, ctx=ctx) for temp in node.temporaries: ctx.add_decl(oir.Temporary(name=temp.name, dtype=temp.dtype)) return oir.VerticalLoop( loop_order=node.loop_order, sections=[ oir.VerticalLoopSection( interval=self.visit(node.interval, **kwargs), horizontal_executions=ctx.horizontal_executions, ) ], caches=[], )
def visit_VerticalLoop( self, node: oir.VerticalLoop, *, new_tmps: List[oir.Temporary], symtable: Dict[str, Any], new_symbol_name: Callable[[str], str], **kwargs: Any, ) -> oir.VerticalLoop: filling_fields: Dict[str, str] = { c.name: new_symbol_name(c.name) for c in node.caches if isinstance(c, oir.KCache) and c.fill } flushing_fields: Dict[str, str] = { c.name: filling_fields[c.name] if c.name in filling_fields else new_symbol_name(c.name) for c in node.caches if isinstance(c, oir.KCache) and c.flush } filling_or_flushing_fields = dict( set(filling_fields.items()) | set(flushing_fields.items()) ) if not filling_or_flushing_fields: return node # new temporaries used for caches, declarations are later added to stencil for field_name, tmp_name in filling_or_flushing_fields.items(): new_tmps.append( oir.Temporary( name=tmp_name, dtype=symtable[field_name].dtype, dimensions=(True, True, True) ) ) if filling_fields: # split sections where more than one fill operations are required at the entry level first_unfilled: Dict[str, int] = dict() split_sections: List[oir.VerticalLoopSection] = [] for section in node.sections: split_section, previous_fills = self._split_section_with_multiple_fills( node.loop_order, section, filling_fields, first_unfilled, new_symbol_name ) split_sections += split_section else: split_sections = node.sections # generate cache fill and flush statements first_unfilled = dict() sections = [] for section in split_sections: fills, first_unfilled = self._fill_stmts( node.loop_order, section, filling_fields, first_unfilled, symtable ) flushes = self._flush_stmts(node.loop_order, section, flushing_fields, symtable) sections.append( self.visit( section, fills=fills, flushes=flushes, name_map=filling_or_flushing_fields, symtable=symtable, **kwargs, ) ) # replace cache declarations caches = [c for c in node.caches if c.name not in filling_or_flushing_fields] + [ oir.KCache(name=f, fill=False, flush=False) for f in filling_or_flushing_fields.values() ] return oir.VerticalLoop(loop_order=node.loop_order, sections=sections, caches=caches)