def test_raises_if_both_z_and_n_are_not_none(self):
     with self.cached_session():
         dist = normal_lib.Normal(loc=0., scale=1.)
         z = dist.sample(seed=42)
         n = 1
         seed = None
         with self.assertRaisesRegexp(ValueError, 'exactly one'):
             _get_samples(dist, z, n, seed)
예제 #2
0
 def test_raises_if_both_z_and_n_are_not_none(self):
   with self.test_session():
     dist = distributions.Normal(loc=0., scale=1.)
     z = dist.sample(seed=42)
     n = 1
     seed = None
     with self.assertRaisesRegexp(ValueError, 'exactly one'):
       _get_samples(dist, z, n, seed)
예제 #3
0
 def test_raises_if_both_z_and_n_are_none(self):
   with self.test_session():
     dist = distributions.Normal(loc=0., scale=1.)
     z = None
     n = None
     seed = None
     with self.assertRaisesRegexp(ValueError, 'exactly one'):
       _get_samples(dist, z, n, seed)
예제 #4
0
 def test_raises_if_both_z_and_n_are_none(self):
   with self.test_session():
     dist = normal_lib.Normal(loc=0., scale=1.)
     z = None
     n = None
     seed = None
     with self.assertRaisesRegexp(ValueError, 'exactly one'):
       _get_samples(dist, z, n, seed)
예제 #5
0
 def test_returns_z_if_z_provided(self):
   with self.test_session():
     dist = distributions.Normal(loc=0., scale=1.)
     z = dist.sample(10, seed=42)
     n = None
     seed = None
     z = _get_samples(dist, z, n, seed)
     self.assertEqual((10,), z.get_shape())
 def test_returns_z_if_z_provided(self):
     with self.cached_session():
         dist = normal_lib.Normal(loc=0., scale=1.)
         z = dist.sample(10, seed=42)
         n = None
         seed = None
         z = _get_samples(dist, z, n, seed)
         self.assertEqual((10, ), z.get_shape())
예제 #7
0
 def test_returns_n_samples_if_n_provided(self):
   with self.test_session():
     dist = distributions.Normal(loc=0., scale=1.)
     z = None
     n = 10
     seed = None
     z = _get_samples(dist, z, n, seed)
     self.assertEqual((10,), z.get_shape())
예제 #8
0
 def test_returns_n_samples_if_n_provided(self):
   with self.test_session():
     dist = normal_lib.Normal(loc=0., scale=1.)
     z = None
     n = 10
     seed = None
     z = _get_samples(dist, z, n, seed)
     self.assertEqual((10,), z.get_shape())