def test_kfold_random_stable(seed): random.seed(seed) k = 2 + int(random.expovariate(0.2)) n = k + int(random.expovariate(0.000001)) splits1 = kfold_random(nrows=n, nsplits=k, seed=seed) with dt.options.context(nthreads=dt.options.nthreads // 2): splits2 = kfold_random(nrows=n, nsplits=k, seed=seed) assert len(splits1) == len(splits2) == k for i in range(k): assert len(splits1[i]) == len(splits2[i]) == 2 for j in range(2): values1 = splits1[i][j].to_list()[0] values2 = splits2[i][j].to_list()[0] assert values1 == values2
def test_kfold_random_bad_args(): assert_value_error(lambda: kfold_random(nrows=-1, nsplits=1), "Argument `nrows` in kfold_random() cannot be negative") assert_value_error(lambda: kfold_random(nrows=3, nsplits=-3), "Argument `nsplits` in kfold_random() cannot be negative") assert_value_error(lambda: kfold_random(nrows=10, nsplits=3, seed=-1), "Argument `seed` in kfold_random() cannot be negative") assert_value_error(lambda: kfold_random(nrows=5, nsplits=0), "The number of splits cannot be less than two") assert_value_error(lambda: kfold_random(nrows=1, nsplits=2), "The number of splits cannot exceed the number of rows")
def test_kfold_random_any(seed): random.seed(seed) k = 2 + int(random.expovariate(0.01)) n = k + int(random.expovariate(0.0001)) splits = kfold_random(nrows=n, nsplits=k, seed=seed) assert isinstance(splits, list) and len(splits) == k all_folds = [] for split in splits: assert isinstance(split, tuple) and len(split) == 2 train, test = split assert isinstance(train, dt.Frame) assert isinstance(test, dt.Frame) assert train.stypes == (dt.int32,) assert test.stypes == (dt.int32,) assert train.ncols == 1 assert test.ncols == 1 assert train.nrows + test.nrows == n assert test.nrows in [n//k, n//k + 1] train_data = train.to_list()[0] test_data = test.to_list()[0] assert train_data == sorted(train_data) assert test_data == sorted(test_data) train_set = set(train_data) test_set = set(test_data) assert len(train_set) == train.nrows assert len(test_set) == test.nrows assert (train_set | test_set) == set(range(n)) all_folds += test_data all_folds.sort() assert all_folds == list(range(n))
def test_kfold_random_2_2(): splits = kfold_random(nrows=2, nsplits=2) assert isinstance(splits, list) and len(splits) == 2 assert all(isinstance(s, tuple) and len(s) == 2 for s in splits) assert all(s[0].shape == (1, 1) and s[1].shape == (1, 1) for s in splits) a = splits[0][0][0, 0] assert a == 0 or a == 1 assert splits[0][1][0, 0] == 1 - a assert splits[1][0][0, 0] == 1 - a assert splits[1][1][0, 0] == a
def test_kfold_random_bad_args(): msg = r"Argument nrows in function datatable.kfold_random\(\) cannot be negative" with pytest.raises(ValueError, match=msg): kfold_random(nrows=-1, nsplits=1) msg = r"Argument nsplits in function datatable.kfold_random\(\) cannot be negative" with pytest.raises(ValueError, match=msg): kfold_random(nrows=3, nsplits=-3) msg = r"Argument seed in function datatable.kfold_random\(\) cannot be negative" with pytest.raises(ValueError, match=msg): kfold_random(nrows=10, nsplits=3, seed=-1) msg = "The number of splits cannot be less than two" with pytest.raises(ValueError, match=msg): kfold_random(nrows=5, nsplits=0) msg = "The number of splits cannot exceed the number of rows" with pytest.raises(ValueError, match=msg): kfold_random(nrows=1, nsplits=2)
def test_kfold_random_api(): assert_type_error(lambda: kfold_random(), "Required parameter `nrows` is missing") assert_type_error(lambda: kfold_random(nrows=5, seed=12345678), "Required parameter `nsplits` is missing") assert_type_error(lambda: kfold_random(5, 2), "kfold_random() takes no positional arguments, but 2 were given") assert_type_error(lambda: kfold_random(nrows=5, nsplits=3.3), "Argument `nsplits` in kfold_random() should be an integer") assert_type_error(lambda: kfold_random(nrows=None, nsplits=7), "Argument `nrows` in kfold_random() should be an integer") assert_type_error(lambda: kfold_random(nrows=5, nsplits=2, seed="boo"), "Argument `seed` in kfold_random() should be an integer")
def test_kfold_random_api(): msg = "Required parameter nrows is missing" with pytest.raises(TypeError, match=msg): kfold_random() msg = "Required parameter nsplits is missing" with pytest.raises(TypeError, match=msg): kfold_random(nrows=5, seed=12345678) msg = r"kfold_random\(\) takes no positional arguments, but 2 were given" with pytest.raises(TypeError, match=msg): kfold_random(5, 2) msg = r"Argument nsplits in function datatable.kfold_random\(\) should be an integer" with pytest.raises(TypeError, match=msg): kfold_random(nrows=5, nsplits=3.3) msg = r"Argument nrows in function datatable.kfold_random\(\) should be an integer" with pytest.raises(TypeError, match=msg): kfold_random(nrows=None, nsplits=7) msg = r"Argument seed in function datatable.kfold_random\(\) should be an integer" with pytest.raises(TypeError, match=msg): kfold_random(nrows=5, nsplits=2, seed="boo")