def test_log_normal_spherical(self): """ Test the log-normal probabilities for spherical covariance. """ N = 100 S = 50 D = 10 means = torch.randn(S, D) covs = torch.rand(S) x = torch.randn(N, D) distributions = [ dist.MultivariateNormal(means[i], torch.diag(covs[i].clone().expand(D))) for i in range(S) ] expected = [] for item in x: e_item = [] for d in distributions: e_item.append(d.log_prob(item)) expected.append(e_item) expected = torch.as_tensor(expected) predicted = log_normal(x, means, covs, 'spherical') self.assertTrue( torch.allclose(expected, predicted, atol=1e-03, rtol=1e-05))
def test_log_responsibilities(self): """ Test the log responsibilities with the help of Sklearn. """ N = 16384 S = 2048 D = 128 means = torch.randn(S, D) covs = torch.rand(S) x = torch.randn(N, D) prior = torch.rand(S) prior /= prior.sum() mixture = GaussianMixture(S, covariance_type='spherical') mixture.means_ = means.numpy() mixture.precisions_cholesky_ = np.sqrt(1 / covs.numpy()) mixture.weights_ = prior.numpy() # pylint: disable=protected-access _, expected = mixture._estimate_log_prob_resp(x.numpy()) expected = torch.from_numpy(expected) probs = log_normal(x, means, covs, 'spherical') predicted = log_responsibilities(probs, prior) self.assertTrue( torch.allclose(expected, predicted, atol=1e-03, rtol=1e-05))