示例#1
0
def test_sample_rows():
    """Sampling ratings when dataframe has non-unique indices"""
    ratings = lktu.ml_test.ratings
    ratings = ratings.set_index('user')  ##forces non-unique index
    with pytest.raises(ValueError):
        for split in xf.sample_rows(ratings, partitions=5, size=1000):
            pass
示例#2
0
def test_sample_oversize():
    ratings = lktu.ml_test.ratings
    splits = xf.sample_rows(ratings, 50, 10000)
    splits = list(splits)
    assert len(splits) == 50

    for s in splits:
        assert len(s.test) + len(s.train) == len(ratings)
        assert all(s.test.index.union(s.train.index) == ratings.index)
        test_idx = s.test.set_index(['user', 'item']).index
        train_idx = s.train.set_index(['user', 'item']).index
        assert len(test_idx.intersection(train_idx)) == 0
示例#3
0
def test_sample_non_disjoint():
    ratings = lktu.ml_test.ratings
    splits = xf.sample_rows(ratings, partitions=10, size=1000, disjoint=False)
    splits = list(splits)
    assert len(splits) == 10

    for s in splits:
        assert len(s.test) == 1000
        assert len(s.test) + len(s.train) == len(ratings)
        test_idx = s.test.set_index(['user', 'item']).index
        train_idx = s.train.set_index(['user', 'item']).index
        assert len(test_idx.intersection(train_idx)) == 0

    # There are enough splits & items we should pick at least one duplicate
    ipairs = ((s1.test.set_index('user', 'item').index,
               s2.test.set_index('user', 'item').index)
              for (s1, s2) in it.product(splits, splits))
    isizes = [len(i1.intersection(i2)) for (i1, i2) in ipairs]
    assert any(n > 0 for n in isizes)
示例#4
0
def test_sample_rows_more_smaller_parts():
    ratings = lktu.ml_test.ratings
    splits = xf.sample_rows(ratings, partitions=10, size=500)
    splits = list(splits)
    assert len(splits) == 10

    for s in splits:
        assert len(s.test) == 500
        assert len(s.test) + len(s.train) == len(ratings)
        test_idx = s.test.set_index(['user', 'item']).index
        train_idx = s.train.set_index(['user', 'item']).index
        assert len(test_idx.intersection(train_idx)) == 0

    for s1, s2 in it.product(splits, splits):
        if s1 is s2:
            continue

        i1 = s1.test.set_index(['user', 'item']).index
        i2 = s2.test.set_index(['user', 'item']).index
        inter = i1.intersection(i2)
        assert len(inter) == 0