def prune_unused_parameters(node: gtir.Stencil) -> gtir.Stencil: """ Remove unused parameters from the gtir signature. (Maybe this pass should go into a later stage. If you need to touch this pass, e.g. when the definition_ir gets removed, consider moving it to a more appropriate level. Maybe to the backend IR?) """ assert isinstance(node, gtir.Stencil) used_variables = ( node.iter_tree() .if_isinstance(gtir.FieldAccess, gtir.ScalarAccess) .getattr("name") .to_list() ) used_params = list(filter(lambda param: param.name in used_variables, node.params)) return node.copy(update={"params": used_params})
def copy_computation(copy_v_loop): yield Stencil( name="copy_gtir", loc=SourceLocation(line=1, column=1, source="copy_gtir"), params=[ FieldDecl(name="a", dtype=DataType.FLOAT32), FieldDecl(name="b", dtype=DataType.FLOAT32), ], vertical_loops=[copy_v_loop], )
def visit_Stencil(self, node: gtir.Stencil, **kwargs: Any) -> FIELD_EXT_T: field_extents = { name: Extent.zeros() for name in _iter_field_names(node) } ctx = self.StencilContext() for field_if in node.iter_tree().if_isinstance(gtir.FieldIfStmt): self.visit(field_if, ctx=ctx) for assign in reversed(_iter_assigns(node).to_list()): self.visit(assign, ctx=ctx, field_extents=field_extents) return field_extents
def visit_Stencil(self, node: gtir.Stencil, *, mask_inwards: bool, **kwargs: Any) -> FIELD_EXT_T: field_extents: FIELD_EXT_T = {} ctx = self.StencilContext() for field_if in node.iter_tree().if_isinstance(gtir.FieldIfStmt): self.visit(field_if, ctx=ctx) for assign in reversed(_iter_assigns(node).to_list()): self.visit(assign, ctx=ctx, field_extents=field_extents) for name in _iter_field_names(node): # ensure we have an extent for all fields. note that we do not initialize to zero in the beginning as this # breaks inward pointing extends (i.e. negative boundaries). field_extents.setdefault(name, Extent.zeros()) if mask_inwards: # set inward pointing extents to zero field_extents[name] = Extent(*((min(0, e[0]), max(0, e[1])) for e in field_extents[name])) return field_extents
def copy_computation(copy_v_loop): yield Stencil( name="copy_gtir", loc=SourceLocation(line=1, column=1, source="copy_gtir"), params=[ FieldDecl( name="foo", dtype=DataType.FLOAT32, dimensions=(True, True, True), ), FieldDecl( name="bar", dtype=DataType.FLOAT32, dimensions=(True, True, True), ), ], vertical_loops=[copy_v_loop], )
def get_nodes_with_name(stencil: Stencil, name: str): return stencil.iter_tree().if_hasattr("name").filter( lambda node: node.name == name).to_list()
def get_nodes_with_name_and_dtype(stencil: Stencil, name: str): return (stencil.iter_tree().if_hasattr("name").filter( lambda node: hasattr(node, "dtype") and node.name == name).to_list())
def build(self) -> Stencil: return Stencil( name=self._name, params=self._params, vertical_loops=self._vertical_loops, )
def _iter_assigns(node: gtir.Stencil) -> XIterator[gtir.ParAssignStmt]: return node.iter_tree().if_isinstance(gtir.ParAssignStmt)