Example #1
0
    def __init__(self, name, stencil: Stencil, nodes):
        self._stencil = stencil
        self._sdfg = SDFG(name)
        self._state = self._sdfg.add_state(name + "_state")
        self._extents = nodes_extent_calculation(nodes)

        self._dtypes = {decl.name: decl.dtype for decl in stencil.declarations + stencil.params}
        self._axes = {
            decl.name: decl.dimensions
            for decl in stencil.declarations + stencil.params
            if isinstance(decl, FieldDecl)
        }

        self._recent_write_acc: Dict[str, dace.nodes.AccessNode] = dict()
        self._recent_read_acc: Dict[str, dace.nodes.AccessNode] = dict()

        self._access_nodes: Dict[str, dace.nodes.AccessNode] = dict()
        self._access_collection_cache: Dict[int, AccessCollector.GeneralAccessCollection] = dict()
        self._source_nodes: Dict[str, dace.nodes.AccessNode] = dict()
        self._delete_candidates: List[MultiConnectorEdge] = list()

        def generate_access_nodes(node):
            if isinstance(node, VerticalLoopLibraryNode):
                for _, s in node.sections:
                    yield from generate_access_nodes(s)
            elif isinstance(node, dace.SDFG):
                for n, _ in node.all_nodes_recursive():
                    if isinstance(n, dace.nodes.LibraryNode):
                        yield from generate_access_nodes(n)
            elif isinstance(node, HorizontalExecutionLibraryNode):
                yield from [
                    acc.name
                    for acc in node.oir_node.iter_tree().if_isinstance(oir.FieldAccess)
                    if isinstance(acc.offset, VariableKOffset)
                ]
            else:
                for n in node:
                    yield from generate_access_nodes(n)

        self._dynamic_k_fields = set(generate_access_nodes(nodes))
Example #2
0
    def __init__(self, name, stencil: Stencil, nodes):
        self._stencil = stencil
        self._sdfg = SDFG(name)
        self._state = self._sdfg.add_state(name + "_state")
        self._extents = nodes_extent_calculation(nodes)

        self._dtypes = {
            decl.name: decl.dtype
            for decl in stencil.declarations + stencil.params
        }
        self._axes = {
            decl.name: decl.dimensions
            for decl in stencil.declarations + stencil.params
            if isinstance(decl, FieldDecl)
        }

        self._recent_write_acc: Dict[str, dace.nodes.AccessNode] = dict()
        self._recent_read_acc: Dict[str, dace.nodes.AccessNode] = dict()

        self._access_nodes: Dict[str, dace.nodes.AccessNode] = dict()
        self._access_collection_cache: Dict[
            int, AccessCollector.CartesianAccessCollection] = dict()
        self._source_nodes: Dict[str, dace.nodes.AccessNode] = dict()
        self._delete_candidates: List[MultiConnectorEdge] = list()