Source code for pyrates.backend.julia.julia_backend

# -*- coding: utf-8 -*-
#
#
# PyRates software framework for flexible implementation of neural
# network model_templates and simulations. See also:
# https://github.com/pyrates-neuroscience/PyRates
#
# Copyright (C) 2017-2018 the original authors (Richard Gast and
# Daniel Rose), the Max-Planck-Institute for Human Cognitive Brain
# Sciences ("MPI CBS") and contributors
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <https://www.gnu.org/licenses/>
#
# CITATION:
#
# Richard Gast and Daniel Rose et. al. in preparation

"""Wraps Julia such that its low-level functions can be used by PyRates to create and simulate a compute graph.
"""

# pyrates internal _imports
from ..base import BaseBackend
from .._one_based import OneBasedCodegenMixin
from ..computegraph import ComputeVar
from .julia_funcs import julia_funcs

# external _imports
from typing import Optional, Dict, List, Callable, Iterable, Union, Tuple
import numpy as np

# meta infos
__author__ = "Richard Gast"
__status__ = "development"


# backend classes
#################


[docs]class JuliaBackend(OneBasedCodegenMixin, BaseBackend): # Adds two Julia-native paths on top of the inherited euler/heun/scipy: # 'julia_ode' → DifferentialEquations.jl ODEProblem # 'julia_dde' → DifferentialEquations.jl DDEProblem (MethodOfSteps) # ``_validate_solver`` is overridden below to accept arbitrary # "julia_*" identifiers (any string containing 'julia') so users can pick # specific algorithms via the existing ``method=`` keyword. SUPPORTED_SOLVERS = ('euler', 'heun', 'scipy', 'julia_ode', 'julia_dde') def __init__(self, ops: Optional[Dict[str, str]] = None, imports: Optional[List[str]] = None, **kwargs ) -> None: """Instantiates Julia backend. """ # add user-provided operations to function dict julia_ops = julia_funcs.copy() if ops: julia_ops.update(ops) # set default float precision to float64 kwargs["float_precision"] = "float64" # call parent method super().__init__(ops=julia_ops, imports=imports, file_ending='.jl', start_idx=1, add_hist_arg=True, **kwargs) # define julia-specific imports self._imports.pop(0) self._imports.append("using LinearAlgebra") # set up pyjulia from julia.api import Julia jl = Julia(runtime=kwargs.pop('julia_path'), compiled_modules=False) from julia import Main self._jl = Main self._no_vectorization = ["*(", "interp(", "hist("] self._fcall = None self._is_dde = False self._lags = [] def _format_assignment(self, lhs: str, rhs: str, indexed: bool) -> str: """Add Julia's broadcast prefix ``@.`` for indexed / shaped assignments, except for a few RHS expressions where broadcasting would be wrong (matrix multiplication, ``interp``, ``hist``). """ if indexed and not any(rhs.startswith(expr) for expr in self._no_vectorization): rhs = f"@. {rhs}" return f"{lhs} = {rhs}"
[docs] def add_var_hist(self, lhs: str, delay: Union[ComputeVar, float], state_idx: Union[int, tuple], **kwargs): self._is_dde = True if isinstance(delay, float): self._lags.append(delay) elif type(delay) is ComputeVar and np.ndim(delay.value) == 0: self._lags.append(float(delay.value)) idx = self._process_idx(state_idx) d = self._process_delay(delay) self.add_code_line(f"{lhs} = hist((), t-{d}; idxs={idx})")
[docs] def get_hist_func(self, y: np.ndarray, t0: float = 0.0): # Julia DDE pre-history: constant initial condition for t <= t0. # During integration DifferentialEquations.jl provides its own # interpolant as `h`; this function is only called for t < t0. self._jl.eval(f"y_init = {y.tolist()}") hist = """ function hist(p, t; idxs=nothing) return idxs == nothing ? y_init : y_init[idxs] end """ self._jl.eval(hist) return self._jl.hist
[docs] def generate_func_tail(self, rhs_var: str = 'dy'): self.add_code_line(f"return {rhs_var}") self.remove_indent() self.add_code_line("end")
[docs] def generate_func(self, func_name: str, to_file: bool = True, **kwargs): self._fcall = func_name kwargs.pop('julia_ode', None) # legacy flags, now auto-generated kwargs.pop('julia_dde', None) # append a DifferentialEquations.jl-compatible wrapper function. # For DDEs the wrapper exposes `h` (the solver interpolant) so the # RHS can call hist(p, t-d; idxs=i). For plain ODEs it is omitted. self.add_linebreak() if self._is_dde: self.add_code_line(f"function {func_name}_julia(dy, y, h, p, t)") self.add_indent() self.add_code_line(f"return {func_name}(t, y, h, dy, p...)") self.remove_indent() self.add_code_line("end") else: self.add_code_line(f"function {func_name}_julia(dy, y, p, t)") self.add_indent() self.add_code_line(f"return {func_name}(t, y, dy, p...)") self.remove_indent() self.add_code_line("end") func_str = self.generate() if to_file: # save rhs function to file file = f'{self.fdir}/{self._fname}{self._fend}' if self.fdir else f"{self._fname}{self._fend}" with open(file, 'w') as f: f.writelines(func_str) # import all functions from file into Julia Main self._jl.include(file) else: # execute the function string directly self._jl.eval(func_str) # return the main RHS function object (wrapper is accessed via _fcall in _solve) rhs_eval = getattr(self._jl, func_name) return self._apply_decorator(rhs_eval, **kwargs)
def _validate_solver(self, solver: str) -> None: # Accept any "julia*" string (the dispatch below uses ``'julia' in # solver``); fall through to the base check otherwise so plain # ``euler`` / ``heun`` / ``scipy`` keep working. if 'julia' in solver: return super()._validate_solver(solver) def _solve(self, solver: str, func: Callable, args: tuple, T: float, dt: float, dts: float, y0: np.ndarray, t0: np.ndarray, times: np.ndarray, **kwargs) -> np.ndarray: self._validate_solver(solver) if 'julia' in solver: # solve via DifferentialEquations.jl self._jl.eval('using DifferentialEquations') # retrieve the pre-generated DifferentialEquations.jl-compatible wrapper wrapper = getattr(self._jl, f'{self._fcall}_julia') method = kwargs.pop('method', 'Tsit5') atol, rtol = kwargs.pop('atol', 1e-6), kwargs.pop('rtol', 1e-3) # auto-detect DDE: _is_dde is set when add_var_hist was called, # or the user can still explicitly request it with solver='julia_dde' is_dde = self._is_dde or 'dde' in solver if is_dde: # args layout: (hist_julia_func, dy_zeros, p1, p2, ...) # hist_julia_func → h argument of DDEProblem # dy_zeros → internal buffer, not passed to Julia # p1, p2, ... → parameter tuple for DDEProblem hist_func = args[0] params = args[2:] # pass constant_lags for discontinuity tracking when available lags = kwargs.pop('constant_lags', self._lags if self._lags else None) solve_kwargs = dict(saveat=times, atol=atol, rtol=rtol) if lags: model = self._jl.DDEProblem(wrapper, y0, hist_func, [float(t0), T], params, constant_lags=list(lags)) else: model = self._jl.DDEProblem(wrapper, y0, hist_func, [float(t0), T], params) jl_solver = self._jl.MethodOfSteps(getattr(self._jl, method)()) results = self._jl.solve(model, jl_solver, **solve_kwargs) else: # args layout: (dy_zeros, p1, p2, ...) # dy_zeros is skipped; p1, p2, ... form the parameter tuple params = args[1:] model = self._jl.ODEProblem(wrapper, y0, [float(t0), T], params) jl_solver = getattr(self._jl, method)() results = self._jl.solve(model, jl_solver, saveat=times, atol=atol, rtol=rtol) results = np.asarray(results).T else: # non-julia solver — fall back to Python/scipy solvers results = super()._solve(solver=solver, func=func, args=args, T=T, dt=dt, dts=dts, y0=y0, t0=t0, times=times, **kwargs) return results def _add_func_call(self, name: str, args: Iterable, return_var: str = 'dy'): self.add_code_line(f"function {name}({','.join(args)})")
# `_process_idx`, `create_index_str`, `expr_to_str`, and `get_var` are # inherited from OneBasedCodegenMixin (also shared with MatlabBackend).