mpi4jax enables zero-copy, multi-host communication of JAX arrays, even from jitted code and from GPU memory.
The JAX framework has great performance for scientific computing workloads, but its multi-host capabilities are still limited.
mpi4jax, you can scale your JAX-based simulations to entire CPU and GPU clusters (without ever leaving
In the spirit of differentiable programming,
mpi4jax also supports differentiating through some MPI operations.
mpi4jax is available through
$ pip install mpi4jax # Pip $ conda install -c conda-forge mpi4jax # conda
Our documentation includes some more advanced installation examples.
from mpi4py import MPI import jax import jax.numpy as jnp import mpi4jax comm = MPI.COMM_WORLD rank = comm.Get_rank() @jax.jit def foo(arr): arr = arr + rank arr_sum, _ = mpi4jax.allreduce(arr, op=MPI.SUM, comm=comm) return arr_sum a = jnp.zeros((3, 3)) result = foo(a) if rank == 0: print(result)
Running this script on 4 processes gives:
$ mpirun -n 4 python example.py [[6. 6. 6.] [6. 6. 6.] [6. 6. 6.]]
allreduce is just one example of the MPI primitives you can use. See all supported operations here.
We use pre-commit hooks to enforce a common code format. To install them, just run:
$ pip install pre-commit $ pre-commit install
You can set the environment variable
enable debug logging every time an MPI primitive is called from within a
jitted function. You will then see messages like this:
$ MPI4JAX_DEBUG=1 mpirun -n 2 python send_recv.py r0 | MPI_Send -> 1 with tag 0 and token 7fd7abc5f5c0 r1 | MPI_Recv <- 0 with tag -1 and token 7f9af7419ac0
- Usage examples
- Demo application: Shallow-water model
- 🔪 The Sharp Bits 🔪
- Communication primitives