Ejemplo n.º 1
0
    def apply(self, sdfg: SDFG):
        # Extract the parameters and ranges of the inner/outer maps.
        graph: SDFGState = sdfg.nodes()[self.state_id]
        outer_map_entry = graph.nodes()[self.subgraph[
            MapInterchange.outer_map_entry]]
        inner_map_entry = graph.nodes()[self.subgraph[
            MapInterchange.inner_map_entry]]
        inner_map_exit = graph.exit_node(inner_map_entry)
        outer_map_exit = graph.exit_node(outer_map_entry)

        # Switch connectors
        outer_map_entry.in_connectors, inner_map_entry.in_connectors = \
            inner_map_entry.in_connectors, outer_map_entry.in_connectors
        outer_map_entry.out_connectors, inner_map_entry.out_connectors = \
            inner_map_entry.out_connectors, outer_map_entry.out_connectors
        outer_map_exit.in_connectors, inner_map_exit.in_connectors = \
            inner_map_exit.in_connectors, outer_map_exit.in_connectors
        outer_map_exit.out_connectors, inner_map_exit.out_connectors = \
            inner_map_exit.out_connectors, outer_map_exit.out_connectors

        # Get edges between the map entries and exits.
        entry_edges = graph.edges_between(outer_map_entry, inner_map_entry)
        exit_edges = graph.edges_between(inner_map_exit, outer_map_exit)
        for e in entry_edges + exit_edges:
            graph.remove_edge(e)

        # Change source and destination of edges.
        sdutil.change_edge_dest(graph, outer_map_entry, inner_map_entry)
        sdutil.change_edge_src(graph, inner_map_entry, outer_map_entry)
        sdutil.change_edge_dest(graph, inner_map_exit, outer_map_exit)
        sdutil.change_edge_src(graph, outer_map_exit, inner_map_exit)

        # Add edges between the map entries and exits.
        new_entry_edges = []
        new_exit_edges = []
        for e in entry_edges:
            new_entry_edges.append(
                graph.add_edge(e.dst, e.src_conn, e.src, e.dst_conn, e.data))
        for e in exit_edges:
            new_exit_edges.append(
                graph.add_edge(e.dst, e.src_conn, e.src, e.dst_conn, e.data))

        # Repropagate memlets in modified region
        for e in new_entry_edges:
            path = graph.memlet_path(e)
            index = next(i for i, edge in enumerate(path) if e is edge)
            e.data.subset = propagate_memlet(graph, path[index + 1].data,
                                             outer_map_entry, True).subset
        for e in new_exit_edges:
            path = graph.memlet_path(e)
            index = next(i for i, edge in enumerate(path) if e is edge)
            e.data.subset = propagate_memlet(graph, path[index - 1].data,
                                             outer_map_exit, True).subset
Ejemplo n.º 2
0
    def apply(self, sdfg):
        graph = sdfg.node(self.state_id)
        array = graph.node(self.subgraph[InMergeArrays._array1])
        map = graph.node(self.subgraph[InMergeArrays._map_entry])
        map_edge = next(e for e in graph.out_edges(array) if e.dst == map)
        result_connector = map_edge.dst_conn[3:]

        # Find all other incoming access nodes without incoming edges
        source_edges = [
            e for e in graph.in_edges(map)
            if isinstance(e.src, nodes.AccessNode) and e.src.data == array.data
            and e.src != array and e.dst_conn and e.dst_conn.startswith('IN_')
            and graph.in_degree(e.src) == 0
        ]

        # Modify connectors to point to first array
        connectors_to_remove = set()
        for e in source_edges:
            connector = e.dst_conn[3:]
            connectors_to_remove.add(connector)
            for inner_edge in graph.out_edges(map):
                if inner_edge.src_conn[4:] == connector:
                    inner_edge._src_conn = 'OUT_' + result_connector

        # Remove other nodes from state
        graph.remove_nodes_from(set(e.src for e in source_edges))

        # Remove connectors from scope entry
        for c in connectors_to_remove:
            map.remove_in_connector('IN_' + c)
            map.remove_out_connector('OUT_' + c)

        # Re-propagate memlets
        edge_to_propagate = next(e for e in graph.out_edges(map)
                                 if e.src_conn[4:] == result_connector)
        map_edge._data = propagate_memlet(dfg_state=graph,
                                          memlet=edge_to_propagate.data,
                                          scope_node=map,
                                          union_inner_edges=True)
Ejemplo n.º 3
0
    def apply(self, graph, sdfg):
        array = self.array1
        map = self.map_exit
        map_edge = next(e for e in graph.in_edges(array) if e.src == map)
        result_connector = map_edge.src_conn[4:]

        # Find all other outgoing access nodes without outgoing edges
        dst_edges = [
            e for e in graph.out_edges(map)
            if isinstance(e.dst, nodes.AccessNode) and e.dst.data == array.data
            and e.dst != array and e.src_conn and e.src_conn.startswith('OUT_')
            and graph.out_degree(e.dst) == 0
        ]

        # Modify connectors to point to first array
        connectors_to_remove = set()
        for e in dst_edges:
            connector = e.src_conn[4:]
            connectors_to_remove.add(connector)
            for inner_edge in graph.in_edges(map):
                if inner_edge.dst_conn[3:] == connector:
                    inner_edge.dst_conn = 'IN_' + result_connector

        # Remove other nodes from state
        graph.remove_nodes_from(set(e.dst for e in dst_edges))

        # Remove connectors from scope entry
        for c in connectors_to_remove:
            map.remove_in_connector('IN_' + c)
            map.remove_out_connector('OUT_' + c)

        # Re-propagate memlets
        edge_to_propagate = next(e for e in graph.in_edges(map)
                                 if e.dst_conn[3:] == result_connector)
        map_edge._data = propagate_memlet(dfg_state=graph,
                                          memlet=edge_to_propagate.data,
                                          scope_node=map,
                                          union_inner_edges=True)
Ejemplo n.º 4
0
    def fuse(self, sdfg, graph, map_entries, do_not_override=None, **kwargs):
        """ takes the map_entries specified and tries to fuse maps.

            all maps have to be extended into outer and inner map
            (use MapExpansion as a pre-pass)

            Arrays that don't exist outside the subgraph get pushed
            into the map and their data dimension gets cropped.
            Otherwise the original array is taken.

            For every output respective connections are crated automatically.

            :param sdfg: SDFG
            :param graph: State
            :param map_entries: Map Entries (class MapEntry) of the outer maps
                                which we want to fuse
            :param do_not_override: List of data names whose corresponding nodes
                                    are fully contained within the subgraph
                                    but should not be augmented/transformed
                                    nevertheless.
        """

        # if there are no maps, return immediately
        if len(map_entries) == 0:
            return

        do_not_override = do_not_override or []

        # get maps and map exits
        maps = [map_entry.map for map_entry in map_entries]
        map_exits = [graph.exit_node(map_entry) for map_entry in map_entries]

        # See function documentation for an explanation of these variables
        node_config = SubgraphFusion.get_adjacent_nodes(sdfg, graph,
                                                        map_entries)
        (in_nodes, intermediate_nodes, out_nodes) = node_config

        if self.debug:
            print("SubgraphFusion::In_nodes", in_nodes)
            print("SubgraphFusion::Out_nodes", out_nodes)
            print("SubgraphFusion::Intermediate_nodes", intermediate_nodes)

        # all maps are assumed to have the same params and range in order
        global_map = nodes.Map(label="outer_fused",
                               params=maps[0].params,
                               ndrange=maps[0].range)
        global_map_entry = nodes.MapEntry(global_map)
        global_map_exit = nodes.MapExit(global_map)

        schedule = map_entries[0].schedule
        global_map_entry.schedule = schedule
        graph.add_node(global_map_entry)
        graph.add_node(global_map_exit)

        # next up, for any intermediate node, find whether it only appears
        # in the subgraph or also somewhere else / as an input
        # create new transients for nodes that are in out_nodes and
        # intermediate_nodes simultaneously
        # also check which dimensions of each transient data element correspond
        # to map axes and write this information into a dict.
        node_info = self.prepare_intermediate_nodes(sdfg, graph, in_nodes, out_nodes, \
                                                    intermediate_nodes,\
                                                    map_entries, map_exits, \
                                                    do_not_override)

        (subgraph_contains_data, transients_created,
         invariant_dimensions) = node_info
        if self.debug:
            print(
                "SubgraphFusion:: {Intermediate_node: subgraph_contains_data} dict"
            )
            print(subgraph_contains_data)

        inconnectors_dict = {}
        # Dict for saving incoming nodes and their assigned connectors
        # Format: {access_node: (edge, in_conn, out_conn)}

        for map_entry, map_exit in zip(map_entries, map_exits):
            # handle inputs
            # TODO: dynamic map range -- this is fairly unrealistic in such a setting
            for edge in graph.in_edges(map_entry):
                src = edge.src
                mmt = graph.memlet_tree(edge)
                out_edges = [child.edge for child in mmt.root().children]

                if src in in_nodes:
                    in_conn = None
                    out_conn = None
                    if src in inconnectors_dict:
                        # no need to augment subset of outer edge.
                        # will do this at the end in one pass.

                        in_conn = inconnectors_dict[src][1]
                        out_conn = inconnectors_dict[src][2]

                    else:
                        next_conn = global_map_entry.next_connector()
                        in_conn = 'IN_' + next_conn
                        out_conn = 'OUT_' + next_conn
                        global_map_entry.add_in_connector(in_conn)
                        global_map_entry.add_out_connector(out_conn)

                        inconnectors_dict[src] = (edge, in_conn, out_conn)

                        # reroute in edge via global_map_entry
                        self.copy_edge(graph, edge, new_dst = global_map_entry, \
                                                        new_dst_conn = in_conn)

                    # map out edges to new map
                    for out_edge in out_edges:
                        self.copy_edge(graph, out_edge, new_src = global_map_entry, \
                                                            new_src_conn = out_conn)

                else:
                    # connect directly
                    for out_edge in out_edges:
                        mm = dcpy(out_edge.data)
                        self.copy_edge(graph,
                                       out_edge,
                                       new_src=src,
                                       new_src_conn=None,
                                       new_data=mm)

            for edge in graph.out_edges(map_entry):
                # special case: for nodes that have no data connections
                if not edge.src_conn:
                    self.copy_edge(graph, edge, new_src=global_map_entry)

            ######################################

            for edge in graph.in_edges(map_exit):
                if not edge.dst_conn:
                    # no destination connector, path ends here.
                    self.copy_edge(graph, edge, new_dst=global_map_exit)
                    continue
                # find corresponding out_edges for current edge, cannot use mmt anymore
                out_edges = [
                    oedge for oedge in graph.out_edges(map_exit)
                    if oedge.src_conn[3:] == edge.dst_conn[2:]
                ]

                # Tuple to store in/out connector port that might be created
                port_created = None

                for out_edge in out_edges:
                    dst = out_edge.dst

                    if dst in intermediate_nodes & out_nodes:

                        # create connection through global map from
                        # dst to dst_transient that was created
                        dst_transient = transients_created[dst]
                        next_conn = global_map_exit.next_connector()
                        in_conn = 'IN_' + next_conn
                        out_conn = 'OUT_' + next_conn
                        global_map_exit.add_in_connector(in_conn)
                        global_map_exit.add_out_connector(out_conn)

                        # for each transient created, create a union
                        # of outgoing memlets' subsets. this is
                        # a cheap fix to override assignments in invariant
                        # dimensions
                        union = None
                        for oe in graph.out_edges(transients_created[dst]):
                            union = subsets.union(union, oe.data.subset)
                        inner_memlet = dcpy(edge.data)
                        for i, s in enumerate(edge.data.subset):
                            if i in invariant_dimensions[dst.label]:
                                inner_memlet.subset[i] = union[i]

                        inner_memlet.other_subset = dcpy(inner_memlet.subset)

                        e_inner = graph.add_edge(dst, None, global_map_exit,
                                                 in_conn, inner_memlet)
                        mm_outer = propagate_memlet(graph, inner_memlet, global_map_entry, \
                                                    union_inner_edges = False)

                        e_outer = graph.add_edge(global_map_exit, out_conn,
                                                 dst_transient, None, mm_outer)

                        # remove edge from dst to dst_transient that was created
                        # in intermediate preparation.
                        for e in graph.out_edges(dst):
                            if e.dst == dst_transient:
                                graph.remove_edge(e)
                                break

                    # handle separately: intermediate_nodes and pure out nodes
                    # case 1: intermediate_nodes: can just redirect edge
                    if dst in intermediate_nodes:
                        self.copy_edge(graph,
                                       out_edge,
                                       new_src=edge.src,
                                       new_src_conn=edge.src_conn,
                                       new_data=dcpy(edge.data))

                    # case 2: pure out node: connect to outer array node
                    if dst in (out_nodes - intermediate_nodes):
                        if edge.dst != global_map_exit:
                            next_conn = global_map_exit.next_connector()
                            in_conn = 'IN_' + next_conn
                            out_conn = 'OUT_' + next_conn
                            global_map_exit.add_in_connector(in_conn)
                            global_map_exit.add_out_connector(out_conn)
                            self.copy_edge(graph,
                                           edge,
                                           new_dst=global_map_exit,
                                           new_dst_conn=in_conn)
                            port_created = (in_conn, out_conn)

                        else:
                            conn_nr = edge.dst_conn[3:]
                            in_conn = port_created.st
                            out_conn = port_created.nd

                        # map
                        graph.add_edge(global_map_exit, out_conn, dst, None,
                                       dcpy(out_edge.data))

            # maps are now ready to be discarded
            # all connected edges will be finally removed as well
            graph.remove_node(map_entry)
            graph.remove_node(map_exit)

        # create a mapping from data arrays to offsets
        # for later memlet adjustments later
        min_offsets = dict()

        # do one pass to augment all transient arrays
        data_intermediate = set([node.data for node in intermediate_nodes])
        for data_name in data_intermediate:
            if subgraph_contains_data[data_name]:
                all_nodes = [
                    n for n in intermediate_nodes if n.data == data_name
                ]
                in_edges = list(chain(*(graph.in_edges(n) for n in all_nodes)))

                in_edges_iter = iter(in_edges)
                in_edge = next(in_edges_iter)
                target_subset = dcpy(in_edge.data.subset)
                target_subset.pop(invariant_dimensions[data_name])
                ######
                while True:
                    try:  # executed if there are multiple in_edges
                        in_edge = next(in_edges_iter)
                        target_subset_curr = dcpy(in_edge.data.subset)
                        target_subset_curr.pop(invariant_dimensions[data_name])
                        target_subset = subsets.union(target_subset, \
                                                      target_subset_curr)
                    except StopIteration:
                        break

                min_offsets_cropped = target_subset.min_element_approx()
                # calculate the new transient array size.
                target_subset.offset(min_offsets_cropped, True)

                # re-add invariant dimensions with offset 0 and save to min_offsets
                min_offset = []
                index = 0
                for i in range(len(sdfg.data(data_name).shape)):
                    if i in invariant_dimensions[data_name]:
                        min_offset.append(0)
                    else:
                        min_offset.append(min_offsets_cropped[index])
                        index += 1

                min_offsets[data_name] = min_offset

                # determine the shape of the new array.
                new_data_shape = []
                index = 0
                for i, sz in enumerate(sdfg.data(data_name).shape):
                    if i in invariant_dimensions[data_name]:
                        new_data_shape.append(sz)
                    else:
                        new_data_shape.append(target_subset.size()[index])
                        index += 1

                new_data_strides = [
                    data._prod(new_data_shape[i + 1:])
                    for i in range(len(new_data_shape))
                ]

                new_data_totalsize = data._prod(new_data_shape)
                new_data_offset = [0] * len(new_data_shape)
                # augment.
                transient_to_transform = sdfg.data(data_name)
                transient_to_transform.shape = new_data_shape
                transient_to_transform.strides = new_data_strides
                transient_to_transform.total_size = new_data_totalsize
                transient_to_transform.offset = new_data_offset
                transient_to_transform.lifetime = dtypes.AllocationLifetime.Scope
                transient_to_transform.storage = self.transient_allocation

            else:
                # don't modify data container - array is needed outside
                # of subgraph.

                # hack: set lifetime to State if allocation has only been
                # scope so far to avoid allocation issues
                if sdfg.data(
                        data_name).lifetime == dtypes.AllocationLifetime.Scope:
                    sdfg.data(
                        data_name).lifetime = dtypes.AllocationLifetime.State

        # do one pass to adjust and the memlets of in-between transients
        for node in intermediate_nodes:
            # all incoming edges to node
            in_edges = graph.in_edges(node)
            # outgoing edges going to another fused part
            out_edges = graph.out_edges(node)

            # memlets of created transient:
            # correct data names
            if node in transients_created:
                transient_in_edges = graph.in_edges(transients_created[node])
                transient_out_edges = graph.out_edges(transients_created[node])
                for edge in chain(transient_in_edges, transient_out_edges):
                    for e in graph.memlet_tree(edge):
                        if e.data.data == node.data:
                            e.data.data += '_OUT'

            # memlets of all in between transients:
            # offset memlets if array has been augmented
            if subgraph_contains_data[node.data]:
                # get min_offset
                min_offset = min_offsets[node.data]
                # re-add invariant dimensions with offset 0
                for iedge in in_edges:
                    for edge in graph.memlet_tree(iedge):
                        if edge.data.data == node.data:
                            edge.data.subset.offset(min_offset, True)
                        elif edge.data.other_subset:
                            edge.data.other_subset.offset(min_offset, True)
                    # nested SDFG: adjust arrays connected
                    if isinstance(iedge.src, nodes.NestedSDFG):
                        nsdfg = iedge.src.sdfg
                        nested_data_name = edge.src_conn
                        self.adjust_arrays_nsdfg(sdfg, nsdfg, node.data,
                                                 nested_data_name)

                for cedge in out_edges:
                    for edge in graph.memlet_tree(cedge):
                        if edge.data.data == node.data:
                            edge.data.subset.offset(min_offset, True)
                        elif edge.data.other_subset:
                            edge.data.other_subset.offset(min_offset, True)
                        # nested SDFG: adjust arrays connected
                        if isinstance(edge.dst, nodes.NestedSDFG):
                            nsdfg = edge.dst.sdfg
                            nested_data_name = edge.dst_conn
                            self.adjust_arrays_nsdfg(sdfg, nsdfg, node.data,
                                                     nested_data_name)

                # if in_edges has several entries:
                # put other_subset into out_edges for correctness
                if len(in_edges) > 1:
                    for oedge in out_edges:
                        if oedge.dst == global_map_exit and \
                                            oedge.data.other_subset is None:
                            oedge.data.other_subset = dcpy(oedge.data.subset)
                            oedge.data.other_subset.offset(min_offset, True)

        # consolidate edges if desired
        if self.consolidate:
            consolidate_edges_scope(graph, global_map_entry)
            consolidate_edges_scope(graph, global_map_exit)

        # propagate edges adjacent to global map entry and exit
        # if desired
        if self.propagate:
            _propagate_node(graph, global_map_entry)
            _propagate_node(graph, global_map_exit)

        # create a hook for outside access to global_map
        self._global_map_entry = global_map_entry
        if self.schedule_innermaps is not None:
            for node in graph.scope_children()[global_map_entry]:
                if isinstance(node, nodes.MapEntry):
                    node.map.schedule = self.schedule_innermaps
Ejemplo n.º 5
0
    def fuse(self, sdfg, graph, map_entries, do_not_override=[], **kwargs):
        """ takes the map_entries specified and tries to fuse maps.

            all maps have to be extended into outer and inner map
            (use MapExpansion as a pre-pass)

            Arrays that don't exist outside the subgraph get pushed
            into the map and their data dimension gets cropped.
            Otherwise the original array is taken.

            For every output respective connections are crated automatically.

            :param sdfg: SDFG
            :param graph: State
            :param map_entries: Map Entries (class MapEntry) of the outer maps
                                which we want to fuse
            :param do_not_override: List of data names whose corresponding nodes
                                    are fully contained within the subgraph
                                    but should not be augmented/transformed
                                    nevertheless.
        """

        # if there are no maps, return immediately
        if len(map_entries) == 0:
            return

        # get maps and map exits
        maps = [map_entry.map for map_entry in map_entries]
        map_exits = [graph.exit_node(map_entry) for map_entry in map_entries]

        # re-construct the map subgraph if necessary
        try:
            self.subgraph
        except AttributeError:
            subgraph_nodes = set()
            scope_dict = graph.scope_dict(node_to_children=True)
            for node in chain(map_entries, map_exits):
                subgraph_nodes.add(node)
                # add all border arrays
                for e in chain(graph.in_edges(node), graph.out_edges(node)):
                    subgraph_nodes.add(e.src)
                    subgraph_nodes.add(e.dst)
                try:
                    subgraph_nodes |= set(scope_dict[node])
                except KeyError:
                    pass
            self.subgraph = SubgraphView(graph, subgraph_nodes)

        # Nodes that flow into one or several maps but no data is flowed to them from any map
        in_nodes = set()

        # Nodes into which data is flowed but that no data flows into any map from them
        out_nodes = set()

        # Nodes that act as intermediate node - data flows from a map into them and then there
        # is an outgoing path into another map
        intermediate_nodes = set()

        ### NOTE:
        #- in_nodes, out_nodes, intermediate_nodes refer to the configuration of the final fused map
        #- in_nodes and out_nodes are trivially disjoint
        #- Intermediate_nodes and out_nodes are not necessarily disjoint
        #- Intermediate_nodes and in_nodes are disjoint by design.
        #  There could be a node that has both incoming edges from a map exit
        #  and from outside, but it is just treated as intermediate_node and handled
        #  automatically.

        for map_entry, map_exit in zip(map_entries, map_exits):
            for edge in graph.in_edges(map_entry):
                in_nodes.add(edge.src)
            for edge in graph.out_edges(map_exit):
                current_node = edge.dst
                if len(graph.out_edges(current_node)) == 0:
                    out_nodes.add(current_node)
                else:
                    for dst_edge in graph.out_edges(current_node):
                        if dst_edge.dst in map_entries:
                            # add to intermediate_nodes
                            intermediate_nodes.add(current_node)

                        else:
                            # add to out_nodes
                            out_nodes.add(current_node)
                for e in graph.in_edges(current_node):
                    if e.src not in map_exits:
                        raise NotImplementedError(
                            "Nodes between two maps to be"
                            "fused with *incoming* edges"
                            "from outside the maps are not"
                            "allowed yet.")

        # any intermediate_nodes currently in in_nodes shouldnt be there
        in_nodes -= intermediate_nodes

        if self.debug:
            print("SubgraphFusion::In_nodes", in_nodes)
            print("SubgraphFusion::Out_nodes", out_nodes)
            print("SubgraphFusion::Intermediate_nodes", intermediate_nodes)

        # all maps are assumed to have the same params and range in order
        global_map = nodes.Map(label="outer_fused",
                               params=maps[0].params,
                               ndrange=maps[0].range)
        global_map_entry = nodes.MapEntry(global_map)
        global_map_exit = nodes.MapExit(global_map)

        schedule = map_entries[0].schedule
        global_map_entry.schedule = schedule
        graph.add_node(global_map_entry)
        graph.add_node(global_map_exit)

        # next up, for any intermediate node, find whether it only appears
        # in the subgraph or also somewhere else / as an input
        # create new transients for nodes that are in out_nodes and
        # intermediate_nodes simultaneously
        # also check which dimensions of each transient data element correspond
        # to map axes and write this information into a dict.
        node_info = self.prepare_intermediate_nodes(sdfg, graph, in_nodes, out_nodes, \
                                                    intermediate_nodes,\
                                                    map_entries, map_exits, \
                                                    do_not_override)

        (subgraph_contains_data, transients_created,
         invariant_dimensions) = node_info
        if self.debug:
            print(
                "SubgraphFusion:: {Intermediate_node: subgraph_contains_data} dict"
            )
            print(subgraph_contains_data)

        inconnectors_dict = {}
        # Dict for saving incoming nodes and their assigned connectors
        # Format: {access_node: (edge, in_conn, out_conn)}

        for map_entry, map_exit in zip(map_entries, map_exits):
            # handle inputs
            # TODO: dynamic map range -- this is fairly unrealistic in such a setting
            for edge in graph.in_edges(map_entry):
                src = edge.src
                mmt = graph.memlet_tree(edge)
                out_edges = [child.edge for child in mmt.root().children]

                if src in in_nodes:
                    in_conn = None
                    out_conn = None
                    if src in inconnectors_dict:
                        # no need to augment subset of outer edge.
                        # will do this at the end in one pass.

                        in_conn = inconnectors_dict[src][1]
                        out_conn = inconnectors_dict[src][2]
                        graph.remove_edge(edge)

                    else:
                        next_conn = global_map_entry.next_connector()
                        in_conn = 'IN_' + next_conn
                        out_conn = 'OUT_' + next_conn
                        global_map_entry.add_in_connector(in_conn)
                        global_map_entry.add_out_connector(out_conn)

                        inconnectors_dict[src] = (edge, in_conn, out_conn)

                        # reroute in edge via global_map_entry
                        self.redirect_edge(graph, edge, new_dst = global_map_entry, \
                                                        new_dst_conn = in_conn)

                    # map out edges to new map
                    for out_edge in out_edges:
                        self.redirect_edge(graph, out_edge, new_src = global_map_entry, \
                                                            new_src_conn = out_conn)

                else:
                    # connect directly
                    for out_edge in out_edges:
                        mm = dcpy(out_edge.data)
                        self.redirect_edge(graph,
                                           out_edge,
                                           new_src=src,
                                           new_data=mm)

                    graph.remove_edge(edge)

            for edge in graph.out_edges(map_entry):
                # special case: for nodes that have no data connections
                if not edge.src_conn:
                    self.redirect_edge(graph, edge, new_src=global_map_entry)

            ######################################

            for edge in graph.in_edges(map_exit):
                if not edge.dst_conn:
                    # no destination connector, path ends here.
                    self.redirect_edge(graph, edge, new_dst=global_map_exit)
                    continue
                # find corresponding out_edges for current edge, cannot use mmt anymore
                out_edges = [
                    oedge for oedge in graph.out_edges(map_exit)
                    if oedge.src_conn[3:] == edge.dst_conn[2:]
                ]

                # Tuple to store in/out connector port that might be created
                port_created = None

                for out_edge in out_edges:
                    dst = out_edge.dst

                    if dst in intermediate_nodes & out_nodes:

                        # create connection through global map from
                        # dst to dst_transient that was created
                        dst_transient = transients_created[dst]
                        next_conn = global_map_exit.next_connector()
                        in_conn = 'IN_' + next_conn
                        out_conn = 'OUT_' + next_conn
                        global_map_exit.add_in_connector(in_conn)
                        global_map_exit.add_out_connector(out_conn)

                        inner_memlet = dcpy(edge.data)
                        inner_memlet.other_subset = dcpy(edge.data.subset)

                        e_inner = graph.add_edge(dst, None, global_map_exit,
                                                 in_conn, inner_memlet)
                        mm_outer = propagate_memlet(graph, inner_memlet, global_map_entry, \
                                                    union_inner_edges = False)

                        e_outer = graph.add_edge(global_map_exit, out_conn,
                                                 dst_transient, None, mm_outer)

                        # remove edge from dst to dst_transient that was created
                        # in intermediate preparation.
                        for e in graph.out_edges(dst):
                            if e.dst == dst_transient:
                                graph.remove_edge(e)
                                removed = True
                                break

                        if self.debug:
                            assert removed == True

                    # handle separately: intermediate_nodes and pure out nodes
                    # case 1: intermediate_nodes: can just redirect edge
                    if dst in intermediate_nodes:
                        self.redirect_edge(graph,
                                           out_edge,
                                           new_src=edge.src,
                                           new_src_conn=edge.src_conn,
                                           new_data=dcpy(edge.data))

                    # case 2: pure out node: connect to outer array node
                    if dst in (out_nodes - intermediate_nodes):
                        if edge.dst != global_map_exit:
                            next_conn = global_map_exit.next_connector()
                            in_conn = 'IN_' + next_conn
                            out_conn = 'OUT_' + next_conn
                            global_map_exit.add_in_connector(in_conn)
                            global_map_exit.add_out_connector(out_conn)
                            self.redirect_edge(graph,
                                               edge,
                                               new_dst=global_map_exit,
                                               new_dst_conn=in_conn)
                            port_created = (in_conn, out_conn)
                            #edge.dst = global_map_exit
                            #edge.dst_conn = in_conn

                        else:
                            conn_nr = edge.dst_conn[3:]
                            in_conn = port_created.st
                            out_conn = port_created.nd

                        # map
                        graph.add_edge(global_map_exit, out_conn, dst, None,
                                       dcpy(out_edge.data))
                        graph.remove_edge(out_edge)

                # remove the edge if it has not been used by any pure out node
                if not port_created:
                    graph.remove_edge(edge)

            # maps are now ready to be discarded
            graph.remove_node(map_entry)
            graph.remove_node(map_exit)

            # end main loop.

        # create a mapping from data arrays to offsets
        # for later memlet adjustments later
        min_offsets = dict()

        # do one pass to augment all transient arrays
        data_intermediate = set([node.data for node in intermediate_nodes])
        for data_name in data_intermediate:
            if subgraph_contains_data[data_name]:
                all_nodes = [
                    n for n in intermediate_nodes if n.data == data_name
                ]
                in_edges = list(chain(*(graph.in_edges(n) for n in all_nodes)))

                in_edges_iter = iter(in_edges)
                in_edge = next(in_edges_iter)
                target_subset = dcpy(in_edge.data.subset)
                target_subset.pop(invariant_dimensions[data_name])
                ######
                while True:
                    try:  # executed if there are multiple in_edges
                        in_edge = next(in_edges_iter)
                        target_subset_curr = dcpy(in_edge.data.subset)
                        target_subset_curr.pop(invariant_dimensions[data_name])
                        target_subset = subsets.union(target_subset, \
                                                      target_subset_curr)
                    except StopIteration:
                        break

                min_offsets_cropped = target_subset.min_element_approx()
                # calculate the new transient array size.
                target_subset.offset(min_offsets_cropped, True)

                # re-add invariant dimensions with offset 0 and save to min_offsets
                min_offset = []
                index = 0
                for i in range(len(sdfg.data(data_name).shape)):
                    if i in invariant_dimensions[data_name]:
                        min_offset.append(0)
                    else:
                        min_offset.append(min_offsets_cropped[index])
                        index += 1

                min_offsets[data_name] = min_offset

                # determine the shape of the new array.
                new_data_shape = []
                index = 0
                for i, sz in enumerate(sdfg.data(data_name).shape):
                    if i in invariant_dimensions[data_name]:
                        new_data_shape.append(sz)
                    else:
                        new_data_shape.append(target_subset.size()[index])
                        index += 1

                new_data_strides = [
                    data._prod(new_data_shape[i + 1:])
                    for i in range(len(new_data_shape))
                ]

                new_data_totalsize = data._prod(new_data_shape)
                new_data_offset = [0] * len(new_data_shape)
                # augment.
                transient_to_transform = sdfg.data(data_name)
                transient_to_transform.shape = new_data_shape
                transient_to_transform.strides = new_data_strides
                transient_to_transform.total_size = new_data_totalsize
                transient_to_transform.offset = new_data_offset
                transient_to_transform.lifetime = dtypes.AllocationLifetime.Scope
                transient_to_transform.storage = self.transient_allocation

            else:
                # don't modify data container - array is needed outside
                # of subgraph.

                # hack: set lifetime to State if allocation has only been
                # scope so far to avoid allocation issues
                if sdfg.data(
                        data_name).lifetime == dtypes.AllocationLifetime.Scope:
                    sdfg.data(
                        data_name).lifetime = dtypes.AllocationLifetime.State

        # do one pass to adjust and the memlets of in-between transients
        for node in intermediate_nodes:
            # all incoming edges to node
            in_edges = graph.in_edges(node)
            # outgoing edges going to another fused part
            inter_edges = []
            # outgoing edges that exit global map
            out_edges = []
            for e in graph.out_edges(node):
                if e.dst == global_map_exit:
                    out_edges.append(e)
                else:
                    inter_edges.append(e)

            # offset memlets where necessary
            if subgraph_contains_data[node.data]:
                # get min_offset
                min_offset = min_offsets[node.data]
                # re-add invariant dimensions with offset 0
                for iedge in in_edges:
                    for edge in graph.memlet_tree(iedge):
                        if edge.data.data == node.data:
                            edge.data.subset.offset(min_offset, True)
                        elif edge.data.other_subset:
                            edge.data.other_subset.offset(min_offset, True)

                for cedge in inter_edges:
                    for edge in graph.memlet_tree(cedge):
                        if edge.data.data == node.data:
                            edge.data.subset.offset(min_offset, True)
                        elif edge.data.other_subset:
                            edge.data.other_subset.offset(min_offset, True)

                # if in_edges has several entries:
                # put other_subset into out_edges for correctness
                if len(in_edges) > 1:
                    for oedge in out_edges:
                        oedge.data.other_subset = dcpy(oedge.data.subset)
                        oedge.data.other_subset.offset(min_offset, True)

            # also correct memlets of created transient
            if node in transients_created:
                transient_in_edges = graph.in_edges(transients_created[node])
                transient_out_edges = graph.out_edges(transients_created[node])
                for edge in chain(transient_in_edges, transient_out_edges):
                    for e in graph.memlet_tree(edge):
                        if e.data.data == node.data:
                            e.data.data += '_OUT'

        # do one last pass to correct outside memlets adjacent to global map
        for out_connector in global_map_entry.out_connectors:
            # find corresponding in_connector
            # and the in-connecting edge
            in_connector = 'IN' + out_connector[3:]
            for iedge in graph.in_edges(global_map_entry):
                if iedge.dst_conn == in_connector:
                    in_edge = iedge

            # find corresponding out_connector
            # and all out-connecting edges that belong to it
            # count them
            oedge_counter = 0
            for oedge in graph.out_edges(global_map_entry):
                if oedge.src_conn == out_connector:
                    out_edge = oedge
                    oedge_counter += 1

            # do memlet propagation
            # if there are several out edges, else there is no need

            if oedge_counter > 1:
                memlet_out = propagate_memlet(dfg_state=graph,
                                              memlet=out_edge.data,
                                              scope_node=global_map_entry,
                                              union_inner_edges=True)
                # override number of accesses
                in_edge.data.volume = memlet_out.volume
                in_edge.data.subset = memlet_out.subset

        # create a hook for outside access to global_map
        self._global_map_entry = global_map_entry
Ejemplo n.º 6
0
    def apply(self, sdfg: dace.SDFG):
        graph: dace.SDFGState = sdfg.node(self.state_id)
        map_entry: nodes.MapEntry = graph.node(self.subgraph[NestK._map_entry])
        stencil: Stencil = graph.node(self.subgraph[NestK._stencil])

        # Find dimension index and name
        pname = map_entry.map.params[0]
        dim_index = None
        for edge in graph.all_edges(stencil):
            if edge.data.data is None:  # Empty memlet
                continue

            if len(edge.data.subset) == 3:
                for i, rng in enumerate(edge.data.subset.ndrange()):
                    for r in rng:
                        if (pname in map(str, r.free_symbols)):
                            dim_index = i
                            break
                    if dim_index is not None:
                        break
                if dim_index is not None:
                    break
        ###

        map_exit = graph.exit_node(map_entry)

        # Reconnect external edges directly to stencil node
        for edge in graph.in_edges(map_entry):
            # Find matching internal edges
            tree = graph.memlet_tree(edge)
            for child in tree.children:
                memlet = propagation.propagate_memlet(graph, child.edge.data,
                                                      map_entry, False)
                graph.add_edge(edge.src, edge.src_conn, stencil,
                               child.edge.dst_conn, memlet)
        for edge in graph.out_edges(map_exit):
            # Find matching internal edges
            tree = graph.memlet_tree(edge)
            for child in tree.children:
                memlet = propagation.propagate_memlet(graph, child.edge.data,
                                                      map_entry, False)
                graph.add_edge(stencil, child.edge.src_conn, edge.dst,
                               edge.dst_conn, memlet)

        # Remove map
        graph.remove_nodes_from([map_entry, map_exit])

        # Reshape stencil node computation based on nested map range
        stencil.shape[dim_index] = map_entry.map.range.num_elements()

        # Add dimensions to access and output fields
        add_dims = set()
        for edge in graph.in_edges(stencil):
            if edge.data.data and len(edge.data.subset) == 3:
                if stencil.accesses[edge.dst_conn][0][dim_index] is False:
                    add_dims.add(edge.dst_conn)
                stencil.accesses[edge.dst_conn][0][dim_index] = True
        for edge in graph.out_edges(stencil):
            if edge.data.data and len(edge.data.subset) == 3:
                if stencil.output_fields[edge.src_conn][0][dim_index] is False:
                    add_dims.add(edge.src_conn)
                stencil.output_fields[edge.src_conn][0][dim_index] = True
        # Change all instances in the code as well
        if stencil.code.language != dace.Language.Python:
            raise ValueError(
                'For NestK to work, Stencil code language must be Python')
        for i, stmt in enumerate(stencil.code.code):
            stencil.code.code[i] = DimensionAdder(add_dims,
                                                  dim_index).visit(stmt)