Checkpointer

How to save and restart from checkpoints

ClimaCoupler supports saving and reading simulation checkpoints. This is useful to split a long simulation into smaller, more manageable chunks.

Checkpoints are a mix of HDF5 and JLD2 files and are typically saved in a checkpoints folder in the simulation output. See Utilities.setup_output_dirs for more information.

Known limitations
  • The number of MPI processes has to remain the same across checkpoints
  • Restart files are generally not portable across machines, julia versions, and package versions
  • Adding/changing new component models will probably require adding/changing code

Saving checkpoints

If you are running a model (such as AMIP), chances are that you can enable checkpointing just by setting a command-line argument; The checkpoint_dt option controls how frequently a checkpoint should be produced.

If your model does not come with this option already, you can checkpoint the simulation by adding a callback that calls the Checkpointer.checkpoint_sims function.

For example, to add a callback to checkpoint every hour of simulated time, assuming you have a start_date

import Dates

import ClimaCoupler: Checkpointer, TimeManager
import ClimaDiagnostics.Schedules: EveryCalendarDtSchedule

schedule = EveryCalendarDtSchedule(Dates.Hour(1); start_date)
checkpoint_callback = TimeManager.Callback(schedule_checkpoint, Checkpointer.checkpoint_sims)

# In the coupling loop:
TimeManager.maybe_trigger_callback(checkpoint_callback, coupled_simulation, time)

Reading checkpoints

There are two ways to restart a simulation from checkpoints. By default, ClimaCoupler tries finding suitable checkpoints and automatically use them. Alternatively, you can specify a directory restart_dir and a simulation time restart_t and restart from files saved in the given directory at the given time. If the model you are running supports writing checkpoints via command-line argument, it will probably also support reading them. In this case, the arguments restart_dir and restart_t identify the path of the top level directory containing all the checkpoint files and the simulated times in second.

If the model does not support directly reading a checkpoint, the Checkpointer module provides a straightforward way to add this feature. Checkpointer.restart! takes a coupled simulation, a restart_dir, and a restart_t and overwrites the content of the coupled simulation with what is in the checkpoint.

Developer notes

In theory, the state of the component models should fully determine the state of the coupled simulation and one should be able to restart a coupled simulation just by using the states of the component models. Unfortunately, this is currently not the case in ClimaCoupler. The main reason for this is the complex interdependencies between component models and within ClimaAtmos which make the initialization step inconsistent. For example, in a coupled simulation, the surface albedo should be determined by the surface models and used by the atmospheric model for radiation transfer, but ClimaAtmos also tries to set the surface albedo (since it has to do so when run in standalone mode). In addition to this, ClimaAtmos has a large cache that has internal interdependencies that are hard to disentangle, and changing a field might require changing some other field in a different part of the cache. As a result, it is not easy for ClimaCoupler to consistently do initialization from a cold state. To conclude, restarting a simulation exclusively using the states of the component models is currently impossible.

Given that restarting a simulation from the state is impossible, ClimaCoupler needs to save the states and the caches. Let us review how we use ClimaCore.InputOutput and JLD2 package to accomplish this.

ClimaCore.InputOutput provides a loss-less way to save the content of certain ClimaCore objects to HDF5 files. Objects saved in this way are not tied to a particular computing device or configuration. When running with MPI, ClimaCore.InputOutput are also efficiently written in parallel.

Unfortunately, ClimaCore.InputOutput only supports certain objects, such as Fields and Spaces, but the cache in component models is more complex than this and contains complex objects with highly stateful quantities (e.g., C pointers). Because of this, model states are saved to HDF5 but caches must be saved to JLD2 files.

JLD2 allows us to save more complex objects without writing specific serialization methods for every struct. JLD2 allows us to take a big step forward, but there are still several challenges that need to be solved:

  1. JLD2 does not support CUDA natively. To go around this, we have to move everything onto the CPU first. Then, when the data is read back, we have to move it back to the GPU.
  2. JLD2 does not support MPI natively. To go around this, each process writes its jld2 checkpoint and reads it back. This introduces the constraint that the number of MPI processes cannot change across restarts.
  3. Some quantities are best not saved and read (for example, anything with pointers). For this, we write a recursive function that traverses the cache and only restores quantities of a certain type (typically, ClimaCore objects)

Point 3. adds significant amount of code and requires component models to specify how their cache has to be restored.

Adding checkpointing to a new component model

There are two ways to add checkpoint/restart support for a new component model:

Path A (ClimaCore-based models): extend get_model_prog_state to return the prognostic state as a ClimaCore.FieldVector. The default checkpoint_model_state and restart_model_state! implementations will handle HDF5 I/O via ClimaCore.InputOutput automatically. This path is intended for models whose prognostic state is a ClimaCore.FieldVector; models that do not use ClimaCore should use Path B instead.

Checkpointer.get_model_prog_state
Checkpointer.get_model_cache
Checkpointer.restore_cache!

Path B (custom checkpoint format): override checkpoint_model_state and restart_model_state! directly for full control over the checkpoint format. This is the approach used by OceananigansSimulation, which writes JLD2 checkpoints via Oceananigans' native checkpoint and restores them with Oceananigans.set!.

Checkpointer.checkpoint_model_state
Checkpointer.checkpoint_model_cache  # optional; no-op if cache checkpointing is not supported
Checkpointer.restart_model_state!
Checkpointer.restart_model_cache!    # optional; warn or no-op if cache restore is not supported

ClimaCoupler moves objects to the CPU with Adapt(Array, x). Adapt traverses the object recursively, and proper Adapt methods have to be defined for every object involved in the chain. The easiest way to do this is using the Adapt.@adapt_structure macro, which defines a recursive Adapt for the given object.

Types to watch for:

  • MPI related objects (e.g., MPICommsContext)
  • TimeVaryingInputs (because they contain NCDatasets, which contain pointers to files)
Adapt and references

For objects that contain multiple fields referencing the same object, using the Adapt.@adapt_structure macro leads to unnecessary copies of the same object. This happens because Adapt.@adapt_structure defines a recursive Adapt that does not account for the possibility that multiple fields could be referencing the same object. As a result, this means that the same object is recreated over and over again when calling Adapt on the cache. This can easily make the file size of the saved cache much bigger than it needs to be. Because of this, we've implemented a CacheIterator object - please see the section below for details.

CacheIterator

Instead of defining a proper Adapt method for the cache, an alternative approach is to recursively iterate over the cache fields and selectively save only the parts that need to be saved. This recursive iteration is performed by the CacheIterator. To initialize a CacheIterator for a component model, you must implement get_cache_ignore.

Using the CacheIterator allows adapt to be called on each individual field instead of on the entire cache. Furthermore, the file size can be reduced by avoiding duplicate saves of fields that reference the same memory. This is accomplished by tracking object IDs and storing references to objects instead of creating copies when the same object is encountered multiple times.

This approach allows for a signficant reducation in the file size of the cache.

Checkpointer API

ClimaCoupler.Checkpointer.get_model_prog_stateFunction
Checkpointer.get_model_prog_state(sim::ClimaAtmosSimulation)

Extension of Checkpointer.getmodelprog_state to get the model state.

source
Checkpointer.get_model_prog_state(sim::BucketSimulation)

Extension of Checkpointer.getmodelprog_state to get the model state.

source
get_model_prog_state(sim::Interfacer.AbstractComponentSimulation)

Returns the model state of a simulation as a ClimaCore.FieldVector. This is a template function that should be implemented for each component model.

source
Checkpointer.get_model_prog_state(sim::SlabOceanSimulation)

Extension of Checkpointer.getmodelprog_state to get the model state.

source
Checkpointer.get_model_prog_state(sim::PrescribedIceSimulation)

Extension of Checkpointer.getmodelprog_state to get the model state.

source
ClimaCoupler.Checkpointer.get_model_cacheFunction
get_model_cache(sim::Interfacer.AbstractComponentSimulation)

Returns the model cache of a simulation. This is a template function that should be implemented for each component model.

source
ClimaCoupler.Checkpointer.get_model_cache_to_checkpointFunction
get_model_cache_to_checkpoint(sim::Interfacer.AbstractComponentSimulation)

Prepare the cache for checkpointing by moving the entire cache to CPU.

source
get_model_cache_to_checkpoint(sim::Interfacer.AbstractAtmosSimulation)

Prepare the atmos cache for checkpoint by selectively moving parts of the atmos cache to CPU instead of moving the entire atmos cache to CPU, resulting in a much smaller saved file.

Implementation Details

When moving the cache from GPU to CPU, calling adapt on the entire cache creates unnecessary duplicate objects because adapt is not properly defined for the entire cache structure. This function addresses three key issues:

  1. Individual adaptation: On GPU, adapt is called on each object separately rather than on the entire cache at once.

  2. Deduplication: Objects sharing the same object ID are not duplicated. Instead, references to already-processed objects are reused.

  3. Selective saving: Only the parts of the cache needed for restoration are saved.

  4. Recreated views: Views are preserved by recreating the views with the parent arrays on CPU. Otherwise, the arrays are unnecessarily duplicated when adapt is called since checks using objectid fail for views, resulting in a larger saved cache on GPU.

Returns

A vector of objects from the cache. Elements may reference the same underlying data if they share object IDs. The order of the objects in the vector is determined by CacheIterator.

source
ClimaCoupler.Checkpointer.checkpoint_model_stateFunction
checkpoint_model_state(
    sim::Interfacer.AbstractComponentSimulation,
    comms_ctx::ClimaComms.AbstractCommsContext,
    t::Int,
    prev_checkpoint_t::Int;
    output_dir = "output")

Checkpoint the model state of a simulation at time t (in seconds).

The default implementation uses get_model_prog_state(sim) to obtain a ClimaCore.FieldVector and writes it to an HDF5 file via ClimaCore.InputOutput. If get_model_prog_state returns nothing, this function does nothing.

Component models that do not use ClimaCore can override this method to use their own checkpointing.

If a previous checkpoint exists, it is removed. This is to avoid accumulating many checkpoint files in the output directory. A value of -1 for prev_checkpoint_t is used to indicate that there is no previous checkpoint to remove.

source
Checkpointer.checkpoint_model_state(sim, comms_ctx, t, prev_checkpoint_t; output_dir)

Save the state of an Oceananigans-backed simulation to a JLD2 file at time t (in seconds) using Oceananigans.checkpoint.

If a previous checkpoint exists, it is removed to avoid accumulating files. A value of -1 for prev_checkpoint_t indicates there is no previous checkpoint.

source
ClimaCoupler.Checkpointer.checkpoint_model_cacheFunction
checkpoint_model_cache(
    sim::Interfacer.AbstractComponentSimulation,
    comms_ctx::ClimaComms.AbstractCommsContext,
    t::Int,
    prev_checkpoint_t::Int;
    output_dir = "output")

Checkpoint the model cache to N JLD2 files at a given time, t (in seconds), where N is the number of MPI ranks.

The default implementation uses get_model_cache(sim) to obtain the cache. If get_model_cache returns nothing, this function does nothing.

Component models that do not use ClimaCore can override this method to use their own checkpointing.

Objects are saved to JLD2 files because caches are generally not ClimaCore objects (and ClimaCore.InputOutput can only save Fields or FieldVectors).

If a previous checkpoint exists, it is removed. This is to avoid accumulating many checkpoint files in the output directory. A value of -1 for prev_checkpoint_t is used to indicate that there is no previous checkpoint to remove.

source
ClimaCoupler.Checkpointer.restart!Function
restart!(cs::CoupledSimulation, checkpoint_dir, checkpoint_t, restart_cache)

Overwrite the content of cs with checkpoints in checkpoint_dir at time checkpoint_t.

If restart_cache is true, the cache will be read from the restart file using restore_cache!. Otherwise, the cache will be left unchanged.

Return a true if the simulation was restarted.

source
ClimaCoupler.Checkpointer.restart_model_state!Function
restart_model_state!(sim, input_file, comms_ctx)

Restore the prognostic state of sim from input_file.

The default implementation reads a ClimaCore.FieldVector from an HDF5 file written by the default checkpoint_model_state. If get_model_prog_state returns nothing, this function does nothing.

Component models that do not use ClimaCore can override this method to use their own checkpointing.

source
Checkpointer.restart_model_state!(sim, input_file, comms_ctx)

Restore the state of an Oceananigans-backed simulation from a JLD2 checkpoint file using Oceananigans.set!.

The coupler constructs input_file with a .hdf5 extension; this method replaces it with .jld2 to match the format written by checkpoint_model_state.

source
ClimaCoupler.Checkpointer.restart_model_cache!Function
restart_model_cache!(sim, input_file)

Restore the cache of sim from input_file.

The default implementation uses get_model_cache(sim) to check whether the simulation has a cache. If get_model_cache returns nothing, this function does nothing.

It relies on restore_cache!(sim, old_cache), which has to be implemented by the component models that have a cache.

source
Checkpointer.restart_model_cache!(sim, input_file)

No-op for Oceananigans-backed simulations. All necessary state is restored via restart_model_state!; there is no separate cache to restore.

source
ClimaCoupler.Checkpointer.restore!Function
restore!(v1, v2, comms_ctx; name = "", ignore = Set())

Recursively traverse v1 and v2, setting each field of v1 with the corresponding field in v2. In this, ignore all the properties that have name within the ignore iterable.

This is intended to be used when restarting a simulation's cache object from a checkpoint.

ignore is useful when there are stateful properties, such as live pointers.

source
restore!(
    v1::Union{
        AbstractTimeVaryingInput,
        ClimaComms.AbstractCommsContext,
        ClimaComms.AbstractDevice,
        UnionAll,
        DataType,
    },
    v2::Union{
        AbstractTimeVaryingInput,
        ClimaComms.AbstractCommsContext,
        ClimaComms.AbstractDevice,
        UnionAll,
        DataType,
    },
    _comms_ctx;
    name = "",
    ignore = Set(),
)

Ignore certain types that don't need to be restored. UnionAll and DataType are infinitely recursive, so we also ignore those.

source
restore!(
    v1::Union{CC.DataLayouts.AbstractData, AbstractArray},
    v2::Union{CC.DataLayouts.AbstractData, AbstractArray},
    comms_ctx;
    name = "",
    ignore = Set(),
)

For array-like objects, we move the original data (v2) to the device of the new data (v1). Then we copy the original data to the new object.

source
restore!(v1::LinearIndices, v2::AbstractArray, comms_ctx; name = "", ignore = Set())

Special case to compare LinearIndices to AbstractArray, which is needed for ClimaAtmos v0.32.

source
restore!(
    v1::Union{StaticArrays.StaticArray, Number, UnitRange, LinRange, Symbol},
    v2::Union{StaticArrays.StaticArray, Number, UnitRange, LinRange, Symbol},
    comms_ctx;
    name = "",
    ignore = Set(),
)

Ensure that immutable objects have been initialized correctly, as they cannot be restored from a checkpoint.

source
restore!(v1::Dict, v2::Dict, comms_ctx; name = "", ignore = Set())

RRTMGP has some internal dictionaries, which we check for consistency.

source
restore!(
    v1::T1,
    v2::T2,
    comms_ctx;
    name = "",
    ignore = Set(),
) where {
    T1 <: Union{Dates.DateTime, Dates.UTInstant, Dates.Millisecond},
    T2 <: Union{Dates.DateTime, Dates.UTInstant, Dates.Millisecond},
}

Special case to compare time-related types to allow different timestamps during restore.

source