def test_build_dataset(digits, with_y, expected): X_train, X_test, y_train, y_test = digits mod = TorchShallowNeuralClassifier() if with_y: dataset = mod.build_dataset(X_train, y_train) else: dataset = mod.build_dataset(X_train) result = next(iter(dataset)) assert len(result) == expected
def test_build_dataset_input_dim(digits, early_stopping): X_train, X_test, y_train, y_test = digits mod = TorchShallowNeuralClassifier(early_stopping=early_stopping) dataset = mod.build_dataset(X_train, y_train) assert mod.input_dim == X_train.shape[1]