def mll(ds: Dataset): x, y = ds.X, ds.y params = {} for iname, iparam in numpyro_params.items(): if iparam["param_type"] == "prior": params[iname] = numpyro.sample(name=iname, fn=iparam["prior"]) else: params[iname] = numpyro.param( name=iname, init_value=iparam["init_value"], constraint=iparam["constraint"], ) # get mean function mu = gp.prior.mean_function(x) # covariance function gram_matrix = gram(gp.prior.kernel, x, params) gram_matrix += params["obs_noise"] * I(x.shape[0]) # scale triangular matrix L = cholesky(gram_matrix, lower=True) return numpyro.sample( "y", dist.MultivariateNormal(loc=mu, scale_tril=L), obs=y.squeeze(), )
def test_gram(dim): x = jnp.linspace(-1.0, 1.0, num=10).reshape(-1, 1) if dim > 1: x = jnp.hstack([x] * dim) kern = RBF() params = initialise(kern) gram_matrix = gram(kern, x, params) assert gram_matrix.shape[0] == x.shape[0] assert gram_matrix.shape[0] == gram_matrix.shape[1]
def test_pos_def(dim, ell, sigma): n = 30 x = jnp.linspace(0.0, 1.0, num=n).reshape(-1, 1) if dim > 1: x = jnp.hstack((x) * dim) kern = RBF() params = {"lengthscale": jnp.array([ell]), "variance": jnp.array(sigma)} gram_matrix = gram(kern, x, params) jitter_matrix = I(n) * 1e-6 gram_matrix += jitter_matrix min_eig = jnp.linalg.eigvals(gram_matrix).min() assert min_eig > 0