def test_sample_noise_with_gradients_raise(self): with self.assertRaises(ValueError): _, _ = perturbations.sample_noise_with_gradients( 'unknown', (3, 2, 4))
def test_sample_noise_with_gradients(self, noise): shape = (3, 2, 4) samples, gradients = perturbations.sample_noise_with_gradients( noise, shape) self.assertAllEqual(samples.shape, shape) self.assertAllEqual(gradients.shape, shape)