def test__compute_logdet_array(self):
     """ Test spherical and diagonal on hard-coded results. """
     gmm = GaussianMixture(n_components=3,
                           num_feat_full=5,
                           num_feat_comp=3,
                           num_feat_shared=1,
                           num_samp=4,
                           transform='dct',
                           D_indices=self.td.D_indices,
                           mask=self.td.mask)
     logdet_spherical = gmm._compute_logdet_array(
         self.td.spherical_covariances, 'spherical')
     logdet_diag = gmm._compute_logdet_array(self.td.diagonal_covariances,
                                             'diag')
     self.assertArrayEqual(self.td.correct_logdet_spherical,
                           logdet_spherical)
     self.assertArrayEqual(self.td.correct_logdet_diag, logdet_diag)
    def test__compute_logdet_array_spherical(self):
        """ Test spherical logdet under compression on an example
        computed here. Redundant with test__compute_logdet_array below but was
        implemented to confirm that test is correct. """
        cov_type = 'spherical'
        rs = np.random.RandomState(10)
        gmm = GaussianMixture(n_components=3,
                              num_feat_full=5,
                              num_feat_comp=3,
                              num_feat_shared=2,
                              num_samp=4,
                              transform=None,
                              mask=None,
                              D_indices=None,
                              covariance_type=cov_type,
                              random_state=rs)
        gmm.fit_sparsifier(X=self.td.X)
        means = rs.rand(gmm.n_components, gmm.num_feat_full)
        covariances = rs.rand(gmm.n_components)

        logdet_test = gmm._compute_logdet_array(covariances, 'spherical')
        logdet_true = gmm.num_feat_comp * np.log(covariances)
        logdet_true = np.tile(logdet_true, (gmm.num_samp, 1))
        self.assertArrayEqual(logdet_test, logdet_true)