示例#1
0
def test_spectral():
    key = jr.PRNGKey(123)
    kernel = to_spectral(RBF(), 10)
    posterior = Prior(kernel=kernel) * Gaussian()
    params = initialise(key, posterior)
    assert list(params.keys()) == sorted(
        ["basis_fns", "obs_noise", "lengthscale", "variance"])
    assert params["basis_fns"].shape == (10, 1)
示例#2
0
def test_prior():
    p = Prior(kernel=RBF())
    params = initialise(p)
    assert list(params.keys()) == sorted(['lengthscale', 'variance'])
    assert isinstance(params, dict)
示例#3
0
def test_non_conjugate_initialise(n):
    posterior = Prior(kernel=RBF()) * Bernoulli()
    params = initialise(posterior, n)
    assert list(params.keys()) == sorted(["lengthscale", "variance", "latent"])
    assert params["latent"].shape == (n, 1)
示例#4
0
def test_dtype(lik):
    posterior = Prior(kernel=RBF()) * lik()
    for k, v in initialise(posterior, 10).items():
        assert v.dtype == jnp.float64
示例#5
0
def test_initialise():
    posterior = Prior(kernel=RBF()) * Gaussian()
    params = initialise(posterior)
    assert list(params.keys()) == sorted(
        ["lengthscale", "variance", "obs_noise"])