Exemple #1
0
def offset_map(state, map_entry):
    offsets = []
    subgraph = state.scope_subgraph(map_entry)
    for i, (p, r) in enumerate(
            zip(map_entry.map.params, map_entry.map.range.min_element())):
        if r != 0:
            offsets.append(r)
            replace(subgraph, str(p), f'{p}+{r}')

        else:
            offsets.append(0)

    map_entry.map.range.offset(offsets, negative=True)
Exemple #2
0
    def apply(self, sdfg):
        """
            This method applies the mapfusion transformation. 
            Other than the removal of the second map entry node (SME), and the first
            map exit (FME) node, it has the following side effects:

            1.  Any transient adjacent to both FME and SME with degree = 2 will be removed. 
                The tasklets that use/produce it shall be connected directly with a 
                scalar/new transient (if the dataflow is more than a single scalar)

            2.  If this transient is adjacent to FME and SME and has other
                uses, it will be adjacent to the new map exit post fusion.
                Tasklet-> Tasklet edges will ALSO be added as mentioned above.

            3.  If an access node is adjacent to FME but not SME, it will be
                adjacent to new map exit post fusion.

            4.  If an access node is adjacent to SME but not FME, it will be
                adjacent to the new map entry node post fusion.

        """
        graph = sdfg.nodes()[self.state_id]
        first_exit = graph.nodes()[self.subgraph[MapFusion._first_map_exit]]
        first_entry = graph.entry_node(first_exit)
        second_entry = graph.nodes()[self.subgraph[
            MapFusion._second_map_entry]]
        second_exit = graph.exit_nodes(second_entry)[0]

        intermediate_nodes = set()
        for _, _, dst, _, _ in graph.out_edges(first_exit):
            intermediate_nodes.add(dst)
            assert isinstance(dst, nodes.AccessNode)

        # Check if an access node refers to non transient memory, or transient
        # is used at another location (cannot erase)
        do_not_erase = set()
        for node in intermediate_nodes:
            if sdfg.arrays[node.data].transient is False:
                do_not_erase.add(node)
            else:
                for edge in graph.in_edges(node):
                    if edge.src != first_exit:
                        do_not_erase.add(node)
                        break
                else:
                    for edge in graph.out_edges(node):
                        if edge.dst != second_entry:
                            do_not_erase.add(node)
                            break

        # Find permutation between first and second scopes
        perm = MapFusion.find_permutation(first_entry.map, second_entry.map)
        params_dict = {}
        for index, param in enumerate(first_entry.map.params):
            params_dict[param] = second_entry.map.params[perm[index]]

        # Replaces (in memlets and tasklet) the second scope map
        # indices with the permuted first map indices.
        # This works in two passes to avoid problems when e.g., exchanging two
        # parameters (instead of replacing (j,i) and (i,j) to (j,j) and then
        # i,i).
        second_scope = graph.scope_subgraph(second_entry)
        for firstp, secondp in params_dict.items():
            if firstp != secondp:
                replace(second_scope, secondp, '__' + secondp + '_fused')
        for firstp, secondp in params_dict.items():
            if firstp != secondp:
                replace(second_scope, '__' + secondp + '_fused', firstp)

        # Isolate First exit node
        ############################
        edges_to_remove = set()
        nodes_to_remove = set()
        for edge in graph.in_edges(first_exit):
            memlet_path = graph.memlet_path(edge)
            edge_index = next(i for i, e in enumerate(memlet_path)
                              if e == edge)
            access_node = memlet_path[-1].dst
            if access_node not in do_not_erase:
                out_edges = [
                    e for e in graph.out_edges(access_node)
                    if e.dst == second_entry
                ]
                # In this transformation, there can only be one edge to the
                # second map
                assert len(out_edges) == 1
                # Get source connector to the second map
                connector = out_edges[0].dst_conn[3:]

                new_dst = None
                new_dst_conn = None
                # Look at the second map entry out-edges to get the new
                # destination
                for _e in graph.out_edges(second_entry):
                    if _e.src_conn[4:] == connector:
                        new_dst = _e.dst
                        new_dst_conn = _e.dst_conn
                        break
                if new_dst is None:
                    # Access node is not used in the second map
                    nodes_to_remove.add(access_node)
                    continue
                # If the source is an access node, modify the memlet to point
                # to it
                if (isinstance(edge.src, nodes.AccessNode)
                        and edge.data.data != edge.src.data):
                    edge.data.data = edge.src.data
                    edge.data.subset = ("0" if edge.data.other_subset is None
                                        else edge.data.other_subset)
                    edge.data.other_subset = None

                else:
                    # Add a transient scalar/array
                    self.fuse_nodes(sdfg, graph, edge, new_dst, new_dst_conn)

                edges_to_remove.add(edge)

                # Remove transient node between the two maps
                nodes_to_remove.add(access_node)
            else:  # The case where intermediate array node cannot be removed
                # Node will become an output of the second map exit
                out_e = memlet_path[edge_index + 1]
                conn = second_exit.next_connector()
                graph.add_edge(
                    second_exit,
                    'OUT_' + conn,
                    out_e.dst,
                    out_e.dst_conn,
                    dcpy(out_e.data),
                )
                second_exit.add_out_connector('OUT_' + conn)

                graph.add_edge(edge.src, edge.src_conn, second_exit,
                               'IN_' + conn, dcpy(edge.data))
                second_exit.add_in_connector('IN_' + conn)

                edges_to_remove.add(out_e)

                # If the second map needs this node, link the connector
                # that generated this to the place where it is needed, with a
                # temp transient/scalar for memlet to be generated
                for out_e in graph.out_edges(second_entry):
                    second_memlet_path = graph.memlet_path(out_e)
                    source_node = second_memlet_path[0].src
                    if source_node == access_node:
                        self.fuse_nodes(sdfg, graph, edge, out_e.dst,
                                        out_e.dst_conn)

                edges_to_remove.add(edge)
        ###
        # First scope exit is isolated and can now be safely removed
        for e in edges_to_remove:
            graph.remove_edge(e)
        graph.remove_nodes_from(nodes_to_remove)
        graph.remove_node(first_exit)

        # Isolate second_entry node
        ###########################
        for edge in graph.in_edges(second_entry):
            memlet_path = graph.memlet_path(edge)
            edge_index = next(i for i, e in enumerate(memlet_path)
                              if e == edge)
            access_node = memlet_path[0].src
            if access_node in intermediate_nodes:
                # Already handled above, can be safely removed
                graph.remove_edge(edge)
                continue

            # This is an external input to the second map which will now go
            # through the first map.
            conn = first_entry.next_connector()
            graph.add_edge(edge.src, edge.src_conn, first_entry, 'IN_' + conn,
                           dcpy(edge.data))
            first_entry.add_in_connector('IN_' + conn)
            graph.remove_edge(edge)
            out_e = memlet_path[edge_index + 1]
            graph.add_edge(
                first_entry,
                'OUT_' + conn,
                out_e.dst,
                out_e.dst_conn,
                dcpy(out_e.data),
            )
            first_entry.add_out_connector('OUT_' + conn)

            graph.remove_edge(out_e)
        ###
        # Second node is isolated and can now be safely removed
        graph.remove_node(second_entry)

        # Fix scope exit to point to the right map
        second_exit.map = first_entry.map
Exemple #3
0
    def expand(self, sdfg, graph, map_entries, map_base_variables=None):
        """
        Expansion into outer and inner maps for each map in a specified set.
        The resulting outer maps all have same range and indices, corresponding
        variables and memlets get changed accordingly. The inner map contains
        the leftover dimensions
        :param sdfg: Underlying SDFG
        :param graph: Graph in which we expand
        :param map_entries: List of Map Entries(Type MapEntry) that we want to expand
        :param map_base_variables: Optional parameter. List of strings
                                   If None, then expand() searches for the maximal amount
                                   of equal map ranges and pushes those and their corresponding
                                   loop variables into the outer loop.
                                   If specified, then expand() pushes the ranges belonging
                                   to the loop iteration variables specified into the outer loop
                                   (For instance map_base_variables = ['i','j'] assumes that
                                   all maps have common iteration indices i and j with corresponding
                                   correct ranges)
        """

        maps = [entry.map for entry in map_entries]

        if not map_base_variables:
            # find the maximal subset of variables to expand
            # greedy if there exist multiple ranges that are equal in a map

            map_base_ranges = helpers.common_map_base_ranges(maps)
            reassignments = helpers.find_reassignment(maps, map_base_ranges)

            ##### first, regroup and reassign
            # create params_dict for every map
            # first, let us define the outer iteration variable names,
            # just take the first map and their indices at common ranges
            map_base_variables = []
            for rng in map_base_ranges:
                for i in range(len(maps[0].params)):
                    if maps[0].range[i] == rng and maps[0].params[
                            i] not in map_base_variables:
                        map_base_variables.append(maps[0].params[i])
                        break

            params_dict = {}
            if self.debug:
                print("MultiExpansion::Map_base_variables:", map_base_variables)
                print("MultiExpansion::Map_base_ranges:", map_base_ranges)
            for map in maps:
                # for each map create param dict, first assign identity
                params_dict_map = {param: param for param in map.params}
                # now look for the correct reassignment
                # for every element neq -1, need to change param to map_base_variables[]
                # if param already appears in own dict, do a swap
                # else we just replace it
                for i, reassignment in enumerate(reassignments[map]):
                    if reassignment == -1:
                        # nothing to do
                        pass
                    else:
                        current_var = map.params[i]
                        current_assignment = params_dict_map[current_var]
                        target_assignment = map_base_variables[reassignment]
                        if current_assignment != target_assignment:
                            if target_assignment in params_dict_map.values():
                                # do a swap
                                key1 = current_var
                                for key, value in params_dict_map.items():
                                    if value == target_assignment:
                                        key2 = key

                                value1 = params_dict_map[key1]
                                value2 = params_dict_map[key2]
                                params_dict_map[key1] = key2
                                params_dict_map[key2] = key1
                            else:
                                # just reassign
                                params_dict_map[current_var] = target_assignment

                # done, assign params_dict_map to the global one
                params_dict[map] = params_dict_map

            for map, map_entry in zip(maps, map_entries):
                map_scope = graph.scope_subgraph(map_entry)
                params_dict_map = params_dict[map]
                for firstp, secondp in params_dict_map.items():
                    if firstp != secondp:
                        replace(map_scope, firstp, '__' + firstp + '_fused')
                for firstp, secondp in params_dict_map.items():
                    if firstp != secondp:
                        replace(map_scope, '__' + firstp + '_fused', secondp)

                # now also replace the map variables inside maps
                for i in range(len(map.params)):
                    map.params[i] = params_dict_map[map.params[i]]

            if self.debug:
                print("MultiExpansion::Params replaced")

        else:
            # just calculate map_base_ranges
            # do a check whether all maps correct
            map_base_ranges = []

            map0 = maps[0]
            for var in map_base_variables:
                index = map0.params.index(var)
                map_base_ranges.append(map0.range[index])

            for map in maps:
                for var, rng in zip(map_base_variables, map_base_ranges):
                    assert map.range[map.params.index(var)] == rng

        # then expand all the maps
        for map, map_entry in zip(maps, map_entries):
            if map.get_param_num() == len(map_base_variables):
                # nothing to expand, continue
                continue

            map_exit = graph.exit_node(map_entry)
            # create two new maps, outer and inner
            params_outer = map_base_variables
            ranges_outer = map_base_ranges

            init_params_inner = []
            init_ranges_inner = []
            for param, rng in zip(map.params, map.range):
                if param in map_base_variables:
                    continue
                else:
                    init_params_inner.append(param)
                    init_ranges_inner.append(rng)

            params_inner = init_params_inner
            ranges_inner = subsets.Range(init_ranges_inner)
            inner_map = nodes.Map(label = map.label + '_inner',
                                  params = params_inner,
                                  ndrange = ranges_inner,
                                  schedule = dtypes.ScheduleType.Sequential \
                                             if self.sequential_innermaps \
                                             else dtypes.ScheduleType.Default)

            map.label = map.label + '_outer'
            map.params = params_outer
            map.range = ranges_outer

            # create new map entries and exits
            map_entry_inner = nodes.MapEntry(inner_map)
            map_exit_inner = nodes.MapExit(inner_map)

            # analogously to Map_Expansion
            for edge in graph.out_edges(map_entry):
                graph.remove_edge(edge)
                graph.add_memlet_path(map_entry,
                                      map_entry_inner,
                                      edge.dst,
                                      src_conn=edge.src_conn,
                                      memlet=edge.data,
                                      dst_conn=edge.dst_conn)

            dynamic_edges = dynamic_map_inputs(graph, map_entry)
            for edge in dynamic_edges:
                # Remove old edge and connector
                graph.remove_edge(edge)
                edge.dst._in_connectors.remove(edge.dst_conn)

                # Propagate to each range it belongs to
                path = []
                for mapnode in [map_entry, map_entry_inner]:
                    path.append(mapnode)
                    if any(edge.dst_conn in map(str, symbolic.symlist(r))
                           for r in mapnode.map.range):
                        graph.add_memlet_path(edge.src,
                                              *path,
                                              memlet=edge.data,
                                              src_conn=edge.src_conn,
                                              dst_conn=edge.dst_conn)

            for edge in graph.in_edges(map_exit):
                graph.remove_edge(edge)
                graph.add_memlet_path(edge.src,
                                      map_exit_inner,
                                      map_exit,
                                      memlet=edge.data,
                                      src_conn=edge.src_conn,
                                      dst_conn=edge.dst_conn)
Exemple #4
0
    def apply(self, graph: sd.SDFGState, sdfg: sd.SDFG):
        map_entry = self.map_entry

        map_param = map_entry.map.params[0]  # Assuming one dimensional

        ##############################
        # Change condition of loop to one fewer iteration (so that the
        # final one reads from the last buffer)
        map_rstart, map_rend, map_rstride = map_entry.map.range[0]
        map_rend = symbolic.pystr_to_symbolic('(%s) - (%s)' %
                                              (map_rend, map_rstride))
        map_entry.map.range = subsets.Range([(map_rstart, map_rend,
                                              map_rstride)])

        ##############################
        # Gather transients to modify
        transients_to_modify = set(edge.dst.data
                                   for edge in graph.out_edges(map_entry)
                                   if isinstance(edge.dst, nodes.AccessNode))

        # Add dimension to transients and modify memlets
        for transient in transients_to_modify:
            desc: data.Array = sdfg.arrays[transient]
            # Using non-python syntax to ensure properties change
            desc.strides = [desc.total_size] + list(desc.strides)
            desc.shape = [2] + list(desc.shape)
            desc.offset = [0] + list(desc.offset)
            desc.total_size = desc.total_size * 2

        ##############################
        # Modify memlets to use map parameter as buffer index
        modified_subsets = []  # Store modified memlets for final state
        for edge in graph.scope_subgraph(map_entry).edges():
            if edge.data.data in transients_to_modify:
                edge.data.subset = self._modify_memlet(sdfg, edge.data.subset,
                                                       edge.data.data)
                modified_subsets.append(edge.data.subset)
            else:  # Could be other_subset
                path = graph.memlet_path(edge)
                src_node = path[0].src
                dst_node = path[-1].dst

                # other_subset could be None. In that case, recreate from array
                dataname = None
                if (isinstance(src_node, nodes.AccessNode)
                        and src_node.data in transients_to_modify):
                    dataname = src_node.data
                elif (isinstance(dst_node, nodes.AccessNode)
                      and dst_node.data in transients_to_modify):
                    dataname = dst_node.data
                if dataname is not None:
                    subset = (edge.data.other_subset or
                              subsets.Range.from_array(sdfg.arrays[dataname]))
                    edge.data.other_subset = self._modify_memlet(
                        sdfg, subset, dataname)
                    modified_subsets.append(edge.data.other_subset)

        ##############################
        # Turn map into for loop
        map_to_for = MapToForLoop()
        map_to_for.setup_match(
            sdfg, self.sdfg_id, self.state_id,
            {MapToForLoop.map_entry: graph.node_id(self.map_entry)},
            self.expr_index)
        nsdfg_node, nstate = map_to_for.apply(graph, sdfg)

        ##############################
        # Gather node copies and remove memlets
        edges_to_replace = []
        for node in nstate.source_nodes():
            for edge in nstate.out_edges(node):
                if (isinstance(edge.dst, nodes.AccessNode)
                        and edge.dst.data in transients_to_modify):
                    edges_to_replace.append(edge)
                    nstate.remove_edge(edge)
            if nstate.out_degree(node) == 0:
                nstate.remove_node(node)

        ##############################
        # Add initial reads to initial nested state
        initial_state: sd.SDFGState = nsdfg_node.sdfg.start_state
        initial_state.set_label('%s_init' % map_entry.map.label)
        for edge in edges_to_replace:
            initial_state.add_node(edge.src)
            rnode = edge.src
            wnode = initial_state.add_write(edge.dst.data)
            initial_state.add_edge(rnode, edge.src_conn, wnode, edge.dst_conn,
                                   copy.deepcopy(edge.data))

        # All instances of the map parameter in this state become the loop start
        sd.replace(initial_state, map_param, map_rstart)
        # Initial writes go to the appropriate buffer
        init_expr = symbolic.pystr_to_symbolic('(%s / %s) %% 2' %
                                               (map_rstart, map_rstride))
        sd.replace(initial_state, '__dace_db_param', init_expr)

        ##############################
        # Modify main state's memlets

        # Divide by loop stride
        new_expr = symbolic.pystr_to_symbolic('(%s / %s) %% 2' %
                                              (map_param, map_rstride))
        sd.replace(nstate, '__dace_db_param', new_expr)

        ##############################
        # Add the main state's contents to the last state, modifying
        # memlets appropriately.
        final_state: sd.SDFGState = nsdfg_node.sdfg.sink_nodes()[0]
        final_state.set_label('%s_final_computation' % map_entry.map.label)
        dup_nstate = copy.deepcopy(nstate)
        final_state.add_nodes_from(dup_nstate.nodes())
        for e in dup_nstate.edges():
            final_state.add_edge(e.src, e.src_conn, e.dst, e.dst_conn, e.data)

        # If there is a WCR output with transient, only output in last state
        nstate: sd.SDFGState
        for node in nstate.sink_nodes():
            for e in list(nstate.in_edges(node)):
                if e.data.wcr is not None:
                    path = nstate.memlet_path(e)
                    if isinstance(path[0].src, nodes.AccessNode):
                        nstate.remove_memlet_path(e)

        ##############################
        # Add reads into next buffers to main state
        for edge in edges_to_replace:
            rnode = copy.deepcopy(edge.src)
            nstate.add_node(rnode)
            wnode = nstate.add_write(edge.dst.data)
            new_memlet = copy.deepcopy(edge.data)
            if new_memlet.data in transients_to_modify:
                new_memlet.other_subset = self._replace_in_subset(
                    new_memlet.other_subset, map_param,
                    '(%s + %s)' % (map_param, map_rstride))
            else:
                new_memlet.subset = self._replace_in_subset(
                    new_memlet.subset, map_param,
                    '(%s + %s)' % (map_param, map_rstride))

            nstate.add_edge(rnode, edge.src_conn, wnode, edge.dst_conn,
                            new_memlet)

        nstate.set_label('%s_double_buffered' % map_entry.map.label)
        # Divide by loop stride
        new_expr = symbolic.pystr_to_symbolic('((%s / %s) + 1) %% 2' %
                                              (map_param, map_rstride))
        sd.replace(nstate, '__dace_db_param', new_expr)

        # Remove symbol once done
        del nsdfg_node.sdfg.symbols['__dace_db_param']
        del nsdfg_node.symbol_mapping['__dace_db_param']

        return nsdfg_node
Exemple #5
0
    def apply(self, sdfg):
        """
            This method applies the mapfusion transformation. 
            Other than the removal of the second map entry node (SME), and the first
            map exit (FME) node, it has the following side effects:

            1.  Any transient adjacent to both FME and SME with degree = 2 will be removed. 
                The tasklets that use/produce it shall be connected directly with a 
                scalar/new transient (if the dataflow is more than a single scalar)

            2.  If this transient is adjacent to FME and SME and has other
                uses, it will be adjacent to the new map exit post fusion.
                Tasklet-> Tasklet edges will ALSO be added as mentioned above.

            3.  If an access node is adjacent to FME but not SME, it will be
                adjacent to new map exit post fusion.

            4.  If an access node is adjacent to SME but not FME, it will be
                adjacent to the new map entry node post fusion.

        """
        graph = sdfg.nodes()[self.state_id]
        first_exit = graph.nodes()[self.subgraph[MapFusion._first_map_exit]]
        first_entry = graph.entry_node(first_exit)
        second_entry = graph.nodes()[self.subgraph[
            MapFusion._second_map_entry]]
        second_exit = graph.exit_nodes(second_entry)[0]

        intermediate_nodes = set()
        for _, _, dst, _, _ in graph.out_edges(first_exit):
            intermediate_nodes.add(dst)
            assert isinstance(dst, nodes.AccessNode)

        # Check if an access node refers to non transient memory, or transient
        # is used at another location (cannot erase)
        do_not_erase = set()
        for node in intermediate_nodes:
            if sdfg.arrays[node.data].transient is False:
                do_not_erase.add(node)
            else:
                # If array is used anywhere else in this state.
                num_occurrences = len([
                    n for n in graph.nodes()
                    if isinstance(n, nodes.AccessNode) and n.data == node.data
                ])
                if num_occurrences > 1:
                    return False

                for edge in graph.in_edges(node):
                    if edge.src != first_exit:
                        do_not_erase.add(node)
                        break
                else:
                    for edge in graph.out_edges(node):
                        if edge.dst != second_entry:
                            do_not_erase.add(node)
                            break

        # Find permutation between first and second scopes
        if first_entry.map.params != second_entry.map.params:
            perm = MapFusion.find_permutation(first_entry.map,
                                              second_entry.map)
            params_dict = {}
            for _index, _param in enumerate(first_entry.map.params):
                params_dict[_param] = second_entry.map.params[perm[_index]]

            # Hopefully replaces (in memlets and tasklet) the second scope map
            # indices with the permuted first map indices
            second_scope = graph.scope_subgraph(second_entry)
            for _firstp, _secondp in params_dict.items():
                replace(second_scope, _secondp, _firstp)

        ########Isolate First MapExit node###########
        for _edge in graph.in_edges(first_exit):
            __some_str = _edge.data.data
            _access_node = graph.find_node(__some_str)
            # all outputs of first_exit are in intermediate_nodes set, so all inputs to
            # first_exit should also be!
            if _access_node not in do_not_erase:
                _new_dst = None
                _new_dst_conn = None
                # look at the second map entry out-edges to get the new destination
                for _e in graph.out_edges(second_entry):
                    if _e.data.data == _access_node.data:
                        _new_dst = _e.dst
                        _new_dst_conn = _e.dst_conn
                        break
                if _new_dst is None:
                    # Access node is not even used in the second map
                    graph.remove_node(_access_node)
                    continue
                if _edge.data.data == _access_node.data and isinstance(
                        _edge._src, nodes.AccessNode):
                    _edge.data.data = _edge._src.data
                    _edge.data.subset = "0"
                    graph.add_edge(
                        _edge._src,
                        _edge.src_conn,
                        _new_dst,
                        _new_dst_conn,
                        dcpy(_edge.data),
                    )
                else:
                    if _edge.data.subset.num_elements() == 1:
                        # We will add a scalar
                        local_name = "__s%d_n%d%s_n%d%s" % (
                            self.state_id,
                            graph.node_id(_edge._src),
                            _edge.src_conn,
                            graph.node_id(_edge._dst),
                            _edge.dst_conn,
                        )
                        local_node = sdfg.add_scalar(
                            local_name,
                            dtype=_access_node.desc(graph).dtype,
                            toplevel=False,
                            transient=True,
                            storage=dtypes.StorageType.Register,
                        )
                        _edge.data.data = (
                            local_name)  # graph.add_access(local_name).data
                        _edge.data.subset = "0"
                        graph.add_edge(
                            _edge._src,
                            _edge.src_conn,
                            _new_dst,
                            _new_dst_conn,
                            dcpy(_edge.data),
                        )
                    else:
                        # We will add a transient of size = memlet subset
                        # size
                        local_name = "__s%d_n%d%s_n%d%s" % (
                            self.state_id,
                            graph.node_id(_edge._src),
                            _edge.src_conn,
                            graph.node_id(_edge._dst),
                            _edge.dst_conn,
                        )
                        local_node = graph.add_transient(
                            local_name,
                            _edge.data.subset.size(),
                            dtype=_access_node.desc(graph).dtype,
                            toplevel=False,
                        )
                        _edge.data.data = (
                            local_name)  # graph.add_access(local_name).data
                        _edge.data.subset = ",".join([
                            "0:" + str(_s) for _s in _edge.data.subset.size()
                        ])
                        graph.add_edge(
                            _edge._src,
                            _edge.src_conn,
                            local_node,
                            None,
                            dcpy(_edge.data),
                        )
                        graph.add_edge(local_node, None, _new_dst,
                                       _new_dst_conn, dcpy(_edge.data))
                graph.remove_edge(_edge)
                ####Isolate this node#####
                for _in_e in graph.in_edges(_access_node):
                    graph.remove_edge(_in_e)
                for _out_e in graph.out_edges(_access_node):
                    graph.remove_edge(_out_e)
                graph.remove_node(_access_node)
            else:
                # _access_node will become an output of the second map exit
                for _out_e in graph.out_edges(first_exit):
                    if _out_e.data.data == _access_node.data:
                        graph.add_edge(
                            second_exit,
                            None,
                            _out_e._dst,
                            _out_e.dst_conn,
                            dcpy(_out_e.data),
                        )

                        graph.remove_edge(_out_e)
                        break
                else:
                    raise AssertionError(
                        "No out-edge was found that leads to {}".format(
                            _access_node))
                graph.add_edge(_edge._src, _edge.src_conn, second_exit, None,
                               dcpy(_edge.data))
                ### If the second map needs this node then link the connector
                # that generated this to the place where it is needed, with a
                # temp transient/scalar for memlet to be generated
                for _out_e in graph.out_edges(second_entry):
                    if _out_e.data.data == _access_node.data:
                        if _edge.data.subset.num_elements() == 1:
                            # We will add a scalar
                            local_name = "__s%d_n%d%s_n%d%s" % (
                                self.state_id,
                                graph.node_id(_edge._src),
                                _edge.src_conn,
                                graph.node_id(_edge._dst),
                                _edge.dst_conn,
                            )
                            local_node = sdfg.add_scalar(
                                local_name,
                                dtype=_access_node.desc(graph).dtype,
                                storage=dtypes.StorageType.Register,
                                toplevel=False,
                                transient=True,
                            )
                            _edge.data.data = (
                                local_name
                            )  # graph.add_access(local_name).data
                            _edge.data.subset = "0"
                            graph.add_edge(
                                _edge._src,
                                _edge.src_conn,
                                _out_e._dst,
                                _out_e.dst_conn,
                                dcpy(_edge.data),
                            )
                        else:
                            # We will add a transient of size = memlet subset
                            # size
                            local_name = "__s%d_n%d%s_n%d%s" % (
                                self.state_id,
                                graph.node_id(_edge._src),
                                _edge.src_conn,
                                graph.node_id(_edge._dst),
                                _edge.dst_conn,
                            )
                            local_node = sdfg.add_transient(
                                local_name,
                                _edge.data.subset.size(),
                                dtype=_access_node.desc(graph).dtype,
                                toplevel=False,
                            )
                            _edge.data.data = (
                                local_name
                            )  # graph.add_access(local_name).data
                            _edge.data.subset = ",".join([
                                "0:" + str(_s)
                                for _s in _edge.data.subset.size()
                            ])
                            graph.add_edge(
                                _edge._src,
                                _edge.src_conn,
                                local_node,
                                None,
                                dcpy(_edge.data),
                            )
                            graph.add_edge(
                                local_node,
                                None,
                                _out_e._dst,
                                _out_e.dst_conn,
                                dcpy(_edge.data),
                            )
                        break
                graph.remove_edge(_edge)
        graph.remove_node(first_exit)  # Take a leap of faith

        #############Isolate second_entry node################
        for _edge in graph.in_edges(second_entry):
            _access_node = graph.find_node(_edge.data.data)
            if _access_node in intermediate_nodes:
                # Already handled above, just remove this
                graph.remove_edge(_edge)
                continue
            else:
                # This is an external input to the second map which will now go through the first
                # map.
                graph.add_edge(_edge._src, _edge.src_conn, first_entry, None,
                               dcpy(_edge.data))
                graph.remove_edge(_edge)
                for _out_e in graph.out_edges(second_entry):
                    if _out_e.data.data == _access_node.data:
                        graph.add_edge(
                            first_entry,
                            None,
                            _out_e._dst,
                            _out_e.dst_conn,
                            dcpy(_out_e.data),
                        )
                        graph.remove_edge(_out_e)
                        break
                else:
                    raise AssertionError(
                        "No out-edge was found that leads to {}".format(
                            _access_node))

        graph.remove_node(second_entry)

        # Fix scope exit
        second_exit.map = first_entry.map
        graph.fill_scope_connectors()