def __init__(self, name, stencil: Stencil, nodes): self._stencil = stencil self._sdfg = SDFG(name) self._state = self._sdfg.add_state(name + "_state") self._extents = nodes_extent_calculation(nodes) self._dtypes = {decl.name: decl.dtype for decl in stencil.declarations + stencil.params} self._axes = { decl.name: decl.dimensions for decl in stencil.declarations + stencil.params if isinstance(decl, FieldDecl) } self._recent_write_acc: Dict[str, dace.nodes.AccessNode] = dict() self._recent_read_acc: Dict[str, dace.nodes.AccessNode] = dict() self._access_nodes: Dict[str, dace.nodes.AccessNode] = dict() self._access_collection_cache: Dict[int, AccessCollector.GeneralAccessCollection] = dict() self._source_nodes: Dict[str, dace.nodes.AccessNode] = dict() self._delete_candidates: List[MultiConnectorEdge] = list() def generate_access_nodes(node): if isinstance(node, VerticalLoopLibraryNode): for _, s in node.sections: yield from generate_access_nodes(s) elif isinstance(node, dace.SDFG): for n, _ in node.all_nodes_recursive(): if isinstance(n, dace.nodes.LibraryNode): yield from generate_access_nodes(n) elif isinstance(node, HorizontalExecutionLibraryNode): yield from [ acc.name for acc in node.oir_node.iter_tree().if_isinstance(oir.FieldAccess) if isinstance(acc.offset, VariableKOffset) ] else: for n in node: yield from generate_access_nodes(n) self._dynamic_k_fields = set(generate_access_nodes(nodes))
def __init__(self, name, stencil: Stencil, nodes): self._stencil = stencil self._sdfg = SDFG(name) self._state = self._sdfg.add_state(name + "_state") self._extents = nodes_extent_calculation(nodes) self._dtypes = { decl.name: decl.dtype for decl in stencil.declarations + stencil.params } self._axes = { decl.name: decl.dimensions for decl in stencil.declarations + stencil.params if isinstance(decl, FieldDecl) } self._recent_write_acc: Dict[str, dace.nodes.AccessNode] = dict() self._recent_read_acc: Dict[str, dace.nodes.AccessNode] = dict() self._access_nodes: Dict[str, dace.nodes.AccessNode] = dict() self._access_collection_cache: Dict[ int, AccessCollector.CartesianAccessCollection] = dict() self._source_nodes: Dict[str, dace.nodes.AccessNode] = dict() self._delete_candidates: List[MultiConnectorEdge] = list()