def test_temporal_split(self): """Test temporal split.""" testdata = generate_temporal_testdata() # check shape self.assertEqual(testdata.shape[0], 10) self.assertEqual(testdata.shape[1], 5) data = temporal_split(testdata, test_rate=0.2, by_user=False) # check shape self.assertEqual(data.shape[0], 10) self.assertEqual(data.shape[1], 6) # check number tp_train = data[data[DEFAULT_FLAG_COL] == "train"] self.assertEqual(tp_train.shape[0], 6) tp_validate = data[data[DEFAULT_FLAG_COL] == "validate"] self.assertEqual(tp_validate.shape[0], 2) tp_test = data[data[DEFAULT_FLAG_COL] == "test"] self.assertEqual(tp_test.shape[0], 2) # check validate self.assertEqual(tp_validate.iloc[0, 0], 1) self.assertEqual(tp_validate.iloc[0, 1], 3) self.assertEqual(tp_validate.iloc[0, 2], 400) self.assertEqual(tp_validate.iloc[0, 3], 50) self.assertEqual(tp_validate.iloc[1, 0], 1) self.assertEqual(tp_validate.iloc[1, 1], 4) self.assertEqual(tp_validate.iloc[1, 2], 410) self.assertEqual(tp_validate.iloc[1, 3], 40) # check test self.assertEqual(tp_test.iloc[0, 0], 2) self.assertEqual(tp_test.iloc[0, 1], 5) self.assertEqual(tp_test.iloc[0, 2], 500) self.assertEqual(tp_test.iloc[0, 3], 60) self.assertEqual(tp_test.iloc[1, 0], 2) self.assertEqual(tp_test.iloc[1, 1], 5) self.assertEqual(tp_test.iloc[1, 2], 500) self.assertEqual(tp_test.iloc[1, 3], 10)
def test_temporal_split_by_user(self): """Test temporal split by user.""" testdata = generate_data_by_user() # check shape self.assertEqual(testdata.shape[0], 8) self.assertEqual(testdata.shape[1], 5) data = temporal_split(testdata, test_rate=0.1, by_user=True) # check shape self.assertEqual(data.shape[0], 8) self.assertEqual(data.shape[1], 6) # check number tp_train = data[data[DEFAULT_FLAG_COL] == "train"] self.assertEqual(tp_train.shape[0], 4) tp_validate = data[data[DEFAULT_FLAG_COL] == "validate"] self.assertEqual(tp_validate.shape[0], 2) tp_test = data[data[DEFAULT_FLAG_COL] == "test"] self.assertEqual(tp_test.shape[0], 2) # check validate self.assertEqual(tp_validate.iloc[0, 0], 0) self.assertEqual(tp_validate.iloc[0, 1], 2) self.assertEqual(tp_validate.iloc[0, 2], 300) self.assertEqual(tp_validate.iloc[0, 3], 20) self.assertEqual(tp_validate.iloc[1, 0], 1) self.assertEqual(tp_validate.iloc[1, 1], 5) self.assertEqual(tp_validate.iloc[1, 2], 600) self.assertEqual(tp_validate.iloc[1, 3], 40) # check test self.assertEqual(tp_test.iloc[0, 0], 0) self.assertEqual(tp_test.iloc[0, 1], 2) self.assertEqual(tp_test.iloc[0, 2], 300) self.assertEqual(tp_test.iloc[0, 3], 30) self.assertEqual(tp_test.iloc[1, 0], 1) self.assertEqual(tp_test.iloc[1, 1], 5) self.assertEqual(tp_test.iloc[1, 2], 600) self.assertEqual(tp_test.iloc[1, 3], 60)