def closure_resolver(self, constant_args, parent_closure=None): # Parse allowed global variables # (for inferring types and values in the DaCe program) global_vars = copy.copy(self.global_vars) # If exist, obtain compile-time constants gvars = {} if constant_args is not None: gvars = {self.argnames[i]: v for i, v in constant_args.items() if isinstance(i, int)} gvars.update({k: v for k, v in constant_args.items() if not isinstance(k, int)}) global_vars = {k: v for k, v in global_vars.items() if k not in self.argnames} # Move "self" from an argument into the closure if self.methodobj is not None: global_vars[self.objname] = self.methodobj # Set module aliases to point to their actual names modules = {k: v.__name__ for k, v in global_vars.items() if dtypes.ismodule(v)} modules['builtins'] = '' # Add symbols as globals with their actual names (sym_0 etc.) global_vars.update({v.name: v for _, v in global_vars.items() if isinstance(v, symbolic.symbol)}) # Add constant arguments to global_vars global_vars.update(gvars) # Parse AST to create the SDFG _, closure = preprocessing.preprocess_dace_program(self.f, {}, global_vars, modules, resolve_functions=self.resolve_functions, parent_closure=parent_closure) return closure
def _generate_pdp(self, args, kwargs, strict=None) -> SDFG: """ Generates the parsed AST representation of a DaCe program. :param args: The given arguments to the program. :param kwargs: The given keyword arguments to the program. :param strict: Whether to apply strict transforms when parsing nested dace programs. :return: A 2-tuple of (parsed SDFG object, was the SDFG retrieved from cache). """ dace_func = self.f # If exist, obtain type annotations (for compilation) argtypes, _, gvars = self._get_type_annotations(args, kwargs) # Move "self" from an argument into the closure if self.methodobj is not None: self.global_vars[self.objname] = self.methodobj for k, v in argtypes.items(): if v.transient: # Arguments to (nested) SDFGs cannot be transient v_cpy = copy.deepcopy(v) v_cpy.transient = False argtypes[k] = v_cpy ############################################# # Parse allowed global variables # (for inferring types and values in the DaCe program) global_vars = copy.copy(self.global_vars) # Remove None arguments and make into globals that can be folded removed_args = set() for k, v in argtypes.items(): if v.dtype.type is None: global_vars[k] = None removed_args.add(k) argtypes = { k: v for k, v in argtypes.items() if v.dtype.type is not None } # Set module aliases to point to their actual names modules = { k: v.__name__ for k, v in global_vars.items() if dtypes.ismodule(v) } modules['builtins'] = '' # Add symbols as globals with their actual names (sym_0 etc.) global_vars.update({ v.name: v for _, v in global_vars.items() if isinstance(v, symbolic.symbol) }) for argtype in argtypes.values(): global_vars.update({v.name: v for v in argtype.free_symbols}) # Add constant arguments to global_vars global_vars.update(gvars) # Parse AST to create the SDFG parsed_ast, closure = preprocessing.preprocess_dace_program( dace_func, argtypes, global_vars, modules, resolve_functions=self.resolve_functions) # Create new argument mapping from closure arrays arg_mapping = { k: v for k, (_, _, v, _) in closure.closure_arrays.items() } self.closure_arg_mapping = arg_mapping self.closure_array_keys = set( closure.closure_arrays.keys()) - removed_args self.closure_constant_keys = set( closure.closure_constants.keys()) - removed_args self.resolver = closure # If parsed SDFG is already cached, use it cachekey = self._cache.make_key(argtypes, self.closure_array_keys, self.closure_constant_keys, gvars) if self._cache.has(cachekey): sdfg = self._cache.get(cachekey).sdfg cached = True else: cached = False sdfg = newast.parse_dace_program(self.name, parsed_ast, argtypes, self.dec_kwargs, closure, strict=strict) # Set SDFG argument names, filtering out constants sdfg.arg_names = [a for a in self.argnames if a in argtypes] # TODO: Add to parsed SDFG cache return sdfg, cached