examples/shallow_water.py

  1"""mpi4jax demo application -- Shallow water
  2
  3A non-linear shallow water solver, adapted from:
  4
  5https://github.com/dionhaefner/shallow-water
  6
  7Usage examples:
  8
  9    # runs demo on 4 processes
 10    $ mpirun -n 4 python shallow_water.py
 11
 12    # saves output animation as shallow-water.mp4
 13    $ mpirun -n 4 python shallow_water.py --save-animation
 14
 15    # runs demo as a benchmark (no output)
 16    $ mpirun -n 4 python shallow_water.py --benchmark
 17
 18"""
 19
 20import os
 21import sys
 22import math
 23import time
 24import warnings
 25from contextlib import ExitStack
 26from collections import namedtuple
 27from functools import partial
 28
 29import numpy as np
 30from mpi4py import MPI
 31
 32try:
 33    import tqdm
 34except ImportError:
 35    warnings.warn("Could not import tqdm, can't show progress bar")
 36    HAS_TQDM = False
 37else:
 38    HAS_TQDM = True
 39
 40mpi_comm = MPI.COMM_WORLD
 41mpi_rank = mpi_comm.Get_rank()
 42mpi_size = mpi_comm.Get_size()
 43
 44# on GPU, put each process on its own device
 45os.environ["CUDA_VISIBLE_DEVICES"] = str(mpi_rank)
 46
 47import jax  # noqa: E402
 48import jax.numpy as jnp  # noqa: E402
 49
 50import mpi4jax  # noqa: E402
 51
 52
 53#
 54# MPI setup
 55#
 56
 57supported_nproc = (1, 2, 4, 6, 8, 16)
 58if mpi_size not in supported_nproc:
 59    raise RuntimeError(
 60        f"Got invalid number of MPI processes: {mpi_size}. "
 61        f"Please choose one of these: {supported_nproc}."
 62    )
 63
 64nproc_y = min(mpi_size, 2)
 65nproc_x = mpi_size // nproc_y
 66
 67proc_idx = np.unravel_index(mpi_rank, (nproc_y, nproc_x))
 68
 69#
 70# Grid setup
 71#
 72
 73# we use 1 cell overlap on each side of the domain
 74nx_global = 360 + 2
 75ny_global = 180 + 2
 76
 77# grid spacing in metres
 78dx = 5e3
 79dy = 5e3
 80
 81# make sure processes divide the domain evenly
 82assert (nx_global - 2) % nproc_x == 0
 83assert (ny_global - 2) % nproc_y == 0
 84
 85nx_local = (nx_global - 2) // nproc_x + 2
 86ny_local = (ny_global - 2) // nproc_y + 2
 87
 88x_global, y_global = (
 89    np.arange(-1, nx_global - 1) * dx,
 90    np.arange(-1, ny_global - 1) * dy,
 91)
 92yy_global, xx_global = np.meshgrid(y_global, x_global, indexing="ij")
 93
 94length_x = x_global[-2] - x_global[1]
 95length_y = y_global[-2] - y_global[1]
 96
 97# this extracts the processor-local domain from a global array
 98local_slice = (
 99    slice((ny_local - 2) * proc_idx[0], (ny_local - 2) * proc_idx[0] + ny_local),
100    slice((nx_local - 2) * proc_idx[1], (nx_local - 2) * proc_idx[1] + nx_local),
101)
102
103x_local = x_global[local_slice[1]]
104y_local = y_global[local_slice[0]]
105
106xx_local = xx_global[local_slice]
107yy_local = yy_global[local_slice]
108
109#
110# Model parameters
111#
112
113# physical parameters
114GRAVITY = 9.81
115DEPTH = 100.0
116CORIOLIS_F = 2e-4
117CORIOLIS_BETA = 2e-11
118CORIOLIS_PARAM = CORIOLIS_F + yy_local * CORIOLIS_BETA
119LATERAL_VISCOSITY = 1e-3 * CORIOLIS_F * dx**2
120
121# other parameters
122DAY_IN_SECONDS = 86_400
123PERIODIC_BOUNDARY_X = True
124
125ADAMS_BASHFORTH_A = 1.5 + 0.1
126ADAMS_BASHFORTH_B = -(0.5 + 0.1)
127
128# output parameters
129PLOT_ETA_RANGE = 10
130PLOT_EVERY = 100
131MAX_QUIVERS = (25, 50)
132
133
134# set time step based on CFL condition
135dt = 0.125 * min(dx, dy) / np.sqrt(GRAVITY * DEPTH)
136
137
138@jax.jit
139def get_initial_conditions():
140    """For the initial conditions, we use a horizontal jet in geostrophic balance."""
141    # global initial conditions
142    u0_global = 10 * jnp.exp(
143        -((yy_global - 0.5 * length_y) ** 2) / (0.02 * length_x) ** 2
144    )
145    v0_global = jnp.zeros_like(u0_global)
146
147    # approximate balance h_y = -(f/g)u
148    coriolis_global = CORIOLIS_F + yy_global * CORIOLIS_BETA
149    h_geostrophy = np.cumsum(-dy * u0_global * coriolis_global / GRAVITY, axis=0)
150    h0_global = (
151        DEPTH
152        + h_geostrophy
153        # make sure h0 is centered around depth
154        - h_geostrophy.mean()
155        # small perturbation to break symmetry
156        + 0.2
157        * np.sin(xx_global / length_x * 10 * np.pi)
158        * np.cos(yy_global / length_y * 8 * np.pi)
159    )
160
161    h0_local = h0_global[local_slice]
162    u0_local = u0_global[local_slice]
163    v0_local = v0_global[local_slice]
164
165    h0_local = enforce_boundaries(h0_local, "h")
166    u0_local = enforce_boundaries(u0_local, "u")
167    v0_local = enforce_boundaries(v0_local, "v")
168
169    return h0_local, u0_local, v0_local
170
171
172@partial(jax.jit, static_argnums=(1,))
173def enforce_boundaries(arr, grid):
174    """Handle boundary exchange between processors.
175
176    This is where mpi4jax comes in!
177    """
178    assert grid in ("h", "u", "v")
179
180    # start sending west, go clockwise
181    send_order = (
182        "west",
183        "north",
184        "east",
185        "south",
186    )
187
188    # start receiving east, go clockwise
189    recv_order = (
190        "east",
191        "south",
192        "west",
193        "north",
194    )
195
196    overlap_slices_send = dict(
197        south=(1, slice(None), Ellipsis),
198        west=(slice(None), 1, Ellipsis),
199        north=(-2, slice(None), Ellipsis),
200        east=(slice(None), -2, Ellipsis),
201    )
202
203    overlap_slices_recv = dict(
204        south=(0, slice(None), Ellipsis),
205        west=(slice(None), 0, Ellipsis),
206        north=(-1, slice(None), Ellipsis),
207        east=(slice(None), -1, Ellipsis),
208    )
209
210    proc_neighbors = {
211        "south": (proc_idx[0] - 1, proc_idx[1]) if proc_idx[0] > 0 else None,
212        "west": (proc_idx[0], proc_idx[1] - 1) if proc_idx[1] > 0 else None,
213        "north": (proc_idx[0] + 1, proc_idx[1]) if proc_idx[0] < nproc_y - 1 else None,
214        "east": (proc_idx[0], proc_idx[1] + 1) if proc_idx[1] < nproc_x - 1 else None,
215    }
216
217    if PERIODIC_BOUNDARY_X:
218        if proc_idx[1] == 0:
219            proc_neighbors["west"] = (proc_idx[0], nproc_x - 1)
220
221        if proc_idx[1] == nproc_x - 1:
222            proc_neighbors["east"] = (proc_idx[0], 0)
223
224    for send_dir, recv_dir in zip(send_order, recv_order):
225        send_proc = proc_neighbors[send_dir]
226        recv_proc = proc_neighbors[recv_dir]
227
228        if send_proc is None and recv_proc is None:
229            continue
230
231        if send_proc is not None:
232            send_proc = np.ravel_multi_index(send_proc, (nproc_y, nproc_x))
233
234        if recv_proc is not None:
235            recv_proc = np.ravel_multi_index(recv_proc, (nproc_y, nproc_x))
236
237        recv_idx = overlap_slices_recv[recv_dir]
238        recv_arr = jnp.empty_like(arr[recv_idx])
239
240        send_idx = overlap_slices_send[send_dir]
241        send_arr = arr[send_idx]
242
243        if send_proc is None:
244            recv_arr = mpi4jax.recv(recv_arr, source=recv_proc, comm=mpi_comm)
245            arr = arr.at[recv_idx].set(recv_arr)
246        elif recv_proc is None:
247            mpi4jax.send(send_arr, dest=send_proc, comm=mpi_comm)
248        else:
249            recv_arr = mpi4jax.sendrecv(
250                send_arr,
251                recv_arr,
252                source=recv_proc,
253                dest=send_proc,
254                comm=mpi_comm,
255            )
256            arr = arr.at[recv_idx].set(recv_arr)
257
258    if not PERIODIC_BOUNDARY_X and grid == "u" and proc_idx[1] == nproc_x - 1:
259        arr = arr.at[:, -2].set(0.0)
260
261    if grid == "v" and proc_idx[0] == nproc_y - 1:
262        arr = arr.at[-2, :].set(0.0)
263
264    return arr
265
266
267ModelState = namedtuple("ModelState", "h, u, v, dh, du, dv")
268
269
270@partial(jax.jit, static_argnums=(1,))
271def shallow_water_step(state, is_first_step):
272    """Perform one step of the shallow-water model.
273
274    Returns modified model state.
275    """
276    h, u, v, dh, du, dv = state
277
278    hc = jnp.pad(h[1:-1, 1:-1], 1, "edge")
279    hc = enforce_boundaries(hc, "h")
280
281    fe = jnp.empty_like(u)
282    fn = jnp.empty_like(u)
283
284    fe = fe.at[1:-1, 1:-1].set(0.5 * (hc[1:-1, 1:-1] + hc[1:-1, 2:]) * u[1:-1, 1:-1])
285    fn = fn.at[1:-1, 1:-1].set(0.5 * (hc[1:-1, 1:-1] + hc[2:, 1:-1]) * v[1:-1, 1:-1])
286    fe = enforce_boundaries(fe, "u")
287    fn = enforce_boundaries(fn, "v")
288
289    dh_new = dh.at[1:-1, 1:-1].set(
290        -(fe[1:-1, 1:-1] - fe[1:-1, :-2]) / dx - (fn[1:-1, 1:-1] - fn[:-2, 1:-1]) / dy
291    )
292
293    # nonlinear momentum equation
294    q = jnp.empty_like(u)
295    ke = jnp.empty_like(u)
296
297    # planetary and relative vorticity
298    q = q.at[1:-1, 1:-1].set(
299        CORIOLIS_PARAM[1:-1, 1:-1]
300        + ((v[1:-1, 2:] - v[1:-1, 1:-1]) / dx - (u[2:, 1:-1] - u[1:-1, 1:-1]) / dy)
301    )
302    # potential vorticity
303    q = q.at[1:-1, 1:-1].mul(
304        1.0 / (0.25 * (hc[1:-1, 1:-1] + hc[1:-1, 2:] + hc[2:, 1:-1] + hc[2:, 2:]))
305    )
306    q = enforce_boundaries(q, "h")
307
308    du_new = du.at[1:-1, 1:-1].set(
309        -GRAVITY * (h[1:-1, 2:] - h[1:-1, 1:-1]) / dx
310        + 0.5
311        * (
312            q[1:-1, 1:-1] * 0.5 * (fn[1:-1, 1:-1] + fn[1:-1, 2:])
313            + q[:-2, 1:-1] * 0.5 * (fn[:-2, 1:-1] + fn[:-2, 2:])
314        )
315    )
316    dv_new = dv.at[1:-1, 1:-1].set(
317        -GRAVITY * (h[2:, 1:-1] - h[1:-1, 1:-1]) / dy
318        - 0.5
319        * (
320            q[1:-1, 1:-1] * 0.5 * (fe[1:-1, 1:-1] + fe[2:, 1:-1])
321            + q[1:-1, :-2] * 0.5 * (fe[1:-1, :-2] + fe[2:, :-2])
322        )
323    )
324    ke = ke.at[1:-1, 1:-1].set(
325        0.5
326        * (
327            0.5 * (u[1:-1, 1:-1] ** 2 + u[1:-1, :-2] ** 2)
328            + 0.5 * (v[1:-1, 1:-1] ** 2 + v[:-2, 1:-1] ** 2)
329        )
330    )
331    ke = enforce_boundaries(ke, "h")
332
333    du_new = du_new.at[1:-1, 1:-1].add(-(ke[1:-1, 2:] - ke[1:-1, 1:-1]) / dx)
334    dv_new = dv_new.at[1:-1, 1:-1].add(-(ke[2:, 1:-1] - ke[1:-1, 1:-1]) / dy)
335
336    if is_first_step:
337        u = u.at[1:-1, 1:-1].add(dt * du_new[1:-1, 1:-1])
338        v = v.at[1:-1, 1:-1].add(dt * dv_new[1:-1, 1:-1])
339        h = h.at[1:-1, 1:-1].add(dt * dh_new[1:-1, 1:-1])
340    else:
341        u = u.at[1:-1, 1:-1].add(
342            dt
343            * (
344                ADAMS_BASHFORTH_A * du_new[1:-1, 1:-1]
345                + ADAMS_BASHFORTH_B * du[1:-1, 1:-1]
346            )
347        )
348        v = v.at[1:-1, 1:-1].add(
349            dt
350            * (
351                ADAMS_BASHFORTH_A * dv_new[1:-1, 1:-1]
352                + ADAMS_BASHFORTH_B * dv[1:-1, 1:-1]
353            )
354        )
355        h = h.at[1:-1, 1:-1].add(
356            dt
357            * (
358                ADAMS_BASHFORTH_A * dh_new[1:-1, 1:-1]
359                + ADAMS_BASHFORTH_B * dh[1:-1, 1:-1]
360            )
361        )
362
363    h = enforce_boundaries(h, "h")
364    u = enforce_boundaries(u, "u")
365    v = enforce_boundaries(v, "v")
366
367    if LATERAL_VISCOSITY > 0:
368        # lateral friction
369        fe = fe.at[1:-1, 1:-1].set(
370            LATERAL_VISCOSITY * (u[1:-1, 2:] - u[1:-1, 1:-1]) / dx
371        )
372        fn = fn.at[1:-1, 1:-1].set(
373            LATERAL_VISCOSITY * (u[2:, 1:-1] - u[1:-1, 1:-1]) / dy
374        )
375        fe = enforce_boundaries(fe, "u")
376        fn = enforce_boundaries(fn, "v")
377
378        u = u.at[1:-1, 1:-1].add(
379            dt
380            * (
381                (fe[1:-1, 1:-1] - fe[1:-1, :-2]) / dx
382                + (fn[1:-1, 1:-1] - fn[:-2, 1:-1]) / dy
383            )
384        )
385
386        fe = fe.at[1:-1, 1:-1].set(
387            LATERAL_VISCOSITY * (v[1:-1, 2:] - u[1:-1, 1:-1]) / dx
388        )
389        fn = fn.at[1:-1, 1:-1].set(
390            LATERAL_VISCOSITY * (v[2:, 1:-1] - u[1:-1, 1:-1]) / dy
391        )
392        fe = enforce_boundaries(fe, "u")
393        fn = enforce_boundaries(fn, "v")
394
395        v = v.at[1:-1, 1:-1].add(
396            dt
397            * (
398                (fe[1:-1, 1:-1] - fe[1:-1, :-2]) / dx
399                + (fn[1:-1, 1:-1] - fn[:-2, 1:-1]) / dy
400            )
401        )
402
403    return ModelState(h, u, v, dh_new, du_new, dv_new)
404
405
406@partial(jax.jit, static_argnums=(1,))
407def do_multistep(state, num_steps):
408    """Perform multiple model steps back-to-back."""
409    return jax.lax.fori_loop(
410        0, num_steps, lambda _, s: shallow_water_step(s, False), state
411    )
412
413
414def solve_shallow_water(t1, num_multisteps=10):
415    """Iterate the model forward in time."""
416    # initial conditions
417    h, u, v = get_initial_conditions()
418    du, dv, dh = (jnp.zeros((ny_local, nx_local)) for _ in range(3))
419
420    state = ModelState(h, u, v, dh, du, dv)
421    sol = [state]
422
423    state = shallow_water_step(state, True)
424    sol.append(state)
425    t = dt
426
427    if HAS_TQDM:
428        pbar = tqdm.tqdm(
429            total=math.ceil(t1 / dt),
430            disable=mpi_rank != 0,
431            unit="model day",
432            initial=t / DAY_IN_SECONDS,
433            unit_scale=dt / DAY_IN_SECONDS,
434            bar_format=(
435                "{l_bar}{bar}| {n:.2f}/{total:.2f} [{elapsed}<{remaining}, "
436                "{rate_fmt}{postfix}]"
437            ),
438        )
439
440    # pre-compile JAX kernel
441    do_multistep(state, num_multisteps)
442
443    start = time.perf_counter()
444    with ExitStack() as es:
445        if HAS_TQDM:
446            es.enter_context(pbar)
447
448        while t < t1:
449            state = do_multistep(state, num_multisteps)
450            state[0].block_until_ready()
451            sol.append(state)
452
453            t += dt * num_multisteps
454
455            if t < t1 and HAS_TQDM:
456                pbar.update(num_multisteps)
457
458    end = time.perf_counter()
459
460    if mpi_rank == 0:
461        print(f"\nSolution took {end - start:.2f}s")
462
463    return sol
464
465
466@jax.vmap
467@jax.jit
468def reassemble_array(arr):
469    """This converts an array containing the solution from each processor as
470    first axis to the full solution.
471
472    Shape (mpi_size, ny_local, nx_local) -> (ny_global, nx_global)
473    """
474    out = jnp.empty((ny_global, nx_global), dtype=arr.dtype)
475    for i in range(mpi_size):
476        proc_idx_i = np.unravel_index(i, (nproc_y, nproc_x))
477        local_slice_i = (
478            slice(
479                (ny_local - 2) * proc_idx_i[0],
480                (ny_local - 2) * proc_idx_i[0] + ny_local,
481            ),
482            slice(
483                (nx_local - 2) * proc_idx_i[1],
484                (nx_local - 2) * proc_idx_i[1] + nx_local,
485            ),
486        )
487        out = out.at[local_slice_i].set(arr[i])
488
489    return out
490
491
492def animate_shallow_water(sol):
493    """Create a matplotlib animation of the result."""
494    import matplotlib.pyplot as plt
495    from matplotlib import animation
496
497    quiver_stride = (
498        slice(1, -1, ny_global // MAX_QUIVERS[0]),
499        slice(1, -1, nx_global // MAX_QUIVERS[1]),
500    )
501
502    # set up figure
503    fig = plt.figure(figsize=(6, 4))
504    ax = plt.gca()
505
506    # surface plot of height anomaly
507    cs = ax.pcolormesh(
508        0.5 * (x_global[:-1] + x_global[1:]) / 1e3,
509        0.5 * (y_global[:-1] + y_global[1:]) / 1e3,
510        sol[0].h[1:-1, 1:-1] - DEPTH,
511        vmin=-PLOT_ETA_RANGE,
512        vmax=PLOT_ETA_RANGE,
513        cmap="RdBu_r",
514    )
515
516    # quiver plot of velocity
517    cq = ax.quiver(
518        xx_global[quiver_stride] / 1e3,
519        yy_global[quiver_stride] / 1e3,
520        sol[0].u[quiver_stride],
521        sol[0].v[quiver_stride],
522        clip_on=True,
523    )
524
525    # time indicator
526    t = ax.text(
527        s="",
528        x=0.05,
529        y=0.95,
530        ha="left",
531        va="top",
532        backgroundcolor=(1, 1, 1, 0.8),
533        transform=ax.transAxes,
534    )
535
536    ax.set(
537        aspect="equal",
538        xlim=(x_global[1] / 1e3, x_global[-2] / 1e3),
539        ylim=(y_global[1] / 1e3, y_global[-2] / 1e3),
540        xlabel="$x$ (km)",
541        ylabel="$y$ (km)",
542    )
543
544    plt.colorbar(
545        cs,
546        orientation="horizontal",
547        label="Surface height anomaly (m)",
548        pad=0.2,
549        shrink=0.8,
550    )
551    fig.tight_layout()
552
553    def animate(i):
554        state = sol[i]
555
556        eta = state.h - DEPTH
557        cs.set_array(eta[1:-1, 1:-1].flatten())
558        cq.set_UVC(state.u[quiver_stride], state.v[quiver_stride])
559
560        current_time = PLOT_EVERY * dt * i
561        t.set_text(f"t = {current_time / DAY_IN_SECONDS:.2f} days")
562        return (cs, cq, t)
563
564    anim = animation.FuncAnimation(
565        fig, animate, frames=len(sol), interval=50, blit=True, repeat_delay=3_000
566    )
567    return anim
568
569
570if __name__ == "__main__":
571    benchmark_mode = "--benchmark" in sys.argv
572
573    sol = solve_shallow_water(t1=10 * DAY_IN_SECONDS, num_multisteps=PLOT_EVERY)
574
575    if benchmark_mode:
576        sys.exit(0)
577
578    # copy solution to mpi_rank 0
579    full_sol_arr, _ = mpi4jax.gather(jnp.asarray(sol), root=0, comm=mpi_comm)
580
581    if mpi_rank == 0:
582        # full_sol_arr has shape (nproc, time, nvars, ny, nx)
583        full_sol_arr = jnp.moveaxis(full_sol_arr, 0, 2)
584        full_sol = [ModelState(*reassemble_array(x)) for x in full_sol_arr]
585
586        anim = animate_shallow_water(full_sol)
587
588        if "--save-animation" in sys.argv:
589            # save animation as MP4 video (requires ffmpeg)
590            anim.save("shallow-water.mp4", writer="ffmpeg", dpi=100)
591        else:
592            import matplotlib.pyplot as plt
593
594            plt.show()

Download source (shallow_water.py)