API Reference

Utilities

has_cuda_support

mpi4jax.has_cuda_support() bool

Returns True if mpi4jax is built with CUDA support and can be used with GPU-based jax-arrays, False otherwise.

Communication primitives

allgather

mpi4jax.allgather(x, *, comm=None, token=<object object>)

Perform an allgather operation.

Warning

x must have the same shape and dtype on all processes.

Parameters:
  • x – Array or scalar input to send.

  • comm (mpi4py.MPI.Comm) – The MPI communicator to use (defaults to a clone of COMM_WORLD).

Returns:

Received data.

Return type:

DeviceArray

allreduce

mpi4jax.allreduce(x, op, *, comm=None, token=<object object>)

Perform an allreduce operation.

Note

This primitive can be differentiated via jax.grad() and related functions if op is mpi4py.MPI.SUM.

Parameters:
  • x – Array or scalar input.

  • op (mpi4py.MPI.Op) – The reduction operator (e.g mpi4py.MPI.SUM).

  • comm (mpi4py.MPI.Comm) – The MPI communicator to use (defaults to a clone of COMM_WORLD).

Returns:

Result of the allreduce operation.

Return type:

DeviceArray

alltoall

mpi4jax.alltoall(x, *, comm=None, token=<object object>)

Perform an alltoall operation.

Parameters:
  • x – Array input to send. First axis must have size nproc.

  • comm (mpi4py.MPI.Comm) – The MPI communicator to use (defaults to a clone of COMM_WORLD).

Returns:

Received data.

Return type:

DeviceArray

barrier

mpi4jax.barrier(*, comm=None, token=<object object>)

Perform a barrier operation.

Parameters:

comm (mpi4py.MPI.Comm) – The MPI communicator to use (defaults to a clone of COMM_WORLD).

bcast

mpi4jax.bcast(x, root, *, comm=None, token=<object object>)

Perform a bcast (broadcast) operation.

Warning

Unlike mpi4py’s bcast, this returns a new array with the received data.

Parameters:
  • x – Array or scalar input. Data is only read on root process. On non-root processes, this is used to determine the shape and dtype of the result.

  • root (int) – The process to use as source.

  • comm (mpi4py.MPI.Comm) – The MPI communicator to use (defaults to a clone of COMM_WORLD).

Returns:

Received data.

Return type:

DeviceArray

gather

mpi4jax.gather(x, root, *, comm=None, token=<object object>)

Perform a gather operation.

Warning

x must have the same shape and dtype on all processes.

Warning

The shape of the returned data varies between ranks. On the root process, it is (nproc, *input_shape). On all other processes the output is identical to the input.

Parameters:
  • x – Array or scalar input to send.

  • root (int) – Rank of the root MPI process.

  • comm (mpi4py.MPI.Comm) – The MPI communicator to use (defaults to a clone of COMM_WORLD).

Returns:

Received data on root process, otherwise unmodified input.

Return type:

DeviceArray

recv

mpi4jax.recv(x, source=-1, *, tag=-1, comm=None, status=None, token=<object object>)

Perform a recv (receive) operation.

Warning

Unlike mpi4py’s recv, this returns a new array with the received data.

Parameters:
  • x – Array or scalar input with the correct shape and dtype. This can contain arbitrary data and will not be overwritten.

  • source (int) – Rank of the source MPI process.

  • tag (int) – Tag of this message.

  • comm (mpi4py.MPI.Comm) – The MPI communicator to use (defaults to a clone of COMM_WORLD).

  • status (mpi4py.MPI.Status) – Status object, can be used for introspection.

Returns:

Received data.

Return type:

DeviceArray

reduce

mpi4jax.reduce(x, op, root, *, comm=None, token=<object object>)

Perform a reduce operation.

Parameters:
  • x – Array or scalar input to send.

  • op (mpi4py.MPI.Op) – The reduction operator (e.g mpi4py.MPI.SUM).

  • root (int) – Rank of the root MPI process.

  • comm (mpi4py.MPI.Comm) – The MPI communicator to use (defaults to a clone of COMM_WORLD).

Returns:

Result of the reduce operation on root process, otherwise

unmodified input.

Return type:

DeviceArray

scan

mpi4jax.scan(x, op, *, comm=None, token=<object object>)

Perform a scan operation.

Parameters:
  • x – Array or scalar input to send.

  • op (mpi4py.MPI.Op) – The reduction operator (e.g mpi4py.MPI.SUM).

  • comm (mpi4py.MPI.Comm) – The MPI communicator to use (defaults to a clone of COMM_WORLD).

Returns:

Result of the scan operation.

Return type:

DeviceArray

scatter

mpi4jax.scatter(x, root, *, comm=None, token=<object object>)

Perform a scatter operation.

Warning

Unlike mpi4py’s scatter, this returns a new array with the received data.

Warning

The expected shape of the first input varies between ranks. On the root process, it is (nproc, *input_shape). On all other processes, it is input_shape.

Parameters:
  • x – Array or scalar input with the correct shape and dtype. On the root process, this contains the data to send, and its first axis must have size nproc. On non-root processes, this may contain arbitrary data and will not be overwritten.

  • root (int) – Rank of the root MPI process.

  • comm (mpi4py.MPI.Comm) – The MPI communicator to use (defaults to a clone of COMM_WORLD).

Returns:

Received data.

Return type:

DeviceArray

send

mpi4jax.send(x, dest, *, tag=0, comm=None, token=<object object>)

Perform a send operation.

Parameters:
  • x – Array or scalar input to send.

  • dest (int) – Rank of the destination MPI process.

  • tag (int) – Tag of this message.

  • comm (mpi4py.MPI.Comm) – The MPI communicator to use (defaults to a clone of COMM_WORLD).

sendrecv

mpi4jax.sendrecv(sendbuf, recvbuf, source, dest, *, sendtag=0, recvtag=-1, comm=None, status=None, token=<object object>)

Perform a sendrecv operation.

Warning

Unlike mpi4py’s sendrecv, this returns a new array with the received data.

Parameters:
  • sendbuf – Array or scalar input to send.

  • recvbuf – Array or scalar input with the correct shape and dtype. This can contain arbitrary data and will not be overwritten.

  • source (int) – Rank of the source MPI process.

  • dest (int) – Rank of the destination MPI process.

  • sendtag (int) – Tag of this message for sending.

  • recvtag (int) – Tag of this message for receiving.

  • comm (mpi4py.MPI.Comm) – The MPI communicator to use (defaults to a clone of COMM_WORLD).

  • status (mpi4py.MPI.Status) – Status object, can be used for introspection.

Returns:

Received data.

Return type:

DeviceArray