def _cart_create(pv: 'ProgramVisitor', sdfg: SDFG, state: SDFGState, dims: ShapeType): """ Creates a process-grid and adds it to the DaCe program. The process-grid is implemented with [MPI_Cart_create](https://www.mpich.org/static/docs/latest/www3/MPI_Cart_create.html). :param dims: Shape of the process-grid (see `dims` parameter of `MPI_Cart_create`), e.g., [2, 3, 3]. :return: Name of the new process-grid descriptor. """ pgrid_name = sdfg.add_pgrid(dims) # Dummy tasklet adds MPI variables to the program's state. from dace.libraries.mpi import Dummy tasklet = Dummy(pgrid_name, [ f'MPI_Comm {pgrid_name}_comm;', f'MPI_Group {pgrid_name}_group;', f'int {pgrid_name}_coords[{len(dims)}];', f'int {pgrid_name}_dims[{len(dims)}];', f'int {pgrid_name}_rank;', f'int {pgrid_name}_size;', f'bool {pgrid_name}_valid;', ]) state.add_node(tasklet) # Pseudo-writing to a dummy variable to avoid removal of Dummy node by transformations. _, scal = sdfg.add_scalar(pgrid_name, dace.int32, transient=True) wnode = state.add_write(pgrid_name) state.add_edge(tasklet, '__out', wnode, None, Memlet.from_array(pgrid_name, scal)) return pgrid_name
def _handle_connectors(state: sd.SDFGState, node: nodes.Tasklet, mapping: Dict[str, Tuple[str, subsets.Range]], ignore: Set[str], in_edges: bool) -> bool: """ Adds new connectors and removes unused connectors after indirection promotion. """ if in_edges: orig_edges = {e.dst_conn: e for e in state.in_edges(node)} else: orig_edges = {e.src_conn: e for e in state.out_edges(node)} for cname, (orig, subset) in mapping.items(): if in_edges: node.add_in_connector(cname) else: node.add_out_connector(cname) # Add new edge orig_edge = orig_edges[orig] if in_edges: state.add_edge(orig_edge.src, orig_edge.src_conn, orig_edge.dst, cname, mm.Memlet(data=orig_edge.data.data, subset=subset)) else: state.add_edge(orig_edge.src, cname, orig_edge.dst, orig_edge.dst_conn, mm.Memlet(data=orig_edge.data.data, subset=subset)) # Remove connectors and edges conns_to_remove = set(v[0] for v in mapping.values()) - ignore for conn in conns_to_remove: state.remove_edge(orig_edges[conn]) if in_edges: node.remove_in_connector(conn) else: node.remove_out_connector(conn)
def apply(self, state: SDFGState, sdfg: SDFG) -> nodes.AccessNode: access: nodes.AccessNode = self.access # Get memlet paths first_edge = state.in_edges(access)[0] second_edge = state.out_edges(access)[0] first_mpath = state.memlet_path(first_edge) second_mpath = state.memlet_path(second_edge) # Create new stream of shape 1 desc = sdfg.arrays[access.data] name, newdesc = sdfg.add_stream(access.data, desc.dtype, buffer_size=self.buffer_size, storage=self.storage, transient=True, find_new_name=True) # Remove transient array if possible for ostate in sdfg.nodes(): if ostate is state: continue if any(n.data == access.data for n in ostate.data_nodes()): break else: del sdfg.arrays[access.data] # Replace memlets in path with stream access for e in first_mpath: e.data = mm.Memlet(data=name, subset='0') if isinstance(e.src, nodes.NestedSDFG): e.data.dynamic = True _streamify_recursive(e.src, e.src_conn, newdesc) if isinstance(e.dst, nodes.NestedSDFG): e.data.dynamic = True _streamify_recursive(e.dst, e.dst_conn, newdesc) for e in second_mpath: e.data = mm.Memlet(data=name, subset='0') if isinstance(e.src, nodes.NestedSDFG): e.data.dynamic = True _streamify_recursive(e.src, e.src_conn, newdesc) if isinstance(e.dst, nodes.NestedSDFG): e.data.dynamic = True _streamify_recursive(e.dst, e.dst_conn, newdesc) # Replace array access node with two stream access nodes wnode = state.add_write(name) rnode = state.add_read(name) state.remove_edge(first_edge) state.add_edge(first_edge.src, first_edge.src_conn, wnode, first_edge.dst_conn, first_edge.data) state.remove_edge(second_edge) state.add_edge(rnode, second_edge.src_conn, second_edge.dst, second_edge.dst_conn, second_edge.data) # Remove original access node state.remove_node(access) return wnode, rnode
def _modify_memlet_path( self, new_edges: Dict[nodes.Node, MultiConnectorEdge], nstate: SDFGState, state: SDFGState, inner_to_outer: Dict[nodes.Node, MultiConnectorEdge], inputs: bool, edges_to_ignore: Set[MultiConnectorEdge], ) -> Set[MultiConnectorEdge]: """ Modifies memlet paths in an inlined SDFG. Returns set of modified edges. """ result = set() for node, top_edge in new_edges.items(): inner_edges = (nstate.out_edges(node) if inputs else nstate.in_edges(node)) for inner_edge in inner_edges: if inner_edge in edges_to_ignore: continue new_memlet = helpers.unsqueeze_memlet(inner_edge.data, top_edge.data) if inputs: if inner_edge.dst in inner_to_outer: dst = inner_to_outer[inner_edge.dst] else: dst = inner_edge.dst new_edge = state.add_edge(top_edge.src, top_edge.src_conn, dst, inner_edge.dst_conn, new_memlet) mtree = state.memlet_tree(new_edge) else: if inner_edge.src in inner_to_outer: # don't add edges twice continue new_edge = state.add_edge(inner_edge.src, inner_edge.src_conn, top_edge.dst, top_edge.dst_conn, new_memlet) mtree = state.memlet_tree(new_edge) # Modify all memlets going forward/backward def traverse(mtree_node): result.add(mtree_node.edge) mtree_node.edge._data = helpers.unsqueeze_memlet( mtree_node.edge.data, top_edge.data) for child in mtree_node.children: traverse(child) for child in mtree.children: traverse(child) return result
def _subarray(pv: 'ProgramVisitor', sdfg: SDFG, state: SDFGState, array: Union[str, ShapeType], subarray: Union[str, ShapeType], dtype: dtypes.typeclass = None, process_grid: str = None, correspondence: Sequence[Integral] = None): """ Adds a sub-array descriptor to the DaCe Program. Sub-arrays are implemented (when `process_grid` is set) with [MPI_Type_create_subarray](https://www.mpich.org/static/docs/v3.2/www3/MPI_Type_create_subarray.html). :param array: Either the name of an Array descriptor or the shape of the array (similar to the `array_of_sizes` parameter of `MPI_Type_create_subarray`). :param subarray: Either the name of an Array descriptor or the sub-shape of the (sub-)array (similar to the `array_of_subsizes` parameter of `MPI_Type_create_subarray`). :param dtype: Datatype of the array/sub-array (similar to the `oldtype` parameter of `MPI_Type_create_subarray`). :process_grid: Name of the process-grid for collective scatter/gather operations. :param correspondence: Matching of the array/sub-array's dimensions to the process-grid's dimensions. :return: Name of the new sub-array descriptor. """ # Get dtype, shape, and subshape if isinstance(array, str): shape = sdfg.arrays[array].shape arr_dtype = sdfg.arrays[array].dtype else: shape = array arr_dtype = None if isinstance(subarray, str): subshape = sdfg.arrays[subarray].shape sub_dtype = sdfg.arrays[subarray].dtype else: subshape = subarray sub_dtype = None dtype = dtype or arr_dtype or sub_dtype subarray_name = sdfg.add_subarray(dtype, shape, subshape, process_grid, correspondence) # Generate subgraph only if process-grid is set, i.e., the sub-array will be used for collective scatter/gather ops. if process_grid: # Dummy tasklet adds MPI variables to the program's state. from dace.libraries.mpi import Dummy tasklet = Dummy(subarray_name, [ f'MPI_Datatype {subarray_name};', f'int* {subarray_name}_counts;', f'int* {subarray_name}_displs;' ]) state.add_node(tasklet) # Pseudo-writing to a dummy variable to avoid removal of Dummy node by transformations. _, scal = sdfg.add_scalar(subarray_name, dace.int32, transient=True) wnode = state.add_write(subarray_name) state.add_edge(tasklet, '__out', wnode, None, Memlet.from_array(subarray_name, scal)) return subarray_name
def _gather(pv: 'ProgramVisitor', sdfg: SDFG, state: SDFGState, in_buffer: str, out_buffer: str, root: Union[str, sp.Expr, Number] = 0): from dace.libraries.mpi.nodes.gather import Gather libnode = Gather('_Gather_') in_desc = sdfg.arrays[in_buffer] out_desc = sdfg.arrays[out_buffer] in_node = state.add_read(in_buffer) out_node = state.add_write(out_buffer) if isinstance(root, str) and root in sdfg.arrays.keys(): root_node = state.add_read(root) else: storage = in_desc.storage root_name = _define_local_scalar(pv, sdfg, state, dace.int32, storage) root_node = state.add_access(root_name) root_tasklet = state.add_tasklet('_set_root_', {}, {'__out'}, '__out = {}'.format(root)) state.add_edge(root_tasklet, '__out', root_node, None, Memlet.simple(root_name, '0')) state.add_edge(in_node, None, libnode, '_inbuffer', Memlet.from_array(in_buffer, in_desc)) state.add_edge(root_node, None, libnode, '_root', Memlet.simple(root_node.data, '0')) state.add_edge(libnode, '_outbuffer', out_node, None, Memlet.from_array(out_buffer, out_desc)) return None
def _Reduce(pv: 'ProgramVisitor', sdfg: SDFG, state: SDFGState, buffer: str, op: str, root: Union[str, sp.Expr, Number] = 0, grid: str = None): from dace.libraries.mpi.nodes.reduce import Reduce libnode = Reduce('_Reduce_', op, grid) desc = sdfg.arrays[buffer] in_buffer = state.add_read(buffer) out_buffer = state.add_write(buffer) if isinstance(root, str) and root in sdfg.arrays.keys(): root_node = state.add_read(root) else: storage = desc.storage root_name = _define_local_scalar(pv, sdfg, state, dace.int32, storage) root_node = state.add_access(root_name) root_tasklet = state.add_tasklet('_set_root_', {}, {'__out'}, '__out = {}'.format(root)) state.add_edge(root_tasklet, '__out', root_node, None, Memlet.simple(root_name, '0')) state.add_edge(in_buffer, None, libnode, '_inbuffer', Memlet.from_array(buffer, desc)) state.add_edge(root_node, None, libnode, '_root', Memlet.simple(root_node.data, '0')) state.add_edge(libnode, '_outbuffer', out_buffer, None, Memlet.from_array(buffer, desc)) return None
def apply(self, graph: SDFGState, sdfg: SDFG): # Extract the parameters and ranges of the inner/outer maps. outer_map_entry = self.outer_map_entry inner_map_entry = self.inner_map_entry inner_map_exit = graph.exit_node(inner_map_entry) outer_map_exit = graph.exit_node(outer_map_entry) # Switch connectors outer_map_entry.in_connectors, inner_map_entry.in_connectors = \ inner_map_entry.in_connectors, outer_map_entry.in_connectors outer_map_entry.out_connectors, inner_map_entry.out_connectors = \ inner_map_entry.out_connectors, outer_map_entry.out_connectors outer_map_exit.in_connectors, inner_map_exit.in_connectors = \ inner_map_exit.in_connectors, outer_map_exit.in_connectors outer_map_exit.out_connectors, inner_map_exit.out_connectors = \ inner_map_exit.out_connectors, outer_map_exit.out_connectors # Get edges between the map entries and exits. entry_edges = graph.edges_between(outer_map_entry, inner_map_entry) exit_edges = graph.edges_between(inner_map_exit, outer_map_exit) for e in entry_edges + exit_edges: graph.remove_edge(e) # Change source and destination of edges. sdutil.change_edge_dest(graph, outer_map_entry, inner_map_entry) sdutil.change_edge_src(graph, inner_map_entry, outer_map_entry) sdutil.change_edge_dest(graph, inner_map_exit, outer_map_exit) sdutil.change_edge_src(graph, outer_map_exit, inner_map_exit) # Add edges between the map entries and exits. new_entry_edges = [] new_exit_edges = [] for e in entry_edges: new_entry_edges.append( graph.add_edge(e.dst, e.src_conn, e.src, e.dst_conn, e.data)) for e in exit_edges: new_exit_edges.append( graph.add_edge(e.dst, e.src_conn, e.src, e.dst_conn, e.data)) # Repropagate memlets in modified region for e in new_entry_edges: path = graph.memlet_path(e) index = next(i for i, edge in enumerate(path) if e is edge) e.data.subset = propagate_memlet(graph, path[index + 1].data, outer_map_entry, True).subset for e in new_exit_edges: path = graph.memlet_path(e) index = next(i for i, edge in enumerate(path) if e is edge) e.data.subset = propagate_memlet(graph, path[index - 1].data, outer_map_exit, True).subset
def _block_gather(pv: 'ProgramVisitor', sdfg: SDFG, state: SDFGState, in_buffer: str, out_buffer: str, gather_grid: str, reduce_grid: str = None, correspondence: Sequence[Integral] = None): """ Block-gathers an Array using process-grids, sub-arrays, and the BlockGather library node. This method currently does not support Array slices and imperfect tiling. :param in_buffer: Name of the (local) Array descriptor. :param out_buffer: Name of the (global) Array descriptor. :param gather_grid: Name of the sub-grid used for gathering the Array (reduction group leaders). :param reduce_grid: Name of the sub-grid used for broadcasting the Array (reduction groups). :param correspondence: Matching of the array/sub-array's dimensions to the process-grid's dimensions. :return: Name of the new sub-array descriptor. """ in_desc = sdfg.arrays[in_buffer] out_desc = sdfg.arrays[out_buffer] if in_desc.dtype != out_desc.dtype: raise ValueError("Input/output buffer datatypes must match!") subarray_name = _subarray(pv, sdfg, state, out_buffer, in_buffer, process_grid=gather_grid, correspondence=correspondence) from dace.libraries.mpi import BlockGather libnode = BlockGather('_BlockGather_', subarray_name, gather_grid, reduce_grid) inbuf_name = in_buffer in_desc = sdfg.arrays[inbuf_name] inbuf_node = state.add_read(inbuf_name) inbuf_mem = Memlet.from_array(inbuf_name, in_desc) outbuf_name = out_buffer out_desc = sdfg.arrays[outbuf_name] outbuf_node = state.add_write(outbuf_name) outbuf_mem = Memlet.from_array(outbuf_name, out_desc) state.add_edge(inbuf_node, None, libnode, '_inp_buffer', inbuf_mem) state.add_edge(libnode, '_out_buffer', outbuf_node, None, outbuf_mem) return subarray_name
def _modify_reshape_data(self, reshapes: Set[str], repldict: Dict[str, str], new_edges: Dict[str, MultiConnectorEdge], nstate: SDFGState, state: SDFGState, inputs: bool): anodes = nstate.source_nodes() if inputs else nstate.sink_nodes() reshp = {repldict[r]: r for r in reshapes} for node in anodes: if not isinstance(node, nodes.AccessNode): continue if node.data not in reshp: continue edge = new_edges[reshp[node.data]] if inputs: state.add_edge(edge.src, edge.src_conn, node, None, edge.data) else: state.add_edge(node, None, edge.dst, edge.dst_conn, edge.data)
def expressions(): # Matching # \======/ # | | # o o g = SDFGState() g.add_node(OutMergeArrays._array1) g.add_node(OutMergeArrays._array2) g.add_node(OutMergeArrays._map_exit) g.add_edge(OutMergeArrays._map_exit, None, OutMergeArrays._array1, None, memlet.Memlet()) g.add_edge(OutMergeArrays._map_exit, None, OutMergeArrays._array2, None, memlet.Memlet()) return [g]
def expressions(): # Matching # o o # | | # /======\ g = SDFGState() g.add_node(InMergeArrays._array1) g.add_node(InMergeArrays._array2) g.add_node(InMergeArrays._map_entry) g.add_edge(InMergeArrays._array1, None, InMergeArrays._map_entry, None, memlet.Memlet()) g.add_edge(InMergeArrays._array2, None, InMergeArrays._map_entry, None, memlet.Memlet()) return [g]
def _transpose(sdfg: SDFG, state: SDFGState, inpname: str): arr1 = sdfg.arrays[inpname] restype = arr1.dtype outname, arr2 = sdfg.add_temp_transient((arr1.shape[1], arr1.shape[0]), restype, arr1.storage) acc1 = state.add_read(inpname) acc2 = state.add_write(outname) import dace.libraries.blas # Avoid import loop tasklet = dace.libraries.blas.Transpose('_Transpose_', restype) state.add_node(tasklet) state.add_edge(acc1, None, tasklet, '_inp', dace.Memlet.from_array(inpname, arr1)) state.add_edge(tasklet, '_out', acc2, None, dace.Memlet.from_array(outname, arr2)) return outname
def _simple_call(sdfg: SDFG, state: SDFGState, inpname: str, func: str, restype: dace.typeclass = None): """ Implements a simple call of the form `out = func(inp)`. """ inparr = sdfg.arrays[inpname] if restype is None: restype = sdfg.arrays[inpname].dtype outname, outarr = sdfg.add_temp_transient(inparr.shape, restype, inparr.storage) num_elements = reduce(lambda x, y: x * y, inparr.shape) if num_elements == 1: inp = state.add_read(inpname) out = state.add_write(outname) tasklet = state.add_tasklet(func, {'__inp'}, {'__out'}, '__out = {f}(__inp)'.format(f=func)) state.add_edge(inp, None, tasklet, '__inp', Memlet.from_array(inpname, inparr)) state.add_edge(tasklet, '__out', out, None, Memlet.from_array(outname, outarr)) else: state.add_mapped_tasklet( name=func, map_ranges={ '__i%d' % i: '0:%s' % n for i, n in enumerate(inparr.shape) }, inputs={ '__inp': Memlet.simple( inpname, ','.join(['__i%d' % i for i in range(len(inparr.shape))])) }, code='__out = {f}(__inp)'.format(f=func), outputs={ '__out': Memlet.simple( outname, ','.join(['__i%d' % i for i in range(len(inparr.shape))])) }, external_edges=True) return outname
def _Allreduce(pv: 'ProgramVisitor', sdfg: SDFG, state: SDFGState, buffer: str, op: str, grid: str = None): from dace.libraries.mpi.nodes.allreduce import Allreduce libnode = Allreduce('_Allreduce_', op, grid) desc = sdfg.arrays[buffer] in_buffer = state.add_read(buffer) out_buffer = state.add_write(buffer) state.add_edge(in_buffer, None, libnode, '_inbuffer', Memlet.from_array(buffer, desc)) state.add_edge(libnode, '_outbuffer', out_buffer, None, Memlet.from_array(buffer, desc)) return None
def _cart_sub(pv: 'ProgramVisitor', sdfg: SDFG, state: SDFGState, parent_grid: str, color: Sequence[Union[Integral, bool]], exact_grid: RankType = None): """ Partitions the `parent_grid` to lower-dimensional sub-grids and adds them to the DaCe program. The sub-grids are implemented with [MPI_Cart_sub](https://www.mpich.org/static/docs/latest/www3/MPI_Cart_sub.html). :param parent_grid: Parent process-grid (similar to the `comm` parameter of `MPI_Cart_sub`). :param color: The i-th entry specifies whether the i-th dimension is kept in the sub-grid or is dropped (see `remain_dims` input of `MPI_Cart_sub`). :param exact_grid: [DEVELOPER] If set then, out of all the sub-grids created, only the one that contains the rank with id `exact_grid` will be utilized for collective communication. :return: Name of the new sub-grid descriptor. """ pgrid_name = sdfg.add_pgrid(parent_grid=parent_grid, color=color, exact_grid=exact_grid) # Count sub-grid dimensions. pgrid_ndims = sum([bool(c) for c in color]) # Dummy tasklet adds MPI variables to the program's state. from dace.libraries.mpi import Dummy tasklet = Dummy(pgrid_name, [ f'MPI_Comm {pgrid_name}_comm;', f'MPI_Group {pgrid_name}_group;', f'int {pgrid_name}_coords[{pgrid_ndims}];', f'int {pgrid_name}_dims[{pgrid_ndims}];', f'int {pgrid_name}_rank;', f'int {pgrid_name}_size;', f'bool {pgrid_name}_valid;', ]) state.add_node(tasklet) # Pseudo-writing to a dummy variable to avoid removal of Dummy node by transformations. _, scal = sdfg.add_scalar(pgrid_name, dace.int32, transient=True) wnode = state.add_write(pgrid_name) state.add_edge(tasklet, '__out', wnode, None, Memlet.from_array(pgrid_name, scal)) return pgrid_name
def gemv_libnode(pv: 'ProgramVisitor', sdfg: SDFG, state: SDFGState, A, x, y, alpha, beta, trans=None): # Get properties if trans is None: trans = (sdfg.arrays[x].shape[0] == sdfg.arrays[A].shape[0]) # Add nodes A_in, x_in = (state.add_read(name) for name in (A, x)) y_out = state.add_write(y) libnode = Gemv('gemv', transA=trans, alpha=alpha, beta=beta) state.add_node(libnode) # Connect nodes state.add_edge(A_in, None, libnode, '_A', mm.Memlet(A)) state.add_edge(x_in, None, libnode, '_x', mm.Memlet(x)) state.add_edge(libnode, '_y', y_out, None, mm.Memlet(y)) if beta != 0: y_in = state.add_read(y) state.add_edge(y_in, None, libnode, '_y', mm.Memlet(y)) return []
def replicate_scope(sdfg: SDFG, state: SDFGState, scope: ScopeSubgraphView) -> ScopeSubgraphView: """ Replicates a scope subgraph view within a state, reconnecting all external edges to the same nodes. :param sdfg: The SDFG in which the subgraph scope resides. :param state: The SDFG state in which the subgraph scope resides. :param scope: The scope subgraph to replicate. :return: A reconnected replica of the scope. """ exit_node = state.exit_node(scope.entry) # Replicate internal graph new_nodes = [] new_entry = None new_exit = None to_find_new_names: Set[nodes.AccessNode] = set() for node in scope.nodes(): node_copy = copy.deepcopy(node) if node == scope.entry: new_entry = node_copy elif node == exit_node: new_exit = node_copy if (isinstance(node, nodes.AccessNode) and node.desc(sdfg).lifetime == dtypes.AllocationLifetime.Scope and node.desc(sdfg).transient): to_find_new_names.add(node_copy) state.add_node(node_copy) new_nodes.append(node_copy) for edge in scope.edges(): src = scope.nodes().index(edge.src) dst = scope.nodes().index(edge.dst) state.add_edge(new_nodes[src], edge.src_conn, new_nodes[dst], edge.dst_conn, copy.deepcopy(edge.data)) # Reconnect external scope nodes for edge in state.in_edges(scope.entry): state.add_edge(edge.src, edge.src_conn, new_entry, edge.dst_conn, copy.deepcopy(edge.data)) for edge in state.out_edges(exit_node): state.add_edge(new_exit, edge.src_conn, edge.dst, edge.dst_conn, copy.deepcopy(edge.data)) # Set the exit node's map to match the entry node new_exit.map = new_entry.map # Replicate all temporary transients within scope for node in to_find_new_names: desc = node.desc(sdfg) new_name = sdfg.add_datadesc(node.data, copy.deepcopy(desc), find_new_name=True) node.data = new_name for edge in state.all_edges(node): for e in state.memlet_tree(edge): e.data.data = new_name return ScopeSubgraphView(state, new_nodes, new_entry)
def nccl_reduce(pv: 'ProgramVisitor', sdfg: SDFG, state: SDFGState, redfunction: Callable[[Any, Any], Any], in_buffer: str, out_buffer: Union[str, None] = None, root: str = None, group_handle: str = None): inputs = {"_inbuffer"} outputs = {"_outbuffer"} if isinstance(group_handle, str): gh_start = False if group_handle in sdfg.arrays.keys(): gh_name = group_handle gh_out = state.add_access(gh_name) gh_in = state.add_access(gh_name) inputs.add("_group_handle") else: gh_start = True gh_name = _define_local_scalar(pv, sdfg, state, dace.int32, dtypes.StorageType.GPU_Global) gh_out = state.add_access(gh_name) outputs.add("_group_handle") libnode = Reduce(inputs=inputs, outputs=outputs, wcr=redfunction, root=root) if isinstance(group_handle, str): gh_memlet = Memlet.simple(gh_name, '0') if not gh_start: state.add_edge(gh_in, None, libnode, "_group_handle", gh_memlet) state.add_edge(libnode, "_group_handle", gh_out, None, gh_memlet) # If out_buffer is not specified, the operation will be in-place. if out_buffer is None: out_buffer = in_buffer # Add nodes in_node = state.add_read(in_buffer) out_node = state.add_write(out_buffer) # Connect nodes state.add_edge(in_node, None, libnode, '_inbuffer', Memlet(in_buffer)) state.add_edge(libnode, '_outbuffer', out_node, None, Memlet(out_buffer)) return []
def _wait(pv: 'ProgramVisitor', sdfg: SDFG, state: SDFGState, request: str): from dace.libraries.mpi.nodes.wait import Wait libnode = Wait('_Wait_') req_range = None if isinstance(request, tuple): req_name, req_range = request else: req_name = request desc = sdfg.arrays[req_name] req_node = state.add_access(req_name) src = sdfg.add_temp_transient([1], dtypes.int32) src_node = state.add_write(src[0]) tag = sdfg.add_temp_transient([1], dtypes.int32) tag_node = state.add_write(tag[0]) if req_range: req_mem = Memlet.simple(req_name, req_range) else: req_mem = Memlet.from_array(req_name, desc) state.add_edge(req_node, None, libnode, '_request', req_mem) state.add_edge(libnode, '_stat_source', src_node, None, Memlet.from_array(*src)) state.add_edge(libnode, '_stat_tag', tag_node, None, Memlet.from_array(*tag)) return None
def _wait(pv: 'ProgramVisitor', sdfg: SDFG, state: SDFGState, request: str): from dace.libraries.mpi.nodes.wait import Waitall libnode = Waitall('_Waitall_') req_range = None if isinstance(request, tuple): req_name, req_range = request else: req_name = request desc = sdfg.arrays[req_name] req_node = state.add_access(req_name) if req_range: req_mem = Memlet.simple(req_name, req_range) else: req_mem = Memlet.from_array(req_name, desc) state.add_edge(req_node, None, libnode, '_request', req_mem) return None
def _modify_access_to_access( self, input_edges: Dict[nodes.Node, MultiConnectorEdge], nsdfg: SDFG, nstate: SDFGState, state: SDFGState, orig_data: Dict[Union[nodes.AccessNode, MultiConnectorEdge], str], ) -> Set[MultiConnectorEdge]: """ Deals with access->access edges where both sides are non-transient. """ result = set() for node, top_edge in input_edges.items(): for inner_edge in nstate.out_edges(node): if inner_edge.dst not in orig_data: continue inner_data = orig_data[inner_edge.dst] if (isinstance(inner_edge.dst, nodes.AccessNode) and not nsdfg.arrays[inner_data].transient): matching_edge: MultiConnectorEdge = next( state.out_edges_by_connector(top_edge.dst, inner_data)) # Create memlet by unsqueezing both w.r.t. src and dst # subsets in_memlet = helpers.unsqueeze_memlet( inner_edge.data, top_edge.data) out_memlet = helpers.unsqueeze_memlet( inner_edge.data, matching_edge.data) new_memlet = in_memlet new_memlet.other_subset = out_memlet.subset # Connect with new edge state.add_edge(top_edge.src, top_edge.src_conn, matching_edge.dst, matching_edge.dst_conn, new_memlet) result.add(inner_edge) return result
def nccl_send(pv: 'ProgramVisitor', sdfg: SDFG, state: SDFGState, in_buffer: str, peer: symbolic.SymbolicType = 0, group_handle: str = None): inputs = {"_inbuffer"} outputs = set() if isinstance(group_handle, str): gh_start = False if group_handle in sdfg.arrays.keys(): gh_name = group_handle gh_out = state.add_access(gh_name) gh_in = state.add_access(gh_name) inputs.add("_group_handle") else: gh_start = True gh_name = _define_local_scalar(pv, sdfg, state, dace.int32, dtypes.StorageType.GPU_Global) gh_out = state.add_access(gh_name) outputs.add("_group_handle") libnode = Send(inputs=inputs, outputs=outputs, peer=peer) if isinstance(group_handle, str): gh_memlet = Memlet.simple(gh_name, '0') if not gh_start: state.add_edge(gh_in, None, libnode, "_group_handle", gh_memlet) state.add_edge(libnode, "_group_handle", gh_out, None, gh_memlet) in_range = None if isinstance(in_buffer, tuple): in_name, in_range = in_buffer else: in_name = in_buffer desc = sdfg.arrays[in_name] conn = libnode.in_connectors conn = { c: (dtypes.pointer(desc.dtype) if c == '_buffer' else t) for c, t in conn.items() } libnode.in_connectors = conn in_node = state.add_read(in_name) if in_range: buf_mem = Memlet.simple(in_name, in_range) else: buf_mem = Memlet.from_array(in_name, desc) state.add_edge(in_node, None, libnode, '_inbuffer', buf_mem) return []
def redirect_edge( state: SDFGState, edge: graph.MultiConnectorEdge[Memlet], new_src: Optional[nodes.Node] = None, new_dst: Optional[nodes.Node] = None, new_src_conn: Optional[str] = None, new_dst_conn: Optional[str] = None, new_data: Optional[str] = None, new_memlet: Optional[Memlet] = None ) -> graph.MultiConnectorEdge[Memlet]: """ Redirects an edge in a state. Choose which elements to override by setting the keyword arguments. :param state: The SDFG state in which the edge resides. :param edge: The edge to redirect. :param new_src: If provided, redirects the source of the new edge. :param new_dst: If provided, redirects the destination of the new edge. :param new_src_conn: If provided, renames the source connector of the edge. :param new_dst_conn: If provided, renames the destination connector of the edge. :param new_data: If provided, changes the data on the memlet of the edge, and the entire associated memlet tree. :param new_memlet: If provided, changes only the memlet of the new edge. :return: The new, redirected edge. :note: ``new_data`` and ``new_memlet`` cannot be used at the same time. """ if new_data is not None and new_memlet is not None: raise ValueError('new_data and new_memlet cannot both be given.') mtree = None if new_data is not None: mtree = state.memlet_tree(edge) state.remove_edge(edge) if new_data is not None: memlet = copy.deepcopy(edge.data) memlet.data = new_data # Rename on full memlet tree for e in mtree: e.data.data = new_data else: memlet = new_memlet or edge.data new_edge = state.add_edge(new_src or edge.src, new_src_conn or edge.src_conn, new_dst or edge.dst, new_dst_conn or edge.dst_conn, memlet) return new_edge
def nccl_recv(pv: 'ProgramVisitor', sdfg: SDFG, state: SDFGState, out_buffer: str, peer: symbolic.SymbolicType = 0, group_handle: str = None): inputs = set() outputs = {"_outbuffer"} if isinstance(group_handle, str): gh_start = False if group_handle in sdfg.arrays.keys(): gh_name = group_handle gh_out = state.add_access(gh_name) gh_in = state.add_access(gh_name) inputs.add("_group_handle") else: gh_start = True gh_name = _define_local_scalar(pv, sdfg, state, dace.int32, dtypes.StorageType.GPU_Global) gh_out = state.add_access(gh_name) outputs.add("_group_handle") libnode = Recv(inputs=inputs, outputs=outputs, peer=peer) if isinstance(group_handle, str): gh_memlet = Memlet.simple(gh_name, '0') if not gh_start: state.add_edge(gh_in, None, libnode, "_group_handle", gh_memlet) state.add_edge(libnode, "_group_handle", gh_out, None, gh_memlet) out_range = None if isinstance(out_buffer, tuple): out_name, out_range = out_buffer out_node = state.add_write(out_name) elif isinstance(out_buffer, str) and out_buffer in sdfg.arrays.keys(): out_name = out_buffer out_node = state.add_write(out_name) else: raise ValueError( "NCCL_Recv out_buffer must be an array, or a an array range tuple.") if out_range: out_mem = Memlet.simple(out_name, out_range) else: out_mem = Memlet.simple(out_name, '0') state.add_edge(libnode, '_outbuffer', out_node, None, out_mem) return []
def ger_libnode(pv: 'ProgramVisitor', sdfg: SDFG, state: SDFGState, A, x, y, output, alpha): # Add nodes A_in, x_in, y_in = (state.add_read(name) for name in (A, x, y)) out = state.add_write(output) libnode = Ger('ger', alpha=alpha) state.add_node(libnode) # Connect nodes state.add_edge(A_in, None, libnode, '_A', mm.Memlet(A)) state.add_edge(x_in, None, libnode, '_x', mm.Memlet(x)) state.add_edge(y_in, None, libnode, '_y', mm.Memlet(y)) state.add_edge(libnode, '_res', out, None, mm.Memlet(output)) return []
def replicate_scope(sdfg: SDFG, state: SDFGState, scope: ScopeSubgraphView) -> ScopeSubgraphView: """ Replicates a scope subgraph view within a state, reconnecting all external edges to the same nodes. :param sdfg: The SDFG in which the subgraph scope resides. :param state: The SDFG state in which the subgraph scope resides. :param scope: The scope subgraph to replicate. :return: A reconnected replica of the scope. """ exit_node = state.exit_node(scope.entry) # Replicate internal graph new_nodes = [] new_entry = None new_exit = None for node in scope.nodes(): node_copy = copy.deepcopy(node) if node == scope.entry: new_entry = node_copy elif node == exit_node: new_exit = node_copy state.add_node(node_copy) new_nodes.append(node_copy) for edge in scope.edges(): src = scope.nodes().index(edge.src) dst = scope.nodes().index(edge.dst) state.add_edge(new_nodes[src], edge.src_conn, new_nodes[dst], edge.dst_conn, copy.deepcopy(edge.data)) # Reconnect external scope nodes for edge in state.in_edges(scope.entry): state.add_edge(edge.src, edge.src_conn, new_entry, edge.dst_conn, copy.deepcopy(edge.data)) for edge in state.out_edges(exit_node): state.add_edge(new_exit, edge.src_conn, edge.dst, edge.dst_conn, copy.deepcopy(edge.data)) # Set the exit node's map to match the entry node new_exit.map = new_entry.map return ScopeSubgraphView(state, new_nodes, new_entry)
def _array_x_binop(visitor: 'ProgramVisitor', sdfg: SDFG, state: SDFGState, op1: str, op2: str, op: str, opcode: str): arr1 = sdfg.arrays[op1] type1 = arr1.dtype.type isscal1 = _is_scalar(sdfg, op1) isnum1 = isscal1 and (op1 in visitor.numbers.values()) if isnum1: type1 = inverse_dict_lookup(visitor.numbers, op1) arr2 = sdfg.arrays[op2] type2 = arr2.dtype.type isscal2 = _is_scalar(sdfg, op2) isnum2 = isscal2 and (op2 in visitor.numbers.values()) if isnum2: type2 = inverse_dict_lookup(visitor.numbers, op2) if _is_op_boolean(op): restype = dace.bool else: restype = dace.DTYPE_TO_TYPECLASS[np.result_type(type1, type2).type] if isscal1 and isscal2: arr1 = sdfg.arrays[op1] arr2 = sdfg.arrays[op2] op3, arr3 = sdfg.add_temp_transient([1], restype, arr2.storage) tasklet = state.add_tasklet('_SS%s_' % op, {'s1', 's2'}, {'s3'}, 's3 = s1 %s s2' % opcode) n1 = state.add_read(op1) n2 = state.add_read(op2) n3 = state.add_write(op3) state.add_edge(n1, None, tasklet, 's1', dace.Memlet.from_array(op1, arr1)) state.add_edge(n2, None, tasklet, 's2', dace.Memlet.from_array(op2, arr2)) state.add_edge(tasklet, 's3', n3, None, dace.Memlet.from_array(op3, arr3)) return op3 else: return _binop(sdfg, state, op1, op2, opcode, op, restype)
def apply(self, graph: SDFGState, sdfg: SDFG): tmap_exit = self.tmap_exit in_array = self.in_array reduce_node = self.reduce out_array = self.out_array # Set nodes to remove according to the expression index nodes_to_remove = [in_array] nodes_to_remove.append(reduce_node) memlet_edge = None for edge in graph.in_edges(tmap_exit): if edge.data.data == in_array.data: memlet_edge = edge break if memlet_edge is None: raise RuntimeError('Reduction memlet cannot be None') # Find which indices should be removed from new memlet input_edge = graph.in_edges(reduce_node)[0] axes = reduce_node.axes or list(range(len(input_edge.data.subset))) array_edge = graph.out_edges(reduce_node)[0] # Delete relevant edges and nodes graph.remove_nodes_from(nodes_to_remove) # Delete relevant data descriptors for node in set(nodes_to_remove): if isinstance(node, nodes.AccessNode): # try to delete it try: sdfg.remove_data(node.data) # will raise ValueError if the datadesc is used somewhere else except ValueError: pass # Filter out reduced dimensions from subset filtered_subset = [ dim for i, dim in enumerate(memlet_edge.data.subset) if i not in axes ] if len(filtered_subset) == 0: # Output is a scalar filtered_subset = [(0, 0, 1)] # Modify edge from tasklet to map exit memlet_edge.data.data = out_array.data memlet_edge.data.wcr = reduce_node.wcr memlet_edge.data.subset = type( memlet_edge.data.subset)(filtered_subset) # Add edge from map exit to output array graph.add_edge( memlet_edge.dst, 'OUT_' + memlet_edge.dst_conn[3:], array_edge.dst, array_edge.dst_conn, Memlet.simple(array_edge.data.data, array_edge.data.subset, num_accesses=array_edge.data.num_accesses, wcr_str=reduce_node.wcr)) # Add initialization state as necessary if not self.no_init and reduce_node.identity is not None: init_state = sdfg.add_state_before(graph) init_state.add_mapped_tasklet( 'freduce_init', [('o%d' % i, '%s:%s:%s' % (r[0], r[1] + 1, r[2])) for i, r in enumerate(array_edge.data.subset)], {}, '__out = %s' % reduce_node.identity, { '__out': Memlet.simple( array_edge.data.data, ','.join([ 'o%d' % i for i in range(len(array_edge.data.subset)) ])) }, external_edges=True)
def apply(self, graph: sd.SDFGState, sdfg: sd.SDFG): map_entry = self.map_entry map_exit = graph.exit_node(map_entry) nsdfg_node: Optional[nodes.NestedSDFG] = None # Obtain subgraph to perform fission to if self.expr_index == 0: # Map with subgraph subgraphs = [(graph, graph.scope_subgraph(map_entry, include_entry=False, include_exit=False))] parent = sdfg else: # Map with nested SDFG nsdfg_node = self.nested_sdfg subgraphs = [(state, state) for state in nsdfg_node.sdfg.nodes()] parent = nsdfg_node.sdfg modified_arrays = set() # Get map information outer_map: nodes.Map = map_entry.map mapsize = outer_map.range.size() # Add new symbols from outer map to nested SDFG if self.expr_index == 1: map_syms = outer_map.range.free_symbols for edge in graph.out_edges(map_entry): if edge.data.data: map_syms.update(edge.data.subset.free_symbols) for edge in graph.in_edges(map_exit): if edge.data.data: map_syms.update(edge.data.subset.free_symbols) for sym in map_syms: symname = str(sym) if symname in outer_map.params: continue if symname not in nsdfg_node.symbol_mapping.keys(): nsdfg_node.symbol_mapping[symname] = sym nsdfg_node.sdfg.symbols[ symname] = graph.symbols_defined_at( nsdfg_node)[symname] # Remove map symbols from nested mapping for name in outer_map.params: if str(name) in nsdfg_node.symbol_mapping: del nsdfg_node.symbol_mapping[str(name)] if str(name) in nsdfg_node.sdfg.symbols: del nsdfg_node.sdfg.symbols[str(name)] for state, subgraph in subgraphs: components = MapFission._components(subgraph) sources = subgraph.source_nodes() sinks = subgraph.sink_nodes() # Collect external edges if self.expr_index == 0: external_edges_entry = list(state.out_edges(map_entry)) external_edges_exit = list(state.in_edges(map_exit)) else: external_edges_entry = [ e for e in subgraph.edges() if (isinstance(e.src, nodes.AccessNode) and not nsdfg_node.sdfg.arrays[e.src.data].transient) ] external_edges_exit = [ e for e in subgraph.edges() if (isinstance(e.dst, nodes.AccessNode) and not nsdfg_node.sdfg.arrays[e.dst.data].transient) ] # Map external edges to outer memlets edge_to_outer = {} for edge in external_edges_entry: if self.expr_index == 0: # Subgraphs use the corresponding outer map edges path = state.memlet_path(edge) eindex = path.index(edge) edge_to_outer[edge] = path[eindex - 1] else: # Nested SDFGs use the internal map edges of the node outer_edge = next(e for e in graph.in_edges(nsdfg_node) if e.dst_conn == edge.src.data) edge_to_outer[edge] = outer_edge for edge in external_edges_exit: if self.expr_index == 0: path = state.memlet_path(edge) eindex = path.index(edge) edge_to_outer[edge] = path[eindex + 1] else: # Nested SDFGs use the internal map edges of the node outer_edge = next(e for e in graph.out_edges(nsdfg_node) if e.src_conn == edge.dst.data) edge_to_outer[edge] = outer_edge # Collect all border arrays and code->code edges arrays = MapFission._border_arrays( nsdfg_node.sdfg if self.expr_index == 1 else sdfg, state, subgraph) scalars = defaultdict(list) for _, component_out in components: for e in subgraph.out_edges(component_out): if isinstance(e.dst, nodes.CodeNode): scalars[e.data.data].append(e) # Create new arrays for scalars for scalar, edges in scalars.items(): desc = parent.arrays[scalar] del parent.arrays[scalar] name, newdesc = parent.add_transient( scalar, mapsize, desc.dtype, desc.storage, lifetime=desc.lifetime, debuginfo=desc.debuginfo, allow_conflicts=desc.allow_conflicts, find_new_name=True) # Add extra nodes in component boundaries for edge in edges: anode = state.add_access(name) sbs = subsets.Range.from_string(','.join(outer_map.params)) # Offset memlet by map range begin (to fit the transient) sbs.offset([r[0] for r in outer_map.range], True) state.add_edge( edge.src, edge.src_conn, anode, None, mm.Memlet.simple( name, sbs, num_accesses=outer_map.range.num_elements())) state.add_edge( anode, None, edge.dst, edge.dst_conn, mm.Memlet.simple( name, sbs, num_accesses=outer_map.range.num_elements())) state.remove_edge(edge) # Add extra maps around components new_map_entries = [] for component_in, component_out in components: me, mx = state.add_map(outer_map.label + '_fission', [(p, '0:1') for p in outer_map.params], outer_map.schedule, unroll=outer_map.unroll, debuginfo=outer_map.debuginfo) # Add dynamic input connectors for conn in map_entry.in_connectors: if not conn.startswith('IN_'): me.add_in_connector(conn) me.map.range = dcpy(outer_map.range) new_map_entries.append(me) # Reconnect edges through new map for e in state.in_edges(component_in): state.add_edge(me, None, e.dst, e.dst_conn, dcpy(e.data)) # Reconnect inner edges at source directly to external nodes if self.expr_index == 0 and e in external_edges_entry: state.add_edge(edge_to_outer[e].src, edge_to_outer[e].src_conn, me, None, dcpy(edge_to_outer[e].data)) else: state.add_edge(e.src, e.src_conn, me, None, dcpy(e.data)) state.remove_edge(e) # Empty memlet edge in nested SDFGs if state.in_degree(component_in) == 0: state.add_edge(me, None, component_in, None, mm.Memlet()) for e in state.out_edges(component_out): state.add_edge(e.src, e.src_conn, mx, None, dcpy(e.data)) # Reconnect inner edges at sink directly to external nodes if self.expr_index == 0 and e in external_edges_exit: state.add_edge(mx, None, edge_to_outer[e].dst, edge_to_outer[e].dst_conn, dcpy(edge_to_outer[e].data)) else: state.add_edge(mx, None, e.dst, e.dst_conn, dcpy(e.data)) state.remove_edge(e) # Empty memlet edge in nested SDFGs if state.out_degree(component_out) == 0: state.add_edge(component_out, None, mx, None, mm.Memlet()) # Connect other sources/sinks not in components (access nodes) # directly to external nodes if self.expr_index == 0: for node in sources: if isinstance(node, nodes.AccessNode): for edge in state.in_edges(node): outer_edge = edge_to_outer[edge] memlet = dcpy(edge.data) memlet.subset = subsets.Range( outer_map.range.ranges + memlet.subset.ranges) state.add_edge(outer_edge.src, outer_edge.src_conn, edge.dst, edge.dst_conn, memlet) for node in sinks: if isinstance(node, nodes.AccessNode): for edge in state.out_edges(node): outer_edge = edge_to_outer[edge] state.add_edge(edge.src, edge.src_conn, outer_edge.dst, outer_edge.dst_conn, dcpy(outer_edge.data)) # Augment arrays by prepending map dimensions for array in arrays: if array in modified_arrays: continue desc = parent.arrays[array] if isinstance( desc, dt.Scalar): # Scalar needs to be augmented to an array desc = dt.Array(desc.dtype, desc.shape, desc.transient, desc.allow_conflicts, desc.storage, desc.location, desc.strides, desc.offset, False, desc.lifetime, 0, desc.debuginfo, desc.total_size, desc.start_offset) parent.arrays[array] = desc for sz in reversed(mapsize): desc.strides = [desc.total_size] + list(desc.strides) desc.total_size = desc.total_size * sz desc.shape = mapsize + list(desc.shape) desc.offset = [0] * len(mapsize) + list(desc.offset) modified_arrays.add(array) # Fill scope connectors so that memlets can be tracked below state.fill_scope_connectors() # Correct connectors and memlets in nested SDFGs to account for # missing outside map if self.expr_index == 1: to_correct = ([(e, e.src) for e in external_edges_entry] + [(e, e.dst) for e in external_edges_exit]) corrected_nodes = set() for edge, node in to_correct: if isinstance(node, nodes.AccessNode): if node in corrected_nodes: continue corrected_nodes.add(node) outer_edge = edge_to_outer[edge] desc = parent.arrays[node.data] # Modify shape of internal array to match outer one outer_desc = sdfg.arrays[outer_edge.data.data] if not isinstance(desc, dt.Scalar): desc.shape = outer_desc.shape if isinstance(desc, dt.Array): desc.strides = outer_desc.strides desc.total_size = outer_desc.total_size # Inside the nested SDFG, offset all memlets to include # the offsets from within the map. # NOTE: Relies on propagation to fix outer memlets for internal_edge in state.all_edges(node): for e in state.memlet_tree(internal_edge): e.data.subset.offset(desc.offset, False) e.data.subset = helpers.unsqueeze_memlet( e.data, outer_edge.data).subset # Only after offsetting memlets we can modify the # overall offset if isinstance(desc, dt.Array): desc.offset = outer_desc.offset # Fill in memlet trees for border transients # NOTE: Memlet propagation should run to correct the outer edges for node in subgraph.nodes(): if isinstance(node, nodes.AccessNode) and node.data in arrays: for edge in state.all_edges(node): for e in state.memlet_tree(edge): # Prepend map dimensions to memlet e.data.subset = subsets.Range( [(pystr_to_symbolic(d) - r[0], pystr_to_symbolic(d) - r[0], 1) for d, r in zip(outer_map.params, outer_map.range)] + e.data.subset.ranges) # If nested SDFG, reconnect nodes around map and modify memlets if self.expr_index == 1: for edge in graph.in_edges(map_entry): if not edge.dst_conn or not edge.dst_conn.startswith('IN_'): continue # Modify edge coming into nested SDFG to include entire array desc = sdfg.arrays[edge.data.data] edge.data.subset = subsets.Range.from_array(desc) edge.data.num_accesses = edge.data.subset.num_elements() # Find matching edge inside map inner_edge = next( e for e in graph.out_edges(map_entry) if e.src_conn and e.src_conn[4:] == edge.dst_conn[3:]) graph.add_edge(edge.src, edge.src_conn, nsdfg_node, inner_edge.dst_conn, dcpy(edge.data)) for edge in graph.out_edges(map_exit): # Modify edge coming out of nested SDFG to include entire array desc = sdfg.arrays[edge.data.data] edge.data.subset = subsets.Range.from_array(desc) # Find matching edge inside map inner_edge = next(e for e in graph.in_edges(map_exit) if e.dst_conn[3:] == edge.src_conn[4:]) graph.add_edge(nsdfg_node, inner_edge.src_conn, edge.dst, edge.dst_conn, dcpy(edge.data)) # Remove outer map graph.remove_nodes_from([map_entry, map_exit])