예제 #1
0
파일: parser.py 프로젝트: sscholbe/dace
    def _generate_pdp(self, args, kwargs, strict=None) -> SDFG:
        """ Generates the parsed AST representation of a DaCe program.
            :param args: The given arguments to the program.
            :param kwargs: The given keyword arguments to the program.
            :param strict: Whether to apply strict transforms when parsing 
                           nested dace programs.
            :return: A 2-tuple of (parsed SDFG object, was the SDFG retrieved
                     from cache).
        """
        dace_func = self.f

        # If exist, obtain type annotations (for compilation)
        argtypes, _, gvars = self._get_type_annotations(args, kwargs)

        # Move "self" from an argument into the closure
        if self.methodobj is not None:
            self.global_vars[self.objname] = self.methodobj

        for k, v in argtypes.items():
            if v.transient:  # Arguments to (nested) SDFGs cannot be transient
                v_cpy = copy.deepcopy(v)
                v_cpy.transient = False
                argtypes[k] = v_cpy

        #############################################

        # Parse allowed global variables
        # (for inferring types and values in the DaCe program)
        global_vars = copy.copy(self.global_vars)

        # Remove None arguments and make into globals that can be folded
        removed_args = set()
        for k, v in argtypes.items():
            if v.dtype.type is None:
                global_vars[k] = None
                removed_args.add(k)
        argtypes = {
            k: v
            for k, v in argtypes.items() if v.dtype.type is not None
        }

        # Set module aliases to point to their actual names
        modules = {
            k: v.__name__
            for k, v in global_vars.items() if dtypes.ismodule(v)
        }
        modules['builtins'] = ''

        # Add symbols as globals with their actual names (sym_0 etc.)
        global_vars.update({
            v.name: v
            for _, v in global_vars.items() if isinstance(v, symbolic.symbol)
        })
        for argtype in argtypes.values():
            global_vars.update({v.name: v for v in argtype.free_symbols})

        # Add constant arguments to global_vars
        global_vars.update(gvars)

        # Parse AST to create the SDFG
        parsed_ast, closure = preprocessing.preprocess_dace_program(
            dace_func,
            argtypes,
            global_vars,
            modules,
            resolve_functions=self.resolve_functions)

        # Create new argument mapping from closure arrays
        arg_mapping = {
            k: v
            for k, (_, _, v, _) in closure.closure_arrays.items()
        }
        self.closure_arg_mapping = arg_mapping
        self.closure_array_keys = set(
            closure.closure_arrays.keys()) - removed_args
        self.closure_constant_keys = set(
            closure.closure_constants.keys()) - removed_args
        self.resolver = closure

        # If parsed SDFG is already cached, use it
        cachekey = self._cache.make_key(argtypes, self.closure_array_keys,
                                        self.closure_constant_keys, gvars)
        if self._cache.has(cachekey):
            sdfg = self._cache.get(cachekey).sdfg
            cached = True
        else:
            cached = False
            sdfg = newast.parse_dace_program(self.name,
                                             parsed_ast,
                                             argtypes,
                                             self.dec_kwargs,
                                             closure,
                                             strict=strict)

            # Set SDFG argument names, filtering out constants
            sdfg.arg_names = [a for a in self.argnames if a in argtypes]

            # TODO: Add to parsed SDFG cache

        return sdfg, cached
예제 #2
0
파일: parser.py 프로젝트: thobauma/dace
    def _generate_pdp(self,
                      args,
                      kwargs,
                      strict=None) -> Tuple[SDFG, Dict[str, str]]:
        """ Generates the parsed AST representation of a DaCe program.
            :param args: The given arguments to the program.
            :param kwargs: The given keyword arguments to the program.
            :param strict: Whether to apply strict transforms when parsing 
                           nested dace programs.
            :return: A 2-tuple of (program, modules), where `program` is a
                     `dace.astnodes._ProgramNode` representing the parsed DaCe 
                     program, and `modules` is a dictionary mapping imported 
                     module names to their actual module names (for maintaining
                     import aliases).
        """
        dace_func = self.f

        # If exist, obtain type annotations (for compilation)
        argtypes, _, gvars = self._get_type_annotations(args, kwargs)

        # Move "self" from an argument into the closure
        if self.methodobj is not None:
            self.global_vars[self.objname] = self.methodobj

        # Parse argument types from call
        if len(self.argnames) > 0:
            if not argtypes:
                if not args and not kwargs:
                    raise SyntaxError(
                        'Compiling DaCe programs requires static types. '
                        'Please provide type annotations on the function, '
                        'or add sample arguments to the compilation call.')

                # Parse compilation arguments
                argtypes = {
                    k: create_datadescriptor(v)
                    for k, v in itertools.chain(self.default_args.items(
                    ), zip(self.argnames, args), kwargs.items())
                }
                if len(argtypes) != len(self.argnames):
                    raise SyntaxError(
                        'Number of arguments must match parameters '
                        f'(expecting {self.argnames}, got {list(argtypes.keys())})'
                    )

        for k, v in argtypes.items():
            if v.transient:  # Arguments to (nested) SDFGs cannot be transient
                v_cpy = copy.deepcopy(v)
                v_cpy.transient = False
                argtypes[k] = v_cpy

        #############################################

        # Parse allowed global variables
        # (for inferring types and values in the DaCe program)
        global_vars = copy.copy(self.global_vars)

        # Remove None arguments and make into globals that can be folded
        removed_args = set()
        for k, v in argtypes.items():
            if v.dtype.type is None:
                global_vars[k] = None
                removed_args.add(k)
        argtypes = {
            k: v
            for k, v in argtypes.items() if v.dtype.type is not None
        }

        # Set module aliases to point to their actual names
        modules = {
            k: v.__name__
            for k, v in global_vars.items() if dtypes.ismodule(v)
        }
        modules['builtins'] = ''

        # Add symbols as globals with their actual names (sym_0 etc.)
        global_vars.update({
            v.name: v
            for _, v in global_vars.items() if isinstance(v, symbolic.symbol)
        })
        for argtype in argtypes.values():
            global_vars.update({v.name: v for v in argtype.free_symbols})

        # Add constant arguments to global_vars
        global_vars.update(gvars)

        # Parse AST to create the SDFG
        sdfg, closure = newast.parse_dace_program(
            dace_func,
            self.name,
            argtypes,
            global_vars,
            modules,
            self.dec_kwargs,
            strict=strict,
            resolve_functions=self.resolve_functions)

        # Set SDFG argument names, filtering out constants
        sdfg.arg_names = [a for a in self.argnames if a in argtypes]

        # Create new argument mapping from closure arrays
        arg_mapping = {
            v: k
            for k, (v, _) in closure.closure_arrays.items()
            if isinstance(v, str)
        }
        arg_mapping.update({
            k: v
            for k, (v, _) in closure.closure_arrays.items()
            if not isinstance(v, str)
        })
        self.closure_arg_mapping = arg_mapping
        self.closure_array_keys = set(
            closure.closure_arrays.keys()) - removed_args
        self.closure_constant_keys = set(
            closure.closure_constants.keys()) - removed_args

        return sdfg, arg_mapping
예제 #3
0
    def generate_pdp(self, *compilation_args):
        """ Generates the parsed AST representation of a DaCe program.
            @param compilation_args: Various compilation arguments e.g., dtypes.
            @return: A 2-tuple of (program, modules), where `program` is a 
                     `dace.astnodes._ProgramNode` representing the parsed DaCe 
                     program, and `modules` is a dictionary mapping imported 
                     module names to their actual module names (for maintaining
                     import aliases).
        """
        dace_func = self.f
        args = self.args

        # If exist, obtain type annotations (for compilation)
        argtypes = _get_type_annotations(dace_func, self.argnames, args)

        # Parse argument types from call
        if len(inspect.getfullargspec(dace_func).args) > 0:
            if not argtypes:
                if not compilation_args:
                    raise SyntaxError(
                        'DaCe program compilation requires either type annotations '
                        'or arrays')

                # Parse compilation arguments
                if len(compilation_args) != len(self.argnames):
                    raise SyntaxError(
                        'Arguments must match DaCe program parameters (expecting '
                        '%d)' % len(self.argnames))
                argtypes = {
                    k: create_datadescriptor(v)
                    for k, v in zip(self.argnames, compilation_args)
                }
        for k, v in argtypes.items():
            if v.transient:  # Arguments to (nested) SDFGs cannot be transient
                v_cpy = copy.deepcopy(v)
                v_cpy.transient = False
                argtypes[k] = v_cpy
        #############################################

        # Parse allowed global variables
        # (for inferring types and values in the DaCe program)
        global_vars = copy.copy(self.global_vars)

        modules = {
            k: v.__name__
            for k, v in global_vars.items() if dtypes.ismodule(v)
        }
        modules['builtins'] = ''

        # Add symbols as globals with their actual names (sym_0 etc.)
        global_vars.update({
            v.name: v
            for k, v in global_vars.items() if isinstance(v, symbolic.symbol)
        })

        # Allow SDFGs and DaceProgram objects
        # NOTE: These are the globals AT THE TIME OF INVOCATION, NOT DEFINITION
        other_sdfgs = {
            k: v
            for k, v in dace_func.__globals__.items()
            if isinstance(v, (SDFG, DaceProgram))
        }

        # Parse AST to create the SDFG
        return newast.parse_dace_program(dace_func, argtypes, global_vars,
                                         modules, other_sdfgs, self.kwargs)
예제 #4
0
    def _generate_pdp(self, args, kwargs, strict=None):
        """ Generates the parsed AST representation of a DaCe program.
            :param args: The given arguments to the program.
            :param kwargs: The given keyword arguments to the program.
            :param strict: Whether to apply strict transforms when parsing 
                           nested dace programs.
            :return: A 2-tuple of (program, modules), where `program` is a
                     `dace.astnodes._ProgramNode` representing the parsed DaCe 
                     program, and `modules` is a dictionary mapping imported 
                     module names to their actual module names (for maintaining
                     import aliases).
        """
        dace_func = self.f

        # If exist, obtain type annotations (for compilation)
        argtypes, _, gvars = self._get_type_annotations(args, kwargs)

        # Parse argument types from call
        if len(self.argnames) > 0:
            if not argtypes:
                if not args and not kwargs:
                    raise SyntaxError(
                        'Compiling DaCe programs requires static types. '
                        'Please provide type annotations on the function, '
                        'or add sample arguments to the compilation call.')

                # Parse compilation arguments
                argtypes = {
                    k: create_datadescriptor(v)
                    for k, v in itertools.chain(self.default_args.items(
                    ), zip(self.argnames, args), kwargs.items())
                }
                if len(argtypes) != len(self.argnames):
                    raise SyntaxError(
                        'Number of arguments must match parameters '
                        f'(expecting {self.argnames}, got {list(argtypes.keys())})'
                    )

        for k, v in argtypes.items():
            if v.transient:  # Arguments to (nested) SDFGs cannot be transient
                v_cpy = copy.deepcopy(v)
                v_cpy.transient = False
                argtypes[k] = v_cpy

        #############################################

        # Parse allowed global variables
        # (for inferring types and values in the DaCe program)
        global_vars = copy.copy(self.global_vars)

        # Remove None arguments and make into globals that can be folded
        for k, v in argtypes.items():
            if v.dtype.type is None:
                global_vars[k] = None
        argtypes = {
            k: v
            for k, v in argtypes.items() if v.dtype.type is not None
        }

        # Set module aliases to point to their actual names
        modules = {
            k: v.__name__
            for k, v in global_vars.items() if dtypes.ismodule(v)
        }
        modules['builtins'] = ''

        # Add symbols as globals with their actual names (sym_0 etc.)
        global_vars.update({
            v.name: v
            for _, v in global_vars.items() if isinstance(v, symbolic.symbol)
        })
        for argtype in argtypes.values():
            global_vars.update({v.name: v for v in argtype.free_symbols})

        # Add constant arguments to global_vars
        global_vars.update(gvars)

        # Allow SDFGs and DaceProgram objects
        # NOTE: These are the globals AT THE TIME OF INVOCATION, NOT DEFINITION
        other_sdfgs = {
            k: v
            for k, v in _get_locals_and_globals(dace_func).items()
            if isinstance(v, (SDFG, DaceProgram))
        }

        # Parse AST to create the SDFG
        sdfg = newast.parse_dace_program(dace_func,
                                         self.name,
                                         argtypes,
                                         global_vars,
                                         modules,
                                         other_sdfgs,
                                         self.dec_kwargs,
                                         strict=strict)

        # Set SDFG argument names, filtering out constants
        sdfg.arg_names = [a for a in self.argnames if a in argtypes]

        return sdfg