Пример #1
0
    def expansion(node, parent_state, parent_sdfg, n=None, **kwargs):
        (buffer, count_str, buffer_offset,
         ddt), dest, tag, req = node.validate(parent_sdfg, parent_state)
        mpi_dtype_str = dace.libraries.mpi.utils.MPI_DDT(
            buffer.dtype.base_type)

        if buffer.dtype.veclen > 1:
            raise NotImplementedError

        code = ""

        if not node.nosync and buffer.storage == dtypes.StorageType.GPU_Global:
            code += f"""
            cudaStreamSynchronize(__dace_current_stream);
            """

        if ddt is not None:
            code += f"""static MPI_Datatype newtype;
                        static int init=1;
                        if (init) {{
                           MPI_Type_vector({ddt['count']}, {ddt['blocklen']}, {ddt['stride']}, {ddt['oldtype']}, &newtype);
                           MPI_Type_commit(&newtype);
                           init=0;
                        }}
                            """
            mpi_dtype_str = "newtype"
            count_str = "1"
        buffer_offset = 0
        code += f"MPI_Isend(&(_buffer[{buffer_offset}]), {count_str}, {mpi_dtype_str}, _dest, _tag, MPI_COMM_WORLD, _request);"
        if ddt is not None:
            code += f"""// MPI_Type_free(&newtype);
            """

        tasklet = dace.sdfg.nodes.Tasklet(node.name,
                                          node.in_connectors,
                                          node.out_connectors,
                                          code,
                                          language=dace.dtypes.Language.CPP)
        conn = tasklet.in_connectors
        conn = {
            c: (dtypes.int32 if c == '_dest' else t)
            for c, t in conn.items()
        }
        tasklet.in_connectors = conn
        conn = tasklet.out_connectors
        conn = {
            c: (dtypes.pointer(dtypes.opaque("MPI_Request"))
                if c == '_request' else t)
            for c, t in conn.items()
        }
        tasklet.out_connectors = conn
        return tasklet
Пример #2
0
 def expansion(node, parent_state, parent_sdfg, n=None, **kwargs):
     count = node.validate(parent_sdfg, parent_state)
     code = f"""
         MPI_Status _s[{count}];
         MPI_Waitall({count}, _request, _s);
         """
     tasklet = dace.sdfg.nodes.Tasklet(node.name,
                                       node.in_connectors,
                                       node.out_connectors,
                                       code,
                                       language=dace.dtypes.Language.CPP)
     conn = tasklet.in_connectors
     conn = {c: (dtypes.pointer(dtypes.opaque("MPI_Request")) if c == '_request' else t) for c, t in conn.items()}
     tasklet.in_connectors = conn
     return tasklet
Пример #3
0
 def expansion(node, parent_state, parent_sdfg, n=None, **kwargs):
     req, status = node.validate(parent_sdfg, parent_state)
     code = f"""
         MPI_Status _s;
         MPI_Wait(_request, &_s);
         _stat_tag = _s.MPI_TAG;
         _stat_source = _s.MPI_SOURCE;
         """
     tasklet = dace.sdfg.nodes.Tasklet(node.name,
                                       node.in_connectors,
                                       node.out_connectors,
                                       code,
                                       language=dace.dtypes.Language.CPP)
     conn = tasklet.in_connectors
     conn = {c: (dtypes.pointer(dtypes.opaque("MPI_Request")) if c == '_request' else t) for c, t in conn.items()}
     tasklet.in_connectors = conn
     return tasklet
Пример #4
0
    def expansion(node, parent_state, parent_sdfg, n=None, **kwargs):
        (buffer, count_str, buffer_offset,
         ddt), src, tag = node.validate(parent_sdfg, parent_state)
        mpi_dtype_str = dace.libraries.mpi.utils.MPI_DDT(
            buffer.dtype.base_type)

        if buffer.dtype.veclen > 1:
            raise NotImplementedError
        code = ""
        if ddt is not None:
            code = f"""static MPI_Datatype newtype;
                        static int init=1;
                        if (init) {{
                           MPI_Type_vector({ddt['count']}, {ddt['blocklen']}, {ddt['stride']}, {ddt['oldtype']}, &newtype);
                           MPI_Type_commit(&newtype);
                           init = 0;
                        }}
                            """
            mpi_dtype_str = "newtype"
            count_str = "1"
        buffer_offset = 0  #this is here because the frontend already changes the pointer
        code += f"MPI_Irecv(_buffer, {count_str}, {mpi_dtype_str}, _src, _tag, MPI_COMM_WORLD, _request);"
        if ddt is not None:
            code += f"""// MPI_Type_free(&newtype);
            """
        tasklet = dace.sdfg.nodes.Tasklet(node.name,
                                          node.in_connectors,
                                          node.out_connectors,
                                          code,
                                          language=dace.dtypes.Language.CPP)

        conn = tasklet.out_connectors
        conn = {
            c: (dtypes.pointer(dtypes.opaque("MPI_Request"))
                if c == '_request' else t)
            for c, t in conn.items()
        }
        tasklet.out_connectors = conn
        return tasklet