コード例 #1
0
def test_input_transform():
    k = Linear().transform(lambda x, c: x - 5)

    yield eq, k.stationary, False
    yield raises, RuntimeError, lambda: k.length_scale
    yield raises, RuntimeError, lambda: k.var
    yield raises, RuntimeError, lambda: k.period

    def f1(x):
        return x

    def f2(x):
        return x**2

    # Test equality.
    yield eq, EQ().transform(f1), EQ().transform(f1)
    yield neq, EQ().transform(f1), EQ().transform(f2)
    yield neq, EQ().transform(f1), Matern12().transform(f1)

    # Standard tests:
    for x in kernel_generator(k):
        yield x

    # Test computation of the kernel.
    k = Linear()
    x1, x2 = np.random.randn(10, 2), np.random.randn(10, 2)

    k2 = k.transform(lambda x, c: x**2)
    k3 = k.transform(lambda x, c: x**2, lambda x, c: x - 5)

    yield assert_allclose, k(x1**2, x2**2), k2(x1, x2)
    yield assert_allclose, k(x1**2, x2 - 5), k3(x1, x2)
コード例 #2
0
def test_shifted():
    k = ShiftedKernel(2 * EQ(), 5)

    assert k.stationary

    # Test equality.
    assert Linear().shift(2) == Linear().shift(2)
    assert Linear().shift(2) != Linear().shift(3)
    assert Linear().shift(2) != DecayingKernel(1, 1).shift(2)

    # Standard tests:
    standard_kernel_tests(k)

    k = (2 * EQ()).shift(5, 6)

    assert not k.stationary

    # Check computation.
    x1 = np.random.randn(10, 2)
    x2 = np.random.randn(5, 2)
    k = Linear()
    allclose(k.shift(5)(x1, x2), k(x1 - 5, x2 - 5))

    # Check passing in a list.
    k = Linear().shift(np.array([1, 2]))
    k(np.random.randn(10, 2))
コード例 #3
0
ファイル: test_graph.py プロジェクト: leekwoon/stheno
def test_approximate_multiplication():
    model = Graph()

    # Construct model.
    p1 = GP(EQ(), 20, graph=model)
    p2 = GP(EQ(), 20, graph=model)
    p_prod = p1 * p2
    x = np.linspace(0, 10, 50)

    # Sample functions.
    s1, s2 = model.sample(p1(x), p2(x))

    # Infer product.
    post = p_prod.condition((p1(x), s1), (p2(x), s2))
    yield le, rel_err(post(x).mean, s1 * s2), 1e-2

    # Perform division.
    cur_epsilon = B.epsilon
    B.epsilon = 1e-8
    post = p2.condition((p1(x), s1), (p_prod(x), s1 * s2))
    yield le, rel_err(post(x).mean, s2), 1e-2
    B.epsilon = cur_epsilon

    # Check graph check.
    model2 = Graph()
    p3 = GP(EQ(), graph=model2)
    yield raises, RuntimeError, lambda: p3 * p1
コード例 #4
0
def test_corner_cases():
    with pytest.raises(IndexError):
        EQ().stretch(1)[1]
    with pytest.raises(IndexError):
        (EQ() + RQ(1))[2]
    with pytest.raises(RuntimeError):
        mul(1, 1)
    with pytest.raises(RuntimeError):
        add(1, 1)
    with pytest.raises(RuntimeError):
        stretch(1, 1)
    with pytest.raises(RuntimeError):
        transform(1, 1)
    with pytest.raises(RuntimeError):
        differentiate(1, 1)
    with pytest.raises(RuntimeError):
        select(1, 1)
    with pytest.raises(RuntimeError):
        shift(1, 1)
    assert repr(EQ()) == str(EQ())
    assert EQ().__name__ == 'EQ'
    with pytest.raises(NotImplementedError):
        WrappedElement(1).display(1, lambda x: x)
    with pytest.raises(NotImplementedError):
        JoinElement(1, 2).display(1, 2, lambda x: x)
コード例 #5
0
ファイル: test_graph.py プロジェクト: cj401/stheno
def test_properties():
    model = Graph()

    p1 = GP(EQ(), graph=model)
    p2 = GP(EQ().stretch(2), graph=model)
    p3 = GP(EQ().periodic(10), graph=model)

    p = p1 + 2 * p2

    assert p.stationary == True, 'stationary'
    assert p.var == 1 + 2**2, 'var'
    allclose(p.length_scale, (1 + 2 * 2**2) / (1 + 2**2))
    assert p.period == np.inf, 'period'

    assert p3.period == 10, 'period'

    p = p3 + p

    assert p.stationary == True, 'stationary 2'
    assert p.var == 1 + 2**2 + 1, 'var 2'
    assert p.period == np.inf, 'period 2'

    p = p + GP(Linear(), graph=model)

    assert p.stationary == False, 'stationary 3'
コード例 #6
0
def test_input_transform():
    assert str(EQ().transform(lambda x: x)) == 'EQ() transform <lambda>'
    assert str(EQ().transform(lambda x: x, lambda x: x)) == \
           'EQ() transform (<lambda>, <lambda>)'

    assert str(ZeroKernel().transform(lambda x: x)) == '0'
    assert str(OneMean().transform(lambda x: x)) == '1'
コード例 #7
0
def test_input_transform():
    k = Linear().transform(lambda x: x - 5)

    # Verify that the kernel has the right properties.
    assert not k.stationary

    def f1(x):
        return x

    def f2(x):
        return x**2

    # Test equality.
    assert EQ().transform(f1) == EQ().transform(f1)
    assert EQ().transform(f1) != EQ().transform(f2)
    assert EQ().transform(f1) != Matern12().transform(f1)

    # Standard tests:
    standard_kernel_tests(k)

    # Test computation of the kernel.
    k = Linear()
    x1, x2 = B.randn(10, 2), B.randn(10, 2)

    k2 = k.transform(lambda x: x**2)
    k3 = k.transform(lambda x: x**2, lambda x: x - 5)

    approx(k(x1**2, x2**2), k2(x1, x2))
    approx(k(x1**2, x2 - 5), k3(x1, x2))
コード例 #8
0
ファイル: test_kernel.py プロジェクト: wubizhi/stheno
def test_periodic():
    k = EQ().stretch(2).periodic(3)

    assert k.stationary
    assert k.length_scale == 2
    assert k.period == 3
    assert k.var == 1

    # Test equality.
    assert EQ().periodic(2) == EQ().periodic(2)
    assert EQ().periodic(2) != EQ().periodic(3)
    assert Matern12().periodic(2) != EQ().periodic(2)

    # Standard tests:
    standard_kernel_tests(k)

    k = 5 * k.stretch(5)

    assert k.stationary
    assert k.length_scale == 10
    assert k.period == 15
    assert k.var == 5

    # Check passing in a list.
    k = EQ().periodic(np.array([1, 2]))
    k(np.random.randn(10, 2))
コード例 #9
0
def test_periodic():
    k = EQ().stretch(2).periodic(3)

    yield eq, k.stationary, True
    yield eq, k.length_scale, 2
    yield eq, k.period, 3
    yield eq, k.var, 1

    # Test equality.
    yield eq, EQ().periodic(2), EQ().periodic(2)
    yield neq, EQ().periodic(2), EQ().periodic(3)
    yield neq, Matern12().periodic(2), EQ().periodic(2)

    # Standard tests:
    for x in kernel_generator(k):
        yield x

    k = 5 * k.stretch(5)

    yield eq, k.stationary, True
    yield eq, k.length_scale, 10
    yield eq, k.period, 15
    yield eq, k.var, 5

    # Check passing in a list.
    k = EQ().periodic([1, 2])
    yield k, np.random.randn(10, 2)
コード例 #10
0
ファイル: test_graph.py プロジェクト: leekwoon/stheno
def test_properties():
    model = Graph()

    p1 = GP(EQ(), graph=model)
    p2 = GP(EQ().stretch(2), graph=model)
    p3 = GP(EQ().periodic(10), graph=model)

    p = p1 + 2 * p2

    yield eq, p.stationary, True, 'stationary'
    yield eq, p.var, 1 + 2 ** 2, 'var'
    yield assert_allclose, p.length_scale, \
          (1 + 2 * 2 ** 2) / (1 + 2 ** 2)
    yield eq, p.period, np.inf, 'period'

    yield eq, p3.period, 10, 'period'

    p = p3 + p

    yield eq, p.stationary, True, 'stationary 2'
    yield eq, p.var, 1 + 2 ** 2 + 1, 'var 2'
    yield eq, p.period, np.inf, 'period 2'

    p = p + GP(Linear(), graph=model)

    yield eq, p.stationary, False, 'stationary 3'
コード例 #11
0
ファイル: test_field.py プロジェクト: leekwoon/stheno
def test_input_transform():
    yield eq, str(EQ().transform(lambda x: x)), 'EQ() transform <lambda>'
    yield eq, str(EQ().transform(lambda x: x, lambda x: x)), \
          'EQ() transform (<lambda>, <lambda>)'

    yield eq, str(ZeroKernel().transform(lambda x: x)), '0'
    yield eq, str(OneMean().transform(lambda x: x)), '1'
コード例 #12
0
ファイル: test_eis.py プロジェクト: mhavasi/stheno
def test_compare_noisy_kernel_and_additive_component_kernel():
    def NoisyKernelCopy(k1, k2):
        return AdditiveComponentKernel({
            Component('wiggly'): k1,
            Component('noise'): k2
        }, latent=[Component('wiggly')])

    x = np.random.randn(10, 2)
    k1 = EQ()
    k2 = RQ(1e-2)
    kn1 = NoisyKernel(k1, k2)
    kn2 = NoisyKernelCopy(k1, k2)

    yield ok, allclose(kn1(Latent(x)), k1(x)), 'noisy kernel 1'
    yield ok, allclose(kn1(Latent(x), Observed(x)), k1(x)), 'noisy kernel 2'
    yield ok, allclose(kn1(Observed(x), Latent(x)), k1(x)), 'noisy kernel 3'
    yield ok, allclose(kn1(Observed(x)), (k1 + k2)(x)), 'noisy kernel 4'

    yield ok, allclose(kn2(Latent(x)), k1(x)), 'noisy copy 1'
    yield ok, allclose(kn2(Latent(x), Observed(x)), k1(x)), 'noisy copy 2'
    yield ok, allclose(kn2(Observed(x), Latent(x)), k1(x)), 'noisy copy 3'
    yield ok, allclose(kn2(Observed(x)), (k1 + k2)(x)), 'noisy copy 4'

    yield ok, allclose(kn2(Latent(x), Component('noise')(x)),
                          ZeroKernel()(x)), 'additive 1'
    yield ok, allclose(kn2(Component('noise')(x), Latent(x)),
                          ZeroKernel()(x)), 'additive 2'
    yield ok, allclose(kn2(Observed(x), Component('noise')(x)),
                          k2(x)), 'additive 3'
    yield ok, allclose(kn2(Component('noise')(x), Observed(x)),
                          k2(x)), 'additive 4'

    yield eq, str(NoisyKernel(EQ(), RQ(1))), \
          'NoisyKernel(EQ(), RQ(1))'
コード例 #13
0
ファイル: test_field.py プロジェクト: leekwoon/stheno
def test_parentheses():
    yield eq, str((reversed(Linear() * Linear() +
                            2 * EQ().stretch(1).periodic(2) +
                            RQ(3).periodic(4))) *
                  (EQ().stretch(1) + EQ())), \
          '(Reversed(Linear()) * Reversed(Linear()) + ' \
          '2 * ((EQ() > 1) per 2) + RQ(3) per 4) * (EQ() > 1 + EQ())'
コード例 #14
0
ファイル: test_kernel.py プロジェクト: wubizhi/stheno
def test_input_transform():
    k = Linear().transform(lambda x: x - 5)

    assert not k.stationary
    with pytest.raises(RuntimeError):
        k.length_scale
    with pytest.raises(RuntimeError):
        k.var
    with pytest.raises(RuntimeError):
        k.period

    def f1(x):
        return x

    def f2(x):
        return x ** 2

    # Test equality.
    assert EQ().transform(f1) == EQ().transform(f1)
    assert EQ().transform(f1) != EQ().transform(f2)
    assert EQ().transform(f1) != Matern12().transform(f1)

    # Standard tests:
    standard_kernel_tests(k)

    # Test computation of the kernel.
    k = Linear()
    x1, x2 = np.random.randn(10, 2), np.random.randn(10, 2)

    k2 = k.transform(lambda x: x ** 2)
    k3 = k.transform(lambda x: x ** 2, lambda x: x - 5)

    allclose(k(x1 ** 2, x2 ** 2), k2(x1, x2))
    allclose(k(x1 ** 2, x2 - 5), k3(x1, x2))
コード例 #15
0
ファイル: test_graph.py プロジェクト: wangwanwang56/stheno
def test_shorthands():
    model = Graph()
    p = GP(EQ(), graph=model)

    # Construct a normal distribution that serves as in input.
    x = p(1)
    assert isinstance(x, At)
    assert type_parameter(x) is p
    assert x.get() == 1
    assert str(p(x)) == '{}({})'.format(str(p), str(x))
    assert repr(p(x)) == '{}({})'.format(repr(p), repr(x))

    # Construct a normal distribution that does not serve as an input.
    x = Normal(np.ones((1, 1)))
    with pytest.raises(RuntimeError):
        type_parameter(x)
    with pytest.raises(RuntimeError):
        x.get()
    with pytest.raises(RuntimeError):
        p | (x, 1)

    # Test shorthands for stretching and selection.
    p = GP(EQ(), graph=Graph())
    assert str(p > 2) == str(p.stretch(2))
    assert str(p[0]) == str(p.select(0))
コード例 #16
0
def test_shifted():
    k = ShiftedKernel(2 * EQ(), 5)

    # Verify that the kernel has the right properties.
    assert k.stationary

    # Test equality.
    assert Linear().shift(2) == Linear().shift(2)
    assert Linear().shift(2) != Linear().shift(3)
    assert Linear().shift(2) != DecayingKernel(1, 1).shift(2)

    # Standard tests:
    standard_kernel_tests(k)

    k = (2 * EQ()).shift(5, 6)

    # Verify that the kernel has the right properties.
    assert not k.stationary

    # Check computation.
    x1 = B.randn(10, 2)
    x2 = B.randn(5, 2)
    k = Linear()
    approx(k.shift(5)(x1, x2), k(x1 - 5, x2 - 5))

    # Check passing in a list.
    k = Linear().shift(np.array([1, 2]))
    k(B.randn(10, 2))
コード例 #17
0
ファイル: test_kernel.py プロジェクト: wubizhi/stheno
def test_shifted():
    k = ShiftedKernel(2 * EQ(), 5)

    assert k.stationary
    assert k.length_scale == 1
    assert k.period == np.inf
    assert k.var == 2

    # Test equality.
    assert Linear().shift(2) == Linear().shift(2)
    assert Linear().shift(2) != Linear().shift(3)
    assert Linear().shift(2) != DecayingKernel(1, 1).shift(2)

    # Standard tests:
    standard_kernel_tests(k)

    k = (2 * EQ()).shift(5, 6)

    assert not k.stationary
    with pytest.raises(RuntimeError):
        k.length_scale
    assert k.period == np.inf
    assert k.var == 2

    # Check computation.
    x1 = np.random.randn(10, 2)
    x2 = np.random.randn(5, 2)
    k = Linear()
    allclose(k.shift(5)(x1, x2), k(x1 - 5, x2 - 5))

    # Check passing in a list.
    k = Linear().shift(np.array([1, 2]))
    k(np.random.randn(10, 2))
コード例 #18
0
def test_reversal():
    x1 = np.random.randn(10, 2)
    x2 = np.random.randn(5, 2)
    x3 = np.random.randn()

    # Test with a stationary and non-stationary kernel.
    for k in [EQ(), Linear()]:
        allclose(k(x1), reversed(k)(x1))
        allclose(k(x3), reversed(k)(x3))
        allclose(k(x1, x2), reversed(k)(x1, x2))
        allclose(k(x1, x2), reversed(k)(x2, x1).T)

        # Test double reversal does the right thing.
        allclose(k(x1), reversed(reversed(k))(x1))
        allclose(k(x3), reversed(reversed(k))(x3))
        allclose(k(x1, x2), reversed(reversed(k))(x1, x2))
        allclose(k(x1, x2), reversed(reversed(k))(x2, x1).T)

    # Verify that the kernel has the right properties.
    k = reversed(EQ())
    assert k.stationary

    k = reversed(Linear())
    assert not k.stationary
    assert str(k) == 'Reversed(Linear())'

    # Check equality.
    assert reversed(Linear()) == reversed(Linear())
    assert reversed(Linear()) != Linear()
    assert reversed(Linear()) != reversed(EQ())
    assert reversed(Linear()) != reversed(DecayingKernel(1, 1))

    # Standard tests:
    standard_kernel_tests(k)
コード例 #19
0
def test_shifted():
    k = ShiftedKernel(2 * EQ(), 5)

    yield eq, k.stationary, True
    yield eq, k.length_scale, 1
    yield eq, k.period, np.inf
    yield eq, k.var, 2

    # Test equality.
    yield eq, Linear().shift(2), Linear().shift(2)
    yield neq, Linear().shift(2), Linear().shift(3)
    yield neq, Linear().shift(2), DecayingKernel(1, 1).shift(2)

    # Standard tests:
    for x in kernel_generator(k):
        yield x

    k = (2 * EQ()).shift(5, 6)

    yield eq, k.stationary, False
    yield raises, RuntimeError, lambda: k.length_scale
    yield eq, k.period, np.inf
    yield eq, k.var, 2

    # Check computation.
    x1 = np.random.randn(10, 2)
    x2 = np.random.randn(5, 2)
    k = Linear()
    yield assert_allclose, k.shift(5)(x1, x2), k(x1 - 5, x2 - 5)

    # Check passing in a list.
    k = Linear().shift([1, 2])
    yield k, np.random.randn(10, 2)
コード例 #20
0
ファイル: test_graph.py プロジェクト: leekwoon/stheno
def test_derivative():
    # Test construction:
    p = GP(EQ(), TensorProductMean(lambda x: x ** 2), graph=Graph())
    yield eq, str(p.diff(1)), 'GP(d(1) EQ(), d(1) <lambda>)'

    # Test case:
    B.backend_to_tf()
    s = B.Session()

    model = Graph()
    x = np.linspace(0, 1, 100)[:, None]
    y = 2 * x

    p = GP(EQ(), graph=model)
    dp = p.diff()

    # Test conditioning on function.
    yield le, abs_err(s.run(dp.condition(p(x), y)(x).mean - 2)), 1e-3

    # Test conditioning on derivative.
    post = p.condition((B.cast(0., np.float64), B.cast(0., np.float64)),
                       (dp(x), y))
    yield le, abs_err(s.run(post(x).mean - x ** 2)), 1e-3

    s.close()
    B.backend_to_np()
コード例 #21
0
def test_case_additive_model():
    m = Measure()
    p1 = GP(EQ(), measure=m)
    p2 = GP(EQ(), measure=m)
    p_sum = p1 + p2

    x = B.linspace(0, 5, 10)
    y1 = p1(x).sample()
    y2 = p2(x).sample()

    # First, test independence:
    assert m.kernels[p2, p1] == ZeroKernel()
    assert m.kernels[p1, p2] == ZeroKernel()

    # Now run through some test cases:
    post = (m | (p1(x), y1)) | (p2(x), y2)
    approx(post(p_sum)(x).mean, y1 + y2)

    post = (m | (p2(x), y2)) | (p1(x), y1)
    approx(post(p_sum)(x).mean, y1 + y2)

    post = (m | (p1(x), y1)) | (p_sum(x), y1 + y2)
    approx(post(p2)(x).mean, y2)

    post = (m | (p_sum(x), y1 + y2)) | (p1(x), y1)
    approx(post(p2)(x).mean, y2)

    post = (m | (p2(x), y2)) | (p_sum(x), y1 + y2)
    approx(post(p1)(x).mean, y1)

    post = (m | (p_sum(x), y1 + y2)) | (p2(x), y2)
    approx(post(p1)(x).mean, y1)
コード例 #22
0
def test_selection():
    # Test construction:
    p = GP(TensorProductMean(lambda x: x**2), EQ())
    assert str(p.select(1)) == "GP(<lambda> : [1], EQ() : [1])"
    assert str(p.select(1, 2)) == "GP(<lambda> : [1, 2], EQ() : [1, 2])"

    # Test case:
    p = GP(EQ())  # 1D
    p2 = p.select(0)  # 2D

    x = B.linspace(0, 5, 10)
    x21 = B.stack(x, B.randn(10), axis=1)
    x22 = B.stack(x, B.randn(10), axis=1)
    y = p2(x).sample()

    post = p.measure | (p2(x21), y)
    approx(post(p(x)).mean, y)
    assert_equal_normals(post(p(x)), post(p2(x21)))

    post = p.measure | (p2(x22), y)
    approx(post(p(x)).mean, y)
    assert_equal_normals(post(p(x)), post(p2(x22)))

    post = p.measure | (p(x), y)
    approx(post(p2(x21)).mean, y)
    approx(post(p2(x22)).mean, y)
    assert_equal_normals(post(p2(x21)), post(p(x)))
    assert_equal_normals(post(p2(x22)), post(p(x)))
コード例 #23
0
ファイル: test_graph.py プロジェクト: leekwoon/stheno
def test_selection():
    model = Graph()

    # Test construction:
    p = GP(EQ(), TensorProductMean(lambda x: x ** 2), graph=model)
    yield eq, str(p.select(1)), 'GP(EQ() : [1], <lambda> : [1])'
    yield eq, str(p.select(1, 2)), 'GP(EQ() : [1, 2], <lambda> : [1, 2])'

    # Test case:
    p = GP(EQ(), graph=model)  # 1D
    p2 = p.select(0)  # 2D

    n = 5
    x = np.linspace(0, 10, n)[:, None]
    x1 = np.concatenate((x, np.random.randn(n, 1)), axis=1)
    x2 = np.concatenate((x, np.random.randn(n, 1)), axis=1)
    y = p2(x).sample()

    post = p.condition(p2(x1), y)
    yield assert_allclose, post(x).mean, y
    yield le, abs_err(B.diag(post(x).var)), 1e-10

    post = p.condition(p2(x2), y)
    yield assert_allclose, post(x).mean, y
    yield le, abs_err(B.diag(post(x).var)), 1e-10

    post = p2.condition(p(x), y)
    yield assert_allclose, post(x1).mean, y
    yield assert_allclose, post(x2).mean, y
    yield le, abs_err(B.diag(post(x1).var)), 1e-10
    yield le, abs_err(B.diag(post(x2).var)), 1e-10
コード例 #24
0
ファイル: test_kernel.py プロジェクト: wubizhi/stheno
def test_nested_derivatives(dtype):
    x = B.randn(dtype, 10, 2)

    res = EQ().diff(0, 0).diff(0, 0)(x)
    assert ~B.isnan(res[0, 0])

    res = EQ().diff(1, 1).diff(1, 1)(x)
    assert ~B.isnan(res[0, 0])
コード例 #25
0
ファイル: test_field.py プロジェクト: leekwoon/stheno
def test_derivative():
    yield eq, str(EQ().diff(0)), 'd(0) EQ()'
    yield eq, str(EQ().diff(0, 1)), 'd(0, 1) EQ()'

    yield eq, str(ZeroKernel().diff(0)), '0'
    yield eq, str(OneKernel().diff(0)), '0'

    yield eq, str(ZeroMean().diff(0)), '0'
    yield eq, str(OneMean().diff(0)), '0'
コード例 #26
0
def test_posterior_kernel():
    k = PosteriorKernel(EQ(), EQ(), EQ(), B.randn(5, 2), EQ()(B.randn(5, 1)))

    # Verify that the kernel has the right properties.
    assert not k.stationary
    assert str(k) == "PosteriorKernel()"

    # Standard tests:
    standard_kernel_tests(k, shapes=[((10, 2), (5, 2))])
コード例 #27
0
def test_derivative():
    assert str(EQ().diff(0)) == 'd(0) EQ()'
    assert str(EQ().diff(0, 1)) == 'd(0, 1) EQ()'

    assert str(ZeroKernel().diff(0)) == '0'
    assert str(OneKernel().diff(0)) == '0'

    assert str(ZeroMean().diff(0)) == '0'
    assert str(OneMean().diff(0)) == '0'
コード例 #28
0
ファイル: test_graph.py プロジェクト: mhavasi/stheno
def test_observations_and_conditioning():
    model = Graph()
    p1 = GP(EQ(), graph=model)
    p2 = GP(EQ(), graph=model)
    p = p1 + p2
    x = np.linspace(0, 5, 10)
    y = p(x).sample()
    y1 = p1(x).sample()

    # Test all ways of conditioning, including shorthands.
    obs1 = Obs(p(x), y)
    obs2 = Obs(x, y, ref=p)
    yield assert_allclose, obs1.y, obs2.y
    yield assert_allclose, obs1.K_x, obs2.K_x

    obs3 = Obs((p(x), y), (p1(x), y1))
    obs4 = Obs((x, y), (p1(x), y1), ref=p)
    yield assert_allclose, obs3.y, obs4.y
    yield assert_allclose, obs3.K_x, obs4.K_x

    def assert_equal_mean_var(x, *ys):
        for y in ys:
            yield assert_allclose, x.mean, y.mean
            yield assert_allclose, x.var, y.var

    for test in assert_equal_mean_var(
            p.condition(x, y)(x),
            p.condition(p(x), y)(x), (p | (x, y))(x), (p | (p(x), y))(x),
            p.condition(obs1)(x),
            p.condition(obs2)(x), (p | obs1)(x), (p | obs2)(x)):
        yield test

    for test in assert_equal_mean_var(
            p.condition((x, y), (p1(x), y1))(x),
            p.condition((p(x), y), (p1(x), y1))(x),
        (p | [(x, y), (p1(x), y1)])(x), (p | [(p(x), y), (p1(x), y1)])(x),
            p.condition(obs3)(x),
            p.condition(obs4)(x), (p | obs3)(x), (p | obs4)(x)):
        yield test

    # Check conditioning multiple processes at once.
    p1_post, p2_post, p_post = (p1, p2, p) | obs1
    p1_post, p2_post, p_post = p1_post(x), p2_post(x), p_post(x)
    p1_post2, p2_post2, p_post2 = (p1 | obs1)(x), (p2 | obs1)(x), (p | obs1)(x)

    yield assert_allclose, p1_post.mean, p1_post2.mean
    yield assert_allclose, p1_post.var, p1_post2.var
    yield assert_allclose, p2_post.mean, p2_post2.mean
    yield assert_allclose, p2_post.var, p2_post2.var
    yield assert_allclose, p_post.mean, p_post2.mean
    yield assert_allclose, p_post.var, p_post2.var

    # Test `At` check.
    yield raises, ValueError, lambda: Obs(0, 0)
    yield raises, ValueError, lambda: Obs((0, 0), (0, 0))
    yield raises, ValueError, lambda: SparseObs(0, p, (0, 0))
コード例 #29
0
ファイル: test_graph.py プロジェクト: leekwoon/stheno
def test_construction():
    model = Graph()
    p = GP(EQ(), graph=model)

    x = np.random.randn(10, 1)
    c = Cache()

    yield p.mean, x
    yield p.mean, p(x)
    yield p.mean, x, c
    yield p.mean, p(x), c

    yield p.kernel, x
    yield p.kernel, p(x)
    yield p.kernel, x, c
    yield p.kernel, p(x), c
    yield p.kernel, x, x
    yield p.kernel, p(x), x
    yield p.kernel, x, p(x)
    yield p.kernel, p(x), p(x)
    yield p.kernel, x, x, c
    yield p.kernel, p(x), x, c
    yield p.kernel, x, p(x), c
    yield p.kernel, p(x), p(x), c

    yield p.kernel.elwise, x
    yield p.kernel.elwise, p(x)
    yield p.kernel.elwise, x, c
    yield p.kernel.elwise, p(x), c
    yield p.kernel.elwise, x, x
    yield p.kernel.elwise, p(x), x
    yield p.kernel.elwise, x, p(x)
    yield p.kernel.elwise, p(x), p(x)
    yield p.kernel.elwise, x, x, c
    yield p.kernel.elwise, p(x), x, c
    yield p.kernel.elwise, x, p(x), c
    yield p.kernel.elwise, p(x), p(x), c

    # Test resolution of kernel and mean.
    k = EQ()
    m = TensorProductMean(lambda x: x ** 2)

    yield assert_instance, GP(k, graph=model).mean, ZeroMean
    yield assert_instance, GP(k, 5, graph=model).mean, ScaledMean
    yield assert_instance, GP(k, 1, graph=model).mean, OneMean
    yield assert_instance, GP(k, 0, graph=model).mean, ZeroMean
    yield assert_instance, GP(k, m, graph=model).mean, TensorProductMean
    yield assert_instance, GP(k, graph=model).kernel, EQ
    yield assert_instance, GP(5, graph=model).kernel, ScaledKernel
    yield assert_instance, GP(1, graph=model).kernel, OneKernel
    yield assert_instance, GP(0, graph=model).kernel, ZeroKernel

    # Test construction of finite-dimensional distribution.
    d = GP(k, m, graph=model)(x)
    yield assert_allclose, d.var, k(x)
    yield assert_allclose, d.mean, m(x)
コード例 #30
0
ファイル: test_graph.py プロジェクト: leekwoon/stheno
def test_corner_cases():
    m1 = Graph()
    m2 = Graph()
    p1 = GP(EQ(), graph=m1)
    p2 = GP(EQ(), graph=m2)
    x = np.random.randn(10, 2)
    yield raises, RuntimeError, lambda: p1 + p2
    yield raises, NotImplementedError, lambda: p1 + p1(x)
    yield raises, NotImplementedError, lambda: p1 * p1(x)
    yield eq, str(GP(EQ(), graph=m1)), 'GP(EQ(), 0)'