def testP(self): temperature = 1.0 logits = [2.0, 3.0, -4.0] dist = relaxed_onehot_categorical.ExpRelaxedOneHotCategorical( temperature, logits) expected_p = np.exp(logits) / np.sum(np.exp(logits)) with self.cached_session(): self.assertAllClose(expected_p, dist.probs.eval()) self.assertAllEqual([3], dist.probs.get_shape())
def testUnknownShape(self): with self.cached_session(): logits_pl = array_ops.placeholder(dtypes.float32) temperature = 1.0 dist = relaxed_onehot_categorical.ExpRelaxedOneHotCategorical( temperature, logits_pl) with self.cached_session(): feed_dict = {logits_pl: [.3, .1, .4]} self.assertAllEqual( [3], dist.sample().eval(feed_dict=feed_dict).shape) self.assertAllEqual( [5, 3], dist.sample(5).eval(feed_dict=feed_dict).shape)
def testPdf(self): temperature = .4 logits = [.3, .1, .4] k = len(logits) p = np.exp(logits) / np.sum(np.exp(logits)) dist = relaxed_onehot_categorical.ExpRelaxedOneHotCategorical( temperature, logits) with self.cached_session(): x = dist.sample().eval() # analytical ExpConcrete density presented in Maddison et al. 2016 prod_term = p * np.exp(-temperature * x) expected_pdf = (gamma(k) * np.power(temperature, k - 1) * np.prod(prod_term / np.sum(prod_term))) pdf = dist.prob(x).eval() self.assertAllClose(expected_pdf, pdf)