Beispiel #1
0
def make_plot(summary: Dataset, ax: Axes, func: str, metric: str,
              plot_warmup: bool) -> None:

    color = build_color_dict(summary.coords["optimizer"].values.tolist())
    optimizers = summary.coords["optimizer"].values
    start = 0 if plot_warmup else 10

    for optimizer in optimizers:
        curr_ds = summary.sel({
            "function": func,
            "optimizer": optimizer,
            "objective": cc.VISIBLE_TO_OPT
        })

        if len(curr_ds.coords[cc.ITER].values) <= start:
            # Not enough trials to make a plot.
            continue

        ax.fill_between(
            curr_ds.coords[cc.ITER].values[start:],
            curr_ds[f"{metric} LB"].values[start:],
            curr_ds[f"{metric} UB"].values[start:],
            color=color[optimizer],
            alpha=0.5,
        )
        ax.plot(
            curr_ds.coords["iter"].values[start:],
            curr_ds[metric].values[start:],
            color=color[optimizer],
            label=optimizer,
        )

    ax.set_xlabel("Budget", fontsize=10)
    ax.set_ylabel(f"{metric.capitalize()} score", fontsize=10)
    ax.grid(alpha=0.2)
def plot_profiles(loss_functions: List[List[float]],
                  ax1: Axes,
                  ax2: Axes,
                  ax4: Axes,
                  x_eor: Optional[float] = None,
                  eor_fontsize: int = 13) -> None:
    loss_functions = np.array(loss_functions)

    avg_profile = get_avg_profile(loss_functions)
    med_profile = get_med_profile(loss_functions)
    std_profile = get_std_profile(loss_functions)
    dor_profile = get_dynamic_order_profile(loss_functions)
    eor_profile = get_ends_order_profile(loss_functions)

    rack = list(range(len(avg_profile)))
    colour = next(ax2._get_lines.prop_cycler)["color"]
    ax1.plot(rack, avg_profile, color=colour)
    ax1.plot(rack, med_profile, '--', color=colour)
    ax1.fill_between(rack,
                     avg_profile + std_profile,
                     avg_profile - std_profile,
                     alpha=0.3,
                     facecolor=colour)
    ax2.plot(rack[1:], dor_profile, color=colour)
    if x_eor is not None:
        ax4.text(x_eor - 0.07,
                 0.2,
                 f"{eor_profile:.2f}",
                 color=colour,
                 fontsize=eor_fontsize)
Beispiel #3
0
 def plot_delay(self, ax: Axes, sweeps: Sequence[ThresholdSweep], color):
     centers = [sweep.at_max_F1().rel_delays_median for sweep in sweeps]
     ax.plot(self.num_delays, centers, self.linestyle, color=color)
     if self.plot_IQR:
         low = [sweep.at_max_F1().rel_delays_Q1 for sweep in sweeps]
         high = [sweep.at_max_F1().rel_delays_Q3 for sweep in sweeps]
         ax.fill_between(self.num_delays, low, high, alpha=0.2, color=color)
Beispiel #4
0
    def render_density(ax: Axes,
                       data,
                       color,
                       name,
                       *,
                       levels=5,
                       line_width=1,
                       alpha=0.6):
        """
        ## Render a density plot from data
        """

        # Mean line
        ln = ax.plot(data[:, 0],
                     data[:, 5],
                     lw=line_width,
                     color=color,
                     alpha=1,
                     label=name)

        # Other percentiles
        for i in range(1, levels):
            ax.fill_between(data[:, 0],
                            data[:, 5 - i],
                            data[:, 5 + i],
                            color=color,
                            lw=0,
                            alpha=alpha**i)

        return ln
Beispiel #5
0
 def shade_under_PR_curves(self, ax: Axes):
     """
     Fill area under each PR-curve with a light shade. Plot shades with
     highest AUC at the bottom (i.e. first, most hidden).
     """
     tups = zip(self.threshold_sweeps, self.colors)
     tups = sorted(tups, key=lambda tup: rank_higher_AUC_lower(tup[0]))
     for sweep, color in tups:
         fc = set_hls_values(color, l=0.95)
         ax.fill_between(sweep.recall, sweep.precision, color=fc)
def _plot_anomalies(axes: Axes,
                    anomaly_points: AnomalyPoints,
                    selection_width: int = 2000) -> None:
    ymin, ymax = axes.get_ylim()
    selection_half = selection_width // 2

    for anomaly in (np.arange(i - selection_half, i + selection_half - 1)
                    for i in anomaly_points):
        y1 = [ymin] * len(anomaly)
        y2 = [ymax] * len(anomaly)
        axes.fill_between(anomaly, y1, y2, facecolor='g', alpha=.05)
def plot_profile_quantiles(ax: Axes, dataset: xr.Dataset) -> Axes:
    """
    Plot elevation quantiles profile
    """

    x = dataset.distance.values
    profile = dataset.profile.values
    density = dataset.density.values
    count = np.sum(density, axis=1)

    parts = np.split(
        np.column_stack([x, profile]),
        np.where(count == 0)[0])

    for k, part in enumerate(parts):

        if k == 0:

            xk = part[:, 0]
            profilek = part[:, 1:]

        else:

            xk = part[1:, 0]
            profilek = part[1:, 1:]

        if xk.size > 0:

            ax.fill_between(
                xk, profilek[:, 0], profilek[:, 4],
                facecolor='#b9d8e6',
                alpha=0.2,
                interpolate=True)

            ax.plot(
                xk, profilek[:, 0], "gray", xk, profilek[:, 4],
                "gray",
                linewidth=0.5,
                linestyle='--')

            ax.fill_between(
                xk, profilek[:, 1], profilek[:, 3],
                facecolor='#48638a',
                alpha=0.5,
                interpolate=True)

            ax.plot(
                xk, profilek[:, 2],
                "#48638a",
                linewidth=1.2)

    ax.set_xlabel("Distance from reference axis (m)")

    return ax
def draw_code_churn(
    axis: axes.Axes,
    project_name: str,
    commit_map: CommitMap,
    revision_selector: tp.Callable[[ShortCommitHash], bool] = lambda x: True,
    sort_df: tp.Callable[
        [pd.DataFrame],
        pd.DataFrame] = lambda data: data.sort_values(by=['time_id'])
) -> None:
    """
    Draws a churn plot onto an axis, showing insertions with green and deletions
    with red.

    Args:
        axis: axis to plot on
        project_name: name of the project to plot churn for
        commit_map: CommitMap for the given project(by project_name)
        revision_selector: takes a revision string and returns True if this rev
                           should be included
        sort_df: function that returns a sorted data frame to plot
    """
    code_churn = build_repo_churn_table(project_name, commit_map)

    code_churn = code_churn[code_churn.apply(
        lambda x: revision_selector(x['revision']), axis=1)]

    code_churn = sort_df(code_churn)

    revision_strs = code_churn.time_id.astype(
        str) + '-' + code_churn.revision.map(lambda x: x.short_hash)

    clipped_insertions = [
        x if x < CODE_CHURN_INSERTION_LIMIT else 1.3 *
        CODE_CHURN_INSERTION_LIMIT for x in code_churn.insertions
    ]
    clipped_deletions = [
        -x if x < CODE_CHURN_DELETION_LIMIT else -1.3 *
        CODE_CHURN_DELETION_LIMIT for x in code_churn.deletions
    ]

    axis.set_ylim(-CODE_CHURN_DELETION_LIMIT, CODE_CHURN_INSERTION_LIMIT)
    axis.fill_between(revision_strs, clipped_insertions, 0, facecolor='green')
    axis.fill_between(
        revision_strs,
        # we need a - here to visualize deletions as negative additions
        clipped_deletions,
        0,
        facecolor='red')
Beispiel #9
0
def cum_res(preds: NumArray, targets: NumArray, ax: Axes = None) -> Axes:
    """Plot the empirical cumulative distribution for the residuals (y - mu).

    Args:
        preds (array): Numpy array of predictions.
        targets (array): Numpy array of targets.
        ax (Axes, optional): matplotlib Axes on which to plot. Defaults to None.

    Returns:
        ax: The plot's matplotlib Axes.
    """
    if ax is None:
        ax = plt.gca()

    res = np.sort(preds - targets)

    n_data = len(res)

    # Plot the empirical distribution
    ax.plot(res, np.arange(n_data) / n_data * 100)

    # Fill the 90% coverage region
    # TODO may look better to add drop downs instead
    low = int(0.05 * (n_data - 1) + 0.5)
    up = int(0.95 * (n_data - 1) + 0.5)
    ax.fill_between(res[low:up], (np.arange(n_data) / n_data * 100)[low:up],
                    alpha=0.3)

    # Get robust (and symmetrical) x axis limits
    delta_low = res[low] - res[int(0.97 * low)]
    delta_up = res[int(1.03 * up)] - res[up]
    delta_max = max(delta_low, delta_up)
    lim = max(abs(res[up] + delta_max), abs(res[low] - delta_max))

    ax.set(xlim=(-lim, lim), ylim=(0, 100))

    # Add some visual guidelines
    ax.plot((0, 0), (0, 100), "--", color="grey", alpha=0.4)
    ax.plot((ax.get_xlim()[0], 0), (50, 50), "--", color="grey", alpha=0.4)

    # Label the plot
    ax.set(xlabel="Residual", ylabel="Percentile", title="Cumulative Residual")
    ax.legend(frameon=False)

    return ax
Beispiel #10
0
    def render_density(self, ax: Axes, data, color, name, *,
                       levels=5,
                       line_width=1,
                       alpha=0.6):
        ln = ax.plot(data[:, 0], data[:, 5],
                     lw=line_width,
                     color=color,
                     alpha=1,
                     label=name)

        for i in range(1, levels):
            ax.fill_between(
                data[:, 0],
                data[:, 5 - i],
                data[:, 5 + i],
                color=color,
                lw=0,
                alpha=alpha ** i)

        return ln
Beispiel #11
0
def fill(h1: Histogram1D, ax: Axes, **kwargs):
    """Fill plot of 1D histogram."""
    show_stats = kwargs.pop("show_stats", False)
    # show_values = kwargs.pop("show_values", False)
    density = kwargs.pop("density", False)
    cumulative = kwargs.pop("cumulative", False)
    kwargs["label"] = kwargs.get("label", h1.name)

    data = get_data(h1, cumulative=cumulative, density=density)
    _apply_xy_lims(ax, h1, data, kwargs)
    _add_ticks(ax, h1, kwargs)
    _add_labels(ax, h1, kwargs)

    ax.fill_between(h1.bin_centers, 0, data, **kwargs)

    if show_stats:
        _add_stats_box(h1, ax, stats=show_stats)
    # if show_values:
    #     _add_values(ax, h1, data)
    return ax
def draw_code_churn_for_revisions(axis: axes.Axes, project_name: str,
                                  commit_map: CommitMap,
                                  revisions: tp.List[FullCommitHash]) -> None:
    """
    Draws a churn plot onto an axis, showing insertions with green and deletions
    with red.

    The churn is calculated as the diff between two successive revisions in
    the ``revisions`` list.

    Args:
        axis: axis to plot on
        project_name: name of the project to plot churn for
        commit_map: CommitMap for the given project(by project_name)
        revisions: list of revisions used to calculate the churn data
    """

    churn_data = build_revisions_churn_table(project_name, commit_map,
                                             revisions)
    clipped_insertions = [
        x if x < CODE_CHURN_INSERTION_LIMIT else 1.3 *
        CODE_CHURN_INSERTION_LIMIT for x in churn_data.insertions
    ]
    clipped_deletions = [
        -x if x < CODE_CHURN_DELETION_LIMIT else -1.3 *
        CODE_CHURN_DELETION_LIMIT for x in churn_data.deletions
    ]
    revision_strs: tp.List[str] = [rev.short_hash for rev in revisions]

    axis.set_ylim(-CODE_CHURN_DELETION_LIMIT, CODE_CHURN_INSERTION_LIMIT)
    axis.fill_between(revision_strs, clipped_insertions, 0, facecolor='green')
    axis.fill_between(
        revision_strs,
        # we need a - here to visualize deletions as negative additions
        clipped_deletions,
        0,
        facecolor='red')
    revision_strs = churn_data.time_id.astype(
        str) + '-' + churn_data.revision.map(lambda x: x.short_hash)
    axis.set_xticks(axis.get_xticks())
    axis.set_xticklabels(revision_strs)
Beispiel #13
0
    def annotate_between_curve(annotation: str,
                               x: np.ndarray,
                               y_lower: np.ndarray,
                               y_upper: np.ndarray,
                               ax: Axes,
                               mark_area: bool = False):
        """Annotate between the curve.

        Args:
            annotation (str): the annotation between the curve.
            x (np.ndarray): independent variable.
            y_lower (np.ndarray): lower bound of the curve.
            y_upper (np.ndarray): upper bound of the curve.
            ax (Axes): axis of the plot.
            mark_area (bool, optional): If True mark the area. Default to False.
        """
        y_diff = y_upper - y_lower
        label_index = np.argmax(y_diff)
        if label_index > 0.95 * len(x) or label_index < 0.05 * len(x):
            label_index = int(0.4 * len(x))
        label_x = x[label_index]
        label_y = 0.5 * (y_lower[label_index] + y_upper[label_index])
        if label_y < y_upper[label_index]:
            ax.text(label_x,
                    label_y,
                    annotation,
                    color='dodgerblue',
                    horizontalalignment='center',
                    verticalalignment='center',
                    size=20)

        if mark_area:
            ax.fill_between(x,
                            y_lower,
                            y_upper,
                            linestyle='--',
                            where=y_lower < y_upper,
                            edgecolor='black',
                            facecolor="none",
                            alpha=0.5)
Beispiel #14
0
def naked_MAB_regret_plot(axes: Axes,
                          xs_dict,
                          cut_time,
                          band_alpha=0.1,
                          legend=False,
                          hide_ylabel=False,
                          adjust_ymax=1,
                          hide_yticklabels=False,
                          **_kwargs):
    for i, (name, value_matrix) in list(enumerate(xs_dict.items())):
        mean_x = np.mean(value_matrix, axis=0)
        sd_x = np.std(value_matrix, axis=0)
        lower, upper = mean_x - sd_x, mean_x + sd_x

        time_points = sparse_index(with_default(cut_time, len(mean_x)), 200)
        axes.plot(time_points,
                  mean_x[time_points],
                  lw=1,
                  label=name.split(' ')[0] if '(TS)' in name else None,
                  color=COLORS[i],
                  linestyle='-' if '(TS)' in name else '--')
        axes.fill_between(time_points,
                          lower[time_points],
                          upper[time_points],
                          color=COLORS[i],
                          alpha=band_alpha,
                          lw=0)

    if legend:
        axes.legend(loc=2, frameon=False)
    if not hide_ylabel:
        axes.set_ylabel('Cum. Regrets')
        axes.get_yaxis().set_label_coords(-0.15, 0.5)
    if adjust_ymax != 1:
        ymin, ymax = axes.get_ylim()
        axes.set_ylim(ymin, ymax * adjust_ymax)
    if hide_yticklabels:
        axes.set_yticklabels([])
Beispiel #15
0
def make_plot(
    summary: pd.DataFrame,
    ax: Axes,
    optimizer: str,
    metric: str,
    plot_warmup: bool,
    color: np.ndarray,
) -> None:

    start = 0 if plot_warmup else 10
    argpos = summary.best_mean.expanding().apply(np.argmin).astype(int)
    best_found = summary.best_mean.values[argpos.values]
    sdev = summary.best_std.values[argpos.values]

    if len(best_found) <= start:
        return

    ax.fill_between(
        np.arange(len(best_found))[start:],
        (best_found - sdev)[start:],
        (best_found + sdev)[start:],
        color=color,
        alpha=0.25,
        step="mid",
    )

    ax.plot(
        np.arange(len(best_found))[start:],
        best_found[start:],
        color=color,
        label=optimizer,
        drawstyle="steps-mid",
    )

    ax.set_xlabel("Budget", fontsize=10)
    ax.set_ylabel(f"Validation {metric.upper()}", fontsize=10)
    ax.grid(alpha=0.2)
Beispiel #16
0
def plot_model(axes: Axes, soil, fit_result: ModelResult):

    # position chi square text
    text_height = 0.02
    x_location = 0.7
    y_ORG = 0.7
    y_MIN = y_ORG - text_height
    y_UNC = y_MIN - text_height
    locations = ((x_location, y_ORG), (x_location, y_MIN), (x_location, y_UNC))
    CHI_SQUARE_LOCATION = dict(zip(SOILS, locations))

    data_kws = {
        'color': COLORS[soil],
        'marker': MARKERS[soil],
        'markersize': 12,
    }

    fit_kws = {
        'color': COLORS[soil],
        'linewidth': 3,
    }

    model = fit_result.model
    parameters = fit_result.params

    X = numpy.arange(DAYS_TO_FIT[0], DAYS_TO_FIT[-1], 1 / 24)  # 24 time points for each day in time_range
    y_fit = model.eval(t=X, params=parameters)

    lines = fit_result.plot_fit(ax=axes, numpoints=480, data_kws=data_kws, fit_kws=fit_kws)

    dely = fit_result.eval_uncertainty(sigma=1, t=X)
    axes.fill_between(X, y_fit - dely, y_fit + dely, color="#ABABAB")

    reduced_chi_square = get_chi_square(fit_result)
    x = CHI_SQUARE_LOCATION[soil][0]
    y = CHI_SQUARE_LOCATION[soil][1]
    axes.text(x, y, str(reduced_chi_square), transform=axes.transAxes)
Beispiel #17
0
def best_model_plot(data_ax: Axes, model: VariationalRegressor,
                    method: str) -> Axes:
    x_plot = plot_dataset[0]
    colour = get_colour_for_method(method)
    best_mean = model.predictive_mean(x_plot, method).flatten()
    best_std = model.predictive_std(x_plot, method).flatten()
    with_prior = getattr(model, "prior_std", None) is not None
    if with_prior:
        prior_std = model.prior_std(x_plot).flatten()
        # Prior plot
        data_ax.fill_between(
            x_plot.flatten(),
            best_mean + 1.96 * prior_std,
            best_mean - 1.96 * prior_std,
            facecolor=colours["grey"],
            alpha=0.65,
            label=r"$\mu_{%s}(x) \pm 1.96 \,\sigma_{%s}(x)$" %
            (method, method),
        )

    # Model plot
    data_ax.plot(x_plot, best_mean, "-", color=colour, alpha=0.55)
    data_ax.fill_between(
        x_plot.flatten(),
        best_mean + 1.96 * best_std,
        best_mean - 1.96 * best_std,
        facecolor=colour,
        alpha=0.3,
        label=r"$\mu_{%s}(x) \pm 1.96 \,\sigma_{prior}(x)$" % method,
    )

    data_ax.legend(bbox_to_anchor=(1.05, 1),
                   loc="upper left",
                   borderaxespad=0.0)

    return data_ax
Beispiel #18
0
    def plot_z_trend_histogram(self,
                               axis: Axes = None,
                               polar: bool = True,
                               normed: bool = True) -> None:

        if axis is None:
            axis = self.figure.add_subplot(111)

        cluster = Cluster(simulation_name=self.simulation.simulation_name,
                          clusterID=0,
                          redshift='z000p000')
        aperture_float = self.get_apertures(cluster)[
            self.aperture_id] / cluster.r200

        if not os.path.isfile(
                os.path.join(
                    self.path,
                    f'redshift_rot0rot4_histogram_aperture_{self.aperture_id}.npy'
                )):
            warnings.warn(
                f"File redshift_rot0rot4_histogram_aperture_{self.aperture_id}.npy not found."
            )
            print("self.make_simhist() activated.")
            self.make_simhist()

        print(
            f"Retrieving npy files: redshift_rot0rot4_histogram_aperture_{self.aperture_id}.npy"
        )
        sim_hist = np.load(os.path.join(
            self.path,
            f'redshift_rot0rot4_histogram_aperture_{self.aperture_id}.npy'),
                           allow_pickle=True)
        sim_hist = np.asarray(sim_hist)

        if normed:
            norm_factor = np.sum(self.simulation.sample_completeness)
            sim_hist[2] /= norm_factor
            sim_hist[3] /= norm_factor
            y_label = r"Sample fraction"
        else:
            y_label = r"Number of samples"

        items_labels = f""" REDSHIFT TRENDS - HISTOGRAM
							Number of clusters: {self.simulation.totalClusters:d}
							$z$ = 0.0 - 1.8
							Total samples: {np.sum(self.simulation.sample_completeness):d} $\equiv N_\mathrm{{clusters}} \cdot N_\mathrm{{redshifts}}$
							Aperture radius = {aperture_float:.2f} $R_{{200\ true}}$"""
        print(items_labels)

        sim_colors = {
            'ceagle': 'pink',
            'celr_e': 'lime',
            'celr_b': 'orange',
            'macsis': 'aqua',
        }

        axis.axvline(90, linestyle='--', color='k', alpha=0.5, linewidth=2)
        axis.step(sim_hist[0],
                  sim_hist[2],
                  color=sim_colors[self.simulation.simulation_name],
                  where='mid')
        axis.fill_between(sim_hist[0],
                          sim_hist[2] + sim_hist[3],
                          sim_hist[2] - sim_hist[3],
                          step='mid',
                          color=sim_colors[self.simulation.simulation_name],
                          alpha=0.2,
                          edgecolor='none',
                          linewidth=0)

        axis.set_ylabel(y_label, size=25)
        axis.set_xlabel(
            r"$\Delta \theta \equiv (\mathbf{L}_\mathrm{gas},\mathrm{\widehat{CoP}},\mathbf{L}_\mathrm{stars})$\quad[degrees]",
            size=25)
        axis.set_xlim(0, 180)
        axis.set_ylim(0, 0.1)
        axis.text(0.03,
                  0.97,
                  items_labels,
                  horizontalalignment='left',
                  verticalalignment='top',
                  transform=axis.transAxes,
                  size=15)

        if polar:
            inset_axis = self.figure.add_axes([0.75, 0.65, 0.25, 0.25],
                                              projection='polar')
            inset_axis.patch.set_alpha(0)  # Transparent background
            inset_axis.set_theta_zero_location('N')
            inset_axis.set_thetamin(0)
            inset_axis.set_thetamax(180)
            inset_axis.set_xticks(np.pi / 180. *
                                  np.linspace(0, 180, 5, endpoint=True))
            inset_axis.set_yticks([])
            inset_axis.step(sim_hist[0] / 180 * np.pi,
                            sim_hist[2],
                            color=sim_colors[self.simulation.simulation_name],
                            where='mid')
            inset_axis.fill_between(
                sim_hist[0] / 180 * np.pi,
                sim_hist[2] + sim_hist[3],
                sim_hist[2] - sim_hist[3],
                step='mid',
                color=sim_colors[self.simulation.simulation_name],
                alpha=0.2,
                edgecolor='none',
                linewidth=0)

        patch_ceagle = Patch(facecolor=sim_colors['ceagle'],
                             label='C-EAGLE',
                             edgecolor='k',
                             linewidth=1)
        patch_celre = Patch(facecolor=sim_colors['celr_e'],
                            label='CELR-E',
                            edgecolor='k',
                            linewidth=1)
        patch_celrb = Patch(facecolor=sim_colors['celr_b'],
                            label='CELR-B',
                            edgecolor='k',
                            linewidth=1)
        patch_macsis = Patch(facecolor=sim_colors['macsis'],
                             label='MACSIS',
                             edgecolor='k',
                             linewidth=1)

        leg2 = axis.legend(
            handles=[patch_ceagle, patch_celre, patch_celrb, patch_macsis],
            loc='lower center',
            handlelength=1,
            fontsize=20)
        axis.add_artist(leg2)
def plot_learning_curve(
    estimator: Estimator,
    x: pd.DataFrame,
    y: DataType,
    cv: int = 5,
    scoring: str = "default",
    n_jobs: int = -1,
    train_sizes: Sequence = np.linspace(0.1, 1.0, 5),
    ax: Axes = None,
    random_state: int = None,
    title: str = "Learning Curve",
    **kwargs,
) -> Axes:
    """
    Generates a :func:`~sklearn.model_selection.learning_curve` plot,
    used to determine model performance as a function of number of training examples.

    Illustrates whether or not number of training examples is the performance bottleneck.
    Also used to diagnose underfitting or overfitting,
    by seeing how the training set and validation set performance differ.

    Parameters
    ----------
    estimator: sklearn-compatible estimator
        An instance of a sklearn estimator
    x: pd.DataFrame
        DataFrame of features
    y: pd.Series or np.Array
        Target values to predict
    cv: int
        Number of CV iterations to run. Uses a :class:`~sklearn.model_selection.StratifiedKFold` if
        `estimator` is a classifier - otherwise a :class:`~sklearn.model_selection.KFold` is used.
    scoring: str
        Metric to use in scoring - must be a scikit-learn compatible
        :ref:`scoring method<sklearn:scoring_parameter>`
    n_jobs: int
        Number of jobs to use in parallelizing the estimator fitting and scoring
    train_sizes: Sequence of floats
        Percentage intervals of data to use when training
    ax: plt.Axes
        The plot will be drawn on the passed ax - otherwise a new figure and ax will be created.
    random_state: int
        Random state to use in CV splitting
    title: str
        Title to be used on the plot
    kwargs: dict
        Passed along to matplotlib line plots

    Returns
    -------
    plt.Axes
    """
    if ax is None:
        fig, ax = plt.subplots()

    if scoring == "default":
        scoring = "accuracy" if is_classifier(estimator) else "r2"

    train_sizes, train_scores, test_scores = learning_curve(
        estimator,
        x,
        y,
        train_sizes=train_sizes,
        cv=cv,
        scoring=scoring,
        n_jobs=n_jobs,
        random_state=random_state,
    )

    train_scores_mean = np.mean(train_scores, axis=1)
    train_scores_std = np.std(train_scores, axis=1)
    test_scores_mean = np.mean(test_scores, axis=1)
    test_scores_std = np.std(test_scores, axis=1)

    ax.fill_between(
        train_sizes,
        train_scores_mean - train_scores_std,
        train_scores_mean + train_scores_std,
        alpha=0.1,
    )
    ax.fill_between(
        train_sizes,
        test_scores_mean - test_scores_std,
        test_scores_mean + test_scores_std,
        alpha=0.1,
    )
    ax.plot(
        train_sizes, train_scores_mean, label=f"Training {scoring.title()}", **kwargs
    )
    ax.plot(
        train_sizes,
        test_scores_mean,
        label=f"Cross-validated {scoring.title()}",
        **kwargs,
    )
    ax.legend(loc="best")
    ax.set_ylabel(f"{scoring.title()} Score")
    ax.set_xlabel("Number of Examples Used")
    ax.set_title(title)

    return ax
Beispiel #20
0
def plot_ridgeline(
        data: Sequence[Dict[str, Sequence[float]]],
        ax: Axes,
        colors: Sequence[Color],
        fill_alpha: float = 0.2,
        point_alpha: float = 1,
        num_bins: int = 20,
        overlap: float = 0.2,
        horizontal_extra: float = 0.2,
        jitter: Optional[float] = None,
        data_bounds: Optional[Tuple[float, float]] = None,
        fontsize: float = 36,
        ) -> None:
    '''Plot data as a ridgeline plot.

    Args:
        data: List of mappings from strings to sequences of values. One
            distribution will be plotted for each sequence of values and
            labeled with the corresponding string. Each mapping
            represents a replicate.
        ax: Axes on which to plot.
        colors: The colors to plot each replicate in.
        fill_alpha: Alpha value to control transparency of fill color.
        point_alpha: Alpha value to control transparency of point color.
        num_bins: Number of equally-sized bins into which the values
            will be grouped. The averages of the values in each bin
            determine the curve height at the middle of the bin on the
            x-axis. This is equivalent to the bins on a histogram.
        overlap: The fraction by which distributions are allowed to
            overlap. For example, with an overlap of ``0.2``, the first
            distribution would be at :math:`y=0` and the next
            distribution would be at :math:`y=0.8`. Note that since not
            all distributions will reach a height of :math:`1`, setting
            a nonzero overlap does not guarantee that any distributions
            will actually overlap.
        horizontal_extra: The fraction of the range of ``data`` which
            should be added to the negative and positive x axis. This
            adds extra space to the left and right of distributions.
        jitter: To let the viewer distinguish between multiple nearby
            points, we randomly add small values to point positions.
            ``jitter`` is the maximum absolute value of these
            perturbations, which are uniformly chosen from the range
            :math:`(-jitter, jitter)`. If ``None``, ``jitter`` will be
            set to one tenth of the maximum value in ``data``. To
            disable jitter, set ``jitter=0``.
        data_bounds: The minimum and maximum values present in the data.
            If ``None``, the range is calculated from ``data``. This may
            be useful when ``data`` is only some of the data that will
            be plotted, e.g. with replicates.
        fontsize: The size to use for all text.
    '''
    if data_bounds is None:
        flat_data = flatten([data_elem.values() for data_elem in data])
        data_min = min(flat_data)
        data_max = max(flat_data)
    else:
        data_min, data_max = data_bounds
    data_range = data_max - data_min
    if jitter is None:
        jitter = data_range / 10
    extra = data_range * horizontal_extra
    x_values = np.linspace(
        data_min - extra, data_max + extra, num_bins)
    y_values, density_curves, offset = _calculate_density_curves(
        data, x_values, overlap)
    num_replicates = len(density_curves[list(y_values.keys())[0]])
    assert len(colors) == num_replicates

    for i, y_label in enumerate(y_values):
        y = y_values[y_label]
        assert len(density_curves[y_label]) == num_replicates
        for j, density_curve in enumerate(density_curves[y_label]):
            color = colors[j]
            zorder = num_replicates - i + 1
            # Below, we cast colors to strings to satisfy mypy, which
            # doesn't understand using tuples as colors.
            ax.plot(
                x_values, density_curve + y,
                color=cast(str, color), linestyle='-', zorder=zorder)
            # Mypy incorrectly complains that Axes has no fill_between
            # attribute.
            ax.fill_between(  # type: ignore
                x_values, np.ones(num_bins) * y, density_curve + y,
                color=cast(str, color), zorder=-1, alpha=fill_alpha)
            points = np.array(
                data[j][y_label], dtype=float)  # type: ignore
            # We disable pylint's no-member check because pylint doesn't
            # recognize that np.random.random is valid.
            # pylint: disable=no-member
            points += (
                np.random.random(len(points)) - 0.5  # type: ignore
            ) * 2 * jitter
            # pylint: enable=no-member
            ax.scatter(points,  # type: ignore
                np.ones(len(points)) * y - offset,
                color=cast(str, color), marker='|', s=20,
                linewidths=0.1, zorder=zorder, alpha=point_alpha)
    ax.set_yticks(list(y_values.values()))
    ax.set_yticklabels(list(y_values.keys()))
    ax.tick_params(  # type: ignore
        axis='both', which='major', labelsize=fontsize)
Beispiel #21
0
def plotWF(qcl: QCLayers,
           plotType: str = 'mode',
           fillPlot: Union[bool, float] = False,
           pickedStates: Iterable = set(),
           showPeriod: bool = True,
           axes: Axes = None):
    r"""Plot the wavefunctions of qcl.
    The wavefunctions are scaled by :func:`scaleWF`.

    Parameters
    ----------
    qcl :
        The QCLayers to plot
    plotType :
        Can be 'mode' or 'wf', to determine it's plotting mode
        (:math:`\psi^2`) or wavefunction itself (:math:`\psi`).
    fillPlot :
        Wether to fill up the wavefunctions. If it's `False` or `None` it will
        not fill, otherwise it should be a float number smaller than 1, meaning
        the transparency of the fill color.
    pickedStates :
        A set of state indices that should be plotted in thick black color.
    showPeriod :
        Flag to whether emphasis the recognized period of wavefunctions.
    axes :
        The axes to plot the figure on.

    Returns
    -------
    A list of plotted data

    """
    if axes is None:
        axes = gca()
    colors = config['wf_colors']
    wfs = scaleWF(qcl, plotType)
    # filter almost zero part
    starts = np.argmax(abs(wfs) > config["wf_almost_zero"], axis=1)
    ends = np.argmax(abs(wfs[:, ::-1]) > config["wf_almost_zero"], axis=1)
    showPop = False
    if showPeriod:
        qcl.period_recognize()
        showPop = qcl.status == 'solved-full'
    if showPop:
        qcl.period_map_build()
        vmin = 0
        vmax = np.ceil(np.max(qcl.population) * 10) / 10
        popMap = cm.ScalarMappable(cmNorm(vmin=vmin, vmax=vmax), 'plasma')
    for n in range(len(qcl.eigenEs)):
        ls = '-'
        if n in pickedStates:
            color = 'k'
            lw = config['default_lw'] * 2
        else:
            if showPop:
                if qcl.periodMap[n] is not None:
                    color = popMap.to_rgba(qcl.state_population(n))
                else:
                    color = 'g'
            else:
                color = colors[n % len(colors)]
            if qcl.status == 'basis':
                lw = config['default_lw']
            else:
                if showPeriod and n in qcl.periodIdx:
                    # lw = 1 if n in qcl.unBound else 1.5
                    lw = config['default_lw']
                    if n in qcl.unBound:
                        ls = (0, (0.5, 0.5))
                else:
                    lw = config['default_lw'] / 2
        x = qcl.xPoints[starts[n]:-ends[n]]
        y = wfs[n, starts[n]:-ends[n]] + qcl.eigenEs[n]
        axes.plot(x, y, lw=lw, ls=ls, color=color)
        if fillPlot:
            axes.fill_between(x,
                              y,
                              qcl.eigenEs[n],
                              facecolor=color,
                              alpha=fillPlot)
    if showPop:
        colorbar_axes = axes.inset_axes([0.03, 0.01, 0.5, 0.02])
        axes.figure.colorbar(popMap,
                             cax=colorbar_axes,
                             orientation='horizontal',
                             label='electron population')
        colorbar_axes.xaxis.set_ticks_position('top')
        colorbar_axes.xaxis.set_label_position('top')
    return wfs
Beispiel #22
0
def qq_gaussian(
    y_true: NumArray,
    y_pred: NumArray,
    y_std: NumArray | dict[str, NumArray],
    ax: Axes = None,
) -> Axes:
    """Plot the Gaussian quantile-quantile (Q-Q) plot of one (passed as array)
    or multiple (passed as dict) sets of uncertainty estimates for a single
    pair of ground truth targets `y_true` and model predictions `y_pred`.

    Overconfidence relative to a Gaussian distribution is visualized as shaded
    areas below the parity line, underconfidence (oversized uncertainties) as
    shaded areas above the parity line.

    The measure of calibration is how well the uncertainty percentiles conform
    to those of a normal distribution.

    Inspired by https://git.io/JufOz.
    Info on Q-Q plots: https://wikipedia.org/wiki/Q-Q_plot

    Args:
        y_true (array): ground truth targets
        y_pred (array): model predictions
        y_std (array | dict[str, array]): model uncertainties
        ax (Axes): matplotlib Axes on which to plot. Defaults to None.

    Returns:
        ax: The plot's matplotlib Axes.
    """
    if ax is None:
        ax = plt.gca()

    if isinstance(y_std, np.ndarray):
        y_std = {"std": y_std}

    res = np.abs(y_pred - y_true)
    resolution = 100

    lines = []  # collect plotted lines to show second legend with miscalibration areas
    for key, std in y_std.items():

        z_scored = (np.array(res) / std).reshape(-1, 1)

        exp_proportions = np.linspace(0, 1, resolution)
        gaussian_upper_bound = norm.ppf(0.5 + exp_proportions / 2)
        obs_proportions = np.mean(z_scored <= gaussian_upper_bound, axis=0)

        [line] = ax.plot(
            exp_proportions, obs_proportions, linewidth=2, alpha=0.8, label=key
        )
        ax.fill_between(
            exp_proportions, y1=obs_proportions, y2=exp_proportions, alpha=0.2
        )
        miscal_area = np.trapz(
            np.abs(obs_proportions - exp_proportions), dx=1 / resolution
        )
        lines.append([line, miscal_area])

    # identity line
    ax.axline((0, 0), (1, 1), alpha=0.5, zorder=0, linestyle="dashed", color="black")

    ax.set(xlim=(0, 1), ylim=(0, 1))
    ax.set(xlabel="Theoretical Quantile", ylabel="Observed Quantile")

    legend1 = ax.legend(loc="upper left", frameon=False)
    # Multiple legends on the same axes:
    # https://matplotlib.org/3.3.3/tutorials/intermediate/legend_guide.html#multiple-legends-on-the-same-axes
    ax.add_artist(legend1)

    lines, areas = zip(*lines)

    if len(lines) > 1:
        legend2 = ax.legend(
            lines,
            [f"{area:.2f}" for area in areas],
            title="Miscalibration areas",
            loc="lower right",
            ncol=2,
            frameon=False,
        )
        legend2._legend_box.align = "left"  # https://stackoverflow.com/a/44620643
    else:
        ax.legend(
            lines,
            [f"Miscalibration area: {areas[0]:.2f}"],
            loc="lower right",
            frameon=False,
        )

    return ax
Beispiel #23
0
def plot_timeseries_histograms(axes: Axes,
                               data: pd.DataFrame,
                               bins: Union[str, int, np.ndarray,
                                           Sequence[Union[int,
                                                          float]]] = "auto",
                               colormap: Colormap = cm.Blues,
                               **plot_kwargs) -> Axes:  # pragma: no cover
    """Generate a heat-map-like plot for time-series sample data.

    The kind of input this function expects can be obtained from an
    XArray object as follows:

    .. code:

        data = az_post_trace.posterior_predictive.Y_t[chain_idx].loc[
            {"dt": slice(t1, t2)}
        ]
        data = data.to_dataframe().Y_t.unstack(level=0)

    Parameters
    ==========
    axes
        The Matplotlib axes to use for plotting.
    data
        The sample data to be plotted.  This should be in "wide" format: i.e.
        the index should be "time" and the columns should correspond to each
        sample.
    bins
        The `bins` parameter passed to ``np.histogram``.
    colormap
        The Matplotlib colormap use to show relative frequencies within bins.
    plot_kwargs
        Keywords passed to ``fill_between``.

    """
    index = data.index
    y_samples = data.values

    n_t = len(index)

    # generate histograms and bins
    list_of_hist, list_of_bins = [], []
    for t in range(n_t):
        # TODO: determine proper range=(np.min(Y_t), np.max(Y_t))
        hist, bins_ = np.histogram(y_samples[t], bins=bins, density=True)
        if np.sum(hist > 0) == 1:
            hist, bins_ = np.array([1.0]), np.array([bins_[0], bins_[-1]])
        list_of_hist.append(hist)
        list_of_bins.append(bins_)

    if axes is None:
        _, (axes) = plt.subplots(nrows=1,
                                 ncols=1,
                                 sharex=True,
                                 figsize=(12, 4))
        axes.plot(index,
                  np.mean(y_samples, axis=1),
                  alpha=0.0,
                  drawstyle="steps")

    for t in range(n_t):
        mask = index == index[t]
        hist, bins_ = list_of_hist[t], list_of_bins[t]
        # normalize bin weights for plotting
        hist = hist / np.max(hist) * 0.85 if len(hist) > 1 else hist
        n = len(hist)
        # construct predictive arrays to plot
        y_t_ = np.tile(bins_, (n_t, 1))
        # include consecutive time points to create grid-ish steps
        if t > 0:
            mask = np.logical_or(mask, index == index[t - 1])
        for i in range(n):
            color_val = hist[i]
            color = colormap(color_val) if color_val else (1, 1, 1, 1)
            plot_kwargs.setdefault("step", "pre")
            axes.fill_between(
                index,
                y_t_[:, i],
                y_t_[:, i + 1],
                where=mask,
                color=color,
                **plot_kwargs,
            )

    return axes
Beispiel #24
0
def spacegroup_hist(
    data: Sequence[int | str] | pd.Series,
    show_counts: bool = True,
    xticks: Literal["all", "crys_sys_edges"] | int = 20,
    include_missing: bool = False,
    ax: Axes = None,
    **kwargs: Any,
) -> Axes:
    """Plot a histogram of spacegroups shaded by crystal system.

    Args:
        data (list[int | str] | pd.Series): A sequence (list, tuple, pd.Series) of
            space group strings or numbers (from 1 - 230) or pymatgen structures.
        show_counts (bool, optional): Whether to count the number of items
            in each crystal system. Defaults to True.
        xticks ('all' | 'crys_sys_edges' | int, optional): Where to add x-ticks. An
            integer will add ticks below that number of tallest bars. Defaults to 20.
            'all' will show below all bars, 'crys_sys_edges' only at the edge from one
            crystal system to another.
        include_missing (bool, optional): Whether to include a 0-height bar for missing
            space groups missing from the data. Currently only implemented for numbers,
            not symbols. Defaults to False.
        ax (Axes, optional): matplotlib Axes on which to plot. Defaults to None.
        kwargs: Keywords passed to pd.Series.plot.bar().

    Returns:
        ax: The plot's matplotlib Axes.
    """
    if ax is None:
        ax = plt.gca()

    if isinstance(next(iter(data)), Structure):
        # if 1st sequence item is structure, assume all are
        series = pd.Series(struct.get_space_group_info()[1]
                           for struct in data  # type: ignore
                           )
    else:
        series = pd.Series(data)

    df = pd.DataFrame(series.value_counts(sort=False))
    df.columns = ["counts"]

    crys_colors = {
        "triclinic": "red",
        "monoclinic": "teal",
        "orthorhombic": "blue",
        "tetragonal": "green",
        "trigonal": "orange",
        "hexagonal": "purple",
        "cubic": "yellow",
    }

    if df.index.is_numeric():  # assume index is space group numbers
        if include_missing:
            df = df.reindex(range(1, 231), fill_value=0)
        else:
            df = df.sort_index()
        df["crystal_sys"] = [get_crystal_sys(x) for x in df.index]
        ax.set(xlim=(0, 230))
        xlabel = "International Spacegroup Number"

    else:  # assume index is space group symbols
        # TODO: figure how to implement include_missing for space group symbols
        # if include_missing:
        #     idx = [SpaceGroup.from_int_number(x).symbol for x in range(1, 231)]
        #     df = df.reindex(idx, fill_value=0)
        df["crystal_sys"] = [SpaceGroup(x).crystal_system for x in df.index]

        # sort df by crystal system going from smallest to largest spacegroup numbers
        # e.g. triclinic (1-2) comes first, cubic (195-230) last
        sys_order = dict(zip(crys_colors, range(len(crys_colors))))
        df = df.loc[df.crystal_sys.map(sys_order).sort_values().index]

        xlabel = "International Spacegroup Symbol"

    ax.set(xlabel=xlabel, ylabel="Count")

    kwargs["width"] = kwargs.get("width", 0.9)  # set default bar width
    # make plot
    df.counts.plot.bar(figsize=[16, 4], ax=ax, **kwargs)

    # https://matplotlib.org/3.1.1/gallery/lines_bars_and_markers/fill_between_demo
    trans = transforms.blended_transform_factory(ax.transData, ax.transAxes)

    # count rows per crystal system
    crys_sys_counts = df.groupby("crystal_sys").sum("counts")

    # sort by key order in dict crys_colors
    crys_sys_counts = crys_sys_counts.loc[[
        x for x in crys_colors if x in crys_sys_counts.index
    ]]

    crys_sys_counts["width"] = df.value_counts("crystal_sys")
    ax.set_title("Totals per crystal system",
                 fontdict={"fontsize": 18},
                 pad=30)
    crys_sys_counts["color"] = pd.Series(crys_colors)

    x0 = 0
    for cryst_sys, count, width, color in crys_sys_counts.itertuples():
        x1 = x0 + width

        for patch in ax.patches[0 if x0 == 1 else x0:x1 + 1]:
            patch.set_facecolor(color)

        text_kwds = dict(transform=trans, horizontalalignment="center")
        ax.text(
            *[(x0 + x1) / 2, 0.95],
            cryst_sys,
            rotation=90,
            verticalalignment="top",
            fontdict={"fontsize": 14},
            **text_kwds,
        )
        if show_counts:
            ax.text(
                *[(x0 + x1) / 2, 1.02],
                f"{count:,} ({count/len(data):.0%})",
                fontdict={"fontsize": 12},
                **text_kwds,
            )

        ax.fill_between(
            [x0 - 0.5, x1 - 0.5],
            *[0, 1],
            facecolor=color,
            alpha=0.1,
            transform=trans,
            edgecolor="black",
        )
        x0 += width

    ax.yaxis.grid(True)
    ax.xaxis.grid(False)

    if xticks == "crys_sys_edges" or isinstance(xticks, int):

        if isinstance(xticks, int):
            # get x_locs of n=xticks tallest bars
            x_indices = df.reset_index().sort_values("counts").tail(
                xticks).index
        else:
            # add x_locs of n=xticks tallest bars
            x_indices = crys_sys_counts.width.cumsum()

        majorLocator = FixedLocator(x_indices)

        ax.xaxis.set_major_locator(majorLocator)
    plt.xticks(rotation=90)

    return ax
Beispiel #25
0
def plot_roc(
    nR_S1: Union[list, np.ndarray],
    nR_S2: Union[list, np.ndarray],
    fitModel: Optional[dict] = None,
    ax: Axes = None,
) -> Axes:
    """Function to plot type2 ROC curve from observed an estimated data fit.

    Parameters
    ----------
    nR_S1 : 1d array-like
        Number of ratings for signal 1 (correct and incorrect).
    nR_S2 : 1d array-like
        Number of ratings for signal 2 (correct and incorrect).
    fitModel : dict or None
        Dictionnary returned by :py:func:`metadpy.std.metad()`. If
        provided, the estimated ratings will be plotted toghether with the
        observed data.
    ax : `Matplotlib.Axes` or None
        Where to draw the plot. Default is `None` (create a new figure).

    Returns
    -------
    ax : :class:`matplotlib.axes.Axes`
        The matplotlib axes containing the plot.

    Examples
    --------
    """
    if fitModel is None:

        nRatings = int(len(nR_S1) / 2)

        # Find incorrect observed ratings
        I_nR_rS2 = nR_S1[nRatings:]
        I_nR_rS1 = np.flip(nR_S2[:nRatings])
        I_nR = I_nR_rS2 + I_nR_rS1

        # Find correct observed ratings
        C_nR_rS2 = nR_S2[nRatings:]
        C_nR_rS1 = np.flip(nR_S1[:nRatings])
        C_nR = C_nR_rS2 + C_nR_rS1

        # Calculate type 2 hits and false alarms
        obs_FAR2_rS2, obs_HR2_rS2, obs_FAR2_rS1, obs_HR2_rS1, obs_FAR2, obs_HR2 = (
            [],
            [],
            [],
            [],
            [],
            [],
        )
        for i in range(nRatings):
            obs_FAR2_rS2.append(sum(I_nR_rS2[i:]) / sum(I_nR_rS2))
            obs_HR2_rS2.append(sum(C_nR_rS2[i:]) / sum(C_nR_rS2))
            obs_FAR2_rS1.append(sum(I_nR_rS1[i:]) / sum(I_nR_rS1))
            obs_HR2_rS1.append(sum(C_nR_rS1[i:]) / sum(C_nR_rS1))
            obs_FAR2.append(sum(I_nR[i:]) / sum(I_nR))
            obs_HR2.append(sum(C_nR[i:]) / sum(C_nR))

        obs_FAR2.append(0)
        obs_HR2.append(0)

        if ax is None:
            fig, ax = plt.subplots(1, 1, figsize=(5, 5))

        ax.plot([0, 1], [0, 1], "--", color="gray")
        ax.fill_between(x=obs_FAR2, y1=obs_HR2, color="lightgray", alpha=0.5)
        ax.plot(obs_FAR2,
                obs_HR2,
                "ko-",
                linewidth=1.5,
                markersize=12,
                label="Observed")
        ax.set_title("Type 2 ROC curve")
        ax.set_ylabel("Type 2 P(correct)")
        ax.set_xlabel("Type 2 P(incorrect)")

    else:
        if not isinstance(fitModel, dict):
            raise ValueError("You should provided a dictionnary. "
                             "See metadpy.std.metad() for help.")
        if ax is None:
            fig, ax = plt.subplots(1, 2, figsize=(10, 5))

        # Stimulus 1
        ax[0].plot([0, 1], [0, 1], "--", color="gray")
        ax[0].plot(
            fitModel["obs_FAR2_rS1"],
            fitModel["obs_HR2_rS1"],
            "ko-",
            linewidth=1.5,
            markersize=12,
            label="Observed",
        )
        ax[0].plot(
            fitModel["est_FAR2_rS1"],
            fitModel["est_HR2_rS1"],
            "bo-",
            linewidth=1.5,
            markersize=6,
            label="Estimated",
        )
        ax[0].set_title("Stimulus 1")
        ax[0].set_ylabel("Type 2 Hit Rate")
        ax[0].set_xlabel("Type 2 False Alarm Rate")

        # Stimulus 2
        ax[1].plot([0, 1], [0, 1], "--", color="gray")
        ax[1].plot(
            fitModel["obs_FAR2_rS2"],
            fitModel["obs_HR2_rS2"],
            "ko-",
            linewidth=1.5,
            markersize=12,
            label="Observed",
        )
        ax[1].plot(
            fitModel["est_FAR2_rS2"],
            fitModel["est_HR2_rS2"],
            "bo-",
            linewidth=1.5,
            markersize=6,
            label="Estimated",
        )
        ax[1].set_title("Stimulus 2")
        ax[1].set_ylabel("Type 2 Hit Rate")
        ax[1].set_xlabel("Type 2 False Alarm Rate")

    return ax
Beispiel #26
0
def plot_validation_curve(
    estimator: Estimator,
    x: pd.DataFrame,
    y: DataType,
    param_name: str,
    param_range: Sequence,
    cv: int = 5,
    scoring: str = "default",
    n_jobs: int = -1,
    ax: Axes = None,
    title: str = "",
    **kwargs,
) -> Axes:
    """
    Plots a :func:`~sklearn.model_selection.validation_curve`, graphing the impact
    of changing a hyperparameter on the scoring metric.

    This lets us examine how a hyperparameter affects
    over/underfitting by examining train/test performance
    with different values of the hyperparameter.


    Parameters
    ----------
    estimator: sklearn-compatible estimator
        An instance of a sklearn estimator

    x: pd.DataFrame
        DataFrame of features

    y: pd.Series or np.Array
        Target values to predict

    param_name: str
        Name of hyperparameter to plot

    param_range: Sequence
        The individual values to plot for `param_name`

    cv: int
        Number of CV iterations to run. Uses a :class:`~sklearn.model_selection.StratifiedKFold` if
        `estimator` is a classifier - otherwise a :class:`~sklearn.model_selection.KFold` is used.

    scoring: str
        Metric to use in scoring - must be a scikit-learn compatible
        :ref:`scoring method<sklearn:scoring_parameter>`

    n_jobs: int
        Number of jobs to use in parallelizing the estimator fitting and scoring

    ax: plt.Axes
        The plot will be drawn on the passed ax - otherwise a new figure and ax will be created.

    title: str
        Title to be used on the plot

    kwargs: dict
        Passed along to matplotlib line plots

    Returns
    -------
    plt.Axes
    """
    if scoring == "default":
        scoring = "accuracy" if is_classifier(estimator) else "r2"

    train_scores, test_scores = validation_curve(
        estimator=estimator,
        X=x,
        y=y,
        param_name=param_name,
        param_range=param_range,
        cv=cv,
        scoring=scoring,
        n_jobs=n_jobs,
    )

    train_scores_mean = np.mean(train_scores, axis=1)
    train_scores_std = np.std(train_scores, axis=1)
    test_scores_mean = np.mean(test_scores, axis=1)
    test_scores_std = np.std(test_scores, axis=1)

    if ax is None:
        fig, ax = plt.subplots()

    title = "Validation Curve" if title is None else title

    ax.plot(
        param_range, train_scores_mean, label=f"Training {scoring.title()}", **kwargs
    )
    ax.fill_between(
        param_range,
        train_scores_mean + train_scores_std,
        train_scores_mean - train_scores_std,
        alpha=0.2,
    )

    ax.plot(param_range, test_scores_mean, label=f"Test {scoring.title()}", **kwargs)
    ax.fill_between(
        param_range,
        test_scores_mean + test_scores_std,
        test_scores_mean - test_scores_std,
        alpha=0.2,
    )

    ax.set_title(title)
    ax.set_ylabel(f"{scoring.title()} Score")
    ax.set_xlabel(f"{param_name.replace('_', ' ').title()}")
    ax.legend(loc="best")
    return ax
Beispiel #27
0
def plot_fit(vis_data: ndarray,
             ax: Axes,
             xy_kwargs: dict = None,
             xycalc_kwargs: dict = None,
             xydiff_kwargs: dict = None,
             xyzero_kwargs: dict = None,
             fill_kwargs: dict = None,
             yzero: ndarray = None):
    """Visualize the fit.

    Parameters
    ----------
    vis_data : ndarray
        The data to visualize. The first three rows are independent variable, dependent variable, and fitting.

    ax : Axes
        The axes to plot on.

    xy_kwargs : dict
        The kwargs for plotting the y v.s. x curve.

    xycalc_kwargs : dict
        The kwargs for plotting the ycalc v.s. x curve.

    xydiff_kwargs : dict
        The kwargs for plotting the ydiff v.s. x curve.

    xyzero_kwargs : dict
        The kwargs for plotting the yzero v.s. x curve.

    fill_kwargs : dict
        The kwargs for filling in the area between ydiff and yzero.

    yzero : ndarray
        The base line corresponding to the zero value in ydiff.
    """
    # use the default value if None
    if xyzero_kwargs is None:
        xyzero_kwargs = {}
    if xydiff_kwargs is None:
        xydiff_kwargs = {}
    if xycalc_kwargs is None:
        xycalc_kwargs = {}
    if xy_kwargs is None:
        xy_kwargs = {}
    if fill_kwargs is None:
        fill_kwargs = {}
    # split data
    if len(vis_data.shape) != 2:
        raise ValueError('Invalid data shape: {}. Need 2D data array.'.format(
            vis_data.shape))
    if vis_data.shape[0] < 3:
        raise ValueError('Invalid data dimension: {}. Need dim >= 3'.format(
            vis_data.shape[0]))
    x, y, ycalc = vis_data[:3]
    ydiff = y - ycalc
    # shift ydiff
    if yzero is None:
        yzero = get_yzero(y, ycalc, ydiff)
    ydiff += yzero
    # circle data curve
    _xy_kwargs = {'fillstyle': 'none', 'label': 'data'}
    _xy_kwargs.update(xy_kwargs)
    data_line, = ax.plot(x, y, 'o', **_xy_kwargs)
    # solid calculation curve
    _xycalc_kwargs = {
        'label': 'fit',
        'color': complimentary(data_line.get_color())
    }
    _xycalc_kwargs.update(xycalc_kwargs)
    ax.plot(x, ycalc, '-', **_xycalc_kwargs)
    # dash zero difference curve
    _xyzero_kwargs = {'color': 'grey'}
    _xyzero_kwargs.update(xyzero_kwargs)
    ax.plot(x, yzero, '--', **_xyzero_kwargs)
    # solid shifted difference curve
    _xydiff_kwargs = {'label': 'residuals', 'color': data_line.get_color()}
    _xydiff_kwargs.update(xydiff_kwargs)
    diff_line, = ax.plot(x, ydiff, '-', **_xydiff_kwargs)
    # fill in area between curves
    if fill_kwargs.pop('fill', True):
        _fill_kwargs = {'alpha': 0.4, 'color': diff_line.get_color()}
        _fill_kwargs.update(fill_kwargs)
        ax.fill_between(x, ydiff, yzero, **_fill_kwargs)
    return ax