예제 #1
0
def plot_scatter_STD(ds_dict, varn, year_start, year_end, fut_year_start, fut_year_end):
    color_dict = dict(
        CESM='r',
        MPI='g',
        CanESM='y',
        ERA20C='k',
        BEST='k',
        CMIP5='b')
    for key in ds_dict.keys():
        if key in ['ERA20C']:
            plt.axvline(
                cos_lat_weighted_mean(ut.linear_detrend(
                    ds_dict[key][varn].sel(year=slice(year_start, year_end)),axis='year').std('year')).data,
                color=color_dict[key],
                linestyle='solid',
                linewidth=1,
                label=key)
        if key in ['BEST']:
            plt.axvline(
                cos_lat_weighted_mean(ut.linear_detrend(
                    ds_dict[key][varn].sel(year=slice(year_start, year_end)),axis='year').std('year')).data,
                color=color_dict[key],
                linestyle='dashed',
                linewidth=1,
                label=key)
        if key in ['CESM','MPI','CanESM','CMIP5']:
            plt.scatter(
                cos_lat_weighted_mean(ut.linear_detrend(
                    ds_dict[key][varn].sel(year=slice(year_start, year_end)),axis='year').std('year')).data,
                cos_lat_weighted_mean(ut.linear_detrend(
                    ds_dict[key][varn].sel(year=slice(fut_year_start, fut_year_end)),axis='year').std('year')).data,
                s=5,
                color=color_dict[key],
                label=key)
예제 #2
0
def plot_fit_STD(ds_dict,
                 varn,
                 year_start,
                 year_end,
                 fut_year_start,
                 fut_year_end,
                 keys=None,
                 **kwargs):
    if keys is None:
        keys = ['CESM', 'MPI', 'CanESM', 'CMIP5']
    elif isinstance(keys, str):
        keys = [keys]

    xx, yy = [], []
    for key in keys:
        xx += list(
            cos_lat_weighted_mean(
                ut.linear_detrend(
                    ds_dict[key][varn].sel(year=slice(year_start, year_end)),
                    axis='year').std('year')).data)
        yy += list(
            cos_lat_weighted_mean(
                ut.linear_detrend(ds_dict[key][varn].sel(
                    year=slice(fut_year_start, fut_year_end)),
                                  axis='year').std('year')).data)

    a, b, r_value, p_value, std_err = linregress(xx, yy)
    f = lambda x: a * x + b
    x = np.array([np.min(xx), np.max(xx)])
    plt.plot(x, f(x), **kwargs)
    return r_value