def apply(self, sdfg: SDFG): nsdfg: nodes.NestedSDFG = self.nsdfg(sdfg) state = sdfg.node(self.state_id) new_state = sdfg.add_state_before(state) isedge = sdfg.edges_between(new_state, state)[0] # Find relevant symbol mapping mapping: Dict[str, str] = {} mapping.update({k: str(v) for k, v in nsdfg.symbol_mapping.items()}) mapping.update({ k: next(iter(state.in_edges_by_connector(nsdfg, k))).data.data for k in nsdfg.in_connectors }) nisedge = nsdfg.sdfg.edges()[0] # Safe replacement of edge contents for k, v in mapping.items(): nisedge.data.replace(k, '__dacesym_' + k, replace_keys=False) for k, v in mapping.items(): nisedge.data.replace('__dacesym_' + k, v, replace_keys=False) for akey, aval in nisedge.data.assignments.items(): # Map assignment to outer edge if akey not in sdfg.symbols and akey not in sdfg.arrays: newname = akey else: newname = nsdfg.label + '_' + akey isedge.data.assignments[newname] = aval # Add symbol to outer SDFG sdfg.add_symbol(newname, nsdfg.sdfg.symbols[akey]) # Add symbol mapping to nested SDFG nsdfg.symbol_mapping[akey] = newname isedge.data.condition = nisedge.data.condition # Clean nested SDFG nsdfg.sdfg.remove_node(nisedge.src)
def apply(self, sdfg: SDFG): nsdfg: nodes.NestedSDFG = self.nsdfg(sdfg) state = sdfg.node(self.state_id) new_state = sdfg.add_state_before(state) isedge = sdfg.edges_between(new_state, state)[0] # Find relevant symbol and data descriptor mapping mapping: Dict[str, str] = {} mapping.update({k: str(v) for k, v in nsdfg.symbol_mapping.items()}) mapping.update({ k: next(iter(state.in_edges_by_connector(nsdfg, k))).data.data for k in nsdfg.in_connectors }) mapping.update({ k: next(iter(state.out_edges_by_connector(nsdfg, k))).data.data for k in nsdfg.out_connectors }) # Get internal state and interstate edge source_state = nsdfg.sdfg.start_state nisedge = nsdfg.sdfg.out_edges(source_state)[0] # Add state contents (nodes) new_state.add_nodes_from(source_state.nodes()) # Replace data descriptors and symbols on state graph for node in source_state.nodes(): if isinstance(node, nodes.AccessNode) and node.data in mapping: node.data = mapping[node.data] for edge in source_state.edges(): edge.data.replace(mapping) if edge.data.data in mapping: edge.data.data = mapping[edge.data.data] # Add state contents (edges) for edge in source_state.edges(): new_state.add_edge(edge.src, edge.src_conn, edge.dst, edge.dst_conn, edge.data) # Safe replacement of edge contents def replfunc(m): for k, v in mapping.items(): nisedge.data.replace(k, v, replace_keys=False) symbolic.safe_replace(mapping, replfunc) # Add interstate edge for akey, aval in nisedge.data.assignments.items(): # Map assignment to outer edge if akey not in sdfg.symbols and akey not in sdfg.arrays: newname = akey else: newname = nsdfg.label + '_' + akey isedge.data.assignments[newname] = aval # Add symbol to outer SDFG sdfg.add_symbol(newname, nsdfg.sdfg.symbols[akey]) # Add symbol mapping to nested SDFG nsdfg.symbol_mapping[akey] = newname isedge.data.condition = nisedge.data.condition # Clean nested SDFG nsdfg.sdfg.remove_node(source_state) # Set new starting state nsdfg.sdfg.start_state = nsdfg.sdfg.node_id(nisedge.dst)