def apply(self, sdfg: SDFG): # Extract the parameters and ranges of the inner/outer maps. graph: SDFGState = sdfg.nodes()[self.state_id] outer_map_entry = graph.nodes()[self.subgraph[ MapInterchange.outer_map_entry]] inner_map_entry = graph.nodes()[self.subgraph[ MapInterchange.inner_map_entry]] inner_map_exit = graph.exit_node(inner_map_entry) outer_map_exit = graph.exit_node(outer_map_entry) # Switch connectors outer_map_entry.in_connectors, inner_map_entry.in_connectors = \ inner_map_entry.in_connectors, outer_map_entry.in_connectors outer_map_entry.out_connectors, inner_map_entry.out_connectors = \ inner_map_entry.out_connectors, outer_map_entry.out_connectors outer_map_exit.in_connectors, inner_map_exit.in_connectors = \ inner_map_exit.in_connectors, outer_map_exit.in_connectors outer_map_exit.out_connectors, inner_map_exit.out_connectors = \ inner_map_exit.out_connectors, outer_map_exit.out_connectors # Get edges between the map entries and exits. entry_edges = graph.edges_between(outer_map_entry, inner_map_entry) exit_edges = graph.edges_between(inner_map_exit, outer_map_exit) for e in entry_edges + exit_edges: graph.remove_edge(e) # Change source and destination of edges. sdutil.change_edge_dest(graph, outer_map_entry, inner_map_entry) sdutil.change_edge_src(graph, inner_map_entry, outer_map_entry) sdutil.change_edge_dest(graph, inner_map_exit, outer_map_exit) sdutil.change_edge_src(graph, outer_map_exit, inner_map_exit) # Add edges between the map entries and exits. new_entry_edges = [] new_exit_edges = [] for e in entry_edges: new_entry_edges.append( graph.add_edge(e.dst, e.src_conn, e.src, e.dst_conn, e.data)) for e in exit_edges: new_exit_edges.append( graph.add_edge(e.dst, e.src_conn, e.src, e.dst_conn, e.data)) # Repropagate memlets in modified region for e in new_entry_edges: path = graph.memlet_path(e) index = next(i for i, edge in enumerate(path) if e is edge) e.data.subset = propagate_memlet(graph, path[index + 1].data, outer_map_entry, True).subset for e in new_exit_edges: path = graph.memlet_path(e) index = next(i for i, edge in enumerate(path) if e is edge) e.data.subset = propagate_memlet(graph, path[index - 1].data, outer_map_exit, True).subset
def apply(self, sdfg): graph = sdfg.node(self.state_id) array = graph.node(self.subgraph[InMergeArrays._array1]) map = graph.node(self.subgraph[InMergeArrays._map_entry]) map_edge = next(e for e in graph.out_edges(array) if e.dst == map) result_connector = map_edge.dst_conn[3:] # Find all other incoming access nodes without incoming edges source_edges = [ e for e in graph.in_edges(map) if isinstance(e.src, nodes.AccessNode) and e.src.data == array.data and e.src != array and e.dst_conn and e.dst_conn.startswith('IN_') and graph.in_degree(e.src) == 0 ] # Modify connectors to point to first array connectors_to_remove = set() for e in source_edges: connector = e.dst_conn[3:] connectors_to_remove.add(connector) for inner_edge in graph.out_edges(map): if inner_edge.src_conn[4:] == connector: inner_edge._src_conn = 'OUT_' + result_connector # Remove other nodes from state graph.remove_nodes_from(set(e.src for e in source_edges)) # Remove connectors from scope entry for c in connectors_to_remove: map.remove_in_connector('IN_' + c) map.remove_out_connector('OUT_' + c) # Re-propagate memlets edge_to_propagate = next(e for e in graph.out_edges(map) if e.src_conn[4:] == result_connector) map_edge._data = propagate_memlet(dfg_state=graph, memlet=edge_to_propagate.data, scope_node=map, union_inner_edges=True)
def apply(self, graph, sdfg): array = self.array1 map = self.map_exit map_edge = next(e for e in graph.in_edges(array) if e.src == map) result_connector = map_edge.src_conn[4:] # Find all other outgoing access nodes without outgoing edges dst_edges = [ e for e in graph.out_edges(map) if isinstance(e.dst, nodes.AccessNode) and e.dst.data == array.data and e.dst != array and e.src_conn and e.src_conn.startswith('OUT_') and graph.out_degree(e.dst) == 0 ] # Modify connectors to point to first array connectors_to_remove = set() for e in dst_edges: connector = e.src_conn[4:] connectors_to_remove.add(connector) for inner_edge in graph.in_edges(map): if inner_edge.dst_conn[3:] == connector: inner_edge.dst_conn = 'IN_' + result_connector # Remove other nodes from state graph.remove_nodes_from(set(e.dst for e in dst_edges)) # Remove connectors from scope entry for c in connectors_to_remove: map.remove_in_connector('IN_' + c) map.remove_out_connector('OUT_' + c) # Re-propagate memlets edge_to_propagate = next(e for e in graph.in_edges(map) if e.dst_conn[3:] == result_connector) map_edge._data = propagate_memlet(dfg_state=graph, memlet=edge_to_propagate.data, scope_node=map, union_inner_edges=True)
def fuse(self, sdfg, graph, map_entries, do_not_override=None, **kwargs): """ takes the map_entries specified and tries to fuse maps. all maps have to be extended into outer and inner map (use MapExpansion as a pre-pass) Arrays that don't exist outside the subgraph get pushed into the map and their data dimension gets cropped. Otherwise the original array is taken. For every output respective connections are crated automatically. :param sdfg: SDFG :param graph: State :param map_entries: Map Entries (class MapEntry) of the outer maps which we want to fuse :param do_not_override: List of data names whose corresponding nodes are fully contained within the subgraph but should not be augmented/transformed nevertheless. """ # if there are no maps, return immediately if len(map_entries) == 0: return do_not_override = do_not_override or [] # get maps and map exits maps = [map_entry.map for map_entry in map_entries] map_exits = [graph.exit_node(map_entry) for map_entry in map_entries] # See function documentation for an explanation of these variables node_config = SubgraphFusion.get_adjacent_nodes(sdfg, graph, map_entries) (in_nodes, intermediate_nodes, out_nodes) = node_config if self.debug: print("SubgraphFusion::In_nodes", in_nodes) print("SubgraphFusion::Out_nodes", out_nodes) print("SubgraphFusion::Intermediate_nodes", intermediate_nodes) # all maps are assumed to have the same params and range in order global_map = nodes.Map(label="outer_fused", params=maps[0].params, ndrange=maps[0].range) global_map_entry = nodes.MapEntry(global_map) global_map_exit = nodes.MapExit(global_map) schedule = map_entries[0].schedule global_map_entry.schedule = schedule graph.add_node(global_map_entry) graph.add_node(global_map_exit) # next up, for any intermediate node, find whether it only appears # in the subgraph or also somewhere else / as an input # create new transients for nodes that are in out_nodes and # intermediate_nodes simultaneously # also check which dimensions of each transient data element correspond # to map axes and write this information into a dict. node_info = self.prepare_intermediate_nodes(sdfg, graph, in_nodes, out_nodes, \ intermediate_nodes,\ map_entries, map_exits, \ do_not_override) (subgraph_contains_data, transients_created, invariant_dimensions) = node_info if self.debug: print( "SubgraphFusion:: {Intermediate_node: subgraph_contains_data} dict" ) print(subgraph_contains_data) inconnectors_dict = {} # Dict for saving incoming nodes and their assigned connectors # Format: {access_node: (edge, in_conn, out_conn)} for map_entry, map_exit in zip(map_entries, map_exits): # handle inputs # TODO: dynamic map range -- this is fairly unrealistic in such a setting for edge in graph.in_edges(map_entry): src = edge.src mmt = graph.memlet_tree(edge) out_edges = [child.edge for child in mmt.root().children] if src in in_nodes: in_conn = None out_conn = None if src in inconnectors_dict: # no need to augment subset of outer edge. # will do this at the end in one pass. in_conn = inconnectors_dict[src][1] out_conn = inconnectors_dict[src][2] else: next_conn = global_map_entry.next_connector() in_conn = 'IN_' + next_conn out_conn = 'OUT_' + next_conn global_map_entry.add_in_connector(in_conn) global_map_entry.add_out_connector(out_conn) inconnectors_dict[src] = (edge, in_conn, out_conn) # reroute in edge via global_map_entry self.copy_edge(graph, edge, new_dst = global_map_entry, \ new_dst_conn = in_conn) # map out edges to new map for out_edge in out_edges: self.copy_edge(graph, out_edge, new_src = global_map_entry, \ new_src_conn = out_conn) else: # connect directly for out_edge in out_edges: mm = dcpy(out_edge.data) self.copy_edge(graph, out_edge, new_src=src, new_src_conn=None, new_data=mm) for edge in graph.out_edges(map_entry): # special case: for nodes that have no data connections if not edge.src_conn: self.copy_edge(graph, edge, new_src=global_map_entry) ###################################### for edge in graph.in_edges(map_exit): if not edge.dst_conn: # no destination connector, path ends here. self.copy_edge(graph, edge, new_dst=global_map_exit) continue # find corresponding out_edges for current edge, cannot use mmt anymore out_edges = [ oedge for oedge in graph.out_edges(map_exit) if oedge.src_conn[3:] == edge.dst_conn[2:] ] # Tuple to store in/out connector port that might be created port_created = None for out_edge in out_edges: dst = out_edge.dst if dst in intermediate_nodes & out_nodes: # create connection through global map from # dst to dst_transient that was created dst_transient = transients_created[dst] next_conn = global_map_exit.next_connector() in_conn = 'IN_' + next_conn out_conn = 'OUT_' + next_conn global_map_exit.add_in_connector(in_conn) global_map_exit.add_out_connector(out_conn) # for each transient created, create a union # of outgoing memlets' subsets. this is # a cheap fix to override assignments in invariant # dimensions union = None for oe in graph.out_edges(transients_created[dst]): union = subsets.union(union, oe.data.subset) inner_memlet = dcpy(edge.data) for i, s in enumerate(edge.data.subset): if i in invariant_dimensions[dst.label]: inner_memlet.subset[i] = union[i] inner_memlet.other_subset = dcpy(inner_memlet.subset) e_inner = graph.add_edge(dst, None, global_map_exit, in_conn, inner_memlet) mm_outer = propagate_memlet(graph, inner_memlet, global_map_entry, \ union_inner_edges = False) e_outer = graph.add_edge(global_map_exit, out_conn, dst_transient, None, mm_outer) # remove edge from dst to dst_transient that was created # in intermediate preparation. for e in graph.out_edges(dst): if e.dst == dst_transient: graph.remove_edge(e) break # handle separately: intermediate_nodes and pure out nodes # case 1: intermediate_nodes: can just redirect edge if dst in intermediate_nodes: self.copy_edge(graph, out_edge, new_src=edge.src, new_src_conn=edge.src_conn, new_data=dcpy(edge.data)) # case 2: pure out node: connect to outer array node if dst in (out_nodes - intermediate_nodes): if edge.dst != global_map_exit: next_conn = global_map_exit.next_connector() in_conn = 'IN_' + next_conn out_conn = 'OUT_' + next_conn global_map_exit.add_in_connector(in_conn) global_map_exit.add_out_connector(out_conn) self.copy_edge(graph, edge, new_dst=global_map_exit, new_dst_conn=in_conn) port_created = (in_conn, out_conn) else: conn_nr = edge.dst_conn[3:] in_conn = port_created.st out_conn = port_created.nd # map graph.add_edge(global_map_exit, out_conn, dst, None, dcpy(out_edge.data)) # maps are now ready to be discarded # all connected edges will be finally removed as well graph.remove_node(map_entry) graph.remove_node(map_exit) # create a mapping from data arrays to offsets # for later memlet adjustments later min_offsets = dict() # do one pass to augment all transient arrays data_intermediate = set([node.data for node in intermediate_nodes]) for data_name in data_intermediate: if subgraph_contains_data[data_name]: all_nodes = [ n for n in intermediate_nodes if n.data == data_name ] in_edges = list(chain(*(graph.in_edges(n) for n in all_nodes))) in_edges_iter = iter(in_edges) in_edge = next(in_edges_iter) target_subset = dcpy(in_edge.data.subset) target_subset.pop(invariant_dimensions[data_name]) ###### while True: try: # executed if there are multiple in_edges in_edge = next(in_edges_iter) target_subset_curr = dcpy(in_edge.data.subset) target_subset_curr.pop(invariant_dimensions[data_name]) target_subset = subsets.union(target_subset, \ target_subset_curr) except StopIteration: break min_offsets_cropped = target_subset.min_element_approx() # calculate the new transient array size. target_subset.offset(min_offsets_cropped, True) # re-add invariant dimensions with offset 0 and save to min_offsets min_offset = [] index = 0 for i in range(len(sdfg.data(data_name).shape)): if i in invariant_dimensions[data_name]: min_offset.append(0) else: min_offset.append(min_offsets_cropped[index]) index += 1 min_offsets[data_name] = min_offset # determine the shape of the new array. new_data_shape = [] index = 0 for i, sz in enumerate(sdfg.data(data_name).shape): if i in invariant_dimensions[data_name]: new_data_shape.append(sz) else: new_data_shape.append(target_subset.size()[index]) index += 1 new_data_strides = [ data._prod(new_data_shape[i + 1:]) for i in range(len(new_data_shape)) ] new_data_totalsize = data._prod(new_data_shape) new_data_offset = [0] * len(new_data_shape) # augment. transient_to_transform = sdfg.data(data_name) transient_to_transform.shape = new_data_shape transient_to_transform.strides = new_data_strides transient_to_transform.total_size = new_data_totalsize transient_to_transform.offset = new_data_offset transient_to_transform.lifetime = dtypes.AllocationLifetime.Scope transient_to_transform.storage = self.transient_allocation else: # don't modify data container - array is needed outside # of subgraph. # hack: set lifetime to State if allocation has only been # scope so far to avoid allocation issues if sdfg.data( data_name).lifetime == dtypes.AllocationLifetime.Scope: sdfg.data( data_name).lifetime = dtypes.AllocationLifetime.State # do one pass to adjust and the memlets of in-between transients for node in intermediate_nodes: # all incoming edges to node in_edges = graph.in_edges(node) # outgoing edges going to another fused part out_edges = graph.out_edges(node) # memlets of created transient: # correct data names if node in transients_created: transient_in_edges = graph.in_edges(transients_created[node]) transient_out_edges = graph.out_edges(transients_created[node]) for edge in chain(transient_in_edges, transient_out_edges): for e in graph.memlet_tree(edge): if e.data.data == node.data: e.data.data += '_OUT' # memlets of all in between transients: # offset memlets if array has been augmented if subgraph_contains_data[node.data]: # get min_offset min_offset = min_offsets[node.data] # re-add invariant dimensions with offset 0 for iedge in in_edges: for edge in graph.memlet_tree(iedge): if edge.data.data == node.data: edge.data.subset.offset(min_offset, True) elif edge.data.other_subset: edge.data.other_subset.offset(min_offset, True) # nested SDFG: adjust arrays connected if isinstance(iedge.src, nodes.NestedSDFG): nsdfg = iedge.src.sdfg nested_data_name = edge.src_conn self.adjust_arrays_nsdfg(sdfg, nsdfg, node.data, nested_data_name) for cedge in out_edges: for edge in graph.memlet_tree(cedge): if edge.data.data == node.data: edge.data.subset.offset(min_offset, True) elif edge.data.other_subset: edge.data.other_subset.offset(min_offset, True) # nested SDFG: adjust arrays connected if isinstance(edge.dst, nodes.NestedSDFG): nsdfg = edge.dst.sdfg nested_data_name = edge.dst_conn self.adjust_arrays_nsdfg(sdfg, nsdfg, node.data, nested_data_name) # if in_edges has several entries: # put other_subset into out_edges for correctness if len(in_edges) > 1: for oedge in out_edges: if oedge.dst == global_map_exit and \ oedge.data.other_subset is None: oedge.data.other_subset = dcpy(oedge.data.subset) oedge.data.other_subset.offset(min_offset, True) # consolidate edges if desired if self.consolidate: consolidate_edges_scope(graph, global_map_entry) consolidate_edges_scope(graph, global_map_exit) # propagate edges adjacent to global map entry and exit # if desired if self.propagate: _propagate_node(graph, global_map_entry) _propagate_node(graph, global_map_exit) # create a hook for outside access to global_map self._global_map_entry = global_map_entry if self.schedule_innermaps is not None: for node in graph.scope_children()[global_map_entry]: if isinstance(node, nodes.MapEntry): node.map.schedule = self.schedule_innermaps
def fuse(self, sdfg, graph, map_entries, do_not_override=[], **kwargs): """ takes the map_entries specified and tries to fuse maps. all maps have to be extended into outer and inner map (use MapExpansion as a pre-pass) Arrays that don't exist outside the subgraph get pushed into the map and their data dimension gets cropped. Otherwise the original array is taken. For every output respective connections are crated automatically. :param sdfg: SDFG :param graph: State :param map_entries: Map Entries (class MapEntry) of the outer maps which we want to fuse :param do_not_override: List of data names whose corresponding nodes are fully contained within the subgraph but should not be augmented/transformed nevertheless. """ # if there are no maps, return immediately if len(map_entries) == 0: return # get maps and map exits maps = [map_entry.map for map_entry in map_entries] map_exits = [graph.exit_node(map_entry) for map_entry in map_entries] # re-construct the map subgraph if necessary try: self.subgraph except AttributeError: subgraph_nodes = set() scope_dict = graph.scope_dict(node_to_children=True) for node in chain(map_entries, map_exits): subgraph_nodes.add(node) # add all border arrays for e in chain(graph.in_edges(node), graph.out_edges(node)): subgraph_nodes.add(e.src) subgraph_nodes.add(e.dst) try: subgraph_nodes |= set(scope_dict[node]) except KeyError: pass self.subgraph = SubgraphView(graph, subgraph_nodes) # Nodes that flow into one or several maps but no data is flowed to them from any map in_nodes = set() # Nodes into which data is flowed but that no data flows into any map from them out_nodes = set() # Nodes that act as intermediate node - data flows from a map into them and then there # is an outgoing path into another map intermediate_nodes = set() ### NOTE: #- in_nodes, out_nodes, intermediate_nodes refer to the configuration of the final fused map #- in_nodes and out_nodes are trivially disjoint #- Intermediate_nodes and out_nodes are not necessarily disjoint #- Intermediate_nodes and in_nodes are disjoint by design. # There could be a node that has both incoming edges from a map exit # and from outside, but it is just treated as intermediate_node and handled # automatically. for map_entry, map_exit in zip(map_entries, map_exits): for edge in graph.in_edges(map_entry): in_nodes.add(edge.src) for edge in graph.out_edges(map_exit): current_node = edge.dst if len(graph.out_edges(current_node)) == 0: out_nodes.add(current_node) else: for dst_edge in graph.out_edges(current_node): if dst_edge.dst in map_entries: # add to intermediate_nodes intermediate_nodes.add(current_node) else: # add to out_nodes out_nodes.add(current_node) for e in graph.in_edges(current_node): if e.src not in map_exits: raise NotImplementedError( "Nodes between two maps to be" "fused with *incoming* edges" "from outside the maps are not" "allowed yet.") # any intermediate_nodes currently in in_nodes shouldnt be there in_nodes -= intermediate_nodes if self.debug: print("SubgraphFusion::In_nodes", in_nodes) print("SubgraphFusion::Out_nodes", out_nodes) print("SubgraphFusion::Intermediate_nodes", intermediate_nodes) # all maps are assumed to have the same params and range in order global_map = nodes.Map(label="outer_fused", params=maps[0].params, ndrange=maps[0].range) global_map_entry = nodes.MapEntry(global_map) global_map_exit = nodes.MapExit(global_map) schedule = map_entries[0].schedule global_map_entry.schedule = schedule graph.add_node(global_map_entry) graph.add_node(global_map_exit) # next up, for any intermediate node, find whether it only appears # in the subgraph or also somewhere else / as an input # create new transients for nodes that are in out_nodes and # intermediate_nodes simultaneously # also check which dimensions of each transient data element correspond # to map axes and write this information into a dict. node_info = self.prepare_intermediate_nodes(sdfg, graph, in_nodes, out_nodes, \ intermediate_nodes,\ map_entries, map_exits, \ do_not_override) (subgraph_contains_data, transients_created, invariant_dimensions) = node_info if self.debug: print( "SubgraphFusion:: {Intermediate_node: subgraph_contains_data} dict" ) print(subgraph_contains_data) inconnectors_dict = {} # Dict for saving incoming nodes and their assigned connectors # Format: {access_node: (edge, in_conn, out_conn)} for map_entry, map_exit in zip(map_entries, map_exits): # handle inputs # TODO: dynamic map range -- this is fairly unrealistic in such a setting for edge in graph.in_edges(map_entry): src = edge.src mmt = graph.memlet_tree(edge) out_edges = [child.edge for child in mmt.root().children] if src in in_nodes: in_conn = None out_conn = None if src in inconnectors_dict: # no need to augment subset of outer edge. # will do this at the end in one pass. in_conn = inconnectors_dict[src][1] out_conn = inconnectors_dict[src][2] graph.remove_edge(edge) else: next_conn = global_map_entry.next_connector() in_conn = 'IN_' + next_conn out_conn = 'OUT_' + next_conn global_map_entry.add_in_connector(in_conn) global_map_entry.add_out_connector(out_conn) inconnectors_dict[src] = (edge, in_conn, out_conn) # reroute in edge via global_map_entry self.redirect_edge(graph, edge, new_dst = global_map_entry, \ new_dst_conn = in_conn) # map out edges to new map for out_edge in out_edges: self.redirect_edge(graph, out_edge, new_src = global_map_entry, \ new_src_conn = out_conn) else: # connect directly for out_edge in out_edges: mm = dcpy(out_edge.data) self.redirect_edge(graph, out_edge, new_src=src, new_data=mm) graph.remove_edge(edge) for edge in graph.out_edges(map_entry): # special case: for nodes that have no data connections if not edge.src_conn: self.redirect_edge(graph, edge, new_src=global_map_entry) ###################################### for edge in graph.in_edges(map_exit): if not edge.dst_conn: # no destination connector, path ends here. self.redirect_edge(graph, edge, new_dst=global_map_exit) continue # find corresponding out_edges for current edge, cannot use mmt anymore out_edges = [ oedge for oedge in graph.out_edges(map_exit) if oedge.src_conn[3:] == edge.dst_conn[2:] ] # Tuple to store in/out connector port that might be created port_created = None for out_edge in out_edges: dst = out_edge.dst if dst in intermediate_nodes & out_nodes: # create connection through global map from # dst to dst_transient that was created dst_transient = transients_created[dst] next_conn = global_map_exit.next_connector() in_conn = 'IN_' + next_conn out_conn = 'OUT_' + next_conn global_map_exit.add_in_connector(in_conn) global_map_exit.add_out_connector(out_conn) inner_memlet = dcpy(edge.data) inner_memlet.other_subset = dcpy(edge.data.subset) e_inner = graph.add_edge(dst, None, global_map_exit, in_conn, inner_memlet) mm_outer = propagate_memlet(graph, inner_memlet, global_map_entry, \ union_inner_edges = False) e_outer = graph.add_edge(global_map_exit, out_conn, dst_transient, None, mm_outer) # remove edge from dst to dst_transient that was created # in intermediate preparation. for e in graph.out_edges(dst): if e.dst == dst_transient: graph.remove_edge(e) removed = True break if self.debug: assert removed == True # handle separately: intermediate_nodes and pure out nodes # case 1: intermediate_nodes: can just redirect edge if dst in intermediate_nodes: self.redirect_edge(graph, out_edge, new_src=edge.src, new_src_conn=edge.src_conn, new_data=dcpy(edge.data)) # case 2: pure out node: connect to outer array node if dst in (out_nodes - intermediate_nodes): if edge.dst != global_map_exit: next_conn = global_map_exit.next_connector() in_conn = 'IN_' + next_conn out_conn = 'OUT_' + next_conn global_map_exit.add_in_connector(in_conn) global_map_exit.add_out_connector(out_conn) self.redirect_edge(graph, edge, new_dst=global_map_exit, new_dst_conn=in_conn) port_created = (in_conn, out_conn) #edge.dst = global_map_exit #edge.dst_conn = in_conn else: conn_nr = edge.dst_conn[3:] in_conn = port_created.st out_conn = port_created.nd # map graph.add_edge(global_map_exit, out_conn, dst, None, dcpy(out_edge.data)) graph.remove_edge(out_edge) # remove the edge if it has not been used by any pure out node if not port_created: graph.remove_edge(edge) # maps are now ready to be discarded graph.remove_node(map_entry) graph.remove_node(map_exit) # end main loop. # create a mapping from data arrays to offsets # for later memlet adjustments later min_offsets = dict() # do one pass to augment all transient arrays data_intermediate = set([node.data for node in intermediate_nodes]) for data_name in data_intermediate: if subgraph_contains_data[data_name]: all_nodes = [ n for n in intermediate_nodes if n.data == data_name ] in_edges = list(chain(*(graph.in_edges(n) for n in all_nodes))) in_edges_iter = iter(in_edges) in_edge = next(in_edges_iter) target_subset = dcpy(in_edge.data.subset) target_subset.pop(invariant_dimensions[data_name]) ###### while True: try: # executed if there are multiple in_edges in_edge = next(in_edges_iter) target_subset_curr = dcpy(in_edge.data.subset) target_subset_curr.pop(invariant_dimensions[data_name]) target_subset = subsets.union(target_subset, \ target_subset_curr) except StopIteration: break min_offsets_cropped = target_subset.min_element_approx() # calculate the new transient array size. target_subset.offset(min_offsets_cropped, True) # re-add invariant dimensions with offset 0 and save to min_offsets min_offset = [] index = 0 for i in range(len(sdfg.data(data_name).shape)): if i in invariant_dimensions[data_name]: min_offset.append(0) else: min_offset.append(min_offsets_cropped[index]) index += 1 min_offsets[data_name] = min_offset # determine the shape of the new array. new_data_shape = [] index = 0 for i, sz in enumerate(sdfg.data(data_name).shape): if i in invariant_dimensions[data_name]: new_data_shape.append(sz) else: new_data_shape.append(target_subset.size()[index]) index += 1 new_data_strides = [ data._prod(new_data_shape[i + 1:]) for i in range(len(new_data_shape)) ] new_data_totalsize = data._prod(new_data_shape) new_data_offset = [0] * len(new_data_shape) # augment. transient_to_transform = sdfg.data(data_name) transient_to_transform.shape = new_data_shape transient_to_transform.strides = new_data_strides transient_to_transform.total_size = new_data_totalsize transient_to_transform.offset = new_data_offset transient_to_transform.lifetime = dtypes.AllocationLifetime.Scope transient_to_transform.storage = self.transient_allocation else: # don't modify data container - array is needed outside # of subgraph. # hack: set lifetime to State if allocation has only been # scope so far to avoid allocation issues if sdfg.data( data_name).lifetime == dtypes.AllocationLifetime.Scope: sdfg.data( data_name).lifetime = dtypes.AllocationLifetime.State # do one pass to adjust and the memlets of in-between transients for node in intermediate_nodes: # all incoming edges to node in_edges = graph.in_edges(node) # outgoing edges going to another fused part inter_edges = [] # outgoing edges that exit global map out_edges = [] for e in graph.out_edges(node): if e.dst == global_map_exit: out_edges.append(e) else: inter_edges.append(e) # offset memlets where necessary if subgraph_contains_data[node.data]: # get min_offset min_offset = min_offsets[node.data] # re-add invariant dimensions with offset 0 for iedge in in_edges: for edge in graph.memlet_tree(iedge): if edge.data.data == node.data: edge.data.subset.offset(min_offset, True) elif edge.data.other_subset: edge.data.other_subset.offset(min_offset, True) for cedge in inter_edges: for edge in graph.memlet_tree(cedge): if edge.data.data == node.data: edge.data.subset.offset(min_offset, True) elif edge.data.other_subset: edge.data.other_subset.offset(min_offset, True) # if in_edges has several entries: # put other_subset into out_edges for correctness if len(in_edges) > 1: for oedge in out_edges: oedge.data.other_subset = dcpy(oedge.data.subset) oedge.data.other_subset.offset(min_offset, True) # also correct memlets of created transient if node in transients_created: transient_in_edges = graph.in_edges(transients_created[node]) transient_out_edges = graph.out_edges(transients_created[node]) for edge in chain(transient_in_edges, transient_out_edges): for e in graph.memlet_tree(edge): if e.data.data == node.data: e.data.data += '_OUT' # do one last pass to correct outside memlets adjacent to global map for out_connector in global_map_entry.out_connectors: # find corresponding in_connector # and the in-connecting edge in_connector = 'IN' + out_connector[3:] for iedge in graph.in_edges(global_map_entry): if iedge.dst_conn == in_connector: in_edge = iedge # find corresponding out_connector # and all out-connecting edges that belong to it # count them oedge_counter = 0 for oedge in graph.out_edges(global_map_entry): if oedge.src_conn == out_connector: out_edge = oedge oedge_counter += 1 # do memlet propagation # if there are several out edges, else there is no need if oedge_counter > 1: memlet_out = propagate_memlet(dfg_state=graph, memlet=out_edge.data, scope_node=global_map_entry, union_inner_edges=True) # override number of accesses in_edge.data.volume = memlet_out.volume in_edge.data.subset = memlet_out.subset # create a hook for outside access to global_map self._global_map_entry = global_map_entry
def apply(self, sdfg: dace.SDFG): graph: dace.SDFGState = sdfg.node(self.state_id) map_entry: nodes.MapEntry = graph.node(self.subgraph[NestK._map_entry]) stencil: Stencil = graph.node(self.subgraph[NestK._stencil]) # Find dimension index and name pname = map_entry.map.params[0] dim_index = None for edge in graph.all_edges(stencil): if edge.data.data is None: # Empty memlet continue if len(edge.data.subset) == 3: for i, rng in enumerate(edge.data.subset.ndrange()): for r in rng: if (pname in map(str, r.free_symbols)): dim_index = i break if dim_index is not None: break if dim_index is not None: break ### map_exit = graph.exit_node(map_entry) # Reconnect external edges directly to stencil node for edge in graph.in_edges(map_entry): # Find matching internal edges tree = graph.memlet_tree(edge) for child in tree.children: memlet = propagation.propagate_memlet(graph, child.edge.data, map_entry, False) graph.add_edge(edge.src, edge.src_conn, stencil, child.edge.dst_conn, memlet) for edge in graph.out_edges(map_exit): # Find matching internal edges tree = graph.memlet_tree(edge) for child in tree.children: memlet = propagation.propagate_memlet(graph, child.edge.data, map_entry, False) graph.add_edge(stencil, child.edge.src_conn, edge.dst, edge.dst_conn, memlet) # Remove map graph.remove_nodes_from([map_entry, map_exit]) # Reshape stencil node computation based on nested map range stencil.shape[dim_index] = map_entry.map.range.num_elements() # Add dimensions to access and output fields add_dims = set() for edge in graph.in_edges(stencil): if edge.data.data and len(edge.data.subset) == 3: if stencil.accesses[edge.dst_conn][0][dim_index] is False: add_dims.add(edge.dst_conn) stencil.accesses[edge.dst_conn][0][dim_index] = True for edge in graph.out_edges(stencil): if edge.data.data and len(edge.data.subset) == 3: if stencil.output_fields[edge.src_conn][0][dim_index] is False: add_dims.add(edge.src_conn) stencil.output_fields[edge.src_conn][0][dim_index] = True # Change all instances in the code as well if stencil.code.language != dace.Language.Python: raise ValueError( 'For NestK to work, Stencil code language must be Python') for i, stmt in enumerate(stencil.code.code): stencil.code.code[i] = DimensionAdder(add_dims, dim_index).visit(stmt)