Ejemplo n.º 1
0
def compare_swe_diff_for_era40_driven():
    b, lons2d, lats2d = draw_regions.get_basemap_and_coords(llcrnrlat=40.0, llcrnrlon=-145, urcrnrlon=-10)

    lons2d[lons2d > 180] -= 360

    x, y = b(lons2d, lats2d)
    #period
    start_year = 1981
    end_year = 1997
    the_months = [12,1,2]

    levels = [10] + list(range(20, 120, 20)) + [150,200, 300,500,1000]
    cmap = mpl.cm.get_cmap(name="jet_r", lut = len(levels))
    norm = colors.BoundaryNorm(levels, cmap.N)



    swe_obs_manager = SweDataManager(var_name="SWE")
    swe_obs = swe_obs_manager.get_mean(start_year, end_year, months=the_months)
    print("Calculated obs swe")

    swe_obs_interp = swe_obs_manager.interpolate_data_to(swe_obs, lons2d, lats2d)


    axes_list = []

    levels_diff = np.arange(-100, 110, 10)

    #plot model res. (ERA40 driven 1)

    paths = ["data/CORDEX/na/era40_1", "data/CORDEX/na/era40_2"]
    prefixes = ["pmNorthAmerica_0.44deg_ERA40-Int_{0}_1958-1961".format("DJF"),
                "pmNorthAmerica_0.44deg_ERA40-Int2_{0}_1958-1961".format("DJF")
                ]
    pf_kinds = draw_regions.get_permafrost_mask(lons2d, lats2d)
    for i, the_path in enumerate(paths):
        base = os.path.basename(the_path)
        fig = plt.figure()
        ax = plt.gca()
        axes_list.append(ax)


        swe_model_era = CRCMDataManager.get_mean_2d_from_climatologies(path=the_path,
            var_name="I5", file_prefixes=prefixes)
        swe_model_era = maskoceans(lons2d, lats2d, swe_model_era)

        #plot model(ERA40 driven) - obs
        axes_list.append(ax)
        img = b.contourf(x, y, swe_model_era - swe_obs_interp, levels = levels_diff)
        draw_colorbar(fig, img, ax = ax, boundaries=levels_diff)
        ax.set_title("Model ({0} 1958-1961) - Obs.".format(base))

        b.drawcoastlines(ax = ax, linewidth = 0.2)
        b.contour(x, y, pf_kinds, ax = ax, colors = "k")
        fig.savefig("swe_{0}.png".format(base))
Ejemplo n.º 2
0
def compare_vars(vname_model="TT", vname_obs="tmp", r_config=None,
                 season_to_months=None,
                 obs_path=None, nx_agg=5, ny_agg=5, bmp_info_agg=None,
                 diff_axes_list=None, obs_axes_list=None,
                 model_axes_list=None, bmp_info_model=None,
                 mask_shape_file=None):
    """

    if obs_axes_list is not None, plot observation data in those

    :param mask_shape_file:
    :param bmp_info_model: basemap info native to the model
    :param model_axes_list: Axes to plot model outputs
    :param vname_model:
    :param vname_obs:
    :param r_config:
    :param season_to_months:
    :param obs_path:
    :param nx_agg:
    :param ny_agg:
    :param bmp_info_agg:
    :param diff_axes_list: if it is None the plots for each variable is done in separate figures
    """

    if vname_obs is None:
        vname_model_to_vname_obs = {"TT": "tmp", "PR": "pre"}
        vname_obs = vname_model_to_vname_obs[vname_model]

    seasonal_clim_fields_model = analysis.get_seasonal_climatology_for_runconfig(run_config=r_config,
                                                                                 varname=vname_model, level=0,
                                                                                 season_to_months=season_to_months)


    season_to_clim_fields_model_agg = OrderedDict()
    for season, field in seasonal_clim_fields_model.items():
        print(field.shape)
        season_to_clim_fields_model_agg[season] = aggregate_array(field, nagg_x=nx_agg, nagg_y=ny_agg)
        if vname_model == "PR":
            season_to_clim_fields_model_agg[season] *= 1.0e3 * 24 * 3600

    if vname_obs in ["SWE", ]:
        obs_manager = SweDataManager(path=obs_path, var_name=vname_obs)
    elif obs_path is None:
        obs_manager = CRUDataManager(var_name=vname_obs)
    else:
        obs_manager = CRUDataManager(var_name=vname_obs, path=obs_path)

    seasonal_clim_fields_obs = obs_manager.get_seasonal_means(season_name_to_months=season_to_months,
                                                              start_year=r_config.start_year,
                                                              end_year=r_config.end_year)

    seasonal_clim_fields_obs_interp = OrderedDict()
    # Derive the mask from a shapefile if provided
    if mask_shape_file is not None:
        the_mask = get_mask(bmp_info_agg.lons, bmp_info_agg.lats, shp_path=mask_shape_file)
    else:
        the_mask = np.zeros_like(bmp_info_agg.lons)

    for season, obs_field in seasonal_clim_fields_obs.items():
        obs_field = obs_manager.interpolate_data_to(obs_field,
                                                    lons2d=bmp_info_agg.lons,
                                                    lats2d=bmp_info_agg.lats,
                                                    nneighbours=1)

        obs_field = np.ma.masked_where(the_mask > 0.5, obs_field)

        seasonal_clim_fields_obs_interp[season] = obs_field

        # assert hasattr(seasonal_clim_fields_obs_interp[season], "mask")

    season_to_err = OrderedDict()
    print("-------------var: {} (PE with CRU)---------------------".format(vname_model))
    for season in seasonal_clim_fields_obs_interp:
        seasonal_clim_fields_obs_interp[season] = np.ma.masked_where(np.isnan(seasonal_clim_fields_obs_interp[season]),
                                                                     seasonal_clim_fields_obs_interp[season])
        season_to_err[season] = season_to_clim_fields_model_agg[season] - seasonal_clim_fields_obs_interp[season]

        if vname_model in ["I5"]:
            lons = bmp_info_agg.lons.copy()
            lons[lons > 180] -= 360
            season_to_err[season] = maskoceans(lons, bmp_info_agg.lats, season_to_err[season])

        good_vals = season_to_err[season]
        good_vals = good_vals[~good_vals.mask]
        
        print("{}: min={}; max={}; avg={}".format(season,
                                                  good_vals.min(),
                                                  good_vals.max(),
                                                  good_vals.mean()))

        print("---------percetages --- CRU ---")
        print("{}: {} \%".format(season, good_vals.mean() / seasonal_clim_fields_obs_interp[season][~season_to_err[season].mask].mean() * 100))





    cs = plot_seasonal_mean_biases(season_to_error_field=season_to_err,
                                   varname=vname_model,
                                   basemap_info=bmp_info_agg,
                                   axes_list=diff_axes_list)

    if obs_axes_list is not None and vname_model in ["I5"]:

        clevs = [0, 50, 60, 70, 80, 90, 100, 150, 200, 250, 300, 350, 400, 500]
        cs_obs = None
        xx, yy = bmp_info_agg.get_proj_xy()
        lons = bmp_info_agg.lons.copy()
        lons[lons > 180] -= 360


        lons_model = None
        xx_model, yy_model = None, None
        cs_mod = None

        norm = BoundaryNorm(clevs, 256)
        for col, (season, obs_field) in enumerate(seasonal_clim_fields_obs_interp.items()):

            # Obsrved fields
            ax = obs_axes_list[col]

            if bmp_info_agg.should_draw_basin_boundaries:
                bmp_info_agg.basemap.readshapefile(BASIN_BOUNDARIES_SHP[:-4], "basin", ax=ax)

            to_plot = maskoceans(lons, bmp_info_agg.lats, obs_field)
            cs_obs = bmp_info_agg.basemap.contourf(xx, yy, to_plot, levels=clevs, ax=ax, norm=norm, extend="max")

            bmp_info_agg.basemap.drawcoastlines(ax=ax, linewidth=0.3)

            ax.set_title(season)

            # Model outputs
            if model_axes_list is not None:
                ax = model_axes_list[col]

                if bmp_info_agg.should_draw_basin_boundaries:
                    bmp_info_agg.basemap.readshapefile(BASIN_BOUNDARIES_SHP[:-4], "basin", ax=ax)

                if lons_model is None:
                    lons_model = bmp_info_model.lons.copy()
                    lons_model[lons_model > 180] -= 360
                    xx_model, yy_model = bmp_info_model.basemap(lons_model, bmp_info_model.lats)

                model_field = seasonal_clim_fields_model[season]

                to_plot = maskoceans(lons_model, bmp_info_model.lats, model_field)
                cs_mod = bmp_info_agg.basemap.contourf(xx_model, yy_model, to_plot, levels=cs_obs.levels, ax=ax,
                                                       norm=cs_obs.norm, cmap=cs_obs.cmap, extend="max")

                bmp_info_agg.basemap.drawcoastlines(ax=ax, linewidth=0.3)


        plt.colorbar(cs_obs, cax=obs_axes_list[-1])


    return cs
Ejemplo n.º 3
0
def main():

    swe_obs_manager = SweDataManager(var_name="SWE")

    data_path = "/home/huziy/skynet3_exec1/from_guillimin/quebec_86x86_0.5deg_without_lakes_v3"
    coord_file = os.path.join(data_path, "pm1985050100_00000000p")
    managerLowRes = Crcm5ModelDataManager(samples_folder_path=data_path,
                                          file_name_prefix="pm",
                                          all_files_in_samples_folder=True,
                                          need_cell_manager=True)

    data_path = "/home/huziy/skynet3_exec1/from_guillimin/quebec_highres_spinup_12_month_without_lakes_v3"
    coord_file = os.path.join(data_path, "pm1985050100_00000000p")
    managerHighRes = Crcm5ModelDataManager(samples_folder_path=data_path,
                                           file_name_prefix="pm",
                                           all_files_in_samples_folder=True,
                                           need_cell_manager=True)

    start_year = 1987
    end_year = 1987
    months = [1, 2, 12]
    rot_lat_lon = RotatedLatLon(lon1=-68, lat1=52, lon2=16.65, lat2=0.0)

    basemap = Basemap(projection="omerc",
                      llcrnrlon=managerHighRes.lons2D[0, 0],
                      llcrnrlat=managerHighRes.lats2D[0, 0],
                      urcrnrlon=managerHighRes.lons2D[-1, -1],
                      urcrnrlat=managerHighRes.lats2D[-1, -1],
                      lat_1=rot_lat_lon.lat1,
                      lat_2=rot_lat_lon.lat2,
                      lon_1=rot_lat_lon.lon1,
                      lon_2=rot_lat_lon.lon2,
                      no_rot=True)

    swe_obs = swe_obs_manager.get_mean(start_year, end_year, months=months)

    obs_ihr = swe_obs_manager.interpolate_data_to(swe_obs,
                                                  managerHighRes.lons2D,
                                                  managerHighRes.lats2D,
                                                  nneighbours=1)

    obs_ilr = swe_obs_manager.interpolate_data_to(swe_obs,
                                                  managerLowRes.lons2D,
                                                  managerLowRes.lats2D,
                                                  nneighbours=1)

    lowResSwe = managerLowRes.get_mean_field(start_year,
                                             end_year,
                                             months=months,
                                             var_name="I5")

    lowResErr = (lowResSwe - obs_ilr)
    lowResErr[obs_ilr > 0] /= obs_ilr[obs_ilr > 0]
    lowResErr = np.ma.masked_where(obs_ilr <= 0, lowResErr)

    highResSwe = managerHighRes.get_mean_field(start_year,
                                               end_year,
                                               months=months,
                                               var_name="I5")
    highResErr = (highResSwe - obs_ihr)
    highResErr[obs_ihr > 0] /= obs_ihr[obs_ihr > 0]
    highResErr = np.ma.masked_where(obs_ihr <= 0, highResErr)

    upscaledHighResSwe = upscale(managerHighRes, managerLowRes, highResSwe)
    upscaledHighResErr = upscaledHighResSwe - obs_ilr
    good_points = obs_ilr > 0
    upscaledHighResErr[good_points] /= obs_ilr[good_points]
    upscaledHighResErr = np.ma.masked_where(~good_points, upscaledHighResErr)

    plot_and_compare_2fields(lowResSwe,
                             "low res",
                             upscaledHighResSwe,
                             "high res (upscaled)",
                             basemap=basemap,
                             manager1=managerLowRes,
                             manager2=managerLowRes)

    plot_and_compare_2fields(lowResErr,
                             "low res err",
                             upscaledHighResErr,
                             "high res (upscaled) err",
                             basemap=basemap,
                             manager1=managerLowRes,
                             manager2=managerLowRes,
                             clevs=np.arange(-1, 1.2, 0.2))

    plot_and_compare_2fields(lowResSwe,
                             "low res",
                             highResSwe,
                             "high res",
                             basemap=basemap,
                             manager1=managerLowRes,
                             manager2=managerHighRes)

    plot_and_compare_2fields(lowResErr,
                             "low res err",
                             highResErr,
                             "high res err",
                             basemap=basemap,
                             manager1=managerLowRes,
                             manager2=managerHighRes,
                             clevs=np.arange(-1, 1.2, 0.2))

    plt.show()
Ejemplo n.º 4
0
def main():

    swe_obs_manager = SweDataManager(var_name="SWE")

    data_path = "/home/huziy/skynet3_exec1/from_guillimin/quebec_86x86_0.5deg_without_lakes_v3"
    coord_file = os.path.join(data_path, "pm1985050100_00000000p")
    managerLowRes = Crcm5ModelDataManager(samples_folder_path=data_path,
                file_name_prefix="pm", all_files_in_samples_folder=True, need_cell_manager=True
    )

    data_path = "/home/huziy/skynet3_exec1/from_guillimin/quebec_highres_spinup_12_month_without_lakes_v3"
    coord_file = os.path.join(data_path, "pm1985050100_00000000p")
    managerHighRes = Crcm5ModelDataManager(samples_folder_path=data_path,
                file_name_prefix="pm", all_files_in_samples_folder=True, need_cell_manager=True
    )


    start_year = 1987
    end_year = 1987
    months = [1,2,12]
    rot_lat_lon = RotatedLatLon(lon1=-68, lat1=52, lon2=16.65, lat2=0.0)

    basemap = Basemap(
        projection="omerc",
        llcrnrlon=managerHighRes.lons2D[0,0],
        llcrnrlat=managerHighRes.lats2D[0, 0],
        urcrnrlon=managerHighRes.lons2D[-1,-1],
        urcrnrlat=managerHighRes.lats2D[-1,-1],
        lat_1=rot_lat_lon.lat1,
        lat_2=rot_lat_lon.lat2,
        lon_1=rot_lat_lon.lon1,
        lon_2=rot_lat_lon.lon2,
        no_rot=True
    )

    swe_obs = swe_obs_manager.get_mean(start_year, end_year, months=months)



    obs_ihr = swe_obs_manager.interpolate_data_to(swe_obs, managerHighRes.lons2D,
                                                           managerHighRes.lats2D, nneighbours=1)

    obs_ilr = swe_obs_manager.interpolate_data_to(swe_obs,managerLowRes.lons2D,
                                                          managerLowRes.lats2D, nneighbours=1)

    lowResSwe = managerLowRes.get_mean_field(start_year, end_year, months=months, var_name="I5")



    lowResErr = (lowResSwe - obs_ilr)
    lowResErr[obs_ilr > 0] /= obs_ilr[obs_ilr > 0]
    lowResErr = np.ma.masked_where(obs_ilr <= 0, lowResErr)

    highResSwe = managerHighRes.get_mean_field(start_year, end_year, months= months, var_name="I5")
    highResErr = (highResSwe - obs_ihr)
    highResErr[obs_ihr > 0 ] /= obs_ihr[obs_ihr > 0]
    highResErr = np.ma.masked_where(obs_ihr <= 0, highResErr)


    upscaledHighResSwe = upscale(managerHighRes, managerLowRes, highResSwe)
    upscaledHighResErr = upscaledHighResSwe - obs_ilr
    good_points = obs_ilr > 0
    upscaledHighResErr[good_points] /= obs_ilr[good_points]
    upscaledHighResErr = np.ma.masked_where(~good_points, upscaledHighResErr)



    plot_and_compare_2fields(lowResSwe, "low res", upscaledHighResSwe, "high res (upscaled)", basemap=basemap,
        manager1 = managerLowRes, manager2 = managerLowRes)

    plot_and_compare_2fields(lowResErr, "low res err", upscaledHighResErr, "high res (upscaled) err", basemap=basemap,
        manager1 = managerLowRes, manager2 = managerLowRes, clevs=np.arange(-1, 1.2, 0.2))


    plot_and_compare_2fields(lowResSwe, "low res", highResSwe, "high res", basemap=basemap,
        manager1 = managerLowRes, manager2 = managerHighRes)

    plot_and_compare_2fields(lowResErr, "low res err", highResErr, "high res err", basemap=basemap,
        manager1 = managerLowRes, manager2 = managerHighRes, clevs = np.arange(-1, 1.2, 0.2))

    plt.show()
Ejemplo n.º 5
0
def main():

    b, lons2d, lats2d = draw_regions.get_basemap_and_coords(llcrnrlat=40.0, llcrnrlon=-145, urcrnrlon=-10)

    lons2d[lons2d > 180] -= 360

    x, y = b(lons2d, lats2d)
    #period
    start_year = 1981
    end_year = 1997
    the_months = [12,1,2]

    levels = [10] + list(range(20, 120, 20)) + [150,200, 300,500,1000]
    cmap = mpl.cm.get_cmap(name="jet_r", lut = len(levels))
    norm = colors.BoundaryNorm(levels, cmap.N)



    swe_obs_manager = SweDataManager(var_name="SWE")
    swe_obs = swe_obs_manager.get_mean(start_year, end_year, months=the_months)
    print("Calculated obs swe")

    swe_obs_interp = swe_obs_manager.interpolate_data_to(swe_obs, lons2d, lats2d, nneighbours=1)

    gs = gridspec.GridSpec(2,3)
    #plot_utils.apply_plot_params(width_pt=None, height_cm=20, width_cm=30, font_size=12)
    fig = plt.figure()
    coast_line_width = 0.25
    axes_list = []

    #plot obs on its own grid but in the model's projection
    ax = fig.add_subplot(gs[0,0])
    axes_list.append(ax)
    x_obs, y_obs = b(swe_obs_manager.lons2d, swe_obs_manager.lats2d)
    swe_obs = maskoceans(swe_obs_manager.lons2d, swe_obs_manager.lats2d, swe_obs)
    img = b.contourf(x_obs, y_obs, swe_obs, levels = levels, norm = norm, cmap = cmap)
    draw_colorbar(fig, img, ax = ax)
    ax.set_title("Obs native grid")

    #plot obs interpolated
    ax = fig.add_subplot(gs[1,0])
    axes_list.append(ax)
    swe_obs_interp = maskoceans(lons2d, lats2d, swe_obs_interp)
    img = b.contourf(x, y, swe_obs_interp, levels = levels, norm = norm, cmap = cmap)
    draw_colorbar(fig, img, ax = ax)
    ax.set_title("Obs interpolated \n to model grid")



    #plot model res. (ERA40 driven)
    ax = fig.add_subplot(gs[0,1])
    axes_list.append(ax)

    prefixes = ["pmNorthAmerica_0.44deg_ERA40-Int_{0}_1958-1977".format(m) for m in ["Dec", "Jan", "Feb"]]
    swe_model_era = CRCMDataManager.get_mean_2d_from_climatologies(path="data/CORDEX/na/means_month/era40_driven",
        var_name="I5", file_prefixes=prefixes)
    swe_model_era = maskoceans(lons2d, lats2d, swe_model_era)
    img = b.contourf(x, y, swe_model_era, levels = levels, norm = norm, cmap = cmap)
    draw_colorbar(fig, img, ax = ax)
    ax.set_title("Model (ERA40 driven 1958-1977)")

    #plot model(ERA40 driven) - obs
    ax = fig.add_subplot(gs[0,2])
    axes_list.append(ax)
    levels_diff = np.arange(-100, 110, 10)
    img = b.contourf(x, y, swe_model_era - swe_obs_interp, levels = levels_diff)
    draw_colorbar(fig, img, ax = ax, boundaries=levels_diff)
    ax.set_title("Model (ERA40 driven 1958-1977) - Obs.")



    #plot model res. (GCM driven, E2)
    ax = fig.add_subplot(gs[1,1])
    axes_list.append(ax)
    path = "/skynet1_exec2/separovi/results/North_America/tests_E/all/means_season"
    prefix = "pmNorthAmerica_0.44deg_CanHistoE2_A1979-1997"
    suffixes = ["djf"]
    swe_model_gcm = CRCMDataManager.get_mean_2d_from_climatologies(path=path, file_prefixes=[prefix],
                    file_suffixes=suffixes, var_name="I5")
    swe_model_gcm = maskoceans(lons2d, lats2d, swe_model_gcm)
    print("model: min = {0}; max = {1}".format(np.ma.min(swe_model_gcm), np.ma.max(swe_model_gcm)))
    img = b.contourf(x, y, swe_model_gcm, levels = levels, norm = norm, cmap = cmap)
    draw_colorbar(fig, img, ax = ax, boundaries=levels_diff)
    ax.set_title("Model (GCM driven, E2, 1979-1997)")




    #plot model(gcm driven) - obs
    ax = fig.add_subplot(gs[1,2])
    axes_list.append(ax)
    levels_diff = np.arange(-100, 110, 10)
    img = b.contourf(x, y, np.ma.array(swe_model_gcm) - swe_obs_interp, levels = levels_diff)
    draw_colorbar(fig, img, ax = ax)
    ax.set_title("Model (GCM driven) - Obs.")



    ####Draw common elements
    pf_kinds = draw_regions.get_permafrost_mask(lons2d, lats2d)
    for the_ax in axes_list:
        b.drawcoastlines(ax = the_ax, linewidth = coast_line_width)
        b.contour(x, y, pf_kinds, ax = the_ax, colors = "k")

    gs.tight_layout(fig, h_pad = 0.9, w_pad = 18)
    fig.savefig("swe_validation.png")
Ejemplo n.º 6
0
def validate_using_monthly_diagnostics():
    start_year = 1980
    end_year = 1996




    sim_data_folder = "/home/huziy/skynet1_rech3/cordex/CORDEX_DIAG/era40_driven_b1"

    sim_names = ["ERA40","MPI","CanESM"]
    simname_to_path = {
        "ERA40": "/home/huziy/skynet1_rech3/cordex/CORDEX_DIAG/era40_driven_b1",
        "MPI": "/home/huziy/skynet1_rech3/cordex/CORDEX_DIAG/NorthAmerica_0.44deg_MPI_B1",
        "CanESM": "/home/huziy/skynet1_rech3/cordex/CORDEX_DIAG/NorthAmerica_0.44deg_CanESM_B1"
    }


    coord_file = os.path.join(sim_data_folder, "pmNorthAmerica_0.44deg_ERA40-Int_B1_200812_moyenne")
    basemap, lons2d, lats2d = draw_regions.get_basemap_and_coords(resolution="c",
        file_path = coord_file, llcrnrlat=45.0, llcrnrlon=-145, urcrnrlon=-20, urcrnrlat=74,
        anchor="W"
    )
    assert isinstance(basemap, Basemap)

    lons2d[lons2d > 180] -= 360

    swe_obs_manager = SweDataManager(var_name="SWE")
    swe_obs = swe_obs_manager.get_mean(start_year, end_year, months=[12,1,2])
    swe_obs = swe_obs_manager.interpolate_data_to(swe_obs, lons2d, lats2d, nneighbours=1)



    x, y = basemap(lons2d, lats2d)
    #x = (x[1:,1:] + x[:-1, :-1]) /2.0


    permafrost_mask = draw_regions.get_permafrost_mask(lons2d, lats2d)
    mask_cond = (permafrost_mask <= 0) | (permafrost_mask >= 2)

    #plot_utils.apply_plot_params(width_pt=None, width_cm=35,height_cm=55, font_size=35)
    fig = plt.figure()
    assert isinstance(fig, Figure)


    cmap = my_colormaps.get_red_blue_colormap(ncolors=10)
    gs = gridspec.GridSpec(3,2, width_ratios=[1,0.1], hspace=0, wspace=0,
        left=0.05, bottom = 0.01, top=0.95)


    all_axes = []
    all_img = []


    i = 0
    for name in sim_names:
        path = simname_to_path[name]
        dm = CRCMDataManager(data_folder=path)
        swe_mod = dm.get_mean_over_months_of_2d_var(start_year, end_year, months = [12,1,2], var_name="I5")

        delta = swe_mod - swe_obs
        ax = fig.add_subplot(gs[i,0])
        assert isinstance(ax, Axes)
        delta = np.ma.masked_where(mask_cond, delta)
        img = basemap.pcolormesh(x, y, delta, cmap = cmap, vmin=-100, vmax = 100)
        if not i:
            ax.set_title("SWE, Mod - Obs, ({0} - {1}), DJF\n".format(start_year, end_year))
        i += 1
        #ax.set_ylabel(name)
        all_axes.append(ax)
        all_img.append(img)



    i = 0
    axs_to_hide = []
    #zones and coastlines
    for the_ax, the_img in zip(all_axes, all_img):
        #divider = make_axes_locatable(the_ax)
        #cax = divider.append_axes("bottom", "5%", pad="3%")

        assert isinstance(the_ax, Axes)
        basemap.drawcoastlines(ax = the_ax, linewidth=0.5)
        basemap.readshapefile("data/pf_4/permafrost8_wgs84/permaice", name="zone",
                ax=the_ax, linewidth=1.5, drawbounds=False)

        for nshape,seg in enumerate(basemap.zone):
            if basemap.zone_info[nshape]["EXTENT"] != "C": continue
            poly = mpl.patches.Polygon(seg,edgecolor = "k", facecolor="none", zorder = 10, lw = 1.5)
            the_ax.add_patch(poly)

        i += 1

    cax = fig.add_subplot(gs[:,1])
    assert isinstance(cax, Axes)
    cax.set_anchor("W")
    cax.set_aspect(30)

    formatter = FuncFormatter(
        lambda x, pos: "{0: <6}".format(str(x))
    )
    cb = fig.colorbar(all_img[0], ax = cax, cax = cax, extend = "both", format = formatter)

    cax.set_title("mm")
    print("aspect = ", cax.get_aspect())

    #fig.tight_layout(h_pad=0)

#    for the_ax in axs_to_hide:
#        the_ax.set_visible(False)

    fig.savefig("swe_validation_era_mpi_canesm_djf.png")





    #swe_obs = swe_obs_manager.get_mean(start_year, end_year, months=the_months)

    pass
def compare_vars(vname_model="TT",
                 vname_obs="tmp",
                 r_config=None,
                 season_to_months=None,
                 obs_path=None,
                 nx_agg_model=5,
                 ny_agg_model=5,
                 bmp_info_agg=None,
                 diff_axes_list=None,
                 obs_axes_list=None,
                 model_axes_list=None,
                 bmp_info_model=None,
                 mask_shape_file=None,
                 nx_agg_obs=1,
                 ny_agg_obs=1):
    """

    if obs_axes_list is not None, plot observation data in those

    :param mask_shape_file:
    :param bmp_info_model: basemap info native to the model
    :param model_axes_list: Axes to plot model outputs
    :param vname_model:
    :param vname_obs:
    :param r_config:
    :param season_to_months:
    :param obs_path:
    :param nx_agg_model:
    :param ny_agg_model:
    :param bmp_info_agg:
    :param diff_axes_list: if it is None the plots for each variable is done in separate figures
    """

    if vname_obs is None:
        vname_model_to_vname_obs = {"TT": "tmp", "PR": "pre"}
        vname_obs = vname_model_to_vname_obs[vname_model]

    seasonal_clim_fields_model = analysis.get_seasonal_climatology_for_runconfig(
        run_config=r_config,
        varname=vname_model,
        level=0,
        season_to_months=season_to_months)

    season_to_clim_fields_model_agg = OrderedDict()
    for season, field in seasonal_clim_fields_model.items():
        print(field.shape)
        season_to_clim_fields_model_agg[season] = aggregate_array(
            field, nagg_x=nx_agg_model, nagg_y=ny_agg_model)
        if vname_model == "PR":
            season_to_clim_fields_model_agg[season] *= 1.0e3 * 24 * 3600

    if vname_obs in [
            "SWE",
    ]:
        obs_manager = SweDataManager(path=obs_path, var_name=vname_obs)
    elif obs_path is None:
        obs_manager = CRUDataManager(var_name=vname_obs)
    else:
        obs_manager = CRUDataManager(var_name=vname_obs, path=obs_path)

    seasonal_clim_fields_obs = obs_manager.get_seasonal_means(
        season_name_to_months=season_to_months,
        start_year=r_config.start_year,
        end_year=r_config.end_year)

    seasonal_clim_fields_obs_interp = OrderedDict()
    # Derive the mask from a shapefile if provided
    if mask_shape_file is not None:
        the_mask = get_mask(bmp_info_agg.lons,
                            bmp_info_agg.lats,
                            shp_path=mask_shape_file)
    else:
        the_mask = np.zeros_like(bmp_info_agg.lons)

    for season, obs_field in seasonal_clim_fields_obs.items():
        obs_field = obs_manager.interpolate_data_to(obs_field,
                                                    lons2d=bmp_info_agg.lons,
                                                    lats2d=bmp_info_agg.lats,
                                                    nneighbours=nx_agg_obs *
                                                    ny_agg_obs)

        obs_field = np.ma.masked_where(the_mask > 0.5, obs_field)

        seasonal_clim_fields_obs_interp[season] = obs_field

        # assert hasattr(seasonal_clim_fields_obs_interp[season], "mask")

    season_to_err = OrderedDict()
    print("-------------var: {} (PE with CRU)---------------------".format(
        vname_model))
    for season in seasonal_clim_fields_obs_interp:
        seasonal_clim_fields_obs_interp[season] = np.ma.masked_where(
            np.isnan(seasonal_clim_fields_obs_interp[season]),
            seasonal_clim_fields_obs_interp[season])
        season_to_err[season] = season_to_clim_fields_model_agg[
            season] - seasonal_clim_fields_obs_interp[season]

        if vname_model in ["I5"]:
            lons = bmp_info_agg.lons.copy()
            lons[lons > 180] -= 360
            season_to_err[season] = maskoceans(lons, bmp_info_agg.lats,
                                               season_to_err[season])

        good_vals = season_to_err[season]
        good_vals = good_vals[~good_vals.mask]

        print("{}: min={}; max={}; avg={}".format(season, good_vals.min(),
                                                  good_vals.max(),
                                                  good_vals.mean()))

        print("---------percetages --- CRU ---")
        print("{}: {} \%".format(
            season,
            good_vals.mean() / seasonal_clim_fields_obs_interp[season]
            [~season_to_err[season].mask].mean() * 100))

    cs = plot_seasonal_mean_biases(season_to_error_field=season_to_err,
                                   varname=vname_model,
                                   basemap_info=bmp_info_agg,
                                   axes_list=diff_axes_list)

    if obs_axes_list is not None and vname_model in ["I5"]:

        clevs = [0, 50, 60, 70, 80, 90, 100, 150, 200, 250, 300, 350, 400, 500]
        cs_obs = None
        xx, yy = bmp_info_agg.get_proj_xy()
        lons = bmp_info_agg.lons.copy()
        lons[lons > 180] -= 360

        lons_model = None
        xx_model, yy_model = None, None
        cs_mod = None

        norm = BoundaryNorm(clevs, 256)
        for col, (season, obs_field) in enumerate(
                seasonal_clim_fields_obs_interp.items()):

            # Obsrved fields
            ax = obs_axes_list[col]

            if bmp_info_agg.should_draw_basin_boundaries:
                bmp_info_agg.basemap.readshapefile(BASIN_BOUNDARIES_SHP[:-4],
                                                   "basin",
                                                   ax=ax)

            to_plot = maskoceans(lons, bmp_info_agg.lats, obs_field)
            cs_obs = bmp_info_agg.basemap.contourf(xx,
                                                   yy,
                                                   to_plot,
                                                   levels=clevs,
                                                   ax=ax,
                                                   norm=norm,
                                                   extend="max")

            bmp_info_agg.basemap.drawcoastlines(ax=ax, linewidth=0.3)

            ax.set_title(season)

            # Model outputs
            if model_axes_list is not None:
                ax = model_axes_list[col]

                if bmp_info_agg.should_draw_basin_boundaries:
                    bmp_info_agg.basemap.readshapefile(
                        BASIN_BOUNDARIES_SHP[:-4], "basin", ax=ax)

                if lons_model is None:
                    lons_model = bmp_info_model.lons.copy()
                    lons_model[lons_model > 180] -= 360
                    xx_model, yy_model = bmp_info_model.basemap(
                        lons_model, bmp_info_model.lats)

                model_field = seasonal_clim_fields_model[season]

                to_plot = maskoceans(lons_model, bmp_info_model.lats,
                                     model_field)
                cs_mod = bmp_info_agg.basemap.contourf(xx_model,
                                                       yy_model,
                                                       to_plot,
                                                       levels=cs_obs.levels,
                                                       ax=ax,
                                                       norm=cs_obs.norm,
                                                       cmap=cs_obs.cmap,
                                                       extend="max")

                bmp_info_agg.basemap.drawcoastlines(ax=ax, linewidth=0.3)

        plt.colorbar(cs_obs, cax=obs_axes_list[-1])

    return cs