def test_spectral(): key = jr.PRNGKey(123) kernel = to_spectral(RBF(), 10) posterior = Prior(kernel=kernel) * Gaussian() params = initialise(key, posterior) assert list(params.keys()) == sorted( ["basis_fns", "obs_noise", "lengthscale", "variance"]) assert params["basis_fns"].shape == (10, 1)
def test_prior(): p = Prior(kernel=RBF()) params = initialise(p) assert list(params.keys()) == sorted(['lengthscale', 'variance']) assert isinstance(params, dict)
def test_non_conjugate_initialise(n): posterior = Prior(kernel=RBF()) * Bernoulli() params = initialise(posterior, n) assert list(params.keys()) == sorted(["lengthscale", "variance", "latent"]) assert params["latent"].shape == (n, 1)
def test_dtype(lik): posterior = Prior(kernel=RBF()) * lik() for k, v in initialise(posterior, 10).items(): assert v.dtype == jnp.float64
def test_initialise(): posterior = Prior(kernel=RBF()) * Gaussian() params = initialise(posterior) assert list(params.keys()) == sorted( ["lengthscale", "variance", "obs_noise"])