Source code for pyrates.backend.computegraph

# -*- coding: utf-8 -*-
#
#
# PyRates software framework for flexible implementation of neural
# network model_templates and simulations. See also:
# https://github.com/pyrates-neuroscience/PyRates
#
# Copyright (C) 2017-2018 the original authors (Richard Gast and
# Daniel Rose), the Max-Planck-Institute for Human Cognitive Brain
# Sciences ("MPI CBS") and contributors
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <https://www.gnu.org/licenses/>
#
# CITATION:
#
# Richard Gast and Daniel Rose et. al. in preparation

"""This module provides the backend class that should be used to set up any backend in pyrates.
"""

# external _imports
from typing import Any, Callable, Union, Iterable, Optional
from networkx import MultiDiGraph
from sympy import Symbol, Expr, Function, lambdify
import numpy as np

# meta infos
__author__ = "Richard Gast"
__status__ = "development"


#########################
# compute graph classes #
#########################


# numpy-based node class
[docs]class ComputeNode: """Base class for adding variables to the compute graph. Creates a numpy array with additional attributes for variable identification/retrieval from graph. Should be used as parent class for custom variable classes. Parameters ---------- name Full name of the variable in the original NetworkGraph (including the node and operator it belongs to). dtype Data-type of the variable. For valid data-types, check the documentation of the backend in use. shape Shape of the variable. """ __slots__ = ["name", "symbol", "dtype", "shape", "_value"] def __init__(self, name: str, symbol: Union[Symbol, Expr, Function], dtype: Optional[str] = None, shape: tuple = (), def_shape: tuple = ()): """Instantiates a basic node of a ComputeGraph instance. """ self.name = name self.symbol = symbol self.shape = self._get_shape(shape, def_shape) self._value = np.zeros(self.shape) self.dtype = dtype self.set_dtype()
[docs] def reshape(self, shape: tuple, **kwargs): self._value = self.value.reshape(shape, **kwargs) self.shape = shape return self
[docs] def squeeze(self, axis=None): self._value = self.value.squeeze(axis=axis) self.shape = self._value.shape return self
[docs] def set_value(self, v: Union[float, np.ndarray]): self._value = np.asarray(v, dtype=self.dtype) self.shape = tuple(v.shape)
@property def value(self): """Returns current value of BaseVar. """ return self._value @property def is_constant(self): raise NotImplementedError("This method has to be defined by each child class.") @property def is_float(self): return "float" in self.dtype @property def is_complex(self): return "complex" in self.dtype def _is_equal_to(self, v): for attr in self.__slots__: if not hasattr(v, attr) or getattr(v, attr) != getattr(self, attr): return False return True def _get_value(self, value: Optional[Union[list, np.ndarray]] = None, dtype: Optional[str] = None, shape: tuple = ()): """Defines initial value of variable. """ # case I: create new array from shape and dtype if value is None: return np.zeros(shape=shape, dtype=dtype) # case II: transform values into an array if not hasattr(value, 'shape'): if type(value) is list: return self._get_value(value=np.asarray(value, dtype=dtype), dtype=dtype, shape=shape) return np.zeros(shape=shape, dtype=dtype) + value # case III: match given shape with the shape of the given value array if len(shape) > 0: value = np.asarray(value, dtype=dtype) if value.shape == shape: return value if sum(shape) < sum(value.shape): return value.squeeze() idx = ",".join("None" if s == 1 else ":" for s in shape) return eval(f'value[{idx}]') # case IV: just ensure the correct data type of the value array return np.asarray(value, dtype=dtype)
[docs] def set_dtype(self, dtype: str = None): if dtype is None: if not self.dtype: if 'float' in str(self.value.dtype): self.dtype = 'float' elif 'complex' in str(self.value.dtype): self.dtype = 'complex' else: self.dtype = 'int' else: self.dtype = dtype
@staticmethod def _get_shape(s: tuple, s_def: tuple): if sum(s) <= 1: return s_def return s def __deepcopy__(self, memodict: dict): node = ComputeNode(name=self.name, symbol=self.symbol, dtype=self.dtype, shape=self.shape) node._value = np.zeros_like(node._value) + node._value return node def __str__(self): return self.name def __hash__(self): return hash(str(self))
[docs]class ComputeVar(ComputeNode): """Class for variables and vector-valued constants in the ComputeGraph. """ __slots__ = ComputeNode.__slots__ + ["vtype"] def __init__(self, name: str, symbol: Union[Symbol, Expr, Function], vtype: str, dtype: Optional[str] = None, shape: tuple = (), value: Optional[Union[list, np.ndarray]] = None, def_shape: tuple = ()): # set attributes super().__init__(name=name, symbol=symbol, dtype=dtype, shape=shape, def_shape=def_shape) self.vtype = vtype # adjust variable value self.set_value(self._get_value(value=value, shape=self.shape, dtype=self.dtype)) @property def is_constant(self): return self.vtype == 'constant'
[docs]class ComputeOp(ComputeNode): """Class for ComputeGraph nodes that represent mathematical operations. """ __slots__ = ComputeNode.__slots__ + ["func", "expr", "func_args", "backend_funcs"] def __init__(self, name: str, symbol: Union[Symbol, Expr, Function], expr: Expr, func: Optional[Callable] = None, func_args: Optional[list] = None, backend_funcs: Optional[dict] = None, dtype: Optional[str] = None, shape: tuple = ()): # set attributes super().__init__(name=name, symbol=symbol, dtype=dtype, shape=shape) self.func = func self.expr = expr self.func_args = func_args if func_args is not None else [] self.backend_funcs = backend_funcs if backend_funcs is not None else {}
[docs] def get_func(self) -> Callable: if self.func is None: self.func = lambdify(self.func_args, expr=self.expr, modules=[self.backend_funcs, "numpy"]) return self.func
@property def is_constant(self): return False
# networkx-based graph class
[docs]class ComputeGraph(MultiDiGraph): """Creates a compute graph where nodes are all constants and variables of the network and edges are the mathematical operations linking those variables/constants together to form equations. """ def __init__(self, backend: str, **kwargs): super().__init__() # choose a backend if backend == 'torch': from pyrates.backend.torch import TorchBackend backend = TorchBackend elif backend == 'jax': from pyrates.backend.jax import JaxBackend backend = JaxBackend elif backend == 'fortran': from pyrates.backend.fortran import FortranBackend backend = FortranBackend elif backend == 'julia': from pyrates.backend.julia import JuliaBackend backend = JuliaBackend elif backend == 'matlab': from pyrates.backend.matlab import MatlabBackend backend = MatlabBackend else: from pyrates.backend.base import BaseBackend backend = BaseBackend # backend-related attributes self.backend = backend(**kwargs) self.var_updates = {'DEs': dict(), 'non-DEs': dict()} self._eq_nodes = [] self._state_var_indices = dict() self._state_var_hist = dict() self._node_names = {} @property def state_vars(self): return list(self.var_updates['DEs'].keys())
[docs] def add_var(self, label: str, value: Any, vtype: str, **kwargs): unique_label = self._generate_unique_label(label) var = ComputeVar(name=unique_label, symbol=Symbol(unique_label), value=value, vtype=vtype, **kwargs) super().add_node(unique_label, node=var) return unique_label, self.nodes[unique_label]['node']
[docs] def add_op(self, inputs: Union[list, tuple], label: str, expr: Expr, func: Optional[Callable] = None, func_args: Optional[list] = None, backend_funcs: Optional[dict] = None, **kwargs): # add target node that contains result of operation unique_label = self._generate_unique_label(label) op = ComputeOp(name=unique_label, symbol=Symbol(unique_label), expr=expr, func=func, func_args=func_args, backend_funcs=backend_funcs, **kwargs) super().add_node(unique_label, node=op) # add edges from source nodes to target node for i, v in enumerate(inputs): super().add_edge(v, unique_label, key=i) return unique_label, self.nodes[unique_label]['node']
[docs] def add_var_update(self, var: str, update: str, differential_equation: bool = False): # store mapping between left-hand side variable and right-hand side update if differential_equation: self.var_updates['DEs'][var] = update else: self.var_updates['non-DEs'][var] = update # remember var and update node to ensure that they are not pruned during compilation self._eq_nodes.extend([var, update])
[docs] def get_var(self, var: str, from_backend: bool = False): v = self.nodes[var]['node'] if from_backend: return self.backend.get_var(v) return v
[docs] def get_op(self, op: str, **kwargs) -> dict: return self.backend.get_op(op, **kwargs)
[docs] def eval_graph(self): for n in self.var_updates['non-DEs'].values(): self.eval_subgraph(n) return self.eval_nodes(self.var_updates['DEs'].values())
[docs] def eval_nodes(self, nodes: Iterable): return [self.eval_node(n) for n in nodes]
[docs] def eval_node(self, n): inputs = tuple([self.eval_node(inp) for inp in self.predecessors(n)]) node = self.get_var(n) if isinstance(node, ComputeOp): return node.get_func()(*inputs) return node.value
[docs] def eval_subgraph(self, n): inputs = [] input_nodes = [node for node in self.predecessors(n)] for inp in input_nodes: inputs.append(self.eval_subgraph(inp)) self.remove_node(inp) node = self.get_var(n) if inputs: node.set_value(node.get_func()(*tuple(inputs))) return node.value
[docs] def remove_subgraph(self, n): for inp in self.predecessors(n): self.remove_subgraph(inp) self.remove_node(n)
[docs] def compile(self): # evaluate constant-based operations out_nodes = [node for node, out_degree in self.out_degree if out_degree == 0] for node in out_nodes: # process inputs of node for inp in self.predecessors(node): if self.get_var(inp).is_constant: self.eval_subgraph(inp) # evaluate node if all its inputs are constants if all([self.get_var(inp).is_constant for inp in self.predecessors(node)]) and node not in self._eq_nodes: self.eval_subgraph(node) # remove unconnected nodes and constants from graph self._prune() return self
[docs] def to_func(self, func_name: str, to_file: bool = True, dt_adapt: bool = True, dt: float = None, **kwargs): # finalize compute graph self.compile() # create state variable vector and state variable update vector ############################################################### variables = [] idx = 0 for var, update in self.var_updates['DEs'].items(): # extract left-hand side nodes from graph lhs, rhs = self._process_var_update(var, update) variables.append(lhs.value) # store information of the original, non-vectorized state variable vshape = sum(lhs.shape) if vshape > 1: self._state_var_indices[var] = (idx, idx+vshape) idx += vshape else: self._state_var_indices[var] = idx idx += 1 # add collected state variables to the backend try: state_vec = np.concatenate(variables, axis=0) except ValueError: try: state_vec = np.asarray(variables) except ValueError: state_vec = np.asarray([np.squeeze(v) for v in variables]) dtype = 'complex' if 'complex' in state_vec.dtype.name else 'float' state_var_key, y = self.add_var(label='y', vtype='state_var', value=state_vec, dtype=dtype) rhs_var_key = self._generate_vecfield_var(state_vec, dtype) try: t = self.get_var('t') except KeyError: _, t = self.add_var(label='t', vtype='state_var', value=0.0 if dt_adapt else 0, dtype='float' if dt_adapt else 'int', shape=()) self.backend.register_vars([t, y]) # When the backend's auto-07p path is active, compute symbolic ∂F/∂U # and ∂F/∂PAR *before* ``_to_str`` runs — ``_to_str`` consumes the # ``var_updates['non-DEs']`` dict during code generation, after which # ``_get_symbolic_rhs`` can no longer expand auxiliary variables (and # the parameter dependency they introduce, e.g. ``weight`` via # ``r_in = r*weight``, is missing from the symbolic vector field). # Skipping DDE models: auto-07p doesn't continue them natively. if kwargs.get('auto', False) and kwargs.get('auto_jac', True): try: jac_data = self._compute_symbolic_jacobian(include_dfdp=True) if not jac_data['is_dde']: kwargs['auto_jacobian'] = jac_data except Exception: # pragma: no cover pass # create a string containing all computations and variable updates represented by the compute graph func_args, code_gen = self._to_str() func_body = code_gen.generate() code_gen.code.clear() # generate function head add_hist_calls = self._state_var_hist and code_gen.add_hist_arg func_args = code_gen.generate_func_head(func_name=func_name, state_var=state_var_key, return_var=rhs_var_key, func_args=[self.get_var(arg) for arg in func_args], add_hist_func=add_hist_calls) # extract state variable histories for delayed interactions code_gen.add_linebreak() for var, delays in self._state_var_hist.items(): # extract index of variable in state vector idx = self._state_var_indices[var] # extract state variable history from backend-specific buffer if type(idx) is not int: if idx[1]-idx[0] < 2: idx = idx[0] for delay, v_hist in delays.items(): # type: ComputeVar, str code_gen.add_var_hist(lhs=v_hist, delay=delay, state_idx=idx, var=var, dt=dt, dt_adapt=dt_adapt) # add lines from function body after function head code_gen.add_linebreak() code_gen.add_code_line(func_body) code_gen.add_linebreak() # generate function tail self._generate_func_tail(code_gen, rhs_var_key) # generate the function (and write to file, optionally) func_args_tmp = func_args[4:] if add_hist_calls else func_args[3:] func = code_gen.generate_func(func_name=func_name, to_file=to_file, func_args=func_args_tmp, state_vars=self.state_vars, **kwargs) # OPTIONAL: write function arguments (state vectors and constants) to file c_fn = kwargs.pop('constants_file_name', None) if c_fn: arg_dict = {arg: self.get_var(arg).value for arg in func_args if arg != 'hist'} if code_gen.lags: arg_dict['lags'] = list(code_gen.lags.keys()) fn = f'{self.backend.fdir}/{c_fn}' if self.backend.fdir else c_fn code_gen.to_file(fn, **arg_dict) # finalize the function arguments fargs = [] for arg in func_args: if arg == 'hist': if 'hist' in kwargs: arg = kwargs.pop('hist') else: y_init = np.asarray(state_vec[:]) arg = code_gen.get_hist_func(y_init) fargs.append(arg) else: fargs.append(self.get_var(arg, from_backend=True)) return func, tuple(fargs), tuple(func_args), self._state_var_indices.copy()
[docs] def run(self, func: Callable, func_args: tuple, T: float, dt: float, dts: Optional[float] = None, outputs: Optional[dict] = None, **kwargs) -> dict: # pre-process outputs if outputs is None: outputs = {key: key for key in self.state_vars} for key in outputs.copy(): var = outputs.pop(key) outputs[key] = self._state_var_indices[var] # handle other arguments if dts is None: dts = dt solver = kwargs.pop('solver', 'euler') # call backend method results, times = self.backend.run(func=func, func_args=func_args, T=T, dt=dt, dts=dts, solver=solver, **kwargs) # set state variables to final simulated value for key in self.state_vars: var = self.get_var(key) idx = self._state_var_indices[key] var.set_value(np.reshape(self._index_state_var(results, idx)[-1], var.shape)) # reduce state recordings to requested state variables for key, idx in outputs.items(): outputs[key] = self._index_state_var(results, idx) outputs['time'] = times return outputs
[docs] def get_jacobian_func(self, func_name: str, to_file: bool = True, sparse: bool = False, dt_adapt: bool = True, dt: float = None, **kwargs) -> tuple: """Generate a function that evaluates the Jacobian of the vector field. For ODE systems the generated function has signature ``J(t, y, *params) -> ndarray (n, n)``. For DDE systems it returns ``(J0, [J_tau1, J_tau2, ...])`` where ``J0`` is the instantaneous Jacobian and ``J_tauK`` is the partial derivative with respect to ``y(t - tau_k)``. If ``sparse=True`` each matrix is a ``scipy.sparse.csr_matrix`` instead. The Jacobian entries are computed symbolically via ``sympy.diff`` using the reconstructed vector-field expressions from the compute graph. Vector-valued state variables fall back to zero blocks (a ``UserWarning`` lists the affected indices). Parameters ---------- func_name Name of the generated function. to_file Write source to a file (same semantics as ``to_func``). sparse Return ``scipy.sparse.csr_matrix`` instead of dense ``ndarray``. dt_adapt Whether the circuit uses adaptive time-stepping (controls ``t`` dtype). dt Fixed time-step (passed through to backend). Returns ------- tuple ``(func, args, arg_names, state_var_indices)`` — same structure as ``to_func``. """ import sympy as sp from warnings import warn # ── 1. Finalise graph and build state-vector structure (mirrors to_func) ── self.compile() variables, idx = [], 0 for var, update in self.var_updates['DEs'].items(): lhs, _ = self._process_var_update(var, update) variables.append(lhs.value) vshape = sum(lhs.shape) if vshape > 1: self._state_var_indices[var] = (idx, idx + vshape) idx += vshape else: self._state_var_indices[var] = idx idx += 1 n = idx try: state_vec = np.concatenate(variables, axis=0) except ValueError: try: state_vec = np.asarray(variables) except ValueError: state_vec = np.asarray([np.squeeze(v) for v in variables]) dtype = 'complex' if 'complex' in state_vec.dtype.name else 'float' state_var_key, y_var = self.add_var(label='y', vtype='state_var', value=state_vec, dtype=dtype) try: t_var = self.get_var('t') except KeyError: _, t_var = self.add_var(label='t', vtype='state_var', value=0.0 if dt_adapt else 0, dtype='float' if dt_adapt else 'int', shape=()) self.backend.register_vars([t_var, y_var]) # ── 2. Rebuild symbolic vector field ── f_exprs, y_syms, past_map, func_args, var_is_vector = self._get_symbolic_rhs() # map state sym → y-index sym_to_y_idx = {sym: self._state_var_indices[var] for var, sym in zip(self.var_updates['DEs'].keys(), y_syms)} is_dde = bool(past_map) start = self.backend._start_idx # 0 for Python, 1 for Julia/MATLAB # group past symbols by delay string for history-Jacobian computation # delay_str → [(fresh_sym, var_sym, state_idx)] delay_groups: dict = {} for (var_sym, delay_sym), fresh_sym in past_map.items(): d_str = str(delay_sym) if d_str not in delay_groups: delay_groups[d_str] = [] for sym, vidx in sym_to_y_idx.items(): if sym == var_sym: delay_groups[d_str].append((fresh_sym, var_sym, vidx)) break # fresh past sym → code string used in generated Jacobian expressions past_sym_to_str: dict = {} for d_str, group in delay_groups.items(): d_safe = d_str.replace('.', 'p').replace('-', 'm') for fresh_sym, _, vidx in group: if isinstance(vidx, tuple): past_sym_to_str[fresh_sym] = f'_yhist_{d_safe}[{vidx[0]+start}:{vidx[1]}]' else: past_sym_to_str[fresh_sym] = f'_yhist_{d_safe}[{vidx+start}]' # ── 3. Compute symbolic Jacobian entries via sympy.diff ── J0_entries: dict = {} # (i_row, j_col) → sympy Expr J_hist: dict = {} # delay_str → {(i_row, j_col) → sympy Expr} numerical_blocks = [] for d_str in delay_groups: J_hist[d_str] = {} i_row = 0 for f_i, yi_sym, fi_is_vec in zip(f_exprs, y_syms, var_is_vector): fi_idx = sym_to_y_idx[yi_sym] fi_nrows = (fi_idx[1] - fi_idx[0]) if isinstance(fi_idx, tuple) else 1 j_col = 0 for yj_sym, fj_is_vec in zip(y_syms, var_is_vector): fj_idx = sym_to_y_idx[yj_sym] fj_ncols = (fj_idx[1] - fj_idx[0]) if isinstance(fj_idx, tuple) else 1 if fi_is_vec or fj_is_vec: numerical_blocks.append((i_row, i_row + fi_nrows, j_col, j_col + fj_ncols)) else: d = sp.diff(f_i, yj_sym) if d != 0: J0_entries[(i_row, j_col)] = d j_col += fj_ncols # history Jacobians for d_str, group in delay_groups.items(): j_col = 0 for fresh_sym, _, fj_idx in group: fj_ncols = (fj_idx[1] - fj_idx[0]) if isinstance(fj_idx, tuple) else 1 if not fi_is_vec and fj_ncols == 1: d = sp.diff(f_i, fresh_sym) if d != 0: J_hist[d_str][(i_row, j_col)] = d j_col += fj_ncols i_row += fi_nrows if numerical_blocks: warn( f"get_jacobian_func: {len(numerical_blocks)} vector-valued Jacobian block(s) set to zero " f"(analytical differentiation not supported for vector DEs). " f"Affected y-index ranges: {numerical_blocks}", UserWarning ) # ── 4. Generate code ── code_gen = self.backend code_gen.code.clear() # Imports that the Jacobian assembly emits. Must be declared BEFORE # generate_func_head, which materialises imports into the source file. code_gen.declare_local_array_imports() # backend-specific (numpy / jax.numpy / ...) if sparse: if not getattr(code_gen, 'SUPPORTS_SPARSE_JACOBIAN', True): raise NotImplementedError( f"Backend `{type(code_gen).__name__}` does not support " "sparse=True for get_jacobian_func. The scipy.sparse " "csr_matrix constructor cannot consume the backend's array " "type at code-gen time. Use sparse=False and convert " "post-hoc if needed." ) code_gen.add_import("from scipy.sparse import csr_matrix") # determine return-variable name(s) for MATLAB function signature if J_hist: d_safes = [d_str.replace('.', 'p').replace('-', 'm') for d_str in J_hist] jk_names = [f'J_hist_{s}' for s in d_safes] return_var_name = f'J0' # Julia/MATLAB named outputs handled below else: jk_names = [] return_var_name = 'J0' func_args_objects = [self.get_var(a) for a in func_args] all_arg_names = code_gen.generate_func_head( func_name=func_name, state_var=state_var_key, return_var=return_var_name, func_args=func_args_objects, add_hist_func=is_dde and code_gen.add_hist_arg, ) # past-state extraction for DDE if is_dde: code_gen.add_linebreak() for d_str, group in delay_groups.items(): d_safe = d_str.replace('.', 'p').replace('-', 'm') code_gen.add_code_line(f"_yhist_{d_safe} = hist(t - {d_str})") # allocate instantaneous Jacobian (backend-aware: emits numpy `zeros` # for the default path, jnp.zeros for JaxBackend, etc.) code_gen.add_linebreak() dtype_str = code_gen._float_precision code_gen.emit_local_array_alloc('J0', (n, n), dtype_str) # fill non-zero J0 entries for (i_r, j_c), d_expr in sorted(J0_entries.items()): d_str_code = self._expr_to_jac_str(d_expr, sym_to_y_idx, {}) if d_str_code is None: code_gen.add_code_line( f"# WARNING: could not differentiate J0[{i_r},{j_c}] analytically — entry left as 0") else: code_gen.emit_local_array_assign('J0', (i_r + start, j_c + start), d_str_code) # history Jacobians code_gen.add_linebreak() for d_str, entries in J_hist.items(): d_safe = d_str.replace('.', 'p').replace('-', 'm') jk = f'J_hist_{d_safe}' code_gen.emit_local_array_alloc(jk, (n, n), dtype_str) for (i_r, j_c), d_expr in sorted(entries.items()): d_str_code = self._expr_to_jac_str(d_expr, sym_to_y_idx, past_sym_to_str) if d_str_code is None: code_gen.add_code_line( f"# WARNING: could not differentiate {jk}[{i_r},{j_c}] analytically — entry left as 0") else: code_gen.emit_local_array_assign(jk, (i_r + start, j_c + start), d_str_code) code_gen.add_linebreak() # return statement (replaces generate_func_tail so we control the return value) if sparse: j0_ret = "csr_matrix(J0)" jk_rets = [f"csr_matrix({jk})" for jk in jk_names] else: j0_ret = "J0" jk_rets = jk_names if jk_names: ret_str = f"{j0_ret}, [{', '.join(jk_rets)}]" else: ret_str = j0_ret code_gen.generate_func_tail(rhs_var=ret_str) # compile / write to file param_arg_names = [a for a in all_arg_names if a not in ('t', state_var_key, 'hist')] func = code_gen.generate_func( func_name=func_name, to_file=to_file, func_args=param_arg_names, state_vars=self.state_vars, **kwargs ) # assemble actual argument values fargs = [0.0, state_vec.copy()] if is_dde and code_gen.add_hist_arg: fargs.append(code_gen.get_hist_func(state_vec.copy())) for a in all_arg_names: if a not in ('t', state_var_key, 'hist'): fargs.append(self.get_var(a, from_backend=True)) return func, tuple(fargs), tuple(all_arg_names), self._state_var_indices.copy()
def _get_symbolic_rhs(self) -> tuple: """Reconstruct the full symbolic vector field from the compute graph. Must be called after ``compile()``. Returns ------- f_exprs : list[sympy.Expr] y_syms : list[sympy.Symbol] past_map : dict ``{(var_sym, delay_sym): fresh_sym}`` func_args : list[str] constant node names (function parameters) var_is_vector : list[bool] """ import sympy as sp # Build symbolic expressions for all non-DE (algebraic) variables so that # delayed inputs routed via buffer operators can be expanded into the DE # expressions and their past() calls can be detected. non_de_exprs: dict = {} non_de_args: list = [] for var, update in self.var_updates['non-DEs'].items(): nde_var = self.get_var(var) try: args, expr = self._node_to_expr(update) except Exception: continue non_de_exprs[nde_var.symbol] = expr non_de_args.extend(args) def _expand_non_de(expr): """Substitute non-DE symbols with their full expressions.""" changed = True while changed: changed = False subs = {} for sym in expr.free_symbols: if sym in non_de_exprs: subs[sym] = non_de_exprs[sym] if subs: new_expr = expr.subs(subs) if new_expr != expr: expr = new_expr changed = True return expr all_func_args = list(non_de_args) f_exprs, y_syms, past_map, var_is_vector = [], [], {}, [] for var, update in self.var_updates['DEs'].items(): lhs_node = self.get_var(var) args, expr = self._node_to_expr(update) all_func_args.extend(args) # expand non-DE symbols to reveal any past() calls expr = _expand_non_de(expr) expr, new_past = self._extract_past_terms(expr) past_map.update(new_past) f_exprs.append(expr) y_syms.append(lhs_node.symbol) var_is_vector.append(sum(lhs_node.shape) > 1) seen, ordered = set(), [] for a in all_func_args: if a not in seen: seen.add(a) ordered.append(a) return f_exprs, y_syms, past_map, ordered, var_is_vector def _compute_symbolic_jacobian(self, include_dfdp: bool = True) -> dict: """Compute symbolic ∂F/∂U and (optionally) ∂F/∂PAR dictionaries. Shared helper used by both :meth:`get_jacobian_func` and the auto-07p path of the Fortran backend (which inlines the entries into the ``FUNC`` wrapper guarded by ``IJAC > 0``). Parameters ---------- include_dfdp When True, also compute ∂F/∂PAR for every non-state, non-time argument the vector field depends on (used by auto-07p so it can be ``IJAC = 2``-compatible). Returns ------- dict with the following keys: ``state_var_indices`` ``{state_var_name: int_or_(int,int)}``: position(s) of each state variable in the flat ``y`` vector. ``f_exprs`` ``list[sympy.Expr]``: one expression per ``DE`` update, in the order they appear in ``self.var_updates['DEs']``. ``y_syms`` ``list[sympy.Symbol]``: state-variable symbols, parallel to ``f_exprs`` (so ``y_syms[i]`` is the LHS symbol of ``f_exprs[i]``). ``var_is_vector`` ``list[bool]``: True for each entry that has shape > 1. ``dfdu`` ``dict[(i, j) -> sympy.Expr]``: non-zero ∂F_i/∂U_j entries, indices into the flat state vector. ``param_syms`` ``dict[name -> sympy.Symbol]``: parameter symbols PyRates exposed to the vector field, in the order they appear in the model (excluding state vars and ``t``). ``dfdp`` ``dict[(i, name) -> sympy.Expr]``: non-zero ∂F_i/∂PAR_name entries — populated only when ``include_dfdp`` is True. ``is_dde`` bool — True iff the vector field references past states (``past(var, tau)``). Auto-07p doesn't natively continue DDEs, so callers may want to skip the analytical Jacobian in that case. ``vector_blocks`` Blocks ``(i_lo, i_hi, j_lo, j_hi)`` of the Jacobian that couldn't be differentiated analytically because they involve a vector-valued state variable. Caller decides whether to fall back to FD or just warn. """ import sympy as sp # ``compile`` calls ``_prune`` which removes nodes the graph hasn't # connected yet — including the ``dy`` node that ``to_func`` has just # registered with ``_generate_vecfield_var``. Only compile if the # graph hasn't been processed yet. ``_eq_nodes`` is populated by # ``eval_subgraph`` (invoked from compile) so it's a reliable # "already compiled" marker. if not self._eq_nodes: self.compile() # build state-vector indices (matches get_jacobian_func / to_func) idx = 0 state_var_indices = {} for var in self.var_updates['DEs']: lhs = self.get_var(var) vshape = sum(lhs.shape) if vshape > 1: state_var_indices[var] = (idx, idx + vshape) idx += vshape else: state_var_indices[var] = idx idx += 1 # symbolic vector field f_exprs, y_syms, past_map, func_args, var_is_vector = self._get_symbolic_rhs() sym_to_y_idx = {sym: state_var_indices[var] for var, sym in zip(self.var_updates['DEs'].keys(), y_syms)} # ∂F/∂U dfdu: dict = {} vector_blocks: list = [] i_row = 0 for f_i, yi_sym, fi_is_vec in zip(f_exprs, y_syms, var_is_vector): fi_idx = sym_to_y_idx[yi_sym] fi_nrows = (fi_idx[1] - fi_idx[0]) if isinstance(fi_idx, tuple) else 1 j_col = 0 for yj_sym, fj_is_vec in zip(y_syms, var_is_vector): fj_idx = sym_to_y_idx[yj_sym] fj_ncols = (fj_idx[1] - fj_idx[0]) if isinstance(fj_idx, tuple) else 1 if fi_is_vec or fj_is_vec: vector_blocks.append((i_row, i_row + fi_nrows, j_col, j_col + fj_ncols)) else: d = sp.diff(f_i, yj_sym) d = self._resolve_derivatives(d) if d != 0: dfdu[(i_row, j_col)] = d j_col += fj_ncols i_row += fi_nrows # ∂F/∂PAR — for every argument that's neither a state variable nor `t` param_syms: dict = {} dfdp: dict = {} if include_dfdp: state_var_names = set(self.var_updates['DEs'].keys()) for arg_name in func_args: if arg_name in state_var_names or arg_name == 't': continue try: v = self.get_var(arg_name) except KeyError: continue # Only include true parameters (constants); skip dy buffers, hist callables, etc. if v.vtype != 'constant': continue param_syms[arg_name] = v.symbol i_row = 0 for f_i, yi_sym, fi_is_vec in zip(f_exprs, y_syms, var_is_vector): fi_idx = sym_to_y_idx[yi_sym] fi_nrows = (fi_idx[1] - fi_idx[0]) if isinstance(fi_idx, tuple) else 1 if not fi_is_vec: for name, psym in param_syms.items(): d = sp.diff(f_i, psym) d = self._resolve_derivatives(d) if d != 0: dfdp[(i_row, name)] = d i_row += fi_nrows return { 'state_var_indices': state_var_indices, 'f_exprs': f_exprs, 'y_syms': y_syms, 'var_is_vector': var_is_vector, 'sym_to_y_idx': sym_to_y_idx, 'dfdu': dfdu, 'param_syms': param_syms, 'dfdp': dfdp, 'is_dde': bool(past_map), 'vector_blocks': vector_blocks, } def _extract_past_terms(self, expr) -> tuple: """Replace every ``past(var, delay)`` call in *expr* with a fresh Symbol. Returns ``(new_expr, past_map)`` where ``past_map`` maps ``(var_sym, delay_sym) → fresh_sym``. """ from sympy import Symbol past_map: dict = {} def _visit(e): if not e.args: return e if e.func.__name__ == 'past' and len(e.args) == 2: key = (e.args[0], e.args[1]) if key not in past_map: safe = str(e.args[1]).replace('.', 'p').replace('-', 'm') past_map[key] = Symbol(f'_past_{e.args[0]}_{safe}') return past_map[key] new_args = tuple(_visit(a) for a in e.args) if new_args != e.args: return e.func(*new_args) return e return _visit(expr), past_map def _resolve_derivatives(self, expr): """Replace ``Derivative(f(x), x)`` with known analytical forms. Currently handles: ``identity`` (pass-through), ``sigmoid``, and ``absv``. """ import sympy as sp from sympy import Derivative, Function # identity(x) = x → d/dx = 1 expr = expr.replace( lambda e: isinstance(e, Derivative) and e.expr.func.__name__ == 'identity', lambda e: sp.Integer(1) ) expr = expr.replace( lambda e: isinstance(e, Derivative) and e.expr.func.__name__ == 'sigmoid', lambda e: (lambda s: s * (1 - s))(Function('sigmoid')(e.expr.args[0])) ) expr = expr.replace( lambda e: isinstance(e, Derivative) and e.expr.func.__name__ == 'absv', lambda e: Function('sign')(e.expr.args[0]) ) return expr def _expr_to_jac_str(self, expr, sym_to_y_idx: dict, past_sym_to_str: dict): """Convert a symbolic Jacobian entry to a backend code string. Parameters ---------- expr : sympy.Expr sym_to_y_idx : dict ``{state_sym: int_or_tuple_idx}`` — maps state-variable symbols to their position in the flat state vector ``y``. past_sym_to_str : dict ``{fresh_past_sym: code_string}`` — maps past-state placeholders to the code that evaluates them (e.g. ``'_yhist_0p5[0]'``). Returns ------- str or None ``None`` signals that unevaluated ``Derivative`` nodes remain and a numerical fallback should be used for this entry. """ import sympy as sp from sympy import Derivative # resolve known analytical derivative rules (sigmoid, absv, …) expr = self._resolve_derivatives(expr) if expr.atoms(Derivative): return None start = self.backend._start_idx # build placeholder substitution: state sym / past sym → unique temp sym subs: dict = {} ph_to_code: dict = {} for i, (sym, idx) in enumerate(sym_to_y_idx.items()): ph = sp.Symbol(f'_ypl{i}_') subs[sym] = ph if isinstance(idx, tuple): ph_to_code[str(ph)] = f'y[{idx[0]+start}:{idx[1]}]' else: ph_to_code[str(ph)] = f'y[{idx+start}]' for i, (psym, code_str) in enumerate(past_sym_to_str.items()): ph = sp.Symbol(f'_ppl{i}_') subs[psym] = ph ph_to_code[str(ph)] = code_str expr_subst = expr.subs(subs) expr_str = str(expr_subst) # replace placeholders with actual code strings (longest first to avoid # partial substring collisions) for ph_str in sorted(ph_to_code.keys(), key=len, reverse=True): expr_str = expr_str.replace(ph_str, ph_to_code[ph_str]) return expr_str
[docs] def clear(self) -> None: """Deletes build directory and removes all compute graph nodes """ # delete network nodes and variables from the compute graph for n in list(self.nodes.keys()): self.remove_subgraph(n) self.var_updates.clear() self._state_var_indices.clear() self._eq_nodes.clear() # clear code generator self.backend.clear()
def _to_str(self): # preparations code_gen = self.backend # extract state variable from state vector rhs_indices_str = [] for var in self.state_vars: # extract index of variable in state vector idx = (self._state_var_indices[var],) # extract state variable from state vector rhs_idx, _ = code_gen.create_index_str(idx) code_gen.add_var_update(lhs=self.get_var(var), rhs=f"y{rhs_idx}") rhs_indices_str.append(idx) # get equation string and argument list for each non-DE node at the end of the compute graph hierarchy func_args2, delete_args1 = self._generate_update_equations(differential_equations=False) code_gen.add_linebreak() # get equation string and argument list for each DE node at the end of the compute graph hierarchy func_args1, delete_args2 = self._generate_update_equations(differential_equations=True, indices=rhs_indices_str) # remove unnecessary function arguments func_args = func_args1 + func_args2 for arg in delete_args1 + delete_args2: while arg in func_args: func_args.pop(func_args.index(arg)) return func_args, code_gen def _generate_update_equations(self, differential_equations: bool, indices: list = None) -> tuple: code_gen = self.backend # extract relevant compute graph nodes and bring them into the correct order nodes = self.var_updates['DEs' if differential_equations else 'non-DEs'] nodes, updates, def_vars, undef_vars = self._sort_var_updates(nodes=nodes, differential_equations=differential_equations) # collect right-hand side expression and all input variables to these expressions func_args, expressions, var_names, rhs_shapes, lhs_indices = undef_vars, [], [], [], [] for node, update in zip(nodes, updates): # collect shape of the right-hand side variable v = self.get_var(update) try: v_eval = self.eval_node(update) v.set_value(v_eval) except IndexError: pass rhs_shapes.append(v.shape) # collect expression and variables of right-hand side of equation expr_args, expr = self._node_to_expr(update) func_args.extend(expr_args) expr_str, expr_args, _, _ = self._expr_to_str(expr, apply=True) func_args.extend(expr_args) expressions.append(expr_str) # process left-hand side of equation var = self.get_var(node) if isinstance(var, ComputeOp): # process indexing of left-hand side variable idx_args, lhs = self._node_to_expr(node) if lhs.args[0].name not in def_vars: idx_args.append(lhs.args[0].name) func_args.extend(idx_args) _, idx_args, lhs_var, idx = self._expr_to_str(lhs, apply=False) func_args.extend(idx_args) else: # process normal update of left-hand side variable lhs_var = var.name idx = None var_names.append(lhs_var) lhs_indices.append(idx) # add the left-hand side assignments of the collected right-hand side expressions to the code generator if differential_equations: # differential equation (DE) update if not indices: raise ValueError('State variables need to be stored in a single state vector, for which the indices ' 'have to be passed to this method.') add_args = self._generate_vecfield(code_gen, indices, expressions, rhs_shapes, var_names) func_args = add_args + func_args else: # non-differential equation update if indices: raise ValueError('Indices to non-state variables should be defined in the respective equations, not' 'be passed to this method.') indices = lhs_indices # non-DE update stored in a single variable for target_var, expr, idx, shape in zip(var_names, expressions, indices, rhs_shapes): try: idx = self.get_var(idx) except (KeyError, AttributeError, TypeError): pass code_gen.add_var_update(lhs=self.get_var(target_var), rhs=expr, lhs_idx=idx, rhs_shape=shape) return func_args, def_vars def _generate_vecfield(self, code_gen, indices: list, expressions: list, rhs_shapes: list, lhs_vars: list) -> list: # DE updates stored in a state-vector dy = self.get_var("dy") for idx, expr, shape in zip(indices, expressions, rhs_shapes): code_gen.add_var_update(lhs=dy, rhs=expr, lhs_idx=idx, rhs_shape=shape) # add rhs var to function arguments return ['dy'] def _generate_vecfield_var(self, state_vec: np.ndarray, dtype: str) -> str: key, _ = self.add_var(label='dy', vtype='state_var', value=np.zeros_like(state_vec), dtype=dtype) return key def _generate_func_tail(self, code_gen, vecfield_key: str): code_gen.generate_func_tail(rhs_var=vecfield_key) def _node_to_expr(self, n: str, **kwargs) -> tuple: expr_args = [] node = self.get_var(n) # case I: node is a mathematical operation and its inputs need to be treated try: # process node inputs expr_info = {self.get_var(inp).symbol: self._node_to_expr(inp, **kwargs) for inp in self.predecessors(n)} # replace old inputs with its processed versions expr = node.expr for expr_old, (args, expr_new) in expr_info.items(): if expr_old != expr_new: expr = expr.replace(expr_old, expr_new) expr_args.extend(args) # Replace generic function calls with the backend-specific function calls. # We deliberately SKIP the rebind when the backend's call name matches # the original sympy name (the typical case for sympy stdlib functions # like ``exp``, ``sin``, ``cos``, ``log``, ...) — rebinding ``sp.exp`` # to ``Function('exp')`` would replace the sympy stdlib class with an # ``UndefinedFunction`` of the same name, losing sympy's built-in # differentiation rule. Without this guard, ``sp.diff(exp(x), x)`` # returns the unevaluated ``Derivative(exp(x), x)`` and the analytical # Jacobian path (``_compute_symbolic_jacobian``) breaks for any model # with a transcendental in its RHS. Renames that actually differ # (e.g. ``matmul`` → ``dot``, ``no_op`` → ``identity``) still rebind # as before; those cases never benefit from sympy.diff anyway. try: expr_old = expr.func.__name__ func_info = self.get_op(expr_old, shape=node.shape) new_call = func_info['call'] if new_call != expr_old: expr = expr.replace(expr.func, Function(new_call)) except (AttributeError, KeyError): pass # case II: node is a simple variable or constant except AttributeError: # add constants to the expression arguments list if node.is_constant: expr_args.append(n) expr = node.symbol elif 'dummy_constant' in node.name: val = float(np.squeeze(node.value)) expr = Symbol(str(val)) else: expr = node.symbol return expr_args, expr def _expr_to_str(self, expr: Any, expr_str: str = None, apply: bool = True, **kwargs) -> tuple: # preparations ############### # initializations index_args = [] func = "" idx = "" # ensure expression string exists if not expr_str: expr_str = str(expr) # transform expression arguments into strings expr_args = [] for arg in expr.args: expr_part, args, _, _ = self._expr_to_str(arg, **kwargs) expr_str = expr_str.replace(str(arg), expr_part) index_args.extend(args) expr_args.append(expr_part) var = str(expr_args[0]) if expr.args else "" # process indexing operations ############################# if 'index_1d(' in expr_str: # replace `index` calls with brackets-based indexing idx = self._get_var_idx(idx=(expr.args[1],), args=index_args, apply=apply, **kwargs) func = 'index_1d' elif 'index_2d(' in expr_str: # replace `2d_index` calls with brackets-based indexing idx = self._get_var_idx(idx=(expr.args[1], expr.args[2]), args=index_args, apply=apply, **kwargs) func = 'index_2d' elif 'index_range(' in expr_str: # replace `range_index` calls with brackets-based indexing idx = self._get_var_idx(idx=((expr.args[1], expr.args[2]),), args=index_args, apply=apply, **kwargs) func = 'index_range' elif 'index_axis(' in expr_str: # replace `axis_index` calls with brackets-based indexing if len(expr.args) < 2: idx = self._get_var_idx(idx=(':',), args=index_args, apply=apply, **kwargs) else: idx = self._get_var_idx(args=index_args, apply=apply, idx=tuple([':' for _ in range(expr.args[2])] + [f"{expr.args[1]}"]), **kwargs) func = "index_axis" # either apply the above indexing calls or return them if func and apply: replacement = self.backend.finalize_idx_str(var=self.get_var(var), idx=idx) expr_str = self._process_func_call(expr=expr_str, func=func, replacement=replacement) # handle other function calls ############################# if 'identity(' in expr_str: # replace `no_op` calls with first argument to the function call expr_str = self._process_func_call(expr=expr_str, func='identity', replacement=var) if 'past(' in expr_str: # replace past calls with the delayed version of the backend variable try: delay = self.get_var(expr.args[1].name) except AttributeError: delay = float(expr.args[1]) replacement = self._get_var_hist(var=var, delay=delay) expr_str = self._process_func_call(expr=expr_str, func="past", replacement=replacement) # backend-specific function call adjustments expr_str = self.backend.expr_to_str(expr_str, expr.args) return expr_str, index_args, var, idx def _process_var_update(self, var: str, update: str) -> tuple: # extract nodes lhs = self.get_var(var) rhs = self.eval_node(update) # extract common shape if lhs.shape == rhs.shape: return lhs, rhs try: rhs = rhs.reshape(lhs.shape) return lhs, rhs except ValueError: raise ValueError( f"Shapes of state variable {var} and its right-hand side update {self.get_var(update).expr} do not" " match.") def _sort_var_updates(self, nodes: dict, differential_equations: bool = True) -> tuple: # case I: for differential equations, do not perform any sorting if differential_equations: return list(nodes.keys()), list(nodes.values()), [], [] # case II: for non-differential equations, sort them according to their graph connections ######################################################################################### # step 1: ensure lhs-indexing operations are considered as well node_names, node_keys = [], [] for node in nodes: n = self.get_var(node) if type(n) is ComputeVar: node_names.append(node) else: node_names.append(list(self._get_inputs(node))[-1]) node_keys.append(node) keys, values, defined_vars, undefined_vars = [], [], [], [] n_nodes = len(nodes) while nodes: for node, update in nodes.copy().items(): # go through node inputs and check whether it depends on other equations to be evaluated first dependent, inp = False, "" for inp in self._get_inputs(update): if inp in node_names: idx = node_names.index(inp) if node_keys[idx] != node: dependent = True break # decide whether this equation can be evaluated now if dependent: continue else: idx = node_keys.index(node) n = node_names.pop(idx) node_keys.pop(idx) nodes.pop(node) keys.append(node) values.append(update) if isinstance(self.get_var(node), ComputeVar) and n not in undefined_vars: defined_vars.append(n) elif n not in undefined_vars: undefined_vars.append(n) # check whether the algorithm is stuck n_nodes_new = len(nodes) if n_nodes_new == n_nodes: break else: n_nodes = n_nodes_new # add mutually depended nodes if nodes: node_keys = list(nodes.keys()) keys.extend(node_keys) values.extend(list(nodes.values())) for n in node_names: if n not in undefined_vars: undefined_vars.append(n) return keys, values, defined_vars, undefined_vars def _get_inputs(self, n: str): inputs = [] for inp in self.predecessors(n): inputs.extend([inp] if isinstance(self.get_var(inp), ComputeVar) else self._get_inputs(inp)) return inputs def _get_var_idx(self, idx: tuple, args: list, apply: bool = True, **kwargs): # collect indexing variables where necessary new_idx = [] for idx_tmp in idx: try: v = self.get_var(idx_tmp.name if type(idx_tmp) is Symbol else idx_tmp) new_idx.append(v) except KeyError: new_idx.append(idx_tmp) # turn index into a backend-specific string idx_str, new_vars = self.backend.create_index_str(tuple(new_idx), apply=apply, **kwargs) # add new variables to graph and index arguments for key, v in new_vars.items(): if key != self.backend.idx_dummy_var: vlabel, _ = self.add_var(label='index', value=v, vtype='constant') idx_str = idx_str.replace(key, vlabel) args.append(vlabel) return idx_str def _get_var_hist(self, var: str, delay: Union[ComputeVar, float]): if var not in self._state_var_hist: self._state_var_hist[var] = dict() if delay not in self._state_var_hist[var]: var_hist = f'{var}_hist{len(self._state_var_hist[var])}' self.add_var(var_hist, value=self.get_var(var).value, vtype='variable') self._state_var_hist[var][delay] = var_hist else: var_hist = self._state_var_hist[var][delay] return var_hist def _prune(self): # remove all subgraphs that contain constants only for n in [node for node, out_degree in self.out_degree if out_degree == 0]: if self.get_var(n).is_constant and n not in self._eq_nodes: self.remove_subgraph(n) # remove all unconnected nodes for n in [node for node, out_degree in self.out_degree if out_degree == 0]: if self.in_degree(n) == 0 and n not in self._eq_nodes: self.remove_node(n) def _generate_unique_label(self, label: str) -> str: if label == "t": return label if label in self._node_names: n = self._node_names[label] if n == 0: label_new = f"{label}_v1" else: label_new = f"{label}_v{n + 1}" self._node_names[label] += 1 else: label_new = label self._node_names[label] = 0 return label_new @staticmethod def _index_state_var(y: np.ndarray, idx: Union[int, tuple, list]) -> np.ndarray: if type(idx) is tuple and idx[1] - idx[0] == 1: idx = (idx[0],) elif type(idx) is int: idx = (idx,) return y[:, idx] if len(idx) == 1 else y[:, idx[0]:idx[1]] @staticmethod def _process_func_call(expr: str, func: str, replacement: str): # identify start and end of the function call start = expr.find(f"{func}(") end = expr[start:].find(')') + 1 # replace part in expression string return expr.replace(expr[start:start + end], replacement)
[docs]class ComputeGraphBackProp(ComputeGraph): def __init__(self, backend: str, **kwargs): super().__init__(backend, **kwargs) self._vecfield_vars = [] self._vecfield_var_str = "" def _generate_vecfield_var(self, state_vec: np.ndarray, dtype: str): return "" def _generate_func_tail(self, code_gen, vecfield_key: str): code_gen.generate_func_tail(rhs_var=self._vecfield_var_str) def _generate_vecfield(self, code_gen, indices: list, expressions: list, rhs_shapes: list, lhs_vars: list) -> list: for lhs, expr in zip(lhs_vars, expressions): lhs_var = f"delta_{lhs}" code_gen.add_code_line(f"{lhs_var} = {expr}") self._vecfield_vars.append(lhs_var) if len(self._vecfield_vars) > 1: op_dict = self.backend.get_op("concatenate") self._vecfield_var_str = f"{op_dict['call']}([{','.join(v for v in self._vecfield_vars)}], 0)" else: self._vecfield_var_str = self._vecfield_vars.pop() return []