示例#1
0
def kl_divergence_multivariate_normal(a_mean,
                                      a_scale_tril,
                                      b_mean,
                                      b_scale_tril,
                                      lower=True):
    def log_abs_determinant(scale_tril_arg):
        diag_scale_tril = np.diagonal(scale_tril_arg, axis1=-2, axis2=-1)
        return 2 * np.sum(np.log(diag_scale_tril), axis=-1)

    def squared_frobenius_norm(x):
        """Helper to make KL calculation slightly more readable."""
        return np.sum(np.square(x), axis=[-2, -1])

    if b_scale_tril.shape[0] == 1:
        tiles = tuple([b_mean.shape[0]] +
                      [1 for _ in range(len(b_scale_tril.shape) - 1)])
        scale_tril = np.tile(b_scale_tril, tiles)
    else:
        scale_tril = b_scale_tril

    b_inv_a = solve_triangular(b_scale_tril, a_scale_tril, lower=lower)
    kl = 0.5 * (log_abs_determinant(b_scale_tril) -
                log_abs_determinant(a_scale_tril) - a_scale_tril.shape[-1] +
                squared_frobenius_norm(b_inv_a) + squared_frobenius_norm(
                    solve_triangular(scale_tril,
                                     (b_mean - a_mean)[..., np.newaxis],
                                     lower=lower)))

    return kl
 def posterior_sample(self, key, sample, X_star, **kwargs):
     # Fetch training data
     batch = kwargs['batch']
     X = batch['X']
     # Fetch params
     var = sample['kernel_var']
     length = sample['kernel_length']
     beta = sample['beta']
     eta = sample['eta']
     theta = np.concatenate([var, length])
     # Compute kernels
     K_xx = self.kernel(X, X, theta) + np.eye(X.shape[0]) * 1e-8
     k_pp = self.kernel(X_star, X_star,
                        theta) + np.eye(X_star.shape[0]) * 1e-8
     k_pX = self.kernel(X_star, X, theta)
     L = cholesky(K_xx, lower=True)
     f = np.matmul(L, eta) + beta
     tmp_1 = solve_triangular(L.T, solve_triangular(L, f, lower=True))
     tmp_2 = solve_triangular(L.T, solve_triangular(L, k_pX.T, lower=True))
     # Compute predictive mean
     mu = np.matmul(k_pX, tmp_1)
     cov = k_pp - np.matmul(k_pX, tmp_2)
     std = np.sqrt(np.clip(np.diag(cov), a_min=0.))
     sample = mu + std * random.normal(key, mu.shape)
     return mu, sample
 def posterior_sample(self, key, sample, X_star, **kwargs):
     # Fetch training data
     norm_const = kwargs['norm_const']
     batch = kwargs['batch']
     X, y = batch['X'], batch['y']
     # Fetch params
     var = sample['kernel_var']
     length = sample['kernel_length']
     noise = sample['noise_var']
     params = np.concatenate(
         [np.array([var]),
          np.array(length),
          np.array([noise])])
     theta = params[:-1]
     # Compute kernels
     k_pp = self.kernel(X_star, X_star,
                        theta) + np.eye(X_star.shape[0]) * (noise + 1e-8)
     k_pX = self.kernel(X_star, X, theta)
     L = self.compute_cholesky(params, batch)
     alpha = solve_triangular(L.T, solve_triangular(L, y, lower=True))
     beta = solve_triangular(L.T, solve_triangular(L, k_pX.T, lower=True))
     # Compute predictive mean, std
     mu = np.matmul(k_pX, alpha)
     cov = k_pp - np.matmul(k_pX, beta)
     std = np.sqrt(np.clip(np.diag(cov), a_min=0.))
     sample = mu + std * random.normal(key, mu.shape)
     mu = mu * norm_const['sigma_y'] + norm_const['mu_y']
     sample = sample * norm_const['sigma_y'] + norm_const['mu_y']
     return mu, sample
示例#4
0
 def solve_tri(A, B, lower=True, from_left=True, transp_L=False):
     if not from_left:
         return sla.solve_triangular(A.T,
                                     B.T,
                                     trans=transp_L,
                                     lower=not lower).T
     else:
         return sla.solve_triangular(A, B, trans=transp_L, lower=lower)
示例#5
0
 def __phi_H_2(self, x, p, xtilde, ptilde):
     ptilde = ptilde - 0.5 * self.__epsilon * self.__hamiltonian.jacobian_at(
         xtilde, p)
     L = self.__target.metric(xtilde)
     # x = x + 0.5*self.__epsilon*np.linalg.solve([email protected],p)
     x = x + 0.5 * self.__epsilon * sla.solve_triangular(
         L.T, sla.solve_triangular(L, p, lower=False), lower=True)
     return x, p, xtilde, ptilde
示例#6
0
    def evaluate(self):
        K = self.model.kernel.function(self.model.X,
                                       self.model.parameters)\
            + jnp.eye(self.N) * (self.model.parameters["noise"] + 1e-8)

        self.L = cholesky(K, lower=True)
        self.alpha = solve_triangular(
            self.L.T, solve_triangular(self.L, self.model.y, lower=True))
示例#7
0
def mvn_kl(mu_a, L_a, mu_b, L_b):
    def squared_frobenius_norm(x):
        return np.sum(np.square(x))

    b_inv_a = solve_triangular(L_b, L_a, lower=True)
    kl_div = (
            np.sum(np.log(np.diag(L_b))) - np.sum(np.log(np.diag(L_a))) +
            0.5 * (-L_a.shape[-1] +
                   squared_frobenius_norm(b_inv_a) + squared_frobenius_norm(
                solve_triangular(L_b, mu_b[:, None] - mu_a[:, None], lower=True))))
    return kl_div
示例#8
0
    def _pred_factorize(self, xtest):

        Kux = self.kernel.cross_covariance(self.x_u, xtest)
        Ws = solve_triangular(self.Luu, Kux, lower=True)
        # pack
        pack = jnp.concatenate([self.W_Dinv_y, Ws], axis=1)
        Linv_pack = solve_triangular(self.L, pack, lower=True)
        # unpack
        Linv_W_Dinv_y = Linv_pack[:, : self.W_Dinv_y.shape[1]]
        Linv_Ws = Linv_pack[:, self.W_Dinv_y.shape[1] :]

        return Ws, Linv_W_Dinv_y, Linv_Ws
示例#9
0
def _pred_factorize(params, xtest):

    Kux = rbf_kernel(params["x_u"], xtest, params["variance"],
                     params["length_scale"])
    Ws = solve_triangular(params["Luu"], Kux, lower=True)
    # pack
    pack = jnp.concatenate([params["W_Dinv_y"], Ws], axis=1)
    Linv_pack = solve_triangular(params["L"], pack, lower=True)
    # unpack
    Linv_W_Dinv_y = Linv_pack[:, :params["W_Dinv_y"].shape[1]]
    Linv_Ws = Linv_pack[:, params["W_Dinv_y"].shape[1]:]

    return Ws, Linv_W_Dinv_y, Linv_Ws
示例#10
0
文件: ops.py 项目: ordabayevy/funsor
def _triangular_solve(x, y, upper=False, transpose=False):
    assert np.ndim(x) >= 2 and np.ndim(y) >= 2
    n, m = x.shape[-2:]
    assert y.shape[-2:] == (n, n)
    # NB: JAX requires x and y have the same batch_shape
    batch_shape = lax.broadcast_shapes(x.shape[:-2], y.shape[:-2])
    x = np.broadcast_to(x, batch_shape + (n, m))
    if y.shape[:-2] == batch_shape:
        return solve_triangular(y, x, trans=int(transpose), lower=not upper)

    # The following procedure handles the case: y.shape = (i, 1, n, n), x.shape = (..., i, j, n, m)
    # because we don't want to broadcast y to the shape (i, j, n, n).
    # We are going to make x have shape (..., 1, j,  i, 1, n) to apply batched triangular_solve
    dx = x.ndim
    prepend_ndim = dx - y.ndim  # ndim of ... part
    # Reshape x with the shape (..., 1, i, j, 1, n, m)
    x_new_shape = batch_shape[:prepend_ndim]
    for (sy, sx) in zip(y.shape[:-2], batch_shape[prepend_ndim:]):
        x_new_shape += (sx // sy, sy)
    x_new_shape += (
        n,
        m,
    )
    x = np.reshape(x, x_new_shape)
    # Permute y to make it have shape (..., 1, j, m, i, 1, n)
    batch_ndim = x.ndim - 2
    permute_dims = (tuple(range(prepend_ndim)) +
                    tuple(range(prepend_ndim, batch_ndim, 2)) +
                    (batch_ndim + 1, ) +
                    tuple(range(prepend_ndim + 1, batch_ndim, 2)) +
                    (batch_ndim, ))
    x = np.transpose(x, permute_dims)
    x_permute_shape = x.shape

    # reshape to (-1, i, 1, n)
    x = np.reshape(x, (-1, ) + y.shape[:-1])
    # permute to (i, 1, n, -1)
    x = np.moveaxis(x, 0, -1)

    sol = solve_triangular(y, x, trans=int(transpose),
                           lower=not upper)  # shape: (i, 1, n, -1)
    sol = np.moveaxis(sol, -1, 0)  # shape: (-1, i, 1, n)
    sol = np.reshape(sol, x_permute_shape)  # shape: (..., 1, j, m, i, 1, n)

    # now we permute back to x_new_shape = (..., 1, i, j, 1, n, m)
    permute_inv_dims = tuple(range(prepend_ndim))
    for i in range(y.ndim - 2):
        permute_inv_dims += (prepend_ndim + i, dx + i - 1)
    permute_inv_dims += (sol.ndim - 1, prepend_ndim + y.ndim - 2)
    sol = np.transpose(sol, permute_inv_dims)
    return sol.reshape(batch_shape + (n, m))
示例#11
0
    def log_likelihood(self, params):
        self.model.set_parameters(params)
        kx = self.model.kernel.function(
            self.model.X, params) + jnp.eye(self.N) * (params["noise"] + 1e-8)
        L = cholesky(kx, lower=True)

        alpha = solve_triangular(L.T,
                                 solve_triangular(L, self.model.y, lower=True))
        W_logdet = 2. * jnp.sum(jnp.log(jnp.diag(L)))
        log_marginal = 0.5 * (-self.model.y.size * log_2_pi -
                              self.model.y.shape[1] * W_logdet -
                              jnp.sum(alpha * self.model.y))

        return log_marginal
示例#12
0
    def _predict(self,
                 xtest,
                 full_covariance: bool = False,
                 noiseless: bool = True):

        # Calculate the Mean
        K_x = self.kernel.cross_covariance(xtest, self.X)
        μ = jnp.dot(K_x, self.weights)

        # calculate covariance
        v = solve_triangular(self.Lff, K_x.T, lower=True)

        if full_covariance:

            K_xx = self.kernel.gram(xtest)

            if not noiseless:
                K_xx = add_to_diagonal(K_xx, self.obs_noise)

            Σ = K_xx - v.T @ v

            return μ, Σ

        else:

            K_xx = self.kernel.diag(xtest)

            σ = K_xx - jnp.sum(jnp.square(v), axis=0)

            if not noiseless:
                σ += self.obs_noise

            return μ, σ
示例#13
0
    def mll(params: dict,
            x: jnp.DeviceArray,
            y: jnp.DeviceArray,
            static_params: dict = None):
        params = transform(params)
        if static_params:
            params = concat_dictionaries(params, static_params)
        m = gp.prior.kernel.num_basis
        phi = gp.prior.kernel._build_phi(x, params)
        A = (params["variance"] / m) * jnp.matmul(
            jnp.transpose(phi), phi) + params["obs_noise"] * I(2 * m)

        RT = jnp.linalg.cholesky(A)
        R = jnp.transpose(RT)

        RtiPhit = solve_triangular(RT, jnp.transpose(phi))
        # Rtiphity=RtiPhit*y_tr;
        Rtiphity = jnp.matmul(RtiPhit, y)

        out = (0.5 / params["obs_noise"] *
               (jnp.sum(jnp.square(y)) -
                params["variance"] / m * jnp.sum(jnp.square(Rtiphity))))
        n = x.shape[0]

        out += (jnp.sum(jnp.log(jnp.diag(R))) +
                (n / 2.0 - m) * jnp.log(params["variance"]) +
                n / 2 * jnp.log(2 * jnp.pi))
        constant = jnp.array(-1.0) if negative else jnp.array(1.0)
        return constant * out.reshape()
示例#14
0
 def final_fn(state, regularize=False):
     """
     :param state: Current state of the scheme.
     :param bool regularize: Whether to adjust diagonal for numerical stability.
     :return: a triple of estimated covariance, the square root of precision, and
         the inverse of that square root.
     """
     mean, m2, n = state
     # XXX it is not necessary to check for the case n=1
     cov = m2 / (n - 1)
     if regularize:
         # Regularization from Stan
         scaled_cov = (n / (n + 5)) * cov
         shrinkage = 1e-3 * (5 / (n + 5))
         if diagonal:
             cov = scaled_cov + shrinkage
         else:
             cov = scaled_cov + shrinkage * jnp.identity(mean.shape[0])
     if jnp.ndim(cov) == 2:
         # copy the implementation of distributions.util.cholesky_of_inverse here
         tril_inv = jnp.swapaxes(
             jnp.linalg.cholesky(cov[..., ::-1, ::-1])[..., ::-1, ::-1], -2,
             -1)
         identity = jnp.identity(cov.shape[-1])
         cov_inv_sqrt = solve_triangular(tril_inv, identity, lower=True)
     else:
         tril_inv = jnp.sqrt(cov)
         cov_inv_sqrt = jnp.reciprocal(tril_inv)
     return cov, cov_inv_sqrt, tril_inv
示例#15
0
def get_cond_params(learned_params: dict,
                    x: Array,
                    y: Array,
                    jitter: float = 1e-5) -> dict:

    params = deepcopy(learned_params)
    n_samples = x.shape[0]

    # calculate the cholesky factorization
    Kuu = rbf_kernel(params["x_u"], params["x_u"], params["variance"],
                     params["length_scale"])
    Kuu = add_to_diagonal(Kuu, jitter)
    Luu = cholesky(Kuu, lower=True)

    Kuf = rbf_kernel(params["x_u"], x, params["variance"],
                     params["length_scale"])

    W = solve_triangular(Luu, Kuf, lower=True)
    D = np.ones(n_samples) * params["obs_noise"]

    W_Dinv = W / D
    K = W_Dinv @ W.T
    K = add_to_diagonal(K, 1.0)

    L = cholesky(K, lower=True)

    # mean function
    y_residual = y  # mean function
    y_2D = y_residual.reshape(-1, n_samples).T
    W_Dinv_y = W_Dinv @ y_2D

    return {"Luu": Luu, "W_Dinv_y": W_Dinv_y, "L": L}
示例#16
0
def solve_via_cholesky(k_chol, y):
    """Solves a positive definite linear system via a Cholesky decomposition.

    Args:
        k_chol: The Cholesky factor of the matrix to solve. A lower triangular
            matrix, perhaps more commonly known as L.
        y: The vector to solve.
    """

    # Solve Ls = y
    s = spl.solve_triangular(k_chol, y, lower=True)

    # Solve Lt b = s
    b = spl.solve_triangular(k_chol.T, s)

    return b
示例#17
0
def sample_theta_space(X, Y, active_dims, msq, lam, eta1, xisq, c, sigma):
    P, N, M = X.shape[1], X.shape[0], len(active_dims)
    # the total number of coefficients we return
    num_coefficients = P + M * (M - 1) // 2

    probe = jnp.zeros((2 * P + 2 * M * (M - 1), P))
    vec = jnp.zeros((num_coefficients, 2 * P + 2 * M * (M - 1)))
    start1 = 0
    start2 = 0

    for dim in range(P):
        probe = jax.ops.index_update(probe, jax.ops.index[start1:start1 + 2,
                                                          dim],
                                     jnp.array([1.0, -1.0]))
        vec = jax.ops.index_update(vec, jax.ops.index[start2,
                                                      start1:start1 + 2],
                                   jnp.array([0.5, -0.5]))
        start1 += 2
        start2 += 1

    for dim1 in active_dims:
        for dim2 in active_dims:
            if dim1 >= dim2:
                continue
            probe = jax.ops.index_update(
                probe, jax.ops.index[start1:start1 + 4, dim1],
                jnp.array([1.0, 1.0, -1.0, -1.0]))
            probe = jax.ops.index_update(
                probe, jax.ops.index[start1:start1 + 4, dim2],
                jnp.array([1.0, -1.0, 1.0, -1.0]))
            vec = jax.ops.index_update(
                vec, jax.ops.index[start2, start1:start1 + 4],
                jnp.array([0.25, -0.25, -0.25, 0.25]))
            start1 += 4
            start2 += 1

    eta2 = jnp.square(eta1) * jnp.sqrt(xisq) / msq
    kappa = jnp.sqrt(msq) * lam / jnp.sqrt(msq + jnp.square(eta1 * lam))

    kX = kappa * X
    kprobe = kappa * probe

    k_xx = kernel(kX, kX, eta1, eta2, c) + sigma**2 * jnp.eye(N)
    L = cho_factor(k_xx, lower=True)[0]
    k_probeX = kernel(kprobe, kX, eta1, eta2, c)
    k_prbprb = kernel(kprobe, kprobe, eta1, eta2, c)

    mu = jnp.matmul(k_probeX, cho_solve((L, True), Y))
    mu = jnp.sum(mu * vec, axis=-1)

    Linv_k_probeX = solve_triangular(L, jnp.transpose(k_probeX), lower=True)
    covar = k_prbprb - jnp.matmul(jnp.transpose(Linv_k_probeX), Linv_k_probeX)
    covar = jnp.matmul(vec, jnp.matmul(covar, jnp.transpose(vec)))

    # sample from N(mu, covar)
    L = jnp.linalg.cholesky(covar)
    sample = mu + jnp.matmul(L, np.random.randn(num_coefficients))

    return sample
示例#18
0
文件: SBL.py 项目: remykusters/modax
def update_posterior(gram, XT_y, prior):
    alpha, beta = prior
    L = cholesky(jnp.diag(alpha) + beta * gram)
    R = solve_triangular(L, jnp.eye(alpha.shape[0]), check_finite=False, lower=True)
    sigma = jnp.dot(R.T, R)
    mean = beta * jnp.dot(sigma, XT_y)

    return mean, sigma
示例#19
0
文件: util.py 项目: while519/numpyro
def cholesky_of_inverse(matrix):
    # This formulation only takes the inverse of a triangular matrix
    # which is more numerically stable.
    # Refer to:
    # https://nbviewer.jupyter.org/gist/fehiepsi/5ef8e09e61604f10607380467eb82006#Precision-to-scale_tril
    tril_inv = jnp.swapaxes(jnp.linalg.cholesky(matrix[..., ::-1, ::-1])[..., ::-1, ::-1], -2, -1)
    identity = jnp.broadcast_to(jnp.identity(matrix.shape[-1]), tril_inv.shape)
    return solve_triangular(tril_inv, identity, lower=True)
示例#20
0
 def posterior_sample(self, key, sample, X_star, **kwargs):
     # Fetch training data
     batch = kwargs['batch']
     XL, XH = batch['XL'], batch['XH']
     NL, NH = XL.shape[0], XH.shape[0]
     # Fetch params
     var_L = sample['kernel_var_L']
     var_H = sample['kernel_var_H']
     length_L = sample['kernel_length_L']
     length_H = sample['kernel_length_H']
     beta_L = sample['beta_L']
     beta_H = sample['beta_H']
     eta_L = sample['eta_L']
     eta_H = sample['eta_H']
     rho = sample['rho']
     theta_L = np.concatenate([var_L, length_L])
     theta_H = np.concatenate([var_H, length_H])
     beta = np.concatenate([beta_L*np.ones(NL), beta_H*np.ones(NH)])
     eta = np.concatenate([eta_L, eta_H])
     # Compute kernels
     k_pp = rho**2 * self.kernel(X_star, X_star, theta_L) + \
                     self.kernel(X_star, X_star, theta_H) + \
                     np.eye(X_star.shape[0])*1e-8
     psi1 = rho*self.kernel(X_star, XL, theta_L)
     psi2 = rho**2 * self.kernel(X_star, XH, theta_L) + \
                     self.kernel(X_star, XH, theta_H)
     k_pX = np.hstack((psi1,psi2))
     # Compute K_xx
     K_LL = self.kernel(XL, XL, theta_L) + np.eye(NL)*1e-8
     K_LH = rho*self.kernel(XL, XH, theta_L)
     K_HH = rho**2 * self.kernel(XH, XH, theta_L) + \
                     self.kernel(XH, XH, theta_H) + np.eye(NH)*1e-8
     K_xx = np.vstack((np.hstack((K_LL,K_LH)),
                    np.hstack((K_LH.T,K_HH))))
     L = cholesky(K_xx, lower=True)
     # Sample latent function
     f = np.matmul(L, eta) + beta
     tmp_1 = solve_triangular(L.T,solve_triangular(L, f, lower=True))
     tmp_2  = solve_triangular(L.T,solve_triangular(L, k_pX.T, lower=True))
     # Compute predictive mean
     mu = np.matmul(k_pX, tmp_1)
     cov = k_pp - np.matmul(k_pX, tmp_2)
     std = np.sqrt(np.clip(np.diag(cov), a_min=0.))
     sample = mu + std * random.normal(key, mu.shape)
     return mu, sample
示例#21
0
 def build_rv(test_points: Array):
     Kfx = cross_covariance(gp.prior.kernel, X, test_points, params)
     Kxx = gram(gp.prior.kernel, test_points, params)
     A = solve_triangular(L, Kfx.T, lower=True)
     latent_var = Kxx - jnp.sum(jnp.square(A), -2)
     latent_mean = jnp.matmul(A.T, nu)
     lvar = jnp.diag(latent_var)
     moment_fn = predictive_moments(gp.likelihood)
     return moment_fn(latent_mean.ravel(), lvar)
示例#22
0
def isPD_and_invert(M):
    L = np.linalg.cholesky(M)
    if np.isnan(np.sum(L)):
        return False, None
    L_inverse = sla.solve_triangular(L,
                                     np.eye(len(L)),
                                     lower=True,
                                     check_finite=False)
    return True, L_inverse.T.dot(L_inverse)
示例#23
0
def random_variable(
    gp: SpectralPosterior,
    params: dict,
    train_inputs: Array,
    train_outputs: Array,
    test_inputs: Array,
    static_params: dict = None,
) -> tfd.Distribution:
    params = concat_dictionaries(params, static_params)
    m = gp.prior.kernel.num_basis
    w = params["basis_fns"] / params["lengthscale"]
    phi = gp.prior.kernel._build_phi(train_inputs, params)

    A = (params["variance"] / m) * jnp.matmul(jnp.transpose(phi), phi) + params["obs_noise"] * I(
        2 * m
    )

    RT = jnp.linalg.cholesky(A)
    R = jnp.transpose(RT)

    RtiPhit = solve_triangular(RT, jnp.transpose(phi))
    # Rtiphity=RtiPhit*y_tr;
    Rtiphity = jnp.matmul(RtiPhit, train_outputs)

    alpha = params["variance"] / m * solve_triangular(R, Rtiphity, lower=False)

    phistar = jnp.matmul(test_inputs, jnp.transpose(w))
    # phistar = [cos(phistar) sin(phistar)];                              % test design matrix
    phistar = jnp.hstack([jnp.cos(phistar), jnp.sin(phistar)])
    # out1(beg_chunk:end_chunk) = phistar*alfa;                           % Predictive mean
    mean = jnp.matmul(phistar, alpha)
    print(mean.shape)

    RtiPhistart = solve_triangular(RT, jnp.transpose(phistar))
    PhiRistar = jnp.transpose(RtiPhistart)
    cov = (
        params["obs_noise"]
        * params["variance"]
        / m
        * jnp.matmul(PhiRistar, jnp.transpose(PhiRistar))
        + I(test_inputs.shape[0]) * 1e-6
    )
    return tfd.MultivariateNormalFullCovariance(mean.squeeze(), cov)
示例#24
0
 def log_mvnormal(x, mean, cov):
     L = jnp.linalg.cholesky(cov)
     dx = x - mean
     dx = solve_triangular(L, dx, lower=True)
     # maha = dx @ jnp.linalg.solve(cov, dx)
     maha = dx @ dx
     # logdet = jnp.log(jnp.linalg.det(cov))
     logdet = jnp.sum(jnp.diag(L))
     log_prob = -0.5 * x.size * jnp.log(2. * jnp.pi) - logdet - 0.5 * maha
     return log_prob
示例#25
0
def _batch_mahalanobis(bL, bx):
    if bL.shape[:-1] == bx.shape:
        # no need to use the below optimization procedure
        solve_bL_bx = solve_triangular(bL, bx[..., None],
                                       lower=True).squeeze(-1)
        return np.sum(np.square(solve_bL_bx), -1)

    # NB: The following procedure handles the case: bL.shape = (i, 1, n, n), bx.shape = (i, j, n)
    # because we don't want to broadcast bL to the shape (i, j, n, n).

    # Assume that bL.shape = (i, 1, n, n), bx.shape = (..., i, j, n),
    # we are going to make bx have shape (..., 1, j,  i, 1, n) to apply batched tril_solve
    sample_ndim = bx.ndim - bL.ndim + 1  # size of sample_shape
    out_shape = np.shape(bx)[:-1]  # shape of output
    # Reshape bx with the shape (..., 1, i, j, 1, n)
    bx_new_shape = out_shape[:sample_ndim]
    for (sL, sx) in zip(bL.shape[:-2], out_shape[sample_ndim:]):
        bx_new_shape += (sx // sL, sL)
    bx_new_shape += (-1, )
    bx = np.reshape(bx, bx_new_shape)
    # Permute bx to make it have shape (..., 1, j, i, 1, n)
    permute_dims = (tuple(range(sample_ndim)) +
                    tuple(range(sample_ndim, bx.ndim - 1, 2)) +
                    tuple(range(sample_ndim + 1, bx.ndim - 1, 2)) +
                    (bx.ndim - 1, ))
    bx = np.transpose(bx, permute_dims)

    # reshape to (-1, i, 1, n)
    xt = np.reshape(bx, (-1, ) + bL.shape[:-1])
    # permute to (i, 1, n, -1)
    xt = np.moveaxis(xt, 0, -1)
    solve_bL_bx = solve_triangular(bL, xt, lower=True)  # shape: (i, 1, n, -1)
    M = np.sum(solve_bL_bx**2, axis=-2)  # shape: (i, 1, -1)
    # permute back to (-1, i, 1)
    M = np.moveaxis(M, -1, 0)
    # reshape back to (..., 1, j, i, 1)
    M = np.reshape(M, bx.shape[:-1])
    # permute back to (..., 1, i, j, 1)
    permute_inv_dims = tuple(range(sample_ndim))
    for i in range(bL.ndim - 2):
        permute_inv_dims += (sample_ndim + i, len(out_shape) + i)
    M = np.transpose(M, permute_inv_dims)
    return np.reshape(M, out_shape)
示例#26
0
    def build_rv(test_points: Array):
        N = test_points.shape[0]
        phistar = jnp.matmul(test_points, jnp.transpose(w))
        phistar = jnp.hstack([jnp.cos(phistar), jnp.sin(phistar)])
        mean = jnp.matmul(phistar, alpha)

        RtiPhistart = solve_triangular(RT, jnp.transpose(phistar))
        PhiRistar = jnp.transpose(RtiPhistart)
        cov = (params["obs_noise"] * params["variance"] / m *
               jnp.matmul(PhiRistar, jnp.transpose(PhiRistar)) + I(N) * 1e-6)
        return tfd.MultivariateNormalFullCovariance(mean.squeeze(), cov)
 def precision_matrix(self):
     # We use "Woodbury matrix identity" to take advantage of low rank form::
     #     inv(W @ W.T + D) = inv(D) - inv(D) @ W @ inv(C) @ W.T @ inv(D)
     # where :math:`C` is the capacitance matrix.
     Wt_Dinv = (np.swapaxes(self.cov_factor, -1, -2)
                / np.expand_dims(self.cov_diag, axis=-2))
     A = solve_triangular(Wt_Dinv, self._capacitance_tril, lower=True)
     # TODO: find a better solution to create a diagonal matrix
     inverse_cov_diag = np.reciprocal(self.cov_diag)
     diag_embed = inverse_cov_diag[..., np.newaxis] * np.identity(self.loc.shape[-1])
     return diag_embed - np.matmul(np.swapaxes(A, -1, -2), A)
示例#28
0
    def meanf(test_inputs: Array) -> Array:
        Kfx = cross_covariance(gp.prior.kernel, X, test_inputs, param)
        Kxx = gram(gp.prior.kernel, test_inputs, param)
        A = solve_triangular(L, Kfx.T, lower=True)
        latent_var = Kxx - jnp.sum(jnp.square(A), -2)
        latent_mean = jnp.matmul(A.T, nu)

        lvar = jnp.diag(latent_var)

        moment_fn = predictive_moments(gp.likelihood)
        pred_rv = moment_fn(latent_mean.ravel(), lvar)
        return pred_rv.mean()
示例#29
0
    def predict(self, Xnew):
        Kx = self.model.kernel.cov(self.model.X, Xnew)
        mu = jnp.dot(Kx.T, self.alpha)

        Kxx = self.model.kernel.cov(Xnew, Xnew)

        tmp = solve_triangular(self.L, Kx, lower=True)

        var = Kxx - jnp.dot(tmp.T,
                            tmp) + jnp.eye(Xnew.shape[0]) * self.model.variance

        return mu, var
示例#30
0
def sample_initial_states(rng, data, num_chain=4, algorithm="chmc"):
    """Sample initial states from prior."""
    init_states = []
    for _ in range(num_chain):
        u = sample_from_prior(rng, data)
        if algorithm == "chmc":
            chol_covar = onp.linalg.cholesky(covar_func(u, data))
            n = sla.solve_triangular(chol_covar, data["y_obs"], lower=True)
            q = onp.concatenate((u, onp.asarray(n)))
        else:
            q = onp.asarray(u)
        init_states.append(q)
    return init_states