def plot_trajectory(self, initials, duration, plot_durations=None, dt=None, show=False, with_plot=True, with_return=False): utils.output('I am plotting the trajectory ...') # check the initial values initials = utils.check_initials(initials, self.target_var_names + self.target_par_names) # 2. format the running duration assert isinstance(duration, (int, float)) # 3. format the plot duration plot_durations = utils.check_plot_durations(plot_durations, duration, initials) # 5. run the network dt = bm.get_dt() if dt is None else dt traject_model = utils.TrajectModel(initial_vars=initials, integrals=self._std_integrators, dt=dt) mon_res = traject_model.run(duration=duration) if with_plot: assert len(self.target_par_names) <= 2 # plots for i, initial in enumerate(zip(*list(initials.values()))): # legend legend = f'$traj_{i}$: ' for j, key in enumerate(self.target_var_names): legend += f'{key}={initial[j]}, ' legend = legend[:-2] # visualization start = int(plot_durations[i][0] / dt) end = int(plot_durations[i][1] / dt) p1_var = self.target_par_names[0] if len(self.target_par_names) == 1: lines = plt.plot(mon_res[self.x_var][start: end, i], mon_res[p1_var][start: end, i], label=legend) elif len(self.target_par_names) == 2: p2_var = self.target_par_names[1] lines = plt.plot(mon_res[self.x_var][start: end, i], mon_res[p1_var][start: end, i], mon_res[p2_var][start: end, i], label=legend) else: raise ValueError utils.add_arrow(lines[0]) # # visualization of others # plt.xlabel(self.x_var) # plt.ylabel(self.target_par_names[0]) # scale = (self.lim_scale - 1.) / 2 # plt.xlim(*utils.rescale(self.target_vars[self.x_var], scale=scale)) # plt.ylim(*utils.rescale(self.target_vars[self.target_par_names[0]], scale=scale)) plt.legend() if show: plt.show() if with_return: return mon_res
def plot_vector_field(self, show=False, with_plot=True, with_return=False): """Plot the vector filed.""" utils.output('I am creating the vector field ...') # Nullcline of the x variable y_val = self.F_fx(self.resolutions[self.x_var]) y_val = np.asarray(y_val) # visualization if with_plot: label = f"d{self.x_var}dt" x_style = dict(color='lightcoral', alpha=.7, linewidth=4) plt.plot(np.asarray(self.resolutions[self.x_var]), y_val, **x_style, label=label) plt.axhline(0) plt.xlabel(self.x_var) plt.ylabel(label) plt.xlim(*utils.rescale(self.target_vars[self.x_var], scale=(self.lim_scale - 1.) / 2)) plt.legend() if show: plt.show() # return if with_return: return y_val
def plot_fixed_point(self, show=False, with_plot=True, with_return=False): """Plot the fixed point.""" utils.output('I am searching fixed points ...') # fixed points and stability analysis fps, _, pars = self._get_fixed_points(self.resolutions[self.x_var]) container = {a: [] for a in stability.get_1d_stability_types()} for i in range(len(fps)): x = fps[i] dfdx = self.F_dfxdx(x) fp_type = stability.stability_analysis(dfdx) utils.output( f"Fixed point #{i + 1} at {self.x_var}={x} is a {fp_type}.") container[fp_type].append(x) # visualization if with_plot: for fp_type, points in container.items(): if len(points): plot_style = stability.plot_scheme[fp_type] plt.plot(points, [0] * len(points), '.', markersize=20, **plot_style, label=fp_type) plt.legend() if show: plt.show() # return if with_return: return fps
def plot_trajectory(self, initials, duration, plot_durations=None, dt=None, show=False, with_plot=True, with_return=False): utils.output('I am plotting the trajectory ...') # check the initial values initials = utils.check_initials(initials, self.target_var_names + self.target_par_names) # 2. format the running duration assert isinstance(duration, (int, float)) # 3. format the plot duration plot_durations = utils.check_plot_durations(plot_durations, duration, initials) # 5. run the network dt = bm.get_dt() if dt is None else dt traject_model = utils.TrajectModel(initial_vars=initials, integrals=self._std_integrators, dt=dt) mon_res = traject_model.run(duration=duration) if with_plot: assert len(self.target_par_names) <= 1 # plots for i, initial in enumerate(zip(*list(initials.values()))): # legend legend = f'$traj_{i}$: ' for j, key in enumerate(self.target_var_names): legend += f'{key}={initial[j]}, ' legend = legend[:-2] start = int(plot_durations[i][0] / dt) end = int(plot_durations[i][1] / dt) # visualization plt.figure(self.x_var) lines = plt.plot(mon_res[self.target_par_names[0]][start: end, i], mon_res[self.x_var][start: end, i], label=legend) utils.add_arrow(lines[0]) plt.figure(self.y_var) lines = plt.plot(mon_res[self.target_par_names[0]][start: end, i], mon_res[self.y_var][start: end, i], label=legend) utils.add_arrow(lines[0]) plt.figure(self.x_var) plt.legend() plt.figure(self.y_var) plt.legend() if show: plt.show() if with_return: return mon_res
def plot_bifurcation(self, with_plot=True, show=False, with_return=False, tol_aux=1e-8, loss_screen=None): utils.output('I am making bifurcation analysis ...') xs = self.resolutions[self.x_var] vps = bm.meshgrid(xs, *tuple(self.resolutions[p] for p in self.target_par_names)) vps = tuple(jnp.moveaxis(vp.value, 0, 1).flatten() for vp in vps) candidates = vps[0] pars = vps[1:] fixed_points, _, pars = self._get_fixed_points(candidates, *pars, tol_aux=tol_aux, loss_screen=loss_screen, num_seg=len(xs)) dfxdx = np.asarray(self.F_vmap_dfxdx(jnp.asarray(fixed_points), *pars)) pars = tuple(np.asarray(p) for p in pars) if with_plot: if len(self.target_pars) == 1: container = {c: {'p': [], 'x': []} for c in stability.get_1d_stability_types()} # fixed point for p, x, dx in zip(pars[0], fixed_points, dfxdx): fp_type = stability.stability_analysis(dx) container[fp_type]['p'].append(p) container[fp_type]['x'].append(x) # visualization plt.figure(self.x_var) for fp_type, points in container.items(): if len(points['x']): plot_style = stability.plot_scheme[fp_type] plt.plot(points['p'], points['x'], '.', **plot_style, label=fp_type) plt.xlabel(self.target_par_names[0]) plt.ylabel(self.x_var) scale = (self.lim_scale - 1) / 2 plt.xlim(*utils.rescale(self.target_pars[self.target_par_names[0]], scale=scale)) plt.ylim(*utils.rescale(self.target_vars[self.x_var], scale=scale)) plt.legend() if show: plt.show() elif len(self.target_pars) == 2: container = {c: {'p0': [], 'p1': [], 'x': []} for c in stability.get_1d_stability_types()} # fixed point for p0, p1, x, dx in zip(pars[0], pars[1], fixed_points, dfxdx): fp_type = stability.stability_analysis(dx) container[fp_type]['p0'].append(p0) container[fp_type]['p1'].append(p1) container[fp_type]['x'].append(x) # visualization fig = plt.figure(self.x_var) ax = fig.add_subplot(projection='3d') for fp_type, points in container.items(): if len(points['x']): plot_style = stability.plot_scheme[fp_type] xs = points['p0'] ys = points['p1'] zs = points['x'] ax.scatter(xs, ys, zs, **plot_style, label=fp_type) ax.set_xlabel(self.target_par_names[0]) ax.set_ylabel(self.target_par_names[1]) ax.set_zlabel(self.x_var) scale = (self.lim_scale - 1) / 2 ax.set_xlim(*utils.rescale(self.target_pars[self.target_par_names[0]], scale=scale)) ax.set_ylim(*utils.rescale(self.target_pars[self.target_par_names[1]], scale=scale)) ax.set_zlim(*utils.rescale(self.target_vars[self.x_var], scale=scale)) ax.grid(True) ax.legend() if show: plt.show() else: raise errors.BrainPyError(f'Cannot visualize co-dimension {len(self.target_pars)} ' f'bifurcation.') if with_return: return fixed_points, pars, dfxdx
def plot_limit_cycle_by_sim(self, duration=100, with_plot=True, with_return=False, plot_style=None, tol=0.001, show=False, dt=None, offset=1.): utils.output('I am plotting the limit cycle ...') if self._fixed_points is None: utils.output('No fixed points found, you may call "plot_bifurcation(with_plot=True)" first.') return final_fps, final_pars = self._fixed_points dt = bm.get_dt() if dt is None else dt traject_model = utils.TrajectModel( initial_vars={self.x_var: final_fps[:, 0] + offset, self.y_var: final_fps[:, 1] + offset}, integrals={self.x_var: self.F_int_x, self.y_var: self.F_int_y}, pars={p: v for p, v in zip(self.target_par_names, final_pars.T)}, dt=dt ) mon_res = traject_model.run(duration=duration) # find limit cycles vs_limit_cycle = tuple({'min': [], 'max': []} for _ in self.target_var_names) ps_limit_cycle = tuple([] for _ in self.target_par_names) for i in range(mon_res[self.x_var].shape[1]): data = mon_res[self.x_var][:, i] max_index = utils.find_indexes_of_limit_cycle_max(data, tol=tol) if max_index[0] != -1: cycle = data[max_index[0]: max_index[1]] vs_limit_cycle[0]['max'].append(mon_res[self.x_var][max_index[1], i]) vs_limit_cycle[0]['min'].append(cycle.min()) cycle = mon_res[self.y_var][max_index[0]: max_index[1], i] vs_limit_cycle[1]['max'].append(mon_res[self.y_var][max_index[1], i]) vs_limit_cycle[1]['min'].append(cycle.min()) for j in range(len(self.target_par_names)): ps_limit_cycle[j].append(final_pars[i, j]) vs_limit_cycle = tuple({k: np.asarray(v) for k, v in lm.items()} for lm in vs_limit_cycle) ps_limit_cycle = tuple(np.array(p) for p in ps_limit_cycle) # visualization if with_plot: if plot_style is None: plot_style = dict() fmt = plot_style.pop('fmt', '.') if len(self.target_par_names) == 2: for i, var in enumerate(self.target_var_names): plt.figure(var) plt.plot(ps_limit_cycle[0], ps_limit_cycle[1], vs_limit_cycle[i]['max'], **plot_style, label='limit cycle (max)') plt.plot(ps_limit_cycle[0], ps_limit_cycle[1], vs_limit_cycle[i]['min'], **plot_style, label='limit cycle (min)') plt.legend() elif len(self.target_par_names) == 1: for i, var in enumerate(self.target_var_names): plt.figure(var) plt.plot(ps_limit_cycle[0], vs_limit_cycle[i]['max'], fmt, **plot_style, label='limit cycle (max)') plt.plot(ps_limit_cycle[0], vs_limit_cycle[i]['min'], fmt, **plot_style, label='limit cycle (min)') plt.legend() else: raise errors.AnalyzerError if show: plt.show() if with_return: return vs_limit_cycle, ps_limit_cycle
def plot_bifurcation(self, with_plot=True, show=False, with_return=False, tol_aux=1e-8, tol_unique=1e-2, tol_opt_candidate=None, num_par_segments=1, num_fp_segment=1, nullcline_aux_filter=1., select_candidates='aux_rank', num_rank=100): """Make the bifurcation analysis. Parameters ---------- with_plot: bool Whether plot the bifurcation figure. show: bool Whether show the figure. with_return: bool Whether return the computed bifurcation results. tol_aux: float The loss tolerance of auxiliary function :math:`f_{aux}` to confirm the fixed point. Default is 1e-7. Once :math:`f_{aux}(x_1) < \mathrm{tol\_aux}`, :math:`x_1` will be a fixed point. tol_unique: float The tolerance of distance between candidate fixed points to confirm they are the same. Default is 1e-2. If :math:`|x_1 - x_2| > \mathrm{tol\_unique}`, then :math:`x_1` and :math:`x_2` are unique fixed points. Otherwise, :math:`x_1` and :math:`x_2` will be treated as a same fixed point. tol_opt_candidate: float, optional The tolerance of auxiliary function :math:`f_{aux}` to select candidate initial points for fixed point optimization. num_par_segments: int, sequence of int How to segment parameters. num_fp_segment: int How to segment fixed points. nullcline_aux_filter: float The select_candidates: str The method to select candidate fixed points. It can be: - ``fx-nullcline``: use the points of fx-nullcline. - ``fy-nullcline``: use the points of fy-nullcline. - ``nullclines``: use the points in both of fx-nullcline and fy-nullcline. - ``aux_rank``: use the minimal value of points for the auxiliary function. num_rank: int The number of candidates to be used to optimize the fixed points. rank to use. Returns ------- results : tuple Return a tuple of analyzed results: - fixed points: a 2D matrix with the shape of (num_point, num_var) - parameters: a 2D matrix with the shape of (num_point, num_par) - jacobians: a 3D tensors with the shape of (num_point, 2, 2) """ utils.output('I am making bifurcation analysis ...') if self._can_convert_to_one_eq(): if self.convert_type() == C.x_by_y: X = self.resolutions[self.y_var].value else: X = self.resolutions[self.x_var].value pars = tuple(self.resolutions[p].value for p in self.target_par_names) mesh_values = jnp.meshgrid(*((X,) + pars)) mesh_values = tuple(jnp.moveaxis(v, 0, 1).flatten() for v in mesh_values) candidates = mesh_values[0] parameters = mesh_values[1:] else: if select_candidates == 'fx-nullcline': fx_nullclines = self._get_fx_nullcline_points(num_segments=num_par_segments, fp_aux_filter=nullcline_aux_filter) candidates = fx_nullclines[0] parameters = fx_nullclines[1:] elif select_candidates == 'fy-nullcline': fy_nullclines = self._get_fy_nullcline_points(num_segments=num_par_segments, fp_aux_filter=nullcline_aux_filter) candidates = fy_nullclines[0] parameters = fy_nullclines[1:] elif select_candidates == 'nullclines': fx_nullclines = self._get_fx_nullcline_points(num_segments=num_par_segments, fp_aux_filter=nullcline_aux_filter) fy_nullclines = self._get_fy_nullcline_points(num_segments=num_par_segments, fp_aux_filter=nullcline_aux_filter) candidates = jnp.vstack([fx_nullclines[0], fy_nullclines[0]]) parameters = [jnp.concatenate([fx_nullclines[i], fy_nullclines[i]]) for i in range(1, len(fy_nullclines))] elif select_candidates == 'aux_rank': assert nullcline_aux_filter > 0. candidates, parameters = self._get_fp_candidates_by_aux_rank(num_segments=num_par_segments, num_rank=num_rank) else: raise ValueError candidates, _, parameters = self._get_fixed_points(candidates, *parameters, tol_aux=tol_aux, tol_unique=tol_unique, tol_opt_candidate=tol_opt_candidate, num_segment=num_fp_segment) candidates = np.asarray(candidates) parameters = np.stack(tuple(np.asarray(p) for p in parameters)).T utils.output('I am trying to filter out duplicate fixed points ...') final_fps = [] final_pars = [] for par in np.unique(parameters, axis=0): ids = np.where(np.all(parameters == par, axis=1))[0] fps, ids2 = utils.keep_unique(candidates[ids], tolerance=tol_unique) final_fps.append(fps) final_pars.append(parameters[ids[ids2]]) final_fps = np.vstack(final_fps) # with the shape of (num_point, num_var) final_pars = np.vstack(final_pars) # with the shape of (num_point, num_par) jacobians = np.asarray(self.F_vmap_jacobian(jnp.asarray(final_fps), *final_pars.T)) utils.output(f'{C.prefix}Found {len(final_fps)} fixed points.') # remember the fixed points for later limit cycle plotting self._fixed_points = (final_fps, final_pars) if with_plot: # bifurcation analysis of co-dimension 1 if len(self.target_pars) == 1: container = {c: {'p': [], self.x_var: [], self.y_var: []} for c in stability.get_2d_stability_types()} # fixed point for p, xy, J in zip(final_pars, final_fps, jacobians): fp_type = stability.stability_analysis(J) container[fp_type]['p'].append(p[0]) container[fp_type][self.x_var].append(xy[0]) container[fp_type][self.y_var].append(xy[1]) # visualization for var in self.target_var_names: plt.figure(var) for fp_type, points in container.items(): if len(points['p']): plot_style = stability.plot_scheme[fp_type] plt.plot(points['p'], points[var], '.', **plot_style, label=fp_type) plt.xlabel(self.target_par_names[0]) plt.ylabel(var) scale = (self.lim_scale - 1) / 2 plt.xlim(*utils.rescale(self.target_pars[self.target_par_names[0]], scale=scale)) plt.ylim(*utils.rescale(self.target_vars[var], scale=scale)) plt.legend() if show: plt.show() # bifurcation analysis of co-dimension 2 elif len(self.target_pars) == 2: container = {c: {'p0': [], 'p1': [], self.x_var: [], self.y_var: []} for c in stability.get_2d_stability_types()} # fixed point for p, xy, J in zip(final_pars, final_fps, jacobians): fp_type = stability.stability_analysis(J) container[fp_type]['p0'].append(p[0]) container[fp_type]['p1'].append(p[1]) container[fp_type][self.x_var].append(xy[0]) container[fp_type][self.y_var].append(xy[1]) # visualization for var in self.target_var_names: fig = plt.figure(var) ax = fig.add_subplot(projection='3d') for fp_type, points in container.items(): if len(points['p0']): plot_style = stability.plot_scheme[fp_type] xs = points['p0'] ys = points['p1'] zs = points[var] ax.scatter(xs, ys, zs, **plot_style, label=fp_type) ax.set_xlabel(self.target_par_names[0]) ax.set_ylabel(self.target_par_names[1]) ax.set_zlabel(var) scale = (self.lim_scale - 1) / 2 ax.set_xlim(*utils.rescale(self.target_pars[self.target_par_names[0]], scale=scale)) ax.set_ylim(*utils.rescale(self.target_pars[self.target_par_names[1]], scale=scale)) ax.set_zlim(*utils.rescale(self.target_vars[var], scale=scale)) ax.grid(True) ax.legend() if show: plt.show() else: raise ValueError('Unknown length of parameters.') if with_return: return final_fps, final_pars, jacobians
def plot_limit_cycle_by_sim(self, initials, duration, tol=0.01, show=False, dt=None): """Plot trajectories according to the settings. Parameters ---------- initials : list, tuple The initial value setting of the targets. - It can be a tuple/list of floats to specify each value of dynamical variables (for example, ``(a, b)``). - It can also be a tuple/list of tuple to specify multiple initial values (for example, ``[(a1, b1), (a2, b2)]``). duration : int, float, tuple, list The running duration. Same with the ``duration`` in ``NeuGroup.run()``. - It can be a int/float (``t_end``) to specify the same running end time, - Or it can be a tuple/list of int/float (``(t_start, t_end)``) to specify the start and end simulation time. - Or, it can be a list of tuple (``[(t1_start, t1_end), (t2_start, t2_end)]``) to specify the specific start and end simulation time for each initial value. show : bool Whether show or not. """ utils.output('I am plotting the limit cycle ...') # 1. format the initial values initials = utils.check_initials(initials, self.target_var_names) # 2. format the running duration assert isinstance(duration, (int, float)) dt = math.get_dt() if dt is None else dt traject_model = utils.TrajectModel(initial_vars=initials, integrals={ self.x_var: self.F_int_x, self.y_var: self.F_int_y }, dt=dt) mon_res = traject_model.run(duration=duration) # 5. run the network for init_i, initial in enumerate(zip(*list(initials.values()))): # 5.2 run the model x_data = mon_res[self.x_var][:, init_i] y_data = mon_res[self.y_var][:, init_i] max_index = utils.find_indexes_of_limit_cycle_max(x_data, tol=tol) if max_index[0] != -1: x_cycle = x_data[max_index[0]:max_index[1]] y_cycle = y_data[max_index[0]:max_index[1]] # 5.5 visualization lines = plt.plot(x_cycle, y_cycle, label='limit cycle') utils.add_arrow(lines[0]) else: utils.output( f'No limit cycle found for initial value {initial}') # 6. visualization plt.xlabel(self.x_var) plt.ylabel(self.y_var) scale = (self.lim_scale - 1.) / 2 plt.xlim(*utils.rescale(self.target_vars[self.x_var], scale=scale)) plt.ylim(*utils.rescale(self.target_vars[self.y_var], scale=scale)) plt.legend() if show: plt.show()
def plot_trajectory(self, initials, duration, plot_durations=None, axes='v-v', dt=None, show=False, with_plot=True, with_return=False, **kwargs): """Plot trajectories according to the settings. Parameters ---------- initials : list, tuple, dict The initial value setting of the targets. It can be a tuple/list of floats to specify each value of dynamical variables (for example, ``(a, b)``). It can also be a tuple/list of tuple to specify multiple initial values (for example, ``[(a1, b1), (a2, b2)]``). duration : int, float, tuple, list The running duration. Same with the ``duration`` in ``NeuGroup.run()``. - It can be a int/float (``t_end``) to specify the same running end time, - Or it can be a tuple/list of int/float (``(t_start, t_end)``) to specify the start and end simulation time. - Or, it can be a list of tuple (``[(t1_start, t1_end), (t2_start, t2_end)]``) to specify the specific start and end simulation time for each initial value. plot_durations : tuple, list, optional The duration to plot. It can be a tuple with ``(start, end)``. It can also be a list of tuple ``[(start1, end1), (start2, end2)]`` to specify the plot duration for each initial value running. axes : str The axes to plot. It can be: - 'v-v': Plot the trajectory in the 'x_var'-'y_var' axis. - 't-v': Plot the trajectory in the 'time'-'var' axis. show : bool Whether show or not. """ utils.output('I am plotting the trajectory ...') if axes not in ['v-v', 't-v']: raise errors.AnalyzerError( f'Unknown axes "{axes}", only support "v-v" and "t-v".') # check the initial values initials = utils.check_initials(initials, self.target_var_names) # 2. format the running duration assert isinstance(duration, (int, float)) # 3. format the plot duration plot_durations = utils.check_plot_durations(plot_durations, duration, initials) # 5. run the network dt = math.get_dt() if dt is None else dt traject_model = utils.TrajectModel(initial_vars=initials, integrals={ self.x_var: self.F_int_x, self.y_var: self.F_int_y }, dt=dt) mon_res = traject_model.run(duration=duration) if with_plot: # plots for i, initial in enumerate(zip(*list(initials.values()))): # legend legend = f'$traj_{i}$: ' for j, key in enumerate(self.target_var_names): legend += f'{key}={round(float(initial[j]), 4)}, ' legend = legend[:-2] # visualization start = int(plot_durations[i][0] / dt) end = int(plot_durations[i][1] / dt) if axes == 'v-v': lines = plt.plot(mon_res[self.x_var][start:end, i], mon_res[self.y_var][start:end, i], label=legend, **kwargs) utils.add_arrow(lines[0]) else: plt.plot(mon_res.ts[start:end], mon_res[self.x_var][start:end, i], label=legend + f', {self.x_var}', **kwargs) plt.plot(mon_res.ts[start:end], mon_res[self.y_var][start:end, i], label=legend + f', {self.y_var}', **kwargs) # visualization of others if axes == 'v-v': plt.xlabel(self.x_var) plt.ylabel(self.y_var) scale = (self.lim_scale - 1.) / 2 plt.xlim( *utils.rescale(self.target_vars[self.x_var], scale=scale)) plt.ylim( *utils.rescale(self.target_vars[self.y_var], scale=scale)) plt.legend() else: plt.legend(title='Initial values') if show: plt.show() if with_return: return mon_res
def plot_fixed_point( self, with_plot=True, with_return=False, show=False, tol_unique=1e-2, tol_aux=1e-8, tol_opt_screen=None, select_candidates='fx-nullcline', num_rank=100, ): """Plot the fixed point and analyze its stability. """ utils.output('I am searching fixed points ...') if self._can_convert_to_one_eq(): if self.convert_type() == C.x_by_y: candidates = self.resolutions[self.y_var].value else: candidates = self.resolutions[self.x_var].value else: if select_candidates == 'fx-nullcline': candidates = [ self.analyzed_results[key][0] for key in self.analyzed_results.keys() if key.startswith(C.fx_nullcline_points) ] if len(candidates) == 0: raise errors.AnalyzerError( f'No nullcline points are found, please call ' f'".{self.plot_nullcline.__name__}()" first.') candidates = jnp.vstack(candidates) elif select_candidates == 'fy-nullcline': candidates = [ self.analyzed_results[key][0] for key in self.analyzed_results.keys() if key.startswith(C.fy_nullcline_points) ] if len(candidates) == 0: raise errors.AnalyzerError( f'No nullcline points are found, please call ' f'".{self.plot_nullcline.__name__}()" first.') candidates = jnp.vstack(candidates) elif select_candidates == 'nullclines': candidates = [ self.analyzed_results[key][0] for key in self.analyzed_results.keys() if key.startswith(C.fy_nullcline_points) or key.startswith(C.fy_nullcline_points) ] if len(candidates) == 0: raise errors.AnalyzerError( f'No nullcline points are found, please call ' f'".{self.plot_nullcline.__name__}()" first.') candidates = jnp.vstack(candidates) elif select_candidates == 'aux_rank': candidates, _ = self._get_fp_candidates_by_aux_rank( num_rank=num_rank) else: raise ValueError # get fixed points if len(candidates): fixed_points, _, _ = self._get_fixed_points( jnp.asarray(candidates), tol_aux=tol_aux, tol_unique=tol_unique, tol_opt_candidate=tol_opt_screen) utils.output( 'I am trying to filter out duplicate fixed points ...') fixed_points = np.asarray(fixed_points) fixed_points, _ = utils.keep_unique(fixed_points, tolerance=tol_unique) utils.output(f'{C.prefix}Found {len(fixed_points)} fixed points.') else: utils.output(f'{C.prefix}Found no fixed points.') return # stability analysis # ------------------ container = { a: { 'x': [], 'y': [] } for a in stability.get_2d_stability_types() } for i in range(len(fixed_points)): x = fixed_points[i, 0] y = fixed_points[i, 1] fp_type = stability.stability_analysis(self.F_jacobian(x, y)) utils.output( f"{C.prefix}#{i + 1} {self.x_var}={x}, {self.y_var}={y} is a {fp_type}." ) container[fp_type]['x'].append(x) container[fp_type]['y'].append(y) # visualization # ------------- if with_plot: for fp_type, points in container.items(): if len(points['x']): plot_style = stability.plot_scheme[fp_type] plt.plot(points['x'], points['y'], '.', markersize=20, **plot_style, label=fp_type) plt.legend() if show: plt.show() if with_return: return fixed_points
def plot_nullcline(self, with_plot=True, with_return=False, y_style=None, x_style=None, show=False, coords=None, tol_nullcline=1e-7): """Plot the nullcline.""" utils.output('I am computing fx-nullcline ...') if coords is None: coords = dict() x_coord = coords.get(self.x_var, None) y_coord = coords.get(self.y_var, None) # Nullcline of the x variable xy_values_in_fx, = self._get_fx_nullcline_points(coords=x_coord, tol=tol_nullcline) x_values_in_fx = np.asarray(xy_values_in_fx[:, 0]) y_values_in_fx = np.asarray(xy_values_in_fx[:, 1]) if with_plot: if x_style is None: x_style = dict( color='cornflowerblue', alpha=.7, ) fmt = x_style.pop('fmt', '.') plt.plot(x_values_in_fx, y_values_in_fx, fmt, **x_style, label=f"{self.x_var} nullcline") # Nullcline of the y variable utils.output('I am computing fy-nullcline ...') xy_values_in_fy, = self._get_fy_nullcline_points(coords=y_coord, tol=tol_nullcline) x_values_in_fy = np.asarray(xy_values_in_fy[:, 0]) y_values_in_fy = np.asarray(xy_values_in_fy[:, 1]) if with_plot: if y_style is None: y_style = dict( color='lightcoral', alpha=.7, ) fmt = y_style.pop('fmt', '.') plt.plot(x_values_in_fy, y_values_in_fy, fmt, **y_style, label=f"{self.y_var} nullcline") if with_plot: plt.xlabel(self.x_var) plt.ylabel(self.y_var) scale = (self.lim_scale - 1.) / 2 plt.xlim(*utils.rescale(self.target_vars[self.x_var], scale=scale)) plt.ylim(*utils.rescale(self.target_vars[self.y_var], scale=scale)) plt.legend() if show: plt.show() if with_return: return { self.x_var: (x_values_in_fx, y_values_in_fx), self.y_var: (x_values_in_fy, y_values_in_fy) }
def plot_vector_field(self, with_plot=True, with_return=False, plot_method='streamplot', plot_style=None, show=False): """Plot the vector field. Parameters ---------- with_plot: bool with_return : bool show : bool plot_method : str The method to plot the vector filed. It can be "streamplot" or "quiver". plot_style : dict, optional The style for vector filed plotting. - For ``plot_method="streamplot"``, it can set the keywords like "density", "linewidth", "color", "arrowsize". More settings please check https://matplotlib.org/api/_as_gen/matplotlib.pyplot.streamplot.html. - For ``plot_method="quiver"``, it can set the keywords like "color", "units", "angles", "scale". More settings please check https://matplotlib.org/api/_as_gen/matplotlib.pyplot.quiver.html. """ utils.output('I am creating the vector field ...') # get vector fields xs = self.resolutions[self.x_var] ys = self.resolutions[self.y_var] X, Y = bm.meshgrid(xs, ys) dx = self.F_fx(X, Y) dy = self.F_fy(X, Y) X, Y = np.asarray(X), np.asarray(Y) dx, dy = np.asarray(dx), np.asarray(dy) if with_plot: # plot vector fields if plot_method == 'quiver': if plot_style is None: plot_style = dict(units='xy') if (not np.isnan(dx).any()) and (not np.isnan(dy).any()): speed = np.sqrt(dx**2 + dy**2) dx = dx / speed dy = dy / speed plt.quiver(X, Y, dx, dy, **plot_style) elif plot_method == 'streamplot': if plot_style is None: plot_style = dict(arrowsize=1.2, density=1, color='thistle') linewidth = plot_style.get('linewidth', None) if linewidth is None: if (not np.isnan(dx).any()) and (not np.isnan(dy).any()): min_width, max_width = 0.5, 5.5 speed = np.nan_to_num(np.sqrt(dx**2 + dy**2)) linewidth = min_width + max_width * (speed / speed.max()) plt.streamplot(X, Y, dx, dy, linewidth=linewidth, **plot_style) else: raise errors.AnalyzerError( f'Unknown plot_method "{plot_method}", ' f'only supports "quiver" and "streamplot".') plt.xlabel(self.x_var) plt.ylabel(self.y_var) if show: plt.show() if with_return: # return vector fields return dx, dy