示例#1
0
def _add_margins(ax: plt.Axes, plot_data: np.ndarray, cutoff_lo: float, cutoff_hi: float, orient: str, margin: float):
    def _quantized_abs_ceil(x, q=0.5):
        return np.ceil(np.abs(x) / q) * q * np.sign(x)

    if orient == 'v':
        old_extents, lim_setter = ax.get_ylim(), ax.set_ylim
        ax.set_xlim(*list(map(_quantized_abs_ceil, ax.get_xlim())))
    else:
        old_extents, lim_setter = ax.get_xlim(), ax.set_xlim
        ax.set_ylim(*list(map(_quantized_abs_ceil, ax.get_ylim())))

    if np.min(plot_data) < cutoff_lo:
        lim_setter([old_extents[0] - margin * np.diff(old_extents), None])
    if np.max(plot_data) > cutoff_hi:
        lim_setter([None, old_extents[1] + margin * np.diff(old_extents)])
示例#2
0
    def __plot_bunches(self,
                       fig: plt.Figure,
                       ax: plt.Axes,
                       point: FiniteMetricVertex,
                       name: str = "u") -> None:
        """
        Plot all points and highlight the bunches for the given point on
        the provided figure/axes.

        :param fig: The matplotlib figure to plot on.
        :param ax: The matplotlib axes to plot on.
        :param point: The vertex whose bunches we wish to plot.
        :param name: The name to use to label the vertex/bunches.
        """
        ax.cla()

        # Plot all points and color by set A_i
        ax.scatter([v.i for v in self.vertices], [v.j for v in self.vertices],
                   s=4,
                   color="black",
                   marker=".",
                   label="Points")

        # Plot and label the point itself
        ax.scatter([point.i], [point.j],
                   s=12,
                   color="red",
                   marker="*",
                   label=name)
        ax.annotate(name, (point.i, point.j), color="red")

        # Force the xlim and ylim to become fixed
        ax.set_xlim(*ax.get_xlim())
        ax.set_ylim(*ax.get_ylim())

        # For the current point, mark and label its p_i s
        # and add circles
        p_i = [self.p[point][i] for i in range(self.k)]
        for i in range(1, self.k):
            if p_i[i] is None:
                continue
            ax.annotate("p_{}({})".format(i, name), (p_i[i][0].i, p_i[i][0].j),
                        xytext=(5, 5),
                        textcoords="offset pixels",
                        color="violet")
            circ = plt.Circle((point.i, point.j), p_i[i][1], fill=False)
            ax.add_patch(circ)

        # Plot the points in the bunch
        B = [w for w in self.B[point]]
        ax.scatter([w.i for w in B], [w.j for w in B],
                   s=12,
                   color="lime",
                   marker="*",
                   label="B({})".format(name))

        ax.set_title("Bunch B({})".format(name))
        ax.legend(bbox_to_anchor=(1.05, 1.0), loc='upper left')
        plt.tight_layout()
        fig.show()
示例#3
0
def plot_p_ch_vs_ev(ev_cond,
                    p_ch,
                    style='pred',
                    ax: plt.Axes = None,
                    **kwargs) -> plt.Line2D:
    """
    @param ev_cond: [condition] or [condition, frame]
    @type ev_cond: torch.Tensor
    @param p_ch: [condition, ch] or [condition, rt_frame, ch]
    @type p_ch: torch.Tensor
    @return:
    """
    if ax is None:
        ax = plt.gca()
    if ev_cond.ndim != 1:
        if ev_cond.ndim == 3:
            ev_cond = npt.p2st(ev_cond)[0]
        assert ev_cond.ndim == 2
        ev_cond = ev_cond.mean(1)
    if p_ch.ndim != 2:
        assert p_ch.ndim == 3
        p_ch = p_ch.sum(1)

    kwargs = get_kw_plot(style, **kwargs)

    h = ax.plot(*npys(ev_cond, npt.p2st(npt.sumto1(p_ch, -1))[1]), **kwargs)
    plt2.box_off(ax=ax)
    x_lim = ax.get_xlim()
    plt2.detach_axis('x', amin=x_lim[0], amax=x_lim[1], ax=ax)
    plt2.detach_axis('y', amin=0, amax=1, ax=ax)
    ax.set_yticks([0, 0.5, 1])
    ax.set_yticklabels(['0', '', '1'])
    ax.set_xlabel('evidence')
    ax.set_ylabel(r"$\mathrm{P}(z=1 \mid c)$")
    return h
示例#4
0
    def draw_LX_bounds(self, ax: plt.Axes, redshifts_on: bool = True):
        self.hconv = 0.70 / self.h

        ax.axhspan(self.bins[0][0] * 1e44 * self.hconv**2,
                   self.bins[-1][1] * 1e44 * self.hconv**2,
                   facecolor='lime',
                   linewidth=0,
                   alpha=0.2)
        ax.axhline(self.bins[0][0] * 1e44 * self.hconv**2,
                   color='lime',
                   linewidth=1,
                   alpha=0.1)

        for i, (luminosity_min, luminosity_max, redshift_min,
                redshift_max) in enumerate(self.bins):

            ax.axhline(luminosity_max * 1e44 * self.hconv**2,
                       color='lime',
                       linewidth=1,
                       alpha=0.1)

            # Print redshift bounds once every 2 bins to avoid clutter.
            if i % 2 == 0 and redshifts_on:
                ax.text(10**ax.get_xlim()[0],
                        10**(0.5 * np.log10(luminosity_min * luminosity_max)) *
                        1e44 * self.hconv**2,
                        f"$z$ = {redshift_min:.3f} - {redshift_max:.3f}",
                        horizontalalignment='left',
                        verticalalignment='center',
                        color='k',
                        alpha=0.3)
def ageify_axis(ax: plt.Axes):
    """
    Change the look of a given axis to match the AGE group standard.

    The axis will have ticks inside and outside on the lower and left axis.
    The top and right axis will get ticks on the inside only.
    Additionally, the text font will be switched to serif (for publiactions) both in
    normal text mode and math mode.

    Parameters
    ----------
    ax : plt.Axes
        Given axis that should be changed.

    Returns
    -------
    ax : TYPE
        Original axis object.
    ax_top : TYPE
        Created axis object representing the top axis.
    ax_right : TYPE
        Created axis object representing the right axis.
    """
    ax.spines["top"].set_visible(True)
    ax.spines["bottom"].set_visible(True)
    ax.spines["left"].set_visible(True)
    ax.spines["right"].set_visible(True)

    ax_right = ax.twinx()
    ax_top = ax.twiny()

    ax.tick_params(
        # axis="both",
        direction="inout",
        labeltop=False,
        labelright=False,
        bottom=True,
        left=True,
    )

    ax_right.tick_params(
        right=True,
        direction="in",
        labelright=False,
    )

    ax_right.set_ylim(ax.get_ylim())

    ax_top.tick_params(
        top=True,
        direction="in",
        labeltop=False,
    )

    ax_top.set_xlim(ax.get_xlim())

    plt.rcParams['font.family'] = 'serif'
    plt.rcParams['mathtext.fontset'] = 'dejavuserif'

    return (ax, ax_top, ax_right)
示例#6
0
def get_axes_limits(ax: Axes, ax_crs: CRS, crs: CRS = SphericalEarth):
    """ Get the limits of the window covered by an Axes in another coordinate
    system. """
    xl, xr = ax.get_xlim()
    yb, yt = ax.get_ylim()
    ax_bbox = get_axes_extent(ax, ax_crs, crs=crs)

    # Minimize bottom spine
    x_ = scipy.optimize.fminbound(lambda x: ax_crs.transform(crs, x, yb)[1],
                                  xl, xr)
    ymin = ax_crs.transform(crs, x_, yb)[1]

    # Maximize top spine
    x_ = scipy.optimize.fminbound(lambda x: -ax_crs.transform(crs, x, yt)[1],
                                  xl, xr)
    ymax = ax_crs.transform(crs, x_, yt)[1]

    # Minimize left spine
    y_ = scipy.optimize.fminbound(lambda y: ax_crs.transform(crs, xl, y)[0],
                                  yb, yt)
    xmin = ax_crs.transform(crs, xl, y_)[0]

    # Maximize right spine
    y_ = scipy.optimize.fminbound(lambda y: -ax_crs.transform(crs, xr, y)[0],
                                  yb, yt)
    xmax = ax_crs.transform(crs, xr, y_)[0]
    return xmin, xmax, ymin, ymax
示例#7
0
def _get_aspect_ratio(ax: plt.Axes) -> float:
    minx, maxx = ax.get_xlim()
    miny, maxy = ax.get_ylim()
    data_width, data_height = maxx - minx, maxy - miny
    if abs(data_height) > 1e-9:
        return data_width / data_height
    return 1.0
示例#8
0
文件: plt2.py 项目: yulkang/pylabyk
def sameaxes(ax: Union[AxesArray, GridAxes],
             ax0: plt.Axes = None,
             xy='xy',
             lim=None):
    """
    Match the chosen limits of axes in ax to ax0's (if given) or the max range.
    Also consider: ax1.get_shared_x_axes().join(ax1, ax2)
    Optionally followed by ax1.set_xticklabels([]); ax2.autoscale()
    See: https://stackoverflow.com/a/42974975/2565317
    :param ax: np.ndarray (as from subplotRCs) or list of axes.
    :param ax0: a scalar axes to match limits to. if None (default),
    match the maximum range among axes in ax.
    :param xy: 'x'|'y'|'xy'(default)
    :return: [[min, max]] of limits. If xy='xy', contains two pairs.
    """
    if type(ax) is np.ndarray or type(ax) is GridAxes:
        ax = ax.flatten()

    def cat_lims(lims):
        return np.concatenate([np.array(v1).reshape(1, 2) for v1 in lims])

    lims_res = []
    for xy1 in xy:
        if lim is None:
            if ax0 is None:
                if xy1 == 'x':
                    lims = cat_lims([ax1.get_xlim() for ax1 in ax])
                    lim0 = ax[0].get_xlim()
                    try:
                        is_inverted = ax[0].get_xaxis().get_inverted()
                    except AttributeError:
                        is_inverted = ax[0].xaxis_inverted()
                else:
                    lims = cat_lims([ax1.get_ylim() for ax1 in ax])
                    try:
                        is_inverted = ax[0].get_yaxis().get_inverted()
                    except AttributeError:
                        is_inverted = ax[0].yaxis_inverted()
                if is_inverted:
                    lims0 = [np.max(lims[:, 0]), np.min(lims[:, 1])]
                else:
                    lims0 = [np.min(lims[:, 0]), np.max(lims[:, 1])]
            else:
                if xy1 == 'x':
                    lims0 = ax0.get_xlim()
                else:
                    lims0 = ax0.get_ylim()
        else:
            lims0 = lim
        if xy1 == 'x':
            for ax1 in ax:
                ax1.set_xlim(lims0)
        else:
            for ax1 in ax:
                ax1.set_ylim(lims0)
        lims_res.append(lims0)
    return lims_res
示例#9
0
def timedXAxis(ax: plt.Axes, stepSize: float = 60000) -> plt.Axes:
    """ Makes the current axis use a custom time tick instead of just plain milliseconds """
    ticks = np.arange(0, ax.get_xlim()[1], stepSize)
    ax.set_xticks(ticks)
    ax.set_xticklabels([
        f"{int(RAConst.mSecToMin(t))}:{int(RAConst.mSecToSec(t) % 60):02d}"
        for t in ticks
    ])
    ax.margins(x=0)
    return ax
示例#10
0
def center_axis(axes: plt.Axes, which='y'):
    if which == 'y':
        max_abs = np.max(np.abs(axes.get_ylim()))
        axes.set_ylim(-max_abs, max_abs)
    elif which == 'x':
        max_abs = np.max(np.abs(axes.get_xlim()))
        axes.set_xlim(-max_abs, max_abs)
    elif which == 'both':
        pass
    return
示例#11
0
def get_axes_extent(ax: Axes, ax_crs: CRS, crs: CRS=SphericalEarth):
    """ Get the extent of an Axes in geographical (or other) coordinates. """
    xl, xr = ax.get_xlim()
    yb, yt = ax.get_ylim()

    ll = ax_crs.transform(crs, xl, yb)
    lr = ax_crs.transform(crs, xr, yb)
    ur = ax_crs.transform(crs, xr, yt)
    ul = ax_crs.transform(crs, xl, yt)
    return Polygon([ll, lr, ur, ul], crs=crs)
示例#12
0
def get_axes_extent(ax: Axes, ax_crs: CRS, crs: CRS = SphericalEarth):
    """ Get the extent of an Axes in geographical (or other) coordinates. """
    xl, xr = ax.get_xlim()
    yb, yt = ax.get_ylim()

    ll = ax_crs.transform(crs, xl, yb)
    lr = ax_crs.transform(crs, xr, yb)
    ur = ax_crs.transform(crs, xr, yt)
    ul = ax_crs.transform(crs, xl, yt)
    return Polygon([ll, lr, ur, ul], crs=crs)
示例#13
0
    def __plot_p_i(self,
                   fig: plt.Figure,
                   ax: plt.Axes,
                   point: FiniteMetricVertex,
                   name: str = "u") -> None:
        """
        Plot all points and highlight the witnesses p_i for the given point
        along with corresponding rings on the given figure and axes.

        :param fig: The matplotlib figure to plot on.
        :param ax: The matplotlib axes to plot on.
        :param point: The vertex whose witnesses/rings we wish to plot.
        :param name: The name to use to label the vertex/bunches.
        """
        ax.cla()

        # Plot all points and color by set A_i
        for i, a_i in enumerate(self.A):
            ax.scatter([v.i for v in a_i], [v.j for v in a_i],
                       s=8,
                       marker="o",
                       label="A_{}".format(i))

        # Plot and label the point itself
        ax.scatter([point.i], [point.j],
                   s=12,
                   color="red",
                   marker="*",
                   label=name)
        ax.annotate(name, (point.i, point.j), color="red")

        # Force the xlim and ylim to become fixed
        ax.set_xlim(*ax.get_xlim())
        ax.set_ylim(*ax.get_ylim())

        # For the current point, mark and label its p_i s
        # and add circles
        p_i = [self.p[point][i] for i in range(self.k)]
        for i in range(1, self.k):
            if p_i[i] is None:
                continue
            ax.annotate("p_{}({})".format(i, name), (p_i[i][0].i, p_i[i][0].j),
                        xytext=(5, 5),
                        textcoords="offset pixels",
                        color="violet")
            circ = plt.Circle((point.i, point.j), p_i[i][1], fill=False)
            ax.add_patch(circ)

        ax.set_title("Witnesses p_i({}) and rings.".format(name))
        ax.legend(bbox_to_anchor=(1.05, 1.0), loc='upper left')
        plt.tight_layout()
        fig.show()
示例#14
0
def make_dual_axis(ax: plt.Axes = None, axis='x', unit='nm', minor_ticks=True):
    if ax is None:
        ax = plt.gca()

    if axis == 'x':
        pseudo_ax = ax.twiny()
        limits = ax.get_xlim()
        u, l = 1e7 / np.array(limits)
        pseudo_ax.set_xlim(limits)
        sub_axis = pseudo_ax.xaxis

    elif axis == 'y':
        pseudo_ax = ax.twinx()
        limits = ax.get_ylim()
        u, l = 1e7 / np.array(limits)
        pseudo_ax.set_ylim(limits)
        sub_axis = pseudo_ax.yaxis
    else:
        raise ValueError('axis must be either x or y.')

    def conv(x, y):
        return '%.0f' % (1e7 / x)

    ff = plt.FuncFormatter(conv)
    sub_axis.set_major_formatter(ff)
    major = [1000, 500, 200, 100, 50]
    minor = [200, 100, 50, 25, 10]
    for x, m in zip(major, minor):
        a, b = math.ceil(u / x), math.ceil(l / x)
        n = abs(b - a)
        if n > 4:
            ticks = np.arange(
                a * x,
                b * x,
                x,
            )

            a, b = math.floor(u / m), math.floor(l / m)
            min_ticks = np.arange(a * m, b * m, m)

            break

    sub_axis.set_ticks(1e7 / ticks)
    sub_axis.set_ticks(1e7 / min_ticks, minor=True)
    if minor_ticks:
        ax.minorticks_on()
        # pseudo_ax.minorticks_on()
    if unit is 'nm':
        sub_axis.set_label('Wavelengths [nm]')
    elif unit is 'cm':
        sub_axis.set_label('Wavenumber [1/cm]')
示例#15
0
def add_lines(axes: plt.Axes,
              xs=None,
              ys=None,
              colors=None,
              **kwargs) -> plt.Axes:
    """
    Add horizontal or vertical lines to charts

    Args:
        axes: axes to add
        xs: Xs
        ys: Ys
        colors: list of colors (Xs first then Ys if both given
        **kwargs: kwargs to pass

    Returns:
        plt.Axes
    """
    idx = 0
    if colors is None: colors = ['darkgreen', 'darkorange', 'darkred']

    if xs is not None:
        if isinstance(xs, str): xs = [xs]
        if not hasattr(xs, '__iter__'): xs = [xs]
        ylim = axes.get_ylim()
        for x in xs:
            axes.vlines(x=x,
                        ymin=ylim[0],
                        ymax=ylim[1],
                        colors=colors[idx % len(colors)],
                        **kwargs)
            idx = idx + 1
        axes.set_ylim(ymin=ylim[0], ymax=ylim[1])

    if ys is not None:
        if isinstance(ys, str): ys = [ys]
        if not hasattr(ys, '__iter__'): ys = [ys]
        xlim = axes.get_xlim()
        for y in ys:
            axes.hlines(y=y,
                        xmin=xlim[0],
                        xmax=xlim[1],
                        colors=colors[idx % len(colors)],
                        **kwargs)
            idx = idx + 1
        axes.set_xlim(xmin=xlim[0], xmax=xlim[1])

    return axes
示例#16
0
def flip_axis(ax: plt.Axes, axis: str = "x") -> None:
    """Flip axis so it extends in the opposite direction.

    Parameters
    ----------
    ax : Axes
        Axes object with axis to flip.
    axis : str, optional
        Which axis to flip, by default "x".
    """
    if axis.lower() == "x":
        ax.set_xlim(reversed(ax.get_xlim()))
    elif axis.lower() == "y":
        ax.set_ylim(reversed(ax.get_ylim()))
    else:
        raise ValueError("`axis` must be 'x' or 'y'")
示例#17
0
def _plot_single_image_stats(image: np.ndarray, mask: np.ndarray, z_slice: int,
                             image_axes: Axes, hist_axes: Axes,
                             box_axes: Axes) -> None:
    data = image.flatten() if mask is None else image[mask > 0].flatten()
    box_axes.boxplot(data, notch=False, vert=False, sym=".", whis=[5, 95])
    hist_axes.hist(data, bins=30)
    image_axes.imshow(image[z_slice, :, :], cmap="Greys_r")
    image_axes.set_xticks([])
    image_axes.set_yticks([])
    # The histogram limits represent the full data range, set that also for the box plot
    # (box plot may show smaller range if no outliers are plotted)
    xlims = hist_axes.get_xlim()
    box_axes.set_xlim(left=xlims[0], right=xlims[1])
    box_axes.set_xticks([])  # Ticks match those of histogram anyway
    box_axes.set_yticks([])  # don't need that 1 tick mark
    hist_axes.set_yticks([])  # Number of voxels is not relevant
示例#18
0
def make_dual_axis(ax: plt.Axes = None, axis='x', unit='nm', minor_ticks=True):
    if ax is None:
        ax = plt.gca()

    if axis == 'x':
        pseudo_ax = ax.twiny()
        limits = ax.get_xlim()
        u, l = 1e7 / np.array(limits)
        pseudo_ax.set_xlim(limits)
        sub_axis = pseudo_ax.xaxis

    elif axis == 'y':
        pseudo_ax = ax.twinx()
        limits = ax.get_ylim()
        u, l = 1e7 / np.array(limits)
        pseudo_ax.set_ylim(limits)
        sub_axis = pseudo_ax.yaxis
    else:
        raise ValueError('axis must be either x or y.')

    def conv(x, y):
        return '%.0f' % (1e7 / x)

    ff = plt.FuncFormatter(conv)
    sub_axis.set_major_formatter(ff)
    major = [1000, 500, 200, 100, 50]
    minor = [200, 100, 50, 25, 10]
    for x, m in zip(major, minor):
        a, b = math.ceil(u / x), math.ceil(l / x)
        n = abs(b - a)
        if n > 4:
            ticks = np.arange(a * x, b * x, x, )

            a, b = math.floor(u / m), math.floor(l / m)
            min_ticks = np.arange(a * m, b * m, m)

            break

    sub_axis.set_ticks(1e7 / ticks)
    sub_axis.set_ticks(1e7 / min_ticks, minor=True)
    if minor_ticks:
        ax.minorticks_on()
        # pseudo_ax.minorticks_on()
    if unit is 'nm':
        sub_axis.set_label('Wavelengths [nm]')
    elif unit is 'cm':
        sub_axis.set_label('Wavenumber [1/cm]')
示例#19
0
文件: plt2.py 项目: Gravifer/pylabyk
def patch_chance_level(level=None,
                       signs=(-1, 1),
                       ax: plt.Axes = None,
                       xy='y',
                       color=(0.7, 0.7, 0.7)):
    if level is None:
        level = np.log(10.)
    if ax is None:
        ax = plt.gca()

    hs = []
    for sign in signs:
        if xy == 'y':
            if ax.yaxis.get_scale() == 'log':
                vmin = 1.
                level1 = level * 10**sign
            else:
                vmin = 0.
                level1 = level * sign

            lim = ax.get_xlim()
            rect = mpl.patches.Rectangle([lim[0], vmin],
                                         lim[1] - lim[0],
                                         level1,
                                         linewidth=0,
                                         fc=color,
                                         zorder=-1)
        elif xy == 'x':
            if ax.xaxis.get_scale() == 'log':
                vmin = 1.
                level1 = level * 10**sign
            else:
                vmin = 0.
                level1 = level * sign

            lim = ax.get_ylim()
            rect = mpl.patches.Rectangle([vmin, lim[0]],
                                         level1,
                                         lim[1] - lim[0],
                                         linewidth=0,
                                         fc=color,
                                         zorder=-1)
        ax.add_patch(rect)
        hs.append(rect)
    return hs
def add_sh_order_lines(ax: plt.Axes,
                       order=None,
                       args_dict=None,
                       x_flag=True,
                       y_flag=True):
    if args_dict is None:
        args_dict = {}
    from src.utils.sphere import sh

    if order is None:
        order = sh.i2nm(np.floor(ax.get_xlim()[1]))[0]

    n = np.arange(order)
    m = n
    locs = sh.nm2i(n, m) + 0.5
    for loc in locs:
        if x_flag:
            ax.axvline(loc, color='red', **args_dict)
        if y_flag:
            ax.axhline(loc, color='red', **args_dict)
示例#21
0
def get_axes_limits(ax: Axes, ax_crs: CRS, crs: CRS=SphericalEarth):
    """ Get the limits of the window covered by an Axes in another coordinate
    system. """
    xl, xr = ax.get_xlim()
    yb, yt = ax.get_ylim()

    # Minimize bottom spine
    x_ = scipy.optimize.fminbound(lambda x: ax_crs.transform(crs, x, yb)[1], xl, xr)
    ymin = ax_crs.transform(crs, x_, yb)[1]

    # Maximize top spine
    x_ = scipy.optimize.fminbound(lambda x: -ax_crs.transform(crs, x, yt)[1], xl, xr)
    ymax = ax_crs.transform(crs, x_, yt)[1]

    # Minimize left spine
    y_ = scipy.optimize.fminbound(lambda y: ax_crs.transform(crs, xl, y)[0], yb, yt)
    xmin = ax_crs.transform(crs, xl, y_)[0]

    # Maximize right spine
    y_ = scipy.optimize.fminbound(lambda y: -ax_crs.transform(crs, xr, y)[0], yb, yt)
    xmax = ax_crs.transform(crs, xr, y_)[0]
    return xmin, xmax, ymin, ymax
示例#22
0
    def plot_observations(self, axes: plt.Axes):
        sun_observations = Sun2009()
        r_r500, S_S500_50, S_S500_10, S_S500_90 = sun_observations.get_shortcut(
        )

        rexcess = Pratt2010(n_radial_bins=21)
        bin_median, bin_perc16, bin_perc84 = rexcess.combine_entropy_profiles(
            m500_limits=(1e14 * Solar_Mass, 5e14 * Solar_Mass),
            k500_rescale=True)

        r = np.array([*axes.get_xlim()])
        k = 1.40 * r**1.1
        axes.plot(r, k, c='grey', ls='--', label='VKB (2005)')

        asymmetric_error = np.array(
            list(zip(bin_median - bin_perc16, bin_perc84 - bin_median))).T
        axes.errorbar(rexcess.radial_bins,
                      bin_median,
                      yerr=asymmetric_error,
                      fmt='o',
                      markersize=2,
                      color='grey',
                      ecolor='lightgray',
                      elinewidth=0.7,
                      capsize=0,
                      label=rexcess.citation)
        asymmetric_error = np.array(
            list(zip(S_S500_50 - S_S500_10, S_S500_90 - S_S500_50))).T
        axes.errorbar(r_r500,
                      S_S500_50,
                      yerr=asymmetric_error,
                      fmt='^',
                      markersize=2,
                      color='grey',
                      ecolor='lightgray',
                      elinewidth=0.7,
                      capsize=0,
                      label=sun_observations.citation)
def adjust_axes_lim(ax: plt.Axes, min_x_thresh: float, max_x_thresh: float,
                    min_y_thresh: float, max_y_thresh: float):
    """
    Adjust ax limit so that it will be at least as small as min threshold and as big as max threshold.
    If min threshold is larger than existing min axes value it will not change (and vice versa for max).
    Thus the axes should be beforehand adjusted not to include padding around plot elements, as this will be
    included in min/max axes value as well.
    :param ax: Adjust range on axes object
    :param min_x_thresh: ax x_min must be at least that small.
    :param max_x_thresh: ax x_max must be at least that big.
    :param min_y_thresh: ax y_min must be at least that small.
    :param max_y_thresh: ax y_max must be at least that big.
    """
    y_min, y_max = ax.get_ylim()
    x_min, x_max = ax.get_xlim()

    if round(y_min, 3) >= round(min_y_thresh, 3):
        y_min = min_y_thresh
    else:
        print('min y was set to', y_min, 'instead of', min_y_thresh)
    if round(y_max, 3) <= round(max_y_thresh, 3):
        y_max = max_y_thresh
    else:
        print('max y was set to', y_max, 'instead of', max_y_thresh)

    if round(x_min, 3) >= round(min_x_thresh, 3):
        x_min = min_x_thresh
    else:
        print('min x was set to', x_min, 'instead of', min_x_thresh)
    if round(x_max, 3) <= round(max_x_thresh, 3):
        x_max = max_x_thresh
    else:
        print('max x was set to', x_max, 'instead of', max_x_thresh)

    ax.set_ylim(y_min, y_max)
    ax.set_xlim(x_min, x_max)
def set_new_axlim(axs: plt.Axes,
                  data: np.ndarray,
                  err: _ty.Optional[np.ndarray] = None,
                  *,
                  yaxis: bool = True,
                  reset: bool = False,
                  buffer: float = 0.05) -> None:
    """Set axes limits that will show all data, including existing.

    Parameters
    ----------
    axs : mpl.axes.Axes
        `Axes` instance whose limits are to be adjusted.
    data : np.ndarray (n)
        Array of numbers plotted along the axis.
    err : None|np.ndarray (n,)|(2,n), optional
        Error bars for data, by default: `None`.
    yaxis : bool, optional
        Are we modifying the y axis? By default: `True`.
    reset : bool, optional
        Do we ignore the existing axis limits? By default: `False`.
    buffer : float, optional
        Fractional padding around data, by default `0.05`.
    """
    log = (axs.get_yscale() if yaxis else axs.get_xscale()) == 'log'
    lim = calc_axlim(data, err, log, buffer)
    if not reset:
        if yaxis:
            axlim = axs.get_ylim()
        else:
            axlim = axs.get_xlim()
        lim = min(lim[0], axlim[0]), max(lim[1], axlim[1])
    if yaxis:
        axs.set_ylim(lim)
    else:
        axs.set_xlim(lim)
示例#25
0
def add_percent_axis(ax: plt.Axes, data_size, flip_axis: bool = False) -> plt.Axes:
    """
    Adds a twin axis with percentages to a count plot.

    Args:
        ax: Plot axes figure to add percentage axis to
        data_size: Total count to use to normalize percentages
        flip_axis: Whether the countplot had its axes flipped

    Returns:
        Twin axis that percentages were added to
    """
    if flip_axis:
        ax_perc = ax.twiny()
        ax_perc.set_xticks(100 * ax.get_xticks() / data_size)
        ax_perc.set_xlim(
            (
                100.0 * (float(ax.get_xlim()[0]) / data_size),
                100.0 * (float(ax.get_xlim()[1]) / data_size),
            )
        )
        ax_perc.xaxis.set_major_formatter(mtick.PercentFormatter())
        ax_perc.xaxis.set_tick_params(labelsize=10)
    else:
        ax_perc = ax.twinx()
        ax_perc.set_yticks(100 * ax.get_yticks() / data_size)
        ax_perc.set_ylim(
            (
                100.0 * (float(ax.get_ylim()[0]) / data_size),
                100.0 * (float(ax.get_ylim()[1]) / data_size),
            )
        )
        ax_perc.yaxis.set_major_formatter(mtick.PercentFormatter())
        ax_perc.yaxis.set_tick_params(labelsize=10)
    ax_perc.grid(False)
    return ax_perc
示例#26
0
def _add_doubling_time_lines(
    fig: plt.Figure,
    ax: plt.Axes,
    *,
    stage: DiseaseStage,
    count: Counting,
    x_axis: Columns.XAxis,
):
    """Add doubling time lines to the given plot

    On a log-scale graph, doubling time lines originate from a point near the
    lower-left and show how fast the number of cases (per capita) would grow if it
    doubled every n days.

    :param fig: The figure containing the plots
    :type fig: plt.Figure
    :param ax: The axes object we wish to annotate
    :type ax: plt.Axes
    :param x_axis: The column to be used for the x-axis of the graph. We only add
    doubling time lines for graphs plotted against days since outbreak (and not actual
    days, as doubling time lines don't make sense then because there is no common
    origin to speak of)
    :type x_axis: Columns.XAxis
    :param stage: The disease stage we are plotting
    :type stage: DiseaseStage
    :param count: The count method used
    :type count: Counting
    """

    DiseaseStage.verify(stage)
    Counting.verify(count)
    Columns.XAxis.verify(x_axis)

    # For ease of computation, everything will be in axes coordinate system
    # Variable names beginning with "ac" refer to axis coords and "dc" to data coords
    # {ac,dc}_{x,y}_{min,max} refer to the coordinates of the doubling-time lines
    if x_axis is Columns.XAxis.DAYS_SINCE_OUTBREAK:
        # Create transformation from data coords to axes coords
        # This composes two transforms, data -> fig and (axes -> fig)^(-1)
        dc_to_ac = ax.transData + ax.transAxes.inverted()

        dc_x_lower_lim, dc_x_upper_lim = ax.get_xlim()
        dc_y_lower_lim, dc_y_upper_lim = ax.get_ylim()

        # Adding stuff causes the axis to resize itself, and we have to stop it
        # from doing so (by setting it back to its original size)
        ax.set_xlim(dc_x_lower_lim, dc_x_upper_lim)

        # Also need to add back margin
        dc_y_upper_lim = dc_to_ac.inverted().transform((0, 1.1))[1]
        ax.set_ylim(dc_y_lower_lim, dc_y_upper_lim)

        # Getting min x,y bounds of lines is easy
        dc_x_min = 0
        dc_y_min = CaseInfo.get_info_item_for(InfoField.THRESHOLD,
                                              stage=stage,
                                              count=count)

        ac_x_min, ac_y_min = dc_to_ac.transform((dc_x_min, dc_y_min))

        # Getting max x,y bounds is trickier due to needing to use the maximum
        # extent of the graph area
        # Get top right corner of graph in data coords (to avoid the edges of the
        # texts' boxes clipping the axes, we move things in just a hair)
        ac_x_upper_lim = ac_y_upper_lim = 1

        doubling_times = [1, 2, 3, 4, 7, 14]  # days (x-axis units)
        for dt in doubling_times:
            # Simple math: assuming dc_y_max := dc_y_upper_lim, then if
            # dc_y_max = dc_y_min * 2**((dc_x_max-dc_x_min)/dt),
            # then...
            dc_x_max = dc_x_min + dt * np.log2(dc_y_upper_lim / dc_y_min)
            ac_x_max, ac_y_max = dc_to_ac.transform((dc_x_max, dc_y_upper_lim))

            # We try to use ac_y_max=1 by default, and if that leads to too long a line
            # (sticking out through the right side of the graph) then we use ac_x_max=1
            # instead and compute ac_y_max accordingly
            if ac_x_max > ac_x_upper_lim:
                dc_y_max = dc_y_min * 2**((dc_x_upper_lim - dc_x_min) / dt)
                ac_x_max, ac_y_max = dc_to_ac.transform(
                    (dc_x_upper_lim, dc_y_max))
                edge = EdgeGuide.RIGHT
            else:
                edge = EdgeGuide.TOP

            # Plot the lines themselves
            ax.plot(
                [ac_x_min, ac_x_max],
                [ac_y_min, ac_y_max],
                transform=ax.transAxes,
                color="0.0",
                alpha=0.7,
                dashes=(1, 2),
                linewidth=1,
            )

            # Annotate lines with assocated doubling times

            # Get text to annotate with
            n_weeks, weekday = divmod(dt, 7)
            if weekday == 0:
                annot_text_str = f"{n_weeks} week"
                if n_weeks != 1:
                    annot_text_str += "s"
            else:
                annot_text_str = f"{dt} day"
                if dt != 1:
                    annot_text_str += "s"

            text_props = {
                "bbox": {
                    "fc": "1.0",
                    "pad": 0,
                    # "edgecolor": "1.0",
                    "alpha": 0.7,
                    "lw": 0,
                }
            }

            # Plot in a temporary location just to get the text box size; we'll move and
            # rotate later
            plotted_text = ax.text(0,
                                   0,
                                   annot_text_str,
                                   text_props,
                                   transform=ax.transAxes)

            ac_line_slope = (ac_y_max - ac_y_min) / (ac_x_max - ac_x_min)
            ac_text_angle_rad = np.arctan(ac_line_slope)
            sin_ac_angle = np.sin(ac_text_angle_rad)
            cos_ac_angle = np.cos(ac_text_angle_rad)

            # Get the unrotated text box bounds
            ac_text_box = plotted_text.get_window_extent(
                fig.canvas.get_renderer()).transformed(ax.transAxes.inverted())
            ac_text_width = ac_text_box.x1 - ac_text_box.x0
            ac_text_height = ac_text_box.y1 - ac_text_box.y0

            # Compute the width and height of the upright rectangle bounding the rotated
            # text box in axis coordinates
            # Simple geometry (a decent high school math problem)
            # We cheat a bit; to create some padding between the rotated text box and
            # the axes, we can add the padding directly to the width and height of the
            # upright rectangle bounding the rotated text box
            # This works because the origin of the rotated text box is in the lower left
            # corner of the upright bounding rectangle, so anything added to these
            # dimensions gets added to the top and right, pushing it away from the axes
            # and producing the padding we want
            # If we wanted to do this the "right" way we'd *redo* the calculations above
            # but with ac_x_upper_lim = ac_y_upper_lim = 1 - padding
            PADDING = 0.005
            ac_rot_text_width = ((ac_text_width * cos_ac_angle) +
                                 (ac_text_height * sin_ac_angle) + PADDING)
            ac_rot_text_height = ((ac_text_width * sin_ac_angle) +
                                  (ac_text_height * cos_ac_angle) + PADDING)

            # Perpendicular distance from text to corresponding line
            AC_DIST_FROM_LINE = 0.005
            # Get text box origin relative to line upper endpoint
            EdgeGuide.verify(edge)
            if edge is EdgeGuide.RIGHT:
                # Account for bit of overhang; when slanted, top left corner of the
                # text box extends left of the bottom left corner, which is its origin
                # Subtracting that bit of overhang (height * sin(theta)) gets us the
                # x-origin
                # This only applies to the x coord; the bottom left corner of the text
                # box is also the bottom of the rotated rectangle
                ac_text_origin_x = ac_x_max - (ac_rot_text_width -
                                               ac_text_height * sin_ac_angle)
                ac_text_origin_y = (
                    ac_y_min + (ac_text_origin_x - ac_x_min) * ac_line_slope +
                    AC_DIST_FROM_LINE / cos_ac_angle)

            # If text box is in very top right of graph, it may use only the right
            # edge of the graph as a guide and hence clip through the top; if that
            # happens, it's effectively the same situation as using the top edge from
            # the start
            if (edge is EdgeGuide.TOP  # Must go first to short-circuit
                    or ac_text_origin_y + ac_rot_text_height > ac_y_upper_lim):
                ac_text_origin_y = ac_y_upper_lim - ac_rot_text_height
                ac_text_origin_x = (
                    ac_x_min - AC_DIST_FROM_LINE / sin_ac_angle +
                    (ac_text_origin_y - ac_y_min) / ac_line_slope)

            # set_x and set_y work in axis coordinates
            plotted_text.set_x(ac_text_origin_x)
            plotted_text.set_y(ac_text_origin_y)
            plotted_text.set_horizontalalignment("left")
            plotted_text.set_verticalalignment("bottom")
            plotted_text.set_rotation(ac_text_angle_rad * 180 /
                                      np.pi)  # takes degrees
            plotted_text.set_rotation_mode("anchor")
示例#27
0
def draw_categorical(
    plot_type: str,
    ax: plt.Axes,
    data: Union[list, np.ndarray, to.Tensor, pd.DataFrame],
    x_label: Optional[Union[str, Sequence[str]]],
    y_label: Optional[str],
    vline_level: float = None,
    vline_label: str = "approx. solved",
    palette=None,
    title: str = None,
    show_legend: bool = True,
    legend_kwargs: dict = None,
    plot_kwargs: dict = None,
) -> plt.Figure:
    """
    Create a box or violin plot for a list of data arrays or a pandas DataFrame.
    The plot is neither shown nor saved.

    If you want to order the 4th element to the 2nd position in terms of colors use

    .. code-block:: python

        palette.insert(1, palette.pop(3))

    .. note::
        If you want to have a tight layout, it is best to pass axes of a figure with `tight_layout=True` or
        `constrained_layout=True`.

    :param plot_type: tye of categorical plot, pass box or violin
    :param ax: axis of the figure to plot on
    :param data: list of data sets to plot as separate boxes
    :param x_label: labels for the categories on the x-axis, if `data` is not given as a `DataFrame`
    :param y_label: label for the y-axis, pass `None` to set no label
    :param vline_level: if not `None` (default) add a vertical line at the given level
    :param vline_label: label for the vertical line
    :param palette: seaborn color palette, pass `None` to use the default palette
    :param show_legend: if `True` the legend is shown, useful when handling multiple subplots
    :param title: title displayed above the figure, set to None to suppress the title
    :param legend_kwargs: keyword arguments forwarded to pyplot's `legend()` function, e.g. `loc='best'`
    :param plot_kwargs: keyword arguments forwarded to seaborn's `boxplot()` or `violinplot()` function
    :return: handle to the resulting figure
    """
    plot_type = plot_type.lower()
    if plot_type not in ["box", "violin"]:
        raise pyrado.ValueErr(given=plot_type, eq_constraint="box or violin")
    if not isinstance(data, (list, to.Tensor, np.ndarray, pd.DataFrame)):
        raise pyrado.TypeErr(
            given=data,
            expected_type=[list, to.Tensor, np.ndarray, pd.DataFrame])

    # Set defaults which can be overwritten
    plot_kwargs = merge_dicts([dict(alpha=1),
                               plot_kwargs])  # by default no transparency
    alpha = plot_kwargs.pop(
        "alpha")  # can't pass the to the seaborn plotting functions
    legend_kwargs = dict() if legend_kwargs is None else legend_kwargs
    palette = sns.color_palette() if palette is None else palette

    # Preprocess
    if isinstance(data, pd.DataFrame):
        df = data
    else:
        if isinstance(data, list):
            data = np.array(data)
        elif isinstance(data, to.Tensor):
            data = data.detach().cpu().numpy()
        if x_label is not None and not len(x_label) == data.shape[1]:
            raise pyrado.ShapeErr(given=data, expected_match=x_label)
        df = pd.DataFrame(data, columns=x_label)

    if data.shape[0] < data.shape[1]:
        print_cbt(
            f"Less data samples {data.shape[0]} then data dimensions {data.shape[1]}",
            "y",
            bright=True)

    # Plot
    if plot_type == "box":
        ax = sns.boxplot(data=df, ax=ax, **plot_kwargs)

    elif plot_type == "violin":
        plot_kwargs = merge_dicts([
            dict(alpha=0.3, scale="count", inner="box", bw=0.3, cut=0),
            plot_kwargs
        ])
        ax = sns.violinplot(data=df, ax=ax, palette=palette, **plot_kwargs)

        # Plot larger circles for medians (need to memorize the limits)
        medians = df.median().to_numpy()
        left, right = ax.get_xlim()
        locs = ax.get_xticks()
        ax.scatter(locs,
                   medians,
                   marker="o",
                   s=30,
                   zorder=3,
                   color="white",
                   edgecolors="black")
        ax.set_xlim((left, right))

    # Postprocess
    if alpha < 1 and plot_type == "box":
        for patch in ax.artists:
            r, g, b, a = patch.get_facecolor()
            patch.set_facecolor((r, g, b, alpha))
    elif alpha < 1 and plot_type == "violin":
        for violin in ax.collections[::2]:
            violin.set_alpha(alpha)

    if vline_level is not None:
        # Add dashed line to mark a threshold
        ax.axhline(vline_level, c="k", ls="--", lw=1.0, label=vline_label)

    if x_label is None:
        ax.get_xaxis().set_ticks([])

    if y_label is not None:
        ax.set_ylabel(y_label)

    if show_legend:
        ax.legend(**legend_kwargs)

    if title is not None:
        ax.set_title(title)

    return plt.gcf()
示例#28
0
def _level_number_variance(
    unfolded: ndarray,
    data: pd.DataFrame,
    title: str = "Level Number Variance",
    mode: PlotMode = "block",
    outfile: Path = None,
    ensembles: List[str] = ["poisson", "goe", "gue", "gse"],
    fig: Figure = None,
    axes: Axes = None,
) -> PlotResult:
    """Plot the computed level number variance against the various expected number
    level variance curves for the classical ensembles.

    Parameters
    ----------
    unfolded: ndarray
        The unfolded eigenvalues to plot.

    data: DataFrame
        `data` argument is pd.DataFrame({"L": L_vals, "sigma": sigma}), where sigma
        are the values computed from
        observables.levelvariance.level_number_variance

    title: string
        The plot title string

    mode: "block" (default) | "noblock" | "save" | "return"
        If "block", call plot.plot() and display plot in a blocking fashion.
        If "noblock", attempt to generate plot in nonblocking fashion.
        If "save", save plot to pathlib Path specified in `outfile` argument
        If "return", return (fig, axes), the matplotlib figure and axes object
        for modification.

    outfile: Path
        If mode="save", save generated plot to Path specified in `outfile` argument.
        Intermediate directories will be created if needed.

    ensembles: ["poisson", "goe", "gue", "gse"]
        Which ensembles to display the expected number level variance curves for comparison against.

    fig: Figure
        If provided with `axes`, configure plotting with the provided `fig`
        object instead of creating a new figure. Useful for creating subplots.

    axes: Axes
        If provided with `fig`, plot to the provided `axes` object. Useful for
        creating subplots.


    Returns
    -------
    (fig, axes): (Figure, Axes)
        The handles to the matplotlib objects, only if `mode` is "return".
    """
    _configure_sbn_style()
    fig, axes = _setup_plotting(fig, axes)
    df = pd.DataFrame(data, columns=["L", "sigma"])
    # sbn.relplot(x="L", y="sigma", data=df, ax=axes)
    sbn.scatterplot(x="L", y="sigma", data=df, ax=axes)
    ensembles = set(ensembles)  # type: ignore

    # _, right = plt.xlim()
    _, right = axes.get_xlim()

    L = df["L"]
    p, y = np.pi, np.euler_gamma
    # s = L / np.mean(unfolded[1:] - unfolded[:-1])
    s = L

    def exact(x: float) -> float:
        def f1(r: float) -> Any:
            return (np.sin(r) / r) ** 2

        # re-arranging the formula for sici from
        # https://docs.scipy.org/doc/scipy/reference/generated/scipy.special.sici.html
        # to match Mehta (2004) p595, A.38, we get:
        int_2 = y + np.log(2 * p * x) - sici(2 * p * x)[1]
        # sici(x) returns (Sine integral from 0 to x, gamma + log(x) + Cosine integral from 0 to x)
        int_3 = (sici(np.inf)[0] - sici(p * x)[0]) ** 2
        t1 = 4 * x / p
        t2 = 2 / p ** 2
        t3 = t2 / 2
        res = (
            t1 * quad(f1, p * x, np.inf, limit=100)[0] + t2 * int_2 - 0.25 + t3 * int_3
        )
        return float(res)

    if "poisson" in ensembles:
        poisson = L / 2  # waste of time, too large very often
        poisson = axes.plot(L, poisson, label="Poisson")
        plt.setp(poisson, color="#08FD4F")
    if "goe" in ensembles:
        goe = np.zeros(s.shape)
        with warnings.catch_warnings():  # ignore integration, divide-by-zero warnings
            warnings.simplefilter("ignore")
            for i, s_val in enumerate(s):
                if L[i] < 10:
                    goe[i] = exact(s_val)
                else:
                    goe[i] = (2 / (p ** 2)) * (
                        np.log(2 * p * s_val) + y + 1 - (p ** 2) / 8
                    )
        goe = axes.plot(L, goe, label="Gaussian Orthogonal")
        plt.setp(goe, color="#FD8208")
    if "gue" in ensembles:
        gue = (1 / (p ** 2)) * (np.log(2 * p * s) + y + 1)
        gue = axes.plot(L, gue, label="Gaussian Unitary")
        plt.setp(gue, color="#0066FF")
    if "gse" in ensembles:
        gse = (1 / (2 * (p ** 2))) * (np.log(4 * p * s) + y + 1 + (p ** 2) / 8)
        gse = axes.plot(L, gse, label="Gaussian Symplectic")
        plt.setp(gse, color="#EA00FF")

    axes.set(
        title=title,
        xlabel="L",
        ylabel="Sigma^2(L)",
    )
    axes.legend().set_visible(True)
    return _handle_plot_mode(mode, fig, axes, outfile)
示例#29
0
def _spectral_rigidity(
    unfolded: Optional[ndarray],
    data: pd.DataFrame,
    title: str = "Spectral Rigidity",
    mode: PlotMode = "block",
    outfile: Path = None,
    ensembles: List[str] = ["poisson", "goe", "gue", "gse"],
    fig: Figure = None,
    axes: Axes = None,
) -> PlotResult:
    """Plot the computed spectral rigidity against the various expected spectral
    rigidity curves for the classical ensembles.

    Parameters
    ----------
    unfolded: ndarray
        The unfolded eigenvalues to plot.

    data: DataFrame
        `data` argument is pd.DataFrame({"L": L_vals, "delta": delta3})
        TODO: fix this

    title: string
        The plot title string

    mode: "block" (default) | "noblock" | "save" | "return"
        If "block", call plot.plot() and display plot in a blocking fashion.
        If "noblock", attempt to generate plot in nonblocking fashion.
        If "save", save plot to pathlib Path specified in `outfile` argument
        If "return", return (fig, axes), the matplotlib figure and axes object
        for modification.

    outfile: Path
        If mode="save", save generated plot to Path specified in `outfile` argument.
        Intermediate directories will be created if needed.

    ensembles: ["poisson", "goe", "gue", "gse"]
        Which ensembles to display the expected spectral rigidity curves for comparison against.

    fig: Figure
        If provided with `axes`, configure plotting with the provided `fig`
        object instead of creating a new figure. Useful for creating subplots.

    axes: Axes
        If provided with `fig`, plot to the provided `axes` object. Useful for
        creating subplots.


    Returns
    -------
    (fig, axes): (Figure, Axes)
        The handles to the matplotlib objects, only if `mode` is "return".
    """
    _configure_sbn_style()
    fig, axes = _setup_plotting(fig, axes)
    df = pd.DataFrame(data, columns=["L", "delta"])
    # sbn.relplot(x="L", y="delta", data=df, ax=axes)
    sbn.scatterplot(x="L", y="delta", data=df, ax=axes)
    ensembles = set(ensembles)  # type: ignore

    # _, right = plt.xlim()
    _, right = axes.get_xlim()

    L = df["L"]
    p, y = np.pi, np.euler_gamma

    # see pg 290 of Mehta (2004) for definition of s
    s = L / np.mean(unfolded[1:] - unfolded[:-1]) if unfolded is not None else L

    if "poisson" in ensembles:
        poisson = L / 15 / 2
        poisson = axes.plot(L, poisson, label="Poisson")
        plt.setp(poisson, color="#08FD4F")
    if "goe" in ensembles:
        goe = (1 / (p ** 2)) * (np.log(2 * p * s) + y - 5 / 4 - (p ** 2) / 8)
        goe = axes.plot(L, goe, label="Gaussian Orthogonal")
        plt.setp(goe, color="#FD8208")
    if "gue" in ensembles:
        gue = (1 / (2 * (p ** 2))) * (np.log(2 * p * s) + y - 5 / 4)
        gue = axes.plot(L, gue, label="Gaussian Unitary")
        plt.setp(gue, color="#0066FF")
    if "gse" in ensembles:
        gse = (1 / (4 * (p ** 2))) * (np.log(4 * p * s) + y - 5 / 4 + (p ** 2) / 8)
        gse = axes.plot(L, gse, label="Gaussian Symplectic")
        plt.setp(gse, color="#EA00FF")

    axes.set(title=title, xlabel="L", ylabel="∆3(L)")
    axes.legend().set_visible(True)
    return _handle_plot_mode(mode, fig, axes, outfile)
示例#30
0
def label_ticks(xs: Iterable[float], ys: Iterable[float], ax: Axes=None,
        map_crs: CRS=Cartesian, graticule_crs: CRS=SphericalEarth,
        textargs=None, tickargs=None,
        xformatter=None, yformatter=None):
    """ Label graticule lines, returning a list if Text objects.

    Parameters
    ----------
    xs : Iterable[float],
    ys : Iterable[float]
        Easting and northing componenets of labels, in `graticule_crs`
    ax : Axes, optional
        Axes to draw to (default current Axes)
    map_crs : karta.crs.CRS, optional
        CRS giving the display projection (default Cartesian)
    graticule_crs : karta.crs.CRS, optional
        CRS giving the graticule/label projection (default SphericalEarth)
    textargs : dict, optional
        Keyword arguments to pass to plt.text
    tickargs : dict, optional
        Keyword arguments to pass to plt.plot
    xformatter : callable, optional
        function that given an easting/longitude returns a label
    yformatter : callable, optional
        function that given a northing/latitude returns a label
    """
    if textargs is None:
        textargs = dict()

    if tickargs is None:
        tickargs = dict(marker="+", mew=2, ms=14, mfc="k", mec="k", ls="none")

    if xformatter is None:
        xformatter = lambda x: "{0} E".format(x)

    if yformatter is None:
        yformatter = lambda y: "{0} N".format(y)

    # Find tick locations
    bbox = get_axes_extent(ax, map_crs, graticule_crs)  # bottom, right, top, left

    ticks = dict(xticks=[], yticks=[])

    xmin, xmax = sorted(ax.get_xlim())
    ymin, ymax = sorted(ax.get_ylim())

    # bottom spine
    for x in xs:
        if isbetween(x, bbox[0][0], bbox[1][0]):
            ticks["xticks"].append((froot(lambda xt:
                                          map_crs.transform(graticule_crs, xt, ymin)[0]-x,
                                          xmin, xmax), ymin, xformatter(x)))

    for y in ys:
        if isbetween(y, bbox[0][1], bbox[1][1]):
            ticks["yticks"].append((froot(lambda xt:
                                          map_crs.transform(graticule_crs, xt, ymin)[1]-y,
                                          xmin, xmax), ymin, yformatter(y)))

    # top spine
    for x in xs:
        if isbetween(x, bbox[2][0], bbox[3][0]):
            ticks["xticks"].append((froot(lambda xt:
                                          map_crs.transform(graticule_crs, xt, ymax)[0]-x,
                                          xmin, xmax), ymax, xformatter(x)))

    for y in ys:
        if isbetween(y, bbox[2][1], bbox[3][1]):
            ticks["yticks"].append((froot(lambda xt:
                                          map_crs.transform(graticule_crs, xt, ymax)[1]-y,
                                          xmin, xmax), ymax, yformatter(y)))

    # left spine
    for x in xs:
        if isbetween(x, bbox[0][0], bbox[3][0]):
            ticks["xticks"].append((xmin,
                                    froot(lambda yt:
                                          map_crs.transform(graticule_crs, xmin, yt)[0]-x,
                                          ymin, ymax), xformatter(x)))


    for y in ys:
        if isbetween(y, bbox[0][1], bbox[3][1]):
            ticks["yticks"].append((xmin,
                                    froot(lambda yt:
                                          map_crs.transform(graticule_crs, xmin, yt)[1]-y,
                                          ymin, ymax), yformatter(y)))


    # right spine
    for x in xs:
        if isbetween(x, bbox[1][0], bbox[2][0]):
            ticks["xticks"].append((xmax,
                                    froot(lambda yt:
                                          map_crs.transform(graticule_crs, xmax, yt)[0]-x,
                                          ymin, ymax), xformatter(x)))

    for y in ys:
        if isbetween(y, bbox[1][1], bbox[2][1]):
            ticks["yticks"].append((xmax,
                                    froot(lambda yt:
                                          map_crs.transform(graticule_crs, xmax, yt)[1]-y,
                                          ymin, ymax), yformatter(y)))

    # Update map
    txts = []
    for pt in ticks["xticks"]:
        ax.plot(pt[0], pt[1], **tickargs)
        txts.append(ax.text(pt[0], pt[1], pt[2], **textargs))

    for pt in ticks["yticks"]:
        ax.plot(pt[0], pt[1], **tickargs)
        txts.append(ax.text(pt[0], pt[1], pt[2], **textargs))

    ax.set_xticks([])
    ax.set_yticks([])
    return txts
示例#31
0
    def overlay_entropy_profiles(self,
                                 axes: plt.Axes = None,
                                 r_units: str = 'r500',
                                 k_units: str = 'K500adi',
                                 vkb05_line: bool = True,
                                 color: str = 'k',
                                 alpha: float = 1.,
                                 markersize: float = 1,
                                 linewidth: float = 0.5) -> None:

        stand_alone = False
        if axes is None:
            stand_alone = True
            fig, axes = plt.subplots()
            axes.loglog()
            axes.set_xlabel(f'$r$ [{r_units}]')
            axes.set_ylabel(f'$K$ [${k_units}$]')
            axes.axvline(1, linestyle=':', color=color, alpha=alpha)

        # Set-up entropy data
        fields = [
            'K_500', 'K_1000', 'K_1500', 'K_2500', 'K_0p15r500', 'K_30kpc'
        ]
        K_stat = dict()
        if k_units == 'K500adi':
            K_conv = 1 / getattr(self, 'K_500_adi')
            axes.axhline(1, linestyle=':', color=color, alpha=alpha)
        elif k_units == 'keVcm^2':
            K_conv = np.ones_like(getattr(self, 'K_500_adi'))
            axes.fill_between(np.array(axes.get_xlim()),
                              y1=np.nanmin(self.K_500_adi),
                              y2=np.nanmax(self.K_500_adi),
                              facecolor='k',
                              alpha=0.3)
        else:
            raise ValueError("Conversion unit unknown.")
        for field in fields:
            data = np.multiply(getattr(self, field), K_conv)
            K_stat[field] = (np.nanpercentile(data,
                                              16), np.nanpercentile(data, 50),
                             np.nanpercentile(data, 84))
            K_stat[field.replace('K',
                                 'num')] = np.count_nonzero(~np.isnan(data))

        # Set-up radial distance data
        r_stat = dict()
        if r_units == 'r500':
            r_conv = 1 / getattr(self, 'r_500')
        elif r_units == 'r2500':
            r_conv = 1 / getattr(self, 'r_2500')
        elif r_units == 'kpc':
            r_conv = np.ones_like(getattr(self, 'r_2500'))
        else:
            raise ValueError("Conversion unit unknown.")
        for field in ['r_500', 'r_1000', 'r_1500', 'r_2500']:
            data = np.multiply(getattr(self, field), r_conv)
            if k_units == 'K500adi':
                data[np.isnan(self.K_500_adi)] = np.nan
            r_stat[field] = (np.nanpercentile(data,
                                              16), np.nanpercentile(data, 50),
                             np.nanpercentile(data, 84))
            r_stat[field.replace('r',
                                 'num')] = np.count_nonzero(~np.isnan(data))
        data = np.multiply(getattr(self, 'r_500') * 0.15, r_conv)
        if k_units == 'K500adi':
            data[np.isnan(self.K_500_adi)] = np.nan
        r_stat['r_0p15r500'] = (np.nanpercentile(data, 16),
                                np.nanpercentile(data, 50),
                                np.nanpercentile(data, 84))
        r_stat['num_0p15r500'] = np.count_nonzero(~np.isnan(data))
        data = np.multiply(
            np.ones_like(getattr(self, 'r_2500')) * 30 * unyt.kpc, r_conv)
        if k_units == 'K500adi':
            data[np.isnan(self.K_500_adi)] = np.nan
        r_stat['r_30kpc'] = (np.nanpercentile(data,
                                              16), np.nanpercentile(data, 50),
                             np.nanpercentile(data, 84))
        r_stat['num_30kpc'] = np.count_nonzero(~np.isnan(data))

        for suffix in [
                '_500', '_1000', '_1500', '_2500', '_0p15r500', '_30kpc'
        ]:
            x_low, x, x_hi = r_stat['r' + suffix]
            y_low, y, y_hi = K_stat['K' + suffix]
            num_objects = f"{r_stat['num' + suffix]}, {K_stat['num' + suffix]}"
            point_label = f"r{suffix:.<17s} Num(x,y) = {num_objects}"
            if stand_alone:
                axes.scatter(x, y, label=point_label, s=markersize)
                axes.errorbar(x,
                              y,
                              yerr=[[y_hi - y], [y - y_low]],
                              xerr=[[x_hi - x], [x - x_low]],
                              ls='none',
                              ms=markersize,
                              lw=linewidth)
            else:
                axes.scatter(x, y, color=color, alpha=alpha, s=markersize)
                axes.errorbar(x,
                              y,
                              yerr=[[y_hi - y], [y - y_low]],
                              xerr=[[x_hi - x], [x - x_low]],
                              ls='none',
                              ecolor=color,
                              alpha=alpha,
                              ms=markersize,
                              lw=linewidth)

        if vkb05_line:
            if r_units == 'r500' and k_units == 'K500adi':
                r = np.linspace(*axes.get_xlim(), 31)
                k = 1.40 * r**1.1 / self.hconv
                axes.plot(r, k, linestyle='--', color=color, alpha=alpha)
            else:
                print((
                    "The VKB05 adiabatic threshold should be plotted only when both "
                    "axes are in scaled units, since the line is calibrated on an NFW "
                    "profile with self-similar halos with an average concentration of "
                    "c_500 ~ 4.2 for the objects in the Sun et al. (2009) sample."
                ))

        if k_units == 'K500adi':
            r_r500, S_S500_50, S_S500_10, S_S500_90 = self.get_shortcut()

            plt.fill_between(r_r500,
                             S_S500_10,
                             S_S500_90,
                             color='grey',
                             alpha=0.5,
                             linewidth=0)
            plt.plot(r_r500, S_S500_50, c='k')

        if stand_alone:
            plt.legend()
            plt.show()
示例#32
0
def plot_p_ch_vs_ev(
        ev_cond: Union[torch.Tensor, np.ndarray],
        n_ch: Union[torch.Tensor, np.ndarray],
        style='pred',
        ax: plt.Axes = None,
        dim_rel=0,
        group_dcond_irr: Iterable[Iterable[int]] = None,
        cmap: Union[str, Callable] = 'cool',
        kw_plot=(),
) -> Iterable[plt.Line2D]:
    """
    @param ev_cond: [condition, dim] or [condition, frame, dim, (mean, var)]
    @type ev_cond: torch.Tensor
    @param n_ch: [condition, ch] or [condition, rt_frame, ch]
    @type n_ch: torch.Tensor
    @return: hs[cond_irr][0] = Line2D, conds_irr
    """
    if ax is None:
        ax = plt.gca()
    if ev_cond.ndim != 2:
        assert ev_cond.ndim == 4
        ev_cond = npt.p2st(ev_cond.mean(1))[0]
    if n_ch.ndim != 2:
        assert n_ch.ndim == 3
        n_ch = n_ch.sum(1)

    ev_cond = npy(ev_cond)
    n_ch = npy(n_ch)
    n_cond_all = n_ch.shape[0]
    ch_rel = np.repeat(np.array(consts.CHS[dim_rel])[None, :], n_cond_all, 0)
    n_ch = n_ch.reshape([-1])
    ch_rel = ch_rel.reshape([-1])

    dim_irr = consts.get_odim(dim_rel)
    conds_rel, dcond_rel = np.unique(ev_cond[:, dim_rel], return_inverse=True)
    conds_irr, dcond_irr = np.unique(np.abs(ev_cond[:, dim_irr]),
                                     return_inverse=True)

    if group_dcond_irr is not None:
        conds_irr, dcond_irr = group_conds(conds_irr, dcond_irr,
                                           group_dcond_irr)

    n_conds = [len(conds_rel), len(conds_irr)]

    n_ch_rel = npg.aggregate(
        np.stack([
            ch_rel,
            np.repeat(dcond_irr[:, None], consts.N_CH_FLAT, 1).flatten(),
            np.repeat(dcond_rel[:, None], consts.N_CH_FLAT, 1).flatten(),
        ]), n_ch, 'sum', [consts.N_CH, n_conds[1], n_conds[0]])
    p_ch_rel = n_ch_rel[1] / n_ch_rel.sum(0)

    hs = []
    for dcond_irr1, p_ch1 in enumerate(p_ch_rel):
        if type(cmap) is str:
            color = plt.get_cmap(cmap, n_conds[1])(dcond_irr1)
        else:
            color = cmap(n_conds[1])(dcond_irr1)
        kw1 = get_kw_plot(style, color=color, **dict(kw_plot))
        h = ax.plot(conds_rel, p_ch1, **kw1)
        hs.append(h)
    plt2.box_off(ax=ax)
    x_lim = ax.get_xlim()
    plt2.detach_axis('x', amin=x_lim[0], amax=x_lim[1], ax=ax)
    plt2.detach_axis('y', amin=0, amax=1, ax=ax)
    ax.set_yticks([0, 0.5, 1])
    ax.set_yticklabels(['0', '', '1'])
    ax.set_xlabel('evidence')
    ax.set_ylabel(r"$\mathrm{P}(z=1 \mid c)$")

    return hs, conds_irr
示例#33
0
def show_latent(
    seq: Sequence,
    ax: plt.Axes = None,
    bounds: Optional[Tuple] = None,
    colors: Sequence = None,
    show_bars: bool = True,
    bar_width: float = 0.1,
    bar_location: str = "top",
    show_vlines: bool = True,
    vline_kws: Optional[dict] = None,
    shift: float = 0,
):
    """ Display a bar plot showing how the latent state changes with time.

    The bars are drawn either above or below the current extents of the plot, expanding
    the y limits appropriately.

    Parameters
    ----------
    seq
        Sequence indicating the latent state.
    ax
        Axes in which to draw the bars. If not given, `plt.gca()` is used.
    bounds
        If not `None`, this should be a tuple `(t0, t1)` such that the latent state is
        shown only for time points `t >= t0` and `t < t1`. If this is `None`, the
        extents are inferred from the current axis limits.
    colors
        Sequence of colors to use for the identities. By default Matplotlib's default
        color cycle is used.
    show_bars
        If `True`, colored bars are drawn to indicate the current state.
    bar_width
        Width of the bars, given as a fraction of the vertical extent of the plot. Note
        that this is calculated at the moment the function is called.
    bar_location
        Location of the bars. This can be "top" or "bottom".
    show_vlines
        If `True`, vertical lines are drawn to show transition points.
    vline_kws
        Keywords to pass to `axvline`.
    shift
        Amount by which to shift bars and lines to the right (towards higher values).
    """
    # handle trivial case
    if len(seq) == 0:
        return

    # handle defaults
    if ax is None:
        ax = plt.gca()
    if colors is None:
        prop_cycle = plt.rcParams["axes.prop_cycle"]
        colors = prop_cycle.by_key()["color"]
    if bounds is None:
        bounds = ax.get_xlim()

    # find transition points
    transitions = np.diff(seq).nonzero()[0] + 1

    # find the first transition in the given range
    visible_mask = transitions + shift >= bounds[0]
    if np.any(visible_mask):
        idx0 = visible_mask.argmax()
    else:
        idx0 = None

    if show_vlines and idx0 is not None:
        # set up the vline parameters
        if vline_kws is not None:
            crt_vline_kws = copy.copy(vline_kws)
        else:
            crt_vline_kws = {}
        crt_vline_kws.setdefault("ls", ":")
        crt_vline_kws.setdefault("lw", 0.5)
        crt_vline_kws.setdefault("c", "k")

        for transition in transitions[idx0:]:
            if transition + shift >= bounds[1]:
                break
            ax.axvline(transition + shift, **crt_vline_kws)

    if show_bars:
        # find how big the bar is in data coordinates...
        yl = ax.get_ylim()
        yrange = yl[1] - yl[0]
        bar_width_data = yrange * bar_width

        # ...and where to place it
        if bar_location == "top":
            bar_y = yl[1]
            # adjust limits
            yl = (yl[0], bar_y + bar_width_data)
        elif bar_location == "bottom":
            bar_y = yl[0] - bar_width_data
            # adjust limits
            yl = (bar_y, yl[1])
        else:
            raise ValueError("Unknown bar location option.")

        # start drawing!
        x0 = max(bounds[0] - shift, 0)
        if idx0 is not None:
            next_idx = idx0
        else:
            next_idx = len(transitions) + 1
        while x0 + shift < bounds[1] and int(x0) < len(seq):
            crt_id = seq[int(x0)]
            x1 = transitions[next_idx] if next_idx < len(transitions) else len(
                seq)
            x1 = min(x1, bounds[1] - shift)
            if x1 > x0:
                patch = patches.Rectangle(
                    (x0 + shift, bar_y),
                    x1 - x0,
                    bar_width_data,
                    edgecolor="none",
                    facecolor=colors[crt_id % len(colors)],
                )
                ax.add_patch(patch)

            next_idx += 1
            x0 = x1

        # adjust limits
        ax.set_ylim(*yl)
示例#34
0
def plot_expression_by_distance(
    ax: plt.Axes,
    data: Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray],
    feature: Optional[str] = None,
    include_background: bool = True,
    curve_label: Optional[str] = None,
    flavor: str = "normal",
    color_scheme: Optional[Dict[str, str]] = None,
    ratio_order: List[str] = ["central", "portal"],
    list_flavor_choices: bool = False,
    feature_type: str = "Feature",
    distance_scale_factor: float = 1.0,
    **kwargs,
) -> None:
    """Generate feature by distance plots

    Function for seamless production of feature
    by distance plots.

    Parameters:
    ----------

    ax : plt.Axes
        axes object to plot data in
    data : Tuple[np.ndarray,np.ndarray,np.ndarray,np.ndarray]
        Tuple of data to plot, should be in form :
        (xs,ys,y_hat,std_err), the output which
        utils.smooth_fit produces.
    feature : Optional[str] (None)
        Name of plotted feature, set to None to exclude this
        information.
    include_background : bool (True)
        Set to True to include data points used
        to fit the smoothed data
    curve_label : Optional[str] (None)
        label of plotted data. Set to None to exclude
        legend.
    flavor : str = "normal",
        flavor of data, choose between 'normal' or
        'logodds'.
    color_scheme : Optional[Dict[str,str]] (None)
        dictionary providing the color scheme to be used.
        Include 'background':'color' to change color
        original data, include 'envelope':'color' to set
        color of envelope, include 'feature_class':'color' to
        set color of class.
    ratio_order : List[str] (["central","portal"])
        if logodds flavor is used, then specify which
        element was nominator (first element) and
        denominator (second element).
    list_flavor_choices : bool (False)
        set to True to list flavor choices
    feature_type : str ("Feature")
        Name of feature to plot, will be prepended to title
        as Feature : X. Set to None to only plot name X. Set
        to None to exclude feature type from being indicated
        in title and y-axis.
    distance_scale_factor : float (1.0)
        scaling factor to multiply distances with

    Returns:
    -------

    Tuple with Matplotlib Figure and Axes object, containing
    feature by distance plots.

    """

    flavors = ["normal", "logodds", "single_vein"]
    if list_flavor_choices:
        print("Flavors to choose from are : {}".format(', '.join(flavors)))
        return None
    if flavor not in flavors:
        raise ValueError("Not a valid flavor")

    if len(data) != 4:
        raise ValueError("Data must be (xs,ys,ys_hat,stderr)")

    if color_scheme is None:
        color_scheme = {}

    scaled_distances = data[0] * (distance_scale_factor if \
                                  flavor != "logodds" else 1.0)

    if include_background:
        ax.scatter(
            scaled_distances,
            data[1],
            s=1,
            c=color_scheme.get("background", "gray"),
            alpha=0.4,
        )

    ax.fill_between(
        scaled_distances,
        data[2] - data[3],
        data[2] + data[3],
        alpha=0.2,
        color=color_scheme.get("envelope", "grey"),
    )

    ax.set_title(
        "{} : {}".format(("" if feature_type is None else feature_type),
                         ("" if feature is None else feature)),
        fontsize=kwargs.get("title_font_size", kwargs.get("fontsize", 15)),
    )
    ax.set_ylabel(
        "{} Value".format(("" if feature_type is None else feature_type)),
        fontsize=kwargs.get("label_font_size", kwargs.get("fontsize", 15)),
    )

    if flavor == "normal":
        unit = ("" if "distance_unit" not in\
                kwargs.keys() else " [{}]".format(kwargs["distance_unit"]))

        ax.set_xlabel(
            "Distance to vein{}".format(unit),
            fontsize=kwargs.get("label_font_size", kwargs.get("fontsize", 15)),
        )

    if flavor == "logodds":

        x_min, x_max = ax.get_xlim()

        ax.axvspan(
            xmin=x_min,
            xmax=0,
            color=color_scheme.get(ratio_order[0], "red"),
            alpha=0.2,
        )

        ax.axvspan(
            xmin=0,
            xmax=x_max,
            color=color_scheme.get(ratio_order[1], "blue"),
            alpha=0.2,
        )

        d1 = ratio_order[0][0]
        d2 = ratio_order[1][0]
        ax.set_xlabel(
            r"$\log(d_{}) - \log(d_{})$".format(d1, d2),
            fontsize=kwargs.get("label_font_size", kwargs.get("fontsize", 15)),
        )

    ax.plot(
        scaled_distances,
        data[2],
        c=color_scheme.get("fitted", "black"),
        linewidth=2,
        label=("none" if curve_label is None else curve_label),
    )

    if "tick_fontsize" in kwargs.keys():
        ax.tick_params(axis="both",
                       which="major",
                       labelsize=kwargs["tick_fontsize"])