Note
Go to the end to download the full example code.
3.4. Differential Evolution
In this tutorial, you will learn how to optimize PyRates models via the differential evolution strategy introduced in [1]. It will be based on the same model and the same parameter as the single parameter grid search example. So it will be worthwhile to first have a look at that example, before proceeding.
Shortly, we will use the Jansen-Rit model (check out the model introduction for the Jansen-Rit model, to learn about the mathematics behind the model and about its implementation in PyRates) [2]. We will perform a 1D evolutionary optimization over its connectivity scaling parameter \(C\). This parameter scales all synaptic strengths in the Jansen-Rit neural mass model and has a critical influence on its behavior [2].
The general idea behind evolutionary model optimization strategies is to optimize a set of model parameters with respect to some objective function that defines the fitness of a certain model parametrization. Initially, a number of different model parametrizations are sampled from a defined parameter space. Then, the iterative optimization starts. At each iteration of the optimization algorithm, the following two steps are performed:
The objective function is evaluated for each parametrization, resulting in a fitness value for each model parametrization
The fitness values are used to sample new model parametrizations, mutate the old parametrizations, or create new parametrizations via combinations of the old parametrizations
These iterations are then repeated until a fitness criterion is reached. The exact functional relationships that are
used to translate fitness values and old parametrizations into new parametrizations depend on the type of evolutionary
model optimization strategy that is used. For a summary of differential evolution, have a look at
this article. Below, we show how PyRates can be used in
combination with scipy.optimize.differential_evolution
to perform such parameter optimizations.
References
First, let’s import the differential_evolution
function from scipy together with other essential packages
from scipy.optimize import differential_evolution
from pyrates.frontend import CircuitTemplate
import matplotlib.pyplot as plt
import numpy as np
3.4.1. Step 1: Definition of the optimization details
(1) To optimize our parameter \(C\), we will have to define the parameter boundaries within which the optimization should be performed:
boundaries = (30.0, 300.0)
Furthermore, we need to define the model and the model parameter that the \(C\) refers to:
model_template = "model_templates.neural_mass_models.jansenrit.JRC2"
model_param = 'jrc/jrc_op/c'
(3) Finally, we have to define the objective function that should be optimized. This objective function always needs to calculate a scalar fitness, based on model output. Thus, we first define the model output:
output = {'V_pce': 'jrc/jrc_op/V_e', 'V_pci': 'jrc/jrc_op/V_i'}
…and then the objective function:
def loss(data, min_amp=6e-3, max_amp=10e-3):
"""Calculates the difference between the value range in the data and the
range defined by min_amp and max_amp.
"""
# calculate the membrane potential of the PC population
data = data.loc[:, 'V_pce'] - data.loc[:, 'V_pci']
# calculate the difference between the membrane potential range
# of the model and the target membrane potential range
data_bounds = np.asarray([np.min(data), np.max(data)]).squeeze()
target_bounds = np.asarray([min_amp, max_amp])
diff = data_bounds - target_bounds
# return the sum of the squared errors
return diff @ diff.T
The value of this loss function depends on the minimum and the maximum value of the average membrane potential of the
pyramidal cell population of the Jansen-Rit model [1]. Depending on the min_amp
and max_amp
arguments of that function, the differential evolution algorithm should optimize the parameter \(C\) of our model
such that the minimum and maximum membrane potential fluctuations of the PC population are as close to those values
as possible. Therefore, this function should suffice to find model parametrizations that express oscillatory behavior
with different oscillation amplitudes or non-oscillatory behavior.
As a final preparation, we need to define a function that takes a single value of \(C\) as an argument, calculates the model activity, and returns the loss for that particular parametrization:
def eval_param(value, T, dt, kwargs):
# load template and update parameter value
jrc = CircuitTemplate.from_yaml(model_template)
jrc.update_var({model_param: value})
# perform simulation
data = jrc.run(outputs=output, simulation_time=T, step_size=dt, clear=True, verbose=False, **kwargs)
# calculate and return loss
return loss(data)
3.4.2. Step 2: Performing the model optimization
Now that we have prepared everything to start the optimization, we can use
scipy.optimize.differential_evolution
to optimize our parameter of interest:
# simulation settings
T = 5.0
dt = 1e-4
kwargs = {'sampling_step_size': 1e-3, 'solver': 'scipy', 'backend': 'default', 'cutoff': 2.0}
# perform optimization
results = differential_evolution(eval_param, bounds=[boundaries], args=(T, dt, kwargs), strategy='randtobest1bin',
maxiter=100, popsize=10, tol=1e-4)
The return value of the differential_evolution
function contains the winning parameter set and its loss
function value. Let’s check out, whether this model parameter indeed produces the behavior we optimized for:
# check out final value of C
c = results.x
print(f'C = {c}')
# calculate dynamics of the JRC for C
jrc = CircuitTemplate.from_yaml(model_template)
jrc = jrc.update_var(node_vars={model_param: c})
kwargs.pop('cutoff')
results = jrc.run(simulation_time=T, step_size=dt, outputs=output, clear=True, **kwargs)
# visualization of the JRC dynamics
data = results['V_pce'] - results['V_pci']
plt.plot(data)
plt.show()
As can be seen, the model shows oscillatory behavior with minimum and maximum membrane potential amplitudes that are close to our target values of 6.0 and 10.0 mV.