def cutout_state(state: SDFGState, *nodes: nd.Node, make_copy: bool = True) -> SDFG: """ Cut out a subgraph of a state from an SDFG to run separately for localized testing or optimization. The subgraph defined by the list of nodes will be extended to include access nodes of data containers necessary to run the graph separately. In addition, all transient data containers created outside the cut out graph will become global. :param state: The SDFG state in which the subgraph resides. :param nodes: The nodes in the subgraph to cut out. :param make_copy: If True, deep-copies every SDFG element in the copy. Otherwise, original references are kept. """ create_element = copy.deepcopy if make_copy else (lambda x: x) sdfg = state.parent subgraph: StateSubgraphView = StateSubgraphView(state, nodes) subgraph = _extend_subgraph_with_access_nodes(state, subgraph) other_arrays = _containers_defined_outside(sdfg, state, subgraph) # Make a new SDFG with the included constants, used symbols, and data containers new_sdfg = SDFG(f'{state.parent.name}_cutout', sdfg.constants_prop) defined_syms = subgraph.defined_symbols() freesyms = subgraph.free_symbols for sym in freesyms: new_sdfg.add_symbol(sym, defined_syms[sym]) for dnode in subgraph.data_nodes(): if dnode.data in new_sdfg.arrays: continue new_desc = sdfg.arrays[dnode.data].clone() # If transient is defined outside, it becomes a global if dnode.data in other_arrays: new_desc.transient = False new_sdfg.add_datadesc(dnode.data, new_desc) # Add a single state with the extended subgraph new_state = new_sdfg.add_state(state.label, is_start_state=True) inserted_nodes: Dict[nd.Node, nd.Node] = {} for e in subgraph.edges(): if e.src not in inserted_nodes: inserted_nodes[e.src] = create_element(e.src) if e.dst not in inserted_nodes: inserted_nodes[e.dst] = create_element(e.dst) new_state.add_edge(inserted_nodes[e.src], e.src_conn, inserted_nodes[e.dst], e.dst_conn, create_element(e.data)) # Insert remaining isolated nodes for n in subgraph.nodes(): if n not in inserted_nodes: inserted_nodes[n] = create_element(n) new_state.add_node(inserted_nodes[n]) # Remove remaining dangling connectors from scope nodes for node in inserted_nodes.values(): used_connectors = set(e.dst_conn for e in new_state.in_edges(node)) for conn in (node.in_connectors.keys() - used_connectors): node.remove_in_connector(conn) used_connectors = set(e.src_conn for e in new_state.out_edges(node)) for conn in (node.out_connectors.keys() - used_connectors): node.remove_out_connector(conn) return new_sdfg
def replicate_scope(sdfg: SDFG, state: SDFGState, scope: ScopeSubgraphView) -> ScopeSubgraphView: """ Replicates a scope subgraph view within a state, reconnecting all external edges to the same nodes. :param sdfg: The SDFG in which the subgraph scope resides. :param state: The SDFG state in which the subgraph scope resides. :param scope: The scope subgraph to replicate. :return: A reconnected replica of the scope. """ exit_node = state.exit_node(scope.entry) # Replicate internal graph new_nodes = [] new_entry = None new_exit = None to_find_new_names: Set[nodes.AccessNode] = set() for node in scope.nodes(): node_copy = copy.deepcopy(node) if node == scope.entry: new_entry = node_copy elif node == exit_node: new_exit = node_copy if (isinstance(node, nodes.AccessNode) and node.desc(sdfg).lifetime == dtypes.AllocationLifetime.Scope and node.desc(sdfg).transient): to_find_new_names.add(node_copy) state.add_node(node_copy) new_nodes.append(node_copy) for edge in scope.edges(): src = scope.nodes().index(edge.src) dst = scope.nodes().index(edge.dst) state.add_edge(new_nodes[src], edge.src_conn, new_nodes[dst], edge.dst_conn, copy.deepcopy(edge.data)) # Reconnect external scope nodes for edge in state.in_edges(scope.entry): state.add_edge(edge.src, edge.src_conn, new_entry, edge.dst_conn, copy.deepcopy(edge.data)) for edge in state.out_edges(exit_node): state.add_edge(new_exit, edge.src_conn, edge.dst, edge.dst_conn, copy.deepcopy(edge.data)) # Set the exit node's map to match the entry node new_exit.map = new_entry.map # Replicate all temporary transients within scope for node in to_find_new_names: desc = node.desc(sdfg) new_name = sdfg.add_datadesc(node.data, copy.deepcopy(desc), find_new_name=True) node.data = new_name for edge in state.all_edges(node): for e in state.memlet_tree(edge): e.data.data = new_name return ScopeSubgraphView(state, new_nodes, new_entry)
def apply(self, sdfg: SDFG): state: SDFGState = sdfg.nodes()[self.state_id] nsdfg_node = state.nodes()[self.subgraph[InlineSDFG._nested_sdfg]] nsdfg: SDFG = nsdfg_node.sdfg nstate: SDFGState = nsdfg.nodes()[0] if nsdfg_node.schedule is not dtypes.ScheduleType.Default: infer_types.set_default_schedule_and_storage_types( nsdfg, nsdfg_node.schedule) nsdfg_scope_entry = state.entry_node(nsdfg_node) nsdfg_scope_exit = (state.exit_node(nsdfg_scope_entry) if nsdfg_scope_entry is not None else None) ####################################################### # Collect and update top-level SDFG metadata # Global/init/exit code for loc, code in nsdfg.global_code.items(): sdfg.append_global_code(code.code, loc) for loc, code in nsdfg.init_code.items(): sdfg.append_init_code(code.code, loc) for loc, code in nsdfg.exit_code.items(): sdfg.append_exit_code(code.code, loc) # Constants for cstname, cstval in nsdfg.constants.items(): if cstname in sdfg.constants: if cstval != sdfg.constants[cstname]: warnings.warn('Constant value mismatch for "%s" while ' 'inlining SDFG. Inner = %s != %s = outer' % (cstname, cstval, sdfg.constants[cstname])) else: sdfg.add_constant(cstname, cstval) # Find original source/destination edges (there is only one edge per # connector, according to match) inputs: Dict[str, MultiConnectorEdge] = {} outputs: Dict[str, MultiConnectorEdge] = {} input_set: Dict[str, str] = {} output_set: Dict[str, str] = {} for e in state.in_edges(nsdfg_node): inputs[e.dst_conn] = e input_set[e.data.data] = e.dst_conn for e in state.out_edges(nsdfg_node): outputs[e.src_conn] = e output_set[e.data.data] = e.src_conn # Access nodes that need to be reshaped reshapes: Set(str) = set() for aname, array in nsdfg.arrays.items(): if array.transient: continue edge = None if aname in inputs: edge = inputs[aname] if len(array.shape) > len(edge.data.subset): reshapes.add(aname) continue if aname in outputs: edge = outputs[aname] if len(array.shape) > len(edge.data.subset): reshapes.add(aname) continue if edge is not None and not InlineSDFG._check_strides( array.strides, sdfg.arrays[edge.data.data].strides, edge.data, nsdfg_node): reshapes.add(aname) # Replace symbols using invocation symbol mapping # Two-step replacement (N -> __dacesym_N --> map[N]) to avoid clashes for symname, symvalue in nsdfg_node.symbol_mapping.items(): if str(symname) != str(symvalue): nsdfg.replace(symname, '__dacesym_' + symname) for symname, symvalue in nsdfg_node.symbol_mapping.items(): if str(symname) != str(symvalue): nsdfg.replace('__dacesym_' + symname, symvalue) # All transients become transients of the parent (if data already # exists, find new name) # Mapping from nested transient name to top-level name transients: Dict[str, str] = {} for node in nstate.nodes(): if isinstance(node, nodes.AccessNode): datadesc = nsdfg.arrays[node.data] if node.data not in transients and datadesc.transient: name = sdfg.add_datadesc('%s_%s' % (nsdfg.label, node.data), datadesc, find_new_name=True) transients[node.data] = name # All transients of edges between code nodes are also added to parent for edge in nstate.edges(): if (isinstance(edge.src, nodes.CodeNode) and isinstance(edge.dst, nodes.CodeNode)): if edge.data.data is not None: datadesc = nsdfg.arrays[edge.data.data] if edge.data.data not in transients and datadesc.transient: name = sdfg.add_datadesc('%s_%s' % (nsdfg.label, edge.data.data), datadesc, find_new_name=True) transients[edge.data.data] = name # Collect nodes to add to top-level graph new_incoming_edges: Dict[nodes.Node, MultiConnectorEdge] = {} new_outgoing_edges: Dict[nodes.Node, MultiConnectorEdge] = {} source_accesses = set() sink_accesses = set() for node in nstate.source_nodes(): if (isinstance(node, nodes.AccessNode) and node.data not in transients and node.data not in reshapes): new_incoming_edges[node] = inputs[node.data] source_accesses.add(node) for node in nstate.sink_nodes(): if (isinstance(node, nodes.AccessNode) and node.data not in transients and node.data not in reshapes): new_outgoing_edges[node] = outputs[node.data] sink_accesses.add(node) ####################################################### # Replace data on inlined SDFG nodes/edges # Replace data names with their top-level counterparts repldict = {} repldict.update(transients) repldict.update({ k: v.data.data for k, v in itertools.chain(inputs.items(), outputs.items()) }) # Add views whenever reshapes are necessary for dname in reshapes: desc = nsdfg.arrays[dname] # To avoid potential confusion, rename protected __return keyword if dname.startswith('__return'): newname = f'{nsdfg.name}_ret{dname[8:]}' else: newname = dname newname, _ = sdfg.add_view(newname, desc.shape, desc.dtype, storage=desc.storage, strides=desc.strides, offset=desc.offset, debuginfo=desc.debuginfo, allow_conflicts=desc.allow_conflicts, total_size=desc.total_size, alignment=desc.alignment, may_alias=desc.may_alias, find_new_name=True) repldict[dname] = newname for node in nstate.nodes(): if isinstance(node, nodes.AccessNode) and node.data in repldict: node.data = repldict[node.data] for edge in nstate.edges(): if edge.data.data in repldict: edge.data.data = repldict[edge.data.data] # Add extra access nodes for out/in view nodes for node in nstate.nodes(): if isinstance(node, nodes.AccessNode) and node.data in reshapes: if nstate.in_degree(node) > 0 and nstate.out_degree(node) > 0: # Such a node has to be in the output set edge = outputs[node.data] # Redirect outgoing edges through access node out_edges = list(nstate.out_edges(node)) anode = nstate.add_access(edge.data.data) vnode = nstate.add_access(node.data) nstate.add_nedge(node, anode, edge.data) nstate.add_nedge(anode, vnode, edge.data) for e in out_edges: nstate.remove_edge(e) nstate.add_edge(vnode, e.src_conn, e.dst, e.dst_conn, e.data) ####################################################### # Add nested SDFG into top-level SDFG # Add nested nodes into original state subgraph = SubgraphView(nstate, [ n for n in nstate.nodes() if n not in (source_accesses | sink_accesses) ]) state.add_nodes_from(subgraph.nodes()) for edge in subgraph.edges(): state.add_edge(edge.src, edge.src_conn, edge.dst, edge.dst_conn, edge.data) ####################################################### # Reconnect inlined SDFG # If a source/sink node is one of the inputs/outputs, reconnect it, # replacing memlets in outgoing/incoming paths modified_edges = set() modified_edges |= self._modify_memlet_path(new_incoming_edges, nstate, state, True) modified_edges |= self._modify_memlet_path(new_outgoing_edges, nstate, state, False) # Reshape: add connections to viewed data self._modify_reshape_data(reshapes, repldict, inputs, nstate, state, True) self._modify_reshape_data(reshapes, repldict, outputs, nstate, state, False) # Modify all other internal edges pertaining to input/output nodes for node in subgraph.nodes(): if isinstance(node, nodes.AccessNode): if node.data in input_set or node.data in output_set: if node.data in input_set: outer_edge = inputs[input_set[node.data]] else: outer_edge = outputs[output_set[node.data]] for edge in state.all_edges(node): if (edge not in modified_edges and edge.data.data == node.data): for e in state.memlet_tree(edge): if e.data.data == node.data: e._data = helpers.unsqueeze_memlet( e.data, outer_edge.data) # If source/sink node is not connected to a source/destination access # node, and the nested SDFG is in a scope, connect to scope with empty # memlets if nsdfg_scope_entry is not None: for node in subgraph.nodes(): if state.in_degree(node) == 0: state.add_edge(nsdfg_scope_entry, None, node, None, Memlet()) if state.out_degree(node) == 0: state.add_edge(node, None, nsdfg_scope_exit, None, Memlet()) # Replace nested SDFG parents with new SDFG for node in nstate.nodes(): if isinstance(node, nodes.NestedSDFG): node.sdfg.parent = state node.sdfg.parent_sdfg = sdfg node.sdfg.parent_nsdfg_node = node # Remove all unused external inputs/output memlet paths, as well as # resulting isolated nodes removed_in_edges = self._remove_edge_path(state, inputs, set(inputs.keys()) - source_accesses, reverse=True) removed_out_edges = self._remove_edge_path(state, outputs, set(outputs.keys()) - sink_accesses, reverse=False) # Re-add in/out edges to first/last nodes in subgraph order = [ x for x in nx.topological_sort(nstate._nx) if isinstance(x, nodes.AccessNode) ] for edge in removed_in_edges: # Find first access node that refers to this edge node = next(n for n in order if n.data == edge.data.data) state.add_edge(edge.src, edge.src_conn, node, edge.dst_conn, edge.data) for edge in removed_out_edges: # Find last access node that refers to this edge node = next(n for n in reversed(order) if n.data == edge.data.data) state.add_edge(node, edge.src_conn, edge.dst, edge.dst_conn, edge.data) ####################################################### # Remove nested SDFG node state.remove_node(nsdfg_node)
def nest_state_subgraph(sdfg: SDFG, state: SDFGState, subgraph: SubgraphView, name: Optional[str] = None, full_data: bool = False) -> nodes.NestedSDFG: """ Turns a state subgraph into a nested SDFG. Operates in-place. :param sdfg: The SDFG containing the state subgraph. :param state: The state containing the subgraph. :param subgraph: Subgraph to nest. :param name: An optional name for the nested SDFG. :param full_data: If True, nests entire input/output data. :return: The nested SDFG node. :raise KeyError: Some or all nodes in the subgraph are not located in this state, or the state does not belong to the given SDFG. :raise ValueError: The subgraph is contained in more than one scope. """ if state.parent != sdfg: raise KeyError('State does not belong to given SDFG') if subgraph is not state and subgraph.graph is not state: raise KeyError('Subgraph does not belong to given state') # Find the top-level scope scope_tree = state.scope_tree() scope_dict = state.scope_dict() scope_dict_children = state.scope_children() top_scopenode = -1 # Initialized to -1 since "None" already means top-level for node in subgraph.nodes(): if node not in scope_dict: raise KeyError('Node not found in state') # If scope entry/exit, ensure entire scope is in subgraph if isinstance(node, nodes.EntryNode): scope_nodes = scope_dict_children[node] if any(n not in subgraph.nodes() for n in scope_nodes): raise ValueError('Subgraph contains partial scopes (entry)') elif isinstance(node, nodes.ExitNode): entry = state.entry_node(node) scope_nodes = scope_dict_children[entry] + [entry] if any(n not in subgraph.nodes() for n in scope_nodes): raise ValueError('Subgraph contains partial scopes (exit)') scope_node = scope_dict[node] if scope_node not in subgraph.nodes(): if top_scopenode != -1 and top_scopenode != scope_node: raise ValueError('Subgraph is contained in more than one scope') top_scopenode = scope_node scope = scope_tree[top_scopenode] ### # Consolidate edges in top scope utils.consolidate_edges(sdfg, scope) snodes = subgraph.nodes() # Collect inputs and outputs of the nested SDFG inputs: List[MultiConnectorEdge] = [] outputs: List[MultiConnectorEdge] = [] for node in snodes: for edge in state.in_edges(node): if edge.src not in snodes: inputs.append(edge) for edge in state.out_edges(node): if edge.dst not in snodes: outputs.append(edge) # Collect transients not used outside of subgraph (will be removed of # top-level graph) data_in_subgraph = set(n.data for n in subgraph.nodes() if isinstance(n, nodes.AccessNode)) # Find other occurrences in SDFG other_nodes = set(n.data for s in sdfg.nodes() for n in s.nodes() if isinstance(n, nodes.AccessNode) and n not in subgraph.nodes()) subgraph_transients = set() for data in data_in_subgraph: datadesc = sdfg.arrays[data] if datadesc.transient and data not in other_nodes: subgraph_transients.add(data) # All transients of edges between code nodes are also added to nested graph for edge in subgraph.edges(): if (isinstance(edge.src, nodes.CodeNode) and isinstance(edge.dst, nodes.CodeNode)): subgraph_transients.add(edge.data.data) # Collect data used in access nodes within subgraph (will be referenced in # full upon nesting) input_arrays = set() output_arrays = {} for node in subgraph.nodes(): if (isinstance(node, nodes.AccessNode) and node.data not in subgraph_transients): if node.has_reads(state): input_arrays.add(node.data) if node.has_writes(state): output_arrays[node.data] = state.in_edges(node)[0].data.wcr # Create the nested SDFG nsdfg = SDFG(name or 'nested_' + state.label) # Transients are added to the nested graph as-is for name in subgraph_transients: nsdfg.add_datadesc(name, sdfg.arrays[name]) # Input/output data that are not source/sink nodes are added to the graph # as non-transients for name in (input_arrays | output_arrays.keys()): datadesc = copy.deepcopy(sdfg.arrays[name]) datadesc.transient = False nsdfg.add_datadesc(name, datadesc) # Connected source/sink nodes outside subgraph become global data # descriptors in nested SDFG input_names = {} output_names = {} global_subsets: Dict[str, Tuple[str, Subset]] = {} for edge in inputs: if edge.data.data is None: # Skip edges with an empty memlet continue name = edge.data.data if name not in global_subsets: datadesc = copy.deepcopy(sdfg.arrays[edge.data.data]) datadesc.transient = False if not full_data: datadesc.shape = edge.data.subset.size() new_name = nsdfg.add_datadesc(name, datadesc, find_new_name=True) global_subsets[name] = (new_name, edge.data.subset) else: new_name, subset = global_subsets[name] if not full_data: new_subset = union(subset, edge.data.subset) if new_subset is None: new_subset = Range.from_array(sdfg.arrays[name]) global_subsets[name] = (new_name, new_subset) nsdfg.arrays[new_name].shape = new_subset.size() input_names[edge] = new_name for edge in outputs: if edge.data.data is None: # Skip edges with an empty memlet continue name = edge.data.data if name not in global_subsets: datadesc = copy.deepcopy(sdfg.arrays[edge.data.data]) datadesc.transient = False if not full_data: datadesc.shape = edge.data.subset.size() new_name = nsdfg.add_datadesc(name, datadesc, find_new_name=True) global_subsets[name] = (new_name, edge.data.subset) else: new_name, subset = global_subsets[name] if not full_data: new_subset = union(subset, edge.data.subset) if new_subset is None: new_subset = Range.from_array(sdfg.arrays[name]) global_subsets[name] = (new_name, new_subset) nsdfg.arrays[new_name].shape = new_subset.size() output_names[edge] = new_name ################### # Add scope symbols to the nested SDFG defined_vars = set( symbolic.pystr_to_symbolic(s) for s in (state.symbols_defined_at(top_scopenode).keys() | sdfg.symbols)) for v in defined_vars: if v in sdfg.symbols: sym = sdfg.symbols[v] nsdfg.add_symbol(v, sym.dtype) # Add constants to nested SDFG for cstname, cstval in sdfg.constants.items(): nsdfg.add_constant(cstname, cstval) # Create nested state nstate = nsdfg.add_state() # Add subgraph nodes and edges to nested state nstate.add_nodes_from(subgraph.nodes()) for e in subgraph.edges(): nstate.add_edge(e.src, e.src_conn, e.dst, e.dst_conn, copy.deepcopy(e.data)) # Modify nested SDFG parents in subgraph for node in subgraph.nodes(): if isinstance(node, nodes.NestedSDFG): node.sdfg.parent = nstate node.sdfg.parent_sdfg = nsdfg node.sdfg.parent_nsdfg_node = node # Add access nodes and edges as necessary edges_to_offset = [] for edge, name in input_names.items(): node = nstate.add_read(name) new_edge = copy.deepcopy(edge.data) new_edge.data = name edges_to_offset.append((edge, nstate.add_edge(node, None, edge.dst, edge.dst_conn, new_edge))) for edge, name in output_names.items(): node = nstate.add_write(name) new_edge = copy.deepcopy(edge.data) new_edge.data = name edges_to_offset.append((edge, nstate.add_edge(edge.src, edge.src_conn, node, None, new_edge))) # Offset memlet paths inside nested SDFG according to subsets for original_edge, new_edge in edges_to_offset: for edge in nstate.memlet_tree(new_edge): edge.data.data = new_edge.data.data if not full_data: edge.data.subset.offset(global_subsets[original_edge.data.data][1], True) # Add nested SDFG node to the input state nested_sdfg = state.add_nested_sdfg(nsdfg, None, set(input_names.values()) | input_arrays, set(output_names.values()) | output_arrays.keys()) # Reconnect memlets to nested SDFG reconnected_in = set() reconnected_out = set() empty_input = None empty_output = None for edge in inputs: if edge.data.data is None: empty_input = edge continue name = input_names[edge] if name in reconnected_in: continue if full_data: data = Memlet.from_array(edge.data.data, sdfg.arrays[edge.data.data]) else: data = copy.deepcopy(edge.data) data.subset = global_subsets[edge.data.data][1] state.add_edge(edge.src, edge.src_conn, nested_sdfg, name, data) reconnected_in.add(name) for edge in outputs: if edge.data.data is None: empty_output = edge continue name = output_names[edge] if name in reconnected_out: continue if full_data: data = Memlet.from_array(edge.data.data, sdfg.arrays[edge.data.data]) else: data = copy.deepcopy(edge.data) data.subset = global_subsets[edge.data.data][1] data.wcr = edge.data.wcr state.add_edge(nested_sdfg, name, edge.dst, edge.dst_conn, data) reconnected_out.add(name) # Connect access nodes to internal input/output data as necessary entry = scope.entry exit = scope.exit for name in input_arrays: node = state.add_read(name) if entry is not None: state.add_nedge(entry, node, Memlet()) state.add_edge(node, None, nested_sdfg, name, Memlet.from_array(name, sdfg.arrays[name])) for name, wcr in output_arrays.items(): node = state.add_write(name) if exit is not None: state.add_nedge(node, exit, Memlet()) state.add_edge(nested_sdfg, name, node, None, Memlet(data=name, wcr=wcr)) # Graph was not reconnected, but needs to be if state.in_degree(nested_sdfg) == 0 and empty_input is not None: state.add_edge(empty_input.src, empty_input.src_conn, nested_sdfg, None, empty_input.data) if state.out_degree(nested_sdfg) == 0 and empty_output is not None: state.add_edge(nested_sdfg, None, empty_output.dst, empty_output.dst_conn, empty_output.data) # Remove subgraph nodes from graph state.remove_nodes_from(subgraph.nodes()) # Remove subgraph transients from top-level graph for transient in subgraph_transients: del sdfg.arrays[transient] # Remove newly isolated nodes due to memlet consolidation for edge in inputs: if state.in_degree(edge.src) + state.out_degree(edge.src) == 0: state.remove_node(edge.src) for edge in outputs: if state.in_degree(edge.dst) + state.out_degree(edge.dst) == 0: state.remove_node(edge.dst) return nested_sdfg
def apply(self, sdfg: sd.SDFG): ####################################################### # Step 0: SDFG metadata # Find all input and output data descriptors input_nodes = [] output_nodes = [] global_code_nodes = [[] for _ in sdfg.nodes()] for i, state in enumerate(sdfg.nodes()): sdict = state.scope_dict() for node in state.nodes(): if (isinstance(node, nodes.AccessNode) and node.desc(sdfg).transient == False): if (state.out_degree(node) > 0 and node.data not in input_nodes): # Special case: nodes that lead to top-level dynamic # map ranges must stay on host for e in state.out_edges(node): last_edge = state.memlet_path(e)[-1] if (isinstance(last_edge.dst, nodes.EntryNode) and last_edge.dst_conn and not last_edge.dst_conn.startswith('IN_') and sdict[last_edge.dst] is None): break else: input_nodes.append((node.data, node.desc(sdfg))) if (state.in_degree(node) > 0 and node.data not in output_nodes): output_nodes.append((node.data, node.desc(sdfg))) elif isinstance(node, nodes.CodeNode) and sdict[node] is None: if not isinstance(node, (nodes.LibraryNode, nodes.NestedSDFG)): global_code_nodes[i].append(node) # Input nodes may also be nodes with WCR memlets and no identity for e in state.edges(): if e.data.wcr is not None: if (e.data.data not in input_nodes and sdfg.arrays[e.data.data].transient == False): input_nodes.append( (e.data.data, sdfg.arrays[e.data.data])) start_state = sdfg.start_state end_states = sdfg.sink_nodes() ####################################################### # Step 1: Create cloned GPU arrays and replace originals cloned_arrays = {} for inodename, inode in set(input_nodes): if isinstance(inode, data.Scalar): # Scalars can remain on host continue if inode.storage == dtypes.StorageType.GPU_Global: continue newdesc = inode.clone() newdesc.storage = dtypes.StorageType.GPU_Global newdesc.transient = True name = sdfg.add_datadesc('gpu_' + inodename, newdesc, find_new_name=True) cloned_arrays[inodename] = name for onodename, onode in set(output_nodes): if onodename in cloned_arrays: continue if onode.storage == dtypes.StorageType.GPU_Global: continue newdesc = onode.clone() newdesc.storage = dtypes.StorageType.GPU_Global newdesc.transient = True name = sdfg.add_datadesc('gpu_' + onodename, newdesc, find_new_name=True) cloned_arrays[onodename] = name # Replace nodes for state in sdfg.nodes(): for node in state.nodes(): if (isinstance(node, nodes.AccessNode) and node.data in cloned_arrays): node.data = cloned_arrays[node.data] # Replace memlets for state in sdfg.nodes(): for edge in state.edges(): if edge.data.data in cloned_arrays: edge.data.data = cloned_arrays[edge.data.data] ####################################################### # Step 2: Create copy-in state excluded_copyin = self.exclude_copyin.split(',') copyin_state = sdfg.add_state(sdfg.label + '_copyin') sdfg.add_edge(copyin_state, start_state, sd.InterstateEdge()) for nname, desc in dtypes.deduplicate(input_nodes): if nname in excluded_copyin or nname not in cloned_arrays: continue src_array = nodes.AccessNode(nname, debuginfo=desc.debuginfo) dst_array = nodes.AccessNode(cloned_arrays[nname], debuginfo=desc.debuginfo) copyin_state.add_node(src_array) copyin_state.add_node(dst_array) copyin_state.add_nedge( src_array, dst_array, memlet.Memlet.from_array(src_array.data, src_array.desc(sdfg))) ####################################################### # Step 3: Create copy-out state excluded_copyout = self.exclude_copyout.split(',') copyout_state = sdfg.add_state(sdfg.label + '_copyout') for state in end_states: sdfg.add_edge(state, copyout_state, sd.InterstateEdge()) for nname, desc in dtypes.deduplicate(output_nodes): if nname in excluded_copyout or nname not in cloned_arrays: continue src_array = nodes.AccessNode(cloned_arrays[nname], debuginfo=desc.debuginfo) dst_array = nodes.AccessNode(nname, debuginfo=desc.debuginfo) copyout_state.add_node(src_array) copyout_state.add_node(dst_array) copyout_state.add_nedge( src_array, dst_array, memlet.Memlet.from_array(dst_array.data, dst_array.desc(sdfg))) ####################################################### # Step 4: Modify transient data storage for state in sdfg.nodes(): sdict = state.scope_dict() for node in state.nodes(): if isinstance(node, nodes.AccessNode) and node.desc(sdfg).transient: nodedesc = node.desc(sdfg) # Special case: nodes that lead to dynamic map ranges must # stay on host if any( isinstance( state.memlet_path(e)[-1].dst, nodes.EntryNode) for e in state.out_edges(node)): continue gpu_storage = [ dtypes.StorageType.GPU_Global, dtypes.StorageType.GPU_Shared, dtypes.StorageType.CPU_Pinned ] if sdict[ node] is None and nodedesc.storage not in gpu_storage: # NOTE: the cloned arrays match too but it's the same # storage so we don't care nodedesc.storage = dtypes.StorageType.GPU_Global # Try to move allocation/deallocation out of loops if (self.toplevel_trans and not isinstance(nodedesc, data.Stream)): nodedesc.lifetime = dtypes.AllocationLifetime.SDFG elif nodedesc.storage not in gpu_storage: # Make internal transients registers if self.register_trans: nodedesc.storage = dtypes.StorageType.Register ####################################################### # Step 5: Wrap free tasklets and nested SDFGs with a GPU map for state, gcodes in zip(sdfg.nodes(), global_code_nodes): for gcode in gcodes: if gcode.label in self.exclude_tasklets.split(','): continue # Create map and connectors me, mx = state.add_map(gcode.label + '_gmap', {gcode.label + '__gmapi': '0:1'}, schedule=dtypes.ScheduleType.GPU_Device) # Store in/out edges in lists so that they don't get corrupted # when they are removed from the graph in_edges = list(state.in_edges(gcode)) out_edges = list(state.out_edges(gcode)) me.in_connectors = {('IN_' + e.dst_conn): None for e in in_edges} me.out_connectors = {('OUT_' + e.dst_conn): None for e in in_edges} mx.in_connectors = {('IN_' + e.src_conn): None for e in out_edges} mx.out_connectors = {('OUT_' + e.src_conn): None for e in out_edges} # Create memlets through map for e in in_edges: state.remove_edge(e) state.add_edge(e.src, e.src_conn, me, 'IN_' + e.dst_conn, e.data) state.add_edge(me, 'OUT_' + e.dst_conn, e.dst, e.dst_conn, e.data) for e in out_edges: state.remove_edge(e) state.add_edge(e.src, e.src_conn, mx, 'IN_' + e.src_conn, e.data) state.add_edge(mx, 'OUT_' + e.src_conn, e.dst, e.dst_conn, e.data) # Map without inputs if len(in_edges) == 0: state.add_nedge(me, gcode, memlet.Memlet()) ####################################################### # Step 6: Change all top-level maps and library nodes to GPU schedule for i, state in enumerate(sdfg.nodes()): sdict = state.scope_dict() for node in state.nodes(): if isinstance(node, (nodes.EntryNode, nodes.LibraryNode)): if sdict[node] is None: node.schedule = dtypes.ScheduleType.GPU_Device elif (isinstance(node, (nodes.EntryNode, nodes.LibraryNode)) and self.sequential_innermaps): node.schedule = dtypes.ScheduleType.Sequential ####################################################### # Step 7: Introduce copy-out if data used in outgoing interstate edges for state in list(sdfg.nodes()): arrays_used = set() for e in sdfg.out_edges(state): # Used arrays = intersection between symbols and cloned arrays arrays_used.update( set(e.data.free_symbols) & set(cloned_arrays.keys())) # Create a state and copy out used arrays if len(arrays_used) > 0: co_state = sdfg.add_state(state.label + '_icopyout') # Reconnect outgoing edges to after interim copyout state for e in sdfg.out_edges(state): sdutil.change_edge_src(sdfg, state, co_state) # Add unconditional edge to interim state sdfg.add_edge(state, co_state, sd.InterstateEdge()) # Add copy-out nodes for nname in arrays_used: desc = sdfg.arrays[nname] src_array = nodes.AccessNode(cloned_arrays[nname], debuginfo=desc.debuginfo) dst_array = nodes.AccessNode(nname, debuginfo=desc.debuginfo) co_state.add_node(src_array) co_state.add_node(dst_array) co_state.add_nedge( src_array, dst_array, memlet.Memlet.from_array(dst_array.data, dst_array.desc(sdfg))) ####################################################### # Step 8: Strict transformations if not self.strict_transform: return # Apply strict state fusions greedily. sdfg.apply_strict_transformations()
def nest_state_subgraph(sdfg: SDFG, state: SDFGState, subgraph: SubgraphView, name: Optional[str] = None, full_data: bool = False) -> nodes.NestedSDFG: """ Turns a state subgraph into a nested SDFG. Operates in-place. :param sdfg: The SDFG containing the state subgraph. :param state: The state containing the subgraph. :param subgraph: Subgraph to nest. :param name: An optional name for the nested SDFG. :param full_data: If True, nests entire input/output data. :return: The nested SDFG node. :raise KeyError: Some or all nodes in the subgraph are not located in this state, or the state does not belong to the given SDFG. :raise ValueError: The subgraph is contained in more than one scope. """ if state.parent != sdfg: raise KeyError('State does not belong to given SDFG') if subgraph.graph != state: raise KeyError('Subgraph does not belong to given state') # Find the top-level scope scope_tree = state.scope_tree() scope_dict = state.scope_dict() scope_dict_children = state.scope_dict(True) top_scopenode = -1 # Initialized to -1 since "None" already means top-level for node in subgraph.nodes(): if node not in scope_dict: raise KeyError('Node not found in state') # If scope entry/exit, ensure entire scope is in subgraph if isinstance(node, nodes.EntryNode): scope_nodes = scope_dict_children[node] if any(n not in subgraph.nodes() for n in scope_nodes): raise ValueError('Subgraph contains partial scopes (entry)') elif isinstance(node, nodes.ExitNode): entry = state.entry_node(node) scope_nodes = scope_dict_children[entry] + [entry] if any(n not in subgraph.nodes() for n in scope_nodes): raise ValueError('Subgraph contains partial scopes (exit)') scope_node = scope_dict[node] if scope_node not in subgraph.nodes(): if top_scopenode != -1 and top_scopenode != scope_node: raise ValueError( 'Subgraph is contained in more than one scope') top_scopenode = scope_node scope = scope_tree[top_scopenode] ### # Collect inputs and outputs of the nested SDFG inputs: List[MultiConnectorEdge] = [] outputs: List[MultiConnectorEdge] = [] for node in subgraph.source_nodes(): inputs.extend(state.in_edges(node)) for node in subgraph.sink_nodes(): outputs.extend(state.out_edges(node)) # Collect transients not used outside of subgraph (will be removed of # top-level graph) data_in_subgraph = set(n.data for n in subgraph.nodes() if isinstance(n, nodes.AccessNode)) # Find other occurrences in SDFG other_nodes = set( n.data for s in sdfg.nodes() for n in s.nodes() if isinstance(n, nodes.AccessNode) and n not in subgraph.nodes()) subgraph_transients = set() for data in data_in_subgraph: datadesc = sdfg.arrays[data] if datadesc.transient and data not in other_nodes: subgraph_transients.add(data) # All transients of edges between code nodes are also added to nested graph for edge in subgraph.edges(): if (isinstance(edge.src, nodes.CodeNode) and isinstance(edge.dst, nodes.CodeNode)): subgraph_transients.add(edge.data.data) # Collect data used in access nodes within subgraph (will be referenced in # full upon nesting) input_arrays = set() output_arrays = set() for node in subgraph.nodes(): if (isinstance(node, nodes.AccessNode) and node.data not in subgraph_transients): if state.out_degree(node) > 0: input_arrays.add(node.data) if state.in_degree(node) > 0: output_arrays.add(node.data) # Create the nested SDFG nsdfg = SDFG(name or 'nested_' + state.label) # Transients are added to the nested graph as-is for name in subgraph_transients: nsdfg.add_datadesc(name, sdfg.arrays[name]) # Input/output data that are not source/sink nodes are added to the graph # as non-transients for name in (input_arrays | output_arrays): datadesc = copy.deepcopy(sdfg.arrays[name]) datadesc.transient = False nsdfg.add_datadesc(name, datadesc) # Connected source/sink nodes outside subgraph become global data # descriptors in nested SDFG input_names = [] output_names = [] for edge in inputs: if edge.data.data is None: # Skip edges with an empty memlet continue name = '__in_' + edge.data.data datadesc = copy.deepcopy(sdfg.arrays[edge.data.data]) datadesc.transient = False if not full_data: datadesc.shape = edge.data.subset.size() input_names.append( nsdfg.add_datadesc(name, datadesc, find_new_name=True)) for edge in outputs: if edge.data.data is None: # Skip edges with an empty memlet continue name = '__out_' + edge.data.data datadesc = copy.deepcopy(sdfg.arrays[edge.data.data]) datadesc.transient = False if not full_data: datadesc.shape = edge.data.subset.size() output_names.append( nsdfg.add_datadesc(name, datadesc, find_new_name=True)) ################### # Add scope symbols to the nested SDFG for v in scope.defined_vars: if v in sdfg.symbols: sym = sdfg.symbols[v] nsdfg.add_symbol(v, sym.dtype) # Create nested state nstate = nsdfg.add_state() # Add subgraph nodes and edges to nested state nstate.add_nodes_from(subgraph.nodes()) for e in subgraph.edges(): nstate.add_edge(e.src, e.src_conn, e.dst, e.dst_conn, e.data) # Modify nested SDFG parents in subgraph for node in subgraph.nodes(): if isinstance(node, nodes.NestedSDFG): node.sdfg.parent = nstate node.sdfg.parent_sdfg = nsdfg # Add access nodes and edges as necessary edges_to_offset = [] for name, edge in zip(input_names, inputs): node = nstate.add_read(name) new_edge = copy.deepcopy(edge.data) new_edge.data = name edges_to_offset.append((edge, nstate.add_edge(node, None, edge.dst, edge.dst_conn, new_edge))) for name, edge in zip(output_names, outputs): node = nstate.add_write(name) new_edge = copy.deepcopy(edge.data) new_edge.data = name edges_to_offset.append((edge, nstate.add_edge(edge.src, edge.src_conn, node, None, new_edge))) # Offset memlet paths inside nested SDFG according to subsets for original_edge, new_edge in edges_to_offset: for edge in nstate.memlet_tree(new_edge): edge.data.data = new_edge.data.data if not full_data: edge.data.subset.offset(original_edge.data.subset, True) # Add nested SDFG node to the input state nested_sdfg = state.add_nested_sdfg(nsdfg, None, set(input_names) | input_arrays, set(output_names) | output_arrays) # Reconnect memlets to nested SDFG for name, edge in zip(input_names, inputs): if full_data: data = Memlet.from_array(edge.data.data, sdfg.arrays[edge.data.data]) else: data = edge.data state.add_edge(edge.src, edge.src_conn, nested_sdfg, name, data) for name, edge in zip(output_names, outputs): if full_data: data = Memlet.from_array(edge.data.data, sdfg.arrays[edge.data.data]) else: data = edge.data state.add_edge(nested_sdfg, name, edge.dst, edge.dst_conn, data) # Connect access nodes to internal input/output data as necessary entry = scope.entry exit = scope.exit for name in input_arrays: node = state.add_read(name) if entry is not None: state.add_nedge(entry, node, EmptyMemlet()) state.add_edge(node, None, nested_sdfg, name, Memlet.from_array(name, sdfg.arrays[name])) for name in output_arrays: node = state.add_write(name) if exit is not None: state.add_nedge(node, exit, EmptyMemlet()) state.add_edge(nested_sdfg, name, node, None, Memlet.from_array(name, sdfg.arrays[name])) # Remove subgraph nodes from graph state.remove_nodes_from(subgraph.nodes()) # Remove subgraph transients from top-level graph for transient in subgraph_transients: del sdfg.arrays[transient] return nested_sdfg
def apply(self, sdfg: SDFG): subgraph = self.subgraph_view(sdfg) entry_states_in, entry_states_out = self.get_entry_states( sdfg, subgraph) _, exit_states_out = self.get_exit_states(sdfg, subgraph) entry_state_in = entry_states_in.pop() entry_state_out = entry_states_out.pop() \ if len(entry_states_out) > 0 else None exit_state_out = exit_states_out.pop() \ if len(exit_states_out) > 0 else None launch_state = None entry_guard_state = None exit_guard_state = None # generate entry guard state if needed if self.include_in_assignment and entry_state_out is not None: entry_edge = sdfg.edges_between(entry_state_out, entry_state_in)[0] if len(entry_edge.data.assignments) > 0: entry_guard_state = sdfg.add_state( label='{}kernel_entry_guard'.format( self.kernel_prefix + '_' if self.kernel_prefix != '' else '')) sdfg.add_edge(entry_state_out, entry_guard_state, InterstateEdge(entry_edge.data.condition)) sdfg.add_edge( entry_guard_state, entry_state_in, InterstateEdge(None, entry_edge.data.assignments)) sdfg.remove_edge(entry_edge) # Update SubgraphView new_node_list = subgraph.nodes() new_node_list.append(entry_guard_state) subgraph = SubgraphView(sdfg, new_node_list) launch_state = sdfg.add_state_before( entry_guard_state, label='{}kernel_launch'.format( self.kernel_prefix + '_' if self.kernel_prefix != '' else '')) # generate exit guard state if exit_state_out is not None: exit_guard_state = sdfg.add_state_before( exit_state_out, label='{}kernel_exit_guard'.format( self.kernel_prefix + '_' if self.kernel_prefix != '' else '')) # Update SubgraphView new_node_list = subgraph.nodes() new_node_list.append(exit_guard_state) subgraph = SubgraphView(sdfg, new_node_list) if launch_state is None: launch_state = sdfg.add_state_before( exit_state_out, label='{}kernel_launch'.format( self.kernel_prefix + '_' if self.kernel_prefix != '' else '')) # If the launch state doesn't exist at this point then there is no other # states outside of the kernel, so create a stand alone launch state if launch_state is None: assert (entry_state_in is None and exit_state_out is None) launch_state = sdfg.add_state(label='{}kernel_launch'.format( self.kernel_prefix + '_' if self.kernel_prefix != '' else '')) # create sdfg for kernel and fill it with states and edges from # ssubgraph dfg will be nested at the end kernel_sdfg = SDFG( '{}kernel'.format(self.kernel_prefix + '_' if self.kernel_prefix != '' else '')) edges = subgraph.edges() for edge in edges: kernel_sdfg.add_edge(edge.src, edge.dst, edge.data) # Setting entry node in nested SDFG if no entry guard was created if entry_guard_state is None: kernel_sdfg.start_state = kernel_sdfg.node_id(entry_state_in) for state in subgraph: state.parent = kernel_sdfg # remove the now nested nodes from the outer sdfg and make sure the # launch state is properly connected to remaining states sdfg.remove_nodes_from(subgraph.nodes()) if entry_state_out is not None \ and len(sdfg.edges_between(entry_state_out, launch_state)) == 0: sdfg.add_edge(entry_state_out, launch_state, InterstateEdge()) if exit_state_out is not None \ and len(sdfg.edges_between(launch_state, exit_state_out)) == 0: sdfg.add_edge(launch_state, exit_state_out, InterstateEdge()) # Handle data for kernel kernel_data = set(node.data for state in kernel_sdfg for node in state.nodes() if isinstance(node, nodes.AccessNode)) # move Streams and Register data into the nested SDFG # normal data will be added as kernel argument kernel_args = [] for data in kernel_data: if (isinstance(sdfg.arrays[data], dace.data.Stream) or (isinstance(sdfg.arrays[data], dace.data.Array) and sdfg.arrays[data].storage == StorageType.Register)): kernel_sdfg.add_datadesc(data, sdfg.arrays[data]) del sdfg.arrays[data] else: copy_desc = copy.deepcopy(sdfg.arrays[data]) copy_desc.transient = False copy_desc.storage = StorageType.Default kernel_sdfg.add_datadesc(data, copy_desc) kernel_args.append(data) # read only data will be passed as input, writeable data will be passed # as 'output' otherwise kernel cannot write to data kernel_args_read = set() kernel_args_write = set() for data in kernel_args: data_accesses_read_only = [ node.access == dtypes.AccessType.ReadOnly for state in kernel_sdfg for node in state if isinstance(node, nodes.AccessNode) and node.data == data ] if all(data_accesses_read_only): kernel_args_read.add(data) else: kernel_args_write.add(data) # Kernel SDFG is complete at this point if self.validate: kernel_sdfg.validate() # Filling launch state with nested SDFG, map and access nodes map_entry, map_exit = launch_state.add_map( '{}kernel_launch_map'.format( self.kernel_prefix + '_' if self.kernel_prefix != '' else ''), dict(ignore='0'), schedule=ScheduleType.GPU_Persistent, ) nested_sdfg = launch_state.add_nested_sdfg( kernel_sdfg, sdfg, kernel_args_read, kernel_args_write, ) # Create and connect read only data access nodes for arg in kernel_args_read: read_node = launch_state.add_read(arg) launch_state.add_memlet_path(read_node, map_entry, nested_sdfg, dst_conn=arg, memlet=Memlet.from_array( arg, sdfg.arrays[arg])) # Create and connect writable data access nodes for arg in kernel_args_write: write_node = launch_state.add_write(arg) launch_state.add_memlet_path(nested_sdfg, map_exit, write_node, src_conn=arg, memlet=Memlet.from_array( arg, sdfg.arrays[arg])) # Transformation is done if self.validate: sdfg.validate()
def apply(self, outer_state: SDFGState, sdfg: SDFG): nsdfg_node = self.nested_sdfg nsdfg: SDFG = nsdfg_node.sdfg if nsdfg_node.schedule is not dtypes.ScheduleType.Default: infer_types.set_default_schedule_and_storage_types( nsdfg, nsdfg_node.schedule) ####################################################### # Collect and update top-level SDFG metadata # Global/init/exit code for loc, code in nsdfg.global_code.items(): sdfg.append_global_code(code.code, loc) for loc, code in nsdfg.init_code.items(): sdfg.append_init_code(code.code, loc) for loc, code in nsdfg.exit_code.items(): sdfg.append_exit_code(code.code, loc) # Environments for nstate in nsdfg.nodes(): for node in nstate.nodes(): if isinstance(node, nodes.CodeNode): node.environments |= nsdfg_node.environments # Constants for cstname, cstval in nsdfg.constants.items(): if cstname in sdfg.constants: if cstval != sdfg.constants[cstname]: warnings.warn('Constant value mismatch for "%s" while ' 'inlining SDFG. Inner = %s != %s = outer' % (cstname, cstval, sdfg.constants[cstname])) else: sdfg.add_constant(cstname, cstval) # Symbols outer_symbols = {str(k): v for k, v in sdfg.symbols.items()} for ise in sdfg.edges(): outer_symbols.update(ise.data.new_symbols(sdfg, outer_symbols)) # Find original source/destination edges (there is only one edge per # connector, according to match) inputs: Dict[str, MultiConnectorEdge] = {} outputs: Dict[str, MultiConnectorEdge] = {} input_set: Dict[str, str] = {} output_set: Dict[str, str] = {} for e in outer_state.in_edges(nsdfg_node): inputs[e.dst_conn] = e input_set[e.data.data] = e.dst_conn for e in outer_state.out_edges(nsdfg_node): outputs[e.src_conn] = e output_set[e.data.data] = e.src_conn # Replace symbols using invocation symbol mapping # Two-step replacement (N -> __dacesym_N --> map[N]) to avoid clashes symbolic.safe_replace(nsdfg_node.symbol_mapping, nsdfg.replace_dict) # Access nodes that need to be reshaped # reshapes: Set(str) = set() # for aname, array in nsdfg.arrays.items(): # if array.transient: # continue # edge = None # if aname in inputs: # edge = inputs[aname] # if len(array.shape) > len(edge.data.subset): # reshapes.add(aname) # continue # if aname in outputs: # edge = outputs[aname] # if len(array.shape) > len(edge.data.subset): # reshapes.add(aname) # continue # if edge is not None and not InlineMultistateSDFG._check_strides( # array.strides, sdfg.arrays[edge.data.data].strides, # edge.data, nsdfg_node): # reshapes.add(aname) # Mapping from nested transient name to top-level name transients: Dict[str, str] = {} # All transients become transients of the parent (if data already # exists, find new name) for nstate in nsdfg.nodes(): for node in nstate.nodes(): if isinstance(node, nodes.AccessNode): datadesc = nsdfg.arrays[node.data] if node.data not in transients and datadesc.transient: new_name = node.data if (new_name in sdfg.arrays or new_name in outer_symbols or new_name in sdfg.constants): new_name = f'{nsdfg.label}_{node.data}' name = sdfg.add_datadesc(new_name, datadesc, find_new_name=True) transients[node.data] = name # All transients of edges between code nodes are also added to parent for edge in nstate.edges(): if (isinstance(edge.src, nodes.CodeNode) and isinstance(edge.dst, nodes.CodeNode)): if edge.data.data is not None: datadesc = nsdfg.arrays[edge.data.data] if edge.data.data not in transients and datadesc.transient: new_name = edge.data.data if (new_name in sdfg.arrays or new_name in outer_symbols or new_name in sdfg.constants): new_name = f'{nsdfg.label}_{edge.data.data}' name = sdfg.add_datadesc(new_name, datadesc, find_new_name=True) transients[edge.data.data] = name ####################################################### # Replace data on inlined SDFG nodes/edges # Replace data names with their top-level counterparts repldict = {} repldict.update(transients) repldict.update({ k: v.data.data for k, v in itertools.chain(inputs.items(), outputs.items()) }) symbolic.safe_replace(repldict, lambda m: replace_datadesc_names(nsdfg, m), value_as_string=True) # Add views whenever reshapes are necessary # for dname in reshapes: # desc = nsdfg.arrays[dname] # # To avoid potential confusion, rename protected __return keyword # if dname.startswith('__return'): # newname = f'{nsdfg.name}_ret{dname[8:]}' # else: # newname = dname # newname, _ = sdfg.add_view(newname, # desc.shape, # desc.dtype, # storage=desc.storage, # strides=desc.strides, # offset=desc.offset, # debuginfo=desc.debuginfo, # allow_conflicts=desc.allow_conflicts, # total_size=desc.total_size, # alignment=desc.alignment, # may_alias=desc.may_alias, # find_new_name=True) # repldict[dname] = newname # Add extra access nodes for out/in view nodes # inv_reshapes = {repldict[r]: r for r in reshapes} # for nstate in nsdfg.nodes(): # for node in nstate.nodes(): # if isinstance(node, # nodes.AccessNode) and node.data in inv_reshapes: # if nstate.in_degree(node) > 0 and nstate.out_degree( # node) > 0: # # Such a node has to be in the output set # edge = outputs[inv_reshapes[node.data]] # # Redirect outgoing edges through access node # out_edges = list(nstate.out_edges(node)) # anode = nstate.add_access(edge.data.data) # vnode = nstate.add_access(node.data) # nstate.add_nedge(node, anode, edge.data) # nstate.add_nedge(anode, vnode, edge.data) # for e in out_edges: # nstate.remove_edge(e) # nstate.add_edge(vnode, e.src_conn, e.dst, # e.dst_conn, e.data) # Make unique names for states statenames = set(s.label for s in sdfg.nodes()) for nstate in nsdfg.nodes(): if nstate.label in statenames: newname = data.find_new_name(nstate.label, statenames) statenames.add(newname) nstate.set_label(newname) ####################################################### # Collect and modify interstate edges as necessary outer_assignments = set() for e in sdfg.edges(): outer_assignments |= e.data.assignments.keys() inner_assignments = set() for e in nsdfg.edges(): inner_assignments |= e.data.assignments.keys() assignments_to_replace = inner_assignments & outer_assignments sym_replacements: Dict[str, str] = {} allnames = set(outer_symbols.keys()) | set(sdfg.arrays.keys()) for assign in assignments_to_replace: newname = data.find_new_name(assign, allnames) allnames.add(newname) sym_replacements[assign] = newname nsdfg.replace_dict(sym_replacements) ####################################################### # Add nested SDFG states into top-level SDFG outer_start_state = sdfg.start_state sdfg.add_nodes_from(nsdfg.nodes()) for ise in nsdfg.edges(): sdfg.add_edge(ise.src, ise.dst, ise.data) ####################################################### # Reconnect inlined SDFG source = nsdfg.start_state sinks = nsdfg.sink_nodes() # Reconnect state machine for e in sdfg.in_edges(outer_state): sdfg.add_edge(e.src, source, e.data) for e in sdfg.out_edges(outer_state): for sink in sinks: sdfg.add_edge(sink, e.dst, e.data) # Modify start state as necessary if outer_start_state is outer_state: sdfg.start_state = sdfg.node_id(source) # TODO: Modify memlets by offsetting # If both source and sink nodes are inputs/outputs, reconnect once # edges_to_ignore = self._modify_access_to_access(new_incoming_edges, # nsdfg, nstate, state, # orig_data) # source_to_outer = {n: e.src for n, e in new_incoming_edges.items()} # sink_to_outer = {n: e.dst for n, e in new_outgoing_edges.items()} # # If a source/sink node is one of the inputs/outputs, reconnect it, # # replacing memlets in outgoing/incoming paths # modified_edges = set() # modified_edges |= self._modify_memlet_path(new_incoming_edges, nstate, # state, sink_to_outer, True, # edges_to_ignore) # modified_edges |= self._modify_memlet_path(new_outgoing_edges, nstate, # state, source_to_outer, # False, edges_to_ignore) # # Reshape: add connections to viewed data # self._modify_reshape_data(reshapes, repldict, inputs, nstate, state, # True) # self._modify_reshape_data(reshapes, repldict, outputs, nstate, state, # False) # Modify all other internal edges pertaining to input/output nodes # for nstate in nsdfg.nodes(): # for node in nstate.nodes(): # if isinstance(node, nodes.AccessNode): # if node.data in input_set or node.data in output_set: # if node.data in input_set: # outer_edge = inputs[input_set[node.data]] # else: # outer_edge = outputs[output_set[node.data]] # for edge in state.all_edges(node): # if (edge not in modified_edges # and edge.data.data == node.data): # for e in state.memlet_tree(edge): # if e.data.data == node.data: # e._data = helpers.unsqueeze_memlet( # e.data, outer_edge.data) # Replace nested SDFG parents with new SDFG for nstate in nsdfg.nodes(): nstate.parent = sdfg for node in nstate.nodes(): if isinstance(node, nodes.NestedSDFG): node.sdfg.parent_sdfg = sdfg node.sdfg.parent_nsdfg_node = node ####################################################### # Remove nested SDFG and state sdfg.remove_node(outer_state) return nsdfg.nodes()
def apply(self, sdfg: sd.SDFG): ####################################################### # Step 0: SDFG metadata # Find all input and output data descriptors input_nodes = [] output_nodes = [] global_code_nodes = [[] for _ in sdfg.nodes()] for i, state in enumerate(sdfg.nodes()): sdict = state.scope_dict() for node in state.nodes(): if (isinstance(node, nodes.AccessNode) and node.desc(sdfg).transient == False): if (state.out_degree(node) > 0 and node.data not in input_nodes): input_nodes.append((node.data, node.desc(sdfg))) if (state.in_degree(node) > 0 and node.data not in output_nodes): output_nodes.append((node.data, node.desc(sdfg))) elif isinstance(node, nodes.CodeNode) and sdict[node] is None: if not isinstance(node, nodes.EmptyTasklet): global_code_nodes[i].append(node) # Input nodes may also be nodes with WCR memlets and no identity for e in state.edges(): if e.data.wcr is not None and e.data.wcr_identity is None: if (e.data.data not in input_nodes and sdfg.arrays[e.data.data].transient == False): input_nodes.append(e.data.data) start_state = sdfg.start_state end_states = sdfg.sink_nodes() ####################################################### # Step 1: Create cloned GPU arrays and replace originals cloned_arrays = {} for inodename, inode in input_nodes: newdesc = inode.clone() newdesc.storage = types.StorageType.GPU_Global newdesc.transient = True sdfg.add_datadesc('gpu_' + inodename, newdesc) cloned_arrays[inodename] = 'gpu_' + inodename for onodename, onode in output_nodes: if onodename in cloned_arrays: continue newdesc = onode.clone() newdesc.storage = types.StorageType.GPU_Global newdesc.transient = True sdfg.add_datadesc('gpu_' + onodename, newdesc) cloned_arrays[onodename] = 'gpu_' + onodename # Replace nodes for state in sdfg.nodes(): for node in state.nodes(): if (isinstance(node, nodes.AccessNode) and node.data in cloned_arrays): node.data = cloned_arrays[node.data] # Replace memlets for state in sdfg.nodes(): for edge in state.edges(): if edge.data.data in cloned_arrays: edge.data.data = cloned_arrays[edge.data.data] ####################################################### # Step 2: Create copy-in state copyin_state = sdfg.add_state(sdfg.label + '_copyin') sdfg.add_edge(copyin_state, start_state, ed.InterstateEdge()) for nname, desc in input_nodes: src_array = nodes.AccessNode(nname, debuginfo=desc.debuginfo) dst_array = nodes.AccessNode(cloned_arrays[nname], debuginfo=desc.debuginfo) copyin_state.add_node(src_array) copyin_state.add_node(dst_array) copyin_state.add_nedge( src_array, dst_array, memlet.Memlet.from_array(src_array.data, src_array.desc(sdfg))) ####################################################### # Step 3: Create copy-out state copyout_state = sdfg.add_state(sdfg.label + '_copyout') for state in end_states: sdfg.add_edge(state, copyout_state, ed.InterstateEdge()) for nname, desc in output_nodes: src_array = nodes.AccessNode(cloned_arrays[nname], debuginfo=desc.debuginfo) dst_array = nodes.AccessNode(nname, debuginfo=desc.debuginfo) copyout_state.add_node(src_array) copyout_state.add_node(dst_array) copyout_state.add_nedge( src_array, dst_array, memlet.Memlet.from_array(dst_array.data, dst_array.desc(sdfg))) ####################################################### # Step 4: Modify transient data storage for state in sdfg.nodes(): sdict = state.scope_dict() for node in state.nodes(): if isinstance(node, nodes.AccessNode) and node.desc(sdfg).transient: nodedesc = node.desc(sdfg) if sdict[node] is None: # NOTE: the cloned arrays match too but it's the same # storage so we don't care nodedesc.storage = types.StorageType.GPU_Global # Try to move allocation/deallocation out of loops if self.toplevel_trans: nodedesc.toplevel = True else: # Make internal transients registers if self.register_trans: nodedesc.storage = types.StorageType.Register ####################################################### # Step 5: Wrap free tasklets and nested SDFGs with a GPU map for state, gcodes in zip(sdfg.nodes(), global_code_nodes): for gcode in gcodes: # Create map and connectors me, mx = state.add_map(gcode.label + '_gmap', {gcode.label + '__gmapi': '0:1'}, schedule=types.ScheduleType.GPU_Device) # Store in/out edges in lists so that they don't get corrupted # when they are removed from the graph in_edges = list(state.in_edges(gcode)) out_edges = list(state.out_edges(gcode)) me.in_connectors = set('IN_' + e.dst_conn for e in in_edges) me.out_connectors = set('OUT_' + e.dst_conn for e in in_edges) mx.in_connectors = set('IN_' + e.src_conn for e in out_edges) mx.out_connectors = set('OUT_' + e.src_conn for e in out_edges) # Create memlets through map for e in in_edges: state.remove_edge(e) state.add_edge(e.src, e.src_conn, me, 'IN_' + e.dst_conn, e.data) state.add_edge(me, 'OUT_' + e.dst_conn, e.dst, e.dst_conn, e.data) for e in out_edges: state.remove_edge(e) state.add_edge(e.src, e.src_conn, mx, 'IN_' + e.src_conn, e.data) state.add_edge(mx, 'OUT_' + e.src_conn, e.dst, e.dst_conn, e.data) # Map without inputs if len(in_edges) == 0: state.add_nedge(me, gcode, memlet.EmptyMemlet()) ####################################################### # Step 6: Change all top-level maps to GPU maps for i, state in enumerate(sdfg.nodes()): sdict = state.scope_dict() for node in state.nodes(): if isinstance(node, nodes.EntryNode): if sdict[node] is None: node.schedule = types.ScheduleType.GPU_Device elif self.sequential_innermaps: node.schedule = types.ScheduleType.Sequential ####################################################### # Step 7: Strict transformations if not self.strict_transform: return # Apply strict state fusions greedily. opt = optimizer.SDFGOptimizer(sdfg, inplace=True) fusions = 0 arrays = 0 options = [ match for match in opt.get_pattern_matches(strict=True) if isinstance(match, (StateFusion, RedundantArray)) ] while options: ssdfg = sdfg.sdfg_list[options[0].sdfg_id] options[0].apply(ssdfg) ssdfg.validate() if isinstance(options[0], StateFusion): fusions += 1 if isinstance(options[0], RedundantArray): arrays += 1 options = [ match for match in opt.get_pattern_matches(strict=True) if isinstance(match, (StateFusion, RedundantArray)) ] if Config.get_bool('debugprint') and (fusions > 0 or arrays > 0): print('Automatically applied {} strict state fusions and removed' ' {} redundant arrays.'.format(fusions, arrays))
def parse_from_function(function, *compilation_args, strict=None): """ Try to parse a DaceProgram object and return the `dace.SDFG` object that corresponds to it. @param function: DaceProgram object (obtained from the `@dace.program` decorator). @param compilation_args: Various compilation arguments e.g. types. @param strict: Whether to apply strict transformations or not (None uses configuration-defined value). @return: The generated SDFG object. """ if not isinstance(function, DaceProgram): raise TypeError( 'Function must be of type dace.frontend.python.DaceProgram') # Obtain parsed DaCe program pdp, modules = function.generate_pdp(*compilation_args) # Create an empty SDFG sdfg = SDFG(pdp.name, pdp.argtypes) sdfg.set_sourcecode(pdp.source, 'python') # Populate SDFG with states and nodes, according to the parsed DaCe program # 1) Inherit dependencies and inject tasklets # 2) Traverse program graph and recursively split into states, # annotating edges with their transition conditions. # 3) Add arrays, streams, and scalars to the SDFG array store # 4) Eliminate empty states with no conditional outgoing transitions # 5) Label states in topological order # 6) Construct dataflow graph for each state # Step 1) for primitive in pdp.children: depanalysis.inherit_dependencies(primitive) # Step 2) state_primitives = depanalysis.create_states_simple(pdp, sdfg) # Step 3) for dataname, datadesc in pdp.all_arrays().items(): sdfg.add_datadesc(dataname, datadesc) # Step 4) Absorb next state into current, if possible oldstates = list(sdfg.topological_sort(sdfg.start_state)) for state in oldstates: if state not in sdfg.nodes(): # State already removed continue if sdfg.out_degree(state) == 1: edge = sdfg.out_edges(state)[0] nextState = edge.dst if not edge.data.is_unconditional(): continue if sdfg.in_degree(nextState) > 1: # If other edges point to state continue if len(state_primitives[nextState]) > 0: # Don't fuse full states continue outEdges = list(sdfg.out_edges(nextState)) for e in outEdges: # Construct new edge from the current assignments, new # assignments, and new conditions newEdge = copy.deepcopy(edge.data) newEdge.assignments.update(e.data.assignments) newEdge.condition = e.data.condition sdfg.add_edge(state, e.dst, newEdge) sdfg.remove_node(nextState) # Step 5) stateList = sdfg.topological_sort(sdfg.start_state) for i, state in enumerate(stateList): if state.label is None or state.label == "": state.set_label("s" + str(i)) # Step 6) for i, state in enumerate(stateList): depanalysis.build_dataflow_graph(sdfg, state, state_primitives[state], modules) # Fill in scope entry/exit connectors sdfg.fill_scope_connectors() # Memlet propagation if sdfg.propagate: labeling.propagate_labels_sdfg(sdfg) # Drawing the SDFG before strict transformations sdfg.draw_to_file(recursive=True) # Apply strict transformations automatically if (strict == True or (strict is None and Config.get_bool('optimizer', 'automatic_state_fusion'))): sdfg.apply_strict_transformations() # Drawing the SDFG (again) to a .dot file sdfg.draw_to_file(recursive=True) # Validate SDFG sdfg.validate() return sdfg