def test_iter_tree(tree): traversals = [] for order in eve.iterators.TraversalOrder: values = [ value for value in eve.iter_tree(tree, order, with_keys=True) ] assert all(isinstance(v, tuple) for v in values) traversals.append(values) traversals.append([value for value in eve.iter_tree(tree, order)]) assert all(len(traversals[0]) == len(t) for t in traversals)
def test_find_and_merge_with_2_vertical_loops(self): var1 = make_local_var("var1") assignment1, _ = make_init("field1") first_loop = make_horizontal_loop( make_block_stmt([assignment1], [var1])) var2 = make_local_var("var2") assignment2, _ = make_init("field2") second_loop = make_horizontal_loop( make_block_stmt([assignment2], [var2])) vertical_loop_1 = make_vertical_loop([first_loop, second_loop]) vertical_loop_2 = vertical_loop_1.copy(deep=True) stencil = nir.Stencil( vertical_loops=[vertical_loop_1, vertical_loop_2]) result = find_and_merge_horizontal_loops(stencil) vloops = eve.iter_tree(result).if_isinstance( nir.VerticalLoop).to_list() assert len(vloops) == 2 for vloop in vloops: # TODO more precise checks assert len(vloop.horizontal_loops) == 1 assert len(vloop.horizontal_loops[0].stmt.statements) == 2 assert len(vloop.horizontal_loops[0].stmt.declarations) == 2
def find_and_merge_horizontal_loops(root: Node): copy = root.copy(deep=True) vertical_loops = eve.iter_tree(copy).if_isinstance( nir.VerticalLoop).to_list() for loop in vertical_loops: loop = merge_horizontal_loops(loop, _find_merge_candidates(loop)) return copy
def find_and_merge_neighbor_loops(root: Node): horizontal_loops = eve.iter_tree(root).if_isinstance( nir.HorizontalLoop).to_list() merge_groups = { id(h_loop): _find_merge_candidates(h_loop) for h_loop in horizontal_loops } new_root = MergeNeighborLoops.apply(root, merge_groups) return new_root
def visit_Stencil(self, node: oir.Stencil, **kwargs): vertical_loops = self.visit(node.vertical_loops, **kwargs) accessed_fields = ( iter_tree(vertical_loops).if_isinstance(oir.FieldAccess).getattr("name").to_set() ) declarations = [decl for decl in node.declarations if decl.name in accessed_fields] return oir.Stencil( name=node.name, vertical_loops=vertical_loops, params=node.params, declarations=declarations, loc=node.loc, )
def test_chain_assignment(self): stencil = make_stencil( fields=[make_field("field")], statements=[ make_var_decl(name="var"), make_assign_to_local_var("var", make_field_acc("field")), make_var_decl(name="another_var", dtype=float_type, init=make_var_acc("var")), ], ) result = InferLocalVariableLocationTypeTransformation.apply(stencil) vardecls = eve.iter_tree(result).if_isinstance( sir.VarDeclStmt).to_list() assert len(vardecls) == 2 for vardecl in vardecls: assert vardecl.location_type == sir.LocationType.Cell
def _find_merge_candidates( h_loop: nir.HorizontalLoop) -> List[List[nir.NeighborLoop]]: # This finder is broken and it doesn't compute any data dependency analysis. # It will only work for naive cases where neighbor loops are contiguous # and they do not need to be reordered. merge_groups = [] neighbor_loops = cast( List[nir.NeighborLoop], eve.iter_tree(h_loop).if_isinstance(nir.NeighborLoop).to_list()) outer = 0 max_len = len(neighbor_loops) while outer < len(neighbor_loops): target_connectivity = neighbor_loops[outer].connectivity i = outer + 1 while i < max_len and neighbor_loops[ i].connectivity == target_connectivity: i += 1 merge_groups.append(neighbor_loops[outer:i]) outer = i return merge_groups
def test_reduction(self): stencil = make_stencil( fields=[], statements=[ make_var_decl(name="var"), make_assign_to_local_var( "var", sir.ReductionOverNeighborExpr( op="+", rhs=make_literal(), init=make_literal(), chain=[sir.LocationType.Edge, sir.LocationType.Cell], ), ), ], ) result = InferLocalVariableLocationTypeTransformation.apply(stencil) vardecl = eve.iter_tree(result).if_isinstance( sir.VarDeclStmt).to_list()[0] assert vardecl.location_type == sir.LocationType.Edge
def visit_HorizontalLoop(self, node: nir.HorizontalLoop, **kwargs): location_type_str = str(common.LocationType( node.location_type).name).lower() primary_connectivity = location_type_str + "_conn" connectivities = set() connectivities.add( usid.Connectivity( name=primary_connectivity, chain=usid.NeighborChain(elements=[node.location_type]))) field_accesses = eve.iter_tree(node.stmt).if_isinstance( nir.FieldAccess).to_list() other_sids_entries = {} primary_sid_entries = set() for acc in field_accesses: if len(acc.primary.elements) == 1: assert acc.primary.elements[0] == node.location_type primary_sid_entries.add(usid.SidCompositeEntry(name=acc.name)) else: assert ( len(acc.primary.elements) == 2 ) # TODO cannot deal with more than one level of nesting secondary_loc = acc.primary.elements[ -1] # TODO change if we have more than one level of nesting if secondary_loc not in other_sids_entries: other_sids_entries[secondary_loc] = set() other_sids_entries[secondary_loc].add( usid.SidCompositeEntry(name=acc.name)) neighloops = eve.iter_tree(node.stmt).if_isinstance( nir.NeighborLoop).to_list() for loop in neighloops: transformed_neighbors = self.visit(loop.neighbors, **kwargs) connectivity_name = str(transformed_neighbors) + "_conn" connectivities.add( usid.Connectivity(name=connectivity_name, chain=transformed_neighbors)) primary_sid_entries.add( usid.SidCompositeNeighborTableEntry( connectivity=connectivity_name)) primary_sid = location_type_str sids = [] sids.append( usid.SidComposite( name=primary_sid, entries=primary_sid_entries, location=usid.NeighborChain(elements=[node.location_type]), )) for k, v in other_sids_entries.items(): chain = usid.NeighborChain(elements=[node.location_type, k]) sids.append( usid.SidComposite(name=str(chain), entries=v, location=chain)) # TODO _conn via property kernel_name = "kernel_" + node.id_ kernel = usid.Kernel( ast=self.visit( node.stmt, sids_tbl={s.location: s for s in sids}, conn_tbl={c.chain: c for c in connectivities}, **kwargs, ), name=kernel_name, primary_connectivity=primary_connectivity, primary_sid=primary_sid, connectivities=connectivities, sids=sids, ) return kernel, usid.KernelCall(name=kernel_name)