def test_adversarial_training_range(): """ This unit test checks the input range check in adversarial_training.py. """ model = torchensemble.AdversarialTrainingClassifier(estimator=MLP_clf, n_estimators=2, cuda=False) model.set_optimizer("Adam") # Prepare data train = TensorDataset(X_train, y_train_clf) train_loader = DataLoader(train, batch_size=2) # Training with pytest.raises(ValueError) as excinfo: model.fit(train_loader) assert "input range of samples passed to adversarial" in str(excinfo.value)
def test_adversarial_training(): model = torchensemble.AdversarialTrainingClassifier(estimator=MLP, n_estimators=2, cuda=False) # Epochs with pytest.raises(ValueError) as excinfo: model.fit(train_loader, epochs=-1) assert "number of training epochs" in str(excinfo.value) # Epsilon with pytest.raises(ValueError) as excinfo: model.fit(train_loader, epsilon=2) assert "step used to generate adversarial samples" in str(excinfo.value) # Log interval with pytest.raises(ValueError) as excinfo: model.fit(train_loader, log_interval=-1) assert "number of batches to wait" in str(excinfo.value)