def em_step(self, y, Lambda, Psi_diag): """Equations 5 and 6 in Ghahramani and Hinton (1996). :param y: Observations of shape (n_dimensions, n_features). :param Lambda: Current Lambda. :param Psi_diag: Current Psi_diag. :return: Maximum likelihood parameters for the current step. """ k = self.latent_dim p, n = y.shape # E-step: compute expected moments for latent variable z. # ------------------------------------------------------- Ez = self.E_z_given_y(Lambda, Psi_diag, y) Ezz = self.E_zzT_given_y(Lambda, Psi_diag, y, k) # M-step: compute optimal Lambda and Psi. # --------------------------------------- # Compute Lambda_new (Equation 5, G&H 1996). Lambda_lterm = LA.sum_outers(y, Ez) Lambda_rterm = Ezz Lambda_new = Lambda_lterm @ inv(Lambda_rterm) # Compute Psi_diag_new (Equation 6, G&H 1996). Must use Lambda_new! Psi_rterm = LA.sum_outers(y, y) - LA.sum_outers(Lambda_new @ Ez, y) Psi_diag_new = 1./n * diag(Psi_rterm) return Lambda_new, Psi_diag_new
def test_Psi_rterm(self): Psi_rterm1 = torch.zeros(self.p, self.p) for yi in self.y.t(): Ez1 = self.pccas.E_z_given_y(self.L, self.P_diag, yi) Psi_rterm1 += outer(yi, yi) - outer(self.L @ Ez1, yi) Ez2 = self.pccav.E_z_given_y(self.L, self.P_diag, self.y) Psi_rterm2 = LA.sum_outers(self.y, self.y) - LA.sum_outers(self.L @ Ez2, self.y) self.assertTrue(torch.allclose(Psi_rterm1, Psi_rterm2, atol=0.01))
def test_Lambda_lterm(self): Lambda_lterm1 = torch.zeros(self.p, self.k) for yi in self.y.t(): Ez1 = self.pccas.E_z_given_y(self.L, self.P_diag, yi) Lambda_lterm1 += outer(yi, Ez1) Ez2 = self.pccav.E_z_given_y(self.L, self.P_diag, self.y) Lambda_lterm2 = LA.sum_outers(self.y, Ez2) self.assertTrue( torch.allclose(Lambda_lterm1, Lambda_lterm2, atol=0.0001))