Example #1
0
    def plot_trajectory(self,
                        initials,
                        duration,
                        plot_duration=None,
                        axes='v-v',
                        show=False):
        """Plot trajectories according to the settings.

        Parameters
        ----------
        initials : list, tuple, dict
            The initial value setting of the targets. It can be a tuple/list of floats to specify
            each value of dynamical variables (for example, ``(a, b)``). It can also be a
            tuple/list of tuple to specify multiple initial values (for example,
            ``[(a1, b1), (a2, b2)]``).
        duration : int, float, tuple, list
            The running duration. Same with the ``duration`` in ``NeuGroup.run()``.
            It can be a int/float (``t_end``) to specify the same running end time,
            or it can be a tuple/list of int/float (``(t_start, t_end)``) to specify
            the start and end simulation time. Or, it can be a list of tuple
            (``[(t1_start, t1_end), (t2_start, t2_end)]``) to specify the specific
            start and end simulation time for each initial value.
        plot_duration : tuple, list, optional
            The duration to plot. It can be a tuple with ``(start, end)``. It can
            also be a list of tuple ``[(start1, end1), (start2, end2)]`` to specify
            the plot duration for each initial value running.
        axes : str
            The axes to plot. It can be:

                 - 'v-v'
                        Plot the trajectory in the 'x_var'-'y_var' axis.
                 - 't-v'
                        Plot the trajectory in the 'time'-'var' axis.
        show : bool
            Whether show or not.
        """

        print('plot trajectory ...')

        if axes not in ['v-v', 't-v']:
            raise errors.ModelUseError(
                f'Unknown axes "{axes}", only support "v-v" and "t-v".')

        # 1. format the initial values
        if isinstance(initials, dict):
            initials = [initials]
        elif isinstance(initials, (list, tuple)):
            if isinstance(initials[0], (int, float)):
                initials = [{
                    self.dvar_names[i]: v
                    for i, v in enumerate(initials)
                }]
            elif isinstance(initials[0], dict):
                initials = initials
            elif isinstance(initials[0], (tuple, list)) and isinstance(
                    initials[0][0], (int, float)):
                initials = [{
                    self.dvar_names[i]: v
                    for i, v in enumerate(init)
                } for init in initials]
            else:
                raise ValueError
        else:
            raise ValueError

        # 2. format the running duration
        if isinstance(duration, (int, float)):
            duration = [(0, duration) for _ in range(len(initials))]
        elif isinstance(duration[0], (int, float)):
            duration = [duration for _ in range(len(initials))]
        else:
            assert len(duration) == len(initials)

        # 3. format the plot duration
        if plot_duration is None:
            plot_duration = duration
        if isinstance(plot_duration[0], (int, float)):
            plot_duration = [plot_duration for _ in range(len(initials))]
        else:
            assert len(plot_duration) == len(initials)

        # 5. run the network
        for init_i, initial in enumerate(initials):
            traj_group = Trajectory(size=1,
                                    integrals=self.model.integrals,
                                    target_vars=initial,
                                    fixed_vars=self.fixed_vars,
                                    pars_update=self.pars_update,
                                    scope=self.model.scopes)

            #   5.2 run the model
            traj_group.run(
                duration=duration[init_i],
                report=False,
            )

            #   5.3 legend
            legend = f'$traj_{init_i}$: '
            for key in self.dvar_names:
                legend += f'{key}={initial[key]}, '
            legend = legend[:-2]

            #   5.4 trajectory
            start = int(plot_duration[init_i][0] / backend.get_dt())
            end = int(plot_duration[init_i][1] / backend.get_dt())

            #   5.5 visualization
            if axes == 'v-v':
                lines = plt.plot(traj_group.mon[self.x_var][start:end, 0],
                                 traj_group.mon[self.y_var][start:end, 0],
                                 label=legend)
                utils.add_arrow(lines[0])
            else:
                plt.plot(traj_group.mon.ts[start:end],
                         traj_group.mon[self.x_var][start:end, 0],
                         label=legend + f', {self.x_var}')
                plt.plot(traj_group.mon.ts[start:end],
                         traj_group.mon[self.y_var][start:end, 0],
                         label=legend + f', {self.y_var}')

        # 6. visualization
        if axes == 'v-v':
            plt.xlabel(self.x_var)
            plt.ylabel(self.y_var)
            scale = (self.options.lim_scale - 1.) / 2
            plt.xlim(*utils.rescale(self.target_vars[self.x_var], scale=scale))
            plt.ylim(*utils.rescale(self.target_vars[self.y_var], scale=scale))
            plt.legend()
        else:
            plt.legend(title='Initial values')

        if show:
            plt.show()
Example #2
0
    def plot_vector_field(self,
                          plot_method='streamplot',
                          plot_style=None,
                          show=False):
        """Plot the vector field.

        Parameters
        ----------
        plot_method : str
            The method to plot the vector filed.
            It can be "streamplot" or "quiver".
        plot_style : dict, optional
            The style for vector filed plotting.

            - For ``plot_method="streamplot"``, it can set the keywords like "density",
              "linewidth", "color", "arrowsize". More settings please check
              https://matplotlib.org/api/_as_gen/matplotlib.pyplot.streamplot.html.
            - For ``plot_method="quiver"``, it can set the keywords like "color",
              "units", "angles", "scale". More settings please check
              https://matplotlib.org/api/_as_gen/matplotlib.pyplot.quiver.html.

        Returns
        -------
        result : tuple
            The ``dx``, ``dy`` values.
        """
        print('plot vector field ...')

        if plot_style is None:
            plot_style = dict()

        xs = self.resolutions[self.x_var]
        ys = self.resolutions[self.y_var]
        X, Y = np.meshgrid(xs, ys)

        # dx
        try:
            dx = self.get_f_dx()(X, Y)
        except TypeError:
            raise errors.ModelUseError(
                'Missing variables. Please check and set missing '
                'variables to "fixed_vars".')

        # dy
        try:
            dy = self.get_f_dy()(X, Y)
        except TypeError:
            raise errors.ModelUseError(
                'Missing variables. Please check and set missing '
                'variables to "fixed_vars".')

        # vector field
        if plot_method == 'quiver':
            styles = dict()
            styles['units'] = plot_style.get('units', 'xy')
            if (not np.isnan(dx).any()) and (not np.isnan(dy).any()):
                speed = np.sqrt(dx**2 + dy**2)
                dx = dx / speed
                dy = dy / speed
            plt.quiver(X, Y, dx, dy, **styles)
        elif plot_method == 'streamplot':
            styles = dict()
            styles['arrowsize'] = plot_style.get('arrowsize', 1.2)
            styles['density'] = plot_style.get('density', 1)
            styles['color'] = plot_style.get('color', 'thistle')
            linewidth = plot_style.get('linewidth', None)
            if (linewidth is None) and (not np.isnan(dx).any()) and (
                    not np.isnan(dy).any()):
                min_width = plot_style.get('min_width', 0.5)
                max_width = plot_style.get('min_width', 5.5)
                speed = np.sqrt(dx**2 + dy**2)
                linewidth = min_width + max_width * speed / speed.max()
            plt.streamplot(X, Y, dx, dy, linewidth=linewidth, **styles)
        else:
            raise ValueError(
                f'Unknown plot_method "{plot_method}", only supports "quiver" and "streamplot".'
            )

        plt.xlabel(self.x_var)
        plt.ylabel(self.y_var)

        if show:
            plt.show()

        return dx, dy
Example #3
0
    def plot_nullcline(self, numerical_setting=None, show=False):
        """Plot the nullcline.

        Parameters
        ----------
        numerical_setting : dict, optional
            Set the numerical method for solving nullclines.
            For each function setting, it contains the following keywords:

                coords
                    The coordination setting, it can be 'var1-var2' (which means
                    for each possible value 'var1' the optimizer method will search
                    the zero root of 'var2') or 'var2-var1' (which means iterate each
                    'var2' and get the optimization results of 'var1').
                plot
                    It can be 'scatter' (default) or 'line'.

        show : bool
            Whether show the figure.

        Returns
        -------
        values : dict
            A dict with the format of ``{func1: (x_val, y_val), func2: (x_val, y_val)}``.
        """
        print('plot nullcline ...')

        if numerical_setting is None:
            numerical_setting = dict()
        x_setting = numerical_setting.get(self.x_eq_group.func_name, {})
        y_setting = numerical_setting.get(self.y_eq_group.func_name, {})
        x_coords = x_setting.get('coords', self.x_var + '-' + self.y_var)
        y_coords = y_setting.get('coords', self.x_var + '-' + self.y_var)
        x_plot_style = x_setting.get('plot', 'scatter')
        y_plot_style = y_setting.get('plot', 'scatter')

        xs = self.resolutions[self.x_var]
        ys = self.resolutions[self.y_var]

        # Nullcline of the y variable
        y_style = dict(
            color='cornflowerblue',
            alpha=.7,
        )
        y_by_x = self.get_y_by_x_in_y_eq()
        if y_by_x['status'] == 'sympy_success':
            try:
                y_values_in_y_eq = y_by_x['f'](xs)
            except TypeError:
                raise errors.ModelUseError(
                    'Missing variables. Please check and set missing '
                    'variables to "fixed_vars".')
            x_values_in_y_eq = xs
            plt.plot(xs,
                     y_values_in_y_eq,
                     **y_style,
                     label=f"{self.y_var} nullcline")

        else:
            x_by_y = self.get_x_by_y_in_y_eq()
            if x_by_y['status'] == 'sympy_success':
                try:
                    x_values_in_y_eq = x_by_y['f'](ys)
                except TypeError:
                    raise errors.ModelUseError(
                        'Missing variables. Please check and set missing '
                        'variables to "fixed_vars".')
                y_values_in_y_eq = ys
                plt.plot(x_values_in_y_eq,
                         ys,
                         **y_style,
                         label=f"{self.y_var} nullcline")
            else:
                # optimization results
                optimizer = self.get_f_optimize_y_nullcline(y_coords)
                x_values_in_y_eq, y_values_in_y_eq = optimizer()

                if x_plot_style == 'scatter':
                    plt.plot(x_values_in_y_eq,
                             y_values_in_y_eq,
                             '.',
                             **y_style,
                             label=f"{self.y_var} nullcline")
                elif x_plot_style == 'line':
                    plt.plot(x_values_in_y_eq,
                             y_values_in_y_eq,
                             **y_style,
                             label=f"{self.y_var} nullcline")
                else:
                    raise ValueError(f'Unknown plot style: {x_plot_style}')

        # Nullcline of the x variable
        x_style = dict(
            color='lightcoral',
            alpha=.7,
        )
        y_by_x = self.get_y_by_x_in_x_eq()
        if y_by_x['status'] == 'sympy_success':
            try:
                y_values_in_x_eq = y_by_x['f'](xs)
            except TypeError:
                raise errors.ModelUseError(
                    'Missing variables. Please check and set missing '
                    'variables to "fixed_vars".')
            x_values_in_x_eq = xs
            plt.plot(xs,
                     y_values_in_x_eq,
                     **x_style,
                     label=f"{self.x_var} nullcline")

        else:
            x_by_y = self.get_x_by_y_in_x_eq()
            if x_by_y['status'] == 'sympy_success':
                try:
                    x_values_in_x_eq = x_by_y['f'](ys)
                except TypeError:
                    raise errors.ModelUseError(
                        'Missing variables. Please check and set missing '
                        'variables to "fixed_vars".')
                y_values_in_x_eq = ys
                plt.plot(x_values_in_x_eq,
                         ys,
                         **x_style,
                         label=f"{self.x_var} nullcline")
            else:
                # optimization results
                optimizer = self.get_f_optimize_x_nullcline(x_coords)
                x_values_in_x_eq, y_values_in_x_eq = optimizer()

                # visualization
                if y_plot_style == 'scatter':
                    plt.plot(x_values_in_x_eq,
                             y_values_in_x_eq,
                             '.',
                             **x_style,
                             label=f"{self.x_var} nullcline")
                elif y_plot_style == 'line':
                    plt.plot(x_values_in_x_eq,
                             y_values_in_x_eq,
                             **x_style,
                             label=f"{self.x_var} nullcline")
                else:
                    raise ValueError(f'Unknown plot style: {x_plot_style}')
        # finally
        plt.xlabel(self.x_var)
        plt.ylabel(self.y_var)
        scale = (self.options.lim_scale - 1.) / 2
        plt.xlim(*utils.rescale(self.target_vars[self.x_var], scale=scale))
        plt.ylim(*utils.rescale(self.target_vars[self.y_var], scale=scale))
        plt.legend()
        if show:
            plt.show()

        return {
            self.x_eq_group.func_name: (x_values_in_x_eq, y_values_in_x_eq),
            self.y_eq_group.func_name: (x_values_in_y_eq, y_values_in_y_eq)
        }
Example #4
0
    def __init__(self,
                 steps,
                 monitors=None,
                 name=None,
                 host=None,
                 show_code=False):
        # host of the data
        # ----------------
        if host is None:
            host = self
        self.host = host

        # model
        # -----
        if callable(steps):
            self.steps = OrderedDict([(steps.__name__, steps)])
        elif isinstance(steps, (list, tuple)) and callable(steps[0]):
            self.steps = OrderedDict([(step.__name__, step) for step in steps])
        elif isinstance(steps, dict):
            self.steps = steps
        else:
            raise errors.ModelDefError(
                f'Unknown model type: {type(steps)}. Currently, BrainPy '
                f'only supports: function, list/tuple/dict of functions.')

        # name
        # ----
        if name is None:
            global _DynamicSystem_NO
            name = f'DS{_DynamicSystem_NO}'
            _DynamicSystem_NO += 1
        if not name.isidentifier():
            raise errors.ModelUseError(
                f'"{name}" isn\'t a valid identifier according to Python '
                f'language definition. Please choose another name.')
        self.name = name

        # monitors
        # ---------
        if monitors is None:
            monitors = []
        self.mon = Monitor(monitors)
        for var in self.mon['vars']:
            if not hasattr(self, var):
                raise errors.ModelDefError(
                    f"Item {var} isn't defined in model {self}, "
                    f"so it can not be monitored.")

        # runner
        # -------
        self.runner = backend.get_node_runner()(pop=self)

        # run function
        # ------------
        self.run_func = None

        # others
        # ---
        self.show_code = show_code
        if self.target_backend is None:
            raise errors.ModelDefError('Must define "target_backend".')
        if isinstance(self.target_backend, str):
            self._target_backend = (self.target_backend, )
        elif isinstance(self.target_backend, (tuple, list)):
            if not isinstance(self.target_backend[0], str):
                raise errors.ModelDefError(
                    '"target_backend" must be a list/tuple of string.')
            self._target_backend = tuple(self.target_backend)
        else:
            raise errors.ModelDefError(
                f'Unknown setting of "target_backend": {self.target_backend}')
Example #5
0
def set(backend=None,
        module_or_operations=None,
        node_runner=None,
        net_runner=None,
        dt=None):
    """Basic backend setting function.

    Using this function, users can set the backend they prefer. For backend
    which is unknown, users can provide `module_or_operations` to specify
    the operations needed. Also, users can customize the node runner, or the
    network runner, by providing the `node_runner` or `net_runner` keywords.
    The default numerical precision `dt` can also be set by this function.

    Parameters
    ----------
    backend : str
        The backend name.
    module_or_operations : module, dict, optional
        The module or the a dict containing necessary operations.
    node_runner : GeneralNodeRunner
        An instance of node runner.
    net_runner : GeneralNetRunner
        An instance of network runner.
    dt : float
        The numerical precision.
    """
    if dt is not None:
        set_dt(dt)

    if (backend is None) or (_backend == backend):
        return

    global_vars = globals()
    if backend == 'numpy':
        from .operators import bk_numpy

        node_runner = GeneralNodeRunner if node_runner is None else node_runner
        net_runner = GeneralNetRunner if net_runner is None else net_runner
        module_or_operations = bk_numpy if module_or_operations is None else module_or_operations

    elif backend == 'pytorch':
        from .operators import bk_pytorch

        node_runner = GeneralNodeRunner if node_runner is None else node_runner
        net_runner = GeneralNetRunner if net_runner is None else net_runner
        module_or_operations = bk_pytorch if module_or_operations is None else module_or_operations

    elif backend == 'tensorflow':
        from .operators import bk_tensorflow

        node_runner = GeneralNodeRunner if node_runner is None else node_runner
        net_runner = GeneralNetRunner if net_runner is None else net_runner
        module_or_operations = bk_tensorflow if module_or_operations is None else module_or_operations

    elif backend == 'numba':
        from .operators import bk_numba_cpu
        from .runners.numba_cpu_runner import NumbaCPUNodeRunner, set_numba_profile

        node_runner = NumbaCPUNodeRunner if node_runner is None else node_runner
        module_or_operations = bk_numba_cpu if module_or_operations is None else module_or_operations
        set_numba_profile(parallel=False)

    elif backend == 'numba-parallel':
        from .operators import bk_numba_cpu
        from .runners.numba_cpu_runner import NumbaCPUNodeRunner, set_numba_profile

        node_runner = NumbaCPUNodeRunner if node_runner is None else node_runner
        module_or_operations = bk_numba_cpu if module_or_operations is None else module_or_operations
        set_numba_profile(parallel=True)

    elif backend == 'numba-cuda':
        from .operators import bk_numba_cuda
        from .runners.numba_cuda_runner import NumbaCudaNodeRunner

        node_runner = NumbaCudaNodeRunner if node_runner is None else node_runner
        module_or_operations = bk_numba_cuda if module_or_operations is None else module_or_operations

    elif backend == 'jax':
        from .operators import bk_jax
        from .runners.jax_runner import JaxRunner

        node_runner = JaxRunner if node_runner is None else node_runner
        module_or_operations = bk_jax if module_or_operations is None else module_or_operations

    else:
        if module_or_operations is None:
            raise errors.ModelUseError(
                f'Backend "{backend}" is unknown, '
                f'please provide the "module_or_operations" '
                f'to specify the necessary computation units.')
        node_runner = GeneralNodeRunner if node_runner is None else node_runner

    global_vars['_backend'] = backend
    global_vars['_node_runner'] = node_runner
    global_vars['_net_runner'] = net_runner
    if isinstance(module_or_operations, ModuleType):
        set_ops_from_module(module_or_operations)
    elif isinstance(module_or_operations, dict):
        set_ops(**module_or_operations)
    else:
        raise errors.ModelUseError('"module_or_operations" must be a module '
                                   'or a dict of operations.')
Example #6
0
def format_net_level_inputs(inputs, run_length):
    """Format the inputs of a network.

    Parameters
    ----------
    inputs : tuple
        The inputs.
    run_length : int
        The running length.

    Returns
    -------
    formatted_input : dict
        The formatted input.
    """
    from brainpy.simulation.dynamic_system import DynamicSystem

    # 1. format the inputs to standard
    #    formats and check the inputs
    if not isinstance(inputs, (tuple, list)):
        raise errors.ModelUseError('"inputs" must be a tuple/list.')
    if len(inputs) > 0 and not isinstance(inputs[0], (list, tuple)):
        if isinstance(inputs[0], DynamicSystem):
            inputs = [inputs]
        else:
            raise errors.ModelUseError('Unknown input structure. Only supports '
                                       '"(target, key, value, [operation])".')
    for input in inputs:
        if not 3 <= len(input) <= 4:
            raise errors.ModelUseError('For each target, you must specify '
                                       '"(target, key, value, [operation])".')
        if len(input) == 4:
            if input[3] not in SUPPORTED_INPUT_OPS:
                raise errors.ModelUseError(f'Input operation only supports '
                                           f'"{SUPPORTED_INPUT_OPS}", '
                                           f'not "{input[3]}".')

    # 2. format inputs
    formatted_inputs = {}
    for input in inputs:
        # target
        if isinstance(input[0], DynamicSystem):
            target = input[0]
            target_name = input[0].name
        else:
            raise KeyError(f'Unknown input target: {str(input[0])}')

        # key
        key = input[1]
        if not isinstance(key, str):
            raise errors.ModelUseError('For each input, input[1] must be a string '
                                       'to specify variable of the target.')
        if not hasattr(target, key):
            raise errors.ModelUseError(f'Target {target} does not have key {key}. '
                                       f'So, it can not assign input to it.')

        # value and data type
        val = input[2]
        if isinstance(input[2], (int, float)):
            data_type = 'fix'
        else:
            shape = ops.shape(val)
            if shape[0] == run_length:
                data_type = 'iter'
            else:
                data_type = 'fix'

        # operation
        if len(input) == 4:
            operation = input[3]
        else:
            operation = '+'

        # final result
        if target_name not in formatted_inputs:
            formatted_inputs[target_name] = []
        format_inp = (key, val, operation, data_type)
        formatted_inputs[target_name].append(format_inp)
    return formatted_inputs
Example #7
0
def format_pop_level_inputs(inputs, host, mon_length):
    """Format the inputs of a population.

    Parameters
    ----------
    inputs : tuple, list
        The inputs of the population.
    host : Population
        The host which contains all data.
    mon_length : int
        The monitor length.

    Returns
    -------
    formatted_inputs : tuple, list
        The formatted inputs of the population.
    """
    if inputs is None:
        inputs = []
    if not isinstance(inputs, (tuple, list)):
        raise errors.ModelUseError('"inputs" must be a tuple/list.')
    if len(inputs) > 0 and not isinstance(inputs[0], (list, tuple)):
        if isinstance(inputs[0], str):
            inputs = [inputs]
        else:
            raise errors.ModelUseError('Unknown input structure, only support inputs '
                                       'with format of "(key, value, [operation])".')
    for input in inputs:
        if not 2 <= len(input) <= 3:
            raise errors.ModelUseError('For each target, you must specify "(key, value, [operation])".')
        if len(input) == 3 and input[2] not in SUPPORTED_INPUT_OPS:
            raise errors.ModelUseError(f'Input operation only supports '
                                       f'"{SUPPORTED_INPUT_OPS}", '
                                       f'not "{input[2]}".')

    # format inputs
    # -------------
    formatted_inputs = []
    for input in inputs:
        # key
        if not isinstance(input[0], str):
            raise errors.ModelUseError('For each input, input[0] must be a string '
                                       'to specify variable of the target.')
        key = input[0]
        if not hasattr(host, key):
            raise errors.ModelUseError(f'Input target key "{key}" is not defined in {host}.')

        # value and data type
        val = input[1]
        if isinstance(input[1], (int, float)):
            data_type = 'fix'
        else:
            shape = ops.shape(input[1])
            if shape[0] == mon_length:
                data_type = 'iter'
            else:
                data_type = 'fix'

        # operation
        if len(input) == 3:
            operation = input[2]
        else:
            operation = '+'
        if operation not in SUPPORTED_INPUT_OPS:
            raise errors.ModelUseError(f'Currently, BrainPy only support operations '
                                       f'{SUPPORTED_INPUT_OPS}, '
                                       f'not {operation}')
        # input
        format_inp = (key, val, operation, data_type)
        formatted_inputs.append(format_inp)

    return formatted_inputs
Example #8
0
    def __init__(self,
                 integrals,
                 fast_vars,
                 slow_vars,
                 fixed_vars=None,
                 pars_update=None,
                 numerical_resolution=0.1,
                 options=None):
        # check "model"
        self.model = utils.transform_integrals_to_model(integrals)

        # check "fast_vars"
        if not isinstance(fast_vars, dict):
            raise errors.ModelUseError(
                '"fast_vars" must a dict with the format of: '
                '{"Var A": [A_min, A_max],'
                ' "Var B": [B_min, B_max]}')
        self.fast_vars = fast_vars
        if len(fast_vars) > 2:
            raise errors.ModelUseError(
                "FastSlowBifurcation can only analyze the system with less "
                "than two-variable fast subsystem.")

        # check "slow_vars"
        if not isinstance(slow_vars, dict):
            raise errors.ModelUseError(
                '"slow_vars" must a dict with the format of: '
                '{"Variable A": [A_min, A_max], '
                '"Variable B": [B_min, B_max]}')
        self.slow_vars = slow_vars
        if len(slow_vars) > 2:
            raise errors.ModelUseError(
                "FastSlowBifurcation can only analyze the system with less "
                "than two-variable slow subsystem.")
        for key in self.slow_vars:
            self.model.variables.remove(key)
            self.model.parameters.append(key)

        # check "fixed_vars"
        if fixed_vars is None:
            fixed_vars = dict()
        if not isinstance(fixed_vars, dict):
            raise errors.ModelUseError(
                '"fixed_vars" must be a dict the format of: '
                '{"Variable A": A_value, "Variable B": B_value}')
        self.fixed_vars = fixed_vars

        # check "pars_update"
        if pars_update is None:
            pars_update = dict()
        if not isinstance(pars_update, dict):
            raise errors.ModelUseError(
                '"pars_update" must be a dict the format of: '
                '{"Par A": A_value, "Par B": B_value}')
        for key in pars_update.keys():
            if (key not in self.model.scopes) and (
                    key not in self.model.parameters):
                raise errors.ModelUseError(
                    f'"{key}" is not a valid parameter in "{integrals}" model. '
                )
        self.pars_update = pars_update

        # bifurcation analysis
        if len(self.fast_vars) == 1:
            self.analyzer = _FastSlow1D(
                model_or_integrals=self.model,
                fast_vars=fast_vars,
                slow_vars=slow_vars,
                fixed_vars=fixed_vars,
                pars_update=pars_update,
                numerical_resolution=numerical_resolution,
                options=options)

        elif len(self.fast_vars) == 2:
            self.analyzer = _FastSlow2D(
                model_or_integrals=self.model,
                fast_vars=fast_vars,
                slow_vars=slow_vars,
                fixed_vars=fixed_vars,
                pars_update=pars_update,
                numerical_resolution=numerical_resolution,
                options=options)

        else:
            raise errors.ModelUseError(
                f'Cannot analyze {len(fast_vars)} dimensional fast system.')
Example #9
0
    def __init__(self,
                 integrals,
                 target_pars,
                 target_vars,
                 fixed_vars=None,
                 pars_update=None,
                 numerical_resolution=0.1,
                 options=None):
        # check "model"
        self.model = utils.transform_integrals_to_model(integrals)

        # check "target_pars"
        if not isinstance(target_pars, dict):
            raise errors.ModelUseError(
                '"target_pars" must a dict with the format of: '
                '{"Parameter A": [A_min, A_max],'
                ' "Parameter B": [B_min, B_max]}')
        self.target_pars = target_pars
        if len(target_pars) > 2:
            raise errors.ModelUseError(
                "The number of parameters in bifurcation"
                "analysis cannot exceed 2.")

        # check "fixed_vars"
        if fixed_vars is None:
            fixed_vars = dict()
        if not isinstance(fixed_vars, dict):
            raise errors.ModelUseError(
                '"fixed_vars" must be a dict the format of: '
                '{"Variable A": A_value, "Variable B": B_value}')
        self.fixed_vars = fixed_vars

        # check "target_vars"
        if not isinstance(target_vars, dict):
            raise errors.ModelUseError(
                '"target_vars" must a dict with the format of: '
                '{"Variable A": [A_min, A_max], "Variable B": [B_min, B_max]}')
        self.target_vars = target_vars

        # check "pars_update"
        if pars_update is None:
            pars_update = dict()
        if not isinstance(pars_update, dict):
            raise errors.ModelUseError(
                '"pars_update" must be a dict the format of: '
                '{"Par A": A_value, "Par B": B_value}')
        for key in pars_update.keys():
            if (key not in self.model.scopes) and (
                    key not in self.model.parameters):
                raise errors.ModelUseError(
                    f'"{key}" is not a valid parameter in "{integrals}". ')
        self.pars_update = pars_update

        # bifurcation analysis
        if len(self.target_vars) == 1:
            self.analyzer = _Bifurcation1D(
                model_or_integrals=self.model,
                target_pars=target_pars,
                target_vars=target_vars,
                fixed_vars=fixed_vars,
                pars_update=pars_update,
                numerical_resolution=numerical_resolution,
                options=options)

        elif len(self.target_vars) == 2:
            self.analyzer = _Bifurcation2D(
                model_or_integrals=self.model,
                target_pars=target_pars,
                target_vars=target_vars,
                fixed_vars=fixed_vars,
                pars_update=pars_update,
                numerical_resolution=numerical_resolution,
                options=options)

        else:
            raise errors.ModelUseError(
                f'Cannot analyze three dimensional system: {self.target_vars}')
Example #10
0
    def plot_bifurcation(self, show=False):
        print('plot bifurcation ...')

        f_fixed_point = self.get_f_fixed_point()
        f_dfdx = self.get_f_dfdx()

        if len(self.target_pars) == 1:
            container = {
                c: {
                    'p': [],
                    'x': []
                }
                for c in stability.get_1d_stability_types()
            }

            # fixed point
            par_a = self.dpar_names[0]
            for p in self.resolutions[par_a]:
                xs = f_fixed_point(p)
                for x in xs:
                    dfdx = f_dfdx(x, p)
                    fp_type = stability.stability_analysis(dfdx)
                    container[fp_type]['p'].append(p)
                    container[fp_type]['x'].append(x)

            # visualization
            plt.figure(self.x_var)
            for fp_type, points in container.items():
                if len(points['x']):
                    plot_style = stability.plot_scheme[fp_type]
                    plt.plot(points['p'],
                             points['x'],
                             '.',
                             **plot_style,
                             label=fp_type)
            plt.xlabel(par_a)
            plt.ylabel(self.x_var)

            # scale = (self.options.lim_scale - 1) / 2
            # plt.xlim(*utils.rescale(self.target_pars[self.dpar_names[0]], scale=scale))
            # plt.ylim(*utils.rescale(self.target_vars[self.x_var], scale=scale))

            plt.legend()
            if show:
                plt.show()

        elif len(self.target_pars) == 2:
            container = {
                c: {
                    'p0': [],
                    'p1': [],
                    'x': []
                }
                for c in stability.get_1d_stability_types()
            }

            # fixed point
            for p0 in self.resolutions[self.dpar_names[0]]:
                for p1 in self.resolutions[self.dpar_names[1]]:
                    xs = f_fixed_point(p0, p1)
                    for x in xs:
                        dfdx = f_dfdx(x, p0, p1)
                        fp_type = stability.stability_analysis(dfdx)
                        container[fp_type]['p0'].append(p0)
                        container[fp_type]['p1'].append(p1)
                        container[fp_type]['x'].append(x)

            # visualization
            fig = plt.figure(self.x_var)
            ax = fig.gca(projection='3d')
            for fp_type, points in container.items():
                if len(points['x']):
                    plot_style = stability.plot_scheme[fp_type]
                    xs = points['p0']
                    ys = points['p1']
                    zs = points['x']
                    ax.scatter(xs, ys, zs, **plot_style, label=fp_type)

            ax.set_xlabel(self.dpar_names[0])
            ax.set_ylabel(self.dpar_names[1])
            ax.set_zlabel(self.x_var)

            # scale = (self.options.lim_scale - 1) / 2
            # ax.set_xlim(*utils.rescale(self.target_pars[self.dpar_names[0]], scale=scale))
            # ax.set_ylim(*utils.rescale(self.target_pars[self.dpar_names[1]], scale=scale))
            # ax.set_zlim(*utils.rescale(self.target_vars[self.x_var], scale=scale))

            ax.grid(True)
            ax.legend()
            if show:
                plt.show()

        else:
            raise errors.ModelUseError(
                f'Cannot visualize co-dimension {len(self.target_pars)} '
                f'bifurcation.')
        return container