Ejemplo n.º 1
0
def plot_sfm_data_3d(sfm_data: GtsfmData,
                     ax: Axes,
                     max_plot_radius: float = 50) -> None:
    """Plot the camera poses and landmarks in 3D matplotlib plot.

    Args:
        sfm_data: SfmData object with camera and tracks.
        ax: axis to plot on.
        max_plot_radius: maximum distance threshold away from any camera for which a point
            will be plotted
    """
    camera_poses = [
        sfm_data.get_camera(i).pose()
        for i in sfm_data.get_valid_camera_indices()
    ]
    plot_poses_3d(camera_poses, ax)

    num_tracks = sfm_data.number_tracks()
    # Restrict 3d points to some radius of camera poses
    points_3d = np.array(
        [list(sfm_data.get_track(j).point3()) for j in range(num_tracks)])

    nearby_points_3d = comp_utils.get_points_within_radius_of_cameras(
        camera_poses, points_3d, max_plot_radius)

    # plot 3D points
    for landmark in nearby_points_3d:
        ax.plot(landmark[0], landmark[1], landmark[2], "g.", markersize=1)
Ejemplo n.º 2
0
def set_axes_equal(ax: Axes):
    """
    Make axes of 3D plot have equal scale so that spheres appear as spheres, cubes as cubes, etc..  This is one
    possible solution to Matplotlib's ax.set_aspect('equal') and ax.axis('equal') not working for 3D.

    Ref: https://github.com/borglab/gtsam/blob/develop/python/gtsam/utils/plot.py#L13

    Args:
        ax: axis for the plot.
    """
    # get the min and max value for each of (x, y, z) axes as 3x2 matrix.
    # This gives us the bounds of the minimum volume cuboid encapsulating all
    # data.
    limits = np.array([ax.get_xlim3d(), ax.get_ylim3d(), ax.get_zlim3d()])

    # find the centroid of the cuboid
    centroid = np.mean(limits, axis=1)

    # pick the largest edge length for this cuboid
    largest_edge_length = np.max(np.abs(limits[:, 1] - limits[:, 0]))

    # set new limits to draw a cube using the largest edge length
    radius = 0.5 * largest_edge_length
    ax.set_xlim3d([centroid[0] - radius, centroid[0] + radius])
    ax.set_ylim3d([centroid[1] - radius, centroid[1] + radius])
    ax.set_zlim3d([centroid[2] - radius, centroid[2] + radius])
Ejemplo n.º 3
0
def axes_equal(axes: Axes):
    """Adjust axis in a 3d plot to be equally scaled.

    Source code taken from the stackoverflow answer of 'karlo' in the
    following question:
    https://stackoverflow.com/questions/13685386/matplotlib-equal-unit
    -length-with-equal-aspect-ratio-z-axis-is-not-equal-to

    Parameters
    ----------
    axes :
        Matplotlib axes object (output from plt.gca())

    """
    x_limits = axes.get_xlim3d()
    y_limits = axes.get_ylim3d()
    z_limits = axes.get_zlim3d()

    x_range = abs(x_limits[1] - x_limits[0])
    x_middle = np.mean(x_limits)
    y_range = abs(y_limits[1] - y_limits[0])
    y_middle = np.mean(y_limits)
    z_range = abs(z_limits[1] - z_limits[0])
    z_middle = np.mean(z_limits)

    # The plot bounding box is a sphere in the sense of the infinity
    # norm, hence I call half the max range the plot radius.
    plot_radius = 0.5 * max([x_range, y_range, z_range])

    axes.set_xlim3d([x_middle - plot_radius, x_middle + plot_radius])
    axes.set_ylim3d([y_middle - plot_radius, y_middle + plot_radius])
    axes.set_zlim3d([z_middle - plot_radius, z_middle + plot_radius])
Ejemplo n.º 4
0
def plot_coordinate_systems(
    cs_data: Tuple[str, Dict],
    axes: Axes = None,
    title: str = None,
    limits: types_limits = None,
    time_index: int = None,
    legend_pos: str = "lower left",
) -> Axes:
    """Plot multiple coordinate systems.

    Parameters
    ----------
    cs_data :
        A tuple containing the coordinate system that should be plotted and a dictionary
        with the key word arguments that should be passed to its plot function.
    axes :
        The target axes object that should be drawn to. If `None` is provided, a new
        one will be created.
    title :
        The title of the plot
    limits :
        Each tuple marks lower and upper boundary of the x, y and z axis. If only a
        single tuple is passed, the boundaries are used for all axis. If `None`
        is provided, the axis are adjusted to be of equal length.
    time_index :
        Index of a specific time step that should be plotted if the corresponding
        coordinate system is time dependent
    legend_pos :
        A string that specifies the position of the legend. See the matplotlib
        documentation for further details

    Returns
    -------
    matplotlib.axes.Axes :
        The axes object that was used as canvas for the plot

    """
    if axes is None:
        _, axes = new_3d_figure_and_axes()

    for lcs, kwargs in cs_data:
        if "time_index" not in kwargs:
            kwargs["time_index"] = time_index
        lcs.plot(axes, **kwargs)

    _set_limits_matplotlib(axes, limits)

    if title is not None:
        axes.set_title(title)
    axes.legend(loc=legend_pos)

    return axes
Ejemplo n.º 5
0
def _fetch_objectives_from_figure(figure: Axes) -> List[float]:
    # Fetch line plots in parallel coordinate.
    line_collections = figure.findobj(LineCollection)
    assert len(line_collections) == 1

    # Fetch objective values from line plots.
    objectives = [line[0, 1] for line in line_collections[0].get_segments()]
    return objectives
Ejemplo n.º 6
0
def _add_descriptor_icons(
    descriptor_arr: npt.ArrayLike,
    icon_method: str,
    n_cond: int,
    ax: Axes = None,
    num_pattern_groups: int = None,
    icon_spacing: float = 1.0,
    linewidth: float = 0.5,
) -> list:
    """_add_descriptor_icons. Used internally by _add_descriptor_labels to add
    Icon-based labels to the X or Y axis.

    Args:
        descriptor_arr (npt.ArrayLike): np.Array-like version of the labels.
        icon_method (str): method to access on Icon instances (typically y_tick_label or
            x_tick_label).
        n_cond (int): Number of conditions in the RDM (usually from RDMs.n_cond).
        ax (matplotlib.axes._axes.Axes): Matplotlib axis handle.
        num_pattern_groups (int): Number of rows/columns for any image labels.
        icon_spacing (float): control spacing of image labels - 1. means no gap (the
            default), 1.1 means pad 10%, .9 means overlap 10% etc.
        linewidth (float): Width of connecting lines from icon labels (if used) to axis
            margin.  The default is 0.5 - set to 0. to disable the lines.

    Returns:
        list: Tick label handles.
    """
    # annotated labels with Icon
    n_to_fit = np.ceil(n_cond / num_pattern_groups)
    # work out sizing of icons
    im_max_pix = 20.
    if descriptor_arr[0].final_image:
        # size by image
        im_width_pix = max(this_desc.final_image.width
                           for this_desc in descriptor_arr)
        im_height_pix = max(this_desc.final_image.height
                            for this_desc in descriptor_arr)
        im_max_pix = max(im_width_pix, im_height_pix) * icon_spacing
    ax.figure.canvas.draw()
    extent = ax.get_window_extent(ax.figure.canvas.get_renderer())
    ax_size_pix = max((extent.width, extent.height))
    size = (ax_size_pix / n_to_fit) / im_max_pix
    # from proportion of original size to figure pixels
    offset = im_max_pix * size
    label_handles = []
    for group_ind in range(num_pattern_groups - 1, -1, -1):
        position = offset * 0.2 + offset * group_ind
        ticks = np.arange(group_ind, n_cond, num_pattern_groups)
        label_handles.append([
            getattr(this_desc, icon_method)(
                this_x,
                size,
                offset=position,
                linewidth=linewidth,
                ax=ax,
            ) for (this_x, this_desc) in zip(ticks, descriptor_arr[ticks])
        ])
    return label_handles
Ejemplo n.º 7
0
def _set_limits_matplotlib(
    axes: Axes,
    limits: types_limits,
    set_axes_equal: bool = False,
):
    """Set the limits of an axes object.

    Parameters
    ----------
    axes :
        The axes object
    limits :
        Each tuple marks lower and upper boundary of the x, y and z axis. If only a
        single tuple is passed, the boundaries are used for all axis. If `None`
        is provided, the axis are adjusted to be of equal length.
    set_axes_equal :
        (matplotlib only) If `True`, all axes are adjusted to cover an equally large
         range of value. That doesn't mean, that the limits are identical

    """
    if limits is not None:
        if isinstance(limits, Tuple):
            limits = [limits]
        if len(limits) == 1:
            limits = [limits[0] for _ in range(3)]
        axes.set_xlim(limits[0])
        axes.set_ylim(limits[1])
        axes.set_zlim(limits[2])
    elif set_axes_equal:
        axes_equal(axes)
Ejemplo n.º 8
0
def IFT_LoopMaskArray(aperture, loop_number: np.array, loop_radius: np.array, ax: axes.Axes, centred=False):
    """
    Draw circular array on the circular aperture face
    :param aperture: the radius of the array
    :param loop_number: number of each loop
    :param loop_radius: radius of each loop
    :param ax: handle of subplot
    :param centred: the centre whether has the element for True or False
    :return:
    """
    detal_phi = 2 * np.pi / loop_number
    total_number = np.einsum("i->", loop_number)

    if centred:
        loc = np.zeros([2, total_number + 1], dtype=np.float64)
        loop_index = 1
    else:
        loc = np.zeros([2, total_number], dtype=np.float64)
        loop_index = 0

    for zmc, radius in enumerate(loop_radius):
        current_number = loop_number[zmc]
        phi = np.arange(0, current_number) * detal_phi[zmc]
        loc_x = radius * np.cos(phi)
        loc_y = radius * np.sin(phi)
        loc[0, loop_index:loop_index + current_number] = loc_x
        loc[1, loop_index:loop_index + current_number] = loc_y
        loop_index = loop_index + current_number

    phi = np.linspace(0, 2 * np.pi, 360 * 4)

    # draw
    circle = ax.fill(aperture * np.cos(phi), aperture * np.sin(phi), color='lightblue', zorder=0)
    elements = ax.scatter(x=loc[0], y=loc[1], s=10, c='gray', zorder=1)

    ax.set_aspect('equal', 'box')
    ax.set_axis_off()
    ax.grid(True)

    return ax
Ejemplo n.º 9
0
def plot_poses_3d(wTi_list: List[Optional[Pose3]],
                  ax: Axes,
                  center_marker_color: str = "k",
                  label_name: Optional[str] = None) -> None:
    """Plot poses in 3D as dots for centers and lines denoting the orthonormal
    coordinate system for each camera.

    Color convention: R -> x axis, G -> y axis, B -> z axis.

    Args:
        wTi_list: list of poses to plot.
        ax: axis to plot on.
        center_marker_color (optional): color for camera center marker. Defaults to "k".
        name:
    """
    spec = "{}.".format(center_marker_color)

    is_label_added = False
    for wTi in wTi_list:
        if wTi is None:
            continue

        if is_label_added:
            # for the rest of iterations, set label to None (otherwise would be duplicated in legend)
            label_name = None

        x, y, z = wTi.translation().squeeze()
        ax.plot(x, y, z, spec, markersize=10, label=label_name)
        is_label_added = True

        R = wTi.rotation().matrix()

        # getting the direction of the coordinate system (x, y, z axes)
        default_axis_length = 0.5
        v1 = R[:, 0] * default_axis_length
        v2 = R[:, 1] * default_axis_length
        v3 = R[:, 2] * default_axis_length

        ax.plot3D([x, x + v1[0]], [y, y + v1[1]], [z, z + v1[2]], c="r")
        ax.plot3D([x, x + v2[0]], [y, y + v2[1]], [z, z + v2[2]], c="g")
        ax.plot3D([x, x + v3[0]], [y, y + v3[1]], [z, z + v3[2]], c="b")
Ejemplo n.º 10
0
    def show_mask(self, ax: axes.Axes, title="", fontfamily="times new roman", fontsize=15, scatter=20):

        phi = np.linspace(0, 2 * np.pi, 360 * 4)
        circle = ax.fill(self._aperture * np.cos(phi), self._aperture * np.sin(phi), color="lightblue", zorder=0)
        elements = ax.scatter(x=self._locations[0], y=self._locations[1], s=scatter, c="gray", zorder=2)

        if title == "":
            title = "Total number: " + str(self.get_total_number())

        # loop line
        for radius in self._radius:
            ax.plot(radius * np.cos(phi), radius * np.sin(phi), color="orange", zorder=1)

        ax.set_aspect("equal", 'box')
        ax.set_axis_off()
        ax.grid(True)
        ax.set_title(title, fontsize=fontsize, fontfamily=fontfamily)
        return ax
Ejemplo n.º 11
0
def show_rdm_panel(
    rdm: rsatoolbox.rdm.RDMs,
    ax: Axes = None,
    cmap: Union[str, matplotlib.colors.Colormap] = None,
    nanmask: npt.ArrayLike = None,
    rdm_descriptor: str = None,
    gridlines: npt.ArrayLike = None,
    vmin: float = None,
    vmax: float = None,
) -> matplotlib.image.AxesImage:
    """show_rdm_panel. Add RDM heatmap to the axis ax.

    Args:
        rdm (rsatoolbox.rdm.RDMs): RDMs object to be plotted (n_rdm must be 1).
        ax (matplotlib.axes._axes.Axes): Matplotlib axis handle. plt.gca() by default.
        cmap (Union[str, matplotlib.colors.Colormap]): colormap to be used (by
            plt.imshow internally). By default we use rdm_colormap.
        nanmask (npt.ArrayLike): boolean mask defining RDM elements to suppress
            (by default, the diagonals).
        rdm_descriptor (str): Key for rdm_descriptor to use as panel title, or
            str for direct labeling.
        gridlines (npt.ArrayLike): Set to add gridlines at these positions.
        vmin (float): Minimum intensity for colorbar mapping. matplotlib imshow
            argument.
        vmax (float): Maximum intensity for colorbar mapping. matplotlib imshow
            argument.

    Returns:
        matplotlib.image.AxesImage: Matplotlib handle.
    """
    if rdm.n_rdm > 1:
        raise ValueError(
            "expected single rdm - use show_rdm for multi-panel figures")
    if ax is None:
        ax = plt.gca()
    if cmap is None:
        cmap = rdm_colormap()
    if nanmask is None:
        nanmask = np.eye(rdm.n_cond, dtype=bool)
    if not np.any(gridlines):
        gridlines = []
    rdmat = rdm.get_matrices()[0, :, :]
    if np.any(nanmask):
        rdmat[nanmask] = np.nan
    image = ax.imshow(rdmat, cmap=cmap, vmin=vmin, vmax=vmax)
    ax.set_xlim(-0.5, rdm.n_cond - 0.5)
    ax.set_ylim(rdm.n_cond - 0.5, -0.5)
    ax.xaxis.set_ticks(gridlines)
    ax.yaxis.set_ticks(gridlines)
    ax.xaxis.set_ticklabels([])
    ax.yaxis.set_ticklabels([])
    ax.xaxis.set_ticks(np.arange(rdm.n_cond), minor=True)
    ax.yaxis.set_ticks(np.arange(rdm.n_cond), minor=True)
    # hide minor ticks by default
    ax.xaxis.set_tick_params(length=0, which="minor")
    ax.yaxis.set_tick_params(length=0, which="minor")
    if rdm_descriptor in rdm.rdm_descriptors:
        ax.set_title(rdm.rdm_descriptors[rdm_descriptor][0])
    else:
        ax.set_title(rdm_descriptor)
    return image
Ejemplo n.º 12
0
def draw_coordinate_system_matplotlib(
    coordinate_system: LocalCoordinateSystem,
    axes: Axes,
    color: Any = None,
    label: str = None,
    time_idx: int = None,
    scale_vectors: Union[float, List, np.ndarray] = None,
    show_origin: bool = True,
    show_vectors: bool = True,
):
    """Draw a coordinate system in a matplotlib 3d plot.

    Parameters
    ----------
    coordinate_system :
        Coordinate system
    axes :
        Target matplotlib axes object
    color :
        Valid matplotlib color selection. The origin of the coordinate system
        will be marked with this color.
    label :
        Name that appears in the legend. Only viable if a color
        was specified.
    time_idx :
        Selects time dependent data by index if the coordinate system has
        a time dependency.
    scale_vectors :
        A scaling factor or array to adjust the vector length
    show_origin :
        If `True`, the origin of the coordinate system will be highlighted in the
        color passed as another parameter
    show_vectors :
        If `True`, the the coordinate axes of the coordinate system are visualized

    """
    if not (show_vectors or show_origin):
        return
    if "time" in coordinate_system.dataset.coords:
        if time_idx is None:
            time_idx = 0
        if isinstance(time_idx, int):
            dsx = coordinate_system.dataset.isel(time=time_idx)
        else:
            dsx = coordinate_system.dataset.sel(time=time_idx).isel(time=0)
    else:
        dsx = coordinate_system.dataset

    p_0 = dsx.coordinates

    if show_vectors:
        if scale_vectors is None:
            tips = dsx.orientation
        else:
            if not isinstance(scale_vectors, np.ndarray):
                if isinstance(scale_vectors, List):
                    scale_vectors = np.array(scale_vectors)
                else:
                    scale_vectors = np.array([scale_vectors for _ in range(3)])

            scale_mat = np.eye(3, 3)
            for i in range(3):
                scale_mat[i, i] = scale_vectors[i]
            tips = np.matmul(scale_mat, dsx.orientation.data)

        p_x = p_0 + tips[:, 0]
        p_y = p_0 + tips[:, 1]
        p_z = p_0 + tips[:, 2]

        axes.plot([p_0[0], p_x[0]], [p_0[1], p_x[1]], [p_0[2], p_x[2]], "r")
        axes.plot([p_0[0], p_y[0]], [p_0[1], p_y[1]], [p_0[2], p_y[2]], "g")
        axes.plot([p_0[0], p_z[0]], [p_0[1], p_z[1]], [p_0[2], p_z[2]], "b")
    if color is not None:
        if show_origin:
            axes.plot([p_0[0]], [p_0[1]], [p_0[2]],
                      "o",
                      color=color,
                      label=label)
    elif label is not None:
        raise Exception("Labels can only be assigned if a color was specified")
Ejemplo n.º 13
0
def plot_spatial_data_matplotlib(
    data: geo.SpatialData,
    axes: Axes = None,
    color: Union[int, Tuple[int, int, int], Tuple[float, float, float]] = None,
    label: str = None,
    show_wireframe: bool = True,
) -> Axes:
    """Visualize a `weldx.geometry.SpatialData` instance.

    Parameters
    ----------
    data :
        The data that should be visualized
    axes :
        The target `matplotlib.axes.Axes` object of the plot. If 'None' is passed, a
        new figure will be created
    color :
        A 24 bit integer, a triplet of integers with a value range of 0-255
        or a triplet of floats with a value range of 0.0-1.0 that represent an RGB
        color
    label :
        Label of the plotted geometry
    show_wireframe :
        If `True`, the mesh is plotted as wireframe. Otherwise only the raster
        points are visualized. Currently, the wireframe can't be visualized if a
        `weldx.geometry.VariableProfile` is used.

    Returns
    -------
    matplotlib.axes.Axes :
        The `matplotlib.axes.Axes` instance that was used for the plot.

    """
    if axes is None:
        _, axes = new_3d_figure_and_axes()

    if not isinstance(data, geo.SpatialData):
        data = geo.SpatialData(data)

    if color is None:
        color = (0.0, 0.0, 0.0)
    else:
        color = color_to_rgb_normalized(color)

    coordinates = data.coordinates.data
    triangles = data.triangles

    # if data is time dependent or has other extra dimensions, just take the first value
    while coordinates.ndim > 2:
        coordinates = coordinates[0]

    axes.scatter(
        coordinates[:, 0],
        coordinates[:, 1],
        coordinates[:, 2],
        marker=".",
        color=color,
        label=label,
        zorder=2,
    )
    if triangles is not None and show_wireframe:
        for triangle in triangles:
            triangle_data = coordinates[[*triangle, triangle[0]], :]
            axes.plot(
                triangle_data[:, 0],
                triangle_data[:, 1],
                triangle_data[:, 2],
                color=color,
                zorder=1,
            )

    return axes
Ejemplo n.º 14
0
def plot_coordinate_system_manager_matplotlib(
    csm: CoordinateSystemManager,
    axes: Axes = None,
    reference_system: str = None,
    coordinate_systems: List[str] = None,
    data_sets: List[str] = None,
    colors: Dict[str, int] = None,
    time: types_timeindex = None,
    time_ref: pd.Timestamp = None,
    title: str = None,
    limits: types_limits = None,
    scale_vectors: Union[float, List, np.ndarray] = None,
    set_axes_equal: bool = False,
    show_origins: bool = True,
    show_trace: bool = True,
    show_vectors: bool = True,
    show_wireframe: bool = True,
) -> Axes:
    """Plot the coordinate systems of a `weldx.transformations.CoordinateSystemManager`.

    Parameters
    ----------
    csm :
        The coordinate system manager instance that should be plotted.
    axes :
        The target axes object that should be drawn to. If `None` is provided, a new
        one will be created.
    reference_system :
        The name of the reference system for the plotted coordinate systems
    coordinate_systems :
        Names of the coordinate systems that should be drawn. If `None` is provided,
        all systems are plotted.
    data_sets :
        Names of the data sets that should be drawn. If `None` is provided, all data
        is plotted.
    colors :
        A mapping between a coordinate system name or a data set name and a color.
        The colors must be provided as 24 bit integer values that are divided into
        three 8 bit sections for the rgb values. For example `0xFF0000` for pure
        red.
        Each coordinate system or data set that does not have a mapping in this
        dictionary will get a default color assigned to it.
    time :
        The time steps that should be plotted
    time_ref :
        A reference timestamp that can be provided if the ``time`` parameter is a
        `pandas.TimedeltaIndex`
    title :
        The title of the plot
    limits :
        Each tuple marks lower and upper boundary of the x, y and z axis. If only a
        single tuple is passed, the boundaries are used for all axis. If `None`
        is provided, the axis are adjusted to be of equal length.
    scale_vectors :
        A scaling factor or array to adjust the length of the coordinate system vectors
    set_axes_equal :
        (matplotlib only) If `True`, all axes are adjusted to cover an equally large
         range of value. That doesn't mean, that the limits are identical
    show_origins :
        If `True`, the origins of the coordinate system are visualized in the color
        assigned to the coordinate system.
    show_trace :
        If `True`, the trace of time dependent coordinate systems is plotted.
    show_vectors :
        If `True`, the coordinate cross of time dependent coordinate systems is plotted.
    show_wireframe :
        If `True`, the mesh is visualized as wireframe. Otherwise, it is not shown.

    Returns
    -------
    matplotlib.axes.Axes :
        The axes object that was used as canvas for the plot.

    """
    if time is not None:
        return plot_coordinate_system_manager_matplotlib(
            csm.interp_time(time=time, time_ref=time_ref),
            axes=axes,
            reference_system=reference_system,
            coordinate_systems=coordinate_systems,
            title=title,
            show_origins=show_origins,
            show_trace=show_trace,
            show_vectors=show_vectors,
        )
    if axes is None:
        _, axes = new_3d_figure_and_axes()
        axes.set_xlabel("x")
        axes.set_ylabel("y")
        axes.set_zlabel("z")

    if reference_system is None:
        reference_system = csm.root_system_name
    if coordinate_systems is None:
        coordinate_systems = csm.coordinate_system_names
    if data_sets is None:
        data_sets = csm.data_names
    if title is not None:
        axes.set_title(title)

    # plot coordinate systems
    color_gen = color_generator_function()
    for lcs_name in coordinate_systems:
        color = color_int_to_rgb_normalized(
            get_color(lcs_name, colors, color_gen))
        lcs = csm.get_cs(lcs_name, reference_system)
        lcs.plot(
            axes=axes,
            color=color,
            label=lcs_name,
            scale_vectors=scale_vectors,
            show_origin=show_origins,
            show_trace=show_trace,
            show_vectors=show_vectors,
        )
    # plot data
    for data_name in data_sets:
        color = color_int_to_rgb_normalized(
            get_color(data_name, colors, color_gen))
        data = csm.get_data(data_name, reference_system)
        plot_spatial_data_matplotlib(
            data=data,
            axes=axes,
            color=color,
            label=data_name,
            show_wireframe=show_wireframe,
        )

    _set_limits_matplotlib(axes, limits, set_axes_equal)
    axes.legend()

    return axes
Ejemplo n.º 15
0
def plot_local_coordinate_system_matplotlib(
    lcs: LocalCoordinateSystem,
    axes: Axes = None,
    color: Any = None,
    label: str = None,
    time: types_timeindex = None,
    time_ref: pd.Timestamp = None,
    time_index: int = None,
    scale_vectors: Union[float, List, np.ndarray] = None,
    show_origin: bool = True,
    show_trace: bool = True,
    show_vectors: bool = True,
) -> Axes:
    """Visualize a `weldx.transformations.LocalCoordinateSystem` using matplotlib.

    Parameters
    ----------
    lcs :
        The coordinate system that should be visualized
    axes :
        The target matplotlib axes. If `None` is provided, a new one will be created
    color :
        An arbitrary color. The data type must be compatible with matplotlib.
    label :
        Name of the coordinate system
    time :
        The time steps that should be plotted
    time_ref :
        A reference timestamp that can be provided if the ``time`` parameter is a
        `pandas.TimedeltaIndex`
    time_index :
        Index of a specific time step that should be plotted
    scale_vectors :
        A scaling factor or array to adjust the vector length
    show_origin :
        If `True`, the origin of the coordinate system will be highlighted in the
        color passed as another parameter
    show_trace :
        If `True`, the trace of a time dependent coordinate system will be visualized in
        the color passed as another parameter
    show_vectors :
        If `True`, the the coordinate axes of the coordinate system are visualized

    Returns
    -------
    matplotlib.axes.Axes :
        The axes object that was used as canvas for the plot.

    """
    if axes is None:
        _, axes = plt.subplots(subplot_kw={
            "projection": "3d",
            "proj_type": "ortho"
        })

    if lcs.is_time_dependent and time is not None:
        lcs = lcs.interp_time(time, time_ref)

    if lcs.is_time_dependent and time_index is None:
        for i, _ in enumerate(lcs.time):
            draw_coordinate_system_matplotlib(
                lcs,
                axes,
                color=color,
                label=label,
                time_idx=i,
                scale_vectors=scale_vectors,
                show_origin=show_origin,
                show_vectors=show_vectors,
            )
            label = None
    else:
        draw_coordinate_system_matplotlib(
            lcs,
            axes,
            color=color,
            label=label,
            time_idx=time_index,
            scale_vectors=scale_vectors,
            show_origin=show_origin,
            show_vectors=show_vectors,
        )

    if show_trace and lcs.coordinates.values.ndim > 1:
        coords = lcs.coordinates.values
        if color is None:
            color = "k"
        axes.plot(coords[:, 0], coords[:, 1], coords[:, 2], ":", color=color)

    return axes
Ejemplo n.º 16
0
def plot_spatial_data_matplotlib(
    data: Union[np.ndarray, geo.SpatialData],
    axes: Axes = None,
    color: Union[int, tuple[int, int, int], tuple[float, float, float]] = None,
    label: str = None,
    limits: types_limits = None,
    show_wireframe: bool = True,
) -> Axes:
    """Visualize a `weldx.geometry.SpatialData` instance.

    Parameters
    ----------
    data :
        The data that should be visualized
    axes :
        The target `matplotlib.axes.Axes` object of the plot. If 'None' is passed, a
        new figure will be created
    color :
        A 24 bit integer, a triplet of integers with a value range of 0-255
        or a triplet of floats with a value range of 0.0-1.0 that represent an RGB
        color
    label :
        Label of the plotted geometry
    limits :
        Each tuple marks lower and upper boundary of the x, y and z axis. If only a
        single tuple is passed, the boundaries are used for all axis. If `None`
        is provided, the axis are adjusted to be of equal length.
    show_wireframe :
        If `True`, the mesh is plotted as wireframe. Otherwise only the raster
        points are visualized. Currently, the wireframe can't be visualized if a
        `weldx.geometry.VariableProfile` is used.

    Returns
    -------
    matplotlib.axes.Axes :
        The `matplotlib.axes.Axes` instance that was used for the plot.

    """
    if axes is None:
        _, axes = new_3d_figure_and_axes()

    if not isinstance(data, geo.SpatialData):
        data = geo.SpatialData(data)

    if color is None:
        color = (0.0, 0.0, 0.0)
    else:
        color = color_to_rgb_normalized(color)

    coordinates = data.coordinates.data
    if isinstance(coordinates, Q_):
        coordinates = coordinates.to(_DEFAULT_LEN_UNIT).m
    coordinates = coordinates.reshape(-1, 3)
    triangles = data.triangles

    # if data is time dependent or has other extra dimensions, just take the first value
    while coordinates.ndim > 2:
        coordinates = coordinates[0]

    axes.scatter(
        coordinates[:, 0],
        coordinates[:, 1],
        coordinates[:, 2],
        marker=".",
        color=color,
        label=label,
        zorder=2,
    )
    if triangles is not None and show_wireframe:
        for triangle in triangles:
            triangle_data = coordinates[[*triangle, triangle[0]], :]
            axes.plot(
                triangle_data[:, 0],
                triangle_data[:, 1],
                triangle_data[:, 2],
                color=color,
                zorder=1,
            )

    _set_limits_matplotlib(axes, limits)
    return axes