示例#1
0
 def test_multivariate_normal_batch_lazy(self, cuda=False):
     device = torch.device("cuda") if cuda else torch.device("cpu")
     for dtype in (torch.float, torch.double):
         mean = torch.tensor([0, 1, 2], device=device,
                             dtype=dtype).repeat(2, 1)
         covmat = torch.diag(
             torch.tensor([1, 0.75, 1.5], device=device,
                          dtype=dtype)).repeat(2, 1, 1)
         covmat_chol = torch.cholesky(covmat)
         mvn = MultivariateNormal(mean=mean,
                                  covariance_matrix=NonLazyTensor(covmat))
         self.assertTrue(torch.is_tensor(mvn.covariance_matrix))
         self.assertIsInstance(mvn.lazy_covariance_matrix, LazyTensor)
         self.assertAllClose(mvn.variance,
                             torch.diagonal(covmat, dim1=-2, dim2=-1))
         self.assertAllClose(mvn._unbroadcasted_scale_tril, covmat_chol)
         mvn_plus1 = mvn + 1
         self.assertAllClose(mvn_plus1.mean, mvn.mean + 1)
         self.assertAllClose(mvn_plus1.covariance_matrix,
                             mvn.covariance_matrix)
         self.assertAllClose(mvn_plus1._unbroadcasted_scale_tril,
                             covmat_chol)
         mvn_times2 = mvn * 2
         self.assertAllClose(mvn_times2.mean, mvn.mean * 2)
         self.assertAllClose(mvn_times2.covariance_matrix,
                             mvn.covariance_matrix * 4)
         self.assertAllClose(mvn_times2._unbroadcasted_scale_tril,
                             covmat_chol * 2)
         mvn_divby2 = mvn / 2
         self.assertAllClose(mvn_divby2.mean, mvn.mean / 2)
         self.assertAllClose(mvn_divby2.covariance_matrix,
                             mvn.covariance_matrix / 4)
         self.assertAllClose(mvn_divby2._unbroadcasted_scale_tril,
                             covmat_chol / 2)
         # TODO: Add tests for entropy, log_prob, etc. - this an issue b/c it
         # uses using root_decomposition which is not very reliable
         # self.assertTrue(torch.allclose(mvn.entropy(), 4.3157 * torch.ones(2)))
         # self.assertTrue(
         #     torch.allclose(mvn.log_prob(torch.zeros(2, 3)), -4.8157 * torch.ones(2))
         # )
         # self.assertTrue(
         #     torch.allclose(mvn.log_prob(torch.zeros(2, 2, 3)), -4.8157 * torch.ones(2, 2))
         # )
         conf_lower, conf_upper = mvn.confidence_region()
         self.assertAllClose(conf_lower, mvn.mean - 2 * mvn.stddev)
         self.assertAllClose(conf_upper, mvn.mean + 2 * mvn.stddev)
         self.assertTrue(mvn.sample().shape == torch.Size([2, 3]))
         self.assertTrue(
             mvn.sample(torch.Size([2])).shape == torch.Size([2, 2, 3]))
         self.assertTrue(
             mvn.sample(torch.Size([2, 4])).shape == torch.Size(
                 [2, 4, 2, 3]))
示例#2
0
文件: gaussian.py 项目: xidulu/pyro
def gaussian_tensordot(x, y, dims=0):
    """
    Computes the integral over two gaussians:

        `(x @ y)(a,c) = log(integral(exp(x(a,b) + y(b,c)), b))`,

    where `x` is a gaussian over variables (a,b), `y` is a gaussian over variables
    (b,c), (a,b,c) can each be sets of zero or more variables, and `dims` is the size of b.

    :param x: a Gaussian instance
    :param y: a Gaussian instance
    :param dims: number of variables to contract
    """
    assert isinstance(x, Gaussian)
    assert isinstance(y, Gaussian)
    na = x.dim() - dims
    nb = dims
    nc = y.dim() - dims
    assert na >= 0
    assert nb >= 0
    assert nc >= 0

    Paa, Pba, Pbb = x.precision[..., :na, :na], x.precision[
        ..., na:, :na], x.precision[..., na:, na:]
    Qbb, Qbc, Qcc = y.precision[..., :nb, :nb], y.precision[
        ..., :nb, nb:], y.precision[..., nb:, nb:]
    xa, xb = x.info_vec[..., :na], x.info_vec[..., na:]  # x.precision @ x.mean
    yb, yc = y.info_vec[..., :nb], y.info_vec[..., nb:]  # y.precision @ y.mean

    precision = pad(Paa, (0, nc, 0, nc)) + pad(Qcc, (na, 0, na, 0))
    info_vec = pad(xa, (0, nc)) + pad(yc, (na, 0))
    log_normalizer = x.log_normalizer + y.log_normalizer
    if nb > 0:
        B = pad(Pba, (0, nc)) + pad(Qbc, (na, 0))
        b = xb + yb

        # Pbb + Qbb needs to be positive definite, so that we can malginalize out `b` (to have a finite integral)
        L = torch.cholesky(Pbb + Qbb)
        LinvB = torch.triangular_solve(B, L, upper=False)[0]
        LinvBt = LinvB.transpose(-2, -1)
        Linvb = torch.triangular_solve(b.unsqueeze(-1), L, upper=False)[0]

        precision = precision - torch.matmul(LinvBt, LinvB)
        # NB: precision might not be invertible for getting mean = precision^-1 @ info_vec
        if na + nc > 0:
            info_vec = info_vec - torch.matmul(LinvBt, Linvb).squeeze(-1)
        logdet = torch.diagonal(L, dim1=-2, dim2=-1).log().sum(-1)
        diff = 0.5 * nb * math.log(
            2 * math.pi) + 0.5 * Linvb.squeeze(-1).pow(2).sum(-1) - logdet
        log_normalizer = log_normalizer + diff

    return Gaussian(log_normalizer, info_vec, precision)
示例#3
0
    def forward(self, inputs):
        r"""
        predicts energy
        """
        result = super(MultiOutput, self).forward(inputs)
        
#        if self.uncertainty_own:
#            result["sigma"] = torch.nn.functional.softplus(result["y"][:,1])
#            result["y"]     = result["y"][:,0]
            
        if self.requires_dr:
            if self.return_stress:
                forces = -grad(result["y"], inputs[Structure.R],
                           grad_outputs=torch.ones_like(result["y"]),
                           create_graph=self.create_graph,retain_graph=self.training)[0]
                n_batch = inputs[Structure.R].size()[0]
                idx_m = torch.arange(n_batch, device=inputs[Structure.R].device, dtype=torch.long)[:,
                        None, None]

                # Subtract positions of central atoms to get distance vectors
                #B,A,N,C = dist_vec.shape
                #dist_vec = dist_vec.view(B,A*N,C)
                pair_force = grad(result['y'], inputs['dist_vec'],
                                  grad_outputs=torch.ones_like(inputs['dist_vec']),
                                  create_graph=False)[0]
                #result['stress'] = torch.sum(dist_vec.mm(pair_force.T),(1,2)) / 2.
                x_chol =  torch.cholesky(torch.einsum('bik,bjk->bij',inputs['_cell'],inputs['_cell']))
                V = x_chol[:,0,0]*x_chol[:,1,1]*x_chol[:,2,2]
                result['stress'] = -torch.einsum('abcd,abch->adh', inputs['dist_vec'], pair_force)/2./V*1.60217662e3
            elif self.return_hessian:
                forces = -grad(result["y"], inputs[Structure.R],
                               grad_outputs=torch.ones_like(result["y"]),
                               create_graph=self.create_graph, retain_graph=self.training)[0]
                result['hessian'] = -grad(forces, inputs[Structure.R], create_graph=False)[0]
                
            elif self.uncertainty:
                forces = -grad(result["y"], inputs[Structure.R],
                               grad_outputs=torch.ones_like(result["y"]),
                               create_graph=True,retain_graph=True)[0]
                               
                forces_std   = -grad(result["sigma"], inputs[Structure.R],
                                   grad_outputs=torch.ones_like(result["y"]),
                                   create_graph=self.training,retain_graph=self.training)[0]
                
                result['sigma_forces'] = forces_std.abs()#nn.functional.relu(forces_std) 
            else:
                forces = -grad(result["y"], inputs[Structure.R],
                               grad_outputs=torch.ones_like(result["y"]),
                               create_graph=self.training,retain_graph=self.training)[0]
            result['dydx'] = forces

        return result
示例#4
0
文件: models.py 项目: mukami12/REM
    def forward(self, x, S):
        x = x.view(-1, self.x_dim)
        bsz = x.size(0)

        ### get w and \alpha and L(\theta)
        mu, logvar = self.encoder(x)
        q_phi = Normal(loc=mu, scale=torch.exp(0.5 * logvar))
        z_q = q_phi.rsample((S, ))
        recon_batch = self.decoder(z_q)
        x_dist = Bernoulli(logits=recon_batch)
        log_lik = x_dist.log_prob(x).sum(-1)
        log_prior = self.prior.log_prob(z_q).sum(-1)
        log_q = q_phi.log_prob(z_q).sum(-1)
        log_w = log_lik + log_prior - log_q
        tmp_alpha = torch.logsumexp(log_w, dim=0).unsqueeze(0)
        alpha = torch.exp(log_w - tmp_alpha).detach()
        if self.version == 'v1':
            p_loss = -alpha * (log_lik + log_prior)

        ### get moment-matched proposal
        mu_r = alpha.unsqueeze(2) * z_q
        mu_r = mu_r.sum(0).detach()
        z_minus_mu_r = z_q - mu_r.unsqueeze(0)
        reshaped_diff = z_minus_mu_r.view(S * bsz, -1, 1)
        reshaped_diff_t = reshaped_diff.permute(0, 2, 1)
        outer = torch.bmm(reshaped_diff, reshaped_diff_t)
        outer = outer.view(S, bsz, self.z_dim, self.z_dim)
        Sigma_r = outer.mean(0) * S / (S - 1)
        Sigma_r = Sigma_r + torch.eye(self.z_dim).to(device) * 1e-6  ## ridging

        ### get v, \beta, and L(\phi)
        L = torch.cholesky(Sigma_r)
        r_phi = MultivariateNormal(loc=mu_r, scale_tril=L)

        z = r_phi.rsample((S, ))
        z_r = z.detach()
        recon_batch_r = self.decoder(z_r)
        x_dist_r = Bernoulli(logits=recon_batch_r)
        log_lik_r = x_dist_r.log_prob(x).sum(-1)
        log_prior_r = self.prior.log_prob(z_r).sum(-1)
        log_r = r_phi.log_prob(z_r)
        log_v = log_lik_r + log_prior_r - log_r
        tmp_beta = torch.logsumexp(log_v, dim=0).unsqueeze(0)
        beta = torch.exp(log_v - tmp_beta).detach()
        log_q = q_phi.log_prob(z_r).sum(-1)
        q_loss = -beta * log_q

        if self.version == 'v2':
            p_loss = -beta * (log_lik_r + log_prior_r)

        rem_loss = torch.sum(q_loss + p_loss, 0).sum()
        return rem_loss
示例#5
0
    def __init__(self,
                 sigma: Union[float, torch.Tensor],
                 opt: Optional[FalkonOptions] = None):
        super().__init__(self.kernel_name, opt)

        self.sigma, self.gaussian_type = self._get_sigma_kt(sigma)

        if self.gaussian_type == 'single':
            self.gamma = torch.tensor(-0.5 / (self.sigma.item()**2),
                                      dtype=torch.float64).item()
        else:
            self.gamma = torch.cholesky(self.sigma, upper=False)
            self.kernel_type = "l2-multi-distance"
示例#6
0
 def log_norm(self):
     'Log-normalization constant given the current parameterization.'
     idxs = torch.arange(0,
                         self.dim,
                         dtype=self.mean.dtype,
                         device=self.mean.device)
     L = torch.cholesky(self.scale_matrix, upper=False)
     logdet = 2 * torch.log(L.diag()).sum()
     return .5 * self.dof * logdet + .5 * self.dof * self.dim * math.log(2) \
            + .25 * self.dim * (self.dim - 1) * math.log(math.pi) \
            + torch.lgamma(.5 * (self.dof + 1 - idxs)).sum() \
            + .5 * self.dim * torch.log(self.scale) \
            + .5 * self.dim * math.log(2 * math.pi)
示例#7
0
def cholesky_inverse(fish, momentum):
    lower = torch.cholesky(fish)
    y = torch.triangular_solve(momentum.view(-1, 1),
                               lower,
                               upper=False,
                               transpose=False,
                               unitriangular=False)[0]
    fish_inv_p = torch.triangular_solve(y,
                                        lower.t(),
                                        upper=True,
                                        transpose=False,
                                        unitriangular=False)[0]
    return fish_inv_p
示例#8
0
 def solve_batched(self, lmbda):
     """
     Solves rom on a batch of inputs lmbda
     :param lmbda: tensor of (positive) permeabilities batch_size x n_cells
     :return: sets self.solution_torch to batch_size x n_dof tensor of dof's of rom,
     and self.LU to LU decomposition
     """
     self.stiffnessMatrix.assemble_torch(lmbda)
     self.rhs.assemble_torch(lmbda)
     self.stiffnessMatrix.cholesky_L = torch.cholesky(
         self.stiffnessMatrix.matrix_torch)
     self.solution_torch = torch.cholesky_solve(
         self.rhs.vector_torch, self.stiffnessMatrix.cholesky_L)
示例#9
0
    def forward(ctx, H, b):
        # don't crash training if cholesky decomp fails
        try:
            U = torch.cholesky(H)
            xs = torch.cholesky_solve(b, U)
            ctx.save_for_backward(U, xs)
            ctx.failed = False
        except Exception as e:
            print(e)
            ctx.failed = True
            xs = torch.zeros_like(b)

        return xs
示例#10
0
    def _sample_cholesky(self, kernel):
        """Sample from the GP prior via the Cholesky decomposition."""
        num_total_points = kernel.shape[-1]
        # Calculate Cholesky, using double precision for better stability:
        cholesky = torch.cholesky(kernel.double()).float()

        # Sample a curve
        # [batch_size, y_size, num_total_points, 1]
        y_values = torch.matmul(
            cholesky,
            torch.randn([self._batch_size, self._y_size, num_total_points, 1]))
        # [batch_size, num_total_points, y_size]
        return y_values.squeeze(dim=3).permute((0, 2, 1))
def cholesky(device_config, K, Y):
    """
    Solve linear system using a cholesky solver.
    Params:
        K - Covariance Matrix
        Y - Target Labels
    """

    with torch.no_grad():
        L = torch.cholesky(K, upper=False)
        solution = torch.cholesky_solve(Y, L, upper=False)

    return solution
            def f(x, u):
                z = torch.cat([x, u], dim=-1)
                phi = self.model.encoder(z)
                if self.model.add_nom_features:
                    phi_nom = self.model.phi_nom_fn(x, u)
                    phi = self.model.augment_phi(phi, phi_nom)
                mu = (K @ phi).squeeze(-1).squeeze(-1) + self.f_nom(x, u)

                if not with_opt:
                    return mu, sig * (1 + 0. * mu.unsqueeze(-1))
                else:
                    Lp = (torch.cholesky(SigEps * Linv) @ phi).squeeze(-1)
                    return mu, sig * (1 + 0. * mu.unsqueeze(-1)), Lp
示例#13
0
    def _predict(self, input_new: TensorType, diag=True):
        """
        Compute posterior p(f*|y), integrating out induced outputs' posterior.

        :return: (mean, var/cov)
        """

        z = self.Z
        z.requires_grad_(False)

        num_inducing = z.size(0)
        dim_output = self.Y.size(1)

        # err = self.Y - self.mean_function(self.X)
        err = self.Y
        Kuf = self.kernel.K(z, self.X)
        # add jitter
        Kuu = self.kernel.K(z) + self.jitter * torch.eye(num_inducing, 
            dtype=torch_dtype)
        Kus = self.kernel.K(z, input_new)
        L = torch.cholesky(Kuu)
        A = trtrs(Kuf, L)
        AAT = A @ A.t() / self.likelihood.variance.transform().expand_as(Kuu)
        B = AAT + torch.eye(num_inducing, dtype=torch_dtype)
        LB = torch.cholesky(B)
        # divide variance at the end
        c = trtrs(A @ err, LB) / self.likelihood.variance.transform()
        tmp1 = trtrs(Kus, L)
        tmp2 = trtrs(tmp1, LB)
        mean = tmp2.t() @ c

        if diag:
            var = self.kernel.Kdiag(input_new) - tmp1.pow(2).sum(0).squeeze() \
                  + tmp2.pow(2).sum(0).squeeze()
            # add kronecker product later for multi-output case
        else:
            var = self.kernel.K(input_new) + tmp2.t() @ tmp2 - tmp1.t() @ tmp1
        # return mean + self.mean_function(input_new), var
        return mean, var
 def scale_tril(self):
     # The following identity is used to increase the numerically computation stability
     # for Cholesky decomposition (see http://www.gaussianprocess.org/gpml/, Section 3.4.3):
     #     W @ W.T + D = D1/2 @ (I + D-1/2 @ W @ W.T @ D-1/2) @ D1/2
     # The matrix "I + D-1/2 @ W @ W.T @ D-1/2" has eigenvalues bounded from below by 1,
     # hence it is well-conditioned and safe to take Cholesky decomposition.
     n = self._event_shape[0]
     cov_diag_sqrt_unsqueeze = self._unbroadcasted_cov_diag.sqrt().unsqueeze(-1)
     Dinvsqrt_W = self._unbroadcasted_cov_factor / cov_diag_sqrt_unsqueeze
     K = torch.matmul(Dinvsqrt_W, Dinvsqrt_W.transpose(-1, -2)).contiguous()
     K.view(-1, n * n)[:, ::n + 1] += 1  # add identity matrix to K
     scale_tril = cov_diag_sqrt_unsqueeze * torch.cholesky(K)
     return scale_tril.expand(self._batch_shape + self._event_shape + self._event_shape)
示例#15
0
    def get_sps(self):
        """
        Constructs the Sigma points used for propagation.
        :return: Sigma points
        :rtype: torch.Tensor
        """
        cholcov = sqrt(self._lam + self._ndim) * torch.cholesky(self._cov)

        self._sps[..., 0, :] = self._mean
        self._sps[..., 1:self._ndim+1, :] = self._mean[..., None, :] + cholcov
        self._sps[..., self._ndim+1:, :] = self._mean[..., None, :] - cholcov

        return self._sps
示例#16
0
 def setUp(self):
     self.module = ETKFWeightsModule()
     self.state, self.obs = _create_matrices()
     innov = (self.obs['observations']-self.state.mean('ensemble'))
     innov = innov.values.reshape(-1)
     hx_perts = self.state.values.reshape(2, 1)
     obs_cov = self.obs['covariance'].values
     prepared_states = [innov, hx_perts, obs_cov]
     torch_states = [torch.from_numpy(s).float() for s in prepared_states]
     innov, hx_perts, obs_cov = torch_states
     obs_cinv = torch.cholesky(obs_cov).inverse()
     self.normed_perts = hx_perts @ obs_cinv
     self.normed_obs = (innov @ obs_cinv).view(1, 1)
示例#17
0
	def fit(self, X, labels):
		self.kern = torch.sum(torch.stack([
			self.activation_fn(self.W_comp[i] * self.kernels[i](X)) + self.W_comp[i] * self.kernels[i](X)
			for i in range(self.nb_kernels)
			]), dim=0)
		K = self.kern + torch.eye(self.kern.size()[0]).to(device) * self.lambda_reg
		L = torch.cholesky(K, upper=False)
		one_hot_y = F.one_hot(labels, num_classes = 10).type(torch.FloatTensor).to(device)

		#A, _ = torch.solve(kern, L)
		#V, _ = torch.solve(one_hot_y, L)
		#alpha = A.T @ V
		self.alpha = torch.cholesky_solve(one_hot_y, L, upper=False)
示例#18
0
 def _zvecmvn(self, mu, cov, X, theta, l, weights):
     d = l.numel()
     t = weights.numel()
     C = cov + torch.diag(l**2).repeat([t, 1, 1])  #(t,d,d)
     L = torch.cholesky(C, upper=False)  #(t,d,d)
     Xm = X.repeat([t, 1, 1]) - mu.reshape([t, 1, d])  #(t,n,d)
     LX = utils.batch_trtrs(Xm.transpose(-2, -1), L, upper=False)  #(t,d,n)
     expoent = -0.5 * torch.sum(LX**2, dim=1)  #(t,n)
     det = torch.prod(1/l**2)*\
           torch.prod(utils.batch_diag1(L),dim=1,keepdim=True)**2 #|I + A^-1B| (t,1)
     vec_ = theta / torch.sqrt(det) * torch.exp(expoent)  #(t,n)
     zvec = (weights.reshape(-1, 1) * vec_).sum(dim=0)  #(n,)
     return zvec
示例#19
0
def generalized_eigenvalue_decomposition(a, b):
    """Solves the generalized eigenvalue decomposition through Cholesky decomposition.
    Returns eigen values and eigen vectors (ascending order).
    """
    cholesky = torch.cholesky(b)
    inv_cholesky = torch.inverse(cholesky)
    # Compute C matrix L⁻1 A L^-T
    cmat = inv_cholesky @ a @ inv_cholesky.conj().transpose(-1, -2)
    # Performing the eigenvalue decomposition
    e_val, e_vec = torch.symeig(cmat, eigenvectors=True)
    # Collecting the eigenvectors
    e_vec = torch.matmul(inv_cholesky.conj().transpose(-1, -2), e_vec)
    return e_val, e_vec
示例#20
0
def cholesky_iteration_loss(l_matrix, preconditioner, n_iter=8):
    """Cholesky iterations for positive (semi-)definite matrices.

    Krishnamoorthy, Aravindh, and Kenan Kocagoez. "Singular values using
    cholesky decomposition." arXiv preprint arXiv:1202.1490 (2012).
    """
    # Algorithm 2.
    J_k = torch.sparse.mm(l_matrix, preconditioner)
    for _ in range(n_iter):
        R_k = torch.cholesky(J_k, upper=True)
        J_k = torch.mm(R_k, R_k.t())
    sigma = sorted(J_k.diag())
    return sigma[-1] / sigma[0]
示例#21
0
 def _cholesky(self, K):
     try:
         return torch.cholesky(K)
     except RuntimeError as e:
         print("ERROR:", e.args[0], file=sys.__stdout__)
         print("K =", K, file=sys.__stdout__)
         if K.isnan().any():
             print("Kernel matrix has NaNs!", file=sys.__stdout__)
         if K.isinf().any():
             print("Kernel matrix has infinities!", file=sys.__stdout__)
         print("Parameters:", file=sys.__stdout__)
         self.print_parameters(file=sys.__stdout__)
         raise CholeskyException(e.args[0], K, self)
示例#22
0
def test_dist_to_funsor_mvn(batch_shape, event_size):
    loc = torch.randn(batch_shape + (event_size, ))
    cov = torch.randn(batch_shape + (event_size, 2 * event_size))
    cov = cov.matmul(cov.transpose(-1, -2))
    scale_tril = torch.cholesky(cov)
    d = dist.MultivariateNormal(loc, scale_tril=scale_tril)
    f = dist_to_funsor(d)
    assert isinstance(f, Funsor)

    value = d.sample()
    actual_log_prob = f(value=tensor_to_funsor(value, event_output=1))
    expected_log_prob = tensor_to_funsor(d.log_prob(value))
    assert_close(actual_log_prob, expected_log_prob)
 def nll(self, X, y):
     m, S = self.forward(X, y)
     y = self.y
     const = -0.5 * self.N * torch.log(2 * torch.tensor(math.pi))
     # data_fit = -0.5 * y.t() @ self.Kxx_noise_inv @ y
     data_fit = -0.5 * (y - m).t() @ S.inverse() @ (y - m)
     # complexity = -torch.trace(torch.cholesky(S))
     complexity = -torch.log(torch.diag(torch.cholesky(self.Kxx_noise + torch.eye(self.N) * 1e-5)))
     complexity = complexity.sum()
     print(
         f"nll terms datafit : {data_fit.detach().numpy()},\n complexity : {complexity}"
     )
     return data_fit + complexity + const
示例#24
0
    def __init__(self,
                 loc,
                 covariance_matrix=None,
                 precision_matrix=None,
                 scale_tril=None,
                 validate_args=None):
        if loc.dim() < 1:
            raise ValueError("loc must be at least one-dimensional.")
        if (covariance_matrix is not None) + (scale_tril is not None) + (
                precision_matrix is not None) != 1:
            raise ValueError(
                "Exactly one of covariance_matrix or precision_matrix or scale_tril may be specified."
            )

        loc_ = loc.unsqueeze(-1)  # temporarily add dim on right
        if scale_tril is not None:
            if scale_tril.dim() < 2:
                raise ValueError(
                    "scale_tril matrix must be at least two-dimensional, "
                    "with optional leading batch dimensions")
            self.scale_tril, loc_ = torch.broadcast_tensors(scale_tril, loc_)
        elif covariance_matrix is not None:
            if covariance_matrix.dim() < 2:
                raise ValueError(
                    "covariance_matrix must be at least two-dimensional, "
                    "with optional leading batch dimensions")
            self.covariance_matrix, loc_ = torch.broadcast_tensors(
                covariance_matrix, loc_)
        else:
            if precision_matrix.dim() < 2:
                raise ValueError(
                    "precision_matrix must be at least two-dimensional, "
                    "with optional leading batch dimensions")
            self.precision_matrix, loc_ = torch.broadcast_tensors(
                precision_matrix, loc_)
        self.loc = loc_[..., 0]  # drop rightmost dim
        self.normalizing_constant = torch.nn.Parameter(torch.tensor([1.]))

        batch_shape, event_shape = self.loc.shape[:-1], self.loc.shape[-1:]
        super(UnnormMVGaussian, self).__init__(batch_shape,
                                               event_shape,
                                               validate_args=validate_args)

        if scale_tril is not None:
            self._unbroadcasted_scale_tril = scale_tril
        else:
            if precision_matrix is not None:
                self.covariance_matrix = torch.inverse(
                    precision_matrix).expand_as(loc_)
            self._unbroadcasted_scale_tril = torch.cholesky(
                self.covariance_matrix)
示例#25
0
def psd_safe_cholesky(A, upper=False, out=None, jitter=None):
    """Compute the Cholesky decomposition of A. If A is only p.s.d, add a small jitter to the diagonal.
    Args:
        :attr:`A` (Tensor):
            The tensor to compute the Cholesky decomposition of
        :attr:`upper` (bool, optional):
            See torch.cholesky
        :attr:`out` (Tensor, optional):
            See torch.cholesky
        :attr:`jitter` (float, optional):
            The jitter to add to the diagonal of A in case A is only p.s.d. If omitted, chosen
            as 1e-6 (float) or 1e-8 (double)
    """
    try:
        L = torch.cholesky(A, upper=upper, out=out)
        return L
    except RuntimeError as e:
        isnan = torch.isnan(A)
        if isnan.any():
            raise NanError(
                f"cholesky_cpu: {isnan.sum().item()} of {A.numel()} elements of the {A.shape} tensor are NaN."
            )

        if jitter is None:
            jitter = 1e-6 if A.dtype == torch.float32 else 1e-8
        Aprime = A.clone()
        jitter_prev = 0
        for i in range(3):
            jitter_new = jitter * (10 ** i)
            Aprime.diagonal(dim1=-2, dim2=-1).add_(jitter_new - jitter_prev)
            jitter_prev = jitter_new
            try:
                L = torch.cholesky(Aprime, upper=upper, out=out)
                warnings.warn(f"A not p.d., added jitter of {jitter_new} to the diagonal", NumericalWarning)
                return L
            except RuntimeError:
                continue
        raise e
示例#26
0
 def test_psd_safe_cholesky_psd(self, cuda=False):
     device = torch.device("cuda") if cuda else torch.device("cpu")
     for dtype in (torch.float, torch.double):
         for batch_mode in (False, True):
             if batch_mode:
                 A = self._gen_test_psd().to(device=device, dtype=dtype)
             else:
                 A = self._gen_test_psd()[0].to(device=device, dtype=dtype)
             idx = torch.arange(A.shape[-1], device=A.device)
             # default values
             Aprime = A.clone()
             Aprime[..., idx,
                    idx] += 1e-6 if A.dtype == torch.float32 else 1e-8
             L_exp = torch.cholesky(Aprime)
             with warnings.catch_warnings(record=True) as w:
                 L_safe = psd_safe_cholesky(A)
                 self.assertTrue(
                     any(
                         issubclass(w_.category, RuntimeWarning)
                         for w_ in w))
                 self.assertTrue(
                     any("A not p.d., added jitter" in str(w_.message)
                         for w_ in w))
             self.assertTrue(torch.allclose(L_exp, L_safe))
             # user-defined value
             Aprime = A.clone()
             Aprime[..., idx, idx] += 1e-2
             L_exp = torch.cholesky(Aprime)
             with warnings.catch_warnings(record=True) as w:
                 L_safe = psd_safe_cholesky(A, jitter=1e-2)
                 self.assertTrue(
                     any(
                         issubclass(w_.category, RuntimeWarning)
                         for w_ in w))
                 self.assertTrue(
                     any("A not p.d., added jitter" in str(w_.message)
                         for w_ in w))
             self.assertTrue(torch.allclose(L_exp, L_safe))
示例#27
0
    def expected_sufficient_stats(self):
        '''Expected sufficient statistics given the current 
        parameterization. 

        For the random variable mu (vector), S (positive definite 
        matrix) the sufficient statistics of the Normal-Wishart are 
        given by: 

                                    +-- Dimension
                                    |
        stats = (                   v  
            S * mu,             <-- D
            S,                  <-- D^2
            tr(S * mu * mu^T),  <-- 1
            ln |S|              <-- 1
        )

        For the standard parameters (m=mean, k=scale, W=W, v=dof) 
        expecation of the sufficient statistics is given by:

                                                    +-- Dimension 
                                                    | 
        exp_stats = (                               v
            v * W * m,                          <-- D   
            v * W,                              <-- D^2
            (D/k) + tr(v * W * m * m^T),        <-- 1
            ( \sum_i psi(.5 * (v + 1 - i)) ) \  
             + D * ln 2 + ln |W|                <-- 1
        )

        Note: "tr" is the trace operator, "D" is the dimenion of "m" 
            and "psi" is the "digamma" function.

        '''
        idxs = torch.arange(0,
                            self.dim,
                            dtype=self.mean.dtype,
                            device=self.mean.device)
        L = torch.cholesky(self.scale_matrix, upper=False)
        logdet = torch.log(L.diag()).sum()
        mean_quad = torch.ger(self.mean, self.mean)
        exp_prec = self.dof * self.scale_matrix
        return torch.cat([
           exp_prec @ self.mean,
            exp_prec.reshape(-1),
            ((self.dim / self.scale) \
                + (exp_prec @ mean_quad).trace()).reshape(1),
            (torch.digamma(.5 * (self.dof + 1 - idxs)).sum() \
                + self.dim * math.log(2) + logdet).reshape(1)
        ])
示例#28
0
    def forward(self, X):
        A = self.alpha * torch.eye(self.n_features).type(torch.double) + self.beta * self.X_train.t() @ self.X_train
        L_A = torch.cholesky(self.jitter(A))
        H1T_star = torch.triangular_solve(X.t(), L_A, upper=False)[0] # Have to compute the transpose because of the way triangular_solve/trtrs is written
        H1T_train = torch.triangular_solve(self.X_train.t(),L_A,upper=False)[0]

        # predictive mean Xm depends on the term $X S_N X^T$, which we compute through a Cholesky decomposition
        Xm =  self.beta * H1T_star.t() @ H1T_train @ self.Y_train
        pred_mean = Xm # predictive mean

        # predictive covariance depends on the same terms
        XSX = H1T_star.t() @ H1T_star
        pred_covar = 1/self.beta*torch.eye(XSX.shape[0]).type(torch.double) + XSX
        return (pred_mean,pred_covar.diag())
示例#29
0
    def forward(self, X, y):
        """
        Returns posterior mean and variance
        """
        self.x, self.y, self.N = X, y, X.shape[-2]

        self.Kxx = self.kernel(X, X)  # + torch.eye(self.N) * 1e-5
        jitter = torch.eye(self.N) * self.noisestd**2
        L = torch.cholesky(self.Kxx + jitter)
        a, _ = torch.solve(y.unsqueeze(0), L.unsqueeze(0))
        alpha, _ = torch.solve(a, L.t().unsqueeze(0))
        m = alpha.squeeze(0)
        S = L
        return m, S
示例#30
0
 def reduce_func_singular_values(self,func):
     bs,c,h,w = self._shape
     padded_weight = F.pad(self.conv.weight,(0,h-3,0,w-3))
     w_fft = torch.rfft(padded_weight, 2, onesided=False, normalized=False)
     # Lift to real valued space
     D = phi(w_fft).permute(2,3,0,1)
     Dt = D.permute(0, 1, 3, 2) #transpose of D
     lhs = torch.matmul(D, Dt)
     scale = lhs.data.norm().detach()/np.sqrt(np.prod(Dt.shape))
     chol_output = torch.cholesky(lhs+1e-4*scale*torch.eye(lhs.size(-1)).to(lhs.device))
     eigs = torch.diagonal(chol_output,dim1=-2,dim2=-1)
     logdet = (func(eigs).sum() / 2.0).expand(bs)
     # 1/4 \sum_{h,w} log det (DDt)
     return logdet
示例#31
0
    def __init__(
        self,
        mean: Tensor,
        cov: Tensor,
        seed: Optional[int] = None,
        inv_transform: bool = False,
    ) -> None:
        r"""Engine for qMC sampling from a multivariate Normal `N(\mu, \Sigma)`.

        Args:
            mean: The mean vector.
            cov: The covariance matrix.
            seed: The seed with which to seed the random number generator of the
                underlying SobolEngine.
            inv_transform: If True, use inverse transform instead of Box-Muller.
        """
        # validate inputs
        if not cov.shape[0] == cov.shape[1]:
            raise ValueError("Covariance matrix is not square.")
        if not mean.shape[0] == cov.shape[0]:
            raise ValueError("Dimension mismatch between mean and covariance.")
        if not torch.allclose(cov, cov.transpose(-1, -2)):
            raise ValueError("Covariance matrix is not symmetric.")
        self._mean = mean
        self._normal_engine = NormalQMCEngine(
            d=mean.shape[0], seed=seed, inv_transform=inv_transform
        )
        # compute Cholesky decomp; if it fails, do the eigendecomposition
        try:
            self._corr_matrix = torch.cholesky(cov).transpose(-1, -2)
        except RuntimeError:
            eigval, eigvec = torch.symeig(cov, eigenvectors=True)
            if not torch.all(eigval >= -1e-8):
                raise ValueError("Covariance matrix not PSD.")
            eigval_root = eigval.clamp_min(0.0).sqrt()
            self._corr_matrix = (eigvec * eigval_root).transpose(-1, -2)