API Reference

Utilities

has_cuda_support

mpi4jax.has_cuda_support() bool

Returns True if mpi4jax is built with CUDA support and can be used with GPU-based jax-arrays, False otherwise.

Communication primitives

allgather

mpi4jax.allgather(x, *, comm=None, token=None)

Perform an allgather operation.

Warning

x must have the same shape and dtype on all processes.

Parameters
  • x – Array or scalar input to send.

  • comm (mpi4py.MPI.Comm) – The MPI communicator to use (defaults to a clone of COMM_WORLD).

  • token (Token) – XLA token to use to ensure correct execution order. If not given, a new token is generated.

Returns

  • Received data.

  • A new, modified token, that depends on this operation.

Return type

Tuple[DeviceArray, Token]

allreduce

mpi4jax.allreduce(x, op, *, comm=None, token=None)

Perform an allreduce operation.

Note

This primitive can be differentiated via jax.grad() and related functions if op is mpi4py.MPI.SUM.

Parameters
  • x – Array or scalar input.

  • op (mpi4py.MPI.Op) – The reduction operator (e.g mpi4py.MPI.SUM).

  • comm (mpi4py.MPI.Comm) – The MPI communicator to use (defaults to a clone of COMM_WORLD).

  • token (Token) – XLA token to use to ensure correct execution order. If not given, a new token is generated.

Returns

  • Result of the allreduce operation.

  • A new, modified token, that depends on this operation.

Return type

Tuple[DeviceArray, Token]

alltoall

mpi4jax.alltoall(x, *, comm=None, token=None)

Perform an alltoall operation.

Parameters
  • x – Array input to send. First axis must have size nproc.

  • comm (mpi4py.MPI.Comm) – The MPI communicator to use (defaults to a clone of COMM_WORLD).

  • token (Token) – XLA token to use to ensure correct execution order. If not given, a new token is generated.

Returns

  • Received data.

  • A new, modified token, that depends on this operation.

Return type

Tuple[DeviceArray, Token]

barrier

mpi4jax.barrier(*, comm=None, token=None)

Perform a barrier operation.

Parameters
  • comm (mpi4py.MPI.Comm) – The MPI communicator to use (defaults to a clone of COMM_WORLD).

  • token (Token) – XLA token to use to ensure correct execution order. If not given, a new token is generated.

Returns

  • A new, modified token, that depends on this operation.

Return type

Token

bcast

mpi4jax.bcast(x, root, *, comm=None, token=None)

Perform a bcast (broadcast) operation.

Warning

Unlike mpi4py’s bcast, this returns a new array with the received data.

Parameters
  • x – Array or scalar input. Data is only read on root process. On non-root processes, this is used to determine the shape and dtype of the result.

  • root (int) – The process to use as source.

  • comm (mpi4py.MPI.Comm) – The MPI communicator to use (defaults to a clone of COMM_WORLD).

  • token (Token) – XLA token to use to ensure correct execution order. If not given, a new token is generated.

Returns

  • Received data.

  • A new, modified token, that depends on this operation.

Return type

Tuple[DeviceArray, Token]

gather

mpi4jax.gather(x, root, *, comm=None, token=None)

Perform a gather operation.

Warning

x must have the same shape and dtype on all processes.

Warning

The shape of the returned data varies between ranks. On the root process, it is (nproc, *input_shape). On all other processes the output is identical to the input.

Parameters
  • x – Array or scalar input to send.

  • root (int) – Rank of the root MPI process.

  • comm (mpi4py.MPI.Comm) – The MPI communicator to use (defaults to a clone of COMM_WORLD).

  • token (Token) – XLA token to use to ensure correct execution order. If not given, a new token is generated.

Returns

  • Received data on root process, otherwise unmodified input.

  • A new, modified token, that depends on this operation.

Return type

Tuple[DeviceArray, Token]

recv

mpi4jax.recv(x, source=-1, *, tag=-1, comm=None, status=None, token=None)

Perform a recv (receive) operation.

Warning

Unlike mpi4py’s recv, this returns a new array with the received data.

Parameters
  • x – Array or scalar input with the correct shape and dtype. This can contain arbitrary data and will not be overwritten.

  • source (int) – Rank of the source MPI process.

  • tag (int) – Tag of this message.

  • comm (mpi4py.MPI.Comm) – The MPI communicator to use (defaults to a clone of COMM_WORLD).

  • status (mpi4py.MPI.Status) – Status object, can be used for introspection.

  • token (Token) – XLA token to use to ensure correct execution order. If not given, a new token is generated.

Returns

  • Received data.

  • A new, modified token, that depends on this operation.

Return type

Tuple[DeviceArray, Token]

reduce

mpi4jax.reduce(x, op, root, *, comm=None, token=None)

Perform a reduce operation.

Parameters
  • x – Array or scalar input to send.

  • op (mpi4py.MPI.Op) – The reduction operator (e.g mpi4py.MPI.SUM).

  • root (int) – Rank of the root MPI process.

  • comm (mpi4py.MPI.Comm) – The MPI communicator to use (defaults to a clone of COMM_WORLD).

  • token (Token) – XLA token to use to ensure correct execution order. If not given, a new token is generated.

Returns

  • Result of the reduce operation on root process, otherwise unmodified input.

  • A new, modified token, that depends on this operation.

Return type

Tuple[DeviceArray, Token]

scan

mpi4jax.scan(x, op, *, comm=None, token=None)

Perform a scan operation.

Parameters
  • x – Array or scalar input to send.

  • op (mpi4py.MPI.Op) – The reduction operator (e.g mpi4py.MPI.SUM).

  • comm (mpi4py.MPI.Comm) – The MPI communicator to use (defaults to a clone of COMM_WORLD).

  • token (Token) – XLA token to use to ensure correct execution order. If not given, a new token is generated.

Returns

  • Result of the scan operation.

  • A new, modified token, that depends on this operation.

Return type

Tuple[DeviceArray, Token]

scatter

mpi4jax.scatter(x, root, *, comm=None, token=None)

Perform a scatter operation.

Warning

Unlike mpi4py’s scatter, this returns a new array with the received data.

Warning

The expected shape of the first input varies between ranks. On the root process, it is (nproc, *input_shape). On all other processes, it is input_shape.

Parameters
  • x – Array or scalar input with the correct shape and dtype. On the root process, this contains the data to send, and its first axis must have size nproc. On non-root processes, this may contain arbitrary data and will not be overwritten.

  • root (int) – Rank of the root MPI process.

  • comm (mpi4py.MPI.Comm) – The MPI communicator to use (defaults to a clone of COMM_WORLD).

  • token (Token) – XLA token to use to ensure correct execution order. If not given, a new token is generated.

Returns

  • Received data.

  • A new, modified token, that depends on this operation.

Return type

Tuple[DeviceArray, Token]

send

mpi4jax.send(x, dest, *, tag=0, comm=None, token=None)

Perform a send operation.

Parameters
  • x – Array or scalar input to send.

  • dest (int) – Rank of the destination MPI process.

  • tag (int) – Tag of this message.

  • comm (mpi4py.MPI.Comm) – The MPI communicator to use (defaults to a clone of COMM_WORLD).

  • token (Token) – XLA token to use to ensure correct execution order. If not given, a new token is generated.

Returns

A new, modified token, that depends on this operation.

Return type

Token

sendrecv

mpi4jax.sendrecv(sendbuf, recvbuf, source, dest, *, sendtag=0, recvtag=-1, comm=None, status=None, token=None)

Perform a sendrecv operation.

Warning

Unlike mpi4py’s sendrecv, this returns a new array with the received data.

Parameters
  • sendbuf – Array or scalar input to send.

  • recvbuf – Array or scalar input with the correct shape and dtype. This can contain arbitrary data and will not be overwritten.

  • source (int) – Rank of the source MPI process.

  • dest (int) – Rank of the destination MPI process.

  • sendtag (int) – Tag of this message for sending.

  • recvtag (int) – Tag of this message for receiving.

  • comm (mpi4py.MPI.Comm) – The MPI communicator to use (defaults to a clone of COMM_WORLD).

  • status (mpi4py.MPI.Status) – Status object, can be used for introspection.

  • token (Token) – XLA token to use to ensure correct execution order. If not given, a new token is generated.

Returns

  • Received data.

  • A new, modified token, that depends on this operation.

Return type

Tuple[DeviceArray, Token]

Experimental

auto_tokenize

Warning

auto_tokenize is currently broken for JAX 0.4.4 and later. To use it, downgrade to jax<=0.4.3. See issue #192 for more details.

mpi4jax.experimental.auto_tokenize(f)

Automatically manage tokens between all mpi4jax ops.

Supports most JAX operations, including fori_loop, cond, jit, while_loop, and scan. This transforms all operations in the decorated function, even through subfunctions and nested applications of jax.jit.

Note

This transforms overrides all mpi4jax token management completely. Do not use this transform if you need manual control over the token managment.

Parameters

f – Any function that uses mpi4jax primitives (jitted or not).

Returns

A transformed version of f that automatically manages all mpi4jax tokens.

Example

>>> @auto_tokenize
... def f(a):
...     # no token handling necessary
...     res, _ = allreduce(a, op=MPI.SUM)
...     res, _ = allreduce(res, op=MPI.SUM)
...     return res
>>> arr = jnp.ones((3, 2))
>>> res = f(arr)