def parse_memlet(visitor, src: MemletType, dst: MemletType, defined_arrays_and_symbols: Dict[str, data.Data]): srcexpr, dstexpr, localvar = None, None, None if isinstance(src, ast.Name) and rname(src) not in defined_arrays_and_symbols: localvar = rname(src) else: srcexpr = ParseMemlet(visitor, defined_arrays_and_symbols, src) if isinstance(dst, ast.Name) and rname(dst) not in defined_arrays_and_symbols: if localvar is not None: raise DaceSyntaxError( visitor, src, 'Memlet source and destination cannot both be local variables') localvar = rname(dst) else: dstexpr = ParseMemlet(visitor, defined_arrays_and_symbols, dst) if srcexpr is not None and dstexpr is not None: # Create two memlets raise NotImplementedError elif srcexpr is not None: expr = srcexpr else: expr = dstexpr return localvar, Memlet(expr.name, expr.accesses, expr.subset, 1, wcr=expr.wcr)
def _Call(self, t): res_type = self.infer(t)[0] if not res_type: raise util.NotSupportedError(f'Unsupported call') if not isinstance(res_type, dtypes.vector): # Call does not involve any vectors (to our knowledge) # Replace default modules (e.g., math) with dace::math:: attr_name = astutils.rname(t) module_name = attr_name[:attr_name.rfind(".")] func_name = attr_name[attr_name.rfind(".") + 1:] if module_name not in dtypes._ALLOWED_MODULES: raise NotImplementedError( f'Module {module_name} is not implemented') cpp_mod_name = dtypes._ALLOWED_MODULES[module_name] name = cpp_mod_name + func_name self.write(name) self.write('(') comma = False for e in t.args: if comma: self.write(", ") else: comma = True self.dispatch(e) self.write(')') return name = None if isinstance(t.func, ast.Name): # Could be an internal operation (provided by the preprocessor) if not util.is_sve_internal(t.func.id): raise NotImplementedError( f'Function {t.func.id} is not implemented') name = util.internal_to_external(t.func.id)[0] elif isinstance(t.func, ast.Attribute): # Some module function (xxx.xxx), make sure it is available name = util.MATH_FUNCTION_TO_SVE.get(astutils.rname(t.func)) if name is None: raise NotImplementedError( f'Function {astutils.rname(t.func)} is not implemented') # Vectorized function self.write('{}_x({}, '.format(name, self.pred_name)) comma = False for e in t.args: if comma: self.write(", ") else: comma = True self.dispatch_expect(e, res_type) self.write(')')
def visit_TopLevelExpr(self, node): if isinstance(node.value, ast.BinOp): if isinstance(node.value.op, ast.LShift): # Obtain memlet metadata and clean AST node from memlet syntax cleaned_right, dynamic, _ = self._clean_memlet( node.value.right) # Replace "a << A[i]" with "a = A[i]" at the beginning if not dynamic: storenode = copy.deepcopy(node.value.left) storenode.ctx = ast.Store() self.pre_statements.append( _copy_location( ast.Assign(targets=[storenode], value=cleaned_right), node)) else: # In-place replacement self.name_replacements[rname( node.value.left)] = cleaned_right return None # Remove from final tasklet code elif isinstance(node.value.op, ast.RShift): # Obtain memlet metadata and clean AST node from memlet syntax cleaned_right, dynamic, wcr = self._clean_memlet( node.value.right) # Replace "a >> A[i]" with "A[i] = a" at the end if not dynamic: rhs = node.value.left if wcr is not None: # If WCR is involved, change expression to include # lambda: "A[i] = (lambda a,b: a+b)(A[i], a)" rhs = _copy_location( ast.Call(func=wcr, args=[cleaned_right, rhs], keywords=[]), rhs) lhs = copy.deepcopy(cleaned_right) lhs.ctx = ast.Store() self.post_statements.append( _copy_location(ast.Assign(targets=[lhs], value=rhs), node)) else: if wcr is not None: # Replace Assignments with lambda every time self.wcr_replacements[rname( node.value.left)] = (cleaned_right, wcr) else: # In-place replacement self.assign_replacements[rname( node.value.left)] = cleaned_right return None # Remove from final tasklet code return self.generic_visit(node)
def has_replacement(callobj: Callable, parent_object: Optional[Any] = None, node: Optional[ast.AST] = None) -> bool: """ Returns True if the function/operator replacement repository has a registered replacement for the called function described by a live object. """ from dace.frontend.common import op_repository as oprepo # Nothing from the `dace` namespace needs preprocessing mod = None try: mod = callobj.__module__ except AttributeError: try: mod = parent_object.__module__ except AttributeError: pass if mod and (mod == 'dace' or mod.startswith('dace.') or mod == 'math' or mod.startswith('math.')): return True # Attributes and methods classname = None if parent_object is not None: classname = type(parent_object).__name__ attrname = callobj.__name__ repl = oprepo.Replacements.get_attribute(classname, attrname) if repl is not None: return True repl = oprepo.Replacements.get_method(classname, attrname) if repl is not None: return True # NumPy ufuncs if (isinstance(callobj, numpy.ufunc) or isinstance(parent_object, numpy.ufunc)): return True # Functions # Special case: Constructor method (e.g., numpy.ndarray) if classname == "type": cbqualname = astutils.rname(node) if oprepo.Replacements.get(cbqualname) is not None: return True full_func_name = callobj.__module__ + '.' + callobj.__qualname__ if oprepo.Replacements.get(full_func_name) is not None: return True # Also try the function as it is called in the AST return oprepo.Replacements.get(astutils.rname(node)) is not None
def visit_Expr(self, node): # Check for DaCe function calls if isinstance(node.value, ast.Call): # Some calls should not be parsed if rname(node.value.func) == "define_local": return None elif rname(node.value.func) == "define_local_scalar": return None elif rname(node.value.func) == "define_stream": return None elif rname(node.value.func) == "define_streamarray": return None return self.generic_visit(node)
def has_replacement(callobj: Callable, parent_object: Optional[Any] = None, node: Optional[ast.AST] = None) -> bool: """ Returns True if the function/operator replacement repository has a registered replacement for the called function described by a live object. """ from dace.frontend.common import op_repository as oprepo # Attributes and methods if parent_object is not None: classname = type(parent_object).__name__ attrname = callobj.__name__ repl = oprepo.Replacements.get_attribute(classname, attrname) if repl is not None: return True repl = oprepo.Replacements.get_method(classname, attrname) if repl is not None: return True # NumPy ufuncs if isinstance(callobj, numpy.ufunc): return True # Functions full_func_name = callobj.__module__ + '.' + callobj.__qualname__ if oprepo.Replacements.get(full_func_name) is not None: return True # Also try the function as it is called in the AST return oprepo.Replacements.get(astutils.rname(node)) is not None
def visit_Subscript(self, node: ast.Subscript) -> Any: # Convert subscript to symbol name node_name = astutils.rname(node) if node_name in self.conn_to_sym: return ast.copy_location( ast.Name(id=self.conn_to_sym[node_name], ctx=ast.Load()), node) return self.generic_visit(node)
def visit_Subscript(self, node: ast.Subscript) -> Any: # Convert subscript to symbol name node_name = astutils.rname(node) if node_name in self.iconns: self.latest[node_name] += 1 new_name = f'{node_name}_{self.latest[node_name]}' subset = subsets.Range( astutils.subscript_to_slice(node, self.sdfg.arrays)[1]) # Check if range can be collapsed if _range_is_promotable(subset, self.defined): self.in_mapping[new_name] = (node_name, subset) return ast.copy_location(ast.Name(id=new_name, ctx=ast.Load()), node) else: self.do_not_remove.add(node_name) elif node_name in self.oconns: self.latest[node_name] += 1 new_name = f'{node_name}_{self.latest[node_name]}' subset = subsets.Range( astutils.subscript_to_slice(node, self.sdfg.arrays)[1]) # Check if range can be collapsed if _range_is_promotable(subset, self.defined): self.out_mapping[new_name] = (node_name, subset) return ast.copy_location( ast.Name(id=new_name, ctx=ast.Store()), node) else: self.do_not_remove.add(node_name) return self.generic_visit(node)
def visit_Assign(self, node): if rname(node.targets[0]) in self.accumOnAssignment: var_name = rname(node.targets[0]) array_name, accum = self.accumOnAssignment[var_name] if isinstance(node.targets[0], ast.Subscript): array_name += '[' + unparse(node.targets[0].slice) + ']' if '[' not in array_name: array_name += '[:]' newnode = ast.parse('{out} = {accum}({out}, {val})'.format( out=array_name, accum=unparse(accum), val=unparse(node.value))).body[0] newnode = _copy_location(newnode, node) return newnode return self.generic_visit(node)
def visit_Attribute(self, node): attrname = rname(node) module_name = attrname[:attrname.rfind(".")] func_name = attrname[attrname.rfind(".") + 1:] if module_name in dtypes._ALLOWED_MODULES: cppmodname = dtypes._ALLOWED_MODULES[module_name] return ast.copy_location( ast.Name(id=(cppmodname + func_name), ctx=ast.Load), node) return self.generic_visit(node)
def visit_Call(self, node: ast.Call) -> Any: if isinstance(node.func, ast.Attribute): # Special case: calling attributed functions on constants (e.g., dace.int64(2)) if (len(node.args) == 1 and astutils.is_constant(node.args[0]) and astutils.rname(node.func.value) == 'dace'): return self.generic_visit(node) self.detected = True return return self.generic_visit(node)
def visit_Name(self, node: ast.Name): name = rname(node) if name not in self.memlets: return self.generic_visit(node) memlet, nc, wcr, dtype = self.memlets[name] if (isinstance(dtype, dtypes.pointer) and memlet.subset.num_elements() == 1): return ast.Name(id="(*{})".format(name), ctx=node.ctx) else: return self.generic_visit(node)
def visit_AugAssign(self, node): if not isinstance(node.target, ast.Subscript): return self.generic_visit(node) target = rname(node.target) if target not in self.memlets: return self.generic_visit(node) raise SyntaxError("Augmented assignments (e.g. +=) not allowed on " + "array memlets")
def visit_Assign(self, node): target = rname(node.targets[0]) if target not in self.memlets: return self.generic_visit(node) memlet, nc, wcr = self.memlets[target] value = self.visit(node.value) if not isinstance(node.targets[0], ast.Subscript): # Dynamic accesses -> every access counts try: if memlet is not None and memlet.num_accesses < 0: if wcr is not None: newnode = ast.Name(id=write_and_resolve_expr( self.sdfg, memlet, nc, '__' + target, cppunparse.cppunparse(value, expr_semicolon=False))) else: newnode = ast.Name(id="__%s.write(%s);" % ( target, cppunparse.cppunparse(value, expr_semicolon=False), )) return ast.copy_location(newnode, node) except TypeError: # cannot determine truth value of Relational pass return self.generic_visit(node) slice = self.visit(node.targets[0].slice) if not isinstance(slice, ast.Index): raise NotImplementedError("Range subscripting not implemented") if isinstance(slice.value, ast.Tuple): subscript = unparse(slice)[1:-1] else: subscript = unparse(slice) if wcr is not None: newnode = ast.Name(id=write_and_resolve_expr( self.sdfg, memlet, nc, "__" + target, cppunparse.cppunparse(value, expr_semicolon=False), indices=subscript, )) else: newnode = ast.Name(id="__%s.write(%s, %s);" % ( target, cppunparse.cppunparse(value, expr_semicolon=False), subscript, )) return ast.copy_location(newnode, node)
def visit_Call(self, node: ast.Call): # Only parse calls to parsed SDFGConvertibles if not isinstance(node.func, (ast.Num, ast.Constant)): self.seen_calls.add(astutils.rname(node.func)) return self.generic_visit(node) if hasattr(node.func, 'oldnode'): if isinstance(node.func.oldnode, ast.Call): self.seen_calls.add(astutils.rname(node.func.oldnode.func)) else: self.seen_calls.add(astutils.rname(node.func.oldnode)) if isinstance(node.func, ast.Num): value = node.func.n else: value = node.func.value if not hasattr(value, '__sdfg__') or isinstance(value, SDFG): return self.generic_visit(node) constant_args = self._eval_args(node) # Resolve nested closure as necessary qualname = None try: qualname = next(k for k, v in self.closure.closure_sdfgs.items() if v is value) self.seen_calls.add(qualname) if hasattr(value, 'closure_resolver'): self.closure.nested_closures.append( (qualname, value.closure_resolver(constant_args, self.closure))) else: self.closure.nested_closures.append((qualname, SDFGClosure())) except DaceRecursionError: # Parsing failed in a nested context, raise raise except Exception as ex: # Parsing failed (anything can happen here) warnings.warn(f'Parsing SDFGConvertible {value} failed: {ex}') if qualname in self.closure.closure_sdfgs: del self.closure.closure_sdfgs[qualname] # Return old call AST instead node.func = node.func.oldnode.func return self.generic_visit(node)
def visit_Subscript(self, node: ast.Subscript) -> ast.Subscript: if astutils.rname(node.value) == self.to_refine: rng = subsets.Range( astutils.subscript_to_slice(node, self.sdfg.arrays, without_array=True)) rng.offset(self.subset, True, self.indices) return ast.copy_location( astutils.slice_to_subscript(self.to_refine, rng), node) return self.generic_visit(node)
def visit_Call(self, node: ast.Call): # Only parse calls to parsed SDFGConvertibles if not isinstance(node.func, (ast.Num, ast.Constant)): self.seen_calls.add(astutils.unparse(node.func)) return self.generic_visit(node) if hasattr(node.func, 'oldnode'): if isinstance(node.func.oldnode, ast.Call): self.seen_calls.add(astutils.unparse(node.func.oldnode.func)) else: self.seen_calls.add(astutils.rname(node.func.oldnode)) if isinstance(node.func, ast.Num): value = node.func.n else: value = node.func.value if not hasattr(value, '__sdfg__') or isinstance(value, SDFG): return self.generic_visit(node) constant_args = self._eval_args(node) # Resolve nested closure as necessary qualname = None try: if id(value) in self.closure.closure_sdfgs: qualname, _ = self.closure.closure_sdfgs[id(value)] elif hasattr(node.func, 'qualname'): qualname = node.func.qualname self.seen_calls.add(qualname) if hasattr(value, 'closure_resolver'): self.closure.nested_closures.append( (qualname, value.closure_resolver(constant_args, self.closure))) else: self.closure.nested_closures.append((qualname, SDFGClosure())) except DaceRecursionError: # Parsing failed in a nested context, raise raise except Exception as ex: # Parsing failed (anything can happen here) optional_qname = '' if qualname is not None: optional_qname = f' ("{qualname}")' warnings.warn( f'Preprocessing SDFGConvertible {value}{optional_qname} failed with {type(ex).__name__}: {ex}' ) if Config.get_bool('frontend', 'raise_nested_parsing_errors'): raise if id(value) in self.closure.closure_sdfgs: del self.closure.closure_sdfgs[id(value)] # Return old call AST instead if not hasattr(node.func, 'oldnode'): raise node.func = node.func.oldnode.func return self.generic_visit(node)
def visit_Call(self, node: ast.Call): # Struct initializer name = astutils.rname(node.func) if name not in self._structs: return self.generic_visit(node) # Parse name and fields struct = self._structs[name] name = struct.name fields = {astutils.rname(arg.arg): arg.value for arg in node.keywords} if tuple(sorted(fields.keys())) != tuple(sorted(struct.fields.keys())): raise SyntaxError('Mismatch in fields in struct definition') # Create custom node #new_node = astutils.StructInitializer(name, fields) #return ast.copy_location(new_node, node) node.func = ast.copy_location( ast.Name(id='__DACESTRUCT_' + name, ctx=ast.Load()), node.func) return node
def visit_Subscript(self, node): target = rname(node) if target not in self.memlets and target not in self.constants: return self.generic_visit(node) subscript = self._subscript_expr(node.slice, target) # New subscript is created as a name AST object (rather than a # subscript), as otherwise the visitor will recursively descend into # the new expression and modify it erroneously. newnode = ast.Name(id="%s[%s]" % (target, sym2cpp(subscript))) return ast.copy_location(newnode, node)
def visit_Call(self, node): if isinstance(node.func, ast.Name) and (node.func.id.startswith('__DACESTRUCT_') or node.func.id in self._structs): fields = ', '.join([ '.%s = %s' % (rname(arg.arg), cppunparse.pyexpr2cpp(arg.value)) for arg in sorted(node.keywords, key=lambda x: x.arg) ]) tname = node.func.id if node.func.id.startswith('__DACESTRUCT_'): tname = node.func.id[len('__DACESTRUCT_'):] return ast.copy_location( ast.Name(id="%s { %s }" % (tname, fields), ctx=ast.Load), node) return self.generic_visit(node)
def ParseMemlet(visitor, defined_arrays_and_symbols: Dict[str, Any], node: MemletType, parsed_slice: Any = None) -> MemletExpr: das = defined_arrays_and_symbols arrname = rname(node) if arrname not in das: raise DaceSyntaxError(visitor, node, 'Use of undefined data "%s" in memlet' % arrname) array = das[arrname] # Determine number of accesses to the memlet (default is the slice size) num_accesses = None write_conflict_resolution = None # Detects expressions of the form "A(2)[...]", "A(300)", "A(1, sum)[:]" if isinstance(node, ast.Call): if len(node.args) < 1 or len(node.args) > 3: raise DaceSyntaxError( visitor, node, 'Number of accesses in memlet must be a number, symbolic ' 'expression, or -1 (dynamic)') num_accesses = pyexpr_to_symbolic(das, node.args[0]) if len(node.args) >= 2: write_conflict_resolution = node.args[1] elif isinstance(node, ast.Subscript) and isinstance(node.value, ast.Call): if len(node.value.args) < 1 or len(node.value.args) > 3: raise DaceSyntaxError( visitor, node, 'Number of accesses in memlet must be a number, symbolic ' 'expression, or -1 (dynamic)') num_accesses = pyexpr_to_symbolic(das, node.value.args[0]) if len(node.value.args) >= 2: write_conflict_resolution = node.value.args[1] subset, new_axes, arrdims = parse_memlet_subset(array, node, das, parsed_slice) # If undefined, default number of accesses is the slice size if num_accesses is None: num_accesses = subset.num_elements() return MemletExpr(arrname, num_accesses, write_conflict_resolution, subset, new_axes, arrdims)
def visit_Subscript(self, node): target = rname(node) if target not in self.memlets and target not in self.constants: return self.generic_visit(node) slice = self.visit(node.slice) if not isinstance(slice, ast.Index): raise NotImplementedError("Range subscripting not implemented") if isinstance(slice.value, ast.Tuple): subscript = unparse(slice)[1:-1] else: subscript = unparse(slice) if target in self.constants: slice_str = DaCeKeywordRemover.ndslice_cpp( subscript.split(", "), self.constants[target].shape) newnode = ast.parse("%s[%s]" % (target, slice_str)).body[0].value else: newnode = ast.parse("__%s(%s)" % (target, subscript)).body[0].value return ast.copy_location(newnode, node)
def visit_FunctionDef(self, node): after_nodes = [] if self.curprim is None: self.curprim = self.pdp self.curchild = -1 if isinstance(node.decorator_list[0], ast.Call): self.module_name = node.decorator_list[0].func.value.id else: self.module_name = node.decorator_list[0].value.id # Strip decorator del node.decorator_list[0] oldchild = self.curchild oldprim = self.curprim else: if len(node.decorator_list) == 0: return self.generic_visit(node) dec = node.decorator_list[0] if isinstance(dec, ast.Call): decname = rname(dec.func.attr) else: decname = rname(dec.attr) if decname in [ 'map', 'reduce', 'consume', 'tasklet', 'iterate', 'loop', 'conditional' ]: self.curchild += 1 oldchild = self.curchild oldprim = self.curprim self.curprim = self.curprim.children[self.curchild] self.curchild = -1 if isinstance(self.curprim, astnodes._MapNode): newnode = \ _copy_location(ast.For(target=ast.Tuple(ctx=ast.Store(), elts=[ast.Name(id=name, ctx=ast.Store()) for name in self.curprim.params]), iter=ast.parse('%s.ndrange(%s)' % (self.module_name, self.curprim.range.pystr())).body[0].value, body=node.body, orelse=[]), node) node = newnode elif isinstance(self.curprim, astnodes._ConsumeNode): stream = self.curprim.stream if isinstance(self.curprim.stream, ast.AST): stream = unparse(self.curprim.stream) if '[' not in stream: stream += '[0]' newnode = \ _copy_location(ast.While( test=ast.parse('len(%s) > 0' % stream).body[0].value, body=node.body, orelse=[]), node) node = newnode node.body.insert( 0, _copy_location( ast.parse( '%s = %s.popleft()' % (str(self.curprim.params[0]), stream)).body[0], node)) elif isinstance(self.curprim, astnodes._TaskletNode): # Strip decorator del node.decorator_list[0] newnode = \ _copy_location(ast.parse('if True: pass').body[0], node) newnode.body = node.body newnode = ast.fix_missing_locations(newnode) node = newnode elif isinstance(self.curprim, astnodes._ReduceNode): in_memlet = self.curprim.inputs['input'] out_memlet = self.curprim.outputs['output'] # Create reduction call params = [unparse(p) for p in node.decorator_list[0].args] params.extend([ unparse(kp) for kp in node.decorator_list[0].keywords ]) reduction = ast.parse( '%s.simulator.simulate_reduce(%s, %s)' % (self.module_name, node.name, ', '.join(params))).body[0] reduction = _copy_location(reduction, node) reduction = ast.increment_lineno(reduction, len(node.body) + 1) reduction = ast.fix_missing_locations(reduction) # Strip decorator del node.decorator_list[0] after_nodes.append(reduction) elif isinstance(self.curprim, astnodes._IterateNode): newnode = \ _copy_location(ast.For(target=ast.Tuple(ctx=ast.Store(), elts=[ast.Name(id=name, ctx=ast.Store()) for name in self.curprim.params]), iter=ast.parse('%s.ndrange(%s)' % (self.module_name, self.curprim.range.pystr())).body[0].value, body=node.body, orelse=[]), node) newnode = ast.fix_missing_locations(newnode) node = newnode elif isinstance(self.curprim, astnodes._LoopNode): newnode = \ _copy_location(ast.While(test=node.decorator_list[0].args[0], body=node.body, orelse=[]), node) newnode = ast.fix_missing_locations(newnode) node = newnode else: raise RuntimeError('Unimplemented primitive %s' % decname) else: return self.generic_visit(node) newbody = [] end_stmts = [] substitute_stmts = [] # Incrementally build new body from original body for stmt in node.body: if isinstance(stmt, ast.Expr): res, append, prepend = self.VisitTopLevelExpr(stmt) if res is not None: newbody.append(res) if append is not None: end_stmts.extend(append) if prepend is not None: substitute_stmts.extend(prepend) else: subnodes = self.visit(stmt) if subnodes is not None: if isinstance(subnodes, list): newbody.extend(subnodes) else: newbody.append(subnodes) node.body = newbody + end_stmts self.curchild = oldchild self.curprim = oldprim substitute_stmts.append(node) if len(after_nodes) > 0: return substitute_stmts + after_nodes return substitute_stmts
def visit_Call(self, node): if '.push' in rname(node.func): node.func.attr = 'append' return self.generic_visit(node)
def visit_Assign(self, node): target = rname(node.targets[-1]) if target not in self.memlets: return self.generic_visit(node) memlet, nc, wcr, dtype = self.memlets[target] value = self.visit(node.value) if not isinstance(node.targets[-1], ast.Subscript): # Dynamic accesses or streams -> every access counts try: if memlet and memlet.data and (memlet.dynamic or isinstance( self.sdfg.arrays[memlet.data], data.Stream)): if wcr is not None: newnode = ast.Name( id=self.codegen.write_and_resolve_expr( self.sdfg, memlet, nc, target, cppunparse.cppunparse(value, expr_semicolon=False), dtype=dtype)) node.value = ast.copy_location(newnode, node.value) return node elif isinstance(self.sdfg.arrays[memlet.data], data.Stream): newnode = ast.Name(id="%s.push(%s);" % ( memlet.data, cppunparse.cppunparse(value, expr_semicolon=False), )) else: var_type, ctypedef = self.codegen._dispatcher.defined_vars.get( memlet.data) if var_type == DefinedType.Scalar: newnode = ast.Name(id="%s = %s;" % ( memlet.data, cppunparse.cppunparse(value, expr_semicolon=False), )) else: newnode = ast.Name(id="%s = %s;" % ( cpp_array_expr(self.sdfg, memlet), cppunparse.cppunparse(value, expr_semicolon=False), )) return self._replace_assignment(newnode, node) except TypeError: # cannot determine truth value of Relational pass return self.generic_visit(node) subscript = self._subscript_expr(node.targets[-1].slice, target) if wcr is not None: newnode = ast.Name(id=self.codegen.write_and_resolve_expr( self.sdfg, memlet, nc, target, cppunparse.cppunparse(value, expr_semicolon=False), indices=sym2cpp(subscript), dtype=dtype) + ';') else: newnode = ast.Name( id="%s[%s] = %s;" % (target, sym2cpp(subscript), cppunparse.cppunparse(value, expr_semicolon=False))) return self._replace_assignment(newnode, node)
def find_promotable_scalars(sdfg: sd.SDFG, transients_only: bool = True, integers_only: bool = True) -> Set[str]: """ Finds scalars that can be promoted to symbols in the given SDFG. Conditions for matching a scalar for symbol-promotion are as follows: * Size of data must be 1, it must not be a stream and must be transient. * Only inputs to candidate scalars must be either arrays or tasklets. * All tasklets that lead to it must have one statement, one output, and may have zero or more **array** inputs and not be in a scope. * Scalar must not be accessed with a write-conflict resolution. * Scalar must not be written to more than once in a state. * If scalar is not integral (i.e., int type), it must also appear in an inter-state condition to be promotable. These conditions must apply on all occurences of the scalar in order for it to be promotable. :param sdfg: The SDFG to query. :param transients_only: If False, also considers global data descriptors (e.g., arguments). :param integers_only: If False, also considers non-integral descriptors for promotion. :return: A set of promotable scalar names. """ # Keep set of active candidates candidates: Set[str] = set() # General array checks for aname, desc in sdfg.arrays.items(): if (transients_only and not desc.transient) or isinstance( desc, dt.Stream): continue if desc.total_size != 1: continue if desc.lifetime is dtypes.AllocationLifetime.Persistent: continue candidates.add(aname) # Check all occurrences of candidates in SDFG and filter out candidates_seen: Set[str] = set() for state in sdfg.nodes(): candidates_in_state: Set[str] = set() for node in state.nodes(): if not isinstance(node, nodes.AccessNode): continue candidate = node.data if candidate not in candidates: continue # If candidate is read-only, continue normally if state.in_degree(node) == 0: continue # If candidate is read by a library node, skip removed = False for oe in state.out_edges(node): for e in state.memlet_tree(oe): if isinstance(e.dst, nodes.LibraryNode): candidates.remove(candidate) removed = True break if removed: break if removed: continue # End of read check # Candidate may only be accessed in a top-level scope if state.entry_node(node) is not None: candidates.remove(candidate) continue # Candidate may only be written to once within a state if candidate in candidates_in_state: if state.in_degree(node) == 1: candidates.remove(candidate) continue candidates_in_state.add(candidate) # If input is not a single array nor tasklet, skip if state.in_degree(node) > 1: candidates.remove(candidate) continue edge = state.in_edges(node)[0] # Edge must not be WCR if edge.data.wcr is not None: candidates.remove(candidate) continue # Check inputs if isinstance(edge.src, nodes.AccessNode): # If input is array, ensure it is not a stream if isinstance(sdfg.arrays[edge.src.data], dt.Stream): candidates.remove(candidate) continue # Ensure no inputs exist to the array if state.in_degree(edge.src) > 0: candidates.remove(candidate) continue elif isinstance(edge.src, nodes.Tasklet): # If input tasklet has more than one output, skip if state.out_degree(edge.src) > 1: candidates.remove(candidate) continue # If inputs to tasklets are not arrays, skip for tinput in state.in_edges(edge.src): if not isinstance(tinput.src, nodes.AccessNode): candidates.remove(candidate) break if isinstance(sdfg.arrays[tinput.src.data], dt.Stream): candidates.remove(candidate) break # If input is not a single-element memlet, skip if (tinput.data.dynamic or tinput.data.subset.num_elements() != 1): candidates.remove(candidate) break # If input array has inputs of its own (cannot promote within same state), skip if state.in_degree(tinput.src) > 0: candidates.remove(candidate) break else: # Check that tasklets have only one statement cb: props.CodeBlock = edge.src.code if cb.language is dtypes.Language.Python: if (len(cb.code) > 1 or not isinstance(cb.code[0], ast.Assign)): candidates.remove(candidate) continue # Ensure the candidate is assigned to if (len(cb.code[0].targets) != 1 or astutils.rname( cb.code[0].targets[0]) != edge.src_conn): candidates.remove(candidate) continue # Ensure that the candidate is not assigned through # an "attribute" call, e.g., "dace.int64". These calls # are not supported currently by the SymPy-based # symbolic module. detector = AttributedCallDetector() detector.visit(cb.code[0].value) if detector.detected: candidates.remove(candidate) continue elif cb.language is dtypes.Language.CPP: # Try to match a single C assignment cstr = cb.as_string.strip() # Since we cannot remove subscripts from C++ tasklets, # if the type of the data is an array we will also skip if re.match(r'^[a-zA-Z_][a-zA-Z_0-9]*\s*=.*;$', cstr) is None: candidates.remove(candidate) continue newcode = translate_cpp_tasklet_to_python(cstr) try: parsed_ast = ast.parse(str(newcode)) except SyntaxError: #if we cannot parse the expression to pythonize it, we cannot promote the candidate candidates.remove(candidate) continue else: # Other languages are currently unsupported candidates.remove(candidate) continue else: # If input is not an acceptable node type, skip candidates.remove(candidate) candidates_seen |= candidates_in_state # Filter out non-integral symbols that do not appear in inter-state edges interstate_symbols = set() for edge in sdfg.edges(): interstate_symbols |= edge.data.free_symbols for candidate in (candidates - interstate_symbols): if integers_only and sdfg.arrays[ candidate].dtype not in dtypes.INTEGER_TYPES: candidates.remove(candidate) # Only keep candidates that were found in SDFG candidates &= (candidates_seen | interstate_symbols) return candidates
def global_value_to_node(self, value, parent_node, qualname, recurse=False, detect_callables=False): # if recurse is false, we don't allow recursion into lists # this should not happen anyway; the globals dict should only contain # single "level" lists if not recurse and isinstance(value, (list, tuple)): # bail after more than one level of lists return None if isinstance(value, list): elts = [ self.global_value_to_node(v, parent_node, qualname + f'[{i}]', detect_callables=detect_callables) for i, v in enumerate(value) ] if any(e is None for e in elts): return None newnode = ast.List(elts=elts, ctx=parent_node.ctx) elif isinstance(value, tuple): elts = [ self.global_value_to_node(v, parent_node, qualname + f'[{i}]', detect_callables=detect_callables) for i, v in enumerate(value) ] if any(e is None for e in elts): return None newnode = ast.Tuple(elts=elts, ctx=parent_node.ctx) elif isinstance(value, symbolic.symbol): # Symbols resolve to the symbol name newnode = ast.Name(id=value.name, ctx=ast.Load()) elif (dtypes.isconstant(value) or isinstance(value, SDFG) or hasattr(value, '__sdfg__')): # Could be a constant, an SDFG, or SDFG-convertible object if isinstance(value, SDFG) or hasattr(value, '__sdfg__'): self.closure.closure_sdfgs[qualname] = value else: self.closure.closure_constants[qualname] = value # Compatibility check since Python changed their AST nodes if sys.version_info >= (3, 8): newnode = ast.Constant(value=value, kind='') else: if value is None: newnode = ast.NameConstant(value=None) elif isinstance(value, str): newnode = ast.Str(s=value) else: newnode = ast.Num(n=value) newnode.oldnode = copy.deepcopy(parent_node) elif detect_callables and hasattr(value, '__call__') and hasattr( value.__call__, '__sdfg__'): return self.global_value_to_node(value.__call__, parent_node, qualname, recurse, detect_callables) elif isinstance(value, numpy.ndarray): # Arrays need to be stored as a new name and fed as an argument if id(value) in self.closure.array_mapping: arrname = self.closure.array_mapping[id(value)] else: arrname = self._qualname_to_array_name(qualname) desc = data.create_datadescriptor(value) self.closure.closure_arrays[arrname] = ( qualname, desc, lambda: eval(qualname, self.globals), False) self.closure.array_mapping[id(value)] = arrname newnode = ast.Name(id=arrname, ctx=ast.Load()) elif detect_callables and callable(value): # Try parsing the function as a dace function/method newnode = None try: from dace.frontend.python import parser # Avoid import loops parent_object = None if hasattr(value, '__self__'): parent_object = value.__self__ # If it is a callable object if (not inspect.isfunction(value) and not inspect.ismethod(value) and not inspect.isbuiltin(value) and hasattr(value, '__call__')): parent_object = value value = value.__call__ # Replacements take precedence over auto-parsing try: if has_replacement(value, parent_object, parent_node): return None except Exception: pass # Store the handle to the original callable, in case parsing fails cbqualname = astutils.rname(parent_node) cbname = self._qualname_to_array_name(cbqualname, prefix='') self.closure.callbacks[cbname] = (cbqualname, value, False) # From this point on, any failure will result in a callback newnode = ast.Name(id=cbname, ctx=ast.Load()) # Decorated or functions with missing source code sast, _, _, _ = astutils.function_to_ast(value) if len(sast.body[0].decorator_list) > 0: return newnode parsed = parser.DaceProgram(value, [], {}, False, dtypes.DeviceType.CPU) # If method, add the first argument (which disappears due to # being a bound method) and the method's object if parent_object is not None: parsed.methodobj = parent_object parsed.objname = inspect.getfullargspec(value).args[0] res = self.global_value_to_node(parsed, parent_node, qualname, recurse, detect_callables) # Keep callback in callbacks in case of parsing failure # del self.closure.callbacks[cbname] return res except Exception: # Parsing failed (almost any exception can occur) return newnode else: return None if parent_node is not None: return ast.copy_location(newnode, parent_node) else: return newnode