# -*- coding: utf-8 -*-
#
#
# PyRates software framework for flexible implementation of neural
# network model_templates and simulations. See also:
# https://github.com/pyrates-neuroscience/PyRates
#
# Copyright (C) 2017-2018 the original authors (Richard Gast and
# Daniel Rose), the Max-Planck-Institute for Human Cognitive Brain
# Sciences ("MPI CBS") and contributors
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>
#
# CITATION:
#
# Richard Gast and Daniel Rose et. al. in preparation
"""
"""
# external _imports
import time
from typing import Union, Dict, Iterator, Optional, List, Tuple
from warnings import filterwarnings
from networkx import MultiDiGraph, DiGraph
import numpy as np
from copy import deepcopy
from warnings import warn
# pyrates-internal _imports
from pyrates.backend import PyRatesException, PyRatesWarning
from pyrates.ir.node import NodeIR
from pyrates.ir.edge import EdgeIR
from pyrates.ir.abc import AbstractBaseIR
from pyrates.backend.parser import parse_equations, get_unique_label, replace
from pyrates.backend.computegraph import ComputeGraph, ComputeVar, ComputeGraphBackProp
__author__ = "Daniel Rose, Richard Gast"
__status__ = "Development"
in_edge_indices = {} # cache for the number of input edges per network node
in_edge_vars = {} # cache for the input variables that enter at each target operator
#####################
# class definitions #
#####################
# networkx-based representation of all nodes and edges in circuit
[docs]class NetworkGraph(AbstractBaseIR):
"""View on the entire network as a graph. Translates edge operations and attributes into a form that allows to
parse the network graph into a final compute graph."""
def __init__(self, label: str = "circuit", nodes: Dict[str, NodeIR] = None, edges: list = None,
template: str = None, step_size: float = 1e-3, step_size_adaptation: bool = True, verbose: bool = True,
**kwargs):
super().__init__(label=label, template=template)
self._edge_idx_counter = 0
self.step_size = step_size
self.step_size_adaptation = step_size_adaptation
self.graph = MultiDiGraph()
if verbose:
print("Compilation Progress")
print("--------------------")
print('\t(1) Translating the circuit template into a networkx graph representation...')
# add nodes to graph
if nodes:
nodes = ((key, {"node": node}) for key, node in nodes.items())
self.graph.add_nodes_from(nodes)
# add edges to graph
if edges:
for (source, target, edge_dict) in edges:
self.add_edge(source, target, **edge_dict)
if verbose:
print('\t\t...finished.')
print("\t(2) Preprocessing edge transmission operations...")
# translate edge operations and attributes into graph operators
self._preprocess_edge_operations(dde_approx=kwargs.pop('dde_approx', 0),
matrix_sparseness=kwargs.pop('matrix_sparseness', 0.1),
vectorized=kwargs.pop('vectorized'))
if verbose:
print("\t\t...finished.")
[docs] def __getitem__(self, key: str):
"""
Custom implementation of __getitem__ that dissolves strings of form "key1/key2/key3" into
lookups of form self[key1][key2][key3].
Parameters
----------
key
Returns
-------
item
"""
try:
return super().__getitem__(key)
except KeyError:
keys = key.split('/')
for i in range(len(keys)):
if "/".join(keys[:i+1]) in self.nodes:
break
key_iter = iter(['/'.join(keys[:i+1])] + keys[i+1:])
key = next(key_iter)
item = self.getitem_from_iterator(key, key_iter)
for key in key_iter:
item = item.getitem_from_iterator(key, key_iter)
return item
[docs] def add_edge(self, source: str, target: str, edge_ir: EdgeIR = None, weight: float = 1., delay: float = None,
spread: float = None, **data):
"""
Parameters
----------
source
target
edge_ir
weight
delay
spread
data
If no template is given, `data` is assumed to conform to the format that is needed to add an edge. I.e.,
`data` needs to contain fields for `weight`, `delay`, `edge_ir`, `source_var`, `target_var`.
Returns
-------
"""
# step 1: parse and verify source and target specifiers
source_node, source_var = self._parse_edge_specifier(source, data, "source_var")
target_node, target_var = self._parse_edge_specifier(target, data, "target_var")
# step 2: parse source variable specifier (might be single string or dictionary for multiple source variables)
source_vars, extra_sources = self._parse_source_vars(source_node, source_var, edge_ir,
data.pop("extra_sources", None))
# step 3: add edges
attr_dict = dict(edge_ir=edge_ir,
weight=weight,
delay=delay,
spread=spread,
source_var=source_vars,
target_var=target_var,
extra_sources=extra_sources,
**data)
self.graph.add_edge(source_node, target_node, **attr_dict)
[docs] def getitem_from_iterator(self, key: str, key_iter: Iterator[str]):
return self.graph.nodes[key]["node"]
def _preprocess_edge_operations(self, dde_approx: int = 0, vectorized: bool = True, **kwargs):
"""Restructures network graph to collapse nodes and edges that share the same operator graphs. Variable values
get an additional vector dimension. References to the respective index is saved in the internal `label_map`."""
# go through nodes and create buffers for delayed outputs and mappings for their inputs
#######################################################################################
for node_name in self.nodes:
node_outputs = self.graph.out_edges(node_name, keys=True)
node_outputs = self._sort_edges(node_outputs, 'source_var', data_included=False)
# loop over ouput variables of node
for i, (out_var, edges) in enumerate(node_outputs.items()):
# extract delay info from variable projections
op_name, var_name = out_var.split('/')
delays, spreads, nodes, add_delay = self._collect_delays_from_edges(edges)
# add synaptic buffer to output variables with delay
if add_delay:
if vectorized:
self._add_edge_buffer(node_name, op_name, var_name, edges=edges, delays=delays,
nodes=nodes, spreads=spreads, dde_approx=dde_approx)
else:
# TODO: sort edges into unique delay/spread combinations and only loop over those
if spreads:
for i, (edge, delay, spread, node) in enumerate(zip(edges, delays, spreads, nodes)):
self._add_edge_buffer(node_name, op_name, var_name, edges=[edge], delays=[delay],
nodes=[node], spreads=[spread], dde_approx=dde_approx,
buffer_id=f"_out{i}")
else:
for i, (edge, delay, node) in enumerate(zip(edges, delays, nodes)):
self._add_edge_buffer(node_name, op_name, var_name, edges=[edge], delays=[delay],
nodes=[node], dde_approx=dde_approx, buffer_id=f"_out{i}")
# go through nodes again, and collect and process all inputs to each node variable
##################################################################################
for node_name in self.nodes:
node_inputs = self.graph.in_edges(node_name, keys=True)
node_inputs = self._sort_edges(node_inputs, 'target_var', data_included=False)
# loop over inputs to node variable
for i, (in_var, edges) in enumerate(node_inputs.items()):
# extract info from projections to input variable
op_name, var_name = in_var.split('/')
data = self._collect_from_edges(edges,
keys=['source_var', 'weight', 'source_idx', 'target_idx'])
# create the final equations for all edges that target the input variable
self._generate_edge_equation(tnode=node_name, top=op_name, tvar=var_name, inputs=data, **kwargs)
def _sort_edges(self, edges: List[tuple], attr: str, data_included: bool = False) -> dict:
"""Sorts edges according to the given edge attribute.
Parameters
----------
edges
Collection of edges of interest.
attr
Name of the edge attribute.
Returns
-------
dict
Key-value pairs of the different values the attribute can take on (keys) and the list of edges for which
the attribute takes on that value (value).
"""
edges_new = {}
if data_included:
for edge in edges:
if len(edge) == 4:
source, target, edge, data = edge
else:
raise ValueError("Missing edge index. This error message should not occur.")
value = self.edges[source, target, edge][attr]
if value not in edges_new.keys():
edges_new[value] = [(source, target, edge, data)]
else:
edges_new[value].append((source, target, edge, data))
else:
for edge in edges:
if len(edge) == 3:
source, target, edge = edge
else:
raise ValueError("Missing edge index. This error message should not occur.")
value = self.edges[source, target, edge][attr]
if value not in edges_new.keys():
edges_new[value] = [(source, target, edge)]
else:
edges_new[value].append((source, target, edge))
return edges_new
def _collect_delays_from_edges(self, edges):
means, stds, nodes = [], [], []
for s, t, e in edges:
# extract delay
d = self.edges[s, t, e]['delay']
if type(d) is list:
d = [1 if d_tmp is None else d_tmp for d_tmp in d]
# extract and process delay distribution spread
v = self.edges[s, t, e].pop('spread', [0])
if v is None or np.sum(v) == 0:
v = [0] * len(self.edges[s, t, e]['target_idx'])
discretize = True
else:
discretize = False
v = self._process_delays(v, discretize=discretize)
# finalize edge delay
if d is None or np.sum(d) == 0:
d = [1] * len(self.edges[s, t, e]['target_idx'])
else:
d = self._process_delays(d, discretize=discretize)
# extract source var index
source = self.edges[s, t, e]['source_idx']
if len(d) > 1 and len(source) == 1:
source = source * len(d)
# collect values
means += d
stds += v
nodes.append(source)
# check whether edge delays have to be implemented or can be ignored
max_delay = np.max(means)
add_delay = ("int" in str(type(max_delay)) and max_delay > 1) or \
("float" in str(type(max_delay)) and max_delay > self.step_size)
if sum(stds) == 0:
stds = None
# if delays are going to be added from the created lists, remove the delays from the edges themselves
if add_delay:
for s, t, e in edges:
self.edges[s, t, e]['source_idx'] = []
self.edges[s, t, e]['delay'] = None
return means, stds, nodes, add_delay
def _collect_from_edges(self, edges: list, keys: list):
data = dict()
for source, target, idx in edges:
edge = self.edges[(source, target, idx)]
if source not in data:
data[source] = dict()
for key in keys:
val = deepcopy(edge[key])
try:
data[source][key].extend(val)
except AttributeError:
field = data[source][key]
if type(field) is str or field is None:
pass
else:
data[source][key] = [field, val]
except KeyError:
data[source][key] = val
return data
def _add_edge_buffer(self, node: str, op: str, var: str, edges: list, delays: list, nodes: list,
spreads: Optional[list] = None, dde_approx: int = 0, buffer_id: str = "") -> None:
"""Adds a buffer variable to an edge.
Parameters
----------
node
Name of the source node of the edge.
op
Name of the source operator of the edge.
var
Name of the source variable of the edge.
edges
List with edge identifier tuples (source_name, target_name, edge_idx).
delays
edge delays.
nodes
Node indices for each edge delay.
spreads
Standard deviations of delay distributions around means given by `delays`.
dde_approx
Only relevant for delayed systems. If larger than zero, all discrete delays in the system will be
automatically approximated by a system of (n+1) coupled ODEs that represent a convolution with a
gamma distribution centered around the original delay (n is the approximation order).
Returns
-------
None
"""
max_delay = np.max(delays)
# extract target shape and node
node_var = self[f"{node}/{op}/{var}"]
target_shape = node_var['shape']
node_ir = self[node]
nodes_tmp = list()
for n in nodes:
nodes_tmp += n
source_idx = np.asarray(nodes_tmp, dtype='int').flatten()
# ODE approximation to DDE
##########################
if dde_approx or spreads:
# calculate orders and rates of ODE-system approximations to delayed connections
if spreads:
orders, rates = [], []
for m, v in zip(delays, spreads):
order = np.round((m / v) ** 2, decimals=0) if v > 0 else 0
orders.append(int(order) if m and order > dde_approx else dde_approx)
rates.append(orders[-1] / m if m else 0)
else:
orders, rates = [], []
for m in delays:
orders.append(dde_approx if m else 0)
rates.append(dde_approx / m if m else 0)
# sort all edge information in ascending ODE order
order_idx = np.argsort(orders, kind='stable')
orders_sorted = np.asarray(orders, dtype='int')[order_idx]
orders_tmp = np.asarray(orders, dtype='int')[order_idx]
rates_tmp = np.asarray(rates)[order_idx]
source_idx_tmp = source_idx[order_idx]
buffer_eqs, var_dict, final_idx = [], {}, []
max_order = max(orders)
for i in range(max_order + 1):
# check which edges require the ODE order treated in this iteration of the loop
k = i + 1
idx, idx_str, idx_var = self._bool_to_idx(orders_tmp >= k)
if type(idx) is int:
idx = [idx]
var_dict.update(idx_var)
# define new equation variable/parameter names
var_next = f"{var}_d{k}{buffer_id}"
var_prev = f"{var}_d{i}{buffer_id}" if i > 0 else var
rate = f"k_d{k}{buffer_id}"
# prepare variables for the next ODE
idx_apply = len(idx) != len(orders_tmp)
val = rates_tmp[idx] if idx_apply else rates_tmp
var_shape = (len(val),) if val.shape else ()
if i == 0 and idx != [0] and (sum(target_shape) != len(idx) or any(np.diff(order_idx) != 1)):
var_prev_idx = f"index({var_prev}, source_idx{buffer_id})"
var_dict[f"source_idx{buffer_id}"] = {'vtype': 'constant',
'dtype': 'int',
'shape': (len(source_idx_tmp[idx]),),
'value': source_idx_tmp[idx]}
elif i != 0 and idx_apply:
var_prev_idx = _get_indexed_var_str(var_prev, idx_str, var_length=len(rates_tmp))
else:
var_prev_idx = var_prev
# create new ODE string and corresponding variable definitions
buffer_eqs.append(f"d/dt * {var_next} = {rate} * ({var_prev_idx} - {var_next})")
var_dict[var_next] = {'vtype': 'state_var',
'dtype': 'float',
'shape': var_shape,
'value': 0.}
var_dict[rate] = {'vtype': 'constant',
'dtype': 'float',
'value': val}
# store indices that are required to fill the edge buffer variable
if idx_apply:
# right-hand side index
if len(orders_tmp) < 2:
idx_rhs_str = ''
elif i == 0:
idx_rhs = np.asarray(source_idx_tmp)[orders_tmp == i]
n_idx = len(idx_rhs)
if n_idx > 1:
idx_rhs_str = f"source_idx2{buffer_id}"
var_dict[f"source_idx2{buffer_id}"] = {'vtype': 'constant',
'dtype': 'int',
'shape': (n_idx,),
'value': idx_rhs}
else:
idx_rhs_str = f"{idx_rhs[0]}"
else:
_, idx_rhs_str, _ = self._bool_to_idx(orders_tmp == i)
# left-hand side index
if len(delays) > 1:
_, idx_lhs_str, _ = self._bool_to_idx(orders_sorted == i)
else:
idx_lhs_str = ''
final_idx.append((i, idx_lhs_str, idx_rhs_str))
# reduce lists of orders and rates by the ones that are fully implemented by the current ODE set
if idx_apply:
orders_tmp = orders_tmp[idx]
rates_tmp = rates_tmp[idx]
if not orders_tmp.shape:
orders_tmp = np.asarray([orders_tmp], dtype='int')
rates_tmp = np.asarray([rates_tmp])
# remove unnecessary ODEs
for _ in range(len(buffer_eqs) - final_idx[-1][0]):
i = len(buffer_eqs)
var_dict.pop(f"{var}_d{i}{buffer_id}")
var_dict.pop(f"k_d{i}{buffer_id}")
buffer_eqs.pop(-1)
# create edge buffer variable
buffer_length = len(delays)
for i, idx_l, idx_r in final_idx:
lhs = _get_indexed_var_str(f"{var}_buffered{buffer_id}", idx_l, var_length=buffer_length)
rhs = _get_indexed_var_str(f"{var}_d{i}{buffer_id}" if i != 0 else var, idx_r, var_length=buffer_length)
buffer_eqs.append(f"{lhs} = {rhs}")
var_dict[f"{var}_buffered{buffer_id}"] = {'vtype': 'variable',
'dtype': 'float',
'shape': (buffer_length,),
'value': 0.0}
# re-order buffered variable if necessary
if any(np.diff(order_idx) != 1):
buffer_eqs.append(f"{var}_buffered{buffer_id} = index({var}_buffered{buffer_id}, "
f"{var}_buffered_idx{buffer_id})")
var_dict[f"{var}_buffered_idx{buffer_id}"] = {'vtype': 'constant',
'dtype': 'int',
'shape': (len(order_idx),),
'value': np.argsort(order_idx, kind='stable')}
# discretized edge buffers
##########################
elif not self.step_size_adaptation:
# create buffer variable shapes
if len(target_shape) < 1 or (len(target_shape) == 1 and target_shape[0] == 1):
buffer_shape = (max_delay + 1,)
else:
buffer_shape = (target_shape[0], max_delay + 1)
# create buffer variable definitions
var_dict = {f'{var}_buffer{buffer_id}': {'vtype': 'variable',
'dtype': 'float',
'shape': buffer_shape,
'value': 0.},
f'{var}_buffered{buffer_id}': {'vtype': 'variable',
'dtype': 'float',
'shape': (len(delays),),
'value': 0.},
f'{var}_delays{buffer_id}': {'vtype': 'constant',
'dtype': 'int',
'value': delays},
f'source_idx{buffer_id}': {'vtype': 'constant',
'dtype': 'int',
'value': source_idx}}
# create buffer equations
if len(target_shape) < 1 or (len(target_shape) == 1 and target_shape[0] == 1):
buffer_eqs = [f"index_axis({var}_buffer{buffer_id}) = roll({var}_buffer{buffer_id}, 1)",
f"index({var}_buffer{buffer_id}, 0) = {var}",
f"{var}_buffered{buffer_id} = index({var}_buffer{buffer_id}, {var}_delays{buffer_id})"]
else:
buffer_eqs = [f"index_axis({var}_buffer{buffer_id}) = roll({var}_buffer{buffer_id}, 1, 1)",
f"index_axis({var}_buffer{buffer_id}, 0, 1) = {var}",
f"{var}_buffered{buffer_id} = index_2d({var}_buffer{buffer_id}, source_idx{buffer_id}, "
f"{var}_delays{buffer_id})"]
# Turn ODE system into DDE system
#################################
else:
warn(PyRatesWarning(f'PyRates detected an edge definition that implies to represent the model as a '
f'delayed differential equation system.\n PyRates will thus attempt to access the '
f'history of the source variable {var} of operator {op} on node {node}. '
f'Note that this requires {var} to be a state-variable, i.e. a variable defined by '
f'a differential equation.'))
# create buffer variable definitions
var_dict = {f'{var}_buffered{buffer_id}': {'vtype': 'variable',
'dtype': 'float',
'shape': (len(delays),),
'value': 0.}
}
buffer_eqs = []
for i, (d, sidx) in enumerate(zip(delays, source_idx)):
var_delayed = f"past({var}, {d})" if type(d) is float or d != 1 else var
if len(target_shape) < 1 or (len(target_shape) == 1 and target_shape[0] == 1):
buffer_eqs.append(f"{var}_buffered{buffer_id} = {var_delayed}")
else:
buffer_eqs.append(f"index({var}_buffered{buffer_id}, {sidx}) = index({var_delayed}, {sidx})")
# add buffer equations to node operator
op_info = node_ir[op]
op_info['equations'] += buffer_eqs
op_info['variables'].update(var_dict)
op_info['output'] = f"{var}_buffered{buffer_id}"
# update input information of node operators connected to this operator
for succ in node_ir.op_graph.succ[op]:
inputs = self[f"{node}/{succ}"]['inputs']
if var not in inputs.keys():
inputs[var] = {'sources': {op}}
# update edge information
idx_l = 0
for i, edge in enumerate(edges):
s, t, e = edge
self.edges[s, t, e]['source_var'] = f"{op}/{var}_buffered{buffer_id}"
if len(edges) > 1:
idx_h = idx_l + len(nodes[i])
self.edges[s, t, e]['source_idx'] = list(range(idx_l, idx_h))
idx_l = idx_h
def _generate_edge_equation(self, tnode: str, top: str, tvar: str, inputs: dict, matrix_sparseness: float = 0.1,
weight_minimum: float = 1e-8):
# step 0: check properties of the target variable and its inputs
multiple_inputs = len(inputs) > 1
tval = self[f"{tnode}/{top}/{tvar}"]
if tval['shape']:
tsize = sum(tval['shape'])
elif type(tval['value']) is list:
tsize = len(tval['value'])
else:
tsize = 0
# step 1: collect all inputs
weights, source_indices, target_indices, sources = [], [], [], []
for snode, sinfo in inputs.items():
weights.append(sinfo['weight'])
source_indices.append(sinfo['source_idx'])
target_indices.append(sinfo['target_idx'])
sources.append((snode,) + tuple(sinfo['source_var'].split('/')))
# step 2: process incoming edges
source_vars, args = {}, {}
eqs, in_vars = [], []
for i, (weight, sidx, tidx, (snode, sop, svar)) in \
enumerate(zip(weights, source_indices, target_indices, sources)):
# get source variable size
sval = self[f"{snode}/{sop}/{svar}"]
if sval['shape']:
ssize = sum(sval['shape'])
elif type(sval['value']) is list:
ssize = len(sval['value'])
else:
ssize = 1
# check whether the edge can be realized via a matrix product
if not tidx:
tidx = [0 for _ in range(len(sidx))]
n, m = len(tidx), len(sidx)
if len(np.unique(tidx)) < len(tidx):
if not sidx:
sidx = [i for i in range(len(weight))]
dot_edge = True
elif n*m > 1 and tsize*ssize > 1:
dot_edge = len(weight) / (n * m) > matrix_sparseness
else:
dot_edge = False
# define new input variable if necessary
if multiple_inputs:
in_shape = (tsize,)
t_str = f'{tvar}_in{i}'
w_str = f'weight_in{i}'
s_str = f'{svar}_in{i}'
sidx_str = f'source_idx_in{i}'
tidx_str = f'target_idx_in{i}'
args[t_str] = {'value': np.zeros(in_shape), 'dtype': 'float', 'vtype': 'variable',
'shape': in_shape}
else:
t_str = tvar
w_str = 'weight'
s_str = svar
sidx_str = 'source_idx'
tidx_str = 'target_idx'
# case I: realize edge projection via a matrix product
if dot_edge:
# create weight matrix
sidx_unique = np.unique(sidx)
tidx_unique = np.unique(tidx)
weight_mat = np.zeros((len(tidx_unique), len(sidx_unique)))
for t, s, w in zip(tidx, sidx, weight):
row = np.argwhere(tidx_unique == t).squeeze()
col = np.argwhere(sidx_unique == s).squeeze()
weight_mat[row, col] = w
# define edge projection equation
s_str_final = _get_indexed_var_str(s_str, sidx_unique, ssize, idx_str=sidx_str, arg_dict=args)
t_str_final = _get_indexed_var_str(t_str, tidx_unique, tsize, idx_str=tidx_str, arg_dict=args)
eq = f"{t_str_final} = matvec({w_str}, {s_str_final})"
args[w_str] = {'vtype': 'constant', 'value': weight_mat, 'dtype': 'float', 'shape': weight_mat.shape}
# case II: realize edge projection via source and target indexing
else:
# check wether weighting of source variables is required
if all([abs(w-1) < weight_minimum for w in weight]):
weighting = ""
else:
weighting = f" * {w_str}"
args[w_str] = {'vtype': 'constant', 'dtype': 'float', 'value': weight if ssize > 1 else weight[0]}
# get final source and target strings
s_str_final = _get_indexed_var_str(s_str, sidx, ssize, reduce=m == 1 and tsize > 1 and n == 1,
idx_str=sidx_str, arg_dict=args)
t_str_final = _get_indexed_var_str(t_str, tidx, tsize, reduce=tsize > 1 or ssize < tsize,
idx_str=tidx_str, arg_dict=args)
# define edge equation
eq = f"{t_str_final} = {s_str_final}{weighting}"
# add equation and source information
eqs.append(eq)
source_vars[s_str] = {'sources': [sop], 'node': snode, 'var': svar}
in_vars.append(t_str)
# step 3: process multiple inputs to same variable
if multiple_inputs:
# finalize edge equations
eq = f"{tvar} = {'+'.join(in_vars)}"
eqs.append(eq)
# step 4: define target variable as operator output
args[tvar] = tval
args[tvar]['vtype'] = 'variable'
# step 5: add edge operator to target node
if tnode not in in_edge_indices:
in_edge_indices[tnode] = 0
op_name = f'in_edge_{in_edge_indices[tnode]}'
in_edge_indices[tnode] += 1
tnode_ir = self[tnode]
tnode_ir.add_op(op_name, inputs=source_vars, output=tvar, equations=eqs, variables=args)
tnode_ir.add_op_edge(op_name, top)
# step 6: add input information to target operator
inputs = self[tnode][top]['inputs']
if tvar in inputs.keys():
inputs[tvar]['sources'].add(op_name)
else:
inputs[tvar] = {'sources': [op_name]}
def _process_delays(self, d, discretize=True):
if type(d) is list:
d = np.asarray(d).squeeze()
d = [self._preprocess_delay(d_tmp, discretize=discretize) for d_tmp in d] if d.shape else \
[self._preprocess_delay(d, discretize=discretize)]
else:
d = [self._preprocess_delay(d, discretize=discretize)]
return d
def _preprocess_delay(self, delay, discretize=True):
return int(np.round(delay / self.step_size, decimals=0)) if discretize and not self.step_size_adaptation \
else delay
def _bool_to_idx(self, v):
v_idx = np.argwhere(v).squeeze()
v_dict = {}
if v_idx.shape and v_idx.shape[0] > 1 and all(np.diff(v_idx) == 1):
v_idx_str = (f"{v_idx[0]}", f"{v_idx[-1] + 1}")
elif v_idx.shape and v_idx.shape[0] > 1:
var_name = f"delay_idx{self._edge_idx_counter}"
v_idx_str = f"{var_name}"
v_dict[var_name] = {'value': v_idx, 'vtype': 'constant'}
self._edge_idx_counter += 1
else:
try:
v_idx_str = f"{v_idx.max()}"
except ValueError:
v_idx_str = ""
return v_idx.tolist(), v_idx_str, v_dict
def _parse_source_vars(self, source_node: str, source_var: Union[str, dict], edge_ir, extra_sources: dict = None
) -> Tuple[Union[str, dict], dict]:
"""Parse is source variable specifications. This tests, whether a single or more source variables and verifies
all given paths.
Parameters
----------
source_node
String that specifies a single node as source of an edge.
source_var
Single variable specifier string or dictionary of form `{source_op/source_var: edge_op/edge_var
edge_ir
Instance of an EdgeIR that contains information about the internal structure of an edge.
Returns
-------
source_var
"""
# step 1: figure out, whether only one or more source variables are defined
try:
# try to treat source_var as dictionary
n_source_vars = len(source_var.keys())
except AttributeError:
# not a dictionary, so must be a string
n_source_vars = 1
else:
# was a dictionary, treat case that it only has length 1
if n_source_vars == 1:
source_var = next(iter(source_var))
if n_source_vars == 1:
_, _ = source_var.split("/") # should be op, var, but we do not need them here
self._verify_path(source_node, source_var)
else:
# verify that number of source variables matches number of input variables in edge
if extra_sources is not None:
n_source_vars += len(extra_sources)
if n_source_vars != edge_ir.n_inputs:
raise PyRatesException(f"Mismatch between number of source variables ({n_source_vars}) and "
f"inputs ({edge_ir.n_inputs}) in an edge with source '{source_node}' and source"
f"variables {source_var}.")
for node_var, edge_var in source_var.items():
self._verify_path(source_node, node_var)
if extra_sources is not None:
for edge_var, source in extra_sources.items():
node, op, var = source.split("/")
source = "/".join((node, op, var))
self._verify_path(source)
extra_sources[edge_var] = source
return source_var, extra_sources
def _verify_path(self, *parts: str):
"""
Parameters
----------
parts
One or more parts of a path string
Returns
-------
"""
# go trough circuit hierarchy
path = "/".join(parts)
# check if path is valid
if path not in self:
raise PyRatesException(f"Could not find object with path `{path}`.")
@staticmethod
def _parse_edge_specifier(specifier: str, data: dict, var_string: str) -> Tuple[str, Union[str, dict]]:
"""Parse source or target specifier for an edge.
Parameters
----------
specifier
String that defines either a specific node or complete variable path of source or target for an edge.
Format: *circuits/node/op/var
data
dictionary containing additional information about the edge. This function looks for a variable specifier
as specified in `var_string`
var_string
String that points to an optional key of the `data` dictionary. Should be either 'source_var' or
'target_var'
Returns
-------
(node, var)
"""
# step 1: try to get source and target variables from data dictionary, if not available, get them from
# source/target string
try:
# try to get source variable info from data dictionary
var = data.pop(var_string) # type: Union[str, dict]
except KeyError:
# not found, assume variable info is contained in `source`
# also means that there is only one source variable (on the main source node) to take care of
*node, op, var = specifier.split("/")
node = "/".join(node)
var = "/".join((op, var))
else:
# source_var was in data, so `source` contains only info about source node
node = specifier # type: str
return node, var
@property
def nodes(self):
"""Shortcut to self.graph.nodes. See documentation of `networkx.MultiDiGraph.nodes`."""
return self.graph.nodes
@property
def edges(self):
"""Shortcut to self.graph.edges. See documentation of `networkx.MultiDiGraph.edges`."""
return self.graph.edges
# intermediate representation that provides interface between frontend and backend
[docs]class CircuitIR(AbstractBaseIR):
"""Custom graph data structure that represents a backend of nodes and edges with associated equations
and variables."""
__slots__ = ["label", "_front_to_back", "graph", "_t", "_verbose", "_dt", "_dt_adapt", "_def_shape"]
def __init__(self, label: str = "circuit", nodes: Dict[str, NodeIR] = None, edges: list = None,
template: str = None, step_size_adaptation: bool = False, step_size: float = None,
verbose: bool = True, backend: str = None, scalar_shape: tuple = None, **kwargs):
"""
Parameters:
-----------
label
String label, could be used as fallback when subcircuiting this circuit. Currently not used, though.
nodes
Dictionary of nodes of form {node_label: `NodeIR` instance}.
edges
List of tuples (source:str, target:str, edge_dict). `edge_dict` should contain the key "edge_ir" with an
`EdgeIR` instance as item and optionally entries for "weight" and "delay". `source` and `target` should be
formatted as "node/op/var" (with optionally prepended circuits).
template
optional string reference to path to template that this circuit was loaded from. Leave empty, if no template
was used.
"""
# filter displayed warnings
filterwarnings("ignore", category=FutureWarning)
# set main attributes
super().__init__(label=label, template=template)
self._verbose = verbose
self._front_to_back = dict()
self._dt = step_size
self._dt_adapt = step_size_adaptation
self._def_shape = (1,) if scalar_shape is None else scalar_shape
# translate the network into a networkx graph
net = NetworkGraph(nodes=nodes, edges=edges, label=label, step_size=step_size,
step_size_adaptation=step_size_adaptation, verbose=verbose, **kwargs)
# parse network equations into a compute graph
if verbose:
print("\t(3) Parsing the model equations into a compute graph...")
self.graph = self.network_to_computegraph(graph=net, backend=backend, **kwargs)
if verbose:
print("\t\t...finished.")
print("\tModel compilation was finished.")
[docs] def get_var(self, var: str, get_key: bool = False) -> Union[str, ComputeVar]:
"""Extracts variable from the backend (i.e. the `ComputeGraph` instance).
Parameters
----------
var
Name of the variable.
get_key
If true, the backend variable name will be returned
Returns
-------
Union[str, ComputeVar]
Either the backend variable or its name.
"""
try:
v = self[var]
except KeyError:
v = self._front_to_back[var]
return v.name if get_key else v
[docs] def get_frontend_varname(self, var: str) -> str:
"""Returns the original frontend variable name given the backend variable name `var`.
Parameters
----------
var
Name of the backend variable.
Returns
-------
str
Name of the frontend variable
"""
v = self.get_var(var)
front_vars = list(self._front_to_back.keys())
back_vars = list(self._front_to_back.values())
idx = back_vars.index(v)
return front_vars[idx]
[docs] def run(self,
simulation_time: float,
outputs: Optional[dict] = None,
sampling_step_size: Optional[float] = None,
solver: str = 'euler',
**kwargs
) -> dict:
"""Simulate the backend behavior over time via a tensorflow session.
Parameters
----------
simulation_time
Simulation time in seconds.
outputs
Output variables that will be returned. Each key is the desired name of an output variable and each value is
a string that specifies a variable in the graph in the same format as used for the input definition:
'node_name/op_name/var_name'.
sampling_step_size
Time in seconds between sampling points of the output variables.
solver
Numerical solving scheme to use for differential equations. Currently supported ODE solving schemes:
- 'euler' for the explicit Euler method
- 'scipy' for integration via the `scipy.integrate.solve_ivp` method.
kwargs
Keyword arguments that are passed on to the chosen solver.
Returns
-------
dict
Output variables in a dictionary.
"""
filterwarnings("ignore", category=FutureWarning)
# collect backend variables and functions
#########################################
if self._verbose:
print("Simulation Progress")
print("-------------------")
print("\t (1) Generating the network run function...")
# generate run function
func_name = kwargs.pop('func_name', 'vector_field')
func, func_args, _, _ = self.get_run_func(func_name, **kwargs)
# extract backend variables that correspond to requested output variables
if self._verbose:
print("\t (2) Processing output variables...")
outputs_col = {}
if outputs:
for key, val in outputs.items():
outputs_col[key] = self.get_var(val, get_key=True)
if self._verbose:
print("\t\t...finished.")
# perform simulation
####################
if self._verbose:
print("\t (3) Running the simulation...")
t0 = time.perf_counter()
# call backend run function
results = self.graph.run(func=func, func_args=func_args, T=simulation_time, dt=self._dt, dts=sampling_step_size,
outputs=outputs_col, solver=solver, **kwargs)
if self._verbose:
t1 = time.perf_counter()
print(f"\t\t...finished after {t1-t0}s.")
# return simulation results if outputs have been passed
if outputs_col:
return results
# else, find the frontend variable names of the returned results and create a new results dict to return
for key in results.copy():
front_key = self.get_frontend_varname(key)
results[front_key] = results.pop(key)
return results
[docs] def get_run_func(self, func_name: str, file_name: Optional[str] = None, **kwargs) -> tuple:
if not file_name:
file_name = f"pyrates_func"
return self.graph.to_func(func_name=func_name, file_name=file_name, dt_adapt=self._dt_adapt, **kwargs)
[docs] def network_to_computegraph(self, graph: NetworkGraph, inplace_vectorfield: bool = True, **kwargs):
# initialize compute graph
cg = ComputeGraph(**kwargs) if inplace_vectorfield else ComputeGraphBackProp(**kwargs)
# add global time variable to compute graph
cg.add_var(label="t", vtype="state_var", value=0.0 if self._dt_adapt else 0, shape=(),
dtype='float' if self._dt_adapt else 'int')
# node operators
parsing_kwargs = ['parsing_method']
parsing_kwargs = {key: kwargs.pop(key) for key in parsing_kwargs if key in kwargs}
self._parse_op_layers_into_computegraph(graph, cg, layers=[], exclude=True, op_identifier="edge_from_",
**parsing_kwargs)
# edge operators
self._parse_op_layers_into_computegraph(graph, cg, layers=[0], exclude=False, op_identifier="edge_from_",
**parsing_kwargs)
return cg
[docs] def getitem_from_iterator(self, key: str, key_iter: Iterator[str]):
return self.graph.get_var(key)
[docs] def clear(self):
"""Clears the backend graph from all operations and variables.
"""
self._front_to_back.clear()
self.graph.clear()
in_edge_indices.clear()
in_edge_vars.clear()
def _parse_op_layers_into_computegraph(self, net: NetworkGraph, cg: ComputeGraph, layers: list,
exclude: bool = False, op_identifier: Optional[str] = None, **kwargs
) -> None:
"""
Parameters
----------
layers
exclude
op_identifier
kwargs
Returns
-------
"""
for node_name, node in net.nodes.items():
op_graph = node['node'].op_graph
g = op_graph.copy() # type: DiGraph
# go through all operators on node and pre-process + extract equations and variables
i = 0
while g.nodes:
# get all operators that have no dependencies on other operators
# noinspection PyTypeChecker
ops = [op for op, in_degree in g.in_degree if in_degree == 0]
if (i in layers and not exclude) or (i not in layers and exclude):
# collect operator variables and equations from node
if op_identifier:
ops_tmp = [op for op in ops if op_identifier not in op] if exclude else \
[op for op in ops if op_identifier in op]
else:
ops_tmp = ops
op_eqs, op_vars = self._collect_ops(ops_tmp, node_name=node_name, graph=net, compute_graph=cg,
reduce=exclude)
# parse equations and variables into computegraph
variables = parse_equations(op_eqs, op_vars, cg=cg, def_shape=self._def_shape, **kwargs)
# remember mapping between frontend variable names and node keys in compute graph
for key, var in variables.items():
if key.split('/')[-1] != 'inputs' and isinstance(var, ComputeVar):
self._front_to_back[key] = var
# remove parsed operators from graph
g.remove_nodes_from(ops)
i += 1
def _collect_ops(self, ops: List[str], node_name: str, graph: NetworkGraph, compute_graph: ComputeGraph,
reduce: bool) -> tuple:
"""Adds a number of operations to the backend graph.
Parameters
----------
ops
Names of the operators that should be parsed into the graph.
node_name
Name of the node that the operators belong to.
graph
compute_graph
reduce
Returns
-------
tuple
Collected and updated operator equations and variables
"""
# set up update operation collector variable
equations = []
variables = {}
# add operations of same hierarchical lvl to compute graph
############################################################
for op_name in ops:
# retrieve operator and operator args
scope = f"{node_name}/{op_name}"
op_info = graph[f"{node_name}/{op_name}"]
op_args = deepcopy(op_info['variables'])
op_args['inputs'] = {}
# handle operator inputs
in_ops = {}
for var_name, inp in op_info['inputs'].items():
# go through inputs to variable
if inp['sources']:
in_ops_col = {}
in_node = inp['node'] if 'node' in inp else node_name
in_var_tmp = inp.pop('var', None)
for i, in_op in enumerate(inp['sources']):
# collect single input to op
in_var = in_var_tmp if in_var_tmp else graph[f"{in_node}/{in_op}"]['output']
in_key = f"{in_node}/{in_op}/{in_var}"
try:
in_val = self._front_to_back[in_key]
except KeyError:
in_val = graph[in_key]
in_ops_col[in_key] = in_val
# if multiple inputs to variable, sum them up
if len(in_ops_col) > 1:
in_ops[var_name] = self._map_multiple_inputs(in_ops_col, scope=scope)
else:
key, _ = in_ops_col.popitem()
in_ops[var_name] = (None, {var_name: key})
# replace input variables with input in operator equations
for var, (inp_term, inp) in in_ops.items():
if inp_term:
op_info['equations'] = [replace(eq, var_name, inp_term) for eq in op_info['equations']]
op_args['inputs'].update(inp)
# collect operator variables and equations
variables[f"{scope}/inputs"] = {}
equations += [(eq, scope) for eq in op_info['equations']]
for key, var in op_args.items():
full_key = f"{scope}/{key}"
# case I: global time variable
if key == "t":
variables[full_key] = compute_graph.get_var('t')
# case II: input variables
elif key == 'inputs' and var:
variables[f"{scope}/inputs"].update(var)
for in_var in var.values():
try:
variables[in_var] = self._front_to_back[in_var]
except KeyError:
variables[in_var] = graph[in_var]
else:
try:
# case III: variables that have already been processed
variables[full_key] = self._front_to_back[full_key]
except KeyError:
# case IV: new variables
variables[full_key] = self._finalize_var_def(var, reduce)
return equations, variables
@staticmethod
def _finalize_var_def(v: dict, reduce: bool):
if not reduce:
return v
if not v:
return v
if v['vtype'] != 'constant':
return v
if v['dtype'] != 'float':
return v
if 'shape' in v and len(v['shape']) > 1:
return v
if len(np.unique(v['value'])) > 1:
return v
try:
v['value'] = v['value'][0]
v['shape'] = tuple()
except (TypeError, IndexError):
pass
return v
@staticmethod
def _map_multiple_inputs(inputs: dict, scope: str) -> tuple:
"""Creates mapping between multiple input variables and a single output variable.
Parameters
----------
inputs
Input variables.
scope
Scope of the input variables
Returns
-------
tuple
Equation that sums up all input variables and the mapping to the respective input variables
"""
# preparations
if scope not in in_edge_vars:
in_edge_vars[scope] = {}
inputs_unique = in_edge_vars[scope]
input_mapping = {}
new_input_vars = []
# go through all inputs
for key, var in inputs.items():
# get a unique label for the input variable
try:
in_var = var.name
except AttributeError:
in_var = key.split('/')[-1]
inp, inputs_unique_tmp = get_unique_label(in_var, inputs_unique)
inputs_unique.update(inputs_unique_tmp)
# store input-related information
new_input_vars.append(inp)
input_mapping[inp] = key
# collect input into single variable
input_term = f"({'+'.join(new_input_vars)})"
return input_term, input_mapping
@property
def nodes(self):
"""Shortcut to self.graph.nodes. See documentation of `networkx.MultiDiGraph.nodes`."""
return self.graph.nodes
@property
def edges(self):
"""Shortcut to self.graph.edges. See documentation of `networkx.MultiDiGraph.edges`."""
return self.graph.edges
def _get_indexed_var_str(var: str, idx: Union[tuple, str, list], var_length: int = None, reduce: bool = False,
idx_str: str = None, arg_dict: dict = None):
if type(idx) is tuple:
if var_length and int(idx[1]) - int(idx[0]) == var_length:
return var
elif not var_length:
var = f"reshape({var}, 1)"
return f"index_range({var}, {idx[0]}, {idx[1]})"
if type(idx) is str and len(idx) > 0:
if not var_length:
var = f"reshape({var}, 1)"
return f"index({var}, {idx})"
if len(idx) > 0:
if len(idx) == var_length:
identical = True
for i1, i2 in zip(idx, list(np.arange(0, var_length))):
if i1 != i2:
identical = False
break
if identical:
return var
if idx_str:
arg_dict[idx_str] = {'vtype': 'constant', 'value': idx, 'dtype': 'int', 'shape': (len(idx),)}
return f"index({var}, {idx_str})"
return f"index({var}, {idx})"
if reduce:
return f"index({var}, {idx[0]})"
return var
def _get_source_str():
pass
def _get_target_str():
pass