コード例 #1
0
# ('sst', os.path.join(path_raw, 'sst_1979-2020_1_12_daily_1.0deg.nc'))]

lags = np.array([1])
tfreq = 2
min_area_in_degrees2 = 3  #10

list_import_ts = None  #[('PDO', os.path.join(data_dir, f'PDO_2y_rm_25-09-20_15hr.h5')),
#                   ('PDO1y', os.path.join(data_dir, 'PDO_1y_rm_25-11-20_17hr.h5'))]

list_for_MI = [
    BivariateMI(name='sst',
                func=class_BivariateMI.corr_map,
                alpha=alpha_corr,
                FDR_control=True,
                kwrgs_func={},
                group_split='together',
                distance_eps=500,
                min_area_in_degrees2=min_area_in_degrees2,
                calc_ts=calc_ts,
                selbox=(130, 260, -10, 60),
                lags=lags)
]

path_out_main = os.path.join(
    main_dir,
    f'publications/Vijverberg_Coumou_2022_NPJ/output/{target}{append_main}/')

rg = RGCPD(list_of_name_path=list_of_name_path,
           list_for_MI=list_for_MI,
           list_import_ts=list_import_ts,
           start_end_TVdate=start_end_TVdate,
コード例 #2
0
name_csv = f'skill_scores_tf{tfreq}.csv'

#%% run RGPD
start_end_TVdate = ('06-01', '08-31')
start_end_date = ('1-1', '12-31')
list_of_name_path = [(cluster_label, TVpath),
                     ('sst',
                      os.path.join(path_raw,
                                   'sst_1979-2020_1_12_monthly_1.0deg.nc'))]

list_for_MI = [
    BivariateMI(name='sst',
                func=class_BivariateMI.corr_map,
                alpha=alpha_corr,
                FDR_control=True,
                kwrgs_func={},
                distance_eps=1700,
                min_area_in_degrees2=3,
                calc_ts=calc_ts,
                selbox=(130, 260, -10, 60),
                lags=corlags)
]
if calc_ts == 'region mean':
    s = ''
else:
    s = '_' + calc_ts.replace(' ', '')

path_out_main = os.path.join(
    main_dir, f'publications/paper2/output/{target}{s}{append_main}/')

rg = RGCPD(list_of_name_path=list_of_name_path,
           list_for_MI=list_for_MI,
コード例 #3
0
    list_of_name_path = [(cluster_label, TVpathtemp),
                         ('z500',
                          os.path.join(path_raw,
                                       'z500_1979-2020_1_12_daily_2.5deg.nc'))]

    # Adjusted box upon request Reviewer 1:
    # z500_green_bb = (155,255,20,73) #: RW box
    # use_sign_pattern_z500 = False

    list_for_MI = [
        BivariateMI(name='z500',
                    func=class_BivariateMI.corr_map,
                    alpha=.05,
                    FDR_control=True,
                    distance_eps=600,
                    min_area_in_degrees2=5,
                    calc_ts='pattern cov',
                    selbox=z500_green_bb,
                    use_sign_pattern=use_sign_pattern_z500,
                    lags=np.array([0]),
                    n_cpu=2)
    ]

    rg1 = RGCPD(list_of_name_path=list_of_name_path,
                list_for_MI=list_for_MI,
                start_end_TVdate=start_end_TVdatet2mvsRW,
                start_end_date=None,
                start_end_year=None,
                tfreq=tfreq,
                path_outmain=path_out_main)
コード例 #4
0
# spring SST correlated with RW
alpha_corr = .05
TVpathERW = os.path.join(data_dir, '2020-10-29_13hr_45min_east_RW.h5')
start_end_TVdate = ('02-01', '05-31')
start_end_date = ('1-1', '12-31')
list_of_name_path = [('z500', TVpathERW),
                     ('sst',
                      os.path.join(path_raw,
                                   'sst_1979-2018_1_12_daily_1.0deg.nc'))]

list_for_MI = [
    BivariateMI(name='sst',
                func=class_BivariateMI.parcorr_map_time,
                alpha=alpha_corr,
                FDR_control=True,
                kwrgs_func={'precursor': True},
                distance_eps=1200,
                min_area_in_degrees2=10,
                calc_ts=calc_ts,
                selbox=(160, 260, 10, 60),
                lags=np.array([0]))
]

if calc_ts == 'region mean':
    s = ''
else:
    s = '_' + calc_ts.replace(' ', '')

path_out_main = os.path.join(
    main_dir, f'publications/paper2/output/easternRW{s}{append_main}/')

rgSST = RGCPD(list_of_name_path=list_of_name_path,
コード例 #5
0
def test_subseas_US_t2m_tigramite(alpha=0.05, tfreq=10, method='random_5',
                                  start_end_TVdate=('07-01', '08-31'),
                                  dailytomonths=False,
                                  TVdates_aggr=False,
                                  lags=np.array([1]),
                                  start_end_yr_precur=None,
                                  start_end_yr_target=None):
    #%%
    # define input: list_of_name_path = [('TVname', 'TVpath'), ('prec_name', 'prec_path')]
    # start_end_yr_target=None; start_end_yr_precur = None; lags = np.array([1]); TVdates_aggr=False; dailytomonths=False;
    # alpha=0.05; tfreq=10; method='random_5';start_end_TVdate=('07-01', '08-31');

    curr_dir = os.path.dirname(os.path.abspath(inspect.getfile(inspect.currentframe()))) # script directory
    main_dir = sep.join(curr_dir.split(sep)[:-1])
    path_test = os.path.join(main_dir, 'data')

    list_of_name_path = [(3, os.path.join(path_test, 'tf5_nc5_dendo_80d77.nc')),
                        ('sst', os.path.join(path_test,'sst_daily_1979-2018_5deg_Pacific_175_240E_25_50N.nc'))]

    # define analysis:
    list_for_MI = [BivariateMI(name='sst', func=class_BivariateMI.corr_map,
                              alpha=alpha, FDR_control=True, lags=lags,
                              distance_eps=700, min_area_in_degrees2=5,
                              dailytomonths=dailytomonths)]

    rg = RGCPD(list_of_name_path=list_of_name_path,
               list_for_MI=list_for_MI,
               start_end_TVdate=start_end_TVdate,
               tfreq=tfreq,
               path_outmain=os.path.join(main_dir,'data', 'test'),
               save=True)

    # if TVpath contains the xr.DataArray xrclustered, we can have a look at the spatial regions.
    rg.plot_df_clust()

    rg.pp_precursors(detrend=True, anomaly=True, selbox=None)


    # ### Post-processing Target Variable
    rg.pp_TV(TVdates_aggr=TVdates_aggr,
             kwrgs_core_pp_time={'dailytomonths':dailytomonths,
                                 'start_end_year':start_end_yr_target})


    rg.traintest(method=method)

    # check
    if TVdates_aggr==False:
        check_dates_RV(rg.df_splits, rg.traintestgroups, start_end_TVdate)

    rg.kwrgs_load['start_end_year'] = start_end_yr_precur

    rg.calc_corr_maps()
    precur = rg.list_for_MI[0]

    rg.plot_maps_corr()

    rg.cluster_list_MI()

    # rg.quick_view_labels(mean=False)

    rg.get_ts_prec()
    try:
        import wrapper_PCMCI as wPCMCI
        if rg.df_data.columns.size <= 3:
            print('Skipping causal inference step')
        else:
            rg.PCMCI_df_data()

            rg.PCMCI_get_links(var=rg.TV.name, alpha_level=.05)
            rg.df_links

            rg.store_df_PCMCI()
    except:
        # raise(ModuleNotFoundError)
        print('Not able to load in Tigramite modules, to enable causal inference '
          'features, install Tigramite from '
          'https://github.com/jakobrunge/tigramite/')
    #%%
    return rg
コード例 #6
0
ファイル: Hovmoller.py プロジェクト: semvijverberg/RGCPD
if west_or_east == 'western':
    TVpathHM = '/Users/semvijverberg/surfdrive/Scripts/RGCPD/publications/paper2/output/west/1ts_0ff31_10jun-24aug_lag0-0_ts_no_train_test_splits1/2020-07-20_15hr_40min_df_data_z500_v200_dt1_0ff31_z500_145-325-20-62.h5'
elif west_or_east == 'eastern':
    TVpathHM = '/Users/semvijverberg/surfdrive/Scripts/RGCPD/publications/paper2/output/east/2ts_0ff31_10jun-24aug_lag0-0_ts_no_train_test_splits1/2020-07-20_15hr_22min_df_data_z500_v200_dt1_0ff31_z500_140-300-20-73.h5'


name_or_cluster_label = 'z500'
name_ds = f'0..0..{name_or_cluster_label}_sp'

list_of_name_path = [(name_or_cluster_label, TVpathHM),
                     ('z500',os.path.join(path_raw, 'z500hpa_1979-2018_1_12_daily_2.5deg.nc')),
                     ('v300',os.path.join(path_raw, 'v300hpa_1979-2018_1_12_daily_2.5deg.nc')),
                     ('sst', os.path.join(path_raw, 'sst_1979-2018_1_12_daily_1.0deg.nc'))]

list_for_MI   = [BivariateMI(name='z500', func=BivariateMI.corr_map,
                             kwrgs_func={'alpha':.01, 'FDR_control':True},
                             distance_eps=500, min_area_in_degrees2=1,
                             calc_ts='pattern cov', use_sign_pattern=True),
                 BivariateMI(name='v300', func=BivariateMI.corr_map,
                             kwrgs_func={'alpha':.01, 'FDR_control':True},
                             distance_eps=500, min_area_in_degrees2=1,
                             calc_ts='pattern cov', use_sign_pattern=True),
                 BivariateMI(name='sst', func=BivariateMI.corr_map,
                             kwrgs_func={'alpha':.01, 'FDR_control':True},
                             distance_eps=500, min_area_in_degrees2=1,
                             calc_ts='pattern cov', selbox=(0,360,0,73))]

rg = RGCPD(list_of_name_path=list_of_name_path,
           list_for_MI=list_for_MI,
           start_end_TVdate=start_end_TVdate,
           start_end_date=start_end_date,
           tfreq=tfreq, lags_i=np.array([0]),
コード例 #7
0
append_main = ''

if precurname == 'sst':
    precursor = ('sst', os.path.join(path_raw, 'sst_1979-2018_1_12_daily_1.0deg.nc'))
elif precurname == 'z500':
    precursor = ('z500', os.path.join(path_raw, 'z500hpa_1979-2018_1_12_daily_2.5deg.nc'))

#%% run RGPD
start_end_TVdate = ('06-01', '08-31')
start_end_date = ('1-1', '12-31')
list_of_name_path = [(cluster_label, TVpath),
                     precursor]

list_for_MI   = [BivariateMI(name=precurname, func=class_BivariateMI.parcorr_map_time,
                            alpha=alpha_corr, FDR_control=True,
                            kwrgs_func={'precursor':True},
                            distance_eps=1200, min_area_in_degrees2=10,
                            calc_ts=calc_ts, selbox=(0,360,-10,90),
                            lags=np.array([0]))]
if calc_ts == 'region mean':
    s = ''
else:
    s = '_' + calc_ts.replace(' ', '')

path_out_main = os.path.join(main_dir, f'publications/paper2/output/{precurname}{s}{append_main}/')

rg = RGCPD(list_of_name_path=list_of_name_path,
           list_for_MI=list_for_MI,
           list_import_ts=None,
           start_end_TVdate=start_end_TVdate,
           start_end_date=start_end_date,
           start_end_year=None,
コード例 #8
0
                                # use_sign_pattern=True,
                                # lags=np.array([0]))]


rg_list = []
start_end_TVdates = [('12-01', '02-28'),
                     ('02-01', '05-30'),
                     ('06-01', '08-31')]
for start_end_TVdate in start_end_TVdates:

    list_for_EOFS = [EOF(name='z500', neofs=1, selbox=z500_green_bb,
                        n_cpu=1, start_end_date=start_end_TVdate)]

    list_for_MI   = [BivariateMI(name='z500', func=class_BivariateMI.corr_map,
                                alpha=.05, FDR_control=True,
                                distance_eps=600, min_area_in_degrees2=5,
                                calc_ts='pattern cov', selbox=z500_selbox,
                                use_sign_pattern=True,
                                lags = np.array([0]), n_cpu=2)]

    rg = RGCPD(list_of_name_path=list_of_name_path,
                list_for_MI=list_for_MI,
                list_for_EOFS=list_for_EOFS,
                list_import_ts=list_import_ts,
                start_end_TVdate=start_end_TVdate,
                start_end_date=start_end_date,
                start_end_year=None,
                tfreq=tfreq,
                path_outmain=path_out_main,
                append_pathsub='_' + name_ds)

コード例 #9
0
#%% Circulation vs temperature
list_of_name_path = [
    (cluster_label, TVpath),
    ('z500', os.path.join(path_raw, 'z500_1979-2020_1_12_daily_2.5deg.nc')),
    ('v300', os.path.join(path_raw, 'v300_1979-2020_1_12_daily_2.5deg.nc'))
]

lags = np.array([0])

list_for_MI = [
    BivariateMI(name='z500',
                func=class_BivariateMI.corr_map,
                alpha=.05,
                FDR_control=True,
                lags=lags,
                distance_eps=600,
                min_area_in_degrees2=5,
                calc_ts='pattern cov',
                selbox=(0, 360, -10, 90),
                use_sign_pattern=True),
    BivariateMI(name='v300',
                func=class_BivariateMI.corr_map,
                alpha=.05,
                FDR_control=True,
                lags=lags,
                distance_eps=600,
                min_area_in_degrees2=5,
                calc_ts='pattern cov',
                selbox=(0, 360, 20, 90),
                use_sign_pattern=True)
]
コード例 #10
0
    ('sst', os.path.join(path_raw, 'sst_1950-2019_1_12_monthly_1.0deg.nc')),
    # ('z500', os.path.join(path_raw, 'z500_1950-2019_1_12_monthly_1.0deg.nc')),
    ('smi',
     os.path.join(path_raw,
                  'SM_spi_gamma_01_1950-2019_1_12_monthly_1.0deg.nc'))
]
# ('swvl1', os.path.join(path_raw, 'swvl1_1950-2019_1_12_monthly_1.0deg.nc')),
# ('swvl1', os.path.join(path_raw, 'swvl1_1950-2019_1_12_monthly_1.0deg.nc'))]

list_for_MI = [
    BivariateMI(name='sst',
                func=class_BivariateMI.corr_map,
                alpha=alpha_corr,
                FDR_control=True,
                kwrgs_func={},
                group_lag=True,
                distance_eps=400,
                min_area_in_degrees2=7,
                calc_ts=calc_ts,
                selbox=GlobalBox,
                lags=corlags),
    BivariateMI(name='smi',
                func=class_BivariateMI.corr_map,
                alpha=alpha_corr,
                FDR_control=True,
                kwrgs_func={},
                group_lag=True,
                distance_eps=300,
                min_area_in_degrees2=4,
                calc_ts=calc_ts,
                selbox=USBox,
コード例 #11
0
        kwrgs_func = {'lag_y':[1], 'lag_x':[1]}

# kwrgs_func['lagzxrelative'] = False

sst_dailytomonths = False
if sst_dailytomonths:
    list_of_name_path = [(name_or_cluster_label, TVpath),
                         ('sst', os.path.join(path_raw, 'sst_1979-2020_1_12_daily_1.0deg.nc'))]
else:
    list_of_name_path = [(name_or_cluster_label, TVpath),
                         ('sst', os.path.join(path_raw, 'sst_1979-2020_1_12_monthly_1.0deg.nc'))]


list_for_MI   = [BivariateMI(name='sst', func=func,
                            alpha=.05, FDR_control=True,
                            kwrgs_func=kwrgs_func,
                            distance_eps=1000, min_area_in_degrees2=1,
                            calc_ts='pattern cov', selbox=(130,260,-10,60),
                            lags=lags, dailytomonths=sst_dailytomonths)]



rg = RGCPD(list_of_name_path=list_of_name_path,
            list_for_MI=list_for_MI,
            start_end_TVdate=('05-01', '08-01'),
            start_end_date=None,
            start_end_year=(1980, 2020),
            tfreq=2,
            path_outmain=path_out_main,
            append_pathsub='_' + exper)

コード例 #12
0
def test_US_t2m_tigramite(alpha=0.05,
                          tfreq=10,
                          method='TimeSeriesSplit_10',
                          start_end_TVdate=('07-01', '08-31'),
                          dailytomonths=False,
                          TVdates_aggr=False,
                          lags=np.array([1]),
                          start_end_yr_precur=None,
                          start_end_yr_target=None,
                          load_annual_target=False):
    #%%
    # define input: list_of_name_path = [('TVname', 'TVpath'), ('prec_name', 'prec_path')]
    # start_end_yr_target=None; start_end_yr_precur = None; lags = np.array([1]); TVdates_aggr=False; dailytomonths=False;
    # alpha=0.05; tfreq=10; method='leave_10';start_end_TVdate=('07-01', '08-31'); load_annual_target=False

    curr_dir = os.getcwd()
    if curr_dir.split(sep)[-1] == 'pytest' or curr_dir.split(
            sep)[-2] == 'RGCPD':
        main_dir = sep.join(curr_dir.split(sep)[:-1])
    else:
        main_dir = curr_dir
    path_test = os.path.join(main_dir, 'data')
    path_target = os.path.join(path_test, 'tf5_nc5_dendo_80d77.nc')

    if load_annual_target:
        ts = core_pp.import_ds_lazy(path_target)['ts'].sel(cluster=3)
        # calculate annual mean value of ts (silly, but just a test)
        df = ts.groupby(ts.time.dt.year).mean().to_dataframe()[['ts']]
        df.index = pd.to_datetime([f'{y}-01-01' for y in df.index])
        path_target = df  # path_target overwritten with df with one-val-per-yr
        dailytomonths = True

    list_of_name_path = [
        (3, path_target),
        ('sst',
         os.path.join(path_test,
                      'sst_daily_1979-2018_5deg_Pacific_175_240E_25_50N.nc'))
    ]

    # define analysis:
    list_for_MI = [
        BivariateMI(name='sst',
                    func=class_BivariateMI.corr_map,
                    alpha=alpha,
                    FDR_control=True,
                    lags=lags,
                    distance_eps=700,
                    min_area_in_degrees2=5,
                    dailytomonths=dailytomonths)
    ]

    rg = RGCPD(list_of_name_path=list_of_name_path,
               list_for_MI=list_for_MI,
               start_end_TVdate=start_end_TVdate,
               tfreq=tfreq,
               path_outmain=os.path.join(main_dir, 'data', 'test'),
               save=True)

    # if TVpath contains the xr.DataArray xrclustered, we can have a look at the spatial regions.
    # rg.plot_df_clust()

    rg.pp_precursors(detrend=True, anomaly=True, selbox=None)

    # ### Post-processing Target Variable
    rg.pp_TV(TVdates_aggr=TVdates_aggr,
             kwrgs_core_pp_time={
                 'dailytomonths': dailytomonths,
                 'start_end_year': start_end_yr_target
             })

    rg.traintest(method=method, gap_prior=1)
    rg.traintestgroups[rg.traintestgroups == 2]
    # check
    if TVdates_aggr == False:
        check_dates_RV(rg.df_splits, rg.traintestgroups, start_end_TVdate)

    rg.kwrgs_load['start_end_year'] = start_end_yr_precur

    rg.calc_corr_maps()
    precur = rg.list_for_MI[0]

    rg.plot_maps_corr()

    rg.cluster_list_MI()

    # rg.quick_view_labels(mean=False)

    rg.get_ts_prec()
    try:
        from RGCPD import wrapper_PCMCI as wPCMCI
        if rg.df_data.columns.size <= 3:
            print('Skipping causal inference step')
        else:
            rg.PCMCI_df_data()

            rg.PCMCI_get_links(var=rg.TV.name, alpha_level=.05)
            rg.df_links

            rg.store_df_PCMCI()
    except:
        # raise(ModuleNotFoundError)
        print(
            'Not able to load in Tigramite modules, to enable causal inference '
            'features, install Tigramite from '
            'https://github.com/jakobrunge/tigramite/')
    #%%
    return rg
コード例 #13
0
)]
RV = user_dir + '/surfdrive/output_RGCPD/easternUS/tf1_n_clusters5_q95_dendo_c378f.nc'
RV_EC = user_dir + '/surfdrive/output_RGCPD/easternUS_EC/EC_tas_tos_Northern/tf1_n_clusters5_q95_dendo_958dd.nc'

list_of_name_path = [(1, RV)]
# ('sm2', '/Users/semvijverberg/surfdrive/ERA5/input_raw/sm2_1979-2018_1_12_daily_1.0deg.nc'),
# ('sm3', '/Users/semvijverberg/surfdrive/ERA5/input_raw/sm3_1979-2018_1_12_daily_1.0deg.nc')]
# ('sst', '/Users/semvijverberg/surfdrive/ERA5/input_raw/sst_1979-2018_1_12_daily_1.0deg.nc')]

list_import_ts = CPPA_s30_14may

list_for_MI = [
    BivariateMI(name='sm2',
                func=BivariateMI.corr_map,
                kwrgs_func={
                    'alpha': .05,
                    'FDR_control': True
                },
                distance_eps=600,
                min_area_in_degrees2=5),
    BivariateMI(name='sm3',
                func=BivariateMI.corr_map,
                kwrgs_func={
                    'alpha': .05,
                    'FDR_control': True
                },
                distance_eps=600,
                min_area_in_degrees2=7)
]
# BivariateMI(name='sst', func=BivariateMI.corr_map,
#              kwrgs_func={'alpha':.001, 'FDR_control':True},
#              distance_eps=800, min_area_in_degrees2=5)]
コード例 #14
0
def pipeline(lags, periodnames, load=False):
    #%%

    method = False
    SM_lags = lags.copy()
    for i, l in enumerate(SM_lags):
        orig = '-'.join(l[0].split('-')[:-1])
        repl = '-'.join(l[1].split('-')[:-1])
        SM_lags[i] = [l[0].replace(orig, repl), l[1]]

    SM = BivariateMI(name='smi',
                     filepath=filename_smi_pp,
                     func=class_BivariateMI.corr_map,
                     alpha=alpha_corr,
                     FDR_control=True,
                     kwrgs_func={},
                     distance_eps=250,
                     min_area_in_degrees2=3,
                     calc_ts='pattern cov',
                     selbox=USBox,
                     lags=SM_lags,
                     use_coef_wghts=True)

    load_SM = '{}_a{}_{}_{}_{}'.format(SM._name, SM.alpha, SM.distance_eps,
                                       SM.min_area_in_degrees2,
                                       periodnames[-1])

    loaded = SM.load_files(pathoutfull, load_SM)
    SM.prec_labels['lag'] = ('lag', periodnames)
    SM.corr_xr['lag'] = ('lag', periodnames)
    # SM.get_prec_ts(kwrgs_load={})
    # df_SM = pd.concat(SM.ts_corr, keys=range(len(SM.ts_corr)))

    TVpath = os.path.join(pathoutfull, f'df_output_{periodnames[-1]}.h5')
    z500_maps = []
    for i, periodname in enumerate(periodnames):
        lag = np.array(lags[i])
        _yrs = [int(l.split('-')[0]) for l in lag]
        if np.unique(_yrs).size > 1:  # crossing year
            crossyr = True
        else:
            crossyr = False
        start_end_TVdate = ('-'.join(lag[0].split('-')[1:]),
                            '-'.join(lag[1].split('-')[1:]))
        lag = np.array([start_end_TVdate])

        list_for_MI = [
            BivariateMI(name='z500',
                        func=class_BivariateMI.corr_map,
                        alpha=alpha_corr,
                        FDR_control=True,
                        kwrgs_func={},
                        distance_eps=250,
                        min_area_in_degrees2=3,
                        calc_ts='pattern cov',
                        selbox=(155, 355, 10, 80),
                        lags=lag,
                        group_split=True,
                        use_coef_wghts=True)
        ]

        name_ds = f'{periodname}..0..{target_dataset}_sp'
        list_of_name_path = [
            ('', TVpath),
            ('z500',
             os.path.join(path_raw, 'z500_1950-2019_1_12_monthly_1.0deg.nc'))
        ]

        start_end_year = (1951, 2019)
        if crossyr:
            TV_start_end_year = (start_end_year[0] + 1, 2019)
        else:
            TV_start_end_year = (start_end_year[0], 2019)
        kwrgs_core_pp_time = {'start_end_year': TV_start_end_year}

        rg = RGCPD(list_of_name_path=list_of_name_path,
                   list_for_MI=list_for_MI,
                   list_import_ts=None,
                   start_end_TVdate=start_end_TVdate,
                   start_end_date=None,
                   start_end_year=start_end_year,
                   tfreq=None,
                   path_outmain=path_out_main)
        rg.figext = '.png'

        rg.pp_precursors(detrend=[True, {
            'tp': False,
            'smi': False
        }],
                         anomaly=[True, {
                             'tp': False,
                             'smi': False
                         }])

        # detrending done prior in clustering_soybean
        rg.pp_TV(name_ds=name_ds,
                 detrend=False,
                 ext_annual_to_mon=False,
                 kwrgs_core_pp_time=kwrgs_core_pp_time)

        # if method.split('_')[0]=='leave':
        # rg.traintest(method, gap_prior=1, gap_after=1, seed=seed,
        # subfoldername=subfoldername)
        # else:
        rg.traintest(method, seed=seed, subfoldername=subfoldername)

        z500 = rg.list_for_MI[0]
        path_circ = os.path.join(rg.path_outsub1, 'circulation')
        os.makedirs(path_circ, exist_ok=True)
        load_z500 = '{}_a{}_{}_{}_{}'.format(z500._name, z500.alpha,
                                             z500.distance_eps,
                                             z500.min_area_in_degrees2,
                                             periodnames[-1])
        if load == 'maps' or load == 'all':
            loaded = z500.load_files(path_circ, load_z500)
        else:
            loaded = False
        if hasattr(z500, 'corr_xr') == False:
            rg.calc_corr_maps('z500')
        # store forecast month
        months = {
            'JJ': 'August',
            'MJ': 'July',
            'AM': 'June',
            'MA': 'May',
            'FM': 'April',
            'JF': 'March',
            'SO': 'December',
            'DJ': 'February'
        }
        rg.fc_month = months[periodnames[-1]]

        z500_maps.append(z500.corr_xr)

        if loaded == False:
            z500.store_netcdf(path_circ, load_z500, add_hash=False)

    z500_maps = xr.concat(z500_maps, dim='lag')
    z500_maps['lag'] = ('lag', periodnames)
    #%%
    # merge maps
    xr_merge = xr.concat(
        [SM.corr_xr.mean('split'),
         z500_maps.drop_vars('split').squeeze()],
        dim='var')
    xr_merge['var'] = ('var', ['SM', 'z500'])
    xr_merge = xr_merge.sel(lag=periodnames[::-1])
    # get mask
    maskSM = RGCPD._get_sign_splits_masked(SM.corr_xr,
                                           min_detect=.1,
                                           mask=SM.corr_xr['mask'])[1]
    xr_mask = xr.concat(
        [maskSM, z500_maps['mask'].drop_vars('split').squeeze()], dim='var')
    xr_mask['var'] = ('var', ['SM', 'z500'])
    xr_mask = xr_mask.sel(lag=periodnames[::-1])

    month_d = {
        'AS': 'Aug-Sep mean',
        'JJ': 'July-June mean',
        'JA': 'July-June mean',
        'MJ': 'May-June mean',
        'AM': 'Apr-May mean',
        'MA': 'Mar-Apr mean',
        'FM': 'Feb-Mar mean',
        'JF': 'Jan-Feb mean',
        'DJ': 'Dec-Jan mean',
        'ND': 'Nov-Dec mean',
        'ON': 'Oct-Nov mean',
        'SO': 'Sep-Oct mean'
    }

    subtitles = np.array([month_d[l] for l in xr_merge.lag.values],
                         dtype='object')[::-1]
    subtitles = np.array([[s + ' SM vs yield' for s in subtitles[::-1]],
                          [s + ' z500 vs SM' for s in subtitles[::-1]]])
    # leadtime = intmon_d[rg.fc_month]
    # subtitles = [subtitles[i-1]+f' ({leadtime+i*2-1}-month lag)' for i in range(1,5)]
    kwrgs_plot = {
        'zoomregion': (170, 355, 15, 80),
        'hspace': -.1,
        'cbar_vert': .05,
        'subtitles': subtitles,
        'clevels': np.arange(-0.8, 0.9, .1),
        'clabels': np.arange(-.8, .9, .2),
        'units': 'Correlation',
        'y_ticks': np.arange(15, 75, 15),
        'x_ticks': np.arange(150, 310, 30)
    }
    fg = plot_maps.plot_corr_maps(xr_merge,
                                  xr_mask,
                                  col_dim='lag',
                                  row_dim='var',
                                  **kwrgs_plot)
    facecolorocean = '#caf0f8'
    facecolorland = 'white'
    for ax in fg.fig.axes[:-1]:
        ax.add_feature(plot_maps.cfeature.__dict__['LAND'],
                       facecolor=facecolorland,
                       zorder=0)
        ax.add_feature(plot_maps.cfeature.__dict__['OCEAN'],
                       facecolor=facecolorocean,
                       zorder=0)

    fg.fig.savefig(os.path.join(path_circ,
                                f'SM_vs_circ_{rg.fc_month}' + rg.figext),
                   bbox_inches='tight')

    # #%%
    # if hasattr(sst, 'prec_labels')==False and 'sst' in use_vars:
    #     rg.cluster_list_MI('sst')
    #     sst.group_small_cluster(distance_eps_sc=2000, eps_corr=0.4)

    #     sst.prec_labels['lag'] = ('lag', periodnames)
    #     sst.corr_xr['lag'] = ('lag', periodnames)
    #     rg.quick_view_labels('sst', min_detect_gc=.5, save=save,
    #                           append_str=periodnames[-1])
    #     plt.close()

    #%%
    return rg
コード例 #15
0
    func = parcorr_z
elif exper == 'corr':
    func = corr_map
    kwrgs_func = {} ;
elif 'parcorrtime' in exper:
    if exper.split('_')[1] == 'target':
        kwrgs_func = {'lag_y':[1]}
    elif exper.split('_')[1] == 'precur':
        kwrgs_func = {'lag_x':[1]}
    elif exper.split('_')[1] == 'both':
        kwrgs_func = {'lag_y':[1], 'lag_x':[1]}
    func = parcorr_map_time

list_for_MI   = [BivariateMI(name='sst', func=func,
                            alpha=.05, FDR_control=True,
                            kwrgs_func=kwrgs_func,
                            distance_eps=1000, min_area_in_degrees2=1,
                            calc_ts='pattern cov', selbox=(130,260,-10,60),
                            lags=lags)]

list_of_name_path = [(name_or_cluster_label, TVpath),
                       ('sst', os.path.join(path_raw, 'sst_1979-2020_1_12_monthly_1.0deg.nc'))]

rg = RGCPD(list_of_name_path=list_of_name_path,
            list_for_MI=list_for_MI,
            start_end_TVdate=('05-01', '08-01'),
            start_end_date=None,
            start_end_year=(1980, 2020),
            tfreq=2,
            path_outmain=path_out_main,
            append_pathsub='_' + exper)
コード例 #16
0
def define(list_of_name_path, TV_targetperiod, n_lags, kwrgs_MI, subfolder):

    #create lag list
    days_dict = {
        '01': '31',
        '02': '28',
        '03': '31',
        '04': '30',
        '05': '31',
        '06': '30',
        '07': '31',
        '08': '31',
        '09': '30',
        '10': '31',
        '11': '30',
        '12': '31'
    }

    target_month_str = TV_targetperiod[0][:2]  #derive month number
    if target_month_str[0] == '0':
        target_month = int(target_month_str[1])  # 01 or 02 ..
    else:
        target_month = int(target_month_str[:])  #10,11,12
    print(target_month)

    if target_month - (n_lags) <= 0:  #cross year?
        crossyr = True
        start_end_year = (1951, 2020)  #hardcoded
    else:
        crossyr = False
        start_end_year = None

    lags = []  #initialize empty lags list
    for i in range(n_lags):
        lag = []  #initialize empty lag list
        if not crossyr:  #if not crossyear with lags, do not add years to lags
            for j in range(1):  #start and end date
                if target_month - i - 1 < 10:
                    lag_month_str_start = '0' + str(
                        target_month - i - 1)  # 01 or 02 ..
                    lag_month_str_end = '0' + str(
                        target_month - i - 1)  # 01 or 02 ..
                else:
                    lag_month_str_start = str(target_month - i - 1)  #10,11,12
                    lag_month_str_end = str(target_month - i - 1)  #10,11,12
        else:  #if crossyear, do add years to lags (1950 and 2019)
            for j in range(1):  #start and end date
                if target_month - i - 1 <= 0:  #crossyear, lagged months in the year before
                    if target_month + 12 - i - 1 < 10:
                        lag_month_str_start = str(
                            start_end_year[0] - 1) + '-0' + str(
                                target_month + 12 - i - 1
                            )  #months in year before TV-targetperiod, 01, 02
                        lag_month_str_end = str(start_end_year[1] -
                                                1) + '-0' + str(target_month +
                                                                12 - i - 1)
                    else:
                        lag_month_str_start = str(
                            start_end_year[0] - 1) + '-' + str(
                                target_month + 12 - i - 1
                            )  #months in year before TV-targetperiod, 10,11,12
                        lag_month_str_end = str(start_end_year[1] -
                                                1) + '-' + str(target_month +
                                                               12 - i - 1)
                else:  #crossyear, but lagged months not in the year before, for instance tv_month 02, lag month 01
                    if target_month - i - 1 < 10:
                        lag_month_str_start = str(
                            start_end_year[0]) + '-0' + str(
                                target_month - i - 1)  # 01 or 02 ..
                        lag_month_str_end = str(
                            start_end_year[1]) + '-0' + str(
                                target_month - i - 1)  # 01 or 02 ..
                    else:
                        lag_month_str_start = str(
                            start_end_year[0]) + '-' + str(
                                target_month - i - 1)  #10,11,12
                        lag_month_str_end = str(start_end_year[1]) + '-' + str(
                            target_month - i - 1)  #10,11,12
        lag.append(lag_month_str_start + '-01')  #first day of month always 01
        lag_month_days_str_end = days_dict[
            lag_month_str_start[-2:]]  #get last day of month from dict
        lag.append(lag_month_str_end + '-' +
                   lag_month_days_str_end)  #concatenate days and months
        lags.append(lag)  #append to lags list
    print(lags)

    #list with input variables
    list_for_MI = [
        BivariateMI(
            name='sst',
            func=class_BivariateMI.corr_map,
            alpha=kwrgs_MI['alpha'],
            FDR_control=kwrgs_MI['FDR_control'],
            lags=np.array(lags),  # <- selecting time periods to aggregate
            distance_eps=kwrgs_MI['distance_eps'],
            min_area_in_degrees2=kwrgs_MI['min_area_in_degrees2'],
            n_jobs_clust=1),
        BivariateMI(
            name='swvl1_2',
            func=class_BivariateMI.corr_map,
            alpha=kwrgs_MI['alpha'],
            FDR_control=kwrgs_MI['FDR_control'],
            lags=np.array(lags),  # <- selecting time periods to aggregate
            distance_eps=kwrgs_MI['distance_eps'],
            min_area_in_degrees2=kwrgs_MI['min_area_in_degrees2'],
            n_jobs_clust=1)
    ]

    #initialize RGCPD class
    rg = RGCPD(
        list_of_name_path=list_of_name_path,
        list_for_MI=list_for_MI,
        tfreq=None,  # <- seasonal forecasting mode, set tfreq to None!
        start_end_TVdate=
        TV_targetperiod,  # <- defining target period (whole year)
        path_outmain=os.path.join(
            main_dir, f'Results/{subfolder}/{list_of_name_path[0][0]}'))

    #preprocess TV
    rg.pp_TV(TVdates_aggr=True,
             kwrgs_core_pp_time={
                 'start_end_year': start_end_year
             })  # <- start_end_TVdate defineds aggregated over period

    return rg, list_for_MI, lags, crossyr
コード例 #17
0
    'df_ts_paper2_clustercorr_{}.h5'.format(xrclustered.attrs['hash']))

functions_pp.store_hdf_df({'df_ts': df_ts}, file_path=TVpath)
#%% Calculate corr maps

list_xr = []
for point in df_ts.columns:
    list_of_name_path = [
        ('', TVpath),
        ('t2m',
         root_data + '/input_raw/t2m_US_1979-2020_1_12_daily_0.25deg.nc')
    ]
    list_for_MI = [
        BivariateMI(name='t2m',
                    func=class_BivariateMI.corr_map,
                    alpha=.05,
                    FDR_control=True,
                    lags=np.array([0]))
    ]

    rg = RGCPD(list_of_name_path=list_of_name_path,
               list_for_MI=list_for_MI,
               path_outmain=path_outmain,
               tfreq=15,
               start_end_TVdate=('06-01', '08-31'),
               save=False)
    rg.pp_precursors()
    rg.pp_TV(name_ds=point)
    rg.traintest(False)
    rg.calc_corr_maps()
    precur = rg.list_for_MI[0]
コード例 #18
0
def check_ts(agg_level):
    ncl_dict = {'high': 20, 'medium': 42, 'low': 135}
    clusters = np.arange(1, ncl_dict[agg_level] + 1)
    path_data = os.path.join(os.path.dirname(main_dir),
                             'Data')  # path of data sets
    kwrgs_MI = {
        'alpha': 0.01,
        'FDR_control': True,
        'distance_eps': 500,
        'min_area_in_degrees2': 5
    }  #some controls for bivariateMI
    targetperiods = [
        ('01-01', '01-31'), ('02-01', '02-28'), ('03-01', '03-31'),
        ('04-01', '04-30'), ('05-01', '05-31'), ('06-01', '06-30'),
        ('07-01', '07-31'), ('08-01', '08-31'), ('09-01', '09-30'),
        ('10-01', '10-31'), ('11-01', '11-30'), ('12-01', '12-31')
    ]
    n_lags = 3  #int, max 12

    allts = pd.DataFrame()
    for c in clusters:
        print(c)
        for t in targetperiods:
            if agg_level == 'high':
                list_of_name_path = [
                    (c, os.path.join(
                        path_data,
                        '[20]_dendo_52baa.nc')),  #for a single cluster!
                    ('sst',
                     os.path.join(path_data,
                                  'sst_1950-2020_1_12_monthly_1.0deg.nc')
                     ),  #sst = global
                    ('swvl1_2',
                     os.path.join(
                         path_data,
                         'swvl_1950-2020_1_12_monthly_1.0deg_mask_0N80N.nc'))
                ]  #swvl = global, summed over layer 1 and 2
            elif agg_level == 'medium':
                list_of_name_path = [
                    (c, os.path.join(
                        path_data,
                        '[42]_dendo_fca84.nc')),  #for a single cluster!
                    ('sst',
                     os.path.join(path_data,
                                  'sst_1950-2020_1_12_monthly_1.0deg.nc')
                     ),  #sst = global
                    ('swvl1_2',
                     os.path.join(
                         path_data,
                         'swvl_1950-2020_1_12_monthly_1.0deg_mask_0N80N.nc'))
                ]  #swvl = global, summed over layer 1 and 2
            elif agg_level == 'low':
                list_of_name_path = [
                    (c, os.path.join(
                        path_data,
                        '[135]_dendo_1c7fe.nc')),  #for a single cluster!
                    ('sst',
                     os.path.join(path_data,
                                  'sst_1950-2020_1_12_monthly_1.0deg.nc')
                     ),  #sst = global
                    ('swvl1_2',
                     os.path.join(
                         path_data,
                         'swvl_1950-2020_1_12_monthly_1.0deg_mask_0N80N.nc'))
                ]  #swvl = global, summed over layer 1 and 2

                #create lag list
            days_dict = {
                '01': '31',
                '02': '28',
                '03': '31',
                '04': '30',
                '05': '31',
                '06': '30',
                '07': '31',
                '08': '31',
                '09': '30',
                '10': '31',
                '11': '30',
                '12': '31'
            }

            target_month_str = t[0][:2]  #derive month number
            if target_month_str[0] == '0':
                target_month = int(target_month_str[1])  # 01 or 02 ..
            else:
                target_month = int(target_month_str[:])  #10,11,12

            if target_month - (n_lags) <= 0:  #cross year?
                crossyr = True
                start_end_year = (1951, 2020)  #hardcoded
            else:
                crossyr = False
                start_end_year = None

            lags = []  #initialize empty lags list
            for i in range(n_lags):
                lag = []  #initialize empty lag list
                if not crossyr:  #if not crossyear with lags, do not add years to lags
                    for j in range(1):  #start and end date
                        if target_month - i - 1 < 10:
                            lag_month_str_start = '0' + str(
                                target_month - i - 1)  # 01 or 02 ..
                            lag_month_str_end = '0' + str(
                                target_month - i - 1)  # 01 or 02 ..
                        else:
                            lag_month_str_start = str(target_month - i -
                                                      1)  #10,11,12
                            lag_month_str_end = str(target_month - i -
                                                    1)  #10,11,12
                else:  #if crossyear, do add years to lags (1950 and 2019)
                    for j in range(1):  #start and end date
                        if target_month - i - 1 <= 0:  #crossyear, lagged months in the year before
                            if target_month + 12 - i - 1 < 10:
                                lag_month_str_start = str(
                                    start_end_year[0] - 1
                                ) + '-0' + str(
                                    target_month + 12 - i - 1
                                )  #months in year before TV-targetperiod, 01, 02
                                lag_month_str_end = str(
                                    start_end_year[1] -
                                    1) + '-0' + str(target_month + 12 - i - 1)
                            else:
                                lag_month_str_start = str(
                                    start_end_year[0] - 1
                                ) + '-' + str(
                                    target_month + 12 - i - 1
                                )  #months in year before TV-targetperiod, 10,11,12
                                lag_month_str_end = str(
                                    start_end_year[1] -
                                    1) + '-' + str(target_month + 12 - i - 1)
                        else:  #crossyear, but lagged months not in the year before, for instance tv_month 02, lag month 01
                            if target_month - i - 1 < 10:
                                lag_month_str_start = str(
                                    start_end_year[0]) + '-0' + str(
                                        target_month - i - 1)  # 01 or 02 ..
                                lag_month_str_end = str(
                                    start_end_year[1]) + '-0' + str(
                                        target_month - i - 1)  # 01 or 02 ..
                            else:
                                lag_month_str_start = str(
                                    start_end_year[0]) + '-' + str(
                                        target_month - i - 1)  #10,11,12
                                lag_month_str_end = str(
                                    start_end_year[1]) + '-' + str(
                                        target_month - i - 1)  #10,11,12
                lag.append(lag_month_str_start +
                           '-01')  #first day of month always 01
                lag_month_days_str_end = days_dict[
                    lag_month_str_start[-2:]]  #get last day of month from dict
                lag.append(
                    lag_month_str_end + '-' +
                    lag_month_days_str_end)  #concatenate days and months
                lags.append(lag)  #append to lags list

            #list with input variables
            list_for_MI = [
                BivariateMI(
                    name='sst',
                    func=class_BivariateMI.corr_map,
                    alpha=kwrgs_MI['alpha'],
                    FDR_control=kwrgs_MI['FDR_control'],
                    lags=np.array(
                        lags),  # <- selecting time periods to aggregate
                    distance_eps=kwrgs_MI['distance_eps'],
                    min_area_in_degrees2=kwrgs_MI['min_area_in_degrees2']),
                BivariateMI(
                    name='swvl1_2',
                    func=class_BivariateMI.corr_map,
                    alpha=kwrgs_MI['alpha'],
                    FDR_control=kwrgs_MI['FDR_control'],
                    lags=np.array(
                        lags),  # <- selecting time periods to aggregate
                    distance_eps=kwrgs_MI['distance_eps'],
                    min_area_in_degrees2=kwrgs_MI['min_area_in_degrees2'])
            ]

            #initialize RGCPD class
            rg = RGCPD(
                list_of_name_path=list_of_name_path,
                list_for_MI=list_for_MI,
                tfreq=None,  # <- seasonal forecasting mode, set tfreq to None!
                start_end_TVdate=t,  # <- defining target period (whole year)
                path_outmain=os.path.join(main_dir, 'data'))

            #preprocess TV
            rg.pp_TV(TVdates_aggr=True,
                     kwrgs_core_pp_time={
                         'start_end_year': start_end_year
                     })  # <- start_end_TVdate defineds aggregated over period

            #update dates
            month = int(t[0][:2])
            delta = month - 1
            df = rg.df_fullts[:]
            date_list = df.index.get_level_values(0).shift(delta, freq='MS')
            df.set_index([date_list], inplace=True)

            #store
            if c - 1 == 0:
                allts = allts.append(df)
            elif c - 1 > 0 and month - 1 == 0:
                allts = allts.join(df, how='left')
            else:
                allts.update(df, join='left')

    datetimestamp = datetime.now()
    datetimestamp_str = datetimestamp.strftime("%Y-%m-%d_%H-%M-%S")
    path_data = os.path.join(os.path.dirname(main_dir),
                             'Data')  # path of data sets
    allts.to_csv(
        os.path.join(path_data, datetimestamp_str + '_cl_ts_' + agg_level +
                     '.csv'))  #save skillscores to csv

    return allts


#check_ts('high')
コード例 #19
0
name_rob_csv = 'robustness_SST_RW_T.csv'

if tfreq > 15: sst_green_bb = (140,240,-9,59) # (180, 240, 30, 60): original warm-code focus
if tfreq <= 15: sst_green_bb = (140,235,20,59) # same as for West

name_or_cluster_label = 'z500'
name_ds = f'0..0..{name_or_cluster_label}_sp'

#%% Circulation vs temperature
list_of_name_path = [(cluster_label, TVpathtemp),
                     ('z500', os.path.join(path_raw, 'z500hpa_1979-2018_1_12_daily_2.5deg.nc')),
                     ('SST', os.path.join(path_raw, 'sst_1979-2018_1_12_daily_1.0deg.nc'))]

list_for_MI   = [BivariateMI(name='z500', func=class_BivariateMI.corr_map,
                            alpha=.05, FDR_control=True,
                            distance_eps=600, min_area_in_degrees2=5,
                            calc_ts='pattern cov', selbox=z500_green_bb,
                            use_sign_pattern=False, lags = np.array([0])),
                 BivariateMI(name='SST', func=class_BivariateMI.corr_map,
                              alpha=.05, FDR_control=True,
                              distance_eps=500, min_area_in_degrees2=5,
                              calc_ts='pattern cov', selbox=sst_green_bb,#(130,340,-10,60),
                              lags=np.array([0]))]
                 # BivariateMI(name='sm', func=class_BivariateMI.parcorr_map_time,
                 #            alpha=.05, FDR_control=True,
                 #            distance_eps=1200, min_area_in_degrees2=10,
                 #            calc_ts='region mean', selbox=(200,300,20,73),
                 #            lags=np.array([0]))]

rg = RGCPD(list_of_name_path=list_of_name_path,
            list_for_MI=list_for_MI,
コード例 #20
0
def pipeline(lags, periodnames, use_vars=['sst', 'smi'], load=False):
    #%%
    if int(lags[0][0].split('-')[-2]) > 7:  # first month after july
        crossyr = True
    else:
        crossyr = False

    SM_lags = lags.copy()
    for i, l in enumerate(SM_lags):
        orig = '-'.join(l[0].split('-')[:-1])
        repl = '-'.join(l[1].split('-')[:-1])
        SM_lags[i] = [l[0].replace(orig, repl), l[1]]

    list_for_MI = [
        BivariateMI(name='sst',
                    func=class_BivariateMI.corr_map,
                    alpha=alpha_corr,
                    FDR_control=True,
                    kwrgs_func={},
                    distance_eps=250,
                    min_area_in_degrees2=3,
                    calc_ts=calc_ts,
                    selbox=GlobalBox,
                    lags=lags,
                    group_split=True,
                    use_coef_wghts=True),
        BivariateMI(name='smi',
                    func=class_BivariateMI.corr_map,
                    alpha=alpha_corr,
                    FDR_control=True,
                    kwrgs_func={},
                    distance_eps=200,
                    min_area_in_degrees2=3,
                    calc_ts='pattern cov',
                    selbox=USBox,
                    lags=SM_lags,
                    use_coef_wghts=True)
    ]

    rg = RGCPD(list_of_name_path=list_of_name_path,
               list_for_MI=list_for_MI,
               list_import_ts=None,
               start_end_TVdate=None,
               start_end_date=None,
               start_end_year=start_end_year,
               tfreq=None,
               path_outmain=path_out_main)
    rg.figext = '.png'

    subfoldername = target_dataset
    #list(np.array(start_end_year, str)))
    subfoldername += append_pathsub

    rg.pp_precursors(detrend=[True, {
        'tp': False,
        'smi': False
    }],
                     anomaly=[True, {
                         'tp': False,
                         'smi': False
                     }],
                     auto_detect_mask=[False, {
                         'swvl1': True,
                         'swvl2': True
                     }])
    if crossyr:
        TV_start_end_year = (start_end_year[0] + 1, 2019)
    else:
        TV_start_end_year = (start_end_year[0], 2019)

    kwrgs_core_pp_time = {'start_end_year': TV_start_end_year}
    rg.pp_TV(name_ds=name_ds,
             detrend={'method': 'linear'},
             ext_annual_to_mon=False,
             kwrgs_core_pp_time=kwrgs_core_pp_time)
    if method.split('_')[0] == 'leave':
        rg.traintest(method,
                     gap_prior=1,
                     gap_after=1,
                     seed=seed,
                     subfoldername=subfoldername)
    else:
        rg.traintest(method, seed=seed, subfoldername=subfoldername)

    #%%
    sst = rg.list_for_MI[0]
    if 'sst' in use_vars:
        load_sst = '{}_a{}_{}_{}_{}'.format(sst._name, sst.alpha,
                                            sst.distance_eps,
                                            sst.min_area_in_degrees2,
                                            periodnames[-1])
        if load:
            loaded = sst.load_files(rg.path_outsub1, load_sst)
        else:
            loaded = False
        if hasattr(sst, 'corr_xr') == False:
            rg.calc_corr_maps('sst')
    #%%
    SM = rg.list_for_MI[1]
    if 'smi' in use_vars:
        load_SM = '{}_a{}_{}_{}_{}'.format(SM._name, SM.alpha, SM.distance_eps,
                                           SM.min_area_in_degrees2,
                                           periodnames[-1])
        if load:
            loaded = SM.load_files(rg.path_outsub1, load_SM)
        else:
            loaded = False
        if hasattr(SM, 'corr_xr') == False:
            rg.calc_corr_maps('smi')

    #%%

    # sst.distance_eps = 250 ; sst.min_area_in_degrees2 = 4
    if hasattr(sst, 'prec_labels') == False and 'sst' in use_vars:
        rg.cluster_list_MI('sst')

        # check if west-Atlantic is a seperate region, otherwise split region 1
        df_labels = find_precursors.labels_to_df(sst.prec_labels)
        dlat = df_labels['latitude'] - 29
        dlon = df_labels['longitude'] - 290
        zz = pd.concat([dlat.abs(), dlon.abs()], axis=1)
        if zz.query('latitude < 10 & longitude < 10').size == 0:
            print('Splitting region west-Atlantic')
            largest_regions = df_labels['n_gridcells'].idxmax()
            split = find_precursors.split_region_by_lonlat
            sst.prec_labels, _ = split(
                sst.prec_labels.copy(),
                label=int(largest_regions),
                kwrgs_mask_latlon={'upper_right': (263, 16)})

        merge = find_precursors.merge_labels_within_lonlatbox

        # # Ensure that what is in Atlantic is one precursor region
        lonlatbox = [263, 300, 17, 40]
        sst.prec_labels = merge(sst, lonlatbox)
        # Indonesia_oceans = [110, 150, 0, 10]
        # sst.prec_labels = merge(sst, Indonesia_oceans)
        Japanese_sea = [100, 150, 30, 50]
        sst.prec_labels = merge(sst, Japanese_sea)
        Mediterrenean_sea = [0, 45, 30, 50]
        sst.prec_labels = merge(sst, Mediterrenean_sea)
        East_Tropical_Atlantic = [330, 20, -10, 10]
        sst.prec_labels = merge(sst, East_Tropical_Atlantic)

    if 'sst' in use_vars:
        if loaded == False:
            sst.store_netcdf(rg.path_outsub1, load_sst, add_hash=False)
        sst.prec_labels['lag'] = ('lag', periodnames)
        sst.corr_xr['lag'] = ('lag', periodnames)
        rg.quick_view_labels('sst',
                             min_detect_gc=.5,
                             save=save,
                             append_str=periodnames[-1])

    #%%
    if hasattr(SM, 'prec_labels') == False and 'smi' in use_vars:
        SM = rg.list_for_MI[1]
        rg.cluster_list_MI('smi')

        lonlatbox = [220, 240, 25, 55]  # eastern US
        SM.prec_labels = merge(SM, lonlatbox)
        lonlatbox = [270, 280, 25, 45]  # mid-US
        SM.prec_labels = merge(SM, lonlatbox)
    if 'smi' in use_vars:
        if loaded == False:
            SM.store_netcdf(rg.path_outsub1, load_SM, add_hash=False)
        SM.corr_xr['lag'] = ('lag', periodnames)
        SM.prec_labels['lag'] = ('lag', periodnames)
        rg.quick_view_labels('smi',
                             min_detect_gc=.5,
                             save=save,
                             append_str=periodnames[-1])
#%%

    rg.get_ts_prec()
    rg.df_data = rg.df_data.rename({rg.df_data.columns[0]: target_dataset},
                                   axis=1)

    # # fill first value of smi (NaN because of missing December when calc smi
    # # on month februari).
    # keys = [k for k in rg.df_data.columns if k.split('..')[-1]=='smi']
    # rg.df_data[keys] = rg.df_data[keys].fillna(value=0)

    #%% Causal Inference

    def feature_selection_CondDep(df_data,
                                  keys,
                                  z_keys=None,
                                  alpha_CI=.05,
                                  x_lag=0,
                                  z_lag=0):

        # Feature selection Cond. Dependence
        keys = list(keys)  # must be list
        if z_keys is None:
            z_keys = keys
        corr, pvals = wrapper_PCMCI.df_data_Parcorr(df_data.copy(),
                                                    keys=keys,
                                                    z_keys=z_keys,
                                                    z_lag=z_lag)
        # removing all keys that are Cond. Indep. in each trainingset
        keys_dict = dict(zip(range(rg.n_spl), [keys] * rg.n_spl))  # all vars
        for s in rg.df_splits.index.levels[0]:
            for k_i in keys:
                onekeyCI = (np.nan_to_num(pvals.loc[k_i][s], nan=alpha_CI) >
                            alpha_CI).mean() > 0
                keyisNaN = np.isnan(pvals.loc[k_i][s]).all()
                if onekeyCI or keyisNaN:
                    k_ = keys_dict[s].copy()
                    k_.pop(k_.index(k_i))
                    keys_dict[s] = k_

        return corr, pvals, keys_dict.copy()

    regress_autocorr_SM = False
    unique_keys = np.unique(
        ['..'.join(k.split('..')[1:]) for k in rg.df_data.columns[1:-2]])
    # select the causal regions from analysys in Causal Inferred Precursors
    print('Start Causal Inference')
    list_pvals = []
    list_corr = []
    for k in unique_keys:
        z_keys = [z for z in rg.df_data.columns[1:-2] if k not in z]

        for mon in periodnames:
            keys = [mon + '..' + k]
            if regress_autocorr_SM and 'sm' in k:
                z_keys = [
                    z for z in rg.df_data.columns[1:-2] if keys[0] not in z
                ]

            if keys[0] not in rg.df_data.columns:
                continue
            out = feature_selection_CondDep(rg.df_data.copy(),
                                            keys=keys,
                                            z_keys=z_keys,
                                            alpha_CI=.05)
            corr, pvals, keys_dict = out
            list_pvals.append(pvals.max(axis=0, level=0))
            list_corr.append(corr.mean(axis=0, level=0))

    rg.df_pvals = pd.concat(list_pvals, axis=0)
    rg.df_corr = pd.concat(list_corr, axis=0)

    return rg