def plot(points, ax=None, space=None, **point_draw_kwargs): """ Plot points in the 3D Special Euclidean Group, by showing them as trihedrons. """ if space not in IMPLEMENTED: raise NotImplementedError( 'The plot function is not implemented' ' for space {}. The spaces available for visualization' ' are: {}.'.format(space, IMPLEMENTED)) if points is None: raise ValueError("No points given for plotting.") points = vectorization.to_ndarray(points, to_ndim=2) if ax is None: if space is 'SE3_GROUP': ax_s = AX_SCALE * np.amax(np.abs(points[:, 3:6])) elif space is 'SO3_GROUP': ax_s = AX_SCALE * np.amax(np.abs(points[:, :3])) else: ax_s = AX_SCALE if space is 'H2': ax = plt.subplot(aspect='equal') plt.setp(ax, xlim=(-ax_s, ax_s), ylim=(-ax_s, ax_s), xlabel='X', ylabel='Y') else: # The 3d projection needs the Axes3d module import. ax = plt.subplot(111, projection='3d', aspect='equal') plt.setp(ax, xlim=(-ax_s, ax_s), ylim=(-ax_s, ax_s), zlim=(-ax_s, ax_s), xlabel='X', ylabel='Y', zlabel='Z') if space in ('SO3_GROUP', 'SE3_GROUP'): trihedrons = convert_to_trihedron(points, space=space) for t in trihedrons: t.draw(ax, **point_draw_kwargs) elif space is 'S2': sphere = Sphere() sphere.add_points(points) sphere.draw(ax, **point_draw_kwargs) elif space is 'H2': poincare_disk = PoincareDisk() poincare_disk.add_points(points) poincare_disk.draw(ax, **point_draw_kwargs) return ax
def plot_and_save_video(geodesics, loss, size=20, fps=10, dpi=100, out='out.mp4', color='red'): """Render a set of geodesics and save it to an mpeg 4 file.""" FFMpegWriter = animation.writers['ffmpeg'] writer = FFMpegWriter(fps=fps) fig = plt.figure(figsize=(size, size)) ax = fig.add_subplot(111, projection='3d', aspect='equal') sphere = visualization.Sphere() sphere.plot_heatmap(ax, loss) points = vectorization.to_ndarray(geodesics[0], to_ndim=2) sphere.add_points(points) sphere.draw(ax, color=color, marker='.') with writer.saving(fig, out, dpi=dpi): for points in geodesics[1:]: points = vectorization.to_ndarray(points, to_ndim=2) sphere.draw_points(ax, points=points, color=color, marker='.') writer.grab_frame()
def convert_to_trihedron(point, space=None): """ Transform a rigid pointrmation into a trihedron s.t.: - the trihedron's base point is the translation of the origin of R^3 by the translation part of point, - the trihedron's orientation is the rotation of the canonical basis of R^3 by the rotation part of point. """ point = vectorization.to_ndarray(point, to_ndim=2) n_points, _ = point.shape dim_rotations = SO3_GROUP.dimension if space is 'SE3_GROUP': rot_vec = point[:, :dim_rotations] translation = point[:, dim_rotations:] elif space is 'SO3_GROUP': rot_vec = point translation = gs.zeros((n_points, 3)) else: raise NotImplementedError( 'Trihedrons are only implemented for SO(3) and SE(3).') rot_mat = SO3_GROUP.matrix_from_rotation_vector(rot_vec) rot_mat = SO3_GROUP.projection(rot_mat) basis_vec_1 = gs.array([1, 0, 0]) basis_vec_2 = gs.array([0, 1, 0]) basis_vec_3 = gs.array([0, 0, 1]) trihedrons = [] for i in range(n_points): trihedron_vec_1 = gs.dot(rot_mat[i], basis_vec_1) trihedron_vec_2 = gs.dot(rot_mat[i], basis_vec_2) trihedron_vec_3 = gs.dot(rot_mat[i], basis_vec_3) trihedron = Trihedron(translation[i], trihedron_vec_1, trihedron_vec_2, trihedron_vec_3) trihedrons.append(trihedron) return trihedrons