def visit_KCache(self, node: oir.KCache, *, pruneable: Set[str], **kwargs: Any) -> oir.KCache: if node.name in pruneable: return oir.KCache(name=node.name, fill=node.fill, flush=False, loc=node.loc) return self.generic_visit(node, **kwargs)
def visit_VerticalLoop(self, node: oir.VerticalLoop, **kwargs: Any) -> oir.VerticalLoop: # k-caches are restricted to loops with a single horizontal region as all regions without # horizontal offsets should be merged before anyway and this restriction allows for easier # conversion of fill and flush caches to local caches later if node.loop_order == common.LoopOrder.PARALLEL or any( len(section.horizontal_executions) != 1 for section in node.sections): return self.generic_visit(node, **kwargs) all_accesses = AccessCollector.apply(node) fields_with_variable_reads = { field for field, offsets in all_accesses.offsets().items() if any(off[2] is None for off in offsets) } def accessed_more_than_once(offsets: Set[Any]) -> bool: return len(offsets) > 1 def already_cached(field: str) -> bool: return field in {c.name for c in node.caches} # TODO(fthaler): k-caches with non-zero ij offsets? def has_horizontal_offset(offsets: Set[Tuple[int, int, int]]) -> bool: return any(offset[:2] != (0, 0) for offset in offsets) def offsets_within_limits(offsets: Set[Tuple[int, int, int]]) -> bool: return all( abs(offset[2]) <= self.max_cacheable_offset for offset in offsets) def has_variable_offset_reads(field: str) -> bool: return field in fields_with_variable_reads accesses = all_accesses.cartesian_accesses().offsets() cacheable = { field for field, offsets in accesses.items() if not already_cached(field) and not has_variable_offset_reads(field) and accessed_more_than_once( offsets) and not has_horizontal_offset(offsets) and offsets_within_limits(offsets) } caches = self.visit(node.caches, **kwargs) + [ oir.KCache(name=field, fill=True, flush=True) for field in cacheable ] return oir.VerticalLoop( loop_order=node.loop_order, sections=node.sections, caches=caches, loc=node.loc, )
def visit_VerticalLoop( self, node: oir.VerticalLoop, *, new_tmps: List[oir.Temporary], symtable: Dict[str, Any], new_symbol_name: Callable[[str], str], **kwargs: Any, ) -> oir.VerticalLoop: filling_fields: Dict[str, str] = { c.name: new_symbol_name(c.name) for c in node.caches if isinstance(c, oir.KCache) and c.fill } flushing_fields: Dict[str, str] = { c.name: filling_fields[c.name] if c.name in filling_fields else new_symbol_name(c.name) for c in node.caches if isinstance(c, oir.KCache) and c.flush } filling_or_flushing_fields = dict( set(filling_fields.items()) | set(flushing_fields.items()) ) if not filling_or_flushing_fields: return node # new temporaries used for caches, declarations are later added to stencil for field_name, tmp_name in filling_or_flushing_fields.items(): new_tmps.append( oir.Temporary( name=tmp_name, dtype=symtable[field_name].dtype, dimensions=(True, True, True) ) ) if filling_fields: # split sections where more than one fill operations are required at the entry level first_unfilled: Dict[str, int] = dict() split_sections: List[oir.VerticalLoopSection] = [] for section in node.sections: split_section, previous_fills = self._split_section_with_multiple_fills( node.loop_order, section, filling_fields, first_unfilled, new_symbol_name ) split_sections += split_section else: split_sections = node.sections # generate cache fill and flush statements first_unfilled = dict() sections = [] for section in split_sections: fills, first_unfilled = self._fill_stmts( node.loop_order, section, filling_fields, first_unfilled, symtable ) flushes = self._flush_stmts(node.loop_order, section, flushing_fields, symtable) sections.append( self.visit( section, fills=fills, flushes=flushes, name_map=filling_or_flushing_fields, symtable=symtable, **kwargs, ) ) # replace cache declarations caches = [c for c in node.caches if c.name not in filling_or_flushing_fields] + [ oir.KCache(name=f, fill=False, flush=False) for f in filling_or_flushing_fields.values() ] return oir.VerticalLoop(loop_order=node.loop_order, sections=sections, caches=caches)