Example #1
0
def test_grouping():
    # Scales:
    assert str(5 * EQ()) == '5 * EQ()'
    assert str(5 * (5 * EQ())) == '25 * EQ()'

    # Stretches:
    assert str(EQ().stretch(5)) == 'EQ() > 5'
    assert str(EQ().stretch(5).stretch(5)) == 'EQ() > 25'

    # Shifts:
    assert str(Linear().shift(5)) == 'Linear() shift 5'
    assert str(Linear().shift(5).shift(5)) == 'Linear() shift 10'

    # Products:
    assert str((5 * EQ()) * (5 * EQ())) == '25 * EQ()'
    assert str((5 * (EQ() * EQ())) * (5 * EQ() * EQ())) == '25 * EQ() * EQ()'
    assert str((5 * RQ(1)) * (5 * RQ(2))) == '25 * RQ(1) * RQ(2)'

    # Sums:
    assert str((5 * EQ()) + (5 * EQ())) == '10 * EQ()'
    assert str(EQ() + (5 * EQ())) == '6 * EQ()'
    assert str((5 * EQ()) + EQ()) == '6 * EQ()'
    assert str((EQ() + EQ())) == '2 * EQ()'
    assert str((5 * (EQ() * EQ())) + (5 * (EQ() * EQ()))) == '10 * EQ() * EQ()'
    assert str((5 * RQ(1)) + (5 * RQ(2))) == '5 * RQ(1) + 5 * RQ(2)'

    # Reversal:
    assert str(reversed(Linear() + EQ())) == 'Reversed(Linear()) + EQ()'
    assert str(reversed(Linear() * EQ())) == 'Reversed(Linear()) * EQ()'
Example #2
0
def test_grouping():
    # Scales:
    yield eq, str(5 * EQ()), '5 * EQ()'
    yield eq, str(5 * (5 * EQ())), '25 * EQ()'

    # Stretches:
    yield eq, str(EQ().stretch(5)), 'EQ() > 5'
    yield eq, str(EQ().stretch(5).stretch(5)), 'EQ() > 25'

    # Shifts:
    yield eq, str(Linear().shift(5)), 'Linear() shift 5'
    yield eq, str(Linear().shift(5).shift(5)), 'Linear() shift 10'

    # Products:
    yield eq, str((5 * EQ()) * (5 * EQ())), '25 * EQ()'
    yield eq, str((5 * (EQ() * EQ())) * (5 * EQ() * EQ())), \
          '25 * EQ() * EQ()'
    yield eq, str((5 * RQ(1)) * (5 * RQ(2))), '25 * RQ(1) * RQ(2)'

    # Sums:
    yield eq, str((5 * EQ()) + (5 * EQ())), '10 * EQ()'
    yield eq, str(EQ() + (5 * EQ())), '6 * EQ()'
    yield eq, str((5 * EQ()) + EQ()), '6 * EQ()'
    yield eq, str((EQ() + EQ())), '2 * EQ()'
    yield eq, str((5 * (EQ() * EQ())) + (5 * (EQ() * EQ()))), \
          '10 * EQ() * EQ()'
    yield eq, str((5 * RQ(1)) + (5 * RQ(2))), '5 * RQ(1) + 5 * RQ(2)'

    # Reversal:
    yield eq, str(reversed(Linear() + EQ())), 'Reversed(Linear()) + EQ()'
    yield eq, str(reversed(Linear() * EQ())), 'Reversed(Linear()) * EQ()'
Example #3
0
def test_parentheses():
    yield eq, str((reversed(Linear() * Linear() +
                            2 * EQ().stretch(1).periodic(2) +
                            RQ(3).periodic(4))) *
                  (EQ().stretch(1) + EQ())), \
          '(Reversed(Linear()) * Reversed(Linear()) + ' \
          '2 * ((EQ() > 1) per 2) + RQ(3) per 4) * (EQ() > 1 + EQ())'
Example #4
0
def test_broadcast():
    assert broadcast(operator.add, (1, 2, 3), (2, 3, 4)) == (3, 5, 7)
    assert broadcast(operator.add, (1, ), (2, 3, 4)) == (3, 4, 5)
    assert broadcast(operator.add, (1, 2, 3), (2, )) == (3, 4, 5)
    with pytest.raises(ValueError):
        broadcast(operator.add, (1, 2), (1, 2, 3))
    assert str(EQ().stretch(2).stretch(1, 3)) == 'EQ() > (2, 6)'
    assert str(EQ().stretch(1, 3).stretch(2)) == 'EQ() > (2, 6)'
    assert str(Linear().shift(2).shift(1, 3)) == 'Linear() shift (3, 5)'
    assert str(Linear().shift(1, 3).shift(2)) == 'Linear() shift (3, 5)'
Example #5
0
def test_broadcast():
    yield eq, broadcast(operator.add, (1, 2, 3), (2, 3, 4)), (3, 5, 7)
    yield eq, broadcast(operator.add, (1, ), (2, 3, 4)), (3, 4, 5)
    yield eq, broadcast(operator.add, (1, 2, 3), (2, )), (3, 4, 5)
    yield raises, ValueError, lambda: broadcast(operator.add, (1, 2),
                                                (1, 2, 3))
    yield eq, str(EQ().stretch(2).stretch(1, 3)), 'EQ() > (2, 6)'
    yield eq, str(EQ().stretch(1, 3).stretch(2)), 'EQ() > (2, 6)'
    yield eq, str(Linear().shift(2).shift(1, 3)), 'Linear() shift (3, 5)'
    yield eq, str(Linear().shift(1, 3).shift(2)), 'Linear() shift (3, 5)'
Example #6
0
def test_linear():
    k = Linear()

    # Verify that the kernel has the right properties.
    assert not k.stationary
    assert str(k) == 'Linear()'

    # Test equality.
    assert Linear() == Linear()
    assert Linear() != EQ()

    # Standard tests:
    standard_kernel_tests(k)
Example #7
0
def test_reversal():
    x1 = np.random.randn(10, 2)
    x2 = np.random.randn(5, 2)
    x3 = np.random.randn()

    # Test with a stationary and non-stationary kernel.
    for k in [EQ(), Linear()]:
        allclose(k(x1), reversed(k)(x1))
        allclose(k(x3), reversed(k)(x3))
        allclose(k(x1, x2), reversed(k)(x1, x2))
        allclose(k(x1, x2), reversed(k)(x2, x1).T)

        # Test double reversal does the right thing.
        allclose(k(x1), reversed(reversed(k))(x1))
        allclose(k(x3), reversed(reversed(k))(x3))
        allclose(k(x1, x2), reversed(reversed(k))(x1, x2))
        allclose(k(x1, x2), reversed(reversed(k))(x2, x1).T)

    # Verify that the kernel has the right properties.
    k = reversed(EQ())
    assert k.stationary

    k = reversed(Linear())
    assert not k.stationary
    assert str(k) == 'Reversed(Linear())'

    # Check equality.
    assert reversed(Linear()) == reversed(Linear())
    assert reversed(Linear()) != Linear()
    assert reversed(Linear()) != reversed(EQ())
    assert reversed(Linear()) != reversed(DecayingKernel(1, 1))

    # Standard tests:
    standard_kernel_tests(k)
Example #8
0
def test_factors():
    k = EQ() * Linear()
    yield eq, k.num_factors, 2
    yield eq, str(k.factor(0)), 'EQ()'
    yield eq, str(k.factor(1)), 'Linear()'
    yield raises, IndexError, lambda: k.factor(2)

    k = (EQ() + EQ()) * Delta() * (RQ(1) + Linear())
    yield eq, k.num_factors, 4
    yield eq, str(k.factor(0)), '2'
    yield eq, str(k.factor(1)), 'EQ()'
    yield eq, str(k.factor(2)), 'Delta()'
    yield eq, str(k.factor(3)), 'RQ(1) + Linear()'
    yield raises, IndexError, lambda: k.factor(4)
    yield raises, IndexError, lambda: EQ().factor(1)
Example #9
0
def test_properties():
    model = Graph()

    p1 = GP(EQ(), graph=model)
    p2 = GP(EQ().stretch(2), graph=model)
    p3 = GP(EQ().periodic(10), graph=model)

    p = p1 + 2 * p2

    assert p.stationary == True, 'stationary'
    assert p.var == 1 + 2**2, 'var'
    allclose(p.length_scale, (1 + 2 * 2**2) / (1 + 2**2))
    assert p.period == np.inf, 'period'

    assert p3.period == 10, 'period'

    p = p3 + p

    assert p.stationary == True, 'stationary 2'
    assert p.var == 1 + 2**2 + 1, 'var 2'
    assert p.period == np.inf, 'period 2'

    p = p + GP(Linear(), graph=model)

    assert p.stationary == False, 'stationary 3'
Example #10
0
def test_basic_arithmetic():
    k1 = EQ()
    k2 = RQ(1e-1)
    k3 = Matern12()
    k4 = Matern32()
    k5 = Matern52()
    k6 = Delta()
    k7 = Linear()
    xs1 = np.random.randn(10, 2), np.random.randn(20, 2)
    xs2 = np.random.randn(), np.random.randn()

    yield ok, allclose(k6(xs1[0]), k6(xs1[0], xs1[0])), 'dispatch'
    yield ok, allclose((k1 * k2)(*xs1), k1(*xs1) * k2(*xs1)), 'prod'
    yield ok, allclose((k1 * k2)(*xs2), k1(*xs2) * k2(*xs2)), 'prod 2'
    yield ok, allclose((k3 + k4)(*xs1), k3(*xs1) + k4(*xs1)), 'sum'
    yield ok, allclose((k3 + k4)(*xs2), k3(*xs2) + k4(*xs2)), 'sum 2'
    yield ok, allclose((5. * k5)(*xs1), 5. * k5(*xs1)), 'prod 3'
    yield ok, allclose((5. * k5)(*xs2), 5. * k5(*xs2)), 'prod 4'
    yield ok, allclose((5. + k7)(*xs1), 5. + k7(*xs1)), 'sum 3'
    yield ok, allclose((5. + k7)(*xs2), 5. + k7(*xs2)), 'sum 4'
    yield ok, allclose(k1.stretch(2.)(*xs1), k1(xs1[0] / 2.,
                                                xs1[1] / 2.)), 'stretch'
    yield ok, allclose(k1.stretch(2.)(*xs2), k1(xs2[0] / 2.,
                                                xs2[1] / 2.)), 'stretch 2'
    yield ok, allclose(
        k1.periodic(1.)(*xs1),
        k1.periodic(1.)(xs1[0], xs1[1] + 5.)), 'periodic'
    yield ok, allclose(
        k1.periodic(1.)(*xs2),
        k1.periodic(1.)(xs2[0], xs2[1] + 5.)), 'periodic 2'
Example #11
0
def test_properties():
    model = Graph()

    p1 = GP(EQ(), graph=model)
    p2 = GP(EQ().stretch(2), graph=model)
    p3 = GP(EQ().periodic(10), graph=model)

    p = p1 + 2 * p2

    yield eq, p.stationary, True, 'stationary'
    yield eq, p.var, 1 + 2 ** 2, 'var'
    yield assert_allclose, p.length_scale, \
          (1 + 2 * 2 ** 2) / (1 + 2 ** 2)
    yield eq, p.period, np.inf, 'period'

    yield eq, p3.period, 10, 'period'

    p = p3 + p

    yield eq, p.stationary, True, 'stationary 2'
    yield eq, p.var, 1 + 2 ** 2 + 1, 'var 2'
    yield eq, p.period, np.inf, 'period 2'

    p = p + GP(Linear(), graph=model)

    yield eq, p.stationary, False, 'stationary 3'
Example #12
0
def test_shifting():
    # Kernels:
    yield eq, str(ZeroKernel().shift(5)), '0'
    yield eq, str(EQ().shift(5)), 'EQ()'
    yield eq, str(Linear().shift(5)), 'Linear() shift 5'
    yield eq, str((5 * EQ()).shift(5)), '5 * EQ()'
    yield eq, str((5 * Linear()).shift(5)), '(5 * Linear()) shift 5'

    # Means:
    def mean(x):
        return x

    m = TensorProductMean(mean)
    yield eq, str(ZeroMean().shift(5)), '0'
    yield eq, str(m.shift(5)), 'mean shift 5'
    yield eq, str(m.shift(5).shift(5)), 'mean shift 10'
    yield eq, str((5 * m).shift(5)), '(5 * mean) shift 5'
Example #13
0
def test_shifting():
    # Kernels:
    assert str(ZeroKernel().shift(5)) == '0'
    assert str(EQ().shift(5)) == 'EQ()'
    assert str(Linear().shift(5)) == 'Linear() shift 5'
    assert str((5 * EQ()).shift(5)) == '5 * EQ()'
    assert str((5 * Linear()).shift(5)) == '(5 * Linear()) shift 5'

    # Means:
    def mean(x):
        return x

    m = TensorProductMean(mean)
    assert str(ZeroMean().shift(5)) == '0'
    assert str(m.shift(5)) == 'mean shift 5'
    assert str(m.shift(5).shift(5)) == 'mean shift 10'
    assert str((5 * m).shift(5)) == '(5 * mean) shift 5'
Example #14
0
def test_linear():
    k = Linear()

    # Verify that the kernel has the right properties.
    yield eq, k.stationary, False
    yield raises, RuntimeError, lambda: k.var
    yield raises, RuntimeError, lambda: k.length_scale
    yield eq, k.period, np.inf
    yield eq, str(k), 'Linear()'

    # Test equality.
    yield eq, Linear(), Linear()
    yield neq, Linear(), EQ()

    # Standard tests:
    for x in kernel_generator(k):
        yield x
Example #15
0
def test_shifted():
    k = ShiftedKernel(2 * EQ(), 5)

    assert k.stationary

    # Test equality.
    assert Linear().shift(2) == Linear().shift(2)
    assert Linear().shift(2) != Linear().shift(3)
    assert Linear().shift(2) != DecayingKernel(1, 1).shift(2)

    # Standard tests:
    standard_kernel_tests(k)

    k = (2 * EQ()).shift(5, 6)

    assert not k.stationary

    # Check computation.
    x1 = np.random.randn(10, 2)
    x2 = np.random.randn(5, 2)
    k = Linear()
    allclose(k.shift(5)(x1, x2), k(x1 - 5, x2 - 5))

    # Check passing in a list.
    k = Linear().shift(np.array([1, 2]))
    k(np.random.randn(10, 2))
Example #16
0
def test_shifted():
    k = ShiftedKernel(2 * EQ(), 5)

    yield eq, k.stationary, True
    yield eq, k.length_scale, 1
    yield eq, k.period, np.inf
    yield eq, k.var, 2

    # Test equality.
    yield eq, Linear().shift(2), Linear().shift(2)
    yield neq, Linear().shift(2), Linear().shift(3)
    yield neq, Linear().shift(2), DecayingKernel(1, 1).shift(2)

    # Standard tests:
    for x in kernel_generator(k):
        yield x

    k = (2 * EQ()).shift(5, 6)

    yield eq, k.stationary, False
    yield raises, RuntimeError, lambda: k.length_scale
    yield eq, k.period, np.inf
    yield eq, k.var, 2

    # Check computation.
    x1 = np.random.randn(10, 2)
    x2 = np.random.randn(5, 2)
    k = Linear()
    yield assert_allclose, k.shift(5)(x1, x2), k(x1 - 5, x2 - 5)

    # Check passing in a list.
    k = Linear().shift([1, 2])
    yield k, np.random.randn(10, 2)
Example #17
0
def test_shifted():
    k = ShiftedKernel(2 * EQ(), 5)

    # Verify that the kernel has the right properties.
    assert k.stationary

    # Test equality.
    assert Linear().shift(2) == Linear().shift(2)
    assert Linear().shift(2) != Linear().shift(3)
    assert Linear().shift(2) != DecayingKernel(1, 1).shift(2)

    # Standard tests:
    standard_kernel_tests(k)

    k = (2 * EQ()).shift(5, 6)

    # Verify that the kernel has the right properties.
    assert not k.stationary

    # Check computation.
    x1 = B.randn(10, 2)
    x2 = B.randn(5, 2)
    k = Linear()
    approx(k.shift(5)(x1, x2), k(x1 - 5, x2 - 5))

    # Check passing in a list.
    k = Linear().shift(np.array([1, 2]))
    k(B.randn(10, 2))
Example #18
0
def test_shifted():
    k = ShiftedKernel(2 * EQ(), 5)

    assert k.stationary
    assert k.length_scale == 1
    assert k.period == np.inf
    assert k.var == 2

    # Test equality.
    assert Linear().shift(2) == Linear().shift(2)
    assert Linear().shift(2) != Linear().shift(3)
    assert Linear().shift(2) != DecayingKernel(1, 1).shift(2)

    # Standard tests:
    standard_kernel_tests(k)

    k = (2 * EQ()).shift(5, 6)

    assert not k.stationary
    with pytest.raises(RuntimeError):
        k.length_scale
    assert k.period == np.inf
    assert k.var == 2

    # Check computation.
    x1 = np.random.randn(10, 2)
    x2 = np.random.randn(5, 2)
    k = Linear()
    allclose(k.shift(5)(x1, x2), k(x1 - 5, x2 - 5))

    # Check passing in a list.
    k = Linear().shift(np.array([1, 2]))
    k(np.random.randn(10, 2))
Example #19
0
def test_linear():
    k = Linear()

    # Verify that the kernel has the right properties.
    assert not k.stationary
    with pytest.raises(RuntimeError):
        k.var
    with pytest.raises(RuntimeError):
        k.length_scale
    assert k.period == np.inf
    assert str(k) == 'Linear()'

    # Test equality.
    assert Linear() == Linear()
    assert Linear() != EQ()

    # Standard tests:
    standard_kernel_tests(k)
Example #20
0
def test_factors():
    k = EQ() * Linear()
    assert k.num_factors == 2
    assert str(k.factor(0)) == 'EQ()'
    assert str(k.factor(1)) == 'Linear()'
    with pytest.raises(IndexError):
        k.factor(2)

    k = (EQ() + EQ()) * Delta() * (RQ(1) + Linear())
    assert k.num_factors == 4
    assert str(k.factor(0)) == '2'
    assert str(k.factor(1)) == 'EQ()'
    assert str(k.factor(2)) == 'Delta()'
    assert str(k.factor(3)) == 'RQ(1) + Linear()'
    with pytest.raises(IndexError):
        k.factor(4)
    with pytest.raises(IndexError):
        EQ().factor(1)
Example #21
0
def test_terms():
    k = EQ() + EQ() * Linear() + RQ(1) * RQ(2) + Delta()
    yield eq, k.num_terms, 4
    yield eq, str(k.term(0)), 'EQ()'
    yield eq, str(k.term(1)), 'EQ() * Linear()'
    yield eq, str(k.term(2)), 'RQ(1) * RQ(2)'
    yield eq, str(k.term(3)), 'Delta()'
    yield raises, IndexError, lambda: k.term(4)
    yield raises, IndexError, lambda: EQ().term(1)
Example #22
0
def test_terms():
    k = EQ() + EQ() * Linear() + RQ(1) * RQ(2) + Delta()
    assert k.num_terms == 4
    assert str(k.term(0)) == 'EQ()'
    assert str(k.term(1)) == 'EQ() * Linear()'
    assert str(k.term(2)) == 'RQ(1) * RQ(2)'
    assert str(k.term(3)) == 'Delta()'
    with pytest.raises(IndexError):
        k.term(4)
    with pytest.raises(IndexError):
        EQ().term(1)
Example #23
0
def test_cancellations_zero():
    # With constants:
    yield eq, str(1 * EQ()), 'EQ()'
    yield eq, str(EQ() * 1), 'EQ()'
    yield eq, str(0 * EQ()), '0'
    yield eq, str(EQ() * 0), '0'
    yield eq, str(0 + EQ()), 'EQ()'
    yield eq, str(EQ() + 0), 'EQ()'
    yield eq, str(0 + OneMean()), '1'
    yield eq, str(OneMean() + 0), '1'

    # Adding to zero:
    yield eq, str(0 + ZeroKernel()), '0'
    yield eq, str(ZeroKernel() + 0), '0'
    yield eq, str(1 + ZeroKernel()), '1'
    yield eq, str(ZeroKernel() + 1), '1'
    yield eq, str(2 + ZeroKernel()), '2 * 1'
    yield eq, str(ZeroKernel() + 2), '2 * 1'

    # Sums:
    yield eq, str(EQ() + EQ()), '2 * EQ()'
    yield eq, str(ZeroKernel() + EQ()), 'EQ()'
    yield eq, str(EQ() + ZeroKernel()), 'EQ()'
    yield eq, str(ZeroKernel() + ZeroKernel()), '0'

    # Products:
    yield eq, str(EQ() * EQ()), 'EQ() * EQ()'
    yield eq, str(ZeroKernel() * EQ()), '0'
    yield eq, str(EQ() * ZeroKernel()), '0'
    yield eq, str(ZeroKernel() * ZeroKernel()), '0'

    # Scales:
    yield eq, str(5 * ZeroKernel()), '0'
    yield eq, str(ZeroKernel() * 5), '0'
    yield eq, str(EQ() * 5), '5 * EQ()'
    yield eq, str(5 * EQ()), '5 * EQ()'

    # Stretches:
    yield eq, str(ZeroKernel().stretch(5)), '0'
    yield eq, str(EQ().stretch(5)), 'EQ() > 5'

    # Periodicisations:
    yield eq, str(ZeroKernel().periodic(5)), '0'
    yield eq, str(EQ().periodic(5)), 'EQ() per 5'

    # Reversals:
    yield eq, str(reversed(ZeroKernel())), '0'
    yield eq, str(reversed(EQ())), 'EQ()'
    yield eq, str(reversed(Linear())), 'Reversed(Linear())'

    # Integration:
    yield eq, str(EQ() * EQ() + ZeroKernel() * EQ()), 'EQ() * EQ()'
    yield eq, str(EQ() * ZeroKernel() + ZeroKernel() * EQ()), '0'
Example #24
0
def test_input_transform():
    k = Linear().transform(lambda x: x - 5)

    assert not k.stationary
    with pytest.raises(RuntimeError):
        k.length_scale
    with pytest.raises(RuntimeError):
        k.var
    with pytest.raises(RuntimeError):
        k.period

    def f1(x):
        return x

    def f2(x):
        return x ** 2

    # Test equality.
    assert EQ().transform(f1) == EQ().transform(f1)
    assert EQ().transform(f1) != EQ().transform(f2)
    assert EQ().transform(f1) != Matern12().transform(f1)

    # Standard tests:
    standard_kernel_tests(k)

    # Test computation of the kernel.
    k = Linear()
    x1, x2 = np.random.randn(10, 2), np.random.randn(10, 2)

    k2 = k.transform(lambda x: x ** 2)
    k3 = k.transform(lambda x: x ** 2, lambda x: x - 5)

    allclose(k(x1 ** 2, x2 ** 2), k2(x1, x2))
    allclose(k(x1 ** 2, x2 - 5), k3(x1, x2))
Example #25
0
def test_input_transform():
    k = Linear().transform(lambda x, c: x - 5)

    yield eq, k.stationary, False
    yield raises, RuntimeError, lambda: k.length_scale
    yield raises, RuntimeError, lambda: k.var
    yield raises, RuntimeError, lambda: k.period

    def f1(x):
        return x

    def f2(x):
        return x**2

    # Test equality.
    yield eq, EQ().transform(f1), EQ().transform(f1)
    yield neq, EQ().transform(f1), EQ().transform(f2)
    yield neq, EQ().transform(f1), Matern12().transform(f1)

    # Standard tests:
    for x in kernel_generator(k):
        yield x

    # Test computation of the kernel.
    k = Linear()
    x1, x2 = np.random.randn(10, 2), np.random.randn(10, 2)

    k2 = k.transform(lambda x, c: x**2)
    k3 = k.transform(lambda x, c: x**2, lambda x, c: x - 5)

    yield assert_allclose, k(x1**2, x2**2), k2(x1, x2)
    yield assert_allclose, k(x1**2, x2 - 5), k3(x1, x2)
Example #26
0
def test_input_transform():
    k = Linear().transform(lambda x: x - 5)

    # Verify that the kernel has the right properties.
    assert not k.stationary

    def f1(x):
        return x

    def f2(x):
        return x**2

    # Test equality.
    assert EQ().transform(f1) == EQ().transform(f1)
    assert EQ().transform(f1) != EQ().transform(f2)
    assert EQ().transform(f1) != Matern12().transform(f1)

    # Standard tests:
    standard_kernel_tests(k)

    # Test computation of the kernel.
    k = Linear()
    x1, x2 = B.randn(10, 2), B.randn(10, 2)

    k2 = k.transform(lambda x: x**2)
    k3 = k.transform(lambda x: x**2, lambda x: x - 5)

    approx(k(x1**2, x2**2), k2(x1, x2))
    approx(k(x1**2, x2 - 5), k3(x1, x2))
Example #27
0
def test_eq():
    k = EQ()

    # Verify that the kernel has the right properties.
    assert k.stationary
    assert str(k) == "EQ()"

    # Test equality.
    assert EQ() == EQ()
    assert EQ() != Linear()

    # Standard tests:
    standard_kernel_tests(k)
Example #28
0
def test_mat52():
    k = Matern52()

    # Verify that the kernel has the right properties.
    assert k.stationary
    assert str(k) == 'Matern52()'

    # Test equality.
    assert Matern52() == Matern52()
    assert Matern52() != Linear()

    # Standard tests:
    standard_kernel_tests(k)
Example #29
0
def test_derivative_linear():
    # Test derivative of kernel `Linear()`.
    k = Linear()
    x1 = B.randn(tf.float64, 10, 1)
    x2 = B.randn(tf.float64, 5, 1)

    # Test derivative with respect to first input.
    allclose(k.diff(0, None)(x1, x2),
             B.ones(tf.float64, 10, 5) * B.transpose(x2))
    allclose(k.diff(0, None)(x1),
             B.ones(tf.float64, 10, 10) * B.transpose(x1))

    # Test derivative with respect to second input.
    allclose(k.diff(None, 0)(x1, x2), B.ones(tf.float64, 10, 5) * x1)
    allclose(k.diff(None, 0)(x1), B.ones(tf.float64, 10, 10) * x1)

    # Test derivative with respect to both inputs.
    ref = B.ones(tf.float64, 10, 5)
    allclose(k.diff(0, 0)(x1, x2), ref)
    allclose(k.diff(0)(x1, x2), ref)
    ref = B.ones(tf.float64, 10, 10)
    allclose(k.diff(0, 0)(x1), ref)
    allclose(k.diff(0)(x1), ref)
Example #30
0
def test_product():
    k = (2 * EQ().stretch(10)) * (3 * RQ(1e-2).stretch(20))

    assert k.stationary

    # Test equality.
    assert EQ() * Linear() == EQ() * Linear()
    assert EQ() * Linear() == Linear() * EQ()
    assert EQ() * Linear() != EQ() * RQ(1e-1)
    assert EQ() * Linear() != RQ(1e-1) * Linear()

    # Standard tests:
    standard_kernel_tests(k)