def replace_inline_expressions_through_defining_expressions( self, definitions, inline_expressions): # type: (list(ASTOdeEquation), list(ASTInlineExpression)) -> list(ASTInlineExpression) """ Replaces symbols from `inline_expressions` in `definitions` with corresponding defining expressions from `inline_expressions`. :param definitions: A sorted list with entries {"symbol": "name", "definition": "expression"} that should be made free from. :param inline_expressions: A sorted list with entries {"symbol": "name", "definition": "expression"} with inline_expressions which must be replaced in `definitions`. :return: A list with definitions. Expressions in `definitions` don't depend on inline_expressions from `inline_expressions`. """ for m in inline_expressions: source_position = m.get_source_position() for target in definitions: matcher = re.compile( self._variable_matching_template.format( m.get_variable_name())) target_definition = str(target.get_rhs()) target_definition = re.sub(matcher, "(" + str(m.get_expression()) + ")", target_definition) target.rhs = ModelParser.parse_expression(target_definition) target.update_scope(m.get_scope()) target.accept(ASTSymbolTableVisitor()) def log_set_source_position(node): if node.get_source_position().is_added_source_position(): node.set_source_position(source_position) target.accept( ASTHigherOrderVisitor(visit_funcs=log_set_source_position)) return definitions
def replace_convolve_calls_with_buffers_(self, neuron, equations_block, kernel_buffers): r""" Replace all occurrences of `convolve(kernel[']^n, spike_input_port)` with the corresponding buffer variable, e.g. `g_E__X__spikes_exc[__d]^n` for a kernel named `g_E` and a spike input port named `spikes_exc`. """ def replace_function_call_through_var(_expr=None): if _expr.is_function_call() and _expr.get_function_call().get_name() == "convolve": convolve = _expr.get_function_call() el = (convolve.get_args()[0], convolve.get_args()[1]) sym = convolve.get_args()[0].get_scope().resolve_to_symbol( convolve.get_args()[0].get_variable().name, SymbolKind.VARIABLE) if sym.block_type == BlockType.INPUT_BUFFER_SPIKE: el = (el[1], el[0]) var = el[0].get_variable() spike_input_port = el[1].get_variable() kernel = neuron.get_kernel_by_name(var.get_name()) _expr.set_function_call(None) buffer_var = construct_kernel_X_spike_buf_name( var.get_name(), spike_input_port, var.get_differential_order() - 1) if is_delta_kernel(kernel): # delta kernel are treated separately, and should be kept out of the dynamics (computing derivates etc.) --> set to zero _expr.set_variable(None) _expr.set_numeric_literal(0) else: ast_variable = ASTVariable(buffer_var) ast_variable.set_source_position(_expr.get_source_position()) _expr.set_variable(ast_variable) def func(x): return replace_function_call_through_var(x) if isinstance(x, ASTSimpleExpression) else True equations_block.accept(ASTHigherOrderVisitor(func))
def replace_variable_names_in_expressions( cls, neuron: ASTNeuron, solver_dicts: List[dict]) -> None: """ Replace all occurrences of variables names in NESTML format (e.g. `g_ex$''`)` with the ode-toolbox formatted variable name (e.g. `g_ex__DOLLAR__d__d`). Variables aliasing convolutions should already have been covered by replace_convolution_aliasing_inlines(). """ def replace_var(_expr=None): if isinstance(_expr, ASTSimpleExpression) and _expr.is_variable(): var = _expr.get_variable() if cls.variable_in_solver( cls.to_ode_toolbox_processed_name( var.get_complete_name()), solver_dicts): ast_variable = ASTVariable( cls.to_ode_toolbox_processed_name( var.get_complete_name()), differential_order=0) ast_variable.set_source_position(var.get_source_position()) _expr.set_variable(ast_variable) elif isinstance(_expr, ASTVariable): var = _expr if cls.variable_in_solver( cls.to_ode_toolbox_processed_name( var.get_complete_name()), solver_dicts): var.set_name( cls.to_ode_toolbox_processed_name( var.get_complete_name())) var.set_differential_order(0) def func(x): return replace_var(x) neuron.accept(ASTHigherOrderVisitor(func))
def replace_rhs_variable(cls, expr: ASTExpression, variable_name_to_replace: str, kernel_var: ASTVariable, spike_buf: ASTInputPort): """ Replace variable names in definitions of kernel dynamics :param expr: expression in which to replace the variables :param variable_name_to_replace: variable name to replace in the expression :param kernel_var: kernel variable instance :param spike_buf: input port instance :return: """ def replace_kernel_var(node): if type(node) is ASTSimpleExpression \ and node.is_variable() \ and node.get_variable().get_name() == variable_name_to_replace: var_order = node.get_variable().get_differential_order() new_variable_name = cls.construct_kernel_X_spike_buf_name( kernel_var.get_name(), spike_buf, var_order - 1, diff_order_symbol="'") new_variable = ASTVariable(new_variable_name, var_order) new_variable.set_source_position( node.get_variable().get_source_position()) node.set_variable(new_variable) expr.accept(ASTHigherOrderVisitor(visit_funcs=replace_kernel_var))
def make_functions_self_contained(self, functions): # type: (list(ASTOdeFunction)) -> list(ASTOdeFunction) """ TODO: it should be a method inside of the ASTOdeFunction TODO by KP: this should be done by means of a visitor Make function definition self contained, e.g. without any references to functions from `functions`. :param functions: A sorted list with entries ASTOdeFunction. :return: A list with ASTOdeFunctions. Defining expressions don't depend on each other. """ for source in functions: source_position = source.get_source_position() for target in functions: matcher = re.compile( self._variable_matching_template.format( source.get_variable_name())) target_definition = str(target.get_expression()) target_definition = re.sub( matcher, "(" + str(source.get_expression()) + ")", target_definition) target.expression = ModelParser.parse_expression( target_definition) def log_set_source_position(node): if node.get_source_position().is_added_source_position(): node.set_source_position(source_position) target.expression.accept( ASTHigherOrderVisitor(visit_funcs=log_set_source_position)) return functions
def replace_convolution_aliasing_inlines(cls, neuron: ASTNeuron) -> None: """ Replace all occurrences of kernel names (e.g. ``I_dend`` and ``I_dend'`` for a definition involving a second-order kernel ``inline kernel I_dend = convolve(kern_name, spike_buf)``) with the ODE-toolbox generated variable ``kern_name__X__spike_buf``. """ def replace_var(_expr, replace_var_name: str, replace_with_var_name: str): if isinstance(_expr, ASTSimpleExpression) and _expr.is_variable(): var = _expr.get_variable() if var.get_name() == replace_var_name: ast_variable = ASTVariable( replace_with_var_name + '__d' * var.get_differential_order(), differential_order=0) ast_variable.set_source_position(var.get_source_position()) _expr.set_variable(ast_variable) elif isinstance(_expr, ASTVariable): var = _expr if var.get_name() == replace_var_name: var.set_name(replace_with_var_name + '__d' * var.get_differential_order()) var.set_differential_order(0) for decl in neuron.get_equations_block().get_declarations(): from pynestml.utils.ast_utils import ASTUtils if isinstance(decl, ASTInlineExpression) \ and isinstance(decl.get_expression(), ASTSimpleExpression) \ and '__X__' in str(decl.get_expression()): replace_with_var_name = decl.get_expression().get_variable( ).get_name() neuron.accept( ASTHigherOrderVisitor(lambda x: replace_var( x, decl.get_variable_name(), replace_with_var_name)))
def get_alias_symbols(cls, ast): """ For the handed over meta_model, this method collects all functions aka. aliases in it. :param ast: a single meta_model node :type ast: AST_ :return: a list of all alias variable symbols :rtype: list(VariableSymbol) """ ret = list() from pynestml.visitors.ast_higher_order_visitor import ASTHigherOrderVisitor res = list() def loc_get_vars(node): if isinstance(node, ASTVariable): res.append(node) ast.accept(ASTHigherOrderVisitor(visit_funcs=loc_get_vars)) for var in res: if '\'' not in var.get_complete_name(): symbol = ast.get_scope().resolve_to_symbol( var.get_complete_name(), SymbolKind.VARIABLE) if symbol is not None and symbol.is_function: ret.append(symbol) return ret
def check_co_co(cls, node): """ Checks if this coco applies for the handed over neuron. :param node: a single neuron instance. :type node: ASTNeuron """ def check_simple_delta(_expr=None): if _expr.is_function_call() and _expr.get_function_call().get_name() == "delta": deltafunc = _expr.get_function_call() parent = node.get_parent(_expr) # check the argument if not (len(deltafunc.get_args()) == 1 and type(deltafunc.get_args()[0]) is ASTSimpleExpression and deltafunc.get_args()[0].get_variable() is not None and deltafunc.get_args()[0].get_variable().name == "t"): code, message = Messages.delta_function_one_arg(deltafunc) Logger.log_message(code=code, message=message, error_position=_expr.get_source_position(), log_level=LoggingLevel.ERROR) if type(parent) is not ASTKernel: code, message = Messages.delta_function_cannot_be_mixed() Logger.log_message(code=code, message=message, error_position=_expr.get_source_position(), log_level=LoggingLevel.ERROR) def func(x): return check_simple_delta(x) if isinstance(x, ASTSimpleExpression) else True node.accept(ASTHigherOrderVisitor(func))
def replace_variable_names_in_expressions(self, neuron, solver_dicts): """ Replace all occurrences of variables names in NESTML format (e.g. `g_ex$''`)` with the ode-toolbox formatted variable name (e.g. `g_ex__DOLLAR__d__d`). """ def replace_var(_expr=None): if isinstance(_expr, ASTSimpleExpression) and _expr.is_variable(): var = _expr.get_variable() if variable_in_solver( to_ode_toolbox_processed_name(var.get_complete_name()), solver_dicts): ast_variable = ASTVariable(to_ode_toolbox_processed_name( var.get_complete_name()), differential_order=0) ast_variable.set_source_position(var.get_source_position()) _expr.set_variable(ast_variable) elif isinstance(_expr, ASTVariable): var = _expr if variable_in_solver( to_ode_toolbox_processed_name(var.get_complete_name()), solver_dicts): var.set_name( to_ode_toolbox_processed_name(var.get_complete_name())) var.set_differential_order(0) def func(x): return replace_var(x) neuron.accept(ASTHigherOrderVisitor(func))
def make_inline_expressions_self_contained( self, inline_expressions: List[ASTInlineExpression] ) -> List[ASTInlineExpression]: """ Make inline_expressions self contained, i.e. without any references to other inline_expressions. TODO: it should be a method inside of the ASTInlineExpression TODO: this should be done by means of a visitor :param inline_expressions: A sorted list with entries ASTInlineExpression. :return: A list with ASTInlineExpressions. Defining expressions don't depend on each other. """ for source in inline_expressions: source_position = source.get_source_position() for target in inline_expressions: matcher = re.compile( self._variable_matching_template.format( source.get_variable_name())) target_definition = str(target.get_expression()) target_definition = re.sub( matcher, "(" + str(source.get_expression()) + ")", target_definition) target.expression = ModelParser.parse_expression( target_definition) target.expression.update_scope(source.get_scope()) target.expression.accept(ASTSymbolTableVisitor()) def log_set_source_position(node): if node.get_source_position().is_added_source_position(): node.set_source_position(source_position) target.expression.accept( ASTHigherOrderVisitor(visit_funcs=log_set_source_position)) return inline_expressions
def get_inline_expression_symbols(cls, ast: ASTNode) -> List[VariableSymbol]: """ For the handed over AST node, this method collects all inline expression variable symbols in it. :param ast: a single AST node :return: a list of all inline expression variable symbols """ from pynestml.visitors.ast_higher_order_visitor import ASTHigherOrderVisitor from pynestml.meta_model.ast_variable import ASTVariable res = list() def loc_get_vars(node): if isinstance(node, ASTVariable): res.append(node) ast.accept(ASTHigherOrderVisitor(visit_funcs=loc_get_vars)) ret = list() for var in res: if '\'' not in var.get_complete_name(): symbol = ast.get_scope().resolve_to_symbol( var.get_complete_name(), SymbolKind.VARIABLE) if symbol is not None and symbol.is_inline_expression: ret.append(symbol) return ret
def replace_inline_expressions_through_defining_expressions( cls, definitions: Sequence[ASTOdeEquation], inline_expressions: Sequence[ASTInlineExpression] ) -> Sequence[ASTOdeEquation]: """ Replaces symbols from `inline_expressions` in `definitions` with corresponding defining expressions from `inline_expressions`. :param definitions: A list of ODE definitions (**updated in-place**). :param inline_expressions: A list of inline expression definitions. :return: A list of updated ODE definitions (same as the ``definitions`` parameter). """ for m in inline_expressions: source_position = m.get_source_position() for target in definitions: matcher = re.compile( cls._variable_matching_template.format( m.get_variable_name())) target_definition = str(target.get_rhs()) target_definition = re.sub(matcher, "(" + str(m.get_expression()) + ")", target_definition) target.rhs = ModelParser.parse_expression(target_definition) target.update_scope(m.get_scope()) target.accept(ASTSymbolTableVisitor()) def log_set_source_position(node): if node.get_source_position().is_added_source_position(): node.set_source_position(source_position) target.accept( ASTHigherOrderVisitor(visit_funcs=log_set_source_position)) return definitions
def add_suffix_to_variable_name(cls, var_name: str, astnode: ASTNode, suffix: str, scope=None): """add suffix to variable by given name recursively throughout astnode""" from pynestml.visitors.ast_symbol_table_visitor import ASTSymbolTableVisitor from pynestml.symbols.variable_symbol import BlockType from pynestml.symbols.variable_symbol import VariableSymbol, BlockType, VariableType def replace_var(_expr=None): if isinstance(_expr, ASTSimpleExpression) and _expr.is_variable(): var = _expr.get_variable() elif isinstance(_expr, ASTVariable): var = _expr else: return if not suffix in var.get_name() \ and not var.get_name() == "t" \ and var.get_name() == var_name: var.set_name(var.get_name() + suffix) # if scope is not None: # symbol = VariableSymbol(name=var.get_name(), block_type=BlockType.PARAMETERS, # type_symbol=var.get_type_symbol(), # variable_type=VariableType.VARIABLE) # scope.add_symbol(symbol) # var.update_scope(scope) # assert scope.resolve_to_symbol(var.get_name(), SymbolKind.VARIABLE) is not None # assert scope.resolve_to_symbol(var.get_name(), SymbolKind.VARIABLE).block_type == BlockType.PARAMETERS # #var.accept(ASTSymbolTableVisitor()) astnode.accept(ASTHigherOrderVisitor(lambda x: replace_var(x)))
def replace_with_external_variable(cls, var_name, node: ASTNode, suffix, new_scope, alternate_name=None): """ Replace all occurrences of variables (``ASTVariable``s) (e.g. ``post_trace'``) in the node with ``ASTExternalVariable``s, indicating that they are moved to the postsynaptic neuron. """ def replace_var(_expr=None): if isinstance(_expr, ASTSimpleExpression) and _expr.is_variable(): var = _expr.get_variable() elif isinstance(_expr, ASTVariable): var = _expr else: return if var.get_name() != var_name: return ast_ext_var = ASTExternalVariable(var.get_name() + suffix, differential_order=var.get_differential_order(), source_position=var.get_source_position()) if alternate_name: ast_ext_var.set_alternate_name(alternate_name) ast_ext_var.update_alt_scope(new_scope) from pynestml.visitors.ast_symbol_table_visitor import ASTSymbolTableVisitor ast_ext_var.accept(ASTSymbolTableVisitor()) if isinstance(_expr, ASTSimpleExpression) and _expr.is_variable(): Logger.log_message(None, -1, "ASTSimpleExpression replacement made (var = " + str( ast_ext_var.get_name()) + ") in expression: " + str(node.get_parent(_expr)), None, LoggingLevel.INFO) _expr.set_variable(ast_ext_var) return if isinstance(_expr, ASTVariable): if isinstance(node.get_parent(_expr), ASTAssignment): node.get_parent(_expr).lhs = ast_ext_var Logger.log_message(None, -1, "ASTVariable replacement made in expression: " + str(node.get_parent(_expr)), None, LoggingLevel.INFO) elif isinstance(node.get_parent(_expr), ASTSimpleExpression) and node.get_parent(_expr).is_variable(): node.get_parent(_expr).set_variable(ast_ext_var) elif isinstance(node.get_parent(_expr), ASTDeclaration): # variable could occur on the left-hand side; ignore. Only replace if it occurs on the right-hand side. pass else: Logger.log_message(None, -1, "Error: unhandled use of variable " + var_name + " in expression " + str(_expr), None, LoggingLevel.INFO) raise Exception() return p = node.get_parent(var) Logger.log_message(None, -1, "Error: unhandled use of variable " + var_name + " in expression " + str(p), None, LoggingLevel.INFO) raise Exception() node.accept(ASTHigherOrderVisitor(lambda x: replace_var(x)))
def replace_rhs_variable(expr, variable_name_to_replace, kernel_var, spike_buf): def replace_kernel_var(node): if type(node) is ASTSimpleExpression \ and node.is_variable() \ and node.get_variable().get_name() == variable_name_to_replace: var_order = node.get_variable().get_differential_order() new_variable_name = construct_kernel_X_spike_buf_name( kernel_var.get_name(), spike_buf, var_order - 1, diff_order_symbol="'") new_variable = ASTVariable(new_variable_name, var_order) new_variable.set_source_position(node.get_variable().get_source_position()) node.set_variable(new_variable) expr.accept(ASTHigherOrderVisitor(visit_funcs=replace_kernel_var))
def add_suffix_to_variable_name(cls, var_name: str, astnode: ASTNode, suffix: str): """add suffix to variable by given name recursively throughout astnode""" def replace_var(_expr=None): if isinstance(_expr, ASTSimpleExpression) and _expr.is_variable(): var = _expr.get_variable() elif isinstance(_expr, ASTVariable): var = _expr else: return if not suffix in var.get_name() \ and not var.get_name() == "t" \ and var.get_name() == var_name: var.set_name(var.get_name() + suffix) astnode.accept(ASTHigherOrderVisitor(lambda x: replace_var(x)))
def get_function_calls(cls, ast_node, function_list): """ For a handed over list of function names, this method retrieves all functions in the meta_model. :param ast_node: a single meta_model node :type ast_node: ASTNode :param function_list: a list of function names :type function_list: list(str) :return: a list of all functions in the meta_model :rtype: list(ASTFunctionCall) """ res = list() from pynestml.visitors.ast_higher_order_visitor import ASTHigherOrderVisitor from pynestml.meta_model.ast_function_call import ASTFunctionCall fun = (lambda x: res.append(x) if isinstance(x, ASTFunctionCall) and x.get_name() in function_list else True) vis = ASTHigherOrderVisitor(visit_funcs=fun) ast_node.accept(vis) return res
def get_cond_sum_function_calls(cls, node): """ Collects all cond_sum function calls in the meta_model. :param node: a single meta_model node :type node: AST_ :return: a list of all functions in the meta_model :rtype: list(ASTFunctionCall) """ res = list() def loc_get_cond_sum(a_node): if isinstance(a_node, ASTFunctionCall) and a_node.get_name( ) == PredefinedFunctions.COND_SUM: res.append(a_node) node.accept(ASTHigherOrderVisitor(loc_get_cond_sum)) return res
def collect_variable_names_in_expression(expr): """collect all occurrences of variables (`ASTVariable`), kernels (`ASTKernel`) XXX ... """ vars_used_ = [] def collect_vars(_expr=None): var = None if isinstance(_expr, ASTSimpleExpression) and _expr.is_variable(): var = _expr.get_variable() elif isinstance(_expr, ASTVariable): var = _expr if var: vars_used_.append(var) expr.accept(ASTHigherOrderVisitor(lambda x: collect_vars(x))) return vars_used_
def get_function_call(cls, ast, function_name): """ Collects for a given name all function calls in a given meta_model node. :param ast: a single node :type ast: ast_node :param function_name: the name of the function :type function_name: str :return: a list of all function calls contained in _ast :rtype: list(ASTFunctionCall) """ from pynestml.visitors.ast_higher_order_visitor import ASTHigherOrderVisitor ret = list() def loc_get_function(node): if isinstance(node, ASTFunctionCall) and node.get_name() == function_name: ret.append(node) ast.accept(ASTHigherOrderVisitor(loc_get_function, list())) return ret
def get_all(cls, ast, node_type): """ Finds all meta_model which are part of the tree as spanned by the handed over meta_model. The type has to be specified. :param ast: a single meta_model node :type ast: AST_ :param node_type: the type :type node_type: AST_ :return: a list of all meta_model of the specified type :rtype: list(AST_) """ from pynestml.visitors.ast_higher_order_visitor import ASTHigherOrderVisitor ret = list() def loc_get_all_of_type(node): if isinstance(node, node_type): ret.append(node) ast.accept(ASTHigherOrderVisitor(visit_funcs=loc_get_all_of_type)) return ret
def collect_variable_names_in_expression( cls, expr: ASTNode) -> List[ASTVariable]: """ Collect all occurrences of variables (`ASTVariable`), kernels (`ASTKernel`) XXX ... :param expr: expression to collect the variables from :return: a list of variables """ vars_used_ = [] def collect_vars(_expr=None): var = None if isinstance(_expr, ASTSimpleExpression) and _expr.is_variable(): var = _expr.get_variable() elif isinstance(_expr, ASTVariable): var = _expr if var: vars_used_.append(var) expr.accept(ASTHigherOrderVisitor(lambda x: collect_vars(x))) return vars_used_
def replace_function_call_through_first_argument(cls, ast, function_name_to_replace): """ Replaces all occurrences of the handed over function call by the first argument. :param ast: a single ast node :type ast: AST_ :param function_name_to_replace: the function to replace :type function_name_to_replace: ASTFunctionCall """ # we define a local collection operation def replace_function_call_through_first_argument(_expr=None): if _expr.is_function_call() and _expr.get_function_call( ) == function_name_to_replace: first_arg = _expr.get_function_call().get_args( )[0].get_variable() _expr.set_function_call(None) _expr.set_variable(first_arg) return func = (lambda x: replace_function_call_through_first_argument(x) if isinstance(x, ASTSimpleExpression) else True) ast.accept(ASTHigherOrderVisitor(func))
def add_suffix_to_variable_names(cls, astnode: Union[ASTNode, List], suffix: str): """add suffix to variable names recursively throughout astnode""" from pynestml.visitors.ast_symbol_table_visitor import ASTSymbolTableVisitor from pynestml.symbols.variable_symbol import BlockType from pynestml.symbols.variable_symbol import VariableSymbol, BlockType, VariableType if not isinstance(astnode, ASTNode): for node in astnode: ASTUtils.add_suffix_to_variable_names(node, suffix) return def replace_var(_expr=None): if isinstance(_expr, ASTSimpleExpression) and _expr.is_variable(): var = _expr.get_variable() elif isinstance(_expr, ASTVariable): var = _expr else: return if not suffix in var.get_name() \ and not var.get_name() == "t": var.set_name(var.get_name() + suffix) astnode.accept(ASTHigherOrderVisitor(lambda x: replace_var(x)))
def parse_while_stmt(cls, string): # type: (str) -> ASTWhileStmt (builder, parser) = tokenize(string) ret = builder.visit(parser.whileStmt()) ret.accept(ASTHigherOrderVisitor(log_set_added_source_position)) return ret
def parse_variable(cls, string): # type: (str) -> ASTVariable (builder, parser) = tokenize(string) ret = builder.visit(parser.variable()) ret.accept(ASTHigherOrderVisitor(log_set_added_source_position)) return ret
def parse_update_block(cls, string): # type: (str) -> ASTUpdateBlock (builder, parser) = tokenize(string) ret = builder.visit(parser.updateBlock()) ret.accept(ASTHigherOrderVisitor(log_set_added_source_position)) return ret
def parse_unary_operator(cls, string): # type: (str) -> ASTUnaryOperator (builder, parser) = tokenize(string) ret = builder.visit(parser.unaryOperator()) ret.accept(ASTHigherOrderVisitor(log_set_added_source_position)) return ret
def parse_simple_expression(cls, string): # type: (str) -> ASTSimpleExpression (builder, parser) = tokenize(string) ret = builder.visit(parser.simpleExpression()) ret.accept(ASTHigherOrderVisitor(log_set_added_source_position)) return ret
def parse_ode_equation(cls, string): # type: (str) -> ASTOdeEquation (builder, parser) = tokenize(string) ret = builder.visit(parser.odeEquation()) ret.accept(ASTHigherOrderVisitor(log_set_added_source_position)) return ret