Esempio n. 1
0
def test_cv_splits_no_stratify():

    total = 10
    n_folds = 5
    cver = cv_splits(total, n_folds, stratify=False)

    # Each fold contains the whole set without repetitions
    for fold in xrange(n_folds):
        train, trest = cver(fold)
        in_train = set(train)
        in_trest = set(trest)
        assert len(in_train & in_trest) == 0, "There must be no repetition in train/test"
        assert len(in_train | in_trest) == total, "All the examples are in the split"

    # Each example is exactly once in test
    in_trest = list(chain(*[cver(fold)[1] for fold in xrange(n_folds)]))
    assert len(set(in_trest)) == len(in_trest), "Each example must be tested exactly once"
    assert len(in_trest) == total, "Each example must be in test at least once"

    # We get an exception if we try to get stratified sampling without providing Y
    with pytest.raises(Exception):
        cv_splits(10, n_folds)

    # We get an exception if we try to get an unexistent fold
    with pytest.raises(Exception):
        cv_splits(10, n_folds)(11)

    # We get an exception if we try to get more folds than instances
    with pytest.raises(Exception):
        cv_splits(5, 6, stratify=False)
Esempio n. 2
0
def test_cv_splits_stratified():

    total = 10
    n_folds = 5

    # Stratify test
    Y = np.array([0, 0, 0, 0, 0, 1, 1, 1, 1, 1])
    cver = cv_splits(10, n_folds, Y=Y)
    balancy_train = []
    for fold in range(n_folds):
        balancy_train.append(np.sum(Y[cver(fold)[0]]) / float(len(cver(fold)[0])))
    # the average balancy should be around 0.5
    assert 0.48 < np.mean(np.array(balancy_train))
    assert 0.52 > np.mean(np.array(balancy_train))

    # Banned train
    cver = cv_splits(10, n_folds, stratify=False, banned_train=np.array([1]))
    for fold in range(n_folds):
        assert 1 not in cver(fold)[0]

    # Banned train 2
    cver = cv_splits(20, n_folds, stratify=False, banned_train=[0, 1])
    for fold in range(n_folds):
        assert 1 not in cver(fold)[0]
        assert 0 not in cver(fold)[0]

    # Banned test
    cver = cv_splits(10, n_folds, stratify=False, banned_test=np.array([1]))
    for fold in range(n_folds):
        assert 1 not in cver(fold)[1]

    # Banned test 2
    cver = cv_splits(10, n_folds, stratify=False, banned_test=[0, 1])
    for fold in range(n_folds):
        assert 1 not in cver(fold)[1]
        assert 0 not in cver(fold)[1]