def prune_unused_parameters(node: gtir.Stencil) -> gtir.Stencil:
    """
    Remove unused parameters from the gtir signature.

    (Maybe this pass should go into a later stage. If you need to touch this pass,
    e.g. when the definition_ir gets removed, consider moving it to a more appropriate
    level. Maybe to the backend IR?)
    """
    assert isinstance(node, gtir.Stencil)
    used_variables = (
        node.iter_tree()
        .if_isinstance(gtir.FieldAccess, gtir.ScalarAccess)
        .getattr("name")
        .to_list()
    )
    used_params = list(filter(lambda param: param.name in used_variables, node.params))
    return node.copy(update={"params": used_params})
예제 #2
0
def copy_computation(copy_v_loop):
    yield Stencil(
        name="copy_gtir",
        loc=SourceLocation(line=1, column=1, source="copy_gtir"),
        params=[
            FieldDecl(name="a", dtype=DataType.FLOAT32),
            FieldDecl(name="b", dtype=DataType.FLOAT32),
        ],
        vertical_loops=[copy_v_loop],
    )
예제 #3
0
 def visit_Stencil(self, node: gtir.Stencil, **kwargs: Any) -> FIELD_EXT_T:
     field_extents = {
         name: Extent.zeros()
         for name in _iter_field_names(node)
     }
     ctx = self.StencilContext()
     for field_if in node.iter_tree().if_isinstance(gtir.FieldIfStmt):
         self.visit(field_if, ctx=ctx)
     for assign in reversed(_iter_assigns(node).to_list()):
         self.visit(assign, ctx=ctx, field_extents=field_extents)
     return field_extents
예제 #4
0
 def visit_Stencil(self, node: gtir.Stencil, *, mask_inwards: bool,
                   **kwargs: Any) -> FIELD_EXT_T:
     field_extents: FIELD_EXT_T = {}
     ctx = self.StencilContext()
     for field_if in node.iter_tree().if_isinstance(gtir.FieldIfStmt):
         self.visit(field_if, ctx=ctx)
     for assign in reversed(_iter_assigns(node).to_list()):
         self.visit(assign, ctx=ctx, field_extents=field_extents)
     for name in _iter_field_names(node):
         # ensure we have an extent for all fields. note that we do not initialize to zero in the beginning as this
         #  breaks inward pointing extends (i.e. negative boundaries).
         field_extents.setdefault(name, Extent.zeros())
         if mask_inwards:
             # set inward pointing extents to zero
             field_extents[name] = Extent(*((min(0, e[0]), max(0, e[1]))
                                            for e in field_extents[name]))
     return field_extents
예제 #5
0
파일: test_gtir.py 프로젝트: fthaler/gt4py
def copy_computation(copy_v_loop):
    yield Stencil(
        name="copy_gtir",
        loc=SourceLocation(line=1, column=1, source="copy_gtir"),
        params=[
            FieldDecl(
                name="foo",
                dtype=DataType.FLOAT32,
                dimensions=(True, True, True),
            ),
            FieldDecl(
                name="bar",
                dtype=DataType.FLOAT32,
                dimensions=(True, True, True),
            ),
        ],
        vertical_loops=[copy_v_loop],
    )
예제 #6
0
def get_nodes_with_name(stencil: Stencil, name: str):
    return stencil.iter_tree().if_hasattr("name").filter(
        lambda node: node.name == name).to_list()
예제 #7
0
def get_nodes_with_name_and_dtype(stencil: Stencil, name: str):
    return (stencil.iter_tree().if_hasattr("name").filter(
        lambda node: hasattr(node, "dtype") and node.name == name).to_list())
예제 #8
0
 def build(self) -> Stencil:
     return Stencil(
         name=self._name,
         params=self._params,
         vertical_loops=self._vertical_loops,
     )
예제 #9
0
def _iter_assigns(node: gtir.Stencil) -> XIterator[gtir.ParAssignStmt]:
    return node.iter_tree().if_isinstance(gtir.ParAssignStmt)