Source code for pyrates.utility

# -*- coding: utf-8 -*-
#
#
# PyRates software framework for flexible implementation of neural
# network model_templates and simulations. See also:
# https://github.com/pyrates-neuroscience/PyRates
#
# Copyright (C) 2017-2018 the original authors (Richard Gast and
# Daniel Rose), the Max-Planck-Institute for Human Cognitive Brain
# Sciences ("MPI CBS") and contributors
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <https://www.gnu.org/licenses/>
#
# CITATION:
#
# Richard Gast and Daniel Rose et. al. in preparation
"""Functions for performing parameter grid simulations with pyrates model_templates.
"""

# meta infos
__author__ = "Christoph Salomon, Richard Gast"
__status__ = "development"

# external imports
from typing import Union, Optional
import pandas as pd
import numpy as np
from copy import deepcopy

# pyrates imports
from pyrates.frontend import CircuitTemplate, OperatorTemplate, template
from pyrates.ir import clear_ir_caches


#####################################
# temporary file clearing functions #
#####################################


[docs]def clear(model: CircuitTemplate, **kwargs): """Function that clears all temporary files and caches that have been created via a PyRates model. Parameters ---------- model Instance of a `CircuitTemplate`. kwargs Additional keyword arguments to be passed to `clear_frontend_caches`. """ try: model.clear() except AttributeError: pass clear_frontend_caches(**kwargs)
[docs]def clear_frontend_caches(clear_template_cache=True, clear_ir_cache=True): """Utility to clear caches in the frontend. Parameters ---------- clear_template_cache toggles whether or not to clear the template_cache that contains all previously loaded templates clear_ir_cache toggles whether or not to clear the cache of unique OperatorIR instances """ if clear_template_cache: template.clear_cache() if clear_ir_cache: OperatorTemplate.cache.clear() clear_ir_caches()
######################## # simulation functions # ########################
[docs]def integrate(circuit: Union[str, CircuitTemplate], **kwargs): """Directly simulate dynamics of a circuit.""" if type(circuit) is str: circuit = CircuitTemplate.from_yaml(circuit) results = circuit.run(**kwargs) if 'clear' in kwargs and kwargs['clear']: clear_frontend_caches() return results
############################# # parameter sweep functions # #############################
[docs]def linearize_grid(grid: dict, permute: bool = False) -> pd.DataFrame: """Turns the grid into a grid that can be traversed linearly, i.e. pairwise. Parameters ---------- grid Parameter grid. permute If true, all combinations of the parameter values in grid will be created. Returns ------- pd.DataFrame Resulting linear grid in form of a data frame. """ arg_lengths = [len(arg) for arg in grid.values()] if len(list(set(arg_lengths))) == 1 and not permute: return pd.DataFrame(grid) elif permute: vals, keys = [], [] for key, val in grid.items(): vals.append(val) keys.append(key) new_grid = np.stack(np.meshgrid(*tuple(vals)), -1).reshape(-1, len(grid)) return pd.DataFrame(new_grid, columns=keys) else: raise ValueError('Wrong number of parameter combinations. If `permute` is False, all parameter vectors in grid ' 'must have the same number of elements.')
[docs]def adapt_circuit(circuit: Union[CircuitTemplate, str], params: dict, param_map: dict) -> CircuitTemplate: """Changes the parametrization of a circuit. Parameters ---------- circuit Circuit instance. params Key-value pairs of the parameters that should be changed. param_map Map between the keys in params and the circuit variables. Returns ------- CircuitIR Updated circuit instance. """ if type(circuit) is str: circuit = deepcopy(CircuitTemplate.from_yaml(circuit)) else: circuit = deepcopy(circuit) node_updates = {} edge_updates = [] for key in params.keys(): val = params[key] if 'nodes' in param_map[key]: for n in param_map[key]['nodes']: for v in param_map[key]['vars']: node_updates[f"{n}/{v}"] = val else: edges = param_map[key]['edges'] if len(edges[0]) < 3: for source, target in edges: for var in param_map[key]['vars']: edge = circuit.get_edge(source=source, target=target, idx=0) edge_updates.append((edge[0], edge[1], {var: val})) else: for source, target, idx in edges: for var in param_map[key]['vars']: edge = circuit.get_edge(source=source, target=target, idx=idx) edge_updates.append((edge[0], edge[1], {var: val})) return circuit.update_var(node_vars=node_updates, edge_vars=edge_updates)
[docs]class Interactive2DParamPlot(object): def __init__(self, data_map: np.array, data_series: pd.DataFrame, x_values: np.array, y_values: np.array, x_key: str, y_key: str, param_map: pd.DataFrame, tmin=0., title=None, **kwargs): """Creates an interactive 2D plot that allows visualization of time series using button press events Derive child class and change get_data() respectively to utilize this plotting method Parameters ---------- data_map 2D ndarray containing a value based on each column data_series, respectively. data_series DataFrame containing all data series used to create the data map x_values ndarray containing values of data-map columns. y_values ndarray containing values of data-map rows. x_key y_key param_map Dataframe containing the mapping between data-series columns (index) and x/y value pairs (columns) (as returned by `grid_search`). tmin Starting point for time-series plots in time units (float). title Title of 2D plot. kwargs Additional information to access a column in data_series if necessary Returns ------- """ import matplotlib.pyplot as plt dt = kwargs.pop('step_size', data_series.index[1] - data_series.index[0]) state_var = kwargs.pop('state_var', '') tmin = int(tmin/dt) self.data = data_series.iloc[tmin:, :] self.x_values = x_values self.y_values = y_values # set up param map matrix self.map = pd.DataFrame(columns=self.x_values, index=self.y_values) for key in param_map.index: x, y = param_map.loc[key, x_key], param_map.loc[key, y_key] x_idx = np.argmin(np.abs(self.x_values - x)) y_idx = np.argmin(np.abs(self.y_values - y)) self.map.iloc[y_idx, x_idx] = key # Create canvas if 'subplots' in kwargs: self.fig, self.ax = kwargs.pop('subplots') else: skwargs = {} for key in ["figsize", "gridspec"]: if key in kwargs: skwargs[key] = kwargs.pop(key) self.fig, self.ax = plt.subplots(ncols=2, nrows=1, **skwargs) # plot signals into right subplot cmap = kwargs.pop("cmap_lines", "magma") self.line_colors = plt.get_cmap(cmap, lut=self.data[self.map.iloc[0, 0]].shape[-1]).colors self.update_lineplot(0, 0) # Initiate marker self.marker = self.ax[0].plot(0, 0, 'x', color='white', markersize='10') # Plot 2D data in left subplot num_x_ticks = kwargs.pop('num_x_ticks', len(x_values)) num_y_ticks = kwargs.pop('num_y_ticks', len(y_values)) shrink = kwargs.pop('cbar_shrink', 0.5) im = self.ax[0].imshow(data_map, **kwargs) set_num_axis_ticks(ax=self.ax[0], num_x_ticks=num_x_ticks, num_y_ticks=num_y_ticks) self.ax[0].set_xlabel(x_key) self.ax[0].set_ylabel(y_key) self.ax[0].set_xticklabels([""] + list(np.round(x_values, decimals=2))) self.ax[0].set_yticklabels([""] + list(np.round(y_values, decimals=2))) if title: self.ax[0].set_title(title) plt.colorbar(im, ax=self.ax[0], shrink=shrink) # Call Interactive2DPlot class instance when mouse button is pressed inside the 2D plot self.fig.canvas.mpl_connect('button_press_event', self)
[docs] def __call__(self, event): """Try to access a column in data_series using x and y values based on cursor position Is called on mouse button press event. Converts the current cursor coordinates inside the plot into x and y values based on the data in x_values and y_values. x and y values are used to access a column in data_series. Access of data_series can be customized in self.get_data(). :param event: :return: """ # Only allow button press events in the (left) 2D overview plot if event.inaxes != self.ax[0]: return # Reset axes self.marker[0].remove() # Transform cursor coordinates in x and y values x_sample = event.xdata y_sample = event.ydata # Add marker at event coordinates self.marker = self.ax[0].plot(x_sample, y_sample, 'x', color='white', markersize='10') # Update serial plot self.update_lineplot(int(np.round(x_sample, decimals=0)), int(np.round(y_sample, decimals=0))) # redraw figure self.fig.canvas.draw() self.fig.canvas.flush_events()
[docs] def update_lineplot(self, x, y): data = self.get_data(x, y) lines = self.ax[1].get_lines() if lines: for line, key in zip(self.ax[1].get_lines(), data.keys()): line.set_data(data.index, data[key].values) else: for i, key in enumerate(data.keys()): self.ax[1].plot(data.index, data[key].values, c=self.line_colors[i]) self.ax[1].set_title( f'x: {np.round(self.x_values[x], decimals=2)}, y: {np.round(self.y_values[y], decimals=2)}') ymin, ymax = np.min(data.values), np.max(data.values) margin = (ymax - ymin)*0.02 self.ax[1].set_ylim([ymin-margin, ymax+margin])
[docs] def set_map_xlabel(self, label): self.ax[0].set_xlabel(label)
[docs] def set_map_ylabel(self, label): self.ax[0].set_ylabel(label)
[docs] def set_map_title(self, title): self.ax[0].set_title(title)
[docs] def set_series_xlabel(self, label): self.ax[1].set_xlabel(label)
[docs] def set_series_ylabel(self, label): self.ax[1].set_ylabel(label)
[docs] def get_data(self, x, y): return self.data[self.map.iloc[y, x]]
[docs]def set_num_axis_ticks(ax, num_x_ticks=10, num_y_ticks=10): """Set the number of x and y ticks of a plot axis Parameters ---------- ax num_x_ticks num_y_ticks Returns ------- """ from matplotlib.ticker import MaxNLocator ax.xaxis.set_major_locator(MaxNLocator(num_x_ticks)) ax.yaxis.set_major_locator(MaxNLocator(num_y_ticks))
########################## # file storage functions # ##########################
[docs]def create_directory(path): """check if a directory exists and create it otherwise""" import os import errno directory = os.path.dirname(path) if directory and not os.path.exists(directory): try: os.makedirs(directory) except OSError as exc: # Guard against race condition if exc.errno != errno.EEXIST: raise