def stitch_yearly_files(ds_list):
     """input is multiple field yearly dataset list and output is the same
     but with stitched discontinuieties"""
     fields = [x for x in ds_list[0].data_vars]
     for i, dss in enumerate(ds_list):
         if i == len(ds_list) - 1:
             break
         first_year = int(ds_list[i].time.dt.year.median().item())
         second_year = int(ds_list[i + 1].time.dt.year.median().item())
         first_ds = ds_list[i].sel(time=slice(
             '{}-12-31T18:00'.format(first_year), str(second_year)))
         second_ds = ds_list[i + 1].sel(time=slice(
             str(first_year), '{}-01-01T06:00'.format(second_year)))
         if dim_intersection([first_ds, second_ds], 'time') is None:
             logger.warning('skipping stitching years {} and {}...'.format(
                 first_year, second_year))
             continue
         else:
             logger.info('stitching years {} and {}'.format(
                 first_year, second_year))
         time = xr.concat([first_ds.time, second_ds.time], 'time')
         time = pd.to_datetime(get_unique_index(time).values)
         st_list = []
         for field in fields:
             df = first_ds[field].to_dataframe()
             df.columns = ['first']
             df = df.reindex(time)
             df['second'] = second_ds[field].to_dataframe()
             if field in ['X', 'Y', 'Z']:
                 method = 'simple_mean'
             elif field in ['GradNorth', 'GradEast', 'WetZ']:
                 method = 'smooth_mean'
             elif 'error' in field:
                 method = 'error_mean'
             dfs = stitch_two_cols(df, method=method)['stitched_signal']
             dfs.index.name = 'time'
             st = dfs.to_xarray()
             st.name = field
             st_list.append(st)
         # merge to all fields:
         st_ds = xr.merge(st_list)
         # replace stitched values to first ds and second ds:
         first_time = dim_intersection([ds_list[i], st_ds])
         vals_rpl = st_ds.sel(time=first_time)
         for field in ds_list[i].data_vars:
             ds_list[i][field].loc[{'time': first_time}] = vals_rpl[field]
         second_time = dim_intersection([ds_list[i + 1], st_ds])
         vals_rpl = st_ds.sel(time=second_time)
         for field in ds_list[i + 1].data_vars:
             ds_list[i + 1][field].loc[{
                 'time': second_time
             }] = vals_rpl[field]
     return ds_list
def align_synoptic_class_with_pw(path):
    import xarray as xr
    from aux_gps import dim_intersection
    from aux_gps import save_ncfile
    from aux_gps import xr_reindex_with_date_range
    pw = xr.load_dataset(path / 'GNSS_PW_thresh_50_homogenized.nc')
    pw = pw[[x for x in pw if '_error' not in x]]
    syn = read_synoptic_classification(report=False).to_xarray()
    # syn = syn.drop(['Name-EN', 'Name-HE'])
    syn = syn['class']
    syn = syn.sel(time=slice('1996', None))
    syn = syn.resample(time='5T').ffill()
    ds_list = []
    for sta in pw:
        print('aligning station {} with synoptics'.format(sta))
        new_time = dim_intersection([pw[sta], syn])
        syn_da = xr.DataArray(syn.sel(time=new_time))
        syn_da.name = '{}_class'.format(sta)
        syn_da = xr_reindex_with_date_range(syn_da)
        ds_list.append(syn_da)
    ds = xr.merge(ds_list)
    ds = ds.astype('int8')
    ds = ds.fillna(0)
    filename = 'GNSS_synoptic_class.nc'
    save_ncfile(ds, path, filename)
    return ds
Beispiel #3
0
 def pw_mlh_to_df(pw_new, mlh_site):
     newtime = dim_intersection([pw_new, mlh_site])
     MLH = mlh_site.sel(time=newtime)
     PW = pw_new.sel(time=newtime)
     df = PW.to_dataframe()
     df[MLH.name] = MLH.to_dataframe()
     new_time = pd.date_range(df.index.min(), df.index.max(), freq='1H')
     df = df.reindex(new_time)
     df.index.name = 'time'
     return df
Beispiel #4
0
def read_BD_matfile(path=ceil_path, plot=True, month=None, add_syn=True):
    from scipy.io import loadmat
    import pandas as pd
    from aux_gps import xr_reindex_with_date_range
    import matplotlib.pyplot as plt
    from aux_gps import dim_intersection
    from synoptic_procedures import read_synoptic_classification
    file = path / 'PBL_BD_LST.mat'
    mat = loadmat(file)
    mdata = mat['pblBD4shlomi']
    # mdata = mat['PBL_BD_LST']
    dates = mdata[:, :3]
    pbl = mdata[:, 3:]
    dates = dates.astype(str)
    dts = [pd.to_datetime(x[0] + '-' + x[1] + '-' + x[2]) for x in dates]
    dfs = []
    for i, dt in enumerate(dts):
        time = dt + pd.Timedelta(0.5, unit='H')
        times = pd.date_range(time, periods=48, freq='30T')
        df = pd.DataFrame(pbl[i], index=times)
        dfs.append(df)
    df = pd.concat(dfs)
    df.columns = ['MLH']
    df.index.name = 'time'
    # switch to UTC:
    df.index = df.index - pd.Timedelta(2, unit='H')
    da = df.to_xarray()['MLH']
    da.name = 'BD'
    da.attrs['full_name'] = 'Mixing Layer Height'
    da.attrs['name'] = 'MLH'
    da.attrs['units'] = 'm'
    da.attrs['station_full_name'] = 'Beit Dagan'
    da.attrs['lon'] = 34.81
    da.attrs['lat'] = 32.00
    da.attrs['alt'] = 34
    da = xr_reindex_with_date_range(da, freq='30T')
    # add synoptic data:
    syn = read_synoptic_classification().to_xarray()
    syn = syn.sel(time=slice('2015', '2016'))
    syn = syn.resample(time='30T').ffill()
    new_time = dim_intersection([da, syn])
    syn_da = syn.sel(time=new_time)
    syn_da = xr_reindex_with_date_range(syn_da, freq='30T')
    if plot:
        bd2015 = da.sel(time='2015').to_dataframe()
        bd2016 = da.sel(time='2016').to_dataframe()
        fig, axes = plt.subplots(2,
                                 1,
                                 sharey=True,
                                 sharex=False,
                                 figsize=(15, 10))
        if add_syn:
            cmap = plt.get_cmap("tab10")
            syn_df = syn_da.to_dataframe()
            bd2015['synoptics'] = syn_df.loc['2015', 'class_abbr']
            groups = []
            for i, (index, group) in enumerate(bd2015.groupby('synoptics')):
                groups.append(index)
                d = xr_reindex_with_date_range(group['BD'].to_xarray(),
                                               freq='30T')
                d.to_dataframe().plot(x_compat=True,
                                      ms=10,
                                      color=cmap(i),
                                      ax=axes[0],
                                      xlim=['2015-06', '2015-10'])
            axes[0].legend(groups)
            bd2016['synoptics'] = syn_df.loc['2016', 'class_abbr']
            groups = []
            for i, (index, group) in enumerate(bd2016.groupby('synoptics')):
                groups.append(index)
                d = xr_reindex_with_date_range(group['BD'].to_xarray(),
                                               freq='30T')
                d.to_dataframe().plot(x_compat=True,
                                      ms=10,
                                      color=cmap(i),
                                      ax=axes[1],
                                      xlim=['2016-06', '2016-10'])
            axes[1].legend(groups)
        else:
            bd2015.plot(ax=axes[0], xlim=['2015-06', '2015-10'])
            bd2016.plot(ax=axes[1], xlim=['2016-06', '2016-10'])
        for ax in axes.flatten():
            ax.set_ylabel('MLH [m]')
            ax.set_xlabel('UTC')
            ax.grid()
        fig.tight_layout()
        fig.suptitle('MLH from Beit-Dagan ceilometer for 2015 and 2016')
        filename = 'MLH-BD_syn.png'
        plt.savefig(savefig_path / filename, orientation='portrait')
    if add_syn:
        ds = da.to_dataset(name='BD')
        ds['syn'] = syn_da['class_abbr']
        return ds
    else:
        return da
Beispiel #5
0
def align_pw_mlh(path=work_yuval,
                 ceil_path=ceil_path,
                 site='tela',
                 interpolate=None,
                 plot=True,
                 dt_range_str='2015'):
    import xarray as xr
    from aux_gps import dim_intersection
    from aux_gps import xr_reindex_with_date_range
    import pandas as pd
    import matplotlib.pyplot as plt

    def pw_mlh_to_df(pw_new, mlh_site):
        newtime = dim_intersection([pw_new, mlh_site])
        MLH = mlh_site.sel(time=newtime)
        PW = pw_new.sel(time=newtime)
        df = PW.to_dataframe()
        df[MLH.name] = MLH.to_dataframe()
        new_time = pd.date_range(df.index.min(), df.index.max(), freq='1H')
        df = df.reindex(new_time)
        df.index.name = 'time'
        return df

    mlh = xr.load_dataset(ceil_path / 'MLH_from_ceilometers.nc')
    mlh_site = xr_reindex_with_date_range(mlh[pw_mlh_dict.get(site)],
                                          freq='1H')
    if interpolate is not None:
        print('interpolating ceil-site {} with max-gap of {}.'.format(
            pw_mlh_dict.get(site), interpolate))
        attrs = mlh_site.attrs
        mlh_site_inter = mlh_site.interpolate_na('time',
                                                 max_gap=interpolate,
                                                 method='cubic')
        mlh_site_inter.attrs = attrs
    pw = xr.open_dataset(work_yuval / 'GNSS_PW_thresh_50_homogenized.nc')
    pw = pw[['tela', 'klhv', 'jslm', 'nzrt', 'yrcm']]
    pw.load()
    pw_new = pw[site]
    if interpolate is not None:
        newtime = dim_intersection([pw_new, mlh_site_inter])
    else:
        newtime = dim_intersection([pw_new, mlh_site])
    pw_new = pw_new.sel(time=newtime)
    pw_new = xr_reindex_with_date_range(pw_new, freq='1H')
    if interpolate is not None:
        print('interpolating pw-site {} with max-gap of {}.'.format(
            site, interpolate))
        attrs = pw_new.attrs
        pw_new_inter = pw_new.interpolate_na('time',
                                             max_gap=interpolate,
                                             method='cubic')
        pw_new_inter.attrs = attrs
    df = pw_mlh_to_df(pw_new, mlh_site)
    if interpolate is not None:
        df_inter = pw_mlh_to_df(pw_new_inter, mlh_site_inter)
    if dt_range_str is not None:
        df = df.loc[dt_range_str, :]
    if plot:
        fig, ax = plt.subplots(figsize=(18, 5))
        if interpolate is not None:
            df_inter[pw_new.name].plot(style='b--', ax=ax)
            # same ax as above since it's automatically added on the right
            df_inter[mlh_site.name].plot(style='r--', secondary_y=True, ax=ax)
        ax = df[pw_new.name].plot(style='b-', marker='o', ax=ax, ms=5)
        # same ax as above since it's automatically added on the right
        ax_twin = df[mlh_site.name].plot(style='r-',
                                         marker='s',
                                         secondary_y=True,
                                         ax=ax,
                                         ms=5)
        if interpolate is not None:
            ax.legend(*[ax.get_lines() + ax.right_ax.get_lines()], [
                'PWV {} max interpolation'.format(interpolate), 'PWV',
                'MLH {} max interpolation'.format(interpolate), 'MLH'
            ],
                      loc='best')
        else:
            ax.legend([ax.get_lines()[0],
                       ax.right_ax.get_lines()[0]], ['PWV', 'MLH'],
                      loc='upper center')
        ax.set_title('MLH {} site and PWV {} site'.format(
            pw_mlh_dict.get(site), site))
        ax.set_xlim(df.dropna().index.min(), df.dropna().index.max())
        ax.set_ylabel('PWV [mm]', color='b')
        ax_twin.set_ylabel('MLH [m]', color='r')
        ax.tick_params(axis='y', colors='b')
        ax_twin.tick_params(axis='y', colors='r')
        ax.grid(True, which='both', axis='x')
        fig.tight_layout()
        if interpolate is not None:
            filename = '{}-{}_{}_time_series_{}_max_gap_interpolation.png'.format(
                site, pw_mlh_dict.get(site), dt_range_str, interpolate)
        else:
            filename = '{}-{}_{}_time_series.png'.format(
                site, pw_mlh_dict.get(site), dt_range_str)
        plt.savefig(savefig_path / filename, orientation='portrait')
    if interpolate is not None:
        ds = df_inter.to_xarray()
        ds[pw_new.name].attrs.update(pw_new.attrs)
        ds[mlh_site.name].attrs.update(mlh_site.attrs)
        return ds
    else:
        ds = df.to_xarray()
        ds[pw_new.name].attrs.update(pw_new.attrs)
        ds[mlh_site.name].attrs.update(mlh_site.attrs)
        return ds
Beispiel #6
0
def twin_hourly_mean_plot(pw,
                          mlh,
                          month=8,
                          ax=None,
                          title=True,
                          leg_loc='best',
                          unit='pts',
                          sample_rate=24,
                          fontsize=14):
    from aux_gps import dim_intersection
    import matplotlib.pyplot as plt
    from calendar import month_abbr
    #    from PW_stations import produce_geo_gnss_solved_stations
    #    df = produce_geo_gnss_solved_stations(plot=False)
    # first run multi-year month mean:
    if month is not None:
        pw = pw.sel(time=pw['time.month'] == month).dropna('time')
        mlh = mlh.sel(time=mlh['time.month'] == month).dropna('time')
    else:
        newtime = dim_intersection([pw, mlh], 'time')
        pw = pw.sel(time=newtime)
        mlh = mlh.sel(time=newtime)
    pw_hour = pw.groupby('time.hour').mean()
    pw_std = pw.groupby('time.hour').std()
    pw_hour_plus = (pw_hour + pw_std).values
    pw_hour_minus = (pw_hour - pw_std).values
    mlh_hour = mlh.groupby('time.hour').mean()
    mlh_std = mlh.groupby('time.hour').std()
    mlh_hour_minus = (mlh_hour - mlh_std).values
    mlh_hour_plus = (mlh_hour + mlh_std).values
    mlhyears = [mlh.time.dt.year.min().item(), mlh.time.dt.year.max().item()]
    pwyears = [pw.time.dt.year.min().item(), pw.time.dt.year.max().item()]
    mlh_month = mlh.time.dt.month.to_dataframe()['month'].value_counts(
    ).index[0]
    if unit == 'pts':
        pw_pts = pw.dropna('time').size
        mlh_pts = mlh.dropna('time').size
    elif unit == 'days':
        pw_pts = int(pw.dropna('time').size / sample_rate)
        mlh_pts = int(mlh.dropna('time').size / sample_rate)
    if ax is None:
        fig, ax = plt.subplots(figsize=(10, 8))
    red = 'tab:red'
    blue = 'tab:blue'
    pwln = pw_hour.plot(color=blue, marker='s', ax=ax)
    #        ax.errorbar(pw_hour.hour.values, pw_hour.values, pw_std.values,
    #                    label='PW', color=blue, capsize=5, elinewidth=2,
    #                    markeredgewidth=2)
    ax.fill_between(pw_hour.hour.values,
                    pw_hour_minus,
                    pw_hour_plus,
                    color=blue,
                    alpha=0.5)
    twin = ax.twinx()
    #        twin.errorbar(mlh_hour.hour.values, mlh_hour.values, mlh_std.values,
    #                      color=red, label='MLH', capsize=5, elinewidth=2,
    #                      markeredgewidth=2)
    mlhln = mlh_hour.plot(color=red, marker='o', ax=twin)
    twin.fill_between(mlh_hour.hour.values,
                      mlh_hour_minus,
                      mlh_hour_plus,
                      color=red,
                      alpha=0.5)
    #        handles, labels = ax.get_legend_handles_labels()
    #        handles = [h[0] for h in handles]
    #        handles1, labels1 = twin.get_legend_handles_labels()
    #        handles1 = [h[0] for h in handles1]
    #        hand = handles + handles1
    #        labs = labels + labels1
    if month is None:
        pw_label = 'PWV: {}-{} ({} {})'.format(pwyears[0], pwyears[1], pw_pts,
                                               unit)
        mlh_label = 'MLH: {}-{} ({} {})'.format(mlhyears[0], mlhyears[1],
                                                mlh_pts, unit)
    else:
        pw_pts = int(pw.dropna('time').size / 288)
        pw_label = 'PWV: {}-{}, {} ({} {})'.format(pwyears[0], pwyears[1],
                                                   month_abbr[mlh_month],
                                                   pw_pts, unit)
        mlh_label = 'MLH: {}-{}, {} ({} {})'.format(mlhyears[0], mlhyears[1],
                                                    month_abbr[mlh_month],
                                                    mlh_pts, unit)


#    if month is not None:
#        pwmln = pw_m_hour.plot(color='tab:orange', marker='^', ax=ax)
#        pwm_label = 'PW: {}-{}, {} ({} pts)'.format(pw_years[0], pw_years[1], month_abbr[month], pw_month.dropna('time').size)
#        ax.legend(pwln + mlhln + pwmln, [pw_label, mlh_label, pwm_label], loc=leg_loc)
#    else:
    ax.legend(pwln + mlhln, [pw_label, mlh_label], loc=leg_loc)
    ax.tick_params(axis='y', colors=blue, labelsize=fontsize)
    twin.tick_params(axis='y', colors=red, labelsize=fontsize)
    ax.set_ylabel('PWV [mm]', color=blue, fontsize=fontsize)
    twin.set_ylabel('MLH [m]', color=red, fontsize=fontsize)
    ax.set_xticks([x for x in range(24)])
    ax.set_xlabel('Hour of day [UTC]', fontsize=fontsize)
    mlh_name = mlh.attrs['station_full_name'].replace('_', '-')
    textstr = '{}, {}'.format(mlh_name, pw.name.upper())
    props = dict(boxstyle='round', facecolor='white', alpha=0.5)
    ax.text(0.05,
            0.95,
            textstr,
            transform=ax.transAxes,
            fontsize=fontsize,
            verticalalignment='top',
            bbox=props)
    if title:
        ax.set_title(
            'The diurnal cycle of {} Mixing Layer Height and {} GNSS site PWV'.
            format(mlh_name, pw.name.upper()))
    return ax, twin
Beispiel #7
0
def scatter_plot_pw_mlh(pw,
                        mlh,
                        diurnal=False,
                        ax=None,
                        title=True,
                        leg_loc='best',
                        month=None):
    from aux_gps import dim_intersection
    import xarray as xr
    import numpy as np
    from sklearn.metrics import r2_score
    import matplotlib.pyplot as plt
    from PW_stations import produce_geo_gnss_solved_stations
    df = produce_geo_gnss_solved_stations(plot=False)
    pw_alt = df.loc[pw.name, 'alt']
    pw_attrs = pw.attrs
    mlh_attrs = mlh.attrs
    if diurnal:
        if month is not None:
            pw = pw.sel(time=pw['time.month'] == month)
        else:
            newtime = dim_intersection([pw, mlh], 'time')
            pw = pw.sel(time=newtime)
            mlh = mlh.sel(time=newtime)
        pw = pw.groupby('time.hour').mean()
        pw.attrs = pw_attrs
        mlh = mlh.groupby('time.hour').mean()
        mlh.attrs = mlh_attrs
    else:
        newtime = dim_intersection([pw, mlh], 'time')
        pw = pw.sel(time=newtime)
        mlh = mlh.sel(time=newtime)
    ds = xr.merge([pw, mlh])
    if ax is None:
        fig, ax = plt.subplots(figsize=(10, 10))
    ds.plot.scatter(pw.name, mlh.name, ax=ax)
    coefs = np.polyfit(pw.values, mlh.values, 1)
    x = np.linspace(pw.min().item(), pw.max().item(), 100)
    y = np.polyval(coefs, x)
    r2 = r2_score(mlh.values, np.polyval(coefs, pw.values))
    #    coefs2 = np.polyfit(pw.values, mlh.values, 2)
    #    y2 = np.polyval(coefs2, x)
    #    r22 = r2_score(mlh.values,np.polyval(coefs2, pw.values))
    ax.plot(x, y, color='tab:red')
    # ax.plot(x, y2, color='tab:orange')
    ax.set_xlabel('PWV [mm]')
    ax.set_ylabel('MLH [m]')
    ax.legend(['linear fit', 'data'], loc=leg_loc)
    textstr = '\n'.join([
        'n={}'.format(pw.size), r'R$^2$={:.2f}'.format(r2),
        'slope={:.1f} m/mm'.format(coefs[0])
    ])
    props = dict(boxstyle='round', facecolor='wheat', alpha=0.5)
    ax.text(0.05,
            0.95,
            textstr,
            transform=ax.transAxes,
            fontsize=10,
            verticalalignment='top',
            bbox=props)
    mlh_name = mlh.attrs['station_full_name'].replace('_', '-')
    if title:
        ax.set_title(
            '{} ({:.0f} m) GNSS site PW vs. {} ({:.0f} m) Mixing Layer Height'.
            format(pw.name.upper(), pw_alt, mlh_name, mlh.attrs['alt']))
    return ax
Beispiel #8
0
def process_data_from_uwyo_sounding(path, st_num):
    """create tm, tpw from sounding station and also add surface temp and
    station caluculated ipw"""
    import xarray as xr
    from aux_gps import dim_intersection
    import numpy as np
    import logging
    logger = logging.getLogger('uwyo')
    # da = xr.open_dataarray(sound_path / 'ALL_bet_dagan_soundings.nc')
    pw_file = 'PW_{}_soundings.nc'.format(st_num)
    all_file = 'ALL_{}_soundings.nc'.format(st_num)
    da = xr.open_dataarray(path / all_file)
    pw = xr.open_dataarray(path / pw_file)
    new_time = dim_intersection([da, pw], 'time', dropna=False)
    logger.info('loaded {}'.format(pw_file))
    logger.info('loaded {}'.format(all_file))
    da = da.sel(time=new_time)
    pw = pw.sel(time=new_time)
    pw.load()
    da.load()
    logger.info('calculating pw and tm for station {}'.format(st_num))
    ts_list = []
    tpw_list = []
    tm_list = []
    #    cld_list = []
    for date in da.time:
        ts_list.append(da.sel(var='TEMP', mpoint=0, time=date) + 273.15)
        tpw_list.append(precipitable_water(da.sel(time=date)))
        tm_list.append(Tm(da.sel(time=date)))


#        if np.isnan(ds.CLD.sel(time=date)).all():
#            cld_list.append(0)
#        else:
#            cld_list.append(1)
    tpw = xr.DataArray(tpw_list, dims='time')
    tm = xr.DataArray(tm_list, dims='time')
    tm.attrs[
        'description'] = 'mean atmospheric temperature calculated by water vapor pressure weights'
    tm.attrs['units'] = 'K'
    ts = xr.concat(ts_list, 'time')
    ts.attrs[
        'description'] = 'Surface temperature from {} station soundings'.format(
            st_num)
    ts.attrs['units'] = 'K'
    result = pw.to_dataset(name='pw')
    result['tpw'] = tpw
    result['tm'] = tm
    result['ts'] = ts
    result['tpw'].attrs[
        'description'] = 'station {} percipatable water calculated from sounding by me'.format(
            st_num)
    result['tpw'].attrs['units'] = 'mm'
    result['season'] = result['time.season']
    result['hour'] = result['time.hour'].astype(str)
    result['hour'] = result.hour.where(result.hour != '12', 'noon')
    result['hour'] = result.hour.where(result.hour != '0', 'midnight')
    #    result['any_cld'] = xr.DataArray(cld_list, dims='time')
    result = result.dropna('time')
    filename = 'station_{}_sounding_pw_Ts_Tk.nc'.format(st_num)
    logger.info('saving {} to {}'.format(filename, path))
    comp = dict(zlib=True, complevel=9)  # best compression
    encoding = {var: comp for var in result}
    result.to_netcdf(path / filename, 'w', encoding=encoding)
    logger.info('Done!')
    return
def evaluate_sin_to_tmsearies(
        da,
        time_dim='time',
        plot=True,
        params_up={
            'amp': 6.0,
            'period': 365.0,
            'phase': 30.0,
            'offset': 253.0,
            'a': 0.0,
            'b': 0.0
        },
        func='sine',
        bounds=None,
        just_eval=False):
    import pandas as pd
    import xarray as xr
    import numpy as np
    from aux_gps import dim_intersection
    from scipy.optimize import curve_fit
    from sklearn.metrics import mean_squared_error

    def sine(time, amp, period, phase, offset):
        f = amp * np.sin(2 * np.pi * (time / period + 1.0 / phase)) + offset
        return f

    def sine_on_linear(time, amp, period, phase, offset, a):
        f = a * time / 365.25 + amp * \
            np.sin(2 * np.pi * (time / period + 1.0 / phase)) + offset
        return f

    def sine_on_quad(time, amp, period, phase, offset, a, b):
        f = a * (time / 365.25) ** 2.0 + b * time / 365.25 + amp * \
            np.sin(2 * np.pi * (time / period + 1.0 / phase)) + offset
        return f

    params = {
        'amp': 6.0,
        'period': 365.0,
        'phase': 30.0,
        'offset': 253.0,
        'a': 0.0,
        'b': 0.0
    }
    params.update(params_up)
    print(params)
    lower = {}
    upper = {}
    if bounds is not None:
        # lower = [(x - y) for x, y in zip(params, perc2)]
        for key in params.keys():
            lower[key] = -np.inf
            upper[key] = np.inf
        lower['phase'] = 0.01
        upper['phase'] = 0.05
        lower['offset'] = 46.2
        upper['offset'] = 46.3
        lower['a'] = 0.0001
        upper['a'] = 0.002
        upper['amp'] = 0.04
        lower['amp'] = 0.02
    else:
        for key in params.keys():
            lower[key] = -np.inf
            upper[key] = np.inf
    lower = [x for x in lower.values()]
    upper = [x for x in upper.values()]
    params = [x for x in params.values()]
    da_no_nans = da.dropna(time_dim)
    time = da_no_nans[time_dim].values
    time = pd.to_datetime(time)
    jul = time.to_julian_date()
    jul -= jul[0]
    jul_with_nans = pd.to_datetime(da[time_dim].values).to_julian_date()
    jul_with_nans -= jul[0]
    ydata = da_no_nans.values
    if func == 'sine':
        print('Model chosen: y = amp * sin (2*pi*(x/T + 1/phi)) + offset')
        if not just_eval:
            popt, pcov = curve_fit(sine,
                                   jul,
                                   ydata,
                                   p0=params[:-2],
                                   bounds=(lower[:-2], upper[:-2]),
                                   ftol=1e-9,
                                   xtol=1e-9)
            amp = popt[0]
            period = popt[1]
            phase = popt[2]
            offset = popt[3]
            perr = np.sqrt(np.diag(pcov))
            print('amp: {:.4f} +- {:.2f}'.format(amp, perr[0]))
            print('period: {:.2f} +- {:.2f}'.format(period, perr[1]))
            print('phase: {:.2f} +- {:.2f}'.format(phase, perr[2]))
            print('offset: {:.2f} +- {:.2f}'.format(offset, perr[3]))
        new = sine(jul_with_nans, amp, period, phase, offset)
    elif func == 'sine_on_linear':
        print(
            'Model chosen: y = a * x + amp * sin (2*pi*(x/T + 1/phi)) + offset'
        )
        if not just_eval:
            popt, pcov = curve_fit(sine_on_linear,
                                   jul,
                                   ydata,
                                   p0=params[:-1],
                                   bounds=(lower[:-1], upper[:-1]),
                                   xtol=1e-11,
                                   ftol=1e-11)
            amp = popt[0]
            period = popt[1]
            phase = popt[2]
            offset = popt[3]
            a = popt[4]
            perr = np.sqrt(np.diag(pcov))
            print('amp: {:.4f} +- {:.2f}'.format(amp, perr[0]))
            print('period: {:.2f} +- {:.2f}'.format(period, perr[1]))
            print('phase: {:.2f} +- {:.2f}'.format(phase, perr[2]))
            print('offset: {:.2f} +- {:.2f}'.format(offset, perr[3]))
            print('a: {:.7f} +- {:.2f}'.format(a, perr[4]))
        new = sine_on_linear(jul_with_nans, amp, period, phase, offset, a)
    elif func == 'sine_on_quad':
        print(
            'Model chosen: y = a * x^2 + b * x + amp * sin (2*pi*(x/T + 1/phi)) + offset'
        )
        if not just_eval:
            popt, pcov = curve_fit(sine_on_quad,
                                   jul,
                                   ydata,
                                   p0=params,
                                   bounds=(lower, upper))
            amp = popt[0]
            period = popt[1]
            phase = popt[2]
            offset = popt[3]
            a = popt[4]
            b = popt[5]
            perr = np.sqrt(np.diag(pcov))
            print('amp: {:.4f} +- {:.2f}'.format(amp, perr[0]))
            print('period: {:.2f} +- {:.2f}'.format(period, perr[1]))
            print('phase: {:.2f} +- {:.2f}'.format(phase, perr[2]))
            print('offset: {:.2f} +- {:.2f}'.format(offset, perr[3]))
            print('a: {:.7f} +- {:.2f}'.format(a, perr[4]))
            print('b: {:.7f} +- {:.2f}'.format(a, perr[5]))
        new = sine_on_quad(jul_with_nans, amp, period, phase, offset, a, b)
    new_da = xr.DataArray(new, dims=[time_dim])
    new_da[time_dim] = da[time_dim]
    resid = new_da - da
    rmean = np.mean(resid)
    new_time = dim_intersection([da, new_da], time_dim)
    rmse = np.sqrt(
        mean_squared_error(
            da.sel({
                time_dim: new_time
            }).values,
            new_da.sel({
                time_dim: new_time
            }).values))
    print('MEAN : {}'.format(rmean))
    print('RMSE : {}'.format(rmse))
    if plot:
        da.plot.line(marker='.', linewidth=0., figsize=(20, 5))
        new_da.plot.line(marker='.', linewidth=0.)
    return new_da