def _wait(pv: 'ProgramVisitor', sdfg: SDFG, state: SDFGState, request: str): from dace.libraries.mpi.nodes.wait import Wait libnode = Wait('_Wait_') req_range = None if isinstance(request, tuple): req_name, req_range = request else: req_name = request desc = sdfg.arrays[req_name] req_node = state.add_access(req_name) src = sdfg.add_temp_transient([1], dtypes.int32) src_node = state.add_write(src[0]) tag = sdfg.add_temp_transient([1], dtypes.int32) tag_node = state.add_write(tag[0]) if req_range: req_mem = Memlet.simple(req_name, req_range) else: req_mem = Memlet.from_array(req_name, desc) state.add_edge(req_node, None, libnode, '_request', req_mem) state.add_edge(libnode, '_stat_source', src_node, None, Memlet.from_array(*src)) state.add_edge(libnode, '_stat_tag', tag_node, None, Memlet.from_array(*tag)) return None
def _gather(pv: 'ProgramVisitor', sdfg: SDFG, state: SDFGState, in_buffer: str, out_buffer: str, root: Union[str, sp.Expr, Number] = 0): from dace.libraries.mpi.nodes.gather import Gather libnode = Gather('_Gather_') in_desc = sdfg.arrays[in_buffer] out_desc = sdfg.arrays[out_buffer] in_node = state.add_read(in_buffer) out_node = state.add_write(out_buffer) if isinstance(root, str) and root in sdfg.arrays.keys(): root_node = state.add_read(root) else: storage = in_desc.storage root_name = _define_local_scalar(pv, sdfg, state, dace.int32, storage) root_node = state.add_access(root_name) root_tasklet = state.add_tasklet('_set_root_', {}, {'__out'}, '__out = {}'.format(root)) state.add_edge(root_tasklet, '__out', root_node, None, Memlet.simple(root_name, '0')) state.add_edge(in_node, None, libnode, '_inbuffer', Memlet.from_array(in_buffer, in_desc)) state.add_edge(root_node, None, libnode, '_root', Memlet.simple(root_node.data, '0')) state.add_edge(libnode, '_outbuffer', out_node, None, Memlet.from_array(out_buffer, out_desc)) return None
def _Reduce(pv: 'ProgramVisitor', sdfg: SDFG, state: SDFGState, buffer: str, op: str, root: Union[str, sp.Expr, Number] = 0, grid: str = None): from dace.libraries.mpi.nodes.reduce import Reduce libnode = Reduce('_Reduce_', op, grid) desc = sdfg.arrays[buffer] in_buffer = state.add_read(buffer) out_buffer = state.add_write(buffer) if isinstance(root, str) and root in sdfg.arrays.keys(): root_node = state.add_read(root) else: storage = desc.storage root_name = _define_local_scalar(pv, sdfg, state, dace.int32, storage) root_node = state.add_access(root_name) root_tasklet = state.add_tasklet('_set_root_', {}, {'__out'}, '__out = {}'.format(root)) state.add_edge(root_tasklet, '__out', root_node, None, Memlet.simple(root_name, '0')) state.add_edge(in_buffer, None, libnode, '_inbuffer', Memlet.from_array(buffer, desc)) state.add_edge(root_node, None, libnode, '_root', Memlet.simple(root_node.data, '0')) state.add_edge(libnode, '_outbuffer', out_buffer, None, Memlet.from_array(buffer, desc)) return None
def _block_gather(pv: 'ProgramVisitor', sdfg: SDFG, state: SDFGState, in_buffer: str, out_buffer: str, gather_grid: str, reduce_grid: str = None, correspondence: Sequence[Integral] = None): """ Block-gathers an Array using process-grids, sub-arrays, and the BlockGather library node. This method currently does not support Array slices and imperfect tiling. :param in_buffer: Name of the (local) Array descriptor. :param out_buffer: Name of the (global) Array descriptor. :param gather_grid: Name of the sub-grid used for gathering the Array (reduction group leaders). :param reduce_grid: Name of the sub-grid used for broadcasting the Array (reduction groups). :param correspondence: Matching of the array/sub-array's dimensions to the process-grid's dimensions. :return: Name of the new sub-array descriptor. """ in_desc = sdfg.arrays[in_buffer] out_desc = sdfg.arrays[out_buffer] if in_desc.dtype != out_desc.dtype: raise ValueError("Input/output buffer datatypes must match!") subarray_name = _subarray(pv, sdfg, state, out_buffer, in_buffer, process_grid=gather_grid, correspondence=correspondence) from dace.libraries.mpi import BlockGather libnode = BlockGather('_BlockGather_', subarray_name, gather_grid, reduce_grid) inbuf_name = in_buffer in_desc = sdfg.arrays[inbuf_name] inbuf_node = state.add_read(inbuf_name) inbuf_mem = Memlet.from_array(inbuf_name, in_desc) outbuf_name = out_buffer out_desc = sdfg.arrays[outbuf_name] outbuf_node = state.add_write(outbuf_name) outbuf_mem = Memlet.from_array(outbuf_name, out_desc) state.add_edge(inbuf_node, None, libnode, '_inp_buffer', inbuf_mem) state.add_edge(libnode, '_out_buffer', outbuf_node, None, outbuf_mem) return subarray_name
def _reduce(sdfg: SDFG, state: SDFGState, redfunction: Callable[[Any, Any], Any], in_array: str, out_array=None, axis=None, identity=None): if out_array is None: inarr = in_array # Convert axes to tuple if axis is not None and not isinstance(axis, (tuple, list)): axis = (axis, ) if axis is not None: axis = tuple(pystr_to_symbolic(a) for a in axis) input_subset = parse_memlet_subset(sdfg.arrays[inarr], ast.parse(in_array).body[0].value, {}) input_memlet = Memlet.simple(inarr, input_subset) output_shape = None if axis is None: output_shape = [1] else: output_subset = copy.deepcopy(input_subset) output_subset.pop(axis) output_shape = output_subset.size() outarr, arr = sdfg.add_temp_transient(output_shape, sdfg.arrays[inarr].dtype, sdfg.arrays[inarr].storage) output_memlet = Memlet.from_array(outarr, arr) else: inarr = in_array outarr = out_array # Convert axes to tuple if axis is not None and not isinstance(axis, (tuple, list)): axis = (axis, ) if axis is not None: axis = tuple(pystr_to_symbolic(a) for a in axis) # Compute memlets input_subset = parse_memlet_subset(sdfg.arrays[inarr], ast.parse(in_array).body[0].value, {}) input_memlet = Memlet.simple(inarr, input_subset) output_subset = parse_memlet_subset(sdfg.arrays[outarr], ast.parse(out_array).body[0].value, {}) output_memlet = Memlet.simple(outarr, output_subset) # Create reduce subgraph inpnode = state.add_read(inarr) rednode = state.add_reduce(redfunction, axis, identity) outnode = state.add_write(outarr) state.add_nedge(inpnode, rednode, input_memlet) state.add_nedge(rednode, outnode, output_memlet) if out_array is None: return outarr else: return []
def _cart_create(pv: 'ProgramVisitor', sdfg: SDFG, state: SDFGState, dims: ShapeType): """ Creates a process-grid and adds it to the DaCe program. The process-grid is implemented with [MPI_Cart_create](https://www.mpich.org/static/docs/latest/www3/MPI_Cart_create.html). :param dims: Shape of the process-grid (see `dims` parameter of `MPI_Cart_create`), e.g., [2, 3, 3]. :return: Name of the new process-grid descriptor. """ pgrid_name = sdfg.add_pgrid(dims) # Dummy tasklet adds MPI variables to the program's state. from dace.libraries.mpi import Dummy tasklet = Dummy(pgrid_name, [ f'MPI_Comm {pgrid_name}_comm;', f'MPI_Group {pgrid_name}_group;', f'int {pgrid_name}_coords[{len(dims)}];', f'int {pgrid_name}_dims[{len(dims)}];', f'int {pgrid_name}_rank;', f'int {pgrid_name}_size;', f'bool {pgrid_name}_valid;', ]) state.add_node(tasklet) # Pseudo-writing to a dummy variable to avoid removal of Dummy node by transformations. _, scal = sdfg.add_scalar(pgrid_name, dace.int32, transient=True) wnode = state.add_write(pgrid_name) state.add_edge(tasklet, '__out', wnode, None, Memlet.from_array(pgrid_name, scal)) return pgrid_name
def test_gpu_dma(): sdfg: dace.SDFG = gpu_dma.to_sdfg(strict=True) sdfg.name = 'gpu_dma' sdfg.apply_transformations(GPUTransformSDFG, options={'strict_transform': False}) map_ = next(n for n, _ in sdfg.all_nodes_recursive() if isinstance(n, nodes.MapEntry)) add_gpu_location(sdfg, map_, 0) sdfg.arrays['gpu_X'].location = {'gpu': 1} # clone GPU scalar inodename = 'alpha' inode = sdfg.arrays['alpha'] newdesc = inode.clone() newdesc.location = {'gpu': 0} newdesc.storage = StorageType.GPU_Global newdesc.transient = True name = sdfg.add_datadesc('gpu_' + inodename, newdesc, find_new_name=True) # Replace original scalar for state in sdfg.nodes(): for node in state.nodes(): if (isinstance(node, nodes.AccessNode) and node.data == inodename): node.data = name # Replace memlets for state in sdfg.nodes(): for edge in state.edges(): if edge.data.data == inodename: edge.data.data = name # add GPU scalar to the copyin state copyin_state = sdfg.start_state src_array = nodes.AccessNode(inodename, debuginfo=inode.debuginfo) dst_array = nodes.AccessNode(name, debuginfo=inode.debuginfo) copyin_state.add_node(src_array) copyin_state.add_node(dst_array) copyin_state.add_nedge( src_array, dst_array, Memlet.from_array(src_array.data, src_array.desc(sdfg))) sdfg.apply_strict_transformations() np.random.seed(0) n = 16 X = np.ndarray(shape=n, dtype=np_dtype) alpha = np.ndarray(shape=1, dtype=np_dtype) alpha.fill(np.random.rand()) a_times_X = sdfg(X=X, alpha=alpha[0], N=n) res = X * alpha idx = zip(*np.where(~np.isclose(res, a_times_X, atol=0, rtol=1e-7))) for i in idx: print(i, res[i], a_times_X, X[i] * alpha, X[i], alpha) assert np.allclose(res, a_times_X) print('PASS')
def _simple_call(sdfg: SDFG, state: SDFGState, inpname: str, func: str, restype: dace.typeclass = None): """ Implements a simple call of the form `out = func(inp)`. """ inparr = sdfg.arrays[inpname] if restype is None: restype = sdfg.arrays[inpname].dtype outname, outarr = sdfg.add_temp_transient(inparr.shape, restype, inparr.storage) num_elements = reduce(lambda x, y: x * y, inparr.shape) if num_elements == 1: inp = state.add_read(inpname) out = state.add_write(outname) tasklet = state.add_tasklet(func, {'__inp'}, {'__out'}, '__out = {f}(__inp)'.format(f=func)) state.add_edge(inp, None, tasklet, '__inp', Memlet.from_array(inpname, inparr)) state.add_edge(tasklet, '__out', out, None, Memlet.from_array(outname, outarr)) else: state.add_mapped_tasklet( name=func, map_ranges={ '__i%d' % i: '0:%s' % n for i, n in enumerate(inparr.shape) }, inputs={ '__inp': Memlet.simple( inpname, ','.join(['__i%d' % i for i in range(len(inparr.shape))])) }, code='__out = {f}(__inp)'.format(f=func), outputs={ '__out': Memlet.simple( outname, ','.join(['__i%d' % i for i in range(len(inparr.shape))])) }, external_edges=True) return outname
def _Allreduce(pv: 'ProgramVisitor', sdfg: SDFG, state: SDFGState, buffer: str, op: str, grid: str = None): from dace.libraries.mpi.nodes.allreduce import Allreduce libnode = Allreduce('_Allreduce_', op, grid) desc = sdfg.arrays[buffer] in_buffer = state.add_read(buffer) out_buffer = state.add_write(buffer) state.add_edge(in_buffer, None, libnode, '_inbuffer', Memlet.from_array(buffer, desc)) state.add_edge(libnode, '_outbuffer', out_buffer, None, Memlet.from_array(buffer, desc)) return None
def create_batch_gemm_sdfg(dtype, strides): ######################### sdfg = SDFG('einsum') state = sdfg.add_state() M, K, N = (symbolic.symbol(s) for s in ['M', 'K', 'N']) BATCH, sAM, sAK, sAB, sBK, sBN, sBB, sCM, sCN, sCB = ( symbolic.symbol(s) if symbolic.issymbolic(strides[s]) else strides[s] for s in [ 'BATCH', 'sAM', 'sAK', 'sAB', 'sBK', 'sBN', 'sBB', 'sCM', 'sCN', 'sCB' ]) batched = strides['BATCH'] != 1 _, xarr = sdfg.add_array( 'X', dtype=dtype, shape=[BATCH, M, K] if batched else [M, K], strides=[sAB, sAM, sAK] if batched else [sAM, sAK]) _, yarr = sdfg.add_array( 'Y', dtype=dtype, shape=[BATCH, K, N] if batched else [K, N], strides=[sBB, sBK, sBN] if batched else [sBK, sBN]) _, zarr = sdfg.add_array( 'Z', dtype=dtype, shape=[BATCH, M, N] if batched else [M, N], strides=[sCB, sCM, sCN] if batched else [sCM, sCN]) gX = state.add_read('X') gY = state.add_read('Y') gZ = state.add_write('Z') import dace.libraries.blas as blas # Avoid import loop libnode = blas.MatMul('einsum_gemm') state.add_node(libnode) state.add_edge(gX, None, libnode, '_a', Memlet.from_array(gX.data, xarr)) state.add_edge(gY, None, libnode, '_b', Memlet.from_array(gY.data, yarr)) state.add_edge(libnode, '_c', gZ, None, Memlet.from_array(gZ.data, zarr)) return sdfg
def nccl_send(pv: 'ProgramVisitor', sdfg: SDFG, state: SDFGState, in_buffer: str, peer: symbolic.SymbolicType = 0, group_handle: str = None): inputs = {"_inbuffer"} outputs = set() if isinstance(group_handle, str): gh_start = False if group_handle in sdfg.arrays.keys(): gh_name = group_handle gh_out = state.add_access(gh_name) gh_in = state.add_access(gh_name) inputs.add("_group_handle") else: gh_start = True gh_name = _define_local_scalar(pv, sdfg, state, dace.int32, dtypes.StorageType.GPU_Global) gh_out = state.add_access(gh_name) outputs.add("_group_handle") libnode = Send(inputs=inputs, outputs=outputs, peer=peer) if isinstance(group_handle, str): gh_memlet = Memlet.simple(gh_name, '0') if not gh_start: state.add_edge(gh_in, None, libnode, "_group_handle", gh_memlet) state.add_edge(libnode, "_group_handle", gh_out, None, gh_memlet) in_range = None if isinstance(in_buffer, tuple): in_name, in_range = in_buffer else: in_name = in_buffer desc = sdfg.arrays[in_name] conn = libnode.in_connectors conn = { c: (dtypes.pointer(desc.dtype) if c == '_buffer' else t) for c, t in conn.items() } libnode.in_connectors = conn in_node = state.add_read(in_name) if in_range: buf_mem = Memlet.simple(in_name, in_range) else: buf_mem = Memlet.from_array(in_name, desc) state.add_edge(in_node, None, libnode, '_inbuffer', buf_mem) return []
def _subarray(pv: 'ProgramVisitor', sdfg: SDFG, state: SDFGState, array: Union[str, ShapeType], subarray: Union[str, ShapeType], dtype: dtypes.typeclass = None, process_grid: str = None, correspondence: Sequence[Integral] = None): """ Adds a sub-array descriptor to the DaCe Program. Sub-arrays are implemented (when `process_grid` is set) with [MPI_Type_create_subarray](https://www.mpich.org/static/docs/v3.2/www3/MPI_Type_create_subarray.html). :param array: Either the name of an Array descriptor or the shape of the array (similar to the `array_of_sizes` parameter of `MPI_Type_create_subarray`). :param subarray: Either the name of an Array descriptor or the sub-shape of the (sub-)array (similar to the `array_of_subsizes` parameter of `MPI_Type_create_subarray`). :param dtype: Datatype of the array/sub-array (similar to the `oldtype` parameter of `MPI_Type_create_subarray`). :process_grid: Name of the process-grid for collective scatter/gather operations. :param correspondence: Matching of the array/sub-array's dimensions to the process-grid's dimensions. :return: Name of the new sub-array descriptor. """ # Get dtype, shape, and subshape if isinstance(array, str): shape = sdfg.arrays[array].shape arr_dtype = sdfg.arrays[array].dtype else: shape = array arr_dtype = None if isinstance(subarray, str): subshape = sdfg.arrays[subarray].shape sub_dtype = sdfg.arrays[subarray].dtype else: subshape = subarray sub_dtype = None dtype = dtype or arr_dtype or sub_dtype subarray_name = sdfg.add_subarray(dtype, shape, subshape, process_grid, correspondence) # Generate subgraph only if process-grid is set, i.e., the sub-array will be used for collective scatter/gather ops. if process_grid: # Dummy tasklet adds MPI variables to the program's state. from dace.libraries.mpi import Dummy tasklet = Dummy(subarray_name, [ f'MPI_Datatype {subarray_name};', f'int* {subarray_name}_counts;', f'int* {subarray_name}_displs;' ]) state.add_node(tasklet) # Pseudo-writing to a dummy variable to avoid removal of Dummy node by transformations. _, scal = sdfg.add_scalar(subarray_name, dace.int32, transient=True) wnode = state.add_write(subarray_name) state.add_edge(tasklet, '__out', wnode, None, Memlet.from_array(subarray_name, scal)) return subarray_name
def _cart_sub(pv: 'ProgramVisitor', sdfg: SDFG, state: SDFGState, parent_grid: str, color: Sequence[Union[Integral, bool]], exact_grid: RankType = None): """ Partitions the `parent_grid` to lower-dimensional sub-grids and adds them to the DaCe program. The sub-grids are implemented with [MPI_Cart_sub](https://www.mpich.org/static/docs/latest/www3/MPI_Cart_sub.html). :param parent_grid: Parent process-grid (similar to the `comm` parameter of `MPI_Cart_sub`). :param color: The i-th entry specifies whether the i-th dimension is kept in the sub-grid or is dropped (see `remain_dims` input of `MPI_Cart_sub`). :param exact_grid: [DEVELOPER] If set then, out of all the sub-grids created, only the one that contains the rank with id `exact_grid` will be utilized for collective communication. :return: Name of the new sub-grid descriptor. """ pgrid_name = sdfg.add_pgrid(parent_grid=parent_grid, color=color, exact_grid=exact_grid) # Count sub-grid dimensions. pgrid_ndims = sum([bool(c) for c in color]) # Dummy tasklet adds MPI variables to the program's state. from dace.libraries.mpi import Dummy tasklet = Dummy(pgrid_name, [ f'MPI_Comm {pgrid_name}_comm;', f'MPI_Group {pgrid_name}_group;', f'int {pgrid_name}_coords[{pgrid_ndims}];', f'int {pgrid_name}_dims[{pgrid_ndims}];', f'int {pgrid_name}_rank;', f'int {pgrid_name}_size;', f'bool {pgrid_name}_valid;', ]) state.add_node(tasklet) # Pseudo-writing to a dummy variable to avoid removal of Dummy node by transformations. _, scal = sdfg.add_scalar(pgrid_name, dace.int32, transient=True) wnode = state.add_write(pgrid_name) state.add_edge(tasklet, '__out', wnode, None, Memlet.from_array(pgrid_name, scal)) return pgrid_name
def _wait(pv: 'ProgramVisitor', sdfg: SDFG, state: SDFGState, request: str): from dace.libraries.mpi.nodes.wait import Waitall libnode = Waitall('_Waitall_') req_range = None if isinstance(request, tuple): req_name, req_range = request else: req_name = request desc = sdfg.arrays[req_name] req_node = state.add_access(req_name) if req_range: req_mem = Memlet.simple(req_name, req_range) else: req_mem = Memlet.from_array(req_name, desc) state.add_edge(req_node, None, libnode, '_request', req_mem) return None
def make_sdfg(implementation, dtype, storage=dace.StorageType.Default): n = dace.symbol("n") suffix = "_device" if storage != dace.StorageType.Default else "" transient = storage != dace.StorageType.Default sdfg = dace.SDFG("matrix_solve_getrf_getrs_{}_{}".format( implementation, dtype)) state = sdfg.add_state("dataflow") Ahost_arr = sdfg.add_array("A", [n, n], dtype, storage=dace.StorageType.Default) Bhost_arr = sdfg.add_array("B", [n], dtype, storage=dace.StorageType.Default) if transient: A_arr = sdfg.add_array("A" + suffix, [n, n], dtype, storage=storage, transient=transient) AT_arr = sdfg.add_array('AT' + suffix, [n, n], dtype, storage=storage, transient=transient) B_arr = sdfg.add_array("B" + suffix, [n], dtype, storage=storage, transient=transient) sdfg.add_array("pivots" + suffix, [n], dace.dtypes.int32, storage=storage, transient=transient) sdfg.add_array("result_getrf" + suffix, [1], dace.dtypes.int32, storage=storage, transient=transient) sdfg.add_array("result_getrs" + suffix, [1], dace.dtypes.int32, storage=storage, transient=transient) if transient: Ahi = state.add_read("A") Ai = state.add_access("A" + suffix) Ain = state.add_access("AT" + suffix) Aout = state.add_access("AT" + suffix) Bhi = state.add_read("B") Bho = state.add_read("B") Bin = state.add_access("B" + suffix) Bout = state.add_access("B" + suffix) transpose_in = blas.nodes.transpose.Transpose("transpose_in", dtype=dtype) transpose_in.implementation = "cuBLAS" state.add_nedge(Ahi, Ai, Memlet.from_array(*Ahost_arr)) state.add_nedge(Bhi, Bin, Memlet.from_array(*Bhost_arr)) state.add_nedge(Bout, Bho, Memlet.from_array(*Bhost_arr)) state.add_memlet_path(Ai, transpose_in, dst_conn='_inp', memlet=Memlet.from_array(*A_arr)) state.add_memlet_path(transpose_in, Ain, src_conn='_out', memlet=Memlet.from_array(*AT_arr)) else: Ain = state.add_access("A" + suffix) Aout = state.add_access("A" + suffix) Bin = state.add_access("B" + suffix) Bout = state.add_access("B" + suffix) pivots = state.add_access("pivots" + suffix) res_getrf = state.add_access("result_getrf" + suffix) res_getrs = state.add_access("result_getrs" + suffix) getrf_node = lapack.nodes.getrf.Getrf("getrf") getrf_node.implementation = implementation getrs_node = lapack.nodes.getrs.Getrs("getrs") getrs_node.implementation = implementation state.add_memlet_path(Ain, getrf_node, dst_conn="_xin", memlet=Memlet.simple(Ain, "0:n, 0:n", num_accesses=n * n)) state.add_memlet_path(getrf_node, res_getrf, src_conn="_res", memlet=Memlet.simple(res_getrf, "0", num_accesses=1)) state.add_memlet_path(getrs_node, res_getrs, src_conn="_res", memlet=Memlet.simple(res_getrs, "0", num_accesses=1)) state.add_memlet_path(getrf_node, pivots, src_conn="_ipiv", memlet=Memlet.simple(pivots, "0:n", num_accesses=n)) state.add_memlet_path(pivots, getrs_node, dst_conn="_ipiv", memlet=Memlet.simple(pivots, "0:n", num_accesses=n)) state.add_memlet_path(getrf_node, Aout, src_conn="_xout", memlet=Memlet.simple(Aout, "0:n, 0:n", num_accesses=n * n)) state.add_memlet_path(Aout, getrs_node, dst_conn="_a", memlet=Memlet.simple(Aout, "0:n, 0:n", num_accesses=n * n)) state.add_memlet_path(Bin, getrs_node, dst_conn="_rhs_in", memlet=Memlet.simple(Bin, "0:n", num_accesses=n)) state.add_memlet_path(getrs_node, Bout, src_conn="_rhs_out", memlet=Memlet.simple(Bout, "0:n", num_accesses=n)) return sdfg
def mm_small( state, A_node, B_node, C_node, A_subset=None, B_subset=None, C_subset=None, A_memlet=None, B_memlet=None, C_memlet=None, map_entry=None, map_exit=None, A_direct=True, B_direct=True, ): # C = A@B sdfg = state.parent Cshape = C_node.desc(sdfg).shape if C_subset is None else C_subset if isinstance(B_node, str): someshape = list( A_node.desc(sdfg).shape) if A_subset is None else A_subset someshape = [someshape[0]] + [someshape[1]] + [Cshape[1]] elif isinstance(A_node, str): someshape = list( B_node.desc(sdfg).shape) if B_subset is None else B_subset someshape = [Cshape[0]] + [someshape[0]] + [someshape[1]] else: raise AssertionError( "one of these have to be an str, else don't call this function") mapRange = ["0:" + str(_s) for _s in someshape] mapParams = ["i2", "i3", "i4"] mmapEntry, mmapExit = state.add_map( "matmul_sequential", dict(zip(mapParams, mapRange)), schedule=dace.ScheduleType.Sequential, ) if isinstance(A_node, str): tasklet = state.add_tasklet("matmul_sequential", {"j1"}, {"out"}, "out=" + A_node + "[i2, i3]" + "*j1") if B_memlet: state.add_edge(map_entry, None, mmapEntry, None, B_memlet) b_memlet_trailing = [str(_t[0]) for _t in B_memlet.subset[-2:]] state.add_edge( mmapEntry, None, tasklet, "j1", Memlet.simple(B_node, ",".join(["i3", "i4"] + b_memlet_trailing)), ) else: state.add_edge( B_node if B_direct else map_entry, None, mmapEntry, None, Memlet.from_array(B_node, B_node.desc(sdfg)), ) state.add_edge( mmapEntry, None, tasklet, "j1", Memlet.simple(B_node, ",".join(["i3", "i4"])), ) else: tasklet = state.add_tasklet("matmul_sequential", {"j0"}, {"out"}, "out=j0*" + B_node + "[i3, i4]") if A_memlet: state.add_edge(map_entry, None, mmapEntry, None, A_memlet) a_memlet_trailing = [str(_t[0]) for _t in A_memlet.subset[-2:]] state.add_edge( mmapEntry, None, tasklet, "j0", Memlet.simple(A_node, ",".join(["i2", "i3"] + a_memlet_trailing)), ) else: state.add_edge( A_node if A_direct else map_entry, None, mmapEntry, None, Memlet.from_array(A_node, A_node.desc(sdfg)), ) state.add_edge( mmapEntry, None, tasklet, "j0", Memlet.simple(A_node, ",".join(["i2", "i3"])), ) if C_memlet: c_memlet_trailing = [str(_t[0]) for _t in C_memlet.subset[-2:]] state.add_edge( tasklet, "out", mmapExit, None, Memlet.simple( C_node, ",".join(["i2", "i4"] + c_memlet_trailing), wcr_str="lambda a,b: a+b", wcr_conflict=False, ), ) state.add_edge(mmapExit, None, map_exit, None, C_memlet) else: state.add_edge( tasklet, "out", mmapExit, None, Memlet.simple( C_node, ",".join(["i2", "i4"]), wcr_str="lambda a,b:a+b", wcr_conflict=False, ), ) state.add_edge( mmapExit, None, C_node if map_exit is None else map_exit, None, Memlet.simple( C_node, ",".join(["0:" + str(_s) for _s in C_node.desc(sdfg).shape]), wcr_str="lambda a,b:a+b", wcr_conflict=False, ), )
def winograd_convolution(dace_session, tf_node): debugNodes = [] state = dace_session.state add_cublas_cusolver(dace_session.graph) #############Add constants for transformation matrices############### dace_session.graph.add_constant('Btrans', bt) dace_session.graph.add_constant('B', b) bNode = 'B' bTransposeNode = 'Btrans' dace_session.graph.add_constant('G', g) dace_session.graph.add_constant('Gtrans', gt) gNode = 'G' gTransposeNode = 'Gtrans' dace_session.graph.add_constant('Atrans', at) dace_session.graph.add_constant('A', a) aNode = 'A' aTransposeNode = 'Atrans' inputNodes = [] inputParams = [] inputDims = [] for _inp in tf_node.inputs: _node, _params, _dims = dace_session.create_and_add_input_node(_inp) inputNodes.append(_node) inputParams.append(_params) inputDims.append(_dims) # Manually add copy for kernel from CPU to GPU kernel_desc = inputNodes[1].desc(dace_session.graph) kernelGPU = state.add_transient( inputNodes[1].data + "GPU", shape=kernel_desc.shape, dtype=kernel_desc.dtype, lifetime=dtypes.AllocationLifetime.SDFG, storage=dace.StorageType.GPU_Global, ) state.add_edge( inputNodes[1], None, kernelGPU, None, Memlet.from_array(inputNodes[1], inputNodes[1].desc(dace_session.graph)), ) inputNodes[1] = kernelGPU outputList = dace_session.create_and_add_output_node(tf_node) outputDims = dace_session.get_default_dims(tf_node.outputs[0]) if str(tf_node.get_attr("padding"))[2:-1] == "SAME": paddedInput, paddedDims = dace_session.inputPadding( tf_node, inputNodes[0], inputNodes[0].desc(dace_session.graph), outputList[0].desc(dace_session.graph).shape[1], inputNodes[1].desc(dace_session.graph).shape[0], tf_node.get_attr("strides")[1], inputDims[0], ) inputDims[0] = paddedDims inputNodes[0] = paddedInput outputShape = [int(_s) for _s in tf_node.outputs[0].shape] inputViewShape = [ IMAGE_TILE_SIZE, IMAGE_TILE_SIZE, tf_node.inputs[0].shape[-1], outputShape[0] * ceil(outputShape[1] / OUTPUT_TILE_SIZE) * ceil(outputShape[2] / OUTPUT_TILE_SIZE), ] inputViewDims = ["0:" + str(_x) for _x in inputViewShape] ########Tiling the image################################# inputViewParams = [ "i3%" + str(outputShape[0]), "(i3/" + str(outputShape[0]) + ")%" # + str(output_shape[0] * ceil(output_shape[1] / OUTPUT_TILE_SIZE)) + str(ceil(outputShape[2] / OUTPUT_TILE_SIZE)) + "*" + str(OUTPUT_TILE_SIZE) + "+i0", # + str( # ceil(output_shape[1] / OUTPUT_TILE_SIZE) # * ceil(output_shape[2] / OUTPUT_TILE_SIZE) # ), "int_floor(i3," # + str(ceil(output_shape[1] / OUTPUT_TILE_SIZE)) + str(outputShape[0] * ceil(outputShape[2] / OUTPUT_TILE_SIZE)) + ")*" + str(OUTPUT_TILE_SIZE) + "+i1", "i2", ] inputView = state.add_transient( "V" + "_".join([str(_s) for _s in inputViewShape]), inputViewShape, dace.float32, dace.StorageType.GPU_Global, ) mapEntry, mapExit = state.add_map( string_builder(tf_node.name) + "_input_tile", dict(zip(inputParams[0], inputViewDims)), ) tasklet = state.add_tasklet( string_builder(tf_node.name) + "_input_tile", {"j0"}, {"out"}, "out = j0") dace_session.add_in_memlets([inputNodes[0]], mapEntry, tasklet, [inputDims[0]], [inputViewParams]) dace_session.add_out_memlets([inputView], mapExit, tasklet, [inputViewDims], [inputParams[0]]) ##################Transforming all input tiles######################### #[TODO] try to re-use memory vNode = state.add_transient( "V_output" + "_".join([str(_s) for _s in inputViewShape]), inputViewShape, dace.float32, dace.StorageType.GPU_Global, ) vNode.setzero = True mapEntry, mapExit = state.add_map( string_builder(tf_node.name) + "_input_txform", dict(zip(inputParams[0][0:2], inputViewDims[2:4])), dace.ScheduleType.GPU_Device, ) intermediateResultNode = state.add_transient("BtI", bt.shape, dace.float32, dace.StorageType.Register) intermediateResultNode.setzero = True state.add_edge( inputView, None, mapEntry, None, Memlet.simple(inputView, ",".join(inputViewDims)), ) mm_small( state, bTransposeNode, inputView, intermediateResultNode, B_subset=[IMAGE_TILE_SIZE, IMAGE_TILE_SIZE], B_memlet=Memlet.simple( inputView, ",".join(inputViewDims[0:2] + inputParams[0][0:2])), map_entry=mapEntry, B_direct=False, ) mm_small( state, intermediateResultNode, bNode, vNode, map_exit=mapExit, C_subset=[IMAGE_TILE_SIZE, IMAGE_TILE_SIZE], C_memlet=Memlet.simple( vNode, ",".join(inputViewDims[0:2] + inputParams[0][0:2]), wcr_str="lambda a,b: a+b", wcr_conflict=False, ), map_entry=mapEntry, A_direct=True, ) state.add_edge( mapExit, None, vNode, None, Memlet.simple( vNode, ",".join(inputViewDims), wcr_str="lambda a,b: a+b", wcr_conflict=False, ), ) #############Transforming the kernel############################### mapEntry, mapExit = state.add_map( string_builder(tf_node.name) + "_kernel_txform", dict(zip(inputParams[1][0:2], inputDims[1][2:4])), dace.ScheduleType.GPU_Device, ) intermediateResultNode = state.add_transient("GF", g.shape, dace.float32, dace.StorageType.Register) intermediateResultNode.setzero = True processedKernelNode = state.add_transient( "U" + "_".join([ str(_s) for _s in inputViewShape[0:2] + list(tf_node.inputs[1].shape[-1:-3:-1]) ]), inputViewShape[0:2] + list(tf_node.inputs[1].shape[-1:-3:-1]), dace.float32, dace.StorageType.GPU_Global, ) processedKernelNode.setzero = True state.add_edge( inputNodes[1], None, mapEntry, None, dace.Memlet.from_array(inputNodes[1].data, inputNodes[1].desc(dace_session.graph)), ) mm_small( state, gNode, inputNodes[1], intermediateResultNode, map_entry=mapEntry, B_subset=tf_node.inputs[1].shape[0:2], B_memlet=Memlet.simple( inputNodes[1], ",".join(inputDims[1][0:2] + inputParams[1][0:2])), B_direct=False, ) mm_small( state, intermediateResultNode, gTransposeNode, processedKernelNode, C_subset=[IMAGE_TILE_SIZE, IMAGE_TILE_SIZE], C_memlet=Memlet.simple( processedKernelNode, ",".join(inputViewDims[0:2] + [inputParams[0][1]] + [inputParams[0][0]]), wcr_str="lambda a,b: a+b", wcr_conflict=False, ), map_entry=mapEntry, map_exit=mapExit, A_direct=True, ) state.add_edge( mapExit, None, processedKernelNode, None, Memlet.simple( processedKernelNode.data, ",".join([ "0:" + str(_s) for _s in processedKernelNode.desc(dace_session.graph).shape ]), wcr_str="lambda a,b: a+b", wcr_conflict=False, ), ) mNode = state.add_transient( "m" + "_".join([ str(_s) for _s in inputViewShape[0:2] + [tf_node.inputs[1].shape[-1], inputViewShape[-1]] ]), inputViewShape[0:2] + [tf_node.inputs[1].shape[-1], inputViewShape[-1]], dace.float32, dace.StorageType.GPU_Global, ) mNodeDims = ["0:" + str(_d) for _d in mNode.desc(dace_session.graph).shape] mapEntry, mapExit = state.add_map( string_builder(tf_node.name) + "_eltwise_product", dict(zip(inputParams[0][0:2], inputViewDims[0:2])), dace.ScheduleType.Sequential, ) state.add_edge( vNode, None, mapEntry, None, Memlet.from_array(vNode.data, vNode.desc(dace_session.graph)), ) state.add_edge( processedKernelNode, None, mapEntry, None, Memlet.from_array(processedKernelNode.data, processedKernelNode.desc(dace_session.graph)), ) mm( state, vNode, processedKernelNode, mNode, A_subset=inputViewShape[2:4], A_memlet=Memlet.simple( vNode, ",".join(inputParams[0][0:2] + inputViewDims[-2:])), B_subset=tf_node.inputs[1].shape[-1:-3:-1], B_memlet=Memlet.simple( processedKernelNode, ",".join( inputParams[0][0:2] + ["0:" + str(_s) for _s in tf_node.inputs[1].shape[-1:-3:-1]]), ), C_subset=[tf_node.inputs[1].shape[-1], inputViewShape[-1]], C_memlet=Memlet.simple(mNode, ",".join(inputParams[0][0:2] + mNodeDims[-2:])), map_entry=mapEntry, map_exit=mapExit, shadow_a=True, shadow_b=True, ) state.add_edge(mapExit, None, mNode, None, Memlet.simple(mNode, ",".join(mNodeDims))) #################OUTPUT TRANSFORMATION################################ mapRange = [inputDims[1][-1]] + [inputViewDims[-1]] mapEntry, mapExit = state.add_map( string_builder(tf_node.name) + "_output_txform", dict(zip(inputParams[0][0:2], mapRange)), dace.ScheduleType.GPU_Device, ) intermediateResultNode = state.add_transient("AtM", at.shape, dace.float32, dace.StorageType.Register) intermediateResultNode.setzero = True transformedOutputNode = state.add_transient( "inv_txformed_output" + "_".join([str(tf_node.inputs[1].shape[-1])] + [str(inputViewShape[-1])]), [OUTPUT_TILE_SIZE, OUTPUT_TILE_SIZE] + [tf_node.inputs[1].shape[-1]] + [inputViewShape[-1]], dace.float32, dace.StorageType.GPU_Global, ) transformedOutputNode.setzero = True state.add_edge(mNode, None, mapEntry, None, Memlet.simple(mNode, ",".join(mNodeDims))) mm_small( state, aTransposeNode, mNode, intermediateResultNode, B_subset=inputViewShape[0:2], B_memlet=Memlet.simple( mNode, ",".join(inputViewDims[0:2] + inputParams[0][0:2])), map_entry=mapEntry, B_direct=False, ) mm_small( state, intermediateResultNode, aNode, transformedOutputNode, C_subset=[OUTPUT_TILE_SIZE, OUTPUT_TILE_SIZE], C_memlet=Memlet.simple( transformedOutputNode, ",".join( ["0:" + str(OUTPUT_TILE_SIZE), "0:" + str(OUTPUT_TILE_SIZE)] + inputParams[0][0:2]), wcr_str="lambda a,b:a+b", wcr_conflict=False, ), map_entry=mapEntry, map_exit=mapExit, A_direct=True, ) state.add_edge( mapExit, None, transformedOutputNode, None, Memlet.simple( transformedOutputNode.data, ",".join([ "0:" + str(_s) for _s in transformedOutputNode.desc(dace_session.graph).shape ]), wcr_str="lambda a,b: a+b", wcr_conflict=False, ), ) ###################Un-Tile the output to NHWC format################### outputParams = [ "i3%" + str(outputShape[0]), "(i3/" + str(outputShape[0]) + ")%" + str(ceil(outputShape[2] / OUTPUT_TILE_SIZE)) + "*" + str(OUTPUT_TILE_SIZE) + "+i0", "int_floor(i3," + str(outputShape[0] * ceil(outputShape[2] / OUTPUT_TILE_SIZE)) + ")*" + str(OUTPUT_TILE_SIZE) + "+i1", "i2", ] mapRange = [ "0:" + str(_s) for _s in transformedOutputNode.desc(dace_session.graph).shape ] mapEntry, mapExit = state.add_map( string_builder(tf_node.name) + "_output_untile", dict(zip(inputParams[0], mapRange)), ) tasklet = state.add_tasklet( string_builder(tf_node.name) + "_output_untile", {"j0"}, {"out"}, "out = j0") dace_session.add_in_memlets([transformedOutputNode], mapEntry, tasklet, [mapRange], [inputParams[0]]) dace_session.add_out_memlets(outputList, mapExit, tasklet, [outputDims], [outputParams]) ################# Debugging with callbacks ############# taskletInputs = ["i" + str(index) for index in range(len(debugNodes))] callback_tasklet = state.add_tasklet( string_builder(tf_node.name) + "_printer", {*taskletInputs}, {}, string_builder(tf_node.name) + "_printer" + "(" + ",".join(taskletInputs) + ");", language=dace.dtypes.Language.CPP, ) for _n, _conn in zip(debugNodes, taskletInputs): _n_cpu = state.add_transient(_n.data + "_cpucopy", _n.desc(dace_session.graph).shape, _n.desc(dace_session.graph).dtype, storage=dace.StorageType.CPU_Heap, lifetime=dtypes.AllocationLifetime.SDFG) state.add_edge(_n, None, _n_cpu, None, Memlet.from_array(_n, _n.desc(dace_session.graph))) state.add_edge( _n_cpu, None, callback_tasklet, _conn, Memlet.from_array(_n_cpu, _n_cpu.desc(dace_session.graph)), ) callback_input_types = [] for somenode in debugNodes: callback_input_types.append(somenode.desc(dace_session.graph)) dace_session.callbackFunctionDict[string_builder(tf_node.name) + "_printer"] = printer dace_session.callbackTypeDict[string_builder(tf_node.name) + "_printer"] = dace.data.Scalar( dace.callback(None, *callback_input_types))
def _bcgather(pv: 'ProgramVisitor', sdfg: SDFG, state: SDFGState, in_buffer: str, out_buffer: str, block_sizes: Union[str, Sequence[Union[sp.Expr, Number]]]): from dace.libraries.pblas.nodes.pgeadd import BlockCyclicGather libnode = BlockCyclicGather('_BCGather_') inbuf_range = None if isinstance(in_buffer, tuple): inbuf_name, inbuf_range = in_buffer else: inbuf_name = in_buffer in_desc = sdfg.arrays[inbuf_name] inbuf_node = state.add_read(inbuf_name) bsizes_range = None if isinstance(block_sizes, (list, tuple)): if isinstance(block_sizes[0], str): bsizes_name, bsizes_range = block_sizes bsizes_desc = sdfg.arrays[bsizes_name] bsizes_node = state.add_read(bsizes_name) else: bsizes_name, bsizes_desc = sdfg.add_temp_transient( (len(block_sizes), ), dtype=dace.int32) bsizes_node = state.add_access(bsizes_name) bsizes_tasklet = state.add_tasklet( '_set_bsizes_', {}, {'__out'}, ";".join([ "__out[{}] = {}".format(i, sz) for i, sz in enumerate(block_sizes) ])) state.add_edge(bsizes_tasklet, '__out', bsizes_node, None, Memlet.from_array(bsizes_name, bsizes_desc)) else: bsizes_name = block_sizes bsizes_desc = sdfg.arrays[bsizes_name] bsizes_node = state.add_read(bsizes_name) outbuf_range = None if isinstance(out_buffer, tuple): outbuf_name, outbuf_range = out_buffer else: outbuf_name = out_buffer out_desc = sdfg.arrays[outbuf_name] outbuf_node = state.add_write(outbuf_name) if inbuf_range: inbuf_mem = Memlet.simple(inbuf_name, inbuf_range) else: inbuf_mem = Memlet.from_array(inbuf_name, in_desc) if bsizes_range: bsizes_mem = Memlet.simple(bsizes_name, bsizes_range) else: bsizes_mem = Memlet.from_array(bsizes_name, bsizes_desc) if outbuf_range: outbuf_mem = Memlet.simple(outbuf_name, outbuf_range) else: outbuf_mem = Memlet.from_array(outbuf_name, out_desc) state.add_edge(inbuf_node, None, libnode, '_inbuffer', inbuf_mem) state.add_edge(bsizes_node, None, libnode, '_block_sizes', bsizes_mem) state.add_edge(libnode, '_outbuffer', outbuf_node, None, outbuf_mem) return None
def _create_einsum_internal(sdfg: SDFG, state: SDFGState, einsum_string: str, *arrays: str, dtype: Optional[dtypes.typeclass] = None, optimize: bool = False, output: Optional[str] = None, nodes: Optional[Dict[str, AccessNode]] = None, init_output: bool = None): # Infer shapes and strides of input/output arrays einsum = EinsumParser(einsum_string) if len(einsum.inputs) != len(arrays): raise ValueError('Invalid number of arrays for einsum expression') # Get shapes from arrays and verify dimensionality chardict = {} for inp, inpname in zip(einsum.inputs, arrays): inparr = sdfg.arrays[inpname] if len(inp) != len(inparr.shape): raise ValueError('Dimensionality mismatch in input "%s"' % inpname) for char, shp in zip(inp, inparr.shape): if char in chardict and shp != chardict[char]: raise ValueError('Dimension mismatch in einsum expression') chardict[char] = shp if optimize: # Try to import opt_einsum try: import opt_einsum as oe except (ModuleNotFoundError, NameError, ImportError): raise ImportError('To optimize einsum expressions, please install ' 'the "opt_einsum" package.') for char, shp in chardict.items(): if symbolic.issymbolic(shp): raise ValueError('Einsum optimization cannot be performed ' 'on symbolically-sized array dimension "%s" ' 'for subscript character "%s"' % (shp, char)) # Create optimal contraction path # noinspection PyTypeChecker _, path_info = oe.contract_path( einsum_string, *oe.helpers.build_views(einsum_string, chardict)) input_nodes = nodes or {arr: state.add_read(arr) for arr in arrays} result_node = None # Follow path and create a chain of operation SDFG states for pair, nonfree, expr, after, blas in path_info.contraction_list: result, result_node = _create_einsum_internal(sdfg, state, expr, arrays[pair[0]], arrays[pair[1]], dtype=dtype, optimize=False, output=None, nodes=input_nodes) arrays = ([a for i, a in enumerate(arrays) if i not in pair] + [result]) input_nodes[result] = result_node return arrays[0], result_node # END of einsum optimization input_nodes = nodes or {arr: state.add_read(arr) for arr in arrays} # Get output shape from chardict, or [1] for a scalar output output_shape = list(map(lambda k: chardict[k], einsum.output)) or [1] output_index = ','.join(o for o in einsum.output) or '0' if output is None: dtype = dtype or sdfg.arrays[arrays[0]].dtype output, odesc = sdfg.add_temp_transient(output_shape, dtype) to_init = True else: odesc = sdfg.arrays[output] dtype = dtype or odesc.dtype to_init = init_output or True is_conflicted = not all( all(indim in einsum.output for indim in inp) for inp in einsum.inputs) if not is_conflicted and init_output is None: to_init = False if not einsum.is_bmm(): # Fall back to "pure" SDFG einsum with conflict resolution c = state.add_write(output) # Add state before this one to initialize the output value if to_init: init_state = sdfg.add_state_before(state) if len(einsum.output) > 0: init_state.add_mapped_tasklet( 'einsum_reset', {k: '0:%s' % chardict[k] for k in einsum.output}, {}, 'out_%s = 0' % output, {'out_%s' % output: Memlet.simple(output, output_index)}, external_edges=True) else: # Scalar output t = init_state.add_tasklet('einsum_reset', set(), {'out_%s' % output}, 'out_%s = 0' % output) onode = init_state.add_write(output) init_state.add_edge(t, 'out_%s' % output, onode, None, Memlet.simple(output, '0')) wcr = 'lambda a,b: a+b' if is_conflicted else None # Pure einsum map state.add_mapped_tasklet( 'einsum', {k: '0:%s' % v for k, v in chardict.items()}, { 'inp_%s' % arr: Memlet.simple(arr, ','.join(inp)) for inp, arr in zip(einsum.inputs, arrays) }, 'out_%s = %s' % (output, ' * '.join('inp_%s' % arr for arr in arrays)), { 'out_%s' % output: Memlet.simple( output, output_index, wcr_str=wcr) }, input_nodes=input_nodes, output_nodes={output: c}, external_edges=True) else: # Represent einsum as a GEMM or batched GEMM (using library nodes) a_shape = sdfg.arrays[arrays[0]].shape b_shape = sdfg.arrays[arrays[1]].shape c_shape = output_shape a = input_nodes[arrays[0]] b = input_nodes[arrays[1]] c = state.add_write(output) # Compute GEMM dimensions and strides strides = dict( BATCH=prod([c_shape[dim] for dim in einsum.c_batch]), M=prod([a_shape[dim] for dim in einsum.a_only]), K=prod([a_shape[dim] for dim in einsum.a_sum]), N=prod([b_shape[dim] for dim in einsum.b_only]), sAM=prod(a_shape[einsum.a_only[-1] + 1:]) if einsum.a_only else 1, sAK=prod(a_shape[einsum.a_sum[-1] + 1:]) if einsum.a_sum else 1, sAB=prod(a_shape[einsum.a_batch[-1] + 1:]) if einsum.a_batch else 1, sBK=prod(b_shape[einsum.b_sum[-1] + 1:]) if einsum.b_sum else 1, sBN=prod(b_shape[einsum.b_only[-1] + 1:]) if einsum.b_only else 1, sBB=prod(b_shape[einsum.b_batch[-1] + 1:]) if einsum.b_batch else 1, sCM=prod(c_shape[einsum.c_a_only[-1] + 1:]) if einsum.c_a_only else 1, sCN=prod(c_shape[einsum.c_b_only[-1] + 1:]) if einsum.c_b_only else 1, sCB=prod(c_shape[einsum.c_batch[-1] + 1:]) if einsum.c_batch else 1) # Complement strides to make matrices as necessary if len(a_shape) == 1 and len(einsum.a_sum) == 1: strides['sAK'] = 1 strides['sAB'] = strides['sAM'] = strides['K'] if len(b_shape) == 1 and len(einsum.b_sum) == 1: strides['sBN'] = 1 strides['sBK'] = 1 strides['sBB'] = strides['K'] if len(c_shape) == 1 and len(einsum.a_sum) == len(einsum.b_sum): strides['sCN'] = 1 strides['sCB'] = strides['sCM'] = strides['N'] # Create nested SDFG for GEMM nsdfg = create_batch_gemm_sdfg(dtype, strides) nsdfg_node = state.add_nested_sdfg(nsdfg, None, {'X', 'Y'}, {'Z'}, strides) state.add_edge(a, None, nsdfg_node, 'X', Memlet.from_array(a.data, a.desc(sdfg))) state.add_edge(b, None, nsdfg_node, 'Y', Memlet.from_array(b.data, b.desc(sdfg))) state.add_edge(nsdfg_node, 'Z', c, None, Memlet.from_array(c.data, c.desc(sdfg))) return output, c
def _redistribute(pv: 'ProgramVisitor', sdfg: SDFG, state: SDFGState, in_buffer: str, in_subarray: str, out_buffer: str, out_subarray: str): """ Redistributes an Array using process-grids, sub-arrays, and the Redistribute library node. :param in_buffer: Name of the (local) input Array descriptor. :param in_subarray: Input sub-array descriptor. :param out_buffer: Name of the (local) output Array descriptor. :param out_subarray: Output sub-array descriptor. :return: Name of the new redistribution descriptor. """ in_desc = sdfg.arrays[in_buffer] out_desc = sdfg.arrays[out_buffer] rdistrarray_name = sdfg.add_rdistrarray(in_subarray, out_subarray) from dace.libraries.mpi import Dummy, Redistribute tasklet = Dummy(rdistrarray_name, [ f'MPI_Datatype {rdistrarray_name};', f'int {rdistrarray_name}_sends;', f'MPI_Datatype* {rdistrarray_name}_send_types;', f'int* {rdistrarray_name}_dst_ranks;', f'int {rdistrarray_name}_recvs;', f'MPI_Datatype* {rdistrarray_name}_recv_types;', f'int* {rdistrarray_name}_src_ranks;', f'int {rdistrarray_name}_self_copies;', f'int* {rdistrarray_name}_self_src;', f'int* {rdistrarray_name}_self_dst;', f'int* {rdistrarray_name}_self_size;' ]) state.add_node(tasklet) _, scal = sdfg.add_scalar(rdistrarray_name, dace.int32, transient=True) wnode = state.add_write(rdistrarray_name) state.add_edge(tasklet, '__out', wnode, None, Memlet.from_array(rdistrarray_name, scal)) libnode = Redistribute('_Redistribute_', rdistrarray_name) inbuf_range = None if isinstance(in_buffer, tuple): inbuf_name, inbuf_range = in_buffer else: inbuf_name = in_buffer in_desc = sdfg.arrays[inbuf_name] inbuf_node = state.add_read(inbuf_name) outbuf_range = None if isinstance(out_buffer, tuple): outbuf_name, outbuf_range = out_buffer else: outbuf_name = out_buffer out_desc = sdfg.arrays[outbuf_name] outbuf_node = state.add_write(outbuf_name) if inbuf_range: inbuf_mem = Memlet.simple(inbuf_name, inbuf_range) else: inbuf_mem = Memlet.from_array(inbuf_name, in_desc) if outbuf_range: outbuf_mem = Memlet.simple(outbuf_name, outbuf_range) else: outbuf_mem = Memlet.from_array(outbuf_name, out_desc) state.add_edge(inbuf_node, None, libnode, '_inp_buffer', inbuf_mem) state.add_edge(libnode, '_out_buffer', outbuf_node, None, outbuf_mem) return rdistrarray_name
def _irecv(pv: 'ProgramVisitor', sdfg: SDFG, state: SDFGState, buffer: str, src: Union[str, sp.Expr, Number], tag: Union[str, sp.Expr, Number], request: str): from dace.libraries.mpi.nodes.irecv import Irecv libnode = Irecv('_Irecv_') buf_range = None if isinstance(buffer, tuple): buf_name, buf_range = buffer else: buf_name = buffer desc = sdfg.arrays[buf_name] buf_node = state.add_read(buf_name) req_range = None if isinstance(request, tuple): req_name, req_range = request else: req_name = request req_desc = sdfg.arrays[req_name] req_node = state.add_write(req_name) conn = libnode.out_connectors conn = { c: (dtypes.pointer(desc.dtype) if c == '_buffer' else t) for c, t in conn.items() } conn = { c: (dtypes.pointer(req_desc.dtype) if c == '_request' else t) for c, t in conn.items() } libnode.out_connectors = conn src_range = None if isinstance(src, tuple): src_name, src_range = src src_node = state.add_read(src_name) elif isinstance(src, str) and src in sdfg.arrays.keys(): src_name = src src_node = state.add_read(src_name) else: storage = desc.storage src_name = _define_local_scalar(pv, sdfg, state, dace.int32, storage) src_node = state.add_access(src_name) src_tasklet = state.add_tasklet('_set_src_', {}, {'__out'}, '__out = {}'.format(src)) state.add_edge(src_tasklet, '__out', src_node, None, Memlet.simple(src_name, '0')) tag_range = None if isinstance(tag, tuple): tag_name, tag_range = tag tag_node = state.add_read(tag_name) if isinstance(tag, str) and tag in sdfg.arrays.keys(): tag_name = tag tag_node = state.add_read(tag) else: storage = desc.storage tag_name = _define_local_scalar(pv, sdfg, state, dace.int32, storage) tag_node = state.add_access(tag_name) tag_tasklet = state.add_tasklet('_set_tag_', {}, {'__out'}, '__out = {}'.format(tag)) state.add_edge(tag_tasklet, '__out', tag_node, None, Memlet.simple(tag_name, '0')) if buf_range: buf_mem = Memlet.simple(buf_name, buf_range) else: buf_mem = Memlet.from_array(buf_name, desc) if req_range: req_mem = Memlet.simple(req_name, req_range) else: req_mem = Memlet.from_array(req_name, req_desc) if src_range: src_mem = Memlet.simple(src_name, src_range) else: src_mem = Memlet.simple(src_name, '0') if tag_range: tag_mem = Memlet.simple(tag_name, tag_range) else: tag_mem = Memlet.simple(tag_name, '0') state.add_edge(libnode, '_buffer', buf_node, None, buf_mem) state.add_edge(src_node, None, libnode, '_src', src_mem) state.add_edge(tag_node, None, libnode, '_tag', tag_mem) state.add_edge(libnode, '_request', req_node, None, req_mem) return None
def _elementwise(sdfg: SDFG, state: SDFGState, func: str, in_array: str, out_array=None): """Apply a lambda function to each element in the input""" inparr = sdfg.arrays[in_array] restype = sdfg.arrays[in_array].dtype if out_array is None: out_array, outarr = sdfg.add_temp_transient(inparr.shape, restype, inparr.storage) else: outarr = sdfg.arrays[out_array] func_ast = ast.parse(func) try: lambda_ast = func_ast.body[0].value if len(lambda_ast.args.args) != 1: raise SyntaxError( "Expected lambda with one arg, but {} has {}".format( func, len(lambda_ast.args.arrgs))) arg = lambda_ast.args.args[0].arg body = astutils.unparse(lambda_ast.body) except AttributeError: raise SyntaxError("Could not parse func {}".format(func)) code = "__out = {}".format(body) num_elements = reduce(lambda x, y: x * y, inparr.shape) if num_elements == 1: inp = state.add_read(in_array) out = state.add_write(out_array) tasklet = state.add_tasklet("_elementwise_", {arg}, {'__out'}, code) state.add_edge(inp, None, tasklet, arg, Memlet.from_array(in_array, inparr)) state.add_edge(tasklet, '__out', out, None, Memlet.from_array(out_array, outarr)) else: state.add_mapped_tasklet( name="_elementwise_", map_ranges={ '__i%d' % i: '0:%s' % n for i, n in enumerate(inparr.shape) }, inputs={ arg: Memlet.simple( in_array, ','.join(['__i%d' % i for i in range(len(inparr.shape))])) }, code=code, outputs={ '__out': Memlet.simple( out_array, ','.join(['__i%d' % i for i in range(len(inparr.shape))])) }, external_edges=True) return out_array
def nest_state_subgraph(sdfg: SDFG, state: SDFGState, subgraph: SubgraphView, name: Optional[str] = None, full_data: bool = False) -> nodes.NestedSDFG: """ Turns a state subgraph into a nested SDFG. Operates in-place. :param sdfg: The SDFG containing the state subgraph. :param state: The state containing the subgraph. :param subgraph: Subgraph to nest. :param name: An optional name for the nested SDFG. :param full_data: If True, nests entire input/output data. :return: The nested SDFG node. :raise KeyError: Some or all nodes in the subgraph are not located in this state, or the state does not belong to the given SDFG. :raise ValueError: The subgraph is contained in more than one scope. """ if state.parent != sdfg: raise KeyError('State does not belong to given SDFG') if subgraph.graph != state: raise KeyError('Subgraph does not belong to given state') # Find the top-level scope scope_tree = state.scope_tree() scope_dict = state.scope_dict() scope_dict_children = state.scope_dict(True) top_scopenode = -1 # Initialized to -1 since "None" already means top-level for node in subgraph.nodes(): if node not in scope_dict: raise KeyError('Node not found in state') # If scope entry/exit, ensure entire scope is in subgraph if isinstance(node, nodes.EntryNode): scope_nodes = scope_dict_children[node] if any(n not in subgraph.nodes() for n in scope_nodes): raise ValueError('Subgraph contains partial scopes (entry)') elif isinstance(node, nodes.ExitNode): entry = state.entry_node(node) scope_nodes = scope_dict_children[entry] + [entry] if any(n not in subgraph.nodes() for n in scope_nodes): raise ValueError('Subgraph contains partial scopes (exit)') scope_node = scope_dict[node] if scope_node not in subgraph.nodes(): if top_scopenode != -1 and top_scopenode != scope_node: raise ValueError( 'Subgraph is contained in more than one scope') top_scopenode = scope_node scope = scope_tree[top_scopenode] ### # Collect inputs and outputs of the nested SDFG inputs: List[MultiConnectorEdge] = [] outputs: List[MultiConnectorEdge] = [] for node in subgraph.source_nodes(): inputs.extend(state.in_edges(node)) for node in subgraph.sink_nodes(): outputs.extend(state.out_edges(node)) # Collect transients not used outside of subgraph (will be removed of # top-level graph) data_in_subgraph = set(n.data for n in subgraph.nodes() if isinstance(n, nodes.AccessNode)) # Find other occurrences in SDFG other_nodes = set( n.data for s in sdfg.nodes() for n in s.nodes() if isinstance(n, nodes.AccessNode) and n not in subgraph.nodes()) subgraph_transients = set() for data in data_in_subgraph: datadesc = sdfg.arrays[data] if datadesc.transient and data not in other_nodes: subgraph_transients.add(data) # All transients of edges between code nodes are also added to nested graph for edge in subgraph.edges(): if (isinstance(edge.src, nodes.CodeNode) and isinstance(edge.dst, nodes.CodeNode)): subgraph_transients.add(edge.data.data) # Collect data used in access nodes within subgraph (will be referenced in # full upon nesting) input_arrays = set() output_arrays = set() for node in subgraph.nodes(): if (isinstance(node, nodes.AccessNode) and node.data not in subgraph_transients): if state.out_degree(node) > 0: input_arrays.add(node.data) if state.in_degree(node) > 0: output_arrays.add(node.data) # Create the nested SDFG nsdfg = SDFG(name or 'nested_' + state.label) # Transients are added to the nested graph as-is for name in subgraph_transients: nsdfg.add_datadesc(name, sdfg.arrays[name]) # Input/output data that are not source/sink nodes are added to the graph # as non-transients for name in (input_arrays | output_arrays): datadesc = copy.deepcopy(sdfg.arrays[name]) datadesc.transient = False nsdfg.add_datadesc(name, datadesc) # Connected source/sink nodes outside subgraph become global data # descriptors in nested SDFG input_names = [] output_names = [] for edge in inputs: if edge.data.data is None: # Skip edges with an empty memlet continue name = '__in_' + edge.data.data datadesc = copy.deepcopy(sdfg.arrays[edge.data.data]) datadesc.transient = False if not full_data: datadesc.shape = edge.data.subset.size() input_names.append( nsdfg.add_datadesc(name, datadesc, find_new_name=True)) for edge in outputs: if edge.data.data is None: # Skip edges with an empty memlet continue name = '__out_' + edge.data.data datadesc = copy.deepcopy(sdfg.arrays[edge.data.data]) datadesc.transient = False if not full_data: datadesc.shape = edge.data.subset.size() output_names.append( nsdfg.add_datadesc(name, datadesc, find_new_name=True)) ################### # Add scope symbols to the nested SDFG for v in scope.defined_vars: if v in sdfg.symbols: sym = sdfg.symbols[v] nsdfg.add_symbol(v, sym.dtype) # Create nested state nstate = nsdfg.add_state() # Add subgraph nodes and edges to nested state nstate.add_nodes_from(subgraph.nodes()) for e in subgraph.edges(): nstate.add_edge(e.src, e.src_conn, e.dst, e.dst_conn, e.data) # Modify nested SDFG parents in subgraph for node in subgraph.nodes(): if isinstance(node, nodes.NestedSDFG): node.sdfg.parent = nstate node.sdfg.parent_sdfg = nsdfg # Add access nodes and edges as necessary edges_to_offset = [] for name, edge in zip(input_names, inputs): node = nstate.add_read(name) new_edge = copy.deepcopy(edge.data) new_edge.data = name edges_to_offset.append((edge, nstate.add_edge(node, None, edge.dst, edge.dst_conn, new_edge))) for name, edge in zip(output_names, outputs): node = nstate.add_write(name) new_edge = copy.deepcopy(edge.data) new_edge.data = name edges_to_offset.append((edge, nstate.add_edge(edge.src, edge.src_conn, node, None, new_edge))) # Offset memlet paths inside nested SDFG according to subsets for original_edge, new_edge in edges_to_offset: for edge in nstate.memlet_tree(new_edge): edge.data.data = new_edge.data.data if not full_data: edge.data.subset.offset(original_edge.data.subset, True) # Add nested SDFG node to the input state nested_sdfg = state.add_nested_sdfg(nsdfg, None, set(input_names) | input_arrays, set(output_names) | output_arrays) # Reconnect memlets to nested SDFG for name, edge in zip(input_names, inputs): if full_data: data = Memlet.from_array(edge.data.data, sdfg.arrays[edge.data.data]) else: data = edge.data state.add_edge(edge.src, edge.src_conn, nested_sdfg, name, data) for name, edge in zip(output_names, outputs): if full_data: data = Memlet.from_array(edge.data.data, sdfg.arrays[edge.data.data]) else: data = edge.data state.add_edge(nested_sdfg, name, edge.dst, edge.dst_conn, data) # Connect access nodes to internal input/output data as necessary entry = scope.entry exit = scope.exit for name in input_arrays: node = state.add_read(name) if entry is not None: state.add_nedge(entry, node, EmptyMemlet()) state.add_edge(node, None, nested_sdfg, name, Memlet.from_array(name, sdfg.arrays[name])) for name in output_arrays: node = state.add_write(name) if exit is not None: state.add_nedge(node, exit, EmptyMemlet()) state.add_edge(nested_sdfg, name, node, None, Memlet.from_array(name, sdfg.arrays[name])) # Remove subgraph nodes from graph state.remove_nodes_from(subgraph.nodes()) # Remove subgraph transients from top-level graph for transient in subgraph_transients: del sdfg.arrays[transient] return nested_sdfg
def _distr_matmult(pv: 'ProgramVisitor', sdfg: SDFG, state: SDFGState, opa: str, opb: str, shape: Sequence[Union[sp.Expr, Number]], a_block_sizes: Union[str, Sequence[Union[sp.Expr, Number]]] = None, b_block_sizes: Union[str, Sequence[Union[sp.Expr, Number]]] = None, c_block_sizes: Union[str, Sequence[Union[sp.Expr, Number]]] = None): arra = sdfg.arrays[opa] arrb = sdfg.arrays[opb] if len(shape) == 3: gm, gn, gk = shape else: gm, gn = shape a_block_sizes = a_block_sizes or arra.shape if len(a_block_sizes) < 2: a_block_sizes = (a_block_sizes[0], 1) b_block_sizes = b_block_sizes or arrb.shape if len(b_block_sizes) < 2: b_block_sizes = (b_block_sizes[0], 1) if len(arra.shape) == 1 and len(arrb.shape) == 2: a_block_sizes, b_block_sizes = b_block_sizes, a_block_sizes a_bsizes_range = None if isinstance(a_block_sizes, (list, tuple)): if isinstance(a_block_sizes[0], str): a_bsizes_name, a_bsizes_range = a_block_sizes a_bsizes_desc = sdfg.arrays[a_bsizes_name] a_bsizes_node = state.add_read(a_bsizes_name) else: a_bsizes_name, a_bsizes_desc = sdfg.add_temp_transient( (len(a_block_sizes), ), dtype=dace.int32) a_bsizes_node = state.add_access(a_bsizes_name) a_bsizes_tasklet = state.add_tasklet( '_set_a_bsizes_', {}, {'__out'}, ";".join([ "__out[{}] = {}".format(i, sz) for i, sz in enumerate(a_block_sizes) ])) state.add_edge(a_bsizes_tasklet, '__out', a_bsizes_node, None, Memlet.from_array(a_bsizes_name, a_bsizes_desc)) else: a_bsizes_name = a_block_sizes a_bsizes_desc = sdfg.arrays[a_bsizes_name] a_bsizes_node = state.add_read(a_bsizes_name) b_bsizes_range = None if isinstance(a_block_sizes, (list, tuple)): if isinstance(a_block_sizes[0], str): b_bsizes_name, b_sizes_range = b_block_sizes b_bsizes_desc = sdfg.arrays[b_bsizes_name] b_bsizes_node = state.add_read(b_bsizes_name) else: b_bsizes_name, b_bsizes_desc = sdfg.add_temp_transient( (len(b_block_sizes), ), dtype=dace.int32) b_bsizes_node = state.add_access(b_bsizes_name) b_bsizes_tasklet = state.add_tasklet( '_set_b_sizes_', {}, {'__out'}, ";".join([ "__out[{}] = {}".format(i, sz) for i, sz in enumerate(b_block_sizes) ])) state.add_edge(b_bsizes_tasklet, '__out', b_bsizes_node, None, Memlet.from_array(b_bsizes_name, b_bsizes_desc)) else: b_bsizes_name = b_block_sizes b_bsizes_desc = sdfg.arrays[b_bsizes_name] b_bsizes_node = state.add_read(b_bsizes_name) if len(arra.shape) == 2 and len(arrb.shape) == 2: # Gemm from dace.libraries.pblas.nodes.pgemm import Pgemm tasklet = Pgemm("__DistrMatMult__", gm, gn, gk) m = arra.shape[0] n = arrb.shape[-1] out = sdfg.add_temp_transient((m, n), dtype=arra.dtype) elif len(arra.shape) == 2 and len(arrb.shape) == 1: # Gemv from dace.libraries.pblas.nodes.pgemv import Pgemv tasklet = Pgemv("__DistrMatVecMult__", m=gm, n=gn) if c_block_sizes: m = c_block_sizes[0] else: m = arra.shape[0] out = sdfg.add_temp_transient((m, ), dtype=arra.dtype) elif len(arra.shape) == 1 and len(arrb.shape) == 2: # Gemv transposed # Swap a and b opa, opb = opb, opa arra, arrb = arrb, arra from dace.libraries.pblas.nodes.pgemv import Pgemv tasklet = Pgemv("__DistrMatVecMult__", transa='T', m=gm, n=gn) if c_block_sizes: n = c_block_sizes[0] else: n = arra.shape[1] out = sdfg.add_temp_transient((n, ), dtype=arra.dtype) anode = state.add_read(opa) bnode = state.add_read(opb) cnode = state.add_write(out[0]) if a_bsizes_range: a_bsizes_mem = Memlet.simple(a_bsizes_name, a_bsizes_range) else: a_bsizes_mem = Memlet.from_array(a_bsizes_name, a_bsizes_desc) if b_bsizes_range: b_bsizes_mem = Memlet.simple(b_bsizes_name, b_bsizes_range) else: b_bsizes_mem = Memlet.from_array(b_bsizes_name, b_bsizes_desc) state.add_edge(anode, None, tasklet, '_a', Memlet.from_array(opa, arra)) state.add_edge(bnode, None, tasklet, '_b', Memlet.from_array(opb, arrb)) state.add_edge(a_bsizes_node, None, tasklet, '_a_block_sizes', a_bsizes_mem) state.add_edge(b_bsizes_node, None, tasklet, '_b_block_sizes', b_bsizes_mem) state.add_edge(tasklet, '_c', cnode, None, Memlet.from_array(*out)) return out[0]
def nest_state_subgraph(sdfg: SDFG, state: SDFGState, subgraph: SubgraphView, name: Optional[str] = None, full_data: bool = False) -> nodes.NestedSDFG: """ Turns a state subgraph into a nested SDFG. Operates in-place. :param sdfg: The SDFG containing the state subgraph. :param state: The state containing the subgraph. :param subgraph: Subgraph to nest. :param name: An optional name for the nested SDFG. :param full_data: If True, nests entire input/output data. :return: The nested SDFG node. :raise KeyError: Some or all nodes in the subgraph are not located in this state, or the state does not belong to the given SDFG. :raise ValueError: The subgraph is contained in more than one scope. """ if state.parent != sdfg: raise KeyError('State does not belong to given SDFG') if subgraph is not state and subgraph.graph is not state: raise KeyError('Subgraph does not belong to given state') # Find the top-level scope scope_tree = state.scope_tree() scope_dict = state.scope_dict() scope_dict_children = state.scope_children() top_scopenode = -1 # Initialized to -1 since "None" already means top-level for node in subgraph.nodes(): if node not in scope_dict: raise KeyError('Node not found in state') # If scope entry/exit, ensure entire scope is in subgraph if isinstance(node, nodes.EntryNode): scope_nodes = scope_dict_children[node] if any(n not in subgraph.nodes() for n in scope_nodes): raise ValueError('Subgraph contains partial scopes (entry)') elif isinstance(node, nodes.ExitNode): entry = state.entry_node(node) scope_nodes = scope_dict_children[entry] + [entry] if any(n not in subgraph.nodes() for n in scope_nodes): raise ValueError('Subgraph contains partial scopes (exit)') scope_node = scope_dict[node] if scope_node not in subgraph.nodes(): if top_scopenode != -1 and top_scopenode != scope_node: raise ValueError('Subgraph is contained in more than one scope') top_scopenode = scope_node scope = scope_tree[top_scopenode] ### # Consolidate edges in top scope utils.consolidate_edges(sdfg, scope) snodes = subgraph.nodes() # Collect inputs and outputs of the nested SDFG inputs: List[MultiConnectorEdge] = [] outputs: List[MultiConnectorEdge] = [] for node in snodes: for edge in state.in_edges(node): if edge.src not in snodes: inputs.append(edge) for edge in state.out_edges(node): if edge.dst not in snodes: outputs.append(edge) # Collect transients not used outside of subgraph (will be removed of # top-level graph) data_in_subgraph = set(n.data for n in subgraph.nodes() if isinstance(n, nodes.AccessNode)) # Find other occurrences in SDFG other_nodes = set(n.data for s in sdfg.nodes() for n in s.nodes() if isinstance(n, nodes.AccessNode) and n not in subgraph.nodes()) subgraph_transients = set() for data in data_in_subgraph: datadesc = sdfg.arrays[data] if datadesc.transient and data not in other_nodes: subgraph_transients.add(data) # All transients of edges between code nodes are also added to nested graph for edge in subgraph.edges(): if (isinstance(edge.src, nodes.CodeNode) and isinstance(edge.dst, nodes.CodeNode)): subgraph_transients.add(edge.data.data) # Collect data used in access nodes within subgraph (will be referenced in # full upon nesting) input_arrays = set() output_arrays = {} for node in subgraph.nodes(): if (isinstance(node, nodes.AccessNode) and node.data not in subgraph_transients): if node.has_reads(state): input_arrays.add(node.data) if node.has_writes(state): output_arrays[node.data] = state.in_edges(node)[0].data.wcr # Create the nested SDFG nsdfg = SDFG(name or 'nested_' + state.label) # Transients are added to the nested graph as-is for name in subgraph_transients: nsdfg.add_datadesc(name, sdfg.arrays[name]) # Input/output data that are not source/sink nodes are added to the graph # as non-transients for name in (input_arrays | output_arrays.keys()): datadesc = copy.deepcopy(sdfg.arrays[name]) datadesc.transient = False nsdfg.add_datadesc(name, datadesc) # Connected source/sink nodes outside subgraph become global data # descriptors in nested SDFG input_names = {} output_names = {} global_subsets: Dict[str, Tuple[str, Subset]] = {} for edge in inputs: if edge.data.data is None: # Skip edges with an empty memlet continue name = edge.data.data if name not in global_subsets: datadesc = copy.deepcopy(sdfg.arrays[edge.data.data]) datadesc.transient = False if not full_data: datadesc.shape = edge.data.subset.size() new_name = nsdfg.add_datadesc(name, datadesc, find_new_name=True) global_subsets[name] = (new_name, edge.data.subset) else: new_name, subset = global_subsets[name] if not full_data: new_subset = union(subset, edge.data.subset) if new_subset is None: new_subset = Range.from_array(sdfg.arrays[name]) global_subsets[name] = (new_name, new_subset) nsdfg.arrays[new_name].shape = new_subset.size() input_names[edge] = new_name for edge in outputs: if edge.data.data is None: # Skip edges with an empty memlet continue name = edge.data.data if name not in global_subsets: datadesc = copy.deepcopy(sdfg.arrays[edge.data.data]) datadesc.transient = False if not full_data: datadesc.shape = edge.data.subset.size() new_name = nsdfg.add_datadesc(name, datadesc, find_new_name=True) global_subsets[name] = (new_name, edge.data.subset) else: new_name, subset = global_subsets[name] if not full_data: new_subset = union(subset, edge.data.subset) if new_subset is None: new_subset = Range.from_array(sdfg.arrays[name]) global_subsets[name] = (new_name, new_subset) nsdfg.arrays[new_name].shape = new_subset.size() output_names[edge] = new_name ################### # Add scope symbols to the nested SDFG defined_vars = set( symbolic.pystr_to_symbolic(s) for s in (state.symbols_defined_at(top_scopenode).keys() | sdfg.symbols)) for v in defined_vars: if v in sdfg.symbols: sym = sdfg.symbols[v] nsdfg.add_symbol(v, sym.dtype) # Add constants to nested SDFG for cstname, cstval in sdfg.constants.items(): nsdfg.add_constant(cstname, cstval) # Create nested state nstate = nsdfg.add_state() # Add subgraph nodes and edges to nested state nstate.add_nodes_from(subgraph.nodes()) for e in subgraph.edges(): nstate.add_edge(e.src, e.src_conn, e.dst, e.dst_conn, copy.deepcopy(e.data)) # Modify nested SDFG parents in subgraph for node in subgraph.nodes(): if isinstance(node, nodes.NestedSDFG): node.sdfg.parent = nstate node.sdfg.parent_sdfg = nsdfg node.sdfg.parent_nsdfg_node = node # Add access nodes and edges as necessary edges_to_offset = [] for edge, name in input_names.items(): node = nstate.add_read(name) new_edge = copy.deepcopy(edge.data) new_edge.data = name edges_to_offset.append((edge, nstate.add_edge(node, None, edge.dst, edge.dst_conn, new_edge))) for edge, name in output_names.items(): node = nstate.add_write(name) new_edge = copy.deepcopy(edge.data) new_edge.data = name edges_to_offset.append((edge, nstate.add_edge(edge.src, edge.src_conn, node, None, new_edge))) # Offset memlet paths inside nested SDFG according to subsets for original_edge, new_edge in edges_to_offset: for edge in nstate.memlet_tree(new_edge): edge.data.data = new_edge.data.data if not full_data: edge.data.subset.offset(global_subsets[original_edge.data.data][1], True) # Add nested SDFG node to the input state nested_sdfg = state.add_nested_sdfg(nsdfg, None, set(input_names.values()) | input_arrays, set(output_names.values()) | output_arrays.keys()) # Reconnect memlets to nested SDFG reconnected_in = set() reconnected_out = set() empty_input = None empty_output = None for edge in inputs: if edge.data.data is None: empty_input = edge continue name = input_names[edge] if name in reconnected_in: continue if full_data: data = Memlet.from_array(edge.data.data, sdfg.arrays[edge.data.data]) else: data = copy.deepcopy(edge.data) data.subset = global_subsets[edge.data.data][1] state.add_edge(edge.src, edge.src_conn, nested_sdfg, name, data) reconnected_in.add(name) for edge in outputs: if edge.data.data is None: empty_output = edge continue name = output_names[edge] if name in reconnected_out: continue if full_data: data = Memlet.from_array(edge.data.data, sdfg.arrays[edge.data.data]) else: data = copy.deepcopy(edge.data) data.subset = global_subsets[edge.data.data][1] data.wcr = edge.data.wcr state.add_edge(nested_sdfg, name, edge.dst, edge.dst_conn, data) reconnected_out.add(name) # Connect access nodes to internal input/output data as necessary entry = scope.entry exit = scope.exit for name in input_arrays: node = state.add_read(name) if entry is not None: state.add_nedge(entry, node, Memlet()) state.add_edge(node, None, nested_sdfg, name, Memlet.from_array(name, sdfg.arrays[name])) for name, wcr in output_arrays.items(): node = state.add_write(name) if exit is not None: state.add_nedge(node, exit, Memlet()) state.add_edge(nested_sdfg, name, node, None, Memlet(data=name, wcr=wcr)) # Graph was not reconnected, but needs to be if state.in_degree(nested_sdfg) == 0 and empty_input is not None: state.add_edge(empty_input.src, empty_input.src_conn, nested_sdfg, None, empty_input.data) if state.out_degree(nested_sdfg) == 0 and empty_output is not None: state.add_edge(nested_sdfg, None, empty_output.dst, empty_output.dst_conn, empty_output.data) # Remove subgraph nodes from graph state.remove_nodes_from(subgraph.nodes()) # Remove subgraph transients from top-level graph for transient in subgraph_transients: del sdfg.arrays[transient] # Remove newly isolated nodes due to memlet consolidation for edge in inputs: if state.in_degree(edge.src) + state.out_degree(edge.src) == 0: state.remove_node(edge.src) for edge in outputs: if state.in_degree(edge.dst) + state.out_degree(edge.dst) == 0: state.remove_node(edge.dst) return nested_sdfg
N = dp.symbol('N') sdfg = SDFG('tlstream') state = sdfg.add_state('doit') localarr = state.add_transient('la', [10], dp.float32) localstream = state.add_stream('ls', dp.float32, 1, transient=True) globalstream = state.add_stream('gs', dp.float32, 1, transient=True) globalarr = state.add_array('ga', [N], dp.float32) me, mx = state.add_map('par', dict(i='0:N')) tasklet = state.add_tasklet('arange', set(), {'a'}, 'a = i') state.add_nedge(me, tasklet, EmptyMemlet()) state.add_edge(tasklet, 'a', localstream, None, Memlet.from_array(localstream.data, localstream.desc(sdfg))) state.add_nedge(localstream, localarr, Memlet.from_array(localarr.data, localarr.desc(sdfg))) state.add_nedge(localarr, mx, Memlet.from_array(globalstream.data, globalstream.desc(sdfg))) state.add_nedge(mx, globalstream, Memlet.from_array(globalstream.data, globalstream.desc(sdfg))) state.add_nedge(globalstream, globalarr, Memlet.from_array(globalarr.data, globalarr.desc(sdfg))) sdfg.fill_scope_connectors() if __name__ == '__main__': print('Thread-local stream test') N.set(20)
def make_sdfg(implementation, dtype, storage=dace.StorageType.Default): n = dace.symbol("n") suffix = "_device" if storage != dace.StorageType.Default else "" transient = storage != dace.StorageType.Default sdfg = dace.SDFG("matrix_lufact_getrf_{}_{}".format( implementation, str(dtype))) state = sdfg.add_state("dataflow") xhost_arr = sdfg.add_array("x", [n, n], dtype, storage=dace.StorageType.Default) if transient: x_arr = sdfg.add_array("x" + suffix, [n, n], dtype, storage=storage, transient=transient) xt_arr = sdfg.add_array('xt' + suffix, [n, n], dtype, storage=storage, transient=transient) sdfg.add_array("pivots" + suffix, [n], dace.dtypes.int32, storage=storage, transient=transient) sdfg.add_array("result" + suffix, [1], dace.dtypes.int32, storage=storage, transient=transient) if transient: xhi = state.add_read("x") xho = state.add_write("x") xi = state.add_access("x" + suffix) xo = state.add_access("x" + suffix) xin = state.add_access("xt" + suffix) xout = state.add_access("xt" + suffix) transpose_in = blas.nodes.transpose.Transpose("transpose_in", dtype=dtype) transpose_in.implementation = "cuBLAS" transpose_out = blas.nodes.transpose.Transpose("transpose_out", dtype=dtype) transpose_out.implementation = "cuBLAS" state.add_nedge(xhi, xi, Memlet.from_array(*xhost_arr)) state.add_nedge(xo, xho, Memlet.from_array(*xhost_arr)) state.add_memlet_path(xi, transpose_in, dst_conn='_inp', memlet=Memlet.from_array(*x_arr)) state.add_memlet_path(transpose_in, xin, src_conn='_out', memlet=Memlet.from_array(*xt_arr)) state.add_memlet_path(xout, transpose_out, dst_conn='_inp', memlet=Memlet.from_array(*xt_arr)) state.add_memlet_path(transpose_out, xo, src_conn='_out', memlet=Memlet.from_array(*x_arr)) else: xin = state.add_access("x" + suffix) xout = state.add_access("x" + suffix) pivots = state.add_access("pivots" + suffix) result = state.add_access("result" + suffix) getrf_node = lapack.nodes.getrf.Getrf("getrf") getrf_node.implementation = implementation state.add_memlet_path(xin, getrf_node, dst_conn="_xin", memlet=Memlet.simple(xin, "0:n, 0:n", num_accesses=n * n)) state.add_memlet_path(getrf_node, result, src_conn="_res", memlet=Memlet.simple(result, "0", num_accesses=1)) state.add_memlet_path(getrf_node, pivots, src_conn="_ipiv", memlet=Memlet.simple(pivots, "0:n", num_accesses=n)) state.add_memlet_path(getrf_node, xout, src_conn="_xout", memlet=Memlet.simple(xout, "0:n, 0:n", num_accesses=n * n)) return sdfg
def _isend(pv: 'ProgramVisitor', sdfg: SDFG, state: SDFGState, buffer: str, dst: Union[str, sp.Expr, Number], tag: Union[str, sp.Expr, Number], request: str): from dace.libraries.mpi.nodes.isend import Isend libnode = Isend('_Isend_') buf_range = None if isinstance(buffer, tuple): buf_name, buf_range = buffer else: buf_name = buffer desc = sdfg.arrays[buf_name] buf_node = state.add_read(buf_name) req_range = None if isinstance(request, tuple): req_name, req_range = request else: req_name = request req_desc = sdfg.arrays[req_name] req_node = state.add_write(req_name) iconn = libnode.in_connectors iconn = { c: (dtypes.pointer(desc.dtype) if c == '_buffer' else t) for c, t in iconn.items() } libnode.in_connectors = iconn oconn = libnode.out_connectors oconn = { c: (dtypes.pointer(req_desc.dtype) if c == '_request' else t) for c, t in oconn.items() } libnode.out_connectors = oconn dst_range = None if isinstance(dst, tuple): dst_name, dst_range = dst dst_node = state.add_read(dst_name) elif isinstance(dst, str) and dst in sdfg.arrays.keys(): dst_name = dst dst_node = state.add_read(dst_name) else: storage = desc.storage dst_name = _define_local_scalar(pv, sdfg, state, dace.int32, storage) dst_node = state.add_access(dst_name) dst_tasklet = state.add_tasklet('_set_dst_', {}, {'__out'}, '__out = {}'.format(dst)) state.add_edge(dst_tasklet, '__out', dst_node, None, Memlet.simple(dst_name, '0')) tag_range = None if isinstance(tag, tuple): tag_name, tag_range = tag tag_node = state.add_read(tag_name) if isinstance(tag, str) and tag in sdfg.arrays.keys(): tag_name = tag tag_node = state.add_read(tag) else: storage = desc.storage tag_name = _define_local_scalar(pv, sdfg, state, dace.int32, storage) tag_node = state.add_access(tag_name) tag_tasklet = state.add_tasklet('_set_tag_', {}, {'__out'}, '__out = {}'.format(tag)) state.add_edge(tag_tasklet, '__out', tag_node, None, Memlet.simple(tag_name, '0')) if buf_range: buf_mem = Memlet.simple(buf_name, buf_range) else: buf_mem = Memlet.from_array(buf_name, desc) if req_range: req_mem = Memlet.simple(req_name, req_range) else: req_mem = Memlet.from_array(req_name, req_desc) if dst_range: dst_mem = Memlet.simple(dst_name, dst_range) else: dst_mem = Memlet.simple(dst_name, '0') if tag_range: tag_mem = Memlet.simple(tag_name, tag_range) else: tag_mem = Memlet.simple(tag_name, '0') state.add_edge(buf_node, None, libnode, '_buffer', buf_mem) state.add_edge(dst_node, None, libnode, '_dest', dst_mem) state.add_edge(tag_node, None, libnode, '_tag', tag_mem) state.add_edge(libnode, '_request', req_node, None, req_mem) return None