def _wait(pv: 'ProgramVisitor', sdfg: SDFG, state: SDFGState, request: str): from dace.libraries.mpi.nodes.wait import Wait libnode = Wait('_Wait_') req_range = None if isinstance(request, tuple): req_name, req_range = request else: req_name = request desc = sdfg.arrays[req_name] req_node = state.add_access(req_name) src = sdfg.add_temp_transient([1], dtypes.int32) src_node = state.add_write(src[0]) tag = sdfg.add_temp_transient([1], dtypes.int32) tag_node = state.add_write(tag[0]) if req_range: req_mem = Memlet.simple(req_name, req_range) else: req_mem = Memlet.from_array(req_name, desc) state.add_edge(req_node, None, libnode, '_request', req_mem) state.add_edge(libnode, '_stat_source', src_node, None, Memlet.from_array(*src)) state.add_edge(libnode, '_stat_tag', tag_node, None, Memlet.from_array(*tag)) return None
def _reduce(sdfg: SDFG, state: SDFGState, redfunction: Callable[[Any, Any], Any], in_array: str, out_array=None, axis=None, identity=None): if out_array is None: inarr = in_array # Convert axes to tuple if axis is not None and not isinstance(axis, (tuple, list)): axis = (axis, ) if axis is not None: axis = tuple(pystr_to_symbolic(a) for a in axis) input_subset = parse_memlet_subset(sdfg.arrays[inarr], ast.parse(in_array).body[0].value, {}) input_memlet = Memlet.simple(inarr, input_subset) output_shape = None if axis is None: output_shape = [1] else: output_subset = copy.deepcopy(input_subset) output_subset.pop(axis) output_shape = output_subset.size() outarr, arr = sdfg.add_temp_transient(output_shape, sdfg.arrays[inarr].dtype, sdfg.arrays[inarr].storage) output_memlet = Memlet.from_array(outarr, arr) else: inarr = in_array outarr = out_array # Convert axes to tuple if axis is not None and not isinstance(axis, (tuple, list)): axis = (axis, ) if axis is not None: axis = tuple(pystr_to_symbolic(a) for a in axis) # Compute memlets input_subset = parse_memlet_subset(sdfg.arrays[inarr], ast.parse(in_array).body[0].value, {}) input_memlet = Memlet.simple(inarr, input_subset) output_subset = parse_memlet_subset(sdfg.arrays[outarr], ast.parse(out_array).body[0].value, {}) output_memlet = Memlet.simple(outarr, output_subset) # Create reduce subgraph inpnode = state.add_read(inarr) rednode = state.add_reduce(redfunction, axis, identity) outnode = state.add_write(outarr) state.add_nedge(inpnode, rednode, input_memlet) state.add_nedge(rednode, outnode, output_memlet) if out_array is None: return outarr else: return []
def _assignop(sdfg: SDFG, state: SDFGState, op1: str, opcode: str, opname: str): """ Implements a general element-wise array assignment operator. """ arr1 = sdfg.arrays[op1] name, _ = sdfg.add_temp_transient(arr1.shape, arr1.dtype, arr1.storage) write_memlet = None if opcode: write_memlet = Memlet.simple( name, ','.join(['__i%d' % i for i in range(len(arr1.shape))]), wcr_str='lambda x, y: x %s y' % opcode) else: write_memlet = Memlet.simple( name, ','.join(['__i%d' % i for i in range(len(arr1.shape))])) state.add_mapped_tasklet( "_%s_" % opname, {'__i%d' % i: '0:%s' % s for i, s in enumerate(arr1.shape)}, { '__in1': Memlet.simple( op1, ','.join(['__i%d' % i for i in range(len(arr1.shape))])) }, '__out = __in1', {'__out': write_memlet}, external_edges=True) return name
def _define_local_ex(sdfg: SDFG, state: SDFGState, shape: Shape, dtype: dace.typeclass, storage: dtypes.StorageType = dtypes.StorageType.Default): """ Defines a local array in a DaCe program. """ name, _ = sdfg.add_temp_transient(shape, dtype, storage=storage) return name
def eye(sdfg: SDFG, state: SDFGState, N, M=None, k=0, dtype=dace.float64): M = M or N name, _ = sdfg.add_temp_transient([N, M], dtype) state.add_mapped_tasklet('eye', dict(i='0:%s' % N, j='0:%s' % M), {}, 'val = 1 if i == (j - %s) else 0' % k, dict(val=dace.Memlet.simple(name, 'i, j')), external_edges=True) return name
def _transpose(sdfg: SDFG, state: SDFGState, inpname: str): arr1 = sdfg.arrays[inpname] restype = arr1.dtype outname, arr2 = sdfg.add_temp_transient((arr1.shape[1], arr1.shape[0]), restype, arr1.storage) acc1 = state.add_read(inpname) acc2 = state.add_write(outname) import dace.libraries.blas # Avoid import loop tasklet = dace.libraries.blas.Transpose('_Transpose_', restype) state.add_node(tasklet) state.add_edge(acc1, None, tasklet, '_inp', dace.Memlet.from_array(inpname, arr1)) state.add_edge(tasklet, '_out', acc2, None, dace.Memlet.from_array(outname, arr2)) return outname
def _simple_call(sdfg: SDFG, state: SDFGState, inpname: str, func: str, restype: dace.typeclass = None): """ Implements a simple call of the form `out = func(inp)`. """ inparr = sdfg.arrays[inpname] if restype is None: restype = sdfg.arrays[inpname].dtype outname, outarr = sdfg.add_temp_transient(inparr.shape, restype, inparr.storage) num_elements = reduce(lambda x, y: x * y, inparr.shape) if num_elements == 1: inp = state.add_read(inpname) out = state.add_write(outname) tasklet = state.add_tasklet(func, {'__inp'}, {'__out'}, '__out = {f}(__inp)'.format(f=func)) state.add_edge(inp, None, tasklet, '__inp', Memlet.from_array(inpname, inparr)) state.add_edge(tasklet, '__out', out, None, Memlet.from_array(outname, outarr)) else: state.add_mapped_tasklet( name=func, map_ranges={ '__i%d' % i: '0:%s' % n for i, n in enumerate(inparr.shape) }, inputs={ '__inp': Memlet.simple( inpname, ','.join(['__i%d' % i for i in range(len(inparr.shape))])) }, code='__out = {f}(__inp)'.format(f=func), outputs={ '__out': Memlet.simple( outname, ','.join(['__i%d' % i for i in range(len(inparr.shape))])) }, external_edges=True) return outname
def _binop(sdfg: SDFG, state: SDFGState, op1: str, op2: str, opcode: str, opname: str, restype: dace.typeclass): """ Implements a general element-wise array binary operator. """ arr1 = sdfg.arrays[op1] arr2 = sdfg.arrays[op2] out_shape, all_idx_dict, all_idx, arr1_idx, arr2_idx = _broadcast_together( arr1.shape, arr2.shape) name, _ = sdfg.add_temp_transient(out_shape, restype, arr1.storage) state.add_mapped_tasklet("_%s_" % opname, all_idx_dict, { '__in1': Memlet.simple(op1, arr1_idx), '__in2': Memlet.simple(op2, arr2_idx) }, '__out = __in1 %s __in2' % opcode, {'__out': Memlet.simple(name, all_idx)}, external_edges=True) return name
def _unop(sdfg: SDFG, state: SDFGState, op1: str, opcode: str, opname: str): """ Implements a general element-wise array unary operator. """ arr1 = sdfg.arrays[op1] name, _ = sdfg.add_temp_transient(arr1.shape, arr1.dtype, arr1.storage) state.add_mapped_tasklet( "_%s_" % opname, {'__i%d' % i: '0:%s' % s for i, s in enumerate(arr1.shape)}, { '__in1': Memlet.simple( op1, ','.join(['__i%d' % i for i in range(len(arr1.shape))])) }, '__out = %s __in1' % opcode, { '__out': Memlet.simple( name, ','.join(['__i%d' % i for i in range(len(arr1.shape))])) }, external_edges=True) return name
def _array_x_binop(visitor: 'ProgramVisitor', sdfg: SDFG, state: SDFGState, op1: str, op2: str, op: str, opcode: str): arr1 = sdfg.arrays[op1] type1 = arr1.dtype.type isscal1 = _is_scalar(sdfg, op1) isnum1 = isscal1 and (op1 in visitor.numbers.values()) if isnum1: type1 = inverse_dict_lookup(visitor.numbers, op1) arr2 = sdfg.arrays[op2] type2 = arr2.dtype.type isscal2 = _is_scalar(sdfg, op2) isnum2 = isscal2 and (op2 in visitor.numbers.values()) if isnum2: type2 = inverse_dict_lookup(visitor.numbers, op2) if _is_op_boolean(op): restype = dace.bool else: restype = dace.DTYPE_TO_TYPECLASS[np.result_type(type1, type2).type] if isscal1 and isscal2: arr1 = sdfg.arrays[op1] arr2 = sdfg.arrays[op2] op3, arr3 = sdfg.add_temp_transient([1], restype, arr2.storage) tasklet = state.add_tasklet('_SS%s_' % op, {'s1', 's2'}, {'s3'}, 's3 = s1 %s s2' % opcode) n1 = state.add_read(op1) n2 = state.add_read(op2) n3 = state.add_write(op3) state.add_edge(n1, None, tasklet, 's1', dace.Memlet.from_array(op1, arr1)) state.add_edge(n2, None, tasklet, 's2', dace.Memlet.from_array(op2, arr2)) state.add_edge(tasklet, 's3', n3, None, dace.Memlet.from_array(op3, arr3)) return op3 else: return _binop(sdfg, state, op1, op2, opcode, op, restype)
def _distr_matmult(pv: 'ProgramVisitor', sdfg: SDFG, state: SDFGState, opa: str, opb: str, shape: Sequence[Union[sp.Expr, Number]], a_block_sizes: Union[str, Sequence[Union[sp.Expr, Number]]] = None, b_block_sizes: Union[str, Sequence[Union[sp.Expr, Number]]] = None, c_block_sizes: Union[str, Sequence[Union[sp.Expr, Number]]] = None): arra = sdfg.arrays[opa] arrb = sdfg.arrays[opb] if len(shape) == 3: gm, gn, gk = shape else: gm, gn = shape a_block_sizes = a_block_sizes or arra.shape if len(a_block_sizes) < 2: a_block_sizes = (a_block_sizes[0], 1) b_block_sizes = b_block_sizes or arrb.shape if len(b_block_sizes) < 2: b_block_sizes = (b_block_sizes[0], 1) if len(arra.shape) == 1 and len(arrb.shape) == 2: a_block_sizes, b_block_sizes = b_block_sizes, a_block_sizes a_bsizes_range = None if isinstance(a_block_sizes, (list, tuple)): if isinstance(a_block_sizes[0], str): a_bsizes_name, a_bsizes_range = a_block_sizes a_bsizes_desc = sdfg.arrays[a_bsizes_name] a_bsizes_node = state.add_read(a_bsizes_name) else: a_bsizes_name, a_bsizes_desc = sdfg.add_temp_transient( (len(a_block_sizes), ), dtype=dace.int32) a_bsizes_node = state.add_access(a_bsizes_name) a_bsizes_tasklet = state.add_tasklet( '_set_a_bsizes_', {}, {'__out'}, ";".join([ "__out[{}] = {}".format(i, sz) for i, sz in enumerate(a_block_sizes) ])) state.add_edge(a_bsizes_tasklet, '__out', a_bsizes_node, None, Memlet.from_array(a_bsizes_name, a_bsizes_desc)) else: a_bsizes_name = a_block_sizes a_bsizes_desc = sdfg.arrays[a_bsizes_name] a_bsizes_node = state.add_read(a_bsizes_name) b_bsizes_range = None if isinstance(a_block_sizes, (list, tuple)): if isinstance(a_block_sizes[0], str): b_bsizes_name, b_sizes_range = b_block_sizes b_bsizes_desc = sdfg.arrays[b_bsizes_name] b_bsizes_node = state.add_read(b_bsizes_name) else: b_bsizes_name, b_bsizes_desc = sdfg.add_temp_transient( (len(b_block_sizes), ), dtype=dace.int32) b_bsizes_node = state.add_access(b_bsizes_name) b_bsizes_tasklet = state.add_tasklet( '_set_b_sizes_', {}, {'__out'}, ";".join([ "__out[{}] = {}".format(i, sz) for i, sz in enumerate(b_block_sizes) ])) state.add_edge(b_bsizes_tasklet, '__out', b_bsizes_node, None, Memlet.from_array(b_bsizes_name, b_bsizes_desc)) else: b_bsizes_name = b_block_sizes b_bsizes_desc = sdfg.arrays[b_bsizes_name] b_bsizes_node = state.add_read(b_bsizes_name) if len(arra.shape) == 2 and len(arrb.shape) == 2: # Gemm from dace.libraries.pblas.nodes.pgemm import Pgemm tasklet = Pgemm("__DistrMatMult__", gm, gn, gk) m = arra.shape[0] n = arrb.shape[-1] out = sdfg.add_temp_transient((m, n), dtype=arra.dtype) elif len(arra.shape) == 2 and len(arrb.shape) == 1: # Gemv from dace.libraries.pblas.nodes.pgemv import Pgemv tasklet = Pgemv("__DistrMatVecMult__", m=gm, n=gn) if c_block_sizes: m = c_block_sizes[0] else: m = arra.shape[0] out = sdfg.add_temp_transient((m, ), dtype=arra.dtype) elif len(arra.shape) == 1 and len(arrb.shape) == 2: # Gemv transposed # Swap a and b opa, opb = opb, opa arra, arrb = arrb, arra from dace.libraries.pblas.nodes.pgemv import Pgemv tasklet = Pgemv("__DistrMatVecMult__", transa='T', m=gm, n=gn) if c_block_sizes: n = c_block_sizes[0] else: n = arra.shape[1] out = sdfg.add_temp_transient((n, ), dtype=arra.dtype) anode = state.add_read(opa) bnode = state.add_read(opb) cnode = state.add_write(out[0]) if a_bsizes_range: a_bsizes_mem = Memlet.simple(a_bsizes_name, a_bsizes_range) else: a_bsizes_mem = Memlet.from_array(a_bsizes_name, a_bsizes_desc) if b_bsizes_range: b_bsizes_mem = Memlet.simple(b_bsizes_name, b_bsizes_range) else: b_bsizes_mem = Memlet.from_array(b_bsizes_name, b_bsizes_desc) state.add_edge(anode, None, tasklet, '_a', Memlet.from_array(opa, arra)) state.add_edge(bnode, None, tasklet, '_b', Memlet.from_array(opb, arrb)) state.add_edge(a_bsizes_node, None, tasklet, '_a_block_sizes', a_bsizes_mem) state.add_edge(b_bsizes_node, None, tasklet, '_b_block_sizes', b_bsizes_mem) state.add_edge(tasklet, '_c', cnode, None, Memlet.from_array(*out)) return out[0]
def _matmult(visitor, sdfg: SDFG, state: SDFGState, op1: str, op2: str): from dace.libraries.blas.nodes.matmul import MatMul # Avoid import loop arr1 = sdfg.arrays[op1] arr2 = sdfg.arrays[op2] if len(arr1.shape) > 1 and len(arr2.shape) > 1: # matrix * matrix if len(arr1.shape) > 3 or len(arr2.shape) > 3: raise SyntaxError( 'Matrix multiplication of tensors of dimensions > 3 ' 'not supported') if arr1.shape[-1] != arr2.shape[-2]: raise SyntaxError('Matrix dimension mismatch %s != %s' % (arr1.shape[-1], arr2.shape[-2])) from dace.libraries.blas.nodes.matmul import _get_batchmm_opts # Determine batched multiplication bopt = _get_batchmm_opts(arr1.shape, arr1.strides, arr2.shape, arr2.strides, None, None) if bopt: output_shape = (bopt['b'], arr1.shape[-2], arr2.shape[-1]) else: output_shape = (arr1.shape[-2], arr2.shape[-1]) elif len(arr1.shape) == 2 and len(arr2.shape) == 1: # matrix * vector if arr1.shape[1] != arr2.shape[0]: raise SyntaxError("Number of matrix columns {} must match" "size of vector {}.".format( arr1.shape[1], arr2.shape[0])) output_shape = (arr1.shape[0], ) elif len(arr1.shape) == 1 and len(arr2.shape) == 1: # vector * vector if arr1.shape[0] != arr2.shape[0]: raise SyntaxError("Vectors in vector product must have same size: " "{} vs. {}".format(arr1.shape[0], arr2.shape[0])) output_shape = (1, ) else: # Dunno what this is, bail raise SyntaxError( "Cannot multiply arrays with shapes: {} and {}".format( arr1.shape, arr2.shape)) type1 = arr1.dtype.type type2 = arr2.dtype.type restype = dace.DTYPE_TO_TYPECLASS[np.result_type(type1, type2).type] op3, arr3 = sdfg.add_temp_transient(output_shape, restype, arr1.storage) acc1 = state.add_read(op1) acc2 = state.add_read(op2) acc3 = state.add_write(op3) tasklet = MatMul('_MatMult_', restype) state.add_node(tasklet) state.add_edge(acc1, None, tasklet, '_a', dace.Memlet.from_array(op1, arr1)) state.add_edge(acc2, None, tasklet, '_b', dace.Memlet.from_array(op2, arr2)) state.add_edge(tasklet, '_c', acc3, None, dace.Memlet.from_array(op3, arr3)) return op3
def _create_einsum_internal(sdfg: SDFG, state: SDFGState, einsum_string: str, *arrays: str, dtype: Optional[dtypes.typeclass] = None, optimize: bool = False, output: Optional[str] = None, nodes: Optional[Dict[str, AccessNode]] = None, init_output: bool = None): # Infer shapes and strides of input/output arrays einsum = EinsumParser(einsum_string) if len(einsum.inputs) != len(arrays): raise ValueError('Invalid number of arrays for einsum expression') # Get shapes from arrays and verify dimensionality chardict = {} for inp, inpname in zip(einsum.inputs, arrays): inparr = sdfg.arrays[inpname] if len(inp) != len(inparr.shape): raise ValueError('Dimensionality mismatch in input "%s"' % inpname) for char, shp in zip(inp, inparr.shape): if char in chardict and shp != chardict[char]: raise ValueError('Dimension mismatch in einsum expression') chardict[char] = shp if optimize: # Try to import opt_einsum try: import opt_einsum as oe except (ModuleNotFoundError, NameError, ImportError): raise ImportError('To optimize einsum expressions, please install ' 'the "opt_einsum" package.') for char, shp in chardict.items(): if symbolic.issymbolic(shp): raise ValueError('Einsum optimization cannot be performed ' 'on symbolically-sized array dimension "%s" ' 'for subscript character "%s"' % (shp, char)) # Create optimal contraction path # noinspection PyTypeChecker _, path_info = oe.contract_path( einsum_string, *oe.helpers.build_views(einsum_string, chardict)) input_nodes = nodes or {arr: state.add_read(arr) for arr in arrays} result_node = None # Follow path and create a chain of operation SDFG states for pair, nonfree, expr, after, blas in path_info.contraction_list: result, result_node = _create_einsum_internal(sdfg, state, expr, arrays[pair[0]], arrays[pair[1]], dtype=dtype, optimize=False, output=None, nodes=input_nodes) arrays = ([a for i, a in enumerate(arrays) if i not in pair] + [result]) input_nodes[result] = result_node return arrays[0], result_node # END of einsum optimization input_nodes = nodes or {arr: state.add_read(arr) for arr in arrays} # Get output shape from chardict, or [1] for a scalar output output_shape = list(map(lambda k: chardict[k], einsum.output)) or [1] output_index = ','.join(o for o in einsum.output) or '0' if output is None: dtype = dtype or sdfg.arrays[arrays[0]].dtype output, odesc = sdfg.add_temp_transient(output_shape, dtype) to_init = True else: odesc = sdfg.arrays[output] dtype = dtype or odesc.dtype to_init = init_output or True is_conflicted = not all( all(indim in einsum.output for indim in inp) for inp in einsum.inputs) if not is_conflicted and init_output is None: to_init = False if not einsum.is_bmm(): # Fall back to "pure" SDFG einsum with conflict resolution c = state.add_write(output) # Add state before this one to initialize the output value if to_init: init_state = sdfg.add_state_before(state) if len(einsum.output) > 0: init_state.add_mapped_tasklet( 'einsum_reset', {k: '0:%s' % chardict[k] for k in einsum.output}, {}, 'out_%s = 0' % output, {'out_%s' % output: Memlet.simple(output, output_index)}, external_edges=True) else: # Scalar output t = init_state.add_tasklet('einsum_reset', set(), {'out_%s' % output}, 'out_%s = 0' % output) onode = init_state.add_write(output) init_state.add_edge(t, 'out_%s' % output, onode, None, Memlet.simple(output, '0')) wcr = 'lambda a,b: a+b' if is_conflicted else None # Pure einsum map state.add_mapped_tasklet( 'einsum', {k: '0:%s' % v for k, v in chardict.items()}, { 'inp_%s' % arr: Memlet.simple(arr, ','.join(inp)) for inp, arr in zip(einsum.inputs, arrays) }, 'out_%s = %s' % (output, ' * '.join('inp_%s' % arr for arr in arrays)), { 'out_%s' % output: Memlet.simple( output, output_index, wcr_str=wcr) }, input_nodes=input_nodes, output_nodes={output: c}, external_edges=True) else: # Represent einsum as a GEMM or batched GEMM (using library nodes) a_shape = sdfg.arrays[arrays[0]].shape b_shape = sdfg.arrays[arrays[1]].shape c_shape = output_shape a = input_nodes[arrays[0]] b = input_nodes[arrays[1]] c = state.add_write(output) # Compute GEMM dimensions and strides strides = dict( BATCH=prod([c_shape[dim] for dim in einsum.c_batch]), M=prod([a_shape[dim] for dim in einsum.a_only]), K=prod([a_shape[dim] for dim in einsum.a_sum]), N=prod([b_shape[dim] for dim in einsum.b_only]), sAM=prod(a_shape[einsum.a_only[-1] + 1:]) if einsum.a_only else 1, sAK=prod(a_shape[einsum.a_sum[-1] + 1:]) if einsum.a_sum else 1, sAB=prod(a_shape[einsum.a_batch[-1] + 1:]) if einsum.a_batch else 1, sBK=prod(b_shape[einsum.b_sum[-1] + 1:]) if einsum.b_sum else 1, sBN=prod(b_shape[einsum.b_only[-1] + 1:]) if einsum.b_only else 1, sBB=prod(b_shape[einsum.b_batch[-1] + 1:]) if einsum.b_batch else 1, sCM=prod(c_shape[einsum.c_a_only[-1] + 1:]) if einsum.c_a_only else 1, sCN=prod(c_shape[einsum.c_b_only[-1] + 1:]) if einsum.c_b_only else 1, sCB=prod(c_shape[einsum.c_batch[-1] + 1:]) if einsum.c_batch else 1) # Complement strides to make matrices as necessary if len(a_shape) == 1 and len(einsum.a_sum) == 1: strides['sAK'] = 1 strides['sAB'] = strides['sAM'] = strides['K'] if len(b_shape) == 1 and len(einsum.b_sum) == 1: strides['sBN'] = 1 strides['sBK'] = 1 strides['sBB'] = strides['K'] if len(c_shape) == 1 and len(einsum.a_sum) == len(einsum.b_sum): strides['sCN'] = 1 strides['sCB'] = strides['sCM'] = strides['N'] # Create nested SDFG for GEMM nsdfg = create_batch_gemm_sdfg(dtype, strides) nsdfg_node = state.add_nested_sdfg(nsdfg, None, {'X', 'Y'}, {'Z'}, strides) state.add_edge(a, None, nsdfg_node, 'X', Memlet.from_array(a.data, a.desc(sdfg))) state.add_edge(b, None, nsdfg_node, 'Y', Memlet.from_array(b.data, b.desc(sdfg))) state.add_edge(nsdfg_node, 'Z', c, None, Memlet.from_array(c.data, c.desc(sdfg))) return output, c
def _argminmax(sdfg: SDFG, state: SDFGState, a: str, axis, func, result_type=dace.int32, return_both=False): nest = NestedCall(sdfg, state) assert func in ['min', 'max'] if axis is None or type(axis) is not int: raise SyntaxError('Axis must be an int') a_arr = sdfg.arrays[a] if not 0 <= axis < len(a_arr.shape): raise SyntaxError("Expected 0 <= axis < len({}.shape), got {}".format( a, axis)) reduced_shape = list(copy.deepcopy(a_arr.shape)) reduced_shape.pop(axis) val_and_idx = dace.struct('_val_and_idx', val=a_arr.dtype, idx=result_type) # HACK: since identity cannot be specified for structs, we have to init the output array reduced_structs, reduced_struct_arr = sdfg.add_temp_transient( reduced_shape, val_and_idx) code = "__init = _val_and_idx(val={}, idx=-1)".format( dtypes.min_value(a_arr.dtype) if func == 'max' else dtypes.max_value(a_arr.dtype)) nest.add_state().add_mapped_tasklet( name="_arg{}_convert_".format(func), map_ranges={ '__i%d' % i: '0:%s' % n for i, n in enumerate(a_arr.shape) if i != axis }, inputs={}, code=code, outputs={ '__init': Memlet.simple( reduced_structs, ','.join('__i%d' % i for i in range(len(a_arr.shape)) if i != axis)) }, external_edges=True) nest.add_state().add_mapped_tasklet( name="_arg{}_reduce_".format(func), map_ranges={'__i%d' % i: '0:%s' % n for i, n in enumerate(a_arr.shape)}, inputs={ '__in': Memlet.simple( a, ','.join('__i%d' % i for i in range(len(a_arr.shape)))) }, code="__out = _val_and_idx(idx={}, val=__in)".format("__i%d" % axis), outputs={ '__out': Memlet.simple( reduced_structs, ','.join('__i%d' % i for i in range(len(a_arr.shape)) if i != axis), wcr_str=("lambda x, y:" "_val_and_idx(val={}(x.val, y.val), " "idx=(y.idx if x.val {} y.val else x.idx))").format( func, '<' if func == 'max' else '>')) }, external_edges=True) if return_both: outidx, outidxarr = sdfg.add_temp_transient( sdfg.arrays[reduced_structs].shape, result_type) outval, outvalarr = sdfg.add_temp_transient( sdfg.arrays[reduced_structs].shape, a_arr.dtype) nest.add_state().add_mapped_tasklet( name="_arg{}_extract_".format(func), map_ranges={ '__i%d' % i: '0:%s' % n for i, n in enumerate(a_arr.shape) if i != axis }, inputs={ '__in': Memlet.simple( reduced_structs, ','.join('__i%d' % i for i in range(len(a_arr.shape)) if i != axis)) }, code="__out_val = __in.val\n__out_idx = __in.idx", outputs={ '__out_val': Memlet.simple( outval, ','.join('__i%d' % i for i in range(len(a_arr.shape)) if i != axis)), '__out_idx': Memlet.simple( outidx, ','.join('__i%d' % i for i in range(len(a_arr.shape)) if i != axis)) }, external_edges=True) return nest, (outval, outidx) else: # map to result_type out, outarr = sdfg.add_temp_transient( sdfg.arrays[reduced_structs].shape, result_type) nest(_elementwise)("lambda x: x.idx", reduced_structs, out_array=out) return nest, out
def _elementwise(sdfg: SDFG, state: SDFGState, func: str, in_array: str, out_array=None): """Apply a lambda function to each element in the input""" inparr = sdfg.arrays[in_array] restype = sdfg.arrays[in_array].dtype if out_array is None: out_array, outarr = sdfg.add_temp_transient(inparr.shape, restype, inparr.storage) else: outarr = sdfg.arrays[out_array] func_ast = ast.parse(func) try: lambda_ast = func_ast.body[0].value if len(lambda_ast.args.args) != 1: raise SyntaxError( "Expected lambda with one arg, but {} has {}".format( func, len(lambda_ast.args.arrgs))) arg = lambda_ast.args.args[0].arg body = astutils.unparse(lambda_ast.body) except AttributeError: raise SyntaxError("Could not parse func {}".format(func)) code = "__out = {}".format(body) num_elements = reduce(lambda x, y: x * y, inparr.shape) if num_elements == 1: inp = state.add_read(in_array) out = state.add_write(out_array) tasklet = state.add_tasklet("_elementwise_", {arg}, {'__out'}, code) state.add_edge(inp, None, tasklet, arg, Memlet.from_array(in_array, inparr)) state.add_edge(tasklet, '__out', out, None, Memlet.from_array(out_array, outarr)) else: state.add_mapped_tasklet( name="_elementwise_", map_ranges={ '__i%d' % i: '0:%s' % n for i, n in enumerate(inparr.shape) }, inputs={ arg: Memlet.simple( in_array, ','.join(['__i%d' % i for i in range(len(inparr.shape))])) }, code=code, outputs={ '__out': Memlet.simple( out_array, ','.join(['__i%d' % i for i in range(len(inparr.shape))])) }, external_edges=True) return out_array
def _bcgather(pv: 'ProgramVisitor', sdfg: SDFG, state: SDFGState, in_buffer: str, out_buffer: str, block_sizes: Union[str, Sequence[Union[sp.Expr, Number]]]): from dace.libraries.pblas.nodes.pgeadd import BlockCyclicGather libnode = BlockCyclicGather('_BCGather_') inbuf_range = None if isinstance(in_buffer, tuple): inbuf_name, inbuf_range = in_buffer else: inbuf_name = in_buffer in_desc = sdfg.arrays[inbuf_name] inbuf_node = state.add_read(inbuf_name) bsizes_range = None if isinstance(block_sizes, (list, tuple)): if isinstance(block_sizes[0], str): bsizes_name, bsizes_range = block_sizes bsizes_desc = sdfg.arrays[bsizes_name] bsizes_node = state.add_read(bsizes_name) else: bsizes_name, bsizes_desc = sdfg.add_temp_transient( (len(block_sizes), ), dtype=dace.int32) bsizes_node = state.add_access(bsizes_name) bsizes_tasklet = state.add_tasklet( '_set_bsizes_', {}, {'__out'}, ";".join([ "__out[{}] = {}".format(i, sz) for i, sz in enumerate(block_sizes) ])) state.add_edge(bsizes_tasklet, '__out', bsizes_node, None, Memlet.from_array(bsizes_name, bsizes_desc)) else: bsizes_name = block_sizes bsizes_desc = sdfg.arrays[bsizes_name] bsizes_node = state.add_read(bsizes_name) outbuf_range = None if isinstance(out_buffer, tuple): outbuf_name, outbuf_range = out_buffer else: outbuf_name = out_buffer out_desc = sdfg.arrays[outbuf_name] outbuf_node = state.add_write(outbuf_name) if inbuf_range: inbuf_mem = Memlet.simple(inbuf_name, inbuf_range) else: inbuf_mem = Memlet.from_array(inbuf_name, in_desc) if bsizes_range: bsizes_mem = Memlet.simple(bsizes_name, bsizes_range) else: bsizes_mem = Memlet.from_array(bsizes_name, bsizes_desc) if outbuf_range: outbuf_mem = Memlet.simple(outbuf_name, outbuf_range) else: outbuf_mem = Memlet.from_array(outbuf_name, out_desc) state.add_edge(inbuf_node, None, libnode, '_inbuffer', inbuf_mem) state.add_edge(bsizes_node, None, libnode, '_block_sizes', bsizes_mem) state.add_edge(libnode, '_outbuffer', outbuf_node, None, outbuf_mem) return None
def apply(self, sdfg: sd.SDFG): graph: sd.SDFGState = sdfg.nodes()[self.state_id] map_entry = graph.node(self.subgraph[DeduplicateAccess._map_entry]) node1 = graph.node(self.subgraph[DeduplicateAccess._node1]) node2 = graph.node(self.subgraph[DeduplicateAccess._node2]) # Steps: # 1. Find unique subsets # 2. Find sets of contiguous subsets # 3. Create transients for subsets # 4. Redirect edges through new transients edges1 = set(e.src_conn for e in graph.edges_between(map_entry, node1)) edges2 = set(e.src_conn for e in graph.edges_between(map_entry, node2)) # Only apply to first connector (determinism) conn = sorted(edges1 & edges2)[0] edges = [e for e in graph.out_edges(map_entry) if e.src_conn == conn] # Get original data descriptor dname = edges[0].data.data desc = sdfg.arrays[edges[0].data.data] if isinstance(edges[0].dst, nodes.AccessNode) and '15' in edges[0].dst.data: sdfg.save('faulty_dedup.sdfg') # Get unique subsets unique_subsets = set(e.data.subset for e in edges) # Find largest contiguous subsets try: # Start from stride-1 dimension contiguous_subsets = helpers.find_contiguous_subsets( unique_subsets, dim=next(i for i, s in enumerate(desc.strides) if s == 1)) except (StopIteration, NotImplementedError): warnings.warn( "DeduplicateAcces::Not operating on Stride One Dimension!") contiguous_subsets = unique_subsets # Then find subsets for rest of the dimensions contiguous_subsets = helpers.find_contiguous_subsets( contiguous_subsets) # Map original edges to subsets edge_mapping = defaultdict(list) for e in edges: for ind, subset in enumerate(contiguous_subsets): if subset.covers(e.data.subset): edge_mapping[ind].append(e) break else: raise ValueError( "Failed to find contiguous subset for edge %s" % e.data) # Create transients for subsets and redirect edges for ind, subset in enumerate(contiguous_subsets): name, _ = sdfg.add_temp_transient(subset.size(), desc.dtype) anode = graph.add_access(name) graph.add_edge(map_entry, conn, anode, None, Memlet(data=dname, subset=subset)) for e in edge_mapping[ind]: graph.remove_edge(e) new_memlet = copy.deepcopy(e.data) new_edge = graph.add_edge(anode, None, e.dst, e.dst_conn, new_memlet) for pe in graph.memlet_tree(new_edge): # Rename data on memlet pe.data.data = name # Offset memlets to match new transient pe.data.subset.offset(subset, True)