def gemv_libnode(pv: 'ProgramVisitor', sdfg: SDFG, state: SDFGState, A, x, y, alpha, beta, trans=None): # Get properties if trans is None: trans = (sdfg.arrays[x].shape[0] == sdfg.arrays[A].shape[0]) # Add nodes A_in, x_in = (state.add_read(name) for name in (A, x)) y_out = state.add_write(y) libnode = Gemv('gemv', transA=trans, alpha=alpha, beta=beta) state.add_node(libnode) # Connect nodes state.add_edge(A_in, None, libnode, '_A', mm.Memlet(A)) state.add_edge(x_in, None, libnode, '_x', mm.Memlet(x)) state.add_edge(libnode, '_y', y_out, None, mm.Memlet(y)) if beta != 0: y_in = state.add_read(y) state.add_edge(y_in, None, libnode, '_y', mm.Memlet(y)) return []
def _gather(pv: 'ProgramVisitor', sdfg: SDFG, state: SDFGState, in_buffer: str, out_buffer: str, root: Union[str, sp.Expr, Number] = 0): from dace.libraries.mpi.nodes.gather import Gather libnode = Gather('_Gather_') in_desc = sdfg.arrays[in_buffer] out_desc = sdfg.arrays[out_buffer] in_node = state.add_read(in_buffer) out_node = state.add_write(out_buffer) if isinstance(root, str) and root in sdfg.arrays.keys(): root_node = state.add_read(root) else: storage = in_desc.storage root_name = _define_local_scalar(pv, sdfg, state, dace.int32, storage) root_node = state.add_access(root_name) root_tasklet = state.add_tasklet('_set_root_', {}, {'__out'}, '__out = {}'.format(root)) state.add_edge(root_tasklet, '__out', root_node, None, Memlet.simple(root_name, '0')) state.add_edge(in_node, None, libnode, '_inbuffer', Memlet.from_array(in_buffer, in_desc)) state.add_edge(root_node, None, libnode, '_root', Memlet.simple(root_node.data, '0')) state.add_edge(libnode, '_outbuffer', out_node, None, Memlet.from_array(out_buffer, out_desc)) return None
def _Reduce(pv: 'ProgramVisitor', sdfg: SDFG, state: SDFGState, buffer: str, op: str, root: Union[str, sp.Expr, Number] = 0, grid: str = None): from dace.libraries.mpi.nodes.reduce import Reduce libnode = Reduce('_Reduce_', op, grid) desc = sdfg.arrays[buffer] in_buffer = state.add_read(buffer) out_buffer = state.add_write(buffer) if isinstance(root, str) and root in sdfg.arrays.keys(): root_node = state.add_read(root) else: storage = desc.storage root_name = _define_local_scalar(pv, sdfg, state, dace.int32, storage) root_node = state.add_access(root_name) root_tasklet = state.add_tasklet('_set_root_', {}, {'__out'}, '__out = {}'.format(root)) state.add_edge(root_tasklet, '__out', root_node, None, Memlet.simple(root_name, '0')) state.add_edge(in_buffer, None, libnode, '_inbuffer', Memlet.from_array(buffer, desc)) state.add_edge(root_node, None, libnode, '_root', Memlet.simple(root_node.data, '0')) state.add_edge(libnode, '_outbuffer', out_buffer, None, Memlet.from_array(buffer, desc)) return None
def _reduce(sdfg: SDFG, state: SDFGState, redfunction: Callable[[Any, Any], Any], in_array: str, out_array=None, axis=None, identity=None): if out_array is None: inarr = in_array # Convert axes to tuple if axis is not None and not isinstance(axis, (tuple, list)): axis = (axis, ) if axis is not None: axis = tuple(pystr_to_symbolic(a) for a in axis) input_subset = parse_memlet_subset(sdfg.arrays[inarr], ast.parse(in_array).body[0].value, {}) input_memlet = Memlet.simple(inarr, input_subset) output_shape = None if axis is None: output_shape = [1] else: output_subset = copy.deepcopy(input_subset) output_subset.pop(axis) output_shape = output_subset.size() outarr, arr = sdfg.add_temp_transient(output_shape, sdfg.arrays[inarr].dtype, sdfg.arrays[inarr].storage) output_memlet = Memlet.from_array(outarr, arr) else: inarr = in_array outarr = out_array # Convert axes to tuple if axis is not None and not isinstance(axis, (tuple, list)): axis = (axis, ) if axis is not None: axis = tuple(pystr_to_symbolic(a) for a in axis) # Compute memlets input_subset = parse_memlet_subset(sdfg.arrays[inarr], ast.parse(in_array).body[0].value, {}) input_memlet = Memlet.simple(inarr, input_subset) output_subset = parse_memlet_subset(sdfg.arrays[outarr], ast.parse(out_array).body[0].value, {}) output_memlet = Memlet.simple(outarr, output_subset) # Create reduce subgraph inpnode = state.add_read(inarr) rednode = state.add_reduce(redfunction, axis, identity) outnode = state.add_write(outarr) state.add_nedge(inpnode, rednode, input_memlet) state.add_nedge(rednode, outnode, output_memlet) if out_array is None: return outarr else: return []
def apply(self, state: SDFGState, sdfg: SDFG) -> nodes.AccessNode: access: nodes.AccessNode = self.access # Get memlet paths first_edge = state.in_edges(access)[0] second_edge = state.out_edges(access)[0] first_mpath = state.memlet_path(first_edge) second_mpath = state.memlet_path(second_edge) # Create new stream of shape 1 desc = sdfg.arrays[access.data] name, newdesc = sdfg.add_stream(access.data, desc.dtype, buffer_size=self.buffer_size, storage=self.storage, transient=True, find_new_name=True) # Remove transient array if possible for ostate in sdfg.nodes(): if ostate is state: continue if any(n.data == access.data for n in ostate.data_nodes()): break else: del sdfg.arrays[access.data] # Replace memlets in path with stream access for e in first_mpath: e.data = mm.Memlet(data=name, subset='0') if isinstance(e.src, nodes.NestedSDFG): e.data.dynamic = True _streamify_recursive(e.src, e.src_conn, newdesc) if isinstance(e.dst, nodes.NestedSDFG): e.data.dynamic = True _streamify_recursive(e.dst, e.dst_conn, newdesc) for e in second_mpath: e.data = mm.Memlet(data=name, subset='0') if isinstance(e.src, nodes.NestedSDFG): e.data.dynamic = True _streamify_recursive(e.src, e.src_conn, newdesc) if isinstance(e.dst, nodes.NestedSDFG): e.data.dynamic = True _streamify_recursive(e.dst, e.dst_conn, newdesc) # Replace array access node with two stream access nodes wnode = state.add_write(name) rnode = state.add_read(name) state.remove_edge(first_edge) state.add_edge(first_edge.src, first_edge.src_conn, wnode, first_edge.dst_conn, first_edge.data) state.remove_edge(second_edge) state.add_edge(rnode, second_edge.src_conn, second_edge.dst, second_edge.dst_conn, second_edge.data) # Remove original access node state.remove_node(access) return wnode, rnode
def nccl_send(pv: 'ProgramVisitor', sdfg: SDFG, state: SDFGState, in_buffer: str, peer: symbolic.SymbolicType = 0, group_handle: str = None): inputs = {"_inbuffer"} outputs = set() if isinstance(group_handle, str): gh_start = False if group_handle in sdfg.arrays.keys(): gh_name = group_handle gh_out = state.add_access(gh_name) gh_in = state.add_access(gh_name) inputs.add("_group_handle") else: gh_start = True gh_name = _define_local_scalar(pv, sdfg, state, dace.int32, dtypes.StorageType.GPU_Global) gh_out = state.add_access(gh_name) outputs.add("_group_handle") libnode = Send(inputs=inputs, outputs=outputs, peer=peer) if isinstance(group_handle, str): gh_memlet = Memlet.simple(gh_name, '0') if not gh_start: state.add_edge(gh_in, None, libnode, "_group_handle", gh_memlet) state.add_edge(libnode, "_group_handle", gh_out, None, gh_memlet) in_range = None if isinstance(in_buffer, tuple): in_name, in_range = in_buffer else: in_name = in_buffer desc = sdfg.arrays[in_name] conn = libnode.in_connectors conn = { c: (dtypes.pointer(desc.dtype) if c == '_buffer' else t) for c, t in conn.items() } libnode.in_connectors = conn in_node = state.add_read(in_name) if in_range: buf_mem = Memlet.simple(in_name, in_range) else: buf_mem = Memlet.from_array(in_name, desc) state.add_edge(in_node, None, libnode, '_inbuffer', buf_mem) return []
def nccl_reduce(pv: 'ProgramVisitor', sdfg: SDFG, state: SDFGState, redfunction: Callable[[Any, Any], Any], in_buffer: str, out_buffer: Union[str, None] = None, root: str = None, group_handle: str = None): inputs = {"_inbuffer"} outputs = {"_outbuffer"} if isinstance(group_handle, str): gh_start = False if group_handle in sdfg.arrays.keys(): gh_name = group_handle gh_out = state.add_access(gh_name) gh_in = state.add_access(gh_name) inputs.add("_group_handle") else: gh_start = True gh_name = _define_local_scalar(pv, sdfg, state, dace.int32, dtypes.StorageType.GPU_Global) gh_out = state.add_access(gh_name) outputs.add("_group_handle") libnode = Reduce(inputs=inputs, outputs=outputs, wcr=redfunction, root=root) if isinstance(group_handle, str): gh_memlet = Memlet.simple(gh_name, '0') if not gh_start: state.add_edge(gh_in, None, libnode, "_group_handle", gh_memlet) state.add_edge(libnode, "_group_handle", gh_out, None, gh_memlet) # If out_buffer is not specified, the operation will be in-place. if out_buffer is None: out_buffer = in_buffer # Add nodes in_node = state.add_read(in_buffer) out_node = state.add_write(out_buffer) # Connect nodes state.add_edge(in_node, None, libnode, '_inbuffer', Memlet(in_buffer)) state.add_edge(libnode, '_outbuffer', out_node, None, Memlet(out_buffer)) return []
def _block_gather(pv: 'ProgramVisitor', sdfg: SDFG, state: SDFGState, in_buffer: str, out_buffer: str, gather_grid: str, reduce_grid: str = None, correspondence: Sequence[Integral] = None): """ Block-gathers an Array using process-grids, sub-arrays, and the BlockGather library node. This method currently does not support Array slices and imperfect tiling. :param in_buffer: Name of the (local) Array descriptor. :param out_buffer: Name of the (global) Array descriptor. :param gather_grid: Name of the sub-grid used for gathering the Array (reduction group leaders). :param reduce_grid: Name of the sub-grid used for broadcasting the Array (reduction groups). :param correspondence: Matching of the array/sub-array's dimensions to the process-grid's dimensions. :return: Name of the new sub-array descriptor. """ in_desc = sdfg.arrays[in_buffer] out_desc = sdfg.arrays[out_buffer] if in_desc.dtype != out_desc.dtype: raise ValueError("Input/output buffer datatypes must match!") subarray_name = _subarray(pv, sdfg, state, out_buffer, in_buffer, process_grid=gather_grid, correspondence=correspondence) from dace.libraries.mpi import BlockGather libnode = BlockGather('_BlockGather_', subarray_name, gather_grid, reduce_grid) inbuf_name = in_buffer in_desc = sdfg.arrays[inbuf_name] inbuf_node = state.add_read(inbuf_name) inbuf_mem = Memlet.from_array(inbuf_name, in_desc) outbuf_name = out_buffer out_desc = sdfg.arrays[outbuf_name] outbuf_node = state.add_write(outbuf_name) outbuf_mem = Memlet.from_array(outbuf_name, out_desc) state.add_edge(inbuf_node, None, libnode, '_inp_buffer', inbuf_mem) state.add_edge(libnode, '_out_buffer', outbuf_node, None, outbuf_mem) return subarray_name
def ger_libnode(pv: 'ProgramVisitor', sdfg: SDFG, state: SDFGState, A, x, y, output, alpha): # Add nodes A_in, x_in, y_in = (state.add_read(name) for name in (A, x, y)) out = state.add_write(output) libnode = Ger('ger', alpha=alpha) state.add_node(libnode) # Connect nodes state.add_edge(A_in, None, libnode, '_A', mm.Memlet(A)) state.add_edge(x_in, None, libnode, '_x', mm.Memlet(x)) state.add_edge(y_in, None, libnode, '_y', mm.Memlet(y)) state.add_edge(libnode, '_res', out, None, mm.Memlet(output)) return []
def _transpose(sdfg: SDFG, state: SDFGState, inpname: str): arr1 = sdfg.arrays[inpname] restype = arr1.dtype outname, arr2 = sdfg.add_temp_transient((arr1.shape[1], arr1.shape[0]), restype, arr1.storage) acc1 = state.add_read(inpname) acc2 = state.add_write(outname) import dace.libraries.blas # Avoid import loop tasklet = dace.libraries.blas.Transpose('_Transpose_', restype) state.add_node(tasklet) state.add_edge(acc1, None, tasklet, '_inp', dace.Memlet.from_array(inpname, arr1)) state.add_edge(tasklet, '_out', acc2, None, dace.Memlet.from_array(outname, arr2)) return outname
def _simple_call(sdfg: SDFG, state: SDFGState, inpname: str, func: str, restype: dace.typeclass = None): """ Implements a simple call of the form `out = func(inp)`. """ inparr = sdfg.arrays[inpname] if restype is None: restype = sdfg.arrays[inpname].dtype outname, outarr = sdfg.add_temp_transient(inparr.shape, restype, inparr.storage) num_elements = reduce(lambda x, y: x * y, inparr.shape) if num_elements == 1: inp = state.add_read(inpname) out = state.add_write(outname) tasklet = state.add_tasklet(func, {'__inp'}, {'__out'}, '__out = {f}(__inp)'.format(f=func)) state.add_edge(inp, None, tasklet, '__inp', Memlet.from_array(inpname, inparr)) state.add_edge(tasklet, '__out', out, None, Memlet.from_array(outname, outarr)) else: state.add_mapped_tasklet( name=func, map_ranges={ '__i%d' % i: '0:%s' % n for i, n in enumerate(inparr.shape) }, inputs={ '__inp': Memlet.simple( inpname, ','.join(['__i%d' % i for i in range(len(inparr.shape))])) }, code='__out = {f}(__inp)'.format(f=func), outputs={ '__out': Memlet.simple( outname, ','.join(['__i%d' % i for i in range(len(inparr.shape))])) }, external_edges=True) return outname
def _Allreduce(pv: 'ProgramVisitor', sdfg: SDFG, state: SDFGState, buffer: str, op: str, grid: str = None): from dace.libraries.mpi.nodes.allreduce import Allreduce libnode = Allreduce('_Allreduce_', op, grid) desc = sdfg.arrays[buffer] in_buffer = state.add_read(buffer) out_buffer = state.add_write(buffer) state.add_edge(in_buffer, None, libnode, '_inbuffer', Memlet.from_array(buffer, desc)) state.add_edge(libnode, '_outbuffer', out_buffer, None, Memlet.from_array(buffer, desc)) return None
def _array_x_binop(visitor: 'ProgramVisitor', sdfg: SDFG, state: SDFGState, op1: str, op2: str, op: str, opcode: str): arr1 = sdfg.arrays[op1] type1 = arr1.dtype.type isscal1 = _is_scalar(sdfg, op1) isnum1 = isscal1 and (op1 in visitor.numbers.values()) if isnum1: type1 = inverse_dict_lookup(visitor.numbers, op1) arr2 = sdfg.arrays[op2] type2 = arr2.dtype.type isscal2 = _is_scalar(sdfg, op2) isnum2 = isscal2 and (op2 in visitor.numbers.values()) if isnum2: type2 = inverse_dict_lookup(visitor.numbers, op2) if _is_op_boolean(op): restype = dace.bool else: restype = dace.DTYPE_TO_TYPECLASS[np.result_type(type1, type2).type] if isscal1 and isscal2: arr1 = sdfg.arrays[op1] arr2 = sdfg.arrays[op2] op3, arr3 = sdfg.add_temp_transient([1], restype, arr2.storage) tasklet = state.add_tasklet('_SS%s_' % op, {'s1', 's2'}, {'s3'}, 's3 = s1 %s s2' % opcode) n1 = state.add_read(op1) n2 = state.add_read(op2) n3 = state.add_write(op3) state.add_edge(n1, None, tasklet, 's1', dace.Memlet.from_array(op1, arr1)) state.add_edge(n2, None, tasklet, 's2', dace.Memlet.from_array(op2, arr2)) state.add_edge(tasklet, 's3', n3, None, dace.Memlet.from_array(op3, arr3)) return op3 else: return _binop(sdfg, state, op1, op2, opcode, op, restype)
def _distr_matmult(pv: 'ProgramVisitor', sdfg: SDFG, state: SDFGState, opa: str, opb: str, shape: Sequence[Union[sp.Expr, Number]], a_block_sizes: Union[str, Sequence[Union[sp.Expr, Number]]] = None, b_block_sizes: Union[str, Sequence[Union[sp.Expr, Number]]] = None, c_block_sizes: Union[str, Sequence[Union[sp.Expr, Number]]] = None): arra = sdfg.arrays[opa] arrb = sdfg.arrays[opb] if len(shape) == 3: gm, gn, gk = shape else: gm, gn = shape a_block_sizes = a_block_sizes or arra.shape if len(a_block_sizes) < 2: a_block_sizes = (a_block_sizes[0], 1) b_block_sizes = b_block_sizes or arrb.shape if len(b_block_sizes) < 2: b_block_sizes = (b_block_sizes[0], 1) if len(arra.shape) == 1 and len(arrb.shape) == 2: a_block_sizes, b_block_sizes = b_block_sizes, a_block_sizes a_bsizes_range = None if isinstance(a_block_sizes, (list, tuple)): if isinstance(a_block_sizes[0], str): a_bsizes_name, a_bsizes_range = a_block_sizes a_bsizes_desc = sdfg.arrays[a_bsizes_name] a_bsizes_node = state.add_read(a_bsizes_name) else: a_bsizes_name, a_bsizes_desc = sdfg.add_temp_transient( (len(a_block_sizes), ), dtype=dace.int32) a_bsizes_node = state.add_access(a_bsizes_name) a_bsizes_tasklet = state.add_tasklet( '_set_a_bsizes_', {}, {'__out'}, ";".join([ "__out[{}] = {}".format(i, sz) for i, sz in enumerate(a_block_sizes) ])) state.add_edge(a_bsizes_tasklet, '__out', a_bsizes_node, None, Memlet.from_array(a_bsizes_name, a_bsizes_desc)) else: a_bsizes_name = a_block_sizes a_bsizes_desc = sdfg.arrays[a_bsizes_name] a_bsizes_node = state.add_read(a_bsizes_name) b_bsizes_range = None if isinstance(a_block_sizes, (list, tuple)): if isinstance(a_block_sizes[0], str): b_bsizes_name, b_sizes_range = b_block_sizes b_bsizes_desc = sdfg.arrays[b_bsizes_name] b_bsizes_node = state.add_read(b_bsizes_name) else: b_bsizes_name, b_bsizes_desc = sdfg.add_temp_transient( (len(b_block_sizes), ), dtype=dace.int32) b_bsizes_node = state.add_access(b_bsizes_name) b_bsizes_tasklet = state.add_tasklet( '_set_b_sizes_', {}, {'__out'}, ";".join([ "__out[{}] = {}".format(i, sz) for i, sz in enumerate(b_block_sizes) ])) state.add_edge(b_bsizes_tasklet, '__out', b_bsizes_node, None, Memlet.from_array(b_bsizes_name, b_bsizes_desc)) else: b_bsizes_name = b_block_sizes b_bsizes_desc = sdfg.arrays[b_bsizes_name] b_bsizes_node = state.add_read(b_bsizes_name) if len(arra.shape) == 2 and len(arrb.shape) == 2: # Gemm from dace.libraries.pblas.nodes.pgemm import Pgemm tasklet = Pgemm("__DistrMatMult__", gm, gn, gk) m = arra.shape[0] n = arrb.shape[-1] out = sdfg.add_temp_transient((m, n), dtype=arra.dtype) elif len(arra.shape) == 2 and len(arrb.shape) == 1: # Gemv from dace.libraries.pblas.nodes.pgemv import Pgemv tasklet = Pgemv("__DistrMatVecMult__", m=gm, n=gn) if c_block_sizes: m = c_block_sizes[0] else: m = arra.shape[0] out = sdfg.add_temp_transient((m, ), dtype=arra.dtype) elif len(arra.shape) == 1 and len(arrb.shape) == 2: # Gemv transposed # Swap a and b opa, opb = opb, opa arra, arrb = arrb, arra from dace.libraries.pblas.nodes.pgemv import Pgemv tasklet = Pgemv("__DistrMatVecMult__", transa='T', m=gm, n=gn) if c_block_sizes: n = c_block_sizes[0] else: n = arra.shape[1] out = sdfg.add_temp_transient((n, ), dtype=arra.dtype) anode = state.add_read(opa) bnode = state.add_read(opb) cnode = state.add_write(out[0]) if a_bsizes_range: a_bsizes_mem = Memlet.simple(a_bsizes_name, a_bsizes_range) else: a_bsizes_mem = Memlet.from_array(a_bsizes_name, a_bsizes_desc) if b_bsizes_range: b_bsizes_mem = Memlet.simple(b_bsizes_name, b_bsizes_range) else: b_bsizes_mem = Memlet.from_array(b_bsizes_name, b_bsizes_desc) state.add_edge(anode, None, tasklet, '_a', Memlet.from_array(opa, arra)) state.add_edge(bnode, None, tasklet, '_b', Memlet.from_array(opb, arrb)) state.add_edge(a_bsizes_node, None, tasklet, '_a_block_sizes', a_bsizes_mem) state.add_edge(b_bsizes_node, None, tasklet, '_b_block_sizes', b_bsizes_mem) state.add_edge(tasklet, '_c', cnode, None, Memlet.from_array(*out)) return out[0]
def _redistribute(pv: 'ProgramVisitor', sdfg: SDFG, state: SDFGState, in_buffer: str, in_subarray: str, out_buffer: str, out_subarray: str): """ Redistributes an Array using process-grids, sub-arrays, and the Redistribute library node. :param in_buffer: Name of the (local) input Array descriptor. :param in_subarray: Input sub-array descriptor. :param out_buffer: Name of the (local) output Array descriptor. :param out_subarray: Output sub-array descriptor. :return: Name of the new redistribution descriptor. """ in_desc = sdfg.arrays[in_buffer] out_desc = sdfg.arrays[out_buffer] rdistrarray_name = sdfg.add_rdistrarray(in_subarray, out_subarray) from dace.libraries.mpi import Dummy, Redistribute tasklet = Dummy(rdistrarray_name, [ f'MPI_Datatype {rdistrarray_name};', f'int {rdistrarray_name}_sends;', f'MPI_Datatype* {rdistrarray_name}_send_types;', f'int* {rdistrarray_name}_dst_ranks;', f'int {rdistrarray_name}_recvs;', f'MPI_Datatype* {rdistrarray_name}_recv_types;', f'int* {rdistrarray_name}_src_ranks;', f'int {rdistrarray_name}_self_copies;', f'int* {rdistrarray_name}_self_src;', f'int* {rdistrarray_name}_self_dst;', f'int* {rdistrarray_name}_self_size;' ]) state.add_node(tasklet) _, scal = sdfg.add_scalar(rdistrarray_name, dace.int32, transient=True) wnode = state.add_write(rdistrarray_name) state.add_edge(tasklet, '__out', wnode, None, Memlet.from_array(rdistrarray_name, scal)) libnode = Redistribute('_Redistribute_', rdistrarray_name) inbuf_range = None if isinstance(in_buffer, tuple): inbuf_name, inbuf_range = in_buffer else: inbuf_name = in_buffer in_desc = sdfg.arrays[inbuf_name] inbuf_node = state.add_read(inbuf_name) outbuf_range = None if isinstance(out_buffer, tuple): outbuf_name, outbuf_range = out_buffer else: outbuf_name = out_buffer out_desc = sdfg.arrays[outbuf_name] outbuf_node = state.add_write(outbuf_name) if inbuf_range: inbuf_mem = Memlet.simple(inbuf_name, inbuf_range) else: inbuf_mem = Memlet.from_array(inbuf_name, in_desc) if outbuf_range: outbuf_mem = Memlet.simple(outbuf_name, outbuf_range) else: outbuf_mem = Memlet.from_array(outbuf_name, out_desc) state.add_edge(inbuf_node, None, libnode, '_inp_buffer', inbuf_mem) state.add_edge(libnode, '_out_buffer', outbuf_node, None, outbuf_mem) return rdistrarray_name
def apply(self, state: SDFGState, sdfg: SDFG) -> nodes.AccessNode: dnode: nodes.AccessNode = self.access if self.expr_index == 0: edges = state.out_edges(dnode) else: edges = state.in_edges(dnode) # To understand how many components we need to create, all map ranges # throughout memlet paths must match exactly. We thus create a # dictionary of unique ranges mapping: Dict[Tuple[subsets.Range], List[gr.MultiConnectorEdge[mm.Memlet]]] = defaultdict( list) ranges = {} for edge in edges: mpath = state.memlet_path(edge) ranges[edge] = _collect_map_ranges(state, mpath) mapping[tuple(r[1] for r in ranges[edge])].append(edge) # Collect all edges with the same memory access pattern components_to_create: Dict[ Tuple[symbolic.SymbolicType], List[gr.MultiConnectorEdge[mm.Memlet]]] = defaultdict(list) for edges_with_same_range in mapping.values(): for edge in edges_with_same_range: # Get memlet path and innermost edge mpath = state.memlet_path(edge) innermost_edge = copy.deepcopy(mpath[-1] if self.expr_index == 0 else mpath[0]) # Store memlets of the same access in the same component expr = _canonicalize_memlet(innermost_edge.data, ranges[edge]) components_to_create[expr].append((innermost_edge, edge)) components = list(components_to_create.values()) # Split out components that have dependencies between them to avoid # deadlocks if self.expr_index == 0: ccs_to_add = [] for i, component in enumerate(components): edges_to_remove = set() for cedge in component: if any( nx.has_path(state.nx, o[1].dst, cedge[1].dst) for o in component if o is not cedge): ccs_to_add.append([cedge]) edges_to_remove.add(cedge) if edges_to_remove: components[i] = [ c for c in component if c not in edges_to_remove ] components.extend(ccs_to_add) # End of split desc = sdfg.arrays[dnode.data] # Create new streams of shape 1 streams = {} mpaths = {} for edge in edges: if self.use_memory_buffering: arrname = str(self.access) # Add gearbox total_size = edge.data.volume vector_size = int(self.memory_buffering_target_bytes / desc.dtype.bytes) if not is_int(sdfg.arrays[dnode.data].shape[-1]): warnings.warn( "Using the MemoryBuffering transformation is potential unsafe since {sym} is not an integer. There should be no issue if {sym} % {vec} == 0" .format(sym=sdfg.arrays[dnode.data].shape[-1], vec=vector_size)) for i in sdfg.arrays[dnode.data].strides: if not is_int(i): warnings.warn( "Using the MemoryBuffering transformation is potential unsafe since {sym} is not an integer. There should be no issue if {sym} % {vec} == 0" .format(sym=i, vec=vector_size)) if self.expr_index == 0: # Read edges = state.out_edges(dnode) gearbox_input_type = dtypes.vector(desc.dtype, vector_size) gearbox_output_type = desc.dtype gearbox_read_volume = total_size / vector_size gearbox_write_volume = total_size else: # Write edges = state.in_edges(dnode) gearbox_input_type = desc.dtype gearbox_output_type = dtypes.vector( desc.dtype, vector_size) gearbox_read_volume = total_size gearbox_write_volume = total_size / vector_size input_gearbox_name, input_gearbox_newdesc = sdfg.add_stream( "gearbox_input", gearbox_input_type, buffer_size=self.buffer_size, storage=self.storage, transient=True, find_new_name=True) output_gearbox_name, output_gearbox_newdesc = sdfg.add_stream( "gearbox_output", gearbox_output_type, buffer_size=self.buffer_size, storage=self.storage, transient=True, find_new_name=True) read_to_gearbox = state.add_read(input_gearbox_name) write_from_gearbox = state.add_write(output_gearbox_name) gearbox = Gearbox(total_size / vector_size) state.add_node(gearbox) state.add_memlet_path(read_to_gearbox, gearbox, dst_conn="from_memory", memlet=Memlet( input_gearbox_name + "[0]", volume=gearbox_read_volume)) state.add_memlet_path(gearbox, write_from_gearbox, src_conn="to_kernel", memlet=Memlet( output_gearbox_name + "[0]", volume=gearbox_write_volume)) if self.expr_index == 0: streams[edge] = input_gearbox_name name = output_gearbox_name newdesc = output_gearbox_newdesc else: streams[edge] = output_gearbox_name name = input_gearbox_name newdesc = input_gearbox_newdesc else: # Qualify name to avoid name clashes if memory interfaces are not decoupled for Xilinx stream_name = "stream_" + dnode.data name, newdesc = sdfg.add_stream(stream_name, desc.dtype, buffer_size=self.buffer_size, storage=self.storage, transient=True, find_new_name=True) streams[edge] = name # Add these such that we can easily use output_gearbox_name and input_gearbox_name without using if statements output_gearbox_name = name input_gearbox_name = name mpath = state.memlet_path(edge) mpaths[edge] = mpath # Replace memlets in path with stream access for e in mpath: e.data = mm.Memlet(data=name, subset='0', other_subset=e.data.other_subset) if isinstance(e.src, nodes.NestedSDFG): e.data.dynamic = True _streamify_recursive(e.src, e.src_conn, newdesc) if isinstance(e.dst, nodes.NestedSDFG): e.data.dynamic = True _streamify_recursive(e.dst, e.dst_conn, newdesc) # Replace access node and memlet tree with one access if self.expr_index == 0: replacement = state.add_read(output_gearbox_name) state.remove_edge(edge) state.add_edge(replacement, edge.src_conn, edge.dst, edge.dst_conn, edge.data) else: replacement = state.add_write(input_gearbox_name) state.remove_edge(edge) state.add_edge(edge.src, edge.src_conn, replacement, edge.dst_conn, edge.data) if self.use_memory_buffering: arrname = str(self.access) vector_size = int(self.memory_buffering_target_bytes / desc.dtype.bytes) # Vectorize access to global array. dtype = sdfg.arrays[arrname].dtype sdfg.arrays[arrname].dtype = dtypes.vector(dtype, vector_size) new_shape = list(sdfg.arrays[arrname].shape) contigidx = sdfg.arrays[arrname].strides.index(1) new_shape[contigidx] /= vector_size try: new_shape[contigidx] = int(new_shape[contigidx]) except TypeError: pass sdfg.arrays[arrname].shape = new_shape # Change strides new_strides: List = list(sdfg.arrays[arrname].strides) for i in range(len(new_strides)): if i == len(new_strides ) - 1: # Skip last dimension since it is always 1 continue new_strides[i] = new_strides[i] / vector_size sdfg.arrays[arrname].strides = new_strides post_state = get_post_state(sdfg, state) if post_state != None: # Change subset in the post state such that the correct amount of memory is copied back from the device for e in post_state.edges(): if e.data.data == self.access.data: new_subset = list(e.data.subset) i, j, k = new_subset[-1] new_subset[-1] = (i, (j + 1) / vector_size - 1, k) e.data = mm.Memlet(data=str(e.src), subset=subsets.Range(new_subset)) # Make read/write components ionodes = [] for component in components: # Pick the first edge as the edge to make the component from innermost_edge, outermost_edge = component[0] mpath = mpaths[outermost_edge] mapname = streams[outermost_edge] innermost_edge.data.other_subset = None # Get edge data and streams if self.expr_index == 0: opname = 'read' path = [e.dst for e in mpath[:-1]] rmemlets = [(dnode, '__inp', innermost_edge.data)] wmemlets = [] for i, (_, edge) in enumerate(component): name = streams[edge] ionode = state.add_write(name) ionodes.append(ionode) wmemlets.append( (ionode, '__out%d' % i, mm.Memlet(data=name, subset='0'))) code = '\n'.join('__out%d = __inp' % i for i in range(len(component))) else: # More than one input stream might mean a data race, so we only # address the first one in the tasklet code if len(component) > 1: warnings.warn( f'More than one input found for the same index for {dnode.data}' ) opname = 'write' path = [state.entry_node(e.src) for e in reversed(mpath[1:])] wmemlets = [(dnode, '__out', innermost_edge.data)] rmemlets = [] for i, (_, edge) in enumerate(component): name = streams[edge] ionode = state.add_read(name) ionodes.append(ionode) rmemlets.append( (ionode, '__inp%d' % i, mm.Memlet(data=name, subset='0'))) code = '__out = __inp0' # Create map structure for read/write component maps = [] for entry in path: map: nodes.Map = entry.map ranges = [(p, (r[0], r[1], r[2])) for p, r in zip(map.params, map.range)] # Change ranges of map if self.use_memory_buffering: # Find edges from/to map edge_subset = [ a_tuple[0] for a_tuple in list(innermost_edge.data.subset) ] # Change range of map if isinstance(edge_subset[-1], symbol) and str( edge_subset[-1]) == map.params[-1]: if not is_int(ranges[-1][1][1]): warnings.warn( "Using the MemoryBuffering transformation is potential unsafe since {sym} is not an integer. There should be no issue if {sym} % {vec} == 0" .format(sym=ranges[-1][1][1].args[1], vec=vector_size)) ranges[-1] = (ranges[-1][0], (ranges[-1][1][0], (ranges[-1][1][1] + 1) / vector_size - 1, ranges[-1][1][2])) elif isinstance(edge_subset[-1], sympy.core.add.Add): for arg in edge_subset[-1].args: if isinstance( arg, symbol) and str(arg) == map.params[-1]: if not is_int(ranges[-1][1][1]): warnings.warn( "Using the MemoryBuffering transformation is potential unsafe since {sym} is not an integer. There should be no issue if {sym} % {vec} == 0" .format(sym=ranges[-1][1][1].args[1], vec=vector_size)) ranges[-1] = (ranges[-1][0], ( ranges[-1][1][0], (ranges[-1][1][1] + 1) / vector_size - 1, ranges[-1][1][2])) maps.append( state.add_map(f'__s{opname}_{mapname}', ranges, map.schedule)) tasklet = state.add_tasklet( f'{opname}_{mapname}', {m[1] for m in rmemlets}, {m[1] for m in wmemlets}, code, ) for node, cname, memlet in rmemlets: state.add_memlet_path(node, *(me for me, _ in maps), tasklet, dst_conn=cname, memlet=memlet) for node, cname, memlet in wmemlets: state.add_memlet_path(tasklet, *(mx for _, mx in reversed(maps)), node, src_conn=cname, memlet=memlet) return ionodes
def _isend(pv: 'ProgramVisitor', sdfg: SDFG, state: SDFGState, buffer: str, dst: Union[str, sp.Expr, Number], tag: Union[str, sp.Expr, Number], request: str): from dace.libraries.mpi.nodes.isend import Isend libnode = Isend('_Isend_') buf_range = None if isinstance(buffer, tuple): buf_name, buf_range = buffer else: buf_name = buffer desc = sdfg.arrays[buf_name] buf_node = state.add_read(buf_name) req_range = None if isinstance(request, tuple): req_name, req_range = request else: req_name = request req_desc = sdfg.arrays[req_name] req_node = state.add_write(req_name) iconn = libnode.in_connectors iconn = { c: (dtypes.pointer(desc.dtype) if c == '_buffer' else t) for c, t in iconn.items() } libnode.in_connectors = iconn oconn = libnode.out_connectors oconn = { c: (dtypes.pointer(req_desc.dtype) if c == '_request' else t) for c, t in oconn.items() } libnode.out_connectors = oconn dst_range = None if isinstance(dst, tuple): dst_name, dst_range = dst dst_node = state.add_read(dst_name) elif isinstance(dst, str) and dst in sdfg.arrays.keys(): dst_name = dst dst_node = state.add_read(dst_name) else: storage = desc.storage dst_name = _define_local_scalar(pv, sdfg, state, dace.int32, storage) dst_node = state.add_access(dst_name) dst_tasklet = state.add_tasklet('_set_dst_', {}, {'__out'}, '__out = {}'.format(dst)) state.add_edge(dst_tasklet, '__out', dst_node, None, Memlet.simple(dst_name, '0')) tag_range = None if isinstance(tag, tuple): tag_name, tag_range = tag tag_node = state.add_read(tag_name) if isinstance(tag, str) and tag in sdfg.arrays.keys(): tag_name = tag tag_node = state.add_read(tag) else: storage = desc.storage tag_name = _define_local_scalar(pv, sdfg, state, dace.int32, storage) tag_node = state.add_access(tag_name) tag_tasklet = state.add_tasklet('_set_tag_', {}, {'__out'}, '__out = {}'.format(tag)) state.add_edge(tag_tasklet, '__out', tag_node, None, Memlet.simple(tag_name, '0')) if buf_range: buf_mem = Memlet.simple(buf_name, buf_range) else: buf_mem = Memlet.from_array(buf_name, desc) if req_range: req_mem = Memlet.simple(req_name, req_range) else: req_mem = Memlet.from_array(req_name, req_desc) if dst_range: dst_mem = Memlet.simple(dst_name, dst_range) else: dst_mem = Memlet.simple(dst_name, '0') if tag_range: tag_mem = Memlet.simple(tag_name, tag_range) else: tag_mem = Memlet.simple(tag_name, '0') state.add_edge(buf_node, None, libnode, '_buffer', buf_mem) state.add_edge(dst_node, None, libnode, '_dest', dst_mem) state.add_edge(tag_node, None, libnode, '_tag', tag_mem) state.add_edge(libnode, '_request', req_node, None, req_mem) return None
def _irecv(pv: 'ProgramVisitor', sdfg: SDFG, state: SDFGState, buffer: str, src: Union[str, sp.Expr, Number], tag: Union[str, sp.Expr, Number], request: str): from dace.libraries.mpi.nodes.irecv import Irecv libnode = Irecv('_Irecv_') buf_range = None if isinstance(buffer, tuple): buf_name, buf_range = buffer else: buf_name = buffer desc = sdfg.arrays[buf_name] buf_node = state.add_read(buf_name) req_range = None if isinstance(request, tuple): req_name, req_range = request else: req_name = request req_desc = sdfg.arrays[req_name] req_node = state.add_write(req_name) conn = libnode.out_connectors conn = { c: (dtypes.pointer(desc.dtype) if c == '_buffer' else t) for c, t in conn.items() } conn = { c: (dtypes.pointer(req_desc.dtype) if c == '_request' else t) for c, t in conn.items() } libnode.out_connectors = conn src_range = None if isinstance(src, tuple): src_name, src_range = src src_node = state.add_read(src_name) elif isinstance(src, str) and src in sdfg.arrays.keys(): src_name = src src_node = state.add_read(src_name) else: storage = desc.storage src_name = _define_local_scalar(pv, sdfg, state, dace.int32, storage) src_node = state.add_access(src_name) src_tasklet = state.add_tasklet('_set_src_', {}, {'__out'}, '__out = {}'.format(src)) state.add_edge(src_tasklet, '__out', src_node, None, Memlet.simple(src_name, '0')) tag_range = None if isinstance(tag, tuple): tag_name, tag_range = tag tag_node = state.add_read(tag_name) if isinstance(tag, str) and tag in sdfg.arrays.keys(): tag_name = tag tag_node = state.add_read(tag) else: storage = desc.storage tag_name = _define_local_scalar(pv, sdfg, state, dace.int32, storage) tag_node = state.add_access(tag_name) tag_tasklet = state.add_tasklet('_set_tag_', {}, {'__out'}, '__out = {}'.format(tag)) state.add_edge(tag_tasklet, '__out', tag_node, None, Memlet.simple(tag_name, '0')) if buf_range: buf_mem = Memlet.simple(buf_name, buf_range) else: buf_mem = Memlet.from_array(buf_name, desc) if req_range: req_mem = Memlet.simple(req_name, req_range) else: req_mem = Memlet.from_array(req_name, req_desc) if src_range: src_mem = Memlet.simple(src_name, src_range) else: src_mem = Memlet.simple(src_name, '0') if tag_range: tag_mem = Memlet.simple(tag_name, tag_range) else: tag_mem = Memlet.simple(tag_name, '0') state.add_edge(libnode, '_buffer', buf_node, None, buf_mem) state.add_edge(src_node, None, libnode, '_src', src_mem) state.add_edge(tag_node, None, libnode, '_tag', tag_mem) state.add_edge(libnode, '_request', req_node, None, req_mem) return None
def _create_einsum_internal(sdfg: SDFG, state: SDFGState, einsum_string: str, *arrays: str, dtype: Optional[dtypes.typeclass] = None, optimize: bool = False, output: Optional[str] = None, nodes: Optional[Dict[str, AccessNode]] = None, init_output: bool = None): # Infer shapes and strides of input/output arrays einsum = EinsumParser(einsum_string) if len(einsum.inputs) != len(arrays): raise ValueError('Invalid number of arrays for einsum expression') # Get shapes from arrays and verify dimensionality chardict = {} for inp, inpname in zip(einsum.inputs, arrays): inparr = sdfg.arrays[inpname] if len(inp) != len(inparr.shape): raise ValueError('Dimensionality mismatch in input "%s"' % inpname) for char, shp in zip(inp, inparr.shape): if char in chardict and shp != chardict[char]: raise ValueError('Dimension mismatch in einsum expression') chardict[char] = shp if optimize: # Try to import opt_einsum try: import opt_einsum as oe except (ModuleNotFoundError, NameError, ImportError): raise ImportError('To optimize einsum expressions, please install ' 'the "opt_einsum" package.') for char, shp in chardict.items(): if symbolic.issymbolic(shp): raise ValueError('Einsum optimization cannot be performed ' 'on symbolically-sized array dimension "%s" ' 'for subscript character "%s"' % (shp, char)) # Create optimal contraction path # noinspection PyTypeChecker _, path_info = oe.contract_path( einsum_string, *oe.helpers.build_views(einsum_string, chardict)) input_nodes = nodes or {arr: state.add_read(arr) for arr in arrays} result_node = None # Follow path and create a chain of operation SDFG states for pair, nonfree, expr, after, blas in path_info.contraction_list: result, result_node = _create_einsum_internal(sdfg, state, expr, arrays[pair[0]], arrays[pair[1]], dtype=dtype, optimize=False, output=None, nodes=input_nodes) arrays = ([a for i, a in enumerate(arrays) if i not in pair] + [result]) input_nodes[result] = result_node return arrays[0], result_node # END of einsum optimization input_nodes = nodes or {arr: state.add_read(arr) for arr in arrays} # Get output shape from chardict, or [1] for a scalar output output_shape = list(map(lambda k: chardict[k], einsum.output)) or [1] output_index = ','.join(o for o in einsum.output) or '0' if output is None: dtype = dtype or sdfg.arrays[arrays[0]].dtype output, odesc = sdfg.add_temp_transient(output_shape, dtype) to_init = True else: odesc = sdfg.arrays[output] dtype = dtype or odesc.dtype to_init = init_output or True is_conflicted = not all( all(indim in einsum.output for indim in inp) for inp in einsum.inputs) if not is_conflicted and init_output is None: to_init = False if not einsum.is_bmm(): # Fall back to "pure" SDFG einsum with conflict resolution c = state.add_write(output) # Add state before this one to initialize the output value if to_init: init_state = sdfg.add_state_before(state) if len(einsum.output) > 0: init_state.add_mapped_tasklet( 'einsum_reset', {k: '0:%s' % chardict[k] for k in einsum.output}, {}, 'out_%s = 0' % output, {'out_%s' % output: Memlet.simple(output, output_index)}, external_edges=True) else: # Scalar output t = init_state.add_tasklet('einsum_reset', set(), {'out_%s' % output}, 'out_%s = 0' % output) onode = init_state.add_write(output) init_state.add_edge(t, 'out_%s' % output, onode, None, Memlet.simple(output, '0')) wcr = 'lambda a,b: a+b' if is_conflicted else None # Pure einsum map state.add_mapped_tasklet( 'einsum', {k: '0:%s' % v for k, v in chardict.items()}, { 'inp_%s' % arr: Memlet.simple(arr, ','.join(inp)) for inp, arr in zip(einsum.inputs, arrays) }, 'out_%s = %s' % (output, ' * '.join('inp_%s' % arr for arr in arrays)), { 'out_%s' % output: Memlet.simple( output, output_index, wcr_str=wcr) }, input_nodes=input_nodes, output_nodes={output: c}, external_edges=True) else: # Represent einsum as a GEMM or batched GEMM (using library nodes) a_shape = sdfg.arrays[arrays[0]].shape b_shape = sdfg.arrays[arrays[1]].shape c_shape = output_shape a = input_nodes[arrays[0]] b = input_nodes[arrays[1]] c = state.add_write(output) # Compute GEMM dimensions and strides strides = dict( BATCH=prod([c_shape[dim] for dim in einsum.c_batch]), M=prod([a_shape[dim] for dim in einsum.a_only]), K=prod([a_shape[dim] for dim in einsum.a_sum]), N=prod([b_shape[dim] for dim in einsum.b_only]), sAM=prod(a_shape[einsum.a_only[-1] + 1:]) if einsum.a_only else 1, sAK=prod(a_shape[einsum.a_sum[-1] + 1:]) if einsum.a_sum else 1, sAB=prod(a_shape[einsum.a_batch[-1] + 1:]) if einsum.a_batch else 1, sBK=prod(b_shape[einsum.b_sum[-1] + 1:]) if einsum.b_sum else 1, sBN=prod(b_shape[einsum.b_only[-1] + 1:]) if einsum.b_only else 1, sBB=prod(b_shape[einsum.b_batch[-1] + 1:]) if einsum.b_batch else 1, sCM=prod(c_shape[einsum.c_a_only[-1] + 1:]) if einsum.c_a_only else 1, sCN=prod(c_shape[einsum.c_b_only[-1] + 1:]) if einsum.c_b_only else 1, sCB=prod(c_shape[einsum.c_batch[-1] + 1:]) if einsum.c_batch else 1) # Complement strides to make matrices as necessary if len(a_shape) == 1 and len(einsum.a_sum) == 1: strides['sAK'] = 1 strides['sAB'] = strides['sAM'] = strides['K'] if len(b_shape) == 1 and len(einsum.b_sum) == 1: strides['sBN'] = 1 strides['sBK'] = 1 strides['sBB'] = strides['K'] if len(c_shape) == 1 and len(einsum.a_sum) == len(einsum.b_sum): strides['sCN'] = 1 strides['sCB'] = strides['sCM'] = strides['N'] # Create nested SDFG for GEMM nsdfg = create_batch_gemm_sdfg(dtype, strides) nsdfg_node = state.add_nested_sdfg(nsdfg, None, {'X', 'Y'}, {'Z'}, strides) state.add_edge(a, None, nsdfg_node, 'X', Memlet.from_array(a.data, a.desc(sdfg))) state.add_edge(b, None, nsdfg_node, 'Y', Memlet.from_array(b.data, b.desc(sdfg))) state.add_edge(nsdfg_node, 'Z', c, None, Memlet.from_array(c.data, c.desc(sdfg))) return output, c
def _elementwise(sdfg: SDFG, state: SDFGState, func: str, in_array: str, out_array=None): """Apply a lambda function to each element in the input""" inparr = sdfg.arrays[in_array] restype = sdfg.arrays[in_array].dtype if out_array is None: out_array, outarr = sdfg.add_temp_transient(inparr.shape, restype, inparr.storage) else: outarr = sdfg.arrays[out_array] func_ast = ast.parse(func) try: lambda_ast = func_ast.body[0].value if len(lambda_ast.args.args) != 1: raise SyntaxError( "Expected lambda with one arg, but {} has {}".format( func, len(lambda_ast.args.arrgs))) arg = lambda_ast.args.args[0].arg body = astutils.unparse(lambda_ast.body) except AttributeError: raise SyntaxError("Could not parse func {}".format(func)) code = "__out = {}".format(body) num_elements = reduce(lambda x, y: x * y, inparr.shape) if num_elements == 1: inp = state.add_read(in_array) out = state.add_write(out_array) tasklet = state.add_tasklet("_elementwise_", {arg}, {'__out'}, code) state.add_edge(inp, None, tasklet, arg, Memlet.from_array(in_array, inparr)) state.add_edge(tasklet, '__out', out, None, Memlet.from_array(out_array, outarr)) else: state.add_mapped_tasklet( name="_elementwise_", map_ranges={ '__i%d' % i: '0:%s' % n for i, n in enumerate(inparr.shape) }, inputs={ arg: Memlet.simple( in_array, ','.join(['__i%d' % i for i in range(len(inparr.shape))])) }, code=code, outputs={ '__out': Memlet.simple( out_array, ','.join(['__i%d' % i for i in range(len(inparr.shape))])) }, external_edges=True) return out_array
def _bcgather(pv: 'ProgramVisitor', sdfg: SDFG, state: SDFGState, in_buffer: str, out_buffer: str, block_sizes: Union[str, Sequence[Union[sp.Expr, Number]]]): from dace.libraries.pblas.nodes.pgeadd import BlockCyclicGather libnode = BlockCyclicGather('_BCGather_') inbuf_range = None if isinstance(in_buffer, tuple): inbuf_name, inbuf_range = in_buffer else: inbuf_name = in_buffer in_desc = sdfg.arrays[inbuf_name] inbuf_node = state.add_read(inbuf_name) bsizes_range = None if isinstance(block_sizes, (list, tuple)): if isinstance(block_sizes[0], str): bsizes_name, bsizes_range = block_sizes bsizes_desc = sdfg.arrays[bsizes_name] bsizes_node = state.add_read(bsizes_name) else: bsizes_name, bsizes_desc = sdfg.add_temp_transient( (len(block_sizes), ), dtype=dace.int32) bsizes_node = state.add_access(bsizes_name) bsizes_tasklet = state.add_tasklet( '_set_bsizes_', {}, {'__out'}, ";".join([ "__out[{}] = {}".format(i, sz) for i, sz in enumerate(block_sizes) ])) state.add_edge(bsizes_tasklet, '__out', bsizes_node, None, Memlet.from_array(bsizes_name, bsizes_desc)) else: bsizes_name = block_sizes bsizes_desc = sdfg.arrays[bsizes_name] bsizes_node = state.add_read(bsizes_name) outbuf_range = None if isinstance(out_buffer, tuple): outbuf_name, outbuf_range = out_buffer else: outbuf_name = out_buffer out_desc = sdfg.arrays[outbuf_name] outbuf_node = state.add_write(outbuf_name) if inbuf_range: inbuf_mem = Memlet.simple(inbuf_name, inbuf_range) else: inbuf_mem = Memlet.from_array(inbuf_name, in_desc) if bsizes_range: bsizes_mem = Memlet.simple(bsizes_name, bsizes_range) else: bsizes_mem = Memlet.from_array(bsizes_name, bsizes_desc) if outbuf_range: outbuf_mem = Memlet.simple(outbuf_name, outbuf_range) else: outbuf_mem = Memlet.from_array(outbuf_name, out_desc) state.add_edge(inbuf_node, None, libnode, '_inbuffer', inbuf_mem) state.add_edge(bsizes_node, None, libnode, '_block_sizes', bsizes_mem) state.add_edge(libnode, '_outbuffer', outbuf_node, None, outbuf_mem) return None
def nest_state_subgraph(sdfg: SDFG, state: SDFGState, subgraph: SubgraphView, name: Optional[str] = None, full_data: bool = False) -> nodes.NestedSDFG: """ Turns a state subgraph into a nested SDFG. Operates in-place. :param sdfg: The SDFG containing the state subgraph. :param state: The state containing the subgraph. :param subgraph: Subgraph to nest. :param name: An optional name for the nested SDFG. :param full_data: If True, nests entire input/output data. :return: The nested SDFG node. :raise KeyError: Some or all nodes in the subgraph are not located in this state, or the state does not belong to the given SDFG. :raise ValueError: The subgraph is contained in more than one scope. """ if state.parent != sdfg: raise KeyError('State does not belong to given SDFG') if subgraph.graph != state: raise KeyError('Subgraph does not belong to given state') # Find the top-level scope scope_tree = state.scope_tree() scope_dict = state.scope_dict() scope_dict_children = state.scope_dict(True) top_scopenode = -1 # Initialized to -1 since "None" already means top-level for node in subgraph.nodes(): if node not in scope_dict: raise KeyError('Node not found in state') # If scope entry/exit, ensure entire scope is in subgraph if isinstance(node, nodes.EntryNode): scope_nodes = scope_dict_children[node] if any(n not in subgraph.nodes() for n in scope_nodes): raise ValueError('Subgraph contains partial scopes (entry)') elif isinstance(node, nodes.ExitNode): entry = state.entry_node(node) scope_nodes = scope_dict_children[entry] + [entry] if any(n not in subgraph.nodes() for n in scope_nodes): raise ValueError('Subgraph contains partial scopes (exit)') scope_node = scope_dict[node] if scope_node not in subgraph.nodes(): if top_scopenode != -1 and top_scopenode != scope_node: raise ValueError( 'Subgraph is contained in more than one scope') top_scopenode = scope_node scope = scope_tree[top_scopenode] ### # Collect inputs and outputs of the nested SDFG inputs: List[MultiConnectorEdge] = [] outputs: List[MultiConnectorEdge] = [] for node in subgraph.source_nodes(): inputs.extend(state.in_edges(node)) for node in subgraph.sink_nodes(): outputs.extend(state.out_edges(node)) # Collect transients not used outside of subgraph (will be removed of # top-level graph) data_in_subgraph = set(n.data for n in subgraph.nodes() if isinstance(n, nodes.AccessNode)) # Find other occurrences in SDFG other_nodes = set( n.data for s in sdfg.nodes() for n in s.nodes() if isinstance(n, nodes.AccessNode) and n not in subgraph.nodes()) subgraph_transients = set() for data in data_in_subgraph: datadesc = sdfg.arrays[data] if datadesc.transient and data not in other_nodes: subgraph_transients.add(data) # All transients of edges between code nodes are also added to nested graph for edge in subgraph.edges(): if (isinstance(edge.src, nodes.CodeNode) and isinstance(edge.dst, nodes.CodeNode)): subgraph_transients.add(edge.data.data) # Collect data used in access nodes within subgraph (will be referenced in # full upon nesting) input_arrays = set() output_arrays = set() for node in subgraph.nodes(): if (isinstance(node, nodes.AccessNode) and node.data not in subgraph_transients): if state.out_degree(node) > 0: input_arrays.add(node.data) if state.in_degree(node) > 0: output_arrays.add(node.data) # Create the nested SDFG nsdfg = SDFG(name or 'nested_' + state.label) # Transients are added to the nested graph as-is for name in subgraph_transients: nsdfg.add_datadesc(name, sdfg.arrays[name]) # Input/output data that are not source/sink nodes are added to the graph # as non-transients for name in (input_arrays | output_arrays): datadesc = copy.deepcopy(sdfg.arrays[name]) datadesc.transient = False nsdfg.add_datadesc(name, datadesc) # Connected source/sink nodes outside subgraph become global data # descriptors in nested SDFG input_names = [] output_names = [] for edge in inputs: if edge.data.data is None: # Skip edges with an empty memlet continue name = '__in_' + edge.data.data datadesc = copy.deepcopy(sdfg.arrays[edge.data.data]) datadesc.transient = False if not full_data: datadesc.shape = edge.data.subset.size() input_names.append( nsdfg.add_datadesc(name, datadesc, find_new_name=True)) for edge in outputs: if edge.data.data is None: # Skip edges with an empty memlet continue name = '__out_' + edge.data.data datadesc = copy.deepcopy(sdfg.arrays[edge.data.data]) datadesc.transient = False if not full_data: datadesc.shape = edge.data.subset.size() output_names.append( nsdfg.add_datadesc(name, datadesc, find_new_name=True)) ################### # Add scope symbols to the nested SDFG for v in scope.defined_vars: if v in sdfg.symbols: sym = sdfg.symbols[v] nsdfg.add_symbol(v, sym.dtype) # Create nested state nstate = nsdfg.add_state() # Add subgraph nodes and edges to nested state nstate.add_nodes_from(subgraph.nodes()) for e in subgraph.edges(): nstate.add_edge(e.src, e.src_conn, e.dst, e.dst_conn, e.data) # Modify nested SDFG parents in subgraph for node in subgraph.nodes(): if isinstance(node, nodes.NestedSDFG): node.sdfg.parent = nstate node.sdfg.parent_sdfg = nsdfg # Add access nodes and edges as necessary edges_to_offset = [] for name, edge in zip(input_names, inputs): node = nstate.add_read(name) new_edge = copy.deepcopy(edge.data) new_edge.data = name edges_to_offset.append((edge, nstate.add_edge(node, None, edge.dst, edge.dst_conn, new_edge))) for name, edge in zip(output_names, outputs): node = nstate.add_write(name) new_edge = copy.deepcopy(edge.data) new_edge.data = name edges_to_offset.append((edge, nstate.add_edge(edge.src, edge.src_conn, node, None, new_edge))) # Offset memlet paths inside nested SDFG according to subsets for original_edge, new_edge in edges_to_offset: for edge in nstate.memlet_tree(new_edge): edge.data.data = new_edge.data.data if not full_data: edge.data.subset.offset(original_edge.data.subset, True) # Add nested SDFG node to the input state nested_sdfg = state.add_nested_sdfg(nsdfg, None, set(input_names) | input_arrays, set(output_names) | output_arrays) # Reconnect memlets to nested SDFG for name, edge in zip(input_names, inputs): if full_data: data = Memlet.from_array(edge.data.data, sdfg.arrays[edge.data.data]) else: data = edge.data state.add_edge(edge.src, edge.src_conn, nested_sdfg, name, data) for name, edge in zip(output_names, outputs): if full_data: data = Memlet.from_array(edge.data.data, sdfg.arrays[edge.data.data]) else: data = edge.data state.add_edge(nested_sdfg, name, edge.dst, edge.dst_conn, data) # Connect access nodes to internal input/output data as necessary entry = scope.entry exit = scope.exit for name in input_arrays: node = state.add_read(name) if entry is not None: state.add_nedge(entry, node, EmptyMemlet()) state.add_edge(node, None, nested_sdfg, name, Memlet.from_array(name, sdfg.arrays[name])) for name in output_arrays: node = state.add_write(name) if exit is not None: state.add_nedge(node, exit, EmptyMemlet()) state.add_edge(nested_sdfg, name, node, None, Memlet.from_array(name, sdfg.arrays[name])) # Remove subgraph nodes from graph state.remove_nodes_from(subgraph.nodes()) # Remove subgraph transients from top-level graph for transient in subgraph_transients: del sdfg.arrays[transient] return nested_sdfg
def _matmult(visitor, sdfg: SDFG, state: SDFGState, op1: str, op2: str): from dace.libraries.blas.nodes.matmul import MatMul # Avoid import loop arr1 = sdfg.arrays[op1] arr2 = sdfg.arrays[op2] if len(arr1.shape) > 1 and len(arr2.shape) > 1: # matrix * matrix if len(arr1.shape) > 3 or len(arr2.shape) > 3: raise SyntaxError( 'Matrix multiplication of tensors of dimensions > 3 ' 'not supported') if arr1.shape[-1] != arr2.shape[-2]: raise SyntaxError('Matrix dimension mismatch %s != %s' % (arr1.shape[-1], arr2.shape[-2])) from dace.libraries.blas.nodes.matmul import _get_batchmm_opts # Determine batched multiplication bopt = _get_batchmm_opts(arr1.shape, arr1.strides, arr2.shape, arr2.strides, None, None) if bopt: output_shape = (bopt['b'], arr1.shape[-2], arr2.shape[-1]) else: output_shape = (arr1.shape[-2], arr2.shape[-1]) elif len(arr1.shape) == 2 and len(arr2.shape) == 1: # matrix * vector if arr1.shape[1] != arr2.shape[0]: raise SyntaxError("Number of matrix columns {} must match" "size of vector {}.".format( arr1.shape[1], arr2.shape[0])) output_shape = (arr1.shape[0], ) elif len(arr1.shape) == 1 and len(arr2.shape) == 1: # vector * vector if arr1.shape[0] != arr2.shape[0]: raise SyntaxError("Vectors in vector product must have same size: " "{} vs. {}".format(arr1.shape[0], arr2.shape[0])) output_shape = (1, ) else: # Dunno what this is, bail raise SyntaxError( "Cannot multiply arrays with shapes: {} and {}".format( arr1.shape, arr2.shape)) type1 = arr1.dtype.type type2 = arr2.dtype.type restype = dace.DTYPE_TO_TYPECLASS[np.result_type(type1, type2).type] op3, arr3 = sdfg.add_temp_transient(output_shape, restype, arr1.storage) acc1 = state.add_read(op1) acc2 = state.add_read(op2) acc3 = state.add_write(op3) tasklet = MatMul('_MatMult_', restype) state.add_node(tasklet) state.add_edge(acc1, None, tasklet, '_a', dace.Memlet.from_array(op1, arr1)) state.add_edge(acc2, None, tasklet, '_b', dace.Memlet.from_array(op2, arr2)) state.add_edge(tasklet, '_c', acc3, None, dace.Memlet.from_array(op3, arr3)) return op3
def nest_state_subgraph(sdfg: SDFG, state: SDFGState, subgraph: SubgraphView, name: Optional[str] = None, full_data: bool = False) -> nodes.NestedSDFG: """ Turns a state subgraph into a nested SDFG. Operates in-place. :param sdfg: The SDFG containing the state subgraph. :param state: The state containing the subgraph. :param subgraph: Subgraph to nest. :param name: An optional name for the nested SDFG. :param full_data: If True, nests entire input/output data. :return: The nested SDFG node. :raise KeyError: Some or all nodes in the subgraph are not located in this state, or the state does not belong to the given SDFG. :raise ValueError: The subgraph is contained in more than one scope. """ if state.parent != sdfg: raise KeyError('State does not belong to given SDFG') if subgraph is not state and subgraph.graph is not state: raise KeyError('Subgraph does not belong to given state') # Find the top-level scope scope_tree = state.scope_tree() scope_dict = state.scope_dict() scope_dict_children = state.scope_children() top_scopenode = -1 # Initialized to -1 since "None" already means top-level for node in subgraph.nodes(): if node not in scope_dict: raise KeyError('Node not found in state') # If scope entry/exit, ensure entire scope is in subgraph if isinstance(node, nodes.EntryNode): scope_nodes = scope_dict_children[node] if any(n not in subgraph.nodes() for n in scope_nodes): raise ValueError('Subgraph contains partial scopes (entry)') elif isinstance(node, nodes.ExitNode): entry = state.entry_node(node) scope_nodes = scope_dict_children[entry] + [entry] if any(n not in subgraph.nodes() for n in scope_nodes): raise ValueError('Subgraph contains partial scopes (exit)') scope_node = scope_dict[node] if scope_node not in subgraph.nodes(): if top_scopenode != -1 and top_scopenode != scope_node: raise ValueError('Subgraph is contained in more than one scope') top_scopenode = scope_node scope = scope_tree[top_scopenode] ### # Consolidate edges in top scope utils.consolidate_edges(sdfg, scope) snodes = subgraph.nodes() # Collect inputs and outputs of the nested SDFG inputs: List[MultiConnectorEdge] = [] outputs: List[MultiConnectorEdge] = [] for node in snodes: for edge in state.in_edges(node): if edge.src not in snodes: inputs.append(edge) for edge in state.out_edges(node): if edge.dst not in snodes: outputs.append(edge) # Collect transients not used outside of subgraph (will be removed of # top-level graph) data_in_subgraph = set(n.data for n in subgraph.nodes() if isinstance(n, nodes.AccessNode)) # Find other occurrences in SDFG other_nodes = set(n.data for s in sdfg.nodes() for n in s.nodes() if isinstance(n, nodes.AccessNode) and n not in subgraph.nodes()) subgraph_transients = set() for data in data_in_subgraph: datadesc = sdfg.arrays[data] if datadesc.transient and data not in other_nodes: subgraph_transients.add(data) # All transients of edges between code nodes are also added to nested graph for edge in subgraph.edges(): if (isinstance(edge.src, nodes.CodeNode) and isinstance(edge.dst, nodes.CodeNode)): subgraph_transients.add(edge.data.data) # Collect data used in access nodes within subgraph (will be referenced in # full upon nesting) input_arrays = set() output_arrays = {} for node in subgraph.nodes(): if (isinstance(node, nodes.AccessNode) and node.data not in subgraph_transients): if node.has_reads(state): input_arrays.add(node.data) if node.has_writes(state): output_arrays[node.data] = state.in_edges(node)[0].data.wcr # Create the nested SDFG nsdfg = SDFG(name or 'nested_' + state.label) # Transients are added to the nested graph as-is for name in subgraph_transients: nsdfg.add_datadesc(name, sdfg.arrays[name]) # Input/output data that are not source/sink nodes are added to the graph # as non-transients for name in (input_arrays | output_arrays.keys()): datadesc = copy.deepcopy(sdfg.arrays[name]) datadesc.transient = False nsdfg.add_datadesc(name, datadesc) # Connected source/sink nodes outside subgraph become global data # descriptors in nested SDFG input_names = {} output_names = {} global_subsets: Dict[str, Tuple[str, Subset]] = {} for edge in inputs: if edge.data.data is None: # Skip edges with an empty memlet continue name = edge.data.data if name not in global_subsets: datadesc = copy.deepcopy(sdfg.arrays[edge.data.data]) datadesc.transient = False if not full_data: datadesc.shape = edge.data.subset.size() new_name = nsdfg.add_datadesc(name, datadesc, find_new_name=True) global_subsets[name] = (new_name, edge.data.subset) else: new_name, subset = global_subsets[name] if not full_data: new_subset = union(subset, edge.data.subset) if new_subset is None: new_subset = Range.from_array(sdfg.arrays[name]) global_subsets[name] = (new_name, new_subset) nsdfg.arrays[new_name].shape = new_subset.size() input_names[edge] = new_name for edge in outputs: if edge.data.data is None: # Skip edges with an empty memlet continue name = edge.data.data if name not in global_subsets: datadesc = copy.deepcopy(sdfg.arrays[edge.data.data]) datadesc.transient = False if not full_data: datadesc.shape = edge.data.subset.size() new_name = nsdfg.add_datadesc(name, datadesc, find_new_name=True) global_subsets[name] = (new_name, edge.data.subset) else: new_name, subset = global_subsets[name] if not full_data: new_subset = union(subset, edge.data.subset) if new_subset is None: new_subset = Range.from_array(sdfg.arrays[name]) global_subsets[name] = (new_name, new_subset) nsdfg.arrays[new_name].shape = new_subset.size() output_names[edge] = new_name ################### # Add scope symbols to the nested SDFG defined_vars = set( symbolic.pystr_to_symbolic(s) for s in (state.symbols_defined_at(top_scopenode).keys() | sdfg.symbols)) for v in defined_vars: if v in sdfg.symbols: sym = sdfg.symbols[v] nsdfg.add_symbol(v, sym.dtype) # Add constants to nested SDFG for cstname, cstval in sdfg.constants.items(): nsdfg.add_constant(cstname, cstval) # Create nested state nstate = nsdfg.add_state() # Add subgraph nodes and edges to nested state nstate.add_nodes_from(subgraph.nodes()) for e in subgraph.edges(): nstate.add_edge(e.src, e.src_conn, e.dst, e.dst_conn, copy.deepcopy(e.data)) # Modify nested SDFG parents in subgraph for node in subgraph.nodes(): if isinstance(node, nodes.NestedSDFG): node.sdfg.parent = nstate node.sdfg.parent_sdfg = nsdfg node.sdfg.parent_nsdfg_node = node # Add access nodes and edges as necessary edges_to_offset = [] for edge, name in input_names.items(): node = nstate.add_read(name) new_edge = copy.deepcopy(edge.data) new_edge.data = name edges_to_offset.append((edge, nstate.add_edge(node, None, edge.dst, edge.dst_conn, new_edge))) for edge, name in output_names.items(): node = nstate.add_write(name) new_edge = copy.deepcopy(edge.data) new_edge.data = name edges_to_offset.append((edge, nstate.add_edge(edge.src, edge.src_conn, node, None, new_edge))) # Offset memlet paths inside nested SDFG according to subsets for original_edge, new_edge in edges_to_offset: for edge in nstate.memlet_tree(new_edge): edge.data.data = new_edge.data.data if not full_data: edge.data.subset.offset(global_subsets[original_edge.data.data][1], True) # Add nested SDFG node to the input state nested_sdfg = state.add_nested_sdfg(nsdfg, None, set(input_names.values()) | input_arrays, set(output_names.values()) | output_arrays.keys()) # Reconnect memlets to nested SDFG reconnected_in = set() reconnected_out = set() empty_input = None empty_output = None for edge in inputs: if edge.data.data is None: empty_input = edge continue name = input_names[edge] if name in reconnected_in: continue if full_data: data = Memlet.from_array(edge.data.data, sdfg.arrays[edge.data.data]) else: data = copy.deepcopy(edge.data) data.subset = global_subsets[edge.data.data][1] state.add_edge(edge.src, edge.src_conn, nested_sdfg, name, data) reconnected_in.add(name) for edge in outputs: if edge.data.data is None: empty_output = edge continue name = output_names[edge] if name in reconnected_out: continue if full_data: data = Memlet.from_array(edge.data.data, sdfg.arrays[edge.data.data]) else: data = copy.deepcopy(edge.data) data.subset = global_subsets[edge.data.data][1] data.wcr = edge.data.wcr state.add_edge(nested_sdfg, name, edge.dst, edge.dst_conn, data) reconnected_out.add(name) # Connect access nodes to internal input/output data as necessary entry = scope.entry exit = scope.exit for name in input_arrays: node = state.add_read(name) if entry is not None: state.add_nedge(entry, node, Memlet()) state.add_edge(node, None, nested_sdfg, name, Memlet.from_array(name, sdfg.arrays[name])) for name, wcr in output_arrays.items(): node = state.add_write(name) if exit is not None: state.add_nedge(node, exit, Memlet()) state.add_edge(nested_sdfg, name, node, None, Memlet(data=name, wcr=wcr)) # Graph was not reconnected, but needs to be if state.in_degree(nested_sdfg) == 0 and empty_input is not None: state.add_edge(empty_input.src, empty_input.src_conn, nested_sdfg, None, empty_input.data) if state.out_degree(nested_sdfg) == 0 and empty_output is not None: state.add_edge(nested_sdfg, None, empty_output.dst, empty_output.dst_conn, empty_output.data) # Remove subgraph nodes from graph state.remove_nodes_from(subgraph.nodes()) # Remove subgraph transients from top-level graph for transient in subgraph_transients: del sdfg.arrays[transient] # Remove newly isolated nodes due to memlet consolidation for edge in inputs: if state.in_degree(edge.src) + state.out_degree(edge.src) == 0: state.remove_node(edge.src) for edge in outputs: if state.in_degree(edge.dst) + state.out_degree(edge.dst) == 0: state.remove_node(edge.dst) return nested_sdfg
def apply(self, state: SDFGState, sdfg: SDFG) -> nodes.AccessNode: dnode: nodes.AccessNode = self.access if self.expr_index == 0: edges = state.out_edges(dnode) else: edges = state.in_edges(dnode) # To understand how many components we need to create, all map ranges # throughout memlet paths must match exactly. We thus create a # dictionary of unique ranges mapping: Dict[Tuple[subsets.Range], List[gr.MultiConnectorEdge[mm.Memlet]]] = defaultdict( list) ranges = {} for edge in edges: mpath = state.memlet_path(edge) ranges[edge] = _collect_map_ranges(state, mpath) mapping[tuple(r[1] for r in ranges[edge])].append(edge) # Collect all edges with the same memory access pattern components_to_create: Dict[ Tuple[symbolic.SymbolicType], List[gr.MultiConnectorEdge[mm.Memlet]]] = defaultdict(list) for edges_with_same_range in mapping.values(): for edge in edges_with_same_range: # Get memlet path and innermost edge mpath = state.memlet_path(edge) innermost_edge = copy.deepcopy(mpath[-1] if self.expr_index == 0 else mpath[0]) # Store memlets of the same access in the same component expr = _canonicalize_memlet(innermost_edge.data, ranges[edge]) components_to_create[expr].append((innermost_edge, edge)) components = list(components_to_create.values()) # Split out components that have dependencies between them to avoid # deadlocks if self.expr_index == 0: ccs_to_add = [] for i, component in enumerate(components): edges_to_remove = set() for cedge in component: if any( nx.has_path(state.nx, o[1].dst, cedge[1].dst) for o in component if o is not cedge): ccs_to_add.append([cedge]) edges_to_remove.add(cedge) if edges_to_remove: components[i] = [ c for c in component if c not in edges_to_remove ] components.extend(ccs_to_add) # End of split desc = sdfg.arrays[dnode.data] # Create new streams of shape 1 streams = {} mpaths = {} for edge in edges: name, newdesc = sdfg.add_stream(dnode.data, desc.dtype, buffer_size=self.buffer_size, storage=self.storage, transient=True, find_new_name=True) streams[edge] = name mpath = state.memlet_path(edge) mpaths[edge] = mpath # Replace memlets in path with stream access for e in mpath: e.data = mm.Memlet(data=name, subset='0', other_subset=e.data.other_subset) if isinstance(e.src, nodes.NestedSDFG): e.data.dynamic = True _streamify_recursive(e.src, e.src_conn, newdesc) if isinstance(e.dst, nodes.NestedSDFG): e.data.dynamic = True _streamify_recursive(e.dst, e.dst_conn, newdesc) # Replace access node and memlet tree with one access if self.expr_index == 0: replacement = state.add_read(name) state.remove_edge(edge) state.add_edge(replacement, edge.src_conn, edge.dst, edge.dst_conn, edge.data) else: replacement = state.add_write(name) state.remove_edge(edge) state.add_edge(edge.src, edge.src_conn, replacement, edge.dst_conn, edge.data) # Make read/write components ionodes = [] for component in components: # Pick the first edge as the edge to make the component from innermost_edge, outermost_edge = component[0] mpath = mpaths[outermost_edge] mapname = streams[outermost_edge] innermost_edge.data.other_subset = None # Get edge data and streams if self.expr_index == 0: opname = 'read' path = [e.dst for e in mpath[:-1]] rmemlets = [(dnode, '__inp', innermost_edge.data)] wmemlets = [] for i, (_, edge) in enumerate(component): name = streams[edge] ionode = state.add_write(name) ionodes.append(ionode) wmemlets.append( (ionode, '__out%d' % i, mm.Memlet(data=name, subset='0'))) code = '\n'.join('__out%d = __inp' % i for i in range(len(component))) else: # More than one input stream might mean a data race, so we only # address the first one in the tasklet code if len(component) > 1: warnings.warn( f'More than one input found for the same index for {dnode.data}' ) opname = 'write' path = [state.entry_node(e.src) for e in reversed(mpath[1:])] wmemlets = [(dnode, '__out', innermost_edge.data)] rmemlets = [] for i, (_, edge) in enumerate(component): name = streams[edge] ionode = state.add_read(name) ionodes.append(ionode) rmemlets.append( (ionode, '__inp%d' % i, mm.Memlet(data=name, subset='0'))) code = '__out = __inp0' # Create map structure for read/write component maps = [] for entry in path: map: nodes.Map = entry.map maps.append( state.add_map(f'__s{opname}_{mapname}', [(p, r) for p, r in zip(map.params, map.range)], map.schedule)) tasklet = state.add_tasklet( f'{opname}_{mapname}', {m[1] for m in rmemlets}, {m[1] for m in wmemlets}, code, ) for node, cname, memlet in rmemlets: state.add_memlet_path(node, *(me for me, _ in maps), tasklet, dst_conn=cname, memlet=memlet) for node, cname, memlet in wmemlets: state.add_memlet_path(tasklet, *(mx for _, mx in reversed(maps)), node, src_conn=cname, memlet=memlet) return ionodes