def test_fit_init_random(self):
        """Test fitting data into a GMM."""
        gmm_learning = RiemannianEM(
            metric=self.metric,
            n_gaussians=self.n_gaussian,
            initialisation_method=self.initialisation_method,
        )

        means, variances, coefficients = gmm_learning.fit(self.data)

        self.assertTrue((coefficients < 1).all() and (coefficients > 0).all())
        self.assertTrue((variances < 1).all() and (variances > 0).all())
        self.assertTrue(self.space.belongs(means).all())

        gmm_learning = RiemannianEM(
            metric=self.metric,
            n_gaussians=self.n_gaussian,
            initialisation_method="random",
        )

        means, variances, coefficients = gmm_learning.fit(self.data)

        self.assertTrue((coefficients < 1).all() and (coefficients > 0).all())
        self.assertTrue((variances < 1).all() and (variances > 0).all())
        self.assertTrue(self.space.belongs(means).all())
예제 #2
0
def expectation_maximisation_poincare_ball():
    """Apply EM algorithm on three random data clusters."""
    dim = 2
    n_samples = 5

    cluster_1 = gs.random.uniform(low=0.2, high=0.6, size=(n_samples, dim))
    cluster_2 = gs.random.uniform(low=-0.6, high=-0.2, size=(n_samples, dim))
    cluster_3 = gs.random.uniform(low=-0.3, high=0, size=(n_samples, dim))
    cluster_3[:, 0] = -cluster_3[:, 0]

    data = gs.concatenate((cluster_1, cluster_2, cluster_3), axis=0)

    n_clusters = 3

    manifold = PoincareBall(dim=2)

    metric = manifold.metric

    EM = RiemannianEM(n_gaussians=n_clusters,
                      metric=metric,
                      initialisation_method='random')

    means, variances, mixture_coefficients = EM.fit(data=data)

    # Plot result
    plot = plot_gaussian_mixture_distribution(data,
                                              mixture_coefficients,
                                              means,
                                              variances,
                                              plot_precision=100,
                                              save_path='result.png',
                                              metric=metric)

    return plot
    def test_fit_init_random_sphere(self):
        """Test fitting data into a GMM."""
        space = Hypersphere(2)
        gmm_learning = RiemannianEM(
            metric=space.metric,
            n_gaussians=2,
            initialisation_method=self.initialisation_method,
        )

        means = space.random_uniform(2)
        cluster_1 = space.random_von_mises_fisher(mu=means[0], kappa=20, n_samples=140)
        cluster_2 = space.random_von_mises_fisher(mu=means[1], kappa=20, n_samples=140)

        data = gs.concatenate((cluster_1, cluster_2), axis=0)
        means, variances, coefficients = gmm_learning.fit(data)

        self.assertTrue((coefficients < 1).all() and (coefficients > 0).all())
        self.assertTrue((variances < 1).all() and (variances > 0).all())
        self.assertTrue(space.belongs(means).all())