Ejemplo n.º 1
0
    def testLogits(self):
        logits = [-42., 42.]
        dist = bernoulli.Bernoulli(logits=logits)
        with self.test_session():
            self.assertAllClose(logits, dist.logits.eval())

        with self.test_session():
            self.assertAllClose(scipy.special.expit(logits), dist.p.eval())

        p = [0.01, 0.99, 0.42]
        dist = bernoulli.Bernoulli(p=p)
        with self.test_session():
            self.assertAllClose(scipy.special.logit(p), dist.logits.eval())
Ejemplo n.º 2
0
 def testEntropyWithBatch(self):
     p = [[0.1, 0.7], [0.2, 0.6]]
     dist = bernoulli.Bernoulli(p=p, validate_args=False)
     with self.test_session():
         self.assertAllClose(dist.entropy().eval(),
                             [[entropy(0.1), entropy(0.7)],
                              [entropy(0.2), entropy(0.6)]])
Ejemplo n.º 3
0
 def sample(self, time, outputs, state, name=None):
     with ops.name_scope(name, "ScheduledOutputTrainingHelperSample",
                         [time, outputs, state]):
         sampler = bernoulli.Bernoulli(probs=self._sampling_probability)
         return math_ops.cast(
             sampler.sample(sample_shape=self.batch_size, seed=self._seed),
             dtypes.bool)
Ejemplo n.º 4
0
    def testBernoulliBernoulliKL(self):
        with self.test_session() as sess:
            batch_size = 6
            a_p = np.array([0.5] * batch_size, dtype=np.float32)
            b_p = np.array([0.4] * batch_size, dtype=np.float32)

            a = bernoulli.Bernoulli(p=a_p)
            b = bernoulli.Bernoulli(p=b_p)

            kl = kullback_leibler.kl(a, b)
            kl_val = sess.run(kl)

            kl_expected = (a_p * np.log(a_p / b_p) + (1. - a_p) * np.log(
                (1. - a_p) / (1. - b_p)))

            self.assertEqual(kl.get_shape(), (batch_size, ))
            self.assertAllClose(kl_val, kl_expected)
Ejemplo n.º 5
0
 def testPmfInvalid(self):
   p = [0.1, 0.2, 0.7]
   with self.test_session():
     dist = bernoulli.Bernoulli(probs=p, validate_args=True)
     with self.assertRaisesOpError("must be non-negative."):
       dist.prob([1, 1, -1]).eval()
     with self.assertRaisesOpError("is not less than or equal to 1."):
       dist.prob([2, 0, 1]).eval()
Ejemplo n.º 6
0
 def testBroadcasting(self):
     with self.test_session():
         p = array_ops.placeholder(dtypes.float32)
         dist = bernoulli.Bernoulli(p=p)
         self.assertAllClose(np.log(0.5), dist.log_pmf(1).eval({p: 0.5}))
         self.assertAllClose(np.log([0.5, 0.5, 0.5]),
                             dist.log_pmf([1, 1, 1]).eval({p: 0.5}))
         self.assertAllClose(np.log([0.5, 0.5, 0.5]),
                             dist.log_pmf(1).eval({p: [0.5, 0.5, 0.5]}))
Ejemplo n.º 7
0
 def testPmfCorrectBroadcastDynamicShape(self):
     with self.test_session():
         p = array_ops.placeholder(dtype=dtypes.float32)
         dist = bernoulli.Bernoulli(p=p)
         event1 = [1, 0, 1]
         event2 = [[1, 0, 1]]
         self.assertAllClose(
             dist.pmf(event1).eval({p: [0.2, 0.3, 0.4]}), [0.2, 0.7, 0.4])
         self.assertAllClose(
             dist.pmf(event2).eval({p: [0.2, 0.3, 0.4]}), [[0.2, 0.7, 0.4]])
Ejemplo n.º 8
0
  def testPmfShapes(self):
    with self.test_session():
      p = array_ops.placeholder(dtypes.float32, shape=[None, 1])
      dist = bernoulli.Bernoulli(probs=p)
      self.assertEqual(2, len(dist.log_prob(1).eval({p: [[0.5], [0.5]]}).shape))

    with self.test_session():
      dist = bernoulli.Bernoulli(probs=0.5)
      self.assertEqual(2, len(dist.log_prob([[1], [1]]).eval().shape))

    with self.test_session():
      dist = bernoulli.Bernoulli(probs=0.5)
      self.assertEqual((), dist.log_prob(1).get_shape())
      self.assertEqual((1), dist.log_prob([1]).get_shape())
      self.assertEqual((2, 1), dist.log_prob([[1], [1]]).get_shape())

    with self.test_session():
      dist = bernoulli.Bernoulli(probs=[[0.5], [0.5]])
      self.assertEqual((2, 1), dist.log_prob(1).get_shape())
Ejemplo n.º 9
0
  def testInvalidP(self):
    invalid_ps = [1.01, 2.]
    for p in invalid_ps:
      with self.test_session():
        with self.assertRaisesOpError("probs has components greater than 1"):
          dist = bernoulli.Bernoulli(probs=p, validate_args=True)
          dist.probs.eval()

    invalid_ps = [-0.01, -3.]
    for p in invalid_ps:
      with self.test_session():
        with self.assertRaisesOpError("Condition x >= 0"):
          dist = bernoulli.Bernoulli(probs=p, validate_args=True)
          dist.probs.eval()

    valid_ps = [0.0, 0.5, 1.0]
    for p in valid_ps:
      with self.test_session():
        dist = bernoulli.Bernoulli(probs=p)
        self.assertEqual(p, dist.probs.eval())  # Should not fail
Ejemplo n.º 10
0
 def testSampleN(self):
   with self.test_session():
     p = [0.2, 0.6]
     dist = bernoulli.Bernoulli(probs=p)
     n = 100000
     samples = dist.sample(n)
     samples.set_shape([n, 2])
     self.assertEqual(samples.dtype, dtypes.int32)
     sample_values = samples.eval()
     self.assertTrue(np.all(sample_values >= 0))
     self.assertTrue(np.all(sample_values <= 1))
     # Note that the standard error for the sample mean is ~ sqrt(p * (1 - p) /
     # n). This means that the tolerance is very sensitive to the value of p
     # as well as n.
     self.assertAllClose(p, np.mean(sample_values, axis=0), atol=1e-2)
     self.assertEqual(set([0, 1]), set(sample_values.flatten()))
     # In this test we're just interested in verifying there isn't a crash
     # owing to mismatched types. b/30940152
     dist = bernoulli.Bernoulli(np.log([.2, .4]))
     self.assertAllEqual((1, 2), dist.sample(1, seed=42).get_shape().as_list())
Ejemplo n.º 11
0
 def testSampleActsLikeSampleN(self):
   with self.test_session() as sess:
     p = [0.2, 0.6]
     dist = bernoulli.Bernoulli(probs=p)
     n = 1000
     seed = 42
     self.assertAllEqual(
         dist.sample(n, seed).eval(), dist.sample(n, seed).eval())
     n = array_ops.placeholder(dtypes.int32)
     sample, sample = sess.run([dist.sample(n, seed), dist.sample(n, seed)],
                               feed_dict={n: 1000})
     self.assertAllEqual(sample, sample)
Ejemplo n.º 12
0
 def testVarianceAndStd(self):
     var = lambda p: p * (1. - p)
     with self.test_session():
         p = [[0.2, 0.7], [0.5, 0.4]]
         dist = bernoulli.Bernoulli(p=p)
         self.assertAllClose(
             dist.variance().eval(),
             np.array([[var(0.2), var(0.7)], [var(0.5), var(0.4)]],
                      dtype=np.float32))
         self.assertAllClose(
             dist.stddev().eval(),
             np.array(
                 [[np.sqrt(var(0.2)), np.sqrt(var(0.7))],
                  [np.sqrt(var(0.5)), np.sqrt(var(0.4))]],
                 dtype=np.float32))
Ejemplo n.º 13
0
  def _testPmf(self, **kwargs):
    dist = bernoulli.Bernoulli(**kwargs)
    with self.test_session():
      # pylint: disable=bad-continuation
      xs = [
          0,
          [1],
          [1, 0],
          [[1, 0]],
          [[1, 0], [1, 1]],
      ]
      expected_pmfs = [
          [[0.8, 0.6], [0.7, 0.4]],
          [[0.2, 0.4], [0.3, 0.6]],
          [[0.2, 0.6], [0.3, 0.4]],
          [[0.2, 0.6], [0.3, 0.4]],
          [[0.2, 0.6], [0.3, 0.6]],
      ]
      # pylint: enable=bad-continuation

      for x, expected_pmf in zip(xs, expected_pmfs):
        self.assertAllClose(dist.prob(x).eval(), expected_pmf)
        self.assertAllClose(dist.log_prob(x).eval(), np.log(expected_pmf))
Ejemplo n.º 14
0
 def testBoundaryConditions(self):
     with self.test_session():
         dist = bernoulli.Bernoulli(p=1.0)
         self.assertAllClose(np.nan, dist.log_pmf(0).eval())
         self.assertAllClose([np.nan], [dist.log_pmf(1).eval()])
Ejemplo n.º 15
0
 def testP(self):
     p = [0.2, 0.4]
     dist = bernoulli.Bernoulli(p=p)
     with self.test_session():
         self.assertAllClose(p, dist.p.eval())
Ejemplo n.º 16
0
 def testEntropyNoBatch(self):
     p = 0.2
     dist = bernoulli.Bernoulli(p=p)
     with self.test_session():
         self.assertAllClose(dist.entropy().eval(), entropy(p))
Ejemplo n.º 17
0
def make_bernoulli(batch_shape, dtype=dtypes.int32):
    p = np.random.uniform(size=list(batch_shape))
    p = constant_op.constant(p, dtype=dtypes.float32)
    return bernoulli.Bernoulli(p=p, dtype=dtype)
Ejemplo n.º 18
0
 def testMean(self):
     with self.test_session():
         p = np.array([[0.2, 0.7], [0.5, 0.4]], dtype=np.float32)
         dist = bernoulli.Bernoulli(p=p)
         self.assertAllEqual(dist.mean().eval(), p)