1"""mpi4jax demo application -- Shallow water
2
3A non-linear shallow water solver, adapted from:
4
5https://github.com/dionhaefner/shallow-water
6
7Usage examples:
8
9 # runs demo on 4 processes
10 $ mpirun -n 4 python shallow_water.py
11
12 # saves output animation as shallow-water.mp4
13 $ mpirun -n 4 python shallow_water.py --save-animation
14
15 # runs demo as a benchmark (no output)
16 $ mpirun -n 4 python shallow_water.py --benchmark
17
18"""
19
20import os
21import sys
22import math
23import time
24import warnings
25from contextlib import ExitStack
26from collections import namedtuple
27from functools import partial
28
29import numpy as np
30from mpi4py import MPI
31
32try:
33 import tqdm
34except ImportError:
35 warnings.warn("Could not import tqdm, can't show progress bar")
36 HAS_TQDM = False
37else:
38 HAS_TQDM = True
39
40mpi_comm = MPI.COMM_WORLD
41mpi_rank = mpi_comm.Get_rank()
42mpi_size = mpi_comm.Get_size()
43
44# on GPU, put each process on its own device
45os.environ["CUDA_VISIBLE_DEVICES"] = str(mpi_rank)
46
47import jax # noqa: E402
48import jax.numpy as jnp # noqa: E402
49
50import mpi4jax # noqa: E402
51
52
53#
54# MPI setup
55#
56
57supported_nproc = (1, 2, 4, 6, 8, 16)
58if mpi_size not in supported_nproc:
59 raise RuntimeError(
60 f"Got invalid number of MPI processes: {mpi_size}. "
61 f"Please choose one of these: {supported_nproc}."
62 )
63
64nproc_y = min(mpi_size, 2)
65nproc_x = mpi_size // nproc_y
66
67proc_idx = np.unravel_index(mpi_rank, (nproc_y, nproc_x))
68
69#
70# Grid setup
71#
72
73# we use 1 cell overlap on each side of the domain
74nx_global = 360 + 2
75ny_global = 180 + 2
76
77# grid spacing in metres
78dx = 5e3
79dy = 5e3
80
81# make sure processes divide the domain evenly
82assert (nx_global - 2) % nproc_x == 0
83assert (ny_global - 2) % nproc_y == 0
84
85nx_local = (nx_global - 2) // nproc_x + 2
86ny_local = (ny_global - 2) // nproc_y + 2
87
88x_global, y_global = (
89 np.arange(-1, nx_global - 1) * dx,
90 np.arange(-1, ny_global - 1) * dy,
91)
92yy_global, xx_global = np.meshgrid(y_global, x_global, indexing="ij")
93
94length_x = x_global[-2] - x_global[1]
95length_y = y_global[-2] - y_global[1]
96
97# this extracts the processor-local domain from a global array
98local_slice = (
99 slice((ny_local - 2) * proc_idx[0], (ny_local - 2) * proc_idx[0] + ny_local),
100 slice((nx_local - 2) * proc_idx[1], (nx_local - 2) * proc_idx[1] + nx_local),
101)
102
103x_local = x_global[local_slice[1]]
104y_local = y_global[local_slice[0]]
105
106xx_local = xx_global[local_slice]
107yy_local = yy_global[local_slice]
108
109#
110# Model parameters
111#
112
113# physical parameters
114GRAVITY = 9.81
115DEPTH = 100.0
116CORIOLIS_F = 2e-4
117CORIOLIS_BETA = 2e-11
118CORIOLIS_PARAM = CORIOLIS_F + yy_local * CORIOLIS_BETA
119LATERAL_VISCOSITY = 1e-3 * CORIOLIS_F * dx**2
120
121# other parameters
122DAY_IN_SECONDS = 86_400
123PERIODIC_BOUNDARY_X = True
124
125ADAMS_BASHFORTH_A = 1.5 + 0.1
126ADAMS_BASHFORTH_B = -(0.5 + 0.1)
127
128# output parameters
129PLOT_ETA_RANGE = 10
130PLOT_EVERY = 100
131MAX_QUIVERS = (25, 50)
132
133
134# set time step based on CFL condition
135dt = 0.125 * min(dx, dy) / np.sqrt(GRAVITY * DEPTH)
136
137
138@jax.jit
139def get_initial_conditions():
140 """For the initial conditions, we use a horizontal jet in geostrophic balance."""
141 # global initial conditions
142 u0_global = 10 * jnp.exp(
143 -((yy_global - 0.5 * length_y) ** 2) / (0.02 * length_x) ** 2
144 )
145 v0_global = jnp.zeros_like(u0_global)
146
147 # approximate balance h_y = -(f/g)u
148 coriolis_global = CORIOLIS_F + yy_global * CORIOLIS_BETA
149 h_geostrophy = np.cumsum(-dy * u0_global * coriolis_global / GRAVITY, axis=0)
150 h0_global = (
151 DEPTH
152 + h_geostrophy
153 # make sure h0 is centered around depth
154 - h_geostrophy.mean()
155 # small perturbation to break symmetry
156 + 0.2
157 * np.sin(xx_global / length_x * 10 * np.pi)
158 * np.cos(yy_global / length_y * 8 * np.pi)
159 )
160
161 h0_local = h0_global[local_slice]
162 u0_local = u0_global[local_slice]
163 v0_local = v0_global[local_slice]
164
165 token = jax.lax.create_token()
166 h0_local, token = enforce_boundaries(h0_local, "h", token)
167 u0_local, token = enforce_boundaries(u0_local, "u", token)
168 v0_local, token = enforce_boundaries(v0_local, "v", token)
169
170 return h0_local, u0_local, v0_local
171
172
173@partial(jax.jit, static_argnums=(1,))
174def enforce_boundaries(arr, grid, token=None):
175 """Handle boundary exchange between processors.
176
177 This is where mpi4jax comes in!
178 """
179 assert grid in ("h", "u", "v")
180
181 # start sending west, go clockwise
182 send_order = (
183 "west",
184 "north",
185 "east",
186 "south",
187 )
188
189 # start receiving east, go clockwise
190 recv_order = (
191 "east",
192 "south",
193 "west",
194 "north",
195 )
196
197 overlap_slices_send = dict(
198 south=(1, slice(None), Ellipsis),
199 west=(slice(None), 1, Ellipsis),
200 north=(-2, slice(None), Ellipsis),
201 east=(slice(None), -2, Ellipsis),
202 )
203
204 overlap_slices_recv = dict(
205 south=(0, slice(None), Ellipsis),
206 west=(slice(None), 0, Ellipsis),
207 north=(-1, slice(None), Ellipsis),
208 east=(slice(None), -1, Ellipsis),
209 )
210
211 proc_neighbors = {
212 "south": (proc_idx[0] - 1, proc_idx[1]) if proc_idx[0] > 0 else None,
213 "west": (proc_idx[0], proc_idx[1] - 1) if proc_idx[1] > 0 else None,
214 "north": (proc_idx[0] + 1, proc_idx[1]) if proc_idx[0] < nproc_y - 1 else None,
215 "east": (proc_idx[0], proc_idx[1] + 1) if proc_idx[1] < nproc_x - 1 else None,
216 }
217
218 if PERIODIC_BOUNDARY_X:
219 if proc_idx[1] == 0:
220 proc_neighbors["west"] = (proc_idx[0], nproc_x - 1)
221
222 if proc_idx[1] == nproc_x - 1:
223 proc_neighbors["east"] = (proc_idx[0], 0)
224
225 if token is None:
226 token = jax.lax.create_token()
227
228 for send_dir, recv_dir in zip(send_order, recv_order):
229 send_proc = proc_neighbors[send_dir]
230 recv_proc = proc_neighbors[recv_dir]
231
232 if send_proc is None and recv_proc is None:
233 continue
234
235 if send_proc is not None:
236 send_proc = np.ravel_multi_index(send_proc, (nproc_y, nproc_x))
237
238 if recv_proc is not None:
239 recv_proc = np.ravel_multi_index(recv_proc, (nproc_y, nproc_x))
240
241 recv_idx = overlap_slices_recv[recv_dir]
242 recv_arr = jnp.empty_like(arr[recv_idx])
243
244 send_idx = overlap_slices_send[send_dir]
245 send_arr = arr[send_idx]
246
247 if send_proc is None:
248 recv_arr, token = mpi4jax.recv(
249 recv_arr, source=recv_proc, comm=mpi_comm, token=token
250 )
251 arr = arr.at[recv_idx].set(recv_arr)
252 elif recv_proc is None:
253 token = mpi4jax.send(send_arr, dest=send_proc, comm=mpi_comm, token=token)
254 else:
255 recv_arr, token = mpi4jax.sendrecv(
256 send_arr,
257 recv_arr,
258 source=recv_proc,
259 dest=send_proc,
260 comm=mpi_comm,
261 token=token,
262 )
263 arr = arr.at[recv_idx].set(recv_arr)
264
265 if not PERIODIC_BOUNDARY_X and grid == "u" and proc_idx[1] == nproc_x - 1:
266 arr = arr.at[:, -2].set(0.0)
267
268 if grid == "v" and proc_idx[0] == nproc_y - 1:
269 arr = arr.at[-2, :].set(0.0)
270
271 return arr, token
272
273
274ModelState = namedtuple("ModelState", "h, u, v, dh, du, dv")
275
276
277@partial(jax.jit, static_argnums=(1,))
278def shallow_water_step(state, is_first_step):
279 """Perform one step of the shallow-water model.
280
281 Returns modified model state.
282 """
283 token = jax.lax.create_token()
284
285 h, u, v, dh, du, dv = state
286
287 hc = jnp.pad(h[1:-1, 1:-1], 1, "edge")
288 hc, token = enforce_boundaries(hc, "h", token)
289
290 fe = jnp.empty_like(u)
291 fn = jnp.empty_like(u)
292
293 fe = fe.at[1:-1, 1:-1].set(0.5 * (hc[1:-1, 1:-1] + hc[1:-1, 2:]) * u[1:-1, 1:-1])
294 fn = fn.at[1:-1, 1:-1].set(0.5 * (hc[1:-1, 1:-1] + hc[2:, 1:-1]) * v[1:-1, 1:-1])
295 fe, token = enforce_boundaries(fe, "u", token)
296 fn, token = enforce_boundaries(fn, "v", token)
297
298 dh_new = dh.at[1:-1, 1:-1].set(
299 -(fe[1:-1, 1:-1] - fe[1:-1, :-2]) / dx - (fn[1:-1, 1:-1] - fn[:-2, 1:-1]) / dy
300 )
301
302 # nonlinear momentum equation
303 q = jnp.empty_like(u)
304 ke = jnp.empty_like(u)
305
306 # planetary and relative vorticity
307 q = q.at[1:-1, 1:-1].set(
308 CORIOLIS_PARAM[1:-1, 1:-1]
309 + ((v[1:-1, 2:] - v[1:-1, 1:-1]) / dx - (u[2:, 1:-1] - u[1:-1, 1:-1]) / dy)
310 )
311 # potential vorticity
312 q = q.at[1:-1, 1:-1].mul(
313 1.0 / (0.25 * (hc[1:-1, 1:-1] + hc[1:-1, 2:] + hc[2:, 1:-1] + hc[2:, 2:]))
314 )
315 q, token = enforce_boundaries(q, "h", token)
316
317 du_new = du.at[1:-1, 1:-1].set(
318 -GRAVITY * (h[1:-1, 2:] - h[1:-1, 1:-1]) / dx
319 + 0.5
320 * (
321 q[1:-1, 1:-1] * 0.5 * (fn[1:-1, 1:-1] + fn[1:-1, 2:])
322 + q[:-2, 1:-1] * 0.5 * (fn[:-2, 1:-1] + fn[:-2, 2:])
323 )
324 )
325 dv_new = dv.at[1:-1, 1:-1].set(
326 -GRAVITY * (h[2:, 1:-1] - h[1:-1, 1:-1]) / dy
327 - 0.5
328 * (
329 q[1:-1, 1:-1] * 0.5 * (fe[1:-1, 1:-1] + fe[2:, 1:-1])
330 + q[1:-1, :-2] * 0.5 * (fe[1:-1, :-2] + fe[2:, :-2])
331 )
332 )
333 ke = ke.at[1:-1, 1:-1].set(
334 0.5
335 * (
336 0.5 * (u[1:-1, 1:-1] ** 2 + u[1:-1, :-2] ** 2)
337 + 0.5 * (v[1:-1, 1:-1] ** 2 + v[:-2, 1:-1] ** 2)
338 )
339 )
340 ke, token = enforce_boundaries(ke, "h", token)
341
342 du_new = du_new.at[1:-1, 1:-1].add(-(ke[1:-1, 2:] - ke[1:-1, 1:-1]) / dx)
343 dv_new = dv_new.at[1:-1, 1:-1].add(-(ke[2:, 1:-1] - ke[1:-1, 1:-1]) / dy)
344
345 if is_first_step:
346 u = u.at[1:-1, 1:-1].add(dt * du_new[1:-1, 1:-1])
347 v = v.at[1:-1, 1:-1].add(dt * dv_new[1:-1, 1:-1])
348 h = h.at[1:-1, 1:-1].add(dt * dh_new[1:-1, 1:-1])
349 else:
350 u = u.at[1:-1, 1:-1].add(
351 dt
352 * (
353 ADAMS_BASHFORTH_A * du_new[1:-1, 1:-1]
354 + ADAMS_BASHFORTH_B * du[1:-1, 1:-1]
355 )
356 )
357 v = v.at[1:-1, 1:-1].add(
358 dt
359 * (
360 ADAMS_BASHFORTH_A * dv_new[1:-1, 1:-1]
361 + ADAMS_BASHFORTH_B * dv[1:-1, 1:-1]
362 )
363 )
364 h = h.at[1:-1, 1:-1].add(
365 dt
366 * (
367 ADAMS_BASHFORTH_A * dh_new[1:-1, 1:-1]
368 + ADAMS_BASHFORTH_B * dh[1:-1, 1:-1]
369 )
370 )
371
372 h, token = enforce_boundaries(h, "h", token)
373 u, token = enforce_boundaries(u, "u", token)
374 v, token = enforce_boundaries(v, "v", token)
375
376 if LATERAL_VISCOSITY > 0:
377 # lateral friction
378 fe = fe.at[1:-1, 1:-1].set(
379 LATERAL_VISCOSITY * (u[1:-1, 2:] - u[1:-1, 1:-1]) / dx
380 )
381 fn = fn.at[1:-1, 1:-1].set(
382 LATERAL_VISCOSITY * (u[2:, 1:-1] - u[1:-1, 1:-1]) / dy
383 )
384 fe, token = enforce_boundaries(fe, "u", token)
385 fn, token = enforce_boundaries(fn, "v", token)
386
387 u = u.at[1:-1, 1:-1].add(
388 dt
389 * (
390 (fe[1:-1, 1:-1] - fe[1:-1, :-2]) / dx
391 + (fn[1:-1, 1:-1] - fn[:-2, 1:-1]) / dy
392 )
393 )
394
395 fe = fe.at[1:-1, 1:-1].set(
396 LATERAL_VISCOSITY * (v[1:-1, 2:] - u[1:-1, 1:-1]) / dx
397 )
398 fn = fn.at[1:-1, 1:-1].set(
399 LATERAL_VISCOSITY * (v[2:, 1:-1] - u[1:-1, 1:-1]) / dy
400 )
401 fe, token = enforce_boundaries(fe, "u", token)
402 fn, token = enforce_boundaries(fn, "v", token)
403
404 v = v.at[1:-1, 1:-1].add(
405 dt
406 * (
407 (fe[1:-1, 1:-1] - fe[1:-1, :-2]) / dx
408 + (fn[1:-1, 1:-1] - fn[:-2, 1:-1]) / dy
409 )
410 )
411
412 return ModelState(h, u, v, dh_new, du_new, dv_new)
413
414
415@partial(jax.jit, static_argnums=(1,))
416def do_multistep(state, num_steps):
417 """Perform multiple model steps back-to-back."""
418 return jax.lax.fori_loop(
419 0, num_steps, lambda _, s: shallow_water_step(s, False), state
420 )
421
422
423def solve_shallow_water(t1, num_multisteps=10):
424 """Iterate the model forward in time."""
425 # initial conditions
426 h, u, v = get_initial_conditions()
427 du, dv, dh = (jnp.zeros((ny_local, nx_local)) for _ in range(3))
428
429 state = ModelState(h, u, v, dh, du, dv)
430 sol = [state]
431
432 state = shallow_water_step(state, True)
433 sol.append(state)
434 t = dt
435
436 if HAS_TQDM:
437 pbar = tqdm.tqdm(
438 total=math.ceil(t1 / dt),
439 disable=mpi_rank != 0,
440 unit="model day",
441 initial=t / DAY_IN_SECONDS,
442 unit_scale=dt / DAY_IN_SECONDS,
443 bar_format=(
444 "{l_bar}{bar}| {n:.2f}/{total:.2f} [{elapsed}<{remaining}, "
445 "{rate_fmt}{postfix}]"
446 ),
447 )
448
449 # pre-compile JAX kernel
450 do_multistep(state, num_multisteps)
451
452 start = time.perf_counter()
453 with ExitStack() as es:
454 if HAS_TQDM:
455 es.enter_context(pbar)
456
457 while t < t1:
458 state = do_multistep(state, num_multisteps)
459 state[0].block_until_ready()
460 sol.append(state)
461
462 t += dt * num_multisteps
463
464 if t < t1 and HAS_TQDM:
465 pbar.update(num_multisteps)
466
467 end = time.perf_counter()
468
469 if mpi_rank == 0:
470 print(f"\nSolution took {end - start:.2f}s")
471
472 return sol
473
474
475@jax.vmap
476@jax.jit
477def reassemble_array(arr):
478 """This converts an array containing the solution from each processor as
479 first axis to the full solution.
480
481 Shape (mpi_size, ny_local, nx_local) -> (ny_global, nx_global)
482 """
483 out = jnp.empty((ny_global, nx_global), dtype=arr.dtype)
484 for i in range(mpi_size):
485 proc_idx_i = np.unravel_index(i, (nproc_y, nproc_x))
486 local_slice_i = (
487 slice(
488 (ny_local - 2) * proc_idx_i[0],
489 (ny_local - 2) * proc_idx_i[0] + ny_local,
490 ),
491 slice(
492 (nx_local - 2) * proc_idx_i[1],
493 (nx_local - 2) * proc_idx_i[1] + nx_local,
494 ),
495 )
496 out = out.at[local_slice_i].set(arr[i])
497
498 return out
499
500
501def animate_shallow_water(sol):
502 """Create a matplotlib animation of the result."""
503 import matplotlib.pyplot as plt
504 from matplotlib import animation
505
506 quiver_stride = (
507 slice(1, -1, ny_global // MAX_QUIVERS[0]),
508 slice(1, -1, nx_global // MAX_QUIVERS[1]),
509 )
510
511 # set up figure
512 fig = plt.figure(figsize=(6, 4))
513 ax = plt.gca()
514
515 # surface plot of height anomaly
516 cs = ax.pcolormesh(
517 0.5 * (x_global[:-1] + x_global[1:]) / 1e3,
518 0.5 * (y_global[:-1] + y_global[1:]) / 1e3,
519 sol[0].h[1:-1, 1:-1] - DEPTH,
520 vmin=-PLOT_ETA_RANGE,
521 vmax=PLOT_ETA_RANGE,
522 cmap="RdBu_r",
523 )
524
525 # quiver plot of velocity
526 cq = ax.quiver(
527 xx_global[quiver_stride] / 1e3,
528 yy_global[quiver_stride] / 1e3,
529 sol[0].u[quiver_stride],
530 sol[0].v[quiver_stride],
531 clip_on=True,
532 )
533
534 # time indicator
535 t = ax.text(
536 s="",
537 x=0.05,
538 y=0.95,
539 ha="left",
540 va="top",
541 backgroundcolor=(1, 1, 1, 0.8),
542 transform=ax.transAxes,
543 )
544
545 ax.set(
546 aspect="equal",
547 xlim=(x_global[1] / 1e3, x_global[-2] / 1e3),
548 ylim=(y_global[1] / 1e3, y_global[-2] / 1e3),
549 xlabel="$x$ (km)",
550 ylabel="$y$ (km)",
551 )
552
553 plt.colorbar(
554 cs,
555 orientation="horizontal",
556 label="Surface height anomaly (m)",
557 pad=0.2,
558 shrink=0.8,
559 )
560 fig.tight_layout()
561
562 def animate(i):
563 state = sol[i]
564
565 eta = state.h - DEPTH
566 cs.set_array(eta[1:-1, 1:-1].flatten())
567 cq.set_UVC(state.u[quiver_stride], state.v[quiver_stride])
568
569 current_time = PLOT_EVERY * dt * i
570 t.set_text(f"t = {current_time / DAY_IN_SECONDS:.2f} days")
571 return (cs, cq, t)
572
573 anim = animation.FuncAnimation(
574 fig, animate, frames=len(sol), interval=50, blit=True, repeat_delay=3_000
575 )
576 return anim
577
578
579if __name__ == "__main__":
580 benchmark_mode = "--benchmark" in sys.argv
581
582 sol = solve_shallow_water(t1=10 * DAY_IN_SECONDS, num_multisteps=PLOT_EVERY)
583
584 if benchmark_mode:
585 sys.exit(0)
586
587 # copy solution to mpi_rank 0
588 full_sol_arr, _ = mpi4jax.gather(jnp.asarray(sol), root=0, comm=mpi_comm)
589
590 if mpi_rank == 0:
591 # full_sol_arr has shape (nproc, time, nvars, ny, nx)
592 full_sol_arr = jnp.moveaxis(full_sol_arr, 0, 2)
593 full_sol = [ModelState(*reassemble_array(x)) for x in full_sol_arr]
594
595 anim = animate_shallow_water(full_sol)
596
597 if "--save-animation" in sys.argv:
598 # save animation as MP4 video (requires ffmpeg)
599 anim.save("shallow-water.mp4", writer="ffmpeg", dpi=100)
600 else:
601 import matplotlib.pyplot as plt
602
603 plt.show()