コード例 #1
0
    def __call__(self, equations, variables=None):
        system = get_conditionally_linear_system(equations)

        code = []
        for var, (A, B) in system.iteritems():
            s_var = sp.Symbol(var)
            s_dt = sp.Symbol('dt')
            if A == 0:
                update_expression = s_var + s_dt * B
            elif B != 0:
                BA = B / A
                # Avoid calculating B/A twice
                BA_name = '_BA_' + var
                s_BA = sp.Symbol(BA_name)
                code += [BA_name + ' = ' + sympy_to_str(BA)]
                update_expression = (s_var + s_BA) * sp.exp(A * s_dt) - s_BA
            else:
                update_expression = s_var * sp.exp(A * s_dt)

            # The actual update step
            update = '_{var} = {expr}'
            code += [
                update.format(var=var, expr=sympy_to_str(update_expression))
            ]

        # Replace all the variables with their updated value
        for var in system:
            code += ['{var} = _{var}'.format(var=var)]

        return '\n'.join(code)
コード例 #2
0
 def __call__(self, equations, variables=None):
     system = get_conditionally_linear_system(equations)
     
     code = []
     for var, (A, B) in system.iteritems():
         s_var = sp.Symbol(var)
         s_dt = sp.Symbol('dt')
         if A == 0:
             update_expression = s_var + s_dt * B
         elif B != 0:
             BA = B / A
             # Avoid calculating B/A twice
             BA_name = '_BA_' + var
             s_BA = sp.Symbol(BA_name)
             code += [BA_name + ' = ' + sympy_to_str(BA)]
             update_expression = (s_var + s_BA)*sp.exp(A*s_dt) - s_BA
         else:
             update_expression = s_var*sp.exp(A*s_dt)
             
         # The actual update step
         update = '_{var} = {expr}'
         code += [update.format(var=var, expr=sympy_to_str(update_expression))]
     
     # Replace all the variables with their updated value
     for var in system:
         code += ['{var} = _{var}'.format(var=var)]
         
     return '\n'.join(code)
コード例 #3
0
def test_constants_sympy():
    '''
    Make sure that symbolic constants are understood correctly by sympy
    '''
    assert sympy_to_str(str_to_sympy('1.0/inf')) == '0'
    assert sympy_to_str(str_to_sympy('sin(pi)')) == '0'
    assert sympy_to_str(str_to_sympy('log(e)')) == '1'
コード例 #4
0
ファイル: test_functions.py プロジェクト: flinz/brian2
def test_constants_sympy():
    '''
    Make sure that symbolic constants are understood correctly by sympy
    '''
    assert sympy_to_str(str_to_sympy('1.0/inf')) == '0'
    assert sympy_to_str(str_to_sympy('sin(pi)')) == '0'
    assert sympy_to_str(str_to_sympy('log(e)')) == '1'
コード例 #5
0
ファイル: codestrings.py プロジェクト: msGenDev/brian2
    def split_stochastic(self):
        '''
        Split the expression into a stochastic and non-stochastic part.
        
        Splits the expression into a tuple of one `Expression` objects f (the
        non-stochastic part) and a dictionary mapping stochastic variables
        to `Expression` objects. For example, an expression of the form 
        ``f + g * xi_1 + h * xi_2`` would be returned as:
        ``(f, {'xi_1': g, 'xi_2': h})``
        Note that the `Expression` objects for the stochastic parts do not
        include the stochastic variable itself. 
        
        Returns
        -------
        (f, d) : (`Expression`, dict)
            A tuple of an `Expression` object and a dictionary, the first
            expression being the non-stochastic part of the equation and 
            the dictionary mapping stochastic variables (``xi`` or starting
            with ``xi_``) to `Expression` objects. If no stochastic variable
            is present in the code string, a tuple ``(self, None)`` will be
            returned with the unchanged `Expression` object.
        '''
        stochastic_variables = []
        for identifier in self.identifiers:
            if identifier == 'xi' or identifier.startswith('xi_'):
                stochastic_variables.append(identifier)

        # No stochastic variable
        if not len(stochastic_variables):
            return (self, None)

        s_expr = self.sympy_expr.expand()

        stochastic_symbols = [
            sympy.Symbol(variable, real=True)
            for variable in stochastic_variables
        ]

        f = sympy.Wild('f', exclude=stochastic_symbols)  # non-stochastic part
        match_objects = [
            sympy.Wild('w_' + variable, exclude=stochastic_symbols)
            for variable in stochastic_variables
        ]
        match_expression = f
        for symbol, match_object in zip(stochastic_symbols, match_objects):
            match_expression += match_object * symbol
        matches = s_expr.match(match_expression)

        if matches is None:
            raise ValueError(
                ('Expression "%s" cannot be separated into stochastic '
                 'and non-stochastic term') % self.code)

        f_expr = Expression(sympy_to_str(matches[f]))
        stochastic_expressions = dict(
            (variable, Expression(sympy_to_str(matches[match_object])))
            for (variable,
                 match_object) in zip(stochastic_variables, match_objects))

        return (f_expr, stochastic_expressions)
コード例 #6
0
ファイル: exact.py プロジェクト: yger/brian2
    def __call__(self, equations, variables=None):
        
        if variables is None:
            variables = {}
        
        # Get a representation of the ODE system in the form of
        # dX/dt = M*X + B
        varnames, matrix, constants = get_linear_system(equations)

        # Make sure that the matrix M is constant, i.e. it only contains
        # external variables or constant variables
        symbols = set.union(*(el.atoms() for el in matrix))
        non_constant = _non_constant_symbols(symbols, variables)
        if len(non_constant):
            raise ValueError(('The coefficient matrix for the equations '
                              'contains the symbols %s, which are not '
                              'constant.') % str(non_constant))
        
        symbols = [Symbol(variable, real=True) for variable in varnames]
        solution = sp.solve_linear_system(matrix.row_join(constants), *symbols)
        b = sp.ImmutableMatrix([solution[symbol] for symbol in symbols]).transpose()
        
        # Solve the system
        dt = Symbol('dt', real=True, positive=True)
        A = (matrix * dt).exp()                
        C = sp.ImmutableMatrix([A.dot(b)]) - b
        _S = sp.MatrixSymbol('_S', len(varnames), 1)
        updates = A * _S + C.transpose()
        try:
            # In sympy 0.7.3, we have to explicitly convert it to a single matrix
            # In sympy 0.7.2, it is already a matrix (which doesn't have an
            # is_explicit method)
            updates = updates.as_explicit()
        except AttributeError:
            pass
        
        # The solution contains _S[0, 0], _S[1, 0] etc. for the state variables,
        # replace them with the state variable names 
        abstract_code = []
        for idx, (variable, update) in enumerate(zip(varnames, updates)):
            rhs = update
            for row_idx, varname in enumerate(varnames):
                rhs = rhs.subs(_S[row_idx, 0], varname)
            identifiers = get_identifiers(sympy_to_str(rhs))
            for identifier in identifiers:
                if identifier in variables:
                    var = variables[identifier]
                    if var.scalar and var.constant:
                        float_val = var.get_value()
                        rhs = rhs.xreplace({Symbol(identifier, real=True): Float(float_val)})

            # Do not overwrite the real state variables yet, the update step
            # of other state variables might still need the original values
            abstract_code.append('_' + variable + ' = ' + sympy_to_str(rhs))
        
        # Update the state variables
        for variable in varnames:
            abstract_code.append('{variable} = _{variable}'.format(variable=variable))
        return '\n'.join(abstract_code)
コード例 #7
0
def test_automatic_augmented_assignments():
    # We test that statements that could be rewritten as augmented assignments
    # are correctly rewritten (using sympy to test for symbolic equality)
    variables = {
        'x': ArrayVariable('x', owner=None, size=10, device=device),
        'y': ArrayVariable('y', owner=None, size=10, device=device),
        'z': ArrayVariable('y', owner=None, size=10, device=device),
        'b': ArrayVariable('b', owner=None, size=10, dtype=bool,
                           device=device),
        'clip': DEFAULT_FUNCTIONS['clip'],
        'inf': DEFAULT_CONSTANTS['inf']
    }
    statements = [
        # examples that should be rewritten
        # Note that using our approach, we will never get -= or /= but always
        # the equivalent += or *= statements
        ('x = x + 1.0', 'x += 1.0'),
        ('x = 2.0 * x', 'x *= 2.0'),
        ('x = x - 3.0', 'x += -3.0'),
        ('x = x/2.0', 'x *= 0.5'),
        ('x = y + (x + 1.0)', 'x += y + 1.0'),
        ('x = x + x', 'x *= 2.0'),
        ('x = x + y + z', 'x += y + z'),
        ('x = x + y + z', 'x += y + z'),
        # examples that should not be rewritten
        ('x = 1.0/x', 'x = 1.0/x'),
        ('x = 1.0', 'x = 1.0'),
        ('x = 2.0*(x + 1.0)', 'x = 2.0*(x + 1.0)'),
        ('x = clip(x + y, 0.0, inf)', 'x = clip(x + y, 0.0, inf)'),
        ('b = b or False', 'b = b or False')
    ]
    for orig, rewritten in statements:
        scalar, vector = make_statements(orig, variables, np.float32)
        try:  # we augment the assertion error with the original statement
            assert len(
                scalar
            ) == 0, 'Did not expect any scalar statements but got ' + str(
                scalar)
            assert len(
                vector
            ) == 1, 'Did expect a single statement but got ' + str(vector)
            statement = vector[0]
            expected_var, expected_op, expected_expr, _ = parse_statement(
                rewritten)
            assert expected_var == statement.var, 'expected write to variable %s, not to %s' % (
                expected_var, statement.var)
            assert expected_op == statement.op, 'expected operation %s, not %s' % (
                expected_op, statement.op)
            # Compare the two expressions using sympy to allow for different order etc.
            sympy_expected = str_to_sympy(expected_expr)
            sympy_actual = str_to_sympy(statement.expr)
            assert sympy_expected == sympy_actual, (
                'RHS expressions "%s" and "%s" are not identical' %
                (sympy_to_str(sympy_expected), sympy_to_str(sympy_actual)))
        except AssertionError as ex:
            raise AssertionError(
                'Transformation for statement "%s" gave an unexpected result: %s'
                % (orig, str(ex)))
コード例 #8
0
ファイル: codestrings.py プロジェクト: Kwartke/brian2
    def split_stochastic(self):
        '''
        Split the expression into a stochastic and non-stochastic part.
        
        Splits the expression into a tuple of one `Expression` objects f (the
        non-stochastic part) and a dictionary mapping stochastic variables
        to `Expression` objects. For example, an expression of the form 
        ``f + g * xi_1 + h * xi_2`` would be returned as:
        ``(f, {'xi_1': g, 'xi_2': h})``
        Note that the `Expression` objects for the stochastic parts do not
        include the stochastic variable itself. 
        
        Returns
        -------
        (f, d) : (`Expression`, dict)
            A tuple of an `Expression` object and a dictionary, the first
            expression being the non-stochastic part of the equation and 
            the dictionary mapping stochastic variables (``xi`` or starting
            with ``xi_``) to `Expression` objects. If no stochastic variable
            is present in the code string, a tuple ``(self, None)`` will be
            returned with the unchanged `Expression` object.
        '''
        stochastic_variables = []
        for identifier in self.identifiers:
            if identifier == 'xi' or identifier.startswith('xi_'):
                stochastic_variables.append(identifier)
        
        # No stochastic variable
        if not len(stochastic_variables):
            return (self, None)
        
        s_expr = self.sympy_expr.expand()
        
        stochastic_symbols = [sympy.Symbol(variable, real=True)
                              for variable in stochastic_variables]

        f = sympy.Wild('f', exclude=stochastic_symbols)  # non-stochastic part
        match_objects = [sympy.Wild('w_'+variable, exclude=stochastic_symbols)
                         for variable in stochastic_variables]
        match_expression = f
        for symbol, match_object in zip(stochastic_symbols, match_objects):
            match_expression += match_object * symbol
        matches = s_expr.match(match_expression)
        
        if matches is None:
            raise ValueError(('Expression "%s" cannot be separated into stochastic '
                              'and non-stochastic term') % self.code)

        f_expr = Expression(sympy_to_str(matches[f]))
        stochastic_expressions = dict((variable, Expression(sympy_to_str(matches[match_object])))
                                        for (variable, match_object) in
                                        zip(stochastic_variables, match_objects))

        return (f_expr, stochastic_expressions)
コード例 #9
0
    def __str__(self):
        s = '%s\n' % self.__class__.__name__

        if len(self.statements) > 0:
            s += 'Intermediate statements:\n'
            s += '\n'.join([(var + ' = ' + sympy_to_str(expr))
                            for var, expr in self.statements])
            s += '\n'

        s += 'Output:\n'
        s += sympy_to_str(self.output)
        return s
コード例 #10
0
ファイル: explicit.py プロジェクト: boddmg/brian2
 def __str__(self):
     s = '%s\n' % self.__class__.__name__
     
     if len(self.statements) > 0:
         s += 'Intermediate statements:\n'
         s += '\n'.join([(var + ' = ' + sympy_to_str(expr))
                         for var, expr in self.statements])
         s += '\n'
         
     s += 'Output:\n'
     s += sympy_to_str(self.output)
     return s
コード例 #11
0
ファイル: equations.py プロジェクト: yger/brian2
    def _get_substituted_expressions(self):
        '''
        Return a list of ``(varname, expr)`` tuples, containing all
        differential equations with all the static equation variables
        substituted with the respective expressions.
        
        Returns
        -------
        expr_tuples : list of (str, `CodeString`)
            A list of ``(varname, expr)`` tuples, where ``expr`` is a
            `CodeString` object with all static equation variables substituted
            with the respective expression.
        '''
        subst_exprs = []
        substitutions = {}
        for eq in self.ordered:
            # Skip parameters
            if eq.expr is None:
                continue

            new_sympy_expr = eq.expr.sympy_expr.subs(substitutions)
            new_str_expr = sympy_to_str(new_sympy_expr)
            expr = Expression(new_str_expr)

            if eq.type == STATIC_EQUATION:
                substitutions.update({sympy.Symbol(eq.varname, real=True): expr.sympy_expr})
            elif eq.type == DIFFERENTIAL_EQUATION:
                #  a differential equation that we have to check
                subst_exprs.append((eq.varname, expr))
            else:
                raise AssertionError('Unknown equation type %s' % eq.type)

        return subst_exprs
コード例 #12
0
ファイル: exact.py プロジェクト: ttxtea/brian2
    def __call__(self, equations, variables=None):
        
        if variables is None:
            variables = {}
        
        # Get a representation of the ODE system in the form of
        # dX/dt = M*X + B
        varnames, matrix, constants = get_linear_system(equations)

        # Make sure that the matrix M is constant, i.e. it only contains
        # external variables or constant variables
        t = Symbol('t', real=True, positive=True)
        symbols = set.union(*(el.atoms() for el in matrix))
        non_constant = _non_constant_symbols(symbols, variables, t)
        if len(non_constant):
            raise ValueError(('The coefficient matrix for the equations '
                              'contains the symbols %s, which are not '
                              'constant.') % str(non_constant))

        # Check for time dependence
        dt_var = variables.get('dt', None)
        if dt_var is not None:
            # This will raise an error if we meet the symbol "t" anywhere
            # except as an argument of a locally constant function
            for entry in itertools.chain(matrix, constants):
                _check_for_locally_constant(entry, variables, dt_var.get_value(), t)
        symbols = [Symbol(variable, real=True) for variable in varnames]
        solution = sp.solve_linear_system(matrix.row_join(constants), *symbols)
        b = sp.ImmutableMatrix([solution[symbol] for symbol in symbols]).transpose()
        
        # Solve the system
        dt = Symbol('dt', real=True, positive=True)
        A = (matrix * dt).exp()                
        C = sp.ImmutableMatrix([A.dot(b)]) - b
        _S = sp.MatrixSymbol('_S', len(varnames), 1)
        updates = A * _S + C.transpose()
        try:
            # In sympy 0.7.3, we have to explicitly convert it to a single matrix
            # In sympy 0.7.2, it is already a matrix (which doesn't have an
            # is_explicit method)
            updates = updates.as_explicit()
        except AttributeError:
            pass
        
        # The solution contains _S[0, 0], _S[1, 0] etc. for the state variables,
        # replace them with the state variable names 
        abstract_code = []
        for idx, (variable, update) in enumerate(zip(varnames, updates)):
            rhs = update
            for row_idx, varname in enumerate(varnames):
                rhs = rhs.subs(_S[row_idx, 0], varname)

            # Do not overwrite the real state variables yet, the update step
            # of other state variables might still need the original values
            abstract_code.append('_' + variable + ' = ' + sympy_to_str(rhs))
        
        # Update the state variables
        for variable in varnames:
            abstract_code.append('{variable} = _{variable}'.format(variable=variable))
        return '\n'.join(abstract_code)
コード例 #13
0
ファイル: equations.py プロジェクト: msGenDev/brian2
    def _get_substituted_expressions(self):
        '''
        Return a list of ``(varname, expr)`` tuples, containing all
        differential equations with all the subexpression variables
        substituted with the respective expressions.
        
        Returns
        -------
        expr_tuples : list of (str, `CodeString`)
            A list of ``(varname, expr)`` tuples, where ``expr`` is a
            `CodeString` object with all subexpression variables substituted
            with the respective expression.
        '''
        subst_exprs = []
        substitutions = {}
        for eq in self.ordered:
            # Skip parameters
            if eq.expr is None:
                continue

            new_sympy_expr = eq.expr.sympy_expr.subs(substitutions)
            new_str_expr = sympy_to_str(new_sympy_expr)
            expr = Expression(new_str_expr)

            if eq.type == SUBEXPRESSION:
                substitutions.update(
                    {sympy.Symbol(eq.varname, real=True): expr.sympy_expr})
            elif eq.type == DIFFERENTIAL_EQUATION:
                #  a differential equation that we have to check
                subst_exprs.append((eq.varname, expr))
            else:
                raise AssertionError('Unknown equation type %s' % eq.type)

        return subst_exprs
コード例 #14
0
ファイル: codestrings.py プロジェクト: yzerlaut/brian2
    def __init__(self, code=None, sympy_expression=None):
        if code is None and sympy_expression is None:
            raise TypeError('Have to provide either a string or a sympy expression')
        if code is not None and sympy_expression is not None:
            raise TypeError('Provide a string expression or a sympy expression, not both')

        if code is None:
            code = sympy_to_str(sympy_expression)
        super(Expression, self).__init__(code=code)
コード例 #15
0
ファイル: equations.py プロジェクト: rohithvarma3000/brian2
    def get_substituted_expressions(self,
                                    variables=None,
                                    include_subexpressions=False):
        '''
        Return a list of ``(varname, expr)`` tuples, containing all
        differential equations (and optionally subexpressions) with all the
        subexpression variables substituted with the respective expressions.

        Parameters
        ----------
        variables : dict, optional
            A mapping of variable names to `Variable`/`Function` objects.
        include_subexpressions : bool
            Whether also to return substituted subexpressions. Defaults to
            ``False``.

        Returns
        -------
        expr_tuples : list of (str, `CodeString`)
            A list of ``(varname, expr)`` tuples, where ``expr`` is a
            `CodeString` object with all subexpression variables substituted
            with the respective expression.
        '''
        if self._substituted_expressions is None:
            self._substituted_expressions = []
            substitutions = {}
            for eq in self.ordered:
                # Skip parameters
                if eq.expr is None:
                    continue

                new_sympy_expr = str_to_sympy(
                    eq.expr.code, variables).xreplace(substitutions)
                new_str_expr = sympy_to_str(new_sympy_expr)
                expr = Expression(new_str_expr)

                if eq.type == SUBEXPRESSION:
                    if eq.var_type == INTEGER:
                        sympy_var = sympy.Symbol(eq.varname, integer=True)
                    else:
                        sympy_var = sympy.Symbol(eq.varname, real=True)
                    substitutions.update(
                        {sympy_var: str_to_sympy(expr.code, variables)})
                    self._substituted_expressions.append((eq.varname, expr))
                elif eq.type == DIFFERENTIAL_EQUATION:
                    #  a differential equation that we have to check
                    self._substituted_expressions.append((eq.varname, expr))
                else:
                    raise AssertionError('Unknown equation type %s' % eq.type)

        if include_subexpressions:
            return self._substituted_expressions
        else:
            return [(name, expr)
                    for name, expr in self._substituted_expressions
                    if self[name].type == DIFFERENTIAL_EQUATION]
コード例 #16
0
    def __call__(self, equations, variables=None, method_options=None):
        method_options = extract_method_options(method_options, {})
        if equations.is_stochastic:
            raise UnsupportedEquationsException('Cannot solve stochastic '
                                                'equations with this state '
                                                'updater.')

        # Try whether the equations are conditionally linear
        try:
            system = get_conditionally_linear_system(equations, variables)
        except ValueError:
            raise UnsupportedEquationsException('Can only solve conditionally '
                                                'linear systems with this '
                                                'state updater.')

        code = []
        for var, (A, B) in system.iteritems():
            s_var = sp.Symbol(var)
            s_dt = sp.Symbol('dt')
            if A == 0:
                update_expression = s_var + s_dt * B
            elif B != 0:
                BA = B / A
                # Avoid calculating B/A twice
                BA_name = '_BA_' + var
                s_BA = sp.Symbol(BA_name)
                code += [BA_name + ' = ' + sympy_to_str(BA)]
                update_expression = (s_var + s_BA) * sp.exp(A * s_dt) - s_BA
            else:
                update_expression = s_var * sp.exp(A * s_dt)

            # The actual update step
            update = '_{var} = {expr}'
            code += [
                update.format(var=var, expr=sympy_to_str(update_expression))
            ]

        # Replace all the variables with their updated value
        for var in system:
            code += ['{var} = _{var}'.format(var=var)]

        return '\n'.join(code)
コード例 #17
0
 def __str__(self):
     s = f'{self.__class__.__name__}\n'
     
     if len(self.statements) > 0:
         s += 'Intermediate statements:\n'
         s += '\n'.join([f"{var} = {sympy_to_str(expr)}"
                         for var, expr in self.statements])
         s += '\n'
         
     s += 'Output:\n'
     s += sympy_to_str(self.output)
     return s
コード例 #18
0
    def __call__(self, equations, variables=None):
        if equations.is_stochastic:
            raise UnsupportedEquationsException('Cannot solve stochastic '
                                                'equations with this state '
                                                'updater.')

        # Try whether the equations are conditionally linear
        try:
            system = get_conditionally_linear_system(equations, variables)
        except ValueError:
            raise UnsupportedEquationsException('Can only solve conditionally '
                                                'linear systems with this '
                                                'state updater.')
        
        code = []
        for var, (A, B) in system.iteritems():
            s_var = sp.Symbol(var)
            s_dt = sp.Symbol('dt')
            if A == 0:
                update_expression = s_var + s_dt * B
            elif B != 0:
                BA = B / A
                # Avoid calculating B/A twice
                BA_name = '_BA_' + var
                s_BA = sp.Symbol(BA_name)
                code += [BA_name + ' = ' + sympy_to_str(BA)]
                update_expression = (s_var + s_BA)*sp.exp(A*s_dt) - s_BA
            else:
                update_expression = s_var*sp.exp(A*s_dt)
                
            # The actual update step
            update = '_{var} = {expr}'
            code += [update.format(var=var, expr=sympy_to_str(update_expression))]
        
        # Replace all the variables with their updated value
        for var in system:
            code += ['{var} = _{var}'.format(var=var)]
            
        return '\n'.join(code)
コード例 #19
0
ファイル: test_parsing.py プロジェクト: achilleas-k/brian2
def test_sympytools():
    # sympy_to_str(str_to_sympy(x)) should equal x

    # Note that the test below is quite fragile since sympy might rearrange the
    # order of symbols
    expressions = ['randn()',  # argumentless function
                   'x + sin(2.0*pi*freq*t)', # expression with a constant
                   'c * userfun(t + x)'
                  ] # non-sympy function

    for expr in expressions:
        expr2 = sympy_to_str(str_to_sympy(expr))
        assert expr.replace(' ', '') == expr2.replace(' ', ''), '%s != %s' % (expr, expr2)
コード例 #20
0
    def __init__(self, code=None, sympy_expression=None):
        if code is None and sympy_expression is None:
            raise TypeError('Have to provide either a string or a sympy expression')
        if code is not None and sympy_expression is not None:
            raise TypeError('Provide a string expression or a sympy expression, not both')

        if code is None:
            code = sympy_to_str(sympy_expression)
        else:
            # Just try to convert it to a sympy expression to get syntax errors
            # for incorrect expressions
            str_to_sympy(code)
        super(Expression, self).__init__(code=code)
コード例 #21
0
ファイル: test_parsing.py プロジェクト: squiba/brian2
def test_sympytools():
    # sympy_to_str(str_to_sympy(x)) should equal x

    # Note that the test below is quite fragile since sympy might rearrange the
    # order of symbols
    expressions = ['randn()',  # argumentless function
                   'x + sin(2.0*pi*freq*t)', # expression with a constant
                   'c * userfun(t + x)'
                  ] # non-sympy function

    for expr in expressions:
        expr2 = sympy_to_str(str_to_sympy(expr))
        assert expr.replace(' ', '') == expr2.replace(' ', ''), '%s != %s' % (expr, expr2)
コード例 #22
0
def get_sensitivity_init(group, parameters, param_init):
    """
    Calculate the initial values for the sensitivity parameters (necessary if
    initial values are functions of parameters).

    Parameters
    ----------
    group : `NeuronGroup`
        The group of neurons that will be simulated.
    parameters : list of str
        Names of the parameters that are fit.
    param_init : dict
        The dictionary with expressions to initialize the model variables.

    Returns
    -------
    sensitivity_init : dict
        Dictionary of expressions to initialize the sensitivity
        parameters.
    """
    sensitivity_dict = {}
    for var_name, expr in param_init.items():
        if not isinstance(expr, str):
            continue
        identifiers = get_identifiers(expr)
        for identifier in identifiers:
            if (identifier in group.variables
                    and getattr(group.variables[identifier], 'type',
                                None) == SUBEXPRESSION):
                raise NotImplementedError('Initializations that refer to a '
                                          'subexpression are currently not '
                                          'supported')
            sympy_expr = str_to_sympy(expr)
            for parameter in parameters:
                diffed = sympy_expr.diff(str_to_sympy(parameter))
                if diffed != sympy.S.Zero:
                    if getattr(group.variables[parameter], 'type',
                               None) == SUBEXPRESSION:
                        raise NotImplementedError(
                            'Sensitivity '
                            f'S_{var_name}_{parameter} '
                            'is initialized to a non-zero '
                            'value, but it has been '
                            'removed from the equations. '
                            'Set optimize=False to avoid '
                            'this.')
                    init_expr = sympy_to_str(diffed)
                    sensitivity_dict[f'S_{var_name}_{parameter}'] = init_expr
    return sensitivity_dict
コード例 #23
0
ファイル: codestrings.py プロジェクト: ZeitgeberH/brian2
    def __init__(self, code=None, sympy_expression=None):
        if code is None and sympy_expression is None:
            raise TypeError('Have to provide either a string or a sympy expression')
        if code is not None and sympy_expression is not None:
            raise TypeError('Provide a string expression or a sympy expression, not both')

        if code is None:
            code = sympy_to_str(sympy_expression)
        if sympy_expression is None:
            sympy_expression = str_to_sympy(code)

        super(Expression, self).__init__(code=code)

        # : The expression as a sympy object
        self.sympy_expr = sympy_expression
コード例 #24
0
ファイル: equations.py プロジェクト: brian-team/brian2
    def get_substituted_expressions(self, variables=None,
                                    include_subexpressions=False):
        '''
        Return a list of ``(varname, expr)`` tuples, containing all
        differential equations (and optionally subexpressions) with all the
        subexpression variables substituted with the respective expressions.

        Parameters
        ----------
        variables : dict, optional
            A mapping of variable names to `Variable`/`Function` objects.
        include_subexpressions : bool
            Whether also to return substituted subexpressions. Defaults to
            ``False``.

        Returns
        -------
        expr_tuples : list of (str, `CodeString`)
            A list of ``(varname, expr)`` tuples, where ``expr`` is a
            `CodeString` object with all subexpression variables substituted
            with the respective expression.
        '''
        if self._substituted_expressions is None:
            self._substituted_expressions = []
            substitutions = {}
            for eq in self.ordered:
                # Skip parameters
                if eq.expr is None:
                    continue

                new_sympy_expr = str_to_sympy(eq.expr.code, variables).xreplace(substitutions)
                new_str_expr = sympy_to_str(new_sympy_expr)
                expr = Expression(new_str_expr)

                if eq.type == SUBEXPRESSION:
                    substitutions.update({sympy.Symbol(eq.varname, real=True): str_to_sympy(expr.code, variables)})
                    self._substituted_expressions.append((eq.varname, expr))
                elif eq.type == DIFFERENTIAL_EQUATION:
                    #  a differential equation that we have to check
                    self._substituted_expressions.append((eq.varname, expr))
                else:
                    raise AssertionError('Unknown equation type %s' % eq.type)

        if include_subexpressions:
            return self._substituted_expressions
        else:
            return [(name, expr) for name, expr in self._substituted_expressions
                    if self[name].type == DIFFERENTIAL_EQUATION]
コード例 #25
0
ファイル: test_parsing.py プロジェクト: zeph1yr/brian2
def test_sympytools():
    # sympy_to_str(str_to_sympy(x)) should equal x

    # Note that the test below is quite fragile since sympy might rearrange the
    # order of symbols
    expressions = ['randn()',  # argumentless function
                   'x + sin(pi*freq*t)',  # expression with a constant
                   'c * userfun(t + x)',  # non-sympy function
                   'abs(x) + ceil(y)',  # functions with a different name in sympy
                   'inf',  # constant with a different name in sympy
                   'not(b)'  # boolean expression
                  ]

    for expr in expressions:
        expr2 = sympy_to_str(str_to_sympy(expr))
        assert expr.replace(' ', '') == expr2.replace(' ', ''), '%s != %s' % (expr, expr2)
コード例 #26
0
    def __init__(self, morphology=None, model=None, threshold=None,
                 refractory=False, reset=None, events=None,
                 threshold_location=None,
                 dt=None, clock=None, order=0, Cm=0.9 * uF / cm ** 2, Ri=150 * ohm * cm,
                 name='spatialneuron*', dtype=None, namespace=None,
                 method=('linear', 'exponential_euler', 'rk2', 'heun')):

        # #### Prepare and validate equations
        if isinstance(model, basestring):
            model = Equations(model)
        if not isinstance(model, Equations):
            raise TypeError(('model has to be a string or an Equations '
                             'object, is "%s" instead.') % type(model))

        # Insert the threshold mechanism at the specified location
        if threshold_location is not None:
            if hasattr(threshold_location,
                       '_indices'):  # assuming this is a method
                threshold_location = threshold_location._indices()
                # for now, only a single compartment allowed
                if len(threshold_location) == 1:
                    threshold_location = threshold_location[0]
                else:
                    raise AttributeError(('Threshold can only be applied on a '
                                          'single location'))
            threshold = '(' + threshold + ') and (i == ' + str(threshold_location) + ')'

        # Check flags (we have point currents)
        model.check_flags({DIFFERENTIAL_EQUATION: ('point current',),
                           PARAMETER: ('constant', 'shared', 'linked', 'point current'),
                           SUBEXPRESSION: ('shared', 'point current')})

        # Add the membrane potential
        model += Equations('''
        v:volt # membrane potential
        ''')

        # Extract membrane equation
        if 'Im' in model:
            membrane_eq = model['Im']  # the membrane equation
        else:
            raise TypeError('The transmembrane current Im must be defined')

        # Insert point currents in the membrane equation
        for eq in model.itervalues():
            if 'point current' in eq.flags:
                fail_for_dimension_mismatch(eq.unit, amp,
                                            "Point current " + eq.varname + " should be in amp")
                eq.flags.remove('point current')
                membrane_eq.expr = Expression(
                    str(membrane_eq.expr.code) + '+' + eq.varname + '/area')

        ###### Process model equations (Im) to extract total conductance and the remaining current
        # Check conditional linearity with respect to v
        # Match to _A*v+_B
        var = sp.Symbol('v', real=True)
        wildcard = sp.Wild('_A', exclude=[var])
        constant_wildcard = sp.Wild('_B', exclude=[var])
        pattern = wildcard * var + constant_wildcard

        # Expand expressions in the membrane equation
        membrane_eq.type = DIFFERENTIAL_EQUATION
        for var, expr in model.get_substituted_expressions():
            if var == 'Im':
                Im_expr = expr
        membrane_eq.type = SUBEXPRESSION

        # Factor out the variable
        s_expr = sp.collect(str_to_sympy(Im_expr.code).expand(), var)
        matches = s_expr.match(pattern)

        if matches is None:
            raise TypeError, "The membrane current must be linear with respect to v"
        a, b = (matches[wildcard],
                matches[constant_wildcard])

        # Extracts the total conductance from Im, and the remaining current
        minusa_str, b_str = sympy_to_str(-a), sympy_to_str(b)
        # Add correct units if necessary
        if minusa_str == '0':
            minusa_str += '*siemens/meter**2'
        if b_str == '0':
            b_str += '*amp/meter**2'
        gtot_str = "gtot__private=" + minusa_str + ": siemens/meter**2"
        I0_str = "I0__private=" + b_str + ": amp/meter**2"
        model += Equations(gtot_str + "\n" + I0_str)

        # Insert morphology (store a copy)
        self.morphology = copy.deepcopy(morphology)

        # Flatten the morphology
        self.flat_morphology = FlatMorphology(morphology)

        # Equations for morphology
        # TODO: check whether Cm and Ri are already in the equations
        #       no: should be shared instead of constant
        #       yes: should be constant (check)
        eqs_constants = Equations("""
        length : meter (constant)
        distance : meter (constant)
        area : meter**2 (constant)
        volume : meter**3
        diameter : meter (constant)
        Cm : farad/meter**2 (constant)
        Ri : ohm*meter (constant, shared)
        r_length_1 : meter (constant)
        r_length_2 : meter (constant)
        time_constant = Cm/gtot__private : second
        space_constant = (2/pi)**(1.0/3.0) * (area/(1/r_length_1 + 1/r_length_2))**(1.0/6.0) /
                         (2*(Ri*gtot__private)**(1.0/2.0)) : meter
        """)
        if self.flat_morphology.has_coordinates:
            eqs_constants += Equations('''
            x : meter (constant)
            y : meter (constant)
            z : meter (constant)
            ''')

        NeuronGroup.__init__(self, morphology.total_compartments,
                             model=model + eqs_constants,
                             threshold=threshold, refractory=refractory,
                             reset=reset, events=events,
                             method=method, dt=dt, clock=clock, order=order,
                             namespace=namespace, dtype=dtype, name=name)
        # Parameters and intermediate variables for solving the cable equations
        # Note that some of these variables could have meaningful physical
        # units (e.g. _v_star is in volt, _I0_all is in amp/meter**2 etc.) but
        # since these variables should never be used in user code, we don't
        # assign them any units
        self.variables.add_arrays(['_ab_star0', '_ab_star1', '_ab_star2',
                                   '_a_minus0', '_a_minus1', '_a_minus2',
                                   '_a_plus0', '_a_plus1', '_a_plus2',
                                   '_b_plus', '_b_minus',
                                   '_v_star', '_u_plus', '_u_minus',
                                   # The following three are for solving the
                                   # three tridiag systems in parallel
                                   '_c1', '_c2', '_c3',
                                   # The following two are only necessary for
                                   # C code where we cannot deal with scalars
                                   # and arrays interchangeably:
                                   '_I0_all', '_gtot_all'], unit=1,
                                  size=self.N, read_only=True)

        self.Cm = Cm
        self.Ri = Ri
        # These explict assignments will load the morphology values from disk
        # in standalone mode
        self.distance_ = self.flat_morphology.distance
        self.length_ = self.flat_morphology.length
        self.area_ = self.flat_morphology.area
        self.diameter_ = self.flat_morphology.diameter
        self.r_length_1_ = self.flat_morphology.r_length_1
        self.r_length_2_ = self.flat_morphology.r_length_2
        if self.flat_morphology.has_coordinates:
            self.x_ = self.flat_morphology.x
            self.y_ = self.flat_morphology.y
            self.z_ = self.flat_morphology.z

        # Performs numerical integration step
        self.add_attribute('diffusion_state_updater')
        self.diffusion_state_updater = SpatialStateUpdater(self, method,
                                                           clock=self.clock,
                                                           order=order)

        # Creation of contained_objects that do the work
        self.contained_objects.extend([self.diffusion_state_updater])
コード例 #27
0
ファイル: test_parsing.py プロジェクト: achilleas-k/brian2
 def evaluator(expr, ns):
     expr = sympy_to_str(expr)
     return eval(expr, ns)
コード例 #28
0
ファイル: spatialneuron.py プロジェクト: brian-team/brian2
    def __init__(self, morphology=None, model=None, threshold=None,
                 refractory=False, reset=None, events=None,
                 threshold_location=None,
                 dt=None, clock=None, order=0, Cm=0.9 * uF / cm ** 2, Ri=150 * ohm * cm,
                 name='spatialneuron*', dtype=None, namespace=None,
                 method=('exact', 'exponential_euler', 'rk2', 'heun'),
                 method_options=None):

        # #### Prepare and validate equations
        if isinstance(model, basestring):
            model = Equations(model)
        if not isinstance(model, Equations):
            raise TypeError(('model has to be a string or an Equations '
                             'object, is "%s" instead.') % type(model))

        # Insert the threshold mechanism at the specified location
        if threshold_location is not None:
            if hasattr(threshold_location,
                       '_indices'):  # assuming this is a method
                threshold_location = threshold_location._indices()
                # for now, only a single compartment allowed
                if len(threshold_location) == 1:
                    threshold_location = threshold_location[0]
                else:
                    raise AttributeError(('Threshold can only be applied on a '
                                          'single location'))
            threshold = '(' + threshold + ') and (i == ' + str(threshold_location) + ')'

        # Check flags (we have point currents)
        model.check_flags({DIFFERENTIAL_EQUATION: ('point current',),
                           PARAMETER: ('constant', 'shared', 'linked', 'point current'),
                           SUBEXPRESSION: ('shared', 'point current',
                                           'constant over dt')})
        #: The original equations as specified by the user (i.e. before
        #: inserting point-currents into the membrane equation, before adding
        #: all the internally used variables and constants, etc.).
        self.user_equations = model

        # Separate subexpressions depending whether they are considered to be
        # constant over a time step or not (this would also be done by the
        # NeuronGroup initializer later, but this would give incorrect results
        # for the linearity check)
        model, constant_over_dt = extract_constant_subexpressions(model)

        # Extract membrane equation
        if 'Im' in model:
            if len(model['Im'].flags):
                raise TypeError('Cannot specify any flags for the transmembrane '
                                'current Im.')
            membrane_expr = model['Im'].expr  # the membrane equation
        else:
            raise TypeError('The transmembrane current Im must be defined')

        model_equations = []
        # Insert point currents in the membrane equation
        for eq in model.itervalues():
            if eq.varname == 'Im':
                continue  # ignore -- handled separately
            if 'point current' in eq.flags:
                fail_for_dimension_mismatch(eq.dim, amp,
                                            "Point current " + eq.varname + " should be in amp")
                membrane_expr = Expression(
                    str(membrane_expr.code) + '+' + eq.varname + '/area')
                eq = SingleEquation(eq.type, eq.varname, eq.dim, expr=eq.expr,
                                    flags=list(set(eq.flags)-set(['point current'])))
            model_equations.append(eq)

        model_equations.append(SingleEquation(SUBEXPRESSION, 'Im',
                                              dimensions=(amp/meter**2).dim,
                                              expr=membrane_expr))
        model_equations.append(SingleEquation(PARAMETER, 'v', volt.dim))
        model = Equations(model_equations)

        ###### Process model equations (Im) to extract total conductance and the remaining current
        # Expand expressions in the membrane equation
        for var, expr in model.get_substituted_expressions(include_subexpressions=True):
            if var == 'Im':
                Im_expr = expr
                break
        else:
            raise AssertionError('Model equations did not contain Im!')

        # Differentiate Im with respect to v
        Im_sympy_exp = str_to_sympy(Im_expr.code)
        v_sympy = sp.Symbol('v', real=True)
        diffed = sp.diff(Im_sympy_exp, v_sympy)

        unevaled_derivatives = diffed.atoms(sp.Derivative)
        if len(unevaled_derivatives):
            raise TypeError('Cannot take the derivative of "{Im}" with respect '
                            'to v.'.format(Im=Im_expr.code))

        gtot_str = sympy_to_str(sp.simplify(-diffed))
        I0_str = sympy_to_str(sp.simplify(Im_sympy_exp - diffed*v_sympy))

        if gtot_str == '0':
            gtot_str += '*siemens/meter**2'
        if I0_str == '0':
            I0_str += '*amp/meter**2'
        gtot_str = "gtot__private=" + gtot_str + ": siemens/meter**2"
        I0_str = "I0__private=" + I0_str + ": amp/meter**2"

        model += Equations(gtot_str + "\n" + I0_str)

        # Insert morphology (store a copy)
        self.morphology = copy.deepcopy(morphology)

        # Flatten the morphology
        self.flat_morphology = FlatMorphology(morphology)

        # Equations for morphology
        # TODO: check whether Cm and Ri are already in the equations
        #       no: should be shared instead of constant
        #       yes: should be constant (check)
        eqs_constants = Equations("""
        length : meter (constant)
        distance : meter (constant)
        area : meter**2 (constant)
        volume : meter**3
        Ic : amp/meter**2
        diameter : meter (constant)
        Cm : farad/meter**2 (constant)
        Ri : ohm*meter (constant, shared)
        r_length_1 : meter (constant)
        r_length_2 : meter (constant)
        time_constant = Cm/gtot__private : second
        space_constant = (2/pi)**(1.0/3.0) * (area/(1/r_length_1 + 1/r_length_2))**(1.0/6.0) /
                         (2*(Ri*gtot__private)**(1.0/2.0)) : meter
        """)
        if self.flat_morphology.has_coordinates:
            eqs_constants += Equations('''
            x : meter (constant)
            y : meter (constant)
            z : meter (constant)
            ''')

        NeuronGroup.__init__(self, morphology.total_compartments,
                             model=model + eqs_constants,
                             method_options=method_options,
                             threshold=threshold, refractory=refractory,
                             reset=reset, events=events,
                             method=method, dt=dt, clock=clock, order=order,
                             namespace=namespace, dtype=dtype, name=name)
        # Parameters and intermediate variables for solving the cable equations
        # Note that some of these variables could have meaningful physical
        # units (e.g. _v_star is in volt, _I0_all is in amp/meter**2 etc.) but
        # since these variables should never be used in user code, we don't
        # assign them any units
        self.variables.add_arrays(['_ab_star0', '_ab_star1', '_ab_star2',
                                   '_b_plus', '_b_minus',
                                   '_v_star', '_u_plus', '_u_minus',
                                   '_v_previous', '_c',
                                   # The following two are only necessary for
                                   # C code where we cannot deal with scalars
                                   # and arrays interchangeably:
                                   '_I0_all', '_gtot_all'],
                                  size=self.N, read_only=True)

        self.Cm = Cm
        self.Ri = Ri
        # These explict assignments will load the morphology values from disk
        # in standalone mode
        self.distance_ = self.flat_morphology.distance
        self.length_ = self.flat_morphology.length
        self.area_ = self.flat_morphology.area
        self.diameter_ = self.flat_morphology.diameter
        self.r_length_1_ = self.flat_morphology.r_length_1
        self.r_length_2_ = self.flat_morphology.r_length_2
        if self.flat_morphology.has_coordinates:
            self.x_ = self.flat_morphology.x
            self.y_ = self.flat_morphology.y
            self.z_ = self.flat_morphology.z

        # Performs numerical integration step
        self.add_attribute('diffusion_state_updater')
        self.diffusion_state_updater = SpatialStateUpdater(self, method,
                                                           clock=self.clock,
                                                           order=order)

        # Update v after the gating variables to obtain consistent Ic and Im
        self.diffusion_state_updater.order = 1

        # Creation of contained_objects that do the work
        self.contained_objects.extend([self.diffusion_state_updater])

        if len(constant_over_dt):
            self.subexpression_updater = SubexpressionUpdater(self,
                                                              constant_over_dt)
            self.contained_objects.append(self.subexpression_updater)
コード例 #29
0
    def _generate_RHS(self, eqs, var, eq_symbols, temp_vars, expr,
                      non_stochastic_expr, stochastic_expr,
                      stochastic_variable=()):
        '''
        Helper function used in `__call__`. Generates the right hand side of
        an abstract code statement by appropriately replacing f, g and t.
        For example, given a differential equation ``dv/dt = -(v + I) / tau``
        (i.e. `var` is ``v` and `expr` is ``(-v + I) / tau``) together with
        the `rk2` step ``return x + dt*f(x +  k/2, t + dt/2)``
        (i.e. `non_stochastic_expr` is
        ``x + dt*f(x +  k/2, t + dt/2)`` and `stochastic_expr` is ``None``),
        produces ``v + dt*(-v - _k_v/2 + I + _k_I/2)/tau``.
                
        '''
        
        # Note: in the following we are silently ignoring the case that a
        # state updater does not care about either the non-stochastic or the
        # stochastic part of an equation. We do trust state updaters to
        # correctly specify their own abilities (i.e. they do not claim to
        # support stochastic equations but actually just ignore the stochastic
        # part). We can't really check the issue here, as we are only dealing
        # with one line of the state updater description. It is perfectly valid
        # to write the euler update as:
        #     non_stochastic = dt * f(x, t)
        #     stochastic = dt**.5 * g(x, t) * xi
        #     return x + non_stochastic + stochastic
        #
        # In the above case, we'll deal with lines which do not define either
        # the stochastic or the non-stochastic part.
        
        non_stochastic, stochastic = expr.split_stochastic()

        if not (non_stochastic is None or non_stochastic_expr is None):
            # We do have a non-stochastic part in our equation and in the state
            # updater description
            non_stochastic_results = self._non_stochastic_part(eq_symbols,
                                                               non_stochastic,
                                                               non_stochastic_expr,
                                                               stochastic_variable,
                                                               temp_vars, var)
        else:
            non_stochastic_results = []

        if not (stochastic is None or stochastic_expr is None):
            # We do have a stochastic part in our equation and in the state
            # updater description
            stochastic_results = self._stochastic_part(eq_symbols,
                                                       stochastic,
                                                       stochastic_expr,
                                                       stochastic_variable,
                                                       temp_vars, var)
        else:
            stochastic_results = []
        
        RHS = sympy.Number(0)
        # All the parts (one non-stochastic and potentially more than one
        # stochastic part) are combined with addition
        for non_stochastic_result in non_stochastic_results:
            RHS += non_stochastic_result
        for stochastic_result in stochastic_results:
            RHS += stochastic_result

        return sympy_to_str(RHS)
コード例 #30
0
def get_sensitivity_equations(group,
                              parameters,
                              namespace=None,
                              level=1,
                              optimize=True):
    """
    Get equations for sensitivity variables.

    Parameters
    ----------
    group : `NeuronGroup`
        The group of neurons that will be simulated.
    parameters : list of str
        Names of the parameters that are fit.
    namespace : dict, optional
        The namespace to use.
    level : `int`, optional
        How much farther to go down in the stack to find the namespace.
    optimize : bool, optional
        Whether to remove sensitivity variables from the equations that do
        not evolve if initialized to zero (e.g. ``dS_x_y/dt = -S_x_y/tau``
        would be removed). This avoids unnecessary computation but will fail
        in the rare case that such a sensitivity variable needs to be
        initialized to a non-zero value. Defaults to ``True``.

    Returns
    -------
    sensitivity_eqs : `Equations`
        The equations for the sensitivity variables.
    """
    if namespace is None:
        namespace = get_local_namespace(level)
        namespace.update(group.namespace)

    eqs = group.equations
    diff_eqs = eqs.get_substituted_expressions(group.variables)
    diff_eq_names = [name for name, _ in diff_eqs]

    system = sympy.Matrix(
        [str_to_sympy(diff_eq[1].code) for diff_eq in diff_eqs])
    J = system.jacobian([str_to_sympy(d) for d in diff_eq_names])

    sensitivity = []
    sensitivity_names = []
    for parameter in parameters:
        F = system.jacobian([str_to_sympy(parameter)])
        names = [
            str_to_sympy(f'S_{diff_eq_name}_{parameter}')
            for diff_eq_name in diff_eq_names
        ]
        sensitivity.append(J * sympy.Matrix(names) + F)
        sensitivity_names.append(names)

    new_eqs = []
    for names, sensitivity_eqs, param in zip(sensitivity_names, sensitivity,
                                             parameters):
        for name, eq, orig_var in zip(names, sensitivity_eqs, diff_eq_names):
            if param in namespace:
                unit = eqs[orig_var].dim / namespace[param].dim
            elif param in group.variables:
                unit = eqs[orig_var].dim / group.variables[param].dim
            else:
                raise AssertionError(
                    f'Parameter {param} neither in namespace nor variables')
            unit = repr(unit) if not unit.is_dimensionless else '1'
            if optimize:
                # Check if the equation stays at zero if initialized at zero
                zeroed = eq.subs(name, sympy.S.Zero)
                if zeroed == sympy.S.Zero:
                    # No need to include equation as differential equation
                    if unit == '1':
                        new_eqs.append(f'{sympy_to_str(name)} = 0 : {unit}')
                    else:
                        new_eqs.append(
                            f'{sympy_to_str(name)} = 0*{unit} : {unit}')
                    continue
            rhs = sympy_to_str(eq)
            if rhs == '0':  # avoid unit mismatch
                rhs = f'0*{unit}/second'
            new_eqs.append('d{lhs}/dt = {rhs} : {unit}'.format(
                lhs=sympy_to_str(name), rhs=rhs, unit=unit))
    new_eqs = Equations('\n'.join(new_eqs))
    return new_eqs
コード例 #31
0
ファイル: test_codegen.py プロジェクト: brian-team/brian2
def test_automatic_augmented_assignments():
    # We test that statements that could be rewritten as augmented assignments
    # are correctly rewritten (using sympy to test for symbolic equality)
    variables = {
        'x': ArrayVariable('x', owner=None, size=10,
                           device=device),
        'y': ArrayVariable('y', owner=None, size=10,
                           device=device),
        'z': ArrayVariable('y', owner=None, size=10,
                           device=device),
        'b': ArrayVariable('b', owner=None, size=10,
                           dtype=np.bool, device=device),
        'clip': DEFAULT_FUNCTIONS['clip'],
        'inf': DEFAULT_CONSTANTS['inf']
    }
    statements = [
        # examples that should be rewritten
        # Note that using our approach, we will never get -= or /= but always
        # the equivalent += or *= statements
        ('x = x + 1', 'x += 1'),
        ('x = 2 * x', 'x *= 2'),
        ('x = x - 3', 'x += -3'),
        ('x = x/2', 'x *= 0.5'),
        ('x = y + (x + 1)', 'x += y + 1'),
        ('x = x + x', 'x *= 2'),
        ('x = x + y + z', 'x += y + z'),
        ('x = x + y + z', 'x += y + z'),
        # examples that should not be rewritten
        ('x = 1/x', 'x = 1/x'),
        ('x = 1', 'x = 1'),
        ('x = 2*(x + 1)', 'x = 2*(x + 1)'),
        ('x = clip(x + y, 0, inf)', 'x = clip(x + y, 0, inf)'),
        ('b = b or False', 'b = b or False')
    ]
    for orig, rewritten in statements:
        scalar, vector = make_statements(orig, variables, np.float32)
        try:  # we augment the assertion error with the original statement
            assert len(scalar) == 0, 'Did not expect any scalar statements but got ' + str(scalar)
            assert len(vector) == 1, 'Did expect a single statement but got ' + str(vector)
            statement = vector[0]
            expected_var, expected_op, expected_expr, _ = parse_statement(rewritten)
            assert expected_var == statement.var, 'expected write to variable %s, not to %s' % (expected_var, statement.var)
            assert expected_op == statement.op, 'expected operation %s, not %s' % (expected_op, statement.op)
            # Compare the two expressions using sympy to allow for different order etc.
            sympy_expected = str_to_sympy(expected_expr)
            sympy_actual = str_to_sympy(statement.expr)
            assert sympy_expected == sympy_actual, ('RHS expressions "%s" and "%s" are not identical' % (sympy_to_str(sympy_expected),
                                                                                                         sympy_to_str(sympy_actual)))
        except AssertionError as ex:
            raise AssertionError('Transformation for statement "%s" gave an unexpected result: %s' % (orig, str(ex)))
コード例 #32
0
ファイル: test_parsing.py プロジェクト: treestreamymw/brian2
 def evaluator(expr, ns):
     expr = sympy_to_str(expr)
     ns = dict(ns)
     # Add the floor function which is used to implement floor division
     ns['floor'] = DEFAULT_FUNCTIONS['floor']
     return eval(expr, ns)
コード例 #33
0
ファイル: test_parsing.py プロジェクト: brian-team/brian2
def test_sympy_infinity():
    # See github issue #1061
    assert sympy_to_str(str_to_sympy('inf')) == 'inf'
    assert sympy_to_str(str_to_sympy('-inf')) == '-inf'
コード例 #34
0
    def __call__(self, equations, variables=None, method_options=None):
        method_options = extract_method_options(method_options,
                                                {'simplify': True})

        if equations.is_stochastic:
            raise UnsupportedEquationsException('Cannot solve stochastic '
                                                'equations with this state '
                                                'updater.')
        if variables is None:
            variables = {}

        # Get a representation of the ODE system in the form of
        # dX/dt = M*X + B
        varnames, matrix, constants = get_linear_system(equations, variables)

        # No differential equations, nothing to do (this occurs sometimes in the
        # test suite where the whole model is nothing more than something like
        # 'v : 1')
        if matrix.shape == (0, 0):
            return ''

        # Make sure that the matrix M is constant, i.e. it only contains
        # external variables or constant variables

        # Check for time dependence
        dt_value = variables['dt'].get_value(
        )[0] if 'dt' in variables else None

        # This will raise an error if we meet the symbol "t" anywhere
        # except as an argument of a locally constant function
        for entry in itertools.chain(matrix, constants):
            if not is_constant_over_dt(entry, variables, dt_value):
                raise UnsupportedEquationsException(
                    ('Expression "{}" is not guaranteed to be constant over a '
                     'time step').format(sympy_to_str(entry)))

        symbols = [Symbol(variable, real=True) for variable in varnames]
        solution = sp.solve_linear_system(matrix.row_join(constants), *symbols)
        if solution is None or set(symbols) != set(solution.keys()):
            raise UnsupportedEquationsException('Cannot solve the given '
                                                'equations with this '
                                                'stateupdater.')
        b = sp.ImmutableMatrix([solution[symbol] for symbol in symbols])

        # Solve the system
        dt = Symbol('dt', real=True, positive=True)
        try:
            A = (matrix * dt).exp()
        except NotImplementedError:
            raise UnsupportedEquationsException('Cannot solve the given '
                                                'equations with this '
                                                'stateupdater.')
        if method_options['simplify']:
            A = A.applyfunc(
                lambda x: sp.factor_terms(sp.cancel(sp.signsimp(x))))
        C = sp.ImmutableMatrix(A * b) - b
        _S = sp.MatrixSymbol('_S', len(varnames), 1)
        updates = A * _S + C
        updates = updates.as_explicit()

        # The solution contains _S[0, 0], _S[1, 0] etc. for the state variables,
        # replace them with the state variable names
        abstract_code = []
        for idx, (variable, update) in enumerate(zip(varnames, updates)):
            rhs = update
            if rhs.has(I, re, im):
                raise UnsupportedEquationsException(
                    'The solution to the linear system '
                    'contains complex values '
                    'which is currently not implemented.')
            for row_idx, varname in enumerate(varnames):
                rhs = rhs.subs(_S[row_idx, 0], varname)

            # Do not overwrite the real state variables yet, the update step
            # of other state variables might still need the original values
            abstract_code.append('_' + variable + ' = ' + sympy_to_str(rhs))

        # Update the state variables
        for variable in varnames:
            abstract_code.append(
                '{variable} = _{variable}'.format(variable=variable))
        return '\n'.join(abstract_code)
コード例 #35
0
    def __call__(self, equations, variables=None, method_options=None):
        logger.warn(
            "The 'independent' state updater is deprecated and might be "
            "removed in future versions of Brian.",
            'deprecated_independent',
            once=True)
        extract_method_options(method_options, {})
        if equations.is_stochastic:
            raise UnsupportedEquationsException('Cannot solve stochastic '
                                                'equations with this state '
                                                'updater')
        if variables is None:
            variables = {}

        diff_eqs = equations.get_substituted_expressions(variables)

        t = Symbol('t', real=True, positive=True)
        dt = Symbol('dt', real=True, positive=True)
        t0 = Symbol('t0', real=True, positive=True)

        code = []
        for name, expression in diff_eqs:
            rhs = str_to_sympy(expression.code, variables)

            # We have to be careful and use the real=True assumption as well,
            # otherwise sympy doesn't consider the symbol a match to the content
            # of the equation
            var = Symbol(name, real=True)
            f = sp.Function(name)
            rhs = rhs.subs(var, f(t))
            derivative = sp.Derivative(f(t), t)
            diff_eq = sp.Eq(derivative, rhs)
            # TODO: simplify=True sometimes fails with 0.7.4, see:
            # https://github.com/sympy/sympy/issues/2666
            try:
                general_solution = sp.dsolve(diff_eq, f(t), simplify=True)
            except RuntimeError:
                general_solution = sp.dsolve(diff_eq, f(t), simplify=False)
            # Check whether this is an explicit solution
            if not getattr(general_solution, 'lhs', None) == f(t):
                raise UnsupportedEquationsException(
                    'Cannot explicitly solve: ' + str(diff_eq))
            # Solve for C1 (assuming "var" as the initial value and "t0" as time)
            if general_solution.has(Symbol('C1')):
                if general_solution.has(Symbol('C2')):
                    raise UnsupportedEquationsException(
                        'Too many constants in solution: %s' %
                        str(general_solution))
                constant_solution = sp.solve(general_solution, Symbol('C1'))
                if len(constant_solution) != 1:
                    raise UnsupportedEquationsException(
                        ("Couldn't solve for the constant "
                         "C1 in : %s ") % str(general_solution))
                constant = constant_solution[0].subs(t, t0).subs(f(t0), var)
                solution = general_solution.rhs.subs('C1', constant)
            else:
                solution = general_solution.rhs.subs(t, t0).subs(f(t0), var)
            # Evaluate the expression for one timestep
            solution = solution.subs(t, t + dt).subs(t0, t)
            # only try symplifying it -- it sometimes raises an error
            try:
                solution = solution.simplify()
            except ValueError:
                pass

            code.append(name + ' = ' + sympy_to_str(solution))

        return '\n'.join(code)
コード例 #36
0
 def evaluator(expr, ns):
     expr = sympy_to_str(expr)
     return eval(expr, ns)
コード例 #37
0
ファイル: translation.py プロジェクト: ttxtea/brian2
def make_statements(code, variables, dtype):
    '''
    Turn a series of abstract code statements into Statement objects, inferring
    whether each line is a set/declare operation, whether the variables are
    constant or not, and handling the cacheing of subexpressions.

    Parameters
    ----------
    code : str
        A (multi-line) string of statements.
    variables : dict-like
        A dictionary of with `Variable` and `Function` objects for every
        identifier used in the `code`.
    dtype : `dtype`
        The data type to use for temporary variables

    Returns
    -------
    scalar_statements, vector_statements : (list of `Statement`, list of `Statement`)
        Lists with statements that are to be executed once and statements that
        are to be executed once for every neuron/synapse/... (or in a vectorised
        way)

    Notes
    -----
    The `scalar_statements` may include newly introduced scalar constants that
    have been identified as loop-invariant and have therefore been pulled out
    of the vector statements. The resulting statements will also use augmented
    assignments where possible, i.e. a statement such as ``w = w + 1`` will be
    replaced by ``w += 1``.
    '''
    code = strip_empty_lines(deindent(code))
    lines = re.split(r'[;\n]', code)
    lines = [LineInfo(code=line) for line in lines if len(line)]
    if DEBUG:
        print 'INPUT CODE:'
        print code
    # Do a copy so we can add stuff without altering the original dict
    variables = dict(variables)
    # we will do inference to work out which lines are := and which are =
    defined = set(k for k, v in variables.iteritems()
                  if not isinstance(v, AuxiliaryVariable))
    for line in lines:
        statement = None
        # parse statement into "var op expr"
        var, op, expr, comment = parse_statement(line.code)
        if op == '=':
            if var not in defined:
                op = ':='
                defined.add(var)
                if var not in variables:
                    is_scalar = is_scalar_expression(expr, variables)
                    new_var = AuxiliaryVariable(
                        var,
                        Unit(1),  # doesn't matter here
                        dtype=dtype,
                        scalar=is_scalar)
                    variables[var] = new_var
            elif not variables[var].is_boolean:
                sympy_expr = str_to_sympy(expr)
                sympy_var = sympy.Symbol(var, real=True)
                try:
                    collected = sympy.collect(sympy_expr,
                                              sympy_var,
                                              exact=True,
                                              evaluate=False)
                except AttributeError:
                    # If something goes wrong during collection, e.g. collect
                    # does not work for logical expressions
                    collected = {1: sympy_expr}

                if (len(collected) == 2
                        and set(collected.keys()) == {1, sympy_var}
                        and collected[sympy_var] == 1):
                    # We can replace this statement by a += assignment
                    statement = Statement(var,
                                          '+=',
                                          sympy_to_str(collected[1]),
                                          comment,
                                          dtype=variables[var].dtype,
                                          scalar=variables[var].scalar)
                elif len(collected) == 1 and sympy_var in collected:
                    # We can replace this statement by a *= assignment
                    statement = Statement(var,
                                          '*=',
                                          sympy_to_str(collected[sympy_var]),
                                          comment,
                                          dtype=variables[var].dtype,
                                          scalar=variables[var].scalar)
        if statement is None:
            statement = Statement(var,
                                  op,
                                  expr,
                                  comment,
                                  dtype=variables[var].dtype,
                                  scalar=variables[var].scalar)

        line.statement = statement
        # for each line will give the variable being written to
        line.write = var
        # each line will give a set of variables which are read
        line.read = get_identifiers_recursively([expr], variables)

    # All writes to scalar variables must happen before writes to vector
    # variables
    scalar_write_done = False
    for line in lines:
        stmt = line.statement
        if stmt.op != ':=' and variables[
                stmt.var].scalar and scalar_write_done:
            raise SyntaxError(
                ('All writes to scalar variables in a code block '
                 'have to be made before writes to vector '
                 'variables. Illegal write to %s.') % line.write)
        elif not variables[stmt.var].scalar:
            scalar_write_done = True

    if DEBUG:
        print 'PARSED STATEMENTS:'
        for line in lines:
            print line.statement, 'Read:' + str(
                line.read), 'Write:' + line.write

    # all variables which are written to at some point in the code block
    # used to determine whether they should be const or not
    all_write = set(line.write for line in lines)

    if DEBUG:
        print 'ALL WRITE:', all_write

    # backwards compute whether or not variables will be read again
    # note that will_read for a line gives the set of variables it will read
    # on the current line or subsequent ones. will_write gives the set of
    # variables that will be written after the current line
    will_read = set()
    will_write = set()
    for line in lines[::-1]:
        will_read = will_read.union(line.read)
        line.will_read = will_read.copy()
        line.will_write = will_write.copy()
        will_write.add(line.write)

    if DEBUG:
        print 'WILL READ/WRITE:'
        for line in lines:
            print line.statement, 'Read:' + str(
                line.will_read), 'Write:' + str(line.will_write)

    # generate cacheing statements for common subexpressions
    # cached subexpressions need to be recomputed whenever they are to be used
    # on the next line, and currently invalid (meaning that the current value
    # stored in the subexpression variable is no longer accurate because one
    # of the variables appearing in it has changed). All subexpressions start
    # as invalid, and are invalidated whenever one of the variables appearing
    # in the RHS changes value.
    subexpressions = dict((name, val) for name, val in variables.items()
                          if isinstance(val, Subexpression))
    # sort subexpressions into an order so that subexpressions that don't depend
    # on other subexpressions are first
    subexpr_deps = dict((name, [dep for dep in subexpr.identifiers if dep in subexpressions]) for \
                                                            name, subexpr in subexpressions.items())
    sorted_subexpr_vars = topsort(subexpr_deps)

    if DEBUG:
        print 'SUBEXPRESSIONS:', subexpressions.keys()
    statements = []
    # all start as invalid
    valid = dict((name, False) for name in subexpressions.keys())
    # none are yet defined (or declared)
    subdefined = dict((name, False) for name in subexpressions.keys())
    for line in lines:
        stmt = line.statement
        read = line.read
        write = line.write
        will_read = line.will_read
        will_write = line.will_write
        # check that all subexpressions in expr are valid, and if not
        # add a definition/set its value, and set it to be valid
        # scan through in sorted order so that recursive subexpression dependencies
        # are handled in the right order
        for var in sorted_subexpr_vars:
            if var not in read:
                continue
            # if subexpression, and invalid
            if not valid.get(var, True):  # all non-subexpressions are valid
                subexpression = subexpressions[var]
                # if already defined/declared
                if subdefined[var]:
                    op = '='
                    constant = False
                else:
                    op = ':='
                    subdefined[var] = True
                    # set to constant only if we will not write to it again
                    constant = var not in will_write
                    # check all subvariables are not written to again as well
                    if constant:
                        ids = subexpression.identifiers
                        constant = all(v not in will_write for v in ids)
                valid[var] = True
                statement = Statement(var,
                                      op,
                                      subexpression.expr,
                                      comment='',
                                      dtype=variables[var].dtype,
                                      constant=constant,
                                      subexpression=True,
                                      scalar=variables[var].scalar)
                statements.append(statement)
        var, op, expr, comment = stmt.var, stmt.op, stmt.expr, stmt.comment
        # invalidate any subexpressions including var, recursively
        # we do this by having a set of variables that are invalid that we
        # start with the changed var and increase by any subexpressions we
        # find that have a dependency on something in the invalid set. We
        # go through in sorted subexpression order so that the invalid set
        # is increased in the right order
        invalid = {var}
        for subvar in sorted_subexpr_vars:
            spec = subexpressions[subvar]
            spec_ids = set(spec.identifiers)
            if spec_ids.intersection(invalid):
                valid[subvar] = False
                invalid.add(subvar)
        # constant only if we are declaring a new variable and we will not
        # write to it again
        constant = op == ':=' and var not in will_write
        statement = Statement(var,
                              op,
                              expr,
                              comment,
                              dtype=variables[var].dtype,
                              constant=constant,
                              scalar=variables[var].scalar)
        statements.append(statement)

    if DEBUG:
        print 'OUTPUT STATEMENTS:'
        for stmt in statements:
            print stmt

    scalar_statements = [s for s in statements if s.scalar]
    vector_statements = [s for s in statements if not s.scalar]

    if prefs.codegen.loop_invariant_optimisations:
        scalar_constants, vector_statements = apply_loop_invariant_optimisations(
            vector_statements, variables, dtype)
        scalar_statements.extend(scalar_constants)

    return scalar_statements, vector_statements
コード例 #38
0
ファイル: exact.py プロジェクト: boddmg/brian2
    def __call__(self, equations, variables=None, simplify=True):

        if variables is None:
            variables = {}

        # Get a representation of the ODE system in the form of
        # dX/dt = M*X + B
        varnames, matrix, constants = get_linear_system(equations)

        # No differential equations, nothing to do (this occurs sometimes in the
        # test suite where the whole model is nothing more than something like
        # 'v : 1')
        if matrix.shape == (0, 0):
            return ''

        # Make sure that the matrix M is constant, i.e. it only contains
        # external variables or constant variables
        t = Symbol('t', real=True, positive=True)

        # Check for time dependence
        if 'dt' in variables:
            dt_value = variables['dt'].get_value()[0]

            # This will raise an error if we meet the symbol "t" anywhere
            # except as an argument of a locally constant function
            t = Symbol('t', real=True, positive=True)
            for entry in itertools.chain(matrix, constants):
                _check_for_locally_constant(entry, variables, dt_value, t)

        symbols = [Symbol(variable, real=True) for variable in varnames]
        solution = sp.solve_linear_system(matrix.row_join(constants), *symbols)
        b = sp.ImmutableMatrix([solution[symbol] for symbol in symbols]).transpose()

        # Solve the system
        dt = Symbol('dt', real=True, positive=True)
        A = (matrix * dt).exp()
        if simplify:
            A.simplify()
        C = sp.ImmutableMatrix([A.dot(b)]) - b
        _S = sp.MatrixSymbol('_S', len(varnames), 1)
        # The use of .as_mutable() here is a workaround for a
        # ``Transpose object does not have
        updates = A * _S + C.transpose()
        try:
            # In sympy 0.7.3, we have to explicitly convert it to a single matrix
            # In sympy 0.7.2, it is already a matrix (which doesn't have an
            # is_explicit method)
            updates = updates.as_explicit()
        except AttributeError:
            pass

        # The solution contains _S[0, 0], _S[1, 0] etc. for the state variables,
        # replace them with the state variable names 
        abstract_code = []
        for idx, (variable, update) in enumerate(zip(varnames, updates)):
            rhs = update
            for row_idx, varname in enumerate(varnames):
                rhs = rhs.subs(_S[row_idx, 0], varname)

            # Do not overwrite the real state variables yet, the update step
            # of other state variables might still need the original values
            abstract_code.append('_' + variable + ' = ' + sympy_to_str(rhs))

        # Update the state variables
        for variable in varnames:
            abstract_code.append('{variable} = _{variable}'.format(variable=variable))
        return '\n'.join(abstract_code)
コード例 #39
0
ファイル: exact.py プロジェクト: brian-team/brian2
    def __call__(self, equations, variables=None, method_options=None):
        method_options = extract_method_options(method_options,
                                                {'simplify': True})

        if equations.is_stochastic:
            raise UnsupportedEquationsException('Cannot solve stochastic '
                                                'equations with this state '
                                                'updater.')
        if variables is None:
            variables = {}

        # Get a representation of the ODE system in the form of
        # dX/dt = M*X + B
        varnames, matrix, constants = get_linear_system(equations, variables)

        # No differential equations, nothing to do (this occurs sometimes in the
        # test suite where the whole model is nothing more than something like
        # 'v : 1')
        if matrix.shape == (0, 0):
            return ''

        # Make sure that the matrix M is constant, i.e. it only contains
        # external variables or constant variables
        t = Symbol('t', real=True, positive=True)

        # Check for time dependence
        dt_value = variables['dt'].get_value()[0] if 'dt' in variables else None

        # This will raise an error if we meet the symbol "t" anywhere
        # except as an argument of a locally constant function
        for entry in itertools.chain(matrix, constants):
            if not is_constant_over_dt(entry, variables, dt_value):
                raise UnsupportedEquationsException(
                    ('Expression "{}" is not guaranteed to be constant over a '
                     'time step').format(sympy_to_str(entry)))

        symbols = [Symbol(variable, real=True) for variable in varnames]
        solution = sp.solve_linear_system(matrix.row_join(constants), *symbols)
        if solution is None or set(symbols) != set(solution.keys()):
            raise UnsupportedEquationsException('Cannot solve the given '
                                                'equations with this '
                                                'stateupdater.')
        b = sp.ImmutableMatrix([solution[symbol] for symbol in symbols])

        # Solve the system
        dt = Symbol('dt', real=True, positive=True)
        try:
            A = (matrix * dt).exp()
        except NotImplementedError:
            raise UnsupportedEquationsException('Cannot solve the given '
                                                'equations with this '
                                                'stateupdater.')
        if method_options['simplify']:
            A = A.applyfunc(lambda x:
                            sp.factor_terms(sp.cancel(sp.signsimp(x))))
        C = sp.ImmutableMatrix(A * b) - b
        _S = sp.MatrixSymbol('_S', len(varnames), 1)
        updates = A * _S + C
        updates = updates.as_explicit()

        # The solution contains _S[0, 0], _S[1, 0] etc. for the state variables,
        # replace them with the state variable names 
        abstract_code = []
        for idx, (variable, update) in enumerate(zip(varnames, updates)):
            rhs = update
            if rhs.has(I, re, im):
                raise UnsupportedEquationsException('The solution to the linear system '
                                                    'contains complex values '
                                                    'which is currently not implemented.')
            for row_idx, varname in enumerate(varnames):
                rhs = rhs.subs(_S[row_idx, 0], varname)

            # Do not overwrite the real state variables yet, the update step
            # of other state variables might still need the original values
            abstract_code.append('_' + variable + ' = ' + sympy_to_str(rhs))

        # Update the state variables
        for variable in varnames:
            abstract_code.append('{variable} = _{variable}'.format(variable=variable))
        return '\n'.join(abstract_code)
コード例 #40
0
ファイル: test_parsing.py プロジェクト: treestreamymw/brian2
def test_sympy_infinity():
    # See github issue #1061
    assert sympy_to_str(str_to_sympy('inf')) == 'inf'
    assert sympy_to_str(str_to_sympy('-inf')) == '-inf'
コード例 #41
0
ファイル: exact.py プロジェクト: brian-team/brian2
    def __call__(self, equations, variables=None, method_options=None):
        logger.warn("The 'independent' state updater is deprecated and might be "
                    "removed in future versions of Brian.",
                    'deprecated_independent', once=True)
        method_options = extract_method_options(method_options, {})
        if equations.is_stochastic:
            raise UnsupportedEquationsException('Cannot solve stochastic '
                                                'equations with this state '
                                                'updater')
        if variables is None:
            variables = {}

        diff_eqs = equations.get_substituted_expressions(variables)

        t = Symbol('t', real=True, positive=True)
        dt = Symbol('dt', real=True, positive=True)
        t0 = Symbol('t0', real=True, positive=True)

        code = []
        for name, expression in diff_eqs:
            rhs = str_to_sympy(expression.code, variables)

            # We have to be careful and use the real=True assumption as well,
            # otherwise sympy doesn't consider the symbol a match to the content
            # of the equation
            var = Symbol(name, real=True)
            f = sp.Function(name)
            rhs = rhs.subs(var, f(t))
            derivative = sp.Derivative(f(t), t)
            diff_eq = sp.Eq(derivative, rhs)
            # TODO: simplify=True sometimes fails with 0.7.4, see:
            # https://github.com/sympy/sympy/issues/2666
            try:
                general_solution = sp.dsolve(diff_eq, f(t), simplify=True)
            except RuntimeError:
                general_solution = sp.dsolve(diff_eq, f(t), simplify=False)
            # Check whether this is an explicit solution
            if not getattr(general_solution, 'lhs', None) == f(t):
                raise UnsupportedEquationsException('Cannot explicitly solve: '
                                                    + str(diff_eq))
            # Solve for C1 (assuming "var" as the initial value and "t0" as time)
            if general_solution.has(Symbol('C1')):
                if general_solution.has(Symbol('C2')):
                    raise UnsupportedEquationsException('Too many constants in solution: %s' % str(general_solution))
                constant_solution = sp.solve(general_solution, Symbol('C1'))
                if len(constant_solution) != 1:
                    raise UnsupportedEquationsException(("Couldn't solve for the constant "
                                                         "C1 in : %s ") % str(general_solution))
                constant = constant_solution[0].subs(t, t0).subs(f(t0), var)
                solution = general_solution.rhs.subs('C1', constant)
            else:
                solution = general_solution.rhs.subs(t, t0).subs(f(t0), var)
            # Evaluate the expression for one timestep
            solution = solution.subs(t, t + dt).subs(t0, t)
            # only try symplifying it -- it sometimes raises an error
            try:
                solution = solution.simplify()
            except ValueError:
                pass

            code.append(name + ' = ' + sympy_to_str(solution))

        return '\n'.join(code)
コード例 #42
0
ファイル: exact.py プロジェクト: ttxtea/brian2
    def __call__(self, equations, variables=None):
        if variables is None:
            variables = {}

        if equations.is_stochastic:
            raise ValueError('Cannot solve stochastic equations with this state updater')

        diff_eqs = equations.substituted_expressions

        t = Symbol('t', real=True, positive=True)
        dt = Symbol('dt', real=True, positive=True)
        t0 = Symbol('t0', real=True, positive=True)
        f0 = Symbol('f0', real=True)
        # TODO: Shortcut for simple linear equations? Is all this effort really
        #       worth it?

        code = []
        for name, expression in diff_eqs:
            rhs = expression.sympy_expr
            non_constant = _non_constant_symbols(rhs.atoms(), variables, t) - {
            name}
            if len(non_constant):
                raise ValueError(('Equation for %s referred to non-constant '
                                  'variables %s') % (name, str(non_constant)))
            # We have to be careful and use the real=True assumption as well,
            # otherwise sympy doesn't consider the symbol a match to the content
            # of the equation
            var = Symbol(name, real=True)
            f = sp.Function(name)
            rhs = rhs.subs(var, f(t))
            derivative = sp.Derivative(f(t), t)
            diff_eq = sp.Eq(derivative, rhs)
            # TODO: simplify=True sometimes fails with 0.7.4, see:
            # https://github.com/sympy/sympy/issues/2666
            try:
                general_solution = sp.dsolve(diff_eq, f(t), simplify=True)
            except RuntimeError:
                general_solution = sp.dsolve(diff_eq, f(t), simplify=False)
            # Check whether this is an explicit solution
            if not getattr(general_solution, 'lhs', None) == f(t):
                raise ValueError('Cannot explicitly solve: ' + str(diff_eq))
            # seems to happen sometimes in sympy 0.7.5
            if getattr(general_solution, 'rhs', None) == sp.nan:
                raise ValueError('Cannot explicitly solve: ' + str(diff_eq))
            # Solve for C1 (assuming "var" as the initial value and "t0" as time)
            if general_solution.has(Symbol('C1')):
                if general_solution.has(Symbol('C2')):
                    raise ValueError('Too many constants in solution: %s' % str(general_solution))
                constant_solution = sp.solve(general_solution, Symbol('C1'))
                if len(constant_solution) != 1:
                    raise ValueError(("Couldn't solve for the constant "
                                      "C1 in : %s ") % str(general_solution))
                constant = constant_solution[0].subs(t, t0).subs(f(t0), var)
                solution = general_solution.rhs.subs('C1', constant)
            else:
                solution = general_solution.rhs.subs(t, t0).subs(f(t0), var)
            # Evaluate the expression for one timestep
            solution = solution.subs(t, t + dt).subs(t0, t)
            # only try symplifying it -- it sometimes raises an error
            try:
                solution = solution.simplify()
            except ValueError:
                pass

            code.append(name + ' = ' + sympy_to_str(solution))

        return '\n'.join(code)
コード例 #43
0
ファイル: test_parsing.py プロジェクト: brian-team/brian2
 def evaluator(expr, ns):
     expr = sympy_to_str(expr)
     ns = dict(ns)
     # Add the floor function which is used to implement floor division
     ns['floor'] = DEFAULT_FUNCTIONS['floor']
     return eval(expr, ns)
コード例 #44
0
ファイル: exact.py プロジェクト: francesconero/brian2
    def __call__(self, equations, variables=None):
        if variables is None:
            variables = {}

        if equations.is_stochastic:
            raise ValueError('Cannot solve stochastic equations with this state updater')

        diff_eqs = equations.substituted_expressions

        t = Symbol('t', real=True, positive=True)
        dt = Symbol('dt', real=True, positive=True)
        t0 = Symbol('t0', real=True, positive=True)
        f0 = Symbol('f0', real=True)
        # TODO: Shortcut for simple linear equations? Is all this effort really
        #       worth it?

        code = []
        for name, expression in diff_eqs:
            rhs = expression.sympy_expr
            non_constant = _non_constant_symbols(rhs.atoms(), variables, t) - set([name])
            if len(non_constant):
                raise ValueError(('Equation for %s referred to non-constant '
                                  'variables %s') % (name, str(non_constant)))
            # We have to be careful and use the real=True assumption as well,
            # otherwise sympy doesn't consider the symbol a match to the content
            # of the equation
            var = Symbol(name, real=True)
            f = sp.Function(name)
            rhs = rhs.subs(var, f(t))
            derivative = sp.Derivative(f(t), t)
            diff_eq = sp.Eq(derivative, rhs)
            # TODO: simplify=True sometimes fails with 0.7.4, see:
            # https://github.com/sympy/sympy/issues/2666
            try:
                general_solution = sp.dsolve(diff_eq, f(t), simplify=True)
            except (RuntimeError, AttributeError):  #AttributeError seems to be raised on Python 2.6
                general_solution = sp.dsolve(diff_eq, f(t), simplify=False)
            # Check whether this is an explicit solution
            if not getattr(general_solution, 'lhs', None) == f(t):
                raise ValueError('Cannot explicitly solve: ' + str(diff_eq))
            # Solve for C1 (assuming "var" as the initial value and "t0" as time)
            if Symbol('C1') in general_solution:
                if Symbol('C2') in general_solution:
                    raise ValueError('Too many constants in solution: %s' % str(general_solution))
                constant_solution = sp.solve(general_solution, Symbol('C1'))
                if len(constant_solution) != 1:
                    raise ValueError(("Couldn't solve for the constant "
                                      "C1 in : %s ") % str(general_solution))
                constant = constant_solution[0].subs(t, t0).subs(f(t0), var)
                solution = general_solution.rhs.subs('C1', constant)
            else:
                solution = general_solution.rhs.subs(t, t0).subs(f(t0), var)
            # Evaluate the expression for one timestep
            solution = solution.subs(t, t + dt).subs(t0, t)
            # only try symplifying it -- it sometimes raises an error
            try:
                solution = solution.simplify()
            except ValueError:
                pass

            code.append(name + ' = ' + sympy_to_str(solution))

        return '\n'.join(code)
コード例 #45
0
ファイル: translation.py プロジェクト: rcaze/brian2
def make_statements(code, variables, dtype, optimise=True):
    '''
    Turn a series of abstract code statements into Statement objects, inferring
    whether each line is a set/declare operation, whether the variables are
    constant or not, and handling the cacheing of subexpressions.

    Parameters
    ----------
    code : str
        A (multi-line) string of statements.
    variables : dict-like
        A dictionary of with `Variable` and `Function` objects for every
        identifier used in the `code`.
    dtype : `dtype`
        The data type to use for temporary variables
    optimise : bool, optional
        Whether to optimise expressions, including
        pulling out loop invariant expressions and putting them in new
        scalar constants. Defaults to ``False``, since this function is also
        used just to in contexts where we are not interested by this kind of
        optimisation. For the main code generation stage, its value is set by
        the `codegen.loop_invariant_optimisations` preference.
    Returns
    -------
    scalar_statements, vector_statements : (list of `Statement`, list of `Statement`)
        Lists with statements that are to be executed once and statements that
        are to be executed once for every neuron/synapse/... (or in a vectorised
        way)

    Notes
    -----
    If ``optimise`` is ``True``, then the
    ``scalar_statements`` may include newly introduced scalar constants that
    have been identified as loop-invariant and have therefore been pulled out
    of the vector statements. The resulting statements will also use augmented
    assignments where possible, i.e. a statement such as ``w = w + 1`` will be
    replaced by ``w += 1``. Also, statements involving booleans will have
    additional information added to them (see `Statement` for details)
    describing how the statement can be reformulated as a sequence of if/then
    statements. Calls `~brian2.codegen.optimisation.optimise_statements`.
    '''
    code = strip_empty_lines(deindent(code))
    lines = re.split(r'[;\n]', code)
    lines = [LineInfo(code=line) for line in lines if len(line)]
    # Do a copy so we can add stuff without altering the original dict
    variables = dict(variables)
    # we will do inference to work out which lines are := and which are =
    defined = set(k for k, v in variables.iteritems()
                  if not isinstance(v, AuxiliaryVariable))
    for line in lines:
        statement = None
        # parse statement into "var op expr"
        var, op, expr, comment = parse_statement(line.code)
        if op == '=':
            if var not in defined:
                op = ':='
                defined.add(var)
                if var not in variables:
                    is_scalar = is_scalar_expression(expr, variables)
                    new_var = AuxiliaryVariable(var, Unit(1), # doesn't matter here
                                                dtype=dtype, scalar=is_scalar)
                    variables[var] = new_var
            elif not variables[var].is_boolean:
                sympy_expr = str_to_sympy(expr, variables)
                sympy_var = sympy.Symbol(var, real=True)
                try:
                    collected = sympy.collect(sympy_expr, sympy_var,
                                              exact=True, evaluate=False)
                except AttributeError:
                    # If something goes wrong during collection, e.g. collect
                    # does not work for logical expressions
                    collected = {1: sympy_expr}

                if (len(collected) == 2 and
                        set(collected.keys()) == {1, sympy_var} and
                        collected[sympy_var] == 1):
                    # We can replace this statement by a += assignment
                    statement = Statement(var, '+=',
                                          sympy_to_str(collected[1]),
                                          comment,
                                          dtype=variables[var].dtype,
                                          scalar=variables[var].scalar)
                elif len(collected) == 1 and sympy_var in collected:
                    # We can replace this statement by a *= assignment
                    statement = Statement(var, '*=',
                                          sympy_to_str(collected[sympy_var]),
                                          comment,
                                          dtype=variables[var].dtype,
                                          scalar=variables[var].scalar)
        if statement is None:
            statement = Statement(var, op, expr, comment,
                                  dtype=variables[var].dtype,
                                  scalar=variables[var].scalar)

        line.statement = statement
        # for each line will give the variable being written to
        line.write = var 
        # each line will give a set of variables which are read
        line.read = get_identifiers_recursively([expr], variables)

    # All writes to scalar variables must happen before writes to vector
    # variables
    scalar_write_done = False
    for line in lines:
        stmt = line.statement
        if stmt.op != ':=' and variables[stmt.var].scalar and scalar_write_done:
            raise SyntaxError(('All writes to scalar variables in a code block '
                               'have to be made before writes to vector '
                               'variables. Illegal write to %s.') % line.write)
        elif not variables[stmt.var].scalar:
            scalar_write_done = True

    # all variables which are written to at some point in the code block
    # used to determine whether they should be const or not
    all_write = set(line.write for line in lines)

    # backwards compute whether or not variables will be read again
    # note that will_read for a line gives the set of variables it will read
    # on the current line or subsequent ones. will_write gives the set of
    # variables that will be written after the current line
    will_read = set()
    will_write = set()
    for line in lines[::-1]:
        will_read = will_read.union(line.read)
        line.will_read = will_read.copy()
        line.will_write = will_write.copy()
        will_write.add(line.write)

    # generate cacheing statements for common subexpressions
    # cached subexpressions need to be recomputed whenever they are to be used
    # on the next line, and currently invalid (meaning that the current value
    # stored in the subexpression variable is no longer accurate because one
    # of the variables appearing in it has changed). All subexpressions start
    # as invalid, and are invalidated whenever one of the variables appearing
    # in the RHS changes value.
    subexpressions = dict((name, val) for name, val in variables.items() if isinstance(val, Subexpression))
    # sort subexpressions into an order so that subexpressions that don't depend
    # on other subexpressions are first
    subexpr_deps = dict((name, [dep for dep in subexpr.identifiers if dep in subexpressions]) for \
                                                            name, subexpr in subexpressions.items())
    sorted_subexpr_vars = topsort(subexpr_deps)

    statements = []
    # all start as invalid
    valid = dict((name, False) for name in subexpressions.keys())
    # none are yet defined (or declared)
    subdefined = dict((name, False) for name in subexpressions.keys())
    for line in lines:
        stmt = line.statement
        read = line.read
        write = line.write
        will_read = line.will_read
        will_write = line.will_write
        # check that all subexpressions in expr are valid, and if not
        # add a definition/set its value, and set it to be valid
        # scan through in sorted order so that recursive subexpression dependencies
        # are handled in the right order
        for var in sorted_subexpr_vars:
            if var not in read:
                continue
            # if subexpression, and invalid
            if not valid.get(var, True): # all non-subexpressions are valid
                subexpression = subexpressions[var]
                # if already defined/declared
                if subdefined[var]:
                    op = '='
                    constant = False
                else:
                    op = ':='
                    subdefined[var] = True
                    # set to constant only if we will not write to it again
                    constant = var not in will_write
                    # check all subvariables are not written to again as well
                    if constant:
                        ids = subexpression.identifiers
                        constant = all(v not in will_write for v in ids)
                valid[var] = True
                statement = Statement(var, op, subexpression.expr, comment='',
                                      dtype=variables[var].dtype,
                                      constant=constant,
                                      subexpression=True,
                                      scalar=variables[var].scalar)
                statements.append(statement)
        var, op, expr, comment = stmt.var, stmt.op, stmt.expr, stmt.comment
        # invalidate any subexpressions including var, recursively
        # we do this by having a set of variables that are invalid that we
        # start with the changed var and increase by any subexpressions we
        # find that have a dependency on something in the invalid set. We
        # go through in sorted subexpression order so that the invalid set
        # is increased in the right order
        invalid = {var}
        for subvar in sorted_subexpr_vars:
            spec = subexpressions[subvar]
            spec_ids = set(spec.identifiers)
            if spec_ids.intersection(invalid):
                valid[subvar] = False
                invalid.add(subvar)
        # constant only if we are declaring a new variable and we will not
        # write to it again
        constant = op==':=' and var not in will_write
        statement = Statement(var, op, expr, comment,
                              dtype=variables[var].dtype,
                              constant=constant,
                              scalar=variables[var].scalar)
        statements.append(statement)

    scalar_statements = [s for s in statements if s.scalar]
    vector_statements = [s for s in statements if not s.scalar]

    if optimise and prefs.codegen.loop_invariant_optimisations:
        scalar_statements, vector_statements = optimise_statements(scalar_statements,
                                                                   vector_statements,
                                                                   variables)

    return scalar_statements, vector_statements
コード例 #46
0
ファイル: explicit.py プロジェクト: yger/brian2
    def _generate_RHS(self, eqs, var, eq_symbols, temp_vars, expr,
                      non_stochastic_expr, stochastic_expr):
        '''
        Helper function used in `__call__`. Generates the right hand side of
        an abstract code statement by appropriately replacing f, g and t.
        For example, given a differential equation ``dv/dt = -(v + I) / tau``
        (i.e. `var` is ``v` and `expr` is ``(-v + I) / tau``) together with
        the `rk2` step ``return x + dt*f(x +  k/2, t + dt/2)``
        (i.e. `non_stochastic_expr` is
        ``x + dt*f(x +  k/2, t + dt/2)`` and `stochastic_expr` is ``None``),
        produces ``v + dt*(-v - _k_v/2 + I + _k_I/2)/tau``.
                
        '''
        
        def replace_func(x, t, expr, temp_vars):
            '''
            Used to replace a single occurance of ``f(x, t)`` or ``g(x, t)``:
            `expr` is the non-stochastic (in the case of ``f``) or stochastic
            part (``g``) of the expression defining the right-hand-side of the
            differential equation describing `var`. It replaces the variable
            `var` with the value given as `x` and `t` by the value given for
            `t. Intermediate variables will be replaced with the appropriate
            replacements as well.
            
            For example, in the `rk2` integrator, the second step involves the
            calculation of ``f(k/2 + x, dt/2 + t)``.  If `var` is ``v`` and
            `expr` is ``-v / tau``, this will result in ``-(_k_v/2 + v)/tau``.
            
            Note that this deals with only one state variable `var`, given as
            an argument to the surrounding `_generate_RHS` function.
            '''

            try:
                s_expr = str_to_sympy(str(expr))
            except SympifyError as ex:
                raise ValueError('Error parsing the expression "%s": %s' %
                                 (expr, str(ex)))

            for var in eq_symbols:
                # Generate specific temporary variables for the state variable,
                # e.g. '_k_v' for the state variable 'v' and the temporary
                # variable 'k'.
                temp_var_replacements = dict(((self.symbols[temp_var],
                                               _symbol('_'+temp_var+'_'+var))
                                              for temp_var in temp_vars))
                # In the expression given as 'x', replace 'x' by the variable
                # 'var' and all the temporary variables by their
                # variable-specific counterparts.
                x_replacement = x.subs(self.symbols['x'], eq_symbols[var])
                x_replacement = x_replacement.subs(temp_var_replacements)
                
                # Replace the variable `var` in the expression by the new `x`
                # expression
                s_expr = s_expr.subs(eq_symbols[var], x_replacement)
                
            # Directly substitute the 't' expression for the symbol t, there
            # are no temporary variables to consider here.             
            s_expr = s_expr.subs(self.symbols['t'], t)
            
            return s_expr
        
        # Note: in the following we are silently ignoring the case that a
        # state updater does not care about either the non-stochastic or the
        # stochastic part of an equation. We do trust state updaters to
        # correctly specify their own abilities (i.e. they do not claim to
        # support stochastic equations but actually just ignore the stochastic
        # part). We can't really check the issue here, as we are only dealing
        # with one line of the state updater description. It is perfectly valid
        # to write the euler update as:
        #     non_stochastic = dt * f(x, t)
        #     stochastic = dt**.5 * g(x, t) * xi
        #     return x + non_stochastic + stochastic
        #
        # In the above case, we'll deal with lines which do not define either
        # the stochastic or the non-stochastic part.
        
        non_stochastic, stochastic = expr.split_stochastic()
        # We do have a non-stochastic part in our equation and in the state
        # updater description 
        if not (non_stochastic is None or non_stochastic_expr is None):
            # Replace the f(x, t) part
            replace_f = lambda x, t:replace_func(x, t, non_stochastic,
                                                 temp_vars)
            non_stochastic_result = non_stochastic_expr.replace(self.symbols['f'],
                                                                replace_f)
            # Replace x by the respective variable
            non_stochastic_result = non_stochastic_result.subs(self.symbols['x'],
                                                               eq_symbols[var])
            # Replace intermediate variables
            temp_var_replacements = dict((self.symbols[temp_var],
                                           _symbol('_'+temp_var+'_'+var))
                                         for temp_var in temp_vars)
            non_stochastic_result = non_stochastic_result.subs(temp_var_replacements)
        else:
            non_stochastic_result = None

        # We do have a stochastic part in our equation and in the state updater
        # description
        if not (stochastic is None or stochastic_expr is None):
            stochastic_results = []
            
            # We potentially have more than one stochastic variable
            for xi in stochastic:
                # Replace the g(x, t)*xi part
                replace_g = lambda x, t:replace_func(x, t, stochastic[xi],
                                                     temp_vars)
                stochastic_result = stochastic_expr.replace(self.symbols['g'],
                                                            replace_g)
                
                # Replace x and xi by the respective variables
                stochastic_result = stochastic_result.subs(self.symbols['x'],
                                                           eq_symbols[var])
                stochastic_result = stochastic_result.subs(self.symbols['dW'], xi)   

                # Replace intermediate variables
                temp_var_replacements = dict((self.symbols[temp_var],
                                               _symbol('_'+temp_var+'_'+var))
                                             for temp_var in temp_vars)
                
                stochastic_result = stochastic_result.subs(temp_var_replacements)

                stochastic_results.append(stochastic_result)                        
        else:
            stochastic_results = []
        
        RHS = []
        # All the parts (one non-stochastic and potentially more than one
        # stochastic part) are combined with addition
        if non_stochastic_result is not None:
            RHS.append(sympy_to_str(non_stochastic_result))
        for stochastic_result in stochastic_results:
            RHS.append(sympy_to_str(stochastic_result))
        
        RHS = ' + '.join(RHS)
        return RHS
コード例 #47
0
def make_statements(code, variables, dtype, optimise=True, blockname=''):
    '''
    Turn a series of abstract code statements into Statement objects, inferring
    whether each line is a set/declare operation, whether the variables are
    constant or not, and handling the cacheing of subexpressions.

    Parameters
    ----------
    code : str
        A (multi-line) string of statements.
    variables : dict-like
        A dictionary of with `Variable` and `Function` objects for every
        identifier used in the `code`.
    dtype : `dtype`
        The data type to use for temporary variables
    optimise : bool, optional
        Whether to optimise expressions, including
        pulling out loop invariant expressions and putting them in new
        scalar constants. Defaults to ``False``, since this function is also
        used just to in contexts where we are not interested by this kind of
        optimisation. For the main code generation stage, its value is set by
        the `codegen.loop_invariant_optimisations` preference.
    blockname : str, optional
        A name for the block (used to name intermediate variables to avoid
        name clashes when multiple blocks are used together)
    Returns
    -------
    scalar_statements, vector_statements : (list of `Statement`, list of `Statement`)
        Lists with statements that are to be executed once and statements that
        are to be executed once for every neuron/synapse/... (or in a vectorised
        way)

    Notes
    -----
    If ``optimise`` is ``True``, then the
    ``scalar_statements`` may include newly introduced scalar constants that
    have been identified as loop-invariant and have therefore been pulled out
    of the vector statements. The resulting statements will also use augmented
    assignments where possible, i.e. a statement such as ``w = w + 1`` will be
    replaced by ``w += 1``. Also, statements involving booleans will have
    additional information added to them (see `Statement` for details)
    describing how the statement can be reformulated as a sequence of if/then
    statements. Calls `~brian2.codegen.optimisation.optimise_statements`.
    '''
    code = strip_empty_lines(deindent(code))
    lines = re.split(r'[;\n]', code)
    lines = [LineInfo(code=line) for line in lines if len(line)]
    # Do a copy so we can add stuff without altering the original dict
    variables = dict(variables)
    # we will do inference to work out which lines are := and which are =
    defined = set(k for k, v in variables.iteritems()
                  if not isinstance(v, AuxiliaryVariable))
    for line in lines:
        statement = None
        # parse statement into "var op expr"
        var, op, expr, comment = parse_statement(line.code)
        if op == '=':
            if var not in defined:
                op = ':='
                defined.add(var)
                if var not in variables:
                    is_scalar = is_scalar_expression(expr, variables)
                    new_var = AuxiliaryVariable(var,
                                                dtype=dtype,
                                                scalar=is_scalar)
                    variables[var] = new_var
            elif not variables[var].is_boolean:
                sympy_expr = str_to_sympy(expr, variables)
                sympy_var = sympy.Symbol(var, real=True)
                try:
                    collected = sympy.collect(sympy_expr,
                                              sympy_var,
                                              exact=True,
                                              evaluate=False)
                except AttributeError:
                    # If something goes wrong during collection, e.g. collect
                    # does not work for logical expressions
                    collected = {1: sympy_expr}

                if (len(collected) == 2
                        and set(collected.keys()) == {1, sympy_var}
                        and collected[sympy_var] == 1):
                    # We can replace this statement by a += assignment
                    statement = Statement(var,
                                          '+=',
                                          sympy_to_str(collected[1]),
                                          comment,
                                          dtype=variables[var].dtype,
                                          scalar=variables[var].scalar)
                elif len(collected) == 1 and sympy_var in collected:
                    # We can replace this statement by a *= assignment
                    statement = Statement(var,
                                          '*=',
                                          sympy_to_str(collected[sympy_var]),
                                          comment,
                                          dtype=variables[var].dtype,
                                          scalar=variables[var].scalar)
        if statement is None:
            statement = Statement(var,
                                  op,
                                  expr,
                                  comment,
                                  dtype=variables[var].dtype,
                                  scalar=variables[var].scalar)

        line.statement = statement
        # for each line will give the variable being written to
        line.write = var
        # each line will give a set of variables which are read
        line.read = get_identifiers_recursively([expr], variables)

    # All writes to scalar variables must happen before writes to vector
    # variables
    scalar_write_done = False
    for line in lines:
        stmt = line.statement
        if stmt.op != ':=' and variables[
                stmt.var].scalar and scalar_write_done:
            raise SyntaxError(
                ('All writes to scalar variables in a code block '
                 'have to be made before writes to vector '
                 'variables. Illegal write to %s.') % line.write)
        elif not variables[stmt.var].scalar:
            scalar_write_done = True

    # all variables which are written to at some point in the code block
    # used to determine whether they should be const or not
    all_write = set(line.write for line in lines)

    # backwards compute whether or not variables will be read again
    # note that will_read for a line gives the set of variables it will read
    # on the current line or subsequent ones. will_write gives the set of
    # variables that will be written after the current line
    will_read = set()
    will_write = set()
    for line in lines[::-1]:
        will_read = will_read.union(line.read)
        line.will_read = will_read.copy()
        line.will_write = will_write.copy()
        will_write.add(line.write)

    subexpressions = dict((name, val) for name, val in variables.items()
                          if isinstance(val, Subexpression))
    # sort subexpressions into an order so that subexpressions that don't depend
    # on other subexpressions are first
    subexpr_deps = dict((name, [dep for dep in subexpr.identifiers if dep in subexpressions]) for \
                                                            name, subexpr in subexpressions.items())
    sorted_subexpr_vars = topsort(subexpr_deps)

    statements = []

    # none are yet defined (or declared)
    subdefined = dict((name, False) for name in subexpressions.keys())
    for line in lines:
        stmt = line.statement
        read = line.read
        write = line.write
        will_read = line.will_read
        will_write = line.will_write
        # update/define all subexpressions needed by this statement
        for var in sorted_subexpr_vars:
            if var not in read:
                continue

            subexpression = subexpressions[var]
            # if already defined/declared
            if subdefined[var]:
                op = '='
                constant = False
            else:
                op = ':='
                subdefined[var] = True
                # set to constant only if we will not write to it again
                constant = var not in will_write
                # check all subvariables are not written to again as well
                if constant:
                    ids = subexpression.identifiers
                    constant = all(v not in will_write for v in ids)

            statement = Statement(var,
                                  op,
                                  subexpression.expr,
                                  comment='',
                                  dtype=variables[var].dtype,
                                  constant=constant,
                                  subexpression=True,
                                  scalar=variables[var].scalar)
            statements.append(statement)

        var, op, expr, comment = stmt.var, stmt.op, stmt.expr, stmt.comment

        # constant only if we are declaring a new variable and we will not
        # write to it again
        constant = op == ':=' and var not in will_write
        statement = Statement(var,
                              op,
                              expr,
                              comment,
                              dtype=variables[var].dtype,
                              constant=constant,
                              scalar=variables[var].scalar)
        statements.append(statement)

    scalar_statements = [s for s in statements if s.scalar]
    vector_statements = [s for s in statements if not s.scalar]

    if optimise and prefs.codegen.loop_invariant_optimisations:
        scalar_statements, vector_statements = optimise_statements(
            scalar_statements,
            vector_statements,
            variables,
            blockname=blockname)

    return scalar_statements, vector_statements
コード例 #48
0
ファイル: spatialneuron.py プロジェクト: rsantana-isg/brian2
    def __init__(self,
                 morphology=None,
                 model=None,
                 threshold=None,
                 refractory=False,
                 reset=None,
                 events=None,
                 threshold_location=None,
                 dt=None,
                 clock=None,
                 order=0,
                 Cm=0.9 * uF / cm**2,
                 Ri=150 * ohm * cm,
                 name='spatialneuron*',
                 dtype=None,
                 namespace=None,
                 method=('linear', 'exponential_euler', 'rk2', 'heun')):

        # #### Prepare and validate equations
        if isinstance(model, basestring):
            model = Equations(model)
        if not isinstance(model, Equations):
            raise TypeError(('model has to be a string or an Equations '
                             'object, is "%s" instead.') % type(model))

        # Insert the threshold mechanism at the specified location
        if threshold_location is not None:
            if hasattr(threshold_location,
                       '_indices'):  # assuming this is a method
                threshold_location = threshold_location._indices()
                # for now, only a single compartment allowed
                if len(threshold_location) == 1:
                    threshold_location = threshold_location[0]
                else:
                    raise AttributeError(('Threshold can only be applied on a '
                                          'single location'))
            threshold = '(' + threshold + ') and (i == ' + str(
                threshold_location) + ')'

        # Check flags (we have point currents)
        model.check_flags({
            DIFFERENTIAL_EQUATION: ('point current', ),
            PARAMETER: ('constant', 'shared', 'linked', 'point current'),
            SUBEXPRESSION: ('shared', 'point current', 'constant over dt')
        })
        #: The original equations as specified by the user (i.e. before
        #: inserting point-currents into the membrane equation, before adding
        #: all the internally used variables and constants, etc.).
        self.user_equations = model

        # Separate subexpressions depending whether they are considered to be
        # constant over a time step or not (this would also be done by the
        # NeuronGroup initializer later, but this would give incorrect results
        # for the linearity check)
        model, constant_over_dt = extract_constant_subexpressions(model)

        # Extract membrane equation
        if 'Im' in model:
            if len(model['Im'].flags):
                raise TypeError(
                    'Cannot specify any flags for the transmembrane '
                    'current Im.')
            membrane_expr = model['Im'].expr  # the membrane equation
        else:
            raise TypeError('The transmembrane current Im must be defined')

        model_equations = []
        # Insert point currents in the membrane equation
        for eq in model.itervalues():
            if eq.varname == 'Im':
                continue  # ignore -- handled separately
            if 'point current' in eq.flags:
                fail_for_dimension_mismatch(
                    eq.dim, amp,
                    "Point current " + eq.varname + " should be in amp")
                membrane_expr = Expression(
                    str(membrane_expr.code) + '+' + eq.varname + '/area')
                eq = SingleEquation(
                    eq.type,
                    eq.varname,
                    eq.dim,
                    expr=eq.expr,
                    flags=list(set(eq.flags) - set(['point current'])))
            model_equations.append(eq)

        model_equations.append(
            SingleEquation(SUBEXPRESSION,
                           'Im',
                           dimensions=(amp / meter**2).dim,
                           expr=membrane_expr))
        model_equations.append(SingleEquation(PARAMETER, 'v', volt.dim))
        model = Equations(model_equations)

        ###### Process model equations (Im) to extract total conductance and the remaining current
        # Expand expressions in the membrane equation
        for var, expr in model.get_substituted_expressions(
                include_subexpressions=True):
            if var == 'Im':
                Im_expr = expr
                break
        else:
            raise AssertionError('Model equations did not contain Im!')

        # Differentiate Im with respect to v
        Im_sympy_exp = str_to_sympy(Im_expr.code)
        v_sympy = sp.Symbol('v', real=True)
        diffed = sp.diff(Im_sympy_exp, v_sympy)

        unevaled_derivatives = diffed.atoms(sp.Derivative)
        if len(unevaled_derivatives):
            raise TypeError(
                'Cannot take the derivative of "{Im}" with respect '
                'to v.'.format(Im=Im_expr.code))

        gtot_str = sympy_to_str(sp.simplify(-diffed))
        I0_str = sympy_to_str(sp.simplify(Im_sympy_exp - diffed * v_sympy))

        if gtot_str == '0':
            gtot_str += '*siemens/meter**2'
        if I0_str == '0':
            I0_str += '*amp/meter**2'
        gtot_str = "gtot__private=" + gtot_str + ": siemens/meter**2"
        I0_str = "I0__private=" + I0_str + ": amp/meter**2"

        model += Equations(gtot_str + "\n" + I0_str)

        # Insert morphology (store a copy)
        self.morphology = copy.deepcopy(morphology)

        # Flatten the morphology
        self.flat_morphology = FlatMorphology(morphology)

        # Equations for morphology
        # TODO: check whether Cm and Ri are already in the equations
        #       no: should be shared instead of constant
        #       yes: should be constant (check)
        eqs_constants = Equations("""
        length : meter (constant)
        distance : meter (constant)
        area : meter**2 (constant)
        volume : meter**3
        Ic : amp/meter**2
        diameter : meter (constant)
        Cm : farad/meter**2 (constant)
        Ri : ohm*meter (constant, shared)
        r_length_1 : meter (constant)
        r_length_2 : meter (constant)
        time_constant = Cm/gtot__private : second
        space_constant = (2/pi)**(1.0/3.0) * (area/(1/r_length_1 + 1/r_length_2))**(1.0/6.0) /
                         (2*(Ri*gtot__private)**(1.0/2.0)) : meter
        """)
        if self.flat_morphology.has_coordinates:
            eqs_constants += Equations('''
            x : meter (constant)
            y : meter (constant)
            z : meter (constant)
            ''')

        NeuronGroup.__init__(self,
                             morphology.total_compartments,
                             model=model + eqs_constants,
                             threshold=threshold,
                             refractory=refractory,
                             reset=reset,
                             events=events,
                             method=method,
                             dt=dt,
                             clock=clock,
                             order=order,
                             namespace=namespace,
                             dtype=dtype,
                             name=name)
        # Parameters and intermediate variables for solving the cable equations
        # Note that some of these variables could have meaningful physical
        # units (e.g. _v_star is in volt, _I0_all is in amp/meter**2 etc.) but
        # since these variables should never be used in user code, we don't
        # assign them any units
        self.variables.add_arrays(
            [
                '_ab_star0',
                '_ab_star1',
                '_ab_star2',
                '_a_minus0',
                '_a_minus1',
                '_a_minus2',
                '_a_plus0',
                '_a_plus1',
                '_a_plus2',
                '_b_plus',
                '_b_minus',
                '_v_star',
                '_u_plus',
                '_u_minus',
                '_v_previous',
                # The following three are for solving the
                # three tridiag systems in parallel
                '_c1',
                '_c2',
                '_c3',
                # The following two are only necessary for
                # C code where we cannot deal with scalars
                # and arrays interchangeably:
                '_I0_all',
                '_gtot_all'
            ],
            size=self.N,
            read_only=True)

        self.Cm = Cm
        self.Ri = Ri
        # These explict assignments will load the morphology values from disk
        # in standalone mode
        self.distance_ = self.flat_morphology.distance
        self.length_ = self.flat_morphology.length
        self.area_ = self.flat_morphology.area
        self.diameter_ = self.flat_morphology.diameter
        self.r_length_1_ = self.flat_morphology.r_length_1
        self.r_length_2_ = self.flat_morphology.r_length_2
        if self.flat_morphology.has_coordinates:
            self.x_ = self.flat_morphology.x
            self.y_ = self.flat_morphology.y
            self.z_ = self.flat_morphology.z

        # Performs numerical integration step
        self.add_attribute('diffusion_state_updater')
        self.diffusion_state_updater = SpatialStateUpdater(self,
                                                           method,
                                                           clock=self.clock,
                                                           order=order)

        # Update v after the gating variables to obtain consistent Ic and Im
        self.diffusion_state_updater.order = 1

        # Creation of contained_objects that do the work
        self.contained_objects.extend([self.diffusion_state_updater])

        if len(constant_over_dt):
            self.subexpression_updater = SubexpressionUpdater(
                self, constant_over_dt)
            self.contained_objects.append(self.subexpression_updater)
コード例 #49
0
    def _generate_RHS(self, eqs, var, eq_symbols, temp_vars, expr,
                      non_stochastic_expr, stochastic_expr):
        '''
        Helper function used in `__call__`. Generates the right hand side of
        an abstract code statement by appropriately replacing f, g and t.
        For example, given a differential equation ``dv/dt = -(v + I) / tau``
        (i.e. `var` is ``v` and `expr` is ``(-v + I) / tau``) together with
        the `rk2` step ``return x + dt*f(x +  k/2, t + dt/2)``
        (i.e. `non_stochastic_expr` is
        ``x + dt*f(x +  k/2, t + dt/2)`` and `stochastic_expr` is ``None``),
        produces ``v + dt*(-v - _k_v/2 + I + _k_I/2)/tau``.
                
        '''
        def replace_func(x, t, expr, temp_vars):
            '''
            Used to replace a single occurance of ``f(x, t)`` or ``g(x, t)``:
            `expr` is the non-stochastic (in the case of ``f``) or stochastic
            part (``g``) of the expression defining the right-hand-side of the
            differential equation describing `var`. It replaces the variable
            `var` with the value given as `x` and `t` by the value given for
            `t. Intermediate variables will be replaced with the appropriate
            replacements as well.
            
            For example, in the `rk2` integrator, the second step involves the
            calculation of ``f(k/2 + x, dt/2 + t)``.  If `var` is ``v`` and
            `expr` is ``-v / tau``, this will result in ``-(_k_v/2 + v)/tau``.
            
            Note that this deals with only one state variable `var`, given as
            an argument to the surrounding `_generate_RHS` function.
            '''

            try:
                s_expr = str_to_sympy(str(expr))
            except SympifyError as ex:
                raise ValueError('Error parsing the expression "%s": %s' %
                                 (expr, str(ex)))

            for var in eq_symbols:
                # Generate specific temporary variables for the state variable,
                # e.g. '_k_v' for the state variable 'v' and the temporary
                # variable 'k'.
                temp_var_replacements = dict(
                    ((self.symbols[temp_var], _symbol(temp_var + '_' + var))
                     for temp_var in temp_vars))
                # In the expression given as 'x', replace 'x' by the variable
                # 'var' and all the temporary variables by their
                # variable-specific counterparts.
                x_replacement = x.subs(self.symbols['__x'], eq_symbols[var])
                x_replacement = x_replacement.subs(temp_var_replacements)

                # Replace the variable `var` in the expression by the new `x`
                # expression
                s_expr = s_expr.subs(eq_symbols[var], x_replacement)

            # Directly substitute the 't' expression for the symbol t, there
            # are no temporary variables to consider here.
            s_expr = s_expr.subs(self.symbols['__t'], t)

            return s_expr

        # Note: in the following we are silently ignoring the case that a
        # state updater does not care about either the non-stochastic or the
        # stochastic part of an equation. We do trust state updaters to
        # correctly specify their own abilities (i.e. they do not claim to
        # support stochastic equations but actually just ignore the stochastic
        # part). We can't really check the issue here, as we are only dealing
        # with one line of the state updater description. It is perfectly valid
        # to write the euler update as:
        #     non_stochastic = dt * f(x, t)
        #     stochastic = dt**.5 * g(x, t) * xi
        #     return x + non_stochastic + stochastic
        #
        # In the above case, we'll deal with lines which do not define either
        # the stochastic or the non-stochastic part.

        non_stochastic, stochastic = expr.split_stochastic()
        # We do have a non-stochastic part in our equation and in the state
        # updater description
        if not (non_stochastic is None or non_stochastic_expr is None):
            # Replace the f(x, t) part
            replace_f = lambda x, t: replace_func(x, t, non_stochastic,
                                                  temp_vars)
            non_stochastic_result = non_stochastic_expr.replace(
                self.symbols['__f'], replace_f)
            # Replace x by the respective variable
            non_stochastic_result = non_stochastic_result.subs(
                self.symbols['__x'], eq_symbols[var])
            # Replace intermediate variables
            temp_var_replacements = dict(
                (self.symbols[temp_var], _symbol(temp_var + '_' + var))
                for temp_var in temp_vars)
            non_stochastic_result = non_stochastic_result.subs(
                temp_var_replacements)
        else:
            non_stochastic_result = None

        # We do have a stochastic part in our equation and in the state updater
        # description
        if not (stochastic is None or stochastic_expr is None):
            stochastic_results = []

            # We potentially have more than one stochastic variable
            for xi in stochastic:
                # Replace the g(x, t)*xi part
                replace_g = lambda x, t: replace_func(x, t, stochastic[xi],
                                                      temp_vars)
                stochastic_result = stochastic_expr.replace(
                    self.symbols['__g'], replace_g)

                # Replace x and xi by the respective variables
                stochastic_result = stochastic_result.subs(
                    self.symbols['__x'], eq_symbols[var])
                stochastic_result = stochastic_result.subs(
                    self.symbols['__dW'], xi)

                # Replace intermediate variables
                temp_var_replacements = dict(
                    (self.symbols[temp_var], _symbol(temp_var + '_' + var))
                    for temp_var in temp_vars)

                stochastic_result = stochastic_result.subs(
                    temp_var_replacements)

                stochastic_results.append(stochastic_result)
        else:
            stochastic_results = []

        RHS = []
        # All the parts (one non-stochastic and potentially more than one
        # stochastic part) are combined with addition
        if non_stochastic_result is not None:
            RHS.append(sympy_to_str(non_stochastic_result))
        for stochastic_result in stochastic_results:
            RHS.append(sympy_to_str(stochastic_result))

        RHS = ' + '.join(RHS)
        return RHS
コード例 #50
0
    def __init__(self,
                 morphology=None,
                 model=None,
                 threshold=None,
                 refractory=False,
                 reset=None,
                 threshold_location=None,
                 dt=None,
                 clock=None,
                 order=0,
                 Cm=0.9 * uF / cm**2,
                 Ri=150 * ohm * cm,
                 name='spatialneuron*',
                 dtype=None,
                 namespace=None,
                 method=('linear', 'exponential_euler', 'rk2', 'heun')):

        # #### Prepare and validate equations
        if isinstance(model, basestring):
            model = Equations(model)
        if not isinstance(model, Equations):
            raise TypeError(('model has to be a string or an Equations '
                             'object, is "%s" instead.') % type(model))

        # Insert the threshold mechanism at the specified location
        if threshold_location is not None:
            if hasattr(threshold_location,
                       '_indices'):  # assuming this is a method
                threshold_location = threshold_location._indices()
                # for now, only a single compartment allowed
                if len(threshold_location) == 1:
                    threshold_location = threshold_location[0]
                else:
                    raise AttributeError(('Threshold can only be applied on a '
                                          'single location'))
            threshold = '(' + threshold + ') and (i == ' + str(
                threshold_location) + ')'

        # Check flags (we have point currents)
        model.check_flags({
            DIFFERENTIAL_EQUATION: ('point current', ),
            PARAMETER: ('constant', 'shared', 'linked', 'point current'),
            SUBEXPRESSION: ('shared', 'point current')
        })

        # Add the membrane potential
        model += Equations('''
        v:volt # membrane potential
        ''')

        # Extract membrane equation
        if 'Im' in model:
            membrane_eq = model['Im']  # the membrane equation
        else:
            raise TypeError('The transmembrane current Im must be defined')

        # Insert point currents in the membrane equation
        for eq in model.itervalues():
            if 'point current' in eq.flags:
                fail_for_dimension_mismatch(
                    eq.unit, amp,
                    "Point current " + eq.varname + " should be in amp")
                eq.flags.remove('point current')
                membrane_eq.expr = Expression(
                    str(membrane_eq.expr.code) + '+' + eq.varname + '/area')

        ###### Process model equations (Im) to extract total conductance and the remaining current
        # Check conditional linearity with respect to v
        # Match to _A*v+_B
        var = sp.Symbol('v', real=True)
        wildcard = sp.Wild('_A', exclude=[var])
        constant_wildcard = sp.Wild('_B', exclude=[var])
        pattern = wildcard * var + constant_wildcard

        # Expand expressions in the membrane equation
        membrane_eq.type = DIFFERENTIAL_EQUATION
        for var, expr in model._get_substituted_expressions(
        ):  # this returns substituted expressions for diff eqs
            if var == 'Im':
                Im_expr = expr
        membrane_eq.type = SUBEXPRESSION

        # Factor out the variable
        s_expr = sp.collect(Im_expr.sympy_expr.expand(), var)
        matches = s_expr.match(pattern)

        if matches is None:
            raise TypeError, "The membrane current must be linear with respect to v"
        a, b = (matches[wildcard], matches[constant_wildcard])

        # Extracts the total conductance from Im, and the remaining current
        minusa_str, b_str = sympy_to_str(-a), sympy_to_str(b)
        # Add correct units if necessary
        if minusa_str == '0':
            minusa_str += '*siemens/meter**2'
        if b_str == '0':
            b_str += '*amp/meter**2'
        gtot_str = "gtot__private=" + minusa_str + ": siemens/meter**2"
        I0_str = "I0__private=" + b_str + ": amp/meter**2"
        model += Equations(gtot_str + "\n" + I0_str)

        # Equations for morphology
        # TODO: check whether Cm and Ri are already in the equations
        #       no: should be shared instead of constant
        #       yes: should be constant (check)
        eqs_constants = Equations("""
        diameter : meter (constant)
        length : meter (constant)
        x : meter (constant)
        y : meter (constant)
        z : meter (constant)
        distance : meter (constant)
        area : meter**2 (constant)
        Cm : farad/meter**2 (constant)
        Ri : ohm*meter (constant, shared)
        space_constant = (diameter/(4*Ri*gtot__private))**.5 : meter # Not so sure about the name

        ### Parameters and intermediate variables for solving the cable equation
        ab_star0 : siemens/meter**2
        ab_plus0 : siemens/meter**2
        ab_minus0 : siemens/meter**2
        ab_star1 : siemens/meter**2
        ab_plus1 : siemens/meter**2
        ab_minus1 : siemens/meter**2
        ab_star2 : siemens/meter**2
        ab_plus2 : siemens/meter**2
        ab_minus2 : siemens/meter**2
        b_plus : siemens/meter**2
        b_minus : siemens/meter**2
        v_star : volt
        u_plus : 1
        u_minus : 1
        # The following two are only necessary for C code where we cannot deal
        # with scalars and arrays interchangeably
        gtot_all : siemens/meter**2
        I0_all : amp/meter**2
        """)
        # Possibilities for the name: characteristic_length, electrotonic_length, length_constant, space_constant

        # Insert morphology
        self.morphology = morphology

        # Link morphology variables to neuron's state variables
        self.morphology_data = MorphologyData(len(morphology))
        self.morphology.compress(self.morphology_data)

        NeuronGroup.__init__(self,
                             len(morphology),
                             model=model + eqs_constants,
                             threshold=threshold,
                             refractory=refractory,
                             reset=reset,
                             method=method,
                             dt=dt,
                             clock=clock,
                             order=order,
                             namespace=namespace,
                             dtype=dtype,
                             name=name)

        self.Cm = Cm
        self.Ri = Ri
        # TODO: View instead of copy for runtime?
        self.diameter_ = self.morphology_data.diameter
        self.distance_ = self.morphology_data.distance
        self.length_ = self.morphology_data.length
        self.area_ = self.morphology_data.area
        self.x_ = self.morphology_data.x
        self.y_ = self.morphology_data.y
        self.z_ = self.morphology_data.z

        # Performs numerical integration step
        self.add_attribute('diffusion_state_updater')
        self.diffusion_state_updater = SpatialStateUpdater(self,
                                                           method,
                                                           clock=self.clock,
                                                           order=order)

        # Creation of contained_objects that do the work
        self.contained_objects.extend([self.diffusion_state_updater])
コード例 #51
0
ファイル: spatialneuron.py プロジェクト: Kwartke/brian2
    def __init__(self, morphology=None, model=None, threshold=None,
                 refractory=False, reset=None,
                 threshold_location=None,
                 dt=None, clock=None, order=0, Cm=0.9 * uF / cm ** 2, Ri=150 * ohm * cm,
                 name='spatialneuron*', dtype=None, namespace=None,
                 method=('linear', 'exponential_euler', 'rk2', 'milstein')):

        # #### Prepare and validate equations
        if isinstance(model, basestring):
            model = Equations(model)
        if not isinstance(model, Equations):
            raise TypeError(('model has to be a string or an Equations '
                             'object, is "%s" instead.') % type(model))

        # Insert the threshold mechanism at the specified location
        if threshold_location is not None:
            if hasattr(threshold_location,
                       '_indices'):  # assuming this is a method
                threshold_location = threshold_location._indices()
                # for now, only a single compartment allowed
                if len(threshold_location) == 1:
                    threshold_location = threshold_location[0]
                else:
                    raise AttributeError(('Threshold can only be applied on a '
                                          'single location'))
            threshold = '(' + threshold + ') and (i == ' + str(threshold_location) + ')'

        # Check flags (we have point currents)
        model.check_flags({DIFFERENTIAL_EQUATION: ('point current',),
                           PARAMETER: ('constant', 'shared', 'linked', 'point current'),
                           SUBEXPRESSION: ('shared', 'point current')})

        # Add the membrane potential
        model += Equations('''
        v:volt # membrane potential
        ''')

        # Extract membrane equation
        if 'Im' in model:
            membrane_eq = model['Im']  # the membrane equation
        else:
            raise TypeError('The transmembrane current Im must be defined')

        # Insert point currents in the membrane equation
        for eq in model.itervalues():
            if 'point current' in eq.flags:
                fail_for_dimension_mismatch(eq.unit, amp,
                                            "Point current " + eq.varname + " should be in amp")
                eq.flags.remove('point current')
                membrane_eq.expr = Expression(
                    str(membrane_eq.expr.code) + '+' + eq.varname + '/area')

        ###### Process model equations (Im) to extract total conductance and the remaining current
        # Check conditional linearity with respect to v
        # Match to _A*v+_B
        var = sp.Symbol('v', real=True)
        wildcard = sp.Wild('_A', exclude=[var])
        constant_wildcard = sp.Wild('_B', exclude=[var])
        pattern = wildcard * var + constant_wildcard

        # Expand expressions in the membrane equation
        membrane_eq.type = DIFFERENTIAL_EQUATION
        for var, expr in model._get_substituted_expressions():  # this returns substituted expressions for diff eqs
            if var == 'Im':
                Im_expr = expr
        membrane_eq.type = SUBEXPRESSION

        # Factor out the variable
        s_expr = sp.collect(Im_expr.sympy_expr.expand(), var)
        matches = s_expr.match(pattern)

        if matches is None:
            raise TypeError, "The membrane current must be linear with respect to v"
        a, b = (matches[wildcard],
                matches[constant_wildcard])

        # Extracts the total conductance from Im, and the remaining current
        minusa_str, b_str = sympy_to_str(-a), sympy_to_str(b)
        # Add correct units if necessary
        if minusa_str == '0':
            minusa_str += '*siemens/meter**2'
        if b_str == '0':
            b_str += '*amp/meter**2'
        gtot_str = "gtot__private=" + minusa_str + ": siemens/meter**2"
        I0_str = "I0__private=" + b_str + ": amp/meter**2"
        model += Equations(gtot_str + "\n" + I0_str)

        # Equations for morphology
        # TODO: check whether Cm and Ri are already in the equations
        #       no: should be shared instead of constant
        #       yes: should be constant (check)
        eqs_constants = Equations("""
        diameter : meter (constant)
        length : meter (constant)
        x : meter (constant)
        y : meter (constant)
        z : meter (constant)
        distance : meter (constant)
        area : meter**2 (constant)
        Cm : farad/meter**2 (constant)
        Ri : ohm*meter (constant, shared)
        space_constant = (diameter/(4*Ri*gtot__private))**.5 : meter # Not so sure about the name

        ### Parameters and intermediate variables for solving the cable equation
        ab_star0 : siemens/meter**2
        ab_plus0 : siemens/meter**2
        ab_minus0 : siemens/meter**2
        ab_star1 : siemens/meter**2
        ab_plus1 : siemens/meter**2
        ab_minus1 : siemens/meter**2
        ab_star2 : siemens/meter**2
        ab_plus2 : siemens/meter**2
        ab_minus2 : siemens/meter**2
        b_plus : siemens/meter**2
        b_minus : siemens/meter**2
        v_star : volt
        u_plus : 1
        u_minus : 1
        """)
        # Possibilities for the name: characteristic_length, electrotonic_length, length_constant, space_constant

        # Insert morphology
        self.morphology = morphology

        # Link morphology variables to neuron's state variables
        self.morphology_data = MorphologyData(len(morphology))
        self.morphology.compress(self.morphology_data)

        NeuronGroup.__init__(self, len(morphology), model=model + eqs_constants,
                             threshold=threshold, refractory=refractory,
                             reset=reset,
                             method=method, dt=dt, clock=clock, order=order,
                             namespace=namespace, dtype=dtype, name=name)

        self.Cm = Cm
        self.Ri = Ri
        # TODO: View instead of copy for runtime?
        self.diameter_ = self.morphology_data.diameter
        self.distance_ = self.morphology_data.distance
        self.length_ = self.morphology_data.length
        self.area_ = self.morphology_data.area
        self.x_ = self.morphology_data.x
        self.y_ = self.morphology_data.y
        self.z_ = self.morphology_data.z

        # Performs numerical integration step
        self.add_attribute('diffusion_state_updater')
        self.diffusion_state_updater = SpatialStateUpdater(self, method,
                                                           clock=self.clock,
                                                           order=order)

        # Creation of contained_objects that do the work
        self.contained_objects.extend([self.diffusion_state_updater])
コード例 #52
0
    def __init__(self,
                 morphology=None,
                 model=None,
                 threshold=None,
                 refractory=False,
                 reset=None,
                 events=None,
                 threshold_location=None,
                 dt=None,
                 clock=None,
                 order=0,
                 Cm=0.9 * uF / cm**2,
                 Ri=150 * ohm * cm,
                 name='spatialneuron*',
                 dtype=None,
                 namespace=None,
                 method=('linear', 'exponential_euler', 'rk2', 'heun')):

        # #### Prepare and validate equations
        if isinstance(model, basestring):
            model = Equations(model)
        if not isinstance(model, Equations):
            raise TypeError(('model has to be a string or an Equations '
                             'object, is "%s" instead.') % type(model))

        # Insert the threshold mechanism at the specified location
        if threshold_location is not None:
            if hasattr(threshold_location,
                       '_indices'):  # assuming this is a method
                threshold_location = threshold_location._indices()
                # for now, only a single compartment allowed
                if len(threshold_location) == 1:
                    threshold_location = threshold_location[0]
                else:
                    raise AttributeError(('Threshold can only be applied on a '
                                          'single location'))
            threshold = '(' + threshold + ') and (i == ' + str(
                threshold_location) + ')'

        # Check flags (we have point currents)
        model.check_flags({
            DIFFERENTIAL_EQUATION: ('point current', ),
            PARAMETER: ('constant', 'shared', 'linked', 'point current'),
            SUBEXPRESSION: ('shared', 'point current')
        })

        # Add the membrane potential
        model += Equations('''
        v:volt # membrane potential
        ''')

        # Extract membrane equation
        if 'Im' in model:
            membrane_eq = model['Im']  # the membrane equation
        else:
            raise TypeError('The transmembrane current Im must be defined')

        # Insert point currents in the membrane equation
        for eq in model.itervalues():
            if 'point current' in eq.flags:
                fail_for_dimension_mismatch(
                    eq.unit, amp,
                    "Point current " + eq.varname + " should be in amp")
                eq.flags.remove('point current')
                membrane_eq.expr = Expression(
                    str(membrane_eq.expr.code) + '+' + eq.varname + '/area')

        ###### Process model equations (Im) to extract total conductance and the remaining current
        # Check conditional linearity with respect to v
        # Match to _A*v+_B
        var = sp.Symbol('v', real=True)
        wildcard = sp.Wild('_A', exclude=[var])
        constant_wildcard = sp.Wild('_B', exclude=[var])
        pattern = wildcard * var + constant_wildcard

        # Expand expressions in the membrane equation
        membrane_eq.type = DIFFERENTIAL_EQUATION
        for var, expr in model.get_substituted_expressions():
            if var == 'Im':
                Im_expr = expr
        membrane_eq.type = SUBEXPRESSION

        # Factor out the variable
        s_expr = sp.collect(str_to_sympy(Im_expr.code).expand(), var)
        matches = s_expr.match(pattern)

        if matches is None:
            raise TypeError, "The membrane current must be linear with respect to v"
        a, b = (matches[wildcard], matches[constant_wildcard])

        # Extracts the total conductance from Im, and the remaining current
        minusa_str, b_str = sympy_to_str(-a), sympy_to_str(b)
        # Add correct units if necessary
        if minusa_str == '0':
            minusa_str += '*siemens/meter**2'
        if b_str == '0':
            b_str += '*amp/meter**2'
        gtot_str = "gtot__private=" + minusa_str + ": siemens/meter**2"
        I0_str = "I0__private=" + b_str + ": amp/meter**2"
        model += Equations(gtot_str + "\n" + I0_str)

        # Insert morphology (store a copy)
        self.morphology = copy.deepcopy(morphology)

        # Flatten the morphology
        self.flat_morphology = FlatMorphology(morphology)

        # Equations for morphology
        # TODO: check whether Cm and Ri are already in the equations
        #       no: should be shared instead of constant
        #       yes: should be constant (check)
        eqs_constants = Equations("""
        length : meter (constant)
        distance : meter (constant)
        area : meter**2 (constant)
        volume : meter**3
        diameter : meter (constant)
        Cm : farad/meter**2 (constant)
        Ri : ohm*meter (constant, shared)
        r_length_1 : meter (constant)
        r_length_2 : meter (constant)
        time_constant = Cm/gtot__private : second
        space_constant = (2/pi)**(1.0/3.0) * (area/(1/r_length_1 + 1/r_length_2))**(1.0/6.0) /
                         (2*(Ri*gtot__private)**(1.0/2.0)) : meter
        """)
        if self.flat_morphology.has_coordinates:
            eqs_constants += Equations('''
            x : meter (constant)
            y : meter (constant)
            z : meter (constant)
            ''')

        NeuronGroup.__init__(self,
                             morphology.total_compartments,
                             model=model + eqs_constants,
                             threshold=threshold,
                             refractory=refractory,
                             reset=reset,
                             events=events,
                             method=method,
                             dt=dt,
                             clock=clock,
                             order=order,
                             namespace=namespace,
                             dtype=dtype,
                             name=name)
        # Parameters and intermediate variables for solving the cable equations
        # Note that some of these variables could have meaningful physical
        # units (e.g. _v_star is in volt, _I0_all is in amp/meter**2 etc.) but
        # since these variables should never be used in user code, we don't
        # assign them any units
        self.variables.add_arrays(
            [
                '_ab_star0',
                '_ab_star1',
                '_ab_star2',
                '_a_minus0',
                '_a_minus1',
                '_a_minus2',
                '_a_plus0',
                '_a_plus1',
                '_a_plus2',
                '_b_plus',
                '_b_minus',
                '_v_star',
                '_u_plus',
                '_u_minus',
                # The following three are for solving the
                # three tridiag systems in parallel
                '_c1',
                '_c2',
                '_c3',
                # The following two are only necessary for
                # C code where we cannot deal with scalars
                # and arrays interchangeably:
                '_I0_all',
                '_gtot_all'
            ],
            unit=1,
            size=self.N,
            read_only=True)

        self.Cm = Cm
        self.Ri = Ri
        # These explict assignments will load the morphology values from disk
        # in standalone mode
        self.distance_ = self.flat_morphology.distance
        self.length_ = self.flat_morphology.length
        self.area_ = self.flat_morphology.area
        self.diameter_ = self.flat_morphology.diameter
        self.r_length_1_ = self.flat_morphology.r_length_1
        self.r_length_2_ = self.flat_morphology.r_length_2
        if self.flat_morphology.has_coordinates:
            self.x_ = self.flat_morphology.x
            self.y_ = self.flat_morphology.y
            self.z_ = self.flat_morphology.z

        # Performs numerical integration step
        self.add_attribute('diffusion_state_updater')
        self.diffusion_state_updater = SpatialStateUpdater(self,
                                                           method,
                                                           clock=self.clock,
                                                           order=order)

        # Creation of contained_objects that do the work
        self.contained_objects.extend([self.diffusion_state_updater])