Пример #1
0
def test_box_beta_log_prob_simple():
    """
    Test log probs for a very simple BoxBeta.
    """
    with tf.Graph().as_default():
        with tf.Session() as sess:
            dist = BoxBeta(np.array([0]), np.array([1]), softplus=False)
            actual = sess.run(dist.log_prob(np.array([[0.5, 0.5]]), np.array([0.5])))
            expected = np.array([_beta_log_prob(0.5, 0.5, 0.5)])
            assert np.allclose(actual, expected, atol=1e-4)
Пример #2
0
def test_box_beta_log_prob():
    """
    Test log probs for known situations of BoxBeta.
    """
    with tf.Graph().as_default():
        with tf.Session() as sess:
            dist = BoxBeta(np.array([[0], [-2]]), np.array([[1], [3]]), softplus=False)
            actual = sess.run(dist.log_prob(np.array([[[0.1, 0.5]], [[0.3, 0.7]]]),
                                            np.array([[0.4], [-0.5]])))
            expected = np.array([_beta_log_prob(0.1, 0.5, 0.4),
                                 _beta_log_prob(0.3, 0.7, 0.3) - np.log(5)])
            assert np.allclose(actual, expected, atol=1e-4)