def draw(self, ax, **kwargs): """Draw.""" circle = plt.Circle((0, 0), radius=1., color='black', fill=False) ax.add_artist(circle) points_x = [gs.to_numpy(point[0]) for point in self.points] points_y = [gs.to_numpy(point[1]) for point in self.points] ax.scatter(points_x, points_y, **kwargs)
def draw_curve(self, alpha=1, zorder=0, **kwargs): """Draw a curve on the Kendall sphere.""" points_x = [gs.to_numpy(point)[0] for point in self.points] points_y = [gs.to_numpy(point)[1] for point in self.points] points_z = [gs.to_numpy(point)[2] for point in self.points] self.ax.plot3D( points_x, points_y, points_z, alpha=alpha, zorder=zorder, **kwargs )
def draw_points(self, alpha=1, zorder=0, **kwargs): """Draw points on the Kendall sphere.""" points_x = [gs.to_numpy(point)[0] for point in self.points] points_y = [gs.to_numpy(point)[1] for point in self.points] points_z = [gs.to_numpy(point)[2] for point in self.points] self.ax.scatter( points_x, points_y, points_z, alpha=alpha, zorder=zorder, **kwargs )
def draw(self, n_theta=25, n_phi=13, scale=0.05, elev=60.0, azim=0.0): """Draw the sphere regularly sampled with corresponding triangles.""" self.set_ax() self.set_view(elev=elev, azim=azim) self.ax.set_axis_off() plt.tight_layout() coords_theta = gs.linspace(0.0, 2.0 * gs.pi, n_theta) coords_phi = gs.linspace(0.0, gs.pi, n_phi) coords_x = gs.to_numpy(0.5 * gs.outer(gs.sin(coords_phi), gs.cos(coords_theta))) coords_y = gs.to_numpy(0.5 * gs.outer(gs.sin(coords_phi), gs.sin(coords_theta))) coords_z = gs.to_numpy( 0.5 * gs.outer(gs.cos(coords_phi), gs.ones_like(coords_theta)) ) self.ax.plot_surface( coords_x, coords_y, coords_z, rstride=1, cstride=1, color="grey", linewidth=0, alpha=0.1, zorder=-1, ) self.ax.plot_wireframe( coords_x, coords_y, coords_z, linewidths=0.6, color="grey", alpha=0.6, zorder=-1, ) def lim(theta): return ( gs.pi - self.elev + (2.0 * self.elev - gs.pi) / gs.pi * abs(self.azim - theta) ) for theta in gs.linspace(0.0, 2.0 * gs.pi, n_theta // 2 + 1): for phi in gs.linspace(0.0, gs.pi, n_phi): if theta <= self.azim + gs.pi and phi <= lim(theta): self.draw_triangle(theta, phi, scale) if theta > self.azim + gs.pi and phi < lim( 2.0 * self.azim + 2.0 * gs.pi - theta ): self.draw_triangle(theta, phi, scale)
def draw(self, n_r=7, n_theta=25, scale=0.05): """Draw the disk regularly sampled with corresponding triangles.""" self.set_ax() self.ax.set_axis_off() plt.tight_layout() coords_r = gs.linspace(0.0, gs.pi / 4.0, n_r) coords_theta = gs.linspace(0.0, 2.0 * gs.pi, n_theta) coords_x = gs.to_numpy(gs.outer(coords_r, gs.cos(coords_theta))) coords_y = gs.to_numpy(gs.outer(coords_r, gs.sin(coords_theta))) self.ax.fill( list(coords_x[-1, :]), list(coords_y[-1, :]), color="grey", alpha=0.1, zorder=-1, ) for i_r in range(n_r): self.ax.plot( coords_x[i_r, :], coords_y[i_r, :], linewidth=0.6, color="grey", alpha=0.6, zorder=-1, ) for i_t in range(n_theta): self.ax.plot( coords_x[:, i_t], coords_y[:, i_t], linewidth=0.6, color="grey", alpha=0.6, zorder=-1, ) for r in gs.linspace(0.0, gs.pi / 4, n_r): for theta in gs.linspace(0.0, 2.0 * gs.pi, n_theta // 2 + 1): if theta == 0.0: self.draw_triangle(0.0, 0.0, scale) else: self.draw_triangle(r, theta, scale)
def _convert_gs_to_np(value): if gs.is_array(value): return gs.to_numpy(value) elif isinstance(value, (list, tuple)): new_value = [] for value_ in value: new_value.append(_convert_gs_to_np(value_)) if isinstance(value, tuple): new_value = tuple(new_value) return new_value elif isinstance(value, dict): return { key: _convert_gs_to_np(value_) for key, value_ in value.items() } return value
def draw_points(self, ax, points=None, **scatter_kwargs): if points is None: points = self.points points = [gs.autodiff.detach(point) for point in points] points = [gs.to_numpy(point) for point in points] points_x = [point[0] for point in points] points_y = [point[1] for point in points] points_z = [point[2] for point in points] ax.scatter(points_x, points_y, points_z, **scatter_kwargs) for i_point, point in enumerate(points): if "label" in scatter_kwargs: if len(scatter_kwargs["label"]) == len(points): ax.text( point[0], point[1], point[2], scatter_kwargs["label"][i_point], size=10, zorder=1, color="k", )
def main(): r"""Compute and visualize a geodesic regression on the sphere. The generative model of the data is: :math:`Z = Exp_{\beta_0}(\beta_1.X)` and :math:`Y = Exp_Z(\epsilon)` where: - :math:`Exp` denotes the Riemannian exponential, - :math:`\beta_0` is called the intercept, - :math:`\beta_1` is called the coefficient, - :math:`\epsilon \sim N(0, 1)` is a standard Gaussian noise, - :math:`X` is the input, :math:`Y` is the target. """ # Generate noise-free data n_samples = 50 X = gs.random.rand(n_samples) X -= gs.mean(X) intercept = SPACE.random_uniform() coef = SPACE.to_tangent(5.0 * gs.random.rand(EMBEDDING_DIM), base_point=intercept) y = METRIC.exp(X[:, None] * coef, base_point=intercept) # Generate normal noise normal_noise = gs.random.normal(size=(n_samples, EMBEDDING_DIM)) noise = SPACE.to_tangent(normal_noise, base_point=y) / gs.pi / 2 rss = gs.sum(METRIC.squared_norm(noise, base_point=y)) / n_samples # Add noise y = METRIC.exp(noise, y) # True noise level and R2 estimator = FrechetMean(METRIC) estimator.fit(y) variance_ = variance(y, estimator.estimate_, metric=METRIC) r2 = 1 - rss / variance_ # Fit geodesic regression gr = GeodesicRegression(SPACE, center_X=False, method="extrinsic", verbose=True) gr.fit(X, y, compute_training_score=True) intercept_hat, coef_hat = gr.intercept_, gr.coef_ # Measure Mean Squared Error mse_intercept = METRIC.squared_dist(intercept_hat, intercept) tangent_vec_to_transport = coef_hat tangent_vec_of_transport = METRIC.log(intercept, base_point=intercept_hat) transported_coef_hat = METRIC.parallel_transport( tangent_vec=tangent_vec_to_transport, base_point=intercept_hat, direction=tangent_vec_of_transport, ) mse_coef = METRIC.squared_norm(transported_coef_hat - coef, base_point=intercept) # Measure goodness of fit r2_hat = gr.training_score_ print(f"MSE on the intercept: {mse_intercept:.2e}") print(f"MSE on the coef, i.e. initial velocity: {mse_coef:.2e}") print(f"Determination coefficient: R^2={r2_hat:.2f}") print(f"True R^2: {r2:.2f}") # Plot fitted_data = gr.predict(X) fig = plt.figure(figsize=(8, 8)) ax = fig.add_subplot(111, projection="3d") sphere_visu = visualization.Sphere(n_meridians=30) ax = sphere_visu.set_ax(ax=ax) path = METRIC.geodesic(initial_point=intercept_hat, initial_tangent_vec=coef_hat) regressed_geodesic = path( gs.linspace(0.0, 1.0, 100) * gs.pi * 2 / METRIC.norm(coef)) regressed_geodesic = gs.to_numpy(gs.autodiff.detach(regressed_geodesic)) size = 10 marker = "o" sphere_visu.draw_points(ax, gs.array([intercept_hat]), marker=marker, c="r", s=size) sphere_visu.draw_points(ax, y, marker=marker, c="b", s=size) sphere_visu.draw_points(ax, fitted_data, marker=marker, c="g", s=size) ax.plot( regressed_geodesic[:, 0], regressed_geodesic[:, 1], regressed_geodesic[:, 2], c="gray", ) sphere_visu.draw(ax, linewidth=1) ax.grid(False) plt.axis("off") plt.show()
def draw(self, ax, **kwargs): points_x = [gs.to_numpy(point[0]) for point in self.points] points_y = [gs.to_numpy(point[1]) for point in self.points] ax.scatter(points_x, points_y, **kwargs)