def test_logpdf(): m = Measure() p1 = GP(EQ(), measure=m) p2 = GP(Exp(), measure=m) e = GP(Delta(), measure=m) p3 = p1 + p2 x1 = B.linspace(0, 2, 5) x2 = B.linspace(1, 3, 6) x3 = B.linspace(2, 4, 7) y1, y2, y3 = m.sample(p1(x1), p2(x2), p3(x3)) # Test case that only one process is fed. approx(p1(x1).logpdf(y1), m.logpdf(p1(x1), y1)) approx(p1(x1).logpdf(y1), m.logpdf((p1(x1), y1))) # Compute the logpdf with the product rule. d1 = m d2 = d1 | (p1(x1), y1) d3 = d2 | (p2(x2), y2) approx( d1(p1)(x1).logpdf(y1) + d2(p2)(x2).logpdf(y2) + d3(p3)(x3).logpdf(y3), m.logpdf((p1(x1), y1), (p2(x2), y2), (p3(x3), y3)), ) # Check that `Measure.logpdf` allows `Obs` and `SparseObs`. obs = Obs(p3(x3), y3) approx(m.logpdf(obs), p3(x3).logpdf(y3)) obs = SparseObs(p3(x3), e, p3(x3), y3) approx(m.logpdf(obs), (p3 + e)(x3).logpdf(y3))
def test_case_blr(): m = Measure() x = B.linspace(0, 10, 100) slope = GP(1, measure=m) intercept = GP(1, measure=m) f = slope * (lambda x: x) + intercept y = f + 1e-2 * GP(Delta(), measure=m) # Sample observations, true slope, and intercept. y_obs, true_slope, true_intercept = m.sample(y(x), slope(0), intercept(0)) # Predict. post = m | (y(x), y_obs) approx(post(slope)(0).mean, true_slope, atol=5e-2) approx(post(intercept)(0).mean, true_intercept, atol=5e-2)
def test_multi_sample(): m = Measure() p1 = GP(1, 0, measure=m) p2 = GP(2, 0, measure=m) p3 = GP(3, 0, measure=m) x1 = B.linspace(0, 1, 5) x2 = B.linspace(0, 1, 10) x3 = B.linspace(0, 1, 15) s1, s2, s3 = m.sample(p1(x1), p2(x2), p3(x3)) assert B.shape(p1(x1).sample()) == s1.shape == (5, 1) assert B.shape(p2(x2).sample()) == s2.shape == (10, 1) assert B.shape(p3(x3).sample()) == s3.shape == (15, 1) approx(s1, 1 * B.ones(5, 1)) approx(s2, 2 * B.ones(10, 1)) approx(s3, 3 * B.ones(15, 1))
def test_approximate_multiplication(): m = Measure() p1 = GP(20, EQ(), measure=m) p2 = GP(20, EQ(), measure=m) p_prod = p1 * p2 # Sample functions. x = B.linspace(0, 10, 50) s1, s2 = m.sample(p1(x), p2(x)) # Perform product. post = m | ((p1(x), s1), (p2(x), s2)) approx(post(p_prod)(x).mean, s1 * s2, rtol=1e-2) # Perform division. cur_epsilon = B.epsilon B.epsilon = 1e-8 post = m | ((p1(x), s1), (p_prod(x), s1 * s2)) approx(post(p2)(x).mean, s2, rtol=1e-2) B.epsilon = cur_epsilon