Exemplo n.º 1
0
 def forward(self, x):
     hx = ilr(self.imputer(x), self.Psi)
     z_mean = self.encoder(hx)
     mu = self.decoder(z_mean)
     W = self.decoder.weight
     # penalties
     D = torch.exp(self.variational_logvars)
     var = torch.exp(self.log_sigma_sq)
     qdist = MultivariateNormalFactorIdentity(mu, var, D, W)
     logp = self.Psi.t() @ self.eta.t()
     prior_loss = Normal(self.zm, self.zI).log_prob(z_mean).mean()
     logit_loss = qdist.log_prob(self.eta).mean()
     mult_loss = Multinomial(logits=logp.t()).log_prob(x).mean()
     loglike = mult_loss + logit_loss + prior_loss
     return -loglike
Exemplo n.º 2
0
    def test_log_prob(self):
        loc = torch.ones(self.d)

        wdw = self.W @ torch.diag(self.D) @ self.W.t()
        sI = self.s2 * self.Id
        sigma = sI + wdw
        dist2 = MultivariateNormal(loc, covariance_matrix=sigma)
        samples = dist2.rsample([10000])
        exp_logp = dist2.log_prob(samples)

        dist1 = MultivariateNormalFactorIdentity(loc, self.s2, self.D, self.W)
        res_logp = dist1.log_prob(samples)

        self.assertAlmostEqual(float(exp_logp.mean()),
                               float(res_logp.mean()),
                               places=3)
Exemplo n.º 3
0
 def test_log_det(self):
     loc = torch.zeros(self.d)
     dist = MultivariateNormalFactorIdentity(loc, self.s2, self.D, self.W)
     cov = dist.covariance_matrix
     res = dist.log_det
     exp = torch.slogdet(cov)[1]
     tt.assert_allclose(res, exp)
Exemplo n.º 4
0
 def test_covariance_matrix(self):
     loc = torch.zeros(self.d)
     exp = (self.W @ torch.diag(self.D) @ self.W.t() + self.s2 * self.Id)
     dist = MultivariateNormalFactorIdentity(loc, self.s2, self.D, self.W)
     cov = dist.covariance_matrix
     self.assertEqual(cov.shape, (self.d, self.d))
     tt.assert_allclose(exp, cov)
Exemplo n.º 5
0
 def test_precision_matrix(self):
     # tests how accurately the inverse covariance matrix can be computed
     loc = torch.zeros(self.d)
     dist = MultivariateNormalFactorIdentity(loc, self.s2, self.D, self.W)
     r = self.W @ torch.diag(self.D) @ self.W.t() + self.s2 * self.Id
     exp = torch.inverse(r)
     tt.assert_allclose(exp,
                        dist.precision_matrix,
                        rtol=1,
                        atol=1 / (math.sqrt(self.d)))
Exemplo n.º 6
0
 def test_rsample(self):
     loc = torch.ones(self.d)
     dist = MultivariateNormalFactorIdentity(loc, self.s2, self.D, self.W)
     samples = dist.rsample([10000])
     self.assertAlmostEqual(float(samples.mean()), 1, places=2)