예제 #1
0
def truncated_countplot(
    x   : pd.Series,
    val : Any = 'mode',
    ax  : plt.Axes = None
    ) -> plt.Axes:
    """
    Truncated count plot to visualize more values when one dominates

    Arguments:
        x :
            Data Series
        val :
            Value to truncate in count plot. 'mode' will truncate the data mode.
        ax :
            matplotlib Axes object to draw plot onto
    Returns:
        ax :
            Returns the Axes object with the plot drawn onto it
    """
    # Setup Axes
    if not ax:
        fig, ax = plt.subplots()
    ax.set_xlabel(x.name)
    ax.set_ylabel('Counts')

    if val is None:
        sns.countplot(x=x, ax=ax)
        return

    if val == 'mode':
        val = x.mode().iloc[0]

    # Plot and truncate
    splot = sns.countplot(x=x, ax=ax)
    ymax = x[x != val].value_counts().iloc[0]*1.4
    ax.set_ylim(0, ymax)

    # Annotate truncated bin
    xticklabels = [x.get_text() for x in ax.get_xticklabels()]
    val_ibin = xticklabels.index(str(val))
    val_bin = splot.patches[val_ibin]
    xloc = val_bin.get_x() + 0.5*val_bin.get_width()
    yloc = ymax
    ax.annotate('', xy=(xloc, 0), xytext=(xloc, yloc), xycoords='data',
                arrowprops=dict(arrowstyle = '<-', color = 'black', lw = '4')
               )
    val_count = (x == val).sum()
    val_perc = val_count / len(x)
    ax.annotate(f'{val} (count={val_count}; {val_perc:.0%} of total)',
                xy=(0.5, 0), xytext=(0.5, 0.9), xycoords='axes fraction',
                ha='center'
               )

    return ax
예제 #2
0
def _radar(df: pd.DataFrame,
           ax: plt.Axes,
           label: str,
           all_tags: Sequence[str],
           color: str,
           alpha: float = 0.2,
           edge_alpha: float = 0.85,
           zorder: int = 2,
           edge_style: str = '-'):
    """Plot utility for generating the underlying radar plot."""
    tmp = df.groupby('tag').mean().reset_index()

    values = []
    for curr_tag in all_tags:
        score = 0.
        selected = tmp[tmp['tag'] == curr_tag]
        if len(selected) == 1:
            score = float(selected['score'])
        else:
            print('{} bsuite scores found for tag {!r} with setting {!r}. '
                  'Replacing with zero.'.format(len(selected), curr_tag,
                                                label))
        values.append(score)
    values = np.maximum(values, 0.05)  # don't let radar collapse to 0.
    values = np.concatenate((values, [values[0]]))

    angles = np.linspace(0, 2 * np.pi, len(all_tags), endpoint=False)
    angles = np.concatenate((angles, [angles[0]]))

    ax.plot(angles,
            values,
            '-',
            linewidth=5,
            label=label,
            c=color,
            alpha=edge_alpha,
            zorder=zorder,
            linestyle=edge_style)
    ax.fill(angles, values, alpha=alpha, color=color, zorder=zorder)
    ax.set_thetagrids(angles * 180 / np.pi,
                      map(_tag_pretify, all_tags),
                      fontsize=18)

    # To avoid text on top of gridlines, we flip horizontalalignment
    # based on label location
    text_angles = np.rad2deg(angles)
    for label, angle in zip(ax.get_xticklabels()[:-1], text_angles[:-1]):
        if 90 <= angle <= 270:
            label.set_horizontalalignment('right')
        else:
            label.set_horizontalalignment('left')
예제 #3
0
파일: main.py 프로젝트: RDelg/rl-book
    def plot_policy(
        ax: plt.Axes,
        title: str = "Value",
    ):
        img = np.flipud(learner.policy)
        ax.imshow(
            img,
            cmap=plt.get_cmap("Spectral_r"),
            vmin=env.act_space.min,
            vmax=env.act_space.max,
        )

        # We don't want to show all ticks...
        ticks_range = np.arange(learner.obs_space_range)
        ticks_plot = [ticks_range[0], ticks_range[-1]]

        ax.set_xticks(ticks_plot)
        ax.set_yticks(np.flip(ticks_plot))

        ax.set_xticklabels(ticks_plot)
        ax.set_yticklabels(ticks_plot)

        # Rotate the tick labels and set their alignment.
        plt.setp(ax.get_xticklabels(),
                 rotation=45,
                 ha="right",
                 rotation_mode="anchor")

        # Loop over data dimensions and create text annotations.
        for i in range(learner.obs_space_range):
            for j in range(learner.obs_space_range):
                ax.text(j,
                        i,
                        img[i, j],
                        ha="center",
                        va="center",
                        color="b",
                        fontsize=8)

        ax.set_title(title)
예제 #4
0
def full_extent(fig: plt.Figure, ax: plt.Axes, *extras, pad=0.01):
    """ Get the full extent of an axes, including axes labels, tick labels, and
    titles.
    
    """
    # For text objects, we need to draw the figure first, otherwise the extents
    # are undefined.
    ax.figure.canvas.draw()
    items = [ax, ax.title, ax.xaxis.label, ax.yaxis.label, *extras]
    items += [*ax.get_xticklabels(), *ax.get_yticklabels()]
    items += [e.xaxis.label for e in extras if hasattr(e, 'xaxis')]
    items += [e.yaxis.label for e in extras if hasattr(e, 'yaxis')]
    items += sum((e.get_xticklabels()
                  for e in extras if hasattr(e, 'get_xticklabels')), [])
    items += sum((e.get_yticklabels()
                  for e in extras if hasattr(e, 'get_yticklabels')), [])
    bbox = Bbox.union([
        item.get_window_extent() for item in items
        if hasattr(item, 'get_window_extent')
    ])

    bbox = bbox.expanded(1.0 + pad, 1.0 + pad)
    bbox = bbox.transformed(fig.dpi_scale_trans.inverted())
    return bbox
예제 #5
0
파일: styling.py 프로젝트: semir2/scirpy
def style_axes(
    ax: plt.Axes,
    title: str = "",
    legend_title: str = "",
    xlab: str = "",
    ylab: str = "",
    title_loc: Literal["center", "left", "right"] = "center",
    title_pad: float = None,
    title_fontsize: int = 10,
    label_fontsize: int = 8,
    tick_fontsize: int = 8,
    change_xticks: bool = True,
    add_legend: bool = True,
) -> None:
    """Style an axes object.

    Parameters
    ----------
    ax
        Axis object to style.
    title
        Figure title.
    legend_title
        Figure legend title.
    xlab
        Label for the x axis.
    ylab
        Label for the y axis.
    title_loc
        Position of the plot title (can be {'center', 'left', 'right'}).
    title_pad
        Padding of the plot title.
    title_fontsize
        Font size of the plot title.
    label_fontsize
        Font size of the axis labels.
    tick_fontsize
        Font size of the axis tick labels.
    change_xticks
        REmoves ticks from x axis.
    add_legend
        Font size of the axis tick labels.
    """
    ax.set_title(
        title, fontdict={"fontsize": title_fontsize}, pad=title_pad, loc=title_loc
    )
    ax.set_xlabel(xlab, fontsize=label_fontsize)
    # ax.set_xticklabels(ax.get_xticklabels(), fontsize=tick_fontsize)
    ax.set_ylabel(ylab, fontsize=label_fontsize)
    # ax.set_yticklabels(ax.get_yticklabels(), fontsize=tick_fontsize)

    ax.set_title(
        title, fontdict={"fontsize": title_fontsize}, pad=title_pad, loc=title_loc
    )
    ax.set_xlabel(xlab, fontsize=label_fontsize)
    if change_xticks:
        ax.set_xticklabels(
            ax.get_xticklabels(), fontsize=tick_fontsize, rotation=30, ha="right"
        )
        xax = ax.get_xaxis()
        xax.set_tick_params(length=0)
    ax.set_ylabel(ylab, fontsize=label_fontsize)
    ax.spines["top"].set_visible(False)
    ax.spines["right"].set_visible(False)
    if add_legend:
        ax.legend(
            title=legend_title,
            loc="upper left",
            bbox_to_anchor=(1.2, 1),
            title_fontsize=label_fontsize,
            fontsize=tick_fontsize,
            frameon=False,
        )
        ax.set_position([0.1, 0.3, 0.6, 0.55])
예제 #6
0
def matrix(values: np.ndarray,
           row_labels: Sequence[str] = None,
           col_labels: Sequence[str] = None,
           row_seps: Union[int, Collection[int]] = None,
           col_seps: Union[int, Collection[int]] = None,
           cmap="RdBu",
           fontcolor_thresh=0.5,
           norm: plt.Normalize = None,
           text_len=4,
           omit_leading_zero=False,
           trailing_zeros=False,
           grid=True,
           angle_left=False,
           cbar=True,
           cbar_label: str = None,
           ax: plt.Axes = None,
           figsize: Tuple[int, int] = None,
           cellsize=0.65,
           title: str = None):
    cmap = get_cmap(cmap)

    # Create figure if necessary.
    if ax is None:
        if figsize is None:
            # Note the extra width factor for the colorbar.
            figsize = (cellsize * values.shape[1] * (1.2 if cbar else 1),
                       cellsize * values.shape[0])
        ax = plt.figure(figsize=figsize).gca()

    # Set title if applicable.
    if title is not None:
        ax.set_title(title)

    if row_seps is not None:
        values = np.insert(values, row_seps, np.nan, axis=0)
        if row_labels is not None:
            row_labels = np.insert(row_labels, row_seps, "")
    if col_seps is not None:
        values = np.insert(values, col_seps, np.nan, axis=1)
        if col_labels is not None:
            col_labels = np.insert(col_labels, col_seps, "")

    # Plot the heatmap.
    im = ax.matshow(values, cmap=cmap, norm=norm)

    # Plot the text annotations showing each cell's value.
    norm_values = im.norm(values)
    for row, col in product(range(values.shape[0]), range(values.shape[1])):
        val = values[row, col]
        if not np.isnan(val):
            # Find text color.
            bg_color = cmap(norm_values[row, col])[:3]
            luma = 0.299 * bg_color[0] + 0.587 * bg_color[
                1] + 0.114 * bg_color[2]
            color = "white" if luma < fontcolor_thresh else "black"

            # Plot cell text.
            annotation = _format_value(val, text_len, omit_leading_zero,
                                       trailing_zeros)
            ax.text(col,
                    row,
                    annotation,
                    ha="center",
                    va="center",
                    color=color)

    # Add ticks and labels.
    if col_labels is None:
        ax.set_xticks([])
    else:
        col_labels = np.asarray(col_labels)
        labeled_cols = np.where(col_labels)[0]
        ax.set_xticks(labeled_cols)
        ax.set_xticklabels(col_labels[labeled_cols])
    if row_labels is None:
        ax.set_yticks([])
    else:
        row_labels = np.asarray(row_labels)
        labeled_rows = np.where(row_labels)[0]
        ax.set_yticks(labeled_rows)
        ax.set_yticklabels(row_labels[labeled_rows])

    ax.tick_params(which="major", bottom=False)

    plt.setp(ax.get_xticklabels(),
             rotation=40,
             ha="left",
             rotation_mode="anchor")

    # Turn off spines.
    for edge, spine in ax.spines.items():
        spine.set_visible(False)

    # Rotate the left labels if applicable.
    if angle_left:
        plt.setp(ax.get_yticklabels(),
                 rotation=40,
                 ha="right",
                 rotation_mode="anchor")

    # Create the white grid if applicable.
    if grid:
        # Extra ticks required to avoid glitch.
        xticks = np.concatenate([[-0.56],
                                 np.arange(values.shape[1] + 1) - 0.5,
                                 [values.shape[1] - 0.44]])
        yticks = np.concatenate([[-0.56],
                                 np.arange(values.shape[0] + 1) - 0.5,
                                 [values.shape[0] - 0.44]])
        ax.set_xticks(xticks, minor=True)
        ax.set_yticks(yticks, minor=True)
        ax.grid(which="minor", color="w", linestyle='-', linewidth=3)
        ax.tick_params(which="minor", bottom=False, top=False, left=False)

    # Create the colorbar if applicable.
    if cbar:
        bar = ax.figure.colorbar(im, ax=ax)
        bar.ax.set_ylabel(cbar_label)
        fmt = bar.ax.yaxis.get_major_formatter()
        if isinstance(fmt, FixedFormatter):
            fmt.seq = [
                _format_value(
                    eval(
                        re.sub(r"[a-z$\\{}]", "",
                               label.replace("times", "*").replace(
                                   "^", "**"))), text_len, omit_leading_zero,
                    trailing_zeros) if label else "" for label in fmt.seq
            ]
예제 #7
0
    def plot(self,
             plot_these: List[np.ndarray],
             ax: plt.Axes = None,
             fout: str = None,
             poly=True,
             root: str = "",
             title: str = None,
             aspect: int = 3,
             ticks: list = [10, 5],
             grid: bool = False):
        """
        Plot the lattice coordinates or lattice as an array

        Arguments:
            plot_these - will take up to 2 sets of lattice and crossover coordinates
            ax - an axes
            fout - saves the plot as a png with the name of `fout`
            poly - True will plot the initial polygon vertices
            root - directory to save file in, defaults to saving in current directory
            title - this is the title of the plot
            aspect - this is the aspect ratio of the x and y axes

        """
        assert len(
            plot_these) <= 4, "Max no. of plots on one axis reached, 4 or less"

        if not ax:
            fig, ax = plt.subplots()
        if grid:
            plt.grid(True)
        for label in ax.get_xticklabels() + ax.get_yticklabels():
            label.set_fontsize(4)
        ax.xaxis.set_major_locator(MultipleLocator(ticks[0]))
        ax.yaxis.set_major_locator(MultipleLocator(ticks[1]))
        ax.set_xlabel("No. of nucleotides")
        ax.set_ylabel("No. of strands")

        point_style = itertools.cycle(["ko", "b.", "r.", "cP"])
        point_size = itertools.cycle([0.5, 2.5])
        for points in plot_these:
            if np.shape(points)[1] not in [2, 3]:  # if array
                nodes = self.array_to_coords(points)
            else:
                nodes = points
            # Lattice sites then crossover sites
            ax.plot(nodes[:, 0],
                    nodes[:, 1],
                    next(point_style),
                    ms=next(point_size),
                    alpha=0.25)

        if poly:
            self.plotPolygon(ax, plot_these[0], coords=True)
        if title:
            ax.set_title(f"{title}")

        plt.gca().set_aspect(aspect)
        if fout:
            plt.savefig(f"{root}{fout}.png", dpi=500)
        if not ax:
            plt.show()
예제 #8
0
def cases_and_deaths(
    data: pd.DataFrame,
    dates: bool = False,
    ax: plt.Axes = None,
    smooth: bool = True,
    cases: str = "cases",
    deaths: str = "deaths",
    tight_layout=False,
    **kwargs,
) -> plt.Axes:
    """
    A simple chart showing observed new cases cases as vertical bars and
    a smoothed out prediction of this curve.

    Args:
        data:
            A dataframe with ["cases", "deaths"] columns.
        dates:
            If True, show dates instead of days in the x-axis.
        ax:
            An explicit matplotlib axes.
        smooth:
            If True, superimpose a plot of a smoothed-out version of the cases
            curve.
        cases:
        deaths:
            Name of the cases/deaths columns in the dataframe.
    """

    if not dates:
        data = data.reset_index(drop=True)

    # Smoothed data
    col_names = {cases: _("Cases"), deaths: _("Deaths")}
    if smooth:
        from pydemic import fitting as fit

        smooth = pd.DataFrame(
            {
                _("{} (smooth)").format(col_names[cases]):
                fit.smoothed_diff(data[cases]),
                _("{} (smooth)").format(col_names[deaths]):
                fit.smoothed_diff(data[deaths]),
            },
            index=data.index,
        )
        ax = smooth.plot(legend=False, lw=2, ax=ax)

    # Prepare cases dataframe and plot it
    kwargs.setdefault("alpha", 0.5)
    new_cases = data.diff().fillna(0)
    new_cases = new_cases.rename(col_names, axis=1)

    if "ylim" not in kwargs:
        deaths = new_cases.iloc[:, 1]
        exp = np.log10(deaths[deaths > 0]).mean()
        exp = min(10, int(exp / 2))
        kwargs["ylim"] = (10**exp, None)
    ax: plt.Axes = new_cases.plot.bar(width=1.0, ax=ax, **kwargs)

    # Fix xticks
    periods = 7 if dates else 10
    xticks = ax.get_xticks()
    labels = ax.get_xticklabels()
    ax.set_xticks(xticks[::periods])
    ax.set_xticklabels(labels[::periods])
    ax.tick_params("x", rotation=0)
    ax.set_ylim(1, None)
    if tight_layout:
        fig = ax.get_figure()
        fig.tight_layout()
    return ax