def _get_copy_dispatcher(self, src_node, dst_node, edge, sdfg, dfg, state_id, function_stream, output_stream): """ (Internal) Returns a code generator that should be dispatched for a memory copy operation. """ src_is_data, dst_is_data = False, False state_dfg = sdfg.node(state_id) if isinstance(src_node, nodes.CodeNode): src_storage = dtypes.StorageType.Register else: src_storage = src_node.desc(sdfg).storage src_is_data = True if isinstance(dst_node, nodes.CodeNode): dst_storage = dtypes.StorageType.Register else: dst_storage = dst_node.desc(sdfg).storage dst_is_data = True # Skip copies to/from views where edge matches if src_is_data and isinstance(src_node.desc(sdfg), dt.View): e = sdutil.get_view_edge(state_dfg, src_node) if e is edge: return None if dst_is_data and isinstance(dst_node.desc(sdfg), dt.View): e = sdutil.get_view_edge(state_dfg, dst_node) if e is edge: return None if (isinstance(src_node, nodes.Tasklet) and not isinstance(dst_node, nodes.Tasklet)): # Special case: Copying from a tasklet to an array, schedule of # the copy is in the copying tasklet dst_schedule_node = state_dfg.entry_node(src_node) else: dst_schedule_node = state_dfg.entry_node(dst_node) if dst_schedule_node is not None: dst_schedule = dst_schedule_node.map.schedule else: dst_schedule = None if (src_storage, dst_storage, dst_schedule) in self._copy_dispatchers: disp = (src_storage, dst_storage, dst_schedule) elif (src_storage, dst_storage, None) in self._copy_dispatchers: disp = (src_storage, dst_storage, None) else: disp = None if disp is not None: # Check if the state satisfies any predicates that delegate to a # specific code generator satisfied_dispatchers = [ dispatcher for pred, dispatcher in self._copy_dispatchers[disp] if pred(sdfg, dfg, src_node, dst_node) is True ] else: satisfied_dispatchers = [] num_satisfied = len(satisfied_dispatchers) if num_satisfied > 1: raise RuntimeError( "Multiple predicates satisfied for copy: {}".format(", ".join( [type(x).__name__ for x in satisfied_dispatchers]))) elif num_satisfied == 1: target = satisfied_dispatchers[0] else: # num_satisfied == 0 # Otherwise use the generic copy dispatchers if (src_storage, dst_storage, dst_schedule) in self._generic_copy_dispatchers: target = self._generic_copy_dispatchers[(src_storage, dst_storage, dst_schedule)] elif (src_storage, dst_storage, None) in self._generic_copy_dispatchers: target = self._generic_copy_dispatchers[(src_storage, dst_storage, None)] else: raise RuntimeError( 'Copy dispatcher for %s->%s with schedule %s' % (str(src_storage), str(dst_storage), str(dst_schedule)) + ' not found') return target
def validate_state(state: 'dace.sdfg.SDFGState', state_id: int = None, sdfg: 'dace.sdfg.SDFG' = None, symbols: Dict[str, dtypes.typeclass] = None): """ Verifies the correctness of an SDFG state by applying multiple tests. Raises an InvalidSDFGError with the erroneous node on failure. """ # Avoid import loops from dace.sdfg import SDFG from dace.config import Config from dace.sdfg import nodes as nd from dace.sdfg.scope import scope_contains_scope from dace import data as dt from dace import subsets as sbs sdfg = sdfg or state.parent state_id = state_id or sdfg.node_id(state) symbols = symbols or {} if not dtypes.validate_name(state._label): raise InvalidSDFGError("Invalid state name", sdfg, state_id) if state._parent != sdfg: raise InvalidSDFGError("State does not point to the correct " "parent", sdfg, state_id) # Unreachable ######################################## if (sdfg.number_of_nodes() > 1 and sdfg.in_degree(state) == 0 and sdfg.out_degree(state) == 0): raise InvalidSDFGError("Unreachable state", sdfg, state_id) for nid, node in enumerate(state.nodes()): # Node validation try: node.validate(sdfg, state) except InvalidSDFGError: raise except Exception as ex: raise InvalidSDFGNodeError("Node validation failed: " + str(ex), sdfg, state_id, nid) from ex # Isolated nodes ######################################## if state.in_degree(node) + state.out_degree(node) == 0: # One corner case: OK if this is a code node if isinstance(node, nd.CodeNode): pass else: raise InvalidSDFGNodeError("Isolated node", sdfg, state_id, nid) # Scope tests ######################################## if isinstance(node, nd.EntryNode): try: state.exit_node(node) except StopIteration: raise InvalidSDFGNodeError( "Entry node does not have matching " "exit node", sdfg, state_id, nid, ) if isinstance(node, (nd.EntryNode, nd.ExitNode)): for iconn in node.in_connectors: if (iconn is not None and iconn.startswith("IN_") and ("OUT_" + iconn[3:]) not in node.out_connectors): raise InvalidSDFGNodeError( "No match for input connector %s in output " "connectors" % iconn, sdfg, state_id, nid, ) for oconn in node.out_connectors: if (oconn is not None and oconn.startswith("OUT_") and ("IN_" + oconn[4:]) not in node.in_connectors): raise InvalidSDFGNodeError( "No match for output connector %s in input " "connectors" % oconn, sdfg, state_id, nid, ) # Node-specific tests ######################################## if isinstance(node, nd.AccessNode): if node.data not in sdfg.arrays: raise InvalidSDFGNodeError( "Access node must point to a valid array name in the SDFG", sdfg, state_id, nid, ) arr = sdfg.arrays[node.data] # Verify View references if isinstance(arr, dt.View): from dace.sdfg import utils as sdutil # Avoid import loops if sdutil.get_view_edge(state, node) is None: raise InvalidSDFGNodeError( "Ambiguous or invalid edge to/from a View access node", sdfg, state_id, nid) # Find uninitialized transients if (arr.transient and state.in_degree(node) == 0 and state.out_degree(node) > 0 # Streams do not need to be initialized and not isinstance(arr, dt.Stream)): # Find other instances of node in predecessor states states = sdfg.predecessor_states(state) input_found = False for s in states: for onode in s.nodes(): if (isinstance(onode, nd.AccessNode) and onode.data == node.data): if s.in_degree(onode) > 0: input_found = True break if input_found: break if not input_found and node.setzero == False: warnings.warn( 'WARNING: Use of uninitialized transient "%s" in state %s' % (node.data, state.label)) # Find writes to input-only arrays only_empty_inputs = all(e.data.is_empty() for e in state.in_edges(node)) if (not arr.transient) and (not only_empty_inputs): nsdfg_node = sdfg.parent_nsdfg_node if nsdfg_node is not None: if node.data not in nsdfg_node.out_connectors: raise InvalidSDFGNodeError( 'Data descriptor %s is ' 'written to, but only given to nested SDFG as an ' 'input connector' % node.data, sdfg, state_id, nid) if (isinstance(node, nd.ConsumeEntry) and "IN_stream" not in node.in_connectors): raise InvalidSDFGNodeError( "Consume entry node must have an input stream", sdfg, state_id, nid) if (isinstance(node, nd.ConsumeEntry) and "OUT_stream" not in node.out_connectors): raise InvalidSDFGNodeError( "Consume entry node must have an internal stream", sdfg, state_id, nid, ) # Connector tests ######################################## # Check for duplicate connector names (unless it's a nested SDFG) if (len(node.in_connectors.keys() & node.out_connectors.keys()) > 0 and not isinstance(node, (nd.NestedSDFG, nd.LibraryNode))): dups = node.in_connectors.keys() & node.out_connectors.keys() raise InvalidSDFGNodeError("Duplicate connectors: " + str(dups), sdfg, state_id, nid) # Check for connectors that are also array/symbol names if isinstance(node, nd.Tasklet): for conn in node.in_connectors.keys(): if conn in sdfg.arrays or conn in symbols: raise InvalidSDFGNodeError( f"Input connector {conn} already " "defined as array or symbol", sdfg, state_id, nid) for conn in node.out_connectors.keys(): if conn in sdfg.arrays or conn in symbols: raise InvalidSDFGNodeError( f"Output connector {conn} already " "defined as array or symbol", sdfg, state_id, nid) # Check for dangling connectors (incoming) for conn in node.in_connectors: incoming_edges = 0 for e in state.in_edges(node): # Connector found if e.dst_conn == conn: incoming_edges += 1 if incoming_edges == 0: raise InvalidSDFGNodeError("Dangling in-connector %s" % conn, sdfg, state_id, nid) # Connectors may have only one incoming edge # Due to input connectors of scope exit, this is only correct # in some cases: if incoming_edges > 1 and not isinstance(node, nd.ExitNode): raise InvalidSDFGNodeError( "Connector '%s' cannot have more " "than one incoming edge, found %d" % (conn, incoming_edges), sdfg, state_id, nid, ) # Check for dangling connectors (outgoing) for conn in node.out_connectors: outgoing_edges = 0 for e in state.out_edges(node): # Connector found if e.src_conn == conn: outgoing_edges += 1 if outgoing_edges == 0: raise InvalidSDFGNodeError("Dangling out-connector %s" % conn, sdfg, state_id, nid) # In case of scope exit or code node, only one outgoing edge per # connector is allowed. if outgoing_edges > 1 and isinstance(node, (nd.ExitNode, nd.CodeNode)): raise InvalidSDFGNodeError( "Connector '%s' cannot have more " "than one outgoing edge, found %d" % (conn, outgoing_edges), sdfg, state_id, nid, ) # Check for edges to nonexistent connectors for e in state.in_edges(node): if e.dst_conn is not None and e.dst_conn not in node.in_connectors: raise InvalidSDFGNodeError( ("Memlet %s leading to " + "nonexistent connector %s") % (str(e.data), e.dst_conn), sdfg, state_id, nid, ) for e in state.out_edges(node): if e.src_conn is not None and e.src_conn not in node.out_connectors: raise InvalidSDFGNodeError( ("Memlet %s coming from " + "nonexistent connector %s") % (str(e.data), e.src_conn), sdfg, state_id, nid, ) ######################################## # Memlet checks scope = state.scope_dict() for eid, e in enumerate(state.edges()): # Edge validation try: e.data.validate(sdfg, state) except InvalidSDFGError: raise except Exception as ex: raise InvalidSDFGEdgeError("Edge validation failed: " + str(ex), sdfg, state_id, eid) # For every memlet, obtain its full path in the DFG path = state.memlet_path(e) src_node = path[0].src dst_node = path[-1].dst # Check if memlet data matches src or dst nodes if (e.data.data is not None and (isinstance(src_node, nd.AccessNode) or isinstance(dst_node, nd.AccessNode)) and (not isinstance(src_node, nd.AccessNode) or e.data.data != src_node.data) and (not isinstance(dst_node, nd.AccessNode) or e.data.data != dst_node.data)): raise InvalidSDFGEdgeError( "Memlet data does not match source or destination " "data nodes)", sdfg, state_id, eid, ) # Check memlet subset validity with respect to source/destination nodes if e.data.data is not None and e.data.allow_oob == False: subset_node = (dst_node if isinstance(dst_node, nd.AccessNode) and e.data.data == dst_node.data else src_node) other_subset_node = ( dst_node if isinstance(dst_node, nd.AccessNode) and e.data.data != dst_node.data else src_node) if isinstance(subset_node, nd.AccessNode): arr = sdfg.arrays[subset_node.data] # Dimensionality if e.data.subset.dims() != len(arr.shape): raise InvalidSDFGEdgeError( "Memlet subset does not match node dimension " "(expected %d, got %d)" % (len(arr.shape), e.data.subset.dims()), sdfg, state_id, eid, ) # Bounds if any(((minel + off) < 0) == True for minel, off in zip( e.data.subset.min_element(), arr.offset)): raise InvalidSDFGEdgeError( "Memlet subset negative out-of-bounds", sdfg, state_id, eid) if any(((maxel + off) >= s) == True for maxel, s, off in zip( e.data.subset.max_element(), arr.shape, arr.offset)): raise InvalidSDFGEdgeError("Memlet subset out-of-bounds", sdfg, state_id, eid) # Test other_subset as well if e.data.other_subset is not None and isinstance( other_subset_node, nd.AccessNode): arr = sdfg.arrays[other_subset_node.data] # Dimensionality if e.data.other_subset.dims() != len(arr.shape): raise InvalidSDFGEdgeError( "Memlet other_subset does not match node dimension " "(expected %d, got %d)" % (len(arr.shape), e.data.other_subset.dims()), sdfg, state_id, eid, ) # Bounds if any(((minel + off) < 0) == True for minel, off in zip( e.data.other_subset.min_element(), arr.offset)): raise InvalidSDFGEdgeError( "Memlet other_subset negative out-of-bounds", sdfg, state_id, eid, ) if any(((maxel + off) >= s) == True for maxel, s, off in zip( e.data.other_subset.max_element(), arr.shape, arr.offset)): raise InvalidSDFGEdgeError( "Memlet other_subset out-of-bounds", sdfg, state_id, eid) # Test subset and other_subset for undefined symbols if Config.get_bool('experimental', 'validate_undefs'): # TODO: Traverse by scopes and accumulate data defined_symbols = state.symbols_defined_at(e.dst) undefs = (e.data.subset.free_symbols - set(defined_symbols.keys())) if len(undefs) > 0: raise InvalidSDFGEdgeError( 'Undefined symbols %s found in memlet subset' % undefs, sdfg, state_id, eid) if e.data.other_subset is not None: undefs = (e.data.other_subset.free_symbols - set(defined_symbols.keys())) if len(undefs) > 0: raise InvalidSDFGEdgeError( 'Undefined symbols %s found in memlet ' 'other_subset' % undefs, sdfg, state_id, eid) ####################################### # Memlet path scope lifetime checks # If scope(src) == scope(dst): OK if scope[src_node] == scope[dst_node] or src_node == scope[dst_node]: pass # If scope(src) contains scope(dst), then src must be a data node, # unless the memlet is empty in order to connect to a scope elif scope_contains_scope(scope, src_node, dst_node): pass # If scope(dst) contains scope(src), then dst must be a data node, # unless the memlet is empty in order to connect to a scope elif scope_contains_scope(scope, dst_node, src_node): if not isinstance(dst_node, nd.AccessNode): if e.data.is_empty() and isinstance(dst_node, nd.ExitNode): pass else: raise InvalidSDFGEdgeError( f"Memlet creates an invalid path (sink node {dst_node}" " should be a data node)", sdfg, state_id, eid) # If scope(dst) is disjoint from scope(src), it's an illegal memlet else: raise InvalidSDFGEdgeError( "Illegal memlet between disjoint scopes", sdfg, state_id, eid) # Check dimensionality of memory access if isinstance(e.data.subset, (sbs.Range, sbs.Indices)): if e.data.subset.dims() != len(sdfg.arrays[e.data.data].shape): raise InvalidSDFGEdgeError( "Memlet subset uses the wrong dimensions" " (%dD for a %dD data node)" % (e.data.subset.dims(), len( sdfg.arrays[e.data.data].shape)), sdfg, state_id, eid, ) # Verify that source and destination subsets contain the same # number of elements if not e.data.allow_oob and e.data.other_subset is not None and not ( (isinstance(src_node, nd.AccessNode) and isinstance(sdfg.arrays[src_node.data], dt.Stream)) or (isinstance(dst_node, nd.AccessNode) and isinstance(sdfg.arrays[dst_node.data], dt.Stream))): src_expr = (e.data.src_subset.num_elements() * sdfg.arrays[src_node.data].veclen) dst_expr = (e.data.dst_subset.num_elements() * sdfg.arrays[dst_node.data].veclen) if symbolic.inequal_symbols(src_expr, dst_expr): raise InvalidSDFGEdgeError( 'Dimensionality mismatch between src/dst subsets', sdfg, state_id, eid)
def can_be_applied(graph, candidate, expr_index, sdfg, strict=False): first_map_exit = graph.nodes()[candidate[MapFusion.first_map_exit]] first_map_entry = graph.entry_node(first_map_exit) second_map_entry = graph.nodes()[candidate[MapFusion.second_map_entry]] second_map_exit = graph.exit_node(second_map_entry) for _in_e in graph.in_edges(first_map_exit): if _in_e.data.wcr is not None: for _out_e in graph.out_edges(second_map_entry): if _out_e.data.data == _in_e.data.data: # wcr is on a node that is used in the second map, quit return False # Check whether there is a pattern map -> access -> map. intermediate_nodes = set() intermediate_data = set() for _, _, dst, _, _ in graph.out_edges(first_map_exit): if isinstance(dst, nodes.AccessNode): intermediate_nodes.add(dst) intermediate_data.add(dst.data) # If array is used anywhere else in this state. num_occurrences = len([ n for s in sdfg.nodes() for n in s.nodes() if isinstance(n, nodes.AccessNode) and n.data == dst.data ]) if num_occurrences > 1: return False else: return False # Check map ranges perm = MapFusion.find_permutation(first_map_entry.map, second_map_entry.map) if perm is None: return False # Check if any intermediate transient is also going to another location second_inodes = set(e.src for e in graph.in_edges(second_map_entry) if isinstance(e.src, nodes.AccessNode)) transients_to_remove = intermediate_nodes & second_inodes # if any(e.dst != second_map_entry for n in transients_to_remove # for e in graph.out_edges(n)): if any(graph.out_degree(n) > 1 for n in transients_to_remove): return False # Create a dict that maps parameters of the first map to those of the # second map. params_dict = {} for _index, _param in enumerate(second_map_entry.map.params): params_dict[_param] = first_map_entry.map.params[perm[_index]] # Create intermediate dicts to avoid conflicts, such as {i:j, j:i} repldict = { symbolic.pystr_to_symbolic(k): symbolic.pystr_to_symbolic('__dacesym_' + str(v)) for k, v in params_dict.items() } repldict_inv = { symbolic.pystr_to_symbolic('__dacesym_' + str(v)): symbolic.pystr_to_symbolic(v) for v in params_dict.values() } out_memlets = [e.data for e in graph.in_edges(first_map_exit)] # Check that input set of second map is provided by the output set # of the first map, or other unrelated maps for second_edge in graph.out_edges(second_map_entry): # Memlets that do not come from one of the intermediate arrays if second_edge.data.data not in intermediate_data: # however, if intermediate_data eventually leads to # second_memlet.data, need to fail. for _n in intermediate_nodes: source_node = _n destination_node = graph.memlet_path(second_edge)[0].src # NOTE: Assumes graph has networkx version if destination_node in nx.descendants( graph._nx, source_node): return False continue provided = False # Compute second subset with respect to first subset's symbols sbs_permuted = dcpy(second_edge.data.subset) if sbs_permuted: sbs_permuted.replace(repldict) sbs_permuted.replace(repldict_inv) for first_memlet in out_memlets: if first_memlet.data != second_edge.data.data: continue # If there is a covered subset, it is provided if first_memlet.subset.covers(sbs_permuted): provided = True break # If none of the output memlets of the first map provide the info, # fail. if provided is False: return False # Checking for stencil pattern and common input/output data # (after fusing the maps) first_map_inputnodes = { e.src: e.src.data for e in graph.in_edges(first_map_entry) if isinstance(e.src, nodes.AccessNode) } input_views = set() viewed_inputnodes = dict() for n in first_map_inputnodes.keys(): if isinstance(n.desc(sdfg), data.View): input_views.add(n) for v in input_views: del first_map_inputnodes[v] e = sdutil.get_view_edge(graph, v) if e: first_map_inputnodes[e.src] = e.src.data viewed_inputnodes[e.src.data] = v second_map_outputnodes = { e.dst: e.dst.data for e in graph.out_edges(second_map_exit) if isinstance(e.dst, nodes.AccessNode) } output_views = set() viewed_outputnodes = dict() for n in second_map_outputnodes: if isinstance(n.desc(sdfg), data.View): output_views.add(n) for v in output_views: del second_map_outputnodes[v] e = sdutil.get_view_edge(graph, v) if e: second_map_outputnodes[e.dst] = e.dst.data viewed_outputnodes[e.dst.data] = v common_data = set(first_map_inputnodes.values()).intersection( set(second_map_outputnodes.values())) if common_data: input_data = [ viewed_inputnodes[d].data if d in viewed_inputnodes.keys() else d for d in common_data ] input_accesses = [ graph.memlet_path(e)[-1].data.src_subset for e in graph.out_edges(first_map_entry) if e.data.data in input_data ] if len(input_accesses) > 1: for i, a in enumerate(input_accesses[:-1]): for b in input_accesses[i + 1:]: if isinstance(a, subsets.Indices): c = subsets.Range.from_indices(a) c.offset(b, negative=True) else: c = a.offset_new(b, negative=True) for r in c: if r != (0, 0, 1): return False output_data = [ viewed_outputnodes[d].data if d in viewed_outputnodes.keys() else d for d in common_data ] output_accesses = [ graph.memlet_path(e)[0].data.dst_subset for e in graph.in_edges(second_map_exit) if e.data.data in output_data ] # Compute output accesses with respect to first map's symbols oacc_permuted = [dcpy(a) for a in output_accesses] for a in oacc_permuted: a.replace(repldict) a.replace(repldict_inv) a = input_accesses[0] for b in oacc_permuted: if isinstance(a, subsets.Indices): c = subsets.Range.from_indices(a) c.offset(b, negative=True) else: c = a.offset_new(b, negative=True) for r in c: if r != (0, 0, 1): return False # Success return True
def can_be_applied(graph, candidate, expr_index, sdfg, strict=False): in_array = graph.nodes()[candidate[RedundantSecondArray._in_array]] out_array = graph.nodes()[candidate[RedundantSecondArray._out_array]] in_desc = in_array.desc(sdfg) out_desc = out_array.desc(sdfg) # Ensure in degree is one (only one source, which is in_array) if graph.in_degree(out_array) != 1: return False # Make sure that the candidate is a transient variable if not out_desc.transient: return False # 1. Get edge e1 and extract/validate subsets for arrays A and B e1 = graph.edges_between(in_array, out_array)[0] a_subset, b1_subset = _validate_subsets(e1, sdfg.arrays) if strict: # In strict mode, make sure the memlet covers the removed array if not b1_subset: return False subset = copy.deepcopy(b1_subset) subset.squeeze() shape = [sz for sz in out_desc.shape if sz != 1] if any(m != a for m, a in zip(subset.size(), shape)): return False # NOTE: Library node check # The transformation must not apply in strict mode if out_array is # not a view, is input to a library node, and an access or a view # of in_desc is also output to the same library node. # The reason is that the application of the transformation will lead # to in_desc being both input and output of the library node. # We do not know if this is safe. # First find the true in_desc (in case in_array is a view). true_in_desc = in_desc if isinstance(in_desc, data.View): e = sdutil.get_view_edge(graph, in_array) if not e: return False true_in_desc = sdfg.arrays[e.dst.data] if not isinstance(out_desc, data.View): edges_to_check = [] for a in graph.out_edges(out_array): if isinstance(a.dst, nodes.LibraryNode): edges_to_check.append(a) elif (isinstance(a.dst, nodes.AccessNode) and isinstance(sdfg.arrays[a.dst.data], data.View)): for b in graph.out_edges(a.dst): edges_to_check.append(graph.memlet_path(b)[-1]) for a in edges_to_check: if isinstance(a.dst, nodes.LibraryNode): for b in graph.out_edges(a.dst): if isinstance(b.dst, nodes.AccessNode): desc = sdfg.arrays[b.dst.data] if isinstance(desc, data.View): e = sdutil.get_view_edge(graph, b.dst) if not e: return False desc = sdfg.arrays[e.dst.data] if desc is true_in_desc: return False # In strict mode, check if the state has two or more access nodes # for in_array and at least one of them is a write access. There # might be a RW, WR, or WW dependency. accesses = [ n for n in graph.nodes() if isinstance(n, nodes.AccessNode) and n.desc(sdfg) == in_desc and n is not in_array ] if len(accesses) > 0: if (graph.in_degree(in_array) > 0 or any(graph.in_degree(a) > 0 for a in accesses)): # We need to ensure that a data race will not happen if we # remove in_array. # First, we simplify the graph G = helpers.simplify_state(graph) # Loop over the accesses for a in accesses: subsets_intersect = False for e in graph.in_edges(a): _, subset = _validate_subsets(e, sdfg.arrays, dst_name=a.data) res = subsets.intersects(a_subset, subset) if res == True or res is None: subsets_intersect = True break if not subsets_intersect: continue try: has_bward_path = nx.has_path(G, a, in_array) except NodeNotFound: has_bward_path = nx.has_path(graph.nx, a, in_array) try: has_fward_path = nx.has_path(G, in_array, a) except NodeNotFound: has_fward_path = nx.has_path(graph.nx, in_array, a) # If there is no path between the access nodes # (disconnected components), then it is definitely # possible to have data races. Abort. if not (has_bward_path or has_fward_path): return False # If there is a forward path then a must not be a direct # successor of in_array. if has_fward_path and a in G.successors(in_array): for src, _ in G.in_edges(a): if src is in_array: continue if (nx.has_path(G, in_array, src) and src != out_array): continue return False # Make sure that both arrays are using the same storage location # and are of the same type (e.g., Stream->Stream) if in_desc.storage != out_desc.storage: return False if in_desc.location != out_desc.location: return False if type(in_desc) != type(out_desc): if isinstance(in_desc, data.View): # Case View -> Access # If the View points to the Access (and has a different shape?) # then we should (probably) not remove the Access. e = sdutil.get_view_edge(graph, in_array) if e and e.dst is out_array and in_desc.shape != out_desc.shape: return False # Check that the View's immediate ancestors are Accesses. # Otherwise, the application of the transformation will result # in an ambiguous View. view_ancestors_desc = [ e.src.desc(sdfg) if isinstance(e.src, nodes.AccessNode) else None for e in graph.in_edges(in_array) ] if any([ not desc or isinstance(desc, data.View) for desc in view_ancestors_desc ]): return False elif isinstance(out_desc, data.View): # Case Access -> View # If the View points to the Access and has the same shape, # it can be removed e = sdutil.get_view_edge(graph, out_array) if e and e.src is in_array and in_desc.shape == out_desc.shape: return True return False else: # Something else, for example, Stream return False else: # Two views connected to each other if isinstance(in_desc, data.View): return False # Find occurrences in this and other states occurrences = [] for state in sdfg.nodes(): occurrences.extend([ n for n in state.nodes() if isinstance(n, nodes.AccessNode) and n.desc(sdfg) == out_desc ]) for isedge in sdfg.edges(): if out_array.data in isedge.data.free_symbols: occurrences.append(isedge) if len(occurrences) > 1: return False # Check whether the data copied from the first datanode cover # the subsets of all the output edges of the second datanode. # We assume the following pattern: A -- e1 --> B -- e2 --> others # 2. Iterate over the e2 edges for e2 in graph.out_edges(out_array): # 2-a. Extract/validate subsets for array B and others try: b2_subset, _ = _validate_subsets(e2, sdfg.arrays) except NotImplementedError: return False # 2-b. Check where b1_subset covers b2_subset if not b1_subset.covers(b2_subset): return False # 2-c. Validate subsets in memlet tree # (should not be needed for valid SDGs) path = graph.memlet_tree(e2) for e3 in path: if e3 is not e2: try: _validate_subsets(e3, sdfg.arrays, src_name=out_array.data) except NotImplementedError: return False return True