Ejemplo n.º 1
0
    def plot_fixed_point(self, show=False, with_plot=True, with_return=False):
        """Plot the fixed point."""
        utils.output('I am searching fixed points ...')

        # fixed points and stability analysis
        fps, _, pars = self._get_fixed_points(self.resolutions[self.x_var])
        container = {a: [] for a in stability.get_1d_stability_types()}
        for i in range(len(fps)):
            x = fps[i]
            dfdx = self.F_dfxdx(x)
            fp_type = stability.stability_analysis(dfdx)
            utils.output(
                f"Fixed point #{i + 1} at {self.x_var}={x} is a {fp_type}.")
            container[fp_type].append(x)

        # visualization
        if with_plot:
            for fp_type, points in container.items():
                if len(points):
                    plot_style = stability.plot_scheme[fp_type]
                    plt.plot(points, [0] * len(points),
                             '.',
                             markersize=20,
                             **plot_style,
                             label=fp_type)
            plt.legend()
            if show:
                plt.show()

        # return
        if with_return:
            return fps
Ejemplo n.º 2
0
    def plot_fixed_point(self, show=False):
        """Plot the fixed point and analyze its stability.

        Parameters
        ----------
        show : bool
            Whether show the figure.

        Returns
        -------
        results : tuple
            The value points.
        """
        print('plot fixed point ...')

        # function for fixed point solving
        f_fixed_point = self.get_f_fixed_point()
        x_values, y_values = f_fixed_point()

        # function for jacobian matrix
        f_jacobian = self.get_f_jacobian()

        # stability analysis
        # ------------------
        container = {
            a: {
                'x': [],
                'y': []
            }
            for a in stability.get_2d_stability_types()
        }
        for i in range(len(x_values)):
            x = x_values[i]
            y = y_values[i]
            fp_type = stability.stability_analysis(f_jacobian(x, y))
            print(
                f"Fixed point #{i + 1} at {self.x_var}={x}, {self.y_var}={y} is a {fp_type}."
            )
            container[fp_type]['x'].append(x)
            container[fp_type]['y'].append(y)

        # visualization
        # -------------
        for fp_type, points in container.items():
            if len(points['x']):
                plot_style = stability.plot_scheme[fp_type]
                plt.plot(points['x'],
                         points['y'],
                         '.',
                         markersize=20,
                         **plot_style,
                         label=fp_type)
        plt.legend()
        if show:
            plt.show()

        return x_values, y_values
Ejemplo n.º 3
0
    def plot_fixed_point(self, show=False):
        """Plot the fixed point.

        Parameters
        ----------
        show : bool
            Whether show the figure.

        Returns
        -------
        points : np.ndarray
            The fixed points.
        """
        print('plot fixed point ...')

        # 1. functions
        f_fixed_point = self.get_f_fixed_point()
        f_dfdx = self.get_f_dfdx()

        # 2. stability analysis
        x_values = f_fixed_point()
        container = {a: [] for a in stability.get_1d_stability_types()}
        for i in range(len(x_values)):
            x = x_values[i]
            dfdx = f_dfdx(x)
            fp_type = stability.stability_analysis(dfdx)
            print(f"Fixed point #{i + 1} at {self.x_var}={x} is a {fp_type}.")
            container[fp_type].append(x)

        # 3. visualization
        for fp_type, points in container.items():
            if len(points):
                plot_style = stability.plot_scheme[fp_type]
                plt.plot(points, [0] * len(points),
                         '.',
                         markersize=20,
                         **plot_style,
                         label=fp_type)
        plt.legend()
        if show:
            plt.show()

        return np.array(x_values)
Ejemplo n.º 4
0
  def plot_bifurcation(self, with_plot=True, show=False, with_return=False,
                       tol_aux=1e-8, loss_screen=None):
    utils.output('I am making bifurcation analysis ...')

    xs = self.resolutions[self.x_var]
    vps = bm.meshgrid(xs, *tuple(self.resolutions[p] for p in self.target_par_names))
    vps = tuple(jnp.moveaxis(vp.value, 0, 1).flatten() for vp in vps)
    candidates = vps[0]
    pars = vps[1:]
    fixed_points, _, pars = self._get_fixed_points(candidates, *pars,
                                                   tol_aux=tol_aux,
                                                   loss_screen=loss_screen,
                                                   num_seg=len(xs))
    dfxdx = np.asarray(self.F_vmap_dfxdx(jnp.asarray(fixed_points), *pars))
    pars = tuple(np.asarray(p) for p in pars)

    if with_plot:
      if len(self.target_pars) == 1:
        container = {c: {'p': [], 'x': []} for c in stability.get_1d_stability_types()}

        # fixed point
        for p, x, dx in zip(pars[0], fixed_points, dfxdx):
          fp_type = stability.stability_analysis(dx)
          container[fp_type]['p'].append(p)
          container[fp_type]['x'].append(x)

        # visualization
        plt.figure(self.x_var)
        for fp_type, points in container.items():
          if len(points['x']):
            plot_style = stability.plot_scheme[fp_type]
            plt.plot(points['p'], points['x'], '.', **plot_style, label=fp_type)
        plt.xlabel(self.target_par_names[0])
        plt.ylabel(self.x_var)

        scale = (self.lim_scale - 1) / 2
        plt.xlim(*utils.rescale(self.target_pars[self.target_par_names[0]], scale=scale))
        plt.ylim(*utils.rescale(self.target_vars[self.x_var], scale=scale))

        plt.legend()
        if show:
          plt.show()

      elif len(self.target_pars) == 2:
        container = {c: {'p0': [], 'p1': [], 'x': []} for c in stability.get_1d_stability_types()}

        # fixed point
        for p0, p1, x, dx in zip(pars[0], pars[1], fixed_points, dfxdx):
          fp_type = stability.stability_analysis(dx)
          container[fp_type]['p0'].append(p0)
          container[fp_type]['p1'].append(p1)
          container[fp_type]['x'].append(x)

        # visualization
        fig = plt.figure(self.x_var)
        ax = fig.add_subplot(projection='3d')
        for fp_type, points in container.items():
          if len(points['x']):
            plot_style = stability.plot_scheme[fp_type]
            xs = points['p0']
            ys = points['p1']
            zs = points['x']
            ax.scatter(xs, ys, zs, **plot_style, label=fp_type)

        ax.set_xlabel(self.target_par_names[0])
        ax.set_ylabel(self.target_par_names[1])
        ax.set_zlabel(self.x_var)

        scale = (self.lim_scale - 1) / 2
        ax.set_xlim(*utils.rescale(self.target_pars[self.target_par_names[0]], scale=scale))
        ax.set_ylim(*utils.rescale(self.target_pars[self.target_par_names[1]], scale=scale))
        ax.set_zlim(*utils.rescale(self.target_vars[self.x_var], scale=scale))

        ax.grid(True)
        ax.legend()
        if show:
          plt.show()

      else:
        raise errors.BrainPyError(f'Cannot visualize co-dimension {len(self.target_pars)} '
                                  f'bifurcation.')
    if with_return:
      return fixed_points, pars, dfxdx
Ejemplo n.º 5
0
  def plot_bifurcation(self, with_plot=True, show=False, with_return=False,
                       tol_aux=1e-8, tol_unique=1e-2, tol_opt_candidate=None,
                       num_par_segments=1, num_fp_segment=1, nullcline_aux_filter=1.,
                       select_candidates='aux_rank', num_rank=100):
    """Make the bifurcation analysis.

    Parameters
    ----------
    with_plot: bool
      Whether plot the bifurcation figure.
    show: bool
      Whether show the figure.
    with_return: bool
      Whether return the computed bifurcation results.
    tol_aux: float
      The loss tolerance of auxiliary function :math:`f_{aux}` to confirm the fixed
      point. Default is 1e-7. Once :math:`f_{aux}(x_1) < \mathrm{tol\_aux}`,
      :math:`x_1` will be a fixed point.
    tol_unique: float
      The tolerance of distance between candidate fixed points to confirm they are
      the same. Default is 1e-2. If :math:`|x_1 - x_2| > \mathrm{tol\_unique}`,
      then :math:`x_1` and :math:`x_2` are unique fixed points. Otherwise,
      :math:`x_1` and :math:`x_2` will be treated as a same fixed point.
    tol_opt_candidate: float, optional
      The tolerance of auxiliary function :math:`f_{aux}` to select candidate
      initial points for fixed point optimization.
    num_par_segments: int, sequence of int
      How to segment parameters.
    num_fp_segment: int
      How to segment fixed points.
    nullcline_aux_filter: float
      The
    select_candidates: str
      The method to select candidate fixed points. It can be:

      - ``fx-nullcline``: use the points of fx-nullcline.
      - ``fy-nullcline``: use the points of fy-nullcline.
      - ``nullclines``: use the points in both of fx-nullcline and fy-nullcline.
      - ``aux_rank``: use the minimal value of points for the auxiliary function.
    num_rank: int
      The number of candidates to be used to optimize the fixed points.
      rank to use.

    Returns
    -------
    results : tuple
      Return a tuple of analyzed results:

      - fixed points: a 2D matrix with the shape of (num_point, num_var)
      - parameters: a 2D matrix with the shape of (num_point, num_par)
      - jacobians: a 3D tensors with the shape of (num_point, 2, 2)
    """
    utils.output('I am making bifurcation analysis ...')

    if self._can_convert_to_one_eq():
      if self.convert_type() == C.x_by_y:
        X = self.resolutions[self.y_var].value
      else:
        X = self.resolutions[self.x_var].value
      pars = tuple(self.resolutions[p].value for p in self.target_par_names)
      mesh_values = jnp.meshgrid(*((X,) + pars))
      mesh_values = tuple(jnp.moveaxis(v, 0, 1).flatten() for v in mesh_values)
      candidates = mesh_values[0]
      parameters = mesh_values[1:]

    else:
      if select_candidates == 'fx-nullcline':
        fx_nullclines = self._get_fx_nullcline_points(num_segments=num_par_segments,
                                                      fp_aux_filter=nullcline_aux_filter)
        candidates = fx_nullclines[0]
        parameters = fx_nullclines[1:]
      elif select_candidates == 'fy-nullcline':
        fy_nullclines = self._get_fy_nullcline_points(num_segments=num_par_segments,
                                                      fp_aux_filter=nullcline_aux_filter)
        candidates = fy_nullclines[0]
        parameters = fy_nullclines[1:]
      elif select_candidates == 'nullclines':
        fx_nullclines = self._get_fx_nullcline_points(num_segments=num_par_segments,
                                                      fp_aux_filter=nullcline_aux_filter)
        fy_nullclines = self._get_fy_nullcline_points(num_segments=num_par_segments,
                                                      fp_aux_filter=nullcline_aux_filter)
        candidates = jnp.vstack([fx_nullclines[0], fy_nullclines[0]])
        parameters = [jnp.concatenate([fx_nullclines[i], fy_nullclines[i]])
                      for i in range(1, len(fy_nullclines))]
      elif select_candidates == 'aux_rank':
        assert nullcline_aux_filter > 0.
        candidates, parameters = self._get_fp_candidates_by_aux_rank(num_segments=num_par_segments,
                                                                     num_rank=num_rank)
      else:
        raise ValueError
    candidates, _, parameters = self._get_fixed_points(candidates,
                                                       *parameters,
                                                       tol_aux=tol_aux,
                                                       tol_unique=tol_unique,
                                                       tol_opt_candidate=tol_opt_candidate,
                                                       num_segment=num_fp_segment)
    candidates = np.asarray(candidates)
    parameters = np.stack(tuple(np.asarray(p) for p in parameters)).T
    utils.output('I am trying to filter out duplicate fixed points ...')
    final_fps = []
    final_pars = []
    for par in np.unique(parameters, axis=0):
      ids = np.where(np.all(parameters == par, axis=1))[0]
      fps, ids2 = utils.keep_unique(candidates[ids], tolerance=tol_unique)
      final_fps.append(fps)
      final_pars.append(parameters[ids[ids2]])
    final_fps = np.vstack(final_fps)  # with the shape of (num_point, num_var)
    final_pars = np.vstack(final_pars)  # with the shape of (num_point, num_par)
    jacobians = np.asarray(self.F_vmap_jacobian(jnp.asarray(final_fps), *final_pars.T))
    utils.output(f'{C.prefix}Found {len(final_fps)} fixed points.')

    # remember the fixed points for later limit cycle plotting
    self._fixed_points = (final_fps, final_pars)

    if with_plot:
      # bifurcation analysis of co-dimension 1
      if len(self.target_pars) == 1:
        container = {c: {'p': [], self.x_var: [], self.y_var: []}
                     for c in stability.get_2d_stability_types()}

        # fixed point
        for p, xy, J in zip(final_pars, final_fps, jacobians):
          fp_type = stability.stability_analysis(J)
          container[fp_type]['p'].append(p[0])
          container[fp_type][self.x_var].append(xy[0])
          container[fp_type][self.y_var].append(xy[1])

        # visualization
        for var in self.target_var_names:
          plt.figure(var)
          for fp_type, points in container.items():
            if len(points['p']):
              plot_style = stability.plot_scheme[fp_type]
              plt.plot(points['p'], points[var], '.', **plot_style, label=fp_type)
          plt.xlabel(self.target_par_names[0])
          plt.ylabel(var)

          scale = (self.lim_scale - 1) / 2
          plt.xlim(*utils.rescale(self.target_pars[self.target_par_names[0]], scale=scale))
          plt.ylim(*utils.rescale(self.target_vars[var], scale=scale))

          plt.legend()
        if show:
          plt.show()

      # bifurcation analysis of co-dimension 2
      elif len(self.target_pars) == 2:
        container = {c: {'p0': [], 'p1': [], self.x_var: [], self.y_var: []}
                     for c in stability.get_2d_stability_types()}

        # fixed point
        for p, xy, J in zip(final_pars, final_fps, jacobians):
          fp_type = stability.stability_analysis(J)
          container[fp_type]['p0'].append(p[0])
          container[fp_type]['p1'].append(p[1])
          container[fp_type][self.x_var].append(xy[0])
          container[fp_type][self.y_var].append(xy[1])

        # visualization
        for var in self.target_var_names:
          fig = plt.figure(var)
          ax = fig.add_subplot(projection='3d')
          for fp_type, points in container.items():
            if len(points['p0']):
              plot_style = stability.plot_scheme[fp_type]
              xs = points['p0']
              ys = points['p1']
              zs = points[var]
              ax.scatter(xs, ys, zs, **plot_style, label=fp_type)

          ax.set_xlabel(self.target_par_names[0])
          ax.set_ylabel(self.target_par_names[1])
          ax.set_zlabel(var)
          scale = (self.lim_scale - 1) / 2
          ax.set_xlim(*utils.rescale(self.target_pars[self.target_par_names[0]], scale=scale))
          ax.set_ylim(*utils.rescale(self.target_pars[self.target_par_names[1]], scale=scale))
          ax.set_zlim(*utils.rescale(self.target_vars[var], scale=scale))
          ax.grid(True)
          ax.legend()
        if show:
          plt.show()

      else:
        raise ValueError('Unknown length of parameters.')

    if with_return:
      return final_fps, final_pars, jacobians
Ejemplo n.º 6
0
    def plot_fixed_point(
        self,
        with_plot=True,
        with_return=False,
        show=False,
        tol_unique=1e-2,
        tol_aux=1e-8,
        tol_opt_screen=None,
        select_candidates='fx-nullcline',
        num_rank=100,
    ):
        """Plot the fixed point and analyze its stability.
    """
        utils.output('I am searching fixed points ...')

        if self._can_convert_to_one_eq():
            if self.convert_type() == C.x_by_y:
                candidates = self.resolutions[self.y_var].value
            else:
                candidates = self.resolutions[self.x_var].value
        else:
            if select_candidates == 'fx-nullcline':
                candidates = [
                    self.analyzed_results[key][0]
                    for key in self.analyzed_results.keys()
                    if key.startswith(C.fx_nullcline_points)
                ]
                if len(candidates) == 0:
                    raise errors.AnalyzerError(
                        f'No nullcline points are found, please call '
                        f'".{self.plot_nullcline.__name__}()" first.')
                candidates = jnp.vstack(candidates)
            elif select_candidates == 'fy-nullcline':
                candidates = [
                    self.analyzed_results[key][0]
                    for key in self.analyzed_results.keys()
                    if key.startswith(C.fy_nullcline_points)
                ]
                if len(candidates) == 0:
                    raise errors.AnalyzerError(
                        f'No nullcline points are found, please call '
                        f'".{self.plot_nullcline.__name__}()" first.')
                candidates = jnp.vstack(candidates)
            elif select_candidates == 'nullclines':
                candidates = [
                    self.analyzed_results[key][0]
                    for key in self.analyzed_results.keys()
                    if key.startswith(C.fy_nullcline_points)
                    or key.startswith(C.fy_nullcline_points)
                ]
                if len(candidates) == 0:
                    raise errors.AnalyzerError(
                        f'No nullcline points are found, please call '
                        f'".{self.plot_nullcline.__name__}()" first.')
                candidates = jnp.vstack(candidates)
            elif select_candidates == 'aux_rank':
                candidates, _ = self._get_fp_candidates_by_aux_rank(
                    num_rank=num_rank)
            else:
                raise ValueError

        # get fixed points
        if len(candidates):
            fixed_points, _, _ = self._get_fixed_points(
                jnp.asarray(candidates),
                tol_aux=tol_aux,
                tol_unique=tol_unique,
                tol_opt_candidate=tol_opt_screen)
            utils.output(
                'I am trying to filter out duplicate fixed points ...')
            fixed_points = np.asarray(fixed_points)
            fixed_points, _ = utils.keep_unique(fixed_points,
                                                tolerance=tol_unique)
            utils.output(f'{C.prefix}Found {len(fixed_points)} fixed points.')
        else:
            utils.output(f'{C.prefix}Found no fixed points.')
            return

        # stability analysis
        # ------------------
        container = {
            a: {
                'x': [],
                'y': []
            }
            for a in stability.get_2d_stability_types()
        }
        for i in range(len(fixed_points)):
            x = fixed_points[i, 0]
            y = fixed_points[i, 1]
            fp_type = stability.stability_analysis(self.F_jacobian(x, y))
            utils.output(
                f"{C.prefix}#{i + 1} {self.x_var}={x}, {self.y_var}={y} is a {fp_type}."
            )
            container[fp_type]['x'].append(x)
            container[fp_type]['y'].append(y)

        # visualization
        # -------------
        if with_plot:
            for fp_type, points in container.items():
                if len(points['x']):
                    plot_style = stability.plot_scheme[fp_type]
                    plt.plot(points['x'],
                             points['y'],
                             '.',
                             markersize=20,
                             **plot_style,
                             label=fp_type)
            plt.legend()
            if show:
                plt.show()

        if with_return:
            return fixed_points
Ejemplo n.º 7
0
    def plot_bifurcation(self, show=False):
        print('plot bifurcation ...')

        # functions
        f_fixed_point = self.get_f_fixed_point()
        f_jacobian = self.get_f_jacobian()

        # bifurcation analysis of co-dimension 1
        if len(self.target_pars) == 1:
            container = {
                c: {
                    'p': [],
                    self.x_var: [],
                    self.y_var: []
                }
                for c in stability.get_2d_stability_types()
            }

            # fixed point
            for p in self.resolutions[self.dpar_names[0]]:
                xs, ys = f_fixed_point(p)
                for x, y in zip(xs, ys):
                    dfdx = f_jacobian(x, y, p)
                    fp_type = stability.stability_analysis(dfdx)
                    container[fp_type]['p'].append(p)
                    container[fp_type][self.x_var].append(x)
                    container[fp_type][self.y_var].append(y)

            # visualization
            for var in self.dvar_names:
                plt.figure(var)
                for fp_type, points in container.items():
                    if len(points['p']):
                        plot_style = stability.plot_scheme[fp_type]
                        plt.plot(points['p'],
                                 points[var],
                                 '.',
                                 **plot_style,
                                 label=fp_type)
                plt.xlabel(self.dpar_names[0])
                plt.ylabel(var)

                # scale = (self.options.lim_scale - 1) / 2
                # plt.xlim(*utils.rescale(self.target_pars[self.dpar_names[0]], scale=scale))
                # plt.ylim(*utils.rescale(self.target_vars[var], scale=scale))

                plt.legend()
            if show:
                plt.show()

        # bifurcation analysis of co-dimension 2
        elif len(self.target_pars) == 2:
            container = {
                c: {
                    'p0': [],
                    'p1': [],
                    self.x_var: [],
                    self.y_var: []
                }
                for c in stability.get_2d_stability_types()
            }

            # fixed point
            for p0 in self.resolutions[self.dpar_names[0]]:
                for p1 in self.resolutions[self.dpar_names[1]]:
                    xs, ys = f_fixed_point(p0, p1)
                    for x, y in zip(xs, ys):
                        dfdx = f_jacobian(x, y, p0, p1)
                        fp_type = stability.stability_analysis(dfdx)
                        container[fp_type]['p0'].append(p0)
                        container[fp_type]['p1'].append(p1)
                        container[fp_type][self.x_var].append(x)
                        container[fp_type][self.y_var].append(y)

            # visualization
            for var in self.dvar_names:
                fig = plt.figure(var)
                ax = fig.gca(projection='3d')
                for fp_type, points in container.items():
                    if len(points['p0']):
                        plot_style = stability.plot_scheme[fp_type]
                        xs = points['p0']
                        ys = points['p1']
                        zs = points[var]
                        ax.scatter(xs, ys, zs, **plot_style, label=fp_type)

                ax.set_xlabel(self.dpar_names[0])
                ax.set_ylabel(self.dpar_names[1])
                ax.set_zlabel(var)

                # scale = (self.options.lim_scale - 1) / 2
                # ax.set_xlim(*utils.rescale(self.target_pars[self.dpar_names[0]], scale=scale))
                # ax.set_ylim(*utils.rescale(self.target_pars[self.dpar_names[1]], scale=scale))
                # ax.set_zlim(*utils.rescale(self.target_vars[var], scale=scale))

                ax.grid(True)
                ax.legend()
            if show:
                plt.show()

        else:
            raise ValueError('Unknown length of parameters.')

        self.fixed_points = container
        return container
Ejemplo n.º 8
0
    def plot_bifurcation(self, show=False):
        print('plot bifurcation ...')

        f_fixed_point = self.get_f_fixed_point()
        f_dfdx = self.get_f_dfdx()

        if len(self.target_pars) == 1:
            container = {
                c: {
                    'p': [],
                    'x': []
                }
                for c in stability.get_1d_stability_types()
            }

            # fixed point
            par_a = self.dpar_names[0]
            for p in self.resolutions[par_a]:
                xs = f_fixed_point(p)
                for x in xs:
                    dfdx = f_dfdx(x, p)
                    fp_type = stability.stability_analysis(dfdx)
                    container[fp_type]['p'].append(p)
                    container[fp_type]['x'].append(x)

            # visualization
            plt.figure(self.x_var)
            for fp_type, points in container.items():
                if len(points['x']):
                    plot_style = stability.plot_scheme[fp_type]
                    plt.plot(points['p'],
                             points['x'],
                             '.',
                             **plot_style,
                             label=fp_type)
            plt.xlabel(par_a)
            plt.ylabel(self.x_var)

            # scale = (self.options.lim_scale - 1) / 2
            # plt.xlim(*utils.rescale(self.target_pars[self.dpar_names[0]], scale=scale))
            # plt.ylim(*utils.rescale(self.target_vars[self.x_var], scale=scale))

            plt.legend()
            if show:
                plt.show()

        elif len(self.target_pars) == 2:
            container = {
                c: {
                    'p0': [],
                    'p1': [],
                    'x': []
                }
                for c in stability.get_1d_stability_types()
            }

            # fixed point
            for p0 in self.resolutions[self.dpar_names[0]]:
                for p1 in self.resolutions[self.dpar_names[1]]:
                    xs = f_fixed_point(p0, p1)
                    for x in xs:
                        dfdx = f_dfdx(x, p0, p1)
                        fp_type = stability.stability_analysis(dfdx)
                        container[fp_type]['p0'].append(p0)
                        container[fp_type]['p1'].append(p1)
                        container[fp_type]['x'].append(x)

            # visualization
            fig = plt.figure(self.x_var)
            ax = fig.gca(projection='3d')
            for fp_type, points in container.items():
                if len(points['x']):
                    plot_style = stability.plot_scheme[fp_type]
                    xs = points['p0']
                    ys = points['p1']
                    zs = points['x']
                    ax.scatter(xs, ys, zs, **plot_style, label=fp_type)

            ax.set_xlabel(self.dpar_names[0])
            ax.set_ylabel(self.dpar_names[1])
            ax.set_zlabel(self.x_var)

            # scale = (self.options.lim_scale - 1) / 2
            # ax.set_xlim(*utils.rescale(self.target_pars[self.dpar_names[0]], scale=scale))
            # ax.set_ylim(*utils.rescale(self.target_pars[self.dpar_names[1]], scale=scale))
            # ax.set_zlim(*utils.rescale(self.target_vars[self.x_var], scale=scale))

            ax.grid(True)
            ax.legend()
            if show:
                plt.show()

        else:
            raise errors.ModelUseError(
                f'Cannot visualize co-dimension {len(self.target_pars)} '
                f'bifurcation.')
        return container