# -*- coding: utf-8 -*-
# PyRates software framework for flexible implementation of neural
# network model_templates and simulations. See also:
# https://github.com/pyrates-neuroscience/PyRates
# Copyright (C) 2017-2018 the original authors (Richard Gast and
# Daniel Rose), the Max-Planck-Institute for Human Cognitive Brain
# Sciences ("MPI CBS") and contributors
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# GNU General Public License for more details.
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>
# Richard Gast and Daniel Rose et. al. in preparation
"""This module provides parser classes and functions to parse string-based equations into symbolic representations of
# external _imports
import math
import typing as tp
from numbers import Number
import numpy as np
# from pyparsing import Literal, CaselessLiteral, Word, Combine, Optional, \
# ZeroOrMore, Forward, nums, alphas, ParserElement
from sympy import Expr, Symbol, lambdify, sympify
# pyrates internal _imports
from pyrates.backend.computegraph import ComputeGraph, ComputeNode
# meta infos
__author__ = "Richard Gast"
__status__ = "development"
# expression parsers (lhs/rhs of an equation) #
[docs]class Algebra:
def __init__(self, **kwargs) -> None:
"""Instantiates expression parser.
# call super method
self.algebra = kwargs.pop('algebra', None)
# define algebra
# if not self.algebra:
# # general symbols
# point = Literal(".")
# comma = Literal(",")
# colon = Literal(":")
# e = CaselessLiteral("E")
# pi = CaselessLiteral("PI")
# # parentheses
# par_l = Literal("(")
# par_r = Literal(")").setParseAction(self._push_first)
# idx_l = Literal("[")
# idx_r = Literal("]")
# # basic mathematical operations
# plus = Literal("+")
# minus = Literal("-")
# mult = Literal("*")
# div = Literal("/")
# mod = Literal("%")
# dot = Literal("@")
# exp_1 = Literal("^")
# exp_2 = Combine(mult + mult)
# transp = Combine(point + Literal("T"))
# inv = Combine(point + Literal("I"))
# # numeric types
# num_float = Combine(Word("-" + nums, nums) +
# Optional(point + Optional(Word(nums))) +
# Optional(e + Word("-" + nums, nums)))
# num_int = Word("-" + nums, nums)
# # variables and functions
# name = Word(alphas, alphas + nums + "_$")
# func_name = Combine(name + par_l, adjacent=True)
# # math operation groups
# op_add = plus | minus
# op_mult = mult | div | dot | mod
# op_exp = exp_1 | exp_2 | inv | transp
# # logical operations
# greater = Literal(">")
# less = Literal("<")
# equal = Combine(Literal("=") + Literal("="))
# unequal = Combine(Literal("!") + Literal("="))
# greater_equal = Combine(Literal(">") + Literal("="))
# less_equal = Combine(Literal("<") + Literal("="))
# # logical operations group
# op_logical = greater_equal | less_equal | unequal | equal | less | greater
# # pre-allocations
# self.algebra = Forward()
# exponential = Forward()
# index_multiples = Forward()
# # basic organization units
# index_start = idx_l.setParseAction(self._push_first)
# index_end = idx_r.setParseAction(self._push_first)
# index_comb = colon.setParseAction(self._push_first)
# arg_comb = comma.setParseAction(self._push_first)
# arg_tuple = par_l + ZeroOrMore(self.algebra.suppress() + Optional(arg_comb)) + par_r
# func_arg = arg_tuple | self.algebra.suppress()
# # basic computation unit
# atom = (func_name + Optional(func_arg.suppress()) + ZeroOrMore(arg_comb.suppress() + func_arg.suppress()) +
# par_r.suppress() | name | pi | e | num_float | num_int).setParseAction(self._push_neg_or_first) | \
# (par_l.setParseAction(self._push_last) + self.algebra.suppress() + par_r).setParseAction(self._push_neg)
# # apply indexing to atoms
# indexed = (Optional(minus) + atom).setParseAction(self._push_neg) + \
# ZeroOrMore((index_start + index_multiples + index_end))
# index_base = (self.algebra.suppress() | index_comb)
# index_full = index_base + ZeroOrMore((index_comb + index_base)) + ZeroOrMore(index_comb)
# index_multiples << index_full + ZeroOrMore((arg_comb + index_full))
# # hierarchical relationships between mathematical and logical operations
# boolean = indexed + Optional((op_logical + indexed).setParseAction(self._push_first))
# exponential << boolean + ZeroOrMore((op_exp + Optional(exponential)).setParseAction(self._push_first))
# factor = exponential + ZeroOrMore((op_mult + exponential).setParseAction(self._push_first))
# expr = factor + ZeroOrMore((op_add + factor).setParseAction(self._push_first))
# self.algebra << expr
def _push_first(self, strg, loc, toks):
"""Push tokens in first-to-last order to expression stack.
def _push_neg(self, strg, loc, toks):
"""Push negative one multiplier if on first position in toks.
if toks and toks[0] == '-':
def _push_neg_or_first(self, strg, loc, toks):
"""Push neg one multipler to expression stack if on first position in toks, else push toks from first-to-last.
if toks and toks[0] == '-':
def _push_last(self, strg, loc, toks):
"""Push tokens in last-to-first order to expression stack.
[docs]class ExpressionParser:
"""Class for parsing mathematical expressions from a string format into a symbolic representation of the
mathematical operation expressed by it.
Mathematical expression in string format.
Dictionary containing all variables and functions needed to evaluate the expression.
ComputeGraph instance in which to parse all variables and operations.
See `pyrates.backend.computegraph.ComputeGraph` for a full documentation of its methods and attributes.
Name of the parsing method to use. Valid options: `sympy` for a sympy-based parser, `pyrates` for the
pyrates-internal, pyparsing-based parser.
Boolean, indicates whether expression is left-hand side or right-hand side of an equation
PyRatesOp for the evaluation of the right-hand side of the equation
Dictionary containing the variables of an expression
String representation of the mathematical expression
List representation of the syntax tree of the (parsed) mathematical expression.
Function that will be used to deduce an operation stack/tree from a given expression.
Passed ComputeGraph instance into which the expression will be parsed.
_constant_counter = 0
def __init__(self, expr_str: str, args: dict, cg: ComputeGraph, def_shape: tp.Optional[tuple] = None,
parsing_method: str = 'sympy') -> None:
"""Instantiates expression parser.
# main attributes
self.vars = args.copy()
self.expr_str = expr_str
self.expr_stack = []
self.cg = cg
self._def_shape = def_shape
# preprocess the expression string
self.lhs, self.rhs, self._diff_eq, self._assign_type, self.lhs_key = self._preprocess_expr_str(expr_str)
# define the parsing function
if parsing_method == 'sympy':
self.parse_func = sympify
elif parsing_method == 'pyrates':
a = Algebra()
self.parse_func = a.parseString
raise ValueError(f'Invalid identifier for the parsing method: {parsing_method}.')
[docs] def parse_expr(self) -> dict:
"""Parses string-based mathematical expression/equation.
Variables of the parsed equation.
# extract symbols and operations from equations right-hand side
self.expr_stack = self.parse_func(self.rhs)
if self.expr_stack.is_number:
c = f"dummy_constant_{self._constant_counter}"
expr = f"no_op({c})"
self.vars[c] = {'vtype': 'input',
'value': float(self.expr_stack),
'shape': ()}
self.expr_stack = self.parse_func(expr)
self._constant_counter += 1
if self.expr_stack.is_symbol:
self.expr_stack = self.parse_func(f"no_op({self.rhs})")
# parse rhs into backend
self.rhs = self._parse_stack(self.expr_stack)
# extract symbols and operations from left-hand side
self.expr_stack = self.parse_func(self.lhs)
# parse lhs into backend
return self.vars
def _parse_stack(self, expr: Expr):
if expr.args:
# parse variables as nodes into compute graph
inputs, func_args, old_args = [], [], []
for arg in self._sort_expr_args(expr.args):
if isinstance(arg, Symbol):
# case I: variables/constants
var = self.vars[arg.name]
if isinstance(var, ComputeNode):
# if parsed already, retrieve label from existing variable
label = var.name
# if not parsed already, parse variable into backend
label, var = self.cg.add_var(label=arg.name, def_shape=self._def_shape, **var)
self.vars[arg.name] = var
# case II: mathematical expressions
label, var = self._parse_stack(arg)
if isinstance(var, ComputeNode):
# store input to mathematical expression, if it is not a simple scalar
# replace names of old expression arguments with new variable symbols
replacements = {old: new for old, new in zip(old_args, func_args) if old != new}
if replacements:
expr = replace_in_expr(expr, replacements)
# create callable function of the operation
label = expr.func.__name__
v_tmp = self.cg.get_var(func_args[0].name)
op = self.cg.get_op(label, shape=v_tmp.shape, dtype=v_tmp.dtype)
backend_funcs = {label: op['func']}
except (KeyError, IndexError):
backend_funcs = dict()
func = lambdify(func_args, expr=expr, modules=[backend_funcs, "numpy"])
# parse mathematical operation into compute graph
return self.cg.add_op(inputs, label=label, expr=expr, func=func)
# for simple scalar constants, return empty string and dict
return "", dict()
def _preprocess_expr_str(self, expr: str) -> tuple:
"""Turns differential equations into simple algebraic equations using a certain solver scheme and extracts
left-hand side, right-hand side and update type of the equation.
Equation in string format.
Contains left hand side, right hand side and left hand side update type
# collect equation specifics
# split equation into lhs and rhs and assign type
lhs, rhs, assign_type = split_equation(expr)
if not assign_type:
self.vars['x'] = {'vtype': 'variable', 'value': 0.0, 'dtype': 'float', 'shape': ()}
return self._preprocess_expr_str(f"x = {expr}")
# for the left-hand side, check whether it includes a differential operator
if "d/dt" in lhs:
diff_eq = True
lhs_split = lhs.split('*')
lhs = "".join(lhs_split[1:])
elif "'" in lhs:
diff_eq = True
lhs = lhs.replace("'", "")
elif "d" in lhs and "/dt" in lhs:
diff_eq = True
lhs = lhs.split('/dt')[0]
lhs = lhs.replace("d", "", count=1)
diff_eq = False
# get clean name of lhs
lhs_key = lhs.split('(')[0]
lhs_key = lhs_key.replace(' ', '')
lhs = lhs.replace(' ', '')
# store equation specifics
if diff_eq and assign_type != '=':
raise ValueError(f'Wrong assignment method for equation: {expr}. '
f'A differential equation cannot be combined with an assign type other than `=`.')
return lhs, rhs, diff_eq, assign_type, lhs_key
def _update_lhs(self):
"""Applies update to left-hand side of equation. For differential equations, different solving schemes are
# update left-hand side of equation
# receive left-hand side variable information
lhs_key = self.lhs_key
if self.expr_stack.is_symbol:
# retrieve variable information
v = self.vars[self.lhs_key]
# create backend state variable if it does not exist already
if not isinstance(v, ComputeNode):
_, v = self.cg.add_var(label=lhs_key, def_shape=self._def_shape, **v)
# parse left-hand side indexing operation
lhs_key, v = self._parse_stack(self.expr_stack)
# create mapping between left-hand side and right-hand side of the equation
self.cg.add_var_update(v.name, self.rhs[0], differential_equation=self._diff_eq)
if lhs_key in self.vars:
self.vars[lhs_key] = v
def _sort_expr_args(args: tuple) -> list:
# sort arguments from longest to shortest expression
arg_lengths = [len(str(arg)) for arg in args]
args_sorted = [args[idx] for idx in np.argsort(arg_lengths)[::-1]]
# add arguments that need to be treated with priority
args_final = []
while args_sorted:
arg = args_sorted.pop(0)
# check whether the position between arguments should be swapped
prioritize = False
if isinstance(arg, Expr):
for arg_tmp in args_sorted:
if len(arg_tmp.find(arg)):
prioritize = True
if prioritize:
idx = args_sorted.index(arg_tmp)
return args_final
# helper classes and functions #
[docs]def parse_equations(equations: list, equation_args: dict, cg: ComputeGraph, def_shape: tuple, **kwargs) -> dict:
"""Parses a system (list) of equations into the backend. Transforms differential equations into the appropriate set
of right-hand side evaluations that can be solved later on.
Collection of equations that describe the dynamics of the nodes and edges.
Key-value pairs of arguments needed for parsing the equations.
ComputeGraph instance that all equations will be parsed into.
Default shape of variables that are scalar. Can either be `(1,)` or `()`.
Additional keyword arguments to be passed to the backend methods.
The updated equations args (in-place manipulation of all variables in equation_args happens during
equation parsing).
for eq, scope in equations:
scope = scope.split('/')
# parse arguments
# extract operator variables from equation args
op_args = {}
in_vars = []
update_vars = {}
for key, var in equation_args.copy().items():
key_split = key.split('/')
if key_split[:-1] == scope:
var_name = key_split[-1]
if var_name == 'inputs':
# extract inputs from other variable scopes
for in_key, inp in var.items():
# check whether input variable has been passed properly
if inp not in equation_args:
raise KeyError(inp)
# extract input variable
inp_tmp = equation_args[inp]
# add input variable to operator arguments
op_args[in_key] = inp_tmp
# remember to update the variable entry in the variable collection later
update_vars[inp] = in_key
# add variable to operator inputs
elif var_name not in in_vars:
# change the variable type of input variables that did not receive any extrinsic input
if var['vtype'] == 'input' and var_name not in equation_args['inputs']:
var['vtype'] = 'constant'
except KeyError:
var['vtype'] = 'constant'
except TypeError:
# include variable information in operator arguments
op_args[var_name] = var
# remember to update the variable entry in the variable collection later
update_vars[key] = var_name
# parse equation
instantaneous = is_diff_eq(eq) is False
# initialize parser
parser = ExpressionParser(expr_str=eq, args=op_args, cg=cg, def_shape=def_shape, **kwargs)
# parse expression into compute graph
variables = parser.parse_expr()
# store newly created backend variables
for full_key, var_key in update_vars.items():
if full_key in equation_args:
equation_args[full_key] = variables[var_key]
return equation_args
[docs]def split_equation(expr: str) -> tuple:
"""Splits an equation string into a left-hand side, right-and side and an assign type.
Equation string. Should contain a left-hand side and a right-hand side, separated by some form of assign symbol.
left-hand side string, right-hand side string, assign operation string.
# define assign types and explicit non-assign types
assign_types = ['+=', '-=', '*=', '/=']
not_assign_types = ['<=', '>=', '==', '!=']
lhs, rhs, assign_type, found_assign_type = "", "", "", False
# look for assign types in expression
for assign_type in assign_types:
if assign_type in expr:
# split expression via assign symbol
if f' {assign_type} ' in expr:
lhs, rhs = expr.split(f' {assign_type} ', maxsplit=1)
elif f' {assign_type}' in expr:
lhs, rhs = expr.split(f' {assign_type}', maxsplit=1)
elif f'{assign_type} ' in expr:
lhs, rhs = expr.split(f'{assign_type} ', maxsplit=1)
lhs, rhs = expr.split(assign_type, maxsplit=1)
found_assign_type = True
elif '=' in expr:
# assume standard assign
assign_type = '='
assign = True
# check if `=` symbol marks an assign operation or not
for not_assign_type in not_assign_types:
if not_assign_type in expr:
expr_tmp = expr.replace(not_assign_type, '')
if '=' not in expr_tmp:
assign = False
if assign:
# split equation via `=` symbol
if f' = ' in expr:
lhs, rhs = expr.split(f' = ', maxsplit=1)
elif f' {assign_type}' in expr:
lhs, rhs = expr.split(f' =', maxsplit=1)
elif f'{assign_type} ' in expr:
lhs, rhs = expr.split(f'= ', maxsplit=1)
lhs, rhs = expr.split(f"=", maxsplit=1)
found_assign_type = True
if not found_assign_type:
return lhs, rhs, False
return lhs, rhs, assign_type
[docs]def replace(eq: str, term: str, replacement: str, rhs_only: tp.Optional[bool] = False,
lhs_only: tp.Optional[bool] = False) -> str:
"""Replaces a term in an equation with a replacement term (save replacement).
Equation that includes the term.
Term that should be replaced.
Replacement for all occurences of term.
If True, replacements will only be performed in right-hand side of the equation.
IF True, replacements will only be performed in left-hand side of the equation.
The updated equation.
# define follow-up operations/signs that are allowed to follow directly after term in eq
allowed_follow_ops = '-+=*/^<>=!.%@[]():, '
# replace every proper appearance of term in eq with replacement
eq_new = ""
idx = eq.find(term)
# go through all appearances of term in eq
while idx != -1:
# get idx of sign that follows after term
idx_follow_op = idx+len(term)
# if it is an allowed sign, replace term, else not
replaced = False
if ((idx_follow_op < len(eq) and eq[idx_follow_op] in allowed_follow_ops) and
(idx == 0 or eq[idx-1] in allowed_follow_ops)) or \
(idx_follow_op == len(eq) and eq[idx-1] in allowed_follow_ops):
eq_part = eq[:idx]
if (rhs_only and "=" in eq_part) or (lhs_only and "=" not in eq_part) or (not rhs_only and not lhs_only):
eq_new += f"{eq_part}{replacement}"
replaced = True
if not replaced:
eq_new += f"{eq[:idx_follow_op]}"
# jump to next appearance of term in eq
eq = eq[idx_follow_op:]
idx = eq.find(term)
# add rest of eq to new eq
eq_new += eq
return eq_new
[docs]def is_diff_eq(eq: str) -> bool:
"""Checks whether `eq` is a differential equation or not.
Equation string.
True, if `eq` is a differential equation.
lhs, rhs, _ = split_equation(eq)
if "d/dt" in lhs:
de = True
elif "'" in lhs:
de = True
elif "dt_test" in replace(rhs, "step_size", "dt_test"):
de = True
de = False
return de
[docs]def var_in_expression(var: str, expr: str) -> bool:
if var == expr:
return True
# define follow-up operations/signs that are allowed to follow directly after term in eq
follow_ops = '+-=*/^<>=!.%@[]():, '
start = 0
n = len(expr)
exists_in_expr = False
while start < n:
# find variable in string and extract the follow-up sign
idx = expr[start:].find(var)
if idx == -1:
idx += start
idx_next = idx + len(var)
next_sign = expr[idx_next]
prev_sign = expr[idx - 1]
# decide whether variable actually exists in expression
exists_in_expr = ((idx_next < n and next_sign in follow_ops) and
(idx == 0 or prev_sign in follow_ops)) or (idx_next == n and prev_sign in follow_ops)
if exists_in_expr:
start = idx_next
except IndexError:
return exists_in_expr
[docs]def get_unique_label(label: str, labels: dict) -> tp.Tuple[str, dict]:
if label in labels:
n = labels[label]
if n == 0:
label_new = f"{label}_num1"
label_new = f"{label}_num{n+1}"
labels[label] += 1
label_new = label
labels[label] = 0
return label_new, labels
[docs]def replace_in_expr(expr: Expr, replacements: dict):
expr = expr.subs(replacements, simultaneous=True)
for arg_old in replacements:
if expr.count(arg_old):
expr = expr.replace(arg_old, replacements[arg_old])
return expr