예제 #1
0
파일: ast_utils.py 프로젝트: nest/nestml
    def add_suffix_to_variable_name(cls, var_name: str, astnode: ASTNode, suffix: str, scope=None):
        """add suffix to variable by given name recursively throughout astnode"""
        from pynestml.visitors.ast_symbol_table_visitor import ASTSymbolTableVisitor
        from pynestml.symbols.variable_symbol import BlockType
        from pynestml.symbols.variable_symbol import VariableSymbol, BlockType, VariableType

        def replace_var(_expr=None):
            if isinstance(_expr, ASTSimpleExpression) and _expr.is_variable():
                var = _expr.get_variable()
            elif isinstance(_expr, ASTVariable):
                var = _expr
            else:
                return

            if not suffix in var.get_name() \
               and not var.get_name() == "t" \
               and var.get_name() == var_name:
                var.set_name(var.get_name() + suffix)

            # if scope is not None:
            #     symbol = VariableSymbol(name=var.get_name(), block_type=BlockType.PARAMETERS,
            #                     type_symbol=var.get_type_symbol(),
            #                     variable_type=VariableType.VARIABLE)
            #     scope.add_symbol(symbol)
            #     var.update_scope(scope)
            #     assert scope.resolve_to_symbol(var.get_name(), SymbolKind.VARIABLE) is not None
            #     assert scope.resolve_to_symbol(var.get_name(), SymbolKind.VARIABLE).block_type == BlockType.PARAMETERS
            #     #var.accept(ASTSymbolTableVisitor())

        astnode.accept(ASTHigherOrderVisitor(lambda x: replace_var(x)))
예제 #2
0
파일: ast_utils.py 프로젝트: nest/nestml
    def get_all_variables(cls, node: ASTNode) -> List[str]:
        """Make a list of all variable symbol names that are in ``node``"""
        class ASTVariablesFinderVisitor(ASTVisitor):
            _variables = []

            def __init__(self):
                super(ASTVariablesFinderVisitor, self).__init__()

            def visit_declaration(self, node):
                symbol = node.get_scope().resolve_to_symbol(node.get_variables()[0].get_complete_name(),
                                                            SymbolKind.VARIABLE)
                if symbol is None:
                    code, message = Messages.get_variable_not_defined(node.get_variable().get_complete_name())
                    Logger.log_message(code=code, message=message, error_position=node.get_source_position(),
                                       log_level=LoggingLevel.ERROR, astnode=node)
                    return

                self._variables.append(symbol)

        if node is None:
            return []

        visitor = ASTVariablesFinderVisitor()
        node.accept(visitor)
        all_variables = [v.name for v in visitor._variables]
        return all_variables
예제 #3
0
    def get_inline_expression_symbols(cls,
                                      ast: ASTNode) -> List[VariableSymbol]:
        """
        For the handed over AST node, this method collects all inline expression variable symbols in it.
        :param ast: a single AST node
        :return: a list of all inline expression variable symbols
        """
        from pynestml.visitors.ast_higher_order_visitor import ASTHigherOrderVisitor
        from pynestml.meta_model.ast_variable import ASTVariable
        res = list()

        def loc_get_vars(node):
            if isinstance(node, ASTVariable):
                res.append(node)

        ast.accept(ASTHigherOrderVisitor(visit_funcs=loc_get_vars))

        ret = list()
        for var in res:
            if '\'' not in var.get_complete_name():
                symbol = ast.get_scope().resolve_to_symbol(
                    var.get_complete_name(), SymbolKind.VARIABLE)
                if symbol is not None and symbol.is_inline_expression:
                    ret.append(symbol)
        return ret
예제 #4
0
파일: ast_utils.py 프로젝트: nest/nestml
    def replace_with_external_variable(cls, var_name, node: ASTNode, suffix, new_scope, alternate_name=None):
        """
        Replace all occurrences of variables (``ASTVariable``s) (e.g. ``post_trace'``) in the node with ``ASTExternalVariable``s, indicating that they are moved to the postsynaptic neuron.
        """

        def replace_var(_expr=None):
            if isinstance(_expr, ASTSimpleExpression) and _expr.is_variable():
                var = _expr.get_variable()
            elif isinstance(_expr, ASTVariable):
                var = _expr
            else:
                return

            if var.get_name() != var_name:
                return

            ast_ext_var = ASTExternalVariable(var.get_name() + suffix,
                                              differential_order=var.get_differential_order(),
                                              source_position=var.get_source_position())
            if alternate_name:
                ast_ext_var.set_alternate_name(alternate_name)

            ast_ext_var.update_alt_scope(new_scope)
            from pynestml.visitors.ast_symbol_table_visitor import ASTSymbolTableVisitor
            ast_ext_var.accept(ASTSymbolTableVisitor())

            if isinstance(_expr, ASTSimpleExpression) and _expr.is_variable():
                Logger.log_message(None, -1, "ASTSimpleExpression replacement made (var = " + str(
                    ast_ext_var.get_name()) + ") in expression: " + str(node.get_parent(_expr)), None, LoggingLevel.INFO)
                _expr.set_variable(ast_ext_var)
                return

            if isinstance(_expr, ASTVariable):
                if isinstance(node.get_parent(_expr), ASTAssignment):
                    node.get_parent(_expr).lhs = ast_ext_var
                    Logger.log_message(None, -1, "ASTVariable replacement made in expression: "
                                       + str(node.get_parent(_expr)), None, LoggingLevel.INFO)
                elif isinstance(node.get_parent(_expr), ASTSimpleExpression) and node.get_parent(_expr).is_variable():
                    node.get_parent(_expr).set_variable(ast_ext_var)
                elif isinstance(node.get_parent(_expr), ASTDeclaration):
                    # variable could occur on the left-hand side; ignore. Only replace if it occurs on the right-hand side.
                    pass
                else:
                    Logger.log_message(None, -1, "Error: unhandled use of variable "
                                       + var_name + " in expression " + str(_expr), None, LoggingLevel.INFO)
                    raise Exception()
                return

            p = node.get_parent(var)
            Logger.log_message(None, -1, "Error: unhandled use of variable "
                               + var_name + " in expression " + str(p), None, LoggingLevel.INFO)
            raise Exception()

        node.accept(ASTHigherOrderVisitor(lambda x: replace_var(x)))
예제 #5
0
    def add_suffix_to_variable_name(cls, var_name: str, astnode: ASTNode,
                                    suffix: str):
        """add suffix to variable by given name recursively throughout astnode"""
        def replace_var(_expr=None):
            if isinstance(_expr, ASTSimpleExpression) and _expr.is_variable():
                var = _expr.get_variable()
            elif isinstance(_expr, ASTVariable):
                var = _expr
            else:
                return

            if not suffix in var.get_name() \
               and not var.get_name() == "t" \
               and var.get_name() == var_name:
                var.set_name(var.get_name() + suffix)

        astnode.accept(ASTHigherOrderVisitor(lambda x: replace_var(x)))
예제 #6
0
    def collect_variable_names_in_expression(
            cls, expr: ASTNode) -> List[ASTVariable]:
        """
        Collect all occurrences of variables (`ASTVariable`), kernels (`ASTKernel`) XXX ...
        :param expr: expression to collect the variables from
        :return: a list of variables
        """
        vars_used_ = []

        def collect_vars(_expr=None):
            var = None
            if isinstance(_expr, ASTSimpleExpression) and _expr.is_variable():
                var = _expr.get_variable()
            elif isinstance(_expr, ASTVariable):
                var = _expr

            if var:
                vars_used_.append(var)

        expr.accept(ASTHigherOrderVisitor(lambda x: collect_vars(x)))

        return vars_used_
예제 #7
0
파일: ast_utils.py 프로젝트: nest/nestml
    def get_all_variables_used_in_convolutions(cls, node: ASTNode, parent_node: ASTNode) -> List[str]:
        """Make a list of all variable symbol names that are in ``node`` and used in a convolution"""
        from pynestml.codegeneration.ast_transformers import ASTTransformers

        class ASTAllVariablesUsedInConvolutionVisitor(ASTVisitor):
            _variables = []
            parent_node = None

            def __init__(self, node, parent_node):
                super(ASTAllVariablesUsedInConvolutionVisitor, self).__init__()
                self.node = node
                self.parent_node = parent_node

            def visit_function_call(self, node):
                func_name = node.get_name()
                if func_name == 'convolve':
                    symbol_buffer = node.get_scope().resolve_to_symbol(str(node.get_args()[1]),
                                                                       SymbolKind.VARIABLE)
                    input_port = ASTTransformers.get_input_port_by_name(
                        self.parent_node.get_input_blocks(), symbol_buffer.name)
                    if input_port:
                        found_parent_assignment = False
                        node_ = node
                        while not found_parent_assignment:
                            node_ = self.parent_node.get_parent(node_)
                            # XXX TODO also needs to accept normal ASTExpression, ASTAssignment?
                            if isinstance(node_, ASTInlineExpression):
                                found_parent_assignment = True
                        var_name = node_.get_variable_name()
                        self._variables.append(var_name)

        if node is None:
            return []

        visitor = ASTAllVariablesUsedInConvolutionVisitor(node, parent_node)
        node.accept(visitor)
        return visitor._variables
예제 #8
0
    def check_co_co(cls, neuron: ASTNode, after_ast_rewrite: bool):
        """
        Checks if this coco applies for the handed over neuron. Models which use not defined elements are not
        correct.
        :param neuron: a single neuron instance.
        :type neuron: ast_neuron
        """
        # for each variable in all expressions, check if the variable has been defined previously
        expression_collector_visitor = ASTExpressionCollectorVisitor()
        neuron.accept(expression_collector_visitor)
        expressions = expression_collector_visitor.ret
        for expr in expressions:
            for var in expr.get_variables():
                symbol = var.get_scope().resolve_to_symbol(
                    var.get_complete_name(), SymbolKind.VARIABLE)
                # this part is required to check that we handle invariants differently
                expr_par = neuron.get_parent(expr)

                if symbol is None:
                    # check if this symbol is actually a type, e.g. "mV" in the expression "(1 + 2) * mV"
                    symbol = var.get_scope().resolve_to_symbol(
                        var.get_complete_name(), SymbolKind.TYPE)
                    if symbol is None:
                        # symbol has not been defined; neither as a variable name nor as a type symbol
                        code, message = Messages.get_variable_not_defined(
                            var.get_name())
                        Logger.log_message(
                            node=neuron,
                            code=code,
                            message=message,
                            log_level=LoggingLevel.ERROR,
                            error_position=var.get_source_position())
                # first check if it is part of an invariant
                # if it is the case, there is no "recursive" declaration
                # so check if the parent is a declaration and the expression the invariant
                elif isinstance(
                        expr_par,
                        ASTDeclaration) and expr_par.get_invariant() == expr:
                    # in this case its ok if it is recursive or defined later on
                    continue

                # now check if it has been defined before usage, except for predefined symbols, buffers and variables added by the AST transformation functions
                elif (not symbol.is_predefined) \
                        and symbol.block_type != BlockType.INPUT_BUFFER_CURRENT \
                        and symbol.block_type != BlockType.INPUT_BUFFER_SPIKE \
                        and not symbol.get_referenced_object().get_source_position().is_added_source_position():
                    # except for parameters, those can be defined after
                    if ((not symbol.get_referenced_object(
                    ).get_source_position().before(var.get_source_position()))
                            and
                        (not symbol.block_type
                         in [BlockType.PARAMETERS, BlockType.INTERNALS])):
                        code, message = Messages.get_variable_used_before_declaration(
                            var.get_name())
                        Logger.log_message(
                            node=neuron,
                            message=message,
                            error_position=var.get_source_position(),
                            code=code,
                            log_level=LoggingLevel.ERROR)
                        # now check that they are now defined recursively, e.g. V_m mV = V_m + 1
                    # todo: we should not check this for invariants
                    if (symbol.get_referenced_object().get_source_position(
                    ).encloses(var.get_source_position())
                            and not symbol.get_referenced_object().
                            get_source_position().is_added_source_position()):
                        code, message = Messages.get_variable_defined_recursively(
                            var.get_name())
                        Logger.log_message(
                            node=neuron,
                            code=code,
                            message=message,
                            error_position=symbol.get_referenced_object(
                            ).get_source_position(),
                            log_level=LoggingLevel.ERROR)

        # now check for each assignment whether the left hand side variable is defined
        vis = ASTAssignedVariableDefinedVisitor(neuron, after_ast_rewrite)
        neuron.accept(vis)