def test_build_dataset(digits, with_y, expected):
    X_train, X_test, y_train, y_test = digits
    mod = TorchShallowNeuralClassifier()
    if with_y:
        dataset = mod.build_dataset(X_train, y_train)
    else:
        dataset = mod.build_dataset(X_train)
    result = next(iter(dataset))
    assert len(result) == expected
def test_build_dataset_input_dim(digits, early_stopping):
    X_train, X_test, y_train, y_test = digits
    mod = TorchShallowNeuralClassifier(early_stopping=early_stopping)
    dataset = mod.build_dataset(X_train, y_train)
    assert mod.input_dim == X_train.shape[1]