Example #1
0
def process_mpoint_da_with_station_num(path=sound_path,
                                       station='08001',
                                       k_iqr=1):
    from aux_gps import path_glob
    import xarray as xr
    from aux_gps import keep_iqr
    file = path_glob(sound_path, 'ALL*{}*.nc'.format(station))
    da = xr.open_dataarray(file[0])
    ts, tm, tpw = calculate_ts_tm_tpw_from_mpoint_da(da)
    ds = xr.Dataset()
    ds['Tm'] = xr.DataArray(tm, dims=['time'], name='Tm')
    ds['Tm'].attrs['unit'] = 'K'
    ds['Tm'].attrs['name'] = 'Water vapor mean atmospheric temperature'
    ds['Ts'] = xr.DataArray(ts, dims=['time'], name='Ts')
    ds['Ts'].attrs['unit'] = 'K'
    ds['Ts'].attrs['name'] = 'Surface temperature'
    ds['Tpw'] = xr.DataArray(tpw, dims=['time'], name='Tpw')
    ds['Tpw'].attrs['unit'] = 'mm'
    ds['Tpw'].attrs['name'] = 'precipitable_water'
    ds['time'] = da.time
    ds = keep_iqr(ds, k=k_iqr, dim='time')
    yr_min = ds.time.min().dt.year.item()
    yr_max = ds.time.max().dt.year.item()
    ds = ds.rename({'time': 'sound_time'})
    filename = 'station_{}_soundings_ts_tm_tpw_{}-{}.nc'.format(
        station, yr_min, yr_max)
    print('saving {} to {}'.format(filename, path))
    comp = dict(zlib=True, complevel=9)  # best compression
    encoding = {var: comp for var in ds}
    ds.to_netcdf(path / filename, 'w', encoding=encoding)
    print('Done!')
    return ds
Example #2
0
def get_pressure_lapse_rate(path=ims_path, model='LR', plot=False):
    from aux_gps import linear_fit_using_scipy_da_ts
    import matplotlib.pyplot as plt
    import xarray as xr
    from aux_gps import keep_iqr
    bp = xr.load_dataset(ims_path / 'IMS_BP_israeli_10mins.nc')
    bps = [keep_iqr(bp[x]) for x in bp]
    bp = xr.merge(bps)
    mean_p = bp.mean('time').to_array('alt')
    mean_p.name = 'mean_pressure'
    alts = [bp[x].attrs['station_alt'] for x in bp.data_vars]
    mean_p['alt'] = alts
    _, results = linear_fit_using_scipy_da_ts(mean_p,
                                              model=model,
                                              slope_factor=1,
                                              not_time=True)
    slope = results['slope']
    inter = results['intercept']
    modeled_var = slope * mean_p['alt'] + inter
    if plot:
        fig, ax = plt.subplots()
        modeled_var.plot(ax=ax, color='r')
        mean_p.plot.line(linewidth=0., marker='o', ax=ax, color='b')
        # lr = 1000 * abs(slope)
        textstr = 'Pressure lapse rate: {:.1f} hPa/km'.format(1000 * slope)
        props = dict(boxstyle='round', facecolor='wheat', alpha=0.5)
        # place a text box in upper left in axes coords
        ax.text(0.5,
                0.95,
                textstr,
                transform=ax.transAxes,
                fontsize=12,
                verticalalignment='top',
                bbox=props)
        ax.set_xlabel('Height a.s.l [m]')
        ax.set_ylabel('Mean Pressure [hPa]')
    return results
def prepare_station_to_pw_comparison(path=aero_path,
                                     gis_path=gis_path,
                                     station='boker',
                                     mm_anoms=False):
    from aux_gps import keep_iqr
    from aux_gps import anomalize_xr
    ds_dict = load_all_station(path=aero_path, gis_path=gis_path)
    try:
        da = ds_dict[station]['WV(cm)_935nm-AOD']
    except KeyError as e:
        print('station {} has no {} field'.format(station, e))
        return
    da = keep_iqr(da)
    # convert to mm:
    da = da * 10
    da.name = station
    if mm_anoms:
        da_mm = da.resample(time='MS').mean()
        da_mm_anoms = anomalize_xr(da_mm, freq='MS')
        da = da_mm_anoms
    da.attrs['data_source'] = 'AERONET'
    da.attrs['data_field'] = 'WV(cm)_935nm-AOD'
    da.attrs['units'] = 'mm'
    return da
def read_gipsyx_all_yearly_files(load_path,
                                 savepath=None,
                                 iqr_k=3.0,
                                 plot=False):
    """read, stitch and clean all yearly post proccessed ppp gipsyx solutions
    and concat them to a multiple fields time-series dataset"""
    from aux_gps import path_glob
    import xarray as xr
    from aux_gps import get_unique_index
    from aux_gps import dim_intersection
    import pandas as pd
    from aux_gps import filter_nan_errors
    from aux_gps import keep_iqr
    from aux_gps import xr_reindex_with_date_range
    from aux_gps import transform_ds_to_lat_lon_alt
    import logging

    def stitch_yearly_files(ds_list):
        """input is multiple field yearly dataset list and output is the same
        but with stitched discontinuieties"""
        fields = [x for x in ds_list[0].data_vars]
        for i, dss in enumerate(ds_list):
            if i == len(ds_list) - 1:
                break
            first_year = int(ds_list[i].time.dt.year.median().item())
            second_year = int(ds_list[i + 1].time.dt.year.median().item())
            first_ds = ds_list[i].sel(time=slice(
                '{}-12-31T18:00'.format(first_year), str(second_year)))
            second_ds = ds_list[i + 1].sel(time=slice(
                str(first_year), '{}-01-01T06:00'.format(second_year)))
            if dim_intersection([first_ds, second_ds], 'time') is None:
                logger.warning('skipping stitching years {} and {}...'.format(
                    first_year, second_year))
                continue
            else:
                logger.info('stitching years {} and {}'.format(
                    first_year, second_year))
            time = xr.concat([first_ds.time, second_ds.time], 'time')
            time = pd.to_datetime(get_unique_index(time).values)
            st_list = []
            for field in fields:
                df = first_ds[field].to_dataframe()
                df.columns = ['first']
                df = df.reindex(time)
                df['second'] = second_ds[field].to_dataframe()
                if field in ['X', 'Y', 'Z']:
                    method = 'simple_mean'
                elif field in ['GradNorth', 'GradEast', 'WetZ']:
                    method = 'smooth_mean'
                elif 'error' in field:
                    method = 'error_mean'
                dfs = stitch_two_cols(df, method=method)['stitched_signal']
                dfs.index.name = 'time'
                st = dfs.to_xarray()
                st.name = field
                st_list.append(st)
            # merge to all fields:
            st_ds = xr.merge(st_list)
            # replace stitched values to first ds and second ds:
            first_time = dim_intersection([ds_list[i], st_ds])
            vals_rpl = st_ds.sel(time=first_time)
            for field in ds_list[i].data_vars:
                ds_list[i][field].loc[{'time': first_time}] = vals_rpl[field]
            second_time = dim_intersection([ds_list[i + 1], st_ds])
            vals_rpl = st_ds.sel(time=second_time)
            for field in ds_list[i + 1].data_vars:
                ds_list[i + 1][field].loc[{
                    'time': second_time
                }] = vals_rpl[field]
        return ds_list

    logger = logging.getLogger('gipsyx_post_proccesser')
    files = sorted(path_glob(load_path, '*.nc'))
    ds_list = []
    for file in files:
        filename = file.as_posix().split('/')[-1]
        station = file.as_posix().split('/')[-1].split('_')[0]
        if 'ppp_post' not in filename:
            continue
        logger.info('reading {}'.format(filename))
        dss = xr.open_dataset(file)
        ds_list.append(dss)
    # now loop over ds_list and stitch yearly discontinuities:
    ds_list = stitch_yearly_files(ds_list)
    logger.info('merging all years...')
    ds = xr.merge(ds_list)
    logger.info('fixing meta-data...')
    for da in ds.data_vars:
        old_keys = [x for x in ds[da].attrs.keys()]
        vals = [x for x in ds[da].attrs.values()]
        new_keys = [x.split('>')[-1] for x in old_keys]
        ds[da].attrs = dict(zip(new_keys, vals))
        if 'desc' in ds[da].attrs.keys():
            ds[da].attrs['full_name'] = ds[da].attrs.pop('desc')
    logger.info('dropping duplicates time stamps...')
    ds = get_unique_index(ds)
    # clean with IQR all fields:
    logger.info('removing outliers with IQR of {}...'.format(iqr_k))
    ds = keep_iqr(ds, dim='time', qlow=0.25, qhigh=0.75, k=iqr_k)
    # filter the fields based on their errors not being NaNs:
    logger.info('filtering out fields if their errors are NaN...')
    ds = filter_nan_errors(ds, error_str='_error', dim='time')
    logger.info('transforming X, Y, Z coords to lat, lon and alt...')
    ds = transform_ds_to_lat_lon_alt(ds, ['X', 'Y', 'Z'], '_error', 'time')
    logger.info(
        'reindexing fields with 5 mins frequency(i.e., inserting NaNs)')
    ds = xr_reindex_with_date_range(ds, 'time', '5min')
    ds.attrs['station'] = station
    if plot:
        plot_gipsy_field(ds, None)
    if savepath is not None:
        comp = dict(zlib=True, complevel=9)  # best compression
        encoding = {var: comp for var in ds.data_vars}
        ymin = ds.time.min().dt.year.item()
        ymax = ds.time.max().dt.year.item()
        new_filename = '{}_PPP_{}-{}.nc'.format(station, ymin, ymax)
        ds.to_netcdf(savepath / new_filename, 'w', encoding=encoding)
        logger.info('{} was saved to {}'.format(new_filename, savepath))
    logger.info('Done!')
    return ds
Example #5
0
def produce_seasonal_trend_breakdown_time_series_from_jpl_gipsyx_site(station='bshm',
                                                                      path=jpl_path,
                                                                      var='V', k=2,
                                                                      verbose=True,
                                                                      plot=True):
    import xarray as xr
    from aux_gps import harmonic_da_ts
    from aux_gps import loess_curve
    from aux_gps import keep_iqr
    from aux_gps import get_unique_index
    from aux_gps import xr_reindex_with_date_range
    from aux_gps import decimal_year_to_datetime
    import matplotlib.pyplot as plt
    if verbose:
        print('producing seasonal time series for {} station {}'.format(station, var))
    ds = read_time_series_jpl_gipsyx_site(station=station,
                                          path=path/'time_series', verbose=verbose)
    # dyear = ds['decimal_year']
    da_ts = ds[var]
    da_ts = xr_reindex_with_date_range(get_unique_index(da_ts), freq='D')
    xr.infer_freq(da_ts['time'])
    if k is not None:
        da_ts = keep_iqr(da_ts, k=k)
    da_ts.name = '{}_{}'.format(station, var)
    # detrend:
    trend = loess_curve(da_ts, plot=False)['mean']
    trend.name = da_ts.name + '_trend'
    trend = xr_reindex_with_date_range(trend, freq='D')
    da_ts_detrended = da_ts - trend
    if verbose:
        print('detrended by loess.')
    da_ts_detrended.name = da_ts.name + '_detrended'
    # harmonic cpy fits:
    harm = harmonic_da_ts(da_ts_detrended.dropna('time'), n=2, grp='month',
                          return_ts_fit=True, verbose=verbose)
    harm = xr_reindex_with_date_range(harm, time_dim='time', freq='D')
    harm1 = harm.sel(cpy=1).reset_coords(drop=True)
    harm1.name = da_ts.name + '_annual'
    harm1_keys = [x for x in harm1.attrs.keys() if '_1' in x]
    harm1.attrs = dict(zip(harm1_keys, [harm1.attrs[x] for x in harm1_keys]))
    harm2 = harm.sel(cpy=2).reset_coords(drop=True)
    harm2.name = da_ts.name + '_semiannual'
    harm2_keys = [x for x in harm2.attrs.keys() if '_2' in x]
    harm2.attrs = dict(zip(harm2_keys, [harm2.attrs[x] for x in harm2_keys]))
    resid = da_ts_detrended - harm1 - harm2
    resid.name = da_ts.name + '_residual'
    ds = xr.merge([da_ts, trend, harm1, harm2, resid])
    # load breakpoints:
    try:
        breakpoints = xr.open_dataset(
            jpl_path/'jpl_break_estimates.nc').sel(station=station.upper())[var]
        df = breakpoints.dropna('year')['year'].to_dataframe()
    # load seasonal coeffs:
        df['dt'] = df['year'].apply(decimal_year_to_datetime)
        df['dt'] = df['dt'].round('D')
        bp_da = df.set_index(df['dt'])['dt'].to_xarray()
        bp_da = bp_da.rename({'dt': 'time'})
        ds['{}_{}_breakpoints'.format(station, var)] = bp_da
        no_bp = False
    except KeyError:
        if verbose:
            print('no breakpoints found for {}!'.format(station))
            no_bp = True
    # seas = xr.load_dataset(
    #     jpl_path/'jpl_seasonal_estimates.nc').sel(station=station.upper())
    # ac1, as1, ac2, as2 = seas[var].values
    # # build seasonal time series:
    # annual = xr.DataArray(ac1*np.cos(dyear*2*np.pi)+as1 *
    #                       np.sin(dyear*2*np.pi), dims=['time'])
    # annual['time'] = da_ts['time']
    # annual.name = '{}_{}_annual'.format(station, var)
    # annual.attrs['units'] = 'mm'
    # annual.attrs['long_name'] = 'annual mode'
    # semiannual = xr.DataArray(ac2*np.cos(dyear*4*np.pi)+as2 *
    #                           np.sin(dyear*4*np.pi), dims=['time'])
    # semiannual['time'] = da_ts['time']
    # semiannual.name = '{}_{}_semiannual'.format(station, var)
    # semiannual.attrs['units'] = 'mm'
    # semiannual.attrs['long_name'] = 'semiannual mode'
    # ds = xr.merge([annual, semiannual, da_ts])
    if plot:
        # plt.figure(figsize=(20, 20))
        dst = ds[[x for x in ds if 'breakpoints' not in x]]
        axes = dst.to_dataframe().plot(subplots=True, figsize=(20, 20), color='k')
        [ax.grid() for ax in axes]
        [ax.set_ylabel('[mm]') for ax in axes]
        if not no_bp:
            for bp in df['dt']:
                [ax.axvline(bp, color='red') for ax in axes]
        plt.tight_layout()
        fig, ax = plt.subplots(figsize=(7, 7))
        harm_mm = harmonic_da_ts(da_ts_detrended.dropna('time'), n=2, grp='month',
                                 return_ts_fit=False, verbose=verbose)
        harm_mm['{}_{}_detrended'.format(station, var)].plot.line(ax=ax, linewidth=0, marker='o', color='k')
        harm_mm['{}_mean'.format(station)].sel(cpy=1).plot.line(ax=ax, marker=None, color='tab:red')
        harm_mm['{}_mean'.format(station)].sel(cpy=2).plot.line(ax=ax, marker=None, color='tab:blue')
        harm_mm['{}_mean'.format(station)].sum('cpy').plot.line(ax=ax, marker=None, color='tab:purple')
        ax.grid()
    return ds