Distributed Calibration Tutorial Using Julia Workers

This example will teach you how to use ClimaCalibrate to parallelize your calibration with workers. Workers are additional processes spun up to run code in a distributed fashion. In this tutorial, we will run ensemble members' forward models on different workers.

The example calibration uses CliMA's atmosphere model, ClimaAtmos.jl, in a column spatial configuration for 30 days to simulate outgoing radiative fluxes. Radiative fluxes are used in the observation map to calibrate the astronomical unit.

First, we load in some necessary packages.

using Distributed
import ClimaCalibrate as CAL
import ClimaAnalysis: SimDir, get, slice, average_xy
using ClimaUtilities.ClimaArtifacts
import EnsembleKalmanProcesses as EKP
import EnsembleKalmanProcesses: I, ParameterDistributions.constrained_gaussian

Next, we add workers. These are primarily added by Distributed.addprocs or by starting Julia with multiple processes: julia -p <nprocs>.

addprocs itself initializes the workers and registers them with the main Julia process, but there are multiple ways to call it. The simplest is just addprocs(nprocs), which will create new local processes on your machine. The other is to use SlurmManager, which will acquire and start workers on Slurm resources. You can use keyword arguments to specify the Slurm resources:

addprocs(ClimaCalibrate.SlurmManager(nprocs), gpus_per_task = 1, time = "01:00:00")

For this example, we would add one worker if it was compatible with Documenter.jl:

addprocs(1)

We can see the number of workers and their ID numbers:

nworkers()
1
workers()
1-element Vector{Int64}:
 1

We can call functions on the worker using remotecall. We pass in the function name and the worker ID followed by the function arguments.

remotecall_fetch(*, 1, 4, 4)
16

ClimaCalibrate uses this functionality to run the forward model on workers.

Since the workers start in their own Julia sessions, we need to import packages and declare variables. Distributed.@everywhere executes code on all workers, allowing us to load the code that they need.

@everywhere begin
    output_dir = joinpath("output", "climaatmos_calibration")
    import ClimaCalibrate as CAL
    import ClimaAtmos as CA
    import ClimaComms
end
output_dir = joinpath("output", "climaatmos_calibration")
mkpath(output_dir)
"output/climaatmos_calibration"

First, we define RadiativeFluxModelInterface which will subtype the ClimaCalibrate.AbstractModelInterface. The RadiativeFluxModelInterface will define how to run the forward model and observation map.

The forward model takes in the sampled parameters, runs the simulation, and saves the diagnostic output that can be processed and compared to observations. This is defined by ClimaCalibrate.forward_model(interface, iteration, member). This function is ran in parallel by the WorkerBackend and the HPCBackends.

Since forward_model(interface, iteration, member) only takes in the iteration and member numbers, so we need to use these as hooks to set the model parameters and output directory. Two useful functions:

The forward model below is running ClimaAtmos.jl in a minimal column spatial configuration.

Everywhere macro

Due to limitations in Documenter.jl (see here and here), we append @eval $(@__MODULE__) to every @everywhere call. For your own calibration script, you do not need to do this.

@everywhere @eval $(@__MODULE__) struct RadiativeFluxModelInterface <:
                                        CAL.AbstractModelInterface end

@everywhere @eval $(@__MODULE__) function CAL.forward_model(
    ::RadiativeFluxModelInterface,
    iteration,
    member,
)
    config_dict = Dict(
        "dt" => "2000secs",
        "t_end" => "30days",
        "config" => "column",
        "h_elem" => 1,
        "insolation" => "timevarying",
        "output_dir" => output_dir,
        "output_default_diagnostics" => false,
        "dt_rad" => "6hours",
        "rad" => "clearsky",
        "co2_model" => "fixed",
        "log_progress" => false,
        "diagnostics" => [
            Dict(
                "reduction_time" => "average",
                "short_name" => "rsut",
                "period" => "30days",
                "writer" => "nc",
            ),
        ],
    )
    # Set the output path for the current member
    member_path = CAL.path_to_ensemble_member(output_dir, iteration, member)
    config_dict["output_dir"] = member_path

    # Set the parameters for the current member
    parameter_path = CAL.parameter_path(output_dir, iteration, member)
    if haskey(config_dict, "toml")
        push!(config_dict["toml"], parameter_path)
    else
        config_dict["toml"] = [parameter_path]
    end

    # Turn off default diagnostics
    config_dict["output_default_diagnostics"] = false

    comms_ctx = ClimaComms.SingletonCommsContext()
    atmos_config = CA.AtmosConfig(config_dict; comms_ctx)
    simulation = CA.get_simulation(atmos_config)
    CA.solve_atmos!(simulation)
    return simulation
end

Next, the observation map is required to process a full ensemble of model output for the ensemble update step. The observation map just takes in the iteration number, and always outputs an array. For observation map output G_ensemble, G_ensemble[:, m] must the output of ensemble member m. This is needed for compatibility with EnsembleKalmanProcesses.jl.

const days = 86_400
function CAL.observation_map(::RadiativeFluxModelInterface, iteration)
    single_member_dims = (1,)
    G_ensemble = Array{Float64}(undef, single_member_dims..., ensemble_size)

    for m in 1:ensemble_size
        member_path = CAL.path_to_ensemble_member(output_dir, iteration, m)
        simdir_path = joinpath(member_path, "output_active")
        if isdir(simdir_path)
            simdir = SimDir(simdir_path)
            G_ensemble[:, m] .= process_member_data(simdir)
        else
            G_ensemble[:, m] .= NaN
        end
    end
    return G_ensemble
end

Separating out the individual ensemble member output processing often results in more readable code.

function process_member_data(simdir::SimDir)
    isempty(simdir.vars) && return NaN
    rsut =
        get(simdir; short_name = "rsut", reduction = "average", period = "30d")
    return slice(average_xy(rsut); time = 30days).data
end
process_member_data (generic function with 1 method)

Now, we can set up the remaining experiment details:

  • ensemble size, number of iterations
  • the prior distribution
  • the observational data
ensemble_size = 30
n_iterations = 7
noise = 0.1 * I
prior = constrained_gaussian("astronomical_unit", 6e10, 1e11, 2e5, Inf)
ParameterDistribution with 1 entries: 
'astronomical_unit' with EnsembleKalmanProcesses.ParameterDistributions.Constraint{EnsembleKalmanProcesses.ParameterDistributions.BoundedBelow}[Bounds: (200000.0, ∞)] over distribution EnsembleKalmanProcesses.ParameterDistributions.Parameterized(Distributions.Normal{Float64}(μ=24.153036641203013, σ=1.1528837102037748)) 

For a perfect model, we generate observations from the forward model itself. This is most easily done by creating an empty parameter file and running the 0th ensemble member:

@info "Generating observations"
parameter_file = CAL.parameter_path(output_dir, 0, 0)
mkpath(dirname(parameter_file))
touch(parameter_file)
simulation = CAL.forward_model(RadiativeFluxModelInterface(), 0, 0)
Simulation 
├── Running on: CPUSingleThreaded
├── Output folder: output/climaatmos_calibration/iteration_000/member_000/output_0000
├── Start date: 2010-01-01T00:00:00
├── Current time: 2.592e6 seconds
└── Stop time: 2.592e6 seconds

Lastly, we use the observation map itself to generate the observations.

observations = Vector{Float64}(undef, 1)
observations .= process_member_data(SimDir(simulation.output_dir))
1-element Vector{Float64}:
 126.61408233642578

Now we are ready to run our calibration, putting it all together using the calibrate function. The WorkerBackend will automatically use all workers available to the main Julia process. Other backends are available for forward models that can't use workers or need to be parallelized internally. The simplest backend is the JuliaBackend, which runs all ensemble members sequentially and does not require Distributed.jl. For more information, see the Backends page.

user_initial_ensemble = EKP.construct_initial_ensemble(prior, ensemble_size)
ekp = EKP.EnsembleKalmanProcess(
    user_initial_ensemble,
    observations,
    noise,
    EKP.Inversion(),
    EKP.default_options_dict(EKP.Inversion()),
)
eki = CAL.calibrate(
    CAL.WorkerBackend(),
    ekp,
    RadiativeFluxModelInterface(),
    n_iterations,
    prior,
    output_dir,
)
EnsembleKalmanProcesses.EnsembleKalmanProcess{Float64, Int64, EnsembleKalmanProcesses.Inversion{Float64, Nothing, Nothing}, EnsembleKalmanProcesses.DataMisfitController{Float64, String}, EnsembleKalmanProcesses.NesterovAccelerator{Float64}, Vector{EnsembleKalmanProcesses.UpdateGroup}, Nothing}(EnsembleKalmanProcesses.DataContainers.DataContainer{Float64}[EnsembleKalmanProcesses.DataContainers.DataContainer{Float64}([23.797365177895426 24.727740911580142 … 25.960814653306233 24.739024928136764]), EnsembleKalmanProcesses.DataContainers.DataContainer{Float64}([23.995802165295157 24.903157290130647 … 25.842863432893274 24.913819113195682]), EnsembleKalmanProcesses.DataContainers.DataContainer{Float64}([24.706736905596294 25.496760325474405 … 25.65949277228212 25.50439989692609]), EnsembleKalmanProcesses.DataContainers.DataContainer{Float64}([25.34256296399913 25.79736036585128 … 25.729074937323592 25.79843588479833]), EnsembleKalmanProcesses.DataContainers.DataContainer{Float64}([25.855509028742205 25.765783281178535 … 25.736817900532234 25.763927252851385]), EnsembleKalmanProcesses.DataContainers.DataContainer{Float64}([25.861058421181195 25.748087416466976 … 25.734258632314404 25.74728396630933]), EnsembleKalmanProcesses.DataContainers.DataContainer{Float64}([25.832800063091653 25.741345438374577 … 25.73312657659528 25.7408933620748]), EnsembleKalmanProcesses.DataContainers.DataContainer{Float64}([25.789624798639426 25.735518032287512 … 25.732078546300144 25.73534453300729])], EnsembleKalmanProcesses.ObservationSeries{Vector{EnsembleKalmanProcesses.Observation{Vector{Vector{Float64}}, Vector{LinearAlgebra.Diagonal{Float64, Vector{Float64}}}, Vector{LinearAlgebra.Diagonal{Float64, Vector{Float64}}}, Vector{String}, Vector{UnitRange{Int64}}, Nothing}}, EnsembleKalmanProcesses.FixedMinibatcher{Vector{Vector{Int64}}, String, Random.TaskLocalRNG}, Vector{String}, Vector{Vector{Vector{Int64}}}, Nothing}(EnsembleKalmanProcesses.Observation{Vector{Vector{Float64}}, Vector{LinearAlgebra.Diagonal{Float64, Vector{Float64}}}, Vector{LinearAlgebra.Diagonal{Float64, Vector{Float64}}}, Vector{String}, Vector{UnitRange{Int64}}, Nothing}[EnsembleKalmanProcesses.Observation{Vector{Vector{Float64}}, Vector{LinearAlgebra.Diagonal{Float64, Vector{Float64}}}, Vector{LinearAlgebra.Diagonal{Float64, Vector{Float64}}}, Vector{String}, Vector{UnitRange{Int64}}, Nothing}([[126.61408233642578]], LinearAlgebra.Diagonal{Float64, Vector{Float64}}[[0.1;;]], LinearAlgebra.Diagonal{Float64, Vector{Float64}}[[10.0;;]], ["observation"], UnitRange{Int64}[1:1], nothing)], EnsembleKalmanProcesses.FixedMinibatcher{Vector{Vector{Int64}}, String, Random.TaskLocalRNG}([[1]], "order", Random.TaskLocalRNG()), ["series_1"], Dict("minibatch" => 1, "epoch" => 8), [[[1]], [[1]], [[1]], [[1]], [[1]], [[1]], [[1]], [[1]]], nothing), 30, EnsembleKalmanProcesses.DataContainers.DataContainer{Float64}[EnsembleKalmanProcesses.DataContainers.DataContainer{Float64}([2.6495842933654785 17.030664443969727 … 200.2987518310547 17.41935157775879]), EnsembleKalmanProcesses.DataContainers.DataContainer{Float64}([3.940361499786377 24.18626594543457 … 158.25518798828125 24.707763671875]), EnsembleKalmanProcesses.DataContainers.DataContainer{Float64}([16.329944610595703 79.24510955810547 … 109.70674133300781 80.463623046875]), EnsembleKalmanProcesses.DataContainers.DataContainer{Float64}([58.22529220581055 144.5026397705078 … 126.07538604736328 144.8164825439453]), EnsembleKalmanProcesses.DataContainers.DataContainer{Float64}([162.30064392089844 135.66397094726562 … 128.03598022460938 135.16319274902344]), EnsembleKalmanProcesses.DataContainers.DataContainer{Float64}([164.1119842529297 130.95724487304688 … 127.38359832763672 130.74484252929688]), EnsembleKalmanProcesses.DataContainers.DataContainer{Float64}([155.105224609375 129.2040557861328 … 127.09961700439453 129.08709716796875])], Dict("unweighted_loss" => [1643.6273939150324, 7791.767348703951, 3655.675063141748, 363.05386979361293, 176.33435425800428, 142.39554046986134, 138.11355477807467], "crps" => [50.856196236571066, 62.27282867972197, 38.79273052190317, 12.804736296133505, 9.20512832950194, 7.2852225859367925, 7.049511524914529], "bayes_loss" => [16436.27850888079, 77918.03404101454, 36557.70255756036, 3631.8999296023467, 1764.7333620622157, 1425.3470694490793, 1382.4956041615346], "unweighted_avg_rmse" => [145.88125658680994, 94.12659982244173, 63.613217906157175, 28.67952135403951, 22.496268018086752, 14.611963653564453, 11.78886464436849], "avg_rmse" => [461.3170387417602, 297.6544438461165, 201.16265787606395, 90.69260968220114, 71.13944579075613, 46.207086232859204, 37.27966330363532], "loss" => [16436.273939150324, 77917.67348703952, 36556.75063141748, 3630.5386979361297, 1763.3435425800428, 1423.9554046986134, 1381.1355477807467]), EnsembleKalmanProcesses.DataMisfitController{Float64, String}([7], 1.0, "stop"), EnsembleKalmanProcesses.NesterovAccelerator{Float64}([25.795735713449748 25.737976121673142 … 25.732494940751252 25.73767619775377], 0.20434762801820305), [2.3904068437735863e-6, 2.6124122256324004e-5, 2.7394817769879495e-5, 5.379231743887082e-5, 9.948354837241961e-5, 0.00021258547284876074, 0.00025354939486462434], EnsembleKalmanProcesses.UpdateGroup[EnsembleKalmanProcesses.UpdateGroup([1], [1], Dict("[1,...,1]" => "[1,...,1]"))], EnsembleKalmanProcesses.Inversion{Float64, Nothing, Nothing}(nothing, nothing, false, 0.0), Random._GLOBAL_RNG(), EnsembleKalmanProcesses.FailureHandler{EnsembleKalmanProcesses.Inversion, EnsembleKalmanProcesses.SampleSuccGauss}(EnsembleKalmanProcesses.var"#failsafe_update#174"()), EnsembleKalmanProcesses.Localizers.Localizer{EnsembleKalmanProcesses.Localizers.SECNice, Float64}(EnsembleKalmanProcesses.Localizers.var"#13#14"{EnsembleKalmanProcesses.Localizers.SECNice{Float64}}(EnsembleKalmanProcesses.Localizers.SECNice{Float64}(1000, 1.0, 1.0))), 0.1, nothing, false)

This page was generated using Literate.jl.