예제 #1
0
def plot_tca_uncertainties(dir_out):

    sensors = ['ASCAT', 'SMAP', 'CLSM']

    res = pd.read_csv(
        '/Users/u0116961/Documents/work/MadKF/CLSM/SM_err_ratio/GEOSldas/sm_validation/Pcorr/result.csv',
        index_col=0)

    figsize = (15, 5)
    fontsize = 12
    cb = [0, 0.04]

    modes = ['anom_lst', 'anom_lt', 'anom_st']
    titles = ['Anomalies', 'LF signal', 'HF signal']
    labels = [
        '$\widehat{std}(\epsilon_{\Theta,smap})$',
        '$\widehat{std}(\epsilon_{\Theta,clsm})$'
    ]

    f = plt.figure(figsize=figsize)

    n = 0
    pos = []
    for s, l in zip(sensors[1::], labels):
        for mode, title in zip(modes, titles):
            n += 1

            plt.subplot(2, 3, n)

            tag = 'ubrmse_grid_' + mode + '_m_' + s + '_tc_ASCAT_SMAP_CLSM'

            r_asc_smap = res[f'r_grid_{mode}_p_ASCAT_SMAP']
            r_asc_clsm = res[f'r_grid_{mode}_p_ASCAT_CLSM']
            r_smap_clsm = res[f'r_grid_{mode}_p_SMAP_CLSM']
            thres = 0.2
            ind_valid = res[(r_asc_smap > thres) & (r_asc_smap > thres) &
                            (r_asc_smap > thres)].index

            im_r = plot_ease_img(res.reindex(ind_valid),
                                 tag,
                                 fontsize=fontsize,
                                 cbrange=cb,
                                 cmap='viridis',
                                 print_mean=True)
            # if (n == 6) | (n == 7):
            #     pos += [im_r.axes.get_position()]

            if s == 'SMAP':
                plt.title(title, fontsize=fontsize)
            if mode == 'anom_lst':
                plt.ylabel(l, fontsize=fontsize)

    plot_centered_cbar(f, im_r, 3, fontsize=fontsize - 2, bottom=0.07)

    fout = dir_out / 'tca_uncertainties.png'
    f.savefig(fout, dpi=300, bbox_inches='tight')
    plt.close()
예제 #2
0
def plot_orthogonality_check(dir_out):

    sensors = ['ASCAT', 'SMAP', 'CLSM']

    res = pd.read_csv(
        '/Users/u0116961/Documents/work/MadKF/CLSM/SM_err_ratio/GEOSldas/sm_validation/Pcorr/result.csv',
        index_col=0)

    figsize = (11, 3)
    fontsize = 10
    # cb = [-0.02, 0.02]
    cb = [-0.02, 0.02]

    f = plt.figure(figsize=figsize)

    valid = pd.Series(True, index=res.index)
    for mode in ['anom_lt', 'anom_st', 'anom_lst']:
        r_asc_smap = res[f'r_grid_{mode}_p_ASCAT_SMAP']
        r_asc_clsm = res[f'r_grid_{mode}_p_ASCAT_CLSM']
        r_smap_clsm = res[f'r_grid_{mode}_p_SMAP_CLSM']
        thres = 0.2
        valid &= ((r_asc_smap > thres) & (r_asc_smap > thres) &
                  (r_asc_smap > thres))

    ind_valid = res[valid].index

    for i, s in enumerate(sensors[1::]):

        plt.subplot(1, 2, i + 1)

        tag_lt = 'ubrmse_grid_anom_lt_m_' + s + '_tc_ASCAT_SMAP_CLSM'
        tag_st = 'ubrmse_grid_anom_st_m_' + s + '_tc_ASCAT_SMAP_CLSM'
        tag_lst = 'ubrmse_grid_anom_lst_m_' + s + '_tc_ASCAT_SMAP_CLSM'

        res['diff'] = (res[tag_lst] - np.sqrt(res[tag_lt]**2 + res[tag_st]**2))

        im_r = plot_ease_img(res.reindex(ind_valid),
                             'diff',
                             fontsize=fontsize + 2,
                             cbrange=cb,
                             cmap=cc.cm.bjy,
                             print_meanstd=True)

        plt.title(s, fontsize=fontsize)

    plot_centered_cbar(f,
                       im_r,
                       2,
                       fontsize=fontsize - 2,
                       bottom=0.00,
                       hspace=0.030,
                       pad=0.02,
                       wdth=0.04)

    fout = dir_out / 'orthogonality_verification.png'
    f.savefig(fout, dpi=300, bbox_inches='tight')
    plt.close()
예제 #3
0
def plot_ensvar_ratio(dir_out):

    dtype, hdr, length = template_error_Tb40()

    f = plt.figure(figsize=(25, 12))

    fontsize = 14
    cb = [-10, 10]
    cmap = cc.cm.bjy

    modes = ['4K', 'abs', 'anom_lst', 'anom_lt', 'anom_st']
    titles = [
        '4K benchmark', 'Total signal', 'Anomalies', 'LF signal', 'HF signal'
    ]
    labels = [
        '(V-pol, Asc.)', '(V-pol, Dsc.)', '(H-pol, Asc.)', '(H-pol, Dsc.)'
    ]

    ios = [
        GEOSldas_io('ObsFcstAna',
                    exp=f'NLv4_M36_US_DA_SMAP_Pcorr_{mode}').timeseries
        for mode in modes
    ]
    io_ol = GEOSldas_io('ObsFcstAna', exp=f'NLv4_M36_US_OL_Pcorr').timeseries

    grid = GEOSldas_io().grid

    for i, (io_da, tit) in enumerate(zip(ios, titles)):

        for spc, label in zip(range(4), labels):

            tmp1 = io_da['obs_obsvar'][:, spc, :, :].values
            tmp2 = io_ol['obs_fcstvar'][:, spc, :, :].values
            avg = np.nanmean(tmp1 / tmp2, axis=0)
            # ratio = io_da['obs_obsvar'] / io_ol['obs_fcstvar']
            # ratio = io['obs_obsvar']
            # avg = ratio.mean(dim='time', skipna=True)

        plt.subplot(4, 5, spc * 5 + i + 1)
        img = plot_latlon_img(10 * np.log10(avg),
                              io_da.lon.values,
                              io_da.lat.values,
                              fontsize=fontsize,
                              cbrange=cb,
                              cmap=cmap,
                              plot_cb=False)
        if spc == 0:
            plt.title(tit, fontsize=fontsize)
        if i == 0:
            plt.ylabel(label, fontsize=fontsize)

    plot_centered_cbar(f, img, 5, fontsize=fontsize - 2)

    fout = dir_out / 'ensvar_ratio.png'
    f.savefig(fout, dpi=300, bbox_inches='tight')
    plt.close()
예제 #4
0
def plot_uncertainty_ratios(dir_out):

    sensors = ['ASCAT', 'SMAP', 'CLSM']

    res = pd.read_csv(
        '/Users/u0116961/Documents/work/MadKF/CLSM/SM_err_ratio/GEOSldas/sm_validation/Pcorr/result.csv',
        index_col=0)
    res_tc = pd.read_csv(
        '/Users/u0116961/Documents/work/MadKF/CLSM/SM_err_ratio/GEOSldas/sm_validation/Pcorr/result.csv',
        index_col=0)

    tg = GEOSldas_io().grid.tilegrids
    res_cols = res.col.values - tg.loc['domain', 'i_offg']
    res_rows = res.row.values - tg.loc['domain', 'j_offg']

    figsize = (16, 4)
    fontsize = 10
    cb = [-10, 10]
    cmap = cc.cm.bjy

    modes = ['4K', 'anom_lst', 'anom_lt', 'anom_st']
    titles = ['4K benchmark', 'Anomalies', 'LF signal', 'HF signal']

    ios = [
        GEOSldas_io('ObsFcstAna',
                    exp=f'NLv4_M36_US_DA_SMAP_Pcorr_{mode}').timeseries
        for mode in modes
    ]
    io_ol = GEOSldas_io('ObsFcstAna', exp=f'NLv4_M36_US_OL_Pcorr').timeseries

    grid = GEOSldas_io().grid

    f = plt.figure(figsize=figsize)

    for n, (mode, title, io_da) in enumerate(zip(modes, titles, ios)):

        if n > 0:
            plt.subplot(2, 4, n + 1)
            tagP = 'ubrmse_grid_' + mode + '_m_CLSM_tc_ASCAT_SMAP_CLSM'
            tagR = 'ubrmse_grid_' + mode + '_m_SMAP_tc_ASCAT_SMAP_CLSM'
            res['tmp'] = 10 * np.log10(res[tagP]**2 / res[tagR]**2)

            r_asc_smap = res_tc[f'r_grid_{mode}_p_ASCAT_SMAP']
            r_asc_clsm = res_tc[f'r_grid_{mode}_p_ASCAT_CLSM']
            r_smap_clsm = res_tc[f'r_grid_{mode}_p_SMAP_CLSM']
            thres = 0.2
            ind_valid = res_tc[(r_asc_smap > thres) & (r_asc_smap > thres) &
                               (r_asc_smap > thres)].index

            img = plot_ease_img(res.reindex(ind_valid),
                                'tmp',
                                fontsize=fontsize,
                                cbrange=cb,
                                cmap=cmap,
                                plot_cb=False)
            plt.title(title, fontsize=fontsize)
            if n == 1:
                plt.ylabel('TCA unc. ratio', fontsize=fontsize)

    for n, (mode, title, io_da) in enumerate(zip(modes, titles, ios)):

        if mode != '4K':
            r_asc_smap = res_tc[f'r_grid_{mode}_p_ASCAT_SMAP']
            r_asc_clsm = res_tc[f'r_grid_{mode}_p_ASCAT_CLSM']
            r_smap_clsm = res_tc[f'r_grid_{mode}_p_SMAP_CLSM']
            thres = 0.2
            ind_valid = res_tc[(r_asc_smap > thres) & (r_asc_smap > thres) &
                               (r_asc_smap > thres)].index
        else:
            ind_valid = res.index

        avg = np.full(io_da['obs_obsvar'].shape[1::], np.nan)
        for spc in range(4):
            tmp1 = io_da['obs_obsvar'][:, spc, :, :].values
            tmp2 = io_ol['obs_fcstvar'][:, spc, :, :].values
            avg[spc, :, :] = np.nanmean(tmp2 / tmp1, axis=0)
        avg = np.nanmean(avg, axis=0)

        res['avg'] = 10 * np.log10(avg[res_rows, res_cols])

        plt.subplot(2, 4, n + 5)
        img = plot_ease_img(res.reindex(ind_valid),
                            'avg',
                            fontsize=fontsize,
                            cbrange=cb,
                            cmap=cmap,
                            plot_cb=False)
        if n == 0:
            plt.title(title, fontsize=fontsize)
            plt.ylabel('Ens. var. ratio', fontsize=fontsize)

    plot_centered_cbar(f, img, 4, fontsize=fontsize, bottom=0.07)

    fout = dir_out / 'uncertainty_ratio.png'
    f.savefig(fout, dpi=300, bbox_inches='tight')
    plt.close()
예제 #5
0
def plot_predicted_skillgain(dir_out):

    res = pd.read_csv(
        '/Users/u0116961/Documents/work/MadKF/CLSM/SM_err_ratio/GEOSldas/sm_validation/Pcorr/result.csv',
        index_col=0)

    R_anom_lst = res[f'ubrmse_grid_anom_lst_m_SMAP_tc_ASCAT_SMAP_CLSM']**2
    P_anom_lst = res[f'ubrmse_grid_anom_lst_m_CLSM_tc_ASCAT_SMAP_CLSM']**2

    R_anom_lt = res[f'ubrmse_grid_anom_lt_m_SMAP_tc_ASCAT_SMAP_CLSM']**2
    P_anom_lt = res[f'ubrmse_grid_anom_lt_m_CLSM_tc_ASCAT_SMAP_CLSM']**2

    R_anom_st = res[f'ubrmse_grid_anom_st_m_SMAP_tc_ASCAT_SMAP_CLSM']**2
    P_anom_st = res[f'ubrmse_grid_anom_st_m_CLSM_tc_ASCAT_SMAP_CLSM']**2

    R_lt_st_arr = (R_anom_lst - (R_anom_lt + R_anom_st)) / 2
    P_lt_st_arr = (P_anom_lst - (P_anom_lt + P_anom_st)) / 2

    P_lt_st_arr_rho = P_lt_st_arr / np.sqrt(P_anom_lt * P_anom_st)

    # Baseline estimates
    # R2 = res[f'r2_grid_abs_m_CLSM_tc_ASCAT_SMAP_CLSM']
    R2 = res[f'r2_grid_anom_lst_m_CLSM_tc_ASCAT_SMAP_CLSM']
    SNR = R2 / (1 - R2)
    SIG = SNR * P_anom_lst

    modes = ['anom_lt', 'anom_st', 'anom_lst', 'anom_lst']
    titles = [
        'LF signal', 'HF signal', 'Anomalies (lumped)', 'Anomalies (joint)'
    ]
    result = pd.DataFrame(index=res.index, columns=modes)
    result['row'] = res.row
    result['col'] = res.col

    for i, mode in enumerate(modes):

        P = res[f'ubrmse_grid_{mode}_m_CLSM_tc_ASCAT_SMAP_CLSM']**2
        R = res[f'ubrmse_grid_{mode}_m_SMAP_tc_ASCAT_SMAP_CLSM']**2

        K = P / (R + P)

        if i < 3:
            K_4K = P / (4**2 + P)

            P_upd = K * R + (1 - K) * P
            NSR_upd = P_upd / SIG
            R2upd = 1 / (1 + NSR_upd)

            result[f'{i}_4K'] = np.sqrt(R2upd) - np.sqrt(R2)

        if i < 2:
            # Single signal assimilation
            # R2 = res[f'r2_grid_{mode}_m_CLSM_tc_ASCAT_SMAP_CLSM']
            # NSR = (1 - R2) / R2
            # R2upd = 1 / (1 + (1 - K) * NSR)

            P_upd = (1 - K) * P
            NSR_upd = P_upd / SIG
            R2upd = 1 / (1 + NSR_upd)

            result[i] = np.sqrt(R2upd) - np.sqrt(R2)
            result[f'P_upd_{i}'] = P_upd

        elif i == 2:
            # Joint signal assimilation
            result[i] = np.nan

            for cnt, idx in enumerate(res.index):
                print(f'{cnt} / {len(res)}')

                R11 = R_anom_lt.loc[idx]
                R22 = R_anom_st.loc[idx]
                R12 = R_lt_st_arr.loc[idx]

                P11 = P_anom_lt.loc[idx]
                P22 = P_anom_st.loc[idx]
                P12 = P_lt_st_arr.loc[idx]

                S = np.matrix([[R11, R12, 0, 0], [R12, R22, 0, 0],
                               [0, 0, P11, P12], [0, 0, P12, P22]])

                A = np.matrix(
                    [K.loc[idx], K.loc[idx], 1 - K.loc[idx], 1 - K.loc[idx]])
                P_upd = (A * S * A.T)[0, 0]
                NSR_upd = P_upd / SIG.loc[idx]
                R2upd = 1 / (1 + NSR_upd)

                result.loc[idx, i] = np.sqrt(R2upd) - np.sqrt(R2.loc[idx])

        else:
            P_upd = result[f'P_upd_0'] + result[
                f'P_upd_1'] + 2 * P_lt_st_arr_rho * np.sqrt(
                    result[f'P_upd_0'] * result[f'P_upd_1'])
            NSR_upd = P_upd / SIG
            R2upd = 1 / (1 + NSR_upd)
            result[i] = np.sqrt(R2upd) - np.sqrt(R2)

    f = plt.figure(figsize=(23, 7))

    for i, title in enumerate(titles):
        plt.subplot(2, 4, i + 1)
        im = plot_ease_img(result,
                           i,
                           fontsize=12,
                           cbrange=[-0.2, 0.2],
                           cmap=cc.cm.bjy,
                           log_scale=False,
                           title=title,
                           plot_cb=False)

        if i < 3:
            plt.subplot(2, 4, i + 5)
            im = plot_ease_img(result,
                               f'{i}_4K',
                               fontsize=12,
                               cbrange=[-0.2, 0.2],
                               cmap=cc.cm.bjy,
                               log_scale=False,
                               title=title + ' (4K)',
                               plot_cb=False)

    plot_centered_cbar(f, im, 3, fontsize=10)

    plt.tight_layout()
    plt.show()
예제 #6
0
def plot_perturbations(dir_out):

    root = Path(
        '/Users/u0116961/Documents/work/MadKF/CLSM/SM_err_ratio/GEOSldas/observation_perturbations/Pcorr'
    )
    res_tc = pd.read_csv(
        '/Users/u0116961/Documents/work/MadKF/CLSM/SM_err_ratio/GEOSldas/sm_validation/Pcorr/result.csv',
        index_col=0)
    pc = 'Pcorr'
    io = GEOSldas_io('ObsFcstAna')
    io2 = LDASsa_io('ObsFcstAna')

    lut = pd.read_csv(Paths().lut, index_col=0)
    ind = np.vectorize(io.grid.colrow2tilenum)(lut.ease2_col,
                                               lut.ease2_row,
                                               local=False)

    dtype, hdr, length = template_error_Tb40()

    f = plt.figure(figsize=(22, 8))

    fontsize = 14
    cbrange = [0, 8]
    cmap = cc.cm.bjy_r
    # cmap='viridis'

    modes = ['anom_lst', 'anom_lt', 'anom_st']
    titles = ['Anomalies', 'LF signal', 'HF signal']

    for i, (mode, title) in enumerate(zip(modes, titles)):

        fA = root / f'{mode}' / 'SMOS_fit_Tb_A.bin'
        fD = root / f'{mode}' / 'SMOS_fit_Tb_D.bin'

        imgA = io2.read_fortran_binary(fA, dtype, hdr=hdr, length=length)
        imgD = io2.read_fortran_binary(fD, dtype, hdr=hdr, length=length)

        imgA.index += 1
        imgD.index += 1

        r_asc_smap = res_tc[f'r_grid_{mode}_p_ASCAT_SMAP']
        r_asc_clsm = res_tc[f'r_grid_{mode}_p_ASCAT_CLSM']
        r_smap_clsm = res_tc[f'r_grid_{mode}_p_SMAP_CLSM']
        thres = 0.2
        ind_valid = res_tc[(r_asc_smap > thres) & (r_asc_smap > thres) &
                           (r_asc_smap > thres)].index
        ind_valid = np.vectorize(io.grid.colrow2tilenum)(
            res_tc.loc[ind_valid, 'col'].values,
            res_tc.loc[ind_valid, 'row'].values,
            local=False)

        plt.subplot(3, 4, i * 4 + 1)
        im = plot_ease_img2(imgA.reindex(ind).reindex(ind_valid),
                            'err_Tbv',
                            cbrange=cbrange,
                            cmap=cmap,
                            io=io,
                            plot_cmap=False)
        if i == 0:
            plt.title('$\hat{R}$ (V-pol, Asc.)', fontsize=fontsize)
        plt.ylabel(title, fontsize=fontsize)

        plt.subplot(3, 4, i * 4 + 2)
        plot_ease_img2(imgD.reindex(ind).reindex(ind_valid),
                       'err_Tbv',
                       cbrange=cbrange,
                       cmap=cmap,
                       io=io,
                       plot_cmap=False)
        if i == 0:
            plt.title('$\hat{R}$ (V-pol, Dsc.)', fontsize=fontsize)

        plt.subplot(3, 4, i * 4 + 3)
        plot_ease_img2(imgA.reindex(ind).reindex(ind_valid),
                       'err_Tbh',
                       cbrange=cbrange,
                       cmap=cmap,
                       io=io,
                       plot_cmap=False)
        if i == 0:
            plt.title('$\hat{R}$ (H-pol, Asc.)', fontsize=fontsize)

        plt.subplot(3, 4, i * 4 + 4)
        plot_ease_img2(imgD.reindex(ind).reindex(ind_valid),
                       'err_Tbh',
                       cbrange=cbrange,
                       cmap=cmap,
                       io=io,
                       plot_cmap=False)
        if i == 0:
            plt.title('$\hat{R}$ (H-pol, Dsc.)', fontsize=fontsize)

    plot_centered_cbar(f, im, 4, fontsize=fontsize - 4, bottom=0.07)

    plt.savefig(dir_out / f'perturbations.png', dpi=300, bbox_inches='tight')
    plt.close()
예제 #7
0
def plot_ascat_eval_relative(res_path, dir_out):

    refs = ['Pcorr_OL', 'Pcorr_4K']
    runs = [['Pcorr_anom_lst', 'Pcorr_anom_lt_ScDY', 'Pcorr_anom_st_ScYH'],
            'Pcorr_LTST']

    titles = [
        'R$_{TC}$ - R$_{OL}$ (Individual)', 'R$_{TC}$ - R$_{OL}$ (Joint)',
        'R$_{TC}$ - R$_{4K}$ (Individual)', 'R$_{TC}$ - R$_{4K}$ (Joint)'
    ]

    res = pd.read_csv(res_path / 'ascat_eval.csv', index_col=0)

    res_tc = pd.read_csv(
        '/Users/u0116961/Documents/work/MadKF/CLSM/SM_err_ratio/GEOSldas/sm_validation/Pcorr/result.csv',
        index_col=0)

    modes = ['anom_lst', 'anom_lt', 'anom_st']
    labels = [f'Anomaly skill', f'LF skill', f'HF skill']

    cb_r = [-0.2, 0.2]

    fontsize = 14

    f = plt.figure(figsize=(22, 8))

    for i, (m, label) in enumerate(zip(modes, labels)):

        for j, ref in enumerate(refs):

            ref_col = f'ana_r_corr_{ref}_{m}'

            res[ref_col][res[ref_col] < 0] = 0
            res[ref_col][res[ref_col] > 1] = 1

            for k, run in enumerate(runs):

                r = run if k == 1 else run[i]

                ax = plt.subplot(3, 4, i * 4 + j * 2 + k + 1)

                col = f'ana_r_corr_{r}_{m}'

                res[col][res[col] < 0] = 0
                res[col][res[col] > 1] = 1

                res['diff'] = res[col] - res[ref_col]

                if i == 0:
                    title = titles[j * 2 + k]
                else:
                    title = ''

                if (j == 0) & (k == 0):
                    ylabel = label
                else:
                    ylabel = ''

                r_asc_smap = res_tc[f'r_grid_{m}_p_ASCAT_SMAP']
                r_asc_clsm = res_tc[f'r_grid_{m}_p_ASCAT_CLSM']
                r_smap_clsm = res_tc[f'r_grid_{m}_p_SMAP_CLSM']
                thres = 0.2
                ind_valid = res_tc[(r_asc_smap > thres) & (r_asc_smap > thres)
                                   & (r_asc_smap > thres)].index

                im = plot_ease_img(res.reindex(ind_valid),
                                   'diff',
                                   title=title,
                                   cmap=cc.cm.bjy,
                                   cbrange=cb_r,
                                   fontsize=fontsize,
                                   print_mean=True,
                                   plot_cb=False)
                ax.set_ylabel(ylabel)

        plot_centered_cbar(f, im, 4, fontsize=fontsize - 2, pad=0.02)

    f.savefig(dir_out / f'ascat_eval_rel.png', dpi=300, bbox_inches='tight')
    plt.close()
예제 #8
0
def plot_ascat_eval_absolute(res_path, dir_out):

    runs = [
        'Pcorr_OL', 'Pcorr_4K',
        ['Pcorr_anom_lst', 'Pcorr_anom_lt_ScDY', 'Pcorr_anom_st_ScYH'],
        'Pcorr_LTST'
    ]
    titles = [
        'Open-loop', '4K Benchmark', 'Individual assimilation',
        'Joint assimilation'
    ]

    res = pd.read_csv(res_path / 'ascat_eval.csv', index_col=0)
    res_tc = pd.read_csv(
        '/Users/u0116961/Documents/work/MadKF/CLSM/SM_err_ratio/GEOSldas/sm_validation/Pcorr/result.csv',
        index_col=0)

    modes = ['anom_lst', 'anom_lt', 'anom_st']

    cb_r = [0.6, 1]

    fontsize = 14

    f = plt.figure(figsize=(22, 8))

    for i, m in enumerate(modes):

        for cnt, (run, tit) in enumerate(zip(runs, titles)):

            r = run if cnt != 2 else run[i]

            ax = plt.subplot(3, 4, i * 4 + cnt + 1)

            col = f'ana_r_corr_{r}_{m}'

            res[col][res[col] < 0] = 0
            res[col][res[col] > 1] = 1

            if i == 0:
                title = tit
            else:
                title = ''

            if cnt == 0:
                if i == 0:
                    ylabel = 'Anomaly skill'
                elif i == 1:
                    ylabel = 'LF skill'
                else:
                    ylabel = 'HF skill'
            else:
                ylabel = ''

            r_asc_smap = res_tc[f'r_grid_{m}_p_ASCAT_SMAP']
            r_asc_clsm = res_tc[f'r_grid_{m}_p_ASCAT_CLSM']
            r_smap_clsm = res_tc[f'r_grid_{m}_p_SMAP_CLSM']
            thres = 0.2
            ind_valid = res_tc[(r_asc_smap > thres) & (r_asc_smap > thres) &
                               (r_asc_smap > thres)].index

            im = plot_ease_img(res.reindex(ind_valid),
                               col,
                               title=title,
                               cmap='viridis',
                               cbrange=cb_r,
                               fontsize=fontsize,
                               print_mean=True,
                               plot_cb=False)
            ax.set_ylabel(ylabel)

        plot_centered_cbar(f, im, 4, fontsize=fontsize - 2, pad=0.02)

    f.savefig(dir_out / f'ascat_eval_abs.png', dpi=300, bbox_inches='tight')
    plt.close()
예제 #9
0
def plot_filter_diagnostics(res_path, dir_out):

    fname = res_path / 'filter_diagnostics.nc'

    res = pd.read_csv(
        '/Users/u0116961/Documents/work/MadKF/CLSM/SM_err_ratio/GEOSldas/sm_validation/Pcorr/result.csv',
        index_col=0)
    res_tc = pd.read_csv(
        '/Users/u0116961/Documents/work/MadKF/CLSM/SM_err_ratio/GEOSldas/sm_validation/Pcorr/result.csv',
        index_col=0)
    tg = GEOSldas_io().grid.tilegrids
    res_cols = res.col.values - tg.loc['domain', 'i_offg']
    res_rows = res.row.values - tg.loc['domain', 'j_offg']
    r_asc_smap = res_tc[f'r_grid_anom_lst_p_ASCAT_SMAP']
    r_asc_clsm = res_tc[f'r_grid_anom_lst_p_ASCAT_CLSM']
    r_smap_clsm = res_tc[f'r_grid_anom_lst_p_SMAP_CLSM']
    thres = 0.2
    ind_valid = res_tc[(r_asc_smap > thres) & (r_asc_smap > thres) &
                       (r_asc_smap > thres)].index

    fontsize = 14

    root = Path('/Users/u0116961/data_sets/GEOSldas_runs')
    runs = [run.name for run in root.glob('*_DA_SMAP_*')]
    runs += ['NLv4_M36_US_OL_Pcorr', 'NLv4_M36_US_OL_noPcorr']

    tags = ['OL_Pcorr', 'Pcorr_4K', f'Pcorr_anom_lst', 'Pcorr_LTST']
    iters = [np.where([tag in run for run in runs])[0][0] for tag in tags]

    titles = ['Open-loop', '4K constant', 'Anomalies']
    labels = ['H pol. / Asc.', 'V pol. / Asc.']

    with Dataset(fname) as ds:

        lons = ds.variables['lon'][:]
        lats = ds.variables['lat'][:]
        lons, lats = np.meshgrid(lons, lats)

        var = 'innov_autocorr'
        cbrange = [0, 0.7]
        step = 0.2
        cmap = 'viridis'

        f = plt.figure(figsize=(19, 6))

        for j, (spc, label) in enumerate(zip([0, 2], labels)):
            for i, (it_tit, it) in enumerate(zip(titles, iters)):

                title = it_tit if j == 0 else ''

                plt.subplot(2, 3, j * 3 + i + 1)
                data = ds.variables[var][:, :, it, spc]

                res['tmp'] = data[res_rows, res_cols]
                im = plot_ease_img(res.reindex(ind_valid),
                                   'tmp',
                                   fontsize=fontsize,
                                   cbrange=cbrange,
                                   cmap=cmap,
                                   title=title,
                                   plot_cb=False,
                                   print_meanstd=True)
                if i == 0:
                    plt.ylabel(label, fontsize=fontsize)

        plot_centered_cbar(f, im, 3, fontsize=fontsize - 2, bottom=0.07)
        fout = dir_out / f'{var}.png'
        f.savefig(fout, dpi=300, bbox_inches='tight')
        plt.close()