def test_selection(): model = Graph() # Test construction: p = GP(EQ(), TensorProductMean(lambda x: x ** 2), graph=model) yield eq, str(p.select(1)), 'GP(EQ() : [1], <lambda> : [1])' yield eq, str(p.select(1, 2)), 'GP(EQ() : [1, 2], <lambda> : [1, 2])' # Test case: p = GP(EQ(), graph=model) # 1D p2 = p.select(0) # 2D n = 5 x = np.linspace(0, 10, n)[:, None] x1 = np.concatenate((x, np.random.randn(n, 1)), axis=1) x2 = np.concatenate((x, np.random.randn(n, 1)), axis=1) y = p2(x).sample() post = p.condition(p2(x1), y) yield assert_allclose, post(x).mean, y yield le, abs_err(B.diag(post(x).var)), 1e-10 post = p.condition(p2(x2), y) yield assert_allclose, post(x).mean, y yield le, abs_err(B.diag(post(x).var)), 1e-10 post = p2.condition(p(x), y) yield assert_allclose, post(x1).mean, y yield assert_allclose, post(x2).mean, y yield le, abs_err(B.diag(post(x1).var)), 1e-10 yield le, abs_err(B.diag(post(x2).var)), 1e-10
def test_shorthands(): model = Graph() p = GP(EQ(), graph=model) # Construct a normal distribution that serves as in input. x = p(1) assert isinstance(x, At) assert type_parameter(x) is p assert x.get() == 1 assert str(p(x)) == '{}({})'.format(str(p), str(x)) assert repr(p(x)) == '{}({})'.format(repr(p), repr(x)) # Construct a normal distribution that does not serve as an input. x = Normal(np.ones((1, 1))) with pytest.raises(RuntimeError): type_parameter(x) with pytest.raises(RuntimeError): x.get() with pytest.raises(RuntimeError): p | (x, 1) # Test shorthands for stretching and selection. p = GP(EQ(), graph=Graph()) assert str(p > 2) == str(p.stretch(2)) assert str(p[0]) == str(p.select(0))
def test_shorthands(): model = Graph() p = GP(EQ(), graph=model) # Construct a normal distribution that serves as in input. x = p(1) yield assert_instance, x, At yield ok, type_parameter(x) is p yield eq, x.get(), 1 yield eq, str(p(x)), '{}({})'.format(str(p), str(x)) yield eq, repr(p(x)), '{}({})'.format(repr(p), repr(x)) # Construct a normal distribution that does not serve as an input. x = Normal(np.ones((1, 1))) yield raises, RuntimeError, lambda: type_parameter(x) yield raises, RuntimeError, lambda: x.get() yield raises, RuntimeError, lambda: p | (x, 1) # Test shorthands for stretching and selection. p = GP(EQ(), graph=Graph()) yield eq, str(p > 2), str(p.stretch(2)) yield eq, str(p[0]), str(p.select(0))