def plot_trajectory(self, initials, duration, plot_duration=None, axes='v-v', show=False): """Plot trajectories according to the settings. Parameters ---------- initials : list, tuple, dict The initial value setting of the targets. It can be a tuple/list of floats to specify each value of dynamical variables (for example, ``(a, b)``). It can also be a tuple/list of tuple to specify multiple initial values (for example, ``[(a1, b1), (a2, b2)]``). duration : int, float, tuple, list The running duration. Same with the ``duration`` in ``NeuGroup.run()``. It can be a int/float (``t_end``) to specify the same running end time, or it can be a tuple/list of int/float (``(t_start, t_end)``) to specify the start and end simulation time. Or, it can be a list of tuple (``[(t1_start, t1_end), (t2_start, t2_end)]``) to specify the specific start and end simulation time for each initial value. plot_duration : tuple, list, optional The duration to plot. It can be a tuple with ``(start, end)``. It can also be a list of tuple ``[(start1, end1), (start2, end2)]`` to specify the plot duration for each initial value running. axes : str The axes to plot. It can be: - 'v-v' Plot the trajectory in the 'x_var'-'y_var' axis. - 't-v' Plot the trajectory in the 'time'-'var' axis. show : bool Whether show or not. """ print('plot trajectory ...') if axes not in ['v-v', 't-v']: raise errors.ModelUseError( f'Unknown axes "{axes}", only support "v-v" and "t-v".') # 1. format the initial values if isinstance(initials, dict): initials = [initials] elif isinstance(initials, (list, tuple)): if isinstance(initials[0], (int, float)): initials = [{ self.dvar_names[i]: v for i, v in enumerate(initials) }] elif isinstance(initials[0], dict): initials = initials elif isinstance(initials[0], (tuple, list)) and isinstance( initials[0][0], (int, float)): initials = [{ self.dvar_names[i]: v for i, v in enumerate(init) } for init in initials] else: raise ValueError else: raise ValueError # 2. format the running duration if isinstance(duration, (int, float)): duration = [(0, duration) for _ in range(len(initials))] elif isinstance(duration[0], (int, float)): duration = [duration for _ in range(len(initials))] else: assert len(duration) == len(initials) # 3. format the plot duration if plot_duration is None: plot_duration = duration if isinstance(plot_duration[0], (int, float)): plot_duration = [plot_duration for _ in range(len(initials))] else: assert len(plot_duration) == len(initials) # 5. run the network for init_i, initial in enumerate(initials): traj_group = Trajectory(size=1, integrals=self.model.integrals, target_vars=initial, fixed_vars=self.fixed_vars, pars_update=self.pars_update, scope=self.model.scopes) # 5.2 run the model traj_group.run( duration=duration[init_i], report=False, ) # 5.3 legend legend = f'$traj_{init_i}$: ' for key in self.dvar_names: legend += f'{key}={initial[key]}, ' legend = legend[:-2] # 5.4 trajectory start = int(plot_duration[init_i][0] / backend.get_dt()) end = int(plot_duration[init_i][1] / backend.get_dt()) # 5.5 visualization if axes == 'v-v': lines = plt.plot(traj_group.mon[self.x_var][start:end, 0], traj_group.mon[self.y_var][start:end, 0], label=legend) utils.add_arrow(lines[0]) else: plt.plot(traj_group.mon.ts[start:end], traj_group.mon[self.x_var][start:end, 0], label=legend + f', {self.x_var}') plt.plot(traj_group.mon.ts[start:end], traj_group.mon[self.y_var][start:end, 0], label=legend + f', {self.y_var}') # 6. visualization if axes == 'v-v': plt.xlabel(self.x_var) plt.ylabel(self.y_var) scale = (self.options.lim_scale - 1.) / 2 plt.xlim(*utils.rescale(self.target_vars[self.x_var], scale=scale)) plt.ylim(*utils.rescale(self.target_vars[self.y_var], scale=scale)) plt.legend() else: plt.legend(title='Initial values') if show: plt.show()
def plot_vector_field(self, plot_method='streamplot', plot_style=None, show=False): """Plot the vector field. Parameters ---------- plot_method : str The method to plot the vector filed. It can be "streamplot" or "quiver". plot_style : dict, optional The style for vector filed plotting. - For ``plot_method="streamplot"``, it can set the keywords like "density", "linewidth", "color", "arrowsize". More settings please check https://matplotlib.org/api/_as_gen/matplotlib.pyplot.streamplot.html. - For ``plot_method="quiver"``, it can set the keywords like "color", "units", "angles", "scale". More settings please check https://matplotlib.org/api/_as_gen/matplotlib.pyplot.quiver.html. Returns ------- result : tuple The ``dx``, ``dy`` values. """ print('plot vector field ...') if plot_style is None: plot_style = dict() xs = self.resolutions[self.x_var] ys = self.resolutions[self.y_var] X, Y = np.meshgrid(xs, ys) # dx try: dx = self.get_f_dx()(X, Y) except TypeError: raise errors.ModelUseError( 'Missing variables. Please check and set missing ' 'variables to "fixed_vars".') # dy try: dy = self.get_f_dy()(X, Y) except TypeError: raise errors.ModelUseError( 'Missing variables. Please check and set missing ' 'variables to "fixed_vars".') # vector field if plot_method == 'quiver': styles = dict() styles['units'] = plot_style.get('units', 'xy') if (not np.isnan(dx).any()) and (not np.isnan(dy).any()): speed = np.sqrt(dx**2 + dy**2) dx = dx / speed dy = dy / speed plt.quiver(X, Y, dx, dy, **styles) elif plot_method == 'streamplot': styles = dict() styles['arrowsize'] = plot_style.get('arrowsize', 1.2) styles['density'] = plot_style.get('density', 1) styles['color'] = plot_style.get('color', 'thistle') linewidth = plot_style.get('linewidth', None) if (linewidth is None) and (not np.isnan(dx).any()) and ( not np.isnan(dy).any()): min_width = plot_style.get('min_width', 0.5) max_width = plot_style.get('min_width', 5.5) speed = np.sqrt(dx**2 + dy**2) linewidth = min_width + max_width * speed / speed.max() plt.streamplot(X, Y, dx, dy, linewidth=linewidth, **styles) else: raise ValueError( f'Unknown plot_method "{plot_method}", only supports "quiver" and "streamplot".' ) plt.xlabel(self.x_var) plt.ylabel(self.y_var) if show: plt.show() return dx, dy
def plot_nullcline(self, numerical_setting=None, show=False): """Plot the nullcline. Parameters ---------- numerical_setting : dict, optional Set the numerical method for solving nullclines. For each function setting, it contains the following keywords: coords The coordination setting, it can be 'var1-var2' (which means for each possible value 'var1' the optimizer method will search the zero root of 'var2') or 'var2-var1' (which means iterate each 'var2' and get the optimization results of 'var1'). plot It can be 'scatter' (default) or 'line'. show : bool Whether show the figure. Returns ------- values : dict A dict with the format of ``{func1: (x_val, y_val), func2: (x_val, y_val)}``. """ print('plot nullcline ...') if numerical_setting is None: numerical_setting = dict() x_setting = numerical_setting.get(self.x_eq_group.func_name, {}) y_setting = numerical_setting.get(self.y_eq_group.func_name, {}) x_coords = x_setting.get('coords', self.x_var + '-' + self.y_var) y_coords = y_setting.get('coords', self.x_var + '-' + self.y_var) x_plot_style = x_setting.get('plot', 'scatter') y_plot_style = y_setting.get('plot', 'scatter') xs = self.resolutions[self.x_var] ys = self.resolutions[self.y_var] # Nullcline of the y variable y_style = dict( color='cornflowerblue', alpha=.7, ) y_by_x = self.get_y_by_x_in_y_eq() if y_by_x['status'] == 'sympy_success': try: y_values_in_y_eq = y_by_x['f'](xs) except TypeError: raise errors.ModelUseError( 'Missing variables. Please check and set missing ' 'variables to "fixed_vars".') x_values_in_y_eq = xs plt.plot(xs, y_values_in_y_eq, **y_style, label=f"{self.y_var} nullcline") else: x_by_y = self.get_x_by_y_in_y_eq() if x_by_y['status'] == 'sympy_success': try: x_values_in_y_eq = x_by_y['f'](ys) except TypeError: raise errors.ModelUseError( 'Missing variables. Please check and set missing ' 'variables to "fixed_vars".') y_values_in_y_eq = ys plt.plot(x_values_in_y_eq, ys, **y_style, label=f"{self.y_var} nullcline") else: # optimization results optimizer = self.get_f_optimize_y_nullcline(y_coords) x_values_in_y_eq, y_values_in_y_eq = optimizer() if x_plot_style == 'scatter': plt.plot(x_values_in_y_eq, y_values_in_y_eq, '.', **y_style, label=f"{self.y_var} nullcline") elif x_plot_style == 'line': plt.plot(x_values_in_y_eq, y_values_in_y_eq, **y_style, label=f"{self.y_var} nullcline") else: raise ValueError(f'Unknown plot style: {x_plot_style}') # Nullcline of the x variable x_style = dict( color='lightcoral', alpha=.7, ) y_by_x = self.get_y_by_x_in_x_eq() if y_by_x['status'] == 'sympy_success': try: y_values_in_x_eq = y_by_x['f'](xs) except TypeError: raise errors.ModelUseError( 'Missing variables. Please check and set missing ' 'variables to "fixed_vars".') x_values_in_x_eq = xs plt.plot(xs, y_values_in_x_eq, **x_style, label=f"{self.x_var} nullcline") else: x_by_y = self.get_x_by_y_in_x_eq() if x_by_y['status'] == 'sympy_success': try: x_values_in_x_eq = x_by_y['f'](ys) except TypeError: raise errors.ModelUseError( 'Missing variables. Please check and set missing ' 'variables to "fixed_vars".') y_values_in_x_eq = ys plt.plot(x_values_in_x_eq, ys, **x_style, label=f"{self.x_var} nullcline") else: # optimization results optimizer = self.get_f_optimize_x_nullcline(x_coords) x_values_in_x_eq, y_values_in_x_eq = optimizer() # visualization if y_plot_style == 'scatter': plt.plot(x_values_in_x_eq, y_values_in_x_eq, '.', **x_style, label=f"{self.x_var} nullcline") elif y_plot_style == 'line': plt.plot(x_values_in_x_eq, y_values_in_x_eq, **x_style, label=f"{self.x_var} nullcline") else: raise ValueError(f'Unknown plot style: {x_plot_style}') # finally plt.xlabel(self.x_var) plt.ylabel(self.y_var) scale = (self.options.lim_scale - 1.) / 2 plt.xlim(*utils.rescale(self.target_vars[self.x_var], scale=scale)) plt.ylim(*utils.rescale(self.target_vars[self.y_var], scale=scale)) plt.legend() if show: plt.show() return { self.x_eq_group.func_name: (x_values_in_x_eq, y_values_in_x_eq), self.y_eq_group.func_name: (x_values_in_y_eq, y_values_in_y_eq) }
def __init__(self, steps, monitors=None, name=None, host=None, show_code=False): # host of the data # ---------------- if host is None: host = self self.host = host # model # ----- if callable(steps): self.steps = OrderedDict([(steps.__name__, steps)]) elif isinstance(steps, (list, tuple)) and callable(steps[0]): self.steps = OrderedDict([(step.__name__, step) for step in steps]) elif isinstance(steps, dict): self.steps = steps else: raise errors.ModelDefError( f'Unknown model type: {type(steps)}. Currently, BrainPy ' f'only supports: function, list/tuple/dict of functions.') # name # ---- if name is None: global _DynamicSystem_NO name = f'DS{_DynamicSystem_NO}' _DynamicSystem_NO += 1 if not name.isidentifier(): raise errors.ModelUseError( f'"{name}" isn\'t a valid identifier according to Python ' f'language definition. Please choose another name.') self.name = name # monitors # --------- if monitors is None: monitors = [] self.mon = Monitor(monitors) for var in self.mon['vars']: if not hasattr(self, var): raise errors.ModelDefError( f"Item {var} isn't defined in model {self}, " f"so it can not be monitored.") # runner # ------- self.runner = backend.get_node_runner()(pop=self) # run function # ------------ self.run_func = None # others # --- self.show_code = show_code if self.target_backend is None: raise errors.ModelDefError('Must define "target_backend".') if isinstance(self.target_backend, str): self._target_backend = (self.target_backend, ) elif isinstance(self.target_backend, (tuple, list)): if not isinstance(self.target_backend[0], str): raise errors.ModelDefError( '"target_backend" must be a list/tuple of string.') self._target_backend = tuple(self.target_backend) else: raise errors.ModelDefError( f'Unknown setting of "target_backend": {self.target_backend}')
def set(backend=None, module_or_operations=None, node_runner=None, net_runner=None, dt=None): """Basic backend setting function. Using this function, users can set the backend they prefer. For backend which is unknown, users can provide `module_or_operations` to specify the operations needed. Also, users can customize the node runner, or the network runner, by providing the `node_runner` or `net_runner` keywords. The default numerical precision `dt` can also be set by this function. Parameters ---------- backend : str The backend name. module_or_operations : module, dict, optional The module or the a dict containing necessary operations. node_runner : GeneralNodeRunner An instance of node runner. net_runner : GeneralNetRunner An instance of network runner. dt : float The numerical precision. """ if dt is not None: set_dt(dt) if (backend is None) or (_backend == backend): return global_vars = globals() if backend == 'numpy': from .operators import bk_numpy node_runner = GeneralNodeRunner if node_runner is None else node_runner net_runner = GeneralNetRunner if net_runner is None else net_runner module_or_operations = bk_numpy if module_or_operations is None else module_or_operations elif backend == 'pytorch': from .operators import bk_pytorch node_runner = GeneralNodeRunner if node_runner is None else node_runner net_runner = GeneralNetRunner if net_runner is None else net_runner module_or_operations = bk_pytorch if module_or_operations is None else module_or_operations elif backend == 'tensorflow': from .operators import bk_tensorflow node_runner = GeneralNodeRunner if node_runner is None else node_runner net_runner = GeneralNetRunner if net_runner is None else net_runner module_or_operations = bk_tensorflow if module_or_operations is None else module_or_operations elif backend == 'numba': from .operators import bk_numba_cpu from .runners.numba_cpu_runner import NumbaCPUNodeRunner, set_numba_profile node_runner = NumbaCPUNodeRunner if node_runner is None else node_runner module_or_operations = bk_numba_cpu if module_or_operations is None else module_or_operations set_numba_profile(parallel=False) elif backend == 'numba-parallel': from .operators import bk_numba_cpu from .runners.numba_cpu_runner import NumbaCPUNodeRunner, set_numba_profile node_runner = NumbaCPUNodeRunner if node_runner is None else node_runner module_or_operations = bk_numba_cpu if module_or_operations is None else module_or_operations set_numba_profile(parallel=True) elif backend == 'numba-cuda': from .operators import bk_numba_cuda from .runners.numba_cuda_runner import NumbaCudaNodeRunner node_runner = NumbaCudaNodeRunner if node_runner is None else node_runner module_or_operations = bk_numba_cuda if module_or_operations is None else module_or_operations elif backend == 'jax': from .operators import bk_jax from .runners.jax_runner import JaxRunner node_runner = JaxRunner if node_runner is None else node_runner module_or_operations = bk_jax if module_or_operations is None else module_or_operations else: if module_or_operations is None: raise errors.ModelUseError( f'Backend "{backend}" is unknown, ' f'please provide the "module_or_operations" ' f'to specify the necessary computation units.') node_runner = GeneralNodeRunner if node_runner is None else node_runner global_vars['_backend'] = backend global_vars['_node_runner'] = node_runner global_vars['_net_runner'] = net_runner if isinstance(module_or_operations, ModuleType): set_ops_from_module(module_or_operations) elif isinstance(module_or_operations, dict): set_ops(**module_or_operations) else: raise errors.ModelUseError('"module_or_operations" must be a module ' 'or a dict of operations.')
def format_net_level_inputs(inputs, run_length): """Format the inputs of a network. Parameters ---------- inputs : tuple The inputs. run_length : int The running length. Returns ------- formatted_input : dict The formatted input. """ from brainpy.simulation.dynamic_system import DynamicSystem # 1. format the inputs to standard # formats and check the inputs if not isinstance(inputs, (tuple, list)): raise errors.ModelUseError('"inputs" must be a tuple/list.') if len(inputs) > 0 and not isinstance(inputs[0], (list, tuple)): if isinstance(inputs[0], DynamicSystem): inputs = [inputs] else: raise errors.ModelUseError('Unknown input structure. Only supports ' '"(target, key, value, [operation])".') for input in inputs: if not 3 <= len(input) <= 4: raise errors.ModelUseError('For each target, you must specify ' '"(target, key, value, [operation])".') if len(input) == 4: if input[3] not in SUPPORTED_INPUT_OPS: raise errors.ModelUseError(f'Input operation only supports ' f'"{SUPPORTED_INPUT_OPS}", ' f'not "{input[3]}".') # 2. format inputs formatted_inputs = {} for input in inputs: # target if isinstance(input[0], DynamicSystem): target = input[0] target_name = input[0].name else: raise KeyError(f'Unknown input target: {str(input[0])}') # key key = input[1] if not isinstance(key, str): raise errors.ModelUseError('For each input, input[1] must be a string ' 'to specify variable of the target.') if not hasattr(target, key): raise errors.ModelUseError(f'Target {target} does not have key {key}. ' f'So, it can not assign input to it.') # value and data type val = input[2] if isinstance(input[2], (int, float)): data_type = 'fix' else: shape = ops.shape(val) if shape[0] == run_length: data_type = 'iter' else: data_type = 'fix' # operation if len(input) == 4: operation = input[3] else: operation = '+' # final result if target_name not in formatted_inputs: formatted_inputs[target_name] = [] format_inp = (key, val, operation, data_type) formatted_inputs[target_name].append(format_inp) return formatted_inputs
def format_pop_level_inputs(inputs, host, mon_length): """Format the inputs of a population. Parameters ---------- inputs : tuple, list The inputs of the population. host : Population The host which contains all data. mon_length : int The monitor length. Returns ------- formatted_inputs : tuple, list The formatted inputs of the population. """ if inputs is None: inputs = [] if not isinstance(inputs, (tuple, list)): raise errors.ModelUseError('"inputs" must be a tuple/list.') if len(inputs) > 0 and not isinstance(inputs[0], (list, tuple)): if isinstance(inputs[0], str): inputs = [inputs] else: raise errors.ModelUseError('Unknown input structure, only support inputs ' 'with format of "(key, value, [operation])".') for input in inputs: if not 2 <= len(input) <= 3: raise errors.ModelUseError('For each target, you must specify "(key, value, [operation])".') if len(input) == 3 and input[2] not in SUPPORTED_INPUT_OPS: raise errors.ModelUseError(f'Input operation only supports ' f'"{SUPPORTED_INPUT_OPS}", ' f'not "{input[2]}".') # format inputs # ------------- formatted_inputs = [] for input in inputs: # key if not isinstance(input[0], str): raise errors.ModelUseError('For each input, input[0] must be a string ' 'to specify variable of the target.') key = input[0] if not hasattr(host, key): raise errors.ModelUseError(f'Input target key "{key}" is not defined in {host}.') # value and data type val = input[1] if isinstance(input[1], (int, float)): data_type = 'fix' else: shape = ops.shape(input[1]) if shape[0] == mon_length: data_type = 'iter' else: data_type = 'fix' # operation if len(input) == 3: operation = input[2] else: operation = '+' if operation not in SUPPORTED_INPUT_OPS: raise errors.ModelUseError(f'Currently, BrainPy only support operations ' f'{SUPPORTED_INPUT_OPS}, ' f'not {operation}') # input format_inp = (key, val, operation, data_type) formatted_inputs.append(format_inp) return formatted_inputs
def __init__(self, integrals, fast_vars, slow_vars, fixed_vars=None, pars_update=None, numerical_resolution=0.1, options=None): # check "model" self.model = utils.transform_integrals_to_model(integrals) # check "fast_vars" if not isinstance(fast_vars, dict): raise errors.ModelUseError( '"fast_vars" must a dict with the format of: ' '{"Var A": [A_min, A_max],' ' "Var B": [B_min, B_max]}') self.fast_vars = fast_vars if len(fast_vars) > 2: raise errors.ModelUseError( "FastSlowBifurcation can only analyze the system with less " "than two-variable fast subsystem.") # check "slow_vars" if not isinstance(slow_vars, dict): raise errors.ModelUseError( '"slow_vars" must a dict with the format of: ' '{"Variable A": [A_min, A_max], ' '"Variable B": [B_min, B_max]}') self.slow_vars = slow_vars if len(slow_vars) > 2: raise errors.ModelUseError( "FastSlowBifurcation can only analyze the system with less " "than two-variable slow subsystem.") for key in self.slow_vars: self.model.variables.remove(key) self.model.parameters.append(key) # check "fixed_vars" if fixed_vars is None: fixed_vars = dict() if not isinstance(fixed_vars, dict): raise errors.ModelUseError( '"fixed_vars" must be a dict the format of: ' '{"Variable A": A_value, "Variable B": B_value}') self.fixed_vars = fixed_vars # check "pars_update" if pars_update is None: pars_update = dict() if not isinstance(pars_update, dict): raise errors.ModelUseError( '"pars_update" must be a dict the format of: ' '{"Par A": A_value, "Par B": B_value}') for key in pars_update.keys(): if (key not in self.model.scopes) and ( key not in self.model.parameters): raise errors.ModelUseError( f'"{key}" is not a valid parameter in "{integrals}" model. ' ) self.pars_update = pars_update # bifurcation analysis if len(self.fast_vars) == 1: self.analyzer = _FastSlow1D( model_or_integrals=self.model, fast_vars=fast_vars, slow_vars=slow_vars, fixed_vars=fixed_vars, pars_update=pars_update, numerical_resolution=numerical_resolution, options=options) elif len(self.fast_vars) == 2: self.analyzer = _FastSlow2D( model_or_integrals=self.model, fast_vars=fast_vars, slow_vars=slow_vars, fixed_vars=fixed_vars, pars_update=pars_update, numerical_resolution=numerical_resolution, options=options) else: raise errors.ModelUseError( f'Cannot analyze {len(fast_vars)} dimensional fast system.')
def __init__(self, integrals, target_pars, target_vars, fixed_vars=None, pars_update=None, numerical_resolution=0.1, options=None): # check "model" self.model = utils.transform_integrals_to_model(integrals) # check "target_pars" if not isinstance(target_pars, dict): raise errors.ModelUseError( '"target_pars" must a dict with the format of: ' '{"Parameter A": [A_min, A_max],' ' "Parameter B": [B_min, B_max]}') self.target_pars = target_pars if len(target_pars) > 2: raise errors.ModelUseError( "The number of parameters in bifurcation" "analysis cannot exceed 2.") # check "fixed_vars" if fixed_vars is None: fixed_vars = dict() if not isinstance(fixed_vars, dict): raise errors.ModelUseError( '"fixed_vars" must be a dict the format of: ' '{"Variable A": A_value, "Variable B": B_value}') self.fixed_vars = fixed_vars # check "target_vars" if not isinstance(target_vars, dict): raise errors.ModelUseError( '"target_vars" must a dict with the format of: ' '{"Variable A": [A_min, A_max], "Variable B": [B_min, B_max]}') self.target_vars = target_vars # check "pars_update" if pars_update is None: pars_update = dict() if not isinstance(pars_update, dict): raise errors.ModelUseError( '"pars_update" must be a dict the format of: ' '{"Par A": A_value, "Par B": B_value}') for key in pars_update.keys(): if (key not in self.model.scopes) and ( key not in self.model.parameters): raise errors.ModelUseError( f'"{key}" is not a valid parameter in "{integrals}". ') self.pars_update = pars_update # bifurcation analysis if len(self.target_vars) == 1: self.analyzer = _Bifurcation1D( model_or_integrals=self.model, target_pars=target_pars, target_vars=target_vars, fixed_vars=fixed_vars, pars_update=pars_update, numerical_resolution=numerical_resolution, options=options) elif len(self.target_vars) == 2: self.analyzer = _Bifurcation2D( model_or_integrals=self.model, target_pars=target_pars, target_vars=target_vars, fixed_vars=fixed_vars, pars_update=pars_update, numerical_resolution=numerical_resolution, options=options) else: raise errors.ModelUseError( f'Cannot analyze three dimensional system: {self.target_vars}')
def plot_bifurcation(self, show=False): print('plot bifurcation ...') f_fixed_point = self.get_f_fixed_point() f_dfdx = self.get_f_dfdx() if len(self.target_pars) == 1: container = { c: { 'p': [], 'x': [] } for c in stability.get_1d_stability_types() } # fixed point par_a = self.dpar_names[0] for p in self.resolutions[par_a]: xs = f_fixed_point(p) for x in xs: dfdx = f_dfdx(x, p) fp_type = stability.stability_analysis(dfdx) container[fp_type]['p'].append(p) container[fp_type]['x'].append(x) # visualization plt.figure(self.x_var) for fp_type, points in container.items(): if len(points['x']): plot_style = stability.plot_scheme[fp_type] plt.plot(points['p'], points['x'], '.', **plot_style, label=fp_type) plt.xlabel(par_a) plt.ylabel(self.x_var) # scale = (self.options.lim_scale - 1) / 2 # plt.xlim(*utils.rescale(self.target_pars[self.dpar_names[0]], scale=scale)) # plt.ylim(*utils.rescale(self.target_vars[self.x_var], scale=scale)) plt.legend() if show: plt.show() elif len(self.target_pars) == 2: container = { c: { 'p0': [], 'p1': [], 'x': [] } for c in stability.get_1d_stability_types() } # fixed point for p0 in self.resolutions[self.dpar_names[0]]: for p1 in self.resolutions[self.dpar_names[1]]: xs = f_fixed_point(p0, p1) for x in xs: dfdx = f_dfdx(x, p0, p1) fp_type = stability.stability_analysis(dfdx) container[fp_type]['p0'].append(p0) container[fp_type]['p1'].append(p1) container[fp_type]['x'].append(x) # visualization fig = plt.figure(self.x_var) ax = fig.gca(projection='3d') for fp_type, points in container.items(): if len(points['x']): plot_style = stability.plot_scheme[fp_type] xs = points['p0'] ys = points['p1'] zs = points['x'] ax.scatter(xs, ys, zs, **plot_style, label=fp_type) ax.set_xlabel(self.dpar_names[0]) ax.set_ylabel(self.dpar_names[1]) ax.set_zlabel(self.x_var) # scale = (self.options.lim_scale - 1) / 2 # ax.set_xlim(*utils.rescale(self.target_pars[self.dpar_names[0]], scale=scale)) # ax.set_ylim(*utils.rescale(self.target_pars[self.dpar_names[1]], scale=scale)) # ax.set_zlim(*utils.rescale(self.target_vars[self.x_var], scale=scale)) ax.grid(True) ax.legend() if show: plt.show() else: raise errors.ModelUseError( f'Cannot visualize co-dimension {len(self.target_pars)} ' f'bifurcation.') return container