예제 #1
0
def plot(points, ax=None, space=None, **point_draw_kwargs):
    """
    Plot points in the 3D Special Euclidean Group,
    by showing them as trihedrons.
    """
    if space not in IMPLEMENTED:
        raise NotImplementedError(
            'The plot function is not implemented'
            ' for space {}. The spaces available for visualization'
            ' are: {}.'.format(space, IMPLEMENTED))

    if points is None:
        raise ValueError("No points given for plotting.")

    points = vectorization.to_ndarray(points, to_ndim=2)

    if ax is None:
        if space is 'SE3_GROUP':
            ax_s = AX_SCALE * np.amax(np.abs(points[:, 3:6]))
        elif space is 'SO3_GROUP':
            ax_s = AX_SCALE * np.amax(np.abs(points[:, :3]))
        else:
            ax_s = AX_SCALE

        if space is 'H2':
            ax = plt.subplot(aspect='equal')
            plt.setp(ax,
                     xlim=(-ax_s, ax_s),
                     ylim=(-ax_s, ax_s),
                     xlabel='X',
                     ylabel='Y')

        else:
            # The 3d projection needs the Axes3d module import.
            ax = plt.subplot(111, projection='3d', aspect='equal')
            plt.setp(ax,
                     xlim=(-ax_s, ax_s),
                     ylim=(-ax_s, ax_s),
                     zlim=(-ax_s, ax_s),
                     xlabel='X',
                     ylabel='Y',
                     zlabel='Z')

    if space in ('SO3_GROUP', 'SE3_GROUP'):
        trihedrons = convert_to_trihedron(points, space=space)
        for t in trihedrons:
            t.draw(ax, **point_draw_kwargs)

    elif space is 'S2':
        sphere = Sphere()
        sphere.add_points(points)
        sphere.draw(ax, **point_draw_kwargs)

    elif space is 'H2':
        poincare_disk = PoincareDisk()
        poincare_disk.add_points(points)
        poincare_disk.draw(ax, **point_draw_kwargs)

    return ax
예제 #2
0
def plot_and_save_video(geodesics,
                        loss,
                        size=20,
                        fps=10,
                        dpi=100,
                        out='out.mp4',
                        color='red'):
    """Render a set of geodesics and save it to an mpeg 4 file."""
    FFMpegWriter = animation.writers['ffmpeg']
    writer = FFMpegWriter(fps=fps)
    fig = plt.figure(figsize=(size, size))
    ax = fig.add_subplot(111, projection='3d', aspect='equal')
    sphere = visualization.Sphere()
    sphere.plot_heatmap(ax, loss)
    points = vectorization.to_ndarray(geodesics[0], to_ndim=2)
    sphere.add_points(points)
    sphere.draw(ax, color=color, marker='.')
    with writer.saving(fig, out, dpi=dpi):
        for points in geodesics[1:]:
            points = vectorization.to_ndarray(points, to_ndim=2)
            sphere.draw_points(ax, points=points, color=color, marker='.')
            writer.grab_frame()
예제 #3
0
def convert_to_trihedron(point, space=None):
    """
    Transform a rigid pointrmation
    into a trihedron s.t.:
    - the trihedron's base point is the translation of the origin
    of R^3 by the translation part of point,
    - the trihedron's orientation is the rotation of the canonical basis
    of R^3 by the rotation part of point.
    """
    point = vectorization.to_ndarray(point, to_ndim=2)
    n_points, _ = point.shape

    dim_rotations = SO3_GROUP.dimension

    if space is 'SE3_GROUP':
        rot_vec = point[:, :dim_rotations]
        translation = point[:, dim_rotations:]
    elif space is 'SO3_GROUP':
        rot_vec = point
        translation = gs.zeros((n_points, 3))
    else:
        raise NotImplementedError(
            'Trihedrons are only implemented for SO(3) and SE(3).')

    rot_mat = SO3_GROUP.matrix_from_rotation_vector(rot_vec)
    rot_mat = SO3_GROUP.projection(rot_mat)
    basis_vec_1 = gs.array([1, 0, 0])
    basis_vec_2 = gs.array([0, 1, 0])
    basis_vec_3 = gs.array([0, 0, 1])

    trihedrons = []
    for i in range(n_points):
        trihedron_vec_1 = gs.dot(rot_mat[i], basis_vec_1)
        trihedron_vec_2 = gs.dot(rot_mat[i], basis_vec_2)
        trihedron_vec_3 = gs.dot(rot_mat[i], basis_vec_3)
        trihedron = Trihedron(translation[i], trihedron_vec_1, trihedron_vec_2,
                              trihedron_vec_3)
        trihedrons.append(trihedron)
    return trihedrons