Beispiel #1
0
    def em_step(self, y, Lambda, Psi_diag):
        """Equations 5 and 6 in Ghahramani and Hinton (1996).

        :param y:        Observations of shape (n_dimensions, n_features).
        :param Lambda:   Current Lambda.
        :param Psi_diag: Current Psi_diag.
        :return:         Maximum likelihood parameters for the current step.
        """
        k = self.latent_dim
        p, n = y.shape

        # E-step: compute expected moments for latent variable z.
        # -------------------------------------------------------
        Ez  = self.E_z_given_y(Lambda, Psi_diag, y)
        Ezz = self.E_zzT_given_y(Lambda, Psi_diag, y, k)

        # M-step: compute optimal Lambda and Psi.
        # ---------------------------------------

        # Compute Lambda_new (Equation 5, G&H 1996).
        Lambda_lterm = LA.sum_outers(y, Ez)
        Lambda_rterm = Ezz
        Lambda_new   = Lambda_lterm @ inv(Lambda_rterm)

        # Compute Psi_diag_new (Equation 6, G&H 1996). Must use Lambda_new!
        Psi_rterm    = LA.sum_outers(y, y) - LA.sum_outers(Lambda_new @ Ez, y)
        Psi_diag_new = 1./n * diag(Psi_rterm)

        return Lambda_new, Psi_diag_new
Beispiel #2
0
    def test_Psi_rterm(self):
        Psi_rterm1 = torch.zeros(self.p, self.p)
        for yi in self.y.t():
            Ez1 = self.pccas.E_z_given_y(self.L, self.P_diag, yi)
            Psi_rterm1 += outer(yi, yi) - outer(self.L @ Ez1, yi)

        Ez2 = self.pccav.E_z_given_y(self.L, self.P_diag, self.y)
        Psi_rterm2 = LA.sum_outers(self.y, self.y) - LA.sum_outers(self.L @ Ez2,
                                                                   self.y)

        self.assertTrue(torch.allclose(Psi_rterm1, Psi_rterm2, atol=0.01))
Beispiel #3
0
    def test_Lambda_lterm(self):
        Lambda_lterm1 = torch.zeros(self.p, self.k)
        for yi in self.y.t():
            Ez1 = self.pccas.E_z_given_y(self.L, self.P_diag, yi)
            Lambda_lterm1 += outer(yi, Ez1)

        Ez2 = self.pccav.E_z_given_y(self.L, self.P_diag, self.y)
        Lambda_lterm2 = LA.sum_outers(self.y, Ez2)

        self.assertTrue(
            torch.allclose(Lambda_lterm1, Lambda_lterm2, atol=0.0001))