def can_be_applied(graph, candidate, expr_index, sdfg, strict=False): first_state = graph.nodes()[candidate[StateFusion._first_state]] second_state = graph.nodes()[candidate[StateFusion._second_state]] out_edges = graph.out_edges(first_state) in_edges = graph.in_edges(first_state) # First state must have only one output edge (with dst the second # state). if len(out_edges) != 1: return False # The interstate edge must not have a condition. if not out_edges[0].data.is_unconditional(): return False # The interstate edge may have assignments, as long as there are input # edges to the first state, that can absorb them. if out_edges[0].data.assignments and not in_edges: return False # There can be no state that have output edges pointing to both the # first and the second state. Such a case will produce a multi-graph. for src, _, _ in in_edges: for _, dst, _ in graph.out_edges(src): if dst == second_state: return False if strict: # If second state has other input edges, there might be issues # Exceptions are when none of the states contain dataflow, unless # the first state is an initial state (in which case the new initial # state would be ambiguous). first_in_edges = graph.in_edges(first_state) second_in_edges = graph.in_edges(second_state) if ((not second_state.is_empty() or not first_state.is_empty() or len(first_in_edges) == 0) and len(second_in_edges) != 1): return False # Get connected components. first_cc = [ cc_nodes for cc_nodes in nx.weakly_connected_components(first_state._nx) ] second_cc = [ cc_nodes for cc_nodes in nx.weakly_connected_components( second_state._nx) ] # Find source/sink (data) nodes first_input = { node for node in nxutil.find_source_nodes(first_state) if isinstance(node, nodes.AccessNode) } first_output = { node for node in first_state.nodes() if isinstance(node, nodes.AccessNode) and node not in first_input } second_input = { node for node in nxutil.find_source_nodes(second_state) if isinstance(node, nodes.AccessNode) } second_output = { node for node in second_state.nodes() if isinstance(node, nodes.AccessNode) and node not in second_input } # Find source/sink (data) nodes by connected component first_cc_input = [cc.intersection(first_input) for cc in first_cc] first_cc_output = [ cc.intersection(first_output) for cc in first_cc ] second_cc_input = [ cc.intersection(second_input) for cc in second_cc ] second_cc_output = [ cc.intersection(second_output) for cc in second_cc ] # Apply transformation in case all paths to the second state's # nodes go through the same access node, which implies sequential # behavior in SDFG semantics. check_strict = len(first_cc) for cc_output in first_cc_output: out_nodes = [ n for n in first_state.sink_nodes() if n in cc_output ] # Branching exists, multiple paths may involve same access node # potentially causing data races if len(out_nodes) > 1: continue # Otherwise, check if any of the second state's connected # components for matching input for node in cc_output: if (next( (x for x in second_input if x.label == node.label), None) is not None): check_strict -= 1 break if check_strict > 0: # Check strict conditions # RW dependency for node in first_input: if (next( (x for x in second_output if x.label == node.label), None) is not None): return False # WW dependency for node in first_output: if (next( (x for x in second_output if x.label == node.label), None) is not None): return False return True
def apply(self, sdfg): # Copy SDFG to nested SDFG nested_sdfg = dace.SDFG('nested_' + sdfg.label) nested_sdfg.add_nodes_from(sdfg.nodes()) for src, dst, data in sdfg.edges(): nested_sdfg.add_edge(src, dst, data) input_orig = {} input_data = set() input_nodes = {} output_orig = {} output_data = set() output_nodes = {} for state in sdfg.nodes(): for node in nxutil.find_source_nodes(state): if isinstance( node, nodes.AccessNode) and not node.desc(sdfg).transient: if node.data not in input_data: input_orig.update({node.data + '_in': node.data}) input_nodes.update({node.data + '_in': dc(node)}) new_data = dc(node.desc(sdfg)) input_data.add(node.data) sdfg.arrays.update({node.data + '_in': new_data}) node.data = node.data + '_in' for node in nxutil.find_sink_nodes(state): if isinstance( node, nodes.AccessNode) and not node.desc(sdfg).transient: if node.data not in output_data: output_orig.update({node.data + '_out': node.data}) output_nodes.update({node.data + '_out': dc(node)}) new_data = dc(node.desc(sdfg)) output_data.add(node.data) sdfg.arrays.update({node.data + '_out': new_data}) # WCR Fix if self.promote_global_trans: for edge in state.in_edges(node): if sd._memlet_path(state, edge)[0].data.wcr: if node.data not in input_data: input_orig.update( {node.data + '_in': node.data}) input_nodes.update( {node.data + '_in': dc(node)}) new_data = dc(node.desc(sdfg)) sdfg.arrays.update( {node.data + '_in': new_data}) input_data.add(node.data + '_in') break node.data = node.data + '_out' if self.promote_global_trans: scope_dict = state.scope_dict() for node in state.nodes(): if (isinstance(node, nodes.AccessNode) and node.desc(sdfg).transient and not scope_dict[node]): if node.data not in output_data: output_orig.update({node.data + '_out': node.data}) output_nodes.update({node.data + '_out': dc(node)}) new_data = dc(node.desc(sdfg)) output_data.add(node.data + '_out') sdfg.arrays.update({node.data + '_out': new_data}) node.data = node.data + '_out' node.desc(sdfg).transient = False for _, edge in enumerate(state.edges()): _, _, _, _, mem = edge src = sd._memlet_path(state, edge)[0].src dst = sd._memlet_path(state, edge)[-1].dst if isinstance(src, nodes.AccessNode) and src.data in input_data: mem.data = src.data if isinstance(src, nodes.AccessNode) and src.data in output_data: mem.data = src.data if isinstance(dst, nodes.AccessNode) and dst.data in output_data: mem.data = dst.data sdfg.remove_nodes_from(sdfg.nodes()) state = sdfg.add_state(sdfg.label) state.add_nodes_from(input_nodes.values()) state.add_nodes_from(output_nodes.values()) nested_node = state.add_nested_sdfg(nested_sdfg, sdfg, input_data.keys(), output_data.keys()) for key, val in input_nodes.items(): state.add_edge( val, None, nested_node, key, memlet.Memlet.simple( val, str(subsets.Range.from_array(val.desc(sdfg))))) for key, val in output_nodes.items(): state.add_edge( nested_node, key, val, None, memlet.Memlet.simple( val, str(subsets.Range.from_array(val.desc(sdfg)))))
def apply(self, sdfg): first_state = sdfg.nodes()[self.subgraph[StateFusion._first_state]] second_state = sdfg.nodes()[self.subgraph[StateFusion._second_state]] # Remove interstate edge(s) edges = sdfg.edges_between(first_state, second_state) for edge in edges: if edge.data.assignments: for src, dst, other_data in sdfg.in_edges(first_state): other_data.assignments.update(edge.data.assignments) sdfg.remove_edge(edge) # Special case 1: first state is empty if first_state.is_empty(): nxutil.change_edge_dest(sdfg, first_state, second_state) sdfg.remove_node(first_state) return # Special case 2: second state is empty if second_state.is_empty(): nxutil.change_edge_src(sdfg, second_state, first_state) nxutil.change_edge_dest(sdfg, second_state, first_state) sdfg.remove_node(second_state) return # Normal case: both states are not empty # Find source/sink (data) nodes first_input = [ node for node in nxutil.find_source_nodes(first_state) if isinstance(node, nodes.AccessNode) ] first_output = [ node for node in nxutil.find_sink_nodes(first_state) if isinstance(node, nodes.AccessNode) ] second_input = [ node for node in nxutil.find_source_nodes(second_state) if isinstance(node, nodes.AccessNode) ] # first input = first input - first output first_input = [ node for node in first_input if next((x for x in first_output if x.label == node.label), None) is None ] # Merge second state to first state # First keep a backup of the topological sorted order of the nodes order = [ x for x in reversed(list(nx.topological_sort(first_state._nx))) if isinstance(x, nodes.AccessNode) ] for node in second_state.nodes(): first_state.add_node(node) for src, src_conn, dst, dst_conn, data in second_state.edges(): first_state.add_edge(src, src_conn, dst, dst_conn, data) # Merge common (data) nodes for node in first_input: try: old_node = next(x for x in second_input if x.label == node.label) except StopIteration: continue nxutil.change_edge_src(first_state, old_node, node) first_state.remove_node(old_node) second_input.remove(old_node) node.access = dtypes.AccessType.ReadWrite for node in first_output: try: new_node = next(x for x in second_input if x.label == node.label) except StopIteration: continue nxutil.change_edge_dest(first_state, node, new_node) first_state.remove_node(node) second_input.remove(new_node) new_node.access = dtypes.AccessType.ReadWrite # Check if any input nodes of the second state have to be merged with # non-input/output nodes of the first state. for node in second_input: if first_state.in_degree(node) == 0: n = next((x for x in order if x.label == node.label), None) if n: nxutil.change_edge_src(first_state, node, n) first_state.remove_node(node) n.access = dtypes.AccessType.ReadWrite # Redirect edges and remove second state nxutil.change_edge_src(sdfg, second_state, first_state) sdfg.remove_node(second_state) if Config.get_bool("debugprint"): StateFusion._states_fused += 1
def apply(self, sdfg): state = sdfg.nodes()[self.subgraph[FPGATransformState._state]] # Find source/sink (data) nodes input_nodes = nxutil.find_source_nodes(state) output_nodes = nxutil.find_sink_nodes(state) fpga_data = {} # Input nodes may also be nodes with WCR memlets # We have to recur across nested SDFGs to find them wcr_input_nodes = set() stack = [] parent_sdfg = {state: sdfg} # Map states to their parent SDFG for node, graph in state.all_nodes_recursive(): if isinstance(graph, dace.SDFG): parent_sdfg[node] = graph if isinstance(node, dace.graph.nodes.AccessNode): for e in graph.all_edges(node): if e.data.wcr is not None: trace = dace.sdfg.trace_nested_access( node, graph, parent_sdfg[graph]) for node_trace, state_trace, sdfg_trace in trace: # Find the name of the accessed node in our scope if state_trace == state and sdfg_trace == sdfg: outer_name = node_trace.data break else: # This does not trace back to the current state, so # we don't care continue input_nodes.append(outer_name) wcr_input_nodes.add(outer_name) if input_nodes: # create pre_state pre_state = sd.SDFGState('pre_' + state.label, sdfg) for node in input_nodes: if not isinstance(node, dace.graph.nodes.AccessNode): continue desc = node.desc(sdfg) if not isinstance(desc, dace.data.Array): # TODO: handle streams continue if node.data in fpga_data: fpga_array = fpga_data[node.data] elif node not in wcr_input_nodes: fpga_array = sdfg.add_array( 'fpga_' + node.data, desc.shape, desc.dtype, materialize_func=desc.materialize_func, transient=True, storage=dtypes.StorageType.FPGA_Global, allow_conflicts=desc.allow_conflicts, strides=desc.strides, offset=desc.offset) fpga_data[node.data] = fpga_array pre_node = pre_state.add_read(node.data) pre_fpga_node = pre_state.add_write('fpga_' + node.data) full_range = subsets.Range([(0, s - 1, 1) for s in desc.shape]) mem = memlet.Memlet(node.data, full_range.num_elements(), full_range, 1) pre_state.add_edge(pre_node, None, pre_fpga_node, None, mem) if node not in wcr_input_nodes: fpga_node = state.add_read('fpga_' + node.data) nxutil.change_edge_src(state, node, fpga_node) state.remove_node(node) sdfg.add_node(pre_state) nxutil.change_edge_dest(sdfg, state, pre_state) sdfg.add_edge(pre_state, state, edges.InterstateEdge()) if output_nodes: post_state = sd.SDFGState('post_' + state.label, sdfg) for node in output_nodes: if not isinstance(node, dace.graph.nodes.AccessNode): continue desc = node.desc(sdfg) if not isinstance(desc, dace.data.Array): # TODO: handle streams continue if node.data in fpga_data: fpga_array = fpga_data[node.data] else: fpga_array = sdfg.add_array( 'fpga_' + node.data, desc.shape, desc.dtype, materialize_func=desc.materialize_func, transient=True, storage=dtypes.StorageType.FPGA_Global, allow_conflicts=desc.allow_conflicts, strides=desc.strides, offset=desc.offset) fpga_data[node.data] = fpga_array # fpga_node = type(node)(fpga_array) post_node = post_state.add_write(node.data) post_fpga_node = post_state.add_read('fpga_' + node.data) full_range = subsets.Range([(0, s - 1, 1) for s in desc.shape]) mem = memlet.Memlet('fpga_' + node.data, full_range.num_elements(), full_range, 1) post_state.add_edge(post_fpga_node, None, post_node, None, mem) fpga_node = state.add_write('fpga_' + node.data) nxutil.change_edge_dest(state, node, fpga_node) state.remove_node(node) sdfg.add_node(post_state) nxutil.change_edge_src(sdfg, state, post_state) sdfg.add_edge(state, post_state, edges.InterstateEdge()) veclen_ = 1 # propagate vector info from a nested sdfg for src, src_conn, dst, dst_conn, mem in state.edges(): # need to go inside the nested SDFG and grab the vector length if isinstance(dst, dace.graph.nodes.NestedSDFG): # this edge is going to the nested SDFG for inner_state in dst.sdfg.states(): for n in inner_state.nodes(): if isinstance(n, dace.graph.nodes.AccessNode ) and n.data == dst_conn: # assuming all memlets have the same vector length veclen_ = inner_state.all_edges(n)[0].data.veclen if isinstance(src, dace.graph.nodes.NestedSDFG): # this edge is coming from the nested SDFG for inner_state in src.sdfg.states(): for n in inner_state.nodes(): if isinstance(n, dace.graph.nodes.AccessNode ) and n.data == src_conn: # assuming all memlets have the same vector length veclen_ = inner_state.all_edges(n)[0].data.veclen if mem.data is not None and mem.data in fpga_data: mem.data = 'fpga_' + mem.data mem.veclen = veclen_ fpga_update(sdfg, state, 0)
def apply(self, sdfg): begin = sdfg.nodes()[self.subgraph[DoubleBuffering._begin]] guard = sdfg.nodes()[self.subgraph[DoubleBuffering._guard]] body = sdfg.nodes()[self.subgraph[DoubleBuffering._body]] end = sdfg.nodes()[self.subgraph[DoubleBuffering._end]] loop_vars = [] for _, dst, e in sdfg.out_edges(body): if dst is guard: for var in e.assignments.keys(): loop_vars.append(var) if len(loop_vars) != 1: raise NotImplementedError() loop_var = loop_vars[0] sym_var = dace.symbolic.pystr_to_symbolic(loop_var) # Find source/sink (data) nodes input_nodes = nxutil.find_source_nodes(body) #output_nodes = nxutil.find_sink_nodes(body) copied_nodes = set() db_nodes = {} for node in input_nodes: for _, _, dst, _, mem in body.out_edges(node): if (isinstance(dst, dace.graph.nodes.AccessNode) and loop_var in mem.subset.free_symbols): # Create new data and nodes in guard if node not in copied_nodes: guard.add_node(node) copied_nodes.add(node) if dst not in copied_nodes: old_data = dst.desc(sdfg) if isinstance(old_data, dace.data.Array): new_shape = tuple([2] + list(old_data.shape)) new_data = sdfg.add_array(old_data.data, old_data.dtype, new_shape, transient=True) elif isinstance(old_data, data.Scalar): new_data = sdfg.add_array(old_data.data, old_data.dtype, (2), transient=True) else: raise NotImplementedError() new_node = dace.graph.nodes.AccessNode(old_data.data) guard.add_node(new_node) copied_nodes.add(dst) db_nodes.update({dst: new_node}) # Create memlet in guard new_mem = copy.deepcopy(mem) old_index = new_mem.other_subset if isinstance(old_index, dace.subsets.Range): new_ranges = [(0, 0, 1)] + old_index.ranges new_mem.other_subset = dace.subsets.Range(new_ranges) elif isinstance(old_index, dace.subsets.Indices): new_indices = [0] + old_index.indices new_mem.other_subset = dace.subsets.Indices( new_indices) guard.add_edge(node, None, new_node, None, new_mem) # Create nodes, memlets in body first_node = copy.deepcopy(new_node) second_node = copy.deepcopy(new_node) body.add_nodes_from([first_node, second_node]) dace.graph.nxutil.change_edge_dest(body, dst, first_node) dace.graph.nxutil.change_edge_src(body, dst, second_node) for src, _, dest, _, memm in body.edges(): if src is node and dest is first_node: old_index = memm.other_subset idx = (sym_var + 1) % 2 if isinstance(old_index, dace.subsets.Range): new_ranges = [(idx, idx, 1)] + old_index.ranges elif isinstance(old_index, dace.subsets.Indices): new_ranges = [(idx, idx, 1)] for index in old_index.indices: new_ranges.append((index, index, 1)) memm.other_subset = dace.subsets.Range(new_ranges) elif memm.data == dst.data: old_index = memm.subset idx = sym_var % 2 if isinstance(old_index, dace.subsets.Range): new_ranges = [(idx, idx, 1)] + old_index.ranges elif isinstance(old_index, dace.subsets.Indices): new_ranges = [(idx, idx, 1)] for index in old_index.indices: new_ranges.append((index, index, 1)) memm.subset = dace.subsets.Range(new_ranges) memm.data = first_node.data body.remove_node(dst)
def apply(self, sdfg): first_state = sdfg.nodes()[self.subgraph[StateFusion._first_state]] second_state = sdfg.nodes()[self.subgraph[StateFusion._second_state]] # Remove interstate edge(s) edges = sdfg.edges_between(first_state, second_state) for edge in edges: if edge.data.assignments: for src, dst, other_data in sdfg.in_edges(first_state): other_data.assignments.update(edge.data.assignments) sdfg.remove_edge(edge) # Special case 1: first state is empty if first_state.is_empty(): nxutil.change_edge_dest(sdfg, first_state, second_state) sdfg.remove_node(first_state) return # Special case 2: second state is empty if second_state.is_empty(): nxutil.change_edge_src(sdfg, second_state, first_state) nxutil.change_edge_dest(sdfg, second_state, first_state) sdfg.remove_node(second_state) return # Normal case: both states are not empty # Find source/sink (data) nodes first_input = [ node for node in nxutil.find_source_nodes(first_state) if isinstance(node, nodes.AccessNode) ] first_output = [ node for node in nxutil.find_sink_nodes(first_state) if isinstance(node, nodes.AccessNode) ] second_input = [ node for node in nxutil.find_source_nodes(second_state) if isinstance(node, nodes.AccessNode) ] # first input = first input - first output first_input = [ node for node in first_input if next((x for x in first_output if x.label == node.label), None) is None ] # Merge second state to first state for node in second_state.nodes(): first_state.add_node(node) for src, src_conn, dst, dst_conn, data in second_state.edges(): first_state.add_edge(src, src_conn, dst, dst_conn, data) # Merge common (data) nodes for node in first_input: try: old_node = next(x for x in second_input if x.label == node.label) except StopIteration: continue nxutil.change_edge_src(first_state, old_node, node) first_state.remove_node(old_node) second_input.remove(old_node) for node in first_output: try: new_node = next(x for x in second_input if x.label == node.label) except StopIteration: continue nxutil.change_edge_dest(first_state, node, new_node) first_state.remove_node(node) second_input.remove(new_node) # Redirect edges and remove second state nxutil.change_edge_src(sdfg, second_state, first_state) sdfg.remove_node(second_state) if Config.get_bool("debugprint"): StateFusion._states_fused += 1
def apply(self, sdfg): state = sdfg.nodes()[self.subgraph[FPGATransformState._state]] # Find source/sink (data) nodes input_nodes = nxutil.find_source_nodes(state) output_nodes = nxutil.find_sink_nodes(state) fpga_data = {} if input_nodes: pre_state = sd.SDFGState('pre_' + state.label, sdfg) for node in input_nodes: if (not isinstance(node, dace.graph.nodes.AccessNode) or not isinstance(node.desc(sdfg), dace.data.Array)): # Only transfer array nodes # TODO: handle streams continue array = node.desc(sdfg) if array.name in fpga_data: fpga_array = fpga_data[node.data] else: fpga_array = sdfg.add_array( 'fpga_' + node.data, array.dtype, array.shape, materialize_func=array.materialize_func, transient=True, storage=types.StorageType.FPGA_Global, allow_conflicts=array.allow_conflicts, access_order=array.access_order, strides=array.strides, offset=array.offset) fpga_data[array.name] = fpga_array fpga_node = type(node)(fpga_array) pre_state.add_node(node) pre_state.add_node(fpga_node) full_range = subsets.Range([(0, s - 1, 1) for s in array.shape]) mem = memlet.Memlet(array, full_range.num_elements(), full_range, 1) pre_state.add_edge(node, None, fpga_node, None, mem) state.add_node(fpga_node) nxutil.change_edge_src(state, node, fpga_node) state.remove_node(node) sdfg.add_node(pre_state) nxutil.change_edge_dest(sdfg, state, pre_state) sdfg.add_edge(pre_state, state, edges.InterstateEdge()) if output_nodes: post_state = sd.SDFGState('post_' + state.label, sdfg) for node in output_nodes: if (not isinstance(node, dace.graph.nodes.AccessNode) or not isinstance(node.desc(sdfg), dace.data.Array)): # Only transfer array nodes # TODO: handle streams continue array = node.desc(sdfg) if node.data in fpga_data: fpga_array = fpga_data[node.data] else: fpga_array = sdfg.add_array( 'fpga_' + node.data, array.dtype, array.shape, materialize_func=array.materialize_func, transient=True, storage=types.StorageType.FPGA_Global, allow_conflicts=array.allow_conflicts, access_order=array.access_order, strides=array.strides, offset=array.offset) fpga_data[node.data] = fpga_array fpga_node = type(node)(fpga_array) post_state.add_node(node) post_state.add_node(fpga_node) full_range = subsets.Range([(0, s - 1, 1) for s in array.shape]) mem = memlet.Memlet(fpga_array, full_range.num_elements(), full_range, 1) post_state.add_edge(fpga_node, None, node, None, mem) state.add_node(fpga_node) nxutil.change_edge_dest(state, node, fpga_node) state.remove_node(node) sdfg.add_node(post_state) nxutil.change_edge_src(sdfg, state, post_state) sdfg.add_edge(state, post_state, edges.InterstateEdge()) for src, _, dst, _, mem in state.edges(): if mem.data is not None and mem.data in fpga_data: mem.data = 'fpga_' + node.data fpga_update(state, 0)
def apply(self, sdfg): state = sdfg.nodes()[self.subgraph[FPGATransformState._state]] # Find source/sink (data) nodes input_nodes = nxutil.find_source_nodes(state) output_nodes = nxutil.find_sink_nodes(state) fpga_data = {} # Input nodes may also be nodes with WCR memlets # We have to recur across nested SDFGs to find them wcr_input_nodes = set() stack = [] for node, graph in state.all_nodes_recursive(): if isinstance(node, dace.graph.nodes.AccessNode): for e in graph.all_edges(node): if e.data.wcr is not None: # This is an output node with wcr # find the target in the parent sdfg # following the structure State->SDFG->State-> SDFG # from the current_state we have to go two levels up parent_state = graph.parent.parent if parent_state is not None: for parent_edges in parent_state.edges(): if parent_edges.src_conn == e.dst.data or ( isinstance(parent_edges.dst, dace.graph.nodes.AccessNode) and e.dst.data == parent_edges.dst.data): # This must be copied to device input_nodes.append(parent_edges.dst) wcr_input_nodes.add(parent_edges.dst) if input_nodes: # create pre_state pre_state = sd.SDFGState('pre_' + state.label, sdfg) for node in input_nodes: if (not isinstance(node, dace.graph.nodes.AccessNode) or not isinstance(node.desc(sdfg), dace.data.Array)): # Only transfer array nodes # TODO: handle streams continue array = node.desc(sdfg) if node.data in fpga_data: fpga_array = fpga_data[node.data] elif node not in wcr_input_nodes: fpga_array = sdfg.add_array( 'fpga_' + node.data, array.shape, array.dtype, materialize_func=array.materialize_func, transient=True, storage=dtypes.StorageType.FPGA_Global, allow_conflicts=array.allow_conflicts, strides=array.strides, offset=array.offset) fpga_data[node.data] = fpga_array pre_node = pre_state.add_read(node.data) pre_fpga_node = pre_state.add_write('fpga_' + node.data) full_range = subsets.Range([(0, s - 1, 1) for s in array.shape]) mem = memlet.Memlet(node.data, full_range.num_elements(), full_range, 1) pre_state.add_edge(pre_node, None, pre_fpga_node, None, mem) if node not in wcr_input_nodes: fpga_node = state.add_read('fpga_' + node.data) nxutil.change_edge_src(state, node, fpga_node) state.remove_node(node) sdfg.add_node(pre_state) nxutil.change_edge_dest(sdfg, state, pre_state) sdfg.add_edge(pre_state, state, edges.InterstateEdge()) if output_nodes: post_state = sd.SDFGState('post_' + state.label, sdfg) for node in output_nodes: if (not isinstance(node, dace.graph.nodes.AccessNode) or not isinstance(node.desc(sdfg), dace.data.Array)): # Only transfer array nodes # TODO: handle streams continue array = node.desc(sdfg) if node.data in fpga_data: fpga_array = fpga_data[node.data] else: fpga_array = sdfg.add_array( 'fpga_' + node.data, array.shape, array.dtype, materialize_func=array.materialize_func, transient=True, storage=dtypes.StorageType.FPGA_Global, allow_conflicts=array.allow_conflicts, strides=array.strides, offset=array.offset) fpga_data[node.data] = fpga_array # fpga_node = type(node)(fpga_array) post_node = post_state.add_write(node.data) post_fpga_node = post_state.add_read('fpga_' + node.data) full_range = subsets.Range([(0, s - 1, 1) for s in array.shape]) mem = memlet.Memlet('fpga_' + node.data, full_range.num_elements(), full_range, 1) post_state.add_edge(post_fpga_node, None, post_node, None, mem) fpga_node = state.add_write('fpga_' + node.data) nxutil.change_edge_dest(state, node, fpga_node) state.remove_node(node) sdfg.add_node(post_state) nxutil.change_edge_src(sdfg, state, post_state) sdfg.add_edge(state, post_state, edges.InterstateEdge()) veclen_ = 1 # propagate vector info from a nested sdfg for src, src_conn, dst, dst_conn, mem in state.edges(): # need to go inside the nested SDFG and grab the vector length if isinstance(dst, dace.graph.nodes.NestedSDFG): # this edge is going to the nested SDFG for inner_state in dst.sdfg.states(): for n in inner_state.nodes(): if isinstance(n, dace.graph.nodes.AccessNode ) and n.data == dst_conn: # assuming all memlets have the same vector length veclen_ = inner_state.all_edges(n)[0].data.veclen if isinstance(src, dace.graph.nodes.NestedSDFG): # this edge is coming from the nested SDFG for inner_state in src.sdfg.states(): for n in inner_state.nodes(): if isinstance(n, dace.graph.nodes.AccessNode ) and n.data == src_conn: # assuming all memlets have the same vector length veclen_ = inner_state.all_edges(n)[0].data.veclen if mem.data is not None and mem.data in fpga_data: mem.data = 'fpga_' + mem.data mem.veclen = veclen_ fpga_update(sdfg, state, 0)
def apply(self, sdfg): outer_sdfg = sdfg nested_sdfg = dc(sdfg) outer_sdfg.arrays.clear() outer_sdfg.remove_nodes_from(outer_sdfg.nodes()) inputs = {} outputs = {} transients = {} for state in nested_sdfg.nodes(): for node in nxutil.find_source_nodes(state): if (isinstance(node, nodes.AccessNode) and not node.desc(nested_sdfg).transient): arrname = node.data if arrname not in inputs: arrobj = nested_sdfg.arrays[arrname] nested_sdfg.arrays[arrname + '_in'] = arrobj outer_sdfg.arrays[arrname] = dc(arrobj) inputs[arrname] = arrname + '_in' node.data = arrname + '_in' for node in nxutil.find_sink_nodes(state): if (isinstance(node, nodes.AccessNode) and not node.desc(nested_sdfg).transient): arrname = node.data if arrname not in outputs: arrobj = nested_sdfg.arrays[arrname] nested_sdfg.arrays[arrname + '_out'] = arrobj if arrname not in inputs: outer_sdfg.arrays[arrname] = dc(arrobj) outputs[arrname] = arrname + '_out' # TODO: Is this needed any longer ? # # WCR Fix # if self.promote_global_trans: # for edge in state.in_edges(node): # if state.memlet_path(edge)[0].data.wcr: # if node.data not in input_data: # input_orig.update({ # node.data + '_in': # node.data # }) # input_nodes.update({ # node.data + '_in': # dc(node) # }) # new_data = dc(node.desc(sdfg)) # sdfg.arrays.update({ # node.data + '_in': # new_data # }) # input_data.add(node.data + '_in') # break node.data = arrname + '_out' if self.promote_global_trans: scope_dict = state.scope_dict() for node in state.nodes(): if (isinstance(node, nodes.AccessNode) and node.desc(nested_sdfg).transient): arrname = node.data if arrname not in transients and not scope_dict[node]: arrobj = nested_sdfg.arrays[arrname] nested_sdfg.arrays[arrname + '_out'] = arrobj outer_sdfg.arrays[arrname] = dc(arrobj) transients[arrname] = arrname + '_out' node.data = arrname + '_out' for arrname in inputs.keys(): nested_sdfg.arrays.pop(arrname) for arrname in outputs.keys(): nested_sdfg.arrays.pop(arrname, None) for oldarrname, newarrname in transients.items(): nested_sdfg.arrays.pop(oldarrname) nested_sdfg.arrays[newarrname].transient = False outer_sdfg.arrays[oldarrname].transient = False outputs.update(transients) for state in nested_sdfg.nodes(): for _, edge in enumerate(state.edges()): _, _, _, _, mem = edge src = state.memlet_path(edge)[0].src dst = state.memlet_path(edge)[-1].dst if isinstance(src, nodes.AccessNode): if (mem.data in inputs.keys() and src.data == inputs[mem.data]): mem.data = inputs[mem.data] elif (mem.data in outputs.keys() and src.data == outputs[mem.data]): mem.data = outputs[mem.data] elif (isinstance(dst, nodes.AccessNode) and mem.data in outputs.keys() and dst.data == outputs[mem.data]): mem.data = outputs[mem.data] outer_state = outer_sdfg.add_state(outer_sdfg.label) nested_node = outer_state.add_nested_sdfg(nested_sdfg, outer_sdfg, inputs.values(), outputs.values()) for key, val in inputs.items(): arrnode = outer_state.add_read(key) outer_state.add_edge( arrnode, None, nested_node, val, memlet.Memlet.from_array(key, arrnode.desc(outer_sdfg))) for key, val in outputs.items(): arrnode = outer_state.add_write(key) outer_state.add_edge( nested_node, val, arrnode, None, memlet.Memlet.from_array(key, arrnode.desc(outer_sdfg)))
def can_be_applied(graph, candidate, expr_index, sdfg, strict=False): first_state = graph.nodes()[candidate[StateFusion._first_state]] second_state = graph.nodes()[candidate[StateFusion._second_state]] out_edges = graph.out_edges(first_state) in_edges = graph.in_edges(first_state) # First state must have only one output edge (with dst the second # state). if len(out_edges) != 1: return False # The interstate edge must not have a condition. if out_edges[0].data.condition.as_string != '': return False # The interstate edge may have assignments, as long as there are input # edges to the first state, that can absorb them. if out_edges[0].data.assignments and not in_edges: return False # There can be no state that have output edges pointing to both the # first and the second state. Such a case will produce a multi-graph. for src, _, _ in in_edges: for _, dst, _ in graph.out_edges(src): if dst == second_state: return False if strict: # If second state has other input edges, there might be issues second_in_edges = graph.in_edges(second_state) if ((not second_state.is_empty() or not first_state.is_empty()) and len(second_in_edges) != 1): return False # Get connected components. first_cc = [ cc_nodes for cc_nodes in nx.weakly_connected_components(first_state._nx) ] second_cc = [ cc_nodes for cc_nodes in nx.weakly_connected_components( second_state._nx) ] # Find source/sink (data) nodes first_input = { node for node in nxutil.find_source_nodes(first_state) if isinstance(node, nodes.AccessNode) } first_output = { node for node in first_state.nodes() if isinstance(node, nodes.AccessNode) and node not in first_input } second_input = { node for node in nxutil.find_source_nodes(second_state) if isinstance(node, nodes.AccessNode) } second_output = { node for node in second_state.nodes() if isinstance(node, nodes.AccessNode) and node not in second_input } # Find source/sink (data) nodes by connected component first_cc_input = [cc.intersection(first_input) for cc in first_cc] first_cc_output = [ cc.intersection(first_output) for cc in first_cc ] second_cc_input = [ cc.intersection(second_input) for cc in second_cc ] second_cc_output = [ cc.intersection(second_output) for cc in second_cc ] check_strict = len(first_cc) for cc_output in first_cc_output: for node in cc_output: if next((x for x in second_input if x.label == node.label), None) is not None: check_strict -= 1 break if check_strict > 0: # Check strict conditions # RW dependency for node in first_input: if next( (x for x in second_output if x.label == node.label), None) is not None: return False # WW dependency for node in first_output: if next( (x for x in second_output if x.label == node.label), None) is not None: return False return True