API Reference
Utilities
has_cuda_support
- mpi4jax.has_cuda_support() bool
Returns True if mpi4jax is built with CUDA support and can be used with GPU-based jax-arrays, False otherwise.
Communication primitives
allgather
- mpi4jax.allgather(x, *, comm=None, token=None)
Perform an allgather operation.
Warning
x
must have the same shape and dtype on all processes.- Parameters
x – Array or scalar input to send.
comm (mpi4py.MPI.Comm) – The MPI communicator to use (defaults to a clone of
COMM_WORLD
).token (Token) – XLA token to use to ensure correct execution order. If not given, a new token is generated.
- Returns
Received data.
A new, modified token, that depends on this operation.
- Return type
Tuple[DeviceArray, Token]
allreduce
- mpi4jax.allreduce(x, op, *, comm=None, token=None)
Perform an allreduce operation.
Note
This primitive can be differentiated via
jax.grad()
and related functions ifop
ismpi4py.MPI.SUM
.- Parameters
x – Array or scalar input.
op (mpi4py.MPI.Op) – The reduction operator (e.g
mpi4py.MPI.SUM
).comm (mpi4py.MPI.Comm) – The MPI communicator to use (defaults to a clone of
COMM_WORLD
).token (Token) – XLA token to use to ensure correct execution order. If not given, a new token is generated.
- Returns
Result of the allreduce operation.
A new, modified token, that depends on this operation.
- Return type
Tuple[DeviceArray, Token]
alltoall
- mpi4jax.alltoall(x, *, comm=None, token=None)
Perform an alltoall operation.
- Parameters
x – Array input to send. First axis must have size
nproc
.comm (mpi4py.MPI.Comm) – The MPI communicator to use (defaults to a clone of
COMM_WORLD
).token (Token) – XLA token to use to ensure correct execution order. If not given, a new token is generated.
- Returns
Received data.
A new, modified token, that depends on this operation.
- Return type
Tuple[DeviceArray, Token]
barrier
- mpi4jax.barrier(*, comm=None, token=None)
Perform a barrier operation.
- Parameters
comm (mpi4py.MPI.Comm) – The MPI communicator to use (defaults to a clone of
COMM_WORLD
).token (Token) – XLA token to use to ensure correct execution order. If not given, a new token is generated.
- Returns
A new, modified token, that depends on this operation.
- Return type
Token
bcast
- mpi4jax.bcast(x, root, *, comm=None, token=None)
Perform a bcast (broadcast) operation.
Warning
Unlike mpi4py’s bcast, this returns a new array with the received data.
- Parameters
x – Array or scalar input. Data is only read on root process. On non-root processes, this is used to determine the shape and dtype of the result.
root (int) – The process to use as source.
comm (mpi4py.MPI.Comm) – The MPI communicator to use (defaults to a clone of
COMM_WORLD
).token (Token) – XLA token to use to ensure correct execution order. If not given, a new token is generated.
- Returns
Received data.
A new, modified token, that depends on this operation.
- Return type
Tuple[DeviceArray, Token]
gather
- mpi4jax.gather(x, root, *, comm=None, token=None)
Perform a gather operation.
Warning
x
must have the same shape and dtype on all processes.Warning
The shape of the returned data varies between ranks. On the root process, it is
(nproc, *input_shape)
. On all other processes the output is identical to the input.- Parameters
x – Array or scalar input to send.
root (int) – Rank of the root MPI process.
comm (mpi4py.MPI.Comm) – The MPI communicator to use (defaults to a clone of
COMM_WORLD
).token (Token) – XLA token to use to ensure correct execution order. If not given, a new token is generated.
- Returns
Received data on root process, otherwise unmodified input.
A new, modified token, that depends on this operation.
- Return type
Tuple[DeviceArray, Token]
recv
- mpi4jax.recv(x, source=-1, *, tag=-1, comm=None, status=None, token=None)
Perform a recv (receive) operation.
Warning
Unlike mpi4py’s recv, this returns a new array with the received data.
- Parameters
x – Array or scalar input with the correct shape and dtype. This can contain arbitrary data and will not be overwritten.
source (int) – Rank of the source MPI process.
tag (int) – Tag of this message.
comm (mpi4py.MPI.Comm) – The MPI communicator to use (defaults to a clone of
COMM_WORLD
).status (mpi4py.MPI.Status) – Status object, can be used for introspection.
token (Token) – XLA token to use to ensure correct execution order. If not given, a new token is generated.
- Returns
Received data.
A new, modified token, that depends on this operation.
- Return type
Tuple[DeviceArray, Token]
reduce
- mpi4jax.reduce(x, op, root, *, comm=None, token=None)
Perform a reduce operation.
- Parameters
x – Array or scalar input to send.
op (mpi4py.MPI.Op) – The reduction operator (e.g
mpi4py.MPI.SUM
).root (int) – Rank of the root MPI process.
comm (mpi4py.MPI.Comm) – The MPI communicator to use (defaults to a clone of
COMM_WORLD
).token (Token) – XLA token to use to ensure correct execution order. If not given, a new token is generated.
- Returns
Result of the reduce operation on root process, otherwise unmodified input.
A new, modified token, that depends on this operation.
- Return type
Tuple[DeviceArray, Token]
scan
- mpi4jax.scan(x, op, *, comm=None, token=None)
Perform a scan operation.
- Parameters
x – Array or scalar input to send.
op (mpi4py.MPI.Op) – The reduction operator (e.g
mpi4py.MPI.SUM
).comm (mpi4py.MPI.Comm) – The MPI communicator to use (defaults to a clone of
COMM_WORLD
).token (Token) – XLA token to use to ensure correct execution order. If not given, a new token is generated.
- Returns
Result of the scan operation.
A new, modified token, that depends on this operation.
- Return type
Tuple[DeviceArray, Token]
scatter
- mpi4jax.scatter(x, root, *, comm=None, token=None)
Perform a scatter operation.
Warning
Unlike mpi4py’s scatter, this returns a new array with the received data.
Warning
The expected shape of the first input varies between ranks. On the root process, it is
(nproc, *input_shape)
. On all other processes, it isinput_shape
.- Parameters
x – Array or scalar input with the correct shape and dtype. On the root process, this contains the data to send, and its first axis must have size
nproc
. On non-root processes, this may contain arbitrary data and will not be overwritten.root (int) – Rank of the root MPI process.
comm (mpi4py.MPI.Comm) – The MPI communicator to use (defaults to a clone of
COMM_WORLD
).token (Token) – XLA token to use to ensure correct execution order. If not given, a new token is generated.
- Returns
Received data.
A new, modified token, that depends on this operation.
- Return type
Tuple[DeviceArray, Token]
send
- mpi4jax.send(x, dest, *, tag=0, comm=None, token=None)
Perform a send operation.
- Parameters
x – Array or scalar input to send.
dest (int) – Rank of the destination MPI process.
tag (int) – Tag of this message.
comm (mpi4py.MPI.Comm) – The MPI communicator to use (defaults to a clone of
COMM_WORLD
).token (Token) – XLA token to use to ensure correct execution order. If not given, a new token is generated.
- Returns
A new, modified token, that depends on this operation.
- Return type
Token
sendrecv
- mpi4jax.sendrecv(sendbuf, recvbuf, source, dest, *, sendtag=0, recvtag=-1, comm=None, status=None, token=None)
Perform a sendrecv operation.
Warning
Unlike mpi4py’s sendrecv, this returns a new array with the received data.
- Parameters
sendbuf – Array or scalar input to send.
recvbuf – Array or scalar input with the correct shape and dtype. This can contain arbitrary data and will not be overwritten.
source (int) – Rank of the source MPI process.
dest (int) – Rank of the destination MPI process.
sendtag (int) – Tag of this message for sending.
recvtag (int) – Tag of this message for receiving.
comm (mpi4py.MPI.Comm) – The MPI communicator to use (defaults to a clone of
COMM_WORLD
).status (mpi4py.MPI.Status) – Status object, can be used for introspection.
token (Token) – XLA token to use to ensure correct execution order. If not given, a new token is generated.
- Returns
Received data.
A new, modified token, that depends on this operation.
- Return type
Tuple[DeviceArray, Token]
Experimental
auto_tokenize
Warning
auto_tokenize
is currently broken for JAX 0.4.4 and later.
To use it, downgrade to jax<=0.4.3
.
See issue #192 for more details.
- mpi4jax.experimental.auto_tokenize(f)
Automatically manage tokens between all mpi4jax ops.
Supports most JAX operations, including
fori_loop
,cond
,jit
,while_loop
, andscan
. This transforms all operations in the decorated function, even through subfunctions and nested applications ofjax.jit
.Note
This transforms overrides all mpi4jax token management completely. Do not use this transform if you need manual control over the token managment.
- Parameters
f – Any function that uses mpi4jax primitives (jitted or not).
- Returns
A transformed version of
f
that automatically manages all mpi4jax tokens.
Example
>>> @auto_tokenize ... def f(a): ... # no token handling necessary ... res, _ = allreduce(a, op=MPI.SUM) ... res, _ = allreduce(res, op=MPI.SUM) ... return res >>> arr = jnp.ones((3, 2)) >>> res = f(arr)