def test_cv_splits_no_stratify(): total = 10 n_folds = 5 cver = cv_splits(total, n_folds, stratify=False) # Each fold contains the whole set without repetitions for fold in xrange(n_folds): train, trest = cver(fold) in_train = set(train) in_trest = set(trest) assert len(in_train & in_trest) == 0, "There must be no repetition in train/test" assert len(in_train | in_trest) == total, "All the examples are in the split" # Each example is exactly once in test in_trest = list(chain(*[cver(fold)[1] for fold in xrange(n_folds)])) assert len(set(in_trest)) == len(in_trest), "Each example must be tested exactly once" assert len(in_trest) == total, "Each example must be in test at least once" # We get an exception if we try to get stratified sampling without providing Y with pytest.raises(Exception): cv_splits(10, n_folds) # We get an exception if we try to get an unexistent fold with pytest.raises(Exception): cv_splits(10, n_folds)(11) # We get an exception if we try to get more folds than instances with pytest.raises(Exception): cv_splits(5, 6, stratify=False)
def test_cv_splits_stratified(): total = 10 n_folds = 5 # Stratify test Y = np.array([0, 0, 0, 0, 0, 1, 1, 1, 1, 1]) cver = cv_splits(10, n_folds, Y=Y) balancy_train = [] for fold in range(n_folds): balancy_train.append(np.sum(Y[cver(fold)[0]]) / float(len(cver(fold)[0]))) # the average balancy should be around 0.5 assert 0.48 < np.mean(np.array(balancy_train)) assert 0.52 > np.mean(np.array(balancy_train)) # Banned train cver = cv_splits(10, n_folds, stratify=False, banned_train=np.array([1])) for fold in range(n_folds): assert 1 not in cver(fold)[0] # Banned train 2 cver = cv_splits(20, n_folds, stratify=False, banned_train=[0, 1]) for fold in range(n_folds): assert 1 not in cver(fold)[0] assert 0 not in cver(fold)[0] # Banned test cver = cv_splits(10, n_folds, stratify=False, banned_test=np.array([1])) for fold in range(n_folds): assert 1 not in cver(fold)[1] # Banned test 2 cver = cv_splits(10, n_folds, stratify=False, banned_test=[0, 1]) for fold in range(n_folds): assert 1 not in cver(fold)[1] assert 0 not in cver(fold)[1]