def expressions(): state = sd.SDFGState() state.add_nedge(DeduplicateAccess._map_entry, DeduplicateAccess._node1, Memlet()) state.add_nedge(DeduplicateAccess._map_entry, DeduplicateAccess._node2, Memlet()) return [state]
class StartStateElimination(transformation.Transformation): """ Start-state elimination removes a redundant state that has one outgoing edge and no contents. This transformation applies only to nested SDFGs. """ start_state = sdfg.SDFGState() @staticmethod def expressions(): return [sdutil.node_path_graph(StartStateElimination.start_state)] @staticmethod def can_be_applied(graph, candidate, expr_index, sdfg, strict=False): state = graph.nodes()[candidate[StartStateElimination.start_state]] # The transformation applies only to nested SDFGs if not graph.parent: return False out_edges = graph.out_edges(state) in_edges = graph.in_edges(state) # If this is a start state, there are no incoming edges if len(in_edges) != 0: return False # We only match start states with one sink and no conditions if len(out_edges) != 1: return False edge = out_edges[0] if not edge.data.is_unconditional(): return False # Only empty states can be eliminated if state.number_of_nodes() > 0: return False return True @staticmethod def match_to_str(graph, candidate): state = graph.nodes()[candidate[StartStateElimination.start_state]] return state.label def apply(self, sdfg): state = sdfg.nodes()[self.subgraph[StartStateElimination.start_state]] # Move assignments to the nested SDFG node's symbol mappings node = sdfg.parent_nsdfg_node edge = sdfg.out_edges(state)[0] for k, v in edge.data.assignments.items(): node.symbol_mapping[k] = v sdfg.remove_node(state)
class EndStateElimination(transformation.Transformation): """ End-state elimination removes a redundant state that has one incoming edge and no contents. """ _end_state = sdfg.SDFGState() @staticmethod def expressions(): return [sdutil.node_path_graph(EndStateElimination._end_state)] @staticmethod def can_be_applied(graph, candidate, expr_index, sdfg, strict=False): state = graph.nodes()[candidate[EndStateElimination._end_state]] out_edges = graph.out_edges(state) in_edges = graph.in_edges(state) # If this is an end state, there are no outgoing edges if len(out_edges) != 0: return False # We only match end states with one source and no conditions if len(in_edges) != 1: return False edge = in_edges[0] if not edge.data.is_unconditional(): return False # Only empty states can be eliminated if state.number_of_nodes() > 0: return False return True @staticmethod def match_to_str(graph, candidate): state = graph.nodes()[candidate[EndStateElimination._end_state]] return state.label def apply(self, sdfg): state = sdfg.nodes()[self.subgraph[EndStateElimination._end_state]] # Handle orphan symbols (due to the deletion the incoming edge) edge = sdfg.in_edges(state)[0] sym_assign = edge.data.assignments.keys() sdfg.remove_node(state) # Remove orphan symbols for sym in sym_assign: if sym in sdfg.free_symbols: sdfg.remove_symbol(sym)
class StateAssignElimination(transformation.Transformation): """ State assign elimination removes all assignments into the final state and subsumes the assigned value into its contents. """ _end_state = sdfg.SDFGState() @staticmethod def expressions(): return [sdutil.node_path_graph(StateAssignElimination._end_state)] @staticmethod def can_be_applied(graph, candidate, expr_index, sdfg, strict=False): state = graph.nodes()[candidate[StateAssignElimination._end_state]] out_edges = graph.out_edges(state) in_edges = graph.in_edges(state) # If this is an end state, there are no outgoing edges if len(out_edges) != 0: return False # We only match end states with one source and at least one assignment if len(in_edges) != 1: return False edge = in_edges[0] if len(edge.data.assignments) == 0: return False return True @staticmethod def match_to_str(graph, candidate): state = graph.nodes()[candidate[StateAssignElimination._end_state]] return state.label def apply(self, sdfg): state = sdfg.nodes()[self.subgraph[StateAssignElimination._end_state]] edge = sdfg.in_edges(state)[0] # Since inter-state assignments that use an assigned value leads to # undefined behavior (e.g., {m: n, n: m}), we can replace each # assignment separately. for varname, assignment in edge.data.assignments.items(): state.replace(varname, assignment) # Remove assignments from edge edge.data.assignments = {}
class StateFusion(pattern_matching.Transformation): """ Implements the state-fusion transformation. State-fusion takes two states that are connected through a single edge, and fuses them into one state. If strict, only applies if no memory access hazards are created. """ _states_fused = 0 _first_state = sdfg.SDFGState() _edge = sdfg.InterstateEdge() _second_state = sdfg.SDFGState() @staticmethod def annotates_memlets(): return False @staticmethod def expressions(): return [ sdutil.node_path_graph(StateFusion._first_state, StateFusion._second_state) ] @staticmethod def can_be_applied(graph, candidate, expr_index, sdfg, strict=False): first_state = graph.nodes()[candidate[StateFusion._first_state]] second_state = graph.nodes()[candidate[StateFusion._second_state]] out_edges = graph.out_edges(first_state) in_edges = graph.in_edges(first_state) # First state must have only one output edge (with dst the second # state). if len(out_edges) != 1: return False # The interstate edge must not have a condition. if not out_edges[0].data.is_unconditional(): return False # The interstate edge may have assignments, as long as there are input # edges to the first state that can absorb them. if out_edges[0].data.assignments: if not in_edges: return False # Fail if symbol is set before the state to fuse # TODO: Also fail if symbol is used in the dataflow of that state new_assignments = set(out_edges[0].data.assignments.keys()) if any((new_assignments & set(e.data.assignments.keys())) for e in in_edges): return False # There can be no state that have output edges pointing to both the # first and the second state. Such a case will produce a multi-graph. for src, _, _ in in_edges: for _, dst, _ in graph.out_edges(src): if dst == second_state: return False if strict: # If second state has other input edges, there might be issues # Exceptions are when none of the states contain dataflow, unless # the first state is an initial state (in which case the new initial # state would be ambiguous). first_in_edges = graph.in_edges(first_state) second_in_edges = graph.in_edges(second_state) if ((not second_state.is_empty() or not first_state.is_empty() or len(first_in_edges) == 0) and len(second_in_edges) != 1): return False # Get connected components. first_cc = [ cc_nodes for cc_nodes in nx.weakly_connected_components(first_state._nx) ] second_cc = [ cc_nodes for cc_nodes in nx.weakly_connected_components( second_state._nx) ] # Find source/sink (data) nodes first_input = { node for node in sdutil.find_source_nodes(first_state) if isinstance(node, nodes.AccessNode) } first_output = { node for node in first_state.nodes() if isinstance(node, nodes.AccessNode) and node not in first_input } second_input = { node for node in sdutil.find_source_nodes(second_state) if isinstance(node, nodes.AccessNode) } second_output = { node for node in second_state.nodes() if isinstance(node, nodes.AccessNode) and node not in second_input } # Find source/sink (data) nodes by connected component first_cc_input = [cc.intersection(first_input) for cc in first_cc] first_cc_output = [ cc.intersection(first_output) for cc in first_cc ] second_cc_input = [ cc.intersection(second_input) for cc in second_cc ] second_cc_output = [ cc.intersection(second_output) for cc in second_cc ] # Apply transformation in case all paths to the second state's # nodes go through the same access node, which implies sequential # behavior in SDFG semantics. check_strict = len(first_cc) for cc_output in first_cc_output: out_nodes = [ n for n in first_state.sink_nodes() if n in cc_output ] # Branching exists, multiple paths may involve same access node # potentially causing data races if len(out_nodes) > 1: continue # Otherwise, check if any of the second state's connected # components for matching input for node in out_nodes: if (next( (x for x in second_input if x.label == node.label), None) is not None): check_strict -= 1 break if check_strict > 0: # Check strict conditions # RW dependency for node in first_input: if (next( (x for x in second_output if x.label == node.label), None) is not None): return False # WW dependency for node in first_output: if (next( (x for x in second_output if x.label == node.label), None) is not None): return False return True @staticmethod def match_to_str(graph, candidate): first_state = graph.nodes()[candidate[StateFusion._first_state]] second_state = graph.nodes()[candidate[StateFusion._second_state]] return " -> ".join(state.label for state in [first_state, second_state]) def apply(self, sdfg): first_state = sdfg.nodes()[self.subgraph[StateFusion._first_state]] second_state = sdfg.nodes()[self.subgraph[StateFusion._second_state]] # Remove interstate edge(s) edges = sdfg.edges_between(first_state, second_state) for edge in edges: if edge.data.assignments: for src, dst, other_data in sdfg.in_edges(first_state): other_data.assignments.update(edge.data.assignments) sdfg.remove_edge(edge) # Special case 1: first state is empty if first_state.is_empty(): sdutil.change_edge_dest(sdfg, first_state, second_state) sdfg.remove_node(first_state) return # Special case 2: second state is empty if second_state.is_empty(): sdutil.change_edge_src(sdfg, second_state, first_state) sdutil.change_edge_dest(sdfg, second_state, first_state) sdfg.remove_node(second_state) return # Normal case: both states are not empty # Find source/sink (data) nodes first_input = [ node for node in sdutil.find_source_nodes(first_state) if isinstance(node, nodes.AccessNode) ] first_output = [ node for node in sdutil.find_sink_nodes(first_state) if isinstance(node, nodes.AccessNode) ] second_input = [ node for node in sdutil.find_source_nodes(second_state) if isinstance(node, nodes.AccessNode) ] # first input = first input - first output first_input = [ node for node in first_input if next((x for x in first_output if x.label == node.label), None) is None ] # Merge second state to first state # First keep a backup of the topological sorted order of the nodes order = [ x for x in reversed(list(nx.topological_sort(first_state._nx))) if isinstance(x, nodes.AccessNode) ] for node in second_state.nodes(): first_state.add_node(node) for src, src_conn, dst, dst_conn, data in second_state.edges(): first_state.add_edge(src, src_conn, dst, dst_conn, data) # Merge common (data) nodes for node in second_input: if first_state.in_degree(node) == 0: n = next((x for x in order if x.label == node.label), None) if n: sdutil.change_edge_src(first_state, node, n) first_state.remove_node(node) n.access = dtypes.AccessType.ReadWrite # Redirect edges and remove second state sdutil.change_edge_src(sdfg, second_state, first_state) sdfg.remove_node(second_state) if Config.get_bool("debugprint"): StateFusion._states_fused += 1
class DetectLoop(pattern_matching.Transformation): """ Detects a for-loop construct from an SDFG. """ _loop_guard = sd.SDFGState() _loop_begin = sd.SDFGState() _exit_state = sd.SDFGState() @staticmethod def expressions(): # Case 1: Loop with one state sdfg = sd.SDFG('_') sdfg.add_nodes_from([ DetectLoop._loop_guard, DetectLoop._loop_begin, DetectLoop._exit_state ]) sdfg.add_edge(DetectLoop._loop_guard, DetectLoop._loop_begin, edges.InterstateEdge()) sdfg.add_edge(DetectLoop._loop_guard, DetectLoop._exit_state, edges.InterstateEdge()) sdfg.add_edge(DetectLoop._loop_begin, DetectLoop._loop_guard, edges.InterstateEdge()) # Case 2: Loop with multiple states (no back-edge from state) msdfg = sd.SDFG('_') msdfg.add_nodes_from([ DetectLoop._loop_guard, DetectLoop._loop_begin, DetectLoop._exit_state ]) msdfg.add_edge(DetectLoop._loop_guard, DetectLoop._loop_begin, edges.InterstateEdge()) msdfg.add_edge(DetectLoop._loop_guard, DetectLoop._exit_state, edges.InterstateEdge()) return [sdfg, msdfg] @staticmethod def can_be_applied(graph, candidate, expr_index, sdfg, strict=False): guard = graph.node(candidate[DetectLoop._loop_guard]) begin = graph.node(candidate[DetectLoop._loop_begin]) # A for-loop guard only has two incoming edges (init and increment) guard_inedges = graph.in_edges(guard) if len(guard_inedges) != 2: return False # A for-loop guard only has two outgoing edges (loop and exit-loop) guard_outedges = graph.out_edges(guard) if len(guard_outedges) != 2: return False # Both incoming edges to guard must set exactly one variable and # the same one if (len(guard_inedges[0].data.assignments) != 1 or len(guard_inedges[1].data.assignments) != 1): return False itervar = list(guard_inedges[0].data.assignments.keys())[0] if itervar not in guard_inedges[1].data.assignments: return False # Outgoing edges must not have assignments and be a negation of each # other if any(len(e.data.assignments) > 0 for e in guard_outedges): return False if guard_outedges[0].data.condition_sympy() != (sp.Not( guard_outedges[1].data.condition_sympy())): return False # All nodes inside loop must be dominated by loop guard dominators = nx.dominance.immediate_dominators(sdfg.nx, sdfg.start_state) loop_nodes = nxutil.dfs_topological_sort( sdfg, sources=[begin], condition=lambda _, child: child != guard) backedge_found = False for node in loop_nodes: if any(e.dst == guard for e in graph.out_edges(node)): backedge_found = True # Traverse the dominator tree upwards, if we reached the guard, # the node is in the loop. If we reach the starting state # without passing through the guard, fail. dom = node while dom != dominators[dom]: if dom == guard: break dom = dominators[dom] else: return False if not backedge_found: return False return True @staticmethod def match_to_str(graph, candidate): guard = graph.node(candidate[DetectLoop._loop_guard]) begin = graph.node(candidate[DetectLoop._loop_begin]) sexit = graph.node(candidate[DetectLoop._exit_state]) ind = list(graph.in_edges(guard)[0].data.assignments.keys())[0] return (' -> '.join(state.label for state in [guard, begin, sexit]) + ' (for loop over "%s")' % ind) def apply(self, sdfg): pass
class SymbolAliasPromotion(transformation.Transformation): """ SymbolAliasPromotion moves inter-state assignments that create symbolic aliases to the previous inter-state edge according to the topological order. The purpose of this transformation is to iteratively move symbolic aliases together, so that true duplicates can be easily removed. """ _first_state = sdfg.SDFGState() _second_state = sdfg.SDFGState() @staticmethod def expressions(): return [ sdutil.node_path_graph(SymbolAliasPromotion._first_state, SymbolAliasPromotion._second_state) ] @staticmethod def can_be_applied(graph, candidate, expr_index, sdfg, strict=False): fstate = graph.nodes()[candidate[SymbolAliasPromotion._first_state]] sstate = graph.nodes()[candidate[SymbolAliasPromotion._second_state]] # For the topological order to be unambiguous: # 1. First state must have unique input edge. in_fedges = graph.in_edges(fstate) if len(in_fedges) != 1: return False in_edge = in_fedges[0].data # 2. There must be a unique edge from the first state to the second # one and no edge from the second state to the first one. edges = graph.edges_between(fstate, sstate) if len(edges) != 1: return False if len(graph.edges_between(sstate, fstate)) > 1: return False edge = edges[0].data in_edge = in_fedges[0].data to_consider = _alias_assignments(sdfg, edge) to_not_consider = set() for k, v in to_consider.items(): # Remove symbols that are taking part in the edge's condition condsyms = [str(s) for s in edge.condition_sympy().free_symbols] if k in condsyms: to_not_consider.add(k) # Remove symbols that are set in the in_edge # with a different assignment if k in in_edge.assignments and in_edge.assignments[k] != v: to_not_consider.add(k) # Remove symbols whose assignment (RHS) is a symbol # and is set in the in_edge. if v in sdfg.symbols and v in in_edge.assignments: to_not_consider.add(k) # Remove symbols whose assignment (RHS) is a scalar # and is set in the first state. if v in sdfg.arrays and isinstance(sdfg.arrays[v], dt.Scalar): if any( isinstance(n, nodes.AccessNode) and n.data == v for n in fstate.nodes()): to_not_consider.add(k) for k in to_not_consider: del to_consider[k] # No assignments to promote if len(to_consider) == 0: return False return True @staticmethod def match_to_str(graph, candidate): state = graph.nodes()[candidate[SymbolAliasPromotion._second_state]] return state.label def apply(self, sdfg): fstate = sdfg.nodes()[self.subgraph[SymbolAliasPromotion._first_state]] sstate = sdfg.nodes()[self.subgraph[SymbolAliasPromotion._second_state]] edge = sdfg.edges_between(fstate, sstate)[0].data in_edge = sdfg.in_edges(fstate)[0].data to_consider = _alias_assignments(sdfg, edge) to_not_consider = set() for k, v in to_consider.items(): # Remove symbols that are taking part in the edge's condition condsyms = [str(s) for s in edge.condition_sympy().free_symbols] if k in condsyms: to_not_consider.add(k) # Remove symbols that are set in the in_edge # with a different assignment if k in in_edge.assignments and in_edge.assignments[k] != v: to_not_consider.add(k) # Remove symbols whose assignment (RHS) is a symbol # and is set in the in_edge. if v in sdfg.symbols and v in in_edge.assignments: to_not_consider.add(k) # Remove symbols whose assignment (RHS) is a scalar # and is set in the first state. if v in sdfg.arrays and isinstance(sdfg.arrays[v], dt.Scalar): if any( isinstance(n, nodes.AccessNode) and n.data == v for n in fstate.nodes()): to_not_consider.add(k) for k in to_not_consider: del to_consider[k] for k, v in to_consider.items(): del edge.assignments[k] in_edge.assignments[k] = v
class FPGATransformState(pattern_matching.Transformation): """ Implements the FPGATransformState transformation. """ _state = sd.SDFGState() @staticmethod def expressions(): return [nxutil.node_path_graph(FPGATransformState._state)] @staticmethod def can_be_applied(graph, candidate, expr_index, sdfg, strict=False): state = graph.nodes()[candidate[FPGATransformState._state]] for node in state.nodes(): if (isinstance(node, nodes.AccessNode) and node.desc(sdfg).storage != types.StorageType.Default): return False if not isinstance(node, nodes.MapEntry): continue map_entry = node candidate_map = map_entry.map # No more than 3 dimensions if candidate_map.range.dims() > 3: return False # Map schedules that are disallowed to transform to FPGAs if (candidate_map.schedule == types.ScheduleType.MPI or candidate_map.schedule == types.ScheduleType.GPU_Device or candidate_map.schedule == types.ScheduleType.FPGA_Device or candidate_map.schedule == types.ScheduleType.GPU_ThreadBlock): return False # Recursively check parent for FPGA schedules sdict = state.scope_dict() current_node = map_entry while current_node != None: if (current_node.map.schedule == types.ScheduleType.GPU_Device or current_node.map.schedule == types.ScheduleType.FPGA_Device or current_node.map.schedule == types.ScheduleType.GPU_ThreadBlock): return False current_node = sdict[current_node] return True @staticmethod def match_to_str(graph, candidate): state = graph.nodes()[candidate[FPGATransformState._state]] return state.label def apply(self, sdfg): state = sdfg.nodes()[self.subgraph[FPGATransformState._state]] # Find source/sink (data) nodes input_nodes = nxutil.find_source_nodes(state) output_nodes = nxutil.find_sink_nodes(state) fpga_data = {} if input_nodes: pre_state = sd.SDFGState('pre_' + state.label, sdfg) for node in input_nodes: if (not isinstance(node, dace.graph.nodes.AccessNode) or not isinstance(node.desc(sdfg), dace.data.Array)): # Only transfer array nodes # TODO: handle streams continue array = node.desc(sdfg) if array.name in fpga_data: fpga_array = fpga_data[node.data] else: fpga_array = sdfg.add_array( 'fpga_' + node.data, array.dtype, array.shape, materialize_func=array.materialize_func, transient=True, storage=types.StorageType.FPGA_Global, allow_conflicts=array.allow_conflicts, access_order=array.access_order, strides=array.strides, offset=array.offset) fpga_data[array.name] = fpga_array fpga_node = type(node)(fpga_array) pre_state.add_node(node) pre_state.add_node(fpga_node) full_range = subsets.Range([(0, s - 1, 1) for s in array.shape]) mem = memlet.Memlet(array, full_range.num_elements(), full_range, 1) pre_state.add_edge(node, None, fpga_node, None, mem) state.add_node(fpga_node) nxutil.change_edge_src(state, node, fpga_node) state.remove_node(node) sdfg.add_node(pre_state) nxutil.change_edge_dest(sdfg, state, pre_state) sdfg.add_edge(pre_state, state, edges.InterstateEdge()) if output_nodes: post_state = sd.SDFGState('post_' + state.label, sdfg) for node in output_nodes: if (not isinstance(node, dace.graph.nodes.AccessNode) or not isinstance(node.desc(sdfg), dace.data.Array)): # Only transfer array nodes # TODO: handle streams continue array = node.desc(sdfg) if node.data in fpga_data: fpga_array = fpga_data[node.data] else: fpga_array = sdfg.add_array( 'fpga_' + node.data, array.dtype, array.shape, materialize_func=array.materialize_func, transient=True, storage=types.StorageType.FPGA_Global, allow_conflicts=array.allow_conflicts, access_order=array.access_order, strides=array.strides, offset=array.offset) fpga_data[node.data] = fpga_array fpga_node = type(node)(fpga_array) post_state.add_node(node) post_state.add_node(fpga_node) full_range = subsets.Range([(0, s - 1, 1) for s in array.shape]) mem = memlet.Memlet(fpga_array, full_range.num_elements(), full_range, 1) post_state.add_edge(fpga_node, None, node, None, mem) state.add_node(fpga_node) nxutil.change_edge_dest(state, node, fpga_node) state.remove_node(node) sdfg.add_node(post_state) nxutil.change_edge_src(sdfg, state, post_state) sdfg.add_edge(state, post_state, edges.InterstateEdge()) for src, _, dst, _, mem in state.edges(): if mem.data is not None and mem.data in fpga_data: mem.data = 'fpga_' + node.data fpga_update(state, 0)
class FPGATransformState(pattern_matching.Transformation): """ Implements the FPGATransformState transformation. """ _state = sd.SDFGState() @staticmethod def expressions(): return [sdutil.node_path_graph(FPGATransformState._state)] @staticmethod def can_be_applied(graph, candidate, expr_index, sdfg, strict=False): state = graph.nodes()[candidate[FPGATransformState._state]] # TODO: Support most of these cases for edge, graph in state.all_edges_recursive(): # Code->Code memlets are disallowed (for now) if (isinstance(edge.src, nodes.CodeNode) and isinstance(edge.dst, nodes.CodeNode)): return False for node, graph in state.all_nodes_recursive(): # Consume scopes are currently unsupported if isinstance(node, (nodes.ConsumeEntry, nodes.ConsumeExit)): return False # Streams have strict conditions due to code generator limitations if (isinstance(node, nodes.AccessNode) and isinstance(sdfg.arrays[node.data], data.Stream)): nodedesc = graph.parent.arrays[node.data] sdict = graph.scope_dict() if nodedesc.storage in [ dtypes.StorageType.CPU_Heap, dtypes.StorageType.CPU_Pinned, dtypes.StorageType.CPU_ThreadLocal ]: return False # Cannot allocate FIFO from CPU code if sdict[node] is None: return False # Arrays of streams cannot have symbolic size on FPGA if dace.symbolic.issymbolic(nodedesc.total_size, graph.parent.constants): return False # Streams cannot be unbounded on FPGA if nodedesc.buffer_size < 1: return False for node in state.nodes(): if (isinstance(node, nodes.AccessNode) and node.desc(sdfg).storage != dtypes.StorageType.Default): return False if not isinstance(node, nodes.MapEntry): continue map_entry = node candidate_map = map_entry.map # No more than 3 dimensions if candidate_map.range.dims() > 3: return False # Map schedules that are disallowed to transform to FPGAs if (candidate_map.schedule == dtypes.ScheduleType.MPI or candidate_map.schedule == dtypes.ScheduleType.GPU_Device or candidate_map.schedule == dtypes.ScheduleType.FPGA_Device or candidate_map.schedule == dtypes.ScheduleType.GPU_ThreadBlock): return False # Recursively check parent for FPGA schedules sdict = state.scope_dict() current_node = map_entry while current_node is not None: if (current_node.map.schedule == dtypes.ScheduleType.GPU_Device or current_node.map.schedule == dtypes.ScheduleType.FPGA_Device or current_node.map.schedule == dtypes.ScheduleType.GPU_ThreadBlock): return False current_node = sdict[current_node] return True @staticmethod def match_to_str(graph, candidate): state = graph.nodes()[candidate[FPGATransformState._state]] return state.label def apply(self, sdfg): state = sdfg.nodes()[self.subgraph[FPGATransformState._state]] # Find source/sink (data) nodes input_nodes = sdutil.find_source_nodes(state) output_nodes = sdutil.find_sink_nodes(state) fpga_data = {} # Input nodes may also be nodes with WCR memlets # We have to recur across nested SDFGs to find them wcr_input_nodes = set() stack = [] parent_sdfg = {state: sdfg} # Map states to their parent SDFG for node, graph in state.all_nodes_recursive(): if isinstance(graph, dace.SDFG): parent_sdfg[node] = graph if isinstance(node, dace.sdfg.nodes.AccessNode): for e in graph.all_edges(node): if e.data.wcr is not None: trace = dace.sdfg.trace_nested_access( node, graph, parent_sdfg[graph]) for node_trace, state_trace, sdfg_trace in trace: # Find the name of the accessed node in our scope if state_trace == state and sdfg_trace == sdfg: outer_node = node_trace break else: # This does not trace back to the current state, so # we don't care continue input_nodes.append(outer_node) wcr_input_nodes.add(outer_node) if input_nodes: # create pre_state pre_state = sd.SDFGState('pre_' + state.label, sdfg) for node in input_nodes: if not isinstance(node, dace.sdfg.nodes.AccessNode): continue desc = node.desc(sdfg) if not isinstance(desc, dace.data.Array): # TODO: handle streams continue if node.data in fpga_data: fpga_array = fpga_data[node.data] elif node not in wcr_input_nodes: fpga_array = sdfg.add_array( 'fpga_' + node.data, desc.shape, desc.dtype, materialize_func=desc.materialize_func, transient=True, storage=dtypes.StorageType.FPGA_Global, allow_conflicts=desc.allow_conflicts, strides=desc.strides, offset=desc.offset) fpga_data[node.data] = fpga_array pre_node = pre_state.add_read(node.data) pre_fpga_node = pre_state.add_write('fpga_' + node.data) full_range = subsets.Range([(0, s - 1, 1) for s in desc.shape]) mem = memlet.Memlet(node.data, full_range.num_elements(), full_range, 1) pre_state.add_edge(pre_node, None, pre_fpga_node, None, mem) if node not in wcr_input_nodes: fpga_node = state.add_read('fpga_' + node.data) sdutil.change_edge_src(state, node, fpga_node) state.remove_node(node) sdfg.add_node(pre_state) sdutil.change_edge_dest(sdfg, state, pre_state) sdfg.add_edge(pre_state, state, sd.InterstateEdge()) if output_nodes: post_state = sd.SDFGState('post_' + state.label, sdfg) for node in output_nodes: if not isinstance(node, dace.sdfg.nodes.AccessNode): continue desc = node.desc(sdfg) if not isinstance(desc, dace.data.Array): # TODO: handle streams continue if node.data in fpga_data: fpga_array = fpga_data[node.data] else: fpga_array = sdfg.add_array( 'fpga_' + node.data, desc.shape, desc.dtype, materialize_func=desc.materialize_func, transient=True, storage=dtypes.StorageType.FPGA_Global, allow_conflicts=desc.allow_conflicts, strides=desc.strides, offset=desc.offset) fpga_data[node.data] = fpga_array # fpga_node = type(node)(fpga_array) post_node = post_state.add_write(node.data) post_fpga_node = post_state.add_read('fpga_' + node.data) full_range = subsets.Range([(0, s - 1, 1) for s in desc.shape]) mem = memlet.Memlet('fpga_' + node.data, full_range.num_elements(), full_range, 1) post_state.add_edge(post_fpga_node, None, post_node, None, mem) fpga_node = state.add_write('fpga_' + node.data) sdutil.change_edge_dest(state, node, fpga_node) state.remove_node(node) sdfg.add_node(post_state) sdutil.change_edge_src(sdfg, state, post_state) sdfg.add_edge(state, post_state, sd.InterstateEdge()) veclen_ = 1 # propagate vector info from a nested sdfg for src, src_conn, dst, dst_conn, mem in state.edges(): # need to go inside the nested SDFG and grab the vector length if isinstance(dst, dace.sdfg.nodes.NestedSDFG): # this edge is going to the nested SDFG for inner_state in dst.sdfg.states(): for n in inner_state.nodes(): if isinstance(n, dace.sdfg.nodes.AccessNode ) and n.data == dst_conn: # assuming all memlets have the same vector length veclen_ = inner_state.all_edges(n)[0].data.veclen if isinstance(src, dace.sdfg.nodes.NestedSDFG): # this edge is coming from the nested SDFG for inner_state in src.sdfg.states(): for n in inner_state.nodes(): if isinstance(n, dace.sdfg.nodes.AccessNode ) and n.data == src_conn: # assuming all memlets have the same vector length veclen_ = inner_state.all_edges(n)[0].data.veclen if mem.data is not None and mem.data in fpga_data: mem.data = 'fpga_' + mem.data mem.veclen = veclen_ fpga_update(sdfg, state, 0)
class StateFusion(transformation.Transformation): """ Implements the state-fusion transformation. State-fusion takes two states that are connected through a single edge, and fuses them into one state. If strict, only applies if no memory access hazards are created. """ _first_state = sdfg.SDFGState() _second_state = sdfg.SDFGState() @staticmethod def annotates_memlets(): return False @staticmethod def expressions(): return [ sdutil.node_path_graph(StateFusion._first_state, StateFusion._second_state) ] @staticmethod def find_fused_components(first_cc_input, first_cc_output, second_cc_input, second_cc_output) -> List[CCDesc]: # Make a bipartite graph out of the first and second components g = nx.DiGraph() g.add_nodes_from((0, i) for i in range(len(first_cc_output))) g.add_nodes_from((1, i) for i in range(len(second_cc_output))) # Find matching nodes in second state for i, cc1 in enumerate(first_cc_output): outnames1 = {n.data for n in cc1} for j, cc2 in enumerate(second_cc_input): inpnames2 = {n.data for n in cc2} if len(outnames1 & inpnames2) > 0: g.add_edge((0, i), (1, j)) # Construct result out of connected components of the bipartite graph result = [] for cc in nx.weakly_connected_components(g): input1, output1, input2, output2 = set(), set(), set(), set() for gind, cind in cc: if gind == 0: input1 |= {n.data for n in first_cc_input[cind]} output1 |= {n.data for n in first_cc_output[cind]} else: input2 |= {n.data for n in second_cc_input[cind]} output2 |= {n.data for n in second_cc_output[cind]} result.append(CCDesc(input1, output1, input2, output2)) return result @staticmethod def memlets_intersect(graph_a: SDFGState, group_a: List[nodes.AccessNode], inputs_a: bool, graph_b: SDFGState, group_b: List[nodes.AccessNode], inputs_b: bool) -> bool: """ Performs an all-pairs check for subset intersection on two groups of nodes. If group intersects or result is indeterminate, returns True as a precaution. :param graph_a: The graph in which the first set of nodes reside. :param group_a: The first set of nodes to check. :param inputs_a: If True, checks inputs of the first group. :param graph_b: The graph in which the second set of nodes reside. :param group_b: The second set of nodes to check. :param inputs_b: If True, checks inputs of the second group. :returns True if subsets intersect or result is indeterminate. """ # Set traversal functions src_subset = lambda e: (e.data.src_subset if e.data.src_subset is not None else e.data.dst_subset) dst_subset = lambda e: (e.data.dst_subset if e.data.dst_subset is not None else e.data.src_subset) if inputs_a: edges_a = [e for n in group_a for e in graph_a.out_edges(n)] subset_a = src_subset else: edges_a = [e for n in group_a for e in graph_a.in_edges(n)] subset_a = dst_subset if inputs_b: edges_b = [e for n in group_b for e in graph_b.out_edges(n)] subset_b = src_subset else: edges_b = [e for n in group_b for e in graph_b.in_edges(n)] subset_b = dst_subset # Simple all-pairs check for ea in edges_a: for eb in edges_b: result = subsets.intersects(subset_a(ea), subset_b(eb)) if result is True or result is None: return True return False @staticmethod def can_be_applied(graph, candidate, expr_index, sdfg, strict=False): first_state = graph.nodes()[candidate[StateFusion._first_state]] second_state = graph.nodes()[candidate[StateFusion._second_state]] out_edges = graph.out_edges(first_state) in_edges = graph.in_edges(first_state) # First state must have only one output edge (with dst the second # state). if len(out_edges) != 1: return False # The interstate edge must not have a condition. if not out_edges[0].data.is_unconditional(): return False # The interstate edge may have assignments, as long as there are input # edges to the first state that can absorb them. if out_edges[0].data.assignments: if not in_edges: return False # Fail if symbol is set before the state to fuse new_assignments = set(out_edges[0].data.assignments.keys()) if any((new_assignments & set(e.data.assignments.keys())) for e in in_edges): return False # Fail if symbol is used in the dataflow of that state if len(new_assignments & first_state.free_symbols) > 0: return False # There can be no state that have output edges pointing to both the # first and the second state. Such a case will produce a multi-graph. for src, _, _ in in_edges: for _, dst, _ in graph.out_edges(src): if dst == second_state: return False if strict: # If second state has other input edges, there might be issues # Exceptions are when none of the states contain dataflow, unless # the first state is an initial state (in which case the new initial # state would be ambiguous). first_in_edges = graph.in_edges(first_state) second_in_edges = graph.in_edges(second_state) if ((not second_state.is_empty() or not first_state.is_empty() or len(first_in_edges) == 0) and len(second_in_edges) != 1): return False # Get connected components. first_cc = [ cc_nodes for cc_nodes in nx.weakly_connected_components(first_state._nx) ] second_cc = [ cc_nodes for cc_nodes in nx.weakly_connected_components( second_state._nx) ] # Find source/sink (data) nodes first_input = { node for node in sdutil.find_source_nodes(first_state) if isinstance(node, nodes.AccessNode) } first_output = { node for node in first_state.nodes() if isinstance(node, nodes.AccessNode) and node not in first_input } second_input = { node for node in sdutil.find_source_nodes(second_state) if isinstance(node, nodes.AccessNode) } second_output = { node for node in second_state.nodes() if isinstance(node, nodes.AccessNode) and node not in second_input } # Find source/sink (data) nodes by connected component first_cc_input = [cc.intersection(first_input) for cc in first_cc] first_cc_output = [ cc.intersection(first_output) for cc in first_cc ] second_cc_input = [ cc.intersection(second_input) for cc in second_cc ] second_cc_output = [ cc.intersection(second_output) for cc in second_cc ] # Apply transformation in case all paths to the second state's # nodes go through the same access node, which implies sequential # behavior in SDFG semantics. first_output_names = {node.data for node in first_output} second_input_names = {node.data for node in second_input} # If any second input appears more than once, fail if len(second_input) > len(second_input_names): return False # If any first output that is an input to the second state # appears in more than one CC, fail matches = first_output_names & second_input_names for match in matches: cc_appearances = 0 for cc in first_cc_output: if len([n for n in cc if n.data == match]) > 0: cc_appearances += 1 if cc_appearances > 1: return False # Recreate fused connected component correspondences, and then # check for hazards resulting_ccs: List[CCDesc] = StateFusion.find_fused_components( first_cc_input, first_cc_output, second_cc_input, second_cc_output) # Check for data races for fused_cc in resulting_ccs: # Write-Write hazard - data is output of both first and second # states, without a read in between write_write_candidates = ( (fused_cc.first_outputs & fused_cc.second_outputs) - fused_cc.second_inputs) if len(write_write_candidates) > 0: # If we have potential candidates, check if there is a # path from the first write to the second write (in that # case, there is no hazard): # Find the leaf (topological) instances of the matches order = [ x for x in reversed( list(nx.topological_sort(first_state._nx))) if isinstance(x, nodes.AccessNode) and x.data in fused_cc.first_outputs ] # Those nodes will be the connection points upon fusion match_nodes = { next(n for n in order if n.data == match) for match in (fused_cc.first_outputs & fused_cc.second_inputs) } else: match_nodes = set() for cand in write_write_candidates: nodes_first = [n for n in first_output if n.data == cand] nodes_second = [n for n in second_output if n.data == cand] # If there is a path for the candidate that goes through # the match nodes in both states, there is no conflict fail = False path_found = False for match in match_nodes: for node in nodes_first: path_to = nx.has_path(first_state._nx, node, match) if not path_to: continue path_found = True node2 = next(n for n in second_input if n.data == match.data) if not all( nx.has_path(second_state._nx, node2, n) for n in nodes_second): fail = True break if fail or path_found: break # Check for intersection (if None, fusion is ok) if fail or not path_found: if StateFusion.memlets_intersect( first_state, nodes_first, False, second_state, nodes_second, False): return False # End of write-write hazard check first_inout = fused_cc.first_inputs | fused_cc.first_outputs for other_cc in resulting_ccs: if other_cc is fused_cc: continue # If an input/output of a connected component in the first # state is an output of another connected component in the # second state, we have a potential data race (Read-Write # or Write-Write) for d in first_inout: if d in other_cc.second_outputs: # Check for intersection (if None, fusion is ok) nodes_second = [ n for n in second_output if n.data == d ] # Read-Write race if d in fused_cc.first_inputs: nodes_first = [ n for n in first_input if n.data == d ] if StateFusion.memlets_intersect( first_state, nodes_first, True, second_state, nodes_second, False): return False # Write-Write race if d in fused_cc.first_outputs: nodes_first = [ n for n in first_output if n.data == d ] if StateFusion.memlets_intersect( first_state, nodes_first, False, second_state, nodes_second, False): return False # End of data race check return True @staticmethod def match_to_str(graph, candidate): first_state = graph.nodes()[candidate[StateFusion._first_state]] second_state = graph.nodes()[candidate[StateFusion._second_state]] return " -> ".join(state.label for state in [first_state, second_state]) def apply(self, sdfg): first_state = sdfg.nodes()[self.subgraph[StateFusion._first_state]] second_state = sdfg.nodes()[self.subgraph[StateFusion._second_state]] # Remove interstate edge(s) edges = sdfg.edges_between(first_state, second_state) for edge in edges: if edge.data.assignments: for src, dst, other_data in sdfg.in_edges(first_state): other_data.assignments.update(edge.data.assignments) sdfg.remove_edge(edge) # Special case 1: first state is empty if first_state.is_empty(): sdutil.change_edge_dest(sdfg, first_state, second_state) sdfg.remove_node(first_state) return # Special case 2: second state is empty if second_state.is_empty(): sdutil.change_edge_src(sdfg, second_state, first_state) sdutil.change_edge_dest(sdfg, second_state, first_state) sdfg.remove_node(second_state) return # Normal case: both states are not empty # Find source/sink (data) nodes first_input = [ node for node in sdutil.find_source_nodes(first_state) if isinstance(node, nodes.AccessNode) ] first_output = [ node for node in sdutil.find_sink_nodes(first_state) if isinstance(node, nodes.AccessNode) ] second_input = [ node for node in sdutil.find_source_nodes(second_state) if isinstance(node, nodes.AccessNode) ] # first input = first input - first output first_input = [ node for node in first_input if next((x for x in first_output if x.label == node.label), None) is None ] # Merge second state to first state # First keep a backup of the topological sorted order of the nodes order = [ x for x in reversed(list(nx.topological_sort(first_state._nx))) if isinstance(x, nodes.AccessNode) ] for node in second_state.nodes(): first_state.add_node(node) for src, src_conn, dst, dst_conn, data in second_state.edges(): first_state.add_edge(src, src_conn, dst, dst_conn, data) # Merge common (data) nodes for node in second_input: if first_state.in_degree(node) == 0: n = next((x for x in order if x.label == node.label), None) if n: sdutil.change_edge_src(first_state, node, n) first_state.remove_node(node) n.access = dtypes.AccessType.ReadWrite # Redirect edges and remove second state sdutil.change_edge_src(sdfg, second_state, first_state) sdfg.remove_node(second_state)
class FPGATransformState(transformation.Transformation): """ Implements the FPGATransformState transformation. """ _state = sd.SDFGState() @staticmethod def expressions(): return [sdutil.node_path_graph(FPGATransformState._state)] @staticmethod def can_be_applied(graph, candidate, expr_index, sdfg, permissive=False): state = graph.nodes()[candidate[FPGATransformState._state]] for node, graph in state.all_nodes_recursive(): # Consume scopes are currently unsupported if isinstance(node, (nodes.ConsumeEntry, nodes.ConsumeExit)): return False # Streams have strict conditions due to code generator limitations if (isinstance(node, nodes.AccessNode) and isinstance( graph.parent.arrays[node.data], data.Stream)): nodedesc = graph.parent.arrays[node.data] sdict = graph.scope_dict() if nodedesc.storage in [ dtypes.StorageType.CPU_Heap, dtypes.StorageType.CPU_Pinned, dtypes.StorageType.CPU_ThreadLocal ]: return False # Cannot allocate FIFO from CPU code if sdict[node] is None: return False # Arrays of streams cannot have symbolic size on FPGA if dace.symbolic.issymbolic(nodedesc.total_size, graph.parent.constants): return False # Streams cannot be unbounded on FPGA if nodedesc.buffer_size < 1: return False for node in state.nodes(): if (isinstance(node, nodes.AccessNode) and node.desc(sdfg).storage not in (dtypes.StorageType.Default, dtypes.StorageType.Register)): return False if not isinstance(node, nodes.MapEntry): continue map_entry = node candidate_map = map_entry.map # Map schedules that are disallowed to transform to FPGAs if (candidate_map.schedule == dtypes.ScheduleType.MPI or candidate_map.schedule == dtypes.ScheduleType.GPU_Device or candidate_map.schedule == dtypes.ScheduleType.FPGA_Device or candidate_map.schedule == dtypes.ScheduleType.GPU_ThreadBlock): return False # Recursively check parent for FPGA schedules sdict = state.scope_dict() current_node = map_entry while current_node is not None: if (current_node.map.schedule == dtypes.ScheduleType.GPU_Device or current_node.map.schedule == dtypes.ScheduleType.FPGA_Device or current_node.map.schedule == dtypes.ScheduleType.GPU_ThreadBlock): return False current_node = sdict[current_node] return True @staticmethod def match_to_str(graph, candidate): state = graph.nodes()[candidate[FPGATransformState._state]] return state.label def apply(self, sdfg): state = sdfg.nodes()[self.subgraph[FPGATransformState._state]] # Find source/sink (data) nodes that are relevant outside this FPGA # kernel shared_transients = set(sdfg.shared_transients()) input_nodes = [ n for n in sdutil.find_source_nodes(state) if isinstance(n, nodes.AccessNode) and (not sdfg.arrays[n.data].transient or n.data in shared_transients) ] output_nodes = [ n for n in sdutil.find_sink_nodes(state) if isinstance(n, nodes.AccessNode) and (not sdfg.arrays[n.data].transient or n.data in shared_transients) ] fpga_data = {} # Input nodes may also be nodes with WCR memlets # We have to recur across nested SDFGs to find them wcr_input_nodes = set() stack = [] parent_sdfg = {state: sdfg} # Map states to their parent SDFG for node, graph in state.all_nodes_recursive(): if isinstance(graph, dace.SDFG): parent_sdfg[node] = graph if isinstance(node, dace.sdfg.nodes.AccessNode): for e in graph.in_edges(node): if e.data.wcr is not None: trace = dace.sdfg.trace_nested_access( node, graph, parent_sdfg[graph]) for node_trace, memlet_trace, state_trace, sdfg_trace in trace: # Find the name of the accessed node in our scope if state_trace == state and sdfg_trace == sdfg: _, outer_node = node_trace if outer_node is not None: break else: # This does not trace back to the current state, so # we don't care continue input_nodes.append(outer_node) wcr_input_nodes.add(outer_node) if input_nodes: # create pre_state pre_state = sd.SDFGState('pre_' + state.label, sdfg) for node in input_nodes: if not isinstance(node, dace.sdfg.nodes.AccessNode): continue desc = node.desc(sdfg) if not isinstance(desc, dace.data.Array): # TODO: handle streams continue if node.data in fpga_data: fpga_array = fpga_data[node.data] elif node not in wcr_input_nodes: fpga_array = sdfg.add_array( 'fpga_' + node.data, desc.shape, desc.dtype, transient=True, storage=dtypes.StorageType.FPGA_Global, allow_conflicts=desc.allow_conflicts, strides=desc.strides, offset=desc.offset) fpga_array[1].location = copy.copy(desc.location) desc.location.clear() fpga_data[node.data] = fpga_array pre_node = pre_state.add_read(node.data) pre_fpga_node = pre_state.add_write('fpga_' + node.data) mem = memlet.Memlet(data=node.data, subset=subsets.Range.from_array(desc)) pre_state.add_edge(pre_node, None, pre_fpga_node, None, mem) if node not in wcr_input_nodes: fpga_node = state.add_read('fpga_' + node.data) sdutil.change_edge_src(state, node, fpga_node) state.remove_node(node) sdfg.add_node(pre_state) sdutil.change_edge_dest(sdfg, state, pre_state) sdfg.add_edge(pre_state, state, sd.InterstateEdge()) if output_nodes: post_state = sd.SDFGState('post_' + state.label, sdfg) for node in output_nodes: if not isinstance(node, dace.sdfg.nodes.AccessNode): continue desc = node.desc(sdfg) if not isinstance(desc, dace.data.Array): # TODO: handle streams continue if node.data in fpga_data: fpga_array = fpga_data[node.data] else: fpga_array = sdfg.add_array( 'fpga_' + node.data, desc.shape, desc.dtype, transient=True, storage=dtypes.StorageType.FPGA_Global, allow_conflicts=desc.allow_conflicts, strides=desc.strides, offset=desc.offset) fpga_array[1].location = copy.copy(desc.location) desc.location.clear() fpga_data[node.data] = fpga_array # fpga_node = type(node)(fpga_array) post_node = post_state.add_write(node.data) post_fpga_node = post_state.add_read('fpga_' + node.data) mem = memlet.Memlet(f"fpga_{node.data}", None, subsets.Range.from_array(desc)) post_state.add_edge(post_fpga_node, None, post_node, None, mem) fpga_node = state.add_write('fpga_' + node.data) sdutil.change_edge_dest(state, node, fpga_node) state.remove_node(node) sdfg.add_node(post_state) sdutil.change_edge_src(sdfg, state, post_state) sdfg.add_edge(state, post_state, sd.InterstateEdge()) # propagate memlet info from a nested sdfg for src, src_conn, dst, dst_conn, mem in state.edges(): if mem.data is not None and mem.data in fpga_data: mem.data = 'fpga_' + mem.data fpga_update(sdfg, state, 0)
class StateFusion(pattern_matching.Transformation): """ Implements the state-fusion transformation. State-fusion takes two states that are connected through a single edge, and fuses them into one state. If strict, only applies if no memory access hazards are created. """ _states_fused = 0 _first_state = sdfg.SDFGState() _edge = edges.InterstateEdge() _second_state = sdfg.SDFGState() @staticmethod def annotates_memlets(): return False @staticmethod def expressions(): return [ nxutil.node_path_graph(StateFusion._first_state, StateFusion._second_state) ] @staticmethod def can_be_applied(graph, candidate, expr_index, sdfg, strict=False): first_state = graph.nodes()[candidate[StateFusion._first_state]] second_state = graph.nodes()[candidate[StateFusion._second_state]] out_edges = graph.out_edges(first_state) in_edges = graph.in_edges(first_state) # First state must have only one output edge (with dst the second # state). if len(out_edges) != 1: return False # The interstate edge must not have a condition. if out_edges[0].data.condition.as_string != "": return False # The interstate edge may have assignments, as long as there are input # edges to the first state, that can absorb them. if out_edges[0].data.assignments and not in_edges: return False # There can be no state that have output edges pointing to both the # first and the second state. Such a case will produce a multi-graph. for src, _, _ in in_edges: for _, dst, _ in graph.out_edges(src): if dst == second_state: return False if strict: # If second state has other input edges, there might be issues # Exceptions are when none of the states contain dataflow, unless # the first state is an initial state (in which case the new initial # state would be ambiguous). first_in_edges = graph.in_edges(first_state) second_in_edges = graph.in_edges(second_state) if ((not second_state.is_empty() or not first_state.is_empty() or len(first_in_edges) == 0) and len(second_in_edges) != 1): return False # Get connected components. first_cc = [ cc_nodes for cc_nodes in nx.weakly_connected_components(first_state._nx) ] second_cc = [ cc_nodes for cc_nodes in nx.weakly_connected_components( second_state._nx) ] # Find source/sink (data) nodes first_input = { node for node in nxutil.find_source_nodes(first_state) if isinstance(node, nodes.AccessNode) } first_output = { node for node in first_state.nodes() if isinstance(node, nodes.AccessNode) and node not in first_input } second_input = { node for node in nxutil.find_source_nodes(second_state) if isinstance(node, nodes.AccessNode) } second_output = { node for node in second_state.nodes() if isinstance(node, nodes.AccessNode) and node not in second_input } # Find source/sink (data) nodes by connected component first_cc_input = [cc.intersection(first_input) for cc in first_cc] first_cc_output = [ cc.intersection(first_output) for cc in first_cc ] second_cc_input = [ cc.intersection(second_input) for cc in second_cc ] second_cc_output = [ cc.intersection(second_output) for cc in second_cc ] check_strict = len(first_cc) for cc_output in first_cc_output: for node in cc_output: if (next( (x for x in second_input if x.label == node.label), None) is not None): check_strict -= 1 break if check_strict > 0: # Check strict conditions # RW dependency for node in first_input: if (next( (x for x in second_output if x.label == node.label), None) is not None): return False # WW dependency for node in first_output: if (next( (x for x in second_output if x.label == node.label), None) is not None): return False return True @staticmethod def match_to_str(graph, candidate): first_state = graph.nodes()[candidate[StateFusion._first_state]] second_state = graph.nodes()[candidate[StateFusion._second_state]] return " -> ".join(state.label for state in [first_state, second_state]) def apply(self, sdfg): first_state = sdfg.nodes()[self.subgraph[StateFusion._first_state]] second_state = sdfg.nodes()[self.subgraph[StateFusion._second_state]] # Remove interstate edge(s) edges = sdfg.edges_between(first_state, second_state) for edge in edges: if edge.data.assignments: for src, dst, other_data in sdfg.in_edges(first_state): other_data.assignments.update(edge.data.assignments) sdfg.remove_edge(edge) # Special case 1: first state is empty if first_state.is_empty(): nxutil.change_edge_dest(sdfg, first_state, second_state) sdfg.remove_node(first_state) return # Special case 2: second state is empty if second_state.is_empty(): nxutil.change_edge_src(sdfg, second_state, first_state) nxutil.change_edge_dest(sdfg, second_state, first_state) sdfg.remove_node(second_state) return # Normal case: both states are not empty # Find source/sink (data) nodes first_input = [ node for node in nxutil.find_source_nodes(first_state) if isinstance(node, nodes.AccessNode) ] first_output = [ node for node in nxutil.find_sink_nodes(first_state) if isinstance(node, nodes.AccessNode) ] second_input = [ node for node in nxutil.find_source_nodes(second_state) if isinstance(node, nodes.AccessNode) ] # first input = first input - first output first_input = [ node for node in first_input if next((x for x in first_output if x.label == node.label), None) is None ] # Merge second state to first state for node in second_state.nodes(): first_state.add_node(node) for src, src_conn, dst, dst_conn, data in second_state.edges(): first_state.add_edge(src, src_conn, dst, dst_conn, data) # Merge common (data) nodes for node in first_input: try: old_node = next(x for x in second_input if x.label == node.label) except StopIteration: continue nxutil.change_edge_src(first_state, old_node, node) first_state.remove_node(old_node) second_input.remove(old_node) for node in first_output: try: new_node = next(x for x in second_input if x.label == node.label) except StopIteration: continue nxutil.change_edge_dest(first_state, node, new_node) first_state.remove_node(node) second_input.remove(new_node) # Redirect edges and remove second state nxutil.change_edge_src(sdfg, second_state, first_state) sdfg.remove_node(second_state) if Config.get_bool("debugprint"): StateFusion._states_fused += 1 @staticmethod def print_debuginfo(): print("Automatically fused {} states using StateFusion transform.". format(StateFusion._states_fused))
def apply(self, sdfg): state = sdfg.nodes()[self.subgraph[FPGATransformState._state]] # Find source/sink (data) nodes input_nodes = nxutil.find_source_nodes(state) output_nodes = nxutil.find_sink_nodes(state) fpga_data = {} if input_nodes: pre_state = sd.SDFGState('pre_' + state.label, sdfg) for node in input_nodes: if (not isinstance(node, dace.graph.nodes.AccessNode) or not isinstance(node.desc(sdfg), dace.data.Array)): # Only transfer array nodes # TODO: handle streams continue array = node.desc(sdfg) if array.name in fpga_data: fpga_array = fpga_data[node.data] else: fpga_array = sdfg.add_array( 'fpga_' + node.data, array.dtype, array.shape, materialize_func=array.materialize_func, transient=True, storage=types.StorageType.FPGA_Global, allow_conflicts=array.allow_conflicts, access_order=array.access_order, strides=array.strides, offset=array.offset) fpga_data[array.name] = fpga_array fpga_node = type(node)(fpga_array) pre_state.add_node(node) pre_state.add_node(fpga_node) full_range = subsets.Range([(0, s - 1, 1) for s in array.shape]) mem = memlet.Memlet(array, full_range.num_elements(), full_range, 1) pre_state.add_edge(node, None, fpga_node, None, mem) state.add_node(fpga_node) nxutil.change_edge_src(state, node, fpga_node) state.remove_node(node) sdfg.add_node(pre_state) nxutil.change_edge_dest(sdfg, state, pre_state) sdfg.add_edge(pre_state, state, edges.InterstateEdge()) if output_nodes: post_state = sd.SDFGState('post_' + state.label, sdfg) for node in output_nodes: if (not isinstance(node, dace.graph.nodes.AccessNode) or not isinstance(node.desc(sdfg), dace.data.Array)): # Only transfer array nodes # TODO: handle streams continue array = node.desc(sdfg) if node.data in fpga_data: fpga_array = fpga_data[node.data] else: fpga_array = sdfg.add_array( 'fpga_' + node.data, array.dtype, array.shape, materialize_func=array.materialize_func, transient=True, storage=types.StorageType.FPGA_Global, allow_conflicts=array.allow_conflicts, access_order=array.access_order, strides=array.strides, offset=array.offset) fpga_data[node.data] = fpga_array fpga_node = type(node)(fpga_array) post_state.add_node(node) post_state.add_node(fpga_node) full_range = subsets.Range([(0, s - 1, 1) for s in array.shape]) mem = memlet.Memlet(fpga_array, full_range.num_elements(), full_range, 1) post_state.add_edge(fpga_node, None, node, None, mem) state.add_node(fpga_node) nxutil.change_edge_dest(state, node, fpga_node) state.remove_node(node) sdfg.add_node(post_state) nxutil.change_edge_src(sdfg, state, post_state) sdfg.add_edge(state, post_state, edges.InterstateEdge()) for src, _, dst, _, mem in state.edges(): if mem.data is not None and mem.data in fpga_data: mem.data = 'fpga_' + node.data fpga_update(state, 0)
def apply(self, _, sdfg): state = self.state # Find source/sink (data) nodes that are relevant outside this FPGA # kernel shared_transients = set(sdfg.shared_transients()) input_nodes = [ n for n in sdutil.find_source_nodes(state) if isinstance(n, nodes.AccessNode) and (not sdfg.arrays[n.data].transient or n.data in shared_transients) ] output_nodes = [ n for n in sdutil.find_sink_nodes(state) if isinstance(n, nodes.AccessNode) and (not sdfg.arrays[n.data].transient or n.data in shared_transients) ] fpga_data = {} # Input nodes may also be nodes with WCR memlets # We have to recur across nested SDFGs to find them wcr_input_nodes = set() stack = [] parent_sdfg = {state: sdfg} # Map states to their parent SDFG for node, graph in state.all_nodes_recursive(): if isinstance(graph, dace.SDFG): parent_sdfg[node] = graph if isinstance(node, dace.sdfg.nodes.AccessNode): for e in graph.in_edges(node): if e.data.wcr is not None: trace = dace.sdfg.trace_nested_access( node, graph, parent_sdfg[graph]) for node_trace, memlet_trace, state_trace, sdfg_trace in trace: # Find the name of the accessed node in our scope if state_trace == state and sdfg_trace == sdfg: _, outer_node = node_trace if outer_node is not None: break else: # This does not trace back to the current state, so # we don't care continue input_nodes.append(outer_node) wcr_input_nodes.add(outer_node) if input_nodes: # create pre_state pre_state = sd.SDFGState('pre_' + state.label, sdfg) for node in input_nodes: if not isinstance(node, dace.sdfg.nodes.AccessNode): continue desc = node.desc(sdfg) if not isinstance(desc, dace.data.Array): # TODO: handle streams continue if node.data in fpga_data: fpga_array = fpga_data[node.data] elif node not in wcr_input_nodes: fpga_array = sdfg.add_array( 'fpga_' + node.data, desc.shape, desc.dtype, transient=True, storage=dtypes.StorageType.FPGA_Global, allow_conflicts=desc.allow_conflicts, strides=desc.strides, offset=desc.offset) fpga_array[1].location = copy.copy(desc.location) desc.location.clear() fpga_data[node.data] = fpga_array pre_node = pre_state.add_read(node.data) pre_fpga_node = pre_state.add_write('fpga_' + node.data) mem = memlet.Memlet(data=node.data, subset=subsets.Range.from_array(desc)) pre_state.add_edge(pre_node, None, pre_fpga_node, None, mem) if node not in wcr_input_nodes: fpga_node = state.add_read('fpga_' + node.data) sdutil.change_edge_src(state, node, fpga_node) state.remove_node(node) sdfg.add_node(pre_state) sdutil.change_edge_dest(sdfg, state, pre_state) sdfg.add_edge(pre_state, state, sd.InterstateEdge()) if output_nodes: post_state = sd.SDFGState('post_' + state.label, sdfg) for node in output_nodes: if not isinstance(node, dace.sdfg.nodes.AccessNode): continue desc = node.desc(sdfg) if not isinstance(desc, dace.data.Array): # TODO: handle streams continue if node.data in fpga_data: fpga_array = fpga_data[node.data] else: fpga_array = sdfg.add_array( 'fpga_' + node.data, desc.shape, desc.dtype, transient=True, storage=dtypes.StorageType.FPGA_Global, allow_conflicts=desc.allow_conflicts, strides=desc.strides, offset=desc.offset) fpga_array[1].location = copy.copy(desc.location) desc.location.clear() fpga_data[node.data] = fpga_array # fpga_node = type(node)(fpga_array) post_node = post_state.add_write(node.data) post_fpga_node = post_state.add_read('fpga_' + node.data) mem = memlet.Memlet(f"fpga_{node.data}", None, subsets.Range.from_array(desc)) post_state.add_edge(post_fpga_node, None, post_node, None, mem) fpga_node = state.add_write('fpga_' + node.data) sdutil.change_edge_dest(state, node, fpga_node) state.remove_node(node) sdfg.add_node(post_state) sdutil.change_edge_src(sdfg, state, post_state) sdfg.add_edge(state, post_state, sd.InterstateEdge()) # propagate memlet info from a nested sdfg for src, src_conn, dst, dst_conn, mem in state.edges(): if mem.data is not None and mem.data in fpga_data: mem.data = 'fpga_' + mem.data fpga_update(sdfg, state, 0)
class DoubleBuffering(pattern_matching.Transformation): """ Implements the double buffering pattern, which pipelines reading and processing data by creating a second copy of the memory. """ _begin = sd.SDFGState() _guard = sd.SDFGState() _body = sd.SDFGState() _end = sd.SDFGState() @staticmethod def expressions(): for_loop_graph = dace.graph.graph.OrderedDiGraph() for_loop_graph.add_nodes_from([ DoubleBuffering._begin, DoubleBuffering._guard, DoubleBuffering._body, DoubleBuffering._end ]) for_loop_graph.add_edge(DoubleBuffering._begin, DoubleBuffering._guard, None) for_loop_graph.add_edge(DoubleBuffering._guard, DoubleBuffering._body, None) for_loop_graph.add_edge(DoubleBuffering._body, DoubleBuffering._guard, None) for_loop_graph.add_edge(DoubleBuffering._guard, DoubleBuffering._end, None) return [for_loop_graph] @staticmethod def can_be_applied(graph, candidate, expr_index, sdfg, strict=False): begin = graph.nodes()[candidate[DoubleBuffering._begin]] guard = graph.nodes()[candidate[DoubleBuffering._guard]] body = graph.nodes()[candidate[DoubleBuffering._body]] end = graph.nodes()[candidate[DoubleBuffering._end]] if not begin.is_empty(): return False if not guard.is_empty(): return False if not end.is_empty(): return False if body.is_empty(): return False return True @staticmethod def match_to_str(graph, candidate): begin = graph.nodes()[candidate[DoubleBuffering._begin]] guard = graph.nodes()[candidate[DoubleBuffering._guard]] body = graph.nodes()[candidate[DoubleBuffering._body]] end = graph.nodes()[candidate[DoubleBuffering._end]] return ', '.join(state.label for state in [begin, guard, body, end]) def apply(self, sdfg): begin = sdfg.nodes()[self.subgraph[DoubleBuffering._begin]] guard = sdfg.nodes()[self.subgraph[DoubleBuffering._guard]] body = sdfg.nodes()[self.subgraph[DoubleBuffering._body]] end = sdfg.nodes()[self.subgraph[DoubleBuffering._end]] loop_vars = [] for _, dst, e in sdfg.out_edges(body): if dst is guard: for var in e.assignments.keys(): loop_vars.append(var) if len(loop_vars) != 1: raise NotImplementedError() loop_var = loop_vars[0] sym_var = dace.symbolic.pystr_to_symbolic(loop_var) # Find source/sink (data) nodes input_nodes = nxutil.find_source_nodes(body) #output_nodes = nxutil.find_sink_nodes(body) copied_nodes = set() db_nodes = {} for node in input_nodes: for _, _, dst, _, mem in body.out_edges(node): if (isinstance(dst, dace.graph.nodes.AccessNode) and loop_var in mem.subset.free_symbols): # Create new data and nodes in guard if node not in copied_nodes: guard.add_node(node) copied_nodes.add(node) if dst not in copied_nodes: old_data = dst.desc(sdfg) if isinstance(old_data, dace.data.Array): new_shape = tuple([2] + list(old_data.shape)) new_data = sdfg.add_array(old_data.data, old_data.dtype, new_shape, transient=True) elif isinstance(old_data, data.Scalar): new_data = sdfg.add_array(old_data.data, old_data.dtype, (2), transient=True) else: raise NotImplementedError() new_node = dace.graph.nodes.AccessNode(old_data.data) guard.add_node(new_node) copied_nodes.add(dst) db_nodes.update({dst: new_node}) # Create memlet in guard new_mem = copy.deepcopy(mem) old_index = new_mem.other_subset if isinstance(old_index, dace.subsets.Range): new_ranges = [(0, 0, 1)] + old_index.ranges new_mem.other_subset = dace.subsets.Range(new_ranges) elif isinstance(old_index, dace.subsets.Indices): new_indices = [0] + old_index.indices new_mem.other_subset = dace.subsets.Indices( new_indices) guard.add_edge(node, None, new_node, None, new_mem) # Create nodes, memlets in body first_node = copy.deepcopy(new_node) second_node = copy.deepcopy(new_node) body.add_nodes_from([first_node, second_node]) dace.graph.nxutil.change_edge_dest(body, dst, first_node) dace.graph.nxutil.change_edge_src(body, dst, second_node) for src, _, dest, _, memm in body.edges(): if src is node and dest is first_node: old_index = memm.other_subset idx = (sym_var + 1) % 2 if isinstance(old_index, dace.subsets.Range): new_ranges = [(idx, idx, 1)] + old_index.ranges elif isinstance(old_index, dace.subsets.Indices): new_ranges = [(idx, idx, 1)] for index in old_index.indices: new_ranges.append((index, index, 1)) memm.other_subset = dace.subsets.Range(new_ranges) elif memm.data == dst.data: old_index = memm.subset idx = sym_var % 2 if isinstance(old_index, dace.subsets.Range): new_ranges = [(idx, idx, 1)] + old_index.ranges elif isinstance(old_index, dace.subsets.Indices): new_ranges = [(idx, idx, 1)] for index in old_index.indices: new_ranges.append((index, index, 1)) memm.subset = dace.subsets.Range(new_ranges) memm.data = first_node.data body.remove_node(dst)
def apply(self, sdfg): state = sdfg.nodes()[self.subgraph[FPGATransformState._state]] # Find source/sink (data) nodes input_nodes = sdutil.find_source_nodes(state) output_nodes = sdutil.find_sink_nodes(state) fpga_data = {} # Input nodes may also be nodes with WCR memlets # We have to recur across nested SDFGs to find them wcr_input_nodes = set() stack = [] parent_sdfg = {state: sdfg} # Map states to their parent SDFG for node, graph in state.all_nodes_recursive(): if isinstance(graph, dace.SDFG): parent_sdfg[node] = graph if isinstance(node, dace.sdfg.nodes.AccessNode): for e in graph.all_edges(node): if e.data.wcr is not None: trace = dace.sdfg.trace_nested_access( node, graph, parent_sdfg[graph]) for node_trace, state_trace, sdfg_trace in trace: # Find the name of the accessed node in our scope if state_trace == state and sdfg_trace == sdfg: outer_node = node_trace break else: # This does not trace back to the current state, so # we don't care continue input_nodes.append(outer_node) wcr_input_nodes.add(outer_node) if input_nodes: # create pre_state pre_state = sd.SDFGState('pre_' + state.label, sdfg) for node in input_nodes: if not isinstance(node, dace.sdfg.nodes.AccessNode): continue desc = node.desc(sdfg) if not isinstance(desc, dace.data.Array): # TODO: handle streams continue if node.data in fpga_data: fpga_array = fpga_data[node.data] elif node not in wcr_input_nodes: fpga_array = sdfg.add_array( 'fpga_' + node.data, desc.shape, desc.dtype, materialize_func=desc.materialize_func, transient=True, storage=dtypes.StorageType.FPGA_Global, allow_conflicts=desc.allow_conflicts, strides=desc.strides, offset=desc.offset) fpga_data[node.data] = fpga_array pre_node = pre_state.add_read(node.data) pre_fpga_node = pre_state.add_write('fpga_' + node.data) full_range = subsets.Range([(0, s - 1, 1) for s in desc.shape]) mem = memlet.Memlet(node.data, full_range.num_elements(), full_range, 1) pre_state.add_edge(pre_node, None, pre_fpga_node, None, mem) if node not in wcr_input_nodes: fpga_node = state.add_read('fpga_' + node.data) sdutil.change_edge_src(state, node, fpga_node) state.remove_node(node) sdfg.add_node(pre_state) sdutil.change_edge_dest(sdfg, state, pre_state) sdfg.add_edge(pre_state, state, sd.InterstateEdge()) if output_nodes: post_state = sd.SDFGState('post_' + state.label, sdfg) for node in output_nodes: if not isinstance(node, dace.sdfg.nodes.AccessNode): continue desc = node.desc(sdfg) if not isinstance(desc, dace.data.Array): # TODO: handle streams continue if node.data in fpga_data: fpga_array = fpga_data[node.data] else: fpga_array = sdfg.add_array( 'fpga_' + node.data, desc.shape, desc.dtype, materialize_func=desc.materialize_func, transient=True, storage=dtypes.StorageType.FPGA_Global, allow_conflicts=desc.allow_conflicts, strides=desc.strides, offset=desc.offset) fpga_data[node.data] = fpga_array # fpga_node = type(node)(fpga_array) post_node = post_state.add_write(node.data) post_fpga_node = post_state.add_read('fpga_' + node.data) full_range = subsets.Range([(0, s - 1, 1) for s in desc.shape]) mem = memlet.Memlet('fpga_' + node.data, full_range.num_elements(), full_range, 1) post_state.add_edge(post_fpga_node, None, post_node, None, mem) fpga_node = state.add_write('fpga_' + node.data) sdutil.change_edge_dest(state, node, fpga_node) state.remove_node(node) sdfg.add_node(post_state) sdutil.change_edge_src(sdfg, state, post_state) sdfg.add_edge(state, post_state, sd.InterstateEdge()) veclen_ = 1 # propagate vector info from a nested sdfg for src, src_conn, dst, dst_conn, mem in state.edges(): # need to go inside the nested SDFG and grab the vector length if isinstance(dst, dace.sdfg.nodes.NestedSDFG): # this edge is going to the nested SDFG for inner_state in dst.sdfg.states(): for n in inner_state.nodes(): if isinstance(n, dace.sdfg.nodes.AccessNode ) and n.data == dst_conn: # assuming all memlets have the same vector length veclen_ = inner_state.all_edges(n)[0].data.veclen if isinstance(src, dace.sdfg.nodes.NestedSDFG): # this edge is coming from the nested SDFG for inner_state in src.sdfg.states(): for n in inner_state.nodes(): if isinstance(n, dace.sdfg.nodes.AccessNode ) and n.data == src_conn: # assuming all memlets have the same vector length veclen_ = inner_state.all_edges(n)[0].data.veclen if mem.data is not None and mem.data in fpga_data: mem.data = 'fpga_' + mem.data mem.veclen = veclen_ fpga_update(sdfg, state, 0)
class StateAssignElimination(transformation.Transformation): """ State assign elimination removes all assignments into the final state and subsumes the assigned value into its contents. """ _end_state = sdfg.SDFGState() @staticmethod def expressions(): return [sdutil.node_path_graph(StateAssignElimination._end_state)] @staticmethod def can_be_applied(graph, candidate, expr_index, sdfg, strict=False): state = graph.nodes()[candidate[StateAssignElimination._end_state]] out_edges = graph.out_edges(state) in_edges = graph.in_edges(state) # We only match end states with one source and at least one assignment if len(in_edges) != 1: return False edge = in_edges[0] assignments_to_consider = _assignments_to_consider(sdfg, edge) # No assignments to eliminate if len(assignments_to_consider) == 0: return False # If this is an end state, there are no other edges to consider if len(out_edges) == 0: return True # Otherwise, ensure the symbols are never set/used again in edges akeys = set(assignments_to_consider.keys()) for e in sdfg.edges(): if e is edge: continue if e.data.free_symbols & akeys: return False # If used in any state that is not the current one, fail for s in sdfg.nodes(): if s is state: continue if s.free_symbols & akeys: return False return True @staticmethod def match_to_str(graph, candidate): state = graph.nodes()[candidate[StateAssignElimination._end_state]] return state.label def apply(self, sdfg): state = sdfg.nodes()[self.subgraph[StateAssignElimination._end_state]] edge = sdfg.in_edges(state)[0] # Since inter-state assignments that use an assigned value leads to # undefined behavior (e.g., {m: n, n: m}), we can replace each # assignment separately. keys_to_remove = set() assignments_to_consider = _assignments_to_consider(sdfg, edge) for varname, assignment in assignments_to_consider.items(): state.replace(varname, assignment) keys_to_remove.add(varname) repl_dict = {} for varname in keys_to_remove: # Remove assignments from edge del edge.data.assignments[varname] for e in sdfg.edges(): if varname in e.data.free_symbols: break else: # If removed assignment does not appear in any other edge, # replace and remove symbol if varname in sdfg.symbols: sdfg.remove_symbol(varname) # if assignments_to_consider[varname] in sdfg.symbols: if varname in sdfg.free_symbols: repl_dict[varname] = assignments_to_consider[varname] def _str_repl(s, d): for k, v in d.items(): s.replace(str(k), str(v)) if repl_dict: symbolic.safe_replace(repl_dict, lambda m: _str_repl(sdfg, m))
class DetectLoop(transformation.Transformation): """ Detects a for-loop construct from an SDFG. """ _loop_guard = sd.SDFGState() _loop_begin = sd.SDFGState() _exit_state = sd.SDFGState() @staticmethod def expressions(): # Case 1: Loop with one state sdfg = sd.SDFG('_') sdfg.add_nodes_from([ DetectLoop._loop_guard, DetectLoop._loop_begin, DetectLoop._exit_state ]) sdfg.add_edge(DetectLoop._loop_guard, DetectLoop._loop_begin, sd.InterstateEdge()) sdfg.add_edge(DetectLoop._loop_guard, DetectLoop._exit_state, sd.InterstateEdge()) sdfg.add_edge(DetectLoop._loop_begin, DetectLoop._loop_guard, sd.InterstateEdge()) # Case 2: Loop with multiple states (no back-edge from state) msdfg = sd.SDFG('_') msdfg.add_nodes_from([ DetectLoop._loop_guard, DetectLoop._loop_begin, DetectLoop._exit_state ]) msdfg.add_edge(DetectLoop._loop_guard, DetectLoop._loop_begin, sd.InterstateEdge()) msdfg.add_edge(DetectLoop._loop_guard, DetectLoop._exit_state, sd.InterstateEdge()) return [sdfg, msdfg] @staticmethod def can_be_applied(graph, candidate, expr_index, sdfg, permissive=False): guard = graph.node(candidate[DetectLoop._loop_guard]) begin = graph.node(candidate[DetectLoop._loop_begin]) # A for-loop guard only has two incoming edges (init and increment) guard_inedges = graph.in_edges(guard) if len(guard_inedges) < 2: return False # A for-loop guard only has two outgoing edges (loop and exit-loop) guard_outedges = graph.out_edges(guard) if len(guard_outedges) != 2: return False # All incoming edges to the guard must set the same variable itvar = None for iedge in guard_inedges: if itvar is None: itvar = set(iedge.data.assignments.keys()) else: itvar &= iedge.data.assignments.keys() if itvar is None: return False # Outgoing edges must be a negation of each other if guard_outedges[0].data.condition_sympy() != (sp.Not( guard_outedges[1].data.condition_sympy())): return False # All nodes inside loop must be dominated by loop guard dominators = nx.dominance.immediate_dominators(sdfg.nx, sdfg.start_state) loop_nodes = sdutil.dfs_conditional( sdfg, sources=[begin], condition=lambda _, child: child != guard) backedge = None for node in loop_nodes: for e in graph.out_edges(node): if e.dst == guard: backedge = e break # Traverse the dominator tree upwards, if we reached the guard, # the node is in the loop. If we reach the starting state # without passing through the guard, fail. dom = node while dom != dominators[dom]: if dom == guard: break dom = dominators[dom] else: return False if backedge is None: return False # The backedge must assignment the iteration variable itvar &= backedge.data.assignments.keys() if len(itvar) != 1: # Either no consistent iteration variable found, or too many # consistent iteration variables found return False return True @staticmethod def match_to_str(graph, candidate): guard = graph.node(candidate[DetectLoop._loop_guard]) begin = graph.node(candidate[DetectLoop._loop_begin]) sexit = graph.node(candidate[DetectLoop._exit_state]) ind = list(graph.in_edges(guard)[0].data.assignments.keys())[0] return (' -> '.join(state.label for state in [guard, begin, sexit]) + ' (for loop over "%s")' % ind) def apply(self, sdfg): pass
def apply(self, sdfg): state = sdfg.nodes()[self.subgraph[FPGATransformState._state]] # Find source/sink (data) nodes input_nodes = nxutil.find_source_nodes(state) output_nodes = nxutil.find_sink_nodes(state) fpga_data = {} # Input nodes may also be nodes with WCR memlets # We have to recur across nested SDFGs to find them wcr_input_nodes = set() stack = [] for node, graph in state.all_nodes_recursive(): if isinstance(node, dace.graph.nodes.AccessNode): for e in graph.all_edges(node): if e.data.wcr is not None: # This is an output node with wcr # find the target in the parent sdfg # following the structure State->SDFG->State-> SDFG # from the current_state we have to go two levels up parent_state = graph.parent.parent if parent_state is not None: for parent_edges in parent_state.edges(): if parent_edges.src_conn == e.dst.data or ( isinstance(parent_edges.dst, dace.graph.nodes.AccessNode) and e.dst.data == parent_edges.dst.data): # This must be copied to device input_nodes.append(parent_edges.dst) wcr_input_nodes.add(parent_edges.dst) if input_nodes: # create pre_state pre_state = sd.SDFGState('pre_' + state.label, sdfg) for node in input_nodes: if (not isinstance(node, dace.graph.nodes.AccessNode) or not isinstance(node.desc(sdfg), dace.data.Array)): # Only transfer array nodes # TODO: handle streams continue array = node.desc(sdfg) if node.data in fpga_data: fpga_array = fpga_data[node.data] elif node not in wcr_input_nodes: fpga_array = sdfg.add_array( 'fpga_' + node.data, array.shape, array.dtype, materialize_func=array.materialize_func, transient=True, storage=dtypes.StorageType.FPGA_Global, allow_conflicts=array.allow_conflicts, strides=array.strides, offset=array.offset) fpga_data[node.data] = fpga_array pre_node = pre_state.add_read(node.data) pre_fpga_node = pre_state.add_write('fpga_' + node.data) full_range = subsets.Range([(0, s - 1, 1) for s in array.shape]) mem = memlet.Memlet(node.data, full_range.num_elements(), full_range, 1) pre_state.add_edge(pre_node, None, pre_fpga_node, None, mem) if node not in wcr_input_nodes: fpga_node = state.add_read('fpga_' + node.data) nxutil.change_edge_src(state, node, fpga_node) state.remove_node(node) sdfg.add_node(pre_state) nxutil.change_edge_dest(sdfg, state, pre_state) sdfg.add_edge(pre_state, state, edges.InterstateEdge()) if output_nodes: post_state = sd.SDFGState('post_' + state.label, sdfg) for node in output_nodes: if (not isinstance(node, dace.graph.nodes.AccessNode) or not isinstance(node.desc(sdfg), dace.data.Array)): # Only transfer array nodes # TODO: handle streams continue array = node.desc(sdfg) if node.data in fpga_data: fpga_array = fpga_data[node.data] else: fpga_array = sdfg.add_array( 'fpga_' + node.data, array.shape, array.dtype, materialize_func=array.materialize_func, transient=True, storage=dtypes.StorageType.FPGA_Global, allow_conflicts=array.allow_conflicts, strides=array.strides, offset=array.offset) fpga_data[node.data] = fpga_array # fpga_node = type(node)(fpga_array) post_node = post_state.add_write(node.data) post_fpga_node = post_state.add_read('fpga_' + node.data) full_range = subsets.Range([(0, s - 1, 1) for s in array.shape]) mem = memlet.Memlet('fpga_' + node.data, full_range.num_elements(), full_range, 1) post_state.add_edge(post_fpga_node, None, post_node, None, mem) fpga_node = state.add_write('fpga_' + node.data) nxutil.change_edge_dest(state, node, fpga_node) state.remove_node(node) sdfg.add_node(post_state) nxutil.change_edge_src(sdfg, state, post_state) sdfg.add_edge(state, post_state, edges.InterstateEdge()) veclen_ = 1 # propagate vector info from a nested sdfg for src, src_conn, dst, dst_conn, mem in state.edges(): # need to go inside the nested SDFG and grab the vector length if isinstance(dst, dace.graph.nodes.NestedSDFG): # this edge is going to the nested SDFG for inner_state in dst.sdfg.states(): for n in inner_state.nodes(): if isinstance(n, dace.graph.nodes.AccessNode ) and n.data == dst_conn: # assuming all memlets have the same vector length veclen_ = inner_state.all_edges(n)[0].data.veclen if isinstance(src, dace.graph.nodes.NestedSDFG): # this edge is coming from the nested SDFG for inner_state in src.sdfg.states(): for n in inner_state.nodes(): if isinstance(n, dace.graph.nodes.AccessNode ) and n.data == src_conn: # assuming all memlets have the same vector length veclen_ = inner_state.all_edges(n)[0].data.veclen if mem.data is not None and mem.data in fpga_data: mem.data = 'fpga_' + mem.data mem.veclen = veclen_ fpga_update(sdfg, state, 0)