Exemplo n.º 1
0
 def create_empty_update_block(self):
     """
     Create an empty update block. Only makes sense if one does not already exist.
     """
     assert self.get_update_blocks() is None or len(self.get_update_blocks(
     )) == 0, "create_empty_update_block() called although update block already present"
     from pynestml.meta_model.ast_node_factory import ASTNodeFactory
     block = ASTNodeFactory.create_ast_block([], ASTSourceLocation.get_predefined_source_position())
     update_block = ASTNodeFactory.create_ast_update_block(block, ASTSourceLocation.get_predefined_source_position())
     self.get_body().get_body_elements().append(update_block)
 def setUp(self):
     Logger.init_logger(LoggingLevel.INFO)
     SymbolTable.initialize_symbol_table(ASTSourceLocation(start_line=0, start_column=0, end_line=0, end_column=0))
     PredefinedUnits.register_units()
     PredefinedTypes.register_types()
     PredefinedVariables.register_variables()
     PredefinedFunctions.register_functions()
Exemplo n.º 3
0
 def visit_while_stmt(self, node):
     """
     Visits a single while stmt and checks that its condition is of boolean type.
     :param node: a single while stmt
     :type node: ASTWhileStmt
     """
     if node.get_source_position().equals(
             ASTSourceLocation.get_added_source_position()):
         # no type checks are executed for added nodes, since we assume correctness
         return
     cond_type = node.get_condition().type
     if isinstance(cond_type, ErrorTypeSymbol):
         code, message = Messages.get_type_could_not_be_derived(
             node.get_condition())
         Logger.log_message(
             code=code,
             message=message,
             error_position=node.get_condition().get_source_position(),
             log_level=LoggingLevel.ERROR)
     elif not cond_type.equals(PredefinedTypes.get_boolean_type()):
         code, message = Messages.get_type_different_from_expected(
             PredefinedTypes.get_boolean_type(), cond_type)
         Logger.log_message(
             code=code,
             message=message,
             error_position=node.get_condition().get_source_position(),
             log_level=LoggingLevel.ERROR)
     return
Exemplo n.º 4
0
def add_assignment_to_update_block(assignment: ASTAssignment, neuron: ASTNeuron) -> ASTNeuron:
    """
    Adds a single assignment to the end of the update block of the handed over neuron.
    :param assignment: a single assignment
    :param neuron: a single neuron instance
    :return: the modified neuron
    """
    small_stmt = ASTNodeFactory.create_ast_small_stmt(assignment=assignment,
                                                      source_position=ASTSourceLocation.get_added_source_position())
    stmt = ASTNodeFactory.create_ast_stmt(small_stmt=small_stmt,
                                          source_position=ASTSourceLocation.get_added_source_position())
    if not neuron.get_update_blocks():
        neuron.create_empty_update_block()
    neuron.get_update_blocks().get_block().get_stmts().append(stmt)
    small_stmt.update_scope(neuron.get_update_blocks().get_block().get_scope())
    stmt.update_scope(neuron.get_update_blocks().get_block().get_scope())
    return neuron
Exemplo n.º 5
0
def add_declaration_to_update_block(declaration: ASTDeclaration, neuron: ASTNeuron) -> ASTNeuron:
    """
    Adds a single declaration to the end of the update block of the handed over neuron.
    :param declaration: ASTDeclaration node to add
    :param neuron: a single neuron instance
    :return: a modified neuron
    """
    small_stmt = ASTNodeFactory.create_ast_small_stmt(declaration=declaration,
                                                      source_position=ASTSourceLocation.get_added_source_position())
    stmt = ASTNodeFactory.create_ast_stmt(small_stmt=small_stmt,
                                          source_position=ASTSourceLocation.get_added_source_position())
    if not neuron.get_update_blocks():
        neuron.create_empty_update_block()
    neuron.get_update_blocks().get_block().get_stmts().append(stmt)
    small_stmt.update_scope(neuron.get_update_blocks().get_block().get_scope())
    stmt.update_scope(neuron.get_update_blocks().get_block().get_scope())
    return neuron
Exemplo n.º 6
0
 def get_source_position(self):
     """
     Returns the source position of the element.
     :return: a source position object.
     :rtype: ASTSourceLocation
     """
     if self.source_position is None:
         return ASTSourceLocation.get_predefined_source_position()
     return self.source_position
Exemplo n.º 7
0
def create_source_pos(ctx):
    """
    Returns a new source location object. Used in order to avoid code duplication.
    :param ctx: a context variable
    :return: ctx
    """
    return ASTSourceLocation.make_ast_source_position(start_line=ctx.start.line,
                                                      start_column=ctx.start.column,
                                                      end_line=ctx.stop.line,
                                                      end_column=ctx.stop.column)
Exemplo n.º 8
0
    def setUp(self) -> None:
        PredefinedUnits.register_units()
        PredefinedTypes.register_types()
        PredefinedFunctions.register_functions()
        PredefinedVariables.register_variables()
        SymbolTable.initialize_symbol_table(ASTSourceLocation(start_line=0, start_column=0, end_line=0, end_column=0))
        Logger.init_logger(LoggingLevel.INFO)

        self.target_path = str(os.path.realpath(os.path.join(os.path.dirname(__file__),
                                                             os.path.join(os.pardir, 'target'))))
Exemplo n.º 9
0
 def visit_for_stmt(self, node):
     """
     Visits a single for stmt and checks that all it parts are correctly defined.
     :param node: a single for stmt
     :type node: ASTForStmt
     """
     if node.get_source_position().equals(
             ASTSourceLocation.get_added_source_position()):
         # no type checks are executed for added nodes, since we assume correctness
         return
     # check that the from stmt is an integer or real
     from_type = node.get_start_from().type
     if isinstance(from_type, ErrorTypeSymbol):
         code, message = Messages.get_type_could_not_be_derived(
             node.get_start_from())
         Logger.log_message(
             code=code,
             message=message,
             error_position=node.get_start_from().get_source_position(),
             log_level=LoggingLevel.ERROR)
     elif not (from_type.equals(PredefinedTypes.get_integer_type())
               or from_type.equals(PredefinedTypes.get_real_type())):
         code, message = Messages.get_type_different_from_expected(
             PredefinedTypes.get_integer_type(), from_type)
         Logger.log_message(
             code=code,
             message=message,
             error_position=node.get_start_from().get_source_position(),
             log_level=LoggingLevel.ERROR)
     # check that the to stmt is an integer or real
     to_type = node.get_end_at().type
     if isinstance(to_type, ErrorTypeSymbol):
         code, message = Messages.get_type_could_not_be_derived(
             node.get_end_at())
         Logger.log_message(
             code=code,
             message=message,
             error_position=node.get_end_at().get_source_position(),
             log_level=LoggingLevel.ERROR)
     elif not (to_type.equals(PredefinedTypes.get_integer_type())
               or to_type.equals(PredefinedTypes.get_real_type())):
         code, message = Messages.get_type_different_from_expected(
             PredefinedTypes.get_integer_type(), to_type)
         Logger.log_message(
             code=code,
             message=message,
             error_position=node.get_end_at().get_source_position(),
             log_level=LoggingLevel.ERROR)
     return
Exemplo n.º 10
0
 def create_internal_block(cls, neuron):
     """
     Creates a single internal block in the handed over neuron.
     :param neuron: a single neuron
     :type neuron: ast_neuron
     :return: the modified neuron
     :rtype: ast_neuron
     """
     from pynestml.meta_model.ast_node_factory import ASTNodeFactory
     if neuron.get_internals_blocks() is None:
         internal = ASTNodeFactory.create_ast_block_with_variables(False, False, True, list(),
                                                                   ASTSourceLocation.get_added_source_position())
         internal.update_scope(neuron.get_scope())
         neuron.get_body().get_body_elements().append(internal)
     return neuron
Exemplo n.º 11
0
 def create_state_block(cls, neuron):
     """
     Creates a single internal block in the handed over neuron.
     :param neuron: a single neuron
     :type neuron: ast_neuron
     :return: the modified neuron
     :rtype: ast_neuron
     """
     # local import since otherwise circular dependency
     from pynestml.meta_model.ast_node_factory import ASTNodeFactory
     if neuron.get_internals_blocks() is None:
         state = ASTNodeFactory.create_ast_block_with_variables(True, False, False, list(),
                                                                ASTSourceLocation.get_added_source_position())
         neuron.get_body().get_body_elements().append(state)
     return neuron
Exemplo n.º 12
0
    def visit_assignment(self, node):
        """
        Visits a single expression and assures that type(lhs) == type(rhs).
        :param node: a single assignment.
        :type node: ASTAssignment
        """
        from pynestml.meta_model.ast_assignment import ASTAssignment
        assert isinstance(node, ASTAssignment)

        if node.get_source_position().equals(
                ASTSourceLocation.get_added_source_position()):
            # no type checks are executed for added nodes, since we assume correctness
            return
        if node.is_direct_assignment:  # case a = b is simple
            self.handle_simple_assignment(node)
        else:
            self.handle_compound_assignment(node)  # e.g. a *= b
Exemplo n.º 13
0
 def visit_declaration(self, node):
     """
     Visits a single declaration and asserts that type of lhs is equal to type of rhs.
     :param node: a single declaration.
     :type node: ASTDeclaration
     """
     assert isinstance(node, ASTDeclaration)
     if node.has_expression():
         if node.get_expression().get_source_position().equals(
                 ASTSourceLocation.get_added_source_position()):
             # no type checks are executed for added nodes, since we assume correctness
             return
         lhs_type = node.get_data_type().get_type_symbol()
         rhs_type = node.get_expression().type
         if isinstance(rhs_type, ErrorTypeSymbol):
             LoggingHelper.drop_missing_type_error(node)
             return
         if self.__types_do_not_match(lhs_type, rhs_type):
             TypeCaster.try_to_recover_or_error(lhs_type, rhs_type,
                                                node.get_expression())
     return
Exemplo n.º 14
0
    def is_conductance_based(self) -> bool:
        """
        Indicates whether this element is conductance based, based on the physical units of the spike input port. If the unit can be cast to Siemens, the function returns True, otherwise it returns False.

        :return: True if conductance based, otherwise False.
        """
        is_cond_based = self.type_symbol.is_castable_to(
            UnitTypeSymbol(unit=PredefinedUnits.get_unit("S")))
        is_curr_based = self.type_symbol.is_castable_to(
            UnitTypeSymbol(unit=PredefinedUnits.get_unit("A")))
        if is_cond_based == is_curr_based:
            code, message = Messages.get_could_not_determine_cond_based(
                type_str=self.type_symbol.print_nestml_type(), name=self.name)
            Logger.log_message(
                node=None,
                code=code,
                message=message,
                log_level=LoggingLevel.WARNING,
                error_position=ASTSourceLocation.get_added_source_position())
            return False

        return is_cond_based
Exemplo n.º 15
0
from pynestml.generated.PyNestMLParser import PyNestMLParser
from pynestml.symbol_table.symbol_table import SymbolTable
from pynestml.symbols.predefined_functions import PredefinedFunctions
from pynestml.symbols.predefined_types import PredefinedTypes
from pynestml.symbols.predefined_units import PredefinedUnits
from pynestml.symbols.predefined_variables import PredefinedVariables
from pynestml.utils.logger import LoggingLevel, Logger
from pynestml.visitors.ast_builder_visitor import ASTBuilderVisitor

# setups the infrastructure
PredefinedUnits.register_units()
PredefinedTypes.register_types()
PredefinedFunctions.register_functions()
PredefinedVariables.register_variables()
SymbolTable.initialize_symbol_table(
    ASTSourceLocation(start_line=0, start_column=0, end_line=0, end_column=0))
Logger.init_logger(LoggingLevel.INFO)


class ASTBuildingTest(unittest.TestCase):
    def test(self):
        for filename in os.listdir(
                os.path.realpath(
                    os.path.join(os.path.dirname(__file__),
                                 os.path.join('..', 'models')))):
            if filename.endswith(".nestml"):
                print('Start creating AST for ' + filename + ' ...'),
                input_file = FileStream(
                    os.path.join(
                        os.path.dirname(__file__),
                        os.path.join(os.path.join('..', 'models'), filename)))
Exemplo n.º 16
0
from pynestml.symbols.predefined_types import PredefinedTypes
from pynestml.symbols.predefined_units import PredefinedUnits
from pynestml.symbols.predefined_variables import PredefinedVariables
from pynestml.utils.ast_source_location import ASTSourceLocation
from pynestml.utils.logger import LoggingLevel, Logger
from pynestml.visitors.ast_builder_visitor import ASTBuilderVisitor
from pynestml.visitors.ast_visitor import ASTVisitor
from pynestml.visitors.comment_collector_visitor import CommentCollectorVisitor


# setups the infrastructure
PredefinedUnits.register_units()
PredefinedTypes.register_types()
PredefinedFunctions.register_functions()
PredefinedVariables.register_variables()
SymbolTable.initialize_symbol_table(ASTSourceLocation(start_line=0, start_column=0, end_line=0, end_column=0))
Logger.init_logger(LoggingLevel.ERROR)


class DocstringCommentException(Exception):
    pass


class DocstringCommentTest(unittest.TestCase):

    def test_docstring_success(self):
        self.run_docstring_test('valid')

    @pytest.mark.xfail(strict=True, raises=DocstringCommentException)
    def test_docstring_failure(self):
        self.run_docstring_test('invalid')
Exemplo n.º 17
0
def log_set_added_source_position(node):
    node.set_source_position(ASTSourceLocation.get_added_source_position())
Exemplo n.º 18
0
 def visitExpression(self, ctx):
     # first check if it is a simple rhs
     if ctx.simpleExpression() is not None:
         return self.visitSimpleExpression(ctx.simpleExpression())
     # now it is not directly a simple rhs
     # check if it is an encapsulated rhs
     is_encapsulated = (True if ctx.leftParentheses is not None and ctx.rightParentheses else False)
     # or a term or negated
     unary_operator = (self.visit(ctx.unaryOperator()) if ctx.unaryOperator() is not None else None)
     is_logical_not = (True if ctx.logicalNot is not None else False)
     expression = self.visit(ctx.term) if ctx.term is not None else None
     # otherwise it is a combined one, check first lhs, then the operator and finally rhs
     lhs = (self.visit(ctx.left) if ctx.left is not None else None)
     if ctx.powOp is not None:
         source_pos = ASTSourceLocation.make_ast_source_position(start_line=ctx.powOp.line,
                                                                 start_column=ctx.powOp.column,
                                                                 end_line=ctx.powOp.line,
                                                                 end_column=ctx.powOp.column)
         binary_operator = ASTNodeFactory.create_ast_arithmetic_operator(is_pow_op=True,
                                                                         source_position=source_pos)
     elif ctx.timesOp is not None:
         source_pos = ASTSourceLocation.make_ast_source_position(start_line=ctx.timesOp.line,
                                                                 start_column=ctx.timesOp.column,
                                                                 end_line=ctx.timesOp.line,
                                                                 end_column=ctx.timesOp.column)
         binary_operator = ASTNodeFactory.create_ast_arithmetic_operator(is_times_op=True,
                                                                         source_position=source_pos)
     elif ctx.divOp is not None:
         source_pos = ASTSourceLocation.make_ast_source_position(start_line=ctx.divOp.line,
                                                                 start_column=ctx.divOp.column,
                                                                 end_line=ctx.divOp.line,
                                                                 end_column=ctx.divOp.column)
         binary_operator = ASTNodeFactory.create_ast_arithmetic_operator(is_div_op=True,
                                                                         source_position=source_pos)
     elif ctx.moduloOp is not None:
         source_pos = ASTSourceLocation.make_ast_source_position(start_line=ctx.moduloOp.line,
                                                                 start_column=ctx.moduloOp.column,
                                                                 end_line=ctx.moduloOp.line,
                                                                 end_column=ctx.moduloOp.column)
         binary_operator = ASTNodeFactory.create_ast_arithmetic_operator(is_modulo_op=True,
                                                                         source_position=source_pos)
     elif ctx.plusOp is not None:
         source_pos = ASTSourceLocation.make_ast_source_position(start_line=ctx.plusOp.line,
                                                                 start_column=ctx.plusOp.column,
                                                                 end_line=ctx.plusOp.line,
                                                                 end_column=ctx.plusOp.column)
         binary_operator = ASTNodeFactory.create_ast_arithmetic_operator(is_plus_op=True,
                                                                         source_position=source_pos)
     elif ctx.minusOp is not None:
         source_pos = ASTSourceLocation.make_ast_source_position(start_line=ctx.minusOp.line,
                                                                 start_column=ctx.minusOp.column,
                                                                 end_line=ctx.minusOp.line,
                                                                 end_column=ctx.minusOp.column)
         binary_operator = ASTNodeFactory.create_ast_arithmetic_operator(is_minus_op=True,
                                                                         source_position=source_pos)
     elif ctx.bitOperator() is not None:
         binary_operator = self.visit(ctx.bitOperator())
     elif ctx.comparisonOperator() is not None:
         binary_operator = self.visit(ctx.comparisonOperator())
     elif ctx.logicalOperator() is not None:
         binary_operator = self.visit(ctx.logicalOperator())
     else:
         binary_operator = None
     rhs = (self.visit(ctx.right) if ctx.right is not None else None)
     # not it was not an operator, thus the ternary one ?
     condition = (self.visit(ctx.condition) if ctx.condition is not None else None)
     if_true = (self.visit(ctx.ifTrue) if ctx.ifTrue is not None else None)
     if_not = (self.visit(ctx.ifNot) if ctx.ifNot is not None else None)
     source_pos = create_source_pos(ctx)
     # finally construct the corresponding rhs
     if expression is not None:
         return ASTNodeFactory.create_ast_expression(is_encapsulated=is_encapsulated,
                                                     is_logical_not=is_logical_not,
                                                     unary_operator=unary_operator,
                                                     expression=expression, source_position=source_pos)
     elif (lhs is not None) and (rhs is not None) and (binary_operator is not None):
         return ASTNodeFactory.create_ast_compound_expression(lhs=lhs, binary_operator=binary_operator,
                                                              rhs=rhs, source_position=source_pos)
     elif (condition is not None) and (if_true is not None) and (if_not is not None):
         return ASTNodeFactory.create_ast_ternary_expression(condition=condition, if_true=if_true,
                                                             if_not=if_not, source_position=source_pos)
     else:
         raise RuntimeError('Type of rhs @%s,%s not recognized!' % (ctx.start.line, ctx.start.column))