Ejemplo n.º 1
0
def test_train_valid_test_split(dataset):
    splitter = RandomSplitter()
    train_ind, valid_ind, test_ind = splitter.train_valid_test_split(dataset)
    assert type(train_ind) == numpy.ndarray
    assert train_ind.shape[0] == 8
    assert valid_ind.shape[0] == 1
    assert test_ind.shape[0] == 1
Ejemplo n.º 2
0
def test_train_valid_test_split_ndarray_return_dataset(ndarray_dataset):
    splitter = RandomSplitter()
    train, valid, test = splitter.train_valid_test_split(ndarray_dataset,
                                                         return_index=False)
    assert type(train) == numpy.ndarray
    assert type(valid) == numpy.ndarray
    assert type(test) == numpy.ndarray
    assert len(train) == 8
    assert len(valid) == 1
    assert len(test) == 1
Ejemplo n.º 3
0
def test_train_valid_test_split_return_dataset(dataset):
    splitter = RandomSplitter()
    train, valid, test = splitter.train_valid_test_split(dataset,
                                                         return_index=False)
    assert type(train) == NumpyTupleDataset
    assert type(valid) == NumpyTupleDataset
    assert type(test) == NumpyTupleDataset
    assert len(train) == 8
    assert len(valid) == 1
    assert len(test) == 1