示例#1
0
def test_check_needless():
    complete_prior = {
        "lengthscale": tfd.Gamma(1.0, 1.0),
        "variance": tfd.Gamma(2.0, 2.0),
        "obs_noise": tfd.Gamma(3.0, 3.0),
        "latent": tfd.Normal(loc=0.0, scale=1.0),
    }
    posterior = Prior(kernel=RBF()) * Bernoulli()
    priors = prior_checks(posterior, complete_prior)
    assert priors == complete_prior
示例#2
0
def test_predictive_moment(n):
    l = Bernoulli()
    key = jr.PRNGKey(123)
    fmean = jr.uniform(key=key, shape=(n, )) * -1
    fvar = jr.uniform(key=key, shape=(n, ))
    pred_mom_fn = predictive_moments(l)
    rv = pred_mom_fn(fmean, fvar)
    mu = rv.mean()
    sigma = rv.variance()
    assert mu.shape == (n, )
    assert sigma.shape == (n, )
示例#3
0
def test_non_conjugate():
    posterior = Prior(kernel=RBF()) * Bernoulli()
    n = 20
    x = jnp.linspace(-1.0, 1.0, n).reshape(-1, 1)
    y = jnp.sin(x)
    params = initialise(posterior, 20)
    config = get_defaults()
    unconstrainer, constrainer = build_all_transforms(params.keys(), config)
    params = unconstrainer(params)
    mll = marginal_ll(posterior, transform=constrainer)
    assert isinstance(mll, Callable)
    neg_mll = marginal_ll(posterior, transform=constrainer, negative=True)
    assert neg_mll(params, x, y) == jnp.array(-1.0) * mll(params, x, y)
示例#4
0
def test_non_conjugate_variance():
    key = jr.PRNGKey(123)
    x = jnp.sort(jr.uniform(key, shape=(10, 1), minval=-1.0, maxval=1.0), axis=0)
    y = 0.5 * jnp.sign(jnp.cos(3 * x + jr.normal(key, shape=x.shape) * 0.05)) + 0.5
    D = Dataset(X=x, y=y)
    xtest = jnp.linspace(-1.05, 1.05, 50).reshape(-1, 1)

    posterior = Prior(kernel=RBF()) * Bernoulli()
    params = initialise(posterior, x.shape[0])

    varf = variance(posterior, params, D)
    sigma = varf(xtest)
    assert sigma.shape == (xtest.shape[0],)
示例#5
0
def test_non_conjugate_mean():
    key = jr.PRNGKey(123)
    x = jnp.sort(jr.uniform(key, shape=(10, 1), minval=-1.0, maxval=1.0),
                 axis=0)
    y = 0.5 * jnp.sign(
        jnp.cos(3 * x + jr.normal(key, shape=x.shape) * 0.05)) + 0.5
    xtest = jnp.linspace(-1.05, 1.05, 50).reshape(-1, 1)

    posterior = Prior(kernel=RBF()) * Bernoulli()
    params = initialise(posterior, x.shape[0])

    mu = mean(posterior, params, xtest, x, y)
    assert mu.shape == (xtest.shape[0], )
示例#6
0
def test_non_conjugate_rv(n):
    key = jr.PRNGKey(123)
    f = posterior = Prior(kernel=RBF()) * Bernoulli()
    x = jnp.sort(jr.uniform(key, shape=(n, 1), minval=-1.0, maxval=1.0), axis=0)
    y = 0.5 * jnp.sign(jnp.cos(3 * x + jr.normal(key, shape=x.shape) * 0.05)) + 0.5
    D = Dataset(X=x, y=y)

    sample_points = jnp.linspace(-1.0, 1.0, num=n).reshape(-1, 1)

    hyperparams = {"lengthscale": jnp.array([1.0]), "variance": jnp.array([1.0])}
    params = complete(hyperparams, posterior, x.shape[0])
    rv = random_variable(f, params, D)
    assert isinstance(rv, Callable)
    fstar = rv(sample_points)
    assert isinstance(fstar, tfd.ProbitBernoulli)
示例#7
0
def test_checks():
    incomplete_priors = {"lengthscale": jnp.array([1.0])}
    posterior = Prior(kernel=RBF()) * Bernoulli()
    priors = prior_checks(posterior, incomplete_priors)
    assert "latent" in priors.keys()
    assert "variance" not in priors.keys()
示例#8
0
def test_non_conjugate_initialise(n):
    posterior = Prior(kernel=RBF()) * Bernoulli()
    params = initialise(posterior, n)
    assert list(params.keys()) == sorted(["lengthscale", "variance", "latent"])
    assert params["latent"].shape == (n, 1)