def __init__(self, obj, width=512, height=512, textureMap=None, scalarField=None, vectorField=None): # Note: subclass's constructor should define # self.MeshConstructor and self.isLineMesh, which will # determine how the geometry is interpreted. if (self.isLineMesh is None): self.isLineMesh = False if (self.MeshConstructor is None): self.MeshConstructor = pythreejs.Mesh light = pythreejs.PointLight(color='white', position=[0, 0, 5]) light.intensity = 0.6 self.cam = pythreejs.PerspectiveCamera(position = [0, 0, 5], up = [0, 1, 0], aspect=width / height, children=[light]) self.avoidRedrawFlicker = False self.objects = pythreejs.Group() self.meshes = pythreejs.Group() self.ghostMeshes = pythreejs.Group() # Translucent meshes kept around by preserveExisting self.materialLibrary = MaterialLibrary(self.isLineMesh) # Sometimes we do not use a particular attribute buffer, e.g. the index buffer when displaying # per-face scalar fields. But to avoid reallocating these buffers when # switching away from these cases, we need to preserve the buffers # that may have previously been allocated. This is done with the bufferAttributeStash. # A buffer attribute, if it exists, must always be attached to the # current BufferGeometry or in this stash (but not both!). self.bufferAttributeStash = {} self.currMesh = None # The main mesh being viewed self.wireframeMesh = None # Wireframe for the main visualization mesh self.pointsMesh = None # Points for the main visualization mesh self.vectorFieldMesh = None self.cachedWireframeMaterial = None self.cachedPointsMaterial = None self.objects.add([self.meshes, self.ghostMeshes]) self.shouldShowWireframe = False self.scalarField = None self.vectorField = None self.arrowMaterial = None # Will hold this viewer's instance of the special vector field shader self._arrowSize = 60 # Camera needs to be part of the scene because the scene light is its child # (so that it follows the camera). self.scene = pythreejs.Scene(children=[self.objects, self.cam, pythreejs.AmbientLight(intensity=0.5)]) # Sane trackball controls. self.controls = pythreejs.TrackballControls(controlling=self.cam) self.controls.staticMoving = True self.controls.rotateSpeed = 2.0 self.controls.zoomSpeed = 2.0 self.controls.panSpeed = 1.0 self.renderer = pythreejs.Renderer(camera=self.cam, scene=self.scene, controls=[self.controls], width=width, height=height) self.update(True, obj, updateModelMatrix=True, textureMap=textureMap, scalarField=scalarField, vectorField=vectorField)
def generate_axis_ticks_and_labels(self): """ Create ticklabels on outline edges """ if self.tick_size is None: self.tick_size = 0.05 * np.amin( np.diff(list(self.xminmax.values()), axis=1).ravel()) ticks_and_labels = p3.Group() iden = np.identity(3, dtype=np.float32) ticker = mpl.ticker.MaxNLocator(5) offsets = { 'x': [0, self.xminmax['y'][0], self.xminmax['z'][0]], 'y': [self.xminmax['x'][0], 0, self.xminmax['z'][0]], 'z': [self.xminmax['x'][0], self.xminmax['y'][0], 0] } for axis, x in enumerate('xyz'): ticks = ticker.tick_values(self.xminmax[x][0], self.xminmax[x][1]) for tick in ticks: if tick >= self.xminmax[x][0] and tick <= self.xminmax[x][1]: tick_pos = iden[axis] * tick + offsets[x] ticks_and_labels.add( self.make_axis_tick(string=value_to_string( tick, precision=1), position=tick_pos.tolist(), size=self.tick_size)) ticks_and_labels.add( self.make_axis_tick( string=self.axlabels[x], position=(iden[axis] * 0.5 * np.sum(self.xminmax[x]) + offsets[x]).tolist(), size=self.tick_size * 0.3 * len(self.axlabels[x]))) return ticks_and_labels
def loadCadnano(self, filename): self.filename = filename self.model = read_cadnano(filename) self.selectable = three.Group( children=[self.drawSegment(s) for s in self.model.segments]) self.allMeshes = [ x for x in sum((segment.children for segment in self.selectable.children), ()) if x.type is 'Mesh' ] self.selected = set()
def sys2mesh(os): s = py3js.Group() if os is not None: for i in os.prop_ray: s.add(ray2mesh(i)) # Draw Components n = 0 for comp in os.complist: C, P, D = comp c = comp2mesh(C, P, D) s.add(c) return s
def add_labels(element_groups, key_elements, use_label_arrays): """Create label elements for the scene.""" import pythreejs as pjs group_labels = pjs.Group() unique_label_sets = {} for el in element_groups["atoms"]: if "label" in el and el.label is not None: unique_label_sets.setdefault( (("label", el.label), ("color", el.get("font_color", "black"))), []).append(el) if unique_label_sets: key_elements["group_labels"] = group_labels for el_hash, els in unique_label_sets.items(): el = els[0] data = dict(el_hash) # depthWrite=depthTest=False is required, for the sprite to remain on top, # and not have the whitespace obscure objects behind, see: # https://stackoverflow.com/questions/11165345/three-js-webgl-transparent-planes-hiding-other-planes-behind-them # TODO can this be improved? text_material = pjs.SpriteMaterial( map=pjs.TextTexture( string=el.label, color=el.get("font_color", "black"), size=2000, # this texttexture size seems to work, not sure why? ), opacity=1.0, transparent=True, depthWrite=False, depthTest=False, ) data["material"] = text_material key_elements.setdefault("label_arrays", []).append(data) if use_label_arrays: text_sprite = pjs.Sprite(material=text_material) label_array = pjs.CloneArray( original=text_sprite, positions=[e.position.tolist() for e in els], merge=False, ) else: label_array = [ pjs.Sprite(material=text_material, position=e.position.tolist()) for e in els ] group_labels.add(label_array) return group_labels
def comp2mesh(C, P, D): c = py3js.Group() if isinstance(C, Component): for surf in C.surflist: sS, sP, sD = surf s = surf2mesh(sS, sP, sD) c.add(s) elif isinstance(C, System): for comp in C.complist: sC, sP, sD = comp c.add(comp2mesh(sC, sP, sD)) #glPopMatrix() c.rotation = *D, "ZYX" c.position = tuple(P) return c
def ray2mesh(ray): rays = py3js.Group() w = ray.wavelength rc, gc, bc = wavelength2RGB(w) rc = int(255 * rc) gc = int(255 * gc) bc = int(255 * bc) material = py3js.LineBasicMaterial( color="#{:02X}{:02X}{:02X}".format(rc, gc, bc)) rl = ray2list(ray) for r in rl: geometry = py3js.Geometry() geometry.vertices = r line = py3js.Line(geometry, material) rays.add(line) return rays
def _generate_axis_ticks_and_labels(self, axparams): """ Create ticklabels on outline edges """ if self.tick_size is None: self.tick_size = 0.05 * np.amin([ axparams['x']["lims"][1] - axparams['x']["lims"][0], axparams['y']["lims"][1] - axparams['y']["lims"][0], axparams['z']["lims"][1] - axparams['z']["lims"][0] ]) ticks_and_labels = p3.Group() iden = np.identity(3, dtype=np.float32) ticker_ = ticker.MaxNLocator(5) offsets = { 'x': [0, axparams['y']["lims"][0], axparams['z']["lims"][0]], 'y': [axparams['x']["lims"][0], 0, axparams['z']["lims"][0]], 'z': [axparams['x']["lims"][0], axparams['y']["lims"][0], 0] } for axis, x in enumerate('xyz'): ticks = ticker_.tick_values(axparams[x]["lims"][0], axparams[x]["lims"][1]) for tick in ticks: if tick >= axparams[x]["lims"][0] and tick <= axparams[x][ "lims"][1]: tick_pos = iden[axis] * tick + offsets[x] ticks_and_labels.add( self._make_axis_tick(string=value_to_string( tick, precision=1), position=tick_pos.tolist(), size=self.tick_size)) axis_label = axparams[x][ "label"] if self.axlabels[x] is None else self.axlabels[x] ticks_and_labels.add( self._make_axis_tick( string=axis_label, position=(iden[axis] * 0.5 * np.sum(axparams[x]["lims"]) + offsets[x]).tolist(), size=self.tick_size * 0.3 * len(axis_label))) return ticks_and_labels
def ray2mesh(ray): rays = py3js.Group() if ray.draw_color is None: color = wavelength2RGB(ray.wavelength) else: color = colors.to_rgb(ray.draw_color) int_colors = [int(255 * c) for c in color] material = py3js.LineBasicMaterial(color="#{:02X}{:02X}{:02X}".format( *int_colors)) rl = ray2list(ray) for r in rl: geometry = py3js.Geometry() geometry.vertices = r line = py3js.Line(geometry, material) rays.add(line) return rays
def visualise(mesh, geometric_field, number_of_dimensions, xi_interpolation, dependent_field=None, variable=None, mechanics_animation=False, colour_map_dependent_component_number=None, cmap='gist_rainbow', resolution=1, node_labels=False): if number_of_dimensions != 3: print( 'Warning: Only visualisation of 3D meshes is currently supported.') return if xi_interpolation != [1, 1, 1]: print( 'Warning: Only visualisation of 3D elements with linear Lagrange \ interpolation along all coordinate directions is currently \ supported.') return view_width = 600 view_height = 600 debug = False if debug: vertices = [[0, 0, 0], [0, 0, 1], [0, 1, 0], [0, 1, 1], [1, 0, 0], [1, 0, 1], [1, 1, 0], [1, 1, 1]] faces = [[0, 1, 3], [0, 3, 2], [0, 2, 4], [2, 6, 4], [0, 4, 1], [1, 4, 5], [2, 3, 6], [3, 7, 6], [1, 5, 3], [3, 5, 7], [4, 6, 5], [5, 6, 7]] vertexcolors = [ '#000000', '#0000ff', '#00ff00', '#ff0000', '#00ffff', '#ff00ff', '#ffff00', '#ffffff' ] else: # Get mesh topology information. num_nodes = mesh_tools.num_nodes_get(mesh, mesh_component=1) node_nums = list(range(1, num_nodes + 1)) num_elements, element_nums = mesh_tools.num_element_get( mesh, mesh_component=1) # Convert geometric field to a morphic mesh and export to json mesh = mesh_tools.OpenCMISS_to_morphic(mesh, geometric_field, element_nums, node_nums, dimension=3, interpolation='linear') vertices, faces, _, xi_element_nums, xis = get_faces( mesh, res=resolution, exterior_only=True, include_xi=True) vertices = vertices.tolist() faces = faces.tolist() centroid = np.mean(vertices, axis=0) max_positions = np.max(vertices, axis=0) min_positions = np.min(vertices, axis=0) range_positions = max_positions - min_positions if (dependent_field is not None) and (colour_map_dependent_component_number is not None): solution = np.zeros(xis.shape[0]) for idx, (xi, xi_element_num) in enumerate(zip(xis, xi_element_nums)): solution[idx] = mesh_tools.interpolate_opencmiss_field_xi( dependent_field, xi, element_ids=[xi_element_num], dimension=3, deriv=1)[colour_map_dependent_component_number - 1] minima = min(solution) maxima = max(solution) import matplotlib norm = matplotlib.colors.Normalize(vmin=minima, vmax=maxima, clip=True) mapper = cm.ScalarMappable(norm=norm, cmap=cm.get_cmap(name=cmap)) vertex_colors = np.zeros((len(vertices), 3), dtype='float32') for idx, v in enumerate(solution): vertex_colors[idx, :] = mapper.to_rgba(v, alpha=None)[:3] # else: # raise ValueError('Visualisation not supported.') else: vertex_colors = np.tile(np.array([0.5, 0.5, 0.5], dtype='float32'), (len(vertices), 1)) geometry = pjs.BufferGeometry(attributes=dict( position=pjs.BufferAttribute(vertices, normalized=False), index=pjs.BufferAttribute( np.array(faces).astype(dtype='uint16').ravel(), normalized=False), color=pjs.BufferAttribute(vertex_colors), )) if mechanics_animation: deformed_vertices = np.zeros((xis.shape[0], 3), dtype='float32') for idx, (xi, xi_element_num) in enumerate(zip(xis, xi_element_nums)): deformed_vertices[idx, :] = \ mesh_tools.interpolate_opencmiss_field_xi( dependent_field, xi, element_ids=[xi_element_num], dimension=3, deriv=1)[0][:3] geometry.morphAttributes = { 'position': [ pjs.BufferAttribute(deformed_vertices), ] } geometry.exec_three_obj_method('computeFaceNormals') geometry.exec_three_obj_method('computeVertexNormals') surf1 = pjs.Mesh(geometry, pjs.MeshPhongMaterial(color='#ff3333', shininess=150, morphTargets=True, side='FrontSide'), name='A') surf2 = pjs.Mesh(geometry, pjs.MeshPhongMaterial(color='#ff3333', shininess=150, morphTargets=True, side='BackSide'), name='B') surf = pjs.Group(children=[surf1, surf2]) # camera = pjs.PerspectiveCamera( # fov=20, position=[range_positions[0] * 10, # range_positions[1] * 10, # range_positions[2] * 10], # width=view_width, # height=view_height, near=1, # far=max(range_positions) * 10) camera = pjs.PerspectiveCamera(position=[ range_positions[0] * 3, range_positions[1] * 3, range_positions[2] * 3 ], aspect=view_width / view_height) camera.up = [0, 0, 1] camera.lookAt(centroid.tolist()) scene3 = pjs.Scene(children=[ surf1, surf2, camera, pjs.DirectionalLight(position=[3, 5, 1], intensity=0.6), pjs.AmbientLight(intensity=0.5) ]) axes = pjs.AxesHelper(size=range_positions[0] * 2) scene3.add(axes) A_track = pjs.NumberKeyframeTrack( name='scene/A.morphTargetInfluences[0]', times=[0, 3], values=[0, 1]) B_track = pjs.NumberKeyframeTrack( name='scene/B.morphTargetInfluences[0]', times=[0, 3], values=[0, 1]) pill_clip = pjs.AnimationClip(tracks=[A_track, B_track]) pill_action = pjs.AnimationAction(pjs.AnimationMixer(scene3), pill_clip, scene3) renderer3 = pjs.Renderer( camera=camera, scene=scene3, controls=[pjs.OrbitControls(controlling=camera)], width=view_width, height=view_height) display(renderer3, pill_action) else: geometry.exec_three_obj_method('computeFaceNormals') geometry.exec_three_obj_method('computeVertexNormals') surf1 = pjs.Mesh(geometry=geometry, material=pjs.MeshLambertMaterial( vertexColors='VertexColors', side='FrontSide')) # Center the cube. surf2 = pjs.Mesh(geometry=geometry, material=pjs.MeshLambertMaterial( vertexColors='VertexColors', side='BackSide')) # Center the cube. surf = pjs.Group(children=[surf1, surf2]) camera = pjs.PerspectiveCamera(position=[ range_positions[0] * 3, range_positions[1] * 3, range_positions[2] * 3 ], aspect=view_width / view_height) camera.up = [0, 0, 1] camera.lookAt(centroid.tolist()) # if perspective: # camera.mode = 'perspective' # else: # camera.mode = 'orthographic' lights = [ pjs.DirectionalLight(position=[ range_positions[0] * 16, range_positions[1] * 12, range_positions[2] * 17 ], intensity=0.5), pjs.AmbientLight(intensity=0.8), ] orbit = pjs.OrbitControls(controlling=camera, screenSpacePanning=True, target=centroid.tolist()) scene = pjs.Scene() axes = pjs.AxesHelper(size=max(range_positions) * 2) scene.add(axes) scene.add(surf1) scene.add(surf2) scene.add(lights) if node_labels: # Add text labels for each mesh node. v, ids = mesh.get_node_ids(group='_default') for idx, v in enumerate(v): text = make_text(str(ids[idx]), position=(v[0], v[1], v[2])) scene.add(text) # Add text for axes labels. x_axis_label = make_text('x', position=(max(range_positions) * 2, 0, 0)) y_axis_label = make_text('y', position=(0, max(range_positions) * 2, 0)) z_axis_label = make_text('z', position=(0, 0, max(range_positions) * 2)) scene.add(x_axis_label) scene.add(y_axis_label) scene.add(z_axis_label) renderer = pjs.Renderer(scene=scene, camera=camera, controls=[orbit], width=view_width, height=view_height) camera.zoom = 1 display(renderer) return vertices, faces
def drawSegment(self, segment): beads = [BeadMesh(bead, self.beadColor) for bead in segment.beads] return three.Group(children=beads + [ConnectionLine(segment, self.lineColor)])
def generate_3js_render( element_groups, canvas_size, zoom, camera_fov=30, background_color="white", background_opacity=1.0, reuse_objects=False, use_atom_arrays=False, use_label_arrays=False, ): """Create a pythreejs scene of the elements. Regarding initialisation performance, see: https://github.com/jupyter-widgets/pythreejs/issues/154 """ import pythreejs as pjs key_elements = {} group_elements = pjs.Group() key_elements["group_elements"] = group_elements unique_atom_sets = {} for el in element_groups["atoms"]: element_hash = ( ("radius", el.sradius), ("color", el.color), ("fill_opacity", el.fill_opacity), ("stroke_color", el.get("stroke_color", "black")), ("ghost", el.ghost), ) unique_atom_sets.setdefault(element_hash, []).append(el) group_atoms = pjs.Group() group_ghosts = pjs.Group() atom_geometries = {} atom_materials = {} outline_materials = {} for el_hash, els in unique_atom_sets.items(): el = els[0] data = dict(el_hash) if reuse_objects: atom_geometry = atom_geometries.setdefault( el.sradius, pjs.SphereBufferGeometry(radius=el.sradius, widthSegments=30, heightSegments=30), ) else: atom_geometry = pjs.SphereBufferGeometry(radius=el.sradius, widthSegments=30, heightSegments=30) if reuse_objects: atom_material = atom_materials.setdefault( (el.color, el.fill_opacity), pjs.MeshLambertMaterial(color=el.color, transparent=True, opacity=el.fill_opacity), ) else: atom_material = pjs.MeshLambertMaterial(color=el.color, transparent=True, opacity=el.fill_opacity) if use_atom_arrays: atom_mesh = pjs.Mesh(geometry=atom_geometry, material=atom_material) atom_array = pjs.CloneArray( original=atom_mesh, positions=[e.position.tolist() for e in els], merge=False, ) else: atom_array = [ pjs.Mesh( geometry=atom_geometry, material=atom_material, position=e.position.tolist(), name=e.info_string, ) for e in els ] data["geometry"] = atom_geometry data["material_body"] = atom_material if el.ghost: key_elements["group_ghosts"] = group_ghosts group_ghosts.add(atom_array) else: key_elements["group_atoms"] = group_atoms group_atoms.add(atom_array) if el.get("stroke_width", 1) > 0: if reuse_objects: outline_material = outline_materials.setdefault( el.get("stroke_color", "black"), pjs.MeshBasicMaterial( color=el.get("stroke_color", "black"), side="BackSide", transparent=True, opacity=el.get("stroke_opacity", 1.0), ), ) else: outline_material = pjs.MeshBasicMaterial( color=el.get("stroke_color", "black"), side="BackSide", transparent=True, opacity=el.get("stroke_opacity", 1.0), ) # TODO use stroke width to dictate scale if use_atom_arrays: outline_mesh = pjs.Mesh( geometry=atom_geometry, material=outline_material, scale=(1.05, 1.05, 1.05), ) outline_array = pjs.CloneArray( original=outline_mesh, positions=[e.position.tolist() for e in els], merge=False, ) else: outline_array = [ pjs.Mesh( geometry=atom_geometry, material=outline_material, position=e.position.tolist(), scale=(1.05, 1.05, 1.05), ) for e in els ] data["material_outline"] = outline_material if el.ghost: group_ghosts.add(outline_array) else: group_atoms.add(outline_array) key_elements.setdefault("atom_arrays", []).append(data) group_elements.add(group_atoms) group_elements.add(group_ghosts) group_labels = add_labels(element_groups, key_elements, use_label_arrays) group_elements.add(group_labels) if len(element_groups["cell_lines"]) > 0: cell_line_mat = pjs.LineMaterial( linewidth=1, color=element_groups["cell_lines"].group_properties["color"]) cell_line_geo = pjs.LineSegmentsGeometry(positions=[ el.position.tolist() for el in element_groups["cell_lines"] ]) cell_lines = pjs.LineSegments2(geometry=cell_line_geo, material=cell_line_mat) key_elements["cell_lines"] = cell_lines group_elements.add(cell_lines) if len(element_groups["bond_lines"]) > 0: bond_line_mat = pjs.LineMaterial( linewidth=element_groups["bond_lines"]. group_properties["stroke_width"], vertexColors="VertexColors", ) bond_line_geo = pjs.LineSegmentsGeometry( positions=[ el.position.tolist() for el in element_groups["bond_lines"] ], colors=[[Color(c).rgb for c in el.color] for el in element_groups["bond_lines"]], ) bond_lines = pjs.LineSegments2(geometry=bond_line_geo, material=bond_line_mat) key_elements["bond_lines"] = bond_lines group_elements.add(bond_lines) group_millers = pjs.Group() if len(element_groups["miller_lines"]) or len( element_groups["miller_planes"]): key_elements["group_millers"] = group_millers if len(element_groups["miller_lines"]) > 0: miller_line_mat = pjs.LineMaterial( linewidth=3, vertexColors="VertexColors" # TODO use stroke_width ) miller_line_geo = pjs.LineSegmentsGeometry( positions=[ el.position.tolist() for el in element_groups["miller_lines"] ], colors=[[Color(el.stroke_color).rgb] * 2 for el in element_groups["miller_lines"]], ) miller_lines = pjs.LineSegments2(geometry=miller_line_geo, material=miller_line_mat) group_millers.add(miller_lines) for el in element_groups["miller_planes"]: vertices = el.position.tolist() faces = [( 0, 1, 2, triangle_normal(vertices[0], vertices[1], vertices[2]), "black", 0, )] if len(vertices) == 4: faces.append(( 2, 3, 0, triangle_normal(vertices[2], vertices[3], vertices[0]), "black", 0, )) elif len(vertices) != 3: raise NotImplementedError("polygons with more than 4 points") plane_geom = pjs.Geometry(vertices=vertices, faces=faces) plane_mat = pjs.MeshBasicMaterial( color=el.fill_color, transparent=True, opacity=el.fill_opacity, side="DoubleSide", ) plane_mesh = pjs.Mesh(geometry=plane_geom, material=plane_mat) group_millers.add(plane_mesh) group_elements.add(group_millers) scene = pjs.Scene(background=None) scene.add([group_elements]) view_width, view_height = canvas_size minp, maxp = element_groups.get_position_range() # compute a minimum camera distance, that is guaranteed to encapsulate all elements camera_dist = maxp[2] + sqrt(maxp[0]**2 + maxp[1]**2) / tan( radians(camera_fov / 2)) camera = pjs.PerspectiveCamera( fov=camera_fov, position=[0, 0, camera_dist], aspect=view_width / view_height, zoom=zoom, ) scene.add([camera]) ambient_light = pjs.AmbientLight(color="lightgray") key_elements["ambient_light"] = ambient_light direct_light = pjs.DirectionalLight(position=(maxp * 2).tolist()) key_elements["direct_light"] = direct_light scene.add([camera, ambient_light, direct_light]) camera_control = pjs.OrbitControls(controlling=camera, screenSpacePanning=True) atom_picker = pjs.Picker(controlling=group_atoms, event="dblclick") key_elements["atom_picker"] = atom_picker material = pjs.SpriteMaterial( map=create_arrow_texture(right=False), transparent=True, depthWrite=False, depthTest=False, ) atom_pointer = pjs.Sprite(material=material, scale=(4, 3, 1), visible=False) scene.add(atom_pointer) key_elements["atom_pointer"] = atom_pointer renderer = pjs.Renderer( camera=camera, scene=scene, controls=[camera_control, atom_picker], width=view_width, height=view_height, alpha=True, clearOpacity=background_opacity, clearColor=background_color, ) return renderer, key_elements
def create_world_axes(camera, controls, initial_rotation=np.eye(3), length=30, width=3, camera_fov=10): """Create a renderer, containing an axes and camera that is synced to another camera. adapted from http://jsfiddle.net/aqnL1mx9/ Parameters ---------- camera : pythreejs.PerspectiveCamera controls : pythreejs.OrbitControls initial_rotation : list or numpy.array initial rotation of the axes length : int length of axes lines width : int line width of axes Returns ------- pythreejs.Renderer """ import pythreejs as pjs canvas_width = length * 2 canvas_height = length * 2 ax_scene = pjs.Scene() group_ax = pjs.Group() # NOTE: could use AxesHelper, but this does not allow for linewidth seletion # TODO: add arrow heads (ArrowHelper doesn't seem to work) ax_line_mat = pjs.LineMaterial(linewidth=width, vertexColors="VertexColors") ax_line_geo = pjs.LineSegmentsGeometry( positions=[[[0, 0, 0], length * r / np.linalg.norm(r)] for r in initial_rotation], colors=[[Color(c).rgb] * 2 for c in ("red", "green", "blue")], ) ax_lines = pjs.LineSegments2(geometry=ax_line_geo, material=ax_line_mat) group_ax.add(ax_lines) ax_scene.add([group_ax]) camera_dist = length / tan(radians(camera_fov / 2)) ax_camera = pjs.PerspectiveCamera(fov=camera_fov, aspect=canvas_width / canvas_height, near=1, far=1000) ax_camera.up = camera.up ax_renderer = pjs.Renderer( scene=ax_scene, camera=ax_camera, width=canvas_width, height=canvas_height, alpha=True, clearOpacity=0.0, clearColor="white", ) def align_axes(change=None): """Align axes to world.""" # TODO: this is not working correctly for TrackballControls, when rotated upside-down # (OrbitControls enforces the camera up direction, # so does not allow the camera to rotate upside-down). # TODO how could this be implemented on the client (js) side? new_position = np.array(camera.position) - np.array(controls.target) new_position = camera_dist * new_position / np.linalg.norm( new_position) ax_camera.position = new_position.tolist() ax_camera.lookAt(ax_scene.position) align_axes() camera.observe(align_axes, names="position") controls.observe(align_axes, names="target") ax_scene.observe(align_axes, names="position") return ax_renderer