def test_diag_inv_speed(self): A = torch.eye(1000) * torch.randn(1000) s1 = time.time() torch.inverse(A) d1 = time.time() - s1 s2 = time.time() linalg.diag_inv(A) d2 = time.time() - s2 # According to my tests, d1 is roughly 2 orders of magnitude slower. self.assertTrue(d1 > d2)
def neg_log_likelihood(self, y, Lambda, Psi_diag): """Appendix A (p. 5) in Ghahramani and Hinton (1996). :param y: (n x p)-dimensional observations. :param Lambda: Current value for Lambda parameter. :param Psi_diag: Current value for Psi parameter. :return: The negative log likelihood of the parameters given y. """ assert len(Psi_diag.shape) == 1 k = self.latent_dim p, n = y.shape rterm_sum = 0 for yi in y.t(): Ezi = self.E_z_given_y(Lambda, Psi_diag, yi) Ezzi = self.E_zzT_given_y(Lambda, Psi_diag, yi, k) Psi_inv = diag(LA.diag_inv(Psi_diag)) A = 1/2. * yi @ Psi_inv @ yi B = yi @ Psi_inv @ Lambda @ Ezi C = 1/2. * tr(Lambda.t() @ Psi_inv @ Lambda @ Ezzi) rterm_sum += A - B + C logdet = -n/2. * log(det(diag(Psi_diag))) ll = (logdet - rterm_sum).item() nll = -ll return nll
def neg_log_likelihood(self, y, Lambda=None, Psi_diag=None): """Appendix A (p. 5) in Ghahramani and Hinton (1996). :param y: (n x p)-dimensional observations. :param Lambda: Current value for Lambda parameter. :param Psi_diag: Current value for Psi parameter. :return: The negative log likelihood of the parameters given y. """ p, n = y.shape k = self.latent_dim if Lambda is None and Psi_diag is None: Lambda, Psi_diag = self.tile_params() Ez = self.E_z_given_y(Lambda, Psi_diag, y).t() Ezz = self.E_zzT_given_y(Lambda, Psi_diag, y, k).t() inv_Psi = diag(LA.diag_inv(Psi_diag)) A = 1/2. * diag(y.t() @ inv_Psi @ y) B = diag(y.t() @ inv_Psi @ Lambda @ Ez.t()) C = 1/2. * tr(Lambda.t() @ inv_Psi @ Lambda @ Ezz) rterm_sum = (A - B).sum() + C logdet = -n/2. * log(det(diag(Psi_diag))) ll = (logdet - rterm_sum).item() nll = -ll return nll
def test_diag_inv_accuracy(self): A = torch.eye(20) * torch.randn(20) A_diag = linalg.diag_inv(A) self.assertTrue( torch.allclose(A @ A_diag, torch.eye(len(A)), atol=0.01))