def apply(self, state: SDFGState, sdfg: SDFG): nsdfg = self.nsdfg candidates, candidate_nodes = self._candidates(nsdfg) for outer_edge in state.out_edges(nsdfg): if outer_edge.src_conn in candidates: state.remove_memlet_path(outer_edge) sdfg.remove_data(outer_edge.data.data, validate=False) for nstate, node in candidate_nodes: for ie in nstate.in_edges(node): nstate.remove_memlet_path(ie) for cand in candidates: nsdfg.sdfg.remove_data(cand, validate=False)
def aggregate_calls(sdfg: dace.SDFG, state: dace.SDFGState, lib_node: nodes.LibraryNode, code: str): group_handle_conn = '_group_handle' sync = True if group_handle_conn in lib_node.in_connectors: sync = False for edge in state.in_edges(lib_node): if edge.dst_conn == group_handle_conn: in_gh_edge = edge in_gh_node = edge.src if not state.predecessors(in_gh_node): code = """ncclGroupStart();\n""" + code else: predecessor_node = state.predecessors(in_gh_node)[0] state.add_edge(predecessor_node, None, lib_node, None, dace.Memlet()) state.remove_edge_and_connectors(state.in_edges(in_gh_node)[0]) state.remove_edge_and_connectors(in_gh_edge) lib_node.remove_in_connector(group_handle_conn) state.remove_node(in_gh_node) if group_handle_conn in lib_node.out_connectors: for edge in state.out_edges(lib_node): if edge.src_conn == group_handle_conn: out_gh_edge = edge out_gh_node = edge.dst if not state.successors(out_gh_node): code += """ncclGroupEnd();""" sync = True out_gh_data = out_gh_node.data state.remove_edge_and_connectors(out_gh_edge) state.remove_node(out_gh_node) try: sdfg.remove_data(out_gh_data) except ValueError as ex: warnings.warn(str(ex)) lib_node.remove_out_connector(group_handle_conn) # if sync: # code += """\ncudaStreamSynchronize(__dace_current_stream);""" return code