コード例 #1
0
    def mean(self,
             points,
             weights=None,
             n_max_iterations=32,
             epsilon=EPSILON,
             point_type='vector',
             verbose=False):
        """Compute the Frechet mean of (weighted) points.

        Parameters
        ----------
        points: array-like, shape=[n_samples, dimension]
        weights: array-like, shape=[n_samples, 1], optional
        verbose: bool, optional

        Returns
        -------
        mean
        """

        # TODO(nina): Profile this code to study performance,
        # i.e. what to do with sq_dists_between_iterates.
        def while_loop_cond(iteration, mean, variance, sq_dist):
            result = ~gs.logical_or(gs.isclose(variance, 0.),
                                    gs.less_equal(sq_dist, epsilon * variance))
            return result[0, 0] or iteration == 0

        def while_loop_body(iteration, mean, variance, sq_dist):
            logs = self.log(point=points, base_point=mean)
            tangent_mean = gs.einsum('nk,nj->j', weights, logs)

            tangent_mean /= sum_weights

            mean_next = self.exp(tangent_vec=tangent_mean, base_point=mean)

            sq_dist = self.squared_dist(mean_next, mean)
            sq_dists_between_iterates.append(sq_dist)

            variance = self.variance(points=points,
                                     weights=weights,
                                     base_point=mean_next)

            mean = mean_next
            iteration += 1
            return [iteration, mean, variance, sq_dist]

        if point_type == 'vector':
            points = gs.to_ndarray(points, to_ndim=2)
        if point_type == 'matrix':
            points = gs.to_ndarray(points, to_ndim=3)
        n_points = gs.shape(points)[0]

        if weights is None:
            weights = gs.ones((n_points, 1))

        weights = gs.array(weights)
        weights = gs.to_ndarray(weights, to_ndim=2, axis=1)

        sum_weights = gs.sum(weights)

        mean = points[0]
        if point_type == 'vector':
            mean = gs.to_ndarray(mean, to_ndim=2)
        if point_type == 'matrix':
            mean = gs.to_ndarray(mean, to_ndim=3)

        if n_points == 1:
            return mean

        sq_dists_between_iterates = []
        iteration = 0
        sq_dist = gs.array([[0.]])
        variance = gs.array([[0.]])

        last_iteration, mean, variance, sq_dist = gs.while_loop(
            lambda i, m, v, sq: while_loop_cond(i, m, v, sq),
            lambda i, m, v, sq: while_loop_body(i, m, v, sq),
            loop_vars=[iteration, mean, variance, sq_dist],
            maximum_iterations=n_max_iterations)

        if last_iteration == n_max_iterations:
            print('Maximum number of iterations {} reached.'
                  'The mean may be inaccurate'.format(n_max_iterations))

        if verbose:
            print('n_iter: {}, final variance: {}, final dist: {}'.format(
                last_iteration, variance, sq_dist))

        mean = gs.to_ndarray(mean, to_ndim=2)
        return mean
コード例 #2
0
def _default_gradient_descent(points, metric, weights, n_max_iterations,
                              point_type, epsilon, verbose):
    def while_loop_cond(iteration, mean, var, sq_dist):
        result = ~gs.logical_or(gs.isclose(var, 0.),
                                gs.less_equal(sq_dist, epsilon * var))
        return result[0, 0] or iteration == 0

    def while_loop_body(iteration, mean, var, sq_dist):

        logs = metric.log(point=points, base_point=mean)

        tangent_mean = gs.einsum('nk,nj->j', weights, logs)

        tangent_mean /= sum_weights

        estimate_next = metric.exp(tangent_vec=tangent_mean, base_point=mean)

        sq_dist = metric.squared_dist(estimate_next, mean)
        sq_dists_between_iterates.append(sq_dist)

        var = variance(points=points,
                       weights=weights,
                       metric=metric,
                       base_point=estimate_next)

        mean = estimate_next
        iteration += 1
        return [iteration, mean, var, sq_dist]

    if point_type == 'vector':
        points = gs.to_ndarray(points, to_ndim=2)
    if point_type == 'matrix':
        points = gs.to_ndarray(points, to_ndim=3)
    n_points = gs.shape(points)[0]

    if weights is None:
        weights = gs.ones((n_points, 1))

    weights = gs.array(weights)
    weights = gs.to_ndarray(weights, to_ndim=2, axis=1)

    sum_weights = gs.sum(weights)

    mean = points[0]
    if point_type == 'vector':
        mean = gs.to_ndarray(mean, to_ndim=2)
    if point_type == 'matrix':
        mean = gs.to_ndarray(mean, to_ndim=3)

    if n_points == 1:
        return mean

    sq_dists_between_iterates = []
    iteration = 0
    sq_dist = gs.array([[0.]])
    var = gs.array([[0.]])

    last_iteration, mean, var, sq_dist = gs.while_loop(
        lambda i, m, v, sq: while_loop_cond(i, m, v, sq),
        lambda i, m, v, sq: while_loop_body(i, m, v, sq),
        loop_vars=[iteration, mean, var, sq_dist],
        maximum_iterations=n_max_iterations)

    if last_iteration == n_max_iterations:
        print('Maximum number of iterations {} reached.'
              'The mean may be inaccurate'.format(n_max_iterations))

    if verbose:
        print('n_iter: {}, final variance: {}, final dist: {}'.format(
            last_iteration, var, sq_dist))

    mean = gs.to_ndarray(mean, to_ndim=2)
    return mean
コード例 #3
0
    def mean(self,
             points,
             weights=None,
             n_max_iterations=32,
             epsilon=EPSILON,
             point_type='vector',
             mean_method='default',
             verbose=False):
        """Frechet mean of (weighted) points.

        Parameters
        ----------
        points : array-like, shape=[n_samples, dimension]
        weights : array-like, shape=[n_samples, 1], optional
        verbose : bool, optional

        Returns
        -------
        mean : array-like
            the Frechet mean of points, a point on the manifold
        """
        if mean_method == 'default':

            # TODO(nina): Profile this code to study performance,
            # i.e. what to do with sq_dists_between_iterates.
            def while_loop_cond(iteration, mean, variance, sq_dist):
                result = ~gs.logical_or(
                    gs.isclose(variance, 0.),
                    gs.less_equal(sq_dist, epsilon * variance))
                return result[0, 0] or iteration == 0

            def while_loop_body(iteration, mean, variance, sq_dist):

                logs = self.log(point=points, base_point=mean)

                tangent_mean = gs.einsum('nk,nj->j', weights, logs)

                tangent_mean /= sum_weights

                mean_next = self.exp(tangent_vec=tangent_mean, base_point=mean)

                sq_dist = self.squared_dist(mean_next, mean)
                sq_dists_between_iterates.append(sq_dist)

                variance = self.variance(points=points,
                                         weights=weights,
                                         base_point=mean_next)

                mean = mean_next
                iteration += 1
                return [iteration, mean, variance, sq_dist]

            if point_type == 'vector':
                points = gs.to_ndarray(points, to_ndim=2)
            if point_type == 'matrix':
                points = gs.to_ndarray(points, to_ndim=3)
            n_points = gs.shape(points)[0]

            if weights is None:
                weights = gs.ones((n_points, 1))

            weights = gs.array(weights)
            weights = gs.to_ndarray(weights, to_ndim=2, axis=1)

            sum_weights = gs.sum(weights)

            mean = points[0]
            if point_type == 'vector':
                mean = gs.to_ndarray(mean, to_ndim=2)
            if point_type == 'matrix':
                mean = gs.to_ndarray(mean, to_ndim=3)

            if n_points == 1:
                return mean

            sq_dists_between_iterates = []
            iteration = 0
            sq_dist = gs.array([[0.]])
            variance = gs.array([[0.]])

            last_iteration, mean, variance, sq_dist = gs.while_loop(
                lambda i, m, v, sq: while_loop_cond(i, m, v, sq),
                lambda i, m, v, sq: while_loop_body(i, m, v, sq),
                loop_vars=[iteration, mean, variance, sq_dist],
                maximum_iterations=n_max_iterations)

            if last_iteration == n_max_iterations:
                print('Maximum number of iterations {} reached.'
                      'The mean may be inaccurate'.format(n_max_iterations))

            if verbose:
                print('n_iter: {}, final variance: {}, final dist: {}'.format(
                    last_iteration, variance, sq_dist))

            mean = gs.to_ndarray(mean, to_ndim=2)
            return mean

        if mean_method == 'frechet-poincare-ball':

            lr = 1e-3
            tau = 5e-3

            if len(points) == 1:
                return points

            iteration = 0
            convergence = math.inf
            barycenter = points.mean(0, keepdims=True) * 0

            while convergence > tau and n_max_iterations > iteration:

                iteration += 1

                expand_barycenter = gs.repeat(barycenter, points.shape[0], 0)

                grad_tangent = 2 * self.log(points, expand_barycenter)

                cc_barycenter = self.exp(
                    lr * grad_tangent.sum(0, keepdims=True), barycenter)

                convergence = self.dist(cc_barycenter, barycenter).max().item()

                barycenter = cc_barycenter

            if iteration == n_max_iterations:
                warnings.warn(
                    'Maximum number of iterations {} reached. The '
                    'mean may be inaccurate'.format(n_max_iterations))

            return barycenter