def visit_ode_function(self, node): """ Private method: Used to visit a single ode-function, create the corresponding symbol and update the scope. :param node: a single ode-function. :type node: ast_ode_function """ data_type_visitor = ASTDataTypeVisitor() node.get_data_type().accept(data_type_visitor) type_symbol = PredefinedTypes.get_type(data_type_visitor.result) # now a new symbol symbol = VariableSymbol(element_reference=node, scope=node.get_scope(), name=node.get_variable_name(), block_type=BlockType.EQUATION, declaring_expression=node.get_expression(), is_predefined=False, is_function=True, is_recordable=node.is_recordable, type_symbol=type_symbol, variable_type=VariableType.VARIABLE) symbol.set_comment(node.get_comment()) # now update the scopes node.get_scope().add_symbol(symbol) node.get_data_type().update_scope(node.get_scope()) node.get_expression().update_scope(node.get_scope())
def visit_declaration(self, node: ASTDeclaration) -> None: """ Private method: Used to visit a single declaration, update its scope and return the corresponding set of symbols :param node: a declaration AST node :return: the scope is updated without a return value. """ expression = node.get_expression() if node.has_expression() else None visitor = ASTDataTypeVisitor() node.get_data_type().accept(visitor) type_name = visitor.result # all declarations in the state block are recordable is_recordable = (node.is_recordable or self.block_type_stack.top() == BlockType.STATE) init_value = node.get_expression() if self.block_type_stack.top( ) == BlockType.STATE else None # split the decorators in the AST up into namespace decorators and other decorators decorators = [] namespace_decorators = {} for d in node.get_decorators(): if isinstance(d, ASTNamespaceDecorator): namespace_decorators[str(d.get_namespace())] = str( d.get_name()) else: decorators.append(d) # now for each variable create a symbol and update the scope block_type = None if not self.block_type_stack.is_empty(): block_type = self.block_type_stack.top() for var in node.get_variables( ): # for all variables declared create a new symbol var.update_scope(node.get_scope()) type_symbol = PredefinedTypes.get_type(type_name) vector_parameter = var.get_vector_parameter() symbol = VariableSymbol(element_reference=node, scope=node.get_scope(), name=var.get_complete_name(), block_type=block_type, declaring_expression=expression, is_predefined=False, is_inline_expression=False, is_recordable=is_recordable, type_symbol=type_symbol, initial_value=init_value, vector_parameter=vector_parameter, variable_type=VariableType.VARIABLE, decorators=decorators, namespace_decorators=namespace_decorators) symbol.set_comment(node.get_comment()) node.get_scope().add_symbol(symbol) var.set_type_symbol(Either.value(type_symbol)) # the data type node.get_data_type().update_scope(node.get_scope()) # the rhs update if node.has_expression(): node.get_expression().update_scope(node.get_scope()) # the invariant update if node.has_invariant(): node.get_invariant().update_scope(node.get_scope())
def endvisit_input_port(self, node): if not node.has_datatype(): return type_symbol = node.get_datatype().get_type_symbol() type_symbol.is_buffer = True # set it as a buffer symbol = VariableSymbol(element_reference=node, scope=node.get_scope(), name=node.get_name(), block_type=BlockType.INPUT, vector_parameter=node.get_index_parameter(), is_predefined=False, is_inline_expression=False, is_recordable=False, type_symbol=type_symbol, variable_type=VariableType.BUFFER) symbol.set_comment(node.get_comment()) node.get_scope().add_symbol(symbol)
def contains_convolve_call(cls, variable: VariableSymbol) -> bool: """ Indicates whether the declaring rhs of this variable symbol has a convolve() in it. :return: True if contained, otherwise False. """ if not variable.get_declaring_expression(): return False else: for func in variable.get_declaring_expression().get_function_calls(): if func.get_name() == PredefinedFunctions.CONVOLVE: return True return False
def get_numeric_vector_size(cls, variable: VariableSymbol) -> int: """ Returns the numerical size of the vector by resolving any variable used as a size parameter in declaration :param variable: vector variable :return: the size of the vector as a numerical value """ vector_parameter = variable.get_vector_parameter() vector_variable = ASTVariable(vector_parameter, scope=variable.get_corresponding_scope()) symbol = vector_variable.get_scope().resolve_to_symbol(vector_variable.get_complete_name(), SymbolKind.VARIABLE) if symbol is not None: # vector size is a variable. Get the value from RHS return symbol.get_declaring_expression().get_numeric_literal() return int(vector_parameter)
def visit_declaration(self, node): """ Private method: Used to visit a single declaration, update its scope and return the corresponding set of symbols :param node: a declaration object. :type node: ast_declaration :return: the scope is update without a return value. :rtype: void """ expression = node.get_expression() if node.has_expression() else None visitor = ASTDataTypeVisitor() node.get_data_type().accept(visitor) type_name = visitor.result # all declarations in the state block are recordable is_recordable = (node.is_recordable or self.block_type_stack.top() == BlockType.STATE or self.block_type_stack.top() == BlockType.INITIAL_VALUES) init_value = node.get_expression() if self.block_type_stack.top( ) == BlockType.INITIAL_VALUES else None vector_parameter = node.get_size_parameter() # now for each variable create a symbol and update the scope for var in node.get_variables( ): # for all variables declared create a new symbol var.update_scope(node.get_scope()) type_symbol = PredefinedTypes.get_type(type_name) symbol = VariableSymbol(element_reference=node, scope=node.get_scope(), name=var.get_complete_name(), block_type=self.block_type_stack.top(), declaring_expression=expression, is_predefined=False, is_function=node.is_function, is_recordable=is_recordable, type_symbol=type_symbol, initial_value=init_value, vector_parameter=vector_parameter, variable_type=VariableType.VARIABLE) symbol.set_comment(node.get_comment()) node.get_scope().add_symbol(symbol) var.set_type_symbol(Either.value(type_symbol)) # the data type node.get_data_type().update_scope(node.get_scope()) # the rhs update if node.has_expression(): node.get_expression().update_scope(node.get_scope()) # the invariant update if node.has_invariant(): node.get_invariant().update_scope(node.get_scope()) return
def print_vector_declaration(self, variable: VariableSymbol) -> str: """ Prints the vector declaration :param variable: Vector variable :return: the corresponding vector declaration statement """ assert isinstance(variable, VariableSymbol), \ '(PyNestML.CodeGeneration.Printer) No or wrong type of variable symbol provided (%s)!' % type(variable) decl_str = self.print_origin(variable) + variable.get_symbol_name() + \ ".resize(" + self.print_vector_size_parameter(variable) + ", " + \ self.print_expression(variable.get_declaring_expression()) + \ ");" return decl_str
def endvisit_input_line(self, node): buffer_type = BlockType.INPUT_BUFFER_SPIKE if node.is_spike() else BlockType.INPUT_BUFFER_CURRENT if node.is_spike() and node.has_datatype(): type_symbol = node.get_datatype().get_type_symbol() elif node.is_spike(): type_symbol = PredefinedTypes.get_type('nS') else: type_symbol = PredefinedTypes.get_type('pA') type_symbol.is_buffer = True # set it as a buffer symbol = VariableSymbol(element_reference=node, scope=node.get_scope(), name=node.get_name(), block_type=buffer_type, vector_parameter=node.get_index_parameter(), is_predefined=False, is_function=False, is_recordable=False, type_symbol=type_symbol, variable_type=VariableType.BUFFER) symbol.set_comment(node.get_comment()) node.get_scope().add_symbol(symbol)
def endvisit_function(self, node): symbol = self.symbol_stack.pop() scope = self.scope_stack.pop() assert isinstance(symbol, FunctionSymbol), 'Not a function symbol' for arg in node.get_parameters(): # given the fact that the name is not directly equivalent to the one as stated in the model, # we have to get it by the sub-visitor data_type_visitor = ASTDataTypeVisitor() arg.get_data_type().accept(data_type_visitor) type_name = data_type_visitor.result # first collect the types for the parameters of the function symbol symbol.add_parameter_type(PredefinedTypes.get_type(type_name)) # update the scope of the arg arg.update_scope(scope) # create the corresponding variable symbol representing the parameter var_symbol = VariableSymbol( element_reference=arg, scope=scope, name=arg.get_name(), block_type=BlockType.LOCAL, is_predefined=False, is_function=False, is_recordable=False, type_symbol=PredefinedTypes.get_type(type_name), variable_type=VariableType.VARIABLE) assert isinstance(scope, Scope) scope.add_symbol(var_symbol) if node.has_return_type(): data_type_visitor = ASTDataTypeVisitor() node.get_return_type().accept(data_type_visitor) symbol.set_return_type( PredefinedTypes.get_type(data_type_visitor.result)) else: symbol.set_return_type(PredefinedTypes.get_void_type()) self.block_type_stack.pop() # before leaving update the type
def visit_declaration(self, node): """ Private method: Used to visit a single declaration, update its scope and return the corresponding set of symbols :param node: a declaration object. :type node: ast_declaration :return: the scope is update without a return value. :rtype: void """ expression = node.get_expression() if node.has_expression() else None visitor = ASTDataTypeVisitor() node.get_data_type().accept(visitor) type_name = visitor.result # all declarations in the state block are recordable is_recordable = (node.is_recordable or self.block_type_stack.top() == BlockType.STATE or self.block_type_stack.top() == BlockType.INITIAL_VALUES) init_value = node.get_expression() if self.block_type_stack.top() == BlockType.INITIAL_VALUES else None vector_parameter = node.get_size_parameter() # now for each variable create a symbol and update the scope for var in node.get_variables(): # for all variables declared create a new symbol var.update_scope(node.get_scope()) type_symbol = PredefinedTypes.get_type(type_name) symbol = VariableSymbol(element_reference=node, scope=node.get_scope(), name=var.get_complete_name(), block_type=self.block_type_stack.top(), declaring_expression=expression, is_predefined=False, is_function=node.is_function, is_recordable=is_recordable, type_symbol=type_symbol, initial_value=init_value, vector_parameter=vector_parameter, variable_type=VariableType.VARIABLE ) symbol.set_comment(node.get_comment()) node.get_scope().add_symbol(symbol) var.set_type_symbol(Either.value(type_symbol)) # the data type node.get_data_type().update_scope(node.get_scope()) # the rhs update if node.has_expression(): node.get_expression().update_scope(node.get_scope()) # the invariant update if node.has_invariant(): node.get_invariant().update_scope(node.get_scope()) return
def __register_time_constant(cls): """ Adds the time constant t. """ symbol = VariableSymbol(name='t', block_type=BlockType.STATE, is_predefined=True, type_symbol=PredefinedTypes.get_type('ms'), variable_type=VariableType.VARIABLE) cls.name2variable[cls.TIME_CONSTANT] = symbol return
def __register_euler_constant(cls): """ Adds the euler constant e. """ symbol = VariableSymbol(name='e', block_type=BlockType.STATE, is_predefined=True, type_symbol=PredefinedTypes.get_real_type(), variable_type=VariableType.VARIABLE) cls.name2variable[cls.E_CONSTANT] = symbol return
def visit_ode_shape(self, node): """ Private method: Used to visit a single ode-shape, create the corresponding symbol and update the scope. :param node: a single ode-shape. :type node: ast_ode_shape """ if node.get_variable().get_differential_order() == 0 and \ node.get_scope().resolve_to_symbol(node.get_variable().get_complete_name(), SymbolKind.VARIABLE) is None: symbol = VariableSymbol(element_reference=node, scope=node.get_scope(), name=node.get_variable().get_name(), block_type=BlockType.EQUATION, declaring_expression=node.get_expression(), is_predefined=False, is_function=False, is_recordable=True, type_symbol=PredefinedTypes.get_real_type(), variable_type=VariableType.SHAPE) symbol.set_comment(node.get_comment()) node.get_scope().add_symbol(symbol) node.get_variable().update_scope(node.get_scope()) node.get_expression().update_scope(node.get_scope())
def __register_predefined_type_variables(cls): """ Registers all predefined type variables, e.g., mV and integer. """ for name in PredefinedTypes.get_types().keys(): symbol = VariableSymbol(name=name, block_type=BlockType.PREDEFINED, is_predefined=True, type_symbol=PredefinedTypes.get_type(name), variable_type=VariableType.TYPE) cls.name2variable[name] = symbol return
def print_vector_size_parameter(self, variable: VariableSymbol) -> str: """ Prints NEST compatible vector size parameter :param variable: Vector variable :return: vector size parameter """ vector_parameter = variable.get_vector_parameter() vector_parameter_var = ASTVariable( vector_parameter, scope=variable.get_corresponding_scope()) symbol = vector_parameter_var.get_scope().resolve_to_symbol( vector_parameter_var.get_complete_name(), SymbolKind.VARIABLE) vector_param = "" if symbol is not None: # size parameter is a variable vector_param += self.print_origin(symbol) + vector_parameter else: # size parameter is an integer vector_param += vector_parameter return vector_param
def visit_kernel(self, node): """ Private method: Used to visit a single kernel, create the corresponding symbol and update the scope. :param node: a kernel. :type node: ASTKernel """ for var, expr in zip(node.get_variables(), node.get_expressions()): if var.get_differential_order() == 0 and \ node.get_scope().resolve_to_symbol(var.get_complete_name(), SymbolKind.VARIABLE) is None: symbol = VariableSymbol(element_reference=node, scope=node.get_scope(), name=var.get_name(), block_type=BlockType.EQUATION, declaring_expression=expr, is_predefined=False, is_inline_expression=False, is_recordable=True, type_symbol=PredefinedTypes.get_real_type(), variable_type=VariableType.KERNEL) symbol.set_comment(node.get_comment()) node.get_scope().add_symbol(symbol) var.update_scope(node.get_scope()) expr.update_scope(node.get_scope())