def plot_ba(experiment, **kwargs):
    # Operate on cached data only.
    get_experiment_split_data.check_in_store(experiment)
    X_train, X_test, y_train, y_test = get_experiment_split_data(experiment)

    # Operate on cached data only.
    get_endog_exog_mask.check_in_store(experiment)
    master_mask = get_endog_exog_mask(experiment)[2]

    check_master_masks(master_mask)

    # Operate on cached fitted models only.
    get_model(X_train, y_train, cache_check=True)

    predicted_test = threading_get_model_predict(
        X_train=X_train,
        y_train=y_train,
        predict_X=X_test,
    )

    ba_plotting(
        *get_ba_plotting_data(predicted_test, y_test, master_mask),
        figure_saver=map_figure_saver(sub_directory=experiment.name),
        **get_aux0_aux1_kwargs(y_test, master_mask),
        filename=f"{experiment.name}_ba_prediction",
    )
Exemplo n.º 2
0
def loco_calc(experiment, cache_check=False, **kwargs):
    """Calculate LOCO values.

    Args:
        experiment (str): Experiment (e.g. 'ALL').
        cache_check (bool): Whether to check for cached data exclusively.

    """
    # Operate on cached data only.
    get_experiment_split_data.check_in_store(experiment)
    X_train, X_test, y_train, y_test = get_experiment_split_data(experiment)

    loco_results = optional_client_call(
        calculate_loco,
        dict(
            rf=DaskRandomForestRegressor(**param_dict),
            X_train=X_train,
            y_train=y_train,
            X_test=X_test,
            y_test=y_test,
            leave_out=("", *selected_features[experiment]),
            local_n_jobs=(1 if (get_ncpus() < 4) else (get_ncpus() - 2)),
        ),
        cache_check=cache_check,
        add_client=True,
    )[0]

    if cache_check:
        return IN_STORE
    return loco_results
Exemplo n.º 3
0
def pfi_calc(experiment, cache_check=False, **kwargs):
    """Calculate PFIs for both training and test data.

    Args:
        experiment (str): Experiment (e.g. 'ALL').
        data ({'test', 'train'}): Which data to use.
        cache_check (bool): Whether to check for cached data exclusively.

    """
    # Operate on cached data only.
    get_experiment_split_data.check_in_store(experiment)
    X_train, X_test, y_train, y_test = get_experiment_split_data(experiment)

    # Operate on cached fitted models only.
    get_model(X_train, y_train, cache_check=True)
    rf = get_model(X_train, y_train)

    # Test data.
    pfi_test_args = (rf, X_test, y_test)
    if cache_check:
        calculate_pfi.check_in_store(*pfi_test_args)

    # Train data.
    pfi_train_args = (rf, X_train, y_train)
    if cache_check:
        return calculate_pfi.check_in_store(*pfi_train_args)

    return {
        "train": calculate_pfi(*pfi_train_args),
        "test": calculate_pfi(*pfi_test_args),
    }
def shap_values(experiment, index, kind, cache_check=False, **kwargs):
    # Operate on cached data only.
    get_experiment_split_data.check_in_store(experiment)
    X_train, X_test, y_train, y_test = get_experiment_split_data(experiment)

    if kind == "train":
        shap_params = get_shap_params(X_train)
        X = X_train
    elif kind == "test":
        shap_params = get_shap_params(X_test)
        X = X_test
    else:
        raise ValueError(f"Unknown kind '{kind}'.")

    # Operate on cached fitted models only.
    get_model(X_train, y_train, cache_check=True)
    rf = get_model(X_train, y_train)

    calc_shap_args = (
        rf,
        X.iloc[index * shap_params["job_samples"]:(index + 1) *
               shap_params["job_samples"]],
    )

    if cache_check:
        return get_shap_values.check_in_store(*calc_shap_args)

    return get_shap_values(*calc_shap_args)
def plot_single_1d_ale(experiment, column, ax, verbose=False):
    # Operate on cached data only.
    get_experiment_split_data.check_in_store(experiment)
    X_train, X_test, y_train, y_test = get_experiment_split_data(experiment)

    # Operate on cached fitted models only.
    get_model(X_train, y_train, cache_check=True)
    model = get_model(X_train, y_train)

    save_ale_1d(
        model,
        X_train,
        column,
        train_response=y_train,
        figure_saver=None,
        verbose=verbose,
        monte_carlo_rep=100,
        monte_carlo_ratio=get_frac_train_nr_samples(Experiment["15VEG_FAPAR"],
                                                    0.1),
        ax=ax,
        ale_factor_exp=plotting_configuration.ale_factor_exps[column.parent],
        x_ndigits=plotting_configuration.ndigits.get(column.parent, 2),
        x_skip=4 if
        ((experiment, column) !=
         (Experiment["15VEG_FAPAR_MON"], variable.DRY_DAY_PERIOD[3])) else
        skip_14,
    )
def combination_fit(combination, split_index, cache_check=False, **kwargs):
    # Get training and test data for all variables.
    get_experiment_split_data.check_in_store(Experiment.ALL)
    X_all, _, y, _ = get_experiment_split_data(Experiment.ALL)

    if cache_check:
        return fit_combination.check_in_store(X_all, y, combination, split_index)
    return fit_combination(X_all, y, combination, split_index)
def get_experiment_model_scores(experiment, cache_check=False, **kwargs):
    # Operate on cached data only.
    get_experiment_split_data.check_in_store(experiment)
    X_train, X_test, y_train, y_test = get_experiment_split_data(experiment)

    # Operate on cached fitted models only.
    get_model(X_train, y_train, cache_check=True)
    model = get_model(X_train, y_train)

    if cache_check:
        return get_model_scores.check_in_store(model, X_test, X_train, y_test,
                                               y_train)
    return get_model_scores(model, X_test, X_train, y_test, y_train)
def fit_experiment_model(experiment, cache_check=False, **kwargs):
    if cache_check:
        get_experiment_split_data.check_in_store(experiment)
    X_train, X_test, y_train, y_test = get_experiment_split_data(experiment)

    if cache_check:
        return get_model(X_train=X_train, y_train=y_train, cache_check=True)
    model = get_model(
        X_train=X_train,
        y_train=y_train,
        parallel_backend_call=(
            # Use local threading backend - avoid the Dask backend.
            partial(parallel_backend, "threading", n_jobs=get_ncpus())),
    )
    return model
def plot_1d_ale(experiment, column, single=False, verbose=False, **kwargs):
    exp_figure_saver = figure_saver(sub_directory=experiment.name)

    # Operate on cached data only.
    get_experiment_split_data.check_in_store(experiment)
    X_train, X_test, y_train, y_test = get_experiment_split_data(experiment)

    # Operate on cached fitted models only.
    get_model(X_train, y_train, cache_check=True)
    model = get_model(X_train, y_train)

    save_ale_1d(
        model,
        X_train,
        column,
        train_response=y_train,
        figure_saver=exp_figure_saver,
        verbose=verbose,
        monte_carlo_rep=200,
        monte_carlo_ratio=0.1,
    )
def plot_obs_pred_comp(experiment, **kwargs):
    # Operate on cached data/models only.
    get_experiment_split_data.check_in_store(experiment)
    X_train, X_test, y_train, y_val = get_experiment_split_data(experiment)
    get_model(X_train, y_train, cache_check=True)

    get_endog_exog_mask.check_in_store(experiment)
    master_mask = get_endog_exog_mask(experiment)[2]

    check_master_masks(master_mask)

    u_pre = threading_get_model_predict(
        X_train=X_train,
        y_train=y_train,
        predict_X=X_test,
    )

    obs_pred_diff_cube = get_obs_pred_diff_cube(y_val, u_pre, master_mask)

    with map_figure_saver(sub_directory=experiment.name)(
            f"{experiment.name}_obs_pred_comp", sub_directory="predictions"):
        disc_cube_plot(
            obs_pred_diff_cube,
            fig=plt.figure(figsize=(5.1, 2.3)),
            cmap="BrBG",
            cmap_midpoint=0,
            cmap_symmetric=False,
            bin_edges=[-0.01, -0.001, -1e-4, 0, 0.001, 0.01, 0.02],
            extend="both",
            cbar_format=get_sci_format(ndigits=0),
            cbar_pad=0.025,
            cbar_label="Ob. - Pr.",
            **get_aux0_aux1_kwargs(y_val, master_mask),
            loc=(0.83, 0.14),
            height=0.055,
            aspect=1,
            spacing=0.06 * 0.2,
        )
Exemplo n.º 11
0
warnings.filterwarnings("ignore", ".*guessing contiguous bounds.*")

warnings.filterwarnings(
    "ignore", 'Setting feature_perturbation = "tree_path_dependent".*')

if __name__ == "__main__":
    experiment = Experiment["15VEG_FAPAR"]

    # Operate on cached model / data only.
    get_endog_exog_mask.check_in_store(experiment)
    endog_data, _, master_mask = get_endog_exog_mask(experiment)

    check_master_masks(master_mask)

    get_experiment_split_data.check_in_store(experiment)
    X_train, X_test, y_train, y_test = get_experiment_split_data(experiment)

    get_model(X_train, y_train, cache_check=True)
    rf = get_model(X_train, y_train)

    get_shap_values.check_in_store(rf=rf, X=X_test)
    shap_values = get_shap_values(rf=rf, X=X_test)

    # Analysis / plotting parameters.
    diff_threshold = 0.5
    ptp_threshold_factor = 0.12  # relative to the mean

    chosen_lags = tuple(lag for lag in variable.lags if lag <= 9)
    assert list(chosen_lags) == sorted(chosen_lags)

    map_shap_results = calculate_2d_masked_shap_values(X_train,
Exemplo n.º 12
0
def multi_model_ale_plot(*args, verbose=False, **kwargs):
    # Experiments for which data will be plotted.
    experiments = [
        Experiment["ALL"],
        Experiment["TOP15"],
        Experiment["CURR"],
        Experiment["BEST15"],
        Experiment["15VEG_FAPAR"],
        Experiment["15VEG_LAI"],
        Experiment["15VEG_VOD"],
        Experiment["15VEG_SIF"],
        Experiment["CURRDD_FAPAR"],
        Experiment["CURRDD_LAI"],
        Experiment["CURRDD_VOD"],
        Experiment["CURRDD_SIF"],
    ]

    # Operate on cached data/models only.
    experiment_masks = []
    plotting_experiment_data = {}

    for experiment in tqdm(experiments, desc="Loading data"):
        get_data(experiment, cache_check=True)
        get_experiment_split_data.check_in_store(experiment)
        X_train, X_test, y_train, y_test = get_experiment_split_data(
            experiment)
        get_model(X_train, y_train, cache_check=True)

        experiment_masks.append(get_endog_exog_mask(experiment)[2])
        plotting_experiment_data[experiment] = dict(
            model=get_model(X_train, y_train),
            X_train=X_train,
        )

    # Ensure masks are aligned.
    check_master_masks(*experiment_masks)

    lags = (0, 1, 3, 6, 9)

    for comp_vars in [[variable.FAPAR, variable.LAI],
                      [variable.SIF, variable.VOD]]:
        fig, axes = plt.subplots(5, 2, sharex="col", figsize=(7.0, 5.8))

        # Create general legend labels (with 'X' instead of FAPAR, or LAI, etc...).
        mod_exp_plot_kwargs = deepcopy(experiment_plot_kwargs)
        for plot_kwargs in mod_exp_plot_kwargs.values():
            if plot_kwargs["label"].startswith("15VEG_"):
                plot_kwargs["label"] = "15VEG_X"
            elif plot_kwargs["label"].startswith("CURRDD_"):
                plot_kwargs["label"] = "CURRDD_X"

        x_factor_exp = 0
        x_factor = 10**x_factor_exp
        # x_factor_str = rf"$10^{{{x_factor_exp}}}$"

        y_factor_exp = -4
        y_factor = 10**y_factor_exp
        y_factor_str = rf"$10^{{{y_factor_exp}}}$"

        multi_model_ale_1d(
            comp_vars[0],
            plotting_experiment_data,
            mod_exp_plot_kwargs,
            verbose=verbose,
            legend_bbox=(0.5, 1.01),
            fig=fig,
            axes=axes[:, 0:1],
            lags=lags,
            x_ndigits=2,
            x_factor=x_factor,
            x_rotation=0,
            y_ndigits=0,
            y_factor=y_factor,
        )
        multi_model_ale_1d(
            comp_vars[1],
            plotting_experiment_data,
            experiment_plot_kwargs,
            verbose=verbose,
            legend=False,
            fig=fig,
            axes=axes[:, 1:2],
            lags=lags,
            x_ndigits=2,
            x_factor=x_factor,
            x_rotation=0,
            y_ndigits=0,
            y_factor=y_factor,
        )

        for ax in axes[:, 1]:
            ax.set_ylabel("")
        for ax in axes[:, 0]:
            lag_match = re.search("(\dM)", ax.get_xlabel())
            if lag_match:
                lag_m = f" {lag_match.group(1)}"
            else:
                lag_m = ""
            ax.set_ylabel(f"ALE{lag_m} ({y_factor_str} BA)")
        for ax in axes.flatten():
            ax.set_xlabel("")

        for ax, var in zip(axes[-1], comp_vars):
            assert x_factor_exp == 0
            ax.set_xlabel(
                f"{shorten_features(str(var))} ({variable.units[var]})")

        for ax, title in zip(axes.flatten(), ascii_lowercase):
            ax.text(0.5, 1.05, f"({title})", transform=ax.transAxes)

        margin = 0.4

        for ax in axes.ravel():
            ax.set_xlim(-margin, 20 + margin)

        fig.tight_layout(h_pad=0.4)
        fig.align_labels()

        figure_saver.save_figure(
            fig,
            f"{'__'.join(map(shorten_features, map(str, comp_vars)))}_ale_comp",
            sub_directory="ale_comp",
        )
    # Get training and test data for all variables.
    get_experiment_split_data.check_in_store(Experiment.ALL)
    X_all, _, y, _ = get_experiment_split_data(Experiment.ALL)

    if cache_check:
        return fit_combination.check_in_store(X_all, y, combination, split_index)
    return fit_combination(X_all, y, combination, split_index)


if __name__ == "__main__":
    # Relevant if called with the command 'cx1' instead of 'local'.
    cx1_kwargs = dict(walltime="24:00:00", ncpus=1, mem="7GB")

    # Get training and test data for all variables.
    get_experiment_split_data.check_in_store(Experiment.ALL)
    X_train, X_test, y_train, y_test = get_experiment_split_data(Experiment.ALL)

    shifts = (0, 1, 3, 6, 9)
    assert all(shift in variable.lags for shift in shifts)

    veg_lags = tuple(
        tuple(
            [
                var_factory[shift]
                for var_factory in variable.feature_categories[
                    variable.Category.VEGETATION
                ]
            ]
        )
        for shift in shifts
    )
Exemplo n.º 14
0
def plot_2d_ale(experiment, single=False, nargs=None, verbose=False, **kwargs):
    exp_figure_saver = figure_saver(sub_directory=experiment.name)

    # Operate on cached data only.
    get_experiment_split_data.check_in_store(experiment)
    X_train, X_test, y_train, y_test = get_experiment_split_data(experiment)

    # Operate on cached fitted models only.
    get_model(X_train, y_train, cache_check=True)
    model = get_model(X_train, y_train)

    columns_list = list(combinations(X_train.columns, 2))

    # Deterministic sorting with FAPAR & FAPAR 1M and FAPAR & DRY_DAY_PERIOD at the
    # front since these are used in the paper.

    def get_combination_value(column_combination):
        # Handle special cases first.
        if (
            variable.FAPAR[0] in column_combination
            and variable.FAPAR[1] in column_combination
        ):
            return -1000
        elif (
            variable.FAPAR[0] in column_combination
            and variable.DRY_DAY_PERIOD[0] in column_combination
        ):
            return -999
        out = ""
        for var in column_combination:
            out += str(var.rank) + str(var.shift)
        return int(out)

    columns_list = sorted(columns_list, key=get_combination_value)

    def param_iter():
        for columns in columns_list:
            for plot_samples in [True, False]:
                yield columns, plot_samples

    if single:
        total = 1
    elif nargs:
        total = nargs
    else:
        total = 2 * len(columns_list)

    for columns, plot_samples in tqdm(
        islice(param_iter(), None, total),
        desc=f"2D ALE plotting ({experiment})",
        total=total,
        disable=not verbose,
    ):
        save_ale_2d(
            experiment=experiment,
            model=model,
            train_set=X_train,
            features=columns,
            n_jobs=get_ncpus(),
            include_first_order=True,
            plot_samples=plot_samples,
            figure_saver=exp_figure_saver,
            ale_factor_exp=plotting_configuration.ale_factor_exps.get(
                (columns[0].parent, columns[1].parent), -2
            ),
            x_factor_exp=plotting_configuration.factor_exps.get(columns[0].parent, 0),
            x_ndigits=plotting_configuration.ndigits.get(columns[0].parent, 2),
            y_factor_exp=plotting_configuration.factor_exps.get(columns[1].parent, 0),
            y_ndigits=plotting_configuration.ndigits.get(columns[1].parent, 2),
        )
        plt.close("all")
Exemplo n.º 15
0
def plot_obs_pred_bin(experiment, **kwargs):
    # Operate on cached data/models only.
    get_experiment_split_data.check_in_store(experiment)
    X_train, X_test, y_train, u_val = get_experiment_split_data(experiment)
    get_model(X_train, y_train, cache_check=True)

    u_pre = threading_get_model_predict(
        X_train=X_train,
        y_train=y_train,
        predict_X=X_test,
    )

    min_non_zero_val = u_val[u_val > 0].min()

    x_edges = np.append(0, np.geomspace(min_non_zero_val, 1, 100))
    y_edges = np.geomspace(np.min(u_pre), np.max(u_pre), 100 + 1)

    h = np.histogram2d(u_val, u_pre, bins=[x_edges, y_edges])[0]

    fig, ax = plt.subplots(figsize=(6, 4), dpi=200)
    img = ax.pcolor(
        x_edges,
        y_edges,
        h.T,
        norm=LogNorm(),
    )

    # Plot diagonal 1:1 line.
    plt.plot(
        *((np.geomspace(max(min(u_val), min(u_pre)), min(
            max(u_val), max(u_pre)), 200), ) * 2),
        linestyle="--",
        c="C3",
        lw=2,
    )

    ax.set_xscale("symlog",
                  linthresh=min_non_zero_val,
                  linscale=2e-1,
                  subs=range(2, 10))
    ax.set_yscale("log")

    def offset_sci_format(x, *args, **kwargs):
        canon = get_sci_format(ndigits=0, trim_leading_one=True)(x, None)
        if np.isclose(x, 1e-5):
            return " " * 6 + canon
        elif np.isclose(x, 0):
            return canon + " " * 3
        return canon

    ax.xaxis.set_major_formatter(
        ticker.FuncFormatter(lambda x, pos: offset_sci_format(x)))
    ax.yaxis.set_major_formatter(
        ticker.FuncFormatter(get_sci_format(ndigits=0, trim_leading_one=True)))

    ax.set_xlabel("Observed (BA)")
    ax.set_ylabel("Predicted (BA)")

    ax.set_axisbelow(True)
    ax.grid(zorder=0)

    fig.colorbar(
        img,
        shrink=0.7,
        aspect=30,
        format=get_sci_format(ndigits=0, trim_leading_one=True),
        pad=0.02,
        label="samples",
    )
    figure_saver(sub_directory=experiment.name).save_figure(
        plt.gcf(),
        f"{experiment.name}_obs_pred_bin",
        sub_directory="predictions")
def get_experiment_data(experiment, cache_check=False, **kwargs):
    if cache_check:
        get_experiment_split_data.check_in_store(experiment)
    X_train, X_test, y_train, y_test = get_experiment_split_data(experiment)
    return X_train, X_test, y_train, y_test
Exemplo n.º 17
0
def plot_multi_ale(experiment, verbose=False, **kwargs):
    exp_figure_saver = figure_saver(sub_directory=experiment.name)

    # Operate on cached data only.
    get_experiment_split_data.check_in_store(experiment)
    X_train, X_test, y_train, y_test = get_experiment_split_data(experiment)

    # Operate on cached fitted models only.
    get_model(X_train, y_train, cache_check=True)
    model = get_model(X_train, y_train)

    fig, axes = plt.subplots(1, 2, figsize=(7.05, 2.8))

    expected_veg = tuple(
        map(
            itemgetter(0),
            variable.feature_categories[variable.Category.VEGETATION],
        )
    )

    matched = [f for f in expected_veg if f in X_train.columns]

    if len(matched) == 0:
        raise ValueError(f"Could not find one of {expected_veg} in {X_train.columns}.")
    elif len(matched) > 1:
        raise ValueError(
            f"Found more than one of {tuple(map(str, expected_veg))} in "
            f"{X_train.columns}: {matched}"
        )
    features = (matched[0].parent, variable.DRY_DAY_PERIOD)

    ale_factor_exp = -3
    x_factor_exp = 0

    for feature_factory, ax, title in zip(
        tqdm(features, desc="Processing features"),
        axes,
        ("(a)", "(b)"),
    ):
        multi_ale_1d(
            model=model,
            X_train=X_train,
            features=[feature_factory[lag] for lag in variable.lags[:5]],
            train_response=y_train,
            fig=fig,
            ax=ax,
            verbose=verbose,
            monte_carlo_rep=100,
            monte_carlo_ratio=get_frac_train_nr_samples(Experiment["15VEG_FAPAR"], 0.1),
            legend=False,
            ale_factor_exp=ale_factor_exp,
            x_factor_exp=x_factor_exp,
            x_ndigits=plotting_configuration.ndigits.get(feature_factory, 2),
            x_skip=4,
            x_rotation=0,
        )
        ax.set_title(title)
        ax.set_xlabel(
            f"{shorten_features(str(feature_factory))} ({variable.units[feature_factory]})"
            if x_factor_exp == 0
            else (
                f"{feature_factory} ($10^{{{x_factor_exp}}}$ "
                f"{variable.units[feature_factory]})"
            ),
        )

    axes[1].set_ylabel("")

    # Inset axis to pronounce low-DD features.

    ax2 = inset_axes(
        axes[1],
        width=2.155,
        height=1.55,
        loc="lower left",
        bbox_to_anchor=(0.019, 0.225),
        bbox_transform=ax.transAxes,
    )
    # Plot the DD data again on the inset axis.
    multi_ale_1d(
        model=model,
        X_train=X_train,
        features=[features[1][lag] for lag in variable.lags[:5]],
        train_response=y_train,
        fig=fig,
        ax=ax2,
        verbose=verbose,
        monte_carlo_rep=100,
        monte_carlo_ratio=get_frac_train_nr_samples(Experiment["15VEG_FAPAR"], 0.1),
        legend=False,
        ale_factor_exp=ale_factor_exp,
    )

    ax2.set_xlim(0, 17.5)
    ax2.set_ylim(-1.5e-3, 2e-3)

    ax2.xaxis.set_major_formatter(ticker.ScalarFormatter())
    ax2.yaxis.set_major_formatter(ticker.ScalarFormatter())
    ax2.tick_params(axis="both", which="both", length=0)
    plt.setp(ax2.get_xticklabels(), visible=False)
    plt.setp(ax2.get_yticklabels(), visible=False)

    ax2.set_ylabel("")
    ax2.set_xlabel("")
    ax2.grid(True)

    mark_inset(axes[1], ax2, loc1=4, loc2=2, fc="none", ec="0.3")

    # Move the first (left) axis to the right.
    orig_bbox = axes[0].get_position()
    axes[0].set_position(
        [orig_bbox.xmin + 0.021, orig_bbox.ymin, orig_bbox.width, orig_bbox.height]
    )

    # Explicitly set the x-axis labels' positions so they line up horizontally.
    y_min = 1
    for ax in axes:
        bbox = ax.get_position()
        if bbox.ymin < y_min:
            y_min = bbox.ymin
    for ax in axes:
        bbox = ax.get_position()
        mean_x = (bbox.xmin + bbox.xmax) / 2.0
        # NOTE - Decrease the negative offset to move the label upwards.
        ax.xaxis.set_label_coords(mean_x, y_min - 0.1, transform=fig.transFigure)

    # Plot the legend in between the two axes.
    axes[1].legend(
        loc="center",
        ncol=5,
        bbox_to_anchor=(
            np.mean(
                [
                    axes[0].get_position().xmax,
                    axes[1].get_position().xmin,
                ]
            ),
            0.932,
        ),
        bbox_transform=fig.transFigure,
        handletextpad=0.25,
        columnspacing=0.5,
    )

    exp_figure_saver.save_figure(
        fig,
        f'{experiment.name}_{"__".join(map(shorten_features, map(str, features)))}_ale_shifts',
        sub_directory="multi_ale",
        transparent=False,
    )
Exemplo n.º 18
0
def plot_score_groups(experiments, **kwargs):
    scores = {}
    for experiment in experiments:
        # Operate on cached data only.
        get_experiment_split_data.check_in_store(experiment)
        X_train, X_test, y_train, y_test = get_experiment_split_data(experiment)

        # Operate on cached fitted models only.
        get_model(X_train, y_train, cache_check=True)
        model = get_model(X_train, y_train)

        # Cached scores only.
        get_model_scores.check_in_store(model, X_test, X_train, y_test, y_train)
        scores[experiment] = get_model_scores(model, X_test, X_train, y_test, y_train)

    # Sort scores based on the validation R2 score.
    sort_indices = np.argsort([score["test_r2"] for score in scores.values()])[::-1]

    # Sorted values.
    s_train_r2s = np.array([score["train_r2"] for score in scores.values()])[
        sort_indices
    ]
    s_validation_r2s = np.array([score["test_r2"] for score in scores.values()])[
        sort_indices
    ]
    s_oob_r2s = np.array([score["oob_r2"] for score in scores.values()])[sort_indices]

    # Adapted from: https://matplotlib.org/gallery/subplots_axes_and_figures/broken_axis.html

    # Ratio of training R2 range to validation R2 range.
    train_validation_ratio = np.ptp(s_train_r2s) / np.ptp(s_validation_r2s)

    fig = plt.figure(figsize=(4, 2.2), dpi=200)

    all_ax = fig.add_subplot(1, 1, 1)
    all_ax.set_ylabel(r"$\mathrm{R}^2$", labelpad=29)
    all_ax.set_xticks([])
    all_ax.set_yticks([])
    all_ax.set_frame_on(
        False
    )  # So we don't get black bars showing through the 'broken' gap.

    # Break the y-axis into 2 parts.
    # fig, (ax1, ax2) = plt.subplots(2, 1, sharex=True, figsize=(6, 3.5))
    ax1, ax2 = fig.subplots(
        2, 1, sharex=True, gridspec_kw=dict(height_ratios=[train_validation_ratio, 1])
    )
    fig.subplots_adjust(hspace=0.05)  # adjust space between axes

    # Plot train and validation R2s.

    train_kwargs = dict(linestyle="", marker="x", c="C1", label="train")
    ax1.plot(s_train_r2s, **train_kwargs)

    validation_kwargs = dict(linestyle="", marker="o", c="C0", label="validation")
    ax2.plot(s_validation_r2s, **validation_kwargs)

    oob_kwargs = dict(linestyle="", marker="^", c="C2", label="train OOB")
    ax2.plot(s_oob_r2s, **oob_kwargs)

    ax2.set_yticks(np.arange(0.575, 0.7 + 0.01, 0.025))

    ax2.legend(
        handles=[
            Line2D([0], [0], **kwargs)
            for kwargs in (train_kwargs, validation_kwargs, oob_kwargs)
        ],
        loc="lower left",
    )

    ylim_1 = ax1.get_ylim()
    ylim_2 = ax2.get_ylim()

    margin_f = (0.22, 0.05)  # Two-sided relative margin addition.
    ax1.set_ylim(
        [
            op(ylim_val, factor * np.ptp(ylim_1))
            for ylim_val, factor, op in zip(ylim_1, margin_f, (sub, add))
        ]
    )
    ax2.set_ylim(
        [
            op(ylim_val, factor * np.ptp(ylim_1) / train_validation_ratio)
            for ylim_val, factor, op in zip(ylim_2, margin_f, (sub, add))
        ]
    )
    # ax2.set_ylim(ylim_2[0], ylim_2[1] + margin_f * np.ptp(ylim_1) / train_validation_ratio)

    # hide the spines between ax and ax2
    ax1.spines["bottom"].set_visible(False)
    ax2.spines["top"].set_visible(False)
    ax1.xaxis.tick_top()
    ax1.tick_params(labeltop=False)  # don't put tick labels at the top
    ax1.xaxis.set_ticks_position("none")  # hide top ticks themselves (not just labels)

    ax2.xaxis.tick_bottom()

    ax2.set_xticks(list(range(len(experiments))))
    ax2.set_xticklabels(
        list(np.array(list(map(attrgetter("name"), scores)))[sort_indices]),
        rotation=45,
        ha="right",
    )
    ax2.tick_params(axis="x", which="major", pad=0)

    # Now, let's turn towards the cut-out slanted lines.
    # We create line objects in axes coordinates, in which (0,0), (0,1),
    # (1,0), and (1,1) are the four corners of the axes.
    # The slanted lines themselves are markers at those locations, such that the
    # lines keep their angle and position, independent of the axes size or scale
    # Finally, we need to disable clipping.

    d = 0.5  # proportion of vertical to horizontal extent of the slanted line
    kwargs = dict(
        marker=[(-1, -d), (1, d)],
        markersize=8,
        linestyle="none",
        color="k",
        mec="k",
        mew=1,
        clip_on=False,
    )
    ax1.plot([0, 1], [0, 0], transform=ax1.transAxes, **kwargs)
    ax2.plot([0, 1], [1, 1], transform=ax2.transAxes, **kwargs)

    for ax in (ax1, ax2):
        ax.set_xticks(list(range(len(experiments))))

    figure_saver.save_figure(fig, "model_comp_scores")
logger = logging.getLogger(__name__)
enable_logging(level="WARNING")

warnings.filterwarnings("ignore", ".*Collapsing a non-contiguous coordinate.*")
warnings.filterwarnings("ignore", ".*DEFAULT_SPHERICAL_EARTH_RADIUS.*")
warnings.filterwarnings("ignore", ".*guessing contiguous bounds.*")

warnings.filterwarnings(
    "ignore", 'Setting feature_perturbation = "tree_path_dependent".*'
)


if __name__ == "__main__":
    # Only carry out the analysis on the ALL model.
    X_train, X_test, y_train, y_test = get_experiment_split_data(Experiment.ALL)

    client = get_client(fallback=False)

    parameters_RF = {
        "n_estimators": [500, 1000],
        "max_depth": [16, 18],
        "min_samples_split": [2, 3],
        "min_samples_leaf": [1, 2, 3],
        "max_features": ["auto"],
        "ccp_alpha": np.linspace(0, 4e-9, 2),
    }

    results = fit_dask_sub_est_random_search_cv(
        DaskRandomForestRegressor(**default_param_dict),
        X_train.values,
def prediction_comparisons():
    """Compare ALL and CURR predictions."""
    experiments = [Experiment.ALL, Experiment.CURR]
    # Operate on cached data/models only.

    experiment_data = {}
    experiment_models = {}

    for experiment in experiments:
        get_data(experiment, cache_check=True)
        get_experiment_split_data.check_in_store(experiment)
        X_train, X_test, y_train, y_test = get_experiment_split_data(experiment)
        get_model(X_train, y_train, cache_check=True)

        experiment_data[experiment] = get_endog_exog_mask(experiment)
        experiment_models[experiment] = get_model(X_train, y_train)

    # Ensure masks are aligned.
    check_master_masks(*(data[2] for data in experiment_data.values()))

    master_mask = next(iter(experiment_data.values()))[2]

    # Record predictions and errors.
    experiment_predictions = {}
    experiment_errors = {}
    map_experiment_predictions = {}
    map_experiment_errors = {}

    for experiment in experiments:
        X_train, X_test, y_train, y_test = get_experiment_split_data(experiment)
        predicted_test = threading_get_model_predict(
            X_train=X_train,
            y_train=y_train,
            predict_X=X_test,
        )

        print("Experiment:", experiment.name)
        print("mean observed test:", np.mean(y_test.values))
        print("mean predicted test:", np.mean(predicted_test))
        print("lowest observed test:", np.min(y_test.values))
        print(
            "fraction of times this occurs:",
            np.sum(y_test.values == np.min(y_test.values)) / y_test.values.size,
        )
        print("lowest test prediction:", np.min(predicted_test))

        experiment_predictions[experiment] = predicted_test
        experiment_errors[experiment] = y_test.values - predicted_test

        map_experiment_predictions[experiment] = get_mm_data(
            experiment_predictions[experiment], master_mask, kind="val"
        )
        map_experiment_errors[experiment] = get_mm_data(
            experiment_errors[experiment], master_mask, kind="val"
        )

    error_mag_diff = np.abs(map_experiment_errors[experiments[1]]) - np.abs(
        map_experiment_errors[experiments[0]]
    )

    y_test = get_experiment_split_data(experiment)[3]

    rel_error_mag_diff = np.mean(error_mag_diff, axis=0) / np.mean(
        get_mm_data(y_test.values, master_mask, kind="val"), axis=0
    )
    all_rel = get_unmasked(rel_error_mag_diff)

    print(f"% >0: {100 * np.sum(all_rel > 0) / all_rel.size:0.1f}")
    print(f"% <0: {100 * np.sum(all_rel < 0) / all_rel.size:0.1f}")

    fig, ax, cbar = disc_cube_plot(
        dummy_lat_lon_cube(rel_error_mag_diff),
        bin_edges=(-0.5, 0, 0.5),
        extend="both",
        cmap="PiYG",
        cmap_midpoint=0,
        cmap_symmetric=False,
        cbar_label=f"<|Err({experiments[1].name})| - |Err({experiments[0].name})|> / <Ob.>",
        cbar_shrink=0.3,
        cbar_aspect=15,
        cbar_extendfrac=0.1,
        cbar_pad=0.02,
        cbar_format=None,
        **get_aux0_aux1_kwargs(y_test, master_mask),
        loc=(0.79, 0.14),
        height=0.05,
        aspect=1.25,
        spacing=0.06 * 0.2,
    )
    cbar.ax.yaxis.label.set_size(7)

    map_figure_saver.save_figure(
        fig, f"rel_error_mag_diff_{'_'.join(map(attrgetter('name'), experiments))}"
    )