Exemple #1
0
    def add_pressure_ticks(self,
                           ax: matplotlib.axes.Axes,
                           interval: float = 1.0,
                           label: str = None):

        v_static, f_static = zip(
            *[(v_data.volume, v_data.energy)
              for v_data in self.calculator.qha_input.volumes])
        static_p_of_v = lambda v: _to_gpa(
            get_static_p_of_v(v_static, f_static)(_from_ang3(v)))

        ndec = max(-int(numpy.log10(interval)), 0)

        __ax = plt.gca()

        _ax = ax.twiny()
        _ax.set_xlim(*ax.get_xlim())
        _ax.xaxis.set_major_locator(
            PofVLocator(static_p_of_v, p_interval=interval))
        _ax.xaxis.set_minor_locator(
            PofVLocator(static_p_of_v, p_interval=interval / 5))
        _ax.xaxis.set_major_formatter(PofVFormatter(static_p_of_v, ndec=ndec))

        if label:
            _ax.set_xlabel(label)

        plt.sca(__ax)
Exemple #2
0
def _autoscale(ax: matplotlib.axes.Axes,
               axis: str = "y",
               sides: str = "both",
               margin: float = 0.1) -> None:
    """Autoscales the x or y axis of a given matplotlib ax object
    to fit the margins set by manually limits of the other axis,
    with margins in fraction of the width of the plot
    if sides is 'max' or 'min' then only adjust the limit on that side of axis"""
    assert axis in ["x", "y"]
    assert sides in ["both", "min", "max"]
    low, high = np.inf, -np.inf
    for artist in ax.collections + ax.lines:
        if axis == "y":
            set_lim = ax.set_ylim
            get_lim = ax.get_ylim
            cur_fixed_limit = ax.get_xlim()
            fixed, dependent = _get_xy(artist)
        else:
            set_lim = ax.set_xlim
            get_lim = ax.get_xlim
            cur_fixed_limit = ax.get_ylim()
            dependent, fixed = _get_xy(artist)
        low, high = _update_limts(low, high, fixed, dependent, cur_fixed_limit)
    margin = margin * (high - low)
    if low == np.inf and high == -np.inf:
        return
    assert low != np.inf and high != -np.inf
    new_min = (low - margin) if sides in ["both", "min"] else get_lim()[0]
    new_max = (high + margin) if sides in ["both", "max"] else get_lim()[1]
    set_lim(new_min, new_max)
def zoom_x_and_save(fig: matplotlib.figure.Figure, ax: matplotlib.axes.Axes,
                    figbase: str, plot_ext: str,
                    xzoom: List[Tuple[float, float]]) -> None:
    """
    Zoom in on subregions of the x-axis and save the figure.

    Arguments
    ---------
    fig : matplotlib.figure.Figure
        Figure to be processed.
    ax : matplotlib.axes.Axes
        Axes to be processed.
    fig_base : str
        Base name of the figure to be saved.
    plot_ext : str
        File extension of the figure to be saved.
    xzoom : List[list[float,float]]
        Values at which to split the x-axis.
    """
    xmin, xmax = ax.get_xlim()
    for ix in range(len(xzoom)):
        ax.set_xlim(xmin=xzoom[ix][0], xmax=xzoom[ix][1])
        figfile = (figbase + ".sub" + str(ix + 1) + plot_ext)
        savefig(fig, figfile)
    ax.set_xlim(xmin=xmin, xmax=xmax)
Exemple #4
0
def _BetterCDF(data: List[float],
               ax: matplotlib.axes.Axes):
    # assumes that axes are already set to (min, max)
    data = np.sort(data)
    x_axis_min, x_axis_max = ax.get_xlim()
    n_points = len(data)
    has_quality_1_point = data[-1] == 1
    if has_quality_1_point:
        # don't print a drop off if the last data point(s)
        # have quality 1
        n_ones = sum(data == data[-1])
        data = np.hstack((
            [x_axis_min],
            data[0:(len(data) - n_ones)],
            [x_axis_max]
        ))
        ys = np.hstack((
            [1],
            np.arange(n_points - 1, n_ones - 1, -1) / np.float(n_points),
            [n_ones / np.float(n_points)]
        ))
    else:
        data = np.hstack((
            [x_axis_min],
            data,
            [x_axis_max]
        ))
        ys = np.hstack((
            [1],
            np.arange(n_points - 1, -1, -1) / np.float(n_points),
            [0]
        ))
    #ax.step(data, ys)#, where='post')
    ax.step(data, ys, where='post')
def _BetterCDF(data_list: List[float], ax: matplotlib.axes.Axes):
    # assumes that axes are already set to (min, max)
    data = np.sort(data_list)
    x_axis_min, x_axis_max = ax.get_xlim()
    n_points = len(data)
    data = np.hstack(([x_axis_min], data, [x_axis_max]))
    ys = 1 - np.hstack(([1], np.arange(n_points - 1, -1, -1) / n_points, [0]))
    return ax.step(data, ys, where='post')
Exemple #6
0
def draw_elements(
    ax: mpl.axes.Axes,
    lattice: Lattice,
    *,
    labels: bool = True,
    location: str = "top",
):
    """Draw the elements of a lattice onto a matplotlib axes."""
    x_min, x_max = ax.get_xlim()
    y_min, y_max = ax.get_ylim()
    rect_height = 0.05 * (y_max - y_min)
    if location == "top":
        y0 = y_max = y_max + rect_height
    else:
        y0 = y_min - rect_height
        y_min -= 3 * rect_height
        plt.hlines(y0, x_min, x_max, color="black", linewidth=1)
    ax.set_ylim(y_min, y_max)

    sign = -1
    start = end = 0
    for element, group in groupby(lattice.sequence):
        start = end
        end += element.length * sum(1 for _ in group)
        if end <= x_min:
            continue
        elif start >= x_max:
            break

        try:
            color = ELEMENT_COLOR[type(element)]
        except KeyError:
            continue

        y0_local = y0
        if isinstance(element, Dipole) and element.angle < 0:
            y0_local += rect_height / 4

        ax.add_patch(
            plt.Rectangle(
                (max(start, x_min), y0_local - 0.5 * rect_height),
                min(end, x_max) - max(start, x_min),
                rect_height,
                facecolor=color,
                clip_on=False,
                zorder=10,
            ))
        if labels and type(element) in {Dipole, Quadrupole}:
            sign = -sign
            ax.annotate(
                element.name,
                xy=(0.5 * (start + end), y0 + sign * rect_height),
                fontsize=FONT_SIZE,
                ha="center",
                va="bottom" if sign > 0 else "top",
                annotation_clip=False,
                zorder=11,
            )
def zoom_xy_and_save(fig: matplotlib.figure.Figure,
                     ax: matplotlib.axes.Axes,
                     figbase: str,
                     plot_ext: str,
                     xyzoom: List[Tuple[float, float, float, float]],
                     scale: float = 1000) -> None:
    """
    Zoom in on subregions in x,y-space and save the figure.

    Arguments
    ---------
    fig : matplotlib.figure.Figure
        Figure to be processed.
    ax : matplotlib.axes.Axes
        Axes to be processed.
    fig_base : str
        Base name of the figure to be saved.
    plot_ext : str
        File extension of the figure to be saved.
    xyzoom : List[List[float, float, float, float]]
        List of xmin, xmax, ymin, ymax values to zoom into.
    scale: float
        Indicates whether the axes are in m (1) or km (1000).
    """
    xmin, xmax = ax.get_xlim()
    ymin, ymax = ax.get_ylim()

    dx_zoom = 0
    xy_ratio = (ymax - ymin) / (xmax - xmin)
    for ix in range(len(xyzoom)):
        xmin0 = xyzoom[ix][0]
        xmax0 = xyzoom[ix][1]
        ymin0 = xyzoom[ix][2]
        ymax0 = xyzoom[ix][3]
        dx = xmax0 - xmin0
        dy = ymax0 - ymin0
        if dy < xy_ratio * dx:
            # x range limiting
            dx_zoom = max(dx_zoom, dx)
        else:
            # y range limiting
            dx_zoom = max(dx_zoom, dy / xy_ratio)
    dy_zoom = dx_zoom * xy_ratio

    for ix in range(len(xyzoom)):
        x0 = (xyzoom[ix][0] + xyzoom[ix][1]) / 2
        y0 = (xyzoom[ix][2] + xyzoom[ix][3]) / 2
        ax.set_xlim(xmin=(x0 - dx_zoom / 2) / scale,
                    xmax=(x0 + dx_zoom / 2) / scale)
        ax.set_ylim(ymin=(y0 - dy_zoom / 2) / scale,
                    ymax=(y0 + dy_zoom / 2) / scale)
        figfile = (figbase + ".sub" + str(ix + 1) + plot_ext)
        savefig(fig, figfile)

    ax.set_xlim(xmin=xmin, xmax=xmax)
    ax.set_ylim(ymin=ymin, ymax=ymax)
Exemple #8
0
    def default_vertices(self, ax: matplotlib.axes.Axes) -> tuple:
        """
        Default to rectangle that has a quarter-width/height border.
        """
        xlims = ax.get_xlim()
        ylims = ax.get_ylim()
        w = np.diff(xlims)
        h = np.diff(ylims)
        x1, x2 = xlims + w // 4 * np.array([1, -1])
        y1, y2 = ylims + h // 4 * np.array([1, -1])

        return ((x1, y1), (x1, y2), (x2, y2), (x2, y1))
Exemple #9
0
def _plot_data(ax: mpl.axes.Axes, data: PlotData) -> Optional[List[mpl.lines.Line2D]]:
    
    x, y = None, None
    
    lines = None  # Return line objects so we can add legends
    
    disp = data.display_attributes
    
    if isinstance(data, XYData) or isinstance(data, TimeSeries):
        x, y = (data.x, data.y) if isinstance(data, XYData) else (np.arange(len(data.timestamps)), data.values)
        if isinstance(disp, LinePlotAttributes):
            lines, = ax.plot(x, y, linestyle=disp.line_type, linewidth=disp.line_width, color=disp.color)
            if disp.marker is not None:  # type: ignore
                ax.scatter(x, y, marker=disp.marker, c=disp.marker_color, s=disp.marker_size, zorder=100)
        elif isinstance(disp, ScatterPlotAttributes):
            lines = ax.scatter(x, y, marker=disp.marker, c=disp.marker_color, s=disp.marker_size, zorder=100)
        elif isinstance(disp, BarPlotAttributes):
            lines = ax.bar(x, y, color=disp.color)  # type: ignore
        elif isinstance(disp, FilledLinePlotAttributes):
            x, y = np.nan_to_num(x), np.nan_to_num(y)
            pos_values = np.where(y > 0, y, 0)
            neg_values = np.where(y < 0, y, 0)
            ax.fill_between(x, pos_values, color=disp.positive_color, step='post', linewidth=0.0)
            ax.fill_between(x, neg_values, color=disp.negative_color, step='post', linewidth=0.0)
        else:
            raise Exception(f'unknown plot combination: {type(data)} {type(disp)}')
            
        # For scatter and filled line, xlim and ylim does not seem to get set automatically
        if isinstance(disp, ScatterPlotAttributes) or isinstance(disp, FilledLinePlotAttributes):
            xmin, xmax = _adjust_axis_limit(ax.get_xlim(), x)
            if not np.isnan(xmin) and not np.isnan(xmax): ax.set_xlim((xmin, xmax))

            ymin, ymax = _adjust_axis_limit(ax.get_ylim(), y)
            if not np.isnan(ymin) and not np.isnan(ymax): ax.set_ylim((ymin, ymax))
                
    elif isinstance(data, TradeSet) and isinstance(disp, ScatterPlotAttributes):
        lines = ax.scatter(np.arange(len(data.timestamps)), data.values, marker=disp.marker, c=disp.marker_color, s=disp.marker_size, zorder=100)
    elif isinstance(data, TradeBarSeries) and isinstance(disp, CandleStickPlotAttributes):
        draw_candlestick(ax, np.arange(len(data.timestamps)), data.o, data.h, data.l, data.c, data.v, data.vwap, colorup=disp.colorup, colordown=disp.colordown)
    elif isinstance(data, BucketedValues) and isinstance(disp, BoxPlotAttributes):
        draw_boxplot(
            ax, data.bucket_names, data.bucket_values, disp.proportional_widths, disp.notched,  # type: ignore
            disp.show_outliers, disp.show_means, disp.show_all)  # type: ignore
    elif isinstance(data, XYZData) and (isinstance(disp, SurfacePlotAttributes) or isinstance(disp, ContourPlotAttributes)):
        display_type: str = 'contour' if isinstance(disp, ContourPlotAttributes) else 'surface'
        draw_3d_plot(ax, data.x, data.y, data.z, display_type, disp.marker, disp.marker_size, 
                     disp.marker_color, disp.interpolation, disp.cmap)
    else:
        raise Exception(f'unknown plot combination: {type(data)} {type(disp)}')

    return lines
Exemple #10
0
def draw_sub_lattices(
    ax: mpl.axes.Axes,
    lattice: Lattice,
    *,
    labels: bool = True,
    location: str = "top",
):
    x_min, x_max = ax.get_xlim()
    length_gen = [0.0, *(obj.length for obj in lattice.children)]
    position_list = np.add.accumulate(length_gen)
    i_min = np.searchsorted(position_list, x_min)
    i_max = np.searchsorted(position_list, x_max, side="right")
    ticks = position_list[i_min:i_max]
    ax.set_xticks(ticks)
    ax.grid(color=Color.LIGHT_GRAY, linestyle="--", linewidth=1)

    if labels:
        y_min, y_max = ax.get_ylim()
        height = 0.08 * (y_max - y_min)
        if location == "top":
            y0 = y_max - height
        else:
            y0, y_min = y_min - height / 3, y_min - height

        ax.set_ylim(y_min, y_max)
        start = end = 0
        for obj in lattice.children:
            end += obj.length
            if not isinstance(obj, Lattice) or start >= x_max or end <= x_min:
                continue

            x0 = 0.5 * (max(start, x_min) + min(end, x_max))
            ax.annotate(
                obj.name,
                xy=(x0, y0),
                fontsize=FONT_SIZE + 2,
                fontstyle="oblique",
                va="center",
                ha="center",
                clip_on=True,
                zorder=102,
            )
            start = end
def plot_topk_cost(ax: mpl.axes.Axes,
                   experiment_name: str,
                   eval_metric: str,
                   pool_size: int,
                   plot_kwargs: Dict[str, Any] = {}) -> None:
    """
    Replicates Figure 2 in [CITE PAPER].

    Parameters
    ===
    experiment_name: str.
        Experimental results were written to files under a directory named using experiment_name.
    eval_metric: str.
        Takes value from ['avg_num_agreement', 'mrr']
    pool_size: int.
        Total size of pool from which samples were drawn.
    plot_kwargs : dict.
        Keyword arguments passed to the plot.
    Returns
    ===
    fig, axes : The generated matplotlib Figure and Axes.
    """

    _plot_kwargs = DEFAULT_PLOT_KWARGS.copy()
    _plot_kwargs.update(plot_kwargs)

    for method in COST_METHOD_NAME_DICT:
        metric_eval = np.load(
            RESULTS_DIR + experiment_name + ('/%s_%s_top1_pseudocount1.0.npy' % (method, eval_metric)))
        x = np.arange(len(metric_eval)) * LOG_FREQ / pool_size
        ax.plot(x, metric_eval, label=COST_METHOD_NAME_DICT[method], **_plot_kwargs)

    cutoff = len(metric_eval) - 1
    ax.set_xlim(0, cutoff * LOG_FREQ / pool_size)
    ax.set_ylim(0, 1.0)
    xmin, xmax = ax.get_xlim()
    step = ((xmax - xmin) / 4.0001)
    ax.xaxis.set_major_formatter(ticker.PercentFormatter(xmax=1))
    ax.xaxis.set_ticks(np.arange(xmin, xmax + 0.001, step))
    ax.yaxis.set_ticks(np.arange(0, 1.01, 0.20))
    ax.tick_params(pad=0.25, length=1.5)

    return ax
Exemple #12
0
    def apply(self, axes: matplotlib.axes.Axes,
              figure: matplotlib.figure.Figure):

        axes.grid(self.grid)
        if self.logx:
            axes.set_xscale("log")
        if self.logy:
            axes.set_yscale("log")

        xmin, xmax = axes.get_xlim()
        ymin, ymax = axes.get_ylim()
        xmin = xmin if self.xmin is None else self.xmin
        xmax = xmax if self.xmax is None else self.xmax
        ymin = ymin if self.ymin is None else self.ymin
        ymax = ymax if self.ymax is None else self.ymax
        axes.set_xlim(xmin=xmin, xmax=xmax)
        axes.set_ylim(ymin=ymin, ymax=ymax)

        if self.dpi and (figure is not None):
            figure.set_dpi(self.dpi)
Exemple #13
0
def plot_vlines(
    ax: matplotlib.axes.Axes,
    vlines: preprocessing.NamedDates,
    alignment: str,
) -> None:
    """ Helper function for marking special events with labeled vertical lines.

    Parameters
    ----------
    ax : matplotlib.axes.Axes
        the subplot to draw into
    vlines : dict of { datetime : label }
        the dates and labels for the lines
    alignment : str
        one of { "top", "bottom" }
    """
    ymin, ymax = ax.get_ylim()
    xmin, xmax = ax.get_xlim()
    for x, label in vlines.items():
        if xmin <= ax.xaxis.convert_xunits(x) <= xmax:
            label = textwrap.shorten(label, width=20, placeholder="...")
            ax.axvline(x, color="gray", linestyle=":")
            if alignment == 'top':
                y = ymin+0.98*(ymax-ymin)
            elif alignment == 'bottom':
                y = ymin+0.02*(ymax-ymin)
            else:
                raise ValueError(f"Unsupported alignment: '{alignment}'")
            ax.text(
                x, y,
                s=f'{label}\n',
                color="gray",
                rotation=90,
                horizontalalignment="center",
                verticalalignment=alignment,
            )
    return None
def equal_axlim(axs: mpl.axes.Axes, mode: str = 'union') -> None:
    """Make x/y axes limits the same.

    Parameters
    ----------
    axs : mpl.axes.Axes
        `Axes` instance whose limits are to be adjusted.
    mode : str
        How do we adjust the limits? Options:
            'union'
                Limits include old ranges of both x and y axes, *default*.
            'intersect'
                Limits only include values in both ranges.
            'x'
                Set y limits to x limits.
            'y'
                Set x limits to y limits.
    Raises
    ------
    ValueError
        If `mode` is not one of the options above.
    """
    xlim = axs.get_xlim()
    ylim = axs.get_ylim()
    modes = {
        'union': (min(xlim[0], ylim[0]), max(xlim[1], ylim[1])),
        'intersect': (max(xlim[0], ylim[0]), min(xlim[1], ylim[1])),
        'x': xlim,
        'y': ylim
    }
    if mode not in modes:
        raise ValueError(f"Unknown mode '{mode}'. Shoulde be one of: "
                         "'union', 'intersect', 'x', 'y'.")
    new_lim = modes[mode]
    axs.set_xlim(new_lim)
    axs.set_ylim(new_lim)
Exemple #15
0
def plot_topk_accuracy(ax: mpl.axes.Axes,
                       experiment_name: str,
                       topk: int,
                       eval_metric: str,
                       pool_size: int,
                       threshold: float,
                       plot_kwargs: Dict[str, Any] = {},
                       plot_informed: bool = False) -> None:
    """
    Replicates Figure 2 in [CITE PAPER].

    Parameters
    ===
    experiment_name: str.
        Experimental results were written to files under a directory named using experiment_name.
    eval_metric: str.
        Takes value from ['avg_num_agreement', 'mrr']
    pool_size: int.
        Total size of pool from which samples were drawn.
    plot_kwargs : dict.
        Keyword arguments passed to the plot.
    Returns
    ===
    fig, axes : The generated matplotlib Figure and Axes.
    """

    _plot_kwargs = DEFAULT_PLOT_KWARGS.copy()
    _plot_kwargs.update(plot_kwargs)

    if plot_informed:
        benchmark = 'ts_informed'
        method_list = {
            'ts_informed': 'TS (informative)',
            'ts_uniform': 'TS (uninformative)',
        }
    else:
        benchmark = 'ts_uniform'
        method_list = {
            'non-active_no_prior', 'ts_uniform', 'epsilon_greedy_no_prior',
            'bayesian_ucb_no_prior'
        }
        # method_list = {'non-active_no_prior', 'ts_uniform'}

    for method in method_list:
        metric_eval = np.load(RESULTS_DIR + experiment_name +
                              ('%s_%s.npy' %
                               (eval_metric, method))).mean(axis=0)
        x = np.arange(len(metric_eval)) * LOG_FREQ / pool_size
        if topk == 1:
            if plot_informed:
                label = method_list[method]
            else:
                label = METHOD_NAME_DICT[method]
        else:
            label = TOPK_METHOD_NAME_DICT[method]
        ax.plot(x,
                metric_eval,
                label=label,
                color=COLOR[method],
                **_plot_kwargs)

        if method == benchmark:
            if method == benchmark:
                if max(metric_eval) > threshold:
                    cutoff = list(
                        map(lambda i: i > threshold,
                            metric_eval.tolist()[10:])).index(True) + 10
                    cutoff = min(int(cutoff * 1.2), len(metric_eval) - 1)
                else:
                    cutoff = len(metric_eval) - 1

    ax.set_xlim(0, cutoff * LOG_FREQ / pool_size)
    ax.set_ylim(0, 1.0)
    xmin, xmax = ax.get_xlim()
    step = ((xmax - xmin) / 4.0001)
    ax.xaxis.set_major_formatter(ticker.PercentFormatter(xmax=1))
    ax.xaxis.set_ticks(np.arange(xmin, xmax + 0.001, step))
    ax.yaxis.set_ticks(np.arange(0, 1.01, 0.20))
    ax.tick_params(pad=0.25, length=1.5)

    return ax
Exemple #16
0
 def plot(self, ax: mpl.axes.Axes):
     """Plot the styled object onto a matplotlib axes."""
     img_height = int(abs(ax.get_xlim()[1] - ax.get_xlim()[0]))
     img_width = int(abs(ax.get_ylim()[1] - ax.get_ylim()[0]))
     array = self.pad_to(img_width, img_height)
     ax.imshow(array, cmap=self.cmap, alpha=self.style.get("alpha", 0.5))
Exemple #17
0
    def plot(
        self,
        x_label: str = "Method 1",
        y_label: str = "Method 2",
        title: str = None,
        line_reference: bool = True,
        line_CI: bool = True,
        legend: bool = True,
        square: bool = False,
        ax: matplotlib.axes.Axes = None,
        point_kws: Optional[Dict] = None,
        color_regr: Optional[str] = None,
        alpha_regr: Optional[float] = None,
    ) -> matplotlib.axes.Axes:
        """Plot regression result

        Parameters
        ----------
        x_label : str, optional
            The label which is added to the X-axis. (default: "Method 1")
        y_label : str, optional
            The label which is added to the Y-axis. (default: "Method 2")
        title : str, optional
            Title of the regression plot. If None is provided, no title will be plotted.
        line_reference : bool, optional
            If True, a grey reference line at y=x will be plotted in the plot
            (default: True)
        line_CI : bool, optional
            If True, dashed lines will be plotted at the boundaries of the confidence
            intervals.
            (default: False)
        legend : bool, optional
            If True, will provide a legend containing the computed regression equation.
            (default: True)
        square : bool, optional
            If True, set the Axes aspect to "equal" so each cell will be
            square-shaped. (default: True)
        ax : matplotlib.axes.Axes, optional
            matplotlib axis object, if not passed, uses gca()
        point_kws : Optional[Dict], optional
            Additional keywords to plt
        color_regr : Optional[str], optional
            color for regression line and CI area
        alpha_regr : Optional[float], optional
            alpha for regression CI area

        Returns
        ------------------
        matplotlib.axes.Axes
            axes object with the plot
        """
        ax = ax or plt.gca()

        # Set scatter plot keywords to defaults and apply override
        pkws = self.DEFAULT_POINT_KWS.copy()
        pkws.update(point_kws or {})

        # Get regression parameters
        slope = self.result["slope"]
        intercept = self.result["intercept"]

        # plot individual points
        ax.scatter(self.method1, self.method2, **pkws)

        # plot reference line
        if line_reference:
            ax.plot(
                [0, 1],
                [0, 1],
                label="Reference",
                color="grey",
                linestyle="--",
                transform=ax.transAxes,
            )

        # Compute x and y values
        xvals = np.array(ax.get_xlim())
        yvals = xvals[:, None] * slope + intercept

        # Plot regression line 0
        ax.plot(
            xvals,
            yvals[:, 0],
            label=
            f"{y_label} = {intercept[0]:.2f} + {slope[0]:.2f} * {x_label}",
            color=color_regr,
            linestyle="-",
        )

        # Plot confidence region
        if yvals.shape[1] > 2:
            ax.fill_between(
                xvals,
                yvals[:, 1],
                yvals[:, 2],
                color=color_regr or self.DEFAULT_REGRESSION_KWS["color"],
                alpha=alpha_regr or self.DEFAULT_REGRESSION_KWS["alpha"],
            )
            if line_CI:
                ax.plot(xvals, yvals[:, 1], linestyle="--")
                ax.plot(xvals, yvals[:, 2], linestyle="--")

        # Set axes labels
        ax.set(
            xlabel=x_label or "",
            ylabel=y_label or "",
            title=title or "",
        )

        if legend:
            ax.legend(loc="upper left", frameon=False)

        if square:
            ax.set_aspect("equal")
        return ax