def test_multivariate_normal_batch_lazy(self, cuda=False): device = torch.device("cuda") if cuda else torch.device("cpu") for dtype in (torch.float, torch.double): mean = torch.tensor([0, 1, 2], device=device, dtype=dtype).repeat(2, 1) covmat = torch.diag( torch.tensor([1, 0.75, 1.5], device=device, dtype=dtype)).repeat(2, 1, 1) covmat_chol = torch.cholesky(covmat) mvn = MultivariateNormal(mean=mean, covariance_matrix=NonLazyTensor(covmat)) self.assertTrue(torch.is_tensor(mvn.covariance_matrix)) self.assertIsInstance(mvn.lazy_covariance_matrix, LazyTensor) self.assertAllClose(mvn.variance, torch.diagonal(covmat, dim1=-2, dim2=-1)) self.assertAllClose(mvn._unbroadcasted_scale_tril, covmat_chol) mvn_plus1 = mvn + 1 self.assertAllClose(mvn_plus1.mean, mvn.mean + 1) self.assertAllClose(mvn_plus1.covariance_matrix, mvn.covariance_matrix) self.assertAllClose(mvn_plus1._unbroadcasted_scale_tril, covmat_chol) mvn_times2 = mvn * 2 self.assertAllClose(mvn_times2.mean, mvn.mean * 2) self.assertAllClose(mvn_times2.covariance_matrix, mvn.covariance_matrix * 4) self.assertAllClose(mvn_times2._unbroadcasted_scale_tril, covmat_chol * 2) mvn_divby2 = mvn / 2 self.assertAllClose(mvn_divby2.mean, mvn.mean / 2) self.assertAllClose(mvn_divby2.covariance_matrix, mvn.covariance_matrix / 4) self.assertAllClose(mvn_divby2._unbroadcasted_scale_tril, covmat_chol / 2) # TODO: Add tests for entropy, log_prob, etc. - this an issue b/c it # uses using root_decomposition which is not very reliable # self.assertTrue(torch.allclose(mvn.entropy(), 4.3157 * torch.ones(2))) # self.assertTrue( # torch.allclose(mvn.log_prob(torch.zeros(2, 3)), -4.8157 * torch.ones(2)) # ) # self.assertTrue( # torch.allclose(mvn.log_prob(torch.zeros(2, 2, 3)), -4.8157 * torch.ones(2, 2)) # ) conf_lower, conf_upper = mvn.confidence_region() self.assertAllClose(conf_lower, mvn.mean - 2 * mvn.stddev) self.assertAllClose(conf_upper, mvn.mean + 2 * mvn.stddev) self.assertTrue(mvn.sample().shape == torch.Size([2, 3])) self.assertTrue( mvn.sample(torch.Size([2])).shape == torch.Size([2, 2, 3])) self.assertTrue( mvn.sample(torch.Size([2, 4])).shape == torch.Size( [2, 4, 2, 3]))
def gaussian_tensordot(x, y, dims=0): """ Computes the integral over two gaussians: `(x @ y)(a,c) = log(integral(exp(x(a,b) + y(b,c)), b))`, where `x` is a gaussian over variables (a,b), `y` is a gaussian over variables (b,c), (a,b,c) can each be sets of zero or more variables, and `dims` is the size of b. :param x: a Gaussian instance :param y: a Gaussian instance :param dims: number of variables to contract """ assert isinstance(x, Gaussian) assert isinstance(y, Gaussian) na = x.dim() - dims nb = dims nc = y.dim() - dims assert na >= 0 assert nb >= 0 assert nc >= 0 Paa, Pba, Pbb = x.precision[..., :na, :na], x.precision[ ..., na:, :na], x.precision[..., na:, na:] Qbb, Qbc, Qcc = y.precision[..., :nb, :nb], y.precision[ ..., :nb, nb:], y.precision[..., nb:, nb:] xa, xb = x.info_vec[..., :na], x.info_vec[..., na:] # x.precision @ x.mean yb, yc = y.info_vec[..., :nb], y.info_vec[..., nb:] # y.precision @ y.mean precision = pad(Paa, (0, nc, 0, nc)) + pad(Qcc, (na, 0, na, 0)) info_vec = pad(xa, (0, nc)) + pad(yc, (na, 0)) log_normalizer = x.log_normalizer + y.log_normalizer if nb > 0: B = pad(Pba, (0, nc)) + pad(Qbc, (na, 0)) b = xb + yb # Pbb + Qbb needs to be positive definite, so that we can malginalize out `b` (to have a finite integral) L = torch.cholesky(Pbb + Qbb) LinvB = torch.triangular_solve(B, L, upper=False)[0] LinvBt = LinvB.transpose(-2, -1) Linvb = torch.triangular_solve(b.unsqueeze(-1), L, upper=False)[0] precision = precision - torch.matmul(LinvBt, LinvB) # NB: precision might not be invertible for getting mean = precision^-1 @ info_vec if na + nc > 0: info_vec = info_vec - torch.matmul(LinvBt, Linvb).squeeze(-1) logdet = torch.diagonal(L, dim1=-2, dim2=-1).log().sum(-1) diff = 0.5 * nb * math.log( 2 * math.pi) + 0.5 * Linvb.squeeze(-1).pow(2).sum(-1) - logdet log_normalizer = log_normalizer + diff return Gaussian(log_normalizer, info_vec, precision)
def forward(self, inputs): r""" predicts energy """ result = super(MultiOutput, self).forward(inputs) # if self.uncertainty_own: # result["sigma"] = torch.nn.functional.softplus(result["y"][:,1]) # result["y"] = result["y"][:,0] if self.requires_dr: if self.return_stress: forces = -grad(result["y"], inputs[Structure.R], grad_outputs=torch.ones_like(result["y"]), create_graph=self.create_graph,retain_graph=self.training)[0] n_batch = inputs[Structure.R].size()[0] idx_m = torch.arange(n_batch, device=inputs[Structure.R].device, dtype=torch.long)[:, None, None] # Subtract positions of central atoms to get distance vectors #B,A,N,C = dist_vec.shape #dist_vec = dist_vec.view(B,A*N,C) pair_force = grad(result['y'], inputs['dist_vec'], grad_outputs=torch.ones_like(inputs['dist_vec']), create_graph=False)[0] #result['stress'] = torch.sum(dist_vec.mm(pair_force.T),(1,2)) / 2. x_chol = torch.cholesky(torch.einsum('bik,bjk->bij',inputs['_cell'],inputs['_cell'])) V = x_chol[:,0,0]*x_chol[:,1,1]*x_chol[:,2,2] result['stress'] = -torch.einsum('abcd,abch->adh', inputs['dist_vec'], pair_force)/2./V*1.60217662e3 elif self.return_hessian: forces = -grad(result["y"], inputs[Structure.R], grad_outputs=torch.ones_like(result["y"]), create_graph=self.create_graph, retain_graph=self.training)[0] result['hessian'] = -grad(forces, inputs[Structure.R], create_graph=False)[0] elif self.uncertainty: forces = -grad(result["y"], inputs[Structure.R], grad_outputs=torch.ones_like(result["y"]), create_graph=True,retain_graph=True)[0] forces_std = -grad(result["sigma"], inputs[Structure.R], grad_outputs=torch.ones_like(result["y"]), create_graph=self.training,retain_graph=self.training)[0] result['sigma_forces'] = forces_std.abs()#nn.functional.relu(forces_std) else: forces = -grad(result["y"], inputs[Structure.R], grad_outputs=torch.ones_like(result["y"]), create_graph=self.training,retain_graph=self.training)[0] result['dydx'] = forces return result
def forward(self, x, S): x = x.view(-1, self.x_dim) bsz = x.size(0) ### get w and \alpha and L(\theta) mu, logvar = self.encoder(x) q_phi = Normal(loc=mu, scale=torch.exp(0.5 * logvar)) z_q = q_phi.rsample((S, )) recon_batch = self.decoder(z_q) x_dist = Bernoulli(logits=recon_batch) log_lik = x_dist.log_prob(x).sum(-1) log_prior = self.prior.log_prob(z_q).sum(-1) log_q = q_phi.log_prob(z_q).sum(-1) log_w = log_lik + log_prior - log_q tmp_alpha = torch.logsumexp(log_w, dim=0).unsqueeze(0) alpha = torch.exp(log_w - tmp_alpha).detach() if self.version == 'v1': p_loss = -alpha * (log_lik + log_prior) ### get moment-matched proposal mu_r = alpha.unsqueeze(2) * z_q mu_r = mu_r.sum(0).detach() z_minus_mu_r = z_q - mu_r.unsqueeze(0) reshaped_diff = z_minus_mu_r.view(S * bsz, -1, 1) reshaped_diff_t = reshaped_diff.permute(0, 2, 1) outer = torch.bmm(reshaped_diff, reshaped_diff_t) outer = outer.view(S, bsz, self.z_dim, self.z_dim) Sigma_r = outer.mean(0) * S / (S - 1) Sigma_r = Sigma_r + torch.eye(self.z_dim).to(device) * 1e-6 ## ridging ### get v, \beta, and L(\phi) L = torch.cholesky(Sigma_r) r_phi = MultivariateNormal(loc=mu_r, scale_tril=L) z = r_phi.rsample((S, )) z_r = z.detach() recon_batch_r = self.decoder(z_r) x_dist_r = Bernoulli(logits=recon_batch_r) log_lik_r = x_dist_r.log_prob(x).sum(-1) log_prior_r = self.prior.log_prob(z_r).sum(-1) log_r = r_phi.log_prob(z_r) log_v = log_lik_r + log_prior_r - log_r tmp_beta = torch.logsumexp(log_v, dim=0).unsqueeze(0) beta = torch.exp(log_v - tmp_beta).detach() log_q = q_phi.log_prob(z_r).sum(-1) q_loss = -beta * log_q if self.version == 'v2': p_loss = -beta * (log_lik_r + log_prior_r) rem_loss = torch.sum(q_loss + p_loss, 0).sum() return rem_loss
def __init__(self, sigma: Union[float, torch.Tensor], opt: Optional[FalkonOptions] = None): super().__init__(self.kernel_name, opt) self.sigma, self.gaussian_type = self._get_sigma_kt(sigma) if self.gaussian_type == 'single': self.gamma = torch.tensor(-0.5 / (self.sigma.item()**2), dtype=torch.float64).item() else: self.gamma = torch.cholesky(self.sigma, upper=False) self.kernel_type = "l2-multi-distance"
def log_norm(self): 'Log-normalization constant given the current parameterization.' idxs = torch.arange(0, self.dim, dtype=self.mean.dtype, device=self.mean.device) L = torch.cholesky(self.scale_matrix, upper=False) logdet = 2 * torch.log(L.diag()).sum() return .5 * self.dof * logdet + .5 * self.dof * self.dim * math.log(2) \ + .25 * self.dim * (self.dim - 1) * math.log(math.pi) \ + torch.lgamma(.5 * (self.dof + 1 - idxs)).sum() \ + .5 * self.dim * torch.log(self.scale) \ + .5 * self.dim * math.log(2 * math.pi)
def cholesky_inverse(fish, momentum): lower = torch.cholesky(fish) y = torch.triangular_solve(momentum.view(-1, 1), lower, upper=False, transpose=False, unitriangular=False)[0] fish_inv_p = torch.triangular_solve(y, lower.t(), upper=True, transpose=False, unitriangular=False)[0] return fish_inv_p
def solve_batched(self, lmbda): """ Solves rom on a batch of inputs lmbda :param lmbda: tensor of (positive) permeabilities batch_size x n_cells :return: sets self.solution_torch to batch_size x n_dof tensor of dof's of rom, and self.LU to LU decomposition """ self.stiffnessMatrix.assemble_torch(lmbda) self.rhs.assemble_torch(lmbda) self.stiffnessMatrix.cholesky_L = torch.cholesky( self.stiffnessMatrix.matrix_torch) self.solution_torch = torch.cholesky_solve( self.rhs.vector_torch, self.stiffnessMatrix.cholesky_L)
def forward(ctx, H, b): # don't crash training if cholesky decomp fails try: U = torch.cholesky(H) xs = torch.cholesky_solve(b, U) ctx.save_for_backward(U, xs) ctx.failed = False except Exception as e: print(e) ctx.failed = True xs = torch.zeros_like(b) return xs
def _sample_cholesky(self, kernel): """Sample from the GP prior via the Cholesky decomposition.""" num_total_points = kernel.shape[-1] # Calculate Cholesky, using double precision for better stability: cholesky = torch.cholesky(kernel.double()).float() # Sample a curve # [batch_size, y_size, num_total_points, 1] y_values = torch.matmul( cholesky, torch.randn([self._batch_size, self._y_size, num_total_points, 1])) # [batch_size, num_total_points, y_size] return y_values.squeeze(dim=3).permute((0, 2, 1))
def cholesky(device_config, K, Y): """ Solve linear system using a cholesky solver. Params: K - Covariance Matrix Y - Target Labels """ with torch.no_grad(): L = torch.cholesky(K, upper=False) solution = torch.cholesky_solve(Y, L, upper=False) return solution
def f(x, u): z = torch.cat([x, u], dim=-1) phi = self.model.encoder(z) if self.model.add_nom_features: phi_nom = self.model.phi_nom_fn(x, u) phi = self.model.augment_phi(phi, phi_nom) mu = (K @ phi).squeeze(-1).squeeze(-1) + self.f_nom(x, u) if not with_opt: return mu, sig * (1 + 0. * mu.unsqueeze(-1)) else: Lp = (torch.cholesky(SigEps * Linv) @ phi).squeeze(-1) return mu, sig * (1 + 0. * mu.unsqueeze(-1)), Lp
def _predict(self, input_new: TensorType, diag=True): """ Compute posterior p(f*|y), integrating out induced outputs' posterior. :return: (mean, var/cov) """ z = self.Z z.requires_grad_(False) num_inducing = z.size(0) dim_output = self.Y.size(1) # err = self.Y - self.mean_function(self.X) err = self.Y Kuf = self.kernel.K(z, self.X) # add jitter Kuu = self.kernel.K(z) + self.jitter * torch.eye(num_inducing, dtype=torch_dtype) Kus = self.kernel.K(z, input_new) L = torch.cholesky(Kuu) A = trtrs(Kuf, L) AAT = A @ A.t() / self.likelihood.variance.transform().expand_as(Kuu) B = AAT + torch.eye(num_inducing, dtype=torch_dtype) LB = torch.cholesky(B) # divide variance at the end c = trtrs(A @ err, LB) / self.likelihood.variance.transform() tmp1 = trtrs(Kus, L) tmp2 = trtrs(tmp1, LB) mean = tmp2.t() @ c if diag: var = self.kernel.Kdiag(input_new) - tmp1.pow(2).sum(0).squeeze() \ + tmp2.pow(2).sum(0).squeeze() # add kronecker product later for multi-output case else: var = self.kernel.K(input_new) + tmp2.t() @ tmp2 - tmp1.t() @ tmp1 # return mean + self.mean_function(input_new), var return mean, var
def scale_tril(self): # The following identity is used to increase the numerically computation stability # for Cholesky decomposition (see http://www.gaussianprocess.org/gpml/, Section 3.4.3): # W @ W.T + D = D1/2 @ (I + D-1/2 @ W @ W.T @ D-1/2) @ D1/2 # The matrix "I + D-1/2 @ W @ W.T @ D-1/2" has eigenvalues bounded from below by 1, # hence it is well-conditioned and safe to take Cholesky decomposition. n = self._event_shape[0] cov_diag_sqrt_unsqueeze = self._unbroadcasted_cov_diag.sqrt().unsqueeze(-1) Dinvsqrt_W = self._unbroadcasted_cov_factor / cov_diag_sqrt_unsqueeze K = torch.matmul(Dinvsqrt_W, Dinvsqrt_W.transpose(-1, -2)).contiguous() K.view(-1, n * n)[:, ::n + 1] += 1 # add identity matrix to K scale_tril = cov_diag_sqrt_unsqueeze * torch.cholesky(K) return scale_tril.expand(self._batch_shape + self._event_shape + self._event_shape)
def get_sps(self): """ Constructs the Sigma points used for propagation. :return: Sigma points :rtype: torch.Tensor """ cholcov = sqrt(self._lam + self._ndim) * torch.cholesky(self._cov) self._sps[..., 0, :] = self._mean self._sps[..., 1:self._ndim+1, :] = self._mean[..., None, :] + cholcov self._sps[..., self._ndim+1:, :] = self._mean[..., None, :] - cholcov return self._sps
def setUp(self): self.module = ETKFWeightsModule() self.state, self.obs = _create_matrices() innov = (self.obs['observations']-self.state.mean('ensemble')) innov = innov.values.reshape(-1) hx_perts = self.state.values.reshape(2, 1) obs_cov = self.obs['covariance'].values prepared_states = [innov, hx_perts, obs_cov] torch_states = [torch.from_numpy(s).float() for s in prepared_states] innov, hx_perts, obs_cov = torch_states obs_cinv = torch.cholesky(obs_cov).inverse() self.normed_perts = hx_perts @ obs_cinv self.normed_obs = (innov @ obs_cinv).view(1, 1)
def fit(self, X, labels): self.kern = torch.sum(torch.stack([ self.activation_fn(self.W_comp[i] * self.kernels[i](X)) + self.W_comp[i] * self.kernels[i](X) for i in range(self.nb_kernels) ]), dim=0) K = self.kern + torch.eye(self.kern.size()[0]).to(device) * self.lambda_reg L = torch.cholesky(K, upper=False) one_hot_y = F.one_hot(labels, num_classes = 10).type(torch.FloatTensor).to(device) #A, _ = torch.solve(kern, L) #V, _ = torch.solve(one_hot_y, L) #alpha = A.T @ V self.alpha = torch.cholesky_solve(one_hot_y, L, upper=False)
def _zvecmvn(self, mu, cov, X, theta, l, weights): d = l.numel() t = weights.numel() C = cov + torch.diag(l**2).repeat([t, 1, 1]) #(t,d,d) L = torch.cholesky(C, upper=False) #(t,d,d) Xm = X.repeat([t, 1, 1]) - mu.reshape([t, 1, d]) #(t,n,d) LX = utils.batch_trtrs(Xm.transpose(-2, -1), L, upper=False) #(t,d,n) expoent = -0.5 * torch.sum(LX**2, dim=1) #(t,n) det = torch.prod(1/l**2)*\ torch.prod(utils.batch_diag1(L),dim=1,keepdim=True)**2 #|I + A^-1B| (t,1) vec_ = theta / torch.sqrt(det) * torch.exp(expoent) #(t,n) zvec = (weights.reshape(-1, 1) * vec_).sum(dim=0) #(n,) return zvec
def generalized_eigenvalue_decomposition(a, b): """Solves the generalized eigenvalue decomposition through Cholesky decomposition. Returns eigen values and eigen vectors (ascending order). """ cholesky = torch.cholesky(b) inv_cholesky = torch.inverse(cholesky) # Compute C matrix L⁻1 A L^-T cmat = inv_cholesky @ a @ inv_cholesky.conj().transpose(-1, -2) # Performing the eigenvalue decomposition e_val, e_vec = torch.symeig(cmat, eigenvectors=True) # Collecting the eigenvectors e_vec = torch.matmul(inv_cholesky.conj().transpose(-1, -2), e_vec) return e_val, e_vec
def cholesky_iteration_loss(l_matrix, preconditioner, n_iter=8): """Cholesky iterations for positive (semi-)definite matrices. Krishnamoorthy, Aravindh, and Kenan Kocagoez. "Singular values using cholesky decomposition." arXiv preprint arXiv:1202.1490 (2012). """ # Algorithm 2. J_k = torch.sparse.mm(l_matrix, preconditioner) for _ in range(n_iter): R_k = torch.cholesky(J_k, upper=True) J_k = torch.mm(R_k, R_k.t()) sigma = sorted(J_k.diag()) return sigma[-1] / sigma[0]
def _cholesky(self, K): try: return torch.cholesky(K) except RuntimeError as e: print("ERROR:", e.args[0], file=sys.__stdout__) print("K =", K, file=sys.__stdout__) if K.isnan().any(): print("Kernel matrix has NaNs!", file=sys.__stdout__) if K.isinf().any(): print("Kernel matrix has infinities!", file=sys.__stdout__) print("Parameters:", file=sys.__stdout__) self.print_parameters(file=sys.__stdout__) raise CholeskyException(e.args[0], K, self)
def test_dist_to_funsor_mvn(batch_shape, event_size): loc = torch.randn(batch_shape + (event_size, )) cov = torch.randn(batch_shape + (event_size, 2 * event_size)) cov = cov.matmul(cov.transpose(-1, -2)) scale_tril = torch.cholesky(cov) d = dist.MultivariateNormal(loc, scale_tril=scale_tril) f = dist_to_funsor(d) assert isinstance(f, Funsor) value = d.sample() actual_log_prob = f(value=tensor_to_funsor(value, event_output=1)) expected_log_prob = tensor_to_funsor(d.log_prob(value)) assert_close(actual_log_prob, expected_log_prob)
def nll(self, X, y): m, S = self.forward(X, y) y = self.y const = -0.5 * self.N * torch.log(2 * torch.tensor(math.pi)) # data_fit = -0.5 * y.t() @ self.Kxx_noise_inv @ y data_fit = -0.5 * (y - m).t() @ S.inverse() @ (y - m) # complexity = -torch.trace(torch.cholesky(S)) complexity = -torch.log(torch.diag(torch.cholesky(self.Kxx_noise + torch.eye(self.N) * 1e-5))) complexity = complexity.sum() print( f"nll terms datafit : {data_fit.detach().numpy()},\n complexity : {complexity}" ) return data_fit + complexity + const
def __init__(self, loc, covariance_matrix=None, precision_matrix=None, scale_tril=None, validate_args=None): if loc.dim() < 1: raise ValueError("loc must be at least one-dimensional.") if (covariance_matrix is not None) + (scale_tril is not None) + ( precision_matrix is not None) != 1: raise ValueError( "Exactly one of covariance_matrix or precision_matrix or scale_tril may be specified." ) loc_ = loc.unsqueeze(-1) # temporarily add dim on right if scale_tril is not None: if scale_tril.dim() < 2: raise ValueError( "scale_tril matrix must be at least two-dimensional, " "with optional leading batch dimensions") self.scale_tril, loc_ = torch.broadcast_tensors(scale_tril, loc_) elif covariance_matrix is not None: if covariance_matrix.dim() < 2: raise ValueError( "covariance_matrix must be at least two-dimensional, " "with optional leading batch dimensions") self.covariance_matrix, loc_ = torch.broadcast_tensors( covariance_matrix, loc_) else: if precision_matrix.dim() < 2: raise ValueError( "precision_matrix must be at least two-dimensional, " "with optional leading batch dimensions") self.precision_matrix, loc_ = torch.broadcast_tensors( precision_matrix, loc_) self.loc = loc_[..., 0] # drop rightmost dim self.normalizing_constant = torch.nn.Parameter(torch.tensor([1.])) batch_shape, event_shape = self.loc.shape[:-1], self.loc.shape[-1:] super(UnnormMVGaussian, self).__init__(batch_shape, event_shape, validate_args=validate_args) if scale_tril is not None: self._unbroadcasted_scale_tril = scale_tril else: if precision_matrix is not None: self.covariance_matrix = torch.inverse( precision_matrix).expand_as(loc_) self._unbroadcasted_scale_tril = torch.cholesky( self.covariance_matrix)
def psd_safe_cholesky(A, upper=False, out=None, jitter=None): """Compute the Cholesky decomposition of A. If A is only p.s.d, add a small jitter to the diagonal. Args: :attr:`A` (Tensor): The tensor to compute the Cholesky decomposition of :attr:`upper` (bool, optional): See torch.cholesky :attr:`out` (Tensor, optional): See torch.cholesky :attr:`jitter` (float, optional): The jitter to add to the diagonal of A in case A is only p.s.d. If omitted, chosen as 1e-6 (float) or 1e-8 (double) """ try: L = torch.cholesky(A, upper=upper, out=out) return L except RuntimeError as e: isnan = torch.isnan(A) if isnan.any(): raise NanError( f"cholesky_cpu: {isnan.sum().item()} of {A.numel()} elements of the {A.shape} tensor are NaN." ) if jitter is None: jitter = 1e-6 if A.dtype == torch.float32 else 1e-8 Aprime = A.clone() jitter_prev = 0 for i in range(3): jitter_new = jitter * (10 ** i) Aprime.diagonal(dim1=-2, dim2=-1).add_(jitter_new - jitter_prev) jitter_prev = jitter_new try: L = torch.cholesky(Aprime, upper=upper, out=out) warnings.warn(f"A not p.d., added jitter of {jitter_new} to the diagonal", NumericalWarning) return L except RuntimeError: continue raise e
def test_psd_safe_cholesky_psd(self, cuda=False): device = torch.device("cuda") if cuda else torch.device("cpu") for dtype in (torch.float, torch.double): for batch_mode in (False, True): if batch_mode: A = self._gen_test_psd().to(device=device, dtype=dtype) else: A = self._gen_test_psd()[0].to(device=device, dtype=dtype) idx = torch.arange(A.shape[-1], device=A.device) # default values Aprime = A.clone() Aprime[..., idx, idx] += 1e-6 if A.dtype == torch.float32 else 1e-8 L_exp = torch.cholesky(Aprime) with warnings.catch_warnings(record=True) as w: L_safe = psd_safe_cholesky(A) self.assertTrue( any( issubclass(w_.category, RuntimeWarning) for w_ in w)) self.assertTrue( any("A not p.d., added jitter" in str(w_.message) for w_ in w)) self.assertTrue(torch.allclose(L_exp, L_safe)) # user-defined value Aprime = A.clone() Aprime[..., idx, idx] += 1e-2 L_exp = torch.cholesky(Aprime) with warnings.catch_warnings(record=True) as w: L_safe = psd_safe_cholesky(A, jitter=1e-2) self.assertTrue( any( issubclass(w_.category, RuntimeWarning) for w_ in w)) self.assertTrue( any("A not p.d., added jitter" in str(w_.message) for w_ in w)) self.assertTrue(torch.allclose(L_exp, L_safe))
def expected_sufficient_stats(self): '''Expected sufficient statistics given the current parameterization. For the random variable mu (vector), S (positive definite matrix) the sufficient statistics of the Normal-Wishart are given by: +-- Dimension | stats = ( v S * mu, <-- D S, <-- D^2 tr(S * mu * mu^T), <-- 1 ln |S| <-- 1 ) For the standard parameters (m=mean, k=scale, W=W, v=dof) expecation of the sufficient statistics is given by: +-- Dimension | exp_stats = ( v v * W * m, <-- D v * W, <-- D^2 (D/k) + tr(v * W * m * m^T), <-- 1 ( \sum_i psi(.5 * (v + 1 - i)) ) \ + D * ln 2 + ln |W| <-- 1 ) Note: "tr" is the trace operator, "D" is the dimenion of "m" and "psi" is the "digamma" function. ''' idxs = torch.arange(0, self.dim, dtype=self.mean.dtype, device=self.mean.device) L = torch.cholesky(self.scale_matrix, upper=False) logdet = torch.log(L.diag()).sum() mean_quad = torch.ger(self.mean, self.mean) exp_prec = self.dof * self.scale_matrix return torch.cat([ exp_prec @ self.mean, exp_prec.reshape(-1), ((self.dim / self.scale) \ + (exp_prec @ mean_quad).trace()).reshape(1), (torch.digamma(.5 * (self.dof + 1 - idxs)).sum() \ + self.dim * math.log(2) + logdet).reshape(1) ])
def forward(self, X): A = self.alpha * torch.eye(self.n_features).type(torch.double) + self.beta * self.X_train.t() @ self.X_train L_A = torch.cholesky(self.jitter(A)) H1T_star = torch.triangular_solve(X.t(), L_A, upper=False)[0] # Have to compute the transpose because of the way triangular_solve/trtrs is written H1T_train = torch.triangular_solve(self.X_train.t(),L_A,upper=False)[0] # predictive mean Xm depends on the term $X S_N X^T$, which we compute through a Cholesky decomposition Xm = self.beta * H1T_star.t() @ H1T_train @ self.Y_train pred_mean = Xm # predictive mean # predictive covariance depends on the same terms XSX = H1T_star.t() @ H1T_star pred_covar = 1/self.beta*torch.eye(XSX.shape[0]).type(torch.double) + XSX return (pred_mean,pred_covar.diag())
def forward(self, X, y): """ Returns posterior mean and variance """ self.x, self.y, self.N = X, y, X.shape[-2] self.Kxx = self.kernel(X, X) # + torch.eye(self.N) * 1e-5 jitter = torch.eye(self.N) * self.noisestd**2 L = torch.cholesky(self.Kxx + jitter) a, _ = torch.solve(y.unsqueeze(0), L.unsqueeze(0)) alpha, _ = torch.solve(a, L.t().unsqueeze(0)) m = alpha.squeeze(0) S = L return m, S
def reduce_func_singular_values(self,func): bs,c,h,w = self._shape padded_weight = F.pad(self.conv.weight,(0,h-3,0,w-3)) w_fft = torch.rfft(padded_weight, 2, onesided=False, normalized=False) # Lift to real valued space D = phi(w_fft).permute(2,3,0,1) Dt = D.permute(0, 1, 3, 2) #transpose of D lhs = torch.matmul(D, Dt) scale = lhs.data.norm().detach()/np.sqrt(np.prod(Dt.shape)) chol_output = torch.cholesky(lhs+1e-4*scale*torch.eye(lhs.size(-1)).to(lhs.device)) eigs = torch.diagonal(chol_output,dim1=-2,dim2=-1) logdet = (func(eigs).sum() / 2.0).expand(bs) # 1/4 \sum_{h,w} log det (DDt) return logdet
def __init__( self, mean: Tensor, cov: Tensor, seed: Optional[int] = None, inv_transform: bool = False, ) -> None: r"""Engine for qMC sampling from a multivariate Normal `N(\mu, \Sigma)`. Args: mean: The mean vector. cov: The covariance matrix. seed: The seed with which to seed the random number generator of the underlying SobolEngine. inv_transform: If True, use inverse transform instead of Box-Muller. """ # validate inputs if not cov.shape[0] == cov.shape[1]: raise ValueError("Covariance matrix is not square.") if not mean.shape[0] == cov.shape[0]: raise ValueError("Dimension mismatch between mean and covariance.") if not torch.allclose(cov, cov.transpose(-1, -2)): raise ValueError("Covariance matrix is not symmetric.") self._mean = mean self._normal_engine = NormalQMCEngine( d=mean.shape[0], seed=seed, inv_transform=inv_transform ) # compute Cholesky decomp; if it fails, do the eigendecomposition try: self._corr_matrix = torch.cholesky(cov).transpose(-1, -2) except RuntimeError: eigval, eigvec = torch.symeig(cov, eigenvectors=True) if not torch.all(eigval >= -1e-8): raise ValueError("Covariance matrix not PSD.") eigval_root = eigval.clamp_min(0.0).sqrt() self._corr_matrix = (eigvec * eigval_root).transpose(-1, -2)