Beispiel #1
0
def _build_dataflow_graph_recurse(sdfg, state, primitives, modules, superEntry,
                                  super_exit):
    # Array of pairs (exit node, memlet)
    exit_nodes = []

    if len(primitives) == 0:
        # Inject empty tasklets into empty states
        primitives = [astnodes._EmptyTaskletNode("Empty Tasklet", None)]

    for prim in primitives:
        label = prim.name

        # Expand node to get entry and exit points
        if isinstance(prim, astnodes._MapNode):
            if len(prim.children) == 0:
                raise ValueError("Map node expected to have children")
            mapNode = nd.Map(label,
                             prim.params,
                             prim.range,
                             is_async=prim.is_async)
            # Add connectors for inputs that exist as array nodes
            entry = nd.MapEntry(
                mapNode,
                _get_input_symbols(prim.inputs, prim.range.free_symbols))
            exit = nd.MapExit(mapNode)
        elif isinstance(prim, astnodes._ConsumeNode):
            if len(prim.children) == 0:
                raise ValueError("Consume node expected to have children")
            consumeNode = nd.Consume(label, (prim.params[1], prim.num_pes),
                                     prim.condition)
            entry = nd.ConsumeEntry(consumeNode)
            exit = nd.ConsumeExit(consumeNode)
        elif isinstance(prim, astnodes._ReduceNode):
            rednode = nd.Reduce(prim.ast, prim.axes, prim.identity)
            state.add_node(rednode)
            entry = rednode
            exit = rednode
        elif isinstance(prim, astnodes._TaskletNode):
            if isinstance(prim, astnodes._EmptyTaskletNode):
                tasklet = nd.EmptyTasklet(prim.name)
            else:
                # Remove memlets from tasklet AST
                if prim.language == types.Language.Python:
                    clean_code = MemletRemover().visit(prim.ast)
                    clean_code = ModuleInliner(modules).visit(clean_code)
                else:  # Use external code from tasklet definition
                    if prim.extcode is None:
                        raise SyntaxError("Cannot define an intrinsic "
                                          "tasklet without an implementation")
                    clean_code = prim.extcode
                tasklet = nd.Tasklet(
                    prim.name,
                    set(prim.inputs.keys()),
                    set(prim.outputs.keys()),
                    code=clean_code,
                    language=prim.language,
                    code_global=prim.gcode)  # TODO: location=prim.location

            # Need to add the tasklet in case we're in an empty state, where no
            # edge will be drawn to it
            state.add_node(tasklet)
            entry = tasklet
            exit = tasklet

        elif isinstance(prim, astnodes._NestedSDFGNode):
            prim.sdfg.parent = state
            prim.sdfg._parent_sdfg = sdfg
            prim.sdfg.update_sdfg_list([])
            nsdfg = nd.NestedSDFG(prim.name, prim.sdfg,
                                  set(prim.inputs.keys()),
                                  set(prim.outputs.keys()))
            state.add_node(nsdfg)
            entry = nsdfg
            exit = nsdfg

        elif isinstance(prim, astnodes._ProgramNode):
            return
        elif isinstance(prim, astnodes._ControlFlowNode):
            continue
        else:
            raise TypeError("Node type not implemented: " +
                            str(prim.__class__))

        # Add incoming edges
        for varname, memlet in prim.inputs.items():
            arr = memlet.dataname
            if (prim.parent is not None
                    and memlet.dataname in prim.parent.transients.keys()):
                node = input_node_for_array(state, memlet.dataname)

                # Add incoming edge into transient as well
                # FIXME: A bit hacked?
                if arr in prim.parent.inputs:
                    astmem = prim.parent.inputs[arr]
                    _add_astmemlet_edge(sdfg, state, superEntry, None, node,
                                        None, astmem)

                    # Remove local name from incoming edge to parent
                    prim.parent.inputs[arr].local_name = None
            elif superEntry:
                node = superEntry
            else:
                node = input_node_for_array(state, memlet.dataname)

            # Destination connector inference
            # Connected to a tasklet or a nested SDFG
            dst_conn = (memlet.local_name
                        if isinstance(entry, nd.CodeNode) else None)
            # Connected to a scope as part of its range
            if str(varname).startswith('__DACEIN_'):
                dst_conn = str(varname)[9:]
            # Handle special case of consume input stream
            if (isinstance(entry, nd.ConsumeEntry)
                    and memlet.data == prim.stream):
                dst_conn = 'IN_stream'

            # If a memlet that covers this input already exists, skip
            # generating this one; otherwise replace memlet with ours
            skip_incoming_edge = False
            remove_edge = None
            for e in state.edges_between(node, entry):
                if e.data.data != memlet.dataname or dst_conn != e.dst_conn:
                    continue
                if e.data.subset.covers(memlet.subset):
                    skip_incoming_edge = True
                    break
                elif memlet.subset.covers(e.data.subset):
                    remove_edge = e
                    break
                else:
                    print('WARNING: Performing bounding-box union on',
                          memlet.subset, 'and', e.data.subset, '(in)')
                    e.data.subset = sbs.bounding_box_union(
                        e.data.subset, memlet.subset)
                    e.data.num_accesses += memlet.num_accesses
                    skip_incoming_edge = True
                    break

            if remove_edge is not None:
                state.remove_edge(remove_edge)

            if skip_incoming_edge == False:
                _add_astmemlet_edge(sdfg, state, node, None, entry, dst_conn,
                                    memlet)

        # If there are no inputs, generate a dummy edge
        if superEntry and len(prim.inputs) == 0:
            state.add_edge(superEntry, None, entry, None, EmptyMemlet())

        if len(prim.children) > 0:
            # Recurse
            inner_outputs = _build_dataflow_graph_recurse(
                sdfg, state, prim.children, modules, entry, exit)
            # Infer output node for each memlet
            for i, (out_src, mem) in enumerate(inner_outputs):
                # If there is no such array in this primitive's outputs,
                # it's an external array (e.g., a map in a map). In this case,
                # connect to the exit node
                if mem.dataname in prim.outputs:
                    inner_outputs[i] = (out_src, prim.outputs[mem.dataname])
                else:
                    inner_outputs[i] = (out_src, mem)
        else:
            inner_outputs = [(exit, mem) for mem in prim.outputs.values()]

        # Add outgoing edges
        for out_src, astmem in inner_outputs:

            data = astmem.data
            dataname = astmem.dataname

            # If WCR is not none, it needs to be handled in the code. Check for
            # this after, as we only expect it for one distinct case
            wcr_was_handled = astmem.wcr is None

            # TODO: This is convoluted. We should find a more readable
            # way of connecting the outgoing edges.

            if super_exit is None:

                # Assert that we're in a top-level node
                if ((not isinstance(prim.parent, astnodes._ProgramNode)) and
                    (not isinstance(prim.parent, astnodes._ControlFlowNode))):
                    raise RuntimeError("Expected to be at the top node")

                # Looks hacky
                src_conn = (astmem.local_name if isinstance(
                    out_src, (nd.Tasklet, nd.NestedSDFG)) else None)

                # Here we just need to connect memlets directly to their
                # respective data nodes
                out_tgt = output_node_for_array(state, astmem.dataname)

                # If a memlet that covers this outuput already exists, skip
                # generating this one; otherwise replace memlet with ours
                skip_outgoing_edge = False
                remove_edge = None
                for e in state.edges_between(out_src, out_tgt):
                    if e.data.data != astmem.dataname or src_conn != e.src_conn:
                        continue
                    if e.data.subset.covers(astmem.subset):
                        skip_outgoing_edge = True
                        break
                    elif astmem.subset.covers(e.data.subset):
                        remove_edge = e
                        break
                    else:
                        print('WARNING: Performing bounding-box union on',
                              astmem.subset, 'and', e.data.subset, '(out)')
                        e.data.subset = sbs.bounding_box_union(
                            e.data.subset, astmem.subset)
                        e.data.num_accesses += astmem.num_accesses
                        skip_outgoing_edge = True
                        break

                if skip_outgoing_edge == True:
                    continue
                if remove_edge is not None:
                    state.remove_edge(remove_edge)

                _add_astmemlet_edge(sdfg,
                                    state,
                                    out_src,
                                    src_conn,
                                    out_tgt,
                                    None,
                                    astmem,
                                    wcr=astmem.wcr,
                                    wcr_identity=astmem.wcr_identity)
                wcr_was_handled = (True if astmem.wcr is not None else
                                   wcr_was_handled)

                # If the program defines another output, connect it too.
                # This refers to the case where we have streams, which
                # must define an input and output, and sometimes this output
                # is defined in pdp.outputs
                if (isinstance(out_tgt, nd.AccessNode)
                        and isinstance(out_tgt.desc(sdfg), dt.Stream)):
                    try:
                        stream_memlet = next(
                            v for k, v in prim.parent.outputs.items()
                            if k == out_tgt.data)
                        stream_output = output_node_for_array(
                            state, stream_memlet.dataname)
                        _add_astmemlet_edge(sdfg, state, out_tgt, None,
                                            stream_output, None, stream_memlet)
                    except StopIteration:  # Stream output not found, skip
                        pass

            else:  # We're in a nest

                if isinstance(prim, astnodes._ScopeNode):
                    # We're a map or a consume node, that needs to connect our
                    # exit to either an array or to the super_exit
                    if data.transient and dataname in prim.parent.transients:
                        # Connect the exit directly
                        out_tgt = output_node_for_array(state, data.dataname)
                        _add_astmemlet_edge(sdfg, state, out_src, None,
                                            out_tgt, None, astmem)
                    else:
                        # This is either a transient defined in an outer scope,
                        # or an I/O array, so redirect thruogh the exit node
                        _add_astmemlet_edge(sdfg, state, out_src, None,
                                            super_exit, None, astmem)
                        # Instruct outer recursion layer to continue the route
                        exit_nodes.append((super_exit, astmem))
                elif isinstance(
                        prim,
                    (astnodes._TaskletNode, astnodes._NestedSDFGNode)):
                    # We're a tasklet, and need to connect either to the exit
                    # if the array is I/O or is defined in a scope further out,
                    # or directly to the transient if it's defined locally
                    if dataname in prim.parent.transients:
                        # This is a local transient variable, so connect to it
                        # directly
                        out_tgt = output_node_for_array(state, data.dataname)
                        _add_astmemlet_edge(sdfg, state, out_src,
                                            astmem.local_name, out_tgt, None,
                                            astmem)
                    else:
                        # This is an I/O array, or an outer level transient, so
                        # redirect through the exit node
                        _add_astmemlet_edge(sdfg,
                                            state,
                                            out_src,
                                            astmem.local_name,
                                            super_exit,
                                            None,
                                            astmem,
                                            wcr=astmem.wcr,
                                            wcr_identity=astmem.wcr_identity)
                        exit_nodes.append((super_exit, astmem))
                        if astmem.wcr is not None:
                            wcr_was_handled = True  # Sanity check
                else:
                    raise TypeError("Unexpected node type: {}".format(
                        type(out_src).__name__))

            if not wcr_was_handled and not isinstance(prim,
                                                      astnodes._ScopeNode):
                raise RuntimeError("Detected unhandled WCR for primitive '{}' "
                                   "of type {}. WCR is only expected for "
                                   "tasklets in a map/consume scope.".format(
                                       prim.name,
                                       type(prim).__name__))

    return exit_nodes
Beispiel #2
0
class InlineSDFG(pattern_matching.Transformation):
    """ Inlines a single-state nested SDFG into a top-level SDFG.

        In particular, the steps taken are:

        1. All transient arrays become transients of the parent
        2. If a source/sink node is one of the inputs/outputs:
          a. Remove it
          b. Reconnect through external edges (map/accessnode)
          c. Replace and reoffset memlets with external data descriptor
        3. If other nodes carry the names of inputs/outputs:
          a. Replace data with external data descriptor
          b. Replace and reoffset memlets with external data descriptor
        4. If source/sink node is not connected to a source/destination, and
           the nested SDFG is in a scope, connect to scope with empty memlets
        5. Remove all unused external inputs/output memlet paths
        6. Remove isolated nodes resulting from previous step

    """

    _nested_sdfg = nodes.NestedSDFG('_', sd.SDFG('_'), set(), set())

    @staticmethod
    def annotates_memlets():
        return True

    @staticmethod
    def expressions():
        # Matches anything
        return [nxutil.node_path_graph(InlineSDFG._nested_sdfg)]

    @staticmethod
    def _find_edge(state: SDFGState, node: nodes.Node,
                   connector: str) -> Optional[MultiConnectorEdge]:
        for edge in state.in_edges(node):
            if edge.dst_conn == connector:
                return edge
        for edge in state.out_edges(node):
            if edge.src_conn == connector:
                return edge
        raise NameError('Edge with connector %s not found on node %s' %
                        (connector, node))

    @staticmethod
    def can_be_applied(graph, candidate, expr_index, sdfg, strict=False):
        nested_sdfg = graph.nodes()[candidate[InlineSDFG._nested_sdfg]]
        if len(nested_sdfg.sdfg.nodes()) != 1:
            return False

        # Ensure every connector has one incoming/outgoing edge
        in_connectors = set()
        out_connectors = set()
        for edge in graph.in_edges(nested_sdfg):
            if edge.dst_conn in in_connectors:
                return False
            in_connectors.add(edge.dst_conn)
        for edge in graph.out_edges(nested_sdfg):
            if edge.src_conn in out_connectors:
                return False
            out_connectors.add(edge.src_conn)

        # Ensure output connectors have no additional outputs (if in a scope),
        # and ensure no two connectors are directly connected to each other
        if graph.entry_node(nested_sdfg) is not None:
            all_connectors = in_connectors | out_connectors
            nstate = nested_sdfg.sdfg.node(0)
            for node in nstate.nodes():
                if isinstance(node, nodes.AccessNode):
                    if (node.data in out_connectors
                            and nstate.out_degree(node) > 0
                            and (node.data not in in_connectors
                                 or nstate.in_degree(node) > 0)):
                        return False
                    if (node.data in in_connectors
                            and any(e.dst.data in all_connectors
                                    for e in nstate.out_edges(node)
                                    if isinstance(e.dst, nodes.AccessNode))):
                        return False

        # If some reshaping that cannot be inlined / unsqueezed is happening,
        # do not match transformation in strict mode.
        if strict:
            for aname, array in nested_sdfg.sdfg.arrays.items():
                if array.transient:
                    continue
                edge = InlineSDFG._find_edge(graph, nested_sdfg, aname)
                if len(array.shape) > len(edge.data.subset):
                    return False

        return True

    @staticmethod
    def match_to_str(graph, candidate):
        return graph.label

    def _remove_edge_path(self,
                          state: SDFGState,
                          edge_map: Dict[str, MultiConnectorEdge],
                          unused: Set[str],
                          reverse: bool = False) -> List[MultiConnectorEdge]:
        """ Remove all edges along a path, until memlet tree contains siblings
            that should not be removed. Removes resulting isolated nodes as
            well. Operates in place.
            :param state: The state in which to remove edges.
            :param edge_map: Mapping from identifier to edge, used as a
                             predicate for removal.
            :param unused: Set of edge identifiers to remove.
            :param reverse: If False, removes forward in path, otherwise
                            backward.
            :return: List of edges from removed nodes at the path's end.
        """

        if reverse:
            edge_func = lambda e: state.out_edges(e.src)
            edge_pred = lambda pedge, e: e.src_conn == pedge.src_conn
        else:
            edge_func = lambda e: state.in_edges(e.dst)
            edge_pred = lambda pedge, e: e.dst_conn == pedge.dst_conn

        result = []

        for identifier, edge in edge_map.items():
            if identifier in unused:
                path = state.memlet_path(edge)
                pedge = None
                for pedge in (reversed(path) if reverse else path):
                    # If there are no other edges, it is safe to remove
                    if len([
                            e for e in edge_func(pedge) if edge_pred(pedge, e)
                    ]) == 1:
                        # Remove connectors as well
                        state.remove_edge_and_connectors(pedge)
                    else:
                        break
                else:  # Reached terminus without breaking, remove external node
                    if pedge is not None:
                        node = pedge.src if reverse else pedge.dst

                        # Keep track of edges on the other end of these nodes,
                        # they will be used to reconnect to first/last
                        # occurrence of access nodes in the inlined subgraph.
                        if reverse:
                            result.extend(state.in_edges(node))
                        else:
                            result.extend(state.out_edges(node))

                        state.remove_node(node)

        return result

    def apply(self, sdfg):
        state: SDFGState = sdfg.nodes()[self.state_id]
        nsdfg_node = state.nodes()[self.subgraph[InlineSDFG._nested_sdfg]]
        nsdfg: SDFG = nsdfg_node.sdfg
        nstate: SDFGState = nsdfg.nodes()[0]

        nsdfg_scope_entry = state.entry_node(nsdfg_node)
        nsdfg_scope_exit = (state.exit_nodes(nsdfg_scope_entry)[0]
                            if nsdfg_scope_entry is not None else None)

        #######################################################
        # Collect and update top-level SDFG metadata

        # Global/init/exit code
        if nsdfg.global_code:
            sdfg.set_global_code(sdfg.global_code + nsdfg.global_code)
        if nsdfg.init_code:
            sdfg.set_init_code(sdfg.init_code + nsdfg.init_code)
        if nsdfg.exit_code:
            sdfg.set_exit_code(sdfg.exit_code + nsdfg.exit_code)

        # Find original source/destination edges (there is only one edge per
        # connector, according to match)
        inputs: Dict[str, MultiConnectorEdge] = {}
        outputs: Dict[str, MultiConnectorEdge] = {}
        input_set: Dict[str, str] = {}
        output_set: Dict[str, str] = {}
        for e in state.in_edges(nsdfg_node):
            inputs[e.dst_conn] = e
            input_set[e.data.data] = e.dst_conn
        for e in state.out_edges(nsdfg_node):
            outputs[e.src_conn] = e
            output_set[e.data.data] = e.src_conn

        # All transients become transients of the parent (if data already
        # exists, find new name)
        # Mapping from nested transient name to top-level name
        transients: Dict[str, str] = {}
        for node in nstate.nodes():
            if isinstance(node, nodes.AccessNode):
                datadesc = nsdfg.arrays[node.data]
                if node.data not in transients and datadesc.transient:
                    name = sdfg.add_datadesc('%s_%s' %
                                             (nsdfg.label, node.data),
                                             datadesc,
                                             find_new_name=True)
                    transients[node.data] = name

        # All transients of edges between code nodes are also added to parent
        for edge in nstate.edges():
            if (isinstance(edge.src, nodes.CodeNode)
                    and isinstance(edge.dst, nodes.CodeNode)):
                datadesc = nsdfg.arrays[edge.data.data]
                if edge.data.data not in transients and datadesc.transient:
                    name = sdfg.add_datadesc('%s_%s' %
                                             (nsdfg.label, edge.data.data),
                                             datadesc,
                                             find_new_name=True)
                    transients[edge.data.data] = name

        # Collect nodes to add to top-level graph
        new_incoming_edges: Dict[nodes.Node, MultiConnectorEdge] = {}
        new_outgoing_edges: Dict[nodes.Node, MultiConnectorEdge] = {}

        source_accesses = set()
        sink_accesses = set()
        for node in nstate.source_nodes():
            if (isinstance(node, nodes.AccessNode)
                    and node.data not in transients):
                new_incoming_edges[node] = inputs[node.data]
                source_accesses.add(node)
        for node in nstate.sink_nodes():
            if (isinstance(node, nodes.AccessNode)
                    and node.data not in transients):
                new_outgoing_edges[node] = outputs[node.data]
                sink_accesses.add(node)

        #######################################################
        # Add nested SDFG into top-level SDFG

        # Add nested nodes into original state
        subgraph = SubgraphView(nstate, [
            n for n in nstate.nodes()
            if n not in (source_accesses | sink_accesses)
        ])
        state.add_nodes_from(subgraph.nodes())
        for edge in subgraph.edges():
            state.add_edge(edge.src, edge.src_conn, edge.dst, edge.dst_conn,
                           edge.data)

        #######################################################
        # Replace data on inlined SDFG nodes/edges

        # Replace symbols using invocation symbol mapping
        # Two-step replacement (N -> __dacesym_N --> map[N]) to avoid clashes
        for symname, symvalue in nsdfg_node.symbol_mapping.items():
            if str(symname) != str(symvalue):
                nsdfg.replace(symname, '__dacesym_' + symname)
        for symname, symvalue in nsdfg_node.symbol_mapping.items():
            if str(symname) != str(symvalue):
                nsdfg.replace('__dacesym_' + symname, symvalue)

        # Replace data names with their top-level counterparts
        repldict = {}
        repldict.update(transients)
        repldict.update({
            k: v.data.data
            for k, v in itertools.chain(inputs.items(), outputs.items())
        })
        for node in subgraph.nodes():
            if isinstance(node, nodes.AccessNode) and node.data in repldict:
                node.data = repldict[node.data]
        for edge in subgraph.edges():
            if edge.data.data in repldict:
                edge.data.data = repldict[edge.data.data]

        #######################################################
        # Reconnect inlined SDFG

        # If a source/sink node is one of the inputs/outputs, reconnect it,
        # replacing memlets in outgoing/incoming paths
        modified_edges = set()
        modified_edges |= self._modify_memlet_path(new_incoming_edges, nstate,
                                                   state, True)
        modified_edges |= self._modify_memlet_path(new_outgoing_edges, nstate,
                                                   state, False)

        # Modify all other internal edges pertaining to input/output nodes
        for node in subgraph.nodes():
            if isinstance(node, nodes.AccessNode):
                if node.data in input_set or node.data in output_set:
                    if node.data in input_set:
                        outer_edge = inputs[input_set[node.data]]
                    else:
                        outer_edge = outputs[output_set[node.data]]

                    for edge in state.all_edges(node):
                        if (edge not in modified_edges
                                and edge.data.data == node.data):
                            for e in state.memlet_tree(edge):
                                if e.data.data == node.data:
                                    e._data = helpers.unsqueeze_memlet(
                                        e.data, outer_edge.data)

        # If source/sink node is not connected to a source/destination access
        # node, and the nested SDFG is in a scope, connect to scope with empty
        # memlets
        if nsdfg_scope_entry is not None:
            for node in subgraph.nodes():
                if state.in_degree(node) == 0:
                    state.add_edge(nsdfg_scope_entry, None, node, None,
                                   EmptyMemlet())
                if state.out_degree(node) == 0:
                    state.add_edge(node, None, nsdfg_scope_exit, None,
                                   EmptyMemlet())

        # Replace nested SDFG parents with new SDFG
        for node in nstate.nodes():
            if isinstance(node, nodes.NestedSDFG):
                node.sdfg.parent = state
                node.sdfg.parent_sdfg = sdfg

        # Remove all unused external inputs/output memlet paths, as well as
        # resulting isolated nodes
        removed_in_edges = self._remove_edge_path(state,
                                                  inputs,
                                                  set(inputs.keys()) -
                                                  source_accesses,
                                                  reverse=True)
        removed_out_edges = self._remove_edge_path(state,
                                                   outputs,
                                                   set(outputs.keys()) -
                                                   sink_accesses,
                                                   reverse=False)

        # Re-add in/out edges to first/last nodes in subgraph
        order = [
            x for x in nx.topological_sort(nstate._nx)
            if isinstance(x, nodes.AccessNode)
        ]
        for edge in removed_in_edges:
            # Find first access node that refers to this edge
            node = next(n for n in order if n.data == edge.data.data)
            state.add_edge(edge.src, edge.src_conn, node, edge.dst_conn,
                           edge.data)
        for edge in removed_out_edges:
            # Find last access node that refers to this edge
            node = next(n for n in reversed(order) if n.data == edge.data.data)
            state.add_edge(node, edge.src_conn, edge.dst, edge.dst_conn,
                           edge.data)

        #######################################################
        # Remove nested SDFG node
        state.remove_node(nsdfg_node)

    def _modify_memlet_path(self,
                            new_edges: Dict[nodes.Node, MultiConnectorEdge],
                            nstate: SDFGState, state: SDFGState,
                            inputs: bool) -> Set[MultiConnectorEdge]:
        """ Modifies memlet paths in an inlined SDFG. Returns set of modified
            edges.
        """
        result = set()
        for node, top_edge in new_edges.items():
            inner_edges = (nstate.out_edges(node)
                           if inputs else nstate.in_edges(node))
            for inner_edge in inner_edges:
                new_memlet = helpers.unsqueeze_memlet(inner_edge.data,
                                                      top_edge.data)
                if inputs:
                    new_edge = state.add_edge(top_edge.src, top_edge.src_conn,
                                              inner_edge.dst,
                                              inner_edge.dst_conn, new_memlet)
                    mtree = state.memlet_tree(new_edge)
                else:
                    new_edge = state.add_edge(inner_edge.src,
                                              inner_edge.src_conn,
                                              top_edge.dst, top_edge.dst_conn,
                                              new_memlet)
                    mtree = state.memlet_tree(new_edge)

                # Modify all memlets going forward/backward
                def traverse(mtree_node):
                    result.add(mtree_node.edge)
                    mtree_node.edge._data = helpers.unsqueeze_memlet(
                        mtree_node.edge.data, top_edge.data)
                    for child in mtree_node.children:
                        traverse(child)

                for child in mtree.children:
                    traverse(child)

        return result
Beispiel #3
0
class CopyToDevice(pattern_matching.Transformation):
    """ Implements the copy-to-device transformation, which copies a nested
        SDFG and its dependencies to a given device.

        The transformation changes all data storage types of a nested SDFG to
        the given `storage` property, and creates new arrays and copies around
        the nested SDFG to that storage.
    """

    _nested_sdfg = nodes.NestedSDFG("", graph.OrderedDiGraph(), set(), set())

    storage = properties.Property(
        dtype=dtypes.StorageType,
        desc="Nested SDFG storage",
        choices=dtypes.StorageType,
        from_string=lambda x: dtypes.StorageType[x],
        default=dtypes.StorageType.Default)

    @staticmethod
    def annotates_memlets():
        return True

    @staticmethod
    def expressions():
        return [nxutil.node_path_graph(CopyToDevice._nested_sdfg)]

    @staticmethod
    def can_be_applied(graph, candidate, expr_index, sdfg, strict=False):
        return True

    @staticmethod
    def match_to_str(graph, candidate):
        nested_sdfg = graph.nodes()[candidate[CopyToDevice._nested_sdfg]]
        return nested_sdfg.label

    def apply(self, sdfg):
        state = sdfg.nodes()[self.state_id]
        nested_sdfg = state.nodes()[self.subgraph[CopyToDevice._nested_sdfg]]
        storage = self.storage

        for _, edge in enumerate(state.in_edges(nested_sdfg)):

            src, src_conn, dst, dst_conn, memlet = edge
            dataname = memlet.data
            memdata = sdfg.arrays[dataname]

            if isinstance(memdata, data.Array):
                new_data = sdfg.add_array(
                    'device_' + dataname + '_in',
                    memdata.dtype, [
                        symbolic.overapproximate(r)
                        for r in memlet.bounding_box_size()
                    ],
                    transient=True,
                    storage=storage)
            elif isinstance(memdata, data.Scalar):
                new_data = sdfg.add_scalar(
                    'device_' + dataname + '_in',
                    memdata.dtype,
                    transient=True,
                    storage=storage)
            else:
                raise NotImplementedError

            data_node = nodes.AccessNode('device_' + dataname + '_in')

            to_data_mm = dcpy(memlet)
            from_data_mm = dcpy(memlet)
            from_data_mm.data = 'device_' + dataname + '_in'
            offset = []
            for ind, r in enumerate(memlet.subset):
                offset.append(r[0])
                if isinstance(memlet.subset[ind], tuple):
                    begin = memlet.subset[ind][0] - r[0]
                    end = memlet.subset[ind][1] - r[0]
                    step = memlet.subset[ind][2]
                    from_data_mm.subset[ind] = (begin, end, step)
                else:
                    from_data_mm.subset[ind] -= r[0]

            state.remove_edge(edge)
            state.add_edge(src, src_conn, data_node, None, to_data_mm)
            state.add_edge(data_node, None, dst, dst_conn, from_data_mm)

        for _, edge in enumerate(state.out_edges(nested_sdfg)):

            src, src_conn, dst, dst_conn, memlet = edge
            dataname = memlet.data
            memdata = sdfg.arrays[dataname]

            if isinstance(memdata, data.Array):
                new_data = data.Array(
                    'device_' + dataname + '_out',
                    memdata.dtype, [
                        symbolic.overapproximate(r)
                        for r in memlet.bounding_box_size()
                    ],
                    transient=True,
                    storage=storage)
            elif isinstance(memdata, data.Scalar):
                new_data = sdfg.add_scalar(
                    'device_' + dataname + '_out',
                    memdata.dtype,
                    transient=True,
                    storage=storage)
            else:
                raise NotImplementedError

            data_node = nodes.AccessNode('device_' + dataname + '_out')

            to_data_mm = dcpy(memlet)
            from_data_mm = dcpy(memlet)
            to_data_mm.data = 'device_' + dataname + '_out'
            offset = []
            for ind, r in enumerate(memlet.subset):
                offset.append(r[0])
                if isinstance(memlet.subset[ind], tuple):
                    begin = memlet.subset[ind][0] - r[0]
                    end = memlet.subset[ind][1] - r[0]
                    step = memlet.subset[ind][2]
                    to_data_mm.subset[ind] = (begin, end, step)
                else:
                    to_data_mm.subset[ind] -= r[0]

            state.remove_edge(edge)
            state.add_edge(src, src_conn, data_node, None, to_data_mm)
            state.add_edge(data_node, None, dst, dst_conn, from_data_mm)

        # Change storage for all data inside nested SDFG to device.
        change_storage(nested_sdfg.sdfg, storage)
Beispiel #4
0
class MapFission(pattern_matching.Transformation):
    """ Implements the MapFission transformation.
        Map fission refers to subsuming a map scope into its internal subgraph,
        essentially replicating the map into maps in all of its internal
        components. This also extends the dimensions of "border" transient
        arrays (i.e., those between the maps), in order to retain program
        semantics after fission.

        There are two cases that match map fission:
        1. A map with an arbitrary subgraph with more than one computational
           (i.e., non-access) node. The use of arrays connecting the
           computational nodes must be limited to the subgraph, and non
           transient arrays may not be used as "border" arrays.
        2. A map with one internal node that is a nested SDFG, in which
           each state matches the conditions of case (1).

        If a map has nested SDFGs in its subgraph, they are not considered in
        the case (1) above, and MapFission must be invoked again on the maps
        with the nested SDFGs in question.
    """
    _map_entry = nodes.EntryNode()
    _nested_sdfg = nodes.NestedSDFG("", OrderedDiGraph(), set(), set())

    @staticmethod
    def annotates_memlets():
        return False

    @staticmethod
    def expressions():
        return [
            nxutil.node_path_graph(MapFission._map_entry, ),
            nxutil.node_path_graph(
                MapFission._map_entry,
                MapFission._nested_sdfg,
            )
        ]

    @staticmethod
    def _components(
            subgraph: sd.SubgraphView) -> List[Tuple[nodes.Node, nodes.Node]]:
        """
        Returns the list of tuples non-array components in this subgraph.
        Each element in the list is a 2 tuple of (input node, output node) of
        the component.
        """
        graph = (subgraph
                 if isinstance(subgraph, sd.SDFGState) else subgraph.graph)
        sdict = subgraph.scope_dict(node_to_children=True)
        ns = [(n,
               graph.exit_nodes(n)[0]) if isinstance(n, nodes.EntryNode) else
              (n, n) for n in sdict[None]
              if isinstance(n, (nodes.CodeNode, nodes.EntryNode))]

        return ns

    @staticmethod
    def _border_arrays(sdfg, parent, subgraph):
        """ Returns a set of array names that are local to the fission
            subgraph. """
        nested = isinstance(parent, sd.SDFGState)
        sdict = subgraph.scope_dict(node_to_children=True)
        subset = sd.SubgraphView(parent, sdict[None])
        if nested:
            return set(node.data for node in subset.nodes()
                       if isinstance(node, nodes.AccessNode)
                       and sdfg.arrays[node.data].transient)
        else:
            return set(node.data for node in subset.nodes()
                       if isinstance(node, nodes.AccessNode))

    @staticmethod
    def _internal_border_arrays(total_components, subgraphs):
        """ Returns the set of border arrays that appear between computational
            components (i.e., without sources and sinks). """
        inputs = set()
        outputs = set()

        for components, subgraph in zip(total_components, subgraphs):
            for component_in, component_out in components:
                for e in subgraph.in_edges(component_in):
                    if isinstance(e.src, nodes.AccessNode):
                        inputs.add(e.src.data)
                for e in subgraph.out_edges(component_out):
                    if isinstance(e.dst, nodes.AccessNode):
                        outputs.add(e.dst.data)

        return inputs & outputs

    @staticmethod
    def _outside_map(node, scope_dict, entry_nodes):
        """ Returns True iff node is not in any of the scopes spanned by
            entry_nodes. """
        while scope_dict[node] is not None:
            if scope_dict[node] in entry_nodes:
                return False
            node = scope_dict[node]
        return True

    @staticmethod
    def can_be_applied(graph, candidate, expr_index, sdfg, strict=False):
        map_node = graph.node(candidate[MapFission._map_entry])
        nsdfg_node = None

        # If the map is dynamic-ranged, the resulting border arrays would be
        # dynamically sized
        if sd.has_dynamic_map_inputs(graph, map_node):
            return False

        if expr_index == 0:  # Map with subgraph
            subgraphs = [
                graph.scope_subgraph(map_node,
                                     include_entry=False,
                                     include_exit=False)
            ]
        else:  # Map with nested SDFG
            nsdfg_node = graph.node(candidate[MapFission._nested_sdfg])
            # Make sure there are no other internal nodes in the map
            if len(set(e.dst for e in graph.out_edges(map_node))) > 1:
                return False
            subgraphs = list(nsdfg_node.sdfg.nodes())

        # Test subgraphs
        border_arrays = set()
        total_components = []
        for sg in subgraphs:
            components = MapFission._components(sg)
            snodes = sg.nodes()
            # Test that the subgraphs have more than one computational component
            if expr_index == 0 and len(snodes) > 0 and len(components) <= 1:
                return False

            # Test that the components are connected by transients that are not
            # used anywhere else
            border_arrays |= MapFission._border_arrays(
                nsdfg_node.sdfg if expr_index == 1 else sdfg,
                sg if expr_index == 1 else graph, sg)
            total_components.append(components)

            # In nested SDFGs and subgraphs, ensure none of the border
            # values are non-transients
            for array in border_arrays:
                if expr_index == 0:
                    ndesc = sdfg.arrays[array]
                else:
                    ndesc = nsdfg_node.sdfg.arrays[array]

                if ndesc.transient is False:
                    return False

            # In subgraphs, make sure transients are not used/allocated
            # in other scopes or states
            if expr_index == 0:
                # Find all nodes not in subgraph
                not_subgraph = set(
                    n.data for n in graph.nodes()
                    if n not in snodes and isinstance(n, nodes.AccessNode))
                not_subgraph.update(
                    set(n.data for s in sdfg.nodes() if s != graph
                        for n in s.nodes() if isinstance(n, nodes.AccessNode)))

                for _, component_out in components:
                    for e in sg.out_edges(component_out):
                        if isinstance(e.dst, nodes.AccessNode):
                            if e.dst.data in not_subgraph:
                                return False

        # Fail if there are arrays inside the map that are not a direct
        # output of a computational component
        # TODO(later): Support this case? Ambiguous array sizes and memlets
        external_arrays = (
            border_arrays -
            MapFission._internal_border_arrays(total_components, subgraphs))
        if len(external_arrays) > 0:
            return False

        return True

    @staticmethod
    def match_to_str(graph, candidate):
        map_entry = graph.node(candidate[MapFission._map_entry])
        return map_entry.map.label

    def apply(self, sdfg: sd.SDFG):
        graph: sd.SDFGState = sdfg.nodes()[self.state_id]
        map_entry = graph.node(self.subgraph[MapFission._map_entry])
        map_exit = graph.exit_nodes(map_entry)[0]
        nsdfg_node: Optional[nodes.NestedSDFG] = None

        # Obtain subgraph to perform fission to
        if self.expr_index == 0:  # Map with subgraph
            subgraphs = [(graph,
                          graph.scope_subgraph(map_entry,
                                               include_entry=False,
                                               include_exit=False))]
            parent = sdfg
        else:  # Map with nested SDFG
            nsdfg_node = graph.node(self.subgraph[MapFission._nested_sdfg])
            subgraphs = [(state, state) for state in nsdfg_node.sdfg.nodes()]
            parent = nsdfg_node.sdfg
        modified_arrays = set()

        # Get map information
        outer_map: nodes.Map = map_entry.map
        mapsize = outer_map.range.size()

        # Add new symbols from outer map to nested SDFG
        if self.expr_index == 1:
            map_syms = outer_map.range.free_symbols
            for edge in graph.out_edges(map_entry):
                if edge.data.data:
                    map_syms.update(edge.data.subset.free_symbols)
            for edge in graph.in_edges(map_exit):
                if edge.data.data:
                    map_syms.update(edge.data.subset.free_symbols)
            for symname, sym in map_syms.items():
                if symname in outer_map.params:
                    continue
                if symname not in nsdfg_node.symbol_mapping.keys():
                    nsdfg_node.symbol_mapping[symname] = sym

        for state, subgraph in subgraphs:
            components = MapFission._components(subgraph)
            sources = subgraph.source_nodes()
            sinks = subgraph.sink_nodes()

            # Collect external edges
            if self.expr_index == 0:
                external_edges_entry = list(state.out_edges(map_entry))
                external_edges_exit = list(state.in_edges(map_exit))
            else:
                external_edges_entry = [
                    e for e in subgraph.edges()
                    if (isinstance(e.src, nodes.AccessNode)
                        and not nsdfg_node.sdfg.arrays[e.src.data].transient)
                ]
                external_edges_exit = [
                    e for e in subgraph.edges()
                    if (isinstance(e.dst, nodes.AccessNode)
                        and not nsdfg_node.sdfg.arrays[e.dst.data].transient)
                ]

            # Map external edges to outer memlets
            edge_to_outer = {}
            for edge in external_edges_entry:
                if self.expr_index == 0:
                    # Subgraphs use the corresponding outer map edges
                    path = state.memlet_path(edge)
                    eindex = path.index(edge)
                    edge_to_outer[edge] = path[eindex - 1]
                else:
                    # Nested SDFGs use the internal map edges of the node
                    outer_edge = next(e for e in graph.in_edges(nsdfg_node)
                                      if e.dst_conn == edge.src.data)
                    edge_to_outer[edge] = outer_edge

            for edge in external_edges_exit:
                if self.expr_index == 0:
                    path = state.memlet_path(edge)
                    eindex = path.index(edge)
                    edge_to_outer[edge] = path[eindex + 1]
                else:
                    # Nested SDFGs use the internal map edges of the node
                    outer_edge = next(e for e in graph.out_edges(nsdfg_node)
                                      if e.src_conn == edge.dst.data)
                    edge_to_outer[edge] = outer_edge

            # Collect all border arrays and code->code edges
            arrays = MapFission._border_arrays(
                nsdfg_node.sdfg if self.expr_index == 1 else sdfg, state,
                subgraph)
            scalars = defaultdict(list)
            for _, component_out in components:
                for e in subgraph.out_edges(component_out):
                    if isinstance(e.dst, nodes.CodeNode):
                        scalars[e.data.data].append(e)

            # Create new arrays for scalars
            for scalar, edges in scalars.items():
                desc = parent.arrays[scalar]
                name, newdesc = parent.add_temp_transient(
                    mapsize,
                    desc.dtype,
                    desc.storage,
                    toplevel=desc.toplevel,
                    debuginfo=desc.debuginfo,
                    allow_conflicts=desc.allow_conflicts)

                # Add extra nodes in component boundaries
                for edge in edges:
                    anode = state.add_access(name)
                    state.add_edge(
                        edge.src, edge.src_conn, anode, None,
                        mm.Memlet(
                            name, outer_map.range.num_elements(),
                            subsets.Range.from_string(','.join(
                                outer_map.params)), 1))
                    state.add_edge(
                        anode, None, edge.dst, edge.dst_conn,
                        mm.Memlet(
                            name, outer_map.range.num_elements(),
                            subsets.Range.from_string(','.join(
                                outer_map.params)), 1))
                    state.remove_edge(edge)

            # Add extra maps around components
            new_map_entries = []
            for component_in, component_out in components:
                me, mx = state.add_map(outer_map.label + '_fission',
                                       [(p, '0:1') for p in outer_map.params],
                                       outer_map.schedule,
                                       unroll=outer_map.unroll,
                                       debuginfo=outer_map.debuginfo)

                # Add dynamic input connectors
                for conn in map_entry.in_connectors:
                    if not conn.startswith('IN_'):
                        me.add_in_connector(conn)

                me.map.range = dcpy(outer_map.range)
                new_map_entries.append(me)

                # Reconnect edges through new map
                for e in state.in_edges(component_in):
                    state.add_edge(me, None, e.dst, e.dst_conn, dcpy(e.data))
                    # Reconnect inner edges at source directly to external nodes
                    if self.expr_index == 0 and e in external_edges_entry:
                        state.add_edge(edge_to_outer[e].src,
                                       edge_to_outer[e].src_conn, me, None,
                                       dcpy(edge_to_outer[e].data))
                    else:
                        state.add_edge(e.src, e.src_conn, me, None,
                                       dcpy(e.data))
                    state.remove_edge(e)
                # Empty memlet edge in nested SDFGs
                if state.in_degree(component_in) == 0:
                    state.add_edge(me, None, component_in, None,
                                   mm.EmptyMemlet())

                for e in state.out_edges(component_out):
                    state.add_edge(e.src, e.src_conn, mx, None, dcpy(e.data))
                    # Reconnect inner edges at sink directly to external nodes
                    if self.expr_index == 0 and e in external_edges_exit:
                        state.add_edge(mx, None, edge_to_outer[e].dst,
                                       edge_to_outer[e].dst_conn,
                                       dcpy(edge_to_outer[e].data))
                    else:
                        state.add_edge(mx, None, e.dst, e.dst_conn,
                                       dcpy(e.data))
                    state.remove_edge(e)
                # Empty memlet edge in nested SDFGs
                if state.out_degree(component_out) == 0:
                    state.add_edge(component_out, None, mx, None,
                                   mm.EmptyMemlet())
            # Connect other sources/sinks not in components (access nodes)
            # directly to external nodes
            if self.expr_index == 0:
                for node in sources:
                    if isinstance(node, nodes.AccessNode):
                        for edge in state.in_edges(node):
                            outer_edge = edge_to_outer[edge]
                            memlet = dcpy(edge.data)
                            memlet.subset = subsets.Range(
                                outer_map.range.ranges + memlet.subset.ranges)
                            state.add_edge(outer_edge.src, outer_edge.src_conn,
                                           edge.dst, edge.dst_conn, memlet)

                for node in sinks:
                    if isinstance(node, nodes.AccessNode):
                        for edge in state.out_edges(node):
                            outer_edge = edge_to_outer[edge]
                            state.add_edge(edge.src, edge.src_conn,
                                           outer_edge.dst, outer_edge.dst_conn,
                                           dcpy(outer_edge.data))

            # Augment arrays by prepending map dimensions
            for array in arrays:
                if array in modified_arrays:
                    continue
                desc = parent.arrays[array]
                for sz in reversed(mapsize):
                    desc.strides = [desc.total_size] + list(desc.strides)
                    desc.total_size = desc.total_size * sz

                desc.shape = mapsize + list(desc.shape)
                desc.offset = [0] * len(mapsize) + list(desc.offset)
                modified_arrays.add(array)

            # Fill scope connectors so that memlets can be tracked below
            state.fill_scope_connectors()

            # Correct connectors and memlets in nested SDFGs to account for
            # missing outside map
            if self.expr_index == 1:
                to_correct = ([(e, e.src) for e in external_edges_entry] +
                              [(e, e.dst) for e in external_edges_exit])
                corrected_nodes = set()
                for edge, node in to_correct:
                    if isinstance(node, nodes.AccessNode):
                        if node in corrected_nodes:
                            continue
                        corrected_nodes.add(node)

                        outer_edge = edge_to_outer[edge]
                        desc = parent.arrays[node.data]

                        # Modify shape of internal array to match outer one
                        outer_desc = sdfg.arrays[outer_edge.data.data]
                        if not isinstance(desc, dt.Scalar):
                            desc.shape = outer_desc.shape
                        if isinstance(desc, dt.Array):
                            desc.strides = outer_desc.strides
                            desc.total_size = outer_desc.total_size

                        # Inside the nested SDFG, offset all memlets to include
                        # the offsets from within the map.
                        # NOTE: Relies on propagation to fix outer memlets
                        for internal_edge in state.all_edges(node):
                            for e in state.memlet_tree(internal_edge):
                                e.data.subset.offset(desc.offset, False)
                                e.data.subset = helpers.unsqueeze_memlet(
                                    e.data, outer_edge.data).subset

                        # Only after offsetting memlets we can modify the
                        # overall offset
                        if isinstance(desc, dt.Array):
                            desc.offset = outer_desc.offset

            # Fill in memlet trees for border transients
            # NOTE: Memlet propagation should run to correct the outer edges
            for node in subgraph.nodes():
                if isinstance(node, nodes.AccessNode) and node.data in arrays:
                    for edge in state.all_edges(node):
                        for e in state.memlet_tree(edge):
                            # Prepend map dimensions to memlet
                            e.data.subset = subsets.Range(
                                [(d, d, 1) for d in outer_map.params] +
                                e.data.subset.ranges)

        # If nested SDFG, reconnect nodes around map and modify memlets
        if self.expr_index == 1:
            for edge in graph.in_edges(map_entry):
                if not edge.dst_conn or not edge.dst_conn.startswith('IN_'):
                    continue

                # Modify edge coming into nested SDFG to include entire array
                desc = sdfg.arrays[edge.data.data]
                edge.data.subset = subsets.Range.from_array(desc)
                edge.data.num_accesses = edge.data.subset.num_elements()

                # Find matching edge inside map
                inner_edge = next(
                    e for e in graph.out_edges(map_entry)
                    if e.src_conn and e.src_conn[4:] == edge.dst_conn[3:])
                graph.add_edge(edge.src, edge.src_conn, nsdfg_node,
                               inner_edge.dst_conn, dcpy(edge.data))

            for edge in graph.out_edges(map_exit):
                # Modify edge coming out of nested SDFG to include entire array
                desc = sdfg.arrays[edge.data.data]
                edge.data.subset = subsets.Range.from_array(desc)

                # Find matching edge inside map
                inner_edge = next(e for e in graph.in_edges(map_exit)
                                  if e.dst_conn[3:] == edge.src_conn[4:])
                graph.add_edge(nsdfg_node, inner_edge.src_conn, edge.dst,
                               edge.dst_conn, dcpy(edge.data))

        # Remove outer map
        graph.remove_nodes_from([map_entry, map_exit])
Beispiel #5
0
class InlineSDFG(pattern_matching.Transformation):
    """ Inlines a single-state nested SDFG into a top-level SDFG """

    _nested_sdfg = nodes.NestedSDFG('_', sd.SDFG('_'), set(), set())

    @staticmethod
    def annotates_memlets():
        return True

    @staticmethod
    def expressions():
        # Matches anything
        return [nxutil.node_path_graph(InlineSDFG._nested_sdfg)]

    @staticmethod
    def can_be_applied(graph, candidate, expr_index, sdfg, strict=False):
        nested_sdfg = graph.nodes()[candidate[InlineSDFG._nested_sdfg]]
        if len(nested_sdfg.sdfg.nodes()) != 1:
            return False

        return True

    @staticmethod
    def match_to_str(graph, candidate):
        return graph.label

    def _modify_memlet(self, internal_memlet: Memlet, external_memlet: Memlet):
        """ Unsqueezes and offsets a memlet, as per the semantics of nested
            SDFGs.
            :param internal_memlet: The internal memlet (inside nested SDFG)
                                    before modification.
            :param internal_memlet: The external memlet before modification.
            :return: Offset Memlet to set on the resulting graph.
        """
        result = dc(internal_memlet)
        result.data = external_memlet.data

        shape = external_memlet.subset.size()
        if len(internal_memlet.subset) < len(external_memlet.subset):
            ones = [i for i, d in enumerate(shape) if d == 1]

            # Special case: If internal memlet is a range of size 1 with (0,0,1),
            #               ignore it when unsqueezing
            if (len(internal_memlet.subset) == 1
                    and (internal_memlet.subset[0] == (0, 0, 1)
                         or internal_memlet.subset[0] == 0)):
                to_unsqueeze = ones[1:]
            else:
                to_unsqueeze = ones

            result.subset.unsqueeze(to_unsqueeze)
        elif len(internal_memlet.subset) > len(external_memlet.subset):
            raise ValueError(
                'Unexpected extra dimensions in internal memlet '
                'while inlining SDFG.\nExternal memlet: %s\n'
                'Internal memlet: %s' % (external_memlet, internal_memlet))

        result.subset.offset(external_memlet.subset, False)

        # TODO: Offset rest of memlet according to other_subset
        if external_memlet.other_subset is not None:
            raise NotImplementedError

        return result

    def apply(self, sdfg):
        graph = sdfg.nodes()[self.state_id]
        nsdfg_node = graph.nodes()[self.subgraph[InlineSDFG._nested_sdfg]]
        nsdfg = nsdfg_node.sdfg

        # Find original source/destination nodes
        inputs = {}
        outputs = {}
        for e in graph.in_edges(nsdfg_node):
            inputs[e.dst_conn] = (e.src, e.src_conn, e.data)
        for e in graph.out_edges(nsdfg_node):
            outputs[e.src_conn] = (e.dst, e.dst_conn, e.data)

        torename = {}
        torename.update({k: v[2].data for k, v in inputs.items()})
        torename.update({k: v[2].data for k, v in outputs.items()})
        entry_connectors = set()

        # Add SDFG nodes to top-level SDFG
        state = nsdfg.nodes()[0]
        for node in state.nodes():
            # Data access nodes
            if isinstance(node, nodes.AccessNode):
                # External node
                if node.data in inputs or node.data in outputs:
                    for _, _, dst, dst_conn, _ in state.out_edges(node):
                        # Custom entry connector case
                        if (isinstance(dst, nodes.EntryNode)
                                and dst_conn[0:3] != 'IN_'):
                            entry_connectors.add(node.data)
                            sdfg.arrays[node.data] = nsdfg.arrays[node.data]
                            sdfg.arrays[node.data].transient = True
                            graph.add_node(node)
                            torename.pop(node.data)
                            break
                    continue
                # Internal node (e.g., transient)
                if node.data not in torename:
                    name = node.data
                    # Name already exists
                    if name in sdfg.arrays:
                        name = '%s_%s' % (nsdfg.label, node.data)
                        i = 0
                        while name in sdfg.arrays:
                            name = '%s_%s_%d' % (nsdfg.label, node.data, i)
                            i += 1
                    # Add transient
                    sdfg.arrays[name] = nsdfg.arrays[node.data]
                    # Rename all internal uses
                    torename[node.data] = name
            # Set all parents of nested SDFG nodes in the inlined SDFG to their
            # new parent
            elif isinstance(node, nodes.NestedSDFG):
                node.sdfg.parent = graph
                node.sdfg.parent_sdfg = sdfg

            graph.add_node(node)

        # TODO: Confirm that the following is always correct
        # Add Scalars of the nested SDFG to the parent
        for name, arr in nsdfg.arrays.items():
            if isinstance(arr, dt.Scalar) and name not in sdfg.arrays:
                sdfg.arrays[name] = arr

        # Reconnect edges to their original source
        for e in state.edges():
            if isinstance(e.src, nodes.AccessNode) and e.src.data in inputs:
                cnode, cconn, cmemlet = inputs[e.src.data]
                if e.src.data in entry_connectors:
                    graph.add_edge(cnode, cconn, e.src, None, cmemlet)
                    graph.add_edge(e.src, None, e.dst, e.dst_conn, e.data)
                else:
                    # Connect to source node instead
                    newmemlet = self._modify_memlet(e.data, cmemlet)
                    graph.add_edge(cnode, cconn, e.dst, e.dst_conn, newmemlet)
            elif isinstance(e.dst, nodes.AccessNode) and e.dst.data in outputs:
                cnode, cconn, cmemlet = outputs[e.dst.data]
                newmemlet = self._modify_memlet(e.data, cmemlet)
                if state.out_edges(e.dst):
                    graph.add_edge(e.src, e.src_conn, e.dst, e.dst_conn,
                                   newmemlet)
                    e._src = e._dst
                    e._src_conn = e._dst_conn
                    # Remove wcr
                    newmemlet = dc(newmemlet)
                    newmemlet.wcr = None
                    newmemlet.other_subset = dc(newmemlet.subset)
                    for _, _, dst, _, memlet in graph.out_edges(cnode):
                        if isinstance(dst, nodes.AccessNode
                                      ) and memlet.data == cmemlet.data:
                            memlet.wcr = None
                    # # Remove output node
                    # out_conn = 'OUT_{}'.format(cconn[3:])
                    # for _, conn, dst, _, _ in graph.out_edges(cnode):
                    #     if conn == out_conn:
                    #         graph.remove_node(dst)
                    # # Remove connectors
                    # in_connectors = dc(cnode.in_connectors)
                    # in_connectors.remove(cconn)
                    # cnode.in_connectors = in_connectors
                    # out_connectors = dc(cnode.out_connectors)
                    # out_connectors.remove(out_conn)
                    # cnode.out_connectors = out_connectors
                # else:
                # Connect to destination node instead
                graph.add_edge(e.src, e.src_conn, cnode, cconn, newmemlet)
            elif e.data.data in torename:
                if e.data.data in inputs:
                    newmemlet = self._modify_memlet(e.data,
                                                    inputs[e.data.data][2])
                elif e.data.data in outputs:
                    newmemlet = self._modify_memlet(e.data,
                                                    outputs[e.data.data][2])
                else:
                    # Rename data
                    cdata = torename[e.data.data]
                    newmemlet = dc(e.data)
                    newmemlet.data = cdata

                graph.add_edge(e.src, e.src_conn, e.dst, e.dst_conn, newmemlet)
            else:
                # Do nothing
                graph.add_edge(e.src, e.src_conn, e.dst, e.dst_conn, e.data)

        # Rename all access nodes
        for node in state.nodes():
            if isinstance(node, nodes.AccessNode) and node.data in torename:
                node.data = torename[node.data]

        # If an empty memlet was connected to the nested SDFG, reconnect
        # all source nodes with empty memlets
        if None in inputs:
            cnode, cconn, cmemlet = inputs[None]
            for node in state.source_nodes():
                graph.add_edge(cnode, cconn, node, None, EmptyMemlet())

        # Remove the nested SDFG node
        graph.remove_node(nsdfg_node)

        # Remove input/output nodes from top-level graph if not connected to
        # any internal node
        for node, _, _ in list(inputs.values()) + list(outputs.values()):
            if len(graph.all_edges(node)) == 0:
                graph.remove_node(node)