Esempio n. 1
0
 def visit_Stencil(self, node: oir.Stencil, **kwargs: Any) -> oir.Stencil:
     accesses = [
         AccessCollector.apply(vertical_loop)
         for vertical_loop in node.vertical_loops
     ]
     vertical_loops = []
     for i, vertical_loop in enumerate(node.vertical_loops):
         flushing_fields = {
             str(c.name)
             for c in vertical_loop.caches
             if isinstance(c, oir.KCache) and c.flush
         }
         read_only_fields = flushing_fields & (accesses[i].read_fields() -
                                               accesses[i].write_fields())
         future_reads: Set[str] = set()
         future_reads = future_reads.union(*(acc.read_fields()
                                             for acc in accesses[i + 1:]))
         tmps_without_reuse = (
             flushing_fields & {str(d.name)
                                for d in node.declarations}) - future_reads
         pruneable = read_only_fields | tmps_without_reuse
         vertical_loops.append(
             self.visit(vertical_loop, pruneable=pruneable, **kwargs))
     return oir.Stencil(
         name=node.name,
         params=self.visit(node.params, **kwargs),
         vertical_loops=vertical_loops,
         declarations=node.declarations,
     )
Esempio n. 2
0
 def visit_Stencil(self, node: oir.Stencil, **kwargs: Any) -> oir.Stencil:
     temporaries = {
         symbol
         for symbol, value in node.symtable_.items()
         if isinstance(value, oir.Temporary)
     }
     horizontal_executions = node.iter_tree().if_isinstance(
         oir.HorizontalExecution)
     counts: collections.Counter = sum(
         (collections.Counter(horizontal_execution.iter_tree(
         ).if_isinstance(
             oir.FieldAccess).getattr("name").if_in(temporaries).to_set())
          for horizontal_execution in horizontal_executions),
         collections.Counter(),
     )
     local_tmps = {tmp for tmp, count in counts.items() if count == 1}
     return oir.Stencil(
         name=node.name,
         params=node.params,
         vertical_loops=self.visit(node.vertical_loops,
                                   local_tmps=local_tmps,
                                   symtable=node.symtable_,
                                   **kwargs),
         declarations=[
             d for d in node.declarations if d.name not in local_tmps
         ],
     )
Esempio n. 3
0
 def visit_Stencil(self, node: gtir.Stencil, **kwargs: Any) -> oir.Stencil:
     ctx = self.Context()
     return oir.Stencil(
         name=node.name,
         params=self.visit(node.params),
         vertical_loops=self.visit(node.vertical_loops, ctx=ctx),
         declarations=ctx.decls,
     )
Esempio n. 4
0
 def visit_Stencil(self, node: gtir.Stencil) -> oir.Stencil:
     ctx = self.Context()
     vertical_loops = self.visit(node.vertical_loops, ctx=ctx)
     return oir.Stencil(
         name=node.name,
         params=self.visit(node.params),
         vertical_loops=vertical_loops,
         declarations=ctx.temp_fields,
         loc=node.loc,
     )
Esempio n. 5
0
 def visit_Stencil(self, node: oir.Stencil, **kwargs: Any) -> oir.Stencil:
     new_tmps: List[oir.Temporary] = []
     return oir.Stencil(
         name=node.name,
         params=node.params,
         vertical_loops=self.visit(
             node.vertical_loops,
             new_tmps=new_tmps,
             new_symbol_name=symbol_name_creator(set(kwargs["symtable"])),
             **kwargs,
         ),
         declarations=node.declarations + new_tmps,
     )
Esempio n. 6
0
 def visit_Stencil(self, node: oir.Stencil, **kwargs):
     vertical_loops = self.visit(node.vertical_loops, **kwargs)
     accessed_fields = (
         iter_tree(vertical_loops).if_isinstance(oir.FieldAccess).getattr("name").to_set()
     )
     declarations = [decl for decl in node.declarations if decl.name in accessed_fields]
     return oir.Stencil(
         name=node.name,
         vertical_loops=vertical_loops,
         params=node.params,
         declarations=declarations,
         loc=node.loc,
     )
Esempio n. 7
0
 def visit_Stencil(self, node: oir.Stencil, **kwargs: Any) -> oir.Stencil:
     tmps_to_replace = kwargs["tmps_to_replace"]
     return oir.Stencil(
         name=node.name,
         params=node.params,
         vertical_loops=self.visit(
             node.vertical_loops,
             new_symbol_name=symbol_name_creator(set(kwargs["symtable"])),
             **kwargs,
         ),
         declarations=[
             d for d in node.declarations if d.name not in tmps_to_replace
         ],
     )
Esempio n. 8
0
 def visit_Stencil(self, node: oir.Stencil, **kwargs: Any) -> oir.Stencil:
     vertical_loops = self.visit(
         node.vertical_loops,
         symtable=node.symtable_,
         new_symbol_name=symbol_name_creator(set(node.symtable_)),
         **kwargs,
     )
     accessed = AccessCollector.apply(vertical_loops).fields()
     return oir.Stencil(
         name=node.name,
         params=node.params,
         vertical_loops=vertical_loops,
         declarations=[d for d in node.declarations if d.name in accessed],
     )
Esempio n. 9
0
 def visit_Stencil(self, node: oir.Stencil, **kwargs: Any) -> oir.Stencil:
     tmps_to_replace = kwargs["tmps_to_replace"]
     all_names = collect_symbol_names(node)
     return oir.Stencil(
         name=node.name,
         params=node.params,
         vertical_loops=self.visit(
             node.vertical_loops,
             new_symbol_name=symbol_name_creator(all_names),
             **kwargs,
         ),
         declarations=[
             d for d in node.declarations if d.name not in tmps_to_replace
         ],
         loc=node.loc,
     )
Esempio n. 10
0
    def visit_Stencil(self, node: oir.Stencil, **kwargs: Any) -> oir.Stencil:
        if not node.vertical_loops:
            return self.generic_visit(node, **kwargs)
        vertical_loops = [self.visit(node.vertical_loops[0], **kwargs)]
        for vertical_loop in node.vertical_loops[1:]:
            vertical_loop = self.visit(vertical_loop, **kwargs)
            mergeable = self._mergeable(vertical_loops[-1], vertical_loop)
            if mergeable:
                vertical_loops[-1] = self._merge(vertical_loops[-1],
                                                 vertical_loop)
            else:
                vertical_loops.append(vertical_loop)

        return oir.Stencil(
            name=node.name,
            params=node.params,
            vertical_loops=vertical_loops,
            declarations=node.declarations,
        )
Esempio n. 11
0
def convert(sdfg: dace.SDFG) -> oir.Stencil:

    validate_oir_sdfg(sdfg)

    params, decls = sdfg_arrays_to_oir_decls(sdfg)
    vertical_loops = []
    for state in sdfg.topological_sort(sdfg.start_state):

        for node in (n for n in nx.topological_sort(state.nx)
                     if isinstance(n, VerticalLoopLibraryNode)):

            new_node = OIRFieldRenamer(get_node_name_mapping(
                state, node)).visit(node.as_oir())
            vertical_loops.append(new_node)

    return oir.Stencil(name=sdfg.name,
                       params=params,
                       declarations=decls,
                       vertical_loops=vertical_loops)
 def visit_Stencil(self, node: oir.Stencil, **kwargs: Any) -> oir.Stencil:
     vertical_loops: List[oir.VerticalLoop] = []
     protected_fields = set(n.name for n in node.params)
     all_names = collect_symbol_names(node)
     vertical_loops = [
         *reversed([
             self.visit(
                 vl,
                 new_symbol_name=symbol_name_creator(all_names),
                 protected_fields=protected_fields,
                 **kwargs,
             ) for vl in reversed(node.vertical_loops)
         ])
     ]
     accessed = AccessCollector.apply(vertical_loops).fields()
     return oir.Stencil(
         name=node.name,
         params=node.params,
         vertical_loops=vertical_loops,
         declarations=[d for d in node.declarations if d.name in accessed],
     )
Esempio n. 13
0
    def visit_Stencil(self, node: oir.Stencil, **kwargs: Any) -> oir.Stencil:
        protected_fields = set(n.name for n in node.params)
        new_symbol_name = symbol_name_creator(collect_symbol_names(node))
        vertical_loops = []
        for vl in reversed(node.vertical_loops):
            vl = self.visit(vl,
                            new_symbol_name=new_symbol_name,
                            protected_fields=protected_fields,
                            **kwargs)
            vertical_loops.append(vl)
            protected_fields |= AccessCollector.apply(vl).read_fields()
        vertical_loops = list(reversed(vertical_loops))

        accessed = AccessCollector.apply(vertical_loops).fields()
        return oir.Stencil(
            name=node.name,
            params=node.params,
            vertical_loops=vertical_loops,
            declarations=[d for d in node.declarations if d.name in accessed],
            loc=node.loc,
        )
 def visit_Stencil(self, node: oir.Stencil, **kwargs: Any) -> oir.Stencil:
     vertical_loops: List[oir.VerticalLoop] = []
     protected_fields = set(n.name for n in node.params)
     for vl in reversed(node.vertical_loops):
         vertical_loops.insert(
             0,
             self.visit(
                 vl,
                 new_symbol_name=symbol_name_creator(set(
                     kwargs["symtable"])),
                 protected_fields=protected_fields,
                 **kwargs,
             ),
         )
         access_collection = AccessCollector.apply(vl)
         protected_fields |= access_collection.fields()
     accessed = AccessCollector.apply(vertical_loops).fields()
     return oir.Stencil(
         name=node.name,
         params=node.params,
         vertical_loops=vertical_loops,
         declarations=[d for d in node.declarations if d.name in accessed],
     )