コード例 #1
0
def get_kernel(kernel_name, input_shape):
    """Return a kernel for a given name."""
    if kernel_name == "white_noise":
        kernel = kernels.WhiteNoise(input_shape=input_shape)
    elif kernel_name == "linear":
        kernel = kernels.Linear(input_shape=input_shape)
    elif kernel_name == "polynomial":
        kernel = kernels.Polynomial(input_shape=input_shape)
    elif kernel_name == "exp_quad":
        kernel = kernels.ExpQuad(input_shape=input_shape)
    elif kernel_name == "rat_quad":
        kernel = kernels.RatQuad(input_shape=input_shape)
    elif kernel_name == "matern12":
        kernel = kernels.Matern(input_shape=input_shape, nu=0.5)
    elif kernel_name == "matern32":
        kernel = kernels.Matern(input_shape=input_shape, nu=1.5)
    elif kernel_name == "matern52":
        kernel = kernels.Matern(input_shape=input_shape, nu=2.5)
    elif kernel_name == "matern72":
        kernel = kernels.Matern(input_shape=input_shape, nu=3.5)
    else:
        raise ValueError(f"Kernel name '{kernel_name}' not recognized.")

    return kernel
コード例 #2
0
def test_add(kernel: kernels.Kernel):
    k_whitenoise = kernels.WhiteNoise(input_shape=kernel.input_shape)
    kernel_sum = kernel + k_whitenoise
    assert isinstance(kernel_sum, SumKernel)
    assert kernel_sum.input_shape == kernel.input_shape
    assert kernel_sum.output_shape == kernel.output_shape
コード例 #3
0
def test_sum_kernel_evaluation(kernel: kernels.Kernel, x0: np.ndarray):
    k_whitenoise = kernels.WhiteNoise(input_shape=kernel.input_shape)
    k_sum = SumKernel(kernel, k_whitenoise)
    np.testing.assert_allclose(k_sum.matrix(x0),
                               kernel.matrix(x0) + k_whitenoise.matrix(x0))
コード例 #4
0
def test_product_kernel_shape_mismatch_raises_error():
    with pytest.raises(ValueError):
        ProductKernel(kernels.WhiteNoise(input_shape=()),
                      kernels.WhiteNoise(input_shape=(1, )))
コード例 #5
0
def test_non_scalar_raises_error():
    with pytest.raises(TypeError):
        ScaledKernel(kernel=kernels.WhiteNoise(input_shape=()),
                     scalar=np.array([0, 1]))