# -*- coding: utf-8 -*-
#
#
# PyRates software framework for flexible implementation of neural
# network model_templates and simulations. See also:
# https://github.com/pyrates-neuroscience/PyRates
#
# Copyright (C) 2017-2018 the original authors (Richard Gast and
# Daniel Rose), the Max-Planck-Institute for Human Cognitive Brain
# Sciences ("MPI CBS") and contributors
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>
#
# CITATION:
#
# Richard Gast and Daniel Rose et. al. in preparation
"""
"""
# external _imports
import time
from typing import Union, Dict, Iterator, Optional, List, Tuple
from warnings import filterwarnings
import re as _re
from networkx import MultiDiGraph, DiGraph, topological_sort
import numpy as np
from copy import deepcopy
from warnings import warn
# pyrates-internal _imports
from pyrates.backend import PyRatesException, PyRatesWarning
from pyrates.ir.node import NodeIR
from pyrates.ir.edge import EdgeIR
from pyrates.ir.abc import AbstractBaseIR
from pyrates.backend.parser import parse_equations, get_unique_label, replace
from pyrates.backend.computegraph import ComputeGraph, ComputeVar, ComputeGraphBackProp
__author__ = "Daniel Rose, Richard Gast"
__status__ = "Development"
in_edge_indices = {} # cache for the number of input edges per network node
in_edge_vars = {} # cache for the input variables that enter at each target operator
#####################
# class definitions #
#####################
# networkx-based representation of all nodes and edges in circuit
[docs]class NetworkGraph(AbstractBaseIR):
"""View on the entire network as a graph. Translates edge operations and attributes into a form that allows to
parse the network graph into a final compute graph."""
def __init__(self, label: str = "circuit", nodes: Dict[str, NodeIR] = None, edges: list = None,
template: str = None, step_size: float = 1e-3, step_size_adaptation: bool = True, verbose: bool = True,
**kwargs):
super().__init__(label=label, template=template)
self._edge_idx_counter = 0
self.step_size = step_size
self.step_size_adaptation = step_size_adaptation
self.graph = MultiDiGraph()
# Set to True by `_add_edge_buffer` when the in-place ring-buffer
# delay path is taken; consulted by CircuitIR.__init__ to fail early
# on backends that don't support mutable-buffer code-gen.
self._uses_edge_delay_buffer = False
if verbose:
print("Compilation Progress")
print("--------------------")
print('\t(1) Translating the circuit template into a networkx graph representation...')
# add nodes to graph
if nodes:
nodes = ((key, {"node": node}) for key, node in nodes.items())
self.graph.add_nodes_from(nodes)
# add edges to graph
if edges:
for (source, target, edge_dict) in edges:
self.add_edge(source, target, **edge_dict)
if verbose:
print('\t\t...finished.')
print("\t(2) Preprocessing edge transmission operations...")
# translate edge operations and attributes into graph operators
self._preprocess_edge_operations(dde_approx=kwargs.pop('dde_approx', 0),
matrix_sparseness=kwargs.pop('matrix_sparseness', 0.1),
vectorized=kwargs.pop('vectorized'))
if verbose:
print("\t\t...finished.")
[docs] def __getitem__(self, key: str):
"""
Custom implementation of __getitem__ that dissolves strings of form "key1/key2/key3" into
lookups of form self[key1][key2][key3].
Parameters
----------
key
Returns
-------
item
"""
try:
return super().__getitem__(key)
except KeyError:
keys = key.split('/')
for i in range(len(keys)):
if "/".join(keys[:i+1]) in self.nodes:
break
key_iter = iter(['/'.join(keys[:i+1])] + keys[i+1:])
key = next(key_iter)
item = self.getitem_from_iterator(key, key_iter)
for key in key_iter:
item = item.getitem_from_iterator(key, key_iter)
return item
[docs] def add_edge(self, source: str, target: str, edge_ir: EdgeIR = None, weight: float = 1., delay: float = None,
spread: float = None, **data):
"""
Parameters
----------
source
target
edge_ir
weight
delay
spread
data
If no template is given, `data` is assumed to conform to the format that is needed to add an edge. I.e.,
`data` needs to contain fields for `weight`, `delay`, `edge_ir`, `source_var`, `target_var`.
Returns
-------
"""
# step 1: parse and verify source and target specifiers
source_node, source_var = self._parse_edge_specifier(source, data, "source_var")
target_node, target_var = self._parse_edge_specifier(target, data, "target_var")
# step 2: parse source variable specifier (might be single string or dictionary for multiple source variables)
source_vars, extra_sources = self._parse_source_vars(source_node, source_var, edge_ir,
data.pop("extra_sources", None))
# step 3: add edges
attr_dict = dict(edge_ir=edge_ir,
weight=weight,
delay=delay,
spread=spread,
source_var=source_vars,
target_var=target_var,
extra_sources=extra_sources,
**data)
self.graph.add_edge(source_node, target_node, **attr_dict)
[docs] def getitem_from_iterator(self, key: str, key_iter: Iterator[str]):
return self.graph.nodes[key]["node"]
def _preprocess_edge_operations(self, dde_approx: int = 0, vectorized: bool = True, **kwargs):
"""Restructures network graph to collapse nodes and edges that share the same operator graphs. Variable values
get an additional vector dimension. References to the respective index is saved in the internal `label_map`."""
# go through nodes and create buffers for delayed outputs and mappings for their inputs
#######################################################################################
for node_name in self.nodes:
node_outputs = self.graph.out_edges(node_name, keys=True)
node_outputs = self._sort_edges(node_outputs, 'source_var', data_included=False)
# loop over ouput variables of node
for i, (out_var, edges) in enumerate(node_outputs.items()):
# extract delay info from variable projections
op_name, var_name = out_var.split('/')
# Separate matrix-connectivity edges (2-D weight array) from scalar edges.
# Matrix edges carry their own delay handling via _add_matrix_delay.
matrix_edges, scalar_edges = [], []
for s, t, e in edges:
w = self.edges[s, t, e]['weight']
if isinstance(w, np.ndarray) and w.ndim == 2:
matrix_edges.append((s, t, e))
else:
scalar_edges.append((s, t, e))
for s, t, e in matrix_edges:
d = self.edges[s, t, e].get('delay')
v = self.edges[s, t, e].get('spread')
if d is not None and d > self.step_size:
self._add_matrix_delay(node_name, op_name, var_name, (s, t, e),
d, v, dde_approx=dde_approx)
if not scalar_edges:
continue
delays, spreads, nodes, add_delay = self._collect_delays_from_edges(scalar_edges)
# add synaptic buffer to output variables with delay
if add_delay:
# Clear delay fields from edges so _generate_edge_equation ignores them.
# Kept here (not inside _collect_delays_from_edges) so that method is pure.
for s, t, e in scalar_edges:
self.edges[s, t, e]['source_idx'] = []
self.edges[s, t, e]['delay'] = None
if vectorized:
self._add_edge_buffer(node_name, op_name, var_name, edges=scalar_edges, delays=delays,
nodes=nodes, spreads=spreads, dde_approx=dde_approx)
else:
# TODO: sort edges into unique delay/spread combinations and only loop over those
if spreads:
for i, (edge, delay, spread, node) in enumerate(zip(scalar_edges, delays, spreads, nodes)):
self._add_edge_buffer(node_name, op_name, var_name, edges=[edge], delays=[delay],
nodes=[node], spreads=[spread], dde_approx=dde_approx,
buffer_id=f"_out{i}")
else:
for i, (edge, delay, node) in enumerate(zip(scalar_edges, delays, nodes)):
self._add_edge_buffer(node_name, op_name, var_name, edges=[edge], delays=[delay],
nodes=[node], dde_approx=dde_approx, buffer_id=f"_out{i}")
# go through nodes again, and collect and process all inputs to each node variable
##################################################################################
for node_name in self.nodes:
node_inputs = self.graph.in_edges(node_name, keys=True)
node_inputs = self._sort_edges(node_inputs, 'target_var', data_included=False)
# loop over inputs to node variable
for i, (in_var, edges) in enumerate(node_inputs.items()):
# extract info from projections to input variable
op_name, var_name = in_var.split('/')
data = self._collect_from_edges(edges,
keys=['source_var', 'weight', 'source_idx', 'target_idx',
'edge_ir', 'edge_var_map'])
# create the final equations for all edges that target the input variable
self._generate_edge_equation(tnode=node_name, top=op_name, tvar=var_name, inputs=data, **kwargs)
def _sort_edges(self, edges: List[tuple], attr: str, data_included: bool = False) -> dict:
"""Sorts edges according to the given edge attribute.
Parameters
----------
edges
Collection of edges of interest.
attr
Name of the edge attribute.
Returns
-------
dict
Key-value pairs of the different values the attribute can take on (keys) and the list of edges for which
the attribute takes on that value (value).
"""
edges_new = {}
if data_included:
for edge in edges:
if len(edge) == 4:
source, target, edge, data = edge
else:
raise ValueError("Missing edge index. This error message should not occur.")
value = self.edges[source, target, edge][attr]
if value not in edges_new.keys():
edges_new[value] = [(source, target, edge, data)]
else:
edges_new[value].append((source, target, edge, data))
else:
for edge in edges:
if len(edge) == 3:
source, target, edge = edge
else:
raise ValueError("Missing edge index. This error message should not occur.")
value = self.edges[source, target, edge][attr]
if value not in edges_new.keys():
edges_new[value] = [(source, target, edge)]
else:
edges_new[value].append((source, target, edge))
return edges_new
def _collect_delays_from_edges(self, edges):
means, stds, nodes = [], [], []
for s, t, e in edges:
# extract delay
d = self.edges[s, t, e]['delay']
if type(d) is list:
d = [1 if d_tmp is None else d_tmp for d_tmp in d]
# extract and process delay distribution spread
v = self.edges[s, t, e].pop('spread', [0])
n_slots = max(len(self.edges[s, t, e]['target_idx']), 1)
if v is None or np.sum(v) == 0:
v = [0] * n_slots
discretize = True
else:
discretize = False
v = self._process_delays(v, discretize=discretize)
# finalize edge delay
if d is None or np.sum(d) == 0:
d = [1] * n_slots
else:
d = self._process_delays(d, discretize=discretize)
# extract source var index
source = self.edges[s, t, e]['source_idx']
if len(d) > 1 and len(source) == 1:
source = source * len(d)
# collect values
means += d
stds += v
nodes.append(source)
# check whether edge delays have to be implemented or can be ignored
max_delay = np.max(means)
add_delay = ("int" in str(type(max_delay)) and max_delay > 1) or \
("float" in str(type(max_delay)) and max_delay > self.step_size)
if sum(stds) == 0:
stds = None
return means, stds, nodes, add_delay
def _collect_from_edges(self, edges: list, keys: list):
data = dict()
for source, target, idx in edges:
edge = self.edges[(source, target, idx)]
if source not in data:
data[source] = dict()
for key in keys:
raw = edge.get(key)
val = raw if isinstance(raw, (np.ndarray, EdgeIR)) else deepcopy(raw)
try:
data[source][key].extend(val)
except AttributeError:
field = data[source][key]
if type(field) is str or field is None:
pass
else:
data[source][key] = [field, val]
except KeyError:
data[source][key] = val
return data
def _add_matrix_delay(self, node: str, op: str, var: str, edge: tuple,
delay: float, spread: Optional[float] = None,
dde_approx: int = 0, buffer_id: str = "") -> None:
"""Add delay buffer equations for a matrix-connectivity edge.
Supports two modes:
- **Discrete ring buffer** (default, fixed step size): ``(Ns, d+1)`` state variable,
updated by roll/index_axis each step. Selected when *spread* is ``None``,
*dde_approx* is 0, and ``step_size_adaptation`` is ``False``.
- **ODE cascade** (gamma-kernel or explicit order): a chain of *n* ODEs of shape
``(Ns,)`` that convolve the source signal with a gamma kernel. Selected when
*spread* > 0, *dde_approx* > 0, or ``step_size_adaptation`` is ``True``.
"""
s, t, e = edge
node_ir = self[node]
op_info = node_ir[op]
node_var = self[f"{node}/{op}/{var}"]
var_shape = node_var.get('shape', ())
Ns = int(var_shape[0]) if var_shape else 1
var_dict: dict = {}
buffer_eqs: list = []
use_ring_buffer = (
(spread is None or spread == 0)
and dde_approx == 0
and not self.step_size_adaptation
)
if use_ring_buffer:
# Record the in-place-buffer requirement so CircuitIR.__init__
# can reject backends that can't honor it (see the matching
# branch in `_add_edge_buffer` and the check in CircuitIR).
self._uses_edge_delay_buffer = True
# --- Discrete ring buffer of shape (Ns, d_steps+1) ---
d_steps = self._preprocess_delay(delay, discretize=True)
buf = f'{var}_buffer{buffer_id}'
buf_out = f'{var}_buffered{buffer_id}'
var_dict[buf] = {'vtype': 'variable', 'dtype': 'float',
'shape': (Ns, d_steps + 1), 'value': 0.}
var_dict[buf_out] = {'vtype': 'variable', 'dtype': 'float',
'shape': (Ns,), 'value': 0.}
# Inline d_steps as a literal so index_axis returns shape (Ns,) not (Ns, 1)
buffer_eqs = [
f"index_axis({buf}) = roll({buf}, 1, 1)",
f"index_axis({buf}, 0, 1) = {var}",
f"{buf_out} = index_axis({buf}, {d_steps}, 1)",
]
else:
# --- ODE cascade (gamma kernel or adaptive step size) ---
if spread is not None and spread > 0:
n = max(1, int(round((delay / spread) ** 2)))
elif dde_approx > 0:
n = dde_approx
else:
n = 1 # minimum ODE order for adaptive step size
a = n / delay if delay else 0.0
for k in range(1, n + 1):
zk = f'{var}_d{k}{buffer_id}'
zk_rate = f'k_d{k}{buffer_id}'
prev = var if k == 1 else f'{var}_d{k-1}{buffer_id}'
var_dict[zk] = {'vtype': 'state_var', 'dtype': 'float',
'shape': (Ns,), 'value': [0.0] * Ns}
var_dict[zk_rate] = {'vtype': 'constant', 'dtype': 'float',
'value': a, 'shape': (1,)}
buffer_eqs.append(f"d/dt * {zk} = {zk_rate} * ({prev} - {zk})")
buf_out = f'{var}_buffered{buffer_id}'
var_dict[buf_out] = {'vtype': 'variable', 'dtype': 'float',
'shape': (Ns,), 'value': [0.0] * Ns}
buffer_eqs.append(f"{buf_out} = {var}_d{n}{buffer_id}")
# Attach buffer equations and variables to the source operator
op_info['equations'] += buffer_eqs
op_info['variables'].update(var_dict)
op_info['output'] = buf_out
# Update intra-node successor inputs (mirrors _add_edge_buffer)
for succ in node_ir.op_graph.succ[op]:
inputs = self[f"{node}/{succ}"]['inputs']
if var not in inputs:
inputs[var] = {'sources': {op}}
# Point the edge at the buffered source variable
self.edges[s, t, e]['source_var'] = f"{op}/{buf_out}"
self.edges[s, t, e]['delay'] = None
def _add_edge_buffer(self, node: str, op: str, var: str, edges: list, delays: list, nodes: list,
spreads: Optional[list] = None, dde_approx: int = 0, buffer_id: str = "") -> None:
"""Adds a buffer variable to an edge.
Parameters
----------
node
Name of the source node of the edge.
op
Name of the source operator of the edge.
var
Name of the source variable of the edge.
edges
List with edge identifier tuples (source_name, target_name, edge_idx).
delays
edge delays.
nodes
Node indices for each edge delay.
spreads
Standard deviations of delay distributions around means given by `delays`.
dde_approx
Only relevant for delayed systems. If larger than zero, all discrete delays in the system will be
automatically approximated by a system of (n+1) coupled ODEs that represent a convolution with a
gamma distribution centered around the original delay (n is the approximation order).
Returns
-------
None
"""
if not delays:
return
max_delay = np.max(delays)
# extract target shape and node
node_var = self[f"{node}/{op}/{var}"]
target_shape = node_var['shape']
node_ir = self[node]
nodes_tmp = list()
for n in nodes:
nodes_tmp += n
source_idx = np.asarray(nodes_tmp, dtype='int').flatten()
# ODE approximation to DDE
##########################
if dde_approx or spreads:
# --- Per-edge ODE orders and rates ---
if spreads:
orders, rates = [], []
for m, v in zip(delays, spreads):
if v > 0:
n_order = int(np.round((m / v) ** 2))
n_order = n_order if m and n_order > dde_approx else dde_approx
else:
n_order = dde_approx if m else 0
orders.append(n_order)
rates.append(n_order / m if m else 0.0)
else:
orders = [dde_approx if m else 0 for m in delays]
rates = [dde_approx / m if m else 0.0 for m in delays]
# --- Group delay slots by (order, rate) — slots in the same group share one ODE chain ---
groups = {}
for slot_idx, (n_order, rate, src) in enumerate(zip(orders, rates, source_idx)):
key = (n_order, round(rate, 12))
if key not in groups:
groups[key] = []
groups[key].append((slot_idx, int(src)))
n_src_var = sum(target_shape) if target_shape else 1
buffer_eqs, var_dict = [], {}
buf_var = f"{var}_buffered{buffer_id}"
var_dict[buf_var] = {'vtype': 'variable', 'dtype': 'float',
'shape': (len(delays),), 'value': 0.0}
for chain_id, ((n_order, _), group) in enumerate(groups.items()):
slot_indices = [g[0] for g in group]
src_indices = [g[1] for g in group]
G = len(group)
rate_val = rates[slot_indices[0]]
# Build chain input: use source var directly when group covers all its elements
if sorted(src_indices) == list(range(n_src_var)):
chain_in = var
elif G == 1:
chain_in = f"index({var}, {src_indices[0]})"
else:
src_name = f"{var}_src{chain_id}{buffer_id}"
var_dict[src_name] = {'vtype': 'constant', 'dtype': 'int',
'value': np.asarray(src_indices, dtype='int'),
'shape': (G,)}
chain_in = f"index({var}, {src_name})"
# Build the ODE chain (n_order stages, one shared rate constant)
chain_shape = (G,) if G > 1 else ()
if n_order > 0:
rate_name = f"k_d{chain_id}{buffer_id}"
var_dict[rate_name] = {'vtype': 'constant', 'dtype': 'float', 'value': rate_val}
prev = chain_in
for k in range(1, n_order + 1):
zk = f"{var}_d{chain_id}_{k}{buffer_id}"
var_dict[zk] = {'vtype': 'state_var', 'dtype': 'float',
'shape': chain_shape, 'value': 0.}
buffer_eqs.append(f"d/dt * {zk} = {rate_name} * ({prev} - {zk})")
prev = zk
else:
prev = chain_in # zero-order: pass-through (no ODE stages)
# Write chain output into the correct slots of the buffer.
# Partial-buffer writes (the G == 1 and else branches below)
# are JAX-compatible when ALL chains write the buffer on every
# call — the buffer doesn't carry state across calls. Ring
# buffers are different: they DO carry state across calls and
# are flagged separately in the `_add_edge_buffer` discrete
# branch (and in `_add_matrix_delay`'s ring-buffer branch).
all_slots = (slot_indices == list(range(len(delays))))
if all_slots:
buffer_eqs.append(f"{buf_var} = {prev}")
elif G == 1:
buffer_eqs.append(f"index({buf_var}, {slot_indices[0]}) = {prev}")
else:
slot_name = f"{var}_slots{chain_id}{buffer_id}"
var_dict[slot_name] = {'vtype': 'constant', 'dtype': 'int',
'value': np.asarray(slot_indices, dtype='int'),
'shape': (len(slot_indices),)}
buffer_eqs.append(f"index({buf_var}, {slot_name}) = {prev}")
# discretized edge buffers
##########################
elif not self.step_size_adaptation:
# Record that this network uses the in-place ring-buffer edge-delay
# path; CircuitIR.__init__ will refuse to compile if the chosen
# backend can't support it (notably JaxBackend, whose arrays are
# immutable — the buffer would never accumulate).
self._uses_edge_delay_buffer = True
# create buffer variable shapes
if len(target_shape) < 1 or (len(target_shape) == 1 and target_shape[0] == 1):
buffer_shape = (max_delay + 1,)
else:
buffer_shape = (target_shape[0], max_delay + 1)
# create buffer variable definitions
var_dict = {f'{var}_buffer{buffer_id}': {'vtype': 'variable',
'dtype': 'float',
'shape': buffer_shape,
'value': 0.},
f'{var}_buffered{buffer_id}': {'vtype': 'variable',
'dtype': 'float',
'shape': (len(delays),),
'value': 0.},
f'{var}_delays{buffer_id}': {'vtype': 'constant',
'dtype': 'int',
'value': delays},
f'source_idx{buffer_id}': {'vtype': 'constant',
'dtype': 'int',
'value': source_idx}}
# create buffer equations
if len(target_shape) < 1 or (len(target_shape) == 1 and target_shape[0] == 1):
# For a single delay, inline the literal so buffer[int] returns a 0-d scalar
# rather than buffer[(1,)_array] which returns a (1,) array and causes
# numpy 2.3+ errors when assigned to a scalar dy[i] slot.
delay_ref = str(delays[0]) if len(delays) == 1 else f"{var}_delays{buffer_id}"
buffer_eqs = [f"index_axis({var}_buffer{buffer_id}) = roll({var}_buffer{buffer_id}, 1)",
f"index({var}_buffer{buffer_id}, 0) = {var}",
f"{var}_buffered{buffer_id} = index({var}_buffer{buffer_id}, {delay_ref})"]
else:
buffer_eqs = [f"index_axis({var}_buffer{buffer_id}) = roll({var}_buffer{buffer_id}, 1, 1)",
f"index_axis({var}_buffer{buffer_id}, 0, 1) = {var}",
f"{var}_buffered{buffer_id} = index_2d({var}_buffer{buffer_id}, source_idx{buffer_id}, "
f"{var}_delays{buffer_id})"]
# Turn ODE system into DDE system
#################################
else:
warn(PyRatesWarning(f'PyRates detected an edge definition that implies to represent the model as a '
f'delayed differential equation system.\n PyRates will thus attempt to access the '
f'history of the source variable {var} of operator {op} on node {node}. '
f'Note that this requires {var} to be a state-variable, i.e. a variable defined by '
f'a differential equation.'))
# create buffer variable definitions
var_dict = {f'{var}_buffered{buffer_id}': {'vtype': 'variable',
'dtype': 'float',
'shape': (len(delays),),
'value': 0.}
}
buffer_eqs = []
for i, (d, sidx) in enumerate(zip(delays, source_idx)):
var_delayed = f"past({var}, {d})" if type(d) is float or d != 1 else var
if len(target_shape) < 1 or (len(target_shape) == 1 and target_shape[0] == 1):
buffer_eqs.append(f"{var}_buffered{buffer_id} = {var_delayed}")
else:
buffer_eqs.append(f"index({var}_buffered{buffer_id}, {sidx}) = index({var_delayed}, {sidx})")
# add buffer equations to node operator
op_info = node_ir[op]
existing_vars = set(op_info.get('variables', {}).keys())
conflicts = existing_vars & set(var_dict.keys())
if conflicts:
raise PyRatesException(
f"Buffer variable name collision in operator '{op}' on node '{node}': {conflicts}. "
f"Use a unique buffer_id to avoid this."
)
op_info['equations'] += buffer_eqs
op_info['variables'].update(var_dict)
op_info['output'] = f"{var}_buffered{buffer_id}"
# update input information of node operators connected to this operator
for succ in node_ir.op_graph.succ[op]:
inputs = self[f"{node}/{succ}"]['inputs']
if var not in inputs.keys():
inputs[var] = {'sources': {op}}
# update edge information
idx_l = 0
for i, edge in enumerate(edges):
s, t, e = edge
self.edges[s, t, e]['source_var'] = f"{op}/{var}_buffered{buffer_id}"
if len(edges) > 1:
idx_h = idx_l + len(nodes[i])
self.edges[s, t, e]['source_idx'] = list(range(idx_l, idx_h))
idx_l = idx_h
def _generate_edge_equation(self, tnode: str, top: str, tvar: str, inputs: dict, matrix_sparseness: float = 0.1,
weight_minimum: float = 1e-8):
# step 0: check properties of the target variable and its inputs
multiple_inputs = len(inputs) > 1
tval = self[f"{tnode}/{top}/{tvar}"]
if tval['shape']:
tsize = sum(tval['shape'])
elif type(tval['value']) is list:
tsize = len(tval['value'])
else:
tsize = 0
# step 1: collect all inputs
weights, source_indices, target_indices, sources = [], [], [], []
edge_irs, edge_var_maps = [], []
for snode, sinfo in inputs.items():
weights.append(sinfo['weight'])
source_indices.append(sinfo['source_idx'])
target_indices.append(sinfo['target_idx'])
sources.append((snode,) + tuple(sinfo['source_var'].split('/')))
edge_irs.append(sinfo.get('edge_ir'))
edge_var_maps.append(sinfo.get('edge_var_map') or {})
# step 2: process incoming edges
source_vars, args = {}, {}
eqs, in_vars = [], []
for i, (weight, sidx, tidx, (snode, sop, svar), edge_ir, edge_var_map) in \
enumerate(zip(weights, source_indices, target_indices, sources, edge_irs, edge_var_maps)):
# define variable name strings (adjusted when multiple inputs share same target var)
if multiple_inputs:
in_shape = (tsize,)
t_str = f'{tvar}_in{i}'
w_str = f'weight_in{i}'
s_str = f'{svar}_in{i}'
sidx_str = f'source_idx_in{i}'
tidx_str = f'target_idx_in{i}'
args[t_str] = {'value': np.zeros(in_shape), 'dtype': 'float', 'vtype': 'variable',
'shape': in_shape}
else:
t_str = tvar
w_str = 'weight'
s_str = svar
sidx_str = 'source_idx'
tidx_str = 'target_idx'
# case 0: matrix edge — weight is a 2-D numpy array supplied directly
# (used by Connectivity; no scalar expansion needed)
if isinstance(weight, np.ndarray) and weight.ndim == 2:
source_vars[s_str] = {'sources': [sop], 'node': snode, 'var': svar}
# Always register the full 2-D weight for cases 0b/0c (wsum uses it);
# case 0a may override with a 1-D vector for the single-source path.
args[w_str] = {'vtype': 'constant', 'value': weight, 'dtype': 'float', 'shape': weight.shape}
if edge_ir is None:
# case 0a: simple matvec — no coupling function
# When n_source == 1 the source variable is a scalar at runtime.
# np.dot((n,1), scalar) returns (n,1) which causes shape errors
# when assigned to an (n,) target. Squeeze axis=1 to a 1-D weight
# vector and use broadcast multiply instead.
if weight.shape[1] == 1:
w_1d = weight.squeeze(axis=1)
args[w_str] = {'vtype': 'constant', 'value': w_1d, 'dtype': 'float', 'shape': w_1d.shape}
eqs.append(f"{t_str} = {w_str} * {s_str}")
else:
eqs.append(f"{t_str} = matvec({w_str}, {s_str})")
else:
# case 0b / 0c: matrix coupling with custom edge equations
def _subst(text, subst_map):
for var, repl in subst_map.items():
text = _re.sub(r'\b' + _re.escape(var) + r'\b', repl, text)
return text
def _de_lhs_var(lhs):
"""Extract bare variable name from a DE left-hand side."""
return _re.sub(r"d/dt\s*\*?\s*|'", "", lhs).strip()
Nt, Ns = weight.shape
# Detect which edge variables are state variables (have DEs)
edge_de_sv_names = set()
for _ok in edge_ir.op_graph.nodes:
for _eq in edge_ir.op_graph.nodes[_ok].get('equations', []):
_lhs = _eq.split('=')[0].strip()
if "d/dt" in _lhs or "'" in _lhs:
edge_de_sv_names.add(_de_lhs_var(_lhs))
# Build broadcast substitution map for edge input variables
expr_map = {}
for ev, info in edge_var_map.items():
if info['role'] == 'source':
expr_map[ev] = f'broadcast_pre({s_str})'
else:
post_var = info['var']
post_op = info['op']
expr_map[ev] = f'broadcast_post({post_var})'
source_vars[post_var] = {'sources': [post_op], 'node': tnode, 'var': post_var}
if edge_de_sv_names:
# case 0c: dynamic edge
# State variables are stored flat (Nt*Ns,) in the global state vector.
# reshape2d / flatten1d views are used in generated equations so that
# the ODE operates in (Nt, Ns) space while the solver sees a 1-D vector.
for _ok in edge_ir.op_graph.nodes:
for vk, vi in edge_ir.op_graph.nodes[_ok].get('variables', {}).items():
vi_dict = vi if isinstance(vi, dict) else {}
vtype = vi_dict.get('vtype', 'constant')
if vk in edge_de_sv_names:
sv_flat = f'{vk}_edge{i}_flat'
sv_init = vi_dict.get('value', 0.0)
if isinstance(sv_init, list):
sv_init = sv_init[0]
expr_map[vk] = f'reshape2d({sv_flat}, {Nt}, {Ns})'
args[sv_flat] = {
'vtype': 'state_var', 'dtype': 'float',
'value': [float(sv_init)] * (Nt * Ns),
'shape': (Nt * Ns,),
}
elif vtype == 'constant' and vk not in edge_var_map:
const_name = f'{vk}_edge{i}'
val = vi_dict.get('value', 0.0)
if isinstance(val, list):
val = val[0]
args[const_name] = {
'vtype': 'constant', 'dtype': 'float',
'value': float(val), 'shape': (1,),
}
expr_map[vk] = const_name
last_out = None
for _ok in topological_sort(edge_ir.op_graph):
_od = edge_ir.op_graph.nodes[_ok]
for _eq in _od.get('equations', []):
_lhs, _rhs = (_s.strip() for _s in _eq.split('=', 1))
_rhs_s = _subst(_rhs, expr_map)
if "d/dt" in _lhs or "'" in _lhs:
sv_flat = f'{_de_lhs_var(_lhs)}_edge{i}_flat'
eqs.append(f"{sv_flat}' = flatten1d({_rhs_s})")
else:
expr_map[_lhs] = _rhs_s
last_out = _od.get('output')
final_expr = expr_map.get(last_out, last_out)
eqs.append(f"{t_str} = wsum({w_str}, {final_expr})")
else:
# case 0b: non-dynamic (algebraic) edge — inline and reduce
last_out = None
for _ok in topological_sort(edge_ir.op_graph):
_od = edge_ir.op_graph.nodes[_ok]
for _eq in _od.get('equations', []):
_lhs, _rhs = (_s.strip() for _s in _eq.split('=', 1))
expr_map[_lhs] = _subst(_rhs, expr_map)
last_out = _od.get('output')
final_expr = expr_map.get(last_out, last_out)
eqs.append(f"{t_str} = wsum({w_str}, {final_expr})")
in_vars.append(t_str)
continue
# get source variable size
sval = self[f"{snode}/{sop}/{svar}"]
if sval['shape']:
ssize = sum(sval['shape'])
elif type(sval['value']) is list:
ssize = len(sval['value'])
else:
ssize = 1
# check whether the edge can be realized via a matrix product
if not tidx:
tidx = [0 for _ in range(len(sidx))]
n, m = len(tidx), len(sidx)
if len(np.unique(tidx)) < len(tidx):
if not sidx:
sidx = [i for i in range(len(weight))]
dot_edge = True
elif n*m > 1 and tsize*ssize > 1:
dot_edge = len(weight) / (n * m) > matrix_sparseness
else:
dot_edge = False
# case I: realize edge projection via a matrix product
if dot_edge:
# create weight matrix
sidx_unique = np.unique(sidx)
tidx_unique = np.unique(tidx)
weight_mat = np.zeros((len(tidx_unique), len(sidx_unique)))
for t, s, w in zip(tidx, sidx, weight):
row = np.argwhere(tidx_unique == t).squeeze()
col = np.argwhere(sidx_unique == s).squeeze()
weight_mat[row, col] = w
# define edge projection equation
s_str_final = _get_indexed_var_str(s_str, sidx_unique, ssize, idx_str=sidx_str, arg_dict=args)
t_str_final = _get_indexed_var_str(t_str, tidx_unique, tsize, idx_str=tidx_str, arg_dict=args)
if len(sidx_unique) == 1:
# Single-source: the source variable is a scalar at runtime
# (time-series arrays are squeezed to 1D so arr[t] returns 0-d).
# Use broadcast multiply with a 1D weight vector so that
# weight_vec * scalar = (n_targets,) rather than the (n_targets,1)
# result that numpy's dot gives for a 2D matrix times a scalar.
weight_mat = weight_mat.squeeze(axis=1)
eq = f"{t_str_final} = {w_str} * {s_str_final}"
else:
eq = f"{t_str_final} = matvec({w_str}, {s_str_final})"
args[w_str] = {'vtype': 'constant', 'value': weight_mat, 'dtype': 'float', 'shape': weight_mat.shape}
# case II: realize edge projection via source and target indexing
else:
# check wether weighting of source variables is required
if all([abs(w-1) < weight_minimum for w in weight]):
weighting = ""
else:
weighting = f" * {w_str}"
args[w_str] = {'vtype': 'constant', 'dtype': 'float', 'value': weight if ssize > 1 else weight[0]}
# get final source and target strings
s_str_final = _get_indexed_var_str(s_str, sidx, ssize, reduce=m == 1 and tsize > 1 and n == 1,
idx_str=sidx_str, arg_dict=args)
t_str_final = _get_indexed_var_str(t_str, tidx, tsize, reduce=tsize > 1 or ssize < tsize,
idx_str=tidx_str, arg_dict=args)
# define edge equation
eq = f"{t_str_final} = {s_str_final}{weighting}"
# add equation and source information
eqs.append(eq)
source_vars[s_str] = {'sources': [sop], 'node': snode, 'var': svar}
in_vars.append(t_str)
# step 3: process multiple inputs to same variable
if multiple_inputs:
# finalize edge equations
eq = f"{tvar} = {'+'.join(in_vars)}"
eqs.append(eq)
# step 4: define target variable as operator output
args[tvar] = tval
args[tvar]['vtype'] = 'variable'
# step 5: add edge operator to target node
if tnode not in in_edge_indices:
in_edge_indices[tnode] = 0
op_name = f'in_edge_{in_edge_indices[tnode]}'
in_edge_indices[tnode] += 1
tnode_ir = self[tnode]
tnode_ir.add_op(op_name, inputs=source_vars, output=tvar, equations=eqs, variables=args)
tnode_ir.add_op_edge(op_name, top)
# step 6: add input information to target operator
inputs = self[tnode][top]['inputs']
if tvar in inputs.keys():
inputs[tvar]['sources'].add(op_name)
else:
inputs[tvar] = {'sources': [op_name]}
def _process_delays(self, d, discretize=True):
if type(d) is list:
d = np.asarray(d).squeeze()
d = [self._preprocess_delay(d_tmp, discretize=discretize) for d_tmp in d] if d.shape else \
[self._preprocess_delay(d, discretize=discretize)]
else:
d = [self._preprocess_delay(d, discretize=discretize)]
return d
def _preprocess_delay(self, delay, discretize=True):
return int(np.round(delay / self.step_size, decimals=0)) if discretize and not self.step_size_adaptation \
else delay
def _bool_to_idx(self, v):
v_idx = np.argwhere(v).squeeze()
v_dict = {}
if v_idx.shape and v_idx.shape[0] > 1 and all(np.diff(v_idx) == 1):
v_idx_str = (f"{v_idx[0]}", f"{v_idx[-1] + 1}")
elif v_idx.shape and v_idx.shape[0] > 1:
var_name = f"delay_idx{self._edge_idx_counter}"
v_idx_str = f"{var_name}"
v_dict[var_name] = {'value': v_idx, 'vtype': 'constant'}
self._edge_idx_counter += 1
else:
try:
v_idx_str = f"{v_idx.max()}"
except ValueError:
v_idx_str = ""
return v_idx.tolist(), v_idx_str, v_dict
def _parse_source_vars(self, source_node: str, source_var: Union[str, dict], edge_ir, extra_sources: dict = None
) -> Tuple[Union[str, dict], dict]:
"""Parse is source variable specifications. This tests, whether a single or more source variables and verifies
all given paths.
Parameters
----------
source_node
String that specifies a single node as source of an edge.
source_var
Single variable specifier string or dictionary of form `{source_op/source_var: edge_op/edge_var
edge_ir
Instance of an EdgeIR that contains information about the internal structure of an edge.
Returns
-------
source_var
"""
# step 1: figure out, whether only one or more source variables are defined
try:
# try to treat source_var as dictionary
n_source_vars = len(source_var.keys())
except AttributeError:
# not a dictionary, so must be a string
n_source_vars = 1
else:
# was a dictionary, treat case that it only has length 1
if n_source_vars == 1:
source_var = next(iter(source_var))
if n_source_vars == 1:
_, _ = source_var.split("/") # should be op, var, but we do not need them here
self._verify_path(source_node, source_var)
else:
# verify that number of source variables matches number of input variables in edge
if extra_sources is not None:
n_source_vars += len(extra_sources)
if n_source_vars != edge_ir.n_inputs:
raise PyRatesException(f"Mismatch between number of source variables ({n_source_vars}) and "
f"inputs ({edge_ir.n_inputs}) in an edge with source '{source_node}' and source"
f"variables {source_var}.")
for node_var, edge_var in source_var.items():
self._verify_path(source_node, node_var)
if extra_sources is not None:
for edge_var, source in extra_sources.items():
node, op, var = source.split("/")
source = "/".join((node, op, var))
self._verify_path(source)
extra_sources[edge_var] = source
return source_var, extra_sources
def _verify_path(self, *parts: str):
"""
Parameters
----------
parts
One or more parts of a path string
Returns
-------
"""
# go trough circuit hierarchy
path = "/".join(parts)
# check if path is valid
if path not in self:
raise PyRatesException(f"Could not find object with path `{path}`.")
@staticmethod
def _parse_edge_specifier(specifier: str, data: dict, var_string: str) -> Tuple[str, Union[str, dict]]:
"""Parse source or target specifier for an edge.
Parameters
----------
specifier
String that defines either a specific node or complete variable path of source or target for an edge.
Format: *circuits/node/op/var
data
dictionary containing additional information about the edge. This function looks for a variable specifier
as specified in `var_string`
var_string
String that points to an optional key of the `data` dictionary. Should be either 'source_var' or
'target_var'
Returns
-------
(node, var)
"""
# step 1: try to get source and target variables from data dictionary, if not available, get them from
# source/target string
try:
# try to get source variable info from data dictionary
var = data.pop(var_string) # type: Union[str, dict]
except KeyError:
# not found, assume variable info is contained in `source`
# also means that there is only one source variable (on the main source node) to take care of
*node, op, var = specifier.split("/")
node = "/".join(node)
var = "/".join((op, var))
else:
# source_var was in data, so `source` contains only info about source node
node = specifier # type: str
return node, var
@property
def nodes(self):
"""Shortcut to self.graph.nodes. See documentation of `networkx.MultiDiGraph.nodes`."""
return self.graph.nodes
@property
def edges(self):
"""Shortcut to self.graph.edges. See documentation of `networkx.MultiDiGraph.edges`."""
return self.graph.edges
# intermediate representation that provides interface between frontend and backend
[docs]class CircuitIR(AbstractBaseIR):
"""Custom graph data structure that represents a backend of nodes and edges with associated equations
and variables."""
__slots__ = ["label", "_front_to_back", "graph", "_t", "_verbose", "_dt", "_dt_adapt", "_def_shape"]
def __init__(self, label: str = "circuit", nodes: Dict[str, NodeIR] = None, edges: list = None,
template: str = None, step_size_adaptation: bool = False, step_size: float = None,
verbose: bool = True, backend: str = None, scalar_shape: tuple = None, **kwargs):
"""
Parameters:
-----------
label
String label, could be used as fallback when subcircuiting this circuit. Currently not used, though.
nodes
Dictionary of nodes of form {node_label: `NodeIR` instance}.
edges
List of tuples (source:str, target:str, edge_dict). `edge_dict` should contain the key "edge_ir" with an
`EdgeIR` instance as item and optionally entries for "weight" and "delay". `source` and `target` should be
formatted as "node/op/var" (with optionally prepended circuits).
template
optional string reference to path to template that this circuit was loaded from. Leave empty, if no template
was used.
"""
# filter displayed warnings
filterwarnings("ignore", category=FutureWarning)
# set main attributes
super().__init__(label=label, template=template)
self._verbose = verbose
self._front_to_back = dict()
self._dt = step_size
self._dt_adapt = step_size_adaptation
self._def_shape = (1,) if scalar_shape is None else scalar_shape
# translate the network into a networkx graph
net = NetworkGraph(nodes=nodes, edges=edges, label=label, step_size=step_size,
step_size_adaptation=step_size_adaptation, verbose=verbose, **kwargs)
# parse network equations into a compute graph
if verbose:
print("\t(3) Parsing the model equations into a compute graph...")
self.graph = self.network_to_computegraph(graph=net, backend=backend, **kwargs)
# Reject combinations the chosen backend can't support. Doing this
# here (after the compute graph exists, so `self.graph.backend` is
# populated, but before the user gets a callable) gives a single
# actionable error instead of a downstream JIT failure.
if getattr(net, '_uses_edge_delay_buffer', False):
be = self.graph.backend
if not getattr(be, 'SUPPORTS_EDGE_DELAY_BUFFER', True):
raise NotImplementedError(
f"Backend `{type(be).__name__}` does not support the "
"discrete edge-delay ring-buffer path emitted by PyRates "
"for fixed-step solvers with `delay=...` edges. The "
"buffer is updated in place via `buf[:] = roll(buf, 1)`, "
"which is incompatible with the backend's immutable "
"arrays — the delayed value would never accumulate. "
"Use an adaptive solver (e.g. solver='scipy' or "
"solver='diffrax') so PyRates emits the DDEHistory path "
"instead, or rewrite the edge with the explicit "
"`past(x, tau)` notation."
)
if verbose:
print("\t\t...finished.")
print("\tModel compilation was finished.")
[docs] def get_var(self, var: str, get_key: bool = False) -> Union[str, ComputeVar]:
"""Extracts variable from the backend (i.e. the `ComputeGraph` instance).
Parameters
----------
var
Name of the variable.
get_key
If true, the backend variable name will be returned
Returns
-------
Union[str, ComputeVar]
Either the backend variable or its name.
"""
try:
v = self[var]
except KeyError:
v = self._front_to_back[var]
return v.name if get_key else v
[docs] def get_frontend_varname(self, var: str) -> str:
"""Returns the original frontend variable name given the backend variable name `var`.
Parameters
----------
var
Name of the backend variable.
Returns
-------
str
Name of the frontend variable
"""
v = self.get_var(var)
front_vars = list(self._front_to_back.keys())
back_vars = list(self._front_to_back.values())
idx = back_vars.index(v)
return front_vars[idx]
[docs] def run(self,
simulation_time: float,
outputs: Optional[dict] = None,
sampling_step_size: Optional[float] = None,
solver: str = 'euler',
**kwargs
) -> dict:
"""Simulate the backend behavior over time.
Parameters
----------
simulation_time
Simulation time in seconds.
outputs
Output variables that will be returned. Each key is the desired name of an output variable and each value is
a string that specifies a variable in the graph in the same format as used for the input definition:
'node_name/op_name/var_name'.
sampling_step_size
Time in seconds between sampling points of the output variables.
solver
Numerical solving scheme to use for differential equations. Currently supported ODE solving schemes:
- 'euler' for the explicit Euler method
- 'scipy' for integration via the `scipy.integrate.solve_ivp` method.
kwargs
Keyword arguments that are passed on to the chosen solver.
Returns
-------
dict
Output variables in a dictionary.
"""
filterwarnings("ignore", category=FutureWarning)
# collect backend variables and functions
#########################################
if self._verbose:
print("Simulation Progress")
print("-------------------")
print("\t (1) Generating the network run function...")
# generate run function
func_name = kwargs.pop('func_name', 'vector_field')
func, func_args, _, _ = self.get_run_func(func_name, **kwargs)
# extract backend variables that correspond to requested output variables
if self._verbose:
print("\t (2) Processing output variables...")
outputs_col = {}
if outputs:
for key, val in outputs.items():
outputs_col[key] = self.get_var(val, get_key=True)
if self._verbose:
print("\t\t...finished.")
# perform simulation
####################
if self._verbose:
print("\t (3) Running the simulation...")
t0 = time.perf_counter()
# call backend run function
results = self.graph.run(func=func, func_args=func_args, T=simulation_time, dt=self._dt, dts=sampling_step_size,
outputs=outputs_col, solver=solver, **kwargs)
if self._verbose:
t1 = time.perf_counter()
print(f"\t\t...finished after {t1-t0}s.")
# return simulation results if outputs have been passed
if outputs_col:
return results
# else, find the frontend variable names of the returned results and create a new results dict to return
for key in results.copy():
front_key = self.get_frontend_varname(key)
results[front_key] = results.pop(key)
return results
[docs] def get_run_func(self, func_name: str, file_name: Optional[str] = None, **kwargs) -> tuple:
if not file_name:
file_name = f"pyrates_func"
return self.graph.to_func(func_name=func_name, file_name=file_name, dt_adapt=self._dt_adapt,
dt=self._dt, **kwargs)
[docs] def get_jacobian_func(self, func_name: str, file_name: Optional[str] = None, **kwargs) -> tuple:
if not file_name:
file_name = "pyrates_func"
return self.graph.get_jacobian_func(func_name=func_name, file_name=file_name,
dt_adapt=self._dt_adapt, dt=self._dt, **kwargs)
[docs] def network_to_computegraph(self, graph: NetworkGraph, inplace_vectorfield: bool = True, **kwargs):
# initialize compute graph
cg = ComputeGraph(**kwargs) if inplace_vectorfield else ComputeGraphBackProp(**kwargs)
# add global time variable to compute graph
cg.add_var(label="t", vtype="state_var", value=0.0 if self._dt_adapt else 0, shape=(),
dtype='float' if self._dt_adapt else 'int')
# node operators
parsing_kwargs = ['parsing_method', 'vectorized']
parsing_kwargs = {key: kwargs.pop(key) for key in parsing_kwargs if key in kwargs}
self._parse_op_layers_into_computegraph(graph, cg, layers=[], exclude=True, op_identifier="edge_from_",
**parsing_kwargs)
# edge operators
self._parse_op_layers_into_computegraph(graph, cg, layers=[0], exclude=False, op_identifier="edge_from_",
**parsing_kwargs)
return cg
[docs] def getitem_from_iterator(self, key: str, key_iter: Iterator[str]):
return self.graph.get_var(key)
[docs] def clear(self):
"""Clears the backend graph from all operations and variables.
"""
self._front_to_back.clear()
self.graph.clear()
in_edge_indices.clear()
in_edge_vars.clear()
def _parse_op_layers_into_computegraph(self, net: NetworkGraph, cg: ComputeGraph, layers: list,
exclude: bool = False, op_identifier: Optional[str] = None,
vectorized: bool = False, **kwargs) -> None:
"""
Parameters
----------
layers
exclude
op_identifier
kwargs
Returns
-------
"""
for node_name, node in net.nodes.items():
op_graph = node['node'].op_graph
g = op_graph.copy() # type: DiGraph
# go through all operators on node and pre-process + extract equations and variables
i = 0
while g.nodes:
# get all operators that have no dependencies on other operators
# noinspection PyTypeChecker
ops = [op for op, in_degree in g.in_degree if in_degree == 0]
if (i in layers and not exclude) or (i not in layers and exclude):
# collect operator variables and equations from node
if op_identifier:
ops_tmp = [op for op in ops if op_identifier not in op] if exclude else \
[op for op in ops if op_identifier in op]
else:
ops_tmp = ops
op_eqs, op_vars = self._collect_ops(ops_tmp, node_name=node_name, graph=net, compute_graph=cg,
reduce=exclude, vectorized=vectorized)
# parse equations and variables into computegraph
variables = parse_equations(op_eqs, op_vars, cg=cg, def_shape=self._def_shape, **kwargs)
# remember mapping between frontend variable names and node keys in compute graph
for key, var in variables.items():
if key.split('/')[-1] != 'inputs' and isinstance(var, ComputeVar):
self._front_to_back[key] = var
# remove parsed operators from graph
g.remove_nodes_from(ops)
i += 1
def _collect_ops(self, ops: List[str], node_name: str, graph: NetworkGraph, compute_graph: ComputeGraph,
reduce: bool, vectorized: bool) -> tuple:
"""Adds a number of operations to the backend graph.
Parameters
----------
ops
Names of the operators that should be parsed into the graph.
node_name
Name of the node that the operators belong to.
graph
compute_graph
reduce
vectorized
Returns
-------
tuple
Collected and updated operator equations and variables
"""
# set up update operation collector variable
equations = []
variables = {}
# add operations of same hierarchical lvl to compute graph
############################################################
for op_name in ops:
# retrieve operator and operator args
scope = f"{node_name}/{op_name}"
op_info = graph[f"{node_name}/{op_name}"]
op_args = deepcopy(op_info['variables'])
op_args['inputs'] = {}
# handle operator inputs
in_ops = {}
for var_name, inp in op_info['inputs'].items():
# go through inputs to variable
if inp['sources']:
in_ops_col = {}
in_node = inp['node'] if 'node' in inp else node_name
in_var_tmp = inp.pop('var', None)
for i, in_op in enumerate(inp['sources']):
# collect single input to op
in_var = in_var_tmp if in_var_tmp else graph[f"{in_node}/{in_op}"]['output']
in_key = f"{in_node}/{in_op}/{in_var}"
try:
in_val = self._front_to_back[in_key]
except KeyError:
in_val = graph[in_key]
in_ops_col[in_key] = in_val
# if multiple inputs to variable, sum them up
if len(in_ops_col) > 1:
in_ops[var_name] = self._map_multiple_inputs(in_ops_col, scope=scope)
else:
key, _ = in_ops_col.popitem()
in_ops[var_name] = (None, {var_name: key})
# replace input variables with input in operator equations
for var, (inp_term, inp) in in_ops.items():
if inp_term:
op_info['equations'] = [replace(eq, var_name, inp_term) for eq in op_info['equations']]
op_args['inputs'].update(inp)
# collect operator variables and equations
variables[f"{scope}/inputs"] = {}
equations += [(eq, scope) for eq in op_info['equations']]
for key, var in op_args.items():
full_key = f"{scope}/{key}"
# case I: global time variable
if key == "t":
variables[full_key] = compute_graph.get_var('t')
# case II: input variables
elif key == 'inputs' and var:
variables[f"{scope}/inputs"].update(var)
for in_var in var.values():
try:
variables[in_var] = self._front_to_back[in_var]
except KeyError:
variables[in_var] = self._finalize_var_def(graph[in_var], reduce, vectorized)
else:
try:
# case III: variables that have already been processed
variables[full_key] = self._front_to_back[full_key]
except KeyError:
# case IV: new variables
variables[full_key] = self._finalize_var_def(var, reduce, vectorized)
return equations, variables
@staticmethod
def _finalize_var_def(v: dict, reduce: bool, vectorized: bool):
if not reduce:
return v
if not v:
return v
if v['vtype'] != "constant" and vectorized:
return v
if v['dtype'] != 'float':
return v
if 'shape' in v and len(v['shape']) > 1:
return v
if 'shape' in v and np.prod(v['shape']) > 1:
return v
if len(np.unique(v['value'])) > 1:
return v
try:
v['value'] = v['value'][0]
v['shape'] = tuple()
except (TypeError, IndexError):
pass
return v
@staticmethod
def _map_multiple_inputs(inputs: dict, scope: str) -> tuple:
"""Creates mapping between multiple input variables and a single output variable.
Parameters
----------
inputs
Input variables.
scope
Scope of the input variables
Returns
-------
tuple
Equation that sums up all input variables and the mapping to the respective input variables
"""
# preparations
if scope not in in_edge_vars:
in_edge_vars[scope] = {}
inputs_unique = in_edge_vars[scope]
input_mapping = {}
new_input_vars = []
# go through all inputs
for key, var in inputs.items():
# get a unique label for the input variable
try:
in_var = var.name
except AttributeError:
in_var = key.split('/')[-1]
inp, inputs_unique_tmp = get_unique_label(in_var, inputs_unique)
inputs_unique.update(inputs_unique_tmp)
# store input-related information
new_input_vars.append(inp)
input_mapping[inp] = key
# collect input into single variable
input_term = f"({'+'.join(new_input_vars)})"
return input_term, input_mapping
@property
def nodes(self):
"""Shortcut to self.graph.nodes. See documentation of `networkx.MultiDiGraph.nodes`."""
return self.graph.nodes
@property
def edges(self):
"""Shortcut to self.graph.edges. See documentation of `networkx.MultiDiGraph.edges`."""
return self.graph.edges
def _get_indexed_var_str(var: str, idx: Union[tuple, str, list], var_length: int = None, reduce: bool = False,
idx_str: str = None, arg_dict: dict = None):
if type(idx) is tuple:
if var_length and int(idx[1]) - int(idx[0]) == var_length:
return var
elif not var_length:
var = f"reshape({var}, 1)"
return f"index_range({var}, {idx[0]}, {idx[1]})"
if type(idx) is str and len(idx) > 0:
if not var_length:
var = f"reshape({var}, 1)"
return f"index({var}, {idx})"
if len(idx) > 0:
if len(idx) == var_length:
identical = True
for i1, i2 in zip(idx, list(np.arange(0, var_length))):
if i1 != i2:
identical = False
break
if identical:
return var
if idx_str:
arg_dict[idx_str] = {'vtype': 'constant', 'value': idx, 'dtype': 'int', 'shape': (len(idx),)}
return f"index({var}, {idx_str})"
return f"index({var}, {idx})"
if reduce:
return f"index({var}, {idx[0]})"
return var
def _get_source_str():
pass
def _get_target_str():
pass