コード例 #1
0
 def __init__(self,
              df_path: str,
              forecast_total: int,
              use_real_precip=True,
              use_real_temp=True,
              target_supplied=True,
              interpolate=False,
              **kwargs):
     """
     :param str df_path:
     A data loader for the test data.
     """
     super().__init__(**kwargs)
     self.original_df = pd.read_csv(df_path)
     if interpolate:
         self.original_df = fix_timezones(df_path)
         self.original_df = interpolate_missing_values(self.original_df)
     print("CSV Path below")
     print(df_path)
     self.forecast_total = forecast_total
     self.use_real_temp = use_real_temp
     self.use_real_precip = use_real_precip
     self.target_supplied = target_supplied
     # Convert back to datetime and save index
     self.original_df["datetime"] = self.original_df["datetime"].astype(
         "datetime64[ns]")
     self.original_df["original_index"] = self.original_df.index
コード例 #2
0
 def __init__(self,
              file_path: str,
              forecast_history: int,
              forecast_length: int,
              target_col: List,
              relevant_cols: List,
              scaling=None,
              start_stamp: int = 0,
              end_stamp: int = None,
              interpolate_param=True):
     """
     A data loader that takes a CSV file and properly batches for use in training/eval a PyTorch model
     :param file_path: The path to the CSV file you wish to use. 
     :param forecast_history: This is the length of the historical time series data you wish to utilize for forecasting
     :param forecast_length: The number of time steps to forecast ahead (for transformer this must equal history_length)
     :param relevant_cols: Supply column names you wish to predict in the forecast (others will not be used)
     :param target_col: The target column or columns you to predict. If you only have one still use a list ['cfs']
     :param scaling: (highly reccomended) If provided should be a subclass of sklearn.base.BaseEstimator 
     and sklearn.base.TransformerMixin) i.e StandardScaler,  MaxAbsScaler, MinMaxScaler, etc) Note without 
     a scaler the loss is likely to explode and cause infinite loss which will corrupt weights
     :param start_stamp int: Optional if you want to only use part of a CSV for training, validation or testing supply these
     "param end_stamp int: Optional if you want to only use part of a CSV for training, validation, or testing supply these
     """
     super().__init__()
     self.forecast_history = forecast_history
     self.forecast_length = forecast_length
     # TODO allow other filling methods
     print("interpolate should be below")
     if interpolate_param:
         print("now filling missing values")
         df = fix_timezones(file_path)
         df = interpolate_missing_values(df)
     else:
         df = pd.read_csv(file_path)
     print("Now loading and scaling " + file_path)
     self.df = df.sort_values(by='datetime')[relevant_cols]
     self.scale = None
     if start_stamp != 0:
         self.df = self.df[start_stamp:]
     if end_stamp != None:
         self.df = self.df[:end_stamp]
     if scaling is not None:
         self.scale = scaling
         temp_df = self.scale.fit_transform(self.df)
         # We define a second scaler to scale the end output
         # back to normal as models might not necessarily predict
         # other present time series values.
         targ_scale_class = self.scale.__class__
         self.targ_scaler = targ_scale_class()
         self.targ_scaler.fit_transform(
             self.df[target_col[0]].values.reshape(-1, 1))
         self.df = pd.DataFrame(temp_df,
                                index=self.df.index,
                                columns=self.df.columns)
     if (len(self.df) - self.df.count()).max() != 0:
         raise (
             "Error nan values detected in data. Please run interpolate ffill or bfill on data"
         )
     self.targ_col = target_col
コード例 #3
0
 def test_tz_interpolate_fix(self):
     """
     Additional function to test interpolation
     """
     file_path = os.path.join(self.test_data_path, "river_test_sm.csv")
     revised_df = fix_timezones(file_path)
     self.assertEqual(revised_df.iloc[0]['cfs'], 0.0)
     self.assertEqual(revised_df.iloc[1]['tmpf'], 19.94)
     revised_df = interpolate_missing_values(revised_df)
     self.assertEqual(0, sum(pd.isnull(revised_df['cfs'])))
     self.assertEqual(0, sum(pd.isnull(revised_df['precip'])))