def _get_transform(self, params): loc = params['{}_loc'.format(self.prefix)] cov_factor = params['{}_cov_factor'.format(self.prefix)] scale = params['{}_scale'.format(self.prefix)] cov_diag = scale * scale cov_factor = cov_factor * scale[..., None] scale_tril = dist.LowRankMultivariateNormal(loc, cov_factor, cov_diag).scale_tril return MultivariateAffineTransform(loc, scale_tril)
def get_posterior(self, params): """ Returns a lowrank multivariate Normal posterior distribution. """ loc = params["{}_loc".format(self.prefix)] cov_factor = params["{}_cov_factor".format(self.prefix)] scale = params["{}_scale".format(self.prefix)] cov_diag = scale * scale cov_factor = cov_factor * scale[..., None] return dist.LowRankMultivariateNormal(loc, cov_factor, cov_diag)
def _sample_latent(self, base_dist, *args, **kwargs): sample_shape = kwargs.pop('sample_shape', ()) rank = int(round(self.latent_size ** 0.5)) if self.rank is None else self.rank loc = numpyro.param('{}_loc'.format(self.prefix), self._init_latent) cov_factor = numpyro.param('{}_cov_factor'.format(self.prefix), np.zeros((self.latent_size, rank))) scale = numpyro.param('{}_scale'.format(self.prefix), np.ones(self.latent_size)) cov_diag = scale * scale cov_factor = cov_factor * scale[..., None] posterior = dist.LowRankMultivariateNormal(loc, cov_factor, cov_diag) return numpyro.sample("_{}_latent".format(self.prefix), posterior, sample_shape=sample_shape)
def _get_posterior(self, *args, **kwargs): rank = int(round(self.latent_dim ** 0.5)) if self.rank is None else self.rank loc = numpyro.param('{}_loc'.format(self.prefix), self._init_latent) cov_factor = numpyro.param('{}_cov_factor'.format(self.prefix), jnp.zeros((self.latent_dim, rank))) scale = numpyro.param('{}_scale'.format(self.prefix), jnp.full(self.latent_dim, self._init_scale), constraint=constraints.positive) cov_diag = scale * scale cov_factor = cov_factor * scale[..., None] return dist.LowRankMultivariateNormal(loc, cov_factor, cov_diag)
def to_numpyro(self, y=None): f_loc = self.mean(self.X) _, W, D = fitc_precompute( self.X, self.X_u, self.obs_noise, self.kernel, jitter=self.jitter ) # Sample y according SGP if y is not None: return numpyro.sample( "y", dist.LowRankMultivariateNormal(loc=f_loc, cov_factor=W, cov_diag=D) .expand_by(self.y.shape[:-1]) .to_event(self.y.ndim - 1), obs=self.y, ) else: return numpyro.sample( "y", dist.LowRankMultivariateNormal(loc=f_loc, cov_factor=W, cov_diag=D) )
def _lowrank_mvn_to_scipy(loc, cov_fac, cov_diag): jax_dist = dist.LowRankMultivariateNormal(loc, cov_fac, cov_diag) mean = jax_dist.mean cov = jax_dist.covariance_matrix return osp.multivariate_normal(mean=mean, cov=cov)
def SparseGP(X, y): n_samples = X.shape[0] X = numpyro.deterministic("X", X) # Set priors on kernel hyperparameters. η = numpyro.sample("variance", dist.HalfCauchy(scale=5.0)) ℓ = numpyro.sample("length_scale", dist.Gamma(2.0, 1.0)) σ = numpyro.sample("obs_noise", dist.HalfCauchy(scale=5.0)) x_u = numpyro.param("x_u", init_value=X_u_init) # η = numpyro.param("kernel_var", init_value=1.0, constraints=dist.constraints.positive) # ℓ = numpyro.param("kernel_length", init_value=0.1, constraints=dist.constraints.positive) # σ = numpyro.param("sigma", init_value=0.1, onstraints=dist.constraints.positive) # ================================ # Mean Function # ================================ f_loc = np.zeros(n_samples) # ================================ # Qff Term # ================================ # W = (inv(Luu) @ Kuf).T # Qff = Kfu @ inv(Kuu) @ Kuf # Qff = W @ W.T # ================================ Kuu = rbf_kernel(x_u, x_u, η, ℓ) Kuf = rbf_kernel(x_u, X, η, ℓ) # Kuu += jnp.eye(Ninducing) * jitter # add jitter Kuu = add_to_diagonal(Kuu, jitter) # cholesky factorization Luu = cholesky(Kuu, lower=True) Luu = numpyro.deterministic("Luu", Luu) # W matrix W = solve_triangular(Luu, Kuf, lower=True) W = numpyro.deterministic("W", W).T # ================================ # Likelihood Noise Term # ================================ # D = noise # ================================ D = numpyro.deterministic("G", jnp.ones(n_samples) * σ) # ================================ # trace term # ================================ # t = tr(Kff - Qff) / noise # t /= - 2.0 # ================================ Kffdiag = jnp.diag(rbf_kernel(X, X, η, ℓ)) Qffdiag = jnp.power(W, 2).sum(axis=1) trace_term = (Kffdiag - Qffdiag).sum() / σ trace_term = jnp.clip(trace_term, a_min=0.0) # numerical errors # add trace term to the log probability loss numpyro.factor("trace_term", -trace_term / 2.0) # Sample y according SGP return numpyro.sample( "y", dist.LowRankMultivariateNormal(loc=f_loc, cov_factor=W, cov_diag=D).expand_by( y.shape[:-1]).to_event(y.ndim - 1), obs=y, )