Example #1
0
  def plot_trajectory(self, initials, duration, plot_durations=None,
                      dt=None, show=False, with_plot=True, with_return=False):
    utils.output('I am plotting the trajectory ...')

    # check the initial values
    initials = utils.check_initials(initials, self.target_var_names + self.target_par_names)

    # 2. format the running duration
    assert isinstance(duration, (int, float))

    # 3. format the plot duration
    plot_durations = utils.check_plot_durations(plot_durations, duration, initials)

    # 5. run the network
    dt = bm.get_dt() if dt is None else dt

    traject_model = utils.TrajectModel(initial_vars=initials, integrals=self._std_integrators, dt=dt)
    mon_res = traject_model.run(duration=duration)

    if with_plot:
      assert len(self.target_par_names) <= 2

      # plots
      for i, initial in enumerate(zip(*list(initials.values()))):
        # legend
        legend = f'$traj_{i}$: '
        for j, key in enumerate(self.target_var_names):
          legend += f'{key}={initial[j]}, '
        legend = legend[:-2]

        # visualization
        start = int(plot_durations[i][0] / dt)
        end = int(plot_durations[i][1] / dt)
        p1_var = self.target_par_names[0]
        if len(self.target_par_names) == 1:
          lines = plt.plot(mon_res[self.x_var][start: end, i],
                           mon_res[p1_var][start: end, i], label=legend)
        elif len(self.target_par_names) == 2:
          p2_var = self.target_par_names[1]
          lines = plt.plot(mon_res[self.x_var][start: end, i],
                           mon_res[p1_var][start: end, i],
                           mon_res[p2_var][start: end, i],
                           label=legend)
        else:
          raise ValueError
        utils.add_arrow(lines[0])

      # # visualization of others
      # plt.xlabel(self.x_var)
      # plt.ylabel(self.target_par_names[0])
      # scale = (self.lim_scale - 1.) / 2
      # plt.xlim(*utils.rescale(self.target_vars[self.x_var], scale=scale))
      # plt.ylim(*utils.rescale(self.target_vars[self.target_par_names[0]], scale=scale))
      plt.legend()

      if show:
        plt.show()

    if with_return:
      return mon_res
    def plot_vector_field(self, show=False, with_plot=True, with_return=False):
        """Plot the vector filed."""
        utils.output('I am creating the vector field ...')

        # Nullcline of the x variable
        y_val = self.F_fx(self.resolutions[self.x_var])
        y_val = np.asarray(y_val)

        # visualization
        if with_plot:
            label = f"d{self.x_var}dt"
            x_style = dict(color='lightcoral', alpha=.7, linewidth=4)
            plt.plot(np.asarray(self.resolutions[self.x_var]),
                     y_val,
                     **x_style,
                     label=label)
            plt.axhline(0)
            plt.xlabel(self.x_var)
            plt.ylabel(label)
            plt.xlim(*utils.rescale(self.target_vars[self.x_var],
                                    scale=(self.lim_scale - 1.) / 2))
            plt.legend()
            if show: plt.show()
        # return
        if with_return:
            return y_val
    def plot_fixed_point(self, show=False, with_plot=True, with_return=False):
        """Plot the fixed point."""
        utils.output('I am searching fixed points ...')

        # fixed points and stability analysis
        fps, _, pars = self._get_fixed_points(self.resolutions[self.x_var])
        container = {a: [] for a in stability.get_1d_stability_types()}
        for i in range(len(fps)):
            x = fps[i]
            dfdx = self.F_dfxdx(x)
            fp_type = stability.stability_analysis(dfdx)
            utils.output(
                f"Fixed point #{i + 1} at {self.x_var}={x} is a {fp_type}.")
            container[fp_type].append(x)

        # visualization
        if with_plot:
            for fp_type, points in container.items():
                if len(points):
                    plot_style = stability.plot_scheme[fp_type]
                    plt.plot(points, [0] * len(points),
                             '.',
                             markersize=20,
                             **plot_style,
                             label=fp_type)
            plt.legend()
            if show:
                plt.show()

        # return
        if with_return:
            return fps
Example #4
0
  def plot_trajectory(self, initials, duration, plot_durations=None,
                      dt=None, show=False, with_plot=True, with_return=False):
    utils.output('I am plotting the trajectory ...')

    # check the initial values
    initials = utils.check_initials(initials, self.target_var_names + self.target_par_names)

    # 2. format the running duration
    assert isinstance(duration, (int, float))

    # 3. format the plot duration
    plot_durations = utils.check_plot_durations(plot_durations, duration, initials)

    # 5. run the network
    dt = bm.get_dt() if dt is None else dt

    traject_model = utils.TrajectModel(initial_vars=initials, integrals=self._std_integrators, dt=dt)
    mon_res = traject_model.run(duration=duration)

    if with_plot:
      assert len(self.target_par_names) <= 1
      # plots
      for i, initial in enumerate(zip(*list(initials.values()))):
        # legend
        legend = f'$traj_{i}$: '
        for j, key in enumerate(self.target_var_names):
          legend += f'{key}={initial[j]}, '
        legend = legend[:-2]

        start = int(plot_durations[i][0] / dt)
        end = int(plot_durations[i][1] / dt)

        # visualization
        plt.figure(self.x_var)
        lines = plt.plot(mon_res[self.target_par_names[0]][start: end, i],
                         mon_res[self.x_var][start: end, i],
                         label=legend)
        utils.add_arrow(lines[0])

        plt.figure(self.y_var)
        lines = plt.plot(mon_res[self.target_par_names[0]][start: end, i],
                         mon_res[self.y_var][start: end, i],
                         label=legend)
        utils.add_arrow(lines[0])

      plt.figure(self.x_var)
      plt.legend()
      plt.figure(self.y_var)
      plt.legend()

      if show:
        plt.show()

    if with_return:
      return mon_res
Example #5
0
  def plot_bifurcation(self, with_plot=True, show=False, with_return=False,
                       tol_aux=1e-8, loss_screen=None):
    utils.output('I am making bifurcation analysis ...')

    xs = self.resolutions[self.x_var]
    vps = bm.meshgrid(xs, *tuple(self.resolutions[p] for p in self.target_par_names))
    vps = tuple(jnp.moveaxis(vp.value, 0, 1).flatten() for vp in vps)
    candidates = vps[0]
    pars = vps[1:]
    fixed_points, _, pars = self._get_fixed_points(candidates, *pars,
                                                   tol_aux=tol_aux,
                                                   loss_screen=loss_screen,
                                                   num_seg=len(xs))
    dfxdx = np.asarray(self.F_vmap_dfxdx(jnp.asarray(fixed_points), *pars))
    pars = tuple(np.asarray(p) for p in pars)

    if with_plot:
      if len(self.target_pars) == 1:
        container = {c: {'p': [], 'x': []} for c in stability.get_1d_stability_types()}

        # fixed point
        for p, x, dx in zip(pars[0], fixed_points, dfxdx):
          fp_type = stability.stability_analysis(dx)
          container[fp_type]['p'].append(p)
          container[fp_type]['x'].append(x)

        # visualization
        plt.figure(self.x_var)
        for fp_type, points in container.items():
          if len(points['x']):
            plot_style = stability.plot_scheme[fp_type]
            plt.plot(points['p'], points['x'], '.', **plot_style, label=fp_type)
        plt.xlabel(self.target_par_names[0])
        plt.ylabel(self.x_var)

        scale = (self.lim_scale - 1) / 2
        plt.xlim(*utils.rescale(self.target_pars[self.target_par_names[0]], scale=scale))
        plt.ylim(*utils.rescale(self.target_vars[self.x_var], scale=scale))

        plt.legend()
        if show:
          plt.show()

      elif len(self.target_pars) == 2:
        container = {c: {'p0': [], 'p1': [], 'x': []} for c in stability.get_1d_stability_types()}

        # fixed point
        for p0, p1, x, dx in zip(pars[0], pars[1], fixed_points, dfxdx):
          fp_type = stability.stability_analysis(dx)
          container[fp_type]['p0'].append(p0)
          container[fp_type]['p1'].append(p1)
          container[fp_type]['x'].append(x)

        # visualization
        fig = plt.figure(self.x_var)
        ax = fig.add_subplot(projection='3d')
        for fp_type, points in container.items():
          if len(points['x']):
            plot_style = stability.plot_scheme[fp_type]
            xs = points['p0']
            ys = points['p1']
            zs = points['x']
            ax.scatter(xs, ys, zs, **plot_style, label=fp_type)

        ax.set_xlabel(self.target_par_names[0])
        ax.set_ylabel(self.target_par_names[1])
        ax.set_zlabel(self.x_var)

        scale = (self.lim_scale - 1) / 2
        ax.set_xlim(*utils.rescale(self.target_pars[self.target_par_names[0]], scale=scale))
        ax.set_ylim(*utils.rescale(self.target_pars[self.target_par_names[1]], scale=scale))
        ax.set_zlim(*utils.rescale(self.target_vars[self.x_var], scale=scale))

        ax.grid(True)
        ax.legend()
        if show:
          plt.show()

      else:
        raise errors.BrainPyError(f'Cannot visualize co-dimension {len(self.target_pars)} '
                                  f'bifurcation.')
    if with_return:
      return fixed_points, pars, dfxdx
Example #6
0
  def plot_limit_cycle_by_sim(self, duration=100, with_plot=True, with_return=False,
                              plot_style=None, tol=0.001, show=False, dt=None, offset=1.):
    utils.output('I am plotting the limit cycle ...')
    if self._fixed_points is None:
      utils.output('No fixed points found, you may call "plot_bifurcation(with_plot=True)" first.')
      return

    final_fps, final_pars = self._fixed_points
    dt = bm.get_dt() if dt is None else dt
    traject_model = utils.TrajectModel(
      initial_vars={self.x_var: final_fps[:, 0] + offset, self.y_var: final_fps[:, 1] + offset},
      integrals={self.x_var: self.F_int_x, self.y_var: self.F_int_y},
      pars={p: v for p, v in zip(self.target_par_names, final_pars.T)},
      dt=dt
    )
    mon_res = traject_model.run(duration=duration)

    # find limit cycles
    vs_limit_cycle = tuple({'min': [], 'max': []} for _ in self.target_var_names)
    ps_limit_cycle = tuple([] for _ in self.target_par_names)
    for i in range(mon_res[self.x_var].shape[1]):
      data = mon_res[self.x_var][:, i]
      max_index = utils.find_indexes_of_limit_cycle_max(data, tol=tol)
      if max_index[0] != -1:
        cycle = data[max_index[0]: max_index[1]]
        vs_limit_cycle[0]['max'].append(mon_res[self.x_var][max_index[1], i])
        vs_limit_cycle[0]['min'].append(cycle.min())
        cycle = mon_res[self.y_var][max_index[0]: max_index[1], i]
        vs_limit_cycle[1]['max'].append(mon_res[self.y_var][max_index[1], i])
        vs_limit_cycle[1]['min'].append(cycle.min())
        for j in range(len(self.target_par_names)):
          ps_limit_cycle[j].append(final_pars[i, j])
    vs_limit_cycle = tuple({k: np.asarray(v) for k, v in lm.items()} for lm in vs_limit_cycle)
    ps_limit_cycle = tuple(np.array(p) for p in ps_limit_cycle)

    # visualization
    if with_plot:
      if plot_style is None: plot_style = dict()
      fmt = plot_style.pop('fmt', '.')

      if len(self.target_par_names) == 2:
        for i, var in enumerate(self.target_var_names):
          plt.figure(var)
          plt.plot(ps_limit_cycle[0], ps_limit_cycle[1], vs_limit_cycle[i]['max'],
                   **plot_style, label='limit cycle (max)')
          plt.plot(ps_limit_cycle[0], ps_limit_cycle[1], vs_limit_cycle[i]['min'],
                   **plot_style, label='limit cycle (min)')
          plt.legend()

      elif len(self.target_par_names) == 1:
        for i, var in enumerate(self.target_var_names):
          plt.figure(var)
          plt.plot(ps_limit_cycle[0], vs_limit_cycle[i]['max'], fmt,
                   **plot_style, label='limit cycle (max)')
          plt.plot(ps_limit_cycle[0], vs_limit_cycle[i]['min'], fmt,
                   **plot_style, label='limit cycle (min)')
          plt.legend()

      else:
        raise errors.AnalyzerError

      if show:
        plt.show()

    if with_return:
      return vs_limit_cycle, ps_limit_cycle
Example #7
0
  def plot_bifurcation(self, with_plot=True, show=False, with_return=False,
                       tol_aux=1e-8, tol_unique=1e-2, tol_opt_candidate=None,
                       num_par_segments=1, num_fp_segment=1, nullcline_aux_filter=1.,
                       select_candidates='aux_rank', num_rank=100):
    """Make the bifurcation analysis.

    Parameters
    ----------
    with_plot: bool
      Whether plot the bifurcation figure.
    show: bool
      Whether show the figure.
    with_return: bool
      Whether return the computed bifurcation results.
    tol_aux: float
      The loss tolerance of auxiliary function :math:`f_{aux}` to confirm the fixed
      point. Default is 1e-7. Once :math:`f_{aux}(x_1) < \mathrm{tol\_aux}`,
      :math:`x_1` will be a fixed point.
    tol_unique: float
      The tolerance of distance between candidate fixed points to confirm they are
      the same. Default is 1e-2. If :math:`|x_1 - x_2| > \mathrm{tol\_unique}`,
      then :math:`x_1` and :math:`x_2` are unique fixed points. Otherwise,
      :math:`x_1` and :math:`x_2` will be treated as a same fixed point.
    tol_opt_candidate: float, optional
      The tolerance of auxiliary function :math:`f_{aux}` to select candidate
      initial points for fixed point optimization.
    num_par_segments: int, sequence of int
      How to segment parameters.
    num_fp_segment: int
      How to segment fixed points.
    nullcline_aux_filter: float
      The
    select_candidates: str
      The method to select candidate fixed points. It can be:

      - ``fx-nullcline``: use the points of fx-nullcline.
      - ``fy-nullcline``: use the points of fy-nullcline.
      - ``nullclines``: use the points in both of fx-nullcline and fy-nullcline.
      - ``aux_rank``: use the minimal value of points for the auxiliary function.
    num_rank: int
      The number of candidates to be used to optimize the fixed points.
      rank to use.

    Returns
    -------
    results : tuple
      Return a tuple of analyzed results:

      - fixed points: a 2D matrix with the shape of (num_point, num_var)
      - parameters: a 2D matrix with the shape of (num_point, num_par)
      - jacobians: a 3D tensors with the shape of (num_point, 2, 2)
    """
    utils.output('I am making bifurcation analysis ...')

    if self._can_convert_to_one_eq():
      if self.convert_type() == C.x_by_y:
        X = self.resolutions[self.y_var].value
      else:
        X = self.resolutions[self.x_var].value
      pars = tuple(self.resolutions[p].value for p in self.target_par_names)
      mesh_values = jnp.meshgrid(*((X,) + pars))
      mesh_values = tuple(jnp.moveaxis(v, 0, 1).flatten() for v in mesh_values)
      candidates = mesh_values[0]
      parameters = mesh_values[1:]

    else:
      if select_candidates == 'fx-nullcline':
        fx_nullclines = self._get_fx_nullcline_points(num_segments=num_par_segments,
                                                      fp_aux_filter=nullcline_aux_filter)
        candidates = fx_nullclines[0]
        parameters = fx_nullclines[1:]
      elif select_candidates == 'fy-nullcline':
        fy_nullclines = self._get_fy_nullcline_points(num_segments=num_par_segments,
                                                      fp_aux_filter=nullcline_aux_filter)
        candidates = fy_nullclines[0]
        parameters = fy_nullclines[1:]
      elif select_candidates == 'nullclines':
        fx_nullclines = self._get_fx_nullcline_points(num_segments=num_par_segments,
                                                      fp_aux_filter=nullcline_aux_filter)
        fy_nullclines = self._get_fy_nullcline_points(num_segments=num_par_segments,
                                                      fp_aux_filter=nullcline_aux_filter)
        candidates = jnp.vstack([fx_nullclines[0], fy_nullclines[0]])
        parameters = [jnp.concatenate([fx_nullclines[i], fy_nullclines[i]])
                      for i in range(1, len(fy_nullclines))]
      elif select_candidates == 'aux_rank':
        assert nullcline_aux_filter > 0.
        candidates, parameters = self._get_fp_candidates_by_aux_rank(num_segments=num_par_segments,
                                                                     num_rank=num_rank)
      else:
        raise ValueError
    candidates, _, parameters = self._get_fixed_points(candidates,
                                                       *parameters,
                                                       tol_aux=tol_aux,
                                                       tol_unique=tol_unique,
                                                       tol_opt_candidate=tol_opt_candidate,
                                                       num_segment=num_fp_segment)
    candidates = np.asarray(candidates)
    parameters = np.stack(tuple(np.asarray(p) for p in parameters)).T
    utils.output('I am trying to filter out duplicate fixed points ...')
    final_fps = []
    final_pars = []
    for par in np.unique(parameters, axis=0):
      ids = np.where(np.all(parameters == par, axis=1))[0]
      fps, ids2 = utils.keep_unique(candidates[ids], tolerance=tol_unique)
      final_fps.append(fps)
      final_pars.append(parameters[ids[ids2]])
    final_fps = np.vstack(final_fps)  # with the shape of (num_point, num_var)
    final_pars = np.vstack(final_pars)  # with the shape of (num_point, num_par)
    jacobians = np.asarray(self.F_vmap_jacobian(jnp.asarray(final_fps), *final_pars.T))
    utils.output(f'{C.prefix}Found {len(final_fps)} fixed points.')

    # remember the fixed points for later limit cycle plotting
    self._fixed_points = (final_fps, final_pars)

    if with_plot:
      # bifurcation analysis of co-dimension 1
      if len(self.target_pars) == 1:
        container = {c: {'p': [], self.x_var: [], self.y_var: []}
                     for c in stability.get_2d_stability_types()}

        # fixed point
        for p, xy, J in zip(final_pars, final_fps, jacobians):
          fp_type = stability.stability_analysis(J)
          container[fp_type]['p'].append(p[0])
          container[fp_type][self.x_var].append(xy[0])
          container[fp_type][self.y_var].append(xy[1])

        # visualization
        for var in self.target_var_names:
          plt.figure(var)
          for fp_type, points in container.items():
            if len(points['p']):
              plot_style = stability.plot_scheme[fp_type]
              plt.plot(points['p'], points[var], '.', **plot_style, label=fp_type)
          plt.xlabel(self.target_par_names[0])
          plt.ylabel(var)

          scale = (self.lim_scale - 1) / 2
          plt.xlim(*utils.rescale(self.target_pars[self.target_par_names[0]], scale=scale))
          plt.ylim(*utils.rescale(self.target_vars[var], scale=scale))

          plt.legend()
        if show:
          plt.show()

      # bifurcation analysis of co-dimension 2
      elif len(self.target_pars) == 2:
        container = {c: {'p0': [], 'p1': [], self.x_var: [], self.y_var: []}
                     for c in stability.get_2d_stability_types()}

        # fixed point
        for p, xy, J in zip(final_pars, final_fps, jacobians):
          fp_type = stability.stability_analysis(J)
          container[fp_type]['p0'].append(p[0])
          container[fp_type]['p1'].append(p[1])
          container[fp_type][self.x_var].append(xy[0])
          container[fp_type][self.y_var].append(xy[1])

        # visualization
        for var in self.target_var_names:
          fig = plt.figure(var)
          ax = fig.add_subplot(projection='3d')
          for fp_type, points in container.items():
            if len(points['p0']):
              plot_style = stability.plot_scheme[fp_type]
              xs = points['p0']
              ys = points['p1']
              zs = points[var]
              ax.scatter(xs, ys, zs, **plot_style, label=fp_type)

          ax.set_xlabel(self.target_par_names[0])
          ax.set_ylabel(self.target_par_names[1])
          ax.set_zlabel(var)
          scale = (self.lim_scale - 1) / 2
          ax.set_xlim(*utils.rescale(self.target_pars[self.target_par_names[0]], scale=scale))
          ax.set_ylim(*utils.rescale(self.target_pars[self.target_par_names[1]], scale=scale))
          ax.set_zlim(*utils.rescale(self.target_vars[var], scale=scale))
          ax.grid(True)
          ax.legend()
        if show:
          plt.show()

      else:
        raise ValueError('Unknown length of parameters.')

    if with_return:
      return final_fps, final_pars, jacobians
    def plot_limit_cycle_by_sim(self,
                                initials,
                                duration,
                                tol=0.01,
                                show=False,
                                dt=None):
        """Plot trajectories according to the settings.

    Parameters
    ----------
    initials : list, tuple
        The initial value setting of the targets.

        - It can be a tuple/list of floats to specify each value of dynamical variables
          (for example, ``(a, b)``).
        - It can also be a tuple/list of tuple to specify multiple initial values (for
          example, ``[(a1, b1), (a2, b2)]``).
    duration : int, float, tuple, list
        The running duration. Same with the ``duration`` in ``NeuGroup.run()``.

        - It can be a int/float (``t_end``) to specify the same running end time,
        - Or it can be a tuple/list of int/float (``(t_start, t_end)``) to specify
          the start and end simulation time.
        - Or, it can be a list of tuple (``[(t1_start, t1_end), (t2_start, t2_end)]``)
          to specify the specific start and end simulation time for each initial value.
    show : bool
        Whether show or not.
    """
        utils.output('I am plotting the limit cycle ...')

        # 1. format the initial values
        initials = utils.check_initials(initials, self.target_var_names)

        # 2. format the running duration
        assert isinstance(duration, (int, float))

        dt = math.get_dt() if dt is None else dt
        traject_model = utils.TrajectModel(initial_vars=initials,
                                           integrals={
                                               self.x_var: self.F_int_x,
                                               self.y_var: self.F_int_y
                                           },
                                           dt=dt)
        mon_res = traject_model.run(duration=duration)

        # 5. run the network
        for init_i, initial in enumerate(zip(*list(initials.values()))):
            #   5.2 run the model
            x_data = mon_res[self.x_var][:, init_i]
            y_data = mon_res[self.y_var][:, init_i]
            max_index = utils.find_indexes_of_limit_cycle_max(x_data, tol=tol)
            if max_index[0] != -1:
                x_cycle = x_data[max_index[0]:max_index[1]]
                y_cycle = y_data[max_index[0]:max_index[1]]
                # 5.5 visualization
                lines = plt.plot(x_cycle, y_cycle, label='limit cycle')
                utils.add_arrow(lines[0])
            else:
                utils.output(
                    f'No limit cycle found for initial value {initial}')

        # 6. visualization
        plt.xlabel(self.x_var)
        plt.ylabel(self.y_var)
        scale = (self.lim_scale - 1.) / 2
        plt.xlim(*utils.rescale(self.target_vars[self.x_var], scale=scale))
        plt.ylim(*utils.rescale(self.target_vars[self.y_var], scale=scale))
        plt.legend()

        if show:
            plt.show()
    def plot_trajectory(self,
                        initials,
                        duration,
                        plot_durations=None,
                        axes='v-v',
                        dt=None,
                        show=False,
                        with_plot=True,
                        with_return=False,
                        **kwargs):
        """Plot trajectories according to the settings.

    Parameters
    ----------
    initials : list, tuple, dict
        The initial value setting of the targets. It can be a tuple/list of floats to specify
        each value of dynamical variables (for example, ``(a, b)``). It can also be a
        tuple/list of tuple to specify multiple initial values (for example,
        ``[(a1, b1), (a2, b2)]``).
    duration : int, float, tuple, list
        The running duration. Same with the ``duration`` in ``NeuGroup.run()``.

        - It can be a int/float (``t_end``) to specify the same running end time,
        - Or it can be a tuple/list of int/float (``(t_start, t_end)``) to specify
          the start and end simulation time.
        - Or, it can be a list of tuple (``[(t1_start, t1_end), (t2_start, t2_end)]``)
          to specify the specific start and end simulation time for each initial value.
    plot_durations : tuple, list, optional
        The duration to plot. It can be a tuple with ``(start, end)``. It can
        also be a list of tuple ``[(start1, end1), (start2, end2)]`` to specify
        the plot duration for each initial value running.
    axes : str
        The axes to plot. It can be:

         - 'v-v': Plot the trajectory in the 'x_var'-'y_var' axis.
         - 't-v': Plot the trajectory in the 'time'-'var' axis.
    show : bool
        Whether show or not.
    """

        utils.output('I am plotting the trajectory ...')

        if axes not in ['v-v', 't-v']:
            raise errors.AnalyzerError(
                f'Unknown axes "{axes}", only support "v-v" and "t-v".')

        # check the initial values
        initials = utils.check_initials(initials, self.target_var_names)

        # 2. format the running duration
        assert isinstance(duration, (int, float))

        # 3. format the plot duration
        plot_durations = utils.check_plot_durations(plot_durations, duration,
                                                    initials)

        # 5. run the network
        dt = math.get_dt() if dt is None else dt
        traject_model = utils.TrajectModel(initial_vars=initials,
                                           integrals={
                                               self.x_var: self.F_int_x,
                                               self.y_var: self.F_int_y
                                           },
                                           dt=dt)
        mon_res = traject_model.run(duration=duration)

        if with_plot:
            # plots
            for i, initial in enumerate(zip(*list(initials.values()))):
                # legend
                legend = f'$traj_{i}$: '
                for j, key in enumerate(self.target_var_names):
                    legend += f'{key}={round(float(initial[j]), 4)}, '
                legend = legend[:-2]

                # visualization
                start = int(plot_durations[i][0] / dt)
                end = int(plot_durations[i][1] / dt)
                if axes == 'v-v':
                    lines = plt.plot(mon_res[self.x_var][start:end, i],
                                     mon_res[self.y_var][start:end, i],
                                     label=legend,
                                     **kwargs)
                    utils.add_arrow(lines[0])
                else:
                    plt.plot(mon_res.ts[start:end],
                             mon_res[self.x_var][start:end, i],
                             label=legend + f', {self.x_var}',
                             **kwargs)
                    plt.plot(mon_res.ts[start:end],
                             mon_res[self.y_var][start:end, i],
                             label=legend + f', {self.y_var}',
                             **kwargs)

            # visualization of others
            if axes == 'v-v':
                plt.xlabel(self.x_var)
                plt.ylabel(self.y_var)
                scale = (self.lim_scale - 1.) / 2
                plt.xlim(
                    *utils.rescale(self.target_vars[self.x_var], scale=scale))
                plt.ylim(
                    *utils.rescale(self.target_vars[self.y_var], scale=scale))
                plt.legend()
            else:
                plt.legend(title='Initial values')

            if show:
                plt.show()

        if with_return:
            return mon_res
Example #10
0
    def plot_fixed_point(
        self,
        with_plot=True,
        with_return=False,
        show=False,
        tol_unique=1e-2,
        tol_aux=1e-8,
        tol_opt_screen=None,
        select_candidates='fx-nullcline',
        num_rank=100,
    ):
        """Plot the fixed point and analyze its stability.
    """
        utils.output('I am searching fixed points ...')

        if self._can_convert_to_one_eq():
            if self.convert_type() == C.x_by_y:
                candidates = self.resolutions[self.y_var].value
            else:
                candidates = self.resolutions[self.x_var].value
        else:
            if select_candidates == 'fx-nullcline':
                candidates = [
                    self.analyzed_results[key][0]
                    for key in self.analyzed_results.keys()
                    if key.startswith(C.fx_nullcline_points)
                ]
                if len(candidates) == 0:
                    raise errors.AnalyzerError(
                        f'No nullcline points are found, please call '
                        f'".{self.plot_nullcline.__name__}()" first.')
                candidates = jnp.vstack(candidates)
            elif select_candidates == 'fy-nullcline':
                candidates = [
                    self.analyzed_results[key][0]
                    for key in self.analyzed_results.keys()
                    if key.startswith(C.fy_nullcline_points)
                ]
                if len(candidates) == 0:
                    raise errors.AnalyzerError(
                        f'No nullcline points are found, please call '
                        f'".{self.plot_nullcline.__name__}()" first.')
                candidates = jnp.vstack(candidates)
            elif select_candidates == 'nullclines':
                candidates = [
                    self.analyzed_results[key][0]
                    for key in self.analyzed_results.keys()
                    if key.startswith(C.fy_nullcline_points)
                    or key.startswith(C.fy_nullcline_points)
                ]
                if len(candidates) == 0:
                    raise errors.AnalyzerError(
                        f'No nullcline points are found, please call '
                        f'".{self.plot_nullcline.__name__}()" first.')
                candidates = jnp.vstack(candidates)
            elif select_candidates == 'aux_rank':
                candidates, _ = self._get_fp_candidates_by_aux_rank(
                    num_rank=num_rank)
            else:
                raise ValueError

        # get fixed points
        if len(candidates):
            fixed_points, _, _ = self._get_fixed_points(
                jnp.asarray(candidates),
                tol_aux=tol_aux,
                tol_unique=tol_unique,
                tol_opt_candidate=tol_opt_screen)
            utils.output(
                'I am trying to filter out duplicate fixed points ...')
            fixed_points = np.asarray(fixed_points)
            fixed_points, _ = utils.keep_unique(fixed_points,
                                                tolerance=tol_unique)
            utils.output(f'{C.prefix}Found {len(fixed_points)} fixed points.')
        else:
            utils.output(f'{C.prefix}Found no fixed points.')
            return

        # stability analysis
        # ------------------
        container = {
            a: {
                'x': [],
                'y': []
            }
            for a in stability.get_2d_stability_types()
        }
        for i in range(len(fixed_points)):
            x = fixed_points[i, 0]
            y = fixed_points[i, 1]
            fp_type = stability.stability_analysis(self.F_jacobian(x, y))
            utils.output(
                f"{C.prefix}#{i + 1} {self.x_var}={x}, {self.y_var}={y} is a {fp_type}."
            )
            container[fp_type]['x'].append(x)
            container[fp_type]['y'].append(y)

        # visualization
        # -------------
        if with_plot:
            for fp_type, points in container.items():
                if len(points['x']):
                    plot_style = stability.plot_scheme[fp_type]
                    plt.plot(points['x'],
                             points['y'],
                             '.',
                             markersize=20,
                             **plot_style,
                             label=fp_type)
            plt.legend()
            if show:
                plt.show()

        if with_return:
            return fixed_points
Example #11
0
    def plot_nullcline(self,
                       with_plot=True,
                       with_return=False,
                       y_style=None,
                       x_style=None,
                       show=False,
                       coords=None,
                       tol_nullcline=1e-7):
        """Plot the nullcline."""
        utils.output('I am computing fx-nullcline ...')

        if coords is None:
            coords = dict()
        x_coord = coords.get(self.x_var, None)
        y_coord = coords.get(self.y_var, None)

        # Nullcline of the x variable
        xy_values_in_fx, = self._get_fx_nullcline_points(coords=x_coord,
                                                         tol=tol_nullcline)
        x_values_in_fx = np.asarray(xy_values_in_fx[:, 0])
        y_values_in_fx = np.asarray(xy_values_in_fx[:, 1])

        if with_plot:
            if x_style is None:
                x_style = dict(
                    color='cornflowerblue',
                    alpha=.7,
                )
            fmt = x_style.pop('fmt', '.')
            plt.plot(x_values_in_fx,
                     y_values_in_fx,
                     fmt,
                     **x_style,
                     label=f"{self.x_var} nullcline")

        # Nullcline of the y variable
        utils.output('I am computing fy-nullcline ...')
        xy_values_in_fy, = self._get_fy_nullcline_points(coords=y_coord,
                                                         tol=tol_nullcline)
        x_values_in_fy = np.asarray(xy_values_in_fy[:, 0])
        y_values_in_fy = np.asarray(xy_values_in_fy[:, 1])

        if with_plot:
            if y_style is None:
                y_style = dict(
                    color='lightcoral',
                    alpha=.7,
                )
            fmt = y_style.pop('fmt', '.')
            plt.plot(x_values_in_fy,
                     y_values_in_fy,
                     fmt,
                     **y_style,
                     label=f"{self.y_var} nullcline")

        if with_plot:
            plt.xlabel(self.x_var)
            plt.ylabel(self.y_var)
            scale = (self.lim_scale - 1.) / 2
            plt.xlim(*utils.rescale(self.target_vars[self.x_var], scale=scale))
            plt.ylim(*utils.rescale(self.target_vars[self.y_var], scale=scale))
            plt.legend()
            if show:
                plt.show()

        if with_return:
            return {
                self.x_var: (x_values_in_fx, y_values_in_fx),
                self.y_var: (x_values_in_fy, y_values_in_fy)
            }
Example #12
0
    def plot_vector_field(self,
                          with_plot=True,
                          with_return=False,
                          plot_method='streamplot',
                          plot_style=None,
                          show=False):
        """Plot the vector field.

    Parameters
    ----------
    with_plot: bool
    with_return : bool
    show : bool
    plot_method : str
        The method to plot the vector filed. It can be "streamplot" or "quiver".
    plot_style : dict, optional
        The style for vector filed plotting.

        - For ``plot_method="streamplot"``, it can set the keywords like "density",
          "linewidth", "color", "arrowsize". More settings please check
          https://matplotlib.org/api/_as_gen/matplotlib.pyplot.streamplot.html.
        - For ``plot_method="quiver"``, it can set the keywords like "color",
          "units", "angles", "scale". More settings please check
          https://matplotlib.org/api/_as_gen/matplotlib.pyplot.quiver.html.
    """
        utils.output('I am creating the vector field ...')

        # get vector fields
        xs = self.resolutions[self.x_var]
        ys = self.resolutions[self.y_var]
        X, Y = bm.meshgrid(xs, ys)
        dx = self.F_fx(X, Y)
        dy = self.F_fy(X, Y)
        X, Y = np.asarray(X), np.asarray(Y)
        dx, dy = np.asarray(dx), np.asarray(dy)

        if with_plot:  # plot vector fields
            if plot_method == 'quiver':
                if plot_style is None:
                    plot_style = dict(units='xy')
                if (not np.isnan(dx).any()) and (not np.isnan(dy).any()):
                    speed = np.sqrt(dx**2 + dy**2)
                    dx = dx / speed
                    dy = dy / speed
                plt.quiver(X, Y, dx, dy, **plot_style)
            elif plot_method == 'streamplot':
                if plot_style is None:
                    plot_style = dict(arrowsize=1.2,
                                      density=1,
                                      color='thistle')
                linewidth = plot_style.get('linewidth', None)
                if linewidth is None:
                    if (not np.isnan(dx).any()) and (not np.isnan(dy).any()):
                        min_width, max_width = 0.5, 5.5
                        speed = np.nan_to_num(np.sqrt(dx**2 + dy**2))
                        linewidth = min_width + max_width * (speed /
                                                             speed.max())
                plt.streamplot(X, Y, dx, dy, linewidth=linewidth, **plot_style)
            else:
                raise errors.AnalyzerError(
                    f'Unknown plot_method "{plot_method}", '
                    f'only supports "quiver" and "streamplot".')

            plt.xlabel(self.x_var)
            plt.ylabel(self.y_var)
            if show:
                plt.show()

        if with_return:  # return vector fields
            return dx, dy