def get_linear_system(eqs, variables): """ Convert equations into a linear system using sympy. Parameters ---------- eqs : `Equations` The model equations. Returns ------- (diff_eq_names, coefficients, constants) : (list of str, `sympy.Matrix`, `sympy.Matrix`) A tuple containing the variable names (`diff_eq_names`) corresponding to the rows of the matrix `coefficients` and the vector `constants`, representing the system of equations in the form M * X + B Raises ------ ValueError If the equations cannot be converted into an M * X + B form. """ diff_eqs = eqs.get_substituted_expressions(variables) diff_eq_names = [name for name, _ in diff_eqs] symbols = [Symbol(name, real=True) for name in diff_eq_names] coefficients = sp.zeros(len(diff_eq_names)) constants = sp.zeros(len(diff_eq_names), 1) for row_idx, (name, expr) in enumerate(diff_eqs): s_expr = str_to_sympy(expr.code, variables).expand() current_s_expr = s_expr for col_idx, symbol in enumerate(symbols): current_s_expr = current_s_expr.collect(symbol) constant_wildcard = Wild('c', exclude=[symbol]) factor_wildcard = Wild(f"c_{name}", exclude=symbols) one_pattern = factor_wildcard * symbol + constant_wildcard matches = current_s_expr.match(one_pattern) if matches is None: raise UnsupportedEquationsException( f"The expression '{expr}', " f"defining the variable " f"'{name}', could not be " f"separated into linear " f"components.") coefficients[row_idx, col_idx] = matches[factor_wildcard] current_s_expr = matches[constant_wildcard] # The remaining constant should be a true constant constants[row_idx] = current_s_expr return (diff_eq_names, coefficients, constants)
def _check_for_locally_constant(expression, variables, dt_value, t_symbol): for arg in expression.args: if arg is t_symbol: # We found "t" -- if it is not the only argument of a locally # constant function we bail out func_name = str(expression.func) if not (func_name in variables and variables[func_name].is_locally_constant(dt_value)): raise UnsupportedEquationsException( ('t is used in a context ' 'where we cannot ' 'guarantee that it can be ' 'considered locally ' 'constant.')) else: _check_for_locally_constant(arg, variables, dt_value, t_symbol)
def translate( self, code, dtype ): # TODO: it's not so nice we have to copy the contents of this function.. ''' Translates an abstract code block into the target language. ''' # first check if user code is not using variables that are also used by GSL reserved_variables = [ '_dataholder', '_fill_y_vector', '_empty_y_vector', '_GSL_dataholder', '_GSL_y', '_GSL_func' ] if any([var in self.variables for var in reserved_variables]): # import here to avoid circular import raise ValueError(("The variables %s are reserved for the GSL " "internal code." % (str(reserved_variables)))) # if the following statements are not added, Brian translates the # differential expressions in the abstract code for GSL to scalar statements # in the case no non-scalar variables are used in the expression diff_vars = self.find_differential_variables(code.values()) self.add_gsl_variables_as_non_scalar(diff_vars) # add arrays we want to use in generated code before self.generator.translate() so # brian does namespace unpacking for us pointer_names = self.add_meta_variables(self.method_options) scalar_statements = {} vector_statements = {} for ac_name, ac_code in code.iteritems(): statements = make_statements(ac_code, self.variables, dtype, optimise=True, blockname=ac_name) scalar_statements[ac_name], vector_statements[ac_name] = statements for vs in vector_statements.itervalues(): # Check that the statements are meaningful independent on the order of # execution (e.g. for synapses) try: if self.has_repeated_indices( vs ): # only do order dependence if there are repeated indices check_for_order_independence( vs, self.generator.variables, self.generator.variable_indices) except OrderDependenceError: # If the abstract code is only one line, display it in ful l if len(vs) <= 1: error_msg = 'Abstract code: "%s"\n' % vs[0] else: error_msg = ( '%_GSL_driver lines of abstract code, first line is: ' '"%s"\n') % (len(vs), vs[0]) # save function names because self.generator.translate_statement_sequence # deletes these from self.variables but we need to know which identifiers # we can safely ignore (i.e. we can ignore the functions because they are # handled by the original generator) self.function_names = self.find_function_names() scalar_code, vector_code, kwds = self.generator.translate_statement_sequence( scalar_statements, vector_statements) ############ translate code for GSL # first check if any indexing other than '_idx' is used (currently not supported) for code_list in scalar_code.values() + vector_code.values(): for code in code_list: m = re.search('\[(\w+)\]', code) if m is not None: if m.group(1) != '0' and m.group(1) != '_idx': from brian2.stateupdaters.base import UnsupportedEquationsException raise UnsupportedEquationsException( ("Equations result in state " "updater code with indexing " "other than '_idx', which " "is currently not supported " "in combination with the " "GSL stateupdater.")) # differential variable specific operations to_replace = self.diff_var_to_replace(diff_vars) GSL_support_code = self.get_dimension_code(len(diff_vars)) GSL_support_code += self.yvector_code(diff_vars) # analyze all needed variables; if not in self.variables: put in separate dic. # also keep track of variables needed for scalar statements and vector statements other_variables = self.find_undefined_variables( scalar_statements[None] + vector_statements[None]) variables_in_scalar = self.find_used_variables(scalar_statements[None], other_variables) variables_in_vector = self.find_used_variables(vector_statements[None], other_variables) # so that _dataholder holds diff_vars as well, even if they don't occur # in the actual statements for var in diff_vars.keys(): if not var in variables_in_vector: variables_in_vector[var] = self.variables[var] # lets keep track of the variables that eventually need to be added to # the _GSL_dataholder somehow self.variables_to_be_processed = variables_in_vector.keys() # add code for _dataholder struct GSL_support_code = self.write_dataholder( variables_in_vector) + GSL_support_code # add e.g. _lio_1 --> _GSL_dataholder._lio_1 to replacer to_replace.update( self.to_replace_vector_vars(variables_in_vector, ignore=diff_vars.keys())) # write statements that unpack (python) namespace to _dataholder struct # or local namespace GSL_main_code = self.unpack_namespace(variables_in_vector, variables_in_scalar, ['t']) # rewrite actual calculations described by vector_code and put them in _GSL_func func_code = self.translate_one_statement_sequence( vector_statements[None], scalar=False) GSL_support_code += self.make_function_code( self.translate_vector_code(func_code, to_replace)) scalar_func_code = self.translate_one_statement_sequence( scalar_statements[None], scalar=True) # rewrite scalar code, keep variables that are needed in scalar code normal # and add variables to _dataholder for vector_code GSL_main_code += '\n' + self.translate_scalar_code( scalar_func_code, variables_in_scalar, variables_in_vector) if len(self.variables_to_be_processed) > 0: raise AssertionError( ("Not all variables that will be used in the vector " "code have been added to the _GSL_dataholder. This " "might mean that the _GSL_func is using unitialized " "variables." "\nThe unprocessed variables " "are: %s" % (str(self.variables_to_be_processed)))) scalar_code['GSL'] = GSL_main_code kwds['define_GSL_scale_array'] = self.scale_array_code( diff_vars, self.method_options) kwds['n_diff_vars'] = len(diff_vars) kwds['GSL_settings'] = dict(self.method_options) kwds['GSL_settings']['integrator'] = self.integrator kwds['support_code_lines'] += GSL_support_code.split('\n') kwds['t_array'] = self.get_array_name(self.variables['t']) + '[0]' kwds['dt_array'] = self.get_array_name(self.variables['dt']) + '[0]' kwds['define_dt'] = 'dt' not in variables_in_scalar kwds['cpp_standalone'] = self.is_cpp_standalone() for key, value in pointer_names.items(): kwds[key] = value return scalar_code, vector_code, kwds
def __call__(self, equations, variables=None, method_options=None): logger.warn( "The 'independent' state updater is deprecated and might be " "removed in future versions of Brian.", 'deprecated_independent', once=True) extract_method_options(method_options, {}) if equations.is_stochastic: raise UnsupportedEquationsException("Cannot solve stochastic " "equations with this state " "updater") if variables is None: variables = {} diff_eqs = equations.get_substituted_expressions(variables) t = Symbol('t', real=True, positive=True) dt = Symbol('dt', real=True, positive=True) t0 = Symbol('t0', real=True, positive=True) code = [] for name, expression in diff_eqs: rhs = str_to_sympy(expression.code, variables) # We have to be careful and use the real=True assumption as well, # otherwise sympy doesn't consider the symbol a match to the content # of the equation var = Symbol(name, real=True) f = sp.Function(name) rhs = rhs.subs(var, f(t)) derivative = sp.Derivative(f(t), t) diff_eq = sp.Eq(derivative, rhs) # TODO: simplify=True sometimes fails with 0.7.4, see: # https://github.com/sympy/sympy/issues/2666 try: general_solution = sp.dsolve(diff_eq, f(t), simplify=True) except RuntimeError: general_solution = sp.dsolve(diff_eq, f(t), simplify=False) # Check whether this is an explicit solution if not getattr(general_solution, 'lhs', None) == f(t): raise UnsupportedEquationsException( f"Cannot explicitly solve: {str(diff_eq)}") # Solve for C1 (assuming "var" as the initial value and "t0" as time) if general_solution.has(Symbol('C1')): if general_solution.has(Symbol('C2')): raise UnsupportedEquationsException( f'Too many constants in solution: {str(general_solution)}' ) constant_solution = sp.solve(general_solution, Symbol('C1')) if len(constant_solution) != 1: raise UnsupportedEquationsException( ("Couldn't solve for the constant " "C1 in : %s ") % str(general_solution)) constant = constant_solution[0].subs(t, t0).subs(f(t0), var) solution = general_solution.rhs.subs('C1', constant) else: solution = general_solution.rhs.subs(t, t0).subs(f(t0), var) # Evaluate the expression for one timestep solution = solution.subs(t, t + dt).subs(t0, t) # only try symplifying it -- it sometimes raises an error try: solution = solution.simplify() except ValueError: pass code.append(f"{name} = {sympy_to_str(solution)}") return '\n'.join(code)
def __call__(self, equations, variables=None, method_options=None): method_options = extract_method_options(method_options, {'simplify': True}) if equations.is_stochastic: raise UnsupportedEquationsException("Cannot solve stochastic " "equations with this state " "updater.") if variables is None: variables = {} # Get a representation of the ODE system in the form of # dX/dt = M*X + B varnames, matrix, constants = get_linear_system(equations, variables) # No differential equations, nothing to do (this occurs sometimes in the # test suite where the whole model is nothing more than something like # 'v : 1') if matrix.shape == (0, 0): return '' # Make sure that the matrix M is constant, i.e. it only contains # external variables or constant variables # Check for time dependence dt_value = variables['dt'].get_value( )[0] if 'dt' in variables else None # This will raise an error if we meet the symbol "t" anywhere # except as an argument of a locally constant function for entry in itertools.chain(matrix, constants): if not is_constant_over_dt(entry, variables, dt_value): raise UnsupportedEquationsException( f"Expression '{sympy_to_str(entry)}' is not guaranteed to be " f"constant over a time step.") symbols = [Symbol(variable, real=True) for variable in varnames] solution = sp.solve_linear_system(matrix.row_join(constants), *symbols) if solution is None or set(symbols) != set(solution.keys()): raise UnsupportedEquationsException("Cannot solve the given " "equations with this " "stateupdater.") b = sp.ImmutableMatrix([solution[symbol] for symbol in symbols]) # Solve the system dt = Symbol('dt', real=True, positive=True) try: A = (matrix * dt).exp() except NotImplementedError: raise UnsupportedEquationsException("Cannot solve the given " "equations with this " "stateupdater.") if method_options['simplify']: A = A.applyfunc( lambda x: sp.factor_terms(sp.cancel(sp.signsimp(x)))) C = sp.ImmutableMatrix(A * b) - b _S = sp.MatrixSymbol('_S', len(varnames), 1) updates = A * _S + C updates = updates.as_explicit() # The solution contains _S[0, 0], _S[1, 0] etc. for the state variables, # replace them with the state variable names abstract_code = [] for idx, (variable, update) in enumerate(zip(varnames, updates)): rhs = update if rhs.has(I, re, im): raise UnsupportedEquationsException( "The solution to the linear system " "contains complex values " "which is currently not implemented.") for row_idx, varname in enumerate(varnames): rhs = rhs.subs(_S[row_idx, 0], varname) # Do not overwrite the real state variables yet, the update step # of other state variables might still need the original values abstract_code.append(f"_{variable} = {sympy_to_str(rhs)}") # Update the state variables for variable in varnames: abstract_code.append(f"{variable} = _{variable}") return '\n'.join(abstract_code)
def __call__(self, equations, variables=None): if equations.is_stochastic: raise UnsupportedEquationsException('Cannot solve stochastic ' 'equations with this state ' 'updater') if variables is None: variables = {} diff_eqs = equations.get_substituted_expressions(variables) t = Symbol('t', real=True, positive=True) dt = Symbol('dt', real=True, positive=True) t0 = Symbol('t0', real=True, positive=True) f0 = Symbol('f0', real=True) # TODO: Shortcut for simple linear equations? Is all this effort really # worth it? code = [] for name, expression in diff_eqs: rhs = str_to_sympy(expression.code, variables) # We have to be careful and use the real=True assumption as well, # otherwise sympy doesn't consider the symbol a match to the content # of the equation var = Symbol(name, real=True) f = sp.Function(name) rhs = rhs.subs(var, f(t)) derivative = sp.Derivative(f(t), t) diff_eq = sp.Eq(derivative, rhs) # TODO: simplify=True sometimes fails with 0.7.4, see: # https://github.com/sympy/sympy/issues/2666 try: general_solution = sp.dsolve(diff_eq, f(t), simplify=True) except RuntimeError: general_solution = sp.dsolve(diff_eq, f(t), simplify=False) # Check whether this is an explicit solution if not getattr(general_solution, 'lhs', None) == f(t): raise UnsupportedEquationsException( 'Cannot explicitly solve: ' + str(diff_eq)) # Solve for C1 (assuming "var" as the initial value and "t0" as time) if general_solution.has(Symbol('C1')): if general_solution.has(Symbol('C2')): raise UnsupportedEquationsException( 'Too many constants in solution: %s' % str(general_solution)) constant_solution = sp.solve(general_solution, Symbol('C1')) if len(constant_solution) != 1: raise UnsupportedEquationsException( ("Couldn't solve for the constant " "C1 in : %s ") % str(general_solution)) constant = constant_solution[0].subs(t, t0).subs(f(t0), var) solution = general_solution.rhs.subs('C1', constant) else: solution = general_solution.rhs.subs(t, t0).subs(f(t0), var) # Evaluate the expression for one timestep solution = solution.subs(t, t + dt).subs(t0, t) # only try symplifying it -- it sometimes raises an error try: solution = solution.simplify() except ValueError: pass code.append(name + ' = ' + sympy_to_str(solution)) return '\n'.join(code)
def __call__(self, equations, variables=None, simplify=True): if equations.is_stochastic: raise UnsupportedEquationsException('Cannot solve stochastic ' 'equations with this state ' 'updater.') if variables is None: variables = {} # Get a representation of the ODE system in the form of # dX/dt = M*X + B varnames, matrix, constants = get_linear_system(equations, variables) # No differential equations, nothing to do (this occurs sometimes in the # test suite where the whole model is nothing more than something like # 'v : 1') if matrix.shape == (0, 0): return '' # Make sure that the matrix M is constant, i.e. it only contains # external variables or constant variables t = Symbol('t', real=True, positive=True) # Check for time dependence if 'dt' in variables: dt_value = variables['dt'].get_value()[0] # This will raise an error if we meet the symbol "t" anywhere # except as an argument of a locally constant function t = Symbol('t', real=True, positive=True) for entry in itertools.chain(matrix, constants): _check_for_locally_constant(entry, variables, dt_value, t) symbols = [Symbol(variable, real=True) for variable in varnames] solution = sp.solve_linear_system(matrix.row_join(constants), *symbols) if solution is None or set(symbols) != set(solution.keys()): raise UnsupportedEquationsException('Cannot solve the given ' 'equations with this ' 'stateupdater.') b = sp.ImmutableMatrix([solution[symbol] for symbol in symbols]).transpose() # Solve the system dt = Symbol('dt', real=True, positive=True) try: A = (matrix * dt).exp() except NotImplementedError: raise UnsupportedEquationsException('Cannot solve the given ' 'equations with this ' 'stateupdater.') if simplify: A = A.applyfunc( lambda x: sp.factor_terms(sp.cancel(sp.signsimp(x)))) C = sp.ImmutableMatrix([A.dot(b)]) - b _S = sp.MatrixSymbol('_S', len(varnames), 1) updates = A * _S + C.transpose() updates = updates.as_explicit() # The solution contains _S[0, 0], _S[1, 0] etc. for the state variables, # replace them with the state variable names abstract_code = [] for idx, (variable, update) in enumerate(zip(varnames, updates)): rhs = update for row_idx, varname in enumerate(varnames): rhs = rhs.subs(_S[row_idx, 0], varname) # Do not overwrite the real state variables yet, the update step # of other state variables might still need the original values abstract_code.append('_' + variable + ' = ' + sympy_to_str(rhs)) # Update the state variables for variable in varnames: abstract_code.append( '{variable} = _{variable}'.format(variable=variable)) return '\n'.join(abstract_code)