コード例 #1
0
ファイル: test_comm.py プロジェクト: mpi4py/mpi4py
    def testContructor(self):
        comm = MPI.Comm()
        self.assertEqual(comm, MPI.COMM_NULL)
        self.assertFalse(comm is MPI.COMM_NULL)

        def construct():
            MPI.Comm((1, 2, 3))

        self.assertRaises(TypeError, construct)
コード例 #2
0
ファイル: utils.py プロジェクト: kiminh/mpi4jax
def MPIComm_from_ptr(ptr):
    """
    MPIComm_from_ptr(ptr)

    Constructs a MPI Comm object from a pointer
    """
    comm = _MPI.Comm()
    comm_ptr = ctypes.c_void_p.from_address(_MPI._addressof(comm))
    comm_ptr.value = int(ptr)
    return comm
コード例 #3
0
def sum_inplace_jax(x, comm):
    if not isinstance(x, jax.interpreters.xla.DeviceArray):
        raise TypeError("Argument to sum_inplace_jax must be a DeviceArray, got {}"
                        .format(type(x)))

    _x = jax.xla._force(x.block_until_ready())
    ptr = _x.device_buffer.unsafe_buffer_pointer()

    # rebuild comm
    _comm = MPI.Comm()
    _comm_ptr = ctypes.c_void_p.from_address(MPI._addressof(_comm))
    _comm_ptr.value = int(comm)

    # using native numpy because jax's numpy does not have ctypeslib
    data_pointer = _np.ctypeslib.ndpointer(x.dtype, shape=x.shape)

    # wrap jax data into a standard numpy array which is handled by MPI
    arr = data_pointer(ptr).contents

    _comm.Allreduce(MPI.IN_PLACE, arr, op=MPI.SUM)

    return _x
コード例 #4
0
ファイル: test_comm.py プロジェクト: mpi4py/mpi4py
 def testContructor(self):
     comm = MPI.Comm(self.COMM)
     self.assertEqual(comm, self.COMM)
     self.assertFalse(comm is self.COMM)
コード例 #5
0
ファイル: test_comm.py プロジェクト: mpi4py/mpi4py
 def construct():
     MPI.Comm((1, 2, 3))
コード例 #6
0
 def testContructor(self):
     comm = MPI.Comm()
     self.assertFalse(comm is MPI.COMM_NULL)
     self.assertEqual(comm, MPI.COMM_NULL)
コード例 #7
0
ファイル: solver.py プロジェクト: michelemartone/lyncs
 def comm(self):
     "Returns the MPI communicator used by the library."
     comm = MPI.Comm()
     comm_ptr = ll.cast["MPI_Comm*"](MPI._addressof(comm))
     ll.assign(comm_ptr, lib.DDalphaAMG_get_communicator())
     return comm