def test_selection(): # Test construction: p = GP(lambda x: x**2, EQ()) assert str(p.select(1)) == "GP(<lambda> : [1], EQ() : [1])" assert str(p.select(1, 2)) == "GP(<lambda> : [1, 2], EQ() : [1, 2])" # Test case: # `p` is 1D. p2 = p.select(0) # `p2` is 2D. x = B.linspace(0, 5, 10) x21 = B.stack(x, B.randn(10), axis=1) x22 = B.stack(x, B.randn(10), axis=1) y = p2(x).sample() post = p.measure | (p2(x21), y) approx(post(p(x)).mean, y) assert_equal_normals(post(p(x)), post(p2(x21))) post = p.measure | (p2(x22), y) approx(post(p(x)).mean, y) assert_equal_normals(post(p(x)), post(p2(x22))) post = p.measure | (p(x), y) approx(post(p2(x21)).mean, y) approx(post(p2(x22)).mean, y) assert_equal_normals(post(p2(x21)), post(p(x))) assert_equal_normals(post(p2(x22)), post(p(x)))
def test_additive_model(): m = Measure() p1 = GP(EQ(), measure=m) p2 = GP(EQ(), measure=m) p_sum = p1 + p2 x = B.linspace(0, 5, 10) y1 = p1(x).sample() y2 = p2(x).sample() # First, test independence: assert m.kernels[p2, p1] == ZeroKernel() assert m.kernels[p1, p2] == ZeroKernel() # Now run through some test cases: post = (m | (p1(x), y1)) | (p2(x), y2) approx(post(p_sum)(x).mean, y1 + y2) post = (m | (p2(x), y2)) | (p1(x), y1) approx(post(p_sum)(x).mean, y1 + y2) post = (m | (p1(x), y1)) | (p_sum(x), y1 + y2) approx(post(p2)(x).mean, y2) post = (m | (p_sum(x), y1 + y2)) | (p1(x), y1) approx(post(p2)(x).mean, y2) post = (m | (p2(x), y2)) | (p_sum(x), y1 + y2) approx(post(p1)(x).mean, y1) post = (m | (p_sum(x), y1 + y2)) | (p2(x), y2) approx(post(p1)(x).mean, y1)
def test_take_x(): m = Measure() f1 = GP(EQ()) f2 = GP(EQ()) k = MultiOutputKernel(m, f1) with pytest.raises(ValueError): _take_x(k, f2(B.linspace(0, 1, 10)), B.randn(10) > 0)
def test_multi_sample(): m = Measure() p1 = GP(1, 0, measure=m) p2 = GP(2, 0, measure=m) p3 = GP(3, 0, measure=m) x1 = B.linspace(0, 1, 5) x2 = B.linspace(0, 1, 10) x3 = B.linspace(0, 1, 15) fdds = (p1(x1), p2(x2), p3(x3)) s1, s2, s3 = m.sample(*fdds) assert B.shape(p1(x1).sample()) == s1.shape == (5, 1) assert B.shape(p2(x2).sample()) == s2.shape == (10, 1) assert B.shape(p3(x3).sample()) == s3.shape == (15, 1) approx(s1, 1 * B.ones(5, 1)) approx(s2, 2 * B.ones(10, 1)) approx(s3, 3 * B.ones(15, 1)) # Test random state. state, s11, s21, s31 = m.sample(B.create_random_state(np.float64, seed=0), *fdds) state, s12, s22, s32 = m.sample(B.create_random_state(np.float64, seed=0), *fdds) assert isinstance(state, B.RandomState) approx(s11, s12) approx(s21, s22) approx(s31, s32)
def test_logpdf(): m = Measure() p1 = GP(EQ(), measure=m) p2 = GP(Exp(), measure=m) p3 = p1 + p2 x1 = B.linspace(0, 2, 5) x2 = B.linspace(1, 3, 6) x3 = B.linspace(2, 4, 7) y1, y2, y3 = m.sample(p1(x1), p2(x2), p3(x3)) # Test case that only one process is fed. approx(p1(x1).logpdf(y1), m.logpdf(p1(x1), y1)) approx(p1(x1).logpdf(y1), m.logpdf((p1(x1), y1))) # Compute the logpdf with the product rule. d1 = m d2 = d1 | (p1(x1), y1) d3 = d2 | (p2(x2), y2) approx( d1(p1)(x1).logpdf(y1) + d2(p2)(x2).logpdf(y2) + d3(p3)(x3).logpdf(y3), m.logpdf((p1(x1), y1), (p2(x2), y2), (p3(x3), y3)), ) # Check that `Measure.logpdf` allows `Obs` and `PseudoObs`. obs = Obs(p3(x3), y3) approx(m.logpdf(obs), p3(x3).logpdf(y3)) obs = PseudoObs(p3(x3), p3(x3, 1), y3) approx(m.logpdf(obs), p3(x3, 1).logpdf(y3))
def test_fd_derivative(): x = B.linspace(0, 10, 50) y = np.sin(x) p = GP(0.7 * EQ().stretch(1.0)) dp = (p.shift(-1e-3) - p.shift(1e-3)) / 2e-3 post = p.measure | (p(x), y) approx(post(dp)(x).mean, np.cos(x)[:, None], atol=1e-4)
def test_shifting(): # Test construction: p = GP(lambda x: x**2, Linear()) assert str(p.shift(1)) == "GP(<lambda> shift 1, Linear() shift 1)" # Test case: p_shifted = p.shift(5) x = B.linspace(0, 5, 10) y = p_shifted(x).sample() post = p.measure | (p_shifted(x, B.epsilon), y) assert_equal_normals(post(p(x - 5)), post(p_shifted(x))) assert_equal_normals(post(p(x)), post(p_shifted(x + 5)))
def test_stretching(): # Test construction: p = GP(lambda x: x**2, Linear()) assert str(p.stretch(1)) == "GP(<lambda> > 1, Linear() > 1)" # Test case: p_stretched = p.stretch(5) x = B.linspace(0, 5, 10) y = p_stretched(x).sample() post = p.measure | (p_stretched(x, B.epsilon), y) assert_equal_normals(post(p(x / 5)), post(p_stretched(x))) assert_equal_normals(post(p(x)), post(p_stretched(x * 5)))
def test_sum_other(): p = GP(TensorProductMean(lambda x: x ** 2), EQ()) def five(y): return 5 * B.ones(B.shape(y)[0], 1) x = B.randn(5, 1) for p_sum in [ # Add a numeric thing. p + 5.0, 5.0 + p, p.measure.sum(GP(), p, 5.0), p.measure.sum(GP(), 5.0, p), # Add a function. p + five, five + p, p.measure.sum(GP(), p, five), p.measure.sum(GP(), five, p), ]: approx(p.mean(x) + 5.0, p_sum.mean(x)) approx(p.mean(x) + 5.0, p_sum.mean(x)) approx(p.kernel(x), p_sum.kernel(x)) approx(p.kernel(x), p_sum.kernel(x)) # Check that a `GP` cannot be summed with a `Normal`. with pytest.raises(NotFoundLookupError): p + Normal(np.eye(3)) with pytest.raises(NotFoundLookupError): Normal(np.eye(3)) + p
def test_mul_other(): p = GP(TensorProductMean(lambda x: x ** 2), EQ()) def five(y): return 5 * B.ones(B.shape(y)[0], 1) x = B.randn(5, 1) for p_mul in [ # Multiply numeric thing. p * 5.0, 5.0 * p, p.measure.mul(GP(), p, 5.0), p.measure.mul(GP(), 5.0, p), # Multiply with a function. p * five, five * p, p.measure.mul(GP(), p, five), p.measure.mul(GP(), five, p), ]: approx(5.0 * p.mean(x), p_mul.mean(x)) approx(5.0 * p.mean(x), p_mul.mean(x)) approx(25.0 * p.kernel(x), p_mul.kernel(x)) approx(25.0 * p.kernel(x), p_mul.kernel(x)) # Check that a `GP` cannot be multiplied with a `Normal`. with pytest.raises(NotFoundLookupError): p * Normal(np.eye(3)) with pytest.raises(NotFoundLookupError): Normal(np.eye(3)) * p
def test_conditioning_consistency(): m = Measure() p = GP(EQ(), measure=m) e = GP(0.1 * Delta(), measure=m) e2 = GP(e.kernel, measure=m) x = B.linspace(0, 5, 10) y = (p + e)(x).sample() post1 = m | ((p + e)(x), y) post2 = m | (p(x, 0.1), y) assert_equal_measures([p(x), (p + e2)(x)], post1, post2) with pytest.raises(AssertionError): assert_equal_normals(post1((p + e)(x)), post2((p + e)(x)))
def test_input_transform(): # Test construction: p = GP(lambda x: x**2, Linear()) assert (str(p.transform(lambda x: x)) == "GP(<lambda> transform <lambda>, Linear() transform <lambda>)") # Test case: p_transformed = p.transform(lambda x: B.sqrt(x)) x = B.linspace(0, 5, 10) y = p_transformed(x).sample() post = p.measure | (p_transformed(x, B.epsilon), y) assert_equal_normals(post(p(B.sqrt(x))), post(p_transformed(x))) assert_equal_normals(post(p(x)), post(p_transformed(x * x)))
def test_pseudoobs_kernel_call_count(): class TrackingEQ(Kernel): """Track the evaluations of this EQ kernel.""" pairwise_calls = [] elwise_calls = [] @pairwise.dispatch def pairwise_(k: TrackingEQ, x: B.Numeric, y: B.Numeric): pairwise_calls.append((B.flatten(x), B.flatten(y))) return B.exp(-0.5 * B.pw_dists2(x, y)) @elwise.dispatch def elwise_(k: TrackingEQ, x: B.Numeric, y: B.Numeric): elwise_calls.append((B.flatten(x), B.flatten(y))) return B.exp(-0.5 * B.ew_dists2(x, y)) # Construct some inputs. x_obs = B.linspace(0, 5, 10) y_obs = B.randn(10) x_ind = B.linspace(0, 5, 5) x_new = B.randn(1) # Perform a pseudo-point approximation p = GP(1, TrackingEQ()) p_post = p | PseudoObs(p(x_ind), (p(x_obs, 0.1), y_obs)) mean, var = p_post(x_new).marginals() # Check the calls. approx(tuple(pairwise_calls), ((x_obs, x_ind), (x_ind, x_ind), (x_ind, x_new))) approx(tuple(elwise_calls), ((x_obs, x_obs), (x_new, x_new)))
def test_manual_new_gp(): m = Measure() p1 = GP(1, EQ(), measure=m) p2 = GP(2, EQ(), measure=m) p_sum = p1 + p2 p1_equivalent = m.add_gp( m.means[p_sum] - m.means[p2], (m.kernels[p_sum] + m.kernels[p2] - m.kernels[p_sum, p2] - m.kernels[p2, p_sum]), lambda j: m.kernels[p_sum, j] - m.kernels[p2, j], ) x = B.linspace(0, 10, 5) s1, s2 = m.sample(p1(x), p1_equivalent(x)) approx(s1, s2, atol=1e-4)
def test_blr(): m = Measure() x = B.linspace(0, 10, 100) slope = GP(1, measure=m) intercept = GP(1, measure=m) f = slope * (lambda x: x) + intercept y = f + 1e-2 * GP(Delta(), measure=m) # Sample observations, true slope, and intercept. y_obs, true_slope, true_intercept = m.sample(y(x), slope(0), intercept(0)) # Predict. post = m | (y(x), y_obs) approx(post(slope)(0).mean, true_slope, atol=5e-2) approx(post(intercept)(0).mean, true_intercept, atol=5e-2)
def test_default_measure(): with Measure() as m1: p1 = GP(EQ()) with Measure() as m2: p2 = GP(EQ()) p3 = GP(EQ()) p4 = GP(EQ()) assert p1.measure is m1 assert p2.measure is m2 assert p3.measure is m1 assert p4.measure is not m1 assert p4.measure is not m2
def test_conditioning_prior(): p = GP(EQ()) x = B.zeros(0, 1) y = B.zeros(0, 1) post = p.measure | (p(x), y) assert post(p).mean is p.mean assert post(p).kernel is p.kernel
def test_conditioning_missing_data(): p = GP(1, EQ()) x = B.linspace(0, 5, 10) y = p(x).sample() y[:3] = B.nan post1 = p | (p(x), y) post2 = p | (p(x[3:]), y[3:]) assert_equal_normals(post1(x), post2(x))
def test_conditioning(generate_noise_tuple): m = Measure() p1 = GP(EQ(), measure=m) p2 = GP(Exp(), measure=m) p_sum = p1 + p2 # Sample some data to condition on. x1 = B.linspace(0, 2, 3) n1 = generate_noise_tuple(x1) y1 = p1(x1, *n1).sample() tup1 = (p1(x1, *n1), y1) x_sum = B.linspace(3, 5, 3) n_sum = generate_noise_tuple(x_sum) y_sum = p_sum(x_sum, *n_sum).sample() tup_sum = (p_sum(x_sum, *n_sum), y_sum) # Determine FDDs to check. x_check = B.linspace(0, 5, 5) fdds_check = [ cross(p1, p2, p_sum)(x_check), p1(x_check), p2(x_check), p_sum(x_check), ] assert_equal_measures( fdds_check, m.condition(*tup_sum), m.condition(tup_sum), m | tup_sum, m | (tup_sum, ), m | Obs(*tup_sum), m | Obs(tup_sum), ) assert_equal_measures( fdds_check, m.condition(tup1, tup_sum), m | (tup1, tup_sum), m | Obs(tup1, tup_sum), ) # Check that conditioning gives an FDD and that it is consistent. post = m | tup1 assert isinstance(post(p1(x1, 0.1)), FDD) assert_equal_measures(post(p1(x1, 0.1)), post(p1)(x1, 0.1))
def test_sample_correct_measure(): m = Measure() p1 = GP(1, EQ(), measure=m) post = m | (p1(0), 1) # Test that `post.sample` indeed samples under `post`. approx(post.sample(10, p1(0)), B.ones(1, 10), atol=1e-4)
def test_stationarity(): m = Measure() p1 = GP(EQ(), measure=m) p2 = GP(EQ().stretch(2), measure=m) p3 = GP(EQ().periodic(10), measure=m) p = p1 + 2 * p2 assert p.stationary p = p3 + p assert p.stationary p = p + GP(Linear(), measure=m) assert not p.stationary
def test_conditioning_shorthand(): p = GP(EQ()) # Test conditioning once. x = B.linspace(0, 5, 10) y = p(x).sample() p_post1 = p.condition(p(x), y) p_post2 = p | (p(x), y) approx(p_post1.mean(x), y) approx(p_post2.mean(x), y) # Test conditioning twice. x = B.linspace(10, 20, 10) y = p(x).sample() p_post1 = p_post1.condition(p(x), y) p_post2 = p_post2 | (p(x), y) approx(p_post1.mean(x), y) approx(p_post2.mean(x), y)
def test_mo_batched(): x = B.randn(16, 10, 1) with Measure(): p = cross(GP(1, 2 * EQ().stretch(0.5)), GP(2, 2 * EQ().stretch(0.5))) y = p(x).sample() logpdf = p(x, 0.1).logpdf(y) assert B.shape(logpdf) == (16, ) assert B.shape(y) == (16, 20, 1) p = p | (p(x), y) y2 = p(x).sample() logpdf2 = p(x, 0.1).logpdf(y) assert B.shape(y2) == (16, 20, 1) assert B.shape(logpdf2) == (16, ) assert B.all(logpdf2 > logpdf) approx(y, y2, atol=1e-5)
def test_multi_sample(): m = Measure() p1 = GP(1, 0, measure=m) p2 = GP(2, 0, measure=m) p3 = GP(3, 0, measure=m) x1 = B.linspace(0, 1, 5) x2 = B.linspace(0, 1, 10) x3 = B.linspace(0, 1, 15) s1, s2, s3 = m.sample(p1(x1), p2(x2), p3(x3)) assert B.shape(p1(x1).sample()) == s1.shape == (5, 1) assert B.shape(p2(x2).sample()) == s2.shape == (10, 1) assert B.shape(p3(x3).sample()) == s3.shape == (15, 1) approx(s1, 1 * B.ones(5, 1)) approx(s2, 2 * B.ones(10, 1)) approx(s3, 3 * B.ones(15, 1))
def test_approximate_derivative(): p = GP(EQ().stretch(1.0)) dp = p.diff_approx() x = B.linspace(0, 1, 100) y = 2 * x x_check = B.linspace(0.2, 0.8, 100) # Test conditioning on function. post = p.measure | (p(x), y) approx(post(dp)(x_check).mean, 2 * B.ones(100, 1), atol=1e-3) # Test conditioning on derivative. orig_epsilon = B.epsilon B.epsilon = 1e-10 post = p.measure | ((p(0), 0), (dp(x), y)) approx(post(p)(x_check).mean, x_check[:, None]**2, atol=1e-3) B.epsilon = orig_epsilon
def test_conditioning_empty_observations(shape): p = GP(1, EQ()) x = B.randn(*shape) y = p(x).sample() # Conditioning should just return the prior exactly. p_post = p | (p(x), y) assert p_post.mean is p.mean assert p_post.kernel is p.kernel
def test_marginal_credible_bounds_efficiency(): p = GP(EQ()) x = B.linspace(0, 5, 5) y = p(x, 0.1).sample() p = p | (p(x, 0.1), y) # Check that the computation at 10_000 points takes at most one second. x = B.linspace(0, 5, 10_000) start = time() p(x, 0.2).marginal_credible_bounds() assert time() - start < 1
def test_approximate_multiplication(): m = Measure() p1 = GP(20, EQ(), measure=m) p2 = GP(20, EQ(), measure=m) p_prod = p1 * p2 # Sample functions. x = B.linspace(0, 10, 50) s1, s2 = m.sample(p1(x), p2(x)) # Perform product. post = m | ((p1(x), s1), (p2(x), s2)) approx(post(p_prod)(x).mean, s1 * s2, rtol=1e-2) # Perform division. cur_epsilon = B.epsilon B.epsilon = 1e-8 post = m | ((p1(x), s1), (p_prod(x), s1 * s2)) approx(post(p2)(x).mean, s2, rtol=1e-2) B.epsilon = cur_epsilon
def test_corner_cases(): p1 = GP(EQ()) p2 = GP(EQ()) x = B.randn(10, 2) # Test check for measure group. with pytest.raises(AssertionError): p1 + p2 with pytest.raises(AssertionError): p1 * p2 # Test incompatible operations. with pytest.raises(NotFoundLookupError): p1 + p1(x) with pytest.raises(NotFoundLookupError): p1 * p1(x) # Check test for prior. with pytest.raises(RuntimeError): GP().measure
def test_summation_with_itself(): p = GP(EQ()) p_many = p + p + p + p + p x = B.linspace(0, 10, 5) approx(p_many(x).var, 25 * p(x).var) approx(p_many(x).mean, B.zeros(5, 1)) y = B.randn(5, 1) post = p.measure | (p(x), y) approx(post(p_many)(x).mean, 5 * y)