Ejemplo n.º 1
0
def test_non_conjugate():
    posterior = Prior(kernel=RBF()) * Bernoulli()
    n = 20
    x = jnp.linspace(-1.0, 1.0, n).reshape(-1, 1)
    y = jnp.sin(x)
    params = initialise(posterior, 20)
    config = get_defaults()
    unconstrainer, constrainer = build_all_transforms(params.keys(), config)
    params = unconstrainer(params)
    mll = marginal_ll(posterior, transform=constrainer)
    assert isinstance(mll, Callable)
    neg_mll = marginal_ll(posterior, transform=constrainer, negative=True)
    assert neg_mll(params, x, y) == jnp.array(-1.0) * mll(params, x, y)
Ejemplo n.º 2
0
def test_conjugate():
    posterior = Prior(kernel=RBF()) * Gaussian()

    x = jnp.linspace(-1.0, 1.0, 20).reshape(-1, 1)
    y = jnp.sin(x)
    D = Dataset(X=x, y=y)
    params = initialise(posterior)
    config = get_defaults()
    unconstrainer, constrainer = build_all_transforms(params.keys(), config)
    params = unconstrainer(params)
    mll = marginal_ll(posterior, transform=constrainer)
    assert isinstance(mll, Callable)
    neg_mll = marginal_ll(posterior, transform=constrainer, negative=True)
    assert neg_mll(params, D) == jnp.array(-1.0) * mll(params, D)
Ejemplo n.º 3
0
def test_conjugate():
    key = jr.PRNGKey(123)
    kern = to_spectral(RBF(), 10)
    posterior = Prior(kernel=kern) * Gaussian()
    x = jnp.linspace(-1.0, 1.0, 20).reshape(-1, 1)
    y = jnp.sin(x)
    params = initialise(key, posterior)
    config = get_defaults()
    unconstrainer, constrainer = build_all_transforms(params.keys(), config)
    params = unconstrainer(params)
    mll = marginal_ll(posterior, transform=constrainer)
    assert isinstance(mll, Callable)
    neg_mll = marginal_ll(posterior, transform=constrainer, negative=True)
    assert neg_mll(params, x, y) == jnp.array(-1.0) * mll(params, x, y)
    nmll = neg_mll(params, x, y)
    assert nmll.shape == ()
Ejemplo n.º 4
0
def test_prior_mll():
    """
    Test that the MLL evaluation works with priors attached to the parameter values.
    """
    key = jr.PRNGKey(123)
    x = jnp.sort(jr.uniform(key, minval=-5.0, maxval=5.0, shape=(100, 1)),
                 axis=0)
    f = lambda x: jnp.sin(jnp.pi * x) / (jnp.pi * x)
    y = f(x) + jr.normal(key, shape=x.shape) * 0.1
    posterior = Prior(kernel=RBF()) * Gaussian()

    params = initialise(posterior)
    config = get_defaults()
    constrainer, unconstrainer = build_all_transforms(params.keys(), config)
    params = unconstrainer(params)
    print(params)

    mll = marginal_ll(posterior, transform=constrainer)

    priors = {
        "lengthscale": tfd.Gamma(1.0, 1.0),
        "variance": tfd.Gamma(2.0, 2.0),
        "obs_noise": tfd.Gamma(2.0, 2.0),
    }
    mll_eval = mll(params, x, y)
    mll_eval_priors = mll(params, x, y, priors)

    assert pytest.approx(mll_eval) == jnp.array(-103.28180663)
    assert pytest.approx(mll_eval_priors) == jnp.array(-105.509218857)
Ejemplo n.º 5
0
def fit(posterior, nits, data, configs):
    params = initialise(posterior)
    constrainer, unconstrainer = build_all_transforms(params.keys(), configs)

    mll = jit(marginal_ll(posterior, transform=constrainer, negative=True))

    opt_init, opt_update, get_params = optimizers.adam(step_size=0.05)
    opt_state = opt_init(params)

    def step(i, opt_state):
        p = get_params(opt_state)
        v, g = value_and_grad(mll)(p, data)
        return opt_update(i, g, opt_state), v

    for i in range(nits):
        opt_state, mll_estimate = step(i, opt_state)
    print(f"{posterior.prior.kernel.name} GP's marginal log-likelihood: {mll_estimate: .2f}")

    final_params = constrainer(get_params(opt_state))
    return final_params