Esempio n. 1
0
def test_iter_tree(tree):
    traversals = []
    for order in eve.iterators.TraversalOrder:
        values = [
            value for value in eve.iter_tree(tree, order, with_keys=True)
        ]
        assert all(isinstance(v, tuple) for v in values)
        traversals.append(values)
        traversals.append([value for value in eve.iter_tree(tree, order)])

    assert all(len(traversals[0]) == len(t) for t in traversals)
    def test_find_and_merge_with_2_vertical_loops(self):
        var1 = make_local_var("var1")
        assignment1, _ = make_init("field1")
        first_loop = make_horizontal_loop(
            make_block_stmt([assignment1], [var1]))

        var2 = make_local_var("var2")
        assignment2, _ = make_init("field2")
        second_loop = make_horizontal_loop(
            make_block_stmt([assignment2], [var2]))

        vertical_loop_1 = make_vertical_loop([first_loop, second_loop])
        vertical_loop_2 = vertical_loop_1.copy(deep=True)

        stencil = nir.Stencil(
            vertical_loops=[vertical_loop_1, vertical_loop_2])
        result = find_and_merge_horizontal_loops(stencil)

        vloops = eve.iter_tree(result).if_isinstance(
            nir.VerticalLoop).to_list()
        assert len(vloops) == 2
        for vloop in vloops:
            # TODO more precise checks
            assert len(vloop.horizontal_loops) == 1
            assert len(vloop.horizontal_loops[0].stmt.statements) == 2
            assert len(vloop.horizontal_loops[0].stmt.declarations) == 2
Esempio n. 3
0
def find_and_merge_horizontal_loops(root: Node):
    copy = root.copy(deep=True)
    vertical_loops = eve.iter_tree(copy).if_isinstance(
        nir.VerticalLoop).to_list()
    for loop in vertical_loops:
        loop = merge_horizontal_loops(loop, _find_merge_candidates(loop))

    return copy
Esempio n. 4
0
def find_and_merge_neighbor_loops(root: Node):
    horizontal_loops = eve.iter_tree(root).if_isinstance(
        nir.HorizontalLoop).to_list()
    merge_groups = {
        id(h_loop): _find_merge_candidates(h_loop)
        for h_loop in horizontal_loops
    }
    new_root = MergeNeighborLoops.apply(root, merge_groups)

    return new_root
Esempio n. 5
0
 def visit_Stencil(self, node: oir.Stencil, **kwargs):
     vertical_loops = self.visit(node.vertical_loops, **kwargs)
     accessed_fields = (
         iter_tree(vertical_loops).if_isinstance(oir.FieldAccess).getattr("name").to_set()
     )
     declarations = [decl for decl in node.declarations if decl.name in accessed_fields]
     return oir.Stencil(
         name=node.name,
         vertical_loops=vertical_loops,
         params=node.params,
         declarations=declarations,
         loc=node.loc,
     )
    def test_chain_assignment(self):
        stencil = make_stencil(
            fields=[make_field("field")],
            statements=[
                make_var_decl(name="var"),
                make_assign_to_local_var("var", make_field_acc("field")),
                make_var_decl(name="another_var",
                              dtype=float_type,
                              init=make_var_acc("var")),
            ],
        )

        result = InferLocalVariableLocationTypeTransformation.apply(stencil)

        vardecls = eve.iter_tree(result).if_isinstance(
            sir.VarDeclStmt).to_list()
        assert len(vardecls) == 2
        for vardecl in vardecls:
            assert vardecl.location_type == sir.LocationType.Cell
Esempio n. 7
0
def _find_merge_candidates(
        h_loop: nir.HorizontalLoop) -> List[List[nir.NeighborLoop]]:
    # This finder is broken and it doesn't compute any data dependency analysis.
    # It will only work for naive cases where neighbor loops are contiguous
    # and they do not need to be reordered.
    merge_groups = []
    neighbor_loops = cast(
        List[nir.NeighborLoop],
        eve.iter_tree(h_loop).if_isinstance(nir.NeighborLoop).to_list())
    outer = 0
    max_len = len(neighbor_loops)
    while outer < len(neighbor_loops):
        target_connectivity = neighbor_loops[outer].connectivity
        i = outer + 1
        while i < max_len and neighbor_loops[
                i].connectivity == target_connectivity:
            i += 1
        merge_groups.append(neighbor_loops[outer:i])
        outer = i

    return merge_groups
    def test_reduction(self):
        stencil = make_stencil(
            fields=[],
            statements=[
                make_var_decl(name="var"),
                make_assign_to_local_var(
                    "var",
                    sir.ReductionOverNeighborExpr(
                        op="+",
                        rhs=make_literal(),
                        init=make_literal(),
                        chain=[sir.LocationType.Edge, sir.LocationType.Cell],
                    ),
                ),
            ],
        )

        result = InferLocalVariableLocationTypeTransformation.apply(stencil)

        vardecl = eve.iter_tree(result).if_isinstance(
            sir.VarDeclStmt).to_list()[0]
        assert vardecl.location_type == sir.LocationType.Edge
Esempio n. 9
0
    def visit_HorizontalLoop(self, node: nir.HorizontalLoop, **kwargs):
        location_type_str = str(common.LocationType(
            node.location_type).name).lower()
        primary_connectivity = location_type_str + "_conn"
        connectivities = set()
        connectivities.add(
            usid.Connectivity(
                name=primary_connectivity,
                chain=usid.NeighborChain(elements=[node.location_type])))

        field_accesses = eve.iter_tree(node.stmt).if_isinstance(
            nir.FieldAccess).to_list()

        other_sids_entries = {}
        primary_sid_entries = set()
        for acc in field_accesses:
            if len(acc.primary.elements) == 1:
                assert acc.primary.elements[0] == node.location_type
                primary_sid_entries.add(usid.SidCompositeEntry(name=acc.name))
            else:
                assert (
                    len(acc.primary.elements) == 2
                )  # TODO cannot deal with more than one level of nesting
                secondary_loc = acc.primary.elements[
                    -1]  # TODO change if we have more than one level of nesting
                if secondary_loc not in other_sids_entries:
                    other_sids_entries[secondary_loc] = set()
                other_sids_entries[secondary_loc].add(
                    usid.SidCompositeEntry(name=acc.name))

        neighloops = eve.iter_tree(node.stmt).if_isinstance(
            nir.NeighborLoop).to_list()
        for loop in neighloops:
            transformed_neighbors = self.visit(loop.neighbors, **kwargs)
            connectivity_name = str(transformed_neighbors) + "_conn"
            connectivities.add(
                usid.Connectivity(name=connectivity_name,
                                  chain=transformed_neighbors))
            primary_sid_entries.add(
                usid.SidCompositeNeighborTableEntry(
                    connectivity=connectivity_name))

        primary_sid = location_type_str
        sids = []
        sids.append(
            usid.SidComposite(
                name=primary_sid,
                entries=primary_sid_entries,
                location=usid.NeighborChain(elements=[node.location_type]),
            ))

        for k, v in other_sids_entries.items():
            chain = usid.NeighborChain(elements=[node.location_type, k])
            sids.append(
                usid.SidComposite(name=str(chain), entries=v,
                                  location=chain))  # TODO _conn via property

        kernel_name = "kernel_" + node.id_
        kernel = usid.Kernel(
            ast=self.visit(
                node.stmt,
                sids_tbl={s.location: s
                          for s in sids},
                conn_tbl={c.chain: c
                          for c in connectivities},
                **kwargs,
            ),
            name=kernel_name,
            primary_connectivity=primary_connectivity,
            primary_sid=primary_sid,
            connectivities=connectivities,
            sids=sids,
        )
        return kernel, usid.KernelCall(name=kernel_name)