Automatic Differentiation
ClimaTimeSteppers is fully compatible with ForwardDiff.jl. Dual numbers propagate correctly through all solver families — explicit RK, IMEX ARK, and Rosenbrock — enabling differentiation of the ODE solution with respect to initial conditions or parameters.
Differentiating with respect to a parameter
Consider the decay ODE $du/dt = -\lambda u$ with exact solution $u(T) = u_0 e^{-\lambda T}$. We can compute $du/d\lambda$ at $\lambda = 1$ by wrapping the solve in a function:
using ClimaTimeSteppers
import ClimaTimeSteppers as CTS
using ForwardDiff
function solve_decay(λ)
f = ClimaODEFunction(; T_exp! = (du, u, p, t) -> (du .= -p[1] .* u))
# Construct problem with types matching λ (required for dual propagation)
prob = CTS.ODEProblem(f, [one(λ)], (zero(λ), one(λ)), [λ])
sol = CTS.solve(prob, ExplicitAlgorithm(RK4()); dt = oftype(λ, 0.1), saveat = (one(λ),))
return sol.u[end][1]
end
λ = 1.0
value = solve_decay(λ)
gradient = ForwardDiff.derivative(solve_decay, λ)
println("u(1) = ", round(value; digits = 6), " (exact: ", round(exp(-1); digits = 6), ")")
println("du/dλ = ", round(gradient; digits = 6), " (exact: ", round(-exp(-1); digits = 6), ")")u(1) = 0.36788 (exact: 0.367879)
du/dλ = -0.367878 (exact: -0.367879)The key requirement: when differentiating with respect to a variable, all quantities derived from it (u0, tspan, dt, jac_prototype) must be constructed with matching type (e.g. one(λ) instead of 1.0) so that ForwardDiff's dual numbers are not accidentally clipped by the solver.
IMEX methods
Implicit solvers work the same way. The Jacobian prototype must also be typed to match:
using LinearAlgebra
function solve_imex(λ)
T = typeof(λ)
T_imp! = (du, u, p, t) -> (du .= -p[1] .* u)
function Wfact!(W, u, p, dtγ, t)
fill!(W, zero(eltype(W)))
for i in axes(W, 1)
W[i, i] = -p[1] * dtγ - 1
end
end
imp = CTS.ODEFunction(T_imp!; jac_prototype = zeros(T, 1, 1), Wfact = Wfact!)
f = ClimaODEFunction(;
T_exp! = (du, u, p, t) -> (du .= zero(u)),
T_imp! = imp,
)
prob = CTS.ODEProblem(f, [one(T)], (zero(T), one(T)), [λ])
alg = IMEXAlgorithm(ARS343(), NewtonsMethod(; max_iters = 1, update_j = UpdateEvery(NewTimeStep)))
sol = CTS.solve(prob, alg; dt = T(0.1), saveat = (one(T),))
return sol.u[end][1]
end
value_imex = solve_imex(1.0)
grad_imex = ForwardDiff.derivative(solve_imex, 1.0)
println("IMEX u(1) = ", round(value_imex; digits = 6), " (exact: ", round(exp(-1); digits = 6), ")")
println("IMEX du/dλ = ", round(grad_imex; digits = 6), " (exact: ", round(-exp(-1); digits = 6), ")")IMEX u(1) = 0.36787 (exact: 0.367879)
IMEX du/dλ = -0.367906 (exact: -0.367879)Rosenbrock methods
Rosenbrock solvers use a direct linear solve (lu factorization) instead of Newton iteration, and dual numbers propagate through this path as well:
function solve_rosenbrock(λ)
T = typeof(λ)
T_imp! = (du, u, p, t) -> (du .= -p[1] .* u)
function Wfact!(W, u, p, dtγ, t)
fill!(W, zero(eltype(W)))
for i in axes(W, 1)
W[i, i] = -p[1] * dtγ - 1
end
end
imp = CTS.ODEFunction(T_imp!; jac_prototype = zeros(T, 1, 1), Wfact = Wfact!)
f = ClimaODEFunction(; T_imp! = imp)
prob = CTS.ODEProblem(f, [one(T)], (zero(T), one(T)), [λ])
alg = CTS.RosenbrockAlgorithm(CTS.tableau(SSPKnoth()))
sol = CTS.solve(prob, alg; dt = T(0.1), saveat = (one(T),))
return sol.u[end][1]
end
value_rb = solve_rosenbrock(1.0)
grad_rb = ForwardDiff.derivative(solve_rosenbrock, 1.0)
println("Rosenbrock u(1) = ", round(value_rb; digits = 6), " (exact: ", round(exp(-1); digits = 6), ")")
println("Rosenbrock du/dλ = ", round(grad_rb; digits = 6), " (exact: ", round(-exp(-1); digits = 6), ")")Rosenbrock u(1) = 0.369393 (exact: 0.367879)
Rosenbrock du/dλ = -0.365124 (exact: -0.367879)Tips
- Type your allocations: use
zeros(typeof(λ), n, n)forjac_prototype,[one(λ)]foru0, andoftype(λ, 0.1)fordt. Propagate your types: usezeros(typeof(λ), n, n)forjac_prototype, - Construct a fresh
ODEProblemeach iteration if you callsolvein a loop — the integrator mutatesu0in place. - Rosenbrock methods work as shown above — the
lufactorization andldiv!both handle dual-valued matrices and right-hand sides.