def MultiChannelRateDistortion(W, device='cpu', eps=0.1): ''' W in shape(N, C, H, W) ''' m, c, _, _ = W.shape W = W.reshape(m, c, -1).permute(1, 2, 0) mean = W.mean(dim=2) W = W - mean.unsqueeze(dim=2) # W in shape(c, n, m) n = W.shape[1] I = torch.eye(n).to(device) a = n / (m * eps**2) # rate = torch.logdet(I.unsqueeze(0) + a * W.matmul(W.transpose(2,1))).sum() / (m*n*c) * (m+n) # rate = rate + torch.log(1 + mean.mul(mean).sum(dim=1)/eps**2).sum() / (m*c) rate = torch.logdet(I.unsqueeze(0) + a * W.matmul(W.transpose(2, 1))).mean() rate = rate / 2. return rate
def forward(self, z, reverse=False): # shape batch_size, group_size, n_of_groups = z.size() W = self.conv.weight.squeeze() if reverse: if not hasattr(self, 'W_inverse'): # Reverse computation W_inverse = W.float().inverse() W_inverse = Variable(W_inverse[..., None]) if (z.type() == 'torch.cuda.HalfTensor' or z.type() == 'torch.HalfTensor'): W_inverse = W_inverse.half() self.W_inverse = W_inverse z = F.conv1d(z, self.W_inverse, bias=None, stride=1, padding=0) return z else: # Forward computation log_det_W = batch_size * n_of_groups * torch.logdet(W.float()) z = self.conv(z) return z, log_det_W
def test_inv_quad_logdet(self): # Forward lazy_tensor = self.create_lazy_tensor(with_solves=True, with_logdet=True) evaluated = self.evaluate_lazy_tensor(lazy_tensor) flattened_evaluated = evaluated.view(-1, *lazy_tensor.matrix_shape) vecs = lazy_tensor.eager_rhss[0].clone().detach().requires_grad_(True) vecs_copy = lazy_tensor.eager_rhss[0].clone().detach().requires_grad_(True) with gpytorch.settings.num_trace_samples(128), warnings.catch_warnings(record=True) as ws: res_inv_quad, res_logdet = lazy_tensor.inv_quad_logdet(inv_quad_rhs=vecs, logdet=True) self.assertFalse(any(issubclass(w.category, ExtraComputationWarning) for w in ws)) res = res_inv_quad + res_logdet actual_inv_quad = evaluated.inverse().matmul(vecs_copy).mul(vecs_copy).sum(-2).sum(-1) actual_logdet = torch.cat( [torch.logdet(flattened_evaluated[i]).unsqueeze(0) for i in range(lazy_tensor.batch_shape.numel())] ).view(lazy_tensor.batch_shape) actual = actual_inv_quad + actual_logdet diff = (res - actual).abs() / actual.abs().clamp(1, math.inf) self.assertLess(diff.max().item(), 15e-2)
def forward(self, z, reverse=False): # shape batch_size, group_size, n_of_groups = z.size() W = self.conv.weight.squeeze() if reverse: if not hasattr(self, 'W_inverse'): # Reverse computation W_inverse = W.float().inverse() W_inverse = Variable(W_inverse[..., None]) if "HalfTensor" in z.type(): W_inverse = W_inverse.half() self.W_inverse = W_inverse self.conv.weight.data = self.W_inverse z = self.conv(z) return z else: # Forward computation log_det_W = batch_size * n_of_groups * torch.logdet(W) z = self.conv(z) return z, log_det_W
def forward(self, x, x_mask=None, reverse=False, **kwargs): # pylint: disable=unused-argument """ Shapes: - x: :math:`[B, C, T]` - x_mask: :math:`[B, 1, T]` """ b, c, t = x.size() assert c % self.num_splits == 0 if x_mask is None: x_mask = 1 x_len = torch.ones((b, ), dtype=x.dtype, device=x.device) * t else: x_len = torch.sum(x_mask, [1, 2]) x = x.view(b, 2, c // self.num_splits, self.num_splits // 2, t) x = x.permute(0, 1, 3, 2, 4).contiguous().view(b, self.num_splits, c // self.num_splits, t) if reverse: if self.weight_inv is not None: weight = self.weight_inv else: weight = torch.inverse( self.weight.float()).to(dtype=self.weight.dtype) logdet = None else: weight = self.weight if self.no_jacobian: logdet = 0 else: logdet = torch.logdet( self.weight) * (c / self.num_splits) * x_len # [b] weight = weight.view(self.num_splits, self.num_splits, 1, 1) z = F.conv2d(x, weight) z = z.view(b, 2, self.num_splits // 2, c // self.num_splits, t) z = z.permute(0, 1, 3, 2, 4).contiguous().view(b, c, t) * x_mask return z, logdet
def lpdf_real_dirichlet(r, lpdf_p): """ Remember to perform: r -= r.max() when before zeroing out the gradient """ # The rank is the dim(r) - 1, since the remaining parameter # must be one minus the sum of the other parameters. rank = r.size().numel() - 1 J = torch.empty([rank, rank], dtype=torch.float64) p = torch.softmax(r, 0) for i in range(rank): for j in range(i + 1): if i == j: # J[i, j] = torch.exp(x[i]) * (sum_x - torch.exp(x[j])) / (sum_x ** 2) J[i, i] = p[i] * (1 - p[i]) else: # tmp = torch.exp(x[i] + x[j]) / (sum_x ** 2) tmp = -p[i] * p[j] J[i, j] = tmp J[j, i] = tmp return lpdf_p(p) + torch.logdet(J)
def test_inv_quad_logdet_no_reduce(self): # Forward lazy_tensor = self.create_lazy_tensor() evaluated = self.evaluate_lazy_tensor(lazy_tensor) flattened_evaluated = evaluated.view(-1, *lazy_tensor.matrix_shape) vecs = torch.randn(*lazy_tensor.batch_shape, lazy_tensor.size(1), 3, requires_grad=True) vecs_copy = vecs.clone().detach_().requires_grad_(True) with gpytorch.settings.num_trace_samples(128): res_inv_quad, res_logdet = lazy_tensor.inv_quad_logdet( inv_quad_rhs=vecs, logdet=True, reduce_inv_quad=False ) res = res_inv_quad.sum(-1) + res_logdet actual_inv_quad = evaluated.inverse().matmul(vecs_copy).mul(vecs_copy).sum(-2).sum(-1) actual_logdet = torch.cat( [torch.logdet(flattened_evaluated[i]).unsqueeze(0) for i in range(lazy_tensor.batch_shape.numel())] ).view(lazy_tensor.batch_shape) actual = actual_inv_quad + actual_logdet diff = (res - actual).abs() / actual.abs().clamp(1, math.inf) self.assertLess(diff.max().item(), 15e-2)
def forward(self, z, reverse=False): # shape batch_size, group_size, n_of_groups = z.size() W = self.conv.weight.squeeze() if reverse: if not hasattr(self, 'set'): # Reverse computation W_inverse = W.float().inverse() W_inverse = W_inverse[..., None] self.W_inverse = W_inverse z = torch.nn.functional.conv1d(z, self.W_inverse, bias=None, stride=1, padding=0) return z else: # Forward computation log_det_W = batch_size * n_of_groups * torch.logdet(W) z = self.conv(z) return z, log_det_W
def forward(self, post: Posterior, comp: Tensor) -> Tensor: r"""Calculate approximated log evidence, i.e., log(P(D|theta)) Args: post: training posterior distribution from self.model comp: Comparisons pairs, see PairwiseGP.__init__ for more details Returns: The approximated evidence, i.e., the marginal log likelihood """ model = self.model if comp is not model.comparisons: raise RuntimeError("Must train on training data") f_max = post.mean log_posterior = model._posterior_f(f_max) part1 = -log_posterior part2 = model.covar @ model.likelihood_hess eye = torch.eye(part2.size(-1), dtype=model.datapoints.dtype, device=model.datapoints.device).expand(part2.shape) part2 = part2 + eye part2 = -0.5 * torch.logdet(part2) evidence = part1 + part2 # Sum up mll first so that when adding prior probs it won't # propagate and double count evidence = evidence.sum() # Add log probs of priors on the (functions of) parameters for _, prior, closure, _ in self.named_priors(): evidence = evidence.add(prior.log_prob(closure()).sum()) return evidence
def JointEntropy(X, Y, device='cpu', eps=0.1): ''' This is the alpha version of matrix MR W should be in shape R(M, C1, H, W) V should be in shape R(M, C2, h, w) ''' m = X.shape[0] X = X.reshape(m, -1) Y = Y.reshape(m, -1) X = torch.cat((X, Y), dim=1) mean = X.mean(dim=0) X = X - mean.unsqueeze(dim=0) # W in shape(c, m, n) I = torch.eye(m).to(device) a = X.shape[1] / (m * eps**2) rate = torch.logdet(I + a * X.matmul(X.T)) torch.cuda.empty_cache() return rate / 2
def forward(self, z, reverse=False): # shape z = z.permute(0,2,1) batch_size, group_size, n_of_groups = z.size() W = self.conv.weight.squeeze() if reverse: # if not hasattr(self, 'W_inverse'): # Reverse computation W_inverse = W.float().inverse() W_inverse = Variable(W_inverse[..., None]) if z.type() == 'torch.cuda.HalfTensor': W_inverse = W_inverse.half() self.W_inverse = W_inverse z = F.conv1d(z, self.W_inverse, bias=None, stride=1, padding=0) z = z.permute(0,2,1) return z else: # Forward computation log_det_W = torch.ones_like(z[:,0,:]) * torch.logdet(W) z = self.conv(z) z = z.permute(0,2,1) return z, log_det_W
def forward(self, z, reverse=False): batch_size, group_size, n_of_groups = z.size() # Here, group size refers to the channel dimension of the data, and each matrix of channel dimension (i.e. [:, i, :]) is # going to be multiplied by the W matrix. The n_of_groups is number of groups, and that's *how many* matrices there are # in the channel dimension, and each one will be multiplied by W # The larger the group_size value is, the more thorough the "mixing" of the variables is before going back to the AC layer # In the extreme case of group_size=1, the variables are never permuted before going back to AC, and the flow would work # terribly because nothing would change order, so the same values would go into the WN each step of flow. In the other # extreme, where n_of_groups=1, mixing is maximized, but we run the risk of our W matrix being too big and possibly getting # numerical instability when we try to invert it, since we're not using very high precision to represent it (float32). W = self.conv.weight.squeeze() if reverse: if not hasattr(self, 'W_inv'): W_inv = W.float().inverse() W_inv = Variable(W_inv[..., None]) self.W_inv = W_inv z = F.conv1d(z, self.W_inv, bias=None, stride=1, padding=0) return z else: log_det_w = batch_size * n_of_groups * torch.logdet(W) z = self.conv(z) return z, log_det_w
def MultiChannelRateDistortion_Label(W, Pi, device='cpu', eps=0.1): m, c, _, _ = W.shape W = W.reshape(m, c, -1).permute(1, 2, 0) k, _ = Pi.shape n = W.shape[1] I = torch.eye(n).to(device) # W in shape(c, n, m) rate = 0 for i in range(k): trPi = Pi[i].sum() + 1e-8 a = n / (trPi * eps**2) W_k = W.matmul(torch.diag(Pi[i])) mean = W_k.sum(dim=2) / trPi W_k = W_k - mean.unsqueeze(2) rate = rate + torch.logdet( I.unsqueeze(0) + a * W_k.matmul(torch.diag(Pi[i])).matmul( W_k.transpose(2, 1))).sum() / (m * n * c) * (trPi) rate = rate / 2 return rate
def _get_logdet_loss(self, M, delta=1e-5): G = _get_Gram(M) return torch.logdet(G + delta * torch.eye(G.size(-1)).repeat(G.size(0), 1, 1)).mean() #.cuda()
def discriminative_loss(self, Z: torch.tensor, gamma: float = 1) -> torch.tensor: m, p = Z.shape identity = torch.eye(p).cuda() return torch.logdet(identity + gamma * p / (m * self.eps) * Z.T.matmul(Z)) / 2
def forward(self) -> Tuple[torch.Tensor, int]: unstable = 0 # get sigma/mean or each level sigma_w, mu_w = self.implied_sigma_mu() sigma_l2, mu_l2 = self.implied_sigma_mu(suffix="_l2") # decompose into 11, 12, 21, 22 sigma_b, sigma_xx, sigma_yx, mu_b, mu_x = self._split(sigma_l2, mu_l2) # cluster FIML -2 * logL (without constants) loss = torch.zeros(1, dtype=mu_w.dtype, device=mu_w.device) # go through each cluster separately data_ys_available = ~torch.isnan(self.data_ys) cache_S_ij = {} cache_S_j_R_j = {} sigma_b_logdet = None sigma_b_inv = None if not self.naive_implementation: sigma_b_logdet = torch.logdet(sigma_b) sigma_b_inv = torch.inverse(sigma_b) for cluster_slice, batches in self.missing_patterns: # get cluster data and define R_j for current cluster j cluster_x = self.data_xs[cluster_slice.start, :] R_j_index = ~torch.isnan(cluster_x) no_cluster = ~R_j_index.any() # cache key = ( tuple(R_j_index.tolist()), tuple([ tuple(x) for x in data_ys_available[cluster_slice, :].tolist() ]), ) sigma_j_logdet, sigma_j_inv = cache_S_j_R_j.get(key, (None, None)) # define S_ij and S_j S_ijs = [] eye_w = torch.eye(mu_w.shape[0], dtype=mu_w.dtype, device=mu_w.device) lambda_ijs_logdet_sum = 0.0 lambda_ijs_inv = [] A_j = torch.zeros_like(sigma_w) for batch_slice in batches: size = batch_slice.stop - batch_slice.start available = data_ys_available[batch_slice.start] S_ij = eye_w[available, :] S_ijs.extend([S_ij] * size) if self.naive_implementation or sigma_j_logdet is not None: continue key_S_ij = tuple(available.tolist()) lambda_ij_inv, lambda_ij_logdet, a_j = cache_S_ij.get( key_S_ij, (None, None, None)) if lambda_ij_inv is None: lambda_ij = sigma_w # no missing data if S_ij.shape[0] != eye_w.shape[0]: # missing data lambda_ij = S_ij.mm(sigma_w.mm(S_ij.t())) lambda_ij_inv = torch.inverse(lambda_ij) lambda_ij_logdet = torch.logdet(lambda_ij) if S_ij.shape[0] != eye_w.shape[0]: # missing data a_j = S_ij.t().mm(lambda_ij_inv.mm(S_ij)) else: a_j = lambda_ij_inv cache_S_ij[key_S_ij] = lambda_ij_inv, lambda_ij_logdet, a_j lambda_ijs_inv.extend([lambda_ij_inv] * size) lambda_ijs_logdet_sum = lambda_ijs_logdet_sum + lambda_ij_logdet * size A_j = A_j + a_j * size S_j = torch.cat(S_ijs, dim=0) # means y_j = torch.cat([ self.data_ys[cluster_slice, :][data_ys_available[ cluster_slice, :]][:, None], cluster_x[R_j_index, None], ]) mu_y = mu_w + mu_b mu_j = torch.cat([S_j.mm(mu_y), mu_x[R_j_index]]) mean_diff = y_j - mu_j G_yj = mean_diff.mm(mean_diff.t()) if sigma_j_logdet is None and not self.naive_implementation: sigma_b_inv_A_j = sigma_b_inv + A_j B_j = torch.inverse(sigma_b_inv_A_j) C_j = eye_w - A_j.mm(B_j) D_j = C_j.mm(A_j) lambda_inv = block_diag(lambda_ijs_inv) V_j_inv = lambda_inv - lambda_inv.mm( S_j.mm(B_j.mm(S_j.t().mm(lambda_inv)))) if no_cluster: # no cluster sigma_11_j = V_j_inv sigma_21_j = torch.empty(0, device=sigma_11_j.device, dtype=sigma_11_j.dtype) sigma_22_1 = torch.empty([0, 0], device=sigma_11_j.device, dtype=sigma_11_j.dtype) sigma_22_inv = sigma_21_j else: # normal case sigma_22_1 = (sigma_xx - sigma_yx.t().mm( D_j.mm(sigma_yx)))[R_j_index, :][:, R_j_index] sigma_22_inv = torch.inverse(sigma_22_1) sigma_jyx = S_j.mm(sigma_yx[:, R_j_index]) sigma_11_j = (V_j_inv.mm( sigma_jyx.mm(sigma_22_inv.mm( sigma_jyx.t().mm(V_j_inv)))) + V_j_inv) sigma_21_j = -sigma_22_inv.mm(sigma_jyx.t().mm(V_j_inv)) sigma_j_inv = torch.cat( [ torch.cat([sigma_11_j, sigma_21_j]), torch.cat([sigma_21_j.t(), sigma_22_inv]), ], dim=1, ) sigma_j_logdet = (lambda_ijs_logdet_sum + sigma_b_logdet + torch.logdet(sigma_b_inv_A_j) + torch.logdet(sigma_22_1)) cache_S_j_R_j[key] = (sigma_j_logdet, sigma_j_inv) elif sigma_j_logdet is None: # naive sigma_j = S_j.mm(sigma_b.mm(S_j.t())) + block_diag( [S_ij.mm(sigma_w.mm(S_ij.t())) for S_ij in S_ijs]) if not no_cluster: sigma_j_12 = S_j.mm(sigma_yx[:, R_j_index]) sigma_j_21 = sigma_j_12.t() sigma_j_22 = sigma_xx[R_j_index, :][:, R_j_index] sigma_j = torch.cat( [ torch.cat([sigma_j, sigma_j_21]), torch.cat([sigma_j_12, sigma_j_22]), ], dim=1, ) sigma_j_logdet = torch.logdet(sigma_j) sigma_j_inv = torch.inverse(sigma_j) cache_S_j_R_j[key] = (sigma_j_logdet, sigma_j_inv) loss_current = sigma_j_logdet + torch.trace(sigma_j_inv.mm(G_yj)) unstable += loss_current.detach().item() < 0 loss = loss + loss_current.clamp(min=0.0) return loss, unstable
def Gaussian_NLLH(pred, cov, label): loss = mean( matmul( sub(label,pred).unsqueeze(1), matmul( inverse(cov), sub(label,pred).unsqueeze(2) ) ) + logdet(cov) ) return loss
def log_det_other(x): return torch.logdet(x)
def general_kl_divergence(mu_1, cov_1, mu_2, cov_2): mu_diff = (mu_1 - mu_2).view(-1, 1) cov_2_inverse = torch.inverse(cov_2) return -0.5 * (torch.logdet(cov_1) - torch.logdet(cov_2) + mu_1.size(0) - torch.trace(cov_1 @ cov_2_inverse) - mu_diff.t() @ cov_2_inverse @ mu_diff)
def _logdetgrad(self, z, x): """Returns logdet|dz/dx|.""" with torch.enable_grad(): if (self.brute_force or not self.training) and (x.ndimension() == 2 and x.shape[1] <= 10): x = x.requires_grad_(True) z = z.requires_grad_(True) Fx = x + self.nnet_x(x) Jx = batch_jacobian(Fx, x) logdet_x = torch.logdet(Jx) Fz = z + self.nnet_z(z) Jz = batch_jacobian(Fz, z) logdet_z = torch.logdet(Jz) return (logdet_x - logdet_z).view(-1, 1) if self.n_dist == 'geometric': geom_p = torch.sigmoid(self.geom_p).item() sample_fn = lambda m: geometric_sample(geom_p, m) rcdf_fn = lambda k, offset: geometric_1mcdf(geom_p, k, offset) elif self.n_dist == 'poisson': lamb = self.lamb.item() sample_fn = lambda m: poisson_sample(lamb, m) rcdf_fn = lambda k, offset: poisson_1mcdf(lamb, k, offset) if self.training: if self.n_power_series is None: # Unbiased estimation. lamb = self.lamb.item() n_samples = sample_fn(self.n_samples) n_power_series = max(n_samples) + self.n_exact_terms coeff_fn = lambda k: 1 / rcdf_fn(k, self.n_exact_terms) * \ sum(n_samples >= k - self.n_exact_terms) / len(n_samples) else: # Truncated estimation. n_power_series = self.n_power_series coeff_fn = lambda k: 1. else: # Unbiased estimation with more exact terms. lamb = self.lamb.item() n_samples = sample_fn(self.n_samples) n_power_series = max(n_samples) + self.n_exact_terms_test coeff_fn = lambda k: 1 / rcdf_fn(k, self.n_exact_terms_test) * \ sum(n_samples >= k - self.n_exact_terms_test) / len(n_samples) if not self.exact_trace: #################################### # Power series with trace estimator. #################################### # vareps_x = torch.randn_like(x) # vareps_z = torch.randn_like(z) vareps_x = torch.distributions.bernoulli.Bernoulli(torch.Tensor([0.5])).sample(x.shape).reshape(x.shape).to(x) * 2 - 1 vareps_z = torch.distributions.bernoulli.Bernoulli(torch.Tensor([0.5])).sample(z.shape).reshape(z.shape).to(z) * 2 - 1 # Choose the type of estimator. if self.training and self.neumann_grad: estimator_fn = neumann_logdet_estimator else: estimator_fn = basic_logdet_estimator # Do backprop-in-forward to save memory. if self.training and self.grad_in_forward: logdet_x = mem_eff_wrapper( estimator_fn, self.nnet_x, x, n_power_series, vareps_x, coeff_fn, self.training ) logdet_z = mem_eff_wrapper( estimator_fn, self.nnet_z, z, n_power_series, vareps_z, coeff_fn, self.training ) logdetgrad = logdet_x - logdet_z else: x = x.requires_grad_(True) z = z.requires_grad_(True) Fx = self.nnet_x(x) Fz = self.nnet_z(z) logdet_x = estimator_fn(Fx, x, n_power_series, vareps_x, coeff_fn, self.training) logdet_z = estimator_fn(Fz, z, n_power_series, vareps_z, coeff_fn, self.training) logdetgrad = logdet_x - logdet_z else: ############################################ # Power series with exact trace computation. ############################################ x = x.requires_grad_(True) z = z.requires_grad_(True) Fx = self.nnet_x(x) Jx = batch_jacobian(Fx, x) logdetJx = batch_trace(Jx) Jx_k = Jx for k in range(2, n_power_series + 1): Jx_k = torch.bmm(Jx, Jx_k) logdetJx = logdetJx + (-1)**(k+1) / k * coeff_fn(k) * batch_trace(Jx_k) Fz = self.nnet_z(z) Jz = batch_jacobian(Fz, z) logdetJz = batch_trace(Jz) Jz_k = Jz for k in range(2, n_power_series + 1): Jz_k = torch.bmm(Jz, Jz_k) logdetJz = logdetJz + (-1)**(k+1) / k * coeff_fn(k) * batch_trace(Jz_k) logdetgrad = logdetJx - logdetJz if self.training and self.n_power_series is None: self.last_n_samples.copy_(torch.tensor(n_samples).to(self.last_n_samples)) estimator = logdetgrad.detach() self.last_firmom.copy_(torch.mean(estimator).to(self.last_firmom)) self.last_secmom.copy_(torch.mean(estimator**2).to(self.last_secmom)) return logdetgrad.view(-1, 1)
def _logits_proposal_posterior( means_pp: Tensor, precisions_pp: Tensor, covariances_pp: Tensor, logits_p: Tensor, means_p: Tensor, precisions_p: Tensor, logits_d: Tensor, means_d: Tensor, precisions_d: Tensor, ): """ Return the component weights (i.e. logits) of the proposal posterior. Args: means_pp: Means of the proposal posterior. precisions_pp: Precision matrices of the proposal posterior. covariances_pp: Covariance matrices of the proposal posterior. logits_p: Component weights (i.e. logits) of the proposal distribution. means_p: Means of the proposal distribution. precisions_p: Precision matrices of the proposal distribution. logits_d: Component weights (i.e. logits) of the density estimator. means_d: Means of the density estimator. precisions_d: Precision matrices of the density estimator. Returns: Component weights of the proposal posterior. L*K terms. """ num_comps_p = precisions_p.shape[1] num_comps_d = precisions_d.shape[1] # Compute log(alpha_i * beta_j) logits_p_rep = logits_p.repeat_interleave(num_comps_d, dim=1) logits_d_rep = logits_d.repeat(1, num_comps_p) logit_factors = logits_p_rep + logits_d_rep # Compute sqrt(det()/(det()*det())) logdet_covariances_pp = torch.logdet(covariances_pp) logdet_covariances_p = -torch.logdet(precisions_p) logdet_covariances_d = -torch.logdet(precisions_d) # Repeat the proposal and density estimator terms such that there are LK terms. # Same trick as has been used above. logdet_covariances_p_rep = logdet_covariances_p.repeat_interleave( num_comps_d, dim=1 ) logdet_covariances_d_rep = logdet_covariances_d.repeat(1, num_comps_p) log_sqrt_det_ratio = 0.5 * ( logdet_covariances_pp - (logdet_covariances_p_rep + logdet_covariances_d_rep) ) # Compute for proposal, density estimator, and proposal posterior: # mu_i.T * P_i * mu_i exponent_p = batched_mixture_vmv(precisions_p, means_p) exponent_d = batched_mixture_vmv(precisions_d, means_d) exponent_pp = batched_mixture_vmv(precisions_pp, means_pp) # Extend proposal and density estimator exponents to get LK terms. exponent_p_rep = exponent_p.repeat_interleave(num_comps_d, dim=1) exponent_d_rep = exponent_d.repeat(1, num_comps_p) exponent = -0.5 * (exponent_p_rep + exponent_d_rep - exponent_pp) logits_pp = logit_factors + log_sqrt_det_ratio + exponent return logits_pp
print("running tests") N = 4 batch = 1 _dtype = t.float64 tmp1 = t.randn(N, N, dtype=_dtype) tmp1 = tmp1 @ tmp1.t() tmp2 = t.randn(N, N, dtype=_dtype) tmp2 = tmp2 @ tmp2.t() #### Test KFac Kpt = KFac(PositiveDefiniteMatrix(tmp1), PositiveDefiniteMatrix(tmp2)) Knp = t.tensor(np.kron(tmp1.numpy(), tmp2.numpy())) assert t.allclose(Kpt.full(), Knp) assert t.allclose(Kpt.logdet(), t.logdet(Kpt.full())) #### Test other m1 = Scale(N, 0.5, dtype=_dtype) m2 = PositiveDefiniteMatrix(tmp1) m3 = m2.chol() m4 = m3.t() m5 = Product(m1, m2, m3) def tests(W): #### Test matrix-vector operations x = t.randn(batch, N, 1, dtype=_dtype) # Test mm assert t.allclose(W.full() @ x, W(x))
def log_normalizer(self) -> th.FloatTensor: return th.logdet(self.L + th.eye(self.N))
def _logits_posterior( means_post: Tensor, precisions_post: Tensor, covariances_post: Tensor, logits_pp: Tensor, means_pp: Tensor, precisions_pp: Tensor, logits_d: Tensor, means_d: Tensor, precisions_d: Tensor, ): r""" Return the component weights (i.e. logits) of the MoG posterior. $\alpha_k^\prime = \frac{ \alpha_k exp(-0.5 c_k) }{ \sum{j} \alpha_j exp(-0.5 c_j) } $ with $c_k = logdet(S_k) - logdet(S_0) - logdet(S_k^\prime) + + m_k^T P_k m_k - m_0^T P_0 m_0 - m_k^\prime^T P_k^\prime m_k^\prime$ (see eqs. (25, 26) in Appendix C of [1]) Args: means_post: Means of the posterior. precisions_post: Precision matrices of the posterior. covariances_post: Covariance matrices of the posterior. logits_pp: Component weights (i.e. logits) of the proposal prior. means_pp: Means of the proposal prior. precisions_pp: Precision matrices of the proposal prior. logits_d: Component weights (i.e. logits) of the density estimator. means_d: Means of the density estimator. precisions_d: Precision matrices of the density estimator. Returns: Component weights of the proposal posterior. """ num_comps_pp = precisions_pp.shape[1] num_comps_d = precisions_d.shape[1] # Compute the ratio of the logits similar to eq (10) in Appendix A.1 of [2] logits_pp_rep = logits_pp.repeat_interleave(num_comps_d, dim=1) logits_d_rep = logits_d.repeat(1, num_comps_pp) logit_factors = logits_d_rep - logits_pp_rep # Compute the log-determinants logdet_covariances_post = torch.logdet(covariances_post) logdet_covariances_pp = -torch.logdet(precisions_pp) logdet_covariances_d = -torch.logdet(precisions_d) # Repeat the proposal and density estimator terms such that there are LK terms. # Same trick as has been used above. logdet_covariances_pp_rep = logdet_covariances_pp.repeat_interleave( num_comps_d, dim=1 ) logdet_covariances_d_rep = logdet_covariances_d.repeat(1, num_comps_pp) log_sqrt_det_ratio = 0.5 * ( # similar to eq (14) in Appendix A.1 of [2] logdet_covariances_post + logdet_covariances_pp_rep - logdet_covariances_d_rep ) # Compute for proposal, density estimator, and proposal posterior: exponent_pp = utils.batched_mixture_vmv( precisions_pp, means_pp # m_0 in eq (26) in Appendix C of [1] ) exponent_d = utils.batched_mixture_vmv( precisions_d, means_d # m_k in eq (26) in Appendix C of [1] ) exponent_post = utils.batched_mixture_vmv( precisions_post, means_post # m_k^\prime in eq (26) in Appendix C of [1] ) # Extend proposal and density estimator exponents to get LK terms. exponent_pp_rep = exponent_pp.repeat_interleave(num_comps_d, dim=1) exponent_d_rep = exponent_d.repeat(1, num_comps_pp) exponent = -0.5 * ( exponent_d_rep - exponent_pp_rep - exponent_post # eq (26) in [1] ) logits_post = logit_factors + log_sqrt_det_ratio + exponent return logits_post
def entropy(self): simple_tst = StudentT_torch(self.df) H = self.coeff * torch.logdet(self.S) + self.d * simple_tst.entropy() return H
def get_laplace_logdet_term(self): inverse_posterior_covariance = self.get_inverse_posterior_covariance() logdet_term = -0.5 * torch.logdet(inverse_posterior_covariance) return logdet_term
def logdet(self): return torch.logdet(self)
def logdet_torch(A): #return torch.log(torch.det(A)) return torch.logdet(A)
def blogdet(A): return torch.stack([torch.logdet(a) for a in A])
def log_prob(self, X): logdetX = torch.logdet(X) expon = torch.trace(torch.matmul(self.W_inv, X)) return self.C + 0.5 * (self.df - self.p - 1) * logdetX - 0.5 * expon