def test_class_gradient(self):
        # Get MNIST
        (_, _), (x_test, y_test), _, _ = load_mnist()
        x_test = x_test[:NB_TEST]
        x_test = np.swapaxes(x_test, 1, 3)

        # Test gradient
        ptc = PyTorchClassifier(None, self._model, self._loss_fn,
                                self._optimizer, (1, 28, 28), (10, ))
        grads = ptc.class_gradient(x_test)

        self.assertTrue(
            np.array(grads.shape == (NB_TEST, 10, 1, 28, 28)).all())
        self.assertTrue(np.sum(grads) != 0)
Exemplo n.º 2
0
    def test_class_gradient_target(self):
        # Get MNIST
        (_, _), (x_test, y_test), _, _ = load_mnist()
        x_test = x_test[:NB_TEST]
        x_test = np.swapaxes(x_test, 1, 3)

        # Create model
        model, loss_fn, optimizer = self._model_setup_module()

        # Test gradient
        ptc = PyTorchClassifier((0, 1), model, loss_fn, optimizer, (1, 28, 28),
                                10)
        grads = ptc.class_gradient(x_test, label=3)

        self.assertTrue(np.array(grads.shape == (NB_TEST, 1, 1, 28, 28)).all())
        self.assertTrue(np.sum(grads) != 0)