コード例 #1
0
    def test_check_params(self):
        with self.assertRaises(ValueError):
            _ = GaussianAugmentation(augmentation=True, ratio=-1)

        with self.assertRaises(ValueError):
            _ = GaussianAugmentation(clip_values=(0, 1, 2))

        with self.assertRaises(ValueError):
            _ = GaussianAugmentation(clip_values=(1, 0))
    def test_failure_augmentation_fit_predict(self):
        # Assert that value error is raised
        with self.assertRaises(ValueError) as context:
            _ = GaussianAugmentation(augmentation=True, apply_fit=False, apply_predict=True)

        self.assertTrue(
            "If `augmentation` is `True`, then `apply_fit` must be `True` and `apply_predict`"
            " must be `False`." in str(context.exception)
        )
        with self.assertRaises(ValueError) as context:
            _ = GaussianAugmentation(augmentation=True, apply_fit=False, apply_predict=False)

        self.assertIn(
            "If `augmentation` is `True`, then `apply_fit` and `apply_predict` can't be both `False`.",
            str(context.exception),
        )
 def test_multiple_size(self):
     x = np.arange(12).reshape((4, 3))
     x_original = x.copy()
     ga = GaussianAugmentation(ratio=3.5)
     x_new, _ = ga(x)
     self.assertEqual(int(4.5 * x.shape[0]), x_new.shape[0])
     # Check that x has not been modified by attack and classifier
     self.assertAlmostEqual(float(np.max(np.abs(x_original - x))), 0.0, delta=0.00001)
    def test_labels(self):
        x = np.arange(12).reshape((4, 3))
        y = np.arange(8).reshape((4, 2))

        ga = GaussianAugmentation()
        x_new, new_y = ga(x, y)
        self.assertTrue(x_new.shape[0] == new_y.shape[0] == 8)
        self.assertEqual(x_new.shape[1:], x.shape[1:])
        self.assertEqual(new_y.shape[1:], y.shape[1:])
 def test_no_augmentation(self):
     x = np.arange(12).reshape((4, 3))
     ga = GaussianAugmentation(augmentation=False)
     x_new, _ = ga(x)
     self.assertEqual(x.shape, x_new.shape)
     self.assertFalse((x == x_new).all())
 def test_small_size(self):
     x = np.arange(15).reshape((5, 3))
     ga = GaussianAugmentation(ratio=0.4)
     x_new, _ = ga(x)
     self.assertEqual(x_new.shape, (7, 3))