Esempio n. 1
0
    def testCategoricalCategoricalKL(self):
        def np_softmax(logits):
            exp_logits = np.exp(logits)
            return exp_logits / exp_logits.sum(axis=-1, keepdims=True)

        with self.test_session() as sess:
            for categories in [2, 4]:
                for batch_size in [1, 10]:
                    a_logits = np.random.randn(batch_size, categories)
                    b_logits = np.random.randn(batch_size, categories)

                    a = categorical.Categorical(logits=a_logits)
                    b = categorical.Categorical(logits=b_logits)

                    kl = kullback_leibler.kl(a, b)
                    kl_val = sess.run(kl)
                    # Make sure KL(a||a) is 0
                    kl_same = sess.run(kullback_leibler.kl(a, a))

                    prob_a = np_softmax(a_logits)
                    prob_b = np_softmax(b_logits)
                    kl_expected = np.sum(prob_a *
                                         (np.log(prob_a) - np.log(prob_b)),
                                         axis=-1)

                    self.assertEqual(kl.get_shape(), (batch_size, ))
                    self.assertAllClose(kl_val, kl_expected)
                    self.assertAllClose(kl_same, np.zeros_like(kl_expected))
Esempio n. 2
0
    def testEntropyGradient(self):
        with self.test_session() as sess:
            logits = constant_op.constant([[1., 2., 3.], [2., 5., 1.]])

            probabilities = nn_ops.softmax(logits)
            log_probabilities = nn_ops.log_softmax(logits)
            true_entropy = -math_ops.reduce_sum(
                probabilities * log_probabilities, axis=-1)

            categorical_distribution = categorical.Categorical(p=probabilities)
            categorical_entropy = categorical_distribution.entropy()

            # works
            true_entropy_g = gradients_impl.gradients(true_entropy, [logits])
            categorical_entropy_g = gradients_impl.gradients(
                categorical_entropy, [logits])

            res = sess.run({
                "true_entropy": true_entropy,
                "categorical_entropy": categorical_entropy,
                "true_entropy_g": true_entropy_g,
                "categorical_entropy_g": categorical_entropy_g
            })
            self.assertAllClose(res["true_entropy"],
                                res["categorical_entropy"])
            self.assertAllClose(res["true_entropy_g"],
                                res["categorical_entropy_g"])
Esempio n. 3
0
    def testLogPMFBroadcasting(self):
        with self.test_session():
            histograms = [[[0.2, 0.8], [0.4, 0.6]]]
            dist = categorical.Categorical(math_ops.log(histograms) - 50.)

            prob = dist.prob(1)
            self.assertAllClose([[0.8, 0.6]], prob.eval())

            prob = dist.prob([1])
            self.assertAllClose([[0.8, 0.6]], prob.eval())

            prob = dist.prob([0, 1])
            self.assertAllClose([[0.2, 0.6]], prob.eval())

            prob = dist.prob([[0, 1]])
            self.assertAllClose([[0.2, 0.6]], prob.eval())

            prob = dist.prob([[[0, 1]]])
            self.assertAllClose([[[0.2, 0.6]]], prob.eval())

            prob = dist.prob([[1, 0], [0, 1]])
            self.assertAllClose([[0.8, 0.4], [0.2, 0.6]], prob.eval())

            prob = dist.prob([[[1, 1], [1, 0]], [[1, 0], [0, 1]]])
            self.assertAllClose(
                [[[0.8, 0.6], [0.8, 0.4]], [[0.8, 0.4], [0.2, 0.6]]],
                prob.eval())
Esempio n. 4
0
 def testEntropyWithBatch(self):
     logits = np.log([[0.2, 0.8], [0.6, 0.4]]) - 50.
     dist = categorical.Categorical(logits)
     with self.test_session():
         self.assertAllClose(dist.entropy().eval(), [
             -(0.2 * np.log(0.2) + 0.8 * np.log(0.8)),
             -(0.6 * np.log(0.6) + 0.4 * np.log(0.4))
         ])
Esempio n. 5
0
 def testLogits(self):
     p = np.array([0.2, 0.8], dtype=np.float32)
     logits = np.log(p) - 50.
     dist = categorical.Categorical(logits=logits)
     with self.test_session():
         self.assertAllEqual([2], dist.p.get_shape())
         self.assertAllEqual([2], dist.logits.get_shape())
         self.assertAllClose(dist.p.eval(), p)
         self.assertAllClose(dist.logits.eval(), logits)
Esempio n. 6
0
 def sample(self, time, outputs, state, name=None):
     with ops.name_scope(name, "ScheduledEmbeddingTrainingHelperSample",
                         [time, outputs, state]):
         # Return -1s where we did not sample, and sample_ids elsewhere
         select_sample_noise = random_ops.random_uniform(
             [self.batch_size], seed=self._scheduling_seed)
         select_sample = (self._sampling_probability > select_sample_noise)
         sample_id_sampler = categorical.Categorical(logits=outputs)
         return array_ops.where(select_sample,
                                sample_id_sampler.sample(seed=self._seed),
                                array_ops.tile([-1], [self.batch_size]))
Esempio n. 7
0
    def testLogPMFShapeNoBatch(self):
        histograms = [0.2, 0.8]
        dist = categorical.Categorical(math_ops.log(histograms))

        log_prob = dist.log_prob(0)
        self.assertEqual(0, log_prob.get_shape().ndims)
        self.assertAllEqual([], log_prob.get_shape())

        log_prob = dist.log_prob([[[1, 1], [1, 0]], [[1, 0], [0, 1]]])
        self.assertEqual(3, log_prob.get_shape().ndims)
        self.assertAllEqual([2, 2, 2], log_prob.get_shape())
Esempio n. 8
0
    def testUnknownShape(self):
        with self.test_session():
            logits = array_ops.placeholder(dtype=dtypes.float32)
            dist = categorical.Categorical(logits)
            sample = dist.sample()
            # Will sample class 1.
            sample_value = sample.eval(feed_dict={logits: [-1000.0, 1000.0]})
            self.assertEqual(1, sample_value)

            # Batch entry 0 will sample class 1, batch entry 1 will sample class 0.
            sample_value_batch = sample.eval(
                feed_dict={logits: [[-1000.0, 1000.0], [1000.0, -1000.0]]})
            self.assertAllEqual([1, 0], sample_value_batch)
Esempio n. 9
0
    def testLogPMFShape(self):
        with self.test_session():
            # shape [1, 2, 2]
            histograms = [[[0.2, 0.8], [0.4, 0.6]]]
            dist = categorical.Categorical(math_ops.log(histograms))

            log_prob = dist.log_prob([0, 1])
            self.assertEqual(2, log_prob.get_shape().ndims)
            self.assertAllEqual([1, 2], log_prob.get_shape())

            log_prob = dist.log_prob([[[1, 1], [1, 0]], [[1, 0], [0, 1]]])
            self.assertEqual(3, log_prob.get_shape().ndims)
            self.assertAllEqual([2, 2, 2], log_prob.get_shape())
Esempio n. 10
0
 def testSampleWithSampleShape(self):
     with self.test_session():
         histograms = [[[0.2, 0.8], [0.4, 0.6]]]
         dist = categorical.Categorical(math_ops.log(histograms) - 50.)
         samples = dist.sample((100, 100), seed=123)
         prob = dist.prob(samples)
         prob_val = prob.eval()
         self.assertAllClose([0.2**2 + 0.8**2],
                             [prob_val[:, :, :, 0].mean()],
                             atol=1e-2)
         self.assertAllClose([0.4**2 + 0.6**2],
                             [prob_val[:, :, :, 1].mean()],
                             atol=1e-2)
Esempio n. 11
0
 def testSample(self):
     with self.test_session():
         histograms = [[[0.2, 0.8], [0.4, 0.6]]]
         dist = categorical.Categorical(math_ops.log(histograms) - 50.)
         n = 10000
         samples = dist.sample(n, seed=123)
         samples.set_shape([n, 1, 2])
         self.assertEqual(samples.dtype, dtypes.int32)
         sample_values = samples.eval()
         self.assertFalse(np.any(sample_values < 0))
         self.assertFalse(np.any(sample_values > 1))
         self.assertAllClose([[0.2, 0.4]],
                             np.mean(sample_values == 0, axis=0),
                             atol=1e-2)
         self.assertAllClose([[0.8, 0.6]],
                             np.mean(sample_values == 1, axis=0),
                             atol=1e-2)
Esempio n. 12
0
 def testPMFNoBatch(self):
     histograms = [0.2, 0.8]
     dist = categorical.Categorical(math_ops.log(histograms) - 50.)
     with self.test_session():
         self.assertAllClose(dist.pmf(0).eval(), 0.2)
Esempio n. 13
0
 def testP(self):
     p = [0.2, 0.8]
     dist = categorical.Categorical(p=p)
     with self.test_session():
         self.assertAllClose(p, dist.p.eval())
         self.assertAllEqual([2], dist.logits.get_shape())
Esempio n. 14
0
def make_categorical(batch_shape, num_classes, dtype=dtypes.int32):
    logits = random_ops.random_uniform(
        list(batch_shape) + [num_classes], -10, 10, dtype=dtypes.float32) - 50.
    return categorical.Categorical(logits, dtype=dtype)
Esempio n. 15
0
 def testLogPMF(self):
     logits = np.log([[0.2, 0.8], [0.6, 0.4]]) - 50.
     dist = categorical.Categorical(logits)
     with self.test_session():
         self.assertAllClose(
             dist.log_pmf([0, 1]).eval(), np.log([0.2, 0.4]))
Esempio n. 16
0
 def testMode(self):
     with self.test_session():
         histograms = [[[0.2, 0.8], [0.6, 0.4]]]
         dist = categorical.Categorical(math_ops.log(histograms) - 50.)
         self.assertAllEqual(dist.mode().eval(), [[1, 0]])
Esempio n. 17
0
 def testPMFWithBatch(self):
     histograms = [[0.2, 0.8], [0.6, 0.4]]
     dist = categorical.Categorical(math_ops.log(histograms) - 50.)
     with self.test_session():
         self.assertAllClose(dist.pmf([0, 1]).eval(), [0.2, 0.4])
Esempio n. 18
0
 def testEntropyNoBatch(self):
     logits = np.log([0.2, 0.8]) - 50.
     dist = categorical.Categorical(logits)
     with self.test_session():
         self.assertAllClose(dist.entropy().eval(),
                             -(0.2 * np.log(0.2) + 0.8 * np.log(0.8)))