Exemplo n.º 1
0
def get_linear_system(eqs, variables):
    """
    Convert equations into a linear system using sympy.
    
    Parameters
    ----------
    eqs : `Equations`
        The model equations.
    
    Returns
    -------
    (diff_eq_names, coefficients, constants) : (list of str, `sympy.Matrix`, `sympy.Matrix`)
        A tuple containing the variable names (`diff_eq_names`) corresponding
        to the rows of the matrix `coefficients` and the vector `constants`,
        representing the system of equations in the form M * X + B
    
    Raises
    ------
    ValueError
        If the equations cannot be converted into an M * X + B form.
    """
    diff_eqs = eqs.get_substituted_expressions(variables)
    diff_eq_names = [name for name, _ in diff_eqs]

    symbols = [Symbol(name, real=True) for name in diff_eq_names]

    coefficients = sp.zeros(len(diff_eq_names))
    constants = sp.zeros(len(diff_eq_names), 1)

    for row_idx, (name, expr) in enumerate(diff_eqs):
        s_expr = str_to_sympy(expr.code, variables).expand()

        current_s_expr = s_expr
        for col_idx, symbol in enumerate(symbols):
            current_s_expr = current_s_expr.collect(symbol)
            constant_wildcard = Wild('c', exclude=[symbol])
            factor_wildcard = Wild(f"c_{name}", exclude=symbols)
            one_pattern = factor_wildcard * symbol + constant_wildcard
            matches = current_s_expr.match(one_pattern)
            if matches is None:
                raise UnsupportedEquationsException(
                    f"The expression '{expr}', "
                    f"defining the variable "
                    f"'{name}', could not be "
                    f"separated into linear "
                    f"components.")

            coefficients[row_idx, col_idx] = matches[factor_wildcard]
            current_s_expr = matches[constant_wildcard]

        # The remaining constant should be a true constant
        constants[row_idx] = current_s_expr

    return (diff_eq_names, coefficients, constants)
Exemplo n.º 2
0
def _check_for_locally_constant(expression, variables, dt_value, t_symbol):

    for arg in expression.args:
        if arg is t_symbol:
            # We found "t" -- if it is not the only argument of a locally
            # constant function we bail out
            func_name = str(expression.func)
            if not (func_name in variables
                    and variables[func_name].is_locally_constant(dt_value)):
                raise UnsupportedEquationsException(
                    ('t is used in a context '
                     'where we cannot '
                     'guarantee that it can be '
                     'considered locally '
                     'constant.'))
        else:
            _check_for_locally_constant(arg, variables, dt_value, t_symbol)
Exemplo n.º 3
0
    def translate(
        self, code, dtype
    ):  # TODO: it's not so nice we have to copy the contents of this function..
        '''
        Translates an abstract code block into the target language.
        '''
        # first check if user code is not using variables that are also used by GSL
        reserved_variables = [
            '_dataholder', '_fill_y_vector', '_empty_y_vector',
            '_GSL_dataholder', '_GSL_y', '_GSL_func'
        ]
        if any([var in self.variables for var in reserved_variables]):
            # import here to avoid circular import
            raise ValueError(("The variables %s are reserved for the GSL "
                              "internal code." % (str(reserved_variables))))

        # if the following statements are not added, Brian translates the
        # differential expressions in the abstract code for GSL to scalar statements
        # in the case no non-scalar variables are used in the expression
        diff_vars = self.find_differential_variables(code.values())
        self.add_gsl_variables_as_non_scalar(diff_vars)

        # add arrays we want to use in generated code before self.generator.translate() so
        # brian does namespace unpacking for us
        pointer_names = self.add_meta_variables(self.method_options)

        scalar_statements = {}
        vector_statements = {}
        for ac_name, ac_code in code.iteritems():
            statements = make_statements(ac_code,
                                         self.variables,
                                         dtype,
                                         optimise=True,
                                         blockname=ac_name)
            scalar_statements[ac_name], vector_statements[ac_name] = statements
        for vs in vector_statements.itervalues():
            # Check that the statements are meaningful independent on the order of
            # execution (e.g. for synapses)
            try:
                if self.has_repeated_indices(
                        vs
                ):  # only do order dependence if there are repeated indices
                    check_for_order_independence(
                        vs, self.generator.variables,
                        self.generator.variable_indices)
            except OrderDependenceError:
                # If the abstract code is only one line, display it in ful   l
                if len(vs) <= 1:
                    error_msg = 'Abstract code: "%s"\n' % vs[0]
                else:
                    error_msg = (
                        '%_GSL_driver lines of abstract code, first line is: '
                        '"%s"\n') % (len(vs), vs[0])

        # save function names because self.generator.translate_statement_sequence
        # deletes these from self.variables but we need to know which identifiers
        # we can safely ignore (i.e. we can ignore the functions because they are
        # handled by the original generator)
        self.function_names = self.find_function_names()

        scalar_code, vector_code, kwds = self.generator.translate_statement_sequence(
            scalar_statements, vector_statements)

        ############ translate code for GSL

        # first check if any indexing other than '_idx' is used (currently not supported)
        for code_list in scalar_code.values() + vector_code.values():
            for code in code_list:
                m = re.search('\[(\w+)\]', code)
                if m is not None:
                    if m.group(1) != '0' and m.group(1) != '_idx':
                        from brian2.stateupdaters.base import UnsupportedEquationsException
                        raise UnsupportedEquationsException(
                            ("Equations result in state "
                             "updater code with indexing "
                             "other than '_idx', which "
                             "is currently not supported "
                             "in combination with the "
                             "GSL stateupdater."))

        # differential variable specific operations
        to_replace = self.diff_var_to_replace(diff_vars)
        GSL_support_code = self.get_dimension_code(len(diff_vars))
        GSL_support_code += self.yvector_code(diff_vars)

        # analyze all needed variables; if not in self.variables: put in separate dic.
        # also keep track of variables needed for scalar statements and vector statements
        other_variables = self.find_undefined_variables(
            scalar_statements[None] + vector_statements[None])
        variables_in_scalar = self.find_used_variables(scalar_statements[None],
                                                       other_variables)
        variables_in_vector = self.find_used_variables(vector_statements[None],
                                                       other_variables)
        # so that _dataholder holds diff_vars as well, even if they don't occur
        # in the actual statements
        for var in diff_vars.keys():
            if not var in variables_in_vector:
                variables_in_vector[var] = self.variables[var]
        # lets keep track of the variables that eventually need to be added to
        # the _GSL_dataholder somehow
        self.variables_to_be_processed = variables_in_vector.keys()

        # add code for _dataholder struct
        GSL_support_code = self.write_dataholder(
            variables_in_vector) + GSL_support_code
        # add e.g. _lio_1 --> _GSL_dataholder._lio_1 to replacer
        to_replace.update(
            self.to_replace_vector_vars(variables_in_vector,
                                        ignore=diff_vars.keys()))
        # write statements that unpack (python) namespace to _dataholder struct
        # or local namespace
        GSL_main_code = self.unpack_namespace(variables_in_vector,
                                              variables_in_scalar, ['t'])

        # rewrite actual calculations described by vector_code and put them in _GSL_func
        func_code = self.translate_one_statement_sequence(
            vector_statements[None], scalar=False)
        GSL_support_code += self.make_function_code(
            self.translate_vector_code(func_code, to_replace))
        scalar_func_code = self.translate_one_statement_sequence(
            scalar_statements[None], scalar=True)
        # rewrite scalar code, keep variables that are needed in scalar code normal
        # and add variables to _dataholder for vector_code
        GSL_main_code += '\n' + self.translate_scalar_code(
            scalar_func_code, variables_in_scalar, variables_in_vector)
        if len(self.variables_to_be_processed) > 0:
            raise AssertionError(
                ("Not all variables that will be used in the vector "
                 "code have been added to the _GSL_dataholder. This "
                 "might mean that the _GSL_func is using unitialized "
                 "variables."
                 "\nThe unprocessed variables "
                 "are: %s" % (str(self.variables_to_be_processed))))

        scalar_code['GSL'] = GSL_main_code
        kwds['define_GSL_scale_array'] = self.scale_array_code(
            diff_vars, self.method_options)
        kwds['n_diff_vars'] = len(diff_vars)
        kwds['GSL_settings'] = dict(self.method_options)
        kwds['GSL_settings']['integrator'] = self.integrator
        kwds['support_code_lines'] += GSL_support_code.split('\n')
        kwds['t_array'] = self.get_array_name(self.variables['t']) + '[0]'
        kwds['dt_array'] = self.get_array_name(self.variables['dt']) + '[0]'
        kwds['define_dt'] = 'dt' not in variables_in_scalar
        kwds['cpp_standalone'] = self.is_cpp_standalone()
        for key, value in pointer_names.items():
            kwds[key] = value
        return scalar_code, vector_code, kwds
Exemplo n.º 4
0
    def __call__(self, equations, variables=None, method_options=None):
        logger.warn(
            "The 'independent' state updater is deprecated and might be "
            "removed in future versions of Brian.",
            'deprecated_independent',
            once=True)
        extract_method_options(method_options, {})
        if equations.is_stochastic:
            raise UnsupportedEquationsException("Cannot solve stochastic "
                                                "equations with this state "
                                                "updater")
        if variables is None:
            variables = {}

        diff_eqs = equations.get_substituted_expressions(variables)

        t = Symbol('t', real=True, positive=True)
        dt = Symbol('dt', real=True, positive=True)
        t0 = Symbol('t0', real=True, positive=True)

        code = []
        for name, expression in diff_eqs:
            rhs = str_to_sympy(expression.code, variables)

            # We have to be careful and use the real=True assumption as well,
            # otherwise sympy doesn't consider the symbol a match to the content
            # of the equation
            var = Symbol(name, real=True)
            f = sp.Function(name)
            rhs = rhs.subs(var, f(t))
            derivative = sp.Derivative(f(t), t)
            diff_eq = sp.Eq(derivative, rhs)
            # TODO: simplify=True sometimes fails with 0.7.4, see:
            # https://github.com/sympy/sympy/issues/2666
            try:
                general_solution = sp.dsolve(diff_eq, f(t), simplify=True)
            except RuntimeError:
                general_solution = sp.dsolve(diff_eq, f(t), simplify=False)
            # Check whether this is an explicit solution
            if not getattr(general_solution, 'lhs', None) == f(t):
                raise UnsupportedEquationsException(
                    f"Cannot explicitly solve: {str(diff_eq)}")
            # Solve for C1 (assuming "var" as the initial value and "t0" as time)
            if general_solution.has(Symbol('C1')):
                if general_solution.has(Symbol('C2')):
                    raise UnsupportedEquationsException(
                        f'Too many constants in solution: {str(general_solution)}'
                    )
                constant_solution = sp.solve(general_solution, Symbol('C1'))
                if len(constant_solution) != 1:
                    raise UnsupportedEquationsException(
                        ("Couldn't solve for the constant "
                         "C1 in : %s ") % str(general_solution))
                constant = constant_solution[0].subs(t, t0).subs(f(t0), var)
                solution = general_solution.rhs.subs('C1', constant)
            else:
                solution = general_solution.rhs.subs(t, t0).subs(f(t0), var)
            # Evaluate the expression for one timestep
            solution = solution.subs(t, t + dt).subs(t0, t)
            # only try symplifying it -- it sometimes raises an error
            try:
                solution = solution.simplify()
            except ValueError:
                pass

            code.append(f"{name} = {sympy_to_str(solution)}")

        return '\n'.join(code)
Exemplo n.º 5
0
    def __call__(self, equations, variables=None, method_options=None):
        method_options = extract_method_options(method_options,
                                                {'simplify': True})

        if equations.is_stochastic:
            raise UnsupportedEquationsException("Cannot solve stochastic "
                                                "equations with this state "
                                                "updater.")
        if variables is None:
            variables = {}

        # Get a representation of the ODE system in the form of
        # dX/dt = M*X + B
        varnames, matrix, constants = get_linear_system(equations, variables)

        # No differential equations, nothing to do (this occurs sometimes in the
        # test suite where the whole model is nothing more than something like
        # 'v : 1')
        if matrix.shape == (0, 0):
            return ''

        # Make sure that the matrix M is constant, i.e. it only contains
        # external variables or constant variables

        # Check for time dependence
        dt_value = variables['dt'].get_value(
        )[0] if 'dt' in variables else None

        # This will raise an error if we meet the symbol "t" anywhere
        # except as an argument of a locally constant function
        for entry in itertools.chain(matrix, constants):
            if not is_constant_over_dt(entry, variables, dt_value):
                raise UnsupportedEquationsException(
                    f"Expression '{sympy_to_str(entry)}' is not guaranteed to be "
                    f"constant over a time step.")

        symbols = [Symbol(variable, real=True) for variable in varnames]
        solution = sp.solve_linear_system(matrix.row_join(constants), *symbols)
        if solution is None or set(symbols) != set(solution.keys()):
            raise UnsupportedEquationsException("Cannot solve the given "
                                                "equations with this "
                                                "stateupdater.")
        b = sp.ImmutableMatrix([solution[symbol] for symbol in symbols])

        # Solve the system
        dt = Symbol('dt', real=True, positive=True)
        try:
            A = (matrix * dt).exp()
        except NotImplementedError:
            raise UnsupportedEquationsException("Cannot solve the given "
                                                "equations with this "
                                                "stateupdater.")
        if method_options['simplify']:
            A = A.applyfunc(
                lambda x: sp.factor_terms(sp.cancel(sp.signsimp(x))))
        C = sp.ImmutableMatrix(A * b) - b
        _S = sp.MatrixSymbol('_S', len(varnames), 1)
        updates = A * _S + C
        updates = updates.as_explicit()

        # The solution contains _S[0, 0], _S[1, 0] etc. for the state variables,
        # replace them with the state variable names
        abstract_code = []
        for idx, (variable, update) in enumerate(zip(varnames, updates)):
            rhs = update
            if rhs.has(I, re, im):
                raise UnsupportedEquationsException(
                    "The solution to the linear system "
                    "contains complex values "
                    "which is currently not implemented.")
            for row_idx, varname in enumerate(varnames):
                rhs = rhs.subs(_S[row_idx, 0], varname)

            # Do not overwrite the real state variables yet, the update step
            # of other state variables might still need the original values
            abstract_code.append(f"_{variable} = {sympy_to_str(rhs)}")

        # Update the state variables
        for variable in varnames:
            abstract_code.append(f"{variable} = _{variable}")
        return '\n'.join(abstract_code)
Exemplo n.º 6
0
    def __call__(self, equations, variables=None):
        if equations.is_stochastic:
            raise UnsupportedEquationsException('Cannot solve stochastic '
                                                'equations with this state '
                                                'updater')
        if variables is None:
            variables = {}

        diff_eqs = equations.get_substituted_expressions(variables)

        t = Symbol('t', real=True, positive=True)
        dt = Symbol('dt', real=True, positive=True)
        t0 = Symbol('t0', real=True, positive=True)
        f0 = Symbol('f0', real=True)
        # TODO: Shortcut for simple linear equations? Is all this effort really
        #       worth it?

        code = []
        for name, expression in diff_eqs:
            rhs = str_to_sympy(expression.code, variables)

            # We have to be careful and use the real=True assumption as well,
            # otherwise sympy doesn't consider the symbol a match to the content
            # of the equation
            var = Symbol(name, real=True)
            f = sp.Function(name)
            rhs = rhs.subs(var, f(t))
            derivative = sp.Derivative(f(t), t)
            diff_eq = sp.Eq(derivative, rhs)
            # TODO: simplify=True sometimes fails with 0.7.4, see:
            # https://github.com/sympy/sympy/issues/2666
            try:
                general_solution = sp.dsolve(diff_eq, f(t), simplify=True)
            except RuntimeError:
                general_solution = sp.dsolve(diff_eq, f(t), simplify=False)
            # Check whether this is an explicit solution
            if not getattr(general_solution, 'lhs', None) == f(t):
                raise UnsupportedEquationsException(
                    'Cannot explicitly solve: ' + str(diff_eq))
            # Solve for C1 (assuming "var" as the initial value and "t0" as time)
            if general_solution.has(Symbol('C1')):
                if general_solution.has(Symbol('C2')):
                    raise UnsupportedEquationsException(
                        'Too many constants in solution: %s' %
                        str(general_solution))
                constant_solution = sp.solve(general_solution, Symbol('C1'))
                if len(constant_solution) != 1:
                    raise UnsupportedEquationsException(
                        ("Couldn't solve for the constant "
                         "C1 in : %s ") % str(general_solution))
                constant = constant_solution[0].subs(t, t0).subs(f(t0), var)
                solution = general_solution.rhs.subs('C1', constant)
            else:
                solution = general_solution.rhs.subs(t, t0).subs(f(t0), var)
            # Evaluate the expression for one timestep
            solution = solution.subs(t, t + dt).subs(t0, t)
            # only try symplifying it -- it sometimes raises an error
            try:
                solution = solution.simplify()
            except ValueError:
                pass

            code.append(name + ' = ' + sympy_to_str(solution))

        return '\n'.join(code)
Exemplo n.º 7
0
    def __call__(self, equations, variables=None, simplify=True):
        if equations.is_stochastic:
            raise UnsupportedEquationsException('Cannot solve stochastic '
                                                'equations with this state '
                                                'updater.')
        if variables is None:
            variables = {}

        # Get a representation of the ODE system in the form of
        # dX/dt = M*X + B
        varnames, matrix, constants = get_linear_system(equations, variables)

        # No differential equations, nothing to do (this occurs sometimes in the
        # test suite where the whole model is nothing more than something like
        # 'v : 1')
        if matrix.shape == (0, 0):
            return ''

        # Make sure that the matrix M is constant, i.e. it only contains
        # external variables or constant variables
        t = Symbol('t', real=True, positive=True)

        # Check for time dependence
        if 'dt' in variables:
            dt_value = variables['dt'].get_value()[0]

            # This will raise an error if we meet the symbol "t" anywhere
            # except as an argument of a locally constant function
            t = Symbol('t', real=True, positive=True)
            for entry in itertools.chain(matrix, constants):
                _check_for_locally_constant(entry, variables, dt_value, t)

        symbols = [Symbol(variable, real=True) for variable in varnames]
        solution = sp.solve_linear_system(matrix.row_join(constants), *symbols)
        if solution is None or set(symbols) != set(solution.keys()):
            raise UnsupportedEquationsException('Cannot solve the given '
                                                'equations with this '
                                                'stateupdater.')
        b = sp.ImmutableMatrix([solution[symbol]
                                for symbol in symbols]).transpose()

        # Solve the system
        dt = Symbol('dt', real=True, positive=True)
        try:
            A = (matrix * dt).exp()
        except NotImplementedError:
            raise UnsupportedEquationsException('Cannot solve the given '
                                                'equations with this '
                                                'stateupdater.')
        if simplify:
            A = A.applyfunc(
                lambda x: sp.factor_terms(sp.cancel(sp.signsimp(x))))
        C = sp.ImmutableMatrix([A.dot(b)]) - b
        _S = sp.MatrixSymbol('_S', len(varnames), 1)
        updates = A * _S + C.transpose()
        updates = updates.as_explicit()

        # The solution contains _S[0, 0], _S[1, 0] etc. for the state variables,
        # replace them with the state variable names
        abstract_code = []
        for idx, (variable, update) in enumerate(zip(varnames, updates)):
            rhs = update
            for row_idx, varname in enumerate(varnames):
                rhs = rhs.subs(_S[row_idx, 0], varname)

            # Do not overwrite the real state variables yet, the update step
            # of other state variables might still need the original values
            abstract_code.append('_' + variable + ' = ' + sympy_to_str(rhs))

        # Update the state variables
        for variable in varnames:
            abstract_code.append(
                '{variable} = _{variable}'.format(variable=variable))
        return '\n'.join(abstract_code)