def test_lmvnpdf_spherical(): n_features, n_components, n_obs = 2, 3, 10 mu = rng.randint(10) * rng.rand(n_components, n_features) spherecv = rng.rand(n_components, 1) ** 2 + 1 obs = rng.randint(10) * rng.rand(n_obs, n_features) cv = np.tile(spherecv, (n_features, 1)) reference = _naive_lmvnpdf_diag(obs, mu, cv) lpr = mixture.lmvnpdf(obs, mu, spherecv, 'spherical') assert_array_almost_equal(lpr, reference)
def test_lmvnpdf_full(): n_features, n_components, n_obs = 2, 3, 10 mu = rng.randint(10) * rng.rand(n_components, n_features) cv = (rng.rand(n_components, n_features) + 1.0) ** 2 obs = rng.randint(10) * rng.rand(n_obs, n_features) fullcv = np.array([np.diag(x) for x in cv]) reference = _naive_lmvnpdf_diag(obs, mu, cv) lpr = mixture.lmvnpdf(obs, mu, fullcv, 'full') assert_array_almost_equal(lpr, reference)
def test_lmvnpdf_diag(): """ test a slow and naive implementation of lmvnpdf and compare it to the vectorized version (mixture.lmvnpdf) to test for correctness """ n_features, n_components, n_obs = 2, 3, 10 mu = rng.randint(10) * rng.rand(n_components, n_features) cv = (rng.rand(n_components, n_features) + 1.0) ** 2 obs = rng.randint(10) * rng.rand(n_obs, n_features) ref = _naive_lmvnpdf_diag(obs, mu, cv) lpr = mixture.lmvnpdf(obs, mu, cv, 'diag') assert_array_almost_equal(lpr, ref)
# Estimation des parametres ! def m_step(gamma, xi, data, sigma, mu, a, pi): pi = gamma[0] for i in range(K): for j in range(K): a[i, j] = xi[:, i, j].sum() / gamma[:, i].sum() for indx, element in enumerate(data): b = (element - mu[i]).reshape((len(element), 1)) sigma[i] += gamma[indx, i] * np.dot(b, b.T) sigma[i] /= gamma[:, i].sum() mu[i] = (gamma[:, i].reshape((len(gamma), 1)) * data).sum(axis=0) mu[i] /= gamma[:, i].sum() return pi, a, gamma, mu, sigma max_iter = 1 for i in range(max_iter): alpha, beta, gamma, xi = e_step(data, a, pi, mu, sigma) pi, a, gamma, mu, sigma = m_step(gamma, xi, data, sigma, mu, a, pi) if i == 0: old_gamma = gamma.copy() else: if ((gamma - old_gamma) ** 2).sum() < 1e6: print "break at iteration %d" % i break q = gamma.argmax(axis=1) a = lmvnpdf(data, mu, sigma, "full") b = lmvnpdf(data, gmm.means, gmm.covars, "full")