示例#1
0
    def __init__(self, v0, delay_len, before_t0=0., t0=0., dt=None):
        # size
        self.size = backend.shape(v0)

        # delay_len
        self.delay_len = delay_len
        self.dt = backend.get_dt() if dt is None else dt
        self.num_delay = int(math.ceil(delay_len / self.dt))

        # other variables
        self._delay_in = self.num_delay - 1
        self._delay_out = 0
        self.current_time = t0

        # before_t0
        self.before_t0 = before_t0

        # delay data
        self.data = backend.zeros((self.num_delay + 1, ) + self.size)
        if callable(before_t0):
            for i in range(self.num_delay):
                self.data[i] = before_t0(t0 + (i - self.num_delay) * self.dt)
        else:
            self.data[:-1] = before_t0
        self.data[-1] = v0
示例#2
0
def _base(A, B1, B2, C, f=None, tol=None, adaptive=None,
          dt=None, show_code=None, var_type=None):
    """

    Parameters
    ----------
    A :
    B1 :
    B2 :
    C :
    f :
    tol :
    adaptive :
    dt :
    show_code :
    var_type :

    Returns
    -------

    """
    adaptive = False if (adaptive is None) else adaptive
    dt = backend.get_dt() if (dt is None) else dt
    tol = 0.1 if tol is None else tol
    show_code = False if tol is None else show_code
    var_type = constants.POPU_VAR if var_type is None else var_type

    if f is None:
        return lambda f: adaptive_rk_wrapper(f, dt=dt, A=A, B1=B1, B2=B2, C=C, tol=tol,
                                             adaptive=adaptive, show_code=show_code,
                                             var_type=var_type)
    else:
        return adaptive_rk_wrapper(f, dt=dt, A=A, B1=B1, B2=B2, C=C, tol=tol,
                                   adaptive=adaptive, show_code=show_code,
                                   var_type=var_type)
示例#3
0
def ramp_current(c_start, c_end, duration, t_start=0, t_end=None, dt=None):
    """Get the gradually changed input current.

    Parameters
    ----------
    c_start : float
        The minimum (or maximum) current size.
    c_end : float
        The maximum (or minimum) current size.
    duration : int, float
        The total duration.
    t_start : float
        The ramped current start time-point.
    t_end : float
        The ramped current end time-point. Default is the None.
    dt

    Returns
    -------
    current_and_duration : tuple
        (The formatted current, total duration)
    """
    dt = backend.get_dt() if dt is None else dt
    t_end = duration if t_end is None else t_end

    current = np.zeros(int(np.ceil(duration / dt)))
    p1 = int(np.ceil(t_start / dt))
    p2 = int(np.ceil(t_end / dt))
    current[p1:p2] = np.linspace(c_start, c_end, p2 - p1)
    return current
示例#4
0
def rk2(f=None, show_code=None, dt=None, beta=None):
    """Runge–Kutta methods for ordinary differential equations.

    Generic second-order method.

    It has the characteristics of:

        - method stage = 2
        - method order = 2
        - Butcher Tables:

    .. math::

        \\begin{array}{c|cc}
            0 & 0 & 0 \\\\
            \\beta & \\beta & 0 \\\\
            \\hline & 1 - {1 \\over 2 * \\beta} & {1 \over 2 * \\beta}
        \\end{array}
    """
    beta = 2 / 3 if beta is None else beta
    dt = backend.get_dt() if dt is None else dt
    show_code = False if show_code is None else show_code

    if f is None:
        return lambda f: wrapper_of_rk2(
            f, show_code=show_code, dt=dt, beta=beta)
    else:
        return wrapper_of_rk2(f, show_code=show_code, dt=dt, beta=beta)
示例#5
0
    def run(self, duration, inputs=(), report=False, report_percent=0.1):
        """The running function.

        Parameters
        ----------
        duration : float, int, tuple, list
            The running duration.
        inputs : list, tuple
            The model inputs with the format of ``[(key, value [operation])]``.
        report : bool
            Whether report the running progress.
        report_percent : float
            The percent of progress to report.
        """

        # times
        # ------
        start, end = utils.check_duration(duration)
        times = backend.arange(start, end, backend.get_dt())
        run_length = backend.shape(times)[0]

        # build run function
        # ------------------
        self.run_func = self.build(inputs,
                                   inputs_is_formatted=False,
                                   mon_length=run_length,
                                   return_code=False)

        # run the model
        # -------------
        utils.run_model(self.run_func, times, report, report_percent)
        self.mon['ts'] = times
示例#6
0
def _wrap(wrapper, f, g, dt, sde_type, var_type, wiener_type, show_code, num_iter):
    """The base function to format a SRK method.

    Parameters
    ----------
    f : callable
        The drift function of the SDE.
    g : callable
        The diffusion function of the SDE.
    dt : float
        The numerical precision.
    sde_type : str
        "utils.ITO_SDE" : Ito's Stochastic Calculus.
        "utils.STRA_SDE" : Stratonovich's Stochastic Calculus.
    wiener_type : str
    var_type : str
        "scalar" : with the shape of ().
        "population" : with the shape of (N,) or (N1, N2) or (N1, N2, ...).
        "system": with the shape of (d, ), (d, N), or (d, N1, N2).
    show_code : bool
        Whether show the formatted code.

    Returns
    -------
    numerical_func : callable
        The numerical function.
    """

    sde_type = constants.ITO_SDE if sde_type is None else sde_type
    assert sde_type in constants.SUPPORTED_SDE_TYPE, f'Currently, BrainPy only support SDE types: ' \
                                                     f'{constants.SUPPORTED_SDE_TYPE}. But we got {sde_type}.'

    var_type = constants.POPU_VAR if var_type is None else var_type
    assert var_type in constants.SUPPORTED_VAR_TYPE, f'Currently, BrainPy only supports variable types: ' \
                                                     f'{constants.SUPPORTED_VAR_TYPE}. But we got {var_type}.'

    wiener_type = constants.SCALAR_WIENER if wiener_type is None else wiener_type
    assert wiener_type in constants.SUPPORTED_WIENER_TYPE, f'Currently, BrainPy only supports Wiener ' \
                                                           f'Process types: {constants.SUPPORTED_WIENER_TYPE}. ' \
                                                           f'But we got {wiener_type}.'

    show_code = False if show_code is None else show_code
    dt = backend.get_dt() if dt is None else dt
    num_iter = 10 if num_iter is None else num_iter

    if f is not None and g is not None:
        return wrapper(f=f, g=g, dt=dt, show_code=show_code, sde_type=sde_type,
                       var_type=var_type, wiener_type=wiener_type, num_iter=num_iter)

    elif f is not None:
        return lambda g: wrapper(f=f, g=g, dt=dt, show_code=show_code, sde_type=sde_type,
                                 var_type=var_type, wiener_type=wiener_type, num_iter=num_iter)

    elif g is not None:
        return lambda f: wrapper(f=f, g=g, dt=dt, show_code=show_code, sde_type=sde_type,
                                 var_type=var_type, wiener_type=wiener_type, num_iter=num_iter)

    else:
        raise ValueError('Must provide "f" or "g".')
示例#7
0
    def wrap(wrapper, f, g, dt, sde_type, var_type, wiener_type, show_code):
        """The base function to format a SRK method.

        Parameters
        ----------
        f : callable
            The drift function of the SDE.
        g : callable
            The diffusion function of the SDE.
        dt : float
            The numerical precision.
        sde_type : str
            "utils.ITO_SDE" : Ito's Stochastic Calculus.
            "utils.STRA_SDE" : Stratonovich's Stochastic Calculus.
        wiener_type : str
        var_type : str
            "scalar" : with the shape of ().
            "population" : with the shape of (N,) or (N1, N2) or (N1, N2, ...).
            "system": with the shape of (d, ), (d, N), or (d, N1, N2).
        show_code : bool
            Whether show the formatted code.

        Returns
        -------
        numerical_func : callable
            The numerical function.
        """

        var_type = constants.POPU_VAR if var_type is None else var_type
        sde_type = constants.ITO_SDE if sde_type is None else sde_type
        wiener_type = constants.SCALAR_WIENER if wiener_type is None else wiener_type
        if var_type not in constants.SUPPORTED_VAR_TYPE:
            raise errors.IntegratorError(f'Currently, BrainPy only supports variable types: '
                                         f'{constants.SUPPORTED_VAR_TYPE}. But we got {var_type}.')
        if sde_type != constants.ITO_SDE:
            raise errors.IntegratorError(f'SRK method for SDEs with scalar noise only supports Ito SDE type, '
                                         f'but we got {sde_type} integral.')
        if wiener_type != constants.SCALAR_WIENER:
            raise errors.IntegratorError(f'SRK method for SDEs with scalar noise only supports scalar '
                                         f'Wiener Process, but we got "{wiener_type}" noise.')

        show_code = False if show_code is None else show_code
        dt = backend.get_dt() if dt is None else dt

        if f is not None and g is not None:
            return wrapper(f=f, g=g, dt=dt, show_code=show_code, sde_type=sde_type,
                           var_type=var_type, wiener_type=wiener_type)

        elif f is not None:
            return lambda g: wrapper(f=f, g=g, dt=dt, show_code=show_code, sde_type=sde_type,
                                     var_type=var_type, wiener_type=wiener_type)

        elif g is not None:
            return lambda f: wrapper(f=f, g=g, dt=dt, show_code=show_code, sde_type=sde_type,
                                     var_type=var_type, wiener_type=wiener_type)

        else:
            raise ValueError('Must provide "f" or "g".')
示例#8
0
def _base(A, B, C, f, show_code, dt):
    dt = backend.get_dt() if dt is None else dt
    show_code = False if show_code is None else show_code

    if f is None:
        return lambda f: rk_wrapper(
            f, show_code=show_code, dt=dt, A=A, B=B, C=C)
    else:
        return rk_wrapper(f, show_code=show_code, dt=dt, A=A, B=B, C=C)
示例#9
0
def cross_correlation(spikes, bin, dt=None):
    """Calculate cross correlation index between neurons.

    The coherence [1]_ between two neurons i and j is measured by their
    cross-correlation of spike trains at zero time lag within a time bin
    of :math:`\\Delta t = \\tau`. More specifically, suppose that a long
    time interval T is divided into small bins of :math:`\\Delta t` and
    that two spike trains are given by :math:`X(l)=` 0 or 1, :math:`Y(l)=` 0
    or 1, :math:`l=1,2, \\ldots, K(T / K=\\tau)`. Thus, we define a coherence
    measure for the pair as:

    .. math::

        \\kappa_{i j}(\\tau)=\\frac{\\sum_{l=1}^{K} X(l) Y(l)}
        {\\sqrt{\\sum_{l=1}^{K} X(l) \\sum_{l=1}^{K} Y(l)}}

    The population coherence measure :math:`\\kappa(\\tau)` is defined by the
    average of :math:`\\kappa_{i j}(\\tau)` over many pairs of neurons in the
    network.

    Parameters
    ----------
    spikes :
        The history of spike states of the neuron group.
        It can be easily get via `StateMonitor(neu, ['spike'])`.
    bin : float, int
        The time bin to normalize spike states.
    dt : float, optional
        The time precision.

    Returns
    -------
    cc_index : float
        The cross correlation value which represents the synchronization index.

    References
    ----------
    .. [1] Wang, Xiao-Jing, and György Buzsáki. "Gamma oscillation by synaptic
           inhibition in a hippocampal interneuronal network model." Journal of
           neuroscience 16.20 (1996): 6402-6413.
    """

    dt = backend.get_dt() if dt is None else dt
    bin_size = int(bin / dt)
    num_hist, num_neu = spikes.shape
    num_bin = int(np.ceil(num_hist / bin_size))
    if num_bin * bin_size != num_hist:
        spikes = np.append(spikes,
                           np.zeros((num_bin * bin_size - num_hist, num_neu)),
                           axis=0)
    states = spikes.T.reshape((num_neu, num_bin, bin_size))
    states = (np.sum(states, axis=2) > 0.).astype(np.float_)
    all_k = []
    for i in range(num_neu):
        for j in range(i + 1, num_neu):
            all_k.append(_cc(states, i, j))
    return np.mean(all_k)
示例#10
0
def exponential_euler(f, return_linear_term=False):
    dt = backend.get_dt()

    def int_f(x, t, *args):
        df, linear_part = f(x, t, *args)
        y = x + (backend.exp(linear_part * dt) - 1) / linear_part * df
        return y

    return int_f
示例#11
0
    def __init__(self, size, delay_time):
        if isinstance(size, int):
            size = (size, )
        self.size = tuple(size)
        self.delay_time = delay_time

        if isinstance(delay_time, (int, float)):
            self.uniform_delay = True
            self.delay_num_step = int(math.ceil(
                delay_time / backend.get_dt())) + 1
            self.delay_data = ops.zeros((self.delay_num_step, ) + self.size)
        else:
            if not len(self.size) == 1:
                raise NotImplementedError(
                    f'Currently, BrainPy only supports 1D heterogeneous delays, while does '
                    f'not implement the heterogeneous delay with {len(self.size)}-dimensions.'
                )
            self.num = size2len(size)
            if isinstance(delay_time, type(ops.as_tensor([1]))):
                assert ops.shape(delay_time) == self.size
            elif callable(delay_time):
                delay_time2 = ops.zeros(size)
                for i in range(size[0]):
                    delay_time2[i] = delay_time()
                delay_time = delay_time2
            else:
                raise NotImplementedError(
                    f'Currently, BrainPy does not support delay type '
                    f'of {type(delay_time)}: {delay_time}')
            self.uniform_delay = False
            delay = delay_time / backend.get_dt()
            dint = ops.as_tensor(delay_time / backend.get_dt(), dtype=int)
            ddiff = (delay - dint) >= 0.5
            self.delay_num_step = ops.as_tensor(delay + ddiff, dtype=int) + 1
            self.delay_data = ops.zeros((max(self.delay_num_step), ) + size)
            self.diag = ops.arange(self.num)

        self.delay_in_idx = self.delay_num_step - 1
        if self.uniform_delay:
            self.delay_out_idx = 0
        else:
            self.delay_out_idx = ops.zeros(self.num, dtype=int)
        self.name = None
示例#12
0
文件: delay.py 项目: FeynmanW/BrainPy
    def __init__(self, size, delay_time):
        self.delay_time = delay_time
        self.delay_num_step = int(math.ceil(delay_time / backend.get_dt())) + 1
        self.delay_in_idx = 0
        self.delay_out_idx = self.delay_num_step - 1

        if isinstance(size, int):
            size = (size, )
        size = tuple(size)
        self.delay_data = backend.zeros((self.delay_num_step + 1, ) + size)
示例#13
0
def _base(A, B, C, f, show_code, dt, var_type, im_return):
    dt = backend.get_dt() if dt is None else dt
    show_code = False if show_code is None else show_code
    var_type = constants.SCALAR_VAR if var_type is None else var_type

    if f is None:
        return lambda f: general_rk_wrapper(f=f, show_code=show_code, dt=dt, A=A, B=B, C=C,
                                            var_type=var_type, im_return=im_return)
    else:
        return general_rk_wrapper(f=f, show_code=show_code, dt=dt, A=A, B=B, C=C,
                                  var_type=var_type, im_return=im_return)
示例#14
0
    def __init__(self, size, freqs, **kwargs):
        self.dt = backend.get_dt() / 1000.
        self.freqs = freqs
        self.size = (size,) if isinstance(size, int) else tuple(size)
        self.num = size2len(size)
        self.spike = ops.zeros(self.num, dtype=bool)
        self.t_last_spike = -1e7 * ops.ones(self.num)

        if backend.get_backend_name() == 'numba-cuda':
            super(PoissonInput, self).__init__(steps={'update': self.numba_cuda_update}, **kwargs)
        else:
            super(PoissonInput, self).__init__(steps={'update': self.non_numba_cuda_update}, **kwargs)
示例#15
0
def firing_rate(sp_matrix, width, window='gaussian'):
    """Calculate the mean firing rate over in a neuron group.

    This method is adopted from Brian2.

    The firing rate in trial :math:`k` is the spike count :math:`n_{k}^{sp}`
    in an interval of duration :math:`T` divided by :math:`T`:

    .. math::

        v_k = {n_k^{sp} \\over T}

    Parameters
    ----------
    sp_matrix : bnp.ndarray
        The spike matrix which record spiking activities.
    width : int, float
        The width of the ``window`` in millisecond.
    window : str
        The window to use for smoothing. It can be a string to chose a
        predefined window:

        - `flat`: a rectangular,
        - `gaussian`: a Gaussian-shaped window.

        For the `Gaussian` window, the `width` parameter specifies the
        standard deviation of the Gaussian, the width of the actual window
        is `4 * width + dt`.
        For the `flat` window, the width of the actual window
        is `2 * width/2 + dt`.

    Returns
    -------
    rate : numpy.ndarray
        The population rate in Hz, smoothed with the given window.
    """
    # rate
    rate = np.sum(sp_matrix, axis=1)

    # window
    dt = backend.get_dt()
    if window == 'gaussian':
        width1 = 2 * width / dt
        width2 = int(np.around(width1))
        window = np.exp(-np.arange(-width2, width2 + 1)**2 / (width1**2 / 2))
    elif window == 'flat':
        width1 = int(width / 2 / dt) * 2 + 1
        window = np.ones(width1)
    else:
        raise ValueError('Unknown window type "{}".'.format(window))
    window = np.float_(window)

    return np.convolve(rate, window / sum(window), mode='same')
示例#16
0
def period_input(values, durations, dt=None, return_length=False):
    """Format an input current with different periods.

    For example:

    If you want to get an input where the size is 0 bwteen 0-100 ms,
    and the size is 1. between 100-200 ms.
    >>> import numpy as np
    >>> period_input(values=[0, 1],
    >>>              durations=[100, 100])

    Parameters
    ----------
    values : list, np.ndarray
        The current values for each period duration.
    durations : list, np.ndarray
        The duration for each period.
    dt : float
        Default is None.
    return_length : bool
        Return the final duration length.

    Returns
    -------
    current_and_duration : tuple
        (The formatted current, total duration)
    """
    assert len(durations) == len(values), f'"values" and "durations" must be the same length, while ' \
                                          f'we got {len(values)} != {len(durations)}.'

    dt = backend.get_dt() if dt is None else dt

    # get input current shape, and duration
    I_duration = sum(durations)
    I_shape = ()
    for val in values:
        shape = ops.shape(val)
        if len(shape) > len(I_shape):
            I_shape = shape

    # get the current
    start = 0
    I_current = ops.zeros((int(math.ceil(I_duration / dt)),) + I_shape)
    for c_size, duration in zip(values, durations):
        length = int(duration / dt)
        I_current[start: start + length] = c_size
        start += length

    if return_length:
        return I_current, I_duration
    else:
        return I_current
示例#17
0
def exponential_euler(f):
    dt = backend.get_dt()
    dt_sqrt = dt**0.5

    def int_f(x, t, *args):
        df, linear_part, g = f(x, t, *args)
        dW = backend.normal(0., 1., backend.shape(x))
        dg = dt_sqrt * g * dW
        exp = backend.exp(linear_part * dt)
        y1 = x + (exp - 1) / linear_part * df + exp * dg
        return y1

    return int_f
示例#18
0
def spike_current(points, lengths, sizes, duration, dt=None):
    """Format current input like a series of short-time spikes.

    For example:

    If you want to generate a spike train at 10 ms, 20 ms, 30 ms, 200 ms, 300 ms,
    and each spike lasts 1 ms and the spike current is 0.5, then you can use the
    following funtions:

    >>> spike_current(points=[10, 20, 30, 200, 300],
    >>>               lengths=1.,  # can be a list to specify the spike length at each point
    >>>               sizes=0.5,  # can be a list to specify the current size at each point
    >>>               duration=400.)

    Parameters
    ----------
    points : list, tuple
        The spike time-points. Must be an iterable object.
    lengths : int, float, list, tuple
        The length of each point-current, mimicking the spike durations.
    sizes : int, float, list, tuple
        The current sizes.
    duration : int, float
        The total current duration.
    dt : float
        The default is None.

    Returns
    -------
    current_and_duration : tuple
        (The formatted current, total duration)
    """
    dt = backend.get_dt() if dt is None else dt
    assert isinstance(points, (list, tuple))
    if isinstance(lengths, (float, int)):
        lengths = [lengths] * len(points)
    if isinstance(sizes, (float, int)):
        sizes = [sizes] * len(points)

    current = np.zeros(int(np.ceil(duration / dt)))
    for time, dur, size in zip(points, lengths, sizes):
        pp = int(time / dt)
        p_len = int(dur / dt)
        current[pp:pp + p_len] = size
    return current
示例#19
0
def constant_current(Iext, dt=None):
    """Format constant input in durations.

    For example:

    If you want to get an input where the size is 0 bwteen 0-100 ms,
    and the size is 1. between 100-200 ms.
    >>> constant_current([(0, 100), (1, 100)])
    >>> constant_current([(np.zeros(100), 100), (np.random.rand(100), 100)])

    Parameters
    ----------
    Iext : list
        This parameter receives the current size and the current
        duration pairs, like `[(size1, duration1), (size2, duration2)]`.
    dt : float
        Default is None.

    Returns
    -------
    current_and_duration : tuple
        (The formatted current, total duration)
    """
    dt = backend.get_dt() if dt is None else dt

    # get input current dimension, shape, and duration
    I_duration = 0.
    I_dim = 0
    I_shape = ()
    for I in Iext:
        I_duration += I[1]
        dim = np.ndim(I[0])
        if dim > I_dim:
            I_dim = dim
            I_shape = np.shape(I[0])

    # get the current
    I_current = np.zeros((int(np.ceil(I_duration / dt)), ) + I_shape)
    start = 0
    for c_size, duration in Iext:
        length = int(duration / dt)
        I_current[start:start + length] = c_size
        start += length
    return I_current, I_duration
示例#20
0
def exponential_euler(f=None, show_code=None, dt=None, var_type=None, im_return=()):
    """First order, explicit exponential Euler method.

    For an ODE equation of the form

    .. math::

        y^{\\prime}=f(y), \quad y(0)=y_{0}

    its schema is given by

    .. math::

        y_{n+1}= y_{n}+h \\varphi(hA) f (y_{n})

    where :math:`A=f^{\prime}(y_{n})` and :math:`\\varphi(z)=\\frac{e^{z}-1}{z}`.

    For linear ODE system: :math:`y^{\\prime} = Ay + B`,
    the above equation is equal to

    .. math::

        y_{n+1}= y_{n}e^{hA}-B/A(1-e^{hA})

    Parameters
    ----------

    Returns
    -------
    func : callable
        The one-step numerical integrator function.
    """

    dt = backend.get_dt() if dt is None else dt
    show_code = False if show_code is None else show_code
    var_type = constants.SCALAR_VAR if var_type is None else var_type

    if f is None:
        return lambda f: exp_euler_wrapper(f, show_code=show_code, dt=dt,
                                           var_type=var_type, im_return=im_return)
    else:
        return exp_euler_wrapper(f, show_code=show_code, dt=dt,
                                 var_type=var_type, im_return=im_return)
示例#21
0
    def run(self, duration, inputs=(), report=False, report_percent=0.1):
        """Run the simulation for the given duration.

        This function provides the most convenient way to run the network.
        For example:

        Parameters
        ----------
        duration : int, float, tuple, list
            The amount of simulation time to run for.
        inputs : list, tuple
            The receivers, external inputs and durations.
        report : bool
            Report the progress of the simulation.
        report_percent : float
            The speed to report simulation progress.
        """
        # preparation
        start, end = utils.check_duration(duration)
        dt = backend.get_dt()
        ts = backend.arange(start, end, dt)

        # build the network
        run_length = ts.shape[0]
        format_inputs = utils.format_net_level_inputs(inputs, run_length)
        net_runner = backend.get_net_runner()(all_nodes=self.all_nodes)
        self.run_func = net_runner.build(run_length=run_length,
                                         formatted_inputs=format_inputs,
                                         return_code=False,
                                         show_code=self.show_code)

        # run the network
        utils.run_model(self.run_func,
                        times=ts,
                        report=report,
                        report_percent=report_percent)

        # end
        self.t_start, self.t_end = start, end
        for obj in self.all_nodes.values():
            if len(obj.mon['vars']) > 0:
                obj.mon['ts'] = ts
示例#22
0
    def run(self, duration, report=False, report_percent=0.1):
        if isinstance(duration, (int, float)):
            duration = [0, duration]
        elif isinstance(duration, (tuple, list)):
            assert len(duration) == 2
            duration = tuple(duration)
        else:
            raise ValueError

        # get the times
        times = ops.arange(duration[0], duration[1], backend.get_dt())
        # reshape the monitor
        for key in self.mon.keys():
            self.mon[key] = ops.zeros((len(times), ) +
                                      ops.shape(self.mon[key])[1:])
        # run the model
        run_model(run_func=self.run_func,
                  times=times,
                  report=report,
                  report_percent=report_percent)
示例#23
0
文件: utils.py 项目: yult0821/BrainPy
def run_model(run_func, times, report, report_percent):
    """Run the model.

    The "run_func" can be the step run function of a population, or a network.

    Parameters
    ----------
    run_func : callable
        The step run function.
    times : iterable
        The model running times.
    report : bool
        Whether report the progress of the running.
    report_percent : float
        The percent of the total running length for each report.
    """
    run_length = len(times)
    dt = backend.get_dt()
    if report:
        t0 = time.time()
        for i, t in enumerate(times[:1]):
            run_func(_t=t, _i=i, _dt=dt)
        compile_time = time.time() - t0
        print('Compilation used {:.4f} s.'.format(compile_time))

        print("Start running ...")
        report_gap = int(run_length * report_percent)
        t0 = time.time()
        for run_idx in range(1, run_length):
            run_func(_t=times[run_idx], _i=run_idx, _dt=dt)
            if (run_idx + 1) % report_gap == 0:
                percent = (run_idx + 1) / run_length * 100
                print('Run {:.1f}% used {:.3f} s.'.format(percent, time.time() - t0))
        running_time = time.time() - t0
        print('Simulation is done in {:.3f} s.'.format(running_time))
        print()
        return running_time
    else:
        for run_idx in range(run_length):
            run_func(_t=times[run_idx], _i=run_idx, _dt=dt)
        return None
示例#24
0
def _base(A,
          B1,
          B2,
          C,
          f=None,
          tol=None,
          adaptive=None,
          im_return=(),
          dt=None,
          show_code=None,
          var_type=None):
    """

    Parameters
    ----------
    A : list
    B1 : list
    B2 : list
    C : list
    f : callable
    tol : float
    adaptive : bool
    im_return : list
        Intermediate value return.
    dt : float
    show_code : bool
    var_type : str

    Returns
    -------

    """
    adaptive = False if (adaptive is None) else adaptive
    dt = backend.get_dt() if (dt is None) else dt
    tol = 0.1 if tol is None else tol
    show_code = False if tol is None else show_code
    var_type = constants.SCALAR_VAR if var_type is None else var_type

    if f is None:
        return lambda f: adaptive_rk_wrapper(f,
                                             dt=dt,
                                             A=A,
                                             B1=B1,
                                             B2=B2,
                                             C=C,
                                             tol=tol,
                                             adaptive=adaptive,
                                             show_code=show_code,
                                             var_type=var_type,
                                             im_return=im_return)
    else:
        return adaptive_rk_wrapper(f,
                                   dt=dt,
                                   A=A,
                                   B1=B1,
                                   B2=B2,
                                   C=C,
                                   tol=tol,
                                   adaptive=adaptive,
                                   show_code=show_code,
                                   var_type=var_type,
                                   im_return=im_return)
示例#25
0
 def ts(self):
     """Get the time points of the network.
     """
     return backend.arange(self.t_start, self.t_end, backend.get_dt())
示例#26
0
文件: plots.py 项目: yult0821/BrainPy
def animate_2D(values,
               net_size,
               dt=None,
               val_min=None,
               val_max=None,
               cmap=None,
               frame_delay=10,
               frame_step=1,
               title_size=10,
               figsize=None,
               gif_dpi=None,
               video_fps=None,
               save_path=None,
               show=True):
    """Animate the potentials of the neuron group.

    Parameters
    ----------
    values : np.ndarray
        The membrane potentials of the neuron group.
    net_size : tuple
        The size of the neuron group.
    dt : float
        The time duration of each step.
    val_min : float, int
        The minimum of the potential.
    val_max : float, int
        The maximum of the potential.
    cmap : str
        The colormap.
    frame_delay : int, float
        The delay to show each frame.
    frame_step : int
        The step to show the potential. If `frame_step=3`, then each
        frame shows one of the every three steps.
    title_size : int
        The size of the title.
    figsize : None, tuple
        The size of the figure.
    gif_dpi : int
        Controls the dots per inch for the movie frames. This combined with
        the figure's size in inches controls the size of the movie. If
        ``None``, use defaults in matplotlib.
    video_fps : int
        Frames per second in the movie. Defaults to ``None``, which will use
        the animation's specified interval to set the frames per second.
    save_path : None, str
        The save path of the animation.
    show : bool
        Whether show the animation.

    Returns
    -------
    anim : animation.FuncAnimation
        The created animation function.
    """
    dt = backend.get_dt() if dt is None else dt
    num_step, num_neuron = values.shape
    height, width = net_size
    val_min = values.min() if val_min is None else val_min
    val_max = values.max() if val_max is None else val_max

    figsize = figsize or (6, 6)

    fig = plt.figure(figsize=(figsize[0], figsize[1]), constrained_layout=True)
    gs = GridSpec(1, 1, figure=fig)
    fig.add_subplot(gs[0, 0])

    def frame(t):
        img = values[t]
        fig.clf()
        plt.pcolor(img, cmap=cmap, vmin=val_min, vmax=val_max)
        plt.colorbar()
        plt.axis('off')
        fig.suptitle(t="Time: {:.2f} ms".format((t + 1) * dt),
                     fontsize=title_size,
                     fontweight='bold')
        return [fig.gca()]

    values = values.reshape((num_step, height, width))
    anim = animation.FuncAnimation(fig=fig,
                                   func=frame,
                                   frames=list(range(1, num_step, frame_step)),
                                   init_func=None,
                                   interval=frame_delay,
                                   repeat_delay=3000)
    if save_path is None:
        if show:
            plt.show()
    else:
        if save_path[-3:] == 'gif':
            anim.save(save_path, dpi=gif_dpi, writer='imagemagick')
        elif save_path[-3:] == 'mp4':
            anim.save(save_path, writer='ffmpeg', fps=video_fps, bitrate=3000)
        else:
            anim.save(save_path + '.mp4', writer='ffmpeg', fps=video_fps, bitrate=3000)
    return anim
示例#27
0
    def plot_trajectory(self,
                        initials,
                        duration,
                        plot_duration=None,
                        show=False):
        """Plot trajectories according to the settings.

        Parameters
        ----------
        initials : list, tuple
            The initial value setting of the targets. It can be a tuple/list of floats to specify
            each value of dynamical variables (for example, ``(a, b)``). It can also be a
            tuple/list of tuple to specify multiple initial values (for example,
            ``[(a1, b1), (a2, b2)]``).
        duration : int, float, tuple, list
            The running duration. Same with the ``duration`` in ``NeuGroup.run()``.
            It can be a int/float (``t_end``) to specify the same running end time,
            or it can be a tuple/list of int/float (``(t_start, t_end)``) to specify
            the start and end simulation time. Or, it can be a list of tuple
            (``[(t1_start, t1_end), (t2_start, t2_end)]``) to specify the specific
            start and end simulation time for each initial value.
        plot_duration : tuple/list of tuple, optional
            The duration to plot. It can be a tuple with ``(start, end)``. It can
            also be a list of tuple ``[(start1, end1), (start2, end2)]`` to specify
            the plot duration for each initial value running.
        show : bool
            Whether show or not.
        """
        print('plot trajectory ...')

        # 1. format the initial values
        all_vars = self.fast_var_names + self.slow_var_names
        if isinstance(initials, dict):
            initials = [initials]
        elif isinstance(initials, (list, tuple)):
            if isinstance(initials[0], (int, float)):
                initials = [{all_vars[i]: v for i, v in enumerate(initials)}]
            elif isinstance(initials[0], dict):
                initials = initials
            elif isinstance(initials[0], (tuple, list)) and isinstance(
                    initials[0][0], (int, float)):
                initials = [{all_vars[i]: v
                             for i, v in enumerate(init)} for init in initials]
            else:
                raise ValueError
        else:
            raise ValueError
        for initial in initials:
            if len(initial) != len(all_vars):
                raise errors.AnalyzerError(
                    f'Should provide all fast-slow variables ({all_vars}) '
                    f' initial values, but we only get initial values for '
                    f'variables {list(initial.keys())}.')

        # 2. format the running duration
        if isinstance(duration, (int, float)):
            duration = [(0, duration) for _ in range(len(initials))]
        elif isinstance(duration[0], (int, float)):
            duration = [duration for _ in range(len(initials))]
        else:
            assert len(duration) == len(initials)

        # 3. format the plot duration
        if plot_duration is None:
            plot_duration = duration
        if isinstance(plot_duration[0], (int, float)):
            plot_duration = [plot_duration for _ in range(len(initials))]
        else:
            assert len(plot_duration) == len(initials)

        # 5. run the network
        for init_i, initial in enumerate(initials):
            traj_group = Trajectory(size=1,
                                    integrals=self.model.integrals,
                                    target_vars=initial,
                                    fixed_vars=self.fixed_vars,
                                    pars_update=self.pars_update,
                                    scope=self.model.scopes)
            traj_group.run(duration=duration[init_i], report=False)

            #   5.3 legend
            legend = f'$traj_{init_i}$: '
            for key in all_vars:
                legend += f'{key}={initial[key]}, '
            legend = legend[:-2]

            #   5.4 trajectory
            start = int(plot_duration[init_i][0] / backend.get_dt())
            end = int(plot_duration[init_i][1] / backend.get_dt())

            #   5.5 visualization
            for var_name in self.fast_var_names:
                s0 = traj_group.mon[self.slow_var_names[0]][start:end, 0]
                fast = traj_group.mon[var_name][start:end, 0]

                fig = plt.figure(var_name)
                if len(self.slow_var_names) == 1:
                    lines = plt.plot(s0, fast, label=legend)
                    utils.add_arrow(lines[0])
                    # middle = int(s0.shape[0] / 2)
                    # plt.arrow(s0[middle], fast[middle],
                    #           s0[middle + 1] - s0[middle], fast[middle + 1] - fast[middle],
                    #           shape='full')

                elif len(self.slow_var_names) == 2:
                    fig.gca(projection='3d')
                    s1 = traj_group.mon[self.slow_var_names[1]][start:end, 0]
                    plt.plot(s0, s1, fast, label=legend)
                else:
                    raise errors.AnalyzerError

        # 6. visualization
        for var_name in self.fast_vars.keys():
            fig = plt.figure(var_name)

            # scale = (self.lim_scale - 1.) / 2
            if len(self.slow_var_names) == 1:
                # plt.xlim(*utils.rescale(self.slow_vars[self.slow_var_names[0]], scale=scale))
                # plt.ylim(*utils.rescale(self.fast_vars[var_name], scale=scale))
                plt.xlabel(self.slow_var_names[0])
                plt.ylabel(var_name)
            elif len(self.slow_var_names) == 2:
                ax = fig.gca(projection='3d')
                # ax.set_xlim(*utils.rescale(self.slow_vars[self.slow_var_names[0]], scale=scale))
                # ax.set_ylim(*utils.rescale(self.slow_vars[self.slow_var_names[1]], scale=scale))
                # ax.set_zlim(*utils.rescale(self.fast_vars[var_name], scale=scale))
                ax.set_xlabel(self.slow_var_names[0])
                ax.set_ylabel(self.slow_var_names[1])
                ax.set_zlabel(var_name)

            plt.legend()

        if show:
            plt.show()
示例#28
0
    def plot_trajectory(self,
                        initials,
                        duration,
                        plot_duration=None,
                        axes='v-v',
                        show=False):
        """Plot trajectories according to the settings.

        Parameters
        ----------
        initials : list, tuple, dict
            The initial value setting of the targets. It can be a tuple/list of floats to specify
            each value of dynamical variables (for example, ``(a, b)``). It can also be a
            tuple/list of tuple to specify multiple initial values (for example,
            ``[(a1, b1), (a2, b2)]``).
        duration : int, float, tuple, list
            The running duration. Same with the ``duration`` in ``NeuGroup.run()``.
            It can be a int/float (``t_end``) to specify the same running end time,
            or it can be a tuple/list of int/float (``(t_start, t_end)``) to specify
            the start and end simulation time. Or, it can be a list of tuple
            (``[(t1_start, t1_end), (t2_start, t2_end)]``) to specify the specific
            start and end simulation time for each initial value.
        plot_duration : tuple, list, optional
            The duration to plot. It can be a tuple with ``(start, end)``. It can
            also be a list of tuple ``[(start1, end1), (start2, end2)]`` to specify
            the plot duration for each initial value running.
        axes : str
            The axes to plot. It can be:

                 - 'v-v'
                        Plot the trajectory in the 'x_var'-'y_var' axis.
                 - 't-v'
                        Plot the trajectory in the 'time'-'var' axis.
        show : bool
            Whether show or not.
        """

        print('plot trajectory ...')

        if axes not in ['v-v', 't-v']:
            raise errors.ModelUseError(
                f'Unknown axes "{axes}", only support "v-v" and "t-v".')

        # 1. format the initial values
        if isinstance(initials, dict):
            initials = [initials]
        elif isinstance(initials, (list, tuple)):
            if isinstance(initials[0], (int, float)):
                initials = [{
                    self.dvar_names[i]: v
                    for i, v in enumerate(initials)
                }]
            elif isinstance(initials[0], dict):
                initials = initials
            elif isinstance(initials[0], (tuple, list)) and isinstance(
                    initials[0][0], (int, float)):
                initials = [{
                    self.dvar_names[i]: v
                    for i, v in enumerate(init)
                } for init in initials]
            else:
                raise ValueError
        else:
            raise ValueError

        # 2. format the running duration
        if isinstance(duration, (int, float)):
            duration = [(0, duration) for _ in range(len(initials))]
        elif isinstance(duration[0], (int, float)):
            duration = [duration for _ in range(len(initials))]
        else:
            assert len(duration) == len(initials)

        # 3. format the plot duration
        if plot_duration is None:
            plot_duration = duration
        if isinstance(plot_duration[0], (int, float)):
            plot_duration = [plot_duration for _ in range(len(initials))]
        else:
            assert len(plot_duration) == len(initials)

        # 5. run the network
        for init_i, initial in enumerate(initials):
            traj_group = Trajectory(size=1,
                                    integrals=self.model.integrals,
                                    target_vars=initial,
                                    fixed_vars=self.fixed_vars,
                                    pars_update=self.pars_update,
                                    scope=self.model.scopes)

            #   5.2 run the model
            traj_group.run(
                duration=duration[init_i],
                report=False,
            )

            #   5.3 legend
            legend = f'$traj_{init_i}$: '
            for key in self.dvar_names:
                legend += f'{key}={initial[key]}, '
            legend = legend[:-2]

            #   5.4 trajectory
            start = int(plot_duration[init_i][0] / backend.get_dt())
            end = int(plot_duration[init_i][1] / backend.get_dt())

            #   5.5 visualization
            if axes == 'v-v':
                lines = plt.plot(traj_group.mon[self.x_var][start:end, 0],
                                 traj_group.mon[self.y_var][start:end, 0],
                                 label=legend)
                utils.add_arrow(lines[0])
            else:
                plt.plot(traj_group.mon.ts[start:end],
                         traj_group.mon[self.x_var][start:end, 0],
                         label=legend + f', {self.x_var}')
                plt.plot(traj_group.mon.ts[start:end],
                         traj_group.mon[self.y_var][start:end, 0],
                         label=legend + f', {self.y_var}')

        # 6. visualization
        if axes == 'v-v':
            plt.xlabel(self.x_var)
            plt.ylabel(self.y_var)
            scale = (self.options.lim_scale - 1.) / 2
            plt.xlim(*utils.rescale(self.target_vars[self.x_var], scale=scale))
            plt.ylim(*utils.rescale(self.target_vars[self.y_var], scale=scale))
            plt.legend()
        else:
            plt.legend(title='Initial values')

        if show:
            plt.show()
示例#29
0
文件: plots.py 项目: yult0821/BrainPy
def animate_1D(dynamical_vars,
               static_vars=(),
               dt=None,
               xlim=None,
               ylim=None,
               xlabel=None,
               ylabel=None,
               frame_delay=50.,
               frame_step=1,
               title_size=10,
               figsize=None,
               gif_dpi=None,
               video_fps=None,
               save_path=None,
               show=True):
    """Animation of one-dimensional data.

    Parameters
    ----------
    dynamical_vars : dict, np.ndarray, list of np.ndarray, list of dict
        The dynamical variables which will be animated.
    static_vars : dict, np.ndarray, list of np.ndarray, list of dict
        The static variables.
    xticks : list, np.ndarray
        The xticks.
    dt : float
        The numerical integration step.
    xlim : tuple
        The xlim.
    ylim : tuple
        The ylim.
    xlabel : str
        The xlabel.
    ylabel : str
        The ylabel.
    frame_delay : int, float
        The delay to show each frame.
    frame_step : int
        The step to show the potential. If `frame_step=3`, then each
        frame shows one of the every three steps.
    title_size : int
        The size of the title.
    figsize : None, tuple
        The size of the figure.
    gif_dpi : int
        Controls the dots per inch for the movie frames. This combined with
        the figure's size in inches controls the size of the movie. If
        ``None``, use defaults in matplotlib.
    video_fps : int
        Frames per second in the movie. Defaults to ``None``, which will use
        the animation's specified interval to set the frames per second.
    save_path : None, str
        The save path of the animation.
    show : bool
        Whether show the animation.

    Returns
    -------
    figure : plt.figure
        The created figure instance.
    """

    # check dt
    dt = backend.get_dt() if dt is None else dt

    # check figure
    fig = plt.figure(figsize=(figsize or (6, 6)), constrained_layout=True)
    gs = GridSpec(1, 1, figure=fig)
    fig.add_subplot(gs[0, 0])

    # check dynamical variables
    final_dynamic_vars = []
    lengths = []
    has_legend = False
    if isinstance(dynamical_vars, (tuple, list)):
        for var in dynamical_vars:
            if isinstance(var, dict):
                assert 'ys' in var, 'Must provide "ys" item.'
                if 'legend' not in var:
                    var['legend'] = None
                else:
                    has_legend = True
                if 'xs' not in var:
                    var['xs'] = np.arange(var['ys'].shape[1])
            elif isinstance(var, np.ndarray):
                var = {'ys': var,
                       'xs': np.arange(var.shape[1]),
                       'legend': None}
            else:
                raise ValueError(f'Unknown data type: {type(var)}')
            assert np.ndim(var['ys']) == 2, "Dynamic variable must be 2D data."
            lengths.append(var['ys'].shape[0])
            final_dynamic_vars.append(var)
    elif isinstance(dynamical_vars, np.ndarray):
        assert np.ndim(dynamical_vars) == 2, "Dynamic variable must be 2D data."
        lengths.append(dynamical_vars.shape[0])
        final_dynamic_vars.append({'ys': dynamical_vars,
                                   'xs': np.arange(dynamical_vars.shape[1]),
                                   'legend': None})
    elif isinstance(dynamical_vars, dict):
        assert 'ys' in dynamical_vars, 'Must provide "ys" item.'
        if 'legend' not in dynamical_vars:
            dynamical_vars['legend'] = None
        else:
            has_legend = True
        if 'xs' not in dynamical_vars:
            dynamical_vars['xs'] = np.arange(dynamical_vars['ys'].shape[1])
        lengths.append(dynamical_vars['ys'].shape[0])
        final_dynamic_vars.append(dynamical_vars)
    else:
        raise ValueError(f'Unknown dynamical data type: {type(dynamical_vars)}')
    lengths = np.array(lengths)
    assert np.all(lengths == lengths[0]), 'Dynamic variables must have equal length.'

    # check static variables
    final_static_vars = []
    if isinstance(static_vars, (tuple, list)):
        for var in static_vars:
            if isinstance(var, dict):
                assert 'data' in var, 'Must provide "ys" item.'
                if 'legend' not in var:
                    var['legend'] = None
                else:
                    has_legend = True
            elif isinstance(var, np.ndarray):
                var = {'data': var, 'legend': None}
            else:
                raise ValueError(f'Unknown data type: {type(var)}')
            assert np.ndim(var['data']) == 1, "Static variable must be 1D data."
            final_static_vars.append(var)
    elif isinstance(static_vars, np.ndarray):
        final_static_vars.append({'data': static_vars,
                                  'xs': np.arange(static_vars.shape[0]),
                                  'legend': None})
    elif isinstance(static_vars, dict):
        assert 'ys' in static_vars, 'Must provide "ys" item.'
        if 'legend' not in static_vars:
            static_vars['legend'] = None
        else:
            has_legend = True
        if 'xs' not in static_vars:
            static_vars['xs'] = np.arange(static_vars['ys'].shape[0])
        final_static_vars.append(static_vars)

    else:
        raise ValueError(f'Unknown static data type: {type(static_vars)}')

    # ylim
    if ylim is None:
        ylim_min = np.inf
        ylim_max = -np.inf
        for var in final_dynamic_vars + final_static_vars:
            if var['ys'].max() > ylim_max:
                ylim_max = var['ys'].max()
            if var['ys'].min() < ylim_min:
                ylim_min = var['ys'].min()
        if ylim_min > 0:
            ylim_min = ylim_min * 0.98
        else:
            ylim_min = ylim_min * 1.02
        if ylim_max > 0:
            ylim_max = ylim_max * 1.02
        else:
            ylim_max = ylim_max * 0.98
        ylim = (ylim_min, ylim_max)

    def frame(t):
        fig.clf()
        for dvar in final_dynamic_vars:
            plt.plot(dvar['xs'], dvar['ys'][t], label=dvar['legend'])
        for svar in final_static_vars:
            plt.plot(svar['xs'], svar['ys'], label=svar['legend'])
        if xlim is not None:
            plt.xlim(xlim[0], xlim[1])
        if has_legend:
            plt.legend()
        if xlabel:
            plt.xlabel(xlabel)
        if ylabel:
            plt.ylabel(ylabel)
        plt.ylim(ylim[0], ylim[1])
        fig.suptitle(t="Time: {:.2f} ms".format((t + 1) * dt),
                     fontsize=title_size,
                     fontweight='bold')
        return [fig.gca()]

    anim_result = animation.FuncAnimation(fig=fig,
                                          func=frame,
                                          frames=range(1, lengths[0], frame_step),
                                          init_func=None,
                                          interval=frame_delay,
                                          repeat_delay=3000)

    # save or show
    if save_path is None:
        if show: plt.show()
    else:
        if save_path[-3:] == 'gif':
            anim_result.save(save_path, dpi=gif_dpi, writer='imagemagick')
        elif save_path[-3:] == 'mp4':
            anim_result.save(save_path, writer='ffmpeg', fps=video_fps, bitrate=3000)
        else:
            anim_result.save(save_path + '.mp4', writer='ffmpeg', fps=video_fps, bitrate=3000)
    return fig
示例#30
0
    def get_integral_step(diff_eq, *args):
        dt = backend.get_dt()
        f_expressions = diff_eq.get_f_expressions(
            substitute_vars=diff_eq.var_name)

        # code lines
        code_lines = [str(expr) for expr in f_expressions[:-1]]

        # get the linear system using sympy
        f_res = f_expressions[-1]
        df_expr = ast_analysis.str2sympy(f_res.code).expr.expand()
        s_df = sympy.Symbol(f"{f_res.var_name}")
        code_lines.append(f'{s_df.name} = {ast_analysis.sympy2str(df_expr)}')
        var = sympy.Symbol(diff_eq.var_name, real=True)

        # get df part
        s_linear = sympy.Symbol(f'_{diff_eq.var_name}_linear')
        s_linear_exp = sympy.Symbol(f'_{diff_eq.var_name}_linear_exp')
        s_df_part = sympy.Symbol(f'_{diff_eq.var_name}_df_part')
        if df_expr.has(var):
            # linear
            linear = sympy.collect(df_expr, var, evaluate=False)[var]
            code_lines.append(
                f'{s_linear.name} = {ast_analysis.sympy2str(linear)}')
            # linear exponential
            linear_exp = sympy.exp(linear * dt)
            code_lines.append(
                f'{s_linear_exp.name} = {ast_analysis.sympy2str(linear_exp)}')
            # df part
            df_part = (s_linear_exp - 1) / s_linear * s_df
            code_lines.append(
                f'{s_df_part.name} = {ast_analysis.sympy2str(df_part)}')

        else:
            # linear exponential
            code_lines.append(f'{s_linear_exp.name} = sqrt({dt})')
            # df part
            code_lines.append(
                f'{s_df_part.name} = {ast_analysis.sympy2str(dt * s_df)}')

        # get dg part
        if diff_eq.is_stochastic:
            # dW
            noise = f'_normal_like_({diff_eq.var_name})'
            code_lines.append(f'_{diff_eq.var_name}_dW = {noise}')
            # expressions of the stochastic part
            g_expressions = diff_eq.get_g_expressions()
            code_lines.extend([str(expr) for expr in g_expressions[:-1]])
            g_expr = g_expressions[-1].code
            # get the dg_part
            s_dg_part = sympy.Symbol(f'_{diff_eq.var_name}_dg_part')
            code_lines.append(
                f'_{diff_eq.var_name}_dg_part = {g_expr} * _{diff_eq.var_name}_dW'
            )
        else:
            s_dg_part = 0

        # update expression
        update = var + s_df_part + s_dg_part * s_linear_exp

        # The actual update step
        code_lines.append(
            f'{diff_eq.var_name} = {ast_analysis.sympy2str(update)}')
        return_expr = ', '.join([diff_eq.var_name] +
                                diff_eq.return_intermediates)
        code_lines.append(f'_res = {return_expr}')

        # final
        code = '\n'.join(code_lines)
        subs_dict = {
            arg: f'_{arg}'
            for arg in diff_eq.func_args + diff_eq.expr_names
        }
        code = tools.word_replace(code, subs_dict)
        return code