def validate_state(state: 'dace.sdfg.SDFGState', state_id: int = None, sdfg: 'dace.sdfg.SDFG' = None, symbols: Dict[str, dtypes.typeclass] = None): """ Verifies the correctness of an SDFG state by applying multiple tests. Raises an InvalidSDFGError with the erroneous node on failure. """ # Avoid import loops from dace.sdfg import SDFG from dace.config import Config from dace.sdfg import nodes as nd from dace.sdfg.scope import scope_contains_scope from dace import data as dt from dace import subsets as sbs sdfg = sdfg or state.parent state_id = state_id or sdfg.node_id(state) symbols = symbols or {} if not dtypes.validate_name(state._label): raise InvalidSDFGError("Invalid state name", sdfg, state_id) if state._parent != sdfg: raise InvalidSDFGError("State does not point to the correct " "parent", sdfg, state_id) # Unreachable ######################################## if (sdfg.number_of_nodes() > 1 and sdfg.in_degree(state) == 0 and sdfg.out_degree(state) == 0): raise InvalidSDFGError("Unreachable state", sdfg, state_id) for nid, node in enumerate(state.nodes()): # Node validation try: node.validate(sdfg, state) except InvalidSDFGError: raise except Exception as ex: raise InvalidSDFGNodeError("Node validation failed: " + str(ex), sdfg, state_id, nid) from ex # Isolated nodes ######################################## if state.in_degree(node) + state.out_degree(node) == 0: # One corner case: OK if this is a code node if isinstance(node, nd.CodeNode): pass else: raise InvalidSDFGNodeError("Isolated node", sdfg, state_id, nid) # Scope tests ######################################## if isinstance(node, nd.EntryNode): try: state.exit_node(node) except StopIteration: raise InvalidSDFGNodeError( "Entry node does not have matching " "exit node", sdfg, state_id, nid, ) if isinstance(node, (nd.EntryNode, nd.ExitNode)): for iconn in node.in_connectors: if (iconn is not None and iconn.startswith("IN_") and ("OUT_" + iconn[3:]) not in node.out_connectors): raise InvalidSDFGNodeError( "No match for input connector %s in output " "connectors" % iconn, sdfg, state_id, nid, ) for oconn in node.out_connectors: if (oconn is not None and oconn.startswith("OUT_") and ("IN_" + oconn[4:]) not in node.in_connectors): raise InvalidSDFGNodeError( "No match for output connector %s in input " "connectors" % oconn, sdfg, state_id, nid, ) # Node-specific tests ######################################## if isinstance(node, nd.AccessNode): if node.data not in sdfg.arrays: raise InvalidSDFGNodeError( "Access node must point to a valid array name in the SDFG", sdfg, state_id, nid, ) arr = sdfg.arrays[node.data] # Verify View references if isinstance(arr, dt.View): from dace.sdfg import utils as sdutil # Avoid import loops if sdutil.get_view_edge(state, node) is None: raise InvalidSDFGNodeError( "Ambiguous or invalid edge to/from a View access node", sdfg, state_id, nid) # Find uninitialized transients if (arr.transient and state.in_degree(node) == 0 and state.out_degree(node) > 0 # Streams do not need to be initialized and not isinstance(arr, dt.Stream)): # Find other instances of node in predecessor states states = sdfg.predecessor_states(state) input_found = False for s in states: for onode in s.nodes(): if (isinstance(onode, nd.AccessNode) and onode.data == node.data): if s.in_degree(onode) > 0: input_found = True break if input_found: break if not input_found and node.setzero == False: warnings.warn( 'WARNING: Use of uninitialized transient "%s" in state %s' % (node.data, state.label)) # Find writes to input-only arrays only_empty_inputs = all(e.data.is_empty() for e in state.in_edges(node)) if (not arr.transient) and (not only_empty_inputs): nsdfg_node = sdfg.parent_nsdfg_node if nsdfg_node is not None: if node.data not in nsdfg_node.out_connectors: raise InvalidSDFGNodeError( 'Data descriptor %s is ' 'written to, but only given to nested SDFG as an ' 'input connector' % node.data, sdfg, state_id, nid) if (isinstance(node, nd.ConsumeEntry) and "IN_stream" not in node.in_connectors): raise InvalidSDFGNodeError( "Consume entry node must have an input stream", sdfg, state_id, nid) if (isinstance(node, nd.ConsumeEntry) and "OUT_stream" not in node.out_connectors): raise InvalidSDFGNodeError( "Consume entry node must have an internal stream", sdfg, state_id, nid, ) # Connector tests ######################################## # Check for duplicate connector names (unless it's a nested SDFG) if (len(node.in_connectors.keys() & node.out_connectors.keys()) > 0 and not isinstance(node, (nd.NestedSDFG, nd.LibraryNode))): dups = node.in_connectors.keys() & node.out_connectors.keys() raise InvalidSDFGNodeError("Duplicate connectors: " + str(dups), sdfg, state_id, nid) # Check for connectors that are also array/symbol names if isinstance(node, nd.Tasklet): for conn in node.in_connectors.keys(): if conn in sdfg.arrays or conn in symbols: raise InvalidSDFGNodeError( f"Input connector {conn} already " "defined as array or symbol", sdfg, state_id, nid) for conn in node.out_connectors.keys(): if conn in sdfg.arrays or conn in symbols: raise InvalidSDFGNodeError( f"Output connector {conn} already " "defined as array or symbol", sdfg, state_id, nid) # Check for dangling connectors (incoming) for conn in node.in_connectors: incoming_edges = 0 for e in state.in_edges(node): # Connector found if e.dst_conn == conn: incoming_edges += 1 if incoming_edges == 0: raise InvalidSDFGNodeError("Dangling in-connector %s" % conn, sdfg, state_id, nid) # Connectors may have only one incoming edge # Due to input connectors of scope exit, this is only correct # in some cases: if incoming_edges > 1 and not isinstance(node, nd.ExitNode): raise InvalidSDFGNodeError( "Connector '%s' cannot have more " "than one incoming edge, found %d" % (conn, incoming_edges), sdfg, state_id, nid, ) # Check for dangling connectors (outgoing) for conn in node.out_connectors: outgoing_edges = 0 for e in state.out_edges(node): # Connector found if e.src_conn == conn: outgoing_edges += 1 if outgoing_edges == 0: raise InvalidSDFGNodeError("Dangling out-connector %s" % conn, sdfg, state_id, nid) # In case of scope exit or code node, only one outgoing edge per # connector is allowed. if outgoing_edges > 1 and isinstance(node, (nd.ExitNode, nd.CodeNode)): raise InvalidSDFGNodeError( "Connector '%s' cannot have more " "than one outgoing edge, found %d" % (conn, outgoing_edges), sdfg, state_id, nid, ) # Check for edges to nonexistent connectors for e in state.in_edges(node): if e.dst_conn is not None and e.dst_conn not in node.in_connectors: raise InvalidSDFGNodeError( ("Memlet %s leading to " + "nonexistent connector %s") % (str(e.data), e.dst_conn), sdfg, state_id, nid, ) for e in state.out_edges(node): if e.src_conn is not None and e.src_conn not in node.out_connectors: raise InvalidSDFGNodeError( ("Memlet %s coming from " + "nonexistent connector %s") % (str(e.data), e.src_conn), sdfg, state_id, nid, ) ######################################## # Memlet checks scope = state.scope_dict() for eid, e in enumerate(state.edges()): # Edge validation try: e.data.validate(sdfg, state) except InvalidSDFGError: raise except Exception as ex: raise InvalidSDFGEdgeError("Edge validation failed: " + str(ex), sdfg, state_id, eid) # For every memlet, obtain its full path in the DFG path = state.memlet_path(e) src_node = path[0].src dst_node = path[-1].dst # Check if memlet data matches src or dst nodes if (e.data.data is not None and (isinstance(src_node, nd.AccessNode) or isinstance(dst_node, nd.AccessNode)) and (not isinstance(src_node, nd.AccessNode) or e.data.data != src_node.data) and (not isinstance(dst_node, nd.AccessNode) or e.data.data != dst_node.data)): raise InvalidSDFGEdgeError( "Memlet data does not match source or destination " "data nodes)", sdfg, state_id, eid, ) # Check memlet subset validity with respect to source/destination nodes if e.data.data is not None and e.data.allow_oob == False: subset_node = (dst_node if isinstance(dst_node, nd.AccessNode) and e.data.data == dst_node.data else src_node) other_subset_node = ( dst_node if isinstance(dst_node, nd.AccessNode) and e.data.data != dst_node.data else src_node) if isinstance(subset_node, nd.AccessNode): arr = sdfg.arrays[subset_node.data] # Dimensionality if e.data.subset.dims() != len(arr.shape): raise InvalidSDFGEdgeError( "Memlet subset does not match node dimension " "(expected %d, got %d)" % (len(arr.shape), e.data.subset.dims()), sdfg, state_id, eid, ) # Bounds if any(((minel + off) < 0) == True for minel, off in zip( e.data.subset.min_element(), arr.offset)): raise InvalidSDFGEdgeError( "Memlet subset negative out-of-bounds", sdfg, state_id, eid) if any(((maxel + off) >= s) == True for maxel, s, off in zip( e.data.subset.max_element(), arr.shape, arr.offset)): raise InvalidSDFGEdgeError("Memlet subset out-of-bounds", sdfg, state_id, eid) # Test other_subset as well if e.data.other_subset is not None and isinstance( other_subset_node, nd.AccessNode): arr = sdfg.arrays[other_subset_node.data] # Dimensionality if e.data.other_subset.dims() != len(arr.shape): raise InvalidSDFGEdgeError( "Memlet other_subset does not match node dimension " "(expected %d, got %d)" % (len(arr.shape), e.data.other_subset.dims()), sdfg, state_id, eid, ) # Bounds if any(((minel + off) < 0) == True for minel, off in zip( e.data.other_subset.min_element(), arr.offset)): raise InvalidSDFGEdgeError( "Memlet other_subset negative out-of-bounds", sdfg, state_id, eid, ) if any(((maxel + off) >= s) == True for maxel, s, off in zip( e.data.other_subset.max_element(), arr.shape, arr.offset)): raise InvalidSDFGEdgeError( "Memlet other_subset out-of-bounds", sdfg, state_id, eid) # Test subset and other_subset for undefined symbols if Config.get_bool('experimental', 'validate_undefs'): # TODO: Traverse by scopes and accumulate data defined_symbols = state.symbols_defined_at(e.dst) undefs = (e.data.subset.free_symbols - set(defined_symbols.keys())) if len(undefs) > 0: raise InvalidSDFGEdgeError( 'Undefined symbols %s found in memlet subset' % undefs, sdfg, state_id, eid) if e.data.other_subset is not None: undefs = (e.data.other_subset.free_symbols - set(defined_symbols.keys())) if len(undefs) > 0: raise InvalidSDFGEdgeError( 'Undefined symbols %s found in memlet ' 'other_subset' % undefs, sdfg, state_id, eid) ####################################### # Memlet path scope lifetime checks # If scope(src) == scope(dst): OK if scope[src_node] == scope[dst_node] or src_node == scope[dst_node]: pass # If scope(src) contains scope(dst), then src must be a data node, # unless the memlet is empty in order to connect to a scope elif scope_contains_scope(scope, src_node, dst_node): pass # If scope(dst) contains scope(src), then dst must be a data node, # unless the memlet is empty in order to connect to a scope elif scope_contains_scope(scope, dst_node, src_node): if not isinstance(dst_node, nd.AccessNode): if e.data.is_empty() and isinstance(dst_node, nd.ExitNode): pass else: raise InvalidSDFGEdgeError( f"Memlet creates an invalid path (sink node {dst_node}" " should be a data node)", sdfg, state_id, eid) # If scope(dst) is disjoint from scope(src), it's an illegal memlet else: raise InvalidSDFGEdgeError( "Illegal memlet between disjoint scopes", sdfg, state_id, eid) # Check dimensionality of memory access if isinstance(e.data.subset, (sbs.Range, sbs.Indices)): if e.data.subset.dims() != len(sdfg.arrays[e.data.data].shape): raise InvalidSDFGEdgeError( "Memlet subset uses the wrong dimensions" " (%dD for a %dD data node)" % (e.data.subset.dims(), len( sdfg.arrays[e.data.data].shape)), sdfg, state_id, eid, ) # Verify that source and destination subsets contain the same # number of elements if not e.data.allow_oob and e.data.other_subset is not None and not ( (isinstance(src_node, nd.AccessNode) and isinstance(sdfg.arrays[src_node.data], dt.Stream)) or (isinstance(dst_node, nd.AccessNode) and isinstance(sdfg.arrays[dst_node.data], dt.Stream))): src_expr = (e.data.src_subset.num_elements() * sdfg.arrays[src_node.data].veclen) dst_expr = (e.data.dst_subset.num_elements() * sdfg.arrays[dst_node.data].veclen) if symbolic.inequal_symbols(src_expr, dst_expr): raise InvalidSDFGEdgeError( 'Dimensionality mismatch between src/dst subsets', sdfg, state_id, eid)
def _validate_subsets(edge: graph.MultiConnectorEdge, arrays: typing.Dict[str, data.Data], src_name: str = None, dst_name: str = None) -> typing.Tuple[subsets.Subset]: """ Extracts and validates src and dst subsets from the edge. """ # Find src and dst names if not src_name and isinstance(edge.src, nodes.AccessNode): src_name = edge.src.data if not dst_name and isinstance(edge.dst, nodes.AccessNode): dst_name = edge.dst.data if not src_name and not dst_name: raise NotImplementedError # Find the src and dst subsets (deep-copy to allow manipulation) src_subset = copy.deepcopy(edge.data.src_subset) dst_subset = copy.deepcopy(edge.data.dst_subset) if not src_subset and not dst_subset: # NOTE: This should never happen raise NotImplementedError # NOTE: If any of the subsets is None, it means that we proceed in # experimental mode. The base case here is that we just copy the other # subset. However, if we can locate the other array, we check the # dimensionality of the subset and we pop or pad indices/ranges accordingly. # In that case, we also set the subset to start from 0 in each dimension. if not src_subset: if src_name: desc = arrays[src_name] if isinstance(desc, data.View) or edge.data.data == dst_name: src_subset = subsets.Range.from_array(desc) src_expr = src_subset.num_elements() src_expr_exact = src_subset.num_elements_exact() dst_expr = dst_subset.num_elements() dst_expr_exact = dst_subset.num_elements_exact() if (src_expr != dst_expr and symbolic.inequal_symbols( src_expr_exact, dst_expr_exact)): raise ValueError( "Source subset is missing (dst_subset: {}, " "src_shape: {}".format(dst_subset, desc.shape)) else: src_subset = copy.deepcopy(dst_subset) padding = len(desc.shape) - len(src_subset) if padding != 0: if padding > 0: if isinstance(src_subset, subsets.Indices): indices = [0] * padding + src_subset.indices src_subset = subsets.Indices(indices) elif isinstance(src_subset, subsets.Range): ranges = [(0, 0, 1)] * padding + src_subset.ranges src_subset = subsets.Range(ranges) elif padding < 0: if isinstance(src_subset, subsets.Indices): indices = src_subset.indices[-padding:] src_subset = subsets.Indices(indices) elif isinstance(src_subset, subsets.Range): ranges = src_subset.ranges[-padding:] src_subset = subsets.Range(ranges) src_subset.offset(src_subset, True) elif not dst_subset: if dst_name: desc = arrays[dst_name] if isinstance(desc, data.View) or edge.data.data == src_name: dst_subset = subsets.Range.from_array(desc) src_expr = src_subset.num_elements() src_expr_exact = src_subset.num_elements_exact() dst_expr = dst_subset.num_elements() dst_expr_exact = dst_subset.num_elements_exact() if (src_expr != dst_expr and symbolic.inequal_symbols( src_expr_exact, dst_expr_exact)): raise ValueError( "Destination subset is missing (src_subset: {}, " "dst_shape: {}".format(src_subset, desc.shape)) else: dst_subset = copy.deepcopy(src_subset) padding = len(desc.shape) - len(dst_subset) if padding != 0: if padding > 0: if isinstance(dst_subset, subsets.Indices): indices = [0] * padding + dst_subset.indices dst_subset = subsets.Indices(indices) elif isinstance(dst_subset, subsets.Range): ranges = [(0, 0, 1)] * padding + dst_subset.ranges dst_subset = subsets.Range(ranges) elif padding < 0: if isinstance(dst_subset, subsets.Indices): indices = dst_subset.indices[-padding:] dst_subset = subsets.Indices(indices) elif isinstance(dst_subset, subsets.Range): ranges = dst_subset.ranges[-padding:] dst_subset = subsets.Range(ranges) dst_subset.offset(dst_subset, True) return src_subset, dst_subset