예제 #1
0
def test_constants_sympy():
    '''
    Make sure that symbolic constants are understood correctly by sympy
    '''
    assert sympy_to_str(str_to_sympy('1.0/inf')) == '0'
    assert sympy_to_str(str_to_sympy('sin(pi)')) == '0'
    assert sympy_to_str(str_to_sympy('log(e)')) == '1'
    def get_substituted_expressions(self,
                                    variables=None,
                                    include_subexpressions=False):
        '''
        Return a list of ``(varname, expr)`` tuples, containing all
        differential equations (and optionally subexpressions) with all the
        subexpression variables substituted with the respective expressions.

        Parameters
        ----------
        variables : dict, optional
            A mapping of variable names to `Variable`/`Function` objects.
        include_subexpressions : bool
            Whether also to return substituted subexpressions. Defaults to
            ``False``.

        Returns
        -------
        expr_tuples : list of (str, `CodeString`)
            A list of ``(varname, expr)`` tuples, where ``expr`` is a
            `CodeString` object with all subexpression variables substituted
            with the respective expression.
        '''
        if self._substituted_expressions is None:
            self._substituted_expressions = []
            substitutions = {}
            for eq in self.ordered:
                # Skip parameters
                if eq.expr is None:
                    continue

                new_sympy_expr = str_to_sympy(
                    eq.expr.code, variables).xreplace(substitutions)
                new_str_expr = sympy_to_str(new_sympy_expr)
                expr = Expression(new_str_expr)

                if eq.type == SUBEXPRESSION:
                    if eq.var_type == INTEGER:
                        sympy_var = sympy.Symbol(eq.varname, integer=True)
                    else:
                        sympy_var = sympy.Symbol(eq.varname, real=True)
                    substitutions.update(
                        {sympy_var: str_to_sympy(expr.code, variables)})
                    self._substituted_expressions.append((eq.varname, expr))
                elif eq.type == DIFFERENTIAL_EQUATION:
                    #  a differential equation that we have to check
                    self._substituted_expressions.append((eq.varname, expr))
                else:
                    raise AssertionError('Unknown equation type %s' % eq.type)

        if include_subexpressions:
            return self._substituted_expressions
        else:
            return [(name, expr)
                    for name, expr in self._substituted_expressions
                    if self[name].type == DIFFERENTIAL_EQUATION]
 def _latex(self, *args):
     if self.type == DIFFERENTIAL_EQUATION:
         return (r'\frac{\mathrm{d}' + sympy.latex(self.varname) +
                 r'}{\mathrm{d}t} = ' +
                 sympy.latex(str_to_sympy(self.expr.code)))
     elif self.type == SUBEXPRESSION:
         return (sympy.latex(self.varname) + ' = ' +
                 sympy.latex(str_to_sympy(self.expr.code)))
     elif self.type == PARAMETER:
         return sympy.latex(self.varname)
    def __init__(self, code=None, sympy_expression=None):
        if code is None and sympy_expression is None:
            raise TypeError('Have to provide either a string or a sympy expression')
        if code is not None and sympy_expression is not None:
            raise TypeError('Provide a string expression or a sympy expression, not both')

        if code is None:
            code = sympy_to_str(sympy_expression)
        else:
            # Just try to convert it to a sympy expression to get syntax errors
            # for incorrect expressions
            str_to_sympy(code)
        super(Expression, self).__init__(code=code)
 def _latex(self, *args):
     equations = []
     t = sympy.Symbol('t')
     for eq in self._equations.values():
         # do not use SingleEquations._latex here as we want nice alignment
         varname = sympy.Symbol(eq.varname)
         if eq.type == DIFFERENTIAL_EQUATION:
             lhs = r'\frac{\mathrm{d}' + sympy.latex(
                 varname) + r'}{\mathrm{d}t}'
         else:
             # Normal equation or parameter
             lhs = varname
         if not eq.type == PARAMETER:
             rhs = str_to_sympy(eq.expr.code)
         if len(eq.flags):
             flag_str = ', flags: ' + ', '.join(eq.flags)
         else:
             flag_str = ''
         if eq.type == PARAMETER:
             eq_latex = r'%s &&& \text{(unit: $%s$%s)}' % (
                 sympy.latex(lhs), sympy.latex(get_unit(eq.dim)), flag_str)
         else:
             eq_latex = r'%s &= %s && \text{(unit of $%s$: $%s$%s)}' % (
                 sympy.latex(lhs), sympy.latex(rhs), sympy.latex(varname),
                 sympy.latex(get_unit(eq.dim)), flag_str)
         equations.append(eq_latex)
     return r'\begin{align*}' + (r'\\' +
                                 '\n').join(equations) + r'\end{align*}'
 def _repr_pretty_(self, p, cycle):
     '''
     Pretty printing for ipython.
     '''
     if cycle:
         raise AssertionError('Cyclical call of CodeString._repr_pretty')
     # Make use of sympy's pretty printing
     p.pretty(str_to_sympy(self.code))
def get_conditionally_linear_system(eqs, variables=None):
    '''
    Convert equations into a linear system using sympy.
    
    Parameters
    ----------
    eqs : `Equations`
        The model equations.
    
    Returns
    -------
    coefficients : dict of (sympy expression, sympy expression) tuples
        For every variable x, a tuple (M, B) containing the coefficients M and
        B (as sympy expressions) for M * x + B
    
    Raises
    ------
    ValueError
        If one of the equations cannot be converted into a M * x + B form.

    Examples
    --------
    >>> from angela2 import Equations
    >>> eqs = Equations("""
    ... dv/dt = (-v + w**2.0) / tau : 1
    ... dw/dt = -w / tau : 1
    ... """)
    >>> system = get_conditionally_linear_system(eqs)
    >>> print(system['v'])
    (-1/tau, w**2.0/tau)
    >>> print(system['w'])
    (-1/tau, 0)

    '''
    diff_eqs = eqs.get_substituted_expressions(variables)
    
    coefficients = {}
    
    for name, expr in diff_eqs:
        var = sp.Symbol(name, real=True)
    
        s_expr = str_to_sympy(expr.code, variables).expand()
        if s_expr.has(var):
            # Factor out the variable
            s_expr = sp.collect(s_expr,
                                var, evaluate=False)

            if len(s_expr) > 2 or var not in s_expr:
                raise ValueError(('The expression "%s", defining the variable %s, '
                                 'could not be separated into linear components') %
                                 (expr, name))
            coefficients[name] = (s_expr[var], s_expr.get(1, 0))
        else:
            coefficients[name] = (0, s_expr)

    return coefficients
예제 #8
0
    def replace_func(self, x, t, expr, temp_vars, eq_symbols,
                     stochastic_variable=None):
        '''
        Used to replace a single occurance of ``f(x, t)`` or ``g(x, t)``:
        `expr` is the non-stochastic (in the case of ``f``) or stochastic
        part (``g``) of the expression defining the right-hand-side of the
        differential equation describing `var`. It replaces the variable
        `var` with the value given as `x` and `t` by the value given for
        `t`. Intermediate variables will be replaced with the appropriate
        replacements as well.

        For example, in the `rk2` integrator, the second step involves the
        calculation of ``f(k/2 + x, dt/2 + t)``.  If `var` is ``v`` and
        `expr` is ``-v / tau``, this will result in ``-(_k_v/2 + v)/tau``.

        Note that this deals with only one state variable `var`, given as
        an argument to the surrounding `_generate_RHS` function.
        '''

        try:
            s_expr = str_to_sympy(str(expr))
        except SympifyError as ex:
            raise ValueError('Error parsing the expression "%s": %s' %
                             (expr, str(ex)))

        for var in eq_symbols:
            # Generate specific temporary variables for the state variable,
            # e.g. '_k_v' for the state variable 'v' and the temporary
            # variable 'k'.
            if stochastic_variable is None:
                temp_var_replacements = dict(((self.symbols[temp_var],
                                               _symbol(temp_var+'_'+var))
                                              for temp_var in temp_vars))
            else:
                temp_var_replacements = dict(((self.symbols[temp_var],
                                               _symbol(temp_var+'_'+var+'_'+stochastic_variable))
                                              for temp_var in temp_vars))
            # In the expression given as 'x', replace 'x' by the variable
            # 'var' and all the temporary variables by their
            # variable-specific counterparts.
            x_replacement = x.subs(self.symbols['__x'], eq_symbols[var])
            x_replacement = x_replacement.subs(temp_var_replacements)

            # Replace the variable `var` in the expression by the new `x`
            # expression
            s_expr = s_expr.subs(eq_symbols[var], x_replacement)

        # If the expression given for t in the state updater description
        # is not just "t" (or rather "__t"), then replace t in the
        # equations by it, and replace "__t" by "t" afterwards.
        if t != self.symbols['__t']:
            s_expr = s_expr.subs(SYMBOLS['t'], t)
            s_expr = s_expr.replace(self.symbols['__t'], SYMBOLS['t'])

        return s_expr
예제 #9
0
def get_linear_system(eqs, variables):
    '''
    Convert equations into a linear system using sympy.
    
    Parameters
    ----------
    eqs : `Equations`
        The model equations.
    
    Returns
    -------
    (diff_eq_names, coefficients, constants) : (list of str, `sympy.Matrix`, `sympy.Matrix`)
        A tuple containing the variable names (`diff_eq_names`) corresponding
        to the rows of the matrix `coefficients` and the vector `constants`,
        representing the system of equations in the form M * X + B
    
    Raises
    ------
    ValueError
        If the equations cannot be converted into an M * X + B form.
    '''
    diff_eqs = eqs.get_substituted_expressions(variables)
    diff_eq_names = [name for name, _ in diff_eqs]

    symbols = [Symbol(name, real=True) for name in diff_eq_names]

    coefficients = sp.zeros(len(diff_eq_names))
    constants = sp.zeros(len(diff_eq_names), 1)

    for row_idx, (name, expr) in enumerate(diff_eqs):
        s_expr = str_to_sympy(expr.code, variables).expand()

        current_s_expr = s_expr
        for col_idx, symbol in enumerate(symbols):
            current_s_expr = current_s_expr.collect(symbol)
            constant_wildcard = Wild('c', exclude=[symbol])
            factor_wildcard = Wild('c_'+name, exclude=symbols)
            one_pattern = factor_wildcard*symbol + constant_wildcard
            matches = current_s_expr.match(one_pattern)
            if matches is None:
                raise UnsupportedEquationsException(('The expression "%s", '
                                                     'defining the variable '
                                                     '%s, could not be '
                                                     'separated into linear '
                                                     'components.') %
                                                    (expr, name))

            coefficients[row_idx, col_idx] = matches[factor_wildcard]
            current_s_expr = matches[constant_wildcard]

        # The remaining constant should be a true constant
        constants[row_idx] = current_s_expr

    return (diff_eq_names, coefficients, constants)
예제 #10
0
    def __init__(self, description, stochastic=None, custom_check=None):
        self._description = description
        self.stochastic = stochastic
        self.custom_check = custom_check

        try:
            parsed = ExplicitStateUpdater.DESCRIPTION.parseString(description,
                                                                  parseAll=True)
        except ParseException as p_exc:
            ex = SyntaxError('Parsing failed: ' + str(p_exc.msg))
            ex.text = str(p_exc.line)
            ex.offset = p_exc.column
            ex.lineno = p_exc.lineno
            raise ex
 
        self.statements = []
        self.symbols = SYMBOLS.copy()
        for element in parsed:
            expression = str_to_sympy(element.expression)
            # Replace all symbols used in state updater expressions by unique
            # names that cannot clash with user-defined variables or functions
            expression = expression.subs(sympy.Function('f'),
                                         self.symbols['__f'])
            expression = expression.subs(sympy.Function('g'),
                                         self.symbols['__g'])
            symbols = list(expression.atoms(sympy.Symbol))
            unique_symbols = []
            for symbol in symbols:
                if symbol.name == 'dt':
                    unique_symbols.append(symbol)
                else:
                    unique_symbols.append(_symbol('__' + symbol.name))
            for symbol, unique_symbol in zip(symbols, unique_symbols):
                expression = expression.subs(symbol, unique_symbol)

            self.symbols.update(dict(((symbol.name, symbol)
                                      for symbol in unique_symbols)))
            if element.getName() == 'statement':
                self.statements.append(('__'+element.identifier, expression))
            elif element.getName() == 'output':
                self.output = expression
            else:
                raise AssertionError('Unknown element name: %s' %
                                     element.getName())
def check_subexpressions(group, equations, run_namespace):
    '''
    Checks the subexpressions in the equations and raises an error if a
    subexpression refers to stateful functions without being marked as
    "constant over dt".

    Parameters
    ----------
    group : `Group`
        The group providing the context.
    equations : `Equations`
        The equations to check.
    run_namespace : dict
        The run namespace for resolving variables.

    Raises
    ------
    SyntaxError
        For subexpressions not marked as "constant over dt" that refer to
        stateful functions.
    '''
    for eq in equations.ordered:
        if eq.type == SUBEXPRESSION:
            # Check whether the expression is stateful (most commonly by
            # referring to rand() or randn()
            variables = group.resolve_all(
                eq.identifiers,
                run_namespace,
                # we don't need to raise any warnings
                # for the user here, warnings will
                # be raised in create_runner_codeobj
                user_identifiers=set())
            expression = str_to_sympy(eq.expr.code, variables=variables)

            # Check whether the expression refers to stateful functions
            if is_stateful(expression, variables):
                raise SyntaxError(
                    "The subexpression '{}' refers to a stateful "
                    "function (e.g. rand()). Such expressions "
                    "should only be evaluated once per timestep, "
                    "add the 'constant over dt'"
                    "flag.".format(eq.varname))
예제 #12
0
    def __call__(self, equations, variables=None, method_options=None):
        logger.warn("The 'independent' state updater is deprecated and might be "
                    "removed in future versions of angela.",
                    'deprecated_independent', once=True)
        method_options = extract_method_options(method_options, {})
        if equations.is_stochastic:
            raise UnsupportedEquationsException('Cannot solve stochastic '
                                                'equations with this state '
                                                'updater')
        if variables is None:
            variables = {}

        diff_eqs = equations.get_substituted_expressions(variables)

        t = Symbol('t', real=True, positive=True)
        dt = Symbol('dt', real=True, positive=True)
        t0 = Symbol('t0', real=True, positive=True)

        code = []
        for name, expression in diff_eqs:
            rhs = str_to_sympy(expression.code, variables)

            # We have to be careful and use the real=True assumption as well,
            # otherwise sympy doesn't consider the symbol a match to the content
            # of the equation
            var = Symbol(name, real=True)
            f = sp.Function(name)
            rhs = rhs.subs(var, f(t))
            derivative = sp.Derivative(f(t), t)
            diff_eq = sp.Eq(derivative, rhs)
            # TODO: simplify=True sometimes fails with 0.7.4, see:
            # https://github.com/sympy/sympy/issues/2666
            try:
                general_solution = sp.dsolve(diff_eq, f(t), simplify=True)
            except RuntimeError:
                general_solution = sp.dsolve(diff_eq, f(t), simplify=False)
            # Check whether this is an explicit solution
            if not getattr(general_solution, 'lhs', None) == f(t):
                raise UnsupportedEquationsException('Cannot explicitly solve: '
                                                    + str(diff_eq))
            # Solve for C1 (assuming "var" as the initial value and "t0" as time)
            if general_solution.has(Symbol('C1')):
                if general_solution.has(Symbol('C2')):
                    raise UnsupportedEquationsException('Too many constants in solution: %s' % str(general_solution))
                constant_solution = sp.solve(general_solution, Symbol('C1'))
                if len(constant_solution) != 1:
                    raise UnsupportedEquationsException(("Couldn't solve for the constant "
                                                         "C1 in : %s ") % str(general_solution))
                constant = constant_solution[0].subs(t, t0).subs(f(t0), var)
                solution = general_solution.rhs.subs('C1', constant)
            else:
                solution = general_solution.rhs.subs(t, t0).subs(f(t0), var)
            # Evaluate the expression for one timestep
            solution = solution.subs(t, t + dt).subs(t0, t)
            # only try symplifying it -- it sometimes raises an error
            try:
                solution = solution.simplify()
            except ValueError:
                pass

            code.append(name + ' = ' + sympy_to_str(solution))

        return '\n'.join(code)
예제 #13
0
    def __call__(self, eqs, variables=None, method_options=None):
        '''
        Apply a state updater description to model equations.
        
        Parameters
        ----------
        eqs : `Equations`
            The equations describing the model
        variables: dict-like, optional
            The `Variable` objects for the model. Ignored by the explicit
            state updater.
        method_options : dict, optional
            Additional options to the state updater (not used at the moment
            for the explicit state updaters).

        Examples
        --------
        >>> from angela2 import *
        >>> eqs = Equations('dv/dt = -v / tau : volt')
        >>> print(euler(eqs))
        _v = -dt*v/tau + v
        v = _v
        >>> print(rk4(eqs))
        __k_1_v = -dt*v/tau
        __k_2_v = -dt*(__k_1_v/2 + v)/tau
        __k_3_v = -dt*(__k_2_v/2 + v)/tau
        __k_4_v = -dt*(__k_3_v + v)/tau
        _v = __k_1_v/6 + __k_2_v/3 + __k_3_v/3 + __k_4_v/6 + v
        v = _v
        '''
        method_options = extract_method_options(method_options, {})
        # Non-stochastic numerical integrators should work for all equations,
        # except for stochastic equations
        if eqs.is_stochastic and self.stochastic is None:
            raise UnsupportedEquationsException('Cannot integrate '
                                                'stochastic equations with '
                                                'this state updater.')
        if self.custom_check:
            self.custom_check(eqs, variables)
        # The final list of statements
        statements = []

        stochastic_variables = eqs.stochastic_variables

        # The variables for the intermediate results in the state updater
        # description, e.g. the variable k in rk2
        intermediate_vars = [var for var, expr in self.statements]
        
        # A dictionary mapping all the variables in the equations to their
        # sympy representations 
        eq_variables = dict(((var, _symbol(var)) for var in eqs.eq_names))
        
        # Generate the random numbers for the stochastic variables
        for stochastic_variable in stochastic_variables:
            statements.append(stochastic_variable + ' = ' + 'dt**.5 * randn()')

        substituted_expressions = eqs.get_substituted_expressions(variables)

        # Process the intermediate statements in the stateupdater description
        for intermediate_var, intermediate_expr in self.statements:
                      
            # Split the expression into a non-stochastic and a stochastic part
            non_stochastic_expr, stochastic_expr = split_expression(intermediate_expr)
            
            # Execute the statement by appropriately replacing the functions f
            # and g and the variable x for every equation in the model.
            # We use the model equations where the subexpressions have
            # already been substituted into the model equations.
            for var, expr in substituted_expressions:
                for xi in stochastic_variables:
                    RHS = self._generate_RHS(eqs, var, eq_variables, intermediate_vars,
                                             expr, non_stochastic_expr,
                                             stochastic_expr, xi)
                    statements.append(intermediate_var+'_'+var+'_'+xi+' = '+RHS)
                if not stochastic_variables:   # no stochastic variables
                    RHS = self._generate_RHS(eqs, var, eq_variables, intermediate_vars,
                                             expr, non_stochastic_expr,
                                             stochastic_expr)
                    statements.append(intermediate_var+'_'+var+' = '+RHS)
                
        # Process the "return" line of the stateupdater description
        non_stochastic_expr, stochastic_expr = split_expression(self.output)

        if eqs.is_stochastic and (self.stochastic != 'multiplicative' and
                                  eqs.stochastic_type == 'multiplicative'):
            # The equations are marked as having multiplicative noise and the
            # current state updater does not support such equations. However,
            # it is possible that the equations do not use multiplicative noise
            # at all. They could depend on time via a function that is constant
            # over a single time step (most likely, a TimedArray). In that case
            # we can integrate the equations
            dt_value = variables['dt'].get_value()[0] if 'dt' in variables else None
            for _, expr in substituted_expressions:
                _, stoch = expr.split_stochastic()
                if stoch is None:
                    continue
                # There could be more than one stochastic variable (e.g. xi_1, xi_2)
                for _, stoch_expr in stoch.items():
                    sympy_expr = str_to_sympy(stoch_expr.code)
                    # The equation really has multiplicative noise, if it depends
                    # on time (and not only via a function that is constant
                    # over dt), or if it depends on another variable defined
                    # via differential equations.
                    if (not is_constant_over_dt(sympy_expr, variables, dt_value)
                            or len(stoch_expr.identifiers & eqs.diff_eq_names)):
                        raise UnsupportedEquationsException('Cannot integrate '
                                                            'equations with '
                                                            'multiplicative noise with '
                                                            'this state updater.')

        # Assign a value to all the model variables described by differential
        # equations
        for var, expr in substituted_expressions:
            RHS = self._generate_RHS(eqs, var, eq_variables, intermediate_vars,
                                     expr, non_stochastic_expr, stochastic_expr,
                                     stochastic_variables)
            statements.append('_' + var + ' = ' + RHS)
        
        # Assign everything to the final variables
        for var, expr in substituted_expressions:
            statements.append(var + ' = ' + '_' + var)

        return '\n'.join(statements)
예제 #14
0
def make_statements(code, variables, dtype, optimise=True, blockname=''):
    '''
    make_statements(code, variables, dtype, optimise=True, blockname='')

    Turn a series of abstract code statements into Statement objects, inferring
    whether each line is a set/declare operation, whether the variables are
    constant or not, and handling the cacheing of subexpressions.

    Parameters
    ----------
    code : str
        A (multi-line) string of statements.
    variables : dict-like
        A dictionary of with `Variable` and `Function` objects for every
        identifier used in the `code`.
    dtype : `dtype`
        The data type to use for temporary variables
    optimise : bool, optional
        Whether to optimise expressions, including
        pulling out loop invariant expressions and putting them in new
        scalar constants. Defaults to ``False``, since this function is also
        used just to in contexts where we are not interested by this kind of
        optimisation. For the main code generation stage, its value is set by
        the `codegen.loop_invariant_optimisations` preference.
    blockname : str, optional
        A name for the block (used to name intermediate variables to avoid
        name clashes when multiple blocks are used together)
    Returns
    -------
    scalar_statements, vector_statements : (list of `Statement`, list of `Statement`)
        Lists with statements that are to be executed once and statements that
        are to be executed once for every neuron/synapse/... (or in a vectorised
        way)

    Notes
    -----
    If ``optimise`` is ``True``, then the
    ``scalar_statements`` may include newly introduced scalar constants that
    have been identified as loop-invariant and have therefore been pulled out
    of the vector statements. The resulting statements will also use augmented
    assignments where possible, i.e. a statement such as ``w = w + 1`` will be
    replaced by ``w += 1``. Also, statements involving booleans will have
    additional information added to them (see `Statement` for details)
    describing how the statement can be reformulated as a sequence of if/then
    statements. Calls `~angela2.codegen.optimisation.optimise_statements`.
    '''
    code = strip_empty_lines(deindent(code))
    lines = re.split(r'[;\n]', code)
    lines = [LineInfo(code=line) for line in lines if len(line)]
    # Do a copy so we can add stuff without altering the original dict
    variables = dict(variables)
    # we will do inference to work out which lines are := and which are =
    defined = set(k for k, v in variables.items()
                  if not isinstance(v, AuxiliaryVariable))
    for line in lines:
        statement = None
        # parse statement into "var op expr"
        var, op, expr, comment = parse_statement(line.code)
        if var in variables and isinstance(variables[var], Subexpression):
            raise SyntaxError("Illegal line '{line}' in abstract code. "
                              "Cannot write to subexpression "
                              "'{var}'.".format(line=line.code, var=var))
        if op == '=':
            if var not in defined:
                op = ':='
                defined.add(var)
                if var not in variables:
                    annotated_ast = angela_ast(expr, variables)
                    is_scalar = annotated_ast.scalar
                    if annotated_ast.dtype == 'boolean':
                        use_dtype = bool
                    elif annotated_ast.dtype == 'integer':
                        use_dtype = int
                    else:
                        use_dtype = dtype
                    new_var = AuxiliaryVariable(var,
                                                dtype=use_dtype,
                                                scalar=is_scalar)
                    variables[var] = new_var
            elif not variables[var].is_boolean:
                sympy_expr = str_to_sympy(expr, variables)
                if variables[var].is_integer:
                    sympy_var = sympy.Symbol(var, integer=True)
                else:
                    sympy_var = sympy.Symbol(var, real=True)
                try:
                    collected = sympy.collect(sympy_expr,
                                              sympy_var,
                                              exact=True,
                                              evaluate=False)
                except AttributeError:
                    # If something goes wrong during collection, e.g. collect
                    # does not work for logical expressions
                    collected = {1: sympy_expr}

                if (len(collected) == 2
                        and set(collected.keys()) == {1, sympy_var}
                        and collected[sympy_var] == 1):
                    # We can replace this statement by a += assignment
                    statement = Statement(var,
                                          '+=',
                                          sympy_to_str(collected[1]),
                                          comment,
                                          dtype=variables[var].dtype,
                                          scalar=variables[var].scalar)
                elif len(collected) == 1 and sympy_var in collected:
                    # We can replace this statement by a *= assignment
                    statement = Statement(var,
                                          '*=',
                                          sympy_to_str(collected[sympy_var]),
                                          comment,
                                          dtype=variables[var].dtype,
                                          scalar=variables[var].scalar)
        if statement is None:
            statement = Statement(var,
                                  op,
                                  expr,
                                  comment,
                                  dtype=variables[var].dtype,
                                  scalar=variables[var].scalar)

        line.statement = statement
        # for each line will give the variable being written to
        line.write = var
        # each line will give a set of variables which are read
        line.read = get_identifiers_recursively([expr], variables)

    # All writes to scalar variables must happen before writes to vector
    # variables
    scalar_write_done = False
    for line in lines:
        stmt = line.statement
        if stmt.op != ':=' and variables[
                stmt.var].scalar and scalar_write_done:
            raise SyntaxError(
                ('All writes to scalar variables in a code block '
                 'have to be made before writes to vector '
                 'variables. Illegal write to %s.') % line.write)
        elif not variables[stmt.var].scalar:
            scalar_write_done = True

    # all variables which are written to at some point in the code block
    # used to determine whether they should be const or not
    all_write = set(line.write for line in lines)

    # backwards compute whether or not variables will be read again
    # note that will_read for a line gives the set of variables it will read
    # on the current line or subsequent ones. will_write gives the set of
    # variables that will be written after the current line
    will_read = set()
    will_write = set()
    for line in lines[::-1]:
        will_read = will_read.union(line.read)
        line.will_read = will_read.copy()
        line.will_write = will_write.copy()
        will_write.add(line.write)

    subexpressions = dict((name, val) for name, val in variables.items()
                          if isinstance(val, Subexpression))
    # Check that no scalar subexpression refers to a vectorised function
    # (e.g. rand()) -- otherwise it would be differently interpreted depending
    # on whether it is used in a scalar or a vector context (i.e., even though
    # the subexpression is supposed to be scalar, it would be vectorised when
    # used as part of non-scalar expressions)
    for name, subexpr in subexpressions.items():
        if subexpr.scalar:
            identifiers = get_identifiers(subexpr.expr)
            for identifier in identifiers:
                if (identifier in variables and getattr(
                        variables[identifier], 'auto_vectorise', False)):
                    raise SyntaxError(('The scalar subexpression {} refers to '
                                       'the implicitly vectorised function {} '
                                       '-- this is not allowed since it leads '
                                       'to different interpretations of this '
                                       'subexpression depending on whether it '
                                       'is used in a scalar or vector '
                                       'context.').format(name, identifier))

    # sort subexpressions into an order so that subexpressions that don't depend
    # on other subexpressions are first
    subexpr_deps = dict(
        (name, [dep for dep in subexpr.identifiers if dep in subexpressions])
        for name, subexpr in subexpressions.items())
    sorted_subexpr_vars = topsort(subexpr_deps)

    statements = []

    # none are yet defined (or declared)
    subdefined = dict((name, None) for name in subexpressions)
    for line in lines:
        stmt = line.statement
        read = line.read
        write = line.write
        will_read = line.will_read
        will_write = line.will_write
        # update/define all subexpressions needed by this statement
        for var in sorted_subexpr_vars:
            if var not in read:
                continue

            subexpression = subexpressions[var]
            # if already defined/declared
            if subdefined[var] == 'constant':
                continue
            elif subdefined[var] == 'variable':
                op = '='
                constant = False
            else:
                op = ':='
                # check if the referred variables ever change
                ids = subexpression.identifiers
                constant = all(v not in will_write for v in ids)
                subdefined[var] = 'constant' if constant else 'variable'

            statement = Statement(var,
                                  op,
                                  subexpression.expr,
                                  comment='',
                                  dtype=variables[var].dtype,
                                  constant=constant,
                                  subexpression=True,
                                  scalar=variables[var].scalar)
            statements.append(statement)

        var, op, expr, comment = stmt.var, stmt.op, stmt.expr, stmt.comment

        # constant only if we are declaring a new variable and we will not
        # write to it again
        constant = op == ':=' and var not in will_write
        statement = Statement(var,
                              op,
                              expr,
                              comment,
                              dtype=variables[var].dtype,
                              constant=constant,
                              scalar=variables[var].scalar)
        statements.append(statement)

    scalar_statements = [s for s in statements if s.scalar]
    vector_statements = [s for s in statements if not s.scalar]

    if optimise and prefs.codegen.loop_invariant_optimisations:
        scalar_statements, vector_statements = optimise_statements(
            scalar_statements,
            vector_statements,
            variables,
            blockname=blockname)

    return scalar_statements, vector_statements
def test_automatic_augmented_assignments():
    # We test that statements that could be rewritten as augmented assignments
    # are correctly rewritten (using sympy to test for symbolic equality)
    variables = {
        'x': ArrayVariable('x', owner=None, size=10, device=device),
        'y': ArrayVariable('y', owner=None, size=10, device=device),
        'z': ArrayVariable('y', owner=None, size=10, device=device),
        'b': ArrayVariable('b',
                           owner=None,
                           size=10,
                           dtype=np.bool,
                           device=device),
        'clip': DEFAULT_FUNCTIONS['clip'],
        'inf': DEFAULT_CONSTANTS['inf']
    }
    statements = [
        # examples that should be rewritten
        # Note that using our approach, we will never get -= or /= but always
        # the equivalent += or *= statements
        ('x = x + 1.0', 'x += 1.0'),
        ('x = 2.0 * x', 'x *= 2.0'),
        ('x = x - 3.0', 'x += -3.0'),
        ('x = x/2.0', 'x *= 0.5'),
        ('x = y + (x + 1.0)', 'x += y + 1.0'),
        ('x = x + x', 'x *= 2.0'),
        ('x = x + y + z', 'x += y + z'),
        ('x = x + y + z', 'x += y + z'),
        # examples that should not be rewritten
        ('x = 1.0/x', 'x = 1.0/x'),
        ('x = 1.0', 'x = 1.0'),
        ('x = 2.0*(x + 1.0)', 'x = 2.0*(x + 1.0)'),
        ('x = clip(x + y, 0.0, inf)', 'x = clip(x + y, 0.0, inf)'),
        ('b = b or False', 'b = b or False')
    ]
    for orig, rewritten in statements:
        scalar, vector = make_statements(orig, variables, np.float32)
        try:  # we augment the assertion error with the original statement
            assert len(
                scalar
            ) == 0, 'Did not expect any scalar statements but got ' + str(
                scalar)
            assert len(
                vector
            ) == 1, 'Did expect a single statement but got ' + str(vector)
            statement = vector[0]
            expected_var, expected_op, expected_expr, _ = parse_statement(
                rewritten)
            assert expected_var == statement.var, 'expected write to variable %s, not to %s' % (
                expected_var, statement.var)
            assert expected_op == statement.op, 'expected operation %s, not %s' % (
                expected_op, statement.op)
            # Compare the two expressions using sympy to allow for different order etc.
            sympy_expected = str_to_sympy(expected_expr)
            sympy_actual = str_to_sympy(statement.expr)
            assert sympy_expected == sympy_actual, (
                'RHS expressions "%s" and "%s" are not identical' %
                (sympy_to_str(sympy_expected), sympy_to_str(sympy_actual)))
        except AssertionError as ex:
            raise AssertionError(
                'Transformation for statement "%s" gave an unexpected result: %s'
                % (orig, str(ex)))
    def __init__(self, morphology=None, model=None, threshold=None,
                 refractory=False, reset=None, events=None,
                 threshold_location=None,
                 dt=None, clock=None, order=0, Cm=0.9 * uF / cm ** 2, Ri=150 * ohm * cm,
                 name='spatialneuron*', dtype=None, namespace=None,
                 method=('exact', 'exponential_euler', 'rk2', 'heun'),
                 method_options=None):

        # #### Prepare and validate equations
        if isinstance(model, str):
            model = Equations(model)
        if not isinstance(model, Equations):
            raise TypeError(('model has to be a string or an Equations '
                             'object, is "%s" instead.') % type(model))

        # Insert the threshold mechanism at the specified location
        if threshold_location is not None:
            if hasattr(threshold_location,
                       '_indices'):  # assuming this is a method
                threshold_location = threshold_location._indices()
                # for now, only a single compartment allowed
                if len(threshold_location) == 1:
                    threshold_location = threshold_location[0]
                else:
                    raise AttributeError(('Threshold can only be applied on a '
                                          'single location'))
            threshold = '(' + threshold + ') and (i == ' + str(threshold_location) + ')'

        # Check flags (we have point currents)
        model.check_flags({DIFFERENTIAL_EQUATION: ('point current',),
                           PARAMETER: ('constant', 'shared', 'linked', 'point current'),
                           SUBEXPRESSION: ('shared', 'point current',
                                           'constant over dt')})
        #: The original equations as specified by the user (i.e. before
        #: inserting point-currents into the membrane equation, before adding
        #: all the internally used variables and constants, etc.).
        self.user_equations = model

        # Separate subexpressions depending whether they are considered to be
        # constant over a time step or not (this would also be done by the
        # NeuronGroup initializer later, but this would give incorrect results
        # for the linearity check)
        model, constant_over_dt = extract_constant_subexpressions(model)

        # Extract membrane equation
        if 'Im' in model:
            if len(model['Im'].flags):
                raise TypeError('Cannot specify any flags for the transmembrane '
                                'current Im.')
            membrane_expr = model['Im'].expr  # the membrane equation
        else:
            raise TypeError('The transmembrane current Im must be defined')

        model_equations = []
        # Insert point currents in the membrane equation
        for eq in model.values():
            if eq.varname == 'Im':
                continue  # ignore -- handled separately
            if 'point current' in eq.flags:
                fail_for_dimension_mismatch(eq.dim, amp,
                                            "Point current " + eq.varname + " should be in amp")
                membrane_expr = Expression(
                    str(membrane_expr.code) + '+' + eq.varname + '/area')
                eq = SingleEquation(eq.type, eq.varname, eq.dim, expr=eq.expr,
                                    flags=list(set(eq.flags)-{'point current'}))
            model_equations.append(eq)

        model_equations.append(SingleEquation(SUBEXPRESSION, 'Im',
                                              dimensions=(amp/meter**2).dim,
                                              expr=membrane_expr))
        model_equations.append(SingleEquation(PARAMETER, 'v', volt.dim))
        model = Equations(model_equations)

        ###### Process model equations (Im) to extract total conductance and the remaining current
        # Expand expressions in the membrane equation
        for var, expr in model.get_substituted_expressions(include_subexpressions=True):
            if var == 'Im':
                Im_expr = expr
                break
        else:
            raise AssertionError('Model equations did not contain Im!')

        # Differentiate Im with respect to v
        Im_sympy_exp = str_to_sympy(Im_expr.code)
        v_sympy = sp.Symbol('v', real=True)
        diffed = sp.diff(Im_sympy_exp, v_sympy)

        unevaled_derivatives = diffed.atoms(sp.Derivative)
        if len(unevaled_derivatives):
            raise TypeError('Cannot take the derivative of "{Im}" with respect '
                            'to v.'.format(Im=Im_expr.code))

        gtot_str = sympy_to_str(sp.simplify(-diffed))
        I0_str = sympy_to_str(sp.simplify(Im_sympy_exp - diffed*v_sympy))

        if gtot_str == '0':
            gtot_str += '*siemens/meter**2'
        if I0_str == '0':
            I0_str += '*amp/meter**2'
        gtot_str = "gtot__private=" + gtot_str + ": siemens/meter**2"
        I0_str = "I0__private=" + I0_str + ": amp/meter**2"

        model += Equations(gtot_str + "\n" + I0_str)

        # Insert morphology (store a copy)
        self.morphology = copy.deepcopy(morphology)

        # Flatten the morphology
        self.flat_morphology = FlatMorphology(morphology)

        # Equations for morphology
        # TODO: check whether Cm and Ri are already in the equations
        #       no: should be shared instead of constant
        #       yes: should be constant (check)
        eqs_constants = Equations("""
        length : meter (constant)
        distance : meter (constant)
        area : meter**2 (constant)
        volume : meter**3
        Ic : amp/meter**2
        diameter : meter (constant)
        Cm : farad/meter**2 (constant)
        Ri : ohm*meter (constant, shared)
        r_length_1 : meter (constant)
        r_length_2 : meter (constant)
        time_constant = Cm/gtot__private : second
        space_constant = (2/pi)**(1.0/3.0) * (area/(1/r_length_1 + 1/r_length_2))**(1.0/6.0) /
                         (2*(Ri*gtot__private)**(1.0/2.0)) : meter
        """)
        if self.flat_morphology.has_coordinates:
            eqs_constants += Equations('''
            x : meter (constant)
            y : meter (constant)
            z : meter (constant)
            ''')

        NeuronGroup.__init__(self, morphology.total_compartments,
                             model=model + eqs_constants,
                             method_options=method_options,
                             threshold=threshold, refractory=refractory,
                             reset=reset, events=events,
                             method=method, dt=dt, clock=clock, order=order,
                             namespace=namespace, dtype=dtype, name=name)
        # Parameters and intermediate variables for solving the cable equations
        # Note that some of these variables could have meaningful physical
        # units (e.g. _v_star is in volt, _I0_all is in amp/meter**2 etc.) but
        # since these variables should never be used in user code, we don't
        # assign them any units
        self.variables.add_arrays(['_ab_star0', '_ab_star1', '_ab_star2',
                                   '_b_plus', '_b_minus',
                                   '_v_star', '_u_plus', '_u_minus',
                                   '_v_previous', '_c',
                                   # The following two are only necessary for
                                   # C code where we cannot deal with scalars
                                   # and arrays interchangeably:
                                   '_I0_all', '_gtot_all'],
                                  size=self.N, read_only=True)

        self.Cm = Cm
        self.Ri = Ri
        # These explict assignments will load the morphology values from disk
        # in standalone mode
        self.distance_ = self.flat_morphology.distance
        self.length_ = self.flat_morphology.length
        self.area_ = self.flat_morphology.area
        self.diameter_ = self.flat_morphology.diameter
        self.r_length_1_ = self.flat_morphology.r_length_1
        self.r_length_2_ = self.flat_morphology.r_length_2
        if self.flat_morphology.has_coordinates:
            self.x_ = self.flat_morphology.x
            self.y_ = self.flat_morphology.y
            self.z_ = self.flat_morphology.z

        # Performs numerical integration step
        self.add_attribute('diffusion_state_updater')
        self.diffusion_state_updater = SpatialStateUpdater(self, method,
                                                           clock=self.clock,
                                                           order=order)

        # Update v after the gating variables to obtain consistent Ic and Im
        self.diffusion_state_updater.order = 1

        # Creation of contained_objects that do the work
        self.contained_objects.extend([self.diffusion_state_updater])

        if len(constant_over_dt):
            self.subexpression_updater = SubexpressionUpdater(self,
                                                              constant_over_dt)
            self.contained_objects.append(self.subexpression_updater)
    def split_stochastic(self):
        '''
        Split the expression into a stochastic and non-stochastic part.
        
        Splits the expression into a tuple of one `Expression` objects f (the
        non-stochastic part) and a dictionary mapping stochastic variables
        to `Expression` objects. For example, an expression of the form 
        ``f + g * xi_1 + h * xi_2`` would be returned as:
        ``(f, {'xi_1': g, 'xi_2': h})``
        Note that the `Expression` objects for the stochastic parts do not
        include the stochastic variable itself. 
        
        Returns
        -------
        (f, d) : (`Expression`, dict)
            A tuple of an `Expression` object and a dictionary, the first
            expression being the non-stochastic part of the equation and 
            the dictionary mapping stochastic variables (``xi`` or starting
            with ``xi_``) to `Expression` objects. If no stochastic variable
            is present in the code string, a tuple ``(self, None)`` will be
            returned with the unchanged `Expression` object.
        '''
        stochastic_variables = []
        for identifier in self.identifiers:
            if identifier == 'xi' or identifier.startswith('xi_'):
                stochastic_variables.append(identifier)

        # No stochastic variable
        if not len(stochastic_variables):
            return (self, None)

        stochastic_symbols = [sympy.Symbol(variable, real=True)
                              for variable in stochastic_variables]

        # Note that collect only works properly if the expression is expanded
        collected = str_to_sympy(self.code).expand().collect(stochastic_symbols,
                                                             evaluate=False)

        f_expr = None
        stochastic_expressions = {}
        for var, s_expr in collected.items():
            expr = Expression(sympy_expression=s_expr)
            if var == 1:
                if any(s_expr.has(s) for s in stochastic_symbols):
                    raise AssertionError(('Error when separating expression '
                                          '"%s" into stochastic and non-'
                                          'stochastic term: non-stochastic '
                                          'part was determined to be "%s" but '
                                          'contains a stochastic symbol)' % (self.code,
                                                                             s_expr)))
                f_expr = expr
            elif var in stochastic_symbols:
                stochastic_expressions[str(var)] = expr
            else:
                raise ValueError(('Expression "%s" cannot be separated into '
                                  'stochastic and non-stochastic '
                                  'term') % self.code)

        if f_expr is None:
            f_expr = Expression('0.0')

        return f_expr, stochastic_expressions