3.8. Gradient-based parameter fitting via symbolic ODE Jacobians

In this tutorial, you will learn how to use PyRates’ symbolic Jacobian function generator, CircuitTemplate.get_jacobian_func, to fit model parameters with a gradient-based optimizer. The Jacobian function returned by PyRates is exact (computed symbolically via sympy.diff) and therefore lets us run analytical sensitivity analysis through the model dynamics. This is much more accurate, and often much cheaper, than the finite-difference Jacobian that gradient-based optimizers would otherwise build internally.

Throughout this example, we will use the quadratic integrate-and-fire (QIF) mean-field model of Montbrió, Pazó & Roxin [1], the same model that is used in the parameter continuation tutorial. Its mean-field equations read

\[ \begin{align}\begin{aligned}\tau \dot r &= \frac{\Delta}{\pi\tau} + 2 r v,\\\tau \dot v &= v^2 + \bar\eta + J\tau r - (\pi r \tau)^2,\end{aligned}\end{align} \]

with firing rate \(r\), mean membrane potential \(v\), and four parameters \(\tau, \bar\eta, \Delta, J\). We will fix \(\tau = \Delta = 1\) and recover the remaining two parameters \(\bar\eta\) and \(J\) from a synthetic firing-rate time series that was generated with known ground-truth values.

The optimisation routine of choice is SLSQP from scipy.optimize.minimize, which expects a callable Jacobian of the objective function (with respect to the free parameters). We compute that callable analytically by solving the forward sensitivity equations for the QIF model — a system that augments the original ODE with one extra state per free parameter and that requires only \(\partial f/\partial y\) (the symbolic ODE Jacobian returned by PyRates) and the analytical \(\partial f/\partial \theta_k\).

References

First, let’s import everything we’ll need:

import numpy as np
from scipy.optimize import minimize
from scipy.integrate import solve_ivp
import matplotlib.pyplot as plt
from pyrates import CircuitTemplate

3.8.1. Step 1: Compile the model and its Jacobian

PyRates provides two complementary functions that translate a CircuitTemplate into Python callables: get_run_func (the ODE right-hand side \(f(t, y, \theta)\)) and the new get_jacobian_func (the Jacobian \(\partial f / \partial y\) of that RHS). Each is generated exactly once and re-used at every step of the optimisation loop.

Both functions return a tuple of the form (func, args, arg_names, state_var_indices) where args contains the default numerical values of every model parameter and arg_names lets us look up the position of any specific parameter we want to vary.

circuit = CircuitTemplate.from_yaml("model_templates.neural_mass_models.qif.qif")
run_func, run_args, run_arg_names, sv_idx = circuit.get_run_func(
    func_name="qif_run", step_size=1e-3, solver="scipy",
    vectorize=False, in_place=False, clear=False,
)

# A fresh copy of the template is used for the Jacobian compilation so that the
# two backend states never interfere with each other.
circuit_jac = CircuitTemplate.from_yaml("model_templates.neural_mass_models.qif.qif")
jac_func, jac_args, jac_arg_names, _ = circuit_jac.get_jacobian_func(
    func_name="qif_jac", step_size=1e-3, solver="scipy",
    vectorize=False, in_place=False, clear=True,
)

We extract the constant model parameters (\(\tau, \Delta, I_{ext}\)) from the compiled arguments so we can keep them fixed during the fit. The two free parameters, \(\bar\eta\) and \(J\), will be passed at every call instead of being read from the default args.

TAU   = float(run_args[run_arg_names.index("p/qif_op/tau")])
DELTA = float(run_args[run_arg_names.index("p/qif_op/Delta")])
I_EXT = float(run_args[run_arg_names.index("p/qif_op/I_ext")])

3.8.2. Step 2: Wrappers around the compiled functions

The compiled functions take the parameter values as plain positional arguments, which makes them trivial to call inside scipy.integrate.solve_ivp. We wrap them in two short closures that expose only the variables we want to fit:

def qif_rhs(t, y, eta, J):
    """Return f(y; eta, J) — the QIF vector field at state y."""
    dy = np.zeros(2, dtype=np.float32)
    return np.asarray(
        run_func(np.float32(t), y.astype(np.float32), dy,
                 np.float32(TAU), np.float32(DELTA), np.float32(I_EXT),
                 np.float32(eta), np.float32(J)),
        dtype=float,
    )


def qif_jacobian(t, y, eta, J):
    r"""Return :math:`J_{\mathrm{ode}} = \partial f / \partial y` — the 2×2 ODE Jacobian.

    For the QIF model, :code:`get_jacobian_func` generates code equivalent to

    .. math::
        J_{\mathrm{ode}} = \begin{pmatrix} 2v/\tau & 2r/\tau \\
                                            J - 2\pi^2 r \tau^2/\tau & 2v/\tau
                                            \end{pmatrix}.
    """
    return np.asarray(
        jac_func(np.float32(t), y.astype(np.float32),
                 np.float32(J), np.float32(TAU), np.float32(DELTA), np.float32(I_EXT),
                 np.float32(eta)),
        dtype=float,
    )

3.8.3. Step 3: Generate synthetic target data

We pick ground-truth parameters \(\bar\eta_\star = -2.0\) and \(J_\star = 15.0\), simulate the QIF mean-field from an initial state in the active-firing regime, and treat the resulting firing-rate trajectory as our “experimental” recording. A shorter simulation window (\(T = 30\) ms) is used because the transient relaxation contains the most information for parameter identification — the long-time steady state alone is degenerate in \((\bar\eta, J)\).

T       = 30.0
T_EVAL  = np.linspace(0.0, T, 300)
Y0      = np.array([0.8, -0.5])                 # active-r initial state

ETA_TARGET, J_TARGET = -2.0, 15.0               # ground truth
ETA_INIT,   J_INIT   = -3.5, 10.0               # initial guess for SLSQP

sol_target = solve_ivp(
    qif_rhs, (0.0, T), Y0.copy(),
    args=(ETA_TARGET, J_TARGET),
    t_eval=T_EVAL, method="RK45", max_step=0.5, rtol=1e-6, atol=1e-9,
)
r_target = sol_target.y[0]                      # only firing rate is observed

3.8.4. Step 4: Forward sensitivity equations

This is the core of the example. Given the parametric ODE

\[\dot y = f(y, \theta),\]

the sensitivity \(s_k(t) = \partial y(t) / \partial \theta_k\) of the trajectory with respect to the \(k\)-th parameter satisfies the linear ODE

\[\dot s_k = \frac{\partial f}{\partial y} s_k + \frac{\partial f}{\partial \theta_k}, \qquad s_k(0) = 0.\]

The Jacobian \(\partial f / \partial y\) is provided directly by get_jacobian_func. The parameter Jacobians \(\partial f / \partial \theta_k\) are simple closed-form expressions for the QIF model — reading them off the right-hand side of the \(\dot v\)-equation,

\[\begin{split}\frac{\partial f}{\partial \bar\eta} = \begin{pmatrix} 0 \\ 1/\tau \end{pmatrix}, \qquad \frac{\partial f}{\partial J} = \begin{pmatrix} 0 \\ r \end{pmatrix}.\end{split}\]

We then integrate the augmented ODE \((y, s_{\bar\eta}, s_J) \in \mathbb{R}^6\) in a single solve_ivp call:

DF_DETA = np.array([0.0, 1.0 / TAU])             # constant: ∂f/∂η

def augmented_rhs(t, y_aug, eta, J):
    y     = y_aug[:2]
    s_eta = y_aug[2:4]
    s_J   = y_aug[4:6]

    dy    = qif_rhs(t, y, eta, J)
    J_ode = qif_jacobian(t, y, eta, J)            # from get_jacobian_func

    df_dJ = np.array([0.0, y[0]])                 # depends on current r

    ds_eta = J_ode @ s_eta + DF_DETA
    ds_J   = J_ode @ s_J   + df_dJ

    return np.concatenate([dy, ds_eta, ds_J])

3.8.5. Step 5: Objective function with analytical gradient

Our objective is the root-mean-square error between the simulated and target firing-rate trajectories,

\[L(\bar\eta, J) = \sqrt{\frac{1}{N}\sum_{n=1}^N (r(t_n) - r^\star(t_n))^2}.\]

Differentiating once with the chain rule yields

\[\frac{\partial L}{\partial \theta_k} = \frac{1}{L}\, \overline{(r - r^\star) \, s_k^{(r)}},\]

where \(s_k^{(r)} = \partial r / \partial \theta_k\) is the first row of the sensitivity matrix and \(\overline{\cdot}\) is the time average. This means that one forward pass of the augmented ODE delivers both \(L\) and \(\nabla L\):

def loss_and_grad(params):
    eta, J = float(params[0]), float(params[1])

    y_aug0 = np.zeros(6)
    y_aug0[:2] = Y0

    sol = solve_ivp(
        augmented_rhs, (0.0, T), y_aug0,
        args=(eta, J),
        t_eval=T_EVAL, method="RK45", max_step=0.5, rtol=1e-6, atol=1e-9,
    )
    r_sim   = sol.y[0]
    s_eta_r = sol.y[2]                  # ∂r/∂η
    s_J_r   = sol.y[4]                  # ∂r/∂J

    diff = r_sim - r_target
    rmse = float(np.sqrt(np.mean(diff ** 2)))

    if rmse > 1e-12:
        dL_deta = float(np.mean(diff * s_eta_r)) / rmse
        dL_dJ   = float(np.mean(diff * s_J_r))   / rmse
    else:
        dL_deta = dL_dJ = 0.0

    return rmse, np.array([dL_deta, dL_dJ])

3.8.6. Step 6: Sanity check — analytical vs. finite-difference gradient

Before handing the gradient to SLSQP, it is good practice to verify it against a forward finite-difference reference. For a correctly derived Jacobian the relative error should be in the same ball-park as the FD truncation error itself — well below 1 % for an \(\epsilon\) of \(10^{-4}\).

eps  = 1e-4
L0, g_ana = loss_and_grad(np.array([ETA_INIT, J_INIT]))
L_e1, _   = loss_and_grad(np.array([ETA_INIT + eps, J_INIT]))
L_e2, _   = loss_and_grad(np.array([ETA_INIT, J_INIT + eps]))
g_fd = np.array([(L_e1 - L0) / eps, (L_e2 - L0) / eps])
rel_err = np.linalg.norm(g_ana - g_fd) / (np.linalg.norm(g_fd) + 1e-12)

print(f"L(x0) = {L0:.6f}")
print(f"  analytical:       dL/dη = {g_ana[0]:+.6f},  dL/dJ = {g_ana[1]:+.6f}")
print(f"  finite difference: dL/dη = {g_fd[0]:+.6f},  dL/dJ = {g_fd[1]:+.6f}")
print(f"  relative error:    {rel_err:.2e}")

3.8.7. Step 7: Run SLSQP

Because loss_and_grad returns the gradient alongside the objective value, we pass jac=True to scipy.optimize.minimizeminimize will then unpack the tuple at every call. Box constraints keep the optimiser inside a physically sensible region:

param_history = [np.array([ETA_INIT, J_INIT])]
result = minimize(
    loss_and_grad,
    x0=np.array([ETA_INIT, J_INIT]),
    method="SLSQP",
    jac=True,
    bounds=[(-7.0, 0.0), (4.0, 22.0)],
    callback=lambda xk: param_history.append(xk.copy()),
    options={"maxiter": 80, "ftol": 1e-10, "disp": False},
)

eta_fit, J_fit = result.x
print(f"\nOptimisation finished in {result.nit} iterations")
print(f"  Target:    η = {ETA_TARGET:.4f},  J = {J_TARGET:.4f}")
print(f"  Recovered: η = {eta_fit:.4f},  J = {J_fit:.4f}")
print(f"  Final RMSE: {result.fun:.2e}")

3.8.8. Step 8: Visualise the result

Three panels summarise the run: the simulated firing-rate trajectory before and after fitting (left), the path taken through parameter space by SLSQP (middle), and the agreement between the analytical Jacobian-derived gradient and a numerical reference along an \(\bar\eta\)-slice through the loss landscape (right).

sol_init = solve_ivp(qif_rhs, (0.0, T), Y0.copy(), args=(ETA_INIT, J_INIT),
                     t_eval=T_EVAL, method="RK45", max_step=0.5)
sol_fit  = solve_ivp(qif_rhs, (0.0, T), Y0.copy(), args=(eta_fit, J_fit),
                     t_eval=T_EVAL, method="RK45", max_step=0.5)

fig, axes = plt.subplots(1, 3, figsize=(15, 4))

ax = axes[0]
ax.plot(T_EVAL, r_target,       color="C0", lw=2,                label=f"target (η={ETA_TARGET}, J={J_TARGET})")
ax.plot(T_EVAL, sol_init.y[0],  color="C3", lw=1.5, ls="--",      label=f"initial (η={ETA_INIT}, J={J_INIT})")
ax.plot(T_EVAL, sol_fit.y[0],   color="C2", lw=1.5, ls=":",       label=f"fitted (η={eta_fit:.2f}, J={J_fit:.2f})")
ax.set_xlabel("time (ms)")
ax.set_ylabel("firing rate r")
ax.set_title("QIF firing-rate dynamics")
ax.legend(fontsize=8)

param_arr = np.array(param_history)
ax = axes[1]
sc = ax.scatter(param_arr[:, 0], param_arr[:, 1],
                c=np.arange(len(param_arr)), cmap="plasma", s=40, zorder=3)
ax.plot(param_arr[:, 0], param_arr[:, 1], "k-", lw=0.8, alpha=0.5)
ax.scatter([ETA_TARGET], [J_TARGET], marker="*", s=200, color="C0", zorder=5, label="target")
ax.scatter([ETA_INIT],   [J_INIT],   marker="o", s=80,  color="C3", zorder=5, label="initial")
ax.scatter([eta_fit],    [J_fit],    marker="s", s=80,  color="C2", zorder=5, label="fitted")
plt.colorbar(sc, ax=ax, label="iteration")
ax.set_xlabel(r"$\bar\eta$")
ax.set_ylabel("J")
ax.set_title("SLSQP trajectory in parameter space")
ax.legend(fontsize=8)

ax = axes[2]
etas_scan = np.linspace(-6.0, 0.0, 15)
g_ana_eta, g_fd_eta = [], []
for e in etas_scan:
    L,  ga = loss_and_grad(np.array([e, J_TARGET]))
    Lp, _  = loss_and_grad(np.array([e + eps, J_TARGET]))
    g_ana_eta.append(ga[0])
    g_fd_eta.append((Lp - L) / eps)
ax.plot(etas_scan, g_ana_eta, "C0-o", ms=4, label="analytical (Jacobian)")
ax.plot(etas_scan, g_fd_eta,  "C3--s", ms=4, label="finite difference")
ax.axvline(ETA_TARGET, color="gray", lw=1, ls=":")
ax.axhline(0, color="gray", lw=0.8)
ax.set_xlabel(r"$\bar\eta$")
ax.set_ylabel(r"$\partial L / \partial \bar\eta$")
ax.set_title("Gradient validation (J = J_target)")
ax.legend(fontsize=8)

plt.tight_layout()
plt.show()

3.8.9. Final remarks

The example demonstrates a tight numerical loop made possible by PyRates’ symbolic Jacobian:

  • get_run_func and get_jacobian_func are each invoked once and reused for every loss-and-gradient evaluation.

  • A single solve_ivp pass over the 6-D augmented system produces both the loss value and its analytical gradient — there is no need for the optimiser to evaluate the model \(2 N_\theta + 1\) times per step as it would with a finite-difference Jacobian.

  • The analytical gradient agrees with a finite-difference reference to better than \(10^{-3}\) relative error, and SLSQP recovers both parameters to within machine precision in roughly 30 iterations (final RMSE \(< 10^{-6}\)).

The same template extends cleanly to higher-dimensional parameter vectors: just add one more 2-vector \(s_k\) and one more analytical \(\partial f / \partial \theta_k\) per free parameter. For models where the parameter Jacobians cannot be written by hand, finite differences in \(\theta\) may still be cheaper than finite differences in the full optimisation loop, because each \(\partial f / \partial \theta_k\) is evaluated pointwise rather than along an entire trajectory.

Gallery generated by Sphinx-Gallery