Exemple #1
0
    def validate(self, sdfg, state):
        if not dtypes.validate_name(self.label):
            raise NameError('Invalid nested SDFG name "%s"' % self.label)
        for in_conn in self.in_connectors:
            if not dtypes.validate_name(in_conn):
                raise NameError('Invalid input connector "%s"' % in_conn)
        for out_conn in self.out_connectors:
            if not dtypes.validate_name(out_conn):
                raise NameError('Invalid output connector "%s"' % out_conn)
        connectors = self.in_connectors.keys() | self.out_connectors.keys()
        for dname, desc in self.sdfg.arrays.items():
            # TODO(later): Disallow scalars without access nodes (so that this
            #              check passes for them too).
            if isinstance(desc, data.Scalar):
                continue
            if not desc.transient and dname not in connectors:
                raise NameError('Data descriptor "%s" not found in nested '
                                'SDFG connectors' % dname)
            if dname in connectors and desc.transient:
                raise NameError(
                    '"%s" is a connector but its corresponding array is transient'
                    % dname)

        # Validate undefined symbols
        symbols = set(k for k in self.sdfg.free_symbols if k not in connectors)
        missing_symbols = [s for s in symbols if s not in self.symbol_mapping]
        if missing_symbols:
            raise ValueError('Missing symbols on nested SDFG: %s' %
                             (missing_symbols))

        # Recursively validate nested SDFG
        self.sdfg.validate()
Exemple #2
0
 def validate(self, sdfg, state):
     if not dtypes.validate_name(self.label):
         raise NameError('Invalid tasklet name "%s"' % self.label)
     for in_conn in self.in_connectors:
         if not dtypes.validate_name(in_conn):
             raise NameError('Invalid input connector "%s"' % in_conn)
     for out_conn in self.out_connectors:
         if not dtypes.validate_name(out_conn):
             raise NameError('Invalid output connector "%s"' % out_conn)
Exemple #3
0
    def _parse_from_subexpr(self, expr: str):
        if expr[-1] != ']':  # No subset given, try to use whole array
            if not dtypes.validate_name(expr):
                raise SyntaxError('Invalid memlet syntax "%s"' % expr)
            return expr, None

        # array[subset] syntax
        arrname, subset_str = expr[:-1].split('[')
        if not dtypes.validate_name(arrname):
            raise SyntaxError('Invalid array name "%s" in memlet' % arrname)
        return arrname, SubsetProperty.from_string(subset_str)
Exemple #4
0
    def __new__(cls, name=None, dtype=DEFAULT_SYMBOL_TYPE, **assumptions):
        if name is None:
            # Set name dynamically
            name = "sym_" + str(symbol.s_currentsymbol)
            symbol.s_currentsymbol += 1
        elif name.startswith('__DACE'):
            raise NameError('Symbols cannot start with __DACE')
        elif not dtypes.validate_name(name):
            raise NameError('Invalid symbol name "%s"' % name)

        if not isinstance(dtype, dtypes.typeclass):
            raise TypeError('dtype must be a DaCe type, got %s' % str(dtype))

        if 'integer' in assumptions or 'int' not in str(dtype):
            # Using __xnew__ as the regular __new__ is cached, which leads
            # to modifying different references of symbols with the same name.
            self = sympy.Symbol.__xnew__(cls, name, **assumptions)
        else:
            self = sympy.Symbol.__xnew__(cls,
                                         name,
                                         integer=True,
                                         **assumptions)

        self.dtype = dtype
        self._constraints = []
        self.value = None
        return self
Exemple #5
0
    def _parse_memlet_from_str(self, expr: str):
        """
        Parses a memlet and fills in either the src_subset,dst_subset fields
        or the _data,_subset fields.
        :param expr: A string expression of the this memlet, given as an ease
                of use API. Must follow one of the following forms:
                1. ``ARRAY``,
                2. ``ARRAY[SUBSET]``,
                3. ``ARRAY[SUBSET] -> OTHER_SUBSET``.
                Note that modes 2 and 3 are deprecated and will leave 
                the memlet uninitialized until inserted into an SDFG.
        """
        expr = expr.strip()
        if '->' not in expr:  # Options 1 and 2
            self.data, self.subset = self._parse_from_subexpr(expr)
            return

        # Option 3
        src_expr, dst_expr = expr.split('->')
        src_expr = src_expr.strip()
        dst_expr = dst_expr.strip()
        if '[' not in src_expr and not dtypes.validate_name(src_expr):
            raise SyntaxError('Expression without data name not yet allowed')

        self.data, self.subset = self._parse_from_subexpr(src_expr)
        self.other_subset = SubsetProperty.from_string(dst_expr)
Exemple #6
0
def pystr_to_symbolic(expr, symbol_map=None, simplify=None):
    """ Takes a Python string and converts it into a symbolic expression. """
    from dace.frontend.python.astutils import unparse  # Avoid import loops

    if isinstance(expr, (SymExpr, sympy.Basic)):
        return expr
    if isinstance(expr, str) and dtypes.validate_name(expr):
        return symbol(expr)

    symbol_map = symbol_map or {}
    locals = {'min': sympy.Min, 'max': sympy.Max}
    # _clash1 enables all one-letter variables like N as symbols
    # _clash also allows pi, beta, zeta and other common greek letters
    locals.update(sympy.abc._clash)

    # Sympy processes "not/and/or" as direct evaluation. Replace with
    # And/Or(x, y), Not(x)
    if isinstance(expr, str) and re.search(r'\bnot\b|\band\b|\bor\b', expr):
        expr = unparse(SympyBooleanConverter().visit(ast.parse(expr).body[0]))

    # TODO: support SymExpr over-approximated expressions
    try:
        return sympy_to_dace(sympy.sympify(expr, locals, evaluate=simplify),
                             symbol_map)
    except TypeError:  # Symbol object is not subscriptable
        # Replace subscript expressions with function calls
        expr = expr.replace('[', '(')
        expr = expr.replace(']', ')')
        return sympy_to_dace(sympy.sympify(expr, locals, evaluate=simplify),
                             symbol_map)
Exemple #7
0
def pystr_to_symbolic(expr, symbol_map=None, simplify=None) -> sympy.Basic:
    """ Takes a Python string and converts it into a symbolic expression. """
    from dace.frontend.python.astutils import unparse  # Avoid import loops

    if isinstance(expr, (SymExpr, sympy.Basic)):
        return expr
    if isinstance(expr, str):
        try:
            return sympy.Integer(int(expr))
        except ValueError:
            pass
        try:
            return sympy.Float(float(expr))
        except ValueError:
            pass
        if dtypes.validate_name(expr):
            return symbol(expr)

    symbol_map = symbol_map or {}
    locals = {
        'abs': sympy.Abs,
        'min': sympy.Min,
        'max': sympy.Max,
        'True': sympy.true,
        'False': sympy.false,
        'GtE': sympy.Ge,
        'LtE': sympy.Le,
        'NotEq': sympy.Ne,
        'floor': sympy.floor,
        'ceil': sympy.ceiling,
        'round': ROUND,
        # Convert and/or to special sympy functions to avoid boolean evaluation
        'And': AND,
        'Or': OR,
        'var': sympy.Symbol('var'),
        'root': sympy.Symbol('root'),
        'arg': sympy.Symbol('arg'),
    }
    # _clash1 enables all one-letter variables like N as symbols
    # _clash also allows pi, beta, zeta and other common greek letters
    locals.update(_sympy_clash)

    # Sympy processes "not/and/or" as direct evaluation. Replace with
    # And/Or(x, y), Not(x)
    if isinstance(expr, str) and re.search(
            r'\bnot\b|\band\b|\bor\b|\bNone\b|==|!=', expr):
        expr = unparse(SympyBooleanConverter().visit(ast.parse(expr).body[0]))

    # TODO: support SymExpr over-approximated expressions
    try:
        return sympy_to_dace(sympy.sympify(expr, locals, evaluate=simplify),
                             symbol_map)
    except (TypeError,
            sympy.SympifyError):  # Symbol object is not subscriptable
        # Replace subscript expressions with function calls
        expr = expr.replace('[', '(')
        expr = expr.replace(']', ')')
        return sympy_to_dace(sympy.sympify(expr, locals, evaluate=simplify),
                             symbol_map)
Exemple #8
0
 def _qualname_to_array_name(self, qualname: str) -> str:
     """ Converts a Python qualified attribute name to an SDFG array name. """
     # We only support attributes and subscripts for now
     sanitized = re.sub(r'[\.\[\]\'\",]', '_', qualname)
     if not dtypes.validate_name(sanitized):
         raise NameError(
             f'Variable name "{sanitized}" is not sanitized '
             'properly during parsing. Please report this issue.')
     return f"__g_{sanitized}"
        def _replace_assignment(v, repl):
            # Special cases to speed up replacement
            v = str(v)
            if not v:
                return v
            if dtypes.validate_name(v) and v in repl:
                return repl[v]

            vast = ast.parse(v)
            replacer = astutils.ASTFindReplace(repl)
            vast = replacer.visit(vast)
            return astutils.unparse(vast)
Exemple #10
0
def free_symbols_and_functions(expr: Union[SymbolicType, str]) -> Set[str]:
    if isinstance(expr, str):
        if dtypes.validate_name(expr):
            return {expr}
        expr = pystr_to_symbolic(expr)
    if not isinstance(expr, sympy.Basic):
        return set()

    result = {str(k) for k in expr.free_symbols}
    for atom in swalk(expr):
        if (is_sympy_userfunction(atom)
                and str(atom.func) not in _builtin_userfunctions):
            result.add(str(atom.func))
    return result
Exemple #11
0
 def validate(self, sdfg, state, node):
     if not dtypes.validate_name(self.label):
         raise NameError('Invalid consume name "%s"' % self.label)
Exemple #12
0
def validate_sdfg(sdfg: 'dace.sdfg.SDFG'):
    """ Verifies the correctness of an SDFG by applying multiple tests.
        :param sdfg: The SDFG to verify.

        Raises an InvalidSDFGError with the erroneous node/edge
        on failure.
    """
    try:
        # SDFG-level checks
        if not validate_name(sdfg.name):
            raise InvalidSDFGError("Invalid name", sdfg, None)

        if len(sdfg.source_nodes()) > 1 and sdfg.start_state is None:
            raise InvalidSDFGError("Starting state undefined", sdfg, None)

        if len(set([s.label for s in sdfg.nodes()])) != len(sdfg.nodes()):
            raise InvalidSDFGError("Found multiple states with the same name",
                                   sdfg, None)

        # Validate array names
        for name in sdfg._arrays.keys():
            if name is not None and not validate_name(name):
                raise InvalidSDFGError("Invalid array name %s" % name, sdfg,
                                       None)

        # Check every state separately
        start_state = sdfg.start_state
        symbols = copy.deepcopy(sdfg.symbols)
        symbols.update(sdfg.arrays)
        symbols.update(sdfg.constants)
        for desc in sdfg.arrays.values():
            for sym in desc.free_symbols:
                symbols[str(sym)] = sym.dtype
        visited = set()
        visited_edges = set()
        # Run through states via DFS, ensuring that only the defined symbols
        # are available for validation
        for edge in sdfg.dfs_edges(start_state):
            # Source -> inter-state definition -> Destination
            ##########################################
            visited_edges.add(edge)
            # Source
            if edge.src not in visited:
                visited.add(edge.src)
                validate_state(edge.src, sdfg.node_id(edge.src), sdfg, symbols)

            ##########################################
            # Edge
            # Check inter-state edge for undefined symbols
            undef_syms = set(edge.data.free_symbols) - set(symbols.keys())
            if len(undef_syms) > 0:
                eid = sdfg.edge_id(edge)
                raise InvalidSDFGInterstateEdgeError(
                    "Undefined symbols in edge: %s" % undef_syms, sdfg, eid)

            # Validate inter-state edge names
            issyms = edge.data.new_symbols(symbols)
            if any(not validate_name(s) for s in issyms):
                invalid = next(s for s in issyms if not validate_name(s))
                eid = sdfg.edge_id(edge)
                raise InvalidSDFGInterstateEdgeError(
                    "Invalid interstate symbol name %s" % invalid, sdfg, eid)

            # Add edge symbols into defined symbols
            symbols.update(issyms)

            ##########################################
            # Destination
            if edge.dst not in visited:
                visited.add(edge.dst)
                validate_state(edge.dst, sdfg.node_id(edge.dst), sdfg, symbols)
        # End of state DFS

        # If there is only one state, the DFS will miss it
        if start_state not in visited:
            validate_state(start_state, sdfg.node_id(start_state), sdfg,
                           symbols)

        # Validate all inter-state edges (including self-loops not found by DFS)
        for eid, edge in enumerate(sdfg.edges()):
            if edge in visited_edges:
                continue
            issyms = edge.data.assignments.keys()
            if any(not validate_name(s) for s in issyms):
                invalid = next(s for s in issyms if not validate_name(s))
                raise InvalidSDFGInterstateEdgeError(
                    "Invalid interstate symbol name %s" % invalid, sdfg, eid)

    except InvalidSDFGError as ex:
        # If the SDFG is invalid, save it
        sdfg.save(os.path.join('_dacegraphs', 'invalid.sdfg'), exception=ex)
        raise
Exemple #13
0
def validate_state(state: 'dace.sdfg.SDFGState',
                   state_id: int = None,
                   sdfg: 'dace.sdfg.SDFG' = None,
                   symbols: Dict[str, typeclass] = None):
    """ Verifies the correctness of an SDFG state by applying multiple
        tests. Raises an InvalidSDFGError with the erroneous node on
        failure.
    """
    # Avoid import loops
    from dace.sdfg import SDFG
    from dace.config import Config
    from dace.sdfg import nodes as nd
    from dace.sdfg.scope import scope_contains_scope
    from dace import data as dt
    from dace import subsets as sbs

    sdfg = sdfg or state.parent
    state_id = state_id or sdfg.node_id(state)
    symbols = symbols or {}

    if not validate_name(state._label):
        raise InvalidSDFGError("Invalid state name", sdfg, state_id)

    if state._parent != sdfg:
        raise InvalidSDFGError("State does not point to the correct "
                               "parent", sdfg, state_id)

    # Unreachable
    ########################################
    if (sdfg.number_of_nodes() > 1 and sdfg.in_degree(state) == 0
            and sdfg.out_degree(state) == 0):
        raise InvalidSDFGError("Unreachable state", sdfg, state_id)

    for nid, node in enumerate(state.nodes()):
        # Node validation
        try:
            node.validate(sdfg, state)
        except InvalidSDFGError:
            raise
        except Exception as ex:
            raise InvalidSDFGNodeError("Node validation failed: " + str(ex),
                                       sdfg, state_id, nid) from ex

        # Isolated nodes
        ########################################
        if state.in_degree(node) + state.out_degree(node) == 0:
            # One corner case: OK if this is a code node
            if isinstance(node, nd.CodeNode):
                pass
            else:
                raise InvalidSDFGNodeError("Isolated node", sdfg, state_id,
                                           nid)

        # Scope tests
        ########################################
        if isinstance(node, nd.EntryNode):
            try:
                state.exit_node(node)
            except StopIteration:
                raise InvalidSDFGNodeError(
                    "Entry node does not have matching "
                    "exit node",
                    sdfg,
                    state_id,
                    nid,
                )

        if isinstance(node, (nd.EntryNode, nd.ExitNode)):
            for iconn in node.in_connectors:
                if (iconn is not None and iconn.startswith("IN_")
                        and ("OUT_" + iconn[3:]) not in node.out_connectors):
                    raise InvalidSDFGNodeError(
                        "No match for input connector %s in output "
                        "connectors" % iconn,
                        sdfg,
                        state_id,
                        nid,
                    )
            for oconn in node.out_connectors:
                if (oconn is not None and oconn.startswith("OUT_")
                        and ("IN_" + oconn[4:]) not in node.in_connectors):
                    raise InvalidSDFGNodeError(
                        "No match for output connector %s in input "
                        "connectors" % oconn,
                        sdfg,
                        state_id,
                        nid,
                    )

        # Node-specific tests
        ########################################
        if isinstance(node, nd.AccessNode):
            if node.data not in sdfg.arrays:
                raise InvalidSDFGNodeError(
                    "Access node must point to a valid array name in the SDFG",
                    sdfg,
                    state_id,
                    nid,
                )

            # Find uninitialized transients
            arr = sdfg.arrays[node.data]
            if (arr.transient and state.in_degree(node) == 0
                    and state.out_degree(node) > 0
                    # Streams do not need to be initialized
                    and not isinstance(arr, dt.Stream)):
                # Find other instances of node in predecessor states
                states = sdfg.predecessor_states(state)
                input_found = False
                for s in states:
                    for onode in s.nodes():
                        if (isinstance(onode, nd.AccessNode)
                                and onode.data == node.data):
                            if s.in_degree(onode) > 0:
                                input_found = True
                                break
                    if input_found:
                        break
                if not input_found and node.setzero == False:
                    warnings.warn(
                        'WARNING: Use of uninitialized transient "%s" in state %s'
                        % (node.data, state.label))

            # Find writes to input-only arrays
            if not arr.transient and state.in_degree(node) > 0:
                nsdfg_node = sdfg.parent_nsdfg_node
                if nsdfg_node is not None:
                    if node.data not in nsdfg_node.out_connectors:
                        raise InvalidSDFGNodeError(
                            'Data descriptor %s is '
                            'written to, but only given to nested SDFG as an '
                            'input connector' % node.data, sdfg, state_id, nid)

        if (isinstance(node, nd.ConsumeEntry)
                and "IN_stream" not in node.in_connectors):
            raise InvalidSDFGNodeError(
                "Consume entry node must have an input stream", sdfg, state_id,
                nid)
        if (isinstance(node, nd.ConsumeEntry)
                and "OUT_stream" not in node.out_connectors):
            raise InvalidSDFGNodeError(
                "Consume entry node must have an internal stream",
                sdfg,
                state_id,
                nid,
            )

        # Connector tests
        ########################################
        # Check for duplicate connector names (unless it's a nested SDFG)
        if (len(node.in_connectors.keys() & node.out_connectors.keys()) > 0
                and not isinstance(node, nd.NestedSDFG)):
            dups = node.in_connectors.keys() & node.out_connectors.keys()
            raise InvalidSDFGNodeError("Duplicate connectors: " + str(dups),
                                       sdfg, state_id, nid)

        # Check for connectors that are also array/symbol names
        if isinstance(node, nd.Tasklet):
            for conn in node.in_connectors.keys():
                if conn in sdfg.arrays or conn in symbols:
                    raise InvalidSDFGNodeError(
                        f"Input connector {conn} already "
                        "defined as array or symbol", sdfg, state_id, nid)
            for conn in node.out_connectors.keys():
                if conn in sdfg.arrays or conn in symbols:
                    raise InvalidSDFGNodeError(
                        f"Output connector {conn} already "
                        "defined as array or symbol", sdfg, state_id, nid)

        # Check for dangling connectors (incoming)
        for conn in node.in_connectors:
            incoming_edges = 0
            for e in state.in_edges(node):
                # Connector found
                if e.dst_conn == conn:
                    incoming_edges += 1

            if incoming_edges == 0:
                raise InvalidSDFGNodeError("Dangling in-connector %s" % conn,
                                           sdfg, state_id, nid)
            # Connectors may have only one incoming edge
            # Due to input connectors of scope exit, this is only correct
            # in some cases:
            if incoming_edges > 1 and not isinstance(node, nd.ExitNode):
                raise InvalidSDFGNodeError(
                    "Connector '%s' cannot have more "
                    "than one incoming edge, found %d" %
                    (conn, incoming_edges),
                    sdfg,
                    state_id,
                    nid,
                )

        # Check for dangling connectors (outgoing)
        for conn in node.out_connectors:
            outgoing_edges = 0
            for e in state.out_edges(node):
                # Connector found
                if e.src_conn == conn:
                    outgoing_edges += 1

            if outgoing_edges == 0:
                raise InvalidSDFGNodeError("Dangling out-connector %s" % conn,
                                           sdfg, state_id, nid)

            # In case of scope exit or code node, only one outgoing edge per
            # connector is allowed.
            if outgoing_edges > 1 and isinstance(node,
                                                 (nd.ExitNode, nd.CodeNode)):
                raise InvalidSDFGNodeError(
                    "Connector '%s' cannot have more "
                    "than one outgoing edge, found %d" %
                    (conn, outgoing_edges),
                    sdfg,
                    state_id,
                    nid,
                )

        # Check for edges to nonexistent connectors
        for e in state.in_edges(node):
            if e.dst_conn is not None and e.dst_conn not in node.in_connectors:
                raise InvalidSDFGNodeError(
                    ("Memlet %s leading to " + "nonexistent connector %s") %
                    (str(e.data), e.dst_conn),
                    sdfg,
                    state_id,
                    nid,
                )
        for e in state.out_edges(node):
            if e.src_conn is not None and e.src_conn not in node.out_connectors:
                raise InvalidSDFGNodeError(
                    ("Memlet %s coming from " + "nonexistent connector %s") %
                    (str(e.data), e.src_conn),
                    sdfg,
                    state_id,
                    nid,
                )
        ########################################

    # Memlet checks
    scope = state.scope_dict()
    for eid, e in enumerate(state.edges()):
        # Edge validation
        try:
            e.data.validate(sdfg, state)
        except InvalidSDFGError:
            raise
        except Exception as ex:
            raise InvalidSDFGEdgeError("Edge validation failed: " + str(ex),
                                       sdfg, state_id, eid)

        # For every memlet, obtain its full path in the DFG
        path = state.memlet_path(e)
        src_node = path[0].src
        dst_node = path[-1].dst

        # Check if memlet data matches src or dst nodes
        if (e.data.data is not None
                and (isinstance(src_node, nd.AccessNode)
                     or isinstance(dst_node, nd.AccessNode))
                and (not isinstance(src_node, nd.AccessNode)
                     or e.data.data != src_node.data)
                and (not isinstance(dst_node, nd.AccessNode)
                     or e.data.data != dst_node.data)):
            raise InvalidSDFGEdgeError(
                "Memlet data does not match source or destination "
                "data nodes)",
                sdfg,
                state_id,
                eid,
            )

        # Check memlet subset validity with respect to source/destination nodes
        if e.data.data is not None and e.data.allow_oob == False:
            subset_node = (dst_node if isinstance(dst_node, nd.AccessNode)
                           and e.data.data == dst_node.data else src_node)
            other_subset_node = (
                dst_node if isinstance(dst_node, nd.AccessNode)
                and e.data.data != dst_node.data else src_node)

            if isinstance(subset_node, nd.AccessNode):
                arr = sdfg.arrays[subset_node.data]
                # Dimensionality
                if e.data.subset.dims() != len(arr.shape):
                    raise InvalidSDFGEdgeError(
                        "Memlet subset does not match node dimension "
                        "(expected %d, got %d)" %
                        (len(arr.shape), e.data.subset.dims()),
                        sdfg,
                        state_id,
                        eid,
                    )

                # Bounds
                if any(((minel + off) < 0) == True for minel, off in zip(
                        e.data.subset.min_element(), arr.offset)):
                    raise InvalidSDFGEdgeError(
                        "Memlet subset negative out-of-bounds", sdfg, state_id,
                        eid)
                if any(((maxel + off) >= s) == True for maxel, s, off in zip(
                        e.data.subset.max_element(), arr.shape, arr.offset)):
                    raise InvalidSDFGEdgeError("Memlet subset out-of-bounds",
                                               sdfg, state_id, eid)
            # Test other_subset as well
            if e.data.other_subset is not None and isinstance(
                    other_subset_node, nd.AccessNode):
                arr = sdfg.arrays[other_subset_node.data]
                # Dimensionality
                if e.data.other_subset.dims() != len(arr.shape):
                    raise InvalidSDFGEdgeError(
                        "Memlet other_subset does not match node dimension "
                        "(expected %d, got %d)" %
                        (len(arr.shape), e.data.other_subset.dims()),
                        sdfg,
                        state_id,
                        eid,
                    )

                # Bounds
                if any(((minel + off) < 0) == True for minel, off in zip(
                        e.data.other_subset.min_element(), arr.offset)):
                    raise InvalidSDFGEdgeError(
                        "Memlet other_subset negative out-of-bounds",
                        sdfg,
                        state_id,
                        eid,
                    )
                if any(((maxel + off) >= s) == True for maxel, s, off in zip(
                        e.data.other_subset.max_element(), arr.shape,
                        arr.offset)):
                    raise InvalidSDFGEdgeError(
                        "Memlet other_subset out-of-bounds", sdfg, state_id,
                        eid)

            # Test subset and other_subset for undefined symbols
            if Config.get_bool('experimental', 'validate_undefs'):
                # TODO: Traverse by scopes and accumulate data
                defined_symbols = state.symbols_defined_at(e.dst)
                undefs = (e.data.subset.free_symbols -
                          set(defined_symbols.keys()))
                if len(undefs) > 0:
                    raise InvalidSDFGEdgeError(
                        'Undefined symbols %s found in memlet subset' % undefs,
                        sdfg, state_id, eid)
                if e.data.other_subset is not None:
                    undefs = (e.data.other_subset.free_symbols -
                              defined_symbols)
                    if len(undefs) > 0:
                        raise InvalidSDFGEdgeError(
                            'Undefined symbols %s found in memlet '
                            'other_subset' % undefs, sdfg, state_id, eid)
        #######################################

        # Memlet path scope lifetime checks
        # If scope(src) == scope(dst): OK
        if scope[src_node] == scope[dst_node] or src_node == scope[dst_node]:
            pass
        # If scope(src) contains scope(dst), then src must be a data node
        elif scope_contains_scope(scope, src_node, dst_node):
            if not isinstance(src_node, nd.AccessNode):
                pass
                # raise InvalidSDFGEdgeError(
                #     "Memlet creates an "
                #     "invalid path (source node %s should "
                #     "be a data node)" % str(src_node),
                #     sdfg,
                #     state_id,
                #     eid,
                # )
        # If scope(dst) contains scope(src), then dst must be a data node
        elif scope_contains_scope(scope, dst_node, src_node):
            if not isinstance(dst_node, nd.AccessNode):
                raise InvalidSDFGEdgeError(
                    "Memlet creates an "
                    "invalid path (sink node %s should "
                    "be a data node)" % str(dst_node),
                    sdfg,
                    state_id,
                    eid,
                )
        # If scope(dst) is disjoint from scope(src), it's an illegal memlet
        else:
            raise InvalidSDFGEdgeError(
                "Illegal memlet between disjoint scopes", sdfg, state_id, eid)

        # Check dimensionality of memory access
        if isinstance(e.data.subset, (sbs.Range, sbs.Indices)):
            if e.data.subset.dims() != len(sdfg.arrays[e.data.data].shape):
                raise InvalidSDFGEdgeError(
                    "Memlet subset uses the wrong dimensions"
                    " (%dD for a %dD data node)" %
                    (e.data.subset.dims(), len(
                        sdfg.arrays[e.data.data].shape)),
                    sdfg,
                    state_id,
                    eid,
                )

        # Verify that source and destination subsets contain the same
        # number of elements
        if e.data.other_subset is not None and not (
            (isinstance(src_node, nd.AccessNode)
             and isinstance(sdfg.arrays[src_node.data], dt.Stream)) or
            (isinstance(dst_node, nd.AccessNode)
             and isinstance(sdfg.arrays[dst_node.data], dt.Stream))):
            if (e.data.src_subset.num_elements() *
                    sdfg.arrays[src_node.data].veclen !=
                    e.data.dst_subset.num_elements() *
                    sdfg.arrays[dst_node.data].veclen):
                raise InvalidSDFGEdgeError(
                    'Dimensionality mismatch between src/dst subsets', sdfg,
                    state_id, eid)
Exemple #14
0
def validate_sdfg(sdfg: 'dace.sdfg.SDFG'):
    """ Verifies the correctness of an SDFG by applying multiple tests.
        :param sdfg: The SDFG to verify.

        Raises an InvalidSDFGError with the erroneous node/edge
        on failure.
    """
    # Avoid import loop
    from dace.codegen.targets import fpga

    try:
        # SDFG-level checks
        if not dtypes.validate_name(sdfg.name):
            raise InvalidSDFGError("Invalid name", sdfg, None)

        if len(sdfg.source_nodes()) > 1 and sdfg.start_state is None:
            raise InvalidSDFGError("Starting state undefined", sdfg, None)

        if len(set([s.label for s in sdfg.nodes()])) != len(sdfg.nodes()):
            raise InvalidSDFGError("Found multiple states with the same name",
                                   sdfg, None)

        # Validate data descriptors
        for name, desc in sdfg._arrays.items():
            # Validate array names
            if name is not None and not dtypes.validate_name(name):
                raise InvalidSDFGError("Invalid array name %s" % name, sdfg,
                                       None)
            # Allocation lifetime checks
            if (desc.lifetime is dtypes.AllocationLifetime.Persistent
                    and desc.storage is dtypes.StorageType.Register):
                raise InvalidSDFGError(
                    "Array %s cannot be both persistent and use Register as "
                    "storage type. Please use a different storage location." %
                    name, sdfg, None)

            # Check for valid bank assignments
            try:
                bank_assignment = fpga.parse_location_bank(desc)
            except ValueError as e:
                raise InvalidSDFGError(str(e), sdfg, None)
            if bank_assignment is not None:
                if bank_assignment[0] == "DDR" or bank_assignment[0] == "HBM":
                    try:
                        tmp = subsets.Range.from_string(bank_assignment[1])
                    except SyntaxError:
                        raise InvalidSDFGError(
                            "Memory bank specifier must be convertible to subsets.Range"
                            f" for array {name}", sdfg, None)
                    try:
                        low, high = fpga.get_multibank_ranges_from_subset(
                            bank_assignment[1], sdfg)
                    except ValueError as e:
                        raise InvalidSDFGError(str(e), sdfg, None)
                    if (high - low < 1):
                        raise InvalidSDFGError(
                            "Memory bank specifier must at least define one bank to be used"
                            f" for array {name}", sdfg, None)
                    if (high - low > 1 and
                        (high - low != desc.shape[0] or len(desc.shape) < 2)):
                        raise InvalidSDFGError(
                            "Arrays that use a multibank access pattern must have the size of the first dimension equal"
                            f" the number of banks and have at least 2 dimensions for array {name}",
                            sdfg, None)

        # Check every state separately
        start_state = sdfg.start_state
        initialized_transients = {'__pystate'}
        symbols = copy.deepcopy(sdfg.symbols)
        symbols.update(sdfg.arrays)
        symbols.update({
            k: dt.create_datadescriptor(v)
            for k, v in sdfg.constants.items()
        })
        for desc in sdfg.arrays.values():
            for sym in desc.free_symbols:
                symbols[str(sym)] = sym.dtype
        visited = set()
        visited_edges = set()
        # Run through states via DFS, ensuring that only the defined symbols
        # are available for validation
        for edge in sdfg.dfs_edges(start_state):
            # Source -> inter-state definition -> Destination
            ##########################################
            visited_edges.add(edge)
            # Source
            if edge.src not in visited:
                visited.add(edge.src)
                validate_state(edge.src, sdfg.node_id(edge.src), sdfg, symbols,
                               initialized_transients)

            ##########################################
            # Edge
            # Check inter-state edge for undefined symbols
            undef_syms = set(edge.data.free_symbols) - set(symbols.keys())
            if len(undef_syms) > 0:
                eid = sdfg.edge_id(edge)
                raise InvalidSDFGInterstateEdgeError(
                    "Undefined symbols in edge: %s" % undef_syms, sdfg, eid)

            # Validate inter-state edge names
            issyms = edge.data.new_symbols(sdfg, symbols)
            if any(not dtypes.validate_name(s) for s in issyms):
                invalid = next(s for s in issyms
                               if not dtypes.validate_name(s))
                eid = sdfg.edge_id(edge)
                raise InvalidSDFGInterstateEdgeError(
                    "Invalid interstate symbol name %s" % invalid, sdfg, eid)

            # Add edge symbols into defined symbols
            symbols.update(issyms)

            ##########################################
            # Destination
            if edge.dst not in visited:
                visited.add(edge.dst)
                validate_state(edge.dst, sdfg.node_id(edge.dst), sdfg, symbols,
                               initialized_transients)
        # End of state DFS

        # If there is only one state, the DFS will miss it
        if start_state not in visited:
            validate_state(start_state, sdfg.node_id(start_state), sdfg,
                           symbols, initialized_transients)

        # Validate all inter-state edges (including self-loops not found by DFS)
        for eid, edge in enumerate(sdfg.edges()):
            if edge in visited_edges:
                continue
            issyms = edge.data.assignments.keys()
            if any(not dtypes.validate_name(s) for s in issyms):
                invalid = next(s for s in issyms
                               if not dtypes.validate_name(s))
                raise InvalidSDFGInterstateEdgeError(
                    "Invalid interstate symbol name %s" % invalid, sdfg, eid)

    except InvalidSDFGError as ex:
        # If the SDFG is invalid, save it
        sdfg.save(os.path.join('_dacegraphs', 'invalid.sdfg'), exception=ex)
        raise