def to_json(self): # Two roundtrips to avoid issues in AST parsing/unparsing of negative # numbers, i.e., "(-1)" becomes "(- 1)" if self.language == dace.dtypes.Language.Python and self.code is not None: code = unparse(ast.parse(self.as_string)) else: code = self.as_string ret = {'string_data': code, 'language': self.language.name} return ret
def remove_scalar_transients(top_sdfg: dace.SDFG): """ Clean up tasklet->scalar-transient, replacing them with symbols. """ dprint = print # lambda *args: pass removed_transients = 0 for sdfg in top_sdfg.all_sdfgs_recursive(): transients_to_remove = {} for dname, desc in sdfg.arrays.items(): skip = False if isinstance(desc, dace.data.Scalar) and desc.transient: # Find node where transient is instantiated init_tasklet: Optional[dace.nodes.Tasklet] = None itstate = None for state in sdfg.nodes(): if skip: break for node in state.nodes(): if (isinstance(node, dace.nodes.AccessNode) and node.data == dname): if state.in_degree(node) > 1: dprint('Cannot remove scalar', dname, '(more than one input)') skip = True break elif state.in_degree(node) == 1: if init_tasklet is not None: dprint('Cannot remove scalar', dname, '(initialized multiple times)') skip = True break init_tasklet = state.in_edges(node)[0].src itstate = state if init_tasklet is None: dprint('Cannot remove scalar', dname, '(uninitialized)') skip = True if skip: continue # We can remove transient, find value from tasklet if len(init_tasklet.code.code) != 1: dprint('Cannot remove scalar', dname, '(complex tasklet)') continue if not isinstance(init_tasklet.code.code[0], ast.Assign): dprint('Cannot remove scalar', dname, '(complex tasklet2)') continue val = float(unparse(init_tasklet.code.code[0].value)) dprint('Converting', dname, 'to constant with value', val) transients_to_remove[dname] = val # Remove initialization tasklet itstate.remove_node(init_tasklet) _remove_transients(sdfg, transients_to_remove) removed_transients += len(transients_to_remove) print('Cleaned up %d extra scalar transients' % removed_transients)
def _analyze_and_unparse_code(func: DaceProgram) -> str: src_ast, _, _, _ = astutils.function_to_ast(func.f) resolved = { k: v for k, v in func.global_vars.items() if k not in func.argnames } src_ast = GlobalResolver(resolved).visit(src_ast) src_ast = ConditionalCodeResolver(resolved).visit(src_ast) src_ast = DeadCodeEliminator().visit(src_ast) return astutils.unparse(src_ast)
def __label__(self, sdfg, state): # Autodetect reduction type redtype = detect_reduction_type(self.wcr) if redtype == types.ReductionType.Custom: wcrstr = unparse(ast.parse(self.wcr).body[0].value.body) else: wcrstr = str(redtype) wcrstr = wcrstr[wcrstr.find('.') + 1:] # Skip "ReductionType." return 'Op: {op}\nAxes: {axes}'.format( axes=('all' if self.axes is None else str(self.axes)), op=wcrstr)
def visit_Subscript(self, node): target = rname(node) if target not in self.memlets and target not in self.constants: return self.generic_visit(node) slice = self.visit(node.slice) if not isinstance(slice, ast.Index): raise NotImplementedError("Range subscripting not implemented") if isinstance(slice.value, ast.Tuple): subscript = unparse(slice)[1:-1] else: subscript = unparse(slice) if target in self.constants: slice_str = DaCeKeywordRemover.ndslice_cpp( subscript.split(", "), self.constants[target].shape) newnode = ast.parse("%s[%s]" % (target, slice_str)).body[0].value else: newnode = ast.parse("__%s(%s)" % (target, subscript)).body[0].value return ast.copy_location(newnode, node)
def _replace_assignment(v, repl): # Special cases to speed up replacement v = str(v) if not v: return v if dtypes.validate_name(v) and v in repl: return repl[v] vast = ast.parse(v) replacer = astutils.ASTFindReplace(repl) vast = replacer.visit(vast) return astutils.unparse(vast)
def _subscript_expr(self, slicenode: ast.AST, target: str) -> symbolic.SymbolicType: visited_slice = self.visit(slicenode) if not isinstance(visited_slice, ast.Index): raise NotImplementedError("Range subscripting not implemented") # Collect strides for index expressions if target in self.constants: strides = shape_to_strides(self.constants[target].shape) else: memlet = self.memlets[target][0] dtype = self.memlets[target][3] dname = memlet.data strides = self.sdfg.arrays[dname].strides # Get memlet absolute strides, including tile sizes strides = memlet.subset.absolute_strides(strides) # Filter ("squeeze") strides w.r.t. scalar dimensions dimlen = dtype.veclen if isinstance(dtype, dtypes.vector) else 1 subset_size = memlet.subset.size() indexdims = [i for i, s in enumerate(subset_size) if s == 1] strides = [ s for i, s in enumerate(strides) if i not in indexdims and not (s == 1 and subset_size[i] == dimlen) ] if isinstance(visited_slice.value, ast.Tuple): if len(strides) != len(visited_slice.value.elts): raise SyntaxError( 'Invalid number of dimensions in expression (expected %d, ' 'got %d)' % (len(strides), len(visited_slice.value.elts))) return sum( symbolic.pystr_to_symbolic(unparse(elt)) * s for elt, s in zip(visited_slice.value.elts, strides)) if len(strides) != 1: raise SyntaxError('Missing dimensions in expression (expected %d, ' 'got one)' % len(strides)) return symbolic.pystr_to_symbolic(unparse(visited_slice)) * strides[0]
def to_string(obj): if isinstance(obj, dict): # The object has annotated language in this case; ignore the language for this operation obj = obj['code_or_block'] if isinstance(obj, str): return obj # Grab the originally parsed string if any if obj._as_string is not None and obj._as_string != "": return obj._as_string # It's probably good enough to assume that there is an original string # if the language was not Python, so we just throw the string to the # astunparser. return unparse(obj)
def visit_Call(self, node: ast.Call) -> Any: if hasattr(node.func, 'n') and isinstance(node.func.n, SDFGConvertible): # Skip already-parsed calls return self.generic_visit(node) try: global_func = astutils.evalnode(node.func, self.globals) if self.resolve_functions: global_val = astutils.evalnode(node, self.globals) else: global_val = node except SyntaxError: return self.generic_visit(node) newnode = None if self.resolve_functions and global_val is not node: # Without this check, casts don't generate code if not isinstance(global_val, dtypes.typeclass): newnode = self.global_value_to_node( global_val, parent_node=node, qualname=astutils.unparse(node), recurse=True) if newnode is not None: return newnode elif not isinstance(global_func, dtypes.typeclass): callables = not self.do_not_detect_callables newnode = self.global_value_to_node( global_func, parent_node=node, qualname=astutils.unparse(node), recurse=True, detect_callables=callables) if newnode is not None: node.func = newnode return self.generic_visit(node) return self.generic_visit(node)
def visit_Attribute(self, node: ast.Attribute) -> Any: # Try to evaluate the expression with only the globals try: global_val = astutils.evalnode(node, self.globals) except SyntaxError: return self.generic_visit(node) if not isinstance(global_val, dtypes.typeclass): newnode = self.global_value_to_node(global_val, parent_node=node, qualname=astutils.unparse(node), recurse=True) if newnode is not None: return newnode return self.generic_visit(node)
def generate_case_body(self, t: ast.If, test_pred: str): # It is very important to remember that elif's rely on the previous elif's (they are sequential) # i.e. the first subcase that hits wins and all following ones lose self.fill('// Case ' + astutils.unparse(t.test)) self.enter() # Generate the case body, which will use the test predicate in the ops self.pred_name = test_pred # Allow for local definitions, so we backup all symbols sym = copy.deepcopy(self.defined_symbols) self.dispatch(t.body) self.defined_symbols = sym self.leave()
def pystr_to_symbolic(expr, symbol_map=None, simplify=None): """ Takes a Python string and converts it into a symbolic expression. """ from dace.frontend.python.astutils import unparse # Avoid import loops if isinstance(expr, (SymExpr, sympy.Basic)): return expr if isinstance(expr, str) and dtypes.validate_name(expr): return symbol(expr) symbol_map = symbol_map or {} locals = { 'abs': sympy.Abs, 'min': sympy.Min, 'max': sympy.Max, 'True': sympy.true, 'False': sympy.false, 'GtE': sympy.Ge, 'LtE': sympy.Le, 'NotEq': sympy.Ne, # Convert and/or to special sympy functions to avoid boolean evaluation 'And': sympy.Function('AND'), 'Or': sympy.Function('OR'), 'var': sympy.Symbol('var'), 'root': sympy.Symbol('root'), } # _clash1 enables all one-letter variables like N as symbols # _clash also allows pi, beta, zeta and other common greek letters locals.update(_sympy_clash) # Sympy processes "not/and/or" as direct evaluation. Replace with # And/Or(x, y), Not(x) if isinstance(expr, str) and re.search( r'\bnot\b|\band\b|\bor\b|\bNone\b|==|!=', expr): expr = unparse(SympyBooleanConverter().visit(ast.parse(expr).body[0])) # TODO: support SymExpr over-approximated expressions try: return sympy_to_dace(sympy.sympify(expr, locals, evaluate=simplify), symbol_map) except (TypeError, sympy.SympifyError): # Symbol object is not subscriptable # Replace subscript expressions with function calls expr = expr.replace('[', '(') expr = expr.replace(']', ')') return sympy_to_dace(sympy.sympify(expr, locals, evaluate=simplify), symbol_map)
def visit_Call(self, node): if isinstance(node.func, ast.Name) and (node.func.id.startswith('__DAPPSTRUCT_') or node.func.id in self._structs): fields = ', '.join([ '.%s = %s' % (rname(arg.arg), unparse(arg.value)) for arg in sorted(node.keywords, key=lambda x: x.arg) ]) tname = node.func.id if node.func.id.startswith('__DAPPSTRUCT_'): tname = node.func.id[len('__DAPPSTRUCT_'):] return ast.copy_location( ast.Name(id="(%s) { %s }" % (tname, fields), ctx=ast.Load), node) return self.generic_visit(node)
def _label(self, shape): result = '' if self.data is not None: result = self.data if self.subset is None: return result num_elements = self.subset.num_elements() if self.num_accesses != num_elements: if self.num_accesses == -1: result += '(dyn) ' else: result += '(%s) ' % SymbolicProperty.to_string( self.num_accesses) arrayNotation = True try: if shape is not None and reduce(operator.mul, shape, 1) == 1: # Don't draw array if we're accessing a single element and it's zero if all(s == 0 for s in self.subset.min_element()): arrayNotation = False except TypeError: # Will fail if trying to check the truth value of a sympy expr pass if arrayNotation: result += '[%s]' % str(self.subset) if self.wcr is not None and str(self.wcr) != '': # Autodetect reduction type redtype = detect_reduction_type(self.wcr) if redtype == dtypes.ReductionType.Custom: wcrstr = unparse(ast.parse(self.wcr).body[0].value.body) else: wcrstr = str(redtype) wcrstr = wcrstr[wcrstr.find('.') + 1:] # Skip "ReductionType." result += ' (CR: %s' % wcrstr if self.wcr_identity is not None: result += ', id: %s' % str(self.wcr_identity) result += ')' if self.other_subset is not None: result += ' -> [%s]' % str(self.other_subset) return result
def inner_eval_ast(defined, node, additional_syms=None): if isinstance(node, ast.AST): code = astutils.unparse(node) else: return node syms = {} syms.update(defined) if additional_syms is not None: syms.update(additional_syms) # First try to evaluate normally try: return eval(code, syms) except: # Literally anything can happen here # If doesn't work, try to evaluate as a sympy expression # Replace subscript expressions with function calls (sympy support) code = code.replace('[', '(') code = code.replace(']', ')') return pystr_to_symbolic(code)
def visit_JoinedStr(self, node: ast.JoinedStr) -> Any: try: global_val = astutils.evalnode(node, self.globals) return ast.copy_location(ast.Constant(kind='', value=global_val), node) except SyntaxError: warnings.warn(f'f-string at line {node.lineno} could not ' 'be fully evaluated in DaCe program, converting to ' 'partially-evaluated string.') visited = self.generic_visit(node) parsed = [ not isinstance(v, ast.FormattedValue) or isinstance(v.value, ast.Constant) for v in visited.values ] values = [astutils.unparse(v.value) for v in visited.values] return ast.copy_location( ast.Constant(kind='', value=''.join(('{%s}' % v) if not p else v for p, v in zip(parsed, values))), node)
def visit_Call(self, node: ast.Call): # Struct initializer name = astutils.unparse(node.func) if name not in self._structs: return self.generic_visit(node) # Parse name and fields struct = self._structs[name] name = struct.name fields = {astutils.rname(arg.arg): arg.value for arg in node.keywords} if tuple(sorted(fields.keys())) != tuple(sorted(struct.fields.keys())): raise SyntaxError('Mismatch in fields in struct definition') # Create custom node #new_node = astutils.StructInitializer(name, fields) #return ast.copy_location(new_node, node) node.func = ast.copy_location( ast.Name(id='__DACESTRUCT_' + name, ctx=ast.Load()), node.func) return node
def remove_constant_stencils(top_sdfg: dace.SDFG): dprint = print # lambda *args: pass removed_transients = 0 for sdfg in top_sdfg.all_sdfgs_recursive(): transients_to_remove = {} for state in sdfg.nodes(): for node in state.nodes(): if (isinstance(node, stencil.Stencil) and state.in_degree(node) == 0 and state.out_degree(node) == 1): # We can remove transient, find value from tasklet if len(node.code.code) != 1: dprint('Cannot remove scalar stencil', node.name, '(complex code)') continue if not isinstance(node.code.code[0], ast.Assign): dprint('Cannot remove scalar stencil', node.name, '(complex code2)') continue # Ensure no one else is writing to it onode = state.memlet_path(state.out_edges(node)[0])[-1].dst dname = state.out_edges(node)[0].data.data if any( s.in_degree(n) > 0 for s in sdfg.nodes() for n in s.nodes() if n != onode and isinstance( n, dace.nodes.AccessNode) and n.data == dname): continue val = float(eval(unparse(node.code.code[0].value))) dprint('Converting scalar stencil result', dname, 'to constant with value', val) transients_to_remove[dname] = val # Remove initialization tasklet state.remove_node(node) _remove_transients(sdfg, transients_to_remove, ReplaceSubscript) removed_transients += len(transients_to_remove) print('Cleaned up %d extra scalar stencils' % removed_transients)
def _elementwise(sdfg: SDFG, state: SDFGState, func: str, in_array: str, out_array=None): """Apply a lambda function to each element in the input""" inparr = sdfg.arrays[in_array] restype = sdfg.arrays[in_array].dtype if out_array is None: out_array, outarr = sdfg.add_temp_transient(inparr.shape, restype, inparr.storage) else: outarr = sdfg.arrays[out_array] func_ast = ast.parse(func) try: lambda_ast = func_ast.body[0].value if len(lambda_ast.args.args) != 1: raise SyntaxError( "Expected lambda with one arg, but {} has {}".format( func, len(lambda_ast.args.arrgs))) arg = lambda_ast.args.args[0].arg body = astutils.unparse(lambda_ast.body) except AttributeError: raise SyntaxError("Could not parse func {}".format(func)) code = "__out = {}".format(body) num_elements = reduce(lambda x, y: x * y, inparr.shape) if num_elements == 1: inp = state.add_read(in_array) out = state.add_write(out_array) tasklet = state.add_tasklet("_elementwise_", {arg}, {'__out'}, code) state.add_edge(inp, None, tasklet, arg, Memlet.from_array(in_array, inparr)) state.add_edge(tasklet, '__out', out, None, Memlet.from_array(out_array, outarr)) else: state.add_mapped_tasklet( name="_elementwise_", map_ranges={ '__i%d' % i: '0:%s' % n for i, n in enumerate(inparr.shape) }, inputs={ arg: Memlet.simple( in_array, ','.join(['__i%d' % i for i in range(len(inparr.shape))])) }, code=code, outputs={ '__out': Memlet.simple( out_array, ','.join(['__i%d' % i for i in range(len(inparr.shape))])) }, external_edges=True) return out_array
def _parse_dim_atom(das, atom): result = pyexpr_to_symbolic(das, atom) if isinstance(result, data.Data): return pystr_to_symbolic(astutils.unparse(atom)) return result
def condition_sympy(self): cond_ast = self.condition return symbolic.pystr_to_symbolic(astutils.unparse(cond_ast))
def to_string(obj): if obj is None: return 'lambda: None' if isinstance(obj, str): return unparse(ast.parse(obj)) return unparse(obj)
def as_string(self): if isinstance(self.code, str) or self.code is None: return self.code return unparse(self.code)
def remove_scalar_reads(sdfg: sd.SDFG, array_names: Dict[str, str]): """ Removes all instances of a promoted symbol's read accesses in an SDFG. This removes each read-only access node as well as all of its descendant edges (in memlet trees) and connectors. Descends recursively to nested SDFGs and modifies tasklets (Python and C++). :param sdfg: The SDFG to operate on. :param array_names: Mapping between scalar names to replace and their replacement symbol name. :note: Operates in-place on the SDFG. """ for state in sdfg.nodes(): scalar_nodes = [ n for n in state.nodes() if isinstance(n, nodes.AccessNode) and n.data in array_names ] for node in scalar_nodes: symname = array_names[node.data] for out_edge in state.out_edges(node): for e in state.memlet_tree(out_edge): # Step 3.1 dst = e.dst state.remove_edge_and_connectors(e) if isinstance(dst, nodes.Tasklet): # Step 3.2 if dst.language is dtypes.Language.Python: promo = TaskletPromoter(e.dst_conn, symname) for stmt in dst.code.code: promo.visit(stmt) elif dst.language is dtypes.Language.CPP: # Replace whole-word matches (identifiers) in code dst.code.code = re.sub( r'\b%s\b' % re.escape(e.dst_conn), symname, dst.code.as_string) elif isinstance(dst, nodes.AccessNode): # Step 3.3 t = state.add_tasklet('symassign', {}, {'__out'}, '__out = %s' % symname) state.add_edge( t, '__out', dst, e.dst_conn, mm.Memlet(data=dst.data, subset=e.data.dst_subset, volume=1)) # Reassign destination for check below dst = t elif isinstance(dst, nodes.NestedSDFG): tmp_symname = symname val = 1 while (tmp_symname in dst.sdfg.symbols or tmp_symname in dst.sdfg.arrays): # Find new symbol name tmp_symname = f'{symname}_{val}' val += 1 # Descend recursively to remove scalar remove_scalar_reads(dst.sdfg, {e.dst_conn: tmp_symname}) for ise in dst.sdfg.edges(): ise.data.replace(e.dst_conn, tmp_symname) # Remove subscript occurrences as well for aname, aval in ise.data.assignments.items(): vast = ast.parse(aval) vast = astutils.RemoveSubscripts( {tmp_symname}).visit(vast) ise.data.assignments[aname] = astutils.unparse( vast) ise.data.replace(tmp_symname + '[0]', tmp_symname) # Set symbol mapping dst.sdfg.remove_data(e.dst_conn, validate=False) dst.remove_in_connector(e.dst_conn) dst.sdfg.symbols[tmp_symname] = sdfg.arrays[ node.data].dtype dst.symbol_mapping[tmp_symname] = symname elif isinstance(dst, (nodes.EntryNode, nodes.ExitNode)): # Skip continue else: raise ValueError( 'Node type "%s" not supported for promotion' % type(dst).__name__) # If nodes were disconnected, reconnect with empty memlet if (isinstance(e.src, nodes.EntryNode) and len(state.edges_between(e.src, dst)) == 0): state.add_nedge(e.src, dst, mm.Memlet()) # Remove newly-isolated nodes state.remove_nodes_from( [n for n in scalar_nodes if len(state.all_edges(n)) == 0])
def VisitTopLevelExpr(self, node): # DaCe memlet expression if isinstance(node.value, ast.BinOp): rhs = node.value.right lhs = node.value.left arrays = self.curprim.arrays() if isinstance(node.value.op, ast.LShift): # Dynamic access. Emit nothing and load memory on encounter if isinstance(rhs, ast.Call) and ast.literal_eval( rhs.args[0]) == -1: array_name = rhs.func.id stripped_subscript = '%s[:]' % (array_name) self.storeOnAssignment[node.value.left.id] = \ ast.parse(stripped_subscript).body[0].value return None, None, None if isinstance(rhs, ast.Subscript) and isinstance( rhs.value, ast.Call): # Dynamic access. Emit nothing and load memory on encounter if ast.literal_eval(rhs.value.args[0]) == -1: array_name = rhs.value.func.id stripped_subscript = '%s[%s]' % (array_name, unparse(rhs.slice)) self.storeOnAssignment[node.value.left.id] = \ ast.parse(stripped_subscript).body[0].value return None, None, None rhs = ast.Subscript(value=rhs.value.func, ctx=ast.Load(), slice=rhs.slice) result = _copy_location( ast.Assign(targets=[node.value.left], value=rhs), node) result.targets[0].ctx = ast.Store() return result, None, None # END of "a << b" elif isinstance(node.value.op, ast.RShift): # If the memlet refers to a sub-array (view), also add an expression to initialize it init_expr = None result = None prefix = [] if isinstance(rhs, ast.Subscript): # Index subscript expression ("tmp >> b(1, sum)[i,j,k,l]") if isinstance(rhs.value, ast.Call): # Only match expressions with possible write-conflict resolution, such as "A(...)[...]" array_name = rhs.value.func.id stripped_subscript = '%s[%s]' % (array_name, unparse(rhs.slice)) # WCR initialization with identity value if len(rhs.value.args) >= 3: prefix.append( _copy_location( ast.parse( '%s = %s' % (stripped_subscript, unparse(rhs.value.args[2]))).body[0], node)) # Dynamic access. Emit nothing and store memory on assignment if ast.literal_eval(rhs.value.args[0]) == -1: if len(rhs.value.args) >= 2: self.accumOnAssignment[node.value.left.id] = \ (stripped_subscript, rhs.value.args[1]) else: self.storeOnAssignment[node.value.left.id] = \ ast.parse(stripped_subscript).body[0].value return init_expr, None, prefix # Make sure WCR function exists if len(rhs.value.args) >= 2: result = ast.parse( '%s = (%s)(%s, %s)' % (stripped_subscript, unparse( rhs.value.args[1]), stripped_subscript, node.value.left.id)).body[0] result = _copy_location(result, node) else: result = ast.parse('%s = %s' % (stripped_subscript, node.value.left.id)).body[0] result = _copy_location(result, node) else: array_name = rhs.value.id if not isinstance(rhs.slice, ast.Index): init_expr = _copy_location( ast.Assign(targets=[ ast.Name(id=node.value.left.id, ctx=ast.Store()) ], value=ast.Subscript( value=ast.Name(id=array_name, ctx=ast.Load()), slice=rhs.slice, ctx=ast.Load())), node) elif not isinstance(rhs, ast.Subscript): if isinstance(rhs, ast.Call): array_name = rhs.func else: array_name = rhs lhs_name = lhs.id # In case of "tmp >> array", write "array[:]" if node.value.left.id in self.curprim.transients: init_expr = None # If reading from a single stream ("b << stream") elif (array_name.id in arrays and isinstance(arrays[array_name.id], data.Stream)): if arrays[array_name.id].shape == [1]: init_expr = _copy_location( ast.parse('{v} = {q}[0]'.format( v=lhs_name, q=array_name.id)).body[0], node) return init_expr, None, [] else: init_expr = _copy_location( ast.Assign(targets=[ ast.Name(id=lhs_name, ctx=ast.Store()) ], value=ast.Subscript( value=ast.Name(id=array_name.id, ctx=ast.Load()), slice=ast.Slice(lower=None, upper=None, step=None), ctx=ast.Load())), node) # If we are setting a stream's sink if lhs_name in arrays and isinstance( arrays[lhs_name], data.Stream): result = ast.parse( '{arr}[0:len({q}[0])] = list({q}[0])'.format( arr=rhs.id, q=lhs.id)).body[0] result = _copy_location(result, node) # If WCR function exists elif isinstance(rhs, ast.Call) and len(rhs.args) >= 2: # WCR initialization with identity value if len(rhs.args) >= 3: prefix.append( _copy_location( ast.parse('%s[:] = %s' % (array_name.id, unparse(rhs.args[2]))).body[0], node)) # Dynamic access. Emit nothing and store memory on assignment if ast.literal_eval(rhs.args[0]) == -1: self.accumOnAssignment[lhs.id] = (array_name.id, rhs.args[1]) return init_expr, None, prefix result = ast.parse( '%s[:] = (%s)(%s[:], %s)' % (array_name.id, unparse(rhs.args[1]), array_name.id, node.value.left.id)).body[0] result = _copy_location(result, node) else: result = _copy_location( ast.Assign(targets=[ ast.Subscript(value=ast.Name(id=array_name.id, ctx=ast.Load()), slice=ast.Slice(lower=None, upper=None, step=None), ctx=ast.Store()) ], value=node.value.left), node) if result is None: result = _copy_location( ast.Assign(targets=[node.value.right], value=node.value.left), node) result.targets[0].ctx = ast.Store() return init_expr, [result], prefix # END of "a >> b" return self.generic_visit(node), [], None
def visit_FunctionDef(self, node): after_nodes = [] if self.curprim is None: self.curprim = self.pdp self.curchild = -1 if isinstance(node.decorator_list[0], ast.Call): self.module_name = node.decorator_list[0].func.value.id else: self.module_name = node.decorator_list[0].value.id # Strip decorator del node.decorator_list[0] oldchild = self.curchild oldprim = self.curprim else: if len(node.decorator_list) == 0: return self.generic_visit(node) dec = node.decorator_list[0] if isinstance(dec, ast.Call): decname = rname(dec.func.attr) else: decname = rname(dec.attr) if decname in [ 'map', 'reduce', 'consume', 'tasklet', 'iterate', 'loop', 'conditional' ]: self.curchild += 1 oldchild = self.curchild oldprim = self.curprim self.curprim = self.curprim.children[self.curchild] self.curchild = -1 if isinstance(self.curprim, astnodes._MapNode): newnode = \ _copy_location(ast.For(target=ast.Tuple(ctx=ast.Store(), elts=[ast.Name(id=name, ctx=ast.Store()) for name in self.curprim.params]), iter=ast.parse('%s.ndrange(%s)' % (self.module_name, self.curprim.range.pystr())).body[0].value, body=node.body, orelse=[]), node) node = newnode elif isinstance(self.curprim, astnodes._ConsumeNode): stream = self.curprim.stream if isinstance(self.curprim.stream, ast.AST): stream = unparse(self.curprim.stream) if '[' not in stream: stream += '[0]' newnode = \ _copy_location(ast.While( test=ast.parse('len(%s) > 0' % stream).body[0].value, body=node.body, orelse=[]), node) node = newnode node.body.insert( 0, _copy_location( ast.parse( '%s = %s.popleft()' % (str(self.curprim.params[0]), stream)).body[0], node)) elif isinstance(self.curprim, astnodes._TaskletNode): # Strip decorator del node.decorator_list[0] newnode = \ _copy_location(ast.parse('if True: pass').body[0], node) newnode.body = node.body newnode = ast.fix_missing_locations(newnode) node = newnode elif isinstance(self.curprim, astnodes._ReduceNode): in_memlet = self.curprim.inputs['input'] out_memlet = self.curprim.outputs['output'] # Create reduction call params = [unparse(p) for p in node.decorator_list[0].args] params.extend([ unparse(kp) for kp in node.decorator_list[0].keywords ]) reduction = ast.parse( '%s.simulator.simulate_reduce(%s, %s)' % (self.module_name, node.name, ', '.join(params))).body[0] reduction = _copy_location(reduction, node) reduction = ast.increment_lineno(reduction, len(node.body) + 1) reduction = ast.fix_missing_locations(reduction) # Strip decorator del node.decorator_list[0] after_nodes.append(reduction) elif isinstance(self.curprim, astnodes._IterateNode): newnode = \ _copy_location(ast.For(target=ast.Tuple(ctx=ast.Store(), elts=[ast.Name(id=name, ctx=ast.Store()) for name in self.curprim.params]), iter=ast.parse('%s.ndrange(%s)' % (self.module_name, self.curprim.range.pystr())).body[0].value, body=node.body, orelse=[]), node) newnode = ast.fix_missing_locations(newnode) node = newnode elif isinstance(self.curprim, astnodes._LoopNode): newnode = \ _copy_location(ast.While(test=node.decorator_list[0].args[0], body=node.body, orelse=[]), node) newnode = ast.fix_missing_locations(newnode) node = newnode else: raise RuntimeError('Unimplemented primitive %s' % decname) else: return self.generic_visit(node) newbody = [] end_stmts = [] substitute_stmts = [] # Incrementally build new body from original body for stmt in node.body: if isinstance(stmt, ast.Expr): res, append, prepend = self.VisitTopLevelExpr(stmt) if res is not None: newbody.append(res) if append is not None: end_stmts.extend(append) if prepend is not None: substitute_stmts.extend(prepend) else: subnodes = self.visit(stmt) if subnodes is not None: if isinstance(subnodes, list): newbody.extend(subnodes) else: newbody.append(subnodes) node.body = newbody + end_stmts self.curchild = oldchild self.curprim = oldprim substitute_stmts.append(node) if len(after_nodes) > 0: return substitute_stmts + after_nodes return substitute_stmts
def _generate_pdp(self, args, kwargs, simplify=None) -> SDFG: """ Generates the parsed AST representation of a DaCe program. :param args: The given arguments to the program. :param kwargs: The given keyword arguments to the program. :param simplify: Whether to apply simplification pass when parsing nested dace programs. :return: A 2-tuple of (parsed SDFG object, was the SDFG retrieved from cache). """ dace_func = self.f # If exist, obtain type annotations (for compilation) argtypes, _, gvars, specified = self._get_type_annotations( args, kwargs) # Move "self" from an argument into the closure if self.methodobj is not None: self.global_vars[self.objname] = self.methodobj for k, v in argtypes.items(): if isinstance( v, data.View): # Arguments to (nested) SDFG cannot be Views argtypes[k] = v.as_array() argtypes[k].transient = False elif v.transient: # Arguments to (nested) SDFGs cannot be transient v_cpy = copy.deepcopy(v) v_cpy.transient = False argtypes[k] = v_cpy ############################################# # Parse allowed global variables # (for inferring types and values in the DaCe program) global_vars = copy.copy(self.global_vars) # Remove None arguments and make into globals that can be folded removed_args = set() for k, v in argtypes.items(): if v.dtype.type is None: global_vars[k] = None removed_args.add(k) # Set module aliases to point to their actual names modules = { k: v.__name__ for k, v in global_vars.items() if dtypes.ismodule(v) } modules['builtins'] = '' # Add symbols as globals with their actual names (sym_0 etc.) global_vars.update({ v.name: v for _, v in global_vars.items() if isinstance(v, symbolic.symbol) }) # Add default arguments to global vars unspecified_default_args = { k: v for k, v in self.default_args.items() if k not in specified } removed_args.update(unspecified_default_args) gvars.update(unspecified_default_args) # Add constant arguments to global_vars global_vars.update(gvars) argtypes = {k: v for k, v in argtypes.items() if k not in removed_args} for argtype in argtypes.values(): global_vars.update({v.name: v for v in argtype.free_symbols}) # Parse AST to create the SDFG parsed_ast, closure = preprocessing.preprocess_dace_program( dace_func, argtypes, global_vars, modules, resolve_functions=self.resolve_functions, default_args=unspecified_default_args.keys()) # Create new argument mapping from closure arrays arg_mapping = { k: v for k, (_, _, v, _) in closure.closure_arrays.items() } self.closure_arg_mapping = arg_mapping self.closure_array_keys = set( closure.closure_arrays.keys()) - removed_args self.closure_constant_keys = set( closure.closure_constants.keys()) - removed_args self.resolver = closure # If parsed SDFG is already cached, use it cachekey = self._cache.make_key(argtypes, specified, self.closure_array_keys, self.closure_constant_keys, gvars) if self._cache.has(cachekey): sdfg = self._cache.get(cachekey).sdfg cached = True else: cached = False try: sdfg = newast.parse_dace_program(self.name, parsed_ast, argtypes, self.dec_kwargs, closure, simplify=simplify) except Exception: if Config.get_bool('frontend', 'verbose_errors'): from dace.frontend.python import astutils print('VERBOSE: Failed to parse the following program:') print(astutils.unparse(parsed_ast.preprocessed_ast)) raise # Set SDFG argument names, filtering out constants sdfg.arg_names = [a for a in self.argnames if a in argtypes] # TODO: Add to parsed SDFG cache return sdfg, cached
def wcr_name(self): label = astutils.unparse(self.wcr.body) if self.wcr_identity is not None: label += ', id: ' + str(self.wcr_identity) return label
def promote_scalars_to_symbols(sdfg: sd.SDFG, ignore: Optional[Set[str]] = None, transients_only: bool = True, integers_only: bool = True) -> Set[str]: """ Promotes all matching transient scalars to SDFG symbols, changing all tasklets to inter-state assignments. This enables the transformed symbols to be used within states as part of memlets, and allows further transformations (such as loop detection) to use the information for optimization. :param sdfg: The SDFG to run the pass on. :param ignore: An optional set of strings of scalars to ignore. :param transients_only: If False, also considers global data descriptors (e.g., arguments). :param integers_only: If False, also considers non-integral descriptors for promotion. :return: Set of promoted scalars. :note: Operates in-place. """ # Process: # 1. Find scalars to promote # 2. For every assignment tasklet/access: # 2.1. Fission state to isolate assignment # 2.2. Replace assignment with inter-state edge assignment # 3. For every read of the scalar: # 3.1. If destination is tasklet, remove node, edges, and connectors # 3.2. If used in tasklet as subscript or connector, modify tasklet code # 3.3. If destination is array, change to tasklet that copies symbol data # 4. Remove newly-isolated access nodes # 5. Remove data descriptors and add symbols to SDFG # 6. Replace subscripts in all interstate conditions and assignments # 7. Make indirections with symbols a single memlet to_promote = find_promotable_scalars(sdfg, transients_only=transients_only, integers_only=integers_only) if ignore: to_promote -= ignore if len(to_promote) == 0: return to_promote for state in sdfg.nodes(): scalar_nodes = [ n for n in state.nodes() if isinstance(n, nodes.AccessNode) and n.data in to_promote ] # Step 2: Assignment tasklets for node in scalar_nodes: if state.in_degree(node) == 0: continue in_edge = state.in_edges(node)[0] input = in_edge.src # There is only zero or one incoming edges by definition tasklet_inputs = [e.src for e in state.in_edges(input)] # Step 2.1 new_state = xfh.state_fission( sdfg, gr.SubgraphView(state, set([input, node] + tasklet_inputs))) new_isedge: sd.InterstateEdge = sdfg.out_edges(new_state)[0] # Step 2.2 node: nodes.AccessNode = new_state.sink_nodes()[0] input = new_state.in_edges(node)[0].src if isinstance(input, nodes.Tasklet): # Convert tasklet to interstate edge newcode: str = '' if input.language is dtypes.Language.Python: newcode = astutils.unparse(input.code.code[0].value) elif input.language is dtypes.Language.CPP: newcode = translate_cpp_tasklet_to_python( input.code.as_string.strip()) # Replace tasklet inputs with incoming edges for e in new_state.in_edges(input): memlet_str: str = e.data.data if (e.data.subset is not None and not isinstance( sdfg.arrays[memlet_str], dt.Scalar)): memlet_str += '[%s]' % e.data.subset newcode = re.sub(r'\b%s\b' % re.escape(e.dst_conn), memlet_str, newcode) # Add interstate edge assignment new_isedge.data.assignments[node.data] = newcode elif isinstance(input, nodes.AccessNode): memlet: mm.Memlet = in_edge.data if (memlet.src_subset and not isinstance(sdfg.arrays[memlet.data], dt.Scalar)): new_isedge.data.assignments[ node.data] = '%s[%s]' % (input.data, memlet.src_subset) else: new_isedge.data.assignments[node.data] = input.data # Clean up all nodes after assignment was transferred new_state.remove_nodes_from(new_state.nodes()) # Step 3: Scalar reads remove_scalar_reads(sdfg, {k: k for k in to_promote}) # Step 4: Isolated nodes for state in sdfg.nodes(): scalar_nodes = [ n for n in state.nodes() if isinstance(n, nodes.AccessNode) and n.data in to_promote ] state.remove_nodes_from( [n for n in scalar_nodes if len(state.all_edges(n)) == 0]) # Step 5: Data descriptor management for scalar in to_promote: desc = sdfg.arrays[scalar] sdfg.remove_data(scalar, validate=False) # If the scalar is already a symbol (e.g., as part of an array size), # do not re-add the symbol if scalar not in sdfg.symbols: sdfg.add_symbol(scalar, desc.dtype) # Step 6: Inter-state edge cleanup cleanup_re = { s: re.compile(fr'\b{re.escape(s)}\[.*?\]') for s in to_promote } promo = TaskletPromoterDict({k: k for k in to_promote}) for edge in sdfg.edges(): ise: InterstateEdge = edge.data # Condition if not edge.data.is_unconditional(): if ise.condition.language is dtypes.Language.Python: for stmt in ise.condition.code: promo.visit(stmt) elif ise.condition.language is dtypes.Language.CPP: for scalar in to_promote: ise.condition = cleanup_re[scalar].sub( scalar, ise.condition.as_string) # Assignments for aname, assignment in ise.assignments.items(): for scalar in to_promote: if scalar in assignment: ise.assignments[aname] = cleanup_re[scalar].sub( scalar, assignment.strip()) # Step 7: Indirection remove_symbol_indirection(sdfg) return to_promote
def global_value_to_node(self, value, parent_node, qualname, recurse=False, detect_callables=False): # if recurse is false, we don't allow recursion into lists # this should not happen anyway; the globals dict should only contain # single "level" lists if not recurse and isinstance(value, (list, tuple)): # bail after more than one level of lists return None if isinstance(value, list): elts = [ self.global_value_to_node(v, parent_node, qualname + f'[{i}]', detect_callables=detect_callables) for i, v in enumerate(value) ] if any(e is None for e in elts): return None newnode = ast.List(elts=elts, ctx=parent_node.ctx) elif isinstance(value, tuple): elts = [ self.global_value_to_node(v, parent_node, qualname + f'[{i}]', detect_callables=detect_callables) for i, v in enumerate(value) ] if any(e is None for e in elts): return None newnode = ast.Tuple(elts=elts, ctx=parent_node.ctx) elif isinstance(value, symbolic.symbol): # Symbols resolve to the symbol name newnode = ast.Name(id=value.name, ctx=ast.Load()) elif (dtypes.isconstant(value) or isinstance(value, SDFG) or hasattr(value, '__sdfg__')): # Could be a constant, an SDFG, or SDFG-convertible object if isinstance(value, SDFG) or hasattr(value, '__sdfg__'): self.closure.closure_sdfgs[id(value)] = (qualname, value) else: # If this is a function call to a None function, do not add its result to the closure if isinstance(parent_node, ast.Call): fqname = getattr(parent_node.func, 'qualname', astutils.rname(parent_node.func)) if fqname in self.closure.closure_constants and self.closure.closure_constants[ fqname] is None: return None if hasattr(parent_node.func, 'n') and parent_node.func.n is None: return None self.closure.closure_constants[qualname] = value # Compatibility check since Python changed their AST nodes if sys.version_info >= (3, 8): newnode = ast.Constant(value=value, kind='') else: if value is None: newnode = ast.NameConstant(value=None) elif isinstance(value, str): newnode = ast.Str(s=value) else: newnode = ast.Num(n=value) newnode.qualname = qualname elif detect_callables and hasattr(value, '__call__') and hasattr( value.__call__, '__sdfg__'): return self.global_value_to_node(value.__call__, parent_node, qualname, recurse, detect_callables) elif isinstance(value, numpy.ndarray): # Arrays need to be stored as a new name and fed as an argument if id(value) in self.closure.array_mapping: arrname = self.closure.array_mapping[id(value)] else: arrname = self._qualname_to_array_name(qualname) desc = data.create_datadescriptor(value) self.closure.closure_arrays[arrname] = ( qualname, desc, lambda: eval(qualname, self.globals), False) self.closure.array_mapping[id(value)] = arrname newnode = ast.Name(id=arrname, ctx=ast.Load()) elif detect_callables and callable(value): # Try parsing the function as a dace function/method newnode = None try: from dace.frontend.python import parser # Avoid import loops parent_object = None if hasattr(value, '__self__'): parent_object = value.__self__ # If it is a callable object if (not inspect.isfunction(value) and not inspect.ismethod(value) and not inspect.isbuiltin(value) and hasattr(value, '__call__')): parent_object = value value = value.__call__ # Replacements take precedence over auto-parsing try: if has_replacement(value, parent_object, parent_node): return None except Exception: pass # Store the handle to the original callable, in case parsing fails if isinstance(parent_node, ast.Call): cbqualname = astutils.unparse(parent_node.func) else: cbqualname = astutils.rname(parent_node) cbname = self._qualname_to_array_name(cbqualname, prefix='') self.closure.callbacks[cbname] = (cbqualname, value, False) # From this point on, any failure will result in a callback newnode = ast.Name(id=cbname, ctx=ast.Load()) # Decorated or functions with missing source code sast, _, _, _ = astutils.function_to_ast(value) if len(sast.body[0].decorator_list) > 0: return newnode parsed = parser.DaceProgram(value, [], {}, False, dtypes.DeviceType.CPU) # If method, add the first argument (which disappears due to # being a bound method) and the method's object if parent_object is not None: parsed.methodobj = parent_object parsed.objname = inspect.getfullargspec(value).args[0] res = self.global_value_to_node(parsed, parent_node, qualname, recurse, detect_callables) res.oldnode = copy.deepcopy(parent_node) # Keep callback in callbacks in case of parsing failure # del self.closure.callbacks[cbname] return res except Exception: # Parsing failed (almost any exception can occur) return newnode else: return None if parent_node is not None: return ast.copy_location(newnode, parent_node) else: return newnode