Beispiel #1
0
def test_prior_mll():
    """
    Test that the MLL evaluation works with priors attached to the parameter values.
    """
    key = jr.PRNGKey(123)
    x = jnp.sort(jr.uniform(key, minval=-5.0, maxval=5.0, shape=(100, 1)),
                 axis=0)
    f = lambda x: jnp.sin(jnp.pi * x) / (jnp.pi * x)
    y = f(x) + jr.normal(key, shape=x.shape) * 0.1
    posterior = Prior(kernel=RBF()) * Gaussian()

    params = initialise(posterior)
    config = get_defaults()
    constrainer, unconstrainer = build_all_transforms(params.keys(), config)
    params = unconstrainer(params)
    print(params)

    mll = marginal_ll(posterior, transform=constrainer)

    priors = {
        "lengthscale": tfd.Gamma(1.0, 1.0),
        "variance": tfd.Gamma(2.0, 2.0),
        "obs_noise": tfd.Gamma(2.0, 2.0),
    }
    mll_eval = mll(params, x, y)
    mll_eval_priors = mll(params, x, y, priors)

    assert pytest.approx(mll_eval) == jnp.array(-103.28180663)
    assert pytest.approx(mll_eval_priors) == jnp.array(-105.509218857)
Beispiel #2
0
 def init_svgp_sample():
     mean_function = Constant(output_dim=output_dim)
     likelihood = Gaussian()
     if output_dim > 1:
         kernels = [
             SquaredExponential(
                 lengthscales=jnp.ones(input_dim, dtype=jnp.float64),
                 variance=2.0,
             ) for _ in range(output_dim)
         ]
         kernel = SeparateIndependent(kernels)
     else:
         kernel = SquaredExponential(lengthscales=jnp.ones(
             input_dim, dtype=jnp.float64),
                                     variance=2.0)
     inducing_variable = jax.random.uniform(key=key,
                                            shape=(num_inducing,
                                                   input_dim))
     return SVGPSample(
         kernel,
         likelihood,
         inducing_variable,
         mean_function,
         num_latent_gps=output_dim,
         q_diag=q_diag,
         whiten=whiten,
     )
def _get_conjugate_posterior_params() -> dict:
    kernel = RBF()
    prior = Prior(kernel=kernel)
    lik = Gaussian()
    posterior = prior * lik
    params = initialise(posterior)
    return params, posterior
Beispiel #4
0
def test_posterior_random_variable(n):
    f = Prior(kernel=RBF()) * Gaussian()
    x = jnp.linspace(-1.0, 1.0, 10).reshape(-1, 1)
    y = jnp.sin(x)
    sample_points = jnp.linspace(-1.0, 1.0, num=n).reshape(-1, 1)
    params = initialise(f)
    rv = random_variable(f, params, sample_points, x, y)
    assert isinstance(rv, tfd.MultivariateNormalFullCovariance)
Beispiel #5
0
def test_spectral():
    key = jr.PRNGKey(123)
    kernel = to_spectral(RBF(), 10)
    posterior = Prior(kernel=kernel) * Gaussian()
    params = initialise(key, posterior)
    assert list(params.keys()) == sorted(
        ["basis_fns", "obs_noise", "lengthscale", "variance"])
    assert params["basis_fns"].shape == (10, 1)
Beispiel #6
0
def test_posterior_sample(n, n_sample):
    key = jr.PRNGKey(123)
    f = Prior(kernel=RBF()) * Gaussian()
    x = jnp.linspace(-1.0, 1.0, 10).reshape(-1, 1)
    y = jnp.sin(x)
    sample_points = jnp.linspace(-1.0, 1.0, num=n).reshape(-1, 1)
    params = initialise(f)
    rv = random_variable(f, params, sample_points, x, y)
    samples = sample(key, rv, n_samples=n_sample)
    assert samples.shape == (n_sample, sample_points.shape[0])
Beispiel #7
0
def test_conjugate_variance():
    key = jr.PRNGKey(123)
    x = jr.uniform(key, shape=(20, 1), minval=-3.0, maxval=3.0)
    y = jnp.sin(x)

    posterior = Prior(kernel=RBF()) * Gaussian()
    params = initialise(posterior)

    xtest = jnp.linspace(-3.0, 3.0, 30).reshape(-1, 1)
    sigma = variance(posterior, params, xtest, x, y)
    assert sigma.shape == (xtest.shape[0], xtest.shape[0])
Beispiel #8
0
def test_posterior_random_variable(n):
    f = Prior(kernel=RBF()) * Gaussian()
    x = jnp.linspace(-1.0, 1.0, 10).reshape(-1, 1)
    y = jnp.sin(x)
    D = Dataset(X=x, y=y)
    sample_points = jnp.linspace(-1.0, 1.0, num=n).reshape(-1, 1)
    params = initialise(f)
    rv = random_variable(f, params, D)
    assert isinstance(rv, Callable)
    fstar = rv(sample_points)
    assert isinstance(fstar, tfd.MultivariateNormalFullCovariance)
Beispiel #9
0
def test_conjugate():
    posterior = Prior(kernel=RBF()) * Gaussian()

    x = jnp.linspace(-1.0, 1.0, 20).reshape(-1, 1)
    y = jnp.sin(x)
    params = initialise(posterior)
    config = get_defaults()
    unconstrainer, constrainer = build_all_transforms(params.keys(), config)
    params = unconstrainer(params)
    mll = marginal_ll(posterior, transform=constrainer)
    assert isinstance(mll, Callable)
    neg_mll = marginal_ll(posterior, transform=constrainer, negative=True)
    assert neg_mll(params, x, y) == jnp.array(-1.0) * mll(params, x, y)
Beispiel #10
0
def test_conjugate_mean():
    key = jr.PRNGKey(123)
    x = jr.uniform(key, shape=(20, 1), minval=-3.0, maxval=3.0)
    y = jnp.sin(x)
    D = Dataset(X=x, y=y)

    posterior = Prior(kernel=RBF()) * Gaussian()
    params = initialise(posterior)

    xtest = jnp.linspace(-3.0, 3.0, 30).reshape(-1, 1)
    meanf = mean(posterior, params, D)
    mu = meanf(xtest)
    assert mu.shape == (xtest.shape[0], y.shape[1])
Beispiel #11
0
def test_spectral_sample():
    key = jr.PRNGKey(123)
    M = 10
    x = jnp.linspace(-1.0, 1.0, 20).reshape(-1, 1)
    y = jnp.sin(x)
    D = Dataset(X=x, y=y)
    sample_points = jnp.linspace(-1.0, 1.0, num=50).reshape(-1, 1)
    kernel = to_spectral(RBF(), M)
    post = Prior(kernel=kernel) * Gaussian()
    params = initialise(key, post)
    sparams = {"basis_fns": params["basis_fns"]}
    del params["basis_fns"]
    posterior_rv = random_variable(post, params, D, static_params=sparams)(sample_points)
    assert isinstance(posterior_rv, tfd.Distribution)
    assert isinstance(posterior_rv, tfd.MultivariateNormalFullCovariance)
Beispiel #12
0
def test_spectral():
    key = jr.PRNGKey(123)
    kern = to_spectral(RBF(), 10)
    posterior = Prior(kernel=kern) * Gaussian()
    x = jnp.linspace(-1.0, 1.0, 20).reshape(-1, 1)
    y = jnp.sin(x)
    D = Dataset(X=x, y=y)
    params = initialise(key, posterior)
    config = get_defaults()
    unconstrainer, constrainer = build_all_transforms(params.keys(), config)
    params = unconstrainer(params)
    mll = marginal_ll(posterior, transform=constrainer)
    assert isinstance(mll, Callable)
    neg_mll = marginal_ll(posterior, transform=constrainer, negative=True)
    assert neg_mll(params, D) == jnp.array(-1.0) * mll(params, D)
    nmll = neg_mll(params, D)
    assert nmll.shape == ()
Beispiel #13
0
    def test_predict_f(
        self,
        input_dim,
        output_dim,
        num_data,
        num_inducing,
        q_diag,
        whiten,
        full_cov,
        full_output_cov,
        num_inducing_samples,
    ):
        """Check shapes of output"""
        Xnew = jax.random.uniform(key, shape=(num_data, input_dim))

        mean_function = Constant(output_dim=output_dim)
        likelihood = Gaussian()
        if output_dim > 1:
            kernels = [
                SquaredExponential(
                    lengthscales=jnp.ones(input_dim, dtype=jnp.float64), variance=2.0
                )
                for _ in range(output_dim)
            ]
            kernel = SeparateIndependent(kernels)
        else:
            kernel = SquaredExponential(
                lengthscales=jnp.ones(input_dim, dtype=jnp.float64), variance=2.0
            )
        inducing_variable = jax.random.uniform(key=key, shape=(num_inducing, input_dim))

        svgp = SVGPSample(
            kernel,
            likelihood,
            inducing_variable,
            mean_function,
            num_latent_gps=output_dim,
            q_diag=q_diag,
            whiten=whiten,
        )

        params = svgp.get_params()

        def predict_f(params, Xnew):
            return svgp.predict_f(
                params, Xnew, key, num_inducing_samples, full_cov, full_output_cov
            )

        var_predict_f = self.variant(predict_f)
        mean, cov = var_predict_f(params, Xnew)
        print("mean")
        print(mean.shape)
        print(cov.shape)

        if num_inducing_samples is None:
            if not full_output_cov:
                assert mean.ndim == 2
                assert mean.shape[0] == num_data
                assert mean.shape[1] == output_dim
                if full_cov:
                    assert cov.ndim == 3
                    assert cov.shape[0] == output_dim
                    assert cov.shape[1] == cov.shape[2] == num_data
                else:
                    assert cov.ndim == 2
                    assert cov.shape[0] == num_data
                    assert cov.shape[1] == output_dim
            else:
                raise NotImplementedError("Need to add tests for full_output_cov=True")
        else:
            if not full_output_cov:
                assert mean.ndim == 3
                assert mean.shape[0] == num_inducing_samples
                assert mean.shape[1] == num_data
                assert mean.shape[2] == output_dim
                if full_cov:
                    assert cov.ndim == 4
                    assert cov.shape[0] == num_inducing_samples
                    assert cov.shape[1] == output_dim
                    assert cov.shape[2] == cov.shape[3] == num_data
                else:
                    assert cov.ndim == 3
                    assert cov.shape[0] == num_inducing_samples
                    assert cov.shape[1] == num_data
                    assert cov.shape[2] == output_dim
            else:
                raise NotImplementedError("Need to add tests for full_output_cov=True")
Beispiel #14
0
def test_conjugate_posterior():
    p = Prior(kernel=RBF())
    lik = Gaussian()
    post = p * lik
    assert isinstance(post, ConjugatePosterior)
Beispiel #15
0
def test_complete():
    posterior = Prior(kernel=RBF()) * Gaussian()
    partial_params = {"lengthscale": jnp.array(1.0)}
    full_params = complete(partial_params, posterior)
    assert list(full_params.keys()) == sorted(
        ["lengthscale", "variance", "obs_noise"])
Beispiel #16
0
def test_spectral():
    kernel = to_spectral(RBF(), 10)
    posterior = Prior(kernel=kernel) * Gaussian()
    assert isinstance(posterior, SpectralPosterior)
Beispiel #17
0
def test_initialise():
    posterior = Prior(kernel=RBF()) * Gaussian()
    params = initialise(posterior)
    assert list(params.keys()) == sorted(
        ["lengthscale", "variance", "obs_noise"])