def get_delta_factors_(self, neuron, equations_block): r""" For every occurrence of a convolution of the form `x^(n) = a * convolve(kernel, inport) + ...` where `kernel` is a delta function, add the element `(x^(n), inport) --> a` to the set. """ delta_factors = {} for ode_eq in equations_block.get_ode_equations(): var = ode_eq.get_lhs() expr = ode_eq.get_rhs() conv_calls = OdeTransformer.get_convolve_function_calls(expr) for conv_call in conv_calls: assert len( conv_call.args ) == 2, "convolve() function call should have precisely two arguments: kernel and spike buffer" kernel = conv_call.args[0] if is_delta_kernel( neuron.get_kernel_by_name( kernel.get_variable().get_name())): inport = conv_call.args[1].get_variable() expr_str = str(expr) sympy_expr = sympy.parsing.sympy_parser.parse_expr( expr_str) sympy_expr = sympy.expand(sympy_expr) sympy_conv_expr = sympy.parsing.sympy_parser.parse_expr( str(conv_call)) factor_str = [] for term in sympy.Add.make_args(sympy_expr): if term.find(sympy_conv_expr): factor_str.append( str(term.replace(sympy_conv_expr, 1))) factor_str = " + ".join(factor_str) delta_factors[(var, inport)] = factor_str return delta_factors
def replace_function_call_through_var(_expr=None): if _expr.is_function_call() and _expr.get_function_call().get_name( ) == "convolve": convolve = _expr.get_function_call() el = (convolve.get_args()[0], convolve.get_args()[1]) sym = convolve.get_args()[0].get_scope().resolve_to_symbol( convolve.get_args()[0].get_variable().name, SymbolKind.VARIABLE) if sym.block_type == BlockType.INPUT_BUFFER_SPIKE: el = (el[1], el[0]) var = el[0].get_variable() spike_input_port = el[1].get_variable() kernel = neuron.get_kernel_by_name(var.get_name()) _expr.set_function_call(None) buffer_var = construct_kernel_X_spike_buf_name( var.get_name(), spike_input_port, var.get_differential_order() - 1) if is_delta_kernel(kernel): # delta kernel are treated separately, and should be kept out of the dynamics (computing derivates etc.) --> set to zero _expr.set_variable(None) _expr.set_numeric_literal(0) else: ast_variable = ASTVariable(buffer_var) ast_variable.set_source_position( _expr.get_source_position()) _expr.set_variable(ast_variable)
def get_spike_update_expressions(self, neuron: ASTNeuron, kernel_buffers, solver_dicts, delta_factors) -> List[ASTAssignment]: """ Generate the equations that update the dynamical variables when incoming spikes arrive. To be invoked after ode-toolbox. For example, a resulting `assignment_str` could be "I_kernel_in += (in_spikes/nS) * 1". The values are taken from the initial values for each corresponding dynamical variable, either from ode-toolbox or directly from user specification in the model. Note that for kernels, `initial_values` actually contains the increment upon spike arrival, rather than the initial value of the corresponding ODE dimension. """ spike_updates = [] initial_values = neuron.get_initial_values_blocks() for kernel, spike_input_port in kernel_buffers: if neuron.get_scope().resolve_to_symbol(str(spike_input_port), SymbolKind.VARIABLE) is None: continue buffer_type = neuron.get_scope().resolve_to_symbol(str(spike_input_port), SymbolKind.VARIABLE).get_type_symbol() if is_delta_kernel(kernel): continue for kernel_var in kernel.get_variables(): for var_order in range(get_kernel_var_order_from_ode_toolbox_result(kernel_var.get_name(), solver_dicts)): kernel_spike_buf_name = construct_kernel_X_spike_buf_name( kernel_var.get_name(), spike_input_port, var_order) expr = get_initial_value_from_ode_toolbox_result(kernel_spike_buf_name, solver_dicts) assert expr is not None, "Initial value not found for kernel " + kernel_var expr = str(expr) if expr in ["0", "0.", "0.0"]: continue # skip adding the statement if we're only adding zero assignment_str = kernel_spike_buf_name + " += " assignment_str += "(" + str(spike_input_port) + ")" if not expr in ["1.", "1.0", "1"]: assignment_str += " * (" + \ self._printer.print_expression(ModelParser.parse_expression(expr)) + ")" if not buffer_type.print_nestml_type() in ["1.", "1.0", "1"]: assignment_str += " / (" + buffer_type.print_nestml_type() + ")" ast_assignment = ModelParser.parse_assignment(assignment_str) ast_assignment.update_scope(neuron.get_scope()) ast_assignment.accept(ASTSymbolTableVisitor()) spike_updates.append(ast_assignment) for k, factor in delta_factors.items(): var = k[0] inport = k[1] assignment_str = var.get_name() + "'" * (var.get_differential_order() - 1) + " += " if not factor in ["1.", "1.0", "1"]: assignment_str += "(" + self._printer.print_expression(ModelParser.parse_expression(factor)) + ") * " assignment_str += str(inport) ast_assignment = ModelParser.parse_assignment(assignment_str) ast_assignment.update_scope(neuron.get_scope()) ast_assignment.accept(ASTSymbolTableVisitor()) spike_updates.append(ast_assignment) return spike_updates
def transform_ode_and_kernels_to_json(self, neuron: ASTNeuron, parameters_block, kernel_buffers): """ Converts AST node to a JSON representation suitable for passing to ode-toolbox. Each kernel has to be generated for each spike buffer convolve in which it occurs, e.g. if the NESTML model code contains the statements convolve(G, ex_spikes) convolve(G, in_spikes) then `kernel_buffers` will contain the pairs `(G, ex_spikes)` and `(G, in_spikes)`, from which two ODEs will be generated, with dynamical state (variable) names `G__X__ex_spikes` and `G__X__in_spikes`. :param equations_block: ASTEquationsBlock :return: Dict """ odetoolbox_indict = {} gsl_converter = ODEToolboxReferenceConverter() gsl_printer = UnitlessExpressionPrinter(gsl_converter) odetoolbox_indict["dynamics"] = [] equations_block = neuron.get_equations_block() for equation in equations_block.get_ode_equations(): # n.b. includes single quotation marks to indicate differential order lhs = to_ode_toolbox_name(equation.get_lhs().get_complete_name()) rhs = gsl_printer.print_expression(equation.get_rhs()) entry = {"expression": lhs + " = " + rhs} symbol_name = equation.get_lhs().get_name() symbol = equations_block.get_scope().resolve_to_symbol( symbol_name, SymbolKind.VARIABLE) entry["initial_values"] = {} symbol_order = equation.get_lhs().get_differential_order() for order in range(symbol_order): iv_symbol_name = symbol_name + "'" * order initial_value_expr = neuron.get_initial_value(iv_symbol_name) if initial_value_expr: expr = gsl_printer.print_expression(initial_value_expr) entry["initial_values"][to_ode_toolbox_name( iv_symbol_name)] = expr odetoolbox_indict["dynamics"].append(entry) # write a copy for each (kernel, spike buffer) combination for kernel, spike_input_port in kernel_buffers: if is_delta_kernel(kernel): # delta function -- skip passing this to ode-toolbox continue for kernel_var in kernel.get_variables(): expr = get_expr_from_kernel_var(kernel, kernel_var.get_complete_name()) kernel_order = kernel_var.get_differential_order() kernel_X_spike_buf_name_ticks = construct_kernel_X_spike_buf_name( kernel_var.get_name(), spike_input_port, kernel_order, diff_order_symbol="'") replace_rhs_variables(expr, kernel_buffers) entry = {} entry[ "expression"] = kernel_X_spike_buf_name_ticks + " = " + str( expr) # initial values need to be declared for order 1 up to kernel order (e.g. none for kernel function f(t) = ...; 1 for kernel ODE f'(t) = ...; 2 for f''(t) = ... and so on) entry["initial_values"] = {} for order in range(kernel_order): iv_sym_name_ode_toolbox = construct_kernel_X_spike_buf_name( kernel_var.get_name(), spike_input_port, order, diff_order_symbol="'") symbol_name_ = kernel_var.get_name() + "'" * order symbol = equations_block.get_scope().resolve_to_symbol( symbol_name_, SymbolKind.VARIABLE) assert symbol is not None, "Could not find initial value for variable " + symbol_name_ initial_value_expr = symbol.get_declaring_expression() assert initial_value_expr is not None, "No initial value found for variable name " + symbol_name_ entry["initial_values"][ iv_sym_name_ode_toolbox] = gsl_printer.print_expression( initial_value_expr) odetoolbox_indict["dynamics"].append(entry) odetoolbox_indict["parameters"] = {} if parameters_block is not None: for decl in parameters_block.get_declarations(): for var in decl.variables: odetoolbox_indict["parameters"][var.get_complete_name( )] = gsl_printer.print_expression(decl.get_expression()) return odetoolbox_indict