Ejemplo n.º 1
0
    def __init__(self, manifolds, default_point_type='vector', n_jobs=1):
        assert default_point_type in ['vector', 'matrix']
        self.default_point_type = default_point_type

        self.manifolds = manifolds
        self.metric = ProductRiemannianMetric(
            [manifold.metric for manifold in manifolds],
            default_point_type=default_point_type)

        self.dimensions = [manifold.dimension for manifold in manifolds]
        super(ProductManifold, self).__init__(dimension=sum(self.dimensions))
        self.n_jobs = n_jobs
Ejemplo n.º 2
0
    def __init__(self, manifolds, default_point_type='vector', n_jobs=1):
        geomstats.error.check_parameter_accepted_values(
            default_point_type, 'default_point_type', ['vector', 'matrix'])

        self.dims = [manifold.dim for manifold in manifolds]
        super(ProductManifold,
              self).__init__(dim=sum(self.dims),
                             default_point_type=default_point_type)
        self.manifolds = manifolds
        self.metric = ProductRiemannianMetric(
            [manifold.metric for manifold in manifolds],
            default_point_type=default_point_type)
        self.n_jobs = n_jobs
Ejemplo n.º 3
0
    def __init__(
        self, manifolds, metrics=None, default_point_type="vector", n_jobs=1, **kwargs
    ):
        geomstats.errors.check_parameter_accepted_values(
            default_point_type, "default_point_type", ["vector", "matrix"]
        )

        self.dims = [manifold.dim for manifold in manifolds]
        if metrics is None:
            metrics = [manifold.metric for manifold in manifolds]
        kwargs.setdefault(
            "metric",
            ProductRiemannianMetric(
                metrics, default_point_type=default_point_type, n_jobs=n_jobs
            ),
        )
        dim = sum(self.dims)

        if default_point_type == "vector":
            shape = (sum([m.shape[0] for m in manifolds]),)
        else:
            shape = (len(manifolds), *manifolds[0].shape)

        super(ProductManifold, self).__init__(
            dim=dim,
            shape=shape,
            default_point_type=default_point_type,
            **kwargs,
        )
        self.manifolds = manifolds
        self.n_jobs = n_jobs
Ejemplo n.º 4
0
    def __init__(self,
                 manifolds,
                 metrics=None,
                 default_point_type="vector",
                 n_jobs=1,
                 **kwargs):
        geomstats.errors.check_parameter_accepted_values(
            default_point_type, "default_point_type", ["vector", "matrix"])

        self.dims = [manifold.dim for manifold in manifolds]
        if metrics is None:
            metrics = [manifold.metric for manifold in manifolds]
        metric = ProductRiemannianMetric(metrics,
                                         default_point_type=default_point_type,
                                         n_jobs=n_jobs)
        dim = sum(self.dims)
        shape = ((dim, ) if default_point_type == "vector" else
                 (len(manifolds), self.dims[0]))

        super(ProductManifold, self).__init__(
            dim=dim,
            shape=shape,
            metric=metric,
            default_point_type=default_point_type,
            **kwargs,
        )
        self.manifolds = manifolds
        self.n_jobs = n_jobs
Ejemplo n.º 5
0
    def __init__(self,
                 manifolds,
                 metrics=None,
                 default_point_type='vector',
                 n_jobs=1,
                 **kwargs):
        geomstats.errors.check_parameter_accepted_values(
            default_point_type, 'default_point_type', ['vector', 'matrix'])

        self.dims = [manifold.dim for manifold in manifolds]
        if metrics is None:
            metrics = [manifold.metric for manifold in manifolds]
        metric = ProductRiemannianMetric(metrics,
                                         default_point_type=default_point_type)
        super(ProductManifold,
              self).__init__(dim=sum(self.dims),
                             metric=metric,
                             default_point_type=default_point_type,
                             **kwargs)
        self.manifolds = manifolds
        self.n_jobs = n_jobs
Ejemplo n.º 6
0
class ProductManifold(Manifold):
    """Class for a product of manifolds M_1 x ... x M_n.

    In contrast to the class Landmarks or DiscretizedCruves,
    the manifolds M_1, ..., M_n need not be the same, nor of
    same dimension, but the list of manifolds needs to be provided.

    By default, a point is represented by an array of shape:
    [..., dim_1 + ... + dim_n_manifolds]
    where n_manifolds is the number of manifolds in the product.
    This type of representation is called 'vector'.

    Alternatively, a point can be represented by an array of shape:
    [..., n_manifolds, dim] if the n_manifolds have same dimension dim.
    This type of representation is called `matrix`.

    Parameters
    ----------
    manifolds : list
        List of manifolds in the product.
    default_point_type : str, {'vector', 'matrix'}
        Default representation of points.
    """

    # FIXME(nguigs): This only works for 1d points

    def __init__(self, manifolds, default_point_type='vector', n_jobs=1):
        geomstats.error.check_parameter_accepted_values(
            default_point_type, 'default_point_type', ['vector', 'matrix'])

        self.dims = [manifold.dim for manifold in manifolds]
        super(ProductManifold,
              self).__init__(dim=sum(self.dims),
                             default_point_type=default_point_type)
        self.manifolds = manifolds
        self.metric = ProductRiemannianMetric(
            [manifold.metric for manifold in manifolds],
            default_point_type=default_point_type)
        self.n_jobs = n_jobs

    @staticmethod
    def _get_method(manifold, method_name, metric_args):
        return getattr(manifold, method_name)(**metric_args)

    def _iterate_over_manifolds(self, func, args, intrinsic=False):

        cum_index = gs.cumsum(self.dims)[:-1] if intrinsic else \
            gs.cumsum([k + 1 for k in self.dims])
        arguments = {
            key: gs.split(args[key], cum_index, axis=1)
            for key in args.keys()
        }
        args_list = [{key: arguments[key][j]
                      for key in args.keys()}
                     for j in range(len(self.manifolds))]
        pool = joblib.Parallel(n_jobs=self.n_jobs)
        out = pool(
            joblib.delayed(self._get_method)(self.manifolds[i], func,
                                             args_list[i])
            for i in range(len(self.manifolds)))
        return out

    @geomstats.vectorization.decorator(['else', 'point', 'point_type'])
    def belongs(self, point, point_type=None):
        """Test if a point belongs to the manifold.

        Parameters
        ----------
        point : array-like, shape=[..., {dim, [dim_2, dim_2]}]
            Point.
        point_type : str, {'vector', 'matrix'}
            Representation of point.

        Returns
        -------
        belongs : array-like, shape=[..., 1]
            Array of booleans evaluating if the corresponding points
            belong to the manifold.
        """
        if point_type is None:
            point_type = self.default_point_type
        if point_type == 'vector':
            intrinsic = self.metric.is_intrinsic(point)
            belongs = self._iterate_over_manifolds('belongs', {'point': point},
                                                   intrinsic)
            belongs = gs.stack(belongs, axis=1)

        elif point_type == 'matrix':
            belongs = gs.stack([
                space.belongs(point[:, i])
                for i, space in enumerate(self.manifolds)
            ],
                               axis=1)

        belongs = gs.all(belongs, axis=1)
        belongs = gs.to_ndarray(belongs, to_ndim=2, axis=1)
        return belongs

    @geomstats.vectorization.decorator(['else', 'point', 'point_type'])
    def regularize(self, point, point_type=None):
        """Regularize the point into the manifold's canonical representation.

        Parameters
        ----------
        point : array-like, shape=[..., {dim, [dim_2, dim_2]}]
            Point to be regularized.
        point_type : str, {'vector', 'matrix'}
            Representation of point.

        Returns
        -------
        regularized_point : array-like, shape=[..., {dim, [dim_2, dim_2]}]
            Point in the manifold's canonical representation.
        """
        if point_type is None:
            point_type = self.default_point_type
        geomstats.error.check_parameter_accepted_values(
            point_type, 'point_type', ['vector', 'matrix'])

        if point_type == 'vector':
            intrinsic = self.metric.is_intrinsic(point)
            regularized_point = self._iterate_over_manifolds(
                'regularize', {'point': point}, intrinsic)
            regularized_point = gs.hstack(regularized_point)
        elif point_type == 'matrix':
            regularized_point = [
                manifold_i.regularize(point[:, i])
                for i, manifold_i in enumerate(self.manifolds)
            ]
            regularized_point = gs.stack(regularized_point, axis=1)
        return regularized_point

    def random_uniform(self, n_samples, point_type=None):
        """Sample in the product space from the uniform distribution.

        Parameters
        ----------
        n_samples : int, optional
            Number of samples.
        point_type : str, {'vector', 'matrix'}
            Representation of point.

        Returns
        -------
        samples : array-like, shape=[..., dim + 1]
            Points sampled on the hypersphere.
        """
        if point_type is None:
            point_type = self.default_point_type
        geomstats.error.check_parameter_accepted_values(
            point_type, 'point_type', ['vector', 'matrix'])

        if point_type == 'vector':
            data = self.manifolds[0].random_uniform(n_samples)
            if len(self.manifolds) > 1:
                for space in self.manifolds[1:]:
                    samples = space.random_uniform(n_samples)
                    data = gs.concatenate([data, samples], axis=-1)
            return data

        point = [space.random_uniform(n_samples) for space in self.manifolds]
        samples = gs.stack(point, axis=1)
        return samples