def test_prior_mll(): """ Test that the MLL evaluation works with priors attached to the parameter values. """ key = jr.PRNGKey(123) x = jnp.sort(jr.uniform(key, minval=-5.0, maxval=5.0, shape=(100, 1)), axis=0) f = lambda x: jnp.sin(jnp.pi * x) / (jnp.pi * x) y = f(x) + jr.normal(key, shape=x.shape) * 0.1 posterior = Prior(kernel=RBF()) * Gaussian() params = initialise(posterior) config = get_defaults() constrainer, unconstrainer = build_all_transforms(params.keys(), config) params = unconstrainer(params) print(params) mll = marginal_ll(posterior, transform=constrainer) priors = { "lengthscale": tfd.Gamma(1.0, 1.0), "variance": tfd.Gamma(2.0, 2.0), "obs_noise": tfd.Gamma(2.0, 2.0), } mll_eval = mll(params, x, y) mll_eval_priors = mll(params, x, y, priors) assert pytest.approx(mll_eval) == jnp.array(-103.28180663) assert pytest.approx(mll_eval_priors) == jnp.array(-105.509218857)
def init_svgp_sample(): mean_function = Constant(output_dim=output_dim) likelihood = Gaussian() if output_dim > 1: kernels = [ SquaredExponential( lengthscales=jnp.ones(input_dim, dtype=jnp.float64), variance=2.0, ) for _ in range(output_dim) ] kernel = SeparateIndependent(kernels) else: kernel = SquaredExponential(lengthscales=jnp.ones( input_dim, dtype=jnp.float64), variance=2.0) inducing_variable = jax.random.uniform(key=key, shape=(num_inducing, input_dim)) return SVGPSample( kernel, likelihood, inducing_variable, mean_function, num_latent_gps=output_dim, q_diag=q_diag, whiten=whiten, )
def _get_conjugate_posterior_params() -> dict: kernel = RBF() prior = Prior(kernel=kernel) lik = Gaussian() posterior = prior * lik params = initialise(posterior) return params, posterior
def test_posterior_random_variable(n): f = Prior(kernel=RBF()) * Gaussian() x = jnp.linspace(-1.0, 1.0, 10).reshape(-1, 1) y = jnp.sin(x) sample_points = jnp.linspace(-1.0, 1.0, num=n).reshape(-1, 1) params = initialise(f) rv = random_variable(f, params, sample_points, x, y) assert isinstance(rv, tfd.MultivariateNormalFullCovariance)
def test_spectral(): key = jr.PRNGKey(123) kernel = to_spectral(RBF(), 10) posterior = Prior(kernel=kernel) * Gaussian() params = initialise(key, posterior) assert list(params.keys()) == sorted( ["basis_fns", "obs_noise", "lengthscale", "variance"]) assert params["basis_fns"].shape == (10, 1)
def test_posterior_sample(n, n_sample): key = jr.PRNGKey(123) f = Prior(kernel=RBF()) * Gaussian() x = jnp.linspace(-1.0, 1.0, 10).reshape(-1, 1) y = jnp.sin(x) sample_points = jnp.linspace(-1.0, 1.0, num=n).reshape(-1, 1) params = initialise(f) rv = random_variable(f, params, sample_points, x, y) samples = sample(key, rv, n_samples=n_sample) assert samples.shape == (n_sample, sample_points.shape[0])
def test_conjugate_variance(): key = jr.PRNGKey(123) x = jr.uniform(key, shape=(20, 1), minval=-3.0, maxval=3.0) y = jnp.sin(x) posterior = Prior(kernel=RBF()) * Gaussian() params = initialise(posterior) xtest = jnp.linspace(-3.0, 3.0, 30).reshape(-1, 1) sigma = variance(posterior, params, xtest, x, y) assert sigma.shape == (xtest.shape[0], xtest.shape[0])
def test_posterior_random_variable(n): f = Prior(kernel=RBF()) * Gaussian() x = jnp.linspace(-1.0, 1.0, 10).reshape(-1, 1) y = jnp.sin(x) D = Dataset(X=x, y=y) sample_points = jnp.linspace(-1.0, 1.0, num=n).reshape(-1, 1) params = initialise(f) rv = random_variable(f, params, D) assert isinstance(rv, Callable) fstar = rv(sample_points) assert isinstance(fstar, tfd.MultivariateNormalFullCovariance)
def test_conjugate(): posterior = Prior(kernel=RBF()) * Gaussian() x = jnp.linspace(-1.0, 1.0, 20).reshape(-1, 1) y = jnp.sin(x) params = initialise(posterior) config = get_defaults() unconstrainer, constrainer = build_all_transforms(params.keys(), config) params = unconstrainer(params) mll = marginal_ll(posterior, transform=constrainer) assert isinstance(mll, Callable) neg_mll = marginal_ll(posterior, transform=constrainer, negative=True) assert neg_mll(params, x, y) == jnp.array(-1.0) * mll(params, x, y)
def test_conjugate_mean(): key = jr.PRNGKey(123) x = jr.uniform(key, shape=(20, 1), minval=-3.0, maxval=3.0) y = jnp.sin(x) D = Dataset(X=x, y=y) posterior = Prior(kernel=RBF()) * Gaussian() params = initialise(posterior) xtest = jnp.linspace(-3.0, 3.0, 30).reshape(-1, 1) meanf = mean(posterior, params, D) mu = meanf(xtest) assert mu.shape == (xtest.shape[0], y.shape[1])
def test_spectral_sample(): key = jr.PRNGKey(123) M = 10 x = jnp.linspace(-1.0, 1.0, 20).reshape(-1, 1) y = jnp.sin(x) D = Dataset(X=x, y=y) sample_points = jnp.linspace(-1.0, 1.0, num=50).reshape(-1, 1) kernel = to_spectral(RBF(), M) post = Prior(kernel=kernel) * Gaussian() params = initialise(key, post) sparams = {"basis_fns": params["basis_fns"]} del params["basis_fns"] posterior_rv = random_variable(post, params, D, static_params=sparams)(sample_points) assert isinstance(posterior_rv, tfd.Distribution) assert isinstance(posterior_rv, tfd.MultivariateNormalFullCovariance)
def test_spectral(): key = jr.PRNGKey(123) kern = to_spectral(RBF(), 10) posterior = Prior(kernel=kern) * Gaussian() x = jnp.linspace(-1.0, 1.0, 20).reshape(-1, 1) y = jnp.sin(x) D = Dataset(X=x, y=y) params = initialise(key, posterior) config = get_defaults() unconstrainer, constrainer = build_all_transforms(params.keys(), config) params = unconstrainer(params) mll = marginal_ll(posterior, transform=constrainer) assert isinstance(mll, Callable) neg_mll = marginal_ll(posterior, transform=constrainer, negative=True) assert neg_mll(params, D) == jnp.array(-1.0) * mll(params, D) nmll = neg_mll(params, D) assert nmll.shape == ()
def test_predict_f( self, input_dim, output_dim, num_data, num_inducing, q_diag, whiten, full_cov, full_output_cov, num_inducing_samples, ): """Check shapes of output""" Xnew = jax.random.uniform(key, shape=(num_data, input_dim)) mean_function = Constant(output_dim=output_dim) likelihood = Gaussian() if output_dim > 1: kernels = [ SquaredExponential( lengthscales=jnp.ones(input_dim, dtype=jnp.float64), variance=2.0 ) for _ in range(output_dim) ] kernel = SeparateIndependent(kernels) else: kernel = SquaredExponential( lengthscales=jnp.ones(input_dim, dtype=jnp.float64), variance=2.0 ) inducing_variable = jax.random.uniform(key=key, shape=(num_inducing, input_dim)) svgp = SVGPSample( kernel, likelihood, inducing_variable, mean_function, num_latent_gps=output_dim, q_diag=q_diag, whiten=whiten, ) params = svgp.get_params() def predict_f(params, Xnew): return svgp.predict_f( params, Xnew, key, num_inducing_samples, full_cov, full_output_cov ) var_predict_f = self.variant(predict_f) mean, cov = var_predict_f(params, Xnew) print("mean") print(mean.shape) print(cov.shape) if num_inducing_samples is None: if not full_output_cov: assert mean.ndim == 2 assert mean.shape[0] == num_data assert mean.shape[1] == output_dim if full_cov: assert cov.ndim == 3 assert cov.shape[0] == output_dim assert cov.shape[1] == cov.shape[2] == num_data else: assert cov.ndim == 2 assert cov.shape[0] == num_data assert cov.shape[1] == output_dim else: raise NotImplementedError("Need to add tests for full_output_cov=True") else: if not full_output_cov: assert mean.ndim == 3 assert mean.shape[0] == num_inducing_samples assert mean.shape[1] == num_data assert mean.shape[2] == output_dim if full_cov: assert cov.ndim == 4 assert cov.shape[0] == num_inducing_samples assert cov.shape[1] == output_dim assert cov.shape[2] == cov.shape[3] == num_data else: assert cov.ndim == 3 assert cov.shape[0] == num_inducing_samples assert cov.shape[1] == num_data assert cov.shape[2] == output_dim else: raise NotImplementedError("Need to add tests for full_output_cov=True")
def test_conjugate_posterior(): p = Prior(kernel=RBF()) lik = Gaussian() post = p * lik assert isinstance(post, ConjugatePosterior)
def test_complete(): posterior = Prior(kernel=RBF()) * Gaussian() partial_params = {"lengthscale": jnp.array(1.0)} full_params = complete(partial_params, posterior) assert list(full_params.keys()) == sorted( ["lengthscale", "variance", "obs_noise"])
def test_spectral(): kernel = to_spectral(RBF(), 10) posterior = Prior(kernel=kernel) * Gaussian() assert isinstance(posterior, SpectralPosterior)
def test_initialise(): posterior = Prior(kernel=RBF()) * Gaussian() params = initialise(posterior) assert list(params.keys()) == sorted( ["lengthscale", "variance", "obs_noise"])