To install mpi4jax in editable mode along with all optional dependencies for testing, just run
$ pip install -e .[dev]
from the repository root.
pytest for testing. After installing the development dependencies, you can run our testing suite with the following commands:
$ pytest . $ mpirun -np 2 pytest .
pytest will run the tests on only 1 process, which means that a large part of mpi4jax cannot be tested (because it relies on communication between different processes). Therefore, you should always make sure that the tests also pass on multiple processes (via
We welcome code contributions or changes to the documentation via pull requests (PRs).
We use pre-commit hooks to enforce a common code format. To install them, just run:
$ pre-commit install
in the repository root. Then, all changes will be validated automatically before you commit.
If you introduce new code, please make sure that it is covered by tests. To catch problems early, we recommend that you run the test suite locally before creating a PR.
You can set the environment variable
enable debug logging every time an MPI primitive is called from within a
jitted function. You will then see messages like these:
$ MPI4JAX_DEBUG=1 mpirun -n 2 python send_recv.py r0 | MPI_Send -> 1 with tag 0 and token 7fd7abc5f5c0 r1 | MPI_Recv <- 0 with tag -1 and token 7f9af7419ac0
This can be useful to debug deadlocks or MPI errors.