def test_adversarial_training_range():
    """
    This unit test checks the input range check in adversarial_training.py.
    """
    model = torchensemble.AdversarialTrainingClassifier(estimator=MLP_clf,
                                                        n_estimators=2,
                                                        cuda=False)

    model.set_optimizer("Adam")

    # Prepare data
    train = TensorDataset(X_train, y_train_clf)
    train_loader = DataLoader(train, batch_size=2)

    # Training
    with pytest.raises(ValueError) as excinfo:
        model.fit(train_loader)
    assert "input range of samples passed to adversarial" in str(excinfo.value)
def test_adversarial_training():
    model = torchensemble.AdversarialTrainingClassifier(estimator=MLP,
                                                        n_estimators=2,
                                                        cuda=False)

    # Epochs
    with pytest.raises(ValueError) as excinfo:
        model.fit(train_loader, epochs=-1)
    assert "number of training epochs" in str(excinfo.value)

    # Epsilon
    with pytest.raises(ValueError) as excinfo:
        model.fit(train_loader, epsilon=2)
    assert "step used to generate adversarial samples" in str(excinfo.value)

    # Log interval
    with pytest.raises(ValueError) as excinfo:
        model.fit(train_loader, log_interval=-1)
    assert "number of batches to wait" in str(excinfo.value)