Ejemplo n.º 1
0
def test_kfold_random_stable(seed):
    random.seed(seed)
    k = 2 + int(random.expovariate(0.2))
    n = k + int(random.expovariate(0.000001))
    splits1 = kfold_random(nrows=n, nsplits=k, seed=seed)
    with dt.options.context(nthreads=dt.options.nthreads // 2):
        splits2 = kfold_random(nrows=n, nsplits=k, seed=seed)
    assert len(splits1) == len(splits2) == k
    for i in range(k):
        assert len(splits1[i]) == len(splits2[i]) == 2
        for j in range(2):
            values1 = splits1[i][j].to_list()[0]
            values2 = splits2[i][j].to_list()[0]
            assert values1 == values2
Ejemplo n.º 2
0
def test_kfold_random_bad_args():
    assert_value_error(lambda: kfold_random(nrows=-1, nsplits=1),
        "Argument `nrows` in kfold_random() cannot be negative")

    assert_value_error(lambda: kfold_random(nrows=3, nsplits=-3),
        "Argument `nsplits` in kfold_random() cannot be negative")

    assert_value_error(lambda: kfold_random(nrows=10, nsplits=3, seed=-1),
        "Argument `seed` in kfold_random() cannot be negative")

    assert_value_error(lambda: kfold_random(nrows=5, nsplits=0),
        "The number of splits cannot be less than two")

    assert_value_error(lambda: kfold_random(nrows=1, nsplits=2),
        "The number of splits cannot exceed the number of rows")
Ejemplo n.º 3
0
def test_kfold_random_any(seed):
    random.seed(seed)
    k = 2 + int(random.expovariate(0.01))
    n = k + int(random.expovariate(0.0001))
    splits = kfold_random(nrows=n, nsplits=k, seed=seed)
    assert isinstance(splits, list) and len(splits) == k
    all_folds = []
    for split in splits:
        assert isinstance(split, tuple) and len(split) == 2
        train, test = split
        assert isinstance(train, dt.Frame)
        assert isinstance(test, dt.Frame)
        assert train.stypes == (dt.int32,)
        assert test.stypes == (dt.int32,)
        assert train.ncols == 1
        assert test.ncols == 1
        assert train.nrows + test.nrows == n
        assert test.nrows in [n//k, n//k + 1]
        train_data = train.to_list()[0]
        test_data = test.to_list()[0]
        assert train_data == sorted(train_data)
        assert test_data == sorted(test_data)
        train_set = set(train_data)
        test_set = set(test_data)
        assert len(train_set) == train.nrows
        assert len(test_set) == test.nrows
        assert (train_set | test_set) == set(range(n))
        all_folds += test_data

    all_folds.sort()
    assert all_folds == list(range(n))
Ejemplo n.º 4
0
def test_kfold_random_2_2():
    splits = kfold_random(nrows=2, nsplits=2)
    assert isinstance(splits, list) and len(splits) == 2
    assert all(isinstance(s, tuple) and len(s) == 2 for s in splits)
    assert all(s[0].shape == (1, 1) and s[1].shape == (1, 1) for s in splits)
    a = splits[0][0][0, 0]
    assert a == 0 or a == 1
    assert splits[0][1][0, 0] == 1 - a
    assert splits[1][0][0, 0] == 1 - a
    assert splits[1][1][0, 0] == a
Ejemplo n.º 5
0
def test_kfold_random_bad_args():
    msg = r"Argument nrows in function datatable.kfold_random\(\) cannot be negative"
    with pytest.raises(ValueError, match=msg):
        kfold_random(nrows=-1, nsplits=1)

    msg = r"Argument nsplits in function datatable.kfold_random\(\) cannot be negative"
    with pytest.raises(ValueError, match=msg):
        kfold_random(nrows=3, nsplits=-3)

    msg = r"Argument seed in function datatable.kfold_random\(\) cannot be negative"
    with pytest.raises(ValueError, match=msg):
        kfold_random(nrows=10, nsplits=3, seed=-1)

    msg = "The number of splits cannot be less than two"
    with pytest.raises(ValueError, match=msg):
        kfold_random(nrows=5, nsplits=0)

    msg = "The number of splits cannot exceed the number of rows"
    with pytest.raises(ValueError, match=msg):
        kfold_random(nrows=1, nsplits=2)
Ejemplo n.º 6
0
def test_kfold_random_api():
    assert_type_error(lambda: kfold_random(),
        "Required parameter `nrows` is missing")

    assert_type_error(lambda: kfold_random(nrows=5, seed=12345678),
        "Required parameter `nsplits` is missing")

    assert_type_error(lambda: kfold_random(5, 2),
        "kfold_random() takes no positional arguments, but 2 were given")

    assert_type_error(lambda: kfold_random(nrows=5, nsplits=3.3),
        "Argument `nsplits` in kfold_random() should be an integer")

    assert_type_error(lambda: kfold_random(nrows=None, nsplits=7),
        "Argument `nrows` in kfold_random() should be an integer")

    assert_type_error(lambda: kfold_random(nrows=5, nsplits=2, seed="boo"),
        "Argument `seed` in kfold_random() should be an integer")
Ejemplo n.º 7
0
def test_kfold_random_api():
    msg = "Required parameter nrows is missing"
    with pytest.raises(TypeError, match=msg):
        kfold_random()

    msg = "Required parameter nsplits is missing"
    with pytest.raises(TypeError, match=msg):
        kfold_random(nrows=5, seed=12345678)

    msg = r"kfold_random\(\) takes no positional arguments, but 2 were given"
    with pytest.raises(TypeError, match=msg):
        kfold_random(5, 2)

    msg = r"Argument nsplits in function datatable.kfold_random\(\) should be an integer"
    with pytest.raises(TypeError, match=msg):
        kfold_random(nrows=5, nsplits=3.3)

    msg = r"Argument nrows in function datatable.kfold_random\(\) should be an integer"
    with pytest.raises(TypeError, match=msg):
        kfold_random(nrows=None, nsplits=7)

    msg = r"Argument seed in function datatable.kfold_random\(\) should be an integer"
    with pytest.raises(TypeError, match=msg):
        kfold_random(nrows=5, nsplits=2, seed="boo")