def test_prior_sample(n, n_sample): key = jr.PRNGKey(123) f = Prior(kernel=RBF()) sample_points = jnp.linspace(-1.0, 1.0, num=n).reshape(-1, 1) params = initialise(RBF()) samples = sample(key, f, params, sample_points, n_samples=n_sample) assert samples.shape == (n_sample, sample_points.shape[0])
def test_prior_random_variable(n): f = Prior(kernel=RBF()) sample_points = jnp.linspace(-1.0, 1.0, num=n).reshape(-1, 1) D = Dataset(X=sample_points) params = initialise(RBF()) rv = random_variable(f, params, D) assert isinstance(rv, tfd.MultivariateNormalFullCovariance)
def test_call(): kernel = RBF() params = initialise(kernel) x, y = jnp.array([[1.]]), jnp.array([[0.5]]) point_corr = kernel(x, y, params) assert isinstance(point_corr, jnp.DeviceArray) assert point_corr.shape == ()
def test_prior_mll(): """ Test that the MLL evaluation works with priors attached to the parameter values. """ key = jr.PRNGKey(123) x = jnp.sort(jr.uniform(key, minval=-5.0, maxval=5.0, shape=(100, 1)), axis=0) f = lambda x: jnp.sin(jnp.pi * x) / (jnp.pi * x) y = f(x) + jr.normal(key, shape=x.shape) * 0.1 posterior = Prior(kernel=RBF()) * Gaussian() params = initialise(posterior) config = get_defaults() constrainer, unconstrainer = build_all_transforms(params.keys(), config) params = unconstrainer(params) print(params) mll = marginal_ll(posterior, transform=constrainer) priors = { "lengthscale": tfd.Gamma(1.0, 1.0), "variance": tfd.Gamma(2.0, 2.0), "obs_noise": tfd.Gamma(2.0, 2.0), } mll_eval = mll(params, x, y) mll_eval_priors = mll(params, x, y, priors) assert pytest.approx(mll_eval) == jnp.array(-103.28180663) assert pytest.approx(mll_eval_priors) == jnp.array(-105.509218857)
def _get_conjugate_posterior_params() -> dict: kernel = RBF() prior = Prior(kernel=kernel) lik = Gaussian() posterior = prior * lik params = initialise(posterior) return params, posterior
def test_spectral(): key = jr.PRNGKey(123) kernel = to_spectral(RBF(), 10) posterior = Prior(kernel=kernel) * Gaussian() params = initialise(key, posterior) assert list(params.keys()) == sorted( ["basis_fns", "obs_noise", "lengthscale", "variance"]) assert params["basis_fns"].shape == (10, 1)
def test_posterior_random_variable(n): f = Prior(kernel=RBF()) * Gaussian() x = jnp.linspace(-1.0, 1.0, 10).reshape(-1, 1) y = jnp.sin(x) sample_points = jnp.linspace(-1.0, 1.0, num=n).reshape(-1, 1) params = initialise(f) rv = random_variable(f, params, sample_points, x, y) assert isinstance(rv, tfd.MultivariateNormalFullCovariance)
def test_constrain(likelihood): posterior = Prior(kernel=RBF()) * likelihood() params = initialise(posterior, 10) config = get_defaults() transform_map = build_constrain(params.keys(), config) transformed_params = transform_map(params) assert transformed_params.keys() == params.keys() for u, v in zip(transformed_params.values(), params.values()): assert u.dtype == v.dtype
def test_gram(dim): x = jnp.linspace(-1.0, 1.0, num=10).reshape(-1, 1) if dim > 1: x = jnp.hstack([x] * dim) kern = RBF() params = initialise(kern) gram_matrix = gram(kern, x, params) assert gram_matrix.shape[0] == x.shape[0] assert gram_matrix.shape[0] == gram_matrix.shape[1]
def test_check_needless(): complete_prior = { "lengthscale": tfd.Gamma(1.0, 1.0), "variance": tfd.Gamma(2.0, 2.0), "obs_noise": tfd.Gamma(3.0, 3.0), "latent": tfd.Normal(loc=0.0, scale=1.0), } posterior = Prior(kernel=RBF()) * Bernoulli() priors = prior_checks(posterior, complete_prior) assert priors == complete_prior
def test_posterior_sample(n, n_sample): key = jr.PRNGKey(123) f = Prior(kernel=RBF()) * Gaussian() x = jnp.linspace(-1.0, 1.0, 10).reshape(-1, 1) y = jnp.sin(x) sample_points = jnp.linspace(-1.0, 1.0, num=n).reshape(-1, 1) params = initialise(f) rv = random_variable(f, params, sample_points, x, y) samples = sample(key, rv, n_samples=n_sample) assert samples.shape == (n_sample, sample_points.shape[0])
def test_posterior_random_variable(n): f = Prior(kernel=RBF()) * Gaussian() x = jnp.linspace(-1.0, 1.0, 10).reshape(-1, 1) y = jnp.sin(x) D = Dataset(X=x, y=y) sample_points = jnp.linspace(-1.0, 1.0, num=n).reshape(-1, 1) params = initialise(f) rv = random_variable(f, params, D) assert isinstance(rv, Callable) fstar = rv(sample_points) assert isinstance(fstar, tfd.MultivariateNormalFullCovariance)
def test_conjugate_variance(): key = jr.PRNGKey(123) x = jr.uniform(key, shape=(20, 1), minval=-3.0, maxval=3.0) y = jnp.sin(x) posterior = Prior(kernel=RBF()) * Gaussian() params = initialise(posterior) xtest = jnp.linspace(-3.0, 3.0, 30).reshape(-1, 1) sigma = variance(posterior, params, xtest, x, y) assert sigma.shape == (xtest.shape[0], xtest.shape[0])
def test_non_conjugate_mean(): key = jr.PRNGKey(123) x = jnp.sort(jr.uniform(key, shape=(10, 1), minval=-1.0, maxval=1.0), axis=0) y = 0.5 * jnp.sign( jnp.cos(3 * x + jr.normal(key, shape=x.shape) * 0.05)) + 0.5 xtest = jnp.linspace(-1.05, 1.05, 50).reshape(-1, 1) posterior = Prior(kernel=RBF()) * Bernoulli() params = initialise(posterior, x.shape[0]) mu = mean(posterior, params, xtest, x, y) assert mu.shape == (xtest.shape[0], )
def test_conjugate_mean(): key = jr.PRNGKey(123) x = jr.uniform(key, shape=(20, 1), minval=-3.0, maxval=3.0) y = jnp.sin(x) D = Dataset(X=x, y=y) posterior = Prior(kernel=RBF()) * Gaussian() params = initialise(posterior) xtest = jnp.linspace(-3.0, 3.0, 30).reshape(-1, 1) meanf = mean(posterior, params, D) mu = meanf(xtest) assert mu.shape == (xtest.shape[0], y.shape[1])
def test_non_conjugate_variance(): key = jr.PRNGKey(123) x = jnp.sort(jr.uniform(key, shape=(10, 1), minval=-1.0, maxval=1.0), axis=0) y = 0.5 * jnp.sign(jnp.cos(3 * x + jr.normal(key, shape=x.shape) * 0.05)) + 0.5 D = Dataset(X=x, y=y) xtest = jnp.linspace(-1.05, 1.05, 50).reshape(-1, 1) posterior = Prior(kernel=RBF()) * Bernoulli() params = initialise(posterior, x.shape[0]) varf = variance(posterior, params, D) sigma = varf(xtest) assert sigma.shape == (xtest.shape[0],)
def test_pos_def(dim, ell, sigma): n = 30 x = jnp.linspace(0.0, 1.0, num=n).reshape(-1, 1) if dim > 1: x = jnp.hstack((x) * dim) kern = RBF() params = {"lengthscale": jnp.array([ell]), "variance": jnp.array(sigma)} gram_matrix = gram(kern, x, params) jitter_matrix = I(n) * 1e-6 gram_matrix += jitter_matrix min_eig = jnp.linalg.eigvals(gram_matrix).min() assert min_eig > 0
def test_non_conjugate(): posterior = Prior(kernel=RBF()) * Bernoulli() n = 20 x = jnp.linspace(-1.0, 1.0, n).reshape(-1, 1) y = jnp.sin(x) params = initialise(posterior, 20) config = get_defaults() unconstrainer, constrainer = build_all_transforms(params.keys(), config) params = unconstrainer(params) mll = marginal_ll(posterior, transform=constrainer) assert isinstance(mll, Callable) neg_mll = marginal_ll(posterior, transform=constrainer, negative=True) assert neg_mll(params, x, y) == jnp.array(-1.0) * mll(params, x, y)
def test_conjugate(): posterior = Prior(kernel=RBF()) * Gaussian() x = jnp.linspace(-1.0, 1.0, 20).reshape(-1, 1) y = jnp.sin(x) D = Dataset(X=x, y=y) params = initialise(posterior) config = get_defaults() unconstrainer, constrainer = build_all_transforms(params.keys(), config) params = unconstrainer(params) mll = marginal_ll(posterior, transform=constrainer) assert isinstance(mll, Callable) neg_mll = marginal_ll(posterior, transform=constrainer, negative=True) assert neg_mll(params, D) == jnp.array(-1.0) * mll(params, D)
def test_non_conjugate_rv(n): key = jr.PRNGKey(123) f = posterior = Prior(kernel=RBF()) * Bernoulli() x = jnp.sort(jr.uniform(key, shape=(n, 1), minval=-1.0, maxval=1.0), axis=0) y = 0.5 * jnp.sign(jnp.cos(3 * x + jr.normal(key, shape=x.shape) * 0.05)) + 0.5 D = Dataset(X=x, y=y) sample_points = jnp.linspace(-1.0, 1.0, num=n).reshape(-1, 1) hyperparams = {"lengthscale": jnp.array([1.0]), "variance": jnp.array([1.0])} params = complete(hyperparams, posterior, x.shape[0]) rv = random_variable(f, params, D) assert isinstance(rv, Callable) fstar = rv(sample_points) assert isinstance(fstar, tfd.ProbitBernoulli)
def test_spectral_sample(): key = jr.PRNGKey(123) M = 10 x = jnp.linspace(-1.0, 1.0, 20).reshape(-1, 1) y = jnp.sin(x) D = Dataset(X=x, y=y) sample_points = jnp.linspace(-1.0, 1.0, num=50).reshape(-1, 1) kernel = to_spectral(RBF(), M) post = Prior(kernel=kernel) * Gaussian() params = initialise(key, post) sparams = {"basis_fns": params["basis_fns"]} del params["basis_fns"] posterior_rv = random_variable(post, params, D, static_params=sparams)(sample_points) assert isinstance(posterior_rv, tfd.Distribution) assert isinstance(posterior_rv, tfd.MultivariateNormalFullCovariance)
def test_conjugate(): key = jr.PRNGKey(123) kern = to_spectral(RBF(), 10) posterior = Prior(kernel=kern) * Gaussian() x = jnp.linspace(-1.0, 1.0, 20).reshape(-1, 1) y = jnp.sin(x) params = initialise(key, posterior) config = get_defaults() unconstrainer, constrainer = build_all_transforms(params.keys(), config) params = unconstrainer(params) mll = marginal_ll(posterior, transform=constrainer) assert isinstance(mll, Callable) neg_mll = marginal_ll(posterior, transform=constrainer, negative=True) assert neg_mll(params, x, y) == jnp.array(-1.0) * mll(params, x, y) nmll = neg_mll(params, x, y) assert nmll.shape == ()
def test_build_all_transforms(likelihood): posterior = Prior(kernel=RBF()) * likelihood() params = initialise(posterior, 10) config = get_defaults() t1, t2 = build_all_transforms(params.keys(), config) constrainer = build_constrain(params.keys(), config) constrained = t1(params) constrained2 = constrainer(params) assert constrained2.keys() == constrained2.keys() for u, v in zip(constrained.values(), constrained2.values()): assert_array_equal(u, v) assert u.dtype == v.dtype unconstrained = t2(params) unconstrainer = build_unconstrain(params.keys(), config) unconstrained2 = unconstrainer(params) for u, v in zip(unconstrained.values(), unconstrained2.values()): assert_array_equal(u, v) assert u.dtype == v.dtype
def mogpe_checkpoint_to_numpy(config_file, ckpt_dir, data_file, expert_num=0): # load data set data = np.load(data_file) X = data["x"] # configure mogpe model from checkpoint model = load_model_from_config_and_checkpoint(config_file, ckpt_dir, X=X) # select the gating function to use gating_func = model.gating_network.gating_function_list[expert_num] mean_function = 0.0 # mogpe gating functions have zero mean function whiten = gating_func.whiten # sparse GP parameters q_mu = gating_func.q_mu.numpy() q_sqrt = gating_func.q_sqrt.numpy() inducing_variable = ( gating_func.inducing_variable.inducing_variable.Z.numpy()) # kerenl parameters variance = gating_func.kernel.kernels[0].variance.numpy() lengthscales = gating_func.kernel.kernels[0].lengthscales.numpy() kernel = RBF(variance=variance, lengthscales=lengthscales) return kernel, inducing_variable, mean_function, q_mu, q_sqrt, whiten
def test_spectral(): kernel = to_spectral(RBF(), 10) posterior = Prior(kernel=kernel) * Gaussian() assert isinstance(posterior, SpectralPosterior)
def test_output(transformation, likelihood): posterior = Prior(kernel=RBF()) * likelihood() params = initialise(posterior, 10) config = get_defaults() transform_map = transformation(params.keys(), config) assert isinstance(transform_map, Callable)
def test_checks(): incomplete_priors = {"lengthscale": jnp.array([1.0])} posterior = Prior(kernel=RBF()) * Bernoulli() priors = prior_checks(posterior, incomplete_priors) assert "latent" in priors.keys() assert "variance" not in priors.keys()
def test_conjugate_posterior(): p = Prior(kernel=RBF()) lik = Gaussian() post = p * lik assert isinstance(post, ConjugatePosterior)
def test_non_conjugate_poster(likelihood): posterior = Prior(kernel=RBF()) * likelihood() assert isinstance(posterior, NonConjugatePosterior)
def test_to_spectral(n_basis): base_kern = RBF() spectral = to_spectral(base_kern, n_basis) assert isinstance(spectral, SpectralRBF) assert spectral.num_basis == n_basis assert spectral.stationary