Source code for pyrates.backend.computegraph

# -*- coding: utf-8 -*-
#
#
# PyRates software framework for flexible implementation of neural
# network model_templates and simulations. See also:
# https://github.com/pyrates-neuroscience/PyRates
#
# Copyright (C) 2017-2018 the original authors (Richard Gast and
# Daniel Rose), the Max-Planck-Institute for Human Cognitive Brain
# Sciences ("MPI CBS") and contributors
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <https://www.gnu.org/licenses/>
#
# CITATION:
#
# Richard Gast and Daniel Rose et. al. in preparation

"""This module provides the backend class that should be used to set up any backend in pyrates.
"""

# external _imports
from typing import Any, Callable, Union, Iterable, Optional
from networkx import MultiDiGraph
from sympy import Symbol, Expr, Function
import numpy as np

# meta infos
__author__ = "Richard Gast"
__status__ = "development"


#########################
# compute graph classes #
#########################


# numpy-based node class
[docs]class ComputeNode: """Base class for adding variables to the compute graph. Creates a numpy array with additional attributes for variable identification/retrieval from graph. Should be used as parent class for custom variable classes. Parameters ---------- name Full name of the variable in the original NetworkGraph (including the node and operator it belongs to). dtype Data-type of the variable. For valid data-types, check the documentation of the backend in use. shape Shape of the variable. """ __slots__ = ["name", "symbol", "dtype", "shape", "_value"] def __init__(self, name: str, symbol: Union[Symbol, Expr, Function], dtype: Optional[str] = None, shape: Optional[tuple] = None, def_shape: Optional[tuple] = None): """Instantiates a basic node of a ComputeGraph instance. """ self.name = name self.symbol = symbol self.shape = self._get_shape(shape, def_shape) self._value = np.zeros(self.shape) self.dtype = dtype self.set_dtype()
[docs] def reshape(self, shape: tuple, **kwargs): self._value = self.value.reshape(shape, **kwargs) self.shape = shape return self
[docs] def squeeze(self, axis=None): self._value = self.value.squeeze(axis=axis) self.shape = self._value.shape return self
[docs] def set_value(self, v: Union[float, np.ndarray]): self._value = np.asarray(v, dtype=self.dtype) self.shape = tuple(v.shape)
@property def value(self): """Returns current value of BaseVar. """ return self._value @property def is_constant(self): raise NotImplementedError("This method has to be defined by each child class.") @property def is_float(self): return "float" in self.dtype @property def is_complex(self): return "complex" in self.dtype def _is_equal_to(self, v): for attr in self.__slots__: if not hasattr(v, attr) or getattr(v, attr) != getattr(self, attr): return False return True def _get_value(self, value: Optional[Union[list, np.ndarray]] = None, dtype: Optional[str] = None, shape: Optional[tuple] = None): """Defines initial value of variable. """ # case I: create new array from shape and dtype if value is None: return np.zeros(shape=shape, dtype=dtype) # case II: transform values into an array if not hasattr(value, 'shape'): if type(value) is list: return self._get_value(value=np.asarray(value, dtype=dtype), dtype=dtype, shape=shape) return np.zeros(shape=shape, dtype=dtype) + value # case III: match given shape with the shape of the given value array if shape is not None: value = np.asarray(value, dtype=dtype) if value.shape == shape: return value if sum(shape) < sum(value.shape): return value.squeeze() idx = ",".join("None" if s == 1 else ":" for s in shape) return eval(f'value[{idx}]') # case IV: just ensure the correct data type of the value array return np.asarray(value, dtype=dtype)
[docs] def set_dtype(self, dtype: str = None): if dtype is None: if not self.dtype: if 'float' in str(self.value.dtype): self.dtype = 'float' elif 'complex' in str(self.value.dtype): self.dtype = 'complex' else: self.dtype = 'int' else: self.dtype = dtype
@staticmethod def _get_shape(s: Union[tuple, None], s_def: tuple): if s is None or sum(s) < 2: return s_def return s def __deepcopy__(self, memodict: dict): node = ComputeNode(name=self.name, symbol=self.symbol, dtype=self.dtype, shape=self.shape) node._value = np.zeros_like(node._value) + node._value return node def __str__(self): return self.name def __hash__(self): return hash(str(self))
[docs]class ComputeVar(ComputeNode): """Class for variables and vector-valued constants in the ComputeGraph. """ __slots__ = ComputeNode.__slots__ + ["vtype"] def __init__(self, name: str, symbol: Union[Symbol, Expr, Function], vtype: str, dtype: Optional[str] = None, shape: Optional[str] = None, value: Optional[Union[list, np.ndarray]] = None, def_shape: Optional[tuple] = None): # set attributes super().__init__(name=name, symbol=symbol, dtype=dtype, shape=shape, def_shape=def_shape) self.vtype = vtype # adjust variable value self.set_value(self._get_value(value=value, shape=self.shape, dtype=self.dtype)) @property def is_constant(self): return self.vtype == 'constant'
[docs]class ComputeOp(ComputeNode): """Class for ComputeGraph nodes that represent mathematical operations. """ __slots__ = ComputeNode.__slots__ + ["func", "expr"] def __init__(self, name: str, symbol: Union[Symbol, Expr, Function], func: Callable, expr: Expr, dtype: Optional[str] = None, shape: Optional[str] = None): # set attributes super().__init__(name=name, symbol=symbol, dtype=dtype, shape=shape) self.func = func self.expr = expr @property def is_constant(self): return False
# networkx-based graph class
[docs]class ComputeGraph(MultiDiGraph): """Creates a compute graph where nodes are all constants and variables of the network and edges are the mathematical operations linking those variables/constants together to form equations. """ def __init__(self, backend: str, **kwargs): super().__init__() # choose a backend if backend == 'tensorflow': from pyrates.backend.tensorflow import TensorflowBackend backend = TensorflowBackend elif backend == 'torch': from pyrates.backend.torch import TorchBackend backend = TorchBackend elif backend == 'fortran': from pyrates.backend.fortran import FortranBackend backend = FortranBackend elif backend == 'julia': from pyrates.backend.julia import JuliaBackend backend = JuliaBackend elif backend == 'matlab': from pyrates.backend.matlab import MatlabBackend backend = MatlabBackend else: from pyrates.backend.base import BaseBackend backend = BaseBackend # backend-related attributes self.backend = backend(**kwargs) self.var_updates = {'DEs': dict(), 'non-DEs': dict()} self._eq_nodes = [] self._state_var_indices = dict() self._state_var_hist = dict() self._node_names = {} @property def state_vars(self): return list(self.var_updates['DEs'].keys())
[docs] def add_var(self, label: str, value: Any, vtype: str, **kwargs): unique_label = self._generate_unique_label(label) var = ComputeVar(name=unique_label, symbol=Symbol(unique_label), value=value, vtype=vtype, **kwargs) super().add_node(unique_label, node=var) return unique_label, self.nodes[unique_label]['node']
[docs] def add_op(self, inputs: Union[list, tuple], label: str, expr: Expr, func: Callable, **kwargs): # add target node that contains result of operation unique_label = self._generate_unique_label(label) op = ComputeOp(name=unique_label, symbol=Symbol(unique_label), func=func, expr=expr, **kwargs) super().add_node(unique_label, node=op) # add edges from source nodes to target node for i, v in enumerate(inputs): super().add_edge(v, unique_label, key=i) return unique_label, self.nodes[unique_label]['node']
[docs] def add_var_update(self, var: str, update: str, differential_equation: bool = False): # store mapping between left-hand side variable and right-hand side update if differential_equation: self.var_updates['DEs'][var] = update else: self.var_updates['non-DEs'][var] = update # remember var and update node to ensure that they are not pruned during compilation self._eq_nodes.extend([var, update])
[docs] def get_var(self, var: str, from_backend: bool = False): v = self.nodes[var]['node'] if from_backend: return self.backend.get_var(v) return v
[docs] def get_op(self, op: str, **kwargs) -> dict: return self.backend.get_op(op, **kwargs)
[docs] def eval_graph(self): for n in self.var_updates['non-DEs'].values(): self.eval_subgraph(n) return self.eval_nodes(self.var_updates['DEs'].values())
[docs] def eval_nodes(self, nodes: Iterable): return [self.eval_node(n) for n in nodes]
[docs] def eval_node(self, n): inputs = tuple([self.eval_node(inp) for inp in self.predecessors(n)]) node = self.get_var(n) if isinstance(node, ComputeOp): return node.func(*inputs) return node.value
[docs] def eval_subgraph(self, n): inputs = [] input_nodes = [node for node in self.predecessors(n)] for inp in input_nodes: inputs.append(self.eval_subgraph(inp)) self.remove_node(inp) node = self.get_var(n) if inputs: node.set_value(node.func(*tuple(inputs))) return node.value
[docs] def remove_subgraph(self, n): for inp in self.predecessors(n): self.remove_subgraph(inp) self.remove_node(n)
[docs] def compile(self): # evaluate constant-based operations out_nodes = [node for node, out_degree in self.out_degree if out_degree == 0] for node in out_nodes: # process inputs of node for inp in self.predecessors(node): if self.get_var(inp).is_constant: self.eval_subgraph(inp) # evaluate node if all its inputs are constants if all([self.get_var(inp).is_constant for inp in self.predecessors(node)]) and node not in self._eq_nodes: self.eval_subgraph(node) # remove unconnected nodes and constants from graph self._prune() return self
[docs] def to_func(self, func_name: str, to_file: bool = True, dt_adapt: bool = True, **kwargs): # finalize compute graph self.compile() # create state variable vector and state variable update vector ############################################################### variables = [] idx = 0 for var, update in self.var_updates['DEs'].items(): # extract left-hand side nodes from graph lhs, rhs = self._process_var_update(var, update) variables.append(lhs.value) # store information of the original, non-vectorized state variable vshape = sum(lhs.shape) if vshape >= 1: self._state_var_indices[var] = (idx, idx+vshape) idx += vshape else: self._state_var_indices[var] = idx idx += 1 # add collected state variables to the backend try: state_vec = np.concatenate(variables, axis=0) except ValueError: state_vec = np.asarray(variables) dtype = 'complex' if 'complex' in state_vec.dtype.name else 'float' state_var_key, y = self.add_var(label='y', vtype='state_var', value=state_vec, dtype=dtype) rhs_var_key = self._generate_vecfield_var(state_vec, dtype) try: t = self.get_var('t') except KeyError: _, t = self.add_var(label='t', vtype='state_var', value=0.0 if dt_adapt else 0, dtype='float' if dt_adapt else 'int', shape=()) self.backend.register_vars([t, y]) # create a string containing all computations and variable updates represented by the compute graph func_args, code_gen = self._to_str() func_body = code_gen.generate() code_gen.code.clear() # generate function head add_hist_calls = self._state_var_hist and code_gen.add_hist_arg func_args = code_gen.generate_func_head(func_name=func_name, state_var=state_var_key, return_var=rhs_var_key, func_args=[self.get_var(arg) for arg in func_args], add_hist_func=add_hist_calls) # extract state variable histories for delayed interactions code_gen.add_linebreak() for var, delays in self._state_var_hist.items(): # extract index of variable in state vector idx = self._state_var_indices[var] # extract state variable history from backend-specific buffer if type(idx) is not int: if idx[1]-idx[0] < 2: idx = idx[0] for delay, v_hist in delays.items(): # type: ComputeVar, str code_gen.add_var_hist(lhs=v_hist, delay=delay, state_idx=idx, var=var) # add lines from function body after function head code_gen.add_linebreak() code_gen.add_code_line(func_body) code_gen.add_linebreak() # generate function tail self._generate_func_tail(code_gen, rhs_var_key) # generate the function (and write to file, optionally) func_args_tmp = func_args[4:] if add_hist_calls else func_args[3:] func = code_gen.generate_func(func_name=func_name, to_file=to_file, func_args=func_args_tmp, state_vars=self.state_vars, **kwargs) # OPTIONAL: write function arguments (state vectors and constants) to file c_fn = kwargs.pop('constants_file_name', None) if c_fn: arg_dict = {arg: self.get_var(arg).value for arg in func_args if arg != 'hist'} if code_gen.lags: arg_dict['lags'] = list(code_gen.lags.keys()) fn = f'{self.backend.fdir}/{c_fn}' if self.backend.fdir else c_fn code_gen.to_file(fn, **arg_dict) # finalize the function arguments fargs = [] for arg in func_args: if arg == 'hist': if 'hist' in kwargs: arg = kwargs.pop('hist') else: y_init = np.asarray(state_vec[:]) arg = code_gen.get_hist_func(y_init) fargs.append(arg) else: fargs.append(self.get_var(arg, from_backend=True)) return func, tuple(fargs), tuple(func_args), self._state_var_indices.copy()
[docs] def run(self, func: Callable, func_args: tuple, T: float, dt: float, dts: Optional[float] = None, outputs: Optional[dict] = None, **kwargs) -> dict: # pre-process outputs if outputs is None: outputs = {key: key for key in self.state_vars} for key in outputs.copy(): var = outputs.pop(key) outputs[key] = self._state_var_indices[var] # handle other arguments if dts is None: dts = dt solver = kwargs.pop('solver', 'euler') # call backend method results, times = self.backend.run(func=func, func_args=func_args, T=T, dt=dt, dts=dts, solver=solver, **kwargs) # set state variables to final simulated value for key in self.state_vars: var = self.get_var(key) idx = self._state_var_indices[key] var.set_value(np.reshape(self._index_state_var(results, idx)[-1], var.shape)) # reduce state recordings to requested state variables for key, idx in outputs.items(): outputs[key] = self._index_state_var(results, idx) outputs['time'] = times return outputs
[docs] def clear(self) -> None: """Deletes build directory and removes all compute graph nodes """ # delete network nodes and variables from the compute graph for n in list(self.nodes.keys()): self.remove_subgraph(n) self.var_updates.clear() self._state_var_indices.clear() self._eq_nodes.clear() # clear code generator self.backend.clear()
def _to_str(self): # preparations code_gen = self.backend # extract state variable from state vector rhs_indices_str = [] for var in self.state_vars: # extract index of variable in state vector idx = (self._state_var_indices[var],) # extract state variable from state vector rhs_idx, _ = code_gen.create_index_str(idx) code_gen.add_var_update(lhs=self.get_var(var), rhs=f"y{rhs_idx}") rhs_indices_str.append(idx) # get equation string and argument list for each non-DE node at the end of the compute graph hierarchy func_args2, delete_args1 = self._generate_update_equations(differential_equations=False) code_gen.add_linebreak() # get equation string and argument list for each DE node at the end of the compute graph hierarchy func_args1, delete_args2 = self._generate_update_equations(differential_equations=True, indices=rhs_indices_str) # remove unnecessary function arguments func_args = func_args1 + func_args2 for arg in delete_args1 + delete_args2: while arg in func_args: func_args.pop(func_args.index(arg)) return func_args, code_gen def _generate_update_equations(self, differential_equations: bool, indices: list = None) -> tuple: code_gen = self.backend # extract relevant compute graph nodes and bring them into the correct order nodes = self.var_updates['DEs' if differential_equations else 'non-DEs'] nodes, updates, def_vars, undef_vars = self._sort_var_updates(nodes=nodes, differential_equations=differential_equations) # collect right-hand side expression and all input variables to these expressions func_args, expressions, var_names, rhs_shapes, lhs_indices = undef_vars, [], [], [], [] for node, update in zip(nodes, updates): # collect shape of the right-hand side variable v = self.get_var(update) try: v_eval = self.eval_node(update) v.set_value(v_eval) except IndexError: pass rhs_shapes.append(v.shape) # collect expression and variables of right-hand side of equation expr_args, expr = self._node_to_expr(update) func_args.extend(expr_args) expr_str, expr_args, _, _ = self._expr_to_str(expr, apply=True) func_args.extend(expr_args) expressions.append(expr_str) # process left-hand side of equation var = self.get_var(node) if isinstance(var, ComputeOp): # process indexing of left-hand side variable idx_args, lhs = self._node_to_expr(node) if lhs.args[0].name not in def_vars: idx_args.append(lhs.args[0].name) func_args.extend(idx_args) _, idx_args, lhs_var, idx = self._expr_to_str(lhs, apply=False) func_args.extend(idx_args) else: # process normal update of left-hand side variable lhs_var = var.name idx = None var_names.append(lhs_var) lhs_indices.append(idx) # add the left-hand side assignments of the collected right-hand side expressions to the code generator if differential_equations: # differential equation (DE) update if not indices: raise ValueError('State variables need to be stored in a single state vector, for which the indices ' 'have to be passed to this method.') add_args = self._generate_vecfield(code_gen, indices, expressions, rhs_shapes, var_names) func_args = add_args + func_args else: # non-differential equation update if indices: raise ValueError('Indices to non-state variables should be defined in the respective equations, not' 'be passed to this method.') indices = lhs_indices # non-DE update stored in a single variable for target_var, expr, idx, shape in zip(var_names, expressions, indices, rhs_shapes): try: idx = self.get_var(idx) except (KeyError, AttributeError, TypeError): pass code_gen.add_var_update(lhs=self.get_var(target_var), rhs=expr, lhs_idx=idx, rhs_shape=shape) return func_args, def_vars def _generate_vecfield(self, code_gen, indices: list, expressions: list, rhs_shapes: list, lhs_vars: list) -> list: # DE updates stored in a state-vector dy = self.get_var("dy") for idx, expr, shape in zip(indices, expressions, rhs_shapes): code_gen.add_var_update(lhs=dy, rhs=expr, lhs_idx=idx, rhs_shape=shape) # add rhs var to function arguments return ['dy'] def _generate_vecfield_var(self, state_vec: np.ndarray, dtype: str) -> str: key, _ = self.add_var(label='dy', vtype='state_var', value=np.zeros_like(state_vec), dtype=dtype) return key def _generate_func_tail(self, code_gen, vecfield_key: str): code_gen.generate_func_tail(rhs_var=vecfield_key) def _node_to_expr(self, n: str, **kwargs) -> tuple: expr_args = [] node = self.get_var(n) # case I: node is a mathematical operation and its inputs need to be treated try: # process node inputs expr_info = {self.get_var(inp).symbol: self._node_to_expr(inp, **kwargs) for inp in self.predecessors(n)} # replace old inputs with its processed versions expr = node.expr for expr_old, (args, expr_new) in expr_info.items(): if expr_old != expr_new: expr = expr.replace(expr_old, expr_new) expr_args.extend(args) # replace generic function calls with the backend-specific function calls try: expr_old = expr.func.__name__ func_info = self.get_op(expr_old, shape=node.shape) expr = expr.replace(expr.func, Function(func_info['call'])) except (AttributeError, KeyError): pass # case II: node is a simple variable or constant except AttributeError: # add constants to the expression arguments list if node.is_constant: expr_args.append(n) expr = node.symbol elif 'dummy_constant' in node.name: expr = Symbol(str(float(node.value))) else: expr = node.symbol return expr_args, expr def _expr_to_str(self, expr: Any, expr_str: str = None, apply: bool = True, **kwargs) -> tuple: # preparations ############### # initializations index_args = [] func = "" idx = "" # ensure expression string exists if not expr_str: expr_str = str(expr) # transform expression arguments into strings expr_args = [] for arg in expr.args: expr_part, args, _, _ = self._expr_to_str(arg, **kwargs) expr_str = expr_str.replace(str(arg), expr_part) index_args.extend(args) expr_args.append(expr_part) var = str(expr_args[0]) if expr.args else "" # process indexing operations ############################# if 'index_1d(' in expr_str: # replace `index` calls with brackets-based indexing idx = self._get_var_idx(idx=(expr.args[1],), args=index_args, apply=apply, **kwargs) func = 'index_1d' elif 'index_2d(' in expr_str: # replace `2d_index` calls with brackets-based indexing idx = self._get_var_idx(idx=(expr.args[1], expr.args[2]), args=index_args, apply=apply, **kwargs) func = 'index_2d' elif 'index_range(' in expr_str: # replace `range_index` calls with brackets-based indexing idx = self._get_var_idx(idx=((expr.args[1], expr.args[2]),), args=index_args, apply=apply, **kwargs) func = 'index_range' elif 'index_axis(' in expr_str: # replace `axis_index` calls with brackets-based indexing if len(expr.args) < 2: idx = self._get_var_idx(idx=(':',), args=index_args, apply=apply, **kwargs) else: idx = self._get_var_idx(args=index_args, apply=apply, idx=tuple([':' for _ in range(expr.args[2])] + [f"{expr.args[1]}"]), **kwargs) func = "index_axis" # either apply the above indexing calls or return them if func and apply: replacement = self.backend.finalize_idx_str(var=self.get_var(var), idx=idx) expr_str = self._process_func_call(expr=expr_str, func=func, replacement=replacement) # handle other function calls ############################# if 'identity(' in expr_str: # replace `no_op` calls with first argument to the function call expr_str = self._process_func_call(expr=expr_str, func='identity', replacement=var) if 'past(' in expr_str: # replace past calls with the delayed version of the backend variable try: delay = self.get_var(expr.args[1].name) except AttributeError: delay = float(expr.args[1]) replacement = self._get_var_hist(var=var, delay=delay) expr_str = self._process_func_call(expr=expr_str, func="past", replacement=replacement) # backend-specific function call adjustments expr_str = self.backend.expr_to_str(expr_str, expr.args) return expr_str, index_args, var, idx def _process_var_update(self, var: str, update: str) -> tuple: # extract nodes lhs = self.get_var(var) rhs = self.eval_node(update) # extract common shape if lhs.shape == rhs.shape: return lhs, rhs try: rhs = rhs.reshape(lhs.shape) return lhs, rhs except ValueError: raise ValueError( f"Shapes of state variable {var} and its right-hand side update {self.get_var(update).expr} do not" " match.") def _sort_var_updates(self, nodes: dict, differential_equations: bool = True) -> tuple: # case I: for differential equations, do not perform any sorting if differential_equations: return list(nodes.keys()), list(nodes.values()), [], [] # case II: for non-differential equations, sort them according to their graph connections ######################################################################################### # step 1: ensure lhs-indexing operations are considered as well node_names, node_keys = [], [] for node in nodes: n = self.get_var(node) if type(n) is ComputeVar: node_names.append(node) else: node_names.append(list(self._get_inputs(node))[-1]) node_keys.append(node) keys, values, defined_vars, undefined_vars = [], [], [], [] n_nodes = len(nodes) while nodes: for node, update in nodes.copy().items(): # go through node inputs and check whether it depends on other equations to be evaluated first dependent, inp = False, "" for inp in self._get_inputs(update): if inp in node_names: idx = node_names.index(inp) if node_keys[idx] != node: dependent = True break # decide whether this equation can be evaluated now if dependent: continue else: idx = node_keys.index(node) n = node_names.pop(idx) node_keys.pop(idx) nodes.pop(node) keys.append(node) values.append(update) if isinstance(self.get_var(node), ComputeVar) and n not in undefined_vars: defined_vars.append(n) elif n not in undefined_vars: undefined_vars.append(n) # check whether the algorithm is stuck n_nodes_new = len(nodes) if n_nodes_new == n_nodes: break else: n_nodes = n_nodes_new # add mutually depended nodes if nodes: node_keys = list(nodes.keys()) keys.extend(node_keys) values.extend(list(nodes.values())) for n in node_names: if n not in undefined_vars: undefined_vars.append(n) return keys, values, defined_vars, undefined_vars def _get_inputs(self, n: str): inputs = [] for inp in self.predecessors(n): inputs.extend([inp] if isinstance(self.get_var(inp), ComputeVar) else self._get_inputs(inp)) return inputs def _get_var_idx(self, idx: tuple, args: list, apply: bool = True, **kwargs): # collect indexing variables where necessary new_idx = [] for idx_tmp in idx: try: v = self.get_var(idx_tmp.name if type(idx_tmp) is Symbol else idx_tmp) new_idx.append(v) except KeyError: new_idx.append(idx_tmp) # turn index into a backend-specific string idx_str, new_vars = self.backend.create_index_str(tuple(new_idx), apply=apply, **kwargs) # add new variables to graph and index arguments for key, v in new_vars.items(): if key != self.backend.idx_dummy_var: vlabel, _ = self.add_var(label='index', value=v, vtype='constant') idx_str = idx_str.replace(key, vlabel) args.append(vlabel) return idx_str def _get_var_hist(self, var: str, delay: Union[ComputeVar, float]): if var not in self._state_var_hist: self._state_var_hist[var] = dict() if delay not in self._state_var_hist[var]: var_hist = f'{var}_hist{len(self._state_var_hist[var])}' self.add_var(var_hist, value=self.get_var(var).value, vtype='variable') self._state_var_hist[var][delay] = var_hist else: var_hist = self._state_var_hist[var][delay] return var_hist def _prune(self): # remove all subgraphs that contain constants only for n in [node for node, out_degree in self.out_degree if out_degree == 0]: if self.get_var(n).is_constant and n not in self._eq_nodes: self.remove_subgraph(n) # remove all unconnected nodes for n in [node for node, out_degree in self.out_degree if out_degree == 0]: if self.in_degree(n) == 0 and n not in self._eq_nodes: self.remove_node(n) def _generate_unique_label(self, label: str) -> str: if label == "t": return label if label in self._node_names: n = self._node_names[label] if n == 0: label_new = f"{label}_v1" else: label_new = f"{label}_v{n + 1}" self._node_names[label] += 1 else: label_new = label self._node_names[label] = 0 return label_new @staticmethod def _index_state_var(y: np.ndarray, idx: Union[int, tuple, list]) -> np.ndarray: if type(idx) is tuple and idx[1] - idx[0] == 1: idx = (idx[0],) elif type(idx) is int: idx = (idx,) return y[:, idx] if len(idx) == 1 else y[:, idx[0]:idx[1]] @staticmethod def _process_func_call(expr: str, func: str, replacement: str): # identify start and end of the function call start = expr.find(f"{func}(") end = expr[start:].find(')') + 1 # replace part in expression string return expr.replace(expr[start:start + end], replacement)
[docs]class ComputeGraphBackProp(ComputeGraph): def __init__(self, backend: str, **kwargs): super().__init__(backend, **kwargs) self._vecfield_vars = [] self._vecfield_var_str = "" def _generate_vecfield_var(self, state_vec: np.ndarray, dtype: str): return "" def _generate_func_tail(self, code_gen, vecfield_key: str): code_gen.generate_func_tail(rhs_var=self._vecfield_var_str) def _generate_vecfield(self, code_gen, indices: list, expressions: list, rhs_shapes: list, lhs_vars: list) -> list: for lhs, expr in zip(lhs_vars, expressions): lhs_var = f"delta_{lhs}" code_gen.add_code_line(f"{lhs_var} = {expr}") self._vecfield_vars.append(lhs_var) if len(self._vecfield_vars) > 1: op_dict = self.backend.get_op("concatenate") self._vecfield_var_str = f"{op_dict['call']}([{','.join(v for v in self._vecfield_vars)}], 0)" else: self._vecfield_var_str = self._vecfield_vars.pop() return []