def test_basic_arithmetic(): k1 = EQ() k2 = RQ(1e-1) k3 = Matern12() k4 = Matern32() k5 = Matern52() k6 = Delta() k7 = Linear() xs1 = np.random.randn(10, 2), np.random.randn(20, 2) xs2 = np.random.randn(), np.random.randn() yield ok, allclose(k6(xs1[0]), k6(xs1[0], xs1[0])), 'dispatch' yield ok, allclose((k1 * k2)(*xs1), k1(*xs1) * k2(*xs1)), 'prod' yield ok, allclose((k1 * k2)(*xs2), k1(*xs2) * k2(*xs2)), 'prod 2' yield ok, allclose((k3 + k4)(*xs1), k3(*xs1) + k4(*xs1)), 'sum' yield ok, allclose((k3 + k4)(*xs2), k3(*xs2) + k4(*xs2)), 'sum 2' yield ok, allclose((5. * k5)(*xs1), 5. * k5(*xs1)), 'prod 3' yield ok, allclose((5. * k5)(*xs2), 5. * k5(*xs2)), 'prod 4' yield ok, allclose((5. + k7)(*xs1), 5. + k7(*xs1)), 'sum 3' yield ok, allclose((5. + k7)(*xs2), 5. + k7(*xs2)), 'sum 4' yield ok, allclose(k1.stretch(2.)(*xs1), k1(xs1[0] / 2., xs1[1] / 2.)), 'stretch' yield ok, allclose(k1.stretch(2.)(*xs2), k1(xs2[0] / 2., xs2[1] / 2.)), 'stretch 2' yield ok, allclose( k1.periodic(1.)(*xs1), k1.periodic(1.)(xs1[0], xs1[1] + 5.)), 'periodic' yield ok, allclose( k1.periodic(1.)(*xs2), k1.periodic(1.)(xs2[0], xs2[1] + 5.)), 'periodic 2'
def test_periodic(): k = EQ().stretch(2).periodic(3) # Verify that the kernel has the right properties. assert str(k) == "(EQ() > 2) per 3" assert k.stationary # Test equality. assert EQ().periodic(2) == EQ().periodic(2) assert EQ().periodic(2) != EQ().periodic(3) assert Matern12().periodic(2) != EQ().periodic(2) # Standard tests: standard_kernel_tests(k) k = 5 * k.stretch(5) # Verify that the kernel has the right properties. assert k.stationary # Check passing in a list. k = EQ().periodic(np.array([1, 2])) k(B.randn(10, 2)) # Check periodication of a zero. k = ZeroKernel() assert k.periodic(3) is k
def test_basic_arithmetic(): k1 = EQ() k2 = RQ(1e-1) k3 = Matern12() k4 = Matern32() k5 = Matern52() k6 = Delta() k7 = Linear() xs1 = B.randn(10, 2), B.randn(20, 2) xs2 = B.randn(), B.randn() approx(k6(xs1[0]), k6(xs1[0], xs1[0])) approx((k1 * k2)(*xs1), k1(*xs1) * k2(*xs1)) approx((k1 * k2)(*xs2), k1(*xs2) * k2(*xs2)) approx((k3 + k4)(*xs1), k3(*xs1) + k4(*xs1)) approx((k3 + k4)(*xs2), k3(*xs2) + k4(*xs2)) approx((5.0 * k5)(*xs1), 5.0 * k5(*xs1)) approx((5.0 * k5)(*xs2), 5.0 * k5(*xs2)) approx((5.0 + k7)(*xs1), 5.0 + k7(*xs1)) approx((5.0 + k7)(*xs2), 5.0 + k7(*xs2)) approx(k1.stretch(2.0)(*xs1), k1(xs1[0] / 2.0, xs1[1] / 2.0)) approx(k1.stretch(2.0)(*xs2), k1(xs2[0] / 2.0, xs2[1] / 2.0)) approx(k1.periodic(1.0)(*xs1), k1.periodic(1.0)(xs1[0], xs1[1] + 5.0)) approx(k1.periodic(1.0)(*xs2), k1.periodic(1.0)(xs2[0], xs2[1] + 5.0))