Exemplo n.º 1
0
def get_streamflow_at(lon=-100., lat=50., data_source_base_dir="",
                      period=None, varname=default_varname_mappings.STREAMFLOW):


    """
    Uses caching
    :param lon:
    :param lat:
    :param data_source_base_dir:
    :param period:
    :param varname:
    :return:
    """
    cache_dir = Path("point_data_cache")
    cache_dir.mkdir(parents=True, exist_ok=True)

    bd_sha = hashlib.sha224(data_source_base_dir.encode()).hexdigest()

    cache_file = cache_dir / f"{varname}_lon{lon}_lat{lat}_{period.start}-{period.end}_{bd_sha}.bin"


    if cache_file.exists():
        return pickle.load(cache_file.open("rb"))

    vname_to_level_erai = {
        T_AIR_2M: VerticalLevel(1, level_kinds.HYBRID),
        U_WE: VerticalLevel(1, level_kinds.HYBRID),
        V_SN: VerticalLevel(1, level_kinds.HYBRID),
    }

    vname_map = {}
    vname_map.update(vname_map_CRCM5)

    store_config = {
            DataManager.SP_BASE_FOLDER: data_source_base_dir,
            DataManager.SP_DATASOURCE_TYPE: data_source_types.SAMPLES_FOLDER_FROM_CRCM_OUTPUT,
            DataManager.SP_INTERNAL_TO_INPUT_VNAME_MAPPING: vname_map,
            DataManager.SP_LEVEL_MAPPING: vname_to_level_erai,
            DataManager.SP_OFFSET_MAPPING: vname_to_offset_CRCM5,
            DataManager.SP_MULTIPLIER_MAPPING: vname_to_multiplier_CRCM5,
            DataManager.SP_VARNAME_TO_FILENAME_PREFIX_MAPPING: vname_to_fname_prefix_CRCM5,
    }

    dm = DataManager(store_config=store_config)


    lons_ = np.asarray([lon])
    lats_ = np.asarray([lat])

    data = dm.read_data_for_period_and_interpolate(
        period=period, varname_internal=varname,
        lons_target=lons_, lats_target=lats_
    )

    pickle.dump(data, cache_file.open("wb"))
    return data
Exemplo n.º 2
0
def __get_maximum_storage_and_corresponding_dates(start_year:int, end_year:int, data_manager:DataManager, storage_varname=""):
    cache_file_current = "cache_{}-{}_calculate_flood_storage_{}.nc".format(start_year, end_year, storage_varname)
    cache_file_current = Path(cache_file_current)

    # if the variables were calculated already
    if cache_file_current.exists():
        ds = xarray.open_dataset(str(cache_file_current))
    else:
        data_current = data_manager.get_min_max_avg_for_period(
            start_year=start_year, end_year=end_year, varname_internal=storage_varname
        )

        ds = xarray.merge([da for da in data_current.values()])
        ds.to_netcdf(str(cache_file_current))

    return ds
Exemplo n.º 3
0
def __get_maximum_storage_and_corresponding_dates(start_year: int,
                                                  end_year: int,
                                                  data_manager: DataManager,
                                                  storage_varname=""):
    cache_file_current = "cache_{}-{}_calculate_flood_storage_{}.nc".format(
        start_year, end_year, storage_varname)
    cache_file_current = Path(cache_file_current)

    # if the variables were calculated already
    if cache_file_current.exists():
        ds = xarray.open_dataset(str(cache_file_current))
    else:
        data_current = data_manager.get_min_max_avg_for_period(
            start_year=start_year,
            end_year=end_year,
            varname_internal=storage_varname)

        ds = xarray.merge([da for da in data_current.values()])
        ds.to_netcdf(str(cache_file_current))

    return ds
Exemplo n.º 4
0
def calculate_lake_effect_snowfall(label_to_config, period=None):
    """

    :param label_to_config:
    :param period:  The period of interest defined by the start and the end year of the period (inclusive)
    """

    assert hasattr(period, "months_of_interest")

    for label, the_config in label_to_config.items():
        data_manager = DataManager(store_config=the_config)

        if "out_folder" in the_config:
            out_folder = the_config["out_folder"]
        else:
            out_folder = "."

        calculate_enh_lakeffect_snowfall_for_a_datasource(
            data_mngr=data_manager,
            label=label,
            period=period,
            out_folder=out_folder)
def main(label_to_data_path: dict, var_pairs: list,
         periods_info: CcPeriodsInfo,
         vname_display_names=None,
         season_to_months: dict = None,
         cur_label=common_params.crcm_nemo_cur_label,
         fut_label=common_params.crcm_nemo_fut_label,
         hles_region_mask=None, lakes_mask=None):
    # get a flat list of all the required variable names (unique)
    varnames = []
    for vpair in var_pairs:
        for v in vpair:
            if v not in varnames:
                varnames.append(v)

    print(f"Considering {varnames}, based on {var_pairs}")

    if vname_display_names is None:
        vname_display_names = {}

    varname_mapping = {v: v for v in varnames}
    level_mapping = {v: VerticalLevel(0) for v in
                     varnames}  # Does not really make a difference, since all variables are 2d

    comon_store_config = {
        DataManager.SP_DATASOURCE_TYPE: data_source_types.ALL_VARS_IN_A_FOLDER_IN_NETCDF_FILES,
        DataManager.SP_INTERNAL_TO_INPUT_VNAME_MAPPING: varname_mapping,
        DataManager.SP_LEVEL_MAPPING: level_mapping
    }

    cur_dm = DataManager(
        store_config=dict({DataManager.SP_BASE_FOLDER: label_to_data_path[cur_label]}, **comon_store_config)
    )

    fut_dm = DataManager(
        store_config=dict({DataManager.SP_BASE_FOLDER: label_to_data_path[fut_label]}, **comon_store_config)
    )

    # get the data and do calculations
    label_to_vname_to_season_to_data = {}

    cur_start_yr, cur_end_year = periods_info.get_cur_year_limits()
    fut_start_yr, fut_end_year = periods_info.get_fut_year_limits()


    #load coordinates in memory
    cur_dm.read_data_for_period(Period(datetime(cur_start_yr, 1, 1), datetime(cur_start_yr, 1, 2)), varname_internal=varnames[0])

    label_to_vname_to_season_to_data = {
        cur_label: {}, fut_label: {}
    }

    for vname in varnames:
        cur_means = cur_dm.get_seasonal_means(start_year=cur_start_yr, end_year=cur_end_year,
                                              season_to_months=season_to_months, varname_internal=vname)

        fut_means = fut_dm.get_seasonal_means(start_year=fut_start_yr, end_year=fut_end_year,
                                              season_to_months=season_to_months, varname_internal=vname)

        label_to_vname_to_season_to_data[cur_label][vname] = cur_means
        label_to_vname_to_season_to_data[fut_label][vname] = fut_means


    if hles_region_mask is None:
        data_field = label_to_vname_to_season_to_data[common_params.crcm_nemo_cur_label][list(season_to_months.keys())[0]]
        hles_region_mask = np.ones_like(data_field)



    correlation_data = calculate_correlations_and_pvalues(var_pairs, label_to_vname_to_season_to_data,
                                                          season_to_months=season_to_months,
                                                          region_of_interest_mask=hles_region_mask,
                                                          lats=cur_dm.lats, lakes_mask=lakes_mask)


    # Calculate mean seasonal temperature
    label_to_season_to_tt_mean = {}
    for label, vname_to_season_to_data in label_to_vname_to_season_to_data.items():
        label_to_season_to_tt_mean[label] = {}
        for season, yearly_data in vname_to_season_to_data["TT"].items():
            label_to_season_to_tt_mean[label][season] = np.mean([f for f in yearly_data.values()], axis=0)



    # do the plotting
    fig = plt.figure()

    ncols = len(season_to_months)
    nrows = len(var_pairs) * len(label_to_vname_to_season_to_data)

    gs = GridSpec(nrows, ncols, wspace=0, hspace=0)

    for col, season in enumerate(season_to_months):
        row = 0

        for vpair in var_pairs:
            for label in sorted(label_to_vname_to_season_to_data):
                ax = fig.add_subplot(gs[row, col], projection=cartopy.crs.PlateCarree())

                r, pv = correlation_data[vpair][label][season]

                r[np.isnan(r)] = 0
                r = np.ma.masked_where(~hles_region_mask, r)
                ax.set_facecolor("0.75")

                # hide the ticks
                ax.xaxis.set_major_locator(NullLocator())
                ax.yaxis.set_major_locator(NullLocator())

                im = ax.pcolormesh(cur_dm.lons, cur_dm.lats, r, cmap=cm.get_cmap("bwr", 11), vmin=-1, vmax=1)

                # add 0 deg line
                cs = ax.contour(cur_dm.lons, cur_dm.lats, label_to_season_to_tt_mean[label][season], levels=[0,],
                                linewidths=1, colors="k")
                ax.set_extent([cur_dm.lons[0, 0], cur_dm.lons[-1, -1], cur_dm.lats[0, 0], cur_dm.lats[-1, -1]])

                ax.background_patch.set_facecolor("0.75")

                if row == 0:
                    # ax.set_title(season + f", {vname_display_names[vpair[0]]}")
                    ax.text(0.5, 1.05, season, transform=ax.transAxes,
                            va="bottom", ha="center", multialignment="center")

                if col == 0:
                    # ax.set_ylabel(f"HLES\nvs {vname_display_names[vpair[1]]}\n{label}")
                    ax.text(-0.05, 0.5, f"HLES\nvs {vname_display_names[vpair[1]]}\n{label}",
                            va="center", ha="right",
                            multialignment="center",
                            rotation=90,
                            transform=ax.transAxes)


                divider = make_axes_locatable(ax)
                ax_cb = divider.new_horizontal(size="5%", pad=0.1, axes_class=plt.Axes)
                fig.add_axes(ax_cb)
                cb = plt.colorbar(im, extend="both", cax=ax_cb)

                if row < nrows - 1 or col < ncols - 1:
                    cb.ax.set_visible(False)

                row += 1

    img_dir = common_params.img_folder
    img_dir.mkdir(exist_ok=True)

    img_file = img_dir / "hles_tt_pr_correlation_fields_cur_and_fut_mean_ice_fraction.png"
    fig.savefig(str(img_file), **common_params.image_file_options)
Exemplo n.º 6
0
def main():
    # dask.set_options(pool=ThreadPool(20))
    img_folder = Path("nei_validation")
    img_folder.mkdir(parents=True, exist_ok=True)

    pval_crit = 0.1

    start_year = 1980
    end_year = 1998

    # TT_min and TT_max mean daily min and maximum temperatures
    var_names = [
        default_varname_mappings.T_AIR_2M_DAILY_MAX,
        default_varname_mappings.T_AIR_2M_DAILY_MIN,
        default_varname_mappings.TOTAL_PREC
    ]

    var_name_to_rolling_window_days = {
        default_varname_mappings.T_AIR_2M_DAILY_MIN: 5,
        default_varname_mappings.T_AIR_2M_DAILY_MAX: 5,
        default_varname_mappings.TOTAL_PREC: 29
    }

    var_name_to_percentile = {
        default_varname_mappings.T_AIR_2M_DAILY_MIN: 0.9,
        default_varname_mappings.T_AIR_2M_DAILY_MAX: 0.1,
        default_varname_mappings.TOTAL_PREC: 0.9,
    }

    # needed for the 3hourly temperature model outputs, when Tmin and Tmax daily are not available
    var_name_to_daily_agg_func = {
        default_varname_mappings.TOTAL_PREC: np.mean,
        default_varname_mappings.T_AIR_2M_DAILY_MAX: np.max,
        default_varname_mappings.T_AIR_2M_DAILY_MIN: np.min,
        default_varname_mappings.T_AIR_2M_DAILY_AVG: np.mean
    }

    model_vname_to_multiplier = {
        default_varname_mappings.TOTAL_PREC: 1000 * 24 * 3600
    }

    WC_044_DEFAULT_LABEL = "WC_0.44deg_default"
    WC_044_CTEM_FRSOIL_DYNGLA_LABEL = "WC_0.44deg_ctem+frsoil+dyngla"
    WC_011_CTEM_FRSOIL_DYNGLA_LABEL = "WC_0.11deg_ctem+frsoil+dyngla"

    sim_paths = OrderedDict()
    sim_paths[WC_011_CTEM_FRSOIL_DYNGLA_LABEL] = Path(
        "/snow3/huziy/NEI/WC/NEI_WC0.11deg_Crr1/Samples")
    sim_paths[WC_044_DEFAULT_LABEL] = Path(
        "/snow3/huziy/NEI/WC/NEI_WC0.44deg_default/Samples")
    sim_paths[WC_044_CTEM_FRSOIL_DYNGLA_LABEL] = Path(
        "/snow3/huziy/NEI/WC/debug_NEI_WC0.44deg_Crr1/Samples")

    mod_spatial_scales = OrderedDict([(WC_044_DEFAULT_LABEL, 0.44),
                                      (WC_044_CTEM_FRSOIL_DYNGLA_LABEL, 0.44),
                                      (WC_011_CTEM_FRSOIL_DYNGLA_LABEL, 0.11)])

    # -- daymet daily (initial spatial res)
    # daymet_vname_to_path = {
    #     "prcp": "/snow3/huziy/Daymet_daily/daymet_v3_prcp_*_na.nc4",
    #     "tavg": "/snow3/huziy/Daymet_daily/daymet_v3_tavg_*_na.nc4",
    #     "tmin": "/snow3/huziy/Daymet_daily/daymet_v3_tmin_*_na.nc4",
    #     "tmax": "/snow3/huziy/Daymet_daily/daymet_v3_tmax_*_na.nc4",
    # }

    # -- daymet daily (spatially aggregated)
    daymet_vname_to_path = {
        default_varname_mappings.TOTAL_PREC:
        "/snow3/huziy/Daymet_daily_derivatives/daymet_spatial_agg_prcp_10x10",
        default_varname_mappings.T_AIR_2M_DAILY_AVG:
        "/snow3/huziy/Daymet_daily_derivatives/daymet_spatial_agg_tavg_10x10",
        default_varname_mappings.T_AIR_2M_DAILY_MIN:
        "/snow3/huziy/Daymet_daily_derivatives/daymet_spatial_agg_tmin_10x10",
        default_varname_mappings.T_AIR_2M_DAILY_MAX:
        "/snow3/huziy/Daymet_daily_derivatives/daymet_spatial_agg_tmax_10x10",
    }

    daymet_vname_to_model_vname_internal = {
        default_varname_mappings.T_AIR_2M_DAILY_MIN:
        default_varname_mappings.T_AIR_2M,
        default_varname_mappings.T_AIR_2M_DAILY_MAX:
        default_varname_mappings.T_AIR_2M,
        default_varname_mappings.TOTAL_PREC:
        default_varname_mappings.TOTAL_PREC,
    }

    plot_utils.apply_plot_params(font_size=8)

    # observations
    obs_spatial_scale = 0.1  # 10x10 aggregation from ~0.01 daymet data

    varnames_list = [
        default_varname_mappings.TOTAL_PREC,
        default_varname_mappings.T_AIR_2M_DAILY_MIN,
        default_varname_mappings.T_AIR_2M_DAILY_MAX
    ]

    data_dict = {vn: {} for vn in varnames_list}
    bias_dict = {vn: {} for vn in varnames_list}

    # calculate the percentiles for each simulation and obs data (obs data interpolated to the model grid)
    for model_label, base_dir in sim_paths.items():
        # model outputs manager
        dm = DataManager(
            store_config={
                DataManager.SP_BASE_FOLDER:
                base_dir,
                DataManager.SP_DATASOURCE_TYPE:
                data_source_types.SAMPLES_FOLDER_FROM_CRCM_OUTPUT,
                DataManager.SP_INTERNAL_TO_INPUT_VNAME_MAPPING:
                default_varname_mappings.vname_map_CRCM5,
                DataManager.SP_LEVEL_MAPPING:
                default_varname_mappings.vname_to_level_map,
                DataManager.SP_VARNAME_TO_FILENAME_PREFIX_MAPPING:
                default_varname_mappings.vname_to_fname_prefix_CRCM5
            })

        for vname_daymet in varnames_list:

            obs_manager = DataManager(
                store_config={
                    DataManager.SP_BASE_FOLDER:
                    daymet_vname_to_path[vname_daymet],
                    DataManager.SP_DATASOURCE_TYPE:
                    data_source_types.ALL_VARS_IN_A_FOLDER_IN_NETCDF_FILES,
                    DataManager.SP_INTERNAL_TO_INPUT_VNAME_MAPPING:
                    default_varname_mappings.daymet_vname_mapping,
                    DataManager.SP_LEVEL_MAPPING: {}
                })

            vname_model = daymet_vname_to_model_vname_internal[vname_daymet]

            nd_rw = var_name_to_rolling_window_days[vname_daymet]
            q = var_name_to_percentile[vname_daymet]
            daily_agg_func = var_name_to_daily_agg_func[vname_daymet]

            # model data
            # TODO: change for the number of summer days
            mod = dm.compute_climatological_quantiles(
                start_year=start_year,
                end_year=end_year,
                daily_agg_func=daily_agg_func,
                rolling_mean_window_days=nd_rw,
                q=q,
                varname_internal=vname_model)

            mod = mod * model_vname_to_multiplier.get(vname_model, 1)

            data_source_mod = f"{model_label}_ndrw{nd_rw}_q{q}_vn{vname_daymet}_{start_year}-{end_year}"

            # obs data
            nneighbors = int(mod_spatial_scales[model_label] /
                             obs_spatial_scale)
            nneighbors = max(nneighbors, 1)

            obs = obs_manager.compute_climatological_quantiles(
                start_year=start_year,
                end_year=end_year,
                daily_agg_func=
                daily_agg_func,  # does not have effect for daymet data because it is daily
                rolling_mean_window_days=nd_rw,
                q=q,
                varname_internal=vname_daymet,
                lons_target=mod.coords["lon"].values,
                lats_target=mod.coords["lat"].values,
                nneighbors=nneighbors)

            # only use model data wherever the obs is not null
            mod = mod.where(obs.notnull())

            data_source_obs = f"DAYMETaggfor_{model_label}_ndrw{nd_rw}_q{q}_vn{vname_daymet}_{start_year}-{end_year}"

            data_source_diff = f"{model_label}vsDAYMET_ndrw{nd_rw}_q{q}_vn{vname_daymet}_{start_year}-{end_year}"

            # save data for line plots
            data_dict[vname_daymet][data_source_mod] = mod
            data_dict[vname_daymet][data_source_obs] = obs
            bias_dict[vname_daymet][data_source_mod] = mod - obs

            bmap = dm.get_basemap(varname_internal=vname_model,
                                  resolution="i",
                                  area_thresh=area_thresh_km2)

            # plot model data
            plot_monthly_panels(mod,
                                bmap,
                                img_dir=str(img_folder),
                                data_label=data_source_mod,
                                color_levels=clevs["mean"][vname_model],
                                cmap=cmaps["mean"][vname_model])

            # plot obs data
            plot_monthly_panels(obs,
                                bmap,
                                img_dir=str(img_folder),
                                data_label=data_source_obs,
                                color_levels=clevs["mean"][vname_model],
                                cmap=cmaps["mean"][vname_model])

            plot_monthly_panels(mod - obs,
                                bmap,
                                img_dir=str(img_folder),
                                data_label=data_source_diff,
                                color_levels=clevs["mean"][vname_model +
                                                           "diff"],
                                cmap=cmaps["mean"][vname_model + "diff"])

    for vn in data_dict:

        if len(data_dict[vn]) == 0:
            continue

        plot_area_avg(data_dict[vn],
                      bias_dict[vn],
                      panel_titles=(vn, ""),
                      img_dir=img_folder / "extremes_1d")
Exemplo n.º 7
0
def get_seasonal_sst_from_crcm5_outputs(sim_label,
                                        start_year=1980,
                                        end_year=2010,
                                        season_to_months=None,
                                        lons_target=None,
                                        lats_target=None):

    from lake_effect_snow.default_varname_mappings import T_AIR_2M
    from lake_effect_snow.default_varname_mappings import U_WE
    from lake_effect_snow.default_varname_mappings import V_SN
    from lake_effect_snow.base_utils import VerticalLevel
    from rpn import level_kinds
    from lake_effect_snow import default_varname_mappings
    from data.robust import data_source_types

    from data.robust.data_manager import DataManager

    sim_configs = {
        sim_label:
        RunConfig(
            data_path=
            "/RECH2/huziy/coupling/GL_440x260_0.1deg_GL_with_Hostetler/Samples_selected",
            start_year=start_year,
            end_year=end_year,
            label=sim_label),
    }

    r_config = sim_configs[sim_label]

    vname_to_level = {
        T_AIR_2M:
        VerticalLevel(1, level_kinds.HYBRID),
        U_WE:
        VerticalLevel(1, level_kinds.HYBRID),
        V_SN:
        VerticalLevel(1, level_kinds.HYBRID),
        default_varname_mappings.LAKE_WATER_TEMP:
        VerticalLevel(1, level_kinds.ARBITRARY)
    }

    vname_map = {}

    vname_map.update(default_varname_mappings.vname_map_CRCM5)

    store_config = {
        "base_folder": r_config.data_path,
        "data_source_type":
        data_source_types.SAMPLES_FOLDER_FROM_CRCM_OUTPUT_VNAME_IN_FNAME,
        "varname_mapping": vname_map,
        "level_mapping": vname_to_level,
        "offset_mapping": default_varname_mappings.vname_to_offset_CRCM5,
        "multiplier_mapping":
        default_varname_mappings.vname_to_multiplier_CRCM5,
    }

    dm = DataManager(store_config=store_config)

    season_to_year_to_mean = dm.get_seasonal_means(
        start_year=start_year,
        end_year=end_year,
        season_to_months=season_to_months,
        varname_internal=default_varname_mappings.LAKE_WATER_TEMP)

    result = {}

    # fill in the result dictionary with seasonal means
    for season in season_to_months:
        result[season] = np.array([
            field for field in season_to_year_to_mean[season].values()
        ]).mean(axis=0)

    # interpolate the data
    if lons_target is not None:
        xt, yt, zt = lat_lon.lon_lat_to_cartesian(lons_target.flatten(),
                                                  lats_target.flatten())

        dists, inds = dm.get_kdtree().query(list(zip(xt, yt, zt)))
        for season in season_to_months:
            result[season] = result[season].flatten()[inds].reshape(
                lons_target.shape)

    return result
def main():
    # dask.set_options(pool=ThreadPool(20))
    img_folder = Path("nei_validation/meridional_avg")
    img_folder.mkdir(parents=True, exist_ok=True)

    pval_crit = 0.1

    start_year = 1980
    end_year = 2010


    subregion = SubRegionByLonLatCorners(lleft={"lon": -128, "lat": 46}, uright={"lon": -113, "lat": 55})


    season_to_months = {
        "DJF": [12, 1, 2],
        "MAM": range(3, 6),
        "JJA": range(6, 9),
        "SON": range(9, 12)
    }

    # TT_min and TT_max mean daily min and maximum temperatures
    var_names = [
        default_varname_mappings.T_AIR_2M_DAILY_MAX,
        default_varname_mappings.T_AIR_2M_DAILY_MIN,
        default_varname_mappings.TOTAL_PREC
    ]

    var_name_to_rolling_window_days = {
        default_varname_mappings.T_AIR_2M_DAILY_MIN: 5,
        default_varname_mappings.T_AIR_2M_DAILY_MAX: 5,
        default_varname_mappings.TOTAL_PREC: 29
    }

    var_name_to_percentile = {
        default_varname_mappings.T_AIR_2M_DAILY_MIN: 0.9,
        default_varname_mappings.T_AIR_2M_DAILY_MAX: 0.1,
        default_varname_mappings.TOTAL_PREC: 0.9,
    }

    # needed for the 3hourly temperature model outputs, when Tmin and Tmax daily are not available
    var_name_to_daily_agg_func = {
        default_varname_mappings.TOTAL_PREC: np.mean,
        default_varname_mappings.T_AIR_2M_DAILY_MAX: np.max,
        default_varname_mappings.T_AIR_2M_DAILY_MIN: np.min,
        default_varname_mappings.T_AIR_2M_DAILY_AVG: np.mean
    }


    var_name_to_display_units = {
        default_varname_mappings.TOTAL_PREC: "mm/day",
        default_varname_mappings.T_AIR_2M_DAILY_MAX: r"$^\circ$C",
        default_varname_mappings.T_AIR_2M_DAILY_MIN: r"$^\circ$C",
        default_varname_mappings.T_AIR_2M_DAILY_AVG: r"$^\circ$C"
    }



    model_vname_to_multiplier = {
        default_varname_mappings.TOTAL_PREC: 1000 * 24 * 3600
    }


    WC_044_DEFAULT_LABEL = "WC_044_default"
    WC_044_CTEM_FRSOIL_DYNGLA_LABEL = "WC_044_modified"
    WC_011_CTEM_FRSOIL_DYNGLA_LABEL = "WC_011_modified"

    sim_paths = OrderedDict()
    sim_paths[WC_044_DEFAULT_LABEL] = Path("/snow3/huziy/NEI/WC/NEI_WC0.44deg_default/Samples")
    sim_paths[WC_044_CTEM_FRSOIL_DYNGLA_LABEL] = Path("/snow3/huziy/NEI/WC/NEI_WC0.44deg_Crr1/Samples")
    sim_paths[WC_011_CTEM_FRSOIL_DYNGLA_LABEL] = Path("/snow3/huziy/NEI/WC/NEI_WC0.11deg_Crr1/Samples")


    elevation_paths = OrderedDict()
    elevation_paths[WC_044_DEFAULT_LABEL] = "/snow3/huziy/NEI/WC/NEI_WC0.44deg_default/geophys_CORDEX_NA_0.44d_filled_hwsd_dpth_om_MODIS_Glacier_v2_newdirs"
    elevation_paths[WC_044_CTEM_FRSOIL_DYNGLA_LABEL] = "/snow3/huziy/NEI/WC/NEI_WC0.44deg_Crr1/geophys_CORDEX_NA_0.44d_filled_hwsd_dpth_om_MODIS_Glacier_v2_dirs_hshedsfix_CTEM_FRAC_GlVolFix"
    elevation_paths[WC_011_CTEM_FRSOIL_DYNGLA_LABEL] = "/snow3/huziy/NEI/WC/NEI_WC0.11deg_Crr1/nei_geophy_wc_011.rpn"


    mod_spatial_scales = OrderedDict([
        (WC_044_DEFAULT_LABEL, 0.44),
        (WC_044_CTEM_FRSOIL_DYNGLA_LABEL, 0.44),
        (WC_011_CTEM_FRSOIL_DYNGLA_LABEL, 0.11)
    ])

    # -- daymet daily (initial spatial res)
    # daymet_vname_to_path = {
    #     "prcp": "/snow3/huziy/Daymet_daily/daymet_v3_prcp_*_na.nc4",
    #     "tavg": "/snow3/huziy/Daymet_daily/daymet_v3_tavg_*_na.nc4",
    #     "tmin": "/snow3/huziy/Daymet_daily/daymet_v3_tmin_*_na.nc4",
    #     "tmax": "/snow3/huziy/Daymet_daily/daymet_v3_tmax_*_na.nc4",
    # }

    # -- daymet daily (spatially aggregated)
    daymet_vname_to_path = {
        default_varname_mappings.TOTAL_PREC: "/snow3/huziy/Daymet_daily_derivatives/daymet_spatial_agg_prcp_10x10",
        default_varname_mappings.T_AIR_2M_DAILY_AVG: "/snow3/huziy/Daymet_daily_derivatives/daymet_spatial_agg_tavg_10x10",
        default_varname_mappings.T_AIR_2M_DAILY_MIN: "/snow3/huziy/Daymet_daily_derivatives/daymet_spatial_agg_tmin_10x10",
        default_varname_mappings.T_AIR_2M_DAILY_MAX: "/snow3/huziy/Daymet_daily_derivatives/daymet_spatial_agg_tmax_10x10",
    }

    daymet_vname_to_model_vname_internal = {
        default_varname_mappings.T_AIR_2M_DAILY_MIN: default_varname_mappings.T_AIR_2M,
        default_varname_mappings.T_AIR_2M_DAILY_MAX: default_varname_mappings.T_AIR_2M,
        default_varname_mappings.TOTAL_PREC: default_varname_mappings.TOTAL_PREC,
    }

    plot_utils.apply_plot_params(font_size=14)


    # observations
    obs_spatial_scale = 0.1  # 10x10 aggregation from ~0.01 daymet data


    varnames_list = [
        default_varname_mappings.TOTAL_PREC,
        default_varname_mappings.T_AIR_2M_DAILY_MIN,
        default_varname_mappings.T_AIR_2M_DAILY_MAX
    ]

    data_dict = {vn: {} for vn in varnames_list}
    bias_dict = {vn: {} for vn in varnames_list}

    bmap = None
    # calculate the percentiles for each simulation and obs data (obs data interpolated to the model grid)
    for model_label, base_dir in sim_paths.items():
        # model outputs manager
        dm = DataManager(
            store_config={
                DataManager.SP_BASE_FOLDER: base_dir,
                DataManager.SP_DATASOURCE_TYPE: data_source_types.SAMPLES_FOLDER_FROM_CRCM_OUTPUT,
                DataManager.SP_INTERNAL_TO_INPUT_VNAME_MAPPING: default_varname_mappings.vname_map_CRCM5,
                DataManager.SP_LEVEL_MAPPING: default_varname_mappings.vname_to_level_map,
                DataManager.SP_VARNAME_TO_FILENAME_PREFIX_MAPPING: default_varname_mappings.vname_to_fname_prefix_CRCM5
            }
        )


        for vname_daymet in varnames_list:

            obs_manager = DataManager(
                store_config={
                    DataManager.SP_BASE_FOLDER: daymet_vname_to_path[vname_daymet],
                    DataManager.SP_DATASOURCE_TYPE: data_source_types.ALL_VARS_IN_A_FOLDER_IN_NETCDF_FILES,
                    DataManager.SP_INTERNAL_TO_INPUT_VNAME_MAPPING: default_varname_mappings.daymet_vname_mapping,
                    DataManager.SP_LEVEL_MAPPING: {}
                }
            )

            vname_model = daymet_vname_to_model_vname_internal[vname_daymet]

            nd_rw = var_name_to_rolling_window_days[vname_daymet]
            q = var_name_to_percentile[vname_daymet]
            daily_agg_func = var_name_to_daily_agg_func[vname_daymet]



            # model data
            mod = dm.compute_climatological_quantiles(start_year=start_year, end_year=end_year,
                                                      daily_agg_func=daily_agg_func,
                                                      rolling_mean_window_days=nd_rw,
                                                      q=q,
                                                      varname_internal=vname_model)


            mod = mod * model_vname_to_multiplier.get(vname_model, 1)

            data_source_mod = f"{model_label}_ndrw{nd_rw}_q{q}_vn{vname_daymet}_{start_year}-{end_year}"



            # obs data
            nneighbors = int(mod_spatial_scales[model_label] / obs_spatial_scale)
            nneighbors = max(nneighbors, 1)


            obs = obs_manager.compute_climatological_quantiles(start_year=start_year,
                                                               end_year=end_year,
                                                               daily_agg_func=daily_agg_func,  # does not have effect for daymet data because it is daily
                                                               rolling_mean_window_days=nd_rw,
                                                               q=q,
                                                               varname_internal=vname_daymet,
                                                               lons_target=mod.coords["lon"].values,
                                                               lats_target=mod.coords["lat"].values,
                                                               nneighbors=nneighbors)


            # only use model data wherever the obs is not null
            mod = mod.where(obs.notnull())



            data_source_obs = f"DAYMETaggfor_{model_label}_ndrw{nd_rw}_q{q}_vn{vname_daymet}_{start_year}-{end_year}"

            data_source_diff = f"{model_label}vsDAYMET_ndrw{nd_rw}_q{q}_vn{vname_daymet}_{start_year}-{end_year}"



            mask, ij_ll, ij_ur = subregion.to_mask(mod.coords["lon"].values, mod.coords["lat"].values)

            mod = mod[:, ij_ll[0]:ij_ur[0] + 1, ij_ll[1]:ij_ur[1] + 1]
            obs = obs[:, ij_ll[0]:ij_ur[0] + 1, ij_ll[1]:ij_ur[1] + 1]

            # set the units to display them during pltting
            mod.attrs["units"] = var_name_to_display_units[vname_daymet]
            obs.attrs["units"] = var_name_to_display_units[vname_daymet]


            # save data for line plots
            data_dict[vname_daymet][data_source_mod] = mod
            data_dict[vname_daymet][data_source_obs] = obs
            bias_dict[vname_daymet][data_source_mod] = mod - obs

            if bmap is None:
                bmap = dm.get_basemap(varname_internal=vname_model, resolution="i", area_thresh=area_thresh_km2)



    # Just here what the graphs mean
    vn_to_title = {
        default_varname_mappings.TOTAL_PREC: "PR90",
        default_varname_mappings.T_AIR_2M_DAILY_MIN: "TN90",
        default_varname_mappings.T_AIR_2M_DAILY_MAX: "TX10"
    }


    elev_field_name = "ME"
    meridional_mean_elev_dict = get_meridional_avg_elevation(geo_path_dict=elevation_paths,
                                                             subregion=subregion,
                                                             elev_field_name=elev_field_name)

    topo_map = get_topo_map(geo_path=elevation_paths[WC_011_CTEM_FRSOIL_DYNGLA_LABEL], elev_field_name=elev_field_name)

    for vn in data_dict:

        if len(data_dict[vn]) == 0:
            continue

        plot_meridional_mean(data_dict[vn], bias_dict[vn], panel_titles=(vn_to_title[vn] + " (annual)", ""),
                             img_dir=img_folder, bmap=bmap, meridional_elev_dict=meridional_mean_elev_dict,
                             map_topo=topo_map)

        for sname, months in season_to_months.items():
            plot_meridional_mean(data_dict[vn], bias_dict[vn], panel_titles=("", ""),
                                 img_dir=img_folder, bmap=bmap,
                                 months=months, season_name=sname,
                                 meridional_elev_dict=meridional_mean_elev_dict,
                                 map_topo=None, plot_values=False,
                                 lon_min=236, lon_max=247,
                                 plot_legend=(vn == default_varname_mappings.T_AIR_2M_DAILY_MAX) and (sname == "SON")
                                 )
def main():
    direction_file_path = Path(
        "/RECH2/huziy/BC-MH/bc_mh_044deg/Samples/bc_mh_044deg_198001/pm1980010100_00000000p"
    )

    sim_label = "mh_0.44"

    start_year = 1981
    end_year = 2010

    streamflow_internal_name = "streamflow"
    selected_staion_ids = constants.selected_station_ids_for_streamflow_validation

    # ======================================================

    day = timedelta(days=1)
    t0 = datetime(2001, 1, 1)
    stamp_dates = [t0 + i * day for i in range(365)]
    print("stamp dates range {} ... {}".format(stamp_dates[0],
                                               stamp_dates[-1]))

    lake_fraction = None

    # establish the correspondence between the stations and model grid points
    with RPN(str(direction_file_path)) as r:
        assert isinstance(r, RPN)
        fldir = r.get_first_record_for_name("FLDR")
        flow_acc_area = r.get_first_record_for_name("FAA")
        lons, lats = r.get_longitudes_and_latitudes_for_the_last_read_rec()
        # lake_fraction = r.get_first_record_for_name("LF1")

    cell_manager = CellManager(fldir,
                               lons2d=lons,
                               lats2d=lats,
                               accumulation_area_km2=flow_acc_area)
    stations = stfl_stations.load_stations_from_csv(
        selected_ids=selected_staion_ids)
    station_to_model_point = cell_manager.get_model_points_for_stations(
        station_list=stations, lake_fraction=lake_fraction, nneighbours=8)

    # Update the end year if required
    max_year_st = -1
    for station in station_to_model_point:
        y = max(station.get_list_of_complete_years())
        if y >= max_year_st:
            max_year_st = y

    if end_year > max_year_st:
        print("Updated end_year to {}, because no obs data after...".format(
            max_year_st))
        end_year = max_year_st

    # read model data
    mod_data_manager = DataManager(
        store_config={
            "varname_mapping": {
                streamflow_internal_name: "STFA"
            },
            "base_folder": str(direction_file_path.parent.parent),
            "data_source_type":
            data_source_types.SAMPLES_FOLDER_FROM_CRCM_OUTPUT,
            "level_mapping": {
                streamflow_internal_name:
                VerticalLevel(-1, level_type=level_kinds.ARBITRARY)
            },
            "offset_mapping": vname_to_offset_CRCM5,
            "filename_prefix_mapping": {
                streamflow_internal_name: "pm"
            }
        })

    station_to_model_data = defaultdict(list)
    for year in range(start_year, end_year + 1):
        start = Pendulum(year, 1, 1)
        p_test = Period(start, start.add(years=1).subtract(microseconds=1))
        stfl_mod = mod_data_manager.read_data_for_period(
            p_test, streamflow_internal_name)

        # convert to daily
        stfl_mod = stfl_mod.resample("D",
                                     "t",
                                     how="mean",
                                     closed="left",
                                     keep_attrs=True)

        assert isinstance(stfl_mod, xr.DataArray)

        for station, model_point in station_to_model_point.items():
            assert isinstance(model_point, ModelPoint)
            ts1 = stfl_mod[:, model_point.ix, model_point.jy].to_series()
            station_to_model_data[station].append(
                pd.Series(index=stfl_mod.t.values, data=ts1))

    # concatenate the timeseries for each point, if required
    if end_year - start_year + 1 > 1:
        for station in station_to_model_data:
            station_to_model_data[station] = pd.concat(
                station_to_model_data[station])
    else:
        for station in station_to_model_data:
            station_to_model_data[station] = station_to_model_data[station][0]

    # calculate observed climatology
    station_to_climatology = OrderedDict()
    for s in sorted(station_to_model_point,
                    key=lambda st: st.latitude,
                    reverse=True):
        assert isinstance(s, Station)
        print(s.id, len(s.get_list_of_complete_years()))

        # Check if there are continuous years for the selected period
        common_years = set(s.get_list_of_complete_years()).intersection(
            set(range(start_year, end_year + 1)))
        if len(common_years) > 0:
            _, station_to_climatology[
                s] = s.get_daily_climatology_for_complete_years_with_pandas(
                    stamp_dates=stamp_dates, years=common_years)

            _, station_to_model_data[
                s] = pandas_utils.get_daily_climatology_from_pandas_series(
                    station_to_model_data[s],
                    stamp_dates,
                    years_of_interest=common_years)

        else:
            print(
                "Skipping {}, since it does not have enough data during the period of interest"
                .format(s.id))

    # ---- Do the plotting ----
    ncols = 4

    nrows = len(station_to_climatology) // ncols
    nrows += int(not (len(station_to_climatology) % ncols == 0))

    axes_list = []
    plot_utils.apply_plot_params(width_cm=8 * ncols,
                                 height_cm=8 * nrows,
                                 font_size=8)
    fig = plt.figure()
    gs = GridSpec(nrows=nrows, ncols=ncols)

    for i, (s, clim) in enumerate(station_to_climatology.items()):
        assert isinstance(s, Station)

        row = i // ncols
        col = i % ncols

        print(row, col, nrows, ncols)

        # normalize by the drainage area
        if s.drainage_km2 is not None:
            station_to_model_data[
                s] *= s.drainage_km2 / station_to_model_point[
                    s].accumulation_area

        if s.id in constants.stations_to_greyout:
            ax = fig.add_subplot(gs[row, col], facecolor="0.45")
        else:
            ax = fig.add_subplot(gs[row, col])

        assert isinstance(ax, Axes)

        ax.plot(stamp_dates, clim, color="k", lw=2, label="Obs.")
        ax.plot(stamp_dates,
                station_to_model_data[s],
                color="r",
                lw=2,
                label="Mod.")
        ax.xaxis.set_major_formatter(FuncFormatter(format_month_label))
        ax.xaxis.set_major_locator(MonthLocator(bymonthday=15))
        ax.xaxis.set_minor_locator(MonthLocator(bymonthday=1))
        ax.grid()

        ax.annotate(s.get_pp_name(),
                    xy=(1.02, 1),
                    xycoords="axes fraction",
                    horizontalalignment="left",
                    verticalalignment="top",
                    fontsize=8,
                    rotation=-90)

        last_date = stamp_dates[-1]
        last_date = last_date.replace(
            day=calendar.monthrange(last_date.year, last_date.month)[1])

        ax.set_xlim(stamp_dates[0].replace(day=1), last_date)

        ymin, ymax = ax.get_ylim()
        ax.set_ylim(0, ymax)

        if s.drainage_km2 is not None:
            ax.set_title(
                "{}: ({:.1f}$^\circ$E, {:.1f}$^\circ$N, DA={:.0f} km$^2$)".
                format(s.id, s.longitude, s.latitude, s.drainage_km2))
        else:
            ax.set_title(
                "{}: ({:.1f}$^\circ$E, {:.1f}$^\circ$N, DA not used)".format(
                    s.id, s.longitude, s.latitude))
        axes_list.append(ax)

    # plot the legend
    axes_list[-1].legend()

    if not img_folder.exists():
        img_folder.mkdir()

    fig.tight_layout()
    img_file = img_folder / "{}_{}-{}_{}.png".format(
        sim_label, start_year, end_year, "-".join(
            sorted(s.id for s in station_to_climatology)))

    print("Saving {}".format(img_file))
    fig.savefig(str(img_file), bbox_inches="tight", dpi=300)
def main():
    start_year = 1980
    end_year = 2009

    HL_LABEL = "CRCM5_HL"
    NEMO_LABEL = "CRCM5_NEMO"

    # critical p-value for the ttest aka significance level
    p_crit = 1

    vars_of_interest = [
        # T_AIR_2M,
        # TOTAL_PREC,
        # SWE,
        default_varname_mappings.LATENT_HF,
        default_varname_mappings.SENSIBLE_HF,
        default_varname_mappings.LWRAD_DOWN,
        default_varname_mappings.SWRAD_DOWN
        #       LAKE_ICE_FRACTION
    ]

    coastline_width = 0.3

    vname_to_seasonmonths_map = {
        SWE: OrderedDict([("November", [11]),
                          ("December", [12]),
                          ("January", [1, ])]),
        LAKE_ICE_FRACTION: OrderedDict([
            ("December", [12]),
            ("January", [1, ]),
            ("February", [2, ]),
            ("March", [3, ]),
            ("April", [4, ])]),
        T_AIR_2M: season_to_months,
        TOTAL_PREC: season_to_months,
    }


    # set season to months mappings
    for vname in vars_of_interest:
        if vname not in vname_to_seasonmonths_map:
            vname_to_seasonmonths_map[vname] = season_to_months


    sim_configs = {
        HL_LABEL: RunConfig(data_path="/RECH2/huziy/coupling/GL_440x260_0.1deg_GL_with_Hostetler/Samples_selected",
                            start_year=start_year, end_year=end_year, label=HL_LABEL),

        NEMO_LABEL: RunConfig(data_path="/RECH2/huziy/coupling/coupled-GL-NEMO1h_30min/selected_fields",
                              start_year=start_year, end_year=end_year, label=NEMO_LABEL),
    }

    sim_labels = [HL_LABEL, NEMO_LABEL]

    vname_to_level = {
        T_AIR_2M: VerticalLevel(1, level_kinds.HYBRID),
        U_WE: VerticalLevel(1, level_kinds.HYBRID),
        V_SN: VerticalLevel(1, level_kinds.HYBRID),
        default_varname_mappings.LATENT_HF: VerticalLevel(5, level_kinds.ARBITRARY),
        default_varname_mappings.SENSIBLE_HF: VerticalLevel(5, level_kinds.ARBITRARY),
    }

    # Try to get the land_fraction for masking if necessary
    land_fraction = None
    try:
        first_ts_file = Path(sim_configs[HL_LABEL].data_path).parent / "pm1979010100_00000000p"

        land_fraction = get_land_fraction(first_timestep_file=first_ts_file)
    except Exception as err:
        raise err
        pass

    # Calculations

    # prepare params for interpolation
    lons_t, lats_t, bsmap = get_target_lons_lats_basemap(sim_configs[HL_LABEL])

    # get a subdomain of the simulation domain
    nx, ny = lons_t.shape
    iss = IndexSubspace(i_start=20, j_start=10, i_end=nx // 1.5, j_end=ny / 1.8)
    # just to change basemap limits
    lons_t, lats_t, bsmap = get_target_lons_lats_basemap(sim_configs[HL_LABEL], sub_space=iss)

    xt, yt, zt = lat_lon.lon_lat_to_cartesian(lons_t.flatten(), lats_t.flatten())

    vname_map = {}
    vname_map.update(default_varname_mappings.vname_map_CRCM5)

    # Read and calculate simulated seasonal means
    mod_label_to_vname_to_season_to_std = {}
    mod_label_to_vname_to_season_to_nobs = {}

    sim_data = defaultdict(dict)
    for label, r_config in sim_configs.items():

        store_config = {
            "base_folder": r_config.data_path,
            "data_source_type": data_source_types.SAMPLES_FOLDER_FROM_CRCM_OUTPUT_VNAME_IN_FNAME,
            "varname_mapping": vname_map,
            "level_mapping": vname_to_level,
            "offset_mapping": default_varname_mappings.vname_to_offset_CRCM5,
            "multiplier_mapping": default_varname_mappings.vname_to_multiplier_CRCM5,
        }

        dm = DataManager(store_config=store_config)

        mod_label_to_vname_to_season_to_std[label] = {}
        mod_label_to_vname_to_season_to_nobs[label] = {}

        interp_indices = None
        for vname in vars_of_interest:

            # --
            end_year_for_current_var = end_year
            if vname == SWE:
                end_year_for_current_var = min(1996, end_year)

            # --
            seas_to_year_to_mean = dm.get_seasonal_means(varname_internal=vname,
                                                         start_year=start_year,
                                                         end_year=end_year_for_current_var,
                                                         season_to_months=vname_to_seasonmonths_map[vname])

            # get the climatology
            seas_to_clim = {seas: np.array(list(y_to_means.values())).mean(axis=0) for seas, y_to_means in
                            seas_to_year_to_mean.items()}

            sim_data[label][vname] = seas_to_clim

            if interp_indices is None:
                _, interp_indices = dm.get_kdtree().query(list(zip(xt, yt, zt)))

            season_to_std = {}
            mod_label_to_vname_to_season_to_std[label][vname] = season_to_std

            season_to_nobs = {}
            mod_label_to_vname_to_season_to_nobs[label][vname] = season_to_nobs

            for season in seas_to_clim:
                interpolated_field = seas_to_clim[season].flatten()[interp_indices].reshape(lons_t.shape)
                seas_to_clim[season] = interpolated_field

                # calculate standard deviations of the interpolated fields
                season_to_std[season] = np.asarray([field.flatten()[interp_indices].reshape(lons_t.shape) for field in
                                                    seas_to_year_to_mean[season].values()]).std(axis=0)

                # calculate numobs for the ttest
                season_to_nobs[season] = np.ones_like(lons_t) * len(seas_to_year_to_mean[season])



    # Plotting: interpolate to the same grid and plot obs and biases
    xx, yy = bsmap(lons_t, lats_t)
    lons_t[lons_t > 180] -= 360


    for vname in vars_of_interest:

        field_mask = maskoceans(lons_t, lats_t, np.zeros_like(lons_t), inlands=vname in [SWE]).mask
        field_mask_lakes = maskoceans(lons_t, lats_t, np.zeros_like(lons_t), inlands=True).mask

        plot_utils.apply_plot_params(width_cm=11 * len(vname_to_seasonmonths_map[vname]), height_cm=20, font_size=8)

        fig = plt.figure()



        nrows = len(sim_configs) + 1
        ncols = len(vname_to_seasonmonths_map[vname])
        gs = GridSpec(nrows=nrows, ncols=ncols)




        # plot the fields
        for current_row, sim_label in enumerate(sim_labels):
            for col, season in enumerate(vname_to_seasonmonths_map[vname]):

                field = sim_data[sim_label][vname][season]

                ax = fig.add_subplot(gs[current_row, col])

                if current_row == 0:
                    ax.set_title(season)

                clevs = get_clevs(vname)
                if clevs is not None:
                    bnorm = BoundaryNorm(clevs, len(clevs) - 1)
                    cmap = cm.get_cmap("viridis", len(clevs) - 1)
                else:
                    cmap = "viridis"
                    bnorm = None

                the_mask = field_mask_lakes if vname in [T_AIR_2M, TOTAL_PREC, SWE] else field_mask
                to_plot = np.ma.masked_where(the_mask, field) * internal_name_to_multiplier[vname]



                # temporary plot the actual values
                cs = bsmap.contourf(xx, yy, to_plot, ax=ax, levels=get_clevs(vname), cmap=cmap, norm=bnorm, extend="both")
                bsmap.drawcoastlines(linewidth=coastline_width)
                bsmap.colorbar(cs, ax=ax)

                if col == 0:
                    ax.set_ylabel("{}".format(sim_label))





        # plot differences between the fields
        for col, season in enumerate(vname_to_seasonmonths_map[vname]):

            field = sim_data[NEMO_LABEL][vname][season] - sim_data[HL_LABEL][vname][season]

            ax = fig.add_subplot(gs[-1, col])

            clevs = get_clevs(vname + "biasdiff")
            if clevs is not None:
                bnorm = BoundaryNorm(clevs, len(clevs) - 1)
                cmap = cm.get_cmap("bwr", len(clevs) - 1)
            else:
                cmap = "bwr"
                bnorm = None


            to_plot = field * internal_name_to_multiplier[vname]
            # to_plot = np.ma.masked_where(field_mask, field) * internal_name_to_multiplier[vname]



            # ttest
            a = sim_data[NEMO_LABEL][vname][season]  # Calculate the simulation data back from biases
            std_a = mod_label_to_vname_to_season_to_std[NEMO_LABEL][vname][season]
            nobs_a = mod_label_to_vname_to_season_to_nobs[NEMO_LABEL][vname][season]

            b = sim_data[HL_LABEL][vname][season]  # Calculate the simulation data back from biases
            std_b = mod_label_to_vname_to_season_to_std[HL_LABEL][vname][season]
            nobs_b = mod_label_to_vname_to_season_to_nobs[HL_LABEL][vname][season]


            t, p = ttest_ind_from_stats(mean1=a, std1=std_a, nobs1=nobs_a,
                                        mean2=b, std2=std_b, nobs2=nobs_b, equal_var=False)

            # Mask non-significant differences as given by the ttest
            to_plot = np.ma.masked_where(p > p_crit, to_plot)


            # mask the points with not sufficient land fraction
            if land_fraction is not None and vname in [SWE, ]:
                to_plot = np.ma.masked_where(land_fraction < 0.05, to_plot)


            # print("land fractions for large differences ", land_fraction[to_plot > 30])


            cs = bsmap.contourf(xx, yy, to_plot, ax=ax, extend="both", levels=get_clevs(vname + "biasdiff"), cmap=cmap, norm=bnorm)
            bsmap.drawcoastlines(linewidth=coastline_width)
            bsmap.colorbar(cs, ax=ax)

            if col == 0:
                ax.set_ylabel("{}\n-\n{}".format(NEMO_LABEL, HL_LABEL))


        fig.tight_layout()

        # save a figure per variable
        img_file = "seasonal_differences_noobs_{}_{}_{}-{}.png".format(vname,
                                                            "-".join([s for s in vname_to_seasonmonths_map[vname]]),
                                                            start_year, end_year)
        img_file = img_folder.joinpath(img_file)

        fig.savefig(str(img_file), dpi=300)

        plt.close(fig)
def main(vars_of_interest=None):
    # Validation with CRU (temp, precip) and CMC SWE

    # obs_data_path = Path("/RESCUE/skynet3_rech1/huziy/obs_data_for_HLES/interploated_to_the_same_grid/GL_0.1_452x260/anusplin+_interpolated_tt_pr.nc")
    obs_data_path = Path("/HOME/huziy/skynet3_rech1/obs_data/mh_churchill_nelson_obs_fields")
    CRU_PRECIP = True

    sim_id = "mh_0.44"
    add_shp_files = [
        default_domains.MH_BASINS_PATH,
        constants.upstream_station_boundaries_shp_path[sim_id]
    ]


    start_year = 1981
    end_year = 2009

    MODEL_LABEL =  "CRCM5 (0.44)"
    # critical p-value for the ttest aka significance level
    # p_crit = 0.05
    p_crit = 1

    coastlines_width = 0.3

    vars_of_interest_default = [
        # T_AIR_2M,
        TOTAL_PREC,
        # SWE,
        # LAKE_ICE_FRACTION
    ]

    if vars_of_interest is None:
        vars_of_interest = vars_of_interest_default


    vname_to_seasonmonths_map = {
        SWE: OrderedDict([("DJF", [12, 1, 2])]),
        T_AIR_2M: season_to_months,
        TOTAL_PREC: OrderedDict([("Annual", [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12])]) # season_to_months,

    }

    sim_configs = {

        MODEL_LABEL: RunConfig(data_path="/RECH2/huziy/BC-MH/bc_mh_044deg/Samples",
                  start_year=start_year, end_year=end_year, label=MODEL_LABEL),

    }


    grid_config = default_domains.bc_mh_044




    sim_labels = [MODEL_LABEL, ]

    vname_to_level = {
        T_AIR_2M: VerticalLevel(1, level_kinds.HYBRID),
        U_WE: VerticalLevel(1, level_kinds.HYBRID),
        V_SN: VerticalLevel(1, level_kinds.HYBRID),
        SWE: VerticalLevel(-1, level_kinds.ARBITRARY)
    }

    vname_map = {
        default_varname_mappings.TOTAL_PREC: "pre",
        default_varname_mappings.T_AIR_2M: "tmp",
        default_varname_mappings.SWE: "SWE"
    }

    filename_prefix_mapping = {
        default_varname_mappings.SWE: "pm",
        default_varname_mappings.TOTAL_PREC: "pm",
        default_varname_mappings.T_AIR_2M: "dm"
    }


    # Try to get the land_fraction for masking if necessary
    land_fraction = None
    try:
        land_fraction = get_land_fraction(sim_configs[MODEL_LABEL])
    except Exception:
        pass



    # Calculations

    # prepare params for interpolation
    lons_t, lats_t, bsmap = get_target_lons_lats_basemap(sim_configs[MODEL_LABEL])

    bsmap, reg_of_interest_mask = grid_config.get_basemap_using_shape_with_polygons_of_interest(lons=lons_t, lats=lats_t,
                                                                                                shp_path=default_domains.MH_BASINS_PATH,
                                                                                                mask_margin=2, resolution="i")

    xt, yt, zt = lat_lon.lon_lat_to_cartesian(lons_t.flatten(), lats_t.flatten())










    obs_multipliers = default_varname_mappings.vname_to_multiplier_CRCM5.copy()

    # Read and calculate observed seasonal means
    store_config = {
            "base_folder": obs_data_path.parent if not obs_data_path.is_dir() else obs_data_path,
            "data_source_type": data_source_types.ALL_VARS_IN_A_FOLDER_IN_NETCDF_FILES_OPEN_EACH_FILE_SEPARATELY,
            "varname_mapping": vname_map,
            "level_mapping": vname_to_level,
            "offset_mapping": default_varname_mappings.vname_to_offset_CRCM5,
            "multiplier_mapping": obs_multipliers,
    }

    obs_dm = DataManager(store_config=store_config)
    obs_data = {}


    # need to save it for ttesting
    obs_vname_to_season_to_std = {}
    obs_vname_to_season_to_nobs = {}

    interp_indices = None
    for vname in vars_of_interest:
        # --
        end_year_for_current_var = end_year
        if vname == SWE:
            end_year_for_current_var = min(1996, end_year)

        # --
        seas_to_year_to_mean = obs_dm.get_seasonal_means(varname_internal=vname,
                                                     start_year=start_year,
                                                     end_year=end_year_for_current_var,
                                                     season_to_months=vname_to_seasonmonths_map[vname])





        seas_to_clim = {seas: np.array(list(y_to_means.values())).mean(axis=0) for seas, y_to_means in seas_to_year_to_mean.items()}

        # convert precip from mm/month (CRU) to mm/day
        if vname in [TOTAL_PREC] and CRU_PRECIP:
            for seas in seas_to_clim:
                seas_to_clim[seas] *= 1. / (365.25 / 12)
                seas_to_clim[seas] = np.ma.masked_where(np.isnan(seas_to_clim[seas]), seas_to_clim[seas])


                print("{}: min={}, max={}".format(seas, seas_to_clim[seas].min(), seas_to_clim[seas].max()))


        obs_data[vname] = seas_to_clim

        if interp_indices is None:
            _, interp_indices = obs_dm.get_kdtree().query(list(zip(xt, yt, zt)))




        # need for ttests
        season_to_std = {}
        obs_vname_to_season_to_std[vname] = season_to_std

        season_to_nobs = {}
        obs_vname_to_season_to_nobs[vname] = season_to_nobs

        for season in seas_to_clim:
            seas_to_clim[season] = seas_to_clim[season].flatten()[interp_indices].reshape(lons_t.shape)



            # save the yearly means for ttesting
            season_to_std[season] = np.asarray([field.flatten()[interp_indices].reshape(lons_t.shape)
                                                         for field in seas_to_year_to_mean[season].values()]).std(axis=0)


            season_to_nobs[season] = np.ones_like(lons_t) * len(seas_to_year_to_mean[season])


        plt.show()



    # Read and calculate simulated seasonal mean biases
    mod_label_to_vname_to_season_to_std = {}
    mod_label_to_vname_to_season_to_nobs = {}

    model_data_multipliers = defaultdict(lambda: 1)
    model_data_multipliers[TOTAL_PREC] = 1000 * 24 * 3600

    sim_data = defaultdict(dict)
    for label, r_config in sim_configs.items():

        store_config = {
                "base_folder": r_config.data_path,
                "data_source_type": data_source_types.SAMPLES_FOLDER_FROM_CRCM_OUTPUT,
                "varname_mapping": default_varname_mappings.vname_map_CRCM5,
                "level_mapping": vname_to_level,
                "offset_mapping": default_varname_mappings.vname_to_offset_CRCM5,
                "multiplier_mapping": model_data_multipliers,
                "filename_prefix_mapping": filename_prefix_mapping
        }


        dm = DataManager(store_config=store_config)

        mod_label_to_vname_to_season_to_std[label] = {}
        mod_label_to_vname_to_season_to_nobs[label] = {}


        interp_indices = None
        for vname in vars_of_interest:

            # --
            end_year_for_current_var = end_year
            if vname == SWE:
                end_year_for_current_var = min(1996, end_year)

            # --
            seas_to_year_to_mean = dm.get_seasonal_means(varname_internal=vname,
                                                         start_year=start_year,
                                                         end_year=end_year_for_current_var,
                                                         season_to_months=vname_to_seasonmonths_map[vname])


            # get the climatology
            seas_to_clim = {seas: np.array(list(y_to_means.values())).mean(axis=0) for seas, y_to_means in seas_to_year_to_mean.items()}

            sim_data[label][vname] = seas_to_clim



            if interp_indices is None:
                _, interp_indices = dm.get_kdtree().query(list(zip(xt, yt, zt)))


            season_to_std = {}
            mod_label_to_vname_to_season_to_std[label][vname] = season_to_std

            season_to_nobs = {}
            mod_label_to_vname_to_season_to_nobs[label][vname] = season_to_nobs

            for season in seas_to_clim:
                interpolated_field = seas_to_clim[season].flatten()[interp_indices].reshape(lons_t.shape)
                seas_to_clim[season] = interpolated_field - obs_data[vname][season]

                # calculate standard deviations of the interpolated fields
                season_to_std[season] = np.asarray([field.flatten()[interp_indices].reshape(lons_t.shape) for field in seas_to_year_to_mean[season].values()]).std(axis=0)

                # calculate numobs for the ttest
                season_to_nobs[season] = np.ones_like(lons_t) * len(seas_to_year_to_mean[season])






    xx, yy = bsmap(lons_t, lats_t)
    lons_t[lons_t > 180] -= 360

    field_mask = maskoceans(lons_t, lats_t, np.zeros_like(lons_t)).mask


    for vname in vars_of_interest:

        if vname not in [SWE]:
            field_mask = np.zeros_like(field_mask, dtype=bool)


        # Plotting: interpolate to the same grid and plot obs and biases
        plot_utils.apply_plot_params(width_cm=32 / 4 * (len(vname_to_seasonmonths_map[vname])),
                                     height_cm=25 / 3.0 * (len(sim_configs) + 1), font_size=8 * len(vname_to_seasonmonths_map[vname]))

        fig = plt.figure()

        # fig.suptitle(internal_name_to_title[vname] + "\n")

        nrows = len(sim_configs) + 2
        ncols = len(vname_to_seasonmonths_map[vname])
        gs = GridSpec(nrows=nrows, ncols=ncols)



        # Plot the obs fields
        current_row = 0
        for col, season in enumerate(vname_to_seasonmonths_map[vname]):
            field = obs_data[vname][season]
            ax = fig.add_subplot(gs[current_row, col])
            ax.set_title(season)

            to_plot = np.ma.masked_where(field_mask, field) * internal_name_to_multiplier[vname]
            clevs = get_clevs(vname)

            to_plot = np.ma.masked_where(~reg_of_interest_mask, to_plot)

            if clevs is not None:
                bnorm = BoundaryNorm(clevs, len(clevs) - 1)
                cmap = cm.get_cmap("Blues", len(clevs) - 1)
            else:
                cmap = "jet"
                bnorm = None

            bsmap.drawmapboundary(fill_color="0.75")

            # cs = bsmap.contourf(xx, yy, to_plot, ax=ax, levels=get_clevs(vname), norm=bnorm, cmap=cmap)
            cs = bsmap.pcolormesh(xx, yy, to_plot, ax=ax, norm=bnorm, cmap=internal_name_to_cmap[vname])

            bsmap.drawcoastlines(linewidth=coastlines_width)
            # bsmap.drawstates(linewidth=0.1)
            # bsmap.drawcountries(linewidth=0.2)
            bsmap.colorbar(cs, ax=ax)

            i = 0
            bsmap.readshapefile(str(add_shp_files[i])[:-4], "field_{}".format(i), linewidth=0.5, color="m")


            if col == 0:
                ax.set_ylabel("Obs")



        # plot the biases
        for sim_label in sim_labels:
            current_row += 1
            for col, season in enumerate(vname_to_seasonmonths_map[vname]):

                field = sim_data[sim_label][vname][season]

                ax = fig.add_subplot(gs[current_row, col])

                clevs = get_clevs(vname + "bias")
                if clevs is not None:
                    bnorm = BoundaryNorm(clevs, len(clevs) - 1)
                    cmap = cm.get_cmap("bwr", len(clevs) - 1)
                else:
                    cmap = "bwr"
                    bnorm = None

                to_plot = np.ma.masked_where(field_mask, field) * internal_name_to_multiplier[vname]


                # ttest
                a = sim_data[sim_label][vname][season] + obs_data[vname][season]  # Calculate the simulation data back from biases
                std_a = mod_label_to_vname_to_season_to_std[sim_label][vname][season]
                nobs_a = mod_label_to_vname_to_season_to_nobs[sim_label][vname][season]

                b = obs_data[vname][season]
                std_b =  obs_vname_to_season_to_std[vname][season]
                nobs_b = obs_vname_to_season_to_nobs[vname][season]



                t, p = ttest_ind_from_stats(mean1=a, std1=std_a, nobs1=nobs_a,
                                            mean2=b, std2=std_b, nobs2=nobs_b, equal_var=False)

                # Mask non-significant differences as given by the ttest
                to_plot = np.ma.masked_where(p > p_crit, to_plot)

                # only focus on the basins of interest
                to_plot = np.ma.masked_where(~reg_of_interest_mask, to_plot)


                # cs = bsmap.contourf(xx, yy, to_plot, ax=ax, extend="both", levels=get_clevs(vname + "bias"), cmap=cmap, norm=bnorm)

                bsmap.drawmapboundary(fill_color="0.75")


                cs = bsmap.pcolormesh(xx, yy, to_plot, ax=ax, cmap=cmap, norm=bnorm)
                bsmap.drawcoastlines(linewidth=coastlines_width)
                bsmap.colorbar(cs, ax=ax, extend="both")





                for i, shp in enumerate(add_shp_files[1:], start=1):
                    bsmap.readshapefile(str(shp)[:-4], "field_{}".format(i), linewidth=0.5, color="k")

                if col == 0:
                    ax.set_ylabel("{}\n-\nObs.".format(sim_label))




        fig.tight_layout()



        # save a figure per variable
        img_file = "seasonal_biases_{}_{}_{}-{}.png".format(vname,
                                                            "-".join([s for s in vname_to_seasonmonths_map[vname]]),
                                                            start_year, end_year)


        if not img_folder.exists():
            img_folder.mkdir(parents=True)

        img_file = img_folder / img_file
        fig.savefig(str(img_file), bbox_inches="tight", dpi=300)

        plt.close(fig)
def main(label_to_data_path: dict, varnames=None, season_to_months: dict=None,
         cur_label="", fut_label="",
         vname_to_mask: dict=None, vname_display_names:dict=None,
         pval_crit=0.1, periods_info: CcPeriodsInfo=None,
         vars_info: dict=None):

    """

    :param pval_crit:
    :param vars_info:
    :param label_to_data_path:
    :param varnames:
    :param season_to_months:
    :param cur_label:
    :param fut_label:
    :param vname_to_mask: - to mask everything except the region of interest
    """

    if vname_display_names is None:
        vname_display_names = {}

    varname_mapping = {v: v for v in varnames}
    level_mapping = {v: VerticalLevel(0) for v in varnames} # Does not really make a difference, since all variables are 2d

    comon_store_config = {
        DataManager.SP_DATASOURCE_TYPE: data_source_types.ALL_VARS_IN_A_FOLDER_IN_NETCDF_FILES,
        DataManager.SP_INTERNAL_TO_INPUT_VNAME_MAPPING: varname_mapping,
        DataManager.SP_LEVEL_MAPPING: level_mapping
    }

    cur_dm = DataManager(
        store_config=dict({DataManager.SP_BASE_FOLDER: label_to_data_path[cur_label]}, **comon_store_config)
    )

    fut_dm = DataManager(
        store_config=dict({DataManager.SP_BASE_FOLDER: label_to_data_path[fut_label]}, **comon_store_config)
    )

    # get the data and do calculations
    var_to_season_to_data = {}

    cur_start_yr, cur_end_year = periods_info.get_cur_year_limits()
    fut_start_yr, fut_end_year = periods_info.get_fut_year_limits()

    for vname in varnames:
        cur_means = cur_dm.get_seasonal_means(start_year=cur_start_yr, end_year=cur_end_year,
                                              season_to_months=season_to_months, varname_internal=vname)

        fut_means = fut_dm.get_seasonal_means(start_year=fut_start_yr, end_year=fut_end_year,
                                              season_to_months=season_to_months, varname_internal=vname)

        # convert means to the accumulators (if required)
        opts = vars_info[vname]
        if "accumulation" in opts and opts["accumulation"]:
            for seas_name, months in season_to_months.items():
                cur_means[seas_name] = {y: f * periods_info.get_numdays_for_season(y, month_list=months) for y, f in cur_means[seas_name].items()}
                fut_means[seas_name] = {y: f * periods_info.get_numdays_for_season(y, month_list=months) for y, f in fut_means[seas_name].items()}


        var_to_season_to_data[vname] = calculate_change_and_pvalues(cur_means, fut_means, percentages=False)


    # add hles days
    hles_days_varname = "hles_snow_days"
    varnames.insert(1, hles_days_varname)
    cur_means = cur_dm.get_mean_number_of_hles_days(start_year=cur_start_yr, end_year=cur_end_year,
                                                    season_to_months=season_to_months,
                                                    hles_vname="hles_snow")


    fut_means = fut_dm.get_mean_number_of_hles_days(start_year=fut_start_yr, end_year=fut_end_year,
                                                     season_to_months=season_to_months,
                                                     hles_vname="hles_snow")

    var_to_season_to_data[hles_days_varname] = calculate_change_and_pvalues(cur_means, fut_means, percentages=False)


    # add CAO days
    cao_ndays_varname = "cao_days"
    varnames.append(cao_ndays_varname)

    cur_means = cur_dm.get_mean_number_of_cao_days(start_year=cur_start_yr, end_year=cur_end_year,
                                                    season_to_months=season_to_months,
                                                    temperature_vname="TT")


    fut_means = fut_dm.get_mean_number_of_cao_days(start_year=fut_start_yr, end_year=fut_end_year,
                                                     season_to_months=season_to_months,
                                                     temperature_vname="TT")

    var_to_season_to_data[cao_ndays_varname] = calculate_change_and_pvalues(cur_means, fut_means, percentages=False)



    # Plotting
    # panel grid dimensions
    ncols = len(season_to_months)
    nrows = len(varnames)

    gs = GridSpec(nrows, ncols, wspace=0, hspace=0)
    fig = plt.figure()

    for col, seas_name in enumerate(season_to_months):
        for row, vname in enumerate(varnames):

            ax = fig.add_subplot(gs[row, col], projection=cartopy.crs.PlateCarree())


            # identify variable names
            if col == 0:
                ax.set_ylabel(vname_display_names.get(vname, vname))

            cc, pv = var_to_season_to_data[vname][seas_name]
            to_plot = cc

            print(f"Plotting {vname} for {seas_name}.")
            opts = vars_info[vname]
            vmin = None
            vmax = None
            if vars_info is not None:
                if vname in vars_info:
                    to_plot = to_plot * opts["multiplier"] + opts["offset"]

                    vmin = opts["vmin"]
                    vmax = opts["vmax"]

                    if "mask" in opts:
                        to_plot = np.ma.masked_where(~opts["mask"], to_plot)


            ax.set_facecolor("0.75")

            # hide the ticks
            ax.xaxis.set_major_locator(NullLocator())
            ax.yaxis.set_major_locator(NullLocator())

            cmap = opts.get("cmap", cm.get_cmap("bwr", 11))

            im = ax.pcolormesh(cur_dm.lons, cur_dm.lats, to_plot,
                               cmap=cmap, vmin=vmin, vmax=vmax)



            # ax.add_feature(cartopy.feature.RIVERS, facecolor="none", edgecolor="0.75", linewidth=0.5)
            line_color = "k"
            ax.add_feature(common_params.LAKES_50m, facecolor="none", edgecolor=line_color, linewidth=0.5)
            ax.add_feature(common_params.COASTLINE_50m, facecolor="none", edgecolor=line_color, linewidth=0.5)
            ax.add_feature(common_params.RIVERS_50m, facecolor="none", edgecolor=line_color, linewidth=0.5)
            ax.set_extent([cur_dm.lons[0, 0], cur_dm.lons[-1, -1], cur_dm.lats[0, 0], cur_dm.lats[-1, -1]])

            divider = make_axes_locatable(ax)
            ax_cb = divider.new_horizontal(size="5%", pad=0.1, axes_class=plt.Axes)
            fig.add_axes(ax_cb)
            cb = plt.colorbar(im, extend="both", cax=ax_cb)

            # if hasattr(to_plot, "mask"):
            #     to_plot = np.ma.masked_where(to_plot.mask, pv)
            # else:
            #     to_plot = pv
            # ax.contour(to_plot.T, levels=(pval_crit, ))


            # set season titles
            if row == 0:
                ax.text(0.5, 1.05, seas_name, va="bottom", ha="center", multialignment="center", transform=ax.transAxes)


            if col < ncols - 1:
                cb.ax.set_visible(False)

    # Save the figure in file
    img_folder = common_params.img_folder
    img_folder.mkdir(exist_ok=True)

    img_file = img_folder / f"cc_{fut_label}-{cur_label}.png"

    fig.savefig(str(img_file), **common_params.image_file_options)
def main(label_to_data_path: dict,
         varnames=None,
         season_to_months: dict = None,
         cur_label="",
         fut_label="",
         vname_to_mask: dict = None,
         vname_display_names: dict = None,
         pval_crit=0.1,
         periods_info: CcPeriodsInfo = None,
         vars_info: dict = None):
    """

    :param pval_crit:
    :param vars_info:
    :param label_to_data_path:
    :param varnames:
    :param season_to_months:
    :param cur_label:
    :param fut_label:
    :param vname_to_mask: - to mask everything except the region of interest
    """

    if vname_display_names is None:
        vname_display_names = {}

    varname_mapping = {v: v for v in varnames}
    level_mapping = {
        v: VerticalLevel(0)
        for v in varnames
    }  # Does not really make a difference, since all variables are 2d

    comon_store_config = {
        DataManager.SP_DATASOURCE_TYPE:
        data_source_types.ALL_VARS_IN_A_FOLDER_IN_NETCDF_FILES,
        DataManager.SP_INTERNAL_TO_INPUT_VNAME_MAPPING: varname_mapping,
        DataManager.SP_LEVEL_MAPPING: level_mapping
    }

    cur_dm = DataManager(store_config=dict(
        {DataManager.SP_BASE_FOLDER: label_to_data_path[cur_label]}, **
        comon_store_config))

    fut_dm = DataManager(store_config=dict(
        {DataManager.SP_BASE_FOLDER: label_to_data_path[fut_label]}, **
        comon_store_config))

    # get the data and do calculations
    var_to_season_to_data = {}

    cur_start_yr, cur_end_year = periods_info.get_cur_year_limits()
    fut_start_yr, fut_end_year = periods_info.get_fut_year_limits()

    for vname in varnames:
        cur_means = cur_dm.get_seasonal_means(
            start_year=cur_start_yr,
            end_year=cur_end_year,
            season_to_months=season_to_months,
            varname_internal=vname)

        fut_means = fut_dm.get_seasonal_means(
            start_year=fut_start_yr,
            end_year=fut_end_year,
            season_to_months=season_to_months,
            varname_internal=vname)

        # convert means to the accumulators (if required)
        opts = vars_info[vname]
        if "accumulation" in opts and opts["accumulation"]:
            for seas_name, months in season_to_months.items():
                cur_means[seas_name] = {
                    y: f *
                    periods_info.get_numdays_for_season(y, month_list=months)
                    for y, f in cur_means[seas_name].items()
                }
                fut_means[seas_name] = {
                    y: f *
                    periods_info.get_numdays_for_season(y, month_list=months)
                    for y, f in fut_means[seas_name].items()
                }

        var_to_season_to_data[vname] = calculate_change_and_pvalues(
            cur_means, fut_means, percentages=False)

    # add hles days
    hles_days_varname = "hles_snow_days"
    varnames.insert(1, hles_days_varname)
    cur_means = cur_dm.get_mean_number_of_hles_days(
        start_year=cur_start_yr,
        end_year=cur_end_year,
        season_to_months=season_to_months,
        hles_vname="hles_snow")

    fut_means = fut_dm.get_mean_number_of_hles_days(
        start_year=fut_start_yr,
        end_year=fut_end_year,
        season_to_months=season_to_months,
        hles_vname="hles_snow")

    var_to_season_to_data[hles_days_varname] = calculate_change_and_pvalues(
        cur_means, fut_means, percentages=False)

    # add CAO days
    cao_ndays_varname = "cao_days"
    varnames.append(cao_ndays_varname)

    cur_means = cur_dm.get_mean_number_of_cao_days(
        start_year=cur_start_yr,
        end_year=cur_end_year,
        season_to_months=season_to_months,
        temperature_vname="TT")

    fut_means = fut_dm.get_mean_number_of_cao_days(
        start_year=fut_start_yr,
        end_year=fut_end_year,
        season_to_months=season_to_months,
        temperature_vname="TT")

    var_to_season_to_data[cao_ndays_varname] = calculate_change_and_pvalues(
        cur_means, fut_means, percentages=False)

    # Plotting
    # panel grid dimensions
    ncols = len(season_to_months)
    nrows = len(varnames)

    gs = GridSpec(nrows, ncols, wspace=0, hspace=0)
    fig = plt.figure()

    for col, seas_name in enumerate(season_to_months):
        for row, vname in enumerate(varnames):

            ax = fig.add_subplot(gs[row, col],
                                 projection=cartopy.crs.PlateCarree())

            # identify variable names
            if col == 0:
                ax.set_ylabel(vname_display_names.get(vname, vname))

            cc, pv = var_to_season_to_data[vname][seas_name]
            to_plot = cc

            print(f"Plotting {vname} for {seas_name}.")
            opts = vars_info[vname]
            vmin = None
            vmax = None
            if vars_info is not None:
                if vname in vars_info:
                    to_plot = to_plot * opts["multiplier"] + opts["offset"]

                    vmin = opts["vmin"]
                    vmax = opts["vmax"]

                    if "mask" in opts:
                        to_plot = np.ma.masked_where(~opts["mask"], to_plot)

            ax.set_facecolor("0.75")

            # hide the ticks
            ax.xaxis.set_major_locator(NullLocator())
            ax.yaxis.set_major_locator(NullLocator())

            cmap = opts.get("cmap", cm.get_cmap("bwr", 11))

            im = ax.pcolormesh(cur_dm.lons,
                               cur_dm.lats,
                               to_plot,
                               cmap=cmap,
                               vmin=vmin,
                               vmax=vmax)

            # ax.add_feature(cartopy.feature.RIVERS, facecolor="none", edgecolor="0.75", linewidth=0.5)
            line_color = "k"
            ax.add_feature(common_params.LAKES_50m,
                           facecolor="none",
                           edgecolor=line_color,
                           linewidth=0.5)
            ax.add_feature(common_params.COASTLINE_50m,
                           facecolor="none",
                           edgecolor=line_color,
                           linewidth=0.5)
            ax.add_feature(common_params.RIVERS_50m,
                           facecolor="none",
                           edgecolor=line_color,
                           linewidth=0.5)
            ax.set_extent([
                cur_dm.lons[0, 0], cur_dm.lons[-1, -1], cur_dm.lats[0, 0],
                cur_dm.lats[-1, -1]
            ])

            divider = make_axes_locatable(ax)
            ax_cb = divider.new_horizontal(size="5%",
                                           pad=0.1,
                                           axes_class=plt.Axes)
            fig.add_axes(ax_cb)
            cb = plt.colorbar(im, extend="both", cax=ax_cb)

            # if hasattr(to_plot, "mask"):
            #     to_plot = np.ma.masked_where(to_plot.mask, pv)
            # else:
            #     to_plot = pv
            # ax.contour(to_plot.T, levels=(pval_crit, ))

            # set season titles
            if row == 0:
                ax.text(0.5,
                        1.05,
                        seas_name,
                        va="bottom",
                        ha="center",
                        multialignment="center",
                        transform=ax.transAxes)

            if col < ncols - 1:
                cb.ax.set_visible(False)

    # Save the figure in file
    img_folder = common_params.img_folder
    img_folder.mkdir(exist_ok=True)

    img_file = img_folder / f"cc_{fut_label}-{cur_label}.png"

    fig.savefig(str(img_file), **common_params.image_file_options)
Exemplo n.º 14
0
def main():
    direction_file_path = Path("/RECH2/huziy/BC-MH/bc_mh_044deg/Samples/bc_mh_044deg_198001/pm1980010100_00000000p")

    sim_label = "mh_0.44"

    start_year = 1981
    end_year = 2010

    streamflow_internal_name = "streamflow"
    selected_staion_ids = constants.selected_station_ids_for_streamflow_validation

    # ======================================================





    day = timedelta(days=1)
    t0 = datetime(2001, 1, 1)
    stamp_dates = [t0 + i * day for i in range(365)]
    print("stamp dates range {} ... {}".format(stamp_dates[0], stamp_dates[-1]))


    lake_fraction = None

    # establish the correspondence between the stations and model grid points
    with RPN(str(direction_file_path)) as r:
        assert isinstance(r, RPN)
        fldir = r.get_first_record_for_name("FLDR")
        flow_acc_area = r.get_first_record_for_name("FAA")
        lons, lats = r.get_longitudes_and_latitudes_for_the_last_read_rec()
        # lake_fraction = r.get_first_record_for_name("LF1")

    cell_manager = CellManager(fldir, lons2d=lons, lats2d=lats, accumulation_area_km2=flow_acc_area)
    stations = stfl_stations.load_stations_from_csv(selected_ids=selected_staion_ids)
    station_to_model_point = cell_manager.get_model_points_for_stations(station_list=stations, lake_fraction=lake_fraction,
                                                                        nneighbours=8)


    # Update the end year if required
    max_year_st = -1
    for station in station_to_model_point:
        y = max(station.get_list_of_complete_years())
        if y >= max_year_st:
            max_year_st = y


    if end_year > max_year_st:
        print("Updated end_year to {}, because no obs data after...".format(max_year_st))
        end_year = max_year_st



    # read model data
    mod_data_manager = DataManager(
        store_config={
            "varname_mapping": {streamflow_internal_name: "STFA"},
            "base_folder": str(direction_file_path.parent.parent),
            "data_source_type": data_source_types.SAMPLES_FOLDER_FROM_CRCM_OUTPUT,
            "level_mapping": {streamflow_internal_name: VerticalLevel(-1, level_type=level_kinds.ARBITRARY)},
            "offset_mapping": vname_to_offset_CRCM5,
            "filename_prefix_mapping": {streamflow_internal_name: "pm"}
    })


    station_to_model_data = defaultdict(list)
    for year in range(start_year, end_year + 1):
        start = Pendulum(year, 1, 1)
        p_test = Period(start, start.add(years=1).subtract(microseconds=1))
        stfl_mod = mod_data_manager.read_data_for_period(p_test, streamflow_internal_name)

        # convert to daily
        stfl_mod = stfl_mod.resample("D", "t", how="mean", closed="left", keep_attrs=True)

        assert isinstance(stfl_mod, xr.DataArray)

        for station, model_point in station_to_model_point.items():
            assert isinstance(model_point, ModelPoint)
            ts1 = stfl_mod[:, model_point.ix, model_point.jy].to_series()
            station_to_model_data[station].append(pd.Series(index=stfl_mod.t.values, data=ts1))





    # concatenate the timeseries for each point, if required
    if end_year - start_year + 1 > 1:
        for station in station_to_model_data:
            station_to_model_data[station] = pd.concat(station_to_model_data[station])
    else:
        for station in station_to_model_data:
            station_to_model_data[station] = station_to_model_data[station][0]



    # calculate observed climatology
    station_to_climatology = OrderedDict()
    for s in sorted(station_to_model_point, key=lambda st: st.latitude, reverse=True):
        assert isinstance(s, Station)
        print(s.id, len(s.get_list_of_complete_years()))

        # Check if there are continuous years for the selected period
        common_years = set(s.get_list_of_complete_years()).intersection(set(range(start_year, end_year + 1)))
        if len(common_years) > 0:
            _, station_to_climatology[s] = s.get_daily_climatology_for_complete_years_with_pandas(stamp_dates=stamp_dates,
                                                                                                  years=common_years)

            _, station_to_model_data[s] = pandas_utils.get_daily_climatology_from_pandas_series(station_to_model_data[s],
                                                                                                stamp_dates,
                                                                                                years_of_interest=common_years)


        else:
            print("Skipping {}, since it does not have enough data during the period of interest".format(s.id))







    # ---- Do the plotting ----
    ncols = 4

    nrows = len(station_to_climatology) // ncols
    nrows += int(not (len(station_to_climatology) % ncols == 0))

    axes_list = []
    plot_utils.apply_plot_params(width_cm=8 * ncols, height_cm=8 * nrows, font_size=8)
    fig = plt.figure()
    gs = GridSpec(nrows=nrows, ncols=ncols)




    for i, (s, clim) in enumerate(station_to_climatology.items()):
        assert isinstance(s, Station)

        row = i // ncols
        col = i % ncols

        print(row, col, nrows, ncols)

        # normalize by the drainage area
        if s.drainage_km2 is not None:
            station_to_model_data[s] *= s.drainage_km2 / station_to_model_point[s].accumulation_area

        if s.id in constants.stations_to_greyout:
            ax = fig.add_subplot(gs[row, col], facecolor="0.45")
        else:
            ax = fig.add_subplot(gs[row, col])

        assert isinstance(ax, Axes)

        ax.plot(stamp_dates, clim, color="k", lw=2, label="Obs.")
        ax.plot(stamp_dates, station_to_model_data[s], color="r", lw=2, label="Mod.")
        ax.xaxis.set_major_formatter(FuncFormatter(format_month_label))
        ax.xaxis.set_major_locator(MonthLocator(bymonthday=15))
        ax.xaxis.set_minor_locator(MonthLocator(bymonthday=1))
        ax.grid()





        ax.annotate(s.get_pp_name(), xy=(1.02, 1), xycoords="axes fraction",
                    horizontalalignment="left", verticalalignment="top", fontsize=8, rotation=-90)


        last_date = stamp_dates[-1]
        last_date = last_date.replace(day=calendar.monthrange(last_date.year, last_date.month)[1])

        ax.set_xlim(stamp_dates[0].replace(day=1), last_date)


        ymin, ymax = ax.get_ylim()
        ax.set_ylim(0, ymax)


        if s.drainage_km2 is not None:
            ax.set_title("{}: ({:.1f}$^\circ$E, {:.1f}$^\circ$N, DA={:.0f} km$^2$)".format(s.id, s.longitude, s.latitude, s.drainage_km2))
        else:
            ax.set_title(
                "{}: ({:.1f}$^\circ$E, {:.1f}$^\circ$N, DA not used)".format(s.id, s.longitude, s.latitude))
        axes_list.append(ax)

    # plot the legend
    axes_list[-1].legend()


    if not img_folder.exists():
        img_folder.mkdir()

    fig.tight_layout()
    img_file = img_folder / "{}_{}-{}_{}.png".format(sim_label, start_year, end_year, "-".join(sorted(s.id for s in station_to_climatology)))

    print("Saving {}".format(img_file))
    fig.savefig(str(img_file), bbox_inches="tight", dpi=300)
Exemplo n.º 15
0
def main():
    # dask.set_options(pool=ThreadPool(20))
    img_folder = Path("nei_validation/meridional_avg")
    img_folder.mkdir(parents=True, exist_ok=True)

    pval_crit = 0.1

    start_year = 1980
    end_year = 2010

    subregion = SubRegionByLonLatCorners(lleft={
        "lon": -128,
        "lat": 46
    },
                                         uright={
                                             "lon": -113,
                                             "lat": 55
                                         })

    season_to_months = {
        "DJF": [12, 1, 2],
        "MAM": range(3, 6),
        "JJA": range(6, 9),
        "SON": range(9, 12)
    }

    # TT_min and TT_max mean daily min and maximum temperatures
    var_names = [
        default_varname_mappings.T_AIR_2M_DAILY_MAX,
        default_varname_mappings.T_AIR_2M_DAILY_MIN,
        default_varname_mappings.TOTAL_PREC
    ]

    var_name_to_rolling_window_days = {
        default_varname_mappings.T_AIR_2M_DAILY_MIN: 5,
        default_varname_mappings.T_AIR_2M_DAILY_MAX: 5,
        default_varname_mappings.TOTAL_PREC: 29
    }

    var_name_to_percentile = {
        default_varname_mappings.T_AIR_2M_DAILY_MIN: 0.9,
        default_varname_mappings.T_AIR_2M_DAILY_MAX: 0.1,
        default_varname_mappings.TOTAL_PREC: 0.9,
    }

    # needed for the 3hourly temperature model outputs, when Tmin and Tmax daily are not available
    var_name_to_daily_agg_func = {
        default_varname_mappings.TOTAL_PREC: np.mean,
        default_varname_mappings.T_AIR_2M_DAILY_MAX: np.max,
        default_varname_mappings.T_AIR_2M_DAILY_MIN: np.min,
        default_varname_mappings.T_AIR_2M_DAILY_AVG: np.mean
    }

    var_name_to_display_units = {
        default_varname_mappings.TOTAL_PREC: "mm/day",
        default_varname_mappings.T_AIR_2M_DAILY_MAX: r"$^\circ$C",
        default_varname_mappings.T_AIR_2M_DAILY_MIN: r"$^\circ$C",
        default_varname_mappings.T_AIR_2M_DAILY_AVG: r"$^\circ$C"
    }

    model_vname_to_multiplier = {
        default_varname_mappings.TOTAL_PREC: 1000 * 24 * 3600
    }

    WC_044_DEFAULT_LABEL = "WC_044_default"
    WC_044_CTEM_FRSOIL_DYNGLA_LABEL = "WC_044_modified"
    WC_011_CTEM_FRSOIL_DYNGLA_LABEL = "WC_011_modified"

    sim_paths = OrderedDict()
    sim_paths[WC_044_DEFAULT_LABEL] = Path(
        "/snow3/huziy/NEI/WC/NEI_WC0.44deg_default/Samples")
    sim_paths[WC_044_CTEM_FRSOIL_DYNGLA_LABEL] = Path(
        "/snow3/huziy/NEI/WC/NEI_WC0.44deg_Crr1/Samples")
    sim_paths[WC_011_CTEM_FRSOIL_DYNGLA_LABEL] = Path(
        "/snow3/huziy/NEI/WC/NEI_WC0.11deg_Crr1/Samples")

    elevation_paths = OrderedDict()
    elevation_paths[
        WC_044_DEFAULT_LABEL] = "/snow3/huziy/NEI/WC/NEI_WC0.44deg_default/geophys_CORDEX_NA_0.44d_filled_hwsd_dpth_om_MODIS_Glacier_v2_newdirs"
    elevation_paths[
        WC_044_CTEM_FRSOIL_DYNGLA_LABEL] = "/snow3/huziy/NEI/WC/NEI_WC0.44deg_Crr1/geophys_CORDEX_NA_0.44d_filled_hwsd_dpth_om_MODIS_Glacier_v2_dirs_hshedsfix_CTEM_FRAC_GlVolFix"
    elevation_paths[
        WC_011_CTEM_FRSOIL_DYNGLA_LABEL] = "/snow3/huziy/NEI/WC/NEI_WC0.11deg_Crr1/nei_geophy_wc_011.rpn"

    mod_spatial_scales = OrderedDict([(WC_044_DEFAULT_LABEL, 0.44),
                                      (WC_044_CTEM_FRSOIL_DYNGLA_LABEL, 0.44),
                                      (WC_011_CTEM_FRSOIL_DYNGLA_LABEL, 0.11)])

    # -- daymet daily (initial spatial res)
    # daymet_vname_to_path = {
    #     "prcp": "/snow3/huziy/Daymet_daily/daymet_v3_prcp_*_na.nc4",
    #     "tavg": "/snow3/huziy/Daymet_daily/daymet_v3_tavg_*_na.nc4",
    #     "tmin": "/snow3/huziy/Daymet_daily/daymet_v3_tmin_*_na.nc4",
    #     "tmax": "/snow3/huziy/Daymet_daily/daymet_v3_tmax_*_na.nc4",
    # }

    # -- daymet daily (spatially aggregated)
    daymet_vname_to_path = {
        default_varname_mappings.TOTAL_PREC:
        "/snow3/huziy/Daymet_daily_derivatives/daymet_spatial_agg_prcp_10x10",
        default_varname_mappings.T_AIR_2M_DAILY_AVG:
        "/snow3/huziy/Daymet_daily_derivatives/daymet_spatial_agg_tavg_10x10",
        default_varname_mappings.T_AIR_2M_DAILY_MIN:
        "/snow3/huziy/Daymet_daily_derivatives/daymet_spatial_agg_tmin_10x10",
        default_varname_mappings.T_AIR_2M_DAILY_MAX:
        "/snow3/huziy/Daymet_daily_derivatives/daymet_spatial_agg_tmax_10x10",
    }

    daymet_vname_to_model_vname_internal = {
        default_varname_mappings.T_AIR_2M_DAILY_MIN:
        default_varname_mappings.T_AIR_2M,
        default_varname_mappings.T_AIR_2M_DAILY_MAX:
        default_varname_mappings.T_AIR_2M,
        default_varname_mappings.TOTAL_PREC:
        default_varname_mappings.TOTAL_PREC,
    }

    plot_utils.apply_plot_params(font_size=14)

    # observations
    obs_spatial_scale = 0.1  # 10x10 aggregation from ~0.01 daymet data

    varnames_list = [
        default_varname_mappings.TOTAL_PREC,
        default_varname_mappings.T_AIR_2M_DAILY_MIN,
        default_varname_mappings.T_AIR_2M_DAILY_MAX
    ]

    data_dict = {vn: {} for vn in varnames_list}
    bias_dict = {vn: {} for vn in varnames_list}

    bmap = None
    # calculate the percentiles for each simulation and obs data (obs data interpolated to the model grid)
    for model_label, base_dir in sim_paths.items():
        # model outputs manager
        dm = DataManager(
            store_config={
                DataManager.SP_BASE_FOLDER:
                base_dir,
                DataManager.SP_DATASOURCE_TYPE:
                data_source_types.SAMPLES_FOLDER_FROM_CRCM_OUTPUT,
                DataManager.SP_INTERNAL_TO_INPUT_VNAME_MAPPING:
                default_varname_mappings.vname_map_CRCM5,
                DataManager.SP_LEVEL_MAPPING:
                default_varname_mappings.vname_to_level_map,
                DataManager.SP_VARNAME_TO_FILENAME_PREFIX_MAPPING:
                default_varname_mappings.vname_to_fname_prefix_CRCM5
            })

        for vname_daymet in varnames_list:

            obs_manager = DataManager(
                store_config={
                    DataManager.SP_BASE_FOLDER:
                    daymet_vname_to_path[vname_daymet],
                    DataManager.SP_DATASOURCE_TYPE:
                    data_source_types.ALL_VARS_IN_A_FOLDER_IN_NETCDF_FILES,
                    DataManager.SP_INTERNAL_TO_INPUT_VNAME_MAPPING:
                    default_varname_mappings.daymet_vname_mapping,
                    DataManager.SP_LEVEL_MAPPING: {}
                })

            vname_model = daymet_vname_to_model_vname_internal[vname_daymet]

            nd_rw = var_name_to_rolling_window_days[vname_daymet]
            q = var_name_to_percentile[vname_daymet]
            daily_agg_func = var_name_to_daily_agg_func[vname_daymet]

            # model data
            mod = dm.compute_climatological_quantiles(
                start_year=start_year,
                end_year=end_year,
                daily_agg_func=daily_agg_func,
                rolling_mean_window_days=nd_rw,
                q=q,
                varname_internal=vname_model)

            mod = mod * model_vname_to_multiplier.get(vname_model, 1)

            data_source_mod = f"{model_label}_ndrw{nd_rw}_q{q}_vn{vname_daymet}_{start_year}-{end_year}"

            # obs data
            nneighbors = int(mod_spatial_scales[model_label] /
                             obs_spatial_scale)
            nneighbors = max(nneighbors, 1)

            obs = obs_manager.compute_climatological_quantiles(
                start_year=start_year,
                end_year=end_year,
                daily_agg_func=
                daily_agg_func,  # does not have effect for daymet data because it is daily
                rolling_mean_window_days=nd_rw,
                q=q,
                varname_internal=vname_daymet,
                lons_target=mod.coords["lon"].values,
                lats_target=mod.coords["lat"].values,
                nneighbors=nneighbors)

            # only use model data wherever the obs is not null
            mod = mod.where(obs.notnull())

            data_source_obs = f"DAYMETaggfor_{model_label}_ndrw{nd_rw}_q{q}_vn{vname_daymet}_{start_year}-{end_year}"

            data_source_diff = f"{model_label}vsDAYMET_ndrw{nd_rw}_q{q}_vn{vname_daymet}_{start_year}-{end_year}"

            mask, ij_ll, ij_ur = subregion.to_mask(mod.coords["lon"].values,
                                                   mod.coords["lat"].values)

            mod = mod[:, ij_ll[0]:ij_ur[0] + 1, ij_ll[1]:ij_ur[1] + 1]
            obs = obs[:, ij_ll[0]:ij_ur[0] + 1, ij_ll[1]:ij_ur[1] + 1]

            # set the units to display them during pltting
            mod.attrs["units"] = var_name_to_display_units[vname_daymet]
            obs.attrs["units"] = var_name_to_display_units[vname_daymet]

            # save data for line plots
            data_dict[vname_daymet][data_source_mod] = mod
            data_dict[vname_daymet][data_source_obs] = obs
            bias_dict[vname_daymet][data_source_mod] = mod - obs

            if bmap is None:
                bmap = dm.get_basemap(varname_internal=vname_model,
                                      resolution="i",
                                      area_thresh=area_thresh_km2)

    # Just here what the graphs mean
    vn_to_title = {
        default_varname_mappings.TOTAL_PREC: "PR90",
        default_varname_mappings.T_AIR_2M_DAILY_MIN: "TN90",
        default_varname_mappings.T_AIR_2M_DAILY_MAX: "TX10"
    }

    elev_field_name = "ME"
    meridional_mean_elev_dict = get_meridional_avg_elevation(
        geo_path_dict=elevation_paths,
        subregion=subregion,
        elev_field_name=elev_field_name)

    topo_map = get_topo_map(
        geo_path=elevation_paths[WC_011_CTEM_FRSOIL_DYNGLA_LABEL],
        elev_field_name=elev_field_name)

    for vn in data_dict:

        if len(data_dict[vn]) == 0:
            continue

        plot_meridional_mean(data_dict[vn],
                             bias_dict[vn],
                             panel_titles=(vn_to_title[vn] + " (annual)", ""),
                             img_dir=img_folder,
                             bmap=bmap,
                             meridional_elev_dict=meridional_mean_elev_dict,
                             map_topo=topo_map)

        for sname, months in season_to_months.items():
            plot_meridional_mean(
                data_dict[vn],
                bias_dict[vn],
                panel_titles=("", ""),
                img_dir=img_folder,
                bmap=bmap,
                months=months,
                season_name=sname,
                meridional_elev_dict=meridional_mean_elev_dict,
                map_topo=None,
                plot_values=False,
                lon_min=236,
                lon_max=247,
                plot_legend=(vn == default_varname_mappings.T_AIR_2M_DAILY_MAX)
                and (sname == "SON"))
def main():

    obs_data_path = Path("/RESCUE/skynet3_rech1/huziy/obs_data_for_HLES/interploated_to_the_same_grid/GL_0.1_452x260/anusplin+_interpolated_tt_pr.nc")

    start_year = 1980
    end_year = 2010

    HL_LABEL = "CRCM5_HL"
    NEMO_LABEL = "CRCM5_NEMO"

    # critical p-value for the ttest aka significance level
    p_crit = 0.1

    vars_of_interest = [
 #       T_AIR_2M,
 #       TOTAL_PREC,
 #       SWE,
        LAKE_ICE_FRACTION
    ]

    coastline_width = 0.3


    vname_to_seasonmonths_map = {
        SWE: OrderedDict([("November", [11]),
                          ("December", [12]),
                          ("January", [1,])]),
        LAKE_ICE_FRACTION: OrderedDict([
                         ("February", [2,]),
                          ("March", [3, ]),]),
        T_AIR_2M: season_to_months,
        TOTAL_PREC:  OrderedDict([
            ("Winter", [12, 1, 2]),
            ("Summer", [6, 7, 8]),
        ])
    }

    sim_configs = {

        HL_LABEL: RunConfig(data_path="/RECH2/huziy/coupling/GL_440x260_0.1deg_GL_with_Hostetler/Samples_selected",
                  start_year=start_year, end_year=end_year, label=HL_LABEL),

        NEMO_LABEL: RunConfig(data_path="/RECH2/huziy/coupling/coupled-GL-NEMO1h_30min/selected_fields",
                  start_year=start_year, end_year=end_year, label=NEMO_LABEL),
    }

    sim_labels = [HL_LABEL, NEMO_LABEL]

    vname_to_level = {
        T_AIR_2M: VerticalLevel(1, level_kinds.HYBRID),
        U_WE: VerticalLevel(1, level_kinds.HYBRID),
        V_SN: VerticalLevel(1, level_kinds.HYBRID),
    }


    # Try to get the land_fraction for masking if necessary
    land_fraction = None
    try:
        first_ts_file = Path(sim_configs[HL_LABEL].data_path).parent / "pm1979010100_00000000p"

        land_fraction = get_land_fraction(first_timestep_file=first_ts_file)
    except Exception as err:
        raise err
        pass



    # Calculations

    # prepare params for interpolation
    lons_t, lats_t, bsmap = get_target_lons_lats_basemap(sim_configs[HL_LABEL])

    # get a subdomain of the simulation domain
    nx, ny = lons_t.shape
    iss = IndexSubspace(i_start=20, j_start=20, i_end=nx // 2, j_end=ny/2)
    # just to change basemap limits
    lons_t, lats_t, bsmap = get_target_lons_lats_basemap(sim_configs[HL_LABEL], sub_space=iss, resolution="i", area_thresh=2000)


    xt, yt, zt = lat_lon.lon_lat_to_cartesian(lons_t.flatten(), lats_t.flatten())


    vname_map = {}
    vname_map.update(default_varname_mappings.vname_map_CRCM5)



    # Read and calculate observed seasonal means
    store_config = {
            "base_folder": obs_data_path.parent,
            "data_source_type": data_source_types.ALL_VARS_IN_A_FOLDER_IN_NETCDF_FILES_OPEN_EACH_FILE_SEPARATELY,
            "varname_mapping": vname_map,
            "level_mapping": vname_to_level,
            "offset_mapping": default_varname_mappings.vname_to_offset_CRCM5,
            "multiplier_mapping": default_varname_mappings.vname_to_multiplier_CRCM5,
    }

    obs_dm = DataManager(store_config=store_config)
    obs_data = {}


    # need to save it for ttesting
    obs_vname_to_season_to_std = {}
    obs_vname_to_season_to_nobs = {}

    interp_indices = None
    for vname in vars_of_interest:
        # --
        end_year_for_current_var = end_year
        if vname == SWE:
            end_year_for_current_var = min(1996, end_year)

        # --
        seas_to_year_to_mean = obs_dm.get_seasonal_means(varname_internal=vname,
                                                     start_year=start_year,
                                                     end_year=end_year_for_current_var,
                                                     season_to_months=vname_to_seasonmonths_map[vname])



        seas_to_clim = {seas: np.array(list(y_to_means.values())).mean(axis=0) for seas, y_to_means in seas_to_year_to_mean.items()}
        obs_data[vname] = seas_to_clim

        if interp_indices is None:
            _, interp_indices = obs_dm.get_kdtree().query(list(zip(xt, yt, zt)))




        # need for ttests
        season_to_std = {}
        obs_vname_to_season_to_std[vname] = season_to_std

        season_to_nobs = {}
        obs_vname_to_season_to_nobs[vname] = season_to_nobs

        for season in seas_to_clim:
            seas_to_clim[season] = seas_to_clim[season].flatten()[interp_indices].reshape(lons_t.shape)


            # save the yearly means for ttesting
            season_to_std[season] = np.asarray([field.flatten()[interp_indices].reshape(lons_t.shape)
                                                         for field in seas_to_year_to_mean[season].values()]).std(axis=0)


            season_to_nobs[season] = np.ones_like(lons_t) * len(seas_to_year_to_mean[season])




    # Read and calculate simulated seasonal mean biases
    mod_label_to_vname_to_season_to_std = {}
    mod_label_to_vname_to_season_to_nobs = {}

    sim_data = defaultdict(dict)
    for label, r_config in sim_configs.items():

        store_config = {
                "base_folder": r_config.data_path,
                "data_source_type": data_source_types.SAMPLES_FOLDER_FROM_CRCM_OUTPUT_VNAME_IN_FNAME,
                "varname_mapping": vname_map,
                "level_mapping": vname_to_level,
                "offset_mapping": default_varname_mappings.vname_to_offset_CRCM5,
                "multiplier_mapping": default_varname_mappings.vname_to_multiplier_CRCM5,
        }


        dm = DataManager(store_config=store_config)

        mod_label_to_vname_to_season_to_std[label] = {}
        mod_label_to_vname_to_season_to_nobs[label] = {}


        interp_indices = None
        for vname in vars_of_interest:

            # --
            end_year_for_current_var = end_year
            if vname == SWE:
                end_year_for_current_var = min(1996, end_year)

            # --
            seas_to_year_to_mean = dm.get_seasonal_means(varname_internal=vname,
                                                         start_year=start_year,
                                                         end_year=end_year_for_current_var,
                                                         season_to_months=vname_to_seasonmonths_map[vname])


            # get the climatology
            seas_to_clim = {seas: np.array(list(y_to_means.values())).mean(axis=0) for seas, y_to_means in seas_to_year_to_mean.items()}

            sim_data[label][vname] = seas_to_clim

            if interp_indices is None:
                _, interp_indices = dm.get_kdtree().query(list(zip(xt, yt, zt)))


            season_to_std = {}
            mod_label_to_vname_to_season_to_std[label][vname] = season_to_std

            season_to_nobs = {}
            mod_label_to_vname_to_season_to_nobs[label][vname] = season_to_nobs

            for season in seas_to_clim:
                interpolated_field = seas_to_clim[season].flatten()[interp_indices].reshape(lons_t.shape)
                seas_to_clim[season] = interpolated_field - obs_data[vname][season]

                # calculate standard deviations of the interpolated fields
                season_to_std[season] = np.asarray([field.flatten()[interp_indices].reshape(lons_t.shape) for field in seas_to_year_to_mean[season].values()]).std(axis=0)

                # calculate numobs for the ttest
                season_to_nobs[season] = np.ones_like(lons_t) * len(seas_to_year_to_mean[season])



    # Plotting: interpolate to the same grid and plot obs and biases



    xx, yy = bsmap(lons_t, lats_t)
    lons_t[lons_t > 180] -= 360


    draw_only_first_sim_biases = True
    for vname in vars_of_interest:

        field_mask = maskoceans(lons_t, lats_t, np.zeros_like(lons_t), inlands=vname in [SWE]).mask
        field_mask_lakes = maskoceans(lons_t, lats_t, np.zeros_like(lons_t), inlands=True).mask

        nrows = len(sim_configs) + 2 - 1 * int(draw_only_first_sim_biases)
        ncols = len(vname_to_seasonmonths_map[vname])

        plot_utils.apply_plot_params(width_cm=8 * len(vname_to_seasonmonths_map[vname]), height_cm=4.5 * nrows, font_size=8)
        fig = plt.figure()



        gs = GridSpec(nrows=nrows, ncols=ncols, hspace=0.2, wspace=0.02)

        extend = "both" if vname not in [LAKE_ICE_FRACTION] else "neither"

        # Plot the obs fields
        current_row = 0
        for col, season in enumerate(vname_to_seasonmonths_map[vname]):
            field = obs_data[vname][season]
            ax = fig.add_subplot(gs[current_row, col])
            # ax.set_title(season)


            the_mask = field_mask_lakes if vname in [T_AIR_2M, TOTAL_PREC, SWE] else field_mask
            to_plot = np.ma.masked_where(the_mask, field) * internal_name_to_multiplier[vname]
            clevs = get_clevs(vname)

            if clevs is not None:
                bnorm = BoundaryNorm(clevs, len(clevs) - 1)
                cmap = cm.get_cmap("viridis", len(clevs) - 1)
            else:
                cmap = "viridis"
                bnorm = None

            cs = bsmap.contourf(xx, yy, to_plot, ax=ax, levels=clevs, norm=bnorm, cmap=cmap)
            bsmap.drawcoastlines(linewidth=coastline_width)
            cb = bsmap.colorbar(cs, ax=ax, location="bottom")

            ax.set_frame_on(vname not in [LAKE_ICE_FRACTION, ])

            cb.ax.set_visible(col == 0)

            if col == 0:
                ax.set_ylabel("Obs")



        # plot the biases
        for sim_label in sim_labels:
            current_row += 1
            for col, season in enumerate(vname_to_seasonmonths_map[vname]):

                field = sim_data[sim_label][vname][season]

                ax = fig.add_subplot(gs[current_row, col])

                clevs = get_clevs(vname + "bias")
                if clevs is not None:
                    bnorm = BoundaryNorm(clevs, len(clevs) - 1)
                    cmap = cm.get_cmap("bwr", len(clevs) - 1)
                else:
                    cmap = "bwr"
                    bnorm = None

                the_mask = field_mask_lakes if vname in [T_AIR_2M, TOTAL_PREC, SWE] else field_mask
                to_plot = np.ma.masked_where(the_mask, field) * internal_name_to_multiplier[vname]


                # ttest
                a = sim_data[sim_label][vname][season] + obs_data[vname][season]  # Calculate the simulation data back from biases
                std_a = mod_label_to_vname_to_season_to_std[sim_label][vname][season]
                nobs_a = mod_label_to_vname_to_season_to_nobs[sim_label][vname][season]

                b = obs_data[vname][season]
                std_b =  obs_vname_to_season_to_std[vname][season]
                nobs_b = obs_vname_to_season_to_nobs[vname][season]



                t, p = ttest_ind_from_stats(mean1=a, std1=std_a, nobs1=nobs_a,
                                            mean2=b, std2=std_b, nobs2=nobs_b, equal_var=False)

                # Mask non-significant differences as given by the ttest
                to_plot = np.ma.masked_where(p > p_crit, to_plot)


                # temporary plot the actual values

                cs = bsmap.contourf(xx, yy, to_plot, ax=ax, extend=extend, levels=get_clevs(vname + "bias"), cmap=cmap, norm=bnorm)
                bsmap.drawcoastlines(linewidth=coastline_width)
                cb = bsmap.colorbar(cs, ax=ax, location="bottom")

                ax.set_frame_on(vname not in [LAKE_ICE_FRACTION, ])
                cb.ax.set_visible(False)

                if col == 0:
                    ax.set_ylabel("{}\n-\nObs.".format(sim_label))

            # draw biases only for the first simulation
            if draw_only_first_sim_biases:
                break


        # plot differences between the biases
        current_row += 1
        for col, season in enumerate(vname_to_seasonmonths_map[vname]):

            field = sim_data[NEMO_LABEL][vname][season] - sim_data[HL_LABEL][vname][season]

            ax = fig.add_subplot(gs[current_row, col])

            clevs = get_clevs(vname + "bias")
            if clevs is not None:
                bnorm = BoundaryNorm(clevs, len(clevs) - 1)
                cmap = cm.get_cmap("bwr", len(clevs) - 1)
            else:
                cmap = "bwr"
                bnorm = None


            to_plot = field * internal_name_to_multiplier[vname]
            # to_plot = np.ma.masked_where(field_mask, field) * internal_name_to_multiplier[vname]



            # ttest
            a = sim_data[NEMO_LABEL][vname][season] + obs_data[vname][season]  # Calculate the simulation data back from biases
            std_a = mod_label_to_vname_to_season_to_std[NEMO_LABEL][vname][season]
            nobs_a = mod_label_to_vname_to_season_to_nobs[NEMO_LABEL][vname][season]

            b = sim_data[HL_LABEL][vname][season] + obs_data[vname][season]  # Calculate the simulation data back from biases
            std_b = mod_label_to_vname_to_season_to_std[HL_LABEL][vname][season]
            nobs_b = mod_label_to_vname_to_season_to_nobs[HL_LABEL][vname][season]


            t, p = ttest_ind_from_stats(mean1=a, std1=std_a, nobs1=nobs_a,
                                        mean2=b, std2=std_b, nobs2=nobs_b, equal_var=False)

            # Mask non-significant differences as given by the ttest
            to_plot = np.ma.masked_where(p > p_crit, to_plot)


            # mask the points with not sufficient land fraction
            if land_fraction is not None and vname in [SWE, ]:
                to_plot = np.ma.masked_where(land_fraction < 0.1, to_plot)


            # print("land fractions for large differences ", land_fraction[to_plot > 30])


            cs = bsmap.contourf(xx, yy, to_plot, ax=ax, extend=extend, levels=clevs, cmap=cmap, norm=bnorm)
            bsmap.drawcoastlines(linewidth=coastline_width)
            cb = bsmap.colorbar(cs, ax=ax, location="bottom")

            ax.text(0.99, 1.1, season, va="top", ha="right", fontsize=16, transform=ax.transAxes)

            cb.ax.set_visible(col == 0)

            assert isinstance(ax, Axes)
            ax.set_frame_on(False)

            if col == 0:
                ax.set_ylabel("{}\n-\n{}".format(NEMO_LABEL, HL_LABEL))


        # fig.tight_layout()

        # save a figure per variable
        img_file = "seasonal_biases_{}_{}_{}-{}.png".format(vname,
                                                            "-".join([s for s in vname_to_seasonmonths_map[vname]]),
                                                            start_year, end_year)
        img_file = img_folder.joinpath(img_file)

        fig.savefig(str(img_file), dpi=300, bbox_inches="tight")

        plt.close(fig)
Exemplo n.º 17
0
def main():

    region_of_interest_shp = "data/shp/mtl_flood_2017_basins/02JKL_SDA_Ottawa.shp"

    current_simlabel = "GL_Current_CanESM2"
    future_simlabel = "GL_Future_CanESM2"

    river_storage_varname = "SWSR"
    lake_storage_varname = "SWSL"

    start_year_current = 1989
    end_year_current = 2010

    start_year_future = 2079
    end_year_future = 2100

    base_sim_dir = Path("/snow3/huziy/NEI/GL/GL_CC_CanESM2_RCP85")
    label_to_sim_path = OrderedDict()

    label_to_sim_path[
        current_simlabel] = base_sim_dir / "coupled-GL-current_CanESM2" / "Samples"
    label_to_sim_path[
        future_simlabel] = base_sim_dir / "coupled-GL-future_CanESM2" / "Samples"

    # some common mappings
    varname_mapping = {
        river_storage_varname: river_storage_varname,
        lake_storage_varname: lake_storage_varname
    }

    level_mapping = {river_storage_varname: VerticalLevel(-1)}

    vname_to_fname_prefix = {
        river_storage_varname: "pm",
        lake_storage_varname: "pm"
    }

    dm_current = DataManager(
        store_config={
            "base_folder": str(label_to_sim_path[current_simlabel]),
            "data_source_type":
            data_source_types.SAMPLES_FOLDER_FROM_CRCM_OUTPUT,
            "varname_to_filename_prefix_mapping": vname_to_fname_prefix,
            "varname_mapping": varname_mapping,
            "level_mapping": level_mapping
        })

    dm_future = DataManager(
        store_config={
            "base_folder": str(label_to_sim_path[future_simlabel]),
            "data_source_type":
            data_source_types.SAMPLES_FOLDER_FROM_CRCM_OUTPUT,
            "varname_to_filename_prefix_mapping": vname_to_fname_prefix,
            "varname_mapping": varname_mapping,
            "level_mapping": level_mapping
        })

    #
    ds_current = __get_maximum_storage_and_corresponding_dates(
        start_year_current,
        end_year_current,
        data_manager=dm_current,
        storage_varname=river_storage_varname)

    ds_future = __get_maximum_storage_and_corresponding_dates(
        start_year_future,
        end_year_future,
        data_manager=dm_future,
        storage_varname=river_storage_varname)

    # get constant in time geophysical fields
    bf_storage = __read_bankfull_storage()

    #
    lons, lats, bmap = __get_lons_lats_basemap_from_rpn(
        resolution="i", region_of_interest_shp=region_of_interest_shp)

    # plot current climate values
    label = "storage_{}-{}".format(start_year_current, end_year_current)
    __plot_vals(ds_current,
                bmap,
                lons,
                lats,
                label=label,
                storage_var_name=river_storage_varname,
                bankfull_storage=bf_storage,
                region_of_interest_shp=region_of_interest_shp,
                plot_deviations_from_bankfull_storage=True)

    label = "storage_{}-{}".format(start_year_future, end_year_future)
    __plot_vals(ds_future,
                bmap,
                lons,
                lats,
                label=label,
                storage_var_name=river_storage_varname,
                bankfull_storage=bf_storage,
                region_of_interest_shp=region_of_interest_shp,
                plot_deviations_from_bankfull_storage=True)
Exemplo n.º 18
0
def get_streamflow_at(lon=-100.,
                      lat=50.,
                      data_source_base_dir="",
                      period=None,
                      varname=default_varname_mappings.STREAMFLOW):
    """
    Uses caching
    :param lon:
    :param lat:
    :param data_source_base_dir:
    :param period:
    :param varname:
    :return:
    """
    cache_dir = Path("point_data_cache")
    cache_dir.mkdir(parents=True, exist_ok=True)

    bd_sha = hashlib.sha224(data_source_base_dir.encode()).hexdigest()

    cache_file = cache_dir / f"{varname}_lon{lon}_lat{lat}_{period.start}-{period.end}_{bd_sha}.bin"

    if cache_file.exists():
        return pickle.load(cache_file.open("rb"))

    vname_to_level_erai = {
        T_AIR_2M: VerticalLevel(1, level_kinds.HYBRID),
        U_WE: VerticalLevel(1, level_kinds.HYBRID),
        V_SN: VerticalLevel(1, level_kinds.HYBRID),
    }

    vname_map = {}
    vname_map.update(vname_map_CRCM5)

    store_config = {
        DataManager.SP_BASE_FOLDER:
        data_source_base_dir,
        DataManager.SP_DATASOURCE_TYPE:
        data_source_types.SAMPLES_FOLDER_FROM_CRCM_OUTPUT,
        DataManager.SP_INTERNAL_TO_INPUT_VNAME_MAPPING:
        vname_map,
        DataManager.SP_LEVEL_MAPPING:
        vname_to_level_erai,
        DataManager.SP_OFFSET_MAPPING:
        vname_to_offset_CRCM5,
        DataManager.SP_MULTIPLIER_MAPPING:
        vname_to_multiplier_CRCM5,
        DataManager.SP_VARNAME_TO_FILENAME_PREFIX_MAPPING:
        vname_to_fname_prefix_CRCM5,
    }

    dm = DataManager(store_config=store_config)

    lons_ = np.asarray([lon])
    lats_ = np.asarray([lat])

    data = dm.read_data_for_period_and_interpolate(period=period,
                                                   varname_internal=varname,
                                                   lons_target=lons_,
                                                   lats_target=lats_)

    pickle.dump(data, cache_file.open("wb"))
    return data
Exemplo n.º 19
0
def calculate_lake_effect_snowfall_each_year_in_parallel(
        label_to_config,
        period=None,
        months_of_interest=None,
        nprocs_to_use=None):
    """
    :param label_to_config:
    :param period:  The period of interest defined by the start and the end year of the period (inclusive)
    """

    if months_of_interest is not None:
        period.months_of_interest = months_of_interest

    assert hasattr(period, "months_of_interest")

    for label, the_config in label_to_config.items():
        data_manager = DataManager(store_config=the_config)

        print(the_config)

        if "out_folder" in the_config:
            out_folder = the_config["out_folder"]
        else:
            out_folder = "."

        out_folder = Path(out_folder)

        try:
            # Try to create the output folder if it does not exist
            if not out_folder.exists():
                out_folder.mkdir()

            print("{}: {} created".format(
                multiprocessing.current_process().name, out_folder))
        except FileExistsError:
            print("{}: {} already exists".format(
                multiprocessing.current_process().name, out_folder))

        if nprocs_to_use is None:
            # Use a fraction of the available processes
            nprocs_to_use = max(int(multiprocessing.cpu_count() * 0.75), 1)
            nprocs_to_use = min(nprocs_to_use, period.in_years(
            ))  # No need for more processes than there is of years
            nprocs_to_use = max(
                1, nprocs_to_use)  # make sure that nprocs_to_use is not 0

        print("Using {} processes for parallelization".format(nprocs_to_use))

        # Construct the input params for each process
        in_data = []

        for start in period.range("years"):
            end_date = start.add(
                months=len(period.months_of_interest)).subtract(seconds=1)
            end_date = min(end_date, period.end)
            p = Period(start=start, end=end_date)
            p.months_of_interest = period.months_of_interest
            in_data.append([
                data_manager, label,
                [p.start, p.end, period.months_of_interest], out_folder
            ])
        print(in_data)

        if nprocs_to_use > 1:
            pool = Pool(processes=nprocs_to_use)
            pool.map(enh_lakeffect_snfall_calculator_proc, in_data)
        else:
            for current_in_data in in_data:
                enh_lakeffect_snfall_calculator_proc(current_in_data)

        del in_data
        del data_manager
def get_seasonal_sst_from_crcm5_outputs(sim_label, start_year=1980, end_year=2010, season_to_months=None,
                                        lons_target=None, lats_target=None):




    from lake_effect_snow.default_varname_mappings import T_AIR_2M
    from lake_effect_snow.default_varname_mappings import U_WE
    from lake_effect_snow.default_varname_mappings import V_SN
    from lake_effect_snow.base_utils import VerticalLevel
    from rpn import level_kinds
    from lake_effect_snow import default_varname_mappings
    from data.robust import data_source_types

    from data.robust.data_manager import DataManager


    sim_configs = {

        sim_label: RunConfig(data_path="/RECH2/huziy/coupling/GL_440x260_0.1deg_GL_with_Hostetler/Samples_selected",
                             start_year=start_year, end_year=end_year, label=sim_label),

    }

    r_config = sim_configs[sim_label]

    vname_to_level = {
        T_AIR_2M: VerticalLevel(1, level_kinds.HYBRID),
        U_WE: VerticalLevel(1, level_kinds.HYBRID),
        V_SN: VerticalLevel(1, level_kinds.HYBRID),
        default_varname_mappings.LAKE_WATER_TEMP: VerticalLevel(1, level_kinds.ARBITRARY)
    }




    vname_map = {}

    vname_map.update(default_varname_mappings.vname_map_CRCM5)



    store_config = {
        "base_folder": r_config.data_path,
        "data_source_type": data_source_types.SAMPLES_FOLDER_FROM_CRCM_OUTPUT_VNAME_IN_FNAME,
        "varname_mapping": vname_map,
        "level_mapping": vname_to_level,
        "offset_mapping": default_varname_mappings.vname_to_offset_CRCM5,
        "multiplier_mapping": default_varname_mappings.vname_to_multiplier_CRCM5,
    }


    dm = DataManager(store_config=store_config)


    season_to_year_to_mean = dm.get_seasonal_means(start_year=start_year, end_year=end_year,
                                                   season_to_months=season_to_months,
                                                   varname_internal=default_varname_mappings.LAKE_WATER_TEMP)

    result = {}

    # fill in the result dictionary with seasonal means
    for season in season_to_months:
        result[season] = np.array([field for field in season_to_year_to_mean[season].values()]).mean(axis=0)



    # interpolate the data
    if lons_target is not None:
        xt, yt, zt = lat_lon.lon_lat_to_cartesian(lons_target.flatten(), lats_target.flatten())

        dists, inds = dm.get_kdtree().query(list(zip(xt, yt, zt)))
        for season in season_to_months:
            result[season] = result[season].flatten()[inds].reshape(lons_target.shape)

    return result
Exemplo n.º 21
0
def main(label_to_data_path: dict,
         var_pairs: list,
         periods_info: CcPeriodsInfo,
         vname_display_names=None,
         season_to_months: dict = None,
         cur_label=common_params.crcm_nemo_cur_label,
         fut_label=common_params.crcm_nemo_fut_label,
         hles_region_mask=None,
         lakes_mask=None):
    # get a flat list of all the required variable names (unique)
    varnames = []
    for vpair in var_pairs:
        for v in vpair:
            if v not in varnames:
                varnames.append(v)

    print(f"Considering {varnames}, based on {var_pairs}")

    if vname_display_names is None:
        vname_display_names = {}

    varname_mapping = {v: v for v in varnames}
    level_mapping = {
        v: VerticalLevel(0)
        for v in varnames
    }  # Does not really make a difference, since all variables are 2d

    comon_store_config = {
        DataManager.SP_DATASOURCE_TYPE:
        data_source_types.ALL_VARS_IN_A_FOLDER_IN_NETCDF_FILES,
        DataManager.SP_INTERNAL_TO_INPUT_VNAME_MAPPING: varname_mapping,
        DataManager.SP_LEVEL_MAPPING: level_mapping
    }

    cur_dm = DataManager(store_config=dict(
        {DataManager.SP_BASE_FOLDER: label_to_data_path[cur_label]}, **
        comon_store_config))

    fut_dm = DataManager(store_config=dict(
        {DataManager.SP_BASE_FOLDER: label_to_data_path[fut_label]}, **
        comon_store_config))

    # get the data and do calculations
    label_to_vname_to_season_to_data = {}

    cur_start_yr, cur_end_year = periods_info.get_cur_year_limits()
    fut_start_yr, fut_end_year = periods_info.get_fut_year_limits()

    #load coordinates in memory
    cur_dm.read_data_for_period(Period(datetime(cur_start_yr, 1, 1),
                                       datetime(cur_start_yr, 1, 2)),
                                varname_internal=varnames[0])

    label_to_vname_to_season_to_data = {cur_label: {}, fut_label: {}}

    for vname in varnames:
        cur_means = cur_dm.get_seasonal_means(
            start_year=cur_start_yr,
            end_year=cur_end_year,
            season_to_months=season_to_months,
            varname_internal=vname)

        fut_means = fut_dm.get_seasonal_means(
            start_year=fut_start_yr,
            end_year=fut_end_year,
            season_to_months=season_to_months,
            varname_internal=vname)

        label_to_vname_to_season_to_data[cur_label][vname] = cur_means
        label_to_vname_to_season_to_data[fut_label][vname] = fut_means

    if hles_region_mask is None:
        data_field = label_to_vname_to_season_to_data[
            common_params.crcm_nemo_cur_label][list(
                season_to_months.keys())[0]]
        hles_region_mask = np.ones_like(data_field)

    correlation_data = calculate_correlations_and_pvalues(
        var_pairs,
        label_to_vname_to_season_to_data,
        season_to_months=season_to_months,
        region_of_interest_mask=hles_region_mask,
        lats=cur_dm.lats,
        lakes_mask=lakes_mask)

    # Calculate mean seasonal temperature
    label_to_season_to_tt_mean = {}
    for label, vname_to_season_to_data in label_to_vname_to_season_to_data.items(
    ):
        label_to_season_to_tt_mean[label] = {}
        for season, yearly_data in vname_to_season_to_data["TT"].items():
            label_to_season_to_tt_mean[label][season] = np.mean(
                [f for f in yearly_data.values()], axis=0)

    # do the plotting
    fig = plt.figure()

    ncols = len(season_to_months)
    nrows = len(var_pairs) * len(label_to_vname_to_season_to_data)

    gs = GridSpec(nrows, ncols, wspace=0, hspace=0)

    for col, season in enumerate(season_to_months):
        row = 0

        for vpair in var_pairs:
            for label in sorted(label_to_vname_to_season_to_data):
                ax = fig.add_subplot(gs[row, col],
                                     projection=cartopy.crs.PlateCarree())

                r, pv = correlation_data[vpair][label][season]

                r[np.isnan(r)] = 0
                r = np.ma.masked_where(~hles_region_mask, r)
                ax.set_facecolor("0.75")

                # hide the ticks
                ax.xaxis.set_major_locator(NullLocator())
                ax.yaxis.set_major_locator(NullLocator())

                im = ax.pcolormesh(cur_dm.lons,
                                   cur_dm.lats,
                                   r,
                                   cmap=cm.get_cmap("bwr", 11),
                                   vmin=-1,
                                   vmax=1)

                # add 0 deg line
                cs = ax.contour(cur_dm.lons,
                                cur_dm.lats,
                                label_to_season_to_tt_mean[label][season],
                                levels=[
                                    0,
                                ],
                                linewidths=1,
                                colors="k")
                ax.set_extent([
                    cur_dm.lons[0, 0], cur_dm.lons[-1, -1], cur_dm.lats[0, 0],
                    cur_dm.lats[-1, -1]
                ])

                ax.background_patch.set_facecolor("0.75")

                if row == 0:
                    # ax.set_title(season + f", {vname_display_names[vpair[0]]}")
                    ax.text(0.5,
                            1.05,
                            season,
                            transform=ax.transAxes,
                            va="bottom",
                            ha="center",
                            multialignment="center")

                if col == 0:
                    # ax.set_ylabel(f"HLES\nvs {vname_display_names[vpair[1]]}\n{label}")
                    ax.text(
                        -0.05,
                        0.5,
                        f"HLES\nvs {vname_display_names[vpair[1]]}\n{label}",
                        va="center",
                        ha="right",
                        multialignment="center",
                        rotation=90,
                        transform=ax.transAxes)

                divider = make_axes_locatable(ax)
                ax_cb = divider.new_horizontal(size="5%",
                                               pad=0.1,
                                               axes_class=plt.Axes)
                fig.add_axes(ax_cb)
                cb = plt.colorbar(im, extend="both", cax=ax_cb)

                if row < nrows - 1 or col < ncols - 1:
                    cb.ax.set_visible(False)

                row += 1

    img_dir = common_params.img_folder
    img_dir.mkdir(exist_ok=True)

    img_file = img_dir / "hles_tt_pr_correlation_fields_cur_and_fut_mean_ice_fraction.png"
    fig.savefig(str(img_file), **common_params.image_file_options)
Exemplo n.º 22
0
def main():
    # dask.set_options(pool=ThreadPool(20))
    img_folder = Path("nei_validation")
    img_folder.mkdir(parents=True, exist_ok=True)

    pval_crit = 0.1

    start_year = 1980
    end_year = 1998


    # TT_min and TT_max mean daily min and maximum temperatures
    var_names = [
        default_varname_mappings.T_AIR_2M_DAILY_MAX,
        default_varname_mappings.T_AIR_2M_DAILY_MIN,
        default_varname_mappings.TOTAL_PREC
    ]

    var_name_to_rolling_window_days = {
        default_varname_mappings.T_AIR_2M_DAILY_MIN: 5,
        default_varname_mappings.T_AIR_2M_DAILY_MAX: 5,
        default_varname_mappings.TOTAL_PREC: 29
    }

    var_name_to_percentile = {
        default_varname_mappings.T_AIR_2M_DAILY_MIN: 0.9,
        default_varname_mappings.T_AIR_2M_DAILY_MAX: 0.1,
        default_varname_mappings.TOTAL_PREC: 0.9,
    }

    # needed for the 3hourly temperature model outputs, when Tmin and Tmax daily are not available
    var_name_to_daily_agg_func = {
        default_varname_mappings.TOTAL_PREC: np.mean,
        default_varname_mappings.T_AIR_2M_DAILY_MAX: np.max,
        default_varname_mappings.T_AIR_2M_DAILY_MIN: np.min,
        default_varname_mappings.T_AIR_2M_DAILY_AVG: np.mean
    }




    model_vname_to_multiplier = {
        default_varname_mappings.TOTAL_PREC: 1000 * 24 * 3600
    }


    WC_044_DEFAULT_LABEL = "WC_0.44deg_default"
    WC_044_CTEM_FRSOIL_DYNGLA_LABEL = "WC_0.44deg_ctem+frsoil+dyngla"
    WC_011_CTEM_FRSOIL_DYNGLA_LABEL = "WC_0.11deg_ctem+frsoil+dyngla"

    sim_paths = OrderedDict()
    sim_paths[WC_011_CTEM_FRSOIL_DYNGLA_LABEL] = Path("/snow3/huziy/NEI/WC/NEI_WC0.11deg_Crr1/Samples")
    sim_paths[WC_044_DEFAULT_LABEL] = Path("/snow3/huziy/NEI/WC/NEI_WC0.44deg_default/Samples")
    sim_paths[WC_044_CTEM_FRSOIL_DYNGLA_LABEL] = Path("/snow3/huziy/NEI/WC/debug_NEI_WC0.44deg_Crr1/Samples")


    mod_spatial_scales = OrderedDict([
        (WC_044_DEFAULT_LABEL, 0.44),
        (WC_044_CTEM_FRSOIL_DYNGLA_LABEL, 0.44),
        (WC_011_CTEM_FRSOIL_DYNGLA_LABEL, 0.11)
    ])

    # -- daymet daily (initial spatial res)
    # daymet_vname_to_path = {
    #     "prcp": "/snow3/huziy/Daymet_daily/daymet_v3_prcp_*_na.nc4",
    #     "tavg": "/snow3/huziy/Daymet_daily/daymet_v3_tavg_*_na.nc4",
    #     "tmin": "/snow3/huziy/Daymet_daily/daymet_v3_tmin_*_na.nc4",
    #     "tmax": "/snow3/huziy/Daymet_daily/daymet_v3_tmax_*_na.nc4",
    # }

    # -- daymet daily (spatially aggregated)
    daymet_vname_to_path = {
        default_varname_mappings.TOTAL_PREC: "/snow3/huziy/Daymet_daily_derivatives/daymet_spatial_agg_prcp_10x10",
        default_varname_mappings.T_AIR_2M_DAILY_AVG: "/snow3/huziy/Daymet_daily_derivatives/daymet_spatial_agg_tavg_10x10",
        default_varname_mappings.T_AIR_2M_DAILY_MIN: "/snow3/huziy/Daymet_daily_derivatives/daymet_spatial_agg_tmin_10x10",
        default_varname_mappings.T_AIR_2M_DAILY_MAX: "/snow3/huziy/Daymet_daily_derivatives/daymet_spatial_agg_tmax_10x10",
    }

    daymet_vname_to_model_vname_internal = {
        default_varname_mappings.T_AIR_2M_DAILY_MIN: default_varname_mappings.T_AIR_2M,
        default_varname_mappings.T_AIR_2M_DAILY_MAX: default_varname_mappings.T_AIR_2M,
        default_varname_mappings.TOTAL_PREC: default_varname_mappings.TOTAL_PREC,
    }

    plot_utils.apply_plot_params(font_size=8)


    # observations
    obs_spatial_scale = 0.1  # 10x10 aggregation from ~0.01 daymet data


    varnames_list = [
        default_varname_mappings.TOTAL_PREC,
        default_varname_mappings.T_AIR_2M_DAILY_MIN,
        default_varname_mappings.T_AIR_2M_DAILY_MAX
    ]


    data_dict = {vn: {} for vn in varnames_list}
    bias_dict = {vn: {} for vn in varnames_list}

    # calculate the percentiles for each simulation and obs data (obs data interpolated to the model grid)
    for model_label, base_dir in sim_paths.items():
        # model outputs manager
        dm = DataManager(
            store_config={
                DataManager.SP_BASE_FOLDER: base_dir,
                DataManager.SP_DATASOURCE_TYPE: data_source_types.SAMPLES_FOLDER_FROM_CRCM_OUTPUT,
                DataManager.SP_INTERNAL_TO_INPUT_VNAME_MAPPING: default_varname_mappings.vname_map_CRCM5,
                DataManager.SP_LEVEL_MAPPING: default_varname_mappings.vname_to_level_map,
                DataManager.SP_VARNAME_TO_FILENAME_PREFIX_MAPPING: default_varname_mappings.vname_to_fname_prefix_CRCM5
            }
        )



        for vname_daymet in varnames_list:



            obs_manager = DataManager(
                store_config={
                    DataManager.SP_BASE_FOLDER: daymet_vname_to_path[vname_daymet],
                    DataManager.SP_DATASOURCE_TYPE: data_source_types.ALL_VARS_IN_A_FOLDER_IN_NETCDF_FILES,
                    DataManager.SP_INTERNAL_TO_INPUT_VNAME_MAPPING: default_varname_mappings.daymet_vname_mapping,
                    DataManager.SP_LEVEL_MAPPING: {}
                }
            )

            vname_model = daymet_vname_to_model_vname_internal[vname_daymet]

            nd_rw = var_name_to_rolling_window_days[vname_daymet]
            q = var_name_to_percentile[vname_daymet]
            daily_agg_func = var_name_to_daily_agg_func[vname_daymet]



            # model data
            # TODO: change for the number of summer days
            mod = dm.compute_climatological_quantiles(start_year=start_year, end_year=end_year,
                                                      daily_agg_func=daily_agg_func,
                                                      rolling_mean_window_days=nd_rw,
                                                      q=q,
                                                      varname_internal=vname_model)


            mod = mod * model_vname_to_multiplier.get(vname_model, 1)

            data_source_mod = f"{model_label}_ndrw{nd_rw}_q{q}_vn{vname_daymet}_{start_year}-{end_year}"



            # obs data
            nneighbors = int(mod_spatial_scales[model_label] / obs_spatial_scale)
            nneighbors = max(nneighbors, 1)


            obs = obs_manager.compute_climatological_quantiles(start_year=start_year,
                                                               end_year=end_year,
                                                               daily_agg_func=daily_agg_func,  # does not have effect for daymet data because it is daily
                                                               rolling_mean_window_days=nd_rw,
                                                               q=q,
                                                               varname_internal=vname_daymet,
                                                               lons_target=mod.coords["lon"].values,
                                                               lats_target=mod.coords["lat"].values,
                                                               nneighbors=nneighbors)


            # only use model data wherever the obs is not null
            mod = mod.where(obs.notnull())



            data_source_obs = f"DAYMETaggfor_{model_label}_ndrw{nd_rw}_q{q}_vn{vname_daymet}_{start_year}-{end_year}"

            data_source_diff = f"{model_label}vsDAYMET_ndrw{nd_rw}_q{q}_vn{vname_daymet}_{start_year}-{end_year}"


            # save data for line plots
            data_dict[vname_daymet][data_source_mod] = mod
            data_dict[vname_daymet][data_source_obs] = obs
            bias_dict[vname_daymet][data_source_mod] = mod - obs


            bmap = dm.get_basemap(varname_internal=vname_model, resolution="i", area_thresh=area_thresh_km2)


            # plot model data
            plot_monthly_panels(mod, bmap, img_dir=str(img_folder), data_label=data_source_mod,
                                color_levels=clevs["mean"][vname_model], cmap=cmaps["mean"][vname_model])






            # plot obs data
            plot_monthly_panels(obs, bmap, img_dir=str(img_folder), data_label=data_source_obs,
                                color_levels=clevs["mean"][vname_model], cmap=cmaps["mean"][vname_model])


            plot_monthly_panels(mod - obs, bmap, img_dir=str(img_folder), data_label=data_source_diff,
                                color_levels=clevs["mean"][vname_model + "diff"], cmap=cmaps["mean"][vname_model + "diff"])





    for vn in data_dict:

        if len(data_dict[vn]) == 0:
            continue

        plot_area_avg(data_dict[vn], bias_dict[vn], panel_titles=(vn, ""), img_dir=img_folder / "extremes_1d")
def main(field_list=None, start_year=1980, end_year=2010, label_to_simpath=None,
         merge_chunks=False):
    global_metadata = OrderedDict([
        ("source_dir", ""),
        ("project", "CNRCWP, NEI"),
        ("website", "http://cnrcwp.ca"),
        ("converted_on", pendulum.now().to_day_datetime_string()),
    ])


    if field_list is None:
        field_list = ["PR", "AD", "AV", "GIAC",
                      "GIML", "GLD", "GLF", "GSAB",
                      "GSAC", "GSML", "GVOL", "GWDI",
                      "GWST", "GZ", "HR", "HU", "I1", "I2", "I4",
                      "I5", "MS", "N3", "N4", "P0", "PN", "S6", "SD",
                      "STFL", "SWSL", "SWSR", "T5", "T9", "TDRA", "TJ", "TRAF", "UD", "VD"]

    fields_4d = field_list


    soil_level_widths = [0.1, 0.2, 0.3, 0.4, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5,
                         1.0, 3.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0]


    subgrid_regions_levels = "lev=1: soil; lev=2: glacier; lev=3: water; lev=4:sea ice; lev=5: aggregated; lev=6: urban; lev=7: lakes."


    metadata = {
        "PR": {
            "long_name": "total precipitation",
            "units": "mm/day",
            "description": "total precipitation"
        },
        "AD": {
            "units": "W/m**2",
            "description": "ACCUMULATION OF FDSI(IR ENERGY FLUX TOWARDS GROUND)"
        },
        "AV": {
            "units": "W/m**2",
            "description": "ACCUMULATION OF FV(SURFACE LATENT FLUX)"
        },
        "DN": {
            "units": "kg/m**3",
            "description": "SNOW DENSITY"
        },
        "FN": {
            "description": "TOTAL CLOUDS"
        },
        "GIAC": {"units": "mm weq/s", "description": "ACCUMUL. OF GLACIER ICE ACCUMULATION [MM WEQ/S]"},
        "GIML": {"units": "mm weq/s", "description": "ACCUMUL. OF GLACIER ICE MELT [MM WEQ/S]"},
        "GLD": {"units": "m", "description": "MEAN GLACIER DEPTH FOR WHOLE GRID BOX [M ICE]"},
        "GLF": {"units": "", "description": "GLACIER FRACTION WRT WHOLE GRID"},
        "GSAB": {"units": "mm weq/s", "description": "ACCUMUL. OF SNOW ABLATION ON GLACIER [MM WEQ/S]"},
        "GSAC": {"units": "mm weq/s", "description": "ACCUMUL. OF SNOW ACCUMUL. ON GLACIER [MM WEQ/S]"},
        "GSML": {"units": "mm weq/s", "description": "ACCUMUL. OF SNOW MELT ON GLACIER [MM WEQ/S]"},
        "GVOL": {"units": "m**3 ice", "description": "GLACIER VOLUME FOR WHOLE GRID BOX [M3 ICE]"},
        "GWDI": {"units": "m**3/s", "description": "GROUND WATER DISCHARGE , M**3/S"},
        "GWST": {"units": "m**3", "description": "GROUND WATER STORE , M**3"},
        "GZ": {"units": "dam", "description": "GEOPOTENTIAL HEIGHT"},
        "HR": {"units": "", "description": "RELATIVE HUMIDITY"},
        "HU": {"units": "kg/kg", "description": "SPECIFIC HUMIDITY"},
        "I1": {"units": "m**3/m**3", "description": "SOIL VOLUMETRIC WATER CONTENTS"},
        "I2": {"units": "m**3/m**3", "description": "SOIL VOLUMETRIC ICE CONTENTS"},
        "I4": {"units": "kg/m**2", "description": "WATER IN THE SNOW PACK"},
        "I5": {"units": "kg/m**2", "description": "SNOW MASS"},
        "MS": {"units": "kg/(m**2 * s)", "description": "MELTING SNOW FROM SNOWPACK"},
        "N3": {"units": "mm/day",
               "description": "ACCUM. OF SOLID PRECIP. USED BY LAND SURFACE SCHEMES (LAGGS 1 TIME STEP FROM PR)"},
        "N4": {"units": "W/m**2", "description": "ACCUM. OF SOLAR RADATION"},
        "P0": {"units": "hPa", "description": "SURFACE PRESSURE"},
        "PN": {"units": "hPa", "description": "SEA LEVEL PRESSURE"},
        "S6": {"units": "", "description": "FRACTIONAL COVERAGE FOR SNOW"},
        "SD": {"units": "cm", "description": "SNOW DEPTH"},
        "STFL": {"units": "m**3/s", "description": "SURF. WATER STREAMFLOW IN M**3/S"},
        "SWSL": {"units": "m**3", "description": "SURF. WATER STORE (LAKE), M**3"},
        "SWSR": {"units": "m**3", "description": "SURF. WATER STORE (RIVER), M**3"},
        "T5": {"units": "K", "description": "MIN TEMPERATURE OVER LAST 24.0 HRS"},
        "T9": {"units": "K", "description": "MAX TEMPERATURE OVER LAST 24.0 HRS"},
        "TDRA": {"units": "kg/(m**2 * s)", "description": "ACCUM. OF BASE DRAINAGE"},
        "TJ": {"units": "K", "description": "SCREEN LEVEL TEMPERATURE"},
        "TRAF": {"units": "kg/(m**2 * s)", "description": "ACCUM. OF TOTAL SURFACE RUNOFF"},
        "UD": {"units": "knots", "description": "SCREEN LEVEL X-COMPONENT OF WIND"},
        "VD": {"units": "knots", "description": "SCREEN LEVEL Y-COMPONENT OF WIND"},
        "TT": {"units": "degC", "description": "Air temperature"}
    }


    # add descriptions of subgrid fraction levels
    for v in metadata:
        if v in ["TRAF", "TDRA", "SD"]:
            metadata[v]["description"] += ", " + subgrid_regions_levels

    soil_levels_map = get_tops_and_bots_of_soil_layers(soil_level_widths)
    vname_to_soil_layers = {"I1": soil_levels_map, "I2":soil_levels_map}


    offsets = copy(vname_to_offset_CRCM5)
    multipliers = copy(vname_to_multiplier_CRCM5)
    multipliers["PR"] = 1000 * 24 * 3600  # convert M/s to mm/day ()
    multipliers["N3"] = multipliers["PR"]  # M/s to mm/day

    vname_to_fname_prefix = dict(vname_to_fname_prefix_CRCM5)
    vname_to_fname_prefix.update({
        "PR": "pm",
        "HU": "dp",
        "HR": "dp",
        "GZ": "dp",
        "P0": "dm",
        "PN": "dm",
        "TT": "dm",
        "SN": "pm"
    })

    for vn in field_list:
        if vn not in vname_to_fname_prefix:
            vname_to_fname_prefix[vn] = "pm"


    vname_to_level = {
        T_AIR_2M: VerticalLevel(1, level_kinds.HYBRID),
        U_WE: VerticalLevel(1, level_kinds.HYBRID),
        V_SN: VerticalLevel(1, level_kinds.HYBRID),
    }

    vname_map = {}
    vname_map.update(vname_map_CRCM5)

    for vn in field_list:
        vname_map[vn] = vn

    if label_to_simpath is None:
        label_to_simpath = OrderedDict()
        label_to_simpath["WC044_modified"] = "/snow3/huziy/NEI/WC/debug_NEI_WC0.44deg_Crr1/Samples"
        #label_to_simpath["WC011_modified"] = "/snow3/huziy/NEI/WC/NEI_WC0.11deg_Crr1/Samples"

    for label, simpath in label_to_simpath.items():
        global_metadata["source_dir"] = simpath

        store_config = {
            DataManager.SP_BASE_FOLDER: simpath,
            DataManager.SP_DATASOURCE_TYPE: data_source_types.SAMPLES_FOLDER_FROM_CRCM_OUTPUT,
            DataManager.SP_INTERNAL_TO_INPUT_VNAME_MAPPING: vname_map,
            DataManager.SP_LEVEL_MAPPING: vname_to_level,
            DataManager.SP_OFFSET_MAPPING: offsets,
            DataManager.SP_MULTIPLIER_MAPPING: multipliers,
            DataManager.SP_VARNAME_TO_FILENAME_PREFIX_MAPPING: vname_to_fname_prefix,
        }

        dm = DataManager(store_config=store_config)

        dm.export_to_netcdf(start_year=start_year, end_year=end_year,
                            field_names=field_list, label=label,
                            field_metadata=metadata, global_metadata=global_metadata,
                            field_to_soil_layers=vname_to_soil_layers,
                            merge_chunks=merge_chunks)
Exemplo n.º 24
0
def main():

    obs_data_path = Path("/RESCUE/skynet3_rech1/huziy/obs_data_for_HLES/interploated_to_the_same_grid/GL_0.1_452x260/anusplin+_interpolated_tt_pr.nc")

    start_year = 1980
    end_year = 2010

    HL_LABEL = "CRCM5_HL"
    NEMO_LABEL = "CRCM5_NEMO"


    vars_of_interest = [
        LAKE_ICE_FRACTION,
    ]

    sim_configs = {

        HL_LABEL: RunConfig(data_path="/RECH2/huziy/coupling/GL_440x260_0.1deg_GL_with_Hostetler/Samples_selected",
                  start_year=start_year, end_year=end_year, label=HL_LABEL),

        NEMO_LABEL: RunConfig(data_path="/RECH2/huziy/coupling/coupled-GL-NEMO1h_30min/selected_fields",
                  start_year=start_year, end_year=end_year, label=NEMO_LABEL),
    }

    sim_labels = [HL_LABEL, NEMO_LABEL]

    vname_to_level = {
        T_AIR_2M: VerticalLevel(1, level_kinds.HYBRID),
        U_WE: VerticalLevel(1, level_kinds.HYBRID),
        V_SN: VerticalLevel(1, level_kinds.HYBRID),
    }


    # Calculations

    # prepare params for interpolation
    lons_t, lats_t, bsmap = get_target_lons_lats_basemap(sim_configs[HL_LABEL])
    xt, yt, zt = lat_lon.lon_lat_to_cartesian(lons_t.flatten(), lats_t.flatten())


    vname_map = {}
    vname_map.update(default_varname_mappings.vname_map_CRCM5)



    # Read and calculate observed seasonal means
    store_config = {
            "base_folder": obs_data_path.parent,
            "data_source_type": data_source_types.ALL_VARS_IN_A_FOLDER_IN_NETCDF_FILES_OPEN_EACH_FILE_SEPARATELY,
            "varname_mapping": vname_map,
            "level_mapping": vname_to_level,
            "offset_mapping": default_varname_mappings.vname_to_offset_CRCM5,
            "multiplier_mapping": default_varname_mappings.vname_to_multiplier_CRCM5,
    }

    obs_dm = DataManager(store_config=store_config)
    obs_data = {}

    interp_indices = None
    for vname in vars_of_interest:
        # --
        end_year_for_current_var = end_year
        if vname == SWE:
            end_year_for_current_var = min(1996, end_year)

        # --
        seas_to_year_to_max = obs_dm.get_seasonal_maxima(varname_internal=vname,
                                                     start_year=start_year,
                                                     end_year=end_year_for_current_var,
                                                     season_to_months=season_to_months)

        seas_to_clim = {seas: np.array(list(y_to_means.values())).mean(axis=0) for seas, y_to_means in seas_to_year_to_max.items()}
        obs_data[vname] = seas_to_clim

        if interp_indices is None:
            _, interp_indices = obs_dm.get_kdtree().query(list(zip(xt, yt, zt)))

        for season in seas_to_clim:
            seas_to_clim[season] = seas_to_clim[season].flatten()[interp_indices].reshape(lons_t.shape)

    # Read and calculate simulated seasonal mean biases
    sim_data = defaultdict(dict)
    for label, r_config in sim_configs.items():

        store_config = {
                "base_folder": r_config.data_path,
                "data_source_type": data_source_types.SAMPLES_FOLDER_FROM_CRCM_OUTPUT_VNAME_IN_FNAME,
                "varname_mapping": vname_map,
                "level_mapping": vname_to_level,
                "offset_mapping": default_varname_mappings.vname_to_offset_CRCM5,
                "multiplier_mapping": default_varname_mappings.vname_to_multiplier_CRCM5,
        }


        dm = DataManager(store_config=store_config)


        interp_indices = None
        for vname in vars_of_interest:

            # --
            end_year_for_current_var = end_year
            if vname == SWE:
                end_year_for_current_var = min(1996, end_year)

            # --
            seas_to_year_to_max = dm.get_seasonal_maxima(varname_internal=vname,
                                                           start_year=start_year,
                                                           end_year=end_year_for_current_var,
                                                           season_to_months=season_to_months)

            # get the climatology
            seas_to_clim = {seas: np.array(list(y_to_means.values())).mean(axis=0) for seas, y_to_means in seas_to_year_to_max.items()}

            sim_data[label][vname] = seas_to_clim

            if interp_indices is None:
                _, interp_indices = dm.get_kdtree().query(list(zip(xt, yt, zt)))

            for season in seas_to_clim:
                seas_to_clim[season] = seas_to_clim[season].flatten()[interp_indices].reshape(lons_t.shape) - obs_data[vname][season]







    # Plotting: interpolate to the same grid and plot obs and biases
    plot_utils.apply_plot_params(width_cm=32, height_cm=20, font_size=8)



    xx, yy = bsmap(lons_t, lats_t)
    lons_t[lons_t > 180] -= 360
    field_mask = ~maskoceans(lons_t, lats_t, np.zeros_like(lons_t)).mask

    for vname in vars_of_interest:

        fig = plt.figure()

        fig.suptitle(internal_name_to_title[vname] + "\n")

        nrows = len(sim_configs) + 2
        ncols = len(season_to_months)
        gs = GridSpec(nrows=nrows, ncols=ncols)



        # Plot the obs fields
        current_row = 0
        for col, season in enumerate(season_to_months):
            field = obs_data[vname][season]
            ax = fig.add_subplot(gs[current_row, col])
            ax.set_title(season)

            to_plot = np.ma.masked_where(field_mask, field) * internal_name_to_multiplier[vname]
            clevs = get_clevs(vname)

            if clevs is not None:
                bnorm = BoundaryNorm(clevs, len(clevs) - 1)
                cmap = cm.get_cmap("jet", len(clevs) - 1)
            else:
                cmap = "jet"
                bnorm = None

            cs = bsmap.contourf(xx, yy, to_plot, ax=ax, levels=get_clevs(vname), norm=bnorm, cmap=cmap)
            bsmap.drawcoastlines()
            bsmap.colorbar(cs, ax=ax)

            if col == 0:
                ax.set_ylabel("Obs")



        # plot the biases
        for sim_label in sim_labels:
            current_row += 1
            for col, season in enumerate(season_to_months):

                field = sim_data[sim_label][vname][season]

                ax = fig.add_subplot(gs[current_row, col])

                clevs = get_clevs(vname + "bias")
                if clevs is not None:
                    bnorm = BoundaryNorm(clevs, len(clevs) - 1)
                    cmap = cm.get_cmap("bwr", len(clevs) - 1)
                else:
                    cmap = "bwr"
                    bnorm = None

                to_plot = np.ma.masked_where(field_mask, field) * internal_name_to_multiplier[vname]
                cs = bsmap.contourf(xx, yy, to_plot, ax=ax, extend="both", levels=get_clevs(vname + "bias"), cmap=cmap, norm=bnorm)
                bsmap.drawcoastlines()
                bsmap.colorbar(cs, ax=ax)

                if col == 0:
                    ax.set_ylabel("{}\n-\nObs.".format(sim_label))


        # plot differences between the biases
        current_row += 1
        for col, season in enumerate(season_to_months):

            field = sim_data[NEMO_LABEL][vname][season] - sim_data[HL_LABEL][vname][season]

            ax = fig.add_subplot(gs[current_row, col])

            clevs = get_clevs(vname + "biasdiff")
            if clevs is not None:
                bnorm = BoundaryNorm(clevs, len(clevs) - 1)
                cmap = cm.get_cmap("bwr", len(clevs) - 1)
            else:
                cmap = "bwr"
                bnorm = None

            to_plot = np.ma.masked_where(field_mask, field) * internal_name_to_multiplier[vname]
            cs = bsmap.contourf(xx, yy, to_plot, ax=ax, extend="both", levels=get_clevs(vname + "biasdiff"), cmap=cmap, norm=bnorm)
            bsmap.drawcoastlines()
            bsmap.colorbar(cs, ax=ax)

            if col == 0:
                ax.set_ylabel("{}\n-\n{}".format(NEMO_LABEL, HL_LABEL))


        fig.tight_layout()

        # save a figure per variable
        img_file = "seasonal_biases_{}_{}_{}-{}.png".format(vname,
                                                            "-".join([s for s in season_to_months]),
                                                            start_year, end_year)
        img_file = img_folder.joinpath(img_file)

        fig.savefig(str(img_file))

        plt.close(fig)