Exemplo n.º 1
0
    def apply(self, state: SDFGState, sdfg: SDFG):
        nsdfg = self.nsdfg

        candidates, candidate_nodes = self._candidates(nsdfg)
        for outer_edge in state.out_edges(nsdfg):
            if outer_edge.src_conn in candidates:
                state.remove_memlet_path(outer_edge)
                sdfg.remove_data(outer_edge.data.data, validate=False)
        for nstate, node in candidate_nodes:
            for ie in nstate.in_edges(node):
                nstate.remove_memlet_path(ie)
        for cand in candidates:
            nsdfg.sdfg.remove_data(cand, validate=False)
Exemplo n.º 2
0
def aggregate_calls(sdfg: dace.SDFG, state: dace.SDFGState,
                    lib_node: nodes.LibraryNode, code: str):
    group_handle_conn = '_group_handle'
    sync = True
    if group_handle_conn in lib_node.in_connectors:
        sync = False
        for edge in state.in_edges(lib_node):
            if edge.dst_conn == group_handle_conn:
                in_gh_edge = edge
                in_gh_node = edge.src
        if not state.predecessors(in_gh_node):
            code = """ncclGroupStart();\n""" + code
        else:
            predecessor_node = state.predecessors(in_gh_node)[0]
            state.add_edge(predecessor_node, None, lib_node, None,
                           dace.Memlet())
            state.remove_edge_and_connectors(state.in_edges(in_gh_node)[0])
        state.remove_edge_and_connectors(in_gh_edge)
        lib_node.remove_in_connector(group_handle_conn)
        state.remove_node(in_gh_node)

    if group_handle_conn in lib_node.out_connectors:
        for edge in state.out_edges(lib_node):
            if edge.src_conn == group_handle_conn:
                out_gh_edge = edge
                out_gh_node = edge.dst
        if not state.successors(out_gh_node):
            code += """ncclGroupEnd();"""
            sync = True
            out_gh_data = out_gh_node.data
            state.remove_edge_and_connectors(out_gh_edge)
            state.remove_node(out_gh_node)
            try:
                sdfg.remove_data(out_gh_data)
            except ValueError as ex:
                warnings.warn(str(ex))
        lib_node.remove_out_connector(group_handle_conn)
    # if sync:
    #     code += """\ncudaStreamSynchronize(__dace_current_stream);"""
    return code