def plot_meshes(meshes: Meshes, *, in_subplots: bool = False, ncols: int = 1, **kwargs): """ Takes a Meshes object and generates a plotly figure. If there is more than one mesh in the batch and in_subplots=True, each mesh will be visualized in an individual subplot with ncols number of subplots in the same row. Otherwise, each mesh in the batch will be visualized as an individual trace in the same plot. If the Meshes object has vertex colors defined as its texture, the vertex colors will be used for generating the plotly figure. Otherwise plotly's default colors will be used. Args: meshes: Meshes object to be visualized in a plotly figure. in_subplots: if each mesh in the batch should be visualized in an individual subplot ncols: number of subplots in the same row if in_subplots is set to be True. Otherwise ncols will be ignored. **kwargs: Accepts lighting (a Lighting object) and lightposition for Mesh3D and any of the args xaxis, yaxis and zaxis accept for scene. Accepts axis_args which is an AxisArgs object that is applied to the 3 axes. Also accepts subplot_titles, which should be a list of string titles matching the number of subplots. Example settings for axis_args and lighting are given above. Returns: Plotly figure of the mesh. If there is more than one mesh in the batch, the plotly figure will contain a series of vertically stacked subplots. """ meshes = meshes.detach().cpu() subplot_titles = kwargs.get("subplot_titles", None) fig = _gen_fig_with_subplots(len(meshes), in_subplots, ncols, subplot_titles) for i in range(len(meshes)): verts = meshes[i].verts_packed() faces = meshes[i].faces_packed() # If mesh has vertex colors defined as texture, use vertex colors # for figure, otherwise use plotly's default colors. verts_rgb = None if isinstance(meshes[i].textures, TexturesVertex): verts_rgb = meshes[i].textures.verts_features_packed() verts_rgb.clamp_(min=0.0, max=1.0) verts_rgb = torch.tensor(255.0) * verts_rgb # Reposition the unused vertices to be "inside" the object # (i.e. they won't be visible in the plot). verts_used = torch.zeros((verts.shape[0], ), dtype=torch.bool) verts_used[torch.unique(faces)] = True verts_center = verts[verts_used].mean(0) verts[~verts_used] = verts_center trace_row = i // ncols + 1 if in_subplots else 1 trace_col = i % ncols + 1 if in_subplots else 1 fig.add_trace( go.Mesh3d( # pyre-ignore[16] x=verts[:, 0], y=verts[:, 1], z=verts[:, 2], vertexcolor=verts_rgb, i=faces[:, 0], j=faces[:, 1], k=faces[:, 2], lighting=kwargs.get("lighting", Lighting())._asdict(), lightposition=kwargs.get("lightposition", {}), ), row=trace_row, col=trace_col, ) # Ensure update for every subplot. plot_scene = "scene" + str(i + 1) if in_subplots else "scene" current_layout = fig["layout"][plot_scene] axis_args = kwargs.get("axis_args", AxisArgs()) xaxis, yaxis, zaxis = _gen_updated_axis_bounds(verts, verts_center, current_layout, axis_args) # Update the axis bounds with the axis settings passed in as kwargs. xaxis.update(**kwargs.get("xaxis", {})) yaxis.update(**kwargs.get("yaxis", {})) zaxis.update(**kwargs.get("zaxis", {})) current_layout.update({ "xaxis": xaxis, "yaxis": yaxis, "zaxis": zaxis, "aspectmode": "cube" }) return fig