Source code for pyrates.ir.operator_graph

# -*- coding: utf-8 -*-
#
#
# PyRates software framework for flexible implementation of neural 
# network model_templates and simulations. See also:
# https://github.com/pyrates-neuroscience/PyRates
# 
# Copyright (C) 2017-2018 the original authors (Richard Gast and 
# Daniel Rose), the Max-Planck-Institute for Human Cognitive Brain 
# Sciences ("MPI CBS") and contributors
# 
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
# 
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
# 
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <https://www.gnu.org/licenses/>
# 
# CITATION:
# 
# Richard Gast and Daniel Rose et. al. in preparation
"""
"""
from typing import Iterator, Dict, List

import numpy as _np
from networkx import DiGraph, find_cycle, NetworkXNoCycle

from pyrates.backend import PyRatesException
from pyrates.ir.operator import OperatorIR

__author__ = "Daniel Rose"
__status__ = "Development"


[docs]class OperatorGraph(DiGraph): """Intermediate representation for nodes and edges.""" def __init__(self, operators: Dict[str, OperatorIR] = None, template: str = ""): super().__init__() if operators is None: operators = {} # compute hash from incoming operators. Different order of operators in input might lead to different hash. self._h = hash(tuple(operators.values())) # collect all information about output variables of operators ############################################################# all_outputs = {} # type: Dict[str, List[str]] for key, operator in operators.items(): # add operator as node to local operator_graph inputs = {var: dict(sources=set(), reduce_dim=True) for var in operator.inputs} self.add_node(key, operator=operator, inputs=inputs, label=key) # collect all output variables out_var = operator.output # check, if variable name exists in outputs and create empty list if it doesn't if out_var not in all_outputs: all_outputs[out_var] = [] all_outputs[out_var].append(key) # link outputs to inputs ######################## for label, data in self.nodes(data=True): op = data["operator"] for in_var in op.inputs: if in_var in all_outputs: # link all collected outputs of given variable in inputs field of operator for predecessor in all_outputs[in_var]: # add predecessor output as source; this would also work for non-equal variable names data["inputs"][in_var]["sources"].add(predecessor) # adding into set means there will be no order, but also nothing can be added twice self.add_edge(predecessor, label) else: pass # means, that 'source' will remain an empty list and no incoming edge will be added # check for cycle in operator graph ################################### try: find_cycle(self) except NetworkXNoCycle: pass else: raise PyRatesException("Found cyclic operator graph. Cycles are not allowed for operators within one node " "or edge.") def __hash__(self): return self._h
[docs] def getitem_from_iterator(self, key: str, key_iter: Iterator[str]): """ Helper function for Python magic __getitem__. Accepts an iterator that yields string keys. If `key_iter` contains one key, an operator will be (looked for and) returned. If it instead contains two keys, properties of a variable that belong to an operator is returned. Parameters ---------- key key_iter Returns ------- item operator or variable properties """ try: var = next(key_iter) except StopIteration: # no variable specified, so we return an operator item = self.nodes[key]["operator"] else: # variable specified, so we return variable properties instead item = self.nodes[key]["operator"].variables[var] return item
[docs] def __iter__(self): """Return an iterator containing all operator labels in the operator graph.""" return iter(self.nodes(data=True))
[docs] def operators(self, get_ops=False, get_vals=False): """Alias for self.nodes""" if get_ops and get_vals: return ((data["label"], op, data["values"]) for op, data in self.nodes(data=True)) elif get_ops: return ((data["label"], op) for op, data in self.nodes(data=True)) elif get_vals: return ((data["label"], data["values"]) for op, data in self.nodes(data=True)) else: return self.nodes
[docs]class VectorizedOperatorGraph(DiGraph): """Alternate version of `OperatorGraph` that is produced during vectorization. Contents of this version are not particularly protected and the instance is not cached.""" def __init__(self, op_graph: OperatorGraph = None, values: dict = None): super().__init__() if op_graph is None: pass else: for op_key, data in op_graph: try: op = data["operator"] except KeyError: # need to add this for the copy capability self.add_operator(op_key, **data) else: self.add_operator(op_key, inputs=data["inputs"], equations=list(op.equations), variables=op.variables.to_dict(), output=op.output) # retrieve values from value dict and pass them into variable dictionary of operator op_values = values[op_key] #deepcopy(values[op_key]) op_vars = self.operators[op_key]["variables"] for var_key, value in op_values.items(): if op_vars[var_key]["vtype"] == "input_variable": op_vars[var_key]["value"] = value op_vars[var_key]["vtype"] = "input" else: op_vars[var_key]["value"] = value if type(value) is list else [value] self.add_edges_from(op_graph.edges)
[docs] def add_operator(self, *args, **kwargs): """Alias for `self.add_node`""" self.add_node(*args, **kwargs)
[docs] def getitem_from_iterator(self, key: str, key_iter: Iterator[str]): """ Helper function for Python magic __getitem__. Accepts an iterator that yields string keys. If `key_iter` contains one key, an operator will be (looked for and) returned. If it instead contains two keys, properties of a variable that belong to an operator is returned. Parameters ---------- key key_iter Returns ------- item operator or variable properties """ try: var = next(key_iter) except StopIteration: # no variable specified, so we return an operator item = self.nodes[key] else: # variable specified, so we return variable properties instead item = self.nodes[key]["variables"][var] return item
@property def operators(self): # , get_ops=False, get_vals=False): """Alias for self.nodes""" # if get_ops and get_vals: # return ((data["label"], op, data["values"]) for op, data in self.nodes(data=True)) # elif get_ops: # return ((data["label"], op) for op, data in self.nodes(data=True)) # elif get_vals: # return ((data["label"], data["values"]) for op, data in self.nodes(data=True)) # else: return self.nodes
[docs] def append_values(self, value_dict: dict) -> dict: """Append value along vector dimension of operators. Parameters ---------- value_dict Returns ------- dict Dictionary containing the indices of the appended variable values in the overall vectorized variables. """ var_ranges = {} for op_key, variables_updates in value_dict.items(): original_variables = self.nodes[op_key]["variables"] for var_key, value in variables_updates.items(): var = original_variables[var_key] shape = _np.shape(value) shape_sum = _np.sum(shape) if shape_sum > 1: if _np.all(shape == _np.shape(var["value"])[1:]): raise ValueError(f"Inconsistent dimensions of variable {var}. Dimension of value to add: " f"{shape} Internal dimension of vectorized value array: " f"{_np.shape(var['value'])[1:]}") var["value"].append(value) elif shape_sum == 0: # case, if shape == () var["value"].append(value) else: var["value"].extend(value) # also recompute shape old_shape = var["shape"] var["shape"] = _np.shape(var["value"]) var_ranges[(op_key, var_key)] = (old_shape[0], var['shape'][0]) return var_ranges
@property def var_lengths(self) -> dict: """Dictionary containing the length of each variable of each operator. """ var_lengths = {} for op in self.operators: for var, data in self.nodes[op]['variables'].items(): var_lengths[(op, var)] = data['shape'][0] if data['shape'] else 1 return var_lengths