示例#1
0
class LRGaussian(Inference):
    def __init__(self, base, base_args, base_kwargs, var_clamp=1e-6):
        super(LRGaussian, self).__init__()
        self.var_clamp = var_clamp
        self.dist = None

    def fit(self, mean, variance, cov_factor):
        # ensure variance >= var_clamp
        variance = torch.clamp(variance, self.var_clamp)

        # form a low rank (+ diagonal Gaussian) distribution when fitting
        self.dist = LowRankMultivariateNormal(loc=mean,
                                              cov_diag=variance,
                                              cov_factor=cov_factor.t())

    def sample(self, scale=0.5, seed=None):
        if seed is not None:
            torch.manual_seed(seed)

        # x = \mu + L'z
        unscaled_sample = self.dist.rsample()

        # x' = \sqrt(scale) * (x - \mu) + \mu
        scaled_sample = (scale**0.5) * (unscaled_sample -
                                        self.dist.loc) + self.dist.loc

        return scaled_sample

    def log_prob(self, sample):
        return self.dist.log_prob(sample)
示例#2
0
    def fit(self, mean, variance, cov_factor):
        # ensure variance >= var_clamp
        variance = torch.clamp(variance, self.var_clamp)

        # form a low rank (+ diagonal Gaussian) distribution when fitting
        self.dist = LowRankMultivariateNormal(loc=mean,
                                              cov_diag=variance,
                                              cov_factor=cov_factor.t())
    def forward(self, x, return_latent_rec=False):
        """
		Send `x` round trip and compute a loss.

		In more detail: Given `x`, compute :math:`q(z|x)` and sample:
		:math:`\hat{z} \sim q(z|x)` . Then compute :math:`\log p(x|\hat{z})`,
		the log-likelihood of `x`, the input, given :math:`\hat{z}`, the latent
		sample. We will also need the likelihood of :math:`\hat{z}` under the
		model's prior: :math:`p(\hat{z})`, and the entropy of the latent
		conditional distribution, :math:`\mathbb{H}[q(z|x)]` . ELBO can then be
		estimated as:

		.. math:: \\frac{1}{N} \sum_{i=1}^N \mathbb{E}_{\hat{z} \sim q(z|x_i)}
			\log p(x_i,\hat{z}) + \mathbb{H}[q(z|x_i)]

		where :math:`N` denotes the number of samples from the data distribution
		and the expectation is estimated using a single latent sample,
		:math:`\hat{z}`. In practice, the outer expectation is estimated using
		minibatches.

		Parameters
		----------
		x : torch.Tensor
			A batch of samples from the data distribution (spectrograms).
			Shape: ``[batch_size, height=128, width=128]``
		return_latent_rec : bool, optional
			Whether to return latent means and reconstructions. Defaults to
			``False``.

		Returns
		-------
		loss : torch.Tensor
			Negative ELBO times the batch size. Shape: ``[]``
		latent : numpy.ndarray, if `return_latent_rec`
			Latent means. Shape: ``[batch_size, self.z_dim]``
		reconstructions : numpy.ndarray, if `return_latent_rec`
			Reconstructed means. Shape: ``[batch_size, height=128, width=128]``
		"""
        mu, u, d = self.encode(x)
        latent_dist = LowRankMultivariateNormal(mu, u, d)
        z = latent_dist.rsample()
        x_rec = self.decode(z)
        # E_{q(z|x)} p(z)
        elbo = -0.5 * (torch.sum(torch.pow(z, 2)) +
                       self.z_dim * np.log(2 * np.pi))
        # E_{q(z|x)} p(x|z)
        pxz_term = -0.5 * X_DIM * (np.log(2 * np.pi / self.model_precision))
        l2s = torch.sum(torch.pow(x.view(x.shape[0], -1) - x_rec, 2), dim=1)
        pxz_term = pxz_term - 0.5 * self.model_precision * torch.sum(l2s)
        elbo = elbo + pxz_term
        # H[q(z|x)]
        elbo = elbo + torch.sum(latent_dist.entropy())
        if return_latent_rec:
            return -elbo, z.detach().cpu().numpy(), \
             x_rec.view(-1, X_SHAPE[0], X_SHAPE[1]).detach().cpu().numpy()
        return -elbo
示例#4
0
    def _forward(self, Y: torch.Tensor, X: torch.Tensor,
                 design: torch.Tensor) -> torch.Tensor:
        """One forward pass.

        :param Y: a sample from the dataset
        :param X: normalized sample data
        :param design: the corresponding row of design matrix
        :return: the cost (elbo) of the current pass
        """
        if self._dset is None:
            raise Exception("the dataset is not provided")
        G = self._dset.get_n_features()
        C = self._dset.get_n_classes()
        N = Y.shape[0]

        Y_spread = Y.view(N, 1, G).repeat(1, C + 1, 1)

        delta_tilde = torch.exp(self._variables["log_delta"])
        mean = delta_tilde * self._data["rho"]
        mean2 = torch.mm(design, self._variables["mu"].T)  ## N x P * P x G
        mean2 = mean2.view(-1, G, 1).repeat(1, 1, C + 1)
        mean = mean + mean2

        # now do the variance modelling
        p = torch.sigmoid(self._variables["p"])

        sigma = torch.exp(self._variables["log_sigma"])
        v1 = (self._data["rho"] * p).T * sigma
        v2 = torch.pow(sigma, 2) * (1 - torch.pow(self._data["rho"] * p, 2)).T

        v1 = v1.view(1, C + 1, G, 1).repeat(N, 1, 1,
                                            1)  # extra 1 is the "rank"
        v2 = v2.view(1, C + 1, G).repeat(N, 1, 1) + 1e-6

        dist = LowRankMultivariateNormal(loc=torch.exp(mean).permute(0, 2, 1),
                                         cov_factor=v1,
                                         cov_diag=v2)

        log_p_y_on_c = dist.log_prob(Y_spread)

        gamma, log_gamma = self._recog.forward(X)
        log_alpha = F.log_softmax(self._variables["alpha_logits"], dim=0)
        alpha = F.softmax(self._variables["alpha_logits"], dim=0)
        mix_prior = self._alpha_prior.log_prob(alpha)

        elbo = (gamma *
                (log_p_y_on_c + log_alpha - log_gamma)).sum() + mix_prior

        return -elbo
示例#5
0
    def forward(self):

        # covariances
        ZZ, ZX, tr = self.covariances()
        noise = positive(self._noise)

        # trace term
        Q, _, ridge = low_rank_factor(ZZ, ZX)
        trace = 0.5 * (tr - torch.einsum('ij,ij', Q, Q)) / noise**2

        # low rank MVN
        p = LowRankMultivariateNormal(self.zeros, Q.t(), self.ones * noise**2)

        # loss
        loss = -p.log_prob(self.Y) + trace
        return loss
def test_lowrank_multivariate_normal() -> None:
    num_samples = 2000
    dim = 4
    rank = 3

    loc = np.arange(0, dim) / float(dim)
    cov_diag = np.eye(dim) * (np.arange(dim) / dim + 0.5)
    cov_factor = np.sqrt(np.ones((dim, rank)) * 0.2)
    Sigma = cov_factor @ cov_factor.T + cov_diag

    distr = LowRankMultivariateNormal(
        loc=torch.Tensor(loc.copy()),
        cov_diag=torch.Tensor(np.diag(cov_diag).copy()),
        cov_factor=torch.Tensor(cov_factor.copy()),
    )

    assert np.allclose(
        distr.covariance_matrix.numpy(), Sigma, atol=0.1, rtol=0.1
    ), f"did not match: sigma = {Sigma}, sigma_hat = {distr.covariance_matrix.numpy()}"

    samples = distr.sample((num_samples, ))

    loc_hat, cov_factor_hat, cov_diag_hat = maximum_likelihood_estimate_sgd(
        LowRankMultivariateNormalOutput(dim=dim,
                                        rank=rank,
                                        sigma_init=0.2,
                                        sigma_minimum=0.0),
        samples,
        learning_rate=0.01,
        num_epochs=10,
    )

    distr = LowRankMultivariateNormal(
        loc=torch.Tensor(loc_hat),
        cov_diag=torch.Tensor(cov_diag_hat),
        cov_factor=torch.Tensor(cov_factor_hat),
    )

    Sigma_hat = distr.covariance_matrix.numpy()

    assert np.allclose(
        loc_hat, loc, atol=0.2,
        rtol=0.1), f"mu did not match: loc = {loc}, loc_hat = {loc_hat}"

    assert np.allclose(
        Sigma_hat, Sigma, atol=0.1, rtol=0.1
    ), f"sigma did not match: sigma = {Sigma}, sigma_hat = {Sigma_hat}"
示例#7
0
 def mll(self):
     """Titsias's 2009 lower-bound"""
     a = self.active
     q = (self.K[a][:, a].cholesky().inverse() @ self.K[a]).T
     f = (LowRankMultivariateNormal(
         self.loc, q, self.sigma**2 * self.ones).log_prob(self.y) - 0.5 *
          (self.diag[~a].sum() - (q[~a]**2).sum()) / self.sigma**2)
     return f
示例#8
0
    def forward(self):

        # covariances
        ZZ, ZX, diag, Y = self.matrices()
        tr = diag.sum()
        noise = positive(self._noise)

        # trace term
        Q, _, ridge = low_rank_factor(ZZ, ZX)
        trace = 0.5 * (tr - torch.einsum('ij,ij', Q, Q)) / noise**2

        # low rank MVN
        p = LowRankMultivariateNormal(torch.zeros_like(Y), Q.t(),
                                      torch.ones_like(Y) * noise**2)

        # loss
        loss = -p.log_prob(Y) + trace
        return loss
示例#9
0
 def forward(self, x, op='func'):
     self.covariance_loss = 0
     operation = {'func': 'func', 'grad': 'gradgrad'}
     if hasattr(self.cov, 'cov_factor'):
         Q, diag, self.covariance_loss = self.cov.cov_factor(
             x, operation=operation[op])
         return LowRankMultivariateNormal(
             self.mean(x, operation=op).view(-1), Q, diag)
     else:
         return MultivariateNormal(self.mean(x, operation=op).view(-1),
                                   covariance_matrix=self.cov(
                                       x, operation=operation[op]))
示例#10
0
 def dist(self):
     return LowRankMultivariateNormal(self._mean, self._w, self._log_std.exp())