def apply(self, _, sdfg): state = self.state # Find source/sink (data) nodes that are relevant outside this FPGA # kernel shared_transients = set(sdfg.shared_transients()) input_nodes = [ n for n in sdutil.find_source_nodes(state) if isinstance(n, nodes.AccessNode) and (not sdfg.arrays[n.data].transient or n.data in shared_transients) ] output_nodes = [ n for n in sdutil.find_sink_nodes(state) if isinstance(n, nodes.AccessNode) and (not sdfg.arrays[n.data].transient or n.data in shared_transients) ] fpga_data = {} # Input nodes may also be nodes with WCR memlets # We have to recur across nested SDFGs to find them wcr_input_nodes = set() stack = [] parent_sdfg = {state: sdfg} # Map states to their parent SDFG for node, graph in state.all_nodes_recursive(): if isinstance(graph, dace.SDFG): parent_sdfg[node] = graph if isinstance(node, dace.sdfg.nodes.AccessNode): for e in graph.in_edges(node): if e.data.wcr is not None: trace = dace.sdfg.trace_nested_access( node, graph, parent_sdfg[graph]) for node_trace, memlet_trace, state_trace, sdfg_trace in trace: # Find the name of the accessed node in our scope if state_trace == state and sdfg_trace == sdfg: _, outer_node = node_trace if outer_node is not None: break else: # This does not trace back to the current state, so # we don't care continue input_nodes.append(outer_node) wcr_input_nodes.add(outer_node) if input_nodes: # create pre_state pre_state = sd.SDFGState('pre_' + state.label, sdfg) for node in input_nodes: if not isinstance(node, dace.sdfg.nodes.AccessNode): continue desc = node.desc(sdfg) if not isinstance(desc, dace.data.Array): # TODO: handle streams continue if node.data in fpga_data: fpga_array = fpga_data[node.data] elif node not in wcr_input_nodes: fpga_array = sdfg.add_array( 'fpga_' + node.data, desc.shape, desc.dtype, transient=True, storage=dtypes.StorageType.FPGA_Global, allow_conflicts=desc.allow_conflicts, strides=desc.strides, offset=desc.offset) fpga_array[1].location = copy.copy(desc.location) desc.location.clear() fpga_data[node.data] = fpga_array pre_node = pre_state.add_read(node.data) pre_fpga_node = pre_state.add_write('fpga_' + node.data) mem = memlet.Memlet(data=node.data, subset=subsets.Range.from_array(desc)) pre_state.add_edge(pre_node, None, pre_fpga_node, None, mem) if node not in wcr_input_nodes: fpga_node = state.add_read('fpga_' + node.data) sdutil.change_edge_src(state, node, fpga_node) state.remove_node(node) sdfg.add_node(pre_state) sdutil.change_edge_dest(sdfg, state, pre_state) sdfg.add_edge(pre_state, state, sd.InterstateEdge()) if output_nodes: post_state = sd.SDFGState('post_' + state.label, sdfg) for node in output_nodes: if not isinstance(node, dace.sdfg.nodes.AccessNode): continue desc = node.desc(sdfg) if not isinstance(desc, dace.data.Array): # TODO: handle streams continue if node.data in fpga_data: fpga_array = fpga_data[node.data] else: fpga_array = sdfg.add_array( 'fpga_' + node.data, desc.shape, desc.dtype, transient=True, storage=dtypes.StorageType.FPGA_Global, allow_conflicts=desc.allow_conflicts, strides=desc.strides, offset=desc.offset) fpga_array[1].location = copy.copy(desc.location) desc.location.clear() fpga_data[node.data] = fpga_array # fpga_node = type(node)(fpga_array) post_node = post_state.add_write(node.data) post_fpga_node = post_state.add_read('fpga_' + node.data) mem = memlet.Memlet(f"fpga_{node.data}", None, subsets.Range.from_array(desc)) post_state.add_edge(post_fpga_node, None, post_node, None, mem) fpga_node = state.add_write('fpga_' + node.data) sdutil.change_edge_dest(state, node, fpga_node) state.remove_node(node) sdfg.add_node(post_state) sdutil.change_edge_src(sdfg, state, post_state) sdfg.add_edge(state, post_state, sd.InterstateEdge()) # propagate memlet info from a nested sdfg for src, src_conn, dst, dst_conn, mem in state.edges(): if mem.data is not None and mem.data in fpga_data: mem.data = 'fpga_' + mem.data fpga_update(sdfg, state, 0)
def apply(self, sdfg): state = sdfg.nodes()[self.subgraph[FPGATransformState._state]] # Find source/sink (data) nodes input_nodes = sdutil.find_source_nodes(state) output_nodes = sdutil.find_sink_nodes(state) fpga_data = {} # Input nodes may also be nodes with WCR memlets # We have to recur across nested SDFGs to find them wcr_input_nodes = set() stack = [] parent_sdfg = {state: sdfg} # Map states to their parent SDFG for node, graph in state.all_nodes_recursive(): if isinstance(graph, dace.SDFG): parent_sdfg[node] = graph if isinstance(node, dace.sdfg.nodes.AccessNode): for e in graph.all_edges(node): if e.data.wcr is not None: trace = dace.sdfg.trace_nested_access( node, graph, parent_sdfg[graph]) for node_trace, state_trace, sdfg_trace in trace: # Find the name of the accessed node in our scope if state_trace == state and sdfg_trace == sdfg: outer_node = node_trace break else: # This does not trace back to the current state, so # we don't care continue input_nodes.append(outer_node) wcr_input_nodes.add(outer_node) if input_nodes: # create pre_state pre_state = sd.SDFGState('pre_' + state.label, sdfg) for node in input_nodes: if not isinstance(node, dace.sdfg.nodes.AccessNode): continue desc = node.desc(sdfg) if not isinstance(desc, dace.data.Array): # TODO: handle streams continue if node.data in fpga_data: fpga_array = fpga_data[node.data] elif node not in wcr_input_nodes: fpga_array = sdfg.add_array( 'fpga_' + node.data, desc.shape, desc.dtype, materialize_func=desc.materialize_func, transient=True, storage=dtypes.StorageType.FPGA_Global, allow_conflicts=desc.allow_conflicts, strides=desc.strides, offset=desc.offset) fpga_data[node.data] = fpga_array pre_node = pre_state.add_read(node.data) pre_fpga_node = pre_state.add_write('fpga_' + node.data) full_range = subsets.Range([(0, s - 1, 1) for s in desc.shape]) mem = memlet.Memlet(node.data, full_range.num_elements(), full_range, 1) pre_state.add_edge(pre_node, None, pre_fpga_node, None, mem) if node not in wcr_input_nodes: fpga_node = state.add_read('fpga_' + node.data) sdutil.change_edge_src(state, node, fpga_node) state.remove_node(node) sdfg.add_node(pre_state) sdutil.change_edge_dest(sdfg, state, pre_state) sdfg.add_edge(pre_state, state, sd.InterstateEdge()) if output_nodes: post_state = sd.SDFGState('post_' + state.label, sdfg) for node in output_nodes: if not isinstance(node, dace.sdfg.nodes.AccessNode): continue desc = node.desc(sdfg) if not isinstance(desc, dace.data.Array): # TODO: handle streams continue if node.data in fpga_data: fpga_array = fpga_data[node.data] else: fpga_array = sdfg.add_array( 'fpga_' + node.data, desc.shape, desc.dtype, materialize_func=desc.materialize_func, transient=True, storage=dtypes.StorageType.FPGA_Global, allow_conflicts=desc.allow_conflicts, strides=desc.strides, offset=desc.offset) fpga_data[node.data] = fpga_array # fpga_node = type(node)(fpga_array) post_node = post_state.add_write(node.data) post_fpga_node = post_state.add_read('fpga_' + node.data) full_range = subsets.Range([(0, s - 1, 1) for s in desc.shape]) mem = memlet.Memlet('fpga_' + node.data, full_range.num_elements(), full_range, 1) post_state.add_edge(post_fpga_node, None, post_node, None, mem) fpga_node = state.add_write('fpga_' + node.data) sdutil.change_edge_dest(state, node, fpga_node) state.remove_node(node) sdfg.add_node(post_state) sdutil.change_edge_src(sdfg, state, post_state) sdfg.add_edge(state, post_state, sd.InterstateEdge()) veclen_ = 1 # propagate vector info from a nested sdfg for src, src_conn, dst, dst_conn, mem in state.edges(): # need to go inside the nested SDFG and grab the vector length if isinstance(dst, dace.sdfg.nodes.NestedSDFG): # this edge is going to the nested SDFG for inner_state in dst.sdfg.states(): for n in inner_state.nodes(): if isinstance(n, dace.sdfg.nodes.AccessNode ) and n.data == dst_conn: # assuming all memlets have the same vector length veclen_ = inner_state.all_edges(n)[0].data.veclen if isinstance(src, dace.sdfg.nodes.NestedSDFG): # this edge is coming from the nested SDFG for inner_state in src.sdfg.states(): for n in inner_state.nodes(): if isinstance(n, dace.sdfg.nodes.AccessNode ) and n.data == src_conn: # assuming all memlets have the same vector length veclen_ = inner_state.all_edges(n)[0].data.veclen if mem.data is not None and mem.data in fpga_data: mem.data = 'fpga_' + mem.data mem.veclen = veclen_ fpga_update(sdfg, state, 0)
def can_be_applied(graph, candidate, expr_index, sdfg, strict=False): first_state = graph.nodes()[candidate[StateFusion._first_state]] second_state = graph.nodes()[candidate[StateFusion._second_state]] out_edges = graph.out_edges(first_state) in_edges = graph.in_edges(first_state) # First state must have only one output edge (with dst the second # state). if len(out_edges) != 1: return False # The interstate edge must not have a condition. if not out_edges[0].data.is_unconditional(): return False # The interstate edge may have assignments, as long as there are input # edges to the first state that can absorb them. if out_edges[0].data.assignments: if not in_edges: return False # Fail if symbol is set before the state to fuse # TODO: Also fail if symbol is used in the dataflow of that state new_assignments = set(out_edges[0].data.assignments.keys()) if any((new_assignments & set(e.data.assignments.keys())) for e in in_edges): return False # There can be no state that have output edges pointing to both the # first and the second state. Such a case will produce a multi-graph. for src, _, _ in in_edges: for _, dst, _ in graph.out_edges(src): if dst == second_state: return False if strict: # If second state has other input edges, there might be issues # Exceptions are when none of the states contain dataflow, unless # the first state is an initial state (in which case the new initial # state would be ambiguous). first_in_edges = graph.in_edges(first_state) second_in_edges = graph.in_edges(second_state) if ((not second_state.is_empty() or not first_state.is_empty() or len(first_in_edges) == 0) and len(second_in_edges) != 1): return False # Get connected components. first_cc = [ cc_nodes for cc_nodes in nx.weakly_connected_components(first_state._nx) ] second_cc = [ cc_nodes for cc_nodes in nx.weakly_connected_components( second_state._nx) ] # Find source/sink (data) nodes first_input = { node for node in sdutil.find_source_nodes(first_state) if isinstance(node, nodes.AccessNode) } first_output = { node for node in first_state.nodes() if isinstance(node, nodes.AccessNode) and node not in first_input } second_input = { node for node in sdutil.find_source_nodes(second_state) if isinstance(node, nodes.AccessNode) } second_output = { node for node in second_state.nodes() if isinstance(node, nodes.AccessNode) and node not in second_input } # Find source/sink (data) nodes by connected component first_cc_input = [cc.intersection(first_input) for cc in first_cc] first_cc_output = [ cc.intersection(first_output) for cc in first_cc ] second_cc_input = [ cc.intersection(second_input) for cc in second_cc ] second_cc_output = [ cc.intersection(second_output) for cc in second_cc ] # Apply transformation in case all paths to the second state's # nodes go through the same access node, which implies sequential # behavior in SDFG semantics. check_strict = len(first_cc) for cc_output in first_cc_output: out_nodes = [ n for n in first_state.sink_nodes() if n in cc_output ] # Branching exists, multiple paths may involve same access node # potentially causing data races if len(out_nodes) > 1: continue # Otherwise, check if any of the second state's connected # components for matching input for node in out_nodes: if (next( (x for x in second_input if x.label == node.label), None) is not None): check_strict -= 1 break if check_strict > 0: # Check strict conditions # RW dependency for node in first_input: if (next( (x for x in second_output if x.label == node.label), None) is not None): return False # WW dependency for node in first_output: if (next( (x for x in second_output if x.label == node.label), None) is not None): return False return True
def apply(self, sdfg): first_state = sdfg.nodes()[self.subgraph[StateFusion._first_state]] second_state = sdfg.nodes()[self.subgraph[StateFusion._second_state]] # Remove interstate edge(s) edges = sdfg.edges_between(first_state, second_state) for edge in edges: if edge.data.assignments: for src, dst, other_data in sdfg.in_edges(first_state): other_data.assignments.update(edge.data.assignments) sdfg.remove_edge(edge) # Special case 1: first state is empty if first_state.is_empty(): sdutil.change_edge_dest(sdfg, first_state, second_state) sdfg.remove_node(first_state) return # Special case 2: second state is empty if second_state.is_empty(): sdutil.change_edge_src(sdfg, second_state, first_state) sdutil.change_edge_dest(sdfg, second_state, first_state) sdfg.remove_node(second_state) return # Normal case: both states are not empty # Find source/sink (data) nodes first_input = [ node for node in sdutil.find_source_nodes(first_state) if isinstance(node, nodes.AccessNode) ] first_output = [ node for node in sdutil.find_sink_nodes(first_state) if isinstance(node, nodes.AccessNode) ] second_input = [ node for node in sdutil.find_source_nodes(second_state) if isinstance(node, nodes.AccessNode) ] # first input = first input - first output first_input = [ node for node in first_input if next((x for x in first_output if x.label == node.label), None) is None ] # Merge second state to first state # First keep a backup of the topological sorted order of the nodes order = [ x for x in reversed(list(nx.topological_sort(first_state._nx))) if isinstance(x, nodes.AccessNode) ] for node in second_state.nodes(): first_state.add_node(node) for src, src_conn, dst, dst_conn, data in second_state.edges(): first_state.add_edge(src, src_conn, dst, dst_conn, data) # Merge common (data) nodes for node in second_input: if first_state.in_degree(node) == 0: n = next((x for x in order if x.label == node.label), None) if n: sdutil.change_edge_src(first_state, node, n) first_state.remove_node(node) n.access = dtypes.AccessType.ReadWrite # Redirect edges and remove second state sdutil.change_edge_src(sdfg, second_state, first_state) sdfg.remove_node(second_state) if Config.get_bool("debugprint"): StateFusion._states_fused += 1
def apply(self, sdfg): if isinstance(self.subgraph[StateFusion.first_state], SDFGState): first_state: SDFGState = self.subgraph[StateFusion.first_state] second_state: SDFGState = self.subgraph[StateFusion.second_state] else: first_state: SDFGState = sdfg.node( self.subgraph[StateFusion.first_state]) second_state: SDFGState = sdfg.node( self.subgraph[StateFusion.second_state]) # Remove interstate edge(s) edges = sdfg.edges_between(first_state, second_state) for edge in edges: if edge.data.assignments: for src, dst, other_data in sdfg.in_edges(first_state): other_data.assignments.update(edge.data.assignments) sdfg.remove_edge(edge) # Special case 1: first state is empty if first_state.is_empty(): sdutil.change_edge_dest(sdfg, first_state, second_state) sdfg.remove_node(first_state) if sdfg.start_state == first_state: sdfg.start_state = sdfg.node_id(second_state) return # Special case 2: second state is empty if second_state.is_empty(): sdutil.change_edge_src(sdfg, second_state, first_state) sdutil.change_edge_dest(sdfg, second_state, first_state) sdfg.remove_node(second_state) if sdfg.start_state == second_state: sdfg.start_state = sdfg.node_id(first_state) return # Normal case: both states are not empty # Find source/sink (data) nodes first_input = [ node for node in sdutil.find_source_nodes(first_state) if isinstance(node, nodes.AccessNode) ] first_output = [ node for node in sdutil.find_sink_nodes(first_state) if isinstance(node, nodes.AccessNode) ] second_input = [ node for node in sdutil.find_source_nodes(second_state) if isinstance(node, nodes.AccessNode) ] top2 = top_level_nodes(second_state) # first input = first input - first output first_input = [ node for node in first_input if next((x for x in first_output if x.data == node.data), None) is None ] # Merge second state to first state # First keep a backup of the topological sorted order of the nodes sdict = first_state.scope_dict() order = [ x for x in reversed(list(nx.topological_sort(first_state._nx))) if isinstance(x, nodes.AccessNode) and sdict[x] is None ] for node in second_state.nodes(): if isinstance(node, nodes.NestedSDFG): # update parent information node.sdfg.parent = first_state first_state.add_node(node) for src, src_conn, dst, dst_conn, data in second_state.edges(): first_state.add_edge(src, src_conn, dst, dst_conn, data) top = top_level_nodes(first_state) # Merge common (data) nodes for node in second_input: # merge only top level nodes, skip everything else if node not in top2: continue if first_state.in_degree(node) == 0: candidates = [ x for x in order if x.data == node.data and x in top ] if len(candidates) == 0: continue elif len(candidates) == 1: n = candidates[0] else: # Choose first candidate that intersects memlets for cand in candidates: if StateFusion.memlets_intersect( first_state, [cand], False, second_state, [node], True): n = cand break else: # No node intersects, use topologically-last node n = candidates[0] sdutil.change_edge_src(first_state, node, n) first_state.remove_node(node) n.access = dtypes.AccessType.ReadWrite # Redirect edges and remove second state sdutil.change_edge_src(sdfg, second_state, first_state) sdfg.remove_node(second_state) if sdfg.start_state == second_state: sdfg.start_state = sdfg.node_id(first_state)
def can_be_applied(graph, candidate, expr_index, sdfg, strict=False): # Workaround for supporting old and new conventions if isinstance(candidate[StateFusion.first_state], SDFGState): first_state: SDFGState = candidate[StateFusion.first_state] second_state: SDFGState = candidate[StateFusion.second_state] else: first_state: SDFGState = graph.node( candidate[StateFusion.first_state]) second_state: SDFGState = graph.node( candidate[StateFusion.second_state]) out_edges = graph.out_edges(first_state) in_edges = graph.in_edges(first_state) # First state must have only one output edge (with dst the second # state). if len(out_edges) != 1: return False # If both states have more than one incoming edge, some control flow # may become ambiguous if len(in_edges) > 1 and graph.in_degree(second_state) > 1: return False # The interstate edge must not have a condition. if not out_edges[0].data.is_unconditional(): return False # The interstate edge may have assignments, as long as there are input # edges to the first state that can absorb them. if out_edges[0].data.assignments: if not in_edges: return False # Fail if symbol is set before the state to fuse new_assignments = set(out_edges[0].data.assignments.keys()) if any((new_assignments & set(e.data.assignments.keys())) for e in in_edges): return False # Fail if symbol is used in the dataflow of that state if len(new_assignments & first_state.free_symbols) > 0: return False # Fail if assignments have free symbols that are updated in the # first state freesyms = out_edges[0].data.free_symbols if freesyms and any(n.data in freesyms for n in first_state.nodes() if isinstance(n, nodes.AccessNode) and first_state.in_degree(n) > 0): return False # Fail if symbols assigned on the first edge are free symbols on the # second edge symbols_used = set(out_edges[0].data.free_symbols) for e in in_edges: if e.data.assignments.keys() & symbols_used: return False # There can be no state that have output edges pointing to both the # first and the second state. Such a case will produce a multi-graph. for src, _, _ in in_edges: for _, dst, _ in graph.out_edges(src): if dst == second_state: return False if strict: # NOTE: This is quick fix for MPI Waitall (probably also needed for # Wait), until we have a better SDFG representation of the buffer # dependencies. try: from dace.libraries.mpi import Waitall next(node for node in first_state.nodes() if isinstance(node, Waitall) or node.label == '_Waitall_') return False except StopIteration: pass try: from dace.libraries.mpi import Waitall next(node for node in second_state.nodes() if isinstance(node, Waitall) or node.label == '_Waitall_') return False except StopIteration: pass # If second state has other input edges, there might be issues # Exceptions are when none of the states contain dataflow, unless # the first state is an initial state (in which case the new initial # state would be ambiguous). first_in_edges = graph.in_edges(first_state) second_in_edges = graph.in_edges(second_state) if ((not second_state.is_empty() or not first_state.is_empty() or len(first_in_edges) == 0) and len(second_in_edges) != 1): return False # Get connected components. first_cc = [ cc_nodes for cc_nodes in nx.weakly_connected_components(first_state._nx) ] second_cc = [ cc_nodes for cc_nodes in nx.weakly_connected_components(second_state._nx) ] # Find source/sink (data) nodes first_input = { node for node in sdutil.find_source_nodes(first_state) if isinstance(node, nodes.AccessNode) } first_output = { node for node in first_state.scope_children()[None] if isinstance(node, nodes.AccessNode) and node not in first_input } second_input = { node for node in sdutil.find_source_nodes(second_state) if isinstance(node, nodes.AccessNode) } second_output = { node for node in second_state.scope_children()[None] if isinstance(node, nodes.AccessNode) and node not in second_input } # Find source/sink (data) nodes by connected component first_cc_input = [cc.intersection(first_input) for cc in first_cc] first_cc_output = [cc.intersection(first_output) for cc in first_cc] second_cc_input = [ cc.intersection(second_input) for cc in second_cc ] second_cc_output = [ cc.intersection(second_output) for cc in second_cc ] # Apply transformation in case all paths to the second state's # nodes go through the same access node, which implies sequential # behavior in SDFG semantics. first_output_names = {node.data for node in first_output} second_input_names = {node.data for node in second_input} # If any second input appears more than once, fail if len(second_input) > len(second_input_names): return False # If any first output that is an input to the second state # appears in more than one CC, fail matches = first_output_names & second_input_names for match in matches: cc_appearances = 0 for cc in first_cc_output: if len([n for n in cc if n.data == match]) > 0: cc_appearances += 1 if cc_appearances > 1: return False # Recreate fused connected component correspondences, and then # check for hazards resulting_ccs: List[CCDesc] = StateFusion.find_fused_components( first_cc_input, first_cc_output, second_cc_input, second_cc_output) # Check for data races for fused_cc in resulting_ccs: # Write-Write hazard - data is output of both first and second # states, without a read in between write_write_candidates = ( (fused_cc.first_outputs & fused_cc.second_outputs) - fused_cc.second_inputs) # Find the leaf (topological) instances of the matches order = [ x for x in reversed( list(nx.topological_sort(first_state._nx))) if isinstance(x, nodes.AccessNode) and x.data in fused_cc.first_outputs ] # Those nodes will be the connection points upon fusion match_nodes = { next(n for n in order if n.data == match) for match in (fused_cc.first_outputs & fused_cc.second_inputs) } # If we have potential candidates, check if there is a # path from the first write to the second write (in that # case, there is no hazard): for cand in write_write_candidates: nodes_first = [n for n in first_output if n.data == cand] nodes_second = [n for n in second_output if n.data == cand] # If there is a path for the candidate that goes through # the match nodes in both states, there is no conflict fail = False path_found = False for match in match_nodes: for node in nodes_first: path_to = nx.has_path(first_state._nx, node, match) if not path_to: continue path_found = True node2 = next(n for n in second_input if n.data == match.data) if not all( nx.has_path(second_state._nx, node2, n) for n in nodes_second): fail = True break if fail or path_found: break # Check for intersection (if None, fusion is ok) if fail or not path_found: if StateFusion.memlets_intersect( first_state, nodes_first, False, second_state, nodes_second, False): return False # End of write-write hazard check first_inout = fused_cc.first_inputs | fused_cc.first_outputs for other_cc in resulting_ccs: # NOTE: Special handling for `other_cc is fused_cc` if other_cc is fused_cc: # Checking for potential Read-Write data races for d in first_inout: if d in other_cc.second_outputs: nodes_second = [ n for n in second_output if n.data == d ] # Read-Write race if d in fused_cc.first_inputs: nodes_first = [ n for n in first_input if n.data == d ] else: nodes_first = [] for n2 in nodes_second: for e in second_state.in_edges(n2): path = second_state.memlet_path(e) src = path[0].src if src in second_input and src.data in fused_cc.first_outputs: for n1 in fused_cc.first_output_nodes: if n1.data == src.data: for n0 in nodes_first: if not nx.has_path( first_state._nx, n0, n1): return False continue # If an input/output of a connected component in the first # state is an output of another connected component in the # second state, we have a potential data race (Read-Write # or Write-Write) for d in first_inout: if d in other_cc.second_outputs: # Check for intersection (if None, fusion is ok) nodes_second = [ n for n in second_output if n.data == d ] # Read-Write race if d in fused_cc.first_inputs: nodes_first = [ n for n in first_input if n.data == d ] if StateFusion.memlets_intersect( first_state, nodes_first, True, second_state, nodes_second, False): return False # Write-Write race if d in fused_cc.first_outputs: nodes_first = [ n for n in first_output if n.data == d ] if StateFusion.memlets_intersect( first_state, nodes_first, False, second_state, nodes_second, False): return False # End of data race check # Read-after-write dependencies: if there is an output of the # second state that is an input of the first, ensure all paths # from the input of the first state lead to the output. # Otherwise, there may be a RAW due to topological sort or # concurrency. second_inout = ((fused_cc.first_inputs | fused_cc.first_outputs) & fused_cc.second_outputs) for inout in second_inout: nodes_first = [ n for n in match_nodes if n.data == inout ] if any(first_state.out_degree(n) > 0 for n in nodes_first): return False # Read-after-write dependencies: if there is more than one first # output with the same data, make sure it can be unambiguously # connected to the second state if (len(fused_cc.first_output_nodes) > len( fused_cc.first_outputs)): for inpnode in fused_cc.second_input_nodes: found = None for outnode in fused_cc.first_output_nodes: if outnode.data != inpnode.data: continue if StateFusion.memlets_intersect( first_state, [outnode], False, second_state, [inpnode], True): # If found more than once, either there is a # path from one to another or it is ambiguous if found is not None: if nx.has_path(first_state.nx, outnode, found): # Found is a descendant, continue continue elif nx.has_path(first_state.nx, found, outnode): # New node is a descendant, set as found found = outnode else: # No path: ambiguous match return False found = outnode return True
def can_be_applied(graph, candidate, expr_index, sdfg, strict=False): first_state = graph.nodes()[candidate[StateFusion._first_state]] second_state = graph.nodes()[candidate[StateFusion._second_state]] out_edges = graph.out_edges(first_state) in_edges = graph.in_edges(first_state) # First state must have only one output edge (with dst the second # state). if len(out_edges) != 1: return False # The interstate edge must not have a condition. if not out_edges[0].data.is_unconditional(): return False # The interstate edge may have assignments, as long as there are input # edges to the first state that can absorb them. if out_edges[0].data.assignments: if not in_edges: return False # Fail if symbol is set before the state to fuse new_assignments = set(out_edges[0].data.assignments.keys()) if any((new_assignments & set(e.data.assignments.keys())) for e in in_edges): return False # Fail if symbol is used in the dataflow of that state if len(new_assignments & first_state.free_symbols) > 0: return False # There can be no state that have output edges pointing to both the # first and the second state. Such a case will produce a multi-graph. for src, _, _ in in_edges: for _, dst, _ in graph.out_edges(src): if dst == second_state: return False if strict: # If second state has other input edges, there might be issues # Exceptions are when none of the states contain dataflow, unless # the first state is an initial state (in which case the new initial # state would be ambiguous). first_in_edges = graph.in_edges(first_state) second_in_edges = graph.in_edges(second_state) if ((not second_state.is_empty() or not first_state.is_empty() or len(first_in_edges) == 0) and len(second_in_edges) != 1): return False # Get connected components. first_cc = [ cc_nodes for cc_nodes in nx.weakly_connected_components(first_state._nx) ] second_cc = [ cc_nodes for cc_nodes in nx.weakly_connected_components( second_state._nx) ] # Find source/sink (data) nodes first_input = { node for node in sdutil.find_source_nodes(first_state) if isinstance(node, nodes.AccessNode) } first_output = { node for node in first_state.nodes() if isinstance(node, nodes.AccessNode) and node not in first_input } second_input = { node for node in sdutil.find_source_nodes(second_state) if isinstance(node, nodes.AccessNode) } second_output = { node for node in second_state.nodes() if isinstance(node, nodes.AccessNode) and node not in second_input } # Find source/sink (data) nodes by connected component first_cc_input = [cc.intersection(first_input) for cc in first_cc] first_cc_output = [ cc.intersection(first_output) for cc in first_cc ] second_cc_input = [ cc.intersection(second_input) for cc in second_cc ] second_cc_output = [ cc.intersection(second_output) for cc in second_cc ] # Apply transformation in case all paths to the second state's # nodes go through the same access node, which implies sequential # behavior in SDFG semantics. first_output_names = {node.data for node in first_output} second_input_names = {node.data for node in second_input} # If any second input appears more than once, fail if len(second_input) > len(second_input_names): return False # If any first output that is an input to the second state # appears in more than one CC, fail matches = first_output_names & second_input_names for match in matches: cc_appearances = 0 for cc in first_cc_output: if len([n for n in cc if n.data == match]) > 0: cc_appearances += 1 if cc_appearances > 1: return False # Recreate fused connected component correspondences, and then # check for hazards resulting_ccs: List[CCDesc] = StateFusion.find_fused_components( first_cc_input, first_cc_output, second_cc_input, second_cc_output) # Check for data races for fused_cc in resulting_ccs: # Write-Write hazard - data is output of both first and second # states, without a read in between write_write_candidates = ( (fused_cc.first_outputs & fused_cc.second_outputs) - fused_cc.second_inputs) if len(write_write_candidates) > 0: # If we have potential candidates, check if there is a # path from the first write to the second write (in that # case, there is no hazard): # Find the leaf (topological) instances of the matches order = [ x for x in reversed( list(nx.topological_sort(first_state._nx))) if isinstance(x, nodes.AccessNode) and x.data in fused_cc.first_outputs ] # Those nodes will be the connection points upon fusion match_nodes = { next(n for n in order if n.data == match) for match in (fused_cc.first_outputs & fused_cc.second_inputs) } else: match_nodes = set() for cand in write_write_candidates: nodes_first = [n for n in first_output if n.data == cand] nodes_second = [n for n in second_output if n.data == cand] # If there is a path for the candidate that goes through # the match nodes in both states, there is no conflict fail = False path_found = False for match in match_nodes: for node in nodes_first: path_to = nx.has_path(first_state._nx, node, match) if not path_to: continue path_found = True node2 = next(n for n in second_input if n.data == match.data) if not all( nx.has_path(second_state._nx, node2, n) for n in nodes_second): fail = True break if fail or path_found: break # Check for intersection (if None, fusion is ok) if fail or not path_found: if StateFusion.memlets_intersect( first_state, nodes_first, False, second_state, nodes_second, False): return False # End of write-write hazard check first_inout = fused_cc.first_inputs | fused_cc.first_outputs for other_cc in resulting_ccs: if other_cc is fused_cc: continue # If an input/output of a connected component in the first # state is an output of another connected component in the # second state, we have a potential data race (Read-Write # or Write-Write) for d in first_inout: if d in other_cc.second_outputs: # Check for intersection (if None, fusion is ok) nodes_second = [ n for n in second_output if n.data == d ] # Read-Write race if d in fused_cc.first_inputs: nodes_first = [ n for n in first_input if n.data == d ] if StateFusion.memlets_intersect( first_state, nodes_first, True, second_state, nodes_second, False): return False # Write-Write race if d in fused_cc.first_outputs: nodes_first = [ n for n in first_output if n.data == d ] if StateFusion.memlets_intersect( first_state, nodes_first, False, second_state, nodes_second, False): return False # End of data race check return True