예제 #1
0
def test_box_beta_log_prob_simple():
    """
    Test log probs for a very simple BoxBeta.
    """
    with tf.Graph().as_default():
        with tf.Session() as sess:
            dist = BoxBeta(np.array([0]), np.array([1]), softplus=False)
            actual = sess.run(dist.log_prob(np.array([[0.5, 0.5]]), np.array([0.5])))
            expected = np.array([_beta_log_prob(0.5, 0.5, 0.5)])
            assert np.allclose(actual, expected, atol=1e-4)
예제 #2
0
def test_box_beta_log_prob():
    """
    Test log probs for known situations of BoxBeta.
    """
    with tf.Graph().as_default():
        with tf.Session() as sess:
            dist = BoxBeta(np.array([[0], [-2]]), np.array([[1], [3]]), softplus=False)
            actual = sess.run(dist.log_prob(np.array([[[0.1, 0.5]], [[0.3, 0.7]]]),
                                            np.array([[0.4], [-0.5]])))
            expected = np.array([_beta_log_prob(0.1, 0.5, 0.4),
                                 _beta_log_prob(0.3, 0.7, 0.3) - np.log(5)])
            assert np.allclose(actual, expected, atol=1e-4)
예제 #3
0
 def test_generic_simple(self):
     """
     Run generic tests with an unscaled distribution.
     """
     dist = BoxBeta(np.array([0]), np.array([1]))
     tester = DistributionTester(self, dist)
     tester.test_all()
예제 #4
0
def test_box_beta_generic():
    """
    Run generic tests for BoxBeta.
    """
    dist = BoxBeta(np.array([[-3, 7, 1], [1, 2, 3]]),
                   np.array([[5, 7.1, 3], [2, 3.1, 4]]))
    tester = DistributionTester(dist, batch_size=400000)
    tester.test_all()
예제 #5
0
def test_box_beta_generic_simple():
    """
    Run generic tests for BoxBeta with an unscaled
    distribution.
    """
    dist = BoxBeta(np.array([0]), np.array([1]))
    tester = DistributionTester(dist)
    tester.test_all()
예제 #6
0
 def test_generic(self):
     """
     Run generic tests with DistributionTester.
     """
     dist = BoxBeta(np.array([[-3, 7, 1], [1, 2, 3]]),
                    np.array([[5, 7.1, 3], [2, 3.1, 4]]))
     tester = DistributionTester(self, dist, batch_size=400000)
     tester.test_all()