def testCategoricalCategoricalKL(self):
    def np_softmax(logits):
      exp_logits = np.exp(logits)
      return exp_logits / exp_logits.sum(axis=-1, keepdims=True)

    with self.cached_session() as sess:
      for categories in [2, 10]:
        for batch_size in [1, 2]:
          p_logits = self._rng.random_sample((batch_size, categories))
          q_logits = self._rng.random_sample((batch_size, categories))
          p = onehot_categorical.OneHotCategorical(logits=p_logits)
          q = onehot_categorical.OneHotCategorical(logits=q_logits)
          prob_p = np_softmax(p_logits)
          prob_q = np_softmax(q_logits)
          kl_expected = np.sum(
              prob_p * (np.log(prob_p) - np.log(prob_q)), axis=-1)

          kl_actual = kullback_leibler.kl_divergence(p, q)
          kl_same = kullback_leibler.kl_divergence(p, p)
          x = p.sample(int(2e4), seed=0)
          x = math_ops.cast(x, dtype=dtypes.float32)
          # Compute empirical KL(p||q).
          kl_sample = math_ops.reduce_mean(p.log_prob(x) - q.log_prob(x), 0)

          [kl_sample_, kl_actual_, kl_same_] = sess.run([kl_sample, kl_actual,
                                                         kl_same])
          self.assertEqual(kl_actual.get_shape(), (batch_size,))
          self.assertAllClose(kl_same_, np.zeros_like(kl_expected))
          self.assertAllClose(kl_actual_, kl_expected, atol=0., rtol=1e-6)
          self.assertAllClose(kl_sample_, kl_expected, atol=1e-2, rtol=0.)
 def testSampleUnbiasedScalarBatch(self):
   with self.cached_session() as sess:
     logits = self._rng.rand(3).astype(np.float32)
     dist = onehot_categorical.OneHotCategorical(logits=logits)
     n = int(1e4)
     x = dist.sample(n, seed=0)
     x = math_ops.cast(x, dtype=dtypes.float32)
     sample_mean = math_ops.reduce_mean(x, 0)  # elementwise mean
     x_centered = x - sample_mean
     sample_covariance = math_ops.matmul(
         x_centered, x_centered, adjoint_a=True) / n
     [
         sample_mean_,
         sample_covariance_,
         actual_mean_,
         actual_covariance_,
     ] = sess.run([
         sample_mean,
         sample_covariance,
         dist.probs,
         dist.covariance(),
     ])
     self.assertAllEqual([3], sample_mean.get_shape())
     self.assertAllClose(actual_mean_, sample_mean_, atol=0., rtol=0.1)
     self.assertAllEqual([3, 3], sample_covariance.get_shape())
     self.assertAllClose(
         actual_covariance_, sample_covariance_, atol=0., rtol=0.1)
 def testSampleUnbiasedNonScalarBatch(self):
   with self.cached_session() as sess:
     logits = self._rng.rand(4, 3, 2).astype(np.float32)
     dist = onehot_categorical.OneHotCategorical(logits=logits)
     n = int(3e3)
     x = dist.sample(n, seed=0)
     x = math_ops.cast(x, dtype=dtypes.float32)
     sample_mean = math_ops.reduce_mean(x, 0)
     x_centered = array_ops.transpose(x - sample_mean, [1, 2, 3, 0])
     sample_covariance = math_ops.matmul(
         x_centered, x_centered, adjoint_b=True) / n
     [
         sample_mean_,
         sample_covariance_,
         actual_mean_,
         actual_covariance_,
     ] = sess.run([
         sample_mean,
         sample_covariance,
         dist.probs,
         dist.covariance(),
     ])
     self.assertAllEqual([4, 3, 2], sample_mean.get_shape())
     self.assertAllClose(actual_mean_, sample_mean_, atol=0., rtol=0.07)
     self.assertAllEqual([4, 3, 2, 2], sample_covariance.get_shape())
     self.assertAllClose(
         actual_covariance_, sample_covariance_, atol=0., rtol=0.10)
 def testEntropyNoBatch(self):
   logits = np.log([0.2, 0.8]) - 50.
   dist = onehot_categorical.OneHotCategorical(logits)
   with self.cached_session():
     self.assertAllClose(
         dist.entropy().eval(),
         -(0.2 * np.log(0.2) + 0.8 * np.log(0.8)))
 def testLogits(self):
   p = np.array([0.2, 0.8], dtype=np.float32)
   logits = np.log(p) - 50.
   dist = onehot_categorical.OneHotCategorical(logits=logits)
   with self.cached_session():
     self.assertAllEqual([2], dist.probs.get_shape())
     self.assertAllEqual([2], dist.logits.get_shape())
     self.assertAllClose(dist.probs.eval(), p)
     self.assertAllClose(dist.logits.eval(), logits)
 def testPmf(self):
   # check that probability of samples correspond to their class probabilities
   with self.cached_session():
     logits = self._rng.random_sample(size=(8, 2, 10))
     prob = np.exp(logits)/np.sum(np.exp(logits), axis=-1, keepdims=True)
     dist = onehot_categorical.OneHotCategorical(logits=logits)
     np_sample = dist.sample().eval()
     np_prob = dist.prob(np_sample).eval()
     expected_prob = prob[np_sample.astype(np.bool)]
     self.assertAllClose(expected_prob, np_prob.flatten())
 def testSampleWithSampleShape(self):
   with self.cached_session():
     probs = [[[0.2, 0.8], [0.4, 0.6]]]
     dist = onehot_categorical.OneHotCategorical(math_ops.log(probs) - 50.)
     samples = dist.sample((100, 100), seed=123)
     prob = dist.prob(samples)
     prob_val = prob.eval()
     self.assertAllClose([0.2**2 + 0.8**2], [prob_val[:, :, :, 0].mean()],
                         atol=1e-2)
     self.assertAllClose([0.4**2 + 0.6**2], [prob_val[:, :, :, 1].mean()],
                         atol=1e-2)
 def testSample(self):
   with self.cached_session():
     probs = [[[0.2, 0.8], [0.4, 0.6]]]
     dist = onehot_categorical.OneHotCategorical(math_ops.log(probs) - 50.)
     n = 100
     samples = dist.sample(n, seed=123)
     self.assertEqual(samples.dtype, dtypes.int32)
     sample_values = samples.eval()
     self.assertAllEqual([n, 1, 2, 2], sample_values.shape)
     self.assertFalse(np.any(sample_values < 0))
     self.assertFalse(np.any(sample_values > 1))
 def testUnknownShape(self):
   with self.cached_session():
     logits = array_ops.placeholder(dtype=dtypes.float32)
     dist = onehot_categorical.OneHotCategorical(logits)
     sample = dist.sample()
     # Will sample class 1.
     sample_value = sample.eval(feed_dict={logits: [-1000.0, 1000.0]})
     self.assertAllEqual([0, 1], sample_value)
     # Batch entry 0 will sample class 1, batch entry 1 will sample class 0.
     sample_value_batch = sample.eval(
         feed_dict={logits: [[-1000.0, 1000.0], [1000.0, -1000.0]]})
     self.assertAllEqual([[0, 1], [1, 0]], sample_value_batch)
 def testP(self):
   p = [0.2, 0.8]
   dist = onehot_categorical.OneHotCategorical(probs=p)
   with self.cached_session():
     self.assertAllClose(p, dist.probs.eval())
     self.assertAllEqual([2], dist.logits.get_shape())
def make_onehot_categorical(batch_shape, num_classes, dtype=dtypes.int32):
  logits = random_ops.random_uniform(
      list(batch_shape) + [num_classes], -10, 10, dtype=dtypes.float32) - 50.
  return onehot_categorical.OneHotCategorical(logits, dtype=dtype)