def expansion(node, parent_state, parent_sdfg, **kwargs): node.validate(parent_sdfg, parent_state) inputs = ('_A', '_x', '_y') outputs = ('_res', ) in_edges = [next(parent_state.in_edges_by_connector(node, conn)) for conn in inputs] out_edges = [next(parent_state.out_edges_by_connector(node, conn)) for conn in outputs] arrays = {} arrays.update({inp: parent_sdfg.arrays[e.data.data] for inp, e in zip(inputs, in_edges)}) arrays.update({out: parent_sdfg.arrays[e.data.data] for out, e in zip(outputs, out_edges)}) # TODO: Support memlet subsets if any(e.data.subset != sbs.Range.from_array(arrays[a]) for a, e in zip(inputs, in_edges)): raise NotImplementedError if any(e.data.subset != sbs.Range.from_array(arrays[a]) for a, e in zip(outputs, out_edges)): raise NotImplementedError sdfg = dace.SDFG(f'{node.label}_sdfg') sdfg.add_symbol('M', int) sdfg.add_symbol('N', int) sdfg.add_symbol('alpha', arrays['_A'].dtype) for name, desc in arrays.items(): newdesc = copy.deepcopy(desc) newdesc.transient = False sdfg.add_datadesc(name, newdesc) state = sdfg.add_state() state.add_mapped_tasklet( 'ger', { '_i': f'0:M', '_j': f'0:N' }, { 'a': mm.Memlet('_A[_i, _j]'), 'xin': mm.Memlet('_x[_i]'), 'yin': mm.Memlet(f'_y[_j]') }, f'aout = alpha * xin * yin + a', {'aout': mm.Memlet('_res[_i, _j]')}, external_edges=True, ) outshape = arrays['_res'].shape nsdfg_node = nodes.NestedSDFG(node.label, sdfg, set(inputs), set(outputs), { 'M': outshape[0], 'N': outshape[1], 'alpha': node.alpha }) return nsdfg_node
class CopyToDevice(pattern_matching.Transformation): """ Implements the copy-to-device transformation, which copies a nested SDFG and its dependencies to a given device. The transformation changes all data storage types of a nested SDFG to the given `storage` property, and creates new arrays and copies around the nested SDFG to that storage. """ _nested_sdfg = nodes.NestedSDFG("", graph.OrderedDiGraph(), {}, {}) storage = properties.Property(dtype=dtypes.StorageType, desc="Nested SDFG storage", choices=dtypes.StorageType, from_string=lambda x: dtypes.StorageType[x], default=dtypes.StorageType.Default) @staticmethod def annotates_memlets(): return True @staticmethod def expressions(): return [sdutil.node_path_graph(CopyToDevice._nested_sdfg)] @staticmethod def can_be_applied(graph, candidate, expr_index, sdfg, strict=False): nested_sdfg = graph.nodes()[candidate[CopyToDevice._nested_sdfg]] for edge in graph.all_edges(nested_sdfg): # Stream inputs/outputs not allowed path = graph.memlet_path(edge) if ((isinstance(path[0].src, nodes.AccessNode) and isinstance(sdfg.arrays[path[0].src.data], data.Stream)) or (isinstance(path[-1].dst, nodes.AccessNode) and isinstance(sdfg.arrays[path[-1].dst.data], data.Stream))): return False # WCR outputs with arrays are not allowed if (edge.data.wcr is not None and edge.data.subset.num_elements() != 1): return False return True @staticmethod def match_to_str(graph, candidate): nested_sdfg = graph.nodes()[candidate[CopyToDevice._nested_sdfg]] return nested_sdfg.label def apply(self, sdfg): state = sdfg.nodes()[self.state_id] nested_sdfg = state.nodes()[self.subgraph[CopyToDevice._nested_sdfg]] storage = self.storage created_arrays = set() for _, edge in enumerate(state.in_edges(nested_sdfg)): src, src_conn, dst, dst_conn, memlet = edge dataname = memlet.data if dataname is None: continue memdata = sdfg.arrays[dataname] name = 'device_' + dataname + '_in' if name not in created_arrays: if isinstance(memdata, data.Array): name, _ = sdfg.add_array( 'device_' + dataname + '_in', shape=[ symbolic.overapproximate(r) for r in memlet.bounding_box_size() ], dtype=memdata.dtype, transient=True, storage=storage, find_new_name=True) elif isinstance(memdata, data.Scalar): name, _ = sdfg.add_scalar('device_' + dataname + '_in', dtype=memdata.dtype, transient=True, storage=storage, find_new_name=True) else: raise NotImplementedError created_arrays.add(name) data_node = nodes.AccessNode(name) to_data_mm = dcpy(memlet) from_data_mm = dcpy(memlet) from_data_mm.data = name offset = [] for ind, r in enumerate(memlet.subset): offset.append(r[0]) if isinstance(memlet.subset[ind], tuple): begin = memlet.subset[ind][0] - r[0] end = memlet.subset[ind][1] - r[0] step = memlet.subset[ind][2] from_data_mm.subset[ind] = (begin, end, step) else: from_data_mm.subset[ind] -= r[0] state.remove_edge(edge) state.add_edge(src, src_conn, data_node, None, to_data_mm) state.add_edge(data_node, None, dst, dst_conn, from_data_mm) for _, edge in enumerate(state.out_edges(nested_sdfg)): src, src_conn, dst, dst_conn, memlet = edge dataname = memlet.data if dataname is None: continue memdata = sdfg.arrays[dataname] name = 'device_' + dataname + '_out' if name not in created_arrays: if isinstance(memdata, data.Array): name, _ = sdfg.add_array( name, shape=[ symbolic.overapproximate(r) for r in memlet.bounding_box_size() ], dtype=memdata.dtype, transient=True, storage=storage, find_new_name=True) elif isinstance(memdata, data.Scalar): name, _ = sdfg.add_scalar(name, dtype=memdata.dtype, transient=True, storage=storage) else: raise NotImplementedError created_arrays.add(name) data_node = nodes.AccessNode(name) to_data_mm = dcpy(memlet) from_data_mm = dcpy(memlet) to_data_mm.data = name offset = [] for ind, r in enumerate(memlet.subset): offset.append(r[0]) if isinstance(memlet.subset[ind], tuple): begin = memlet.subset[ind][0] - r[0] end = memlet.subset[ind][1] - r[0] step = memlet.subset[ind][2] to_data_mm.subset[ind] = (begin, end, step) else: to_data_mm.subset[ind] -= r[0] state.remove_edge(edge) state.add_edge(src, src_conn, data_node, None, to_data_mm) state.add_edge(data_node, None, dst, dst_conn, from_data_mm) # Change storage for all data inside nested SDFG to device. change_storage(nested_sdfg.sdfg, storage)
class InlineSDFG(transformation.Transformation): """ Inlines a single-state nested SDFG into a top-level SDFG. In particular, the steps taken are: 1. All transient arrays become transients of the parent 2. If a source/sink node is one of the inputs/outputs: a. Remove it b. Reconnect through external edges (map/accessnode) c. Replace and reoffset memlets with external data descriptor 3. If other nodes carry the names of inputs/outputs: a. Replace data with external data descriptor b. Replace and reoffset memlets with external data descriptor 4. If source/sink node is not connected to a source/destination, and the nested SDFG is in a scope, connect to scope with empty memlets 5. Remove all unused external inputs/output memlet paths 6. Remove isolated nodes resulting from previous step """ _nested_sdfg = nodes.NestedSDFG('_', sd.SDFG('_'), {}, {}) @staticmethod def annotates_memlets(): return True @staticmethod def expressions(): return [sdutil.node_path_graph(InlineSDFG._nested_sdfg)] @staticmethod def _check_strides(inner_strides: List[symbolic.SymbolicType], outer_strides: List[symbolic.SymbolicType], memlet: Memlet, nested_sdfg: nodes.NestedSDFG) -> bool: """ Returns True if the strides of the inner array can be matched to the strides of the outer array upon inlining. Takes into consideration memlet (un)squeeze and nested SDFG symbol mapping. :param inner_strides: The strides of the array inside the nested SDFG. :param outer_strides: The strides of the array in the external SDFG. :param nested_sdfg: Nested SDFG node with symbol mapping. :return: True if all strides match, False otherwise. """ # Take unsqueezing into account dims_to_ignore = [ i for i, s in enumerate(memlet.subset.size()) if s == 1 ] ostrides = [ os for i, os in enumerate(outer_strides) if i not in dims_to_ignore ] if len(ostrides) == 0: ostrides = [1] if len(ostrides) != len(inner_strides): return False # Replace all inner symbols based on symbol mapping repldict = { symbolic.pystr_to_symbolic(k): symbolic.pystr_to_symbolic(v) for k, v in nested_sdfg.symbol_mapping.items() } istrides = [ istr.subs(repldict) if symbolic.issymbolic(istr) else istr for istr in inner_strides ] return all(istr == ostr for istr, ostr in zip(istrides, ostrides)) @staticmethod def can_be_applied(graph, candidate, expr_index, sdfg, strict=False): nested_sdfg = graph.nodes()[candidate[InlineSDFG._nested_sdfg]] if nested_sdfg.no_inline: return False if len(nested_sdfg.sdfg.nodes()) != 1: return False # Ensure every connector has one incoming/outgoing edge in_connectors = set() out_connectors = set() for edge in graph.in_edges(nested_sdfg): if edge.dst_conn in in_connectors: return False in_connectors.add(edge.dst_conn) for edge in graph.out_edges(nested_sdfg): if edge.src_conn in out_connectors: return False out_connectors.add(edge.src_conn) # Ensure output connectors have no additional outputs (if in a scope), # and ensure no two connectors are directly connected to each other if graph.entry_node(nested_sdfg) is not None: all_connectors = in_connectors | out_connectors nstate = nested_sdfg.sdfg.node(0) for node in nstate.nodes(): if isinstance(node, nodes.AccessNode): if (node.data in out_connectors and nstate.out_degree(node) > 0 and (node.data not in in_connectors or nstate.in_degree(node) > 0)): return False if (node.data in in_connectors and any(e.dst.data in all_connectors for e in nstate.out_edges(node) if isinstance(e.dst, nodes.AccessNode))): return False return True @staticmethod def match_to_str(graph, candidate): return graph.label def _remove_edge_path(self, state: SDFGState, edge_map: Dict[str, MultiConnectorEdge], unused: Set[str], reverse: bool = False) -> List[MultiConnectorEdge]: """ Remove all edges along a path, until memlet tree contains siblings that should not be removed. Removes resulting isolated nodes as well. Operates in place. :param state: The state in which to remove edges. :param edge_map: Mapping from identifier to edge, used as a predicate for removal. :param unused: Set of edge identifiers to remove. :param reverse: If False, removes forward in path, otherwise backward. :return: List of edges from removed nodes at the path's end. """ if reverse: edge_func = lambda e: state.out_edges(e.src) edge_pred = lambda pedge, e: e.src_conn == pedge.src_conn else: edge_func = lambda e: state.in_edges(e.dst) edge_pred = lambda pedge, e: e.dst_conn == pedge.dst_conn result = [] for identifier, edge in edge_map.items(): if identifier in unused: path = state.memlet_path(edge) pedge = None for pedge in (reversed(path) if reverse else path): # If there are no other edges, it is safe to remove if len([ e for e in edge_func(pedge) if edge_pred(pedge, e) ]) == 1: # Remove connectors as well state.remove_edge_and_connectors(pedge) else: break else: # Reached terminus without breaking, remove external node if pedge is not None: node = pedge.src if reverse else pedge.dst # Keep track of edges on the other end of these nodes, # they will be used to reconnect to first/last # occurrence of access nodes in the inlined subgraph. if reverse: result.extend(state.in_edges(node)) else: result.extend(state.out_edges(node)) state.remove_node(node) return result def apply(self, sdfg: SDFG): state: SDFGState = sdfg.nodes()[self.state_id] nsdfg_node = state.nodes()[self.subgraph[InlineSDFG._nested_sdfg]] nsdfg: SDFG = nsdfg_node.sdfg nstate: SDFGState = nsdfg.nodes()[0] if nsdfg_node.schedule is not dtypes.ScheduleType.Default: infer_types.set_default_schedule_and_storage_types( nsdfg, nsdfg_node.schedule) nsdfg_scope_entry = state.entry_node(nsdfg_node) nsdfg_scope_exit = (state.exit_node(nsdfg_scope_entry) if nsdfg_scope_entry is not None else None) ####################################################### # Collect and update top-level SDFG metadata # Global/init/exit code for loc, code in nsdfg.global_code.items(): sdfg.append_global_code(code.code, loc) for loc, code in nsdfg.init_code.items(): sdfg.append_init_code(code.code, loc) for loc, code in nsdfg.exit_code.items(): sdfg.append_exit_code(code.code, loc) # Constants for cstname, cstval in nsdfg.constants.items(): if cstname in sdfg.constants: if cstval != sdfg.constants[cstname]: warnings.warn('Constant value mismatch for "%s" while ' 'inlining SDFG. Inner = %s != %s = outer' % (cstname, cstval, sdfg.constants[cstname])) else: sdfg.add_constant(cstname, cstval) # Find original source/destination edges (there is only one edge per # connector, according to match) inputs: Dict[str, MultiConnectorEdge] = {} outputs: Dict[str, MultiConnectorEdge] = {} input_set: Dict[str, str] = {} output_set: Dict[str, str] = {} for e in state.in_edges(nsdfg_node): inputs[e.dst_conn] = e input_set[e.data.data] = e.dst_conn for e in state.out_edges(nsdfg_node): outputs[e.src_conn] = e output_set[e.data.data] = e.src_conn # Access nodes that need to be reshaped reshapes: Set(str) = set() for aname, array in nsdfg.arrays.items(): if array.transient: continue edge = None if aname in inputs: edge = inputs[aname] if len(array.shape) > len(edge.data.subset): reshapes.add(aname) continue if aname in outputs: edge = outputs[aname] if len(array.shape) > len(edge.data.subset): reshapes.add(aname) continue if edge is not None and not InlineSDFG._check_strides( array.strides, sdfg.arrays[edge.data.data].strides, edge.data, nsdfg_node): reshapes.add(aname) # Replace symbols using invocation symbol mapping # Two-step replacement (N -> __dacesym_N --> map[N]) to avoid clashes for symname, symvalue in nsdfg_node.symbol_mapping.items(): if str(symname) != str(symvalue): nsdfg.replace(symname, '__dacesym_' + symname) for symname, symvalue in nsdfg_node.symbol_mapping.items(): if str(symname) != str(symvalue): nsdfg.replace('__dacesym_' + symname, symvalue) # All transients become transients of the parent (if data already # exists, find new name) # Mapping from nested transient name to top-level name transients: Dict[str, str] = {} for node in nstate.nodes(): if isinstance(node, nodes.AccessNode): datadesc = nsdfg.arrays[node.data] if node.data not in transients and datadesc.transient: name = sdfg.add_datadesc('%s_%s' % (nsdfg.label, node.data), datadesc, find_new_name=True) transients[node.data] = name # All transients of edges between code nodes are also added to parent for edge in nstate.edges(): if (isinstance(edge.src, nodes.CodeNode) and isinstance(edge.dst, nodes.CodeNode)): if edge.data.data is not None: datadesc = nsdfg.arrays[edge.data.data] if edge.data.data not in transients and datadesc.transient: name = sdfg.add_datadesc('%s_%s' % (nsdfg.label, edge.data.data), datadesc, find_new_name=True) transients[edge.data.data] = name # Collect nodes to add to top-level graph new_incoming_edges: Dict[nodes.Node, MultiConnectorEdge] = {} new_outgoing_edges: Dict[nodes.Node, MultiConnectorEdge] = {} source_accesses = set() sink_accesses = set() for node in nstate.source_nodes(): if (isinstance(node, nodes.AccessNode) and node.data not in transients and node.data not in reshapes): new_incoming_edges[node] = inputs[node.data] source_accesses.add(node) for node in nstate.sink_nodes(): if (isinstance(node, nodes.AccessNode) and node.data not in transients and node.data not in reshapes): new_outgoing_edges[node] = outputs[node.data] sink_accesses.add(node) ####################################################### # Replace data on inlined SDFG nodes/edges # Replace data names with their top-level counterparts repldict = {} repldict.update(transients) repldict.update({ k: v.data.data for k, v in itertools.chain(inputs.items(), outputs.items()) }) # Add views whenever reshapes are necessary for dname in reshapes: desc = nsdfg.arrays[dname] # To avoid potential confusion, rename protected __return keyword if dname.startswith('__return'): newname = f'{nsdfg.name}_ret{dname[8:]}' else: newname = dname newname, _ = sdfg.add_view(newname, desc.shape, desc.dtype, storage=desc.storage, strides=desc.strides, offset=desc.offset, debuginfo=desc.debuginfo, allow_conflicts=desc.allow_conflicts, total_size=desc.total_size, alignment=desc.alignment, may_alias=desc.may_alias, find_new_name=True) repldict[dname] = newname for node in nstate.nodes(): if isinstance(node, nodes.AccessNode) and node.data in repldict: node.data = repldict[node.data] for edge in nstate.edges(): if edge.data.data in repldict: edge.data.data = repldict[edge.data.data] # Add extra access nodes for out/in view nodes for node in nstate.nodes(): if isinstance(node, nodes.AccessNode) and node.data in reshapes: if nstate.in_degree(node) > 0 and nstate.out_degree(node) > 0: # Such a node has to be in the output set edge = outputs[node.data] # Redirect outgoing edges through access node out_edges = list(nstate.out_edges(node)) anode = nstate.add_access(edge.data.data) vnode = nstate.add_access(node.data) nstate.add_nedge(node, anode, edge.data) nstate.add_nedge(anode, vnode, edge.data) for e in out_edges: nstate.remove_edge(e) nstate.add_edge(vnode, e.src_conn, e.dst, e.dst_conn, e.data) ####################################################### # Add nested SDFG into top-level SDFG # Add nested nodes into original state subgraph = SubgraphView(nstate, [ n for n in nstate.nodes() if n not in (source_accesses | sink_accesses) ]) state.add_nodes_from(subgraph.nodes()) for edge in subgraph.edges(): state.add_edge(edge.src, edge.src_conn, edge.dst, edge.dst_conn, edge.data) ####################################################### # Reconnect inlined SDFG # If a source/sink node is one of the inputs/outputs, reconnect it, # replacing memlets in outgoing/incoming paths modified_edges = set() modified_edges |= self._modify_memlet_path(new_incoming_edges, nstate, state, True) modified_edges |= self._modify_memlet_path(new_outgoing_edges, nstate, state, False) # Reshape: add connections to viewed data self._modify_reshape_data(reshapes, repldict, inputs, nstate, state, True) self._modify_reshape_data(reshapes, repldict, outputs, nstate, state, False) # Modify all other internal edges pertaining to input/output nodes for node in subgraph.nodes(): if isinstance(node, nodes.AccessNode): if node.data in input_set or node.data in output_set: if node.data in input_set: outer_edge = inputs[input_set[node.data]] else: outer_edge = outputs[output_set[node.data]] for edge in state.all_edges(node): if (edge not in modified_edges and edge.data.data == node.data): for e in state.memlet_tree(edge): if e.data.data == node.data: e._data = helpers.unsqueeze_memlet( e.data, outer_edge.data) # If source/sink node is not connected to a source/destination access # node, and the nested SDFG is in a scope, connect to scope with empty # memlets if nsdfg_scope_entry is not None: for node in subgraph.nodes(): if state.in_degree(node) == 0: state.add_edge(nsdfg_scope_entry, None, node, None, Memlet()) if state.out_degree(node) == 0: state.add_edge(node, None, nsdfg_scope_exit, None, Memlet()) # Replace nested SDFG parents with new SDFG for node in nstate.nodes(): if isinstance(node, nodes.NestedSDFG): node.sdfg.parent = state node.sdfg.parent_sdfg = sdfg node.sdfg.parent_nsdfg_node = node # Remove all unused external inputs/output memlet paths, as well as # resulting isolated nodes removed_in_edges = self._remove_edge_path(state, inputs, set(inputs.keys()) - source_accesses, reverse=True) removed_out_edges = self._remove_edge_path(state, outputs, set(outputs.keys()) - sink_accesses, reverse=False) # Re-add in/out edges to first/last nodes in subgraph order = [ x for x in nx.topological_sort(nstate._nx) if isinstance(x, nodes.AccessNode) ] for edge in removed_in_edges: # Find first access node that refers to this edge node = next(n for n in order if n.data == edge.data.data) state.add_edge(edge.src, edge.src_conn, node, edge.dst_conn, edge.data) for edge in removed_out_edges: # Find last access node that refers to this edge node = next(n for n in reversed(order) if n.data == edge.data.data) state.add_edge(node, edge.src_conn, edge.dst, edge.dst_conn, edge.data) ####################################################### # Remove nested SDFG node state.remove_node(nsdfg_node) def _modify_memlet_path(self, new_edges: Dict[nodes.Node, MultiConnectorEdge], nstate: SDFGState, state: SDFGState, inputs: bool) -> Set[MultiConnectorEdge]: """ Modifies memlet paths in an inlined SDFG. Returns set of modified edges. """ result = set() for node, top_edge in new_edges.items(): inner_edges = (nstate.out_edges(node) if inputs else nstate.in_edges(node)) for inner_edge in inner_edges: new_memlet = helpers.unsqueeze_memlet(inner_edge.data, top_edge.data) if inputs: new_edge = state.add_edge(top_edge.src, top_edge.src_conn, inner_edge.dst, inner_edge.dst_conn, new_memlet) mtree = state.memlet_tree(new_edge) else: new_edge = state.add_edge(inner_edge.src, inner_edge.src_conn, top_edge.dst, top_edge.dst_conn, new_memlet) mtree = state.memlet_tree(new_edge) # Modify all memlets going forward/backward def traverse(mtree_node): result.add(mtree_node.edge) mtree_node.edge._data = helpers.unsqueeze_memlet( mtree_node.edge.data, top_edge.data) for child in mtree_node.children: traverse(child) for child in mtree.children: traverse(child) return result def _modify_reshape_data(self, reshapes: Set[str], repldict: Dict[str, str], new_edges: Dict[str, MultiConnectorEdge], nstate: SDFGState, state: SDFGState, inputs: bool): anodes = nstate.source_nodes() if inputs else nstate.sink_nodes() reshp = {repldict[r]: r for r in reshapes} for node in anodes: if not isinstance(node, nodes.AccessNode): continue if node.data not in reshp: continue edge = new_edges[reshp[node.data]] if inputs: state.add_edge(edge.src, edge.src_conn, node, None, edge.data) else: state.add_edge(node, None, edge.dst, edge.dst_conn, edge.data)
class InlineSDFG(pattern_matching.Transformation): """ Inlines a single-state nested SDFG into a top-level SDFG. In particular, the steps taken are: 1. All transient arrays become transients of the parent 2. If a source/sink node is one of the inputs/outputs: a. Remove it b. Reconnect through external edges (map/accessnode) c. Replace and reoffset memlets with external data descriptor 3. If other nodes carry the names of inputs/outputs: a. Replace data with external data descriptor b. Replace and reoffset memlets with external data descriptor 4. If source/sink node is not connected to a source/destination, and the nested SDFG is in a scope, connect to scope with empty memlets 5. Remove all unused external inputs/output memlet paths 6. Remove isolated nodes resulting from previous step """ _nested_sdfg = nodes.NestedSDFG('_', sd.SDFG('_'), set(), set()) @staticmethod def annotates_memlets(): return True @staticmethod def expressions(): # Matches anything return [sdutil.node_path_graph(InlineSDFG._nested_sdfg)] @staticmethod def _find_edge(state: SDFGState, node: nodes.Node, connector: str) -> Optional[MultiConnectorEdge]: for edge in state.in_edges(node): if edge.dst_conn == connector: return edge for edge in state.out_edges(node): if edge.src_conn == connector: return edge raise NameError('Edge with connector %s not found on node %s' % (connector, node)) @staticmethod def can_be_applied(graph, candidate, expr_index, sdfg, strict=False): nested_sdfg = graph.nodes()[candidate[InlineSDFG._nested_sdfg]] if len(nested_sdfg.sdfg.nodes()) != 1: return False # Ensure every connector has one incoming/outgoing edge in_connectors = set() out_connectors = set() for edge in graph.in_edges(nested_sdfg): if edge.dst_conn in in_connectors: return False in_connectors.add(edge.dst_conn) for edge in graph.out_edges(nested_sdfg): if edge.src_conn in out_connectors: return False out_connectors.add(edge.src_conn) # Ensure output connectors have no additional outputs (if in a scope), # and ensure no two connectors are directly connected to each other if graph.entry_node(nested_sdfg) is not None: all_connectors = in_connectors | out_connectors nstate = nested_sdfg.sdfg.node(0) for node in nstate.nodes(): if isinstance(node, nodes.AccessNode): if (node.data in out_connectors and nstate.out_degree(node) > 0 and (node.data not in in_connectors or nstate.in_degree(node) > 0)): return False if (node.data in in_connectors and any(e.dst.data in all_connectors for e in nstate.out_edges(node) if isinstance(e.dst, nodes.AccessNode))): return False # If some reshaping that cannot be inlined / unsqueezed is happening, # do not match transformation in strict mode. if strict: for aname, array in nested_sdfg.sdfg.arrays.items(): if array.transient: continue edge = InlineSDFG._find_edge(graph, nested_sdfg, aname) if len(array.shape) > len(edge.data.subset): return False return True @staticmethod def match_to_str(graph, candidate): return graph.label def _remove_edge_path(self, state: SDFGState, edge_map: Dict[str, MultiConnectorEdge], unused: Set[str], reverse: bool = False) -> List[MultiConnectorEdge]: """ Remove all edges along a path, until memlet tree contains siblings that should not be removed. Removes resulting isolated nodes as well. Operates in place. :param state: The state in which to remove edges. :param edge_map: Mapping from identifier to edge, used as a predicate for removal. :param unused: Set of edge identifiers to remove. :param reverse: If False, removes forward in path, otherwise backward. :return: List of edges from removed nodes at the path's end. """ if reverse: edge_func = lambda e: state.out_edges(e.src) edge_pred = lambda pedge, e: e.src_conn == pedge.src_conn else: edge_func = lambda e: state.in_edges(e.dst) edge_pred = lambda pedge, e: e.dst_conn == pedge.dst_conn result = [] for identifier, edge in edge_map.items(): if identifier in unused: path = state.memlet_path(edge) pedge = None for pedge in (reversed(path) if reverse else path): # If there are no other edges, it is safe to remove if len([ e for e in edge_func(pedge) if edge_pred(pedge, e) ]) == 1: # Remove connectors as well state.remove_edge_and_connectors(pedge) else: break else: # Reached terminus without breaking, remove external node if pedge is not None: node = pedge.src if reverse else pedge.dst # Keep track of edges on the other end of these nodes, # they will be used to reconnect to first/last # occurrence of access nodes in the inlined subgraph. if reverse: result.extend(state.in_edges(node)) else: result.extend(state.out_edges(node)) state.remove_node(node) return result def apply(self, sdfg): state: SDFGState = sdfg.nodes()[self.state_id] nsdfg_node = state.nodes()[self.subgraph[InlineSDFG._nested_sdfg]] nsdfg: SDFG = nsdfg_node.sdfg nstate: SDFGState = nsdfg.nodes()[0] nsdfg_scope_entry = state.entry_node(nsdfg_node) nsdfg_scope_exit = (state.exit_node(nsdfg_scope_entry) if nsdfg_scope_entry is not None else None) ####################################################### # Collect and update top-level SDFG metadata # Global/init/exit code for loc, code in nsdfg.global_code.items(): sdfg.append_global_code(code.code, loc) for loc, code in nsdfg.init_code.items(): sdfg.append_init_code(code.code, loc) for loc, code in nsdfg.exit_code.items(): sdfg.append_exit_code(code.code, loc) # Constants for cstname, cstval in nsdfg.constants.items(): if cstname in sdfg.constants: if cstval != sdfg.constants[cstname]: warnings.warn('Constant value mismatch for "%s" while ' 'inlining SDFG. Inner = %s != %s = outer' % (cstname, cstval, sdfg.constants[cstname])) else: sdfg.add_constant(cstname, cstval) # Find original source/destination edges (there is only one edge per # connector, according to match) inputs: Dict[str, MultiConnectorEdge] = {} outputs: Dict[str, MultiConnectorEdge] = {} input_set: Dict[str, str] = {} output_set: Dict[str, str] = {} for e in state.in_edges(nsdfg_node): inputs[e.dst_conn] = e input_set[e.data.data] = e.dst_conn for e in state.out_edges(nsdfg_node): outputs[e.src_conn] = e output_set[e.data.data] = e.src_conn # All transients become transients of the parent (if data already # exists, find new name) # Mapping from nested transient name to top-level name transients: Dict[str, str] = {} for node in nstate.nodes(): if isinstance(node, nodes.AccessNode): datadesc = nsdfg.arrays[node.data] if node.data not in transients and datadesc.transient: name = sdfg.add_datadesc('%s_%s' % (nsdfg.label, node.data), datadesc, find_new_name=True) transients[node.data] = name # All transients of edges between code nodes are also added to parent for edge in nstate.edges(): if (isinstance(edge.src, nodes.CodeNode) and isinstance(edge.dst, nodes.CodeNode)): datadesc = nsdfg.arrays[edge.data.data] if edge.data.data not in transients and datadesc.transient: name = sdfg.add_datadesc('%s_%s' % (nsdfg.label, edge.data.data), datadesc, find_new_name=True) transients[edge.data.data] = name # Collect nodes to add to top-level graph new_incoming_edges: Dict[nodes.Node, MultiConnectorEdge] = {} new_outgoing_edges: Dict[nodes.Node, MultiConnectorEdge] = {} source_accesses = set() sink_accesses = set() for node in nstate.source_nodes(): if (isinstance(node, nodes.AccessNode) and node.data not in transients): new_incoming_edges[node] = inputs[node.data] source_accesses.add(node) for node in nstate.sink_nodes(): if (isinstance(node, nodes.AccessNode) and node.data not in transients): new_outgoing_edges[node] = outputs[node.data] sink_accesses.add(node) ####################################################### # Add nested SDFG into top-level SDFG # Add nested nodes into original state subgraph = SubgraphView(nstate, [ n for n in nstate.nodes() if n not in (source_accesses | sink_accesses) ]) state.add_nodes_from(subgraph.nodes()) for edge in subgraph.edges(): state.add_edge(edge.src, edge.src_conn, edge.dst, edge.dst_conn, edge.data) ####################################################### # Replace data on inlined SDFG nodes/edges # Replace symbols using invocation symbol mapping # Two-step replacement (N -> __dacesym_N --> map[N]) to avoid clashes for symname, symvalue in nsdfg_node.symbol_mapping.items(): if str(symname) != str(symvalue): nsdfg.replace(symname, '__dacesym_' + symname) for symname, symvalue in nsdfg_node.symbol_mapping.items(): if str(symname) != str(symvalue): nsdfg.replace('__dacesym_' + symname, symvalue) # Replace data names with their top-level counterparts repldict = {} repldict.update(transients) repldict.update({ k: v.data.data for k, v in itertools.chain(inputs.items(), outputs.items()) }) for node in subgraph.nodes(): if isinstance(node, nodes.AccessNode) and node.data in repldict: node.data = repldict[node.data] for edge in subgraph.edges(): if edge.data.data in repldict: edge.data.data = repldict[edge.data.data] ####################################################### # Reconnect inlined SDFG # If a source/sink node is one of the inputs/outputs, reconnect it, # replacing memlets in outgoing/incoming paths modified_edges = set() modified_edges |= self._modify_memlet_path(new_incoming_edges, nstate, state, True) modified_edges |= self._modify_memlet_path(new_outgoing_edges, nstate, state, False) # Modify all other internal edges pertaining to input/output nodes for node in subgraph.nodes(): if isinstance(node, nodes.AccessNode): if node.data in input_set or node.data in output_set: if node.data in input_set: outer_edge = inputs[input_set[node.data]] else: outer_edge = outputs[output_set[node.data]] for edge in state.all_edges(node): if (edge not in modified_edges and edge.data.data == node.data): for e in state.memlet_tree(edge): if e.data.data == node.data: e._data = helpers.unsqueeze_memlet( e.data, outer_edge.data) # If source/sink node is not connected to a source/destination access # node, and the nested SDFG is in a scope, connect to scope with empty # memlets if nsdfg_scope_entry is not None: for node in subgraph.nodes(): if state.in_degree(node) == 0: state.add_edge(nsdfg_scope_entry, None, node, None, Memlet()) if state.out_degree(node) == 0: state.add_edge(node, None, nsdfg_scope_exit, None, Memlet()) # Replace nested SDFG parents with new SDFG for node in nstate.nodes(): if isinstance(node, nodes.NestedSDFG): node.sdfg.parent = state node.sdfg.parent_sdfg = sdfg node.sdfg.parent_nsdfg_node = node # Remove all unused external inputs/output memlet paths, as well as # resulting isolated nodes removed_in_edges = self._remove_edge_path(state, inputs, set(inputs.keys()) - source_accesses, reverse=True) removed_out_edges = self._remove_edge_path(state, outputs, set(outputs.keys()) - sink_accesses, reverse=False) # Re-add in/out edges to first/last nodes in subgraph order = [ x for x in nx.topological_sort(nstate._nx) if isinstance(x, nodes.AccessNode) ] for edge in removed_in_edges: # Find first access node that refers to this edge node = next(n for n in order if n.data == edge.data.data) state.add_edge(edge.src, edge.src_conn, node, edge.dst_conn, edge.data) for edge in removed_out_edges: # Find last access node that refers to this edge node = next(n for n in reversed(order) if n.data == edge.data.data) state.add_edge(node, edge.src_conn, edge.dst, edge.dst_conn, edge.data) ####################################################### # Remove nested SDFG node state.remove_node(nsdfg_node) def _modify_memlet_path(self, new_edges: Dict[nodes.Node, MultiConnectorEdge], nstate: SDFGState, state: SDFGState, inputs: bool) -> Set[MultiConnectorEdge]: """ Modifies memlet paths in an inlined SDFG. Returns set of modified edges. """ result = set() for node, top_edge in new_edges.items(): inner_edges = (nstate.out_edges(node) if inputs else nstate.in_edges(node)) for inner_edge in inner_edges: new_memlet = helpers.unsqueeze_memlet(inner_edge.data, top_edge.data) if inputs: new_edge = state.add_edge(top_edge.src, top_edge.src_conn, inner_edge.dst, inner_edge.dst_conn, new_memlet) mtree = state.memlet_tree(new_edge) else: new_edge = state.add_edge(inner_edge.src, inner_edge.src_conn, top_edge.dst, top_edge.dst_conn, new_memlet) mtree = state.memlet_tree(new_edge) # Modify all memlets going forward/backward def traverse(mtree_node): result.add(mtree_node.edge) mtree_node.edge._data = helpers.unsqueeze_memlet( mtree_node.edge.data, top_edge.data) for child in mtree_node.children: traverse(child) for child in mtree.children: traverse(child) return result
class MapFission(transformation.Transformation): """ Implements the MapFission transformation. Map fission refers to subsuming a map scope into its internal subgraph, essentially replicating the map into maps in all of its internal components. This also extends the dimensions of "border" transient arrays (i.e., those between the maps), in order to retain program semantics after fission. There are two cases that match map fission: 1. A map with an arbitrary subgraph with more than one computational (i.e., non-access) node. The use of arrays connecting the computational nodes must be limited to the subgraph, and non transient arrays may not be used as "border" arrays. 2. A map with one internal node that is a nested SDFG, in which each state matches the conditions of case (1). If a map has nested SDFGs in its subgraph, they are not considered in the case (1) above, and MapFission must be invoked again on the maps with the nested SDFGs in question. """ _map_entry = nodes.EntryNode() _nested_sdfg = nodes.NestedSDFG("", OrderedDiGraph(), {}, {}) @staticmethod def annotates_memlets(): return False @staticmethod def expressions(): return [ sdutil.node_path_graph(MapFission._map_entry, ), sdutil.node_path_graph( MapFission._map_entry, MapFission._nested_sdfg, ) ] @staticmethod def _components( subgraph: gr.SubgraphView) -> List[Tuple[nodes.Node, nodes.Node]]: """ Returns the list of tuples non-array components in this subgraph. Each element in the list is a 2 tuple of (input node, output node) of the component. """ graph = (subgraph if isinstance(subgraph, sd.SDFGState) else subgraph.graph) schildren = subgraph.scope_children() ns = [(n, graph.exit_node(n)) if isinstance(n, nodes.EntryNode) else (n, n) for n in schildren[None] if isinstance(n, (nodes.CodeNode, nodes.EntryNode))] return ns @staticmethod def _border_arrays(sdfg, parent, subgraph): """ Returns a set of array names that are local to the fission subgraph. """ nested = isinstance(parent, sd.SDFGState) schildren = subgraph.scope_children() subset = gr.SubgraphView(parent, schildren[None]) if nested: return set(node.data for node in subset.nodes() if isinstance(node, nodes.AccessNode) and sdfg.arrays[node.data].transient) else: return set(node.data for node in subset.nodes() if isinstance(node, nodes.AccessNode)) @staticmethod def _internal_border_arrays(total_components, subgraphs): """ Returns the set of border arrays that appear between computational components (i.e., without sources and sinks). """ inputs = set() outputs = set() for components, subgraph in zip(total_components, subgraphs): for component_in, component_out in components: for e in subgraph.in_edges(component_in): if isinstance(e.src, nodes.AccessNode): inputs.add(e.src.data) for e in subgraph.out_edges(component_out): if isinstance(e.dst, nodes.AccessNode): outputs.add(e.dst.data) return inputs & outputs @staticmethod def _outside_map(node, scope_dict, entry_nodes): """ Returns True iff node is not in any of the scopes spanned by entry_nodes. """ while scope_dict[node] is not None: if scope_dict[node] in entry_nodes: return False node = scope_dict[node] return True @staticmethod def can_be_applied(graph, candidate, expr_index, sdfg, strict=False): map_node = graph.node(candidate[MapFission._map_entry]) nsdfg_node = None # If the map is dynamic-ranged, the resulting border arrays would be # dynamically sized if sd.has_dynamic_map_inputs(graph, map_node): return False if expr_index == 0: # Map with subgraph subgraphs = [ graph.scope_subgraph(map_node, include_entry=False, include_exit=False) ] else: # Map with nested SDFG nsdfg_node = graph.node(candidate[MapFission._nested_sdfg]) # Make sure there are no other internal nodes in the map if len(set(e.dst for e in graph.out_edges(map_node))) > 1: return False subgraphs = list(nsdfg_node.sdfg.nodes()) # Test subgraphs border_arrays = set() total_components = [] for sg in subgraphs: components = MapFission._components(sg) snodes = sg.nodes() # Test that the subgraphs have more than one computational component if expr_index == 0 and len(snodes) > 0 and len(components) <= 1: return False # Test that the components are connected by transients that are not # used anywhere else border_arrays |= MapFission._border_arrays( nsdfg_node.sdfg if expr_index == 1 else sdfg, sg if expr_index == 1 else graph, sg) total_components.append(components) # In nested SDFGs and subgraphs, ensure none of the border # values are non-transients for array in border_arrays: if expr_index == 0: ndesc = sdfg.arrays[array] else: ndesc = nsdfg_node.sdfg.arrays[array] if ndesc.transient is False: return False # In subgraphs, make sure transients are not used/allocated # in other scopes or states if expr_index == 0: # Find all nodes not in subgraph not_subgraph = set( n.data for n in graph.nodes() if n not in snodes and isinstance(n, nodes.AccessNode)) not_subgraph.update( set(n.data for s in sdfg.nodes() if s != graph for n in s.nodes() if isinstance(n, nodes.AccessNode))) for _, component_out in components: for e in sg.out_edges(component_out): if isinstance(e.dst, nodes.AccessNode): if e.dst.data in not_subgraph: return False # Fail if there are arrays inside the map that are not a direct # output of a computational component # TODO(later): Support this case? Ambiguous array sizes and memlets external_arrays = ( border_arrays - MapFission._internal_border_arrays(total_components, subgraphs)) if len(external_arrays) > 0: return False return True @staticmethod def match_to_str(graph, candidate): map_entry = graph.node(candidate[MapFission._map_entry]) return map_entry.map.label def apply(self, sdfg: sd.SDFG): graph: sd.SDFGState = sdfg.nodes()[self.state_id] map_entry = graph.node(self.subgraph[MapFission._map_entry]) map_exit = graph.exit_node(map_entry) nsdfg_node: Optional[nodes.NestedSDFG] = None # Obtain subgraph to perform fission to if self.expr_index == 0: # Map with subgraph subgraphs = [(graph, graph.scope_subgraph(map_entry, include_entry=False, include_exit=False))] parent = sdfg else: # Map with nested SDFG nsdfg_node = graph.node(self.subgraph[MapFission._nested_sdfg]) subgraphs = [(state, state) for state in nsdfg_node.sdfg.nodes()] parent = nsdfg_node.sdfg modified_arrays = set() # Get map information outer_map: nodes.Map = map_entry.map mapsize = outer_map.range.size() # Add new symbols from outer map to nested SDFG if self.expr_index == 1: map_syms = outer_map.range.free_symbols for edge in graph.out_edges(map_entry): if edge.data.data: map_syms.update(edge.data.subset.free_symbols) for edge in graph.in_edges(map_exit): if edge.data.data: map_syms.update(edge.data.subset.free_symbols) for sym in map_syms: symname = str(sym) if symname in outer_map.params: continue if symname not in nsdfg_node.symbol_mapping.keys(): nsdfg_node.symbol_mapping[symname] = sym nsdfg_node.sdfg.symbols[ symname] = graph.symbols_defined_at( nsdfg_node)[symname] # Remove map symbols from nested mapping for name in outer_map.params: if str(name) in nsdfg_node.symbol_mapping: del nsdfg_node.symbol_mapping[str(name)] if str(name) in nsdfg_node.sdfg.symbols: del nsdfg_node.sdfg.symbols[str(name)] for state, subgraph in subgraphs: components = MapFission._components(subgraph) sources = subgraph.source_nodes() sinks = subgraph.sink_nodes() # Collect external edges if self.expr_index == 0: external_edges_entry = list(state.out_edges(map_entry)) external_edges_exit = list(state.in_edges(map_exit)) else: external_edges_entry = [ e for e in subgraph.edges() if (isinstance(e.src, nodes.AccessNode) and not nsdfg_node.sdfg.arrays[e.src.data].transient) ] external_edges_exit = [ e for e in subgraph.edges() if (isinstance(e.dst, nodes.AccessNode) and not nsdfg_node.sdfg.arrays[e.dst.data].transient) ] # Map external edges to outer memlets edge_to_outer = {} for edge in external_edges_entry: if self.expr_index == 0: # Subgraphs use the corresponding outer map edges path = state.memlet_path(edge) eindex = path.index(edge) edge_to_outer[edge] = path[eindex - 1] else: # Nested SDFGs use the internal map edges of the node outer_edge = next(e for e in graph.in_edges(nsdfg_node) if e.dst_conn == edge.src.data) edge_to_outer[edge] = outer_edge for edge in external_edges_exit: if self.expr_index == 0: path = state.memlet_path(edge) eindex = path.index(edge) edge_to_outer[edge] = path[eindex + 1] else: # Nested SDFGs use the internal map edges of the node outer_edge = next(e for e in graph.out_edges(nsdfg_node) if e.src_conn == edge.dst.data) edge_to_outer[edge] = outer_edge # Collect all border arrays and code->code edges arrays = MapFission._border_arrays( nsdfg_node.sdfg if self.expr_index == 1 else sdfg, state, subgraph) scalars = defaultdict(list) for _, component_out in components: for e in subgraph.out_edges(component_out): if isinstance(e.dst, nodes.CodeNode): scalars[e.data.data].append(e) # Create new arrays for scalars for scalar, edges in scalars.items(): desc = parent.arrays[scalar] del parent.arrays[scalar] name, newdesc = parent.add_transient( scalar, mapsize, desc.dtype, desc.storage, lifetime=desc.lifetime, debuginfo=desc.debuginfo, allow_conflicts=desc.allow_conflicts, find_new_name=True) # Add extra nodes in component boundaries for edge in edges: anode = state.add_access(name) sbs = subsets.Range.from_string(','.join(outer_map.params)) # Offset memlet by map range begin (to fit the transient) sbs.offset([r[0] for r in outer_map.range], True) state.add_edge( edge.src, edge.src_conn, anode, None, mm.Memlet.simple( name, sbs, num_accesses=outer_map.range.num_elements())) state.add_edge( anode, None, edge.dst, edge.dst_conn, mm.Memlet.simple( name, sbs, num_accesses=outer_map.range.num_elements())) state.remove_edge(edge) # Add extra maps around components new_map_entries = [] for component_in, component_out in components: me, mx = state.add_map(outer_map.label + '_fission', [(p, '0:1') for p in outer_map.params], outer_map.schedule, unroll=outer_map.unroll, debuginfo=outer_map.debuginfo) # Add dynamic input connectors for conn in map_entry.in_connectors: if not conn.startswith('IN_'): me.add_in_connector(conn) me.map.range = dcpy(outer_map.range) new_map_entries.append(me) # Reconnect edges through new map for e in state.in_edges(component_in): state.add_edge(me, None, e.dst, e.dst_conn, dcpy(e.data)) # Reconnect inner edges at source directly to external nodes if self.expr_index == 0 and e in external_edges_entry: state.add_edge(edge_to_outer[e].src, edge_to_outer[e].src_conn, me, None, dcpy(edge_to_outer[e].data)) else: state.add_edge(e.src, e.src_conn, me, None, dcpy(e.data)) state.remove_edge(e) # Empty memlet edge in nested SDFGs if state.in_degree(component_in) == 0: state.add_edge(me, None, component_in, None, mm.Memlet()) for e in state.out_edges(component_out): state.add_edge(e.src, e.src_conn, mx, None, dcpy(e.data)) # Reconnect inner edges at sink directly to external nodes if self.expr_index == 0 and e in external_edges_exit: state.add_edge(mx, None, edge_to_outer[e].dst, edge_to_outer[e].dst_conn, dcpy(edge_to_outer[e].data)) else: state.add_edge(mx, None, e.dst, e.dst_conn, dcpy(e.data)) state.remove_edge(e) # Empty memlet edge in nested SDFGs if state.out_degree(component_out) == 0: state.add_edge(component_out, None, mx, None, mm.Memlet()) # Connect other sources/sinks not in components (access nodes) # directly to external nodes if self.expr_index == 0: for node in sources: if isinstance(node, nodes.AccessNode): for edge in state.in_edges(node): outer_edge = edge_to_outer[edge] memlet = dcpy(edge.data) memlet.subset = subsets.Range( outer_map.range.ranges + memlet.subset.ranges) state.add_edge(outer_edge.src, outer_edge.src_conn, edge.dst, edge.dst_conn, memlet) for node in sinks: if isinstance(node, nodes.AccessNode): for edge in state.out_edges(node): outer_edge = edge_to_outer[edge] state.add_edge(edge.src, edge.src_conn, outer_edge.dst, outer_edge.dst_conn, dcpy(outer_edge.data)) # Augment arrays by prepending map dimensions for array in arrays: if array in modified_arrays: continue desc = parent.arrays[array] for sz in reversed(mapsize): desc.strides = [desc.total_size] + list(desc.strides) desc.total_size = desc.total_size * sz desc.shape = mapsize + list(desc.shape) desc.offset = [0] * len(mapsize) + list(desc.offset) modified_arrays.add(array) # Fill scope connectors so that memlets can be tracked below state.fill_scope_connectors() # Correct connectors and memlets in nested SDFGs to account for # missing outside map if self.expr_index == 1: to_correct = ([(e, e.src) for e in external_edges_entry] + [(e, e.dst) for e in external_edges_exit]) corrected_nodes = set() for edge, node in to_correct: if isinstance(node, nodes.AccessNode): if node in corrected_nodes: continue corrected_nodes.add(node) outer_edge = edge_to_outer[edge] desc = parent.arrays[node.data] # Modify shape of internal array to match outer one outer_desc = sdfg.arrays[outer_edge.data.data] if not isinstance(desc, dt.Scalar): desc.shape = outer_desc.shape if isinstance(desc, dt.Array): desc.strides = outer_desc.strides desc.total_size = outer_desc.total_size # Inside the nested SDFG, offset all memlets to include # the offsets from within the map. # NOTE: Relies on propagation to fix outer memlets for internal_edge in state.all_edges(node): for e in state.memlet_tree(internal_edge): e.data.subset.offset(desc.offset, False) e.data.subset = helpers.unsqueeze_memlet( e.data, outer_edge.data).subset # Only after offsetting memlets we can modify the # overall offset if isinstance(desc, dt.Array): desc.offset = outer_desc.offset # Fill in memlet trees for border transients # NOTE: Memlet propagation should run to correct the outer edges for node in subgraph.nodes(): if isinstance(node, nodes.AccessNode) and node.data in arrays: for edge in state.all_edges(node): for e in state.memlet_tree(edge): # Prepend map dimensions to memlet e.data.subset = subsets.Range( [(pystr_to_symbolic(d) - r[0], pystr_to_symbolic(d) - r[0], 1) for d, r in zip(outer_map.params, outer_map.range)] + e.data.subset.ranges) # If nested SDFG, reconnect nodes around map and modify memlets if self.expr_index == 1: for edge in graph.in_edges(map_entry): if not edge.dst_conn or not edge.dst_conn.startswith('IN_'): continue # Modify edge coming into nested SDFG to include entire array desc = sdfg.arrays[edge.data.data] edge.data.subset = subsets.Range.from_array(desc) edge.data.num_accesses = edge.data.subset.num_elements() # Find matching edge inside map inner_edge = next( e for e in graph.out_edges(map_entry) if e.src_conn and e.src_conn[4:] == edge.dst_conn[3:]) graph.add_edge(edge.src, edge.src_conn, nsdfg_node, inner_edge.dst_conn, dcpy(edge.data)) for edge in graph.out_edges(map_exit): # Modify edge coming out of nested SDFG to include entire array desc = sdfg.arrays[edge.data.data] edge.data.subset = subsets.Range.from_array(desc) # Find matching edge inside map inner_edge = next(e for e in graph.in_edges(map_exit) if e.dst_conn[3:] == edge.src_conn[4:]) graph.add_edge(nsdfg_node, inner_edge.src_conn, edge.dst, edge.dst_conn, dcpy(edge.data)) # Remove outer map graph.remove_nodes_from([map_entry, map_exit])