Beispiel #1
0
 def _get_transform(self, params):
     loc = params['{}_loc'.format(self.prefix)]
     cov_factor = params['{}_cov_factor'.format(self.prefix)]
     scale = params['{}_scale'.format(self.prefix)]
     cov_diag = scale * scale
     cov_factor = cov_factor * scale[..., None]
     scale_tril = dist.LowRankMultivariateNormal(loc, cov_factor, cov_diag).scale_tril
     return MultivariateAffineTransform(loc, scale_tril)
Beispiel #2
0
 def get_posterior(self, params):
     """
     Returns a lowrank multivariate Normal posterior distribution.
     """
     loc = params["{}_loc".format(self.prefix)]
     cov_factor = params["{}_cov_factor".format(self.prefix)]
     scale = params["{}_scale".format(self.prefix)]
     cov_diag = scale * scale
     cov_factor = cov_factor * scale[..., None]
     return dist.LowRankMultivariateNormal(loc, cov_factor, cov_diag)
Beispiel #3
0
 def _sample_latent(self, base_dist, *args, **kwargs):
     sample_shape = kwargs.pop('sample_shape', ())
     rank = int(round(self.latent_size ** 0.5)) if self.rank is None else self.rank
     loc = numpyro.param('{}_loc'.format(self.prefix), self._init_latent)
     cov_factor = numpyro.param('{}_cov_factor'.format(self.prefix), np.zeros((self.latent_size, rank)))
     scale = numpyro.param('{}_scale'.format(self.prefix), np.ones(self.latent_size))
     cov_diag = scale * scale
     cov_factor = cov_factor * scale[..., None]
     posterior = dist.LowRankMultivariateNormal(loc, cov_factor, cov_diag)
     return numpyro.sample("_{}_latent".format(self.prefix), posterior, sample_shape=sample_shape)
Beispiel #4
0
 def _get_posterior(self, *args, **kwargs):
     rank = int(round(self.latent_dim ** 0.5)) if self.rank is None else self.rank
     loc = numpyro.param('{}_loc'.format(self.prefix), self._init_latent)
     cov_factor = numpyro.param('{}_cov_factor'.format(self.prefix), jnp.zeros((self.latent_dim, rank)))
     scale = numpyro.param('{}_scale'.format(self.prefix),
                           jnp.full(self.latent_dim, self._init_scale),
                           constraint=constraints.positive)
     cov_diag = scale * scale
     cov_factor = cov_factor * scale[..., None]
     return dist.LowRankMultivariateNormal(loc, cov_factor, cov_diag)
Beispiel #5
0
    def to_numpyro(self, y=None):

        f_loc = self.mean(self.X)

        _, W, D = fitc_precompute(
            self.X, self.X_u, self.obs_noise, self.kernel, jitter=self.jitter
        )
        # Sample y according SGP
        if y is not None:

            return numpyro.sample(
                "y",
                dist.LowRankMultivariateNormal(loc=f_loc, cov_factor=W, cov_diag=D)
                .expand_by(self.y.shape[:-1])
                .to_event(self.y.ndim - 1),
                obs=self.y,
            )
        else:

            return numpyro.sample(
                "y", dist.LowRankMultivariateNormal(loc=f_loc, cov_factor=W, cov_diag=D)
            )
def _lowrank_mvn_to_scipy(loc, cov_fac, cov_diag):
    jax_dist = dist.LowRankMultivariateNormal(loc, cov_fac, cov_diag)
    mean = jax_dist.mean
    cov = jax_dist.covariance_matrix
    return osp.multivariate_normal(mean=mean, cov=cov)
Beispiel #7
0
def SparseGP(X, y):

    n_samples = X.shape[0]
    X = numpyro.deterministic("X", X)
    # Set priors on kernel hyperparameters.
    η = numpyro.sample("variance", dist.HalfCauchy(scale=5.0))
    ℓ = numpyro.sample("length_scale", dist.Gamma(2.0, 1.0))
    σ = numpyro.sample("obs_noise", dist.HalfCauchy(scale=5.0))

    x_u = numpyro.param("x_u", init_value=X_u_init)

    # η = numpyro.param("kernel_var", init_value=1.0, constraints=dist.constraints.positive)
    # ℓ = numpyro.param("kernel_length", init_value=0.1,  constraints=dist.constraints.positive)
    # σ = numpyro.param("sigma", init_value=0.1, onstraints=dist.constraints.positive)

    # ================================
    # Mean Function
    # ================================
    f_loc = np.zeros(n_samples)

    # ================================
    # Qff Term
    # ================================
    # W   = (inv(Luu) @ Kuf).T
    # Qff = Kfu @ inv(Kuu) @ Kuf
    # Qff = W @ W.T
    # ================================
    Kuu = rbf_kernel(x_u, x_u, η, ℓ)
    Kuf = rbf_kernel(x_u, X, η, ℓ)
    # Kuu += jnp.eye(Ninducing) * jitter
    # add jitter
    Kuu = add_to_diagonal(Kuu, jitter)

    # cholesky factorization
    Luu = cholesky(Kuu, lower=True)
    Luu = numpyro.deterministic("Luu", Luu)

    # W matrix
    W = solve_triangular(Luu, Kuf, lower=True)
    W = numpyro.deterministic("W", W).T

    # ================================
    # Likelihood Noise Term
    # ================================
    # D = noise
    # ================================
    D = numpyro.deterministic("G", jnp.ones(n_samples) * σ)

    # ================================
    # trace term
    # ================================
    # t = tr(Kff - Qff) / noise
    # t /= - 2.0
    # ================================
    Kffdiag = jnp.diag(rbf_kernel(X, X, η, ℓ))
    Qffdiag = jnp.power(W, 2).sum(axis=1)
    trace_term = (Kffdiag - Qffdiag).sum() / σ
    trace_term = jnp.clip(trace_term, a_min=0.0)  # numerical errors

    # add trace term to the log probability loss
    numpyro.factor("trace_term", -trace_term / 2.0)

    # Sample y according SGP
    return numpyro.sample(
        "y",
        dist.LowRankMultivariateNormal(loc=f_loc, cov_factor=W,
                                       cov_diag=D).expand_by(
                                           y.shape[:-1]).to_event(y.ndim - 1),
        obs=y,
    )