Usage examples

Basic example: Global sum

The following computes the sum of an array over several processes (similar to jax.lax.psum()), using allreduce():

from mpi4py import MPI
import jax
import jax.numpy as jnp
import mpi4jax

rank = comm.Get_rank()

def foo(arr):
   arr = arr + rank
   arr_sum, _ = mpi4jax.allreduce(arr, op=MPI.SUM, comm=comm)
   return arr_sum

a = jnp.zeros((3, 3))
result = foo(a)

if rank == 0:

Most MPI libraries supply a wrapper executable mpirun to execute a script on several processes:

$ mpirun -n 4 python
[[6. 6. 6.]
 [6. 6. 6.]
 [6. 6. 6.]]

The result is an array full of the value 6, because each process adds its rank to the result (4 processes with ranks 0, 1, 2, 3).

Basic example: sending and receiving

mpi4jax can of course also send and receive data without performing an operation on it. For this, you can use send() and recv():

from mpi4py import MPI
import jax
import jax.numpy as jnp
import mpi4jax

size = comm.Get_size()
rank = comm.Get_rank()
assert size == 2  # make sure we are on 2 processes

def foo(arr):
    arr = arr + rank
    # note: this could also use mpi4jax.sendrecv
    if rank == 0:
        # send, then receive
        token = mpi4jax.send(arr, dest=1, comm=comm)
        other_arr, token = mpi4jax.recv(arr, source=1, comm=comm, token=token)
        # receive, then send
        other_arr, token = mpi4jax.recv(arr, source=0, comm=comm)
        token = mpi4jax.send(arr, dest=0, comm=comm, token=token)

    return other_arr

a = jnp.zeros((3, 3))
result = foo(a)

print(f'r{rank} | {result}')

Executing this shows that each process has received the data from the other process:

$ mpirun -n 2 python
r1 | [[0. 0. 0.]
 [0. 0. 0.]
 [0. 0. 0.]]
r0 | [[1. 1. 1.]
 [1. 1. 1.]
 [1. 1. 1.]]

For operations like this, the correct order of the send() / recv() calls is critical to prevent the program from deadlocking (e.g. when both processes wait for data at the same time).

In mpi4jax, we enforce order of execution through tokens. In the example code, you can see this behavior e.g. in the following lines:

token = mpi4jax.send(arr, dest=1, comm=comm)
other_arr, token = mpi4jax.recv(arr, source=1, comm=comm, token=token)

The first call to send() returns a token, which we then pass to recv(). recv() also returns a new token that we could pass to subsequent communication primitives.

Because of the nature of JAX, using tokens to enforce order is not optional. If you do not use correct token management, you will experience deadlocks and crashes.

See also

For more information on tokens, see Token management.