예제 #1
0
 def testBernoulli(self):
   mean = 0.23
   shape = [13, 34, 29]
   atol = 0.1
   outputs = extensions.bernoulli(123, mean, shape)
   self.assertAllClose(mean, np.mean(outputs), atol=atol)
예제 #2
0
 def testBernoulliWrongShape(self):
   mean = [0.1, 0.2]
   shape = [3]
   with self.assertRaisesWithPredicateMatch(tf.errors.InvalidArgumentError,
                                            r"Incompatible shapes"):
     extensions.bernoulli(123, mean, shape)
예제 #3
0
 def testBernoulliWrongShape(self):
     mean = [0.1, 0.2]
     shape = [3]
     with self.assertRaisesIncompatibleShapesError():
         extensions.bernoulli(123, mean, shape)