def create_M36_M09_lut():
    ''' Create a NN look-up table from the M09 to the M36 grid'''

    fout = '/Users/u0116961/data_sets/GEOSldas_runs/LUT_M36_M09_US.csv'

    fname36 = '/Users/u0116961/data_sets/GEOSldas_runs/NLv4_M36_US_SMAP_TB_OL.ldas_tilecoord.bin'
    fname09 = '/Users/u0116961/data_sets/GEOSldas_runs/US_M09_SMAP_OL.ldas_tilecoord.bin'

    io = LDAS_io(exp='US_M36_SMAP_OL')
    dtype, hdr, length = get_template('tilecoord')

    tc36 = io.read_fortran_binary(fname36,
                                  dtype,
                                  hdr=hdr,
                                  length=length,
                                  reg_ftags=True)
    tc09 = io.read_fortran_binary(fname09,
                                  dtype,
                                  hdr=hdr,
                                  length=length,
                                  reg_ftags=True)

    tc36['ind09'] = -9999

    for idx, data in tc36.iterrows():
        print('%i / %i' % (idx, len(tc36)))
        tc36.loc[idx, 'ind09'] = np.argmin((tc09.com_lat - data.com_lat)**2 +
                                           (tc09.com_lon - data.com_lon)**2)

    tc36['ind09'].to_csv(fout)
def extract_smap_ldas_data():

    # Path specifications
    if platform.system() == 'Darwin':
        pointlist_file = '/Users/u0116961/data_sets/SMAP/4frederike/tileind_id_lat_lon_Dry_Chaco_tight.txt'
        smap_root = Path('/Users/u0116961/data_sets/SMAP/SPL2SMP.006')
        ldas_root_tc = Path(
            '/Users/u0116961/data_sets/LDASsa_runs/GLOB_M36_7Thv_TWS_ensin_FOV0_M2/output/SMAP_EASEv2_M36_US/rc_out'
        )
        ldas_root_data = Path('/Users/u0116961/data_sets/SMAP/4frederike')
        out_path = Path('/Users/u0116961/data_sets/SMAP/4frederike')
    else:
        pointlist_file = '/staging/leuven/stg_00024/MSC_TMP/frederike/RTM/RTM_CALI/tileind_id_lat_lon_Dry_Chaco_tight.txt'
        smap_root = Path(
            '/staging/leuven/stg_00024/l_data/obs_satellite/SMAP/SPL2SMP.006')
        ldas_root_tc = Path(
            '/scratch/leuven/314/vsc31402/output/GLOB_M36_7Thv_TWS_ensin_FOV0_M2/output/SMAP_EASEv2_M36_GLOB/rc_out'
        )
        ldas_root_data = Path(
            '/scratch/leuven/314/vsc31402/output/GLOB_M36_7Thv_TWS_ensin_FOV0_M2/output/SMAP_EASEv2_M36_GLOB/cat/ens_avg'
        )
        out_path = Path('/staging/leuven/stg_00024/OUTPUT/alexg/4frederike')
    fbase = 'GLOB_M36_7Thv_TWS_ensin_FOV0_M2_run_1.ens_avg.ldas_tile_inst_out'

    # Get SMAP data path and observation dates
    files = sorted(smap_root.glob('**/*.h5'))
    dates = pd.to_datetime([str(f)[-29:-14] for f in files]).round('3h')
    dates_u = dates.unique()

    # Setup LDAS interface and get tile_coord LUT
    io = LDAS_io(exp='US_M36_SMOS40_TB_MadKF_DA_it614')
    dtype, _, _ = get_template('xhourly_inst')
    tc_file = ldas_root_tc / 'GLOB_M36_7Thv_TWS_ensin_FOV0_M2.ldas_tilecoord.bin'
    tc = io.read_params('tilecoord', fname=tc_file)
    n_tiles = len(tc)

    # Extract only the relevant domain, as specified by GdL
    plist = pd.read_csv(pointlist_file,
                        names=['tile_idx', 'tile_id', 'lat', 'lon'],
                        sep='\t')
    tc = tc.reindex(plist.tile_idx.values)

    # Create empty array to be filled w. SMAP and LDAS data
    res_arr = np.full((len(dates_u), len(tc), 2), np.nan)

    ind_valid = []  # Keep only dates with valid data!
    for cnt, (f, date) in enumerate(zip(files, dates)):
        print(f'Processing file {cnt} / {len(files)}...')
        with h5py.File(f, mode='r') as arr:
            qf = arr['Soil_Moisture_Retrieval_Data']['retrieval_qual_flag'][:]
            idx = np.where((qf == 0) | (qf == 8))
            row = arr['Soil_Moisture_Retrieval_Data']['EASE_row_index'][idx]
            col = arr['Soil_Moisture_Retrieval_Data']['EASE_column_index'][idx]
            sm = arr['Soil_Moisture_Retrieval_Data']['soil_moisture'][idx]

            rowcols_smap = [f'{r:03d}{c:03d}' for r, c in zip(row, col)]
            rowcols_list = [
                f'{r:03d}{c:03d}' for r, c in zip(tc.j_indg, tc.i_indg)
            ]

            ind_dict_smap = dict((i, j) for j, i in enumerate(rowcols_smap))
            ind_dict_list = dict((i, j) for j, i in enumerate(rowcols_list))
            inter = set(rowcols_smap).intersection(set(rowcols_list))
            if len(inter) > 0:
                inds_smap = np.array([ind_dict_smap[x] for x in inter])
                inds_list = np.array([ind_dict_list[x] for x in inter])
                srt = np.argsort(inds_smap)

                fname = ldas_root_data / f'Y{date.year}' / f'M{date.month:02d}' / f'{fbase}.{date.strftime("%Y%m%d_%H%Mz.bin")}'
                if fname.exists():
                    data_ldas = io.read_fortran_binary(fname,
                                                       dtype=dtype,
                                                       length=n_tiles)
                    data_ldas.index += 1

                    dt_idx = np.where(dates_u == date)[0][0]
                    if dt_idx not in ind_valid:
                        ind_valid.append(dt_idx)
                    res_arr[dt_idx, inds_list[srt], 0] = sm[inds_smap[srt]]
                    res_arr[dt_idx, inds_list[srt], 1] = data_ldas.reindex(
                        tc.iloc[inds_list[srt]].index)['sm_surface'].values

    res_arr = res_arr[ind_valid, :, :]
    dates_u = dates_u[ind_valid]

    # Save date information
    pd.Series(dates_u.to_julian_date(),
              index=dates_u).to_csv(out_path / 'dates.csv', header=None)

    # Save output for Matlab
    res_arr = {'smap': res_arr[:, :, 0], 'ldas': res_arr[:, :, 1]}
    sio.savemat(out_path / 'soil_moisture_smap_ldas.mat', res_arr)
Example #3
0
def plot_improvement_vs_uncertainty_update(iteration):

    io = LDAS_io()

    root = Path('/work/MadKF/CLSM/iter_%i' % iteration)

    res = pd.read_csv(root / 'validation' / 'insitu_TCA.csv')
    res.index = res.network
    tilenum = np.vectorize(io.grid.colrow2tilenum)(res.ease_col.values,res.ease_row.values)

    root = Path('/work/MadKF/CLSM/iter_%i' % 531)
    fA = root / 'absolute' / 'error_files' / 'gapfilled' / 'SMOS_fit_Tb_A.bin'
    fD = root / 'absolute' / 'error_files' / 'gapfilled' / 'SMOS_fit_Tb_D.bin'

    dtype, hdr, length = template_error_Tb40()
    imgA = io.read_fortran_binary(fA, dtype, hdr=hdr, length=length)
    imgD = io.read_fortran_binary(fD, dtype, hdr=hdr, length=length)
    imgA.index += 1
    imgD.index += 1

    pol = 'h'
    orb = 'dsc'

    # if (orb == 'asc') & (pol == 'h'):
    #     perts = imgA.loc[tilenum,'err_Tbh'].values
    # elif (orb == 'asc') & (pol == 'v'):
    #     perts = imgA.loc[tilenum,'err_Tbv'].values
    # elif (orb == 'dsc') & (pol == 'h'):
    #     perts = imgD.loc[tilenum,'err_Tbh'].values
    # elif (orb == 'dsc') & (pol == 'v'):
    #     perts = imgD.loc[tilenum,'err_Tbv'].values

    perts = (imgA.loc[tilenum, 'err_Tbh'].values +
            imgA.loc[tilenum,'err_Tbv'].values +
            imgD.loc[tilenum,'err_Tbh'].values +
            imgD.loc[tilenum, 'err_Tbv'].values) / 4

    fontsize=14

    f = plt.figure(figsize=(20,12))

    for i,var in enumerate(['sm_surface', 'sm_rootzone']):
        for j,mode in enumerate(['absolute','shortterm','longterm']):

            tag1 = 'R2_model_DA_madkf_' + mode + '_' + var
            tag2 = 'R2_model_DA_const_err_' + mode + '_' + var
            dR2 = (res[tag1] - res[tag2]).values

            ind = np.where(~np.isnan(dR2))
            fit = np.polyfit(perts[ind],dR2[ind],1)


            ax = plt.subplot(2,3,j+1 + i*3)
            plt.axhline(color='black', linestyle='--', linewidth=1)
            plt.plot(perts, dR2, 'o', color='orange', markeredgecolor='black', markeredgewidth=0.5, markersize=6)
            plt.plot(np.arange(12), fit[0] * np.arange(12) + fit[1], '--', color='black', linewidth=3)

            plt.xlim(0,11)
            plt.ylim(-1,1)

            if i==0:
                plt.title(mode, fontsize=fontsize)
                ax.tick_params(labelbottom=False)
                # labels = [item.get_text() for item in ax.get_xticklabels()]
                # empty_string_labels = [''] * len(labels)
                # ax.set_xticklabels(empty_string_labels)
            else:
                plt.xticks(fontsize=fontsize-2)


            if j==0:
                plt.ylabel(var, fontsize=fontsize)
                plt.yticks(fontsize=fontsize-2)
            else:
                ax.tick_params(labelleft=False)
                # labels = [item.get_text() for item in ax.get_yticklabels()]
                # empty_string_labels = [''] * len(labels)
                # ax.set_xticklabels(empty_string_labels)

    fout = root / 'validation' / 'plots' / 'gain_vs_err.png'
    f.savefig(fout, dpi=300, bbox_inches='tight')
    plt.close()