def expressions(): # Case 1: Loop with one state sdfg = sd.SDFG('_') sdfg.add_nodes_from([ DetectLoop._loop_guard, DetectLoop._loop_begin, DetectLoop._exit_state ]) sdfg.add_edge(DetectLoop._loop_guard, DetectLoop._loop_begin, edges.InterstateEdge()) sdfg.add_edge(DetectLoop._loop_guard, DetectLoop._exit_state, edges.InterstateEdge()) sdfg.add_edge(DetectLoop._loop_begin, DetectLoop._loop_guard, edges.InterstateEdge()) # Case 2: Loop with multiple states (no back-edge from state) msdfg = sd.SDFG('_') msdfg.add_nodes_from([ DetectLoop._loop_guard, DetectLoop._loop_begin, DetectLoop._exit_state ]) msdfg.add_edge(DetectLoop._loop_guard, DetectLoop._loop_begin, edges.InterstateEdge()) msdfg.add_edge(DetectLoop._loop_guard, DetectLoop._exit_state, edges.InterstateEdge()) return [sdfg, msdfg]
def create_states_simple(pdp, out_sdfg, start_state=None, end_state=None, start_edge=None): """ Creates a state per primitive, with the knowledge that they can be optimized later. @param pdp: A parsed dace program. @param out_sdfg: The output SDFG. @param start_state: The starting/parent state to connect from (for recursive calls). @param end_state: The end/parent state to connect to (for recursive calls). @return: A dictionary mapping between a state and the list of dace primitives included in it. """ state_to_primitives = OrderedDict() # Create starting state and edge if start_state is None: start_state = out_sdfg.add_state('start') state_to_primitives[start_state] = [] if start_edge is None: start_edge = ed.InterstateEdge() previous_state = start_state previous_edge = start_edge for i, primitive in enumerate(pdp.children): state = out_sdfg.add_state(primitive.name) state_to_primitives[state] = [] # Edge that can be created on entry to control flow children entry_edge = None ######################################### # Cases depending on primitive type ######################################### # Nothing special happens with a dataflow node (nested states are # handled with a separate call to create_states_simple) if isinstance(primitive, astnodes._DataFlowNode): out_sdfg.add_edge(previous_state, state, previous_edge) state_to_primitives[state] = [primitive] previous_state = state previous_edge = ed.InterstateEdge() # Control flow needs to traverse into children nodes elif isinstance(primitive, astnodes._ControlFlowNode): # Iteration has >=3 states - begin, loop[...], end; and connects the # loop states, as well as the begin to end directly if the condition # did not evaluate to true if isinstance(primitive, astnodes._IterateNode): condition = ast.parse( '(%s %s %s)' % (primitive.params[0], '<' if primitive.range[0][2] >= 0 else '>', primitive.range[0][1] + 1)).body[0] condition_neg = astutils.negate_expr(condition) # Loop-start state lstart_state = out_sdfg.add_state(primitive.name + '_start') state_to_primitives[lstart_state] = [] out_sdfg.add_edge(previous_state, lstart_state, previous_edge) out_sdfg.add_edge( lstart_state, state, ed.InterstateEdge( assignments={ primitive.params[0]: primitive.range[0][0] })) # Loop-end state that jumps back to `state` loop_state = out_sdfg.add_state(primitive.name + '_end') state_to_primitives[loop_state] = [] # Connect loop out_sdfg.add_edge( loop_state, state, ed.InterstateEdge( assignments={ primitive.params[0]: symbolic.pystr_to_symbolic(primitive.params[0]) + primitive.range[0][2] })) # End connection previous_state = state previous_edge = ed.InterstateEdge(condition=condition_neg) # Create children states cmap = create_states_simple( primitive, out_sdfg, state, loop_state, ed.InterstateEdge(condition=condition)) state_to_primitives.update(cmap) # Loop is similar to iterate, but more general w.r.t. conditions elif isinstance(primitive, astnodes._LoopNode): loop_condition = primitive.condition # Entry out_sdfg.add_edge(previous_state, state, previous_edge) # Loop-end state that jumps back to `state` loop_state = out_sdfg.add_state(primitive.name + '_end') state_to_primitives[loop_state] = [] # Loopback out_sdfg.add_edge(loop_state, state, ed.InterstateEdge()) # End connection previous_state = state previous_edge = ed.InterstateEdge( condition=astutils.negate_expr(loop_condition)) entry_edge = ed.InterstateEdge(condition=loop_condition) # Create children states cmap = create_states_simple(primitive, out_sdfg, state, loop_state, entry_edge) state_to_primitives.update(cmap) elif isinstance(primitive, astnodes._IfNode): if_condition = primitive.condition # Check if we have an else node, otherwise add a skip condition # ourselves if (i + 1) < len(pdp.children) and isinstance( pdp.children[i + 1], astnodes._ElseNode): has_else = True else_prim = pdp.children[i + 1] else_condition = else_prim.condition else: has_else = False else_condition = astutils.negate_expr(primitive.condition) # End-of-branch state (converge to this) bend_state = out_sdfg.add_state(primitive.name + '_end') state_to_primitives[bend_state] = [] # Entry out_sdfg.add_edge(previous_state, state, previous_edge) # Create children states cmap = create_states_simple( primitive, out_sdfg, state, bend_state, ed.InterstateEdge(condition=if_condition)) state_to_primitives.update(cmap) # Handle 'else' condition if not has_else: out_sdfg.add_edge( state, bend_state, ed.InterstateEdge(condition=else_condition)) else: # Recursively parse 'else' primitive's children cmap = create_states_simple( else_prim, out_sdfg, state, bend_state, ed.InterstateEdge(condition=else_condition)) state_to_primitives.update(cmap) # Exit previous_state = bend_state previous_edge = ed.InterstateEdge() elif isinstance(primitive, astnodes._ElseNode): if i - 1 < 0 or not isinstance(pdp.children[i - 1], astnodes._IfNode): raise SyntaxError('Found else state without matching if') # If 'else' state is correct, we already processed it del state_to_primitives[state] out_sdfg.remove_node(state) # Connect to end_state (and create it if necessary) if end_state is None: end_state = out_sdfg.add_state('end') state_to_primitives[end_state] = [] out_sdfg.add_edge(previous_state, end_state, previous_edge) return state_to_primitives
class StateFusion(pattern_matching.Transformation): """ Implements the state-fusion transformation. State-fusion takes two states that are connected through a single edge, and fuses them into one state. If strict, only applies if no memory access hazards are created. """ _states_fused = 0 _first_state = sdfg.SDFGState() _edge = edges.InterstateEdge() _second_state = sdfg.SDFGState() @staticmethod def annotates_memlets(): return False @staticmethod def expressions(): return [ nxutil.node_path_graph(StateFusion._first_state, StateFusion._second_state) ] @staticmethod def can_be_applied(graph, candidate, expr_index, sdfg, strict=False): first_state = graph.nodes()[candidate[StateFusion._first_state]] second_state = graph.nodes()[candidate[StateFusion._second_state]] out_edges = graph.out_edges(first_state) in_edges = graph.in_edges(first_state) # First state must have only one output edge (with dst the second # state). if len(out_edges) != 1: return False # The interstate edge must not have a condition. if not out_edges[0].data.is_unconditional(): return False # The interstate edge may have assignments, as long as there are input # edges to the first state, that can absorb them. if out_edges[0].data.assignments and not in_edges: return False # There can be no state that have output edges pointing to both the # first and the second state. Such a case will produce a multi-graph. for src, _, _ in in_edges: for _, dst, _ in graph.out_edges(src): if dst == second_state: return False if strict: # If second state has other input edges, there might be issues # Exceptions are when none of the states contain dataflow, unless # the first state is an initial state (in which case the new initial # state would be ambiguous). first_in_edges = graph.in_edges(first_state) second_in_edges = graph.in_edges(second_state) if ((not second_state.is_empty() or not first_state.is_empty() or len(first_in_edges) == 0) and len(second_in_edges) != 1): return False # Get connected components. first_cc = [ cc_nodes for cc_nodes in nx.weakly_connected_components(first_state._nx) ] second_cc = [ cc_nodes for cc_nodes in nx.weakly_connected_components( second_state._nx) ] # Find source/sink (data) nodes first_input = { node for node in nxutil.find_source_nodes(first_state) if isinstance(node, nodes.AccessNode) } first_output = { node for node in first_state.nodes() if isinstance(node, nodes.AccessNode) and node not in first_input } second_input = { node for node in nxutil.find_source_nodes(second_state) if isinstance(node, nodes.AccessNode) } second_output = { node for node in second_state.nodes() if isinstance(node, nodes.AccessNode) and node not in second_input } # Find source/sink (data) nodes by connected component first_cc_input = [cc.intersection(first_input) for cc in first_cc] first_cc_output = [ cc.intersection(first_output) for cc in first_cc ] second_cc_input = [ cc.intersection(second_input) for cc in second_cc ] second_cc_output = [ cc.intersection(second_output) for cc in second_cc ] # Apply transformation in case all paths to the second state's # nodes go through the same access node, which implies sequential # behavior in SDFG semantics. check_strict = len(first_cc) for cc_output in first_cc_output: out_nodes = [ n for n in first_state.sink_nodes() if n in cc_output ] # Branching exists, multiple paths may involve same access node # potentially causing data races if len(out_nodes) > 1: continue # Otherwise, check if any of the second state's connected # components for matching input for node in cc_output: if (next( (x for x in second_input if x.label == node.label), None) is not None): check_strict -= 1 break if check_strict > 0: # Check strict conditions # RW dependency for node in first_input: if (next( (x for x in second_output if x.label == node.label), None) is not None): return False # WW dependency for node in first_output: if (next( (x for x in second_output if x.label == node.label), None) is not None): return False return True @staticmethod def match_to_str(graph, candidate): first_state = graph.nodes()[candidate[StateFusion._first_state]] second_state = graph.nodes()[candidate[StateFusion._second_state]] return " -> ".join(state.label for state in [first_state, second_state]) def apply(self, sdfg): first_state = sdfg.nodes()[self.subgraph[StateFusion._first_state]] second_state = sdfg.nodes()[self.subgraph[StateFusion._second_state]] # Remove interstate edge(s) edges = sdfg.edges_between(first_state, second_state) for edge in edges: if edge.data.assignments: for src, dst, other_data in sdfg.in_edges(first_state): other_data.assignments.update(edge.data.assignments) sdfg.remove_edge(edge) # Special case 1: first state is empty if first_state.is_empty(): nxutil.change_edge_dest(sdfg, first_state, second_state) sdfg.remove_node(first_state) return # Special case 2: second state is empty if second_state.is_empty(): nxutil.change_edge_src(sdfg, second_state, first_state) nxutil.change_edge_dest(sdfg, second_state, first_state) sdfg.remove_node(second_state) return # Normal case: both states are not empty # Find source/sink (data) nodes first_input = [ node for node in nxutil.find_source_nodes(first_state) if isinstance(node, nodes.AccessNode) ] first_output = [ node for node in nxutil.find_sink_nodes(first_state) if isinstance(node, nodes.AccessNode) ] second_input = [ node for node in nxutil.find_source_nodes(second_state) if isinstance(node, nodes.AccessNode) ] # first input = first input - first output first_input = [ node for node in first_input if next((x for x in first_output if x.label == node.label), None) is None ] # Merge second state to first state # First keep a backup of the topological sorted order of the nodes order = [ x for x in reversed(list(nx.topological_sort(first_state._nx))) if isinstance(x, nodes.AccessNode) ] for node in second_state.nodes(): first_state.add_node(node) for src, src_conn, dst, dst_conn, data in second_state.edges(): first_state.add_edge(src, src_conn, dst, dst_conn, data) # Merge common (data) nodes for node in first_input: try: old_node = next(x for x in second_input if x.label == node.label) except StopIteration: continue nxutil.change_edge_src(first_state, old_node, node) first_state.remove_node(old_node) second_input.remove(old_node) node.access = dtypes.AccessType.ReadWrite for node in first_output: try: new_node = next(x for x in second_input if x.label == node.label) except StopIteration: continue nxutil.change_edge_dest(first_state, node, new_node) first_state.remove_node(node) second_input.remove(new_node) new_node.access = dtypes.AccessType.ReadWrite # Check if any input nodes of the second state have to be merged with # non-input/output nodes of the first state. for node in second_input: if first_state.in_degree(node) == 0: n = next((x for x in order if x.label == node.label), None) if n: nxutil.change_edge_src(first_state, node, n) first_state.remove_node(node) n.access = dtypes.AccessType.ReadWrite # Redirect edges and remove second state nxutil.change_edge_src(sdfg, second_state, first_state) sdfg.remove_node(second_state) if Config.get_bool("debugprint"): StateFusion._states_fused += 1
def apply(self, sdfg): # Retrieve map entry and exit nodes. graph = sdfg.nodes()[self.state_id] map_entry = graph.nodes()[self.subgraph[MapToForLoop._map_entry]] map_exits = graph.exit_nodes(map_entry) loop_idx = map_entry.map.params[0] loop_from, loop_to, loop_step = map_entry.map.range[0] nested_sdfg = dace.SDFG(graph.label + '_' + map_entry.map.label) # Construct nested SDFG begin = nested_sdfg.add_state('begin') guard = nested_sdfg.add_state('guard') body = nested_sdfg.add_state('body') end = nested_sdfg.add_state('end') nested_sdfg.add_edge( begin, guard, edges.InterstateEdge(assignments={str(loop_idx): str(loop_from)})) nested_sdfg.add_edge( guard, body, edges.InterstateEdge(condition = str(loop_idx) + ' <= ' + \ str(loop_to)) ) nested_sdfg.add_edge( guard, end, edges.InterstateEdge(condition = str(loop_idx) + ' > ' + \ str(loop_to)) ) nested_sdfg.add_edge( body, guard, edges.InterstateEdge(assignments = {str(loop_idx): str(loop_idx) + \ ' + ' +str(loop_step)}) ) # Add map contents map_subgraph = graph.scope_subgraph(map_entry) for node in map_subgraph.nodes(): if node is not map_entry and node not in map_exits: body.add_node(node) for src, src_conn, dst, dst_conn, memlet in map_subgraph.edges(): if src is not map_entry and dst not in map_exits: body.add_edge(src, src_conn, dst, dst_conn, memlet) # Reconnect inputs nested_in_data_nodes = {} nested_in_connectors = {} nested_in_memlets = {} for i, edge in enumerate(graph.in_edges(map_entry)): src, src_conn, dst, dst_conn, memlet = edge data_label = '_in_' + memlet.data memdata = sdfg.arrays[memlet.data] if isinstance(memdata, data.Array): data_array = sdfg.add_array(data_label, memdata.dtype, [ symbolic.overapproximate(r) for r in memlet.bounding_box_size() ]) elif isinstance(memdata, data.Scalar): data_array = sdfg.add_scalar(data_label, memdata.dtype) else: raise NotImplementedError() data_node = nodes.AccessNode(data_label) body.add_node(data_node) nested_in_data_nodes.update({i: data_node}) nested_in_connectors.update({i: data_label}) nested_in_memlets.update({i: memlet}) for _, _, _, _, old_memlet in body.edges(): if old_memlet.data == memlet.data: old_memlet.data = data_label #body.add_edge(data_node, None, dst, dst_conn, memlet) # Reconnect outputs nested_out_data_nodes = {} nested_out_connectors = {} nested_out_memlets = {} for map_exit in map_exits: for i, edge in enumerate(graph.out_edges(map_exit)): src, src_conn, dst, dst_conn, memlet = edge data_label = '_out_' + memlet.data memdata = sdfg.arrays[memlet.data] if isinstance(memdata, data.Array): data_array = sdfg.add_array(data_label, memdata.dtype, [ symbolic.overapproximate(r) for r in memlet.bounding_box_size() ]) elif isinstance(memdata, data.Scalar): data_array = sdfg.add_scalar(data_label, memdata.dtype) else: raise NotImplementedError() data_node = nodes.AccessNode(data_label) body.add_node(data_node) nested_out_data_nodes.update({i: data_node}) nested_out_connectors.update({i: data_label}) nested_out_memlets.update({i: memlet}) for _, _, _, _, old_memlet in body.edges(): if old_memlet.data == memlet.data: old_memlet.data = data_label #body.add_edge(src, src_conn, data_node, None, memlet) # Add nested SDFG and reconnect it nested_node = graph.add_nested_sdfg( nested_sdfg, sdfg, set(nested_in_connectors.values()), set(nested_out_connectors.values())) for i, edge in enumerate(graph.in_edges(map_entry)): src, src_conn, dst, dst_conn, memlet = edge graph.add_edge(src, src_conn, nested_node, nested_in_connectors[i], nested_in_memlets[i]) for map_exit in map_exits: for i, edge in enumerate(graph.out_edges(map_exit)): src, src_conn, dst, dst_conn, memlet = edge graph.add_edge(nested_node, nested_out_connectors[i], dst, dst_conn, nested_out_memlets[i]) for src, src_conn, dst, dst_conn, memlet in graph.out_edges(map_entry): i = int(src_conn[4:]) - 1 new_memlet = dcpy(memlet) new_memlet.data = nested_in_data_nodes[i].data body.add_edge(nested_in_data_nodes[i], None, dst, dst_conn, new_memlet) for map_exit in map_exits: for src, src_conn, dst, dst_conn, memlet in graph.in_edges( map_exit): i = int(dst_conn[3:]) - 1 new_memlet = dcpy(memlet) new_memlet.data = nested_out_data_nodes[i].data body.add_edge(src, src_conn, nested_out_data_nodes[i], None, new_memlet) for node in map_subgraph: graph.remove_node(node)
def apply(self, sdfg): state = sdfg.nodes()[self.subgraph[FPGATransformState._state]] # Find source/sink (data) nodes input_nodes = nxutil.find_source_nodes(state) output_nodes = nxutil.find_sink_nodes(state) fpga_data = {} # Input nodes may also be nodes with WCR memlets # We have to recur across nested SDFGs to find them wcr_input_nodes = set() stack = [] parent_sdfg = {state: sdfg} # Map states to their parent SDFG for node, graph in state.all_nodes_recursive(): if isinstance(graph, dace.SDFG): parent_sdfg[node] = graph if isinstance(node, dace.graph.nodes.AccessNode): for e in graph.all_edges(node): if e.data.wcr is not None: trace = dace.sdfg.trace_nested_access( node, graph, parent_sdfg[graph]) for node_trace, state_trace, sdfg_trace in trace: # Find the name of the accessed node in our scope if state_trace == state and sdfg_trace == sdfg: outer_name = node_trace.data break else: # This does not trace back to the current state, so # we don't care continue input_nodes.append(outer_name) wcr_input_nodes.add(outer_name) if input_nodes: # create pre_state pre_state = sd.SDFGState('pre_' + state.label, sdfg) for node in input_nodes: if not isinstance(node, dace.graph.nodes.AccessNode): continue desc = node.desc(sdfg) if not isinstance(desc, dace.data.Array): # TODO: handle streams continue if node.data in fpga_data: fpga_array = fpga_data[node.data] elif node not in wcr_input_nodes: fpga_array = sdfg.add_array( 'fpga_' + node.data, desc.shape, desc.dtype, materialize_func=desc.materialize_func, transient=True, storage=dtypes.StorageType.FPGA_Global, allow_conflicts=desc.allow_conflicts, strides=desc.strides, offset=desc.offset) fpga_data[node.data] = fpga_array pre_node = pre_state.add_read(node.data) pre_fpga_node = pre_state.add_write('fpga_' + node.data) full_range = subsets.Range([(0, s - 1, 1) for s in desc.shape]) mem = memlet.Memlet(node.data, full_range.num_elements(), full_range, 1) pre_state.add_edge(pre_node, None, pre_fpga_node, None, mem) if node not in wcr_input_nodes: fpga_node = state.add_read('fpga_' + node.data) nxutil.change_edge_src(state, node, fpga_node) state.remove_node(node) sdfg.add_node(pre_state) nxutil.change_edge_dest(sdfg, state, pre_state) sdfg.add_edge(pre_state, state, edges.InterstateEdge()) if output_nodes: post_state = sd.SDFGState('post_' + state.label, sdfg) for node in output_nodes: if not isinstance(node, dace.graph.nodes.AccessNode): continue desc = node.desc(sdfg) if not isinstance(desc, dace.data.Array): # TODO: handle streams continue if node.data in fpga_data: fpga_array = fpga_data[node.data] else: fpga_array = sdfg.add_array( 'fpga_' + node.data, desc.shape, desc.dtype, materialize_func=desc.materialize_func, transient=True, storage=dtypes.StorageType.FPGA_Global, allow_conflicts=desc.allow_conflicts, strides=desc.strides, offset=desc.offset) fpga_data[node.data] = fpga_array # fpga_node = type(node)(fpga_array) post_node = post_state.add_write(node.data) post_fpga_node = post_state.add_read('fpga_' + node.data) full_range = subsets.Range([(0, s - 1, 1) for s in desc.shape]) mem = memlet.Memlet('fpga_' + node.data, full_range.num_elements(), full_range, 1) post_state.add_edge(post_fpga_node, None, post_node, None, mem) fpga_node = state.add_write('fpga_' + node.data) nxutil.change_edge_dest(state, node, fpga_node) state.remove_node(node) sdfg.add_node(post_state) nxutil.change_edge_src(sdfg, state, post_state) sdfg.add_edge(state, post_state, edges.InterstateEdge()) veclen_ = 1 # propagate vector info from a nested sdfg for src, src_conn, dst, dst_conn, mem in state.edges(): # need to go inside the nested SDFG and grab the vector length if isinstance(dst, dace.graph.nodes.NestedSDFG): # this edge is going to the nested SDFG for inner_state in dst.sdfg.states(): for n in inner_state.nodes(): if isinstance(n, dace.graph.nodes.AccessNode ) and n.data == dst_conn: # assuming all memlets have the same vector length veclen_ = inner_state.all_edges(n)[0].data.veclen if isinstance(src, dace.graph.nodes.NestedSDFG): # this edge is coming from the nested SDFG for inner_state in src.sdfg.states(): for n in inner_state.nodes(): if isinstance(n, dace.graph.nodes.AccessNode ) and n.data == src_conn: # assuming all memlets have the same vector length veclen_ = inner_state.all_edges(n)[0].data.veclen if mem.data is not None and mem.data in fpga_data: mem.data = 'fpga_' + mem.data mem.veclen = veclen_ fpga_update(sdfg, state, 0)
def apply(self, sdfg: sd.SDFG): ####################################################### # Step 0: SDFG metadata # Find all input and output data descriptors input_nodes = [] output_nodes = [] global_code_nodes = [[] for _ in sdfg.nodes()] for i, state in enumerate(sdfg.nodes()): sdict = state.scope_dict() for node in state.nodes(): if (isinstance(node, nodes.AccessNode) and node.desc(sdfg).transient == False): if (state.out_degree(node) > 0 and node.data not in input_nodes): input_nodes.append((node.data, node.desc(sdfg))) if (state.in_degree(node) > 0 and node.data not in output_nodes): output_nodes.append((node.data, node.desc(sdfg))) elif isinstance(node, nodes.CodeNode) and sdict[node] is None: if not isinstance(node, nodes.EmptyTasklet): global_code_nodes[i].append(node) # Input nodes may also be nodes with WCR memlets and no identity for e in state.edges(): if e.data.wcr is not None and e.data.wcr_identity is None: if (e.data.data not in input_nodes and sdfg.arrays[e.data.data].transient == False): input_nodes.append(e.data.data) start_state = sdfg.start_state end_states = sdfg.sink_nodes() ####################################################### # Step 1: Create cloned GPU arrays and replace originals cloned_arrays = {} for inodename, inode in input_nodes: newdesc = inode.clone() newdesc.storage = types.StorageType.GPU_Global newdesc.transient = True sdfg.add_datadesc('gpu_' + inodename, newdesc) cloned_arrays[inodename] = 'gpu_' + inodename for onodename, onode in output_nodes: if onodename in cloned_arrays: continue newdesc = onode.clone() newdesc.storage = types.StorageType.GPU_Global newdesc.transient = True sdfg.add_datadesc('gpu_' + onodename, newdesc) cloned_arrays[onodename] = 'gpu_' + onodename # Replace nodes for state in sdfg.nodes(): for node in state.nodes(): if (isinstance(node, nodes.AccessNode) and node.data in cloned_arrays): node.data = cloned_arrays[node.data] # Replace memlets for state in sdfg.nodes(): for edge in state.edges(): if edge.data.data in cloned_arrays: edge.data.data = cloned_arrays[edge.data.data] ####################################################### # Step 2: Create copy-in state copyin_state = sdfg.add_state(sdfg.label + '_copyin') sdfg.add_edge(copyin_state, start_state, ed.InterstateEdge()) for nname, desc in input_nodes: src_array = nodes.AccessNode(nname, debuginfo=desc.debuginfo) dst_array = nodes.AccessNode(cloned_arrays[nname], debuginfo=desc.debuginfo) copyin_state.add_node(src_array) copyin_state.add_node(dst_array) copyin_state.add_nedge( src_array, dst_array, memlet.Memlet.from_array(src_array.data, src_array.desc(sdfg))) ####################################################### # Step 3: Create copy-out state copyout_state = sdfg.add_state(sdfg.label + '_copyout') for state in end_states: sdfg.add_edge(state, copyout_state, ed.InterstateEdge()) for nname, desc in output_nodes: src_array = nodes.AccessNode(cloned_arrays[nname], debuginfo=desc.debuginfo) dst_array = nodes.AccessNode(nname, debuginfo=desc.debuginfo) copyout_state.add_node(src_array) copyout_state.add_node(dst_array) copyout_state.add_nedge( src_array, dst_array, memlet.Memlet.from_array(dst_array.data, dst_array.desc(sdfg))) ####################################################### # Step 4: Modify transient data storage for state in sdfg.nodes(): sdict = state.scope_dict() for node in state.nodes(): if isinstance(node, nodes.AccessNode) and node.desc(sdfg).transient: nodedesc = node.desc(sdfg) if sdict[node] is None: # NOTE: the cloned arrays match too but it's the same # storage so we don't care nodedesc.storage = types.StorageType.GPU_Global # Try to move allocation/deallocation out of loops if self.toplevel_trans: nodedesc.toplevel = True else: # Make internal transients registers if self.register_trans: nodedesc.storage = types.StorageType.Register ####################################################### # Step 5: Wrap free tasklets and nested SDFGs with a GPU map for state, gcodes in zip(sdfg.nodes(), global_code_nodes): for gcode in gcodes: # Create map and connectors me, mx = state.add_map(gcode.label + '_gmap', {gcode.label + '__gmapi': '0:1'}, schedule=types.ScheduleType.GPU_Device) # Store in/out edges in lists so that they don't get corrupted # when they are removed from the graph in_edges = list(state.in_edges(gcode)) out_edges = list(state.out_edges(gcode)) me.in_connectors = set('IN_' + e.dst_conn for e in in_edges) me.out_connectors = set('OUT_' + e.dst_conn for e in in_edges) mx.in_connectors = set('IN_' + e.src_conn for e in out_edges) mx.out_connectors = set('OUT_' + e.src_conn for e in out_edges) # Create memlets through map for e in in_edges: state.remove_edge(e) state.add_edge(e.src, e.src_conn, me, 'IN_' + e.dst_conn, e.data) state.add_edge(me, 'OUT_' + e.dst_conn, e.dst, e.dst_conn, e.data) for e in out_edges: state.remove_edge(e) state.add_edge(e.src, e.src_conn, mx, 'IN_' + e.src_conn, e.data) state.add_edge(mx, 'OUT_' + e.src_conn, e.dst, e.dst_conn, e.data) # Map without inputs if len(in_edges) == 0: state.add_nedge(me, gcode, memlet.EmptyMemlet()) ####################################################### # Step 6: Change all top-level maps to GPU maps for i, state in enumerate(sdfg.nodes()): sdict = state.scope_dict() for node in state.nodes(): if isinstance(node, nodes.EntryNode): if sdict[node] is None: node.schedule = types.ScheduleType.GPU_Device elif self.sequential_innermaps: node.schedule = types.ScheduleType.Sequential ####################################################### # Step 7: Strict transformations if not self.strict_transform: return # Apply strict state fusions greedily. opt = optimizer.SDFGOptimizer(sdfg, inplace=True) fusions = 0 arrays = 0 options = [ match for match in opt.get_pattern_matches(strict=True) if isinstance(match, (StateFusion, RedundantArray)) ] while options: ssdfg = sdfg.sdfg_list[options[0].sdfg_id] options[0].apply(ssdfg) ssdfg.validate() if isinstance(options[0], StateFusion): fusions += 1 if isinstance(options[0], RedundantArray): arrays += 1 options = [ match for match in opt.get_pattern_matches(strict=True) if isinstance(match, (StateFusion, RedundantArray)) ] if Config.get_bool('debugprint') and (fusions > 0 or arrays > 0): print('Automatically applied {} strict state fusions and removed' ' {} redundant arrays.'.format(fusions, arrays))
def apply(self, sdfg): state = sdfg.nodes()[self.subgraph[FPGATransformState._state]] # Find source/sink (data) nodes input_nodes = nxutil.find_source_nodes(state) output_nodes = nxutil.find_sink_nodes(state) fpga_data = {} if input_nodes: pre_state = sd.SDFGState('pre_' + state.label, sdfg) for node in input_nodes: if (not isinstance(node, dace.graph.nodes.AccessNode) or not isinstance(node.desc(sdfg), dace.data.Array)): # Only transfer array nodes # TODO: handle streams continue array = node.desc(sdfg) if array.name in fpga_data: fpga_array = fpga_data[node.data] else: fpga_array = sdfg.add_array( 'fpga_' + node.data, array.dtype, array.shape, materialize_func=array.materialize_func, transient=True, storage=types.StorageType.FPGA_Global, allow_conflicts=array.allow_conflicts, access_order=array.access_order, strides=array.strides, offset=array.offset) fpga_data[array.name] = fpga_array fpga_node = type(node)(fpga_array) pre_state.add_node(node) pre_state.add_node(fpga_node) full_range = subsets.Range([(0, s - 1, 1) for s in array.shape]) mem = memlet.Memlet(array, full_range.num_elements(), full_range, 1) pre_state.add_edge(node, None, fpga_node, None, mem) state.add_node(fpga_node) nxutil.change_edge_src(state, node, fpga_node) state.remove_node(node) sdfg.add_node(pre_state) nxutil.change_edge_dest(sdfg, state, pre_state) sdfg.add_edge(pre_state, state, edges.InterstateEdge()) if output_nodes: post_state = sd.SDFGState('post_' + state.label, sdfg) for node in output_nodes: if (not isinstance(node, dace.graph.nodes.AccessNode) or not isinstance(node.desc(sdfg), dace.data.Array)): # Only transfer array nodes # TODO: handle streams continue array = node.desc(sdfg) if node.data in fpga_data: fpga_array = fpga_data[node.data] else: fpga_array = sdfg.add_array( 'fpga_' + node.data, array.dtype, array.shape, materialize_func=array.materialize_func, transient=True, storage=types.StorageType.FPGA_Global, allow_conflicts=array.allow_conflicts, access_order=array.access_order, strides=array.strides, offset=array.offset) fpga_data[node.data] = fpga_array fpga_node = type(node)(fpga_array) post_state.add_node(node) post_state.add_node(fpga_node) full_range = subsets.Range([(0, s - 1, 1) for s in array.shape]) mem = memlet.Memlet(fpga_array, full_range.num_elements(), full_range, 1) post_state.add_edge(fpga_node, None, node, None, mem) state.add_node(fpga_node) nxutil.change_edge_dest(state, node, fpga_node) state.remove_node(node) sdfg.add_node(post_state) nxutil.change_edge_src(sdfg, state, post_state) sdfg.add_edge(state, post_state, edges.InterstateEdge()) for src, _, dst, _, mem in state.edges(): if mem.data is not None and mem.data in fpga_data: mem.data = 'fpga_' + node.data fpga_update(state, 0)
def apply(self, sdfg): state = sdfg.nodes()[self.subgraph[FPGATransformState._state]] # Find source/sink (data) nodes input_nodes = nxutil.find_source_nodes(state) output_nodes = nxutil.find_sink_nodes(state) fpga_data = {} # Input nodes may also be nodes with WCR memlets # We have to recur across nested SDFGs to find them wcr_input_nodes = set() stack = [] for node, graph in state.all_nodes_recursive(): if isinstance(node, dace.graph.nodes.AccessNode): for e in graph.all_edges(node): if e.data.wcr is not None: # This is an output node with wcr # find the target in the parent sdfg # following the structure State->SDFG->State-> SDFG # from the current_state we have to go two levels up parent_state = graph.parent.parent if parent_state is not None: for parent_edges in parent_state.edges(): if parent_edges.src_conn == e.dst.data or ( isinstance(parent_edges.dst, dace.graph.nodes.AccessNode) and e.dst.data == parent_edges.dst.data): # This must be copied to device input_nodes.append(parent_edges.dst) wcr_input_nodes.add(parent_edges.dst) if input_nodes: # create pre_state pre_state = sd.SDFGState('pre_' + state.label, sdfg) for node in input_nodes: if (not isinstance(node, dace.graph.nodes.AccessNode) or not isinstance(node.desc(sdfg), dace.data.Array)): # Only transfer array nodes # TODO: handle streams continue array = node.desc(sdfg) if node.data in fpga_data: fpga_array = fpga_data[node.data] elif node not in wcr_input_nodes: fpga_array = sdfg.add_array( 'fpga_' + node.data, array.shape, array.dtype, materialize_func=array.materialize_func, transient=True, storage=dtypes.StorageType.FPGA_Global, allow_conflicts=array.allow_conflicts, strides=array.strides, offset=array.offset) fpga_data[node.data] = fpga_array pre_node = pre_state.add_read(node.data) pre_fpga_node = pre_state.add_write('fpga_' + node.data) full_range = subsets.Range([(0, s - 1, 1) for s in array.shape]) mem = memlet.Memlet(node.data, full_range.num_elements(), full_range, 1) pre_state.add_edge(pre_node, None, pre_fpga_node, None, mem) if node not in wcr_input_nodes: fpga_node = state.add_read('fpga_' + node.data) nxutil.change_edge_src(state, node, fpga_node) state.remove_node(node) sdfg.add_node(pre_state) nxutil.change_edge_dest(sdfg, state, pre_state) sdfg.add_edge(pre_state, state, edges.InterstateEdge()) if output_nodes: post_state = sd.SDFGState('post_' + state.label, sdfg) for node in output_nodes: if (not isinstance(node, dace.graph.nodes.AccessNode) or not isinstance(node.desc(sdfg), dace.data.Array)): # Only transfer array nodes # TODO: handle streams continue array = node.desc(sdfg) if node.data in fpga_data: fpga_array = fpga_data[node.data] else: fpga_array = sdfg.add_array( 'fpga_' + node.data, array.shape, array.dtype, materialize_func=array.materialize_func, transient=True, storage=dtypes.StorageType.FPGA_Global, allow_conflicts=array.allow_conflicts, strides=array.strides, offset=array.offset) fpga_data[node.data] = fpga_array # fpga_node = type(node)(fpga_array) post_node = post_state.add_write(node.data) post_fpga_node = post_state.add_read('fpga_' + node.data) full_range = subsets.Range([(0, s - 1, 1) for s in array.shape]) mem = memlet.Memlet('fpga_' + node.data, full_range.num_elements(), full_range, 1) post_state.add_edge(post_fpga_node, None, post_node, None, mem) fpga_node = state.add_write('fpga_' + node.data) nxutil.change_edge_dest(state, node, fpga_node) state.remove_node(node) sdfg.add_node(post_state) nxutil.change_edge_src(sdfg, state, post_state) sdfg.add_edge(state, post_state, edges.InterstateEdge()) veclen_ = 1 # propagate vector info from a nested sdfg for src, src_conn, dst, dst_conn, mem in state.edges(): # need to go inside the nested SDFG and grab the vector length if isinstance(dst, dace.graph.nodes.NestedSDFG): # this edge is going to the nested SDFG for inner_state in dst.sdfg.states(): for n in inner_state.nodes(): if isinstance(n, dace.graph.nodes.AccessNode ) and n.data == dst_conn: # assuming all memlets have the same vector length veclen_ = inner_state.all_edges(n)[0].data.veclen if isinstance(src, dace.graph.nodes.NestedSDFG): # this edge is coming from the nested SDFG for inner_state in src.sdfg.states(): for n in inner_state.nodes(): if isinstance(n, dace.graph.nodes.AccessNode ) and n.data == src_conn: # assuming all memlets have the same vector length veclen_ = inner_state.all_edges(n)[0].data.veclen if mem.data is not None and mem.data in fpga_data: mem.data = 'fpga_' + mem.data mem.veclen = veclen_ fpga_update(sdfg, state, 0)
def apply(self, sdfg: sd.SDFG): ####################################################### # Step 0: SDFG metadata # Find all input and output data descriptors input_nodes = [] output_nodes = [] global_code_nodes = [[] for _ in sdfg.nodes()] for i, state in enumerate(sdfg.nodes()): sdict = state.scope_dict() for node in state.nodes(): if (isinstance(node, nodes.AccessNode) and node.desc(sdfg).transient == False): if (state.out_degree(node) > 0 and node.data not in input_nodes): # Special case: nodes that lead to dynamic map ranges # must stay on host for e in state.out_edges(node): last_edge = state.memlet_path(e)[-1] if (isinstance(last_edge.dst, nodes.EntryNode) and last_edge.dst_conn and not last_edge.dst_conn.startswith('IN_')): break else: input_nodes.append((node.data, node.desc(sdfg))) if (state.in_degree(node) > 0 and node.data not in output_nodes): output_nodes.append((node.data, node.desc(sdfg))) elif isinstance(node, nodes.CodeNode) and sdict[node] is None: if not isinstance(node, nodes.EmptyTasklet): global_code_nodes[i].append(node) # Input nodes may also be nodes with WCR memlets and no identity for e in state.edges(): if e.data.wcr is not None and e.data.wcr_identity is None: if (e.data.data not in input_nodes and sdfg.arrays[e.data.data].transient == False): input_nodes.append( (e.data.data, sdfg.arrays[e.data.data])) start_state = sdfg.start_state end_states = sdfg.sink_nodes() ####################################################### # Step 1: Create cloned GPU arrays and replace originals cloned_arrays = {} for inodename, inode in set(input_nodes): if isinstance(inode, data.Scalar): # Scalars can remain on host continue newdesc = inode.clone() newdesc.storage = dtypes.StorageType.GPU_Global newdesc.transient = True name = sdfg.add_datadesc('gpu_' + inodename, newdesc, find_new_name=True) cloned_arrays[inodename] = name for onodename, onode in set(output_nodes): if onodename in cloned_arrays: continue newdesc = onode.clone() newdesc.storage = dtypes.StorageType.GPU_Global newdesc.transient = True name = sdfg.add_datadesc('gpu_' + onodename, newdesc, find_new_name=True) cloned_arrays[onodename] = name # Replace nodes for state in sdfg.nodes(): for node in state.nodes(): if (isinstance(node, nodes.AccessNode) and node.data in cloned_arrays): node.data = cloned_arrays[node.data] # Replace memlets for state in sdfg.nodes(): for edge in state.edges(): if edge.data.data in cloned_arrays: edge.data.data = cloned_arrays[edge.data.data] ####################################################### # Step 2: Create copy-in state excluded_copyin = self.exclude_copyin.split(',') copyin_state = sdfg.add_state(sdfg.label + '_copyin') sdfg.add_edge(copyin_state, start_state, ed.InterstateEdge()) for nname, desc in dtypes.deduplicate(input_nodes): if nname in excluded_copyin or nname not in cloned_arrays: continue src_array = nodes.AccessNode(nname, debuginfo=desc.debuginfo) dst_array = nodes.AccessNode(cloned_arrays[nname], debuginfo=desc.debuginfo) copyin_state.add_node(src_array) copyin_state.add_node(dst_array) copyin_state.add_nedge( src_array, dst_array, memlet.Memlet.from_array(src_array.data, src_array.desc(sdfg))) ####################################################### # Step 3: Create copy-out state excluded_copyout = self.exclude_copyout.split(',') copyout_state = sdfg.add_state(sdfg.label + '_copyout') for state in end_states: sdfg.add_edge(state, copyout_state, ed.InterstateEdge()) for nname, desc in dtypes.deduplicate(output_nodes): if nname in excluded_copyout or nname not in cloned_arrays: continue src_array = nodes.AccessNode(cloned_arrays[nname], debuginfo=desc.debuginfo) dst_array = nodes.AccessNode(nname, debuginfo=desc.debuginfo) copyout_state.add_node(src_array) copyout_state.add_node(dst_array) copyout_state.add_nedge( src_array, dst_array, memlet.Memlet.from_array(dst_array.data, dst_array.desc(sdfg))) ####################################################### # Step 4: Modify transient data storage for state in sdfg.nodes(): sdict = state.scope_dict() for node in state.nodes(): if isinstance(node, nodes.AccessNode) and node.desc(sdfg).transient: nodedesc = node.desc(sdfg) # Special case: nodes that lead to dynamic map ranges must # stay on host if any( isinstance( state.memlet_path(e)[-1].dst, nodes.EntryNode) for e in state.out_edges(node)): continue if sdict[node] is None: # NOTE: the cloned arrays match too but it's the same # storage so we don't care nodedesc.storage = dtypes.StorageType.GPU_Global # Try to move allocation/deallocation out of loops if (self.toplevel_trans and not isinstance(nodedesc, data.Stream)): nodedesc.toplevel = True else: # Make internal transients registers if self.register_trans: nodedesc.storage = dtypes.StorageType.Register ####################################################### # Step 5: Wrap free tasklets and nested SDFGs with a GPU map for state, gcodes in zip(sdfg.nodes(), global_code_nodes): for gcode in gcodes: if gcode.label in self.exclude_tasklets.split(','): continue # Create map and connectors me, mx = state.add_map(gcode.label + '_gmap', {gcode.label + '__gmapi': '0:1'}, schedule=dtypes.ScheduleType.GPU_Device) # Store in/out edges in lists so that they don't get corrupted # when they are removed from the graph in_edges = list(state.in_edges(gcode)) out_edges = list(state.out_edges(gcode)) me.in_connectors = set('IN_' + e.dst_conn for e in in_edges) me.out_connectors = set('OUT_' + e.dst_conn for e in in_edges) mx.in_connectors = set('IN_' + e.src_conn for e in out_edges) mx.out_connectors = set('OUT_' + e.src_conn for e in out_edges) # Create memlets through map for e in in_edges: state.remove_edge(e) state.add_edge(e.src, e.src_conn, me, 'IN_' + e.dst_conn, e.data) state.add_edge(me, 'OUT_' + e.dst_conn, e.dst, e.dst_conn, e.data) for e in out_edges: state.remove_edge(e) state.add_edge(e.src, e.src_conn, mx, 'IN_' + e.src_conn, e.data) state.add_edge(mx, 'OUT_' + e.src_conn, e.dst, e.dst_conn, e.data) # Map without inputs if len(in_edges) == 0: state.add_nedge(me, gcode, memlet.EmptyMemlet()) ####################################################### # Step 6: Change all top-level maps and Reduce nodes to GPU schedule for i, state in enumerate(sdfg.nodes()): sdict = state.scope_dict() for node in state.nodes(): if isinstance(node, (nodes.EntryNode, nodes.Reduce)): if sdict[node] is None: node.schedule = dtypes.ScheduleType.GPU_Device elif (isinstance(node, nodes.EntryNode) and self.sequential_innermaps): node.schedule = dtypes.ScheduleType.Sequential ####################################################### # Step 7: Introduce copy-out if data used in outgoing interstate edges for state in list(sdfg.nodes()): arrays_used = set() for e in sdfg.out_edges(state): # Used arrays = intersection between symbols and cloned arrays arrays_used.update( set(e.data.condition_symbols()) & set(cloned_arrays.keys())) # Create a state and copy out used arrays if len(arrays_used) > 0: co_state = sdfg.add_state(state.label + '_icopyout') # Reconnect outgoing edges to after interim copyout state for e in sdfg.out_edges(state): nxutil.change_edge_src(sdfg, state, co_state) # Add unconditional edge to interim state sdfg.add_edge(state, co_state, ed.InterstateEdge()) # Add copy-out nodes for nname in arrays_used: desc = sdfg.arrays[nname] src_array = nodes.AccessNode(cloned_arrays[nname], debuginfo=desc.debuginfo) dst_array = nodes.AccessNode(nname, debuginfo=desc.debuginfo) co_state.add_node(src_array) co_state.add_node(dst_array) co_state.add_nedge( src_array, dst_array, memlet.Memlet.from_array(dst_array.data, dst_array.desc(sdfg))) ####################################################### # Step 8: Strict transformations if not self.strict_transform: return # Apply strict state fusions greedily. sdfg.apply_strict_transformations()
def apply(self, sdfg): # Obtain loop information guard: sd.SDFGState = sdfg.node(self.subgraph[DetectLoop._loop_guard]) begin: sd.SDFGState = sdfg.node(self.subgraph[DetectLoop._loop_begin]) after_state: sd.SDFGState = sdfg.node( self.subgraph[DetectLoop._exit_state]) # Obtain iteration variable, range, and stride guard_inedges = sdfg.in_edges(guard) condition_edge = sdfg.edges_between(guard, begin)[0] itervar = list(guard_inedges[0].data.assignments.keys())[0] condition = condition_edge.data.condition_sympy() rng = LoopUnroll._loop_range(itervar, guard_inedges, condition) # Loop must be unrollable if self.count == 0 and any( symbolic.issymbolic(r, sdfg.constants) for r in rng): raise ValueError('Loop cannot be fully unrolled, size is symbolic') if self.count != 0: raise NotImplementedError # TODO(later) # Find the state prior to the loop if str(rng[0]) == guard_inedges[0].data.assignments[itervar]: before_state: sd.SDFGState = guard_inedges[0].src last_state: sd.SDFGState = guard_inedges[1].src else: before_state: sd.SDFGState = guard_inedges[1].src last_state: sd.SDFGState = guard_inedges[0].src # Get loop states loop_states = list( nxutil.dfs_topological_sort( sdfg, sources=[begin], condition=lambda _, child: child != guard)) first_id = loop_states.index(begin) last_id = loop_states.index(last_state) loop_subgraph = gr.SubgraphView(sdfg, loop_states) # Evaluate the real values of the loop start, end, stride = (symbolic.evaluate(r, sdfg.constants) for r in rng) # Create states for loop subgraph unrolled_states = [] for i in range(start, end + 1, stride): # Using to/from JSON copies faster than deepcopy (which will also # copy the parent SDFG) new_states = [ sd.SDFGState.from_json(s.to_json(), context={'sdfg': sdfg}) for s in loop_states ] # Replace iterate with value in each state for state in new_states: state.set_label(state.label + '_%s_%d' % (itervar, i)) state.replace(itervar, str(i)) # Add subgraph to original SDFG for edge in loop_subgraph.edges(): src = new_states[loop_states.index(edge.src)] dst = new_states[loop_states.index(edge.dst)] # Replace conditions in subgraph edges data: edges.InterstateEdge = copy.deepcopy(edge.data) if data.condition: ASTFindReplace({itervar: str(i)}).visit(data.condition) sdfg.add_edge(src, dst, data) # Connect iterations with unconditional edges if len(unrolled_states) > 0: sdfg.add_edge(unrolled_states[-1][1], new_states[first_id], edges.InterstateEdge()) unrolled_states.append((new_states[first_id], new_states[last_id])) # Connect new states to before and after states without conditions if unrolled_states: sdfg.add_edge(before_state, unrolled_states[0][0], edges.InterstateEdge()) sdfg.add_edge(unrolled_states[-1][1], after_state, edges.InterstateEdge()) # Remove old states from SDFG sdfg.remove_nodes_from([guard] + loop_states)
class StateFusion(pattern_matching.Transformation): """ Implements the state-fusion transformation. State-fusion takes two states that are connected through a single edge, and fuses them into one state. If strict, only applies if no memory access hazards are created. """ _first_state = sdfg.SDFGState() _edge = edges.InterstateEdge() _second_state = sdfg.SDFGState() @staticmethod def annotates_memlets(): return False @staticmethod def expressions(): return [ nxutil.node_path_graph(StateFusion._first_state, StateFusion._second_state) ] @staticmethod def can_be_applied(graph, candidate, expr_index, sdfg, strict=False): first_state = graph.nodes()[candidate[StateFusion._first_state]] second_state = graph.nodes()[candidate[StateFusion._second_state]] out_edges = graph.out_edges(first_state) in_edges = graph.in_edges(first_state) # First state must have only one output edge (with dst the second # state). if len(out_edges) != 1: return False # The interstate edge must not have a condition. if out_edges[0].data.condition.as_string != '': return False # The interstate edge may have assignments, as long as there are input # edges to the first state, that can absorb them. if out_edges[0].data.assignments and not in_edges: return False # There can be no state that have output edges pointing to both the # first and the second state. Such a case will produce a multi-graph. for src, _, _ in in_edges: for _, dst, _ in graph.out_edges(src): if dst == second_state: return False if strict: # If second state has other input edges, there might be issues second_in_edges = graph.in_edges(second_state) if ((not second_state.is_empty() or not first_state.is_empty()) and len(second_in_edges) != 1): return False # Get connected components. first_cc = [ cc_nodes for cc_nodes in nx.weakly_connected_components(first_state._nx) ] second_cc = [ cc_nodes for cc_nodes in nx.weakly_connected_components( second_state._nx) ] # Find source/sink (data) nodes first_input = { node for node in nxutil.find_source_nodes(first_state) if isinstance(node, nodes.AccessNode) } first_output = { node for node in first_state.nodes() if isinstance(node, nodes.AccessNode) and node not in first_input } second_input = { node for node in nxutil.find_source_nodes(second_state) if isinstance(node, nodes.AccessNode) } second_output = { node for node in second_state.nodes() if isinstance(node, nodes.AccessNode) and node not in second_input } # Find source/sink (data) nodes by connected component first_cc_input = [cc.intersection(first_input) for cc in first_cc] first_cc_output = [ cc.intersection(first_output) for cc in first_cc ] second_cc_input = [ cc.intersection(second_input) for cc in second_cc ] second_cc_output = [ cc.intersection(second_output) for cc in second_cc ] check_strict = len(first_cc) for cc_output in first_cc_output: for node in cc_output: if next((x for x in second_input if x.label == node.label), None) is not None: check_strict -= 1 break if check_strict > 0: # Check strict conditions # RW dependency for node in first_input: if next( (x for x in second_output if x.label == node.label), None) is not None: return False # WW dependency for node in first_output: if next( (x for x in second_output if x.label == node.label), None) is not None: return False return True @staticmethod def match_to_str(graph, candidate): first_state = graph.nodes()[candidate[StateFusion._first_state]] second_state = graph.nodes()[candidate[StateFusion._second_state]] return ' -> '.join(state.label for state in [first_state, second_state]) def apply(self, sdfg): first_state = sdfg.nodes()[self.subgraph[StateFusion._first_state]] second_state = sdfg.nodes()[self.subgraph[StateFusion._second_state]] # Remove interstate edge(s) edges = sdfg.edges_between(first_state, second_state) for edge in edges: if edge.data.assignments: for src, dst, other_data in sdfg.in_edges(first_state): other_data.assignments.update(edge.data.assignments) sdfg.remove_edge(edge) # Special case 1: first state is empty if first_state.is_empty(): nxutil.change_edge_dest(sdfg, first_state, second_state) sdfg.remove_node(first_state) return # Special case 2: second state is empty if second_state.is_empty(): nxutil.change_edge_src(sdfg, second_state, first_state) nxutil.change_edge_dest(sdfg, second_state, first_state) sdfg.remove_node(second_state) return # Normal case: both states are not empty # Find source/sink (data) nodes first_input = [ node for node in nxutil.find_source_nodes(first_state) if isinstance(node, nodes.AccessNode) ] first_output = [ node for node in nxutil.find_sink_nodes(first_state) if isinstance(node, nodes.AccessNode) ] second_input = [ node for node in nxutil.find_source_nodes(second_state) if isinstance(node, nodes.AccessNode) ] # first input = first input - first output first_input = [ node for node in first_input if next((x for x in first_output if x.label == node.label), None) is None ] # Merge second state to first state for node in second_state.nodes(): first_state.add_node(node) for src, src_conn, dst, dst_conn, data in second_state.edges(): first_state.add_edge(src, src_conn, dst, dst_conn, data) # Merge common (data) nodes for node in first_input: try: old_node = next(x for x in second_input if x.label == node.label) except StopIteration: continue nxutil.change_edge_src(first_state, old_node, node) first_state.remove_node(old_node) for node in first_output: try: new_node = next(x for x in second_input if x.label == node.label) except StopIteration: continue nxutil.change_edge_dest(first_state, node, new_node) first_state.remove_node(node) # Redirect edges and remove second state nxutil.change_edge_src(sdfg, second_state, first_state) sdfg.remove_node(second_state)