Perfect convective adjustment calibration with Ensemble Kalman Inversion

This example calibrates a convective adjustment model in the "perfect model context". In this context, synthetic observations are generated by a convective adjustment model with "true" parameters. The true parameters are then "rediscovered" by calibrating the model to match the synthetic observations.

We use the discrepancy between observed and modeled buoyancy $b$ to calibrate the convective adjustment model. The calibration problem is solved by Ensemble Kalman Inversion. For more information about Ensemble Kalman Inversion, see the EnsembleKalmanProcesses.jl documentation.

Install dependencies

First let's make sure we have all required packages installed.

using Pkg
pkg"add ParameterEstimocean, Oceananigans, Distributions, CairoMakie"
using ParameterEstimocean, LinearAlgebra, CairoMakie

We reuse some code from a previous example to generate observations,

examples_path = joinpath(pathof(ParameterEstimocean), "..", "..", "examples")
include(joinpath(examples_path, "intro_to_inverse_problems.jl"))

data_path = generate_synthetic_observations()
observations = SyntheticObservations(data_path, field_names=:b, transformation=ZScore())
SyntheticObservations with fields (:b,)
├── times: [0 s, 4 hrs, 8 hrs, 12 hrs]
├── grid: 1×1×32 RectilinearGrid{Float64, Oceananigans.Grids.Flat, Oceananigans.Grids.Flat, Oceananigans.Grids.Bounded} on Oceananigans.Architectures.CPU with 0×0×3 halo
├── path: "convective_adjustment.jld2"
├── metadata: (:parameters, :grid, :coriolis, :closure)
└── transformation: Dict{Symbol, ParameterEstimocean.Transformations.Transformation{TimeIndices{UnitRange{Int64}}, Nothing, ZScore{Float64}}} with 1 entry

and an ensemble simulation,

ensemble_simulation, closure★ = build_ensemble_simulation(observations; Nensemble=50)
(Simulation of HydrostaticFreeSurfaceModel{CPU, RectilinearGrid}(time = 0 seconds, iteration = 0)
├── Next time step: 10 seconds
├── Elapsed wall time: 0 seconds
├── Wall time per iteration: NaN years
├── Stop time: 12 hours
├── Stop iteration : Inf
├── Wall time limit: Inf
├── Callbacks: OrderedDict with 4 entries:
│   ├── stop_time_exceeded => Callback of stop_time_exceeded on IterationInterval(1)
│   ├── stop_iteration_exceeded => Callback of stop_iteration_exceeded on IterationInterval(1)
│   ├── wall_time_limit_exceeded => Callback of wall_time_limit_exceeded on IterationInterval(1)
│   └── nan_checker => Callback of NaNChecker for u on IterationInterval(100)
├── Output writers: OrderedDict with no entries
└── Diagnostics: OrderedDict with no entries, ConvectiveAdjustmentVerticalDiffusivity{Oceananigans.TurbulenceClosures.VerticallyImplicitTimeDiscretization}(background_κz=0.0001 convective_κz=1.0 background_νz=1.0e-5 convective_νz=0.9))

The handy utility function build_ensemble_simulation also tells us the optimal parameters that were used when generating the synthetic observations:

@show θ★ = (convective_κz = closure★.convective_κz, background_κz = closure★.background_κz)
(convective_κz = 1.0, background_κz = 0.0001)

The InverseProblem

To build an inverse problem we first define free parameters. Here we calibrate convective_κz and background_κz, using log-normal priors to prevent the parameters from becoming negative:

priors = (convective_κz = lognormal(mean=0.3, std=0.5),
          background_κz = lognormal(mean=2.5e-4, std=2.5e-5))

free_parameters = FreeParameters(priors)
FreeParameters with 2 parameters
├── names: (:convective_κz, :background_κz)
├── priors: Dict{Symbol, Any}
│   ├── convective_κz => LogNormal{Float64}(μ=-1.8685407779659071, σ=1.152881584240091)
│   └── background_κz => LogNormal{Float64}(μ=-8.299024805528612, σ=0.0997513451195927)
└── dependent parameters: Dict{Symbol, Any}

The InverseProblem is then constructed from observations, ensemble_simulation, and free_parameters,

calibration = InverseProblem(observations, ensemble_simulation, free_parameters)
InverseProblem{ConcatenatedOutputMap} with free parameters (:convective_κz, :background_κz)
├── observations: SyntheticObservations of (:b,) on 1×1×32 RectilinearGrid{Float64, Oceananigans.Grids.Flat, Oceananigans.Grids.Flat, Oceananigans.Grids.Bounded} on Oceananigans.Architectures.CPU with 0×0×3 halo
├── simulation: Simulation on 50×1×32 RectilinearGrid{Float64, Oceananigans.Grids.Flat, Oceananigans.Grids.Flat, Oceananigans.Grids.Bounded} on Oceananigans.Architectures.CPU with 0×0×3 halo with Δt=10.0
├── free_parameters: (:convective_κz, :background_κz)
└── output map: ConcatenatedOutputMap

For more information about the above steps, see Intro to observations and Intro to InverseProblem.

Ensemble Kalman Inversion

Next, we construct an EnsembleKalmanInversion (EKI) object,

The calibration is done here using Ensemble Kalman Inversion. For more information about the algorithm refer to EnsembleKalmanProcesses.jl documentation.

eki = EnsembleKalmanInversion(calibration; pseudo_stepping = ConstantConvergence(0.5))
EnsembleKalmanInversion
├── inverse_problem: InverseProblem{ConcatenatedOutputMap} with free parameters (:convective_κz, :background_κz)
├── ensemble_kalman_process: EnsembleKalmanProcesses.Inversion
├── mapped_observations: 96-element Vector{Float64}
├── noise_covariance: 96×96 Matrix{Float64}
├── pseudo_stepping: ConstantConvergence{Float64}(0.5)
├── iteration: 0
├── resampler: Resampler{FullEnsembleDistribution}├── unconstrained_parameters: 2×50 Matrix{Float64}
├── forward_map_output: 96×50 Matrix{Float64}
└── mark_failed_particles: NormExceedsMedian{Float64}

and perform few iterations to see if we can converge to the true parameter values.

iterate!(eki; iterations = 10)
(convective_κz = 0.6275159654558178, background_κz = 0.0001929344044871522)

Last, we visualize the outputs of EKI calibration.

θ̅(iteration) = [eki.iteration_summaries[iteration].ensemble_mean...]
varθ(iteration) = eki.iteration_summaries[iteration].ensemble_var

weight_distances = [norm(θ̅(iter) - [θ★[1], θ★[2]]) for iter in 0:eki.iteration]
output_distances = [norm(forward_map(calibration, θ̅(iter))[:, 1] - y) for iter in 0:eki.iteration]
ensemble_variances = [varθ(iter) for iter in 0:eki.iteration]

f = Figure()

lines(f[1, 1], 0:eki.iteration, weight_distances, color = :red, linewidth = 2,
      axis = (title = "Parameter distance",
              xlabel = "Iteration",
              ylabel = "|θ̅ₙ - θ★|"))

lines(f[1, 2], 0:eki.iteration, output_distances, color = :blue, linewidth = 2,
      axis = (title = "Output distance",
              xlabel = "Iteration",
              ylabel = "|G(θ̅ₙ) - y|"))

ax3 = Axis(f[2, 1:2],
           title = "Parameter convergence",
           xlabel = "Iteration",
           ylabel = "Ensemble variance",
           yscale = log10)

for (i, pname) in enumerate(free_parameters.names)
    ev = getindex.(ensemble_variances, i)
    lines!(ax3, 0:eki.iteration, ev / ev[1], label = String(pname), linewidth = 2)
end

axislegend(ax3, valign = :top, halign = :right)
[ Info: Initializing simulation...
[ Info:     ... simulation initialization complete (710.209 μs)
[ Info: Executing initial time step...
[ Info:     ... initial time step complete (2.470 ms).
[ Info: Simulation is stopping. Model time 12 hours has hit or exceeded simulation stop time 12 hours.
[ Info: Initializing simulation...
[ Info:     ... simulation initialization complete (622.709 μs)
[ Info: Executing initial time step...
[ Info:     ... initial time step complete (2.222 ms).
[ Info: Simulation is stopping. Model time 12 hours has hit or exceeded simulation stop time 12 hours.
[ Info: Initializing simulation...
[ Info:     ... simulation initialization complete (565.607 μs)
[ Info: Executing initial time step...
[ Info:     ... initial time step complete (2.224 ms).
[ Info: Simulation is stopping. Model time 12 hours has hit or exceeded simulation stop time 12 hours.
[ Info: Initializing simulation...
[ Info:     ... simulation initialization complete (538.007 μs)
[ Info: Executing initial time step...
[ Info:     ... initial time step complete (2.159 ms).
[ Info: Simulation is stopping. Model time 12 hours has hit or exceeded simulation stop time 12 hours.
[ Info: Initializing simulation...
[ Info:     ... simulation initialization complete (548.907 μs)
[ Info: Executing initial time step...
[ Info:     ... initial time step complete (2.154 ms).
[ Info: Simulation is stopping. Model time 12 hours has hit or exceeded simulation stop time 12 hours.
[ Info: Initializing simulation...
[ Info:     ... simulation initialization complete (547.107 μs)
[ Info: Executing initial time step...
[ Info:     ... initial time step complete (2.193 ms).
[ Info: Simulation is stopping. Model time 12 hours has hit or exceeded simulation stop time 12 hours.
[ Info: Initializing simulation...
[ Info:     ... simulation initialization complete (564.807 μs)
[ Info: Executing initial time step...
[ Info:     ... initial time step complete (2.226 ms).
[ Info: Simulation is stopping. Model time 12 hours has hit or exceeded simulation stop time 12 hours.
[ Info: Initializing simulation...
[ Info:     ... simulation initialization complete (399.105 μs)
[ Info: Executing initial time step...
[ Info:     ... initial time step complete (1.674 ms).
[ Info: Simulation is stopping. Model time 12 hours has hit or exceeded simulation stop time 12 hours.
[ Info: Initializing simulation...
[ Info:     ... simulation initialization complete (547.307 μs)
[ Info: Executing initial time step...
[ Info:     ... initial time step complete (2.116 ms).
[ Info: Simulation is stopping. Model time 12 hours has hit or exceeded simulation stop time 12 hours.
[ Info: Initializing simulation...
[ Info:     ... simulation initialization complete (562.407 μs)
[ Info: Executing initial time step...
[ Info:     ... initial time step complete (2.138 ms).
[ Info: Simulation is stopping. Model time 12 hours has hit or exceeded simulation stop time 12 hours.
[ Info: Initializing simulation...
[ Info:     ... simulation initialization complete (561.807 μs)
[ Info: Executing initial time step...
[ Info:     ... initial time step complete (2.188 ms).
[ Info: Simulation is stopping. Model time 12 hours has hit or exceeded simulation stop time 12 hours.

And also we plot the the distributions of the various model ensembles for few EKI iterations to see if and how well they converge to the true diffusivity values.

fig = Figure()

axtop = Axis(fig[1, 1])
axmain = Axis(fig[2, 1], xlabel = "convective_κz [m² s⁻¹]",
                       ylabel = "background_κz [m² s⁻¹]")

axright = Axis(fig[2, 2])
scatters = []
labels = String[]

for iteration in [0, 1, 2, 10]
    # Make parameter matrix
    parameters = eki.iteration_summaries[iteration].parameters
    Nensemble = length(parameters)
    Nparameters = length(first(parameters))
    parameter_ensemble_matrix = [parameters[i][j] for i=1:Nensemble, j=1:Nparameters]

    label = iteration == 0 ? "Initial ensemble" : "Iteration $iteration"
    push!(labels, label)
    push!(scatters, scatter!(axmain, parameter_ensemble_matrix))
    density!(axtop, parameter_ensemble_matrix[:, 1])
    density!(axright, parameter_ensemble_matrix[:, 2], direction = :y)
end

vlines!(axmain, [θ★.convective_κz], color = :red)
vlines!(axtop, [θ★.convective_κz], color = :red)

hlines!(axmain, [θ★.background_κz], color = :red)
hlines!(axright, [θ★.background_κz], color = :red)

colsize!(fig.layout, 1, Fixed(300))
colsize!(fig.layout, 2, Fixed(200))
rowsize!(fig.layout, 1, Fixed(200))
rowsize!(fig.layout, 2, Fixed(300))

Legend(fig[1, 2], scatters, labels, valign = :bottom, halign = :left)

hidedecorations!(axtop, grid = false)
hidedecorations!(axright, grid = false)

xlims!(axmain, -0.25, 3.2)
xlims!(axtop, -0.25, 3.2)
ylims!(axmain, 5e-5, 35e-5)
ylims!(axright, 5e-5, 35e-5)


This page was generated using Literate.jl.