Going Neural
In this part of the tutorial we'll see how the DifferentialEquations ecosystems can be interfaced with the Lux (and Flux) Deep Learning capabilities.
We will train Neural Networks to approximate a dynamical system (the same we defined in the previous Pluto).
We start by loading the necessary libraries
using Lux, DiffEqFlux, DifferentialEquations, Optimization, OptimizationOptimJL, Random, Plots
Simulate Data
We create some data for our analysis, setting up an ODE system. For the details of what's happening here
begin
rng = Random.default_rng()
uβ = Float32[1.0; 1.0]
datasize = 30
tspan = (0.0f0, 10.f0)
tsteps = range(tspan[1], tspan[2], length = datasize)
end
0.0f0:0.3448276f0:10.0f0
function cat_love(du,u,p,t)
π±, π = u
a, b = [1.0, 0.5]
du[1] = dπ± = - a*π
du[2] = dπ = b*π±
end;
begin
prob_cat = ODEProblem(cat_love, uβ, tspan)
true_data = Array(solve(prob_cat, Tsit5(), saveat = tsteps))
end
2Γ30 Matrix{Float32}: 1.0 0.629 0.220789 -0.200486 β¦ 0.918601 0.536955 0.123528 -0.297349 1.0 1.14113 1.21475 1.21651 1.03835 1.16454 1.22174 1.20661
Define the Neural Network
Lux is a lot similar to Flux, but with a different philosophy in how to handle the model (network) parameters, which makes it conveniente for our scenario. The network will take
dudtββ = Lux.Chain(
x -> x.^3,
Lux.Dense(2, 20, tanh),
Lux.Dense(20,20,tanh),
Lux.Dense(20, 2)
)
Chain( layer_1 = WrappedFunction(#4), layer_2 = Dense(2 => 20, tanh_fast), # 60 parameters layer_3 = Dense(20 => 20, tanh_fast), # 420 parameters layer_4 = Dense(20 => 2), # 42 parameters ) # Total: 522 parameters, # plus 0 states, summarysize 48 bytes.
pββ, strββ = Lux.setup(rng, dudtββ)
((layer_1 = NamedTuple(), layer_2 = (weight = Float32[0.48141208 0.080691345; -0.19250715 0.28334433; β¦ ; 0.3114537 -0.08708531; -0.5137568 0.50625026], bias = Float32[0.0; 0.0; β¦ ; 0.0; 0.0;;]), layer_3 = (weight = Float32[0.38209108 0.012740582 β¦ -0.3584142 -0.3686378; -0.24421029 0.35249817 β¦ -0.22306056 -0.058025397; β¦ ; 0.36102444 0.086774394 β¦ -0.1765904 -0.15196739; 0.32950535 0.35649836 β¦ 0.2989066 0.17584707], bias = Float32[0.0; 0.0; β¦ ; 0.0; 0.0;;]), layer_4 = (weight = Float32[0.19966958 -0.5135044 β¦ 0.16056709 0.47664255; 0.27101988 0.4461131 β¦ 0.48264143 -0.21090077], bias = Float32[0.0; 0.0;;])), (layer_1 = NamedTuple(), layer_2 = NamedTuple(), layer_3 = NamedTuple(), layer_4 = NamedTuple()))
We build a neural ordinary differential equation, the gradient is calculated via adjoints.
prob_neuralode = NeuralODE(dudtββ, tspan, Tsit5(), saveat = tsteps)
NeuralODE() # 522 parameters
Next, we build a prediction function that takes the initial values of the system, and the network information (parameters and structure).
function predict_neuralode(p)
Array(prob_neuralode(uβ, p, strββ)[1])
end
predict_neuralode (generic function with 1 method)
Define the loss function
The key step is now to define a loss function. The parameters of the network will be optimized to minimise the loss. The function needs to take two inputs: network params and hyperparameters (that we don't actually use in our case).
There's various way to do it, but we'll use Julia multiple dispatch defining a method for one argument, and another for two arguments.
function loss_neuralode(network_params)
pred = predict_neuralode(network_params)
loss = sum(abs2, true_data .- pred)
return loss, pred
end
loss_neuralode (generic function with 1 method)
loss_neuralode(p,hyper_p) = loss_neuralode(p)
loss_neuralode (generic function with 2 methods)
Train to Results
We're mostly done, now it's time to traing the network on the synthetic data we produced, plot and enjoy.
function plot_solutions(true_data, prediction)
plt = plot(tsteps, true_data', label = ["org π±" "org π"])
scatter!(plt, tsteps, prediction', label = ["pred π±" "pred π"])
return plt
end
plot_solutions (generic function with 1 method)
callback = function (p, l, pred; doplot = false)
println(l)
# plot current prediction against data
if doplot
plt = plot_solutions(true_data, pred)
return plt
end
return false
end
#6 (generic function with 1 method)
At the beginning the random initialization of the Neural Network does a very bad job, as we would expect!
pinit = Lux.ComponentArray(pββ)
ComponentVector{Float32}(layer_1 = Float32[], layer_2 = (weight = Float32[0.48141208 0.080691345; -0.19250715 0.28334433; β¦ ; 0.3114537 -0.08708531; -0.5137568 0.50625026], bias = Float32[0.0; 0.0; β¦ ; 0.0; 0.0;;]), layer_3 = (weight = Float32[0.38209108 0.012740582 β¦ -0.3584142 -0.3686378; -0.24421029 0.35249817 β¦ -0.22306056 -0.058025397; β¦ ; 0.36102444 0.086774394 β¦ -0.1765904 -0.15196739; 0.32950535 0.35649836 β¦ 0.2989066 0.17584707], bias = Float32[0.0; 0.0; β¦ ; 0.0; 0.0;;]), layer_4 = (weight = Float32[0.19966958 -0.5135044 β¦ 0.16056709 0.47664255; 0.27101988 0.4461131 β¦ 0.48264143 -0.21090077], bias = Float32[0.0; 0.0;;]))
callback(pinit, loss_neuralode(pinit)...; doplot=true)
We use Zygote.jl for handling Automatic Differentiation (reverse-mode AD), and Optimization.jl to define a representation of an optimization of an objective function f, defined by:
$$\min_{u} f(u,p)$$
adtype = Optimization.AutoZygote()
Optimization.AutoZygote()
optf = OptimizationFunction(loss_neuralode, adtype)
(::SciMLBase.OptimizationFunction{true, Optimization.AutoZygote, typeof(Main.var"workspace#4".loss_neuralode), Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED_NO_TIME), Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing}) (generic function with 1 method)
optprob = Optimization.OptimizationProblem(optf, pinit)
OptimizationProblem. In-place: true u0: ComponentVector{Float32}(layer_1 = Float32[], layer_2 = (weight = Float32[0.48141208 0.080691345; -0.19250715 0.28334433; β¦ ; 0.3114537 -0.08708531; -0.5137568 0.50625026], bias = Float32[0.0; 0.0; β¦ ; 0.0; 0.0;;]), layer_3 = (weight = Float32[0.38209108 0.012740582 β¦ -0.3584142 -0.3686378; -0.24421029 0.35249817 β¦ -0.22306056 -0.058025397; β¦ ; 0.36102444 0.086774394 β¦ -0.1765904 -0.15196739; 0.32950535 0.35649836 β¦ 0.2989066 0.17584707], bias = Float32[0.0; 0.0; β¦ ; 0.0; 0.0;;]), layer_4 = (weight = Float32[0.19966958 -0.5135044 β¦ 0.16056709 0.47664255; 0.27101988 0.4461131 β¦ 0.48264143 -0.21090077], bias = Float32[0.0; 0.0;;]))
And we train! Once with ADAM.
result_neuralode = Optimization.solve(optprob,
ADAM(0.05),
callback = callback,
maxiters = 300)
u: ComponentVector{Float32}(layer_1 = Float32[], layer_2 = (weight = Float32[-0.18703368 -0.08665115; 0.09643898 0.59104586; β¦ ; 0.008400643 -0.36410508; -0.20857872 0.83780897], bias = Float32[-0.020561367; 0.22864722; β¦ ; -0.21528973; 0.1539101;;]), layer_3 = (weight = Float32[0.06469259 -0.28483814 β¦ -0.067857966 -0.65807956; 0.018308736 0.62485164 β¦ 0.06786882 0.21999092; β¦ ; 0.055926092 -0.21345933 β¦ 0.11846514 -0.44960493; 0.045831285 0.06742562 β¦ 0.57291824 -0.117072105], bias = Float32[-0.27243486; 0.21784204; β¦ ; -0.30153865; -0.26715395;;]), layer_4 = (weight = Float32[-0.107740164 -0.7094275 β¦ -0.15050524 0.7438748; -0.04830679 0.16073358 β¦ 0.15324427 0.063451245], bias = Float32[-0.28312048; -0.2910977;;]))
callback(result_neuralode.u, loss_neuralode(result_neuralode.u)...; doplot=true)
optprob2 = remake(optprob,u0 = result_neuralode.u)
OptimizationProblem. In-place: true u0: ComponentVector{Float32}(layer_1 = Float32[], layer_2 = (weight = Float32[-0.18703368 -0.08665115; 0.09643898 0.59104586; β¦ ; 0.008400643 -0.36410508; -0.20857872 0.83780897], bias = Float32[-0.020561367; 0.22864722; β¦ ; -0.21528973; 0.1539101;;]), layer_3 = (weight = Float32[0.06469259 -0.28483814 β¦ -0.067857966 -0.65807956; 0.018308736 0.62485164 β¦ 0.06786882 0.21999092; β¦ ; 0.055926092 -0.21345933 β¦ 0.11846514 -0.44960493; 0.045831285 0.06742562 β¦ 0.57291824 -0.117072105], bias = Float32[-0.27243486; 0.21784204; β¦ ; -0.30153865; -0.26715395;;]), layer_4 = (weight = Float32[-0.107740164 -0.7094275 β¦ -0.15050524 0.7438748; -0.04830679 0.16073358 β¦ 0.15324427 0.063451245], bias = Float32[-0.28312048; -0.2910977;;]))
And once with BFGS (quasi-Newton method that updates an approximation to the Hessian using past approximations as well as the gradient).
result_neuralode2 = Optimization.solve(optprob2,
Optim.BFGS(initial_stepnorm=0.01),
callback=callback,
allow_f_increases = false)
u: ComponentVector{Float32}(layer_1 = Float32[], layer_2 = (weight = Float32[-0.31507474 -0.045717224; -0.02608984 0.5890759; β¦ ; 0.031012857 -0.36588567; -0.22888936 0.8383275], bias = Float32[0.014799015; 0.2504964; β¦ ; -0.17060904; 0.20097949;;]), layer_3 = (weight = Float32[0.064511605 -0.2907945 β¦ -0.06590484 -0.66270703; 0.009276138 0.6127131 β¦ 0.06101837 0.21883063; β¦ ; 0.05731808 -0.20909072 β¦ 0.11285106 -0.43947992; 0.06104287 0.08648598 β¦ 0.57356703 -0.10599235], bias = Float32[-0.27328825; 0.23201706; β¦ ; -0.29051417; -0.2493857;;]), layer_4 = (weight = Float32[-0.06532827 -0.66010374 β¦ -0.14326413 0.74342334; -0.041659426 0.15407513 β¦ 0.12824729 0.09886456], bias = Float32[-0.2034155; -0.2722975;;]))
callback(result_neuralode2.u, loss_neuralode(result_neuralode2.u)...; doplot=true)
Bonus: going Symbolic
Finally, if we are not satisfied of staying with a not much interpretable Neural Network, we can use something like SymbolicRegression to recover two functions closely approximating the behaviour of the NN.
using SymbolicRegression, SymbolicUtils
#result_neuralode2.u
ff(X) = first(dudtββ(X, result_neuralode2.u, strββ))
ff (generic function with 1 method)
Search_options = SymbolicRegression.Options(
binary_operators=(+, *, /, -),
unary_operators=(cos, exp),
npopulations=20
)
Options( # Operators: binops=Function[+, *, /, -], unaops=Function[cos, exp], # Loss: loss=L2DistLoss, # Complexity Management: maxsize=20, maxdepth=20, bin_constraints=[(-1, -1), (-1, -1), (-1, -1), (-1, -1)], una_constraints=[-1, -1], use_frequency=true, use_frequency_in_tournament=true, parsimony=0.0032, warmup_maxsize_by=0.0, # Search Size: npopulations=20, ncycles_per_iteration=550, npop=33, # Migration: migration=true, hof_migration=true, fraction_replaced=0.00036, fraction_replaced_hof=0.035, # Tournaments: prob_pick_first=0.86, tournament_selection_n=12, topn=12, # Constant tuning: perturbation_factor=0.076, probability_negate_constant=0.01, should_optimize_constants=true, optimizer_algorithm=BFGS, optimizer_probability=0.14, optimizer_nrestarts=2, optimizer_iterations=8, # Mutations: mutation_weights=SymbolicRegression.CoreModule.OptionsStructModule.MutationWeights(0.048, 0.47, 0.79, 5.1, 1.7, 0.002, 0.00023, 0.21, 0.0), crossover_probability=0.066, skip_mutation_failures=true # Annealing: annealing=false, alpha=0.1, # Speed Tweaks: batching=false, batch_size=50, fast_cycle=false, # Logistics: output_file=hall_of_fame.csv, verbosity=0, seed=nothing, progress=true, # Early Exit: early_stop_condition=nothing, timeout_in_seconds=nothing, )
X_inputs = hcat([ [i,j] for i in Float32.(range(-2.0,2.0,100)) for j in Float32.(range(-2.0,2.0,100))]...)
2Γ10000 Matrix{Float32}: -2.0 -2.0 -2.0 -2.0 β¦ 2.0 2.0 2.0 2.0 2.0 -2.0 -1.9596 -1.91919 -1.87879 1.83838 1.87879 1.91919 1.9596 2.0
y_outputs = mapslices(ff,X_inputs, dims = 1)
2Γ10000 Matrix{Float32}: 0.999807 1.00186 1.00431 1.00717 β¦ -1.46445 -1.46381 -1.46345 -0.375392 -0.388972 -0.405392 -0.424926 0.0534068 0.0481486 0.0439227
hall_of_fame = EquationSearch(
X_inputs,
y_outputs,
niterations=40,
options=Search_options,
parallelism=:multithreading
)
2-element Vector{HallOfFame{Float32}}: HallOfFame{Float32}(SymbolicRegression.PopMemberModule.PopMember{Float32}[PopMember{Float32}(-0.07428278, 1.0032f0, 0.8772836f0, 16709845251576330, 5076706588222721652, 5084970122937532414), PopMember{Float32}(cos(-1.64911), 1.0064179f0, 0.87729925f0, 16709845439236880, 7850576782106859271, 7636084086841349744), PopMember{Float32}((x2 / -1.2907078), 0.08517534f0, 0.06630101f0, 16709851607858870, 887379761439760799, 942695892763958296), PopMember{Float32}(cos(-20.38411 - x2), 0.10257155f0, 0.0787551f0, 16709854455731110, 7212984294470400568, 8039148132759834383), PopMember{Float32}(((x2 * -0.7747569) - 0.07427521), 0.08528681f0, 0.06078418f0, 16709845265261158, 7367202346146461500, 8838213112199601501), PopMember{Float32}((cos(-1.5319057 - x2) * 1.1738775), 0.0884158f0, 0.06072189f0, 16709847372526400, 1864866950780855736, 5660364932754329303), PopMember{Float32}((((-0.049507447 - x2) + -0.046363194) * 0.774757), 0.09168679f0, 0.06078417f0, 16709853181080970, 8305250481995377216, 1039816071626935425), PopMember{Float32}((cos(x2 * 0.43625915) * (-0.06777891 - x2)), 0.082472645f0, 0.049893435f0, 16709847987948680, 7521684404430242408, 6361211929706889548), PopMember{Float32}((cos(cos(x2 / 0.7017368)) * (-0.09597678 - x2)), 0.06350014f0, 0.030441867f0, 16709845442204720, 9160824584414961171, 3361581301686038635), PopMember{Float32}((cos(cos(x2 / cos(0.8159817))) * (-0.09597678 - x2)), 0.06723676f0, 0.03091263f0, 16709853069270850, 5207699387688815969, 9160824584414961171) β¦ PopMember{Float32}((cos(cos((x2 - 0.1682796) / 0.7156774) * -1.0742114) * (-0.059516314 - x2)), 0.062153522f0, 0.018031267f0, 16709850767553000, 8798010999718922636, 7272956553658801524), PopMember{Float32}((cos(-0.055472262) * (cos(cos((0.18118364 - x2) / -0.69111735)) * (-0.055472262 - x2))), 0.066129796f0, 0.018712282f0, 16709853058862640, 6382948575759574930, 8366008705968491200), PopMember{Float32}((cos((cos((0.1651691 - x2) / 0.71373415) / 0.94113725) + 0.030790726) * (-0.07568518 - x2)), 0.06855005f0, 0.018028222f0, 16709848000891600, 2996957690937921374, 8123041194607217162), PopMember{Float32}((cos(0.28539377 / x2) * (cos(cos((0.17224659 - x2) / -0.69537413)) * (-0.069837265 - x2))), 0.06988191f0, 0.016389335f0, 16709851329461520, 9081805609971732856, 6981649544356303419), PopMember{Float32}((cos((cos((0.1651691 - x2) / 0.71373415) / (0.94113725 - 0.054117296)) + 0.030790726) * (-0.07568518 - x2)), 0.07463906f0, 0.017755397f0, 16709853320682040, 7898288797189064930, 4612302297109160375), PopMember{Float32}((cos(0.15058275 / (x2 * x2)) * (cos(cos((0.17224659 - x2) / 0.6906991)) * (-0.069837265 - x2))), 0.0748636f0, 0.015145071f0, 16709847808543730, 6410237777821918747, 6591478495493270255), PopMember{Float32}((cos(exp(1.0759605 / x2) * -0.07723609) * (cos(cos((0.15539327 - x2) / -0.6946569)) * (-0.08689842 - x2))), 0.0781866f0, 0.015252985f0, 16709850810418230, 7581669140362096293, 1829290525319419186), PopMember{Float32}((cos(0.13252431 / ((x2 - 0.19706355) * x2)) * (cos(cos((0.16643244 - x2) / 0.6862906)) * (-0.07448242 - x2))), 0.08073322f0, 0.014679781f0, 16709854329443250, 6072012253835945293, 4934239010082602252), PopMember{Float32}((cos(0.14240952 / exp(-0.1281856 / (x2 * 0.17009799))) * (cos(cos((0.16723333 - x2) / 0.693912)) * (-0.10651281 - x2))), 0.085010275f0, 0.015624663f0, 16709848578614270, 236753394310803842, 4584695452403089170), PopMember{Float32}(1.0, 0.0f0, Inf32, 16709845046976820, 6055331240097854964, -1)], Bool[1, 1, 1, 1, 1, 1, 1, 1, 1, 1 β¦ 1, 1, 1, 1, 1, 1, 1, 1, 1, 0]) HallOfFame{Float32}(SymbolicRegression.PopMemberModule.PopMember{Float32}[PopMember{Float32}(0.004151821, 1.0032f0, 0.273602f0, 16709845234779500, 8070853189669326107, 2635277088445305388), PopMember{Float32}(cos(-1.5483272), 1.0076267f0, 0.2739376f0, 16709845178162370, 5475386534062334964, 7605980807250228133), PopMember{Float32}((x1 * 0.36743218), 0.3384446f0, 0.08997254f0, 16709845234634320, 1872890424507990925, 3167078465170371481), PopMember{Float32}((x1 * exp(-0.97522247)), 0.34210995f0, 0.09009986f0, 16709845352402338, 1112272242661843186, 1712261236233863731), PopMember{Float32}(((x1 * 0.36743385) - -0.0041470435), 0.34478176f0, 0.089955345f0, 16709845234672400, 3440149990394048360, 5554081351348120563), PopMember{Float32}(((x1 + cos(x2)) * 0.32625407), 0.3332336f0, 0.08592022f0, 16709851874652060, 5938528166413979203, 4613045698859213061), PopMember{Float32}(((x1 * 0.16589196) * exp(cos(x2))), 0.31302765f0, 0.079516314f0, 16709845844643260, 3214912282003004927, 7518735306497207742), PopMember{Float32}((((cos(x2) + -0.38235453) + x1) * 0.35924438), 0.25893888f0, 0.06384199f0, 16709852116007220, 8725833963185887347, 4368249879606765909), PopMember{Float32}(((cos(x2 / cos(0.7712531)) + x1) * 0.3078712), 0.27282313f0, 0.06676522f0, 16709854255337530, 7529009169055405677, 2323972361638573193), PopMember{Float32}((((cos(x2) * (x1 - -0.51495445)) + x1) * 0.26199228), 0.2410858f0, 0.05720629f0, 16709848227196540, 2132428152190934385, 3146871906772545785) β¦ PopMember{Float32}((((cos(x2) * ((cos(x1) + x1) + x2)) + x1) * 0.2564885), 0.20804597f0, 0.045539953f0, 16709852408871592, 7710337310754655154, 2250525985789501388), PopMember{Float32}(((x1 - (0.5018772 - (((x1 - -1.3673851) + x2) * cos(x2)))) * 0.2558748), 0.19180539f0, 0.040220972f0, 16709854186594360, 7624623670800037198, 7428845258692451161), PopMember{Float32}((((x1 - (1.0964998 - ((x1 - -1.4363981) * cos(x2)))) + cos(x1)) * 0.25581315), 0.18129107f0, 0.036468703f0, 16709848896985550, 2437010203115423985, 6953421183196286170), PopMember{Float32}((((x1 - (0.8995929 - ((cos(x2) + x1) * cos(x2)))) + cos(x1)) * 0.25581315), 0.18598628f0, 0.036877796f0, 16709848337604610, 6838466055978924860, 1258820144270218292), PopMember{Float32}(((x1 - (0.8929953 - ((((x1 - -1.414429) + x2) + cos(x1)) * cos(x2)))) * 0.25163484), 0.17756225f0, 0.03369744f0, 16709853266634560, 7950141200518708974, 3815182431900933944), PopMember{Float32}((((x1 - (0.8995929 - (((cos(x2) + x2) + x1) * cos(x2)))) + cos(x1)) * 0.25581315), 0.17817798f0, 0.03299038f0, 16709848971623180, 4297569995501528575, 6838466055978924860), PopMember{Float32}(((x1 - (0.8920709 - (((x1 - -1.7248862) + (cos(x1) - (x2 * x2))) * cos(x2)))) * 0.25167146), 0.15561232f0, 0.025940841f0, 16709853833463960, 2373154384291104908, 3536958506904971102), PopMember{Float32}((((x1 - (1.1415747 - ((((x1 + x2) + cos(x1)) + cos(x2)) * cos(x2)))) + 0.55560625) * 0.25959492), 0.15691113f0, 0.025420673f0, 16709849912677810, 3723772632457375661, 5454430744039521221), PopMember{Float32}(((x1 - (0.8929953 - ((((x1 - -1.726694) + x2) + (cos(x1) - (x2 * x2))) * cos(x2)))) * 0.25163484), 0.14735726f0, 0.021931188f0, 16709850634177970, 1234042143914524821, 1618882768450674886), PopMember{Float32}(1.0, 0.0f0, Inf32, 16709845046976850, 1631187884433385412, -1)], Bool[1, 1, 1, 1, 1, 1, 1, 1, 1, 1 β¦ 1, 1, 1, 1, 1, 1, 1, 1, 1, 0])
Built with Julia 1.8.3 and
DiffEqFlux 1.52.0DifferentialEquations 7.6.0
Lux 0.4.36
Optimization 3.10.0
OptimizationOptimJL 0.1.5
Plots 1.37.0
SymbolicRegression 0.14.5
SymbolicUtils 0.19.11
To run this tutorial locally, download [this file](/tutorials/02neuralode.jl) and open it with Pluto.jl.