API Reference
Utilities
has_cuda_support
- mpi4jax.has_cuda_support() bool
Returns True if mpi4jax is built with CUDA support and can be used with GPU-based jax-arrays, False otherwise.
Communication primitives
allgather
- mpi4jax.allgather(x, *, comm=None, token=None)
Perform an allgather operation.
Warning
x
must have the same shape and dtype on all processes.- Parameters:
x – Array or scalar input to send.
comm (mpi4py.MPI.Comm) – The MPI communicator to use (defaults to a clone of
COMM_WORLD
).token (Token) – XLA token to use to ensure correct execution order. If not given, a new token is generated.
- Returns:
Received data.
A new, modified token, that depends on this operation.
- Return type:
Tuple[DeviceArray, Token]
allreduce
- mpi4jax.allreduce(x, op, *, comm=None, token=None)
Perform an allreduce operation.
Note
This primitive can be differentiated via
jax.grad()
and related functions ifop
ismpi4py.MPI.SUM
.- Parameters:
x – Array or scalar input.
op (mpi4py.MPI.Op) – The reduction operator (e.g
mpi4py.MPI.SUM
).comm (mpi4py.MPI.Comm) – The MPI communicator to use (defaults to a clone of
COMM_WORLD
).token (Token) – XLA token to use to ensure correct execution order. If not given, a new token is generated.
- Returns:
Result of the allreduce operation.
A new, modified token, that depends on this operation.
- Return type:
Tuple[DeviceArray, Token]
alltoall
- mpi4jax.alltoall(x, *, comm=None, token=None)
Perform an alltoall operation.
- Parameters:
x – Array input to send. First axis must have size
nproc
.comm (mpi4py.MPI.Comm) – The MPI communicator to use (defaults to a clone of
COMM_WORLD
).token (Token) – XLA token to use to ensure correct execution order. If not given, a new token is generated.
- Returns:
Received data.
A new, modified token, that depends on this operation.
- Return type:
Tuple[DeviceArray, Token]
barrier
- mpi4jax.barrier(*, comm=None, token=None)
Perform a barrier operation.
- Parameters:
comm (mpi4py.MPI.Comm) – The MPI communicator to use (defaults to a clone of
COMM_WORLD
).token (Token) – XLA token to use to ensure correct execution order. If not given, a new token is generated.
- Returns:
A new, modified token, that depends on this operation.
- Return type:
Token
bcast
- mpi4jax.bcast(x, root, *, comm=None, token=None)
Perform a bcast (broadcast) operation.
Warning
Unlike mpi4py’s bcast, this returns a new array with the received data.
- Parameters:
x – Array or scalar input. Data is only read on root process. On non-root processes, this is used to determine the shape and dtype of the result.
root (int) – The process to use as source.
comm (mpi4py.MPI.Comm) – The MPI communicator to use (defaults to a clone of
COMM_WORLD
).token (Token) – XLA token to use to ensure correct execution order. If not given, a new token is generated.
- Returns:
Received data.
A new, modified token, that depends on this operation.
- Return type:
Tuple[DeviceArray, Token]
gather
- mpi4jax.gather(x, root, *, comm=None, token=None)
Perform a gather operation.
Warning
x
must have the same shape and dtype on all processes.Warning
The shape of the returned data varies between ranks. On the root process, it is
(nproc, *input_shape)
. On all other processes the output is identical to the input.- Parameters:
x – Array or scalar input to send.
root (int) – Rank of the root MPI process.
comm (mpi4py.MPI.Comm) – The MPI communicator to use (defaults to a clone of
COMM_WORLD
).token (Token) – XLA token to use to ensure correct execution order. If not given, a new token is generated.
- Returns:
Received data on root process, otherwise unmodified input.
A new, modified token, that depends on this operation.
- Return type:
Tuple[DeviceArray, Token]
recv
- mpi4jax.recv(x, source=-1, *, tag=-1, comm=None, status=None, token=None)
Perform a recv (receive) operation.
Warning
Unlike mpi4py’s recv, this returns a new array with the received data.
- Parameters:
x – Array or scalar input with the correct shape and dtype. This can contain arbitrary data and will not be overwritten.
source (int) – Rank of the source MPI process.
tag (int) – Tag of this message.
comm (mpi4py.MPI.Comm) – The MPI communicator to use (defaults to a clone of
COMM_WORLD
).status (mpi4py.MPI.Status) – Status object, can be used for introspection.
token (Token) – XLA token to use to ensure correct execution order. If not given, a new token is generated.
- Returns:
Received data.
A new, modified token, that depends on this operation.
- Return type:
Tuple[DeviceArray, Token]
reduce
- mpi4jax.reduce(x, op, root, *, comm=None, token=None)
Perform a reduce operation.
- Parameters:
x – Array or scalar input to send.
op (mpi4py.MPI.Op) – The reduction operator (e.g
mpi4py.MPI.SUM
).root (int) – Rank of the root MPI process.
comm (mpi4py.MPI.Comm) – The MPI communicator to use (defaults to a clone of
COMM_WORLD
).token (Token) – XLA token to use to ensure correct execution order. If not given, a new token is generated.
- Returns:
Result of the reduce operation on root process, otherwise unmodified input.
A new, modified token, that depends on this operation.
- Return type:
Tuple[DeviceArray, Token]
scan
- mpi4jax.scan(x, op, *, comm=None, token=None)
Perform a scan operation.
- Parameters:
x – Array or scalar input to send.
op (mpi4py.MPI.Op) – The reduction operator (e.g
mpi4py.MPI.SUM
).comm (mpi4py.MPI.Comm) – The MPI communicator to use (defaults to a clone of
COMM_WORLD
).token (Token) – XLA token to use to ensure correct execution order. If not given, a new token is generated.
- Returns:
Result of the scan operation.
A new, modified token, that depends on this operation.
- Return type:
Tuple[DeviceArray, Token]
scatter
- mpi4jax.scatter(x, root, *, comm=None, token=None)
Perform a scatter operation.
Warning
Unlike mpi4py’s scatter, this returns a new array with the received data.
Warning
The expected shape of the first input varies between ranks. On the root process, it is
(nproc, *input_shape)
. On all other processes, it isinput_shape
.- Parameters:
x – Array or scalar input with the correct shape and dtype. On the root process, this contains the data to send, and its first axis must have size
nproc
. On non-root processes, this may contain arbitrary data and will not be overwritten.root (int) – Rank of the root MPI process.
comm (mpi4py.MPI.Comm) – The MPI communicator to use (defaults to a clone of
COMM_WORLD
).token (Token) – XLA token to use to ensure correct execution order. If not given, a new token is generated.
- Returns:
Received data.
A new, modified token, that depends on this operation.
- Return type:
Tuple[DeviceArray, Token]
send
- mpi4jax.send(x, dest, *, tag=0, comm=None, token=None)
Perform a send operation.
- Parameters:
x – Array or scalar input to send.
dest (int) – Rank of the destination MPI process.
tag (int) – Tag of this message.
comm (mpi4py.MPI.Comm) – The MPI communicator to use (defaults to a clone of
COMM_WORLD
).token (Token) – XLA token to use to ensure correct execution order. If not given, a new token is generated.
- Returns:
A new, modified token, that depends on this operation.
- Return type:
Token
sendrecv
- mpi4jax.sendrecv(sendbuf, recvbuf, source, dest, *, sendtag=0, recvtag=-1, comm=None, status=None, token=None)
Perform a sendrecv operation.
Warning
Unlike mpi4py’s sendrecv, this returns a new array with the received data.
- Parameters:
sendbuf – Array or scalar input to send.
recvbuf – Array or scalar input with the correct shape and dtype. This can contain arbitrary data and will not be overwritten.
source (int) – Rank of the source MPI process.
dest (int) – Rank of the destination MPI process.
sendtag (int) – Tag of this message for sending.
recvtag (int) – Tag of this message for receiving.
comm (mpi4py.MPI.Comm) – The MPI communicator to use (defaults to a clone of
COMM_WORLD
).status (mpi4py.MPI.Status) – Status object, can be used for introspection.
token (Token) – XLA token to use to ensure correct execution order. If not given, a new token is generated.
- Returns:
Received data.
A new, modified token, that depends on this operation.
- Return type:
Tuple[DeviceArray, Token]