Exemple #1
0
def plot_histogram(axes: mpl.axes.Axes, hist, popt, bin_edges: np.ndarray,
                   axis: str):
    """ plots the histogram of one of the axis projections

    Plot the projection of histogram of hits. The plot is plotted vertically
    if the axis label is given as 'y'

    Parameters
    ----------
    axes : mpl.axes.Axes
        the axes to plot the histogram into. The plot will be vertical if the
        axis is specified to be 'y'
    hist : np.ndarray
        the histogramd hits
    popt : np.ndarray
        the parameters for the Gaussian that to be plotted over the hist
    bin_edges : np.ndarray
        the edges of the bins used for the histogram
    axis : str
        the axis on for which the results should be plotted (either 'x'
        or 'y'). Plots vertically if 'y' is specified.
    """
    bin_centers = (bin_edges[1:] + bin_edges[:-1])/2
    if axis == 'x':
        axes.set_xlim([bin_edges[0], bin_edges[-1]])
        axes.hist(bin_edges[:-1], bin_edges, weights=hist, density=True)
        axes.plot(bin_centers, norm.pdf(bin_centers, *popt))
    elif axis == 'y':
        axes.set_ylim([bin_edges[0], bin_edges[-1]])
        axes.hist(bin_edges[:-1], bin_edges, weights=hist, density=True,
                  orientation='horizontal')
        axes.plot(norm.pdf(bin_centers, *popt), bin_centers)
    else:
        raise ValueError("axis has to be either 'x' or 'y'")
Exemple #2
0
def volcano_plot(df: pd.DataFrame, ax: matplotlib.axes.Axes) -> matplotlib.axes.Axes:
    '''Generate a volcano plot

    Parameters
    ----------
    df : pd.DataFrame
        differential expression output from `diffex_multifactor`.
    ax : matplotlib.axes.Axes

    Returns
    -------
    matplotlib.axes.Axes
    '''
    if 'significant' not in df.columns:
        print('Adding significance cutoff at alpha=0.05')
        df['significant'] = df['q_val'] < 0.05

    n_colors = len(np.unique(df['significant']))
    sns.scatterplot(data=df, x='log2_fc', y='nlogq', hue='significant',
                    linewidth=0.,
                    alpha=0.3,
                    ax=ax,
                    palette=sns.hls_palette(n_colors)[::-1])
    ax.set_xlim((-6, 6))
    ax.set_ylabel(r'$-\log_{10}$ q-value')
    ax.set_xlabel(r'$\log_2$ (Old / Young)')
    ax.get_legend().remove()
    return ax
def zoom_x_and_save(fig: matplotlib.figure.Figure, ax: matplotlib.axes.Axes,
                    figbase: str, plot_ext: str,
                    xzoom: List[Tuple[float, float]]) -> None:
    """
    Zoom in on subregions of the x-axis and save the figure.

    Arguments
    ---------
    fig : matplotlib.figure.Figure
        Figure to be processed.
    ax : matplotlib.axes.Axes
        Axes to be processed.
    fig_base : str
        Base name of the figure to be saved.
    plot_ext : str
        File extension of the figure to be saved.
    xzoom : List[list[float,float]]
        Values at which to split the x-axis.
    """
    xmin, xmax = ax.get_xlim()
    for ix in range(len(xzoom)):
        ax.set_xlim(xmin=xzoom[ix][0], xmax=xzoom[ix][1])
        figfile = (figbase + ".sub" + str(ix + 1) + plot_ext)
        savefig(fig, figfile)
    ax.set_xlim(xmin=xmin, xmax=xmax)
def zoom_xy_and_save(fig: matplotlib.figure.Figure,
                     ax: matplotlib.axes.Axes,
                     figbase: str,
                     plot_ext: str,
                     xyzoom: List[Tuple[float, float, float, float]],
                     scale: float = 1000) -> None:
    """
    Zoom in on subregions in x,y-space and save the figure.

    Arguments
    ---------
    fig : matplotlib.figure.Figure
        Figure to be processed.
    ax : matplotlib.axes.Axes
        Axes to be processed.
    fig_base : str
        Base name of the figure to be saved.
    plot_ext : str
        File extension of the figure to be saved.
    xyzoom : List[List[float, float, float, float]]
        List of xmin, xmax, ymin, ymax values to zoom into.
    scale: float
        Indicates whether the axes are in m (1) or km (1000).
    """
    xmin, xmax = ax.get_xlim()
    ymin, ymax = ax.get_ylim()

    dx_zoom = 0
    xy_ratio = (ymax - ymin) / (xmax - xmin)
    for ix in range(len(xyzoom)):
        xmin0 = xyzoom[ix][0]
        xmax0 = xyzoom[ix][1]
        ymin0 = xyzoom[ix][2]
        ymax0 = xyzoom[ix][3]
        dx = xmax0 - xmin0
        dy = ymax0 - ymin0
        if dy < xy_ratio * dx:
            # x range limiting
            dx_zoom = max(dx_zoom, dx)
        else:
            # y range limiting
            dx_zoom = max(dx_zoom, dy / xy_ratio)
    dy_zoom = dx_zoom * xy_ratio

    for ix in range(len(xyzoom)):
        x0 = (xyzoom[ix][0] + xyzoom[ix][1]) / 2
        y0 = (xyzoom[ix][2] + xyzoom[ix][3]) / 2
        ax.set_xlim(xmin=(x0 - dx_zoom / 2) / scale,
                    xmax=(x0 + dx_zoom / 2) / scale)
        ax.set_ylim(ymin=(y0 - dy_zoom / 2) / scale,
                    ymax=(y0 + dy_zoom / 2) / scale)
        figfile = (figbase + ".sub" + str(ix + 1) + plot_ext)
        savefig(fig, figfile)

    ax.set_xlim(xmin=xmin, xmax=xmax)
    ax.set_ylim(ymin=ymin, ymax=ymax)
Exemple #5
0
 def _draw_curve(self, ax: matplotlib.axes.Axes, rt_buffer: float) -> None:
     """Draw the EIC data and fill under the curve betweeen RT min and RT max"""
     eic = self.compound["data"]["eic"]
     if eic is not None and "rt" in eic and len(eic["rt"]) > 0:
         # fill_between requires a data point at each end of range, so add points via interpolation
         x, y = add_interp_at(eic["rt"], eic["intensity"], self.rt_range)
         ax.plot(x, y)
         utils.fill_under(ax, x, y, between=self.rt_range, color="c", alpha=0.3)
     x_min = min(self.rt_range[0], self.rt_peak) - rt_buffer
     x_max = max(self.rt_range[1], self.rt_peak) + rt_buffer
     ax.set_xlim(x_min, x_max)
Exemple #6
0
def customize_ax(ax: matplotlib.axes.Axes,
                 title=None,
                 xlabel=None,
                 ylabel=None,
                 xlim=None,
                 ylim=None,
                 invert_yaxis=False,
                 xticks_maj_freq=None,
                 xticks_min_freq=None,
                 yticks_maj_freq=None,
                 yticks_min_freq=None,
                 with_hline=False,
                 hline_height=None,
                 hline_color='r',
                 hline_style='--'):
    """
    : ax (matplotlib.axes.Axes): plot to customize.
    : Use to customize a plot with labels, ticks, etc.
    """
    if xlabel is not None:
        ax.set_xlabel(xlabel)
    if ylabel is not None:
        ax.set_ylabel(ylabel)

    if xlim is not None:
        ax.set_xlim(xlim)
    if ylim is not None:
        ax.set_ylim(ylim)

    if invert_yaxis:
        ax.invert_yaxis()

    if title is not None:
        ax.set_title(title)

    if xticks_maj_freq is not None:
        ax.xaxis.set_major_locator(ticker.MultipleLocator(xticks_maj_freq))

    if xticks_min_freq is not None:
        ax.xaxis.set_minor_locator(ticker.MultipleLocator(xticks_min_freq))

    if yticks_maj_freq is not None:
        ax.yaxis.set_major_locator(ticker.MultipleLocator(yticks_maj_freq))

    if yticks_min_freq is not None:
        ax.yaxis.set_minor_locator(ticker.MultipleLocator(yticks_min_freq))

    if with_hline:
        if hline_height is None:
            ylim = plt.ylim()
            hline_height = max(ylim) / 2
        ax.axhline(y=hline_height, color=hline_color, linestyle=hline_style)
Exemple #7
0
def _plot_data(ax: mpl.axes.Axes, data: PlotData) -> Optional[List[mpl.lines.Line2D]]:
    
    x, y = None, None
    
    lines = None  # Return line objects so we can add legends
    
    disp = data.display_attributes
    
    if isinstance(data, XYData) or isinstance(data, TimeSeries):
        x, y = (data.x, data.y) if isinstance(data, XYData) else (np.arange(len(data.timestamps)), data.values)
        if isinstance(disp, LinePlotAttributes):
            lines, = ax.plot(x, y, linestyle=disp.line_type, linewidth=disp.line_width, color=disp.color)
            if disp.marker is not None:  # type: ignore
                ax.scatter(x, y, marker=disp.marker, c=disp.marker_color, s=disp.marker_size, zorder=100)
        elif isinstance(disp, ScatterPlotAttributes):
            lines = ax.scatter(x, y, marker=disp.marker, c=disp.marker_color, s=disp.marker_size, zorder=100)
        elif isinstance(disp, BarPlotAttributes):
            lines = ax.bar(x, y, color=disp.color)  # type: ignore
        elif isinstance(disp, FilledLinePlotAttributes):
            x, y = np.nan_to_num(x), np.nan_to_num(y)
            pos_values = np.where(y > 0, y, 0)
            neg_values = np.where(y < 0, y, 0)
            ax.fill_between(x, pos_values, color=disp.positive_color, step='post', linewidth=0.0)
            ax.fill_between(x, neg_values, color=disp.negative_color, step='post', linewidth=0.0)
        else:
            raise Exception(f'unknown plot combination: {type(data)} {type(disp)}')
            
        # For scatter and filled line, xlim and ylim does not seem to get set automatically
        if isinstance(disp, ScatterPlotAttributes) or isinstance(disp, FilledLinePlotAttributes):
            xmin, xmax = _adjust_axis_limit(ax.get_xlim(), x)
            if not np.isnan(xmin) and not np.isnan(xmax): ax.set_xlim((xmin, xmax))

            ymin, ymax = _adjust_axis_limit(ax.get_ylim(), y)
            if not np.isnan(ymin) and not np.isnan(ymax): ax.set_ylim((ymin, ymax))
                
    elif isinstance(data, TradeSet) and isinstance(disp, ScatterPlotAttributes):
        lines = ax.scatter(np.arange(len(data.timestamps)), data.values, marker=disp.marker, c=disp.marker_color, s=disp.marker_size, zorder=100)
    elif isinstance(data, TradeBarSeries) and isinstance(disp, CandleStickPlotAttributes):
        draw_candlestick(ax, np.arange(len(data.timestamps)), data.o, data.h, data.l, data.c, data.v, data.vwap, colorup=disp.colorup, colordown=disp.colordown)
    elif isinstance(data, BucketedValues) and isinstance(disp, BoxPlotAttributes):
        draw_boxplot(
            ax, data.bucket_names, data.bucket_values, disp.proportional_widths, disp.notched,  # type: ignore
            disp.show_outliers, disp.show_means, disp.show_all)  # type: ignore
    elif isinstance(data, XYZData) and (isinstance(disp, SurfacePlotAttributes) or isinstance(disp, ContourPlotAttributes)):
        display_type: str = 'contour' if isinstance(disp, ContourPlotAttributes) else 'surface'
        draw_3d_plot(ax, data.x, data.y, data.z, display_type, disp.marker, disp.marker_size, 
                     disp.marker_color, disp.interpolation, disp.cmap)
    else:
        raise Exception(f'unknown plot combination: {type(data)} {type(disp)}')

    return lines
Exemple #8
0
    def plot(self, ax: matplotlib.axes.Axes):
        # individual points
        ax.scatter(self.mean, self.diff, s=20, alpha=0.6, color=self.color_points,
                   **self.point_kws)

        # mean difference and SD lines
        ax.axhline(self.mean_diff, color=self.color_mean, linestyle='-')
        ax.axhline(self.mean_diff + self.loa_sd, color=self.color_loa, linestyle='--')
        ax.axhline(self.mean_diff - self.loa_sd, color=self.color_loa, linestyle='--')

        if self.reference:
            ax.axhline(0, color='grey', linestyle='-', alpha=0.4)

        # confidence intervals (if requested)
        if self.CI is not None:
            ax.axhspan(self.CI_mean[0],  self.CI_mean[1], color=self.color_mean, alpha=0.2)
            ax.axhspan(self.CI_upper[0], self.CI_upper[1], color=self.color_loa, alpha=0.2)
            ax.axhspan(self.CI_lower[0], self.CI_lower[1], color=self.color_loa, alpha=0.2)

        # text in graph
        trans: matplotlib.transform = transforms.blended_transform_factory(
            ax.transAxes, ax.transData)
        offset: float = (((self.loa * self.sd_diff) * 2) / 100) * 1.2
        ax.text(0.98, self.mean_diff + offset, 'Mean', ha="right", va="bottom", transform=trans)
        ax.text(0.98, self.mean_diff - offset, f'{self.mean_diff:.2f}', ha="right", va="top", transform=trans)
        ax.text(0.98, self.mean_diff + self.loa_sd + offset,
                f'+{self.loa:.2f} SD', ha="right", va="bottom", transform=trans)
        ax.text(0.98, self.mean_diff + self.loa_sd - offset,
                f'{self.mean_diff + self.loa_sd:.2f}', ha="right", va="top", transform=trans)
        ax.text(0.98, self.mean_diff - self.loa_sd - offset,
                f'-{self.loa:.2f} SD', ha="right", va="top", transform=trans)
        ax.text(0.98, self.mean_diff - self.loa_sd + offset,
                f'{self.mean_diff - self.loa_sd:.2f}', ha="right", va="bottom", transform=trans)

        # transform graphs
        ax.spines['right'].set_visible(False)
        ax.spines['top'].set_visible(False)

        # set X and Y limits
        if self.xlim is not None:
            ax.set_xlim(self.xlim[0], self.xlim[1])
        if self.ylim is not None:
            ax.set_ylim(self.ylim[0], self.ylim[1])

        # graph labels
        ax.set_ylabel(self.y_title)
        ax.set_xlabel(self.x_title)
        if self.graph_title is not None:
            ax.set_title(self.graph_title)
Exemple #9
0
    def set_ax_lims(self, ax: mpl.axes.Axes, xlims: tuple = None,
                    ylims: tuple = None, yshade: list = None):
        """Set matplotlib axis limits and apply vertical shade.

        Keyword arguments:
        ax -- matplotlib.axes.Axes object to apply changes
        xlims -- tuple for setting x-axis limits (xmin, xmax)
        ylims -- tuple for setting y-axis limits (ymin, ymax)
        yshade -- list of tuples to apply axvspan matplotlib method
        """
        if xlims is not None:
            ax.set_xlim(xlims)
        if ylims is not None:
            ax.set_ylim(ylims)
        if yshade is not None:
            for window in yshade:
                ax.axvspan(window[0], window[1], color='grey', alpha=0.75)
Exemple #10
0
def plot_astrometric_residuals(ax: matplotlib.axes.Axes,
                               xs: np.ndarray,
                               ys: np.ndarray) -> None:
    """
    Plot the astrometric residual field of a set of points.

    Parameters
    ----------
    ax:
        Matplotlib axis in which to plot
    xs:
        Array of the x- and y-components of the field
    ys:
        Array of the x- and y-components of the astrometric residual field

    Returns
    -------
    None
    """
    qdict = dict(
        alpha=1,
        angles='uv',
        headlength=5,
        headwidth=3,
        headaxislength=4,
        minlength=0,
        pivot='middle',
        scale_units='xy',
        width=0.002,
        color='#001146'
    )

    q = ax.quiver(xs[:, 0], xs[:, 1], ys[:, 0], ys[:, 1], scale=1, **qdict)
    ax.quiverkey(q, 0.0, 1.8, 0.1, 'residual = 0.1 arcsec',
                 coordinates='data', labelpos='N',
                 color='darkred', labelcolor='darkred')

    ax.set_xlabel('RA [degrees]')
    ax.set_ylabel('Dec [degrees]')

    ax.set_xlim(-1.95, 1.95)
    ax.set_ylim(-1.9, 2.0)
    ax.set_aspect('equal')
def plot_topk_cost(ax: mpl.axes.Axes,
                   experiment_name: str,
                   eval_metric: str,
                   pool_size: int,
                   plot_kwargs: Dict[str, Any] = {}) -> None:
    """
    Replicates Figure 2 in [CITE PAPER].

    Parameters
    ===
    experiment_name: str.
        Experimental results were written to files under a directory named using experiment_name.
    eval_metric: str.
        Takes value from ['avg_num_agreement', 'mrr']
    pool_size: int.
        Total size of pool from which samples were drawn.
    plot_kwargs : dict.
        Keyword arguments passed to the plot.
    Returns
    ===
    fig, axes : The generated matplotlib Figure and Axes.
    """

    _plot_kwargs = DEFAULT_PLOT_KWARGS.copy()
    _plot_kwargs.update(plot_kwargs)

    for method in COST_METHOD_NAME_DICT:
        metric_eval = np.load(
            RESULTS_DIR + experiment_name + ('/%s_%s_top1_pseudocount1.0.npy' % (method, eval_metric)))
        x = np.arange(len(metric_eval)) * LOG_FREQ / pool_size
        ax.plot(x, metric_eval, label=COST_METHOD_NAME_DICT[method], **_plot_kwargs)

    cutoff = len(metric_eval) - 1
    ax.set_xlim(0, cutoff * LOG_FREQ / pool_size)
    ax.set_ylim(0, 1.0)
    xmin, xmax = ax.get_xlim()
    step = ((xmax - xmin) / 4.0001)
    ax.xaxis.set_major_formatter(ticker.PercentFormatter(xmax=1))
    ax.xaxis.set_ticks(np.arange(xmin, xmax + 0.001, step))
    ax.yaxis.set_ticks(np.arange(0, 1.01, 0.20))
    ax.tick_params(pad=0.25, length=1.5)

    return ax
def set_bbox(
    ax: matplotlib.axes.Axes,
    bbox: Tuple[float, float, float, float],
    scale: float = 1000,
) -> None:
    """
    Specify the bounding limits of an axes object.
    
    Arguments
    ---------
    ax : matplotlib.axes.Axes
        Axes object to be adjusted. 
    bbox : Tuple[float, float, float, float]
        Tuple containing boundary limits (xmin, ymin, xmax, ymax); unit m.
    scale: float
        Indicates whether the axes are in m (1) or km (1000).
    """
    ax.set_xlim(xmin=bbox[0] / scale, xmax=bbox[2] / scale)
    ax.set_ylim(ymin=bbox[1] / scale, ymax=bbox[3] / scale)
Exemple #13
0
def plot_step_analyzer(axes: matplotlib.axes.Axes, result: dict, title: str,
                       legends: list, colorset: int):
    colors = [
        ('red', 'green', 'blue'),  # TODO: add more colors
        ('tomato', 'lightgreen', 'steelblue'),
    ]
    c = colors[colorset]

    n = len(result['states'])

    for i in range(n):
        if legends:
            axes.plot(result['times'],
                      result['states'][i],
                      color=c[i],
                      label=legends[i])
        else:
            axes.plot(result['times'], result['states'][i], color=c[i])

        if result['references'][i] != 0.0:
            axes.axhline(result['references'][i], linestyle='--', color=c[i])
            axes.axhline(result['references'][i] - result['thresholds'][i],
                         linestyle='-.',
                         color=c[i])
            axes.axhline(result['references'][i] + result['thresholds'][i],
                         linestyle='-.',
                         color=c[i])
            axes.axhline(result['references'][i] + result['overshoots'][i],
                         linestyle=':',
                         color=c[i])
        if result['risetimes'][i] > 0.0:
            axes.axvline(result['risetimes'][i], linestyle='--', color=c[i])
        if result['settletimes'][i] > 0.0:
            axes.axvline(result['settletimes'][i], linestyle='-.', color=c[i])

    if legends:
        axes.legend()

    if title:
        axes.set_title(title)

    axes.set_xlim(result['times'][0], result['times'][-1])
    axes.figure.tight_layout()
def plot_ic(
    ax: matplotlib.axes.Axes,
    ic: IonChromatogram,
    minutes: bool = False,
    **kwargs,
) -> List[Line2D]:
    """
	Plots an Ion Chromatogram.

	:param ax: The axes to plot the IonChromatogram on.
	:param ic: Ion chromatogram m/z channels for plotting.
	:param minutes: Whether the x-axis should be plotted in minutes. Default :py:obj:`False` (plotted in seconds)
	:no-default minutes:

	:Other Parameters: :class:`matplotlib.lines.Line2D` properties.
		Used to specify properties like a line label (for auto legends),
		linewidth, antialiasing, marker face color.

		.. code-block:: python

			>>> plot_ic(im.get_ic_at_index(5), label='IC @ Index 5', linewidth=2)

		See :class:`matplotlib.lines.Line2D` for the list of possible keyword arguments.

	:return: A list of Line2D objects representing the plotted data.
	"""

    if not isinstance(ic, IonChromatogram):
        raise TypeError("'ic' must be an IonChromatogram")

    time_list = ic.time_list

    if minutes:
        time_list = [time / 60 for time in time_list]

    plot = ax.plot(time_list, ic.intensity_array, **kwargs)

    # Set axis ranges
    ax.set_xlim(min(time_list), max(time_list))
    ax.set_ylim(bottom=0)

    return plot
Exemple #15
0
def plot_ece_samples(ax: mpl.axes.Axes,
                     ground_truth_ece: float,
                     frequentist_ece,
                     samples_posterior: np.ndarray,
                     plot_kwargs: Dict[str, Any] = {}) -> mpl.axes.Axes:
    """

    :param ax:
    :param ground_truth_ece: float
    :param frequentist_ece: float or np.ndarray

    :param samples_posterior:
    :param plot_kwargs:
    :return:
    """
    _plot_kwargs = DEFAULT_PLOT_KWARGS.copy()
    _plot_kwargs.update(plot_kwargs)
    if isinstance(frequentist_ece, float):
        ax.axvline(x=frequentist_ece,
                   label='Frequentist',
                   color='blue',
                   **_plot_kwargs)
    else:
        ax.hist(frequentist_ece,
                color='blue',
                alpha=0.7,
                label='Frequentist',
                **_plot_kwargs)
    ax.hist(samples_posterior,
            color='red',
            label='Bayesian',
            alpha=0.7,
            **_plot_kwargs)
    ax.axvline(x=ground_truth_ece,
               label='Ground truth',
               color='black',
               **_plot_kwargs)

    ax.set_xlim(0, 0.3)
    ax.set_xticks([0.0, 0.1, 0.2, 0.3])

    return ax
Exemple #16
0
    def apply(self, axes: matplotlib.axes.Axes,
              figure: matplotlib.figure.Figure):

        axes.grid(self.grid)
        if self.logx:
            axes.set_xscale("log")
        if self.logy:
            axes.set_yscale("log")

        xmin, xmax = axes.get_xlim()
        ymin, ymax = axes.get_ylim()
        xmin = xmin if self.xmin is None else self.xmin
        xmax = xmax if self.xmax is None else self.xmax
        ymin = ymin if self.ymin is None else self.ymin
        ymax = ymax if self.ymax is None else self.ymax
        axes.set_xlim(xmin=xmin, xmax=xmax)
        axes.set_ylim(ymin=ymin, ymax=ymax)

        if self.dpi and (figure is not None):
            figure.set_dpi(self.dpi)
Exemple #17
0
def plot_testcount_forecast(
    result: pandas.Series,
    m: preprocessing.fbprophet.Prophet,
    forecast: pandas.DataFrame,
    considered_holidays: preprocessing.NamedDates, *,
    ax: matplotlib.axes.Axes=None
) -> matplotlib.axes.Axes:
    """ Helper function for plotting the detailed testcount forecasting result.

    Parameters
    ----------
    result : pandas.Series
        the date-indexed series of smoothed/predicted testcounts
    m : fbprophet.Prophet
        the prophet model
    forecast : pandas.DataFrame
        contains the prophet model prediction
    holidays : dict of { datetime : str }
        dictionary of the holidays that were used in the model
    ax : optional, matplotlib.axes.Axes
        an existing subplot to use

    Returns
    -------
    ax : matplotlib.axes.Axes
        the (created) subplot that was plotted into
    """
    if not ax:
        _, ax = pyplot.subplots(figsize=(13.4, 6))
    m.plot(forecast[forecast.ds >= m.history.set_index('ds').index[0]], ax=ax)
    ax.set_ylim(bottom=0)
    ax.set_xlim(pandas.to_datetime('2020-03-01'))
    plot_vlines(ax, considered_holidays, alignment='bottom')
    ax.legend(frameon=False, loc='upper left', handles=[
        ax.scatter([], [], color='black', label='training data'),
        ax.plot([], [], color='blue', label='prediction')[0],
        ax.plot(result.index, result.values, color='orange', label='result')[0],
    ])
    ax.set_ylabel('total tests')
    ax.set_xlabel('')
    return ax
Exemple #18
0
def plot_ic(ax: matplotlib.axes.Axes, ic: IonChromatogram, minutes: bool = False, **kwargs) -> List[Line2D]:
	"""
	Plots an Ion Chromatogram

	:param ax: The axes to plot the IonChromatogram on
	:type ax: matplotlib.axes.Axes
	:param ic: Ion Chromatograms m/z channels for plotting
	:type ic: pyms.IonChromatogram.IonChromatogram
	:param minutes: Whether the x-axis should be plotted in minutes. Default False (plotted in seconds)
	:type minutes: bool, optional

	:Other Parameters: :class:`matplotlib.lines.Line2D` properties.
		Used to specify properties like a line label (for auto legends),
		linewidth, antialiasing, marker face color.

		Example::

		>>> plot_ic(im.get_ic_at_index(5), label='IC @ Index 5', linewidth=2)

		See https://matplotlib.org/3.1.1/api/_as_gen/matplotlib.lines.Line2D.html
		for the list of possible kwargs

	:return: A list of Line2D objects representing the plotted data.
	:rtype: list of :class:`matplotlib.lines.Line2D`
	"""

	if not isinstance(ic, IonChromatogram):
		raise TypeError("'ic' must be an IonChromatogram")

	time_list = ic.time_list
	if minutes:
		time_list = [time / 60 for time in time_list]

	plot = ax.plot(time_list, ic.intensity_array, **kwargs)

	# Set axis ranges
	ax.set_xlim(min(ic.time_list), max(ic.time_list))
	ax.set_ylim(bottom=0)

	return plot
def equal_axlim(axs: mpl.axes.Axes, mode: str = 'union') -> None:
    """Make x/y axes limits the same.

    Parameters
    ----------
    axs : mpl.axes.Axes
        `Axes` instance whose limits are to be adjusted.
    mode : str
        How do we adjust the limits? Options:
            'union'
                Limits include old ranges of both x and y axes, *default*.
            'intersect'
                Limits only include values in both ranges.
            'x'
                Set y limits to x limits.
            'y'
                Set x limits to y limits.
    Raises
    ------
    ValueError
        If `mode` is not one of the options above.
    """
    xlim = axs.get_xlim()
    ylim = axs.get_ylim()
    modes = {
        'union': (min(xlim[0], ylim[0]), max(xlim[1], ylim[1])),
        'intersect': (max(xlim[0], ylim[0]), min(xlim[1], ylim[1])),
        'x': xlim,
        'y': ylim
    }
    if mode not in modes:
        raise ValueError(f"Unknown mode '{mode}'. Shoulde be one of: "
                         "'union', 'intersect', 'x', 'y'.")
    new_lim = modes[mode]
    axs.set_xlim(new_lim)
    axs.set_ylim(new_lim)
Exemple #20
0
def floor_plan(
    ax: mpl.axes.Axes,
    lattice: Lattice,
    *,
    start_angle: float = 0,
    labels: bool = True,
):
    ax.set_aspect("equal")
    codes = Path.MOVETO, Path.LINETO
    current_angle = start_angle
    start = np.zeros(2)
    end = np.zeros(2)
    x_min = y_min = 0
    x_max = y_max = 0
    sign = 1
    for element, group in groupby(lattice.sequence):
        start = end.copy()
        length = element.length * sum(1 for _ in group)
        if isinstance(element, Drift):
            color = Color.BLACK
            line_width = 1
        else:
            color = ELEMENT_COLOR[type(element)]
            line_width = 6

        # TODO: refactor current angle
        angle = 0
        if isinstance(element, Dipole):
            angle = element.k0 * length
            radius = length / angle
            vec = radius * np.array([np.sin(angle), 1 - np.cos(angle)])
            sin = np.sin(current_angle)
            cos = np.cos(current_angle)
            rot = np.array([[cos, -sin], [sin, cos]])
            end += rot @ vec

            angle_center = current_angle + 0.5 * np.pi
            center = start + radius * np.array(
                [np.cos(angle_center),
                 np.sin(angle_center)])
            diameter = 2 * radius
            arc_angle = -90
            theta1 = current_angle * 180 / np.pi
            theta2 = (current_angle + angle) * 180 / np.pi
            if angle < 0:
                theta1, theta2 = theta2, theta1

            line = patches.Arc(
                center,
                width=diameter,
                height=diameter,
                angle=arc_angle,
                theta1=theta1,
                theta2=theta2,
                color=color,
                linewidth=line_width,
            )
            current_angle += angle
        else:
            end += length * np.array(
                [np.cos(current_angle),
                 np.sin(current_angle)])
            line = patches.PathPatch(Path((start, end), codes),
                                     color=color,
                                     linewidth=line_width)

        x_min = min(x_min, end[0])
        y_min = min(y_min, end[1])
        x_max = max(x_max, end[0])
        y_max = max(y_max, end[1])

        ax.add_patch(line)  # TODO: currently splitted elements get drawn twice

        if labels and isinstance(element, (Dipole, Quadrupole)):
            angle_center = (current_angle - 0.5 * angle) + 0.5 * np.pi
            sign = -sign
            center = 0.5 * (start + end) + 0.5 * sign * np.array(
                [np.cos(angle_center),
                 np.sin(angle_center)])
            ax.annotate(
                element.name,
                xy=center,
                fontsize=6,
                ha="center",
                va="center",
                # rotation=(current_angle * 180 / np.pi -90) % 180,
                annotation_clip=False,
                zorder=11,
            )

    margin = 0.01 * max((x_max - x_min), (y_max - y_min))
    ax.set_xlim(x_min - margin, x_max + margin)
    ax.set_ylim(y_min - margin, y_max + margin)
    return ax
Exemple #21
0
    def plot(
        self,
        x_label: str = "Mean of methods",
        y_label: str = "Difference between methods",
        graph_title: str = None,
        reference: bool = False,
        xlim: Tuple = None,
        ylim: Tuple = None,
        color_mean: str = "#008bff",
        color_loa: str = "#FF7000",
        color_points: str = "#000000",
        point_kws: Dict = None,
        ci_alpha: float = 0.2,
        loa_linestyle: str = "--",
        ax: matplotlib.axes.Axes = None,
    ):
        """Provide a method comparison using Bland-Altman plotting.
        This is an Axis-level function which will draw the Bland-Altman plot
        onto the current active Axis object unless ``ax`` is provided.
        Parameters
        ----------
        x_label : str, optional
            The label which is added to the X-axis. If None is provided, a standard
            label will be added.
        y_label : str, optional
            The label which is added to the Y-axis. If None is provided, a standard
            label will be added.
        graph_title : str, optional
            Title of the Bland-Altman plot.
            If None is provided, no title will be plotted.
        reference : bool, optional
            If True, a grey reference line at y=0 will be plotted in the Bland-Altman.
        xlim : list, optional
            Minimum and maximum limits for X-axis. Should be provided as list or tuple.
            If not set, matplotlib will decide its own bounds.
        ylim : list, optional
            Minimum and maximum limits for Y-axis. Should be provided as list or tuple.
            If not set, matplotlib will decide its own bounds.
        color_mean : str, optional
            Color of the mean difference line that will be plotted.
        color_loa : str, optional
            Color of the limit of agreement lines that will be plotted.
        color_points : str, optional
            Color of the individual differences that will be plotted.
        point_kws : dict of key, value mappings, optional
            Additional keyword arguments for `plt.scatter`.
        ci_alpha: float, optional
            Alpha value of the confidence interval.
        loa_linestyle: str, optional
            Linestyle of the limit of agreement lines.
        ax : matplotlib Axes, optional
            Axes in which to draw the plot, otherwise use the currently-active
            Axes.

        Returns
        -------
        ax : matplotlib Axes
            Axes object with the Bland-Altman plot.
        """

        ax = ax or plt.gca()

        pkws = self.DEFAULT_POINTS_KWS.copy()
        pkws.update(point_kws or {})

        # Get parameters
        mean, mean_CI = self.result["mean"], self.result["mean_CI"]
        loa_upper, loa_upper_CI = self.result["loa_upper"], self.result[
            "loa_upper_CI"]
        loa_lower, loa_lower_CI = self.result["loa_lower"], self.result[
            "loa_lower_CI"]
        sd_diff = self.result["sd_diff"]

        # individual points
        ax.scatter(self.mean, self.diff, **pkws)

        # mean difference and SD lines
        ax.axhline(mean, color=color_mean, linestyle=loa_linestyle)
        ax.axhline(loa_upper, color=color_loa, linestyle=loa_linestyle)
        ax.axhline(loa_lower, color=color_loa, linestyle=loa_linestyle)

        if reference:
            ax.axhline(0, color="grey", linestyle="-", alpha=0.4)

        # confidence intervals (if requested)
        if self.CI is not None:
            ax.axhspan(*mean_CI, color=color_mean, alpha=ci_alpha)
            ax.axhspan(*loa_upper_CI, color=color_loa, alpha=ci_alpha)
            ax.axhspan(*loa_lower_CI, color=color_loa, alpha=ci_alpha)

        # text in graph
        trans: matplotlib.transform = transforms.blended_transform_factory(
            ax.transAxes, ax.transData)
        offset: float = (((self.loa * sd_diff) * 2) / 100) * 1.2
        ax.text(
            0.98,
            mean + offset,
            "Mean",
            ha="right",
            va="bottom",
            transform=trans,
        )
        ax.text(
            0.98,
            mean - offset,
            f"{mean:.2f}",
            ha="right",
            va="top",
            transform=trans,
        )
        ax.text(
            0.98,
            loa_upper + offset,
            f"+{self.loa:.2f} SD",
            ha="right",
            va="bottom",
            transform=trans,
        )
        ax.text(
            0.98,
            loa_upper - offset,
            f"{loa_upper:.2f}",
            ha="right",
            va="top",
            transform=trans,
        )
        ax.text(
            0.98,
            loa_lower - offset,
            f"-{self.loa:.2f} SD",
            ha="right",
            va="top",
            transform=trans,
        )
        ax.text(
            0.98,
            loa_lower + offset,
            f"{loa_lower:.2f}",
            ha="right",
            va="bottom",
            transform=trans,
        )

        # transform graphs
        ax.spines["right"].set_visible(False)
        ax.spines["top"].set_visible(False)

        # set X and Y limits
        if xlim is not None:
            ax.set_xlim(xlim[0], xlim[1])
        if ylim is not None:
            ax.set_ylim(ylim[0], ylim[1])

        # graph labels
        ax.set(xlabel=x_label, ylabel=y_label, title=graph_title)

        return ax
Exemple #22
0
def plot_topk_accuracy(ax: mpl.axes.Axes,
                       experiment_name: str,
                       topk: int,
                       eval_metric: str,
                       pool_size: int,
                       threshold: float,
                       plot_kwargs: Dict[str, Any] = {},
                       plot_informed: bool = False) -> None:
    """
    Replicates Figure 2 in [CITE PAPER].

    Parameters
    ===
    experiment_name: str.
        Experimental results were written to files under a directory named using experiment_name.
    eval_metric: str.
        Takes value from ['avg_num_agreement', 'mrr']
    pool_size: int.
        Total size of pool from which samples were drawn.
    plot_kwargs : dict.
        Keyword arguments passed to the plot.
    Returns
    ===
    fig, axes : The generated matplotlib Figure and Axes.
    """

    _plot_kwargs = DEFAULT_PLOT_KWARGS.copy()
    _plot_kwargs.update(plot_kwargs)

    if plot_informed:
        benchmark = 'ts_informed'
        method_list = {
            'ts_informed': 'TS (informative)',
            'ts_uniform': 'TS (uninformative)',
        }
    else:
        benchmark = 'ts_uniform'
        method_list = {
            'non-active_no_prior', 'ts_uniform', 'epsilon_greedy_no_prior',
            'bayesian_ucb_no_prior'
        }
        # method_list = {'non-active_no_prior', 'ts_uniform'}

    for method in method_list:
        metric_eval = np.load(RESULTS_DIR + experiment_name +
                              ('%s_%s.npy' %
                               (eval_metric, method))).mean(axis=0)
        x = np.arange(len(metric_eval)) * LOG_FREQ / pool_size
        if topk == 1:
            if plot_informed:
                label = method_list[method]
            else:
                label = METHOD_NAME_DICT[method]
        else:
            label = TOPK_METHOD_NAME_DICT[method]
        ax.plot(x,
                metric_eval,
                label=label,
                color=COLOR[method],
                **_plot_kwargs)

        if method == benchmark:
            if method == benchmark:
                if max(metric_eval) > threshold:
                    cutoff = list(
                        map(lambda i: i > threshold,
                            metric_eval.tolist()[10:])).index(True) + 10
                    cutoff = min(int(cutoff * 1.2), len(metric_eval) - 1)
                else:
                    cutoff = len(metric_eval) - 1

    ax.set_xlim(0, cutoff * LOG_FREQ / pool_size)
    ax.set_ylim(0, 1.0)
    xmin, xmax = ax.get_xlim()
    step = ((xmax - xmin) / 4.0001)
    ax.xaxis.set_major_formatter(ticker.PercentFormatter(xmax=1))
    ax.xaxis.set_ticks(np.arange(xmin, xmax + 0.001, step))
    ax.yaxis.set_ticks(np.arange(0, 1.01, 0.20))
    ax.tick_params(pad=0.25, length=1.5)

    return ax