예제 #1
0
def get_wrf_pw_at_dsea_gnss_coord(path=des_path,
                                  work_path=work_yuval,
                                  point=None):
    from PW_stations import produce_geo_gnss_solved_stations
    import xarray as xr
    from aux_gps import get_nearest_lat_lon_for_xy
    from aux_gps import path_glob
    from aux_gps import get_unique_index
    df = produce_geo_gnss_solved_stations(path=work_path / 'gis', plot=False)
    dsea_point = df.loc['dsea'][['lat', 'lon']].astype(float).values
    files = path_glob(path, 'pw_wrfout*.nc')
    pw_list = []
    for file in files:
        pw_all = xr.load_dataset(file)
        freq = xr.infer_freq(pw_all['Time'])
        print(freq)
        if point is not None:
            print('looking for {} at wrf.'.format(point))
            dsea_point = point
        loc = get_nearest_lat_lon_for_xy(pw_all['XLAT'], pw_all['XLONG'],
                                         dsea_point)
        print(loc)
        pw = pw_all.isel(south_north=loc[0][0], west_east=loc[0][1])
        pw_list.append(pw)
    pw_ts = xr.concat(pw_list, 'Time')
    pw_ts = get_unique_index(pw_ts, dim='Time')
    return pw_ts
예제 #2
0
def replace_fields_in_ds(dss, da_repl, field='WetZ', verbose=None):
    """replaces dss overlapping field(and then some) with the stiched signal
    fron da_repl. be carful with the choices for field"""
    from aux_gps import get_unique_index
    import xarray as xr
    import logging
    logger = logging.getLogger('gipsyx_post_proccesser')
    if verbose == 0:
        logger.info('replacing {} field.'.format(field))
    # choose the field from the bigger dss:
    nums = sorted(
        list(
            set([
                int(x.split('-')[1]) for x in dss if x.split('-')[0] == field
            ])))
    ds = dss[['{}-{}'.format(field, i) for i in nums]]
    da_list = []
    for i, _ in enumerate(ds):
        if i == len(ds) - 1:
            break
        first = ds['{}-{}'.format(field, i)]
        time0 = list(set(first.dims))[0]
        second = ds['{}-{}'.format(field, i + 1)]
        time1 = list(set(second.dims))[0]
        try:
            min_time = first.dropna(time0)[time0].min()
            max_time = second.dropna(time1)[time1].max()
        except ValueError:
            if verbose == 1:
                logger.warning('item {}, {} - {} is lonely'.format(
                    field, i, i + 1))
            continue
        try:
            da = da_repl.sel(time=slice(min_time, max_time))
        except KeyError:
            if verbose == 1:
                logger.warning('item {}, {} - {} is lonely'.format(
                    field, i, i + 1))
            continue
        if verbose == 1:
            logger.info('proccesing {} and {}'.format(first.name, second.name))
        # utime = dim_union([first, second], 'time')
        first_time = set(first.dropna(time0)[time0].values).difference(
            set(da.time.values))
        second_time = set(second.dropna(time1)[time1].values).difference(
            set(da.time.values))
        first = first.sel({time0: list(first_time)})
        second = second.sel({time1: list(second_time)})
        first = first.rename({time0: 'time'})
        second = second.rename({time1: 'time'})
        da_list.append(xr.concat([first, da, second], 'time'))
    da_final = xr.concat(da_list, 'time')
    da_final = da_final.sortby('time')
    da_final.name = field
    da_final.attrs = da_repl.attrs
    da_final = get_unique_index(da_final, 'time')
    return da_final
예제 #3
0
 def stitch_yearly_files(ds_list):
     """input is multiple field yearly dataset list and output is the same
     but with stitched discontinuieties"""
     fields = [x for x in ds_list[0].data_vars]
     for i, dss in enumerate(ds_list):
         if i == len(ds_list) - 1:
             break
         first_year = int(ds_list[i].time.dt.year.median().item())
         second_year = int(ds_list[i + 1].time.dt.year.median().item())
         first_ds = ds_list[i].sel(time=slice(
             '{}-12-31T18:00'.format(first_year), str(second_year)))
         second_ds = ds_list[i + 1].sel(time=slice(
             str(first_year), '{}-01-01T06:00'.format(second_year)))
         if dim_intersection([first_ds, second_ds], 'time') is None:
             logger.warning('skipping stitching years {} and {}...'.format(
                 first_year, second_year))
             continue
         else:
             logger.info('stitching years {} and {}'.format(
                 first_year, second_year))
         time = xr.concat([first_ds.time, second_ds.time], 'time')
         time = pd.to_datetime(get_unique_index(time).values)
         st_list = []
         for field in fields:
             df = first_ds[field].to_dataframe()
             df.columns = ['first']
             df = df.reindex(time)
             df['second'] = second_ds[field].to_dataframe()
             if field in ['X', 'Y', 'Z']:
                 method = 'simple_mean'
             elif field in ['GradNorth', 'GradEast', 'WetZ']:
                 method = 'smooth_mean'
             elif 'error' in field:
                 method = 'error_mean'
             dfs = stitch_two_cols(df, method=method)['stitched_signal']
             dfs.index.name = 'time'
             st = dfs.to_xarray()
             st.name = field
             st_list.append(st)
         # merge to all fields:
         st_ds = xr.merge(st_list)
         # replace stitched values to first ds and second ds:
         first_time = dim_intersection([ds_list[i], st_ds])
         vals_rpl = st_ds.sel(time=first_time)
         for field in ds_list[i].data_vars:
             ds_list[i][field].loc[{'time': first_time}] = vals_rpl[field]
         second_time = dim_intersection([ds_list[i + 1], st_ds])
         vals_rpl = st_ds.sel(time=second_time)
         for field in ds_list[i + 1].data_vars:
             ds_list[i + 1][field].loc[{
                 'time': second_time
             }] = vals_rpl[field]
     return ds_list
예제 #4
0
def assemble_WRF_pwv(path=des_path, work_path=work_yuval, radius=1):
    from PW_stations import produce_geo_gnss_solved_stations
    import xarray as xr
    from aux_gps import save_ncfile
    from aux_gps import get_nearest_lat_lon_for_xy
    from aux_gps import get_unique_index
    df = produce_geo_gnss_solved_stations(path=work_path / 'gis', plot=False)
    dsea_point = df.loc['dsea'][['lat', 'lon']].astype(float).values
    if radius is not None:
        point = None
    else:
        point = dsea_point
    wrf_pw = read_all_WRF_GNSS_files(path, var='pw', point=point)
    wrf_pw8 = xr.load_dataarray(
        path / 'pw_wrfout_d04_2014-08-08_40lev.nc').sel(Time='2014-08-08')
    wrf_pw16 = xr.load_dataarray(
        path / 'pw_wrfout_d04_2014-08-16_40lev.nc').sel(Time='2014-08-16')
    wrf_pw_8_16 = xr.concat([wrf_pw8, wrf_pw16], 'Time')
    print('looking for {} at wrf.'.format(dsea_point))
    loc = get_nearest_lat_lon_for_xy(wrf_pw_8_16['XLAT'], wrf_pw_8_16['XLONG'],
                                     dsea_point)
    print(loc)
    if radius is not None:
        print('getting {} radius around {}.'.format(radius, dsea_point))
        lat_islice = [loc[0][0] - radius, loc[0][0] + radius + 1]
        lon_islice = [loc[0][1] - radius, loc[0][1] + radius + 1]
        wrf_pw_8_16 = wrf_pw_8_16.isel(south_north=slice(*lat_islice),
                                       west_east=slice(*lon_islice))
        loc = get_nearest_lat_lon_for_xy(wrf_pw['XLAT'], wrf_pw['XLONG'],
                                         dsea_point)
        lat_islice = [loc[0][0] - radius, loc[0][0] + radius + 1]
        lon_islice = [loc[0][1] - radius, loc[0][1] + radius + 1]
        wrf_pw = wrf_pw.isel(south_north=slice(*lat_islice),
                             west_east=slice(*lon_islice))
    else:
        wrf_pw_8_16 = wrf_pw_8_16.isel(south_north=loc[0][0],
                                       west_east=loc[0][1])
    wrf_pw = xr.concat([wrf_pw, wrf_pw_8_16], 'Time')
    wrf_pw = wrf_pw.rename({'Time': 'time'})
    wrf_pw = wrf_pw.sortby('time')
    wrf_pw = get_unique_index(wrf_pw)
    if wrf_pw.attrs['projection'] is not None:
        wrf_pw.attrs['projection'] = wrf_pw.attrs['projection'].proj4()
    if radius is not None:
        filename = 'pwv_wrf_dsea_gnss_radius_{}_2014-08.nc'.format(radius)
    else:
        filename = 'pwv_wrf_dsea_gnss_point_2014-08.nc'
    save_ncfile(wrf_pw, des_path, filename)
    return wrf_pw
예제 #5
0
def read_surface_pressure(path=des_path, dem_path=work_yuval / 'AW3D30'):
    """taken from ein gedi spa 31.417313616189308, 35.378962961491474"""
    import pandas as pd
    from aux_gps import path_glob
    from aux_gps import get_unique_index
    import xarray as xr
    awd = xr.open_rasterio(dem_path / 'israel_dem.tif')
    awd = awd.squeeze(drop=True)
    alt = awd.sel(x=35.3789, y=31.4173, method='nearest').item()
    file = path_glob(path, 'EBS1_*_pressure.txt')[0]
    df = pd.read_csv(file)
    df['Time'] = pd.to_datetime(df['Time'], format='%d-%b-%Y %H:%M:%S')
    df = df.set_index('Time')
    df.index.name = 'time'
    da = df.to_xarray()['Press']
    da = get_unique_index(da)
    da.attrs['station_alt'] = alt
    da.attrs['lat'] = 31.4173
    da.attrs['lon'] = 35.3789
    return da
예제 #6
0
def produce_pwv_from_dsea_axis_station(path=axis_path, ims_path=ims_path):
    """use axis_path = work_yuval/dsea_gispyx for original soi-apn dsea station"""
    import xarray as xr
    from aux_gps import transform_ds_to_lat_lon_alt
    from aux_gps import get_unique_index
    ds = xr.load_dataset(path / 'smoothFinal_2014.nc').squeeze()
    ds = get_unique_index(ds)
    # for now cut:
    if 'axis' in path.as_posix():
        ds = ds.sel(time=slice(None, '2014-08-12'))
    ds = transform_ds_to_lat_lon_alt(ds)
    axis_zwd = ds['WetZ']
    ts = xr.open_dataset(ims_path / 'IMS_TD_israeli_10mins.nc')['SEDOM']
    axis_pwv = produce_pwv_from_zwd_with_ts_tm_from_deserve(ts=ts,
                                                            zwd=axis_zwd)
    if 'axis' in path.as_posix():
        axis_pwv.name = 'AXIS-DSEA'
    else:
        axis_pwv.name = 'SOI-DSEA'
    axis_pwv.attrs['lat'] = ds['lat'].values[0]
    axis_pwv.attrs['lon'] = ds['lon'].values[0]
    axis_pwv.attrs['alt'] = ds['alt'].values[0]
    return axis_pwv
예제 #7
0
def read_gipsyx_all_yearly_files(load_path,
                                 savepath=None,
                                 iqr_k=3.0,
                                 plot=False):
    """read, stitch and clean all yearly post proccessed ppp gipsyx solutions
    and concat them to a multiple fields time-series dataset"""
    from aux_gps import path_glob
    import xarray as xr
    from aux_gps import get_unique_index
    from aux_gps import dim_intersection
    import pandas as pd
    from aux_gps import filter_nan_errors
    from aux_gps import keep_iqr
    from aux_gps import xr_reindex_with_date_range
    from aux_gps import transform_ds_to_lat_lon_alt
    import logging

    def stitch_yearly_files(ds_list):
        """input is multiple field yearly dataset list and output is the same
        but with stitched discontinuieties"""
        fields = [x for x in ds_list[0].data_vars]
        for i, dss in enumerate(ds_list):
            if i == len(ds_list) - 1:
                break
            first_year = int(ds_list[i].time.dt.year.median().item())
            second_year = int(ds_list[i + 1].time.dt.year.median().item())
            first_ds = ds_list[i].sel(time=slice(
                '{}-12-31T18:00'.format(first_year), str(second_year)))
            second_ds = ds_list[i + 1].sel(time=slice(
                str(first_year), '{}-01-01T06:00'.format(second_year)))
            if dim_intersection([first_ds, second_ds], 'time') is None:
                logger.warning('skipping stitching years {} and {}...'.format(
                    first_year, second_year))
                continue
            else:
                logger.info('stitching years {} and {}'.format(
                    first_year, second_year))
            time = xr.concat([first_ds.time, second_ds.time], 'time')
            time = pd.to_datetime(get_unique_index(time).values)
            st_list = []
            for field in fields:
                df = first_ds[field].to_dataframe()
                df.columns = ['first']
                df = df.reindex(time)
                df['second'] = second_ds[field].to_dataframe()
                if field in ['X', 'Y', 'Z']:
                    method = 'simple_mean'
                elif field in ['GradNorth', 'GradEast', 'WetZ']:
                    method = 'smooth_mean'
                elif 'error' in field:
                    method = 'error_mean'
                dfs = stitch_two_cols(df, method=method)['stitched_signal']
                dfs.index.name = 'time'
                st = dfs.to_xarray()
                st.name = field
                st_list.append(st)
            # merge to all fields:
            st_ds = xr.merge(st_list)
            # replace stitched values to first ds and second ds:
            first_time = dim_intersection([ds_list[i], st_ds])
            vals_rpl = st_ds.sel(time=first_time)
            for field in ds_list[i].data_vars:
                ds_list[i][field].loc[{'time': first_time}] = vals_rpl[field]
            second_time = dim_intersection([ds_list[i + 1], st_ds])
            vals_rpl = st_ds.sel(time=second_time)
            for field in ds_list[i + 1].data_vars:
                ds_list[i + 1][field].loc[{
                    'time': second_time
                }] = vals_rpl[field]
        return ds_list

    logger = logging.getLogger('gipsyx_post_proccesser')
    files = sorted(path_glob(load_path, '*.nc'))
    ds_list = []
    for file in files:
        filename = file.as_posix().split('/')[-1]
        station = file.as_posix().split('/')[-1].split('_')[0]
        if 'ppp_post' not in filename:
            continue
        logger.info('reading {}'.format(filename))
        dss = xr.open_dataset(file)
        ds_list.append(dss)
    # now loop over ds_list and stitch yearly discontinuities:
    ds_list = stitch_yearly_files(ds_list)
    logger.info('merging all years...')
    ds = xr.merge(ds_list)
    logger.info('fixing meta-data...')
    for da in ds.data_vars:
        old_keys = [x for x in ds[da].attrs.keys()]
        vals = [x for x in ds[da].attrs.values()]
        new_keys = [x.split('>')[-1] for x in old_keys]
        ds[da].attrs = dict(zip(new_keys, vals))
        if 'desc' in ds[da].attrs.keys():
            ds[da].attrs['full_name'] = ds[da].attrs.pop('desc')
    logger.info('dropping duplicates time stamps...')
    ds = get_unique_index(ds)
    # clean with IQR all fields:
    logger.info('removing outliers with IQR of {}...'.format(iqr_k))
    ds = keep_iqr(ds, dim='time', qlow=0.25, qhigh=0.75, k=iqr_k)
    # filter the fields based on their errors not being NaNs:
    logger.info('filtering out fields if their errors are NaN...')
    ds = filter_nan_errors(ds, error_str='_error', dim='time')
    logger.info('transforming X, Y, Z coords to lat, lon and alt...')
    ds = transform_ds_to_lat_lon_alt(ds, ['X', 'Y', 'Z'], '_error', 'time')
    logger.info(
        'reindexing fields with 5 mins frequency(i.e., inserting NaNs)')
    ds = xr_reindex_with_date_range(ds, 'time', '5min')
    ds.attrs['station'] = station
    if plot:
        plot_gipsy_field(ds, None)
    if savepath is not None:
        comp = dict(zlib=True, complevel=9)  # best compression
        encoding = {var: comp for var in ds.data_vars}
        ymin = ds.time.min().dt.year.item()
        ymax = ds.time.max().dt.year.item()
        new_filename = '{}_PPP_{}-{}.nc'.format(station, ymin, ymax)
        ds.to_netcdf(savepath / new_filename, 'w', encoding=encoding)
        logger.info('{} was saved to {}'.format(new_filename, savepath))
    logger.info('Done!')
    return ds
예제 #8
0
def produce_seasonal_trend_breakdown_time_series_from_jpl_gipsyx_site(station='bshm',
                                                                      path=jpl_path,
                                                                      var='V', k=2,
                                                                      verbose=True,
                                                                      plot=True):
    import xarray as xr
    from aux_gps import harmonic_da_ts
    from aux_gps import loess_curve
    from aux_gps import keep_iqr
    from aux_gps import get_unique_index
    from aux_gps import xr_reindex_with_date_range
    from aux_gps import decimal_year_to_datetime
    import matplotlib.pyplot as plt
    if verbose:
        print('producing seasonal time series for {} station {}'.format(station, var))
    ds = read_time_series_jpl_gipsyx_site(station=station,
                                          path=path/'time_series', verbose=verbose)
    # dyear = ds['decimal_year']
    da_ts = ds[var]
    da_ts = xr_reindex_with_date_range(get_unique_index(da_ts), freq='D')
    xr.infer_freq(da_ts['time'])
    if k is not None:
        da_ts = keep_iqr(da_ts, k=k)
    da_ts.name = '{}_{}'.format(station, var)
    # detrend:
    trend = loess_curve(da_ts, plot=False)['mean']
    trend.name = da_ts.name + '_trend'
    trend = xr_reindex_with_date_range(trend, freq='D')
    da_ts_detrended = da_ts - trend
    if verbose:
        print('detrended by loess.')
    da_ts_detrended.name = da_ts.name + '_detrended'
    # harmonic cpy fits:
    harm = harmonic_da_ts(da_ts_detrended.dropna('time'), n=2, grp='month',
                          return_ts_fit=True, verbose=verbose)
    harm = xr_reindex_with_date_range(harm, time_dim='time', freq='D')
    harm1 = harm.sel(cpy=1).reset_coords(drop=True)
    harm1.name = da_ts.name + '_annual'
    harm1_keys = [x for x in harm1.attrs.keys() if '_1' in x]
    harm1.attrs = dict(zip(harm1_keys, [harm1.attrs[x] for x in harm1_keys]))
    harm2 = harm.sel(cpy=2).reset_coords(drop=True)
    harm2.name = da_ts.name + '_semiannual'
    harm2_keys = [x for x in harm2.attrs.keys() if '_2' in x]
    harm2.attrs = dict(zip(harm2_keys, [harm2.attrs[x] for x in harm2_keys]))
    resid = da_ts_detrended - harm1 - harm2
    resid.name = da_ts.name + '_residual'
    ds = xr.merge([da_ts, trend, harm1, harm2, resid])
    # load breakpoints:
    try:
        breakpoints = xr.open_dataset(
            jpl_path/'jpl_break_estimates.nc').sel(station=station.upper())[var]
        df = breakpoints.dropna('year')['year'].to_dataframe()
    # load seasonal coeffs:
        df['dt'] = df['year'].apply(decimal_year_to_datetime)
        df['dt'] = df['dt'].round('D')
        bp_da = df.set_index(df['dt'])['dt'].to_xarray()
        bp_da = bp_da.rename({'dt': 'time'})
        ds['{}_{}_breakpoints'.format(station, var)] = bp_da
        no_bp = False
    except KeyError:
        if verbose:
            print('no breakpoints found for {}!'.format(station))
            no_bp = True
    # seas = xr.load_dataset(
    #     jpl_path/'jpl_seasonal_estimates.nc').sel(station=station.upper())
    # ac1, as1, ac2, as2 = seas[var].values
    # # build seasonal time series:
    # annual = xr.DataArray(ac1*np.cos(dyear*2*np.pi)+as1 *
    #                       np.sin(dyear*2*np.pi), dims=['time'])
    # annual['time'] = da_ts['time']
    # annual.name = '{}_{}_annual'.format(station, var)
    # annual.attrs['units'] = 'mm'
    # annual.attrs['long_name'] = 'annual mode'
    # semiannual = xr.DataArray(ac2*np.cos(dyear*4*np.pi)+as2 *
    #                           np.sin(dyear*4*np.pi), dims=['time'])
    # semiannual['time'] = da_ts['time']
    # semiannual.name = '{}_{}_semiannual'.format(station, var)
    # semiannual.attrs['units'] = 'mm'
    # semiannual.attrs['long_name'] = 'semiannual mode'
    # ds = xr.merge([annual, semiannual, da_ts])
    if plot:
        # plt.figure(figsize=(20, 20))
        dst = ds[[x for x in ds if 'breakpoints' not in x]]
        axes = dst.to_dataframe().plot(subplots=True, figsize=(20, 20), color='k')
        [ax.grid() for ax in axes]
        [ax.set_ylabel('[mm]') for ax in axes]
        if not no_bp:
            for bp in df['dt']:
                [ax.axvline(bp, color='red') for ax in axes]
        plt.tight_layout()
        fig, ax = plt.subplots(figsize=(7, 7))
        harm_mm = harmonic_da_ts(da_ts_detrended.dropna('time'), n=2, grp='month',
                                 return_ts_fit=False, verbose=verbose)
        harm_mm['{}_{}_detrended'.format(station, var)].plot.line(ax=ax, linewidth=0, marker='o', color='k')
        harm_mm['{}_mean'.format(station)].sel(cpy=1).plot.line(ax=ax, marker=None, color='tab:red')
        harm_mm['{}_mean'.format(station)].sel(cpy=2).plot.line(ax=ax, marker=None, color='tab:blue')
        harm_mm['{}_mean'.format(station)].sum('cpy').plot.line(ax=ax, marker=None, color='tab:purple')
        ax.grid()
    return ds
예제 #9
0
def read_hydrographs(path=hydro_path):
    from aux_gps import path_glob
    import pandas as pd
    import xarray as xr
    from aux_gps import get_unique_index
    files = path_glob(path, 'hydro_graph*.txt')
    df_list = []
    for file in files:
        print(file)
        df = pd.read_csv(file, header=0, sep=',')
        df.columns = [
            'id', 'name', 'time', 'tide_height[m]', 'flow[m^3/sec]',
            'data_type', 'flow_type', 'record_type', 'record_code'
        ]
        df['time'] = pd.to_datetime(df['time'], dayfirst=True)
        df['tide_height[m]'] = df['tide_height[m]'].astype(float)
        df['flow[m^3/sec]'] = df['flow[m^3/sec]'].astype(float)
        df.loc[:, 'data_type'][df['data_type'].str.contains(
            'מדודים', na=False)] = 'measured'
        df.loc[:, 'data_type'][df['data_type'].str.contains(
            'משוחזרים', na=False)] = 'reconstructed'
        df.loc[:,
               'flow_type'][df['flow_type'].str.contains('תקין',
                                                         na=False)] = 'normal'
        df.loc[:, 'flow_type'][df['flow_type'].str.contains('גאות',
                                                            na=False)] = 'tide'
        df.loc[:, 'record_type'][df['record_type'].str.contains(
            'נקודה פנימית', na=False)] = 'inner_point'
        df.loc[:, 'record_type'][df['record_type'].str.contains(
            'נקודה פנימית', na=False)] = 'inner_point'
        df.loc[:, 'record_type'][df['record_type'].str.contains(
            'התחלת קטע', na=False)] = 'section_begining'
        df.loc[:, 'record_type'][df['record_type'].str.contains(
            'סיום קטע', na=False)] = 'section_ending'
        df_list.append(df)
    df = pd.concat(df_list)
    dfs = [x for _, x in df.groupby('id')]
    ds_list = []
    meta_df = read_hydro_metadata(path, gis_path, False)
    for df in dfs:
        st_id = df['id'].iloc[0]
        st_name = df['name'].iloc[0]
        print('proccessing station number: {}, {}'.format(st_id, st_name))
        meta = meta_df[meta_df['id'] == st_id]
        ds = xr.Dataset()
        df.set_index('time', inplace=True)
        attrs = {}
        attrs['station_name'] = st_name
        if not meta.empty:
            attrs['lon'] = meta.lon.values.item()
            attrs['lat'] = meta.lat.values.item()
            attrs['alt'] = meta.alt.values.item()
            attrs['drainage_basin_area'] = meta.area.values.item()
            attrs['active'] = meta.active.values.item()
        attrs['units'] = 'm'
        tide_height = df['tide_height[m]'].to_xarray()
        tide_height.name = 'HS_{}_tide_height'.format(st_id)
        tide_height.attrs = attrs
        flow = df['flow[m^3/sec]'].to_xarray()
        flow.name = 'HS_{}_flow'.format(st_id)
        attrs['units'] = 'm^3/sec'
        flow.attrs = attrs
        ds['{}'.format(tide_height.name)] = tide_height
        ds['{}'.format(flow.name)] = flow
        ds_list.append(ds)
    dsu = [get_unique_index(x) for x in ds_list]
    print('merging...')
    ds = xr.merge(dsu)
    filename = 'hydro_graphs.nc'
    print('saving {} to {}'.format(filename, path))
    comp = dict(zlib=True, complevel=9)  # best compression
    encoding = {var: comp for var in ds.data_vars}
    ds.to_netcdf(path / filename, 'w', encoding=encoding)
    print('Done!')
    return ds
예제 #10
0
def read_tides(path=hydro_path):
    from aux_gps import path_glob
    import pandas as pd
    import xarray as xr
    from aux_gps import get_unique_index
    files = path_glob(path, 'tide_report*.xlsx')
    df_list = []
    for file in files:
        df = pd.read_excel(file, header=4)
        df.drop(df.columns[len(df.columns) - 1], axis=1, inplace=True)
        df.columns = [
            'id', 'name', 'hydro_year', 'tide_start_hour', 'tide_start_date',
            'tide_end_hour', 'tide_end_date', 'tide_duration', 'tide_max_hour',
            'tide_max_date', 'max_height', 'max_flow[m^3/sec]', 'tide_vol[MCM]'
        ]
        df = df[~df.hydro_year.isnull()]
        df['id'] = df['id'].astype(int)
        df['tide_start'] = pd.to_datetime(
            df['tide_start_date'], dayfirst=True) + pd.to_timedelta(
                df['tide_start_hour'].add(':00'), unit='m', errors='coerce')
        df['tide_end'] = pd.to_datetime(
            df['tide_end_date'], dayfirst=True) + pd.to_timedelta(
                df['tide_end_hour'].add(':00'), unit='m', errors='coerce')
        df['tide_max'] = pd.to_datetime(
            df['tide_max_date'], dayfirst=True) + pd.to_timedelta(
                df['tide_max_hour'].add(':00'), unit='m', errors='coerce')
        df['tide_duration'] = pd.to_timedelta(df['tide_duration'] + ':00',
                                              unit='m',
                                              errors='coerce')
        df.loc[:, 'max_flow[m^3/sec]'][df['max_flow[m^3/sec]'].str.contains(
            '<', na=False)] = 0
        df.loc[:,
               'tide_vol[MCM]'][df['tide_vol[MCM]'].str.contains('<',
                                                                 na=False)] = 0
        df['max_flow[m^3/sec]'] = df['max_flow[m^3/sec]'].astype(float)
        df['tide_vol[MCM]'] = df['tide_vol[MCM]'].astype(float)
        to_drop = [
            'tide_start_hour', 'tide_start_date', 'tide_end_hour',
            'tide_end_date', 'tide_max_hour', 'tide_max_date'
        ]
        df = df.drop(to_drop, axis=1)
        df_list.append(df)
    df = pd.concat(df_list)
    dfs = [x for _, x in df.groupby('id')]
    ds_list = []
    meta_df = read_hydro_metadata(path, gis_path, False)
    for df in dfs:
        st_id = df['id'].iloc[0]
        st_name = df['name'].iloc[0]
        print('proccessing station number: {}, {}'.format(st_id, st_name))
        meta = meta_df[meta_df['id'] == st_id]
        ds = xr.Dataset()
        df.set_index('tide_start', inplace=True)
        attrs = {}
        attrs['station_name'] = st_name
        if not meta.empty:
            attrs['lon'] = meta.lon.values.item()
            attrs['lat'] = meta.lat.values.item()
            attrs['alt'] = meta.alt.values.item()
            attrs['drainage_basin_area'] = meta.area.values.item()
            attrs['active'] = meta.active.values.item()
        attrs['units'] = 'm'
        max_height = df['max_height'].to_xarray()
        max_height.name = 'TS_{}_max_height'.format(st_id)
        max_height.attrs = attrs
        max_flow = df['max_flow[m^3/sec]'].to_xarray()
        max_flow.name = 'TS_{}_max_flow'.format(st_id)
        attrs['units'] = 'm^3/sec'
        max_flow.attrs = attrs
        attrs['units'] = 'MCM'
        tide_vol = df['tide_vol[MCM]'].to_xarray()
        tide_vol.name = 'TS_{}_tide_vol'.format(st_id)
        tide_vol.attrs = attrs
        attrs.pop('units')
        #        tide_start = df['tide_start'].to_xarray()
        #        tide_start.name = 'TS_{}_tide_start'.format(st_id)
        #        tide_start.attrs = attrs
        tide_end = df['tide_end'].to_xarray()
        tide_end.name = 'TS_{}_tide_end'.format(st_id)
        tide_end.attrs = attrs
        tide_max = df['tide_max'].to_xarray()
        tide_max.name = 'TS_{}_tide_max'.format(st_id)
        tide_max.attrs = attrs
        ds['{}'.format(max_height.name)] = max_height
        ds['{}'.format(max_flow.name)] = max_flow
        ds['{}'.format(tide_vol.name)] = tide_vol
        #         ds['{}'.format(tide_start.name)] = tide_start
        ds['{}'.format(tide_end.name)] = tide_end
        ds['{}'.format(tide_max.name)] = tide_max
        ds_list.append(ds)
    dsu = [get_unique_index(x, dim='tide_start') for x in ds_list]
    print('merging...')
    ds = xr.merge(dsu)
    filename = 'hydro_tides.nc'
    print('saving {} to {}'.format(filename, path))
    comp = dict(zlib=True, complevel=9)  # best compression
    encoding = {var: comp for var in ds.data_vars}
    ds.to_netcdf(path / filename, 'w', encoding=encoding)
    print('Done!')
    return ds
예제 #11
0
def plot_figure_4(physical_file=phys_soundings,
                  model='LR',
                  times=['2007', '2019'],
                  save=True):
    """plot ts-tm relashonship"""
    import xarray as xr
    import matplotlib.pyplot as plt
    import seaborn as sns
    from PW_stations import ML_Switcher
    from sklearn.metrics import mean_squared_error
    from sklearn.metrics import r2_score
    from aux_gps import get_unique_index
    import numpy as np
    from mpl_toolkits.axes_grid1.inset_locator import inset_axes
    # sns.set_style('whitegrid')
    pds = xr.open_dataset(phys_soundings)
    pds = pds[['Tm', 'Ts']]
    pds = get_unique_index(pds, 'sound_time')
    pds = pds.sel(sound_time=slice(*times))
    fig, ax = plt.subplots(1, 1, figsize=(7, 7))
    pds.plot.scatter(x='Ts',
                     y='Tm',
                     marker='.',
                     s=100.,
                     linewidth=0,
                     alpha=0.5,
                     ax=ax)
    ax.grid()
    ml = ML_Switcher()
    fit_model = ml.pick_model(model)
    X = pds.Ts.values.reshape(-1, 1)
    y = pds.Tm.values
    fit_model.fit(X, y)
    predict = fit_model.predict(X)
    coef = fit_model.coef_[0]
    inter = fit_model.intercept_
    ax.plot(X, predict, c='r')
    bevis_tm = pds.Ts.values * 0.72 + 70.0
    ax.plot(pds.Ts.values, bevis_tm, c='purple')
    ax.legend([
        'OLS ({:.2f}, {:.2f})'.format(coef, inter),
        'Bevis 1992 et al. (0.72, 70.0)'
    ])
    #    ax.set_xlabel('Surface Temperature [K]')
    #    ax.set_ylabel('Water Vapor Mean Atmospheric Temperature [K]')
    ax.set_xlabel('Ts [K]')
    ax.set_ylabel('Tm [K]')
    ax.set_ylim(265, 320)
    axin1 = inset_axes(ax, width="40%", height="40%", loc=2)
    resid = predict - y
    sns.distplot(resid,
                 bins=50,
                 color='k',
                 label='residuals',
                 ax=axin1,
                 kde=False,
                 hist_kws={
                     "linewidth": 1,
                     "alpha": 0.5,
                     "color": "k"
                 })
    axin1.yaxis.tick_right()
    rmean = np.mean(resid)
    rmse = np.sqrt(mean_squared_error(y, predict))
    r2 = r2_score(y, predict)
    axin1.axvline(rmean, color='r', linestyle='dashed', linewidth=1)
    # axin1.set_xlabel('Residual distribution[K]')
    textstr = '\n'.join([
        'n={}'.format(pds.Ts.size), 'RMSE: ', '{:.2f} K'.format(rmse),
        r'R$^2$: {:.2f}'.format(r2)
    ])
    props = dict(boxstyle='round', facecolor='wheat', alpha=0.5)
    axin1.text(0.05,
               0.95,
               textstr,
               transform=axin1.transAxes,
               fontsize=10,
               verticalalignment='top',
               bbox=props)
    #    axin1.text(0.2, 0.9, 'n={}'.format(pds.Ts.size),
    #               verticalalignment='top', horizontalalignment='center',
    #               transform=axin1.transAxes, color='k', fontsize=10)
    #    axin1.text(0.78, 0.9, 'RMSE: {:.2f} K'.format(rmse),
    #               verticalalignment='top', horizontalalignment='center',
    #               transform=axin1.transAxes, color='k', fontsize=10)
    axin1.set_xlim(-15, 15)
    fig.tight_layout()
    filename = 'Bet_dagan_ts_tm_fit.png'
    caption(
        'Water vapor mean temperature (Tm) vs. surface temperature (Ts) of the Bet-dagan radiosonde station. Ordinary least squares linear fit(red) yields the residual distribution with RMSE of 4 K. Bevis(1992) model is plotted(purple) for comparison.'
    )
    if save:
        plt.savefig(savefig_path / filename, bbox_inches='tight')
    return