def plot(points, ax=None, space=None, point_type='extrinsic', **point_draw_kwargs): """Plot points in the 3D Special Euclidean Group. Plot points in the 3D Special Euclidean Group, by showing them as trihedrons. """ if space not in IMPLEMENTED: raise NotImplementedError( 'The plot function is not implemented' ' for space {}. The spaces available for visualization' ' are: {}.'.format(space, IMPLEMENTED)) if points is None: raise ValueError("No points given for plotting.") points = gs.to_ndarray(points, to_ndim=2) if space in ('SO3_GROUP', 'SE3_GROUP'): if ax is None: ax = plt.subplot(111, projection='3d') if space == 'SE3_GROUP': ax_s = AX_SCALE * gs.amax(gs.abs(points[:, 3:6])) elif space == 'SO3_GROUP': ax_s = AX_SCALE * gs.amax(gs.abs(points[:, :3])) plt.setp(ax, xlim=(-ax_s, ax_s), ylim=(-ax_s, ax_s), zlim=(-ax_s, ax_s), xlabel='X', ylabel='Y', zlabel='Z') trihedrons = convert_to_trihedron(points, space=space) for t in trihedrons: t.draw(ax, **point_draw_kwargs) elif space == 'S1': circle = Circle() ax = circle.set_ax(ax=ax) circle.add_points(points) circle.draw(ax, **point_draw_kwargs) elif space == 'S2': sphere = Sphere() ax = sphere.set_ax(ax=ax) sphere.add_points(points) sphere.draw(ax, **point_draw_kwargs) elif space == 'H2_poincare_disk': poincare_disk = PoincareDisk(point_type=point_type) ax = poincare_disk.set_ax(ax=ax) poincare_disk.add_points(points) poincare_disk.draw(ax, **point_draw_kwargs) elif space == 'poincare_polydisk': n_disks = points.shape[1] poincare_poly_disk = PoincarePolyDisk(point_type=point_type, n_disks=n_disks) n_columns = gs.ceil(n_disks**0.5) n_rows = gs.ceil(n_disks / n_columns) axis_list = [] for i_disk in range(n_disks): axis_list.append(ax.add_subplot(n_rows, n_columns, i_disk + 1)) for i_disk, ax in enumerate(axis_list): ax = poincare_poly_disk.set_ax(ax=ax) poincare_poly_disk.clear_points() poincare_poly_disk.add_points(points[:, i_disk, ...]) poincare_poly_disk.draw(ax, **point_draw_kwargs) elif space == 'H2_poincare_half_plane': poincare_half_plane = PoincareHalfPlane() ax = poincare_half_plane.set_ax(ax=ax) poincare_half_plane.add_points(points) poincare_half_plane.draw(ax, **point_draw_kwargs) elif space == 'H2_klein_disk': klein_disk = KleinDisk() ax = klein_disk.set_ax(ax=ax) klein_disk.add_points(points) klein_disk.draw(ax, **point_draw_kwargs) return ax
def plot(points, ax=None, space=None, point_type=None, **point_draw_kwargs): """Plot points in one of the implemented manifolds. The implemented manifolds are: - the special orthogonal group SO(3) - the special Euclidean group SE(3) - the circle S1 and the sphere S2 - the hyperbolic plane (the Poincare disk, the Poincare half plane and the Klein disk) - the Poincare polydisk - the Kendall shape space of 2D triangles - the Kendall shape space of 3D triangles Parameters ---------- points : array-like, shape=[..., dim] Points to be plotted. space: str, optional, {'SO3_GROUP', 'SE3_GROUP', 'S1', 'S2', 'H2_poincare_disk', 'H2_poincare_half_plane', 'H2_klein_disk', 'poincare_polydisk', 'S32', 'M32', 'S33', 'M33'} point_type: str, optional, {'extrinsic', 'ball', 'half-space', 'pre-shape'} """ if space not in IMPLEMENTED: raise NotImplementedError( 'The plot function is not implemented' ' for space {}. The spaces available for visualization' ' are: {}.'.format(space, IMPLEMENTED)) if points is None: raise ValueError("No points given for plotting.") if points.ndim < 2: points = gs.expand_dims(points, 0) if space in ('SO3_GROUP', 'SE3_GROUP'): if ax is None: ax = plt.subplot(111, projection='3d') if space == 'SE3_GROUP': ax_s = AX_SCALE * gs.amax(gs.abs(points[:, 3:6])) elif space == 'SO3_GROUP': ax_s = AX_SCALE * gs.amax(gs.abs(points[:, :3])) ax_s = float(ax_s) bounds = (-ax_s, ax_s) plt.setp(ax, xlim=bounds, ylim=bounds, zlim=bounds, xlabel='X', ylabel='Y', zlabel='Z') trihedrons = convert_to_trihedron(points, space=space) for t in trihedrons: t.draw(ax, **point_draw_kwargs) elif space == 'S1': circle = Circle() ax = circle.set_ax(ax=ax) circle.add_points(points) circle.draw(ax, **point_draw_kwargs) elif space == 'S2': sphere = Sphere() ax = sphere.set_ax(ax=ax) sphere.add_points(points) sphere.draw(ax, **point_draw_kwargs) elif space == 'H2_poincare_disk': if point_type is None: point_type = 'extrinsic' poincare_disk = PoincareDisk(point_type=point_type) ax = poincare_disk.set_ax(ax=ax) poincare_disk.add_points(points) poincare_disk.draw(ax, **point_draw_kwargs) plt.axis('off') elif space == 'poincare_polydisk': if point_type is None: point_type = 'extrinsic' n_disks = points.shape[1] poincare_poly_disk = PoincarePolyDisk(point_type=point_type, n_disks=n_disks) n_columns = int(gs.ceil(n_disks**0.5)) n_rows = int(gs.ceil(n_disks / n_columns)) axis_list = [] for i_disk in range(n_disks): axis_list.append(ax.add_subplot(n_rows, n_columns, i_disk + 1)) for i_disk, one_ax in enumerate(axis_list): ax = poincare_poly_disk.set_ax(ax=one_ax) poincare_poly_disk.clear_points() poincare_poly_disk.add_points(points[:, i_disk, ...]) poincare_poly_disk.draw(ax, **point_draw_kwargs) elif space == 'H2_poincare_half_plane': if point_type is None: point_type = 'half-space' poincare_half_plane = PoincareHalfPlane(point_type=point_type) ax = poincare_half_plane.set_ax(points=points, ax=ax) poincare_half_plane.add_points(points) poincare_half_plane.draw(ax, **point_draw_kwargs) elif space == 'H2_klein_disk': klein_disk = KleinDisk() ax = klein_disk.set_ax(ax=ax) klein_disk.add_points(points) klein_disk.draw(ax, **point_draw_kwargs) elif space == 'SE2_GROUP': plane = SpecialEuclidean2() ax = plane.set_ax(ax=ax) plane.add_points(points) plane.draw(ax, **point_draw_kwargs) elif space == 'S32': sphere = KendallSphere() sphere.add_points(points) sphere.draw() sphere.draw_points() ax = sphere.ax elif space == 'M32': sphere = KendallSphere(point_type='extrinsic') sphere.add_points(points) sphere.draw() sphere.draw_points() ax = sphere.ax elif space == 'S33': disk = KendallDisk() disk.add_points(points) disk.draw() disk.draw_points() ax = disk.ax elif space == 'M33': disk = KendallDisk(point_type='extrinsic') disk.add_points(points) disk.draw() disk.draw_points() ax = disk.ax return ax