예제 #1
0
 def testP(self):
     temperature = 1.0
     logits = [2.0, 3.0, -4.0]
     dist = relaxed_onehot_categorical.ExpRelaxedOneHotCategorical(
         temperature, logits)
     expected_p = np.exp(logits) / np.sum(np.exp(logits))
     with self.cached_session():
         self.assertAllClose(expected_p, dist.probs.eval())
         self.assertAllEqual([3], dist.probs.get_shape())
예제 #2
0
 def testUnknownShape(self):
     with self.cached_session():
         logits_pl = array_ops.placeholder(dtypes.float32)
         temperature = 1.0
         dist = relaxed_onehot_categorical.ExpRelaxedOneHotCategorical(
             temperature, logits_pl)
         with self.cached_session():
             feed_dict = {logits_pl: [.3, .1, .4]}
             self.assertAllEqual(
                 [3],
                 dist.sample().eval(feed_dict=feed_dict).shape)
             self.assertAllEqual(
                 [5, 3],
                 dist.sample(5).eval(feed_dict=feed_dict).shape)
예제 #3
0
 def testPdf(self):
     temperature = .4
     logits = [.3, .1, .4]
     k = len(logits)
     p = np.exp(logits) / np.sum(np.exp(logits))
     dist = relaxed_onehot_categorical.ExpRelaxedOneHotCategorical(
         temperature, logits)
     with self.cached_session():
         x = dist.sample().eval()
         # analytical ExpConcrete density presented in Maddison et al. 2016
         prod_term = p * np.exp(-temperature * x)
         expected_pdf = (gamma(k) * np.power(temperature, k - 1) *
                         np.prod(prod_term / np.sum(prod_term)))
         pdf = dist.prob(x).eval()
         self.assertAllClose(expected_pdf, pdf)