Exemple #1
0
    def __init__(self, v0, delay_len, before_t0=0., t0=0., dt=None):
        # size
        self.size = math.shape(v0)

        # delay_len
        self.delay_len = delay_len
        self.dt = math.get_dt() if dt is None else dt
        self.num_delay = int(math.ceil(delay_len / self.dt))

        # other variables
        self._delay_in = self.num_delay - 1
        self._delay_out = 0
        self.current_time = t0

        # before_t0
        self.before_t0 = before_t0

        # delay data
        self.data = math.zeros((self.num_delay + 1, ) + self.size)
        if callable(before_t0):
            for i in range(self.num_delay):
                self.data[i] = before_t0(t0 + (i - self.num_delay) * self.dt)
        else:
            self.data[:-1] = before_t0
        self.data[-1] = v0
Exemple #2
0
  def __init__(self, size, freq, **kwargs):
    super(PoissonNoise, self).__init__(size=size, **kwargs)

    self.freq = bm.Variable(bm.array([freq]))
    self.dt = bm.get_dt() / 1000.
    self.spike = bm.Variable(bm.zeros(self.num, dtype=bool))
    self.rng = bm.random.RandomState()
Exemple #3
0
def ramp_input(c_start, c_end, duration, t_start=0, t_end=None, dt=None):
    """Get the gradually changed input current.

  Parameters
  ----------
  c_start : float
      The minimum (or maximum) current size.
  c_end : float
      The maximum (or minimum) current size.
  duration : int, float
      The total duration.
  t_start : float
      The ramped current start time-point.
  t_end : float
      The ramped current end time-point. Default is the None.
  dt : float, int, optional
      The numerical precision.

  Returns
  -------
  current_and_duration : tuple
      (The formatted current, total duration)
  """
    dt = math.get_dt() if dt is None else dt
    t_end = duration if t_end is None else t_end

    current = math.zeros(int(np.ceil(duration / dt)), dtype=math.float_)
    p1 = int(np.ceil(t_start / dt))
    p2 = int(np.ceil(t_end / dt))
    current[p1:p2] = math.array(math.linspace(c_start, c_end, p2 - p1),
                                dtype=math.float_)
    return current
Exemple #4
0
def firing_rate(sp_matrix, width, dt=None):
    r"""Calculate the mean firing rate over in a neuron group.

  This method is adopted from Brian2.

  The firing rate in trial :math:`k` is the spike count :math:`n_{k}^{sp}`
  in an interval of duration :math:`T` divided by :math:`T`:

  .. math::

      v_k = {n_k^{sp} \over T}

  Parameters
  ----------
  sp_matrix : math.JaxArray, np.ndarray
    The spike matrix which record spiking activities.
  width : int, float
    The width of the ``window`` in millisecond.
  dt : float, optional
    The sample rate.

  Returns
  -------
  rate : numpy.ndarray
      The population rate in Hz, smoothed with the given window.
  """
    sp_matrix = np.asarray(sp_matrix)
    rate = np.sum(sp_matrix, axis=1) / sp_matrix.shape[1]
    dt = math.get_dt() if dt is None else dt
    width1 = int(width / 2 / dt) * 2 + 1
    window = np.ones(width1) * 1000 / width
    return np.convolve(rate, window, mode='same')
Exemple #5
0
  def __init__(self, f, var_type=None, dt=None, name=None, show_code=False):
    super(ODEIntegrator, self).__init__(name=name)

    # others
    self.dt = math.get_dt() if dt is None else dt
    assert isinstance(self.dt, (int, float)), f'"dt" must be a float, but got {self.dt}'
    self.show_code = show_code

    # derivative function
    self.derivative = {constants.F: f}
    self.f = f

    # integration function
    self.integral = None

    # parse function arguments
    variables, parameters, arguments = utils.get_args(f)
    self.variables = variables  # variable names, (before 't')
    self.parameters = parameters  # parameter names, (after 't')
    self.arguments = list(arguments) + [f'{constants.DT}={self.dt}']  # function arguments
    self.var_type = var_type  # variable type

    # code scope
    self.code_scope = {constants.F: f}

    # code lines
    self.func_name = f_names(f)
    self.code_lines = [f'def {self.func_name}({", ".join(self.arguments)}):']
Exemple #6
0
  def plot_trajectory(self, initials, duration, plot_durations=None,
                      dt=None, show=False, with_plot=True, with_return=False):
    utils.output('I am plotting the trajectory ...')

    # check the initial values
    initials = utils.check_initials(initials, self.target_var_names + self.target_par_names)

    # 2. format the running duration
    assert isinstance(duration, (int, float))

    # 3. format the plot duration
    plot_durations = utils.check_plot_durations(plot_durations, duration, initials)

    # 5. run the network
    dt = bm.get_dt() if dt is None else dt

    traject_model = utils.TrajectModel(initial_vars=initials, integrals=self._std_integrators, dt=dt)
    mon_res = traject_model.run(duration=duration)

    if with_plot:
      assert len(self.target_par_names) <= 2

      # plots
      for i, initial in enumerate(zip(*list(initials.values()))):
        # legend
        legend = f'$traj_{i}$: '
        for j, key in enumerate(self.target_var_names):
          legend += f'{key}={initial[j]}, '
        legend = legend[:-2]

        # visualization
        start = int(plot_durations[i][0] / dt)
        end = int(plot_durations[i][1] / dt)
        p1_var = self.target_par_names[0]
        if len(self.target_par_names) == 1:
          lines = plt.plot(mon_res[self.x_var][start: end, i],
                           mon_res[p1_var][start: end, i], label=legend)
        elif len(self.target_par_names) == 2:
          p2_var = self.target_par_names[1]
          lines = plt.plot(mon_res[self.x_var][start: end, i],
                           mon_res[p1_var][start: end, i],
                           mon_res[p2_var][start: end, i],
                           label=legend)
        else:
          raise ValueError
        utils.add_arrow(lines[0])

      # # visualization of others
      # plt.xlabel(self.x_var)
      # plt.ylabel(self.target_par_names[0])
      # scale = (self.lim_scale - 1.) / 2
      # plt.xlim(*utils.rescale(self.target_vars[self.x_var], scale=scale))
      # plt.ylim(*utils.rescale(self.target_vars[self.target_par_names[0]], scale=scale))
      plt.legend()

      if show:
        plt.show()

    if with_return:
      return mon_res
Exemple #7
0
 def integral(*args, **kwargs):
     assert len(args) > 0
     dt = kwargs.pop('dt', math.get_dt())
     linear, derivative = value_and_grad(*args, **kwargs)
     phi = math.where(linear == 0., math.ones_like(linear),
                      (math.exp(dt * linear) - 1) / (dt * linear))
     return args[0] + dt * phi * derivative
Exemple #8
0
def run_model(run_func, times, report, dt=None, extra_func=None):
    """Run the model.

  The "run_func" can be the step run function of a dynamical system.

  Parameters
  ----------
  run_func : callable
      The step run function.
  times : iterable
      The model running times.
  report : float
      The percent of the total running length for each report.
  """

    # numerical integration step
    if dt is None:
        dt = math.get_dt()
    assert isinstance(dt, (int, float))

    # running function
    if extra_func is None:
        running_func = run_func
    else:

        def running_func(t_and_dt):
            extra_func(*t_and_dt)
            run_func(t_and_dt)

    # simulations
    run_length = len(times)
    if report:
        t0 = time.time()
        running_func((times[0], dt))
        compile_time = time.time() - t0
        print('Compilation used {:.4f} s.'.format(compile_time))

        print("Start running ...")
        report_gap = int(run_length * report)
        t0 = time.time()
        for run_idx in range(1, run_length):
            running_func((times[run_idx], dt))
            if (run_idx + 1) % report_gap == 0:
                percent = (run_idx + 1) / run_length * 100
                print('Run {:.1f}% used {:.3f} s.'.format(
                    percent,
                    time.time() - t0))
        running_time = time.time() - t0
        print('Simulation is done in {:.3f} s.'.format(running_time))
        print()

    else:
        t0 = time.time()
        for run_idx in range(run_length):
            running_func((times[run_idx], dt))
        running_time = time.time() - t0

    return running_time
Exemple #9
0
def cross_correlation(spikes, bin, dt=None):
    r"""Calculate cross correlation index between neurons.

  The coherence [1]_ between two neurons i and j is measured by their
  cross-correlation of spike trains at zero time lag within a time bin
  of :math:`\Delta t = \tau`. More specifically, suppose that a long
  time interval T is divided into small bins of :math:`\Delta t` and
  that two spike trains are given by :math:`X(l)=` 0 or 1, :math:`Y(l)=` 0
  or 1, :math:`l=1,2, \ldots, K(T / K=\tau)`. Thus, we define a coherence
  measure for the pair as:

  .. math::

      \kappa_{i j}(\tau)=\frac{\sum_{l=1}^{K} X(l) Y(l)}
      {\sqrt{\sum_{l=1}^{K} X(l) \sum_{l=1}^{K} Y(l)}}

  The population coherence measure :math:`\kappa(\tau)` is defined by the
  average of :math:`\kappa_{i j}(\tau)` over many pairs of neurons in the
  network.

  Parameters
  ----------
  spikes :
      The history of spike states of the neuron group.
      It can be easily get via `StateMonitor(neu, ['spike'])`.
  bin : float, int
      The time bin to normalize spike states.
  dt : float, optional
      The time precision.

  Returns
  -------
  cc_index : float
      The cross correlation value which represents the synchronization index.

  References
  ----------
  .. [1] Wang, Xiao-Jing, and György Buzsáki. "Gamma oscillation by synaptic
         inhibition in a hippocampal interneuronal network model." Journal of
         neuroscience 16.20 (1996): 6402-6413.
  """
    spikes = np.asarray(spikes)
    dt = math.get_dt() if dt is None else dt
    bin_size = int(bin / dt)
    num_hist, num_neu = spikes.shape
    num_bin = int(np.ceil(num_hist / bin_size))
    if num_bin * bin_size != num_hist:
        spikes = np.append(spikes,
                           np.zeros((num_bin * bin_size - num_hist, num_neu)),
                           axis=0)
    states = spikes.T.reshape((num_neu, num_bin, bin_size))
    states = (np.sum(states, axis=2) > 0.).astype(np.float_)
    all_k = []
    for i in range(num_neu):
        for j in range(i + 1, num_neu):
            all_k.append(_cc(states, i, j))
    return np.mean(all_k)
Exemple #10
0
  def __init__(self, size, freqs, seed=None, name=None):
    super(PoissonInput, self).__init__(size=size, name=name)

    self.freqs = freqs
    self.dt = bm.get_dt() / 1000.
    self.size = (size,) if isinstance(size, int) else tuple(size)
    self.spike = bm.Variable(bm.zeros(self.num, dtype=bool))
    self.t_last_spike = bm.Variable(bm.ones(self.num) * -1e7)
    self.rng = bm.random.RandomState(seed=seed)
Exemple #11
0
  def plot_trajectory(self, initials, duration, plot_durations=None,
                      dt=None, show=False, with_plot=True, with_return=False):
    utils.output('I am plotting the trajectory ...')

    # check the initial values
    initials = utils.check_initials(initials, self.target_var_names + self.target_par_names)

    # 2. format the running duration
    assert isinstance(duration, (int, float))

    # 3. format the plot duration
    plot_durations = utils.check_plot_durations(plot_durations, duration, initials)

    # 5. run the network
    dt = bm.get_dt() if dt is None else dt

    traject_model = utils.TrajectModel(initial_vars=initials, integrals=self._std_integrators, dt=dt)
    mon_res = traject_model.run(duration=duration)

    if with_plot:
      assert len(self.target_par_names) <= 1
      # plots
      for i, initial in enumerate(zip(*list(initials.values()))):
        # legend
        legend = f'$traj_{i}$: '
        for j, key in enumerate(self.target_var_names):
          legend += f'{key}={initial[j]}, '
        legend = legend[:-2]

        start = int(plot_durations[i][0] / dt)
        end = int(plot_durations[i][1] / dt)

        # visualization
        plt.figure(self.x_var)
        lines = plt.plot(mon_res[self.target_par_names[0]][start: end, i],
                         mon_res[self.x_var][start: end, i],
                         label=legend)
        utils.add_arrow(lines[0])

        plt.figure(self.y_var)
        lines = plt.plot(mon_res[self.target_par_names[0]][start: end, i],
                         mon_res[self.y_var][start: end, i],
                         label=legend)
        utils.add_arrow(lines[0])

      plt.figure(self.x_var)
      plt.legend()
      plt.figure(self.y_var)
      plt.legend()

      if show:
        plt.show()

    if with_return:
      return mon_res
Exemple #12
0
    def __init__(self,
                 target,
                 monitors=None,
                 inputs=(),
                 dt=None,
                 jit=False,
                 dyn_vars=None,
                 numpy_mon_after_run=False):
        dt = math.get_dt() if dt is None else dt
        if not isinstance(dt, (int, float)):
            raise RunningError(f'"dt" must be scalar, but got {dt}')
        self.dt = dt
        self.jit = jit
        self.numpy_mon_after_run = numpy_mon_after_run

        # target
        if not isinstance(target, DynamicalSystem):
            raise RunningError(
                f'"target" must be an instance of {DynamicalSystem.__name__}, '
                f'but we got {type(target)}: {target}')
        self.target = target

        # dynamical changed variables
        if dyn_vars is None:
            dyn_vars = self.target.vars().unique()
        if isinstance(dyn_vars, (list, tuple)):
            dyn_vars = {f'_v{i}': v for i, v in enumerate(dyn_vars)}
        if not isinstance(dyn_vars, dict):
            raise RunningError(
                f'"dyn_vars" must be a dict, but we got {type(dyn_vars)}')
        self.dyn_vars = dyn_vars

        # monitors
        if monitors is None:
            self.mon = Monitor(target=self, variables=[])
        elif isinstance(monitors, (list, tuple, dict)):
            self.mon = Monitor(target=self, variables=monitors)
        elif isinstance(monitors, Monitor):
            self.mon = monitors
            self.mon.target = self
        else:
            raise MonitorError(f'"monitors" only supports list/tuple/dict/ '
                               f'instance of Monitor, not {type(monitors)}.')
        self.mon.build()  # build the monitor
        # Build the monitor function
        #   All the monitors are wrapped in a single function.
        self._monitor_step = self.build_monitors()

        # Build input function
        inputs = utils.check_and_format_inputs(host=target, inputs=inputs)
        self._input_step = self.build_inputs(inputs)

        # start simulation time
        self._start_t = None
Exemple #13
0
  def __init__(self, size, freq_mean, freq_var, t_interval, **kwargs):
    super(PoissonStim, self).__init__(size=size, **kwargs)

    self.freq_mean = freq_mean
    self.freq_var = freq_var
    self.t_interval = t_interval
    self.dt = bm.get_dt() / 1000.

    self.freq = bm.Variable(bm.zeros(1))
    self.freq_t_last_change = bm.Variable(bm.ones(1) * -1e7)
    self.spike = bm.Variable(bm.zeros(self.num, dtype=bool))
    self.rng = bm.random.RandomState()
Exemple #14
0
def section_input(values, durations, dt=None, return_length=False):
    """Format an input current with different sections.

  For example:

  If you want to get an input where the size is 0 bwteen 0-100 ms,
  and the size is 1. between 100-200 ms.

  >>> section_input(values=[0, 1],
  >>>               durations=[100, 100])

  Parameters
  ----------
  values : list, np.ndarray
      The current values for each period duration.
  durations : list, np.ndarray
      The duration for each period.
  dt : float
      Default is None.
  return_length : bool
      Return the final duration length.

  Returns
  -------
  current_and_duration : tuple
      (The formatted current, total duration)
  """
    assert len(durations) == len(values), f'"values" and "durations" must be the same length, while ' \
                                          f'we got {len(values)} != {len(durations)}.'

    dt = math.get_dt() if dt is None else dt

    # get input current shape, and duration
    I_duration = sum(durations)
    I_shape = ()
    for val in values:
        shape = math.shape(val)
        if len(shape) > len(I_shape):
            I_shape = shape

    # get the current
    start = 0
    I_current = math.zeros((int(np.ceil(I_duration / dt)), ) + I_shape,
                           dtype=math.float_)
    for c_size, duration in zip(values, durations):
        length = int(duration / dt)
        I_current[start:start + length] = c_size
        start += length

    if return_length:
        return I_current, I_duration
    else:
        return I_current
    def __init__(self, post, freq=8.):
        super(PoissonInput, self).__init__(size=(post.num, ))

        # parameters
        self.prob = freq * bm.get_dt() / 1000.
        self.loc = post.num * self.prob
        self.scale = np.sqrt(post.num * self.prob * (1 - self.prob))
        self.weight = ExpSyn.exc_weight[0]
        self.post = post
        assert hasattr(post, 'I')

        # variables
        self.rng = bm.random.RandomState()
    def __init__(self, pops, freq=8.):
        super(PoissonInput2, self).__init__(size=sum([p.num for p in pops]))

        # parameters
        self.pops = pops
        prob = freq * bm.get_dt() / 1000.
        assert (prob * self.num > 5.) and (self.num * (1 - prob) > 5)
        self.loc = self.num * prob
        self.scale = np.sqrt(self.num * prob * (1 - prob))
        self.weight = ExpSyn.exc_weight[0]

        # variables
        self.rng = bm.random.RandomState()
Exemple #17
0
    def __init__(self, size, delay, dtype=None, dt=None, **kwargs):
        # dt
        self.dt = bm.get_dt() if dt is None else dt

        # data size
        if isinstance(size, int): size = (size, )
        if not isinstance(size, (tuple, list)):
            raise ModelBuildError(
                f'"size" must a tuple/list of int, but we got {type(size)}: {size}'
            )
        self.size = tuple(size)

        # delay time length
        self.delay = delay

        # data and operations
        if isinstance(delay, (int, float)):  # uniform delay
            self.uniform_delay = True
            self.num_step = int(pm.ceil(delay / self.dt)) + 1
            self.out_idx = bm.Variable(bm.array([0], dtype=bm.uint32))
            self.in_idx = bm.Variable(
                bm.array([self.num_step - 1], dtype=bm.uint32))
            self.data = bm.Variable(
                bm.zeros((self.num_step, ) + self.size, dtype=dtype))

        else:  # non-uniform delay
            self.uniform_delay = False
            if not len(self.size) == 1:
                raise NotImplementedError(
                    f'Currently, BrainPy only supports 1D heterogeneous '
                    f'delays, while we got the heterogeneous delay with '
                    f'{len(self.size)}-dimensions.')
            self.num = size2len(size)
            if bm.ndim(delay) != 1:
                raise ModelBuildError(f'Only support a 1D non-uniform delay. '
                                      f'But we got {delay.ndim}D: {delay}')
            if delay.shape[0] != self.size[0]:
                raise ModelBuildError(
                    f"The first shape of the delay time size must "
                    f"be the same with the delay data size. But "
                    f"we got {delay.shape[0]} != {self.size[0]}")
            delay = bm.around(delay / self.dt)
            self.diag = bm.array(bm.arange(self.num), dtype=bm.int_)
            self.num_step = bm.array(delay, dtype=bm.uint32) + 1
            self.in_idx = bm.Variable(self.num_step - 1)
            self.out_idx = bm.Variable(bm.zeros(self.num, dtype=bm.uint32))
            self.data = bm.Variable(
                bm.zeros((self.num_step.max(), ) + size, dtype=dtype))

        super(ConstantDelay, self).__init__(**kwargs)
Exemple #18
0
def constant_input(I_and_duration, dt=None):
    """Format constant input in durations.

  For example:

  If you want to get an input where the size is 0 bwteen 0-100 ms,
  and the size is 1. between 100-200 ms.

  >>> import brainpy.math as bm
  >>> constant_input([(0, 100), (1, 100)])
  >>> constant_input([(bm.zeros(100), 100), (bm.random.rand(100), 100)])

  Parameters
  ----------
  I_and_duration : list
      This parameter receives the current size and the current
      duration pairs, like `[(Isize1, duration1), (Isize2, duration2)]`.
  dt : float
      Default is None.

  Returns
  -------
  current_and_duration : tuple
      (The formatted current, total duration)
  """
    dt = math.get_dt() if dt is None else dt

    # get input current dimension, shape, and duration
    I_duration = 0.
    I_shape = ()
    for I in I_and_duration:
        I_duration += I[1]
        shape = math.shape(I[0])
        if len(shape) > len(I_shape):
            I_shape = shape

    # get the current
    start = 0
    I_current = math.zeros((int(np.ceil(I_duration / dt)), ) + I_shape,
                           dtype=math.float_)
    for c_size, duration in I_and_duration:
        length = int(duration / dt)
        I_current[start:start + length] = c_size
        start += length
    return I_current, I_duration
Exemple #19
0
def spike_input(sp_times, sp_lens, sp_sizes, duration, dt=None):
    """Format current input like a series of short-time spikes.

  For example:

  If you want to generate a spike train at 10 ms, 20 ms, 30 ms, 200 ms, 300 ms,
  and each spike lasts 1 ms and the spike current is 0.5, then you can use the
  following funtions:

  >>> spike_input(sp_times=[10, 20, 30, 200, 300],
  >>>             sp_lens=1.,  # can be a list to specify the spike length at each point
  >>>             sp_sizes=0.5,  # can be a list to specify the current size at each point
  >>>             duration=400.)

  Parameters
  ----------
  sp_times : list, tuple
      The spike time-points. Must be an iterable object.
  sp_lens : int, float, list, tuple
      The length of each point-current, mimicking the spike durations.
  sp_sizes : int, float, list, tuple
      The current sizes.
  duration : int, float
      The total current duration.
  dt : float
      The default is None.

  Returns
  -------
  current : math.ndarray
      The formatted input current.
  """
    dt = math.get_dt() if dt is None else dt
    assert isinstance(sp_times, (list, tuple))
    if isinstance(sp_lens, (float, int)):
        sp_lens = [sp_lens] * len(sp_times)
    if isinstance(sp_sizes, (float, int)):
        sp_sizes = [sp_sizes] * len(sp_times)

    current = math.zeros(int(np.ceil(duration / dt)), dtype=math.float_)
    for time, dur, size in zip(sp_times, sp_lens, sp_sizes):
        pp = int(time / dt)
        p_len = int(dur / dt)
        current[pp:pp + p_len] = size
    return current
Exemple #20
0
        def integral_func(*args, **kwargs):
            # format arguments
            params_in = Collector()
            for i, arg in enumerate(args):
                params_in[all_vps[i]] = arg
            params_in.update(kwargs)
            if 'dt' not in params_in:
                params_in['dt'] = math.get_dt()

            # call integrals
            results = []
            for i, int_fun in enumerate(integrals):
                _key = arg_names[i][0]
                r = int_fun(
                    params_in[_key], **{
                        arg: params_in[arg]
                        for arg in arg_names[i][1:] if arg in params_in
                    })
                results.append(r)
            return results if isinstance(self.f,
                                         joint_eq.JointEq) else results[0]
    def plot_trajectory(self,
                        initials,
                        duration,
                        plot_durations=None,
                        axes='v-v',
                        dt=None,
                        show=False,
                        with_plot=True,
                        with_return=False,
                        **kwargs):
        """Plot trajectories according to the settings.

    Parameters
    ----------
    initials : list, tuple, dict
        The initial value setting of the targets. It can be a tuple/list of floats to specify
        each value of dynamical variables (for example, ``(a, b)``). It can also be a
        tuple/list of tuple to specify multiple initial values (for example,
        ``[(a1, b1), (a2, b2)]``).
    duration : int, float, tuple, list
        The running duration. Same with the ``duration`` in ``NeuGroup.run()``.

        - It can be a int/float (``t_end``) to specify the same running end time,
        - Or it can be a tuple/list of int/float (``(t_start, t_end)``) to specify
          the start and end simulation time.
        - Or, it can be a list of tuple (``[(t1_start, t1_end), (t2_start, t2_end)]``)
          to specify the specific start and end simulation time for each initial value.
    plot_durations : tuple, list, optional
        The duration to plot. It can be a tuple with ``(start, end)``. It can
        also be a list of tuple ``[(start1, end1), (start2, end2)]`` to specify
        the plot duration for each initial value running.
    axes : str
        The axes to plot. It can be:

         - 'v-v': Plot the trajectory in the 'x_var'-'y_var' axis.
         - 't-v': Plot the trajectory in the 'time'-'var' axis.
    show : bool
        Whether show or not.
    """

        utils.output('I am plotting the trajectory ...')

        if axes not in ['v-v', 't-v']:
            raise errors.AnalyzerError(
                f'Unknown axes "{axes}", only support "v-v" and "t-v".')

        # check the initial values
        initials = utils.check_initials(initials, self.target_var_names)

        # 2. format the running duration
        assert isinstance(duration, (int, float))

        # 3. format the plot duration
        plot_durations = utils.check_plot_durations(plot_durations, duration,
                                                    initials)

        # 5. run the network
        dt = math.get_dt() if dt is None else dt
        traject_model = utils.TrajectModel(initial_vars=initials,
                                           integrals={
                                               self.x_var: self.F_int_x,
                                               self.y_var: self.F_int_y
                                           },
                                           dt=dt)
        mon_res = traject_model.run(duration=duration)

        if with_plot:
            # plots
            for i, initial in enumerate(zip(*list(initials.values()))):
                # legend
                legend = f'$traj_{i}$: '
                for j, key in enumerate(self.target_var_names):
                    legend += f'{key}={round(float(initial[j]), 4)}, '
                legend = legend[:-2]

                # visualization
                start = int(plot_durations[i][0] / dt)
                end = int(plot_durations[i][1] / dt)
                if axes == 'v-v':
                    lines = plt.plot(mon_res[self.x_var][start:end, i],
                                     mon_res[self.y_var][start:end, i],
                                     label=legend,
                                     **kwargs)
                    utils.add_arrow(lines[0])
                else:
                    plt.plot(mon_res.ts[start:end],
                             mon_res[self.x_var][start:end, i],
                             label=legend + f', {self.x_var}',
                             **kwargs)
                    plt.plot(mon_res.ts[start:end],
                             mon_res[self.y_var][start:end, i],
                             label=legend + f', {self.y_var}',
                             **kwargs)

            # visualization of others
            if axes == 'v-v':
                plt.xlabel(self.x_var)
                plt.ylabel(self.y_var)
                scale = (self.lim_scale - 1.) / 2
                plt.xlim(
                    *utils.rescale(self.target_vars[self.x_var], scale=scale))
                plt.ylim(
                    *utils.rescale(self.target_vars[self.y_var], scale=scale))
                plt.legend()
            else:
                plt.legend(title='Initial values')

            if show:
                plt.show()

        if with_return:
            return mon_res
Exemple #22
0
def animate_1D(dynamical_vars,
               static_vars=(),
               dt=None,
               xlim=None,
               ylim=None,
               xlabel=None,
               ylabel=None,
               frame_delay=50.,
               frame_step=1,
               title_size=10,
               figsize=None,
               gif_dpi=None,
               video_fps=None,
               save_path=None,
               show=True,
               **kwargs):
    """Animation of one-dimensional data.

  Parameters
  ----------
  dynamical_vars : dict, np.ndarray, list of np.ndarray, list of dict
      The dynamical variables which will be animated.
  static_vars : dict, np.ndarray, list of np.ndarray, list of dict
      The static variables.
  xticks : list, np.ndarray
      The xticks.
  dt : float
      The numerical integration step.
  xlim : tuple
      The xlim.
  ylim : tuple
      The ylim.
  xlabel : str
      The xlabel.
  ylabel : str
      The ylabel.
  frame_delay : int, float
      The delay to show each frame.
  frame_step : int
      The step to show the potential. If `frame_step=3`, then each
      frame shows one of the every three steps.
  title_size : int
      The size of the title.
  figsize : None, tuple
      The size of the figure.
  gif_dpi : int
      Controls the dots per inch for the movie frames. This combined with
      the figure's size in inches controls the size of the movie. If
      ``None``, use defaults in matplotlib.
  video_fps : int
      Frames per second in the movie. Defaults to ``None``, which will use
      the animation's specified interval to set the frames per second.
  save_path : None, str
      The save path of the animation.
  show : bool
      Whether show the animation.

  Returns
  -------
  figure : plt.figure
      The created figure instance.
  """

    # check dt
    dt = math.get_dt() if dt is None else dt

    # check figure
    fig = plt.figure(figsize=(figsize or (6, 6)), constrained_layout=True)
    gs = GridSpec(1, 1, figure=fig)
    fig.add_subplot(gs[0, 0])

    # check dynamical variables
    final_dynamic_vars = []
    lengths = []
    has_legend = False
    if isinstance(dynamical_vars, (tuple, list)):
        for var in dynamical_vars:
            if isinstance(var, dict):
                assert 'ys' in var, 'Must provide "ys" item.'
                if 'legend' not in var:
                    var['legend'] = None
                else:
                    has_legend = True
                var['ys'] = np.asarray(var['ys'])
                if 'xs' not in var:
                    var['xs'] = np.arange(var['ys'].shape[1])
            elif isinstance(var, (np.ndarray, math.ndarray)):
                var = np.asarray(var)
                var = {
                    'ys': var,
                    'xs': np.arange(var.shape[1]),
                    'legend': None
                }
            else:
                raise ValueError(f'Unknown data type: {type(var)}')
            assert np.ndim(var['ys']) == 2, "Dynamic variable must be 2D data."
            lengths.append(var['ys'].shape[0])
            final_dynamic_vars.append(var)
    elif isinstance(dynamical_vars, dict):
        assert 'ys' in dynamical_vars, 'Must provide "ys" item.'
        if 'legend' not in dynamical_vars:
            dynamical_vars['legend'] = None
        else:
            has_legend = True
        dynamical_vars['ys'] = np.asarray(dynamical_vars['ys'])
        if 'xs' not in dynamical_vars:
            dynamical_vars['xs'] = np.arange(dynamical_vars['ys'].shape[1])
        lengths.append(dynamical_vars['ys'].shape[0])
        final_dynamic_vars.append(dynamical_vars)
    else:
        assert np.ndim(
            dynamical_vars) == 2, "Dynamic variable must be 2D data."
        dynamical_vars = np.asarray(dynamical_vars)
        lengths.append(dynamical_vars.shape[0])
        final_dynamic_vars.append({
            'ys': dynamical_vars,
            'xs': np.arange(dynamical_vars.shape[1]),
            'legend': None
        })
    lengths = np.array(lengths)
    assert np.all(
        lengths == lengths[0]), 'Dynamic variables must have equal length.'

    # check static variables
    final_static_vars = []
    if isinstance(static_vars, (tuple, list)):
        for var in static_vars:
            if isinstance(var, dict):
                assert 'data' in var, 'Must provide "ys" item.'
                if 'legend' not in var:
                    var['legend'] = None
                else:
                    has_legend = True
            elif isinstance(var, np.ndarray):
                var = {'data': var, 'legend': None}
            else:
                raise ValueError(f'Unknown data type: {type(var)}')
            assert np.ndim(
                var['data']) == 1, "Static variable must be 1D data."
            final_static_vars.append(var)
    elif isinstance(static_vars, np.ndarray):
        final_static_vars.append({
            'data': static_vars,
            'xs': np.arange(static_vars.shape[0]),
            'legend': None
        })
    elif isinstance(static_vars, dict):
        assert 'ys' in static_vars, 'Must provide "ys" item.'
        if 'legend' not in static_vars:
            static_vars['legend'] = None
        else:
            has_legend = True
        if 'xs' not in static_vars:
            static_vars['xs'] = np.arange(static_vars['ys'].shape[0])
        final_static_vars.append(static_vars)

    else:
        raise ValueError(f'Unknown static data type: {type(static_vars)}')

    # ylim
    if ylim is None:
        ylim_min = np.inf
        ylim_max = -np.inf
        for var in final_dynamic_vars + final_static_vars:
            if var['ys'].max() > ylim_max:
                ylim_max = var['ys'].max()
            if var['ys'].min() < ylim_min:
                ylim_min = var['ys'].min()
        if ylim_min > 0:
            ylim_min = ylim_min * 0.98
        else:
            ylim_min = ylim_min * 1.02
        if ylim_max > 0:
            ylim_max = ylim_max * 1.02
        else:
            ylim_max = ylim_max * 0.98
        ylim = (ylim_min, ylim_max)

    def frame(t):
        fig.clf()
        for dvar in final_dynamic_vars:
            plt.plot(dvar['xs'], dvar['ys'][t], label=dvar['legend'], **kwargs)
        for svar in final_static_vars:
            plt.plot(svar['xs'], svar['ys'], label=svar['legend'], **kwargs)
        if xlim is not None:
            plt.xlim(xlim[0], xlim[1])
        if has_legend:
            plt.legend()
        if xlabel:
            plt.xlabel(xlabel)
        if ylabel:
            plt.ylabel(ylabel)
        plt.ylim(ylim[0], ylim[1])
        fig.suptitle(t="Time: {:.2f} ms".format((t + 1) * dt),
                     fontsize=title_size,
                     fontweight='bold')
        return [fig.gca()]

    anim_result = animation.FuncAnimation(fig=fig,
                                          func=frame,
                                          frames=range(1, lengths[0],
                                                       frame_step),
                                          init_func=None,
                                          interval=frame_delay,
                                          repeat_delay=3000)

    # save or show
    if save_path is None:
        if show: plt.show()
    else:
        logger.warning(f'Saving the animation into {save_path} ...')
        if save_path[-3:] == 'gif':
            anim_result.save(save_path, dpi=gif_dpi, writer='imagemagick')
        elif save_path[-3:] == 'mp4':
            anim_result.save(save_path,
                             writer='ffmpeg',
                             fps=video_fps,
                             bitrate=3000)
        else:
            anim_result.save(save_path + '.mp4',
                             writer='ffmpeg',
                             fps=video_fps,
                             bitrate=3000)
    return fig
Exemple #23
0
    def __init__(self,
                 target,
                 monitors=None,
                 inits=None,
                 args=None,
                 dyn_args=None,
                 dyn_vars=None,
                 jit=True,
                 dt=None,
                 numpy_mon_after_run=True,
                 progress_bar=True):
        super(IntegratorRunner, self).__init__()

        # parameters
        dt = math.get_dt() if dt is None else dt
        if not isinstance(dt, (int, float)):
            raise RunningError(f'"dt" must be scalar, but got {dt}')
        self.dt = dt
        self.jit = jit
        self.numpy_mon_after_run = numpy_mon_after_run
        self._pbar = None  # progress bar
        self.progress_bar = progress_bar

        # target
        if not isinstance(target, Integrator):
            raise RunningError(
                f'"target" must be an instance of {Integrator.__name__}, '
                f'but we got {type(target)}: {target}')
        self.target = target

        # arguments of the integral function
        self._static_args = Collector()
        if args is not None:
            assert isinstance(
                args, dict
            ), f'"args" must be a dict, but we get {type(args)}: {args}'
            self._static_args.update(args)
        self._dyn_args = Collector()
        if dyn_args is not None:
            assert isinstance(
                dyn_args, dict
            ), f'"dyn_args" must be a dict, but we get {type(dyn_args)}: {dyn_args}'
            sizes = np.unique([len(v) for v in dyn_args.values()])
            num_size = len(sizes)
            if num_size != 1:
                raise RunningError(
                    f'All values in "dyn_args" should have the same length. But we got '
                    f'{num_size}: {sizes}')
            self._dyn_args.update(dyn_args)

        # dynamical changed variables
        if dyn_vars is None:
            dyn_vars = self.target.vars().unique()
        if isinstance(dyn_vars, (list, tuple)):
            dyn_vars = {f'_v{i}': v for i, v in enumerate(dyn_vars)}
        if not isinstance(dyn_vars, dict):
            raise RunningError(
                f'"dyn_vars" must be a dict, but we got {type(dyn_vars)}')
        self.dyn_vars = TensorCollector(dyn_vars)

        # monitors
        if monitors is None:
            self.mon = Monitor(target=self, variables=[])
        elif isinstance(monitors, (list, tuple, dict)):
            self.mon = Monitor(target=self, variables=monitors)
        elif isinstance(monitors, Monitor):
            self.mon = monitors
            self.mon.target = self
        else:
            raise MonitorError(f'"monitors" only supports list/tuple/dict/ '
                               f'instance of Monitor, not {type(monitors)}.')
        self.mon.build()  # build the monitor
        for k in self.mon.item_names:
            if k not in self.target.variables:
                raise MonitorError(
                    f'Variable "{k}" to monitor is not defined in the integrator {self.target}.'
                )

        # start simulation time
        self._start_t = None

        # Variables
        if inits is not None:
            if isinstance(inits, (list, tuple)):
                assert len(self.target.variables) == len(inits)
                inits = {
                    k: inits[i]
                    for i, k in enumerate(self.target.variables)
                }
            assert isinstance(inits, dict)
            sizes = np.unique([np.size(v) for v in list(inits.values())])
            max_size = np.max(sizes)
        else:
            max_size = 1
            inits = dict()
        self.variables = TensorCollector({
            v: math.Variable(math.zeros(max_size))
            for v in self.target.variables
        })
        for k in inits.keys():
            self.variables[k][:] = inits[k]
        self.dyn_vars.update(self.variables)
        if len(self._dyn_args) > 0:
            self.idx = math.Variable(math.zeros(1, dtype=math.int_))
            self.dyn_vars['_idx'] = self.idx

        # build the update step
        if jit:
            _loop_func = math.make_loop(
                self._step,
                dyn_vars=self.dyn_vars,
                out_vars={k: self.variables[k]
                          for k in self.mon.item_names})
        else:

            def _loop_func(t_and_dt):
                out_vars = {k: [] for k in self.mon.item_names}
                times, dts = t_and_dt
                for i in range(len(times)):
                    _t = times[i]
                    _dt = dts[i]
                    self._step([_t, _dt])
                    for k in self.mon.item_names:
                        out_vars[k].append(
                            math.as_device_array(self.variables[k]))
                out_vars = {
                    k: math.asarray(out_vars[k])
                    for k in self.mon.item_names
                }
                return out_vars

        self.step_func = _loop_func
Exemple #24
0
def animate_2D(values,
               net_size,
               dt=None,
               val_min=None,
               val_max=None,
               cmap=None,
               frame_delay=10,
               frame_step=1,
               title_size=10,
               figsize=None,
               gif_dpi=None,
               video_fps=None,
               save_path=None,
               show=True):
    """Animate the potentials of the neuron group.

  Parameters
  ----------
  values : np.ndarray
      The membrane potentials of the neuron group.
  net_size : tuple
      The size of the neuron group.
  dt : float
      The time duration of each step.
  val_min : float, int
      The minimum of the potential.
  val_max : float, int
      The maximum of the potential.
  cmap : str
      The colormap.
  frame_delay : int, float
      The delay to show each frame.
  frame_step : int
      The step to show the potential. If `frame_step=3`, then each
      frame shows one of the every three steps.
  title_size : int
      The size of the title.
  figsize : None, tuple
      The size of the figure.
  gif_dpi : int
      Controls the dots per inch for the movie frames. This combined with
      the figure's size in inches controls the size of the movie. If
      ``None``, use defaults in matplotlib.
  video_fps : int
      Frames per second in the movie. Defaults to ``None``, which will use
      the animation's specified interval to set the frames per second.
  save_path : None, str
      The save path of the animation.
  show : bool
      Whether show the animation.

  Returns
  -------
  anim : animation.FuncAnimation
      The created animation function.
  """
    dt = math.get_dt() if dt is None else dt
    num_step, num_neuron = values.shape
    height, width = net_size

    values = np.asarray(values)
    val_min = values.min() if val_min is None else val_min
    val_max = values.max() if val_max is None else val_max

    figsize = figsize or (6, 6)

    fig = plt.figure(figsize=(figsize[0], figsize[1]), constrained_layout=True)
    gs = GridSpec(1, 1, figure=fig)
    fig.add_subplot(gs[0, 0])

    def frame(t):
        img = values[t]
        fig.clf()
        plt.pcolor(img, cmap=cmap, vmin=val_min, vmax=val_max)
        plt.colorbar()
        plt.axis('off')
        fig.suptitle(t="Time: {:.2f} ms".format((t + 1) * dt),
                     fontsize=title_size,
                     fontweight='bold')
        return [fig.gca()]

    values = values.reshape((num_step, height, width))
    anim = animation.FuncAnimation(fig=fig,
                                   func=frame,
                                   frames=list(range(1, num_step, frame_step)),
                                   init_func=None,
                                   interval=frame_delay,
                                   repeat_delay=3000)
    if save_path is None:
        if show:
            plt.show()
    else:
        logger.warning(f'Saving the animation into {save_path} ...')
        if save_path[-3:] == 'gif':
            anim.save(save_path, dpi=gif_dpi, writer='imagemagick')
        elif save_path[-3:] == 'mp4':
            anim.save(save_path, writer='ffmpeg', fps=video_fps, bitrate=3000)
        else:
            anim.save(save_path + '.mp4',
                      writer='ffmpeg',
                      fps=video_fps,
                      bitrate=3000)
    return anim
  def update(self, _t, _dt):
    r1 = bm.square(self.u)
    r2 = 1.0 + self.k * bm.sum(r1)
    self.r.value = r1 / r2
    Irec = bm.dot(self.conn_mat, self.r)
    self.u.value = self.u + (-self.u + Irec + self.input - self.v) / self.tau * _dt
    self.v.value = self.v + (-self.v + self.m * self.u) / self.tau_v * _dt
    self.input[:] = 0.


cann = CANN1D(num=512)

# Smooth tracking #
dur1, dur2, dur3 = 100., 2000., 500.
num1 = int(dur1 / bm.get_dt())
num2 = int(dur2 / bm.get_dt())
num3 = int(dur3 / bm.get_dt())
position = bm.zeros(num1 + num2 + num3)
final_pos = cann.a / cann.tau_v * 0.6 * dur2
position[num1: num1 + num2] = bm.linspace(0., final_pos, num2)
position[num1 + num2:] = final_pos
position = position.reshape((-1, 1))
Iext = cann.get_stimulus_by_pos(position)
runner = bp.StructRunner(cann,
                         inputs=('input', Iext, 'iter'),
                         monitors=['u', 'v'],
                         dyn_vars=cann.vars())
runner(dur1 + dur2 + dur3)
bp.visualize.animate_1D(
  dynamical_vars=[
Exemple #26
0
def _wrap(wrapper, f, g, dt, sde_type, var_type, wiener_type, show_code,
          num_iter):
    """The base function to format a SRK method.

  Parameters
  ----------
  f : callable
      The drift function of the SDE_INT.
  g : callable
      The diffusion function of the SDE_INT.
  dt : float
      The numerical precision.
  sde_type : str
      "utils.ITO_SDE" : Ito's Stochastic Calculus.
      "utils.STRA_SDE" : Stratonovich's Stochastic Calculus.
  wiener_type : str
  var_type : str
      "scalar" : with the shape of ().
      "population" : with the shape of (N,) or (N1, N2) or (N1, N2, ...).
      "system": with the shape of (d, ), (d, N), or (d, N1, N2).
  show_code : bool
      Whether show the formatted code.

  Returns
  -------
  numerical_func : callable
      The numerical function.
  """

    sde_type = constants.ITO_SDE if sde_type is None else sde_type
    assert sde_type in constants.SUPPORTED_INTG_TYPE, f'Currently, BrainPy only support SDE_INT types: ' \
                                                      f'{constants.SUPPORTED_INTG_TYPE}. But we got {sde_type}.'

    var_type = constants.POP_VAR if var_type is None else var_type
    assert var_type in constants.SUPPORTED_VAR_TYPE, f'Currently, BrainPy only supports variable types: ' \
                                                     f'{constants.SUPPORTED_VAR_TYPE}. But we got {var_type}.'

    wiener_type = constants.SCALAR_WIENER if wiener_type is None else wiener_type
    assert wiener_type in constants.SUPPORTED_WIENER_TYPE, f'Currently, BrainPy only supports Wiener ' \
                                                           f'Process types: {constants.SUPPORTED_WIENER_TYPE}. ' \
                                                           f'But we got {wiener_type}.'

    show_code = False if show_code is None else show_code
    dt = math.get_dt() if dt is None else dt
    num_iter = 10 if num_iter is None else num_iter

    if f is not None and g is not None:
        return wrapper(f=f,
                       g=g,
                       dt=dt,
                       show_code=show_code,
                       sde_type=sde_type,
                       var_type=var_type,
                       wiener_type=wiener_type,
                       num_iter=num_iter)

    elif f is not None:
        return lambda g: wrapper(f=f,
                                 g=g,
                                 dt=dt,
                                 show_code=show_code,
                                 sde_type=sde_type,
                                 var_type=var_type,
                                 wiener_type=wiener_type,
                                 num_iter=num_iter)

    elif g is not None:
        return lambda f: wrapper(f=f,
                                 g=g,
                                 dt=dt,
                                 show_code=show_code,
                                 sde_type=sde_type,
                                 var_type=var_type,
                                 wiener_type=wiener_type,
                                 num_iter=num_iter)

    else:
        raise ValueError('Must provide "f" or "g".')
    def __init__(self, pre, post, prob, syn_type='e', conn_type=0):
        super(ExpSyn, self).__init__(pre=pre, post=post, conn=None)
        self.check_pre_attrs('spike')
        self.check_post_attrs('I')
        assert syn_type in ['e', 'i']
        # assert conn_type in [0, 1, 2, 3]
        assert 0. < prob < 1.

        # parameters
        self.syn_type = syn_type
        self.conn_type = conn_type

        # connection
        if conn_type == 0:
            # number of synapses calculated with equation 3 from the article
            num = int(
                np.log(1.0 - prob) / np.log(1.0 -
                                            (1.0 / float(pre.num * post.num))))
            self.pre2post = bp.conn.ij2csr(
                pre_ids=np.random.randint(0, pre.num, num),
                post_ids=np.random.randint(0, post.num, num),
                num_pre=pre.num)
            self.num = self.pre2post[0].size
        elif conn_type == 1:
            # number of synapses calculated with equation 5 from the article
            self.pre2post = bp.conn.FixedProb(prob)(
                pre.size, post.size).require('pre2post')
            self.num = self.pre2post[0].size
        elif conn_type == 2:
            self.num = int(prob * pre.num * post.num)
            self.pre_ids = bm.random.randint(0,
                                             pre.num,
                                             size=self.num,
                                             dtype=bm.uint32)
            self.post_ids = bm.random.randint(0,
                                              post.num,
                                              size=self.num,
                                              dtype=bm.uint32)
        elif conn_type in [3, 4]:
            self.pre2post = bp.conn.FixedProb(prob)(
                pre.size, post.size).require('pre2post')
            self.num = self.pre2post[0].size
            self.max_post_conn = bm.diff(self.pre2post[1]).max()
        else:
            raise ValueError

        # delay
        if syn_type == 'e':
            self.delay = bm.random.normal(*self.exc_delay, size=pre.num)
        elif syn_type == 'i':
            self.delay = bm.random.normal(*self.inh_delay, size=pre.num)
        else:
            raise ValueError
        self.delay = bm.where(self.delay < bm.get_dt(), bm.get_dt(),
                              self.delay)

        # weights
        self.weights = bm.random.normal(*self.exc_weight, size=self.num)
        self.weights = bm.where(self.weights < 0, 0., self.weights)
        if syn_type == 'i':
            self.weights *= self.inh_weight_scale

        # variables
        self.pre_sps = bp.ConstantDelay(pre.num, self.delay, bool)
Exemple #28
0
  def plot_limit_cycle_by_sim(self, duration=100, with_plot=True, with_return=False,
                              plot_style=None, tol=0.001, show=False, dt=None, offset=1.):
    utils.output('I am plotting the limit cycle ...')
    if self._fixed_points is None:
      utils.output('No fixed points found, you may call "plot_bifurcation(with_plot=True)" first.')
      return

    final_fps, final_pars = self._fixed_points
    dt = bm.get_dt() if dt is None else dt
    traject_model = utils.TrajectModel(
      initial_vars={self.x_var: final_fps[:, 0] + offset, self.y_var: final_fps[:, 1] + offset},
      integrals={self.x_var: self.F_int_x, self.y_var: self.F_int_y},
      pars={p: v for p, v in zip(self.target_par_names, final_pars.T)},
      dt=dt
    )
    mon_res = traject_model.run(duration=duration)

    # find limit cycles
    vs_limit_cycle = tuple({'min': [], 'max': []} for _ in self.target_var_names)
    ps_limit_cycle = tuple([] for _ in self.target_par_names)
    for i in range(mon_res[self.x_var].shape[1]):
      data = mon_res[self.x_var][:, i]
      max_index = utils.find_indexes_of_limit_cycle_max(data, tol=tol)
      if max_index[0] != -1:
        cycle = data[max_index[0]: max_index[1]]
        vs_limit_cycle[0]['max'].append(mon_res[self.x_var][max_index[1], i])
        vs_limit_cycle[0]['min'].append(cycle.min())
        cycle = mon_res[self.y_var][max_index[0]: max_index[1], i]
        vs_limit_cycle[1]['max'].append(mon_res[self.y_var][max_index[1], i])
        vs_limit_cycle[1]['min'].append(cycle.min())
        for j in range(len(self.target_par_names)):
          ps_limit_cycle[j].append(final_pars[i, j])
    vs_limit_cycle = tuple({k: np.asarray(v) for k, v in lm.items()} for lm in vs_limit_cycle)
    ps_limit_cycle = tuple(np.array(p) for p in ps_limit_cycle)

    # visualization
    if with_plot:
      if plot_style is None: plot_style = dict()
      fmt = plot_style.pop('fmt', '.')

      if len(self.target_par_names) == 2:
        for i, var in enumerate(self.target_var_names):
          plt.figure(var)
          plt.plot(ps_limit_cycle[0], ps_limit_cycle[1], vs_limit_cycle[i]['max'],
                   **plot_style, label='limit cycle (max)')
          plt.plot(ps_limit_cycle[0], ps_limit_cycle[1], vs_limit_cycle[i]['min'],
                   **plot_style, label='limit cycle (min)')
          plt.legend()

      elif len(self.target_par_names) == 1:
        for i, var in enumerate(self.target_var_names):
          plt.figure(var)
          plt.plot(ps_limit_cycle[0], vs_limit_cycle[i]['max'], fmt,
                   **plot_style, label='limit cycle (max)')
          plt.plot(ps_limit_cycle[0], vs_limit_cycle[i]['min'], fmt,
                   **plot_style, label='limit cycle (min)')
          plt.legend()

      else:
        raise errors.AnalyzerError

      if show:
        plt.show()

    if with_return:
      return vs_limit_cycle, ps_limit_cycle
Exemple #29
0
    def __init__(self,
                 f,
                 g,
                 dt=None,
                 name=None,
                 show_code=False,
                 var_type=None,
                 intg_type=None,
                 wiener_type=None):
        super(SDEIntegrator, self).__init__(name=name)

        # derivative functions
        self.derivative = {constants.F: f, constants.G: g}
        self.f = f
        self.g = g

        # integration function
        self.integral = None

        # essential parameters
        self.dt = math.get_dt() if dt is None else dt
        assert isinstance(
            self.dt, (int, float)), f'"dt" must be a float, but got {self.dt}'
        intg_type = constants.ITO_SDE if intg_type is None else intg_type
        var_type = constants.SCALAR_VAR if var_type is None else var_type
        wiener_type = constants.SCALAR_WIENER if wiener_type is None else wiener_type
        if intg_type not in constants.SUPPORTED_INTG_TYPE:
            raise errors.IntegratorError(
                f'Currently, BrainPy only support SDE_INT types: '
                f'{constants.SUPPORTED_INTG_TYPE}. But we got {intg_type}.')
        if var_type not in constants.SUPPORTED_VAR_TYPE:
            raise errors.IntegratorError(
                f'Currently, BrainPy only supports variable types: '
                f'{constants.SUPPORTED_VAR_TYPE}. But we got {var_type}.')
        if wiener_type not in constants.SUPPORTED_WIENER_TYPE:
            raise errors.IntegratorError(
                f'Currently, BrainPy only supports Wiener '
                f'Process types: {constants.SUPPORTED_WIENER_TYPE}. '
                f'But we got {wiener_type}.')
        self.var_type = var_type  # variable type
        self.intg_type = intg_type  # integral type
        self.wiener_type = wiener_type  # wiener process type

        # parse function arguments
        variables, parameters, arguments = utils.get_args(f)
        self.variables = variables  # variable names, (before 't')
        self.parameters = parameters  # parameter names, (after 't')
        self.arguments = list(arguments) + [f'{constants.DT}={self.dt}'
                                            ]  # function arguments

        # random seed
        self.rng = math.random.RandomState()

        # code scope
        self.code_scope = {
            constants.F: f,
            constants.G: g,
            'math': math,
            'random': self.rng
        }

        # code lines
        self.func_name = f_names(f)
        self.code_lines = [
            f'def {self.func_name}({", ".join(self.arguments)}):'
        ]

        # others
        self.show_code = show_code
    def plot_limit_cycle_by_sim(self,
                                initials,
                                duration,
                                tol=0.01,
                                show=False,
                                dt=None):
        """Plot trajectories according to the settings.

    Parameters
    ----------
    initials : list, tuple
        The initial value setting of the targets.

        - It can be a tuple/list of floats to specify each value of dynamical variables
          (for example, ``(a, b)``).
        - It can also be a tuple/list of tuple to specify multiple initial values (for
          example, ``[(a1, b1), (a2, b2)]``).
    duration : int, float, tuple, list
        The running duration. Same with the ``duration`` in ``NeuGroup.run()``.

        - It can be a int/float (``t_end``) to specify the same running end time,
        - Or it can be a tuple/list of int/float (``(t_start, t_end)``) to specify
          the start and end simulation time.
        - Or, it can be a list of tuple (``[(t1_start, t1_end), (t2_start, t2_end)]``)
          to specify the specific start and end simulation time for each initial value.
    show : bool
        Whether show or not.
    """
        utils.output('I am plotting the limit cycle ...')

        # 1. format the initial values
        initials = utils.check_initials(initials, self.target_var_names)

        # 2. format the running duration
        assert isinstance(duration, (int, float))

        dt = math.get_dt() if dt is None else dt
        traject_model = utils.TrajectModel(initial_vars=initials,
                                           integrals={
                                               self.x_var: self.F_int_x,
                                               self.y_var: self.F_int_y
                                           },
                                           dt=dt)
        mon_res = traject_model.run(duration=duration)

        # 5. run the network
        for init_i, initial in enumerate(zip(*list(initials.values()))):
            #   5.2 run the model
            x_data = mon_res[self.x_var][:, init_i]
            y_data = mon_res[self.y_var][:, init_i]
            max_index = utils.find_indexes_of_limit_cycle_max(x_data, tol=tol)
            if max_index[0] != -1:
                x_cycle = x_data[max_index[0]:max_index[1]]
                y_cycle = y_data[max_index[0]:max_index[1]]
                # 5.5 visualization
                lines = plt.plot(x_cycle, y_cycle, label='limit cycle')
                utils.add_arrow(lines[0])
            else:
                utils.output(
                    f'No limit cycle found for initial value {initial}')

        # 6. visualization
        plt.xlabel(self.x_var)
        plt.ylabel(self.y_var)
        scale = (self.lim_scale - 1.) / 2
        plt.xlim(*utils.rescale(self.target_vars[self.x_var], scale=scale))
        plt.ylim(*utils.rescale(self.target_vars[self.y_var], scale=scale))
        plt.legend()

        if show:
            plt.show()