Source code for pyrates.backend.base.base_backend

# -*- coding: utf-8 -*-
#
#
# PyRates software framework for flexible implementation of neural
# network model_templates and simulations. See also:
# https://github.com/pyrates-neuroscience/PyRates
#
# Copyright (C) 2017-2018 the original authors (Richard Gast and
# Daniel Rose), the Max-Planck-Institute for Human Cognitive Brain
# Sciences ("MPI CBS") and contributors
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <https://www.gnu.org/licenses/>
#
# CITATION:
#
# Richard Gast and Daniel Rose et. al. in preparation

"""Contains wrapper classes for different backends that are needed by the parser module.

A new backend needs to implement the following methods:
- __init__
- run
- add_var
- add_op
- add_layer

Currently supported backends:
- Numpy: BaseBackend.
- Torch: TorchBackend.
- Fortran: FortranBackend (experimental).
- Julia: JuliaBackend.
- Matlab: MatlabBackend.

"""

# pyrates internal _imports
from ..computegraph import ComputeVar
from .base_funcs import base_funcs
from .. import PyRatesException

# external _imports
from typing import Optional, Dict, List, Union, Tuple, Callable, Iterable
import numpy as np
import os, sys, importlib, hashlib, types as _types
from shutil import rmtree


# ---------------------------------------------------------------------------
# Module-level cache for compiled RHS modules, keyed by SHA-256 of the source
# string passed to ``compile()``.  Lets parameter sweeps and the
# ``get_run_func`` / ``get_jacobian_func`` paths skip the compile+exec round
# trip when the same model is regenerated identically (which is the common
# case during optimisation loops).  Cleared explicitly via
# ``clear_compile_cache()`` below — there is no automatic eviction.
# ---------------------------------------------------------------------------
_compiled_module_cache: Dict[str, _types.ModuleType] = {}


[docs]def clear_compile_cache() -> None: """Drop all cached compiled RHS modules. Call this if you have generated many distinct models in a long-running process and want to reclaim memory. Has no effect on already-returned callables; only future :meth:`BaseBackend.generate_func` calls are affected. """ _compiled_module_cache.clear()
# Helper Functions and Classes ############################## import bisect
[docs]class DDEHistory: """Callable history buffer for delay-differential equations. Stores (time, state) pairs and returns linearly-interpolated past state. Pre-history (t <= t0) always returns the initial condition. Implementation -------------- The previous list-of-ndarrays implementation (PyRates <= 1.1.x) allocated a fresh ``ndarray`` on every :meth:`update` call via ``y.copy()`` — for a 100k-step simulation that was 100k small heap allocations. We replace the ``_y`` storage with a single pre-allocated 2-D numpy buffer that holds one row per step; row assignment ``self._y[i] = y`` still constitutes a copy, so the "update takes ownership of its y argument" contract is unchanged. ``_t`` stays as a Python list because :code:`bisect.bisect_right` on a list is noticeably faster than :code:`np.searchsorted` on the equivalent numpy view when called once per RHS evaluation in the integration loop (per-call CPython overhead is the dominant cost here, not the underlying binary search). Parameters ---------- y0 Initial state vector. Defines the ``shape`` and ``dtype`` of all subsequent history rows. t0 Initial time. History returns ``y0`` for any query ``t <= t0``. max_steps Optional hard cap on the number of stored steps. If given, the buffer is allocated once at this size and :meth:`update` raises ``IndexError`` past the cap. If ``None`` (default) the buffer grows geometrically. """ _INITIAL_CAPACITY = 1024 _GROW_FACTOR = 2 def __init__(self, y0: np.ndarray, t0: float = 0.0, max_steps: Optional[int] = None): y0 = np.asarray(y0) if max_steps is None: capacity = self._INITIAL_CAPACITY self._growable = True else: capacity = max(int(max_steps), 1) self._growable = False self._t = [float(t0)] # bisect-friendly self._y = np.empty((capacity,) + y0.shape, dtype=y0.dtype) # pre-allocated rows self._y[0] = y0 self._n = 1
[docs] def update(self, t: float, y: np.ndarray) -> None: """Record state ``y`` at time ``t``. ``y`` is copied into the pre-allocated row buffer; the caller may free or overwrite its own ``y`` after this call returns. """ if self._n >= len(self._y): if self._growable: self._grow() else: raise IndexError( f"DDEHistory: exceeded max_steps={len(self._y)}; " "increase the bound or omit max_steps to allow growth." ) self._t.append(float(t)) self._y[self._n] = y # row assignment copies y into the buffer self._n += 1
def _grow(self) -> None: old_cap = len(self._y) new_cap = old_cap * self._GROW_FACTOR new_y = np.empty((new_cap,) + self._y.shape[1:], dtype=self._y.dtype) new_y[:self._n] = self._y[:self._n] self._y = new_y def __call__(self, t: float) -> np.ndarray: t = float(t) if t <= self._t[0]: return self._y[0] if t >= self._t[-1]: return self._y[self._n - 1] idx = bisect.bisect_right(self._t, t) - 1 t0_ = self._t[idx] t1_ = self._t[idx + 1] alpha = (t - t0_) / (t1_ - t0_) return self._y[idx] + alpha * (self._y[idx + 1] - self._y[idx])
[docs]class CodeGen: """Generates python code. Can add code lines, line-breaks, indents and remove indents. """ def __init__(self): self.code = [] self.lvl = 0
[docs] def generate(self): """Generates a single code string from its history of code additions. """ return '\n'.join(self.code)
[docs] def add_code_line(self, code_str): """Add code line string to code. """ code_str = code_str.split('\n') for code in code_str: self.code.append("\t" * self.lvl + code)
[docs] def add_linebreak(self): """Add a line-break to the code. """ self.code.append("")
[docs] def add_indent(self): """Add an indent to the code. """ self.lvl += 1
[docs] def remove_indent(self): """Remove an indent to the code. """ if self.lvl == 0: raise(SyntaxError("Error in generation of network function file: A net negative indentation was requested.") ) self.lvl -= 1
[docs] def clear(self): """Deletes all code lines from the memory of the generator. """ self.code.clear()
####################################### # classes for backend functionalities # #######################################
[docs]class BaseBackend(CodeGen): """Default backend class. Transforms all network equations into their numpy equivalents. Based on a Python code generator. """ # ---------------------------------------------------------------------- # Class attribute: which `solver=` values does this backend accept? # Subclasses override (or extend) this tuple. Validated early via # ``_validate_solver`` so the user sees a clear error instead of the # current "the call silently took a wrong branch and produced numpy # output" failure mode. # ---------------------------------------------------------------------- SUPPORTED_SOLVERS: Tuple[str, ...] = ('euler', 'heun', 'scipy') # ---------------------------------------------------------------------- # Class attribute: ops whose `def` strings must NOT be appended to the # generated module's helper-funcs list (they are language primitives or # pure-Python no-ops handled directly by the parser). Subclasses can # extend this with e.g. ``_no_funcs = BaseBackend._no_funcs + ("foo",)``. # ---------------------------------------------------------------------- _no_funcs: Tuple[str, ...] = ("identity", "index_1d", "index_2d", "index_range", "index_axis") # ---------------------------------------------------------------------- # Feature flags consulted by ComputeGraph at code-gen time so users get # a clear, early error rather than a cryptic JIT-trace failure. # # ``SUPPORTS_SPARSE_JACOBIAN``: ``get_jacobian_func(..., sparse=True)`` # emits ``csr_matrix(J0)``. ``scipy.sparse.csr_matrix`` does not # understand JAX tracers, so the JAX backend sets this to False. # # ``SUPPORTS_EDGE_DELAY_BUFFER``: the discrete-delay path generates a # ring-buffer that's updated in-place each step via # ``buf[:] = roll(buf, 1); buf[0] = current``. JAX arrays are # immutable; the equivalent functional update produces a new array # that's never threaded back into the solver's state, so the buffer # stays at its initial value forever. JaxBackend sets this False; # users should switch to the DDEHistory / ``past(x, tau)`` path. # ---------------------------------------------------------------------- SUPPORTS_SPARSE_JACOBIAN: bool = True SUPPORTS_EDGE_DELAY_BUFFER: bool = True def __init__(self, ops: Optional[Dict[str, str]] = None, imports: Optional[List[str]] = None, **kwargs ) -> None: """Instantiates the standard, numpy-based backend. """ # call to super method (initializes code generator) super().__init__() # definition of usable math operations self._funcs = base_funcs.copy() if ops: self._funcs.update(ops) self._helper_funcs = [] # definition of extrinsic function _imports self._imports = ["from numpy import pi, sqrt"] if imports: for imp in imports: self.add_import(imp) # public attributes self.add_hist_arg = kwargs.pop('add_hist_arg', True) self.lags = {} self.idx_dummy_var = "temporary_pyrates_var_index" # private attributes self._float_precision = kwargs.pop('float_precision', 'float32') self._int_precision = kwargs.pop('int_precision', 'int32') self._idx_left = kwargs.pop('idx_left', '[') self._idx_right = kwargs.pop('idx_right', ']') self._start_idx = kwargs.pop('start_idx', 0) # `_no_funcs` lives on the class (see top of class body). Do not # shadow it here — subclasses extend the class attribute. # Tracks which ComputeVars have already had `_start_idx` baked into # their `.value` by `_process_idx`. Prevents double-application when # the same ComputeVar is processed more than once (review §4.3). # Keyed by ``id(ComputeVar)`` — safe because ComputeVars live as long # as the backend itself (owned by the ComputeGraph that wraps it). self._offsetted_var_ids: set = set() # file-creation-related attributes fdir, *fname = self.get_fname(kwargs.pop('file_name', 'pyrates_run')) cwdir = os.getcwd() sys.path.append(cwdir) if fdir: fdir = f"{cwdir}/{fdir}" sys.path.append(fdir) self.fdir = fdir self._fname = fname[0] self._fend = f".{fname[1]}" if len(fname) > 1 else kwargs.pop('file_ending', '.py')
[docs] def get_var(self, v: ComputeVar): if v.is_float or v.is_complex: dtype = self._float_precision if 'complex' in dtype and v.name in ['t', 'time']: dtype = f'float{dtype[7:]}' else: dtype = self._int_precision result = np.asarray(v.value, dtype=dtype) # Squeeze single-element constants to 0-d scalars. # PyRates stores scalar parameters internally as shape (1,) but # numpy 2.3+ raises an error when a (1,) array is assigned to a scalar # state-vector slot (e.g. dy[i] = (1,)_param * expr). if result.shape == (1,) and v.vtype == 'constant': result = result.squeeze() return result
[docs] def get_op(self, name: str, **kwargs) -> dict: # retrieve function information from backend definitions func_info = self._get_func_info(name, **kwargs) func_name = func_info['call'] # add extrinsic function imports if necessary if 'imports' in func_info: for imp in func_info['imports']: *in_path, in_func = imp.split('.') self.add_import(f"from {'.'.join(in_path)} import {in_func}") if 'def' in func_info: # extract the provided function definition func_str = func_info['def'] # remember the function definition string for file creation if func_str not in self._helper_funcs and func_name not in self._no_funcs: self._helper_funcs.append(func_str) if 'func' in func_info: # extract the provided callable func = func_info['func'] else: # extract the provided function definition func_str = func_info['def'] # make _imports available to function for imp in self._imports: try: exec(imp, globals()) except SyntaxError: pass # evaluate the function string to receive a callable exec(func_str, globals()) func = globals().pop(func_name) return {'func': func, 'call': func_name}
[docs] def add_var_update(self, lhs: ComputeVar, rhs: str, lhs_idx: Optional[str] = None, rhs_shape: Optional[tuple] = ()): lhs_str = lhs.name if lhs_idx: idx, _ = self.create_index_str(lhs_idx, apply=True) lhs_str = f"{lhs_str}{idx}" indexed = bool(lhs_idx) or bool(rhs_shape) self.add_code_line(self._format_assignment(lhs_str, rhs, indexed))
def _format_assignment(self, lhs: str, rhs: str, indexed: bool) -> str: """Render a single ``lhs = rhs`` assignment for the target language. Subclasses override this hook to splice in language-specific syntax (Julia's broadcast prefix ``@.``, Matlab's ``vectorize`` + trailing ``;``, …) instead of the old pop-and-rewrite pattern that ``super().add_var_update`` then ``self.code.pop() / line.split(' = ')`` used to implement. ``indexed`` is True iff the assignment writes to an index expression (``lhs[idx] = rhs``) or the rhs has non-scalar shape — i.e. the cases where broadcasting matters. """ return f"{lhs} = {rhs}" # ------------------------------------------------------------------ # Jacobian-emission hooks — used by ComputeGraph.get_jacobian_func to # emit local-array allocations and per-element assignments that match # the target language's mutation semantics. The default emits plain # numpy code; JaxBackend overrides for functional ``.at[i, j].set(...)`` # updates (and importing zeros from jax.numpy instead of numpy). # # ``declare_local_array_imports`` must be called *before* # ``generate_func_head`` because that method materialises ``self._imports`` # into source-file lines. Adding an import inside # ``emit_local_array_alloc`` (which runs after the function head) would be # too late — the import section is already closed by then. # ------------------------------------------------------------------
[docs] def declare_local_array_imports(self) -> None: """Register the imports needed by :meth:`emit_local_array_alloc`.""" self.add_import("from numpy import zeros")
[docs] def emit_local_array_alloc(self, name: str, shape: tuple, dtype: str) -> None: """Emit ``name = zeros(shape, dtype=dtype)``. Requires :meth:`declare_local_array_imports` to have been called earlier. """ self.add_code_line(f"{name} = zeros({shape}, dtype='{dtype}')")
[docs] def emit_local_array_assign(self, name: str, indices: tuple, expr: str) -> None: """Emit ``name[i, j, ...] = expr`` for a previously-allocated array.""" idx_str = ', '.join(str(i) for i in indices) self.add_code_line(f"{name}[{idx_str}] = {expr}")
[docs] def add_var_hist(self, lhs: str, delay: Union[ComputeVar, float], state_idx: str, dt: Optional[float] = None, dt_adapt: bool = True, **kwargs): idx = self._process_idx(state_idx) d = self._process_delay(delay) if dt is not None and not dt_adapt: self.add_code_line(f"{lhs} = hist(t*{dt:.10e}-{d})[{idx}]") else: self.add_code_line(f"{lhs} = hist(t-{d})[{idx}]")
[docs] def add_import(self, line: str): if line not in self._imports: self._imports.append(line)
[docs] def create_index_str(self, idx: Union[str, int, tuple], separator: str = ',', apply: bool = True, **kwargs) -> Tuple[str, dict]: # preprocess idx if type(idx) is str and separator in idx: idx = tuple(idx.split(separator)) # case: multiple indices if type(idx) is tuple: idx = list(idx) for i in range(len(idx)): idx[i] = self._process_idx(idx[i], **kwargs) idx = tuple([f"{i}" for i in idx]) idx_str = f"{self._idx_left}{separator.join(idx)}{self._idx_right}" if apply else separator.join(idx) return idx_str, dict() # case: single index idx = self._process_idx(idx, **kwargs) return f"{self._idx_left}{idx}{self._idx_right}" if apply else idx, dict()
[docs] def get_fname(self, f: str) -> tuple: f_split = f.split('.') if len(f_split) > 2: raise ValueError(f'File name {f} has wrong format. Only one `.` can be used to separate file name from ' f'file ending.') if len(f_split) == 2: *path, file = f_split[0].split('/') return '/'.join(path), file, f_split[1] else: *path, file = f.split('/') return '/'.join(path), file
[docs] def generate_func_head(self, func_name: str, state_var: str = 'y', return_var: str = 'dy', func_args: list = None, add_hist_func: Optional[bool] = None): """Generate the function header for the RHS file. ``add_hist_func`` Whether to include the ``hist`` callable in the generated function signature. If ``None`` (default), falls back to the backend's ``add_hist_arg`` constructor flag — making that flag the single source of truth. Callers that already know whether the model is a DDE (e.g. ``ComputeGraph.to_func``) pass the resolved bool explicitly and skip the fallback. """ if add_hist_func is None: add_hist_func = self.add_hist_arg imports = self._imports helper_funcs = self._helper_funcs if func_args: func_args = [arg.name for arg in func_args] else: func_args = [] state_vars = ['t', state_var] if add_hist_func: state_vars.append('hist') _, indices = np.unique(func_args, return_index=True) func_args = state_vars + [func_args[idx] for idx in np.sort(indices)] if imports: # add _imports at beginning of file for imp in imports: self.add_code_line(imp) self.add_linebreak() if helper_funcs: # add definitions of helper functions after the _imports for func in helper_funcs: self.add_code_line(func) self.add_linebreak() # add function header self.add_linebreak() self._add_func_call(name=func_name, args=func_args, return_var=return_var) self.add_indent() return func_args
[docs] def generate_func_tail(self, rhs_var: str = 'dy'): self.add_code_line(f"return {rhs_var}") self.remove_indent()
[docs] def generate_func(self, func_name: str, to_file: bool = True, **kwargs): # generate the current function string via the code generator func_str = self.generate() # Write the source to disk first (so users can still inspect / debug it # even on a cache hit) — file IO is cheap compared to compile(). if to_file: file = f'{self.fdir}/{self._fname}' if self.fdir else self._fname src_path = f'{file}{self._fend}' with open(src_path, 'w') as f: f.writelines(func_str) else: src_path = f'<pyrates:{self._fname}>' # Consult the SHA-256-keyed module cache. A hit means we already have # a fully-compiled-and-executed module for this exact source string; # we can pull the function object out of its namespace directly and # skip both compile() and exec(). See module-level docstring for # ``_compiled_module_cache``. cache_key = hashlib.sha256(func_str.encode('utf-8')).hexdigest() _mod = _compiled_module_cache.get(cache_key) if _mod is None: _mod = _types.ModuleType(self._fname) _mod.__file__ = src_path # Compile from the in-memory source string and exec into the fresh # module, bypassing the .pyc bytecode cache entirely. A stale # .pyc can persist when clear() removes the .py but not # __pycache__, and the next write lands in the same second, making # the source mtime appear unchanged to Python's cache validator. exec(compile(func_str, src_path, 'exec'), _mod.__dict__) _compiled_module_cache[cache_key] = _mod # Refresh sys.modules so subsequent introspection (e.g. tracebacks) # finds the right module under self._fname. We always replace the # entry — the cached module may have been registered under a previous # _fname or have been removed by .clear(). if to_file: sys.modules[self._fname] = _mod rhs_eval = _mod.__dict__[func_name] return self._apply_decorator(rhs_eval, **kwargs)
@staticmethod def _apply_decorator(rhs_eval: Callable, **kwargs) -> Callable: """Optionally wrap the generated function with a user-supplied decorator. Looks for ``decorator`` (callable) and ``decorator_kwargs`` (dict) in ``kwargs``; pops them out and applies ``decorator(rhs_eval, **decorator_kwargs)`` when present. Returns ``rhs_eval`` unchanged otherwise. Centralised here so backend subclasses with their own ``generate_func`` don't each have to repeat the same four lines. """ decorator = kwargs.pop('decorator', None) if decorator: decorator_kwargs = kwargs.pop('decorator_kwargs', dict()) rhs_eval = decorator(rhs_eval, **decorator_kwargs) return rhs_eval
[docs] def run(self, func: Callable, func_args: tuple, T: float, dt: float, dts: float, solver: str, **kwargs) -> tuple: # initial values t0 = func_args[0] y0 = func_args[1] # use a safer way to generate time points (endpoint=False ensures times match Euler step indices) step = dts if dts else dt n_time_points = round(T/step) times = np.linspace(0.0, T, num=n_time_points, endpoint=False) # perform simulation results = self._solve(solver=solver, func=func, args=func_args[2:], T=T, dt=dt, dts=dts, y0=y0, t0=t0, times=times, **kwargs) return results, times
[docs] def clear(self): # clear code generator super().clear() # remove files and directories that have been created during simulation process if self.fdir: rmtree(self.fdir) else: try: os.remove(f"{self._fname}{self._fend}") except FileNotFoundError: pass # delete loaded modules from the system if self._fname in sys.modules: del sys.modules[self._fname]
[docs] @staticmethod def to_file(fn: str, **kwargs): np.savez(fn, **kwargs)
[docs] @staticmethod def register_vars(variables: list): pass
[docs] @staticmethod def finalize_idx_str(var: ComputeVar, idx: str): return f"{var.name}{idx}"
[docs] @staticmethod def expr_to_str(expr: str, args: tuple): return expr
[docs] @staticmethod def get_hist_func(y: np.ndarray, t0: float = 0.0) -> DDEHistory: return DDEHistory(y, t0=t0)
def _get_func_info(self, name: str, **kwargs): return self._funcs[name] def _process_idx(self, idx: Union[Tuple[int, int], int, str, ComputeVar], **kwargs) -> str: if type(idx) is ComputeVar: # Idempotent offset: bake `self._start_idx` into `idx.value` only # the first time we see this particular ComputeVar. The previous # implementation re-added the offset on every call, which forced # Julia (and indirectly Matlab) to temporarily toggle # `self._start_idx = 0` around calls that could touch an already # processed var — fragile and easy to break (review §4.3). if self._start_idx and id(idx) not in self._offsetted_var_ids: idx.set_value(idx.value + self._start_idx) self._offsetted_var_ids.add(id(idx)) return idx.name if type(idx) is tuple: return f"{idx[0] + self._start_idx}:{idx[1]}" if type(idx) is int: return f"{idx + self._start_idx}" try: return self._process_idx(int(idx), **kwargs) except (TypeError, ValueError): return idx def _process_delay(self, delay: Union[ComputeVar, float]) -> str: return f"{delay}[{self._start_idx}]" if type(delay) is ComputeVar and delay.shape else f"{delay}" def _validate_solver(self, solver: str) -> None: """Raise a helpful error if the requested solver is not supported. Called by every ``_solve`` override (in this class and subclasses) so users get a clear "solver X is not supported by backend Y; supported are Z" message instead of the previous fall-through into a generic ``PyRatesException`` (or, worse, a silent dispatch into a method that happened to share a prefix). """ if solver not in self.SUPPORTED_SOLVERS: raise PyRatesException( f"Backend `{type(self).__name__}` does not support solver " f"`{solver}`. Supported solvers: {list(self.SUPPORTED_SOLVERS)}." ) def _solve(self, solver: str, func: Callable, args: tuple, T: float, dt: float, dts: float, y0: np.ndarray, t0: np.ndarray, times: np.ndarray, **kwargs) -> np.ndarray: self._validate_solver(solver) if solver == 'euler': return self._solve_euler(func, args, T, dt, dts, y0, t0) if solver == 'heun': return self._solve_heun(func, args, T, dt, dts, y0, t0) # solver == 'scipy' if len(args) > 0 and isinstance(args[0], DDEHistory): return self._solve_scipy_dde(func, args, T, dt, y0, t0, times, **kwargs) return self._solve_scipy(func, args, T, dt, y0, t0, times, **kwargs) def _add_func_call(self, name: str, args: Iterable, return_var: str = 'dy'): self.add_code_line(f"def {name}({','.join(args)}):") @staticmethod def _solve_euler(func: Callable, args: tuple, T: float, dt: float, dts: float, y: np.ndarray, t0): # preparations for fixed step-size integration idx = 0 steps = int(np.round(T / dt)) store_steps = int(np.round(T / dts)) store_step = int(np.round(dts / dt)) # state_rec is fully overwritten row-by-row in the loop below before any # row is read, so np.empty is safe and avoids the zero-fill (which # dominates startup for large state vectors). state_rec = np.empty((store_steps, y.shape[0]) if y.shape else (store_steps, 1), dtype=y.dtype) has_dde = len(args) > 0 and isinstance(args[0], DDEHistory) # solve ivp for forward Euler method. Storage cadence is driven by the # iteration counter `i`, not by the wall-clock step number — using # ``step % store_step == t0`` (the previous formulation) silently # produces zero stored samples whenever ``t0 >= store_step``. for i in range(steps): if i % store_step == 0: state_rec[idx, :] = y idx += 1 step = i + t0 rhs = func(step, y, *args) y += dt * rhs if has_dde: args[0].update((i + 1) * dt, y) return state_rec @staticmethod def _solve_heun(func: Callable, args: tuple, T: float, dt: float, dts: float, y: np.ndarray, t0): # preparations for fixed step-size integration idx = 0 steps = int(np.round(T / dt)) store_steps = int(np.round(T / dts)) store_step = int(np.round(dts / dt)) # state_rec is fully overwritten row-by-row in the loop below before any # row is read, so np.empty is safe and avoids the zero-fill (which # dominates startup for large state vectors). state_rec = np.empty((store_steps, y.shape[0]) if y.shape else (store_steps, 1), dtype=y.dtype) has_dde = len(args) > 0 and isinstance(args[0], DDEHistory) # solve ivp via Heun's method. See `_solve_euler` for the rationale # behind the iteration-counter-based storage condition. for i in range(steps): if i % store_step == 0: state_rec[idx, :] = y idx += 1 step = i + t0 rhs = func(step, y, *args) y_0 = y + dt * rhs y += dt/2 * (rhs + func(step, y_0, *args)) if has_dde: args[0].update((i + 1) * dt, y) return state_rec @staticmethod def _solve_scipy_dde(func: Callable, args: tuple, T: float, dt: float, y: np.ndarray, t0: np.ndarray, times: np.ndarray, **kwargs): from scipy.integrate import ode hist = args[0] kwargs.pop('method', None) def rhs(t, y_): return func(t, y_, *args) solver = ode(rhs).set_integrator('dopri5', first_step=dt, nsteps=50000) solver.set_initial_value(y, float(t0)) def solout(t, y_): # DDEHistory.update copies y_ into its pre-allocated buffer. hist.update(t, y_) return 0 solver.set_solout(solout) state_rec = np.zeros((len(times), y.shape[0]), dtype=y.dtype) for i, t_out in enumerate(times): if not solver.successful(): break solver.integrate(t_out) state_rec[i, :] = solver.y return state_rec @staticmethod def _solve_scipy(func: Callable, args: tuple, T: float, dt: float, y: np.ndarray, t0: np.ndarray, times: np.ndarray, **kwargs): # solve ivp via scipy methods (solvers of various orders with adaptive step-size) from scipy.integrate import solve_ivp kwargs['t_eval'] = times # call scipy solver results = solve_ivp(fun=func, t_span=(t0, T), y0=y, first_step=dt, args=args, **kwargs) return results['y'].T