Ejemplo n.º 1
0
def validate_state(state: 'dace.sdfg.SDFGState',
                   state_id: int = None,
                   sdfg: 'dace.sdfg.SDFG' = None,
                   symbols: Dict[str, dtypes.typeclass] = None):
    """ Verifies the correctness of an SDFG state by applying multiple
        tests. Raises an InvalidSDFGError with the erroneous node on
        failure.
    """
    # Avoid import loops
    from dace.sdfg import SDFG
    from dace.config import Config
    from dace.sdfg import nodes as nd
    from dace.sdfg.scope import scope_contains_scope
    from dace import data as dt
    from dace import subsets as sbs

    sdfg = sdfg or state.parent
    state_id = state_id or sdfg.node_id(state)
    symbols = symbols or {}

    if not dtypes.validate_name(state._label):
        raise InvalidSDFGError("Invalid state name", sdfg, state_id)

    if state._parent != sdfg:
        raise InvalidSDFGError("State does not point to the correct "
                               "parent", sdfg, state_id)

    # Unreachable
    ########################################
    if (sdfg.number_of_nodes() > 1 and sdfg.in_degree(state) == 0
            and sdfg.out_degree(state) == 0):
        raise InvalidSDFGError("Unreachable state", sdfg, state_id)

    for nid, node in enumerate(state.nodes()):
        # Node validation
        try:
            node.validate(sdfg, state)
        except InvalidSDFGError:
            raise
        except Exception as ex:
            raise InvalidSDFGNodeError("Node validation failed: " + str(ex),
                                       sdfg, state_id, nid) from ex

        # Isolated nodes
        ########################################
        if state.in_degree(node) + state.out_degree(node) == 0:
            # One corner case: OK if this is a code node
            if isinstance(node, nd.CodeNode):
                pass
            else:
                raise InvalidSDFGNodeError("Isolated node", sdfg, state_id,
                                           nid)

        # Scope tests
        ########################################
        if isinstance(node, nd.EntryNode):
            try:
                state.exit_node(node)
            except StopIteration:
                raise InvalidSDFGNodeError(
                    "Entry node does not have matching "
                    "exit node",
                    sdfg,
                    state_id,
                    nid,
                )

        if isinstance(node, (nd.EntryNode, nd.ExitNode)):
            for iconn in node.in_connectors:
                if (iconn is not None and iconn.startswith("IN_")
                        and ("OUT_" + iconn[3:]) not in node.out_connectors):
                    raise InvalidSDFGNodeError(
                        "No match for input connector %s in output "
                        "connectors" % iconn,
                        sdfg,
                        state_id,
                        nid,
                    )
            for oconn in node.out_connectors:
                if (oconn is not None and oconn.startswith("OUT_")
                        and ("IN_" + oconn[4:]) not in node.in_connectors):
                    raise InvalidSDFGNodeError(
                        "No match for output connector %s in input "
                        "connectors" % oconn,
                        sdfg,
                        state_id,
                        nid,
                    )

        # Node-specific tests
        ########################################
        if isinstance(node, nd.AccessNode):
            if node.data not in sdfg.arrays:
                raise InvalidSDFGNodeError(
                    "Access node must point to a valid array name in the SDFG",
                    sdfg,
                    state_id,
                    nid,
                )
            arr = sdfg.arrays[node.data]

            # Verify View references
            if isinstance(arr, dt.View):
                from dace.sdfg import utils as sdutil  # Avoid import loops
                if sdutil.get_view_edge(state, node) is None:
                    raise InvalidSDFGNodeError(
                        "Ambiguous or invalid edge to/from a View access node",
                        sdfg, state_id, nid)

            # Find uninitialized transients
            if (arr.transient and state.in_degree(node) == 0
                    and state.out_degree(node) > 0
                    # Streams do not need to be initialized
                    and not isinstance(arr, dt.Stream)):
                # Find other instances of node in predecessor states
                states = sdfg.predecessor_states(state)
                input_found = False
                for s in states:
                    for onode in s.nodes():
                        if (isinstance(onode, nd.AccessNode)
                                and onode.data == node.data):
                            if s.in_degree(onode) > 0:
                                input_found = True
                                break
                    if input_found:
                        break
                if not input_found and node.setzero == False:
                    warnings.warn(
                        'WARNING: Use of uninitialized transient "%s" in state %s'
                        % (node.data, state.label))

            # Find writes to input-only arrays
            only_empty_inputs = all(e.data.is_empty()
                                    for e in state.in_edges(node))
            if (not arr.transient) and (not only_empty_inputs):
                nsdfg_node = sdfg.parent_nsdfg_node
                if nsdfg_node is not None:
                    if node.data not in nsdfg_node.out_connectors:
                        raise InvalidSDFGNodeError(
                            'Data descriptor %s is '
                            'written to, but only given to nested SDFG as an '
                            'input connector' % node.data, sdfg, state_id, nid)

        if (isinstance(node, nd.ConsumeEntry)
                and "IN_stream" not in node.in_connectors):
            raise InvalidSDFGNodeError(
                "Consume entry node must have an input stream", sdfg, state_id,
                nid)
        if (isinstance(node, nd.ConsumeEntry)
                and "OUT_stream" not in node.out_connectors):
            raise InvalidSDFGNodeError(
                "Consume entry node must have an internal stream",
                sdfg,
                state_id,
                nid,
            )

        # Connector tests
        ########################################
        # Check for duplicate connector names (unless it's a nested SDFG)
        if (len(node.in_connectors.keys() & node.out_connectors.keys()) > 0
                and not isinstance(node, (nd.NestedSDFG, nd.LibraryNode))):
            dups = node.in_connectors.keys() & node.out_connectors.keys()
            raise InvalidSDFGNodeError("Duplicate connectors: " + str(dups),
                                       sdfg, state_id, nid)

        # Check for connectors that are also array/symbol names
        if isinstance(node, nd.Tasklet):
            for conn in node.in_connectors.keys():
                if conn in sdfg.arrays or conn in symbols:
                    raise InvalidSDFGNodeError(
                        f"Input connector {conn} already "
                        "defined as array or symbol", sdfg, state_id, nid)
            for conn in node.out_connectors.keys():
                if conn in sdfg.arrays or conn in symbols:
                    raise InvalidSDFGNodeError(
                        f"Output connector {conn} already "
                        "defined as array or symbol", sdfg, state_id, nid)

        # Check for dangling connectors (incoming)
        for conn in node.in_connectors:
            incoming_edges = 0
            for e in state.in_edges(node):
                # Connector found
                if e.dst_conn == conn:
                    incoming_edges += 1

            if incoming_edges == 0:
                raise InvalidSDFGNodeError("Dangling in-connector %s" % conn,
                                           sdfg, state_id, nid)
            # Connectors may have only one incoming edge
            # Due to input connectors of scope exit, this is only correct
            # in some cases:
            if incoming_edges > 1 and not isinstance(node, nd.ExitNode):
                raise InvalidSDFGNodeError(
                    "Connector '%s' cannot have more "
                    "than one incoming edge, found %d" %
                    (conn, incoming_edges),
                    sdfg,
                    state_id,
                    nid,
                )

        # Check for dangling connectors (outgoing)
        for conn in node.out_connectors:
            outgoing_edges = 0
            for e in state.out_edges(node):
                # Connector found
                if e.src_conn == conn:
                    outgoing_edges += 1

            if outgoing_edges == 0:
                raise InvalidSDFGNodeError("Dangling out-connector %s" % conn,
                                           sdfg, state_id, nid)

            # In case of scope exit or code node, only one outgoing edge per
            # connector is allowed.
            if outgoing_edges > 1 and isinstance(node,
                                                 (nd.ExitNode, nd.CodeNode)):
                raise InvalidSDFGNodeError(
                    "Connector '%s' cannot have more "
                    "than one outgoing edge, found %d" %
                    (conn, outgoing_edges),
                    sdfg,
                    state_id,
                    nid,
                )

        # Check for edges to nonexistent connectors
        for e in state.in_edges(node):
            if e.dst_conn is not None and e.dst_conn not in node.in_connectors:
                raise InvalidSDFGNodeError(
                    ("Memlet %s leading to " + "nonexistent connector %s") %
                    (str(e.data), e.dst_conn),
                    sdfg,
                    state_id,
                    nid,
                )
        for e in state.out_edges(node):
            if e.src_conn is not None and e.src_conn not in node.out_connectors:
                raise InvalidSDFGNodeError(
                    ("Memlet %s coming from " + "nonexistent connector %s") %
                    (str(e.data), e.src_conn),
                    sdfg,
                    state_id,
                    nid,
                )
        ########################################

    # Memlet checks
    scope = state.scope_dict()
    for eid, e in enumerate(state.edges()):
        # Edge validation
        try:
            e.data.validate(sdfg, state)
        except InvalidSDFGError:
            raise
        except Exception as ex:
            raise InvalidSDFGEdgeError("Edge validation failed: " + str(ex),
                                       sdfg, state_id, eid)

        # For every memlet, obtain its full path in the DFG
        path = state.memlet_path(e)
        src_node = path[0].src
        dst_node = path[-1].dst

        # Check if memlet data matches src or dst nodes
        if (e.data.data is not None
                and (isinstance(src_node, nd.AccessNode)
                     or isinstance(dst_node, nd.AccessNode))
                and (not isinstance(src_node, nd.AccessNode)
                     or e.data.data != src_node.data)
                and (not isinstance(dst_node, nd.AccessNode)
                     or e.data.data != dst_node.data)):
            raise InvalidSDFGEdgeError(
                "Memlet data does not match source or destination "
                "data nodes)",
                sdfg,
                state_id,
                eid,
            )

        # Check memlet subset validity with respect to source/destination nodes
        if e.data.data is not None and e.data.allow_oob == False:
            subset_node = (dst_node if isinstance(dst_node, nd.AccessNode)
                           and e.data.data == dst_node.data else src_node)
            other_subset_node = (
                dst_node if isinstance(dst_node, nd.AccessNode)
                and e.data.data != dst_node.data else src_node)

            if isinstance(subset_node, nd.AccessNode):
                arr = sdfg.arrays[subset_node.data]
                # Dimensionality
                if e.data.subset.dims() != len(arr.shape):
                    raise InvalidSDFGEdgeError(
                        "Memlet subset does not match node dimension "
                        "(expected %d, got %d)" %
                        (len(arr.shape), e.data.subset.dims()),
                        sdfg,
                        state_id,
                        eid,
                    )

                # Bounds
                if any(((minel + off) < 0) == True for minel, off in zip(
                        e.data.subset.min_element(), arr.offset)):
                    raise InvalidSDFGEdgeError(
                        "Memlet subset negative out-of-bounds", sdfg, state_id,
                        eid)
                if any(((maxel + off) >= s) == True for maxel, s, off in zip(
                        e.data.subset.max_element(), arr.shape, arr.offset)):
                    raise InvalidSDFGEdgeError("Memlet subset out-of-bounds",
                                               sdfg, state_id, eid)
            # Test other_subset as well
            if e.data.other_subset is not None and isinstance(
                    other_subset_node, nd.AccessNode):
                arr = sdfg.arrays[other_subset_node.data]
                # Dimensionality
                if e.data.other_subset.dims() != len(arr.shape):
                    raise InvalidSDFGEdgeError(
                        "Memlet other_subset does not match node dimension "
                        "(expected %d, got %d)" %
                        (len(arr.shape), e.data.other_subset.dims()),
                        sdfg,
                        state_id,
                        eid,
                    )

                # Bounds
                if any(((minel + off) < 0) == True for minel, off in zip(
                        e.data.other_subset.min_element(), arr.offset)):
                    raise InvalidSDFGEdgeError(
                        "Memlet other_subset negative out-of-bounds",
                        sdfg,
                        state_id,
                        eid,
                    )
                if any(((maxel + off) >= s) == True for maxel, s, off in zip(
                        e.data.other_subset.max_element(), arr.shape,
                        arr.offset)):
                    raise InvalidSDFGEdgeError(
                        "Memlet other_subset out-of-bounds", sdfg, state_id,
                        eid)

            # Test subset and other_subset for undefined symbols
            if Config.get_bool('experimental', 'validate_undefs'):
                # TODO: Traverse by scopes and accumulate data
                defined_symbols = state.symbols_defined_at(e.dst)
                undefs = (e.data.subset.free_symbols -
                          set(defined_symbols.keys()))
                if len(undefs) > 0:
                    raise InvalidSDFGEdgeError(
                        'Undefined symbols %s found in memlet subset' % undefs,
                        sdfg, state_id, eid)
                if e.data.other_subset is not None:
                    undefs = (e.data.other_subset.free_symbols -
                              set(defined_symbols.keys()))
                    if len(undefs) > 0:
                        raise InvalidSDFGEdgeError(
                            'Undefined symbols %s found in memlet '
                            'other_subset' % undefs, sdfg, state_id, eid)
        #######################################

        # Memlet path scope lifetime checks
        # If scope(src) == scope(dst): OK
        if scope[src_node] == scope[dst_node] or src_node == scope[dst_node]:
            pass
        # If scope(src) contains scope(dst), then src must be a data node,
        # unless the memlet is empty in order to connect to a scope
        elif scope_contains_scope(scope, src_node, dst_node):
            pass
        # If scope(dst) contains scope(src), then dst must be a data node,
        # unless the memlet is empty in order to connect to a scope
        elif scope_contains_scope(scope, dst_node, src_node):
            if not isinstance(dst_node, nd.AccessNode):
                if e.data.is_empty() and isinstance(dst_node, nd.ExitNode):
                    pass
                else:
                    raise InvalidSDFGEdgeError(
                        f"Memlet creates an invalid path (sink node {dst_node}"
                        " should be a data node)", sdfg, state_id, eid)
        # If scope(dst) is disjoint from scope(src), it's an illegal memlet
        else:
            raise InvalidSDFGEdgeError(
                "Illegal memlet between disjoint scopes", sdfg, state_id, eid)

        # Check dimensionality of memory access
        if isinstance(e.data.subset, (sbs.Range, sbs.Indices)):
            if e.data.subset.dims() != len(sdfg.arrays[e.data.data].shape):
                raise InvalidSDFGEdgeError(
                    "Memlet subset uses the wrong dimensions"
                    " (%dD for a %dD data node)" %
                    (e.data.subset.dims(), len(
                        sdfg.arrays[e.data.data].shape)),
                    sdfg,
                    state_id,
                    eid,
                )

        # Verify that source and destination subsets contain the same
        # number of elements
        if not e.data.allow_oob and e.data.other_subset is not None and not (
            (isinstance(src_node, nd.AccessNode)
             and isinstance(sdfg.arrays[src_node.data], dt.Stream)) or
            (isinstance(dst_node, nd.AccessNode)
             and isinstance(sdfg.arrays[dst_node.data], dt.Stream))):
            src_expr = (e.data.src_subset.num_elements() *
                        sdfg.arrays[src_node.data].veclen)
            dst_expr = (e.data.dst_subset.num_elements() *
                        sdfg.arrays[dst_node.data].veclen)
            if symbolic.inequal_symbols(src_expr, dst_expr):
                raise InvalidSDFGEdgeError(
                    'Dimensionality mismatch between src/dst subsets', sdfg,
                    state_id, eid)
Ejemplo n.º 2
0
def _validate_subsets(edge: graph.MultiConnectorEdge,
                      arrays: typing.Dict[str, data.Data],
                      src_name: str = None,
                      dst_name: str = None) -> typing.Tuple[subsets.Subset]:
    """ Extracts and validates src and dst subsets from the edge. """

    # Find src and dst names
    if not src_name and isinstance(edge.src, nodes.AccessNode):
        src_name = edge.src.data
    if not dst_name and isinstance(edge.dst, nodes.AccessNode):
        dst_name = edge.dst.data
    if not src_name and not dst_name:
        raise NotImplementedError

    # Find the src and dst subsets (deep-copy to allow manipulation)
    src_subset = copy.deepcopy(edge.data.src_subset)
    dst_subset = copy.deepcopy(edge.data.dst_subset)

    if not src_subset and not dst_subset:
        # NOTE: This should never happen
        raise NotImplementedError
    # NOTE: If any of the subsets is None, it means that we proceed in
    # experimental mode. The base case here is that we just copy the other
    # subset. However, if we can locate the other array, we check the
    # dimensionality of the subset and we pop or pad indices/ranges accordingly.
    # In that case, we also set the subset to start from 0 in each dimension.
    if not src_subset:
        if src_name:
            desc = arrays[src_name]
            if isinstance(desc, data.View) or edge.data.data == dst_name:
                src_subset = subsets.Range.from_array(desc)
                src_expr = src_subset.num_elements()
                src_expr_exact = src_subset.num_elements_exact()
                dst_expr = dst_subset.num_elements()
                dst_expr_exact = dst_subset.num_elements_exact()
                if (src_expr != dst_expr and symbolic.inequal_symbols(
                        src_expr_exact, dst_expr_exact)):
                    raise ValueError(
                        "Source subset is missing (dst_subset: {}, "
                        "src_shape: {}".format(dst_subset, desc.shape))
            else:
                src_subset = copy.deepcopy(dst_subset)
                padding = len(desc.shape) - len(src_subset)
                if padding != 0:
                    if padding > 0:
                        if isinstance(src_subset, subsets.Indices):
                            indices = [0] * padding + src_subset.indices
                            src_subset = subsets.Indices(indices)
                        elif isinstance(src_subset, subsets.Range):
                            ranges = [(0, 0, 1)] * padding + src_subset.ranges
                            src_subset = subsets.Range(ranges)
                    elif padding < 0:
                        if isinstance(src_subset, subsets.Indices):
                            indices = src_subset.indices[-padding:]
                            src_subset = subsets.Indices(indices)
                        elif isinstance(src_subset, subsets.Range):
                            ranges = src_subset.ranges[-padding:]
                            src_subset = subsets.Range(ranges)
                    src_subset.offset(src_subset, True)
    elif not dst_subset:
        if dst_name:
            desc = arrays[dst_name]
            if isinstance(desc, data.View) or edge.data.data == src_name:
                dst_subset = subsets.Range.from_array(desc)
                src_expr = src_subset.num_elements()
                src_expr_exact = src_subset.num_elements_exact()
                dst_expr = dst_subset.num_elements()
                dst_expr_exact = dst_subset.num_elements_exact()
                if (src_expr != dst_expr and symbolic.inequal_symbols(
                        src_expr_exact, dst_expr_exact)):
                    raise ValueError(
                        "Destination subset is missing (src_subset: {}, "
                        "dst_shape: {}".format(src_subset, desc.shape))
            else:
                dst_subset = copy.deepcopy(src_subset)
                padding = len(desc.shape) - len(dst_subset)
                if padding != 0:
                    if padding > 0:
                        if isinstance(dst_subset, subsets.Indices):
                            indices = [0] * padding + dst_subset.indices
                            dst_subset = subsets.Indices(indices)
                        elif isinstance(dst_subset, subsets.Range):
                            ranges = [(0, 0, 1)] * padding + dst_subset.ranges
                            dst_subset = subsets.Range(ranges)
                    elif padding < 0:
                        if isinstance(dst_subset, subsets.Indices):
                            indices = dst_subset.indices[-padding:]
                            dst_subset = subsets.Indices(indices)
                        elif isinstance(dst_subset, subsets.Range):
                            ranges = dst_subset.ranges[-padding:]
                            dst_subset = subsets.Range(ranges)
                    dst_subset.offset(dst_subset, True)

    return src_subset, dst_subset