Ejemplo n.º 1
0
    def mll(ds: Dataset):

        x, y = ds.X, ds.y

        params = {}

        for iname, iparam in numpyro_params.items():
            if iparam["param_type"] == "prior":
                params[iname] = numpyro.sample(name=iname, fn=iparam["prior"])
            else:
                params[iname] = numpyro.param(
                    name=iname,
                    init_value=iparam["init_value"],
                    constraint=iparam["constraint"],
                )
        # get mean function
        mu = gp.prior.mean_function(x)

        # covariance function
        gram_matrix = gram(gp.prior.kernel, x, params)
        gram_matrix += params["obs_noise"] * I(x.shape[0])

        # scale triangular matrix
        L = cholesky(gram_matrix, lower=True)
        return numpyro.sample(
            "y",
            dist.MultivariateNormal(loc=mu, scale_tril=L),
            obs=y.squeeze(),
        )
Ejemplo n.º 2
0
def test_gram(dim):
    x = jnp.linspace(-1.0, 1.0, num=10).reshape(-1, 1)
    if dim > 1:
        x = jnp.hstack([x] * dim)
    kern = RBF()
    params = initialise(kern)
    gram_matrix = gram(kern, x, params)
    assert gram_matrix.shape[0] == x.shape[0]
    assert gram_matrix.shape[0] == gram_matrix.shape[1]
Ejemplo n.º 3
0
def test_pos_def(dim, ell, sigma):
    n = 30
    x = jnp.linspace(0.0, 1.0, num=n).reshape(-1, 1)
    if dim > 1:
        x = jnp.hstack((x) * dim)
    kern = RBF()
    params = {"lengthscale": jnp.array([ell]), "variance": jnp.array(sigma)}

    gram_matrix = gram(kern, x, params)
    jitter_matrix = I(n) * 1e-6
    gram_matrix += jitter_matrix
    min_eig = jnp.linalg.eigvals(gram_matrix).min()
    assert min_eig > 0