示例#1
0
    def compute_nudging_term(self, date, model_state):

        # Get model SSH state
        ssh = model_state.getvar(0) # 1st variable (ssh)

        # Get observations and nudging parameters
        obs_ssh, nudging_coeff_ssh, sigma_ssh =\
            bfn_get_data_at_t(date,
                              self.dict_proj_ssh)
        obs_rv, nudging_coeff_rv, sigma_rv =\
            bfn_get_data_at_t(date,
                              self.dict_proj_rv)



        # Compute nudging term
        N = {'ssh':np.zeros_like(ssh), 'rv':np.zeros_like(ssh)}

        if obs_rv is not None and np.any(np.isfinite(obs_rv)):
            # Nudging towards relative vorticity
            rv = switchvar.ssh2rv(
                ssh, self.lon2d, self.lat2d, name_grd=self.name_grd)
            nobs = len(obs_rv)
            for iobs in range(nobs):
                indNoNan = ~np.isnan(obs_rv[iobs])
                if np.any(indNoNan):
                    # Filter model state for spectral nudging
                    rv_ls = rv.copy()
                    if sigma_rv[iobs] is not None and sigma_rv[iobs]>0:
                        rv_ls = gaussian_filter(rv_ls,sigma=sigma_rv[iobs])
                    N['rv'][indNoNan] += nudging_coeff_rv[iobs,indNoNan] *\
                        (obs_rv[iobs,indNoNan]-rv_ls[indNoNan])

        if obs_ssh is not None and np.any(np.isfinite(obs_ssh)):
            # Nudging towards ssh
            nobs = len(obs_ssh)
            for iobs in range(nobs):
                indNoNan = ~np.isnan(obs_ssh[iobs])
                if np.any(indNoNan):
                    # Filter model state for spectral nudging
                    ssh_ls = ssh.copy()
                    if sigma_ssh[iobs] is not None and sigma_ssh[iobs]>0:
                        ssh_ls = gaussian_filter(ssh_ls,sigma=sigma_ssh[iobs])
                    N['ssh'][indNoNan] += nudging_coeff_ssh[iobs,indNoNan] *\
                         (obs_ssh[iobs,indNoNan]-ssh_ls[indNoNan])

        # Mask pixels that are not influenced by observations
        N['ssh'] = N['ssh'] * self.scalenudg[0]
        N['rv'] = N['rv'] * self.scalenudg[1]
        N['ssh'][N['ssh']==0] = np.nan
        N['rv'][N['rv']==0] = np.nan

        if self.flag_plot>3:
            plt.figure()
            plt.suptitle('Nudging coefficient')
            plt.pcolormesh(self.lon2d,self.lat2d,N['ssh'])
            plt.colorbar()
            plt.show()

        return N
示例#2
0
def bfn_merge_projections(varname, sat_info_list, obs_file_list,
                          lon2d, lat2d,
                          flag_plot=None,
                          nudging_coeff_list=None, dist_scale=None):


    if len(sat_info_list)==1 and sat_info_list[0].kind in ['fullSSH','fullRV']:
        # Full fields is provided, no need to compute tapering
        with xr.open_dataset(obs_file_list[0]) as ncin:
            lonobs = ncin[sat_info_list[0].name_obs_lon].values % 360
            latobs = ncin[sat_info_list[0].name_obs_lat].values
            varobs = ncin[sat_info_list[0].name_obs_var[0]].values
            if len(varobs.shape)==3:
                if varobs.shape[0]>1:
                    print('Warning: the full field provided has several\
                          timestep, we take the first one')
                varobs = varobs[0]
            if varname == 'relvort' and sat_info_list[0].kind=='fullSSH':
                proj_var = switchvar.ssh2rv(varobs, lonobs, latobs)
            else:
                proj_var = varobs

        if np.any(lonobs!=lon2d) or np.any(latobs!=lat2d):
            print('ERROR: When providing ' + sat_info_list[0].kind +\
' observations, grid has to be the same as the model one')
            sys.exit()
        proj_nudging_coeff = nudging_coeff_list[0] * np.ones_like(proj_var)


    else:
        # Construct KD tree for projection
        grnd_pix_tree, dist_threshold =\
               bfn_construct_ground_pixel_tree(lon2d, lat2d)
        if nudging_coeff_list is None:
            nudging_coeff_list = [1 for _ in range(len(sat_info_list))]
            dist_scale = 2*dist_threshold

        # Initialization
        lonobs, latobs, varobs, nudging_coeff = [np.array([]) for _ in range(4)]

        # Merge observations
        for iobs, (sat_info, obs_file, K) in\
          enumerate(zip(sat_info_list, obs_file_list, nudging_coeff_list)):
            # Open observation file
            with xr.open_dataset(obs_file) as ncin:
                lon = ncin[sat_info.name_obs_lon].values
                lat = ncin[sat_info.name_obs_lat].values
                var = [ncin[var_].values for var_ in sat_info.name_obs_var]
            K = K * np.ones_like(lon)
            # Merging
            lonobs = np.append(lonobs, lon.ravel())
            latobs = np.append(latobs, lat.ravel())
            nudging_coeff = np.append(nudging_coeff, K)
            # Check if we need to compute relative vorticity
            if varname == 'relvort' and sat_info.name_obs_xac is not None:
                # Only for 2D data (need 'xac' variable)
                xac = ncin[sat_info.name_obs_xac].values
                rv = switchvar.ssh2rv(var[0], lon, lat, xac=xac)
                varobs = np.append(varobs, rv.ravel())
            elif varname == 'ssh':
                if sat_info.kind == 'CMEMS':
                    var = var[0] + var[1]  # SLA + MDT
                else:
                    var = var[0]  # SSH
                varobs = np.append(varobs, var.ravel())
            else:
                print('Warning: name of nudging variable not recongnized!!')

        # Create mask
        mask = varobs.copy()
        mask[np.isnan(mask)] = 1e19
        varobs = np.ma.masked_where(np.abs(mask) > 50, varobs)

        # Clean memory
        del var, mask, lon, lat

        # Perform projection
        proj_var, proj_nudging_coeff =\
               bfn_project_obsvar_to_state_grid(varobs, nudging_coeff,
                                                lonobs, latobs,
                                                grnd_pix_tree,
                                                dist_threshold,
                                                lon2d.shape[0],
                                                lon2d.shape[1],
                                                dist_scale)

    # Debug
    if flag_plot is not None and flag_plot > 1:

        params = {
            'font.size': 20,
            'axes.labelsize': 15,
            'axes.titlesize': 20,
            'xtick.labelsize': 12,
            'ytick.labelsize': 12,
            'legend.fontsize': 20,
            'legend.handlelength': 2,
            'lines.linewidth': 4
            }

        plt.rcParams.update(params)

        fig, (ax0, ax1, ax2) = plt.subplots(1, 3, figsize=(20, 7))
        if len(lonobs.shape) == 2:
            im0 = ax0.pcolormesh(lonobs, latobs, varobs,shading='auto')
        else:
            im0 = ax0.scatter(lonobs, latobs, c=varobs)
            ax0.set_xlim(lon2d.min(), lon2d.max())
            ax0.set_ylim(lat2d.min(), lat2d.max())
        cbar = plt.colorbar(im0, ax=ax0)
        cbar.ax.set_title("m")
        ax0.set_title('Available observations')
        im1 = ax1.pcolormesh(lon2d, lat2d, proj_var,shading='auto')
        cbar = plt.colorbar(im1, ax=ax1)
        cbar.ax.set_title("m")
        ax1.set_title('Projected observations')
        im2 = ax2.pcolormesh(lon2d, lat2d, proj_nudging_coeff,
                             cmap='Spectral_r',shading='auto')
        cbar = plt.colorbar(im2, ax=ax2)
        ax2.set_title('Nudging term')
        plt.show()

    return proj_var, proj_nudging_coeff