Universal Differential Equations

And finally the part of the tutorial I prefer! Universal Differential Equations are a smart generalization of Neural Differential Equation. Now, the right handside of the equations defining the system can be a combination of explicit functions (capturing the known part of the system, the domain knowledge we can gather) and Neural Networks (capturing the unknown elements of the system).

As always, we start by loading the libraries we'll need. There a bunch of them this time :-)

using Lux, DataDrivenDiffEq, ModelingToolkit, OrdinaryDiffEq, DataDrivenSparse, LinearAlgebra, Plots
using Optimization, OptimizationOptimisers, OptimizationFlux, OptimizationOptimJL , DiffEqSensitivity
using DifferentialEquations
using Statistics, ComponentArrays, Random

Define the dynamical system

begin
    uā‚€ = šŸ±ā‚€, šŸ˜„ā‚€ = [-1.0, 1.0]
    p_true = [1.0, .4, .2, .1]
    tspan = (0.0,10.0)
    
    function cat_love(du,u,p,t)
        🐱, šŸ˜„ = u
        a, b, α, β = p # 
        du[1] = d🐱 = - a*šŸ˜„ + α*šŸ˜„*🐱
        du[2] = dšŸ˜„ = b*🐱 - β*šŸ˜„*🐱
    end

    prob = ODEProblem(cat_love, uā‚€,tspan, p_true)
    solution = solve(prob, Vern7(), abstol=1e-12, reltol=1e-12, saveat = 0.3)

    X = Array(solution)
    t = solution.t
end;
prob_org = ODEProblem(cat_love, # the equation system
    uā‚€, # initial state
    (0.0,30.0), # time interval
    p_true # parameters
)
ODEProblem with uType Vector{Float64} and tType Float64. In-place: true
timespan: (0.0, 30.0)
u0: 2-element Vector{Float64}:
 -1.0
  1.0
begin
    solution_org = solve(prob_org,Tsit5(), saveat = 0.1)
    plot( solution_org.t, Array(solution_org)', label = ["🐱" "šŸ˜„"])
end
DX = Array(solution(solution.t, Val{1}))
2Ɨ35 Matrix{Float64}:
 -1.17193   -1.17193   -1.07165   -0.898139  …  -1.17512    -1.20815   -1.20802
 -0.359549  -0.359549  -0.481937  -0.606033     -0.0842729  -0.190613  -0.265336
full_problem = DataDrivenProblem(X, t = t, DX = DX)
Continuous DataDrivenProblem{Float64} ##DDProblem#636 in 2 dimensions and 35 samples

Define the neural network

begin
    rng = Random.default_rng()
    Random.seed!(1234)
    # Define the network
    # Gaussian RBF as activation
    rbf(x) = exp.(-(x.^2))

    # Multilayer FeedForward
    U = Lux.Chain(
        Lux.Dense(2,5,rbf),
        Lux.Dense(5,5, rbf),
        Lux.Dense(5,5, rbf),
        Lux.Dense(5,2)
    )
    # Get the initial parameters and state variables of the model
    pₙₙ, st = Lux.setup(rng, U)

end
((layer_1 = (weight = Float32[0.41883966 -0.5210763; -0.3222286 0.40877005; … ; 0.090825 0.058136694; 0.53497386 -0.27198443], bias = Float32[0.0; 0.0; … ; 0.0; 0.0;;]), layer_2 = (weight = Float32[-0.14838398 0.7236945 … -0.60292 -0.4678297; 0.21620789 0.7046474 … -0.49983823 -0.6270864; … ; 0.52614 0.69810414 … -0.4017533 0.5352689; -0.6902347 0.042112827 … 0.31156725 0.7140419], bias = Float32[0.0; 0.0; … ; 0.0; 0.0;;]), layer_3 = (weight = Float32[-0.7521383 0.11509942 … 0.24387686 -0.3582348; -0.58238137 0.28912154 … 0.7191186 0.26629153; … ; 0.23063432 -0.6323712 … -0.51756436 -0.031197049; 0.35919273 -0.20792396 … 0.59049577 -0.6734359], bias = Float32[0.0; 0.0; … ; 0.0; 0.0;;]), layer_4 = (weight = Float32[-0.43561655 -0.3437964 … 0.097560994 0.25912023; -0.8772459 -0.28295153 … 0.5422182 0.86514133], bias = Float32[0.0; 0.0;;])), (layer_1 = NamedTuple(), layer_2 = NamedTuple(), layer_3 = NamedTuple(), layer_4 = NamedTuple()))

Universal Differential Equation

Here we pretend we know that the system is linked by d🐱 depending on šŸ˜„ multiplicatively, and dšŸ˜„ depending on 🐱 also multiplicatively.

# Define the hybrid model
function ude_cat!(du, u, p, t) #, p_true)
    û = U(u, p, st)[1] # Network prediction
    du[1] = û[1]*u[2]
    du[2] = û[2]*u[1]
end
ude_cat! (generic function with 1 method)
#  ODEProblem{ Is In Place? , Specialize? } -> ? ODEFunction
prob_nn = ODEProblem{true, SciMLBase.FullSpecialize}(ude_cat!,uā‚€, tspan, pₙₙ)
ODEProblem with uType Vector{Float64} and tType Float64. In-place: true
timespan: (0.0, 10.0)
u0: 2-element Vector{Float64}:
 -1.0
  1.0

And we train!

function predict(Īø, X = uā‚€, T = t)
    _prob = remake(prob_nn, u0 = X, tspan = (T[1], T[end]), p = Īø)
    Array(solve(_prob, Vern7(), saveat = T,
                abstol=1e-6, reltol=1e-6,
                sensealg = ForwardDiffSensitivity()
                ))
end
predict (generic function with 3 methods)
# Simple L2 loss
function loss(Īø,hyper)
    XĢ‚ = predict(Īø)
    sum(abs2, X .- XĢ‚)
end
loss (generic function with 1 method)
begin
    adtype = Optimization.AutoZygote()
    optf = Optimization.OptimizationFunction(loss, adtype)
    optprob = Optimization.OptimizationProblem(optf, ComponentVector{Float64}(pₙₙ))
end
OptimizationProblem. In-place: true
u0: ComponentVector{Float64}(layer_1 = (weight = [0.4188396632671356 -0.5210763216018677; -0.3222286105155945 0.40877005457878113; … ; 0.09082499891519547 0.05813669413328171; 0.5349738597869873 -0.27198442816734314], bias = [0.0; 0.0; … ; 0.0; 0.0;;]), layer_2 = (weight = [-0.1483839750289917 0.7236945033073425 … -0.6029199957847595 -0.46782970428466797; 0.21620789170265198 0.704647421836853 … -0.4998382329940796 -0.6270864009857178; … ; 0.5261399745941162 0.6981041431427002 … -0.401753306388855 0.5352689027786255; -0.6902347207069397 0.04211282730102539 … 0.3115672469139099 0.7140418887138367], bias = [0.0; 0.0; … ; 0.0; 0.0;;]), layer_3 = (weight = [-0.7521383166313171 0.11509942263364792 … 0.2438768595457077 -0.35823479294776917; -0.5823813676834106 0.28912153840065 … 0.719118595123291 0.2662915289402008; … ; 0.23063431680202484 -0.632371187210083 … -0.5175643563270569 -0.031197048723697662; 0.35919272899627686 -0.20792396366596222 … 0.590495765209198 -0.6734359264373779], bias = [0.0; 0.0; … ; 0.0; 0.0;;]), layer_4 = (weight = [-0.43561655282974243 -0.34379640221595764 … 0.09756099432706833 0.25912022590637207; -0.8772459030151367 -0.2829515337944031 … 0.5422182083129883 0.8651413321495056], bias = [0.0; 0.0;;]))
begin
    res1 = Optimization.solve(optprob, ADAM(0.1), maxiters = 200)
end
u: ComponentVector{Float64}(layer_1 = (weight = [0.9329137145362386 -1.14390302410167; 0.5740133392071249 0.5662870873180993; … ; -0.1395497051412112 -1.4075775787854536; 1.8860097440145391 -1.448813214265837], bias = [-0.43235987250908525; -1.4042417890843308; … ; 0.8409739859030176; 0.2663964851796898;;]), layer_2 = (weight = [-0.08283411115410458 0.8032399062662892 … -0.5396684718387518 -0.6574632596279617; 0.43725167063027276 0.023382550764401645 … -1.2532325143886063 0.4794776690813305; … ; 0.8306520086616839 0.6419872749509058 … -0.6313244452762709 1.2275539739440995; -0.34945545168090625 -0.5393183918061216 … -0.351719770916033 1.6359672765617765], bias = [-0.19685687689991843; -0.43827466005949894; … ; -0.2385682493500643; -0.41712804771250467;;]), layer_3 = (weight = [-0.604588197400037 -0.6821058326823396 … -0.2235380487587348 -1.0477307359667047; -0.13782759900749839 1.3252832160697237 … 1.3752238168623576 1.7220475253066423; … ; 1.3191192875414361 0.20908812915557626 … 0.4028351033928086 0.7823028515441364; 1.4649870823537174 0.5863142899254746 … 1.1640569726118053 0.1391639056969676], bias = [0.024859395401106002; 0.48831481666259263; … ; 1.05939570616502; 1.138988026477087;;]), layer_4 = (weight = [-1.318967100974358 -1.5808122425942088 … -0.7897284976373723 -0.3490118740866796; -0.5401545612117324 0.12289341037177055 … 0.482754626019797 1.38381659460503], bias = [-0.8428392130106009; 0.46637384262714715;;]))
begin
    optprob2 = Optimization.OptimizationProblem(optf, res1.minimizer)
    res2 = Optimization.solve(optprob2, Optim.BFGS(initial_stepnorm=0.01), maxiters = 200)
end
u: ComponentVector{Float64}(layer_1 = (weight = [0.5586930244843454 -1.0653175859292792; -0.08168622464792798 0.7805346329618921; … ; 0.06354679687439149 -0.7207505119924804; 2.4568185519515 -2.088309740051133], bias = [-1.0134142571761495; -1.6936988331035274; … ; -0.1184186911403389; 0.012744160369824457;;]), layer_2 = (weight = [-0.3021431331206667 0.870370397951155 … -0.2574179829792301 -0.8025435456150155; 1.65375485400009 0.016689226414514143 … -1.9498968119037836 1.264071329717429; … ; 1.6655048365883753 -0.2240996622842686 … -1.304472139141909 2.527413354540473; 0.3033838667241871 -1.0682703304925405 … -0.6668502101734235 2.166048158288414], bias = [-0.6301150570386213; -0.5787047024002201; … ; -1.0627635248707525; -0.26988441165154614;;]), layer_3 = (weight = [-0.21914994184111103 -0.09986833653669905 … -0.5409559309448428 -0.8084014053032869; 0.24092636151538047 1.560851522841987 … 1.9328940719452674 1.1071648728324064; … ; 1.7119695666898302 0.08550747264910524 … 0.38698837594622026 0.7782489147754519; 1.2897415654369686 0.676871312448159 … 1.046028157116686 0.3381962238276238], bias = [-0.4905454574372125; -0.2858535317023063; … ; 0.45798096755190193; 0.8119435512490865;;]), layer_4 = (weight = [3.8496744892725383 -2.0908100451828413 … -1.3099105146314345 -0.907530076846057; -2.314065716578152 0.6123368350459606 … 0.3702974634035101 1.203212748354118], bias = [-1.0887535448107724; 0.7038602217116432;;]))
begin
    p_trained = res2.minimizer
    ts = first(solution.t):mean(diff(solution.t))/2:last(solution.t)
    XĢ‚ = predict(p_trained, X[:,1], ts)
    plot(solution, alpha = 0.75, color = :black, label = ["True Data" nothing])
    scatter!(ts,XĢ‚',alpha = 0.75, color = :red, label = ["NN Data" nothing])
end

Here we had to strongly limit the number of training iterations, as the model is running on Github Actions, and we don't want to have it cut midway through. But still, not bad of a fit eh!

Equation discovery

begin
    YĢ‚ = U(XĢ‚,p_trained,st)[1]
    Ī» = exp10.(-3:0.01:-1) # thresholds
    opt = STLSQ(Ī») # SINDy optimizer
    nn_problem = DirectDataDrivenProblem(XĢ‚, YĢ‚)
end
Direct DataDrivenProblem{Float64} ##DDProblem#968 in 2 dimensions and 69 samples
begin
    # Create a Basis
    @variables u[1:2]
    basis = Basis(polynomial_basis(u, 2),u)
end

$$\begin{align} \varphi_1 =& 1 \\ \varphi_2 =& u_1 \\ \varphi_3 =& u_1^{2} \\ \varphi_4 =& u_2 \\ \varphi_5 =& u_1 u_2 \\ \varphi_6 =& u_2^{2} \end{align}$$

sampler = DataProcessing(split = 0.9, shuffle = true)
DataProcessing
  split: Float64 0.9
  shuffle: Bool true
  batchsize: Int64 0
  partial: Bool true
  rng: TaskLocalRNG TaskLocalRNG()
nn_res = solve(nn_problem, basis, opt,
    options = DataDrivenCommonOptions(
        digits = 1,
        data_processing = sampler,
        maxiters = 10000,
        denoise = true,
        )
)
"DataDrivenSolution{Float64}"
get_basis(nn_res)

$$\begin{align} \varphi_1 =& p_1 + u_1^{2} p_2 \\ \varphi_2 =& u_1^{2} p_3 \end{align}$$

tiv = get_basis(nn_res) |> get_iv |> Symbolics.unwrap

$$\begin{equation} t \end{equation}$$

Built with Julia 1.8.3 and

ComponentArrays 0.13.4
DataDrivenDiffEq 1.0.1
DataDrivenSparse 0.1.1
DiffEqSensitivity 6.79.0
DifferentialEquations 7.6.0
Lux 0.4.36
ModelingToolkit 8.36.0
Optimization 3.10.0
OptimizationFlux 0.1.2
OptimizationOptimJL 0.1.5
OptimizationOptimisers 0.1.1
OrdinaryDiffEq 6.35.1
Plots 1.37.2

To run this tutorial locally, download [this file](/tutorials/03udesparse.jl) and open it with Pluto.jl.