def testBernoulli(self): mean = 0.23 shape = [13, 34, 29] atol = 0.1 outputs = extensions.bernoulli(123, mean, shape) self.assertAllClose(mean, np.mean(outputs), atol=atol)
def testBernoulliWrongShape(self): mean = [0.1, 0.2] shape = [3] with self.assertRaisesWithPredicateMatch(tf.errors.InvalidArgumentError, r"Incompatible shapes"): extensions.bernoulli(123, mean, shape)
def testBernoulliWrongShape(self): mean = [0.1, 0.2] shape = [3] with self.assertRaisesIncompatibleShapesError(): extensions.bernoulli(123, mean, shape)