Universal Differential Equations
And finally the part of the tutorial I prefer! Universal Differential Equations are a smart generalization of Neural Differential Equation. Now, the right handside of the equations defining the system can be a combination of explicit functions (capturing the known part of the system, the domain knowledge we can gather) and Neural Networks (capturing the unknown elements of the system).
As always, we start by loading the libraries we'll need. There a bunch of them this time :-)
using Lux, DataDrivenDiffEq, ModelingToolkit, OrdinaryDiffEq, DataDrivenSparse, LinearAlgebra, Plots
using Optimization, OptimizationOptimisers, OptimizationFlux, OptimizationOptimJL , DiffEqSensitivity
using DifferentialEquations
using Statistics, ComponentArrays, Random
Define the dynamical system
begin
uā = š±ā, šā = [-1.0, 1.0]
p_true = [1.0, .4, .2, .1]
tspan = (0.0,10.0)
function cat_love(du,u,p,t)
š±, š = u
a, b, α, β = p #
du[1] = dš± = - a*š + α*š*š±
du[2] = dš = b*š± - β*š*š±
end
prob = ODEProblem(cat_love, uā,tspan, p_true)
solution = solve(prob, Vern7(), abstol=1e-12, reltol=1e-12, saveat = 0.3)
X = Array(solution)
t = solution.t
end;
prob_org = ODEProblem(cat_love, # the equation system
uā, # initial state
(0.0,30.0), # time interval
p_true # parameters
)
ODEProblem with uType Vector{Float64} and tType Float64. In-place: true timespan: (0.0, 30.0) u0: 2-element Vector{Float64}: -1.0 1.0
begin
solution_org = solve(prob_org,Tsit5(), saveat = 0.1)
plot( solution_org.t, Array(solution_org)', label = ["š±" "š"])
end
DX = Array(solution(solution.t, Val{1}))
2Ć35 Matrix{Float64}: -1.17193 -1.17193 -1.07165 -0.898139 ⦠-1.17512 -1.20815 -1.20802 -0.359549 -0.359549 -0.481937 -0.606033 -0.0842729 -0.190613 -0.265336
full_problem = DataDrivenProblem(X, t = t, DX = DX)
Continuous DataDrivenProblem{Float64} ##DDProblem#636 in 2 dimensions and 35 samples
Define the neural network
begin
rng = Random.default_rng()
Random.seed!(1234)
# Define the network
# Gaussian RBF as activation
rbf(x) = exp.(-(x.^2))
# Multilayer FeedForward
U = Lux.Chain(
Lux.Dense(2,5,rbf),
Lux.Dense(5,5, rbf),
Lux.Dense(5,5, rbf),
Lux.Dense(5,2)
)
# Get the initial parameters and state variables of the model
pāā, st = Lux.setup(rng, U)
end
((layer_1 = (weight = Float32[0.41883966 -0.5210763; -0.3222286 0.40877005; ⦠; 0.090825 0.058136694; 0.53497386 -0.27198443], bias = Float32[0.0; 0.0; ⦠; 0.0; 0.0;;]), layer_2 = (weight = Float32[-0.14838398 0.7236945 ⦠-0.60292 -0.4678297; 0.21620789 0.7046474 ⦠-0.49983823 -0.6270864; ⦠; 0.52614 0.69810414 ⦠-0.4017533 0.5352689; -0.6902347 0.042112827 ⦠0.31156725 0.7140419], bias = Float32[0.0; 0.0; ⦠; 0.0; 0.0;;]), layer_3 = (weight = Float32[-0.7521383 0.11509942 ⦠0.24387686 -0.3582348; -0.58238137 0.28912154 ⦠0.7191186 0.26629153; ⦠; 0.23063432 -0.6323712 ⦠-0.51756436 -0.031197049; 0.35919273 -0.20792396 ⦠0.59049577 -0.6734359], bias = Float32[0.0; 0.0; ⦠; 0.0; 0.0;;]), layer_4 = (weight = Float32[-0.43561655 -0.3437964 ⦠0.097560994 0.25912023; -0.8772459 -0.28295153 ⦠0.5422182 0.86514133], bias = Float32[0.0; 0.0;;])), (layer_1 = NamedTuple(), layer_2 = NamedTuple(), layer_3 = NamedTuple(), layer_4 = NamedTuple()))
Universal Differential Equation
Here we pretend we know that the system is linked by dš± depending on š multiplicatively, and dš depending on š± also multiplicatively.
# Define the hybrid model
function ude_cat!(du, u, p, t) #, p_true)
uĢ = U(u, p, st)[1] # Network prediction
du[1] = uĢ[1]*u[2]
du[2] = uĢ[2]*u[1]
end
ude_cat! (generic function with 1 method)
# ODEProblem{ Is In Place? , Specialize? } -> ? ODEFunction
prob_nn = ODEProblem{true, SciMLBase.FullSpecialize}(ude_cat!,uā, tspan, pāā)
ODEProblem with uType Vector{Float64} and tType Float64. In-place: true timespan: (0.0, 10.0) u0: 2-element Vector{Float64}: -1.0 1.0
And we train!
function predict(Īø, X = uā, T = t)
_prob = remake(prob_nn, u0 = X, tspan = (T[1], T[end]), p = Īø)
Array(solve(_prob, Vern7(), saveat = T,
abstol=1e-6, reltol=1e-6,
sensealg = ForwardDiffSensitivity()
))
end
predict (generic function with 3 methods)
# Simple L2 loss
function loss(Īø,hyper)
XĢ = predict(Īø)
sum(abs2, X .- XĢ)
end
loss (generic function with 1 method)
begin
adtype = Optimization.AutoZygote()
optf = Optimization.OptimizationFunction(loss, adtype)
optprob = Optimization.OptimizationProblem(optf, ComponentVector{Float64}(pāā))
end
OptimizationProblem. In-place: true u0: ComponentVector{Float64}(layer_1 = (weight = [0.4188396632671356 -0.5210763216018677; -0.3222286105155945 0.40877005457878113; ⦠; 0.09082499891519547 0.05813669413328171; 0.5349738597869873 -0.27198442816734314], bias = [0.0; 0.0; ⦠; 0.0; 0.0;;]), layer_2 = (weight = [-0.1483839750289917 0.7236945033073425 ⦠-0.6029199957847595 -0.46782970428466797; 0.21620789170265198 0.704647421836853 ⦠-0.4998382329940796 -0.6270864009857178; ⦠; 0.5261399745941162 0.6981041431427002 ⦠-0.401753306388855 0.5352689027786255; -0.6902347207069397 0.04211282730102539 ⦠0.3115672469139099 0.7140418887138367], bias = [0.0; 0.0; ⦠; 0.0; 0.0;;]), layer_3 = (weight = [-0.7521383166313171 0.11509942263364792 ⦠0.2438768595457077 -0.35823479294776917; -0.5823813676834106 0.28912153840065 ⦠0.719118595123291 0.2662915289402008; ⦠; 0.23063431680202484 -0.632371187210083 ⦠-0.5175643563270569 -0.031197048723697662; 0.35919272899627686 -0.20792396366596222 ⦠0.590495765209198 -0.6734359264373779], bias = [0.0; 0.0; ⦠; 0.0; 0.0;;]), layer_4 = (weight = [-0.43561655282974243 -0.34379640221595764 ⦠0.09756099432706833 0.25912022590637207; -0.8772459030151367 -0.2829515337944031 ⦠0.5422182083129883 0.8651413321495056], bias = [0.0; 0.0;;]))
begin
res1 = Optimization.solve(optprob, ADAM(0.1), maxiters = 200)
end
u: ComponentVector{Float64}(layer_1 = (weight = [0.9329137145362386 -1.14390302410167; 0.5740133392071249 0.5662870873180993; ⦠; -0.1395497051412112 -1.4075775787854536; 1.8860097440145391 -1.448813214265837], bias = [-0.43235987250908525; -1.4042417890843308; ⦠; 0.8409739859030176; 0.2663964851796898;;]), layer_2 = (weight = [-0.08283411115410458 0.8032399062662892 ⦠-0.5396684718387518 -0.6574632596279617; 0.43725167063027276 0.023382550764401645 ⦠-1.2532325143886063 0.4794776690813305; ⦠; 0.8306520086616839 0.6419872749509058 ⦠-0.6313244452762709 1.2275539739440995; -0.34945545168090625 -0.5393183918061216 ⦠-0.351719770916033 1.6359672765617765], bias = [-0.19685687689991843; -0.43827466005949894; ⦠; -0.2385682493500643; -0.41712804771250467;;]), layer_3 = (weight = [-0.604588197400037 -0.6821058326823396 ⦠-0.2235380487587348 -1.0477307359667047; -0.13782759900749839 1.3252832160697237 ⦠1.3752238168623576 1.7220475253066423; ⦠; 1.3191192875414361 0.20908812915557626 ⦠0.4028351033928086 0.7823028515441364; 1.4649870823537174 0.5863142899254746 ⦠1.1640569726118053 0.1391639056969676], bias = [0.024859395401106002; 0.48831481666259263; ⦠; 1.05939570616502; 1.138988026477087;;]), layer_4 = (weight = [-1.318967100974358 -1.5808122425942088 ⦠-0.7897284976373723 -0.3490118740866796; -0.5401545612117324 0.12289341037177055 ⦠0.482754626019797 1.38381659460503], bias = [-0.8428392130106009; 0.46637384262714715;;]))
begin
optprob2 = Optimization.OptimizationProblem(optf, res1.minimizer)
res2 = Optimization.solve(optprob2, Optim.BFGS(initial_stepnorm=0.01), maxiters = 200)
end
u: ComponentVector{Float64}(layer_1 = (weight = [0.5586930244843454 -1.0653175859292792; -0.08168622464792798 0.7805346329618921; ⦠; 0.06354679687439149 -0.7207505119924804; 2.4568185519515 -2.088309740051133], bias = [-1.0134142571761495; -1.6936988331035274; ⦠; -0.1184186911403389; 0.012744160369824457;;]), layer_2 = (weight = [-0.3021431331206667 0.870370397951155 ⦠-0.2574179829792301 -0.8025435456150155; 1.65375485400009 0.016689226414514143 ⦠-1.9498968119037836 1.264071329717429; ⦠; 1.6655048365883753 -0.2240996622842686 ⦠-1.304472139141909 2.527413354540473; 0.3033838667241871 -1.0682703304925405 ⦠-0.6668502101734235 2.166048158288414], bias = [-0.6301150570386213; -0.5787047024002201; ⦠; -1.0627635248707525; -0.26988441165154614;;]), layer_3 = (weight = [-0.21914994184111103 -0.09986833653669905 ⦠-0.5409559309448428 -0.8084014053032869; 0.24092636151538047 1.560851522841987 ⦠1.9328940719452674 1.1071648728324064; ⦠; 1.7119695666898302 0.08550747264910524 ⦠0.38698837594622026 0.7782489147754519; 1.2897415654369686 0.676871312448159 ⦠1.046028157116686 0.3381962238276238], bias = [-0.4905454574372125; -0.2858535317023063; ⦠; 0.45798096755190193; 0.8119435512490865;;]), layer_4 = (weight = [3.8496744892725383 -2.0908100451828413 ⦠-1.3099105146314345 -0.907530076846057; -2.314065716578152 0.6123368350459606 ⦠0.3702974634035101 1.203212748354118], bias = [-1.0887535448107724; 0.7038602217116432;;]))
begin
p_trained = res2.minimizer
ts = first(solution.t):mean(diff(solution.t))/2:last(solution.t)
XĢ = predict(p_trained, X[:,1], ts)
plot(solution, alpha = 0.75, color = :black, label = ["True Data" nothing])
scatter!(ts,XĢ',alpha = 0.75, color = :red, label = ["NN Data" nothing])
end
Here we had to strongly limit the number of training iterations, as the model is running on Github Actions, and we don't want to have it cut midway through. But still, not bad of a fit eh!
Equation discovery
begin
YĢ = U(XĢ,p_trained,st)[1]
Ī» = exp10.(-3:0.01:-1) # thresholds
opt = STLSQ(Ī») # SINDy optimizer
nn_problem = DirectDataDrivenProblem(XĢ, YĢ)
end
Direct DataDrivenProblem{Float64} ##DDProblem#968 in 2 dimensions and 69 samples
begin
# Create a Basis
@variables u[1:2]
basis = Basis(polynomial_basis(u, 2),u)
end
$$\begin{align} \varphi_1 =& 1 \\ \varphi_2 =& u_1 \\ \varphi_3 =& u_1^{2} \\ \varphi_4 =& u_2 \\ \varphi_5 =& u_1 u_2 \\ \varphi_6 =& u_2^{2} \end{align}$$
sampler = DataProcessing(split = 0.9, shuffle = true)
DataProcessing split: Float64 0.9 shuffle: Bool true batchsize: Int64 0 partial: Bool true rng: TaskLocalRNG TaskLocalRNG()
nn_res = solve(nn_problem, basis, opt,
options = DataDrivenCommonOptions(
digits = 1,
data_processing = sampler,
maxiters = 10000,
denoise = true,
)
)
"DataDrivenSolution{Float64}"
get_basis(nn_res)
$$\begin{align} \varphi_1 =& p_1 + u_1^{2} p_2 \\ \varphi_2 =& u_1^{2} p_3 \end{align}$$
tiv = get_basis(nn_res) |> get_iv |> Symbolics.unwrap
$$\begin{equation} t \end{equation}$$
Built with Julia 1.8.3 and
ComponentArrays 0.13.4DataDrivenDiffEq 1.0.1
DataDrivenSparse 0.1.1
DiffEqSensitivity 6.79.0
DifferentialEquations 7.6.0
Lux 0.4.36
ModelingToolkit 8.36.0
Optimization 3.10.0
OptimizationFlux 0.1.2
OptimizationOptimJL 0.1.5
OptimizationOptimisers 0.1.1
OrdinaryDiffEq 6.35.1
Plots 1.37.2
To run this tutorial locally, download [this file](/tutorials/03udesparse.jl) and open it with Pluto.jl.