def test_sendrecv_array(self, array): if not rank: comm.send(array, 1, 123) h_comm.send(array, 1, 456) else: assert np.allclose(comm.recv(None, 0, 123), h_comm.recv(None, 0, 456))
def test_sendrecv_list(self, lst): if not rank: comm.send(list(lst), 1, 123) h_comm.send(list(lst), 1, 456) else: assert np.allclose(comm.recv(None, 0, 123), h_comm.recv(None, 0, 456))
def test_gather_list(self, lst): g_lst1 = comm.gather(lst, 0) g_lst2 = h_comm.gather(lst, 0) assert type(g_lst1) == type(g_lst2) if not rank: for lst1, lst2 in zip(g_lst1, g_lst2): assert np.allclose(lst1, lst2)
def test_gather_array(self, array): g_array1 = comm.gather(array, 0) g_array2 = h_comm.gather(array, 0) assert type(g_array1) == type(g_array2) if not rank: for array1, array2 in zip(g_array1, g_array2): assert np.allclose(array1, array2)
def array(self): np.random.seed(comm.Get_rank()) return (np.random.rand(size, 10))
def test_comm_size_unity(self): s_comm = comm.Split(comm.Get_rank(), 0) assert get_HybridComm_obj(s_comm) is d_comm s_comm.Free()
def test_scatter_list(self, lst): assert np.allclose(comm.scatter(list(lst), 0), h_comm.scatter(list(lst), 0))
def test_scatter_array(self, array): assert np.allclose(comm.scatter(array, 0), h_comm.scatter(array, 0))
def test_bcast_list(self, lst): assert np.allclose(comm.bcast(lst, 0), h_comm.bcast(lst, 0))
def test_bcast_array(self, array): assert np.allclose(comm.bcast(array, 0), h_comm.bcast(array, 0))
# %% IMPORTS # Built-in imports from types import BuiltinMethodType, MethodType # Package imports import numpy as np import pytest # mpi4pyd imports from mpi4pyd import MPI from mpi4pyd.dummyMPI import COMM_WORLD as d_comm from mpi4pyd.MPI import (COMM_WORLD as comm, HYBRID_COMM_WORLD as h_comm, get_HybridComm_obj) # Get size and rank rank = comm.Get_rank() size = comm.Get_size() # Get method types m_types = (BuiltinMethodType, MethodType) # %% PYTEST CLASSES AND FUNCTIONS # Pytest for get_HybridComm_obj() function class Test_get_HybridComm_obj(object): # Test if default input arguments work def test_default(self): assert get_HybridComm_obj() is h_comm # Test if providing comm returns h_comm def test_comm(self):