def truncated_countplot( x : pd.Series, val : Any = 'mode', ax : plt.Axes = None ) -> plt.Axes: """ Truncated count plot to visualize more values when one dominates Arguments: x : Data Series val : Value to truncate in count plot. 'mode' will truncate the data mode. ax : matplotlib Axes object to draw plot onto Returns: ax : Returns the Axes object with the plot drawn onto it """ # Setup Axes if not ax: fig, ax = plt.subplots() ax.set_xlabel(x.name) ax.set_ylabel('Counts') if val is None: sns.countplot(x=x, ax=ax) return if val == 'mode': val = x.mode().iloc[0] # Plot and truncate splot = sns.countplot(x=x, ax=ax) ymax = x[x != val].value_counts().iloc[0]*1.4 ax.set_ylim(0, ymax) # Annotate truncated bin xticklabels = [x.get_text() for x in ax.get_xticklabels()] val_ibin = xticklabels.index(str(val)) val_bin = splot.patches[val_ibin] xloc = val_bin.get_x() + 0.5*val_bin.get_width() yloc = ymax ax.annotate('', xy=(xloc, 0), xytext=(xloc, yloc), xycoords='data', arrowprops=dict(arrowstyle = '<-', color = 'black', lw = '4') ) val_count = (x == val).sum() val_perc = val_count / len(x) ax.annotate(f'{val} (count={val_count}; {val_perc:.0%} of total)', xy=(0.5, 0), xytext=(0.5, 0.9), xycoords='axes fraction', ha='center' ) return ax
def _radar(df: pd.DataFrame, ax: plt.Axes, label: str, all_tags: Sequence[str], color: str, alpha: float = 0.2, edge_alpha: float = 0.85, zorder: int = 2, edge_style: str = '-'): """Plot utility for generating the underlying radar plot.""" tmp = df.groupby('tag').mean().reset_index() values = [] for curr_tag in all_tags: score = 0. selected = tmp[tmp['tag'] == curr_tag] if len(selected) == 1: score = float(selected['score']) else: print('{} bsuite scores found for tag {!r} with setting {!r}. ' 'Replacing with zero.'.format(len(selected), curr_tag, label)) values.append(score) values = np.maximum(values, 0.05) # don't let radar collapse to 0. values = np.concatenate((values, [values[0]])) angles = np.linspace(0, 2 * np.pi, len(all_tags), endpoint=False) angles = np.concatenate((angles, [angles[0]])) ax.plot(angles, values, '-', linewidth=5, label=label, c=color, alpha=edge_alpha, zorder=zorder, linestyle=edge_style) ax.fill(angles, values, alpha=alpha, color=color, zorder=zorder) ax.set_thetagrids(angles * 180 / np.pi, map(_tag_pretify, all_tags), fontsize=18) # To avoid text on top of gridlines, we flip horizontalalignment # based on label location text_angles = np.rad2deg(angles) for label, angle in zip(ax.get_xticklabels()[:-1], text_angles[:-1]): if 90 <= angle <= 270: label.set_horizontalalignment('right') else: label.set_horizontalalignment('left')
def plot_policy( ax: plt.Axes, title: str = "Value", ): img = np.flipud(learner.policy) ax.imshow( img, cmap=plt.get_cmap("Spectral_r"), vmin=env.act_space.min, vmax=env.act_space.max, ) # We don't want to show all ticks... ticks_range = np.arange(learner.obs_space_range) ticks_plot = [ticks_range[0], ticks_range[-1]] ax.set_xticks(ticks_plot) ax.set_yticks(np.flip(ticks_plot)) ax.set_xticklabels(ticks_plot) ax.set_yticklabels(ticks_plot) # Rotate the tick labels and set their alignment. plt.setp(ax.get_xticklabels(), rotation=45, ha="right", rotation_mode="anchor") # Loop over data dimensions and create text annotations. for i in range(learner.obs_space_range): for j in range(learner.obs_space_range): ax.text(j, i, img[i, j], ha="center", va="center", color="b", fontsize=8) ax.set_title(title)
def full_extent(fig: plt.Figure, ax: plt.Axes, *extras, pad=0.01): """ Get the full extent of an axes, including axes labels, tick labels, and titles. """ # For text objects, we need to draw the figure first, otherwise the extents # are undefined. ax.figure.canvas.draw() items = [ax, ax.title, ax.xaxis.label, ax.yaxis.label, *extras] items += [*ax.get_xticklabels(), *ax.get_yticklabels()] items += [e.xaxis.label for e in extras if hasattr(e, 'xaxis')] items += [e.yaxis.label for e in extras if hasattr(e, 'yaxis')] items += sum((e.get_xticklabels() for e in extras if hasattr(e, 'get_xticklabels')), []) items += sum((e.get_yticklabels() for e in extras if hasattr(e, 'get_yticklabels')), []) bbox = Bbox.union([ item.get_window_extent() for item in items if hasattr(item, 'get_window_extent') ]) bbox = bbox.expanded(1.0 + pad, 1.0 + pad) bbox = bbox.transformed(fig.dpi_scale_trans.inverted()) return bbox
def style_axes( ax: plt.Axes, title: str = "", legend_title: str = "", xlab: str = "", ylab: str = "", title_loc: Literal["center", "left", "right"] = "center", title_pad: float = None, title_fontsize: int = 10, label_fontsize: int = 8, tick_fontsize: int = 8, change_xticks: bool = True, add_legend: bool = True, ) -> None: """Style an axes object. Parameters ---------- ax Axis object to style. title Figure title. legend_title Figure legend title. xlab Label for the x axis. ylab Label for the y axis. title_loc Position of the plot title (can be {'center', 'left', 'right'}). title_pad Padding of the plot title. title_fontsize Font size of the plot title. label_fontsize Font size of the axis labels. tick_fontsize Font size of the axis tick labels. change_xticks REmoves ticks from x axis. add_legend Font size of the axis tick labels. """ ax.set_title( title, fontdict={"fontsize": title_fontsize}, pad=title_pad, loc=title_loc ) ax.set_xlabel(xlab, fontsize=label_fontsize) # ax.set_xticklabels(ax.get_xticklabels(), fontsize=tick_fontsize) ax.set_ylabel(ylab, fontsize=label_fontsize) # ax.set_yticklabels(ax.get_yticklabels(), fontsize=tick_fontsize) ax.set_title( title, fontdict={"fontsize": title_fontsize}, pad=title_pad, loc=title_loc ) ax.set_xlabel(xlab, fontsize=label_fontsize) if change_xticks: ax.set_xticklabels( ax.get_xticklabels(), fontsize=tick_fontsize, rotation=30, ha="right" ) xax = ax.get_xaxis() xax.set_tick_params(length=0) ax.set_ylabel(ylab, fontsize=label_fontsize) ax.spines["top"].set_visible(False) ax.spines["right"].set_visible(False) if add_legend: ax.legend( title=legend_title, loc="upper left", bbox_to_anchor=(1.2, 1), title_fontsize=label_fontsize, fontsize=tick_fontsize, frameon=False, ) ax.set_position([0.1, 0.3, 0.6, 0.55])
def matrix(values: np.ndarray, row_labels: Sequence[str] = None, col_labels: Sequence[str] = None, row_seps: Union[int, Collection[int]] = None, col_seps: Union[int, Collection[int]] = None, cmap="RdBu", fontcolor_thresh=0.5, norm: plt.Normalize = None, text_len=4, omit_leading_zero=False, trailing_zeros=False, grid=True, angle_left=False, cbar=True, cbar_label: str = None, ax: plt.Axes = None, figsize: Tuple[int, int] = None, cellsize=0.65, title: str = None): cmap = get_cmap(cmap) # Create figure if necessary. if ax is None: if figsize is None: # Note the extra width factor for the colorbar. figsize = (cellsize * values.shape[1] * (1.2 if cbar else 1), cellsize * values.shape[0]) ax = plt.figure(figsize=figsize).gca() # Set title if applicable. if title is not None: ax.set_title(title) if row_seps is not None: values = np.insert(values, row_seps, np.nan, axis=0) if row_labels is not None: row_labels = np.insert(row_labels, row_seps, "") if col_seps is not None: values = np.insert(values, col_seps, np.nan, axis=1) if col_labels is not None: col_labels = np.insert(col_labels, col_seps, "") # Plot the heatmap. im = ax.matshow(values, cmap=cmap, norm=norm) # Plot the text annotations showing each cell's value. norm_values = im.norm(values) for row, col in product(range(values.shape[0]), range(values.shape[1])): val = values[row, col] if not np.isnan(val): # Find text color. bg_color = cmap(norm_values[row, col])[:3] luma = 0.299 * bg_color[0] + 0.587 * bg_color[ 1] + 0.114 * bg_color[2] color = "white" if luma < fontcolor_thresh else "black" # Plot cell text. annotation = _format_value(val, text_len, omit_leading_zero, trailing_zeros) ax.text(col, row, annotation, ha="center", va="center", color=color) # Add ticks and labels. if col_labels is None: ax.set_xticks([]) else: col_labels = np.asarray(col_labels) labeled_cols = np.where(col_labels)[0] ax.set_xticks(labeled_cols) ax.set_xticklabels(col_labels[labeled_cols]) if row_labels is None: ax.set_yticks([]) else: row_labels = np.asarray(row_labels) labeled_rows = np.where(row_labels)[0] ax.set_yticks(labeled_rows) ax.set_yticklabels(row_labels[labeled_rows]) ax.tick_params(which="major", bottom=False) plt.setp(ax.get_xticklabels(), rotation=40, ha="left", rotation_mode="anchor") # Turn off spines. for edge, spine in ax.spines.items(): spine.set_visible(False) # Rotate the left labels if applicable. if angle_left: plt.setp(ax.get_yticklabels(), rotation=40, ha="right", rotation_mode="anchor") # Create the white grid if applicable. if grid: # Extra ticks required to avoid glitch. xticks = np.concatenate([[-0.56], np.arange(values.shape[1] + 1) - 0.5, [values.shape[1] - 0.44]]) yticks = np.concatenate([[-0.56], np.arange(values.shape[0] + 1) - 0.5, [values.shape[0] - 0.44]]) ax.set_xticks(xticks, minor=True) ax.set_yticks(yticks, minor=True) ax.grid(which="minor", color="w", linestyle='-', linewidth=3) ax.tick_params(which="minor", bottom=False, top=False, left=False) # Create the colorbar if applicable. if cbar: bar = ax.figure.colorbar(im, ax=ax) bar.ax.set_ylabel(cbar_label) fmt = bar.ax.yaxis.get_major_formatter() if isinstance(fmt, FixedFormatter): fmt.seq = [ _format_value( eval( re.sub(r"[a-z$\\{}]", "", label.replace("times", "*").replace( "^", "**"))), text_len, omit_leading_zero, trailing_zeros) if label else "" for label in fmt.seq ]
def plot(self, plot_these: List[np.ndarray], ax: plt.Axes = None, fout: str = None, poly=True, root: str = "", title: str = None, aspect: int = 3, ticks: list = [10, 5], grid: bool = False): """ Plot the lattice coordinates or lattice as an array Arguments: plot_these - will take up to 2 sets of lattice and crossover coordinates ax - an axes fout - saves the plot as a png with the name of `fout` poly - True will plot the initial polygon vertices root - directory to save file in, defaults to saving in current directory title - this is the title of the plot aspect - this is the aspect ratio of the x and y axes """ assert len( plot_these) <= 4, "Max no. of plots on one axis reached, 4 or less" if not ax: fig, ax = plt.subplots() if grid: plt.grid(True) for label in ax.get_xticklabels() + ax.get_yticklabels(): label.set_fontsize(4) ax.xaxis.set_major_locator(MultipleLocator(ticks[0])) ax.yaxis.set_major_locator(MultipleLocator(ticks[1])) ax.set_xlabel("No. of nucleotides") ax.set_ylabel("No. of strands") point_style = itertools.cycle(["ko", "b.", "r.", "cP"]) point_size = itertools.cycle([0.5, 2.5]) for points in plot_these: if np.shape(points)[1] not in [2, 3]: # if array nodes = self.array_to_coords(points) else: nodes = points # Lattice sites then crossover sites ax.plot(nodes[:, 0], nodes[:, 1], next(point_style), ms=next(point_size), alpha=0.25) if poly: self.plotPolygon(ax, plot_these[0], coords=True) if title: ax.set_title(f"{title}") plt.gca().set_aspect(aspect) if fout: plt.savefig(f"{root}{fout}.png", dpi=500) if not ax: plt.show()
def cases_and_deaths( data: pd.DataFrame, dates: bool = False, ax: plt.Axes = None, smooth: bool = True, cases: str = "cases", deaths: str = "deaths", tight_layout=False, **kwargs, ) -> plt.Axes: """ A simple chart showing observed new cases cases as vertical bars and a smoothed out prediction of this curve. Args: data: A dataframe with ["cases", "deaths"] columns. dates: If True, show dates instead of days in the x-axis. ax: An explicit matplotlib axes. smooth: If True, superimpose a plot of a smoothed-out version of the cases curve. cases: deaths: Name of the cases/deaths columns in the dataframe. """ if not dates: data = data.reset_index(drop=True) # Smoothed data col_names = {cases: _("Cases"), deaths: _("Deaths")} if smooth: from pydemic import fitting as fit smooth = pd.DataFrame( { _("{} (smooth)").format(col_names[cases]): fit.smoothed_diff(data[cases]), _("{} (smooth)").format(col_names[deaths]): fit.smoothed_diff(data[deaths]), }, index=data.index, ) ax = smooth.plot(legend=False, lw=2, ax=ax) # Prepare cases dataframe and plot it kwargs.setdefault("alpha", 0.5) new_cases = data.diff().fillna(0) new_cases = new_cases.rename(col_names, axis=1) if "ylim" not in kwargs: deaths = new_cases.iloc[:, 1] exp = np.log10(deaths[deaths > 0]).mean() exp = min(10, int(exp / 2)) kwargs["ylim"] = (10**exp, None) ax: plt.Axes = new_cases.plot.bar(width=1.0, ax=ax, **kwargs) # Fix xticks periods = 7 if dates else 10 xticks = ax.get_xticks() labels = ax.get_xticklabels() ax.set_xticks(xticks[::periods]) ax.set_xticklabels(labels[::periods]) ax.tick_params("x", rotation=0) ax.set_ylim(1, None) if tight_layout: fig = ax.get_figure() fig.tight_layout() return ax