def test_raises_if_both_z_and_n_are_not_none(self): dist = tfd.Normal(loc=0., scale=1.) z = dist.sample(seed=42) n = 1 seed = None with self.assertRaisesRegexp(ValueError, 'exactly one'): _get_samples(dist, z, n, seed)
def test_returns_z_if_z_provided(self): dist = tfd.Normal(loc=0., scale=1.) z = dist.sample(10, seed=42) n = None seed = None z = _get_samples(dist, z, n, seed) self.assertEqual((10, ), z.shape)
def test_returns_n_samples_if_n_provided(self): dist = tfd.Normal(loc=0., scale=1.) z = None n = 10 seed = None z = _get_samples(dist, z, n, seed) self.assertEqual((10, ), z.shape)