コード例 #1
0
ファイル: ger.py プロジェクト: mfkiwl/dace
    def expansion(node, parent_state, parent_sdfg, **kwargs):
        node.validate(parent_sdfg, parent_state)
        inputs = ('_A', '_x', '_y')
        outputs = ('_res', )
        in_edges = [next(parent_state.in_edges_by_connector(node, conn)) for conn in inputs]
        out_edges = [next(parent_state.out_edges_by_connector(node, conn)) for conn in outputs]
        arrays = {}
        arrays.update({inp: parent_sdfg.arrays[e.data.data] for inp, e in zip(inputs, in_edges)})
        arrays.update({out: parent_sdfg.arrays[e.data.data] for out, e in zip(outputs, out_edges)})

        # TODO: Support memlet subsets
        if any(e.data.subset != sbs.Range.from_array(arrays[a]) for a, e in zip(inputs, in_edges)):
            raise NotImplementedError
        if any(e.data.subset != sbs.Range.from_array(arrays[a]) for a, e in zip(outputs, out_edges)):
            raise NotImplementedError

        sdfg = dace.SDFG(f'{node.label}_sdfg')
        sdfg.add_symbol('M', int)
        sdfg.add_symbol('N', int)
        sdfg.add_symbol('alpha', arrays['_A'].dtype)

        for name, desc in arrays.items():
            newdesc = copy.deepcopy(desc)
            newdesc.transient = False
            sdfg.add_datadesc(name, newdesc)

        state = sdfg.add_state()
        state.add_mapped_tasklet(
            'ger',
            {
                '_i': f'0:M',
                '_j': f'0:N'
            },
            {
                'a': mm.Memlet('_A[_i, _j]'),
                'xin': mm.Memlet('_x[_i]'),
                'yin': mm.Memlet(f'_y[_j]')
            },
            f'aout = alpha * xin * yin + a',
            {'aout': mm.Memlet('_res[_i, _j]')},
            external_edges=True,
        )

        outshape = arrays['_res'].shape
        nsdfg_node = nodes.NestedSDFG(node.label, sdfg, set(inputs), set(outputs), {
            'M': outshape[0],
            'N': outshape[1],
            'alpha': node.alpha
        })

        return nsdfg_node
コード例 #2
0
class CopyToDevice(pattern_matching.Transformation):
    """ Implements the copy-to-device transformation, which copies a nested
        SDFG and its dependencies to a given device.

        The transformation changes all data storage types of a nested SDFG to
        the given `storage` property, and creates new arrays and copies around
        the nested SDFG to that storage.
    """

    _nested_sdfg = nodes.NestedSDFG("", graph.OrderedDiGraph(), {}, {})

    storage = properties.Property(dtype=dtypes.StorageType,
                                  desc="Nested SDFG storage",
                                  choices=dtypes.StorageType,
                                  from_string=lambda x: dtypes.StorageType[x],
                                  default=dtypes.StorageType.Default)

    @staticmethod
    def annotates_memlets():
        return True

    @staticmethod
    def expressions():
        return [sdutil.node_path_graph(CopyToDevice._nested_sdfg)]

    @staticmethod
    def can_be_applied(graph, candidate, expr_index, sdfg, strict=False):
        nested_sdfg = graph.nodes()[candidate[CopyToDevice._nested_sdfg]]

        for edge in graph.all_edges(nested_sdfg):
            # Stream inputs/outputs not allowed
            path = graph.memlet_path(edge)
            if ((isinstance(path[0].src, nodes.AccessNode)
                 and isinstance(sdfg.arrays[path[0].src.data], data.Stream)) or
                (isinstance(path[-1].dst, nodes.AccessNode)
                 and isinstance(sdfg.arrays[path[-1].dst.data], data.Stream))):
                return False
            # WCR outputs with arrays are not allowed
            if (edge.data.wcr is not None
                    and edge.data.subset.num_elements() != 1):
                return False

        return True

    @staticmethod
    def match_to_str(graph, candidate):
        nested_sdfg = graph.nodes()[candidate[CopyToDevice._nested_sdfg]]
        return nested_sdfg.label

    def apply(self, sdfg):
        state = sdfg.nodes()[self.state_id]
        nested_sdfg = state.nodes()[self.subgraph[CopyToDevice._nested_sdfg]]
        storage = self.storage
        created_arrays = set()

        for _, edge in enumerate(state.in_edges(nested_sdfg)):

            src, src_conn, dst, dst_conn, memlet = edge
            dataname = memlet.data
            if dataname is None:
                continue
            memdata = sdfg.arrays[dataname]

            name = 'device_' + dataname + '_in'
            if name not in created_arrays:
                if isinstance(memdata, data.Array):
                    name, _ = sdfg.add_array(
                        'device_' + dataname + '_in',
                        shape=[
                            symbolic.overapproximate(r)
                            for r in memlet.bounding_box_size()
                        ],
                        dtype=memdata.dtype,
                        transient=True,
                        storage=storage,
                        find_new_name=True)
                elif isinstance(memdata, data.Scalar):
                    name, _ = sdfg.add_scalar('device_' + dataname + '_in',
                                              dtype=memdata.dtype,
                                              transient=True,
                                              storage=storage,
                                              find_new_name=True)
                else:
                    raise NotImplementedError
                created_arrays.add(name)

            data_node = nodes.AccessNode(name)

            to_data_mm = dcpy(memlet)
            from_data_mm = dcpy(memlet)
            from_data_mm.data = name
            offset = []
            for ind, r in enumerate(memlet.subset):
                offset.append(r[0])
                if isinstance(memlet.subset[ind], tuple):
                    begin = memlet.subset[ind][0] - r[0]
                    end = memlet.subset[ind][1] - r[0]
                    step = memlet.subset[ind][2]
                    from_data_mm.subset[ind] = (begin, end, step)
                else:
                    from_data_mm.subset[ind] -= r[0]

            state.remove_edge(edge)
            state.add_edge(src, src_conn, data_node, None, to_data_mm)
            state.add_edge(data_node, None, dst, dst_conn, from_data_mm)

        for _, edge in enumerate(state.out_edges(nested_sdfg)):

            src, src_conn, dst, dst_conn, memlet = edge
            dataname = memlet.data
            if dataname is None:
                continue
            memdata = sdfg.arrays[dataname]

            name = 'device_' + dataname + '_out'
            if name not in created_arrays:
                if isinstance(memdata, data.Array):
                    name, _ = sdfg.add_array(
                        name,
                        shape=[
                            symbolic.overapproximate(r)
                            for r in memlet.bounding_box_size()
                        ],
                        dtype=memdata.dtype,
                        transient=True,
                        storage=storage,
                        find_new_name=True)
                elif isinstance(memdata, data.Scalar):
                    name, _ = sdfg.add_scalar(name,
                                              dtype=memdata.dtype,
                                              transient=True,
                                              storage=storage)
                else:
                    raise NotImplementedError
                created_arrays.add(name)

            data_node = nodes.AccessNode(name)

            to_data_mm = dcpy(memlet)
            from_data_mm = dcpy(memlet)
            to_data_mm.data = name
            offset = []
            for ind, r in enumerate(memlet.subset):
                offset.append(r[0])
                if isinstance(memlet.subset[ind], tuple):
                    begin = memlet.subset[ind][0] - r[0]
                    end = memlet.subset[ind][1] - r[0]
                    step = memlet.subset[ind][2]
                    to_data_mm.subset[ind] = (begin, end, step)
                else:
                    to_data_mm.subset[ind] -= r[0]

            state.remove_edge(edge)
            state.add_edge(src, src_conn, data_node, None, to_data_mm)
            state.add_edge(data_node, None, dst, dst_conn, from_data_mm)

        # Change storage for all data inside nested SDFG to device.
        change_storage(nested_sdfg.sdfg, storage)
コード例 #3
0
class InlineSDFG(transformation.Transformation):
    """ Inlines a single-state nested SDFG into a top-level SDFG.

        In particular, the steps taken are:

        1. All transient arrays become transients of the parent
        2. If a source/sink node is one of the inputs/outputs:
          a. Remove it
          b. Reconnect through external edges (map/accessnode)
          c. Replace and reoffset memlets with external data descriptor
        3. If other nodes carry the names of inputs/outputs:
          a. Replace data with external data descriptor
          b. Replace and reoffset memlets with external data descriptor
        4. If source/sink node is not connected to a source/destination, and
           the nested SDFG is in a scope, connect to scope with empty memlets
        5. Remove all unused external inputs/output memlet paths
        6. Remove isolated nodes resulting from previous step

    """

    _nested_sdfg = nodes.NestedSDFG('_', sd.SDFG('_'), {}, {})

    @staticmethod
    def annotates_memlets():
        return True

    @staticmethod
    def expressions():
        return [sdutil.node_path_graph(InlineSDFG._nested_sdfg)]

    @staticmethod
    def _check_strides(inner_strides: List[symbolic.SymbolicType],
                       outer_strides: List[symbolic.SymbolicType],
                       memlet: Memlet, nested_sdfg: nodes.NestedSDFG) -> bool:
        """
        Returns True if the strides of the inner array can be matched
        to the strides of the outer array upon inlining. Takes into
        consideration memlet (un)squeeze and nested SDFG symbol mapping.
        :param inner_strides: The strides of the array inside the nested SDFG.
        :param outer_strides: The strides of the array in the external SDFG.
        :param nested_sdfg: Nested SDFG node with symbol mapping.
        :return: True if all strides match, False otherwise.
        """
        # Take unsqueezing into account
        dims_to_ignore = [
            i for i, s in enumerate(memlet.subset.size()) if s == 1
        ]
        ostrides = [
            os for i, os in enumerate(outer_strides) if i not in dims_to_ignore
        ]
        if len(ostrides) == 0:
            ostrides = [1]
        if len(ostrides) != len(inner_strides):
            return False

        # Replace all inner symbols based on symbol mapping
        repldict = {
            symbolic.pystr_to_symbolic(k): symbolic.pystr_to_symbolic(v)
            for k, v in nested_sdfg.symbol_mapping.items()
        }
        istrides = [
            istr.subs(repldict) if symbolic.issymbolic(istr) else istr
            for istr in inner_strides
        ]

        return all(istr == ostr for istr, ostr in zip(istrides, ostrides))

    @staticmethod
    def can_be_applied(graph, candidate, expr_index, sdfg, strict=False):
        nested_sdfg = graph.nodes()[candidate[InlineSDFG._nested_sdfg]]
        if nested_sdfg.no_inline:
            return False
        if len(nested_sdfg.sdfg.nodes()) != 1:
            return False

        # Ensure every connector has one incoming/outgoing edge
        in_connectors = set()
        out_connectors = set()
        for edge in graph.in_edges(nested_sdfg):
            if edge.dst_conn in in_connectors:
                return False
            in_connectors.add(edge.dst_conn)
        for edge in graph.out_edges(nested_sdfg):
            if edge.src_conn in out_connectors:
                return False
            out_connectors.add(edge.src_conn)

        # Ensure output connectors have no additional outputs (if in a scope),
        # and ensure no two connectors are directly connected to each other
        if graph.entry_node(nested_sdfg) is not None:
            all_connectors = in_connectors | out_connectors
            nstate = nested_sdfg.sdfg.node(0)
            for node in nstate.nodes():
                if isinstance(node, nodes.AccessNode):
                    if (node.data in out_connectors
                            and nstate.out_degree(node) > 0
                            and (node.data not in in_connectors
                                 or nstate.in_degree(node) > 0)):
                        return False
                    if (node.data in in_connectors
                            and any(e.dst.data in all_connectors
                                    for e in nstate.out_edges(node)
                                    if isinstance(e.dst, nodes.AccessNode))):
                        return False

        return True

    @staticmethod
    def match_to_str(graph, candidate):
        return graph.label

    def _remove_edge_path(self,
                          state: SDFGState,
                          edge_map: Dict[str, MultiConnectorEdge],
                          unused: Set[str],
                          reverse: bool = False) -> List[MultiConnectorEdge]:
        """ Remove all edges along a path, until memlet tree contains siblings
            that should not be removed. Removes resulting isolated nodes as
            well. Operates in place.
            :param state: The state in which to remove edges.
            :param edge_map: Mapping from identifier to edge, used as a
                             predicate for removal.
            :param unused: Set of edge identifiers to remove.
            :param reverse: If False, removes forward in path, otherwise
                            backward.
            :return: List of edges from removed nodes at the path's end.
        """

        if reverse:
            edge_func = lambda e: state.out_edges(e.src)
            edge_pred = lambda pedge, e: e.src_conn == pedge.src_conn
        else:
            edge_func = lambda e: state.in_edges(e.dst)
            edge_pred = lambda pedge, e: e.dst_conn == pedge.dst_conn

        result = []

        for identifier, edge in edge_map.items():
            if identifier in unused:
                path = state.memlet_path(edge)
                pedge = None
                for pedge in (reversed(path) if reverse else path):
                    # If there are no other edges, it is safe to remove
                    if len([
                            e for e in edge_func(pedge) if edge_pred(pedge, e)
                    ]) == 1:
                        # Remove connectors as well
                        state.remove_edge_and_connectors(pedge)
                    else:
                        break
                else:  # Reached terminus without breaking, remove external node
                    if pedge is not None:
                        node = pedge.src if reverse else pedge.dst

                        # Keep track of edges on the other end of these nodes,
                        # they will be used to reconnect to first/last
                        # occurrence of access nodes in the inlined subgraph.
                        if reverse:
                            result.extend(state.in_edges(node))
                        else:
                            result.extend(state.out_edges(node))

                        state.remove_node(node)

        return result

    def apply(self, sdfg: SDFG):
        state: SDFGState = sdfg.nodes()[self.state_id]
        nsdfg_node = state.nodes()[self.subgraph[InlineSDFG._nested_sdfg]]
        nsdfg: SDFG = nsdfg_node.sdfg
        nstate: SDFGState = nsdfg.nodes()[0]

        if nsdfg_node.schedule is not dtypes.ScheduleType.Default:
            infer_types.set_default_schedule_and_storage_types(
                nsdfg, nsdfg_node.schedule)

        nsdfg_scope_entry = state.entry_node(nsdfg_node)
        nsdfg_scope_exit = (state.exit_node(nsdfg_scope_entry)
                            if nsdfg_scope_entry is not None else None)

        #######################################################
        # Collect and update top-level SDFG metadata

        # Global/init/exit code
        for loc, code in nsdfg.global_code.items():
            sdfg.append_global_code(code.code, loc)
        for loc, code in nsdfg.init_code.items():
            sdfg.append_init_code(code.code, loc)
        for loc, code in nsdfg.exit_code.items():
            sdfg.append_exit_code(code.code, loc)

        # Constants
        for cstname, cstval in nsdfg.constants.items():
            if cstname in sdfg.constants:
                if cstval != sdfg.constants[cstname]:
                    warnings.warn('Constant value mismatch for "%s" while '
                                  'inlining SDFG. Inner = %s != %s = outer' %
                                  (cstname, cstval, sdfg.constants[cstname]))
            else:
                sdfg.add_constant(cstname, cstval)

        # Find original source/destination edges (there is only one edge per
        # connector, according to match)
        inputs: Dict[str, MultiConnectorEdge] = {}
        outputs: Dict[str, MultiConnectorEdge] = {}
        input_set: Dict[str, str] = {}
        output_set: Dict[str, str] = {}
        for e in state.in_edges(nsdfg_node):
            inputs[e.dst_conn] = e
            input_set[e.data.data] = e.dst_conn
        for e in state.out_edges(nsdfg_node):
            outputs[e.src_conn] = e
            output_set[e.data.data] = e.src_conn

        # Access nodes that need to be reshaped
        reshapes: Set(str) = set()
        for aname, array in nsdfg.arrays.items():
            if array.transient:
                continue
            edge = None
            if aname in inputs:
                edge = inputs[aname]
                if len(array.shape) > len(edge.data.subset):
                    reshapes.add(aname)
                    continue
            if aname in outputs:
                edge = outputs[aname]
                if len(array.shape) > len(edge.data.subset):
                    reshapes.add(aname)
                    continue
            if edge is not None and not InlineSDFG._check_strides(
                    array.strides, sdfg.arrays[edge.data.data].strides,
                    edge.data, nsdfg_node):
                reshapes.add(aname)

        # Replace symbols using invocation symbol mapping
        # Two-step replacement (N -> __dacesym_N --> map[N]) to avoid clashes
        for symname, symvalue in nsdfg_node.symbol_mapping.items():
            if str(symname) != str(symvalue):
                nsdfg.replace(symname, '__dacesym_' + symname)
        for symname, symvalue in nsdfg_node.symbol_mapping.items():
            if str(symname) != str(symvalue):
                nsdfg.replace('__dacesym_' + symname, symvalue)

        # All transients become transients of the parent (if data already
        # exists, find new name)
        # Mapping from nested transient name to top-level name
        transients: Dict[str, str] = {}
        for node in nstate.nodes():
            if isinstance(node, nodes.AccessNode):
                datadesc = nsdfg.arrays[node.data]
                if node.data not in transients and datadesc.transient:
                    name = sdfg.add_datadesc('%s_%s' %
                                             (nsdfg.label, node.data),
                                             datadesc,
                                             find_new_name=True)
                    transients[node.data] = name

        # All transients of edges between code nodes are also added to parent
        for edge in nstate.edges():
            if (isinstance(edge.src, nodes.CodeNode)
                    and isinstance(edge.dst, nodes.CodeNode)):
                if edge.data.data is not None:
                    datadesc = nsdfg.arrays[edge.data.data]
                    if edge.data.data not in transients and datadesc.transient:
                        name = sdfg.add_datadesc('%s_%s' %
                                                 (nsdfg.label, edge.data.data),
                                                 datadesc,
                                                 find_new_name=True)
                        transients[edge.data.data] = name

        # Collect nodes to add to top-level graph
        new_incoming_edges: Dict[nodes.Node, MultiConnectorEdge] = {}
        new_outgoing_edges: Dict[nodes.Node, MultiConnectorEdge] = {}

        source_accesses = set()
        sink_accesses = set()
        for node in nstate.source_nodes():
            if (isinstance(node, nodes.AccessNode)
                    and node.data not in transients
                    and node.data not in reshapes):
                new_incoming_edges[node] = inputs[node.data]
                source_accesses.add(node)
        for node in nstate.sink_nodes():
            if (isinstance(node, nodes.AccessNode)
                    and node.data not in transients
                    and node.data not in reshapes):
                new_outgoing_edges[node] = outputs[node.data]
                sink_accesses.add(node)

        #######################################################
        # Replace data on inlined SDFG nodes/edges

        # Replace data names with their top-level counterparts
        repldict = {}
        repldict.update(transients)
        repldict.update({
            k: v.data.data
            for k, v in itertools.chain(inputs.items(), outputs.items())
        })

        # Add views whenever reshapes are necessary
        for dname in reshapes:
            desc = nsdfg.arrays[dname]
            # To avoid potential confusion, rename protected __return keyword
            if dname.startswith('__return'):
                newname = f'{nsdfg.name}_ret{dname[8:]}'
            else:
                newname = dname
            newname, _ = sdfg.add_view(newname,
                                       desc.shape,
                                       desc.dtype,
                                       storage=desc.storage,
                                       strides=desc.strides,
                                       offset=desc.offset,
                                       debuginfo=desc.debuginfo,
                                       allow_conflicts=desc.allow_conflicts,
                                       total_size=desc.total_size,
                                       alignment=desc.alignment,
                                       may_alias=desc.may_alias,
                                       find_new_name=True)
            repldict[dname] = newname

        for node in nstate.nodes():
            if isinstance(node, nodes.AccessNode) and node.data in repldict:
                node.data = repldict[node.data]
        for edge in nstate.edges():
            if edge.data.data in repldict:
                edge.data.data = repldict[edge.data.data]

        # Add extra access nodes for out/in view nodes
        for node in nstate.nodes():
            if isinstance(node, nodes.AccessNode) and node.data in reshapes:
                if nstate.in_degree(node) > 0 and nstate.out_degree(node) > 0:
                    # Such a node has to be in the output set
                    edge = outputs[node.data]

                    # Redirect outgoing edges through access node
                    out_edges = list(nstate.out_edges(node))
                    anode = nstate.add_access(edge.data.data)
                    vnode = nstate.add_access(node.data)
                    nstate.add_nedge(node, anode, edge.data)
                    nstate.add_nedge(anode, vnode, edge.data)
                    for e in out_edges:
                        nstate.remove_edge(e)
                        nstate.add_edge(vnode, e.src_conn, e.dst, e.dst_conn,
                                        e.data)

        #######################################################
        # Add nested SDFG into top-level SDFG

        # Add nested nodes into original state
        subgraph = SubgraphView(nstate, [
            n for n in nstate.nodes()
            if n not in (source_accesses | sink_accesses)
        ])
        state.add_nodes_from(subgraph.nodes())
        for edge in subgraph.edges():
            state.add_edge(edge.src, edge.src_conn, edge.dst, edge.dst_conn,
                           edge.data)

        #######################################################
        # Reconnect inlined SDFG

        # If a source/sink node is one of the inputs/outputs, reconnect it,
        # replacing memlets in outgoing/incoming paths
        modified_edges = set()
        modified_edges |= self._modify_memlet_path(new_incoming_edges, nstate,
                                                   state, True)
        modified_edges |= self._modify_memlet_path(new_outgoing_edges, nstate,
                                                   state, False)

        # Reshape: add connections to viewed data
        self._modify_reshape_data(reshapes, repldict, inputs, nstate, state,
                                  True)
        self._modify_reshape_data(reshapes, repldict, outputs, nstate, state,
                                  False)

        # Modify all other internal edges pertaining to input/output nodes
        for node in subgraph.nodes():
            if isinstance(node, nodes.AccessNode):
                if node.data in input_set or node.data in output_set:
                    if node.data in input_set:
                        outer_edge = inputs[input_set[node.data]]
                    else:
                        outer_edge = outputs[output_set[node.data]]

                    for edge in state.all_edges(node):
                        if (edge not in modified_edges
                                and edge.data.data == node.data):
                            for e in state.memlet_tree(edge):
                                if e.data.data == node.data:
                                    e._data = helpers.unsqueeze_memlet(
                                        e.data, outer_edge.data)

        # If source/sink node is not connected to a source/destination access
        # node, and the nested SDFG is in a scope, connect to scope with empty
        # memlets
        if nsdfg_scope_entry is not None:
            for node in subgraph.nodes():
                if state.in_degree(node) == 0:
                    state.add_edge(nsdfg_scope_entry, None, node, None,
                                   Memlet())
                if state.out_degree(node) == 0:
                    state.add_edge(node, None, nsdfg_scope_exit, None,
                                   Memlet())

        # Replace nested SDFG parents with new SDFG
        for node in nstate.nodes():
            if isinstance(node, nodes.NestedSDFG):
                node.sdfg.parent = state
                node.sdfg.parent_sdfg = sdfg
                node.sdfg.parent_nsdfg_node = node

        # Remove all unused external inputs/output memlet paths, as well as
        # resulting isolated nodes
        removed_in_edges = self._remove_edge_path(state,
                                                  inputs,
                                                  set(inputs.keys()) -
                                                  source_accesses,
                                                  reverse=True)
        removed_out_edges = self._remove_edge_path(state,
                                                   outputs,
                                                   set(outputs.keys()) -
                                                   sink_accesses,
                                                   reverse=False)

        # Re-add in/out edges to first/last nodes in subgraph
        order = [
            x for x in nx.topological_sort(nstate._nx)
            if isinstance(x, nodes.AccessNode)
        ]
        for edge in removed_in_edges:
            # Find first access node that refers to this edge
            node = next(n for n in order if n.data == edge.data.data)
            state.add_edge(edge.src, edge.src_conn, node, edge.dst_conn,
                           edge.data)
        for edge in removed_out_edges:
            # Find last access node that refers to this edge
            node = next(n for n in reversed(order) if n.data == edge.data.data)
            state.add_edge(node, edge.src_conn, edge.dst, edge.dst_conn,
                           edge.data)

        #######################################################
        # Remove nested SDFG node
        state.remove_node(nsdfg_node)

    def _modify_memlet_path(self, new_edges: Dict[nodes.Node,
                                                  MultiConnectorEdge],
                            nstate: SDFGState, state: SDFGState,
                            inputs: bool) -> Set[MultiConnectorEdge]:
        """ Modifies memlet paths in an inlined SDFG. Returns set of modified
            edges.
        """
        result = set()
        for node, top_edge in new_edges.items():
            inner_edges = (nstate.out_edges(node)
                           if inputs else nstate.in_edges(node))
            for inner_edge in inner_edges:
                new_memlet = helpers.unsqueeze_memlet(inner_edge.data,
                                                      top_edge.data)
                if inputs:
                    new_edge = state.add_edge(top_edge.src, top_edge.src_conn,
                                              inner_edge.dst,
                                              inner_edge.dst_conn, new_memlet)
                    mtree = state.memlet_tree(new_edge)
                else:
                    new_edge = state.add_edge(inner_edge.src,
                                              inner_edge.src_conn,
                                              top_edge.dst, top_edge.dst_conn,
                                              new_memlet)
                    mtree = state.memlet_tree(new_edge)

                # Modify all memlets going forward/backward
                def traverse(mtree_node):
                    result.add(mtree_node.edge)
                    mtree_node.edge._data = helpers.unsqueeze_memlet(
                        mtree_node.edge.data, top_edge.data)
                    for child in mtree_node.children:
                        traverse(child)

                for child in mtree.children:
                    traverse(child)

        return result

    def _modify_reshape_data(self, reshapes: Set[str], repldict: Dict[str,
                                                                      str],
                             new_edges: Dict[str, MultiConnectorEdge],
                             nstate: SDFGState, state: SDFGState,
                             inputs: bool):
        anodes = nstate.source_nodes() if inputs else nstate.sink_nodes()
        reshp = {repldict[r]: r for r in reshapes}
        for node in anodes:
            if not isinstance(node, nodes.AccessNode):
                continue
            if node.data not in reshp:
                continue
            edge = new_edges[reshp[node.data]]
            if inputs:
                state.add_edge(edge.src, edge.src_conn, node, None, edge.data)
            else:
                state.add_edge(node, None, edge.dst, edge.dst_conn, edge.data)
コード例 #4
0
ファイル: sdfg_nesting.py プロジェクト: JanKleine/dace
class InlineSDFG(pattern_matching.Transformation):
    """ Inlines a single-state nested SDFG into a top-level SDFG.

        In particular, the steps taken are:

        1. All transient arrays become transients of the parent
        2. If a source/sink node is one of the inputs/outputs:
          a. Remove it
          b. Reconnect through external edges (map/accessnode)
          c. Replace and reoffset memlets with external data descriptor
        3. If other nodes carry the names of inputs/outputs:
          a. Replace data with external data descriptor
          b. Replace and reoffset memlets with external data descriptor
        4. If source/sink node is not connected to a source/destination, and
           the nested SDFG is in a scope, connect to scope with empty memlets
        5. Remove all unused external inputs/output memlet paths
        6. Remove isolated nodes resulting from previous step

    """

    _nested_sdfg = nodes.NestedSDFG('_', sd.SDFG('_'), set(), set())

    @staticmethod
    def annotates_memlets():
        return True

    @staticmethod
    def expressions():
        # Matches anything
        return [sdutil.node_path_graph(InlineSDFG._nested_sdfg)]

    @staticmethod
    def _find_edge(state: SDFGState, node: nodes.Node,
                   connector: str) -> Optional[MultiConnectorEdge]:
        for edge in state.in_edges(node):
            if edge.dst_conn == connector:
                return edge
        for edge in state.out_edges(node):
            if edge.src_conn == connector:
                return edge
        raise NameError('Edge with connector %s not found on node %s' %
                        (connector, node))

    @staticmethod
    def can_be_applied(graph, candidate, expr_index, sdfg, strict=False):
        nested_sdfg = graph.nodes()[candidate[InlineSDFG._nested_sdfg]]
        if len(nested_sdfg.sdfg.nodes()) != 1:
            return False

        # Ensure every connector has one incoming/outgoing edge
        in_connectors = set()
        out_connectors = set()
        for edge in graph.in_edges(nested_sdfg):
            if edge.dst_conn in in_connectors:
                return False
            in_connectors.add(edge.dst_conn)
        for edge in graph.out_edges(nested_sdfg):
            if edge.src_conn in out_connectors:
                return False
            out_connectors.add(edge.src_conn)

        # Ensure output connectors have no additional outputs (if in a scope),
        # and ensure no two connectors are directly connected to each other
        if graph.entry_node(nested_sdfg) is not None:
            all_connectors = in_connectors | out_connectors
            nstate = nested_sdfg.sdfg.node(0)
            for node in nstate.nodes():
                if isinstance(node, nodes.AccessNode):
                    if (node.data in out_connectors
                            and nstate.out_degree(node) > 0
                            and (node.data not in in_connectors
                                 or nstate.in_degree(node) > 0)):
                        return False
                    if (node.data in in_connectors
                            and any(e.dst.data in all_connectors
                                    for e in nstate.out_edges(node)
                                    if isinstance(e.dst, nodes.AccessNode))):
                        return False

        # If some reshaping that cannot be inlined / unsqueezed is happening,
        # do not match transformation in strict mode.
        if strict:
            for aname, array in nested_sdfg.sdfg.arrays.items():
                if array.transient:
                    continue
                edge = InlineSDFG._find_edge(graph, nested_sdfg, aname)
                if len(array.shape) > len(edge.data.subset):
                    return False

        return True

    @staticmethod
    def match_to_str(graph, candidate):
        return graph.label

    def _remove_edge_path(self,
                          state: SDFGState,
                          edge_map: Dict[str, MultiConnectorEdge],
                          unused: Set[str],
                          reverse: bool = False) -> List[MultiConnectorEdge]:
        """ Remove all edges along a path, until memlet tree contains siblings
            that should not be removed. Removes resulting isolated nodes as
            well. Operates in place.
            :param state: The state in which to remove edges.
            :param edge_map: Mapping from identifier to edge, used as a
                             predicate for removal.
            :param unused: Set of edge identifiers to remove.
            :param reverse: If False, removes forward in path, otherwise
                            backward.
            :return: List of edges from removed nodes at the path's end.
        """

        if reverse:
            edge_func = lambda e: state.out_edges(e.src)
            edge_pred = lambda pedge, e: e.src_conn == pedge.src_conn
        else:
            edge_func = lambda e: state.in_edges(e.dst)
            edge_pred = lambda pedge, e: e.dst_conn == pedge.dst_conn

        result = []

        for identifier, edge in edge_map.items():
            if identifier in unused:
                path = state.memlet_path(edge)
                pedge = None
                for pedge in (reversed(path) if reverse else path):
                    # If there are no other edges, it is safe to remove
                    if len([
                            e for e in edge_func(pedge) if edge_pred(pedge, e)
                    ]) == 1:
                        # Remove connectors as well
                        state.remove_edge_and_connectors(pedge)
                    else:
                        break
                else:  # Reached terminus without breaking, remove external node
                    if pedge is not None:
                        node = pedge.src if reverse else pedge.dst

                        # Keep track of edges on the other end of these nodes,
                        # they will be used to reconnect to first/last
                        # occurrence of access nodes in the inlined subgraph.
                        if reverse:
                            result.extend(state.in_edges(node))
                        else:
                            result.extend(state.out_edges(node))

                        state.remove_node(node)

        return result

    def apply(self, sdfg):
        state: SDFGState = sdfg.nodes()[self.state_id]
        nsdfg_node = state.nodes()[self.subgraph[InlineSDFG._nested_sdfg]]
        nsdfg: SDFG = nsdfg_node.sdfg
        nstate: SDFGState = nsdfg.nodes()[0]

        nsdfg_scope_entry = state.entry_node(nsdfg_node)
        nsdfg_scope_exit = (state.exit_node(nsdfg_scope_entry)
                            if nsdfg_scope_entry is not None else None)

        #######################################################
        # Collect and update top-level SDFG metadata

        # Global/init/exit code
        for loc, code in nsdfg.global_code.items():
            sdfg.append_global_code(code.code, loc)
        for loc, code in nsdfg.init_code.items():
            sdfg.append_init_code(code.code, loc)
        for loc, code in nsdfg.exit_code.items():
            sdfg.append_exit_code(code.code, loc)

        # Constants
        for cstname, cstval in nsdfg.constants.items():
            if cstname in sdfg.constants:
                if cstval != sdfg.constants[cstname]:
                    warnings.warn('Constant value mismatch for "%s" while '
                                  'inlining SDFG. Inner = %s != %s = outer' %
                                  (cstname, cstval, sdfg.constants[cstname]))
            else:
                sdfg.add_constant(cstname, cstval)

        # Find original source/destination edges (there is only one edge per
        # connector, according to match)
        inputs: Dict[str, MultiConnectorEdge] = {}
        outputs: Dict[str, MultiConnectorEdge] = {}
        input_set: Dict[str, str] = {}
        output_set: Dict[str, str] = {}
        for e in state.in_edges(nsdfg_node):
            inputs[e.dst_conn] = e
            input_set[e.data.data] = e.dst_conn
        for e in state.out_edges(nsdfg_node):
            outputs[e.src_conn] = e
            output_set[e.data.data] = e.src_conn

        # All transients become transients of the parent (if data already
        # exists, find new name)
        # Mapping from nested transient name to top-level name
        transients: Dict[str, str] = {}
        for node in nstate.nodes():
            if isinstance(node, nodes.AccessNode):
                datadesc = nsdfg.arrays[node.data]
                if node.data not in transients and datadesc.transient:
                    name = sdfg.add_datadesc('%s_%s' %
                                             (nsdfg.label, node.data),
                                             datadesc,
                                             find_new_name=True)
                    transients[node.data] = name

        # All transients of edges between code nodes are also added to parent
        for edge in nstate.edges():
            if (isinstance(edge.src, nodes.CodeNode)
                    and isinstance(edge.dst, nodes.CodeNode)):
                datadesc = nsdfg.arrays[edge.data.data]
                if edge.data.data not in transients and datadesc.transient:
                    name = sdfg.add_datadesc('%s_%s' %
                                             (nsdfg.label, edge.data.data),
                                             datadesc,
                                             find_new_name=True)
                    transients[edge.data.data] = name

        # Collect nodes to add to top-level graph
        new_incoming_edges: Dict[nodes.Node, MultiConnectorEdge] = {}
        new_outgoing_edges: Dict[nodes.Node, MultiConnectorEdge] = {}

        source_accesses = set()
        sink_accesses = set()
        for node in nstate.source_nodes():
            if (isinstance(node, nodes.AccessNode)
                    and node.data not in transients):
                new_incoming_edges[node] = inputs[node.data]
                source_accesses.add(node)
        for node in nstate.sink_nodes():
            if (isinstance(node, nodes.AccessNode)
                    and node.data not in transients):
                new_outgoing_edges[node] = outputs[node.data]
                sink_accesses.add(node)

        #######################################################
        # Add nested SDFG into top-level SDFG

        # Add nested nodes into original state
        subgraph = SubgraphView(nstate, [
            n for n in nstate.nodes()
            if n not in (source_accesses | sink_accesses)
        ])
        state.add_nodes_from(subgraph.nodes())
        for edge in subgraph.edges():
            state.add_edge(edge.src, edge.src_conn, edge.dst, edge.dst_conn,
                           edge.data)

        #######################################################
        # Replace data on inlined SDFG nodes/edges

        # Replace symbols using invocation symbol mapping
        # Two-step replacement (N -> __dacesym_N --> map[N]) to avoid clashes
        for symname, symvalue in nsdfg_node.symbol_mapping.items():
            if str(symname) != str(symvalue):
                nsdfg.replace(symname, '__dacesym_' + symname)
        for symname, symvalue in nsdfg_node.symbol_mapping.items():
            if str(symname) != str(symvalue):
                nsdfg.replace('__dacesym_' + symname, symvalue)

        # Replace data names with their top-level counterparts
        repldict = {}
        repldict.update(transients)
        repldict.update({
            k: v.data.data
            for k, v in itertools.chain(inputs.items(), outputs.items())
        })
        for node in subgraph.nodes():
            if isinstance(node, nodes.AccessNode) and node.data in repldict:
                node.data = repldict[node.data]
        for edge in subgraph.edges():
            if edge.data.data in repldict:
                edge.data.data = repldict[edge.data.data]

        #######################################################
        # Reconnect inlined SDFG

        # If a source/sink node is one of the inputs/outputs, reconnect it,
        # replacing memlets in outgoing/incoming paths
        modified_edges = set()
        modified_edges |= self._modify_memlet_path(new_incoming_edges, nstate,
                                                   state, True)
        modified_edges |= self._modify_memlet_path(new_outgoing_edges, nstate,
                                                   state, False)

        # Modify all other internal edges pertaining to input/output nodes
        for node in subgraph.nodes():
            if isinstance(node, nodes.AccessNode):
                if node.data in input_set or node.data in output_set:
                    if node.data in input_set:
                        outer_edge = inputs[input_set[node.data]]
                    else:
                        outer_edge = outputs[output_set[node.data]]

                    for edge in state.all_edges(node):
                        if (edge not in modified_edges
                                and edge.data.data == node.data):
                            for e in state.memlet_tree(edge):
                                if e.data.data == node.data:
                                    e._data = helpers.unsqueeze_memlet(
                                        e.data, outer_edge.data)

        # If source/sink node is not connected to a source/destination access
        # node, and the nested SDFG is in a scope, connect to scope with empty
        # memlets
        if nsdfg_scope_entry is not None:
            for node in subgraph.nodes():
                if state.in_degree(node) == 0:
                    state.add_edge(nsdfg_scope_entry, None, node, None,
                                   Memlet())
                if state.out_degree(node) == 0:
                    state.add_edge(node, None, nsdfg_scope_exit, None,
                                   Memlet())

        # Replace nested SDFG parents with new SDFG
        for node in nstate.nodes():
            if isinstance(node, nodes.NestedSDFG):
                node.sdfg.parent = state
                node.sdfg.parent_sdfg = sdfg
                node.sdfg.parent_nsdfg_node = node

        # Remove all unused external inputs/output memlet paths, as well as
        # resulting isolated nodes
        removed_in_edges = self._remove_edge_path(state,
                                                  inputs,
                                                  set(inputs.keys()) -
                                                  source_accesses,
                                                  reverse=True)
        removed_out_edges = self._remove_edge_path(state,
                                                   outputs,
                                                   set(outputs.keys()) -
                                                   sink_accesses,
                                                   reverse=False)

        # Re-add in/out edges to first/last nodes in subgraph
        order = [
            x for x in nx.topological_sort(nstate._nx)
            if isinstance(x, nodes.AccessNode)
        ]
        for edge in removed_in_edges:
            # Find first access node that refers to this edge
            node = next(n for n in order if n.data == edge.data.data)
            state.add_edge(edge.src, edge.src_conn, node, edge.dst_conn,
                           edge.data)
        for edge in removed_out_edges:
            # Find last access node that refers to this edge
            node = next(n for n in reversed(order) if n.data == edge.data.data)
            state.add_edge(node, edge.src_conn, edge.dst, edge.dst_conn,
                           edge.data)

        #######################################################
        # Remove nested SDFG node
        state.remove_node(nsdfg_node)

    def _modify_memlet_path(self, new_edges: Dict[nodes.Node,
                                                  MultiConnectorEdge],
                            nstate: SDFGState, state: SDFGState,
                            inputs: bool) -> Set[MultiConnectorEdge]:
        """ Modifies memlet paths in an inlined SDFG. Returns set of modified
            edges.
        """
        result = set()
        for node, top_edge in new_edges.items():
            inner_edges = (nstate.out_edges(node)
                           if inputs else nstate.in_edges(node))
            for inner_edge in inner_edges:
                new_memlet = helpers.unsqueeze_memlet(inner_edge.data,
                                                      top_edge.data)
                if inputs:
                    new_edge = state.add_edge(top_edge.src, top_edge.src_conn,
                                              inner_edge.dst,
                                              inner_edge.dst_conn, new_memlet)
                    mtree = state.memlet_tree(new_edge)
                else:
                    new_edge = state.add_edge(inner_edge.src,
                                              inner_edge.src_conn,
                                              top_edge.dst, top_edge.dst_conn,
                                              new_memlet)
                    mtree = state.memlet_tree(new_edge)

                # Modify all memlets going forward/backward
                def traverse(mtree_node):
                    result.add(mtree_node.edge)
                    mtree_node.edge._data = helpers.unsqueeze_memlet(
                        mtree_node.edge.data, top_edge.data)
                    for child in mtree_node.children:
                        traverse(child)

                for child in mtree.children:
                    traverse(child)

        return result
コード例 #5
0
ファイル: map_fission.py プロジェクト: tobiasholenstein/dace
class MapFission(transformation.Transformation):
    """ Implements the MapFission transformation.
        Map fission refers to subsuming a map scope into its internal subgraph,
        essentially replicating the map into maps in all of its internal
        components. This also extends the dimensions of "border" transient
        arrays (i.e., those between the maps), in order to retain program
        semantics after fission.

        There are two cases that match map fission:
        1. A map with an arbitrary subgraph with more than one computational
           (i.e., non-access) node. The use of arrays connecting the
           computational nodes must be limited to the subgraph, and non
           transient arrays may not be used as "border" arrays.
        2. A map with one internal node that is a nested SDFG, in which
           each state matches the conditions of case (1).

        If a map has nested SDFGs in its subgraph, they are not considered in
        the case (1) above, and MapFission must be invoked again on the maps
        with the nested SDFGs in question.
    """
    _map_entry = nodes.EntryNode()
    _nested_sdfg = nodes.NestedSDFG("", OrderedDiGraph(), {}, {})

    @staticmethod
    def annotates_memlets():
        return False

    @staticmethod
    def expressions():
        return [
            sdutil.node_path_graph(MapFission._map_entry, ),
            sdutil.node_path_graph(
                MapFission._map_entry,
                MapFission._nested_sdfg,
            )
        ]

    @staticmethod
    def _components(
            subgraph: gr.SubgraphView) -> List[Tuple[nodes.Node, nodes.Node]]:
        """
        Returns the list of tuples non-array components in this subgraph.
        Each element in the list is a 2 tuple of (input node, output node) of
        the component.
        """
        graph = (subgraph
                 if isinstance(subgraph, sd.SDFGState) else subgraph.graph)
        schildren = subgraph.scope_children()
        ns = [(n, graph.exit_node(n)) if isinstance(n, nodes.EntryNode) else
              (n, n) for n in schildren[None]
              if isinstance(n, (nodes.CodeNode, nodes.EntryNode))]

        return ns

    @staticmethod
    def _border_arrays(sdfg, parent, subgraph):
        """ Returns a set of array names that are local to the fission
            subgraph. """
        nested = isinstance(parent, sd.SDFGState)
        schildren = subgraph.scope_children()
        subset = gr.SubgraphView(parent, schildren[None])
        if nested:
            return set(node.data for node in subset.nodes()
                       if isinstance(node, nodes.AccessNode)
                       and sdfg.arrays[node.data].transient)
        else:
            return set(node.data for node in subset.nodes()
                       if isinstance(node, nodes.AccessNode))

    @staticmethod
    def _internal_border_arrays(total_components, subgraphs):
        """ Returns the set of border arrays that appear between computational
            components (i.e., without sources and sinks). """
        inputs = set()
        outputs = set()

        for components, subgraph in zip(total_components, subgraphs):
            for component_in, component_out in components:
                for e in subgraph.in_edges(component_in):
                    if isinstance(e.src, nodes.AccessNode):
                        inputs.add(e.src.data)
                for e in subgraph.out_edges(component_out):
                    if isinstance(e.dst, nodes.AccessNode):
                        outputs.add(e.dst.data)

        return inputs & outputs

    @staticmethod
    def _outside_map(node, scope_dict, entry_nodes):
        """ Returns True iff node is not in any of the scopes spanned by
            entry_nodes. """
        while scope_dict[node] is not None:
            if scope_dict[node] in entry_nodes:
                return False
            node = scope_dict[node]
        return True

    @staticmethod
    def can_be_applied(graph, candidate, expr_index, sdfg, strict=False):
        map_node = graph.node(candidate[MapFission._map_entry])
        nsdfg_node = None

        # If the map is dynamic-ranged, the resulting border arrays would be
        # dynamically sized
        if sd.has_dynamic_map_inputs(graph, map_node):
            return False

        if expr_index == 0:  # Map with subgraph
            subgraphs = [
                graph.scope_subgraph(map_node,
                                     include_entry=False,
                                     include_exit=False)
            ]
        else:  # Map with nested SDFG
            nsdfg_node = graph.node(candidate[MapFission._nested_sdfg])
            # Make sure there are no other internal nodes in the map
            if len(set(e.dst for e in graph.out_edges(map_node))) > 1:
                return False
            subgraphs = list(nsdfg_node.sdfg.nodes())

        # Test subgraphs
        border_arrays = set()
        total_components = []
        for sg in subgraphs:
            components = MapFission._components(sg)
            snodes = sg.nodes()
            # Test that the subgraphs have more than one computational component
            if expr_index == 0 and len(snodes) > 0 and len(components) <= 1:
                return False

            # Test that the components are connected by transients that are not
            # used anywhere else
            border_arrays |= MapFission._border_arrays(
                nsdfg_node.sdfg if expr_index == 1 else sdfg,
                sg if expr_index == 1 else graph, sg)
            total_components.append(components)

            # In nested SDFGs and subgraphs, ensure none of the border
            # values are non-transients
            for array in border_arrays:
                if expr_index == 0:
                    ndesc = sdfg.arrays[array]
                else:
                    ndesc = nsdfg_node.sdfg.arrays[array]

                if ndesc.transient is False:
                    return False

            # In subgraphs, make sure transients are not used/allocated
            # in other scopes or states
            if expr_index == 0:
                # Find all nodes not in subgraph
                not_subgraph = set(
                    n.data for n in graph.nodes()
                    if n not in snodes and isinstance(n, nodes.AccessNode))
                not_subgraph.update(
                    set(n.data for s in sdfg.nodes() if s != graph
                        for n in s.nodes() if isinstance(n, nodes.AccessNode)))

                for _, component_out in components:
                    for e in sg.out_edges(component_out):
                        if isinstance(e.dst, nodes.AccessNode):
                            if e.dst.data in not_subgraph:
                                return False

        # Fail if there are arrays inside the map that are not a direct
        # output of a computational component
        # TODO(later): Support this case? Ambiguous array sizes and memlets
        external_arrays = (
            border_arrays -
            MapFission._internal_border_arrays(total_components, subgraphs))
        if len(external_arrays) > 0:
            return False

        return True

    @staticmethod
    def match_to_str(graph, candidate):
        map_entry = graph.node(candidate[MapFission._map_entry])
        return map_entry.map.label

    def apply(self, sdfg: sd.SDFG):
        graph: sd.SDFGState = sdfg.nodes()[self.state_id]
        map_entry = graph.node(self.subgraph[MapFission._map_entry])
        map_exit = graph.exit_node(map_entry)
        nsdfg_node: Optional[nodes.NestedSDFG] = None

        # Obtain subgraph to perform fission to
        if self.expr_index == 0:  # Map with subgraph
            subgraphs = [(graph,
                          graph.scope_subgraph(map_entry,
                                               include_entry=False,
                                               include_exit=False))]
            parent = sdfg
        else:  # Map with nested SDFG
            nsdfg_node = graph.node(self.subgraph[MapFission._nested_sdfg])
            subgraphs = [(state, state) for state in nsdfg_node.sdfg.nodes()]
            parent = nsdfg_node.sdfg
        modified_arrays = set()

        # Get map information
        outer_map: nodes.Map = map_entry.map
        mapsize = outer_map.range.size()

        # Add new symbols from outer map to nested SDFG
        if self.expr_index == 1:
            map_syms = outer_map.range.free_symbols
            for edge in graph.out_edges(map_entry):
                if edge.data.data:
                    map_syms.update(edge.data.subset.free_symbols)
            for edge in graph.in_edges(map_exit):
                if edge.data.data:
                    map_syms.update(edge.data.subset.free_symbols)
            for sym in map_syms:
                symname = str(sym)
                if symname in outer_map.params:
                    continue
                if symname not in nsdfg_node.symbol_mapping.keys():
                    nsdfg_node.symbol_mapping[symname] = sym
                    nsdfg_node.sdfg.symbols[
                        symname] = graph.symbols_defined_at(
                            nsdfg_node)[symname]

            # Remove map symbols from nested mapping
            for name in outer_map.params:
                if str(name) in nsdfg_node.symbol_mapping:
                    del nsdfg_node.symbol_mapping[str(name)]
                if str(name) in nsdfg_node.sdfg.symbols:
                    del nsdfg_node.sdfg.symbols[str(name)]

        for state, subgraph in subgraphs:
            components = MapFission._components(subgraph)
            sources = subgraph.source_nodes()
            sinks = subgraph.sink_nodes()

            # Collect external edges
            if self.expr_index == 0:
                external_edges_entry = list(state.out_edges(map_entry))
                external_edges_exit = list(state.in_edges(map_exit))
            else:
                external_edges_entry = [
                    e for e in subgraph.edges()
                    if (isinstance(e.src, nodes.AccessNode)
                        and not nsdfg_node.sdfg.arrays[e.src.data].transient)
                ]
                external_edges_exit = [
                    e for e in subgraph.edges()
                    if (isinstance(e.dst, nodes.AccessNode)
                        and not nsdfg_node.sdfg.arrays[e.dst.data].transient)
                ]

            # Map external edges to outer memlets
            edge_to_outer = {}
            for edge in external_edges_entry:
                if self.expr_index == 0:
                    # Subgraphs use the corresponding outer map edges
                    path = state.memlet_path(edge)
                    eindex = path.index(edge)
                    edge_to_outer[edge] = path[eindex - 1]
                else:
                    # Nested SDFGs use the internal map edges of the node
                    outer_edge = next(e for e in graph.in_edges(nsdfg_node)
                                      if e.dst_conn == edge.src.data)
                    edge_to_outer[edge] = outer_edge

            for edge in external_edges_exit:
                if self.expr_index == 0:
                    path = state.memlet_path(edge)
                    eindex = path.index(edge)
                    edge_to_outer[edge] = path[eindex + 1]
                else:
                    # Nested SDFGs use the internal map edges of the node
                    outer_edge = next(e for e in graph.out_edges(nsdfg_node)
                                      if e.src_conn == edge.dst.data)
                    edge_to_outer[edge] = outer_edge

            # Collect all border arrays and code->code edges
            arrays = MapFission._border_arrays(
                nsdfg_node.sdfg if self.expr_index == 1 else sdfg, state,
                subgraph)
            scalars = defaultdict(list)
            for _, component_out in components:
                for e in subgraph.out_edges(component_out):
                    if isinstance(e.dst, nodes.CodeNode):
                        scalars[e.data.data].append(e)

            # Create new arrays for scalars
            for scalar, edges in scalars.items():
                desc = parent.arrays[scalar]
                del parent.arrays[scalar]
                name, newdesc = parent.add_transient(
                    scalar,
                    mapsize,
                    desc.dtype,
                    desc.storage,
                    lifetime=desc.lifetime,
                    debuginfo=desc.debuginfo,
                    allow_conflicts=desc.allow_conflicts,
                    find_new_name=True)

                # Add extra nodes in component boundaries
                for edge in edges:
                    anode = state.add_access(name)
                    sbs = subsets.Range.from_string(','.join(outer_map.params))
                    # Offset memlet by map range begin (to fit the transient)
                    sbs.offset([r[0] for r in outer_map.range], True)
                    state.add_edge(
                        edge.src, edge.src_conn, anode, None,
                        mm.Memlet.simple(
                            name,
                            sbs,
                            num_accesses=outer_map.range.num_elements()))
                    state.add_edge(
                        anode, None, edge.dst, edge.dst_conn,
                        mm.Memlet.simple(
                            name,
                            sbs,
                            num_accesses=outer_map.range.num_elements()))
                    state.remove_edge(edge)

            # Add extra maps around components
            new_map_entries = []
            for component_in, component_out in components:
                me, mx = state.add_map(outer_map.label + '_fission',
                                       [(p, '0:1') for p in outer_map.params],
                                       outer_map.schedule,
                                       unroll=outer_map.unroll,
                                       debuginfo=outer_map.debuginfo)

                # Add dynamic input connectors
                for conn in map_entry.in_connectors:
                    if not conn.startswith('IN_'):
                        me.add_in_connector(conn)

                me.map.range = dcpy(outer_map.range)
                new_map_entries.append(me)

                # Reconnect edges through new map
                for e in state.in_edges(component_in):
                    state.add_edge(me, None, e.dst, e.dst_conn, dcpy(e.data))
                    # Reconnect inner edges at source directly to external nodes
                    if self.expr_index == 0 and e in external_edges_entry:
                        state.add_edge(edge_to_outer[e].src,
                                       edge_to_outer[e].src_conn, me, None,
                                       dcpy(edge_to_outer[e].data))
                    else:
                        state.add_edge(e.src, e.src_conn, me, None,
                                       dcpy(e.data))
                    state.remove_edge(e)
                # Empty memlet edge in nested SDFGs
                if state.in_degree(component_in) == 0:
                    state.add_edge(me, None, component_in, None, mm.Memlet())

                for e in state.out_edges(component_out):
                    state.add_edge(e.src, e.src_conn, mx, None, dcpy(e.data))
                    # Reconnect inner edges at sink directly to external nodes
                    if self.expr_index == 0 and e in external_edges_exit:
                        state.add_edge(mx, None, edge_to_outer[e].dst,
                                       edge_to_outer[e].dst_conn,
                                       dcpy(edge_to_outer[e].data))
                    else:
                        state.add_edge(mx, None, e.dst, e.dst_conn,
                                       dcpy(e.data))
                    state.remove_edge(e)
                # Empty memlet edge in nested SDFGs
                if state.out_degree(component_out) == 0:
                    state.add_edge(component_out, None, mx, None, mm.Memlet())
            # Connect other sources/sinks not in components (access nodes)
            # directly to external nodes
            if self.expr_index == 0:
                for node in sources:
                    if isinstance(node, nodes.AccessNode):
                        for edge in state.in_edges(node):
                            outer_edge = edge_to_outer[edge]
                            memlet = dcpy(edge.data)
                            memlet.subset = subsets.Range(
                                outer_map.range.ranges + memlet.subset.ranges)
                            state.add_edge(outer_edge.src, outer_edge.src_conn,
                                           edge.dst, edge.dst_conn, memlet)

                for node in sinks:
                    if isinstance(node, nodes.AccessNode):
                        for edge in state.out_edges(node):
                            outer_edge = edge_to_outer[edge]
                            state.add_edge(edge.src, edge.src_conn,
                                           outer_edge.dst, outer_edge.dst_conn,
                                           dcpy(outer_edge.data))

            # Augment arrays by prepending map dimensions
            for array in arrays:
                if array in modified_arrays:
                    continue
                desc = parent.arrays[array]
                for sz in reversed(mapsize):
                    desc.strides = [desc.total_size] + list(desc.strides)
                    desc.total_size = desc.total_size * sz

                desc.shape = mapsize + list(desc.shape)
                desc.offset = [0] * len(mapsize) + list(desc.offset)
                modified_arrays.add(array)

            # Fill scope connectors so that memlets can be tracked below
            state.fill_scope_connectors()

            # Correct connectors and memlets in nested SDFGs to account for
            # missing outside map
            if self.expr_index == 1:
                to_correct = ([(e, e.src) for e in external_edges_entry] +
                              [(e, e.dst) for e in external_edges_exit])
                corrected_nodes = set()
                for edge, node in to_correct:
                    if isinstance(node, nodes.AccessNode):
                        if node in corrected_nodes:
                            continue
                        corrected_nodes.add(node)

                        outer_edge = edge_to_outer[edge]
                        desc = parent.arrays[node.data]

                        # Modify shape of internal array to match outer one
                        outer_desc = sdfg.arrays[outer_edge.data.data]
                        if not isinstance(desc, dt.Scalar):
                            desc.shape = outer_desc.shape
                        if isinstance(desc, dt.Array):
                            desc.strides = outer_desc.strides
                            desc.total_size = outer_desc.total_size

                        # Inside the nested SDFG, offset all memlets to include
                        # the offsets from within the map.
                        # NOTE: Relies on propagation to fix outer memlets
                        for internal_edge in state.all_edges(node):
                            for e in state.memlet_tree(internal_edge):
                                e.data.subset.offset(desc.offset, False)
                                e.data.subset = helpers.unsqueeze_memlet(
                                    e.data, outer_edge.data).subset

                        # Only after offsetting memlets we can modify the
                        # overall offset
                        if isinstance(desc, dt.Array):
                            desc.offset = outer_desc.offset

            # Fill in memlet trees for border transients
            # NOTE: Memlet propagation should run to correct the outer edges
            for node in subgraph.nodes():
                if isinstance(node, nodes.AccessNode) and node.data in arrays:
                    for edge in state.all_edges(node):
                        for e in state.memlet_tree(edge):
                            # Prepend map dimensions to memlet
                            e.data.subset = subsets.Range(
                                [(pystr_to_symbolic(d) - r[0],
                                  pystr_to_symbolic(d) - r[0], 1) for d, r in
                                 zip(outer_map.params, outer_map.range)] +
                                e.data.subset.ranges)

        # If nested SDFG, reconnect nodes around map and modify memlets
        if self.expr_index == 1:
            for edge in graph.in_edges(map_entry):
                if not edge.dst_conn or not edge.dst_conn.startswith('IN_'):
                    continue

                # Modify edge coming into nested SDFG to include entire array
                desc = sdfg.arrays[edge.data.data]
                edge.data.subset = subsets.Range.from_array(desc)
                edge.data.num_accesses = edge.data.subset.num_elements()

                # Find matching edge inside map
                inner_edge = next(
                    e for e in graph.out_edges(map_entry)
                    if e.src_conn and e.src_conn[4:] == edge.dst_conn[3:])
                graph.add_edge(edge.src, edge.src_conn, nsdfg_node,
                               inner_edge.dst_conn, dcpy(edge.data))

            for edge in graph.out_edges(map_exit):
                # Modify edge coming out of nested SDFG to include entire array
                desc = sdfg.arrays[edge.data.data]
                edge.data.subset = subsets.Range.from_array(desc)

                # Find matching edge inside map
                inner_edge = next(e for e in graph.in_edges(map_exit)
                                  if e.dst_conn[3:] == edge.src_conn[4:])
                graph.add_edge(nsdfg_node, inner_edge.src_conn, edge.dst,
                               edge.dst_conn, dcpy(edge.data))

        # Remove outer map
        graph.remove_nodes_from([map_entry, map_exit])