def test_raises_if_both_z_and_n_are_not_none(self):
     dist = tfd.Normal(loc=0., scale=1.)
     z = dist.sample(seed=42)
     n = 1
     seed = None
     with self.assertRaisesRegexp(ValueError, 'exactly one'):
         _get_samples(dist, z, n, seed)
 def test_returns_z_if_z_provided(self):
     dist = tfd.Normal(loc=0., scale=1.)
     z = dist.sample(10, seed=42)
     n = None
     seed = None
     z = _get_samples(dist, z, n, seed)
     self.assertEqual((10, ), z.shape)
 def test_returns_n_samples_if_n_provided(self):
     dist = tfd.Normal(loc=0., scale=1.)
     z = None
     n = 10
     seed = None
     z = _get_samples(dist, z, n, seed)
     self.assertEqual((10, ), z.shape)