def add_kernel(self, kernel: ASTKernel) -> None: """ Adds the handed over declaration to the state block. :param kernel: a single declaration. """ assert self.get_equations_block() is not None self.get_equations_block().get_declarations().append(kernel) kernel.update_scope(self.get_equations_blocks().get_scope())
def get_expr_from_kernel_var( cls, kernel: ASTKernel, var_name: str) -> Union[ASTExpression, ASTSimpleExpression]: """ Get the expression using the kernel variable """ assert type(var_name) == str for var, expr in zip(kernel.get_variables(), kernel.get_expressions()): if var.get_complete_name() == var_name: return expr assert False, "variable name not found in kernel"
def add_kernel_to_variable(cls, kernel: ASTKernel): r""" Adds the kernel as the defining equation. If the definition of the kernel is e.g. `g'' = ...` then variable symbols `g` and `g'` will have their kernel definition and variable type set. :param kernel: a single kernel object. """ if len(kernel.get_variables()) == 1 \ and kernel.get_variables()[0].get_differential_order() == 0: # we only update those which define an ODE; skip "direct function of time" specifications return for var, expr in zip(kernel.get_variables(), kernel.get_expressions()): for diff_order in range(var.get_differential_order()): var_name = var.get_name() + "'" * diff_order existing_symbol = kernel.get_scope().resolve_to_symbol(var_name, SymbolKind.VARIABLE) if existing_symbol is None: code, message = Messages.get_no_variable_found(var.get_name_of_lhs()) Logger.log_message(code=code, message=message, error_position=kernel.get_source_position(), log_level=LoggingLevel.ERROR) return existing_symbol.set_ode_or_kernel(expr) existing_symbol.set_variable_type(VariableType.KERNEL) kernel.get_scope().update_variable_symbol(existing_symbol)
def create_ast_kernel(cls, variables=None, expressions=None, source_position=None): # type: (ASTVariable,ASTSimpleExpression|ASTExpression,ASTSourceLocation) -> ASTKernel return ASTKernel(variables, expressions, source_position=source_position)
def is_delta_kernel(cls, kernel: ASTKernel) -> bool: """ Catches definition of kernel, or reference (function call or variable name) of a delta kernel function. """ if type(kernel) is ASTKernel: if not len(kernel.get_variables()) == 1: # delta kernel not allowed if more than one variable is defined in this kernel return False expr = kernel.get_expressions()[0] else: expr = kernel rhs_is_delta_kernel = type(expr) is ASTSimpleExpression \ and expr.is_function_call() \ and expr.get_function_call().get_scope().resolve_to_symbol( expr.get_function_call().get_name(), SymbolKind.FUNCTION) == PredefinedFunctions.name2function["delta"] rhs_is_multiplied_delta_kernel = type(expr) is ASTExpression \ and type(expr.get_rhs()) is ASTSimpleExpression \ and expr.get_rhs().is_function_call() \ and expr.get_rhs().get_function_call().get_scope().resolve_to_symbol( expr.get_rhs().get_function_call().get_name(), SymbolKind.FUNCTION) == PredefinedFunctions.name2function[ "delta"] return rhs_is_delta_kernel or rhs_is_multiplied_delta_kernel