Exemplo n.º 1
0
 def visit_NeighborLoop(self, node: nir.NeighborLoop, **kwargs):
     return usid.NeighborLoop(
         outer_sid=kwargs["sids_tbl"][usid.NeighborChain(elements=[node.location_type])].name,
         connectivity=kwargs["conn_tbl"][node.neighbors].name,
         sid=kwargs["sids_tbl"][node.neighbors].name
         if node.neighbors in kwargs["sids_tbl"]
         else None,
         location_type=node.location_type,
         body_location_type=node.neighbors.elements[-1],
         body=self.visit(node.body, **kwargs),
     )
Exemplo n.º 2
0
 def visit_NeighborChain(self, node: nir.NeighborChain, **kwargs):
     return usid.NeighborChain(elements=[location for location in node.elements])
Exemplo n.º 3
0
    def visit_HorizontalLoop(self, node: nir.HorizontalLoop, **kwargs):
        location_type_str = str(common.LocationType(node.location_type).name).lower()
        primary_connectivity = location_type_str + "_conn"
        connectivities = set()
        connectivities.add(
            usid.Connectivity(
                name=primary_connectivity, chain=usid.NeighborChain(elements=[node.location_type])
            )
        )

        field_accesses = eve.FindNodes().by_type(nir.FieldAccess, node.stmt)

        other_sids_entries = {}
        primary_sid_entries = set()
        for acc in field_accesses:
            if len(acc.primary.elements) == 1:
                assert acc.primary.elements[0] == node.location_type
                primary_sid_entries.add(usid.SidCompositeEntry(name=acc.name))
            else:
                assert (
                    len(acc.primary.elements) == 2
                )  # TODO cannot deal with more than one level of nesting
                secondary_loc = acc.primary.elements[
                    -1
                ]  # TODO change if we have more than one level of nesting
                if secondary_loc not in other_sids_entries:
                    other_sids_entries[secondary_loc] = set()
                other_sids_entries[secondary_loc].add(usid.SidCompositeEntry(name=acc.name))

        neighloops = eve.FindNodes().by_type(nir.NeighborLoop, node.stmt)
        for loop in neighloops:
            transformed_neighbors = self.visit(loop.neighbors, **kwargs)
            connectivity_name = str(transformed_neighbors) + "_conn"
            connectivities.add(
                usid.Connectivity(name=connectivity_name, chain=transformed_neighbors)
            )
            primary_sid_entries.add(
                usid.SidCompositeNeighborTableEntry(connectivity=connectivity_name)
            )

        primary_sid = location_type_str
        sids = []
        sids.append(
            usid.SidComposite(
                name=primary_sid,
                entries=primary_sid_entries,
                location=usid.NeighborChain(elements=[node.location_type]),
            )
        )

        for k, v in other_sids_entries.items():
            chain = usid.NeighborChain(elements=[node.location_type, k])
            sids.append(
                usid.SidComposite(name=str(chain), entries=v, location=chain)
            )  # TODO _conn via property

        kernel_name = "kernel_" + node.id_attr_
        kernel = usid.Kernel(
            ast=self.visit(
                node.stmt,
                sids_tbl={s.location: s for s in sids},
                conn_tbl={c.chain: c for c in connectivities},
                **kwargs,
            ),
            name=kernel_name,
            primary_connectivity=primary_connectivity,
            primary_sid=primary_sid,
            connectivities=connectivities,
            sids=sids,
        )
        return kernel, usid.KernelCall(name=kernel_name)