# -*- coding: utf-8 -*-
# PyRates software framework for flexible implementation of neural
# network model_templates and simulations. See also:
# https://github.com/pyrates-neuroscience/PyRates
# Copyright (C) 2017-2018 the original authors (Richard Gast and
# Daniel Rose), the Max-Planck-Institute for Human Cognitive Brain
# Sciences ("MPI CBS") and contributors
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# GNU General Public License for more details.
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>
# Richard Gast and Daniel Rose et. al. in preparation
"""This module provides the backend class that should be used to set up any backend in pyrates.
# external _imports
from typing import Any, Callable, Union, Iterable, Optional
from networkx import MultiDiGraph
from sympy import Symbol, Expr, Function
import numpy as np
# meta infos
__author__ = "Richard Gast"
__status__ = "development"
# compute graph classes #
# numpy-based node class
[docs]class ComputeNode:
"""Base class for adding variables to the compute graph. Creates a numpy array with additional attributes
for variable identification/retrieval from graph. Should be used as parent class for custom variable classes.
Full name of the variable in the original NetworkGraph (including the node and operator it belongs to).
Data-type of the variable. For valid data-types, check the documentation of the backend in use.
Shape of the variable.
__slots__ = ["name", "symbol", "dtype", "shape", "_value"]
def __init__(self, name: str, symbol: Union[Symbol, Expr, Function], dtype: Optional[str] = None,
shape: Optional[tuple] = None, def_shape: Optional[tuple] = None):
"""Instantiates a basic node of a ComputeGraph instance.
self.name = name
self.symbol = symbol
self.shape = self._get_shape(shape, def_shape)
self._value = np.zeros(self.shape)
self.dtype = dtype
[docs] def reshape(self, shape: tuple, **kwargs):
self._value = self.value.reshape(shape, **kwargs)
self.shape = shape
return self
[docs] def squeeze(self, axis=None):
self._value = self.value.squeeze(axis=axis)
self.shape = self._value.shape
return self
[docs] def set_value(self, v: Union[float, np.ndarray]):
self._value = np.asarray(v, dtype=self.dtype)
self.shape = tuple(v.shape)
def value(self):
"""Returns current value of BaseVar.
return self._value
def is_constant(self):
raise NotImplementedError("This method has to be defined by each child class.")
def is_float(self):
return "float" in self.dtype
def is_complex(self):
return "complex" in self.dtype
def _is_equal_to(self, v):
for attr in self.__slots__:
if not hasattr(v, attr) or getattr(v, attr) != getattr(self, attr):
return False
return True
def _get_value(self, value: Optional[Union[list, np.ndarray]] = None, dtype: Optional[str] = None,
shape: Optional[tuple] = None):
"""Defines initial value of variable.
# case I: create new array from shape and dtype
if value is None:
return np.zeros(shape=shape, dtype=dtype)
# case II: transform values into an array
if not hasattr(value, 'shape'):
if type(value) is list:
return self._get_value(value=np.asarray(value, dtype=dtype), dtype=dtype, shape=shape)
return np.zeros(shape=shape, dtype=dtype) + value
# case III: match given shape with the shape of the given value array
if shape is not None:
value = np.asarray(value, dtype=dtype)
if value.shape == shape:
return value
if sum(shape) < sum(value.shape):
return value.squeeze()
idx = ",".join("None" if s == 1 else ":" for s in shape)
return eval(f'value[{idx}]')
# case IV: just ensure the correct data type of the value array
return np.asarray(value, dtype=dtype)
[docs] def set_dtype(self, dtype: str = None):
if dtype is None:
if not self.dtype:
if 'float' in str(self.value.dtype):
self.dtype = 'float'
elif 'complex' in str(self.value.dtype):
self.dtype = 'complex'
self.dtype = 'int'
self.dtype = dtype
def _get_shape(s: Union[tuple, None], s_def: tuple):
if s is None or sum(s) < 2:
return s_def
return s
def __deepcopy__(self, memodict: dict):
node = ComputeNode(name=self.name, symbol=self.symbol, dtype=self.dtype, shape=self.shape)
node._value = np.zeros_like(node._value) + node._value
return node
def __str__(self):
return self.name
def __hash__(self):
return hash(str(self))
[docs]class ComputeVar(ComputeNode):
"""Class for variables and vector-valued constants in the ComputeGraph.
__slots__ = ComputeNode.__slots__ + ["vtype"]
def __init__(self, name: str, symbol: Union[Symbol, Expr, Function], vtype: str, dtype: Optional[str] = None,
shape: Optional[str] = None, value: Optional[Union[list, np.ndarray]] = None,
def_shape: Optional[tuple] = None):
# set attributes
super().__init__(name=name, symbol=symbol, dtype=dtype, shape=shape, def_shape=def_shape)
self.vtype = vtype
# adjust variable value
self.set_value(self._get_value(value=value, shape=self.shape, dtype=self.dtype))
def is_constant(self):
return self.vtype == 'constant'
[docs]class ComputeOp(ComputeNode):
"""Class for ComputeGraph nodes that represent mathematical operations.
__slots__ = ComputeNode.__slots__ + ["func", "expr"]
def __init__(self, name: str, symbol: Union[Symbol, Expr, Function], func: Callable, expr: Expr,
dtype: Optional[str] = None, shape: Optional[str] = None):
# set attributes
super().__init__(name=name, symbol=symbol, dtype=dtype, shape=shape)
self.func = func
self.expr = expr
def is_constant(self):
return False
# networkx-based graph class
[docs]class ComputeGraph(MultiDiGraph):
"""Creates a compute graph where nodes are all constants and variables of the network and edges are the mathematical
operations linking those variables/constants together to form equations.
def __init__(self, backend: str, **kwargs):
# choose a backend
if backend == 'tensorflow':
from pyrates.backend.tensorflow import TensorflowBackend
backend = TensorflowBackend
elif backend == 'torch':
from pyrates.backend.torch import TorchBackend
backend = TorchBackend
elif backend == 'fortran':
from pyrates.backend.fortran import FortranBackend
backend = FortranBackend
elif backend == 'julia':
from pyrates.backend.julia import JuliaBackend
backend = JuliaBackend
elif backend == 'matlab':
from pyrates.backend.matlab import MatlabBackend
backend = MatlabBackend
from pyrates.backend.base import BaseBackend
backend = BaseBackend
# backend-related attributes
self.backend = backend(**kwargs)
self.var_updates = {'DEs': dict(), 'non-DEs': dict()}
self._eq_nodes = []
self._state_var_indices = dict()
self._state_var_hist = dict()
self._node_names = {}
def state_vars(self):
return list(self.var_updates['DEs'].keys())
[docs] def add_var(self, label: str, value: Any, vtype: str, **kwargs):
unique_label = self._generate_unique_label(label)
var = ComputeVar(name=unique_label, symbol=Symbol(unique_label), value=value, vtype=vtype, **kwargs)
super().add_node(unique_label, node=var)
return unique_label, self.nodes[unique_label]['node']
[docs] def add_op(self, inputs: Union[list, tuple], label: str, expr: Expr, func: Callable, **kwargs):
# add target node that contains result of operation
unique_label = self._generate_unique_label(label)
op = ComputeOp(name=unique_label, symbol=Symbol(unique_label), func=func, expr=expr, **kwargs)
super().add_node(unique_label, node=op)
# add edges from source nodes to target node
for i, v in enumerate(inputs):
super().add_edge(v, unique_label, key=i)
return unique_label, self.nodes[unique_label]['node']
[docs] def add_var_update(self, var: str, update: str, differential_equation: bool = False):
# store mapping between left-hand side variable and right-hand side update
if differential_equation:
self.var_updates['DEs'][var] = update
self.var_updates['non-DEs'][var] = update
# remember var and update node to ensure that they are not pruned during compilation
self._eq_nodes.extend([var, update])
[docs] def get_var(self, var: str, from_backend: bool = False):
v = self.nodes[var]['node']
if from_backend:
return self.backend.get_var(v)
return v
[docs] def get_op(self, op: str, **kwargs) -> dict:
return self.backend.get_op(op, **kwargs)
[docs] def eval_graph(self):
for n in self.var_updates['non-DEs'].values():
return self.eval_nodes(self.var_updates['DEs'].values())
[docs] def eval_nodes(self, nodes: Iterable):
return [self.eval_node(n) for n in nodes]
[docs] def eval_node(self, n):
inputs = tuple([self.eval_node(inp) for inp in self.predecessors(n)])
node = self.get_var(n)
if isinstance(node, ComputeOp):
return node.func(*inputs)
return node.value
[docs] def eval_subgraph(self, n):
inputs = []
input_nodes = [node for node in self.predecessors(n)]
for inp in input_nodes:
node = self.get_var(n)
if inputs:
return node.value
[docs] def remove_subgraph(self, n):
for inp in self.predecessors(n):
[docs] def compile(self):
# evaluate constant-based operations
out_nodes = [node for node, out_degree in self.out_degree if out_degree == 0]
for node in out_nodes:
# process inputs of node
for inp in self.predecessors(node):
if self.get_var(inp).is_constant:
# evaluate node if all its inputs are constants
if all([self.get_var(inp).is_constant for inp in self.predecessors(node)]) and node not in self._eq_nodes:
# remove unconnected nodes and constants from graph
return self
[docs] def to_func(self, func_name: str, to_file: bool = True, dt_adapt: bool = True, **kwargs):
# finalize compute graph
# create state variable vector and state variable update vector
variables = []
idx = 0
for var, update in self.var_updates['DEs'].items():
# extract left-hand side nodes from graph
lhs, rhs = self._process_var_update(var, update)
# store information of the original, non-vectorized state variable
vshape = sum(lhs.shape)
if vshape >= 1:
self._state_var_indices[var] = (idx, idx+vshape)
idx += vshape
self._state_var_indices[var] = idx
idx += 1
# add collected state variables to the backend
state_vec = np.concatenate(variables, axis=0)
except ValueError:
state_vec = np.asarray(variables)
dtype = 'complex' if 'complex' in state_vec.dtype.name else 'float'
state_var_key, y = self.add_var(label='y', vtype='state_var', value=state_vec, dtype=dtype)
rhs_var_key = self._generate_vecfield_var(state_vec, dtype)
t = self.get_var('t')
except KeyError:
_, t = self.add_var(label='t', vtype='state_var', value=0.0 if dt_adapt else 0,
dtype='float' if dt_adapt else 'int', shape=())
self.backend.register_vars([t, y])
# create a string containing all computations and variable updates represented by the compute graph
func_args, code_gen = self._to_str()
func_body = code_gen.generate()
# generate function head
add_hist_calls = self._state_var_hist and code_gen.add_hist_arg
func_args = code_gen.generate_func_head(func_name=func_name, state_var=state_var_key, return_var=rhs_var_key,
func_args=[self.get_var(arg) for arg in func_args],
# extract state variable histories for delayed interactions
for var, delays in self._state_var_hist.items():
# extract index of variable in state vector
idx = self._state_var_indices[var]
# extract state variable history from backend-specific buffer
if type(idx) is not int:
if idx[1]-idx[0] < 2:
idx = idx[0]
for delay, v_hist in delays.items(): # type: ComputeVar, str
code_gen.add_var_hist(lhs=v_hist, delay=delay, state_idx=idx, var=var)
# add lines from function body after function head
# generate function tail
self._generate_func_tail(code_gen, rhs_var_key)
# generate the function (and write to file, optionally)
func_args_tmp = func_args[4:] if add_hist_calls else func_args[3:]
func = code_gen.generate_func(func_name=func_name, to_file=to_file, func_args=func_args_tmp,
state_vars=self.state_vars, **kwargs)
# OPTIONAL: write function arguments (state vectors and constants) to file
c_fn = kwargs.pop('constants_file_name', None)
if c_fn:
arg_dict = {arg: self.get_var(arg).value for arg in func_args if arg != 'hist'}
if code_gen.lags:
arg_dict['lags'] = list(code_gen.lags.keys())
fn = f'{self.backend.fdir}/{c_fn}' if self.backend.fdir else c_fn
code_gen.to_file(fn, **arg_dict)
# finalize the function arguments
fargs = []
for arg in func_args:
if arg == 'hist':
if 'hist' in kwargs:
arg = kwargs.pop('hist')
y_init = np.asarray(state_vec[:])
arg = code_gen.get_hist_func(y_init)
fargs.append(self.get_var(arg, from_backend=True))
return func, tuple(fargs), tuple(func_args), self._state_var_indices.copy()
[docs] def run(self, func: Callable, func_args: tuple, T: float, dt: float, dts: Optional[float] = None,
outputs: Optional[dict] = None, **kwargs) -> dict:
# pre-process outputs
if outputs is None:
outputs = {key: key for key in self.state_vars}
for key in outputs.copy():
var = outputs.pop(key)
outputs[key] = self._state_var_indices[var]
# handle other arguments
if dts is None:
dts = dt
solver = kwargs.pop('solver', 'euler')
# call backend method
results, times = self.backend.run(func=func, func_args=func_args, T=T, dt=dt, dts=dts, solver=solver, **kwargs)
# set state variables to final simulated value
for key in self.state_vars:
var = self.get_var(key)
idx = self._state_var_indices[key]
var.set_value(np.reshape(self._index_state_var(results, idx)[-1], var.shape))
# reduce state recordings to requested state variables
for key, idx in outputs.items():
outputs[key] = self._index_state_var(results, idx)
outputs['time'] = times
return outputs
[docs] def clear(self) -> None:
"""Deletes build directory and removes all compute graph nodes
# delete network nodes and variables from the compute graph
for n in list(self.nodes.keys()):
# clear code generator
def _to_str(self):
# preparations
code_gen = self.backend
# extract state variable from state vector
rhs_indices_str = []
for var in self.state_vars:
# extract index of variable in state vector
idx = (self._state_var_indices[var],)
# extract state variable from state vector
rhs_idx, _ = code_gen.create_index_str(idx)
code_gen.add_var_update(lhs=self.get_var(var), rhs=f"y{rhs_idx}")
# get equation string and argument list for each non-DE node at the end of the compute graph hierarchy
func_args2, delete_args1 = self._generate_update_equations(differential_equations=False)
# get equation string and argument list for each DE node at the end of the compute graph hierarchy
func_args1, delete_args2 = self._generate_update_equations(differential_equations=True, indices=rhs_indices_str)
# remove unnecessary function arguments
func_args = func_args1 + func_args2
for arg in delete_args1 + delete_args2:
while arg in func_args:
return func_args, code_gen
def _generate_update_equations(self, differential_equations: bool, indices: list = None) -> tuple:
code_gen = self.backend
# extract relevant compute graph nodes and bring them into the correct order
nodes = self.var_updates['DEs' if differential_equations else 'non-DEs']
nodes, updates, def_vars, undef_vars = self._sort_var_updates(nodes=nodes,
# collect right-hand side expression and all input variables to these expressions
func_args, expressions, var_names, rhs_shapes, lhs_indices = undef_vars, [], [], [], []
for node, update in zip(nodes, updates):
# collect shape of the right-hand side variable
v = self.get_var(update)
v_eval = self.eval_node(update)
except IndexError:
# collect expression and variables of right-hand side of equation
expr_args, expr = self._node_to_expr(update)
expr_str, expr_args, _, _ = self._expr_to_str(expr, apply=True)
# process left-hand side of equation
var = self.get_var(node)
if isinstance(var, ComputeOp):
# process indexing of left-hand side variable
idx_args, lhs = self._node_to_expr(node)
if lhs.args[0].name not in def_vars:
_, idx_args, lhs_var, idx = self._expr_to_str(lhs, apply=False)
# process normal update of left-hand side variable
lhs_var = var.name
idx = None
# add the left-hand side assignments of the collected right-hand side expressions to the code generator
if differential_equations:
# differential equation (DE) update
if not indices:
raise ValueError('State variables need to be stored in a single state vector, for which the indices '
'have to be passed to this method.')
add_args = self._generate_vecfield(code_gen, indices, expressions, rhs_shapes, var_names)
func_args = add_args + func_args
# non-differential equation update
if indices:
raise ValueError('Indices to non-state variables should be defined in the respective equations, not'
'be passed to this method.')
indices = lhs_indices
# non-DE update stored in a single variable
for target_var, expr, idx, shape in zip(var_names, expressions, indices, rhs_shapes):
idx = self.get_var(idx)
except (KeyError, AttributeError, TypeError):
code_gen.add_var_update(lhs=self.get_var(target_var), rhs=expr, lhs_idx=idx, rhs_shape=shape)
return func_args, def_vars
def _generate_vecfield(self, code_gen, indices: list, expressions: list, rhs_shapes: list, lhs_vars: list) -> list:
# DE updates stored in a state-vector
dy = self.get_var("dy")
for idx, expr, shape in zip(indices, expressions, rhs_shapes):
code_gen.add_var_update(lhs=dy, rhs=expr, lhs_idx=idx, rhs_shape=shape)
# add rhs var to function arguments
return ['dy']
def _generate_vecfield_var(self, state_vec: np.ndarray, dtype: str) -> str:
key, _ = self.add_var(label='dy', vtype='state_var', value=np.zeros_like(state_vec), dtype=dtype)
return key
def _generate_func_tail(self, code_gen, vecfield_key: str):
def _node_to_expr(self, n: str, **kwargs) -> tuple:
expr_args = []
node = self.get_var(n)
# case I: node is a mathematical operation and its inputs need to be treated
# process node inputs
expr_info = {self.get_var(inp).symbol: self._node_to_expr(inp, **kwargs) for inp in self.predecessors(n)}
# replace old inputs with its processed versions
expr = node.expr
for expr_old, (args, expr_new) in expr_info.items():
if expr_old != expr_new:
expr = expr.replace(expr_old, expr_new)
# replace generic function calls with the backend-specific function calls
expr_old = expr.func.__name__
func_info = self.get_op(expr_old, shape=node.shape)
expr = expr.replace(expr.func, Function(func_info['call']))
except (AttributeError, KeyError):
# case II: node is a simple variable or constant
except AttributeError:
# add constants to the expression arguments list
if node.is_constant:
expr = node.symbol
elif 'dummy_constant' in node.name:
expr = Symbol(str(float(node.value)))
expr = node.symbol
return expr_args, expr
def _expr_to_str(self, expr: Any, expr_str: str = None, apply: bool = True, **kwargs) -> tuple:
# preparations
# initializations
index_args = []
func = ""
idx = ""
# ensure expression string exists
if not expr_str:
expr_str = str(expr)
# transform expression arguments into strings
expr_args = []
for arg in expr.args:
expr_part, args, _, _ = self._expr_to_str(arg, **kwargs)
expr_str = expr_str.replace(str(arg), expr_part)
var = str(expr_args[0]) if expr.args else ""
# process indexing operations
if 'index_1d(' in expr_str:
# replace `index` calls with brackets-based indexing
idx = self._get_var_idx(idx=(expr.args[1],), args=index_args, apply=apply, **kwargs)
func = 'index_1d'
elif 'index_2d(' in expr_str:
# replace `2d_index` calls with brackets-based indexing
idx = self._get_var_idx(idx=(expr.args[1], expr.args[2]), args=index_args, apply=apply, **kwargs)
func = 'index_2d'
elif 'index_range(' in expr_str:
# replace `range_index` calls with brackets-based indexing
idx = self._get_var_idx(idx=((expr.args[1], expr.args[2]),), args=index_args, apply=apply, **kwargs)
func = 'index_range'
elif 'index_axis(' in expr_str:
# replace `axis_index` calls with brackets-based indexing
if len(expr.args) < 2:
idx = self._get_var_idx(idx=(':',), args=index_args, apply=apply, **kwargs)
idx = self._get_var_idx(args=index_args, apply=apply,
idx=tuple([':' for _ in range(expr.args[2])] + [f"{expr.args[1]}"]), **kwargs)
func = "index_axis"
# either apply the above indexing calls or return them
if func and apply:
replacement = self.backend.finalize_idx_str(var=self.get_var(var), idx=idx)
expr_str = self._process_func_call(expr=expr_str, func=func, replacement=replacement)
# handle other function calls
if 'identity(' in expr_str:
# replace `no_op` calls with first argument to the function call
expr_str = self._process_func_call(expr=expr_str, func='identity', replacement=var)
if 'past(' in expr_str:
# replace past calls with the delayed version of the backend variable
delay = self.get_var(expr.args[1].name)
except AttributeError:
delay = float(expr.args[1])
replacement = self._get_var_hist(var=var, delay=delay)
expr_str = self._process_func_call(expr=expr_str, func="past", replacement=replacement)
# backend-specific function call adjustments
expr_str = self.backend.expr_to_str(expr_str, expr.args)
return expr_str, index_args, var, idx
def _process_var_update(self, var: str, update: str) -> tuple:
# extract nodes
lhs = self.get_var(var)
rhs = self.eval_node(update)
# extract common shape
if lhs.shape == rhs.shape:
return lhs, rhs
rhs = rhs.reshape(lhs.shape)
return lhs, rhs
except ValueError:
raise ValueError(
f"Shapes of state variable {var} and its right-hand side update {self.get_var(update).expr} do not"
" match.")
def _sort_var_updates(self, nodes: dict, differential_equations: bool = True) -> tuple:
# case I: for differential equations, do not perform any sorting
if differential_equations:
return list(nodes.keys()), list(nodes.values()), [], []
# case II: for non-differential equations, sort them according to their graph connections
# step 1: ensure lhs-indexing operations are considered as well
node_names, node_keys = [], []
for node in nodes:
n = self.get_var(node)
if type(n) is ComputeVar:
keys, values, defined_vars, undefined_vars = [], [], [], []
n_nodes = len(nodes)
while nodes:
for node, update in nodes.copy().items():
# go through node inputs and check whether it depends on other equations to be evaluated first
dependent, inp = False, ""
for inp in self._get_inputs(update):
if inp in node_names:
idx = node_names.index(inp)
if node_keys[idx] != node:
dependent = True
# decide whether this equation can be evaluated now
if dependent:
idx = node_keys.index(node)
n = node_names.pop(idx)
if isinstance(self.get_var(node), ComputeVar) and n not in undefined_vars:
elif n not in undefined_vars:
# check whether the algorithm is stuck
n_nodes_new = len(nodes)
if n_nodes_new == n_nodes:
n_nodes = n_nodes_new
# add mutually depended nodes
if nodes:
node_keys = list(nodes.keys())
for n in node_names:
if n not in undefined_vars:
return keys, values, defined_vars, undefined_vars
def _get_inputs(self, n: str):
inputs = []
for inp in self.predecessors(n):
inputs.extend([inp] if isinstance(self.get_var(inp), ComputeVar) else self._get_inputs(inp))
return inputs
def _get_var_idx(self, idx: tuple, args: list, apply: bool = True, **kwargs):
# collect indexing variables where necessary
new_idx = []
for idx_tmp in idx:
v = self.get_var(idx_tmp.name if type(idx_tmp) is Symbol else idx_tmp)
except KeyError:
# turn index into a backend-specific string
idx_str, new_vars = self.backend.create_index_str(tuple(new_idx), apply=apply, **kwargs)
# add new variables to graph and index arguments
for key, v in new_vars.items():
if key != self.backend.idx_dummy_var:
vlabel, _ = self.add_var(label='index', value=v, vtype='constant')
idx_str = idx_str.replace(key, vlabel)
return idx_str
def _get_var_hist(self, var: str, delay: Union[ComputeVar, float]):
if var not in self._state_var_hist:
self._state_var_hist[var] = dict()
if delay not in self._state_var_hist[var]:
var_hist = f'{var}_hist{len(self._state_var_hist[var])}'
self.add_var(var_hist, value=self.get_var(var).value, vtype='variable')
self._state_var_hist[var][delay] = var_hist
var_hist = self._state_var_hist[var][delay]
return var_hist
def _prune(self):
# remove all subgraphs that contain constants only
for n in [node for node, out_degree in self.out_degree if out_degree == 0]:
if self.get_var(n).is_constant and n not in self._eq_nodes:
# remove all unconnected nodes
for n in [node for node, out_degree in self.out_degree if out_degree == 0]:
if self.in_degree(n) == 0 and n not in self._eq_nodes:
def _generate_unique_label(self, label: str) -> str:
if label == "t":
return label
if label in self._node_names:
n = self._node_names[label]
if n == 0:
label_new = f"{label}_v1"
label_new = f"{label}_v{n + 1}"
self._node_names[label] += 1
label_new = label
self._node_names[label] = 0
return label_new
def _index_state_var(y: np.ndarray, idx: Union[int, tuple, list]) -> np.ndarray:
if type(idx) is tuple and idx[1] - idx[0] == 1:
idx = (idx[0],)
elif type(idx) is int:
idx = (idx,)
return y[:, idx] if len(idx) == 1 else y[:, idx[0]:idx[1]]
def _process_func_call(expr: str, func: str, replacement: str):
# identify start and end of the function call
start = expr.find(f"{func}(")
end = expr[start:].find(')') + 1
# replace part in expression string
return expr.replace(expr[start:start + end], replacement)
[docs]class ComputeGraphBackProp(ComputeGraph):
def __init__(self, backend: str, **kwargs):
super().__init__(backend, **kwargs)
self._vecfield_vars = []
self._vecfield_var_str = ""
def _generate_vecfield_var(self, state_vec: np.ndarray, dtype: str):
return ""
def _generate_func_tail(self, code_gen, vecfield_key: str):
def _generate_vecfield(self, code_gen, indices: list, expressions: list, rhs_shapes: list, lhs_vars: list) -> list:
for lhs, expr in zip(lhs_vars, expressions):
lhs_var = f"delta_{lhs}"
code_gen.add_code_line(f"{lhs_var} = {expr}")
if len(self._vecfield_vars) > 1:
op_dict = self.backend.get_op("concatenate")
self._vecfield_var_str = f"{op_dict['call']}([{','.join(v for v in self._vecfield_vars)}], 0)"
self._vecfield_var_str = self._vecfield_vars.pop()
return []