예제 #1
0
def test_bad_args2():
    assert_value_error(lambda: kfold(nrows=-1, nsplits=1),
        "Argument `nrows` in kfold() cannot be negative")

    assert_value_error(lambda: kfold(nrows=3, nsplits=-3),
        "Argument `nsplits` in kfold() cannot be negative")

    assert_value_error(lambda: kfold(nrows=5, nsplits=0),
        "The number of splits cannot be less than two")

    assert_value_error(lambda: kfold(nrows=1, nsplits=2),
        "The number of splits cannot exceed the number of rows")
예제 #2
0
def test_kfold_k_2(seed):
    random.seed(seed)
    n = 2 + int(random.expovariate(0.01))
    h = n // 2
    splits = kfold(nrows=n, nsplits=2)
    assert splits == [(range(h, n), range(0, h)),
                      (range(0, h), range(h, n))]
예제 #3
0
def test_kfold_k_3(seed):
    random.seed(seed)
    n = 3 + int(random.expovariate(0.01))
    h1 = n // 3
    h2 = 2 * n // 3
    splits = kfold(nrows=n, nsplits=3)
    assert len(splits) == 3
    assert splits[0] == (range(h1, n), range(0, h1))
    assert splits[1][1] == range(h1, h2)
    assert splits[2] == (range(0, h2), range(h2, n))
    assert isinstance(splits[1][0], dt.Frame)
    assert splits[1][0].nrows == h1 + n - h2
    assert splits[1][0].to_list()[0] == list(range(0, h1)) + list(range(h2, n))
예제 #4
0
def test_bad_args2():
    msg = r"Argument nrows in function datatable.kfold\(\) cannot be negative"
    with pytest.raises(ValueError, match=msg):
        kfold(nrows=-1, nsplits=1)

    msg = r"Argument nsplits in function datatable.kfold\(\) cannot be negative"
    with pytest.raises(ValueError, match=msg):
        kfold(nrows=3, nsplits=-3)

    msg = "The number of splits cannot be less than two"
    with pytest.raises(ValueError, match=msg):
        kfold(nrows=5, nsplits=0)

    msg = "The number of splits cannot exceed the number of rows"
    with pytest.raises(ValueError, match=msg):
        kfold(nrows=1, nsplits=2)
예제 #5
0
def test_kfold_k_any(seed):
    random.seed(seed)
    k = 2 + int(random.expovariate(0.01))
    n = k + int(random.expovariate(0.0001))
    splits = kfold(nrows=n, nsplits=k)
    h1 = n // k
    h2 = n * (k-1) // k
    assert len(splits) == k
    assert splits[0] == (range(h1, n), range(0, h1))
    assert splits[-1] == (range(0, h2), range(h2, n))
    for j, split in enumerate(splits[1:-1], 1):
        hl = j * n // k
        hu = (j + 1) * n // k
        assert split[1] == range(hl, hu)
        assert isinstance(split[0], dt.Frame)
        assert split[0].nrows + len(split[1]) == n
        assert split[0].ncols == 1
        assert split[0].to_list()[0] == list(range(0, hl)) + list(range(hu, n))
예제 #6
0
def test_kfold_api():
    assert_type_error(lambda: kfold(),
        "Required parameter `nrows` is missing")

    assert_type_error(lambda: kfold(nrows=5),
        "Required parameter `nsplits` is missing")

    assert_type_error(lambda: kfold(nrows=5, nsplits=2, seed=12345),
        "kfold() got an unexpected keyword argument `seed`")

    assert_type_error(lambda: kfold(5, 2),
        "kfold() takes no positional arguments, but 2 were given")

    assert_type_error(lambda: kfold(nrows=5, nsplits=3.3),
        "Argument `nsplits` in kfold() should be an integer")

    assert_type_error(lambda: kfold(nrows=None, nsplits=7),
        "Argument `nrows` in kfold() should be an integer")
예제 #7
0
def test_kfold_simple():
    splits = kfold(nrows=2, nsplits=2)
    assert splits == [(range(1, 2), range(0, 1)),
                      (range(0, 1), range(1, 2))]
예제 #8
0
def test_kfold_api():
    msg = "Required parameter nrows is missing"
    with pytest.raises(TypeError, match=msg):
        kfold()

    msg = "Required parameter nsplits is missing"
    with pytest.raises(TypeError, match=msg):
        kfold(nrows=5)

    msg = r"kfold\(\) got an unexpected keyword argument seed"
    with pytest.raises(TypeError, match=msg):
        kfold(nrows=5, nsplits=2, seed=12345)

    msg = r"kfold\(\) takes no positional arguments, but 2 were given"
    with pytest.raises(TypeError, match=msg):
        kfold(5, 2)

    msg = r"Argument nsplits in function datatable.kfold\(\) should be an integer"
    with pytest.raises(TypeError, match=msg):
        kfold(nrows=5, nsplits=3.3)

    msg = r"Argument nrows in function datatable.kfold\(\) should be an integer"
    with pytest.raises(TypeError, match=msg):
        kfold(nrows=None, nsplits=7)