Esempio n. 1
0
    def closure_resolver(self, constant_args, parent_closure=None):
        # Parse allowed global variables
        # (for inferring types and values in the DaCe program)
        global_vars = copy.copy(self.global_vars)

        # If exist, obtain compile-time constants
        gvars = {}
        if constant_args is not None:
            gvars = {self.argnames[i]: v for i, v in constant_args.items() if isinstance(i, int)}
            gvars.update({k: v for k, v in constant_args.items() if not isinstance(k, int)})
        global_vars = {k: v for k, v in global_vars.items() if k not in self.argnames}

        # Move "self" from an argument into the closure
        if self.methodobj is not None:
            global_vars[self.objname] = self.methodobj

        # Set module aliases to point to their actual names
        modules = {k: v.__name__ for k, v in global_vars.items() if dtypes.ismodule(v)}
        modules['builtins'] = ''

        # Add symbols as globals with their actual names (sym_0 etc.)
        global_vars.update({v.name: v for _, v in global_vars.items() if isinstance(v, symbolic.symbol)})

        # Add constant arguments to global_vars
        global_vars.update(gvars)

        # Parse AST to create the SDFG
        _, closure = preprocessing.preprocess_dace_program(self.f, {},
                                                           global_vars,
                                                           modules,
                                                           resolve_functions=self.resolve_functions,
                                                           parent_closure=parent_closure)
        return closure
Esempio n. 2
0
    def _generate_pdp(self, args, kwargs, strict=None) -> SDFG:
        """ Generates the parsed AST representation of a DaCe program.
            :param args: The given arguments to the program.
            :param kwargs: The given keyword arguments to the program.
            :param strict: Whether to apply strict transforms when parsing 
                           nested dace programs.
            :return: A 2-tuple of (parsed SDFG object, was the SDFG retrieved
                     from cache).
        """
        dace_func = self.f

        # If exist, obtain type annotations (for compilation)
        argtypes, _, gvars = self._get_type_annotations(args, kwargs)

        # Move "self" from an argument into the closure
        if self.methodobj is not None:
            self.global_vars[self.objname] = self.methodobj

        for k, v in argtypes.items():
            if v.transient:  # Arguments to (nested) SDFGs cannot be transient
                v_cpy = copy.deepcopy(v)
                v_cpy.transient = False
                argtypes[k] = v_cpy

        #############################################

        # Parse allowed global variables
        # (for inferring types and values in the DaCe program)
        global_vars = copy.copy(self.global_vars)

        # Remove None arguments and make into globals that can be folded
        removed_args = set()
        for k, v in argtypes.items():
            if v.dtype.type is None:
                global_vars[k] = None
                removed_args.add(k)
        argtypes = {
            k: v
            for k, v in argtypes.items() if v.dtype.type is not None
        }

        # Set module aliases to point to their actual names
        modules = {
            k: v.__name__
            for k, v in global_vars.items() if dtypes.ismodule(v)
        }
        modules['builtins'] = ''

        # Add symbols as globals with their actual names (sym_0 etc.)
        global_vars.update({
            v.name: v
            for _, v in global_vars.items() if isinstance(v, symbolic.symbol)
        })
        for argtype in argtypes.values():
            global_vars.update({v.name: v for v in argtype.free_symbols})

        # Add constant arguments to global_vars
        global_vars.update(gvars)

        # Parse AST to create the SDFG
        parsed_ast, closure = preprocessing.preprocess_dace_program(
            dace_func,
            argtypes,
            global_vars,
            modules,
            resolve_functions=self.resolve_functions)

        # Create new argument mapping from closure arrays
        arg_mapping = {
            k: v
            for k, (_, _, v, _) in closure.closure_arrays.items()
        }
        self.closure_arg_mapping = arg_mapping
        self.closure_array_keys = set(
            closure.closure_arrays.keys()) - removed_args
        self.closure_constant_keys = set(
            closure.closure_constants.keys()) - removed_args
        self.resolver = closure

        # If parsed SDFG is already cached, use it
        cachekey = self._cache.make_key(argtypes, self.closure_array_keys,
                                        self.closure_constant_keys, gvars)
        if self._cache.has(cachekey):
            sdfg = self._cache.get(cachekey).sdfg
            cached = True
        else:
            cached = False
            sdfg = newast.parse_dace_program(self.name,
                                             parsed_ast,
                                             argtypes,
                                             self.dec_kwargs,
                                             closure,
                                             strict=strict)

            # Set SDFG argument names, filtering out constants
            sdfg.arg_names = [a for a in self.argnames if a in argtypes]

            # TODO: Add to parsed SDFG cache

        return sdfg, cached