def _update_axes_bounds( verts_center: np.array, max_expand: float, current_layout: go.Scene, # pyre-ignore[11] ): """ Takes in the vertices' center point and max spread, and the current plotly figure layout and updates the layout to have bounds that include all traces for that subplot. Args: verts_center: tensor of size (3) corresponding to a trace's vertices' center point. max_expand: the maximum spread in any dimension of the trace's vertices. current_layout: the plotly figure layout scene corresponding to the referenced trace. """ verts_min = verts_center - max_expand verts_max = verts_center + max_expand bounds = np.stack([verts_min, verts_max], axis=-1) # Ensure that within a subplot, the bounds capture all traces old_xrange, old_yrange, old_zrange = ( current_layout['xaxis']['range'], current_layout['yaxis']['range'], current_layout['zaxis']['range'], ) x_range, y_range, z_range = bounds if old_xrange is not None: x_range[0] = min(x_range[0], old_xrange[0]) x_range[1] = max(x_range[1], old_xrange[1]) if old_yrange is not None: y_range[0] = min(y_range[0], old_yrange[0]) y_range[1] = max(y_range[1], old_yrange[1]) if old_zrange is not None: z_range[0] = min(z_range[0], old_zrange[0]) z_range[1] = max(z_range[1], old_zrange[1]) xaxis = {'range': x_range} yaxis = {'range': y_range} zaxis = {'range': z_range} current_layout.update({'xaxis': xaxis, 'yaxis': yaxis, 'zaxis': zaxis})
def _update_axes_bounds( verts_center: torch.Tensor, max_expand: float, current_layout: go.Scene, # pyre-ignore[11] ): # pragma: no cover """ Takes in the vertices' center point and max spread, and the current plotly figure layout and updates the layout to have bounds that include all traces for that subplot. Args: verts_center: tensor of size (3) corresponding to a trace's vertices' center point. max_expand: the maximum spread in any dimension of the trace's vertices. current_layout: the plotly figure layout scene corresponding to the referenced trace. """ verts_min = verts_center - max_expand verts_max = verts_center + max_expand bounds = torch.t(torch.stack((verts_min, verts_max))) # Ensure that within a subplot, the bounds capture all traces old_xrange, old_yrange, old_zrange = ( current_layout["xaxis"]["range"], current_layout["yaxis"]["range"], current_layout["zaxis"]["range"], ) x_range, y_range, z_range = bounds if old_xrange is not None: x_range[0] = min(x_range[0], old_xrange[0]) x_range[1] = max(x_range[1], old_xrange[1]) if old_yrange is not None: y_range[0] = min(y_range[0], old_yrange[0]) y_range[1] = max(y_range[1], old_yrange[1]) if old_zrange is not None: z_range[0] = min(z_range[0], old_zrange[0]) z_range[1] = max(z_range[1], old_zrange[1]) xaxis = {"range": x_range} yaxis = {"range": y_range} zaxis = {"range": z_range} current_layout.update({"xaxis": xaxis, "yaxis": yaxis, "zaxis": zaxis})