def test_add_parameter(): config = get_defaults() config = add_parameter(config, ("test", tfb.Identity())) assert "test" in config.transformations assert "custom_test" in config.transformations assert config.transformations["test"] == "custom_test" assert isinstance(config.transformations["custom_test"], tfb.Bijector)
def test_prior_mll(): """ Test that the MLL evaluation works with priors attached to the parameter values. """ key = jr.PRNGKey(123) x = jnp.sort(jr.uniform(key, minval=-5.0, maxval=5.0, shape=(100, 1)), axis=0) f = lambda x: jnp.sin(jnp.pi * x) / (jnp.pi * x) y = f(x) + jr.normal(key, shape=x.shape) * 0.1 posterior = Prior(kernel=RBF()) * Gaussian() params = initialise(posterior) config = get_defaults() constrainer, unconstrainer = build_all_transforms(params.keys(), config) params = unconstrainer(params) print(params) mll = marginal_ll(posterior, transform=constrainer) priors = { "lengthscale": tfd.Gamma(1.0, 1.0), "variance": tfd.Gamma(2.0, 2.0), "obs_noise": tfd.Gamma(2.0, 2.0), } mll_eval = mll(params, x, y) mll_eval_priors = mll(params, x, y, priors) assert pytest.approx(mll_eval) == jnp.array(-103.28180663) assert pytest.approx(mll_eval_priors) == jnp.array(-105.509218857)
def test_constrain(likelihood): posterior = Prior(kernel=RBF()) * likelihood() params = initialise(posterior, 10) config = get_defaults() transform_map = build_constrain(params.keys(), config) transformed_params = transform_map(params) assert transformed_params.keys() == params.keys() for u, v in zip(transformed_params.values(), params.values()): assert u.dtype == v.dtype
def test_non_conjugate(): posterior = Prior(kernel=RBF()) * Bernoulli() n = 20 x = jnp.linspace(-1.0, 1.0, n).reshape(-1, 1) y = jnp.sin(x) params = initialise(posterior, 20) config = get_defaults() unconstrainer, constrainer = build_all_transforms(params.keys(), config) params = unconstrainer(params) mll = marginal_ll(posterior, transform=constrainer) assert isinstance(mll, Callable) neg_mll = marginal_ll(posterior, transform=constrainer, negative=True) assert neg_mll(params, x, y) == jnp.array(-1.0) * mll(params, x, y)
def test_conjugate(): posterior = Prior(kernel=RBF()) * Gaussian() x = jnp.linspace(-1.0, 1.0, 20).reshape(-1, 1) y = jnp.sin(x) D = Dataset(X=x, y=y) params = initialise(posterior) config = get_defaults() unconstrainer, constrainer = build_all_transforms(params.keys(), config) params = unconstrainer(params) mll = marginal_ll(posterior, transform=constrainer) assert isinstance(mll, Callable) neg_mll = marginal_ll(posterior, transform=constrainer, negative=True) assert neg_mll(params, D) == jnp.array(-1.0) * mll(params, D)
def test_conjugate(): key = jr.PRNGKey(123) kern = to_spectral(RBF(), 10) posterior = Prior(kernel=kern) * Gaussian() x = jnp.linspace(-1.0, 1.0, 20).reshape(-1, 1) y = jnp.sin(x) params = initialise(key, posterior) config = get_defaults() unconstrainer, constrainer = build_all_transforms(params.keys(), config) params = unconstrainer(params) mll = marginal_ll(posterior, transform=constrainer) assert isinstance(mll, Callable) neg_mll = marginal_ll(posterior, transform=constrainer, negative=True) assert neg_mll(params, x, y) == jnp.array(-1.0) * mll(params, x, y) nmll = neg_mll(params, x, y) assert nmll.shape == ()
def test_build_all_transforms(likelihood): posterior = Prior(kernel=RBF()) * likelihood() params = initialise(posterior, 10) config = get_defaults() t1, t2 = build_all_transforms(params.keys(), config) constrainer = build_constrain(params.keys(), config) constrained = t1(params) constrained2 = constrainer(params) assert constrained2.keys() == constrained2.keys() for u, v in zip(constrained.values(), constrained2.values()): assert_array_equal(u, v) assert u.dtype == v.dtype unconstrained = t2(params) unconstrainer = build_unconstrain(params.keys(), config) unconstrained2 = unconstrainer(params) for u, v in zip(unconstrained.values(), unconstrained2.values()): assert_array_equal(u, v) assert u.dtype == v.dtype
def test_output(transformation, likelihood): posterior = Prior(kernel=RBF()) * likelihood() params = initialise(posterior, 10) config = get_defaults() transform_map = transformation(params.keys(), config) assert isinstance(transform_map, Callable)
def test_get_defaults(): config = get_defaults() assert isinstance(config, ConfigDict) assert isinstance(config.transformations, ConfigDict)