# loop over variables and seasons
VARS = ['TMEAN']

for VAR in VARS:

    # import 3d (time, lat, lon) features
    with h5py.File(JRA_dir + 'JRA_{}_features_US_2015_2020.hdf'.format(VAR), 'r') as hdf_io:
        PRISM_T = hdf_io['{}_4km'.format(VAR)][...]
        REGRID_T = hdf_io['{}_REGRID'.format(VAR)][...]

        
    shape_3d = REGRID_T.shape
    RESULT_CLEAN = np.zeros(shape_3d)
    RESULT_025 = np.zeros((shape_3d[0],)+lon_025.shape)

    for n in range(shape_3d[0]):
        print('\t{}'.format(n))
        X = (REGRID_T[n, ...], etopo_regrid)
        temp_unet = vu.pred_domain(X, land_mask, CGAN, param, method='norm_std')
        temp_025 = du.interp2d_wraper(lon_4km, lat_4km, temp_unet, lon_025, lat_025, method=interp_method)
        temp_4km = du.interp2d_wraper(lon_025, lat_025, temp_025, lon_4km, lat_4km, method=interp_method)
        
        RESULT_025[n, ...] = temp_025
        RESULT_CLEAN[n, ...] = temp_4km

    RESULT_CLEAN[:, land_mask] = np.nan
    tuple_save = (lon_4km, lat_4km, PRISM_T, RESULT_CLEAN, RESULT_025, etopo_4km, etopo_regrid)
    label_save = ['lon_4km', 'lat_4km', '{}_4km'.format(VAR), '{}_REGRID'.format(VAR), '{}_025'.format(VAR), 'etopo_4km', 'etopo_regrid']
    du.save_hdf5(tuple_save, label_save, out_dir=JRA_dir, filename='JRA_US_{}_clean_2015_2020.hdf'.format(VAR))
    
    
    # import 3d (time, lat, lon) features
    hdf_io = h5py.File(PRISM_dir + 'PRISM_{}_features_2015_2020.hdf'.format(var), 'r')
    PRISM_T = hdf_io['{}_4km'.format(var)][ind_pred, ...]
    REGRID_T = hdf_io['{}_REGRID'.format(var)][ind_pred, ...]
    hdf_io.close()

    # import pre-trained models (import together for saving time)
    # UNET
    unet = {}
    unet['djf'] = keras.models.load_model(model_import_dir+'UNET3_{}_djf.hdf'.format(var))
    unet['mam'] = keras.models.load_model(model_import_dir+'UNET3_{}_mam.hdf'.format(var))
    unet['jja'] = keras.models.load_model(model_import_dir+'UNET3_{}_jja.hdf'.format(var))
    unet['son'] = keras.models.load_model(model_import_dir+'UNET3_{}_son.hdf'.format(var))
    for n, date in enumerate(pred_list):
        X = (REGRID_T[n, ...], etopo_4km, etopo_regrid)
        print(date)
        if date.month in [12, 1, 2]:
            temp_unet = vu.pred_domain(X, land_mask, unet['djf'], param, method='norm_std')
        elif date.month in [3, 4, 5]:
            temp_unet = vu.pred_domain(X, land_mask, unet['mam'], param, method='norm_std')
        elif date.month in [6, 7, 8]:
            temp_unet = vu.pred_domain(X, land_mask, unet['jja'], param, method='norm_std')
        elif date.month in [9, 10, 11]:
            temp_unet = vu.pred_domain(X, land_mask, unet['son'], param, method='norm_std')
    
        RESULT_UNET[n, ...] = temp_unet
            
    tuple_save = (lon_4km, lat_4km, PRISM_T, REGRID_T, RESULT_UNET)
    label_save = ['lon_4km', 'lat_4km', '{}_4km'.format(var), '{}_REGRID'.format(var), 'RESULT_UNET']
    du.save_hdf5(tuple_save, label_save, out_dir=save_dir, filename='PRISM_PRED_{}_test.hdf'.format(var))