API Reference
Utilities
has_cuda_support
- mpi4jax.has_cuda_support() bool
Returns True if mpi4jax is built with CUDA support and can be used with GPU-based jax-arrays, False otherwise.
Communication primitives
allgather
- mpi4jax.allgather(x, *, comm=None, token=<object object>)
Perform an allgather operation.
Warning
xmust have the same shape and dtype on all processes.- Parameters:
x – Array or scalar input to send.
comm (mpi4py.MPI.Comm) – The MPI communicator to use (defaults to a clone of
COMM_WORLD).
- Returns:
Received data.
- Return type:
DeviceArray
allreduce
- mpi4jax.allreduce(x, op, *, comm=None, token=<object object>)
Perform an allreduce operation.
Note
This primitive can be differentiated via
jax.grad()and related functions ifopismpi4py.MPI.SUM.- Parameters:
x – Array or scalar input.
op (mpi4py.MPI.Op) – The reduction operator (e.g
mpi4py.MPI.SUM).comm (mpi4py.MPI.Comm) – The MPI communicator to use (defaults to a clone of
COMM_WORLD).
- Returns:
Result of the allreduce operation.
- Return type:
DeviceArray
alltoall
- mpi4jax.alltoall(x, *, comm=None, token=<object object>)
Perform an alltoall operation.
- Parameters:
x – Array input to send. First axis must have size
nproc.comm (mpi4py.MPI.Comm) – The MPI communicator to use (defaults to a clone of
COMM_WORLD).
- Returns:
Received data.
- Return type:
DeviceArray
barrier
- mpi4jax.barrier(*, comm=None, token=<object object>)
Perform a barrier operation.
- Parameters:
comm (mpi4py.MPI.Comm) – The MPI communicator to use (defaults to a clone of
COMM_WORLD).
bcast
- mpi4jax.bcast(x, root, *, comm=None, token=<object object>)
Perform a bcast (broadcast) operation.
Warning
Unlike mpi4py’s bcast, this returns a new array with the received data.
- Parameters:
x – Array or scalar input. Data is only read on root process. On non-root processes, this is used to determine the shape and dtype of the result.
root (int) – The process to use as source.
comm (mpi4py.MPI.Comm) – The MPI communicator to use (defaults to a clone of
COMM_WORLD).
- Returns:
Received data.
- Return type:
DeviceArray
gather
- mpi4jax.gather(x, root, *, comm=None, token=<object object>)
Perform a gather operation.
Warning
xmust have the same shape and dtype on all processes.Warning
The shape of the returned data varies between ranks. On the root process, it is
(nproc, *input_shape). On all other processes the output is identical to the input.- Parameters:
x – Array or scalar input to send.
root (int) – Rank of the root MPI process.
comm (mpi4py.MPI.Comm) – The MPI communicator to use (defaults to a clone of
COMM_WORLD).
- Returns:
Received data on root process, otherwise unmodified input.
- Return type:
DeviceArray
recv
- mpi4jax.recv(x, source=-1, *, tag=-1, comm=None, status=None, token=<object object>)
Perform a recv (receive) operation.
Warning
Unlike mpi4py’s recv, this returns a new array with the received data.
- Parameters:
x – Array or scalar input with the correct shape and dtype. This can contain arbitrary data and will not be overwritten.
source (int) – Rank of the source MPI process.
tag (int) – Tag of this message.
comm (mpi4py.MPI.Comm) – The MPI communicator to use (defaults to a clone of
COMM_WORLD).status (mpi4py.MPI.Status) – Status object, can be used for introspection.
- Returns:
Received data.
- Return type:
DeviceArray
reduce
- mpi4jax.reduce(x, op, root, *, comm=None, token=<object object>)
Perform a reduce operation.
- Parameters:
x – Array or scalar input to send.
op (mpi4py.MPI.Op) – The reduction operator (e.g
mpi4py.MPI.SUM).root (int) – Rank of the root MPI process.
comm (mpi4py.MPI.Comm) – The MPI communicator to use (defaults to a clone of
COMM_WORLD).
- Returns:
- Result of the reduce operation on root process, otherwise
unmodified input.
- Return type:
DeviceArray
scan
- mpi4jax.scan(x, op, *, comm=None, token=<object object>)
Perform a scan operation.
- Parameters:
x – Array or scalar input to send.
op (mpi4py.MPI.Op) – The reduction operator (e.g
mpi4py.MPI.SUM).comm (mpi4py.MPI.Comm) – The MPI communicator to use (defaults to a clone of
COMM_WORLD).
- Returns:
Result of the scan operation.
- Return type:
DeviceArray
scatter
- mpi4jax.scatter(x, root, *, comm=None, token=<object object>)
Perform a scatter operation.
Warning
Unlike mpi4py’s scatter, this returns a new array with the received data.
Warning
The expected shape of the first input varies between ranks. On the root process, it is
(nproc, *input_shape). On all other processes, it isinput_shape.- Parameters:
x – Array or scalar input with the correct shape and dtype. On the root process, this contains the data to send, and its first axis must have size
nproc. On non-root processes, this may contain arbitrary data and will not be overwritten.root (int) – Rank of the root MPI process.
comm (mpi4py.MPI.Comm) – The MPI communicator to use (defaults to a clone of
COMM_WORLD).
- Returns:
Received data.
- Return type:
DeviceArray
send
- mpi4jax.send(x, dest, *, tag=0, comm=None, token=<object object>)
Perform a send operation.
- Parameters:
x – Array or scalar input to send.
dest (int) – Rank of the destination MPI process.
tag (int) – Tag of this message.
comm (mpi4py.MPI.Comm) – The MPI communicator to use (defaults to a clone of
COMM_WORLD).
sendrecv
- mpi4jax.sendrecv(sendbuf, recvbuf, source, dest, *, sendtag=0, recvtag=-1, comm=None, status=None, token=<object object>)
Perform a sendrecv operation.
Warning
Unlike mpi4py’s sendrecv, this returns a new array with the received data.
- Parameters:
sendbuf – Array or scalar input to send.
recvbuf – Array or scalar input with the correct shape and dtype. This can contain arbitrary data and will not be overwritten.
source (int) – Rank of the source MPI process.
dest (int) – Rank of the destination MPI process.
sendtag (int) – Tag of this message for sending.
recvtag (int) – Tag of this message for receiving.
comm (mpi4py.MPI.Comm) – The MPI communicator to use (defaults to a clone of
COMM_WORLD).status (mpi4py.MPI.Status) – Status object, can be used for introspection.
- Returns:
Received data.
- Return type:
DeviceArray