def visit_Stencil(self, node: oir.Stencil, **kwargs: Any) -> oir.Stencil: write_before_read_tmps = { symbol for symbol, value in kwargs["symtable"].items() if isinstance(value, oir.Temporary) } horizontal_executions = node.iter_tree().if_isinstance( oir.HorizontalExecution) for horizontal_execution in horizontal_executions: accesses = AccessCollector.apply(horizontal_execution) offsets = accesses.offsets() ordered_accesses = accesses.ordered_accesses() def write_before_read(tmp: str) -> bool: if tmp not in offsets: return True if offsets[tmp] != {(0, 0, 0)}: return False return next(o.is_write for o in ordered_accesses if o.field == tmp) write_before_read_tmps = { tmp for tmp in write_before_read_tmps if write_before_read(tmp) } return super().visit_Stencil(node, tmps_to_replace=write_before_read_tmps, **kwargs)
def build(self) -> Stencil: return Stencil( name=self._name, params=self._params, vertical_loops=self._vertical_loops, declarations=self._declarations, )
def visit_Stencil(self, node: oir.Stencil, **kwargs: Any) -> oir.Stencil: temporaries = { symbol for symbol, value in node.symtable_.items() if isinstance(value, oir.Temporary) } horizontal_executions = node.iter_tree().if_isinstance( oir.HorizontalExecution) counts: collections.Counter = sum( (collections.Counter(horizontal_execution.iter_tree( ).if_isinstance( oir.FieldAccess).getattr("name").if_in(temporaries).to_set()) for horizontal_execution in horizontal_executions), collections.Counter(), ) local_tmps = {tmp for tmp, count in counts.items() if count == 1} return oir.Stencil( name=node.name, params=node.params, vertical_loops=self.visit(node.vertical_loops, local_tmps=local_tmps, symtable=node.symtable_, **kwargs), declarations=[ d for d in node.declarations if d.name not in local_tmps ], )
def visit_Stencil(self, node: oir.Stencil, **kwargs: Any) -> Dict[str, oir.Expr]: masks_to_inline: Dict[str, oir.Expr] = { mask_stmt.mask.name: None for mask_stmt in node.iter_tree().if_isinstance(oir.MaskStmt). filter(lambda stmt: isinstance(stmt.mask, oir.FieldAccess)) } self.visit(node.vertical_loops, masks_to_inline=masks_to_inline, **kwargs) assert all(value is not None for value in masks_to_inline.values()) return masks_to_inline
def visit_Stencil(self, node: oir.Stencil, **kwargs: Any) -> oir.Stencil: vertical_loops = node.iter_tree().if_isinstance(oir.VerticalLoop) counts: collections.Counter = sum( (collections.Counter(vertical_loop.iter_tree().if_isinstance( oir.FieldAccess).getattr("name").if_in( {tmp.name for tmp in node.declarations}).to_set()) for vertical_loop in vertical_loops), collections.Counter(), ) local_tmps = {tmp for tmp, count in counts.items() if count == 1} return self.generic_visit(node, local_tmps=local_tmps, **kwargs)
def visit_Stencil(self, node: oir.Stencil, **kwargs: Any) -> oir.Stencil: horizontal_executions = node.iter_tree().if_isinstance( oir.HorizontalExecution) counts: collections.Counter = sum( (collections.Counter( horizontal_execution.iter_tree().if_isinstance( oir.FieldAccess).getattr("name").if_in( {tmp.name for tmp in node.declarations}).to_set()) for horizontal_execution in horizontal_executions), collections.Counter(), ) local_tmps = {tmp for tmp, count in counts.items() if count == 1} return super().visit_Stencil(node, tmps_to_replace=local_tmps)
def has_variable_access(stencil: oir.Stencil) -> bool: return len(stencil.iter_tree().if_isinstance(oir.VariableKOffset).to_list()) > 0