예제 #1
0
    def visit_simple_expression(self, node):
        """
        Visits a single function call as stored in a simple expression and derives the correct type of all its
        parameters. :param node: a simple expression :type node: ASTSimpleExpression :rtype void
        """
        assert isinstance(node, ASTSimpleExpression), \
            '(PyNestML.Visitor.FunctionCallVisitor) No or wrong type of simple expression provided (%s)!' % tuple(node)
        assert (node.get_scope() is not None), \
            "(PyNestML.Visitor.FunctionCallVisitor) No scope found, run symboltable creator!"
        scope = node.get_scope()
        function_name = node.get_function_call().get_name()
        method_symbol = scope.resolve_to_symbol(function_name,
                                                SymbolKind.FUNCTION)
        # check if this function exists
        if method_symbol is None:
            code, message = Messages.get_could_not_resolve(function_name)
            Logger.log_message(code=code,
                               message=message,
                               error_position=node.get_source_position(),
                               log_level=LoggingLevel.ERROR)
            node.type = ErrorTypeSymbol()
            return
        return_type = method_symbol.get_return_type()
        return_type.referenced_object = node

        # convolve symbol does not have a return type set.
        # returns whatever type the second parameter is.
        if function_name == PredefinedFunctions.CONVOLVE:
            # Deviations from the assumptions made here are handled in the convolveCoco
            buffer_parameter = node.get_function_call().get_args()[1]

            if buffer_parameter.getVariable() is not None:
                buffer_name = buffer_parameter.getVariable().getName()
                buffer_symbol_resolve = scope.resolve_to_symbol(
                    buffer_name, SymbolKind.VARIABLE)
                if buffer_symbol_resolve is not None:
                    node.type = buffer_symbol_resolve.getTypeSymbol()
                    return

            # getting here means there is an error with the parameters to convolve
            code, message = Messages.get_convolve_needs_buffer_parameter()
            Logger.log_message(code=code,
                               message=message,
                               error_position=node.get_source_position(),
                               log_level=LoggingLevel.ERROR)
            node.type = ErrorTypeSymbol()
            return

        if isinstance(method_symbol.get_return_type(), VoidTypeSymbol):
            # todo by KP: the error message is not used here, @ptraeder fix this
            # error_msg = ErrorStrings.message_void_function_on_rhs(self, function_name, node.get_source_position())
            node.type = ErrorTypeSymbol()
            return

        # if nothing special is handled, just get the expression type from the return type of the function
        node.type = return_type
예제 #2
0
 def endvisit_unit_type(self, node):
     if node.is_encapsulated:
         node.set_type_symbol(node.compound_unit.get_type_symbol())
     elif node.is_pow:
         base_symbol = node.base.get_type_symbol()
         exponent = node.exponent
         astropy_unit = base_symbol.astropy_unit ** exponent
         res = handle_unit(astropy_unit)
         node.set_type_symbol(res)
         self.symbol = res
     elif node.is_div:
         if isinstance(node.get_lhs(), ASTUnitType):  # regard that lhs can be a numeric or a unit-type
             lhs = node.get_lhs().get_type_symbol().astropy_unit
         else:
             lhs = node.get_lhs()
         rhs = node.get_rhs().get_type_symbol().astropy_unit
         res = lhs / rhs
         res = handle_unit(res)
         node.set_type_symbol(res)
         self.symbol = res
     elif node.is_times:
         if isinstance(node.get_lhs(), ASTUnitType):  # regard that lhs can be a numeric or a unit-type
             if node.get_lhs().get_type_symbol() is None or isinstance(node.get_lhs().get_type_symbol(), ErrorTypeSymbol):
                 node.set_type_symbol(ErrorTypeSymbol())
                 return
             lhs = node.get_lhs().get_type_symbol().astropy_unit
         else:
             lhs = node.get_lhs()
         rhs = node.get_rhs().get_type_symbol().astropy_unit
         res = lhs * rhs
         res = handle_unit(res)
         node.set_type_symbol(res)
         self.symbol = res
     return
예제 #3
0
    def visit_expression(self, node):
        """
        Visits a single expression containing a plus or minus operator and updates its type.
        :param node: a single expression
        :type node: ast_expression
        """
        lhs_type = node.get_lhs().type
        rhs_type = node.get_rhs().type

        arith_op = node.get_binary_operator()

        lhs_type.referenced_object = node.get_lhs()
        rhs_type.referenced_object = node.get_rhs()

        node.type = ErrorTypeSymbol()
        if arith_op.is_plus_op:
            node.type = lhs_type + rhs_type
        elif arith_op.is_minus_op:
            node.type = lhs_type - rhs_type

        if isinstance(node.type, ErrorTypeSymbol):
            code, message = Messages.get_binary_operation_type_could_not_be_derived(
                lhs=str(node.get_lhs()),
                operator=str(arith_op),
                rhs=str(node.get_rhs()),
                lhs_type=str(lhs_type.print_nestml_type()),
                rhs_type=str(rhs_type.print_nestml_type()))
            Logger.log_message(code=code,
                               message=message,
                               error_position=node.get_source_position(),
                               log_level=LoggingLevel.ERROR)
예제 #4
0
    def visit_simple_expression(self, node):
        """
        Visits a single variable as contained in a simple expression and derives its type.
        :param node: a single simple expression
        :type node: ASTSimpleExpression
        """
        assert isinstance(node, ASTSimpleExpression), \
            '(PyNestML.Visitor.VariableVisitor) No or wrong type of simple expression provided (%s)!' % type(node)
        assert (node.get_scope() is not None), \
            '(PyNestML.Visitor.VariableVisitor) No scope found, run symboltable creator!'

        scope = node.get_scope()
        var_name = node.get_variable().get_name()
        var_resolve = scope.resolve_to_symbol(var_name, SymbolKind.VARIABLE)

        # update the type of the variable according to its symbol type.
        if var_resolve is not None:
            node.type = var_resolve.get_type_symbol()
            node.type.referenced_object = node
        else:
            message = 'Variable ' + str(node) + ' could not be resolved!'
            Logger.log_message(code=MessageCode.SYMBOL_NOT_RESOLVED,
                               error_position=node.get_source_position(),
                               message=message,
                               log_level=LoggingLevel.ERROR)
            node.type = ErrorTypeSymbol()
        return
예제 #5
0
    def visit_simple_expression(self, node):
        """
        Visits a single function call as stored in a simple expression and checks to see whether any calls are made to generate a random number. If so, set a flag so that the necessary initialisers can be called at the right time in the generated code.
        """
        assert isinstance(node, ASTSimpleExpression), \
            '(PyNestML.Visitor.FunctionCallVisitor) No or wrong type of simple expression provided (%s)!' % tuple(node)
        assert (node.get_scope() is not None), \
            "(PyNestML.Visitor.FunctionCallVisitor) No scope found, run symboltable creator!"
        scope = node.get_scope()
        if node.get_function_call() is None:
            return
        function_name = node.get_function_call().get_name()
        method_symbol = scope.resolve_to_symbol(function_name, SymbolKind.FUNCTION)

        # check if this function exists
        if method_symbol is None:
            code, message = Messages.get_could_not_resolve(function_name)
            Logger.log_message(code=code, message=message, error_position=node.get_source_position(),
                               log_level=LoggingLevel.ERROR)
            node.type = ErrorTypeSymbol()
            return

        if function_name == PredefinedFunctions.RANDOM_NORMAL:
            self._norm_rng_is_used = True
            return
예제 #6
0
    def visit_expression(self, node):
        """
        Visits an expression which uses a binary logic operator and updates the type.
        :param node: a single expression.
        :type node: ast_expression
        """
        lhs_type = node.get_lhs().type
        rhs_type = node.get_rhs().type

        lhs_type.referenced_object = node.get_lhs()
        rhs_type.referenced_object = node.get_rhs()

        if isinstance(lhs_type, BooleanTypeSymbol) and isinstance(
                rhs_type, BooleanTypeSymbol):
            node.type = PredefinedTypes.get_boolean_type()
        else:
            if isinstance(lhs_type, BooleanTypeSymbol):
                offending_type = lhs_type
            else:
                offending_type = rhs_type
            code, message = Messages.get_type_different_from_expected(
                BooleanTypeSymbol(), offending_type)
            Logger.log_message(code=code,
                               message=message,
                               error_position=lhs_type.referenced_object.
                               get_source_position(),
                               log_level=LoggingLevel.ERROR)
            node.type = ErrorTypeSymbol()
        return
예제 #7
0
    def visit_simple_expression(self, node):
        """
        Visit a simple rhs and update the type of a numeric literal.
        :param node: a single meta_model node
        :type node: ast_node
        :return: no value returned, the type is updated in-place
        :rtype: void
        """
        assert node.get_scope() is not None, "Run symboltable creator."
        # if variable is also set in this rhs, the var type overrides the literal
        if node.get_variable() is not None:
            scope = node.get_scope()
            var_name = node.get_variable().get_name()
            variable_symbol_resolve = scope.resolve_to_symbol(var_name, SymbolKind.VARIABLE)
            if variable_symbol_resolve is not None:
                node.type = variable_symbol_resolve.get_type_symbol()
            else:
                type_symbol_resolve = scope.resolve_to_symbol(var_name, SymbolKind.TYPE)
                if type_symbol_resolve is not None:
                    node.type = type_symbol_resolve
                else:
                    node.type = ErrorTypeSymbol()
            node.type.referenced_object = node
            return

        if node.get_numeric_literal() is not None and isinstance(node.get_numeric_literal(), float):
            node.type = PredefinedTypes.get_real_type()
            node.type.referenced_object = node
            return

        elif node.get_numeric_literal() is not None and isinstance(node.get_numeric_literal(), int):
            node.type = PredefinedTypes.get_integer_type()
            node.type.referenced_object = node
            return
예제 #8
0
 def unary_operation_not_defined_error(self, _operator):
     from pynestml.symbols.error_type_symbol import ErrorTypeSymbol
     result = ErrorTypeSymbol()
     code, message = Messages.get_unary_operation_not_defined(
         _operator, self.print_symbol())
     Logger.log_message(
         code=code,
         message=message,
         error_position=self.referenced_object.get_source_position(),
         log_level=LoggingLevel.ERROR)
     return result
 def visit_unit_type(self, node):
     """
     Check if the coco applies,
     :param node: a single unit type object.
     :type node: ast_unit_type
     """
     if node.is_div and isinstance(node.lhs, int) and node.lhs != 1:
         from pynestml.symbols.error_type_symbol import ErrorTypeSymbol
         node.set_type_symbol(ErrorTypeSymbol())
         code, message = Messages.get_wrong_numerator(str(node))
         Logger.log_message(code=code, message=message, error_position=node.get_source_position(),
                            log_level=LoggingLevel.ERROR)
예제 #10
0
 def visit_expression(self, node):
     """
     Visits a single rhs but does not execute any steps besides printing a message. This
     visitor indicates that no functionality has been implemented for this type of nodes.
     :param node: a single rhs
     :type node: ast_expression or ast_simple_expression
     """
     error_msg = ErrorStrings.message_no_semantics(
         self, str(node), node.get_source_position())
     node.type = ErrorTypeSymbol()
     # just warn though
     Logger.log_message(message=error_msg,
                        code=MessageCode.NO_SEMANTICS,
                        error_position=node.get_source_position(),
                        log_level=LoggingLevel.WARNING)
     return
예제 #11
0
    def visit_expression(self, expr):
        """
        Visits a single comparison operator expression and updates the type.
        :param expr: an expression
        :type expr: ast_expression
        """
        lhs_type = expr.get_lhs().type
        rhs_type = expr.get_rhs().type

        lhs_type.referenced_object = expr.get_lhs()
        rhs_type.referenced_object = expr.get_rhs()

        if (lhs_type.is_numeric_primitive() and rhs_type.is_numeric_primitive()) \
                or (lhs_type.equals(rhs_type) and lhs_type.is_numeric()) or (
                isinstance(lhs_type, BooleanTypeSymbol) and isinstance(rhs_type, BooleanTypeSymbol)):
            expr.type = PredefinedTypes.get_boolean_type()
            return

        # Error message for any other operation
        if (isinstance(lhs_type, UnitTypeSymbol)
                and rhs_type.is_numeric()) or (isinstance(
                    rhs_type, UnitTypeSymbol) and lhs_type.is_numeric()):
            # if the incompatibility exists between a unit and a numeric, the c++ will still be fine, just WARN
            error_msg = ErrorStrings.message_comparison(
                self, expr.get_source_position())
            expr.type = PredefinedTypes.get_boolean_type()
            Logger.log_message(message=error_msg,
                               code=MessageCode.SOFT_INCOMPATIBILITY,
                               error_position=expr.get_source_position(),
                               log_level=LoggingLevel.WARNING)
            return
        else:
            # hard incompatibility, cannot recover in c++, ERROR
            error_msg = ErrorStrings.message_comparison(
                self, expr.get_source_position())
            expr.type = ErrorTypeSymbol()
            Logger.log_message(code=MessageCode.HARD_INCOMPATIBILITY,
                               error_position=expr.get_source_position(),
                               message=error_msg,
                               log_level=LoggingLevel.ERROR)
            return
예제 #12
0
    def visit_expression(self, node):
        """
        Visits an rhs consisting of the ternary operator and updates its type.
        :param node: a single rhs
        :type node: ast_expression
        """
        condition = node.get_condition().type
        if_true = node.get_if_true().type
        if_not = node.get_if_not().type

        condition.referenced_object = node.get_condition()
        if_true.referenced_object = node.get_if_true()
        if_not.referenced_object = node.get_if_not()

        # Condition must be a bool
        if not condition.equals(PredefinedTypes.get_boolean_type()):
            error_msg = ErrorStrings.message_ternary(
                self, node.get_source_position())
            node.type = ErrorTypeSymbol()
            Logger.log_message(message=error_msg,
                               error_position=node.get_source_position(),
                               code=MessageCode.TYPE_DIFFERENT_FROM_EXPECTED,
                               log_level=LoggingLevel.ERROR)
            return

        # Alternatives match exactly -> any is valid
        if if_true.equals(if_not):
            node.type = if_true
            return

        # Both are units but not matching-> real WARN
        if isinstance(if_true, UnitTypeSymbol) and isinstance(
                if_not, UnitTypeSymbol):
            error_msg = ErrorStrings.message_ternary_mismatch(
                self, if_true.print_symbol(), if_not.print_symbol(),
                node.get_source_position())
            node.type = PredefinedTypes.get_real_type()
            Logger.log_message(
                message=error_msg,
                code=MessageCode.TYPE_DIFFERENT_FROM_EXPECTED,
                error_position=if_true.referenced_object.get_source_position(),
                log_level=LoggingLevel.WARNING)
            return

        # one Unit and one numeric primitive and vice versa -> assume unit, WARN
        if (isinstance(if_true, UnitTypeSymbol)
                and if_not.is_numeric_primitive()) or (
                    isinstance(if_not, UnitTypeSymbol)
                    and if_true.is_numeric_primitive()):
            if isinstance(if_true, UnitTypeSymbol):
                unit_type = if_true
            else:
                unit_type = if_not
            error_msg = ErrorStrings.message_ternary_mismatch(
                self, str(if_true), str(if_not), node.get_source_position())
            node.type = unit_type
            Logger.log_message(
                message=error_msg,
                code=MessageCode.TYPE_DIFFERENT_FROM_EXPECTED,
                error_position=if_true.referenced_object.get_source_position(),
                log_level=LoggingLevel.WARNING)
            return

        # both are numeric primitives (and not equal) ergo one is real and one is integer -> real
        if if_true.is_numeric_primitive() and if_not.is_numeric_primitive():
            node.type = PredefinedTypes.get_real_type()
            return

        # if we get here it is an error
        error_msg = ErrorStrings.message_ternary_mismatch(
            self, str(if_true), str(if_not), node.get_source_position())
        node.type = ErrorTypeSymbol()
        Logger.log_message(message=error_msg,
                           error_position=node.get_source_position(),
                           code=MessageCode.TYPE_DIFFERENT_FROM_EXPECTED,
                           log_level=LoggingLevel.ERROR)
예제 #13
0
    def visit_simple_expression(self, node):
        """
        Visits a single function call as stored in a simple expression and derives the correct type of all its
        parameters. :param node: a simple expression :type node: ASTSimpleExpression :rtype void
        """
        assert isinstance(node, ASTSimpleExpression), \
            '(PyNestML.Visitor.FunctionCallVisitor) No or wrong type of simple expression provided (%s)!' % tuple(node)
        assert (node.get_scope() is not None), \
            "(PyNestML.Visitor.FunctionCallVisitor) No scope found, run symboltable creator!"
        scope = node.get_scope()
        function_name = node.get_function_call().get_name()
        method_symbol = scope.resolve_to_symbol(function_name, SymbolKind.FUNCTION)
        # check if this function exists
        if method_symbol is None:
            code, message = Messages.get_could_not_resolve(function_name)
            Logger.log_message(code=code, message=message, error_position=node.get_source_position(),
                               log_level=LoggingLevel.ERROR)
            node.type = ErrorTypeSymbol()
            return
        return_type = method_symbol.get_return_type()

        if isinstance(return_type, TemplateTypeSymbol):
            for i, arg_type in enumerate(method_symbol.param_types):
                if arg_type == return_type:
                    return_type = node.get_function_call().get_args()[i].type
                    break

            if isinstance(return_type, TemplateTypeSymbol):
                # error: return type template not found among parameter type templates
                assert(False)

            # check for consistency among actual derived types for template parameters
            from pynestml.cocos.co_co_function_argument_template_types_consistent import CorrectTemplatedArgumentTypesVisitor
            correctTemplatedArgumentTypesVisitor = CorrectTemplatedArgumentTypesVisitor()
            correctTemplatedArgumentTypesVisitor._failure_occurred = False
            node.accept(correctTemplatedArgumentTypesVisitor)
            if correctTemplatedArgumentTypesVisitor._failure_occurred:
                return_type = ErrorTypeSymbol()

        return_type.referenced_object = node

        # convolve symbol does not have a return type set.
        # returns whatever type the second parameter is.
        if function_name == PredefinedFunctions.CONVOLVE:
            # Deviations from the assumptions made here are handled in the convolveCoco
            buffer_parameter = node.get_function_call().get_args()[1]

            if buffer_parameter.get_variable() is not None:
                buffer_name = buffer_parameter.get_variable().get_name()
                buffer_symbol_resolve = scope.resolve_to_symbol(buffer_name, SymbolKind.VARIABLE)
                if buffer_symbol_resolve is not None:
                    node.type = buffer_symbol_resolve.get_type_symbol()
                    return

            # getting here means there is an error with the parameters to convolve
            code, message = Messages.get_convolve_needs_buffer_parameter()
            Logger.log_message(code=code, message=message, error_position=node.get_source_position(),
                               log_level=LoggingLevel.ERROR)
            node.type = ErrorTypeSymbol()
            return

        if isinstance(method_symbol.get_return_type(), VoidTypeSymbol):
            # todo: the error message is not used here, fix this
            # error_msg = ErrorStrings.message_void_function_on_rhs(self, function_name, node.get_source_position())
            node.type = ErrorTypeSymbol()
            return

        # if nothing special is handled, just get the expression type from the return type of the function
        node.type = return_type