examples/shallow_water.py

  1"""mpi4jax demo application -- Shallow water
  2
  3A non-linear shallow water solver, adapted from:
  4
  5https://github.com/dionhaefner/shallow-water
  6
  7Usage examples:
  8
  9    # runs demo on 4 processes
 10    $ mpirun -n 4 python shallow_water.py
 11
 12    # saves output animation as shallow-water.mp4
 13    $ mpirun -n 4 python shallow_water.py --save-animation
 14
 15    # runs demo as a benchmark (no output)
 16    $ mpirun -n 4 python shallow_water.py --benchmark
 17
 18"""
 19
 20import os
 21import sys
 22import math
 23import time
 24import warnings
 25from contextlib import ExitStack
 26from collections import namedtuple
 27from functools import partial
 28
 29import numpy as np
 30from mpi4py import MPI
 31
 32try:
 33    import tqdm
 34except ImportError:
 35    warnings.warn("Could not import tqdm, can't show progress bar")
 36    HAS_TQDM = False
 37else:
 38    HAS_TQDM = True
 39
 40mpi_comm = MPI.COMM_WORLD
 41mpi_rank = mpi_comm.Get_rank()
 42mpi_size = mpi_comm.Get_size()
 43
 44# on GPU, put each process on its own device
 45os.environ["CUDA_VISIBLE_DEVICES"] = str(mpi_rank)
 46
 47import jax  # noqa: E402
 48import jax.numpy as jnp  # noqa: E402
 49
 50import mpi4jax  # noqa: E402
 51
 52
 53#
 54# MPI setup
 55#
 56
 57supported_nproc = (1, 2, 4, 6, 8, 16)
 58if mpi_size not in supported_nproc:
 59    raise RuntimeError(
 60        f"Got invalid number of MPI processes: {mpi_size}. "
 61        f"Please choose one of these: {supported_nproc}."
 62    )
 63
 64nproc_y = min(mpi_size, 2)
 65nproc_x = mpi_size // nproc_y
 66
 67proc_idx = np.unravel_index(mpi_rank, (nproc_y, nproc_x))
 68
 69#
 70# Grid setup
 71#
 72
 73# we use 1 cell overlap on each side of the domain
 74nx_global = 360 + 2
 75ny_global = 180 + 2
 76
 77# grid spacing in metres
 78dx = 5e3
 79dy = 5e3
 80
 81# make sure processes divide the domain evenly
 82assert (nx_global - 2) % nproc_x == 0
 83assert (ny_global - 2) % nproc_y == 0
 84
 85nx_local = (nx_global - 2) // nproc_x + 2
 86ny_local = (ny_global - 2) // nproc_y + 2
 87
 88x_global, y_global = (
 89    np.arange(-1, nx_global - 1) * dx,
 90    np.arange(-1, ny_global - 1) * dy,
 91)
 92yy_global, xx_global = np.meshgrid(y_global, x_global, indexing="ij")
 93
 94length_x = x_global[-2] - x_global[1]
 95length_y = y_global[-2] - y_global[1]
 96
 97# this extracts the processor-local domain from a global array
 98local_slice = (
 99    slice((ny_local - 2) * proc_idx[0], (ny_local - 2) * proc_idx[0] + ny_local),
100    slice((nx_local - 2) * proc_idx[1], (nx_local - 2) * proc_idx[1] + nx_local),
101)
102
103x_local = x_global[local_slice[1]]
104y_local = y_global[local_slice[0]]
105
106xx_local = xx_global[local_slice]
107yy_local = yy_global[local_slice]
108
109#
110# Model parameters
111#
112
113# physical parameters
114GRAVITY = 9.81
115DEPTH = 100.0
116CORIOLIS_F = 2e-4
117CORIOLIS_BETA = 2e-11
118CORIOLIS_PARAM = CORIOLIS_F + yy_local * CORIOLIS_BETA
119LATERAL_VISCOSITY = 1e-3 * CORIOLIS_F * dx**2
120
121# other parameters
122DAY_IN_SECONDS = 86_400
123PERIODIC_BOUNDARY_X = True
124
125ADAMS_BASHFORTH_A = 1.5 + 0.1
126ADAMS_BASHFORTH_B = -(0.5 + 0.1)
127
128# output parameters
129PLOT_ETA_RANGE = 10
130PLOT_EVERY = 100
131MAX_QUIVERS = (25, 50)
132
133
134# set time step based on CFL condition
135dt = 0.125 * min(dx, dy) / np.sqrt(GRAVITY * DEPTH)
136
137
138@jax.jit
139def get_initial_conditions():
140    """For the initial conditions, we use a horizontal jet in geostrophic balance."""
141    # global initial conditions
142    u0_global = 10 * jnp.exp(
143        -((yy_global - 0.5 * length_y) ** 2) / (0.02 * length_x) ** 2
144    )
145    v0_global = jnp.zeros_like(u0_global)
146
147    # approximate balance h_y = -(f/g)u
148    coriolis_global = CORIOLIS_F + yy_global * CORIOLIS_BETA
149    h_geostrophy = np.cumsum(-dy * u0_global * coriolis_global / GRAVITY, axis=0)
150    h0_global = (
151        DEPTH
152        + h_geostrophy
153        # make sure h0 is centered around depth
154        - h_geostrophy.mean()
155        # small perturbation to break symmetry
156        + 0.2
157        * np.sin(xx_global / length_x * 10 * np.pi)
158        * np.cos(yy_global / length_y * 8 * np.pi)
159    )
160
161    h0_local = h0_global[local_slice]
162    u0_local = u0_global[local_slice]
163    v0_local = v0_global[local_slice]
164
165    token = jax.lax.create_token()
166    h0_local, token = enforce_boundaries(h0_local, "h", token)
167    u0_local, token = enforce_boundaries(u0_local, "u", token)
168    v0_local, token = enforce_boundaries(v0_local, "v", token)
169
170    return h0_local, u0_local, v0_local
171
172
173@partial(jax.jit, static_argnums=(1,))
174def enforce_boundaries(arr, grid, token=None):
175    """Handle boundary exchange between processors.
176
177    This is where mpi4jax comes in!
178    """
179    assert grid in ("h", "u", "v")
180
181    # start sending west, go clockwise
182    send_order = (
183        "west",
184        "north",
185        "east",
186        "south",
187    )
188
189    # start receiving east, go clockwise
190    recv_order = (
191        "east",
192        "south",
193        "west",
194        "north",
195    )
196
197    overlap_slices_send = dict(
198        south=(1, slice(None), Ellipsis),
199        west=(slice(None), 1, Ellipsis),
200        north=(-2, slice(None), Ellipsis),
201        east=(slice(None), -2, Ellipsis),
202    )
203
204    overlap_slices_recv = dict(
205        south=(0, slice(None), Ellipsis),
206        west=(slice(None), 0, Ellipsis),
207        north=(-1, slice(None), Ellipsis),
208        east=(slice(None), -1, Ellipsis),
209    )
210
211    proc_neighbors = {
212        "south": (proc_idx[0] - 1, proc_idx[1]) if proc_idx[0] > 0 else None,
213        "west": (proc_idx[0], proc_idx[1] - 1) if proc_idx[1] > 0 else None,
214        "north": (proc_idx[0] + 1, proc_idx[1]) if proc_idx[0] < nproc_y - 1 else None,
215        "east": (proc_idx[0], proc_idx[1] + 1) if proc_idx[1] < nproc_x - 1 else None,
216    }
217
218    if PERIODIC_BOUNDARY_X:
219        if proc_idx[1] == 0:
220            proc_neighbors["west"] = (proc_idx[0], nproc_x - 1)
221
222        if proc_idx[1] == nproc_x - 1:
223            proc_neighbors["east"] = (proc_idx[0], 0)
224
225    if token is None:
226        token = jax.lax.create_token()
227
228    for send_dir, recv_dir in zip(send_order, recv_order):
229        send_proc = proc_neighbors[send_dir]
230        recv_proc = proc_neighbors[recv_dir]
231
232        if send_proc is None and recv_proc is None:
233            continue
234
235        if send_proc is not None:
236            send_proc = np.ravel_multi_index(send_proc, (nproc_y, nproc_x))
237
238        if recv_proc is not None:
239            recv_proc = np.ravel_multi_index(recv_proc, (nproc_y, nproc_x))
240
241        recv_idx = overlap_slices_recv[recv_dir]
242        recv_arr = jnp.empty_like(arr[recv_idx])
243
244        send_idx = overlap_slices_send[send_dir]
245        send_arr = arr[send_idx]
246
247        if send_proc is None:
248            recv_arr, token = mpi4jax.recv(
249                recv_arr, source=recv_proc, comm=mpi_comm, token=token
250            )
251            arr = arr.at[recv_idx].set(recv_arr)
252        elif recv_proc is None:
253            token = mpi4jax.send(send_arr, dest=send_proc, comm=mpi_comm, token=token)
254        else:
255            recv_arr, token = mpi4jax.sendrecv(
256                send_arr,
257                recv_arr,
258                source=recv_proc,
259                dest=send_proc,
260                comm=mpi_comm,
261                token=token,
262            )
263            arr = arr.at[recv_idx].set(recv_arr)
264
265    if not PERIODIC_BOUNDARY_X and grid == "u" and proc_idx[1] == nproc_x - 1:
266        arr = arr.at[:, -2].set(0.0)
267
268    if grid == "v" and proc_idx[0] == nproc_y - 1:
269        arr = arr.at[-2, :].set(0.0)
270
271    return arr, token
272
273
274ModelState = namedtuple("ModelState", "h, u, v, dh, du, dv")
275
276
277@partial(jax.jit, static_argnums=(1,))
278def shallow_water_step(state, is_first_step):
279    """Perform one step of the shallow-water model.
280
281    Returns modified model state.
282    """
283    token = jax.lax.create_token()
284
285    h, u, v, dh, du, dv = state
286
287    hc = jnp.pad(h[1:-1, 1:-1], 1, "edge")
288    hc, token = enforce_boundaries(hc, "h", token)
289
290    fe = jnp.empty_like(u)
291    fn = jnp.empty_like(u)
292
293    fe = fe.at[1:-1, 1:-1].set(0.5 * (hc[1:-1, 1:-1] + hc[1:-1, 2:]) * u[1:-1, 1:-1])
294    fn = fn.at[1:-1, 1:-1].set(0.5 * (hc[1:-1, 1:-1] + hc[2:, 1:-1]) * v[1:-1, 1:-1])
295    fe, token = enforce_boundaries(fe, "u", token)
296    fn, token = enforce_boundaries(fn, "v", token)
297
298    dh_new = dh.at[1:-1, 1:-1].set(
299        -(fe[1:-1, 1:-1] - fe[1:-1, :-2]) / dx - (fn[1:-1, 1:-1] - fn[:-2, 1:-1]) / dy
300    )
301
302    # nonlinear momentum equation
303    q = jnp.empty_like(u)
304    ke = jnp.empty_like(u)
305
306    # planetary and relative vorticity
307    q = q.at[1:-1, 1:-1].set(
308        CORIOLIS_PARAM[1:-1, 1:-1]
309        + ((v[1:-1, 2:] - v[1:-1, 1:-1]) / dx - (u[2:, 1:-1] - u[1:-1, 1:-1]) / dy)
310    )
311    # potential vorticity
312    q = q.at[1:-1, 1:-1].mul(
313        1.0 / (0.25 * (hc[1:-1, 1:-1] + hc[1:-1, 2:] + hc[2:, 1:-1] + hc[2:, 2:]))
314    )
315    q, token = enforce_boundaries(q, "h", token)
316
317    du_new = du.at[1:-1, 1:-1].set(
318        -GRAVITY * (h[1:-1, 2:] - h[1:-1, 1:-1]) / dx
319        + 0.5
320        * (
321            q[1:-1, 1:-1] * 0.5 * (fn[1:-1, 1:-1] + fn[1:-1, 2:])
322            + q[:-2, 1:-1] * 0.5 * (fn[:-2, 1:-1] + fn[:-2, 2:])
323        )
324    )
325    dv_new = dv.at[1:-1, 1:-1].set(
326        -GRAVITY * (h[2:, 1:-1] - h[1:-1, 1:-1]) / dy
327        - 0.5
328        * (
329            q[1:-1, 1:-1] * 0.5 * (fe[1:-1, 1:-1] + fe[2:, 1:-1])
330            + q[1:-1, :-2] * 0.5 * (fe[1:-1, :-2] + fe[2:, :-2])
331        )
332    )
333    ke = ke.at[1:-1, 1:-1].set(
334        0.5
335        * (
336            0.5 * (u[1:-1, 1:-1] ** 2 + u[1:-1, :-2] ** 2)
337            + 0.5 * (v[1:-1, 1:-1] ** 2 + v[:-2, 1:-1] ** 2)
338        )
339    )
340    ke, token = enforce_boundaries(ke, "h", token)
341
342    du_new = du_new.at[1:-1, 1:-1].add(-(ke[1:-1, 2:] - ke[1:-1, 1:-1]) / dx)
343    dv_new = dv_new.at[1:-1, 1:-1].add(-(ke[2:, 1:-1] - ke[1:-1, 1:-1]) / dy)
344
345    if is_first_step:
346        u = u.at[1:-1, 1:-1].add(dt * du_new[1:-1, 1:-1])
347        v = v.at[1:-1, 1:-1].add(dt * dv_new[1:-1, 1:-1])
348        h = h.at[1:-1, 1:-1].add(dt * dh_new[1:-1, 1:-1])
349    else:
350        u = u.at[1:-1, 1:-1].add(
351            dt
352            * (
353                ADAMS_BASHFORTH_A * du_new[1:-1, 1:-1]
354                + ADAMS_BASHFORTH_B * du[1:-1, 1:-1]
355            )
356        )
357        v = v.at[1:-1, 1:-1].add(
358            dt
359            * (
360                ADAMS_BASHFORTH_A * dv_new[1:-1, 1:-1]
361                + ADAMS_BASHFORTH_B * dv[1:-1, 1:-1]
362            )
363        )
364        h = h.at[1:-1, 1:-1].add(
365            dt
366            * (
367                ADAMS_BASHFORTH_A * dh_new[1:-1, 1:-1]
368                + ADAMS_BASHFORTH_B * dh[1:-1, 1:-1]
369            )
370        )
371
372    h, token = enforce_boundaries(h, "h", token)
373    u, token = enforce_boundaries(u, "u", token)
374    v, token = enforce_boundaries(v, "v", token)
375
376    if LATERAL_VISCOSITY > 0:
377        # lateral friction
378        fe = fe.at[1:-1, 1:-1].set(
379            LATERAL_VISCOSITY * (u[1:-1, 2:] - u[1:-1, 1:-1]) / dx
380        )
381        fn = fn.at[1:-1, 1:-1].set(
382            LATERAL_VISCOSITY * (u[2:, 1:-1] - u[1:-1, 1:-1]) / dy
383        )
384        fe, token = enforce_boundaries(fe, "u", token)
385        fn, token = enforce_boundaries(fn, "v", token)
386
387        u = u.at[1:-1, 1:-1].add(
388            dt
389            * (
390                (fe[1:-1, 1:-1] - fe[1:-1, :-2]) / dx
391                + (fn[1:-1, 1:-1] - fn[:-2, 1:-1]) / dy
392            )
393        )
394
395        fe = fe.at[1:-1, 1:-1].set(
396            LATERAL_VISCOSITY * (v[1:-1, 2:] - u[1:-1, 1:-1]) / dx
397        )
398        fn = fn.at[1:-1, 1:-1].set(
399            LATERAL_VISCOSITY * (v[2:, 1:-1] - u[1:-1, 1:-1]) / dy
400        )
401        fe, token = enforce_boundaries(fe, "u", token)
402        fn, token = enforce_boundaries(fn, "v", token)
403
404        v = v.at[1:-1, 1:-1].add(
405            dt
406            * (
407                (fe[1:-1, 1:-1] - fe[1:-1, :-2]) / dx
408                + (fn[1:-1, 1:-1] - fn[:-2, 1:-1]) / dy
409            )
410        )
411
412    return ModelState(h, u, v, dh_new, du_new, dv_new)
413
414
415@partial(jax.jit, static_argnums=(1,))
416def do_multistep(state, num_steps):
417    """Perform multiple model steps back-to-back."""
418    return jax.lax.fori_loop(
419        0, num_steps, lambda _, s: shallow_water_step(s, False), state
420    )
421
422
423def solve_shallow_water(t1, num_multisteps=10):
424    """Iterate the model forward in time."""
425    # initial conditions
426    h, u, v = get_initial_conditions()
427    du, dv, dh = (jnp.zeros((ny_local, nx_local)) for _ in range(3))
428
429    state = ModelState(h, u, v, dh, du, dv)
430    sol = [state]
431
432    state = shallow_water_step(state, True)
433    sol.append(state)
434    t = dt
435
436    if HAS_TQDM:
437        pbar = tqdm.tqdm(
438            total=math.ceil(t1 / dt),
439            disable=mpi_rank != 0,
440            unit="model day",
441            initial=t / DAY_IN_SECONDS,
442            unit_scale=dt / DAY_IN_SECONDS,
443            bar_format=(
444                "{l_bar}{bar}| {n:.2f}/{total:.2f} [{elapsed}<{remaining}, "
445                "{rate_fmt}{postfix}]"
446            ),
447        )
448
449    # pre-compile JAX kernel
450    do_multistep(state, num_multisteps)
451
452    start = time.perf_counter()
453    with ExitStack() as es:
454        if HAS_TQDM:
455            es.enter_context(pbar)
456
457        while t < t1:
458            state = do_multistep(state, num_multisteps)
459            state[0].block_until_ready()
460            sol.append(state)
461
462            t += dt * num_multisteps
463
464            if t < t1 and HAS_TQDM:
465                pbar.update(num_multisteps)
466
467    end = time.perf_counter()
468
469    if mpi_rank == 0:
470        print(f"\nSolution took {end - start:.2f}s")
471
472    return sol
473
474
475@jax.vmap
476@jax.jit
477def reassemble_array(arr):
478    """This converts an array containing the solution from each processor as
479    first axis to the full solution.
480
481    Shape (mpi_size, ny_local, nx_local) -> (ny_global, nx_global)
482    """
483    out = jnp.empty((ny_global, nx_global), dtype=arr.dtype)
484    for i in range(mpi_size):
485        proc_idx_i = np.unravel_index(i, (nproc_y, nproc_x))
486        local_slice_i = (
487            slice(
488                (ny_local - 2) * proc_idx_i[0],
489                (ny_local - 2) * proc_idx_i[0] + ny_local,
490            ),
491            slice(
492                (nx_local - 2) * proc_idx_i[1],
493                (nx_local - 2) * proc_idx_i[1] + nx_local,
494            ),
495        )
496        out = out.at[local_slice_i].set(arr[i])
497
498    return out
499
500
501def animate_shallow_water(sol):
502    """Create a matplotlib animation of the result."""
503    import matplotlib.pyplot as plt
504    from matplotlib import animation
505
506    quiver_stride = (
507        slice(1, -1, ny_global // MAX_QUIVERS[0]),
508        slice(1, -1, nx_global // MAX_QUIVERS[1]),
509    )
510
511    # set up figure
512    fig = plt.figure(figsize=(6, 4))
513    ax = plt.gca()
514
515    # surface plot of height anomaly
516    cs = ax.pcolormesh(
517        0.5 * (x_global[:-1] + x_global[1:]) / 1e3,
518        0.5 * (y_global[:-1] + y_global[1:]) / 1e3,
519        sol[0].h[1:-1, 1:-1] - DEPTH,
520        vmin=-PLOT_ETA_RANGE,
521        vmax=PLOT_ETA_RANGE,
522        cmap="RdBu_r",
523    )
524
525    # quiver plot of velocity
526    cq = ax.quiver(
527        xx_global[quiver_stride] / 1e3,
528        yy_global[quiver_stride] / 1e3,
529        sol[0].u[quiver_stride],
530        sol[0].v[quiver_stride],
531        clip_on=True,
532    )
533
534    # time indicator
535    t = ax.text(
536        s="",
537        x=0.05,
538        y=0.95,
539        ha="left",
540        va="top",
541        backgroundcolor=(1, 1, 1, 0.8),
542        transform=ax.transAxes,
543    )
544
545    ax.set(
546        aspect="equal",
547        xlim=(x_global[1] / 1e3, x_global[-2] / 1e3),
548        ylim=(y_global[1] / 1e3, y_global[-2] / 1e3),
549        xlabel="$x$ (km)",
550        ylabel="$y$ (km)",
551    )
552
553    plt.colorbar(
554        cs,
555        orientation="horizontal",
556        label="Surface height anomaly (m)",
557        pad=0.2,
558        shrink=0.8,
559    )
560    fig.tight_layout()
561
562    def animate(i):
563        state = sol[i]
564
565        eta = state.h - DEPTH
566        cs.set_array(eta[1:-1, 1:-1].flatten())
567        cq.set_UVC(state.u[quiver_stride], state.v[quiver_stride])
568
569        current_time = PLOT_EVERY * dt * i
570        t.set_text(f"t = {current_time / DAY_IN_SECONDS:.2f} days")
571        return (cs, cq, t)
572
573    anim = animation.FuncAnimation(
574        fig, animate, frames=len(sol), interval=50, blit=True, repeat_delay=3_000
575    )
576    return anim
577
578
579if __name__ == "__main__":
580    benchmark_mode = "--benchmark" in sys.argv
581
582    sol = solve_shallow_water(t1=10 * DAY_IN_SECONDS, num_multisteps=PLOT_EVERY)
583
584    if benchmark_mode:
585        sys.exit(0)
586
587    # copy solution to mpi_rank 0
588    full_sol_arr, _ = mpi4jax.gather(jnp.asarray(sol), root=0, comm=mpi_comm)
589
590    if mpi_rank == 0:
591        # full_sol_arr has shape (nproc, time, nvars, ny, nx)
592        full_sol_arr = jnp.moveaxis(full_sol_arr, 0, 2)
593        full_sol = [ModelState(*reassemble_array(x)) for x in full_sol_arr]
594
595        anim = animate_shallow_water(full_sol)
596
597        if "--save-animation" in sys.argv:
598            # save animation as MP4 video (requires ffmpeg)
599            anim.save("shallow-water.mp4", writer="ffmpeg", dpi=100)
600        else:
601            import matplotlib.pyplot as plt
602
603            plt.show()

Download source (shallow_water.py)