コード例 #1
0
def test_train_valid_split_return_dataset(dataset):
    splitter = RandomSplitter()
    train, valid = splitter.train_valid_split(dataset, return_index=False)
    assert type(train) == NumpyTupleDataset
    assert type(valid) == NumpyTupleDataset
    assert len(train) == 9
    assert len(valid) == 1
コード例 #2
0
def test_train_valid_test_split(dataset):
    splitter = RandomSplitter()
    train_ind, valid_ind, test_ind = splitter.train_valid_test_split(dataset)
    assert type(train_ind) == numpy.ndarray
    assert train_ind.shape[0] == 8
    assert valid_ind.shape[0] == 1
    assert test_ind.shape[0] == 1
コード例 #3
0
def test_split_fail(dataset):
    splitter = RandomSplitter()
    with pytest.raises(AssertionError):
        train_ind, valid_ind, test_ind = splitter._split(dataset,
                                                         frac_train=0.4,
                                                         frac_valid=0.3,
                                                         frac_test=0.2)
コード例 #4
0
def test_split_fix_seed(dataset):
    splitter = RandomSplitter()
    train_ind1, valid_ind1, test_ind1 = splitter._split(dataset, seed=44)
    train_ind2, valid_ind2, test_ind2 = splitter._split(dataset, seed=44)

    assert numpy.array_equal(train_ind1, train_ind2)
    assert numpy.array_equal(valid_ind1, valid_ind2)
    assert numpy.array_equal(test_ind1, test_ind2)
コード例 #5
0
def test_train_valid_test_split_ndarray_return_dataset(ndarray_dataset):
    splitter = RandomSplitter()
    train, valid, test = splitter.train_valid_test_split(ndarray_dataset,
                                                         return_index=False)
    assert type(train) == numpy.ndarray
    assert type(valid) == numpy.ndarray
    assert type(test) == numpy.ndarray
    assert len(train) == 8
    assert len(valid) == 1
    assert len(test) == 1
コード例 #6
0
def test_split(dataset):
    splitter = RandomSplitter()
    train_ind, valid_ind, test_ind = splitter._split(dataset)
    assert type(train_ind) == numpy.ndarray
    assert train_ind.shape[0] == 8
    assert valid_ind.shape[0] == 1
    assert test_ind.shape[0] == 1

    train_ind, valid_ind, test_ind = splitter._split(dataset,
                                                     frac_train=0.5,
                                                     frac_valid=0.3,
                                                     frac_test=0.2)
    assert type(train_ind) == numpy.ndarray
    assert train_ind.shape[0] == 5
    assert valid_ind.shape[0] == 3
    assert test_ind.shape[0] == 2