class LRGaussian(Inference): def __init__(self, base, base_args, base_kwargs, var_clamp=1e-6): super(LRGaussian, self).__init__() self.var_clamp = var_clamp self.dist = None def fit(self, mean, variance, cov_factor): # ensure variance >= var_clamp variance = torch.clamp(variance, self.var_clamp) # form a low rank (+ diagonal Gaussian) distribution when fitting self.dist = LowRankMultivariateNormal(loc=mean, cov_diag=variance, cov_factor=cov_factor.t()) def sample(self, scale=0.5, seed=None): if seed is not None: torch.manual_seed(seed) # x = \mu + L'z unscaled_sample = self.dist.rsample() # x' = \sqrt(scale) * (x - \mu) + \mu scaled_sample = (scale**0.5) * (unscaled_sample - self.dist.loc) + self.dist.loc return scaled_sample def log_prob(self, sample): return self.dist.log_prob(sample)
def fit(self, mean, variance, cov_factor): # ensure variance >= var_clamp variance = torch.clamp(variance, self.var_clamp) # form a low rank (+ diagonal Gaussian) distribution when fitting self.dist = LowRankMultivariateNormal(loc=mean, cov_diag=variance, cov_factor=cov_factor.t())
def forward(self, x, return_latent_rec=False): """ Send `x` round trip and compute a loss. In more detail: Given `x`, compute :math:`q(z|x)` and sample: :math:`\hat{z} \sim q(z|x)` . Then compute :math:`\log p(x|\hat{z})`, the log-likelihood of `x`, the input, given :math:`\hat{z}`, the latent sample. We will also need the likelihood of :math:`\hat{z}` under the model's prior: :math:`p(\hat{z})`, and the entropy of the latent conditional distribution, :math:`\mathbb{H}[q(z|x)]` . ELBO can then be estimated as: .. math:: \\frac{1}{N} \sum_{i=1}^N \mathbb{E}_{\hat{z} \sim q(z|x_i)} \log p(x_i,\hat{z}) + \mathbb{H}[q(z|x_i)] where :math:`N` denotes the number of samples from the data distribution and the expectation is estimated using a single latent sample, :math:`\hat{z}`. In practice, the outer expectation is estimated using minibatches. Parameters ---------- x : torch.Tensor A batch of samples from the data distribution (spectrograms). Shape: ``[batch_size, height=128, width=128]`` return_latent_rec : bool, optional Whether to return latent means and reconstructions. Defaults to ``False``. Returns ------- loss : torch.Tensor Negative ELBO times the batch size. Shape: ``[]`` latent : numpy.ndarray, if `return_latent_rec` Latent means. Shape: ``[batch_size, self.z_dim]`` reconstructions : numpy.ndarray, if `return_latent_rec` Reconstructed means. Shape: ``[batch_size, height=128, width=128]`` """ mu, u, d = self.encode(x) latent_dist = LowRankMultivariateNormal(mu, u, d) z = latent_dist.rsample() x_rec = self.decode(z) # E_{q(z|x)} p(z) elbo = -0.5 * (torch.sum(torch.pow(z, 2)) + self.z_dim * np.log(2 * np.pi)) # E_{q(z|x)} p(x|z) pxz_term = -0.5 * X_DIM * (np.log(2 * np.pi / self.model_precision)) l2s = torch.sum(torch.pow(x.view(x.shape[0], -1) - x_rec, 2), dim=1) pxz_term = pxz_term - 0.5 * self.model_precision * torch.sum(l2s) elbo = elbo + pxz_term # H[q(z|x)] elbo = elbo + torch.sum(latent_dist.entropy()) if return_latent_rec: return -elbo, z.detach().cpu().numpy(), \ x_rec.view(-1, X_SHAPE[0], X_SHAPE[1]).detach().cpu().numpy() return -elbo
def _forward(self, Y: torch.Tensor, X: torch.Tensor, design: torch.Tensor) -> torch.Tensor: """One forward pass. :param Y: a sample from the dataset :param X: normalized sample data :param design: the corresponding row of design matrix :return: the cost (elbo) of the current pass """ if self._dset is None: raise Exception("the dataset is not provided") G = self._dset.get_n_features() C = self._dset.get_n_classes() N = Y.shape[0] Y_spread = Y.view(N, 1, G).repeat(1, C + 1, 1) delta_tilde = torch.exp(self._variables["log_delta"]) mean = delta_tilde * self._data["rho"] mean2 = torch.mm(design, self._variables["mu"].T) ## N x P * P x G mean2 = mean2.view(-1, G, 1).repeat(1, 1, C + 1) mean = mean + mean2 # now do the variance modelling p = torch.sigmoid(self._variables["p"]) sigma = torch.exp(self._variables["log_sigma"]) v1 = (self._data["rho"] * p).T * sigma v2 = torch.pow(sigma, 2) * (1 - torch.pow(self._data["rho"] * p, 2)).T v1 = v1.view(1, C + 1, G, 1).repeat(N, 1, 1, 1) # extra 1 is the "rank" v2 = v2.view(1, C + 1, G).repeat(N, 1, 1) + 1e-6 dist = LowRankMultivariateNormal(loc=torch.exp(mean).permute(0, 2, 1), cov_factor=v1, cov_diag=v2) log_p_y_on_c = dist.log_prob(Y_spread) gamma, log_gamma = self._recog.forward(X) log_alpha = F.log_softmax(self._variables["alpha_logits"], dim=0) alpha = F.softmax(self._variables["alpha_logits"], dim=0) mix_prior = self._alpha_prior.log_prob(alpha) elbo = (gamma * (log_p_y_on_c + log_alpha - log_gamma)).sum() + mix_prior return -elbo
def forward(self): # covariances ZZ, ZX, tr = self.covariances() noise = positive(self._noise) # trace term Q, _, ridge = low_rank_factor(ZZ, ZX) trace = 0.5 * (tr - torch.einsum('ij,ij', Q, Q)) / noise**2 # low rank MVN p = LowRankMultivariateNormal(self.zeros, Q.t(), self.ones * noise**2) # loss loss = -p.log_prob(self.Y) + trace return loss
def test_lowrank_multivariate_normal() -> None: num_samples = 2000 dim = 4 rank = 3 loc = np.arange(0, dim) / float(dim) cov_diag = np.eye(dim) * (np.arange(dim) / dim + 0.5) cov_factor = np.sqrt(np.ones((dim, rank)) * 0.2) Sigma = cov_factor @ cov_factor.T + cov_diag distr = LowRankMultivariateNormal( loc=torch.Tensor(loc.copy()), cov_diag=torch.Tensor(np.diag(cov_diag).copy()), cov_factor=torch.Tensor(cov_factor.copy()), ) assert np.allclose( distr.covariance_matrix.numpy(), Sigma, atol=0.1, rtol=0.1 ), f"did not match: sigma = {Sigma}, sigma_hat = {distr.covariance_matrix.numpy()}" samples = distr.sample((num_samples, )) loc_hat, cov_factor_hat, cov_diag_hat = maximum_likelihood_estimate_sgd( LowRankMultivariateNormalOutput(dim=dim, rank=rank, sigma_init=0.2, sigma_minimum=0.0), samples, learning_rate=0.01, num_epochs=10, ) distr = LowRankMultivariateNormal( loc=torch.Tensor(loc_hat), cov_diag=torch.Tensor(cov_diag_hat), cov_factor=torch.Tensor(cov_factor_hat), ) Sigma_hat = distr.covariance_matrix.numpy() assert np.allclose( loc_hat, loc, atol=0.2, rtol=0.1), f"mu did not match: loc = {loc}, loc_hat = {loc_hat}" assert np.allclose( Sigma_hat, Sigma, atol=0.1, rtol=0.1 ), f"sigma did not match: sigma = {Sigma}, sigma_hat = {Sigma_hat}"
def mll(self): """Titsias's 2009 lower-bound""" a = self.active q = (self.K[a][:, a].cholesky().inverse() @ self.K[a]).T f = (LowRankMultivariateNormal( self.loc, q, self.sigma**2 * self.ones).log_prob(self.y) - 0.5 * (self.diag[~a].sum() - (q[~a]**2).sum()) / self.sigma**2) return f
def forward(self): # covariances ZZ, ZX, diag, Y = self.matrices() tr = diag.sum() noise = positive(self._noise) # trace term Q, _, ridge = low_rank_factor(ZZ, ZX) trace = 0.5 * (tr - torch.einsum('ij,ij', Q, Q)) / noise**2 # low rank MVN p = LowRankMultivariateNormal(torch.zeros_like(Y), Q.t(), torch.ones_like(Y) * noise**2) # loss loss = -p.log_prob(Y) + trace return loss
def forward(self, x, op='func'): self.covariance_loss = 0 operation = {'func': 'func', 'grad': 'gradgrad'} if hasattr(self.cov, 'cov_factor'): Q, diag, self.covariance_loss = self.cov.cov_factor( x, operation=operation[op]) return LowRankMultivariateNormal( self.mean(x, operation=op).view(-1), Q, diag) else: return MultivariateNormal(self.mean(x, operation=op).view(-1), covariance_matrix=self.cov( x, operation=operation[op]))
def dist(self): return LowRankMultivariateNormal(self._mean, self._w, self._log_std.exp())