示例#1
0
def align_hilo_ylim(ax1: plt.Axes, ax2: plt.Axes):
    # Limits
    ((a1, b1), (a2, b2)) = (ax1.get_ylim(), ax2.get_ylim())

    # Visible ticks
    t1 = np.asarray([y for y in ax1.get_yticks() if (a1 <= y <= b1)])
    t2 = np.asarray([y for y in ax2.get_yticks() if (a2 <= y <= b2)])

    # Relative position of ticks
    r1 = (t1 - a1) / (b1 - a1)
    r2 = (t2 - a2) / (b2 - a2)

    # Lower and upper space
    lo = max(min(r1) - 0, min(r2) - 0)
    hi = max(1 - max(r1), 1 - max(r2))

    # Stretch the middle part (usually breaks everything)
    (s1, s2) = (1, 1)

    # Adjust lower and upper space
    f1 = s1 * (max(t1) - min(t1)) / (1 - hi - lo)
    f2 = s2 * (max(t2) - min(t2)) / (1 - hi - lo)
    (a1, b1) = (-lo * f1 + min(t1),
                +hi * f1 + min(t1) + s1 * (max(t1) - min(t1)))
    (a2, b2) = (-lo * f2 + min(t2),
                +hi * f2 + min(t2) + s2 * (max(t2) - min(t2)))

    # Set limits
    ax1.set_ylim(a1, b1)
    ax2.set_ylim(a2, b2)
示例#2
0
def align_twinx_ticks(ax_left: plt.Axes, ax_right: plt.Axes) -> np.ndarray:
    """
    Returns an array of ticks for the right axis which match ones on the left.

    There's no easy way of aligning ticks nor a good general solution.
    """
    left = ax_left.get_ylim()
    right = ax_right.get_ylim()
    return linear_mapping(left, right, ax_left.get_yticks())
示例#3
0
def _add_margins(ax: plt.Axes, plot_data: np.ndarray, cutoff_lo: float, cutoff_hi: float, orient: str, margin: float):
    def _quantized_abs_ceil(x, q=0.5):
        return np.ceil(np.abs(x) / q) * q * np.sign(x)

    if orient == 'v':
        old_extents, lim_setter = ax.get_ylim(), ax.set_ylim
        ax.set_xlim(*list(map(_quantized_abs_ceil, ax.get_xlim())))
    else:
        old_extents, lim_setter = ax.get_xlim(), ax.set_xlim
        ax.set_ylim(*list(map(_quantized_abs_ceil, ax.get_ylim())))

    if np.min(plot_data) < cutoff_lo:
        lim_setter([old_extents[0] - margin * np.diff(old_extents), None])
    if np.max(plot_data) > cutoff_hi:
        lim_setter([None, old_extents[1] + margin * np.diff(old_extents)])
示例#4
0
def plot_hdi_info(prior_plot: plt.Axes, posterior_plot: plt.Axes,
                  p_theta: np.array, p_theta_given_data: np.array,
                  theta: np.array, hdi_mass: float):
    hdi_info = hdi_of_grid(p_theta, cred_mass=hdi_mass)
    height = hdi_info["height"]
    xmin, xmax = theta[hdi_info["indices"]].min(), theta[
        hdi_info["indices"]].max()
    prior_plot.hlines(y=height, xmin=xmin, xmax=xmax)
    prior_plot.vlines(x=xmin, ymin=0, ymax=height)
    prior_plot.vlines(x=xmax, ymin=0, ymax=height)
    prior_plot.text(xmin,
                    height * 1.3,
                    round(xmin, 3),
                    verticalalignment='bottom',
                    horizontalalignment='center')
    prior_plot.text(xmax,
                    height * 1.3,
                    round(xmax, 3),
                    verticalalignment='bottom',
                    horizontalalignment='center')
    prior_plot.text((xmin + xmax) / 2,
                    prior_plot.get_ylim()[1] * 0.5,
                    f"{hdi_mass * 100:.0f}% HDI",
                    verticalalignment='center',
                    horizontalalignment='center')

    hdi_info = hdi_of_grid(p_theta_given_data, cred_mass=hdi_mass)
    height = hdi_info["height"]
    xmin, xmax = theta[hdi_info["indices"]].min(), theta[
        hdi_info["indices"]].max()
    posterior_plot.hlines(y=height, xmin=xmin, xmax=xmax)
    posterior_plot.vlines(x=xmin, ymin=0, ymax=height)
    posterior_plot.vlines(x=xmax, ymin=0, ymax=height)
    posterior_plot.text(xmin,
                        height * 1.3,
                        round(xmin, 3),
                        verticalalignment='bottom',
                        horizontalalignment='center')
    posterior_plot.text(xmax,
                        height * 1.3,
                        round(xmax, 3),
                        verticalalignment='bottom',
                        horizontalalignment='center')
    posterior_plot.text((xmin + xmax) / 2,
                        posterior_plot.get_ylim()[1] * 0.5,
                        f"{hdi_mass * 100:.0f}% HDI",
                        verticalalignment='center',
                        horizontalalignment='center')
示例#5
0
def plot_summary_stats(ax: plt.Axes,
                       name: str,
                       accepted_s: [float],
                       s_obs: float,
                       s_hat: float,
                       dim=0) -> plt.Axes:
    """
    DESCRIPTION
    plot values of a summary statistic generated by accepted parameter values during sampling.

    PARAMETERS
    ax (plt.Axes) - axes to plot on.
    name (str) - name of summary statistic.
    accepted_s ([float]) - values of summary statistic from sampling
    s_obs (float) - summary statistic values from true model.
    prior (stats.Distribution) - summary statistic value of fitted model.

    RETURNS
    plt.Axes - axes on which plot was made
    """
    ax.hist(accepted_s)
    ymax = ax.get_ylim()[1]
    ax.vlines(s_obs, ymin=0, ymax=ymax, colors="green", label="s_obs")
    ax.vlines(s_hat, ymin=0, ymax=ymax, colors="orange", label="From Fitted")

    ax.set_xlabel(name)
    ax.set_title("Accepted {}".format(name))
    if (dim == 0): ax.legend()
    ax.margins(0)

    return ax
示例#6
0
    def __plot_bunches(self,
                       fig: plt.Figure,
                       ax: plt.Axes,
                       point: FiniteMetricVertex,
                       name: str = "u") -> None:
        """
        Plot all points and highlight the bunches for the given point on
        the provided figure/axes.

        :param fig: The matplotlib figure to plot on.
        :param ax: The matplotlib axes to plot on.
        :param point: The vertex whose bunches we wish to plot.
        :param name: The name to use to label the vertex/bunches.
        """
        ax.cla()

        # Plot all points and color by set A_i
        ax.scatter([v.i for v in self.vertices], [v.j for v in self.vertices],
                   s=4,
                   color="black",
                   marker=".",
                   label="Points")

        # Plot and label the point itself
        ax.scatter([point.i], [point.j],
                   s=12,
                   color="red",
                   marker="*",
                   label=name)
        ax.annotate(name, (point.i, point.j), color="red")

        # Force the xlim and ylim to become fixed
        ax.set_xlim(*ax.get_xlim())
        ax.set_ylim(*ax.get_ylim())

        # For the current point, mark and label its p_i s
        # and add circles
        p_i = [self.p[point][i] for i in range(self.k)]
        for i in range(1, self.k):
            if p_i[i] is None:
                continue
            ax.annotate("p_{}({})".format(i, name), (p_i[i][0].i, p_i[i][0].j),
                        xytext=(5, 5),
                        textcoords="offset pixels",
                        color="violet")
            circ = plt.Circle((point.i, point.j), p_i[i][1], fill=False)
            ax.add_patch(circ)

        # Plot the points in the bunch
        B = [w for w in self.B[point]]
        ax.scatter([w.i for w in B], [w.j for w in B],
                   s=12,
                   color="lime",
                   marker="*",
                   label="B({})".format(name))

        ax.set_title("Bunch B({})".format(name))
        ax.legend(bbox_to_anchor=(1.05, 1.0), loc='upper left')
        plt.tight_layout()
        fig.show()
def ageify_axis(ax: plt.Axes):
    """
    Change the look of a given axis to match the AGE group standard.

    The axis will have ticks inside and outside on the lower and left axis.
    The top and right axis will get ticks on the inside only.
    Additionally, the text font will be switched to serif (for publiactions) both in
    normal text mode and math mode.

    Parameters
    ----------
    ax : plt.Axes
        Given axis that should be changed.

    Returns
    -------
    ax : TYPE
        Original axis object.
    ax_top : TYPE
        Created axis object representing the top axis.
    ax_right : TYPE
        Created axis object representing the right axis.
    """
    ax.spines["top"].set_visible(True)
    ax.spines["bottom"].set_visible(True)
    ax.spines["left"].set_visible(True)
    ax.spines["right"].set_visible(True)

    ax_right = ax.twinx()
    ax_top = ax.twiny()

    ax.tick_params(
        # axis="both",
        direction="inout",
        labeltop=False,
        labelright=False,
        bottom=True,
        left=True,
    )

    ax_right.tick_params(
        right=True,
        direction="in",
        labelright=False,
    )

    ax_right.set_ylim(ax.get_ylim())

    ax_top.tick_params(
        top=True,
        direction="in",
        labeltop=False,
    )

    ax_top.set_xlim(ax.get_xlim())

    plt.rcParams['font.family'] = 'serif'
    plt.rcParams['mathtext.fontset'] = 'dejavuserif'

    return (ax, ax_top, ax_right)
示例#8
0
def get_axes_limits(ax: Axes, ax_crs: CRS, crs: CRS = SphericalEarth):
    """ Get the limits of the window covered by an Axes in another coordinate
    system. """
    xl, xr = ax.get_xlim()
    yb, yt = ax.get_ylim()
    ax_bbox = get_axes_extent(ax, ax_crs, crs=crs)

    # Minimize bottom spine
    x_ = scipy.optimize.fminbound(lambda x: ax_crs.transform(crs, x, yb)[1],
                                  xl, xr)
    ymin = ax_crs.transform(crs, x_, yb)[1]

    # Maximize top spine
    x_ = scipy.optimize.fminbound(lambda x: -ax_crs.transform(crs, x, yt)[1],
                                  xl, xr)
    ymax = ax_crs.transform(crs, x_, yt)[1]

    # Minimize left spine
    y_ = scipy.optimize.fminbound(lambda y: ax_crs.transform(crs, xl, y)[0],
                                  yb, yt)
    xmin = ax_crs.transform(crs, xl, y_)[0]

    # Maximize right spine
    y_ = scipy.optimize.fminbound(lambda y: -ax_crs.transform(crs, xr, y)[0],
                                  yb, yt)
    xmax = ax_crs.transform(crs, xr, y_)[0]
    return xmin, xmax, ymin, ymax
示例#9
0
def _get_aspect_ratio(ax: plt.Axes) -> float:
    minx, maxx = ax.get_xlim()
    miny, maxy = ax.get_ylim()
    data_width, data_height = maxx - minx, maxy - miny
    if abs(data_height) > 1e-9:
        return data_width / data_height
    return 1.0
示例#10
0
def plot_marginals(
    infd: az.InferenceData,
    var_name: str,
    ax: plt.Axes,
    true_values: pd.Series = None,
):
    """Plot marginal ability parameters against true values."""
    qs = infd.posterior[[var_name]].quantile([0.1, 0.9], dim=("chain", "draw"))
    qs = qs.sortby(qs[var_name].sel(quantile=0.9))
    y = np.linspace(*ax.get_ylim(), qs.dims["skater_name"])
    ax.set_yticks(y)
    ax.set_yticklabels(qs["skater_name"].values)
    ax.hlines(
        y,
        qs[var_name].sel(quantile=0.1),
        qs[var_name].sel(quantile=0.9),
        color="tab:blue",
        label="90% marginal interval",
    )
    if true_values is not None:
        qs["truth"] = (("skater_name", ), true_values)
        qs["truth_in_interval"] = (qs["truth"] <
                                   qs[var_name].sel(quantile=0.9) & qs["truth"]
                                   > qs[var_name].sel(quantile=0.1))
        ax.scatter(qs["truth"],
                   y,
                   marker="|",
                   color="red",
                   label="True ability")
    ax.legend(frameon=False)
    return ax
示例#11
0
def plot_marginals(
    infd: az.InferenceData,
    var_name,
    ax: plt.Axes,
    true_values: pd.Series = None,
):
    """Plot marginal ability parameters against true values."""
    qs = infd.posterior[var_name].to_series().unstack().quantile([0.1, 0.9]).T
    qs["truth"] = true_values
    qs["truth_in_interval"] = (qs["truth"] < qs[0.9]) & (qs["truth"] > qs[0.1])
    qs = qs.sort_values(0.9)
    y = np.linspace(*ax.get_ylim(), len(qs))
    ax.set_yticks(y)
    ax.set_yticklabels(qs.index)
    ax.hlines(y,
              qs[0.1],
              qs[0.9],
              color="tab:blue",
              label="90% marginal interval")
    if true_values is not None:
        ax.scatter(qs["truth"],
                   y,
                   marker="|",
                   color="red",
                   label="True ability")
    ax.legend(frameon=False)
    return ax
示例#12
0
def append_loc_to_fig(ax: plt.Axes, dt_list: list, label: str = "g") -> dict:
    """
    append vertical lines to indicate a time location 'for eg: arterial blood gas'

    Parameters
    ----------
    ax : plt.Axes
        the axis to add on.
    dt_list : list
        list of datetime values.
    label : str, optional (default is 'g')
        a key to add to the label.

    Returns
    -------
    dict
        a dictionary containing the locations.

    """

    num_times = mdates.date2num(dt_list)
    res = {}
    for i, num_time in enumerate(num_times):
        st = label + str(i + 1)
        ax.axvline(num_time, color="tab:blue")
        ax.text(num_time, ax.get_ylim()[1], st, color="tab:blue")
        res[i] = num_time
    return res
示例#13
0
文件: plt2.py 项目: yulkang/pylabyk
def sameaxes(ax: Union[AxesArray, GridAxes],
             ax0: plt.Axes = None,
             xy='xy',
             lim=None):
    """
    Match the chosen limits of axes in ax to ax0's (if given) or the max range.
    Also consider: ax1.get_shared_x_axes().join(ax1, ax2)
    Optionally followed by ax1.set_xticklabels([]); ax2.autoscale()
    See: https://stackoverflow.com/a/42974975/2565317
    :param ax: np.ndarray (as from subplotRCs) or list of axes.
    :param ax0: a scalar axes to match limits to. if None (default),
    match the maximum range among axes in ax.
    :param xy: 'x'|'y'|'xy'(default)
    :return: [[min, max]] of limits. If xy='xy', contains two pairs.
    """
    if type(ax) is np.ndarray or type(ax) is GridAxes:
        ax = ax.flatten()

    def cat_lims(lims):
        return np.concatenate([np.array(v1).reshape(1, 2) for v1 in lims])

    lims_res = []
    for xy1 in xy:
        if lim is None:
            if ax0 is None:
                if xy1 == 'x':
                    lims = cat_lims([ax1.get_xlim() for ax1 in ax])
                    lim0 = ax[0].get_xlim()
                    try:
                        is_inverted = ax[0].get_xaxis().get_inverted()
                    except AttributeError:
                        is_inverted = ax[0].xaxis_inverted()
                else:
                    lims = cat_lims([ax1.get_ylim() for ax1 in ax])
                    try:
                        is_inverted = ax[0].get_yaxis().get_inverted()
                    except AttributeError:
                        is_inverted = ax[0].yaxis_inverted()
                if is_inverted:
                    lims0 = [np.max(lims[:, 0]), np.min(lims[:, 1])]
                else:
                    lims0 = [np.min(lims[:, 0]), np.max(lims[:, 1])]
            else:
                if xy1 == 'x':
                    lims0 = ax0.get_xlim()
                else:
                    lims0 = ax0.get_ylim()
        else:
            lims0 = lim
        if xy1 == 'x':
            for ax1 in ax:
                ax1.set_xlim(lims0)
        else:
            for ax1 in ax:
                ax1.set_ylim(lims0)
        lims_res.append(lims0)
    return lims_res
示例#14
0
def get_axes_extent(ax: Axes, ax_crs: CRS, crs: CRS=SphericalEarth):
    """ Get the extent of an Axes in geographical (or other) coordinates. """
    xl, xr = ax.get_xlim()
    yb, yt = ax.get_ylim()

    ll = ax_crs.transform(crs, xl, yb)
    lr = ax_crs.transform(crs, xr, yb)
    ur = ax_crs.transform(crs, xr, yt)
    ul = ax_crs.transform(crs, xl, yt)
    return Polygon([ll, lr, ur, ul], crs=crs)
示例#15
0
def get_axes_extent(ax: Axes, ax_crs: CRS, crs: CRS = SphericalEarth):
    """ Get the extent of an Axes in geographical (or other) coordinates. """
    xl, xr = ax.get_xlim()
    yb, yt = ax.get_ylim()

    ll = ax_crs.transform(crs, xl, yb)
    lr = ax_crs.transform(crs, xr, yb)
    ur = ax_crs.transform(crs, xr, yt)
    ul = ax_crs.transform(crs, xl, yt)
    return Polygon([ll, lr, ur, ul], crs=crs)
示例#16
0
def center_axis(axes: plt.Axes, which='y'):
    if which == 'y':
        max_abs = np.max(np.abs(axes.get_ylim()))
        axes.set_ylim(-max_abs, max_abs)
    elif which == 'x':
        max_abs = np.max(np.abs(axes.get_xlim()))
        axes.set_xlim(-max_abs, max_abs)
    elif which == 'both':
        pass
    return
示例#17
0
    def __plot_p_i(self,
                   fig: plt.Figure,
                   ax: plt.Axes,
                   point: FiniteMetricVertex,
                   name: str = "u") -> None:
        """
        Plot all points and highlight the witnesses p_i for the given point
        along with corresponding rings on the given figure and axes.

        :param fig: The matplotlib figure to plot on.
        :param ax: The matplotlib axes to plot on.
        :param point: The vertex whose witnesses/rings we wish to plot.
        :param name: The name to use to label the vertex/bunches.
        """
        ax.cla()

        # Plot all points and color by set A_i
        for i, a_i in enumerate(self.A):
            ax.scatter([v.i for v in a_i], [v.j for v in a_i],
                       s=8,
                       marker="o",
                       label="A_{}".format(i))

        # Plot and label the point itself
        ax.scatter([point.i], [point.j],
                   s=12,
                   color="red",
                   marker="*",
                   label=name)
        ax.annotate(name, (point.i, point.j), color="red")

        # Force the xlim and ylim to become fixed
        ax.set_xlim(*ax.get_xlim())
        ax.set_ylim(*ax.get_ylim())

        # For the current point, mark and label its p_i s
        # and add circles
        p_i = [self.p[point][i] for i in range(self.k)]
        for i in range(1, self.k):
            if p_i[i] is None:
                continue
            ax.annotate("p_{}({})".format(i, name), (p_i[i][0].i, p_i[i][0].j),
                        xytext=(5, 5),
                        textcoords="offset pixels",
                        color="violet")
            circ = plt.Circle((point.i, point.j), p_i[i][1], fill=False)
            ax.add_patch(circ)

        ax.set_title("Witnesses p_i({}) and rings.".format(name))
        ax.legend(bbox_to_anchor=(1.05, 1.0), loc='upper left')
        plt.tight_layout()
        fig.show()
示例#18
0
def make_dual_axis(ax: plt.Axes = None, axis='x', unit='nm', minor_ticks=True):
    if ax is None:
        ax = plt.gca()

    if axis == 'x':
        pseudo_ax = ax.twiny()
        limits = ax.get_xlim()
        u, l = 1e7 / np.array(limits)
        pseudo_ax.set_xlim(limits)
        sub_axis = pseudo_ax.xaxis

    elif axis == 'y':
        pseudo_ax = ax.twinx()
        limits = ax.get_ylim()
        u, l = 1e7 / np.array(limits)
        pseudo_ax.set_ylim(limits)
        sub_axis = pseudo_ax.yaxis
    else:
        raise ValueError('axis must be either x or y.')

    def conv(x, y):
        return '%.0f' % (1e7 / x)

    ff = plt.FuncFormatter(conv)
    sub_axis.set_major_formatter(ff)
    major = [1000, 500, 200, 100, 50]
    minor = [200, 100, 50, 25, 10]
    for x, m in zip(major, minor):
        a, b = math.ceil(u / x), math.ceil(l / x)
        n = abs(b - a)
        if n > 4:
            ticks = np.arange(
                a * x,
                b * x,
                x,
            )

            a, b = math.floor(u / m), math.floor(l / m)
            min_ticks = np.arange(a * m, b * m, m)

            break

    sub_axis.set_ticks(1e7 / ticks)
    sub_axis.set_ticks(1e7 / min_ticks, minor=True)
    if minor_ticks:
        ax.minorticks_on()
        # pseudo_ax.minorticks_on()
    if unit is 'nm':
        sub_axis.set_label('Wavelengths [nm]')
    elif unit is 'cm':
        sub_axis.set_label('Wavenumber [1/cm]')
示例#19
0
def tune_axes(
    ax: plt.Axes,
    axr: plt.Axes,
    variable: str,
    binning: Tuple[int, float, float],
    logy: bool = False,
    linscale: float = 1.4,
    logscale: float = 100,
) -> None:
    """tune up the axes properties

    Parameters
    ----------
    ax : :obj:`matplotlib.axes.Axes`
        the main stack axes
    axr : :obj:`matplotlib.axes.Axes`
        the ratio axes
    variable : str
        the name for the variable that is histogrammed
    binning : tuple(int, float, float)
        the number of bins and the start and stop on the x-axis
    logy : bool
        set the yscale to log
    linscale : float
        the factor to scale up the y-axis when linear
    logscale : float
        the factor to scale up the y-axis when log
    """
    nbins, start, stop = binning
    width = round((stop - start) / nbins, 2)
    set_labels(ax, axr, variable, width=width)
    ax.set_xlim([start, stop])
    axr.set_xlim([start, stop])
    if logy:
        ax.set_yscale("log")
        ax.set_ylim([10, ax.get_ylim()[1] * logscale])
    else:
        ax.set_ylim([0, ax.get_ylim()[1] * linscale])
示例#20
0
def plot_parameter_posterior(ax: plt.Axes,
                             name: str,
                             accepted_parameter: [float],
                             predicted_val: float,
                             prior: "stats.Distribution",
                             dim=0,
                             weights=None) -> plt.Axes:
    """
    DESCRIPTION
    plot posterior of a parameter.

    PARAMETERS
    ax (plt.Axes) - axes to plot on.
    name (str) - name of parameter.
    accepted_parameter ([float]) - values of parameter which were accepted during sampling.
    predicted_val (float) - predicted value for parameter (likely mean of `accepted_parameter`)
    prior (stats.Distribution) - prior used when sampling for parameter.

    RETURNS
    plt.Axes - axes on which plot was made
    """
    weights = weights if weights else [1 / len(accepted_parameter)
                                       ] * len(accepted_parameter)

    # plot prior used
    x = np.linspace(min(accepted_parameter + [prior.ppf(.01)]),
                    max(accepted_parameter + [prior.ppf(.99)]), 100)
    # x=np.linspace(prior.ppf(.01),prior.ppf(.99),100)
    ax.plot(x, prior.pdf(x), "k-", lw=2, label='Prior')

    # plot accepted  points
    ax.hist(accepted_parameter, density=True)

    # plot smooth posterior (ie KDE)
    density = stats.kde.gaussian_kde(accepted_parameter, weights=weights)
    ax.plot(x, density(x), "--", lw=2, c="orange", label="Posterior KDE")

    ymax = ax.get_ylim()[1]
    ax.vlines(predicted_val,
              ymin=0,
              ymax=ymax,
              colors="orange",
              label="Prediction")
    ax.set_xlabel(name)
    ax.set_title("Posterior for {}".format(name))
    if (dim == 0): ax.legend()
    ax.margins(0)

    return ax
示例#21
0
def add_lines(axes: plt.Axes,
              xs=None,
              ys=None,
              colors=None,
              **kwargs) -> plt.Axes:
    """
    Add horizontal or vertical lines to charts

    Args:
        axes: axes to add
        xs: Xs
        ys: Ys
        colors: list of colors (Xs first then Ys if both given
        **kwargs: kwargs to pass

    Returns:
        plt.Axes
    """
    idx = 0
    if colors is None: colors = ['darkgreen', 'darkorange', 'darkred']

    if xs is not None:
        if isinstance(xs, str): xs = [xs]
        if not hasattr(xs, '__iter__'): xs = [xs]
        ylim = axes.get_ylim()
        for x in xs:
            axes.vlines(x=x,
                        ymin=ylim[0],
                        ymax=ylim[1],
                        colors=colors[idx % len(colors)],
                        **kwargs)
            idx = idx + 1
        axes.set_ylim(ymin=ylim[0], ymax=ylim[1])

    if ys is not None:
        if isinstance(ys, str): ys = [ys]
        if not hasattr(ys, '__iter__'): ys = [ys]
        xlim = axes.get_xlim()
        for y in ys:
            axes.hlines(y=y,
                        xmin=xlim[0],
                        xmax=xlim[1],
                        colors=colors[idx % len(colors)],
                        **kwargs)
            idx = idx + 1
        axes.set_xlim(xmin=xlim[0], xmax=xlim[1])

    return axes
示例#22
0
def flip_axis(ax: plt.Axes, axis: str = "x") -> None:
    """Flip axis so it extends in the opposite direction.

    Parameters
    ----------
    ax : Axes
        Axes object with axis to flip.
    axis : str, optional
        Which axis to flip, by default "x".
    """
    if axis.lower() == "x":
        ax.set_xlim(reversed(ax.get_xlim()))
    elif axis.lower() == "y":
        ax.set_ylim(reversed(ax.get_ylim()))
    else:
        raise ValueError("`axis` must be 'x' or 'y'")
示例#23
0
    def _plot_ax_outliers(axes: plt.Axes, ax_data: pd.Series, extents: np.ndarray):
        if plotter == sns.kdeplot:
            group = .5 * np.diff(axes.get_ylim())
            ax_data = ax_data.values

            outlier_data = ax_data[np.logical_or(cutoff_lo > ax_data, ax_data > cutoff_hi)]
            _plot_outliers(axes, outlier_data, orient=orient, group=group, padding=padding,
                           plot_extents=extents, fmt=fmt)
            return axes

        if not group_names or len(group_names) == 1:
            _plot_group_outliers(ax_data, extents, axes=axes)
            return

        for group_idx, group_name in enumerate(group_names):
            _plot_group_outliers(ax_data, extents, group_idx=group_idx, group_name=group_name, axes=axes)
示例#24
0
def make_dual_axis(ax: plt.Axes = None, axis='x', unit='nm', minor_ticks=True):
    if ax is None:
        ax = plt.gca()

    if axis == 'x':
        pseudo_ax = ax.twiny()
        limits = ax.get_xlim()
        u, l = 1e7 / np.array(limits)
        pseudo_ax.set_xlim(limits)
        sub_axis = pseudo_ax.xaxis

    elif axis == 'y':
        pseudo_ax = ax.twinx()
        limits = ax.get_ylim()
        u, l = 1e7 / np.array(limits)
        pseudo_ax.set_ylim(limits)
        sub_axis = pseudo_ax.yaxis
    else:
        raise ValueError('axis must be either x or y.')

    def conv(x, y):
        return '%.0f' % (1e7 / x)

    ff = plt.FuncFormatter(conv)
    sub_axis.set_major_formatter(ff)
    major = [1000, 500, 200, 100, 50]
    minor = [200, 100, 50, 25, 10]
    for x, m in zip(major, minor):
        a, b = math.ceil(u / x), math.ceil(l / x)
        n = abs(b - a)
        if n > 4:
            ticks = np.arange(a * x, b * x, x, )

            a, b = math.floor(u / m), math.floor(l / m)
            min_ticks = np.arange(a * m, b * m, m)

            break

    sub_axis.set_ticks(1e7 / ticks)
    sub_axis.set_ticks(1e7 / min_ticks, minor=True)
    if minor_ticks:
        ax.minorticks_on()
        # pseudo_ax.minorticks_on()
    if unit is 'nm':
        sub_axis.set_label('Wavelengths [nm]')
    elif unit is 'cm':
        sub_axis.set_label('Wavenumber [1/cm]')
示例#25
0
def add_labels(ax: plt.Axes,
               labels: list = None,
               vertical_offsets: list = None,
               patch_num: list = None,
               fontsize: int = 14,
               rotation: int = 0,
               skip_zero: bool = False,
               format_str: str = "{:.2f}x",
               label_color: str = "#2f2f2f"):
    """
    :param ax: current axis, it is assumed that each ax.Patch is a bar over which we want to add a label
    :param labels: optional labels to add. If not present, add the bar height
    :param vertical_offsets: additional vertical offset for each label.
      Useful when displaying error bars (see @get_upper_ci_size), and for fine tuning
    :param patch_num: indices of patches to which we add labels, if some of them should be skipped
    :param fontsize: size of each label
    :param rotation: rotation of the labels (e.g. 90°)
    :param skip_zero: if True, don't put a label over the first bar
    :param format_str: format of each label, by default use speedup (e.g. 2.10x)
    :param label_color: hexadecimal color used for labels
        
    Used to add labels above barplots;
    """
    if not vertical_offsets:
        # 5% above each bar, by default;
        vertical_offsets = [ax.get_ylim()[1] * 0.05] * len(ax.patches)
    if not labels:
        labels = [p.get_height() for p in ax.patches]
    patches = []
    if not patch_num:
        patches = ax.patches
    else:
        patches = [p for i, p in enumerate(ax.patches) if i in patch_num]

    # Iterate through the list of axes' patches
    for i, p in enumerate(patches):
        if labels[i] and (i > 0 or not skip_zero):
            ax.text(p.get_x() + p.get_width() / 2.,
                    vertical_offsets[i] + p.get_height(),
                    format_str.format(labels[i]),
                    fontsize=fontsize,
                    color=label_color,
                    ha='center',
                    va='bottom',
                    rotation=rotation)
示例#26
0
文件: plt2.py 项目: Gravifer/pylabyk
def patch_chance_level(level=None,
                       signs=(-1, 1),
                       ax: plt.Axes = None,
                       xy='y',
                       color=(0.7, 0.7, 0.7)):
    if level is None:
        level = np.log(10.)
    if ax is None:
        ax = plt.gca()

    hs = []
    for sign in signs:
        if xy == 'y':
            if ax.yaxis.get_scale() == 'log':
                vmin = 1.
                level1 = level * 10**sign
            else:
                vmin = 0.
                level1 = level * sign

            lim = ax.get_xlim()
            rect = mpl.patches.Rectangle([lim[0], vmin],
                                         lim[1] - lim[0],
                                         level1,
                                         linewidth=0,
                                         fc=color,
                                         zorder=-1)
        elif xy == 'x':
            if ax.xaxis.get_scale() == 'log':
                vmin = 1.
                level1 = level * 10**sign
            else:
                vmin = 0.
                level1 = level * sign

            lim = ax.get_ylim()
            rect = mpl.patches.Rectangle([vmin, lim[0]],
                                         level1,
                                         lim[1] - lim[0],
                                         linewidth=0,
                                         fc=color,
                                         zorder=-1)
        ax.add_patch(rect)
        hs.append(rect)
    return hs
示例#27
0
    def plot_bound(self,
                   t_all: Sequence[float] = None,
                   ax: plt.Axes = None,
                   **kwargs) -> plt.Line2D:
        if ax is None:
            ax = plt.gca()
        if t_all is None:
            t_all = self.t_all

        kwargs = argsutil.kwdefault(kwargs, color='k', linestyle='-')
        h = ax.plot(*npys(t_all, self.get_bound(t_all)), **kwargs)
        ax.set_xlabel('time (s)')
        ax.set_ylabel(r"$b(t)$")
        y_lim = ax.get_ylim()
        y_min = -y_lim[1] * 0.05
        ax.set_ylim(ymin=y_min)
        plt2.detach_axis('y', amin=0)
        plt2.detach_axis('x', amin=0)
        plt2.box_off()
        return h
示例#28
0
def add_tukey_marks(
    data: pd.Series,
    ax: plt.Axes,
    annot: bool = True,
    iqr_color: str = "r",
    fence_color: str = "k",
    fence_style: str = "--",
    annot_quarts: bool = False,
) -> plt.Axes:
    """Add IQR box and fences to a histogram-like plot.

    Args:
        data (pd.Series): Data for calculating IQR and fences.
        ax (plt.Axes): Axes to annotate.
        iqr_color (str, optional): Color of shaded IQR box. Defaults to "r".
        fence_color (str, optional): Fence line color. Defaults to "k".
        fence_style (str, optional): Fence line style. Defaults to "--".
        annot_quarts (bool, optional): Annotate Q1 and Q3. Defaults to False.

    Returns:
        plt.Axes: Annotated Axes object.
    """
    q1 = data.quantile(0.25)
    q3 = data.quantile(0.75)
    ax.axvspan(q1, q3, color=iqr_color, alpha=0.2)
    iqr_mp = q1 + ((q3 - q1) / 2)
    lower, upper = outliers.tukey_fences(data)
    ax.axvline(lower, c=fence_color, ls=fence_style)
    ax.axvline(upper, c=fence_color, ls=fence_style)
    text_yval = ax.get_ylim()[1]
    text_yval *= 1.01
    if annot:
        ax.text(iqr_mp, text_yval, "IQR", ha="center")
        if annot_quarts:
            ax.text(q1, text_yval, "Q1", ha="center")
            ax.text(q3, text_yval, "Q3", ha="center")
        ax.text(upper, text_yval, "Fence", ha="center")
        ax.text(lower, text_yval, "Fence", ha="center")
    return ax
示例#29
0
def plot_positive_negative_bars(ax: plt.Axes,
                                values: pd.Series,
                                positive_dict: dict = None,
                                negative_dict: dict = None,
                                title='Significant Spearman Correlation',
                                x_label='Spearman Correlation'):
    if positive_dict is None:
        positive_dict = {}
    if negative_dict is None:
        negative_dict = {}

    default_positive_dict = {'color': 'green', 'height': 0.2}
    default_negative_dict = {'color': 'red', 'height': 0.2}
    positive_dict = set_default_parameters(positive_dict,
                                           default_positive_dict)
    negative_dict = set_default_parameters(negative_dict,
                                           default_negative_dict)

    sorted_values = values.sort_values()

    y_position = np.arange(len(sorted_values))
    positive_values = sorted_values.apply(lambda x: x if x >= 0 else 0)
    ax.barh(y_position, positive_values, **positive_dict)
    negative_values = sorted_values.apply(lambda x: x if x < 0 else 0)
    ax.barh(y_position, negative_values, **negative_dict)

    ax.set_yticks(y_position)
    ax.set_yticklabels(sorted_values.index)

    fat_bar_number = 5
    if len(y_position) < fat_bar_number:
        ax.set_ylim([
            i + np.sign(i) * (fat_bar_number - len(y_position))
            for i in ax.get_ylim()
        ])
    # ax.set_xlim((1,10))
    ax.set_title(title)
    ax.set_xlabel(x_label)
    return ax
示例#30
0
def get_axes_limits(ax: Axes, ax_crs: CRS, crs: CRS=SphericalEarth):
    """ Get the limits of the window covered by an Axes in another coordinate
    system. """
    xl, xr = ax.get_xlim()
    yb, yt = ax.get_ylim()

    # Minimize bottom spine
    x_ = scipy.optimize.fminbound(lambda x: ax_crs.transform(crs, x, yb)[1], xl, xr)
    ymin = ax_crs.transform(crs, x_, yb)[1]

    # Maximize top spine
    x_ = scipy.optimize.fminbound(lambda x: -ax_crs.transform(crs, x, yt)[1], xl, xr)
    ymax = ax_crs.transform(crs, x_, yt)[1]

    # Minimize left spine
    y_ = scipy.optimize.fminbound(lambda y: ax_crs.transform(crs, xl, y)[0], yb, yt)
    xmin = ax_crs.transform(crs, xl, y_)[0]

    # Maximize right spine
    y_ = scipy.optimize.fminbound(lambda y: -ax_crs.transform(crs, xr, y)[0], yb, yt)
    xmax = ax_crs.transform(crs, xr, y_)[0]
    return xmin, xmax, ymin, ymax
示例#31
0
文件: pl.py 项目: joshua-gould/scanpy
    def _plot_scores(ax: plt.Axes,
                     scores: np.ndarray,
                     scale: str,
                     title: str,
                     threshold=None):
        ax.hist(
            scores,
            np.linspace(0, 1, 50),
            color='gray',
            linewidth=0,
            density=True,
        )
        ax.set_yscale(scale)
        yl = ax.get_ylim()
        ax.set_ylim(yl)

        if threshold is not None:
            ax.plot(threshold * np.ones(2), yl, c='black', linewidth=1)

        ax.set_title(title)
        ax.set_xlabel('Doublet score')
        ax.set_ylabel('Prob. density')
def adjust_axes_lim(ax: plt.Axes, min_x_thresh: float, max_x_thresh: float,
                    min_y_thresh: float, max_y_thresh: float):
    """
    Adjust ax limit so that it will be at least as small as min threshold and as big as max threshold.
    If min threshold is larger than existing min axes value it will not change (and vice versa for max).
    Thus the axes should be beforehand adjusted not to include padding around plot elements, as this will be
    included in min/max axes value as well.
    :param ax: Adjust range on axes object
    :param min_x_thresh: ax x_min must be at least that small.
    :param max_x_thresh: ax x_max must be at least that big.
    :param min_y_thresh: ax y_min must be at least that small.
    :param max_y_thresh: ax y_max must be at least that big.
    """
    y_min, y_max = ax.get_ylim()
    x_min, x_max = ax.get_xlim()

    if round(y_min, 3) >= round(min_y_thresh, 3):
        y_min = min_y_thresh
    else:
        print('min y was set to', y_min, 'instead of', min_y_thresh)
    if round(y_max, 3) <= round(max_y_thresh, 3):
        y_max = max_y_thresh
    else:
        print('max y was set to', y_max, 'instead of', max_y_thresh)

    if round(x_min, 3) >= round(min_x_thresh, 3):
        x_min = min_x_thresh
    else:
        print('min x was set to', x_min, 'instead of', min_x_thresh)
    if round(x_max, 3) <= round(max_x_thresh, 3):
        x_max = max_x_thresh
    else:
        print('max x was set to', x_max, 'instead of', max_x_thresh)

    ax.set_ylim(y_min, y_max)
    ax.set_xlim(x_min, x_max)
def set_new_axlim(axs: plt.Axes,
                  data: np.ndarray,
                  err: _ty.Optional[np.ndarray] = None,
                  *,
                  yaxis: bool = True,
                  reset: bool = False,
                  buffer: float = 0.05) -> None:
    """Set axes limits that will show all data, including existing.

    Parameters
    ----------
    axs : mpl.axes.Axes
        `Axes` instance whose limits are to be adjusted.
    data : np.ndarray (n)
        Array of numbers plotted along the axis.
    err : None|np.ndarray (n,)|(2,n), optional
        Error bars for data, by default: `None`.
    yaxis : bool, optional
        Are we modifying the y axis? By default: `True`.
    reset : bool, optional
        Do we ignore the existing axis limits? By default: `False`.
    buffer : float, optional
        Fractional padding around data, by default `0.05`.
    """
    log = (axs.get_yscale() if yaxis else axs.get_xscale()) == 'log'
    lim = calc_axlim(data, err, log, buffer)
    if not reset:
        if yaxis:
            axlim = axs.get_ylim()
        else:
            axlim = axs.get_xlim()
        lim = min(lim[0], axlim[0]), max(lim[1], axlim[1])
    if yaxis:
        axs.set_ylim(lim)
    else:
        axs.set_xlim(lim)
示例#34
0
def label_ticks(xs: Iterable[float], ys: Iterable[float], ax: Axes=None,
        map_crs: CRS=Cartesian, graticule_crs: CRS=SphericalEarth,
        textargs=None, tickargs=None,
        xformatter=None, yformatter=None):
    """ Label graticule lines, returning a list if Text objects.

    Parameters
    ----------
    xs : Iterable[float],
    ys : Iterable[float]
        Easting and northing componenets of labels, in `graticule_crs`
    ax : Axes, optional
        Axes to draw to (default current Axes)
    map_crs : karta.crs.CRS, optional
        CRS giving the display projection (default Cartesian)
    graticule_crs : karta.crs.CRS, optional
        CRS giving the graticule/label projection (default SphericalEarth)
    textargs : dict, optional
        Keyword arguments to pass to plt.text
    tickargs : dict, optional
        Keyword arguments to pass to plt.plot
    xformatter : callable, optional
        function that given an easting/longitude returns a label
    yformatter : callable, optional
        function that given a northing/latitude returns a label
    """
    if textargs is None:
        textargs = dict()

    if tickargs is None:
        tickargs = dict(marker="+", mew=2, ms=14, mfc="k", mec="k", ls="none")

    if xformatter is None:
        xformatter = lambda x: "{0} E".format(x)

    if yformatter is None:
        yformatter = lambda y: "{0} N".format(y)

    # Find tick locations
    bbox = get_axes_extent(ax, map_crs, graticule_crs)  # bottom, right, top, left

    ticks = dict(xticks=[], yticks=[])

    xmin, xmax = sorted(ax.get_xlim())
    ymin, ymax = sorted(ax.get_ylim())

    # bottom spine
    for x in xs:
        if isbetween(x, bbox[0][0], bbox[1][0]):
            ticks["xticks"].append((froot(lambda xt:
                                          map_crs.transform(graticule_crs, xt, ymin)[0]-x,
                                          xmin, xmax), ymin, xformatter(x)))

    for y in ys:
        if isbetween(y, bbox[0][1], bbox[1][1]):
            ticks["yticks"].append((froot(lambda xt:
                                          map_crs.transform(graticule_crs, xt, ymin)[1]-y,
                                          xmin, xmax), ymin, yformatter(y)))

    # top spine
    for x in xs:
        if isbetween(x, bbox[2][0], bbox[3][0]):
            ticks["xticks"].append((froot(lambda xt:
                                          map_crs.transform(graticule_crs, xt, ymax)[0]-x,
                                          xmin, xmax), ymax, xformatter(x)))

    for y in ys:
        if isbetween(y, bbox[2][1], bbox[3][1]):
            ticks["yticks"].append((froot(lambda xt:
                                          map_crs.transform(graticule_crs, xt, ymax)[1]-y,
                                          xmin, xmax), ymax, yformatter(y)))

    # left spine
    for x in xs:
        if isbetween(x, bbox[0][0], bbox[3][0]):
            ticks["xticks"].append((xmin,
                                    froot(lambda yt:
                                          map_crs.transform(graticule_crs, xmin, yt)[0]-x,
                                          ymin, ymax), xformatter(x)))


    for y in ys:
        if isbetween(y, bbox[0][1], bbox[3][1]):
            ticks["yticks"].append((xmin,
                                    froot(lambda yt:
                                          map_crs.transform(graticule_crs, xmin, yt)[1]-y,
                                          ymin, ymax), yformatter(y)))


    # right spine
    for x in xs:
        if isbetween(x, bbox[1][0], bbox[2][0]):
            ticks["xticks"].append((xmax,
                                    froot(lambda yt:
                                          map_crs.transform(graticule_crs, xmax, yt)[0]-x,
                                          ymin, ymax), xformatter(x)))

    for y in ys:
        if isbetween(y, bbox[1][1], bbox[2][1]):
            ticks["yticks"].append((xmax,
                                    froot(lambda yt:
                                          map_crs.transform(graticule_crs, xmax, yt)[1]-y,
                                          ymin, ymax), yformatter(y)))

    # Update map
    txts = []
    for pt in ticks["xticks"]:
        ax.plot(pt[0], pt[1], **tickargs)
        txts.append(ax.text(pt[0], pt[1], pt[2], **textargs))

    for pt in ticks["yticks"]:
        ax.plot(pt[0], pt[1], **tickargs)
        txts.append(ax.text(pt[0], pt[1], pt[2], **textargs))

    ax.set_xticks([])
    ax.set_yticks([])
    return txts