예제 #1
0
파일: flatsym.py 프로젝트: jrkerns/pylinac
 def _plot_image(self, axis: plt.Axes=None, title: str=''):
     plt.ioff()
     if axis is None:
         fig, axis = plt.subplots()
     axis.imshow(self.array, cmap=get_dicom_cmap())
     axis.axhline(self.positions['vertical']*self.array.shape[0], color='r')  # y
     axis.axvline(self.positions['horizontal']*self.array.shape[1], color='r')  # x
     _remove_ticklabels(axis)
     axis.set_title(title)
예제 #2
0
파일: flatsym.py 프로젝트: jrkerns/pylinac
 def _plot_flatness(self, direction: str, axis: plt.Axes=None):
     plt.ioff()
     if axis is None:
         fig, axis = plt.subplots()
     data = self.flatness[direction.lower()]
     axis.set_title(direction.capitalize() + " Flatness")
     axis.plot(data['profile'].values)
     _remove_ticklabels(axis)
     axis.axhline(data['profile max'], color='r')
     axis.axhline(data['profile min'], color='r')
     axis.axvline(data['profile left'], color='g', linestyle='-.')
     axis.axvline(data['profile right'], color='g', linestyle='-.')
예제 #3
0
def generate_distance_plot(distances: List[float], similarity_cutoff: float, filename: Optional[Path] = None, ax: plt.Axes = None):
	""" Shows the spread of computed pairwise distances as a histogram."""
	if not ax:
		fig, ax = plt.subplots(figsize = (12, 10))

	seaborn.distplot(distances, ax = ax, kde = False, rug = True, bins = 20)
	ax.axvline(similarity_cutoff, color = 'red')

	ax.set_title("Pairwise distances between each pair of trajectories")
	ax.set_xlabel("Distance")
	ax.set_ylabel("Count")
	ax.set_xlim(0, max(distances))
	plt.tight_layout()
	if filename:
		plt.savefig(filename)
	else:
		plt.show()
예제 #4
0
def plot_ghi_curves(
        clearsky_ghi: np.ndarray,
        station_ghi: np.ndarray,
        pred_ghi: typing.Optional[np.ndarray],
        window_start: datetime.datetime,
        window_end: datetime.datetime,
        sample_step: datetime.timedelta,
        horiz_offset: datetime.timedelta,
        ax: plt.Axes,
        station_name: typing.Optional[typing.AnyStr] = None,
        station_color: typing.Optional[typing.AnyStr] = None,
        current_time: typing.Optional[datetime.datetime] = None,
) -> plt.Axes:
    """Plots a set of GHI curves and returns the associated matplotlib axes object.

    This function is used in ``draw_daily_ghi`` and ``preplot_live_ghi_curves`` to create simple
    graphs of GHI curves (clearsky, measured, predicted).
    """
    assert clearsky_ghi.ndim == 1 and station_ghi.ndim == 1 and clearsky_ghi.size == station_ghi.size
    assert pred_ghi is None or (pred_ghi.ndim == 1 and clearsky_ghi.size == pred_ghi.size)
    hour_tick_locator = matplotlib.dates.HourLocator(interval=4)
    minute_tick_locator = matplotlib.dates.HourLocator(interval=1)
    datetime_fmt = matplotlib.dates.DateFormatter("%H:%M")
    datetime_range = pd.date_range(window_start, window_end, freq=sample_step)
    xrange_real = matplotlib.dates.date2num([d.to_pydatetime() for d in datetime_range])
    if current_time is not None:
        ax.axvline(x=matplotlib.dates.date2num(current_time), color="r", label="current")
    station_name = f"measured ({station_name})" if station_name else "measured"
    ax.plot(xrange_real, clearsky_ghi, ":", label="clearsky")
    if station_color is not None:
        ax.plot(xrange_real, station_ghi, linestyle="solid", color=station_color, label=station_name)
    else:
        ax.plot(xrange_real, station_ghi, linestyle="solid", label=station_name)
    datetime_range = pd.date_range(window_start + horiz_offset, window_end + horiz_offset, freq=sample_step)
    xrange_offset = matplotlib.dates.date2num([d.to_pydatetime() for d in datetime_range])
    if pred_ghi is not None:
        ax.plot(xrange_offset, pred_ghi, ".-", label="predicted")
    ax.xaxis.set_major_locator(hour_tick_locator)
    ax.xaxis.set_major_formatter(datetime_fmt)
    ax.xaxis.set_minor_locator(minute_tick_locator)
    hour_offset = datetime.timedelta(hours=1) // sample_step
    ax.set_xlim(xrange_real[hour_offset - 1], xrange_real[-hour_offset + 1])
    ax.format_xdata = matplotlib.dates.DateFormatter("%Y-%m-%d %H:%M")
    ax.grid(True)
    return ax
예제 #5
0
def _plot_cluster(ax: plt.Axes, cluster: pd.Series) -> None:
    if len(cluster) >= 2:
        ax.axvspan(
            xmin=cluster.index[0],
            xmax=cluster.index[-1],
            alpha=0.25,
            edgecolor="None",
            facecolor="#D1D3D4",
            zorder=2.5,
        )
        for cluster_boundary in [cluster.index[0], cluster.index[-1]]:
            ax.axvline(
                cluster_boundary,
                ls="--",
                lw=0.5,
                color="#D1D3D4",
                zorder=5,
            )
예제 #6
0
def scatter_peaks_no_peaks(
    top_eco: pd.DataFrame,
    top_naked: pd.DataFrame,
    non_top_eco: pd.DataFrame,
    non_top_naked: pd.DataFrame,
    ax: plt.Axes = None,
):
    if not ax:
        _, ax = plt.subplots(figsize=(12, 12))
    ax.set_xlabel("Chromatin")
    ax.set_ylabel("Naked")
    ax.scatter(
        non_top_eco,
        non_top_naked,
        alpha=0.2,
        label="All Points",
    )
    ax.scatter(top_eco, top_naked, label="Open ATAC")

    ax.axvline(non_top_eco.mean(), color="C0")
    ax.axvline(top_eco.mean(), color="C1")
    ax.axhline(non_top_naked.mean(), color="C0")
    ax.axhline(top_naked.mean(), color="C1")

    ax.legend(
        loc="upper right",
        frameon=False,
        shadow=False,
    )
    # We concatenate the two DFs to a single one so that the dropna() call will
    # "synced" between the two different rows
    top = pd.DataFrame({"chrom": top_eco, "naked": top_naked}).dropna(axis=0)
    all_ = pd.DataFrame({
        "chrom": non_top_eco,
        "naked": non_top_naked
    }).dropna(axis=0)
    r_top, _ = scipy.stats.pearsonr(top.loc[:, "chrom"], top.loc[:, "naked"])
    r_all, _ = scipy.stats.pearsonr(all_.loc[:, "chrom"], all_.loc[:, "naked"])
    ax.text(0.01,
            0.8,
            f"R (top) = {r_top} \nR (rest) = {r_all}",
            transform=ax.transAxes)
    return ax
def add_sh_order_lines(ax: plt.Axes,
                       order=None,
                       args_dict=None,
                       x_flag=True,
                       y_flag=True):
    if args_dict is None:
        args_dict = {}
    from src.utils.sphere import sh

    if order is None:
        order = sh.i2nm(np.floor(ax.get_xlim()[1]))[0]

    n = np.arange(order)
    m = n
    locs = sh.nm2i(n, m) + 0.5
    for loc in locs:
        if x_flag:
            ax.axvline(loc, color='red', **args_dict)
        if y_flag:
            ax.axhline(loc, color='red', **args_dict)
예제 #8
0
    def plot_residual(self, ax: plt.Axes) -> plt.Axes:
        # compute the residual and observation standard deviation
        residual = self.df["residual"]
        obs_se = self.df["residual_se"]
        max_obs_se = np.quantile(obs_se, 0.99)
        fill_index = self.df[self.model.cwdata.col_study_id].str.contains(
            "fill")

        # create funnel plot
        ax = plt.subplots()[1] if ax is None else ax
        ax.set_ylim(max_obs_se, 0.0)
        ax.scatter(residual, obs_se, color="gray", alpha=0.4)
        if fill_index.sum() > 0:
            ax.scatter(residual[fill_index],
                       obs_se[fill_index],
                       color="#008080",
                       alpha=0.7)
        ax.scatter(residual[self.df.outlier == 1],
                   obs_se[self.df.outlier == 1],
                   color='red',
                   marker='x',
                   alpha=0.4)
        ax.fill_betweenx([0.0, max_obs_se], [0.0, -1.96 * max_obs_se],
                         [0.0, 1.96 * max_obs_se],
                         color='#B0E0E6',
                         alpha=0.4)
        ax.plot([0, -1.96 * max_obs_se], [0.0, max_obs_se],
                linewidth=1,
                color='#87CEFA')
        ax.plot([0.0, 1.96 * max_obs_se], [0.0, max_obs_se],
                linewidth=1,
                color='#87CEFA')
        ax.axvline(0.0, color='k', linewidth=1, linestyle='--')
        ax.set_xlabel("residual")
        ax.set_ylabel("ln_rr_se")
        ax.set_title(
            f"{self.name}: egger_mean={self.se_model['mean']: .3f}, "
            f"egger_sd={self.se_model['sd']: .3f}, "
            f"egger_pval={self.se_model['pval']: .3f}",
            loc="left")
        return ax
예제 #9
0
def add_tukey_marks(
    data: pd.Series,
    ax: plt.Axes,
    annot: bool = True,
    iqr_color: str = "r",
    fence_color: str = "k",
    fence_style: str = "--",
    annot_quarts: bool = False,
) -> plt.Axes:
    """Add IQR box and fences to a histogram-like plot.

    Args:
        data (pd.Series): Data for calculating IQR and fences.
        ax (plt.Axes): Axes to annotate.
        iqr_color (str, optional): Color of shaded IQR box. Defaults to "r".
        fence_color (str, optional): Fence line color. Defaults to "k".
        fence_style (str, optional): Fence line style. Defaults to "--".
        annot_quarts (bool, optional): Annotate Q1 and Q3. Defaults to False.

    Returns:
        plt.Axes: Annotated Axes object.
    """
    q1 = data.quantile(0.25)
    q3 = data.quantile(0.75)
    ax.axvspan(q1, q3, color=iqr_color, alpha=0.2)
    iqr_mp = q1 + ((q3 - q1) / 2)
    lower, upper = outliers.tukey_fences(data)
    ax.axvline(lower, c=fence_color, ls=fence_style)
    ax.axvline(upper, c=fence_color, ls=fence_style)
    text_yval = ax.get_ylim()[1]
    text_yval *= 1.01
    if annot:
        ax.text(iqr_mp, text_yval, "IQR", ha="center")
        if annot_quarts:
            ax.text(q1, text_yval, "Q1", ha="center")
            ax.text(q3, text_yval, "Q3", ha="center")
        ax.text(upper, text_yval, "Fence", ha="center")
        ax.text(lower, text_yval, "Fence", ha="center")
    return ax
예제 #10
0
 def _plot_image(self, axis: plt.Axes = None, title: str = ''):
     plt.ioff()
     if axis is None:
         fig, axis = plt.subplots()
     axis.imshow(self.image.array, cmap=get_dicom_cmap())
     #show horizontal profiles
     left_profile = (
         self.positions['horizontal'] -
         self.widths['horizontal'] / 2) * self.image.array.shape[0]
     right_profile = (
         self.positions['horizontal'] +
         self.widths['horizontal'] / 2) * self.image.array.shape[0]
     axis.axhline(left_profile, color='r')  # X
     axis.axhline(right_profile, color='r')  # X
     #show vertical profiles
     bottom_profile = (
         self.positions['vertical'] -
         self.widths['vertical'] / 2) * self.image.array.shape[1]
     top_profile = (self.positions['vertical'] +
                    self.widths['vertical'] / 2) * self.image.array.shape[1]
     axis.axvline(bottom_profile, color='r')  # Y
     axis.axvline(top_profile, color='r')  # Y
     _remove_ticklabels(axis)
     axis.set_title(title)
예제 #11
0
파일: flatsym.py 프로젝트: jrkerns/pylinac
 def _plot_symmetry(self, direction: str, axis: plt.Axes=None):
     plt.ioff()
     if axis is None:
         fig, axis = plt.subplots()
     data = self.symmetry[direction.lower()]
     axis.set_title(direction.capitalize() + " Symmetry")
     axis.plot(data['profile'].values)
     # plot lines
     cax_idx = data['profile'].fwxm_center()
     axis.axvline(data['profile left'], color='g', linestyle='-.')
     axis.axvline(data['profile right'], color='g', linestyle='-.')
     axis.axvline(cax_idx, color='m', linestyle='-.')
     # plot symmetry array
     if not data['array'] == 0:
         twin_axis = axis.twinx()
         twin_axis.plot(range(cax_idx, data['profile right']), data['array'][int(round(len(data['array'])/2)):])
         twin_axis.set_ylabel("Symmetry (%)")
     _remove_ticklabels(axis)
     # plot profile mirror
     central_idx = int(round(data['profile'].values.size / 2))
     offset = cax_idx - central_idx
     mirror_vals = data['profile'].values[::-1]
     axis.plot(data['profile']._indices + 2 * offset, mirror_vals)
예제 #12
0
 def _plot_symmetry(self, direction: str, axis: plt.Axes=None):
     plt.ioff()
     if axis is None:
         fig, axis = plt.subplots()
     data = self.symmetry[direction.lower()]
     axis.set_title(direction.capitalize() + " Symmetry")
     axis.plot(data['profile'].values)
     # plot lines
     cax_idx = data['profile'].fwxm_center()
     axis.axvline(data['profile left'], color='g', linestyle='-.')
     axis.axvline(data['profile right'], color='g', linestyle='-.')
     axis.axvline(cax_idx, color='m', linestyle='-.')
     # plot symmetry array
     if not data['array'] == 0:
         twin_axis = axis.twinx()
         twin_axis.plot(range(cax_idx, data['profile right']), data['array'][int(round(len(data['array'])/2)):])
         twin_axis.set_ylabel("Symmetry (%)")
     _remove_ticklabels(axis)
     # plot profile mirror
     central_idx = int(round(data['profile'].values.size / 2))
     offset = cax_idx - central_idx
     mirror_vals = data['profile'].values[::-1]
     axis.plot(data['profile']._indices + 2 * offset, mirror_vals)
예제 #13
0
def add_geodesic_grid(ax: plt.Axes, manifold: Stereographic, line_width=0.1):
    import math
    # define geodesic grid parameters
    N_EVALS_PER_GEODESIC = 10000
    STYLE = "--"
    COLOR = "gray"
    LINE_WIDTH = line_width

    # get manifold properties
    K = manifold.k.item()
    R = manifold.radius.item()

    # get maximal numerical distance to origin on manifold
    if K < 0:
        # create point on R
        r = torch.tensor((R, 0.0), dtype=manifold.dtype)
        # project point on R into valid range (epsilon border)
        r = manifold.projx(r)
        # determine distance from origin
        max_dist_0 = manifold.dist0(r).item()
    else:
        max_dist_0 = math.pi * R
    # adjust line interval for spherical geometry
    circumference = 2 * math.pi * R

    # determine reasonable number of geodesics
    # choose the grid interval size always as if we'd be in spherical
    # geometry, such that the grid interpolates smoothly and evenly
    # divides the sphere circumference
    n_geodesics_per_circumference = 4 * 6  # multiple of 4!
    n_geodesics_per_quadrant = n_geodesics_per_circumference // 2
    grid_interval_size = circumference / n_geodesics_per_circumference
    if K < 0:
        n_geodesics_per_quadrant = int(max_dist_0 / grid_interval_size)

    # create time evaluation array for geodesics
    if K < 0:
        min_t = -1.2 * max_dist_0
    else:
        min_t = -circumference / 2.0
    t = torch.linspace(min_t, -min_t, N_EVALS_PER_GEODESIC)[:, None]

    # define a function to plot the geodesics
    def plot_geodesic(gv):
        ax.plot(*gv.t().numpy(), STYLE, color=COLOR, linewidth=LINE_WIDTH)

    # define geodesic directions
    u_x = torch.tensor((0.0, 1.0))
    u_y = torch.tensor((1.0, 0.0))

    # add origin x/y-crosshair
    o = torch.tensor((0.0, 0.0))
    if K < 0:
        x_geodesic = manifold.geodesic_unit(t, o, u_x)
        y_geodesic = manifold.geodesic_unit(t, o, u_y)
        plot_geodesic(x_geodesic)
        plot_geodesic(y_geodesic)
    else:
        # add the crosshair manually for the sproj of sphere
        # because the lines tend to get thicker if plotted
        # as done for K<0
        ax.axvline(0, linestyle=STYLE, color=COLOR, linewidth=LINE_WIDTH)
        ax.axhline(0, linestyle=STYLE, color=COLOR, linewidth=LINE_WIDTH)

    # add geodesics per quadrant
    for i in range(1, n_geodesics_per_quadrant):
        i = torch.as_tensor(float(i))
        # determine start of geodesic on x/y-crosshair
        x = manifold.geodesic_unit(i * grid_interval_size, o, u_y)
        y = manifold.geodesic_unit(i * grid_interval_size, o, u_x)

        # compute point on geodesics
        x_geodesic = manifold.geodesic_unit(t, x, u_x)
        y_geodesic = manifold.geodesic_unit(t, y, u_y)

        # plot geodesics
        plot_geodesic(x_geodesic)
        plot_geodesic(y_geodesic)
        if K < 0:
            plot_geodesic(-x_geodesic)
            plot_geodesic(-y_geodesic)
예제 #14
0
    def overlay_entropy_profiles(self,
                                 axes: plt.Axes = None,
                                 r_units: str = 'r500',
                                 k_units: str = 'K500adi',
                                 vkb05_line: bool = True,
                                 color: str = 'k',
                                 alpha: float = 1.,
                                 markersize: float = 1,
                                 linewidth: float = 0.5) -> None:

        stand_alone = False
        if axes is None:
            stand_alone = True
            fig, axes = plt.subplots()
            axes.loglog()
            axes.set_xlabel(f'$r$ [{r_units}]')
            axes.set_ylabel(f'$K$ [${k_units}$]')
            axes.axvline(1, linestyle=':', color=color, alpha=alpha)

        # Set-up entropy data
        fields = [
            'K_500', 'K_1000', 'K_1500', 'K_2500', 'K_0p15r500', 'K_30kpc'
        ]
        K_stat = dict()
        if k_units == 'K500adi':
            K_conv = 1 / getattr(self, 'K_500_adi')
            axes.axhline(1, linestyle=':', color=color, alpha=alpha)
        elif k_units == 'keVcm^2':
            K_conv = np.ones_like(getattr(self, 'K_500_adi'))
            axes.fill_between(np.array(axes.get_xlim()),
                              y1=np.nanmin(self.K_500_adi),
                              y2=np.nanmax(self.K_500_adi),
                              facecolor='k',
                              alpha=0.3)
        else:
            raise ValueError("Conversion unit unknown.")
        for field in fields:
            data = np.multiply(getattr(self, field), K_conv)
            K_stat[field] = (np.nanpercentile(data,
                                              16), np.nanpercentile(data, 50),
                             np.nanpercentile(data, 84))
            K_stat[field.replace('K',
                                 'num')] = np.count_nonzero(~np.isnan(data))

        # Set-up radial distance data
        r_stat = dict()
        if r_units == 'r500':
            r_conv = 1 / getattr(self, 'r_500')
        elif r_units == 'r2500':
            r_conv = 1 / getattr(self, 'r_2500')
        elif r_units == 'kpc':
            r_conv = np.ones_like(getattr(self, 'r_2500'))
        else:
            raise ValueError("Conversion unit unknown.")
        for field in ['r_500', 'r_1000', 'r_1500', 'r_2500']:
            data = np.multiply(getattr(self, field), r_conv)
            if k_units == 'K500adi':
                data[np.isnan(self.K_500_adi)] = np.nan
            r_stat[field] = (np.nanpercentile(data,
                                              16), np.nanpercentile(data, 50),
                             np.nanpercentile(data, 84))
            r_stat[field.replace('r',
                                 'num')] = np.count_nonzero(~np.isnan(data))
        data = np.multiply(getattr(self, 'r_500') * 0.15, r_conv)
        if k_units == 'K500adi':
            data[np.isnan(self.K_500_adi)] = np.nan
        r_stat['r_0p15r500'] = (np.nanpercentile(data, 16),
                                np.nanpercentile(data, 50),
                                np.nanpercentile(data, 84))
        r_stat['num_0p15r500'] = np.count_nonzero(~np.isnan(data))
        data = np.multiply(
            np.ones_like(getattr(self, 'r_2500')) * 30 * unyt.kpc, r_conv)
        if k_units == 'K500adi':
            data[np.isnan(self.K_500_adi)] = np.nan
        r_stat['r_30kpc'] = (np.nanpercentile(data,
                                              16), np.nanpercentile(data, 50),
                             np.nanpercentile(data, 84))
        r_stat['num_30kpc'] = np.count_nonzero(~np.isnan(data))

        for suffix in [
                '_500', '_1000', '_1500', '_2500', '_0p15r500', '_30kpc'
        ]:
            x_low, x, x_hi = r_stat['r' + suffix]
            y_low, y, y_hi = K_stat['K' + suffix]
            num_objects = f"{r_stat['num' + suffix]}, {K_stat['num' + suffix]}"
            point_label = f"r{suffix:.<17s} Num(x,y) = {num_objects}"
            if stand_alone:
                axes.scatter(x, y, label=point_label, s=markersize)
                axes.errorbar(x,
                              y,
                              yerr=[[y_hi - y], [y - y_low]],
                              xerr=[[x_hi - x], [x - x_low]],
                              ls='none',
                              ms=markersize,
                              lw=linewidth)
            else:
                axes.scatter(x, y, color=color, alpha=alpha, s=markersize)
                axes.errorbar(x,
                              y,
                              yerr=[[y_hi - y], [y - y_low]],
                              xerr=[[x_hi - x], [x - x_low]],
                              ls='none',
                              ecolor=color,
                              alpha=alpha,
                              ms=markersize,
                              lw=linewidth)

        if vkb05_line:
            if r_units == 'r500' and k_units == 'K500adi':
                r = np.linspace(*axes.get_xlim(), 31)
                k = 1.40 * r**1.1 / self.hconv
                axes.plot(r, k, linestyle='--', color=color, alpha=alpha)
            else:
                print((
                    "The VKB05 adiabatic threshold should be plotted only when both "
                    "axes are in scaled units, since the line is calibrated on an NFW "
                    "profile with self-similar halos with an average concentration of "
                    "c_500 ~ 4.2 for the objects in the Sun et al. (2009) sample."
                ))

        if k_units == 'K500adi':
            r_r500, S_S500_50, S_S500_10, S_S500_90 = self.get_shortcut()

            plt.fill_between(r_r500,
                             S_S500_10,
                             S_S500_90,
                             color='grey',
                             alpha=0.5,
                             linewidth=0)
            plt.plot(r_r500, S_S500_50, c='k')

        if stand_alone:
            plt.legend()
            plt.show()
예제 #15
0
 def _draw_truth_offdiag(ax: plt.Axes, truth_x, truth_y, **kwargs):
     ax.axvline(truth_x, **kwargs)
     ax.axhline(truth_y, **kwargs)
예제 #16
0
def show_latent(
    seq: Sequence,
    ax: plt.Axes = None,
    bounds: Optional[Tuple] = None,
    colors: Sequence = None,
    show_bars: bool = True,
    bar_width: float = 0.1,
    bar_location: str = "top",
    show_vlines: bool = True,
    vline_kws: Optional[dict] = None,
    shift: float = 0,
):
    """ Display a bar plot showing how the latent state changes with time.

    The bars are drawn either above or below the current extents of the plot, expanding
    the y limits appropriately.

    Parameters
    ----------
    seq
        Sequence indicating the latent state.
    ax
        Axes in which to draw the bars. If not given, `plt.gca()` is used.
    bounds
        If not `None`, this should be a tuple `(t0, t1)` such that the latent state is
        shown only for time points `t >= t0` and `t < t1`. If this is `None`, the
        extents are inferred from the current axis limits.
    colors
        Sequence of colors to use for the identities. By default Matplotlib's default
        color cycle is used.
    show_bars
        If `True`, colored bars are drawn to indicate the current state.
    bar_width
        Width of the bars, given as a fraction of the vertical extent of the plot. Note
        that this is calculated at the moment the function is called.
    bar_location
        Location of the bars. This can be "top" or "bottom".
    show_vlines
        If `True`, vertical lines are drawn to show transition points.
    vline_kws
        Keywords to pass to `axvline`.
    shift
        Amount by which to shift bars and lines to the right (towards higher values).
    """
    # handle trivial case
    if len(seq) == 0:
        return

    # handle defaults
    if ax is None:
        ax = plt.gca()
    if colors is None:
        prop_cycle = plt.rcParams["axes.prop_cycle"]
        colors = prop_cycle.by_key()["color"]
    if bounds is None:
        bounds = ax.get_xlim()

    # find transition points
    transitions = np.diff(seq).nonzero()[0] + 1

    # find the first transition in the given range
    visible_mask = transitions + shift >= bounds[0]
    if np.any(visible_mask):
        idx0 = visible_mask.argmax()
    else:
        idx0 = None

    if show_vlines and idx0 is not None:
        # set up the vline parameters
        if vline_kws is not None:
            crt_vline_kws = copy.copy(vline_kws)
        else:
            crt_vline_kws = {}
        crt_vline_kws.setdefault("ls", ":")
        crt_vline_kws.setdefault("lw", 0.5)
        crt_vline_kws.setdefault("c", "k")

        for transition in transitions[idx0:]:
            if transition + shift >= bounds[1]:
                break
            ax.axvline(transition + shift, **crt_vline_kws)

    if show_bars:
        # find how big the bar is in data coordinates...
        yl = ax.get_ylim()
        yrange = yl[1] - yl[0]
        bar_width_data = yrange * bar_width

        # ...and where to place it
        if bar_location == "top":
            bar_y = yl[1]
            # adjust limits
            yl = (yl[0], bar_y + bar_width_data)
        elif bar_location == "bottom":
            bar_y = yl[0] - bar_width_data
            # adjust limits
            yl = (bar_y, yl[1])
        else:
            raise ValueError("Unknown bar location option.")

        # start drawing!
        x0 = max(bounds[0] - shift, 0)
        if idx0 is not None:
            next_idx = idx0
        else:
            next_idx = len(transitions) + 1
        while x0 + shift < bounds[1] and int(x0) < len(seq):
            crt_id = seq[int(x0)]
            x1 = transitions[next_idx] if next_idx < len(transitions) else len(
                seq)
            x1 = min(x1, bounds[1] - shift)
            if x1 > x0:
                patch = patches.Rectangle(
                    (x0 + shift, bar_y),
                    x1 - x0,
                    bar_width_data,
                    edgecolor="none",
                    facecolor=colors[crt_id % len(colors)],
                )
                ax.add_patch(patch)

            next_idx += 1
            x0 = x1

        # adjust limits
        ax.set_ylim(*yl)
예제 #17
0
 def plot_apogee(self, ax: plt.Axes):
     line = ax.axvline(x=self.time_apogee.to_datetime(),
                       label='apogee',
                       linestyle='--',
                       linewidth=1)
     return line