Beispiel #1
0
    def coverage_dicts(sdfg, graph, map_entry, outer_range=True):
        '''
        returns a tuple of two dicts:
        the first dict has as a key all data entering the map
        and its associated access range
        the second dict has as a key all data exiting the map
        and its associated access range
        if outer_range = True, substitutes outer ranges
        into min/max of inner access range
        '''
        map_exit = graph.exit_node(map_entry)
        map = map_entry.map

        entry_coverage = {}
        exit_coverage = {}
        # create dicts with which we can replace all iteration
        # variable_mapping by their range
        map_min = {dace.symbol(param): e for param, e in zip(map.params, map.range.min_element())}
        map_max = {dace.symbol(param): e for param, e in zip(map.params, map.range.max_element())}

        # look at inner memlets at map entry
        for e in graph.out_edges(map_entry):
            if not e.data.subset:
                continue
            if outer_range:
                # get subset
                min_element = [m.subs(map_min) for m in e.data.subset.min_element()]
                max_element = [m.subs(map_max) for m in e.data.subset.max_element()]
                # create range
                rng = subsets.Range((min_e, max_e, 1) for min_e, max_e in zip(min_element, max_element))
            else:
                rng = dcpy(e.data.subset)

            if e.data.data not in entry_coverage:
                entry_coverage[e.data.data] = rng
            else:
                old_coverage = entry_coverage[e.data.data]
                entry_coverage[e.data.data] = subsets.union(old_coverage, rng)

        # look at inner memlets at map exit
        for e in graph.in_edges(map_exit):
            if outer_range:
                # get subset
                min_element = [m.subs(map_min) for m in e.data.subset.min_element()]
                max_element = [m.subs(map_max) for m in e.data.subset.max_element()]
                # craete range
                rng = subsets.Range((min_e, max_e, 1) for min_e, max_e in zip(min_element, max_element))
            else:
                rng = dcpy(e.data.subset)

            if e.data.data not in exit_coverage:
                exit_coverage[e.data.data] = rng
            else:
                old_coverage = exit_coverage[e.data]
                exit_coverage[e.data.data] = subsets.union(old_coverage, rng)

        # return both coverages as a tuple
        return (entry_coverage, exit_coverage)
Beispiel #2
0
def consolidate_edges_scope(
        state: SDFGState, scope_node: Union[nd.EntryNode, nd.ExitNode]) -> int:
    """
        Union scope-entering memlets relating to the same data node in a scope.
        This effectively reduces the number of connectors and allows more
        transformations to be performed, at the cost of losing the individual
        per-tasklet memlets.
        :param state: The SDFG state in which the scope to consolidate resides.
        :param scope_node: The scope node whose edges will be consolidated.
        :return: Number of edges removed.
    """
    if scope_node is None:
        return 0
    data_to_conn = {}
    consolidated = 0
    if isinstance(scope_node, nd.EntryNode):
        outer_edges = state.in_edges
        inner_edges = state.out_edges
        remove_outer_connector = scope_node.remove_in_connector
        remove_inner_connector = scope_node.remove_out_connector
        prefix, oprefix = 'IN_', 'OUT_'
    else:
        outer_edges = state.out_edges
        inner_edges = state.in_edges
        remove_outer_connector = scope_node.remove_out_connector
        remove_inner_connector = scope_node.remove_in_connector
        prefix, oprefix = 'OUT_', 'IN_'

    edges_by_connector = collections.defaultdict(list)
    connectors_to_remove = set()
    for e in inner_edges(scope_node):
        edges_by_connector[e.src_conn].append(e)
        if e.data.data not in data_to_conn:
            data_to_conn[e.data.data] = e.src_conn
        elif data_to_conn[e.data.data] != e.src_conn:  # Need to consolidate
            connectors_to_remove.add(e.src_conn)

    for conn in connectors_to_remove:
        e = edges_by_connector[conn][0]
        # Outer side of the scope - remove edge and union subsets
        target_conn = prefix + data_to_conn[e.data.data][len(oprefix):]
        conn_to_remove = prefix + conn[len(oprefix):]
        remove_outer_connector(conn_to_remove)
        out_edge = next(ed for ed in outer_edges(scope_node)
                        if ed.dst_conn == target_conn)
        edge_to_remove = next(ed for ed in outer_edges(scope_node)
                              if ed.dst_conn == conn_to_remove)
        out_edge.data.subset = sbs.union(out_edge.data.subset,
                                         edge_to_remove.data.subset)
        state.remove_edge(edge_to_remove)
        consolidated += 1
        # Inner side of the scope - remove and reconnect
        remove_inner_connector(e.src_conn)
        for e in edges_by_connector[conn]:
            e._src_conn = data_to_conn[e.data.data]

    return consolidated
Beispiel #3
0
def nest_state_subgraph(sdfg: SDFG,
                        state: SDFGState,
                        subgraph: SubgraphView,
                        name: Optional[str] = None,
                        full_data: bool = False) -> nodes.NestedSDFG:
    """ Turns a state subgraph into a nested SDFG. Operates in-place.
        :param sdfg: The SDFG containing the state subgraph.
        :param state: The state containing the subgraph.
        :param subgraph: Subgraph to nest.
        :param name: An optional name for the nested SDFG.
        :param full_data: If True, nests entire input/output data.
        :return: The nested SDFG node.
        :raise KeyError: Some or all nodes in the subgraph are not located in
                         this state, or the state does not belong to the given
                         SDFG.
        :raise ValueError: The subgraph is contained in more than one scope.
    """
    if state.parent != sdfg:
        raise KeyError('State does not belong to given SDFG')
    if subgraph is not state and subgraph.graph is not state:
        raise KeyError('Subgraph does not belong to given state')

    # Find the top-level scope
    scope_tree = state.scope_tree()
    scope_dict = state.scope_dict()
    scope_dict_children = state.scope_children()
    top_scopenode = -1  # Initialized to -1 since "None" already means top-level

    for node in subgraph.nodes():
        if node not in scope_dict:
            raise KeyError('Node not found in state')

        # If scope entry/exit, ensure entire scope is in subgraph
        if isinstance(node, nodes.EntryNode):
            scope_nodes = scope_dict_children[node]
            if any(n not in subgraph.nodes() for n in scope_nodes):
                raise ValueError('Subgraph contains partial scopes (entry)')
        elif isinstance(node, nodes.ExitNode):
            entry = state.entry_node(node)
            scope_nodes = scope_dict_children[entry] + [entry]
            if any(n not in subgraph.nodes() for n in scope_nodes):
                raise ValueError('Subgraph contains partial scopes (exit)')

        scope_node = scope_dict[node]
        if scope_node not in subgraph.nodes():
            if top_scopenode != -1 and top_scopenode != scope_node:
                raise ValueError('Subgraph is contained in more than one scope')
            top_scopenode = scope_node

    scope = scope_tree[top_scopenode]
    ###

    # Consolidate edges in top scope
    utils.consolidate_edges(sdfg, scope)
    snodes = subgraph.nodes()

    # Collect inputs and outputs of the nested SDFG
    inputs: List[MultiConnectorEdge] = []
    outputs: List[MultiConnectorEdge] = []
    for node in snodes:
        for edge in state.in_edges(node):
            if edge.src not in snodes:
                inputs.append(edge)
        for edge in state.out_edges(node):
            if edge.dst not in snodes:
                outputs.append(edge)

    # Collect transients not used outside of subgraph (will be removed of
    # top-level graph)
    data_in_subgraph = set(n.data for n in subgraph.nodes() if isinstance(n, nodes.AccessNode))
    # Find other occurrences in SDFG
    other_nodes = set(n.data for s in sdfg.nodes() for n in s.nodes()
                      if isinstance(n, nodes.AccessNode) and n not in subgraph.nodes())
    subgraph_transients = set()
    for data in data_in_subgraph:
        datadesc = sdfg.arrays[data]
        if datadesc.transient and data not in other_nodes:
            subgraph_transients.add(data)

    # All transients of edges between code nodes are also added to nested graph
    for edge in subgraph.edges():
        if (isinstance(edge.src, nodes.CodeNode) and isinstance(edge.dst, nodes.CodeNode)):
            subgraph_transients.add(edge.data.data)

    # Collect data used in access nodes within subgraph (will be referenced in
    # full upon nesting)
    input_arrays = set()
    output_arrays = {}
    for node in subgraph.nodes():
        if (isinstance(node, nodes.AccessNode) and node.data not in subgraph_transients):
            if node.has_reads(state):
                input_arrays.add(node.data)
            if node.has_writes(state):
                output_arrays[node.data] = state.in_edges(node)[0].data.wcr

    # Create the nested SDFG
    nsdfg = SDFG(name or 'nested_' + state.label)

    # Transients are added to the nested graph as-is
    for name in subgraph_transients:
        nsdfg.add_datadesc(name, sdfg.arrays[name])

    # Input/output data that are not source/sink nodes are added to the graph
    # as non-transients
    for name in (input_arrays | output_arrays.keys()):
        datadesc = copy.deepcopy(sdfg.arrays[name])
        datadesc.transient = False
        nsdfg.add_datadesc(name, datadesc)

    # Connected source/sink nodes outside subgraph become global data
    # descriptors in nested SDFG
    input_names = {}
    output_names = {}
    global_subsets: Dict[str, Tuple[str, Subset]] = {}
    for edge in inputs:
        if edge.data.data is None:  # Skip edges with an empty memlet
            continue
        name = edge.data.data
        if name not in global_subsets:
            datadesc = copy.deepcopy(sdfg.arrays[edge.data.data])
            datadesc.transient = False
            if not full_data:
                datadesc.shape = edge.data.subset.size()
            new_name = nsdfg.add_datadesc(name, datadesc, find_new_name=True)
            global_subsets[name] = (new_name, edge.data.subset)
        else:
            new_name, subset = global_subsets[name]
            if not full_data:
                new_subset = union(subset, edge.data.subset)
                if new_subset is None:
                    new_subset = Range.from_array(sdfg.arrays[name])
                global_subsets[name] = (new_name, new_subset)
                nsdfg.arrays[new_name].shape = new_subset.size()
        input_names[edge] = new_name
    for edge in outputs:
        if edge.data.data is None:  # Skip edges with an empty memlet
            continue
        name = edge.data.data
        if name not in global_subsets:
            datadesc = copy.deepcopy(sdfg.arrays[edge.data.data])
            datadesc.transient = False
            if not full_data:
                datadesc.shape = edge.data.subset.size()
            new_name = nsdfg.add_datadesc(name, datadesc, find_new_name=True)
            global_subsets[name] = (new_name, edge.data.subset)
        else:
            new_name, subset = global_subsets[name]
            if not full_data:
                new_subset = union(subset, edge.data.subset)
                if new_subset is None:
                    new_subset = Range.from_array(sdfg.arrays[name])
                global_subsets[name] = (new_name, new_subset)
                nsdfg.arrays[new_name].shape = new_subset.size()
        output_names[edge] = new_name
    ###################

    # Add scope symbols to the nested SDFG
    defined_vars = set(
        symbolic.pystr_to_symbolic(s) for s in (state.symbols_defined_at(top_scopenode).keys()
                                                | sdfg.symbols))
    for v in defined_vars:
        if v in sdfg.symbols:
            sym = sdfg.symbols[v]
            nsdfg.add_symbol(v, sym.dtype)

    # Add constants to nested SDFG
    for cstname, cstval in sdfg.constants.items():
        nsdfg.add_constant(cstname, cstval)

    # Create nested state
    nstate = nsdfg.add_state()

    # Add subgraph nodes and edges to nested state
    nstate.add_nodes_from(subgraph.nodes())
    for e in subgraph.edges():
        nstate.add_edge(e.src, e.src_conn, e.dst, e.dst_conn, copy.deepcopy(e.data))

    # Modify nested SDFG parents in subgraph
    for node in subgraph.nodes():
        if isinstance(node, nodes.NestedSDFG):
            node.sdfg.parent = nstate
            node.sdfg.parent_sdfg = nsdfg
            node.sdfg.parent_nsdfg_node = node

    # Add access nodes and edges as necessary
    edges_to_offset = []
    for edge, name in input_names.items():
        node = nstate.add_read(name)
        new_edge = copy.deepcopy(edge.data)
        new_edge.data = name
        edges_to_offset.append((edge, nstate.add_edge(node, None, edge.dst, edge.dst_conn, new_edge)))
    for edge, name in output_names.items():
        node = nstate.add_write(name)
        new_edge = copy.deepcopy(edge.data)
        new_edge.data = name
        edges_to_offset.append((edge, nstate.add_edge(edge.src, edge.src_conn, node, None, new_edge)))

    # Offset memlet paths inside nested SDFG according to subsets
    for original_edge, new_edge in edges_to_offset:
        for edge in nstate.memlet_tree(new_edge):
            edge.data.data = new_edge.data.data
            if not full_data:
                edge.data.subset.offset(global_subsets[original_edge.data.data][1], True)

    # Add nested SDFG node to the input state
    nested_sdfg = state.add_nested_sdfg(nsdfg, None,
                                        set(input_names.values()) | input_arrays,
                                        set(output_names.values()) | output_arrays.keys())

    # Reconnect memlets to nested SDFG
    reconnected_in = set()
    reconnected_out = set()
    empty_input = None
    empty_output = None
    for edge in inputs:
        if edge.data.data is None:
            empty_input = edge
            continue

        name = input_names[edge]
        if name in reconnected_in:
            continue
        if full_data:
            data = Memlet.from_array(edge.data.data, sdfg.arrays[edge.data.data])
        else:
            data = copy.deepcopy(edge.data)
            data.subset = global_subsets[edge.data.data][1]
        state.add_edge(edge.src, edge.src_conn, nested_sdfg, name, data)
        reconnected_in.add(name)

    for edge in outputs:
        if edge.data.data is None:
            empty_output = edge
            continue

        name = output_names[edge]
        if name in reconnected_out:
            continue
        if full_data:
            data = Memlet.from_array(edge.data.data, sdfg.arrays[edge.data.data])
        else:
            data = copy.deepcopy(edge.data)
            data.subset = global_subsets[edge.data.data][1]
        data.wcr = edge.data.wcr
        state.add_edge(nested_sdfg, name, edge.dst, edge.dst_conn, data)
        reconnected_out.add(name)

    # Connect access nodes to internal input/output data as necessary
    entry = scope.entry
    exit = scope.exit
    for name in input_arrays:
        node = state.add_read(name)
        if entry is not None:
            state.add_nedge(entry, node, Memlet())
        state.add_edge(node, None, nested_sdfg, name, Memlet.from_array(name, sdfg.arrays[name]))
    for name, wcr in output_arrays.items():
        node = state.add_write(name)
        if exit is not None:
            state.add_nedge(node, exit, Memlet())
        state.add_edge(nested_sdfg, name, node, None, Memlet(data=name, wcr=wcr))

    # Graph was not reconnected, but needs to be
    if state.in_degree(nested_sdfg) == 0 and empty_input is not None:
        state.add_edge(empty_input.src, empty_input.src_conn, nested_sdfg, None, empty_input.data)
    if state.out_degree(nested_sdfg) == 0 and empty_output is not None:
        state.add_edge(nested_sdfg, None, empty_output.dst, empty_output.dst_conn, empty_output.data)

    # Remove subgraph nodes from graph
    state.remove_nodes_from(subgraph.nodes())

    # Remove subgraph transients from top-level graph
    for transient in subgraph_transients:
        del sdfg.arrays[transient]

    # Remove newly isolated nodes due to memlet consolidation
    for edge in inputs:
        if state.in_degree(edge.src) + state.out_degree(edge.src) == 0:
            state.remove_node(edge.src)
    for edge in outputs:
        if state.in_degree(edge.dst) + state.out_degree(edge.dst) == 0:
            state.remove_node(edge.dst)

    return nested_sdfg
    def apply(self, sdfg):
        graph = sdfg.nodes()[self.state_id]
        subgraph = self.subgraph_view(sdfg)
        map_entries = helpers.get_outermost_scope_maps(sdfg, graph, subgraph)

        result = StencilTiling.topology(sdfg, graph, map_entries)
        (children_dict, parent_dict, sink_maps) = result

        # next up, calculate inferred ranges for each map
        # for each map entry, this contains a tuple of dicts:
        # each of those maps from data_name of the array to
        # inferred outer ranges. An inferred outer range is created
        # by taking the union of ranges of inner subsets corresponding
        # to that data and substituting this subset by the min / max of the
        # parametrized map boundaries
        # finally, from these outer ranges we can easily calculate
        # strides and tile sizes required for every map
        inferred_ranges = defaultdict(dict)

        # create array of reverse topologically sorted map entries
        # to iterate over
        topo_reversed = []
        queue = set(sink_maps.copy())
        while len(queue) > 0:
            element = next(e for e in queue
                           if not children_dict[e] - set(topo_reversed))
            topo_reversed.append(element)
            queue.remove(element)
            for parent in parent_dict[element]:
                queue.add(parent)

        # main loop
        # first get coverage dicts for each map entry
        # for each map, contains a tuple of two dicts
        # each of those two maps from data name to outer range
        coverage = {}
        for map_entry in map_entries:
            coverage[map_entry] = StencilTiling.coverage_dicts(
                sdfg, graph, map_entry, outer_range=True)

        # we have a mapping from data name to outer range
        # however we want a mapping from map parameters to outer ranges
        # for this we need to find out how all array dimensions map to
        # outer ranges

        variable_mapping = defaultdict(list)
        for map_entry in topo_reversed:
            map = map_entry.map

            # first find out variable mapping
            for e in itertools.chain(
                    graph.out_edges(map_entry),
                    graph.in_edges(graph.exit_node(map_entry))):
                mapping = []
                for dim in e.data.subset:
                    syms = set()
                    for d in dim:
                        syms |= symbolic.symlist(d).keys()
                    if len(syms) > 1:
                        raise NotImplementedError(
                            "One incoming or outgoing stencil subset is indexed "
                            "by multiple map parameters. "
                            "This is not supported yet.")
                    try:
                        mapping.append(syms.pop())
                    except KeyError:
                        # just append None if there is no map symbol in it.
                        # we don't care for now.
                        mapping.append(None)

                if e.data in variable_mapping:
                    # assert that this is the same everywhere.
                    # else we might run into problems
                    assert variable_mapping[e.data.data] == mapping
                else:
                    variable_mapping[e.data.data] = mapping

            # now do mapping data name -> outer range
            # and from that infer mapping variable -> outer range
            local_ranges = {dn: None for dn in coverage[map_entry][1].keys()}
            for data_name, cov in coverage[map_entry][1].items():
                local_ranges[data_name] = subsets.union(
                    local_ranges[data_name], cov)
                # now look at proceeding maps
                # and union those subsets -> could be larger with stencil indent
                for child_map in children_dict[map_entry]:
                    if data_name in coverage[child_map][0]:
                        local_ranges[data_name] = subsets.union(
                            local_ranges[data_name],
                            coverage[child_map][0][data_name])

            # final assignent: combine local_ranges and variable_mapping
            # together into inferred_ranges
            inferred_ranges[map_entry] = {p: None for p in map.params}
            for data_name, ranges in local_ranges.items():
                for param, r in zip(variable_mapping[data_name], ranges):
                    # create new range from this subset and assign
                    rng = subsets.Range((r, ))
                    if param:
                        inferred_ranges[map_entry][param] = subsets.union(
                            inferred_ranges[map_entry][param], rng)

        # get parameters -- should all be the same
        params = next(iter(map_entries)).map.params.copy()
        # define reference range as inferred range of one of the sink maps
        self.reference_range = inferred_ranges[next(iter(sink_maps))]
        if self.debug:
            print("StencilTiling::Reference Range", self.reference_range)
        # next up, search for the ranges that don't change
        invariant_dims = []
        for idx, p in enumerate(params):
            different = False
            if self.reference_range[p] is None:
                invariant_dims.append(idx)
                warnings.warn(
                    f"StencilTiling::No Stencil pattern detected for parameter {p}"
                )
                continue
            for m in map_entries:
                if inferred_ranges[m][p] != self.reference_range[p]:
                    different = True
                    break
            if not different:
                invariant_dims.append(idx)
                warnings.warn(
                    f"StencilTiling::No Stencil pattern detected for parameter {p}"
                )

        # during stripmining, we will create new outer map entries
        # for easy access
        self._outer_entries = set()
        # with inferred_ranges constructed, we can begin to strip mine
        for map_entry in map_entries:
            # Retrieve map entry and exit nodes.
            map = map_entry.map

            stripmine_subgraph = {
                StripMining._map_entry: graph.nodes().index(map_entry)
            }

            sdfg_id = sdfg.sdfg_id
            last_map_entry = None
            original_schedule = map_entry.schedule
            self.tile_sizes = []
            self.tile_offset_lower = []
            self.tile_offset_upper = []

            # strip mining each dimension where necessary
            removed_maps = 0
            for dim_idx, param in enumerate(map_entry.map.params):
                # get current_node tile size
                if dim_idx >= len(self.strides):
                    tile_stride = symbolic.pystr_to_symbolic(self.strides[-1])
                else:
                    tile_stride = symbolic.pystr_to_symbolic(
                        self.strides[dim_idx])

                trivial = False

                if dim_idx in invariant_dims:
                    self.tile_sizes.append(tile_stride)
                    self.tile_offset_lower.append(0)
                    self.tile_offset_upper.append(0)
                else:
                    target_range_current = inferred_ranges[map_entry][param]
                    reference_range_current = self.reference_range[param]

                    min_diff = symbolic.SymExpr(reference_range_current.min_element()[0] \
                                    - target_range_current.min_element()[0])
                    max_diff = symbolic.SymExpr(target_range_current.max_element()[0] \
                                    - reference_range_current.max_element()[0])

                    try:
                        min_diff = symbolic.evaluate(min_diff, {})
                        max_diff = symbolic.evaluate(max_diff, {})
                    except TypeError:
                        raise RuntimeError("Symbolic evaluation of map "
                                           "ranges failed. Please check "
                                           "your parameters and match.")

                    self.tile_sizes.append(tile_stride + max_diff + min_diff)
                    self.tile_offset_lower.append(
                        symbolic.pystr_to_symbolic(str(min_diff)))
                    self.tile_offset_upper.append(
                        symbolic.pystr_to_symbolic(str(max_diff)))

                # get calculated parameters
                tile_size = self.tile_sizes[-1]

                dim_idx -= removed_maps
                # If map or tile sizes are trivial, skip strip-mining map dimension
                # special cases:
                # if tile size is trivial AND we have an invariant dimension, skip
                if tile_size == map.range.size()[dim_idx] and (
                        dim_idx + removed_maps) in invariant_dims:
                    continue

                # trivial map: we just continue
                if map.range.size()[dim_idx] in [0, 1]:
                    continue

                if tile_size == 1 and tile_stride == 1 and (
                        dim_idx + removed_maps) in invariant_dims:
                    trivial = True
                    removed_maps += 1

                # indent all map ranges accordingly and then perform
                # strip mining on these. Offset inner maps accordingly afterwards

                range_tuple = (map.range[dim_idx][0] +
                               self.tile_offset_lower[-1],
                               map.range[dim_idx][1] -
                               self.tile_offset_upper[-1],
                               map.range[dim_idx][2])
                map.range[dim_idx] = range_tuple
                stripmine = StripMining(sdfg_id, self.state_id,
                                        stripmine_subgraph, 0)

                stripmine.tiling_type = 'ceilrange'
                stripmine.dim_idx = dim_idx
                stripmine.new_dim_prefix = self.prefix if not trivial else ''
                # use tile_stride for both -- we will extend
                # the inner tiles later
                stripmine.tile_size = str(tile_stride)
                stripmine.tile_stride = str(tile_stride)
                outer_map = stripmine.apply(sdfg)
                outer_map.schedule = original_schedule

                # apply to the new map the schedule of the original one
                map_entry.schedule = self.schedule

                # if tile stride is 1, we can make a nice simplification by just
                # taking the overapproximated inner range as inner range
                # this eliminates the min/max in the range which
                # enables loop unrolling
                if tile_stride == 1:
                    map_entry.range[dim_idx] = tuple(
                        symbolic.SymExpr(el._approx_expr) if isinstance(
                            el, symbolic.SymExpr) else el
                        for el in map_entry.range[dim_idx])

                # in map_entry: enlarge tiles by upper and lower offset
                # doing it this way and not via stripmine strides ensures
                # that the max gets changed as well
                old_range = map_entry.range[dim_idx]
                map_entry.range[dim_idx] = ((old_range[0] -
                                             self.tile_offset_lower[-1]),
                                            (old_range[1] +
                                             self.tile_offset_upper[-1]),
                                            old_range[2])

                # We have to propagate here for correct outer volume and subset sizes
                _propagate_node(graph, map_entry)
                _propagate_node(graph, graph.exit_node(map_entry))

                # usual tiling pipeline
                if last_map_entry:
                    new_map_entry = graph.in_edges(map_entry)[0].src
                    mapcollapse_subgraph = {
                        MapCollapse._outer_map_entry:
                        graph.node_id(last_map_entry),
                        MapCollapse._inner_map_entry:
                        graph.node_id(new_map_entry)
                    }
                    mapcollapse = MapCollapse(sdfg_id, self.state_id,
                                              mapcollapse_subgraph, 0)
                    mapcollapse.apply(sdfg)
                last_map_entry = graph.in_edges(map_entry)[0].src
            # add last instance of map entries to _outer_entries
            if last_map_entry:
                self._outer_entries.add(last_map_entry)

            # Map Unroll Feature: only unroll if conditions are met:
            # Only unroll if at least one of the inner map ranges is strictly larger than 1
            # Only unroll if strides all are one
            if self.unroll_loops and all(s == 1 for s in self.strides) and any(
                    s not in [0, 1] for s in map_entry.range.size()):
                l = len(map_entry.params)
                if l > 1:
                    subgraph = {
                        MapExpansion.map_entry: graph.nodes().index(map_entry)
                    }
                    trafo_expansion = MapExpansion(sdfg.sdfg_id,
                                                   sdfg.nodes().index(graph),
                                                   subgraph, 0)
                    trafo_expansion.apply(sdfg)
                maps = [map_entry]
                for _ in range(l - 1):
                    map_entry = graph.out_edges(map_entry)[0].dst
                    maps.append(map_entry)

                for map in reversed(maps):
                    # MapToForLoop
                    subgraph = {
                        MapToForLoop._map_entry: graph.nodes().index(map)
                    }
                    trafo_for_loop = MapToForLoop(sdfg.sdfg_id,
                                                  sdfg.nodes().index(graph),
                                                  subgraph, 0)
                    trafo_for_loop.apply(sdfg)
                    nsdfg = trafo_for_loop.nsdfg

                    # LoopUnroll

                    guard = trafo_for_loop.guard
                    end = trafo_for_loop.after_state
                    begin = next(e.dst for e in nsdfg.out_edges(guard)
                                 if e.dst != end)

                    subgraph = {
                        DetectLoop._loop_guard: nsdfg.nodes().index(guard),
                        DetectLoop._loop_begin: nsdfg.nodes().index(begin),
                        DetectLoop._exit_state: nsdfg.nodes().index(end)
                    }
                    transformation = LoopUnroll(0, 0, subgraph, 0)
                    transformation.apply(nsdfg)
            elif self.unroll_loops:
                warnings.warn(
                    "Did not unroll loops. Either all ranges are equal to "
                    "one or range difference is symbolic.")

        self._outer_entries = list(self._outer_entries)
    def can_be_applied(sdfg, subgraph) -> bool:
        # get highest scope maps
        graph = subgraph.graph
        map_entries = set(
            helpers.get_outermost_scope_maps(sdfg, graph, subgraph))
        # 1.1: There has to be more than one outermost scope map entry
        if len(map_entries) <= 1:
            return False

        # 1.2: check basic constraints:
        # - all parameters have to be the same (this implies same length)
        # - no parameter permutations here as ambiguity is very high then
        # - same strides everywhere
        first_map = next(iter(map_entries))
        params = dcpy(first_map.map.params)
        strides = first_map.map.range.strides()
        schedule = first_map.map.schedule

        for map_entry in map_entries:
            if map_entry.map.params != params:
                return False
            if map_entry.map.range.strides() != strides:
                return False
            if map_entry.map.schedule != schedule:
                return False

        # 1.3: check whether all map entries only differ by a const amount
        first_entry = next(iter(map_entries))
        for map_entry in map_entries:
            for r1, r2 in zip(map_entry.map.range, first_entry.map.range):
                if len((r1[0] - r2[0]).free_symbols) > 0:
                    return False
                if len((r1[1] - r2[1]).free_symbols) > 0:
                    return False

        # get intermediate_nodes, out_nodes from SubgraphFusion Transformation
        node_config = SubgraphFusion.get_adjacent_nodes(
            sdfg, graph, map_entries)
        (_, intermediate_nodes, out_nodes) = node_config

        # 1.4: check topological feasibility
        if not SubgraphFusion.check_topo_feasibility(
                sdfg, graph, map_entries, intermediate_nodes, out_nodes):
            return False
        # 1.5 nodes that are both intermediate and out nodes
        # are not supported in StencilTiling
        if len(intermediate_nodes & out_nodes) > 0:
            return False
        # get coverages for every map entry
        coverages = {}
        memlets = {}
        for map_entry in map_entries:
            coverages[map_entry] = StencilTiling.coverage_dicts(
                sdfg, graph, map_entry)
            memlets[map_entry] = StencilTiling.coverage_dicts(
                sdfg, graph, map_entry, outer_range=False)

        # get DAG neighbours for each map
        dag_neighbors = StencilTiling.topology(sdfg, graph, map_entries)
        (children_dict, _, sink_maps) = dag_neighbors

        # 1.6: we now check coverage:
        # each outgoing coverage for a data memlet has to
        # be exactly equal to the union of incoming coverages
        # of all chidlren map memlets of this data

        # important:
        # 1. it has to be equal and not only cover it in order to
        #    account for ranges too long
        # 2. we check coverages by map parameter and not by
        #    array, this way it is even more general
        # 3. map parameter coverages are checked for each
        #    (map_entry, children of this map_entry) - pair
        for map_entry in map_entries:
            # get coverage from current map_entry
            map_coverage = coverages[map_entry][1]

            # final mapping map_parameter -> coverage will be stored here
            param_parent_coverage = {p: None for p in map_entry.params}
            param_children_coverage = {p: None for p in map_entry.params}
            for child_entry in children_dict[map_entry]:
                # get mapping data_name -> coverage
                for (data_name, cov) in map_coverage.items():
                    parent_coverage = cov
                    children_coverage = None
                    if data_name in coverages[child_entry][0]:
                        children_coverage = subsets.union(
                            children_coverage,
                            coverages[child_entry][0][data_name])

                    # extend mapping map_parameter -> coverage
                    # by the previous mapping

                    for i, (p_subset, c_subset) in enumerate(
                            zip(parent_coverage, children_coverage)):

                        # transform into subset
                        p_subset = subsets.Range((p_subset, ))
                        c_subset = subsets.Range((c_subset, ))

                        # get associated parameter in memlet
                        params1 = symbolic.symlist(
                            memlets[map_entry][1][data_name][i]).keys()
                        params2 = symbolic.symlist(
                            memlets[child_entry][0][data_name][i]).keys()
                        if params1 != params2:
                            return False
                        params = params1
                        if len(params) > 1:
                            # this is not supported
                            return False
                        try:
                            symbol = next(iter(params))
                            param_parent_coverage[symbol] = subsets.union(
                                param_parent_coverage[symbol], p_subset)
                            param_children_coverage[symbol] = subsets.union(
                                param_children_coverage[symbol], c_subset)

                        except StopIteration:
                            # current dim has no symbol associated.
                            # ignore and continue
                            warnings.warn(
                                f"In map {map_entry}, there is a "
                                "dimension belonging to {data_name} "
                                "that has no map parameter associated.")
                            pass

            # 1.6: parameter mapping must be the same
            if param_parent_coverage != param_children_coverage:
                return False

        # 1.7: we want all sink maps to have the same range size
        assert len(sink_maps) > 0
        first_sink_map = next(iter(sink_maps))
        if not all([
                map.range.size() == first_sink_map.range.size()
                for map in sink_maps
        ]):
            return False

        return True
Beispiel #6
0
def propagate_subset(
        memlets: List[Memlet],
        arr: data.Data,
        params: List[str],
        rng: subsets.Subset,
        defined_variables: Set[symbolic.SymbolicType] = None) -> Memlet:
    """ Tries to propagate a list of memlets through a range (computes the 
        image of the memlet function applied on an integer set of, e.g., a 
        map range) and returns a new memlet object.
        :param memlets: The memlets to propagate.
        :param arr: Array descriptor for memlet (used for obtaining extents).
        :param params: A list of variable names.
        :param rng: A subset with dimensionality len(params) that contains the
                    range to propagate with.
        :param defined_variables: A set of symbols defined that will remain the
                                  same throughout propagation. If None, assumes
                                  that all symbols outside of `params` have been
                                  defined.
        :return: Memlet with propagated subset and volume.
    """
    # Argument handling
    if defined_variables is None:
        # Default defined variables is "everything but params"
        defined_variables = set()
        defined_variables |= rng.free_symbols
        for memlet in memlets:
            defined_variables |= memlet.free_symbols
        defined_variables -= set(params)
        defined_variables = set(
            symbolic.pystr_to_symbolic(p) for p in defined_variables)

    # Propagate subset
    variable_context = [
        defined_variables, [symbolic.pystr_to_symbolic(p) for p in params]
    ]

    new_subset = None
    for md in memlets:
        tmp_subset = None
        for pclass in MemletPattern.extensions():
            pattern = pclass()
            if pattern.can_be_applied([md.subset], variable_context, rng,
                                      [md]):
                tmp_subset = pattern.propagate(arr, [md.subset], rng)
                break
        else:
            # No patterns found. Emit a warning and propagate the entire
            # array
            warnings.warn('Cannot find appropriate memlet pattern to '
                          'propagate %s through %s' %
                          (str(md.subset), str(rng)))
            tmp_subset = subsets.Range.from_array(arr)

        # Union edges as necessary
        if new_subset is None:
            new_subset = tmp_subset
        else:
            old_subset = new_subset
            new_subset = subsets.union(new_subset, tmp_subset)
            if new_subset is None:
                warnings.warn('Subset union failed between %s and %s ' %
                              (old_subset, tmp_subset))
                break

    # Some unions failed
    if new_subset is None:
        new_subset = subsets.Range.from_array(arr)
    ### End of subset propagation

    # Create new memlet
    new_memlet = copy.copy(memlets[0])
    new_memlet.subset = new_subset
    new_memlet.other_subset = None

    # Propagate volume:
    # Number of accesses in the propagated memlet is the sum of the internal
    # number of accesses times the size of the map range set (unbounded dynamic)
    new_memlet.volume = (sum(m.volume for m in memlets) *
                         functools.reduce(lambda a, b: a * b, rng.size(), 1))
    if any(m.dynamic for m in memlets):
        new_memlet.dynamic = True
    elif symbolic.issymbolic(new_memlet.volume) and any(
            s not in defined_variables
            for s in new_memlet.volume.free_symbols):
        new_memlet.dynamic = True
        new_memlet.volume = 0

    return new_memlet
Beispiel #7
0
def propagate_memlet(dfg_state,
                     memlet: Memlet,
                     scope_node: nodes.EntryNode,
                     union_inner_edges: bool,
                     arr=None):
    """ Tries to propagate a memlet through a scope (computes the image of 
        the memlet function applied on an integer set of, e.g., a map range) 
        and returns a new memlet object.
        :param dfg_state: An SDFGState object representing the graph.
        :param memlet: The memlet adjacent to the scope node from the inside.
        :param scope_node: A scope entry or exit node.
        :param union_inner_edges: True if the propagation should take other
                                  neighboring internal memlets within the same
                                  scope into account.
    """
    if isinstance(scope_node, nodes.EntryNode):
        entry_node = scope_node
        neighboring_edges = dfg_state.out_edges(scope_node)
    elif isinstance(scope_node, nodes.ExitNode):
        entry_node = dfg_state.scope_dict()[scope_node]
        neighboring_edges = dfg_state.in_edges(scope_node)
    else:
        raise TypeError('Trying to propagate through a non-scope node')
    if isinstance(memlet, EmptyMemlet):
        return EmptyMemlet()

    sdfg = dfg_state.parent
    defined_vars = [
        symbolic.pystr_to_symbolic(s)
        for s in (sdfg.symbols_defined_at(scope_node, dfg_state).keys())
    ]

    # Find other adjacent edges within the connected to the scope node
    # and union their subsets
    if union_inner_edges:
        aggdata = [
            e.data for e in neighboring_edges
            if e.data.data == memlet.data and e.data != memlet
        ]
    else:
        aggdata = []

    aggdata.append(memlet)

    if arr is None:
        if memlet.data not in sdfg.arrays:
            raise KeyError('Data descriptor (Array, Stream) "%s" not defined '
                           'in SDFG.' % memlet.data)
        arr = sdfg.arrays[memlet.data]

    # Propagate subset
    if isinstance(entry_node, nodes.MapEntry):
        mapnode = entry_node.map

        variable_context = [
            defined_vars,
            [symbolic.pystr_to_symbolic(p) for p in mapnode.params]
        ]

        new_subset = None
        for md in aggdata:
            tmp_subset = None
            for pattern in MemletPattern.patterns():
                if pattern.match([md.subset], variable_context, mapnode.range,
                                 [md]):
                    tmp_subset = pattern.propagate(arr, [md.subset],
                                                   mapnode.range)
                    break
            else:
                # No patterns found. Emit a warning and propagate the entire
                # array
                warnings.warn('Cannot find appropriate memlet pattern to '
                              'propagate %s through %s' %
                              (str(md.subset), str(mapnode.range)))
                tmp_subset = subsets.Range.from_array(arr)

            # Union edges as necessary
            if new_subset is None:
                new_subset = tmp_subset
            else:
                old_subset = new_subset
                new_subset = subsets.union(new_subset, tmp_subset)
                if new_subset is None:
                    warnings.warn('Subset union failed between %s and %s ' %
                                  (old_subset, tmp_subset))

        # Some unions failed
        if new_subset is None:
            new_subset = subsets.Range.from_array(arr)

        assert new_subset is not None

    elif isinstance(entry_node, nodes.ConsumeEntry):
        # Nothing to analyze/propagate in consume
        new_subset = subsets.Range.from_array(arr)
    else:
        raise NotImplementedError('Unimplemented primitive: %s' %
                                  type(scope_node))
    ### End of subset propagation

    new_memlet = copy.copy(memlet)
    new_memlet.subset = new_subset
    new_memlet.other_subset = None

    # Number of accesses in the propagated memlet is the sum of the internal
    # number of accesses times the size of the map range set
    new_memlet.num_accesses = (
        sum(m.num_accesses for m in aggdata) *
        functools.reduce(lambda a, b: a * b, scope_node.map.range.size(), 1))
    if any(m.num_accesses == -1 for m in aggdata):
        memlet.num_accesses = -1
    elif symbolic.issymbolic(memlet.num_accesses) and any(
            s not in defined_vars for s in memlet.num_accesses.free_symbols):
        memlet.num_accesses = -1

    return new_memlet
    def can_be_applied(sdfg: SDFG, subgraph: SubgraphView) -> bool:
        '''
        Fusible if
        1. Maps have the same access sets and ranges in order
        2. Any nodes in between two maps are AccessNodes only, without WCR
           There is at most one AccessNode only on a path between two maps,
           no other nodes are allowed
        3. The exiting memlets' subsets to an intermediate edge must cover
           the respective incoming memlets' subset into the next map.
           Also, as a limitation, the union of all exiting memlets'
           subsets must be contiguous.
        '''
        # get graph
        graph = subgraph.graph
        for node in subgraph.nodes():
            if node not in graph.nodes():
                return False

        # next, get all the maps
        map_entries = helpers.get_outermost_scope_maps(sdfg, graph, subgraph)
        map_exits = [graph.exit_node(map_entry) for map_entry in map_entries]
        maps = [map_entry.map for map_entry in map_entries]

        # 1. basic checks:
        # 1.1 we need to have at least two maps
        if len(maps) <= 1:
            return False
        '''
        # 1.2 Special Case: If we can establish a valid permutation, we can
        #     skip check 1.3
        permutation = self.find_permutation
        '''
        # 1.3 check whether all maps are the same
        base_map = maps[0]
        for map in maps:
            if map.get_param_num() != base_map.get_param_num():
                return False
            if not all(
                [p1 == p2 for (p1, p2) in zip(map.params, base_map.params)]):
                return False
            if not map.range == base_map.range:
                return False
        # 1.3 check whether all map entries have the same schedule
        schedule = map_entries[0].schedule
        if not all([entry.schedule == schedule for entry in map_entries]):
            return False

        # 2. check intermediate feasiblility
        # see map_fusion.py for similar checks
        # with the restrictions below being more relaxed

        # 2.1 do some preparation work first:
        # calculate all out_nodes and intermediate_nodes
        # definition see in apply()
        node_config = SubgraphFusion.get_adjacent_nodes(sdfg, graph,
                                                        map_entries)
        _, intermediate_nodes, out_nodes = node_config

        # 2.2 topological feasibility:
        if not SubgraphFusion.check_topo_feasibility(
                sdfg, graph, map_entries, intermediate_nodes, out_nodes):
            return False

        # 2.3 memlet feasibility
        # For each intermediate node, look at whether inner adjacent
        # memlets of the exiting map cover inner adjacent memlets
        # of the next entering map.
        # We also check for any WCRs on the fly.

        for node in intermediate_nodes:
            upper_subsets = set()
            lower_subsets = set()
            # First, determine which dimensions of the memlet ranges
            # change with the map, we do not need to care about the other dimensions.
            try:
                dims_to_discard = SubgraphFusion.get_invariant_dimensions(
                    sdfg, graph, map_entries, map_exits, node)
            except NotImplementedError:
                return False
            # find upper_subsets
            for in_edge in graph.in_edges(node):
                in_in_edge = graph.memlet_path(in_edge)[-2]
                # first check for WCRs
                if in_edge.data.wcr:
                    # check whether the WCR is actually produced at
                    # this edge or further up in the memlet path. If not,
                    # we can still fuse!
                    subset_params = set(
                        [str(s) for s in in_in_edge.data.subset.free_symbols])
                    if any([
                            p not in subset_params
                            for p in in_edge.src.map.params
                    ]):
                        return False
                if in_edge.src in map_exits:
                    subset_to_add = dcpy(in_in_edge.data.subset\
                                         if in_in_edge.data.data == node.data\
                                         else in_in_edge.data.other_subset)
                    subset_to_add.pop(dims_to_discard)
                    upper_subsets.add(subset_to_add)
                else:
                    raise NotImplementedError("Nodes between two maps to be"
                                              "fused with *incoming* edges"
                                              "from outside the maps are not"
                                              "allowed yet.")

            # find lower_subsets
            for out_edge in graph.out_edges(node):
                if out_edge.dst in map_entries:
                    # cannot use memlet tree here as there could be
                    # not just one map succedding. Do it manually
                    for oedge in graph.out_edges(out_edge.dst):
                        if oedge.src_conn[3:] == out_edge.dst_conn[2:]:
                            subset_to_add = dcpy(oedge.data.subset \
                                                 if oedge.data.data == node.data \
                                                 else oedge.data.other_subset)
                            subset_to_add.pop(dims_to_discard)
                            lower_subsets.add(subset_to_add)

            # We assume that upper_subsets are contiguous
            # Check for this.
            try:
                contiguous_upper = find_contiguous_subsets(upper_subsets)
                if len(contiguous_upper) > 1:
                    return False
            except TypeError:
                warnings.warn(
                    'Could not determine whether subset is continuous.'
                    'Exiting Check with False.')
                return False

            # now take union of upper subsets
            upper_iter = iter(upper_subsets)
            union_upper = next(upper_iter)
            for subs in upper_iter:
                union_upper = subsets.union(union_upper, subs)
                if not union_upper:
                    # something went wrong using union -- we'd rather abort
                    return False

            # finally check coverage
            # every lower subset must be completely covered by union_upper
            for lower_subset in lower_subsets:
                if not union_upper.covers(lower_subset):
                    return False

        return True
    def fuse(self, sdfg, graph, map_entries, do_not_override=None, **kwargs):
        """ takes the map_entries specified and tries to fuse maps.

            all maps have to be extended into outer and inner map
            (use MapExpansion as a pre-pass)

            Arrays that don't exist outside the subgraph get pushed
            into the map and their data dimension gets cropped.
            Otherwise the original array is taken.

            For every output respective connections are crated automatically.

            :param sdfg: SDFG
            :param graph: State
            :param map_entries: Map Entries (class MapEntry) of the outer maps
                                which we want to fuse
            :param do_not_override: List of data names whose corresponding nodes
                                    are fully contained within the subgraph
                                    but should not be augmented/transformed
                                    nevertheless.
        """

        # if there are no maps, return immediately
        if len(map_entries) == 0:
            return

        do_not_override = do_not_override or []

        # get maps and map exits
        maps = [map_entry.map for map_entry in map_entries]
        map_exits = [graph.exit_node(map_entry) for map_entry in map_entries]

        # See function documentation for an explanation of these variables
        node_config = SubgraphFusion.get_adjacent_nodes(sdfg, graph,
                                                        map_entries)
        (in_nodes, intermediate_nodes, out_nodes) = node_config

        if self.debug:
            print("SubgraphFusion::In_nodes", in_nodes)
            print("SubgraphFusion::Out_nodes", out_nodes)
            print("SubgraphFusion::Intermediate_nodes", intermediate_nodes)

        # all maps are assumed to have the same params and range in order
        global_map = nodes.Map(label="outer_fused",
                               params=maps[0].params,
                               ndrange=maps[0].range)
        global_map_entry = nodes.MapEntry(global_map)
        global_map_exit = nodes.MapExit(global_map)

        schedule = map_entries[0].schedule
        global_map_entry.schedule = schedule
        graph.add_node(global_map_entry)
        graph.add_node(global_map_exit)

        # next up, for any intermediate node, find whether it only appears
        # in the subgraph or also somewhere else / as an input
        # create new transients for nodes that are in out_nodes and
        # intermediate_nodes simultaneously
        # also check which dimensions of each transient data element correspond
        # to map axes and write this information into a dict.
        node_info = self.prepare_intermediate_nodes(sdfg, graph, in_nodes, out_nodes, \
                                                    intermediate_nodes,\
                                                    map_entries, map_exits, \
                                                    do_not_override)

        (subgraph_contains_data, transients_created,
         invariant_dimensions) = node_info
        if self.debug:
            print(
                "SubgraphFusion:: {Intermediate_node: subgraph_contains_data} dict"
            )
            print(subgraph_contains_data)

        inconnectors_dict = {}
        # Dict for saving incoming nodes and their assigned connectors
        # Format: {access_node: (edge, in_conn, out_conn)}

        for map_entry, map_exit in zip(map_entries, map_exits):
            # handle inputs
            # TODO: dynamic map range -- this is fairly unrealistic in such a setting
            for edge in graph.in_edges(map_entry):
                src = edge.src
                mmt = graph.memlet_tree(edge)
                out_edges = [child.edge for child in mmt.root().children]

                if src in in_nodes:
                    in_conn = None
                    out_conn = None
                    if src in inconnectors_dict:
                        # no need to augment subset of outer edge.
                        # will do this at the end in one pass.

                        in_conn = inconnectors_dict[src][1]
                        out_conn = inconnectors_dict[src][2]

                    else:
                        next_conn = global_map_entry.next_connector()
                        in_conn = 'IN_' + next_conn
                        out_conn = 'OUT_' + next_conn
                        global_map_entry.add_in_connector(in_conn)
                        global_map_entry.add_out_connector(out_conn)

                        inconnectors_dict[src] = (edge, in_conn, out_conn)

                        # reroute in edge via global_map_entry
                        self.copy_edge(graph, edge, new_dst = global_map_entry, \
                                                        new_dst_conn = in_conn)

                    # map out edges to new map
                    for out_edge in out_edges:
                        self.copy_edge(graph, out_edge, new_src = global_map_entry, \
                                                            new_src_conn = out_conn)

                else:
                    # connect directly
                    for out_edge in out_edges:
                        mm = dcpy(out_edge.data)
                        self.copy_edge(graph,
                                       out_edge,
                                       new_src=src,
                                       new_src_conn=None,
                                       new_data=mm)

            for edge in graph.out_edges(map_entry):
                # special case: for nodes that have no data connections
                if not edge.src_conn:
                    self.copy_edge(graph, edge, new_src=global_map_entry)

            ######################################

            for edge in graph.in_edges(map_exit):
                if not edge.dst_conn:
                    # no destination connector, path ends here.
                    self.copy_edge(graph, edge, new_dst=global_map_exit)
                    continue
                # find corresponding out_edges for current edge, cannot use mmt anymore
                out_edges = [
                    oedge for oedge in graph.out_edges(map_exit)
                    if oedge.src_conn[3:] == edge.dst_conn[2:]
                ]

                # Tuple to store in/out connector port that might be created
                port_created = None

                for out_edge in out_edges:
                    dst = out_edge.dst

                    if dst in intermediate_nodes & out_nodes:

                        # create connection through global map from
                        # dst to dst_transient that was created
                        dst_transient = transients_created[dst]
                        next_conn = global_map_exit.next_connector()
                        in_conn = 'IN_' + next_conn
                        out_conn = 'OUT_' + next_conn
                        global_map_exit.add_in_connector(in_conn)
                        global_map_exit.add_out_connector(out_conn)

                        # for each transient created, create a union
                        # of outgoing memlets' subsets. this is
                        # a cheap fix to override assignments in invariant
                        # dimensions
                        union = None
                        for oe in graph.out_edges(transients_created[dst]):
                            union = subsets.union(union, oe.data.subset)
                        inner_memlet = dcpy(edge.data)
                        for i, s in enumerate(edge.data.subset):
                            if i in invariant_dimensions[dst.label]:
                                inner_memlet.subset[i] = union[i]

                        inner_memlet.other_subset = dcpy(inner_memlet.subset)

                        e_inner = graph.add_edge(dst, None, global_map_exit,
                                                 in_conn, inner_memlet)
                        mm_outer = propagate_memlet(graph, inner_memlet, global_map_entry, \
                                                    union_inner_edges = False)

                        e_outer = graph.add_edge(global_map_exit, out_conn,
                                                 dst_transient, None, mm_outer)

                        # remove edge from dst to dst_transient that was created
                        # in intermediate preparation.
                        for e in graph.out_edges(dst):
                            if e.dst == dst_transient:
                                graph.remove_edge(e)
                                break

                    # handle separately: intermediate_nodes and pure out nodes
                    # case 1: intermediate_nodes: can just redirect edge
                    if dst in intermediate_nodes:
                        self.copy_edge(graph,
                                       out_edge,
                                       new_src=edge.src,
                                       new_src_conn=edge.src_conn,
                                       new_data=dcpy(edge.data))

                    # case 2: pure out node: connect to outer array node
                    if dst in (out_nodes - intermediate_nodes):
                        if edge.dst != global_map_exit:
                            next_conn = global_map_exit.next_connector()
                            in_conn = 'IN_' + next_conn
                            out_conn = 'OUT_' + next_conn
                            global_map_exit.add_in_connector(in_conn)
                            global_map_exit.add_out_connector(out_conn)
                            self.copy_edge(graph,
                                           edge,
                                           new_dst=global_map_exit,
                                           new_dst_conn=in_conn)
                            port_created = (in_conn, out_conn)

                        else:
                            conn_nr = edge.dst_conn[3:]
                            in_conn = port_created.st
                            out_conn = port_created.nd

                        # map
                        graph.add_edge(global_map_exit, out_conn, dst, None,
                                       dcpy(out_edge.data))

            # maps are now ready to be discarded
            # all connected edges will be finally removed as well
            graph.remove_node(map_entry)
            graph.remove_node(map_exit)

        # create a mapping from data arrays to offsets
        # for later memlet adjustments later
        min_offsets = dict()

        # do one pass to augment all transient arrays
        data_intermediate = set([node.data for node in intermediate_nodes])
        for data_name in data_intermediate:
            if subgraph_contains_data[data_name]:
                all_nodes = [
                    n for n in intermediate_nodes if n.data == data_name
                ]
                in_edges = list(chain(*(graph.in_edges(n) for n in all_nodes)))

                in_edges_iter = iter(in_edges)
                in_edge = next(in_edges_iter)
                target_subset = dcpy(in_edge.data.subset)
                target_subset.pop(invariant_dimensions[data_name])
                ######
                while True:
                    try:  # executed if there are multiple in_edges
                        in_edge = next(in_edges_iter)
                        target_subset_curr = dcpy(in_edge.data.subset)
                        target_subset_curr.pop(invariant_dimensions[data_name])
                        target_subset = subsets.union(target_subset, \
                                                      target_subset_curr)
                    except StopIteration:
                        break

                min_offsets_cropped = target_subset.min_element_approx()
                # calculate the new transient array size.
                target_subset.offset(min_offsets_cropped, True)

                # re-add invariant dimensions with offset 0 and save to min_offsets
                min_offset = []
                index = 0
                for i in range(len(sdfg.data(data_name).shape)):
                    if i in invariant_dimensions[data_name]:
                        min_offset.append(0)
                    else:
                        min_offset.append(min_offsets_cropped[index])
                        index += 1

                min_offsets[data_name] = min_offset

                # determine the shape of the new array.
                new_data_shape = []
                index = 0
                for i, sz in enumerate(sdfg.data(data_name).shape):
                    if i in invariant_dimensions[data_name]:
                        new_data_shape.append(sz)
                    else:
                        new_data_shape.append(target_subset.size()[index])
                        index += 1

                new_data_strides = [
                    data._prod(new_data_shape[i + 1:])
                    for i in range(len(new_data_shape))
                ]

                new_data_totalsize = data._prod(new_data_shape)
                new_data_offset = [0] * len(new_data_shape)
                # augment.
                transient_to_transform = sdfg.data(data_name)
                transient_to_transform.shape = new_data_shape
                transient_to_transform.strides = new_data_strides
                transient_to_transform.total_size = new_data_totalsize
                transient_to_transform.offset = new_data_offset
                transient_to_transform.lifetime = dtypes.AllocationLifetime.Scope
                transient_to_transform.storage = self.transient_allocation

            else:
                # don't modify data container - array is needed outside
                # of subgraph.

                # hack: set lifetime to State if allocation has only been
                # scope so far to avoid allocation issues
                if sdfg.data(
                        data_name).lifetime == dtypes.AllocationLifetime.Scope:
                    sdfg.data(
                        data_name).lifetime = dtypes.AllocationLifetime.State

        # do one pass to adjust and the memlets of in-between transients
        for node in intermediate_nodes:
            # all incoming edges to node
            in_edges = graph.in_edges(node)
            # outgoing edges going to another fused part
            out_edges = graph.out_edges(node)

            # memlets of created transient:
            # correct data names
            if node in transients_created:
                transient_in_edges = graph.in_edges(transients_created[node])
                transient_out_edges = graph.out_edges(transients_created[node])
                for edge in chain(transient_in_edges, transient_out_edges):
                    for e in graph.memlet_tree(edge):
                        if e.data.data == node.data:
                            e.data.data += '_OUT'

            # memlets of all in between transients:
            # offset memlets if array has been augmented
            if subgraph_contains_data[node.data]:
                # get min_offset
                min_offset = min_offsets[node.data]
                # re-add invariant dimensions with offset 0
                for iedge in in_edges:
                    for edge in graph.memlet_tree(iedge):
                        if edge.data.data == node.data:
                            edge.data.subset.offset(min_offset, True)
                        elif edge.data.other_subset:
                            edge.data.other_subset.offset(min_offset, True)
                    # nested SDFG: adjust arrays connected
                    if isinstance(iedge.src, nodes.NestedSDFG):
                        nsdfg = iedge.src.sdfg
                        nested_data_name = edge.src_conn
                        self.adjust_arrays_nsdfg(sdfg, nsdfg, node.data,
                                                 nested_data_name)

                for cedge in out_edges:
                    for edge in graph.memlet_tree(cedge):
                        if edge.data.data == node.data:
                            edge.data.subset.offset(min_offset, True)
                        elif edge.data.other_subset:
                            edge.data.other_subset.offset(min_offset, True)
                        # nested SDFG: adjust arrays connected
                        if isinstance(edge.dst, nodes.NestedSDFG):
                            nsdfg = edge.dst.sdfg
                            nested_data_name = edge.dst_conn
                            self.adjust_arrays_nsdfg(sdfg, nsdfg, node.data,
                                                     nested_data_name)

                # if in_edges has several entries:
                # put other_subset into out_edges for correctness
                if len(in_edges) > 1:
                    for oedge in out_edges:
                        if oedge.dst == global_map_exit and \
                                            oedge.data.other_subset is None:
                            oedge.data.other_subset = dcpy(oedge.data.subset)
                            oedge.data.other_subset.offset(min_offset, True)

        # consolidate edges if desired
        if self.consolidate:
            consolidate_edges_scope(graph, global_map_entry)
            consolidate_edges_scope(graph, global_map_exit)

        # propagate edges adjacent to global map entry and exit
        # if desired
        if self.propagate:
            _propagate_node(graph, global_map_entry)
            _propagate_node(graph, global_map_exit)

        # create a hook for outside access to global_map
        self._global_map_entry = global_map_entry
        if self.schedule_innermaps is not None:
            for node in graph.scope_children()[global_map_entry]:
                if isinstance(node, nodes.MapEntry):
                    node.map.schedule = self.schedule_innermaps
Beispiel #10
0
    def can_be_applied(sdfg: SDFG, subgraph: SubgraphView) -> bool:
        '''
        Fusible if
        1. Maps have the same access sets and ranges in order
        2. Any nodes in between two maps are AccessNodes only, without WCR
           There is at most one AccessNode only on a path between two maps,
           no other nodes are allowed
        3. The exiting memlets' subsets to an intermediate edge must cover
           the respective incoming memlets' subset into the next map
        '''
        # get graph
        graph = subgraph.graph
        for node in subgraph.nodes():
            if node not in graph.nodes():
                return False

        # next, get all the maps
        map_entries = helpers.get_highest_scope_maps(sdfg, graph, subgraph)
        map_exits = [graph.exit_node(map_entry) for map_entry in map_entries]
        maps = [map_entry.map for map_entry in map_entries]

        # 1. check whether all map ranges and indices are the same
        if len(maps) <= 1:
            return False
        base_map = maps[0]
        for map in maps:
            if map.get_param_num() != base_map.get_param_num():
                return False
            if not all(
                [p1 == p2 for (p1, p2) in zip(map.params, base_map.params)]):
                return False
            if not map.range == base_map.range:
                return False

        # 1.1 check whether all map entries have the same schedule
        schedule = map_entries[0].schedule
        if not all([entry.schedule == schedule for entry in map_entries]):
            return False

        # 2. check intermediate feasiblility
        # see map_fusion.py for similar checks
        # we are being more relaxed here

        # 2.1 do some preparation work first:
        # calculate all out_nodes and intermediate_nodes
        # definition see in apply()
        intermediate_nodes = set()
        out_nodes = set()
        for map_entry, map_exit in zip(map_entries, map_exits):
            for edge in graph.out_edges(map_exit):
                current_node = edge.dst
                if len(graph.out_edges(current_node)) == 0:
                    out_nodes.add(current_node)
                else:
                    for dst_edge in graph.out_edges(current_node):
                        if dst_edge.dst in map_entries:
                            intermediate_nodes.add(current_node)
                        else:
                            out_nodes.add(current_node)

        # 2.2 topological feasibility:
        # For each intermediate and out node: must never reach any map
        # entry if it is not connected to map entry immediately
        visited = set()

        # for memoization purposes
        def visit_descendants(graph, node, visited, map_entries):
            # if we have already been at this node
            if node in visited:
                return True
            # not necessary to add if there aren't any other in connections
            if len(graph.in_edges(node)) > 1:
                visited.add(node)
            for oedge in graph.out_edges(node):
                if not visit_descendants(graph, oedge.dst, visited,
                                         map_entries):
                    return False
            return True

        for node in intermediate_nodes | out_nodes:
            # these nodes must not lead to a map entry
            nodes_to_check = set()
            for oedge in graph.out_edges(node):
                if oedge.dst not in map_entries:
                    nodes_to_check.add(oedge.dst)

            for forbidden_node in nodes_to_check:
                if not visit_descendants(graph, forbidden_node, visited,
                                         map_entries):
                    return False

        # 2.3 memlet feasibility
        # For each intermediate node, look at whether inner adjacent
        # memlets of the exiting map cover inner adjacent memlets
        # of the next entering map.
        # We also check for any WCRs on the fly.

        for node in intermediate_nodes:
            upper_subsets = set()
            lower_subsets = set()
            # First, determine which dimensions of the memlet ranges
            # change with the map, we do not need to care about the other dimensions.
            total_dims = len(sdfg.data(node.data).shape)
            dims_to_discard = SubgraphFusion.get_invariant_dimensions(
                sdfg, graph, map_entries, map_exits, node)

            # find upper_subsets
            for in_edge in graph.in_edges(node):
                # first check for WCRs
                if in_edge.data.wcr:
                    return False
                if in_edge.src in map_exits:
                    edge = graph.memlet_path(in_edge)[-2]
                    subset_to_add = dcpy(edge.data.subset\
                                         if edge.data.data == node.data\
                                         else edge.data.other_subset)
                    subset_to_add.pop(dims_to_discard)
                    upper_subsets.add(subset_to_add)
                else:
                    raise NotImplementedError("Nodes between two maps to be"
                                              "fused with *incoming* edges"
                                              "from outside the maps are not"
                                              "allowed yet.")

            # find lower_subsets
            for out_edge in graph.out_edges(node):
                if out_edge.dst in map_entries:
                    # cannot use memlet tree here as there could be
                    # not just one map succedding. Do it manually
                    for oedge in graph.out_edges(out_edge.dst):
                        if oedge.src_conn[3:] == out_edge.dst_conn[2:]:
                            subset_to_add = dcpy(oedge.data.subset \
                                                 if edge.data.data == node.data \
                                                 else edge.data.other_subset)
                            subset_to_add.pop(dims_to_discard)
                            lower_subsets.add(subset_to_add)

            upper_iter = iter(upper_subsets)
            union_upper = next(upper_iter)

            # TODO: add this check at a later point
            # We assume that upper_subsets for each data array
            # are contiguous
            # or do the full check if possible (intersection needed)
            '''
            # check whether subsets in upper_subsets are adjacent.
            # this is a requriement for the current implementation
            #try:
            # O(n^2*|dims|) but very small amount of subsets anyway
            try:
                for dim in range(total_dims - len(dims_to_discard)):
                    ordered_list = [(-1,-1,-1)]
                    for upper_subset in upper_subsets:
                        lo = upper_subset[dim][0]
                        hi = upper_subset[dim][1]
                        for idx,element in enumerate(ordered_list):
                            if element[0] <= lo and element[1] >= hi:
                                break
                            if element[0] > lo:
                                ordered_list.insert(idx, (lo,hi))
                    ordered_list.pop(0)


                    highest = ordered_list[0][1]
                    for i in range(len(ordered_list)):
                        if i < len(ordered_list)-1:
                            current_range = ordered_list[i]
                            if current_range[1] > highest:
                                hightest = current_range[1]
                            next_range = ordered_list[i+1]
                            if highest < next_range[0] - 1:
                                return False
            except TypeError:
                #return False
            '''
            # FORNOW: just omit warning if unsure
            for lower_subset in lower_subsets:
                covers = False
                for upper_subset in upper_subsets:
                    if upper_subset.covers(lower_subset):
                        covers = True
                        break
                if not covers:
                    warnings.warn(
                        f"WARNING: For node {node}, please check assure that"
                        "incoming memlets cover outgoing ones. Ambiguous check (WIP)."
                    )

            # now take union of upper subsets
            for subs in upper_iter:
                union_upper = subsets.union(union_upper, subs)
                if not union_upper:
                    # something went wrong using union -- we'd rather abort
                    return False

            # finally check coverage
            for lower_subset in lower_subsets:
                if not union_upper.covers(lower_subset):
                    return False

        return True
Beispiel #11
0
    def fuse(self, sdfg, graph, map_entries, do_not_override=[], **kwargs):
        """ takes the map_entries specified and tries to fuse maps.

            all maps have to be extended into outer and inner map
            (use MapExpansion as a pre-pass)

            Arrays that don't exist outside the subgraph get pushed
            into the map and their data dimension gets cropped.
            Otherwise the original array is taken.

            For every output respective connections are crated automatically.

            :param sdfg: SDFG
            :param graph: State
            :param map_entries: Map Entries (class MapEntry) of the outer maps
                                which we want to fuse
            :param do_not_override: List of data names whose corresponding nodes
                                    are fully contained within the subgraph
                                    but should not be augmented/transformed
                                    nevertheless.
        """

        # if there are no maps, return immediately
        if len(map_entries) == 0:
            return

        # get maps and map exits
        maps = [map_entry.map for map_entry in map_entries]
        map_exits = [graph.exit_node(map_entry) for map_entry in map_entries]

        # re-construct the map subgraph if necessary
        try:
            self.subgraph
        except AttributeError:
            subgraph_nodes = set()
            scope_dict = graph.scope_dict(node_to_children=True)
            for node in chain(map_entries, map_exits):
                subgraph_nodes.add(node)
                # add all border arrays
                for e in chain(graph.in_edges(node), graph.out_edges(node)):
                    subgraph_nodes.add(e.src)
                    subgraph_nodes.add(e.dst)
                try:
                    subgraph_nodes |= set(scope_dict[node])
                except KeyError:
                    pass
            self.subgraph = SubgraphView(graph, subgraph_nodes)

        # Nodes that flow into one or several maps but no data is flowed to them from any map
        in_nodes = set()

        # Nodes into which data is flowed but that no data flows into any map from them
        out_nodes = set()

        # Nodes that act as intermediate node - data flows from a map into them and then there
        # is an outgoing path into another map
        intermediate_nodes = set()

        ### NOTE:
        #- in_nodes, out_nodes, intermediate_nodes refer to the configuration of the final fused map
        #- in_nodes and out_nodes are trivially disjoint
        #- Intermediate_nodes and out_nodes are not necessarily disjoint
        #- Intermediate_nodes and in_nodes are disjoint by design.
        #  There could be a node that has both incoming edges from a map exit
        #  and from outside, but it is just treated as intermediate_node and handled
        #  automatically.

        for map_entry, map_exit in zip(map_entries, map_exits):
            for edge in graph.in_edges(map_entry):
                in_nodes.add(edge.src)
            for edge in graph.out_edges(map_exit):
                current_node = edge.dst
                if len(graph.out_edges(current_node)) == 0:
                    out_nodes.add(current_node)
                else:
                    for dst_edge in graph.out_edges(current_node):
                        if dst_edge.dst in map_entries:
                            # add to intermediate_nodes
                            intermediate_nodes.add(current_node)

                        else:
                            # add to out_nodes
                            out_nodes.add(current_node)
                for e in graph.in_edges(current_node):
                    if e.src not in map_exits:
                        raise NotImplementedError(
                            "Nodes between two maps to be"
                            "fused with *incoming* edges"
                            "from outside the maps are not"
                            "allowed yet.")

        # any intermediate_nodes currently in in_nodes shouldnt be there
        in_nodes -= intermediate_nodes

        if self.debug:
            print("SubgraphFusion::In_nodes", in_nodes)
            print("SubgraphFusion::Out_nodes", out_nodes)
            print("SubgraphFusion::Intermediate_nodes", intermediate_nodes)

        # all maps are assumed to have the same params and range in order
        global_map = nodes.Map(label="outer_fused",
                               params=maps[0].params,
                               ndrange=maps[0].range)
        global_map_entry = nodes.MapEntry(global_map)
        global_map_exit = nodes.MapExit(global_map)

        schedule = map_entries[0].schedule
        global_map_entry.schedule = schedule
        graph.add_node(global_map_entry)
        graph.add_node(global_map_exit)

        # next up, for any intermediate node, find whether it only appears
        # in the subgraph or also somewhere else / as an input
        # create new transients for nodes that are in out_nodes and
        # intermediate_nodes simultaneously
        # also check which dimensions of each transient data element correspond
        # to map axes and write this information into a dict.
        node_info = self.prepare_intermediate_nodes(sdfg, graph, in_nodes, out_nodes, \
                                                    intermediate_nodes,\
                                                    map_entries, map_exits, \
                                                    do_not_override)

        (subgraph_contains_data, transients_created,
         invariant_dimensions) = node_info
        if self.debug:
            print(
                "SubgraphFusion:: {Intermediate_node: subgraph_contains_data} dict"
            )
            print(subgraph_contains_data)

        inconnectors_dict = {}
        # Dict for saving incoming nodes and their assigned connectors
        # Format: {access_node: (edge, in_conn, out_conn)}

        for map_entry, map_exit in zip(map_entries, map_exits):
            # handle inputs
            # TODO: dynamic map range -- this is fairly unrealistic in such a setting
            for edge in graph.in_edges(map_entry):
                src = edge.src
                mmt = graph.memlet_tree(edge)
                out_edges = [child.edge for child in mmt.root().children]

                if src in in_nodes:
                    in_conn = None
                    out_conn = None
                    if src in inconnectors_dict:
                        # no need to augment subset of outer edge.
                        # will do this at the end in one pass.

                        in_conn = inconnectors_dict[src][1]
                        out_conn = inconnectors_dict[src][2]
                        graph.remove_edge(edge)

                    else:
                        next_conn = global_map_entry.next_connector()
                        in_conn = 'IN_' + next_conn
                        out_conn = 'OUT_' + next_conn
                        global_map_entry.add_in_connector(in_conn)
                        global_map_entry.add_out_connector(out_conn)

                        inconnectors_dict[src] = (edge, in_conn, out_conn)

                        # reroute in edge via global_map_entry
                        self.redirect_edge(graph, edge, new_dst = global_map_entry, \
                                                        new_dst_conn = in_conn)

                    # map out edges to new map
                    for out_edge in out_edges:
                        self.redirect_edge(graph, out_edge, new_src = global_map_entry, \
                                                            new_src_conn = out_conn)

                else:
                    # connect directly
                    for out_edge in out_edges:
                        mm = dcpy(out_edge.data)
                        self.redirect_edge(graph,
                                           out_edge,
                                           new_src=src,
                                           new_data=mm)

                    graph.remove_edge(edge)

            for edge in graph.out_edges(map_entry):
                # special case: for nodes that have no data connections
                if not edge.src_conn:
                    self.redirect_edge(graph, edge, new_src=global_map_entry)

            ######################################

            for edge in graph.in_edges(map_exit):
                if not edge.dst_conn:
                    # no destination connector, path ends here.
                    self.redirect_edge(graph, edge, new_dst=global_map_exit)
                    continue
                # find corresponding out_edges for current edge, cannot use mmt anymore
                out_edges = [
                    oedge for oedge in graph.out_edges(map_exit)
                    if oedge.src_conn[3:] == edge.dst_conn[2:]
                ]

                # Tuple to store in/out connector port that might be created
                port_created = None

                for out_edge in out_edges:
                    dst = out_edge.dst

                    if dst in intermediate_nodes & out_nodes:

                        # create connection through global map from
                        # dst to dst_transient that was created
                        dst_transient = transients_created[dst]
                        next_conn = global_map_exit.next_connector()
                        in_conn = 'IN_' + next_conn
                        out_conn = 'OUT_' + next_conn
                        global_map_exit.add_in_connector(in_conn)
                        global_map_exit.add_out_connector(out_conn)

                        inner_memlet = dcpy(edge.data)
                        inner_memlet.other_subset = dcpy(edge.data.subset)

                        e_inner = graph.add_edge(dst, None, global_map_exit,
                                                 in_conn, inner_memlet)
                        mm_outer = propagate_memlet(graph, inner_memlet, global_map_entry, \
                                                    union_inner_edges = False)

                        e_outer = graph.add_edge(global_map_exit, out_conn,
                                                 dst_transient, None, mm_outer)

                        # remove edge from dst to dst_transient that was created
                        # in intermediate preparation.
                        for e in graph.out_edges(dst):
                            if e.dst == dst_transient:
                                graph.remove_edge(e)
                                removed = True
                                break

                        if self.debug:
                            assert removed == True

                    # handle separately: intermediate_nodes and pure out nodes
                    # case 1: intermediate_nodes: can just redirect edge
                    if dst in intermediate_nodes:
                        self.redirect_edge(graph,
                                           out_edge,
                                           new_src=edge.src,
                                           new_src_conn=edge.src_conn,
                                           new_data=dcpy(edge.data))

                    # case 2: pure out node: connect to outer array node
                    if dst in (out_nodes - intermediate_nodes):
                        if edge.dst != global_map_exit:
                            next_conn = global_map_exit.next_connector()
                            in_conn = 'IN_' + next_conn
                            out_conn = 'OUT_' + next_conn
                            global_map_exit.add_in_connector(in_conn)
                            global_map_exit.add_out_connector(out_conn)
                            self.redirect_edge(graph,
                                               edge,
                                               new_dst=global_map_exit,
                                               new_dst_conn=in_conn)
                            port_created = (in_conn, out_conn)
                            #edge.dst = global_map_exit
                            #edge.dst_conn = in_conn

                        else:
                            conn_nr = edge.dst_conn[3:]
                            in_conn = port_created.st
                            out_conn = port_created.nd

                        # map
                        graph.add_edge(global_map_exit, out_conn, dst, None,
                                       dcpy(out_edge.data))
                        graph.remove_edge(out_edge)

                # remove the edge if it has not been used by any pure out node
                if not port_created:
                    graph.remove_edge(edge)

            # maps are now ready to be discarded
            graph.remove_node(map_entry)
            graph.remove_node(map_exit)

            # end main loop.

        # create a mapping from data arrays to offsets
        # for later memlet adjustments later
        min_offsets = dict()

        # do one pass to augment all transient arrays
        data_intermediate = set([node.data for node in intermediate_nodes])
        for data_name in data_intermediate:
            if subgraph_contains_data[data_name]:
                all_nodes = [
                    n for n in intermediate_nodes if n.data == data_name
                ]
                in_edges = list(chain(*(graph.in_edges(n) for n in all_nodes)))

                in_edges_iter = iter(in_edges)
                in_edge = next(in_edges_iter)
                target_subset = dcpy(in_edge.data.subset)
                target_subset.pop(invariant_dimensions[data_name])
                ######
                while True:
                    try:  # executed if there are multiple in_edges
                        in_edge = next(in_edges_iter)
                        target_subset_curr = dcpy(in_edge.data.subset)
                        target_subset_curr.pop(invariant_dimensions[data_name])
                        target_subset = subsets.union(target_subset, \
                                                      target_subset_curr)
                    except StopIteration:
                        break

                min_offsets_cropped = target_subset.min_element_approx()
                # calculate the new transient array size.
                target_subset.offset(min_offsets_cropped, True)

                # re-add invariant dimensions with offset 0 and save to min_offsets
                min_offset = []
                index = 0
                for i in range(len(sdfg.data(data_name).shape)):
                    if i in invariant_dimensions[data_name]:
                        min_offset.append(0)
                    else:
                        min_offset.append(min_offsets_cropped[index])
                        index += 1

                min_offsets[data_name] = min_offset

                # determine the shape of the new array.
                new_data_shape = []
                index = 0
                for i, sz in enumerate(sdfg.data(data_name).shape):
                    if i in invariant_dimensions[data_name]:
                        new_data_shape.append(sz)
                    else:
                        new_data_shape.append(target_subset.size()[index])
                        index += 1

                new_data_strides = [
                    data._prod(new_data_shape[i + 1:])
                    for i in range(len(new_data_shape))
                ]

                new_data_totalsize = data._prod(new_data_shape)
                new_data_offset = [0] * len(new_data_shape)
                # augment.
                transient_to_transform = sdfg.data(data_name)
                transient_to_transform.shape = new_data_shape
                transient_to_transform.strides = new_data_strides
                transient_to_transform.total_size = new_data_totalsize
                transient_to_transform.offset = new_data_offset
                transient_to_transform.lifetime = dtypes.AllocationLifetime.Scope
                transient_to_transform.storage = self.transient_allocation

            else:
                # don't modify data container - array is needed outside
                # of subgraph.

                # hack: set lifetime to State if allocation has only been
                # scope so far to avoid allocation issues
                if sdfg.data(
                        data_name).lifetime == dtypes.AllocationLifetime.Scope:
                    sdfg.data(
                        data_name).lifetime = dtypes.AllocationLifetime.State

        # do one pass to adjust and the memlets of in-between transients
        for node in intermediate_nodes:
            # all incoming edges to node
            in_edges = graph.in_edges(node)
            # outgoing edges going to another fused part
            inter_edges = []
            # outgoing edges that exit global map
            out_edges = []
            for e in graph.out_edges(node):
                if e.dst == global_map_exit:
                    out_edges.append(e)
                else:
                    inter_edges.append(e)

            # offset memlets where necessary
            if subgraph_contains_data[node.data]:
                # get min_offset
                min_offset = min_offsets[node.data]
                # re-add invariant dimensions with offset 0
                for iedge in in_edges:
                    for edge in graph.memlet_tree(iedge):
                        if edge.data.data == node.data:
                            edge.data.subset.offset(min_offset, True)
                        elif edge.data.other_subset:
                            edge.data.other_subset.offset(min_offset, True)

                for cedge in inter_edges:
                    for edge in graph.memlet_tree(cedge):
                        if edge.data.data == node.data:
                            edge.data.subset.offset(min_offset, True)
                        elif edge.data.other_subset:
                            edge.data.other_subset.offset(min_offset, True)

                # if in_edges has several entries:
                # put other_subset into out_edges for correctness
                if len(in_edges) > 1:
                    for oedge in out_edges:
                        oedge.data.other_subset = dcpy(oedge.data.subset)
                        oedge.data.other_subset.offset(min_offset, True)

            # also correct memlets of created transient
            if node in transients_created:
                transient_in_edges = graph.in_edges(transients_created[node])
                transient_out_edges = graph.out_edges(transients_created[node])
                for edge in chain(transient_in_edges, transient_out_edges):
                    for e in graph.memlet_tree(edge):
                        if e.data.data == node.data:
                            e.data.data += '_OUT'

        # do one last pass to correct outside memlets adjacent to global map
        for out_connector in global_map_entry.out_connectors:
            # find corresponding in_connector
            # and the in-connecting edge
            in_connector = 'IN' + out_connector[3:]
            for iedge in graph.in_edges(global_map_entry):
                if iedge.dst_conn == in_connector:
                    in_edge = iedge

            # find corresponding out_connector
            # and all out-connecting edges that belong to it
            # count them
            oedge_counter = 0
            for oedge in graph.out_edges(global_map_entry):
                if oedge.src_conn == out_connector:
                    out_edge = oedge
                    oedge_counter += 1

            # do memlet propagation
            # if there are several out edges, else there is no need

            if oedge_counter > 1:
                memlet_out = propagate_memlet(dfg_state=graph,
                                              memlet=out_edge.data,
                                              scope_node=global_map_entry,
                                              union_inner_edges=True)
                # override number of accesses
                in_edge.data.volume = memlet_out.volume
                in_edge.data.subset = memlet_out.subset

        # create a hook for outside access to global_map
        self._global_map_entry = global_map_entry