Ejemplo n.º 1
0
def model_transform(model):
    # check integrals
    if isinstance(model, NumDSWrapper):
        return model
    elif isinstance(model, ODEIntegrator):  #
        model = [model]
    if isinstance(model, (list, tuple)):
        if len(model) == 0:
            raise errors.AnalyzerError(f'Found no integrators: {model}')
        model = tuple(model)
        for intg in model:
            if not isinstance(intg, ODEIntegrator):
                raise errors.AnalyzerError(
                    f'Must be the instance of {ODEIntegrator}, but got {intg}.'
                )
    elif isinstance(model, dict):
        if len(model) == 0:
            raise errors.AnalyzerError(f'Found no integrators: {model}')
        model = tuple(model.values())
        for intg in model:
            if not isinstance(intg, ODEIntegrator):
                raise errors.AnalyzerError(
                    f'Must be the instance of {ODEIntegrator}, but got {intg}')
    elif isinstance(model, DynamicalSystem):
        model = tuple(model.ints().subset(ODEIntegrator).unique().values())
    else:
        raise errors.UnsupportedError(
            f'Dynamics analysis by symbolic approach only supports '
            f'list/tuple/dict of {ODEIntegrator} or {DynamicalSystem}, '
            f'but we got: {type(model)}: {str(model)}')

    # pars to update
    pars_update = set()
    for intg in model:
        pars_update.update(intg.parameters[1:])

    all_variables = set()
    all_parameters = set()
    for integral in model:
        if len(integral.variables) != 1:
            raise errors.AnalyzerError(
                f'Only supports one {ODEIntegrator.__name__} one variable, '
                f'but we got {len(integral.variables)} variables in {integral}.'
            )
        var = integral.variables[0]
        if var in all_variables:
            raise errors.AnalyzerError(
                f'Variable name {var} has been defined before. '
                f'Please change another name.')
        all_variables.add(var)
        # parameters
        all_parameters.update(integral.parameters[1:])

    # form a dynamic model
    return NumDSWrapper(integrals=model,
                        variables=list(all_variables),
                        parameters=list(all_parameters),
                        pars_update=pars_update)
Ejemplo n.º 2
0
  def __init__(self, var_name, variables, expressions, derivative_expr, scope, func_name):
    self.func_name = func_name
    # function scope
    self.func_scope = scope

    # differential variable name and time name
    self.var_name = var_name
    self.t_name = 't'

    # analyse function code
    self.expressions = [Expression(v, expr) for v, expr in zip(variables, expressions)]
    self.f_expr = Expression('_f_res_', derivative_expr)
    for k, num in Counter(variables).items():
      if num > 1:
        raise errors.AnalyzerError(
          f'Found "{k}" {num} times. Please assign each expression '
          f'in differential function with a unique name. ')
Ejemplo n.º 3
0
 def __init__(self,
              model,
              target_vars,
              fixed_vars=None,
              target_pars=None,
              pars_update=None,
              resolutions=None,
              **kwargs):
     if (target_pars is not None) and len(target_pars) > 0:
         raise errors.AnalyzerError(
             f'Phase plane analysis does not support "target_pars". '
             f'While we detect "target_pars={target_pars}".')
     super(PhasePlane1D, self).__init__(model=model,
                                        target_vars=target_vars,
                                        fixed_vars=fixed_vars,
                                        target_pars=target_pars,
                                        pars_update=pars_update,
                                        resolutions=resolutions,
                                        **kwargs)
Ejemplo n.º 4
0
def separate_variables(func_or_code):
    """Separate the expressions in a differential equation for each variable.

  For example, take the HH neuron model as an example:

  >>> eq_code = '''
  >>> def derivative(V, m, h, n, t, C, gNa, ENa, gK, EK, gL, EL, Iext):
  >>>     alpha = 0.1 * (V + 40) / (1 - bp.math.exp(-(V + 40) / 10))
  >>>     beta = 4.0 * bp.math.exp(-(V + 65) / 18)
  >>>     dmdt = alpha * (1 - m) - beta * m
  >>>
  >>>     alpha = 0.07 * bp.math.exp(-(V + 65) / 20.)
  >>>     beta = 1 / (1 + bp.math.exp(-(V + 35) / 10))
  >>>     dhdt = alpha * (1 - h) - beta * h
  >>>
  >>>     alpha = 0.01 * (V + 55) / (1 - bp.math.exp(-(V + 55) / 10))
  >>>     beta = 0.125 * bp.math.exp(-(V + 65) / 80)
  >>>     dndt = alpha * (1 - n) - beta * n
  >>>
  >>>     I_Na = (gNa * m ** 3.0 * h) * (V - ENa)
  >>>     I_K = (gK * n ** 4.0) * (V - EK)
  >>>     I_leak = gL * (V - EL)
  >>>     dVdt = (- I_Na - I_K - I_leak + Iext) / C
  >>>
  >>>     return dVdt, dmdt, dhdt, dndt
  >>> '''
  >>> separate_variables(eq_code)
  {'code_lines_for_returns': {'dVdt': ['I_Na = gNa * m ** 3.0 * h * (V - ENa)\n',
                                       'I_K = gK * n ** 4.0 * (V - EK)\n',
                                       'I_leak = gL * (V - EL)\n',
                                       'dVdt = (-I_Na - I_K - I_leak + Iext) / C\n'],
                              'dhdt': ['alpha = 0.07 * bp.math.exp(-(V + 65) / 20.0)\n',
                                       'beta = 1 / (1 + bp.math.exp(-(V + 35) / 10))\n',
                                       'dhdt = alpha * (1 - h) - beta * h\n'],
                              'dmdt': ['alpha = 0.1 * (V + 40) / (1 - '
                                       'bp.math.exp(-(V + 40) / 10))\n',
                                       'beta = 4.0 * bp.math.exp(-(V + 65) / 18)\n',
                                       'dmdt = alpha * (1 - m) - beta * m\n'],
                              'dndt': ['alpha = 0.01 * (V + 55) / (1 - '
                                       'bp.math.exp(-(V + 55) / 10))\n',
                                       'beta = 0.125 * bp.math.exp(-(V + 65) / 80)\n',
                                       'dndt = alpha * (1 - n) - beta * n\n']},
   'expressions_for_returns': {'dVdt': ['gNa * m ** 3.0 * h * (V - ENa)',
                                        'gK * n ** 4.0 * (V - EK)',
                                        'gL * (V - EL)',
                                        '(-I_Na - I_K - I_leak + Iext) / C'],
                               'dhdt': ['0.07 * bp.math.exp(-(V + 65) / 20.0)',
                                        '1 / (1 + bp.math.exp(-(V + 35) / 10))',
                                        'alpha * (1 - h) - beta * h'],
                               'dmdt': ['0.1 * (V + 40) / (1 - '
                                        'bp.math.exp(-(V + 40) / 10))',
                                        '4.0 * bp.math.exp(-(V + 65) / 18)',
                                        'alpha * (1 - m) - beta * m'],
                               'dndt': ['0.01 * (V + 55) / (1 - '
                                        'bp.math.exp(-(V + 55) / 10))',
                                        '0.125 * bp.math.exp(-(V + 65) / 80)',
                                        'alpha * (1 - n) - beta * n']},
   'variables_for_returns': {'dVdt': [['I_Na'], ['I_K'], ['I_leak'], ['dVdt']],
                             'dhdt': [['alpha'], ['beta'], ['dhdt']],
                             'dmdt': [['alpha'], ['beta'], ['dmdt']],
                             'dndt': [['alpha'], ['beta'], ['dndt']]}}

  Parameters
  ----------
  func_or_code : callable, str
      The callable function or the function code.

  Returns
  -------
  anlysis : dict
      The expressions for each return variable.
  """
    if callable(func_or_code):
        if tools.is_lambda_function(func_or_code):
            raise errors.AnalyzerError(
                f'Cannot analyze lambda function: {func_or_code}.')
        func_or_code = tools.deindent(inspect.getsource(func_or_code))
    assert isinstance(func_or_code, str)
    analyser = DiffEqReader()
    analyser.visit(ast.parse(func_or_code))

    returns = analyser.returns
    variables = analyser.variables
    right_exprs = analyser.rights
    code_lines = analyser.code_lines

    return_requires = OrderedDict([(r, set(tools.get_identifiers(r)))
                                   for r in returns])
    code_lines_for_returns = OrderedDict([(r, []) for r in returns])
    variables_for_returns = OrderedDict([(r, []) for r in returns])
    expressions_for_returns = OrderedDict([(r, []) for r in returns])

    length = len(variables)
    reverse_ids = list(reversed([i - length for i in range(length)]))
    for r in code_lines_for_returns.keys():
        for rid in reverse_ids:
            dep = []
            for v in variables[rid]:
                if v in return_requires[r]:
                    dep.append(v)
            if len(dep):
                code_lines_for_returns[r].append(code_lines[rid])
                variables_for_returns[r].append(variables[rid])
                expr = right_exprs[rid]
                expressions_for_returns[r].append(expr)
                for d in dep:
                    return_requires[r].remove(d)
                return_requires[r].update(tools.get_identifiers(expr))
    for r in list(code_lines_for_returns.keys()):
        code_lines_for_returns[r] = code_lines_for_returns[r][::-1]
        variables_for_returns[r] = variables_for_returns[r][::-1]
        expressions_for_returns[r] = expressions_for_returns[r][::-1]

    analysis = tools.DictPlus(
        code_lines_for_returns=code_lines_for_returns,
        variables_for_returns=variables_for_returns,
        expressions_for_returns=expressions_for_returns,
    )
    return analysis
Ejemplo n.º 5
0
    def plot_trajectory(self,
                        initials,
                        duration,
                        plot_durations=None,
                        axes='v-v',
                        dt=None,
                        show=False,
                        with_plot=True,
                        with_return=False,
                        **kwargs):
        """Plot trajectories according to the settings.

    Parameters
    ----------
    initials : list, tuple, dict
        The initial value setting of the targets. It can be a tuple/list of floats to specify
        each value of dynamical variables (for example, ``(a, b)``). It can also be a
        tuple/list of tuple to specify multiple initial values (for example,
        ``[(a1, b1), (a2, b2)]``).
    duration : int, float, tuple, list
        The running duration. Same with the ``duration`` in ``NeuGroup.run()``.

        - It can be a int/float (``t_end``) to specify the same running end time,
        - Or it can be a tuple/list of int/float (``(t_start, t_end)``) to specify
          the start and end simulation time.
        - Or, it can be a list of tuple (``[(t1_start, t1_end), (t2_start, t2_end)]``)
          to specify the specific start and end simulation time for each initial value.
    plot_durations : tuple, list, optional
        The duration to plot. It can be a tuple with ``(start, end)``. It can
        also be a list of tuple ``[(start1, end1), (start2, end2)]`` to specify
        the plot duration for each initial value running.
    axes : str
        The axes to plot. It can be:

         - 'v-v': Plot the trajectory in the 'x_var'-'y_var' axis.
         - 't-v': Plot the trajectory in the 'time'-'var' axis.
    show : bool
        Whether show or not.
    """

        utils.output('I am plotting the trajectory ...')

        if axes not in ['v-v', 't-v']:
            raise errors.AnalyzerError(
                f'Unknown axes "{axes}", only support "v-v" and "t-v".')

        # check the initial values
        initials = utils.check_initials(initials, self.target_var_names)

        # 2. format the running duration
        assert isinstance(duration, (int, float))

        # 3. format the plot duration
        plot_durations = utils.check_plot_durations(plot_durations, duration,
                                                    initials)

        # 5. run the network
        dt = math.get_dt() if dt is None else dt
        traject_model = utils.TrajectModel(initial_vars=initials,
                                           integrals={
                                               self.x_var: self.F_int_x,
                                               self.y_var: self.F_int_y
                                           },
                                           dt=dt)
        mon_res = traject_model.run(duration=duration)

        if with_plot:
            # plots
            for i, initial in enumerate(zip(*list(initials.values()))):
                # legend
                legend = f'$traj_{i}$: '
                for j, key in enumerate(self.target_var_names):
                    legend += f'{key}={round(float(initial[j]), 4)}, '
                legend = legend[:-2]

                # visualization
                start = int(plot_durations[i][0] / dt)
                end = int(plot_durations[i][1] / dt)
                if axes == 'v-v':
                    lines = plt.plot(mon_res[self.x_var][start:end, i],
                                     mon_res[self.y_var][start:end, i],
                                     label=legend,
                                     **kwargs)
                    utils.add_arrow(lines[0])
                else:
                    plt.plot(mon_res.ts[start:end],
                             mon_res[self.x_var][start:end, i],
                             label=legend + f', {self.x_var}',
                             **kwargs)
                    plt.plot(mon_res.ts[start:end],
                             mon_res[self.y_var][start:end, i],
                             label=legend + f', {self.y_var}',
                             **kwargs)

            # visualization of others
            if axes == 'v-v':
                plt.xlabel(self.x_var)
                plt.ylabel(self.y_var)
                scale = (self.lim_scale - 1.) / 2
                plt.xlim(
                    *utils.rescale(self.target_vars[self.x_var], scale=scale))
                plt.ylim(
                    *utils.rescale(self.target_vars[self.y_var], scale=scale))
                plt.legend()
            else:
                plt.legend(title='Initial values')

            if show:
                plt.show()

        if with_return:
            return mon_res
Ejemplo n.º 6
0
    def plot_fixed_point(
        self,
        with_plot=True,
        with_return=False,
        show=False,
        tol_unique=1e-2,
        tol_aux=1e-8,
        tol_opt_screen=None,
        select_candidates='fx-nullcline',
        num_rank=100,
    ):
        """Plot the fixed point and analyze its stability.
    """
        utils.output('I am searching fixed points ...')

        if self._can_convert_to_one_eq():
            if self.convert_type() == C.x_by_y:
                candidates = self.resolutions[self.y_var].value
            else:
                candidates = self.resolutions[self.x_var].value
        else:
            if select_candidates == 'fx-nullcline':
                candidates = [
                    self.analyzed_results[key][0]
                    for key in self.analyzed_results.keys()
                    if key.startswith(C.fx_nullcline_points)
                ]
                if len(candidates) == 0:
                    raise errors.AnalyzerError(
                        f'No nullcline points are found, please call '
                        f'".{self.plot_nullcline.__name__}()" first.')
                candidates = jnp.vstack(candidates)
            elif select_candidates == 'fy-nullcline':
                candidates = [
                    self.analyzed_results[key][0]
                    for key in self.analyzed_results.keys()
                    if key.startswith(C.fy_nullcline_points)
                ]
                if len(candidates) == 0:
                    raise errors.AnalyzerError(
                        f'No nullcline points are found, please call '
                        f'".{self.plot_nullcline.__name__}()" first.')
                candidates = jnp.vstack(candidates)
            elif select_candidates == 'nullclines':
                candidates = [
                    self.analyzed_results[key][0]
                    for key in self.analyzed_results.keys()
                    if key.startswith(C.fy_nullcline_points)
                    or key.startswith(C.fy_nullcline_points)
                ]
                if len(candidates) == 0:
                    raise errors.AnalyzerError(
                        f'No nullcline points are found, please call '
                        f'".{self.plot_nullcline.__name__}()" first.')
                candidates = jnp.vstack(candidates)
            elif select_candidates == 'aux_rank':
                candidates, _ = self._get_fp_candidates_by_aux_rank(
                    num_rank=num_rank)
            else:
                raise ValueError

        # get fixed points
        if len(candidates):
            fixed_points, _, _ = self._get_fixed_points(
                jnp.asarray(candidates),
                tol_aux=tol_aux,
                tol_unique=tol_unique,
                tol_opt_candidate=tol_opt_screen)
            utils.output(
                'I am trying to filter out duplicate fixed points ...')
            fixed_points = np.asarray(fixed_points)
            fixed_points, _ = utils.keep_unique(fixed_points,
                                                tolerance=tol_unique)
            utils.output(f'{C.prefix}Found {len(fixed_points)} fixed points.')
        else:
            utils.output(f'{C.prefix}Found no fixed points.')
            return

        # stability analysis
        # ------------------
        container = {
            a: {
                'x': [],
                'y': []
            }
            for a in stability.get_2d_stability_types()
        }
        for i in range(len(fixed_points)):
            x = fixed_points[i, 0]
            y = fixed_points[i, 1]
            fp_type = stability.stability_analysis(self.F_jacobian(x, y))
            utils.output(
                f"{C.prefix}#{i + 1} {self.x_var}={x}, {self.y_var}={y} is a {fp_type}."
            )
            container[fp_type]['x'].append(x)
            container[fp_type]['y'].append(y)

        # visualization
        # -------------
        if with_plot:
            for fp_type, points in container.items():
                if len(points['x']):
                    plot_style = stability.plot_scheme[fp_type]
                    plt.plot(points['x'],
                             points['y'],
                             '.',
                             markersize=20,
                             **plot_style,
                             label=fp_type)
            plt.legend()
            if show:
                plt.show()

        if with_return:
            return fixed_points
Ejemplo n.º 7
0
    def plot_vector_field(self,
                          with_plot=True,
                          with_return=False,
                          plot_method='streamplot',
                          plot_style=None,
                          show=False):
        """Plot the vector field.

    Parameters
    ----------
    with_plot: bool
    with_return : bool
    show : bool
    plot_method : str
        The method to plot the vector filed. It can be "streamplot" or "quiver".
    plot_style : dict, optional
        The style for vector filed plotting.

        - For ``plot_method="streamplot"``, it can set the keywords like "density",
          "linewidth", "color", "arrowsize". More settings please check
          https://matplotlib.org/api/_as_gen/matplotlib.pyplot.streamplot.html.
        - For ``plot_method="quiver"``, it can set the keywords like "color",
          "units", "angles", "scale". More settings please check
          https://matplotlib.org/api/_as_gen/matplotlib.pyplot.quiver.html.
    """
        utils.output('I am creating the vector field ...')

        # get vector fields
        xs = self.resolutions[self.x_var]
        ys = self.resolutions[self.y_var]
        X, Y = bm.meshgrid(xs, ys)
        dx = self.F_fx(X, Y)
        dy = self.F_fy(X, Y)
        X, Y = np.asarray(X), np.asarray(Y)
        dx, dy = np.asarray(dx), np.asarray(dy)

        if with_plot:  # plot vector fields
            if plot_method == 'quiver':
                if plot_style is None:
                    plot_style = dict(units='xy')
                if (not np.isnan(dx).any()) and (not np.isnan(dy).any()):
                    speed = np.sqrt(dx**2 + dy**2)
                    dx = dx / speed
                    dy = dy / speed
                plt.quiver(X, Y, dx, dy, **plot_style)
            elif plot_method == 'streamplot':
                if plot_style is None:
                    plot_style = dict(arrowsize=1.2,
                                      density=1,
                                      color='thistle')
                linewidth = plot_style.get('linewidth', None)
                if linewidth is None:
                    if (not np.isnan(dx).any()) and (not np.isnan(dy).any()):
                        min_width, max_width = 0.5, 5.5
                        speed = np.nan_to_num(np.sqrt(dx**2 + dy**2))
                        linewidth = min_width + max_width * (speed /
                                                             speed.max())
                plt.streamplot(X, Y, dx, dy, linewidth=linewidth, **plot_style)
            else:
                raise errors.AnalyzerError(
                    f'Unknown plot_method "{plot_method}", '
                    f'only supports "quiver" and "streamplot".')

            plt.xlabel(self.x_var)
            plt.ylabel(self.y_var)
            if show:
                plt.show()

        if with_return:  # return vector fields
            return dx, dy
Ejemplo n.º 8
0
    def plot_trajectory(self,
                        initials,
                        duration,
                        plot_duration=None,
                        show=False):
        """Plot trajectories according to the settings.

        Parameters
        ----------
        initials : list, tuple
            The initial value setting of the targets. It can be a tuple/list of floats to specify
            each value of dynamical variables (for example, ``(a, b)``). It can also be a
            tuple/list of tuple to specify multiple initial values (for example,
            ``[(a1, b1), (a2, b2)]``).
        duration : int, float, tuple, list
            The running duration. Same with the ``duration`` in ``NeuGroup.run()``.
            It can be a int/float (``t_end``) to specify the same running end time,
            or it can be a tuple/list of int/float (``(t_start, t_end)``) to specify
            the start and end simulation time. Or, it can be a list of tuple
            (``[(t1_start, t1_end), (t2_start, t2_end)]``) to specify the specific
            start and end simulation time for each initial value.
        plot_duration : tuple/list of tuple, optional
            The duration to plot. It can be a tuple with ``(start, end)``. It can
            also be a list of tuple ``[(start1, end1), (start2, end2)]`` to specify
            the plot duration for each initial value running.
        show : bool
            Whether show or not.
        """
        print('plot trajectory ...')

        # 1. format the initial values
        all_vars = self.fast_var_names + self.slow_var_names
        if isinstance(initials, dict):
            initials = [initials]
        elif isinstance(initials, (list, tuple)):
            if isinstance(initials[0], (int, float)):
                initials = [{all_vars[i]: v for i, v in enumerate(initials)}]
            elif isinstance(initials[0], dict):
                initials = initials
            elif isinstance(initials[0], (tuple, list)) and isinstance(
                    initials[0][0], (int, float)):
                initials = [{all_vars[i]: v
                             for i, v in enumerate(init)} for init in initials]
            else:
                raise ValueError
        else:
            raise ValueError
        for initial in initials:
            if len(initial) != len(all_vars):
                raise errors.AnalyzerError(
                    f'Should provide all fast-slow variables ({all_vars}) '
                    f' initial values, but we only get initial values for '
                    f'variables {list(initial.keys())}.')

        # 2. format the running duration
        if isinstance(duration, (int, float)):
            duration = [(0, duration) for _ in range(len(initials))]
        elif isinstance(duration[0], (int, float)):
            duration = [duration for _ in range(len(initials))]
        else:
            assert len(duration) == len(initials)

        # 3. format the plot duration
        if plot_duration is None:
            plot_duration = duration
        if isinstance(plot_duration[0], (int, float)):
            plot_duration = [plot_duration for _ in range(len(initials))]
        else:
            assert len(plot_duration) == len(initials)

        # 5. run the network
        for init_i, initial in enumerate(initials):
            traj_group = Trajectory(size=1,
                                    integrals=self.model.integrals,
                                    target_vars=initial,
                                    fixed_vars=self.fixed_vars,
                                    pars_update=self.pars_update,
                                    scope=self.model.scopes)
            traj_group.run(duration=duration[init_i], report=False)

            #   5.3 legend
            legend = f'$traj_{init_i}$: '
            for key in all_vars:
                legend += f'{key}={initial[key]}, '
            legend = legend[:-2]

            #   5.4 trajectory
            start = int(plot_duration[init_i][0] / backend.get_dt())
            end = int(plot_duration[init_i][1] / backend.get_dt())

            #   5.5 visualization
            for var_name in self.fast_var_names:
                s0 = traj_group.mon[self.slow_var_names[0]][start:end, 0]
                fast = traj_group.mon[var_name][start:end, 0]

                fig = plt.figure(var_name)
                if len(self.slow_var_names) == 1:
                    lines = plt.plot(s0, fast, label=legend)
                    utils.add_arrow(lines[0])
                    # middle = int(s0.shape[0] / 2)
                    # plt.arrow(s0[middle], fast[middle],
                    #           s0[middle + 1] - s0[middle], fast[middle + 1] - fast[middle],
                    #           shape='full')

                elif len(self.slow_var_names) == 2:
                    fig.gca(projection='3d')
                    s1 = traj_group.mon[self.slow_var_names[1]][start:end, 0]
                    plt.plot(s0, s1, fast, label=legend)
                else:
                    raise errors.AnalyzerError

        # 6. visualization
        for var_name in self.fast_vars.keys():
            fig = plt.figure(var_name)

            # scale = (self.lim_scale - 1.) / 2
            if len(self.slow_var_names) == 1:
                # plt.xlim(*utils.rescale(self.slow_vars[self.slow_var_names[0]], scale=scale))
                # plt.ylim(*utils.rescale(self.fast_vars[var_name], scale=scale))
                plt.xlabel(self.slow_var_names[0])
                plt.ylabel(var_name)
            elif len(self.slow_var_names) == 2:
                ax = fig.gca(projection='3d')
                # ax.set_xlim(*utils.rescale(self.slow_vars[self.slow_var_names[0]], scale=scale))
                # ax.set_ylim(*utils.rescale(self.slow_vars[self.slow_var_names[1]], scale=scale))
                # ax.set_zlim(*utils.rescale(self.fast_vars[var_name], scale=scale))
                ax.set_xlabel(self.slow_var_names[0])
                ax.set_ylabel(self.slow_var_names[1])
                ax.set_zlabel(var_name)

            plt.legend()

        if show:
            plt.show()
Ejemplo n.º 9
0
    def plot_limit_cycle_by_sim(self,
                                var,
                                duration=100,
                                inputs=(),
                                plot_style=None,
                                tol=0.001,
                                show=False):
        print('plot limit cycle ...')

        if self.fixed_points is None:
            raise errors.AnalyzerError(
                'Please call "plot_bifurcation()" before "plot_limit_cycle_by_sim()".'
            )
        if plot_style is None:
            plot_style = dict()
        fmt = plot_style.pop('fmt', '.')

        if var not in [self.x_var, self.y_var]:
            raise errors.AnalyzerError()

        all_xs, all_ys, all_p0, all_p1 = [], [], [], []

        # unstable node
        unstable_node = self.fixed_points[stability.UNSTABLE_NODE_2D]
        all_xs.extend(unstable_node[self.x_var])
        all_ys.extend(unstable_node[self.y_var])
        if len(self.dpar_names) == 1:
            all_p0.extend(unstable_node['p'])
        elif len(self.dpar_names) == 2:
            all_p0.extend(unstable_node['p0'])
            all_p1.extend(unstable_node['p1'])
        else:
            raise ValueError

        # unstable focus
        unstable_focus = self.fixed_points[stability.UNSTABLE_FOCUS_2D]
        all_xs.extend(unstable_focus[self.x_var])
        all_ys.extend(unstable_focus[self.y_var])
        if len(self.dpar_names) == 1:
            all_p0.extend(unstable_focus['p'])
        elif len(self.dpar_names) == 2:
            all_p0.extend(unstable_focus['p0'])
            all_p1.extend(unstable_focus['p1'])
        else:
            raise ValueError

        # format points
        all_xs = np.array(all_xs)
        all_ys = np.array(all_ys)
        all_p0 = np.array(all_p0)
        all_p1 = np.array(all_p1)

        # fixed variables
        fixed_vars = dict()
        for key, val in self.fixed_vars.items():
            fixed_vars[key] = val
        fixed_vars[self.dpar_names[0]] = all_p0
        if len(self.dpar_names) == 2:
            fixed_vars[self.dpar_names[1]] = all_p1

        # initialize neuron group
        length = all_xs.shape[0]
        traj_group = Trajectory(size=length,
                                integrals=self.model.integrals,
                                target_vars={
                                    self.x_var: all_xs,
                                    self.y_var: all_ys
                                },
                                fixed_vars=fixed_vars,
                                pars_update=self.pars_update,
                                scope=self.model.scopes)
        traj_group.run(duration=duration)

        # find limit cycles
        limit_cycle_max = []
        limit_cycle_min = []
        # limit_cycle = []
        p0_limit_cycle = []
        p1_limit_cycle = []
        for i in range(length):
            data = traj_group.mon[var][:, i]
            max_index = utils.find_indexes_of_limit_cycle_max(data, tol=tol)
            if max_index[0] != -1:
                x_cycle = data[max_index[0]:max_index[1]]
                limit_cycle_max.append(data[max_index[1]])
                limit_cycle_min.append(x_cycle.min())
                # limit_cycle.append(x_cycle)
                p0_limit_cycle.append(all_p0[i])
                if len(self.dpar_names) == 2:
                    p1_limit_cycle.append(all_p1[i])
        self.fixed_points['limit_cycle'] = {
            var: {
                'max': limit_cycle_max,
                'min': limit_cycle_min,
                # 'cycle': limit_cycle
            }
        }
        p0_limit_cycle = np.array(p0_limit_cycle)
        p1_limit_cycle = np.array(p1_limit_cycle)

        # visualization
        if len(self.dpar_names) == 2:
            self.fixed_points['limit_cycle'] = {
                'p0': p0_limit_cycle,
                'p1': p1_limit_cycle
            }
            plt.figure(var)
            plt.plot(p0_limit_cycle,
                     p1_limit_cycle,
                     limit_cycle_max,
                     **plot_style,
                     label='limit cycle (max)')
            plt.plot(p0_limit_cycle,
                     p1_limit_cycle,
                     limit_cycle_min,
                     **plot_style,
                     label='limit cycle (min)')
            plt.legend()

        else:
            self.fixed_points['limit_cycle'] = {'p': p0_limit_cycle}
            if len(limit_cycle_max):
                plt.figure(var)
                plt.plot(p0_limit_cycle,
                         limit_cycle_max,
                         fmt,
                         **plot_style,
                         label='limit cycle (max)')
                plt.plot(p0_limit_cycle,
                         limit_cycle_min,
                         fmt,
                         **plot_style,
                         label='limit cycle (min)')
                plt.legend()

        if show:
            plt.show()

        del traj_group
        gc.collect()