def state_fission(sdfg: SDFG, subgraph: graph.SubgraphView) -> SDFGState: ''' Given a subgraph, adds a new SDFG state before the state that contains it, removes the subgraph from the original state, and connects the two states. :param subgraph: the subgraph to remove. :return: the newly created SDFG state. ''' state: SDFGState = subgraph.graph newstate = sdfg.add_state_before(state) # Save edges before removing nodes orig_edges = subgraph.edges() # Mark boundary access nodes to keep after fission nodes_to_remove = set(subgraph.nodes()) nodes_to_remove -= set(n for n in subgraph.source_nodes() if state.out_degree(n) > 1) nodes_to_remove -= set(n for n in subgraph.sink_nodes() if state.in_degree(n) > 1) state.remove_nodes_from(nodes_to_remove) for n in subgraph.nodes(): if isinstance(n, nodes.NestedSDFG): # Set the new parent state n.sdfg.parent = newstate newstate.add_nodes_from(subgraph.nodes()) for e in orig_edges: newstate.add_edge(e.src, e.src_conn, e.dst, e.dst_conn, e.data) return newstate
def can_be_applied(sdfg: SDFG, subgraph: SubgraphView): if not set(subgraph.nodes()).issubset(set(sdfg.nodes())): return False # All states need to be GPU states for state in subgraph: if not GPUPersistentKernel.is_gpu_state(sdfg, state): return False # for now exactly one inner and one outer entry state entry_states_in, entry_states_out = \ GPUPersistentKernel.get_entry_states(sdfg, subgraph) if len(entry_states_in) != 1 or len(entry_states_out) > 1: return False entry_state_in = entry_states_in.pop() if len(entry_states_out) == 1 \ and len(sdfg.edges_between(entry_states_out.pop(), entry_state_in) ) > 1: return False # for now only one outside state allowed, multiple inner exit states # allowed _, exit_states_out = GPUPersistentKernel.get_exit_states( sdfg, subgraph) if len(exit_states_out) > 1: return False # check reachability front = [entry_state_in] reachable = {entry_state_in} while len(front) > 0: current = front.pop(0) unseen = [ suc for suc in subgraph.successors(current) if suc not in reachable ] front += unseen reachable.update(unseen) if reachable != set(subgraph.nodes()): return False return True
def state_fission(sdfg: SDFG, subgraph: graph.SubgraphView) -> SDFGState: ''' Given a subgraph, adds a new SDFG state before the state that contains it, removes the subgraph from the original state, and connects the two states. :param subgraph: the subgraph to remove. :return: the newly created SDFG state. ''' state: SDFGState = subgraph.graph newstate = sdfg.add_state_before(state) # Save edges before removing nodes orig_edges = subgraph.edges() # Mark boundary access nodes to keep after fission nodes_to_remove = set(subgraph.nodes()) boundary_nodes = [ n for n in subgraph.nodes() if len(state.out_edges(n)) > len(subgraph.out_edges(n)) ] + [ n for n in subgraph.nodes() if len(state.in_edges(n)) > len(subgraph.in_edges(n)) ] # Make dictionary of nodes to add to new state new_nodes = {n: n for n in subgraph.nodes()} new_nodes.update({b: copy.deepcopy(b) for b in boundary_nodes}) nodes_to_remove -= set(boundary_nodes) state.remove_nodes_from(nodes_to_remove) for n in new_nodes.values(): if isinstance(n, nodes.NestedSDFG): # Set the new parent state n.sdfg.parent = newstate newstate.add_nodes_from(new_nodes.values()) for e in orig_edges: newstate.add_edge(new_nodes[e.src], e.src_conn, new_nodes[e.dst], e.dst_conn, e.data) return newstate
def apply(self, sdfg: SDFG): subgraph = self.subgraph_view(sdfg) entry_states_in, entry_states_out = self.get_entry_states( sdfg, subgraph) _, exit_states_out = self.get_exit_states(sdfg, subgraph) entry_state_in = entry_states_in.pop() entry_state_out = entry_states_out.pop() \ if len(entry_states_out) > 0 else None exit_state_out = exit_states_out.pop() \ if len(exit_states_out) > 0 else None launch_state = None entry_guard_state = None exit_guard_state = None # generate entry guard state if needed if self.include_in_assignment and entry_state_out is not None: entry_edge = sdfg.edges_between(entry_state_out, entry_state_in)[0] if len(entry_edge.data.assignments) > 0: entry_guard_state = sdfg.add_state( label='{}kernel_entry_guard'.format( self.kernel_prefix + '_' if self.kernel_prefix != '' else '')) sdfg.add_edge(entry_state_out, entry_guard_state, InterstateEdge(entry_edge.data.condition)) sdfg.add_edge( entry_guard_state, entry_state_in, InterstateEdge(None, entry_edge.data.assignments)) sdfg.remove_edge(entry_edge) # Update SubgraphView new_node_list = subgraph.nodes() new_node_list.append(entry_guard_state) subgraph = SubgraphView(sdfg, new_node_list) launch_state = sdfg.add_state_before( entry_guard_state, label='{}kernel_launch'.format( self.kernel_prefix + '_' if self.kernel_prefix != '' else '')) # generate exit guard state if exit_state_out is not None: exit_guard_state = sdfg.add_state_before( exit_state_out, label='{}kernel_exit_guard'.format( self.kernel_prefix + '_' if self.kernel_prefix != '' else '')) # Update SubgraphView new_node_list = subgraph.nodes() new_node_list.append(exit_guard_state) subgraph = SubgraphView(sdfg, new_node_list) if launch_state is None: launch_state = sdfg.add_state_before( exit_state_out, label='{}kernel_launch'.format( self.kernel_prefix + '_' if self.kernel_prefix != '' else '')) # If the launch state doesn't exist at this point then there is no other # states outside of the kernel, so create a stand alone launch state if launch_state is None: assert (entry_state_in is None and exit_state_out is None) launch_state = sdfg.add_state(label='{}kernel_launch'.format( self.kernel_prefix + '_' if self.kernel_prefix != '' else '')) # create sdfg for kernel and fill it with states and edges from # ssubgraph dfg will be nested at the end kernel_sdfg = SDFG( '{}kernel'.format(self.kernel_prefix + '_' if self.kernel_prefix != '' else '')) edges = subgraph.edges() for edge in edges: kernel_sdfg.add_edge(edge.src, edge.dst, edge.data) # Setting entry node in nested SDFG if no entry guard was created if entry_guard_state is None: kernel_sdfg.start_state = kernel_sdfg.node_id(entry_state_in) for state in subgraph: state.parent = kernel_sdfg # remove the now nested nodes from the outer sdfg and make sure the # launch state is properly connected to remaining states sdfg.remove_nodes_from(subgraph.nodes()) if entry_state_out is not None \ and len(sdfg.edges_between(entry_state_out, launch_state)) == 0: sdfg.add_edge(entry_state_out, launch_state, InterstateEdge()) if exit_state_out is not None \ and len(sdfg.edges_between(launch_state, exit_state_out)) == 0: sdfg.add_edge(launch_state, exit_state_out, InterstateEdge()) # Handle data for kernel kernel_data = set(node.data for state in kernel_sdfg for node in state.nodes() if isinstance(node, nodes.AccessNode)) # move Streams and Register data into the nested SDFG # normal data will be added as kernel argument kernel_args = [] for data in kernel_data: if (isinstance(sdfg.arrays[data], dace.data.Stream) or (isinstance(sdfg.arrays[data], dace.data.Array) and sdfg.arrays[data].storage == StorageType.Register)): kernel_sdfg.add_datadesc(data, sdfg.arrays[data]) del sdfg.arrays[data] else: copy_desc = copy.deepcopy(sdfg.arrays[data]) copy_desc.transient = False copy_desc.storage = StorageType.Default kernel_sdfg.add_datadesc(data, copy_desc) kernel_args.append(data) # read only data will be passed as input, writeable data will be passed # as 'output' otherwise kernel cannot write to data kernel_args_read = set() kernel_args_write = set() for data in kernel_args: data_accesses_read_only = [ node.access == dtypes.AccessType.ReadOnly for state in kernel_sdfg for node in state if isinstance(node, nodes.AccessNode) and node.data == data ] if all(data_accesses_read_only): kernel_args_read.add(data) else: kernel_args_write.add(data) # Kernel SDFG is complete at this point if self.validate: kernel_sdfg.validate() # Filling launch state with nested SDFG, map and access nodes map_entry, map_exit = launch_state.add_map( '{}kernel_launch_map'.format( self.kernel_prefix + '_' if self.kernel_prefix != '' else ''), dict(ignore='0'), schedule=ScheduleType.GPU_Persistent, ) nested_sdfg = launch_state.add_nested_sdfg( kernel_sdfg, sdfg, kernel_args_read, kernel_args_write, ) # Create and connect read only data access nodes for arg in kernel_args_read: read_node = launch_state.add_read(arg) launch_state.add_memlet_path(read_node, map_entry, nested_sdfg, dst_conn=arg, memlet=Memlet.from_array( arg, sdfg.arrays[arg])) # Create and connect writable data access nodes for arg in kernel_args_write: write_node = launch_state.add_write(arg) launch_state.add_memlet_path(nested_sdfg, map_exit, write_node, src_conn=arg, memlet=Memlet.from_array( arg, sdfg.arrays[arg])) # Transformation is done if self.validate: sdfg.validate()
def calculate_topology(self, subgraph): ''' Calculates topology information of the graph self._adjacency_list: neighbors dict of outermost scope maps self._source_maps: outermost scope maps that have in_degree 0 in the subgraph / graph self._labels: assigns index according to topological ordering (1) + node ID (2) with priorities (1) and (2) ''' sdfg = self._sdfg graph = self._graph self._adjacency_list = {m: set() for m in self._map_entries} # helper dict needed for a quick build exit_nodes = {graph.exit_node(me): me for me in self._map_entries} if subgraph: proximity_in = set(ie.src for me in self._map_entries for ie in graph.in_edges(me)) proximity_out = set(ie.dst for me in exit_nodes for ie in graph.out_edges(me)) extended_subgraph = SubgraphView( graph, set( itertools.chain(subgraph.nodes(), proximity_in, proximity_out))) for node in (extended_subgraph.nodes() if subgraph else graph.nodes()): if isinstance(node, nodes.AccessNode): adjacent_entries = set() for e in graph.in_edges(node): if isinstance(e.src, nodes.MapExit) and e.src in exit_nodes: adjacent_entries.add(exit_nodes[e.src]) for e in graph.out_edges(node): if isinstance( e.dst, nodes.MapEntry) and e.dst in self._map_entries: adjacent_entries.add(e.dst) # bidirectional mapping for entry in adjacent_entries: for other_entry in adjacent_entries: if entry != other_entry: self._adjacency_list[entry].add(other_entry) self._adjacency_list[other_entry].add(entry) # get DAG children and parents children_dict = defaultdict(set) parent_dict = defaultdict(set) for map_entry in self._map_entries: map_exit = graph.exit_node(map_entry) for e in graph.out_edges(map_exit): if isinstance(e.dst, nodes.AccessNode): for oe in graph.out_edges(e.dst): if oe.dst in self._map_entries: other_entry = oe.dst children_dict[map_entry].add(other_entry) parent_dict[other_entry].add(map_entry) # find out source nodes self._source_maps = [ me for me in self._map_entries if len(parent_dict[me]) == 0 ] # assign a unique id to each map entry according to topological # ordering. If on same level, sort according to ID for determinism self._labels = {} # map -> ID current_id = 0 while current_id < len(self._map_entries): # get current ids whose in_degree is 0 candidates = list(me for (me, s) in parent_dict.items() if len(s) == 0 and me not in self._labels) candidates.sort(key=lambda me: self._graph.node_id(me)) for c in candidates: self._labels[c] = current_id current_id += 1 # remove candidate for each players adjacency list for c_child in children_dict[c]: parent_dict[c_child].remove(c)
def can_be_applied(sdfg: SDFG, subgraph: SubgraphView) -> bool: ''' Fusible if 1. Maps have the same access sets and ranges in order 2. Any nodes in between two maps are AccessNodes only, without WCR There is at most one AccessNode only on a path between two maps, no other nodes are allowed 3. The exiting memlets' subsets to an intermediate edge must cover the respective incoming memlets' subset into the next map. Also, as a limitation, the union of all exiting memlets' subsets must be contiguous. ''' # get graph graph = subgraph.graph for node in subgraph.nodes(): if node not in graph.nodes(): return False # next, get all the maps map_entries = helpers.get_outermost_scope_maps(sdfg, graph, subgraph) map_exits = [graph.exit_node(map_entry) for map_entry in map_entries] maps = [map_entry.map for map_entry in map_entries] # 1. basic checks: # 1.1 we need to have at least two maps if len(maps) <= 1: return False ''' # 1.2 Special Case: If we can establish a valid permutation, we can # skip check 1.3 permutation = self.find_permutation ''' # 1.3 check whether all maps are the same base_map = maps[0] for map in maps: if map.get_param_num() != base_map.get_param_num(): return False if not all( [p1 == p2 for (p1, p2) in zip(map.params, base_map.params)]): return False if not map.range == base_map.range: return False # 1.3 check whether all map entries have the same schedule schedule = map_entries[0].schedule if not all([entry.schedule == schedule for entry in map_entries]): return False # 2. check intermediate feasiblility # see map_fusion.py for similar checks # with the restrictions below being more relaxed # 2.1 do some preparation work first: # calculate all out_nodes and intermediate_nodes # definition see in apply() node_config = SubgraphFusion.get_adjacent_nodes(sdfg, graph, map_entries) _, intermediate_nodes, out_nodes = node_config # 2.2 topological feasibility: if not SubgraphFusion.check_topo_feasibility( sdfg, graph, map_entries, intermediate_nodes, out_nodes): return False # 2.3 memlet feasibility # For each intermediate node, look at whether inner adjacent # memlets of the exiting map cover inner adjacent memlets # of the next entering map. # We also check for any WCRs on the fly. for node in intermediate_nodes: upper_subsets = set() lower_subsets = set() # First, determine which dimensions of the memlet ranges # change with the map, we do not need to care about the other dimensions. try: dims_to_discard = SubgraphFusion.get_invariant_dimensions( sdfg, graph, map_entries, map_exits, node) except NotImplementedError: return False # find upper_subsets for in_edge in graph.in_edges(node): in_in_edge = graph.memlet_path(in_edge)[-2] # first check for WCRs if in_edge.data.wcr: # check whether the WCR is actually produced at # this edge or further up in the memlet path. If not, # we can still fuse! subset_params = set( [str(s) for s in in_in_edge.data.subset.free_symbols]) if any([ p not in subset_params for p in in_edge.src.map.params ]): return False if in_edge.src in map_exits: subset_to_add = dcpy(in_in_edge.data.subset\ if in_in_edge.data.data == node.data\ else in_in_edge.data.other_subset) subset_to_add.pop(dims_to_discard) upper_subsets.add(subset_to_add) else: raise NotImplementedError("Nodes between two maps to be" "fused with *incoming* edges" "from outside the maps are not" "allowed yet.") # find lower_subsets for out_edge in graph.out_edges(node): if out_edge.dst in map_entries: # cannot use memlet tree here as there could be # not just one map succedding. Do it manually for oedge in graph.out_edges(out_edge.dst): if oedge.src_conn[3:] == out_edge.dst_conn[2:]: subset_to_add = dcpy(oedge.data.subset \ if oedge.data.data == node.data \ else oedge.data.other_subset) subset_to_add.pop(dims_to_discard) lower_subsets.add(subset_to_add) # We assume that upper_subsets are contiguous # Check for this. try: contiguous_upper = find_contiguous_subsets(upper_subsets) if len(contiguous_upper) > 1: return False except TypeError: warnings.warn( 'Could not determine whether subset is continuous.' 'Exiting Check with False.') return False # now take union of upper subsets upper_iter = iter(upper_subsets) union_upper = next(upper_iter) for subs in upper_iter: union_upper = subsets.union(union_upper, subs) if not union_upper: # something went wrong using union -- we'd rather abort return False # finally check coverage # every lower subset must be completely covered by union_upper for lower_subset in lower_subsets: if not union_upper.covers(lower_subset): return False return True
def can_be_applied(sdfg: SDFG, subgraph: SubgraphView) -> bool: ''' Fusible if 1. Maps have the same access sets and ranges in order 2. Any nodes in between two maps are AccessNodes only, without WCR There is at most one AccessNode only on a path between two maps, no other nodes are allowed 3. The exiting memlets' subsets to an intermediate edge must cover the respective incoming memlets' subset into the next map ''' # get graph graph = subgraph.graph for node in subgraph.nodes(): if node not in graph.nodes(): return False # next, get all the maps map_entries = helpers.get_highest_scope_maps(sdfg, graph, subgraph) map_exits = [graph.exit_node(map_entry) for map_entry in map_entries] maps = [map_entry.map for map_entry in map_entries] # 1. check whether all map ranges and indices are the same if len(maps) <= 1: return False base_map = maps[0] for map in maps: if map.get_param_num() != base_map.get_param_num(): return False if not all( [p1 == p2 for (p1, p2) in zip(map.params, base_map.params)]): return False if not map.range == base_map.range: return False # 1.1 check whether all map entries have the same schedule schedule = map_entries[0].schedule if not all([entry.schedule == schedule for entry in map_entries]): return False # 2. check intermediate feasiblility # see map_fusion.py for similar checks # we are being more relaxed here # 2.1 do some preparation work first: # calculate all out_nodes and intermediate_nodes # definition see in apply() intermediate_nodes = set() out_nodes = set() for map_entry, map_exit in zip(map_entries, map_exits): for edge in graph.out_edges(map_exit): current_node = edge.dst if len(graph.out_edges(current_node)) == 0: out_nodes.add(current_node) else: for dst_edge in graph.out_edges(current_node): if dst_edge.dst in map_entries: intermediate_nodes.add(current_node) else: out_nodes.add(current_node) # 2.2 topological feasibility: # For each intermediate and out node: must never reach any map # entry if it is not connected to map entry immediately visited = set() # for memoization purposes def visit_descendants(graph, node, visited, map_entries): # if we have already been at this node if node in visited: return True # not necessary to add if there aren't any other in connections if len(graph.in_edges(node)) > 1: visited.add(node) for oedge in graph.out_edges(node): if not visit_descendants(graph, oedge.dst, visited, map_entries): return False return True for node in intermediate_nodes | out_nodes: # these nodes must not lead to a map entry nodes_to_check = set() for oedge in graph.out_edges(node): if oedge.dst not in map_entries: nodes_to_check.add(oedge.dst) for forbidden_node in nodes_to_check: if not visit_descendants(graph, forbidden_node, visited, map_entries): return False # 2.3 memlet feasibility # For each intermediate node, look at whether inner adjacent # memlets of the exiting map cover inner adjacent memlets # of the next entering map. # We also check for any WCRs on the fly. for node in intermediate_nodes: upper_subsets = set() lower_subsets = set() # First, determine which dimensions of the memlet ranges # change with the map, we do not need to care about the other dimensions. total_dims = len(sdfg.data(node.data).shape) dims_to_discard = SubgraphFusion.get_invariant_dimensions( sdfg, graph, map_entries, map_exits, node) # find upper_subsets for in_edge in graph.in_edges(node): # first check for WCRs if in_edge.data.wcr: return False if in_edge.src in map_exits: edge = graph.memlet_path(in_edge)[-2] subset_to_add = dcpy(edge.data.subset\ if edge.data.data == node.data\ else edge.data.other_subset) subset_to_add.pop(dims_to_discard) upper_subsets.add(subset_to_add) else: raise NotImplementedError("Nodes between two maps to be" "fused with *incoming* edges" "from outside the maps are not" "allowed yet.") # find lower_subsets for out_edge in graph.out_edges(node): if out_edge.dst in map_entries: # cannot use memlet tree here as there could be # not just one map succedding. Do it manually for oedge in graph.out_edges(out_edge.dst): if oedge.src_conn[3:] == out_edge.dst_conn[2:]: subset_to_add = dcpy(oedge.data.subset \ if edge.data.data == node.data \ else edge.data.other_subset) subset_to_add.pop(dims_to_discard) lower_subsets.add(subset_to_add) upper_iter = iter(upper_subsets) union_upper = next(upper_iter) # TODO: add this check at a later point # We assume that upper_subsets for each data array # are contiguous # or do the full check if possible (intersection needed) ''' # check whether subsets in upper_subsets are adjacent. # this is a requriement for the current implementation #try: # O(n^2*|dims|) but very small amount of subsets anyway try: for dim in range(total_dims - len(dims_to_discard)): ordered_list = [(-1,-1,-1)] for upper_subset in upper_subsets: lo = upper_subset[dim][0] hi = upper_subset[dim][1] for idx,element in enumerate(ordered_list): if element[0] <= lo and element[1] >= hi: break if element[0] > lo: ordered_list.insert(idx, (lo,hi)) ordered_list.pop(0) highest = ordered_list[0][1] for i in range(len(ordered_list)): if i < len(ordered_list)-1: current_range = ordered_list[i] if current_range[1] > highest: hightest = current_range[1] next_range = ordered_list[i+1] if highest < next_range[0] - 1: return False except TypeError: #return False ''' # FORNOW: just omit warning if unsure for lower_subset in lower_subsets: covers = False for upper_subset in upper_subsets: if upper_subset.covers(lower_subset): covers = True break if not covers: warnings.warn( f"WARNING: For node {node}, please check assure that" "incoming memlets cover outgoing ones. Ambiguous check (WIP)." ) # now take union of upper subsets for subs in upper_iter: union_upper = subsets.union(union_upper, subs) if not union_upper: # something went wrong using union -- we'd rather abort return False # finally check coverage for lower_subset in lower_subsets: if not union_upper.covers(lower_subset): return False return True
def nest_state_subgraph(sdfg: SDFG, state: SDFGState, subgraph: SubgraphView, name: Optional[str] = None, full_data: bool = False) -> nodes.NestedSDFG: """ Turns a state subgraph into a nested SDFG. Operates in-place. :param sdfg: The SDFG containing the state subgraph. :param state: The state containing the subgraph. :param subgraph: Subgraph to nest. :param name: An optional name for the nested SDFG. :param full_data: If True, nests entire input/output data. :return: The nested SDFG node. :raise KeyError: Some or all nodes in the subgraph are not located in this state, or the state does not belong to the given SDFG. :raise ValueError: The subgraph is contained in more than one scope. """ if state.parent != sdfg: raise KeyError('State does not belong to given SDFG') if subgraph.graph != state: raise KeyError('Subgraph does not belong to given state') # Find the top-level scope scope_tree = state.scope_tree() scope_dict = state.scope_dict() scope_dict_children = state.scope_dict(True) top_scopenode = -1 # Initialized to -1 since "None" already means top-level for node in subgraph.nodes(): if node not in scope_dict: raise KeyError('Node not found in state') # If scope entry/exit, ensure entire scope is in subgraph if isinstance(node, nodes.EntryNode): scope_nodes = scope_dict_children[node] if any(n not in subgraph.nodes() for n in scope_nodes): raise ValueError('Subgraph contains partial scopes (entry)') elif isinstance(node, nodes.ExitNode): entry = state.entry_node(node) scope_nodes = scope_dict_children[entry] + [entry] if any(n not in subgraph.nodes() for n in scope_nodes): raise ValueError('Subgraph contains partial scopes (exit)') scope_node = scope_dict[node] if scope_node not in subgraph.nodes(): if top_scopenode != -1 and top_scopenode != scope_node: raise ValueError( 'Subgraph is contained in more than one scope') top_scopenode = scope_node scope = scope_tree[top_scopenode] ### # Collect inputs and outputs of the nested SDFG inputs: List[MultiConnectorEdge] = [] outputs: List[MultiConnectorEdge] = [] for node in subgraph.source_nodes(): inputs.extend(state.in_edges(node)) for node in subgraph.sink_nodes(): outputs.extend(state.out_edges(node)) # Collect transients not used outside of subgraph (will be removed of # top-level graph) data_in_subgraph = set(n.data for n in subgraph.nodes() if isinstance(n, nodes.AccessNode)) # Find other occurrences in SDFG other_nodes = set( n.data for s in sdfg.nodes() for n in s.nodes() if isinstance(n, nodes.AccessNode) and n not in subgraph.nodes()) subgraph_transients = set() for data in data_in_subgraph: datadesc = sdfg.arrays[data] if datadesc.transient and data not in other_nodes: subgraph_transients.add(data) # All transients of edges between code nodes are also added to nested graph for edge in subgraph.edges(): if (isinstance(edge.src, nodes.CodeNode) and isinstance(edge.dst, nodes.CodeNode)): subgraph_transients.add(edge.data.data) # Collect data used in access nodes within subgraph (will be referenced in # full upon nesting) input_arrays = set() output_arrays = set() for node in subgraph.nodes(): if (isinstance(node, nodes.AccessNode) and node.data not in subgraph_transients): if state.out_degree(node) > 0: input_arrays.add(node.data) if state.in_degree(node) > 0: output_arrays.add(node.data) # Create the nested SDFG nsdfg = SDFG(name or 'nested_' + state.label) # Transients are added to the nested graph as-is for name in subgraph_transients: nsdfg.add_datadesc(name, sdfg.arrays[name]) # Input/output data that are not source/sink nodes are added to the graph # as non-transients for name in (input_arrays | output_arrays): datadesc = copy.deepcopy(sdfg.arrays[name]) datadesc.transient = False nsdfg.add_datadesc(name, datadesc) # Connected source/sink nodes outside subgraph become global data # descriptors in nested SDFG input_names = [] output_names = [] for edge in inputs: if edge.data.data is None: # Skip edges with an empty memlet continue name = '__in_' + edge.data.data datadesc = copy.deepcopy(sdfg.arrays[edge.data.data]) datadesc.transient = False if not full_data: datadesc.shape = edge.data.subset.size() input_names.append( nsdfg.add_datadesc(name, datadesc, find_new_name=True)) for edge in outputs: if edge.data.data is None: # Skip edges with an empty memlet continue name = '__out_' + edge.data.data datadesc = copy.deepcopy(sdfg.arrays[edge.data.data]) datadesc.transient = False if not full_data: datadesc.shape = edge.data.subset.size() output_names.append( nsdfg.add_datadesc(name, datadesc, find_new_name=True)) ################### # Add scope symbols to the nested SDFG for v in scope.defined_vars: if v in sdfg.symbols: sym = sdfg.symbols[v] nsdfg.add_symbol(v, sym.dtype) # Create nested state nstate = nsdfg.add_state() # Add subgraph nodes and edges to nested state nstate.add_nodes_from(subgraph.nodes()) for e in subgraph.edges(): nstate.add_edge(e.src, e.src_conn, e.dst, e.dst_conn, e.data) # Modify nested SDFG parents in subgraph for node in subgraph.nodes(): if isinstance(node, nodes.NestedSDFG): node.sdfg.parent = nstate node.sdfg.parent_sdfg = nsdfg # Add access nodes and edges as necessary edges_to_offset = [] for name, edge in zip(input_names, inputs): node = nstate.add_read(name) new_edge = copy.deepcopy(edge.data) new_edge.data = name edges_to_offset.append((edge, nstate.add_edge(node, None, edge.dst, edge.dst_conn, new_edge))) for name, edge in zip(output_names, outputs): node = nstate.add_write(name) new_edge = copy.deepcopy(edge.data) new_edge.data = name edges_to_offset.append((edge, nstate.add_edge(edge.src, edge.src_conn, node, None, new_edge))) # Offset memlet paths inside nested SDFG according to subsets for original_edge, new_edge in edges_to_offset: for edge in nstate.memlet_tree(new_edge): edge.data.data = new_edge.data.data if not full_data: edge.data.subset.offset(original_edge.data.subset, True) # Add nested SDFG node to the input state nested_sdfg = state.add_nested_sdfg(nsdfg, None, set(input_names) | input_arrays, set(output_names) | output_arrays) # Reconnect memlets to nested SDFG for name, edge in zip(input_names, inputs): if full_data: data = Memlet.from_array(edge.data.data, sdfg.arrays[edge.data.data]) else: data = edge.data state.add_edge(edge.src, edge.src_conn, nested_sdfg, name, data) for name, edge in zip(output_names, outputs): if full_data: data = Memlet.from_array(edge.data.data, sdfg.arrays[edge.data.data]) else: data = edge.data state.add_edge(nested_sdfg, name, edge.dst, edge.dst_conn, data) # Connect access nodes to internal input/output data as necessary entry = scope.entry exit = scope.exit for name in input_arrays: node = state.add_read(name) if entry is not None: state.add_nedge(entry, node, EmptyMemlet()) state.add_edge(node, None, nested_sdfg, name, Memlet.from_array(name, sdfg.arrays[name])) for name in output_arrays: node = state.add_write(name) if exit is not None: state.add_nedge(node, exit, EmptyMemlet()) state.add_edge(nested_sdfg, name, node, None, Memlet.from_array(name, sdfg.arrays[name])) # Remove subgraph nodes from graph state.remove_nodes_from(subgraph.nodes()) # Remove subgraph transients from top-level graph for transient in subgraph_transients: del sdfg.arrays[transient] return nested_sdfg