def testLogits(self): logits = [-42., 42.] dist = bernoulli.Bernoulli(logits=logits) with self.test_session(): self.assertAllClose(logits, dist.logits.eval()) with self.test_session(): self.assertAllClose(scipy.special.expit(logits), dist.p.eval()) p = [0.01, 0.99, 0.42] dist = bernoulli.Bernoulli(p=p) with self.test_session(): self.assertAllClose(scipy.special.logit(p), dist.logits.eval())
def testEntropyWithBatch(self): p = [[0.1, 0.7], [0.2, 0.6]] dist = bernoulli.Bernoulli(p=p, validate_args=False) with self.test_session(): self.assertAllClose(dist.entropy().eval(), [[entropy(0.1), entropy(0.7)], [entropy(0.2), entropy(0.6)]])
def sample(self, time, outputs, state, name=None): with ops.name_scope(name, "ScheduledOutputTrainingHelperSample", [time, outputs, state]): sampler = bernoulli.Bernoulli(probs=self._sampling_probability) return math_ops.cast( sampler.sample(sample_shape=self.batch_size, seed=self._seed), dtypes.bool)
def testBernoulliBernoulliKL(self): with self.test_session() as sess: batch_size = 6 a_p = np.array([0.5] * batch_size, dtype=np.float32) b_p = np.array([0.4] * batch_size, dtype=np.float32) a = bernoulli.Bernoulli(p=a_p) b = bernoulli.Bernoulli(p=b_p) kl = kullback_leibler.kl(a, b) kl_val = sess.run(kl) kl_expected = (a_p * np.log(a_p / b_p) + (1. - a_p) * np.log( (1. - a_p) / (1. - b_p))) self.assertEqual(kl.get_shape(), (batch_size, )) self.assertAllClose(kl_val, kl_expected)
def testPmfInvalid(self): p = [0.1, 0.2, 0.7] with self.test_session(): dist = bernoulli.Bernoulli(probs=p, validate_args=True) with self.assertRaisesOpError("must be non-negative."): dist.prob([1, 1, -1]).eval() with self.assertRaisesOpError("is not less than or equal to 1."): dist.prob([2, 0, 1]).eval()
def testBroadcasting(self): with self.test_session(): p = array_ops.placeholder(dtypes.float32) dist = bernoulli.Bernoulli(p=p) self.assertAllClose(np.log(0.5), dist.log_pmf(1).eval({p: 0.5})) self.assertAllClose(np.log([0.5, 0.5, 0.5]), dist.log_pmf([1, 1, 1]).eval({p: 0.5})) self.assertAllClose(np.log([0.5, 0.5, 0.5]), dist.log_pmf(1).eval({p: [0.5, 0.5, 0.5]}))
def testPmfCorrectBroadcastDynamicShape(self): with self.test_session(): p = array_ops.placeholder(dtype=dtypes.float32) dist = bernoulli.Bernoulli(p=p) event1 = [1, 0, 1] event2 = [[1, 0, 1]] self.assertAllClose( dist.pmf(event1).eval({p: [0.2, 0.3, 0.4]}), [0.2, 0.7, 0.4]) self.assertAllClose( dist.pmf(event2).eval({p: [0.2, 0.3, 0.4]}), [[0.2, 0.7, 0.4]])
def testPmfShapes(self): with self.test_session(): p = array_ops.placeholder(dtypes.float32, shape=[None, 1]) dist = bernoulli.Bernoulli(probs=p) self.assertEqual(2, len(dist.log_prob(1).eval({p: [[0.5], [0.5]]}).shape)) with self.test_session(): dist = bernoulli.Bernoulli(probs=0.5) self.assertEqual(2, len(dist.log_prob([[1], [1]]).eval().shape)) with self.test_session(): dist = bernoulli.Bernoulli(probs=0.5) self.assertEqual((), dist.log_prob(1).get_shape()) self.assertEqual((1), dist.log_prob([1]).get_shape()) self.assertEqual((2, 1), dist.log_prob([[1], [1]]).get_shape()) with self.test_session(): dist = bernoulli.Bernoulli(probs=[[0.5], [0.5]]) self.assertEqual((2, 1), dist.log_prob(1).get_shape())
def testInvalidP(self): invalid_ps = [1.01, 2.] for p in invalid_ps: with self.test_session(): with self.assertRaisesOpError("probs has components greater than 1"): dist = bernoulli.Bernoulli(probs=p, validate_args=True) dist.probs.eval() invalid_ps = [-0.01, -3.] for p in invalid_ps: with self.test_session(): with self.assertRaisesOpError("Condition x >= 0"): dist = bernoulli.Bernoulli(probs=p, validate_args=True) dist.probs.eval() valid_ps = [0.0, 0.5, 1.0] for p in valid_ps: with self.test_session(): dist = bernoulli.Bernoulli(probs=p) self.assertEqual(p, dist.probs.eval()) # Should not fail
def testSampleN(self): with self.test_session(): p = [0.2, 0.6] dist = bernoulli.Bernoulli(probs=p) n = 100000 samples = dist.sample(n) samples.set_shape([n, 2]) self.assertEqual(samples.dtype, dtypes.int32) sample_values = samples.eval() self.assertTrue(np.all(sample_values >= 0)) self.assertTrue(np.all(sample_values <= 1)) # Note that the standard error for the sample mean is ~ sqrt(p * (1 - p) / # n). This means that the tolerance is very sensitive to the value of p # as well as n. self.assertAllClose(p, np.mean(sample_values, axis=0), atol=1e-2) self.assertEqual(set([0, 1]), set(sample_values.flatten())) # In this test we're just interested in verifying there isn't a crash # owing to mismatched types. b/30940152 dist = bernoulli.Bernoulli(np.log([.2, .4])) self.assertAllEqual((1, 2), dist.sample(1, seed=42).get_shape().as_list())
def testSampleActsLikeSampleN(self): with self.test_session() as sess: p = [0.2, 0.6] dist = bernoulli.Bernoulli(probs=p) n = 1000 seed = 42 self.assertAllEqual( dist.sample(n, seed).eval(), dist.sample(n, seed).eval()) n = array_ops.placeholder(dtypes.int32) sample, sample = sess.run([dist.sample(n, seed), dist.sample(n, seed)], feed_dict={n: 1000}) self.assertAllEqual(sample, sample)
def testVarianceAndStd(self): var = lambda p: p * (1. - p) with self.test_session(): p = [[0.2, 0.7], [0.5, 0.4]] dist = bernoulli.Bernoulli(p=p) self.assertAllClose( dist.variance().eval(), np.array([[var(0.2), var(0.7)], [var(0.5), var(0.4)]], dtype=np.float32)) self.assertAllClose( dist.stddev().eval(), np.array( [[np.sqrt(var(0.2)), np.sqrt(var(0.7))], [np.sqrt(var(0.5)), np.sqrt(var(0.4))]], dtype=np.float32))
def _testPmf(self, **kwargs): dist = bernoulli.Bernoulli(**kwargs) with self.test_session(): # pylint: disable=bad-continuation xs = [ 0, [1], [1, 0], [[1, 0]], [[1, 0], [1, 1]], ] expected_pmfs = [ [[0.8, 0.6], [0.7, 0.4]], [[0.2, 0.4], [0.3, 0.6]], [[0.2, 0.6], [0.3, 0.4]], [[0.2, 0.6], [0.3, 0.4]], [[0.2, 0.6], [0.3, 0.6]], ] # pylint: enable=bad-continuation for x, expected_pmf in zip(xs, expected_pmfs): self.assertAllClose(dist.prob(x).eval(), expected_pmf) self.assertAllClose(dist.log_prob(x).eval(), np.log(expected_pmf))
def testBoundaryConditions(self): with self.test_session(): dist = bernoulli.Bernoulli(p=1.0) self.assertAllClose(np.nan, dist.log_pmf(0).eval()) self.assertAllClose([np.nan], [dist.log_pmf(1).eval()])
def testP(self): p = [0.2, 0.4] dist = bernoulli.Bernoulli(p=p) with self.test_session(): self.assertAllClose(p, dist.p.eval())
def testEntropyNoBatch(self): p = 0.2 dist = bernoulli.Bernoulli(p=p) with self.test_session(): self.assertAllClose(dist.entropy().eval(), entropy(p))
def make_bernoulli(batch_shape, dtype=dtypes.int32): p = np.random.uniform(size=list(batch_shape)) p = constant_op.constant(p, dtype=dtypes.float32) return bernoulli.Bernoulli(p=p, dtype=dtype)
def testMean(self): with self.test_session(): p = np.array([[0.2, 0.7], [0.5, 0.4]], dtype=np.float32) dist = bernoulli.Bernoulli(p=p) self.assertAllEqual(dist.mean().eval(), p)