def parse_memlet(visitor, src: MemletType, dst: MemletType, defined_arrays_and_symbols: Dict[str, data.Data]): srcexpr, dstexpr, localvar = None, None, None if isinstance(src, ast.Name) and rname(src) not in defined_arrays_and_symbols: localvar = rname(src) else: srcexpr = ParseMemlet(visitor, defined_arrays_and_symbols, src) if isinstance(dst, ast.Name) and rname(dst) not in defined_arrays_and_symbols: if localvar is not None: raise DaceSyntaxError( visitor, src, 'Memlet source and destination cannot both be local variables') localvar = rname(dst) else: dstexpr = ParseMemlet(visitor, defined_arrays_and_symbols, dst) if srcexpr is not None and dstexpr is not None: # Create two memlets raise NotImplementedError elif srcexpr is not None: expr = srcexpr else: expr = dstexpr return localvar, Memlet(expr.name, expr.accesses, expr.subset, 1, wcr=expr.wcr)
def ParseMemlet(visitor, defined_arrays_and_symbols: Dict[str, Any], node: MemletType, parsed_slice: Any = None) -> MemletExpr: das = defined_arrays_and_symbols arrname = rname(node) if arrname not in das: raise DaceSyntaxError(visitor, node, 'Use of undefined data "%s" in memlet' % arrname) array = das[arrname] # Determine number of accesses to the memlet (default is the slice size) num_accesses = None write_conflict_resolution = None # Detects expressions of the form "A(2)[...]", "A(300)", "A(1, sum)[:]" if isinstance(node, ast.Call): if len(node.args) < 1 or len(node.args) > 3: raise DaceSyntaxError( visitor, node, 'Number of accesses in memlet must be a number, symbolic ' 'expression, or -1 (dynamic)') num_accesses = pyexpr_to_symbolic(das, node.args[0]) if len(node.args) >= 2: write_conflict_resolution = node.args[1] elif isinstance(node, ast.Subscript) and isinstance(node.value, ast.Call): if len(node.value.args) < 1 or len(node.value.args) > 3: raise DaceSyntaxError( visitor, node, 'Number of accesses in memlet must be a number, symbolic ' 'expression, or -1 (dynamic)') num_accesses = pyexpr_to_symbolic(das, node.value.args[0]) if len(node.value.args) >= 2: write_conflict_resolution = node.value.args[1] subset, new_axes, arrdims = parse_memlet_subset(array, node, das, parsed_slice) # If undefined, default number of accesses is the slice size if num_accesses is None: num_accesses = subset.num_elements() return MemletExpr(arrname, num_accesses, write_conflict_resolution, subset, new_axes, arrdims)
def visit_For(self, node: ast.For) -> Any: # Avoid import loops EXPLICIT_GENERATORS = [ range, # Handled in ProgramVisitor dace.map, dace.consume, ] node = self.generic_visit(node) # First, skip loops that contain break/continue that is part of this # for loop (rather than nested ones) cannot_unroll = False cflow_finder = _FindBreakContinueStmts() for stmt in node.body: cflow_finder.visit(stmt) if cflow_finder.has_cflow or node.orelse: cannot_unroll = True niter = node.iter # Find out if loop was explicitly requested to be unrolled with unroll, # and whether it should be done implicitly explicitly_requested = False if isinstance(niter, ast.Call): # Avoid import loop from dace.frontend.python.interface import unroll try: genfunc = astutils.evalnode(niter.func, self.globals) except SyntaxError: genfunc = None if genfunc is unroll: explicitly_requested = True niter = niter.args[0] if explicitly_requested and cannot_unroll: raise DaceSyntaxError( None, node, 'Cannot unroll loop due to ' '"break", "continue", or "else" statements.') # Find out if unrolling should be done implicitly implicit = True # Anything not a call is implicitly allowed if isinstance(niter, (ast.Call, ast.Subscript)): if isinstance(niter, ast.Subscript): nfunc = niter.value else: nfunc = niter.func implicit = False # Try to see if it's one of the allowed stateless generators try: genfunc = astutils.evalnode(nfunc, self.globals) # If genfunc is a bound method, try to extract function from type if hasattr(genfunc, '__self__'): genfunc = getattr(type(genfunc.__self__), genfunc.__name__, False) if genfunc in LoopUnroller.STATELESS_GENERATORS: implicit = True elif genfunc in EXPLICIT_GENERATORS: implicit = False except SyntaxError: pass # Loop will not be unrolled if not implicit and not explicitly_requested: return node # Check if loop target is supported if isinstance(node.target, ast.Tuple): to_replace = node.target.elts elif isinstance(node.target, ast.Name): to_replace = [node.target] else: # Unsupported loop target return node if isinstance(niter, (ast.Tuple, ast.List, ast.Set)): # Check if a literal tuple/list/set generator = niter.elts elif isinstance(niter, ast.Dict): # If dict, take keys (Python compatible) generator = niter.keys # elif isinstance(iter, (ast.ListComp, ast.DictComp, ast.SetComp, ast.GeneratorExp)): # # Check if a comprehension or generator expression # pass else: # Check if the generator is compile-time constant try: generator = astutils.evalnode(niter, self.globals) except SyntaxError: # Cannot evaluate generator at compile time return node # Too verbose? if implicit and not explicitly_requested: warnings.warn(f'Loop at {self.filename}:{node.lineno} will be ' 'implicitly unrolled.') ########################################## # Unroll loop new_body = [] for elem in generator: # Paste loop body with replaced elements try: iter(elem) except (TypeError, ValueError): elem = [elem] elembody = copy.deepcopy(node.body) replace = astutils.ASTFindReplace( {k: v for k, v in zip(to_replace, elem)}) for stmt in elembody: new_body.append(replace.visit(stmt)) return new_body