def test_sample_rows(): """Sampling ratings when dataframe has non-unique indices""" ratings = lktu.ml_test.ratings ratings = ratings.set_index('user') ##forces non-unique index with pytest.raises(ValueError): for split in xf.sample_rows(ratings, partitions=5, size=1000): pass
def test_sample_oversize(): ratings = lktu.ml_test.ratings splits = xf.sample_rows(ratings, 50, 10000) splits = list(splits) assert len(splits) == 50 for s in splits: assert len(s.test) + len(s.train) == len(ratings) assert all(s.test.index.union(s.train.index) == ratings.index) test_idx = s.test.set_index(['user', 'item']).index train_idx = s.train.set_index(['user', 'item']).index assert len(test_idx.intersection(train_idx)) == 0
def test_sample_non_disjoint(): ratings = lktu.ml_test.ratings splits = xf.sample_rows(ratings, partitions=10, size=1000, disjoint=False) splits = list(splits) assert len(splits) == 10 for s in splits: assert len(s.test) == 1000 assert len(s.test) + len(s.train) == len(ratings) test_idx = s.test.set_index(['user', 'item']).index train_idx = s.train.set_index(['user', 'item']).index assert len(test_idx.intersection(train_idx)) == 0 # There are enough splits & items we should pick at least one duplicate ipairs = ((s1.test.set_index('user', 'item').index, s2.test.set_index('user', 'item').index) for (s1, s2) in it.product(splits, splits)) isizes = [len(i1.intersection(i2)) for (i1, i2) in ipairs] assert any(n > 0 for n in isizes)
def test_sample_rows_more_smaller_parts(): ratings = lktu.ml_test.ratings splits = xf.sample_rows(ratings, partitions=10, size=500) splits = list(splits) assert len(splits) == 10 for s in splits: assert len(s.test) == 500 assert len(s.test) + len(s.train) == len(ratings) test_idx = s.test.set_index(['user', 'item']).index train_idx = s.train.set_index(['user', 'item']).index assert len(test_idx.intersection(train_idx)) == 0 for s1, s2 in it.product(splits, splits): if s1 is s2: continue i1 = s1.test.set_index(['user', 'item']).index i2 = s2.test.set_index(['user', 'item']).index inter = i1.intersection(i2) assert len(inter) == 0