Пример #1
0
    def test_diag_inv_speed(self):
        A = torch.eye(1000) * torch.randn(1000)

        s1 = time.time()
        torch.inverse(A)
        d1 = time.time() - s1

        s2 = time.time()
        linalg.diag_inv(A)
        d2 = time.time() - s2

        # According to my tests, d1 is roughly 2 orders of magnitude slower.
        self.assertTrue(d1 > d2)
Пример #2
0
    def neg_log_likelihood(self, y, Lambda, Psi_diag):
        """Appendix A (p. 5) in Ghahramani and Hinton (1996).

        :param y:        (n x p)-dimensional observations.
        :param Lambda:   Current value for Lambda parameter.
        :param Psi_diag: Current value for Psi parameter.
        :return:         The negative log likelihood of the parameters given y.
        """
        assert len(Psi_diag.shape) == 1
        k = self.latent_dim
        p, n = y.shape
        rterm_sum = 0

        for yi in y.t():

            Ezi  = self.E_z_given_y(Lambda, Psi_diag, yi)
            Ezzi = self.E_zzT_given_y(Lambda, Psi_diag, yi, k)
            Psi_inv = diag(LA.diag_inv(Psi_diag))

            A = 1/2. * yi @ Psi_inv @ yi
            B = yi @ Psi_inv @ Lambda @ Ezi
            C = 1/2. * tr(Lambda.t() @ Psi_inv @ Lambda @ Ezzi)
            rterm_sum += A - B + C

        logdet = -n/2. * log(det(diag(Psi_diag)))

        ll = (logdet - rterm_sum).item()
        nll = -ll
        return nll
Пример #3
0
    def neg_log_likelihood(self, y, Lambda=None, Psi_diag=None):
        """Appendix A (p. 5) in Ghahramani and Hinton (1996).

        :param y:        (n x p)-dimensional observations.
        :param Lambda:   Current value for Lambda parameter.
        :param Psi_diag: Current value for Psi parameter.
        :return:         The negative log likelihood of the parameters given y.
        """
        p, n = y.shape
        k = self.latent_dim

        if Lambda is None and Psi_diag is None:
            Lambda, Psi_diag = self.tile_params()

        Ez  = self.E_z_given_y(Lambda, Psi_diag, y).t()
        Ezz = self.E_zzT_given_y(Lambda, Psi_diag, y, k).t()

        inv_Psi = diag(LA.diag_inv(Psi_diag))

        A = 1/2. * diag(y.t() @ inv_Psi @ y)
        B = diag(y.t() @ inv_Psi @ Lambda @ Ez.t())
        C = 1/2. * tr(Lambda.t() @ inv_Psi @ Lambda @ Ezz)
        rterm_sum = (A - B).sum() + C

        logdet = -n/2. * log(det(diag(Psi_diag)))

        ll = (logdet - rterm_sum).item()
        nll = -ll
        return nll
Пример #4
0
 def test_diag_inv_accuracy(self):
     A = torch.eye(20) * torch.randn(20)
     A_diag = linalg.diag_inv(A)
     self.assertTrue(
         torch.allclose(A @ A_diag, torch.eye(len(A)), atol=0.01))