# -*- coding: utf-8 -*-
#
#
# PyRates software framework for flexible implementation of neural
# network model_templates and simulations. See also:
# https://github.com/pyrates-neuroscience/PyRates
#
# Copyright (C) 2017-2018 the original authors (Richard Gast and
# Daniel Rose), the Max-Planck-Institute for Human Cognitive Brain
# Sciences ("MPI CBS") and contributors
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>
#
# CITATION:
#
# Richard Gast and Daniel Rose et. al. in preparation
"""This module provides parser classes and functions to parse string-based equations into symbolic representations of
operations.
"""
# external _imports
import math
import typing as tp
from numbers import Number
import numpy as np
# from pyparsing import Literal, CaselessLiteral, Word, Combine, Optional, \
# ZeroOrMore, Forward, nums, alphas, ParserElement
from sympy import Expr, Symbol, lambdify, sympify
# pyrates internal _imports
from pyrates.backend.computegraph import ComputeGraph, ComputeNode
# meta infos
__author__ = "Richard Gast"
__status__ = "development"
###############################################
# expression parsers (lhs/rhs of an equation) #
###############################################
[docs]class Algebra:
def __init__(self, **kwargs) -> None:
"""Instantiates expression parser.
"""
# call super method
super().__init__()
self.algebra = kwargs.pop('algebra', None)
# define algebra
################
# if not self.algebra:
#
# # general symbols
# point = Literal(".")
# comma = Literal(",")
# colon = Literal(":")
# e = CaselessLiteral("E")
# pi = CaselessLiteral("PI")
#
# # parentheses
# par_l = Literal("(")
# par_r = Literal(")").setParseAction(self._push_first)
# idx_l = Literal("[")
# idx_r = Literal("]")
#
# # basic mathematical operations
# plus = Literal("+")
# minus = Literal("-")
# mult = Literal("*")
# div = Literal("/")
# mod = Literal("%")
# dot = Literal("@")
# exp_1 = Literal("^")
# exp_2 = Combine(mult + mult)
# transp = Combine(point + Literal("T"))
# inv = Combine(point + Literal("I"))
#
# # numeric types
# num_float = Combine(Word("-" + nums, nums) +
# Optional(point + Optional(Word(nums))) +
# Optional(e + Word("-" + nums, nums)))
# num_int = Word("-" + nums, nums)
#
# # variables and functions
# name = Word(alphas, alphas + nums + "_$")
# func_name = Combine(name + par_l, adjacent=True)
#
# # math operation groups
# op_add = plus | minus
# op_mult = mult | div | dot | mod
# op_exp = exp_1 | exp_2 | inv | transp
#
# # logical operations
# greater = Literal(">")
# less = Literal("<")
# equal = Combine(Literal("=") + Literal("="))
# unequal = Combine(Literal("!") + Literal("="))
# greater_equal = Combine(Literal(">") + Literal("="))
# less_equal = Combine(Literal("<") + Literal("="))
#
# # logical operations group
# op_logical = greater_equal | less_equal | unequal | equal | less | greater
#
# # pre-allocations
# self.algebra = Forward()
# exponential = Forward()
# index_multiples = Forward()
#
# # basic organization units
# index_start = idx_l.setParseAction(self._push_first)
# index_end = idx_r.setParseAction(self._push_first)
# index_comb = colon.setParseAction(self._push_first)
# arg_comb = comma.setParseAction(self._push_first)
# arg_tuple = par_l + ZeroOrMore(self.algebra.suppress() + Optional(arg_comb)) + par_r
# func_arg = arg_tuple | self.algebra.suppress()
#
# # basic computation unit
# atom = (func_name + Optional(func_arg.suppress()) + ZeroOrMore(arg_comb.suppress() + func_arg.suppress()) +
# par_r.suppress() | name | pi | e | num_float | num_int).setParseAction(self._push_neg_or_first) | \
# (par_l.setParseAction(self._push_last) + self.algebra.suppress() + par_r).setParseAction(self._push_neg)
#
# # apply indexing to atoms
# indexed = (Optional(minus) + atom).setParseAction(self._push_neg) + \
# ZeroOrMore((index_start + index_multiples + index_end))
# index_base = (self.algebra.suppress() | index_comb)
# index_full = index_base + ZeroOrMore((index_comb + index_base)) + ZeroOrMore(index_comb)
# index_multiples << index_full + ZeroOrMore((arg_comb + index_full))
#
# # hierarchical relationships between mathematical and logical operations
# boolean = indexed + Optional((op_logical + indexed).setParseAction(self._push_first))
# exponential << boolean + ZeroOrMore((op_exp + Optional(exponential)).setParseAction(self._push_first))
# factor = exponential + ZeroOrMore((op_mult + exponential).setParseAction(self._push_first))
# expr = factor + ZeroOrMore((op_add + factor).setParseAction(self._push_first))
# self.algebra << expr
def _push_first(self, strg, loc, toks):
"""Push tokens in first-to-last order to expression stack.
"""
self.expr_stack.append(toks[0])
def _push_neg(self, strg, loc, toks):
"""Push negative one multiplier if on first position in toks.
"""
if toks and toks[0] == '-':
self.expr_stack.append('-one')
def _push_neg_or_first(self, strg, loc, toks):
"""Push neg one multipler to expression stack if on first position in toks, else push toks from first-to-last.
"""
if toks and toks[0] == '-':
self.expr_stack.append('-one')
else:
self.expr_stack.append(toks[0])
def _push_last(self, strg, loc, toks):
"""Push tokens in last-to-first order to expression stack.
"""
self.expr_stack.append(toks[-1])
[docs]class ExpressionParser:
"""Class for parsing mathematical expressions from a string format into a symbolic representation of the
mathematical operation expressed by it.
Parameters
----------
expr_str
Mathematical expression in string format.
args
Dictionary containing all variables and functions needed to evaluate the expression.
cg
ComputeGraph instance in which to parse all variables and operations.
See `pyrates.backend.computegraph.ComputeGraph` for a full documentation of its methods and attributes.
parsing_method
Name of the parsing method to use. Valid options: `sympy` for a sympy-based parser, `pyrates` for the
pyrates-internal, pyparsing-based parser.
Attributes
----------
lhs
Boolean, indicates whether expression is left-hand side or right-hand side of an equation
rhs
PyRatesOp for the evaluation of the right-hand side of the equation
vars
Dictionary containing the variables of an expression
expr_str
String representation of the mathematical expression
expr_stack
List representation of the syntax tree of the (parsed) mathematical expression.
parse_func
Function that will be used to deduce an operation stack/tree from a given expression.
cg
Passed ComputeGraph instance into which the expression will be parsed.
"""
_constant_counter = 0
def __init__(self, expr_str: str, args: dict, cg: ComputeGraph, def_shape: tp.Optional[tuple] = None,
parsing_method: str = 'sympy') -> None:
"""Instantiates expression parser.
"""
# main attributes
self.vars = args.copy()
self.expr_str = expr_str
self.expr_stack = []
self.cg = cg
self._def_shape = def_shape
# preprocess the expression string
self.lhs, self.rhs, self._diff_eq, self._assign_type, self.lhs_key = self._preprocess_expr_str(expr_str)
# define the parsing function
if parsing_method == 'sympy':
self.parse_func = sympify
elif parsing_method == 'pyrates':
a = Algebra()
self.parse_func = a.parseString
else:
raise ValueError(f'Invalid identifier for the parsing method: {parsing_method}.')
[docs] def parse_expr(self) -> dict:
"""Parses string-based mathematical expression/equation.
Returns
-------
dict
Variables of the parsed equation.
"""
# extract symbols and operations from equations right-hand side
self.expr_stack = self.parse_func(self.rhs)
if self.expr_stack.is_number:
c = f"dummy_constant_{self._constant_counter}"
expr = f"no_op({c})"
self.vars[c] = {'vtype': 'input',
'value': float(self.expr_stack),
'shape': ()}
self.expr_stack = self.parse_func(expr)
self._constant_counter += 1
if self.expr_stack.is_symbol:
self.expr_stack = self.parse_func(f"no_op({self.rhs})")
# parse rhs into backend
self.rhs = self._parse_stack(self.expr_stack)
# extract symbols and operations from left-hand side
self.expr_stack = self.parse_func(self.lhs)
# parse lhs into backend
self._update_lhs()
return self.vars
def _parse_stack(self, expr: Expr):
if expr.args:
# parse variables as nodes into compute graph
inputs, func_args, old_args = [], [], []
for arg in self._sort_expr_args(expr.args):
if isinstance(arg, Symbol):
# case I: variables/constants
var = self.vars[arg.name]
if isinstance(var, ComputeNode):
# if parsed already, retrieve label from existing variable
label = var.name
else:
# if not parsed already, parse variable into backend
label, var = self.cg.add_var(label=arg.name, def_shape=self._def_shape, **var)
self.vars[arg.name] = var
else:
# case II: mathematical expressions
label, var = self._parse_stack(arg)
if isinstance(var, ComputeNode):
# store input to mathematical expression, if it is not a simple scalar
inputs.append(label)
func_args.append(var.symbol)
old_args.append(arg)
# replace names of old expression arguments with new variable symbols
replacements = {old: new for old, new in zip(old_args, func_args) if old != new}
if replacements:
expr = replace_in_expr(expr, replacements)
# create callable function of the operation
label = expr.func.__name__
try:
v_tmp = self.cg.get_var(func_args[0].name)
op = self.cg.get_op(label, shape=v_tmp.shape, dtype=v_tmp.dtype)
backend_funcs = {label: op['func']}
except (KeyError, IndexError):
backend_funcs = dict()
func = lambdify(func_args, expr=expr, modules=[backend_funcs, "numpy"])
# parse mathematical operation into compute graph
return self.cg.add_op(inputs, label=label, expr=expr, func=func)
else:
# for simple scalar constants, return empty string and dict
return "", dict()
def _preprocess_expr_str(self, expr: str) -> tuple:
"""Turns differential equations into simple algebraic equations using a certain solver scheme and extracts
left-hand side, right-hand side and update type of the equation.
Parameters
----------
expr
Equation in string format.
Returns
-------
tuple
Contains left hand side, right hand side and left hand side update type
"""
# collect equation specifics
############################
# split equation into lhs and rhs and assign type
lhs, rhs, assign_type = split_equation(expr)
if not assign_type:
self.vars['x'] = {'vtype': 'variable', 'value': 0.0, 'dtype': 'float', 'shape': ()}
return self._preprocess_expr_str(f"x = {expr}")
# for the left-hand side, check whether it includes a differential operator
if "d/dt" in lhs:
diff_eq = True
lhs_split = lhs.split('*')
lhs = "".join(lhs_split[1:])
elif "'" in lhs:
diff_eq = True
lhs = lhs.replace("'", "")
elif "d" in lhs and "/dt" in lhs:
diff_eq = True
lhs = lhs.split('/dt')[0]
lhs = lhs.replace("d", "", count=1)
else:
diff_eq = False
# get clean name of lhs
lhs_key = lhs.split('(')[0]
lhs_key = lhs_key.replace(' ', '')
lhs = lhs.replace(' ', '')
# store equation specifics
if diff_eq and assign_type != '=':
raise ValueError(f'Wrong assignment method for equation: {expr}. '
f'A differential equation cannot be combined with an assign type other than `=`.')
return lhs, rhs, diff_eq, assign_type, lhs_key
def _update_lhs(self):
"""Applies update to left-hand side of equation. For differential equations, different solving schemes are
available.
"""
# update left-hand side of equation
###################################
# receive left-hand side variable information
lhs_key = self.lhs_key
if self.expr_stack.is_symbol:
# retrieve variable information
v = self.vars[self.lhs_key]
# create backend state variable if it does not exist already
if not isinstance(v, ComputeNode):
_, v = self.cg.add_var(label=lhs_key, def_shape=self._def_shape, **v)
else:
# parse left-hand side indexing operation
lhs_key, v = self._parse_stack(self.expr_stack)
# create mapping between left-hand side and right-hand side of the equation
self.cg.add_var_update(v.name, self.rhs[0], differential_equation=self._diff_eq)
if lhs_key in self.vars:
self.vars[lhs_key] = v
@staticmethod
def _sort_expr_args(args: tuple) -> list:
# sort arguments from longest to shortest expression
arg_lengths = [len(str(arg)) for arg in args]
args_sorted = [args[idx] for idx in np.argsort(arg_lengths)[::-1]]
# add arguments that need to be treated with priority
args_final = []
while args_sorted:
arg = args_sorted.pop(0)
# check whether the position between arguments should be swapped
prioritize = False
if isinstance(arg, Expr):
for arg_tmp in args_sorted:
if len(arg_tmp.find(arg)):
prioritize = True
break
if prioritize:
idx = args_sorted.index(arg_tmp)
args_final.append(args_sorted.pop(idx))
else:
args_final.append(arg)
return args_final
################################
# helper classes and functions #
################################
[docs]def parse_equations(equations: list, equation_args: dict, cg: ComputeGraph, def_shape: tuple, **kwargs) -> dict:
"""Parses a system (list) of equations into the backend. Transforms differential equations into the appropriate set
of right-hand side evaluations that can be solved later on.
Parameters
----------
equations
Collection of equations that describe the dynamics of the nodes and edges.
equation_args
Key-value pairs of arguments needed for parsing the equations.
cg
ComputeGraph instance that all equations will be parsed into.
def_shape
Default shape of variables that are scalar. Can either be `(1,)` or `()`.
kwargs
Additional keyword arguments to be passed to the backend methods.
Returns
-------
dict
The updated equations args (in-place manipulation of all variables in equation_args happens during
equation parsing).
"""
for eq, scope in equations:
scope = scope.split('/')
# parse arguments
#################
# extract operator variables from equation args
op_args = {}
in_vars = []
update_vars = {}
for key, var in equation_args.copy().items():
key_split = key.split('/')
if key_split[:-1] == scope:
var_name = key_split[-1]
if var_name == 'inputs':
# extract inputs from other variable scopes
for in_key, inp in var.items():
# check whether input variable has been passed properly
if inp not in equation_args:
raise KeyError(inp)
# extract input variable
inp_tmp = equation_args[inp]
# add input variable to operator arguments
op_args[in_key] = inp_tmp
# remember to update the variable entry in the variable collection later
update_vars[inp] = in_key
# add variable to operator inputs
in_vars.append(in_key)
elif var_name not in in_vars:
# change the variable type of input variables that did not receive any extrinsic input
try:
if var['vtype'] == 'input' and var_name not in equation_args['inputs']:
var['vtype'] = 'constant'
except KeyError:
var['vtype'] = 'constant'
except TypeError:
pass
# include variable information in operator arguments
op_args[var_name] = var
# remember to update the variable entry in the variable collection later
update_vars[key] = var_name
# parse equation
################
instantaneous = is_diff_eq(eq) is False
# initialize parser
parser = ExpressionParser(expr_str=eq, args=op_args, cg=cg, def_shape=def_shape, **kwargs)
# parse expression into compute graph
variables = parser.parse_expr()
# store newly created backend variables
for full_key, var_key in update_vars.items():
if full_key in equation_args:
equation_args[full_key] = variables[var_key]
return equation_args
[docs]def split_equation(expr: str) -> tuple:
"""Splits an equation string into a left-hand side, right-and side and an assign type.
Parameters
----------
expr
Equation string. Should contain a left-hand side and a right-hand side, separated by some form of assign symbol.
Returns
-------
tuple
left-hand side string, right-hand side string, assign operation string.
"""
# define assign types and explicit non-assign types
assign_types = ['+=', '-=', '*=', '/=']
not_assign_types = ['<=', '>=', '==', '!=']
lhs, rhs, assign_type, found_assign_type = "", "", "", False
# look for assign types in expression
for assign_type in assign_types:
if assign_type in expr:
# split expression via assign symbol
if f' {assign_type} ' in expr:
lhs, rhs = expr.split(f' {assign_type} ', maxsplit=1)
elif f' {assign_type}' in expr:
lhs, rhs = expr.split(f' {assign_type}', maxsplit=1)
elif f'{assign_type} ' in expr:
lhs, rhs = expr.split(f'{assign_type} ', maxsplit=1)
else:
lhs, rhs = expr.split(assign_type, maxsplit=1)
found_assign_type = True
break
elif '=' in expr:
# assume standard assign
assign_type = '='
assign = True
# check if `=` symbol marks an assign operation or not
for not_assign_type in not_assign_types:
if not_assign_type in expr:
expr_tmp = expr.replace(not_assign_type, '')
if '=' not in expr_tmp:
assign = False
if assign:
# split equation via `=` symbol
if f' = ' in expr:
lhs, rhs = expr.split(f' = ', maxsplit=1)
elif f' {assign_type}' in expr:
lhs, rhs = expr.split(f' =', maxsplit=1)
elif f'{assign_type} ' in expr:
lhs, rhs = expr.split(f'= ', maxsplit=1)
else:
lhs, rhs = expr.split(f"=", maxsplit=1)
found_assign_type = True
break
if not found_assign_type:
return lhs, rhs, False
return lhs, rhs, assign_type
[docs]def replace(eq: str, term: str, replacement: str, rhs_only: tp.Optional[bool] = False,
lhs_only: tp.Optional[bool] = False) -> str:
"""Replaces a term in an equation with a replacement term (save replacement).
Parameters
----------
eq
Equation that includes the term.
term
Term that should be replaced.
replacement
Replacement for all occurences of term.
rhs_only
If True, replacements will only be performed in right-hand side of the equation.
lhs_only
IF True, replacements will only be performed in left-hand side of the equation.
Returns
-------
str
The updated equation.
"""
# define follow-up operations/signs that are allowed to follow directly after term in eq
allowed_follow_ops = '-+=*/^<>=!.%@[]():, '
# replace every proper appearance of term in eq with replacement
################################################################
eq_new = ""
idx = eq.find(term)
# go through all appearances of term in eq
while idx != -1:
# get idx of sign that follows after term
idx_follow_op = idx+len(term)
# if it is an allowed sign, replace term, else not
replaced = False
if ((idx_follow_op < len(eq) and eq[idx_follow_op] in allowed_follow_ops) and
(idx == 0 or eq[idx-1] in allowed_follow_ops)) or \
(idx_follow_op == len(eq) and eq[idx-1] in allowed_follow_ops):
eq_part = eq[:idx]
if (rhs_only and "=" in eq_part) or (lhs_only and "=" not in eq_part) or (not rhs_only and not lhs_only):
eq_new += f"{eq_part}{replacement}"
replaced = True
if not replaced:
eq_new += f"{eq[:idx_follow_op]}"
# jump to next appearance of term in eq
eq = eq[idx_follow_op:]
idx = eq.find(term)
# add rest of eq to new eq
eq_new += eq
return eq_new
[docs]def is_diff_eq(eq: str) -> bool:
"""Checks whether `eq` is a differential equation or not.
Parameters
----------
eq
Equation string.
Returns
-------
bool
True, if `eq` is a differential equation.
"""
lhs, rhs, _ = split_equation(eq)
if "d/dt" in lhs:
de = True
elif "'" in lhs:
de = True
elif "dt_test" in replace(rhs, "step_size", "dt_test"):
de = True
else:
de = False
return de
[docs]def var_in_expression(var: str, expr: str) -> bool:
if var == expr:
return True
# define follow-up operations/signs that are allowed to follow directly after term in eq
follow_ops = '+-=*/^<>=!.%@[]():, '
start = 0
n = len(expr)
exists_in_expr = False
while start < n:
try:
# find variable in string and extract the follow-up sign
idx = expr[start:].find(var)
if idx == -1:
break
idx += start
idx_next = idx + len(var)
next_sign = expr[idx_next]
prev_sign = expr[idx - 1]
# decide whether variable actually exists in expression
exists_in_expr = ((idx_next < n and next_sign in follow_ops) and
(idx == 0 or prev_sign in follow_ops)) or (idx_next == n and prev_sign in follow_ops)
if exists_in_expr:
break
start = idx_next
except IndexError:
break
return exists_in_expr
[docs]def get_unique_label(label: str, labels: dict) -> tp.Tuple[str, dict]:
if label in labels:
n = labels[label]
if n == 0:
label_new = f"{label}_num1"
else:
label_new = f"{label}_num{n+1}"
labels[label] += 1
else:
label_new = label
labels[label] = 0
return label_new, labels
[docs]def replace_in_expr(expr: Expr, replacements: dict):
expr = expr.subs(replacements, simultaneous=True)
for arg_old in replacements:
if expr.count(arg_old):
expr = expr.replace(arg_old, replacements[arg_old])
return expr