Exemplo n.º 1
0
    def draw_node(ax: pyplot.Axes,
                  node_info: Dict[str, Any]) -> List[pyplot.Artist]:
        """根据一个节点的信息画出一个节点
        """
        artists = []

        artists.extend(
            ax.plot(node_info['xy'][0],
                    node_info['xy'][1],
                    '.',
                    color=node_info['color']))
        if node_info['label'] not in ('dead', ):
            cir = pyplot.Circle(xy=node_info['xy'],
                                radius=node_info['r'],
                                alpha=node_info['power'] /
                                node_info['total_power'] * 0.1,
                                color=node_info['color'])
            ax.add_artist(cir)
            artists.append(cir)
            if node_info['last_node']:
                artists.extend(
                    ax.plot([node_info['xy'][0], node_info['last_node'][0]],
                            [node_info['xy'][1], node_info['last_node'][1]],
                            marker='_',
                            linewidth=1,
                            alpha=0.2,
                            color='red'))

        return artists
Exemplo n.º 2
0
def add_legend(axes: plt.Axes, plots: Sequence[lines.Line2D],
               **kwargs: Any) -> None:
    """
    Add the legend to the axes object.

    Parameters
    ----------
    axes : matplotlib.pyplot.Axes
        The axes object.
    plots : Sequence[matplotlib.lines.Line2D]
        The plotted data containing the legends.
    kwargs : Any
        Additional kwargs which will be passed to the axes.legend method.

    Returns
    -------
    matplotlib.lines.Line2D
        A list of objects representing the plotted data.
    """
    labels = [plot.get_label() for plot in plots]
    kwargs['fontsize'] = kwargs.get('fontsize', default['fontsize'])
    kwargs['title_fontsize'] = kwargs.get('title_fontsize',
                                          default['fontsize'])
    legend = axes.legend(plots, labels, edgecolor='k', **kwargs)
    legend.get_frame().set_linewidth(default['borderwidth'])
    legend.get_frame().set_edgecolor('k')
    axes.add_artist(legend)
Exemplo n.º 3
0
def plot_multipolygon(geom: Union[Multipolygon, Iterable[Multipolygon]], *args,
        ax: Axes=None, crs: CRS=None, **kwargs):
    """ Plot a Multipolygon geometry, projected to the coordinate system `crs` """
    kwargs.setdefault("facecolors", "none")
    kwargs.setdefault("edgecolors", "black")
    paths = [matplotlib.path.Path(vertices[0][:,:2], closed=True, readonly=True)
             for vertices in geom.get_vertices(crs=crs)]
    coll = matplotlib.collections.PathCollection(paths, *args, **kwargs)
    ax.add_artist(coll)
    return coll
Exemplo n.º 4
0
def median_plot(axes: plt.Axes, x: np.ndarray, y: np.ndarray,  **kwargs):

	perc84 = Line2D([], [], color='k', marker='^', linewidth=1, linestyle='-', markersize=3, label=r'$84^{th}$ percentile')
	perc50 = Line2D([], [], color='k', marker='o', linewidth=1, linestyle='-', markersize=3, label=r'median')
	perc16 = Line2D([], [], color='k', marker='v', linewidth=1, linestyle='-', markersize=3, label=r'$16^{th}$ percentile')
	legend = axes.legend(handles=[perc84, perc50, perc16], loc='lower right', handlelength=2)
	axes.add_artist(legend)
	data_plot = utils.medians_2d(x, y, **kwargs)
	axes.errorbar(data_plot['median_x'], data_plot['median_y'], yerr=data_plot['err_y'],
	              marker='o', ms=2, alpha=1, linestyle='-', capsize=0, linewidth=0.5)
	axes.errorbar(data_plot['median_x'], data_plot['percent16_y'], yerr=data_plot['err_y'],
	              marker='v', ms=2, alpha=1, linestyle='-', capsize=0, linewidth=0.5)
	axes.errorbar(data_plot['median_x'], data_plot['percent84_y'], yerr=data_plot['err_y'],
	              marker='^', ms=2, alpha=1, linestyle='-', capsize=0, linewidth=0.5)
Exemplo n.º 5
0
    def __setup_subplot__(self, rhythm_loop: RhythmLoop, axes: plt.Axes, **kw):
        # avoid stretching the aspect ratio
        axes.axis('equal')
        # noinspection PyTypeChecker
        axes.axis([0, 1, 0, 1])

        # add base rhythm circle
        main_radius = 0.3
        main_center = 0.5, 0.5

        # draws a wedge from the given start pulse to the given end pulse
        def draw_wedge(pulse_start,
                       pulse_end,
                       center=main_center,
                       radius=main_radius,
                       **kw_):
            theta_1, theta_2 = (((90 - (pulse / n_pulses * 360)) % 360)
                                for pulse in (pulse_end, pulse_start))
            axes.add_artist(Wedge(center, radius, theta_1, theta_2, **kw_))

        unit = self.get_unit()
        n_pulses = kw['n_pulses']
        n_pulses_per_measure = int(rhythm_loop.get_measure_duration(unit))

        try:
            n_measures = int(n_pulses / n_pulses_per_measure)
        except ZeroDivisionError:
            n_measures = 0

        # measure wedges
        for i_measure in range(0, n_measures, 2):
            from_pulse = i_measure * n_pulses_per_measure
            to_pulse = (i_measure + 1) * n_pulses_per_measure
            draw_wedge(from_pulse,
                       to_pulse,
                       radius=1.0,
                       fc=to_rgba("gray", 0.25))

        # main circle
        circle = plt.Circle(main_center, main_radius, fc="white")
        axes.add_artist(circle)

        # draw the pulse wedges
        for i_pulse in range(0, n_pulses, 2):
            draw_wedge(i_pulse, i_pulse + 1, fc=to_rgba("gray", 0.25))

        return circle
Exemplo n.º 6
0
def draw_agent(state: typing.Union[torch.Tensor, np.ndarray],
               color: typing.Union[np.ndarray, str],
               env_axes: typing.Tuple[typing.Tuple[float, float],
                                      typing.Tuple[float, float]],
               ax: plt.Axes,
               is_robot: bool = False,
               scale: float = 1.0):
    """Add circle for agent and agent id description. If the state (position) is outside of the scene, just
    do not plot it, return directly instead (state = position or position+velocity)."""
    if not (env_axes[0][0] < state[0] < env_axes[0][1]) or not (
            env_axes[1][0] < state[1] < env_axes[1][1]):
        return
    if type(state) is torch.Tensor:
        state = state.detach().numpy()

    # Read image (differentiate between robot and pedestrian).
    if is_robot:
        image_path = mantrap.utility.io.build_os_path(
            os.path.join("third_party", "visualization", "robot.png"))
    else:
        image_path = mantrap.utility.io.build_os_path(
            os.path.join("third_party", "visualization", "walking.png"))
    image = matplotlib.image.imread(image_path)

    # Rotate image in correct orientation (based on velocity vector if given inside state).
    black_index = np.where(np.isclose(image[:, :, 0:3], np.zeros(3), atol=0.1))
    image[black_index[0], black_index[1], 0:3] = color
    if state.size > 2:
        orientation = np.arctan2(state[3], state[2])  # theta = arctan2(vy/vx)
        if -np.pi / 2 < orientation < np.pi / 2:
            image = np.fliplr(image)

    # Transparent white background.
    color_norm = np.linalg.norm(image, axis=2)
    for ix in range(image.shape[0]):
        for iy in range(image.shape[1]):
            color_norm = np.sum(image[ix, iy, :3])
            image[ix, iy, -1] = 1.0 if color_norm < 2.0 else 0.0

    # Draw image in axes using `OffsetImage`.
    image_box = matplotlib.offsetbox.OffsetImage(image, zoom=scale)
    ab = matplotlib.offsetbox.AnnotationBbox(image_box, (state[0], state[1]),
                                             frameon=False)
    ax.add_artist(ab)

    return ax
Exemplo n.º 7
0
    def __init__(self, x, y, ax: plt.Axes, *, radius=0.05, index=None):
        self.circle = plt.Circle(
            (x, y),
            radius,
            lw=2,
            edgecolor="#84affa",
            facecolor="#c6e5ff",
        )
        self.ax = ax
        ax.add_patch(self.circle)

        if index is None:
            index = str(id(self) % 100)
        self.index = plt.Text(x, y, index, size="small")
        ax.add_artist(self.index)

        self.tos = {}
        self.frs = {}
        return
Exemplo n.º 8
0
    def __draw_track__(self, rhythm_track: Track, axes: plt.Axes, **kw):
        setup_result = kw['setup_ret']
        grid_box = setup_result['grid_box']  # type: Rectangle2D

        onset_positions = self.get_feature_extractor(
            "onset_positions").process(rhythm_track)
        n_steps = rhythm_track.get_duration(self.unit)
        track_ix = kw['track_ix']

        # filter out onsets whose position might have been pushed out of the grid (due to
        # a low rhythm plotter unit)
        for onset_pos in filter(lambda pos: pos < n_steps, onset_positions):
            axes.add_artist(
                plt.Rectangle([grid_box.x + onset_pos, grid_box.y + track_ix],
                              width=1,
                              height=1,
                              facecolor=kw['color'],
                              edgecolor="black",
                              linewidth=0.75,
                              joinstyle="miter"))

        return plt.Rectangle([0, 0], 1, 1)
Exemplo n.º 9
0
    def to_plot(self, n: int = 200, ax: plt.Axes = None, arrows: bool = False):
        """
        Prism.to_plot(n)

        plot in 3d using n points
        """
        if ax is None:
            fig = plt.figure()
            ax = fig.add_subplot(111, projection="3d")
        x = []
        y = []
        z = []
        for ii in range(n):
            p = self.get_point()
            x.append(p[0])
            y.append(p[1])
            z.append(p[2])
        ax.scatter(x, y, z)
        ax.set_xlabel("x")
        ax.set_ylabel("y")
        ax.set_zlabel("z")

        if arrows is True:
            for (axis, sz) in zip([self.axis, self.norm1, self.norm2],
                                  self.size):
                a = Arrow3D(
                    self.center,
                    axis * sz + self.center,
                    mutation_scale=20,
                    lw=1,
                    arrowstyle="-|>",
                    color="k",
                )
                ax.add_artist(a)

        if ax is None:
            return fig
        else:
            return None
Exemplo n.º 10
0
def plot_at_coordinates(images: np.ndarray,
                        coordinates: np.ndarray,
                        zoom: float,
                        ax: plt.Axes = None):
    assert images.ndim == 3 and coordinates.ndim == 2
    assert images.shape[0] == coordinates.shape[0] and coordinates.shape[1] == 2
    if ax is None:
        ax = plt.gca()

    artists = []
    for (x, y), to_plot in zip(coordinates, images):
        im = OffsetImage(to_plot, zoom=zoom, cmap='coolwarm')
        box = AnnotationBbox(im, (x, y), frameon=False)
        artists.append(ax.add_artist(box))
    return artists
Exemplo n.º 11
0
    def __setup_subplot__(self, rhythm_loop: RhythmLoop, axes: plt.Axes, **kw):
        box_notation_setup_res = super().__setup_subplot__(
            rhythm_loop, axes, **kw)
        box_notation_viewport = box_notation_setup_res['viewport']
        box_notation_container = box_notation_setup_res['container']
        box_notation_grid = box_notation_setup_res['grid_box']
        box_notation_textbox = box_notation_setup_res['text_box']

        # compute the polyphonic syncopation vector
        poly_sync_extractor = self.__get_sync_extractor__()
        poly_sync_vector = poly_sync_extractor.process(rhythm_loop)

        # retrieve the levels on which syncopations could occur
        time_sig = rhythm_loop.get_time_signature()
        natural_duration_map = time_sig.get_natural_duration_map(self.unit)
        possible_sync_durations = sorted(set(natural_duration_map))
        sync_level_units = tuple(
            Unit.get(Fraction(d,
                              self.unit.get_note_value().denominator))
            for d in possible_sync_durations)  # type: tp.Tuple[Unit]
        n_sync_levels = len(possible_sync_durations)

        # main syncopations container
        sync_container = Rectangle2D(x=box_notation_container.x,
                                     y=box_notation_container.y_bounds[1] +
                                     self.spacing,
                                     width=box_notation_container.width,
                                     height=n_sync_levels)

        # syncopations grid box
        sync_grid_box = Rectangle2D(x=box_notation_grid.x,
                                    y=sync_container.y,
                                    width=box_notation_grid.width,
                                    height=box_notation_grid.height)

        # draw main syncopations container
        axes.add_artist(
            plt.Rectangle(sync_container.position,
                          sync_container.width,
                          sync_container.height,
                          fill=False,
                          edgecolor=self.line_color,
                          linewidth=self.line_width))

        # add horizontal lines
        for sync_level in range(1, n_sync_levels):
            line_y = sync_container.y + sync_level
            axes.add_artist(
                plt.Line2D(sync_container.x_bounds, [line_y, line_y],
                           color=self.line_color,
                           linewidth=self.line_width))

        # draw horizontal header line
        axes.add_artist(
            plt.Line2D([sync_grid_box.x, sync_grid_box.x],
                       sync_container.y_bounds,
                       color=self.line_color,
                       linewidth=self.line_width))

        # draw vertical grid lines (transparent)
        for step in range(1, sync_grid_box.width):
            line_x = sync_grid_box.x + step
            axes.add_artist(
                plt.Line2D([line_x, line_x],
                           sync_container.y_bounds,
                           color=to_rgba(self.line_color, 0.1),
                           linewidth=self.line_width))

        sync_level_label_x = box_notation_textbox.x + (
            box_notation_textbox.width / 2)

        # draw sync level labels
        for ix, sync_unit in enumerate(sync_level_units):
            sync_unit_label = str(sync_unit.get_note_value())
            axes.text(sync_level_label_x,
                      sync_container.y + ix + 0.5,
                      sync_unit_label,
                      verticalalignment="center",
                      horizontalalignment="center")

        n_steps = rhythm_loop.get_duration(self.unit, ceil=True)

        # draw the syncopations
        for syncopation in poly_sync_vector:
            pos_from, pos_to = syncopation[1:]
            cyclic = False

            while pos_to < pos_from:
                cyclic = True
                pos_to += n_steps

            sync_duration = pos_to - pos_from
            sync_level_ix = possible_sync_durations.index(sync_duration)
            line_x_bounds = sync_grid_box.x + pos_from + 0.5, sync_grid_box.x + pos_to + 0.5
            line_y = sync_grid_box.y + sync_level_ix + 0.5
            sync_line_color = self.cyclic_syncopations_color if cyclic else self.syncopations_color

            axes.add_artist(
                plt.Line2D(line_x_bounds, [line_y, line_y],
                           color=sync_line_color,
                           linewidth=self.line_width,
                           marker="o"))

        # area occupied by both the box notation container and the syncopations container
        combined_area = Rectangle2D(
            x=box_notation_container.x,
            y=box_notation_container.y,
            width=abs(box_notation_container.x_bounds[0] -
                      sync_container.x_bounds[1]),
            height=abs(box_notation_container.y_bounds[0] -
                       sync_container.y_bounds[1]))

        # re-adjust viewport (center the whole thing vertically)
        pad_y = (box_notation_viewport.width - combined_area.height) / 2
        y_bounds = (combined_area.y_bounds[0] - pad_y,
                    combined_area.y_bounds[1] + pad_y)
        viewport = Rectangle2D(x=box_notation_viewport.x,
                               y=y_bounds[0],
                               width=box_notation_viewport.width,
                               height=abs(y_bounds[0] - y_bounds[1]))

        axes.set_xlim(viewport.x_bounds)
        axes.set_ylim(*reversed(viewport.y_bounds))

        return box_notation_setup_res
Exemplo n.º 12
0
def plot_box_notation_grid(axes: plt.Axes,
                           rhythm: RhythmLoop,
                           unit: UnitType,
                           position: Point2D = Point2D(0, 0),
                           textbox_width: int = 2,
                           line_width: int = 1,
                           line_color: str = "black",
                           beat_colors=(to_rgba("black",
                                                0.21), to_rgba("black",
                                                               0.09))):
    # hide axis
    for axis in [axes.xaxis, axes.yaxis]:
        axis.set_ticklabels([])
        axis.set_visible(False)

    common_height = rhythm.get_track_count()

    text_box = Rectangle2D(width=textbox_width,
                           height=common_height,
                           x=position.x,
                           y=position.y)

    grid_box = Rectangle2D(width=rhythm.get_duration(unit, ceil=True),
                           height=common_height,
                           x=text_box.x + text_box.width,
                           y=text_box.y)

    container = Rectangle2D(width=(text_box.width + grid_box.width),
                            height=grid_box.height,
                            x=text_box.x,
                            y=text_box.y)

    # draw beat rectangles
    if beat_colors:
        timesig = rhythm.get_time_signature()
        beat_unit = timesig.get_beat_unit()
        n_steps_per_beat = beat_unit.convert(1, unit, quantize=True)
        n_beats = rhythm.get_duration(beat_unit, ceil=True)

        for beat in range(n_beats):
            beat_in_measure = beat % timesig.numerator
            step = beat * n_steps_per_beat
            axes.add_artist(
                plt.Rectangle([grid_box.x + step, grid_box.y],
                              width=n_steps_per_beat,
                              height=grid_box.height,
                              facecolor=beat_colors[beat_in_measure % 2],
                              fill=True))

    # draw main box
    axes.add_artist(
        plt.Rectangle(container.position,
                      container.width,
                      container.height,
                      fill=False,
                      edgecolor=line_color,
                      linewidth=line_width))

    # add horizontal lines
    for track_ix in range(1, grid_box.height):
        line_y = track_ix + position.x
        axes.add_artist(
            plt.Line2D(container.x_bounds, [line_y, line_y],
                       color=line_color,
                       linewidth=line_width))

    # add vertical lines
    for step in range(grid_box.width):
        step_x = grid_box.x + step
        axes.add_artist(
            plt.Line2D([step_x, step_x], [grid_box.y_bounds],
                       color=line_color,
                       linewidth=line_width,
                       solid_capstyle="butt"))

    track_names = tuple(t.name for t in rhythm.get_track_iterator())
    track_ids = generate_abbreviations(track_names, max_abbreviation_len=3)
    track_y_positions = dict()
    text_x = text_box.x + (text_box.width / 2)

    # draw track name ids
    for track_ix, [track_name,
                   track_id] in enumerate(zip(track_names, track_ids)):
        axes.text(text_x,
                  text_box.y + track_ix + 0.5,
                  track_id,
                  verticalalignment="center",
                  horizontalalignment="center")

        track_y_positions[track_name] = track_ix

    return {
        'grid_box': grid_box,
        'text_box': text_box,
        'container': container,
        'track_y_data': track_y_positions
    }
Exemplo n.º 13
0
def decorate_axes(
    ax: plt.Axes, image_attributes: ImageAttributes, galaxy_attributes: GalaxyAttributes
):
    """
    Decorates the axes with the requested extra information.
    """

    if image_attributes.decorate_radius:
        # Show the radius as a dashed circle
        radius = Circle(
            galaxy_attributes.center[:2],
            radius=galaxy_attributes.radius,
            color=image_attributes.text_color,
            linestyle="dashed",
            fill=False,
        )

        x = galaxy_attributes.center[0]
        y = galaxy_attributes.center[1] + 1.025 * galaxy_attributes.radius

        ax.text(
            x,
            y,
            galaxy_attributes.radius_name,
            ha="center",
            va="bottom",
            color=image_attributes.text_color,
        )

        ax.add_artist(radius)

    if image_attributes.decorate_position:
        # Show the position (in mpc)
        x, y, z = galaxy_attributes.center.to("Mpc").value
        position = f"[{x:.3g}, {y:.3g}, {z:.3g}] Mpc"
        redshift = f"$z={galaxy_attributes.redshift:3.3f}$"

        ax.text(
            0.975,
            0.025,
            f"{redshift}\n{position}",
            ha="right",
            va="bottom",
            multialignment="right",
            transform=ax.transAxes,
            color=image_attributes.text_color,
        )

    if image_attributes.decorate_scalebar:
        x = galaxy_attributes.center[0]
        y = galaxy_attributes.center[1] - 1.025 * galaxy_attributes.radius

        ax.text(
            x,
            y,
            f"{latex_float(galaxy_attributes.radius)}",
            ha="center",
            va="top",
            color=image_attributes.text_color,
        )

        pass

    ptype_title = image_attributes.particle_type.replace("_", " ").title()
    visualise_title = image_attributes.visualise.replace("_", " ").title()

    if image_attributes.decorate_image_type:
        ax.text(
            0.025,
            0.975,
            f"{ptype_title} {visualise_title}",
            ha="left",
            va="top",
            transform=ax.transAxes,
            color=image_attributes.text_color,
        )

        ax.text(
            0.025,
            0.025,
            (
                f"Halo {galaxy_attributes.unique_id}\n"
                f"{image_attributes.projection.replace('on', ' on').title()}"
            ),
            ha="left",
            va="bottom",
            transform=ax.transAxes,
            color=image_attributes.text_color,
        )

    if image_attributes.decorate_masses:
        ax.text(
            0.975,
            0.975,
            (
                f"$M_H$={latex_float(galaxy_attributes.halo_mass.to('Solar_Mass'))}\n"
                f"$M_*$={latex_float(galaxy_attributes.stellar_mass.to('Solar_Mass'))}"
            ),
            color=image_attributes.text_color,
            ha="right",
            va="top",
            transform=ax.transAxes,
        )

    return
Exemplo n.º 14
0
def set_axes_properties(ax: plt.Axes, **kwargs) -> None:
    """
    Ease the configuration of a :class:`matplotlib.axes.Axes` object.

    Parameters
    ----------
    ax : matplotlib.axes.Axes
        The axes.

    Keyword arguments
    -----------------
    fontsize : int
        Font size to use for the plot titles, and axes ticks and labels.
        Defaults to 12.
    title_center : str
        The center title. Defaults to an empty string.
    title_left : str
        The left title. Defaults to an empty string.
    title_right : str
        The right title. Defaults to an empty string.
    x_label : str
        The x-axis label. Defaults to an empty string.
    x_labelcolor : str
        Color of the x-axis label. Defaults to 'black'.
    x_lim : Sequence[int]
        Data limits for the x-axis. Defaults to `None`, i.e., the data limits
        will be left unchanged.
    invert_xaxis : bool
        `True` to make to invert the x-axis, `False` otherwise.
        Defaults to `False`.
    x_scale : str
        The x-axis scale. Defaults to 'linear'.
    x_ticks : sequence[float]
        Sequence of x-axis ticks location. Defaults to `None`.
    x_ticklabels : sequence[str]
        Sequence of x-axis ticks labels. Defaults to `None`.
    x_ticklabels_color : str
        Color for the x-axis ticks labels. Defaults to 'black'.
    x_ticklabels_rotation : float
        Rotation angle of the x-axis ticks labels. Defaults to 0.
    xaxis_minor_ticks_visible : bool
        `True` to show all ticks, either labelled or unlabelled,
        `False` to show only the labelled ticks. Defaults to `False`.
    xaxis_visible : bool
        `False` to make the x-axis invisible. Defaults to `True`.
    y_label : str
        The y-axis label. Defaults to an empty string.
    y_labelcolor : str
        Color of the y-axis label. Defaults to 'black'.
    y_lim : Sequence[int]
        Data limits for the y-axis. Defaults to `None`, i.e., the data limits
        will be left unchanged.
    invert_yaxis : bool
        `True` to make to invert the y-axis, `False` otherwise.
        Defaults to `False`.
    y_scale : str
        The y-axis scale. Defaults to 'linear'.
    y_ticks : sequence[float]
        Sequence of y-axis ticks location. Defaults to `None`.
    y_ticklabels : sequence[str]
        Sequence of y-axis ticks labels. Defaults to `None`.
    y_ticklabels_color : str
        Color for the y-axis ticks labels. Defaults to 'black'.
    y_ticklabels_rotation : float
        Rotation angle of the y-axis ticks labels. Defaults to 0.
    yaxis_minor_ticks_visible : bool
        `True` to show all ticks, either labelled or unlabelled,
        `False` to show only the labelled ticks. Defaults to :obj:`False`.
    yaxis_visible : bool
        :obj:`False` to make the y-axis invisible. Defaults to :obj:`True`.
    z_label : str
        The z-axis label. Defaults to an empty string.
    z_labelcolor : str
        Color of the z-axis label. Defaults to 'black'.
    z_lim : Sequence[int]
        Data limits for the z-axis. Defaults to :obj:`None`, i.e., the data limits
        will be left unchanged.
    invert_zaxis : bool
        :obj:`True` to make to invert the z-axis, :obj:`False` otherwise.
        Defaults to :obj:`False`.
    z_scale : str
        The z-axis scale. Defaults to 'linear'.
    z_ticks : sequence[float]
        Sequence of z-axis ticks location. Defaults to :obj:`None`.
    z_ticklabels : sequence[str]
        Sequence of z-axis ticks labels. Defaults to :obj:`None`.
    z_ticklabels_color : str
        Rotation angle of the z-axis ticks labels. Defaults to 0.
    z_ticklabels_rotation : float
        Color for the z-axis ticks labels. Defaults to 'black'.
    zaxis_minor_ticks_visible : bool
        :obj:`True` to show all ticks, either labelled or unlabelled,
        :obj:`False` to show only the labelled ticks. Defaults to :obj:`False`.
    zaxis_visible : bool
        :obj:`False` to make the z-axis invisible. Defaults to :obj:`True`.
    legend_on : bool
        :obj:`True` to show the legend, :obj:`False` otherwise. Defaults to :obj:`False`.
    legend_loc : str
        String specifying the location where the legend should be placed.
        Defaults to 'best'; please see :func:`matplotlib.pyplot.legend` for all
        the available options.
    legend_bbox_to_anchor : Sequence[float]
        4-items tuple defining the box used to place the legend. This is used in
        conjuction with `legend_loc` to allow arbitrary placement of the legend.
    legend_framealpha : float
        Legend transparency. It should be between 0 and 1; defaults to 0.5.
    legend_ncol : int
        Number of columns into which the legend labels should be arranged.
        Defaults to 1.
    text : str
        Text to be added to the figure as anchored text. Defaults to :obj:`None`,
        meaning that no text box is shown.
    text_loc : str
        String specifying the location where the text box should be placed.
        Defaults to 'upper right'; please see :class:`matplotlib.offsetbox.AnchoredText`
        for all the available options.
    grid_on : bool
        :obj:`True` to show the plot grid, :obj:`False` otherwise.
        Defaults to :obj:`False`.
    grid_properties : dict
        Keyword arguments specifying various settings of the plot grid.
    """
    fontsize = kwargs.get("fontsize", 12)
    # title
    title_center = kwargs.get("title_center", "")
    title_left = kwargs.get("title_left", "")
    title_right = kwargs.get("title_right", "")
    # x-axis
    x_label = kwargs.get("x_label", "")
    x_labelcolor = kwargs.get("x_labelcolor", "black")
    x_lim = kwargs.get("x_lim", None)
    invert_xaxis = kwargs.get("invert_xaxis", False)
    x_scale = kwargs.get("x_scale", "linear")
    x_ticks = kwargs.get("x_ticks", None)
    x_ticklabels = kwargs.get("x_ticklabels", None)
    x_ticklabels_color = kwargs.get("x_ticklabels_color", "black")
    x_ticklabels_rotation = kwargs.get("x_ticklabels_rotation", 0)
    x_tickformat = kwargs.get("x_tickformat", None)
    xaxis_minor_ticks_visible = kwargs.get("xaxis_minor_ticks_visible", False)
    xaxis_visible = kwargs.get("xaxis_visible", True)
    # y-axis
    y_label = kwargs.get("y_label", "")
    y_labelcolor = kwargs.get("y_labelcolor", "black")
    y_lim = kwargs.get("y_lim", None)
    invert_yaxis = kwargs.get("invert_yaxis", False)
    y_scale = kwargs.get("y_scale", "linear")
    y_ticks = kwargs.get("y_ticks", None)
    y_ticklabels = kwargs.get("y_ticklabels", None)
    y_ticklabels_color = kwargs.get("y_ticklabels_color", "black")
    y_ticklabels_rotation = kwargs.get("y_ticklabels_rotation", 0)
    y_tickformat = kwargs.get("y_tickformat", None)
    yaxis_minor_ticks_visible = kwargs.get("yaxis_minor_ticks_visible", False)
    yaxis_visible = kwargs.get("yaxis_visible", True)
    # legend
    legend_on = kwargs.get("legend_on", False)
    legend_loc = kwargs.get("legend_loc", "best")
    legend_bbox_to_anchor = kwargs.get("legend_bbox_to_anchor", None)
    legend_framealpha = kwargs.get("legend_framealpha", 0.5)
    legend_ncol = kwargs.get("legend_ncol", 1)
    legend_fontsize = kwargs.get("legend_fontsize", fontsize)
    # textbox
    text = kwargs.get("text", None)
    text_loc = kwargs.get("text_loc", "")
    # grid
    grid_on = kwargs.get("grid_on", False)
    grid_properties = kwargs.get("grid_properties", None)

    rcParams["font.size"] = fontsize
    # rcParams['text.usetex'] = True

    # plot titles
    if ax.get_title(loc="center") == "":
        ax.set_title(title_center, loc="center", fontsize=rcParams["font.size"] - 1)
    if ax.get_title(loc="left") == "":
        ax.set_title(title_left, loc="left", fontsize=rcParams["font.size"] - 1)
    if ax.get_title(loc="right") == "":
        ax.set_title(title_right, loc="right", fontsize=rcParams["font.size"] - 1)

    # axes labels
    if ax.get_xlabel() == "":
        ax.set(xlabel=x_label)
    if ax.get_ylabel() == "":
        ax.set(ylabel=y_label)

    # axes labelcolors
    if ax.get_xlabel() != "" and x_labelcolor != "":
        ax.xaxis.label.set_color(x_labelcolor)
    if ax.get_ylabel() != "" and y_labelcolor != "":
        ax.yaxis.label.set_color(y_labelcolor)

    # axes limits
    if x_lim is not None:
        ax.set_xlim(x_lim)
    if y_lim is not None:
        ax.set_ylim(y_lim)

    # invert the axes
    if invert_xaxis:
        ax.invert_xaxis()
    if invert_yaxis:
        ax.invert_yaxis()

    # axes scale
    if x_scale is not None:
        ax.set_xscale(x_scale)
    if y_scale is not None:
        ax.set_yscale(y_scale)

    # axes ticks
    if x_ticks is not None:
        ax.get_xaxis().set_ticks(x_ticks)
    if y_ticks is not None:
        ax.get_yaxis().set_ticks(y_ticks)

    # axes tick labels
    if x_ticklabels is not None:
        ax.get_xaxis().set_ticklabels(x_ticklabels)
    if y_ticklabels is not None:
        ax.get_yaxis().set_ticklabels(y_ticklabels)

    # axes tick labels color
    if x_ticklabels_color != "":
        ax.tick_params(axis="x", colors=x_ticklabels_color)
    if y_ticklabels_color != "":
        ax.tick_params(axis="y", colors=y_ticklabels_color)

    # axes tick format
    if x_tickformat is not None:
        ax.xaxis.set_major_formatter(FormatStrFormatter(x_tickformat))
    if y_tickformat is not None:
        ax.yaxis.set_major_formatter(FormatStrFormatter(y_tickformat))

    # axes tick labels rotation
    plt.xticks(rotation=x_ticklabels_rotation)
    plt.yticks(rotation=y_ticklabels_rotation)

    # unlabelled axes ticks
    if not xaxis_minor_ticks_visible:
        ax.get_xaxis().set_tick_params(which="minor", size=0)
        ax.get_xaxis().set_tick_params(which="minor", width=0)
    if not yaxis_minor_ticks_visible:
        ax.get_yaxis().set_tick_params(which="minor", size=0)
        ax.get_yaxis().set_tick_params(which="minor", width=0)

    # axes visibility
    if not xaxis_visible:
        ax.get_xaxis().set_visible(False)
    if not yaxis_visible:
        ax.get_yaxis().set_visible(False)

    # legend
    if legend_on:
        if legend_bbox_to_anchor is None:
            ax.legend(
                loc=legend_loc,
                framealpha=legend_framealpha,
                ncol=legend_ncol,
                fontsize=legend_fontsize,
            )
        else:
            ax.legend(
                loc=legend_loc,
                framealpha=legend_framealpha,
                ncol=legend_ncol,
                fontsize=legend_fontsize,
                bbox_to_anchor=legend_bbox_to_anchor,
            )

    # text box
    if text is not None:
        ax.add_artist(AnchoredText(text, loc=text_locations[text_loc]))

    # plot grid
    if grid_on:
        gps = grid_properties if grid_properties is not None else {}
        ax.grid(True, **gps)
    def plot_stacked_bar(self,
                         ax: plt.Axes,
                         width: float = 0.8,
                         plot_legend: bool = True):
        p, labels = [], []
        bottom = np.zeros(len(ExpConfig.glo_exp2))

        def whiten_color(color):
            return [c + (.8 - c) / 2 for c in to_rgb(color)]

        colormap = [('C', 'high', whiten_color('darkgoldenrod')),
                    ('C', 'low', whiten_color('goldenrod')),
                    ('H', 'low', whiten_color('green')),
                    ('H', 'high', whiten_color('darkgreen'))]
        for choice, confidence, color in colormap:
            y = [
                len(self.df[(self.df['choice'] == choice)
                            & (self.df['confidence'] == confidence) &
                            (self.df['ground_truth'] == f'{s:.2f}')]) /
                (len(self.df) / len(ExpConfig.glo_exp2))
                for s in ExpConfig.glo_exp2
            ]
            p.append(
                ax.bar(self.x,
                       y,
                       width=width,
                       bottom=bottom,
                       color=color,
                       alpha=0.85)[0])
            labels.append(f'{"C" if choice=="C" else "H"} {confidence}')
            bottom += y
        ax.set_xlim(-width / 2, self.x[-1] + width / 2)
        ax.set_xticks([])
        ax.set_ylim(0, 1)
        ax.set_ylabel(r'$P$(choice=$C\,|\,\bf{X}$)')
        dx, dy = ax.transAxes.transform((1, 1)) - ax.transAxes.transform(
            (0, 0))
        w, r = .9, .7 / 12
        aspect = dx / ((self.x[-1] + w) * dy - dx)
        for x in self.x:
            ax.add_patch(
                FancyBboxPatch((x - w / 2, -aspect),
                               w,
                               w * aspect,
                               boxstyle='Round,pad=0,rounding_size=0.05',
                               fc='#E6E6E6',
                               ec='#B3B3B3',
                               lw=1,
                               clip_on=False,
                               mutation_aspect=aspect))
            nodes = []
            for dx in [-5 * r, 0, 5 * r]:
                nodes.append((x + dx, (-.9 + r) * aspect))
            nodes.append((x - 2.5 * r, -.72 * aspect))
            if x > 0:
                nodes.append((x, (-.2 - .52 * ExpConfig.glo_exp2[x]) * aspect))
            for node in nodes:
                ax.add_artist(
                    Ellipse(node, 2 * r, 2 * r * aspect, fc='k',
                            clip_on=False))
            ax.add_patch(
                ConnectionPatch(nodes[0],
                                nodes[3],
                                'data',
                                'data',
                                clip_on=False))
            ax.plot([nodes[0][0], nodes[3][0]], [nodes[0][1], nodes[3][1]],
                    color='k',
                    clip_on=False)
            ax.plot([nodes[1][0], nodes[3][0]], [nodes[1][1], nodes[3][1]],
                    color='k',
                    clip_on=False)
            if x == 0:
                ax.plot([nodes[2][0], nodes[2][0]],
                        [nodes[2][1], -.2 * aspect],
                        color='k',
                        clip_on=False)
                ax.plot([nodes[3][0], nodes[3][0]],
                        [nodes[3][1], -.2 * aspect],
                        color='k',
                        clip_on=False)
            else:
                ax.plot([nodes[2][0], nodes[4][0]], [nodes[2][1], nodes[4][1]],
                        color='k',
                        clip_on=False)
                ax.plot([nodes[3][0], nodes[4][0]], [nodes[3][1], nodes[4][1]],
                        color='k',
                        clip_on=False)
                ax.plot([nodes[4][0], nodes[4][0]],
                        [nodes[4][1], -.2 * aspect],
                        color='k',
                        clip_on=False)
        ax.text(-width / 2, -0.15 * aspect, '$C$', ha='left', va='top')
        ax.text(self.x[-1] - width / 2,
                -0.15 * aspect,
                '$H$',
                ha='left',
                va='top')
        if plot_legend:
            ax.add_artist(plt.legend(p[::-1], labels[::-1], loc='lower left'))
        else:
            return p[::-1], labels[::-1]