コード例 #1
0
ファイル: test_graph.py プロジェクト: wangwanwang56/stheno
def test_shorthands():
    model = Graph()
    p = GP(EQ(), graph=model)

    # Construct a normal distribution that serves as in input.
    x = p(1)
    assert isinstance(x, At)
    assert type_parameter(x) is p
    assert x.get() == 1
    assert str(p(x)) == '{}({})'.format(str(p), str(x))
    assert repr(p(x)) == '{}({})'.format(repr(p), repr(x))

    # Construct a normal distribution that does not serve as an input.
    x = Normal(np.ones((1, 1)))
    with pytest.raises(RuntimeError):
        type_parameter(x)
    with pytest.raises(RuntimeError):
        x.get()
    with pytest.raises(RuntimeError):
        p | (x, 1)

    # Test shorthands for stretching and selection.
    p = GP(EQ(), graph=Graph())
    assert str(p > 2) == str(p.stretch(2))
    assert str(p[0]) == str(p.select(0))
コード例 #2
0
ファイル: test_graph.py プロジェクト: leekwoon/stheno
def test_stretching():
    model = Graph()

    # Test construction:
    p = GP(EQ(), TensorProductMean(lambda x: x ** 2), graph=model)
    yield eq, str(p.stretch(1)), 'GP(EQ() > 1, <lambda> > 1)'

    # Test case:
    p = GP(EQ(), graph=model)
    p2 = p.stretch(5)

    n = 5
    x = np.linspace(0, 10, n)[:, None]
    y = p2(x).sample()

    post = p.condition(p2(x), y)
    yield assert_allclose, post(x / 5).mean, y
    yield le, abs_err(B.diag(post(x / 5).var)), 1e-10

    post = p2.condition(p(x), y)
    yield assert_allclose, post(x * 5).mean, y
    yield le, abs_err(B.diag(post(x * 5).var)), 1e-10
コード例 #3
0
ファイル: test_graph.py プロジェクト: leekwoon/stheno
def test_shorthands():
    model = Graph()
    p = GP(EQ(), graph=model)

    # Construct a normal distribution that serves as in input.
    x = p(1)
    yield assert_instance, x, At
    yield ok, type_parameter(x) is p
    yield eq, x.get(), 1
    yield eq, str(p(x)), '{}({})'.format(str(p), str(x))
    yield eq, repr(p(x)), '{}({})'.format(repr(p), repr(x))

    # Construct a normal distribution that does not serve as an input.
    x = Normal(np.ones((1, 1)))
    yield raises, RuntimeError, lambda: type_parameter(x)
    yield raises, RuntimeError, lambda: x.get()
    yield raises, RuntimeError, lambda: p | (x, 1)

    # Test shorthands for stretching and selection.
    p = GP(EQ(), graph=Graph())
    yield eq, str(p > 2), str(p.stretch(2))
    yield eq, str(p[0]), str(p.select(0))