def visit_Stencil(self, node: oir.Stencil, **kwargs: Any) -> oir.Stencil:
        write_before_read_tmps = {
            symbol
            for symbol, value in kwargs["symtable"].items()
            if isinstance(value, oir.Temporary)
        }
        horizontal_executions = node.iter_tree().if_isinstance(
            oir.HorizontalExecution)

        for horizontal_execution in horizontal_executions:
            accesses = AccessCollector.apply(horizontal_execution)
            offsets = accesses.offsets()
            ordered_accesses = accesses.ordered_accesses()

            def write_before_read(tmp: str) -> bool:
                if tmp not in offsets:
                    return True
                if offsets[tmp] != {(0, 0, 0)}:
                    return False
                return next(o.is_write for o in ordered_accesses
                            if o.field == tmp)

            write_before_read_tmps = {
                tmp
                for tmp in write_before_read_tmps if write_before_read(tmp)
            }

        return super().visit_Stencil(node,
                                     tmps_to_replace=write_before_read_tmps,
                                     **kwargs)
Exemple #2
0
 def build(self) -> Stencil:
     return Stencil(
         name=self._name,
         params=self._params,
         vertical_loops=self._vertical_loops,
         declarations=self._declarations,
     )
Exemple #3
0
 def visit_Stencil(self, node: oir.Stencil, **kwargs: Any) -> oir.Stencil:
     temporaries = {
         symbol
         for symbol, value in node.symtable_.items()
         if isinstance(value, oir.Temporary)
     }
     horizontal_executions = node.iter_tree().if_isinstance(
         oir.HorizontalExecution)
     counts: collections.Counter = sum(
         (collections.Counter(horizontal_execution.iter_tree(
         ).if_isinstance(
             oir.FieldAccess).getattr("name").if_in(temporaries).to_set())
          for horizontal_execution in horizontal_executions),
         collections.Counter(),
     )
     local_tmps = {tmp for tmp, count in counts.items() if count == 1}
     return oir.Stencil(
         name=node.name,
         params=node.params,
         vertical_loops=self.visit(node.vertical_loops,
                                   local_tmps=local_tmps,
                                   symtable=node.symtable_,
                                   **kwargs),
         declarations=[
             d for d in node.declarations if d.name not in local_tmps
         ],
     )
Exemple #4
0
 def visit_Stencil(self, node: oir.Stencil,
                   **kwargs: Any) -> Dict[str, oir.Expr]:
     masks_to_inline: Dict[str, oir.Expr] = {
         mask_stmt.mask.name: None
         for mask_stmt in node.iter_tree().if_isinstance(oir.MaskStmt).
         filter(lambda stmt: isinstance(stmt.mask, oir.FieldAccess))
     }
     self.visit(node.vertical_loops,
                masks_to_inline=masks_to_inline,
                **kwargs)
     assert all(value is not None for value in masks_to_inline.values())
     return masks_to_inline
Exemple #5
0
 def visit_Stencil(self, node: oir.Stencil, **kwargs: Any) -> oir.Stencil:
     vertical_loops = node.iter_tree().if_isinstance(oir.VerticalLoop)
     counts: collections.Counter = sum(
         (collections.Counter(vertical_loop.iter_tree().if_isinstance(
             oir.FieldAccess).getattr("name").if_in(
                 {tmp.name
                  for tmp in node.declarations}).to_set())
          for vertical_loop in vertical_loops),
         collections.Counter(),
     )
     local_tmps = {tmp for tmp, count in counts.items() if count == 1}
     return self.generic_visit(node, local_tmps=local_tmps, **kwargs)
Exemple #6
0
 def visit_Stencil(self, node: oir.Stencil, **kwargs: Any) -> oir.Stencil:
     horizontal_executions = node.iter_tree().if_isinstance(
         oir.HorizontalExecution)
     counts: collections.Counter = sum(
         (collections.Counter(
             horizontal_execution.iter_tree().if_isinstance(
                 oir.FieldAccess).getattr("name").if_in(
                     {tmp.name
                      for tmp in node.declarations}).to_set())
          for horizontal_execution in horizontal_executions),
         collections.Counter(),
     )
     local_tmps = {tmp for tmp, count in counts.items() if count == 1}
     return super().visit_Stencil(node, tmps_to_replace=local_tmps)
Exemple #7
0
def has_variable_access(stencil: oir.Stencil) -> bool:
    return len(stencil.iter_tree().if_isinstance(oir.VariableKOffset).to_list()) > 0