def apply(self, sdfg): state = sdfg.nodes()[self.subgraph[StateAssignElimination._end_state]] edge = sdfg.in_edges(state)[0] # Since inter-state assignments that use an assigned value leads to # undefined behavior (e.g., {m: n, n: m}), we can replace each # assignment separately. keys_to_remove = set() assignments_to_consider = _assignments_to_consider(sdfg, edge) for varname, assignment in assignments_to_consider.items(): state.replace(varname, assignment) keys_to_remove.add(varname) repl_dict = {} for varname in keys_to_remove: # Remove assignments from edge del edge.data.assignments[varname] for e in sdfg.edges(): if varname in e.data.free_symbols: break else: # If removed assignment does not appear in any other edge, # replace and remove symbol if assignments_to_consider[varname] in sdfg.symbols: repl_dict[varname] = assignments_to_consider[varname] if varname in sdfg.symbols: sdfg.remove_symbol(varname) def _str_repl(s, d): for k, v in d.items(): s.replace(str(k), str(v)) if repl_dict: symbolic.safe_replace(repl_dict, lambda m: _str_repl(sdfg, m))
def apply(self, sdfg): state = sdfg.nodes()[self.subgraph[EndStateElimination._end_state]] # Handle orphan symbols (due to the deletion the incoming edge) edge = sdfg.in_edges(state)[0] sym_assign = edge.data.assignments.keys() sdfg.remove_node(state) # Remove orphan symbols for sym in sym_assign: if sym in sdfg.free_symbols: sdfg.remove_symbol(sym)
def apply(self, _, sdfg): state = self.end_state # Handle orphan symbols (due to the deletion the incoming edge) edge = sdfg.in_edges(state)[0] sym_assign = edge.data.assignments.keys() sdfg.remove_node(state) # Remove orphan symbols for sym in sym_assign: if sym in sdfg.free_symbols: sdfg.remove_symbol(sym)
def apply(self, sdfg): state = sdfg.nodes()[self.subgraph[StateAssignElimination._end_state]] edge = sdfg.in_edges(state)[0] # Since inter-state assignments that use an assigned value leads to # undefined behavior (e.g., {m: n, n: m}), we can replace each # assignment separately. for varname, assignment in edge.data.assignments.items(): state.replace(varname, assignment) # Remove assignments from edge edge.data.assignments = {}
def apply(self, _, sdfg: SDFG): state = self.end_state edge = sdfg.in_edges(state)[0] # Since inter-state assignments that use an assigned value leads to # undefined behavior (e.g., {m: n, n: m}), we can replace each # assignment separately. assignments_to_consider = _assignments_to_consider(sdfg, edge, True) def _str_repl(s, d, **kwargs): for k, v in d.items(): s.replace(str(k), str(v), **kwargs) # Replace in state, and all successors symbolic.safe_replace(assignments_to_consider, lambda m: _str_repl(state, m)) visited = {edge} for isedge in sdfg.bfs_edges(state): if isedge not in visited: symbolic.safe_replace( assignments_to_consider, lambda m: _str_repl(isedge.data, m, replace_keys=False)) visited.add(isedge) if isedge.dst not in visited: symbolic.safe_replace(assignments_to_consider, lambda m: _str_repl(isedge.dst, m)) visited.add(isedge.dst) repl_dict = {} for varname in assignments_to_consider.keys(): # Remove assignments from edge del edge.data.assignments[varname] for e in sdfg.edges(): if varname in e.data.free_symbols: break else: # If removed assignment does not appear in any other edge, # replace and remove symbol if varname in sdfg.symbols: sdfg.remove_symbol(varname) # if assignments_to_consider[varname] in sdfg.symbols: if varname in sdfg.free_symbols: repl_dict[varname] = assignments_to_consider[varname] if repl_dict: symbolic.safe_replace(repl_dict, lambda m: _str_repl(sdfg, m))
def apply(self, sdfg): fstate = sdfg.nodes()[self.subgraph[SymbolAliasPromotion._first_state]] sstate = sdfg.nodes()[self.subgraph[ SymbolAliasPromotion._second_state]] edge = sdfg.edges_between(fstate, sstate)[0].data in_edge = sdfg.in_edges(fstate)[0].data to_consider = _alias_assignments(sdfg, edge) to_not_consider = set() for k, v in to_consider.items(): # Remove symbols that are taking part in the edge's condition condsyms = [str(s) for s in edge.condition_sympy().free_symbols] if k in condsyms: to_not_consider.add(k) # Remove symbols that are set in the in_edge # with a different assignment if k in in_edge.assignments and in_edge.assignments[k] != v: to_not_consider.add(k) # Remove symbols whose assignment (RHS) is a symbol # and is set in the in_edge. if v in sdfg.symbols and v in in_edge.assignments: to_not_consider.add(k) # Remove symbols whose assignment (RHS) is a scalar # and is set in the first state. if v in sdfg.arrays and isinstance(sdfg.arrays[v], dt.Scalar): if any( isinstance(n, nodes.AccessNode) and n.data == v for n in fstate.nodes()): to_not_consider.add(k) for k in to_not_consider: del to_consider[k] for k, v in to_consider.items(): del edge.assignments[k] in_edge.assignments[k] = v
def apply(self, sdfg): first_state = sdfg.nodes()[self.subgraph[StateFusion._first_state]] second_state = sdfg.nodes()[self.subgraph[StateFusion._second_state]] # Remove interstate edge(s) edges = sdfg.edges_between(first_state, second_state) for edge in edges: if edge.data.assignments: for src, dst, other_data in sdfg.in_edges(first_state): other_data.assignments.update(edge.data.assignments) sdfg.remove_edge(edge) # Special case 1: first state is empty if first_state.is_empty(): sdutil.change_edge_dest(sdfg, first_state, second_state) sdfg.remove_node(first_state) return # Special case 2: second state is empty if second_state.is_empty(): sdutil.change_edge_src(sdfg, second_state, first_state) sdutil.change_edge_dest(sdfg, second_state, first_state) sdfg.remove_node(second_state) return # Normal case: both states are not empty # Find source/sink (data) nodes first_input = [ node for node in sdutil.find_source_nodes(first_state) if isinstance(node, nodes.AccessNode) ] first_output = [ node for node in sdutil.find_sink_nodes(first_state) if isinstance(node, nodes.AccessNode) ] second_input = [ node for node in sdutil.find_source_nodes(second_state) if isinstance(node, nodes.AccessNode) ] # first input = first input - first output first_input = [ node for node in first_input if next((x for x in first_output if x.label == node.label), None) is None ] # Merge second state to first state # First keep a backup of the topological sorted order of the nodes order = [ x for x in reversed(list(nx.topological_sort(first_state._nx))) if isinstance(x, nodes.AccessNode) ] for node in second_state.nodes(): first_state.add_node(node) for src, src_conn, dst, dst_conn, data in second_state.edges(): first_state.add_edge(src, src_conn, dst, dst_conn, data) # Merge common (data) nodes for node in second_input: if first_state.in_degree(node) == 0: n = next((x for x in order if x.label == node.label), None) if n: sdutil.change_edge_src(first_state, node, n) first_state.remove_node(node) n.access = dtypes.AccessType.ReadWrite # Redirect edges and remove second state sdutil.change_edge_src(sdfg, second_state, first_state) sdfg.remove_node(second_state) if Config.get_bool("debugprint"): StateFusion._states_fused += 1
def apply(self, sdfg): if isinstance(self.subgraph[StateFusion.first_state], SDFGState): first_state: SDFGState = self.subgraph[StateFusion.first_state] second_state: SDFGState = self.subgraph[StateFusion.second_state] else: first_state: SDFGState = sdfg.node( self.subgraph[StateFusion.first_state]) second_state: SDFGState = sdfg.node( self.subgraph[StateFusion.second_state]) # Remove interstate edge(s) edges = sdfg.edges_between(first_state, second_state) for edge in edges: if edge.data.assignments: for src, dst, other_data in sdfg.in_edges(first_state): other_data.assignments.update(edge.data.assignments) sdfg.remove_edge(edge) # Special case 1: first state is empty if first_state.is_empty(): sdutil.change_edge_dest(sdfg, first_state, second_state) sdfg.remove_node(first_state) if sdfg.start_state == first_state: sdfg.start_state = sdfg.node_id(second_state) return # Special case 2: second state is empty if second_state.is_empty(): sdutil.change_edge_src(sdfg, second_state, first_state) sdutil.change_edge_dest(sdfg, second_state, first_state) sdfg.remove_node(second_state) if sdfg.start_state == second_state: sdfg.start_state = sdfg.node_id(first_state) return # Normal case: both states are not empty # Find source/sink (data) nodes first_input = [ node for node in sdutil.find_source_nodes(first_state) if isinstance(node, nodes.AccessNode) ] first_output = [ node for node in sdutil.find_sink_nodes(first_state) if isinstance(node, nodes.AccessNode) ] second_input = [ node for node in sdutil.find_source_nodes(second_state) if isinstance(node, nodes.AccessNode) ] top2 = top_level_nodes(second_state) # first input = first input - first output first_input = [ node for node in first_input if next((x for x in first_output if x.data == node.data), None) is None ] # Merge second state to first state # First keep a backup of the topological sorted order of the nodes sdict = first_state.scope_dict() order = [ x for x in reversed(list(nx.topological_sort(first_state._nx))) if isinstance(x, nodes.AccessNode) and sdict[x] is None ] for node in second_state.nodes(): if isinstance(node, nodes.NestedSDFG): # update parent information node.sdfg.parent = first_state first_state.add_node(node) for src, src_conn, dst, dst_conn, data in second_state.edges(): first_state.add_edge(src, src_conn, dst, dst_conn, data) top = top_level_nodes(first_state) # Merge common (data) nodes for node in second_input: # merge only top level nodes, skip everything else if node not in top2: continue if first_state.in_degree(node) == 0: candidates = [ x for x in order if x.data == node.data and x in top ] if len(candidates) == 0: continue elif len(candidates) == 1: n = candidates[0] else: # Choose first candidate that intersects memlets for cand in candidates: if StateFusion.memlets_intersect( first_state, [cand], False, second_state, [node], True): n = cand break else: # No node intersects, use topologically-last node n = candidates[0] sdutil.change_edge_src(first_state, node, n) first_state.remove_node(node) n.access = dtypes.AccessType.ReadWrite # Redirect edges and remove second state sdutil.change_edge_src(sdfg, second_state, first_state) sdfg.remove_node(second_state) if sdfg.start_state == second_state: sdfg.start_state = sdfg.node_id(first_state)
def apply(self, sdfg): first_state = sdfg.nodes()[self.subgraph[StateFusion._first_state]] second_state = sdfg.nodes()[self.subgraph[StateFusion._second_state]] # Remove interstate edge(s) edges = sdfg.edges_between(first_state, second_state) for edge in edges: if edge.data.assignments: for src, dst, other_data in sdfg.in_edges(first_state): other_data.assignments.update(edge.data.assignments) sdfg.remove_edge(edge) # Special case 1: first state is empty if first_state.is_empty(): nxutil.change_edge_dest(sdfg, first_state, second_state) sdfg.remove_node(first_state) return # Special case 2: second state is empty if second_state.is_empty(): nxutil.change_edge_src(sdfg, second_state, first_state) nxutil.change_edge_dest(sdfg, second_state, first_state) sdfg.remove_node(second_state) return # Normal case: both states are not empty # Find source/sink (data) nodes first_input = [ node for node in nxutil.find_source_nodes(first_state) if isinstance(node, nodes.AccessNode) ] first_output = [ node for node in nxutil.find_sink_nodes(first_state) if isinstance(node, nodes.AccessNode) ] second_input = [ node for node in nxutil.find_source_nodes(second_state) if isinstance(node, nodes.AccessNode) ] # first input = first input - first output first_input = [ node for node in first_input if next((x for x in first_output if x.label == node.label), None) is None ] # Merge second state to first state for node in second_state.nodes(): first_state.add_node(node) for src, src_conn, dst, dst_conn, data in second_state.edges(): first_state.add_edge(src, src_conn, dst, dst_conn, data) # Merge common (data) nodes for node in first_input: try: old_node = next(x for x in second_input if x.label == node.label) except StopIteration: continue nxutil.change_edge_src(first_state, old_node, node) first_state.remove_node(old_node) second_input.remove(old_node) for node in first_output: try: new_node = next(x for x in second_input if x.label == node.label) except StopIteration: continue nxutil.change_edge_dest(first_state, node, new_node) first_state.remove_node(node) second_input.remove(new_node) # Redirect edges and remove second state nxutil.change_edge_src(sdfg, second_state, first_state) sdfg.remove_node(second_state) if Config.get_bool("debugprint"): StateFusion._states_fused += 1