def test_mvn_to_mvn(): means = 123.0 * np.ones((1, 1)) covs = 4.0 * np.ones((1, 1, 1)) gmm = GMM(n_components=1, priors=np.ones(1), means=means, covariances=covs) mvn = gmm.to_mvn() assert_array_almost_equal(mvn.mean, means[0]) assert_array_almost_equal(mvn.covariance, covs[0])
def test_2_components_to_mvn(): priors = np.array([0.25, 0.75]) means = np.array([[1.0, 2.0], [3.0, 4.0]]) covs = np.array([ [[1.0, 0.0], [0.0, 1.0]], [[1.0, 0.0], [0.0, 1.0]], ]) gmm = GMM(n_components=1, priors=priors, means=means, covariances=covs) mvn = gmm.to_mvn() assert_array_almost_equal(mvn.mean, np.array([2.5, 3.5]))
def test_gmm_to_mvn_vs_mvn(): random_state = check_random_state(0) gmm = GMM(n_components=2, random_state=random_state) gmm.from_samples(X) mvn_from_gmm = gmm.to_mvn() mvn = MVN(random_state=random_state) mvn.from_samples(X) assert_array_almost_equal(mvn_from_gmm.mean, mvn.mean) assert_array_almost_equal(mvn_from_gmm.covariance, mvn.covariance, decimal=3)