예제 #1
0
 def test_sample_noise_with_gradients_raise(self):
     with self.assertRaises(ValueError):
         _, _ = perturbations.sample_noise_with_gradients(
             'unknown', (3, 2, 4))
예제 #2
0
 def test_sample_noise_with_gradients(self, noise):
     shape = (3, 2, 4)
     samples, gradients = perturbations.sample_noise_with_gradients(
         noise, shape)
     self.assertAllEqual(samples.shape, shape)
     self.assertAllEqual(gradients.shape, shape)