Going Neural

In this part of the tutorial we'll see how the DifferentialEquations ecosystems can be interfaced with the Lux (and Flux) Deep Learning capabilities.

We will train Neural Networks to approximate a dynamical system (the same we defined in the previous Pluto).

We start by loading the necessary libraries

using Lux, DiffEqFlux, DifferentialEquations, Optimization, OptimizationOptimJL, Random, Plots

Simulate Data

We create some data for our analysis, setting up an ODE system. For the details of what's happening here

begin
    rng = Random.default_rng()
    uβ‚€ = Float32[1.0; 1.0]
    datasize = 30
    tspan = (0.0f0, 10.f0)
    tsteps = range(tspan[1], tspan[2], length = datasize)
end
0.0f0:0.3448276f0:10.0f0
function cat_love(du,u,p,t)
    🐱, πŸ˜„ = u
    a, b = [1.0, 0.5]
    du[1] = d🐱 = - a*πŸ˜„
    du[2] = dπŸ˜„ = b*🐱
end;
begin
    prob_cat = ODEProblem(cat_love, uβ‚€, tspan)
    true_data = Array(solve(prob_cat, Tsit5(), saveat = tsteps))
end
2Γ—30 Matrix{Float32}:
 1.0  0.629    0.220789  -0.200486  …  0.918601  0.536955  0.123528  -0.297349
 1.0  1.14113  1.21475    1.21651      1.03835   1.16454   1.22174    1.20661

Define the Neural Network

Lux is a lot similar to Flux, but with a different philosophy in how to handle the model (network) parameters, which makes it conveniente for our scenario. The network will take

dudtβ‚™β‚™ = Lux.Chain(
    x -> x.^3,
    Lux.Dense(2, 20, tanh),
    Lux.Dense(20,20,tanh),
    Lux.Dense(20, 2)
)
Chain(
    layer_1 = WrappedFunction(#4),
    layer_2 = Dense(2 => 20, tanh_fast),  # 60 parameters
    layer_3 = Dense(20 => 20, tanh_fast),  # 420 parameters
    layer_4 = Dense(20 => 2),           # 42 parameters
)         # Total: 522 parameters,
          #        plus 0 states, summarysize 48 bytes.
pβ‚™β‚™, strβ‚™β‚™ = Lux.setup(rng, dudtβ‚™β‚™)
((layer_1 = NamedTuple(), layer_2 = (weight = Float32[0.48141208 0.080691345; -0.19250715 0.28334433; … ; 0.3114537 -0.08708531; -0.5137568 0.50625026], bias = Float32[0.0; 0.0; … ; 0.0; 0.0;;]), layer_3 = (weight = Float32[0.38209108 0.012740582 … -0.3584142 -0.3686378; -0.24421029 0.35249817 … -0.22306056 -0.058025397; … ; 0.36102444 0.086774394 … -0.1765904 -0.15196739; 0.32950535 0.35649836 … 0.2989066 0.17584707], bias = Float32[0.0; 0.0; … ; 0.0; 0.0;;]), layer_4 = (weight = Float32[0.19966958 -0.5135044 … 0.16056709 0.47664255; 0.27101988 0.4461131 … 0.48264143 -0.21090077], bias = Float32[0.0; 0.0;;])), (layer_1 = NamedTuple(), layer_2 = NamedTuple(), layer_3 = NamedTuple(), layer_4 = NamedTuple()))

We build a neural ordinary differential equation, the gradient is calculated via adjoints.

prob_neuralode = NeuralODE(dudtβ‚™β‚™, tspan, Tsit5(), saveat = tsteps)
NeuralODE()         # 522 parameters

Next, we build a prediction function that takes the initial values of the system, and the network information (parameters and structure).

function predict_neuralode(p)
  Array(prob_neuralode(uβ‚€, p, strβ‚™β‚™)[1])
end
predict_neuralode (generic function with 1 method)

Define the loss function

The key step is now to define a loss function. The parameters of the network will be optimized to minimise the loss. The function needs to take two inputs: network params and hyperparameters (that we don't actually use in our case).

There's various way to do it, but we'll use Julia multiple dispatch defining a method for one argument, and another for two arguments.

function loss_neuralode(network_params)
    pred = predict_neuralode(network_params)
    loss = sum(abs2, true_data .- pred)
    return loss, pred
end
loss_neuralode (generic function with 1 method)
loss_neuralode(p,hyper_p) = loss_neuralode(p)
loss_neuralode (generic function with 2 methods)

Train to Results

We're mostly done, now it's time to traing the network on the synthetic data we produced, plot and enjoy.

function plot_solutions(true_data, prediction)
    plt = plot(tsteps, true_data', label = ["org 🐱" "org πŸ˜„"])
    scatter!(plt, tsteps, prediction',  label = ["pred 🐱" "pred πŸ˜„"])
    return plt
end
plot_solutions (generic function with 1 method)
callback = function (p, l, pred; doplot = false)
  println(l)
  # plot current prediction against data
  if doplot
    plt =  plot_solutions(true_data, pred)
      return plt
  end
  return false
end
#6 (generic function with 1 method)

At the beginning the random initialization of the Neural Network does a very bad job, as we would expect!

pinit = Lux.ComponentArray(pβ‚™β‚™)
ComponentVector{Float32}(layer_1 = Float32[], layer_2 = (weight = Float32[0.48141208 0.080691345; -0.19250715 0.28334433; … ; 0.3114537 -0.08708531; -0.5137568 0.50625026], bias = Float32[0.0; 0.0; … ; 0.0; 0.0;;]), layer_3 = (weight = Float32[0.38209108 0.012740582 … -0.3584142 -0.3686378; -0.24421029 0.35249817 … -0.22306056 -0.058025397; … ; 0.36102444 0.086774394 … -0.1765904 -0.15196739; 0.32950535 0.35649836 … 0.2989066 0.17584707], bias = Float32[0.0; 0.0; … ; 0.0; 0.0;;]), layer_4 = (weight = Float32[0.19966958 -0.5135044 … 0.16056709 0.47664255; 0.27101988 0.4461131 … 0.48264143 -0.21090077], bias = Float32[0.0; 0.0;;]))
callback(pinit, loss_neuralode(pinit)...; doplot=true)

We use Zygote.jl for handling Automatic Differentiation (reverse-mode AD), and Optimization.jl to define a representation of an optimization of an objective function f, defined by:

$$\min_{u} f(u,p)$$

adtype = Optimization.AutoZygote()
Optimization.AutoZygote()
optf = OptimizationFunction(loss_neuralode, adtype)
(::SciMLBase.OptimizationFunction{true, Optimization.AutoZygote, typeof(Main.var"workspace#4".loss_neuralode), Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED_NO_TIME), Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing}) (generic function with 1 method)
optprob = Optimization.OptimizationProblem(optf, pinit)
OptimizationProblem. In-place: true
u0: ComponentVector{Float32}(layer_1 = Float32[], layer_2 = (weight = Float32[0.48141208 0.080691345; -0.19250715 0.28334433; … ; 0.3114537 -0.08708531; -0.5137568 0.50625026], bias = Float32[0.0; 0.0; … ; 0.0; 0.0;;]), layer_3 = (weight = Float32[0.38209108 0.012740582 … -0.3584142 -0.3686378; -0.24421029 0.35249817 … -0.22306056 -0.058025397; … ; 0.36102444 0.086774394 … -0.1765904 -0.15196739; 0.32950535 0.35649836 … 0.2989066 0.17584707], bias = Float32[0.0; 0.0; … ; 0.0; 0.0;;]), layer_4 = (weight = Float32[0.19966958 -0.5135044 … 0.16056709 0.47664255; 0.27101988 0.4461131 … 0.48264143 -0.21090077], bias = Float32[0.0; 0.0;;]))

And we train! Once with ADAM.

result_neuralode = Optimization.solve(optprob,
                                       ADAM(0.05),
                                       callback = callback,
                                       maxiters = 300)
u: ComponentVector{Float32}(layer_1 = Float32[], layer_2 = (weight = Float32[-0.18703368 -0.08665115; 0.09643898 0.59104586; … ; 0.008400643 -0.36410508; -0.20857872 0.83780897], bias = Float32[-0.020561367; 0.22864722; … ; -0.21528973; 0.1539101;;]), layer_3 = (weight = Float32[0.06469259 -0.28483814 … -0.067857966 -0.65807956; 0.018308736 0.62485164 … 0.06786882 0.21999092; … ; 0.055926092 -0.21345933 … 0.11846514 -0.44960493; 0.045831285 0.06742562 … 0.57291824 -0.117072105], bias = Float32[-0.27243486; 0.21784204; … ; -0.30153865; -0.26715395;;]), layer_4 = (weight = Float32[-0.107740164 -0.7094275 … -0.15050524 0.7438748; -0.04830679 0.16073358 … 0.15324427 0.063451245], bias = Float32[-0.28312048; -0.2910977;;]))
callback(result_neuralode.u, loss_neuralode(result_neuralode.u)...; doplot=true)
optprob2 = remake(optprob,u0 = result_neuralode.u)
OptimizationProblem. In-place: true
u0: ComponentVector{Float32}(layer_1 = Float32[], layer_2 = (weight = Float32[-0.18703368 -0.08665115; 0.09643898 0.59104586; … ; 0.008400643 -0.36410508; -0.20857872 0.83780897], bias = Float32[-0.020561367; 0.22864722; … ; -0.21528973; 0.1539101;;]), layer_3 = (weight = Float32[0.06469259 -0.28483814 … -0.067857966 -0.65807956; 0.018308736 0.62485164 … 0.06786882 0.21999092; … ; 0.055926092 -0.21345933 … 0.11846514 -0.44960493; 0.045831285 0.06742562 … 0.57291824 -0.117072105], bias = Float32[-0.27243486; 0.21784204; … ; -0.30153865; -0.26715395;;]), layer_4 = (weight = Float32[-0.107740164 -0.7094275 … -0.15050524 0.7438748; -0.04830679 0.16073358 … 0.15324427 0.063451245], bias = Float32[-0.28312048; -0.2910977;;]))

And once with BFGS (quasi-Newton method that updates an approximation to the Hessian using past approximations as well as the gradient).

result_neuralode2 = Optimization.solve(optprob2,
                                        Optim.BFGS(initial_stepnorm=0.01),
                                        callback=callback,
                                        allow_f_increases = false)
u: ComponentVector{Float32}(layer_1 = Float32[], layer_2 = (weight = Float32[-0.31507474 -0.045717224; -0.02608984 0.5890759; … ; 0.031012857 -0.36588567; -0.22888936 0.8383275], bias = Float32[0.014799015; 0.2504964; … ; -0.17060904; 0.20097949;;]), layer_3 = (weight = Float32[0.064511605 -0.2907945 … -0.06590484 -0.66270703; 0.009276138 0.6127131 … 0.06101837 0.21883063; … ; 0.05731808 -0.20909072 … 0.11285106 -0.43947992; 0.06104287 0.08648598 … 0.57356703 -0.10599235], bias = Float32[-0.27328825; 0.23201706; … ; -0.29051417; -0.2493857;;]), layer_4 = (weight = Float32[-0.06532827 -0.66010374 … -0.14326413 0.74342334; -0.041659426 0.15407513 … 0.12824729 0.09886456], bias = Float32[-0.2034155; -0.2722975;;]))
callback(result_neuralode2.u, loss_neuralode(result_neuralode2.u)...; doplot=true)

Bonus: going Symbolic

Finally, if we are not satisfied of staying with a not much interpretable Neural Network, we can use something like SymbolicRegression to recover two functions closely approximating the behaviour of the NN.

using SymbolicRegression, SymbolicUtils
#result_neuralode2.u
ff(X) = first(dudtβ‚™β‚™(X, result_neuralode2.u, strβ‚™β‚™))
ff (generic function with 1 method)
Search_options = SymbolicRegression.Options(
    binary_operators=(+, *, /, -),
    unary_operators=(cos, exp),
    npopulations=20
)
Options(
    # Operators:
        binops=Function[+, *, /, -], unaops=Function[cos, exp],
    # Loss:
        loss=L2DistLoss,
    # Complexity Management:
        maxsize=20, maxdepth=20, bin_constraints=[(-1, -1), (-1, -1), (-1, -1), (-1, -1)], una_constraints=[-1, -1], use_frequency=true, use_frequency_in_tournament=true, parsimony=0.0032, warmup_maxsize_by=0.0, 
    # Search Size:
        npopulations=20, ncycles_per_iteration=550, npop=33, 
    # Migration:
        migration=true, hof_migration=true, fraction_replaced=0.00036, fraction_replaced_hof=0.035,
    # Tournaments:
        prob_pick_first=0.86, tournament_selection_n=12, topn=12, 
    # Constant tuning:
        perturbation_factor=0.076, probability_negate_constant=0.01, should_optimize_constants=true, optimizer_algorithm=BFGS, optimizer_probability=0.14, optimizer_nrestarts=2, optimizer_iterations=8,
    # Mutations:
        mutation_weights=SymbolicRegression.CoreModule.OptionsStructModule.MutationWeights(0.048, 0.47, 0.79, 5.1, 1.7, 0.002, 0.00023, 0.21, 0.0), crossover_probability=0.066, skip_mutation_failures=true
    # Annealing:
        annealing=false, alpha=0.1, 
    # Speed Tweaks:
        batching=false, batch_size=50, fast_cycle=false, 
    # Logistics:
        output_file=hall_of_fame.csv, verbosity=0, seed=nothing, progress=true,
    # Early Exit:
        early_stop_condition=nothing, timeout_in_seconds=nothing,
)
X_inputs = hcat([ [i,j] for i in Float32.(range(-2.0,2.0,100)) for j in Float32.(range(-2.0,2.0,100))]...)
2Γ—10000 Matrix{Float32}:
 -2.0  -2.0     -2.0      -2.0      …  2.0      2.0      2.0      2.0     2.0
 -2.0  -1.9596  -1.91919  -1.87879     1.83838  1.87879  1.91919  1.9596  2.0
y_outputs = mapslices(ff,X_inputs, dims = 1)
2Γ—10000 Matrix{Float32}:
  0.999807   1.00186    1.00431    1.00717   …  -1.46445    -1.46381    -1.46345
 -0.375392  -0.388972  -0.405392  -0.424926      0.0534068   0.0481486   0.0439227
hall_of_fame = EquationSearch(
    X_inputs,
    y_outputs,
    niterations=40,
    options=Search_options,
    parallelism=:multithreading
)
2-element Vector{HallOfFame{Float32}}:
 HallOfFame{Float32}(SymbolicRegression.PopMemberModule.PopMember{Float32}[PopMember{Float32}(-0.07428278, 1.0032f0, 0.8772836f0, 16709845251576330, 5076706588222721652, 5084970122937532414), PopMember{Float32}(cos(-1.64911), 1.0064179f0, 0.87729925f0, 16709845439236880, 7850576782106859271, 7636084086841349744), PopMember{Float32}((x2 / -1.2907078), 0.08517534f0, 0.06630101f0, 16709851607858870, 887379761439760799, 942695892763958296), PopMember{Float32}(cos(-20.38411 - x2), 0.10257155f0, 0.0787551f0, 16709854455731110, 7212984294470400568, 8039148132759834383), PopMember{Float32}(((x2 * -0.7747569) - 0.07427521), 0.08528681f0, 0.06078418f0, 16709845265261158, 7367202346146461500, 8838213112199601501), PopMember{Float32}((cos(-1.5319057 - x2) * 1.1738775), 0.0884158f0, 0.06072189f0, 16709847372526400, 1864866950780855736, 5660364932754329303), PopMember{Float32}((((-0.049507447 - x2) + -0.046363194) * 0.774757), 0.09168679f0, 0.06078417f0, 16709853181080970, 8305250481995377216, 1039816071626935425), PopMember{Float32}((cos(x2 * 0.43625915) * (-0.06777891 - x2)), 0.082472645f0, 0.049893435f0, 16709847987948680, 7521684404430242408, 6361211929706889548), PopMember{Float32}((cos(cos(x2 / 0.7017368)) * (-0.09597678 - x2)), 0.06350014f0, 0.030441867f0, 16709845442204720, 9160824584414961171, 3361581301686038635), PopMember{Float32}((cos(cos(x2 / cos(0.8159817))) * (-0.09597678 - x2)), 0.06723676f0, 0.03091263f0, 16709853069270850, 5207699387688815969, 9160824584414961171)  …  PopMember{Float32}((cos(cos((x2 - 0.1682796) / 0.7156774) * -1.0742114) * (-0.059516314 - x2)), 0.062153522f0, 0.018031267f0, 16709850767553000, 8798010999718922636, 7272956553658801524), PopMember{Float32}((cos(-0.055472262) * (cos(cos((0.18118364 - x2) / -0.69111735)) * (-0.055472262 - x2))), 0.066129796f0, 0.018712282f0, 16709853058862640, 6382948575759574930, 8366008705968491200), PopMember{Float32}((cos((cos((0.1651691 - x2) / 0.71373415) / 0.94113725) + 0.030790726) * (-0.07568518 - x2)), 0.06855005f0, 0.018028222f0, 16709848000891600, 2996957690937921374, 8123041194607217162), PopMember{Float32}((cos(0.28539377 / x2) * (cos(cos((0.17224659 - x2) / -0.69537413)) * (-0.069837265 - x2))), 0.06988191f0, 0.016389335f0, 16709851329461520, 9081805609971732856, 6981649544356303419), PopMember{Float32}((cos((cos((0.1651691 - x2) / 0.71373415) / (0.94113725 - 0.054117296)) + 0.030790726) * (-0.07568518 - x2)), 0.07463906f0, 0.017755397f0, 16709853320682040, 7898288797189064930, 4612302297109160375), PopMember{Float32}((cos(0.15058275 / (x2 * x2)) * (cos(cos((0.17224659 - x2) / 0.6906991)) * (-0.069837265 - x2))), 0.0748636f0, 0.015145071f0, 16709847808543730, 6410237777821918747, 6591478495493270255), PopMember{Float32}((cos(exp(1.0759605 / x2) * -0.07723609) * (cos(cos((0.15539327 - x2) / -0.6946569)) * (-0.08689842 - x2))), 0.0781866f0, 0.015252985f0, 16709850810418230, 7581669140362096293, 1829290525319419186), PopMember{Float32}((cos(0.13252431 / ((x2 - 0.19706355) * x2)) * (cos(cos((0.16643244 - x2) / 0.6862906)) * (-0.07448242 - x2))), 0.08073322f0, 0.014679781f0, 16709854329443250, 6072012253835945293, 4934239010082602252), PopMember{Float32}((cos(0.14240952 / exp(-0.1281856 / (x2 * 0.17009799))) * (cos(cos((0.16723333 - x2) / 0.693912)) * (-0.10651281 - x2))), 0.085010275f0, 0.015624663f0, 16709848578614270, 236753394310803842, 4584695452403089170), PopMember{Float32}(1.0, 0.0f0, Inf32, 16709845046976820, 6055331240097854964, -1)], Bool[1, 1, 1, 1, 1, 1, 1, 1, 1, 1  …  1, 1, 1, 1, 1, 1, 1, 1, 1, 0])
 HallOfFame{Float32}(SymbolicRegression.PopMemberModule.PopMember{Float32}[PopMember{Float32}(0.004151821, 1.0032f0, 0.273602f0, 16709845234779500, 8070853189669326107, 2635277088445305388), PopMember{Float32}(cos(-1.5483272), 1.0076267f0, 0.2739376f0, 16709845178162370, 5475386534062334964, 7605980807250228133), PopMember{Float32}((x1 * 0.36743218), 0.3384446f0, 0.08997254f0, 16709845234634320, 1872890424507990925, 3167078465170371481), PopMember{Float32}((x1 * exp(-0.97522247)), 0.34210995f0, 0.09009986f0, 16709845352402338, 1112272242661843186, 1712261236233863731), PopMember{Float32}(((x1 * 0.36743385) - -0.0041470435), 0.34478176f0, 0.089955345f0, 16709845234672400, 3440149990394048360, 5554081351348120563), PopMember{Float32}(((x1 + cos(x2)) * 0.32625407), 0.3332336f0, 0.08592022f0, 16709851874652060, 5938528166413979203, 4613045698859213061), PopMember{Float32}(((x1 * 0.16589196) * exp(cos(x2))), 0.31302765f0, 0.079516314f0, 16709845844643260, 3214912282003004927, 7518735306497207742), PopMember{Float32}((((cos(x2) + -0.38235453) + x1) * 0.35924438), 0.25893888f0, 0.06384199f0, 16709852116007220, 8725833963185887347, 4368249879606765909), PopMember{Float32}(((cos(x2 / cos(0.7712531)) + x1) * 0.3078712), 0.27282313f0, 0.06676522f0, 16709854255337530, 7529009169055405677, 2323972361638573193), PopMember{Float32}((((cos(x2) * (x1 - -0.51495445)) + x1) * 0.26199228), 0.2410858f0, 0.05720629f0, 16709848227196540, 2132428152190934385, 3146871906772545785)  …  PopMember{Float32}((((cos(x2) * ((cos(x1) + x1) + x2)) + x1) * 0.2564885), 0.20804597f0, 0.045539953f0, 16709852408871592, 7710337310754655154, 2250525985789501388), PopMember{Float32}(((x1 - (0.5018772 - (((x1 - -1.3673851) + x2) * cos(x2)))) * 0.2558748), 0.19180539f0, 0.040220972f0, 16709854186594360, 7624623670800037198, 7428845258692451161), PopMember{Float32}((((x1 - (1.0964998 - ((x1 - -1.4363981) * cos(x2)))) + cos(x1)) * 0.25581315), 0.18129107f0, 0.036468703f0, 16709848896985550, 2437010203115423985, 6953421183196286170), PopMember{Float32}((((x1 - (0.8995929 - ((cos(x2) + x1) * cos(x2)))) + cos(x1)) * 0.25581315), 0.18598628f0, 0.036877796f0, 16709848337604610, 6838466055978924860, 1258820144270218292), PopMember{Float32}(((x1 - (0.8929953 - ((((x1 - -1.414429) + x2) + cos(x1)) * cos(x2)))) * 0.25163484), 0.17756225f0, 0.03369744f0, 16709853266634560, 7950141200518708974, 3815182431900933944), PopMember{Float32}((((x1 - (0.8995929 - (((cos(x2) + x2) + x1) * cos(x2)))) + cos(x1)) * 0.25581315), 0.17817798f0, 0.03299038f0, 16709848971623180, 4297569995501528575, 6838466055978924860), PopMember{Float32}(((x1 - (0.8920709 - (((x1 - -1.7248862) + (cos(x1) - (x2 * x2))) * cos(x2)))) * 0.25167146), 0.15561232f0, 0.025940841f0, 16709853833463960, 2373154384291104908, 3536958506904971102), PopMember{Float32}((((x1 - (1.1415747 - ((((x1 + x2) + cos(x1)) + cos(x2)) * cos(x2)))) + 0.55560625) * 0.25959492), 0.15691113f0, 0.025420673f0, 16709849912677810, 3723772632457375661, 5454430744039521221), PopMember{Float32}(((x1 - (0.8929953 - ((((x1 - -1.726694) + x2) + (cos(x1) - (x2 * x2))) * cos(x2)))) * 0.25163484), 0.14735726f0, 0.021931188f0, 16709850634177970, 1234042143914524821, 1618882768450674886), PopMember{Float32}(1.0, 0.0f0, Inf32, 16709845046976850, 1631187884433385412, -1)], Bool[1, 1, 1, 1, 1, 1, 1, 1, 1, 1  …  1, 1, 1, 1, 1, 1, 1, 1, 1, 0])

Built with Julia 1.8.3 and

DiffEqFlux 1.52.0
DifferentialEquations 7.6.0
Lux 0.4.36
Optimization 3.10.0
OptimizationOptimJL 0.1.5
Plots 1.37.0
SymbolicRegression 0.14.5
SymbolicUtils 0.19.11

To run this tutorial locally, download [this file](/tutorials/02neuralode.jl) and open it with Pluto.jl.