def test_mat32(): k = Matern32() # Verify that the kernel has the right properties. assert k.stationary assert str(k) == 'Matern32()' # Test equality. assert Matern32() == Matern32() assert Matern32() != Linear() # Standard tests: standard_kernel_tests(k)
def test_basic_arithmetic(): k1 = EQ() k2 = RQ(1e-1) k3 = Matern12() k4 = Matern32() k5 = Matern52() k6 = Delta() k7 = Linear() xs1 = np.random.randn(10, 2), np.random.randn(20, 2) xs2 = np.random.randn(), np.random.randn() yield ok, allclose(k6(xs1[0]), k6(xs1[0], xs1[0])), 'dispatch' yield ok, allclose((k1 * k2)(*xs1), k1(*xs1) * k2(*xs1)), 'prod' yield ok, allclose((k1 * k2)(*xs2), k1(*xs2) * k2(*xs2)), 'prod 2' yield ok, allclose((k3 + k4)(*xs1), k3(*xs1) + k4(*xs1)), 'sum' yield ok, allclose((k3 + k4)(*xs2), k3(*xs2) + k4(*xs2)), 'sum 2' yield ok, allclose((5. * k5)(*xs1), 5. * k5(*xs1)), 'prod 3' yield ok, allclose((5. * k5)(*xs2), 5. * k5(*xs2)), 'prod 4' yield ok, allclose((5. + k7)(*xs1), 5. + k7(*xs1)), 'sum 3' yield ok, allclose((5. + k7)(*xs2), 5. + k7(*xs2)), 'sum 4' yield ok, allclose(k1.stretch(2.)(*xs1), k1(xs1[0] / 2., xs1[1] / 2.)), 'stretch' yield ok, allclose(k1.stretch(2.)(*xs2), k1(xs2[0] / 2., xs2[1] / 2.)), 'stretch 2' yield ok, allclose( k1.periodic(1.)(*xs1), k1.periodic(1.)(xs1[0], xs1[1] + 5.)), 'periodic' yield ok, allclose( k1.periodic(1.)(*xs2), k1.periodic(1.)(xs2[0], xs2[1] + 5.)), 'periodic 2'
def test_mat32(): k = Matern32() # Verify that the kernel has the right properties. yield eq, k.stationary, True yield eq, k.var, 1 yield eq, k.length_scale, 1 yield eq, k.period, np.inf yield eq, str(k), 'Matern32()' # Test equality. yield eq, Matern32(), Matern32() yield neq, Matern32(), Linear() # Standard tests: for x in kernel_generator(k): yield x
def test_basic_arithmetic(): k1 = EQ() k2 = RQ(1e-1) k3 = Matern12() k4 = Matern32() k5 = Matern52() k6 = Delta() k7 = Linear() xs1 = B.randn(10, 2), B.randn(20, 2) xs2 = B.randn(), B.randn() approx(k6(xs1[0]), k6(xs1[0], xs1[0])) approx((k1 * k2)(*xs1), k1(*xs1) * k2(*xs1)) approx((k1 * k2)(*xs2), k1(*xs2) * k2(*xs2)) approx((k3 + k4)(*xs1), k3(*xs1) + k4(*xs1)) approx((k3 + k4)(*xs2), k3(*xs2) + k4(*xs2)) approx((5.0 * k5)(*xs1), 5.0 * k5(*xs1)) approx((5.0 * k5)(*xs2), 5.0 * k5(*xs2)) approx((5.0 + k7)(*xs1), 5.0 + k7(*xs1)) approx((5.0 + k7)(*xs2), 5.0 + k7(*xs2)) approx(k1.stretch(2.0)(*xs1), k1(xs1[0] / 2.0, xs1[1] / 2.0)) approx(k1.stretch(2.0)(*xs2), k1(xs2[0] / 2.0, xs2[1] / 2.0)) approx(k1.periodic(1.0)(*xs1), k1.periodic(1.0)(xs1[0], xs1[1] + 5.0)) approx(k1.periodic(1.0)(*xs2), k1.periodic(1.0)(xs2[0], xs2[1] + 5.0))