def test_simple(self): test_ast, _, _, _ = astutils.function_to_ast(toresolve) code = astunparse.unparse( GlobalResolver({ 'b': 9, 'a': -4 }).visit(test_ast)) self.assertTrue('return 9' in code) self.assertTrue('f(a, b)' in code) self.assertTrue('g(b' in code)
def _analyze_and_unparse_code(func: DaceProgram) -> str: src_ast, _, _, _ = astutils.function_to_ast(func.f) resolved = { k: v for k, v in func.global_vars.items() if k not in func.argnames } src_ast = GlobalResolver(resolved).visit(src_ast) src_ast = ConditionalCodeResolver(resolved).visit(src_ast) src_ast = DeadCodeEliminator().visit(src_ast) return astutils.unparse(src_ast)
def test_dace_unroll_break(): """ Tests unrolling functionality with control flow statements. """ @dace.program def tounroll(A: dace.float64[1]): for i in dace.unroll(range(1, 4)): A[0] += i * i if i in (2, 3): break src_ast, fname, _, _ = astutils.function_to_ast(tounroll.f) lu = LoopUnroller(tounroll.global_vars, fname) with pytest.raises(DaceSyntaxError): unrolled = lu.visit(src_ast)
def test_dace_unroll(): """ Tests that unrolling functionality works within DaCe programs. """ @dace.program def tounroll(A: dace.float64[1]): for i in dace.unroll(range(1, 4)): A[0] += i * i src_ast, fname, _, _ = astutils.function_to_ast(tounroll.f) lu = LoopUnroller(tounroll.global_vars, fname) unrolled = lu.visit(src_ast) assert len(unrolled.body[0].body) == 3 a = np.zeros([1]) tounroll(a) assert a[0] == 14
def test_dace_unroll_multistatement(): """ Tests unrolling functionality with multiple statements. """ @dace.program def tounroll(A: dace.float64[1]): for i in dace.unroll(range(1, 4)): A[0] += i * i if i in (3, ): A[0] += 2 src_ast, fname, _, _ = astutils.function_to_ast(tounroll.f) lu = LoopUnroller(tounroll.global_vars, fname) unrolled = lu.visit(src_ast) assert len(unrolled.body[0].body) == 6 a = np.zeros([1]) tounroll(a) assert a[0] == 16
def preprocess_dace_program( f: Callable[..., Any], argtypes: Dict[str, data.Data], global_vars: Dict[str, Any], modules: Dict[str, Any], resolve_functions: bool = False, parent_closure: Optional[SDFGClosure] = None ) -> Tuple[PreprocessedAST, SDFGClosure]: """ Preprocesses a ``@dace.program`` and all its nested functions, returning a preprocessed AST object and the closure of the resulting SDFG. :param f: A Python function to parse. :param argtypes: An dictionary of (name, type) for the given function's arguments, which may pertain to data nodes or symbols (scalars). :param global_vars: A dictionary of global variables in the closure of `f`. :param modules: A dictionary from an imported module name to the module itself. :param constants: A dictionary from a name to a constant value. :param resolve_functions: If True, treats all global functions defined outside of the program as returning constant values. :param parent_closure: If not None, represents the closure of the parent of the currently processed function. :return: A 2-tuple of the AST and its reduced (used) closure. """ src_ast, src_file, src_line, src = astutils.function_to_ast(f) # Resolve data structures src_ast = StructTransformer(global_vars).visit(src_ast) src_ast = ModuleResolver(modules).visit(src_ast) # Convert modules after resolution for mod, modval in modules.items(): if mod == 'builtins': continue newmod = global_vars[mod] #del global_vars[mod] global_vars[modval] = newmod # Resolve constants to their values (if they are not already defined in this scope) # and symbols to their names resolved = { k: v for k, v in global_vars.items() if k not in argtypes and k != '_' } closure_resolver = GlobalResolver(resolved, resolve_functions) # Append element to call stack and handle max recursion depth if parent_closure is not None: fid = id(f) if fid in parent_closure.callstack: raise DaceRecursionError(fid) if len(parent_closure.callstack) > Config.get( 'frontend', 'implicit_recursion_depth'): raise TypeError( 'Implicit (automatically parsed) recursion depth ' 'exceeded. Functions below this call will not be ' 'parsed. To change this setting, modify the value ' '`frontend.implicit_recursion_depth` in .dace.conf') closure_resolver.closure.callstack = parent_closure.callstack + [fid] src_ast = closure_resolver.visit(src_ast) src_ast = LoopUnroller(resolved, src_file).visit(src_ast) src_ast = ConditionalCodeResolver(resolved).visit(src_ast) src_ast = DeadCodeEliminator().visit(src_ast) try: ctr = CallTreeResolver(closure_resolver.closure, resolved) ctr.visit(src_ast) except DaceRecursionError as ex: if id(f) == ex.fid: raise TypeError( 'Parsing failed due to recursion in a data-centric ' 'context called from this function') else: raise ex used_arrays = ArrayClosureResolver(closure_resolver.closure) used_arrays.visit(src_ast) # Filter out arrays that are not used after dead code elimination closure_resolver.closure.closure_arrays = { k: v for k, v in closure_resolver.closure.closure_arrays.items() if k in used_arrays.arrays } # Filter out callbacks that were removed after dead code elimination closure_resolver.closure.callbacks = { k: v for k, v in closure_resolver.closure.callbacks.items() if k in ctr.seen_calls } # Filter remaining global variables according to type and scoping rules program_globals = { k: v for k, v in global_vars.items() if k not in argtypes } # Fill in data descriptors from closure arrays argtypes.update({ arrname: v[1] for arrname, v in closure_resolver.closure.closure_arrays.items() }) # Combine nested closures with the current one closure_resolver.closure.combine_nested_closures() past = PreprocessedAST(src_file, src_line, src, src_ast, program_globals) return past, closure_resolver.closure
def global_value_to_node(self, value, parent_node, qualname, recurse=False, detect_callables=False): # if recurse is false, we don't allow recursion into lists # this should not happen anyway; the globals dict should only contain # single "level" lists if not recurse and isinstance(value, (list, tuple)): # bail after more than one level of lists return None if isinstance(value, list): elts = [ self.global_value_to_node(v, parent_node, qualname + f'[{i}]', detect_callables=detect_callables) for i, v in enumerate(value) ] if any(e is None for e in elts): return None newnode = ast.List(elts=elts, ctx=parent_node.ctx) elif isinstance(value, tuple): elts = [ self.global_value_to_node(v, parent_node, qualname + f'[{i}]', detect_callables=detect_callables) for i, v in enumerate(value) ] if any(e is None for e in elts): return None newnode = ast.Tuple(elts=elts, ctx=parent_node.ctx) elif isinstance(value, symbolic.symbol): # Symbols resolve to the symbol name newnode = ast.Name(id=value.name, ctx=ast.Load()) elif (dtypes.isconstant(value) or isinstance(value, SDFG) or hasattr(value, '__sdfg__')): # Could be a constant, an SDFG, or SDFG-convertible object if isinstance(value, SDFG) or hasattr(value, '__sdfg__'): self.closure.closure_sdfgs[qualname] = value else: self.closure.closure_constants[qualname] = value # Compatibility check since Python changed their AST nodes if sys.version_info >= (3, 8): newnode = ast.Constant(value=value, kind='') else: if value is None: newnode = ast.NameConstant(value=None) elif isinstance(value, str): newnode = ast.Str(s=value) else: newnode = ast.Num(n=value) newnode.oldnode = copy.deepcopy(parent_node) elif detect_callables and hasattr(value, '__call__') and hasattr( value.__call__, '__sdfg__'): return self.global_value_to_node(value.__call__, parent_node, qualname, recurse, detect_callables) elif isinstance(value, numpy.ndarray): # Arrays need to be stored as a new name and fed as an argument if id(value) in self.closure.array_mapping: arrname = self.closure.array_mapping[id(value)] else: arrname = self._qualname_to_array_name(qualname) desc = data.create_datadescriptor(value) self.closure.closure_arrays[arrname] = ( qualname, desc, lambda: eval(qualname, self.globals), False) self.closure.array_mapping[id(value)] = arrname newnode = ast.Name(id=arrname, ctx=ast.Load()) elif detect_callables and callable(value): # Try parsing the function as a dace function/method newnode = None try: from dace.frontend.python import parser # Avoid import loops parent_object = None if hasattr(value, '__self__'): parent_object = value.__self__ # If it is a callable object if (not inspect.isfunction(value) and not inspect.ismethod(value) and not inspect.isbuiltin(value) and hasattr(value, '__call__')): parent_object = value value = value.__call__ # Replacements take precedence over auto-parsing try: if has_replacement(value, parent_object, parent_node): return None except Exception: pass # Store the handle to the original callable, in case parsing fails cbqualname = astutils.rname(parent_node) cbname = self._qualname_to_array_name(cbqualname, prefix='') self.closure.callbacks[cbname] = (cbqualname, value, False) # From this point on, any failure will result in a callback newnode = ast.Name(id=cbname, ctx=ast.Load()) # Decorated or functions with missing source code sast, _, _, _ = astutils.function_to_ast(value) if len(sast.body[0].decorator_list) > 0: return newnode parsed = parser.DaceProgram(value, [], {}, False, dtypes.DeviceType.CPU) # If method, add the first argument (which disappears due to # being a bound method) and the method's object if parent_object is not None: parsed.methodobj = parent_object parsed.objname = inspect.getfullargspec(value).args[0] res = self.global_value_to_node(parsed, parent_node, qualname, recurse, detect_callables) # Keep callback in callbacks in case of parsing failure # del self.closure.callbacks[cbname] return res except Exception: # Parsing failed (almost any exception can occur) return newnode else: return None if parent_node is not None: return ast.copy_location(newnode, parent_node) else: return newnode