Ejemplo n.º 1
0
 def test_natural_gradient(self):
     """
     Test random natural gradient cases.
     """
     with tf.Graph().as_default():
         with tf.Session() as sess:
             for size in range(3, 9):
                 dist = NaturalSoftmax(size, epsilon=0)
                 softmax = CategoricalSoftmax(size)
                 param_row = tf.constant(np.random.normal(size=(size, )),
                                         dtype=tf.float64)
                 params = tf.stack([param_row])
                 one_hot = np.zeros((1, size))
                 one_hot[0, 1] = 1
                 samples = tf.constant(one_hot, dtype=tf.float64)
                 kl_div = softmax.kl_divergence(tf.stop_gradient(params),
                                                params)
                 hessian = sess.run(tf.hessians(kl_div, param_row)[0])
                 gradient = sess.run(
                     tf.gradients(softmax.log_prob(params, samples),
                                  params)[0][0])
                 expected = np.matmul(np.array([gradient]),
                                      np.linalg.pinv(hessian))[0]
                 actual = sess.run(
                     tf.gradients(dist.log_prob(params, samples),
                                  params)[0][0])
                 self.assertTrue(np.allclose(actual, expected))
Ejemplo n.º 2
0
def test_cat_softmax_generic():
    """
    Run generic tests for CategoricalSoftmax.
    """
    dist = CategoricalSoftmax(7, low=2)
    tester = DistributionTester(dist)
    tester.test_all()
Ejemplo n.º 3
0
 def test_generic(self):
     """
     Run generic tests with DistributionTester.
     """
     dist = CategoricalSoftmax(7, low=2)
     tester = DistributionTester(self, dist)
     tester.test_all()
Ejemplo n.º 4
0
def test_nat_softmax_log_prob():
    """
    Test log probabilities of NaturalSoftmax.
    """
    with tf.Graph().as_default():
        with tf.Session() as sess:
            dist = NaturalSoftmax(7)
            params = tf.constant(np.random.normal(size=(15, 7)), dtype=tf.float64)
            sampled = tf.one_hot([random.randrange(7) for _ in range(15)], 7,
                                 dtype=tf.float64)
            actual = sess.run(dist.log_prob(params, sampled))
            expected = sess.run(CategoricalSoftmax(7).log_prob(params, sampled))
            assert np.allclose(actual, expected)