示例#1
0
def get_type_annotations(f, f_argnames, decorator_args) -> ArgTypes:
    """ Obtains types from decorator or from type annotations in a function. 
    """
    type_annotations = {}
    if hasattr(f, '__annotations__'):
        type_annotations.update(f.__annotations__)

    # Type annotation conditions
    has_args = len(decorator_args) > 0
    has_annotations = len(type_annotations) > 0

    # Set __return* arrays from return type annotations
    if 'return' in type_annotations:
        rettype = type_annotations['return']
        if isinstance(rettype, tuple):
            for i, subrettype in enumerate(rettype):
                type_annotations[f'__return_{i}'] = subrettype
        else:
            type_annotations['__return'] = rettype
        del type_annotations['return']

    # If both arguments and annotations are given, annotations take precedence
    if has_args and has_annotations:
        has_args = False

    # Alert if there are any discrepancies between annotations and arguments
    if has_args:
        # Make sure all arguments are annotated
        if len(decorator_args) != len(f_argnames):
            raise SyntaxError(
                'Decorator arguments must match number of DaCe ' +
                'program parameters (expecting ' + str(len(f_argnames)) + ')')
        # Return arguments and their matched decorator annotation
        return {
            k: create_datadescriptor(v)
            for k, v in zip(f_argnames, decorator_args)
        }
    elif has_annotations:
        # Make sure all arguments are annotated
        filtered = {
            a
            for a in type_annotations.keys() if not a.startswith('__return')
        }
        if len(filtered) != len(f_argnames):
            raise SyntaxError(
                'Either none or all DaCe program parameters must ' +
                'have type annotations')
    return {k: create_datadescriptor(v) for k, v in type_annotations.items()}
示例#2
0
 def _evaluate_descriptors(
         self,
         arrays: Set[str],
         extra_constants: Dict[str, Any] = None) -> ConstantTypes:
     # Evaluate closure array types at call time
     return {
         k: dt.create_datadescriptor(self.eval_callback(k, extra_constants))
         for k in arrays
     }
示例#3
0
def _get_type_annotations(f, f_argnames, decorator_args):
    """ Obtains types from decorator or from type annotations in a function. 
    """
    type_annotations = {}
    if hasattr(f, '__annotations__'):
        type_annotations.update(f.__annotations__)

    # Type annotation conditions
    has_args = len(decorator_args) > 0
    has_annotations = len(type_annotations) > 0
    if 'return' in type_annotations:
        raise TypeError('DaCe programs do not have a return type')
    if has_args and has_annotations:
        raise SyntaxError('DaCe programs can only have decorator arguments ' +
                          '(\'@dace.program(...)\') or type annotations ' +
                          '(\'def program(arr: type, ...)\'), but not both')

    # Alert if there are any discrepancies between annotations and arguments
    if has_args:
        # Make sure all arguments are annotated
        if len(decorator_args) != len(f_argnames):
            raise SyntaxError(
                'Decorator arguments must match number of DaCe ' +
                'program parameters (expecting ' + str(len(f_argnames)) + ')')
        # Return arguments and their matched decorator annotation
        return {
            k: create_datadescriptor(v)
            for k, v in zip(f_argnames, decorator_args)
        }
    elif has_annotations:
        # Make sure all arguments are annotated
        if len(type_annotations) != len(f_argnames):
            raise SyntaxError(
                'Either none or all DaCe program parameters must ' +
                'have type annotations')
    return {k: create_datadescriptor(v) for k, v in type_annotations.items()}
示例#4
0
文件: parser.py 项目: mfkiwl/dace
    def _create_sdfg_args(self, sdfg: SDFG, args: Tuple[Any], kwargs: Dict[str, Any]) -> Dict[str, Any]:
        # Start with default arguments, then add other arguments
        result = {**self.default_args}
        # Reconstruct keyword arguments
        result.update({aname: arg for aname, arg in zip(self.argnames, args)})
        result.update(kwargs)

        # Add closure arguments to the call
        result.update(self.__sdfg_closure__())

        # Update arguments with symbols in data shapes
        result.update(
            infer_symbols_from_datadescriptor(
                sdfg, {k: create_datadescriptor(v)
                       for k, v in result.items() if k not in self.constant_args}))
        return result
示例#5
0
文件: parser.py 项目: sscholbe/dace
    def _get_type_annotations(
        self, given_args: Tuple[Any], given_kwargs: Dict[str, Any]
    ) -> Tuple[ArgTypes, Dict[str, Any], Dict[str, Any]]:
        """ 
        Obtains types from decorator and/or from type annotations in a function.
        :param given_args: The call-site arguments to the dace.program.
        :param given_kwargs: The call-site keyword arguments to the program.
        :return: A 3-tuple containing (argument type mapping, extra argument 
                 mapping, extra global variable mapping)
        """
        types: ArgTypes = {}
        arg_mapping: Dict[str, Any] = {}
        gvar_mapping: Dict[str, Any] = {}

        # Filter symbols out of given keyword arguments
        given_kwargs = {
            k: v
            for k, v in given_kwargs.items() if k not in self.symbols
        }

        # Make argument mapping to either type annotation, given argument,
        # default argument, or ignore (symbols and constants).
        nargs = len(given_args)
        arg_ind = 0
        for i, (aname,
                sig_arg) in enumerate(self.signature.parameters.items()):
            if self.objname is not None and aname == self.objname:
                # Skip "self" argument
                continue

            ann = sig_arg.annotation

            # Variable-length arguments: obtain from the remainder of given_*
            if sig_arg.kind is sig_arg.VAR_POSITIONAL:
                vargs = given_args[arg_ind:]

                # If an annotation is given but the argument list is empty, fail
                if not _is_empty(ann) and len(vargs) == 0:
                    raise SyntaxError(
                        'Cannot compile DaCe program with type-annotated '
                        'variable-length (starred) arguments and no given '
                        'parameters. Please compile the program with arguments, '
                        'call it without annotations, or remove the starred '
                        f'arguments (invalid argument name: "{aname}").')

                types.update({
                    f'__arg{j}': create_datadescriptor(varg)
                    for j, varg in enumerate(vargs)
                })
                arg_mapping.update(
                    {f'__arg{j}': varg
                     for j, varg in enumerate(vargs)})
                gvar_mapping[aname] = tuple(f'__arg{j}'
                                            for j in range(len(vargs)))
                # Shift arg_ind to the end
                arg_ind = len(given_args)
            elif sig_arg.kind is sig_arg.VAR_KEYWORD:
                vargs = {
                    k: create_datadescriptor(v)
                    for k, v in given_kwargs.items() if k not in types
                }
                # If an annotation is given but the argument list is empty, fail
                if not _is_empty(ann) and len(vargs) == 0:
                    raise SyntaxError(
                        'Cannot compile DaCe program with type-annotated '
                        'variable-length (starred) keyword arguments and no given '
                        'parameters. Please compile the program with arguments, '
                        'call it without annotations, or remove the starred '
                        f'arguments (invalid argument name: "{aname}").')
                types.update({f'__kwarg_{k}': v for k, v in vargs.items()})
                arg_mapping.update(
                    {f'__kwarg_{k}': given_kwargs[k]
                     for k in vargs.keys()})
                gvar_mapping[aname] = {k: f'__kwarg_{k}' for k in vargs.keys()}
            # END OF VARIABLE-LENGTH ARGUMENTS
            else:
                # Regular arguments (annotations take precedence)
                curarg = None
                is_constant = False
                if not _is_empty(ann):
                    # If constant, use given argument
                    if ann is dtypes.constant:
                        curarg = None
                        is_constant = True
                    else:
                        curarg = ann

                # If no annotation is provided, use given arguments
                if sig_arg.kind is sig_arg.POSITIONAL_ONLY:
                    if arg_ind >= nargs:
                        if curarg is None and not _is_empty(sig_arg.default):
                            curarg = sig_arg.default
                        elif curarg is None:
                            raise SyntaxError(
                                'Not enough arguments given to program (missing '
                                f'argument: "{aname}").')
                    else:
                        if curarg is None:
                            curarg = given_args[arg_ind]
                        arg_ind += 1
                elif sig_arg.kind is sig_arg.POSITIONAL_OR_KEYWORD:
                    if arg_ind >= nargs:
                        if aname not in given_kwargs:
                            if curarg is None and not _is_empty(
                                    sig_arg.default):
                                curarg = sig_arg.default
                            elif curarg is None:
                                raise SyntaxError(
                                    'Not enough arguments given to program (missing '
                                    f'argument: "{aname}").')
                        elif curarg is None:
                            curarg = given_kwargs[aname]
                    else:
                        if curarg is None:
                            curarg = given_args[arg_ind]
                        arg_ind += 1
                elif sig_arg.kind is sig_arg.KEYWORD_ONLY:
                    if aname not in given_kwargs:
                        if curarg is None and not _is_empty(sig_arg.default):
                            curarg = sig_arg.default
                        elif curarg is None:
                            raise SyntaxError(
                                'Not enough arguments given to program (missing '
                                f'argument: "{aname}").')
                    elif curarg is None:
                        curarg = given_kwargs[aname]

                if is_constant:
                    gvar_mapping[aname] = curarg
                    continue  # Skip argument

                # Set type
                types[aname] = create_datadescriptor(curarg)

        # Set __return* arrays from return type annotations
        rettype = self.signature.return_annotation
        if not _is_empty(rettype):
            if isinstance(rettype, tuple):
                for i, subrettype in enumerate(rettype):
                    types[f'__return_{i}'] = create_datadescriptor(subrettype)
            else:
                types['__return'] = create_datadescriptor(rettype)

        return types, arg_mapping, gvar_mapping
示例#6
0
    def generate_pdp(self, *compilation_args):
        """ Generates the parsed AST representation of a DaCe program.
            @param compilation_args: Various compilation arguments e.g., dtypes.
            @return: A 2-tuple of (program, modules), where `program` is a 
                     `dace.astnodes._ProgramNode` representing the parsed DaCe 
                     program, and `modules` is a dictionary mapping imported 
                     module names to their actual module names (for maintaining
                     import aliases).
        """
        dace_func = self.f
        args = self.args

        # If exist, obtain type annotations (for compilation)
        argtypes = _get_type_annotations(dace_func, self.argnames, args)

        # Parse argument types from call
        if len(inspect.getfullargspec(dace_func).args) > 0:
            if not argtypes:
                if not compilation_args:
                    raise SyntaxError(
                        'DaCe program compilation requires either type annotations '
                        'or arrays')

                # Parse compilation arguments
                if len(compilation_args) != len(self.argnames):
                    raise SyntaxError(
                        'Arguments must match DaCe program parameters (expecting '
                        '%d)' % len(self.argnames))
                argtypes = {
                    k: create_datadescriptor(v)
                    for k, v in zip(self.argnames, compilation_args)
                }
        for k, v in argtypes.items():
            if v.transient:  # Arguments to (nested) SDFGs cannot be transient
                v_cpy = copy.deepcopy(v)
                v_cpy.transient = False
                argtypes[k] = v_cpy
        #############################################

        # Parse allowed global variables
        # (for inferring types and values in the DaCe program)
        global_vars = copy.copy(self.global_vars)

        modules = {
            k: v.__name__
            for k, v in global_vars.items() if dtypes.ismodule(v)
        }
        modules['builtins'] = ''

        # Add symbols as globals with their actual names (sym_0 etc.)
        global_vars.update({
            v.name: v
            for k, v in global_vars.items() if isinstance(v, symbolic.symbol)
        })

        # Allow SDFGs and DaceProgram objects
        # NOTE: These are the globals AT THE TIME OF INVOCATION, NOT DEFINITION
        other_sdfgs = {
            k: v
            for k, v in dace_func.__globals__.items()
            if isinstance(v, (SDFG, DaceProgram))
        }

        # Parse AST to create the SDFG
        return newast.parse_dace_program(dace_func, argtypes, global_vars,
                                         modules, other_sdfgs, self.kwargs)
示例#7
0
    def global_value_to_node(self,
                             value,
                             parent_node,
                             qualname,
                             recurse=False,
                             detect_callables=False):
        # if recurse is false, we don't allow recursion into lists
        # this should not happen anyway; the globals dict should only contain
        # single "level" lists
        if not recurse and isinstance(value, (list, tuple)):
            # bail after more than one level of lists
            return None

        if isinstance(value, list):
            elts = [
                self.global_value_to_node(v,
                                          parent_node,
                                          qualname + f'[{i}]',
                                          detect_callables=detect_callables)
                for i, v in enumerate(value)
            ]
            if any(e is None for e in elts):
                return None
            newnode = ast.List(elts=elts, ctx=parent_node.ctx)
        elif isinstance(value, tuple):
            elts = [
                self.global_value_to_node(v,
                                          parent_node,
                                          qualname + f'[{i}]',
                                          detect_callables=detect_callables)
                for i, v in enumerate(value)
            ]
            if any(e is None for e in elts):
                return None
            newnode = ast.Tuple(elts=elts, ctx=parent_node.ctx)
        elif isinstance(value, symbolic.symbol):
            # Symbols resolve to the symbol name
            newnode = ast.Name(id=value.name, ctx=ast.Load())
        elif (dtypes.isconstant(value) or isinstance(value, SDFG)
              or hasattr(value, '__sdfg__')):
            # Could be a constant, an SDFG, or SDFG-convertible object
            if isinstance(value, SDFG) or hasattr(value, '__sdfg__'):
                self.closure.closure_sdfgs[qualname] = value
            else:
                self.closure.closure_constants[qualname] = value

            # Compatibility check since Python changed their AST nodes
            if sys.version_info >= (3, 8):
                newnode = ast.Constant(value=value, kind='')
            else:
                if value is None:
                    newnode = ast.NameConstant(value=None)
                elif isinstance(value, str):
                    newnode = ast.Str(s=value)
                else:
                    newnode = ast.Num(n=value)

            newnode.oldnode = copy.deepcopy(parent_node)

        elif detect_callables and hasattr(value, '__call__') and hasattr(
                value.__call__, '__sdfg__'):
            return self.global_value_to_node(value.__call__, parent_node,
                                             qualname, recurse,
                                             detect_callables)
        elif isinstance(value, numpy.ndarray):
            # Arrays need to be stored as a new name and fed as an argument
            if id(value) in self.closure.array_mapping:
                arrname = self.closure.array_mapping[id(value)]
            else:
                arrname = self._qualname_to_array_name(qualname)
                desc = data.create_datadescriptor(value)
                self.closure.closure_arrays[arrname] = (
                    qualname, desc, lambda: eval(qualname, self.globals),
                    False)
                self.closure.array_mapping[id(value)] = arrname

            newnode = ast.Name(id=arrname, ctx=ast.Load())
        elif detect_callables and callable(value):
            # Try parsing the function as a dace function/method
            newnode = None
            try:
                from dace.frontend.python import parser  # Avoid import loops

                parent_object = None
                if hasattr(value, '__self__'):
                    parent_object = value.__self__

                # If it is a callable object
                if (not inspect.isfunction(value)
                        and not inspect.ismethod(value)
                        and not inspect.isbuiltin(value)
                        and hasattr(value, '__call__')):
                    parent_object = value
                    value = value.__call__

                # Replacements take precedence over auto-parsing
                try:
                    if has_replacement(value, parent_object, parent_node):
                        return None
                except Exception:
                    pass

                # Store the handle to the original callable, in case parsing fails
                cbqualname = astutils.rname(parent_node)
                cbname = self._qualname_to_array_name(cbqualname, prefix='')
                self.closure.callbacks[cbname] = (cbqualname, value, False)

                # From this point on, any failure will result in a callback
                newnode = ast.Name(id=cbname, ctx=ast.Load())

                # Decorated or functions with missing source code
                sast, _, _, _ = astutils.function_to_ast(value)
                if len(sast.body[0].decorator_list) > 0:
                    return newnode

                parsed = parser.DaceProgram(value, [], {}, False,
                                            dtypes.DeviceType.CPU)
                # If method, add the first argument (which disappears due to
                # being a bound method) and the method's object
                if parent_object is not None:
                    parsed.methodobj = parent_object
                    parsed.objname = inspect.getfullargspec(value).args[0]

                res = self.global_value_to_node(parsed, parent_node, qualname,
                                                recurse, detect_callables)
                # Keep callback in callbacks in case of parsing failure
                # del self.closure.callbacks[cbname]
                return res
            except Exception:  # Parsing failed (almost any exception can occur)
                return newnode
        else:
            return None

        if parent_node is not None:
            return ast.copy_location(newnode, parent_node)
        else:
            return newnode
示例#8
0
def validate_sdfg(sdfg: 'dace.sdfg.SDFG'):
    """ Verifies the correctness of an SDFG by applying multiple tests.
        :param sdfg: The SDFG to verify.

        Raises an InvalidSDFGError with the erroneous node/edge
        on failure.
    """
    # Avoid import loop
    from dace.codegen.targets import fpga

    try:
        # SDFG-level checks
        if not dtypes.validate_name(sdfg.name):
            raise InvalidSDFGError("Invalid name", sdfg, None)

        if len(sdfg.source_nodes()) > 1 and sdfg.start_state is None:
            raise InvalidSDFGError("Starting state undefined", sdfg, None)

        if len(set([s.label for s in sdfg.nodes()])) != len(sdfg.nodes()):
            raise InvalidSDFGError("Found multiple states with the same name",
                                   sdfg, None)

        # Validate data descriptors
        for name, desc in sdfg._arrays.items():
            # Validate array names
            if name is not None and not dtypes.validate_name(name):
                raise InvalidSDFGError("Invalid array name %s" % name, sdfg,
                                       None)
            # Allocation lifetime checks
            if (desc.lifetime is dtypes.AllocationLifetime.Persistent
                    and desc.storage is dtypes.StorageType.Register):
                raise InvalidSDFGError(
                    "Array %s cannot be both persistent and use Register as "
                    "storage type. Please use a different storage location." %
                    name, sdfg, None)

            # Check for valid bank assignments
            try:
                bank_assignment = fpga.parse_location_bank(desc)
            except ValueError as e:
                raise InvalidSDFGError(str(e), sdfg, None)
            if bank_assignment is not None:
                if bank_assignment[0] == "DDR" or bank_assignment[0] == "HBM":
                    try:
                        tmp = subsets.Range.from_string(bank_assignment[1])
                    except SyntaxError:
                        raise InvalidSDFGError(
                            "Memory bank specifier must be convertible to subsets.Range"
                            f" for array {name}", sdfg, None)
                    try:
                        low, high = fpga.get_multibank_ranges_from_subset(
                            bank_assignment[1], sdfg)
                    except ValueError as e:
                        raise InvalidSDFGError(str(e), sdfg, None)
                    if (high - low < 1):
                        raise InvalidSDFGError(
                            "Memory bank specifier must at least define one bank to be used"
                            f" for array {name}", sdfg, None)
                    if (high - low > 1 and
                        (high - low != desc.shape[0] or len(desc.shape) < 2)):
                        raise InvalidSDFGError(
                            "Arrays that use a multibank access pattern must have the size of the first dimension equal"
                            f" the number of banks and have at least 2 dimensions for array {name}",
                            sdfg, None)

        # Check every state separately
        start_state = sdfg.start_state
        initialized_transients = {'__pystate'}
        symbols = copy.deepcopy(sdfg.symbols)
        symbols.update(sdfg.arrays)
        symbols.update({
            k: dt.create_datadescriptor(v)
            for k, v in sdfg.constants.items()
        })
        for desc in sdfg.arrays.values():
            for sym in desc.free_symbols:
                symbols[str(sym)] = sym.dtype
        visited = set()
        visited_edges = set()
        # Run through states via DFS, ensuring that only the defined symbols
        # are available for validation
        for edge in sdfg.dfs_edges(start_state):
            # Source -> inter-state definition -> Destination
            ##########################################
            visited_edges.add(edge)
            # Source
            if edge.src not in visited:
                visited.add(edge.src)
                validate_state(edge.src, sdfg.node_id(edge.src), sdfg, symbols,
                               initialized_transients)

            ##########################################
            # Edge
            # Check inter-state edge for undefined symbols
            undef_syms = set(edge.data.free_symbols) - set(symbols.keys())
            if len(undef_syms) > 0:
                eid = sdfg.edge_id(edge)
                raise InvalidSDFGInterstateEdgeError(
                    "Undefined symbols in edge: %s" % undef_syms, sdfg, eid)

            # Validate inter-state edge names
            issyms = edge.data.new_symbols(sdfg, symbols)
            if any(not dtypes.validate_name(s) for s in issyms):
                invalid = next(s for s in issyms
                               if not dtypes.validate_name(s))
                eid = sdfg.edge_id(edge)
                raise InvalidSDFGInterstateEdgeError(
                    "Invalid interstate symbol name %s" % invalid, sdfg, eid)

            # Add edge symbols into defined symbols
            symbols.update(issyms)

            ##########################################
            # Destination
            if edge.dst not in visited:
                visited.add(edge.dst)
                validate_state(edge.dst, sdfg.node_id(edge.dst), sdfg, symbols,
                               initialized_transients)
        # End of state DFS

        # If there is only one state, the DFS will miss it
        if start_state not in visited:
            validate_state(start_state, sdfg.node_id(start_state), sdfg,
                           symbols, initialized_transients)

        # Validate all inter-state edges (including self-loops not found by DFS)
        for eid, edge in enumerate(sdfg.edges()):
            if edge in visited_edges:
                continue
            issyms = edge.data.assignments.keys()
            if any(not dtypes.validate_name(s) for s in issyms):
                invalid = next(s for s in issyms
                               if not dtypes.validate_name(s))
                raise InvalidSDFGInterstateEdgeError(
                    "Invalid interstate symbol name %s" % invalid, sdfg, eid)

    except InvalidSDFGError as ex:
        # If the SDFG is invalid, save it
        sdfg.save(os.path.join('_dacegraphs', 'invalid.sdfg'), exception=ex)
        raise
示例#9
0
文件: parser.py 项目: thobauma/dace
    def _generate_pdp(self,
                      args,
                      kwargs,
                      strict=None) -> Tuple[SDFG, Dict[str, str]]:
        """ Generates the parsed AST representation of a DaCe program.
            :param args: The given arguments to the program.
            :param kwargs: The given keyword arguments to the program.
            :param strict: Whether to apply strict transforms when parsing 
                           nested dace programs.
            :return: A 2-tuple of (program, modules), where `program` is a
                     `dace.astnodes._ProgramNode` representing the parsed DaCe 
                     program, and `modules` is a dictionary mapping imported 
                     module names to their actual module names (for maintaining
                     import aliases).
        """
        dace_func = self.f

        # If exist, obtain type annotations (for compilation)
        argtypes, _, gvars = self._get_type_annotations(args, kwargs)

        # Move "self" from an argument into the closure
        if self.methodobj is not None:
            self.global_vars[self.objname] = self.methodobj

        # Parse argument types from call
        if len(self.argnames) > 0:
            if not argtypes:
                if not args and not kwargs:
                    raise SyntaxError(
                        'Compiling DaCe programs requires static types. '
                        'Please provide type annotations on the function, '
                        'or add sample arguments to the compilation call.')

                # Parse compilation arguments
                argtypes = {
                    k: create_datadescriptor(v)
                    for k, v in itertools.chain(self.default_args.items(
                    ), zip(self.argnames, args), kwargs.items())
                }
                if len(argtypes) != len(self.argnames):
                    raise SyntaxError(
                        'Number of arguments must match parameters '
                        f'(expecting {self.argnames}, got {list(argtypes.keys())})'
                    )

        for k, v in argtypes.items():
            if v.transient:  # Arguments to (nested) SDFGs cannot be transient
                v_cpy = copy.deepcopy(v)
                v_cpy.transient = False
                argtypes[k] = v_cpy

        #############################################

        # Parse allowed global variables
        # (for inferring types and values in the DaCe program)
        global_vars = copy.copy(self.global_vars)

        # Remove None arguments and make into globals that can be folded
        removed_args = set()
        for k, v in argtypes.items():
            if v.dtype.type is None:
                global_vars[k] = None
                removed_args.add(k)
        argtypes = {
            k: v
            for k, v in argtypes.items() if v.dtype.type is not None
        }

        # Set module aliases to point to their actual names
        modules = {
            k: v.__name__
            for k, v in global_vars.items() if dtypes.ismodule(v)
        }
        modules['builtins'] = ''

        # Add symbols as globals with their actual names (sym_0 etc.)
        global_vars.update({
            v.name: v
            for _, v in global_vars.items() if isinstance(v, symbolic.symbol)
        })
        for argtype in argtypes.values():
            global_vars.update({v.name: v for v in argtype.free_symbols})

        # Add constant arguments to global_vars
        global_vars.update(gvars)

        # Parse AST to create the SDFG
        sdfg, closure = newast.parse_dace_program(
            dace_func,
            self.name,
            argtypes,
            global_vars,
            modules,
            self.dec_kwargs,
            strict=strict,
            resolve_functions=self.resolve_functions)

        # Set SDFG argument names, filtering out constants
        sdfg.arg_names = [a for a in self.argnames if a in argtypes]

        # Create new argument mapping from closure arrays
        arg_mapping = {
            v: k
            for k, (v, _) in closure.closure_arrays.items()
            if isinstance(v, str)
        }
        arg_mapping.update({
            k: v
            for k, (v, _) in closure.closure_arrays.items()
            if not isinstance(v, str)
        })
        self.closure_arg_mapping = arg_mapping
        self.closure_array_keys = set(
            closure.closure_arrays.keys()) - removed_args
        self.closure_constant_keys = set(
            closure.closure_constants.keys()) - removed_args

        return sdfg, arg_mapping
示例#10
0
    def _generate_pdp(self, args, kwargs, strict=None):
        """ Generates the parsed AST representation of a DaCe program.
            :param args: The given arguments to the program.
            :param kwargs: The given keyword arguments to the program.
            :param strict: Whether to apply strict transforms when parsing 
                           nested dace programs.
            :return: A 2-tuple of (program, modules), where `program` is a
                     `dace.astnodes._ProgramNode` representing the parsed DaCe 
                     program, and `modules` is a dictionary mapping imported 
                     module names to their actual module names (for maintaining
                     import aliases).
        """
        dace_func = self.f

        # If exist, obtain type annotations (for compilation)
        argtypes, _, gvars = self._get_type_annotations(args, kwargs)

        # Parse argument types from call
        if len(self.argnames) > 0:
            if not argtypes:
                if not args and not kwargs:
                    raise SyntaxError(
                        'Compiling DaCe programs requires static types. '
                        'Please provide type annotations on the function, '
                        'or add sample arguments to the compilation call.')

                # Parse compilation arguments
                argtypes = {
                    k: create_datadescriptor(v)
                    for k, v in itertools.chain(self.default_args.items(
                    ), zip(self.argnames, args), kwargs.items())
                }
                if len(argtypes) != len(self.argnames):
                    raise SyntaxError(
                        'Number of arguments must match parameters '
                        f'(expecting {self.argnames}, got {list(argtypes.keys())})'
                    )

        for k, v in argtypes.items():
            if v.transient:  # Arguments to (nested) SDFGs cannot be transient
                v_cpy = copy.deepcopy(v)
                v_cpy.transient = False
                argtypes[k] = v_cpy

        #############################################

        # Parse allowed global variables
        # (for inferring types and values in the DaCe program)
        global_vars = copy.copy(self.global_vars)

        # Remove None arguments and make into globals that can be folded
        for k, v in argtypes.items():
            if v.dtype.type is None:
                global_vars[k] = None
        argtypes = {
            k: v
            for k, v in argtypes.items() if v.dtype.type is not None
        }

        # Set module aliases to point to their actual names
        modules = {
            k: v.__name__
            for k, v in global_vars.items() if dtypes.ismodule(v)
        }
        modules['builtins'] = ''

        # Add symbols as globals with their actual names (sym_0 etc.)
        global_vars.update({
            v.name: v
            for _, v in global_vars.items() if isinstance(v, symbolic.symbol)
        })
        for argtype in argtypes.values():
            global_vars.update({v.name: v for v in argtype.free_symbols})

        # Add constant arguments to global_vars
        global_vars.update(gvars)

        # Allow SDFGs and DaceProgram objects
        # NOTE: These are the globals AT THE TIME OF INVOCATION, NOT DEFINITION
        other_sdfgs = {
            k: v
            for k, v in _get_locals_and_globals(dace_func).items()
            if isinstance(v, (SDFG, DaceProgram))
        }

        # Parse AST to create the SDFG
        sdfg = newast.parse_dace_program(dace_func,
                                         self.name,
                                         argtypes,
                                         global_vars,
                                         modules,
                                         other_sdfgs,
                                         self.dec_kwargs,
                                         strict=strict)

        # Set SDFG argument names, filtering out constants
        sdfg.arg_names = [a for a in self.argnames if a in argtypes]

        return sdfg
示例#11
0
def validate_sdfg(sdfg: 'dace.sdfg.SDFG'):
    """ Verifies the correctness of an SDFG by applying multiple tests.
        :param sdfg: The SDFG to verify.

        Raises an InvalidSDFGError with the erroneous node/edge
        on failure.
    """
    try:
        # SDFG-level checks
        if not dtypes.validate_name(sdfg.name):
            raise InvalidSDFGError("Invalid name", sdfg, None)

        if len(sdfg.source_nodes()) > 1 and sdfg.start_state is None:
            raise InvalidSDFGError("Starting state undefined", sdfg, None)

        if len(set([s.label for s in sdfg.nodes()])) != len(sdfg.nodes()):
            raise InvalidSDFGError("Found multiple states with the same name",
                                   sdfg, None)

        # Validate data descriptors
        for name, desc in sdfg._arrays.items():
            # Validate array names
            if name is not None and not dtypes.validate_name(name):
                raise InvalidSDFGError("Invalid array name %s" % name, sdfg,
                                       None)
            # Allocation lifetime checks
            if (desc.lifetime is dtypes.AllocationLifetime.Persistent
                    and desc.storage is dtypes.StorageType.Register):
                raise InvalidSDFGError(
                    "Array %s cannot be both persistent and use Register as "
                    "storage type. Please use a different storage location." %
                    name, sdfg, None)

        # Check every state separately
        start_state = sdfg.start_state
        symbols = copy.deepcopy(sdfg.symbols)
        symbols.update(sdfg.arrays)
        symbols.update({
            k: dt.create_datadescriptor(v)
            for k, v in sdfg.constants.items()
        })
        for desc in sdfg.arrays.values():
            for sym in desc.free_symbols:
                symbols[str(sym)] = sym.dtype
        visited = set()
        visited_edges = set()
        # Run through states via DFS, ensuring that only the defined symbols
        # are available for validation
        for edge in sdfg.dfs_edges(start_state):
            # Source -> inter-state definition -> Destination
            ##########################################
            visited_edges.add(edge)
            # Source
            if edge.src not in visited:
                visited.add(edge.src)
                validate_state(edge.src, sdfg.node_id(edge.src), sdfg, symbols)

            ##########################################
            # Edge
            # Check inter-state edge for undefined symbols
            undef_syms = set(edge.data.free_symbols) - set(symbols.keys())
            if len(undef_syms) > 0:
                eid = sdfg.edge_id(edge)
                raise InvalidSDFGInterstateEdgeError(
                    "Undefined symbols in edge: %s" % undef_syms, sdfg, eid)

            # Validate inter-state edge names
            issyms = edge.data.new_symbols(symbols)
            if any(not dtypes.validate_name(s) for s in issyms):
                invalid = next(s for s in issyms
                               if not dtypes.validate_name(s))
                eid = sdfg.edge_id(edge)
                raise InvalidSDFGInterstateEdgeError(
                    "Invalid interstate symbol name %s" % invalid, sdfg, eid)

            # Add edge symbols into defined symbols
            symbols.update(issyms)

            ##########################################
            # Destination
            if edge.dst not in visited:
                visited.add(edge.dst)
                validate_state(edge.dst, sdfg.node_id(edge.dst), sdfg, symbols)
        # End of state DFS

        # If there is only one state, the DFS will miss it
        if start_state not in visited:
            validate_state(start_state, sdfg.node_id(start_state), sdfg,
                           symbols)

        # Validate all inter-state edges (including self-loops not found by DFS)
        for eid, edge in enumerate(sdfg.edges()):
            if edge in visited_edges:
                continue
            issyms = edge.data.assignments.keys()
            if any(not dtypes.validate_name(s) for s in issyms):
                invalid = next(s for s in issyms
                               if not dtypes.validate_name(s))
                raise InvalidSDFGInterstateEdgeError(
                    "Invalid interstate symbol name %s" % invalid, sdfg, eid)

    except InvalidSDFGError as ex:
        # If the SDFG is invalid, save it
        sdfg.save(os.path.join('_dacegraphs', 'invalid.sdfg'), exception=ex)
        raise