Exemple #1
0
    def test_sample_von_mises_fisher(self):
        """
        Check that the maximum likelihood estimates of the mean and
        concentration parameter are close to the real values. A first
        estimation of the concentration parameter is obtained by a
        closed-form expression and improved through the Newton method.
        """
        dim = 2
        n_points = 1000000
        sphere = Hypersphere(dim)

        # check mean value for concentrated distribution
        kappa = 1000000
        points = sphere.random_von_mises_fisher(kappa, n_points)
        sum_points = gs.sum(points, axis=0)
        mean = gs.array([0., 0., 1.])
        mean_estimate = sum_points / gs.linalg.norm(sum_points)
        expected = mean
        result = mean_estimate
        self.assertTrue(
                gs.allclose(result, expected, atol=MEAN_ESTIMATION_TOL)
                )
        # check concentration parameter for dispersed distribution
        kappa = 1
        points = sphere.random_von_mises_fisher(kappa, n_points)
        sum_points = gs.sum(points, axis=0)
        mean_norm = gs.linalg.norm(sum_points) / n_points
        kappa_estimate = (mean_norm * (dim + 1. - mean_norm**2)
                          / (1. - mean_norm**2))
        p = dim + 1
        n_steps = 100
        for i in range(n_steps):
            bessel_func_1 = scipy.special.iv(p/2., kappa_estimate)
            bessel_func_2 = scipy.special.iv(p/2.-1., kappa_estimate)
            ratio = bessel_func_1 / bessel_func_2
            denominator = 1. - ratio**2 - (p-1.)*ratio/kappa_estimate
            kappa_estimate = kappa_estimate - (ratio-mean_norm)/denominator
        expected = kappa
        result = kappa_estimate
        self.assertTrue(
                gs.allclose(result, expected, atol=KAPPA_ESTIMATION_TOL))
Exemple #2
0
 def test_optimal_quantization(self):
     """
         Check that optimal quantization yields the same result as
         the karcher flow algorithm when we look for one center.
         """
     dim = 2
     n_points = 1000
     n_centers = 1
     sphere = Hypersphere(dim)
     points = sphere.random_von_mises_fisher(kappa=10, n_samples=n_points)
     mean = sphere.metric.mean(points)
     centers, weights, clusters, n_iterations = sphere.metric.\
         optimal_quantization(points=points, n_centers=n_centers)
     error = sphere.metric.dist(mean, centers)
     diameter = sphere.metric.diameter(points)
     result = error / diameter
     expected = 0.0
     self.assertTrue(
         gs.allclose(result, expected, atol=OPTIMAL_QUANTIZATION_TOL))
Exemple #3
0
def main():
    fig = plt.figure(figsize=(15, 5))

    sphere = Hypersphere(dimension=2)

    data = sphere.random_von_mises_fisher(kappa=15, n_samples=140)
    mean = sphere.metric.mean(data)

    tpca = TangentPCA(metric=sphere.metric, n_components=2)
    tpca = tpca.fit(data, base_point=mean)
    tangent_projected_data = tpca.transform(data)

    geodesic_0 = sphere.metric.geodesic(
        initial_point=mean, initial_tangent_vec=tpca.components_[0])
    geodesic_1 = sphere.metric.geodesic(
        initial_point=mean, initial_tangent_vec=tpca.components_[1])

    n_steps = 100
    t = np.linspace(-1, 1, n_steps)
    geodesic_points_0 = geodesic_0(t)
    geodesic_points_1 = geodesic_1(t)

    print('Coordinates of the Log of the first 5 data points at the mean, '
          'projected on the principal components:')
    print(tangent_projected_data[:5])

    ax_var = fig.add_subplot(121)
    xticks = np.arange(1, 2 + 1, 1)
    ax_var.xaxis.set_ticks(xticks)
    ax_var.set_title('Explained variance')
    ax_var.set_xlabel('Number of Principal Components')
    ax_var.set_ylim((0, 1))
    ax_var.plot(xticks, tpca.explained_variance_ratio_)

    ax = fig.add_subplot(122, projection="3d")

    visualization.plot(mean, ax, space='S2', color='darkgreen', s=10)
    visualization.plot(geodesic_points_0, ax, space='S2', linewidth=2)
    visualization.plot(geodesic_points_1, ax, space='S2', linewidth=2)
    visualization.plot(data, ax, space='S2', color='black', alpha=0.7)

    plt.show()