Source code for pyrates.ir.node

# -*- coding: utf-8 -*-
#
#
# PyRates software framework for flexible implementation of neural 
# network model_templates and simulations. See also:
# https://github.com/pyrates-neuroscience/PyRates
# 
# Copyright (C) 2017-2018 the original authors (Richard Gast and 
# Daniel Rose), the Max-Planck-Institute for Human Cognitive Brain 
# Sciences ("MPI CBS") and contributors
# 
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
# 
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
# 
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <https://www.gnu.org/licenses/>
# 
# CITATION:
# 
# Richard Gast and Daniel Rose et. al. in preparation
"""
"""
from typing import Iterator, Callable

from pyrates.ir.abc import AbstractBaseIR
from pyrates.ir.operator_graph import OperatorGraph, VectorizedOperatorGraph
from pyrates.backend.parser import get_unique_label

__author__ = "Daniel Rose"
__status__ = "Development"

node_cache = {}
op_cache = {}
node_labels = {}


[docs]def clear_ir_caches(): node_cache.clear() op_cache.clear() node_labels.clear()
[docs]def cache_func(label: str, operators: dict, values: dict = None, template: str = None, ir_class: Callable = None, **kwargs): if operators is None: operators = {} # compute hash from incoming operators. Different order of operators in input might lead to different hash. op_graph = OperatorGraph(operators) h = hash(op_graph) changed_labels = dict() vectorize = kwargs.pop('vectorize', True) try: # if vectorization is to be skipped, ignore cached IRs if not vectorize: raise KeyError # extract node from cache node = node_cache[h] # change operator labels if necessary for name, op in operators.items(): op_hash = hash(op) for cached_name, cached_op in op_cache[h]: if op_hash == hash(cached_op["operator"]): changed_labels[name] = cached_name for old_name, new_name in changed_labels.items(): try: values[new_name] = values.pop(old_name) except (AttributeError, KeyError): pass # extend cached node var_ranges = node.extend(NodeIR(label, operators=op_graph, values=values, template=template), **kwargs) except KeyError: # create new node op_cache[h] = op_graph label, labels_tmp = get_unique_label(label, node_labels) node_labels.update(labels_tmp) node = ir_class(label, operators=op_graph, values=values, template=template, **kwargs) node_cache[h] = node var_ranges = {key: (0, val) for key, val in node.op_graph.var_lengths.items()} return node, changed_labels, var_ranges
[docs]class NodeIR(AbstractBaseIR): __slots__ = ["_op_graph", "values"] def __init__(self, label: str, operators: OperatorGraph, values: dict = None, template: str = None): super().__init__(label, template) self._op_graph = operators self.values = values @property def op_graph(self): return self._op_graph
[docs] def getitem_from_iterator(self, key: str, key_iter: Iterator[str]): """Alias for self.op_graph.getitem_from_iterator""" return self.op_graph.getitem_from_iterator(key, key_iter)
[docs] def __iter__(self): """Return an iterator containing all operator labels in the operator graph.""" return iter(self.op_graph)
@property def operators(self): return self.op_graph.operators def __hash__(self): raise NotImplementedError
[docs]class VectorizedNodeIR(AbstractBaseIR): """Alternate version of NodeIR that takes a full NodeIR as input and creates a vectorized form of it.""" __slots__ = ["op_graph", "length", "_var_lengths"] def __init__(self, label: str, operators: OperatorGraph, values: dict = None, template: str = None): super().__init__(label, template) self.op_graph = VectorizedOperatorGraph(operators, values=values) # save current length of this node vector. self.length = 1
[docs] def getitem_from_iterator(self, key: str, key_iter: Iterator[str]): """Alias for self.op_graph.getitem_from_iterator""" return self.op_graph.getitem_from_iterator(key, key_iter)
[docs] def __iter__(self): """Return an iterator containing all operator labels in the operator graph.""" return iter(self.op_graph)
@property def operators(self): return self.op_graph.operators def __hash__(self): raise NotImplementedError
[docs] def extend(self, node: NodeIR) -> dict: """ Extend variables vectors by values from one additional node. Parameters ---------- node A node whose values are used to extend the vector dimension of this vectorized node. Returns ------- dict Dictionary containing the indices of the appended variable values in the overall vectorized variables. """ # add values to respective lists in collapsed node var_ranges = self.op_graph.append_values(node.values) self.length += 1 return var_ranges
[docs] def __len__(self): """Returns size of this vector node as recorded in self._length. Returns ------- self._length """ return self.length
[docs] def add_op(self, op_key: str, inputs: dict, output: str, equations: list, variables: dict): """Wrapper for internal `op_graph.add_operator` that adds any values to node-level values dictionary for quick access Parameters ---------- op_key Name of operator to be added inputs dictionary definining input variables of the operator output string defining name of single output variable equations list of equations (strings) variables dictionary describing variables Returns ------- """ # add operator to op_graph self.op_graph.add_operator(op_key, inputs=inputs, output=output, equations=equations, variables=variables)
[docs] def add_op_edge(self, source_op_key: str, target_op_key: str, **attr): """ Alias to `self.op_graph.add_edge` Parameters ---------- source_op_key target_op_key attr Returns ------- """ self.op_graph.add_edge(source_op_key, target_op_key, **attr)