def test_state_fission(): ''' Tests state fission. The starting point is a stae SDFG with two Nested SDFGs. The state is splitted into two :return: ''' size_n = 16 size_m = 32 sdfg = make_nested_sdfg_cpu() # state fission state = sdfg.states()[0] node_x = state.nodes()[0] node_y = state.nodes()[1] node_z = state.nodes()[2] vec_add1 = state.nodes()[3] subg = dace.sdfg.graph.SubgraphView(state, [node_x, node_y, vec_add1, node_z]) helpers.state_fission(sdfg, subg) sdfg.validate() assert (len(sdfg.states()) == 2) # run the program vec_add = sdfg.compile() x = np.random.rand(size_n).astype(np.float32) y = np.random.rand(size_n).astype(np.float32) z = np.random.rand(size_n).astype(np.float32) v = np.random.rand(size_m).astype(np.float32) w = np.random.rand(size_m).astype(np.float32) u = np.random.rand(size_m).astype(np.float32) vec_add(x=x, y=y, z=z, v=v, w=w, u=u, n=size_n, m=size_m) ref1 = np.add(x, y) ref2 = np.add(v, w) diff1 = np.linalg.norm(ref1 - z) / size_n diff2 = np.linalg.norm(ref2 - u) / size_m assert (diff1 <= 1e-5 and diff2 <= 1e-5)
def promote_scalars_to_symbols(sdfg: sd.SDFG, ignore: Optional[Set[str]] = None, transients_only: bool = True, integers_only: bool = True) -> Set[str]: """ Promotes all matching transient scalars to SDFG symbols, changing all tasklets to inter-state assignments. This enables the transformed symbols to be used within states as part of memlets, and allows further transformations (such as loop detection) to use the information for optimization. :param sdfg: The SDFG to run the pass on. :param ignore: An optional set of strings of scalars to ignore. :param transients_only: If False, also considers global data descriptors (e.g., arguments). :param integers_only: If False, also considers non-integral descriptors for promotion. :return: Set of promoted scalars. :note: Operates in-place. """ # Process: # 1. Find scalars to promote # 2. For every assignment tasklet/access: # 2.1. Fission state to isolate assignment # 2.2. Replace assignment with inter-state edge assignment # 3. For every read of the scalar: # 3.1. If destination is tasklet, remove node, edges, and connectors # 3.2. If used in tasklet as subscript or connector, modify tasklet code # 3.3. If destination is array, change to tasklet that copies symbol data # 4. Remove newly-isolated access nodes # 5. Remove data descriptors and add symbols to SDFG # 6. Replace subscripts in all interstate conditions and assignments # 7. Make indirections with symbols a single memlet to_promote = find_promotable_scalars(sdfg, transients_only=transients_only, integers_only=integers_only) if ignore: to_promote -= ignore if len(to_promote) == 0: return to_promote for state in sdfg.nodes(): scalar_nodes = [ n for n in state.nodes() if isinstance(n, nodes.AccessNode) and n.data in to_promote ] # Step 2: Assignment tasklets for node in scalar_nodes: if state.in_degree(node) == 0: continue in_edge = state.in_edges(node)[0] input = in_edge.src # There is only zero or one incoming edges by definition tasklet_inputs = [e.src for e in state.in_edges(input)] # Step 2.1 new_state = xfh.state_fission( sdfg, gr.SubgraphView(state, set([input, node] + tasklet_inputs))) new_isedge: sd.InterstateEdge = sdfg.out_edges(new_state)[0] # Step 2.2 node: nodes.AccessNode = new_state.sink_nodes()[0] input = new_state.in_edges(node)[0].src if isinstance(input, nodes.Tasklet): # Convert tasklet to interstate edge newcode: str = '' if input.language is dtypes.Language.Python: newcode = astutils.unparse(input.code.code[0].value) elif input.language is dtypes.Language.CPP: newcode = translate_cpp_tasklet_to_python( input.code.as_string.strip()) # Replace tasklet inputs with incoming edges for e in new_state.in_edges(input): memlet_str: str = e.data.data if (e.data.subset is not None and not isinstance( sdfg.arrays[memlet_str], dt.Scalar)): memlet_str += '[%s]' % e.data.subset newcode = re.sub(r'\b%s\b' % re.escape(e.dst_conn), memlet_str, newcode) # Add interstate edge assignment new_isedge.data.assignments[node.data] = newcode elif isinstance(input, nodes.AccessNode): memlet: mm.Memlet = in_edge.data if (memlet.src_subset and not isinstance(sdfg.arrays[memlet.data], dt.Scalar)): new_isedge.data.assignments[ node.data] = '%s[%s]' % (input.data, memlet.src_subset) else: new_isedge.data.assignments[node.data] = input.data # Clean up all nodes after assignment was transferred new_state.remove_nodes_from(new_state.nodes()) # Step 3: Scalar reads remove_scalar_reads(sdfg, {k: k for k in to_promote}) # Step 4: Isolated nodes for state in sdfg.nodes(): scalar_nodes = [ n for n in state.nodes() if isinstance(n, nodes.AccessNode) and n.data in to_promote ] state.remove_nodes_from( [n for n in scalar_nodes if len(state.all_edges(n)) == 0]) # Step 5: Data descriptor management for scalar in to_promote: desc = sdfg.arrays[scalar] sdfg.remove_data(scalar, validate=False) # If the scalar is already a symbol (e.g., as part of an array size), # do not re-add the symbol if scalar not in sdfg.symbols: sdfg.add_symbol(scalar, desc.dtype) # Step 6: Inter-state edge cleanup cleanup_re = { s: re.compile(fr'\b{re.escape(s)}\[.*?\]') for s in to_promote } promo = TaskletPromoterDict({k: k for k in to_promote}) for edge in sdfg.edges(): ise: InterstateEdge = edge.data # Condition if not edge.data.is_unconditional(): if ise.condition.language is dtypes.Language.Python: for stmt in ise.condition.code: promo.visit(stmt) elif ise.condition.language is dtypes.Language.CPP: for scalar in to_promote: ise.condition = cleanup_re[scalar].sub( scalar, ise.condition.as_string) # Assignments for aname, assignment in ise.assignments.items(): for scalar in to_promote: if scalar in assignment: ise.assignments[aname] = cleanup_re[scalar].sub( scalar, assignment.strip()) # Step 7: Indirection remove_symbol_indirection(sdfg) return to_promote
def apply(self, graph: SDFGState, sdfg: SDFG) -> nodes.MapEntry: me = self.mapentry # Add new map within map mx = graph.exit_node(me) new_me, new_mx = graph.add_map('warp_tile', dict(__tid=f'0:{self.warp_size}'), dtypes.ScheduleType.GPU_ThreadBlock) __tid = symbolic.pystr_to_symbolic('__tid') for e in graph.out_edges(me): xfh.reconnect_edge_through_map(graph, e, new_me, True) for e in graph.in_edges(mx): xfh.reconnect_edge_through_map(graph, e, new_mx, False) # Stride and offset all internal maps maps_to_stride = xfh.get_internal_scopes(graph, new_me, immediate=True) for nstate, nmap in maps_to_stride: nsdfg = nstate.parent nsdfg_node = nsdfg.parent_nsdfg_node # Map cannot be partitioned across a warp if (nmap.range.size()[-1] < self.warp_size) == True: continue if nsdfg is not sdfg and nsdfg_node is not None: nsdfg_node.symbol_mapping['__tid'] = __tid if '__tid' not in nsdfg.symbols: nsdfg.add_symbol('__tid', dtypes.int32) nmap.range[-1] = (nmap.range[-1][0], nmap.range[-1][1] - __tid, nmap.range[-1][2] * self.warp_size) subgraph = nstate.scope_subgraph(nmap) subgraph.replace(nmap.params[-1], f'{nmap.params[-1]} + __tid') inner_map_exit = nstate.exit_node(nmap) # If requested, replicate maps with multiple dependent maps if self.replicate_maps: destinations = [ nstate.memlet_path(edge)[-1].dst for edge in nstate.out_edges(inner_map_exit) ] for dst in destinations: # Transformation will not replicate map with more than one # output if len(destinations) != 1: break if not isinstance(dst, nodes.AccessNode): continue # Not leading to access node if not xfh.contained_in(nstate, dst, new_me): continue # Memlet path goes out of map if not nsdfg.arrays[dst.data].transient: continue # Cannot modify non-transients for edge in nstate.out_edges(dst)[1:]: rep_subgraph = xfh.replicate_scope( nsdfg, nstate, subgraph) rep_edge = nstate.out_edges( rep_subgraph.sink_nodes()[0])[0] # Add copy of data newdesc = copy.deepcopy(sdfg.arrays[dst.data]) newname = nsdfg.add_datadesc(dst.data, newdesc, find_new_name=True) newaccess = nstate.add_access(newname) # Redirect edges xfh.redirect_edge(nstate, rep_edge, new_dst=newaccess, new_data=newname) xfh.redirect_edge(nstate, edge, new_src=newaccess, new_data=newname) # If has WCR, add warp-collaborative reduction on outputs for out_edge in nstate.out_edges(inner_map_exit): dst = nstate.memlet_path(out_edge)[-1].dst if not xfh.contained_in(nstate, dst, new_me): # Skip edges going out of map continue if dst.desc(nsdfg).storage == dtypes.StorageType.GPU_Global: # Skip shared memory continue if out_edge.data.wcr is not None: ctype = nsdfg.arrays[out_edge.data.data].dtype.ctype redtype = detect_reduction_type(out_edge.data.wcr) if redtype == dtypes.ReductionType.Custom: raise NotImplementedError credtype = ('dace::ReductionType::' + str(redtype)[str(redtype).find('.') + 1:]) # One element: tasklet if out_edge.data.subset.num_elements() == 1: # Add local access between thread-local and warp reduction name = nsdfg._find_new_name(out_edge.data.data) nsdfg.add_scalar( name, nsdfg.arrays[out_edge.data.data].dtype, transient=True) # Initialize thread-local to global value read = nstate.add_read(out_edge.data.data) write = nstate.add_write(name) edge = nstate.add_nedge(read, write, copy.deepcopy(out_edge.data)) edge.data.wcr = None xfh.state_fission(nsdfg, SubgraphView(nstate, [read, write])) newnode = nstate.add_access(name) nstate.remove_edge(out_edge) edge = nstate.add_edge(out_edge.src, out_edge.src_conn, newnode, None, copy.deepcopy(out_edge.data)) for e in nstate.memlet_path(edge): e.data.data = name e.data.subset = subsets.Range([(0, 0, 1)]) wrt = nstate.add_tasklet( 'warpreduce', {'__a'}, {'__out'}, f'__out = dace::warpReduce<{credtype}, {ctype}>::reduce(__a);', dtypes.Language.CPP) nstate.add_edge(newnode, None, wrt, '__a', Memlet(name)) out_edge.data.wcr = None nstate.add_edge(wrt, '__out', out_edge.dst, None, out_edge.data) else: # More than one element: mapped tasklet # Could be a parallel summation # TODO(later): Check if reduction continue # End of WCR to warp reduction # Make nested SDFG out of new scope xfh.nest_state_subgraph(sdfg, graph, graph.scope_subgraph(new_me, False, False)) return new_me