def test_check_needless(): complete_prior = { "lengthscale": tfd.Gamma(1.0, 1.0), "variance": tfd.Gamma(2.0, 2.0), "obs_noise": tfd.Gamma(3.0, 3.0), "latent": tfd.Normal(loc=0.0, scale=1.0), } posterior = Prior(kernel=RBF()) * Bernoulli() priors = prior_checks(posterior, complete_prior) assert priors == complete_prior
def test_predictive_moment(n): l = Bernoulli() key = jr.PRNGKey(123) fmean = jr.uniform(key=key, shape=(n, )) * -1 fvar = jr.uniform(key=key, shape=(n, )) pred_mom_fn = predictive_moments(l) rv = pred_mom_fn(fmean, fvar) mu = rv.mean() sigma = rv.variance() assert mu.shape == (n, ) assert sigma.shape == (n, )
def test_non_conjugate(): posterior = Prior(kernel=RBF()) * Bernoulli() n = 20 x = jnp.linspace(-1.0, 1.0, n).reshape(-1, 1) y = jnp.sin(x) params = initialise(posterior, 20) config = get_defaults() unconstrainer, constrainer = build_all_transforms(params.keys(), config) params = unconstrainer(params) mll = marginal_ll(posterior, transform=constrainer) assert isinstance(mll, Callable) neg_mll = marginal_ll(posterior, transform=constrainer, negative=True) assert neg_mll(params, x, y) == jnp.array(-1.0) * mll(params, x, y)
def test_non_conjugate_variance(): key = jr.PRNGKey(123) x = jnp.sort(jr.uniform(key, shape=(10, 1), minval=-1.0, maxval=1.0), axis=0) y = 0.5 * jnp.sign(jnp.cos(3 * x + jr.normal(key, shape=x.shape) * 0.05)) + 0.5 D = Dataset(X=x, y=y) xtest = jnp.linspace(-1.05, 1.05, 50).reshape(-1, 1) posterior = Prior(kernel=RBF()) * Bernoulli() params = initialise(posterior, x.shape[0]) varf = variance(posterior, params, D) sigma = varf(xtest) assert sigma.shape == (xtest.shape[0],)
def test_non_conjugate_mean(): key = jr.PRNGKey(123) x = jnp.sort(jr.uniform(key, shape=(10, 1), minval=-1.0, maxval=1.0), axis=0) y = 0.5 * jnp.sign( jnp.cos(3 * x + jr.normal(key, shape=x.shape) * 0.05)) + 0.5 xtest = jnp.linspace(-1.05, 1.05, 50).reshape(-1, 1) posterior = Prior(kernel=RBF()) * Bernoulli() params = initialise(posterior, x.shape[0]) mu = mean(posterior, params, xtest, x, y) assert mu.shape == (xtest.shape[0], )
def test_non_conjugate_rv(n): key = jr.PRNGKey(123) f = posterior = Prior(kernel=RBF()) * Bernoulli() x = jnp.sort(jr.uniform(key, shape=(n, 1), minval=-1.0, maxval=1.0), axis=0) y = 0.5 * jnp.sign(jnp.cos(3 * x + jr.normal(key, shape=x.shape) * 0.05)) + 0.5 D = Dataset(X=x, y=y) sample_points = jnp.linspace(-1.0, 1.0, num=n).reshape(-1, 1) hyperparams = {"lengthscale": jnp.array([1.0]), "variance": jnp.array([1.0])} params = complete(hyperparams, posterior, x.shape[0]) rv = random_variable(f, params, D) assert isinstance(rv, Callable) fstar = rv(sample_points) assert isinstance(fstar, tfd.ProbitBernoulli)
def test_checks(): incomplete_priors = {"lengthscale": jnp.array([1.0])} posterior = Prior(kernel=RBF()) * Bernoulli() priors = prior_checks(posterior, incomplete_priors) assert "latent" in priors.keys() assert "variance" not in priors.keys()
def test_non_conjugate_initialise(n): posterior = Prior(kernel=RBF()) * Bernoulli() params = initialise(posterior, n) assert list(params.keys()) == sorted(["lengthscale", "variance", "latent"]) assert params["latent"].shape == (n, 1)