Exemplo n.º 1
    def test_rotate_points(self):
        sphere = Hypersphere(2)
        end_point = sphere.random_uniform()
        north_pole = gs.array([1., 0., 0.])
        result = utils.rotate_points(north_pole, end_point)
        expected = end_point
        self.assertAllClose(result, expected)

        points = sphere.random_uniform(10)
        result = utils.rotate_points(points, north_pole)
        self.assertAllClose(result, points)

        points = gs.concatenate([north_pole[None, :], points])
        result = utils.rotate_points(points, end_point)
        self.assertAllClose(result[0], end_point)
Exemplo n.º 2
    def random_von_mises_fisher(
            self, mu=None, kappa=10, n_samples=1, max_iter=100):
        """Sample with the von Mises-Fisher distribution.

        This distribution corresponds to the maximum entropy distribution
        given a mean. In dimension 2, a closed form expression is available.
        In larger dimension, rejection sampling is used according to [Wood94]_


        .. [Wood94]   Wood, Andrew T. A. “Simulation of the von Mises Fisher
                      Distribution.” Communications in Statistics - Simulation
                      and Computation, June 27, 2007.

        mu : array-like, shape=[dim]
            Mean parameter of the distribution.
        kappa : float
            Kappa parameter of the von Mises distribution.
            Optional, default: 10.
        n_samples : int
            Number of samples.
            Optional, default: 1.
        max_iter : int
            Maximum number of trials in the rejection algorithm. In case it
            is reached, the current number of samples < n_samples is returned.
            Optional, default: 100.

        point : array-like, shape=[n_samples, dim + 1]
            Points sampled on the sphere in extrinsic coordinates
            in Euclidean space of dimension dim + 1.
        dim = self.dim

        if dim == 2:
            angle = 2. * gs.pi * gs.random.rand(n_samples)
            angle = gs.to_ndarray(angle, to_ndim=2, axis=1)
            unit_vector = gs.hstack((gs.cos(angle), gs.sin(angle)))
            scalar = gs.random.rand(n_samples)

            coord_x = 1. + 1. / kappa * gs.log(
                scalar + (1. - scalar) * gs.exp(gs.array(-2. * kappa)))
            coord_x = gs.to_ndarray(coord_x, to_ndim=2, axis=1)
            coord_yz = gs.sqrt(1. - coord_x ** 2) * unit_vector
            sample = gs.hstack((coord_x, coord_yz))

            # rejection sampling in the general case
            sqrt = gs.sqrt(4 * kappa ** 2. + dim ** 2)
            envelop_param = (-2 * kappa + sqrt) / dim
            node = (1. - envelop_param) / (1. + envelop_param)
            correction = kappa * node + dim * gs.log(1. - node ** 2)

            n_accepted, n_iter = 0, 0
            result = []
            while (n_accepted < n_samples) and (n_iter < max_iter):
                sym_beta = beta.rvs(
                    dim / 2, dim / 2, size=n_samples - n_accepted)
                sym_beta = gs.cast(sym_beta, node.dtype)
                coord_x = (1 - (1 + envelop_param) * sym_beta) / (
                    1 - (1 - envelop_param) * sym_beta)
                accept_tol = gs.random.rand(n_samples - n_accepted)
                criterion = (
                    kappa * coord_x
                    + dim * gs.log(1 - node * coord_x)
                    - correction) > gs.log(accept_tol)
                n_accepted += gs.sum(criterion)
                n_iter += 1
            if n_accepted < n_samples:
                    'Maximum number of iteration reached in rejection '
                    'sampling before n_samples were accepted.')
            coord_x = gs.concatenate(result)
            coord_rest = _Hypersphere(dim - 1).random_uniform(n_accepted)
            coord_rest = gs.einsum(
                '...,...i->...i', gs.sqrt(1 - coord_x ** 2), coord_rest)
            sample = gs.concatenate([coord_x[..., None], coord_rest], axis=1)

        if mu is not None:
            sample = utils.rotate_points(sample, mu)

        return sample if (n_samples > 1) else sample[0]