def _get_buffer_size(self, state, loop_var, loop_axis): min_offset, max_offset = 1000, -1000 for memlet in self._buffer_memlets([state]): rb, re, _ = memlet.subset.ranges[loop_axis] rb_offset = rb - symbolic.symbol(loop_var) re_offset = re - symbolic.symbol(loop_var) min_offset = min(min_offset, rb_offset, re_offset) max_offset = max(max_offset, rb_offset, re_offset) return max_offset - min_offset + 1
def _canonicalize_memlet( memlet: mm.Memlet, mapranges: List[Tuple[str, subsets.Range]]) -> Tuple[symbolic.SymbolicType]: """ Turn a memlet subset expression (of a single element) into an expression that does not depend on the map symbol names. """ repldict = { symbolic.symbol(p): symbolic.symbol('__dace%d' % i) for i, (p, _) in enumerate(mapranges) } return tuple(rb.subs(repldict) for rb, _, _ in memlet.subset.ndrange())
def replace(subgraph: 'dace.sdfg.state.StateGraphView', name: str, new_name: str): """ Finds and replaces all occurrences of a symbol or array in the given subgraph. :param subgraph: The given graph or subgraph to replace in. :param name: Name to find. :param new_name: Name to replace. """ symrepl = { symbolic.symbol(name): symbolic.pystr_to_symbolic(new_name) if isinstance(new_name, str) else new_name } # Replace in node properties for node in subgraph.nodes(): replace_properties(node, name, new_name) # Replace in memlets for edge in subgraph.edges(): if edge.data.data == name: edge.data.data = new_name edge.data.subset = _replsym(edge.data.subset, symrepl) edge.data.other_subset = _replsym(edge.data.other_subset, symrepl) edge.data.num_accesses = _replsym(edge.data.num_accesses, symrepl)
def __call__(self, *args, **kwargs): """ Convenience function that parses, compiles, and runs a DaCe program. """ binaryobj = self.compile(*args) # Add named arguments to the call kwargs.update({aname: arg for aname, arg in zip(self.argnames, args)}) # Update arguments with symbols in data shapes kwargs.update({ sym: symbolic.symbol(sym).get() for arg in args for sym in (symbolic.symlist(arg.descriptor.shape) if hasattr( arg, 'descriptor') else []) }) # Update arguments with symbol values for aname in self.argnames: if aname in binaryobj.sdfg.arrays: sym_shape = binaryobj.sdfg.arrays[aname].shape for sym in (sym_shape): if symbolic.issymbolic(sym): try: kwargs[str(sym)] = sym.get() except: pass return binaryobj(**kwargs)
def replace(subgraph: 'dace.sdfg.state.StateGraphView', name: str, new_name: str): """ Finds and replaces all occurrences of a symbol or array in the given subgraph. :param subgraph: The given graph or subgraph to replace in. :param name: Name to find. :param new_name: Name to replace. """ if str(name) == str(new_name): return symname = symbolic.symbol(name) symrepl = { symname: symbolic.pystr_to_symbolic(new_name) if isinstance(new_name, str) else new_name } # Replace in node properties for node in subgraph.nodes(): replace_properties(node, symrepl, name, new_name) # Replace in memlets for edge in subgraph.edges(): if edge.data.data == name: edge.data.data = new_name if (edge.data.subset is not None and name in edge.data.subset.free_symbols): edge.data.subset = _replsym(edge.data.subset, symrepl) if (edge.data.other_subset is not None and name in edge.data.other_subset.free_symbols): edge.data.other_subset = _replsym(edge.data.other_subset, symrepl) if symname in edge.data.volume.free_symbols: edge.data.volume = _replsym(edge.data.volume, symrepl)
def replace_properties(node: Any, name: str, new_name: str): if str(name) == str(new_name): return symrepl = { symbolic.symbol(name): symbolic.symbol(new_name) if isinstance(new_name, str) else new_name } for propclass, propval in node.properties(): if propval is None: continue pname = propclass.attr_name if isinstance(propclass, properties.SymbolicProperty): setattr(node, pname, propval.subs(symrepl)) elif isinstance(propclass, properties.DataProperty): if propval == name: setattr(node, pname, new_name) elif isinstance(propclass, (properties.RangeProperty, properties.ShapeProperty)): setattr(node, pname, _replsym(list(propval), symrepl)) elif isinstance(propclass, properties.CodeProperty): if isinstance(propval.code, str): if str(name) != str(new_name): lang = propval.language newcode = propval.code if not re.findall(r'[^\w]%s[^\w]' % name, newcode): continue if lang is dtypes.Language.CPP: # Replace in C++ code # Use local variables and shadowing to replace replacement = 'auto %s = %s;\n' % (name, new_name) propval.code = replacement + newcode else: warnings.warn( 'Replacement of %s with %s was not made ' 'for string tasklet code of language %s' % (name, new_name, lang)) elif propval.code is not None: for stmt in propval.code: ASTFindReplace({name: new_name}).visit(stmt) elif (isinstance(propclass, properties.DictProperty) and pname == 'symbol_mapping'): # Symbol mappings for nested SDFGs for symname, sym_mapping in propval.items(): propval[symname] = sym_mapping.subs(symrepl)
def create_batch_gemm_sdfg(dtype, strides): ######################### sdfg = SDFG('einsum') state = sdfg.add_state() M, K, N = (symbolic.symbol(s) for s in ['M', 'K', 'N']) BATCH, sAM, sAK, sAB, sBK, sBN, sBB, sCM, sCN, sCB = ( symbolic.symbol(s) if symbolic.issymbolic(strides[s]) else strides[s] for s in [ 'BATCH', 'sAM', 'sAK', 'sAB', 'sBK', 'sBN', 'sBB', 'sCM', 'sCN', 'sCB' ]) batched = strides['BATCH'] != 1 _, xarr = sdfg.add_array( 'X', dtype=dtype, shape=[BATCH, M, K] if batched else [M, K], strides=[sAB, sAM, sAK] if batched else [sAM, sAK]) _, yarr = sdfg.add_array( 'Y', dtype=dtype, shape=[BATCH, K, N] if batched else [K, N], strides=[sBB, sBK, sBN] if batched else [sBK, sBN]) _, zarr = sdfg.add_array( 'Z', dtype=dtype, shape=[BATCH, M, N] if batched else [M, N], strides=[sCB, sCM, sCN] if batched else [sCM, sCN]) gX = state.add_read('X') gY = state.add_read('Y') gZ = state.add_write('Z') import dace.libraries.blas as blas # Avoid import loop libnode = blas.MatMul('einsum_gemm') state.add_node(libnode) state.add_edge(gX, None, libnode, '_a', Memlet.from_array(gX.data, xarr)) state.add_edge(gY, None, libnode, '_b', Memlet.from_array(gY.data, yarr)) state.add_edge(libnode, '_c', gZ, None, Memlet.from_array(gZ.data, zarr)) return sdfg
def _loop_range( itervar: str, inedges: List[gr.Edge], condition: sp.Expr) -> Optional[Tuple[sp.Expr, sp.Expr, sp.Expr]]: """ Finds loop range from state machine. :param itersym: String representing the iteration variable. :param inedges: Incoming edges into guard state (length must be 2). :param condition: Condition as sympy expression. :return: A three-tuple of (start, end, stride) expressions, or None if proper for-loop was not detected. ``end`` is inclusive. """ # Find starting expression and stride itersym = symbolic.symbol(itervar) if (itersym in symbolic.pystr_to_symbolic( inedges[0].data.assignments[itervar]).free_symbols and itersym not in symbolic.pystr_to_symbolic( inedges[1].data.assignments[itervar]).free_symbols): stride = (symbolic.pystr_to_symbolic( inedges[0].data.assignments[itervar]) - itersym) start = symbolic.pystr_to_symbolic( inedges[1].data.assignments[itervar]) elif (itersym in symbolic.pystr_to_symbolic( inedges[1].data.assignments[itervar]).free_symbols and itersym not in symbolic.pystr_to_symbolic( inedges[0].data.assignments[itervar]).free_symbols): stride = (symbolic.pystr_to_symbolic( inedges[1].data.assignments[itervar]) - itersym) start = symbolic.pystr_to_symbolic( inedges[0].data.assignments[itervar]) else: return None # Find condition by matching expressions end: Optional[sp.Expr] = None a = sp.Wild('a') match = condition.match(itersym < a) if match: end = match[a] - 1 if end is None: match = condition.match(itersym <= a) if match: end = match[a] if end is None: match = condition.match(itersym > a) if match: end = match[a] + 1 if end is None: match = condition.match(itersym >= a) if match: end = match[a] if end is None: # No match found return None return start, end, stride
def _do_memlets_correspond( memlet_a: mm.Memlet, memlet_b: mm.Memlet, mapranges_a: List[Tuple[str, subsets.Range]], mapranges_b: List[Tuple[str, subsets.Range]]) -> bool: """ Returns True if the two memlets correspond to each other, disregarding symbols from equivalent maps. """ for s1, s2 in zip(memlet_a.subset, memlet_b.subset): # Check for matching but disregard parameter names s1b = s1[0].subs({ symbolic.symbol(k1): symbolic.symbol(k2) for (k1, _), (k2, _) in zip(mapranges_a, mapranges_b) }) s2b = s2[0] # Since there is one element in both subsets, we can check only # the beginning if s1b != s2b: return False return True
def replace(self, repl_dict): """ Substitute a given set of symbols with a different set of symbols. :param repl_dict: A dict of string symbol names to symbols with which to replace them. """ repl_to_intermediate = {} repl_to_final = {} for symbol in repl_dict: if str(symbol) != str(repl_dict[symbol]): intermediate = symbolic.symbol('__dacesym_' + str(symbol)) repl_to_intermediate[symbolic.symbol(symbol)] = intermediate repl_to_final[intermediate] = repl_dict[symbol] if len(repl_to_intermediate) > 0: if self.volume is not None and symbolic.issymbolic(self.volume): self.volume = self.volume.subs(repl_to_intermediate) self.volume = self.volume.subs(repl_to_final) if self.subset is not None: self.subset.replace(repl_to_intermediate) self.subset.replace(repl_to_final) if self.other_subset is not None: self.other_subset.replace(repl_to_intermediate) self.other_subset.replace(repl_to_final)
def vectorize_connector(sdfg: dace.SDFG, dfg: dace.SDFGState, node: dace.nodes.Node, par: str, conn: str, is_input: bool): edges = get_connector_edges(dfg, node, conn, is_input) connectors = node.in_connectors if is_input else node.out_connectors for edge in edges: if edge.data.data is None: # Empty memlets return desc = sdfg.arrays[edge.data.data] if isinstance(desc, data.Stream): # Streams are treated differently in SVE, instead of pointers they become vectors of unknown size connectors[conn] = dace.dtypes.vector(connectors[conn].base_type, -1) return if isinstance(connectors[conn], (dace.dtypes.vector, dace.dtypes.pointer)): # No need for vectorization return subset = edge.data.subset sve_dim = None for i, rng in enumerate(subset): for expr in rng: if symbolic.symbol(par) in symbolic.pystr_to_symbolic( expr).free_symbols: if sve_dim is not None and sve_dim != i: raise util.NotSupportedError( 'Non-vectorizable memlet (loop param occurs in more than one dimension)' ) sve_dim = i if sve_dim is None and edge.data.wcr is None: # Should stay scalar return if sve_dim is not None: sve_subset = subset[sve_dim] edge.data.subset[sve_dim] = (sve_subset[0], sve_subset[0] + util.SVE_LEN, sve_subset[2]) connectors[conn] = dace.dtypes.vector( connectors[conn].type or desc.dtype, util.SVE_LEN)
def can_be_applied(graph: dace.SDFGState, candidate: Dict[Any, int], expr_index: int, sdfg: dace.SDFG, strict=False): map_entry: nodes.MapEntry = graph.node(candidate[NestK._map_entry]) stencil: Stencil = graph.node(candidate[NestK._stencil]) if len(map_entry.map.params) != 1: return False if sd.has_dynamic_map_inputs(graph, map_entry): return False pname = map_entry.map.params[0] # Usually "k" dim_index = None for edge in graph.out_edges(map_entry): if edge.dst != stencil: return False for edge in graph.all_edges(stencil): if edge.data.data is None: # Empty memlet continue # TODO: Use bitmap to verify lower-dimensional arrays if len(edge.data.subset) == 3: for i, rng in enumerate(edge.data.subset.ndrange()): for r in rng: if pname in map(str, r.free_symbols): if dim_index is not None and dim_index != i: # k dimension must match in all memlets return False if str(r) != pname: if symbolic.issymbolic( r - symbolic.symbol(pname), sdfg.constants): warnings.warn('k expression is nontrivial') dim_index = i # No nesting dimension found if dim_index is None: return False # Ensure the stencil shape is 1 for the found dimension if stencil.shape[dim_index] != 1: return False return True
def test_memlet_volume_propagation_nsdfg(): sdfg = make_sdfg() propagation.propagate_memlets_sdfg(sdfg) main_state = sdfg.nodes()[0] data_in_memlet = main_state.edges()[0].data bound_stream_in_memlet = main_state.edges()[1].data out_stream_memlet = main_state.edges()[2].data memlet_check_parameters(data_in_memlet, 0, True, [(0, N - 1, 1)]) memlet_check_parameters(bound_stream_in_memlet, 1, False, [(0, 0, 1)]) memlet_check_parameters(out_stream_memlet, 0, True, [(0, 0, 1)]) nested_sdfg = main_state.nodes()[3].sdfg loop_state = nested_sdfg.nodes()[2] state_check_executions(loop_state, symbol('loop_bound'))
def get_stride(self, sdfg: 'dace.sdfg.SDFG', map: 'dace.sdfg.nodes.Map', dim: int = -1) -> 'dace.symbolic.SymExpr': """ Returns the stride of the underlying memory when traversing a Map. :param sdfg: The SDFG in which the memlet resides. :param map: The map in which the memlet resides. :param dim: The dimension that is incremented. By default it is the innermost. """ if self.data is None: return symbolic.pystr_to_symbolic('0') param = symbolic.symbol(map.params[dim]) array = sdfg.arrays[self.data] # Flatten the subset to a 1D-offset (using the array strides) at some iteration curr = self.subset.at([0] * len(array.strides), array.strides) # Substitute the param with the next (possibly strided) value next = curr.subs(param, param + map.range[dim][2]) # The stride is the difference between both return (next - curr).simplify()
def infer_symbols_from_datadescriptor(sdfg: SDFG, args: Dict[str, Any], exclude: Optional[Set[str]] = None) -> \ Dict[str, Any]: """ Infers the values of SDFG symbols (not given as arguments) from the shapes and strides of input arguments (e.g., arrays). :param sdfg: The SDFG that is being called. :param args: A dictionary mapping from current argument names to their values. This may also include symbols. :param exclude: An optional set of symbols to ignore on inference. :return: A dictionary mapping from symbol names that are not in ``args`` to their inferred values. :raise ValueError: If symbol values are ambiguous. """ exclude = exclude or set() exclude = set(symbolic.symbol(s) for s in exclude) equations = [] symbols = set() # Collect equations and symbols from arguments and shapes for arg_name, arg_val in args.items(): if arg_name in sdfg.arrays: desc = sdfg.arrays[arg_name] if not hasattr(desc, 'shape') or not hasattr(arg_val, 'shape'): continue symbolic_values = list(desc.shape) + list( getattr(desc, 'strides', [])) given_values = list(arg_val.shape) given_strides = [] if hasattr(arg_val, 'strides'): # NumPy arrays use bytes in strides factor = getattr(arg_val, 'itemsize', 1) given_strides = [s // factor for s in arg_val.strides] given_values += given_strides for sym_dim, real_dim in zip(symbolic_values, given_values): repldict = {} for sym in symbolic.symlist(sym_dim).values(): newsym = symbolic.symbol('__SOLVE_' + str(sym)) if str(sym) in args: exclude.add(newsym) else: symbols.add(newsym) exclude.add(sym) repldict[sym] = newsym # Replace symbols with __SOLVE_ symbols so as to allow # the same symbol in the called SDFG if repldict: sym_dim = sym_dim.subs(repldict) equations.append(sym_dim - real_dim) if len(symbols) == 0: return {} # Solve for all at once results = sympy.solve(equations, *symbols, dict=True, exclude=exclude) if len(results) > 1: raise ValueError('Ambiguous values for symbols in inference. ' 'Options: %s' % str(results)) if len(results) == 0: raise ValueError('Cannot infer values for symbols in inference.') result = results[0] if not result: raise ValueError('Cannot infer values for symbols in inference.') # Remove __SOLVE_ prefix return {str(k)[8:]: v for k, v in result.items()}
def can_be_applied(graph, candidate, expr_index, sdfg, strict=False): first_map_exit = graph.nodes()[candidate[MapFusion._first_map_exit]] first_map_entry = graph.entry_node(first_map_exit) second_map_entry = graph.nodes()[candidate[ MapFusion._second_map_entry]] for _in_e in graph.in_edges(first_map_exit): if _in_e.data.wcr is not None: for _out_e in graph.out_edges(second_map_entry): if _out_e.data.data == _in_e.data.data: # wcr is on a node that is used in the second map, quit return False # Check whether there is a pattern map -> access -> map. intermediate_nodes = set() intermediate_data = set() for _, _, dst, _, _ in graph.out_edges(first_map_exit): if isinstance(dst, nodes.AccessNode): intermediate_nodes.add(dst) intermediate_data.add(dst.data) # If array is used anywhere else in this state. num_occurrences = len([ n for n in graph.nodes() if isinstance(n, nodes.AccessNode) and n.data == dst.data ]) if num_occurrences > 1: return False else: return False # Check map ranges perm = MapFusion.find_permutation(first_map_entry.map, second_map_entry.map) if perm is None: return False # Create a dict that maps parameters of the first map to those of the # second map. params_dict = {} for _index, _param in enumerate(first_map_entry.map.params): params_dict[_param] = second_map_entry.map.params[perm[_index]] out_memlets = [e.data for e in graph.in_edges(first_map_exit)] # Check that input set of second map is provided by the output set # of the first map, or other unrelated maps for _, _, _, _, second_memlet in graph.out_edges(second_map_entry): # Memlets that do not come from one of the intermediate arrays if second_memlet.data not in intermediate_data: # however, if intermediate_data eventually leads to # second_memlet.data, need to fail. for _n in intermediate_nodes: source_node = _n # graph.find_node(_n.data) destination_node = graph.find_node(second_memlet.data) # NOTE: Assumes graph has networkx version if destination_node in nx.descendants( graph._nx, source_node): return False continue provided = False for first_memlet in out_memlets: if first_memlet.data != second_memlet.data: continue # If there is an equivalent subset, it is provided expected_second_subset = [] for _tup in first_memlet.subset: new_tuple = [] if isinstance(_tup, symbolic.symbol): new_tuple = symbolic.symbol(params_dict[str(_tup)]) elif isinstance(_tup, (list, tuple)): for _sym in _tup: if (isinstance(_sym, symbolic.symbol) and str(_sym) in params_dict): new_tuple.append( symbolic.symbol(params_dict[str(_sym)])) else: new_tuple.append(_sym) new_tuple = tuple(new_tuple) else: new_tuple = _tup expected_second_subset.append(new_tuple) if expected_second_subset == list(second_memlet.subset): provided = True break # If none of the output memlets of the first map provide the info, # fail. if provided is False: return False # Success return True
def _construct_args(self, *args, **kwargs): """ Main function that controls argument construction for calling the C prototype of the SDFG. Organizes arguments first by `sdfg.arglist`, then data descriptors by alphabetical order, then symbols by alphabetical order. """ if len(kwargs) > 0 and len(args) > 0: raise AttributeError( 'Compiled SDFGs can only be called with either arguments ' + '(e.g. "program(a,b,c)") or keyword arguments ' + '("program(A=a,B=b)"), but not both') # Argument construction sig = self._sdfg.signature_arglist(with_types=False) typedict = self._sdfg.arglist() if len(kwargs) > 0: # Construct mapping from arguments to signature arglist = [] argtypes = [] argnames = [] for a in sig: try: arglist.append(kwargs[a]) argtypes.append(typedict[a]) argnames.append(a) except KeyError: raise KeyError("Missing program argument \"{}\"".format(a)) elif len(args) > 0: arglist = list(args) argtypes = [typedict[s] for s in sig] argnames = sig sig = [] else: arglist = [] argtypes = [] argnames = [] sig = [] # Type checking for a, arg, atype in zip(argnames, arglist, argtypes): if not isinstance(arg, np.ndarray) and isinstance(atype, dt.Array): raise TypeError( 'Passing an object (type %s) to an array in argument "%s"' % (type(arg).__name__, a)) if isinstance(arg, np.ndarray) and not isinstance(atype, dt.Array): raise TypeError( 'Passing an array to a scalar (type %s) in argument "%s"' % (atype.dtype.ctype, a)) if not isinstance(atype, dt.Array) and not isinstance( arg, atype.dtype.type): print('WARNING: Casting scalar argument "%s" from %s to %s' % (a, type(arg).__name__, atype.dtype.type)) # Retain only the element datatype for upcoming checks and casts argtypes = [t.dtype.type for t in argtypes] sdfg = self._sdfg # As in compilation, add symbols used in array sizes to parameters symparams = {} symtypes = {} for symname in sdfg.undefined_symbols(False): # Ignore arguments (as they may not be symbols but constants, # see below) if symname in sdfg.arg_types: continue try: symval = symbolic.symbol(symname) symparams[symname] = symval.get() symtypes[symname] = symval.dtype.type except UnboundLocalError: try: symparams[symname] = kwargs[symname] except KeyError: raise UnboundLocalError('Unassigned symbol %s' % symname) arglist.extend( [symparams[k] for k in sorted(symparams.keys()) if k not in sig]) argtypes.extend( [symtypes[k] for k in sorted(symtypes.keys()) if k not in sig]) # Obtain SDFG constants constants = sdfg.constants # Remove symbolic constants from arguments callparams = tuple( (arg, atype) for arg, atype in zip(arglist, argtypes) if not symbolic.issymbolic(arg) or ( hasattr(arg, 'name') and arg.name not in constants)) # Replace symbols with their values callparams = tuple( (atype(symbolic.eval(arg)), atype) if symbolic.issymbolic(arg, constants) else (arg, atype) for arg, atype in callparams) # Replace arrays with their pointers newargs = tuple( (ctypes.c_void_p(arg.__array_interface__['data'][0]), atype) if (isinstance(arg, ndarray.ndarray) or isinstance(arg, np.ndarray)) else (arg, atype) for arg, atype in callparams) newargs = tuple(types._FFI_CTYPES[atype](arg) if ( atype in types._FFI_CTYPES and not isinstance(arg, ctypes.c_void_p)) else arg for arg, atype in newargs) self._lastargs = newargs return self._lastargs
def find_for_loop( sdfg: sd.SDFG, guard: sd.SDFGState, entry: sd.SDFGState, itervar: Optional[str] = None ) -> Optional[Tuple[AnyStr, Tuple[symbolic.SymbolicType, symbolic.SymbolicType, symbolic.SymbolicType], Tuple[ List[sd.SDFGState], sd.SDFGState]]]: """ Finds loop range from state machine. :param guard: State from which the outgoing edges detect whether to exit the loop or not. :param entry: First state in the loop "body". :return: (iteration variable, (start, end, stride), (start_states[], last_loop_state)), or None if proper for-loop was not detected. ``end`` is inclusive. """ # Extract state transition edge information guard_inedges = sdfg.in_edges(guard) condition_edge = sdfg.edges_between(guard, entry)[0] if itervar is None: itervar = list(guard_inedges[0].data.assignments.keys())[0] condition = condition_edge.data.condition_sympy() # Find the stride edge. All in-edges to the guard except for the stride edge # should have exactly the same assignment, since a valid for loop can only # have one assignment. init_edges = [] init_assignment = None step_edge = None itersym = symbolic.symbol(itervar) for iedge in guard_inedges: assignment = iedge.data.assignments[itervar] if itersym in symbolic.pystr_to_symbolic(assignment).free_symbols: if step_edge is None: step_edge = iedge else: # More than one edge with the iteration variable as a free # symbol, which is not legal. Invalid for loop. return None else: if init_assignment is None: init_assignment = assignment init_edges.append(iedge) elif init_assignment != assignment: # More than one init assignment variations mean that this for # loop is not valid. return None else: init_edges.append(iedge) if step_edge is None or len(init_edges) == 0 or init_assignment is None: # Less than two assignment variations, can't be a valid for loop. return None # Get the init expression and the stride. start = symbolic.pystr_to_symbolic(init_assignment) stride = (symbolic.pystr_to_symbolic(step_edge.data.assignments[itervar]) - itersym) # Get a list of the last states before the loop and a reference to the last # loop state. start_states = [] for init_edge in init_edges: start_state = init_edge.src if start_state not in start_states: start_states.append(start_state) last_loop_state = step_edge.src # Find condition by matching expressions end: Optional[symbolic.SymbolicType] = None a = sp.Wild('a') match = condition.match(itersym < a) if match: end = match[a] - 1 if end is None: match = condition.match(itersym <= a) if match: end = match[a] if end is None: match = condition.match(itersym > a) if match: end = match[a] + 1 if end is None: match = condition.match(itersym >= a) if match: end = match[a] if end is None: # No match found return None return itervar, (start, end, stride), (start_states, last_loop_state)
def expansion(node, parent_state, parent_sdfg): inp_buffer, out_buffer = node.validate(parent_sdfg, parent_state) redistr = parent_sdfg.rdistrarrays[node.redistr] array_a = parent_sdfg.subarrays[redistr.array_a] array_b = parent_sdfg.subarrays[redistr.array_b] inp_symbols = [ symbolic.symbol(f"__inp_s{i}") for i in range(len(inp_buffer.shape)) ] out_symbols = [ symbolic.symbol(f"__out_s{i}") for i in range(len(out_buffer.shape)) ] inp_subset = subsets.Indices(inp_symbols) out_subset = subsets.Indices(out_symbols) inp_offset = cpp.cpp_offset_expr(inp_buffer, inp_subset) out_offset = cpp.cpp_offset_expr(out_buffer, out_subset) print(inp_offset) print(out_offset) inp_repl = "" for i, s in enumerate(inp_symbols): inp_repl += f"int {s} = __state->{node.redistr}_self_src[__idx * {len(inp_buffer.shape)} + {i}];\n" out_repl = "" for i, s in enumerate(out_symbols): out_repl += f"int {s} = __state->{node.redistr}_self_dst[__idx * {len(out_buffer.shape)} + {i}];\n" copy_args = ", ".join([ f"__state->{node.redistr}_self_size[__idx * {len(inp_buffer.shape)} + {i}], {istride}, {ostride}" for i, (istride, ostride ) in enumerate(zip(inp_buffer.strides, out_buffer.strides)) ]) code = f""" int myrank; MPI_Comm_rank(MPI_COMM_WORLD, &myrank); MPI_Request* req = new MPI_Request[__state->{node._redistr}_sends]; MPI_Status* status = new MPI_Status[__state->{node._redistr}_sends]; MPI_Status recv_status; if (__state->{array_a.pgrid}_valid) {{ for (auto __idx = 0; __idx < __state->{node._redistr}_sends; ++__idx) {{ // printf("({redistr.array_a} -> {redistr.array_b}) I am rank %d and I send to %d\\n", myrank, __state->{node._redistr}_dst_ranks[__idx]); // fflush(stdout); MPI_Isend(_inp_buffer, 1, __state->{node._redistr}_send_types[__idx], __state->{node._redistr}_dst_ranks[__idx], 0, MPI_COMM_WORLD, &req[__idx]); }} }} if (__state->{array_b.pgrid}_valid) {{ for (auto __idx = 0; __idx < __state->{node._redistr}_self_copies; ++__idx) {{ // printf("({redistr.array_a} -> {redistr.array_b}) I am rank %d and I self-copy\\n", myrank); // fflush(stdout); {inp_repl} {out_repl} dace::CopyNDDynamic<{inp_buffer.dtype.ctype}, 1, false, {len(inp_buffer.shape)}>::Dynamic::Copy( _inp_buffer + {inp_offset}, _out_buffer + {out_offset}, {copy_args} ); }} for (auto __idx = 0; __idx < __state->{node._redistr}_recvs; ++__idx) {{ // printf("({redistr.array_a} -> {redistr.array_b}) I am rank %d and I receive from %d\\n", myrank, __state->{node._redistr}_src_ranks[__idx]); // fflush(stdout); MPI_Recv(_out_buffer, 1, __state->{node._redistr}_recv_types[__idx], __state->{node._redistr}_src_ranks[__idx], 0, MPI_COMM_WORLD, &recv_status); }} }} if (__state->{array_a.pgrid}_valid) {{ MPI_Waitall(__state->{node._redistr}_sends, req, status); delete[] req; delete[] status; }} // printf("I am rank %d and I finished the redistribution {redistr.array_a} -> {redistr.array_b}\\n", myrank); // fflush(stdout); """ tasklet = nodes.Tasklet(node.name, node.in_connectors, node.out_connectors, code, language=dtypes.Language.CPP) return tasklet
def infer_symbols_from_shapes(sdfg: SDFG, args: Dict[str, Any], exclude: Optional[Set[str]] = None) -> \ Dict[str, Any]: """ Infers the values of SDFG symbols (not given as arguments) from the shapes of input arguments (e.g., arrays). :param sdfg: The SDFG that is being called. :param args: A dictionary mapping from current argument names to their values. This may also include symbols. :param exclude: An optional set of symbols to ignore on inference. :return: A dictionary mapping from symbol names that are not in ``args`` to their inferred values. :raise ValueError: If symbol values are ambiguous. """ exclude = exclude or set() exclude = set(symbolic.symbol(s) for s in exclude) equations = [] symbols = set() # Collect equations and symbols from arguments and shapes for arg_name, arg_val in args.items(): if arg_name in sdfg.arrays: desc = sdfg.arrays[arg_name] if not hasattr(desc, 'shape') or not hasattr(arg_val, 'shape'): continue symbolic_shape = desc.shape given_shape = arg_val.shape for sym_dim, real_dim in zip(symbolic_shape, given_shape): repldict = {} for sym in symbolic.symlist(sym_dim).values(): newsym = symbolic.symbol('__SOLVE_' + str(sym)) if str(sym) in args: exclude.add(newsym) else: symbols.add(newsym) exclude.add(sym) repldict[sym] = newsym # Replace symbols with __SOLVE_ symbols so as to allow # the same symbol in the called SDFG if repldict: sym_dim = sym_dim.subs(repldict) equations.append(sym_dim - real_dim) if len(symbols) == 0: return {} # Solve for all at once results = sympy.solve(equations, *symbols, dict=True, exclude=exclude) if len(results) > 1: raise ValueError('Ambiguous values for symbols in inference. ' 'Options: %s' % str(results)) if len(results) == 0: raise ValueError('Cannot infer values for symbols in inference.') result = results[0] if not result: raise ValueError('Cannot infer values for symbols in inference.') # Fast path (unnecessary) # # For each symbol in each dimension, try to solve an equation # for sym_dim, real_dim in zip(symbolic_shape, given_shape): # for sym in symbolic.symlist(sym_dim): # if sym in inferred_syms and symval != inferred_syms[sym]: # raise ValueError('Ambiguous value for symbol %s in argument ' # '%s: can be either %d or %d' % ( # sym, arg_name, inferred_syms[sym], symval)) # Remove __SOLVE_ prefix return {str(k)[8:]: v for k, v in result.items()}
def can_be_applied(cls, state: SDFGState, candidate, expr_index, sdfg: SDFG, strict=False) -> bool: map_entry = state.node(candidate[cls.map_entry]) map_exit = state.exit_node(map_entry) current_map = map_entry.map subgraph = state.scope_subgraph(map_entry) subgraph_contents = state.scope_subgraph(map_entry, include_entry=False, include_exit=False) # Prevent infinite repeats if current_map.schedule == dace.dtypes.ScheduleType.SVE_Map: return False # Infer all connector types for later checks (without modifying the graph) inferred = infer_types.infer_connector_types(sdfg, state, subgraph) ######################## # Ensure only Tasklets and AccessNodes are within the map for node, _ in subgraph_contents.all_nodes_recursive(): if not isinstance(node, (nodes.Tasklet, nodes.AccessNode)): return False ######################## # Check for unsupported datatypes on the connectors (including on the Map itself) bit_widths = set() for node, _ in subgraph.all_nodes_recursive(): for conn in node.in_connectors: t = inferred[(node, conn, True)] bit_widths.add(util.get_base_type(t).bytes) if not t.type in sve.util.TYPE_TO_SVE: return False for conn in node.out_connectors: t = inferred[(node, conn, False)] bit_widths.add(util.get_base_type(t).bytes) if not t.type in sve.util.TYPE_TO_SVE: return False # Multiple different bit widths occuring (messes up the predicates) if len(bit_widths) > 1: return False ######################## # Check for unsupported memlets param_name = current_map.params[-1] for e, _ in subgraph.all_edges_recursive(): # Check for unsupported strides # The only unsupported strides are the ones containing the innermost # loop param because they are not constant during a vector step param_sym = symbolic.symbol(current_map.params[-1]) if param_sym in e.data.get_stride(sdfg, map_entry.map).free_symbols: return False # Check for unsupported WCR if e.data.wcr is not None: # Unsupported reduction type reduction_type = dace.frontend.operations.detect_reduction_type( e.data.wcr) if reduction_type not in sve.util.REDUCTION_TYPE_TO_SVE: return False # Param in memlet during WCR is not supported if param_name in e.data.subset.free_symbols and e.data.wcr_nonatomic: return False # vreduce is not supported dst_node = state.memlet_path(e)[-1] if isinstance(dst_node, nodes.Tasklet): if isinstance(dst_node.in_connectors[e.dst_conn], dtypes.vector): return False elif isinstance(dst_node, nodes.AccessNode): desc = dst_node.desc(sdfg) if isinstance(desc, data.Scalar) and isinstance( desc.dtype, dtypes.vector): return False ######################## # Check for invalid copies in the subgraph for node, _ in subgraph.all_nodes_recursive(): if not isinstance(node, nodes.Tasklet): continue for e in state.in_edges(node): # Check for valid copies from other tasklets and/or streams if e.data.data is not None: src_node = state.memlet_path(e)[0].src if not isinstance(src_node, (nodes.Tasklet, nodes.AccessNode)): # Make sure we only have Code->Code copies and from arrays return False if isinstance(src_node, nodes.AccessNode): src_desc = src_node.desc(sdfg) if isinstance(src_desc, dace.data.Stream): # Stream pops are not implemented return False # Run the vector inference algorithm to check if vectorization is feasible try: inf_graph = vector_inference.infer_vectors( sdfg, state, map_entry, util.SVE_LEN, flags=vector_inference.VectorInferenceFlags.Allow_Stride, apply=False) except vector_inference.VectorInferenceException as ex: print(f'UserWarning: Vector inference failed! {ex}') return False return True
def _construct_args(self, **kwargs): """ Main function that controls argument construction for calling the C prototype of the SDFG. Organizes arguments first by `sdfg.arglist`, then data descriptors by alphabetical order, then symbols by alphabetical order. """ # Argument construction sig = self._sdfg.signature_arglist(with_types=False) typedict = self._sdfg.arglist() if len(kwargs) > 0: # Construct mapping from arguments to signature arglist = [] argtypes = [] argnames = [] for a in sig: try: arglist.append(kwargs[a]) argtypes.append(typedict[a]) argnames.append(a) except KeyError: raise KeyError("Missing program argument \"{}\"".format(a)) else: arglist = [] argtypes = [] argnames = [] sig = [] # Type checking for a, arg, atype in zip(argnames, arglist, argtypes): if not isinstance(arg, np.ndarray) and isinstance(atype, dt.Array): raise TypeError( 'Passing an object (type %s) to an array in argument "%s"' % (type(arg).__name__, a)) if isinstance(arg, np.ndarray) and not isinstance(atype, dt.Array): raise TypeError( 'Passing an array to a scalar (type %s) in argument "%s"' % (atype.dtype.ctype, a)) if not isinstance(atype, dt.Array) and not isinstance( atype.dtype, dace.callback) and not isinstance( arg, atype.dtype.type): print('WARNING: Casting scalar argument "%s" from %s to %s' % (a, type(arg).__name__, atype.dtype.type)) # Call a wrapper function to make NumPy arrays from pointers. for index, (arg, argtype) in enumerate(zip(arglist, argtypes)): if isinstance(argtype.dtype, dace.callback): arglist[index] = argtype.dtype.get_trampoline(arg) # Retain only the element datatype for upcoming checks and casts argtypes = [t.dtype.as_ctypes() for t in argtypes] sdfg = self._sdfg # As in compilation, add symbols used in array sizes to parameters symparams = {} symtypes = {} for symname in sdfg.undefined_symbols(False): try: symval = symbolic.symbol(symname) symparams[symname] = symval.get() symtypes[symname] = symval.dtype.as_ctypes() except UnboundLocalError: try: symparams[symname] = kwargs[symname] except KeyError: raise UnboundLocalError('Unassigned symbol %s' % symname) arglist.extend( [symparams[k] for k in sorted(symparams.keys()) if k not in sig]) argtypes.extend( [symtypes[k] for k in sorted(symtypes.keys()) if k not in sig]) # Obtain SDFG constants constants = sdfg.constants # Remove symbolic constants from arguments callparams = tuple( (arg, atype) for arg, atype in zip(arglist, argtypes) if not symbolic.issymbolic(arg) or ( hasattr(arg, 'name') and arg.name not in constants)) # Replace symbols with their values callparams = tuple( (atype(symbolic.eval(arg)), atype) if symbolic.issymbolic(arg, constants) else (arg, atype) for arg, atype in callparams) # Replace arrays with their pointers newargs = tuple( (ctypes.c_void_p(arg.__array_interface__['data'][0]), atype) if isinstance(arg, np.ndarray) else (arg, atype) for arg, atype in callparams) newargs = tuple( atype(arg) if (not isinstance(arg, ctypes._SimpleCData)) else arg for arg, atype in newargs) self._lastargs = newargs return self._lastargs