def add_suffix_to_variable_name(cls, var_name: str, astnode: ASTNode, suffix: str, scope=None): """add suffix to variable by given name recursively throughout astnode""" from pynestml.visitors.ast_symbol_table_visitor import ASTSymbolTableVisitor from pynestml.symbols.variable_symbol import BlockType from pynestml.symbols.variable_symbol import VariableSymbol, BlockType, VariableType def replace_var(_expr=None): if isinstance(_expr, ASTSimpleExpression) and _expr.is_variable(): var = _expr.get_variable() elif isinstance(_expr, ASTVariable): var = _expr else: return if not suffix in var.get_name() \ and not var.get_name() == "t" \ and var.get_name() == var_name: var.set_name(var.get_name() + suffix) # if scope is not None: # symbol = VariableSymbol(name=var.get_name(), block_type=BlockType.PARAMETERS, # type_symbol=var.get_type_symbol(), # variable_type=VariableType.VARIABLE) # scope.add_symbol(symbol) # var.update_scope(scope) # assert scope.resolve_to_symbol(var.get_name(), SymbolKind.VARIABLE) is not None # assert scope.resolve_to_symbol(var.get_name(), SymbolKind.VARIABLE).block_type == BlockType.PARAMETERS # #var.accept(ASTSymbolTableVisitor()) astnode.accept(ASTHigherOrderVisitor(lambda x: replace_var(x)))
def get_all_variables(cls, node: ASTNode) -> List[str]: """Make a list of all variable symbol names that are in ``node``""" class ASTVariablesFinderVisitor(ASTVisitor): _variables = [] def __init__(self): super(ASTVariablesFinderVisitor, self).__init__() def visit_declaration(self, node): symbol = node.get_scope().resolve_to_symbol(node.get_variables()[0].get_complete_name(), SymbolKind.VARIABLE) if symbol is None: code, message = Messages.get_variable_not_defined(node.get_variable().get_complete_name()) Logger.log_message(code=code, message=message, error_position=node.get_source_position(), log_level=LoggingLevel.ERROR, astnode=node) return self._variables.append(symbol) if node is None: return [] visitor = ASTVariablesFinderVisitor() node.accept(visitor) all_variables = [v.name for v in visitor._variables] return all_variables
def get_inline_expression_symbols(cls, ast: ASTNode) -> List[VariableSymbol]: """ For the handed over AST node, this method collects all inline expression variable symbols in it. :param ast: a single AST node :return: a list of all inline expression variable symbols """ from pynestml.visitors.ast_higher_order_visitor import ASTHigherOrderVisitor from pynestml.meta_model.ast_variable import ASTVariable res = list() def loc_get_vars(node): if isinstance(node, ASTVariable): res.append(node) ast.accept(ASTHigherOrderVisitor(visit_funcs=loc_get_vars)) ret = list() for var in res: if '\'' not in var.get_complete_name(): symbol = ast.get_scope().resolve_to_symbol( var.get_complete_name(), SymbolKind.VARIABLE) if symbol is not None and symbol.is_inline_expression: ret.append(symbol) return ret
def replace_with_external_variable(cls, var_name, node: ASTNode, suffix, new_scope, alternate_name=None): """ Replace all occurrences of variables (``ASTVariable``s) (e.g. ``post_trace'``) in the node with ``ASTExternalVariable``s, indicating that they are moved to the postsynaptic neuron. """ def replace_var(_expr=None): if isinstance(_expr, ASTSimpleExpression) and _expr.is_variable(): var = _expr.get_variable() elif isinstance(_expr, ASTVariable): var = _expr else: return if var.get_name() != var_name: return ast_ext_var = ASTExternalVariable(var.get_name() + suffix, differential_order=var.get_differential_order(), source_position=var.get_source_position()) if alternate_name: ast_ext_var.set_alternate_name(alternate_name) ast_ext_var.update_alt_scope(new_scope) from pynestml.visitors.ast_symbol_table_visitor import ASTSymbolTableVisitor ast_ext_var.accept(ASTSymbolTableVisitor()) if isinstance(_expr, ASTSimpleExpression) and _expr.is_variable(): Logger.log_message(None, -1, "ASTSimpleExpression replacement made (var = " + str( ast_ext_var.get_name()) + ") in expression: " + str(node.get_parent(_expr)), None, LoggingLevel.INFO) _expr.set_variable(ast_ext_var) return if isinstance(_expr, ASTVariable): if isinstance(node.get_parent(_expr), ASTAssignment): node.get_parent(_expr).lhs = ast_ext_var Logger.log_message(None, -1, "ASTVariable replacement made in expression: " + str(node.get_parent(_expr)), None, LoggingLevel.INFO) elif isinstance(node.get_parent(_expr), ASTSimpleExpression) and node.get_parent(_expr).is_variable(): node.get_parent(_expr).set_variable(ast_ext_var) elif isinstance(node.get_parent(_expr), ASTDeclaration): # variable could occur on the left-hand side; ignore. Only replace if it occurs on the right-hand side. pass else: Logger.log_message(None, -1, "Error: unhandled use of variable " + var_name + " in expression " + str(_expr), None, LoggingLevel.INFO) raise Exception() return p = node.get_parent(var) Logger.log_message(None, -1, "Error: unhandled use of variable " + var_name + " in expression " + str(p), None, LoggingLevel.INFO) raise Exception() node.accept(ASTHigherOrderVisitor(lambda x: replace_var(x)))
def add_suffix_to_variable_name(cls, var_name: str, astnode: ASTNode, suffix: str): """add suffix to variable by given name recursively throughout astnode""" def replace_var(_expr=None): if isinstance(_expr, ASTSimpleExpression) and _expr.is_variable(): var = _expr.get_variable() elif isinstance(_expr, ASTVariable): var = _expr else: return if not suffix in var.get_name() \ and not var.get_name() == "t" \ and var.get_name() == var_name: var.set_name(var.get_name() + suffix) astnode.accept(ASTHigherOrderVisitor(lambda x: replace_var(x)))
def collect_variable_names_in_expression( cls, expr: ASTNode) -> List[ASTVariable]: """ Collect all occurrences of variables (`ASTVariable`), kernels (`ASTKernel`) XXX ... :param expr: expression to collect the variables from :return: a list of variables """ vars_used_ = [] def collect_vars(_expr=None): var = None if isinstance(_expr, ASTSimpleExpression) and _expr.is_variable(): var = _expr.get_variable() elif isinstance(_expr, ASTVariable): var = _expr if var: vars_used_.append(var) expr.accept(ASTHigherOrderVisitor(lambda x: collect_vars(x))) return vars_used_
def get_all_variables_used_in_convolutions(cls, node: ASTNode, parent_node: ASTNode) -> List[str]: """Make a list of all variable symbol names that are in ``node`` and used in a convolution""" from pynestml.codegeneration.ast_transformers import ASTTransformers class ASTAllVariablesUsedInConvolutionVisitor(ASTVisitor): _variables = [] parent_node = None def __init__(self, node, parent_node): super(ASTAllVariablesUsedInConvolutionVisitor, self).__init__() self.node = node self.parent_node = parent_node def visit_function_call(self, node): func_name = node.get_name() if func_name == 'convolve': symbol_buffer = node.get_scope().resolve_to_symbol(str(node.get_args()[1]), SymbolKind.VARIABLE) input_port = ASTTransformers.get_input_port_by_name( self.parent_node.get_input_blocks(), symbol_buffer.name) if input_port: found_parent_assignment = False node_ = node while not found_parent_assignment: node_ = self.parent_node.get_parent(node_) # XXX TODO also needs to accept normal ASTExpression, ASTAssignment? if isinstance(node_, ASTInlineExpression): found_parent_assignment = True var_name = node_.get_variable_name() self._variables.append(var_name) if node is None: return [] visitor = ASTAllVariablesUsedInConvolutionVisitor(node, parent_node) node.accept(visitor) return visitor._variables
def check_co_co(cls, neuron: ASTNode, after_ast_rewrite: bool): """ Checks if this coco applies for the handed over neuron. Models which use not defined elements are not correct. :param neuron: a single neuron instance. :type neuron: ast_neuron """ # for each variable in all expressions, check if the variable has been defined previously expression_collector_visitor = ASTExpressionCollectorVisitor() neuron.accept(expression_collector_visitor) expressions = expression_collector_visitor.ret for expr in expressions: for var in expr.get_variables(): symbol = var.get_scope().resolve_to_symbol( var.get_complete_name(), SymbolKind.VARIABLE) # this part is required to check that we handle invariants differently expr_par = neuron.get_parent(expr) if symbol is None: # check if this symbol is actually a type, e.g. "mV" in the expression "(1 + 2) * mV" symbol = var.get_scope().resolve_to_symbol( var.get_complete_name(), SymbolKind.TYPE) if symbol is None: # symbol has not been defined; neither as a variable name nor as a type symbol code, message = Messages.get_variable_not_defined( var.get_name()) Logger.log_message( node=neuron, code=code, message=message, log_level=LoggingLevel.ERROR, error_position=var.get_source_position()) # first check if it is part of an invariant # if it is the case, there is no "recursive" declaration # so check if the parent is a declaration and the expression the invariant elif isinstance( expr_par, ASTDeclaration) and expr_par.get_invariant() == expr: # in this case its ok if it is recursive or defined later on continue # now check if it has been defined before usage, except for predefined symbols, buffers and variables added by the AST transformation functions elif (not symbol.is_predefined) \ and symbol.block_type != BlockType.INPUT_BUFFER_CURRENT \ and symbol.block_type != BlockType.INPUT_BUFFER_SPIKE \ and not symbol.get_referenced_object().get_source_position().is_added_source_position(): # except for parameters, those can be defined after if ((not symbol.get_referenced_object( ).get_source_position().before(var.get_source_position())) and (not symbol.block_type in [BlockType.PARAMETERS, BlockType.INTERNALS])): code, message = Messages.get_variable_used_before_declaration( var.get_name()) Logger.log_message( node=neuron, message=message, error_position=var.get_source_position(), code=code, log_level=LoggingLevel.ERROR) # now check that they are now defined recursively, e.g. V_m mV = V_m + 1 # todo: we should not check this for invariants if (symbol.get_referenced_object().get_source_position( ).encloses(var.get_source_position()) and not symbol.get_referenced_object(). get_source_position().is_added_source_position()): code, message = Messages.get_variable_defined_recursively( var.get_name()) Logger.log_message( node=neuron, code=code, message=message, error_position=symbol.get_referenced_object( ).get_source_position(), log_level=LoggingLevel.ERROR) # now check for each assignment whether the left hand side variable is defined vis = ASTAssignedVariableDefinedVisitor(neuron, after_ast_rewrite) neuron.accept(vis)