def make_trivial_assignment(var, order, equations_block, is_shape=False): from pynestml.meta_model.ast_variable import ASTVariable from pynestml.meta_model.ast_equations_block import ASTEquationsBlock from pynestml.meta_model.ast_node import ASTNode # type: (ASTVariable,int,ASTEquationsBlock,bool) -> ASTNode lhs_variable = ASTNodeFactory.create_ast_variable( name=var.get_name(), differential_order=order + 1, source_position=ASTSourceLocation.get_added_source_position()) rhs_variable = ASTNodeFactory.create_ast_variable( name=convert_variable_name_to_generator_notation(var).get_name(), differential_order=order, source_position=ASTSourceLocation.get_added_source_position()) expression = ASTNodeFactory.create_ast_simple_expression( variable=rhs_variable, source_position=ASTSourceLocation.get_added_source_position()) source_loc = ASTSourceLocation.get_added_source_position() if is_shape: node = ASTNodeFactory.create_ast_ode_shape(lhs=lhs_variable, rhs=expression, source_position=source_loc) else: node = ASTNodeFactory.create_ast_ode_equation( lhs=lhs_variable, rhs=expression, source_position=source_loc) equations_block.get_declarations().append(node) return node
def add_assignment_to_update_block(assignment, neuron): """ Adds a single assignment to the end of the update block of the handed over neuron. :param assignment: a single assignment :param neuron: a single neuron instance :return: the modified neuron """ small_stmt = ASTNodeFactory.create_ast_small_stmt(assignment=assignment, source_position=ASTSourceLocation.get_added_source_position()) stmt = ASTNodeFactory.create_ast_stmt(small_stmt=small_stmt, source_position=ASTSourceLocation.get_added_source_position()) neuron.get_update_blocks().get_block().get_stmts().append(stmt) return neuron
def add_declaration_to_update_block(declaration, neuron): # type: (ASTDeclaration, ASTNeuron) -> ASTNeuron """ Adds a single declaration to the end of the update block of the handed over neuron. :param declaration: ASTDeclaration node to add :param neuron: a single neuron instance :return: a modified neuron """ small_stmt = ASTNodeFactory.create_ast_small_stmt(declaration=declaration, source_position=ASTSourceLocation.get_added_source_position()) stmt = ASTNodeFactory.create_ast_stmt(small_stmt=small_stmt, source_position=ASTSourceLocation.get_added_source_position()) neuron.get_update_blocks().get_block().get_stmts().append(stmt) return neuron
def add_assignment_to_update_block(assignment, neuron): """ Adds a single assignment to the end of the update block of the handed over neuron. :param assignment: a single assignment :param neuron: a single neuron instance :return: the modified neuron """ small_stmt = ASTNodeFactory.create_ast_small_stmt( assignment=assignment, source_position=ASTSourceLocation.get_added_source_position()) stmt = ASTNodeFactory.create_ast_stmt( small_stmt=small_stmt, source_position=ASTSourceLocation.get_added_source_position()) neuron.get_update_blocks().get_block().get_stmts().append(stmt) return neuron
def visit_while_stmt(self, node): """ Visits a single while stmt and checks that its condition is of boolean type. :param node: a single while stmt :type node: ASTWhileStmt """ if node.get_source_position().equals( ASTSourceLocation.get_added_source_position()): # no type checks are executed for added nodes, since we assume correctness return cond_type = node.get_condition().type if isinstance(cond_type, ErrorTypeSymbol): code, message = Messages.get_type_could_not_be_derived( node.get_condition()) Logger.log_message( code=code, message=message, error_position=node.get_condition().get_source_position(), log_level=LoggingLevel.ERROR) elif not cond_type.equals(PredefinedTypes.get_boolean_type()): code, message = Messages.get_type_different_from_expected( PredefinedTypes.get_boolean_type(), cond_type) Logger.log_message( code=code, message=message, error_position=node.get_condition().get_source_position(), log_level=LoggingLevel.ERROR) return
def add_declaration_to_update_block(declaration, neuron): # type: (ASTDeclaration, ASTNeuron) -> ASTNeuron """ Adds a single declaration to the end of the update block of the handed over neuron. :param declaration: ASTDeclaration node to add :param neuron: a single neuron instance :return: a modified neuron """ small_stmt = ASTNodeFactory.create_ast_small_stmt( declaration=declaration, source_position=ASTSourceLocation.get_added_source_position()) stmt = ASTNodeFactory.create_ast_stmt( small_stmt=small_stmt, source_position=ASTSourceLocation.get_added_source_position()) neuron.get_update_blocks().get_block().get_stmts().append(stmt) return neuron
def get_source_position(self): """ Returns the source position of the element. :return: a source position object. :rtype: ASTSourceLocation """ if self.sourcePosition is not None: return self.sourcePosition else: return ASTSourceLocation.get_predefined_source_position()
def create_source_pos(ctx): """ Returns a new source location object. Used in order to avoid code duplication. :param ctx: a context variable :return: ctx """ return ASTSourceLocation.make_ast_source_position(start_line=ctx.start.line, start_column=ctx.start.column, end_line=ctx.stop.line, end_column=ctx.stop.column)
def setUp(self): Logger.init_logger(LoggingLevel.INFO) SymbolTable.initialize_symbol_table( ASTSourceLocation(start_line=0, start_column=0, end_line=0, end_column=0)) PredefinedUnits.register_units() PredefinedTypes.register_types() PredefinedVariables.register_variables() PredefinedFunctions.register_functions()
def create_internal_block(cls, neuron): """ Creates a single internal block in the handed over neuron. :param neuron: a single neuron :type neuron: ast_neuron :return: the modified neuron :rtype: ast_neuron """ from pynestml.meta_model.ast_node_factory import ASTNodeFactory if neuron.get_internals_blocks() is None: internal = ASTNodeFactory.create_ast_block_with_variables(False, False, True, False, list(), ASTSourceLocation.get_added_source_position()) neuron.get_body().get_body_elements().append(internal) return neuron
def visit_for_stmt(self, node): """ Visits a single for stmt and checks that all it parts are correctly defined. :param node: a single for stmt :type node: ASTForStmt """ if node.get_source_position().equals( ASTSourceLocation.get_added_source_position()): # no type checks are executed for added nodes, since we assume correctness return # check that the from stmt is an integer or real from_type = node.get_start_from().type if isinstance(from_type, ErrorTypeSymbol): code, message = Messages.get_type_could_not_be_derived( node.get_start_from()) Logger.log_message( code=code, message=message, error_position=node.get_start_from().get_source_position(), log_level=LoggingLevel.ERROR) elif not (from_type.equals(PredefinedTypes.get_integer_type()) or from_type.equals(PredefinedTypes.get_real_type())): code, message = Messages.get_type_different_from_expected( PredefinedTypes.get_integer_type(), from_type) Logger.log_message( code=code, message=message, error_position=node.get_start_from().get_source_position(), log_level=LoggingLevel.ERROR) # check that the to stmt is an integer or real to_type = node.get_end_at().type if isinstance(to_type, ErrorTypeSymbol): code, message = Messages.get_type_could_not_be_derived( node.get_end_at()) Logger.log_message( code=code, message=message, error_position=node.get_end_at().get_source_position(), log_level=LoggingLevel.ERROR) elif not (to_type.equals(PredefinedTypes.get_integer_type()) or to_type.equals(PredefinedTypes.get_real_type())): code, message = Messages.get_type_different_from_expected( PredefinedTypes.get_integer_type(), to_type) Logger.log_message( code=code, message=message, error_position=node.get_end_at().get_source_position(), log_level=LoggingLevel.ERROR) return
def setUp(self): PredefinedUnits.register_units() PredefinedTypes.register_types() PredefinedFunctions.register_functions() PredefinedVariables.register_variables() SymbolTable.initialize_symbol_table( ASTSourceLocation(start_line=0, start_column=0, end_line=0, end_column=0)) Logger.init_logger(LoggingLevel.INFO) self.target_path = str( os.path.realpath( os.path.join(os.path.dirname(__file__), os.path.join(os.pardir, 'target'))))
def create_initial_values_block(cls, neuron): """ Creates a single initial values block in the handed over neuron. :param neuron: a single neuron :type neuron: ast_neuron :return: the modified neuron :rtype: ast_neuron """ # local import since otherwise circular dependency from pynestml.meta_model.ast_node_factory import ASTNodeFactory if neuron.get_initial_blocks() is None: initial_values = ASTNodeFactory. \ create_ast_block_with_variables(False, False, False, True, list(), ASTSourceLocation.get_added_source_position()) neuron.get_body().get_body_elements().append(initial_values) return neuron
def visit_assignment(self, node): """ Visits a single expression and assures that type(lhs) == type(rhs). :param node: a single assignment. :type node: ASTAssignment """ from pynestml.meta_model.ast_assignment import ASTAssignment assert isinstance(node, ASTAssignment) if node.get_source_position().equals( ASTSourceLocation.get_added_source_position()): # no type checks are executed for added nodes, since we assume correctness return if node.is_direct_assignment: # case a = b is simple self.handle_simple_assignment(node) else: self.handle_complex_assignment(node) # e.g. a *= b return
def visit_declaration(self, node): """ Visits a single declaration and asserts that type of lhs is equal to type of rhs. :param node: a single declaration. :type node: ASTDeclaration """ assert isinstance(node, ASTDeclaration) if node.has_expression(): if node.get_expression().get_source_position().equals( ASTSourceLocation.get_added_source_position()): # no type checks are executed for added nodes, since we assume correctness return lhs_type = node.get_data_type().get_type_symbol() rhs_type = node.get_expression().type if isinstance(rhs_type, ErrorTypeSymbol): LoggingHelper.drop_missing_type_error(node) return if self.__types_do_not_match(lhs_type, rhs_type): TypeCaster.try_to_recover_or_error(lhs_type, rhs_type, node.get_expression()) return
def make_trivial_assignment(var, order, equations_block, is_shape=False): from pynestml.meta_model.ast_variable import ASTVariable from pynestml.meta_model.ast_equations_block import ASTEquationsBlock from pynestml.meta_model.ast_node import ASTNode # type: (ASTVariable,int,ASTEquationsBlock,bool) -> ASTNode lhs_variable = ASTNodeFactory.create_ast_variable(name=var.get_name(), differential_order=order + 1, source_position=ASTSourceLocation. get_added_source_position()) rhs_variable = ASTNodeFactory.create_ast_variable(name=convert_variable_name_to_generator_notation(var).get_name(), differential_order=order, source_position=ASTSourceLocation. get_added_source_position()) expression = ASTNodeFactory.create_ast_simple_expression(variable=rhs_variable, source_position=ASTSourceLocation. get_added_source_position()) source_loc = ASTSourceLocation.get_added_source_position() if is_shape: node = ASTNodeFactory.create_ast_ode_shape(lhs=lhs_variable, rhs=expression, source_position=source_loc) else: node = ASTNodeFactory.create_ast_ode_equation(lhs=lhs_variable, rhs=expression, source_position=source_loc) equations_block.get_declarations().append(node) return node
from pynestml.codegeneration.unit_converter import UnitConverter from pynestml.meta_model.ast_source_location import ASTSourceLocation from pynestml.symbol_table.symbol_table import SymbolTable from pynestml.symbols.predefined_functions import PredefinedFunctions from pynestml.symbols.predefined_types import PredefinedTypes from pynestml.symbols.predefined_units import PredefinedUnits from pynestml.symbols.predefined_variables import PredefinedVariables from pynestml.symbols.symbol import SymbolKind from pynestml.symbols.unit_type_symbol import UnitTypeSymbol from pynestml.utils.logger import Logger, LoggingLevel from pynestml.utils.messages import MessageCode from pynestml.utils.model_parser import ModelParser from pynestml.visitors.ast_visitor import ASTVisitor # minor setup steps required SymbolTable.initialize_symbol_table(ASTSourceLocation(start_line=0, start_column=0, end_line=0, end_column=0)) PredefinedUnits.register_units() PredefinedTypes.register_types() PredefinedVariables.register_variables() PredefinedFunctions.register_functions() class ExpressionTestVisitor(ASTVisitor): def endvisit_assignment(self, node): scope = node.get_scope() var_name = node.get_variable().get_name() _expr = node.get_expression() var_symbol = scope.resolve_to_symbol(var_name, SymbolKind.VARIABLE)
from pynestml.meta_model.ast_source_location import ASTSourceLocation from pynestml.symbol_table.scope import ScopeType from pynestml.symbol_table.symbol_table import SymbolTable from pynestml.symbols.predefined_functions import PredefinedFunctions from pynestml.symbols.predefined_types import PredefinedTypes from pynestml.symbols.predefined_units import PredefinedUnits from pynestml.symbols.predefined_variables import PredefinedVariables from pynestml.symbols.symbol import SymbolKind from pynestml.utils.logger import Logger, LoggingLevel from pynestml.utils.model_parser import ModelParser # minor setup steps required Logger.init_logger(LoggingLevel.NO) SymbolTable.initialize_symbol_table( ASTSourceLocation(start_line=0, start_column=0, end_line=0, end_column=0)) PredefinedUnits.register_units() PredefinedTypes.register_types() PredefinedVariables.register_variables() PredefinedFunctions.register_functions() class SymbolTableResolutionTest(unittest.TestCase): """ This test is used to check if the resolution of symbols works as expected. """ def test(self): model = ModelParser.parse_model( os.path.join( os.path.realpath( os.path.join(os.path.dirname(__file__), 'resources',
def log_set_added_source_position(node): node.set_source_position(ASTSourceLocation.get_added_source_position())
def visitExpression(self, ctx): # first check if it is a simple rhs if ctx.simpleExpression() is not None: return self.visitSimpleExpression(ctx.simpleExpression()) # now it is not directly a simple rhs # check if it is an encapsulated rhs is_encapsulated = (True if ctx.leftParentheses is not None and ctx.rightParentheses else False) # or a term or negated unary_operator = (self.visit(ctx.unaryOperator()) if ctx.unaryOperator() is not None else None) is_logical_not = (True if ctx.logicalNot is not None else False) expression = self.visit(ctx.term) if ctx.term is not None else None # otherwise it is a combined one, check first lhs, then the operator and finally rhs lhs = (self.visit(ctx.left) if ctx.left is not None else None) if ctx.powOp is not None: source_pos = ASTSourceLocation.make_ast_source_position(start_line=ctx.powOp.line, start_column=ctx.powOp.column, end_line=ctx.powOp.line, end_column=ctx.powOp.column) binary_operator = ASTNodeFactory.create_ast_arithmetic_operator(is_pow_op=True, source_position=source_pos) elif ctx.timesOp is not None: source_pos = ASTSourceLocation.make_ast_source_position(start_line=ctx.timesOp.line, start_column=ctx.timesOp.column, end_line=ctx.timesOp.line, end_column=ctx.timesOp.column) binary_operator = ASTNodeFactory.create_ast_arithmetic_operator(is_times_op=True, source_position=source_pos) elif ctx.divOp is not None: source_pos = ASTSourceLocation.make_ast_source_position(start_line=ctx.divOp.line, start_column=ctx.divOp.column, end_line=ctx.divOp.line, end_column=ctx.divOp.column) binary_operator = ASTNodeFactory.create_ast_arithmetic_operator(is_div_op=True, source_position=source_pos) elif ctx.moduloOp is not None: source_pos = ASTSourceLocation.make_ast_source_position(start_line=ctx.moduloOp.line, start_column=ctx.moduloOp.column, end_line=ctx.moduloOp.line, end_column=ctx.moduloOp.column) binary_operator = ASTNodeFactory.create_ast_arithmetic_operator(is_modulo_op=True, source_position=source_pos) elif ctx.plusOp is not None: source_pos = ASTSourceLocation.make_ast_source_position(start_line=ctx.plusOp.line, start_column=ctx.plusOp.column, end_line=ctx.plusOp.line, end_column=ctx.plusOp.column) binary_operator = ASTNodeFactory.create_ast_arithmetic_operator(is_plus_op=True, source_position=source_pos) elif ctx.minusOp is not None: source_pos = ASTSourceLocation.make_ast_source_position(start_line=ctx.minusOp.line, start_column=ctx.minusOp.column, end_line=ctx.minusOp.line, end_column=ctx.minusOp.column) binary_operator = ASTNodeFactory.create_ast_arithmetic_operator(is_minus_op=True, source_position=source_pos) elif ctx.bitOperator() is not None: binary_operator = self.visit(ctx.bitOperator()) elif ctx.comparisonOperator() is not None: binary_operator = self.visit(ctx.comparisonOperator()) elif ctx.logicalOperator() is not None: binary_operator = self.visit(ctx.logicalOperator()) else: binary_operator = None rhs = (self.visit(ctx.right) if ctx.right is not None else None) # not it was not an operator, thus the ternary one ? condition = (self.visit(ctx.condition) if ctx.condition is not None else None) if_true = (self.visit(ctx.ifTrue) if ctx.ifTrue is not None else None) if_not = (self.visit(ctx.ifNot) if ctx.ifNot is not None else None) source_pos = create_source_pos(ctx) # finally construct the corresponding rhs if expression is not None: return ASTNodeFactory.create_ast_expression(is_encapsulated=is_encapsulated, is_logical_not=is_logical_not, unary_operator=unary_operator, expression=expression, source_position=source_pos) elif (lhs is not None) and (rhs is not None) and (binary_operator is not None): return ASTNodeFactory.create_ast_compound_expression(lhs=lhs, binary_operator=binary_operator, rhs=rhs, source_position=source_pos) elif (condition is not None) and (if_true is not None) and (if_not is not None): return ASTNodeFactory.create_ast_ternary_expression(condition=condition, if_true=if_true, if_not=if_not, source_position=source_pos) else: raise RuntimeError('Type of rhs @%s,%s not recognized!' % (ctx.start.line, ctx.start.column))