def apply(self, sdfg: dace.SDFG) -> None: state = sdfg.node(self.state_id) left = self.left(sdfg) right = self.right(sdfg) # Merge source locations dinfo = self._merge_source_locations(left, right) # merge oir nodes res = HorizontalExecutionLibraryNode( oir_node=oir.HorizontalExecution( body=left.as_oir().body + right.as_oir().body, declarations=left.as_oir().declarations + right.as_oir().declarations, ), iteration_space=left.iteration_space, debuginfo=dinfo, ) state.add_node(res) intermediate_accesses = set( n for path in nx.all_simple_paths(state.nx, left, right) for n in path[1:-1]) # rewire edges and connectors to left and delete right for edge in state.edges_between(left, right): state.remove_edge_and_connectors(edge) for acc in intermediate_accesses: for edge in state.in_edges(acc): if edge.src is not left: rewire_edge(state, edge, dst=res) else: state.remove_edge_and_connectors(edge) for edge in state.out_edges(acc): if edge.dst is not right: rewire_edge(state, edge, src=res) else: state.remove_edge_and_connectors(edge) for edge in state.in_edges(left): rewire_edge(state, edge, dst=res) for edge in state.out_edges(right): rewire_edge(state, edge, src=res) for edge in state.out_edges(left): rewire_edge(state, edge, src=res) for edge in state.in_edges(right): rewire_edge(state, edge, dst=res) state.remove_node(left) state.remove_node(right) for acc in intermediate_accesses: if not state.in_edges(acc): if not state.out_edges(acc): state.remove_node(acc) else: assert (len(state.edges_between(acc, res)) == 1 and len(state.out_edges(acc)) == 1), "Previously written array now read-only." state.remove_node(acc) res.remove_in_connector("IN_" + acc.label) elif not state.out_edges: acc.access = dace.AccessType.WriteOnly
def offsets_match(left: HorizontalExecutionLibraryNode, right: HorizontalExecutionLibraryNode) -> bool: left_accesses = AccessCollector.apply(left.as_oir()) right_accesses = AccessCollector.apply(right.as_oir()) conflicting = read_after_write_conflicts( ij_offsets(left_accesses.write_offsets()), ij_offsets( right_accesses.read_offsets())) | write_after_read_conflicts( ij_offsets(left_accesses.read_offsets()), ij_offsets(right_accesses.write_offsets())) return not conflicting
def visit_HorizontalExecution(self, node: oir.HorizontalExecution, *, iteration_spaces, **kwargs): return HorizontalExecutionLibraryNode( name=f"HorizontalExecution_{id(node)}", oir_node=node, iteration_space=iteration_spaces[id(node)], )
def get_vertical_loop_section_sdfg(section: "VerticalLoopSection") -> SDFG: from gtc.dace.nodes import HorizontalExecutionLibraryNode sdfg = SDFG("VerticalLoopSection_" + str(id(section))) old_state = sdfg.add_state("start_state", is_start_state=True) for he in section.horizontal_executions: new_state = sdfg.add_state("HorizontalExecution_" + str(id(he)) + "_state") sdfg.add_edge(old_state, new_state, InterstateEdge()) new_state.add_node(HorizontalExecutionLibraryNode(oir_node=he)) old_state = new_state return sdfg
def visit_HorizontalExecution(self, node: oir.HorizontalExecution, *, block_extents, **kwargs): return HorizontalExecutionLibraryNode( name=f"HorizontalExecution_{id(node)}", oir_node=node, extent=block_extents[id(node)], )