def mean(self, points, weights=None, n_max_iterations=32, epsilon=EPSILON, point_type='vector', verbose=False): """Compute the Frechet mean of (weighted) points. Parameters ---------- points: array-like, shape=[n_samples, dimension] weights: array-like, shape=[n_samples, 1], optional verbose: bool, optional Returns ------- mean """ # TODO(nina): Profile this code to study performance, # i.e. what to do with sq_dists_between_iterates. def while_loop_cond(iteration, mean, variance, sq_dist): result = ~gs.logical_or(gs.isclose(variance, 0.), gs.less_equal(sq_dist, epsilon * variance)) return result[0, 0] or iteration == 0 def while_loop_body(iteration, mean, variance, sq_dist): logs = self.log(point=points, base_point=mean) tangent_mean = gs.einsum('nk,nj->j', weights, logs) tangent_mean /= sum_weights mean_next = self.exp(tangent_vec=tangent_mean, base_point=mean) sq_dist = self.squared_dist(mean_next, mean) sq_dists_between_iterates.append(sq_dist) variance = self.variance(points=points, weights=weights, base_point=mean_next) mean = mean_next iteration += 1 return [iteration, mean, variance, sq_dist] if point_type == 'vector': points = gs.to_ndarray(points, to_ndim=2) if point_type == 'matrix': points = gs.to_ndarray(points, to_ndim=3) n_points = gs.shape(points)[0] if weights is None: weights = gs.ones((n_points, 1)) weights = gs.array(weights) weights = gs.to_ndarray(weights, to_ndim=2, axis=1) sum_weights = gs.sum(weights) mean = points[0] if point_type == 'vector': mean = gs.to_ndarray(mean, to_ndim=2) if point_type == 'matrix': mean = gs.to_ndarray(mean, to_ndim=3) if n_points == 1: return mean sq_dists_between_iterates = [] iteration = 0 sq_dist = gs.array([[0.]]) variance = gs.array([[0.]]) last_iteration, mean, variance, sq_dist = gs.while_loop( lambda i, m, v, sq: while_loop_cond(i, m, v, sq), lambda i, m, v, sq: while_loop_body(i, m, v, sq), loop_vars=[iteration, mean, variance, sq_dist], maximum_iterations=n_max_iterations) if last_iteration == n_max_iterations: print('Maximum number of iterations {} reached.' 'The mean may be inaccurate'.format(n_max_iterations)) if verbose: print('n_iter: {}, final variance: {}, final dist: {}'.format( last_iteration, variance, sq_dist)) mean = gs.to_ndarray(mean, to_ndim=2) return mean
def _default_gradient_descent(points, metric, weights, n_max_iterations, point_type, epsilon, verbose): def while_loop_cond(iteration, mean, var, sq_dist): result = ~gs.logical_or(gs.isclose(var, 0.), gs.less_equal(sq_dist, epsilon * var)) return result[0, 0] or iteration == 0 def while_loop_body(iteration, mean, var, sq_dist): logs = metric.log(point=points, base_point=mean) tangent_mean = gs.einsum('nk,nj->j', weights, logs) tangent_mean /= sum_weights estimate_next = metric.exp(tangent_vec=tangent_mean, base_point=mean) sq_dist = metric.squared_dist(estimate_next, mean) sq_dists_between_iterates.append(sq_dist) var = variance(points=points, weights=weights, metric=metric, base_point=estimate_next) mean = estimate_next iteration += 1 return [iteration, mean, var, sq_dist] if point_type == 'vector': points = gs.to_ndarray(points, to_ndim=2) if point_type == 'matrix': points = gs.to_ndarray(points, to_ndim=3) n_points = gs.shape(points)[0] if weights is None: weights = gs.ones((n_points, 1)) weights = gs.array(weights) weights = gs.to_ndarray(weights, to_ndim=2, axis=1) sum_weights = gs.sum(weights) mean = points[0] if point_type == 'vector': mean = gs.to_ndarray(mean, to_ndim=2) if point_type == 'matrix': mean = gs.to_ndarray(mean, to_ndim=3) if n_points == 1: return mean sq_dists_between_iterates = [] iteration = 0 sq_dist = gs.array([[0.]]) var = gs.array([[0.]]) last_iteration, mean, var, sq_dist = gs.while_loop( lambda i, m, v, sq: while_loop_cond(i, m, v, sq), lambda i, m, v, sq: while_loop_body(i, m, v, sq), loop_vars=[iteration, mean, var, sq_dist], maximum_iterations=n_max_iterations) if last_iteration == n_max_iterations: print('Maximum number of iterations {} reached.' 'The mean may be inaccurate'.format(n_max_iterations)) if verbose: print('n_iter: {}, final variance: {}, final dist: {}'.format( last_iteration, var, sq_dist)) mean = gs.to_ndarray(mean, to_ndim=2) return mean
def mean(self, points, weights=None, n_max_iterations=32, epsilon=EPSILON, point_type='vector', mean_method='default', verbose=False): """Frechet mean of (weighted) points. Parameters ---------- points : array-like, shape=[n_samples, dimension] weights : array-like, shape=[n_samples, 1], optional verbose : bool, optional Returns ------- mean : array-like the Frechet mean of points, a point on the manifold """ if mean_method == 'default': # TODO(nina): Profile this code to study performance, # i.e. what to do with sq_dists_between_iterates. def while_loop_cond(iteration, mean, variance, sq_dist): result = ~gs.logical_or( gs.isclose(variance, 0.), gs.less_equal(sq_dist, epsilon * variance)) return result[0, 0] or iteration == 0 def while_loop_body(iteration, mean, variance, sq_dist): logs = self.log(point=points, base_point=mean) tangent_mean = gs.einsum('nk,nj->j', weights, logs) tangent_mean /= sum_weights mean_next = self.exp(tangent_vec=tangent_mean, base_point=mean) sq_dist = self.squared_dist(mean_next, mean) sq_dists_between_iterates.append(sq_dist) variance = self.variance(points=points, weights=weights, base_point=mean_next) mean = mean_next iteration += 1 return [iteration, mean, variance, sq_dist] if point_type == 'vector': points = gs.to_ndarray(points, to_ndim=2) if point_type == 'matrix': points = gs.to_ndarray(points, to_ndim=3) n_points = gs.shape(points)[0] if weights is None: weights = gs.ones((n_points, 1)) weights = gs.array(weights) weights = gs.to_ndarray(weights, to_ndim=2, axis=1) sum_weights = gs.sum(weights) mean = points[0] if point_type == 'vector': mean = gs.to_ndarray(mean, to_ndim=2) if point_type == 'matrix': mean = gs.to_ndarray(mean, to_ndim=3) if n_points == 1: return mean sq_dists_between_iterates = [] iteration = 0 sq_dist = gs.array([[0.]]) variance = gs.array([[0.]]) last_iteration, mean, variance, sq_dist = gs.while_loop( lambda i, m, v, sq: while_loop_cond(i, m, v, sq), lambda i, m, v, sq: while_loop_body(i, m, v, sq), loop_vars=[iteration, mean, variance, sq_dist], maximum_iterations=n_max_iterations) if last_iteration == n_max_iterations: print('Maximum number of iterations {} reached.' 'The mean may be inaccurate'.format(n_max_iterations)) if verbose: print('n_iter: {}, final variance: {}, final dist: {}'.format( last_iteration, variance, sq_dist)) mean = gs.to_ndarray(mean, to_ndim=2) return mean if mean_method == 'frechet-poincare-ball': lr = 1e-3 tau = 5e-3 if len(points) == 1: return points iteration = 0 convergence = math.inf barycenter = points.mean(0, keepdims=True) * 0 while convergence > tau and n_max_iterations > iteration: iteration += 1 expand_barycenter = gs.repeat(barycenter, points.shape[0], 0) grad_tangent = 2 * self.log(points, expand_barycenter) cc_barycenter = self.exp( lr * grad_tangent.sum(0, keepdims=True), barycenter) convergence = self.dist(cc_barycenter, barycenter).max().item() barycenter = cc_barycenter if iteration == n_max_iterations: warnings.warn( 'Maximum number of iterations {} reached. The ' 'mean may be inaccurate'.format(n_max_iterations)) return barycenter