def testContructor(self): comm = MPI.Comm() self.assertEqual(comm, MPI.COMM_NULL) self.assertFalse(comm is MPI.COMM_NULL) def construct(): MPI.Comm((1, 2, 3)) self.assertRaises(TypeError, construct)
def MPIComm_from_ptr(ptr): """ MPIComm_from_ptr(ptr) Constructs a MPI Comm object from a pointer """ comm = _MPI.Comm() comm_ptr = ctypes.c_void_p.from_address(_MPI._addressof(comm)) comm_ptr.value = int(ptr) return comm
def sum_inplace_jax(x, comm): if not isinstance(x, jax.interpreters.xla.DeviceArray): raise TypeError("Argument to sum_inplace_jax must be a DeviceArray, got {}" .format(type(x))) _x = jax.xla._force(x.block_until_ready()) ptr = _x.device_buffer.unsafe_buffer_pointer() # rebuild comm _comm = MPI.Comm() _comm_ptr = ctypes.c_void_p.from_address(MPI._addressof(_comm)) _comm_ptr.value = int(comm) # using native numpy because jax's numpy does not have ctypeslib data_pointer = _np.ctypeslib.ndpointer(x.dtype, shape=x.shape) # wrap jax data into a standard numpy array which is handled by MPI arr = data_pointer(ptr).contents _comm.Allreduce(MPI.IN_PLACE, arr, op=MPI.SUM) return _x
def testContructor(self): comm = MPI.Comm(self.COMM) self.assertEqual(comm, self.COMM) self.assertFalse(comm is self.COMM)
def construct(): MPI.Comm((1, 2, 3))
def testContructor(self): comm = MPI.Comm() self.assertFalse(comm is MPI.COMM_NULL) self.assertEqual(comm, MPI.COMM_NULL)
def comm(self): "Returns the MPI communicator used by the library." comm = MPI.Comm() comm_ptr = ll.cast["MPI_Comm*"](MPI._addressof(comm)) ll.assign(comm_ptr, lib.DDalphaAMG_get_communicator()) return comm