Esempio n. 1
0
 def testIReadIWriteAll(self):
     comm = self.COMM
     size = comm.Get_size()
     rank = comm.Get_rank()
     fh = self.FILE
     try:  # MPI 3.1
         for array in arrayimpl.ArrayTypes:
             for typecode in arrayimpl.TypeMap:
                 etype = arrayimpl.TypeMap[typecode]
                 fh.Set_size(0)
                 fh.Set_view(0, etype)
                 count = 13
                 wbuf = array(42, typecode, count)
                 fh.Seek(count * rank, MPI.SEEK_SET)
                 fh.Iwrite_all(wbuf.as_raw()).Wait()
                 fh.Sync()
                 comm.Barrier()
                 fh.Sync()
                 rbuf = array(-1, typecode, count + 1)
                 fh.Seek(count * rank, MPI.SEEK_SET)
                 fh.Iread_all(rbuf.as_mpi_c(count)).Wait()
                 for value in rbuf[:-1]:
                     self.assertEqual(value, 42)
                 self.assertEqual(rbuf[-1], -1)
                 comm.Barrier()
     except NotImplementedError:
         if MPI.Get_version() >= (3, 1): raise
         self.skipTest("mpi-iwrite_all")
Esempio n. 2
0
 def testCreateEnv(self):
     try:
         env = MPI.Info.Create_env()
     except NotImplementedError:
         if MPI.Get_version() >= (4, 0): raise
         raise unittest.SkipTest("mpi-info-create-env")
     for key in self.KEYS:
         v = env.Get(key)
     try:
         dup = env.Dup()
         try:
             for key in self.KEYS:
                 self.assertEqual(env.Get(key), dup.Get(key))
         finally:
             dup.Free()
     finally:
         env.Free()
     for args in (
         None, [], (),
         sys.executable,
         [sys.executable],
         (sys.executable,),
     ):
         MPI.Info.Create_env(args).Free()
         MPI.Info.Create_env(args=args).Free()
Esempio n. 3
0
def getlibraryinfo():
    from mpi4py import MPI
    info = "MPI %d.%d" % MPI.Get_version()
    name, version = MPI.get_vendor()
    if name != "unknown":
        info += (" (%s %s)" % (name, '%d.%d.%d' % version))
    return info
Esempio n. 4
0
 def testSetCancelled(self):
     try:
         self.STATUS.Set_cancelled(True)
         flag = self.STATUS.Is_cancelled()
         self.assertTrue(flag)
     except NotImplementedError:
         if MPI.Get_version() >= (2, 0): raise
Esempio n. 5
0
 def testIReadIWriteAtAll(self):
     comm = self.COMM
     size = comm.Get_size()
     rank = comm.Get_rank()
     fh = self.FILE
     for array, typecode in arrayimpl_loop_io():
         with arrayimpl.test(self):
             try:  # MPI 3.1
                 etype = array.TypeMap[typecode]
                 fh.Set_size(0)
                 fh.Set_view(0, etype)
                 count = 13
                 wbuf = array(42, typecode, count)
                 fh.Iwrite_at_all(count * rank, wbuf.as_raw()).Wait()
                 fh.Sync()
                 comm.Barrier()
                 fh.Sync()
                 rbuf = array(-1, typecode, count + 1)
                 fh.Iread_at_all(count * rank, rbuf.as_mpi_c(count)).Wait()
                 for value in rbuf[:-1]:
                     self.assertEqual(value, 42)
                 self.assertEqual(rbuf[-1], -1)
                 comm.Barrier()
             except NotImplementedError:
                 if MPI.Get_version() >= (3, 1): raise
                 self.skipTest('mpi-iwrite_at_all')
Esempio n. 6
0
 def assertRaisesMPI(self, IErrClass, callableObj, *args, **kwargs):
     from mpi4py import MPI
     excClass = MPI.Exception
     try:
         callableObj(*args, **kwargs)
     except NotImplementedError:
         if MPI.Get_version() < (2, 0):
             raise self.failureException("raised NotImplementedError")
     except excClass as excValue:
         error_class = excValue.Get_error_class()
         if isinstance(IErrClass, (list, tuple)):
             match = (error_class in IErrClass)
         else:
             match = (error_class == IErrClass)
         if not match:
             if isinstance(IErrClass, (list, tuple)):
                 IErrClassName = [ErrClsName(e) for e in IErrClass]
                 IErrClassName = type(IErrClass)(IErrClassName)
             else:
                 IErrClassName = ErrClsName(IErrClass)
             raise self.failureException(
                 "generated error class is '{}' ({}), "
                 "but expected '{}' ({})".format(
                     ErrClsName(error_class),
                     error_class,
                     IErrClassName,
                     IErrClass,
                 ))
     else:
         raise self.failureException(f"{excClass.__name__} not raised")
Esempio n. 7
0
 def testISendrecv(self):
     size = self.COMM.Get_size()
     rank = self.COMM.Get_rank()
     dest = (rank + 1) % size
     source = (rank - 1) % size
     try:
         self.COMM.Isendrecv(
             bytearray(1),
             dest,
             0,
             bytearray(1),
             source,
             0,
         ).Wait()
     except NotImplementedError:
         if MPI.Get_version() >= (4, 0): raise
         raise unittest.SkipTest("mpi-isendrecv")
     for array, typecode in arrayimpl.loop():
         with arrayimpl.test(self):
             for s in range(0, size + 1):
                 with self.subTest(s=s):
                     sbuf = array(s, typecode, s)
                     rbuf = array(-1, typecode, s + 1)
                     self.COMM.Isendrecv(
                         sbuf.as_mpi(),
                         dest,
                         0,
                         rbuf.as_mpi(),
                         source,
                         0,
                     ).Wait()
                     for value in rbuf[:-1]:
                         self.assertEqual(value, s)
                     self.assertEqual(rbuf[-1], -1)
Esempio n. 8
0
    def pytest_terminal_summary(self, terminalreporter, exitstatus, *args):
        """
        Hook for printing MPI info at the end of the run
        """
        # pylint: disable=unused-argument
        if self._is_testing_mpi:
            terminalreporter.section("MPI Information")
            try:
                from mpi4py import MPI, rc, get_config
            except ImportError:
                terminalreporter.write("Unable to import mpi4py")
            else:
                comm = MPI.COMM_WORLD
                terminalreporter.write("rank: {}\n".format(comm.rank))
                terminalreporter.write("size: {}\n".format(comm.size))

                terminalreporter.write("MPI version: {}\n".format(
                    '.'.join([str(v) for v in MPI.Get_version()])
                ))
                terminalreporter.write("MPI library version: {}\n".format(
                    MPI.Get_library_version()
                ))

                vendor, vendor_version = MPI.get_vendor()
                terminalreporter.write("MPI vendor: {} {}\n".format(
                    vendor, '.'.join([str(v) for v in vendor_version])
                ))

                terminalreporter.write("mpi4py rc: \n")
                for name, value in vars(rc).items():
                    terminalreporter.write(" {}: {}\n".format(name, value))

                terminalreporter.write("mpi4py config:\n")
                for name, value in get_config().items():
                    terminalreporter.write(" {}: {}\n".format(name, value))
Esempio n. 9
0
 def testGetVersion(self):
     version = MPI.Get_version()
     self.assertEqual(len(version), 2)
     major, minor = version
     self.assertTrue(type(major) is int)
     self.assertTrue(major >= 1)
     self.assertTrue(type(minor) is int)
     self.assertTrue(minor >= 0)
Esempio n. 10
0
 def testSetElements(self):
     try:
         self.STATUS.Set_elements(MPI.BYTE, 7)
         count = self.STATUS.Get_count(MPI.BYTE)
         self.assertEqual(count, 7)
         elements = self.STATUS.Get_elements(MPI.BYTE)
         self.assertEqual(elements, 7)
     except NotImplementedError:
         if MPI.Get_version() >= (2, 0): raise
Esempio n. 11
0
def getlibraryinfo():
    from mpi4py import MPI
    x, y = MPI.Get_version()
    info = f"MPI {x}.{y}"
    name, version = MPI.get_vendor()
    if name != "unknown":
        x, y, z = version
        info += f" ({name} {x}.{y}.{z})"
    return info
Esempio n. 12
0
def mpi_predicate(predicate):
    from mpi4py import MPI
    def key(s):
        s = s.replace(' ', '')
        s = s.replace('/', '')
        s = s.replace('-', '')
        s = s.replace('Microsoft', 'MS')
        return s.lower()
    vp = VersionPredicate(key(predicate))
    if vp.name == 'mpi':
        name, version = 'mpi', MPI.Get_version()
        version = version + (0,)
    else:
        name, version = MPI.get_vendor()
    if vp.name == key(name):
        if vp.satisfied_by('%d.%d.%d' % version):
            return vp
    return None
Esempio n. 13
0
def skipMPI(predicate, *conditions):
    from mpi4py import MPI
    def key(s):
        s = s.replace(' ', '')
        s = s.replace('/', '')
        s = s.replace('-', '')
        s = s.replace('Microsoft', 'MS')
        return s.lower()
    vp = VersionPredicate(key(predicate))
    if vp.name == 'mpi':
        name, version = 'mpi', MPI.Get_version()
        version = version + (0,)
    else:
        name, version = MPI.get_vendor()
    if vp.name == key(name):
        if vp.satisfied_by('%d.%d.%d' % version):
            if not conditions or any(conditions):
                return unittest.skip(str(vp))
    return unittest.skipIf(False, '')
Esempio n. 14
0
    def get_intro_string(self):
        """Return the string to append to the end of the relax introduction string.

        @return:    The string describing this Processor fabric.
        @rtype:     str
        """

        # Get the specific MPI version.
        version_info = MPI.Get_version()

        # The vendor info.
        vendor = MPI.get_vendor()
        vendor_name = vendor[0]
        vendor_version = str(vendor[1][0])
        for i in range(1, len(vendor[1])):
            vendor_version = vendor_version + '.%i' % vendor[1][i]

        # Return the string.
        return "MPI %s.%s running via mpi4py with %i slave processors & 1 master.  Using %s %s." % (version_info[0], version_info[1], self.processor_size(), vendor_name, vendor_version)
Esempio n. 15
0
 def testProcNullISendrecv(self):
     try:
         self.COMM.Isendrecv(
             None,
             MPI.PROC_NULL,
             0,
             None,
             MPI.PROC_NULL,
             0,
         ).Wait()
         self.COMM.Isendrecv_replace(
             None,
             MPI.PROC_NULL,
             0,
             MPI.PROC_NULL,
             0,
         ).Wait()
     except NotImplementedError:
         if MPI.Get_version() >= (4, 0): raise
         raise unittest.SkipTest("mpi-isendrecv")
Esempio n. 16
0
def check_mpi():
    """
    When called via::

        # python3 -m netket.tools.check_mpi
        mpi_available                : True
        mpi4jax_available            : True
        avalable_cpus (rank 0)       : 12
        n_nodes                      : 1
        mpi4py | MPI version         : (3, 1)
        mpi4py | MPI library_version : Open MPI v4.1.0, ...

    this will print out basic MPI information to allow users to check whether
    the environment has been set up correctly.
    """
    if rank > 0:
        return

    info = {
        "mpi_available": mpi_available,
        "mpi4jax_available": mpi4jax_available,
        "avalable_cpus (rank 0)": available_cpus(),
    }
    if mpi_available:
        from mpi4py import MPI

        info.update(
            {
                "n_nodes": n_nodes,
                "mpi4py | MPI version": MPI.Get_version(),
                "mpi4py | MPI library_version": MPI.Get_library_version(),
            }
        )

    maxkeylen = max(len(k) for k in info.keys())

    for k, v in info.items():
        print(f"{k:{maxkeylen}} : {v}")
Esempio n. 17
0
def call(func, *args, **kwargs):
    """Distribute function call to MPI processes via mpi4py.futures serialization/deserialization mechanism."""
    logger = logging.getLogger(__name__)
    comm = MPI.COMM_WORLD

    with MPIPoolExecutor() as executor:
        rank = comm.Get_rank()
        size = comm.Get_size()
        version = MPI.Get_version()
        logger.debug("MPI version: %s" % str(version))
        logger.debug("Current MPI size is: %s" % size)
        universe_size = comm.Get_attr(MPI.UNIVERSE_SIZE)
        logger.debug("MPI universe size is: %s" % universe_size)

        # distribute payload to rank 1 to MPI_UNIVERSE_SIZE
        jobs = [
            executor.submit(func, *args, **kwargs) for _ in range(size - 1)
        ]
        # also run payload on rank 0
        data = func(*args, **kwargs)

    # function may have return value on rank 0
    if data is not None:
        return data
Esempio n. 18
0
if name == 'MPICH':
    if MPI.COMM_WORLD.Get_attr(MPI.APPNUM) is None:
        SKIP_POOL_TEST = True
if name == 'MVAPICH2':
    SKIP_POOL_TEST = True
if name == 'MPICH2':
    if MPI.COMM_WORLD.Get_attr(MPI.APPNUM) is None:
        SKIP_POOL_TEST = True
if name == 'Microsoft MPI':
    if version < (8, 1, 0):
        SKIP_POOL_TEST = True
    if MPI.COMM_WORLD.Get_attr(MPI.APPNUM) is None:
        SKIP_POOL_TEST = True
if name == 'Platform MPI':
    SKIP_POOL_TEST = True
if MPI.Get_version() < (2, 0):
    SKIP_POOL_TEST = True

if SHARED_POOL:
    del MPICommExecutorTest.test_arg_root
    del MPICommExecutorTest.test_arg_comm_bad
    del ProcessPoolInitTest.test_init_globals
    if WORLD_SIZE == 1:
        del ProcessPoolInitTest.test_run_name
        del ProcessPoolPickleTest
elif WORLD_SIZE > 1 or SKIP_POOL_TEST:
    del ProcessPoolInitTest
    del ProcessPoolBootupTest
    del ProcessPoolShutdownTest
    del ProcessPoolWaitTest
    del ProcessPoolAsCompletedTest
Esempio n. 19
0
from mpi4py import MPI

comm = MPI.COMM_WORLD
rank = comm.Get_rank()
size = comm.Get_size()
hostname = MPI.Get_processor_name()
ver = MPI.Get_version()
lib_ver = MPI.Get_library_version()
lib_ver = ""

try:
    print("Hello World! ({}/{}: {}; Ver{}; Lib{})".format(rank, size, hostname, ver, lib_ver))
    comm.Barrier()
    print("Done! ({}/{}: {})".format(rank, size, hostname))
except Exception as ex:
    print("{} ({}/{}: {}; Ver{}; Lib{})".format(ex, rank, size, hostname, ver, lib_ver))
Esempio n. 20
0
class MPIEnvironment:

    # Static variables #################################################################################################

    # Set hostname
    hostname = socket.gethostname()

    # Initialization
    mpi_initialized = False
    try:
        # don't load mpi unless we are already running under mpi
        # trying to load a broken mpi installation will abort the process not
        # giving us a chance to run in the serial mode
        # testing mpi via a forked import causes deadlock on process end when
        # running test_mpi4casa[test_server_not_responsive]
        if 'OMPI_COMM_WORLD_RANK' not in os.environ:
            raise ValueError('MPI disabled')

        # Set mpi4py runtime configuration
        from mpi4py import rc as __mpi_runtime_config
        # Automatic MPI initialization at import time
        __mpi_runtime_config.initialize = True
        # Request for thread support at MPI initialization
        __mpi_runtime_config.threaded = True
        # Level of thread support to request at MPI initialization
        # "single" : use MPI_THREAD_SINGLE
        # "funneled" : use MPI_THREAD_FUNNELED
        # "serialized" : use MPI_THREAD_SERIALIZED
        # "multiple" : use MPI_THREAD_MULTIPLE
        __mpi_runtime_config.thread_level = 'multiple'
        # Automatic MPI finalization at exit time
        __mpi_runtime_config.finalize = False

        # Import mpi4py and thus initialize MPI
        from mpi4py import MPI as __mpi_factory  # NOTE: This is a private variable to avoid uncontrolled access to MPI

        # Get world size and processor rank
        mpi_world_size = __mpi_factory.COMM_WORLD.Get_size()
        mpi_processor_rank = __mpi_factory.COMM_WORLD.Get_rank()

        # Get pid
        mpi_pid = os.getpid()

        # Get version and vendor info
        mpi_version_info = __mpi_factory.Get_version()
        mpi_vendor_info = __mpi_factory.get_vendor()
        mpi_thread_safe_level = __mpi_factory.Query_thread()

        # Prepare version info string
        mpi_version_str = str(mpi_version_info[0])
        for tuple_element_index in range(1, len(mpi_version_info)):
            mpi_version_str = mpi_version_str + "." + str(
                mpi_version_info[tuple_element_index])

        # Prepare vendor info string
        mpi_vendor_str = str(mpi_vendor_info[0])
        if len(mpi_vendor_info) > 1:
            mpi_vendor_version = mpi_vendor_info[1]
            mpi_vendor_version_str = str(mpi_vendor_version[0])
            for tuple_element_index in range(1, len(mpi_vendor_version)):
                mpi_vendor_version_str = mpi_vendor_version_str + "." + str(
                    mpi_vendor_version[tuple_element_index])
            mpi_vendor_str = mpi_vendor_str + " v" + mpi_vendor_version_str

        # Set initialization flag
        mpi_initialized = True
    except Exception, instance:
        mpi_initialization_error_msg = traceback.format_exc()
        __mpi_factory = None
        mpi_world_size = 1
        mpi_processor_rank = 0
        mpi_version_info = None
        mpi_vendor_info = None
        mpi_version_str = ""
        mpi_vendor_str = ""
        mpi_vendor_version_str = ""
        mpi_thread_safe_level = -1
        mpi_initialized = False
Esempio n. 21
0
def info():
    """
    When called via::

        # python3 -m netket.tools.check_mpi
        mpi4py_available     : True
        mpi4jax_available : True
        n_nodes           : 1

    this will print out basic MPI information to make allow users to check whether
    the environment has been set up correctly.
    """
    print("====================================================")
    print("==         NetKet Diagnostic Informations         ==")
    print("====================================================")

    # try to import version without import netket itself
    from .. import _version

    printfmt("NetKet version", _version.version)
    print()

    print("# Python")
    printfmt("implementation", platform.python_implementation(), indent=1)
    printfmt("version", platform.python_version(), indent=1)
    printfmt("distribution", platform.python_compiler(), indent=1)
    printfmt("path", sys.executable, indent=1)
    print()

    # Try to detect platform
    print("# Host informations")
    printfmt("System      ", platform.platform(), indent=1)
    printfmt("Architecture", platform.machine(), indent=1)

    # Try to query cpu info
    platform_info = cpu_info()

    printfmt("AVX", platform_info["supports_avx"], indent=1)
    printfmt("AVX2", platform_info["supports_avx2"], indent=1)
    if "cpu cores" in platform_info:
        printfmt("Cores", platform_info["cpu cores"], indent=1)
    elif "cpu_cores" in platform_info:
        printfmt("Cores", platform_info["cpu_cores"], indent=1)
    elif "core_count" in platform_info:
        printfmt("Cores", platform_info["core_count"], indent=1)
    print()

    # try to load jax
    print("# NetKet dependencies")
    printfmt("numpy", version("numpy"), indent=1)
    printfmt("jaxlib", version("jaxlib"), indent=1)
    printfmt("jax", version("jax"), indent=1)
    printfmt("flax", version("flax"), indent=1)
    printfmt("optax", version("optax"), indent=1)
    printfmt("numba", version("numba"), indent=1)
    printfmt("mpi4py", version("mpi4py"), indent=1)
    printfmt("mpi4jax", version("mpi4jax"), indent=1)
    printfmt("netket", version("netket"), indent=1)
    print()

    if is_available("jax"):
        print("# Jax ")
        import jax

        backends = _jax_backends()
        printfmt("backends", backends, indent=1)
        for backend in backends:
            printfmt(
                f"{backend}",
                [_fmt_device(dev) for dev in jax.devices(backend)],
                indent=2,
            )
        print()

    if is_available("mpi4jax"):
        print("# MPI4JAX")
        import mpi4jax

        if hasattr(mpi4jax, "has_cuda_support"):
            printfmt("HAS_GPU_EXT", mpi4jax.has_cuda_support(), indent=1)
        elif hasattr(mpi4jax, "_src"):
            if hasattr(mpi4jax._src, "xla_bridge"):
                if hasattr(mpi4jax._src.xla_bridge, "HAS_GPU_EXT"):
                    printfmt(
                        "HAS_GPU_EXT", mpi4jax._src.xla_bridge.HAS_GPU_EXT, indent=1
                    )
        print()

    if is_available("mpi4py"):
        print("# MPI ")
        import mpi4py
        from mpi4py import MPI

        printfmt("mpi4py", indent=1)
        printfmt("MPICC", mpi4py.get_config()["mpicc"], indent=1)
        printfmt(
            "MPI link flags",
            get_link_flags(exec_in_terminal([mpi4py.get_config()["mpicc"], "-show"])),
            indent=1,
        )
        printfmt("MPI version", MPI.Get_version(), indent=2)
        printfmt("MPI library_version", MPI.Get_library_version(), indent=2)

        global_info = get_global_mpi_info()
        printfmt("global", indent=1)
        printfmt("MPICC", global_info["mpicc"], indent=2)
        printfmt("MPI link flags", global_info["link_flags"], indent=2)
        print()
Esempio n. 22
0
        self.WIN.Unlock_all()


class TestRMASelf(BaseTestRMA, unittest.TestCase):
    COMM = MPI.COMM_SELF


class TestRMAWorld(BaseTestRMA, unittest.TestCase):
    COMM = MPI.COMM_WORLD


try:
    MPI.Win.Create(None, 1, MPI.INFO_NULL, MPI.COMM_SELF).Free()
except NotImplementedError:
    del TestRMASelf, TestRMAWorld
else:
    name, version = MPI.get_vendor()
    if name == 'Open MPI':
        if version < (1, 8, 1):
            del TestRMASelf, TestRMAWorld
    elif name == 'MPICH2':
        if version < (1, 5, 0):
            del TestRMASelf, TestRMAWorld
        elif version >= (2, 0, 0) and MPI.VERSION < 3:  # Intel MPI
            del TestRMASelf, TestRMAWorld
    elif MPI.Get_version() < (3, 0):
        del TestRMASelf, TestRMAWorld

if __name__ == '__main__':
    unittest.main()