Exemplo n.º 1
0
def plot_latency_throughput(df: pd.DataFrame, ax: plt.Axes) -> None:
    def outlier_throughput(g: pd.DataFrame) -> float:
        cutoff = 0.5 * g['throughput'].max()
        return g[g['throughput'] >= cutoff]['throughput'].mean()

    def outlier_throughput_std(g: pd.DataFrame) -> float:
        cutoff = 0.5 * g['throughput'].max()
        return g[g['throughput'] >= cutoff]['throughput'].std()

    def outlier_latency(g: pd.DataFrame) -> float:
        cutoff = 5 * g['latency'].min()
        return g[g['latency'] <= cutoff]['latency'].mean()

    grouped = df.groupby('num_clients')
    for (name, group) in grouped:
        print(f'# {name}')
        print(group[['throughput', 'latency']])
    throughput = grouped.apply(outlier_throughput).sort_index()
    latency = grouped.apply(outlier_latency).sort_index()
    throughput_std = grouped.apply(outlier_throughput_std).sort_index()
    print(f'throughput = {throughput}.')
    print(f'latency = {latency}.')
    print()
    line = ax.plot(throughput, latency, '.-', linewidth=2)[0]
    ax.fill_betweenx(latency,
                     throughput - throughput_std,
                     throughput + throughput_std,
                     color=line.get_color(),
                     alpha=0.25)
Exemplo n.º 2
0
def save_one2dspec(ax: plt.Axes, spec: np.ndarray, edges: Tuple[np.ndarray,
                                                                np.ndarray],
                   traces: List[np.ndarray], fwhms: List[np.ndarray]) -> None:
    all_left, all_right = edges

    norm = ImageNormalize(spec, interval=ZScaleInterval())
    im = ax.imshow(spec, origin='upper', norm=norm, cmap='gray')
    #im, norm = imshow_norm(spec, interval=ZScaleInterval(), cmap='gray')

    plt.axis('off')

    xs = np.arange(spec.shape[0])

    for i in range(all_left.shape[1]):
        ax.plot(all_left[:, i], xs, 'green', lw=1)
        ax.plot(all_right[:, i], xs, 'red', lw=1)

    if traces is not None:
        for trace, fwhm in zip(traces, fwhms):
            ax.plot(trace, xs, 'orange', lw=1)
            ax.fill_betweenx(xs,
                             trace - fwhm,
                             trace + fwhm,
                             color='orange',
                             alpha=0.2)
Exemplo n.º 3
0
def plot_lt(df: pd.DataFrame, ax: plt.Axes, marker: str, label: str) -> None:
    grouped = df.groupby('num_clients')
    throughput = grouped['throughput'].agg(np.mean).sort_index() / 1000
    throughput_std = grouped['throughput'].agg(np.std).sort_index() / 1000
    latency = grouped['latency'].agg(np.mean).sort_index()
    line = ax.plot(throughput, latency, marker, label=label, linewidth=2)[0]
    ax.fill_betweenx(latency,
                     throughput - throughput_std,
                     throughput + throughput_std,
                     color=line.get_color(),
                     alpha=0.25)
def plot_lt(df: pd.DataFrame, ax: plt.Axes, title: str) -> None:
    def outlier_throughput(g: pd.DataFrame) -> float:
        cutoff = 0.5 * g['throughput'].max()
        return g[g['throughput'] >= cutoff]['throughput'].mean() / 100000

    def outlier_throughput_std(g: pd.DataFrame) -> float:
        cutoff = 0.5 * g['throughput'].max()
        return g[g['throughput'] >= cutoff]['throughput'].std() / 100000

    # Draw throughput.
    grouped = df.groupby([
        'num_shards', 'num_replicas', 'num_proxy_replicas',
        'server_options.push_size', 'server_options.push_period',
        'aggregator_options.num_shard_cuts_per_proposal',
        'replica_options.unsafe_yolo_execution'
    ])
    for (name, group) in grouped:
        print(f'## {name}')
        print(group[['throughput', 'latency']])

        by_clients = group.groupby('num_clients')
        throughput = by_clients['throughput'].agg(np.mean).sort_index() / 1000
        throughput_std = by_clients['throughput'].agg(
            np.std).sort_index() / 1000
        latency = by_clients['latency'].agg(np.mean).sort_index()
        line = ax.plot(throughput,
                       latency,
                       '-',
                       marker=next(MARKERS),
                       label=name,
                       linewidth=2)[0]
        ax.fill_betweenx(latency,
                         throughput - throughput_std,
                         throughput + throughput_std,
                         color=line.get_color(),
                         alpha=0.25)

        ax.grid()
        ax.set_title(title)
        ax.set_xlabel('Throughput (100,000 commands per second)')
        ax.set_ylabel('Latency\n(milliseconds)')
        ax.legend(loc='center left', bbox_to_anchor=(1, 0.5))
Exemplo n.º 5
0
    def plot_residual(self, ax: plt.Axes) -> plt.Axes:
        # compute the residual and observation standard deviation
        residual = self.df["residual"]
        obs_se = self.df["residual_se"]
        max_obs_se = np.quantile(obs_se, 0.99)
        fill_index = self.df[self.model.cwdata.col_study_id].str.contains(
            "fill")

        # create funnel plot
        ax = plt.subplots()[1] if ax is None else ax
        ax.set_ylim(max_obs_se, 0.0)
        ax.scatter(residual, obs_se, color="gray", alpha=0.4)
        if fill_index.sum() > 0:
            ax.scatter(residual[fill_index],
                       obs_se[fill_index],
                       color="#008080",
                       alpha=0.7)
        ax.scatter(residual[self.df.outlier == 1],
                   obs_se[self.df.outlier == 1],
                   color='red',
                   marker='x',
                   alpha=0.4)
        ax.fill_betweenx([0.0, max_obs_se], [0.0, -1.96 * max_obs_se],
                         [0.0, 1.96 * max_obs_se],
                         color='#B0E0E6',
                         alpha=0.4)
        ax.plot([0, -1.96 * max_obs_se], [0.0, max_obs_se],
                linewidth=1,
                color='#87CEFA')
        ax.plot([0.0, 1.96 * max_obs_se], [0.0, max_obs_se],
                linewidth=1,
                color='#87CEFA')
        ax.axvline(0.0, color='k', linewidth=1, linestyle='--')
        ax.set_xlabel("residual")
        ax.set_ylabel("ln_rr_se")
        ax.set_title(
            f"{self.name}: egger_mean={self.se_model['mean']: .3f}, "
            f"egger_sd={self.se_model['sd']: .3f}, "
            f"egger_pval={self.se_model['pval']: .3f}",
            loc="left")
        return ax
def plot_lt(df: pd.DataFrame, ax: plt.Axes, title: str) -> None:
    def outlier_throughput(g: pd.DataFrame) -> float:
        cutoff = 0.5 * g['throughput'].max()
        return g[g['throughput'] >= cutoff]['throughput'].mean() / 100000

    def outlier_throughput_std(g: pd.DataFrame) -> float:
        cutoff = 0.5 * g['throughput'].max()
        return g[g['throughput'] >= cutoff]['throughput'].std() / 100000

    grouped = df.groupby([
        'num_replicas',
        'num_proxy_leaders',
        'num_acceptor_groups',
        'num_acceptors_per_group',
        'workload.write_size_mean',
        'leader_options.flush_phase2as_every_n'
    ])
    for (name, group) in grouped:
        print(f'## {name}')
        print(group[['throughput', 'latency']])

        by_clients = group.groupby('num_clients')
        throughput = by_clients['throughput'].agg(np.mean).sort_index() / 1000
        throughput_std = by_clients['throughput'].agg(np.std).sort_index() / 1000
        latency = by_clients['latency'].agg(np.mean).sort_index()
        line = ax.plot(throughput, latency, '-', marker=next(MARKERS),
                       label=name, linewidth=2)[0]
        ax.fill_betweenx(latency,
                         throughput - throughput_std,
                         throughput + throughput_std,
                         color = line.get_color(),
                         alpha=0.25)

    ax.set_title(title)
    ax.set_xlabel('Throughput (100,000 commands per second)')
    ax.set_ylabel('Latency\n(milliseconds)')
    ax.legend(loc='center left', bbox_to_anchor=(1, 0.5))
    ax.grid(b=True)
Exemplo n.º 7
0
def plot_section(rfst,
                 channel="PRF",
                 timelimits: list or tuple or None = None,
                 epilimits: list or tuple or None = None,
                 scalingfactor: float = 2.0,
                 ax: plt.Axes = None,
                 line: bool = True,
                 linewidth: float = 0.25,
                 outputfile: str or None = None,
                 title: str or None = None,
                 show: bool = True,
                 format: str = None):
    """Creates plot of a receiver function section as a function
    of epicentral distance.

    Parameters
    ----------
    rfst : :class:`pyglimer.RFStream`
        Stream of receiver functions
    timelimits : list or tuple or None
        y axis time limits in seconds (len(list)==2).
        If `None` full traces is plotted.
        Default None.
    epilimits : list or tuple or None = None,
        y axis time limits in seconds (len(list)==2).
        If `None` from 30 to 90 degrees plotted.
        Default None.
    scalingfactor : float
        sets the scale for the traces. Could be automated in
        future functions(Something like mean distance between
        traces)
        Defaults to 2.0
    line : bool
        plots black line of the actual RF
        Defaults to True
    linewidth: float
        sets linewidth of individual traces
    ax : `matplotlib.pyplot.Axes`, optional
        Can define an axes to plot the RF into. Defaults to None.
        If None, new figure is created.
    outputdir : str, optional
        If set, saves a pdf of the plot to the directory.
        If None, plot will be shown instantly. Defaults to None.
    clean: bool
        If True, clears out all axes and plots RF only.
        Defaults to False.

     Returns
    -------
    ax : `matplotlib.pyplot.Axes`

    """
    set_mpl_params()

    # Create figure if no axes is specified
    if ax is None:
        plt.figure(figsize=(8, 6))
        ax = plt.axes(zorder=9999999)

    # Grab one component only
    # That doesn't work anymore. Was there an update in the obspy function?
    # rfst_chan = rfst.sort(channel=channel).sort(keys=['distance'])
    rfst_chan = rfst.sort(keys=['distance'])

    if not len(rfst_chan):
        raise ValueError(
            'There are no receiver functions of channel %s in the RFStream.' %
            channel)

    # Plot traces
    for _i, rf in enumerate(rfst_chan):
        if rf.stats.type == 'time':
            times = rf.times() - (rf.stats.onset - rf.stats.starttime)
            if rf.stats.phase[-1] == 'S':
                times = np.flip(times)
        else:
            z = rf.stats.pp_depth
            times = z

        rftmp = rf.data * scalingfactor \
            + rf.stats.distance
        ax.fill_betweenx(times,
                         rf.stats.distance,
                         rftmp,
                         where=rftmp < rf.stats.distance,
                         interpolate=True,
                         color=(0.2, 0.2, 0.7),
                         zorder=-_i,
                         alpha=.8)
        ax.fill_betweenx(times,
                         rf.stats.distance,
                         rftmp,
                         where=rftmp > rf.stats.distance,
                         interpolate=True,
                         color=(0.9, 0.2, 0.2),
                         zorder=-_i - 0.1,
                         alpha=.8)
        if line:
            ax.plot(rftmp, times, 'k', lw=linewidth, zorder=-_i + 0.1)

    # Set limits
    if epilimits is None:
        plt.xlim(epilimits)
    else:
        plt.xlim(epilimits)

    if timelimits is None:
        if rfst[0].stats.type == 'time':
            ylim0 = 0
        else:
            ylim0 = times[0]
        ylim1 = times[-1] + ylim0
        plt.ylim(ylim0, ylim1)
    else:
        plt.ylim(timelimits)
    ax.invert_yaxis()

    # Set labels
    plt.xlabel(r"$\Delta$ [$^{\circ}$]")
    if rfst[0].stats.type == 'time':
        plt.ylabel(r"Time [s]")
    else:
        plt.ylabel(r"Depth [km]")

    # Set title
    if title is not None:
        plt.title(title)
    else:
        plt.title("%s component" % channel)

    # Set output directory
    if outputfile:
        plt.savefig(outputfile, dpi=300, transparent=True, format=format)
    elif show:
        plt.show()
    return ax
Exemplo n.º 8
0
def plot_single_rf(rf,
                   tlim: list or tuple or None = None,
                   ylim: list or tuple or None = None,
                   depth: np.ndarray or None = None,
                   ax: plt.Axes = None,
                   outputdir: str = None,
                   pre_fix: str = None,
                   post_fix: str = None,
                   format: str = 'pdf',
                   clean: bool = False,
                   std: np.ndarray = None,
                   flipxy: bool = False):
    """Creates plot of a single receiver function

    Parameters
    ----------
    rf : :class:`pyglimer.RFTrace`
        single receiver function trace
    tlim: list or tuple or None
        x axis time limits in seconds if type=='time' or depth in km if
        type==depth (len(list)==2).
        If `None` full trace is plotted.
        Default None.
    ylim: list or tuple or None
        y axis amplitude limits in. If `None` ± 1.05 absmax. Default None.
    depth: :class:`numpy.ndarray`
        1D array of depths
    ax : `matplotlib.pyplot.Axes`, optional
        Can define an axes to plot the RF into. Defaults to None.
        If None, new figure is created.
    outputdir : str, optional
        If set, saves a pdf of the plot to the directory.
        If None, plot will be shown instantly. Defaults to None.
    pre_fix : str, optional
        prepend filename
    post_fix : str, optional
        append to filename
    clean: bool, optional
        If True, clears out all axes and plots RF only.
        Defaults to False.
    std: np.ndarray, optional
            **Only if self.type == stastack**. Plots the upper and lower
            limit of the standard deviation in the plot. Provide the std
            as a numpy array (can be easily computed from the output of
            :meth:`~pyglimer.rf.create.RFStream.bootstrap`)
    flipxy: bool, optional
        Plot Depth/Time on the Y-Axis and amplitude on the x-axis. Defaults
        to False.

     Returns
    -------
    ax : `matplotlib.pyplot.Axes`
    """
    set_mpl_params()

    # Get figure/axes dimensions
    if ax is None:
        if flipxy:
            height, width = 8, 3
        else:
            width, height = 10, 2.5
        fig = plt.figure(figsize=(width, height))
        ax = plt.axes(zorder=9999999)
        axtmp = None
    else:
        fig = plt.gcf()
        bbox = ax.get_window_extent().transformed(
            fig.dpi_scale_trans.inverted())
        width, height = bbox.width, bbox.height
        axtmp = ax

    # The ratio ensures that the text
    # is perfectly distanced from top left/right corner
    ratio = width / height

    # Use times depending on phase and moveout correction
    ydata = rf.data
    if rf.stats.type == 'time':
        # Get times
        times = rf.times() - (rf.stats.onset - rf.stats.starttime)
        if rf.stats.phase[-1] == 'S':
            times = np.flip(times)
            ydata = np.flip(-rf.data)
    else:
        z = np.hstack(((np.arange(-10, 0, .1)), np.arange(0, maxz + res, res)))
        times = z

    # Plot stuff into axes
    if flipxy:
        if std is not None:
            ax.plot(ydata - std, times, 'k--', lw=0.75)
            ax.plot(ydata + std, times, 'k--', lw=0.75)
            ax.fill_betweenx(times,
                             0,
                             ydata,
                             where=ydata > 0,
                             interpolate=True,
                             color=(0.9, 0.2, 0.2),
                             alpha=.8)
            ax.fill_betweenx(times,
                             0,
                             ydata,
                             where=ydata < 0,
                             interpolate=True,
                             color=(0.2, 0.2, 0.7),
                             alpha=.8)
        else:
            ax.fill_betweenx(times,
                             0,
                             ydata,
                             where=ydata > 0,
                             interpolate=True,
                             color=(0.9, 0.2, 0.2),
                             alpha=.8)
            ax.fill_betweenx(times,
                             0,
                             ydata,
                             where=ydata < 0,
                             interpolate=True,
                             color=(0.2, 0.2, 0.7),
                             alpha=.8)
        ax.plot(ydata, times, 'k', lw=0.75)

        # Set limits
        if tlim is None:
            # don't really wanna see the stuff before
            ax.set_ylim(0, times[-1])
        else:
            ax.set_ylim(tlim)

        if ylim is None:
            absmax = 1.1 * np.max(np.abs(ydata))
            ax.set_xlim([-absmax, absmax])
        else:
            ax.set_xlim(ylim)
        ax.invert_yaxis()
    else:
        if std is not None:
            ax.plot(times, ydata - std, 'k--', lw=0.75)
            ax.plot(times, ydata + std, 'k--', lw=0.75)
            ax.fill_between(times,
                            0,
                            ydata,
                            where=ydata > 0,
                            interpolate=True,
                            color=(0.9, 0.2, 0.2),
                            alpha=.8)
            ax.fill_between(times,
                            0,
                            ydata,
                            where=ydata < 0,
                            interpolate=True,
                            color=(0.2, 0.2, 0.7),
                            alpha=.8)
        else:
            ax.fill_between(times,
                            0,
                            ydata,
                            where=ydata > 0,
                            interpolate=True,
                            color=(0.9, 0.2, 0.2),
                            alpha=.8)
            ax.fill_between(times,
                            0,
                            ydata,
                            where=ydata < 0,
                            interpolate=True,
                            color=(0.2, 0.2, 0.7),
                            alpha=.8)
        ax.plot(times, ydata, 'k', lw=0.75)

        # Set limits
        if tlim is None:
            ax.set_xlim(0, times[-1])
            # don't really wanna see the stuff before
        else:
            ax.set_xlim(tlim)

        if ylim is None:
            absmax = 1.1 * np.max(np.abs(ydata))
            ax.set_ylim([-absmax, absmax])
        else:
            ax.set_ylim(ylim)

    # Removes top/right axes spines. If you want the whole thing, comment
    # or remove
    remove_topright()

    # Plot RF only
    if clean:
        remove_all()
    else:
        if rf.stats.type == 'time':
            if flipxy:
                ax.set_ylabel("Conversion Time [s]", rotation=90)
            else:
                ax.set_xlabel("Conversion Time [s]")
        else:
            if flipxy:
                ax.set_ylabel("Conversion Depth [km]", rotation=90)
            else:
                ax.set_xlabel("Conversion Depth [km]")
        if flipxy:
            ax.set_xlabel("A    ", rotation=0)
        else:
            ax.set_ylabel("A    ", rotation=0)

        # Start time in station stack does not make sense
        if rf.stats.type == 'stastack':
            text = rf.get_id()
        else:
            text = rf.stats.starttime.isoformat(sep=" ") + "\n" + rf.get_id()
        ax.text(0.995,
                1.0 - 0.005 * ratio,
                text,
                transform=ax.transAxes,
                horizontalalignment="right",
                verticalalignment="top")

    # Only use tight layout if not part of plot.
    if axtmp is None:
        plt.tight_layout()

    # Outout the receiver function as pdf using
    # its station name and starttime

    if outputdir is not None:
        # Set pre and post fix
        if pre_fix is not None:
            pre_fix = pre_fix + "_"
        else:
            pre_fix = ""
        if post_fix is not None:
            post_fix = "_" + post_fix
        else:
            post_fix = ""

        # Get filename
        filename = os.path.join(
            outputdir, pre_fix + rf.get_id() + "_" +
            rf.stats.starttime.strftime('%Y%m%dT%H%M%S') + post_fix +
            f".{format}")
        plt.savefig(filename, format=format, transparent=True)
    else:
        plt.show()

    return ax
Exemplo n.º 9
0
def plot_section(rfst,
                 channel="PRF",
                 timelimits: list or tuple or None = None,
                 epilimits: list or tuple or None = None,
                 scalingfactor: float = 2.0,
                 ax: plt.Axes = None,
                 line: bool = True,
                 linewidth: float = 0.25,
                 outputdir: str or None = None,
                 title: str or None = None,
                 show: bool = True):
    """Creates plot of a receiver function section as a function
    of epicentral distance.

    Parameters
    ----------
    rfst : :class:`pyglimer.RFStream`
        Stream of receiver functions
    timelimits : list or tuple or None
        y axis time limits in seconds (len(list)==2).
        If `None` full traces is plotted.
        Default None.
    epilimits : list or tuple or None = None,
        y axis time limits in seconds (len(list)==2).
        If `None` from 30 to 90 degrees plotted.
        Default None.
    scalingfactor : float
        sets the scale for the traces. Could be automated in 
        future functions(Something like mean distance between
        traces)
        Defaults to 2.0
    line : bool
        plots black line of the actual RF
        Defaults to True
    linewidth: float
        sets linewidth of individual traces
    ax : `matplotlib.pyplot.Axes`, optional
        Can define an axes to plot the RF into. Defaults to None.
        If None, new figure is created.
    outputdir : str, optional
        If set, saves a pdf of the plot to the directory.
        If None, plot will be shown instantly. Defaults to None.
    clean: bool
        If True, clears out all axes and plots RF only.
        Defaults to False.

     Returns
    -------
    ax : `matplotlib.pyplot.Axes`

    """
    # set_mpl_params()

    # Create figure if no axes is specified
    if ax is None:
        plt.figure(figsize=(10, 15))
        ax = plt.gca(zorder=999999)

    # Grab one component only
    #rfst_chan = rfst.select(channel=channel).sort(keys=['distance'])
    rfst_chan = rfst.sort(keys=['distance'])

    # Plot traces
    for _i, rf in enumerate(rfst_chan):
        ydata = rf.data
        if rf.stats.type == 'time':
            times = rf.times() - (rf.stats.onset - rf.stats.starttime)
            if rf.stats.phase == 'S':
                ydata = np.flip(-rf.data)
                times = np.flip(times)
        else:
            z = np.hstack(((np.arange(-10, 0,
                                      .1)), np.arange(0, maxz + res, res)))
            times = z
        rftmp = rf.data * scalingfactor \
            + rf.stats.distance
        ax.fill_betweenx(times,
                         rf.stats.distance,
                         rftmp,
                         where=rftmp < rf.stats.distance,
                         interpolate=True,
                         color=(0.2, 0.2, 0.7),
                         zorder=-_i)
        ax.fill_betweenx(times,
                         rf.stats.distance,
                         rftmp,
                         where=rftmp > rf.stats.distance,
                         interpolate=True,
                         color=(0.9, 0.2, 0.2),
                         zorder=-_i - 0.1)
        if line:
            ax.plot(rftmp, times, 'k', lw=linewidth, zorder=-_i + 0.1)

    # Set limits
    if epilimits is None:
        plt.xlim(epilimits)
    else:
        plt.xlim(epilimits)

    if timelimits is None:
        if rfst[0].stats.type == 'time':
            ylim0 = 0
            #rfst_chan[0].stats.starttime - rfst_chan[0].stats.onset
        else:
            ylim0 = times[0]
        ylim1 = times[-1] + ylim0
        plt.ylim(ylim0, ylim1)
    else:
        plt.ylim(timelimits)
    ax.invert_yaxis()

    # Set labels
    plt.xlabel(r"$\Delta$ [$^{\circ}$]")
    if rfst[0].stats.type == 'time':
        plt.ylabel(r"Time [s]")
    else:
        plt.ylabel(r"Depth [km]")

    # Set title
    if title is not None:
        plt.title(title + " - %s" % channel)
    else:
        plt.title("%s component" % channel)

    # Set output directory
    if outputdir is None:
        plt.show()
    else:
        outputfilename = os.path.join(outputdir, "channel_%s.pdf" % channel)
        plt.savefig(outputfilename, format="pdf")
    return ax
Exemplo n.º 10
0
def plot_echelle(
    data: az.InferenceData,
    group="posterior",
    kind: str = "full",
    delta_nu: Optional[float] = None,
    quantiles: Optional[List[float]] = None,
    observed: Union[bool, str] = "auto",
    use_alpha: bool = True,
    ax: plt.Axes = None,
    **kwargs,
) -> plt.Axes:
    """Plot an echelle diagram of the data.

    Choose to plot the full mode, background model or glitchless model. This is
    compatible with data from inference on models like :class:`GlitchModel`.

    Args:
        data (az.InferenceData): Inference data object.
        group (str): On of ['posterior', 'prior']. Defaults to 'posterior'.
        kind (str): One of ['full', 'glitchless', 'background']. Defaults to
            'full' which plots the full model for nu. Use 'glitchless' to plot
            the model without the glitch components. Use 'background' to plot
            the background component of the model.
        delta_nu (float, optional): Large frequency separation to modulo by.
            If None, the median value from ``data['group']`` is used.
        quantiles (iterable, optional): Quantiles to plot as confidence
            intervals. If None, defaults to the 68% confidence interval. Pass
            an empty list to plot no confidence intervals.
        observed (bool or str): Whether to plot observed data. Default is
            "auto" which will plot observed data when group is "posterior".
        use_alpha (bool): Whether to use alpha channel for transparency. If
            False, will shade with lightened solid color.
        ax (matplotlib.axes.Axes): Axis on which to plot the echelle.
        **kwargs: Keyword arguments to pass to :func:`matplotlib.pyplot.plot`.

    Raises:
        ValueError: If kind is not valid.

    Returns:
        matplotlib.axes.Axes: Axis on which the echelle is plot.
    """
    if ax is None:
        _, ax = plt.subplots()

    if quantiles is None:
        quantiles = [0.16, 0.84]

    if observed == "auto":
        observed = group == "posterior"

    predictive = _validate_predictive_group(data, group)
    dim = ("chain", "draw")  # dim over which to take stats

    if delta_nu is None:
        if group == "prior":  # <-- currently no prior group
            delta_nu = predictive["delta_nu"].median().to_numpy()
        else:
            delta_nu = data[group]["delta_nu"].median().to_numpy()

    nu = data.observed_data.nu
    nu_err = data.constant_data.nu_err
    n_pred = predictive.n_pred

    if observed:
        # Plot observed - prior predictive should be independent of obs
        ax.errorbar(
            nu % delta_nu,
            nu,
            xerr=nu_err,
            color="k",
            marker="o",
            linestyle="none",
            label="observed",
        )

    # All mean function components for GP
    # full_mu = [
    # predictive["nu_bkg"].attrs.get("symbol", r"$\nu_\mathrm{bkg}$"),
    # predictive["dnu_he"].attrs.get("symbol", r"$\delta\nu_{He}$"),
    # predictive["dnu_cz"].attrs.get("symbol", r"$\delta\nu_{BCZ}$"),
    # ]
    kindl = kind.lower()
    if kindl == "full":
        y = predictive["nu_pred"]
        # label = r"$\mathrm{GP}($" + " + ".join(full_mu) + r"$,\,K)$"
    elif kindl == "background":
        y = predictive["nu_bkg_pred"]
        # label = full_mu[0]  # <-- just the background, no GP
    elif kindl == "glitchless":
        y = (predictive["nu_pred"] - predictive.get("dnu_he_pred", 0.0) -
             predictive.get("dnu_cz_pred", 0.0))
        y.attrs["unit"] = predictive["nu_pred"].attrs["unit"]
        # label = r"$\mathrm{GP}($" + full_mu[0] + r"$,\,K)$"
    else:
        raise ValueError(f"Kind '{kindl}' is not one of " +
                         "['full', 'background', 'glitchless'].")
    label = f"{kindl} model"

    y_mod = (y - n_pred * delta_nu) % delta_nu
    y_med = y.median(dim=dim)
    label = kwargs.pop("label", label)
    (line, ) = ax.plot(
        y_mod.median(dim=dim),
        y_med,
        label=label,
        **kwargs,
    )

    y_mod_quant = y_mod.quantile(quantiles, dim=dim)
    num_quant = len(quantiles) // 2
    num_alphas = num_quant * 2 + 1
    alphas = np.linspace(0.1, 0.5, num_alphas)
    base_color = line.get_color()

    if use_alpha:
        colors = [base_color] * num_alphas
    else:
        # Mimic alpha by lightening the base color
        colors = [_lighten_color(base_color, 1.5 * a) for a in alphas]
        alphas = [None] * num_alphas  # reset alphas to None

    for i in range(num_quant):
        delta = quantiles[-i - 1] - quantiles[i]
        ax.fill_betweenx(
            y_med,
            y_mod_quant[i],
            y_mod_quant[-i - 1],
            color=colors[2 * i + 1],
            alpha=alphas[2 * i + 1],
            label=f"{delta:.1%} CI",
        )

    # xlabel = [r"$\nu\,\mathrm{mod}.\,{" + f"{delta_nu:.2f}" + "}$"]
    unit = u.Unit(y.attrs.get("unit", ""))
    # if str(unit) != "":
    # xlabel.append(unit.to_string(format="latex_inline"))
    # ax.set_xlabel("/".join(xlabel))

    ax.set_xlabel(r"$\nu$ modulo " + f"{delta_nu:.2f} " +
                  f"({unit.to_string(format='latex_inline')})")
    # ylabel = [r"$\nu$"]
    unit = u.Unit(nu.attrs.get("unit", "uHz"))
    # if str(unit) != "":
    # ylabel.append(unit.to_string(format="latex_inline"))
    # ax.set_ylabel("/".join(ylabel))

    ax.set_ylabel(r"$\nu$ " + f"({unit.to_string(format='latex_inline')})")

    ax.legend()

    return ax