Пример #1
0
def _explained_variance_plot(model, ax: mpl.axes.Axes, cutoff: float = 0.9):
    n = len(model.explained_variance_ratio_)
    _x = np.arange(1, n + 1)
    _ycum = np.cumsum(model.explained_variance_ratio_)
    best_index = np.where(_ycum > cutoff)[0]
    # modify in case we dont have one
    best_index = best_index[0] if best_index.shape[0] > 0 else n - 1
    # calculate AUC
    auc = np.trapz(_ycum, _x / n)
    # plot
    ax.plot(_x, _ycum, "x-")
    # plot best point
    ax.scatter(
        [_x[best_index]],
        [_ycum[best_index]],
        facecolors="None",
        edgecolors="red",
        s=100,
        label="n=%d, auc=%.3f" % (_x[best_index], auc),
    )
    # plot 0 to 1 line
    ax.plot([1, n], [0, 1], "k--")
    ax.set_xlabel("N\n(Best proportion: %.3f)" % (_x[best_index] / (n + 1)))
    ax.set_ylabel("Explained variance (ratio)\n(cutoff=%.2f)" % cutoff)
    ax.grid()
    ax.legend()
Пример #2
0
    def plot_modes(self, ax: matplotlib.axes.Axes, n: int = 0, iq: int = 0):
        '''Plotting the phonon modes and its derivatives

        :param ax:
        :param n: the order of derivatives to be plotted:
            :math:`n = 0` for :math:`\\omega_{qm}(V)`,
            :math:`n = 1` for :math:`\\gamma_{qm}(V)`,
            :math:`n = 2` for :math:`V\\frac{\partial\gamma_{qm}(V)}{\partial V}`
        :param iq: the index of :math:`q` point to be plotted
        '''

        if n == 0:
            w_arrays = self.calculator.freq_array[:, iq, :]
        elif n == 1:
            w_arrays = self.calculator.mode_gamma[0][:, iq, :]
        elif n == 2:
            w_arrays = self.calculator.mode_gamma[1][:, iq, :]

        for k in range(self.calculator.np):
            if iq == 0 and k < 3: continue
            w_array = w_arrays[:, k]
            ax.plot(self.v_array, w_array)

        if n != 0: return

        for k in range(self.calculator.np):
            if iq == 0 and k < 3: continue
            freqs = numpy.array([
                volume.q_points[iq].modes[k]
                for volume in self.qha_input.volumes
            ])
            ax.scatter(self.volumes, freqs, s=10)
Пример #3
0
def _best_eigenvector_plot(
    x, y, labels: pd.Index, ax: mpl.axes.Axes, nk: Tuple[int, int] = (6, 5)
):
    n_samples, n_pcs = nk

    ax.scatter(x, y)
    ax.hlines(0, -0.5, n_pcs - 0.5, linestyle="--")
    annotate(x, y, labels, ax=ax, word_shorten=15)
    ax.set_ylabel("Eigenvector")
    ax.grid()
Пример #4
0
def _plot_data(ax: mpl.axes.Axes, data: PlotData) -> Optional[List[mpl.lines.Line2D]]:
    
    x, y = None, None
    
    lines = None  # Return line objects so we can add legends
    
    disp = data.display_attributes
    
    if isinstance(data, XYData) or isinstance(data, TimeSeries):
        x, y = (data.x, data.y) if isinstance(data, XYData) else (np.arange(len(data.timestamps)), data.values)
        if isinstance(disp, LinePlotAttributes):
            lines, = ax.plot(x, y, linestyle=disp.line_type, linewidth=disp.line_width, color=disp.color)
            if disp.marker is not None:  # type: ignore
                ax.scatter(x, y, marker=disp.marker, c=disp.marker_color, s=disp.marker_size, zorder=100)
        elif isinstance(disp, ScatterPlotAttributes):
            lines = ax.scatter(x, y, marker=disp.marker, c=disp.marker_color, s=disp.marker_size, zorder=100)
        elif isinstance(disp, BarPlotAttributes):
            lines = ax.bar(x, y, color=disp.color)  # type: ignore
        elif isinstance(disp, FilledLinePlotAttributes):
            x, y = np.nan_to_num(x), np.nan_to_num(y)
            pos_values = np.where(y > 0, y, 0)
            neg_values = np.where(y < 0, y, 0)
            ax.fill_between(x, pos_values, color=disp.positive_color, step='post', linewidth=0.0)
            ax.fill_between(x, neg_values, color=disp.negative_color, step='post', linewidth=0.0)
        else:
            raise Exception(f'unknown plot combination: {type(data)} {type(disp)}')
            
        # For scatter and filled line, xlim and ylim does not seem to get set automatically
        if isinstance(disp, ScatterPlotAttributes) or isinstance(disp, FilledLinePlotAttributes):
            xmin, xmax = _adjust_axis_limit(ax.get_xlim(), x)
            if not np.isnan(xmin) and not np.isnan(xmax): ax.set_xlim((xmin, xmax))

            ymin, ymax = _adjust_axis_limit(ax.get_ylim(), y)
            if not np.isnan(ymin) and not np.isnan(ymax): ax.set_ylim((ymin, ymax))
                
    elif isinstance(data, TradeSet) and isinstance(disp, ScatterPlotAttributes):
        lines = ax.scatter(np.arange(len(data.timestamps)), data.values, marker=disp.marker, c=disp.marker_color, s=disp.marker_size, zorder=100)
    elif isinstance(data, TradeBarSeries) and isinstance(disp, CandleStickPlotAttributes):
        draw_candlestick(ax, np.arange(len(data.timestamps)), data.o, data.h, data.l, data.c, data.v, data.vwap, colorup=disp.colorup, colordown=disp.colordown)
    elif isinstance(data, BucketedValues) and isinstance(disp, BoxPlotAttributes):
        draw_boxplot(
            ax, data.bucket_names, data.bucket_values, disp.proportional_widths, disp.notched,  # type: ignore
            disp.show_outliers, disp.show_means, disp.show_all)  # type: ignore
    elif isinstance(data, XYZData) and (isinstance(disp, SurfacePlotAttributes) or isinstance(disp, ContourPlotAttributes)):
        display_type: str = 'contour' if isinstance(disp, ContourPlotAttributes) else 'surface'
        draw_3d_plot(ax, data.x, data.y, data.z, display_type, disp.marker, disp.marker_size, 
                     disp.marker_color, disp.interpolation, disp.cmap)
    else:
        raise Exception(f'unknown plot combination: {type(data)} {type(disp)}')

    return lines
Пример #5
0
    def plot(self, ax: matplotlib.axes.Axes):
        # individual points
        ax.scatter(self.mean, self.diff, s=20, alpha=0.6, color=self.color_points,
                   **self.point_kws)

        # mean difference and SD lines
        ax.axhline(self.mean_diff, color=self.color_mean, linestyle='-')
        ax.axhline(self.mean_diff + self.loa_sd, color=self.color_loa, linestyle='--')
        ax.axhline(self.mean_diff - self.loa_sd, color=self.color_loa, linestyle='--')

        if self.reference:
            ax.axhline(0, color='grey', linestyle='-', alpha=0.4)

        # confidence intervals (if requested)
        if self.CI is not None:
            ax.axhspan(self.CI_mean[0],  self.CI_mean[1], color=self.color_mean, alpha=0.2)
            ax.axhspan(self.CI_upper[0], self.CI_upper[1], color=self.color_loa, alpha=0.2)
            ax.axhspan(self.CI_lower[0], self.CI_lower[1], color=self.color_loa, alpha=0.2)

        # text in graph
        trans: matplotlib.transform = transforms.blended_transform_factory(
            ax.transAxes, ax.transData)
        offset: float = (((self.loa * self.sd_diff) * 2) / 100) * 1.2
        ax.text(0.98, self.mean_diff + offset, 'Mean', ha="right", va="bottom", transform=trans)
        ax.text(0.98, self.mean_diff - offset, f'{self.mean_diff:.2f}', ha="right", va="top", transform=trans)
        ax.text(0.98, self.mean_diff + self.loa_sd + offset,
                f'+{self.loa:.2f} SD', ha="right", va="bottom", transform=trans)
        ax.text(0.98, self.mean_diff + self.loa_sd - offset,
                f'{self.mean_diff + self.loa_sd:.2f}', ha="right", va="top", transform=trans)
        ax.text(0.98, self.mean_diff - self.loa_sd - offset,
                f'-{self.loa:.2f} SD', ha="right", va="top", transform=trans)
        ax.text(0.98, self.mean_diff - self.loa_sd + offset,
                f'{self.mean_diff - self.loa_sd:.2f}', ha="right", va="bottom", transform=trans)

        # transform graphs
        ax.spines['right'].set_visible(False)
        ax.spines['top'].set_visible(False)

        # set X and Y limits
        if self.xlim is not None:
            ax.set_xlim(self.xlim[0], self.xlim[1])
        if self.ylim is not None:
            ax.set_ylim(self.ylim[0], self.ylim[1])

        # graph labels
        ax.set_ylabel(self.y_title)
        ax.set_xlabel(self.x_title)
        if self.graph_title is not None:
            ax.set_title(self.graph_title)
Пример #6
0
def add_legend(ax: mpl.axes.Axes, labels: list, cmap1="cool", cmap2="bwr", s=20):
    cmp1 = mpl.cm.get_cmap(cmap1)
    cmp2 = mpl.cm.get_cmap(cmap2)
    colors = (cmp1(1.0), cmp1(0.0), cmp2(1.0), cmp2(0.0))
    markers = ("+", "+", "x", "x")
    sizes = (s * 2, s * 2, s + 5, s + 5)
    for label, c, m, size in zip(labels, colors, markers, sizes):
        ax.scatter([], [], color=c, marker=m, label=label, s=size)
    ax.legend(
        bbox_to_anchor=(0.0, 1.2, 1.0, 0.102),
        loc="lower left",
        ncol=2,
        mode="expand",
        borderaxespad=0.0,
    )
    def update_preview(self, list_obj: wx.ListBox, axes: matplotlib.axes.Axes):
        """
		Update the preview from the given list

		:param list_obj: The list to update the preview for
		:param axes: The preview axes to update
		"""

        axes.clear()
        axes.axis("off")
        selection_string = list_obj.GetStringSelection()
        if selection_string == '':
            return

        axes.scatter(1, 1, s=400, color=selection_string, marker='s')
Пример #8
0
def plot_2d_scatterplot(x: np.ndarray,
                        y: np.ndarray,
                        z: np.ndarray,
                        ax: matplotlib.axes.Axes,
                        colorbar: matplotlib.colorbar.Colorbar = None,
                        **kwargs) -> AxesTuple:
    """
    Make a 2D scatterplot of the data. ``**kwargs`` are passed to matplotlib's
    scatter used for the plotting. By default the data will be rasterized
    in any vector plot if more that 5000 points are supplied. This can be
    overridden by supplying the `rasterized` kwarg.

    Args:
        x: The x values
        y: The y values
        z: The z values
        ax: The axis to plot onto
        colorbar: The colorbar to plot into

    Returns:
        The matplotlib axis handles for plot and colorbar
    """
    if 'rasterized' in kwargs.keys():
        rasterized = kwargs.pop('rasterized')
    else:
        rasterized = len(z) > qc.config.plotting.rasterize_threshold

    z_is_string_valued = isinstance(z[0], str)

    if z_is_string_valued:
        z_int = list(range(len(z)))
        mappable = ax.scatter(x=x,
                              y=y,
                              c=z_int,
                              rasterized=rasterized,
                              **kwargs)
    else:
        mappable = ax.scatter(x=x, y=y, c=z, rasterized=rasterized, **kwargs)

    if colorbar is not None:
        colorbar = ax.figure.colorbar(mappable, ax=ax, cax=colorbar.ax)
    else:
        colorbar = ax.figure.colorbar(mappable, ax=ax)

    if z_is_string_valued:
        colorbar.ax.set_yticklabels(z)

    return ax, colorbar
Пример #9
0
def plot_2d_scatterplot(x: np.ndarray,
                        y: np.ndarray,
                        z: np.ndarray,
                        ax: matplotlib.axes.Axes,
                        colorbar: matplotlib.colorbar.Colorbar = None,
                        **kwargs) -> AxesTuple:
    """
    Make a 2D scatterplot of the data

    Args:
        x: The x values
        y: The y values
        z: The z values
        ax: The axis to plot onto
        colorbar: The colorbar to plot into

    Returns:
        The matplotlib axis handles for plot and colorbar
    """
    mappable = ax.scatter(x=x, y=y, c=z, **kwargs)
    if colorbar is not None:
        colorbar = ax.figure.colorbar(mappable, ax=ax, cax=colorbar.ax)
    else:
        colorbar = ax.figure.colorbar(mappable, ax=ax)
    return ax, colorbar
Пример #10
0
    def _do_scatter(self, a: mpl.axes.Axes, x: np.ndarray, y: np.ndarray):
        """Create density scatter plot

        Parameters
        ----------
        a
            Axes to draw on
        x, y
            data
        """
        dens = self._calc_kde(x, y, x, y)
        sort_idx = np.argsort(dens)
        dens = dens[sort_idx]
        x = x[sort_idx]
        y = y[sort_idx]

        a.scatter(x, y, c=dens, cmap="viridis", marker=".")
Пример #11
0
def plot_delivered_vaccines_quantity(df_delivered: pd.DataFrame,
                                     ax: mp.axes.Axes) -> ResultValue:
    log = logging.getLogger('plot_delivered_vaccines_quantity')
    log.info(" >>")
    rv: ResultValue = ResultKo(Exception("Error"))
    try:
        line_label = "Dosi consegnate - somma"
        line_color = "#ff5733"

        df_delivered.sort_values(by="data_consegna", inplace=True)
        by_date = df_delivered.groupby(["data_consegna"]).sum()
        by_date.reset_index(level=0, inplace=True)
        by_date["cumulata"] = by_date["numero_dosi"].cumsum()

        x_del = by_date["data_consegna"]
        y_del = by_date["cumulata"]

        remove_tick_lines('x', ax)
        remove_tick_lines('y', ax)
        set_axes_common_properties(ax, no_grid=True)
        ax.xaxis.set_major_formatter(mdates.DateFormatter("%d/%m/%y"))
        ax.xaxis.set_minor_formatter(mdates.DateFormatter("%d/%m"))
        ax.xaxis.set_major_locator(mdates.DayLocator(interval=2))

        ax.scatter(x_del, y_del, s=30, marker='.')
        line = ax.plot(x_del,
                       y_del,
                       'b-',
                       linewidth=2,
                       color=line_color,
                       label=line_label)

        ax.set_xticklabels(x_del, rotation=80)

        handles, labels = ax.get_legend_handles_labels()
        patch = mpatches.Patch(color=line_color, label=line_label)
        handles.append(patch)
        plt.legend(handles=handles, loc='upper left')

        rv = ResultOk(line)

    except Exception as ex:
        log.error("Exception caught - {ex}".format(ex=ex))
        rv = ResultKo(ex)
    log.info(" <<")
    return rv
Пример #12
0
def plot_2d_scatterplot(x: np.ndarray,
                        y: np.ndarray,
                        z: np.ndarray,
                        ax: matplotlib.axes.Axes,
                        colorbar: matplotlib.colorbar.Colorbar = None,
                        **kwargs: Any) -> AxesTuple:
    """
    Make a 2D scatterplot of the data. ``**kwargs`` are passed to matplotlib's
    scatter used for the plotting. By default the data will be rasterized
    in any vector plot if more than 5000 points are supplied. This can be
    overridden by supplying the `rasterized` kwarg.

    Args:
        x: The x values
        y: The y values
        z: The z values
        ax: The axis to plot onto
        colorbar: The colorbar to plot into

    Returns:
        The matplotlib axis handles for plot and colorbar
    """
    if 'rasterized' in kwargs.keys():
        rasterized = kwargs.pop('rasterized')
    else:
        rasterized = len(z) > qc.config.plotting.rasterize_threshold

    z_is_stringy = isinstance(z[0], str)

    if z_is_stringy:
        z_strings = np.unique(z)
        z = _strings_as_ints(z)

    cmap = kwargs.pop('cmap') if 'cmap' in kwargs else None

    if z_is_stringy:
        name = cmap.name if hasattr(cmap, 'name') else 'viridis'
        cmap = matplotlib.cm.get_cmap(name, len(z_strings))

    mappable = ax.scatter(x=x,
                          y=y,
                          c=z,
                          rasterized=rasterized,
                          cmap=cmap,
                          **kwargs)

    if colorbar is not None:
        colorbar = ax.figure.colorbar(mappable, ax=ax, cax=colorbar.ax)
    else:
        colorbar = ax.figure.colorbar(mappable, ax=ax)

    if z_is_stringy:
        N = len(z_strings)
        f = (N - 1) / N
        colorbar.set_ticks([(n + 0.5) * f for n in range(N)])
        colorbar.set_ticklabels(z_strings)

    return ax, colorbar
Пример #13
0
def kde2d(X: Union[np.ndarray, Series, List, Tuple],
          Y: Union[np.ndarray, Series, List, Tuple],
          c: str = "red",
          ax: mpl.axes.Axes = None,
          fill: bool = False,
          with_scatter: bool = False,
          **contour_kwargs):
    """TODO: Generates a 2D KDE using contours."""
    instance_check((X, Y), (list, tuple, np.ndarray, Series))
    instance_check(c, str)
    instance_check((fill, with_scatter), bool)
    instance_check(ax, mpl.axes.Axes)
    arrays_equal_size(X, Y)

    # calculate density
    _X, _Y = remove_na(np.asarray(X), np.asarray(Y), paired=True)

    H = density(_X, _Y)
    offx = np.abs(_X.max() - _X.min()) / 15.0
    offy = np.abs(_Y.max() - _Y.min()) / 15.0
    _alpha = 0.5 if with_scatter else 1.0

    if ax is None:
        fig, ax = plt.subplots(figsize=(8, 5))

    if fill:
        ax.contourf(
            H,
            extent=(_X.min() - offx, _X.max() + offx, _Y.min() - offy,
                    _Y.max() + offy),
            color=c,
            alpha=_alpha,
        )
    else:
        cset = ax.contour(H,
                          extent=(_X.min() - offx, _X.max() + offx,
                                  _Y.min() - offy, _Y.max() + offy),
                          color=c,
                          **contour_kwargs)
        ax.clabel(cset, inline=1, fontsize=10)

    if with_scatter:
        ax.scatter(_X, _Y, c=c, alpha=_alpha)

    return ax
Пример #14
0
def reachable_zone_scatter(
        ax: matplotlib.axes.Axes,
        mount_model: MountModel,
        axis_0_west_limit: float = 110,
        axis_0_east_limit: float = 110,
    ) -> None:
    """Generate a scatter plot showing the reachable zone.

    This function assumes an equatorial mount with limits on the right ascension axis.

    Args:
        ax: Axes object this function will plot on. This should be generated by `make_sky_plot()`.
        mount_model: Mount model from which this plot will be generated.
        axis_0_west_limit: Western limit on axis 0 in degrees from the meridian.
        axis_0_east_limit: Eastern limit on axis 0 in degrees from the meridian.
    """
    # convert from arg values to encoder position angles
    axis_0_west_limit = 180 - axis_0_west_limit
    axis_0_east_limit = 180 + axis_0_east_limit

    for meridian_side in MeridianSide:

        axis_0 = np.linspace(axis_0_west_limit, axis_0_east_limit, 20)
        if meridian_side == MeridianSide.EAST:
            axis_1 = np.linspace(0, 180, 20)
        else:
            axis_1 = np.linspace(180, 360, 20)
        ax0, ax1 = np.meshgrid(axis_0, axis_1)
        points = np.vstack([ax0.ravel(), ax1.ravel()])

        az = []
        alt = []
        for idx in range(points.shape[1]):
            topo = mount_model.encoders_to_topocentric(
                MountEncoderPositions(
                    Longitude(points[0][idx]*u.deg),
                    Longitude(points[1][idx]*u.deg),
                )
            )
            az.append(topo.az.deg)
            alt.append(topo.alt.deg)
        az = np.array(az)
        alt = np.array(alt)

        ax.scatter(np.radians(az), 90.0 - alt, label=meridian_side.name.title())
Пример #15
0
def hist2d_scatter(x,
                   y,
                   bg_x,
                   bg_y,
                   axis: matplotlib.axes.Axes,
                   dataset_name: str,
                   normalize=True,
                   marker_color: str = 'k',
                   bins=300,
                   colormap_name: str = 'jet',
                   color_bar=False,
                   hist2dkw={},
                   scatterkw={}):
    """Plots 2D histogram with two parameters (e.g. BxGSM and ByGSM).
    x, y (array_like): Values of the TPAs for the parameter that will be plotted on the x- or y-axis.
                       This data will correspond to the dots in the plot.
    bg_x, bg_y (array_like): Values of the IMF over the period of the dataset.
                             These will form the background (colored tiles) of the plot
    """
    colormap = cm.get_cmap(colormap_name)
    axis.axhline(0, color='grey', zorder=1)
    axis.axvline(0, color='grey', zorder=1)
    omit_index = np.isnan(x) | np.isnan(y)
    x = x[~omit_index]
    y = y[~omit_index]
    bg_omit_index = np.isnan(bg_x) | np.isnan(bg_y)
    bg_x = bg_x[~bg_omit_index]
    bg_y = bg_y[~bg_omit_index]
    counts, xedges, yedges, im = axis.hist2d(bg_x,
                                             bg_y,
                                             bins=bins,
                                             cmap=colormap,
                                             density=normalize,
                                             zorder=0,
                                             **hist2dkw)

    if color_bar:
        cbar = plt.colorbar(im, ax=axis)
        cbar.set_label('IMF probability distribution',
                       rotation=270,
                       labelpad=10)

    scatter = axis.scatter(x,
                           y,
                           s=30,
                           marker='P',
                           edgecolors='w',
                           linewidth=0.5,
                           label=dataset_name,
                           c=marker_color[~omit_index] if isinstance(
                               marker_color, np.ndarray) else marker_color,
                           zorder=2,
                           **scatterkw)

    axis.set_facecolor(colormap(0))
    axis.legend(loc='upper left')
    return scatter, counts, xedges, yedges, im
Пример #16
0
def draw_3d_plot(ax: mpl.axes.Axes,
                 x: np.ndarray,
                 y: np.ndarray,
                 z: np.ndarray,
                 plot_type: str,
                 marker: str = 'X',
                 marker_size: int = 50, 
                 marker_color: str = 'red',
                 interpolation: str = 'linear', 
                 cmap: str = 'viridis') -> None:

    '''Draw a 3d plot.  See XYZData class for explanation of arguments
    
    >>> points = np.random.rand(1000, 2)
    >>> x = np.random.rand(10)
    >>> y = np.random.rand(10)
    >>> z = x ** 2 + y ** 2
    >>> if has_display():
    ...    fig, ax = plt.subplots()
    ...    draw_3d_plot(ax, x = x, y = y, z = z, plot_type = 'contour', interpolation = 'linear')
    '''
    xi = np.linspace(min(x), max(x))
    yi = np.linspace(min(y), max(y))
    X, Y = np.meshgrid(xi, yi)
    Z = griddata((x, y), z, (xi[None, :], yi[:, None]), method=interpolation)
    Z = np.nan_to_num(Z)

    if plot_type == 'surface':
        ax.plot_surface(X, Y, Z, cmap=cmap)
        if marker is not None:
            ax.scatter(x, y, z, marker=marker, s=marker_size, c=marker_color)
    elif plot_type == 'contour':
        cs = ax.contour(X, Y, Z, linewidths=0.5, colors='k')
        ax.clabel(cs, cs.levels[::2], fmt="%.3g", inline=1)
        ax.contourf(X, Y, Z, cmap=cmap)
        if marker is not None:
            ax.scatter(x, y, marker=marker, s=marker_size, c=marker_color, zorder=10)
    else:
        raise Exception(f'unknown plot type: {plot_type}')

    m = cm.ScalarMappable(cmap=cmap)
    m.set_array(Z)
    plt.colorbar(m, ax=ax)
Пример #17
0
def draw_candlestick(
        ax: mpl.axes.Axes,
        index: np.ndarray,
        o: np.ndarray,
        h: np.ndarray,
        l: np.ndarray,  # noqa: E741: ignore # l ambiguous
        c: np.ndarray,
        v: Optional[np.ndarray],
        vwap: np.ndarray,
        colorup: str = 'darkgreen',
        colordown: str = '#F2583E') -> None:
    '''Draw candlesticks given parrallel numpy arrays of o, h, l, c, v values.  v is optional.  
        See TradeBarSeries class __init__ for argument descriptions.'''
    width = 0.5

    # Have to do volume first because of a mpl bug with axes fonts if we use make_axes_locatable after plotting on top axis
    if v is not None and not np.isnan(v).all():
        divider = make_axes_locatable(ax)
        vol_ax = divider.append_axes('bottom', size='25%', sharex=ax, pad=0)
        _c = np.nan_to_num(c)
        _o = np.nan_to_num(o)
        pos = _c >= _o
        neg = _c < _o
        vol_ax.bar(index[pos], v[pos], color=colorup, width=width)
        vol_ax.bar(index[neg], v[neg], color=colordown, width=width)

    offset = width / 2.0

    mask = ~np.isnan(c) & ~np.isnan(o)
    mask[mask] &= c[mask] < o[mask]

    left = index - offset
    bottom = np.where(mask, o, c)
    top = np.where(mask, c, o)
    right = left + width

    draw_poly(ax, left[mask], bottom[mask], top[mask], right[mask], colordown,
              'k', 100)
    draw_poly(ax, left[~mask], bottom[~mask], top[~mask], right[~mask],
              colorup, 'k', 100)
    draw_poly(ax, left + offset, l, h, left + offset, 'k', 'k', 1)
    if vwap is not None:
        ax.scatter(index, vwap, marker='o', color='orange', zorder=110)
Пример #18
0
def plot_2d_scatterplot(x: np.ndarray, y: np.ndarray, z: np.ndarray,
                        ax: matplotlib.axes.Axes) -> AxesTuple:
    """
    Make a 2D scatterplot of the data

    Args:
        x: The x values
        y: The y values
        z: The z values
        ax: The axis to plot onto

    Returns:
        The matplotlib axis handles for plot and colorbar
    """
    mappable = ax.scatter(x=x, y=y, c=z)
    cbax = ax.figure.colorbar(mappable, ax=ax)
    return ax, cbax
Пример #19
0
def plot_testcount_forecast(
    result: pandas.Series,
    m: preprocessing.fbprophet.Prophet,
    forecast: pandas.DataFrame,
    considered_holidays: preprocessing.NamedDates, *,
    ax: matplotlib.axes.Axes=None
) -> matplotlib.axes.Axes:
    """ Helper function for plotting the detailed testcount forecasting result.

    Parameters
    ----------
    result : pandas.Series
        the date-indexed series of smoothed/predicted testcounts
    m : fbprophet.Prophet
        the prophet model
    forecast : pandas.DataFrame
        contains the prophet model prediction
    holidays : dict of { datetime : str }
        dictionary of the holidays that were used in the model
    ax : optional, matplotlib.axes.Axes
        an existing subplot to use

    Returns
    -------
    ax : matplotlib.axes.Axes
        the (created) subplot that was plotted into
    """
    if not ax:
        _, ax = pyplot.subplots(figsize=(13.4, 6))
    m.plot(forecast[forecast.ds >= m.history.set_index('ds').index[0]], ax=ax)
    ax.set_ylim(bottom=0)
    ax.set_xlim(pandas.to_datetime('2020-03-01'))
    plot_vlines(ax, considered_holidays, alignment='bottom')
    ax.legend(frameon=False, loc='upper left', handles=[
        ax.scatter([], [], color='black', label='training data'),
        ax.plot([], [], color='blue', label='prediction')[0],
        ax.plot(result.index, result.values, color='orange', label='result')[0],
    ])
    ax.set_ylabel('total tests')
    ax.set_xlabel('')
    return ax
Пример #20
0
def draw_3d_plot(ax: mpl.axes.Axes,
                 x: np.ndarray,
                 y: np.ndarray,
                 z: np.ndarray,
                 plot_type: str = 'contour',
                 marker: str = 'X',
                 marker_size: int = 50,
                 marker_color: str = 'red',
                 interpolation: str = 'linear',
                 cmap: matplotlib.colors.Colormap = matplotlib.cm.RdBu_r,
                 min_level: float = math.nan,
                 max_level: float = math.nan) -> None:
    '''Draw a 3d plot.  See XYZData class for explanation of arguments
    
    >>> points = np.random.rand(1000, 2)
    >>> x = np.random.rand(10)
    >>> y = np.random.rand(10)
    >>> z = x ** 2 + y ** 2
    >>> if has_display():
    ...    fig, ax = plt.subplots()
    ...    draw_3d_plot(ax, x = x, y = y, z = z, plot_type = 'contour', interpolation = 'linear');
    '''
    xi = np.linspace(min(x), max(x))
    yi = np.linspace(min(y), max(y))
    X, Y = np.meshgrid(xi, yi)
    Z = griddata((x, y), z, (xi[None, :], yi[:, None]), method=interpolation)
    Z = np.nan_to_num(Z)

    if plot_type == 'surface':
        ax.plot_surface(X, Y, Z, cmap=cmap)
        if marker is not None:
            ax.scatter(x, y, z, marker=marker, s=marker_size, c=marker_color)
        m = cm.ScalarMappable(cmap=cmap)
        m.set_array(Z)
        plt.colorbar(m, ax=ax)

    elif plot_type == 'contour':
        # extract all colors from the  map
        cmaplist = [cmap(i) for i in range(cmap.N)]
        # create the new map
        cmap = cmap.from_list('Custom cmap', cmaplist, cmap.N)
        Z = np.ma.masked_array(Z, mask=~np.isfinite(Z))
        if math.isnan(min_level): min_level = np.min(Z)
        if math.isnan(max_level): max_level = np.max(Z)
        # define the bins and normalize and forcing 0 to be part of the colorbar!
        bounds = np.arange(min_level, max_level,
                           (max_level - min_level) / cmap.N)
        idx = np.searchsorted(bounds, 0)
        bounds = np.insert(bounds, idx, 0)
        norm = BoundaryNorm(bounds, cmap.N)
        cs = ax.contourf(X, Y, Z, cmap=cmap, norm=norm)

        if marker is not None:
            x = x[np.isfinite(z)]
            y = y[np.isfinite(z)]
            ax.scatter(x,
                       y,
                       marker=marker,
                       s=marker_size,
                       c=z[np.isfinite(z)],
                       zorder=10,
                       cmap=cmap)
        LABEL_SIZE = 16
        ax.tick_params(axis='both', which='major', labelsize=LABEL_SIZE)
        ax.tick_params(axis='both', which='minor', labelsize=LABEL_SIZE)
        cbar = plt.colorbar(cs, ax=ax)
        cbar.ax.tick_params(labelsize=LABEL_SIZE)

    else:
        raise Exception(f'unknown plot type: {plot_type}')
Пример #21
0
    def plot(
        self,
        x_label: str = "Mean of methods",
        y_label: str = "Difference between methods",
        graph_title: str = None,
        reference: bool = False,
        xlim: Tuple = None,
        ylim: Tuple = None,
        color_mean: str = "#008bff",
        color_loa: str = "#FF7000",
        color_points: str = "#000000",
        point_kws: Dict = None,
        ci_alpha: float = 0.2,
        loa_linestyle: str = "--",
        ax: matplotlib.axes.Axes = None,
    ):
        """Provide a method comparison using Bland-Altman plotting.
        This is an Axis-level function which will draw the Bland-Altman plot
        onto the current active Axis object unless ``ax`` is provided.
        Parameters
        ----------
        x_label : str, optional
            The label which is added to the X-axis. If None is provided, a standard
            label will be added.
        y_label : str, optional
            The label which is added to the Y-axis. If None is provided, a standard
            label will be added.
        graph_title : str, optional
            Title of the Bland-Altman plot.
            If None is provided, no title will be plotted.
        reference : bool, optional
            If True, a grey reference line at y=0 will be plotted in the Bland-Altman.
        xlim : list, optional
            Minimum and maximum limits for X-axis. Should be provided as list or tuple.
            If not set, matplotlib will decide its own bounds.
        ylim : list, optional
            Minimum and maximum limits for Y-axis. Should be provided as list or tuple.
            If not set, matplotlib will decide its own bounds.
        color_mean : str, optional
            Color of the mean difference line that will be plotted.
        color_loa : str, optional
            Color of the limit of agreement lines that will be plotted.
        color_points : str, optional
            Color of the individual differences that will be plotted.
        point_kws : dict of key, value mappings, optional
            Additional keyword arguments for `plt.scatter`.
        ci_alpha: float, optional
            Alpha value of the confidence interval.
        loa_linestyle: str, optional
            Linestyle of the limit of agreement lines.
        ax : matplotlib Axes, optional
            Axes in which to draw the plot, otherwise use the currently-active
            Axes.

        Returns
        -------
        ax : matplotlib Axes
            Axes object with the Bland-Altman plot.
        """

        ax = ax or plt.gca()

        pkws = self.DEFAULT_POINTS_KWS.copy()
        pkws.update(point_kws or {})

        # Get parameters
        mean, mean_CI = self.result["mean"], self.result["mean_CI"]
        loa_upper, loa_upper_CI = self.result["loa_upper"], self.result[
            "loa_upper_CI"]
        loa_lower, loa_lower_CI = self.result["loa_lower"], self.result[
            "loa_lower_CI"]
        sd_diff = self.result["sd_diff"]

        # individual points
        ax.scatter(self.mean, self.diff, **pkws)

        # mean difference and SD lines
        ax.axhline(mean, color=color_mean, linestyle=loa_linestyle)
        ax.axhline(loa_upper, color=color_loa, linestyle=loa_linestyle)
        ax.axhline(loa_lower, color=color_loa, linestyle=loa_linestyle)

        if reference:
            ax.axhline(0, color="grey", linestyle="-", alpha=0.4)

        # confidence intervals (if requested)
        if self.CI is not None:
            ax.axhspan(*mean_CI, color=color_mean, alpha=ci_alpha)
            ax.axhspan(*loa_upper_CI, color=color_loa, alpha=ci_alpha)
            ax.axhspan(*loa_lower_CI, color=color_loa, alpha=ci_alpha)

        # text in graph
        trans: matplotlib.transform = transforms.blended_transform_factory(
            ax.transAxes, ax.transData)
        offset: float = (((self.loa * sd_diff) * 2) / 100) * 1.2
        ax.text(
            0.98,
            mean + offset,
            "Mean",
            ha="right",
            va="bottom",
            transform=trans,
        )
        ax.text(
            0.98,
            mean - offset,
            f"{mean:.2f}",
            ha="right",
            va="top",
            transform=trans,
        )
        ax.text(
            0.98,
            loa_upper + offset,
            f"+{self.loa:.2f} SD",
            ha="right",
            va="bottom",
            transform=trans,
        )
        ax.text(
            0.98,
            loa_upper - offset,
            f"{loa_upper:.2f}",
            ha="right",
            va="top",
            transform=trans,
        )
        ax.text(
            0.98,
            loa_lower - offset,
            f"-{self.loa:.2f} SD",
            ha="right",
            va="top",
            transform=trans,
        )
        ax.text(
            0.98,
            loa_lower + offset,
            f"{loa_lower:.2f}",
            ha="right",
            va="bottom",
            transform=trans,
        )

        # transform graphs
        ax.spines["right"].set_visible(False)
        ax.spines["top"].set_visible(False)

        # set X and Y limits
        if xlim is not None:
            ax.set_xlim(xlim[0], xlim[1])
        if ylim is not None:
            ax.set_ylim(ylim[0], ylim[1])

        # graph labels
        ax.set(xlabel=x_label, ylabel=y_label, title=graph_title)

        return ax
Пример #22
0
    def plot(
        self,
        figsize: Tuple[float, float] = (8, 5),
        same_plot: bool = False,
        hide_cells: bool = False,
        perc: Tuple[float, float] = None,
        abs_prob_cmap: mcolors.ListedColormap = cm.viridis,
        cell_color: str = "black",
        lineage_color: str = "black",
        alpha: float = 0.8,
        lineage_alpha: float = 0.2,
        title: Optional[str] = None,
        size: int = 15,
        lw: float = 2,
        cbar: bool = True,
        margins: float = 0.015,
        xlabel: str = "pseudotime",
        ylabel: str = "expression",
        conf_int: bool = True,
        lineage_probability: bool = False,
        lineage_probability_conf_int: Union[bool, float] = False,
        lineage_probability_color: Optional[str] = None,
        dpi: int = None,
        fig: mpl.figure.Figure = None,
        ax: mpl.axes.Axes = None,
        return_fig: bool = False,
        save: Optional[str] = None,
        **kwargs,
    ) -> Optional[mpl.figure.Figure]:
        """
        Plot the smoothed gene expression.

        Parameters
        ----------
        figsize
            Size of the figure.
        same_plot
            Whether to plot all trends in the same plot.
        hide_cells
            Whether to hide the cells.
        perc
            Percentile by which to clip the absorption probabilities.
        abs_prob_cmap
            Colormap to use when coloring in the absorption probabilities.
        cell_color
            Color for the cells when not coloring absorption probabilities.
        lineage_color
            Color for the lineage.
        alpha
            Alpha channel for cells.
        lineage_alpha
            Alpha channel for lineage confidence intervals.
        title
            Title of the plot.
        size
            Size of the points.
        lw
            Line width for the smoothed values.
        cbar
            Whether to show colorbar.
        margins
            Margins around the plot.
        xlabel
            Label on the x-axis.
        ylabel
            Label on the y-axis.
        conf_int
            Whether to show the confidence interval.
        lineage_probability
            Whether to show smoothed lineage probability as a dashed line.
            Note that this will require 1 additional model fit.
        lineage_probability_conf_int
            Whether to compute and show smoothed lineage probability confidence interval.
            If :paramref:`self` is :class:`cellrank.ul.models.GAMR`, it can also specify the confidence level,
            the default is `0.95`. Only used when ``show_lineage_probability=True``.
        lineage_probability_color
            Color to use when plotting the smoothed ``lineage_probability``.
            If `None`, it's the same as ``lineage_color``. Only used when ``show_lineage_probability=True``.
        dpi
            Dots per inch.
        fig
            Figure to use, if `None`, create a new one.
        ax: :class:`matplotlib.axes.Axes`
            Ax to use, if `None`, create a new one.
        return_fig
            If `True`, return the figure object.
        save
            Filename where to save the plot. If `None`, just shows the plots.
        **kwargs
            Keyword arguments for :meth:`matplotlib.axes.Axes.legend`, e.g. to disable the legend, specify ``loc=None``.
            Only available when ``show_lineage_probability=True``.

        Returns
        -------
        %(just_plots)s
        """

        if self.y_test is None:
            raise RuntimeError("Run `.predict()` first.")

        if fig is None or ax is None:
            fig, ax = plt.subplots(figsize=figsize, constrained_layout=True)

        if dpi is not None:
            fig.set_dpi(dpi)

        conf_int = conf_int and self.conf_int is not None
        hide_cells = (hide_cells or self.x_all is None or self.w_all is None
                      or self.y_all is None)

        lineage_probability_color = (lineage_color
                                     if lineage_probability_color is None else
                                     lineage_probability_color)

        scaler = kwargs.pop(
            "scaler",
            self._create_scaler(
                lineage_probability,
                show_conf_int=conf_int,
            ),
        )

        if lineage_probability:
            if ylabel in ("expression", self._gene):
                ylabel = f"scaled {ylabel}"

        vmin, vmax = None, None
        if not hide_cells:
            vmin, vmax = _minmax(self.w_all, perc)
            _ = ax.scatter(
                self.x_all.squeeze(),
                scaler(self.y_all.squeeze()),
                c=cell_color if same_plot or np.allclose(self.w_all, 1.0) else
                self.w_all.squeeze(),
                s=size,
                cmap=abs_prob_cmap,
                vmin=vmin,
                vmax=vmax,
                alpha=alpha,
            )

        if title is None:
            title = (f"{self._gene} @ {self._lineage}"
                     if self._lineage is not None else f"{self._gene}")

        ax.plot(self.x_test,
                scaler(self.y_test),
                color=lineage_color,
                lw=lw,
                label=title)

        if title is not None:
            ax.set_title(title)
        if ylabel is not None:
            ax.set_ylabel(ylabel)
        if xlabel is not None:
            ax.set_xlabel(xlabel)

        ax.margins(margins)

        if conf_int:
            ax.fill_between(
                self.x_test.squeeze(),
                scaler(self.conf_int[:, 0]),
                scaler(self.conf_int[:, 1]),
                alpha=lineage_alpha,
                color=lineage_color,
                linestyle="--",
            )

        if (lineage_probability and not isinstance(self, FittedModel)
                and not np.allclose(self.w, 1.0)):
            from cellrank.pl._utils import _is_any_gam_mgcv

            model = deepcopy(self)
            model._y = self._reshape_and_retype(self.w).copy()
            model = model.fit()

            if not lineage_probability_conf_int:
                y = model.predict()
            elif _is_any_gam_mgcv(model):
                y = model.predict(
                    level=lineage_probability_conf_int if isinstance(
                        lineage_probability_conf_int, float) else 0.95)
            else:
                y = model.predict()
                model.confidence_interval()

                ax.fill_between(
                    model.x_test.squeeze(),
                    model.conf_int[:, 0],
                    model.conf_int[:, 1],
                    alpha=lineage_alpha,
                    color=lineage_probability_color,
                    linestyle="--",
                )

            handle = ax.plot(
                model.x_test,
                y,
                color=lineage_probability_color,
                lw=lw,
                linestyle="--",
                zorder=-1,
                label="probability",
            )

            if kwargs.get("loc", "best") is not None:
                ax.legend(handles=handle, **kwargs)

        if (cbar and not hide_cells and not same_plot
                and not np.allclose(self.w_all, 1.0)):
            norm = mcolors.Normalize(vmin=vmin, vmax=vmax)
            divider = make_axes_locatable(ax)
            cax = divider.append_axes("right", size="2%", pad=0.1)
            _ = mpl.colorbar.ColorbarBase(
                cax,
                norm=norm,
                cmap=abs_prob_cmap,
                ticks=np.linspace(norm.vmin, norm.vmax, 5),
            )

        if save is not None:
            save_fig(fig, save)

        if return_fig:
            return fig
Пример #23
0
def plot_2d_gauss_fit(axes: mpl.axes.Axes, popt_2d, x_centers: np.ndarray,
                      y_centers: np.ndarray):
    """ plot the contours of the Gaussian fit on the hitmap """
    xx, yy = np.meshgrid(x_centers, y_centers)
    axes.contour(xx, yy, ft.gauss_2d(*popt_2d)(xx, yy), cmap='coolwarm')
    axes.scatter(popt_2d[0], popt_2d[1], color='darkred')
Пример #24
0
def fate(
    adata: AnnData,
    x: int = 0,
    y: int = 1,
    basis: str = "pca",
    color: str = "ntr",
    ax: matplotlib.axes.Axes = None,
    save_show_or_return: str = "show",
    save_kwargs: dict = {},
    **kwargs: dict
):
    """Draw the predicted integration paths on the low-dimensional embedding.

    Parameters
    ----------
        adata: :class:`~anndata.AnnData`
            an Annodata object
        basis: `str`
            The reduced dimension.
        x: `int` (default: `0`)
            The column index of the low dimensional embedding for the x-axis.
        y: `int` (default: `1`)
            The column index of the low dimensional embedding for the y-axis.
        color: `string` (default: `ntr`)
            Any column names or gene expression, etc. that will be used for coloring cells.
        ax: `matplotlib.Axis` (optional, default `None`)
            The matplotlib axes object where new plots will be added to. Only applicable to drawing a single component.
        save_show_or_return: `str` {'save', 'show', 'return'} (default: `show`)
            Whether to save, show or return the figure. If "both", it will save and plot the figure at the same time. If
            "all", the figure will be saved, displayed and the associated axis and other object will be return.
        save_kwargs: `dict` (default: `{}`)
            A dictionary that will passed to the save_fig function. By default it is an empty dictionary and the
            save_fig function will use the {"path": None, "prefix": 'scatter', "dpi": None, "ext": 'pdf', "transparent":
            True, "close": True, "verbose": True} as its parameters. Otherwise you can provide a dictionary that
            properly modify those keys according to your needs.
        kwargs:
            Additional arguments passed to pl.scatters or plt.scatters.

    Returns
    -------
        result:
            Either None or a matplotlib axis with the relevant plot displayed.
            If you are using a notbooks and have ``%matplotlib inline`` set
            then this will simply display inline.
    """

    import matplotlib.pyplot as plt

    ax = scatters(adata, basis=basis, color=color, save_show_or_return="return", ax=ax, **kwargs)

    fate_key = "fate" if basis is None else "fate_" + basis
    lap_dict = adata.uns[fate_key]

    for i, j in zip(lap_dict["prediction"], lap_dict["t"]):
        ax.scatter(*i[:, [x, y]].T, c=map2color(j))
        ax.plot(*i[:, [x, y]].T, c="k")

    if save_show_or_return in ["save", "both", "all"]:
        s_kwargs = {
            "path": None,
            "prefix": "kinetic_curves",
            "dpi": None,
            "ext": "pdf",
            "transparent": True,
            "close": True,
            "verbose": True,
        }
        s_kwargs = update_dict(s_kwargs, save_kwargs)

        save_fig(**s_kwargs)
    elif save_show_or_return in ["show", "both", "all"]:
        plt.tight_layout()
        plt.show()
    elif save_show_or_return in ["return", "all"]:
        return ax
Пример #25
0
def plot_vaccinations_by_time(df: pd.DataFrame,
                              df_delivered: pd.DataFrame,
                              ax: mp.axes.Axes,
                              wich: str = "first") -> ResultValue:
    log = logging.getLogger('plot_vaccinations_by_time')
    log.info(" >>")
    try:
        ln_one_color = "#f08814"
        ln_two_color = "#92b7e9"
        ln_one_label = "Cumulata numero vaccinazioni"
        ln_two_label = "Distribuzione giornaliera"

        grp_by_time = df.groupby("data_somministrazione").sum()
        x = grp_by_time.index.values
        y = grp_by_time["prima_dose"]
        y_cum_sum = grp_by_time["prima_dose"].cumsum()

        set_axes_common_properties(ax, no_grid=False)
        ax.get_yaxis().set_major_formatter(
            mp.ticker.FuncFormatter(lambda x, p: format(int(x), ',')))

        remove_tick_lines('x', ax)
        remove_tick_lines('y', ax)

        ax.set_xticks(x)
        ax.set_xticklabels(x, rotation=80)
        ax.xaxis.set_major_formatter(mdates.DateFormatter("%d/%m/%y"))
        ax.xaxis.set_minor_formatter(mdates.DateFormatter("%d/%m"))
        ax.xaxis.set_major_locator(mdates.DayLocator(interval=2))
        ax.set_ylabel(ln_one_label, fontsize=14)
        ax.set_xlabel("Data", fontsize=14)
        ax.set_title("Vaccinazioni nel tempo - prima dose", fontsize=18)
        ax.tick_params(axis='y', colors=ln_one_color)
        ax.yaxis.label.set_color(ln_one_color)

        ax.scatter(x, y_cum_sum, color=ln_one_color, s=30, marker='.')
        ln_one = ax.plot(x,
                         y_cum_sum,
                         'b-',
                         linewidth=2,
                         color=ln_one_color,
                         label=ln_one_label)

        result = plot_delivered_vaccines_quantity(df_delivered, ax)
        if result.is_in_error() == True:
            log.error(result())
            return result
        line_three = result()

        ax_dec = ax.twinx()

        remove_tick_lines('y', ax_dec)
        remove_tick_lines('x', ax_dec)

        set_axes_common_properties(ax_dec, no_grid=True)

        ax_dec.scatter(x, y, color=ln_two_color, s=30, marker='.')
        ln_two = ax_dec.plot(x,
                             y,
                             'b-',
                             linewidth=2,
                             color=ln_two_color,
                             label=ln_two_label)

        ax_dec.set_ylabel(ln_two_label, fontsize=14)
        ax_dec.yaxis.label.set_color(ln_two_color)
        ax_dec.tick_params(axis='y', colors=ln_two_color)

        lns = ln_one + ln_two + line_three
        labs = [l.get_label() for l in lns]
        ax.legend(lns, labs, loc='upper left')

    except Exception as ex:
        log.error("Exception caught - {ex}".format(ex=ex))
        return ResultKo(ex)
    log.info(" <<")
    return ResultOk(True)
Пример #26
0
 def plot_misses_against_hits(ax: mpl.axes.Axes, x: Sequence[int], y: Sequence[int], **kwargs) -> mpl.collections.PathCollection:
     ax.set_xlabel("Cache misses")
     ax.set_ylabel("Cache hits")
     return ax.scatter(x, y, edgecolors="none", **kwargs)
Пример #27
0
    def plot(
        self,
        figsize: Tuple[float, float] = (15, 10),
        same_plot: bool = False,
        hide_cells: bool = False,
        perc: Tuple[float, float] = None,
        abs_prob_cmap: mcolors.ListedColormap = cm.viridis,
        cell_color: str = "black",
        lineage_color: str = "black",
        alpha: float = 0.8,
        lineage_alpha: float = 0.2,
        title: Optional[str] = None,
        size: int = 15,
        lw: float = 2,
        show_cbar: bool = True,
        margins: float = 0.015,
        xlabel: str = "pseudotime",
        ylabel: str = "expression",
        show_conf_int: bool = True,
        dpi: int = None,
        fig: mpl.figure.Figure = None,
        ax: mpl.axes.Axes = None,
        return_fig: bool = False,
        save: Optional[str] = None,
    ) -> Optional[mpl.figure.Figure]:
        """
        Plot the smoothed gene expression.

        Parameters
        ----------
        figsize
            Size of the figure.
        same_plot
            Whether to plot all trends in the same plot.
        hide_cells
            Whether to hide the cells.
        perc
            Percentile by which to clip the absorption probabilities./
        abs_prob_cmap
            Colormap to use when coloring in the absorption probabilities.
        cell_color
            Color for the cells when not coloring absorption probabilities.
        lineage_color
            Color for the lineage.
        alpha
            Alpha channel for cells.
        lineage_alpha
            Alpha channel for lineage confidence intervals.
        title
            Title of the plot.
        size
            Size of the points.
        lw
            Line width for the smoothed values.
        show_cbar
            Whether to show colorbar.
        margins
            Margins around the plot.
        xlabel
            Label on the x-axis.
        ylabel
            Label on the y-axis.
        show_conf_int
            Whether to show the confidence interval.
        dpi
            Dots per inch.
        fig
            Figure to use, if `None`, create a new one.
        ax: :class:`matplotlib.axes.Axes`
            Ax to use, if `None`, create a new one.
        return_fig
            If `True`, return the figure object.
        save
            Filename where to save the plot. If `None`, just shows the plots.

        Returns
        -------
        %(just_plots)s
        """

        if fig is None or ax is None:
            fig, ax = plt.subplots(figsize=figsize, constrained_layout=True)

        if dpi is not None:
            fig.set_dpi(dpi)

        vmin, vmax = _minmax(self.w, perc)
        if not hide_cells:
            _ = ax.scatter(
                self.x_all.squeeze(),
                self.y_all.squeeze(),
                c=cell_color if same_plot or np.allclose(self.w_all, 1.0) else
                self.w_all.squeeze(),
                s=size,
                cmap=abs_prob_cmap,
                vmin=vmin,
                vmax=vmax,
                alpha=alpha,
            )

        if title is None:
            title = f"{self._gene} @ {self._lineage}"

        _ = ax.plot(self.x_test,
                    self.y_test,
                    color=lineage_color,
                    lw=lw,
                    label=title)

        ax.set_title(title)
        ax.set_ylabel(ylabel)
        ax.set_xlabel(xlabel)

        ax.margins(margins)

        if show_conf_int and self.conf_int is not None:
            ax.fill_between(
                self.x_test.squeeze(),
                self.conf_int[:, 0],
                self.conf_int[:, 1],
                alpha=lineage_alpha,
                color=lineage_color,
                linestyle="--",
            )

        if (show_cbar and not hide_cells and not same_plot
                and not np.allclose(self.w_all, 1)):
            norm = mcolors.Normalize(vmin=vmin, vmax=vmax)
            divider = make_axes_locatable(ax)
            cax = divider.append_axes("right", size="2.5%", pad=0.1)
            _ = mpl.colorbar.ColorbarBase(cax,
                                          norm=norm,
                                          cmap=abs_prob_cmap,
                                          label="absorption probability")

        if save is not None:
            save_fig(fig, save)

        if return_fig:
            return fig
Пример #28
0
    def plot(
        self,
        x_label: str = "Method 1",
        y_label: str = "Method 2",
        title: str = None,
        line_reference: bool = True,
        line_CI: bool = True,
        legend: bool = True,
        square: bool = False,
        ax: matplotlib.axes.Axes = None,
        point_kws: Optional[Dict] = None,
        color_regr: Optional[str] = None,
        alpha_regr: Optional[float] = None,
    ) -> matplotlib.axes.Axes:
        """Plot regression result

        Parameters
        ----------
        x_label : str, optional
            The label which is added to the X-axis. (default: "Method 1")
        y_label : str, optional
            The label which is added to the Y-axis. (default: "Method 2")
        title : str, optional
            Title of the regression plot. If None is provided, no title will be plotted.
        line_reference : bool, optional
            If True, a grey reference line at y=x will be plotted in the plot
            (default: True)
        line_CI : bool, optional
            If True, dashed lines will be plotted at the boundaries of the confidence
            intervals.
            (default: False)
        legend : bool, optional
            If True, will provide a legend containing the computed regression equation.
            (default: True)
        square : bool, optional
            If True, set the Axes aspect to "equal" so each cell will be
            square-shaped. (default: True)
        ax : matplotlib.axes.Axes, optional
            matplotlib axis object, if not passed, uses gca()
        point_kws : Optional[Dict], optional
            Additional keywords to plt
        color_regr : Optional[str], optional
            color for regression line and CI area
        alpha_regr : Optional[float], optional
            alpha for regression CI area

        Returns
        ------------------
        matplotlib.axes.Axes
            axes object with the plot
        """
        ax = ax or plt.gca()

        # Set scatter plot keywords to defaults and apply override
        pkws = self.DEFAULT_POINT_KWS.copy()
        pkws.update(point_kws or {})

        # Get regression parameters
        slope = self.result["slope"]
        intercept = self.result["intercept"]

        # plot individual points
        ax.scatter(self.method1, self.method2, **pkws)

        # plot reference line
        if line_reference:
            ax.plot(
                [0, 1],
                [0, 1],
                label="Reference",
                color="grey",
                linestyle="--",
                transform=ax.transAxes,
            )

        # Compute x and y values
        xvals = np.array(ax.get_xlim())
        yvals = xvals[:, None] * slope + intercept

        # Plot regression line 0
        ax.plot(
            xvals,
            yvals[:, 0],
            label=
            f"{y_label} = {intercept[0]:.2f} + {slope[0]:.2f} * {x_label}",
            color=color_regr,
            linestyle="-",
        )

        # Plot confidence region
        if yvals.shape[1] > 2:
            ax.fill_between(
                xvals,
                yvals[:, 1],
                yvals[:, 2],
                color=color_regr or self.DEFAULT_REGRESSION_KWS["color"],
                alpha=alpha_regr or self.DEFAULT_REGRESSION_KWS["alpha"],
            )
            if line_CI:
                ax.plot(xvals, yvals[:, 1], linestyle="--")
                ax.plot(xvals, yvals[:, 2], linestyle="--")

        # Set axes labels
        ax.set(
            xlabel=x_label or "",
            ylabel=y_label or "",
            title=title or "",
        )

        if legend:
            ax.legend(loc="upper left", frameon=False)

        if square:
            ax.set_aspect("equal")
        return ax
Пример #29
0
def visualize_one_dataset(dataset: data.Dataset, ax: matplotlib.axes.Axes):
    for coordinate, label in dataset:
        x, y = coordinate
        color = {0: "#bada55", 1: "#55bada"}[label.item()]
        marker = {0: "+", 1: "."}[label.item()]
        ax.scatter(x, y, c=color, marker=marker)