示例#1
0
def test_mvn_to_mvn():
    means = 123.0 * np.ones((1, 1))
    covs = 4.0 * np.ones((1, 1, 1))
    gmm = GMM(n_components=1, priors=np.ones(1), means=means, covariances=covs)
    mvn = gmm.to_mvn()
    assert_array_almost_equal(mvn.mean, means[0])
    assert_array_almost_equal(mvn.covariance, covs[0])
示例#2
0
def test_2_components_to_mvn():
    priors = np.array([0.25, 0.75])
    means = np.array([[1.0, 2.0], [3.0, 4.0]])
    covs = np.array([
        [[1.0, 0.0], [0.0, 1.0]],
        [[1.0, 0.0], [0.0, 1.0]],
    ])
    gmm = GMM(n_components=1, priors=priors, means=means, covariances=covs)
    mvn = gmm.to_mvn()
    assert_array_almost_equal(mvn.mean, np.array([2.5, 3.5]))
示例#3
0
def test_gmm_to_mvn_vs_mvn():
    random_state = check_random_state(0)
    gmm = GMM(n_components=2, random_state=random_state)
    gmm.from_samples(X)
    mvn_from_gmm = gmm.to_mvn()
    mvn = MVN(random_state=random_state)
    mvn.from_samples(X)
    assert_array_almost_equal(mvn_from_gmm.mean, mvn.mean)
    assert_array_almost_equal(mvn_from_gmm.covariance,
                              mvn.covariance,
                              decimal=3)