Esempio n. 1
0
    def pytest_terminal_summary(self, terminalreporter, exitstatus, *args):
        """
        Hook for printing MPI info at the end of the run
        """
        # pylint: disable=unused-argument
        if self._is_testing_mpi:
            terminalreporter.section("MPI Information")
            try:
                from mpi4py import MPI, rc, get_config
            except ImportError:
                terminalreporter.write("Unable to import mpi4py")
            else:
                comm = MPI.COMM_WORLD
                terminalreporter.write("rank: {}\n".format(comm.rank))
                terminalreporter.write("size: {}\n".format(comm.size))

                terminalreporter.write("MPI version: {}\n".format(
                    '.'.join([str(v) for v in MPI.Get_version()])
                ))
                terminalreporter.write("MPI library version: {}\n".format(
                    MPI.Get_library_version()
                ))

                vendor, vendor_version = MPI.get_vendor()
                terminalreporter.write("MPI vendor: {} {}\n".format(
                    vendor, '.'.join([str(v) for v in vendor_version])
                ))

                terminalreporter.write("mpi4py rc: \n")
                for name, value in vars(rc).items():
                    terminalreporter.write(" {}: {}\n".format(name, value))

                terminalreporter.write("mpi4py config:\n")
                for name, value in get_config().items():
                    terminalreporter.write(" {}: {}\n".format(name, value))
Esempio n. 2
0
def have_feature():
    info = MPI.Get_library_version()
    if 'MPICH' in info and 'ch3:' in info:
        raise NotImplementedError
    sreq = MPI.COMM_SELF.Psend_init(bytearray(1), 1, 0, 0)
    rreq = MPI.COMM_SELF.Precv_init(bytearray(1), 1, 0, 0)
    sreq.Start()
    rreq.Start()
    sreq.Pready(0)
    rreq.Parrived(0)
    rreq.Wait()
    rreq.Free()
    del rreq
    sreq.Wait()
    sreq.Free()
    del sreq
Esempio n. 3
0
def check_mpi():
    """
    When called via::

        # python3 -m netket.tools.check_mpi
        mpi_available                : True
        mpi4jax_available            : True
        avalable_cpus (rank 0)       : 12
        n_nodes                      : 1
        mpi4py | MPI version         : (3, 1)
        mpi4py | MPI library_version : Open MPI v4.1.0, ...

    this will print out basic MPI information to allow users to check whether
    the environment has been set up correctly.
    """
    if rank > 0:
        return

    info = {
        "mpi_available": mpi_available,
        "mpi4jax_available": mpi4jax_available,
        "avalable_cpus (rank 0)": available_cpus(),
    }
    if mpi_available:
        from mpi4py import MPI

        info.update(
            {
                "n_nodes": n_nodes,
                "mpi4py | MPI version": MPI.Get_version(),
                "mpi4py | MPI library_version": MPI.Get_library_version(),
            }
        )

    maxkeylen = max(len(k) for k in info.keys())

    for k, v in info.items():
        print(f"{k:{maxkeylen}} : {v}")
Esempio n. 4
0
def ch4_ucx():
    return 'ch4:ucx' in MPI.Get_library_version()
Esempio n. 5
0
def ch4_ofi():
    return 'ch4:ofi' in MPI.Get_library_version()
Esempio n. 6
0
def info():
    """
    When called via::

        # python3 -m netket.tools.check_mpi
        mpi4py_available     : True
        mpi4jax_available : True
        n_nodes           : 1

    this will print out basic MPI information to make allow users to check whether
    the environment has been set up correctly.
    """
    print("====================================================")
    print("==         NetKet Diagnostic Informations         ==")
    print("====================================================")

    # try to import version without import netket itself
    from .. import _version

    printfmt("NetKet version", _version.version)
    print()

    print("# Python")
    printfmt("implementation", platform.python_implementation(), indent=1)
    printfmt("version", platform.python_version(), indent=1)
    printfmt("distribution", platform.python_compiler(), indent=1)
    printfmt("path", sys.executable, indent=1)
    print()

    # Try to detect platform
    print("# Host informations")
    printfmt("System      ", platform.platform(), indent=1)
    printfmt("Architecture", platform.machine(), indent=1)

    # Try to query cpu info
    platform_info = cpu_info()

    printfmt("AVX", platform_info["supports_avx"], indent=1)
    printfmt("AVX2", platform_info["supports_avx2"], indent=1)
    if "cpu cores" in platform_info:
        printfmt("Cores", platform_info["cpu cores"], indent=1)
    elif "cpu_cores" in platform_info:
        printfmt("Cores", platform_info["cpu_cores"], indent=1)
    elif "core_count" in platform_info:
        printfmt("Cores", platform_info["core_count"], indent=1)
    print()

    # try to load jax
    print("# NetKet dependencies")
    printfmt("numpy", version("numpy"), indent=1)
    printfmt("jaxlib", version("jaxlib"), indent=1)
    printfmt("jax", version("jax"), indent=1)
    printfmt("flax", version("flax"), indent=1)
    printfmt("optax", version("optax"), indent=1)
    printfmt("numba", version("numba"), indent=1)
    printfmt("mpi4py", version("mpi4py"), indent=1)
    printfmt("mpi4jax", version("mpi4jax"), indent=1)
    printfmt("netket", version("netket"), indent=1)
    print()

    if is_available("jax"):
        print("# Jax ")
        import jax

        backends = _jax_backends()
        printfmt("backends", backends, indent=1)
        for backend in backends:
            printfmt(
                f"{backend}",
                [_fmt_device(dev) for dev in jax.devices(backend)],
                indent=2,
            )
        print()

    if is_available("mpi4jax"):
        print("# MPI4JAX")
        import mpi4jax

        if hasattr(mpi4jax, "has_cuda_support"):
            printfmt("HAS_GPU_EXT", mpi4jax.has_cuda_support(), indent=1)
        elif hasattr(mpi4jax, "_src"):
            if hasattr(mpi4jax._src, "xla_bridge"):
                if hasattr(mpi4jax._src.xla_bridge, "HAS_GPU_EXT"):
                    printfmt(
                        "HAS_GPU_EXT", mpi4jax._src.xla_bridge.HAS_GPU_EXT, indent=1
                    )
        print()

    if is_available("mpi4py"):
        print("# MPI ")
        import mpi4py
        from mpi4py import MPI

        printfmt("mpi4py", indent=1)
        printfmt("MPICC", mpi4py.get_config()["mpicc"], indent=1)
        printfmt(
            "MPI link flags",
            get_link_flags(exec_in_terminal([mpi4py.get_config()["mpicc"], "-show"])),
            indent=1,
        )
        printfmt("MPI version", MPI.Get_version(), indent=2)
        printfmt("MPI library_version", MPI.Get_library_version(), indent=2)

        global_info = get_global_mpi_info()
        printfmt("global", indent=1)
        printfmt("MPICC", global_info["mpicc"], indent=2)
        printfmt("MPI link flags", global_info["link_flags"], indent=2)
        print()
from mpi4py import MPI

comm = MPI.COMM_WORLD
rank = comm.Get_rank()
size = comm.Get_size()
hostname = MPI.Get_processor_name()
ver = MPI.Get_version()
lib_ver = MPI.Get_library_version()
lib_ver = ""

try:
    print("Hello World! ({}/{}: {}; Ver{}; Lib{})".format(rank, size, hostname, ver, lib_ver))
    comm.Barrier()
    print("Done! ({}/{}: {})".format(rank, size, hostname))
except Exception as ex:
    print("{} ({}/{}: {}; Ver{}; Lib{})".format(ex, rank, size, hostname, ver, lib_ver))
Esempio n. 8
0

Revisions:  YYMMDD    Author            Comments
-------------------------------------------------------------------------------------
            171201    Cody Herndon, Brian Phung, Ashley Spear Speed up of python code
            151201    Brian Leavy       Initial visualization of particle domains
"""

try:
    from mpi4py import MPI
    comm = MPI.COMM_WORLD
    rank = comm.Get_rank()
    size = comm.Get_size()
    if rank == 0:
      print("uda2xmf.py: You must use the same MPI library version when running the script in parallel.")
      print("uda2xmf.py: MPI library Version: ",MPI.Get_library_version() )

except ImportError:
    comm = None
    rank = 0
    size = 1

class Data_parser:
    """!
    @brief A factory class that returns variable data and descriptors.

    Parses input from binary files, and generates descriptors for that data.
    """
    parse_all=False

    @staticmethod
Esempio n. 9
0
def mpi_run(func,
            targets=None,
            delete_tempfile=True,
            log=False,
            log_screen=False,
            log_file_name='neurokernel.log'):
    """
    Run a function with mpiexec.
    
    Implemented as a fix to 'import neurokernel.mpi_relaunch', which does not
    work within notebooks. Writes the source code for a function to a temporary
    file and then runs the temporary file using mpiexec. Returns the stdout of
    from the function along with a string indicating whether or not the function
    executed properly.

    Parameters
    ----------
    func : function or str
        Function to be executed with mpiexec. All imports and variables used
        must be imported or defined within the function. func can either be a callable
        function or code that represents a valid function.
    targets : list
        Dependencies of the manager, such as child classes of the Module class
        from neurokernel.core_gpu or neurokernel.core.
    delete_tempfile : bool
        Whether or not to delete temporary file once func is executed.
    log : boolean
        Whether or not to connect to logger for func if logger exists.
    log_screen : bool
        Whether or not to send log messages to the screen.
    log_file_name : str
        File to send log messages to.
    
    Returns
    -------
    output : str
        The stdout from the function run with mpiexec cast to a string.

    Usage
    -----
    Does not seem to work with openmpi version 2
    func should not import neurokernel.mpi_relaunch
    All modules and variables used must be imported or defined within func
    Returns the stdout from the function run under 'mpiexec -np 1 python {tmp_file_name}'
    """

    l = LoggerMixin("mpi_run()", log_on=log)

    if callable(func):
        func_text = inspect.getsource(func)
        # Make a feeble attempt at fixing indentation. Will work for a nested function
        # that takes no args, not a member function that expects (self) or a class
        func_text = "\n" + re.sub(r"(^\s+)def ", "def ", func_text) + "\n"
        func_name = func.__name__
    else:
        func_text = "\n" + func + "\n"
        func_name = re.search('def *(.*)\(\):', func_text).group(1)

    target_text = "\n"

    if targets:
        for t in targets:
            target_text += "\n" + inspect.getsource(t) + "\n"

    main_code = "\n"
    main_code += "\nif __name__ == \"__main__\":"
    main_code += "\n   import neurokernel.mpi as mpi"
    main_code += "\n   from neurokernel.mixins import LoggerMixin"
    main_code += "\n   from mpi4py import MPI"

    if log:
        main_code += "\n   mpi.setup_logger(screen=%s, file_name=\"%s\"," % (
            log_screen, log_file_name)
        main_code += "\n                    mpi_comm=MPI.COMM_WORLD, multiline=True)"

    main_code += "\n   l = LoggerMixin(\"%s\",%s)" % (func_name, str(log))
    main_code += "\n   try:"
    main_code += "\n      %s()" % func_name
    main_code += "\n      print(\"MPI_RUN_SUCCESS: %s\")" % func_name
    main_code += "\n      l.log_info(\"MPI_RUN_SUCCESS: %s\")" % func_name
    main_code += "\n   except Exception as e:"
    main_code += "\n      print(\"MPI_RUN_FAILURE: %s\")" % func_name
    main_code += "\n      l.log_error(\"MPI_RUN_FAILURE: %s\")" % func_name
    main_code += "\n      print(e)"
    main_code += "\n"

    try:
        from mpi4py import MPI
        #Write code for the function to a temp file
        temp = tempfile.NamedTemporaryFile(delete=delete_tempfile)
        temp.write(target_text)
        temp.write(func_text)
        temp.write(main_code)
        temp.flush()

        #Execute the code
        #There's a bug in Open MPI v2 that prevents running this with mpiexec. Running 'from mpi4py import MPI'
        #does a basic mpi_relaunch which will work for the notebook code, but you give up some of the features
        #of mpiexec.
        if MPI.Get_library_version().startswith("Open MPI v2"):
            command = ["python", temp.name]
        else:
            command = ["mpiexec", "-np", "1", "python", temp.name]

        env = os.environ.copy()
        l.log_info("Calling: " + " ".join(command))
        out = subprocess.check_output(command, env=env)

    except Exception as e:
        l.log_error(str(e))
        raise

    finally:
        #Closing the temp file closes and deletes it
        temp.close()

    #Return the output
    if "MPI_RUN_FAILURE" in out:
        raise RuntimeError(out)

    return str(out)
Esempio n. 10
0
 def testGetLibraryVersion(self):
     version = MPI.Get_library_version()
     self.assertTrue(isinstance(version, str))
     self.assertTrue(len(version) > 0)