예제 #1
0
    def test_log_normal_spherical(self):
        """
        Test the log-normal probabilities for spherical covariance.
        """
        N = 100
        S = 50
        D = 10

        means = torch.randn(S, D)
        covs = torch.rand(S)
        x = torch.randn(N, D)

        distributions = [
            dist.MultivariateNormal(means[i],
                                    torch.diag(covs[i].clone().expand(D)))
            for i in range(S)
        ]

        expected = []
        for item in x:
            e_item = []
            for d in distributions:
                e_item.append(d.log_prob(item))
            expected.append(e_item)
        expected = torch.as_tensor(expected)

        predicted = log_normal(x, means, covs, 'spherical')

        self.assertTrue(
            torch.allclose(expected, predicted, atol=1e-03, rtol=1e-05))
예제 #2
0
    def test_log_responsibilities(self):
        """
        Test the log responsibilities with the help of Sklearn.
        """
        N = 16384
        S = 2048
        D = 128

        means = torch.randn(S, D)
        covs = torch.rand(S)
        x = torch.randn(N, D)
        prior = torch.rand(S)
        prior /= prior.sum()
        mixture = GaussianMixture(S, covariance_type='spherical')
        mixture.means_ = means.numpy()
        mixture.precisions_cholesky_ = np.sqrt(1 / covs.numpy())
        mixture.weights_ = prior.numpy()

        # pylint: disable=protected-access
        _, expected = mixture._estimate_log_prob_resp(x.numpy())
        expected = torch.from_numpy(expected)

        probs = log_normal(x, means, covs, 'spherical')
        predicted = log_responsibilities(probs, prior)

        self.assertTrue(
            torch.allclose(expected, predicted, atol=1e-03, rtol=1e-05))