def plot_3D_data(x,y,z,fig_title): # Visualize it with mlab.surf from mayavi import mlab mlab.figure(bgcolor=(1, 1, 1)) surf = mlab.surf(z, colormap='cool') # Retrieve the LUT of the surf object. lut = surf.module_manager.scalar_lut_manager.lut.table.to_array() # The lut is a 255x4 array, with the columns representing RGBA # (red, green, blue, alpha) coded with integers going from 0 to 255. # We modify the alpha channel to add a transparency gradient lut[:, -1] = np.linspace(0, 255, 256) # and finally we put this LUT back in the surface object. We could have # added any 255*4 array rather than modifying an existing LUT. surf.module_manager.scalar_lut_manager.lut.table = lut mlab.title(fig_title,color=(0,0,0)) # mlab.axes(surf, color=(.7, .7, .7), ranges=(0, 1, 0, 1, 0, 1), xlabel='', ylabel='', zlabel='Probability', x_axis_visibility=False, z_axis_visibility=False) # We need to force update of the figure now that we have changed the LUT. #mlab.axes() mlab.orientation_axes() #mlab.xlabel('X') mlab.view(40, 85) #mlab.title(fig_title) mlab.draw() mlab.show()
def surfcf(gridx, gridy, phase, modulus, colormap=None): r"""Plot the modulus of a complex valued function :math:`f:R^2 -> C` together with its phase in a color coded fashion. :param gridx: The grid nodes along the :math:`x` axis of the real domain :math:`R^2` :param gridy: The grid nodes along the :math:`y` axis of the real domain :math:`R^2` :param phase: The phase of the complex domain result f(grid) :param modulus: The modulus of the complex domain result f(grid) :param colormap: The colormap to use, if none is given, compute the 'default' QM colormap. """ if colormap is None: colormap = compute_color_map() # The real(.) is necessary just to get an array with dtype real mesh = mlab.mesh(gridx, gridy, real(modulus), scalars=phase) # Set the custom color map mesh.module_manager.scalar_lut_manager.use_default_range = False mesh.module_manager.scalar_lut_manager.data_range = [-pi, pi] lut = mesh.module_manager.scalar_lut_manager.lut.table.to_array() lut[:,0:3] = colormap.copy() mesh.module_manager.scalar_lut_manager.lut.table = lut # Update the figure mlab.draw() return mesh
def zoncaview(m): """ m is a healpix sky map, such as provided by WMAP or Planck. """ nside = hp.npix2nside(len(m)) vmin = -1e3; vmax = 1e3 # Set up some grids: xsize = ysize = 1000 theta = np.linspace(np.pi, 0, ysize) phi = np.linspace(-np.pi, np.pi, xsize) longitude = np.radians(np.linspace(-180, 180, xsize)) latitude = np.radians(np.linspace(-90, 90, ysize)) # Project the map to a rectangular matrix xsize x ysize: PHI, THETA = np.meshgrid(phi, theta) grid_pix = hp.ang2pix(nside, THETA, PHI) grid_map = m[grid_pix] # Create a sphere: r = 0.3 x = r*np.sin(THETA)*np.cos(PHI) y = r*np.sin(THETA)*np.sin(PHI) z = r*np.cos(THETA) # The figure: mlab.figure(1, bgcolor=(1, 1, 1), fgcolor=(0, 0, 0), size=(400, 300)) mlab.clf() mlab.mesh(x, y, z, scalars=grid_map, colormap="jet", vmin=vmin, vmax=vmax) mlab.draw() return
def visualize_different_color(): # Primitives N = 200 # Number of points ones = np.ones(N) scalars = np.arange(N) # Key point: set an integer for each point # scalars = np.random.randint(0, 5, N) # Define color table (including alpha), which must be uint8 and [0,255] colors = (np.random.random((N, 4)) * 255).astype(np.uint8) colors[:, -1] = 255 # No transparency # Define coordinates and points x, y, z = colors[:, 0], colors[:, 1], colors[:, 2] # Assign x, y, z values to match color pts = mlab.quiver3d(x, y, z, ones, ones, ones, scalars=scalars, mode='sphere') # Create points pts.glyph.color_mode = 'color_by_scalar' # Color by scalar # Set look-up table and redraw pts.module_manager.scalar_lut_manager.lut.table = colors mlab.draw() mlab.show() return 0
def mlab_imshowColor(im, alpha=255, **kwargs): """ Plot a color image with mayavi.mlab.imshow. im is a ndarray with dim (n, m, 3) and scale (0->255] alpha is a single number or a ndarray with dim (n*m) and scale (0->255] **kwargs is passed onto mayavi.mlab.imshow(..., **kwargs) """ try: alpha[0] except: alpha = pl.ones(im.shape[0] * im.shape[1]) * alpha if len(alpha.shape) != 1: alpha = alpha.flatten() # The lut is a Nx4 array, with the columns representing RGBA # (red, green, blue, alpha) coded with integers going from 0 to 255, # we create it by stacking all the pixles (r,g,b,alpha) as rows. myLut = pl.c_[im.reshape(-1, 3), alpha] myLutLookupArray = pl.arange(im.shape[0] * im.shape[1]).reshape(im.shape[0], im.shape[1]) #We can display an color image by using mlab.imshow, a lut color list and a lut lookup table. theImshow = mlab.imshow(myLutLookupArray, colormap='binary', **kwargs) #temporary colormap theImshow.module_manager.scalar_lut_manager.lut.table = myLut mlab.draw() return theImshow
def draw_scene(): s = mlab.pipeline.triangular_mesh_source(x, y, z, triIndices) s.data.cell_data.scalars = np.cos(phaseAngle) surf = mlab.pipeline.surface(s) surf.contour.filled_contours = True surf.contour.minimum_contour = 0.0 surf.contour.maximum_contour = 1.0 surf.module_manager.scalar_lut_manager.data_range = (0,1) mlab.plot3d(xSun_plt, ySun_plt, zSun_plt, tube_radius=fStretch/1000, color=(1,1,0)) mlab.plot3d(xSC_plt, ySC_plt, zSC_plt, tube_radius=fStretch/1000, color=(0,0,1)) ball_x = [] ball_y = [] ball_z = [] for i in range(nPixelsX): for j in range(nPixelsY): p = np.dot(R, pVectors[:,i,j]) p_tan = np.dot(rCG, p) * p + rSC xVIR_plt, yVIR_plt, zVIR_plt = plt_coords(rSC, 1.1*fStretch*p) mlab.plot3d(xVIR_plt, yVIR_plt, zVIR_plt, tube_radius=fStretch/5000, color=(0,0,0)) ball_x.append(p_tan[0]) ball_y.append(p_tan[1]) ball_z.append(p_tan[2]) mlab.points3d(ball_x, ball_y, ball_z, np.ones(len(ball_x)), scale_factor=150, color=(1,0.7,0.1)) mlab.draw()
def surfcf(gridx, gridy, phase, modulus, colormap=None): r"""Plot the modulus of a complex valued function :math:`f:R^2 -> C` together with its phase in a color coded fashion. :param gridx: The grid nodes along the :math:`x` axis of the real domain :math:`R^2` :param gridy: The grid nodes along the :math:`y` axis of the real domain :math:`R^2` :param phase: The phase of the complex domain result f(grid) :param modulus: The modulus of the complex domain result f(grid) :param colormap: The colormap to use, if none is given, compute the 'default' QM colormap. """ if colormap is None: colormap = compute_color_map() # The real(.) is necessary just to get an array with dtype real mesh = mlab.mesh(gridx, gridy, real(modulus), scalars=phase) # Set the custom color map mesh.module_manager.scalar_lut_manager.use_default_range = False mesh.module_manager.scalar_lut_manager.data_range = [-pi, pi] lut = mesh.module_manager.scalar_lut_manager.lut.table.to_array() lut[:, 0:3] = colormap.copy() mesh.module_manager.scalar_lut_manager.lut.table = lut # Update the figure mlab.draw() return mesh
def draw_sats(self, figure): # Draws all satellites as a single plot3D if self.sat_points is None: x, y, z = self.get_mayavi_xyz() self.sat_points = mlab.points3d(x, y, z, np.arange(len(self)), figure=figure, scale_mode='none', scale_factor=SCALE_FACTOR, color=tuple(self.colors[0, 0:3] / 255)) if len(self) > 1: self.sat_points.module_manager.scalar_lut_manager.lut.number_of_colors = len( self.colors) self.sat_points.module_manager.scalar_lut_manager.lut.table = self.colors mlab.draw() else: x, y, z = self.get_mayavi_xyz() self.sat_points.mlab_source.trait_set(x=x, y=y, z=z)
def generate_plots_3d(self): self.ax = mlab.figure(1, bgcolor=(1, 1, 1), fgcolor=(0, 0, 0), size=(800, 600)) self.clf = mlab.clf() minS, maxS = maxint, 0 contour_plots = [] for cond in self.conductors.itervalues(): minS, maxS, face_data = self.generate_plot_data_for_faces_3d(cond, minS, maxS) for (x, y, z, s) in face_data: if isinstance(cond, conductor_type_3d['Unstructured']): pts = mlab.points3d(x, y, z, s, scale_mode='none', scale_factor=0.002) mesh = mlab.pipeline.delaunay3d(pts) contour_plots.append(mlab.pipeline.surface(mesh, colormap='viridis')) else: if np.min(s) < 0.0: contour_plots.append(mlab.mesh(x, y, z, color=(0, 0, 0), colormap='viridis')) else: contour_plots.append(mlab.mesh(x, y, z, scalars=s, colormap='viridis')) for cp in contour_plots: cp.module_manager.scalar_lut_manager.trait_set(default_data_range=[minS * 0.95, maxS * 1.05]) mlab.draw() mlab.colorbar(object=contour_plots[0], orientation='vertical') mlab.show()
def e_field(q, pos, x_grid, y_grid, z_grid, x_field, y_field, z_field, no_lines): fig = mplt.figure() X, Y, Z = np.meshgrid(x_grid, y_grid, z_grid, indexing='ij') for charge, location in zip(q, pos): # draw sphere for point charge sphere(charge, location) # draw electric field lines ball = mplt.flow(X, Y, Z, x_field, y_field, z_field, figure=fig, seedtype='sphere', integration_direction='both') ball.seed.widget.center = location # number of field lines to integrate over ball.seed.widget.theta_resolution = no_lines ball.seed.widget.phi_resolution = no_lines # number of integration steps ball.stream_tracer.maximum_propagation = 200 ball.seed.widget.radius = 1 # dodgy hax... TL;DR widgets are dumb ball.seed.widget.enabled = False ball.seed.widget.enabled = True mplt.axes() # set view to x-axis coming out of screen fig.scene.x_plus_view() mplt.draw(figure=fig) mplt.show()
def show_contrasts(subject, contrasts, side, threshold): x, y, z, triangles = get_geometry(subject, side, "inflated") ## inflated or white curv = get_curvature_sign(subject, side) f = mlab.figure() mlab.clf() # anatomical mesh mlab.triangular_mesh(x, y, z, triangles, transparent=False, opacity=1., name=subject, scalars=curv, colormap="bone", vmin=-1, vmax=2) mlab.title(subject) cmaps = [colormaps[c.split("-")[0]]['colormap'] for c in contrasts] for contrast, colormap in zip(contrasts, cmaps): # functional mesh data = get_contrast(subject, contrast, side) func_mesh = mlab.pipeline.triangular_mesh_source(x, y, z, triangles, scalars=data) # threshold thresh = mlab.pipeline.threshold(func_mesh, low=threshold) surf = mlab.pipeline.surface(thresh, colormap='hot', transparent=True, opacity=.8) # diminuer pour avoir plus de transparence lut = (np.array([colormap(v) for v in np.linspace(.25, 1., 256)]) * 255 ).astype(int) surf.module_manager.scalar_lut_manager.lut.table = lut mlab.draw() return f
def _set_3d_view(figure, azimuth, elevation, focalpoint, distance): from mayavi import mlab with warnings.catch_warnings(record=True): # traits with SilenceStdout(): mlab.view(azimuth, elevation, distance, focalpoint=focalpoint, figure=figure) mlab.draw(figure)
def draw_conns(self,new_edges=None): try: self.thres.set(lower_threshold=self.ds.thresval) lo=self.thres.lower_threshold; hi=self.thres.upper_threshold set_lut(self.vectors,self.ds.opts.activation_map) if new_edges is not None: new_starts=self.ds.lab_pos[new_edges[:,0]] new_vecs=self.ds.lab_pos[new_edges[:,1]] - new_starts self.vectors.mlab_source.reset( x=new_starts[:,0],y=new_starts[:,1],z=new_starts[:,2], u=new_vecs[:,0],v=new_vecs[:,1],w=new_vecs[:,2]) if self.ds.curr_node is not None: self.vectors.actor.property.opacity=.75 self.txt.set(text=' %s'%self.ds.labnam[self.ds.curr_node]) else: self.vectors.actor.property.opacity=( .5 if self.ds.opts.tube_conns else .3) self.txt.set(text='') mlab.draw() # In case the user changes the threshold while there are no connections # present and so the VTK objects have not been created yet except AttributeError: pass
def view_patch(r, attrib=[], opacity=1, fig=0, show=1): if fig == 0: fig = mlab.figure() else: mlab.figure(fig) if len(attrib) > 0: mlab.triangular_mesh(r.vertices[:, 0], r.vertices[:, 1], r.vertices[:, 2], r.faces, representation='surface', opacity=opacity, scalars=attrib) else: mlab.triangular_mesh(r.vertices[:, 0], r.vertices[:, 1], r.vertices[:, 2], r.faces, representation='surface', opacity=opacity) mlab.gcf().scene.parallel_projection = True mlab.view(azimuth=0, elevation=90) mlab.colorbar(orientation='horizontal') mlab.draw() if show > 0: mlab.show(stop=True) return fig
def draw(self): x = np.linspace(self._base_square_x[0], self._base_square_x[1],self._Nl) y = np.linspace(self._base_square_y[0], self._base_square_y[1],self._Nw) x,y = np.meshgrid(x,y) z = 0.0*x for trans in self._transforms: p = np.concatenate((x[:,:,None], y[:,:,None], z[:,:,None]), axis=-1) p = trans(p) pflat = np.reshape(p,(-1,3)) x,y,z = pflat[:,0],pflat[:,1],pflat[:,2] value = self._f(pflat,*self._f_args,**self._f_kwargs) u,v,w = value[:,0],value[:,1],value[:,2] m = mlab.quiver3d(x,y,z,u,v,w,mode='arrow',color=(0.4,0.4,0.4)) mlab.draw() if self._plots is None: self._plots = (m,trans), else: self._plots += (m,trans),
def draw(self,**kwargs): x = np.linspace(self._base_square_x[0] + 0.5*(self._base_square_x[1]-self._base_square_x[0])/self._Nl, self._base_square_x[1] - 0.5*(self._base_square_x[1]-self._base_square_x[0])/self._Nl, self._Nl) y = np.linspace(self._base_square_y[0] + 0.5*(self._base_square_y[1]-self._base_square_y[0])/self._Nw, self._base_square_y[1] - 0.5*(self._base_square_y[1]-self._base_square_y[0])/self._Nw, self._Nw) x,y = np.meshgrid(x,y) z = 0.0*x for trans in self._transforms: p = np.concatenate((x[:,:,None], y[:,:,None], z[:,:,None]), axis=-1) p = trans(p) pflat = np.reshape(p,(-1,3)) x,y,z = pflat[:,0],pflat[:,1],pflat[:,2] value = self._f(pflat,*self._f_args,**self._f_kwargs) u,v,w = value[:,0],value[:,1],value[:,2] m = mlab.quiver3d(x,y,z,u,v,w,mode='arrow',color=(1.0,1.0,1.0),scale_factor=self.scale_units, resolution=20,**kwargs) mlab.draw() if self._plots is None: self._plots = (m,trans), else: self._plots += (m,trans), return [i[0] for i in self._plots]
def __init__(self,start,end,maxd=5000.,n=100): ''' Constructor ''' self.start=start self.end=end self.lopath=numpy.linspace(start[0], end[0], n) self.lapath=numpy.linspace(start[1], end[1], n) self.set_proj() self.tile=TiffReader(lon=self.lopath[0],lat=self.lapath[0]) self.tile.readit() for i in range(n): if not self.tile==TiffReader(lon=self.lopath[i],lat=self.lapath[i]): self.tile=TiffReader(lon=self.lopath[i],lat=self.lapath[i]) self.tile.readit() if not hasattr(self,'mesh'): lo,la,z=self.tile.subset(rect=None, around=(self.lopath[i],self.lapath[i],maxd)) x,y=self.proj(lo,la) x=x-x.mean() y=y-y.mean() self.mesh=mlab.mesh(x,y,z,scalars=z,vmax=1500.,vmin=0.) mlab.view(180.,45.,maxd,numpy.array([x.max(),0,z.max()])) else: lo,la,self.mesh.mlab_source.z=self.tile.subset(rect=None, around=(self.lopath[i],self.lapath[i],maxd)) self.mesh.mlab_source.scalars=self.mesh.mlab_source.z mlab.view(180.,45.,5*x.max(),numpy.array([x.max(),0,self.mesh.mlab_source.z.max()])) mlab.draw()
def draw_conns(self, new_edges=None): try: self.thres.set(lower_threshold=self.ds.thresval) lo = self.thres.lower_threshold hi = self.thres.upper_threshold set_lut(self.vectors, self.ds.opts.activation_map) if new_edges is not None: new_starts = self.ds.lab_pos[new_edges[:, 0]] new_vecs = self.ds.lab_pos[new_edges[:, 1]] - new_starts self.vectors.mlab_source.reset(x=new_starts[:, 0], y=new_starts[:, 1], z=new_starts[:, 2], u=new_vecs[:, 0], v=new_vecs[:, 1], w=new_vecs[:, 2]) if self.ds.curr_node is not None: self.vectors.actor.property.opacity = .75 self.txt.set(text=' %s' % self.ds.labnam[self.ds.curr_node]) else: self.vectors.actor.property.opacity = ( .5 if self.ds.opts.tube_conns else .3) self.txt.set(text='') mlab.draw() # In case the user changes the threshold while there are no connections # present and so the VTK objects have not been created yet except AttributeError: pass
def draw(self,**kwargs): x = np.linspace(self._base_square_x[0], self._base_square_x[1],self._Nl) y = np.linspace(self._base_square_y[0], self._base_square_y[1],self._Nw) x,y = np.meshgrid(x,y) z = 0.0*x for trans in self._transforms: p = np.concatenate((x[:,:,None], y[:,:,None], z[:,:,None]), axis=-1) p = trans(p) pflat = np.reshape(p,(-1,3)) c = self._f(pflat,*self._f_args,**self._f_kwargs) c = np.reshape(c,np.shape(p)[:-1]) if self._clim is None: self._clim = (np.min(c),np.max(c)) m = mlab.mesh(p[:,:,0],p[:,:,1],p[:,:,2], scalars=c, vmin=self._clim[0], vmax=self._clim[1],**kwargs) m.module_manager.scalar_lut_manager.lut.table = self._rgba mlab.colorbar() mlab.draw() if self._plots is None: self._plots = (m,trans), else: self._plots += (m,trans), return [i[0] for i in self._plots]
def process_launch(): '''Procédure reliant une fenetre graphique et le coeur du programme''' global nb_etapesIV nb_etapes=nb_etapesIV.get()#On récupère le nombre d'étapes fig=mlab.figure(1) mlab.clf()#La fenêtre de dessin est initialisée mlab.draw(terrain([(0,1,2),(2,3,4),(4,5,6)],[(Point(0,0,0),Point(1,0,0)),(Point(1,0,0),Point(1,1,0)),(Point(0,0,0),Point(1,1,0)),(Point(1,1,0),Point(0,1,0)),(Point(0,0,0),Point(0,1,0)),(Point(0,0,0),Point(-1,1,0)),(Point(-1,1,0),Point(0,1,0))],nb_etapes))#On affiche le dessin
def plot_isosurface(crystal): filename = './output/potentialfield.txt' data = np.genfromtxt(filename, delimiter='\t') size = np.round((len(data))**(1/3)) X = np.reshape(data[:,0], (size,size,size)) Y = np.reshape(data[:,1], (size,size,size)) Z = np.reshape(data[:,2], (size,size,size)) DeltaU = np.reshape(data[:,3], (size,size,size)) average = np.average(crystal.coordinates[:,0]) start = average - crystal.a end = average + crystal.a coords1 = np.array([[start, start, start]]) coords2 = np.array([[end, end, end]]) array1 = np.repeat(coords1,len(crystal.coordinates),axis=0) array2 = np.repeat(coords2,len(crystal.coordinates),axis=0) basefilter1 = np.greater(crystal.coordinates,array1) basefilter2 = np.less(crystal.coordinates,array2) basefilter = np.nonzero(np.all(basefilter1*basefilter2, axis=1)) base = crystal.coordinates[basefilter] mlab.figure(bgcolor=(1, 1, 1), fgcolor=(1, 1, 1), size=(2048,2048)) dataset = mlab.contour3d(X, Y, Z, DeltaU, contours=[3.50],color=(1,0.25,0)) scatter = mlab.points3d(base[:,0], base[:,1], base[:,2], color=(0.255,0.647,0.88), resolution=24, scale_factor=1.0, opacity=0.40) mlab.view(azimuth=17, elevation=90, distance=10, focalpoint=[average,average-0.2,average]) mlab.draw() savename = './output/3Dpotential.png' mlab.savefig(savename, size=(2048,2048)) mlab.show()
def plot_potential(grid, potential, sparsify=1, along_axes=False, view=None, interactive=False, path='.'): """Plot the potential :param iom: An :py:class:`IOManager` instance providing the simulation data. """ # The Grid u, v = grid.get_nodes(split=True, flat=False) u = real(u[::sparsify, ::sparsify]) v = real(v[::sparsify, ::sparsify]) # Create potential and evaluate eigenvalues potew = potential.evaluate_eigenvalues_at(grid) potew = [ real(level).reshape( grid.get_number_nodes(overall=False))[::sparsify, ::sparsify] for level in potew ] # Plot if not interactive: mlab.options.offscreen = True fig = mlab.figure(size=(800, 700)) for level in potew: # The energy surfaces of the potential src = mlab.pipeline.grid_source(u, v, level) # Clip to given view if view is not None: geometry_filter = mlab.pipeline.user_defined( src, filter='GeometryFilter') geometry_filter.filter.extent_clipping = True geometry_filter.filter.extent = view src = mlab.pipeline.user_defined(geometry_filter, filter='CleanPolyData') # Plot the surface normals = mlab.pipeline.poly_data_normals(src) mlab.pipeline.surface(normals) mlab.axes() fig.scene.parallel_projection = True fig.scene.isometric_view() # fig.scene.show_axes = True mlab.draw() if interactive: mlab.show() else: mlab.savefig(os.path.join(path, "potential_3D_view.png")) mlab.close(fig)
def _set_3d_view(figure, azimuth, elevation, focalpoint, distance, roll=None, reset_camera=True, update=True): from mayavi import mlab with warnings.catch_warnings(record=True): # traits with SilenceStdout(): mlab.view(azimuth, elevation, distance, focalpoint=focalpoint, figure=figure, roll=roll) if update: mlab.draw(figure)
def surfacePlot( Hmap, nrows, ncols, xyspacing, zscale, name, hRange, file_path, lutfromfile, lut, lut_file_path, colorbar_on, save, show, ): # Create a grid of the x and y coordinates corresponding to each pixel in the height matrix x, y = np.mgrid[0 : ncols * xyspacing : xyspacing, 0 : nrows * xyspacing : xyspacing] # Create a new figure mlab.figure(size=(1000, 1000)) # Set the background color if desired # bgcolor=(0.16, 0.28, 0.46) # Create the surface plot of the reconstructed data plot = mlab.surf(x, y, Hmap, warp_scale=zscale, vmin=hRange[0], vmax=hRange[1], colormap=lut) # Import the LUT from a file if necessary if lutfromfile: plot.module_manager.scalar_lut_manager.load_lut_from_file(lut_file_path) # Draw the figure with the new LUT if necessary mlab.draw() # Zoom in to fill the entire window f = mlab.gcf() f.scene.camera.zoom(1.05) # Change the view to a top-down perspective mlab.view(270, 0) # Add a colorbar if indicated by colorbar_on (=True) if colorbar_on: # mlab.colorbar(title='Height (nm)', orientation='vertical') mlab.colorbar(orientation="vertical") # Save the figure if indicated by save (=True) if save: mlab.savefig(file_path, size=(1000, 1000)) if show == False: mlab.close() # Keep the figure open if indicated by show (=True) if show: mlab.show()
def simulate(sim_time, dt, sat, ground_track, sat_marker, picture_marker, picture_points): elapsed_time = 0 start_time = time.time() step_start_time = start_time time_deviation = 0 for i in [0, 1, 2]: sat.vel[i] = -sat.vel[i] sat = stepTime(sat, int(sat.period()), 10) for i in [0, 1, 2]: sat.vel[i] = -sat.vel[i] sat = stepTime(sat, int(sat.period()) * 2, 10) while elapsed_time < sim_time: ##SIMULATION LOOP## #step through time sat = stepTime(sat, dt, dt) ## create 2d map with track updateGroundMap(ground_track, sat_marker, sat, picture_marker, picture_points) plt.pause(0.05) plt.draw() ## create 3d globe with track mlab.clf() scaling_factor = m.sqrt(400000**2 / 3) line = mlab.quiver3d(sat.last_pos[int(sat.period() / dt), 1], sat.last_pos[int(sat.period() / dt), 0], sat.last_pos[int(sat.period() / dt), 2], sat.spin_vector[0], sat.spin_vector[1], sat.spin_vector[2], scale_factor=400000, line_width=2, figure=fig, color=(1, 0, 0), mode='2darrow') mlab.points3d(picture_points[:, 0], picture_points[:, 1], picture_points[:, 2], figure=fig, scale_factor=200000, color=(1, 0, 1)) mlab.draw() #make sure step only last one second in real time if abs(time.time() - step_start_time) <= 1: if 1 - abs(time.time() - step_start_time) - time_deviation > 0: time.sleep(1 - abs(time.time() - step_start_time) - time_deviation) else: time.sleep(1 - abs(time.time() - step_start_time)) step_start_time = time.time() elapsed_time = time.time() - start_time print(elapsed_time) time_deviation = abs(start_time - time.time()) % 1
def visualize_trajectory( gt, title_name='defualt', id=0, step=10 ): #, file_name='3d_trajectory.csv', dir_name='/home/steve/MatlabWs/VisualizationToolbox/'): trajectory = np.zeros([gt.shape[0], 12]) trajectory[:, 0:3] = gt[:, 0:3] * 1.0 x_vec = np.asarray((1.0, 0.0, 0.0)) y_vec = np.asarray((0.0, 1.0, 0.0)) z_vec = np.asarray((0.0, 0.0, 1.0)) for i in range(0, trajectory.shape[0], step): # trajectory[:,3:6] = np.linalg.inv(q2dcm(dg.gt[i,4:8])).dot(x_vec) # trajectory[:,3:6] = (q2dcm(dg.gt[i,4:8])).dot(x_vec) R = np.linalg.inv(q2dcm(gt[i, 3:7])) # trajectory[i, 0:3] = gt[i, 1:4] trajectory[i, 3:6] = R.dot(x_vec) trajectory[i, 6:9] = R.dot(y_vec) trajectory[i, 9:12] = R.dot(z_vec) # np.savetxt(dir_name + file_name, trajectory,delimiter=',') from mayavi import mlab # ax.plot(trajectory[:, 0], trajectory[:, 1], trajectory[:, 2],label='trajector') if id > 0: fig = mlab.figure(id) else: fig = mlab.figure() # fig.title(title_name) # fig. mlab.plot3d(trajectory[:, 0], trajectory[:, 1], trajectory[:, 2], tube_radius=0.01, color=(1.0, 1.0, 1.0)) t_name = title_name mlab.title(t_name) for k in range(3): c = np.zeros(3) c[k] = 1.0 ct = (c[0], c[1], c[2]) quiver = mlab.quiver3d(trajectory[:, 0], trajectory[:, 1], trajectory[:, 2], trajectory[:, 3 + k * 3], trajectory[:, 4 + k * 3], trajectory[:, 5 + k * 3], color=ct) # quiver.glyph.mask_input_points = True # quiver.glyph.mask_points.on_ratio=20 # if id==2: # mlab.sync_camera(mlab.figure(1),) mlab.axes() mlab.draw(fig)
def draw_nucleus(v): mlab.clf() s = mlab.pipeline.triangular_mesh_source(x, y, z, triIndices) s.data.cell_data.scalars = np.cos(phaseAngle) surf = mlab.pipeline.surface(s) surf.contour.filled_contours = True surf.contour.minimum_contour = 0.0 surf.contour.maximum_contour = 1.0 surf.module_manager.scalar_lut_manager.data_range = (0,1) mlab.view(v) mlab.draw()
def force_render( figure=None ): from mayavi import mlab figure.scene.render() mlab.draw(figure=figure) from pyface.api import GUI _gui = GUI() orig_val = _gui.busy _gui.set_busy(busy=True) _gui.process_events() _gui.set_busy(busy=orig_val) _gui.process_events()
def draw_nucleus(v): mlab.clf() s = mlab.pipeline.triangular_mesh_source(x, y, z, triIndices) s.data.cell_data.scalars = np.cos(phaseAngle) surf = mlab.pipeline.surface(s) surf.contour.filled_contours = True surf.contour.minimum_contour = 0.0 surf.contour.maximum_contour = 1.0 surf.module_manager.scalar_lut_manager.data_range = (0, 1) mlab.view(v) mlab.draw()
def parse_box(box_data,indx): """ from box_data produce a 3D image of surfaces and display it using an animation of a slice moving back and forth through the box NOT FUNCTIONAL AT PRESENT - NEED TO WORK OUT THE MAYAVI SYNTAX -can't get different slices to display in a nice fashion """ s=box_data sp=1-box_data mlab.pipeline.image_plane_widget(mlab.pipeline.scalar_field(s), plane_orientation='z_axes', slice_index=0,colormap='black-white' ) mlab.pipeline.image_plane_widget(mlab.pipeline.scalar_field(s), plane_orientation='x_axes', slice_index=0,colormap='black-white' ) l=mlab.pipeline.image_plane_widget(mlab.pipeline.scalar_field(s), plane_orientation='y_axes', slice_index=0,colormap='black-white' ) #s=sp fig=mlab.gcf() ms=l.mlab_source print 'indx=',dir(ms) for i in range(90): #l=mlab.pipeline.image_plane_widget(mlab.pipeline.scalar_field(s), # plane_orientation='y_axes', # slice_index=i,colormap='black-white' # ) if(i%2): ms.reset(scalars=sp) else: ms.reset(scalars=s) #print dir(ms) #print ms #camera=fig.scene.camera #camera.yaw(i) fig.scene.reset_zoom() mlab.draw() filename='volume_slice_frame'+str(i)+'.png' print filename #mlab.savefig(filename) mlab.show() return
def draw_nodes(self): nc=self.ds.scalar_display_settings.node_color ns=self.ds.scalar_display_settings.node_size lhn=self.nodes_lh; lhn_ix=self.ds.lhnodes rhn=self.nodes_rh; rhn_ix=self.ds.rhnodes for nodes,ixes in ((lhn,lhn_ix),(rhn,rhn_ix)): nr=len(ixes) #set node size if self.ds.display_mode=='scalar' and ns: nodes.glyph.scale_mode='scale_by_vector' nodes.glyph.glyph.scale_factor=8 nodes.mlab_source.dataset.point_data.vectors=( np.tile(self.ds.node_scalars[ns][ixes],(3,1)).T) else: nodes.glyph.scale_mode='data_scaling_off' nodes.glyph.glyph.scale_factor=3 #set node color -- we dont care about ds.node_colors for mayavi if (self.ds.display_mode=='normal' or (self.ds.display_mode=='scalar' and not nc)): set_lut(nodes,self.ds.opts.default_map) nodes.mlab_source.dataset.point_data.scalars=np.tile(.3,nr) elif self.ds.display_mode=='scalar': #and nc must be true set_lut(nodes,self.ds.opts.scalar_map) nodes.mlab_source.dataset.point_data.scalars=( self.ds.node_scalars[nc][ixes]) elif self.ds.display_mode=='module_single': set_lut(nodes,self.ds.opts.default_map) new_colors=np.tile(.3,self.ds.nr_labels) new_colors[self.ds.get_module()]=.8 nodes.mlab_source.dataset.point_data.scalars=new_colors[ixes] elif self.ds.display_mode=='module_multi': new_colors=np.array(self.ds.module_colors[:self.ds.nr_modules]) manager=nodes.module_manager.scalar_lut_manager #set the mayavi object to the dummy cmap that we hide from user #so that when changed notifications will work correctly manager.lut_mode='black-white' #now adjust its LUT manually manager.number_of_colors=self.ds.nr_modules manager.lut.table=new_colors #set the mayavi scalars to be fractions between 0 and 1 import bct nodes.mlab_source.dataset.point_data.scalars=(bct.ls2ci( self.ds.modules,zeroindexed=True)/self.ds.nr_modules)[ixes] mlab.draw()
def anim(): global vessels, colorGradient, timeText, showingNode timeText = None while True: if timeText != None: timeText.remove() timeText = mlab.text(0.01, 0.01, 'Time: ' + str(g.globals.time), width=0.3) for (top, bottom, host) in vessels: lutT = top.module_manager.scalar_lut_manager.lut.table.to_array() lutB = bottom.module_manager.scalar_lut_manager.lut.table.to_array() count = host.getBacteriaCount() colorNdx = int(float(count) / parameters.cell_count_color_mapping * (len(colorGradient) - 1)) if colorNdx >= len(colorGradient): colorNdx = len(colorGradient) - 1 color = colorGradient[colorNdx] assert len(color) == 6 ones = np.ones(np.shape(lutT[:, 0])) R = int(color[0:2], 16) G = int(color[2:4], 16) B = int(color[4:6], 16) #R lutT[:,0] = ones * R lutB[:,0] = ones * R #G lutT[:,1] = ones * G lutB[:,1] = ones * G #B lutT[:,2] = ones * B lutB[:,2] = ones * B top.module_manager.scalar_lut_manager.lut.table = lutT bottom.module_manager.scalar_lut_manager.lut.table = lutB if showingNode is not None: (ax1, ax2) = plots ax1.cla() history = copy.deepcopy(showingNode.getCellCountHistory()) x = np.linspace(0, len(history) * parameters.cell_count_history_interval, len(history)) y = history ax1.plot(x, y, 'r') ax1.set_title('Bacteria Count history') ax2.cla() history = copy.deepcopy(showingNode.getFlowHistory()) x = np.linspace(0, len(history) * parameters.cell_count_history_interval, len(history)) y = history ax2.plot(x, y, 'b') ax2.set_title('Blood flow history') mlab.draw() if parameters.verbose: print("updating graph") yield
def plot_shape(self, kind, spatialSize=None, renderSize=None, features=None): """ outline=True - box edges title=True - print kind of picture **kwargs # one of: vertA, faceA, features, x, y, z, spatialSize """ f = mlab.figure(bgcolor=(1, 1, 1)) if kind == 'solid': vertA, faceA = self.vertices, self.faces mlab.triangular_mesh(vertA[:, 0], vertA[:, 1], vertA[:, 2], faceA) elif kind == 'transparent': vertA, faceA = self.vertices, self.faces mlab.triangular_mesh(vertA[:,0], vertA[:, 1], vertA[:, 2], faceA, opacity=0.1) elif kind == 'wireframe': vertA, faceA = self.vertices, self.faces mlab.triangular_mesh(vertA[:, 0], vertA[:, 1], vertA[:, 2], faceA, representation='wireframe') elif kind == "Bool": x, y, z = map(np.array, self.voxelize(spatialSize, renderSize)) assert len(x) == len(y) == len(z) N = len(x) scalars = np.arange(N) # Key point: set an integer for each point colors = np.zeros((N, 4), dtype=np.uint8) colors[:, -1] = 255 # No transparency if features is not None: features = features.ravel() colors[:, 0] = 255 colors[:, 1] = ( 255 * (1 - features / np.max(features))).astype(np.uint8) colors[:, 2] = ( 255 * (1 - features / np.max(features))).astype(np.uint8) else: colors[:, 0] = 0 colors[:, 1] = 255 colors[:, 2] = 0 pts = mlab.quiver3d( x-spatialSize/2, y-spatialSize/2+0.5, z-spatialSize/2+0.5, np.ones(N), np.zeros(N), np.zeros(N), scalars=scalars, mode='cube', scale_factor=0.7, line_width=10) pts.glyph.color_mode = 'color_by_scalar' try: pts.module_manager.scalar_lut_manager.lut.table = colors except: pass mlab.draw() return f
def updateAnimation(xx, yy, zz, xT, yT, zT): t = 0 while t < timesize: mlab.draw(figure=None) ball.mlab_source.set(x=xx[:, t], y=yy[:, t], z=zz[:, t]) #redefine the dragontail matrix for j in range(0, numparticles): oxR = range(0, 5) for updateN in oxR[4:None:-1]: if updateN > 0: xT[j][updateN] = xT[j][updateN - 1] yT[j][updateN] = yT[j][updateN - 1] zT[j][updateN] = zT[j][updateN - 1] else: xT[j][updateN] = xx[j, t] yT[j][updateN] = yy[j, t] zT[j][updateN] = zz[j, t] dragontail[j].mlab_source.reset(x=xT[j], y=yT[j], z=zT[j]) # for k in range(0,np.size(xx,axis=0)): # text[k].mlab_source.reset(x = xx[k,t], y = yy[k,t], z = zz[k,t]) #for j in range(0,numparticles): # for update in reversed(range(0,DragonFrame)): # if update>0: # DTx[update] = DTx[update-1] # DTy[update] = DTy[update-1] # DTz[update] = DTz[update-1] # else: # DTx[update] = xx[] #run calculations for dragontail data #if t>6: # dtLB = t-DragonFrame # dtUB = t+1 # # for jj in range(0,numparticles): # xpts = xx[jj,dtLB:dtUB] # ypts = yy[jj,dtLB:dtUB] # zpts = zz[jj,dtLB:dtUB] # dragontail[jj].mlab_source.reset(x=xpts,y=ypts,z=zpts) t += 1 if t == interpstep - 1: xT = xds yT = yds zT = zds t = 0 yield
def m_vector_wire(ori, loc, grid, x_grid, y_grid, z_grid, x_field, y_field, z_field): fig = mplt.figure() # draw sphere for point charge for orientation, location in zip(ori, loc): wire(orientation, location, grid) # draw vector field X, Y, Z = np.meshgrid(x_grid, y_grid, z_grid, indexing='ij') mplt.quiver3d(X, Y, Z, x_field, y_field, z_field) mplt.axes() # set view to x-axis coming out of screen fig.scene.x_plus_view() mplt.draw(figure=fig) mplt.show()
def plot_torus(theta, gamma, Z, weight_deform=0., torus_radius=5., tube_radius=3.0, try_mayavi=True, draw_colorbar=True): ''' Plot a torus, with the color set by Z. Also possible to deform the sphere according to Z, by putting a nonzero weight_deform. Need theta in [0, 2pi] and gamma in [0, pi] ''' Z_norm = Z/Z.max() X, Y = np.meshgrid(theta, gamma) x = (torus_radius+ tube_radius*np.cos(X)*(1.+weight_deform*Z_norm))*np.cos(Y) y = (torus_radius+ tube_radius*np.cos(X)*(1.+weight_deform*Z_norm))*np.sin(Y) z = tube_radius*np.sin(X)*(1.+weight_deform*Z_norm) use_mayavi = False if try_mayavi: try: import mayavi.mlab as mplt use_mayavi = True except: pass if use_mayavi: # mplt.figure(bgcolor=(0.7,0.7,0.7)) mplt.figure(bgcolor=(1.0, 1.0, 1.0)) mplt.mesh(x, y, z, scalars=Z_norm, vmin=0.0) if draw_colorbar: cb = mplt.colorbar(title='', orientation='vertical', label_fmt='%.2f', nb_labels=5) mplt.outline(color=(0., 0., 0.)) mplt.draw() else: fig = plt.figure() ax = fig.add_subplot(111, projection='3d') ax.plot_surface(x, y, z, facecolors=cm.jet(Z_norm), rstride=1, cstride=1, linewidth=0, antialiased=True, shade=False) # Colorbar m = cm.ScalarMappable(cmap=cm.jet) m.set_array(Z_norm) if draw_colorbar: plt.colorbar(m)
def process_launch(): '''Procédure reliant une fenetre graphique et le coeur du programme''' global nb_etapesIV nb_etapes = nb_etapesIV.get() #On récupère le nombre d'étapes fig = mlab.figure(1) mlab.clf() #La fenêtre de dessin est initialisée mlab.draw( terrain([(0, 1, 2), (2, 3, 4), (4, 5, 6)], [(Point(0, 0, 0), Point(1, 0, 0)), (Point(1, 0, 0), Point(1, 1, 0)), (Point(0, 0, 0), Point(1, 1, 0)), (Point(1, 1, 0), Point(0, 1, 0)), (Point(0, 0, 0), Point(0, 1, 0)), (Point(0, 0, 0), Point(-1, 1, 0)), (Point(-1, 1, 0), Point(0, 1, 0))], nb_etapes)) #On affiche le dessin
def plot_potential(grid, potential, sparsify=1, along_axes=False, view=None, interactive=False, path='.'): """Plot the potential :param iom: An :py:class:`IOManager` instance providing the simulation data. """ # The Grid u, v = grid.get_nodes(split=True, flat=False) u = real(u[::sparsify, ::sparsify]) v = real(v[::sparsify, ::sparsify]) # Create potential and evaluate eigenvalues potew = potential.evaluate_eigenvalues_at(grid) potew = [real(level).reshape(grid.get_number_nodes(overall=False))[::sparsify, ::sparsify] for level in potew] # Plot if not interactive: mlab.options.offscreen = True fig = mlab.figure(size=(800, 700)) for level in potew: # The energy surfaces of the potential src = mlab.pipeline.grid_source(u, v, level) # Clip to given view if view is not None: geometry_filter = mlab.pipeline.user_defined(src, filter='GeometryFilter') geometry_filter.filter.extent_clipping = True geometry_filter.filter.extent = view src = mlab.pipeline.user_defined(geometry_filter, filter='CleanPolyData') # Plot the surface normals = mlab.pipeline.poly_data_normals(src) mlab.pipeline.surface(normals) mlab.axes() fig.scene.parallel_projection = True fig.scene.isometric_view() # fig.scene.show_axes = True mlab.draw() if interactive: mlab.show() else: mlab.savefig(os.path.join(path, "potential_3D_view.png")) mlab.close(fig)
def fmlab(volume, slices, spacing=np.array([1.,1.,1.]), no_slices=5, filename=None, size=[1024,768]): '''Creates a volume with slices of data from up to 2 scalar fields. Volume and slices are scalar fields that are used to generate their respective parts of the plot. They can be the same field, but must be 4- or 3-Dimensional. If a field is 4-Dimensional, however, only the first point is used. It's best to pass only 3D data to this. Spacing is the 'streching factor' of the data in x,y,z format. no_slices determines the number of slices that are generated. ''' # Only plot one time slice, so remove 4th dimension if np.shape(volume) == 4: volume = volume[0,:,:,:] if np.shape(slices) == 4: slices = slices[0,:,:,:] # Creating the mlab scene which will contain all of the visualisations plot = mlab.figure() # Loads in the data as a scalar field, creates the VTK data source # Can be changed for other types of data source (vector field, etc) s_field1 = mlab.pipeline.scalar_field(volume) s_field2 = mlab.pipeline.scalar_field(slices) # Changes the spacing of the grid points to make the data look right # fiddle with until it looks good. s_field1.spacing = spacing s_field2.spacing = spacing #Creates the 'cloud' of the whole data. mlab.pipeline.volume(s_field1, figure=plot) # Adds in the 'slices' within the volume for n in np.around(range(no_slices)): mlab.pipeline.image_plane_widget(s_field2, figure=plot, plane_orientation = 'y_axes', slice_index=( np.int((n + 3./4.)*((np.shape(slices)[1]/no_slices)) ) ) ) if filename is not None: mlab.options.offscreen = True mlab.draw(figure = plot) mlab.savefig(str(filename) + '.png', size=size, figure=plot, magnification=1.) else: mlab.draw(figure=plot) mlab.show()
def mlab_imshow_color(img_array, **kwargs): """ Plot a color image with mayavi.mlab.imshow. img_array is a ndarray with dim (n, m, 4) and scale (0->255] **kwargs is passed onto mayavi.mlab.imshow(..., **kwargs) """ my_lut = pl.c_[img_array.reshape(-1, 4)] my_lut_lookup_array = pl.arange(img_array.shape[0] * img_array.shape[1]).reshape( img_array.shape[0], img_array.shape[1]) the_imshow = mlab.imshow(my_lut_lookup_array, colormap='binary', **kwargs) # temporary colormap the_imshow.module_manager.scalar_lut_manager.lut.table = my_lut mlab.draw() return the_imshow
def draw_scene(): s = mlab.pipeline.triangular_mesh_source(x, y, z, triIndices) s.data.cell_data.scalars = np.cos(phaseAngle) surf = mlab.pipeline.surface(s) surf.contour.filled_contours = True surf.contour.minimum_contour = 0.0 surf.contour.maximum_contour = 1.0 surf.module_manager.scalar_lut_manager.data_range = (0, 1) mlab.plot3d(xSun_plt, ySun_plt, zSun_plt, tube_radius=fStretch / 1000, color=(1, 1, 0)) mlab.plot3d(xSC_plt, ySC_plt, zSC_plt, tube_radius=fStretch / 1000, color=(0, 0, 1)) ball_x = [] ball_y = [] ball_z = [] for i in range(nPixelsX): for j in range(nPixelsY): p = np.dot(R, pVectors[:, i, j]) p_tan = np.dot(rCG, p) * p + rSC xVIR_plt, yVIR_plt, zVIR_plt = plt_coords(rSC, 1.1 * fStretch * p) mlab.plot3d(xVIR_plt, yVIR_plt, zVIR_plt, tube_radius=fStretch / 5000, color=(0, 0, 0)) ball_x.append(p_tan[0]) ball_y.append(p_tan[1]) ball_z.append(p_tan[2]) mlab.points3d(ball_x, ball_y, ball_z, np.ones(len(ball_x)), scale_factor=150, color=(1, 0.7, 0.1)) mlab.draw()
def surface3D(x,y,z,idx,ux,uy,uz,dat,cm='seismic',quiv=True,fac=0.01,col=True): from mayavi import mlab lut=eval('plt.cm.'+cm+'(np.linspace(0,1,255))*255') if col: col = (0.43, 0.43, 0.43) else: col = None mlab.figure(size=(800, 800)) mesh_handle = mlab.mesh(x[..., idx], y[..., idx], z[..., idx], scalars=dat[..., idx]) mesh_handle.module_manager.scalar_lut_manager.lut.table = lut # mesh_handle.module_manager.scalar_lut_manager.reverse_lut = True if quiv: mlab.quiver3d(x, y, z, ux, uy, uz, color=col, scale_mode='vector', mode='arrow', mask_points=4, scale_factor=fac) #mlab.show() mlab.draw()
def spatial_map(vector,r_vertices,r_faces,msk_small_region,labs,val): r_labels = np.zeros([r_vertices.shape[0]]) r_labels[~msk_small_region] = vector r_labels[msk_small_region]=labs mlab.figure(size=(1024, 768), \ bgcolor=(1, 1, 1),fgcolor=(0.5,0.5,0.5)) mlab.triangular_mesh(r_vertices[:, 0], r_vertices[:, 1], r_vertices[:, 2], r_faces, representation='surface', opacity=1, scalars=np.float64(r_labels)) mlab.gcf().scene.parallel_projection = True mlab.view(azimuth=0, elevation=270) mlab.draw() mlab.colorbar(orientation='vertical') #mlab.show() mlab.options.offscreen = True mlab.savefig('precentral_back'+str(val)+'.png') mlab.close() return r_labels
def _run_calculation_changed(self, value): # action = ThreadedAction(self.data, self.figure) # action.start() global frame print("Update 3D plots calculation in Frame %d" % frame, end=' ') truth_points, obs_points, pred_points, contour = self.data if log_plot: contour_s = np.log(gm_s_list[frame] + np.finfo(np.float).tiny) else: contour_s = gm_s_list[frame] contour.mlab_source.scalars = contour_s for i in range(frame, max(0, frame - 8), -1): opacity = 1. - 0.1 * (frame - i) truth_points[i].actor.property.opacity = opacity obs_points[i].actor.property.opacity = opacity pred_points[i].actor.property.opacity = opacity print('done.') mlab.draw() frame += 1
def torus(size=1., levels=1, nphi=31, n=31, azimuth=90.): ''' Construct a torus consisting of cubes with given size ''' dphi = 2. * np.pi / nphi if levels > 1: size = 2.2 * 0.9**(1. / 3.) * 3**levels radius = size / dphi for i in range(n): phi = i * dphi + np.pi / 2. x = radius * np.cos(phi) y = radius * (np.sin(phi) - 1.0) z = 0.0 cube(x, y, z, size, phi) ml.draw() ml.view(focalpoint=[0., 0., 0.], azimuth=azimuth, elevation=75., distance=4 * radius)
def test_colormap(cmap): from mayavi import mlab x,y=meshgrid(linspace(0,100,101),linspace(0,100,101)) z=(x+y)*0.5 if True: mesh_water=mlab.mesh(x.transpose(),y.transpose(),z.transpose(),colormap='gist_earth',vmin=0.,vmax=100.) #Retrieve the LUT of the surf object. lut = mesh_water.module_manager.scalar_lut_manager.lut.table.to_array() lut[:, 0] = cmap[0] lut[:, 1] = cmap[1] lut[:, 2] = cmap[2] mesh_water.module_manager.scalar_lut_manager.lut.table = lut mlab.draw() mlab.view(50.495536875966657, 26.902697031665959, 334.60652149512265, array([ 50., 50., 50.])) mlab.colorbar()
def mlab_imshowColor(im, alpha = 255, **kwargs): """ Plot a color image with mayavi.mlab.imshow. im is a ndarray with dim (n, m, 3) and scale (0->255] alpha is a single number or a ndarray with dim (n*m) and scale (0->255] **kwargs is passed onto mayavi.mlab.imshow(..., **kwargs) """ from tvtk.api import tvtk # Homogenous coordinate conversion im = np.concatenate((im, alpha * np.ones((im.shape[0], im.shape[1], 1), dtype = np.uint8)), axis = -1) colors = tvtk.UnsignedCharArray() colors.from_array(im.reshape(-1, 4)) m_image = mlab.imshow(np.ones(im.shape[:2][::-1])) m_image.actor.input.point_data.scalars = colors mlab.draw() mlab.show() return
def surf_default(self, arrayindex): surf = mlab.surf( self.arrays[arrayindex, :,:,0], # warp_scale=self.warp_scale warp_scale=(1 / np.max(self.arrays[arrayindex])) * 40 ) # Retrieve the LUT of the surf object. lut = surf.module_manager.scalar_lut_manager.lut.table.to_array() # The lut is a 255x4 array, with the columns representing RGBA # (red, green, blue, alpha) coded with integers going from 0 to 255. # We modify the alpha channel to add a transparency gradient # lut[:, -1] = np.linspace(0, 255, 256) # print lut.shape # (256, 4) if arrayindex == 1: lut[:, 0] = np.linspace(0, 255, 256) # red lut[:, 1] = np.linspace(0, 0, 256) # green lut[:, 2] = np.linspace(255, 0, 256) # blue lut[:, 3] = np.linspace(255, 255, 256) # alpha # lut[:, 3] = np.linspace(127, 127, 256) # alpha else: lut[:, 0] = np.linspace(0, 0, 256) # red lut[:, 1] = np.linspace(0, 255, 256) # green lut[:, 2] = np.linspace(255, 0, 256) # blue lut[:, 3] = np.linspace(255, 255, 256) # alpha # lut[:, 3] = np.linspace(127, 127, 256) # alpha # and finally we put this LUT back in the surface object. We could have # added any 255*4 array rather than modifying an existing LUT. surf.module_manager.scalar_lut_manager.lut.table = lut # We need to force update of the figure now that we have changed the LUT. mlab.draw() self.scene.scene_editor.isometric_view() #WARNING removing this line causes the surface not to display and the axes to rotate 90 degrees vertically! self.scene.camera.azimuth(90) return surf
def vis_colored_point_cloud(points, colors, **kwargs): if colors.shape[-1] != 4: raise ValueError('colors must be an (n, 4) array of rgba values') n = len(points) if len(colors) != n: raise ValueError('colors must be the same length as points') scalars = np.arange(n) ones = np.ones((n, )) x, y, z = points.T pts = mlab.quiver3d(x, y, z, ones, ones, ones, scalars=scalars, mode='sphere', **kwargs) pts.glyph.color_mode = 'color_by_scalar' # Color by scalar pts.module_manager.scalar_lut_manager.lut.table = colors mlab.draw()
def justPlotBoolArray(blob, figSize=(300, 300)): '''Plots a 3D boolean array with points where array is True''' Nx, Ny, Nz = blob.shape indexX, indexY, indexZ = mgrid[0:Nx, 0:Ny, 0:Nz] fig = mlab.figure(0, size=figSize) mlab.clf(fig) fig.scene.interactor.interactor_style = tvtk.InteractorStyleTerrain() idx = blob print idx.sum(), 'points' if idx.sum() > 0: idxFlat = idx.flatten() pts = points3d( indexX.flatten()[idxFlat] + .5, indexY.flatten()[idxFlat] + .5, indexZ.flatten()[idxFlat] + .5, ones(sum(idxFlat)) * .9, #((blob-mn) / (mx-mn) * .9)[idx], color=(1, 1, 1), mode='cube', scale_factor=1.0) lut = pts.module_manager.scalar_lut_manager.lut.table.to_array() tt = linspace(0, 255, 256) lut[:, 0] = tt * 0 + 255 lut[:, 1] = tt * 0 + 255 lut[:, 2] = tt * 0 + 255 lut[:, 3] = tt pts.module_manager.scalar_lut_manager.lut.table = lut #mlab.view(57.15, 75.55, 50.35, (7.5, 7.5, 7.5)) # nice view #mlab.view(24, 74, 33, (5, 5, 5)) # Default older RBM mlab.view(24, 88, 45, (5, 5, 10)) # Good for EF mlab.draw()
def justPlotBoolArray(blob, figSize = (300,300)): '''Plots a 3D boolean array with points where array is True''' Nx,Ny,Nz = blob.shape indexX, indexY, indexZ = mgrid[0:Nx,0:Ny,0:Nz] fig = mlab.figure(0, size = figSize) mlab.clf(fig) fig.scene.interactor.interactor_style = tvtk.InteractorStyleTerrain() idx = blob print idx.sum(), 'points' if idx.sum() > 0: idxFlat = idx.flatten() pts = points3d(indexX.flatten()[idxFlat] + .5, indexY.flatten()[idxFlat] + .5, indexZ.flatten()[idxFlat] + .5, ones(sum(idxFlat)) * .9, #((blob-mn) / (mx-mn) * .9)[idx], color = (1,1,1), mode = 'cube', scale_factor = 1.0) lut = pts.module_manager.scalar_lut_manager.lut.table.to_array() tt = linspace(0, 255, 256) lut[:, 0] = tt*0 + 255 lut[:, 1] = tt*0 + 255 lut[:, 2] = tt*0 + 255 lut[:, 3] = tt pts.module_manager.scalar_lut_manager.lut.table = lut #mlab.view(57.15, 75.55, 50.35, (7.5, 7.5, 7.5)) # nice view #mlab.view(24, 74, 33, (5, 5, 5)) # Default older RBM mlab.view(24, 88, 45, (5, 5, 10)) # Good for EF mlab.draw()
def scatter3d_torus(theta, gamma, torus_radius=5., tube_radius=3.0, try_mayavi=True): ''' Plot points on a torus. Need theta \in [0, 2pi] and gamma \in [0, pi] ''' x = (torus_radius+ tube_radius*np.cos(theta))*np.cos(gamma) y = (torus_radius+ tube_radius*np.cos(theta))*np.sin(gamma) z = tube_radius*np.sin(theta) use_mayavi = False if try_mayavi: try: import mayavi.mlab as mplt use_mayavi = True except: pass if use_mayavi: # mplt.figure(bgcolor=(0.7,0.7,0.7)) mplt.figure(bgcolor=(1.0, 1.0, 1.0)) mplt.points3d(x, y, z, scale_factor=1, color=(0.0, 0.2, 0.9)) mplt.outline(color=(0., 0., 0.)) mplt.draw() else: fig = plt.figure() ax = fig.add_subplot(111, projection='3d') ax.scatter(x, y, z) fig.canvas.draw() return ax
def update_frame(self): print('Frame %d' % self.current_frame) if log_plot: contour_s = np.log( gm_s_list[self.current_frame] + np.finfo(np.float).tiny ) else: contour_s = gm_s_list[self.current_frame] self.phd_contour.mlab_source.set( scalars=contour_s ) self.color_vector[:] = 0. if observation_list is not None: obs_array = observation_list[self.current_frame] obs_index = [ np.where( np.all(embeddings == sensor_vec.flatten(), axis=1) )[0][0] for sensor_vec in obs_array ] self.color_vector[obs_index] = 1. self.sensor_points.mlab_source.dataset.point_data.scalars = \ self.color_vector mlab.draw()
def m_field_wire(ori, pos, grid, x_grid, y_grid, z_grid, x_field, y_field, z_field, no_lines): fig = mplt.figure() X, Y, Z = np.meshgrid(x_grid, y_grid, z_grid, indexing='ij') for orientation, location in zip(ori, pos): # draw sphere for point charge wire(orientation, location, grid) # draw electric field lines line = mplt.flow(X, Y, Z, x_field, y_field, z_field, figure=fig, seedtype='line', integration_direction='both') # number of integration steps line.stream_tracer.maximum_propagation = 200 mplt.axes() # set view to x-axis coming out of screen fig.scene.x_plus_view() mplt.draw(figure=fig) mplt.show()
def main(): parser = OptionParser(usage="Usage: %prog [options] <tract.vtp>") parser.add_option("-s", "--scalar", dest="scalar", default="FA", help="Scalar to measure") parser.add_option("-n", "--num", dest="num", default=50, type='int', help="Number of subdivisions along centroids") parser.add_option("-l", "--local", dest="is_local", action="store_true", default=False, help="Measure from Quickbundle assigned streamlines. Default is to measure from all streamlines") parser.add_option("-d", "--dist", dest="dist", default=20, type='float', help="Quickbundle distance threshold") parser.add_option("--curvepoints", dest="curvepoints_file", help="Define a curve to use as centroid. Control points are defined in a csv file in the same space as the tract points. The curve is the vtk cardinal spline implementation, which is a catmull-rom spline.") parser.add_option('--yrange', dest='yrange') parser.add_option('--xrange', dest='xrange') parser.add_option('--reverse', dest='is_reverse', action='store_true', default=False, help='Reverse the centroid measure stepping order') parser.add_option('--pairplot', dest='pairplot',) parser.add_option('--noviz', dest='is_viz', action='store_false', default=True) parser.add_option('--hide-centroid', dest='show_centroid', action='store_false', default=True) parser.add_option('--config', dest='config') parser.add_option('--background', dest='bg_file', help='Background NIFTI image') parser.add_option('--annot', dest='annot') (options, args) = parser.parse_args() if len(args) == 0: parser.print_help() sys.exit(2) name_mapping = None if options.config: config = GtsConfig(options.config, configure=False) name_mapping = config.group_labels annotations = None if options.annot: with open(options.annot, 'r') as fp: annotations = yaml.load(fp) QB_DIST = options.dist QB_NPOINTS = options.num SCALAR_NAME = options.scalar LOCAL_POINT_ASSIGN = options.is_local filename = args[0] filebase = path.basename(filename).split('.')[0] reader = vtk.vtkXMLPolyDataReader() reader.SetFileName(filename) reader.Update() polydata = reader.GetOutput() tract_ids = [] for i in range(polydata.GetNumberOfCells()): # get point ids in [[ids][ids...]...] format pids = polydata.GetCell(i).GetPointIds() ids = [pids.GetId(p) for p in range(pids.GetNumberOfIds())] tract_ids.append(ids) print 'tracks:', len(tract_ids) verts = vtk_to_numpy(polydata.GetPoints().GetData()) print 'verts:', len(verts) scalars = [] groups = [] subjects = [] pointdata = polydata.GetPointData() for si in range(pointdata.GetNumberOfArrays()): sname = pointdata.GetArrayName(si) print sname if sname == SCALAR_NAME: scalars = vtk_to_numpy(pointdata.GetArray(si)) if sname == 'group': groups = vtk_to_numpy(pointdata.GetArray(si)) groups = groups.astype(int) if sname == 'tid': subjects = vtk_to_numpy(pointdata.GetArray(si)) subjects = subjects.astype(int) streamlines = [] stream_scalars = [] stream_groups = [] stream_pids = [] stream_sids = [] for i in tract_ids: # index np.array by a list will get all the respective indices streamlines.append(verts[i]) stream_scalars.append(scalars[i]) stream_pids.append(i) stream_sids.append(subjects[i]) try: stream_groups.append(groups[i]) except Exception: # group might not exist pass streamlines = np.array(streamlines) stream_scalars = np.array(stream_scalars) stream_groups = np.array(stream_groups) stream_pids = np.array(stream_pids) stream_sids = np.array(stream_sids) # get total average direction (where majority point towards) avg_d = np.zeros(3) # for line in streams: # d = np.array(line[-1]) - np.array(line[0]) # d = d / la.norm(d) # avg_d += d # avg_d /= la.norm(avg_d) avg_com = np.zeros(3) avg_mid = np.zeros(3) strl_len = [len(l) for l in streamlines] stl_ori = np.array([np.abs(tm.mean_orientation(l)) for l in streamlines]) centroids = [] if options.curvepoints_file: LOCAL_POINT_ASSIGN = False cpoints = [] ctrlpoints = np.loadtxt(options.curvepoints_file, delimiter=',') # have a separate vtkCardinalSpline interpreter for x,y,z curve = [vtk.vtkCardinalSpline() for i in range(3)] for c in curve: c.ClosedOff() for pi, point in enumerate(ctrlpoints): for i, val in enumerate(point): curve[i].AddPoint(pi, point[i]) param_range = [0.0, 0.0] curve[0].GetParametricRange(param_range) t = param_range[0] step = (param_range[1] - param_range[0]) / (QB_NPOINTS - 1.0) while t < param_range[1]: cp = [c.Evaluate(t) for c in curve] cpoints.append(cp) t = t + step centroids.append(cpoints) centroids = np.array(centroids) else: """ Use quickbundles to find centroids """ # streamlines = newlines qb = QuickBundles(streamlines, dist_thr=QB_DIST, pts=QB_NPOINTS) # bundle_distance_mam centroids = qb.centroids clusters = qb.clusters() avg_d = np.zeros(3) avg_com = np.zeros(3) avg_mid = np.zeros(3) #unify centroid list orders to point in the same general direction for i, line in enumerate(centroids): ori = np.array(tm.mean_orientation(line)) #d = np.array(line[-1]) - np.array(line[0]) #print line[-1],line[0],d # get the unit vector of the mean orientation if i == 0: avg_d = ori #d = d / la.norm(d) dotprod = ori.dot(avg_d) print 'dotprod', dotprod if dotprod < 0: print 'reverse', dotprod centroids[i] = line[::-1] line = centroids[i] ori *= -1 avg_d += ori if options.is_reverse: for i, c in enumerate(centroids): centroids[i] = c[::-1] # prepare mayavi 3d viz if options.is_viz: bg_val = 0. fig = mlab.figure(bgcolor=(bg_val, bg_val, bg_val)) scene = mlab.gcf().scene fig.scene.render_window.aa_frames = 4 mlab.draw() if options.bg_file: mrsrc, bgdata = getNiftiAsScalarField(options.bg_file) orie = 'z_axes' opacity = 0.5 slice_index = 0 mlab.pipeline.image_plane_widget(mrsrc, opacity=opacity, plane_orientation=orie, slice_index=int( slice_index), colormap='black-white', line_width=0, reset_zoom=False) # prepare the plt plot len_cent = len(centroids) pal = sns.color_palette("bright", len_cent) DATADF = None """ CENTROIDS """ for ci, cent in enumerate(centroids): print '---- centroid:' if LOCAL_POINT_ASSIGN: """ apply centroid to only their point assignments through quickbundles """ ind = clusters[ci]['indices'] cent_streams = streamlines[ind] cent_scalars = stream_scalars[ind] cent_groups = stream_groups[ind] cent_pids = stream_pids[ind] cent_sids = stream_sids[ind] else: # apply each centriod to all the points # instead of only their centroid assignments cent_streams = streamlines cent_scalars = stream_scalars cent_groups = stream_groups cent_pids = stream_pids cent_sids = stream_sids cent_verts = np.vstack(cent_streams) cent_scalars = np.concatenate(cent_scalars) cent_groups = np.concatenate(cent_groups) cent_pids = np.concatenate(cent_pids) cent_sids = np.concatenate(cent_sids) cent_color = np.array(pal[ci]) c, labels = kmeans2(cent_verts, cent, iter=1) cid = np.ones(len(labels)) d = {'value': cent_scalars, 'position': labels, 'group': cent_groups, 'pid': cent_pids, 'sid': cent_sids} df = pd.DataFrame(data=d) if DATADF is None: DATADF = df else: pd.concat([DATADF, df]) UNIQ_GROUPS = df.group.unique() UNIQ_GROUPS.sort() # UNIQ_GROUPS = [0,1] grppal = sns.color_palette("Set2", len(UNIQ_GROUPS)) print '# UNIQ GROUPS', UNIQ_GROUPS # print df # df = df[df['sid'] != 15] # df = df[df['sid'] != 16] # df = df[df['sid'] != 17] # df = df[df['sid'] != 18] """ plot each group by their position """ fig = plt.figure(figsize=(14, 7)) ax1 = plt.subplot2grid((4, 3), (0, 0), colspan=3, rowspan=3) ax2 = plt.subplot2grid((4, 3), (3, 0), colspan=3, sharex=ax1) axes = [ax1, ax2] plt.xlabel('Position Index') if len(centroids) > 1: cent_patch = mpatches.Patch( color=cent_color, label='Centroid {}'.format(ci + 1)) cent_legend = axes[0].legend(handles=[cent_patch], loc=9) axes[0].add_artist(cent_legend) """ Perform stats """ if len(UNIQ_GROUPS) > 1: # df = resample_data(df, num_sample_per_pos=120) # print df pvalsDf = position_stats(df, name_mapping=name_mapping) logpvals = np.log(pvalsDf) * -1 # print logpvals pvals = logpvals.mask(pvalsDf >= 0.05) import matplotlib.ticker as mticker print pvals cmap = mcmap.Reds cmap.set_bad('w', 1.) axes[1].pcolormesh(pvals.values.T, cmap=cmap, vmin=0, vmax=10, edgecolors='face', alpha=0.8) #axes[1].yaxis.set_major_locator(mticker.MultipleLocator(base=1.0)) axes[1].set_yticks( np.arange(pvals.values.shape[1]) + 0.5, minor=False) axes[1].set_yticklabels( pvalsDf.columns.values.tolist(), minor=False) legend_handles = [] for gi, GRP in enumerate(UNIQ_GROUPS): print '-------------------- GROUP ', gi, '----------------------' subgrp = df[df['group'] == GRP] print len(subgrp) if options.xrange: x0, x1 = options.xrange.split(',') x0 = int(x0) x1 = int(x1) subgrp = subgrp[(subgrp['position'] >= x0) & (subgrp['position'] < x1)] posGrp = subgrp.groupby('position', sort=True) cent_stats = posGrp.apply(lambda x: stats_per_group(x)) if len(cent_stats) == 0: continue cent_stats = cent_stats.unstack() cent_median_scalar = cent_stats['median'].tolist() x = np.array([i for i in posGrp.groups]) # print x # print cent_stats['median'].tolist() mcolor = np.array(grppal[gi]) # if gi>0: # mcolor*= 1./(1+gi) cent_color = tuple(cent_color) mcolor = tuple(mcolor) if type(axes) is list: cur_axe = axes[0] else: cur_axe = axes cur_axe.set_ylabel(SCALAR_NAME) # cur_axe.yaxis.label.set_color(cent_color) # cur_axe.tick_params(axis='y', colors=cent_color) #cur_axe.fill_between(x, [s[0] for s in cent_ci], [t[1] for t in cent_ci], alpha=0.3, color=mcolor) # cur_axe.fill_between(x, [s[0] for s in cent_stats['whisk'].tolist()], # [t[1] for t in cent_stats['whisk'].tolist()], alpha=0.1, color=mcolor) qtile_top = np.array([s[0] for s in cent_stats['ci'].tolist()]) qtile_bottom = np.array([t[1] for t in cent_stats['ci'].tolist()]) x_new, qtop_sm = smooth(x, qtile_top) x_new, qbottom_sm = smooth(x, qtile_bottom) cur_axe.fill_between(x_new, qtop_sm, qbottom_sm, alpha=0.25, color=mcolor) # cur_axe.errorbar(x, cent_stats['median'].tolist(), yerr=[[s[0] for s in cent_stats['err'].tolist()], # [t[1] for t in cent_stats['err'].tolist()]], color=mcolor, alpha=0.1) x_new, median_sm = smooth(x, cent_stats['median']) hnd, = cur_axe.plot(x_new, median_sm, c=mcolor) legend_handles.append(hnd) # cur_axe.scatter(x,cent_stats['median'].tolist(), c=mcolor) if options.yrange: plotrange = options.yrange.split(',') cur_axe.set_ylim([float(plotrange[0]), float(plotrange[1])]) legend_labels = UNIQ_GROUPS if name_mapping is not None: legend_labels = [name_mapping[str(i)] for i in UNIQ_GROUPS] cur_axe.legend(legend_handles, legend_labels) if annotations: for key, val in annotations.iteritems(): # print key cur_axe.axvspan(val[0], val[1], fill=False, linestyle='dashed') axis_to_data = cur_axe.transAxes + cur_axe.transData.inverted() data_to_axis = axis_to_data.inverted() axpoint = data_to_axis.transform((val[0], 0)) # print axpoint cur_axe.text(axpoint[0], 1.02, key, transform=cur_axe.transAxes) """ Plot 3D Viz """ if options.is_viz: scene.disable_render = True # scene.renderer.render_window.set(alpha_bit_planes=1,multi_samples=0) # scene.renderer.set(use_depth_peeling=True,maximum_number_of_peels=4,occlusion_ratio=0.1) # ran_colors = np.random.random_integers(255, size=(len(cent),4)) # ran_colors[:,-1] = 255 mypts = mlab.points3d(cent_verts[:, 0], cent_verts[:, 1], cent_verts[:, 2], labels, opacity=0.3, scale_mode='none', scale_factor=2, line_width=2, colormap='blue-red', mode='point') # print mypts.module_manager.scalar_lut_manager.lut.table.to_array() # mypts.module_manager.scalar_lut_manager.lut.table = ran_colors # mypts.module_manager.scalar_lut_manager.lut.number_of_colors = len(ran_colors) delta = len(cent) - len(cent_median_scalar) if delta > 0: cent_median_scalar = np.pad( cent_median_scalar, (0, delta), mode='constant', constant_values=0) # calculate the displacement vector for all pairs uvw = cent - np.roll(cent, 1, axis=0) uvw[0] *= 0 uvw = np.roll(uvw, -1, axis=0) arrow_plot = mlab.quiver3d( cent[:, 0], cent[:, 1], cent[:, 2], uvw[:, 0], uvw[:, 1], uvw[:, 2], scalars=cent_median_scalar, scale_factor=1, #color=mcolor, mode='arrow') gsource = arrow_plot.glyph.glyph_source.glyph_source # for name, thing in inspect.getmembers(gsource): # print name arrow_plot.glyph.color_mode = 'color_by_scalar' #arrow_plot.glyph.scale_mode = 'scale_by_scalar' #arrow_plot.glyph.glyph.clamping = True #arrow_plot.glyph.glyph.scale_factor = 5 #print arrow_plot.glyph.glyph.glyph_source gsource.tip_length = 0.4 gsource.shaft_radius = 0.2 gsource.tip_radius = 0.3 if options.show_centroid: tube_plot = mlab.plot3d(cent[:, 0], cent[:, 1], cent[ :, 2], cent_median_scalar, color=cent_color, tube_radius=0.2, opacity=0.25) tube_filter = tube_plot.parent.parent.filter tube_filter.vary_radius = 'vary_radius_by_scalar' tube_filter.radius_factor = 10 # plot first and last def plot_pos_index(p): pos = cent[p] mlab.text3d(pos[0], pos[1], pos[2], str(p), scale=0.8) for p in xrange(0, len(cent - 1), 10): plot_pos_index(p) plot_pos_index(len(cent) - 1) scene.disable_render = False DATADF.to_csv( '_'.join([filebase, SCALAR_NAME, 'rawdata.csv']), index=False) outfile = '_'.join([filebase, SCALAR_NAME]) print 'save to {}'.format(outfile) plt.savefig('{}.pdf'.format(outfile), dpi=300) if options.is_viz: plt.show(block=False) mlab.show()
import numpy as np x, y = np.mgrid[-10:10:200j, -10:10:200j] z = 100 * np.sin(x * y) / (x * y) # Visualize it with mlab.surf from mayavi import mlab #mlab.options.backend = 'envisage' mlab.figure(bgcolor=(1, 1, 1)) surf = mlab.surf(z, colormap='cool') # Retrieve the LUT of the surf object. lut = surf.module_manager.scalar_lut_manager.lut.table.to_array() # The lut is a 255x4 array, with the columns representing RGBA # (red, green, blue, alpha) coded with integers going from 0 to 255. # We modify the alpha channel to add a transparency gradient print lut lut[:, -1] = np.linspace(0, 255, 256) print lut # and finally we put this LUT back in the surface object. We could have # added any 255*4 array rather than modifying an existing LUT. surf.module_manager.scalar_lut_manager.lut.table = lut # We need to force update of the figure now that we have changed the LUT. mlab.draw() mlab.view(40, 85) mlab.show()
try: if bg_color is None or sum(bg_color) < 2: text_color = (1., 1., 1.) else: text_color = (0., 0., 0.) cbar.label_text_property.color = text_color # set the labels(around the bar) color except AttributeError: pass # _toggle_render(True, view) state = True if mlab.options.backend != "test": figure.scene.disable_render = not state if state is True and view is not None: mlab.draw(figure=figure) # mlab.view(*view, figure=figure) # The statement will raise some problem, so as the example! if state is True: # force render figure.render() mlab.draw(figure=figure) # remove overlay # -------------------------------------------------------------------------------- if pos_surf is not None: pos_surf.remove() pos_bar.visible = False if neg_surf is not None: neg_surf.remove() neg_bar.visible = False