def stitch_yearly_files(ds_list): """input is multiple field yearly dataset list and output is the same but with stitched discontinuieties""" fields = [x for x in ds_list[0].data_vars] for i, dss in enumerate(ds_list): if i == len(ds_list) - 1: break first_year = int(ds_list[i].time.dt.year.median().item()) second_year = int(ds_list[i + 1].time.dt.year.median().item()) first_ds = ds_list[i].sel(time=slice( '{}-12-31T18:00'.format(first_year), str(second_year))) second_ds = ds_list[i + 1].sel(time=slice( str(first_year), '{}-01-01T06:00'.format(second_year))) if dim_intersection([first_ds, second_ds], 'time') is None: logger.warning('skipping stitching years {} and {}...'.format( first_year, second_year)) continue else: logger.info('stitching years {} and {}'.format( first_year, second_year)) time = xr.concat([first_ds.time, second_ds.time], 'time') time = pd.to_datetime(get_unique_index(time).values) st_list = [] for field in fields: df = first_ds[field].to_dataframe() df.columns = ['first'] df = df.reindex(time) df['second'] = second_ds[field].to_dataframe() if field in ['X', 'Y', 'Z']: method = 'simple_mean' elif field in ['GradNorth', 'GradEast', 'WetZ']: method = 'smooth_mean' elif 'error' in field: method = 'error_mean' dfs = stitch_two_cols(df, method=method)['stitched_signal'] dfs.index.name = 'time' st = dfs.to_xarray() st.name = field st_list.append(st) # merge to all fields: st_ds = xr.merge(st_list) # replace stitched values to first ds and second ds: first_time = dim_intersection([ds_list[i], st_ds]) vals_rpl = st_ds.sel(time=first_time) for field in ds_list[i].data_vars: ds_list[i][field].loc[{'time': first_time}] = vals_rpl[field] second_time = dim_intersection([ds_list[i + 1], st_ds]) vals_rpl = st_ds.sel(time=second_time) for field in ds_list[i + 1].data_vars: ds_list[i + 1][field].loc[{ 'time': second_time }] = vals_rpl[field] return ds_list
def align_synoptic_class_with_pw(path): import xarray as xr from aux_gps import dim_intersection from aux_gps import save_ncfile from aux_gps import xr_reindex_with_date_range pw = xr.load_dataset(path / 'GNSS_PW_thresh_50_homogenized.nc') pw = pw[[x for x in pw if '_error' not in x]] syn = read_synoptic_classification(report=False).to_xarray() # syn = syn.drop(['Name-EN', 'Name-HE']) syn = syn['class'] syn = syn.sel(time=slice('1996', None)) syn = syn.resample(time='5T').ffill() ds_list = [] for sta in pw: print('aligning station {} with synoptics'.format(sta)) new_time = dim_intersection([pw[sta], syn]) syn_da = xr.DataArray(syn.sel(time=new_time)) syn_da.name = '{}_class'.format(sta) syn_da = xr_reindex_with_date_range(syn_da) ds_list.append(syn_da) ds = xr.merge(ds_list) ds = ds.astype('int8') ds = ds.fillna(0) filename = 'GNSS_synoptic_class.nc' save_ncfile(ds, path, filename) return ds
def pw_mlh_to_df(pw_new, mlh_site): newtime = dim_intersection([pw_new, mlh_site]) MLH = mlh_site.sel(time=newtime) PW = pw_new.sel(time=newtime) df = PW.to_dataframe() df[MLH.name] = MLH.to_dataframe() new_time = pd.date_range(df.index.min(), df.index.max(), freq='1H') df = df.reindex(new_time) df.index.name = 'time' return df
def read_BD_matfile(path=ceil_path, plot=True, month=None, add_syn=True): from scipy.io import loadmat import pandas as pd from aux_gps import xr_reindex_with_date_range import matplotlib.pyplot as plt from aux_gps import dim_intersection from synoptic_procedures import read_synoptic_classification file = path / 'PBL_BD_LST.mat' mat = loadmat(file) mdata = mat['pblBD4shlomi'] # mdata = mat['PBL_BD_LST'] dates = mdata[:, :3] pbl = mdata[:, 3:] dates = dates.astype(str) dts = [pd.to_datetime(x[0] + '-' + x[1] + '-' + x[2]) for x in dates] dfs = [] for i, dt in enumerate(dts): time = dt + pd.Timedelta(0.5, unit='H') times = pd.date_range(time, periods=48, freq='30T') df = pd.DataFrame(pbl[i], index=times) dfs.append(df) df = pd.concat(dfs) df.columns = ['MLH'] df.index.name = 'time' # switch to UTC: df.index = df.index - pd.Timedelta(2, unit='H') da = df.to_xarray()['MLH'] da.name = 'BD' da.attrs['full_name'] = 'Mixing Layer Height' da.attrs['name'] = 'MLH' da.attrs['units'] = 'm' da.attrs['station_full_name'] = 'Beit Dagan' da.attrs['lon'] = 34.81 da.attrs['lat'] = 32.00 da.attrs['alt'] = 34 da = xr_reindex_with_date_range(da, freq='30T') # add synoptic data: syn = read_synoptic_classification().to_xarray() syn = syn.sel(time=slice('2015', '2016')) syn = syn.resample(time='30T').ffill() new_time = dim_intersection([da, syn]) syn_da = syn.sel(time=new_time) syn_da = xr_reindex_with_date_range(syn_da, freq='30T') if plot: bd2015 = da.sel(time='2015').to_dataframe() bd2016 = da.sel(time='2016').to_dataframe() fig, axes = plt.subplots(2, 1, sharey=True, sharex=False, figsize=(15, 10)) if add_syn: cmap = plt.get_cmap("tab10") syn_df = syn_da.to_dataframe() bd2015['synoptics'] = syn_df.loc['2015', 'class_abbr'] groups = [] for i, (index, group) in enumerate(bd2015.groupby('synoptics')): groups.append(index) d = xr_reindex_with_date_range(group['BD'].to_xarray(), freq='30T') d.to_dataframe().plot(x_compat=True, ms=10, color=cmap(i), ax=axes[0], xlim=['2015-06', '2015-10']) axes[0].legend(groups) bd2016['synoptics'] = syn_df.loc['2016', 'class_abbr'] groups = [] for i, (index, group) in enumerate(bd2016.groupby('synoptics')): groups.append(index) d = xr_reindex_with_date_range(group['BD'].to_xarray(), freq='30T') d.to_dataframe().plot(x_compat=True, ms=10, color=cmap(i), ax=axes[1], xlim=['2016-06', '2016-10']) axes[1].legend(groups) else: bd2015.plot(ax=axes[0], xlim=['2015-06', '2015-10']) bd2016.plot(ax=axes[1], xlim=['2016-06', '2016-10']) for ax in axes.flatten(): ax.set_ylabel('MLH [m]') ax.set_xlabel('UTC') ax.grid() fig.tight_layout() fig.suptitle('MLH from Beit-Dagan ceilometer for 2015 and 2016') filename = 'MLH-BD_syn.png' plt.savefig(savefig_path / filename, orientation='portrait') if add_syn: ds = da.to_dataset(name='BD') ds['syn'] = syn_da['class_abbr'] return ds else: return da
def align_pw_mlh(path=work_yuval, ceil_path=ceil_path, site='tela', interpolate=None, plot=True, dt_range_str='2015'): import xarray as xr from aux_gps import dim_intersection from aux_gps import xr_reindex_with_date_range import pandas as pd import matplotlib.pyplot as plt def pw_mlh_to_df(pw_new, mlh_site): newtime = dim_intersection([pw_new, mlh_site]) MLH = mlh_site.sel(time=newtime) PW = pw_new.sel(time=newtime) df = PW.to_dataframe() df[MLH.name] = MLH.to_dataframe() new_time = pd.date_range(df.index.min(), df.index.max(), freq='1H') df = df.reindex(new_time) df.index.name = 'time' return df mlh = xr.load_dataset(ceil_path / 'MLH_from_ceilometers.nc') mlh_site = xr_reindex_with_date_range(mlh[pw_mlh_dict.get(site)], freq='1H') if interpolate is not None: print('interpolating ceil-site {} with max-gap of {}.'.format( pw_mlh_dict.get(site), interpolate)) attrs = mlh_site.attrs mlh_site_inter = mlh_site.interpolate_na('time', max_gap=interpolate, method='cubic') mlh_site_inter.attrs = attrs pw = xr.open_dataset(work_yuval / 'GNSS_PW_thresh_50_homogenized.nc') pw = pw[['tela', 'klhv', 'jslm', 'nzrt', 'yrcm']] pw.load() pw_new = pw[site] if interpolate is not None: newtime = dim_intersection([pw_new, mlh_site_inter]) else: newtime = dim_intersection([pw_new, mlh_site]) pw_new = pw_new.sel(time=newtime) pw_new = xr_reindex_with_date_range(pw_new, freq='1H') if interpolate is not None: print('interpolating pw-site {} with max-gap of {}.'.format( site, interpolate)) attrs = pw_new.attrs pw_new_inter = pw_new.interpolate_na('time', max_gap=interpolate, method='cubic') pw_new_inter.attrs = attrs df = pw_mlh_to_df(pw_new, mlh_site) if interpolate is not None: df_inter = pw_mlh_to_df(pw_new_inter, mlh_site_inter) if dt_range_str is not None: df = df.loc[dt_range_str, :] if plot: fig, ax = plt.subplots(figsize=(18, 5)) if interpolate is not None: df_inter[pw_new.name].plot(style='b--', ax=ax) # same ax as above since it's automatically added on the right df_inter[mlh_site.name].plot(style='r--', secondary_y=True, ax=ax) ax = df[pw_new.name].plot(style='b-', marker='o', ax=ax, ms=5) # same ax as above since it's automatically added on the right ax_twin = df[mlh_site.name].plot(style='r-', marker='s', secondary_y=True, ax=ax, ms=5) if interpolate is not None: ax.legend(*[ax.get_lines() + ax.right_ax.get_lines()], [ 'PWV {} max interpolation'.format(interpolate), 'PWV', 'MLH {} max interpolation'.format(interpolate), 'MLH' ], loc='best') else: ax.legend([ax.get_lines()[0], ax.right_ax.get_lines()[0]], ['PWV', 'MLH'], loc='upper center') ax.set_title('MLH {} site and PWV {} site'.format( pw_mlh_dict.get(site), site)) ax.set_xlim(df.dropna().index.min(), df.dropna().index.max()) ax.set_ylabel('PWV [mm]', color='b') ax_twin.set_ylabel('MLH [m]', color='r') ax.tick_params(axis='y', colors='b') ax_twin.tick_params(axis='y', colors='r') ax.grid(True, which='both', axis='x') fig.tight_layout() if interpolate is not None: filename = '{}-{}_{}_time_series_{}_max_gap_interpolation.png'.format( site, pw_mlh_dict.get(site), dt_range_str, interpolate) else: filename = '{}-{}_{}_time_series.png'.format( site, pw_mlh_dict.get(site), dt_range_str) plt.savefig(savefig_path / filename, orientation='portrait') if interpolate is not None: ds = df_inter.to_xarray() ds[pw_new.name].attrs.update(pw_new.attrs) ds[mlh_site.name].attrs.update(mlh_site.attrs) return ds else: ds = df.to_xarray() ds[pw_new.name].attrs.update(pw_new.attrs) ds[mlh_site.name].attrs.update(mlh_site.attrs) return ds
def twin_hourly_mean_plot(pw, mlh, month=8, ax=None, title=True, leg_loc='best', unit='pts', sample_rate=24, fontsize=14): from aux_gps import dim_intersection import matplotlib.pyplot as plt from calendar import month_abbr # from PW_stations import produce_geo_gnss_solved_stations # df = produce_geo_gnss_solved_stations(plot=False) # first run multi-year month mean: if month is not None: pw = pw.sel(time=pw['time.month'] == month).dropna('time') mlh = mlh.sel(time=mlh['time.month'] == month).dropna('time') else: newtime = dim_intersection([pw, mlh], 'time') pw = pw.sel(time=newtime) mlh = mlh.sel(time=newtime) pw_hour = pw.groupby('time.hour').mean() pw_std = pw.groupby('time.hour').std() pw_hour_plus = (pw_hour + pw_std).values pw_hour_minus = (pw_hour - pw_std).values mlh_hour = mlh.groupby('time.hour').mean() mlh_std = mlh.groupby('time.hour').std() mlh_hour_minus = (mlh_hour - mlh_std).values mlh_hour_plus = (mlh_hour + mlh_std).values mlhyears = [mlh.time.dt.year.min().item(), mlh.time.dt.year.max().item()] pwyears = [pw.time.dt.year.min().item(), pw.time.dt.year.max().item()] mlh_month = mlh.time.dt.month.to_dataframe()['month'].value_counts( ).index[0] if unit == 'pts': pw_pts = pw.dropna('time').size mlh_pts = mlh.dropna('time').size elif unit == 'days': pw_pts = int(pw.dropna('time').size / sample_rate) mlh_pts = int(mlh.dropna('time').size / sample_rate) if ax is None: fig, ax = plt.subplots(figsize=(10, 8)) red = 'tab:red' blue = 'tab:blue' pwln = pw_hour.plot(color=blue, marker='s', ax=ax) # ax.errorbar(pw_hour.hour.values, pw_hour.values, pw_std.values, # label='PW', color=blue, capsize=5, elinewidth=2, # markeredgewidth=2) ax.fill_between(pw_hour.hour.values, pw_hour_minus, pw_hour_plus, color=blue, alpha=0.5) twin = ax.twinx() # twin.errorbar(mlh_hour.hour.values, mlh_hour.values, mlh_std.values, # color=red, label='MLH', capsize=5, elinewidth=2, # markeredgewidth=2) mlhln = mlh_hour.plot(color=red, marker='o', ax=twin) twin.fill_between(mlh_hour.hour.values, mlh_hour_minus, mlh_hour_plus, color=red, alpha=0.5) # handles, labels = ax.get_legend_handles_labels() # handles = [h[0] for h in handles] # handles1, labels1 = twin.get_legend_handles_labels() # handles1 = [h[0] for h in handles1] # hand = handles + handles1 # labs = labels + labels1 if month is None: pw_label = 'PWV: {}-{} ({} {})'.format(pwyears[0], pwyears[1], pw_pts, unit) mlh_label = 'MLH: {}-{} ({} {})'.format(mlhyears[0], mlhyears[1], mlh_pts, unit) else: pw_pts = int(pw.dropna('time').size / 288) pw_label = 'PWV: {}-{}, {} ({} {})'.format(pwyears[0], pwyears[1], month_abbr[mlh_month], pw_pts, unit) mlh_label = 'MLH: {}-{}, {} ({} {})'.format(mlhyears[0], mlhyears[1], month_abbr[mlh_month], mlh_pts, unit) # if month is not None: # pwmln = pw_m_hour.plot(color='tab:orange', marker='^', ax=ax) # pwm_label = 'PW: {}-{}, {} ({} pts)'.format(pw_years[0], pw_years[1], month_abbr[month], pw_month.dropna('time').size) # ax.legend(pwln + mlhln + pwmln, [pw_label, mlh_label, pwm_label], loc=leg_loc) # else: ax.legend(pwln + mlhln, [pw_label, mlh_label], loc=leg_loc) ax.tick_params(axis='y', colors=blue, labelsize=fontsize) twin.tick_params(axis='y', colors=red, labelsize=fontsize) ax.set_ylabel('PWV [mm]', color=blue, fontsize=fontsize) twin.set_ylabel('MLH [m]', color=red, fontsize=fontsize) ax.set_xticks([x for x in range(24)]) ax.set_xlabel('Hour of day [UTC]', fontsize=fontsize) mlh_name = mlh.attrs['station_full_name'].replace('_', '-') textstr = '{}, {}'.format(mlh_name, pw.name.upper()) props = dict(boxstyle='round', facecolor='white', alpha=0.5) ax.text(0.05, 0.95, textstr, transform=ax.transAxes, fontsize=fontsize, verticalalignment='top', bbox=props) if title: ax.set_title( 'The diurnal cycle of {} Mixing Layer Height and {} GNSS site PWV'. format(mlh_name, pw.name.upper())) return ax, twin
def scatter_plot_pw_mlh(pw, mlh, diurnal=False, ax=None, title=True, leg_loc='best', month=None): from aux_gps import dim_intersection import xarray as xr import numpy as np from sklearn.metrics import r2_score import matplotlib.pyplot as plt from PW_stations import produce_geo_gnss_solved_stations df = produce_geo_gnss_solved_stations(plot=False) pw_alt = df.loc[pw.name, 'alt'] pw_attrs = pw.attrs mlh_attrs = mlh.attrs if diurnal: if month is not None: pw = pw.sel(time=pw['time.month'] == month) else: newtime = dim_intersection([pw, mlh], 'time') pw = pw.sel(time=newtime) mlh = mlh.sel(time=newtime) pw = pw.groupby('time.hour').mean() pw.attrs = pw_attrs mlh = mlh.groupby('time.hour').mean() mlh.attrs = mlh_attrs else: newtime = dim_intersection([pw, mlh], 'time') pw = pw.sel(time=newtime) mlh = mlh.sel(time=newtime) ds = xr.merge([pw, mlh]) if ax is None: fig, ax = plt.subplots(figsize=(10, 10)) ds.plot.scatter(pw.name, mlh.name, ax=ax) coefs = np.polyfit(pw.values, mlh.values, 1) x = np.linspace(pw.min().item(), pw.max().item(), 100) y = np.polyval(coefs, x) r2 = r2_score(mlh.values, np.polyval(coefs, pw.values)) # coefs2 = np.polyfit(pw.values, mlh.values, 2) # y2 = np.polyval(coefs2, x) # r22 = r2_score(mlh.values,np.polyval(coefs2, pw.values)) ax.plot(x, y, color='tab:red') # ax.plot(x, y2, color='tab:orange') ax.set_xlabel('PWV [mm]') ax.set_ylabel('MLH [m]') ax.legend(['linear fit', 'data'], loc=leg_loc) textstr = '\n'.join([ 'n={}'.format(pw.size), r'R$^2$={:.2f}'.format(r2), 'slope={:.1f} m/mm'.format(coefs[0]) ]) props = dict(boxstyle='round', facecolor='wheat', alpha=0.5) ax.text(0.05, 0.95, textstr, transform=ax.transAxes, fontsize=10, verticalalignment='top', bbox=props) mlh_name = mlh.attrs['station_full_name'].replace('_', '-') if title: ax.set_title( '{} ({:.0f} m) GNSS site PW vs. {} ({:.0f} m) Mixing Layer Height'. format(pw.name.upper(), pw_alt, mlh_name, mlh.attrs['alt'])) return ax
def process_data_from_uwyo_sounding(path, st_num): """create tm, tpw from sounding station and also add surface temp and station caluculated ipw""" import xarray as xr from aux_gps import dim_intersection import numpy as np import logging logger = logging.getLogger('uwyo') # da = xr.open_dataarray(sound_path / 'ALL_bet_dagan_soundings.nc') pw_file = 'PW_{}_soundings.nc'.format(st_num) all_file = 'ALL_{}_soundings.nc'.format(st_num) da = xr.open_dataarray(path / all_file) pw = xr.open_dataarray(path / pw_file) new_time = dim_intersection([da, pw], 'time', dropna=False) logger.info('loaded {}'.format(pw_file)) logger.info('loaded {}'.format(all_file)) da = da.sel(time=new_time) pw = pw.sel(time=new_time) pw.load() da.load() logger.info('calculating pw and tm for station {}'.format(st_num)) ts_list = [] tpw_list = [] tm_list = [] # cld_list = [] for date in da.time: ts_list.append(da.sel(var='TEMP', mpoint=0, time=date) + 273.15) tpw_list.append(precipitable_water(da.sel(time=date))) tm_list.append(Tm(da.sel(time=date))) # if np.isnan(ds.CLD.sel(time=date)).all(): # cld_list.append(0) # else: # cld_list.append(1) tpw = xr.DataArray(tpw_list, dims='time') tm = xr.DataArray(tm_list, dims='time') tm.attrs[ 'description'] = 'mean atmospheric temperature calculated by water vapor pressure weights' tm.attrs['units'] = 'K' ts = xr.concat(ts_list, 'time') ts.attrs[ 'description'] = 'Surface temperature from {} station soundings'.format( st_num) ts.attrs['units'] = 'K' result = pw.to_dataset(name='pw') result['tpw'] = tpw result['tm'] = tm result['ts'] = ts result['tpw'].attrs[ 'description'] = 'station {} percipatable water calculated from sounding by me'.format( st_num) result['tpw'].attrs['units'] = 'mm' result['season'] = result['time.season'] result['hour'] = result['time.hour'].astype(str) result['hour'] = result.hour.where(result.hour != '12', 'noon') result['hour'] = result.hour.where(result.hour != '0', 'midnight') # result['any_cld'] = xr.DataArray(cld_list, dims='time') result = result.dropna('time') filename = 'station_{}_sounding_pw_Ts_Tk.nc'.format(st_num) logger.info('saving {} to {}'.format(filename, path)) comp = dict(zlib=True, complevel=9) # best compression encoding = {var: comp for var in result} result.to_netcdf(path / filename, 'w', encoding=encoding) logger.info('Done!') return
def evaluate_sin_to_tmsearies( da, time_dim='time', plot=True, params_up={ 'amp': 6.0, 'period': 365.0, 'phase': 30.0, 'offset': 253.0, 'a': 0.0, 'b': 0.0 }, func='sine', bounds=None, just_eval=False): import pandas as pd import xarray as xr import numpy as np from aux_gps import dim_intersection from scipy.optimize import curve_fit from sklearn.metrics import mean_squared_error def sine(time, amp, period, phase, offset): f = amp * np.sin(2 * np.pi * (time / period + 1.0 / phase)) + offset return f def sine_on_linear(time, amp, period, phase, offset, a): f = a * time / 365.25 + amp * \ np.sin(2 * np.pi * (time / period + 1.0 / phase)) + offset return f def sine_on_quad(time, amp, period, phase, offset, a, b): f = a * (time / 365.25) ** 2.0 + b * time / 365.25 + amp * \ np.sin(2 * np.pi * (time / period + 1.0 / phase)) + offset return f params = { 'amp': 6.0, 'period': 365.0, 'phase': 30.0, 'offset': 253.0, 'a': 0.0, 'b': 0.0 } params.update(params_up) print(params) lower = {} upper = {} if bounds is not None: # lower = [(x - y) for x, y in zip(params, perc2)] for key in params.keys(): lower[key] = -np.inf upper[key] = np.inf lower['phase'] = 0.01 upper['phase'] = 0.05 lower['offset'] = 46.2 upper['offset'] = 46.3 lower['a'] = 0.0001 upper['a'] = 0.002 upper['amp'] = 0.04 lower['amp'] = 0.02 else: for key in params.keys(): lower[key] = -np.inf upper[key] = np.inf lower = [x for x in lower.values()] upper = [x for x in upper.values()] params = [x for x in params.values()] da_no_nans = da.dropna(time_dim) time = da_no_nans[time_dim].values time = pd.to_datetime(time) jul = time.to_julian_date() jul -= jul[0] jul_with_nans = pd.to_datetime(da[time_dim].values).to_julian_date() jul_with_nans -= jul[0] ydata = da_no_nans.values if func == 'sine': print('Model chosen: y = amp * sin (2*pi*(x/T + 1/phi)) + offset') if not just_eval: popt, pcov = curve_fit(sine, jul, ydata, p0=params[:-2], bounds=(lower[:-2], upper[:-2]), ftol=1e-9, xtol=1e-9) amp = popt[0] period = popt[1] phase = popt[2] offset = popt[3] perr = np.sqrt(np.diag(pcov)) print('amp: {:.4f} +- {:.2f}'.format(amp, perr[0])) print('period: {:.2f} +- {:.2f}'.format(period, perr[1])) print('phase: {:.2f} +- {:.2f}'.format(phase, perr[2])) print('offset: {:.2f} +- {:.2f}'.format(offset, perr[3])) new = sine(jul_with_nans, amp, period, phase, offset) elif func == 'sine_on_linear': print( 'Model chosen: y = a * x + amp * sin (2*pi*(x/T + 1/phi)) + offset' ) if not just_eval: popt, pcov = curve_fit(sine_on_linear, jul, ydata, p0=params[:-1], bounds=(lower[:-1], upper[:-1]), xtol=1e-11, ftol=1e-11) amp = popt[0] period = popt[1] phase = popt[2] offset = popt[3] a = popt[4] perr = np.sqrt(np.diag(pcov)) print('amp: {:.4f} +- {:.2f}'.format(amp, perr[0])) print('period: {:.2f} +- {:.2f}'.format(period, perr[1])) print('phase: {:.2f} +- {:.2f}'.format(phase, perr[2])) print('offset: {:.2f} +- {:.2f}'.format(offset, perr[3])) print('a: {:.7f} +- {:.2f}'.format(a, perr[4])) new = sine_on_linear(jul_with_nans, amp, period, phase, offset, a) elif func == 'sine_on_quad': print( 'Model chosen: y = a * x^2 + b * x + amp * sin (2*pi*(x/T + 1/phi)) + offset' ) if not just_eval: popt, pcov = curve_fit(sine_on_quad, jul, ydata, p0=params, bounds=(lower, upper)) amp = popt[0] period = popt[1] phase = popt[2] offset = popt[3] a = popt[4] b = popt[5] perr = np.sqrt(np.diag(pcov)) print('amp: {:.4f} +- {:.2f}'.format(amp, perr[0])) print('period: {:.2f} +- {:.2f}'.format(period, perr[1])) print('phase: {:.2f} +- {:.2f}'.format(phase, perr[2])) print('offset: {:.2f} +- {:.2f}'.format(offset, perr[3])) print('a: {:.7f} +- {:.2f}'.format(a, perr[4])) print('b: {:.7f} +- {:.2f}'.format(a, perr[5])) new = sine_on_quad(jul_with_nans, amp, period, phase, offset, a, b) new_da = xr.DataArray(new, dims=[time_dim]) new_da[time_dim] = da[time_dim] resid = new_da - da rmean = np.mean(resid) new_time = dim_intersection([da, new_da], time_dim) rmse = np.sqrt( mean_squared_error( da.sel({ time_dim: new_time }).values, new_da.sel({ time_dim: new_time }).values)) print('MEAN : {}'.format(rmean)) print('RMSE : {}'.format(rmse)) if plot: da.plot.line(marker='.', linewidth=0., figsize=(20, 5)) new_da.plot.line(marker='.', linewidth=0.) return new_da