Пример #1
0
 def _setup_obs_handler(self):
     ratings_df = pd.read_csv(rating_csv_path)
     obs = ObservationsDF(ratings_df, uid_col=self.user_id_col, iid_col=self.item_id_col)
     obs = obs.sample_observations(n_users=1000, n_items=1000)
     self.state.train_obs, self.state.test_obs = obs.split_train_test(ratio=0.2, users_ratio=1.0)
     # add some fake data for sanity tests
     self.state.train_obs.df_obs = self._add_testing_obs_data(self.state.train_obs.df_obs)
Пример #2
0
    def test_splits(self):
        from ml_recsys_tools.data_handlers.interaction_handlers_base import ObservationsDF

        ratings_df = pd.read_csv(rating_csv_path)
        obs = ObservationsDF(ratings_df,
                             uid_col='userid',
                             iid_col='itemid',
                             timestamp_col='timestamp')
        obs = obs.sample_observations(n_users=1000, n_items=1000)

        ratio = 0.2

        # regular split
        train_obs, test_obs = obs.split_train_test(ratio=ratio)
        self._obs_split_data_check(obs, train_obs, test_obs)
        self.state.train_obs, self.state.test_obs = train_obs, test_obs

        # split for only some users
        user_ratio = 0.2
        train_obs, test_obs = obs.split_train_test(ratio=ratio,
                                                   users_ratio=user_ratio)
        self._obs_split_data_check(obs, train_obs, test_obs)
        post_split_ratio = test_obs.df_obs['userid'].nunique(
        ) / train_obs.df_obs['userid'].nunique()
        self.assertAlmostEqual(user_ratio, post_split_ratio, places=1)

        # split by timestamp
        time_col = obs.timestamp_col
        train_obs, test_obs = obs.split_train_test(ratio=ratio,
                                                   time_split_column=time_col)
        self._obs_split_data_check(obs, train_obs, test_obs)
        self.assertGreaterEqual(test_obs.df_obs[time_col].min(),
                                train_obs.df_obs[time_col].max())
Пример #3
0
    def test_splits(self):

        ratings_df = pd.read_csv(rating_csv_path)
        obs_params = dict(uid_col='userid',
                          iid_col='itemid',
                          timestamp_col='timestamp')
        obs = ObservationsDF(ratings_df, **obs_params)
        obs = obs.sample_observations(n_users=1000, n_items=1000)
        self._split_tester(obs)

        items_df = pd.read_csv(movies_csv_path)
        obs_feat = ObsWithFeatures(df_obs=ratings_df,
                                   df_items=items_df,
                                   item_id_col='itemid',
                                   **obs_params)
        obs_feat = obs_feat.sample_observations(n_users=1000, n_items=1000)
        self._split_tester(obs_feat)