def test_classifier_match(self):
        attack = FastGradientMethod(self.classifier)
        adv_trainer = AdversarialTrainer(self.classifier, attack)

        self.assertEqual(len(adv_trainer.attacks), 1)
        self.assertEqual(adv_trainer.attacks[0].estimator,
                         adv_trainer.get_classifier())
class AdversarialTrainerMadryPGD(Trainer):
    """
    Class performing adversarial training following Madry's Protocol.
    Paper link: https://arxiv.org/abs/1706.06083
    Please keep in mind the limitations of defences. While adversarial training is
    widely regarded as a promising, principled approach to making classifiers more
    robust (see https://arxiv.org/abs/1802.00420), very careful evaluations are
    required to assess its effectiveness case by case (see https://arxiv.org/abs/1902.06705).
    """
    def __init__(self,
                 classifier,
                 eps=0.03,
                 eps_step=0.008,
                 max_iter=7,
                 ratio=1.0):
        self.attack = ProjectedGradientDescent(
            classifier,
            eps=eps,
            eps_step=eps_step,
            max_iter=max_iter,
        )

        self.trainer = AdversarialTrainer(classifier, self.attack, ratio=ratio)

    def fit(self, x, y, **kwargs):
        self.trainer.fit(x, y, **kwargs)

    def fit_generator(self, generator, nb_epochs, **kwargs):
        self.trainer.fit_generator(generator, nb_epochs=nb_epochs, **kwargs)

    def get_classifier(self):
        return self.trainer.get_classifier()
class AdversarialTrainerMadryPGD(Trainer):
    """
    Class performing adversarial training following Madry's Protocol.

    | Paper link: https://arxiv.org/abs/1706.06083

    | Please keep in mind the limitations of defences. While adversarial training is widely regarded as a promising,
        principled approach to making classifiers more robust (see https://arxiv.org/abs/1802.00420), very careful
        evaluations are required to assess its effectiveness case by case (see https://arxiv.org/abs/1902.06705).
    """

    def __init__(
        self,
        classifier: "ClassifierGradients",
        nb_epochs: int = 391,
        batch_size: int = 128,
        eps: float = 8.0,
        eps_step: float = 2.0,
        max_iter: int = 7,
        num_random_init: Union[bool, int] = True,
    ) -> None:
        """
        Create an :class:`.AdversarialTrainerMadryPGD` instance.

        Default values are for CIFAR-10 in pixel range 0-255.

        :param classifier: Classifier to train adversarially.
        :param nb_epochs: Number of training epochs.
        :param batch_size: Size of the batch on which adversarial samples are generated.
        :param eps: Maximum perturbation that the attacker can introduce.
        :param eps_step: Attack step size (input variation) at each iteration.
        :param max_iter: The maximum number of iterations.
        :param num_random_init: Number of random initialisations within the epsilon ball. For num_random_init=0
            starting at the original input.
        """
        super(AdversarialTrainerMadryPGD, self).__init__(classifier=classifier)  # type: ignore
        self.batch_size = batch_size
        self.nb_epochs = nb_epochs

        # Setting up adversary and perform adversarial training:
        self.attack = ProjectedGradientDescent(
            classifier, eps=eps, eps_step=eps_step, max_iter=max_iter, num_random_init=num_random_init,
        )

        self.trainer = AdversarialTrainer(classifier, self.attack, ratio=1.0)  # type: ignore

    def fit(self, x: np.ndarray, y: np.ndarray, validation_data: Optional[np.ndarray] = None, **kwargs) -> None:
        """
        Train a model adversarially. See class documentation for more information on the exact procedure.

        :param x: Training data.
        :param y: Labels for the training data.
        :param validation_data: Validation data.
        :param kwargs: Dictionary of framework-specific arguments.
        """
        self.trainer.fit(
            x, y, validation_data=validation_data, nb_epochs=self.nb_epochs, batch_size=self.batch_size, **kwargs
        )

    def get_classifier(self) -> "Classifier":
        return self.trainer.get_classifier()