1"""mpi4jax demo application -- Shallow water
2
3A non-linear shallow water solver, adapted from:
4
5https://github.com/dionhaefner/shallow-water
6
7Usage examples:
8
9 # runs demo on 4 processes
10 $ mpirun -n 4 python shallow_water.py
11
12 # saves output animation as shallow-water.mp4
13 $ mpirun -n 4 python shallow_water.py --save-animation
14
15 # runs demo as a benchmark (no output)
16 $ mpirun -n 4 python shallow_water.py --benchmark
17
18"""
19
20import os
21import sys
22import math
23import time
24import warnings
25from contextlib import ExitStack
26from collections import namedtuple
27from functools import partial
28
29import numpy as np
30from mpi4py import MPI
31
32try:
33 import tqdm
34except ImportError:
35 warnings.warn("Could not import tqdm, can't show progress bar")
36 HAS_TQDM = False
37else:
38 HAS_TQDM = True
39
40mpi_comm = MPI.COMM_WORLD
41mpi_rank = mpi_comm.Get_rank()
42mpi_size = mpi_comm.Get_size()
43
44# on GPU, put each process on its own device
45os.environ["CUDA_VISIBLE_DEVICES"] = str(mpi_rank)
46
47import jax # noqa: E402
48import jax.numpy as jnp # noqa: E402
49
50import mpi4jax # noqa: E402
51
52
53#
54# MPI setup
55#
56
57supported_nproc = (1, 2, 4, 6, 8, 16)
58if mpi_size not in supported_nproc:
59 raise RuntimeError(
60 f"Got invalid number of MPI processes: {mpi_size}. "
61 f"Please choose one of these: {supported_nproc}."
62 )
63
64nproc_y = min(mpi_size, 2)
65nproc_x = mpi_size // nproc_y
66
67proc_idx = np.unravel_index(mpi_rank, (nproc_y, nproc_x))
68
69#
70# Grid setup
71#
72
73# we use 1 cell overlap on each side of the domain
74nx_global = 360 + 2
75ny_global = 180 + 2
76
77# grid spacing in metres
78dx = 5e3
79dy = 5e3
80
81# make sure processes divide the domain evenly
82assert (nx_global - 2) % nproc_x == 0
83assert (ny_global - 2) % nproc_y == 0
84
85nx_local = (nx_global - 2) // nproc_x + 2
86ny_local = (ny_global - 2) // nproc_y + 2
87
88x_global, y_global = (
89 np.arange(-1, nx_global - 1) * dx,
90 np.arange(-1, ny_global - 1) * dy,
91)
92yy_global, xx_global = np.meshgrid(y_global, x_global, indexing="ij")
93
94length_x = x_global[-2] - x_global[1]
95length_y = y_global[-2] - y_global[1]
96
97# this extracts the processor-local domain from a global array
98local_slice = (
99 slice((ny_local - 2) * proc_idx[0], (ny_local - 2) * proc_idx[0] + ny_local),
100 slice((nx_local - 2) * proc_idx[1], (nx_local - 2) * proc_idx[1] + nx_local),
101)
102
103x_local = x_global[local_slice[1]]
104y_local = y_global[local_slice[0]]
105
106xx_local = xx_global[local_slice]
107yy_local = yy_global[local_slice]
108
109#
110# Model parameters
111#
112
113# physical parameters
114GRAVITY = 9.81
115DEPTH = 100.0
116CORIOLIS_F = 2e-4
117CORIOLIS_BETA = 2e-11
118CORIOLIS_PARAM = CORIOLIS_F + yy_local * CORIOLIS_BETA
119LATERAL_VISCOSITY = 1e-3 * CORIOLIS_F * dx**2
120
121# other parameters
122DAY_IN_SECONDS = 86_400
123PERIODIC_BOUNDARY_X = True
124
125ADAMS_BASHFORTH_A = 1.5 + 0.1
126ADAMS_BASHFORTH_B = -(0.5 + 0.1)
127
128# output parameters
129PLOT_ETA_RANGE = 10
130PLOT_EVERY = 100
131MAX_QUIVERS = (25, 50)
132
133
134# set time step based on CFL condition
135dt = 0.125 * min(dx, dy) / np.sqrt(GRAVITY * DEPTH)
136
137
138@jax.jit
139def get_initial_conditions():
140 """For the initial conditions, we use a horizontal jet in geostrophic balance."""
141 # global initial conditions
142 u0_global = 10 * jnp.exp(
143 -((yy_global - 0.5 * length_y) ** 2) / (0.02 * length_x) ** 2
144 )
145 v0_global = jnp.zeros_like(u0_global)
146
147 # approximate balance h_y = -(f/g)u
148 coriolis_global = CORIOLIS_F + yy_global * CORIOLIS_BETA
149 h_geostrophy = np.cumsum(-dy * u0_global * coriolis_global / GRAVITY, axis=0)
150 h0_global = (
151 DEPTH
152 + h_geostrophy
153 # make sure h0 is centered around depth
154 - h_geostrophy.mean()
155 # small perturbation to break symmetry
156 + 0.2
157 * np.sin(xx_global / length_x * 10 * np.pi)
158 * np.cos(yy_global / length_y * 8 * np.pi)
159 )
160
161 h0_local = h0_global[local_slice]
162 u0_local = u0_global[local_slice]
163 v0_local = v0_global[local_slice]
164
165 h0_local = enforce_boundaries(h0_local, "h")
166 u0_local = enforce_boundaries(u0_local, "u")
167 v0_local = enforce_boundaries(v0_local, "v")
168
169 return h0_local, u0_local, v0_local
170
171
172@partial(jax.jit, static_argnums=(1,))
173def enforce_boundaries(arr, grid):
174 """Handle boundary exchange between processors.
175
176 This is where mpi4jax comes in!
177 """
178 assert grid in ("h", "u", "v")
179
180 # start sending west, go clockwise
181 send_order = (
182 "west",
183 "north",
184 "east",
185 "south",
186 )
187
188 # start receiving east, go clockwise
189 recv_order = (
190 "east",
191 "south",
192 "west",
193 "north",
194 )
195
196 overlap_slices_send = dict(
197 south=(1, slice(None), Ellipsis),
198 west=(slice(None), 1, Ellipsis),
199 north=(-2, slice(None), Ellipsis),
200 east=(slice(None), -2, Ellipsis),
201 )
202
203 overlap_slices_recv = dict(
204 south=(0, slice(None), Ellipsis),
205 west=(slice(None), 0, Ellipsis),
206 north=(-1, slice(None), Ellipsis),
207 east=(slice(None), -1, Ellipsis),
208 )
209
210 proc_neighbors = {
211 "south": (proc_idx[0] - 1, proc_idx[1]) if proc_idx[0] > 0 else None,
212 "west": (proc_idx[0], proc_idx[1] - 1) if proc_idx[1] > 0 else None,
213 "north": (proc_idx[0] + 1, proc_idx[1]) if proc_idx[0] < nproc_y - 1 else None,
214 "east": (proc_idx[0], proc_idx[1] + 1) if proc_idx[1] < nproc_x - 1 else None,
215 }
216
217 if PERIODIC_BOUNDARY_X:
218 if proc_idx[1] == 0:
219 proc_neighbors["west"] = (proc_idx[0], nproc_x - 1)
220
221 if proc_idx[1] == nproc_x - 1:
222 proc_neighbors["east"] = (proc_idx[0], 0)
223
224 for send_dir, recv_dir in zip(send_order, recv_order):
225 send_proc = proc_neighbors[send_dir]
226 recv_proc = proc_neighbors[recv_dir]
227
228 if send_proc is None and recv_proc is None:
229 continue
230
231 if send_proc is not None:
232 send_proc = np.ravel_multi_index(send_proc, (nproc_y, nproc_x))
233
234 if recv_proc is not None:
235 recv_proc = np.ravel_multi_index(recv_proc, (nproc_y, nproc_x))
236
237 recv_idx = overlap_slices_recv[recv_dir]
238 recv_arr = jnp.empty_like(arr[recv_idx])
239
240 send_idx = overlap_slices_send[send_dir]
241 send_arr = arr[send_idx]
242
243 if send_proc is None:
244 recv_arr = mpi4jax.recv(recv_arr, source=recv_proc, comm=mpi_comm)
245 arr = arr.at[recv_idx].set(recv_arr)
246 elif recv_proc is None:
247 mpi4jax.send(send_arr, dest=send_proc, comm=mpi_comm)
248 else:
249 recv_arr = mpi4jax.sendrecv(
250 send_arr,
251 recv_arr,
252 source=recv_proc,
253 dest=send_proc,
254 comm=mpi_comm,
255 )
256 arr = arr.at[recv_idx].set(recv_arr)
257
258 if not PERIODIC_BOUNDARY_X and grid == "u" and proc_idx[1] == nproc_x - 1:
259 arr = arr.at[:, -2].set(0.0)
260
261 if grid == "v" and proc_idx[0] == nproc_y - 1:
262 arr = arr.at[-2, :].set(0.0)
263
264 return arr
265
266
267ModelState = namedtuple("ModelState", "h, u, v, dh, du, dv")
268
269
270@partial(jax.jit, static_argnums=(1,))
271def shallow_water_step(state, is_first_step):
272 """Perform one step of the shallow-water model.
273
274 Returns modified model state.
275 """
276 h, u, v, dh, du, dv = state
277
278 hc = jnp.pad(h[1:-1, 1:-1], 1, "edge")
279 hc = enforce_boundaries(hc, "h")
280
281 fe = jnp.empty_like(u)
282 fn = jnp.empty_like(u)
283
284 fe = fe.at[1:-1, 1:-1].set(0.5 * (hc[1:-1, 1:-1] + hc[1:-1, 2:]) * u[1:-1, 1:-1])
285 fn = fn.at[1:-1, 1:-1].set(0.5 * (hc[1:-1, 1:-1] + hc[2:, 1:-1]) * v[1:-1, 1:-1])
286 fe = enforce_boundaries(fe, "u")
287 fn = enforce_boundaries(fn, "v")
288
289 dh_new = dh.at[1:-1, 1:-1].set(
290 -(fe[1:-1, 1:-1] - fe[1:-1, :-2]) / dx - (fn[1:-1, 1:-1] - fn[:-2, 1:-1]) / dy
291 )
292
293 # nonlinear momentum equation
294 q = jnp.empty_like(u)
295 ke = jnp.empty_like(u)
296
297 # planetary and relative vorticity
298 q = q.at[1:-1, 1:-1].set(
299 CORIOLIS_PARAM[1:-1, 1:-1]
300 + ((v[1:-1, 2:] - v[1:-1, 1:-1]) / dx - (u[2:, 1:-1] - u[1:-1, 1:-1]) / dy)
301 )
302 # potential vorticity
303 q = q.at[1:-1, 1:-1].mul(
304 1.0 / (0.25 * (hc[1:-1, 1:-1] + hc[1:-1, 2:] + hc[2:, 1:-1] + hc[2:, 2:]))
305 )
306 q = enforce_boundaries(q, "h")
307
308 du_new = du.at[1:-1, 1:-1].set(
309 -GRAVITY * (h[1:-1, 2:] - h[1:-1, 1:-1]) / dx
310 + 0.5
311 * (
312 q[1:-1, 1:-1] * 0.5 * (fn[1:-1, 1:-1] + fn[1:-1, 2:])
313 + q[:-2, 1:-1] * 0.5 * (fn[:-2, 1:-1] + fn[:-2, 2:])
314 )
315 )
316 dv_new = dv.at[1:-1, 1:-1].set(
317 -GRAVITY * (h[2:, 1:-1] - h[1:-1, 1:-1]) / dy
318 - 0.5
319 * (
320 q[1:-1, 1:-1] * 0.5 * (fe[1:-1, 1:-1] + fe[2:, 1:-1])
321 + q[1:-1, :-2] * 0.5 * (fe[1:-1, :-2] + fe[2:, :-2])
322 )
323 )
324 ke = ke.at[1:-1, 1:-1].set(
325 0.5
326 * (
327 0.5 * (u[1:-1, 1:-1] ** 2 + u[1:-1, :-2] ** 2)
328 + 0.5 * (v[1:-1, 1:-1] ** 2 + v[:-2, 1:-1] ** 2)
329 )
330 )
331 ke = enforce_boundaries(ke, "h")
332
333 du_new = du_new.at[1:-1, 1:-1].add(-(ke[1:-1, 2:] - ke[1:-1, 1:-1]) / dx)
334 dv_new = dv_new.at[1:-1, 1:-1].add(-(ke[2:, 1:-1] - ke[1:-1, 1:-1]) / dy)
335
336 if is_first_step:
337 u = u.at[1:-1, 1:-1].add(dt * du_new[1:-1, 1:-1])
338 v = v.at[1:-1, 1:-1].add(dt * dv_new[1:-1, 1:-1])
339 h = h.at[1:-1, 1:-1].add(dt * dh_new[1:-1, 1:-1])
340 else:
341 u = u.at[1:-1, 1:-1].add(
342 dt
343 * (
344 ADAMS_BASHFORTH_A * du_new[1:-1, 1:-1]
345 + ADAMS_BASHFORTH_B * du[1:-1, 1:-1]
346 )
347 )
348 v = v.at[1:-1, 1:-1].add(
349 dt
350 * (
351 ADAMS_BASHFORTH_A * dv_new[1:-1, 1:-1]
352 + ADAMS_BASHFORTH_B * dv[1:-1, 1:-1]
353 )
354 )
355 h = h.at[1:-1, 1:-1].add(
356 dt
357 * (
358 ADAMS_BASHFORTH_A * dh_new[1:-1, 1:-1]
359 + ADAMS_BASHFORTH_B * dh[1:-1, 1:-1]
360 )
361 )
362
363 h = enforce_boundaries(h, "h")
364 u = enforce_boundaries(u, "u")
365 v = enforce_boundaries(v, "v")
366
367 if LATERAL_VISCOSITY > 0:
368 # lateral friction
369 fe = fe.at[1:-1, 1:-1].set(
370 LATERAL_VISCOSITY * (u[1:-1, 2:] - u[1:-1, 1:-1]) / dx
371 )
372 fn = fn.at[1:-1, 1:-1].set(
373 LATERAL_VISCOSITY * (u[2:, 1:-1] - u[1:-1, 1:-1]) / dy
374 )
375 fe = enforce_boundaries(fe, "u")
376 fn = enforce_boundaries(fn, "v")
377
378 u = u.at[1:-1, 1:-1].add(
379 dt
380 * (
381 (fe[1:-1, 1:-1] - fe[1:-1, :-2]) / dx
382 + (fn[1:-1, 1:-1] - fn[:-2, 1:-1]) / dy
383 )
384 )
385
386 fe = fe.at[1:-1, 1:-1].set(
387 LATERAL_VISCOSITY * (v[1:-1, 2:] - u[1:-1, 1:-1]) / dx
388 )
389 fn = fn.at[1:-1, 1:-1].set(
390 LATERAL_VISCOSITY * (v[2:, 1:-1] - u[1:-1, 1:-1]) / dy
391 )
392 fe = enforce_boundaries(fe, "u")
393 fn = enforce_boundaries(fn, "v")
394
395 v = v.at[1:-1, 1:-1].add(
396 dt
397 * (
398 (fe[1:-1, 1:-1] - fe[1:-1, :-2]) / dx
399 + (fn[1:-1, 1:-1] - fn[:-2, 1:-1]) / dy
400 )
401 )
402
403 return ModelState(h, u, v, dh_new, du_new, dv_new)
404
405
406@partial(jax.jit, static_argnums=(1,))
407def do_multistep(state, num_steps):
408 """Perform multiple model steps back-to-back."""
409 return jax.lax.fori_loop(
410 0, num_steps, lambda _, s: shallow_water_step(s, False), state
411 )
412
413
414def solve_shallow_water(t1, num_multisteps=10):
415 """Iterate the model forward in time."""
416 # initial conditions
417 h, u, v = get_initial_conditions()
418 du, dv, dh = (jnp.zeros((ny_local, nx_local)) for _ in range(3))
419
420 state = ModelState(h, u, v, dh, du, dv)
421 sol = [state]
422
423 state = shallow_water_step(state, True)
424 sol.append(state)
425 t = dt
426
427 if HAS_TQDM:
428 pbar = tqdm.tqdm(
429 total=math.ceil(t1 / dt),
430 disable=mpi_rank != 0,
431 unit="model day",
432 initial=t / DAY_IN_SECONDS,
433 unit_scale=dt / DAY_IN_SECONDS,
434 bar_format=(
435 "{l_bar}{bar}| {n:.2f}/{total:.2f} [{elapsed}<{remaining}, "
436 "{rate_fmt}{postfix}]"
437 ),
438 )
439
440 # pre-compile JAX kernel
441 do_multistep(state, num_multisteps)
442
443 start = time.perf_counter()
444 with ExitStack() as es:
445 if HAS_TQDM:
446 es.enter_context(pbar)
447
448 while t < t1:
449 state = do_multistep(state, num_multisteps)
450 state[0].block_until_ready()
451 sol.append(state)
452
453 t += dt * num_multisteps
454
455 if t < t1 and HAS_TQDM:
456 pbar.update(num_multisteps)
457
458 end = time.perf_counter()
459
460 if mpi_rank == 0:
461 print(f"\nSolution took {end - start:.2f}s")
462
463 return sol
464
465
466@jax.vmap
467@jax.jit
468def reassemble_array(arr):
469 """This converts an array containing the solution from each processor as
470 first axis to the full solution.
471
472 Shape (mpi_size, ny_local, nx_local) -> (ny_global, nx_global)
473 """
474 out = jnp.empty((ny_global, nx_global), dtype=arr.dtype)
475 for i in range(mpi_size):
476 proc_idx_i = np.unravel_index(i, (nproc_y, nproc_x))
477 local_slice_i = (
478 slice(
479 (ny_local - 2) * proc_idx_i[0],
480 (ny_local - 2) * proc_idx_i[0] + ny_local,
481 ),
482 slice(
483 (nx_local - 2) * proc_idx_i[1],
484 (nx_local - 2) * proc_idx_i[1] + nx_local,
485 ),
486 )
487 out = out.at[local_slice_i].set(arr[i])
488
489 return out
490
491
492def animate_shallow_water(sol):
493 """Create a matplotlib animation of the result."""
494 import matplotlib.pyplot as plt
495 from matplotlib import animation
496
497 quiver_stride = (
498 slice(1, -1, ny_global // MAX_QUIVERS[0]),
499 slice(1, -1, nx_global // MAX_QUIVERS[1]),
500 )
501
502 # set up figure
503 fig = plt.figure(figsize=(6, 4))
504 ax = plt.gca()
505
506 # surface plot of height anomaly
507 cs = ax.pcolormesh(
508 0.5 * (x_global[:-1] + x_global[1:]) / 1e3,
509 0.5 * (y_global[:-1] + y_global[1:]) / 1e3,
510 sol[0].h[1:-1, 1:-1] - DEPTH,
511 vmin=-PLOT_ETA_RANGE,
512 vmax=PLOT_ETA_RANGE,
513 cmap="RdBu_r",
514 )
515
516 # quiver plot of velocity
517 cq = ax.quiver(
518 xx_global[quiver_stride] / 1e3,
519 yy_global[quiver_stride] / 1e3,
520 sol[0].u[quiver_stride],
521 sol[0].v[quiver_stride],
522 clip_on=True,
523 )
524
525 # time indicator
526 t = ax.text(
527 s="",
528 x=0.05,
529 y=0.95,
530 ha="left",
531 va="top",
532 backgroundcolor=(1, 1, 1, 0.8),
533 transform=ax.transAxes,
534 )
535
536 ax.set(
537 aspect="equal",
538 xlim=(x_global[1] / 1e3, x_global[-2] / 1e3),
539 ylim=(y_global[1] / 1e3, y_global[-2] / 1e3),
540 xlabel="$x$ (km)",
541 ylabel="$y$ (km)",
542 )
543
544 plt.colorbar(
545 cs,
546 orientation="horizontal",
547 label="Surface height anomaly (m)",
548 pad=0.2,
549 shrink=0.8,
550 )
551 fig.tight_layout()
552
553 def animate(i):
554 state = sol[i]
555
556 eta = state.h - DEPTH
557 cs.set_array(eta[1:-1, 1:-1].flatten())
558 cq.set_UVC(state.u[quiver_stride], state.v[quiver_stride])
559
560 current_time = PLOT_EVERY * dt * i
561 t.set_text(f"t = {current_time / DAY_IN_SECONDS:.2f} days")
562 return (cs, cq, t)
563
564 anim = animation.FuncAnimation(
565 fig, animate, frames=len(sol), interval=50, blit=True, repeat_delay=3_000
566 )
567 return anim
568
569
570if __name__ == "__main__":
571 benchmark_mode = "--benchmark" in sys.argv
572
573 sol = solve_shallow_water(t1=10 * DAY_IN_SECONDS, num_multisteps=PLOT_EVERY)
574
575 if benchmark_mode:
576 sys.exit(0)
577
578 # copy solution to mpi_rank 0
579 full_sol_arr, _ = mpi4jax.gather(jnp.asarray(sol), root=0, comm=mpi_comm)
580
581 if mpi_rank == 0:
582 # full_sol_arr has shape (nproc, time, nvars, ny, nx)
583 full_sol_arr = jnp.moveaxis(full_sol_arr, 0, 2)
584 full_sol = [ModelState(*reassemble_array(x)) for x in full_sol_arr]
585
586 anim = animate_shallow_water(full_sol)
587
588 if "--save-animation" in sys.argv:
589 # save animation as MP4 video (requires ffmpeg)
590 anim.save("shallow-water.mp4", writer="ffmpeg", dpi=100)
591 else:
592 import matplotlib.pyplot as plt
593
594 plt.show()