def testDTypes(self): # check that sampling and log_prob work for a range of dtypes with self.cached_session(): for dtype in (dtypes.float16, dtypes.float32, dtypes.float64): logits = random_ops.random_uniform(shape=[3, 3], dtype=dtype) dist = relaxed_onehot_categorical.RelaxedOneHotCategorical( temperature=0.5, logits=logits) dist.log_prob(dist.sample())
def testLogits(self): temperature = 1.0 logits = [2.0, 3.0, -4.0] dist = relaxed_onehot_categorical.RelaxedOneHotCategorical( temperature, logits) with self.cached_session(): # check p for ExpRelaxed base distribution self.assertAllClose(logits, dist._distribution.logits.eval()) self.assertAllEqual([3], dist._distribution.logits.get_shape())
def make_relaxed_categorical(batch_shape, num_classes, dtype=dtypes.float32): logits = random_ops.random_uniform( list(batch_shape) + [num_classes], -10, 10, dtype=dtype) - 50. temperatures = random_ops.random_uniform(list(batch_shape), 0.1, 10, dtype=dtypes.float32) return relaxed_onehot_categorical.RelaxedOneHotCategorical(temperatures, logits, dtype=dtype)
def testSample(self): temperature = 1.4 with self.cached_session(): # single logit logits = [.3, .1, .4] dist = relaxed_onehot_categorical.RelaxedOneHotCategorical( temperature, logits) self.assertAllEqual([3], dist.sample().eval().shape) self.assertAllEqual([5, 3], dist.sample(5).eval().shape) # multiple distributions logits = [[2.0, 3.0, -4.0], [.3, .1, .4]] dist = relaxed_onehot_categorical.RelaxedOneHotCategorical( temperature, logits) self.assertAllEqual([2, 3], dist.sample().eval().shape) self.assertAllEqual([5, 2, 3], dist.sample(5).eval().shape) # multiple distributions logits = np.random.uniform(size=(4, 1, 3)).astype(np.float32) dist = relaxed_onehot_categorical.RelaxedOneHotCategorical( temperature, logits) self.assertAllEqual([4, 1, 3], dist.sample().eval().shape) self.assertAllEqual([5, 4, 1, 3], dist.sample(5).eval().shape)
def testPdf(self): def analytical_pdf(x, temperature, logits): # analytical density of RelaxedOneHotCategorical temperature = np.reshape(temperature, (-1, 1)) if len(x.shape) == 1: x = np.expand_dims(x, 0) k = logits.shape[1] p = np.exp(logits) / np.sum(np.exp(logits), axis=1, keepdims=True) term1 = gamma(k) * np.power(temperature, k - 1) term2 = np.sum(p / (np.power(x, temperature)), axis=1, keepdims=True) term3 = np.prod(p / (np.power(x, temperature + 1)), axis=1, keepdims=True) expected_pdf = term1 * np.power(term2, -k) * term3 return expected_pdf with self.cached_session(): temperature = .4 logits = np.array([[.3, .1, .4]]).astype(np.float32) dist = relaxed_onehot_categorical.RelaxedOneHotCategorical( temperature, logits) x = dist.sample().eval() pdf = dist.prob(x).eval() expected_pdf = analytical_pdf(x, temperature, logits) self.assertAllClose(expected_pdf.flatten(), pdf, rtol=1e-4) # variable batch size logits = np.array([[.3, .1, .4], [.6, -.1, 2.]]).astype(np.float32) temperatures = np.array([0.4, 2.3]).astype(np.float32) dist = relaxed_onehot_categorical.RelaxedOneHotCategorical( temperatures, logits) x = dist.sample().eval() pdf = dist.prob(x).eval() expected_pdf = analytical_pdf(x, temperatures, logits) self.assertAllClose(expected_pdf.flatten(), pdf, rtol=1e-4)