Usage examples
Basic example: Global sum
The following computes the sum of an array over several processes (similar to jax.lax.psum()), using allreduce():
from mpi4py import MPI
import jax
import jax.numpy as jnp
import mpi4jax
comm = MPI.COMM_WORLD
rank = comm.Get_rank()
@jax.jit
def foo(arr):
arr = arr + rank
arr_sum, _ = mpi4jax.allreduce(arr, op=MPI.SUM, comm=comm)
return arr_sum
a = jnp.zeros((3, 3))
result = foo(a)
if rank == 0:
print(result)
Most MPI libraries supply a wrapper executable mpirun to execute a script on several processes:
$ mpirun -n 4 python mpi4jax-example.py
[[6. 6. 6.]
[6. 6. 6.]
[6. 6. 6.]]
The result is an array full of the value 6, because each process adds its rank to the result (4 processes with ranks 0, 1, 2, 3).
Basic example: sending and receiving
mpi4jax can send and receive data without performing an operation on it. For this, you can use send() and recv():
from mpi4py import MPI
import jax
import jax.numpy as jnp
import mpi4jax
comm = MPI.COMM_WORLD
size = comm.Get_size()
rank = comm.Get_rank()
assert size == 2 # make sure we are on 2 processes
@jax.jit
def foo(arr):
arr = arr + rank
# note: this could also use mpi4jax.sendrecv
if rank == 0:
# send, then receive
mpi4jax.send(arr, dest=1, comm=comm)
other_arr = mpi4jax.recv(arr, source=1, comm=comm)
else:
# receive, then send
other_arr = mpi4jax.recv(arr, source=0, comm=comm)
mpi4jax.send(arr, dest=0, comm=comm)
return other_arr
a = jnp.zeros((3, 3))
result = foo(a)
print(f'r{rank} | {result}')
Executing this shows that each process has received the data from the other process:
$ mpirun -n 2 python mpi4jax-example-2.py
r1 | [[0. 0. 0.]
[0. 0. 0.]
[0. 0. 0.]]
r0 | [[1. 1. 1.]
[1. 1. 1.]
[1. 1. 1.]]
For operations like this, the correct order of the send() / recv() calls is critical to prevent the program from deadlocking (e.g. when both processes wait for data at the same time).