Esempio n. 1
0
    def plot_vector_field(self, show=False, with_plot=True, with_return=False):
        """Plot the vector filed."""
        utils.output('I am creating the vector field ...')

        # Nullcline of the x variable
        y_val = self.F_fx(self.resolutions[self.x_var])
        y_val = np.asarray(y_val)

        # visualization
        if with_plot:
            label = f"d{self.x_var}dt"
            x_style = dict(color='lightcoral', alpha=.7, linewidth=4)
            plt.plot(np.asarray(self.resolutions[self.x_var]),
                     y_val,
                     **x_style,
                     label=label)
            plt.axhline(0)
            plt.xlabel(self.x_var)
            plt.ylabel(label)
            plt.xlim(*utils.rescale(self.target_vars[self.x_var],
                                    scale=(self.lim_scale - 1.) / 2))
            plt.legend()
            if show: plt.show()
        # return
        if with_return:
            return y_val
Esempio n. 2
0
    def plot_vector_field(self, show=False):
        """Plot the vector filed.

        Parameters
        ----------
        show : bool
            Whether show the figure.

        Returns
        -------
        results : np.ndarray
            The dx values.
        """
        print('plot vector field ...')

        # 1. Nullcline of the x variable
        try:
            y_val = self.get_f_dx()(self.resolutions[self.x_var])
        except TypeError:
            raise errors.ModelUseError(
                'Missing variables. Please check and set missing '
                'variables to "fixed_vars".')

        # 2. visualization
        label = f"d{self.x_var}dt"
        x_style = dict(color='lightcoral', alpha=.7, linewidth=4)
        plt.plot(self.resolutions[self.x_var], y_val, **x_style, label=label)
        plt.axhline(0)

        plt.xlabel(self.x_var)
        plt.ylabel(label)
        plt.xlim(*utils.rescale(self.target_vars[self.x_var],
                                scale=(self.options.lim_scale - 1.) / 2))
        plt.legend()
        if show:
            plt.show()
        return y_val
Esempio n. 3
0
    def plot_limit_cycle_by_sim(self,
                                initials,
                                duration,
                                tol=0.001,
                                show=False):
        """Plot trajectories according to the settings.

        Parameters
        ----------
        initials : list, tuple
            The initial value setting of the targets. It can be a tuple/list of floats to specify
            each value of dynamical variables (for example, ``(a, b)``). It can also be a
            tuple/list of tuple to specify multiple initial values (for example,
            ``[(a1, b1), (a2, b2)]``).
        duration : int, float, tuple, list
            The running duration. Same with the ``duration`` in ``NeuGroup.run()``.
            It can be a int/float (``t_end``) to specify the same running end time,
            or it can be a tuple/list of int/float (``(t_start, t_end)``) to specify
            the start and end simulation time. Or, it can be a list of tuple
            (``[(t1_start, t1_end), (t2_start, t2_end)]``) to specify the specific
            start and end simulation time for each initial value.
        show : bool
            Whether show or not.
        """
        print('plot limit cycle ...')

        # 1. format the initial values
        if isinstance(initials, dict):
            initials = [initials]
        elif isinstance(initials, (list, tuple)):
            if isinstance(initials[0], (int, float)):
                initials = [{
                    self.dvar_names[i]: v
                    for i, v in enumerate(initials)
                }]
            elif isinstance(initials[0], dict):
                initials = initials
            elif isinstance(initials[0], (tuple, list)) and isinstance(
                    initials[0][0], (int, float)):
                initials = [{
                    self.dvar_names[i]: v
                    for i, v in enumerate(init)
                } for init in initials]
            else:
                raise ValueError
        else:
            raise ValueError

        # 2. format the running duration
        if isinstance(duration, (int, float)):
            duration = [(0, duration) for _ in range(len(initials))]
        elif isinstance(duration[0], (int, float)):
            duration = [duration for _ in range(len(initials))]
        else:
            assert len(duration) == len(initials)

        # 5. run the network
        for init_i, initial in enumerate(initials):
            traj_group = Trajectory(size=1,
                                    integrals=self.model.integrals,
                                    target_vars=initial,
                                    fixed_vars=self.fixed_vars,
                                    pars_update=self.pars_update,
                                    scope=self.model.scopes)

            #   5.2 run the model
            traj_group.run(
                duration=duration[init_i],
                report=False,
            )
            x_data = traj_group.mon[self.x_var][:, 0]
            y_data = traj_group.mon[self.y_var][:, 0]
            max_index = utils.find_indexes_of_limit_cycle_max(x_data, tol=tol)
            if max_index[0] != -1:
                x_cycle = x_data[max_index[0]:max_index[1]]
                y_cycle = y_data[max_index[0]:max_index[1]]
                # 5.5 visualization
                lines = plt.plot(x_cycle, y_cycle, label='limit cycle')
                utils.add_arrow(lines[0])
            else:
                print(f'No limit cycle found for initial value {initial}')

        # 6. visualization
        plt.xlabel(self.x_var)
        plt.ylabel(self.y_var)
        scale = (self.options.lim_scale - 1.) / 2
        plt.xlim(*utils.rescale(self.target_vars[self.x_var], scale=scale))
        plt.ylim(*utils.rescale(self.target_vars[self.y_var], scale=scale))
        plt.legend()

        if show:
            plt.show()
Esempio n. 4
0
    def plot_trajectory(self,
                        initials,
                        duration,
                        plot_duration=None,
                        axes='v-v',
                        show=False):
        """Plot trajectories according to the settings.

        Parameters
        ----------
        initials : list, tuple, dict
            The initial value setting of the targets. It can be a tuple/list of floats to specify
            each value of dynamical variables (for example, ``(a, b)``). It can also be a
            tuple/list of tuple to specify multiple initial values (for example,
            ``[(a1, b1), (a2, b2)]``).
        duration : int, float, tuple, list
            The running duration. Same with the ``duration`` in ``NeuGroup.run()``.
            It can be a int/float (``t_end``) to specify the same running end time,
            or it can be a tuple/list of int/float (``(t_start, t_end)``) to specify
            the start and end simulation time. Or, it can be a list of tuple
            (``[(t1_start, t1_end), (t2_start, t2_end)]``) to specify the specific
            start and end simulation time for each initial value.
        plot_duration : tuple, list, optional
            The duration to plot. It can be a tuple with ``(start, end)``. It can
            also be a list of tuple ``[(start1, end1), (start2, end2)]`` to specify
            the plot duration for each initial value running.
        axes : str
            The axes to plot. It can be:

                 - 'v-v'
                        Plot the trajectory in the 'x_var'-'y_var' axis.
                 - 't-v'
                        Plot the trajectory in the 'time'-'var' axis.
        show : bool
            Whether show or not.
        """

        print('plot trajectory ...')

        if axes not in ['v-v', 't-v']:
            raise errors.ModelUseError(
                f'Unknown axes "{axes}", only support "v-v" and "t-v".')

        # 1. format the initial values
        if isinstance(initials, dict):
            initials = [initials]
        elif isinstance(initials, (list, tuple)):
            if isinstance(initials[0], (int, float)):
                initials = [{
                    self.dvar_names[i]: v
                    for i, v in enumerate(initials)
                }]
            elif isinstance(initials[0], dict):
                initials = initials
            elif isinstance(initials[0], (tuple, list)) and isinstance(
                    initials[0][0], (int, float)):
                initials = [{
                    self.dvar_names[i]: v
                    for i, v in enumerate(init)
                } for init in initials]
            else:
                raise ValueError
        else:
            raise ValueError

        # 2. format the running duration
        if isinstance(duration, (int, float)):
            duration = [(0, duration) for _ in range(len(initials))]
        elif isinstance(duration[0], (int, float)):
            duration = [duration for _ in range(len(initials))]
        else:
            assert len(duration) == len(initials)

        # 3. format the plot duration
        if plot_duration is None:
            plot_duration = duration
        if isinstance(plot_duration[0], (int, float)):
            plot_duration = [plot_duration for _ in range(len(initials))]
        else:
            assert len(plot_duration) == len(initials)

        # 5. run the network
        for init_i, initial in enumerate(initials):
            traj_group = Trajectory(size=1,
                                    integrals=self.model.integrals,
                                    target_vars=initial,
                                    fixed_vars=self.fixed_vars,
                                    pars_update=self.pars_update,
                                    scope=self.model.scopes)

            #   5.2 run the model
            traj_group.run(
                duration=duration[init_i],
                report=False,
            )

            #   5.3 legend
            legend = f'$traj_{init_i}$: '
            for key in self.dvar_names:
                legend += f'{key}={initial[key]}, '
            legend = legend[:-2]

            #   5.4 trajectory
            start = int(plot_duration[init_i][0] / backend.get_dt())
            end = int(plot_duration[init_i][1] / backend.get_dt())

            #   5.5 visualization
            if axes == 'v-v':
                lines = plt.plot(traj_group.mon[self.x_var][start:end, 0],
                                 traj_group.mon[self.y_var][start:end, 0],
                                 label=legend)
                utils.add_arrow(lines[0])
            else:
                plt.plot(traj_group.mon.ts[start:end],
                         traj_group.mon[self.x_var][start:end, 0],
                         label=legend + f', {self.x_var}')
                plt.plot(traj_group.mon.ts[start:end],
                         traj_group.mon[self.y_var][start:end, 0],
                         label=legend + f', {self.y_var}')

        # 6. visualization
        if axes == 'v-v':
            plt.xlabel(self.x_var)
            plt.ylabel(self.y_var)
            scale = (self.options.lim_scale - 1.) / 2
            plt.xlim(*utils.rescale(self.target_vars[self.x_var], scale=scale))
            plt.ylim(*utils.rescale(self.target_vars[self.y_var], scale=scale))
            plt.legend()
        else:
            plt.legend(title='Initial values')

        if show:
            plt.show()
Esempio n. 5
0
    def plot_nullcline(self, numerical_setting=None, show=False):
        """Plot the nullcline.

        Parameters
        ----------
        numerical_setting : dict, optional
            Set the numerical method for solving nullclines.
            For each function setting, it contains the following keywords:

                coords
                    The coordination setting, it can be 'var1-var2' (which means
                    for each possible value 'var1' the optimizer method will search
                    the zero root of 'var2') or 'var2-var1' (which means iterate each
                    'var2' and get the optimization results of 'var1').
                plot
                    It can be 'scatter' (default) or 'line'.

        show : bool
            Whether show the figure.

        Returns
        -------
        values : dict
            A dict with the format of ``{func1: (x_val, y_val), func2: (x_val, y_val)}``.
        """
        print('plot nullcline ...')

        if numerical_setting is None:
            numerical_setting = dict()
        x_setting = numerical_setting.get(self.x_eq_group.func_name, {})
        y_setting = numerical_setting.get(self.y_eq_group.func_name, {})
        x_coords = x_setting.get('coords', self.x_var + '-' + self.y_var)
        y_coords = y_setting.get('coords', self.x_var + '-' + self.y_var)
        x_plot_style = x_setting.get('plot', 'scatter')
        y_plot_style = y_setting.get('plot', 'scatter')

        xs = self.resolutions[self.x_var]
        ys = self.resolutions[self.y_var]

        # Nullcline of the y variable
        y_style = dict(
            color='cornflowerblue',
            alpha=.7,
        )
        y_by_x = self.get_y_by_x_in_y_eq()
        if y_by_x['status'] == 'sympy_success':
            try:
                y_values_in_y_eq = y_by_x['f'](xs)
            except TypeError:
                raise errors.ModelUseError(
                    'Missing variables. Please check and set missing '
                    'variables to "fixed_vars".')
            x_values_in_y_eq = xs
            plt.plot(xs,
                     y_values_in_y_eq,
                     **y_style,
                     label=f"{self.y_var} nullcline")

        else:
            x_by_y = self.get_x_by_y_in_y_eq()
            if x_by_y['status'] == 'sympy_success':
                try:
                    x_values_in_y_eq = x_by_y['f'](ys)
                except TypeError:
                    raise errors.ModelUseError(
                        'Missing variables. Please check and set missing '
                        'variables to "fixed_vars".')
                y_values_in_y_eq = ys
                plt.plot(x_values_in_y_eq,
                         ys,
                         **y_style,
                         label=f"{self.y_var} nullcline")
            else:
                # optimization results
                optimizer = self.get_f_optimize_y_nullcline(y_coords)
                x_values_in_y_eq, y_values_in_y_eq = optimizer()

                if x_plot_style == 'scatter':
                    plt.plot(x_values_in_y_eq,
                             y_values_in_y_eq,
                             '.',
                             **y_style,
                             label=f"{self.y_var} nullcline")
                elif x_plot_style == 'line':
                    plt.plot(x_values_in_y_eq,
                             y_values_in_y_eq,
                             **y_style,
                             label=f"{self.y_var} nullcline")
                else:
                    raise ValueError(f'Unknown plot style: {x_plot_style}')

        # Nullcline of the x variable
        x_style = dict(
            color='lightcoral',
            alpha=.7,
        )
        y_by_x = self.get_y_by_x_in_x_eq()
        if y_by_x['status'] == 'sympy_success':
            try:
                y_values_in_x_eq = y_by_x['f'](xs)
            except TypeError:
                raise errors.ModelUseError(
                    'Missing variables. Please check and set missing '
                    'variables to "fixed_vars".')
            x_values_in_x_eq = xs
            plt.plot(xs,
                     y_values_in_x_eq,
                     **x_style,
                     label=f"{self.x_var} nullcline")

        else:
            x_by_y = self.get_x_by_y_in_x_eq()
            if x_by_y['status'] == 'sympy_success':
                try:
                    x_values_in_x_eq = x_by_y['f'](ys)
                except TypeError:
                    raise errors.ModelUseError(
                        'Missing variables. Please check and set missing '
                        'variables to "fixed_vars".')
                y_values_in_x_eq = ys
                plt.plot(x_values_in_x_eq,
                         ys,
                         **x_style,
                         label=f"{self.x_var} nullcline")
            else:
                # optimization results
                optimizer = self.get_f_optimize_x_nullcline(x_coords)
                x_values_in_x_eq, y_values_in_x_eq = optimizer()

                # visualization
                if y_plot_style == 'scatter':
                    plt.plot(x_values_in_x_eq,
                             y_values_in_x_eq,
                             '.',
                             **x_style,
                             label=f"{self.x_var} nullcline")
                elif y_plot_style == 'line':
                    plt.plot(x_values_in_x_eq,
                             y_values_in_x_eq,
                             **x_style,
                             label=f"{self.x_var} nullcline")
                else:
                    raise ValueError(f'Unknown plot style: {x_plot_style}')
        # finally
        plt.xlabel(self.x_var)
        plt.ylabel(self.y_var)
        scale = (self.options.lim_scale - 1.) / 2
        plt.xlim(*utils.rescale(self.target_vars[self.x_var], scale=scale))
        plt.ylim(*utils.rescale(self.target_vars[self.y_var], scale=scale))
        plt.legend()
        if show:
            plt.show()

        return {
            self.x_eq_group.func_name: (x_values_in_x_eq, y_values_in_x_eq),
            self.y_eq_group.func_name: (x_values_in_y_eq, y_values_in_y_eq)
        }
Esempio n. 6
0
  def plot_bifurcation(self, with_plot=True, show=False, with_return=False,
                       tol_aux=1e-8, loss_screen=None):
    utils.output('I am making bifurcation analysis ...')

    xs = self.resolutions[self.x_var]
    vps = bm.meshgrid(xs, *tuple(self.resolutions[p] for p in self.target_par_names))
    vps = tuple(jnp.moveaxis(vp.value, 0, 1).flatten() for vp in vps)
    candidates = vps[0]
    pars = vps[1:]
    fixed_points, _, pars = self._get_fixed_points(candidates, *pars,
                                                   tol_aux=tol_aux,
                                                   loss_screen=loss_screen,
                                                   num_seg=len(xs))
    dfxdx = np.asarray(self.F_vmap_dfxdx(jnp.asarray(fixed_points), *pars))
    pars = tuple(np.asarray(p) for p in pars)

    if with_plot:
      if len(self.target_pars) == 1:
        container = {c: {'p': [], 'x': []} for c in stability.get_1d_stability_types()}

        # fixed point
        for p, x, dx in zip(pars[0], fixed_points, dfxdx):
          fp_type = stability.stability_analysis(dx)
          container[fp_type]['p'].append(p)
          container[fp_type]['x'].append(x)

        # visualization
        plt.figure(self.x_var)
        for fp_type, points in container.items():
          if len(points['x']):
            plot_style = stability.plot_scheme[fp_type]
            plt.plot(points['p'], points['x'], '.', **plot_style, label=fp_type)
        plt.xlabel(self.target_par_names[0])
        plt.ylabel(self.x_var)

        scale = (self.lim_scale - 1) / 2
        plt.xlim(*utils.rescale(self.target_pars[self.target_par_names[0]], scale=scale))
        plt.ylim(*utils.rescale(self.target_vars[self.x_var], scale=scale))

        plt.legend()
        if show:
          plt.show()

      elif len(self.target_pars) == 2:
        container = {c: {'p0': [], 'p1': [], 'x': []} for c in stability.get_1d_stability_types()}

        # fixed point
        for p0, p1, x, dx in zip(pars[0], pars[1], fixed_points, dfxdx):
          fp_type = stability.stability_analysis(dx)
          container[fp_type]['p0'].append(p0)
          container[fp_type]['p1'].append(p1)
          container[fp_type]['x'].append(x)

        # visualization
        fig = plt.figure(self.x_var)
        ax = fig.add_subplot(projection='3d')
        for fp_type, points in container.items():
          if len(points['x']):
            plot_style = stability.plot_scheme[fp_type]
            xs = points['p0']
            ys = points['p1']
            zs = points['x']
            ax.scatter(xs, ys, zs, **plot_style, label=fp_type)

        ax.set_xlabel(self.target_par_names[0])
        ax.set_ylabel(self.target_par_names[1])
        ax.set_zlabel(self.x_var)

        scale = (self.lim_scale - 1) / 2
        ax.set_xlim(*utils.rescale(self.target_pars[self.target_par_names[0]], scale=scale))
        ax.set_ylim(*utils.rescale(self.target_pars[self.target_par_names[1]], scale=scale))
        ax.set_zlim(*utils.rescale(self.target_vars[self.x_var], scale=scale))

        ax.grid(True)
        ax.legend()
        if show:
          plt.show()

      else:
        raise errors.BrainPyError(f'Cannot visualize co-dimension {len(self.target_pars)} '
                                  f'bifurcation.')
    if with_return:
      return fixed_points, pars, dfxdx
Esempio n. 7
0
  def plot_bifurcation(self, with_plot=True, show=False, with_return=False,
                       tol_aux=1e-8, tol_unique=1e-2, tol_opt_candidate=None,
                       num_par_segments=1, num_fp_segment=1, nullcline_aux_filter=1.,
                       select_candidates='aux_rank', num_rank=100):
    """Make the bifurcation analysis.

    Parameters
    ----------
    with_plot: bool
      Whether plot the bifurcation figure.
    show: bool
      Whether show the figure.
    with_return: bool
      Whether return the computed bifurcation results.
    tol_aux: float
      The loss tolerance of auxiliary function :math:`f_{aux}` to confirm the fixed
      point. Default is 1e-7. Once :math:`f_{aux}(x_1) < \mathrm{tol\_aux}`,
      :math:`x_1` will be a fixed point.
    tol_unique: float
      The tolerance of distance between candidate fixed points to confirm they are
      the same. Default is 1e-2. If :math:`|x_1 - x_2| > \mathrm{tol\_unique}`,
      then :math:`x_1` and :math:`x_2` are unique fixed points. Otherwise,
      :math:`x_1` and :math:`x_2` will be treated as a same fixed point.
    tol_opt_candidate: float, optional
      The tolerance of auxiliary function :math:`f_{aux}` to select candidate
      initial points for fixed point optimization.
    num_par_segments: int, sequence of int
      How to segment parameters.
    num_fp_segment: int
      How to segment fixed points.
    nullcline_aux_filter: float
      The
    select_candidates: str
      The method to select candidate fixed points. It can be:

      - ``fx-nullcline``: use the points of fx-nullcline.
      - ``fy-nullcline``: use the points of fy-nullcline.
      - ``nullclines``: use the points in both of fx-nullcline and fy-nullcline.
      - ``aux_rank``: use the minimal value of points for the auxiliary function.
    num_rank: int
      The number of candidates to be used to optimize the fixed points.
      rank to use.

    Returns
    -------
    results : tuple
      Return a tuple of analyzed results:

      - fixed points: a 2D matrix with the shape of (num_point, num_var)
      - parameters: a 2D matrix with the shape of (num_point, num_par)
      - jacobians: a 3D tensors with the shape of (num_point, 2, 2)
    """
    utils.output('I am making bifurcation analysis ...')

    if self._can_convert_to_one_eq():
      if self.convert_type() == C.x_by_y:
        X = self.resolutions[self.y_var].value
      else:
        X = self.resolutions[self.x_var].value
      pars = tuple(self.resolutions[p].value for p in self.target_par_names)
      mesh_values = jnp.meshgrid(*((X,) + pars))
      mesh_values = tuple(jnp.moveaxis(v, 0, 1).flatten() for v in mesh_values)
      candidates = mesh_values[0]
      parameters = mesh_values[1:]

    else:
      if select_candidates == 'fx-nullcline':
        fx_nullclines = self._get_fx_nullcline_points(num_segments=num_par_segments,
                                                      fp_aux_filter=nullcline_aux_filter)
        candidates = fx_nullclines[0]
        parameters = fx_nullclines[1:]
      elif select_candidates == 'fy-nullcline':
        fy_nullclines = self._get_fy_nullcline_points(num_segments=num_par_segments,
                                                      fp_aux_filter=nullcline_aux_filter)
        candidates = fy_nullclines[0]
        parameters = fy_nullclines[1:]
      elif select_candidates == 'nullclines':
        fx_nullclines = self._get_fx_nullcline_points(num_segments=num_par_segments,
                                                      fp_aux_filter=nullcline_aux_filter)
        fy_nullclines = self._get_fy_nullcline_points(num_segments=num_par_segments,
                                                      fp_aux_filter=nullcline_aux_filter)
        candidates = jnp.vstack([fx_nullclines[0], fy_nullclines[0]])
        parameters = [jnp.concatenate([fx_nullclines[i], fy_nullclines[i]])
                      for i in range(1, len(fy_nullclines))]
      elif select_candidates == 'aux_rank':
        assert nullcline_aux_filter > 0.
        candidates, parameters = self._get_fp_candidates_by_aux_rank(num_segments=num_par_segments,
                                                                     num_rank=num_rank)
      else:
        raise ValueError
    candidates, _, parameters = self._get_fixed_points(candidates,
                                                       *parameters,
                                                       tol_aux=tol_aux,
                                                       tol_unique=tol_unique,
                                                       tol_opt_candidate=tol_opt_candidate,
                                                       num_segment=num_fp_segment)
    candidates = np.asarray(candidates)
    parameters = np.stack(tuple(np.asarray(p) for p in parameters)).T
    utils.output('I am trying to filter out duplicate fixed points ...')
    final_fps = []
    final_pars = []
    for par in np.unique(parameters, axis=0):
      ids = np.where(np.all(parameters == par, axis=1))[0]
      fps, ids2 = utils.keep_unique(candidates[ids], tolerance=tol_unique)
      final_fps.append(fps)
      final_pars.append(parameters[ids[ids2]])
    final_fps = np.vstack(final_fps)  # with the shape of (num_point, num_var)
    final_pars = np.vstack(final_pars)  # with the shape of (num_point, num_par)
    jacobians = np.asarray(self.F_vmap_jacobian(jnp.asarray(final_fps), *final_pars.T))
    utils.output(f'{C.prefix}Found {len(final_fps)} fixed points.')

    # remember the fixed points for later limit cycle plotting
    self._fixed_points = (final_fps, final_pars)

    if with_plot:
      # bifurcation analysis of co-dimension 1
      if len(self.target_pars) == 1:
        container = {c: {'p': [], self.x_var: [], self.y_var: []}
                     for c in stability.get_2d_stability_types()}

        # fixed point
        for p, xy, J in zip(final_pars, final_fps, jacobians):
          fp_type = stability.stability_analysis(J)
          container[fp_type]['p'].append(p[0])
          container[fp_type][self.x_var].append(xy[0])
          container[fp_type][self.y_var].append(xy[1])

        # visualization
        for var in self.target_var_names:
          plt.figure(var)
          for fp_type, points in container.items():
            if len(points['p']):
              plot_style = stability.plot_scheme[fp_type]
              plt.plot(points['p'], points[var], '.', **plot_style, label=fp_type)
          plt.xlabel(self.target_par_names[0])
          plt.ylabel(var)

          scale = (self.lim_scale - 1) / 2
          plt.xlim(*utils.rescale(self.target_pars[self.target_par_names[0]], scale=scale))
          plt.ylim(*utils.rescale(self.target_vars[var], scale=scale))

          plt.legend()
        if show:
          plt.show()

      # bifurcation analysis of co-dimension 2
      elif len(self.target_pars) == 2:
        container = {c: {'p0': [], 'p1': [], self.x_var: [], self.y_var: []}
                     for c in stability.get_2d_stability_types()}

        # fixed point
        for p, xy, J in zip(final_pars, final_fps, jacobians):
          fp_type = stability.stability_analysis(J)
          container[fp_type]['p0'].append(p[0])
          container[fp_type]['p1'].append(p[1])
          container[fp_type][self.x_var].append(xy[0])
          container[fp_type][self.y_var].append(xy[1])

        # visualization
        for var in self.target_var_names:
          fig = plt.figure(var)
          ax = fig.add_subplot(projection='3d')
          for fp_type, points in container.items():
            if len(points['p0']):
              plot_style = stability.plot_scheme[fp_type]
              xs = points['p0']
              ys = points['p1']
              zs = points[var]
              ax.scatter(xs, ys, zs, **plot_style, label=fp_type)

          ax.set_xlabel(self.target_par_names[0])
          ax.set_ylabel(self.target_par_names[1])
          ax.set_zlabel(var)
          scale = (self.lim_scale - 1) / 2
          ax.set_xlim(*utils.rescale(self.target_pars[self.target_par_names[0]], scale=scale))
          ax.set_ylim(*utils.rescale(self.target_pars[self.target_par_names[1]], scale=scale))
          ax.set_zlim(*utils.rescale(self.target_vars[var], scale=scale))
          ax.grid(True)
          ax.legend()
        if show:
          plt.show()

      else:
        raise ValueError('Unknown length of parameters.')

    if with_return:
      return final_fps, final_pars, jacobians
Esempio n. 8
0
    def plot_limit_cycle_by_sim(self,
                                initials,
                                duration,
                                tol=0.01,
                                show=False,
                                dt=None):
        """Plot trajectories according to the settings.

    Parameters
    ----------
    initials : list, tuple
        The initial value setting of the targets.

        - It can be a tuple/list of floats to specify each value of dynamical variables
          (for example, ``(a, b)``).
        - It can also be a tuple/list of tuple to specify multiple initial values (for
          example, ``[(a1, b1), (a2, b2)]``).
    duration : int, float, tuple, list
        The running duration. Same with the ``duration`` in ``NeuGroup.run()``.

        - It can be a int/float (``t_end``) to specify the same running end time,
        - Or it can be a tuple/list of int/float (``(t_start, t_end)``) to specify
          the start and end simulation time.
        - Or, it can be a list of tuple (``[(t1_start, t1_end), (t2_start, t2_end)]``)
          to specify the specific start and end simulation time for each initial value.
    show : bool
        Whether show or not.
    """
        utils.output('I am plotting the limit cycle ...')

        # 1. format the initial values
        initials = utils.check_initials(initials, self.target_var_names)

        # 2. format the running duration
        assert isinstance(duration, (int, float))

        dt = math.get_dt() if dt is None else dt
        traject_model = utils.TrajectModel(initial_vars=initials,
                                           integrals={
                                               self.x_var: self.F_int_x,
                                               self.y_var: self.F_int_y
                                           },
                                           dt=dt)
        mon_res = traject_model.run(duration=duration)

        # 5. run the network
        for init_i, initial in enumerate(zip(*list(initials.values()))):
            #   5.2 run the model
            x_data = mon_res[self.x_var][:, init_i]
            y_data = mon_res[self.y_var][:, init_i]
            max_index = utils.find_indexes_of_limit_cycle_max(x_data, tol=tol)
            if max_index[0] != -1:
                x_cycle = x_data[max_index[0]:max_index[1]]
                y_cycle = y_data[max_index[0]:max_index[1]]
                # 5.5 visualization
                lines = plt.plot(x_cycle, y_cycle, label='limit cycle')
                utils.add_arrow(lines[0])
            else:
                utils.output(
                    f'No limit cycle found for initial value {initial}')

        # 6. visualization
        plt.xlabel(self.x_var)
        plt.ylabel(self.y_var)
        scale = (self.lim_scale - 1.) / 2
        plt.xlim(*utils.rescale(self.target_vars[self.x_var], scale=scale))
        plt.ylim(*utils.rescale(self.target_vars[self.y_var], scale=scale))
        plt.legend()

        if show:
            plt.show()
Esempio n. 9
0
    def plot_trajectory(self,
                        initials,
                        duration,
                        plot_durations=None,
                        axes='v-v',
                        dt=None,
                        show=False,
                        with_plot=True,
                        with_return=False,
                        **kwargs):
        """Plot trajectories according to the settings.

    Parameters
    ----------
    initials : list, tuple, dict
        The initial value setting of the targets. It can be a tuple/list of floats to specify
        each value of dynamical variables (for example, ``(a, b)``). It can also be a
        tuple/list of tuple to specify multiple initial values (for example,
        ``[(a1, b1), (a2, b2)]``).
    duration : int, float, tuple, list
        The running duration. Same with the ``duration`` in ``NeuGroup.run()``.

        - It can be a int/float (``t_end``) to specify the same running end time,
        - Or it can be a tuple/list of int/float (``(t_start, t_end)``) to specify
          the start and end simulation time.
        - Or, it can be a list of tuple (``[(t1_start, t1_end), (t2_start, t2_end)]``)
          to specify the specific start and end simulation time for each initial value.
    plot_durations : tuple, list, optional
        The duration to plot. It can be a tuple with ``(start, end)``. It can
        also be a list of tuple ``[(start1, end1), (start2, end2)]`` to specify
        the plot duration for each initial value running.
    axes : str
        The axes to plot. It can be:

         - 'v-v': Plot the trajectory in the 'x_var'-'y_var' axis.
         - 't-v': Plot the trajectory in the 'time'-'var' axis.
    show : bool
        Whether show or not.
    """

        utils.output('I am plotting the trajectory ...')

        if axes not in ['v-v', 't-v']:
            raise errors.AnalyzerError(
                f'Unknown axes "{axes}", only support "v-v" and "t-v".')

        # check the initial values
        initials = utils.check_initials(initials, self.target_var_names)

        # 2. format the running duration
        assert isinstance(duration, (int, float))

        # 3. format the plot duration
        plot_durations = utils.check_plot_durations(plot_durations, duration,
                                                    initials)

        # 5. run the network
        dt = math.get_dt() if dt is None else dt
        traject_model = utils.TrajectModel(initial_vars=initials,
                                           integrals={
                                               self.x_var: self.F_int_x,
                                               self.y_var: self.F_int_y
                                           },
                                           dt=dt)
        mon_res = traject_model.run(duration=duration)

        if with_plot:
            # plots
            for i, initial in enumerate(zip(*list(initials.values()))):
                # legend
                legend = f'$traj_{i}$: '
                for j, key in enumerate(self.target_var_names):
                    legend += f'{key}={round(float(initial[j]), 4)}, '
                legend = legend[:-2]

                # visualization
                start = int(plot_durations[i][0] / dt)
                end = int(plot_durations[i][1] / dt)
                if axes == 'v-v':
                    lines = plt.plot(mon_res[self.x_var][start:end, i],
                                     mon_res[self.y_var][start:end, i],
                                     label=legend,
                                     **kwargs)
                    utils.add_arrow(lines[0])
                else:
                    plt.plot(mon_res.ts[start:end],
                             mon_res[self.x_var][start:end, i],
                             label=legend + f', {self.x_var}',
                             **kwargs)
                    plt.plot(mon_res.ts[start:end],
                             mon_res[self.y_var][start:end, i],
                             label=legend + f', {self.y_var}',
                             **kwargs)

            # visualization of others
            if axes == 'v-v':
                plt.xlabel(self.x_var)
                plt.ylabel(self.y_var)
                scale = (self.lim_scale - 1.) / 2
                plt.xlim(
                    *utils.rescale(self.target_vars[self.x_var], scale=scale))
                plt.ylim(
                    *utils.rescale(self.target_vars[self.y_var], scale=scale))
                plt.legend()
            else:
                plt.legend(title='Initial values')

            if show:
                plt.show()

        if with_return:
            return mon_res
Esempio n. 10
0
    def plot_nullcline(self,
                       with_plot=True,
                       with_return=False,
                       y_style=None,
                       x_style=None,
                       show=False,
                       coords=None,
                       tol_nullcline=1e-7):
        """Plot the nullcline."""
        utils.output('I am computing fx-nullcline ...')

        if coords is None:
            coords = dict()
        x_coord = coords.get(self.x_var, None)
        y_coord = coords.get(self.y_var, None)

        # Nullcline of the x variable
        xy_values_in_fx, = self._get_fx_nullcline_points(coords=x_coord,
                                                         tol=tol_nullcline)
        x_values_in_fx = np.asarray(xy_values_in_fx[:, 0])
        y_values_in_fx = np.asarray(xy_values_in_fx[:, 1])

        if with_plot:
            if x_style is None:
                x_style = dict(
                    color='cornflowerblue',
                    alpha=.7,
                )
            fmt = x_style.pop('fmt', '.')
            plt.plot(x_values_in_fx,
                     y_values_in_fx,
                     fmt,
                     **x_style,
                     label=f"{self.x_var} nullcline")

        # Nullcline of the y variable
        utils.output('I am computing fy-nullcline ...')
        xy_values_in_fy, = self._get_fy_nullcline_points(coords=y_coord,
                                                         tol=tol_nullcline)
        x_values_in_fy = np.asarray(xy_values_in_fy[:, 0])
        y_values_in_fy = np.asarray(xy_values_in_fy[:, 1])

        if with_plot:
            if y_style is None:
                y_style = dict(
                    color='lightcoral',
                    alpha=.7,
                )
            fmt = y_style.pop('fmt', '.')
            plt.plot(x_values_in_fy,
                     y_values_in_fy,
                     fmt,
                     **y_style,
                     label=f"{self.y_var} nullcline")

        if with_plot:
            plt.xlabel(self.x_var)
            plt.ylabel(self.y_var)
            scale = (self.lim_scale - 1.) / 2
            plt.xlim(*utils.rescale(self.target_vars[self.x_var], scale=scale))
            plt.ylim(*utils.rescale(self.target_vars[self.y_var], scale=scale))
            plt.legend()
            if show:
                plt.show()

        if with_return:
            return {
                self.x_var: (x_values_in_fx, y_values_in_fx),
                self.y_var: (x_values_in_fy, y_values_in_fy)
            }