Exemple #1
0
 def testSampleWithIntegerOutcomes(self):
     outcomes = self._build_tensor([1, 2])
     probs = self._build_tensor([0.2, 0.8])
     dist = finite_discrete.FiniteDiscrete(outcomes,
                                           probs=probs,
                                           validate_args=True)
     samples = self.evaluate(dist.sample(5000, seed=1234))
     self.assertAllClose(np.mean(samples), dist.mean(), atol=0.1)
     self.assertAllClose(np.std(samples), dist.stddev(), atol=0.1)
Exemple #2
0
 def testOutcomesNotStrictlyIncreasingRaises(self):
     outcomes = self._build_tensor([1.0, 1.0, 2.0, 2.0])
     probs = self._build_tensor([0.25, 0.25, 0.25, 0.25])
     with self.assertRaisesWithPredicateMatch(
             Exception, 'outcomes is not strictly increasing.'):
         dist = finite_discrete.FiniteDiscrete(outcomes,
                                               probs=probs,
                                               validate_args=True)
         self.evaluate(dist.outcomes)
Exemple #3
0
 def testModeWithIntegerOutcomes(self):
     outcomes = self._build_tensor([1, 2, 3])
     probs = self._build_tensor([0.3, 0.1, 0.6])
     dist = finite_discrete.FiniteDiscrete(outcomes,
                                           probs=probs,
                                           validate_args=True)
     mode = dist.mode()
     self.assertAllEqual((), self._get_shape(mode))
     self.assertAllEqual(3, mode)
Exemple #4
0
 def testRankOfOutcomesLargerThanOneRaises(self):
     outcomes = self._build_tensor([[1.0, 2.0], [3.0, 4.0]])
     probs = self._build_tensor([0.5, 0.5])
     with self.assertRaisesWithPredicateMatch(
             Exception, 'Rank of outcomes must be 1.'):
         dist = finite_discrete.FiniteDiscrete(outcomes,
                                               probs=probs,
                                               validate_args=True)
         self.evaluate(dist.outcomes)
 def testMean(self):
     outcomes = self._build_tensor([1.0, 2.0])
     probs = self._build_tensor([0.5, 0.5])
     dist = finite_discrete.FiniteDiscrete(outcomes,
                                           probs=probs,
                                           validate_args=True)
     mean = dist.mean()
     self.assertAllEqual((), self._get_shape(mean))
     self.assertAllClose(1.5, mean)
Exemple #6
0
 def testCDFWithDifferentAtol(self):
     outcomes = self._build_tensor([0.1, 0.2, 0.4, 0.8])
     probs = self._build_tensor([0.0, 0.1, 0.2, 0.7])
     x = self._build_tensor([[0.095, 0.095], [0.395, 0.395]])
     dist1 = finite_discrete.FiniteDiscrete(outcomes,
                                            probs=probs,
                                            atol=0.001,
                                            validate_args=True)
     cdf = dist1.cdf(x)
     self.assertAllEqual((2, 2), self._get_shape(cdf))
     self.assertAllClose([[0.0, 0.0], [0.1, 0.1]], cdf)
     dist2 = finite_discrete.FiniteDiscrete(outcomes,
                                            probs=probs,
                                            atol=0.01,
                                            validate_args=True)
     cdf = dist2.cdf(x)
     self.assertAllEqual((2, 2), self._get_shape(cdf))
     self.assertAllClose([[0.0, 0.0], [0.3, 0.3]], cdf)
Exemple #7
0
 def testSizeOfOutcomesIsZeroRaises(self):
     outcomes = self._build_tensor([])
     probs = self._build_tensor([])
     with self.assertRaisesWithPredicateMatch(
             Exception, 'Size of outcomes must be greater than 0.'):
         dist = finite_discrete.FiniteDiscrete(outcomes,
                                               probs=probs,
                                               validate_args=True)
         self.evaluate(dist.outcomes)
Exemple #8
0
 def testMode(self):
     outcomes = self._build_tensor([1.0, 2.0, 3.0])
     probs = self._build_tensor([0.3, 0.1, 0.6])
     dist = finite_discrete.FiniteDiscrete(outcomes,
                                           probs=probs,
                                           validate_args=True)
     mode = dist.mode()
     self.assertAllEqual((), self._get_shape(mode))
     self.assertAllClose(3.0, mode)
Exemple #9
0
 def testCDF(self):
     outcomes = self._build_tensor([0.1, 0.2, 0.4, 0.8])
     probs = self._build_tensor([0.0, 0.1, 0.2, 0.7])
     dist = finite_discrete.FiniteDiscrete(outcomes,
                                           probs=probs,
                                           validate_args=True)
     cdf = dist.cdf(0.4)
     self.assertAllEqual((), self._get_shape(cdf))
     self.assertAllClose(0.3, cdf)
Exemple #10
0
 def testPMF(self):
     outcomes = self._build_tensor([1.0, 2.0, 4.0, 8.0])
     probs = self._build_tensor([[0.0, 0.1, 0.2, 0.7], [0.5, 0.3, 0.2,
                                                        0.0]])
     dist = finite_discrete.FiniteDiscrete(outcomes,
                                           probs=probs,
                                           validate_args=True)
     prob = dist.prob(8.0)
     self.assertAllEqual((2, ), self._get_shape(prob))
     self.assertAllClose([0.7, 0.0], prob)
Exemple #11
0
 def testSample(self):
     outcomes = self._build_tensor([1.0, 2.0])
     probs = self._build_tensor([[0.2, 0.8], [0.8, 0.2]])
     dist = finite_discrete.FiniteDiscrete(outcomes,
                                           probs=probs,
                                           validate_args=True)
     samples = self.evaluate(dist.sample(5000, seed=1234))
     self.assertAllEqual((5000, 2), self._get_shape(samples))
     self.assertAllClose(np.mean(samples, axis=0), dist.mean(), atol=0.1)
     self.assertAllClose(np.std(samples, axis=0), dist.stddev(), atol=0.1)
Exemple #12
0
 def testEntropy(self):
     outcomes = self._build_tensor([1, 2, 3, 4])
     probs = np.array([[0.125, 0.125, 0.25, 0.5], [0.25, 0.25, 0.25, 0.25]])
     outcome_probs = self._build_tensor(probs)
     dist = finite_discrete.FiniteDiscrete(outcomes,
                                           probs=outcome_probs,
                                           validate_args=True)
     entropy = dist.entropy()
     self.assertAllEqual((2, ), self._get_shape(entropy))
     self.assertAllClose(np.sum(-probs * np.log(probs), axis=1), entropy)
Exemple #13
0
 def testMean(self):
     outcomes = self._build_tensor([1.0, 2.0])
     probs = self._build_tensor([[0.5, 0.5], [0.2, 0.8]])
     expected_means = [1.5, 1.8]
     dist = finite_discrete.FiniteDiscrete(outcomes,
                                           probs=probs,
                                           validate_args=True)
     mean = dist.mean()
     self.assertAllEqual((2, ), self._get_shape(mean))
     self.assertAllClose(expected_means, mean)
Exemple #14
0
 def testInequalLastDimRaises(self):
     outcomes = self._build_tensor([1.0, 2.0])
     probs = self._build_tensor([0.25, 0.25, 0.5])
     with self.assertRaisesWithPredicateMatch(
             Exception,
             'Last dimension of outcomes and probs must be equal size'):
         dist = finite_discrete.FiniteDiscrete(outcomes,
                                               probs=probs,
                                               validate_args=True)
         self.evaluate(dist.outcomes)
Exemple #15
0
 def testCDFWithBatchSampleShape(self):
     outcomes = self._build_tensor([0.1, 0.2, 0.4, 0.8])
     probs = self._build_tensor([0.0, 0.1, 0.2, 0.7])
     x = self._build_tensor([[0.0999, 0.1], [0.2, 0.4], [0.8, 0.8001]])
     dist = finite_discrete.FiniteDiscrete(outcomes,
                                           probs=probs,
                                           validate_args=True)
     cdf = dist.cdf(x)
     self.assertAllEqual((3, 2), self._get_shape(cdf))
     self.assertAllClose([[0.0, 0.0], [0.1, 0.3], [1.0, 1.0]], cdf)
Exemple #16
0
 def testPMFWithIntegerOutcomes(self):
     outcomes = self._build_tensor([1, 2, 4, 8])
     probs = self._build_tensor([0.0, 0.1, 0.2, 0.7])
     x = self._build_tensor([[1], [2], [3], [4], [8]])
     dist = finite_discrete.FiniteDiscrete(outcomes,
                                           probs=probs,
                                           validate_args=True)
     prob = dist.prob(x)
     self.assertAllEqual((5, 1), self._get_shape(prob))
     self.assertAllClose([[0.0], [0.1], [0.0], [0.2], [0.7]], prob)
Exemple #17
0
 def testPMFWithBatchSampleShape(self):
     outcomes = self._build_tensor([1.0, 2.0, 4.0, 8.0])
     probs = self._build_tensor([0.0, 0.1, 0.2, 0.7])
     x = self._build_tensor([[1.0], [2.0], [3.0], [4.0], [8.0]])
     dist = finite_discrete.FiniteDiscrete(outcomes,
                                           probs=probs,
                                           validate_args=True)
     prob = dist.prob(x)
     self.assertAllEqual((5, 1), self._get_shape(prob))
     self.assertAllClose([[0.0], [0.1], [0.0], [0.2], [0.7]], prob)
Exemple #18
0
 def testCDFWithIntegerOutcomes(self):
     outcomes = self._build_tensor([1, 2, 4, 8])
     probs = self._build_tensor([0.0, 0.1, 0.2, 0.7])
     x = self._build_tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])
     dist = finite_discrete.FiniteDiscrete(outcomes,
                                           probs=probs,
                                           validate_args=True)
     cdf = dist.cdf(x)
     self.assertAllEqual((10, ), self._get_shape(cdf))
     self.assertAllClose([0.0, 0.0, 0.1, 0.1, 0.3, 0.3, 0.3, 0.3, 1.0, 1.0],
                         cdf)
Exemple #19
0
 def testShape(self):
     outcomes = self._build_tensor([0.0, 0.2, 0.3, 0.5])
     logits = self._build_tensor([-0.1, 0.0, 0.1, 0.2])
     dist = finite_discrete.FiniteDiscrete(outcomes,
                                           logits=logits,
                                           validate_args=True)
     if self.use_static_shape:
         self.assertAllEqual([], dist.batch_shape)
     self.assertAllEqual([], dist.batch_shape_tensor())
     self.assertAllEqual([], dist.event_shape)
     self.assertAllEqual([], dist.event_shape_tensor())
Exemple #20
0
 def testStddevAndVariance(self):
     outcomes = self._build_tensor([1.0, 2.0])
     probs = self._build_tensor([[0.5, 0.5], [0.2, 0.8]])
     dist = finite_discrete.FiniteDiscrete(outcomes,
                                           probs=probs,
                                           validate_args=True)
     stddev = dist.stddev()
     self.assertAllEqual((2, ), self._get_shape(stddev))
     self.assertAllClose([0.5, 0.4], stddev)
     variance = dist.variance()
     self.assertAllEqual((2, ), self._get_shape(variance))
     self.assertAllClose([0.25, 0.16], variance)
Exemple #21
0
 def testParamTensorFromProbs(self):
     outcomes = self._build_tensor([0.1, 0.2, 0.4])
     x = tf.constant([0.1, 0.5, 0.4])
     d = finite_discrete.FiniteDiscrete(outcomes,
                                        probs=x,
                                        validate_args=True)
     self.assertAllClose(*self.evaluate(
         [tf.math.log(x), d.logits_parameter()]),
                         atol=0,
                         rtol=1e-4)
     self.assertAllClose(*self.evaluate([x, d.probs_parameter()]),
                         atol=0,
                         rtol=1e-4)
Exemple #22
0
 def testShapes(self):
     outcomes = [0.0, 0.2, 0.3, 0.5]
     outcomes_tensor = self._build_tensor(outcomes)
     for batch_shape in ([1], [2], [3, 4, 5]):
         logits = self._build_tensor(
             np.random.uniform(-1,
                               1,
                               size=list(batch_shape) + [len(outcomes)]))
         dist = finite_discrete.FiniteDiscrete(outcomes_tensor,
                                               logits=logits,
                                               validate_args=True)
         if self.use_static_shape:
             self.assertAllEqual(batch_shape, dist.batch_shape)
         self.assertAllEqual(batch_shape, dist.batch_shape_tensor())
         self.assertAllEqual([], dist.event_shape)
         self.assertAllEqual([], dist.event_shape_tensor())
Exemple #23
0
 def testPMF(self):
     outcomes = self._build_tensor([1.0, 2.0, 4.0, 8.0])
     probs = self._build_tensor([0.0, 0.1, 0.2, 0.7])
     dist = finite_discrete.FiniteDiscrete(outcomes,
                                           probs=probs,
                                           validate_args=True)
     prob = dist.prob(4.0)
     self.assertAllEqual((), self._get_shape(prob))
     self.assertAllClose(0.2, prob)
     # Outcome with zero probability.
     prob = dist.prob(1.0)
     self.assertAllEqual((), self._get_shape(prob))
     self.assertAllClose(0.0, prob)
     # Input that is not in the list of possible outcomes.
     prob = dist.prob(3.0)
     self.assertAllEqual((), self._get_shape(prob))
     self.assertAllClose(0.0, prob)