# -*- coding: utf-8 -*-
#
#
# PyRates software framework for flexible implementation of neural
# network model_templates and simulations. See also:
# https://github.com/pyrates-neuroscience/PyRates
#
# Copyright (C) 2017-2018 the original authors (Richard Gast and
# Daniel Rose), the Max-Planck-Institute for Human Cognitive Brain
# Sciences ("MPI CBS") and contributors
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>
#
# CITATION:
#
# Richard Gast and Daniel Rose et. al. in preparation
"""Wraps Julia such that its low-level functions can be used by PyRates to create and simulate a compute graph.
"""
# pyrates internal _imports
import sys
from ..base import BaseBackend
from ..computegraph import ComputeVar
from .julia_funcs import julia_funcs
# external _imports
from typing import Optional, Dict, List, Callable, Iterable, Union, Tuple
import numpy as np
# meta infos
__author__ = "Richard Gast"
__status__ = "development"
# backend classes
#################
[docs]class JuliaBackend(BaseBackend):
def __init__(self,
ops: Optional[Dict[str, str]] = None,
imports: Optional[List[str]] = None,
**kwargs
) -> None:
"""Instantiates Julia backend.
"""
# add user-provided operations to function dict
julia_ops = julia_funcs.copy()
if ops:
julia_ops.update(ops)
# set default float precision to float64
kwargs["float_precision"] = "float64"
# call parent method
super().__init__(ops=julia_ops, imports=imports, file_ending='.jl', start_idx=1, add_hist_arg=True, **kwargs)
# define julia-specific imports
self._imports.pop(0)
self._imports.append("using LinearAlgebra")
# set up pyjulia
from julia.api import Julia
jl = Julia(runtime=kwargs.pop('julia_path'), compiled_modules=False)
from julia import Main
self._jl = Main
self._no_vectorization = ["*(", "interp("]
self._fcall = None
[docs] def get_var(self, v: ComputeVar):
v = super().get_var(v)
dtype = v.dtype.name
s = sum(v.shape)
if s > 0:
return v
if 'float' in dtype:
return float(v)
if 'complex' in dtype:
return complex(np.real(v), np.imag(v))
return int(v)
[docs] def add_var_update(self, lhs: ComputeVar, rhs: str, lhs_idx: Optional[str] = None, rhs_shape: Optional[tuple] = ()):
super().add_var_update(lhs=lhs, rhs=rhs, lhs_idx=lhs_idx, rhs_shape=rhs_shape)
if rhs_shape or lhs_idx:
line = self.code.pop()
lhs, rhs = line.split(' = ')
if not any([rhs[:len(expr)] == expr for expr in self._no_vectorization]):
rhs = f"@. {rhs}"
self.add_code_line(f"{lhs} = {rhs}")
[docs] def add_var_hist(self, lhs: str, delay: Union[ComputeVar, float], state_idx: Union[int, tuple], **kwargs):
idx = self._process_idx(state_idx)
d = self._process_delay(delay)
self.add_code_line(f"{lhs} = hist((), t-{d}; idxs={idx})")
[docs] def get_hist_func(self, y: np.ndarray):
self._jl.eval(f"y_init = {y.tolist()}")
hist = """
function hist(p, t; idxs=nothing)
return idxs == nothing ? y_init : y_init[idxs]
end
"""
self._jl.eval(hist)
return self._jl.hist
[docs] def create_index_str(self, idx: Union[str, int, tuple], separator: str = ',', apply: bool = True,
**kwargs) -> Tuple[str, dict]:
if not apply:
self._start_idx = 0
idx, idx_dict = super().create_index_str(idx, separator, apply, **kwargs)
self._start_idx = 1
return idx, idx_dict
else:
return super().create_index_str(idx, separator, apply, **kwargs)
[docs] def generate_func_tail(self, rhs_var: str = 'dy'):
self.add_code_line(f"return {rhs_var}")
self.remove_indent()
self.add_code_line("end")
[docs] def generate_func(self, func_name: str, to_file: bool = True, **kwargs):
self._fcall = func_name
# generate the current function string via the code generator
if kwargs.pop('julia_ode', False):
self.add_linebreak()
self.add_code_line(f"function {func_name}_julia(dy, y, p, t)")
self.add_indent()
self.add_code_line(f"return {func_name}(t, y, dy, p...)")
self.remove_indent()
self.add_code_line("end")
if kwargs.pop('julia_dde', False):
self.add_linebreak()
self.add_code_line(f"function {func_name}_julia(dy, y, h, p, t)")
self.add_indent()
self.add_code_line(f"return {func_name}(t, y, h, dy, p...)")
self.remove_indent()
self.add_code_line("end")
func_str = self.generate()
if to_file:
# save rhs function to file
file = f'{self.fdir}/{self._fname}{self._fend}' if self.fdir else f"{self._fname}{self._fend}"
with open(file, 'w') as f:
f.writelines(func_str)
f.close()
# import function from file
rhs_eval = self._jl.include(file)
else:
# just execute the function string, without writing it to file
rhs_eval = self._jl.eval(func_str)
# apply function decorator
decorator = kwargs.pop('decorator', None)
if decorator:
decorator_kwargs = kwargs.pop('decorator_kwargs', dict())
rhs_eval = decorator(rhs_eval, **decorator_kwargs)
return rhs_eval
def _solve(self, solver: str, func: Callable, args: tuple, T: float, dt: float, dts: float, y0: np.ndarray,
t0: np.ndarray, times: np.ndarray, **kwargs) -> np.ndarray:
if 'julia' in solver:
# solve via DifferentialEquations.jl
self._jl.eval('using DifferentialEquations')
if 'dde' in solver:
# define wrapper function and solver family
jfunc = f"""
function julia_dderun(du,u,h,p,t)
return {self._fcall}(t,u,h,du,p...)
end
"""
# solve ivp via DifferentialEquations.jl solver
self._jl.eval(jfunc)
model = self._jl.DDEProblem(self._jl.julia_dderun, y0, args[0], [0.0, T], args[2:])
method = kwargs.pop('method', 'Tsit5')
solver = getattr(self._jl, method)
solver = self._jl.MethodOfSteps(solver())
atol, rtol = kwargs.pop('atol', 1e-6), kwargs.pop('rtol', 1e-3)
results = self._jl.solve(model, solver, saveat=times, atol=atol, rtol=rtol)
else:
# define wrapper function and solver family
jfunc = f"""
function julia_oderun(du,u,p,t)
return {self._fcall}(t,u,du,p...)
end
"""
# solve ivp via DifferentialEquations.jl solver
self._jl.eval(jfunc)
model = self._jl.ODEProblem(self._jl.julia_oderun, y0, [0.0, T], args[1:])
method = kwargs.pop('method', 'Tsit5')
solver = getattr(self._jl, method)
atol, rtol = kwargs.pop('atol', 1e-6), kwargs.pop('rtol', 1e-3)
results = self._jl.solve(model, solver(), saveat=times, atol=atol, rtol=rtol)
results = np.asarray(results).T
else:
# non-julia solver
results = super()._solve(solver=solver, func=func, args=args, T=T, dt=dt, dts=dts, y0=y0, t0=t0,
times=times, **kwargs)
return results
def _add_func_call(self, name: str, args: Iterable, return_var: str = 'dy'):
self.add_code_line(f"function {name}({','.join(args)})")
def _process_idx(self, idx: Union[Tuple[int, int], int, str, ComputeVar], **kwargs) -> str:
if type(idx) is str and idx != ':' and ':' in idx:
idx0, idx1 = idx.split(':')
self._start_idx = 0
idx0 = int(self._process_idx(idx0))
idx1 = int(self._process_idx(idx1))
self._start_idx = 1
return self._process_idx((idx0, idx1))
if type(idx) is ComputeVar and idx.name == "t" and idx.value >= self._start_idx:
self._start_idx = 0
idx_processed = super()._process_idx(idx=idx, **kwargs)
self._start_idx = 1
return idx_processed
return super()._process_idx(idx=idx, **kwargs)
[docs] @staticmethod
def expr_to_str(expr: str, args: tuple):
# replace power operator
func = '**'
while func in expr:
expr = expr.replace(func, '^')
return expr