def get_wrf_pw_at_dsea_gnss_coord(path=des_path, work_path=work_yuval, point=None): from PW_stations import produce_geo_gnss_solved_stations import xarray as xr from aux_gps import get_nearest_lat_lon_for_xy from aux_gps import path_glob from aux_gps import get_unique_index df = produce_geo_gnss_solved_stations(path=work_path / 'gis', plot=False) dsea_point = df.loc['dsea'][['lat', 'lon']].astype(float).values files = path_glob(path, 'pw_wrfout*.nc') pw_list = [] for file in files: pw_all = xr.load_dataset(file) freq = xr.infer_freq(pw_all['Time']) print(freq) if point is not None: print('looking for {} at wrf.'.format(point)) dsea_point = point loc = get_nearest_lat_lon_for_xy(pw_all['XLAT'], pw_all['XLONG'], dsea_point) print(loc) pw = pw_all.isel(south_north=loc[0][0], west_east=loc[0][1]) pw_list.append(pw) pw_ts = xr.concat(pw_list, 'Time') pw_ts = get_unique_index(pw_ts, dim='Time') return pw_ts
def replace_fields_in_ds(dss, da_repl, field='WetZ', verbose=None): """replaces dss overlapping field(and then some) with the stiched signal fron da_repl. be carful with the choices for field""" from aux_gps import get_unique_index import xarray as xr import logging logger = logging.getLogger('gipsyx_post_proccesser') if verbose == 0: logger.info('replacing {} field.'.format(field)) # choose the field from the bigger dss: nums = sorted( list( set([ int(x.split('-')[1]) for x in dss if x.split('-')[0] == field ]))) ds = dss[['{}-{}'.format(field, i) for i in nums]] da_list = [] for i, _ in enumerate(ds): if i == len(ds) - 1: break first = ds['{}-{}'.format(field, i)] time0 = list(set(first.dims))[0] second = ds['{}-{}'.format(field, i + 1)] time1 = list(set(second.dims))[0] try: min_time = first.dropna(time0)[time0].min() max_time = second.dropna(time1)[time1].max() except ValueError: if verbose == 1: logger.warning('item {}, {} - {} is lonely'.format( field, i, i + 1)) continue try: da = da_repl.sel(time=slice(min_time, max_time)) except KeyError: if verbose == 1: logger.warning('item {}, {} - {} is lonely'.format( field, i, i + 1)) continue if verbose == 1: logger.info('proccesing {} and {}'.format(first.name, second.name)) # utime = dim_union([first, second], 'time') first_time = set(first.dropna(time0)[time0].values).difference( set(da.time.values)) second_time = set(second.dropna(time1)[time1].values).difference( set(da.time.values)) first = first.sel({time0: list(first_time)}) second = second.sel({time1: list(second_time)}) first = first.rename({time0: 'time'}) second = second.rename({time1: 'time'}) da_list.append(xr.concat([first, da, second], 'time')) da_final = xr.concat(da_list, 'time') da_final = da_final.sortby('time') da_final.name = field da_final.attrs = da_repl.attrs da_final = get_unique_index(da_final, 'time') return da_final
def stitch_yearly_files(ds_list): """input is multiple field yearly dataset list and output is the same but with stitched discontinuieties""" fields = [x for x in ds_list[0].data_vars] for i, dss in enumerate(ds_list): if i == len(ds_list) - 1: break first_year = int(ds_list[i].time.dt.year.median().item()) second_year = int(ds_list[i + 1].time.dt.year.median().item()) first_ds = ds_list[i].sel(time=slice( '{}-12-31T18:00'.format(first_year), str(second_year))) second_ds = ds_list[i + 1].sel(time=slice( str(first_year), '{}-01-01T06:00'.format(second_year))) if dim_intersection([first_ds, second_ds], 'time') is None: logger.warning('skipping stitching years {} and {}...'.format( first_year, second_year)) continue else: logger.info('stitching years {} and {}'.format( first_year, second_year)) time = xr.concat([first_ds.time, second_ds.time], 'time') time = pd.to_datetime(get_unique_index(time).values) st_list = [] for field in fields: df = first_ds[field].to_dataframe() df.columns = ['first'] df = df.reindex(time) df['second'] = second_ds[field].to_dataframe() if field in ['X', 'Y', 'Z']: method = 'simple_mean' elif field in ['GradNorth', 'GradEast', 'WetZ']: method = 'smooth_mean' elif 'error' in field: method = 'error_mean' dfs = stitch_two_cols(df, method=method)['stitched_signal'] dfs.index.name = 'time' st = dfs.to_xarray() st.name = field st_list.append(st) # merge to all fields: st_ds = xr.merge(st_list) # replace stitched values to first ds and second ds: first_time = dim_intersection([ds_list[i], st_ds]) vals_rpl = st_ds.sel(time=first_time) for field in ds_list[i].data_vars: ds_list[i][field].loc[{'time': first_time}] = vals_rpl[field] second_time = dim_intersection([ds_list[i + 1], st_ds]) vals_rpl = st_ds.sel(time=second_time) for field in ds_list[i + 1].data_vars: ds_list[i + 1][field].loc[{ 'time': second_time }] = vals_rpl[field] return ds_list
def assemble_WRF_pwv(path=des_path, work_path=work_yuval, radius=1): from PW_stations import produce_geo_gnss_solved_stations import xarray as xr from aux_gps import save_ncfile from aux_gps import get_nearest_lat_lon_for_xy from aux_gps import get_unique_index df = produce_geo_gnss_solved_stations(path=work_path / 'gis', plot=False) dsea_point = df.loc['dsea'][['lat', 'lon']].astype(float).values if radius is not None: point = None else: point = dsea_point wrf_pw = read_all_WRF_GNSS_files(path, var='pw', point=point) wrf_pw8 = xr.load_dataarray( path / 'pw_wrfout_d04_2014-08-08_40lev.nc').sel(Time='2014-08-08') wrf_pw16 = xr.load_dataarray( path / 'pw_wrfout_d04_2014-08-16_40lev.nc').sel(Time='2014-08-16') wrf_pw_8_16 = xr.concat([wrf_pw8, wrf_pw16], 'Time') print('looking for {} at wrf.'.format(dsea_point)) loc = get_nearest_lat_lon_for_xy(wrf_pw_8_16['XLAT'], wrf_pw_8_16['XLONG'], dsea_point) print(loc) if radius is not None: print('getting {} radius around {}.'.format(radius, dsea_point)) lat_islice = [loc[0][0] - radius, loc[0][0] + radius + 1] lon_islice = [loc[0][1] - radius, loc[0][1] + radius + 1] wrf_pw_8_16 = wrf_pw_8_16.isel(south_north=slice(*lat_islice), west_east=slice(*lon_islice)) loc = get_nearest_lat_lon_for_xy(wrf_pw['XLAT'], wrf_pw['XLONG'], dsea_point) lat_islice = [loc[0][0] - radius, loc[0][0] + radius + 1] lon_islice = [loc[0][1] - radius, loc[0][1] + radius + 1] wrf_pw = wrf_pw.isel(south_north=slice(*lat_islice), west_east=slice(*lon_islice)) else: wrf_pw_8_16 = wrf_pw_8_16.isel(south_north=loc[0][0], west_east=loc[0][1]) wrf_pw = xr.concat([wrf_pw, wrf_pw_8_16], 'Time') wrf_pw = wrf_pw.rename({'Time': 'time'}) wrf_pw = wrf_pw.sortby('time') wrf_pw = get_unique_index(wrf_pw) if wrf_pw.attrs['projection'] is not None: wrf_pw.attrs['projection'] = wrf_pw.attrs['projection'].proj4() if radius is not None: filename = 'pwv_wrf_dsea_gnss_radius_{}_2014-08.nc'.format(radius) else: filename = 'pwv_wrf_dsea_gnss_point_2014-08.nc' save_ncfile(wrf_pw, des_path, filename) return wrf_pw
def read_surface_pressure(path=des_path, dem_path=work_yuval / 'AW3D30'): """taken from ein gedi spa 31.417313616189308, 35.378962961491474""" import pandas as pd from aux_gps import path_glob from aux_gps import get_unique_index import xarray as xr awd = xr.open_rasterio(dem_path / 'israel_dem.tif') awd = awd.squeeze(drop=True) alt = awd.sel(x=35.3789, y=31.4173, method='nearest').item() file = path_glob(path, 'EBS1_*_pressure.txt')[0] df = pd.read_csv(file) df['Time'] = pd.to_datetime(df['Time'], format='%d-%b-%Y %H:%M:%S') df = df.set_index('Time') df.index.name = 'time' da = df.to_xarray()['Press'] da = get_unique_index(da) da.attrs['station_alt'] = alt da.attrs['lat'] = 31.4173 da.attrs['lon'] = 35.3789 return da
def produce_pwv_from_dsea_axis_station(path=axis_path, ims_path=ims_path): """use axis_path = work_yuval/dsea_gispyx for original soi-apn dsea station""" import xarray as xr from aux_gps import transform_ds_to_lat_lon_alt from aux_gps import get_unique_index ds = xr.load_dataset(path / 'smoothFinal_2014.nc').squeeze() ds = get_unique_index(ds) # for now cut: if 'axis' in path.as_posix(): ds = ds.sel(time=slice(None, '2014-08-12')) ds = transform_ds_to_lat_lon_alt(ds) axis_zwd = ds['WetZ'] ts = xr.open_dataset(ims_path / 'IMS_TD_israeli_10mins.nc')['SEDOM'] axis_pwv = produce_pwv_from_zwd_with_ts_tm_from_deserve(ts=ts, zwd=axis_zwd) if 'axis' in path.as_posix(): axis_pwv.name = 'AXIS-DSEA' else: axis_pwv.name = 'SOI-DSEA' axis_pwv.attrs['lat'] = ds['lat'].values[0] axis_pwv.attrs['lon'] = ds['lon'].values[0] axis_pwv.attrs['alt'] = ds['alt'].values[0] return axis_pwv
def read_gipsyx_all_yearly_files(load_path, savepath=None, iqr_k=3.0, plot=False): """read, stitch and clean all yearly post proccessed ppp gipsyx solutions and concat them to a multiple fields time-series dataset""" from aux_gps import path_glob import xarray as xr from aux_gps import get_unique_index from aux_gps import dim_intersection import pandas as pd from aux_gps import filter_nan_errors from aux_gps import keep_iqr from aux_gps import xr_reindex_with_date_range from aux_gps import transform_ds_to_lat_lon_alt import logging def stitch_yearly_files(ds_list): """input is multiple field yearly dataset list and output is the same but with stitched discontinuieties""" fields = [x for x in ds_list[0].data_vars] for i, dss in enumerate(ds_list): if i == len(ds_list) - 1: break first_year = int(ds_list[i].time.dt.year.median().item()) second_year = int(ds_list[i + 1].time.dt.year.median().item()) first_ds = ds_list[i].sel(time=slice( '{}-12-31T18:00'.format(first_year), str(second_year))) second_ds = ds_list[i + 1].sel(time=slice( str(first_year), '{}-01-01T06:00'.format(second_year))) if dim_intersection([first_ds, second_ds], 'time') is None: logger.warning('skipping stitching years {} and {}...'.format( first_year, second_year)) continue else: logger.info('stitching years {} and {}'.format( first_year, second_year)) time = xr.concat([first_ds.time, second_ds.time], 'time') time = pd.to_datetime(get_unique_index(time).values) st_list = [] for field in fields: df = first_ds[field].to_dataframe() df.columns = ['first'] df = df.reindex(time) df['second'] = second_ds[field].to_dataframe() if field in ['X', 'Y', 'Z']: method = 'simple_mean' elif field in ['GradNorth', 'GradEast', 'WetZ']: method = 'smooth_mean' elif 'error' in field: method = 'error_mean' dfs = stitch_two_cols(df, method=method)['stitched_signal'] dfs.index.name = 'time' st = dfs.to_xarray() st.name = field st_list.append(st) # merge to all fields: st_ds = xr.merge(st_list) # replace stitched values to first ds and second ds: first_time = dim_intersection([ds_list[i], st_ds]) vals_rpl = st_ds.sel(time=first_time) for field in ds_list[i].data_vars: ds_list[i][field].loc[{'time': first_time}] = vals_rpl[field] second_time = dim_intersection([ds_list[i + 1], st_ds]) vals_rpl = st_ds.sel(time=second_time) for field in ds_list[i + 1].data_vars: ds_list[i + 1][field].loc[{ 'time': second_time }] = vals_rpl[field] return ds_list logger = logging.getLogger('gipsyx_post_proccesser') files = sorted(path_glob(load_path, '*.nc')) ds_list = [] for file in files: filename = file.as_posix().split('/')[-1] station = file.as_posix().split('/')[-1].split('_')[0] if 'ppp_post' not in filename: continue logger.info('reading {}'.format(filename)) dss = xr.open_dataset(file) ds_list.append(dss) # now loop over ds_list and stitch yearly discontinuities: ds_list = stitch_yearly_files(ds_list) logger.info('merging all years...') ds = xr.merge(ds_list) logger.info('fixing meta-data...') for da in ds.data_vars: old_keys = [x for x in ds[da].attrs.keys()] vals = [x for x in ds[da].attrs.values()] new_keys = [x.split('>')[-1] for x in old_keys] ds[da].attrs = dict(zip(new_keys, vals)) if 'desc' in ds[da].attrs.keys(): ds[da].attrs['full_name'] = ds[da].attrs.pop('desc') logger.info('dropping duplicates time stamps...') ds = get_unique_index(ds) # clean with IQR all fields: logger.info('removing outliers with IQR of {}...'.format(iqr_k)) ds = keep_iqr(ds, dim='time', qlow=0.25, qhigh=0.75, k=iqr_k) # filter the fields based on their errors not being NaNs: logger.info('filtering out fields if their errors are NaN...') ds = filter_nan_errors(ds, error_str='_error', dim='time') logger.info('transforming X, Y, Z coords to lat, lon and alt...') ds = transform_ds_to_lat_lon_alt(ds, ['X', 'Y', 'Z'], '_error', 'time') logger.info( 'reindexing fields with 5 mins frequency(i.e., inserting NaNs)') ds = xr_reindex_with_date_range(ds, 'time', '5min') ds.attrs['station'] = station if plot: plot_gipsy_field(ds, None) if savepath is not None: comp = dict(zlib=True, complevel=9) # best compression encoding = {var: comp for var in ds.data_vars} ymin = ds.time.min().dt.year.item() ymax = ds.time.max().dt.year.item() new_filename = '{}_PPP_{}-{}.nc'.format(station, ymin, ymax) ds.to_netcdf(savepath / new_filename, 'w', encoding=encoding) logger.info('{} was saved to {}'.format(new_filename, savepath)) logger.info('Done!') return ds
def produce_seasonal_trend_breakdown_time_series_from_jpl_gipsyx_site(station='bshm', path=jpl_path, var='V', k=2, verbose=True, plot=True): import xarray as xr from aux_gps import harmonic_da_ts from aux_gps import loess_curve from aux_gps import keep_iqr from aux_gps import get_unique_index from aux_gps import xr_reindex_with_date_range from aux_gps import decimal_year_to_datetime import matplotlib.pyplot as plt if verbose: print('producing seasonal time series for {} station {}'.format(station, var)) ds = read_time_series_jpl_gipsyx_site(station=station, path=path/'time_series', verbose=verbose) # dyear = ds['decimal_year'] da_ts = ds[var] da_ts = xr_reindex_with_date_range(get_unique_index(da_ts), freq='D') xr.infer_freq(da_ts['time']) if k is not None: da_ts = keep_iqr(da_ts, k=k) da_ts.name = '{}_{}'.format(station, var) # detrend: trend = loess_curve(da_ts, plot=False)['mean'] trend.name = da_ts.name + '_trend' trend = xr_reindex_with_date_range(trend, freq='D') da_ts_detrended = da_ts - trend if verbose: print('detrended by loess.') da_ts_detrended.name = da_ts.name + '_detrended' # harmonic cpy fits: harm = harmonic_da_ts(da_ts_detrended.dropna('time'), n=2, grp='month', return_ts_fit=True, verbose=verbose) harm = xr_reindex_with_date_range(harm, time_dim='time', freq='D') harm1 = harm.sel(cpy=1).reset_coords(drop=True) harm1.name = da_ts.name + '_annual' harm1_keys = [x for x in harm1.attrs.keys() if '_1' in x] harm1.attrs = dict(zip(harm1_keys, [harm1.attrs[x] for x in harm1_keys])) harm2 = harm.sel(cpy=2).reset_coords(drop=True) harm2.name = da_ts.name + '_semiannual' harm2_keys = [x for x in harm2.attrs.keys() if '_2' in x] harm2.attrs = dict(zip(harm2_keys, [harm2.attrs[x] for x in harm2_keys])) resid = da_ts_detrended - harm1 - harm2 resid.name = da_ts.name + '_residual' ds = xr.merge([da_ts, trend, harm1, harm2, resid]) # load breakpoints: try: breakpoints = xr.open_dataset( jpl_path/'jpl_break_estimates.nc').sel(station=station.upper())[var] df = breakpoints.dropna('year')['year'].to_dataframe() # load seasonal coeffs: df['dt'] = df['year'].apply(decimal_year_to_datetime) df['dt'] = df['dt'].round('D') bp_da = df.set_index(df['dt'])['dt'].to_xarray() bp_da = bp_da.rename({'dt': 'time'}) ds['{}_{}_breakpoints'.format(station, var)] = bp_da no_bp = False except KeyError: if verbose: print('no breakpoints found for {}!'.format(station)) no_bp = True # seas = xr.load_dataset( # jpl_path/'jpl_seasonal_estimates.nc').sel(station=station.upper()) # ac1, as1, ac2, as2 = seas[var].values # # build seasonal time series: # annual = xr.DataArray(ac1*np.cos(dyear*2*np.pi)+as1 * # np.sin(dyear*2*np.pi), dims=['time']) # annual['time'] = da_ts['time'] # annual.name = '{}_{}_annual'.format(station, var) # annual.attrs['units'] = 'mm' # annual.attrs['long_name'] = 'annual mode' # semiannual = xr.DataArray(ac2*np.cos(dyear*4*np.pi)+as2 * # np.sin(dyear*4*np.pi), dims=['time']) # semiannual['time'] = da_ts['time'] # semiannual.name = '{}_{}_semiannual'.format(station, var) # semiannual.attrs['units'] = 'mm' # semiannual.attrs['long_name'] = 'semiannual mode' # ds = xr.merge([annual, semiannual, da_ts]) if plot: # plt.figure(figsize=(20, 20)) dst = ds[[x for x in ds if 'breakpoints' not in x]] axes = dst.to_dataframe().plot(subplots=True, figsize=(20, 20), color='k') [ax.grid() for ax in axes] [ax.set_ylabel('[mm]') for ax in axes] if not no_bp: for bp in df['dt']: [ax.axvline(bp, color='red') for ax in axes] plt.tight_layout() fig, ax = plt.subplots(figsize=(7, 7)) harm_mm = harmonic_da_ts(da_ts_detrended.dropna('time'), n=2, grp='month', return_ts_fit=False, verbose=verbose) harm_mm['{}_{}_detrended'.format(station, var)].plot.line(ax=ax, linewidth=0, marker='o', color='k') harm_mm['{}_mean'.format(station)].sel(cpy=1).plot.line(ax=ax, marker=None, color='tab:red') harm_mm['{}_mean'.format(station)].sel(cpy=2).plot.line(ax=ax, marker=None, color='tab:blue') harm_mm['{}_mean'.format(station)].sum('cpy').plot.line(ax=ax, marker=None, color='tab:purple') ax.grid() return ds
def read_hydrographs(path=hydro_path): from aux_gps import path_glob import pandas as pd import xarray as xr from aux_gps import get_unique_index files = path_glob(path, 'hydro_graph*.txt') df_list = [] for file in files: print(file) df = pd.read_csv(file, header=0, sep=',') df.columns = [ 'id', 'name', 'time', 'tide_height[m]', 'flow[m^3/sec]', 'data_type', 'flow_type', 'record_type', 'record_code' ] df['time'] = pd.to_datetime(df['time'], dayfirst=True) df['tide_height[m]'] = df['tide_height[m]'].astype(float) df['flow[m^3/sec]'] = df['flow[m^3/sec]'].astype(float) df.loc[:, 'data_type'][df['data_type'].str.contains( 'מדודים', na=False)] = 'measured' df.loc[:, 'data_type'][df['data_type'].str.contains( 'משוחזרים', na=False)] = 'reconstructed' df.loc[:, 'flow_type'][df['flow_type'].str.contains('תקין', na=False)] = 'normal' df.loc[:, 'flow_type'][df['flow_type'].str.contains('גאות', na=False)] = 'tide' df.loc[:, 'record_type'][df['record_type'].str.contains( 'נקודה פנימית', na=False)] = 'inner_point' df.loc[:, 'record_type'][df['record_type'].str.contains( 'נקודה פנימית', na=False)] = 'inner_point' df.loc[:, 'record_type'][df['record_type'].str.contains( 'התחלת קטע', na=False)] = 'section_begining' df.loc[:, 'record_type'][df['record_type'].str.contains( 'סיום קטע', na=False)] = 'section_ending' df_list.append(df) df = pd.concat(df_list) dfs = [x for _, x in df.groupby('id')] ds_list = [] meta_df = read_hydro_metadata(path, gis_path, False) for df in dfs: st_id = df['id'].iloc[0] st_name = df['name'].iloc[0] print('proccessing station number: {}, {}'.format(st_id, st_name)) meta = meta_df[meta_df['id'] == st_id] ds = xr.Dataset() df.set_index('time', inplace=True) attrs = {} attrs['station_name'] = st_name if not meta.empty: attrs['lon'] = meta.lon.values.item() attrs['lat'] = meta.lat.values.item() attrs['alt'] = meta.alt.values.item() attrs['drainage_basin_area'] = meta.area.values.item() attrs['active'] = meta.active.values.item() attrs['units'] = 'm' tide_height = df['tide_height[m]'].to_xarray() tide_height.name = 'HS_{}_tide_height'.format(st_id) tide_height.attrs = attrs flow = df['flow[m^3/sec]'].to_xarray() flow.name = 'HS_{}_flow'.format(st_id) attrs['units'] = 'm^3/sec' flow.attrs = attrs ds['{}'.format(tide_height.name)] = tide_height ds['{}'.format(flow.name)] = flow ds_list.append(ds) dsu = [get_unique_index(x) for x in ds_list] print('merging...') ds = xr.merge(dsu) filename = 'hydro_graphs.nc' print('saving {} to {}'.format(filename, path)) comp = dict(zlib=True, complevel=9) # best compression encoding = {var: comp for var in ds.data_vars} ds.to_netcdf(path / filename, 'w', encoding=encoding) print('Done!') return ds
def read_tides(path=hydro_path): from aux_gps import path_glob import pandas as pd import xarray as xr from aux_gps import get_unique_index files = path_glob(path, 'tide_report*.xlsx') df_list = [] for file in files: df = pd.read_excel(file, header=4) df.drop(df.columns[len(df.columns) - 1], axis=1, inplace=True) df.columns = [ 'id', 'name', 'hydro_year', 'tide_start_hour', 'tide_start_date', 'tide_end_hour', 'tide_end_date', 'tide_duration', 'tide_max_hour', 'tide_max_date', 'max_height', 'max_flow[m^3/sec]', 'tide_vol[MCM]' ] df = df[~df.hydro_year.isnull()] df['id'] = df['id'].astype(int) df['tide_start'] = pd.to_datetime( df['tide_start_date'], dayfirst=True) + pd.to_timedelta( df['tide_start_hour'].add(':00'), unit='m', errors='coerce') df['tide_end'] = pd.to_datetime( df['tide_end_date'], dayfirst=True) + pd.to_timedelta( df['tide_end_hour'].add(':00'), unit='m', errors='coerce') df['tide_max'] = pd.to_datetime( df['tide_max_date'], dayfirst=True) + pd.to_timedelta( df['tide_max_hour'].add(':00'), unit='m', errors='coerce') df['tide_duration'] = pd.to_timedelta(df['tide_duration'] + ':00', unit='m', errors='coerce') df.loc[:, 'max_flow[m^3/sec]'][df['max_flow[m^3/sec]'].str.contains( '<', na=False)] = 0 df.loc[:, 'tide_vol[MCM]'][df['tide_vol[MCM]'].str.contains('<', na=False)] = 0 df['max_flow[m^3/sec]'] = df['max_flow[m^3/sec]'].astype(float) df['tide_vol[MCM]'] = df['tide_vol[MCM]'].astype(float) to_drop = [ 'tide_start_hour', 'tide_start_date', 'tide_end_hour', 'tide_end_date', 'tide_max_hour', 'tide_max_date' ] df = df.drop(to_drop, axis=1) df_list.append(df) df = pd.concat(df_list) dfs = [x for _, x in df.groupby('id')] ds_list = [] meta_df = read_hydro_metadata(path, gis_path, False) for df in dfs: st_id = df['id'].iloc[0] st_name = df['name'].iloc[0] print('proccessing station number: {}, {}'.format(st_id, st_name)) meta = meta_df[meta_df['id'] == st_id] ds = xr.Dataset() df.set_index('tide_start', inplace=True) attrs = {} attrs['station_name'] = st_name if not meta.empty: attrs['lon'] = meta.lon.values.item() attrs['lat'] = meta.lat.values.item() attrs['alt'] = meta.alt.values.item() attrs['drainage_basin_area'] = meta.area.values.item() attrs['active'] = meta.active.values.item() attrs['units'] = 'm' max_height = df['max_height'].to_xarray() max_height.name = 'TS_{}_max_height'.format(st_id) max_height.attrs = attrs max_flow = df['max_flow[m^3/sec]'].to_xarray() max_flow.name = 'TS_{}_max_flow'.format(st_id) attrs['units'] = 'm^3/sec' max_flow.attrs = attrs attrs['units'] = 'MCM' tide_vol = df['tide_vol[MCM]'].to_xarray() tide_vol.name = 'TS_{}_tide_vol'.format(st_id) tide_vol.attrs = attrs attrs.pop('units') # tide_start = df['tide_start'].to_xarray() # tide_start.name = 'TS_{}_tide_start'.format(st_id) # tide_start.attrs = attrs tide_end = df['tide_end'].to_xarray() tide_end.name = 'TS_{}_tide_end'.format(st_id) tide_end.attrs = attrs tide_max = df['tide_max'].to_xarray() tide_max.name = 'TS_{}_tide_max'.format(st_id) tide_max.attrs = attrs ds['{}'.format(max_height.name)] = max_height ds['{}'.format(max_flow.name)] = max_flow ds['{}'.format(tide_vol.name)] = tide_vol # ds['{}'.format(tide_start.name)] = tide_start ds['{}'.format(tide_end.name)] = tide_end ds['{}'.format(tide_max.name)] = tide_max ds_list.append(ds) dsu = [get_unique_index(x, dim='tide_start') for x in ds_list] print('merging...') ds = xr.merge(dsu) filename = 'hydro_tides.nc' print('saving {} to {}'.format(filename, path)) comp = dict(zlib=True, complevel=9) # best compression encoding = {var: comp for var in ds.data_vars} ds.to_netcdf(path / filename, 'w', encoding=encoding) print('Done!') return ds
def plot_figure_4(physical_file=phys_soundings, model='LR', times=['2007', '2019'], save=True): """plot ts-tm relashonship""" import xarray as xr import matplotlib.pyplot as plt import seaborn as sns from PW_stations import ML_Switcher from sklearn.metrics import mean_squared_error from sklearn.metrics import r2_score from aux_gps import get_unique_index import numpy as np from mpl_toolkits.axes_grid1.inset_locator import inset_axes # sns.set_style('whitegrid') pds = xr.open_dataset(phys_soundings) pds = pds[['Tm', 'Ts']] pds = get_unique_index(pds, 'sound_time') pds = pds.sel(sound_time=slice(*times)) fig, ax = plt.subplots(1, 1, figsize=(7, 7)) pds.plot.scatter(x='Ts', y='Tm', marker='.', s=100., linewidth=0, alpha=0.5, ax=ax) ax.grid() ml = ML_Switcher() fit_model = ml.pick_model(model) X = pds.Ts.values.reshape(-1, 1) y = pds.Tm.values fit_model.fit(X, y) predict = fit_model.predict(X) coef = fit_model.coef_[0] inter = fit_model.intercept_ ax.plot(X, predict, c='r') bevis_tm = pds.Ts.values * 0.72 + 70.0 ax.plot(pds.Ts.values, bevis_tm, c='purple') ax.legend([ 'OLS ({:.2f}, {:.2f})'.format(coef, inter), 'Bevis 1992 et al. (0.72, 70.0)' ]) # ax.set_xlabel('Surface Temperature [K]') # ax.set_ylabel('Water Vapor Mean Atmospheric Temperature [K]') ax.set_xlabel('Ts [K]') ax.set_ylabel('Tm [K]') ax.set_ylim(265, 320) axin1 = inset_axes(ax, width="40%", height="40%", loc=2) resid = predict - y sns.distplot(resid, bins=50, color='k', label='residuals', ax=axin1, kde=False, hist_kws={ "linewidth": 1, "alpha": 0.5, "color": "k" }) axin1.yaxis.tick_right() rmean = np.mean(resid) rmse = np.sqrt(mean_squared_error(y, predict)) r2 = r2_score(y, predict) axin1.axvline(rmean, color='r', linestyle='dashed', linewidth=1) # axin1.set_xlabel('Residual distribution[K]') textstr = '\n'.join([ 'n={}'.format(pds.Ts.size), 'RMSE: ', '{:.2f} K'.format(rmse), r'R$^2$: {:.2f}'.format(r2) ]) props = dict(boxstyle='round', facecolor='wheat', alpha=0.5) axin1.text(0.05, 0.95, textstr, transform=axin1.transAxes, fontsize=10, verticalalignment='top', bbox=props) # axin1.text(0.2, 0.9, 'n={}'.format(pds.Ts.size), # verticalalignment='top', horizontalalignment='center', # transform=axin1.transAxes, color='k', fontsize=10) # axin1.text(0.78, 0.9, 'RMSE: {:.2f} K'.format(rmse), # verticalalignment='top', horizontalalignment='center', # transform=axin1.transAxes, color='k', fontsize=10) axin1.set_xlim(-15, 15) fig.tight_layout() filename = 'Bet_dagan_ts_tm_fit.png' caption( 'Water vapor mean temperature (Tm) vs. surface temperature (Ts) of the Bet-dagan radiosonde station. Ordinary least squares linear fit(red) yields the residual distribution with RMSE of 4 K. Bevis(1992) model is plotted(purple) for comparison.' ) if save: plt.savefig(savefig_path / filename, bbox_inches='tight') return