Пример #1
    def instantiate_loop(
        sdfg: sd.SDFG,
        loop_states: List[sd.SDFGState],
        loop_subgraph: gr.SubgraphView,
        itervar: str,
        value: symbolic.SymbolicType,
        # Using to/from JSON copies faster than deepcopy (which will also
        # copy the parent SDFG)
        new_states = [
            sd.SDFGState.from_json(s.to_json(), context={'sdfg': sdfg})
            for s in loop_states

        # Replace iterate with value in each state
        for state in new_states:
            state.set_label(state.label + '_' + itervar + '_' + (
                state_suffix if state_suffix is not None else '%d' % value))
            state.replace(itervar, value)

        # Add subgraph to original SDFG
        for edge in loop_subgraph.edges():
            src = new_states[loop_states.index(edge.src)]
            dst = new_states[loop_states.index(edge.dst)]

            # Replace conditions in subgraph edges
            data: sd.InterstateEdge = copy.deepcopy(edge.data)
            if data.condition:
                ASTFindReplace({itervar: str(value)}).visit(data.condition)

            sdfg.add_edge(src, dst, data)

        return new_states
Пример #2
def state_fission(sdfg: SDFG, subgraph: graph.SubgraphView) -> SDFGState:
    Given a subgraph, adds a new SDFG state before the state that contains it,
    removes the subgraph from the original state, and connects the two states.
    :param subgraph: the subgraph to remove.
    :return: the newly created SDFG state.

    state: SDFGState = subgraph.graph
    newstate = sdfg.add_state_before(state)

    # Save edges before removing nodes
    orig_edges = subgraph.edges()

    # Mark boundary access nodes to keep after fission
    nodes_to_remove = set(subgraph.nodes())
    nodes_to_remove -= set(n for n in subgraph.source_nodes()
                           if state.out_degree(n) > 1)
    nodes_to_remove -= set(n for n in subgraph.sink_nodes()
                           if state.in_degree(n) > 1)

    for n in subgraph.nodes():
        if isinstance(n, nodes.NestedSDFG):
            # Set the new parent state
            n.sdfg.parent = newstate


    for e in orig_edges:
        newstate.add_edge(e.src, e.src_conn, e.dst, e.dst_conn, e.data)

    return newstate
Пример #3
def state_fission(sdfg: SDFG, subgraph: graph.SubgraphView) -> SDFGState:
    Given a subgraph, adds a new SDFG state before the state that contains it,
    removes the subgraph from the original state, and connects the two states.
    :param subgraph: the subgraph to remove.
    :return: the newly created SDFG state.

    state: SDFGState = subgraph.graph
    newstate = sdfg.add_state_before(state)

    # Save edges before removing nodes
    orig_edges = subgraph.edges()

    # Mark boundary access nodes to keep after fission
    nodes_to_remove = set(subgraph.nodes())
    boundary_nodes = [
        n for n in subgraph.nodes()
        if len(state.out_edges(n)) > len(subgraph.out_edges(n))
    ] + [
        n for n in subgraph.nodes()
        if len(state.in_edges(n)) > len(subgraph.in_edges(n))

    # Make dictionary of nodes to add to new state
    new_nodes = {n: n for n in subgraph.nodes()}
    new_nodes.update({b: copy.deepcopy(b) for b in boundary_nodes})

    nodes_to_remove -= set(boundary_nodes)

    for n in new_nodes.values():
        if isinstance(n, nodes.NestedSDFG):
            # Set the new parent state
            n.sdfg.parent = newstate


    for e in orig_edges:
        newstate.add_edge(new_nodes[e.src], e.src_conn, new_nodes[e.dst],
                          e.dst_conn, e.data)

    return newstate
Пример #4
 def get_actions(actions, graph, match):
     subgraph_node_ids = match.subgraph.values()
     subgraph_nodes = [graph.nodes()[nid] for nid in subgraph_node_ids]
     for node in subgraph_nodes:
         version = 0
         while (node, type(match).__name__, match.expr_index,
                version) in actions.keys():
             version += 1
         actions[(node, type(match).__name__, match.expr_index,
                  version)] = match
     subgraph = SubgraphView(graph, subgraph_nodes)
     for edge in subgraph.edges():
         version = 0
         while (edge, type(match).__name__, match.expr_index,
                version) in actions.keys():
             version += 1
         actions[(edge, type(match).__name__, match.expr_index,
                  version)] = match
     return actions
Пример #5
    def apply(self, sdfg: SDFG):
        state: SDFGState = sdfg.nodes()[self.state_id]
        nsdfg_node = state.nodes()[self.subgraph[InlineSDFG._nested_sdfg]]
        nsdfg: SDFG = nsdfg_node.sdfg
        nstate: SDFGState = nsdfg.nodes()[0]

        if nsdfg_node.schedule is not dtypes.ScheduleType.Default:
                nsdfg, nsdfg_node.schedule)

        nsdfg_scope_entry = state.entry_node(nsdfg_node)
        nsdfg_scope_exit = (state.exit_node(nsdfg_scope_entry)
                            if nsdfg_scope_entry is not None else None)

        # Collect and update top-level SDFG metadata

        # Global/init/exit code
        for loc, code in nsdfg.global_code.items():
            sdfg.append_global_code(code.code, loc)
        for loc, code in nsdfg.init_code.items():
            sdfg.append_init_code(code.code, loc)
        for loc, code in nsdfg.exit_code.items():
            sdfg.append_exit_code(code.code, loc)

        # Constants
        for cstname, cstval in nsdfg.constants.items():
            if cstname in sdfg.constants:
                if cstval != sdfg.constants[cstname]:
                    warnings.warn('Constant value mismatch for "%s" while '
                                  'inlining SDFG. Inner = %s != %s = outer' %
                                  (cstname, cstval, sdfg.constants[cstname]))
                sdfg.add_constant(cstname, cstval)

        # Find original source/destination edges (there is only one edge per
        # connector, according to match)
        inputs: Dict[str, MultiConnectorEdge] = {}
        outputs: Dict[str, MultiConnectorEdge] = {}
        input_set: Dict[str, str] = {}
        output_set: Dict[str, str] = {}
        for e in state.in_edges(nsdfg_node):
            inputs[e.dst_conn] = e
            input_set[e.data.data] = e.dst_conn
        for e in state.out_edges(nsdfg_node):
            outputs[e.src_conn] = e
            output_set[e.data.data] = e.src_conn

        # Access nodes that need to be reshaped
        reshapes: Set(str) = set()
        for aname, array in nsdfg.arrays.items():
            if array.transient:
            edge = None
            if aname in inputs:
                edge = inputs[aname]
                if len(array.shape) > len(edge.data.subset):
            if aname in outputs:
                edge = outputs[aname]
                if len(array.shape) > len(edge.data.subset):
            if edge is not None and not InlineSDFG._check_strides(
                    array.strides, sdfg.arrays[edge.data.data].strides,
                    edge.data, nsdfg_node):

        # Replace symbols using invocation symbol mapping
        # Two-step replacement (N -> __dacesym_N --> map[N]) to avoid clashes
        for symname, symvalue in nsdfg_node.symbol_mapping.items():
            if str(symname) != str(symvalue):
                nsdfg.replace(symname, '__dacesym_' + symname)
        for symname, symvalue in nsdfg_node.symbol_mapping.items():
            if str(symname) != str(symvalue):
                nsdfg.replace('__dacesym_' + symname, symvalue)

        # All transients become transients of the parent (if data already
        # exists, find new name)
        # Mapping from nested transient name to top-level name
        transients: Dict[str, str] = {}
        for node in nstate.nodes():
            if isinstance(node, nodes.AccessNode):
                datadesc = nsdfg.arrays[node.data]
                if node.data not in transients and datadesc.transient:
                    name = sdfg.add_datadesc('%s_%s' %
                                             (nsdfg.label, node.data),
                    transients[node.data] = name

        # All transients of edges between code nodes are also added to parent
        for edge in nstate.edges():
            if (isinstance(edge.src, nodes.CodeNode)
                    and isinstance(edge.dst, nodes.CodeNode)):
                if edge.data.data is not None:
                    datadesc = nsdfg.arrays[edge.data.data]
                    if edge.data.data not in transients and datadesc.transient:
                        name = sdfg.add_datadesc('%s_%s' %
                                                 (nsdfg.label, edge.data.data),
                        transients[edge.data.data] = name

        # Collect nodes to add to top-level graph
        new_incoming_edges: Dict[nodes.Node, MultiConnectorEdge] = {}
        new_outgoing_edges: Dict[nodes.Node, MultiConnectorEdge] = {}

        source_accesses = set()
        sink_accesses = set()
        for node in nstate.source_nodes():
            if (isinstance(node, nodes.AccessNode)
                    and node.data not in transients
                    and node.data not in reshapes):
                new_incoming_edges[node] = inputs[node.data]
        for node in nstate.sink_nodes():
            if (isinstance(node, nodes.AccessNode)
                    and node.data not in transients
                    and node.data not in reshapes):
                new_outgoing_edges[node] = outputs[node.data]

        # Replace data on inlined SDFG nodes/edges

        # Replace data names with their top-level counterparts
        repldict = {}
            k: v.data.data
            for k, v in itertools.chain(inputs.items(), outputs.items())

        # Add views whenever reshapes are necessary
        for dname in reshapes:
            desc = nsdfg.arrays[dname]
            # To avoid potential confusion, rename protected __return keyword
            if dname.startswith('__return'):
                newname = f'{nsdfg.name}_ret{dname[8:]}'
                newname = dname
            newname, _ = sdfg.add_view(newname,
            repldict[dname] = newname

        for node in nstate.nodes():
            if isinstance(node, nodes.AccessNode) and node.data in repldict:
                node.data = repldict[node.data]
        for edge in nstate.edges():
            if edge.data.data in repldict:
                edge.data.data = repldict[edge.data.data]

        # Add extra access nodes for out/in view nodes
        for node in nstate.nodes():
            if isinstance(node, nodes.AccessNode) and node.data in reshapes:
                if nstate.in_degree(node) > 0 and nstate.out_degree(node) > 0:
                    # Such a node has to be in the output set
                    edge = outputs[node.data]

                    # Redirect outgoing edges through access node
                    out_edges = list(nstate.out_edges(node))
                    anode = nstate.add_access(edge.data.data)
                    vnode = nstate.add_access(node.data)
                    nstate.add_nedge(node, anode, edge.data)
                    nstate.add_nedge(anode, vnode, edge.data)
                    for e in out_edges:
                        nstate.add_edge(vnode, e.src_conn, e.dst, e.dst_conn,

        # Add nested SDFG into top-level SDFG

        # Add nested nodes into original state
        subgraph = SubgraphView(nstate, [
            n for n in nstate.nodes()
            if n not in (source_accesses | sink_accesses)
        for edge in subgraph.edges():
            state.add_edge(edge.src, edge.src_conn, edge.dst, edge.dst_conn,

        # Reconnect inlined SDFG

        # If a source/sink node is one of the inputs/outputs, reconnect it,
        # replacing memlets in outgoing/incoming paths
        modified_edges = set()
        modified_edges |= self._modify_memlet_path(new_incoming_edges, nstate,
                                                   state, True)
        modified_edges |= self._modify_memlet_path(new_outgoing_edges, nstate,
                                                   state, False)

        # Reshape: add connections to viewed data
        self._modify_reshape_data(reshapes, repldict, inputs, nstate, state,
        self._modify_reshape_data(reshapes, repldict, outputs, nstate, state,

        # Modify all other internal edges pertaining to input/output nodes
        for node in subgraph.nodes():
            if isinstance(node, nodes.AccessNode):
                if node.data in input_set or node.data in output_set:
                    if node.data in input_set:
                        outer_edge = inputs[input_set[node.data]]
                        outer_edge = outputs[output_set[node.data]]

                    for edge in state.all_edges(node):
                        if (edge not in modified_edges
                                and edge.data.data == node.data):
                            for e in state.memlet_tree(edge):
                                if e.data.data == node.data:
                                    e._data = helpers.unsqueeze_memlet(
                                        e.data, outer_edge.data)

        # If source/sink node is not connected to a source/destination access
        # node, and the nested SDFG is in a scope, connect to scope with empty
        # memlets
        if nsdfg_scope_entry is not None:
            for node in subgraph.nodes():
                if state.in_degree(node) == 0:
                    state.add_edge(nsdfg_scope_entry, None, node, None,
                if state.out_degree(node) == 0:
                    state.add_edge(node, None, nsdfg_scope_exit, None,

        # Replace nested SDFG parents with new SDFG
        for node in nstate.nodes():
            if isinstance(node, nodes.NestedSDFG):
                node.sdfg.parent = state
                node.sdfg.parent_sdfg = sdfg
                node.sdfg.parent_nsdfg_node = node

        # Remove all unused external inputs/output memlet paths, as well as
        # resulting isolated nodes
        removed_in_edges = self._remove_edge_path(state,
                                                  set(inputs.keys()) -
        removed_out_edges = self._remove_edge_path(state,
                                                   set(outputs.keys()) -

        # Re-add in/out edges to first/last nodes in subgraph
        order = [
            x for x in nx.topological_sort(nstate._nx)
            if isinstance(x, nodes.AccessNode)
        for edge in removed_in_edges:
            # Find first access node that refers to this edge
            node = next(n for n in order if n.data == edge.data.data)
            state.add_edge(edge.src, edge.src_conn, node, edge.dst_conn,
        for edge in removed_out_edges:
            # Find last access node that refers to this edge
            node = next(n for n in reversed(order) if n.data == edge.data.data)
            state.add_edge(node, edge.src_conn, edge.dst, edge.dst_conn,

        # Remove nested SDFG node
Пример #6
def nest_state_subgraph(sdfg: SDFG,
                        state: SDFGState,
                        subgraph: SubgraphView,
                        name: Optional[str] = None,
                        full_data: bool = False) -> nodes.NestedSDFG:
    """ Turns a state subgraph into a nested SDFG. Operates in-place.
        :param sdfg: The SDFG containing the state subgraph.
        :param state: The state containing the subgraph.
        :param subgraph: Subgraph to nest.
        :param name: An optional name for the nested SDFG.
        :param full_data: If True, nests entire input/output data.
        :return: The nested SDFG node.
        :raise KeyError: Some or all nodes in the subgraph are not located in
                         this state, or the state does not belong to the given
        :raise ValueError: The subgraph is contained in more than one scope.
    if state.parent != sdfg:
        raise KeyError('State does not belong to given SDFG')
    if subgraph is not state and subgraph.graph is not state:
        raise KeyError('Subgraph does not belong to given state')

    # Find the top-level scope
    scope_tree = state.scope_tree()
    scope_dict = state.scope_dict()
    scope_dict_children = state.scope_children()
    top_scopenode = -1  # Initialized to -1 since "None" already means top-level

    for node in subgraph.nodes():
        if node not in scope_dict:
            raise KeyError('Node not found in state')

        # If scope entry/exit, ensure entire scope is in subgraph
        if isinstance(node, nodes.EntryNode):
            scope_nodes = scope_dict_children[node]
            if any(n not in subgraph.nodes() for n in scope_nodes):
                raise ValueError('Subgraph contains partial scopes (entry)')
        elif isinstance(node, nodes.ExitNode):
            entry = state.entry_node(node)
            scope_nodes = scope_dict_children[entry] + [entry]
            if any(n not in subgraph.nodes() for n in scope_nodes):
                raise ValueError('Subgraph contains partial scopes (exit)')

        scope_node = scope_dict[node]
        if scope_node not in subgraph.nodes():
            if top_scopenode != -1 and top_scopenode != scope_node:
                raise ValueError('Subgraph is contained in more than one scope')
            top_scopenode = scope_node

    scope = scope_tree[top_scopenode]

    # Consolidate edges in top scope
    utils.consolidate_edges(sdfg, scope)
    snodes = subgraph.nodes()

    # Collect inputs and outputs of the nested SDFG
    inputs: List[MultiConnectorEdge] = []
    outputs: List[MultiConnectorEdge] = []
    for node in snodes:
        for edge in state.in_edges(node):
            if edge.src not in snodes:
        for edge in state.out_edges(node):
            if edge.dst not in snodes:

    # Collect transients not used outside of subgraph (will be removed of
    # top-level graph)
    data_in_subgraph = set(n.data for n in subgraph.nodes() if isinstance(n, nodes.AccessNode))
    # Find other occurrences in SDFG
    other_nodes = set(n.data for s in sdfg.nodes() for n in s.nodes()
                      if isinstance(n, nodes.AccessNode) and n not in subgraph.nodes())
    subgraph_transients = set()
    for data in data_in_subgraph:
        datadesc = sdfg.arrays[data]
        if datadesc.transient and data not in other_nodes:

    # All transients of edges between code nodes are also added to nested graph
    for edge in subgraph.edges():
        if (isinstance(edge.src, nodes.CodeNode) and isinstance(edge.dst, nodes.CodeNode)):

    # Collect data used in access nodes within subgraph (will be referenced in
    # full upon nesting)
    input_arrays = set()
    output_arrays = {}
    for node in subgraph.nodes():
        if (isinstance(node, nodes.AccessNode) and node.data not in subgraph_transients):
            if node.has_reads(state):
            if node.has_writes(state):
                output_arrays[node.data] = state.in_edges(node)[0].data.wcr

    # Create the nested SDFG
    nsdfg = SDFG(name or 'nested_' + state.label)

    # Transients are added to the nested graph as-is
    for name in subgraph_transients:
        nsdfg.add_datadesc(name, sdfg.arrays[name])

    # Input/output data that are not source/sink nodes are added to the graph
    # as non-transients
    for name in (input_arrays | output_arrays.keys()):
        datadesc = copy.deepcopy(sdfg.arrays[name])
        datadesc.transient = False
        nsdfg.add_datadesc(name, datadesc)

    # Connected source/sink nodes outside subgraph become global data
    # descriptors in nested SDFG
    input_names = {}
    output_names = {}
    global_subsets: Dict[str, Tuple[str, Subset]] = {}
    for edge in inputs:
        if edge.data.data is None:  # Skip edges with an empty memlet
        name = edge.data.data
        if name not in global_subsets:
            datadesc = copy.deepcopy(sdfg.arrays[edge.data.data])
            datadesc.transient = False
            if not full_data:
                datadesc.shape = edge.data.subset.size()
            new_name = nsdfg.add_datadesc(name, datadesc, find_new_name=True)
            global_subsets[name] = (new_name, edge.data.subset)
            new_name, subset = global_subsets[name]
            if not full_data:
                new_subset = union(subset, edge.data.subset)
                if new_subset is None:
                    new_subset = Range.from_array(sdfg.arrays[name])
                global_subsets[name] = (new_name, new_subset)
                nsdfg.arrays[new_name].shape = new_subset.size()
        input_names[edge] = new_name
    for edge in outputs:
        if edge.data.data is None:  # Skip edges with an empty memlet
        name = edge.data.data
        if name not in global_subsets:
            datadesc = copy.deepcopy(sdfg.arrays[edge.data.data])
            datadesc.transient = False
            if not full_data:
                datadesc.shape = edge.data.subset.size()
            new_name = nsdfg.add_datadesc(name, datadesc, find_new_name=True)
            global_subsets[name] = (new_name, edge.data.subset)
            new_name, subset = global_subsets[name]
            if not full_data:
                new_subset = union(subset, edge.data.subset)
                if new_subset is None:
                    new_subset = Range.from_array(sdfg.arrays[name])
                global_subsets[name] = (new_name, new_subset)
                nsdfg.arrays[new_name].shape = new_subset.size()
        output_names[edge] = new_name

    # Add scope symbols to the nested SDFG
    defined_vars = set(
        symbolic.pystr_to_symbolic(s) for s in (state.symbols_defined_at(top_scopenode).keys()
                                                | sdfg.symbols))
    for v in defined_vars:
        if v in sdfg.symbols:
            sym = sdfg.symbols[v]
            nsdfg.add_symbol(v, sym.dtype)

    # Add constants to nested SDFG
    for cstname, cstval in sdfg.constants.items():
        nsdfg.add_constant(cstname, cstval)

    # Create nested state
    nstate = nsdfg.add_state()

    # Add subgraph nodes and edges to nested state
    for e in subgraph.edges():
        nstate.add_edge(e.src, e.src_conn, e.dst, e.dst_conn, copy.deepcopy(e.data))

    # Modify nested SDFG parents in subgraph
    for node in subgraph.nodes():
        if isinstance(node, nodes.NestedSDFG):
            node.sdfg.parent = nstate
            node.sdfg.parent_sdfg = nsdfg
            node.sdfg.parent_nsdfg_node = node

    # Add access nodes and edges as necessary
    edges_to_offset = []
    for edge, name in input_names.items():
        node = nstate.add_read(name)
        new_edge = copy.deepcopy(edge.data)
        new_edge.data = name
        edges_to_offset.append((edge, nstate.add_edge(node, None, edge.dst, edge.dst_conn, new_edge)))
    for edge, name in output_names.items():
        node = nstate.add_write(name)
        new_edge = copy.deepcopy(edge.data)
        new_edge.data = name
        edges_to_offset.append((edge, nstate.add_edge(edge.src, edge.src_conn, node, None, new_edge)))

    # Offset memlet paths inside nested SDFG according to subsets
    for original_edge, new_edge in edges_to_offset:
        for edge in nstate.memlet_tree(new_edge):
            edge.data.data = new_edge.data.data
            if not full_data:
                edge.data.subset.offset(global_subsets[original_edge.data.data][1], True)

    # Add nested SDFG node to the input state
    nested_sdfg = state.add_nested_sdfg(nsdfg, None,
                                        set(input_names.values()) | input_arrays,
                                        set(output_names.values()) | output_arrays.keys())

    # Reconnect memlets to nested SDFG
    reconnected_in = set()
    reconnected_out = set()
    empty_input = None
    empty_output = None
    for edge in inputs:
        if edge.data.data is None:
            empty_input = edge

        name = input_names[edge]
        if name in reconnected_in:
        if full_data:
            data = Memlet.from_array(edge.data.data, sdfg.arrays[edge.data.data])
            data = copy.deepcopy(edge.data)
            data.subset = global_subsets[edge.data.data][1]
        state.add_edge(edge.src, edge.src_conn, nested_sdfg, name, data)

    for edge in outputs:
        if edge.data.data is None:
            empty_output = edge

        name = output_names[edge]
        if name in reconnected_out:
        if full_data:
            data = Memlet.from_array(edge.data.data, sdfg.arrays[edge.data.data])
            data = copy.deepcopy(edge.data)
            data.subset = global_subsets[edge.data.data][1]
        data.wcr = edge.data.wcr
        state.add_edge(nested_sdfg, name, edge.dst, edge.dst_conn, data)

    # Connect access nodes to internal input/output data as necessary
    entry = scope.entry
    exit = scope.exit
    for name in input_arrays:
        node = state.add_read(name)
        if entry is not None:
            state.add_nedge(entry, node, Memlet())
        state.add_edge(node, None, nested_sdfg, name, Memlet.from_array(name, sdfg.arrays[name]))
    for name, wcr in output_arrays.items():
        node = state.add_write(name)
        if exit is not None:
            state.add_nedge(node, exit, Memlet())
        state.add_edge(nested_sdfg, name, node, None, Memlet(data=name, wcr=wcr))

    # Graph was not reconnected, but needs to be
    if state.in_degree(nested_sdfg) == 0 and empty_input is not None:
        state.add_edge(empty_input.src, empty_input.src_conn, nested_sdfg, None, empty_input.data)
    if state.out_degree(nested_sdfg) == 0 and empty_output is not None:
        state.add_edge(nested_sdfg, None, empty_output.dst, empty_output.dst_conn, empty_output.data)

    # Remove subgraph nodes from graph

    # Remove subgraph transients from top-level graph
    for transient in subgraph_transients:
        del sdfg.arrays[transient]

    # Remove newly isolated nodes due to memlet consolidation
    for edge in inputs:
        if state.in_degree(edge.src) + state.out_degree(edge.src) == 0:
    for edge in outputs:
        if state.in_degree(edge.dst) + state.out_degree(edge.dst) == 0:

    return nested_sdfg
Пример #7
    def apply(self, sdfg):
        state: SDFGState = sdfg.nodes()[self.state_id]
        nsdfg_node = state.nodes()[self.subgraph[InlineSDFG._nested_sdfg]]
        nsdfg: SDFG = nsdfg_node.sdfg
        nstate: SDFGState = nsdfg.nodes()[0]

        nsdfg_scope_entry = state.entry_node(nsdfg_node)
        nsdfg_scope_exit = (state.exit_node(nsdfg_scope_entry)
                            if nsdfg_scope_entry is not None else None)

        # Collect and update top-level SDFG metadata

        # Global/init/exit code
        for loc, code in nsdfg.global_code.items():
            sdfg.append_global_code(code.code, loc)
        for loc, code in nsdfg.init_code.items():
            sdfg.append_init_code(code.code, loc)
        for loc, code in nsdfg.exit_code.items():
            sdfg.append_exit_code(code.code, loc)

        # Constants
        for cstname, cstval in nsdfg.constants.items():
            if cstname in sdfg.constants:
                if cstval != sdfg.constants[cstname]:
                    warnings.warn('Constant value mismatch for "%s" while '
                                  'inlining SDFG. Inner = %s != %s = outer' %
                                  (cstname, cstval, sdfg.constants[cstname]))
                sdfg.add_constant(cstname, cstval)

        # Find original source/destination edges (there is only one edge per
        # connector, according to match)
        inputs: Dict[str, MultiConnectorEdge] = {}
        outputs: Dict[str, MultiConnectorEdge] = {}
        input_set: Dict[str, str] = {}
        output_set: Dict[str, str] = {}
        for e in state.in_edges(nsdfg_node):
            inputs[e.dst_conn] = e
            input_set[e.data.data] = e.dst_conn
        for e in state.out_edges(nsdfg_node):
            outputs[e.src_conn] = e
            output_set[e.data.data] = e.src_conn

        # All transients become transients of the parent (if data already
        # exists, find new name)
        # Mapping from nested transient name to top-level name
        transients: Dict[str, str] = {}
        for node in nstate.nodes():
            if isinstance(node, nodes.AccessNode):
                datadesc = nsdfg.arrays[node.data]
                if node.data not in transients and datadesc.transient:
                    name = sdfg.add_datadesc('%s_%s' %
                                             (nsdfg.label, node.data),
                    transients[node.data] = name

        # All transients of edges between code nodes are also added to parent
        for edge in nstate.edges():
            if (isinstance(edge.src, nodes.CodeNode)
                    and isinstance(edge.dst, nodes.CodeNode)):
                datadesc = nsdfg.arrays[edge.data.data]
                if edge.data.data not in transients and datadesc.transient:
                    name = sdfg.add_datadesc('%s_%s' %
                                             (nsdfg.label, edge.data.data),
                    transients[edge.data.data] = name

        # Collect nodes to add to top-level graph
        new_incoming_edges: Dict[nodes.Node, MultiConnectorEdge] = {}
        new_outgoing_edges: Dict[nodes.Node, MultiConnectorEdge] = {}

        source_accesses = set()
        sink_accesses = set()
        for node in nstate.source_nodes():
            if (isinstance(node, nodes.AccessNode)
                    and node.data not in transients):
                new_incoming_edges[node] = inputs[node.data]
        for node in nstate.sink_nodes():
            if (isinstance(node, nodes.AccessNode)
                    and node.data not in transients):
                new_outgoing_edges[node] = outputs[node.data]

        # Add nested SDFG into top-level SDFG

        # Add nested nodes into original state
        subgraph = SubgraphView(nstate, [
            n for n in nstate.nodes()
            if n not in (source_accesses | sink_accesses)
        for edge in subgraph.edges():
            state.add_edge(edge.src, edge.src_conn, edge.dst, edge.dst_conn,

        # Replace data on inlined SDFG nodes/edges

        # Replace symbols using invocation symbol mapping
        # Two-step replacement (N -> __dacesym_N --> map[N]) to avoid clashes
        for symname, symvalue in nsdfg_node.symbol_mapping.items():
            if str(symname) != str(symvalue):
                nsdfg.replace(symname, '__dacesym_' + symname)
        for symname, symvalue in nsdfg_node.symbol_mapping.items():
            if str(symname) != str(symvalue):
                nsdfg.replace('__dacesym_' + symname, symvalue)

        # Replace data names with their top-level counterparts
        repldict = {}
            k: v.data.data
            for k, v in itertools.chain(inputs.items(), outputs.items())
        for node in subgraph.nodes():
            if isinstance(node, nodes.AccessNode) and node.data in repldict:
                node.data = repldict[node.data]
        for edge in subgraph.edges():
            if edge.data.data in repldict:
                edge.data.data = repldict[edge.data.data]

        # Reconnect inlined SDFG

        # If a source/sink node is one of the inputs/outputs, reconnect it,
        # replacing memlets in outgoing/incoming paths
        modified_edges = set()
        modified_edges |= self._modify_memlet_path(new_incoming_edges, nstate,
                                                   state, True)
        modified_edges |= self._modify_memlet_path(new_outgoing_edges, nstate,
                                                   state, False)

        # Modify all other internal edges pertaining to input/output nodes
        for node in subgraph.nodes():
            if isinstance(node, nodes.AccessNode):
                if node.data in input_set or node.data in output_set:
                    if node.data in input_set:
                        outer_edge = inputs[input_set[node.data]]
                        outer_edge = outputs[output_set[node.data]]

                    for edge in state.all_edges(node):
                        if (edge not in modified_edges
                                and edge.data.data == node.data):
                            for e in state.memlet_tree(edge):
                                if e.data.data == node.data:
                                    e._data = helpers.unsqueeze_memlet(
                                        e.data, outer_edge.data)

        # If source/sink node is not connected to a source/destination access
        # node, and the nested SDFG is in a scope, connect to scope with empty
        # memlets
        if nsdfg_scope_entry is not None:
            for node in subgraph.nodes():
                if state.in_degree(node) == 0:
                    state.add_edge(nsdfg_scope_entry, None, node, None,
                if state.out_degree(node) == 0:
                    state.add_edge(node, None, nsdfg_scope_exit, None,

        # Replace nested SDFG parents with new SDFG
        for node in nstate.nodes():
            if isinstance(node, nodes.NestedSDFG):
                node.sdfg.parent = state
                node.sdfg.parent_sdfg = sdfg
                node.sdfg.parent_nsdfg_node = node

        # Remove all unused external inputs/output memlet paths, as well as
        # resulting isolated nodes
        removed_in_edges = self._remove_edge_path(state,
                                                  set(inputs.keys()) -
        removed_out_edges = self._remove_edge_path(state,
                                                   set(outputs.keys()) -

        # Re-add in/out edges to first/last nodes in subgraph
        order = [
            x for x in nx.topological_sort(nstate._nx)
            if isinstance(x, nodes.AccessNode)
        for edge in removed_in_edges:
            # Find first access node that refers to this edge
            node = next(n for n in order if n.data == edge.data.data)
            state.add_edge(edge.src, edge.src_conn, node, edge.dst_conn,
        for edge in removed_out_edges:
            # Find last access node that refers to this edge
            node = next(n for n in reversed(order) if n.data == edge.data.data)
            state.add_edge(node, edge.src_conn, edge.dst, edge.dst_conn,

        # Remove nested SDFG node
Пример #8
    def apply(self, sdfg: SDFG):
        subgraph = self.subgraph_view(sdfg)

        entry_states_in, entry_states_out = self.get_entry_states(
            sdfg, subgraph)
        _, exit_states_out = self.get_exit_states(sdfg, subgraph)

        entry_state_in = entry_states_in.pop()
        entry_state_out = entry_states_out.pop() \
            if len(entry_states_out) > 0 else None
        exit_state_out = exit_states_out.pop() \
            if len(exit_states_out) > 0 else None

        launch_state = None
        entry_guard_state = None
        exit_guard_state = None

        # generate entry guard state if needed
        if self.include_in_assignment and entry_state_out is not None:
            entry_edge = sdfg.edges_between(entry_state_out, entry_state_in)[0]
            if len(entry_edge.data.assignments) > 0:
                entry_guard_state = sdfg.add_state(
                        self.kernel_prefix +
                        '_' if self.kernel_prefix != '' else ''))
                sdfg.add_edge(entry_state_out, entry_guard_state,
                    entry_guard_state, entry_state_in,
                    InterstateEdge(None, entry_edge.data.assignments))

                # Update SubgraphView
                new_node_list = subgraph.nodes()
                subgraph = SubgraphView(sdfg, new_node_list)

                launch_state = sdfg.add_state_before(
                        self.kernel_prefix +
                        '_' if self.kernel_prefix != '' else ''))

        # generate exit guard state
        if exit_state_out is not None:
            exit_guard_state = sdfg.add_state_before(
                    self.kernel_prefix +
                    '_' if self.kernel_prefix != '' else ''))

            # Update SubgraphView
            new_node_list = subgraph.nodes()
            subgraph = SubgraphView(sdfg, new_node_list)

            if launch_state is None:
                launch_state = sdfg.add_state_before(
                        self.kernel_prefix +
                        '_' if self.kernel_prefix != '' else ''))

        # If the launch state doesn't exist at this point then there is no other
        # states outside of the kernel, so create a stand alone launch state
        if launch_state is None:
            assert (entry_state_in is None and exit_state_out is None)
            launch_state = sdfg.add_state(label='{}kernel_launch'.format(
                self.kernel_prefix + '_' if self.kernel_prefix != '' else ''))

        # create sdfg for kernel and fill it with states and edges from
        # ssubgraph dfg will be nested at the end
        kernel_sdfg = SDFG(
            '{}kernel'.format(self.kernel_prefix +
                              '_' if self.kernel_prefix != '' else ''))

        edges = subgraph.edges()
        for edge in edges:
            kernel_sdfg.add_edge(edge.src, edge.dst, edge.data)

        # Setting entry node in nested SDFG if no entry guard was created
        if entry_guard_state is None:
            kernel_sdfg.start_state = kernel_sdfg.node_id(entry_state_in)

        for state in subgraph:
            state.parent = kernel_sdfg

        # remove the now nested nodes from the outer sdfg and make sure the
        # launch state is properly connected to remaining states

        if entry_state_out is not None \
                and len(sdfg.edges_between(entry_state_out, launch_state)) == 0:
            sdfg.add_edge(entry_state_out, launch_state, InterstateEdge())

        if exit_state_out is not None \
                and len(sdfg.edges_between(launch_state, exit_state_out)) == 0:
            sdfg.add_edge(launch_state, exit_state_out, InterstateEdge())

        # Handle data for kernel
        kernel_data = set(node.data for state in kernel_sdfg
                          for node in state.nodes()
                          if isinstance(node, nodes.AccessNode))

        # move Streams and Register data into the nested SDFG
        # normal data will be added as kernel argument
        kernel_args = []
        for data in kernel_data:
            if (isinstance(sdfg.arrays[data], dace.data.Stream) or
                (isinstance(sdfg.arrays[data], dace.data.Array)
                 and sdfg.arrays[data].storage == StorageType.Register)):
                kernel_sdfg.add_datadesc(data, sdfg.arrays[data])
                del sdfg.arrays[data]
                copy_desc = copy.deepcopy(sdfg.arrays[data])
                copy_desc.transient = False
                copy_desc.storage = StorageType.Default
                kernel_sdfg.add_datadesc(data, copy_desc)

        # read only data will be passed as input, writeable data will be passed
        # as 'output' otherwise kernel cannot write to data
        kernel_args_read = set()
        kernel_args_write = set()
        for data in kernel_args:
            data_accesses_read_only = [
                node.access == dtypes.AccessType.ReadOnly
                for state in kernel_sdfg for node in state
                if isinstance(node, nodes.AccessNode) and node.data == data
            if all(data_accesses_read_only):

        # Kernel SDFG is complete at this point
        if self.validate:

        # Filling launch state with nested SDFG, map and access nodes
        map_entry, map_exit = launch_state.add_map(
                self.kernel_prefix + '_' if self.kernel_prefix != '' else ''),

        nested_sdfg = launch_state.add_nested_sdfg(

        # Create and connect read only data access nodes
        for arg in kernel_args_read:
            read_node = launch_state.add_read(arg)
                                             arg, sdfg.arrays[arg]))

        # Create and connect writable data access nodes
        for arg in kernel_args_write:
            write_node = launch_state.add_write(arg)
                                             arg, sdfg.arrays[arg]))

        # Transformation is done
        if self.validate:
Пример #9
def nest_state_subgraph(sdfg: SDFG,
                        state: SDFGState,
                        subgraph: SubgraphView,
                        name: Optional[str] = None,
                        full_data: bool = False) -> nodes.NestedSDFG:
    """ Turns a state subgraph into a nested SDFG. Operates in-place.
        :param sdfg: The SDFG containing the state subgraph.
        :param state: The state containing the subgraph.
        :param subgraph: Subgraph to nest.
        :param name: An optional name for the nested SDFG.
        :param full_data: If True, nests entire input/output data.
        :return: The nested SDFG node.
        :raise KeyError: Some or all nodes in the subgraph are not located in
                         this state, or the state does not belong to the given
        :raise ValueError: The subgraph is contained in more than one scope.
    if state.parent != sdfg:
        raise KeyError('State does not belong to given SDFG')
    if subgraph.graph != state:
        raise KeyError('Subgraph does not belong to given state')

    # Find the top-level scope
    scope_tree = state.scope_tree()
    scope_dict = state.scope_dict()
    scope_dict_children = state.scope_dict(True)
    top_scopenode = -1  # Initialized to -1 since "None" already means top-level

    for node in subgraph.nodes():
        if node not in scope_dict:
            raise KeyError('Node not found in state')

        # If scope entry/exit, ensure entire scope is in subgraph
        if isinstance(node, nodes.EntryNode):
            scope_nodes = scope_dict_children[node]
            if any(n not in subgraph.nodes() for n in scope_nodes):
                raise ValueError('Subgraph contains partial scopes (entry)')
        elif isinstance(node, nodes.ExitNode):
            entry = state.entry_node(node)
            scope_nodes = scope_dict_children[entry] + [entry]
            if any(n not in subgraph.nodes() for n in scope_nodes):
                raise ValueError('Subgraph contains partial scopes (exit)')

        scope_node = scope_dict[node]
        if scope_node not in subgraph.nodes():
            if top_scopenode != -1 and top_scopenode != scope_node:
                raise ValueError(
                    'Subgraph is contained in more than one scope')
            top_scopenode = scope_node

    scope = scope_tree[top_scopenode]

    # Collect inputs and outputs of the nested SDFG
    inputs: List[MultiConnectorEdge] = []
    outputs: List[MultiConnectorEdge] = []
    for node in subgraph.source_nodes():
    for node in subgraph.sink_nodes():

    # Collect transients not used outside of subgraph (will be removed of
    # top-level graph)
    data_in_subgraph = set(n.data for n in subgraph.nodes()
                           if isinstance(n, nodes.AccessNode))
    # Find other occurrences in SDFG
    other_nodes = set(
        n.data for s in sdfg.nodes() for n in s.nodes()
        if isinstance(n, nodes.AccessNode) and n not in subgraph.nodes())
    subgraph_transients = set()
    for data in data_in_subgraph:
        datadesc = sdfg.arrays[data]
        if datadesc.transient and data not in other_nodes:

    # All transients of edges between code nodes are also added to nested graph
    for edge in subgraph.edges():
        if (isinstance(edge.src, nodes.CodeNode)
                and isinstance(edge.dst, nodes.CodeNode)):

    # Collect data used in access nodes within subgraph (will be referenced in
    # full upon nesting)
    input_arrays = set()
    output_arrays = set()
    for node in subgraph.nodes():
        if (isinstance(node, nodes.AccessNode)
                and node.data not in subgraph_transients):
            if state.out_degree(node) > 0:
            if state.in_degree(node) > 0:

    # Create the nested SDFG
    nsdfg = SDFG(name or 'nested_' + state.label)

    # Transients are added to the nested graph as-is
    for name in subgraph_transients:
        nsdfg.add_datadesc(name, sdfg.arrays[name])

    # Input/output data that are not source/sink nodes are added to the graph
    # as non-transients
    for name in (input_arrays | output_arrays):
        datadesc = copy.deepcopy(sdfg.arrays[name])
        datadesc.transient = False
        nsdfg.add_datadesc(name, datadesc)

    # Connected source/sink nodes outside subgraph become global data
    # descriptors in nested SDFG
    input_names = []
    output_names = []
    for edge in inputs:
        if edge.data.data is None:  # Skip edges with an empty memlet
        name = '__in_' + edge.data.data
        datadesc = copy.deepcopy(sdfg.arrays[edge.data.data])
        datadesc.transient = False
        if not full_data:
            datadesc.shape = edge.data.subset.size()
            nsdfg.add_datadesc(name, datadesc, find_new_name=True))
    for edge in outputs:
        if edge.data.data is None:  # Skip edges with an empty memlet
        name = '__out_' + edge.data.data
        datadesc = copy.deepcopy(sdfg.arrays[edge.data.data])
        datadesc.transient = False
        if not full_data:
            datadesc.shape = edge.data.subset.size()
            nsdfg.add_datadesc(name, datadesc, find_new_name=True))

    # Add scope symbols to the nested SDFG
    for v in scope.defined_vars:
        if v in sdfg.symbols:
            sym = sdfg.symbols[v]
            nsdfg.add_symbol(v, sym.dtype)

    # Create nested state
    nstate = nsdfg.add_state()

    # Add subgraph nodes and edges to nested state
    for e in subgraph.edges():
        nstate.add_edge(e.src, e.src_conn, e.dst, e.dst_conn, e.data)

    # Modify nested SDFG parents in subgraph
    for node in subgraph.nodes():
        if isinstance(node, nodes.NestedSDFG):
            node.sdfg.parent = nstate
            node.sdfg.parent_sdfg = nsdfg

    # Add access nodes and edges as necessary
    edges_to_offset = []
    for name, edge in zip(input_names, inputs):
        node = nstate.add_read(name)
        new_edge = copy.deepcopy(edge.data)
        new_edge.data = name
                                nstate.add_edge(node, None, edge.dst,
                                                edge.dst_conn, new_edge)))
    for name, edge in zip(output_names, outputs):
        node = nstate.add_write(name)
        new_edge = copy.deepcopy(edge.data)
        new_edge.data = name
                                nstate.add_edge(edge.src, edge.src_conn, node,
                                                None, new_edge)))

    # Offset memlet paths inside nested SDFG according to subsets
    for original_edge, new_edge in edges_to_offset:
        for edge in nstate.memlet_tree(new_edge):
            edge.data.data = new_edge.data.data
            if not full_data:
                edge.data.subset.offset(original_edge.data.subset, True)

    # Add nested SDFG node to the input state
    nested_sdfg = state.add_nested_sdfg(nsdfg, None,
                                        set(input_names) | input_arrays,
                                        set(output_names) | output_arrays)

    # Reconnect memlets to nested SDFG
    for name, edge in zip(input_names, inputs):
        if full_data:
            data = Memlet.from_array(edge.data.data,
            data = edge.data
        state.add_edge(edge.src, edge.src_conn, nested_sdfg, name, data)
    for name, edge in zip(output_names, outputs):
        if full_data:
            data = Memlet.from_array(edge.data.data,
            data = edge.data
        state.add_edge(nested_sdfg, name, edge.dst, edge.dst_conn, data)

    # Connect access nodes to internal input/output data as necessary
    entry = scope.entry
    exit = scope.exit
    for name in input_arrays:
        node = state.add_read(name)
        if entry is not None:
            state.add_nedge(entry, node, EmptyMemlet())
        state.add_edge(node, None, nested_sdfg, name,
                       Memlet.from_array(name, sdfg.arrays[name]))
    for name in output_arrays:
        node = state.add_write(name)
        if exit is not None:
            state.add_nedge(node, exit, EmptyMemlet())
        state.add_edge(nested_sdfg, name, node, None,
                       Memlet.from_array(name, sdfg.arrays[name]))

    # Remove subgraph nodes from graph

    # Remove subgraph transients from top-level graph
    for transient in subgraph_transients:
        del sdfg.arrays[transient]

    return nested_sdfg