def instantiate_loop( self, sdfg: sd.SDFG, loop_states: List[sd.SDFGState], loop_subgraph: gr.SubgraphView, itervar: str, value: symbolic.SymbolicType, state_suffix=None, ): # Using to/from JSON copies faster than deepcopy (which will also # copy the parent SDFG) new_states = [ sd.SDFGState.from_json(s.to_json(), context={'sdfg': sdfg}) for s in loop_states ] # Replace iterate with value in each state for state in new_states: state.set_label(state.label + '_' + itervar + '_' + ( state_suffix if state_suffix is not None else '%d' % value)) state.replace(itervar, value) # Add subgraph to original SDFG for edge in loop_subgraph.edges(): src = new_states[loop_states.index(edge.src)] dst = new_states[loop_states.index(edge.dst)] # Replace conditions in subgraph edges data: sd.InterstateEdge = copy.deepcopy(edge.data) if data.condition: ASTFindReplace({itervar: str(value)}).visit(data.condition) sdfg.add_edge(src, dst, data) return new_states
def state_fission(sdfg: SDFG, subgraph: graph.SubgraphView) -> SDFGState: ''' Given a subgraph, adds a new SDFG state before the state that contains it, removes the subgraph from the original state, and connects the two states. :param subgraph: the subgraph to remove. :return: the newly created SDFG state. ''' state: SDFGState = subgraph.graph newstate = sdfg.add_state_before(state) # Save edges before removing nodes orig_edges = subgraph.edges() # Mark boundary access nodes to keep after fission nodes_to_remove = set(subgraph.nodes()) nodes_to_remove -= set(n for n in subgraph.source_nodes() if state.out_degree(n) > 1) nodes_to_remove -= set(n for n in subgraph.sink_nodes() if state.in_degree(n) > 1) state.remove_nodes_from(nodes_to_remove) for n in subgraph.nodes(): if isinstance(n, nodes.NestedSDFG): # Set the new parent state n.sdfg.parent = newstate newstate.add_nodes_from(subgraph.nodes()) for e in orig_edges: newstate.add_edge(e.src, e.src_conn, e.dst, e.dst_conn, e.data) return newstate
def state_fission(sdfg: SDFG, subgraph: graph.SubgraphView) -> SDFGState: ''' Given a subgraph, adds a new SDFG state before the state that contains it, removes the subgraph from the original state, and connects the two states. :param subgraph: the subgraph to remove. :return: the newly created SDFG state. ''' state: SDFGState = subgraph.graph newstate = sdfg.add_state_before(state) # Save edges before removing nodes orig_edges = subgraph.edges() # Mark boundary access nodes to keep after fission nodes_to_remove = set(subgraph.nodes()) boundary_nodes = [ n for n in subgraph.nodes() if len(state.out_edges(n)) > len(subgraph.out_edges(n)) ] + [ n for n in subgraph.nodes() if len(state.in_edges(n)) > len(subgraph.in_edges(n)) ] # Make dictionary of nodes to add to new state new_nodes = {n: n for n in subgraph.nodes()} new_nodes.update({b: copy.deepcopy(b) for b in boundary_nodes}) nodes_to_remove -= set(boundary_nodes) state.remove_nodes_from(nodes_to_remove) for n in new_nodes.values(): if isinstance(n, nodes.NestedSDFG): # Set the new parent state n.sdfg.parent = newstate newstate.add_nodes_from(new_nodes.values()) for e in orig_edges: newstate.add_edge(new_nodes[e.src], e.src_conn, new_nodes[e.dst], e.dst_conn, e.data) return newstate
def get_actions(actions, graph, match): subgraph_node_ids = match.subgraph.values() subgraph_nodes = [graph.nodes()[nid] for nid in subgraph_node_ids] for node in subgraph_nodes: version = 0 while (node, type(match).__name__, match.expr_index, version) in actions.keys(): version += 1 actions[(node, type(match).__name__, match.expr_index, version)] = match subgraph = SubgraphView(graph, subgraph_nodes) for edge in subgraph.edges(): version = 0 while (edge, type(match).__name__, match.expr_index, version) in actions.keys(): version += 1 actions[(edge, type(match).__name__, match.expr_index, version)] = match return actions
def apply(self, sdfg: SDFG): state: SDFGState = sdfg.nodes()[self.state_id] nsdfg_node = state.nodes()[self.subgraph[InlineSDFG._nested_sdfg]] nsdfg: SDFG = nsdfg_node.sdfg nstate: SDFGState = nsdfg.nodes()[0] if nsdfg_node.schedule is not dtypes.ScheduleType.Default: infer_types.set_default_schedule_and_storage_types( nsdfg, nsdfg_node.schedule) nsdfg_scope_entry = state.entry_node(nsdfg_node) nsdfg_scope_exit = (state.exit_node(nsdfg_scope_entry) if nsdfg_scope_entry is not None else None) ####################################################### # Collect and update top-level SDFG metadata # Global/init/exit code for loc, code in nsdfg.global_code.items(): sdfg.append_global_code(code.code, loc) for loc, code in nsdfg.init_code.items(): sdfg.append_init_code(code.code, loc) for loc, code in nsdfg.exit_code.items(): sdfg.append_exit_code(code.code, loc) # Constants for cstname, cstval in nsdfg.constants.items(): if cstname in sdfg.constants: if cstval != sdfg.constants[cstname]: warnings.warn('Constant value mismatch for "%s" while ' 'inlining SDFG. Inner = %s != %s = outer' % (cstname, cstval, sdfg.constants[cstname])) else: sdfg.add_constant(cstname, cstval) # Find original source/destination edges (there is only one edge per # connector, according to match) inputs: Dict[str, MultiConnectorEdge] = {} outputs: Dict[str, MultiConnectorEdge] = {} input_set: Dict[str, str] = {} output_set: Dict[str, str] = {} for e in state.in_edges(nsdfg_node): inputs[e.dst_conn] = e input_set[e.data.data] = e.dst_conn for e in state.out_edges(nsdfg_node): outputs[e.src_conn] = e output_set[e.data.data] = e.src_conn # Access nodes that need to be reshaped reshapes: Set(str) = set() for aname, array in nsdfg.arrays.items(): if array.transient: continue edge = None if aname in inputs: edge = inputs[aname] if len(array.shape) > len(edge.data.subset): reshapes.add(aname) continue if aname in outputs: edge = outputs[aname] if len(array.shape) > len(edge.data.subset): reshapes.add(aname) continue if edge is not None and not InlineSDFG._check_strides( array.strides, sdfg.arrays[edge.data.data].strides, edge.data, nsdfg_node): reshapes.add(aname) # Replace symbols using invocation symbol mapping # Two-step replacement (N -> __dacesym_N --> map[N]) to avoid clashes for symname, symvalue in nsdfg_node.symbol_mapping.items(): if str(symname) != str(symvalue): nsdfg.replace(symname, '__dacesym_' + symname) for symname, symvalue in nsdfg_node.symbol_mapping.items(): if str(symname) != str(symvalue): nsdfg.replace('__dacesym_' + symname, symvalue) # All transients become transients of the parent (if data already # exists, find new name) # Mapping from nested transient name to top-level name transients: Dict[str, str] = {} for node in nstate.nodes(): if isinstance(node, nodes.AccessNode): datadesc = nsdfg.arrays[node.data] if node.data not in transients and datadesc.transient: name = sdfg.add_datadesc('%s_%s' % (nsdfg.label, node.data), datadesc, find_new_name=True) transients[node.data] = name # All transients of edges between code nodes are also added to parent for edge in nstate.edges(): if (isinstance(edge.src, nodes.CodeNode) and isinstance(edge.dst, nodes.CodeNode)): if edge.data.data is not None: datadesc = nsdfg.arrays[edge.data.data] if edge.data.data not in transients and datadesc.transient: name = sdfg.add_datadesc('%s_%s' % (nsdfg.label, edge.data.data), datadesc, find_new_name=True) transients[edge.data.data] = name # Collect nodes to add to top-level graph new_incoming_edges: Dict[nodes.Node, MultiConnectorEdge] = {} new_outgoing_edges: Dict[nodes.Node, MultiConnectorEdge] = {} source_accesses = set() sink_accesses = set() for node in nstate.source_nodes(): if (isinstance(node, nodes.AccessNode) and node.data not in transients and node.data not in reshapes): new_incoming_edges[node] = inputs[node.data] source_accesses.add(node) for node in nstate.sink_nodes(): if (isinstance(node, nodes.AccessNode) and node.data not in transients and node.data not in reshapes): new_outgoing_edges[node] = outputs[node.data] sink_accesses.add(node) ####################################################### # Replace data on inlined SDFG nodes/edges # Replace data names with their top-level counterparts repldict = {} repldict.update(transients) repldict.update({ k: v.data.data for k, v in itertools.chain(inputs.items(), outputs.items()) }) # Add views whenever reshapes are necessary for dname in reshapes: desc = nsdfg.arrays[dname] # To avoid potential confusion, rename protected __return keyword if dname.startswith('__return'): newname = f'{nsdfg.name}_ret{dname[8:]}' else: newname = dname newname, _ = sdfg.add_view(newname, desc.shape, desc.dtype, storage=desc.storage, strides=desc.strides, offset=desc.offset, debuginfo=desc.debuginfo, allow_conflicts=desc.allow_conflicts, total_size=desc.total_size, alignment=desc.alignment, may_alias=desc.may_alias, find_new_name=True) repldict[dname] = newname for node in nstate.nodes(): if isinstance(node, nodes.AccessNode) and node.data in repldict: node.data = repldict[node.data] for edge in nstate.edges(): if edge.data.data in repldict: edge.data.data = repldict[edge.data.data] # Add extra access nodes for out/in view nodes for node in nstate.nodes(): if isinstance(node, nodes.AccessNode) and node.data in reshapes: if nstate.in_degree(node) > 0 and nstate.out_degree(node) > 0: # Such a node has to be in the output set edge = outputs[node.data] # Redirect outgoing edges through access node out_edges = list(nstate.out_edges(node)) anode = nstate.add_access(edge.data.data) vnode = nstate.add_access(node.data) nstate.add_nedge(node, anode, edge.data) nstate.add_nedge(anode, vnode, edge.data) for e in out_edges: nstate.remove_edge(e) nstate.add_edge(vnode, e.src_conn, e.dst, e.dst_conn, e.data) ####################################################### # Add nested SDFG into top-level SDFG # Add nested nodes into original state subgraph = SubgraphView(nstate, [ n for n in nstate.nodes() if n not in (source_accesses | sink_accesses) ]) state.add_nodes_from(subgraph.nodes()) for edge in subgraph.edges(): state.add_edge(edge.src, edge.src_conn, edge.dst, edge.dst_conn, edge.data) ####################################################### # Reconnect inlined SDFG # If a source/sink node is one of the inputs/outputs, reconnect it, # replacing memlets in outgoing/incoming paths modified_edges = set() modified_edges |= self._modify_memlet_path(new_incoming_edges, nstate, state, True) modified_edges |= self._modify_memlet_path(new_outgoing_edges, nstate, state, False) # Reshape: add connections to viewed data self._modify_reshape_data(reshapes, repldict, inputs, nstate, state, True) self._modify_reshape_data(reshapes, repldict, outputs, nstate, state, False) # Modify all other internal edges pertaining to input/output nodes for node in subgraph.nodes(): if isinstance(node, nodes.AccessNode): if node.data in input_set or node.data in output_set: if node.data in input_set: outer_edge = inputs[input_set[node.data]] else: outer_edge = outputs[output_set[node.data]] for edge in state.all_edges(node): if (edge not in modified_edges and edge.data.data == node.data): for e in state.memlet_tree(edge): if e.data.data == node.data: e._data = helpers.unsqueeze_memlet( e.data, outer_edge.data) # If source/sink node is not connected to a source/destination access # node, and the nested SDFG is in a scope, connect to scope with empty # memlets if nsdfg_scope_entry is not None: for node in subgraph.nodes(): if state.in_degree(node) == 0: state.add_edge(nsdfg_scope_entry, None, node, None, Memlet()) if state.out_degree(node) == 0: state.add_edge(node, None, nsdfg_scope_exit, None, Memlet()) # Replace nested SDFG parents with new SDFG for node in nstate.nodes(): if isinstance(node, nodes.NestedSDFG): node.sdfg.parent = state node.sdfg.parent_sdfg = sdfg node.sdfg.parent_nsdfg_node = node # Remove all unused external inputs/output memlet paths, as well as # resulting isolated nodes removed_in_edges = self._remove_edge_path(state, inputs, set(inputs.keys()) - source_accesses, reverse=True) removed_out_edges = self._remove_edge_path(state, outputs, set(outputs.keys()) - sink_accesses, reverse=False) # Re-add in/out edges to first/last nodes in subgraph order = [ x for x in nx.topological_sort(nstate._nx) if isinstance(x, nodes.AccessNode) ] for edge in removed_in_edges: # Find first access node that refers to this edge node = next(n for n in order if n.data == edge.data.data) state.add_edge(edge.src, edge.src_conn, node, edge.dst_conn, edge.data) for edge in removed_out_edges: # Find last access node that refers to this edge node = next(n for n in reversed(order) if n.data == edge.data.data) state.add_edge(node, edge.src_conn, edge.dst, edge.dst_conn, edge.data) ####################################################### # Remove nested SDFG node state.remove_node(nsdfg_node)
def nest_state_subgraph(sdfg: SDFG, state: SDFGState, subgraph: SubgraphView, name: Optional[str] = None, full_data: bool = False) -> nodes.NestedSDFG: """ Turns a state subgraph into a nested SDFG. Operates in-place. :param sdfg: The SDFG containing the state subgraph. :param state: The state containing the subgraph. :param subgraph: Subgraph to nest. :param name: An optional name for the nested SDFG. :param full_data: If True, nests entire input/output data. :return: The nested SDFG node. :raise KeyError: Some or all nodes in the subgraph are not located in this state, or the state does not belong to the given SDFG. :raise ValueError: The subgraph is contained in more than one scope. """ if state.parent != sdfg: raise KeyError('State does not belong to given SDFG') if subgraph is not state and subgraph.graph is not state: raise KeyError('Subgraph does not belong to given state') # Find the top-level scope scope_tree = state.scope_tree() scope_dict = state.scope_dict() scope_dict_children = state.scope_children() top_scopenode = -1 # Initialized to -1 since "None" already means top-level for node in subgraph.nodes(): if node not in scope_dict: raise KeyError('Node not found in state') # If scope entry/exit, ensure entire scope is in subgraph if isinstance(node, nodes.EntryNode): scope_nodes = scope_dict_children[node] if any(n not in subgraph.nodes() for n in scope_nodes): raise ValueError('Subgraph contains partial scopes (entry)') elif isinstance(node, nodes.ExitNode): entry = state.entry_node(node) scope_nodes = scope_dict_children[entry] + [entry] if any(n not in subgraph.nodes() for n in scope_nodes): raise ValueError('Subgraph contains partial scopes (exit)') scope_node = scope_dict[node] if scope_node not in subgraph.nodes(): if top_scopenode != -1 and top_scopenode != scope_node: raise ValueError('Subgraph is contained in more than one scope') top_scopenode = scope_node scope = scope_tree[top_scopenode] ### # Consolidate edges in top scope utils.consolidate_edges(sdfg, scope) snodes = subgraph.nodes() # Collect inputs and outputs of the nested SDFG inputs: List[MultiConnectorEdge] = [] outputs: List[MultiConnectorEdge] = [] for node in snodes: for edge in state.in_edges(node): if edge.src not in snodes: inputs.append(edge) for edge in state.out_edges(node): if edge.dst not in snodes: outputs.append(edge) # Collect transients not used outside of subgraph (will be removed of # top-level graph) data_in_subgraph = set(n.data for n in subgraph.nodes() if isinstance(n, nodes.AccessNode)) # Find other occurrences in SDFG other_nodes = set(n.data for s in sdfg.nodes() for n in s.nodes() if isinstance(n, nodes.AccessNode) and n not in subgraph.nodes()) subgraph_transients = set() for data in data_in_subgraph: datadesc = sdfg.arrays[data] if datadesc.transient and data not in other_nodes: subgraph_transients.add(data) # All transients of edges between code nodes are also added to nested graph for edge in subgraph.edges(): if (isinstance(edge.src, nodes.CodeNode) and isinstance(edge.dst, nodes.CodeNode)): subgraph_transients.add(edge.data.data) # Collect data used in access nodes within subgraph (will be referenced in # full upon nesting) input_arrays = set() output_arrays = {} for node in subgraph.nodes(): if (isinstance(node, nodes.AccessNode) and node.data not in subgraph_transients): if node.has_reads(state): input_arrays.add(node.data) if node.has_writes(state): output_arrays[node.data] = state.in_edges(node)[0].data.wcr # Create the nested SDFG nsdfg = SDFG(name or 'nested_' + state.label) # Transients are added to the nested graph as-is for name in subgraph_transients: nsdfg.add_datadesc(name, sdfg.arrays[name]) # Input/output data that are not source/sink nodes are added to the graph # as non-transients for name in (input_arrays | output_arrays.keys()): datadesc = copy.deepcopy(sdfg.arrays[name]) datadesc.transient = False nsdfg.add_datadesc(name, datadesc) # Connected source/sink nodes outside subgraph become global data # descriptors in nested SDFG input_names = {} output_names = {} global_subsets: Dict[str, Tuple[str, Subset]] = {} for edge in inputs: if edge.data.data is None: # Skip edges with an empty memlet continue name = edge.data.data if name not in global_subsets: datadesc = copy.deepcopy(sdfg.arrays[edge.data.data]) datadesc.transient = False if not full_data: datadesc.shape = edge.data.subset.size() new_name = nsdfg.add_datadesc(name, datadesc, find_new_name=True) global_subsets[name] = (new_name, edge.data.subset) else: new_name, subset = global_subsets[name] if not full_data: new_subset = union(subset, edge.data.subset) if new_subset is None: new_subset = Range.from_array(sdfg.arrays[name]) global_subsets[name] = (new_name, new_subset) nsdfg.arrays[new_name].shape = new_subset.size() input_names[edge] = new_name for edge in outputs: if edge.data.data is None: # Skip edges with an empty memlet continue name = edge.data.data if name not in global_subsets: datadesc = copy.deepcopy(sdfg.arrays[edge.data.data]) datadesc.transient = False if not full_data: datadesc.shape = edge.data.subset.size() new_name = nsdfg.add_datadesc(name, datadesc, find_new_name=True) global_subsets[name] = (new_name, edge.data.subset) else: new_name, subset = global_subsets[name] if not full_data: new_subset = union(subset, edge.data.subset) if new_subset is None: new_subset = Range.from_array(sdfg.arrays[name]) global_subsets[name] = (new_name, new_subset) nsdfg.arrays[new_name].shape = new_subset.size() output_names[edge] = new_name ################### # Add scope symbols to the nested SDFG defined_vars = set( symbolic.pystr_to_symbolic(s) for s in (state.symbols_defined_at(top_scopenode).keys() | sdfg.symbols)) for v in defined_vars: if v in sdfg.symbols: sym = sdfg.symbols[v] nsdfg.add_symbol(v, sym.dtype) # Add constants to nested SDFG for cstname, cstval in sdfg.constants.items(): nsdfg.add_constant(cstname, cstval) # Create nested state nstate = nsdfg.add_state() # Add subgraph nodes and edges to nested state nstate.add_nodes_from(subgraph.nodes()) for e in subgraph.edges(): nstate.add_edge(e.src, e.src_conn, e.dst, e.dst_conn, copy.deepcopy(e.data)) # Modify nested SDFG parents in subgraph for node in subgraph.nodes(): if isinstance(node, nodes.NestedSDFG): node.sdfg.parent = nstate node.sdfg.parent_sdfg = nsdfg node.sdfg.parent_nsdfg_node = node # Add access nodes and edges as necessary edges_to_offset = [] for edge, name in input_names.items(): node = nstate.add_read(name) new_edge = copy.deepcopy(edge.data) new_edge.data = name edges_to_offset.append((edge, nstate.add_edge(node, None, edge.dst, edge.dst_conn, new_edge))) for edge, name in output_names.items(): node = nstate.add_write(name) new_edge = copy.deepcopy(edge.data) new_edge.data = name edges_to_offset.append((edge, nstate.add_edge(edge.src, edge.src_conn, node, None, new_edge))) # Offset memlet paths inside nested SDFG according to subsets for original_edge, new_edge in edges_to_offset: for edge in nstate.memlet_tree(new_edge): edge.data.data = new_edge.data.data if not full_data: edge.data.subset.offset(global_subsets[original_edge.data.data][1], True) # Add nested SDFG node to the input state nested_sdfg = state.add_nested_sdfg(nsdfg, None, set(input_names.values()) | input_arrays, set(output_names.values()) | output_arrays.keys()) # Reconnect memlets to nested SDFG reconnected_in = set() reconnected_out = set() empty_input = None empty_output = None for edge in inputs: if edge.data.data is None: empty_input = edge continue name = input_names[edge] if name in reconnected_in: continue if full_data: data = Memlet.from_array(edge.data.data, sdfg.arrays[edge.data.data]) else: data = copy.deepcopy(edge.data) data.subset = global_subsets[edge.data.data][1] state.add_edge(edge.src, edge.src_conn, nested_sdfg, name, data) reconnected_in.add(name) for edge in outputs: if edge.data.data is None: empty_output = edge continue name = output_names[edge] if name in reconnected_out: continue if full_data: data = Memlet.from_array(edge.data.data, sdfg.arrays[edge.data.data]) else: data = copy.deepcopy(edge.data) data.subset = global_subsets[edge.data.data][1] data.wcr = edge.data.wcr state.add_edge(nested_sdfg, name, edge.dst, edge.dst_conn, data) reconnected_out.add(name) # Connect access nodes to internal input/output data as necessary entry = scope.entry exit = scope.exit for name in input_arrays: node = state.add_read(name) if entry is not None: state.add_nedge(entry, node, Memlet()) state.add_edge(node, None, nested_sdfg, name, Memlet.from_array(name, sdfg.arrays[name])) for name, wcr in output_arrays.items(): node = state.add_write(name) if exit is not None: state.add_nedge(node, exit, Memlet()) state.add_edge(nested_sdfg, name, node, None, Memlet(data=name, wcr=wcr)) # Graph was not reconnected, but needs to be if state.in_degree(nested_sdfg) == 0 and empty_input is not None: state.add_edge(empty_input.src, empty_input.src_conn, nested_sdfg, None, empty_input.data) if state.out_degree(nested_sdfg) == 0 and empty_output is not None: state.add_edge(nested_sdfg, None, empty_output.dst, empty_output.dst_conn, empty_output.data) # Remove subgraph nodes from graph state.remove_nodes_from(subgraph.nodes()) # Remove subgraph transients from top-level graph for transient in subgraph_transients: del sdfg.arrays[transient] # Remove newly isolated nodes due to memlet consolidation for edge in inputs: if state.in_degree(edge.src) + state.out_degree(edge.src) == 0: state.remove_node(edge.src) for edge in outputs: if state.in_degree(edge.dst) + state.out_degree(edge.dst) == 0: state.remove_node(edge.dst) return nested_sdfg
def apply(self, sdfg): state: SDFGState = sdfg.nodes()[self.state_id] nsdfg_node = state.nodes()[self.subgraph[InlineSDFG._nested_sdfg]] nsdfg: SDFG = nsdfg_node.sdfg nstate: SDFGState = nsdfg.nodes()[0] nsdfg_scope_entry = state.entry_node(nsdfg_node) nsdfg_scope_exit = (state.exit_node(nsdfg_scope_entry) if nsdfg_scope_entry is not None else None) ####################################################### # Collect and update top-level SDFG metadata # Global/init/exit code for loc, code in nsdfg.global_code.items(): sdfg.append_global_code(code.code, loc) for loc, code in nsdfg.init_code.items(): sdfg.append_init_code(code.code, loc) for loc, code in nsdfg.exit_code.items(): sdfg.append_exit_code(code.code, loc) # Constants for cstname, cstval in nsdfg.constants.items(): if cstname in sdfg.constants: if cstval != sdfg.constants[cstname]: warnings.warn('Constant value mismatch for "%s" while ' 'inlining SDFG. Inner = %s != %s = outer' % (cstname, cstval, sdfg.constants[cstname])) else: sdfg.add_constant(cstname, cstval) # Find original source/destination edges (there is only one edge per # connector, according to match) inputs: Dict[str, MultiConnectorEdge] = {} outputs: Dict[str, MultiConnectorEdge] = {} input_set: Dict[str, str] = {} output_set: Dict[str, str] = {} for e in state.in_edges(nsdfg_node): inputs[e.dst_conn] = e input_set[e.data.data] = e.dst_conn for e in state.out_edges(nsdfg_node): outputs[e.src_conn] = e output_set[e.data.data] = e.src_conn # All transients become transients of the parent (if data already # exists, find new name) # Mapping from nested transient name to top-level name transients: Dict[str, str] = {} for node in nstate.nodes(): if isinstance(node, nodes.AccessNode): datadesc = nsdfg.arrays[node.data] if node.data not in transients and datadesc.transient: name = sdfg.add_datadesc('%s_%s' % (nsdfg.label, node.data), datadesc, find_new_name=True) transients[node.data] = name # All transients of edges between code nodes are also added to parent for edge in nstate.edges(): if (isinstance(edge.src, nodes.CodeNode) and isinstance(edge.dst, nodes.CodeNode)): datadesc = nsdfg.arrays[edge.data.data] if edge.data.data not in transients and datadesc.transient: name = sdfg.add_datadesc('%s_%s' % (nsdfg.label, edge.data.data), datadesc, find_new_name=True) transients[edge.data.data] = name # Collect nodes to add to top-level graph new_incoming_edges: Dict[nodes.Node, MultiConnectorEdge] = {} new_outgoing_edges: Dict[nodes.Node, MultiConnectorEdge] = {} source_accesses = set() sink_accesses = set() for node in nstate.source_nodes(): if (isinstance(node, nodes.AccessNode) and node.data not in transients): new_incoming_edges[node] = inputs[node.data] source_accesses.add(node) for node in nstate.sink_nodes(): if (isinstance(node, nodes.AccessNode) and node.data not in transients): new_outgoing_edges[node] = outputs[node.data] sink_accesses.add(node) ####################################################### # Add nested SDFG into top-level SDFG # Add nested nodes into original state subgraph = SubgraphView(nstate, [ n for n in nstate.nodes() if n not in (source_accesses | sink_accesses) ]) state.add_nodes_from(subgraph.nodes()) for edge in subgraph.edges(): state.add_edge(edge.src, edge.src_conn, edge.dst, edge.dst_conn, edge.data) ####################################################### # Replace data on inlined SDFG nodes/edges # Replace symbols using invocation symbol mapping # Two-step replacement (N -> __dacesym_N --> map[N]) to avoid clashes for symname, symvalue in nsdfg_node.symbol_mapping.items(): if str(symname) != str(symvalue): nsdfg.replace(symname, '__dacesym_' + symname) for symname, symvalue in nsdfg_node.symbol_mapping.items(): if str(symname) != str(symvalue): nsdfg.replace('__dacesym_' + symname, symvalue) # Replace data names with their top-level counterparts repldict = {} repldict.update(transients) repldict.update({ k: v.data.data for k, v in itertools.chain(inputs.items(), outputs.items()) }) for node in subgraph.nodes(): if isinstance(node, nodes.AccessNode) and node.data in repldict: node.data = repldict[node.data] for edge in subgraph.edges(): if edge.data.data in repldict: edge.data.data = repldict[edge.data.data] ####################################################### # Reconnect inlined SDFG # If a source/sink node is one of the inputs/outputs, reconnect it, # replacing memlets in outgoing/incoming paths modified_edges = set() modified_edges |= self._modify_memlet_path(new_incoming_edges, nstate, state, True) modified_edges |= self._modify_memlet_path(new_outgoing_edges, nstate, state, False) # Modify all other internal edges pertaining to input/output nodes for node in subgraph.nodes(): if isinstance(node, nodes.AccessNode): if node.data in input_set or node.data in output_set: if node.data in input_set: outer_edge = inputs[input_set[node.data]] else: outer_edge = outputs[output_set[node.data]] for edge in state.all_edges(node): if (edge not in modified_edges and edge.data.data == node.data): for e in state.memlet_tree(edge): if e.data.data == node.data: e._data = helpers.unsqueeze_memlet( e.data, outer_edge.data) # If source/sink node is not connected to a source/destination access # node, and the nested SDFG is in a scope, connect to scope with empty # memlets if nsdfg_scope_entry is not None: for node in subgraph.nodes(): if state.in_degree(node) == 0: state.add_edge(nsdfg_scope_entry, None, node, None, Memlet()) if state.out_degree(node) == 0: state.add_edge(node, None, nsdfg_scope_exit, None, Memlet()) # Replace nested SDFG parents with new SDFG for node in nstate.nodes(): if isinstance(node, nodes.NestedSDFG): node.sdfg.parent = state node.sdfg.parent_sdfg = sdfg node.sdfg.parent_nsdfg_node = node # Remove all unused external inputs/output memlet paths, as well as # resulting isolated nodes removed_in_edges = self._remove_edge_path(state, inputs, set(inputs.keys()) - source_accesses, reverse=True) removed_out_edges = self._remove_edge_path(state, outputs, set(outputs.keys()) - sink_accesses, reverse=False) # Re-add in/out edges to first/last nodes in subgraph order = [ x for x in nx.topological_sort(nstate._nx) if isinstance(x, nodes.AccessNode) ] for edge in removed_in_edges: # Find first access node that refers to this edge node = next(n for n in order if n.data == edge.data.data) state.add_edge(edge.src, edge.src_conn, node, edge.dst_conn, edge.data) for edge in removed_out_edges: # Find last access node that refers to this edge node = next(n for n in reversed(order) if n.data == edge.data.data) state.add_edge(node, edge.src_conn, edge.dst, edge.dst_conn, edge.data) ####################################################### # Remove nested SDFG node state.remove_node(nsdfg_node)
def apply(self, sdfg: SDFG): subgraph = self.subgraph_view(sdfg) entry_states_in, entry_states_out = self.get_entry_states( sdfg, subgraph) _, exit_states_out = self.get_exit_states(sdfg, subgraph) entry_state_in = entry_states_in.pop() entry_state_out = entry_states_out.pop() \ if len(entry_states_out) > 0 else None exit_state_out = exit_states_out.pop() \ if len(exit_states_out) > 0 else None launch_state = None entry_guard_state = None exit_guard_state = None # generate entry guard state if needed if self.include_in_assignment and entry_state_out is not None: entry_edge = sdfg.edges_between(entry_state_out, entry_state_in)[0] if len(entry_edge.data.assignments) > 0: entry_guard_state = sdfg.add_state( label='{}kernel_entry_guard'.format( self.kernel_prefix + '_' if self.kernel_prefix != '' else '')) sdfg.add_edge(entry_state_out, entry_guard_state, InterstateEdge(entry_edge.data.condition)) sdfg.add_edge( entry_guard_state, entry_state_in, InterstateEdge(None, entry_edge.data.assignments)) sdfg.remove_edge(entry_edge) # Update SubgraphView new_node_list = subgraph.nodes() new_node_list.append(entry_guard_state) subgraph = SubgraphView(sdfg, new_node_list) launch_state = sdfg.add_state_before( entry_guard_state, label='{}kernel_launch'.format( self.kernel_prefix + '_' if self.kernel_prefix != '' else '')) # generate exit guard state if exit_state_out is not None: exit_guard_state = sdfg.add_state_before( exit_state_out, label='{}kernel_exit_guard'.format( self.kernel_prefix + '_' if self.kernel_prefix != '' else '')) # Update SubgraphView new_node_list = subgraph.nodes() new_node_list.append(exit_guard_state) subgraph = SubgraphView(sdfg, new_node_list) if launch_state is None: launch_state = sdfg.add_state_before( exit_state_out, label='{}kernel_launch'.format( self.kernel_prefix + '_' if self.kernel_prefix != '' else '')) # If the launch state doesn't exist at this point then there is no other # states outside of the kernel, so create a stand alone launch state if launch_state is None: assert (entry_state_in is None and exit_state_out is None) launch_state = sdfg.add_state(label='{}kernel_launch'.format( self.kernel_prefix + '_' if self.kernel_prefix != '' else '')) # create sdfg for kernel and fill it with states and edges from # ssubgraph dfg will be nested at the end kernel_sdfg = SDFG( '{}kernel'.format(self.kernel_prefix + '_' if self.kernel_prefix != '' else '')) edges = subgraph.edges() for edge in edges: kernel_sdfg.add_edge(edge.src, edge.dst, edge.data) # Setting entry node in nested SDFG if no entry guard was created if entry_guard_state is None: kernel_sdfg.start_state = kernel_sdfg.node_id(entry_state_in) for state in subgraph: state.parent = kernel_sdfg # remove the now nested nodes from the outer sdfg and make sure the # launch state is properly connected to remaining states sdfg.remove_nodes_from(subgraph.nodes()) if entry_state_out is not None \ and len(sdfg.edges_between(entry_state_out, launch_state)) == 0: sdfg.add_edge(entry_state_out, launch_state, InterstateEdge()) if exit_state_out is not None \ and len(sdfg.edges_between(launch_state, exit_state_out)) == 0: sdfg.add_edge(launch_state, exit_state_out, InterstateEdge()) # Handle data for kernel kernel_data = set(node.data for state in kernel_sdfg for node in state.nodes() if isinstance(node, nodes.AccessNode)) # move Streams and Register data into the nested SDFG # normal data will be added as kernel argument kernel_args = [] for data in kernel_data: if (isinstance(sdfg.arrays[data], dace.data.Stream) or (isinstance(sdfg.arrays[data], dace.data.Array) and sdfg.arrays[data].storage == StorageType.Register)): kernel_sdfg.add_datadesc(data, sdfg.arrays[data]) del sdfg.arrays[data] else: copy_desc = copy.deepcopy(sdfg.arrays[data]) copy_desc.transient = False copy_desc.storage = StorageType.Default kernel_sdfg.add_datadesc(data, copy_desc) kernel_args.append(data) # read only data will be passed as input, writeable data will be passed # as 'output' otherwise kernel cannot write to data kernel_args_read = set() kernel_args_write = set() for data in kernel_args: data_accesses_read_only = [ node.access == dtypes.AccessType.ReadOnly for state in kernel_sdfg for node in state if isinstance(node, nodes.AccessNode) and node.data == data ] if all(data_accesses_read_only): kernel_args_read.add(data) else: kernel_args_write.add(data) # Kernel SDFG is complete at this point if self.validate: kernel_sdfg.validate() # Filling launch state with nested SDFG, map and access nodes map_entry, map_exit = launch_state.add_map( '{}kernel_launch_map'.format( self.kernel_prefix + '_' if self.kernel_prefix != '' else ''), dict(ignore='0'), schedule=ScheduleType.GPU_Persistent, ) nested_sdfg = launch_state.add_nested_sdfg( kernel_sdfg, sdfg, kernel_args_read, kernel_args_write, ) # Create and connect read only data access nodes for arg in kernel_args_read: read_node = launch_state.add_read(arg) launch_state.add_memlet_path(read_node, map_entry, nested_sdfg, dst_conn=arg, memlet=Memlet.from_array( arg, sdfg.arrays[arg])) # Create and connect writable data access nodes for arg in kernel_args_write: write_node = launch_state.add_write(arg) launch_state.add_memlet_path(nested_sdfg, map_exit, write_node, src_conn=arg, memlet=Memlet.from_array( arg, sdfg.arrays[arg])) # Transformation is done if self.validate: sdfg.validate()
def nest_state_subgraph(sdfg: SDFG, state: SDFGState, subgraph: SubgraphView, name: Optional[str] = None, full_data: bool = False) -> nodes.NestedSDFG: """ Turns a state subgraph into a nested SDFG. Operates in-place. :param sdfg: The SDFG containing the state subgraph. :param state: The state containing the subgraph. :param subgraph: Subgraph to nest. :param name: An optional name for the nested SDFG. :param full_data: If True, nests entire input/output data. :return: The nested SDFG node. :raise KeyError: Some or all nodes in the subgraph are not located in this state, or the state does not belong to the given SDFG. :raise ValueError: The subgraph is contained in more than one scope. """ if state.parent != sdfg: raise KeyError('State does not belong to given SDFG') if subgraph.graph != state: raise KeyError('Subgraph does not belong to given state') # Find the top-level scope scope_tree = state.scope_tree() scope_dict = state.scope_dict() scope_dict_children = state.scope_dict(True) top_scopenode = -1 # Initialized to -1 since "None" already means top-level for node in subgraph.nodes(): if node not in scope_dict: raise KeyError('Node not found in state') # If scope entry/exit, ensure entire scope is in subgraph if isinstance(node, nodes.EntryNode): scope_nodes = scope_dict_children[node] if any(n not in subgraph.nodes() for n in scope_nodes): raise ValueError('Subgraph contains partial scopes (entry)') elif isinstance(node, nodes.ExitNode): entry = state.entry_node(node) scope_nodes = scope_dict_children[entry] + [entry] if any(n not in subgraph.nodes() for n in scope_nodes): raise ValueError('Subgraph contains partial scopes (exit)') scope_node = scope_dict[node] if scope_node not in subgraph.nodes(): if top_scopenode != -1 and top_scopenode != scope_node: raise ValueError( 'Subgraph is contained in more than one scope') top_scopenode = scope_node scope = scope_tree[top_scopenode] ### # Collect inputs and outputs of the nested SDFG inputs: List[MultiConnectorEdge] = [] outputs: List[MultiConnectorEdge] = [] for node in subgraph.source_nodes(): inputs.extend(state.in_edges(node)) for node in subgraph.sink_nodes(): outputs.extend(state.out_edges(node)) # Collect transients not used outside of subgraph (will be removed of # top-level graph) data_in_subgraph = set(n.data for n in subgraph.nodes() if isinstance(n, nodes.AccessNode)) # Find other occurrences in SDFG other_nodes = set( n.data for s in sdfg.nodes() for n in s.nodes() if isinstance(n, nodes.AccessNode) and n not in subgraph.nodes()) subgraph_transients = set() for data in data_in_subgraph: datadesc = sdfg.arrays[data] if datadesc.transient and data not in other_nodes: subgraph_transients.add(data) # All transients of edges between code nodes are also added to nested graph for edge in subgraph.edges(): if (isinstance(edge.src, nodes.CodeNode) and isinstance(edge.dst, nodes.CodeNode)): subgraph_transients.add(edge.data.data) # Collect data used in access nodes within subgraph (will be referenced in # full upon nesting) input_arrays = set() output_arrays = set() for node in subgraph.nodes(): if (isinstance(node, nodes.AccessNode) and node.data not in subgraph_transients): if state.out_degree(node) > 0: input_arrays.add(node.data) if state.in_degree(node) > 0: output_arrays.add(node.data) # Create the nested SDFG nsdfg = SDFG(name or 'nested_' + state.label) # Transients are added to the nested graph as-is for name in subgraph_transients: nsdfg.add_datadesc(name, sdfg.arrays[name]) # Input/output data that are not source/sink nodes are added to the graph # as non-transients for name in (input_arrays | output_arrays): datadesc = copy.deepcopy(sdfg.arrays[name]) datadesc.transient = False nsdfg.add_datadesc(name, datadesc) # Connected source/sink nodes outside subgraph become global data # descriptors in nested SDFG input_names = [] output_names = [] for edge in inputs: if edge.data.data is None: # Skip edges with an empty memlet continue name = '__in_' + edge.data.data datadesc = copy.deepcopy(sdfg.arrays[edge.data.data]) datadesc.transient = False if not full_data: datadesc.shape = edge.data.subset.size() input_names.append( nsdfg.add_datadesc(name, datadesc, find_new_name=True)) for edge in outputs: if edge.data.data is None: # Skip edges with an empty memlet continue name = '__out_' + edge.data.data datadesc = copy.deepcopy(sdfg.arrays[edge.data.data]) datadesc.transient = False if not full_data: datadesc.shape = edge.data.subset.size() output_names.append( nsdfg.add_datadesc(name, datadesc, find_new_name=True)) ################### # Add scope symbols to the nested SDFG for v in scope.defined_vars: if v in sdfg.symbols: sym = sdfg.symbols[v] nsdfg.add_symbol(v, sym.dtype) # Create nested state nstate = nsdfg.add_state() # Add subgraph nodes and edges to nested state nstate.add_nodes_from(subgraph.nodes()) for e in subgraph.edges(): nstate.add_edge(e.src, e.src_conn, e.dst, e.dst_conn, e.data) # Modify nested SDFG parents in subgraph for node in subgraph.nodes(): if isinstance(node, nodes.NestedSDFG): node.sdfg.parent = nstate node.sdfg.parent_sdfg = nsdfg # Add access nodes and edges as necessary edges_to_offset = [] for name, edge in zip(input_names, inputs): node = nstate.add_read(name) new_edge = copy.deepcopy(edge.data) new_edge.data = name edges_to_offset.append((edge, nstate.add_edge(node, None, edge.dst, edge.dst_conn, new_edge))) for name, edge in zip(output_names, outputs): node = nstate.add_write(name) new_edge = copy.deepcopy(edge.data) new_edge.data = name edges_to_offset.append((edge, nstate.add_edge(edge.src, edge.src_conn, node, None, new_edge))) # Offset memlet paths inside nested SDFG according to subsets for original_edge, new_edge in edges_to_offset: for edge in nstate.memlet_tree(new_edge): edge.data.data = new_edge.data.data if not full_data: edge.data.subset.offset(original_edge.data.subset, True) # Add nested SDFG node to the input state nested_sdfg = state.add_nested_sdfg(nsdfg, None, set(input_names) | input_arrays, set(output_names) | output_arrays) # Reconnect memlets to nested SDFG for name, edge in zip(input_names, inputs): if full_data: data = Memlet.from_array(edge.data.data, sdfg.arrays[edge.data.data]) else: data = edge.data state.add_edge(edge.src, edge.src_conn, nested_sdfg, name, data) for name, edge in zip(output_names, outputs): if full_data: data = Memlet.from_array(edge.data.data, sdfg.arrays[edge.data.data]) else: data = edge.data state.add_edge(nested_sdfg, name, edge.dst, edge.dst_conn, data) # Connect access nodes to internal input/output data as necessary entry = scope.entry exit = scope.exit for name in input_arrays: node = state.add_read(name) if entry is not None: state.add_nedge(entry, node, EmptyMemlet()) state.add_edge(node, None, nested_sdfg, name, Memlet.from_array(name, sdfg.arrays[name])) for name in output_arrays: node = state.add_write(name) if exit is not None: state.add_nedge(node, exit, EmptyMemlet()) state.add_edge(nested_sdfg, name, node, None, Memlet.from_array(name, sdfg.arrays[name])) # Remove subgraph nodes from graph state.remove_nodes_from(subgraph.nodes()) # Remove subgraph transients from top-level graph for transient in subgraph_transients: del sdfg.arrays[transient] return nested_sdfg