def test_raises_if_both_z_and_n_are_not_none(self): dist = tfd.Normal(loc=0., scale=1.) z = dist.sample(seed=42) n = 1 seed = None with self.assertRaisesRegexp(ValueError, 'exactly one'): _get_samples(dist, z, n, seed)
def test_raises_if_both_z_and_n_are_none(self): with self.test_session(): dist = normal_lib.Normal(loc=0., scale=1.) z = None n = None seed = None with self.assertRaisesRegexp(ValueError, 'exactly one'): _get_samples(dist, z, n, seed)
def test_returns_z_if_z_provided(self): dist = tfd.Normal(loc=0., scale=1.) z = dist.sample(10, seed=42) n = None seed = None z = _get_samples(dist, z, n, seed) self.assertEqual((10, ), z.get_shape())
def test_returns_n_samples_if_n_provided(self): dist = tfd.Normal(loc=0., scale=1.) z = None n = 10 seed = None z = _get_samples(dist, z, n, seed) self.assertEqual((10, ), z.get_shape())
def test_returns_z_if_z_provided(self): dist = tfd.Normal(loc=0., scale=1.) z = dist.sample(10, seed=42) n = None seed = None z = _get_samples(dist, z, n, seed) self.assertEqual((10,), z.shape)
def test_returns_n_samples_if_n_provided(self): dist = tfd.Normal(loc=0., scale=1.) z = None n = 10 seed = None z = _get_samples(dist, z, n, seed) self.assertEqual((10,), z.shape)