Exemplo n.º 1
0
 def testSampleUnbiasedNonScalarBatch(self):
     with self.test_session() as sess:
         logits = self._rng.rand(4, 3, 2).astype(np.float32)
         dist = onehot_categorical.OneHotCategorical(logits=logits)
         n = int(3e3)
         x = dist.sample(n, seed=0)
         x = math_ops.cast(x, dtype=dtypes.float32)
         sample_mean = math_ops.reduce_mean(x, 0)
         x_centered = array_ops.transpose(x - sample_mean, [1, 2, 3, 0])
         sample_covariance = math_ops.matmul(
             x_centered, x_centered, adjoint_b=True) / n
         [
             sample_mean_,
             sample_covariance_,
             actual_mean_,
             actual_covariance_,
         ] = sess.run([
             sample_mean,
             sample_covariance,
             dist.probs,
             dist.covariance(),
         ])
         self.assertAllEqual([4, 3, 2], sample_mean.get_shape())
         self.assertAllClose(actual_mean_, sample_mean_, atol=0., rtol=0.07)
         self.assertAllEqual([4, 3, 2, 2], sample_covariance.get_shape())
         self.assertAllClose(actual_covariance_,
                             sample_covariance_,
                             atol=0.,
                             rtol=0.10)
Exemplo n.º 2
0
 def testSampleUnbiasedScalarBatch(self):
     with self.test_session() as sess:
         logits = self._rng.rand(3).astype(np.float32)
         dist = onehot_categorical.OneHotCategorical(logits=logits)
         n = int(1e4)
         x = dist.sample(n, seed=0)
         x = math_ops.cast(x, dtype=dtypes.float32)
         sample_mean = math_ops.reduce_mean(x, 0)  # elementwise mean
         x_centered = x - sample_mean
         sample_covariance = math_ops.matmul(
             x_centered, x_centered, adjoint_a=True) / n
         [
             sample_mean_,
             sample_covariance_,
             actual_mean_,
             actual_covariance_,
         ] = sess.run([
             sample_mean,
             sample_covariance,
             dist.probs,
             dist.covariance(),
         ])
         self.assertAllEqual([3], sample_mean.get_shape())
         self.assertAllClose(actual_mean_, sample_mean_, atol=0., rtol=0.1)
         self.assertAllEqual([3, 3], sample_covariance.get_shape())
         self.assertAllClose(actual_covariance_,
                             sample_covariance_,
                             atol=0.,
                             rtol=0.1)
Exemplo n.º 3
0
 def testEntropyWithBatch(self):
     logits = np.log([[0.2, 0.8], [0.6, 0.4]]) - 50.
     dist = onehot_categorical.OneHotCategorical(logits)
     with self.test_session():
         self.assertAllClose(dist.entropy().eval(), [
             -(0.2 * np.log(0.2) + 0.8 * np.log(0.8)),
             -(0.6 * np.log(0.6) + 0.4 * np.log(0.4))
         ])
Exemplo n.º 4
0
 def testLogits(self):
     p = np.array([0.2, 0.8], dtype=np.float32)
     logits = np.log(p) - 50.
     dist = onehot_categorical.OneHotCategorical(logits=logits)
     with self.test_session():
         self.assertAllEqual([2], dist.probs.get_shape())
         self.assertAllEqual([2], dist.logits.get_shape())
         self.assertAllClose(dist.probs.eval(), p)
         self.assertAllClose(dist.logits.eval(), logits)
Exemplo n.º 5
0
 def testPmf(self):
     # check that probability of samples correspond to their class probabilities
     with self.test_session():
         logits = self._rng.random_sample(size=(8, 2, 10))
         prob = np.exp(logits) / np.sum(
             np.exp(logits), axis=-1, keepdims=True)
         dist = onehot_categorical.OneHotCategorical(logits=logits)
         np_sample = dist.sample().eval()
         np_prob = dist.prob(np_sample).eval()
         expected_prob = prob[np_sample.astype(np.bool)]
         self.assertAllClose(expected_prob, np_prob.flatten())
Exemplo n.º 6
0
 def testSample(self):
     with self.test_session():
         probs = [[[0.2, 0.8], [0.4, 0.6]]]
         dist = onehot_categorical.OneHotCategorical(
             math_ops.log(probs) - 50.)
         n = 100
         samples = dist.sample(n, seed=123)
         self.assertEqual(samples.dtype, dtypes.int32)
         sample_values = samples.eval()
         self.assertAllEqual([n, 1, 2, 2], sample_values.shape)
         self.assertFalse(np.any(sample_values < 0))
         self.assertFalse(np.any(sample_values > 1))
Exemplo n.º 7
0
 def testUnknownShape(self):
     with self.test_session():
         logits = array_ops.placeholder(dtype=dtypes.float32)
         dist = onehot_categorical.OneHotCategorical(logits)
         sample = dist.sample()
         # Will sample class 1.
         sample_value = sample.eval(feed_dict={logits: [-1000.0, 1000.0]})
         self.assertAllEqual([0, 1], sample_value)
         # Batch entry 0 will sample class 1, batch entry 1 will sample class 0.
         sample_value_batch = sample.eval(
             feed_dict={logits: [[-1000.0, 1000.0], [1000.0, -1000.0]]})
         self.assertAllEqual([[0, 1], [1, 0]], sample_value_batch)
Exemplo n.º 8
0
    def testCategoricalCategoricalKL(self):
        def np_softmax(logits):
            exp_logits = np.exp(logits)
            return exp_logits / exp_logits.sum(axis=-1, keepdims=True)

        with self.test_session() as sess:
            for categories in [2, 10]:
                for batch_size in [1, 2]:
                    p_logits = self._rng.random_sample(
                        (batch_size, categories))
                    q_logits = self._rng.random_sample(
                        (batch_size, categories))
                    p = onehot_categorical.OneHotCategorical(logits=p_logits)
                    q = onehot_categorical.OneHotCategorical(logits=q_logits)
                    prob_p = np_softmax(p_logits)
                    prob_q = np_softmax(q_logits)
                    kl_expected = np.sum(prob_p *
                                         (np.log(prob_p) - np.log(prob_q)),
                                         axis=-1)

                    kl_actual = kullback_leibler.kl(p, q)
                    kl_same = kullback_leibler.kl(p, p)
                    x = p.sample(int(2e4), seed=0)
                    x = math_ops.cast(x, dtype=dtypes.float32)
                    # Compute empirical KL(p||q).
                    kl_sample = math_ops.reduce_mean(
                        p.log_prob(x) - q.log_prob(x), 0)

                    [kl_sample_, kl_actual_,
                     kl_same_] = sess.run([kl_sample, kl_actual, kl_same])
                    self.assertEqual(kl_actual.get_shape(), (batch_size, ))
                    self.assertAllClose(kl_same_, np.zeros_like(kl_expected))
                    self.assertAllClose(kl_actual_,
                                        kl_expected,
                                        atol=0.,
                                        rtol=1e-6)
                    self.assertAllClose(kl_sample_,
                                        kl_expected,
                                        atol=1e-2,
                                        rtol=0.)
Exemplo n.º 9
0
 def testSampleWithSampleShape(self):
     with self.test_session():
         probs = [[[0.2, 0.8], [0.4, 0.6]]]
         dist = onehot_categorical.OneHotCategorical(
             math_ops.log(probs) - 50.)
         samples = dist.sample((100, 100), seed=123)
         prob = dist.prob(samples)
         prob_val = prob.eval()
         self.assertAllClose([0.2**2 + 0.8**2],
                             [prob_val[:, :, :, 0].mean()],
                             atol=1e-2)
         self.assertAllClose([0.4**2 + 0.6**2],
                             [prob_val[:, :, :, 1].mean()],
                             atol=1e-2)
Exemplo n.º 10
0
 def dist(self):
     return onehot_categorical.OneHotCategorical(logits=self._logits)
Exemplo n.º 11
0
 def testP(self):
     p = [0.2, 0.8]
     dist = onehot_categorical.OneHotCategorical(probs=p)
     with self.test_session():
         self.assertAllClose(p, dist.probs.eval())
         self.assertAllEqual([2], dist.logits.get_shape())
Exemplo n.º 12
0
def make_onehot_categorical(batch_shape, num_classes, dtype=dtypes.int32):
    logits = random_ops.random_uniform(
        list(batch_shape) + [num_classes], -10, 10, dtype=dtypes.float32) - 50.
    return onehot_categorical.OneHotCategorical(logits, dtype=dtype)
 def testEntropyNoBatch(self):
     logits = np.log([0.2, 0.8]) - 50.
     dist = onehot_categorical.OneHotCategorical(logits)
     with self.cached_session():
         self.assertAllClose(dist.entropy().eval(),
                             -(0.2 * np.log(0.2) + 0.8 * np.log(0.8)))