Skip to content
/ mpi4jax Public
forked from mpi4jax/mpi4jax

MPI greets Jax and speeds it ups

License

Notifications You must be signed in to change notification settings

kiminh/mpi4jax

 
 

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

96 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

mpi4jax

Tests Conda Recipe

MPI plugin for JAX, allowing MPI operations to be inserted in jitted blocks.

Installation

You can Install mpi4jax through pip (see below) or conda (click on the badge)

pip install mpi4jax                     # Pip
conda install -c conda-forge mpi4jax    # conda

Supported operations

  • Send
  • Recv
  • Sendrecv
  • Allreduce

Usage

from mpi4py import MPI
import jax
import mpi4jax

comm = MPI.COMM_WORLD
a = jax.numpy.ones(5,4)
b = mpi4jax.Allreduce(a, op=MPI.SUM, comm=comm)
b_jit = jax.jit(lambda x: mpi4jax.Allreduce(x, op=MPI.SUM, comm=comm))(a)

Debugging

You can set the environment variable MPI4JAX_DEBUG to 1 to enable debug logging every time an MPI primitive is called from within a jitted function. You will then see messages like this:

$ MPI4JAX_DEBUG=1 mpirun -n 2 python send_recv.py
r0 | MPI_Send -> 1 with tag 0 and token 7fd7abc5f5c0
r1 | MPI_Recv <- 0 with tag -1 and token 7f9af7419ac0

Contributors

  • Filippo Vicentini @PhilipVinc
  • Dion Häfner @dionhaefner

About

MPI greets Jax and speeds it ups

Resources

License

Stars

Watchers

Forks

Packages

No packages published

Languages

  • Python 89.0%
  • Shell 11.0%