Beispiel #1
0
def test_basic_arithmetic():
    k1 = EQ()
    k2 = RQ(1e-1)
    k3 = Matern12()
    k4 = Matern32()
    k5 = Matern52()
    k6 = Delta()
    k7 = Linear()
    xs1 = np.random.randn(10, 2), np.random.randn(20, 2)
    xs2 = np.random.randn(), np.random.randn()

    yield ok, allclose(k6(xs1[0]), k6(xs1[0], xs1[0])), 'dispatch'
    yield ok, allclose((k1 * k2)(*xs1), k1(*xs1) * k2(*xs1)), 'prod'
    yield ok, allclose((k1 * k2)(*xs2), k1(*xs2) * k2(*xs2)), 'prod 2'
    yield ok, allclose((k3 + k4)(*xs1), k3(*xs1) + k4(*xs1)), 'sum'
    yield ok, allclose((k3 + k4)(*xs2), k3(*xs2) + k4(*xs2)), 'sum 2'
    yield ok, allclose((5. * k5)(*xs1), 5. * k5(*xs1)), 'prod 3'
    yield ok, allclose((5. * k5)(*xs2), 5. * k5(*xs2)), 'prod 4'
    yield ok, allclose((5. + k7)(*xs1), 5. + k7(*xs1)), 'sum 3'
    yield ok, allclose((5. + k7)(*xs2), 5. + k7(*xs2)), 'sum 4'
    yield ok, allclose(k1.stretch(2.)(*xs1), k1(xs1[0] / 2.,
                                                xs1[1] / 2.)), 'stretch'
    yield ok, allclose(k1.stretch(2.)(*xs2), k1(xs2[0] / 2.,
                                                xs2[1] / 2.)), 'stretch 2'
    yield ok, allclose(
        k1.periodic(1.)(*xs1),
        k1.periodic(1.)(xs1[0], xs1[1] + 5.)), 'periodic'
    yield ok, allclose(
        k1.periodic(1.)(*xs2),
        k1.periodic(1.)(xs2[0], xs2[1] + 5.)), 'periodic 2'
Beispiel #2
0
def test_periodic():
    k = EQ().stretch(2).periodic(3)

    # Verify that the kernel has the right properties.
    assert str(k) == "(EQ() > 2) per 3"
    assert k.stationary

    # Test equality.
    assert EQ().periodic(2) == EQ().periodic(2)
    assert EQ().periodic(2) != EQ().periodic(3)
    assert Matern12().periodic(2) != EQ().periodic(2)

    # Standard tests:
    standard_kernel_tests(k)

    k = 5 * k.stretch(5)

    # Verify that the kernel has the right properties.
    assert k.stationary

    # Check passing in a list.
    k = EQ().periodic(np.array([1, 2]))
    k(B.randn(10, 2))

    # Check periodication of a zero.
    k = ZeroKernel()
    assert k.periodic(3) is k
Beispiel #3
0
def test_basic_arithmetic():
    k1 = EQ()
    k2 = RQ(1e-1)
    k3 = Matern12()
    k4 = Matern32()
    k5 = Matern52()
    k6 = Delta()
    k7 = Linear()
    xs1 = B.randn(10, 2), B.randn(20, 2)
    xs2 = B.randn(), B.randn()

    approx(k6(xs1[0]), k6(xs1[0], xs1[0]))
    approx((k1 * k2)(*xs1), k1(*xs1) * k2(*xs1))
    approx((k1 * k2)(*xs2), k1(*xs2) * k2(*xs2))
    approx((k3 + k4)(*xs1), k3(*xs1) + k4(*xs1))
    approx((k3 + k4)(*xs2), k3(*xs2) + k4(*xs2))
    approx((5.0 * k5)(*xs1), 5.0 * k5(*xs1))
    approx((5.0 * k5)(*xs2), 5.0 * k5(*xs2))
    approx((5.0 + k7)(*xs1), 5.0 + k7(*xs1))
    approx((5.0 + k7)(*xs2), 5.0 + k7(*xs2))
    approx(k1.stretch(2.0)(*xs1), k1(xs1[0] / 2.0, xs1[1] / 2.0))
    approx(k1.stretch(2.0)(*xs2), k1(xs2[0] / 2.0, xs2[1] / 2.0))
    approx(k1.periodic(1.0)(*xs1), k1.periodic(1.0)(xs1[0], xs1[1] + 5.0))
    approx(k1.periodic(1.0)(*xs2), k1.periodic(1.0)(xs2[0], xs2[1] + 5.0))