예제 #1
0
 def test_raises_if_both_z_and_n_are_not_none(self):
     dist = tfd.Normal(loc=0., scale=1.)
     z = dist.sample(seed=42)
     n = 1
     seed = None
     with self.assertRaisesRegexp(ValueError, 'exactly one'):
         _get_samples(dist, z, n, seed)
예제 #2
0
 def test_raises_if_both_z_and_n_are_not_none(self):
   dist = tfd.Normal(loc=0., scale=1.)
   z = dist.sample(seed=42)
   n = 1
   seed = None
   with self.assertRaisesRegexp(ValueError, 'exactly one'):
     _get_samples(dist, z, n, seed)
예제 #3
0
 def test_raises_if_both_z_and_n_are_none(self):
   with self.test_session():
     dist = normal_lib.Normal(loc=0., scale=1.)
     z = None
     n = None
     seed = None
     with self.assertRaisesRegexp(ValueError, 'exactly one'):
       _get_samples(dist, z, n, seed)
예제 #4
0
 def test_returns_z_if_z_provided(self):
     dist = tfd.Normal(loc=0., scale=1.)
     z = dist.sample(10, seed=42)
     n = None
     seed = None
     z = _get_samples(dist, z, n, seed)
     self.assertEqual((10, ), z.get_shape())
예제 #5
0
 def test_returns_n_samples_if_n_provided(self):
     dist = tfd.Normal(loc=0., scale=1.)
     z = None
     n = 10
     seed = None
     z = _get_samples(dist, z, n, seed)
     self.assertEqual((10, ), z.get_shape())
예제 #6
0
 def test_returns_z_if_z_provided(self):
   dist = tfd.Normal(loc=0., scale=1.)
   z = dist.sample(10, seed=42)
   n = None
   seed = None
   z = _get_samples(dist, z, n, seed)
   self.assertEqual((10,), z.shape)
예제 #7
0
 def test_returns_n_samples_if_n_provided(self):
   dist = tfd.Normal(loc=0., scale=1.)
   z = None
   n = 10
   seed = None
   z = _get_samples(dist, z, n, seed)
   self.assertEqual((10,), z.shape)