def plot_histogram(axes: mpl.axes.Axes, hist, popt, bin_edges: np.ndarray, axis: str): """ plots the histogram of one of the axis projections Plot the projection of histogram of hits. The plot is plotted vertically if the axis label is given as 'y' Parameters ---------- axes : mpl.axes.Axes the axes to plot the histogram into. The plot will be vertical if the axis is specified to be 'y' hist : np.ndarray the histogramd hits popt : np.ndarray the parameters for the Gaussian that to be plotted over the hist bin_edges : np.ndarray the edges of the bins used for the histogram axis : str the axis on for which the results should be plotted (either 'x' or 'y'). Plots vertically if 'y' is specified. """ bin_centers = (bin_edges[1:] + bin_edges[:-1])/2 if axis == 'x': axes.set_xlim([bin_edges[0], bin_edges[-1]]) axes.hist(bin_edges[:-1], bin_edges, weights=hist, density=True) axes.plot(bin_centers, norm.pdf(bin_centers, *popt)) elif axis == 'y': axes.set_ylim([bin_edges[0], bin_edges[-1]]) axes.hist(bin_edges[:-1], bin_edges, weights=hist, density=True, orientation='horizontal') axes.plot(norm.pdf(bin_centers, *popt), bin_centers) else: raise ValueError("axis has to be either 'x' or 'y'")
def volcano_plot(df: pd.DataFrame, ax: matplotlib.axes.Axes) -> matplotlib.axes.Axes: '''Generate a volcano plot Parameters ---------- df : pd.DataFrame differential expression output from `diffex_multifactor`. ax : matplotlib.axes.Axes Returns ------- matplotlib.axes.Axes ''' if 'significant' not in df.columns: print('Adding significance cutoff at alpha=0.05') df['significant'] = df['q_val'] < 0.05 n_colors = len(np.unique(df['significant'])) sns.scatterplot(data=df, x='log2_fc', y='nlogq', hue='significant', linewidth=0., alpha=0.3, ax=ax, palette=sns.hls_palette(n_colors)[::-1]) ax.set_xlim((-6, 6)) ax.set_ylabel(r'$-\log_{10}$ q-value') ax.set_xlabel(r'$\log_2$ (Old / Young)') ax.get_legend().remove() return ax
def zoom_x_and_save(fig: matplotlib.figure.Figure, ax: matplotlib.axes.Axes, figbase: str, plot_ext: str, xzoom: List[Tuple[float, float]]) -> None: """ Zoom in on subregions of the x-axis and save the figure. Arguments --------- fig : matplotlib.figure.Figure Figure to be processed. ax : matplotlib.axes.Axes Axes to be processed. fig_base : str Base name of the figure to be saved. plot_ext : str File extension of the figure to be saved. xzoom : List[list[float,float]] Values at which to split the x-axis. """ xmin, xmax = ax.get_xlim() for ix in range(len(xzoom)): ax.set_xlim(xmin=xzoom[ix][0], xmax=xzoom[ix][1]) figfile = (figbase + ".sub" + str(ix + 1) + plot_ext) savefig(fig, figfile) ax.set_xlim(xmin=xmin, xmax=xmax)
def zoom_xy_and_save(fig: matplotlib.figure.Figure, ax: matplotlib.axes.Axes, figbase: str, plot_ext: str, xyzoom: List[Tuple[float, float, float, float]], scale: float = 1000) -> None: """ Zoom in on subregions in x,y-space and save the figure. Arguments --------- fig : matplotlib.figure.Figure Figure to be processed. ax : matplotlib.axes.Axes Axes to be processed. fig_base : str Base name of the figure to be saved. plot_ext : str File extension of the figure to be saved. xyzoom : List[List[float, float, float, float]] List of xmin, xmax, ymin, ymax values to zoom into. scale: float Indicates whether the axes are in m (1) or km (1000). """ xmin, xmax = ax.get_xlim() ymin, ymax = ax.get_ylim() dx_zoom = 0 xy_ratio = (ymax - ymin) / (xmax - xmin) for ix in range(len(xyzoom)): xmin0 = xyzoom[ix][0] xmax0 = xyzoom[ix][1] ymin0 = xyzoom[ix][2] ymax0 = xyzoom[ix][3] dx = xmax0 - xmin0 dy = ymax0 - ymin0 if dy < xy_ratio * dx: # x range limiting dx_zoom = max(dx_zoom, dx) else: # y range limiting dx_zoom = max(dx_zoom, dy / xy_ratio) dy_zoom = dx_zoom * xy_ratio for ix in range(len(xyzoom)): x0 = (xyzoom[ix][0] + xyzoom[ix][1]) / 2 y0 = (xyzoom[ix][2] + xyzoom[ix][3]) / 2 ax.set_xlim(xmin=(x0 - dx_zoom / 2) / scale, xmax=(x0 + dx_zoom / 2) / scale) ax.set_ylim(ymin=(y0 - dy_zoom / 2) / scale, ymax=(y0 + dy_zoom / 2) / scale) figfile = (figbase + ".sub" + str(ix + 1) + plot_ext) savefig(fig, figfile) ax.set_xlim(xmin=xmin, xmax=xmax) ax.set_ylim(ymin=ymin, ymax=ymax)
def _draw_curve(self, ax: matplotlib.axes.Axes, rt_buffer: float) -> None: """Draw the EIC data and fill under the curve betweeen RT min and RT max""" eic = self.compound["data"]["eic"] if eic is not None and "rt" in eic and len(eic["rt"]) > 0: # fill_between requires a data point at each end of range, so add points via interpolation x, y = add_interp_at(eic["rt"], eic["intensity"], self.rt_range) ax.plot(x, y) utils.fill_under(ax, x, y, between=self.rt_range, color="c", alpha=0.3) x_min = min(self.rt_range[0], self.rt_peak) - rt_buffer x_max = max(self.rt_range[1], self.rt_peak) + rt_buffer ax.set_xlim(x_min, x_max)
def customize_ax(ax: matplotlib.axes.Axes, title=None, xlabel=None, ylabel=None, xlim=None, ylim=None, invert_yaxis=False, xticks_maj_freq=None, xticks_min_freq=None, yticks_maj_freq=None, yticks_min_freq=None, with_hline=False, hline_height=None, hline_color='r', hline_style='--'): """ : ax (matplotlib.axes.Axes): plot to customize. : Use to customize a plot with labels, ticks, etc. """ if xlabel is not None: ax.set_xlabel(xlabel) if ylabel is not None: ax.set_ylabel(ylabel) if xlim is not None: ax.set_xlim(xlim) if ylim is not None: ax.set_ylim(ylim) if invert_yaxis: ax.invert_yaxis() if title is not None: ax.set_title(title) if xticks_maj_freq is not None: ax.xaxis.set_major_locator(ticker.MultipleLocator(xticks_maj_freq)) if xticks_min_freq is not None: ax.xaxis.set_minor_locator(ticker.MultipleLocator(xticks_min_freq)) if yticks_maj_freq is not None: ax.yaxis.set_major_locator(ticker.MultipleLocator(yticks_maj_freq)) if yticks_min_freq is not None: ax.yaxis.set_minor_locator(ticker.MultipleLocator(yticks_min_freq)) if with_hline: if hline_height is None: ylim = plt.ylim() hline_height = max(ylim) / 2 ax.axhline(y=hline_height, color=hline_color, linestyle=hline_style)
def _plot_data(ax: mpl.axes.Axes, data: PlotData) -> Optional[List[mpl.lines.Line2D]]: x, y = None, None lines = None # Return line objects so we can add legends disp = data.display_attributes if isinstance(data, XYData) or isinstance(data, TimeSeries): x, y = (data.x, data.y) if isinstance(data, XYData) else (np.arange(len(data.timestamps)), data.values) if isinstance(disp, LinePlotAttributes): lines, = ax.plot(x, y, linestyle=disp.line_type, linewidth=disp.line_width, color=disp.color) if disp.marker is not None: # type: ignore ax.scatter(x, y, marker=disp.marker, c=disp.marker_color, s=disp.marker_size, zorder=100) elif isinstance(disp, ScatterPlotAttributes): lines = ax.scatter(x, y, marker=disp.marker, c=disp.marker_color, s=disp.marker_size, zorder=100) elif isinstance(disp, BarPlotAttributes): lines = ax.bar(x, y, color=disp.color) # type: ignore elif isinstance(disp, FilledLinePlotAttributes): x, y = np.nan_to_num(x), np.nan_to_num(y) pos_values = np.where(y > 0, y, 0) neg_values = np.where(y < 0, y, 0) ax.fill_between(x, pos_values, color=disp.positive_color, step='post', linewidth=0.0) ax.fill_between(x, neg_values, color=disp.negative_color, step='post', linewidth=0.0) else: raise Exception(f'unknown plot combination: {type(data)} {type(disp)}') # For scatter and filled line, xlim and ylim does not seem to get set automatically if isinstance(disp, ScatterPlotAttributes) or isinstance(disp, FilledLinePlotAttributes): xmin, xmax = _adjust_axis_limit(ax.get_xlim(), x) if not np.isnan(xmin) and not np.isnan(xmax): ax.set_xlim((xmin, xmax)) ymin, ymax = _adjust_axis_limit(ax.get_ylim(), y) if not np.isnan(ymin) and not np.isnan(ymax): ax.set_ylim((ymin, ymax)) elif isinstance(data, TradeSet) and isinstance(disp, ScatterPlotAttributes): lines = ax.scatter(np.arange(len(data.timestamps)), data.values, marker=disp.marker, c=disp.marker_color, s=disp.marker_size, zorder=100) elif isinstance(data, TradeBarSeries) and isinstance(disp, CandleStickPlotAttributes): draw_candlestick(ax, np.arange(len(data.timestamps)), data.o, data.h, data.l, data.c, data.v, data.vwap, colorup=disp.colorup, colordown=disp.colordown) elif isinstance(data, BucketedValues) and isinstance(disp, BoxPlotAttributes): draw_boxplot( ax, data.bucket_names, data.bucket_values, disp.proportional_widths, disp.notched, # type: ignore disp.show_outliers, disp.show_means, disp.show_all) # type: ignore elif isinstance(data, XYZData) and (isinstance(disp, SurfacePlotAttributes) or isinstance(disp, ContourPlotAttributes)): display_type: str = 'contour' if isinstance(disp, ContourPlotAttributes) else 'surface' draw_3d_plot(ax, data.x, data.y, data.z, display_type, disp.marker, disp.marker_size, disp.marker_color, disp.interpolation, disp.cmap) else: raise Exception(f'unknown plot combination: {type(data)} {type(disp)}') return lines
def plot(self, ax: matplotlib.axes.Axes): # individual points ax.scatter(self.mean, self.diff, s=20, alpha=0.6, color=self.color_points, **self.point_kws) # mean difference and SD lines ax.axhline(self.mean_diff, color=self.color_mean, linestyle='-') ax.axhline(self.mean_diff + self.loa_sd, color=self.color_loa, linestyle='--') ax.axhline(self.mean_diff - self.loa_sd, color=self.color_loa, linestyle='--') if self.reference: ax.axhline(0, color='grey', linestyle='-', alpha=0.4) # confidence intervals (if requested) if self.CI is not None: ax.axhspan(self.CI_mean[0], self.CI_mean[1], color=self.color_mean, alpha=0.2) ax.axhspan(self.CI_upper[0], self.CI_upper[1], color=self.color_loa, alpha=0.2) ax.axhspan(self.CI_lower[0], self.CI_lower[1], color=self.color_loa, alpha=0.2) # text in graph trans: matplotlib.transform = transforms.blended_transform_factory( ax.transAxes, ax.transData) offset: float = (((self.loa * self.sd_diff) * 2) / 100) * 1.2 ax.text(0.98, self.mean_diff + offset, 'Mean', ha="right", va="bottom", transform=trans) ax.text(0.98, self.mean_diff - offset, f'{self.mean_diff:.2f}', ha="right", va="top", transform=trans) ax.text(0.98, self.mean_diff + self.loa_sd + offset, f'+{self.loa:.2f} SD', ha="right", va="bottom", transform=trans) ax.text(0.98, self.mean_diff + self.loa_sd - offset, f'{self.mean_diff + self.loa_sd:.2f}', ha="right", va="top", transform=trans) ax.text(0.98, self.mean_diff - self.loa_sd - offset, f'-{self.loa:.2f} SD', ha="right", va="top", transform=trans) ax.text(0.98, self.mean_diff - self.loa_sd + offset, f'{self.mean_diff - self.loa_sd:.2f}', ha="right", va="bottom", transform=trans) # transform graphs ax.spines['right'].set_visible(False) ax.spines['top'].set_visible(False) # set X and Y limits if self.xlim is not None: ax.set_xlim(self.xlim[0], self.xlim[1]) if self.ylim is not None: ax.set_ylim(self.ylim[0], self.ylim[1]) # graph labels ax.set_ylabel(self.y_title) ax.set_xlabel(self.x_title) if self.graph_title is not None: ax.set_title(self.graph_title)
def set_ax_lims(self, ax: mpl.axes.Axes, xlims: tuple = None, ylims: tuple = None, yshade: list = None): """Set matplotlib axis limits and apply vertical shade. Keyword arguments: ax -- matplotlib.axes.Axes object to apply changes xlims -- tuple for setting x-axis limits (xmin, xmax) ylims -- tuple for setting y-axis limits (ymin, ymax) yshade -- list of tuples to apply axvspan matplotlib method """ if xlims is not None: ax.set_xlim(xlims) if ylims is not None: ax.set_ylim(ylims) if yshade is not None: for window in yshade: ax.axvspan(window[0], window[1], color='grey', alpha=0.75)
def plot_astrometric_residuals(ax: matplotlib.axes.Axes, xs: np.ndarray, ys: np.ndarray) -> None: """ Plot the astrometric residual field of a set of points. Parameters ---------- ax: Matplotlib axis in which to plot xs: Array of the x- and y-components of the field ys: Array of the x- and y-components of the astrometric residual field Returns ------- None """ qdict = dict( alpha=1, angles='uv', headlength=5, headwidth=3, headaxislength=4, minlength=0, pivot='middle', scale_units='xy', width=0.002, color='#001146' ) q = ax.quiver(xs[:, 0], xs[:, 1], ys[:, 0], ys[:, 1], scale=1, **qdict) ax.quiverkey(q, 0.0, 1.8, 0.1, 'residual = 0.1 arcsec', coordinates='data', labelpos='N', color='darkred', labelcolor='darkred') ax.set_xlabel('RA [degrees]') ax.set_ylabel('Dec [degrees]') ax.set_xlim(-1.95, 1.95) ax.set_ylim(-1.9, 2.0) ax.set_aspect('equal')
def plot_topk_cost(ax: mpl.axes.Axes, experiment_name: str, eval_metric: str, pool_size: int, plot_kwargs: Dict[str, Any] = {}) -> None: """ Replicates Figure 2 in [CITE PAPER]. Parameters === experiment_name: str. Experimental results were written to files under a directory named using experiment_name. eval_metric: str. Takes value from ['avg_num_agreement', 'mrr'] pool_size: int. Total size of pool from which samples were drawn. plot_kwargs : dict. Keyword arguments passed to the plot. Returns === fig, axes : The generated matplotlib Figure and Axes. """ _plot_kwargs = DEFAULT_PLOT_KWARGS.copy() _plot_kwargs.update(plot_kwargs) for method in COST_METHOD_NAME_DICT: metric_eval = np.load( RESULTS_DIR + experiment_name + ('/%s_%s_top1_pseudocount1.0.npy' % (method, eval_metric))) x = np.arange(len(metric_eval)) * LOG_FREQ / pool_size ax.plot(x, metric_eval, label=COST_METHOD_NAME_DICT[method], **_plot_kwargs) cutoff = len(metric_eval) - 1 ax.set_xlim(0, cutoff * LOG_FREQ / pool_size) ax.set_ylim(0, 1.0) xmin, xmax = ax.get_xlim() step = ((xmax - xmin) / 4.0001) ax.xaxis.set_major_formatter(ticker.PercentFormatter(xmax=1)) ax.xaxis.set_ticks(np.arange(xmin, xmax + 0.001, step)) ax.yaxis.set_ticks(np.arange(0, 1.01, 0.20)) ax.tick_params(pad=0.25, length=1.5) return ax
def set_bbox( ax: matplotlib.axes.Axes, bbox: Tuple[float, float, float, float], scale: float = 1000, ) -> None: """ Specify the bounding limits of an axes object. Arguments --------- ax : matplotlib.axes.Axes Axes object to be adjusted. bbox : Tuple[float, float, float, float] Tuple containing boundary limits (xmin, ymin, xmax, ymax); unit m. scale: float Indicates whether the axes are in m (1) or km (1000). """ ax.set_xlim(xmin=bbox[0] / scale, xmax=bbox[2] / scale) ax.set_ylim(ymin=bbox[1] / scale, ymax=bbox[3] / scale)
def plot_step_analyzer(axes: matplotlib.axes.Axes, result: dict, title: str, legends: list, colorset: int): colors = [ ('red', 'green', 'blue'), # TODO: add more colors ('tomato', 'lightgreen', 'steelblue'), ] c = colors[colorset] n = len(result['states']) for i in range(n): if legends: axes.plot(result['times'], result['states'][i], color=c[i], label=legends[i]) else: axes.plot(result['times'], result['states'][i], color=c[i]) if result['references'][i] != 0.0: axes.axhline(result['references'][i], linestyle='--', color=c[i]) axes.axhline(result['references'][i] - result['thresholds'][i], linestyle='-.', color=c[i]) axes.axhline(result['references'][i] + result['thresholds'][i], linestyle='-.', color=c[i]) axes.axhline(result['references'][i] + result['overshoots'][i], linestyle=':', color=c[i]) if result['risetimes'][i] > 0.0: axes.axvline(result['risetimes'][i], linestyle='--', color=c[i]) if result['settletimes'][i] > 0.0: axes.axvline(result['settletimes'][i], linestyle='-.', color=c[i]) if legends: axes.legend() if title: axes.set_title(title) axes.set_xlim(result['times'][0], result['times'][-1]) axes.figure.tight_layout()
def plot_ic( ax: matplotlib.axes.Axes, ic: IonChromatogram, minutes: bool = False, **kwargs, ) -> List[Line2D]: """ Plots an Ion Chromatogram. :param ax: The axes to plot the IonChromatogram on. :param ic: Ion chromatogram m/z channels for plotting. :param minutes: Whether the x-axis should be plotted in minutes. Default :py:obj:`False` (plotted in seconds) :no-default minutes: :Other Parameters: :class:`matplotlib.lines.Line2D` properties. Used to specify properties like a line label (for auto legends), linewidth, antialiasing, marker face color. .. code-block:: python >>> plot_ic(im.get_ic_at_index(5), label='IC @ Index 5', linewidth=2) See :class:`matplotlib.lines.Line2D` for the list of possible keyword arguments. :return: A list of Line2D objects representing the plotted data. """ if not isinstance(ic, IonChromatogram): raise TypeError("'ic' must be an IonChromatogram") time_list = ic.time_list if minutes: time_list = [time / 60 for time in time_list] plot = ax.plot(time_list, ic.intensity_array, **kwargs) # Set axis ranges ax.set_xlim(min(time_list), max(time_list)) ax.set_ylim(bottom=0) return plot
def plot_ece_samples(ax: mpl.axes.Axes, ground_truth_ece: float, frequentist_ece, samples_posterior: np.ndarray, plot_kwargs: Dict[str, Any] = {}) -> mpl.axes.Axes: """ :param ax: :param ground_truth_ece: float :param frequentist_ece: float or np.ndarray :param samples_posterior: :param plot_kwargs: :return: """ _plot_kwargs = DEFAULT_PLOT_KWARGS.copy() _plot_kwargs.update(plot_kwargs) if isinstance(frequentist_ece, float): ax.axvline(x=frequentist_ece, label='Frequentist', color='blue', **_plot_kwargs) else: ax.hist(frequentist_ece, color='blue', alpha=0.7, label='Frequentist', **_plot_kwargs) ax.hist(samples_posterior, color='red', label='Bayesian', alpha=0.7, **_plot_kwargs) ax.axvline(x=ground_truth_ece, label='Ground truth', color='black', **_plot_kwargs) ax.set_xlim(0, 0.3) ax.set_xticks([0.0, 0.1, 0.2, 0.3]) return ax
def apply(self, axes: matplotlib.axes.Axes, figure: matplotlib.figure.Figure): axes.grid(self.grid) if self.logx: axes.set_xscale("log") if self.logy: axes.set_yscale("log") xmin, xmax = axes.get_xlim() ymin, ymax = axes.get_ylim() xmin = xmin if self.xmin is None else self.xmin xmax = xmax if self.xmax is None else self.xmax ymin = ymin if self.ymin is None else self.ymin ymax = ymax if self.ymax is None else self.ymax axes.set_xlim(xmin=xmin, xmax=xmax) axes.set_ylim(ymin=ymin, ymax=ymax) if self.dpi and (figure is not None): figure.set_dpi(self.dpi)
def plot_testcount_forecast( result: pandas.Series, m: preprocessing.fbprophet.Prophet, forecast: pandas.DataFrame, considered_holidays: preprocessing.NamedDates, *, ax: matplotlib.axes.Axes=None ) -> matplotlib.axes.Axes: """ Helper function for plotting the detailed testcount forecasting result. Parameters ---------- result : pandas.Series the date-indexed series of smoothed/predicted testcounts m : fbprophet.Prophet the prophet model forecast : pandas.DataFrame contains the prophet model prediction holidays : dict of { datetime : str } dictionary of the holidays that were used in the model ax : optional, matplotlib.axes.Axes an existing subplot to use Returns ------- ax : matplotlib.axes.Axes the (created) subplot that was plotted into """ if not ax: _, ax = pyplot.subplots(figsize=(13.4, 6)) m.plot(forecast[forecast.ds >= m.history.set_index('ds').index[0]], ax=ax) ax.set_ylim(bottom=0) ax.set_xlim(pandas.to_datetime('2020-03-01')) plot_vlines(ax, considered_holidays, alignment='bottom') ax.legend(frameon=False, loc='upper left', handles=[ ax.scatter([], [], color='black', label='training data'), ax.plot([], [], color='blue', label='prediction')[0], ax.plot(result.index, result.values, color='orange', label='result')[0], ]) ax.set_ylabel('total tests') ax.set_xlabel('') return ax
def plot_ic(ax: matplotlib.axes.Axes, ic: IonChromatogram, minutes: bool = False, **kwargs) -> List[Line2D]: """ Plots an Ion Chromatogram :param ax: The axes to plot the IonChromatogram on :type ax: matplotlib.axes.Axes :param ic: Ion Chromatograms m/z channels for plotting :type ic: pyms.IonChromatogram.IonChromatogram :param minutes: Whether the x-axis should be plotted in minutes. Default False (plotted in seconds) :type minutes: bool, optional :Other Parameters: :class:`matplotlib.lines.Line2D` properties. Used to specify properties like a line label (for auto legends), linewidth, antialiasing, marker face color. Example:: >>> plot_ic(im.get_ic_at_index(5), label='IC @ Index 5', linewidth=2) See https://matplotlib.org/3.1.1/api/_as_gen/matplotlib.lines.Line2D.html for the list of possible kwargs :return: A list of Line2D objects representing the plotted data. :rtype: list of :class:`matplotlib.lines.Line2D` """ if not isinstance(ic, IonChromatogram): raise TypeError("'ic' must be an IonChromatogram") time_list = ic.time_list if minutes: time_list = [time / 60 for time in time_list] plot = ax.plot(time_list, ic.intensity_array, **kwargs) # Set axis ranges ax.set_xlim(min(ic.time_list), max(ic.time_list)) ax.set_ylim(bottom=0) return plot
def equal_axlim(axs: mpl.axes.Axes, mode: str = 'union') -> None: """Make x/y axes limits the same. Parameters ---------- axs : mpl.axes.Axes `Axes` instance whose limits are to be adjusted. mode : str How do we adjust the limits? Options: 'union' Limits include old ranges of both x and y axes, *default*. 'intersect' Limits only include values in both ranges. 'x' Set y limits to x limits. 'y' Set x limits to y limits. Raises ------ ValueError If `mode` is not one of the options above. """ xlim = axs.get_xlim() ylim = axs.get_ylim() modes = { 'union': (min(xlim[0], ylim[0]), max(xlim[1], ylim[1])), 'intersect': (max(xlim[0], ylim[0]), min(xlim[1], ylim[1])), 'x': xlim, 'y': ylim } if mode not in modes: raise ValueError(f"Unknown mode '{mode}'. Shoulde be one of: " "'union', 'intersect', 'x', 'y'.") new_lim = modes[mode] axs.set_xlim(new_lim) axs.set_ylim(new_lim)
def floor_plan( ax: mpl.axes.Axes, lattice: Lattice, *, start_angle: float = 0, labels: bool = True, ): ax.set_aspect("equal") codes = Path.MOVETO, Path.LINETO current_angle = start_angle start = np.zeros(2) end = np.zeros(2) x_min = y_min = 0 x_max = y_max = 0 sign = 1 for element, group in groupby(lattice.sequence): start = end.copy() length = element.length * sum(1 for _ in group) if isinstance(element, Drift): color = Color.BLACK line_width = 1 else: color = ELEMENT_COLOR[type(element)] line_width = 6 # TODO: refactor current angle angle = 0 if isinstance(element, Dipole): angle = element.k0 * length radius = length / angle vec = radius * np.array([np.sin(angle), 1 - np.cos(angle)]) sin = np.sin(current_angle) cos = np.cos(current_angle) rot = np.array([[cos, -sin], [sin, cos]]) end += rot @ vec angle_center = current_angle + 0.5 * np.pi center = start + radius * np.array( [np.cos(angle_center), np.sin(angle_center)]) diameter = 2 * radius arc_angle = -90 theta1 = current_angle * 180 / np.pi theta2 = (current_angle + angle) * 180 / np.pi if angle < 0: theta1, theta2 = theta2, theta1 line = patches.Arc( center, width=diameter, height=diameter, angle=arc_angle, theta1=theta1, theta2=theta2, color=color, linewidth=line_width, ) current_angle += angle else: end += length * np.array( [np.cos(current_angle), np.sin(current_angle)]) line = patches.PathPatch(Path((start, end), codes), color=color, linewidth=line_width) x_min = min(x_min, end[0]) y_min = min(y_min, end[1]) x_max = max(x_max, end[0]) y_max = max(y_max, end[1]) ax.add_patch(line) # TODO: currently splitted elements get drawn twice if labels and isinstance(element, (Dipole, Quadrupole)): angle_center = (current_angle - 0.5 * angle) + 0.5 * np.pi sign = -sign center = 0.5 * (start + end) + 0.5 * sign * np.array( [np.cos(angle_center), np.sin(angle_center)]) ax.annotate( element.name, xy=center, fontsize=6, ha="center", va="center", # rotation=(current_angle * 180 / np.pi -90) % 180, annotation_clip=False, zorder=11, ) margin = 0.01 * max((x_max - x_min), (y_max - y_min)) ax.set_xlim(x_min - margin, x_max + margin) ax.set_ylim(y_min - margin, y_max + margin) return ax
def plot( self, x_label: str = "Mean of methods", y_label: str = "Difference between methods", graph_title: str = None, reference: bool = False, xlim: Tuple = None, ylim: Tuple = None, color_mean: str = "#008bff", color_loa: str = "#FF7000", color_points: str = "#000000", point_kws: Dict = None, ci_alpha: float = 0.2, loa_linestyle: str = "--", ax: matplotlib.axes.Axes = None, ): """Provide a method comparison using Bland-Altman plotting. This is an Axis-level function which will draw the Bland-Altman plot onto the current active Axis object unless ``ax`` is provided. Parameters ---------- x_label : str, optional The label which is added to the X-axis. If None is provided, a standard label will be added. y_label : str, optional The label which is added to the Y-axis. If None is provided, a standard label will be added. graph_title : str, optional Title of the Bland-Altman plot. If None is provided, no title will be plotted. reference : bool, optional If True, a grey reference line at y=0 will be plotted in the Bland-Altman. xlim : list, optional Minimum and maximum limits for X-axis. Should be provided as list or tuple. If not set, matplotlib will decide its own bounds. ylim : list, optional Minimum and maximum limits for Y-axis. Should be provided as list or tuple. If not set, matplotlib will decide its own bounds. color_mean : str, optional Color of the mean difference line that will be plotted. color_loa : str, optional Color of the limit of agreement lines that will be plotted. color_points : str, optional Color of the individual differences that will be plotted. point_kws : dict of key, value mappings, optional Additional keyword arguments for `plt.scatter`. ci_alpha: float, optional Alpha value of the confidence interval. loa_linestyle: str, optional Linestyle of the limit of agreement lines. ax : matplotlib Axes, optional Axes in which to draw the plot, otherwise use the currently-active Axes. Returns ------- ax : matplotlib Axes Axes object with the Bland-Altman plot. """ ax = ax or plt.gca() pkws = self.DEFAULT_POINTS_KWS.copy() pkws.update(point_kws or {}) # Get parameters mean, mean_CI = self.result["mean"], self.result["mean_CI"] loa_upper, loa_upper_CI = self.result["loa_upper"], self.result[ "loa_upper_CI"] loa_lower, loa_lower_CI = self.result["loa_lower"], self.result[ "loa_lower_CI"] sd_diff = self.result["sd_diff"] # individual points ax.scatter(self.mean, self.diff, **pkws) # mean difference and SD lines ax.axhline(mean, color=color_mean, linestyle=loa_linestyle) ax.axhline(loa_upper, color=color_loa, linestyle=loa_linestyle) ax.axhline(loa_lower, color=color_loa, linestyle=loa_linestyle) if reference: ax.axhline(0, color="grey", linestyle="-", alpha=0.4) # confidence intervals (if requested) if self.CI is not None: ax.axhspan(*mean_CI, color=color_mean, alpha=ci_alpha) ax.axhspan(*loa_upper_CI, color=color_loa, alpha=ci_alpha) ax.axhspan(*loa_lower_CI, color=color_loa, alpha=ci_alpha) # text in graph trans: matplotlib.transform = transforms.blended_transform_factory( ax.transAxes, ax.transData) offset: float = (((self.loa * sd_diff) * 2) / 100) * 1.2 ax.text( 0.98, mean + offset, "Mean", ha="right", va="bottom", transform=trans, ) ax.text( 0.98, mean - offset, f"{mean:.2f}", ha="right", va="top", transform=trans, ) ax.text( 0.98, loa_upper + offset, f"+{self.loa:.2f} SD", ha="right", va="bottom", transform=trans, ) ax.text( 0.98, loa_upper - offset, f"{loa_upper:.2f}", ha="right", va="top", transform=trans, ) ax.text( 0.98, loa_lower - offset, f"-{self.loa:.2f} SD", ha="right", va="top", transform=trans, ) ax.text( 0.98, loa_lower + offset, f"{loa_lower:.2f}", ha="right", va="bottom", transform=trans, ) # transform graphs ax.spines["right"].set_visible(False) ax.spines["top"].set_visible(False) # set X and Y limits if xlim is not None: ax.set_xlim(xlim[0], xlim[1]) if ylim is not None: ax.set_ylim(ylim[0], ylim[1]) # graph labels ax.set(xlabel=x_label, ylabel=y_label, title=graph_title) return ax
def plot_topk_accuracy(ax: mpl.axes.Axes, experiment_name: str, topk: int, eval_metric: str, pool_size: int, threshold: float, plot_kwargs: Dict[str, Any] = {}, plot_informed: bool = False) -> None: """ Replicates Figure 2 in [CITE PAPER]. Parameters === experiment_name: str. Experimental results were written to files under a directory named using experiment_name. eval_metric: str. Takes value from ['avg_num_agreement', 'mrr'] pool_size: int. Total size of pool from which samples were drawn. plot_kwargs : dict. Keyword arguments passed to the plot. Returns === fig, axes : The generated matplotlib Figure and Axes. """ _plot_kwargs = DEFAULT_PLOT_KWARGS.copy() _plot_kwargs.update(plot_kwargs) if plot_informed: benchmark = 'ts_informed' method_list = { 'ts_informed': 'TS (informative)', 'ts_uniform': 'TS (uninformative)', } else: benchmark = 'ts_uniform' method_list = { 'non-active_no_prior', 'ts_uniform', 'epsilon_greedy_no_prior', 'bayesian_ucb_no_prior' } # method_list = {'non-active_no_prior', 'ts_uniform'} for method in method_list: metric_eval = np.load(RESULTS_DIR + experiment_name + ('%s_%s.npy' % (eval_metric, method))).mean(axis=0) x = np.arange(len(metric_eval)) * LOG_FREQ / pool_size if topk == 1: if plot_informed: label = method_list[method] else: label = METHOD_NAME_DICT[method] else: label = TOPK_METHOD_NAME_DICT[method] ax.plot(x, metric_eval, label=label, color=COLOR[method], **_plot_kwargs) if method == benchmark: if method == benchmark: if max(metric_eval) > threshold: cutoff = list( map(lambda i: i > threshold, metric_eval.tolist()[10:])).index(True) + 10 cutoff = min(int(cutoff * 1.2), len(metric_eval) - 1) else: cutoff = len(metric_eval) - 1 ax.set_xlim(0, cutoff * LOG_FREQ / pool_size) ax.set_ylim(0, 1.0) xmin, xmax = ax.get_xlim() step = ((xmax - xmin) / 4.0001) ax.xaxis.set_major_formatter(ticker.PercentFormatter(xmax=1)) ax.xaxis.set_ticks(np.arange(xmin, xmax + 0.001, step)) ax.yaxis.set_ticks(np.arange(0, 1.01, 0.20)) ax.tick_params(pad=0.25, length=1.5) return ax