def test_constants_sympy(): ''' Make sure that symbolic constants are understood correctly by sympy ''' assert sympy_to_str(str_to_sympy('1.0/inf')) == '0' assert sympy_to_str(str_to_sympy('sin(pi)')) == '0' assert sympy_to_str(str_to_sympy('log(e)')) == '1'
def __str__(self): s = '%s\n' % self.__class__.__name__ if len(self.statements) > 0: s += 'Intermediate statements:\n' s += '\n'.join([(var + ' = ' + sympy_to_str(expr)) for var, expr in self.statements]) s += '\n' s += 'Output:\n' s += sympy_to_str(self.output) return s
def get_substituted_expressions(self, variables=None, include_subexpressions=False): ''' Return a list of ``(varname, expr)`` tuples, containing all differential equations (and optionally subexpressions) with all the subexpression variables substituted with the respective expressions. Parameters ---------- variables : dict, optional A mapping of variable names to `Variable`/`Function` objects. include_subexpressions : bool Whether also to return substituted subexpressions. Defaults to ``False``. Returns ------- expr_tuples : list of (str, `CodeString`) A list of ``(varname, expr)`` tuples, where ``expr`` is a `CodeString` object with all subexpression variables substituted with the respective expression. ''' if self._substituted_expressions is None: self._substituted_expressions = [] substitutions = {} for eq in self.ordered: # Skip parameters if eq.expr is None: continue new_sympy_expr = str_to_sympy( eq.expr.code, variables).xreplace(substitutions) new_str_expr = sympy_to_str(new_sympy_expr) expr = Expression(new_str_expr) if eq.type == SUBEXPRESSION: if eq.var_type == INTEGER: sympy_var = sympy.Symbol(eq.varname, integer=True) else: sympy_var = sympy.Symbol(eq.varname, real=True) substitutions.update( {sympy_var: str_to_sympy(expr.code, variables)}) self._substituted_expressions.append((eq.varname, expr)) elif eq.type == DIFFERENTIAL_EQUATION: # a differential equation that we have to check self._substituted_expressions.append((eq.varname, expr)) else: raise AssertionError('Unknown equation type %s' % eq.type) if include_subexpressions: return self._substituted_expressions else: return [(name, expr) for name, expr in self._substituted_expressions if self[name].type == DIFFERENTIAL_EQUATION]
def __call__(self, equations, variables=None, method_options=None): method_options = extract_method_options(method_options, {}) if equations.is_stochastic: raise UnsupportedEquationsException('Cannot solve stochastic ' 'equations with this state ' 'updater.') # Try whether the equations are conditionally linear try: system = get_conditionally_linear_system(equations, variables) except ValueError: raise UnsupportedEquationsException('Can only solve conditionally ' 'linear systems with this ' 'state updater.') code = [] for var, (A, B) in system.items(): s_var = sp.Symbol(var) s_dt = sp.Symbol('dt') if A == 0: update_expression = s_var + s_dt * B elif B != 0: BA = B / A # Avoid calculating B/A twice BA_name = '_BA_' + var s_BA = sp.Symbol(BA_name) code += [BA_name + ' = ' + sympy_to_str(BA)] update_expression = (s_var + s_BA)*sp.exp(A*s_dt) - s_BA else: update_expression = s_var*sp.exp(A*s_dt) # The actual update step update = '_{var} = {expr}' code += [update.format(var=var, expr=sympy_to_str(update_expression))] # Replace all the variables with their updated value for var in system: code += ['{var} = _{var}'.format(var=var)] return '\n'.join(code)
def __init__(self, code=None, sympy_expression=None): if code is None and sympy_expression is None: raise TypeError('Have to provide either a string or a sympy expression') if code is not None and sympy_expression is not None: raise TypeError('Provide a string expression or a sympy expression, not both') if code is None: code = sympy_to_str(sympy_expression) else: # Just try to convert it to a sympy expression to get syntax errors # for incorrect expressions str_to_sympy(code) super(Expression, self).__init__(code=code)
def __init__(self, morphology=None, model=None, threshold=None, refractory=False, reset=None, events=None, threshold_location=None, dt=None, clock=None, order=0, Cm=0.9 * uF / cm ** 2, Ri=150 * ohm * cm, name='spatialneuron*', dtype=None, namespace=None, method=('exact', 'exponential_euler', 'rk2', 'heun'), method_options=None): # #### Prepare and validate equations if isinstance(model, str): model = Equations(model) if not isinstance(model, Equations): raise TypeError(('model has to be a string or an Equations ' 'object, is "%s" instead.') % type(model)) # Insert the threshold mechanism at the specified location if threshold_location is not None: if hasattr(threshold_location, '_indices'): # assuming this is a method threshold_location = threshold_location._indices() # for now, only a single compartment allowed if len(threshold_location) == 1: threshold_location = threshold_location[0] else: raise AttributeError(('Threshold can only be applied on a ' 'single location')) threshold = '(' + threshold + ') and (i == ' + str(threshold_location) + ')' # Check flags (we have point currents) model.check_flags({DIFFERENTIAL_EQUATION: ('point current',), PARAMETER: ('constant', 'shared', 'linked', 'point current'), SUBEXPRESSION: ('shared', 'point current', 'constant over dt')}) #: The original equations as specified by the user (i.e. before #: inserting point-currents into the membrane equation, before adding #: all the internally used variables and constants, etc.). self.user_equations = model # Separate subexpressions depending whether they are considered to be # constant over a time step or not (this would also be done by the # NeuronGroup initializer later, but this would give incorrect results # for the linearity check) model, constant_over_dt = extract_constant_subexpressions(model) # Extract membrane equation if 'Im' in model: if len(model['Im'].flags): raise TypeError('Cannot specify any flags for the transmembrane ' 'current Im.') membrane_expr = model['Im'].expr # the membrane equation else: raise TypeError('The transmembrane current Im must be defined') model_equations = [] # Insert point currents in the membrane equation for eq in model.values(): if eq.varname == 'Im': continue # ignore -- handled separately if 'point current' in eq.flags: fail_for_dimension_mismatch(eq.dim, amp, "Point current " + eq.varname + " should be in amp") membrane_expr = Expression( str(membrane_expr.code) + '+' + eq.varname + '/area') eq = SingleEquation(eq.type, eq.varname, eq.dim, expr=eq.expr, flags=list(set(eq.flags)-{'point current'})) model_equations.append(eq) model_equations.append(SingleEquation(SUBEXPRESSION, 'Im', dimensions=(amp/meter**2).dim, expr=membrane_expr)) model_equations.append(SingleEquation(PARAMETER, 'v', volt.dim)) model = Equations(model_equations) ###### Process model equations (Im) to extract total conductance and the remaining current # Expand expressions in the membrane equation for var, expr in model.get_substituted_expressions(include_subexpressions=True): if var == 'Im': Im_expr = expr break else: raise AssertionError('Model equations did not contain Im!') # Differentiate Im with respect to v Im_sympy_exp = str_to_sympy(Im_expr.code) v_sympy = sp.Symbol('v', real=True) diffed = sp.diff(Im_sympy_exp, v_sympy) unevaled_derivatives = diffed.atoms(sp.Derivative) if len(unevaled_derivatives): raise TypeError('Cannot take the derivative of "{Im}" with respect ' 'to v.'.format(Im=Im_expr.code)) gtot_str = sympy_to_str(sp.simplify(-diffed)) I0_str = sympy_to_str(sp.simplify(Im_sympy_exp - diffed*v_sympy)) if gtot_str == '0': gtot_str += '*siemens/meter**2' if I0_str == '0': I0_str += '*amp/meter**2' gtot_str = "gtot__private=" + gtot_str + ": siemens/meter**2" I0_str = "I0__private=" + I0_str + ": amp/meter**2" model += Equations(gtot_str + "\n" + I0_str) # Insert morphology (store a copy) self.morphology = copy.deepcopy(morphology) # Flatten the morphology self.flat_morphology = FlatMorphology(morphology) # Equations for morphology # TODO: check whether Cm and Ri are already in the equations # no: should be shared instead of constant # yes: should be constant (check) eqs_constants = Equations(""" length : meter (constant) distance : meter (constant) area : meter**2 (constant) volume : meter**3 Ic : amp/meter**2 diameter : meter (constant) Cm : farad/meter**2 (constant) Ri : ohm*meter (constant, shared) r_length_1 : meter (constant) r_length_2 : meter (constant) time_constant = Cm/gtot__private : second space_constant = (2/pi)**(1.0/3.0) * (area/(1/r_length_1 + 1/r_length_2))**(1.0/6.0) / (2*(Ri*gtot__private)**(1.0/2.0)) : meter """) if self.flat_morphology.has_coordinates: eqs_constants += Equations(''' x : meter (constant) y : meter (constant) z : meter (constant) ''') NeuronGroup.__init__(self, morphology.total_compartments, model=model + eqs_constants, method_options=method_options, threshold=threshold, refractory=refractory, reset=reset, events=events, method=method, dt=dt, clock=clock, order=order, namespace=namespace, dtype=dtype, name=name) # Parameters and intermediate variables for solving the cable equations # Note that some of these variables could have meaningful physical # units (e.g. _v_star is in volt, _I0_all is in amp/meter**2 etc.) but # since these variables should never be used in user code, we don't # assign them any units self.variables.add_arrays(['_ab_star0', '_ab_star1', '_ab_star2', '_b_plus', '_b_minus', '_v_star', '_u_plus', '_u_minus', '_v_previous', '_c', # The following two are only necessary for # C code where we cannot deal with scalars # and arrays interchangeably: '_I0_all', '_gtot_all'], size=self.N, read_only=True) self.Cm = Cm self.Ri = Ri # These explict assignments will load the morphology values from disk # in standalone mode self.distance_ = self.flat_morphology.distance self.length_ = self.flat_morphology.length self.area_ = self.flat_morphology.area self.diameter_ = self.flat_morphology.diameter self.r_length_1_ = self.flat_morphology.r_length_1 self.r_length_2_ = self.flat_morphology.r_length_2 if self.flat_morphology.has_coordinates: self.x_ = self.flat_morphology.x self.y_ = self.flat_morphology.y self.z_ = self.flat_morphology.z # Performs numerical integration step self.add_attribute('diffusion_state_updater') self.diffusion_state_updater = SpatialStateUpdater(self, method, clock=self.clock, order=order) # Update v after the gating variables to obtain consistent Ic and Im self.diffusion_state_updater.order = 1 # Creation of contained_objects that do the work self.contained_objects.extend([self.diffusion_state_updater]) if len(constant_over_dt): self.subexpression_updater = SubexpressionUpdater(self, constant_over_dt) self.contained_objects.append(self.subexpression_updater)
def __call__(self, equations, variables=None, method_options=None): logger.warn("The 'independent' state updater is deprecated and might be " "removed in future versions of angela.", 'deprecated_independent', once=True) method_options = extract_method_options(method_options, {}) if equations.is_stochastic: raise UnsupportedEquationsException('Cannot solve stochastic ' 'equations with this state ' 'updater') if variables is None: variables = {} diff_eqs = equations.get_substituted_expressions(variables) t = Symbol('t', real=True, positive=True) dt = Symbol('dt', real=True, positive=True) t0 = Symbol('t0', real=True, positive=True) code = [] for name, expression in diff_eqs: rhs = str_to_sympy(expression.code, variables) # We have to be careful and use the real=True assumption as well, # otherwise sympy doesn't consider the symbol a match to the content # of the equation var = Symbol(name, real=True) f = sp.Function(name) rhs = rhs.subs(var, f(t)) derivative = sp.Derivative(f(t), t) diff_eq = sp.Eq(derivative, rhs) # TODO: simplify=True sometimes fails with 0.7.4, see: # https://github.com/sympy/sympy/issues/2666 try: general_solution = sp.dsolve(diff_eq, f(t), simplify=True) except RuntimeError: general_solution = sp.dsolve(diff_eq, f(t), simplify=False) # Check whether this is an explicit solution if not getattr(general_solution, 'lhs', None) == f(t): raise UnsupportedEquationsException('Cannot explicitly solve: ' + str(diff_eq)) # Solve for C1 (assuming "var" as the initial value and "t0" as time) if general_solution.has(Symbol('C1')): if general_solution.has(Symbol('C2')): raise UnsupportedEquationsException('Too many constants in solution: %s' % str(general_solution)) constant_solution = sp.solve(general_solution, Symbol('C1')) if len(constant_solution) != 1: raise UnsupportedEquationsException(("Couldn't solve for the constant " "C1 in : %s ") % str(general_solution)) constant = constant_solution[0].subs(t, t0).subs(f(t0), var) solution = general_solution.rhs.subs('C1', constant) else: solution = general_solution.rhs.subs(t, t0).subs(f(t0), var) # Evaluate the expression for one timestep solution = solution.subs(t, t + dt).subs(t0, t) # only try symplifying it -- it sometimes raises an error try: solution = solution.simplify() except ValueError: pass code.append(name + ' = ' + sympy_to_str(solution)) return '\n'.join(code)
def __call__(self, equations, variables=None, method_options=None): method_options = extract_method_options(method_options, {'simplify': True}) if equations.is_stochastic: raise UnsupportedEquationsException('Cannot solve stochastic ' 'equations with this state ' 'updater.') if variables is None: variables = {} # Get a representation of the ODE system in the form of # dX/dt = M*X + B varnames, matrix, constants = get_linear_system(equations, variables) # No differential equations, nothing to do (this occurs sometimes in the # test suite where the whole model is nothing more than something like # 'v : 1') if matrix.shape == (0, 0): return '' # Make sure that the matrix M is constant, i.e. it only contains # external variables or constant variables t = Symbol('t', real=True, positive=True) # Check for time dependence dt_value = variables['dt'].get_value()[0] if 'dt' in variables else None # This will raise an error if we meet the symbol "t" anywhere # except as an argument of a locally constant function for entry in itertools.chain(matrix, constants): if not is_constant_over_dt(entry, variables, dt_value): raise UnsupportedEquationsException( ('Expression "{}" is not guaranteed to be constant over a ' 'time step').format(sympy_to_str(entry))) symbols = [Symbol(variable, real=True) for variable in varnames] solution = sp.solve_linear_system(matrix.row_join(constants), *symbols) if solution is None or set(symbols) != set(solution.keys()): raise UnsupportedEquationsException('Cannot solve the given ' 'equations with this ' 'stateupdater.') b = sp.ImmutableMatrix([solution[symbol] for symbol in symbols]) # Solve the system dt = Symbol('dt', real=True, positive=True) try: A = (matrix * dt).exp() except NotImplementedError: raise UnsupportedEquationsException('Cannot solve the given ' 'equations with this ' 'stateupdater.') if method_options['simplify']: A = A.applyfunc(lambda x: sp.factor_terms(sp.cancel(sp.signsimp(x)))) C = sp.ImmutableMatrix(A * b) - b _S = sp.MatrixSymbol('_S', len(varnames), 1) updates = A * _S + C updates = updates.as_explicit() # The solution contains _S[0, 0], _S[1, 0] etc. for the state variables, # replace them with the state variable names abstract_code = [] for idx, (variable, update) in enumerate(zip(varnames, updates)): rhs = update if rhs.has(I, re, im): raise UnsupportedEquationsException('The solution to the linear system ' 'contains complex values ' 'which is currently not implemented.') for row_idx, varname in enumerate(varnames): rhs = rhs.subs(_S[row_idx, 0], varname) # Do not overwrite the real state variables yet, the update step # of other state variables might still need the original values abstract_code.append('_' + variable + ' = ' + sympy_to_str(rhs)) # Update the state variables for variable in varnames: abstract_code.append('{variable} = _{variable}'.format(variable=variable)) return '\n'.join(abstract_code)
def test_automatic_augmented_assignments(): # We test that statements that could be rewritten as augmented assignments # are correctly rewritten (using sympy to test for symbolic equality) variables = { 'x': ArrayVariable('x', owner=None, size=10, device=device), 'y': ArrayVariable('y', owner=None, size=10, device=device), 'z': ArrayVariable('y', owner=None, size=10, device=device), 'b': ArrayVariable('b', owner=None, size=10, dtype=np.bool, device=device), 'clip': DEFAULT_FUNCTIONS['clip'], 'inf': DEFAULT_CONSTANTS['inf'] } statements = [ # examples that should be rewritten # Note that using our approach, we will never get -= or /= but always # the equivalent += or *= statements ('x = x + 1.0', 'x += 1.0'), ('x = 2.0 * x', 'x *= 2.0'), ('x = x - 3.0', 'x += -3.0'), ('x = x/2.0', 'x *= 0.5'), ('x = y + (x + 1.0)', 'x += y + 1.0'), ('x = x + x', 'x *= 2.0'), ('x = x + y + z', 'x += y + z'), ('x = x + y + z', 'x += y + z'), # examples that should not be rewritten ('x = 1.0/x', 'x = 1.0/x'), ('x = 1.0', 'x = 1.0'), ('x = 2.0*(x + 1.0)', 'x = 2.0*(x + 1.0)'), ('x = clip(x + y, 0.0, inf)', 'x = clip(x + y, 0.0, inf)'), ('b = b or False', 'b = b or False') ] for orig, rewritten in statements: scalar, vector = make_statements(orig, variables, np.float32) try: # we augment the assertion error with the original statement assert len( scalar ) == 0, 'Did not expect any scalar statements but got ' + str( scalar) assert len( vector ) == 1, 'Did expect a single statement but got ' + str(vector) statement = vector[0] expected_var, expected_op, expected_expr, _ = parse_statement( rewritten) assert expected_var == statement.var, 'expected write to variable %s, not to %s' % ( expected_var, statement.var) assert expected_op == statement.op, 'expected operation %s, not %s' % ( expected_op, statement.op) # Compare the two expressions using sympy to allow for different order etc. sympy_expected = str_to_sympy(expected_expr) sympy_actual = str_to_sympy(statement.expr) assert sympy_expected == sympy_actual, ( 'RHS expressions "%s" and "%s" are not identical' % (sympy_to_str(sympy_expected), sympy_to_str(sympy_actual))) except AssertionError as ex: raise AssertionError( 'Transformation for statement "%s" gave an unexpected result: %s' % (orig, str(ex)))
def make_statements(code, variables, dtype, optimise=True, blockname=''): ''' make_statements(code, variables, dtype, optimise=True, blockname='') Turn a series of abstract code statements into Statement objects, inferring whether each line is a set/declare operation, whether the variables are constant or not, and handling the cacheing of subexpressions. Parameters ---------- code : str A (multi-line) string of statements. variables : dict-like A dictionary of with `Variable` and `Function` objects for every identifier used in the `code`. dtype : `dtype` The data type to use for temporary variables optimise : bool, optional Whether to optimise expressions, including pulling out loop invariant expressions and putting them in new scalar constants. Defaults to ``False``, since this function is also used just to in contexts where we are not interested by this kind of optimisation. For the main code generation stage, its value is set by the `codegen.loop_invariant_optimisations` preference. blockname : str, optional A name for the block (used to name intermediate variables to avoid name clashes when multiple blocks are used together) Returns ------- scalar_statements, vector_statements : (list of `Statement`, list of `Statement`) Lists with statements that are to be executed once and statements that are to be executed once for every neuron/synapse/... (or in a vectorised way) Notes ----- If ``optimise`` is ``True``, then the ``scalar_statements`` may include newly introduced scalar constants that have been identified as loop-invariant and have therefore been pulled out of the vector statements. The resulting statements will also use augmented assignments where possible, i.e. a statement such as ``w = w + 1`` will be replaced by ``w += 1``. Also, statements involving booleans will have additional information added to them (see `Statement` for details) describing how the statement can be reformulated as a sequence of if/then statements. Calls `~angela2.codegen.optimisation.optimise_statements`. ''' code = strip_empty_lines(deindent(code)) lines = re.split(r'[;\n]', code) lines = [LineInfo(code=line) for line in lines if len(line)] # Do a copy so we can add stuff without altering the original dict variables = dict(variables) # we will do inference to work out which lines are := and which are = defined = set(k for k, v in variables.items() if not isinstance(v, AuxiliaryVariable)) for line in lines: statement = None # parse statement into "var op expr" var, op, expr, comment = parse_statement(line.code) if var in variables and isinstance(variables[var], Subexpression): raise SyntaxError("Illegal line '{line}' in abstract code. " "Cannot write to subexpression " "'{var}'.".format(line=line.code, var=var)) if op == '=': if var not in defined: op = ':=' defined.add(var) if var not in variables: annotated_ast = angela_ast(expr, variables) is_scalar = annotated_ast.scalar if annotated_ast.dtype == 'boolean': use_dtype = bool elif annotated_ast.dtype == 'integer': use_dtype = int else: use_dtype = dtype new_var = AuxiliaryVariable(var, dtype=use_dtype, scalar=is_scalar) variables[var] = new_var elif not variables[var].is_boolean: sympy_expr = str_to_sympy(expr, variables) if variables[var].is_integer: sympy_var = sympy.Symbol(var, integer=True) else: sympy_var = sympy.Symbol(var, real=True) try: collected = sympy.collect(sympy_expr, sympy_var, exact=True, evaluate=False) except AttributeError: # If something goes wrong during collection, e.g. collect # does not work for logical expressions collected = {1: sympy_expr} if (len(collected) == 2 and set(collected.keys()) == {1, sympy_var} and collected[sympy_var] == 1): # We can replace this statement by a += assignment statement = Statement(var, '+=', sympy_to_str(collected[1]), comment, dtype=variables[var].dtype, scalar=variables[var].scalar) elif len(collected) == 1 and sympy_var in collected: # We can replace this statement by a *= assignment statement = Statement(var, '*=', sympy_to_str(collected[sympy_var]), comment, dtype=variables[var].dtype, scalar=variables[var].scalar) if statement is None: statement = Statement(var, op, expr, comment, dtype=variables[var].dtype, scalar=variables[var].scalar) line.statement = statement # for each line will give the variable being written to line.write = var # each line will give a set of variables which are read line.read = get_identifiers_recursively([expr], variables) # All writes to scalar variables must happen before writes to vector # variables scalar_write_done = False for line in lines: stmt = line.statement if stmt.op != ':=' and variables[ stmt.var].scalar and scalar_write_done: raise SyntaxError( ('All writes to scalar variables in a code block ' 'have to be made before writes to vector ' 'variables. Illegal write to %s.') % line.write) elif not variables[stmt.var].scalar: scalar_write_done = True # all variables which are written to at some point in the code block # used to determine whether they should be const or not all_write = set(line.write for line in lines) # backwards compute whether or not variables will be read again # note that will_read for a line gives the set of variables it will read # on the current line or subsequent ones. will_write gives the set of # variables that will be written after the current line will_read = set() will_write = set() for line in lines[::-1]: will_read = will_read.union(line.read) line.will_read = will_read.copy() line.will_write = will_write.copy() will_write.add(line.write) subexpressions = dict((name, val) for name, val in variables.items() if isinstance(val, Subexpression)) # Check that no scalar subexpression refers to a vectorised function # (e.g. rand()) -- otherwise it would be differently interpreted depending # on whether it is used in a scalar or a vector context (i.e., even though # the subexpression is supposed to be scalar, it would be vectorised when # used as part of non-scalar expressions) for name, subexpr in subexpressions.items(): if subexpr.scalar: identifiers = get_identifiers(subexpr.expr) for identifier in identifiers: if (identifier in variables and getattr( variables[identifier], 'auto_vectorise', False)): raise SyntaxError(('The scalar subexpression {} refers to ' 'the implicitly vectorised function {} ' '-- this is not allowed since it leads ' 'to different interpretations of this ' 'subexpression depending on whether it ' 'is used in a scalar or vector ' 'context.').format(name, identifier)) # sort subexpressions into an order so that subexpressions that don't depend # on other subexpressions are first subexpr_deps = dict( (name, [dep for dep in subexpr.identifiers if dep in subexpressions]) for name, subexpr in subexpressions.items()) sorted_subexpr_vars = topsort(subexpr_deps) statements = [] # none are yet defined (or declared) subdefined = dict((name, None) for name in subexpressions) for line in lines: stmt = line.statement read = line.read write = line.write will_read = line.will_read will_write = line.will_write # update/define all subexpressions needed by this statement for var in sorted_subexpr_vars: if var not in read: continue subexpression = subexpressions[var] # if already defined/declared if subdefined[var] == 'constant': continue elif subdefined[var] == 'variable': op = '=' constant = False else: op = ':=' # check if the referred variables ever change ids = subexpression.identifiers constant = all(v not in will_write for v in ids) subdefined[var] = 'constant' if constant else 'variable' statement = Statement(var, op, subexpression.expr, comment='', dtype=variables[var].dtype, constant=constant, subexpression=True, scalar=variables[var].scalar) statements.append(statement) var, op, expr, comment = stmt.var, stmt.op, stmt.expr, stmt.comment # constant only if we are declaring a new variable and we will not # write to it again constant = op == ':=' and var not in will_write statement = Statement(var, op, expr, comment, dtype=variables[var].dtype, constant=constant, scalar=variables[var].scalar) statements.append(statement) scalar_statements = [s for s in statements if s.scalar] vector_statements = [s for s in statements if not s.scalar] if optimise and prefs.codegen.loop_invariant_optimisations: scalar_statements, vector_statements = optimise_statements( scalar_statements, vector_statements, variables, blockname=blockname) return scalar_statements, vector_statements
def _generate_RHS(self, eqs, var, eq_symbols, temp_vars, expr, non_stochastic_expr, stochastic_expr, stochastic_variable=()): ''' Helper function used in `__call__`. Generates the right hand side of an abstract code statement by appropriately replacing f, g and t. For example, given a differential equation ``dv/dt = -(v + I) / tau`` (i.e. `var` is ``v` and `expr` is ``(-v + I) / tau``) together with the `rk2` step ``return x + dt*f(x + k/2, t + dt/2)`` (i.e. `non_stochastic_expr` is ``x + dt*f(x + k/2, t + dt/2)`` and `stochastic_expr` is ``None``), produces ``v + dt*(-v - _k_v/2 + I + _k_I/2)/tau``. ''' # Note: in the following we are silently ignoring the case that a # state updater does not care about either the non-stochastic or the # stochastic part of an equation. We do trust state updaters to # correctly specify their own abilities (i.e. they do not claim to # support stochastic equations but actually just ignore the stochastic # part). We can't really check the issue here, as we are only dealing # with one line of the state updater description. It is perfectly valid # to write the euler update as: # non_stochastic = dt * f(x, t) # stochastic = dt**.5 * g(x, t) * xi # return x + non_stochastic + stochastic # # In the above case, we'll deal with lines which do not define either # the stochastic or the non-stochastic part. non_stochastic, stochastic = expr.split_stochastic() if non_stochastic_expr is not None: # We do have a non-stochastic part in the state updater description non_stochastic_results = self._non_stochastic_part(eq_symbols, non_stochastic, non_stochastic_expr, stochastic_variable, temp_vars, var) else: non_stochastic_results = [] if not (stochastic is None or stochastic_expr is None): # We do have a stochastic part in the state # updater description stochastic_results = self._stochastic_part(eq_symbols, stochastic, stochastic_expr, stochastic_variable, temp_vars, var) else: stochastic_results = [] RHS = sympy.Number(0) # All the parts (one non-stochastic and potentially more than one # stochastic part) are combined with addition for non_stochastic_result in non_stochastic_results: RHS += non_stochastic_result for stochastic_result in stochastic_results: RHS += stochastic_result return sympy_to_str(RHS)