Source code for pyrates.backend.torch.torch_backend

# -*- coding: utf-8 -*-
#
#
# PyRates software framework for flexible implementation of neural
# network model_templates and simulations. See also:
# https://github.com/pyrates-neuroscience/PyRates
#
# Copyright (C) 2017-2018 the original authors (Richard Gast and
# Daniel Rose), the Max-Planck-Institute for Human Cognitive Brain
# Sciences ("MPI CBS") and contributors
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <https://www.gnu.org/licenses/>
#
# CITATION:
#
# Richard Gast and Daniel Rose et. al. in preparation

"""Wraps torch such that it's low-level functions can be used by PyRates to create and simulate a compute graph.
"""

# pyrates internal _imports
from ..base import BaseBackend
from ..computegraph import ComputeVar
from .torch_funcs import torch_funcs

# external _imports
import torch
import numpy as np
from typing import Callable, Optional, Dict, List

# meta infos
__author__ = "Richard Gast"
__status__ = "development"


#######################################
# classes for backend functionalities #
#######################################


[docs]class TorchBackend(BaseBackend): # `heun` falls back to BaseBackend._solve_heun, which writes to a numpy # state-record buffer and would silently sever the autograd graph. # Declare it unsupported until a tensor-native heun is added. SUPPORTED_SOLVERS = ('euler', 'scipy') def __init__(self, ops: Optional[Dict[str, str]] = None, imports: Optional[List[str]] = None, **kwargs ) -> None: """Instantiates PyTorch backend. """ # add user-provided operations to function dict torch_ops = torch_funcs.copy() if ops: torch_ops.update(ops) # ensure that long is the standard integer type if 'int_precision' in kwargs: print(f"Warning: User-provided integer precision `{kwargs.pop('int_precision')}` will be ignored, since the" f"torch backend requires integer precision `int64` for some indexing operations.") kwargs['int_precision'] = 'int64' # call parent method super().__init__(ops=torch_ops, imports=imports, **kwargs) self._imports[0] = "from torch import pi, sqrt"
[docs] def get_var(self, v: ComputeVar): return torch.from_numpy(super().get_var(v))
def _solve_scipy(self, func: Callable, args: tuple, T: float, dt: float, y: torch.Tensor, t0: torch.Tensor, times: torch.Tensor, **kwargs): # solve ivp via scipy methods (solvers of various orders with adaptive step-size) from scipy.integrate import solve_ivp kwargs['t_eval'] = times dtype = self._torch_float_dtype(y) # wrapper to rhs function: use torch.as_tensor for a zero-copy view of # the numpy arrays scipy hands us (a copy only occurs when the dtypes # differ, matching the previous torch.tensor() behavior). def f(t, y): rhs = func(torch.as_tensor(t, dtype=dtype), torch.as_tensor(y, dtype=dtype), *args) return rhs.numpy() # call scipy solver results = solve_ivp(fun=f, t_span=(t0, T), y0=y, first_step=dt, **kwargs) return results['y'].T def _solve_scipy_dde(self, func: Callable, args: tuple, T: float, dt: float, y: torch.Tensor, t0: torch.Tensor, times: np.ndarray, **kwargs): """DDE integration via scipy.integrate.ode. Closes the silent-fallback flagged in review §1.2: when a model with ``delay`` edges is run on TorchBackend with ``solver='scipy'``, the BaseBackend implementation called the compiled torch function with raw numpy inputs (it happened to work for purely-arithmetic RHS but was undefined for tensor ops). This override wraps the RHS with the same :code:`torch.as_tensor` boundary conversion used by :meth:`_solve_scipy`. Note: autograd does not flow through the scipy step; the returned trajectory is plain numpy. A tensor-native DDE solver would be needed for differentiable simulation. """ from scipy.integrate import ode from ..base.base_backend import DDEHistory hist = args[0] if not isinstance(hist, DDEHistory): raise TypeError("_solve_scipy_dde expects args[0] to be a DDEHistory instance.") kwargs.pop('method', None) dtype = self._torch_float_dtype(y) # rhs wrapper that crosses the numpy ↔ torch boundary cleanly def rhs(t, y_): out = func(torch.as_tensor(t, dtype=dtype), torch.as_tensor(y_, dtype=dtype), *args) return out.numpy() if hasattr(out, 'numpy') else np.asarray(out) y_np = y.numpy() if hasattr(y, 'numpy') else np.asarray(y) solver = ode(rhs).set_integrator('dopri5', first_step=dt, nsteps=50000) solver.set_initial_value(y_np, float(t0)) def solout(t, y_): # DDEHistory.update copies y_ into its pre-allocated buffer. hist.update(t, y_) return 0 solver.set_solout(solout) # np.zeros (not np.empty): the loop may break early on integrator # failure, leaving unwritten rows that must remain defined. state_rec = np.zeros((len(times), y_np.shape[0]), dtype=y_np.dtype) for i, t_out in enumerate(times): if not solver.successful(): break solver.integrate(t_out) state_rec[i, :] = solver.y return state_rec def _torch_float_dtype(self, y: torch.Tensor) -> torch.dtype: """Resolve the torch dtype used inside scipy-bridged RHS wrappers.""" if y.dtype.is_complex: return y.dtype try: return getattr(torch, self._float_precision) except AttributeError: return torch.get_default_dtype() @staticmethod def _solve_euler(func: Callable, args: tuple, T: float, dt: float, dts: float, y: torch.Tensor, t0: int ) -> np.ndarray: """Forward-Euler integration with tensor-valued state. Mirrors :code:`BaseBackend._solve_euler`'s control flow: the write cursor (``idx``) and the time index (``step``) are separate variables. The previous implementation conflated them, which broke any nonzero ``t0`` (review §4.1) — the first stored sample landed at ``state_rec[t0, :]`` instead of ``state_rec[0, :]``. DDE history updates from :code:`BaseBackend._solve_euler` are not replicated here because :class:`DDEHistory.update` calls :code:`y.copy()`, which is not a method on torch tensors. A tensor-native DDE+Euler path would need its own ring buffer; until then DDE simulation on the torch backend should use ``solver='scipy'`` (see :meth:`_solve_scipy_dde` below). """ # preparations for fixed step-size integration idx = 0 steps = int(np.round(T / dt)) store_steps = int(np.round(T / dts)) store_step = int(np.round(dts / dt)) # state_rec is fully overwritten row-by-row; torch.empty skips the # zero-fill (safe now that the idx-shadow bug is fixed). state_rec = torch.empty((store_steps, y.shape[0]) if y.shape else (store_steps, 1), dtype=y.dtype) # solve ivp via forward Euler. Storage cadence is driven by the # iteration counter `i` rather than the wall-clock step number — see # BaseBackend._solve_euler for the rationale (review §4.2). t0_int = int(t0) for i in range(steps): if i % store_step == 0: state_rec[idx, :] = y idx += 1 step = i + t0_int rhs = func(step, y, *args) y += dt * rhs return state_rec.numpy()