def expansion(node, parent_state, parent_sdfg, n=None, **kwargs): (buffer, count_str, buffer_offset, ddt), dest, tag, req = node.validate(parent_sdfg, parent_state) mpi_dtype_str = dace.libraries.mpi.utils.MPI_DDT( buffer.dtype.base_type) if buffer.dtype.veclen > 1: raise NotImplementedError code = "" if not node.nosync and buffer.storage == dtypes.StorageType.GPU_Global: code += f""" cudaStreamSynchronize(__dace_current_stream); """ if ddt is not None: code += f"""static MPI_Datatype newtype; static int init=1; if (init) {{ MPI_Type_vector({ddt['count']}, {ddt['blocklen']}, {ddt['stride']}, {ddt['oldtype']}, &newtype); MPI_Type_commit(&newtype); init=0; }} """ mpi_dtype_str = "newtype" count_str = "1" buffer_offset = 0 code += f"MPI_Isend(&(_buffer[{buffer_offset}]), {count_str}, {mpi_dtype_str}, _dest, _tag, MPI_COMM_WORLD, _request);" if ddt is not None: code += f"""// MPI_Type_free(&newtype); """ tasklet = dace.sdfg.nodes.Tasklet(node.name, node.in_connectors, node.out_connectors, code, language=dace.dtypes.Language.CPP) conn = tasklet.in_connectors conn = { c: (dtypes.int32 if c == '_dest' else t) for c, t in conn.items() } tasklet.in_connectors = conn conn = tasklet.out_connectors conn = { c: (dtypes.pointer(dtypes.opaque("MPI_Request")) if c == '_request' else t) for c, t in conn.items() } tasklet.out_connectors = conn return tasklet
def expansion(node, parent_state, parent_sdfg, n=None, **kwargs): count = node.validate(parent_sdfg, parent_state) code = f""" MPI_Status _s[{count}]; MPI_Waitall({count}, _request, _s); """ tasklet = dace.sdfg.nodes.Tasklet(node.name, node.in_connectors, node.out_connectors, code, language=dace.dtypes.Language.CPP) conn = tasklet.in_connectors conn = {c: (dtypes.pointer(dtypes.opaque("MPI_Request")) if c == '_request' else t) for c, t in conn.items()} tasklet.in_connectors = conn return tasklet
def expansion(node, parent_state, parent_sdfg, n=None, **kwargs): req, status = node.validate(parent_sdfg, parent_state) code = f""" MPI_Status _s; MPI_Wait(_request, &_s); _stat_tag = _s.MPI_TAG; _stat_source = _s.MPI_SOURCE; """ tasklet = dace.sdfg.nodes.Tasklet(node.name, node.in_connectors, node.out_connectors, code, language=dace.dtypes.Language.CPP) conn = tasklet.in_connectors conn = {c: (dtypes.pointer(dtypes.opaque("MPI_Request")) if c == '_request' else t) for c, t in conn.items()} tasklet.in_connectors = conn return tasklet
def expansion(node, parent_state, parent_sdfg, n=None, **kwargs): (buffer, count_str, buffer_offset, ddt), src, tag = node.validate(parent_sdfg, parent_state) mpi_dtype_str = dace.libraries.mpi.utils.MPI_DDT( buffer.dtype.base_type) if buffer.dtype.veclen > 1: raise NotImplementedError code = "" if ddt is not None: code = f"""static MPI_Datatype newtype; static int init=1; if (init) {{ MPI_Type_vector({ddt['count']}, {ddt['blocklen']}, {ddt['stride']}, {ddt['oldtype']}, &newtype); MPI_Type_commit(&newtype); init = 0; }} """ mpi_dtype_str = "newtype" count_str = "1" buffer_offset = 0 #this is here because the frontend already changes the pointer code += f"MPI_Irecv(_buffer, {count_str}, {mpi_dtype_str}, _src, _tag, MPI_COMM_WORLD, _request);" if ddt is not None: code += f"""// MPI_Type_free(&newtype); """ tasklet = dace.sdfg.nodes.Tasklet(node.name, node.in_connectors, node.out_connectors, code, language=dace.dtypes.Language.CPP) conn = tasklet.out_connectors conn = { c: (dtypes.pointer(dtypes.opaque("MPI_Request")) if c == '_request' else t) for c, t in conn.items() } tasklet.out_connectors = conn return tasklet