API Reference

Utilities

has_cuda_support

mpi4jax.has_cuda_support() bool

Returns True if mpi4jax is built with CUDA support and can be used with GPU-based jax-arrays, False otherwise.

Communication primitives

allgather

mpi4jax.allgather(x, *, comm=None, token=None)

Perform an allgather operation.

Warning

x must have the same shape and dtype on all processes.

Parameters
  • x – Array or scalar input to send.

  • comm (mpi4py.MPI.Comm) – The MPI communicator to use (defaults to a clone of COMM_WORLD).

  • token (Token) – XLA token to use to ensure correct execution order. If not given, a new token is generated.

Returns

  • Received data.

  • A new, modified token, that depends on this operation.

Return type

Tuple[DeviceArray, Token]

allreduce

mpi4jax.allreduce(x, op, *, comm=None, token=None)

Perform an allreduce operation.

Note

This primitive can be differentiated via jax.grad() and related functions if op is mpi4py.MPI.SUM.

Parameters
  • x – Array or scalar input.

  • op (mpi4py.MPI.Op) – The reduction operator (e.g mpi4py.MPI.SUM).

  • comm (mpi4py.MPI.Comm) – The MPI communicator to use (defaults to a clone of COMM_WORLD).

  • token (Token) – XLA token to use to ensure correct execution order. If not given, a new token is generated.

Returns

  • Result of the allreduce operation.

  • A new, modified token, that depends on this operation.

Return type

Tuple[DeviceArray, Token]

alltoall

mpi4jax.alltoall(x, *, comm=None, token=None)

Perform an alltoall operation.

Parameters
  • x – Array input to send. First axis must have size nproc.

  • comm (mpi4py.MPI.Comm) – The MPI communicator to use (defaults to a clone of COMM_WORLD).

  • token (Token) – XLA token to use to ensure correct execution order. If not given, a new token is generated.

Returns

  • Received data.

  • A new, modified token, that depends on this operation.

Return type

Tuple[DeviceArray, Token]

barrier

mpi4jax.barrier(*, comm=None, token=None)

Perform a barrier operation.

Parameters
  • comm (mpi4py.MPI.Comm) – The MPI communicator to use (defaults to a clone of COMM_WORLD).

  • token (Token) – XLA token to use to ensure correct execution order. If not given, a new token is generated.

Returns

  • A new, modified token, that depends on this operation.

Return type

Token

bcast

mpi4jax.bcast(x, root, *, comm=None, token=None)

Perform a bcast (broadcast) operation.

Warning

Unlike mpi4py’s bcast, this returns a new array with the received data.

Parameters
  • x – Array or scalar input. Data is only read on root process. On non-root processes, this is used to determine the shape and dtype of the result.

  • root (int) – The process to use as source.

  • comm (mpi4py.MPI.Comm) – The MPI communicator to use (defaults to a clone of COMM_WORLD).

  • token (Token) – XLA token to use to ensure correct execution order. If not given, a new token is generated.

Returns

  • Received data.

  • A new, modified token, that depends on this operation.

Return type

Tuple[DeviceArray, Token]

gather

mpi4jax.gather(x, root, *, comm=None, token=None)

Perform a gather operation.

Warning

x must have the same shape and dtype on all processes.

Warning

The shape of the returned data varies between ranks. On the root process, it is (nproc, *input_shape). On all other processes the output is identical to the input.

Parameters
  • x – Array or scalar input to send.

  • root (int) – Rank of the root MPI process.

  • comm (mpi4py.MPI.Comm) – The MPI communicator to use (defaults to a clone of COMM_WORLD).

  • token (Token) – XLA token to use to ensure correct execution order. If not given, a new token is generated.

Returns

  • Received data on root process, otherwise unmodified input.

  • A new, modified token, that depends on this operation.

Return type

Tuple[DeviceArray, Token]

recv

mpi4jax.recv(x, source=-1, *, tag=-1, comm=None, status=None, token=None)

Perform a recv (receive) operation.

Warning

Unlike mpi4py’s recv, this returns a new array with the received data.

Parameters
  • x – Array or scalar input with the correct shape and dtype. This can contain arbitrary data and will not be overwritten.

  • source (int) – Rank of the source MPI process.

  • tag (int) – Tag of this message.

  • comm (mpi4py.MPI.Comm) – The MPI communicator to use (defaults to a clone of COMM_WORLD).

  • status (mpi4py.MPI.Status) – Status object, can be used for introspection.

  • token (Token) – XLA token to use to ensure correct execution order. If not given, a new token is generated.

Returns

  • Received data.

  • A new, modified token, that depends on this operation.

Return type

Tuple[DeviceArray, Token]

reduce

mpi4jax.reduce(x, op, root, *, comm=None, token=None)

Perform a reduce operation.

Parameters
  • x – Array or scalar input to send.

  • op (mpi4py.MPI.Op) – The reduction operator (e.g mpi4py.MPI.SUM).

  • root (int) – Rank of the root MPI process.

  • comm (mpi4py.MPI.Comm) – The MPI communicator to use (defaults to a clone of COMM_WORLD).

  • token (Token) – XLA token to use to ensure correct execution order. If not given, a new token is generated.

Returns

  • Result of the reduce operation on root process, otherwise unmodified input.

  • A new, modified token, that depends on this operation.

Return type

Tuple[DeviceArray, Token]

scan

mpi4jax.scan(x, op, *, comm=None, token=None)

Perform a scan operation.

Parameters
  • x – Array or scalar input to send.

  • op (mpi4py.MPI.Op) – The reduction operator (e.g mpi4py.MPI.SUM).

  • comm (mpi4py.MPI.Comm) – The MPI communicator to use (defaults to a clone of COMM_WORLD).

  • token (Token) – XLA token to use to ensure correct execution order. If not given, a new token is generated.

Returns

  • Result of the scan operation.

  • A new, modified token, that depends on this operation.

Return type

Tuple[DeviceArray, Token]

scatter

mpi4jax.scatter(x, root, *, comm=None, token=None)

Perform a scatter operation.

Warning

Unlike mpi4py’s scatter, this returns a new array with the received data.

Warning

The expected shape of the first input varies between ranks. On the root process, it is (nproc, *input_shape). On all other processes, it is input_shape.

Parameters
  • x – Array or scalar input with the correct shape and dtype. On the root process, this contains the data to send, and its first axis must have size nproc. On non-root processes, this may contain arbitrary data and will not be overwritten.

  • root (int) – Rank of the root MPI process.

  • comm (mpi4py.MPI.Comm) – The MPI communicator to use (defaults to a clone of COMM_WORLD).

  • token (Token) – XLA token to use to ensure correct execution order. If not given, a new token is generated.

Returns

  • Received data.

  • A new, modified token, that depends on this operation.

Return type

Tuple[DeviceArray, Token]

send

mpi4jax.send(x, dest, *, tag=0, comm=None, token=None)

Perform a send operation.

Parameters
  • x – Array or scalar input to send.

  • dest (int) – Rank of the destination MPI process.

  • tag (int) – Tag of this message.

  • comm (mpi4py.MPI.Comm) – The MPI communicator to use (defaults to a clone of COMM_WORLD).

  • token (Token) – XLA token to use to ensure correct execution order. If not given, a new token is generated.

Returns

A new, modified token, that depends on this operation.

Return type

Token

sendrecv

mpi4jax.sendrecv(sendbuf, recvbuf, source, dest, *, sendtag=0, recvtag=-1, comm=None, status=None, token=None)

Perform a sendrecv operation.

Warning

Unlike mpi4py’s sendrecv, this returns a new array with the received data.

Parameters
  • sendbuf – Array or scalar input to send.

  • recvbuf – Array or scalar input with the correct shape and dtype. This can contain arbitrary data and will not be overwritten.

  • source (int) – Rank of the source MPI process.

  • dest (int) – Rank of the destination MPI process.

  • sendtag (int) – Tag of this message for sending.

  • recvtag (int) – Tag of this message for receiving.

  • comm (mpi4py.MPI.Comm) – The MPI communicator to use (defaults to a clone of COMM_WORLD).

  • status (mpi4py.MPI.Status) – Status object, can be used for introspection.

  • token (Token) – XLA token to use to ensure correct execution order. If not given, a new token is generated.

Returns

  • Received data.

  • A new, modified token, that depends on this operation.

Return type

Tuple[DeviceArray, Token]

Experimental

auto_tokenize

Warning

auto_tokenize is currently broken for JAX 0.4.4 and later. To use it, downgrade to jax<=0.4.3. See issue #192 for more details.

mpi4jax.experimental.auto_tokenize(f)

Automatically manage tokens between all mpi4jax ops.

Supports most JAX operations, including fori_loop, cond, jit, while_loop, and scan. This transforms all operations in the decorated function, even through subfunctions and nested applications of jax.jit.

Note

This transforms overrides all mpi4jax token management completely. Do not use this transform if you need manual control over the token managment.

Parameters

f – Any function that uses mpi4jax primitives (jitted or not).

Returns

A transformed version of f that automatically manages all mpi4jax tokens.

Example

>>> @auto_tokenize
... def f(a):
...     # no token handling necessary
...     res, _ = allreduce(a, op=MPI.SUM)
...     res, _ = allreduce(res, op=MPI.SUM)
...     return res
>>> arr = jnp.ones((3, 2))
>>> res = f(arr)