コード例 #1
def test_input_transform():
    k = Linear().transform(lambda x: x - 5)

    # Verify that the kernel has the right properties.
    assert not k.stationary

    def f1(x):
        return x

    def f2(x):
        return x**2

    # Test equality.
    assert EQ().transform(f1) == EQ().transform(f1)
    assert EQ().transform(f1) != EQ().transform(f2)
    assert EQ().transform(f1) != Matern12().transform(f1)

    # Standard tests:

    # Test computation of the kernel.
    k = Linear()
    x1, x2 = B.randn(10, 2), B.randn(10, 2)

    k2 = k.transform(lambda x: x**2)
    k3 = k.transform(lambda x: x**2, lambda x: x - 5)

    approx(k(x1**2, x2**2), k2(x1, x2))
    approx(k(x1**2, x2 - 5), k3(x1, x2))
コード例 #2
ファイル: test_kernel.py プロジェクト: wubizhi/stheno
def test_fixed_delta():
    noises = B.rand(3)
    k = FixedDelta(noises)

    # Verify that the kernel has the right properties.
    assert k.stationary
    assert k.var == 1
    assert k.length_scale == 0
    assert k.period == np.inf
    assert str(k) == 'FixedDelta()'

    # Check equality.
    assert FixedDelta(noises) == FixedDelta(noises)
    assert FixedDelta(noises) != FixedDelta(2 * noises)
    assert FixedDelta(noises) != EQ()

    # Standard tests:

    # Check correctness.
    x1 = B.randn(5)
    x2 = B.randn(5)
    allclose(k(x1), B.zeros(5, 5))
    allclose(k.elwise(x1), B.zeros(5, 1))
    allclose(k(x1, x2), B.zeros(5, 5))
    allclose(k.elwise(x1, x2), B.zeros(5, 1))

    x1 = B.randn(3)
    x2 = B.randn(3)
    allclose(k(x1), B.diag(noises))
    allclose(k.elwise(x1), B.uprank(noises))
    allclose(k(x1, x2), B.zeros(3, 3))
    allclose(k.elwise(x1, x2), B.zeros(3, 1))
コード例 #3
ファイル: test_kernel.py プロジェクト: wubizhi/stheno
def standard_kernel_tests(k, shapes=None, dtype=np.float64):
    if shapes is None:
        shapes = [((10, 2), (5, 2)),
                  ((10, 1), (5, 1)),
                  ((10,), (5,)),
                  ((10,), ()),
                  ((), (5,)),
                  ((), ())]

    # Check various shapes of arguments.
    for shape1, shape2 in shapes:
        x1 = B.randn(dtype, *shape1)
        x2 = B.randn(dtype, *shape2)

        # Check that the kernel computes consistently.
        allclose(k(x1, x2), reversed(k)(x2, x1).T)

        # Check `elwise`.
        x2 = B.randn(dtype, *shape1)

        allclose(k.elwise(x1, x2)[:, 0], B.diag(k(x1, x2)))
        allclose(k.elwise(x1, x2), Kernel.elwise(k, x1, x2))
        # The element-wise computation is more accurate, which is why we allow
        # a discrepancy a bit larger than the square root of the machine
        # epsilon.
        allclose(k.elwise(x1)[:, 0], B.diag(k(x1)),
                 desc='', atol=1e-6, rtol=1e-6)
        allclose(k.elwise(x1), Kernel.elwise(k, x1))
コード例 #4
コード例 #5
def test_reversal():
    x1 = B.randn(10, 2)
    x2 = B.randn(5, 2)
    x3 = B.randn()

    # Test with a stationary and non-stationary kernel.
    for k in [EQ(), Linear()]:
        approx(k(x1), reversed(k)(x1))
        approx(k(x3), reversed(k)(x3))
        approx(k(x1, x2), reversed(k)(x1, x2))
        approx(k(x1, x2), reversed(k)(x2, x1).T)

        # Test double reversal does the right thing.
        approx(k(x1), reversed(reversed(k))(x1))
        approx(k(x3), reversed(reversed(k))(x3))
        approx(k(x1, x2), reversed(reversed(k))(x1, x2))
        approx(k(x1, x2), reversed(reversed(k))(x2, x1).T)

    # Verify that the kernel has the right properties.
    k = reversed(EQ())
    assert k.stationary

    k = reversed(Linear())
    assert not k.stationary
    assert str(k) == "Reversed(Linear())"

    # Check equality.
    assert reversed(Linear()) == reversed(Linear())
    assert reversed(Linear()) != Linear()
    assert reversed(Linear()) != reversed(EQ())
    assert reversed(Linear()) != reversed(DecayingKernel(1, 1))

    # Standard tests:
コード例 #6
def test_shifted():
    k = ShiftedKernel(2 * EQ(), 5)

    # Verify that the kernel has the right properties.
    assert k.stationary

    # Test equality.
    assert Linear().shift(2) == Linear().shift(2)
    assert Linear().shift(2) != Linear().shift(3)
    assert Linear().shift(2) != DecayingKernel(1, 1).shift(2)

    # Standard tests:

    k = (2 * EQ()).shift(5, 6)

    # Verify that the kernel has the right properties.
    assert not k.stationary

    # Check computation.
    x1 = B.randn(10, 2)
    x2 = B.randn(5, 2)
    k = Linear()
    approx(k.shift(5)(x1, x2), k(x1 - 5, x2 - 5))

    # Check passing in a list.
    k = Linear().shift(np.array([1, 2]))
    k(B.randn(10, 2))
コード例 #7
def test_posterior_kernel():
    k = PosteriorKernel(EQ(), EQ(), EQ(), B.randn(5, 2), EQ()(B.randn(5, 1)))

    # Verify that the kernel has the right properties.
    assert not k.stationary
    assert str(k) == "PosteriorKernel()"

    # Standard tests:
    standard_kernel_tests(k, shapes=[((10, 2), (5, 2))])
コード例 #8
def test_corrective_kernel():
    a, b = B.randn(3, 3), B.randn(3, 3)
    a, b = a.dot(a.T), b.dot(b.T)
    z = B.randn(3, 2)
    k = CorrectiveKernel(EQ(), EQ(), z, a, b)

    # Verify that the kernel has the right properties.
    assert not k.stationary
    assert str(k) == "CorrectiveKernel()"

    # Standard tests:
    standard_kernel_tests(k, shapes=[((10, 2), (5, 2))])
コード例 #9
def test_periodic():
    k = EQ().stretch(2).periodic(3)

    # Verify that the kernel has the right properties.
    assert str(k) == "(EQ() > 2) per 3"
    assert k.stationary

    # Test equality.
    assert EQ().periodic(2) == EQ().periodic(2)
    assert EQ().periodic(2) != EQ().periodic(3)
    assert Matern12().periodic(2) != EQ().periodic(2)

    # Standard tests:

    k = 5 * k.stretch(5)

    # Verify that the kernel has the right properties.
    assert k.stationary

    # Check passing in a list.
    k = EQ().periodic(np.array([1, 2]))
    k(B.randn(10, 2))

    # Check periodication of a zero.
    k = ZeroKernel()
    assert k.periodic(3) is k
コード例 #10
def test_selection():
    k = (2 * EQ().stretch(5)).select(0)

    # Verify that the kernel has the right properties.
    assert k.stationary

    # Test equality.
    assert EQ().select(0) == EQ().select(0)
    assert EQ().select(0) != EQ().select(1)
    assert EQ().select(0) != Matern12().select(0)

    # Standard tests:

    # Verify that the kernel has the right properties.
    k = (2 * EQ().stretch(5)).select([2, 3])
    assert k.stationary

    k = (2 * EQ().stretch(np.array([1, 2, 3]))).select([0, 2])
    assert k.stationary

    k = (2 * EQ().periodic(np.array([1, 2, 3]))).select([1, 2])
    assert k.stationary

    k = (2 * EQ().stretch(np.array([1, 2, 3]))).select([0, 2], [1, 2])
    assert not k.stationary

    k = (2 * EQ().periodic(np.array([1, 2, 3]))).select([0, 2], [1, 2])
    assert not k.stationary

    # Test computation of the kernel.
    k1 = EQ().select([1, 2])
    k2 = EQ()
    x = B.randn(10, 3)
    approx(k1(x), k2(x[:, [1, 2]]))
コード例 #11
ファイル: test_kernel.py プロジェクト: wubizhi/stheno
def test_nested_derivatives(dtype):
    x = B.randn(dtype, 10, 2)

    res = EQ().diff(0, 0).diff(0, 0)(x)
    assert ~B.isnan(res[0, 0])

    res = EQ().diff(1, 1).diff(1, 1)(x)
    assert ~B.isnan(res[0, 0])
コード例 #12
def test_zero():
    k = ZeroKernel()
    x1 = B.randn(10, 2)
    x2 = B.randn(5, 2)

    # Test that the kernel computes correctly.
    approx(k(x1, x2), np.zeros((10, 5)))

    # Verify that the kernel has the right properties.
    assert k.stationary
    assert str(k) == "0"

    # Test equality.
    assert ZeroKernel() == ZeroKernel()
    assert ZeroKernel() != Linear()

    # Standard tests:
コード例 #13
ファイル: test_kernel.py プロジェクト: wubizhi/stheno
def test_derivative_eq():
    # Test derivative of kernel `EQ()`.
    k = EQ()
    x1 = B.randn(tf.float64, 10, 1)
    x2 = B.randn(tf.float64, 5, 1)

    # Test derivative with respect to first input.
    allclose(k.diff(0, None)(x1, x2), -k(x1, x2) * (x1 - B.transpose(x2)))
    allclose(k.diff(0, None)(x1), -k(x1) * (x1 - B.transpose(x1)))

    # Test derivative with respect to second input.
    allclose(k.diff(None, 0)(x1, x2), -k(x1, x2) * (B.transpose(x2) - x1))
    allclose(k.diff(None, 0)(x1), -k(x1) * (B.transpose(x1) - x1))

    # Test derivative with respect to both inputs.
    ref = k(x1, x2) * (1 - (x1 - B.transpose(x2)) ** 2)
    allclose(k.diff(0, 0)(x1, x2), ref)
    allclose(k.diff(0)(x1, x2), ref)
    ref = k(x1) * (1 - (x1 - B.transpose(x1)) ** 2)
    allclose(k.diff(0, 0)(x1), ref)
    allclose(k.diff(0)(x1), ref)
コード例 #14
ファイル: test_kernel.py プロジェクト: wubizhi/stheno
def test_derivative_linear():
    # Test derivative of kernel `Linear()`.
    k = Linear()
    x1 = B.randn(tf.float64, 10, 1)
    x2 = B.randn(tf.float64, 5, 1)

    # Test derivative with respect to first input.
    allclose(k.diff(0, None)(x1, x2),
             B.ones(tf.float64, 10, 5) * B.transpose(x2))
    allclose(k.diff(0, None)(x1),
             B.ones(tf.float64, 10, 10) * B.transpose(x1))

    # Test derivative with respect to second input.
    allclose(k.diff(None, 0)(x1, x2), B.ones(tf.float64, 10, 5) * x1)
    allclose(k.diff(None, 0)(x1), B.ones(tf.float64, 10, 10) * x1)

    # Test derivative with respect to both inputs.
    ref = B.ones(tf.float64, 10, 5)
    allclose(k.diff(0, 0)(x1, x2), ref)
    allclose(k.diff(0)(x1, x2), ref)
    ref = B.ones(tf.float64, 10, 10)
    allclose(k.diff(0, 0)(x1), ref)
    allclose(k.diff(0)(x1), ref)
コード例 #15
def test_basic_arithmetic():
    k1 = EQ()
    k2 = RQ(1e-1)
    k3 = Matern12()
    k4 = Matern32()
    k5 = Matern52()
    k6 = Delta()
    k7 = Linear()
    xs1 = B.randn(10, 2), B.randn(20, 2)
    xs2 = B.randn(), B.randn()

    approx(k6(xs1[0]), k6(xs1[0], xs1[0]))
    approx((k1 * k2)(*xs1), k1(*xs1) * k2(*xs1))
    approx((k1 * k2)(*xs2), k1(*xs2) * k2(*xs2))
    approx((k3 + k4)(*xs1), k3(*xs1) + k4(*xs1))
    approx((k3 + k4)(*xs2), k3(*xs2) + k4(*xs2))
    approx((5.0 * k5)(*xs1), 5.0 * k5(*xs1))
    approx((5.0 * k5)(*xs2), 5.0 * k5(*xs2))
    approx((5.0 + k7)(*xs1), 5.0 + k7(*xs1))
    approx((5.0 + k7)(*xs2), 5.0 + k7(*xs2))
    approx(k1.stretch(2.0)(*xs1), k1(xs1[0] / 2.0, xs1[1] / 2.0))
    approx(k1.stretch(2.0)(*xs2), k1(xs2[0] / 2.0, xs2[1] / 2.0))
    approx(k1.periodic(1.0)(*xs1), k1.periodic(1.0)(xs1[0], xs1[1] + 5.0))
    approx(k1.periodic(1.0)(*xs2), k1.periodic(1.0)(xs2[0], xs2[1] + 5.0))
コード例 #16
def test_stretched():
    k = EQ().stretch(2)

    # Verify that the kernel has the right properties.
    assert k.stationary

    # Test equality.
    assert EQ().stretch(2) == EQ().stretch(2)
    assert EQ().stretch(2) != EQ().stretch(3)
    assert EQ().stretch(2) != Matern12().stretch(2)

    # Standard tests:

    k = EQ().stretch(1, 2)

    # Verify that the kernel has the right properties.
    assert not k.stationary

    # Check passing in a list.
    k = EQ().stretch(np.array([1, 2]))
    k(B.randn(10, 2))
コード例 #17
        # Check againtst fallback brute force computation.
        approx(k.elwise(x1, x2), Kernel.elwise(k, x1, x2))
        # The element-wise computation is more accurate, which is why we allow a
        # discrepancy a bit larger than the square root of the machine epsilon.
        approx(k.elwise(x1)[:, 0], B.diag(k(x1)), atol=1e-6, rtol=1e-6)
        approx(k.elwise(x1), Kernel.elwise(k, x1))

def test_construction(x1, x2):
    k = EQ()

    k(x1, x2)

    k.elwise(x1, x2)

    # Test `MultiInput` construction.
コード例 #18
ファイル: test_kernel.py プロジェクト: wubizhi/stheno
    # Verify that the kernel has the right properties.
    assert k.stationary
    assert k.var == 1
    assert k.length_scale == 0
    assert k.period == np.inf
    assert str(k) == 'Delta()'

    # Check equality.
    assert Delta() == Delta()
    assert Delta() != Delta(epsilon=k.epsilon * 10)
    assert Delta() != EQ()

@pytest.mark.parametrize('x1_x2', [(np.array(0), np.array(1)),
                                   (B.randn(10), B.randn(5)),
                                   (B.randn(10, 1), B.randn(5, 1)),
                                   (B.randn(10, 2), B.randn(5, 2))])
def test_delta_evaluations(x1_x2):
    k = Delta()
    x1, x2 = x1_x2
    n1 = B.shape(B.uprank(x1))[0]
    n2 = B.shape(B.uprank(x2))[0]

    # Check uniqueness checks.
    allclose(k(x1), B.eye(n1))
    allclose(k(x1, x2), B.zeros(n1, n2))

    # Standard tests: