def draw_node(ax: pyplot.Axes, node_info: Dict[str, Any]) -> List[pyplot.Artist]: """根据一个节点的信息画出一个节点 """ artists = [] artists.extend( ax.plot(node_info['xy'][0], node_info['xy'][1], '.', color=node_info['color'])) if node_info['label'] not in ('dead', ): cir = pyplot.Circle(xy=node_info['xy'], radius=node_info['r'], alpha=node_info['power'] / node_info['total_power'] * 0.1, color=node_info['color']) ax.add_artist(cir) artists.append(cir) if node_info['last_node']: artists.extend( ax.plot([node_info['xy'][0], node_info['last_node'][0]], [node_info['xy'][1], node_info['last_node'][1]], marker='_', linewidth=1, alpha=0.2, color='red')) return artists
def add_legend(axes: plt.Axes, plots: Sequence[lines.Line2D], **kwargs: Any) -> None: """ Add the legend to the axes object. Parameters ---------- axes : matplotlib.pyplot.Axes The axes object. plots : Sequence[matplotlib.lines.Line2D] The plotted data containing the legends. kwargs : Any Additional kwargs which will be passed to the axes.legend method. Returns ------- matplotlib.lines.Line2D A list of objects representing the plotted data. """ labels = [plot.get_label() for plot in plots] kwargs['fontsize'] = kwargs.get('fontsize', default['fontsize']) kwargs['title_fontsize'] = kwargs.get('title_fontsize', default['fontsize']) legend = axes.legend(plots, labels, edgecolor='k', **kwargs) legend.get_frame().set_linewidth(default['borderwidth']) legend.get_frame().set_edgecolor('k') axes.add_artist(legend)
def plot_multipolygon(geom: Union[Multipolygon, Iterable[Multipolygon]], *args, ax: Axes=None, crs: CRS=None, **kwargs): """ Plot a Multipolygon geometry, projected to the coordinate system `crs` """ kwargs.setdefault("facecolors", "none") kwargs.setdefault("edgecolors", "black") paths = [matplotlib.path.Path(vertices[0][:,:2], closed=True, readonly=True) for vertices in geom.get_vertices(crs=crs)] coll = matplotlib.collections.PathCollection(paths, *args, **kwargs) ax.add_artist(coll) return coll
def median_plot(axes: plt.Axes, x: np.ndarray, y: np.ndarray, **kwargs): perc84 = Line2D([], [], color='k', marker='^', linewidth=1, linestyle='-', markersize=3, label=r'$84^{th}$ percentile') perc50 = Line2D([], [], color='k', marker='o', linewidth=1, linestyle='-', markersize=3, label=r'median') perc16 = Line2D([], [], color='k', marker='v', linewidth=1, linestyle='-', markersize=3, label=r'$16^{th}$ percentile') legend = axes.legend(handles=[perc84, perc50, perc16], loc='lower right', handlelength=2) axes.add_artist(legend) data_plot = utils.medians_2d(x, y, **kwargs) axes.errorbar(data_plot['median_x'], data_plot['median_y'], yerr=data_plot['err_y'], marker='o', ms=2, alpha=1, linestyle='-', capsize=0, linewidth=0.5) axes.errorbar(data_plot['median_x'], data_plot['percent16_y'], yerr=data_plot['err_y'], marker='v', ms=2, alpha=1, linestyle='-', capsize=0, linewidth=0.5) axes.errorbar(data_plot['median_x'], data_plot['percent84_y'], yerr=data_plot['err_y'], marker='^', ms=2, alpha=1, linestyle='-', capsize=0, linewidth=0.5)
def __setup_subplot__(self, rhythm_loop: RhythmLoop, axes: plt.Axes, **kw): # avoid stretching the aspect ratio axes.axis('equal') # noinspection PyTypeChecker axes.axis([0, 1, 0, 1]) # add base rhythm circle main_radius = 0.3 main_center = 0.5, 0.5 # draws a wedge from the given start pulse to the given end pulse def draw_wedge(pulse_start, pulse_end, center=main_center, radius=main_radius, **kw_): theta_1, theta_2 = (((90 - (pulse / n_pulses * 360)) % 360) for pulse in (pulse_end, pulse_start)) axes.add_artist(Wedge(center, radius, theta_1, theta_2, **kw_)) unit = self.get_unit() n_pulses = kw['n_pulses'] n_pulses_per_measure = int(rhythm_loop.get_measure_duration(unit)) try: n_measures = int(n_pulses / n_pulses_per_measure) except ZeroDivisionError: n_measures = 0 # measure wedges for i_measure in range(0, n_measures, 2): from_pulse = i_measure * n_pulses_per_measure to_pulse = (i_measure + 1) * n_pulses_per_measure draw_wedge(from_pulse, to_pulse, radius=1.0, fc=to_rgba("gray", 0.25)) # main circle circle = plt.Circle(main_center, main_radius, fc="white") axes.add_artist(circle) # draw the pulse wedges for i_pulse in range(0, n_pulses, 2): draw_wedge(i_pulse, i_pulse + 1, fc=to_rgba("gray", 0.25)) return circle
def draw_agent(state: typing.Union[torch.Tensor, np.ndarray], color: typing.Union[np.ndarray, str], env_axes: typing.Tuple[typing.Tuple[float, float], typing.Tuple[float, float]], ax: plt.Axes, is_robot: bool = False, scale: float = 1.0): """Add circle for agent and agent id description. If the state (position) is outside of the scene, just do not plot it, return directly instead (state = position or position+velocity).""" if not (env_axes[0][0] < state[0] < env_axes[0][1]) or not ( env_axes[1][0] < state[1] < env_axes[1][1]): return if type(state) is torch.Tensor: state = state.detach().numpy() # Read image (differentiate between robot and pedestrian). if is_robot: image_path = mantrap.utility.io.build_os_path( os.path.join("third_party", "visualization", "robot.png")) else: image_path = mantrap.utility.io.build_os_path( os.path.join("third_party", "visualization", "walking.png")) image = matplotlib.image.imread(image_path) # Rotate image in correct orientation (based on velocity vector if given inside state). black_index = np.where(np.isclose(image[:, :, 0:3], np.zeros(3), atol=0.1)) image[black_index[0], black_index[1], 0:3] = color if state.size > 2: orientation = np.arctan2(state[3], state[2]) # theta = arctan2(vy/vx) if -np.pi / 2 < orientation < np.pi / 2: image = np.fliplr(image) # Transparent white background. color_norm = np.linalg.norm(image, axis=2) for ix in range(image.shape[0]): for iy in range(image.shape[1]): color_norm = np.sum(image[ix, iy, :3]) image[ix, iy, -1] = 1.0 if color_norm < 2.0 else 0.0 # Draw image in axes using `OffsetImage`. image_box = matplotlib.offsetbox.OffsetImage(image, zoom=scale) ab = matplotlib.offsetbox.AnnotationBbox(image_box, (state[0], state[1]), frameon=False) ax.add_artist(ab) return ax
def __init__(self, x, y, ax: plt.Axes, *, radius=0.05, index=None): self.circle = plt.Circle( (x, y), radius, lw=2, edgecolor="#84affa", facecolor="#c6e5ff", ) self.ax = ax ax.add_patch(self.circle) if index is None: index = str(id(self) % 100) self.index = plt.Text(x, y, index, size="small") ax.add_artist(self.index) self.tos = {} self.frs = {} return
def __draw_track__(self, rhythm_track: Track, axes: plt.Axes, **kw): setup_result = kw['setup_ret'] grid_box = setup_result['grid_box'] # type: Rectangle2D onset_positions = self.get_feature_extractor( "onset_positions").process(rhythm_track) n_steps = rhythm_track.get_duration(self.unit) track_ix = kw['track_ix'] # filter out onsets whose position might have been pushed out of the grid (due to # a low rhythm plotter unit) for onset_pos in filter(lambda pos: pos < n_steps, onset_positions): axes.add_artist( plt.Rectangle([grid_box.x + onset_pos, grid_box.y + track_ix], width=1, height=1, facecolor=kw['color'], edgecolor="black", linewidth=0.75, joinstyle="miter")) return plt.Rectangle([0, 0], 1, 1)
def to_plot(self, n: int = 200, ax: plt.Axes = None, arrows: bool = False): """ Prism.to_plot(n) plot in 3d using n points """ if ax is None: fig = plt.figure() ax = fig.add_subplot(111, projection="3d") x = [] y = [] z = [] for ii in range(n): p = self.get_point() x.append(p[0]) y.append(p[1]) z.append(p[2]) ax.scatter(x, y, z) ax.set_xlabel("x") ax.set_ylabel("y") ax.set_zlabel("z") if arrows is True: for (axis, sz) in zip([self.axis, self.norm1, self.norm2], self.size): a = Arrow3D( self.center, axis * sz + self.center, mutation_scale=20, lw=1, arrowstyle="-|>", color="k", ) ax.add_artist(a) if ax is None: return fig else: return None
def plot_at_coordinates(images: np.ndarray, coordinates: np.ndarray, zoom: float, ax: plt.Axes = None): assert images.ndim == 3 and coordinates.ndim == 2 assert images.shape[0] == coordinates.shape[0] and coordinates.shape[1] == 2 if ax is None: ax = plt.gca() artists = [] for (x, y), to_plot in zip(coordinates, images): im = OffsetImage(to_plot, zoom=zoom, cmap='coolwarm') box = AnnotationBbox(im, (x, y), frameon=False) artists.append(ax.add_artist(box)) return artists
def __setup_subplot__(self, rhythm_loop: RhythmLoop, axes: plt.Axes, **kw): box_notation_setup_res = super().__setup_subplot__( rhythm_loop, axes, **kw) box_notation_viewport = box_notation_setup_res['viewport'] box_notation_container = box_notation_setup_res['container'] box_notation_grid = box_notation_setup_res['grid_box'] box_notation_textbox = box_notation_setup_res['text_box'] # compute the polyphonic syncopation vector poly_sync_extractor = self.__get_sync_extractor__() poly_sync_vector = poly_sync_extractor.process(rhythm_loop) # retrieve the levels on which syncopations could occur time_sig = rhythm_loop.get_time_signature() natural_duration_map = time_sig.get_natural_duration_map(self.unit) possible_sync_durations = sorted(set(natural_duration_map)) sync_level_units = tuple( Unit.get(Fraction(d, self.unit.get_note_value().denominator)) for d in possible_sync_durations) # type: tp.Tuple[Unit] n_sync_levels = len(possible_sync_durations) # main syncopations container sync_container = Rectangle2D(x=box_notation_container.x, y=box_notation_container.y_bounds[1] + self.spacing, width=box_notation_container.width, height=n_sync_levels) # syncopations grid box sync_grid_box = Rectangle2D(x=box_notation_grid.x, y=sync_container.y, width=box_notation_grid.width, height=box_notation_grid.height) # draw main syncopations container axes.add_artist( plt.Rectangle(sync_container.position, sync_container.width, sync_container.height, fill=False, edgecolor=self.line_color, linewidth=self.line_width)) # add horizontal lines for sync_level in range(1, n_sync_levels): line_y = sync_container.y + sync_level axes.add_artist( plt.Line2D(sync_container.x_bounds, [line_y, line_y], color=self.line_color, linewidth=self.line_width)) # draw horizontal header line axes.add_artist( plt.Line2D([sync_grid_box.x, sync_grid_box.x], sync_container.y_bounds, color=self.line_color, linewidth=self.line_width)) # draw vertical grid lines (transparent) for step in range(1, sync_grid_box.width): line_x = sync_grid_box.x + step axes.add_artist( plt.Line2D([line_x, line_x], sync_container.y_bounds, color=to_rgba(self.line_color, 0.1), linewidth=self.line_width)) sync_level_label_x = box_notation_textbox.x + ( box_notation_textbox.width / 2) # draw sync level labels for ix, sync_unit in enumerate(sync_level_units): sync_unit_label = str(sync_unit.get_note_value()) axes.text(sync_level_label_x, sync_container.y + ix + 0.5, sync_unit_label, verticalalignment="center", horizontalalignment="center") n_steps = rhythm_loop.get_duration(self.unit, ceil=True) # draw the syncopations for syncopation in poly_sync_vector: pos_from, pos_to = syncopation[1:] cyclic = False while pos_to < pos_from: cyclic = True pos_to += n_steps sync_duration = pos_to - pos_from sync_level_ix = possible_sync_durations.index(sync_duration) line_x_bounds = sync_grid_box.x + pos_from + 0.5, sync_grid_box.x + pos_to + 0.5 line_y = sync_grid_box.y + sync_level_ix + 0.5 sync_line_color = self.cyclic_syncopations_color if cyclic else self.syncopations_color axes.add_artist( plt.Line2D(line_x_bounds, [line_y, line_y], color=sync_line_color, linewidth=self.line_width, marker="o")) # area occupied by both the box notation container and the syncopations container combined_area = Rectangle2D( x=box_notation_container.x, y=box_notation_container.y, width=abs(box_notation_container.x_bounds[0] - sync_container.x_bounds[1]), height=abs(box_notation_container.y_bounds[0] - sync_container.y_bounds[1])) # re-adjust viewport (center the whole thing vertically) pad_y = (box_notation_viewport.width - combined_area.height) / 2 y_bounds = (combined_area.y_bounds[0] - pad_y, combined_area.y_bounds[1] + pad_y) viewport = Rectangle2D(x=box_notation_viewport.x, y=y_bounds[0], width=box_notation_viewport.width, height=abs(y_bounds[0] - y_bounds[1])) axes.set_xlim(viewport.x_bounds) axes.set_ylim(*reversed(viewport.y_bounds)) return box_notation_setup_res
def plot_box_notation_grid(axes: plt.Axes, rhythm: RhythmLoop, unit: UnitType, position: Point2D = Point2D(0, 0), textbox_width: int = 2, line_width: int = 1, line_color: str = "black", beat_colors=(to_rgba("black", 0.21), to_rgba("black", 0.09))): # hide axis for axis in [axes.xaxis, axes.yaxis]: axis.set_ticklabels([]) axis.set_visible(False) common_height = rhythm.get_track_count() text_box = Rectangle2D(width=textbox_width, height=common_height, x=position.x, y=position.y) grid_box = Rectangle2D(width=rhythm.get_duration(unit, ceil=True), height=common_height, x=text_box.x + text_box.width, y=text_box.y) container = Rectangle2D(width=(text_box.width + grid_box.width), height=grid_box.height, x=text_box.x, y=text_box.y) # draw beat rectangles if beat_colors: timesig = rhythm.get_time_signature() beat_unit = timesig.get_beat_unit() n_steps_per_beat = beat_unit.convert(1, unit, quantize=True) n_beats = rhythm.get_duration(beat_unit, ceil=True) for beat in range(n_beats): beat_in_measure = beat % timesig.numerator step = beat * n_steps_per_beat axes.add_artist( plt.Rectangle([grid_box.x + step, grid_box.y], width=n_steps_per_beat, height=grid_box.height, facecolor=beat_colors[beat_in_measure % 2], fill=True)) # draw main box axes.add_artist( plt.Rectangle(container.position, container.width, container.height, fill=False, edgecolor=line_color, linewidth=line_width)) # add horizontal lines for track_ix in range(1, grid_box.height): line_y = track_ix + position.x axes.add_artist( plt.Line2D(container.x_bounds, [line_y, line_y], color=line_color, linewidth=line_width)) # add vertical lines for step in range(grid_box.width): step_x = grid_box.x + step axes.add_artist( plt.Line2D([step_x, step_x], [grid_box.y_bounds], color=line_color, linewidth=line_width, solid_capstyle="butt")) track_names = tuple(t.name for t in rhythm.get_track_iterator()) track_ids = generate_abbreviations(track_names, max_abbreviation_len=3) track_y_positions = dict() text_x = text_box.x + (text_box.width / 2) # draw track name ids for track_ix, [track_name, track_id] in enumerate(zip(track_names, track_ids)): axes.text(text_x, text_box.y + track_ix + 0.5, track_id, verticalalignment="center", horizontalalignment="center") track_y_positions[track_name] = track_ix return { 'grid_box': grid_box, 'text_box': text_box, 'container': container, 'track_y_data': track_y_positions }
def decorate_axes( ax: plt.Axes, image_attributes: ImageAttributes, galaxy_attributes: GalaxyAttributes ): """ Decorates the axes with the requested extra information. """ if image_attributes.decorate_radius: # Show the radius as a dashed circle radius = Circle( galaxy_attributes.center[:2], radius=galaxy_attributes.radius, color=image_attributes.text_color, linestyle="dashed", fill=False, ) x = galaxy_attributes.center[0] y = galaxy_attributes.center[1] + 1.025 * galaxy_attributes.radius ax.text( x, y, galaxy_attributes.radius_name, ha="center", va="bottom", color=image_attributes.text_color, ) ax.add_artist(radius) if image_attributes.decorate_position: # Show the position (in mpc) x, y, z = galaxy_attributes.center.to("Mpc").value position = f"[{x:.3g}, {y:.3g}, {z:.3g}] Mpc" redshift = f"$z={galaxy_attributes.redshift:3.3f}$" ax.text( 0.975, 0.025, f"{redshift}\n{position}", ha="right", va="bottom", multialignment="right", transform=ax.transAxes, color=image_attributes.text_color, ) if image_attributes.decorate_scalebar: x = galaxy_attributes.center[0] y = galaxy_attributes.center[1] - 1.025 * galaxy_attributes.radius ax.text( x, y, f"{latex_float(galaxy_attributes.radius)}", ha="center", va="top", color=image_attributes.text_color, ) pass ptype_title = image_attributes.particle_type.replace("_", " ").title() visualise_title = image_attributes.visualise.replace("_", " ").title() if image_attributes.decorate_image_type: ax.text( 0.025, 0.975, f"{ptype_title} {visualise_title}", ha="left", va="top", transform=ax.transAxes, color=image_attributes.text_color, ) ax.text( 0.025, 0.025, ( f"Halo {galaxy_attributes.unique_id}\n" f"{image_attributes.projection.replace('on', ' on').title()}" ), ha="left", va="bottom", transform=ax.transAxes, color=image_attributes.text_color, ) if image_attributes.decorate_masses: ax.text( 0.975, 0.975, ( f"$M_H$={latex_float(galaxy_attributes.halo_mass.to('Solar_Mass'))}\n" f"$M_*$={latex_float(galaxy_attributes.stellar_mass.to('Solar_Mass'))}" ), color=image_attributes.text_color, ha="right", va="top", transform=ax.transAxes, ) return
def set_axes_properties(ax: plt.Axes, **kwargs) -> None: """ Ease the configuration of a :class:`matplotlib.axes.Axes` object. Parameters ---------- ax : matplotlib.axes.Axes The axes. Keyword arguments ----------------- fontsize : int Font size to use for the plot titles, and axes ticks and labels. Defaults to 12. title_center : str The center title. Defaults to an empty string. title_left : str The left title. Defaults to an empty string. title_right : str The right title. Defaults to an empty string. x_label : str The x-axis label. Defaults to an empty string. x_labelcolor : str Color of the x-axis label. Defaults to 'black'. x_lim : Sequence[int] Data limits for the x-axis. Defaults to `None`, i.e., the data limits will be left unchanged. invert_xaxis : bool `True` to make to invert the x-axis, `False` otherwise. Defaults to `False`. x_scale : str The x-axis scale. Defaults to 'linear'. x_ticks : sequence[float] Sequence of x-axis ticks location. Defaults to `None`. x_ticklabels : sequence[str] Sequence of x-axis ticks labels. Defaults to `None`. x_ticklabels_color : str Color for the x-axis ticks labels. Defaults to 'black'. x_ticklabels_rotation : float Rotation angle of the x-axis ticks labels. Defaults to 0. xaxis_minor_ticks_visible : bool `True` to show all ticks, either labelled or unlabelled, `False` to show only the labelled ticks. Defaults to `False`. xaxis_visible : bool `False` to make the x-axis invisible. Defaults to `True`. y_label : str The y-axis label. Defaults to an empty string. y_labelcolor : str Color of the y-axis label. Defaults to 'black'. y_lim : Sequence[int] Data limits for the y-axis. Defaults to `None`, i.e., the data limits will be left unchanged. invert_yaxis : bool `True` to make to invert the y-axis, `False` otherwise. Defaults to `False`. y_scale : str The y-axis scale. Defaults to 'linear'. y_ticks : sequence[float] Sequence of y-axis ticks location. Defaults to `None`. y_ticklabels : sequence[str] Sequence of y-axis ticks labels. Defaults to `None`. y_ticklabels_color : str Color for the y-axis ticks labels. Defaults to 'black'. y_ticklabels_rotation : float Rotation angle of the y-axis ticks labels. Defaults to 0. yaxis_minor_ticks_visible : bool `True` to show all ticks, either labelled or unlabelled, `False` to show only the labelled ticks. Defaults to :obj:`False`. yaxis_visible : bool :obj:`False` to make the y-axis invisible. Defaults to :obj:`True`. z_label : str The z-axis label. Defaults to an empty string. z_labelcolor : str Color of the z-axis label. Defaults to 'black'. z_lim : Sequence[int] Data limits for the z-axis. Defaults to :obj:`None`, i.e., the data limits will be left unchanged. invert_zaxis : bool :obj:`True` to make to invert the z-axis, :obj:`False` otherwise. Defaults to :obj:`False`. z_scale : str The z-axis scale. Defaults to 'linear'. z_ticks : sequence[float] Sequence of z-axis ticks location. Defaults to :obj:`None`. z_ticklabels : sequence[str] Sequence of z-axis ticks labels. Defaults to :obj:`None`. z_ticklabels_color : str Rotation angle of the z-axis ticks labels. Defaults to 0. z_ticklabels_rotation : float Color for the z-axis ticks labels. Defaults to 'black'. zaxis_minor_ticks_visible : bool :obj:`True` to show all ticks, either labelled or unlabelled, :obj:`False` to show only the labelled ticks. Defaults to :obj:`False`. zaxis_visible : bool :obj:`False` to make the z-axis invisible. Defaults to :obj:`True`. legend_on : bool :obj:`True` to show the legend, :obj:`False` otherwise. Defaults to :obj:`False`. legend_loc : str String specifying the location where the legend should be placed. Defaults to 'best'; please see :func:`matplotlib.pyplot.legend` for all the available options. legend_bbox_to_anchor : Sequence[float] 4-items tuple defining the box used to place the legend. This is used in conjuction with `legend_loc` to allow arbitrary placement of the legend. legend_framealpha : float Legend transparency. It should be between 0 and 1; defaults to 0.5. legend_ncol : int Number of columns into which the legend labels should be arranged. Defaults to 1. text : str Text to be added to the figure as anchored text. Defaults to :obj:`None`, meaning that no text box is shown. text_loc : str String specifying the location where the text box should be placed. Defaults to 'upper right'; please see :class:`matplotlib.offsetbox.AnchoredText` for all the available options. grid_on : bool :obj:`True` to show the plot grid, :obj:`False` otherwise. Defaults to :obj:`False`. grid_properties : dict Keyword arguments specifying various settings of the plot grid. """ fontsize = kwargs.get("fontsize", 12) # title title_center = kwargs.get("title_center", "") title_left = kwargs.get("title_left", "") title_right = kwargs.get("title_right", "") # x-axis x_label = kwargs.get("x_label", "") x_labelcolor = kwargs.get("x_labelcolor", "black") x_lim = kwargs.get("x_lim", None) invert_xaxis = kwargs.get("invert_xaxis", False) x_scale = kwargs.get("x_scale", "linear") x_ticks = kwargs.get("x_ticks", None) x_ticklabels = kwargs.get("x_ticklabels", None) x_ticklabels_color = kwargs.get("x_ticklabels_color", "black") x_ticklabels_rotation = kwargs.get("x_ticklabels_rotation", 0) x_tickformat = kwargs.get("x_tickformat", None) xaxis_minor_ticks_visible = kwargs.get("xaxis_minor_ticks_visible", False) xaxis_visible = kwargs.get("xaxis_visible", True) # y-axis y_label = kwargs.get("y_label", "") y_labelcolor = kwargs.get("y_labelcolor", "black") y_lim = kwargs.get("y_lim", None) invert_yaxis = kwargs.get("invert_yaxis", False) y_scale = kwargs.get("y_scale", "linear") y_ticks = kwargs.get("y_ticks", None) y_ticklabels = kwargs.get("y_ticklabels", None) y_ticklabels_color = kwargs.get("y_ticklabels_color", "black") y_ticklabels_rotation = kwargs.get("y_ticklabels_rotation", 0) y_tickformat = kwargs.get("y_tickformat", None) yaxis_minor_ticks_visible = kwargs.get("yaxis_minor_ticks_visible", False) yaxis_visible = kwargs.get("yaxis_visible", True) # legend legend_on = kwargs.get("legend_on", False) legend_loc = kwargs.get("legend_loc", "best") legend_bbox_to_anchor = kwargs.get("legend_bbox_to_anchor", None) legend_framealpha = kwargs.get("legend_framealpha", 0.5) legend_ncol = kwargs.get("legend_ncol", 1) legend_fontsize = kwargs.get("legend_fontsize", fontsize) # textbox text = kwargs.get("text", None) text_loc = kwargs.get("text_loc", "") # grid grid_on = kwargs.get("grid_on", False) grid_properties = kwargs.get("grid_properties", None) rcParams["font.size"] = fontsize # rcParams['text.usetex'] = True # plot titles if ax.get_title(loc="center") == "": ax.set_title(title_center, loc="center", fontsize=rcParams["font.size"] - 1) if ax.get_title(loc="left") == "": ax.set_title(title_left, loc="left", fontsize=rcParams["font.size"] - 1) if ax.get_title(loc="right") == "": ax.set_title(title_right, loc="right", fontsize=rcParams["font.size"] - 1) # axes labels if ax.get_xlabel() == "": ax.set(xlabel=x_label) if ax.get_ylabel() == "": ax.set(ylabel=y_label) # axes labelcolors if ax.get_xlabel() != "" and x_labelcolor != "": ax.xaxis.label.set_color(x_labelcolor) if ax.get_ylabel() != "" and y_labelcolor != "": ax.yaxis.label.set_color(y_labelcolor) # axes limits if x_lim is not None: ax.set_xlim(x_lim) if y_lim is not None: ax.set_ylim(y_lim) # invert the axes if invert_xaxis: ax.invert_xaxis() if invert_yaxis: ax.invert_yaxis() # axes scale if x_scale is not None: ax.set_xscale(x_scale) if y_scale is not None: ax.set_yscale(y_scale) # axes ticks if x_ticks is not None: ax.get_xaxis().set_ticks(x_ticks) if y_ticks is not None: ax.get_yaxis().set_ticks(y_ticks) # axes tick labels if x_ticklabels is not None: ax.get_xaxis().set_ticklabels(x_ticklabels) if y_ticklabels is not None: ax.get_yaxis().set_ticklabels(y_ticklabels) # axes tick labels color if x_ticklabels_color != "": ax.tick_params(axis="x", colors=x_ticklabels_color) if y_ticklabels_color != "": ax.tick_params(axis="y", colors=y_ticklabels_color) # axes tick format if x_tickformat is not None: ax.xaxis.set_major_formatter(FormatStrFormatter(x_tickformat)) if y_tickformat is not None: ax.yaxis.set_major_formatter(FormatStrFormatter(y_tickformat)) # axes tick labels rotation plt.xticks(rotation=x_ticklabels_rotation) plt.yticks(rotation=y_ticklabels_rotation) # unlabelled axes ticks if not xaxis_minor_ticks_visible: ax.get_xaxis().set_tick_params(which="minor", size=0) ax.get_xaxis().set_tick_params(which="minor", width=0) if not yaxis_minor_ticks_visible: ax.get_yaxis().set_tick_params(which="minor", size=0) ax.get_yaxis().set_tick_params(which="minor", width=0) # axes visibility if not xaxis_visible: ax.get_xaxis().set_visible(False) if not yaxis_visible: ax.get_yaxis().set_visible(False) # legend if legend_on: if legend_bbox_to_anchor is None: ax.legend( loc=legend_loc, framealpha=legend_framealpha, ncol=legend_ncol, fontsize=legend_fontsize, ) else: ax.legend( loc=legend_loc, framealpha=legend_framealpha, ncol=legend_ncol, fontsize=legend_fontsize, bbox_to_anchor=legend_bbox_to_anchor, ) # text box if text is not None: ax.add_artist(AnchoredText(text, loc=text_locations[text_loc])) # plot grid if grid_on: gps = grid_properties if grid_properties is not None else {} ax.grid(True, **gps)
def plot_stacked_bar(self, ax: plt.Axes, width: float = 0.8, plot_legend: bool = True): p, labels = [], [] bottom = np.zeros(len(ExpConfig.glo_exp2)) def whiten_color(color): return [c + (.8 - c) / 2 for c in to_rgb(color)] colormap = [('C', 'high', whiten_color('darkgoldenrod')), ('C', 'low', whiten_color('goldenrod')), ('H', 'low', whiten_color('green')), ('H', 'high', whiten_color('darkgreen'))] for choice, confidence, color in colormap: y = [ len(self.df[(self.df['choice'] == choice) & (self.df['confidence'] == confidence) & (self.df['ground_truth'] == f'{s:.2f}')]) / (len(self.df) / len(ExpConfig.glo_exp2)) for s in ExpConfig.glo_exp2 ] p.append( ax.bar(self.x, y, width=width, bottom=bottom, color=color, alpha=0.85)[0]) labels.append(f'{"C" if choice=="C" else "H"} {confidence}') bottom += y ax.set_xlim(-width / 2, self.x[-1] + width / 2) ax.set_xticks([]) ax.set_ylim(0, 1) ax.set_ylabel(r'$P$(choice=$C\,|\,\bf{X}$)') dx, dy = ax.transAxes.transform((1, 1)) - ax.transAxes.transform( (0, 0)) w, r = .9, .7 / 12 aspect = dx / ((self.x[-1] + w) * dy - dx) for x in self.x: ax.add_patch( FancyBboxPatch((x - w / 2, -aspect), w, w * aspect, boxstyle='Round,pad=0,rounding_size=0.05', fc='#E6E6E6', ec='#B3B3B3', lw=1, clip_on=False, mutation_aspect=aspect)) nodes = [] for dx in [-5 * r, 0, 5 * r]: nodes.append((x + dx, (-.9 + r) * aspect)) nodes.append((x - 2.5 * r, -.72 * aspect)) if x > 0: nodes.append((x, (-.2 - .52 * ExpConfig.glo_exp2[x]) * aspect)) for node in nodes: ax.add_artist( Ellipse(node, 2 * r, 2 * r * aspect, fc='k', clip_on=False)) ax.add_patch( ConnectionPatch(nodes[0], nodes[3], 'data', 'data', clip_on=False)) ax.plot([nodes[0][0], nodes[3][0]], [nodes[0][1], nodes[3][1]], color='k', clip_on=False) ax.plot([nodes[1][0], nodes[3][0]], [nodes[1][1], nodes[3][1]], color='k', clip_on=False) if x == 0: ax.plot([nodes[2][0], nodes[2][0]], [nodes[2][1], -.2 * aspect], color='k', clip_on=False) ax.plot([nodes[3][0], nodes[3][0]], [nodes[3][1], -.2 * aspect], color='k', clip_on=False) else: ax.plot([nodes[2][0], nodes[4][0]], [nodes[2][1], nodes[4][1]], color='k', clip_on=False) ax.plot([nodes[3][0], nodes[4][0]], [nodes[3][1], nodes[4][1]], color='k', clip_on=False) ax.plot([nodes[4][0], nodes[4][0]], [nodes[4][1], -.2 * aspect], color='k', clip_on=False) ax.text(-width / 2, -0.15 * aspect, '$C$', ha='left', va='top') ax.text(self.x[-1] - width / 2, -0.15 * aspect, '$H$', ha='left', va='top') if plot_legend: ax.add_artist(plt.legend(p[::-1], labels[::-1], loc='lower left')) else: return p[::-1], labels[::-1]