def plotlyIsosurface(self, isomin, isomax, title): '''Return a plotly FigureWidget containing an isosurface representation of the 3D histogram Args: isomin (int): minimum value of isosurface colorscale isomax (int): max value of isosurface colorscale title (string): title of the plot ''' # generate X, Y, Z, Values arrays. Plotly needs these as inputs to the Isosurface function. We are storing these as properties # so that the user could access them to plot more involved isosurface representations from the Jupyter notebook. self.histo3D_unrolled = np.reshape(self.histo3D[0], [-1]) self.histo3D_unrolled_coords = [[x, y, z] for z in self.histo3D[1][2][:-1] for y in self.histo3D[1][1][:-1] for x in self.histo3D[1][0][:-1]] data = [ go.Isosurface(x=np.array(self.histo3D_unrolled_coords)[:, 0], y=np.array(self.histo3D_unrolled_coords)[:, 1], z=np.array(self.histo3D_unrolled_coords)[:, 2], value=self.histo3D_unrolled, isomin=isomin, isomax=isomax, colorscale='Blues') ] layout = go.Layout(title=title) return go.FigureWidget(data, layout)
def plot_iso_surface(sdf, box_size=(1.0, 1.0, 1.0), max_n_eval_pts=1e6, iso_max=0.1, resolution=64, thres=0.0, save_path=None) -> go.Figure: """ plot levelset at a certain cross section, assume inputs are centered Args: decode_points_func: A function to extract the SDF/occupancy logits of (N, 3) points box_size (List[float]): bounding box dimension max_n_eval_pts (int): max number of points to evaluate in one inference resolution (int): cross section resolution xy thres (float): levelset value imgs_per_cut (int): number of images for each cut (plotted in rows) Returns: a numpy array for the image """ grid = get_grid_uniform(resolution, box_side_length=box_size[0]) z = [] points = grid['grid_points'] for i, pnts in enumerate(torch.split(points, 100000, dim=0)): values = sdf(pnts) if isinstance(values, torch.Tensor): values = values.detach().cpu().numpy() z.append(values) z = np.concatenate(z, axis=0) z = z.astype(np.float32) points = points.cpu().numpy() fig = go.Figure(data=go.Isosurface( x=points[..., 0].flatten(), y=points[..., 1].flatten(), z=points[..., 2].flatten(), value=z.flatten(), opacity=0.6, isomin=0.0, isomax=iso_max, surface_count=3, surface_fill=0.8, # surface_pattern='A', # surface=dict(count=3, fill=0.7), caps=dict(x_show=False, y_show=False, z_show=False))) if save_path is not None: os.makedirs(os.path.dirname(save_path), exist_ok=True) fig.write_html(save_path) return fig
def get_plot(self, var, value, plottype, colorscale="Viridis", sym="F", log="F", vmin="", vmax=""): ''' Return a plotly figure object for the plottype requested. var, value, and plottype are required variables. -var = name of plot variable -value = position of slice or isosurface value -plottype = 2D-alt, 2D-lon, 2D-lat, 3D-alt, iso -colorscale = Viridis [default], Cividis, Rainbow, or BlueRed -sym = F [default] for symetric colorscale around 0 -log = F [default] for log10() of plot value -vmin,vmax = minimum and maximum value for contour values, empty is min/max ''' # Common code blocks for all plots txtbot = "Model: GITM v" + str(self.codeversion) + ", Run: " + self.runname units=self.variables[var]['units'] txtbar = var + " [" + units + "]" time=self.dtvalue.strftime("%Y/%m/%d %H:%M:%S UT") if plottype == "2D-alt": # Check if altitude entered is valid if value < self.altmin or value > self.altmax: print('Altitude is out of range: alt=',value,\ ' min/max=',self.altmin,'/',self.altmax) return # Set grid for plot altkm=value/1000. ilon = np.linspace(0, 360, 361) ilat = np.linspace(-90, 90, 181) ialt = np.array([value]) xx, yy = np.meshgrid(np.array(ilon), np.array(ilat)) grid = np.ndarray(shape=(np.size(np.reshape(xx,-1)),3), dtype=np.float32) grid[:,0] = np.reshape(xx,-1) grid[:,1] = np.reshape(yy,-1) grid[:,2] = value test = self.variables[var]['interpolator'](grid) result = np.reshape(test,(ilat.shape[0],ilon.shape[0])) if log == "T": txtbar = "log<br>"+txtbar result = np.log10(result) if sym == "T": cmax = np.max(np.absolute(result)) if vmax != "": cmax = abs(float(vmax)) if vmin != "": cmax = max(cmax,abs(float(vmin))) cmin = -cmax else: cmax = np.max(result) cmin = np.min(result) if vmax != "": cmax = float(vmax) if vmin != "": cmin = float(vmin) def plot_var(lon = ilon, lat = ilat): return result plotvar = Kamodo(plot_var = plot_var) fig = plotvar.plot(plot_var = dict()) #fig.update_xaxes(nticks=7,title_text="",scaleanchor='y') fig.update_xaxes(tick0=0.,dtick=45.,title_text="") fig.update_yaxes(tick0=0.,dtick=45,title_text="") if colorscale == "BlueRed": fig.update_traces(colorscale="RdBu", reversescale=True) elif colorscale == "Rainbow": fig.update_traces( colorscale=[[0.00, 'rgb(0,0,255)'], [0.25, 'rgb(0,255,255)'], [0.50, 'rgb(0,255,0)'], [0.75, 'rgb(255,255,0)'], [1.00, 'rgb(255,0,0)']] ) else: fig.update_traces(colorscale=colorscale) fig.update_traces( zmin=cmin, zmax=cmax, ncontours=201, colorbar=dict(title=txtbar, tickformat=".3g"), contours=dict(coloring="fill", showlines=False) ) if log == "T": fig.update_traces( hovertemplate="Lon: %{x:.0f}<br>Lat: %{y:.0f}<br><b>log("+var+"): %{z:.4g}</b><extra></extra>" ) else: fig.update_traces( hovertemplate="Lon: %{x:.0f}<br>Lat: %{y:.0f}<br><b>"+var+": %{z:.4g}</b><extra></extra>" ) fig.update_layout( title=dict(text="Altitude="+"{:.0f}".format(altkm)+" km, Time = " + time, yref="container", yanchor="top", y=0.95), title_font_size=16, annotations=[ dict(text="Lon [degrees]", x=0.5, y=-0.13, showarrow=False, xref="paper", yref="paper", font=dict(size=12)), dict(text="Lat [degrees]", x=-0.1, y=0.5, showarrow=False, xref="paper", yref="paper", font=dict(size=12), textangle=-90), dict(text=txtbot, x=0.0, y=0.0, ax=0, ay=0, xanchor="left", xshift=-65, yshift=-42, xref="paper", yref="paper", font=dict(size=16, family="sans serif", color="#000000")) ], height=340 ) return fig if plottype == "2D-lat": # Check if latitude entered is valid if value < -90. or value > 90.: print('Latitude is out of range: lat=',value,' min/max= -90./90.') return # Set grid for plot ilon = np.linspace(0, 360, 361) ilat = np.array([value]) ialt = np.linspace(self.altmin, self.altmax, 300) xx, yy = np.meshgrid(np.array(ilon), np.array(ialt)) grid = np.ndarray(shape=(np.size(np.reshape(xx,-1)),3), dtype=np.float32) grid[:,0] = np.reshape(xx,-1) grid[:,1] = value grid[:,2] = np.reshape(yy,-1) test = self.variables[var]['interpolator'](grid) result = np.reshape(test,(ialt.shape[0],ilon.shape[0])) if log == "T": txtbar = "log<br>"+txtbar result = np.log10(result) if sym == "T" and vmin == "" and vmax == "": cmax = np.max(np.absolute(result)) cmin = -cmax else: cmax = np.max(result) cmin = np.min(result) if vmax != "": cmax = float(vmax) if vmin != "": cmin = float(vmin) ialt = ialt/1000. def plot_var(lon = ilon, alt = ialt): return result plotvar = Kamodo(plot_var = plot_var) fig = plotvar.plot(plot_var = dict()) #fig.update_xaxes(nticks=7,title_text="",scaleanchor='y') fig.update_xaxes(tick0=0.,dtick=45.,title_text="") fig.update_yaxes(title_text="") if colorscale == "BlueRed": fig.update_traces(colorscale="RdBu", reversescale=True) elif colorscale == "Rainbow": fig.update_traces( colorscale=[[0.00, 'rgb(0,0,255)'], [0.25, 'rgb(0,255,255)'], [0.50, 'rgb(0,255,0)'], [0.75, 'rgb(255,255,0)'], [1.00, 'rgb(255,0,0)']] ) else: fig.update_traces(colorscale=colorscale) fig.update_traces( zmin=cmin, zmax=cmax, ncontours=201, colorbar=dict(title=txtbar, tickformat=".3g"), contours=dict(coloring="fill", showlines=False) ) if log == "T": fig.update_traces( hovertemplate="Lon: %{x:.0f}<br>Alt: %{y:.0f}<br><b>log("+var+"): %{z:.4g}</b><extra></extra>" ) else: fig.update_traces( hovertemplate="Lon: %{x:.0f}<br>Alt: %{y:.0f}<br><b>"+var+": %{z:.4g}</b><extra></extra>" ) fig.update_layout( title=dict(text="Latitude="+"{:.1f}".format(value)+" degrees, Time = " + time, yref="container", yanchor="top", y=0.95), title_font_size=16, annotations=[ dict(text="Lon [degrees]", x=0.5, y=-0.13, showarrow=False, xref="paper", yref="paper", font=dict(size=12)), dict(text="Altitude [km]", x=-0.1, y=0.5, showarrow=False, xref="paper", yref="paper", font=dict(size=12), textangle=-90), dict(text=txtbot, x=0.0, y=0.0, ax=0, ay=0, xanchor="left", xshift=-65, yshift=-42, xref="paper", yref="paper", font=dict(size=16, family="sans serif", color="#000000")) ], height=340 ) return fig if plottype == "2D-lon": # Check if longitude entered is valid if value < 0. or value > 360.: print('Latitude is out of range: lat=',value,' min/max= 0./360.') return # Set grid for plot ilon = np.array([value]) ilat = np.linspace(-90, 90, 181) ialt = np.linspace(self.altmin, self.altmax, 300) xx, yy = np.meshgrid(np.array(ilat), np.array(ialt)) grid = np.ndarray(shape=(np.size(np.reshape(xx,-1)),3), dtype=np.float32) grid[:,0] = value grid[:,1] = np.reshape(xx,-1) grid[:,2] = np.reshape(yy,-1) test = self.variables[var]['interpolator'](grid) result = np.reshape(test,(ialt.shape[0],ilat.shape[0])) if log == "T": txtbar = "log<br>"+txtbar result = np.log10(result) if sym == "T" and vmin == "" and vmax == "": cmax = np.max(np.absolute(result)) cmin = -cmax else: cmax = np.max(result) cmin = np.min(result) if vmax != "": cmax = float(vmax) if vmin != "": cmin = float(vmin) ialt = ialt/1000. def plot_var(lat = ilat, alt = ialt): return result plotvar = Kamodo(plot_var = plot_var) fig = plotvar.plot(plot_var = dict()) #fig.update_xaxes(nticks=7,title_text="",scaleanchor='y') fig.update_xaxes(tick0=0.,dtick=30.,title_text="") fig.update_yaxes(title_text="") if colorscale == "BlueRed": fig.update_traces(colorscale="RdBu", reversescale=True) elif colorscale == "Rainbow": fig.update_traces( colorscale=[[0.00, 'rgb(0,0,255)'], [0.25, 'rgb(0,255,255)'], [0.50, 'rgb(0,255,0)'], [0.75, 'rgb(255,255,0)'], [1.00, 'rgb(255,0,0)']] ) else: fig.update_traces(colorscale=colorscale) fig.update_traces( zmin=cmin, zmax=cmax, ncontours=201, colorbar=dict(title=txtbar, tickformat=".3g"), contours=dict(coloring="fill", showlines=False) ) if log == "T": fig.update_traces( hovertemplate="Lat: %{x:.0f}<br>Alt: %{y:.0f}<br><b>log("+var+"): %{z:.4g}</b><extra></extra>" ) else: fig.update_traces( hovertemplate="Lat: %{x:.0f}<br>Alt: %{y:.0f}<br><b>"+var+": %{z:.4g}</b><extra></extra>" ) fig.update_layout( title=dict(text="Longitude="+"{:.1f}".format(value)+" degrees, Time = " + time, yref="container", yanchor="top", y=0.95), title_font_size=16, annotations=[ dict(text="Lat [degrees]", x=0.5, y=-0.13, showarrow=False, xref="paper", yref="paper", font=dict(size=12)), dict(text="Altitude [km]", x=-0.1, y=0.5, showarrow=False, xref="paper", yref="paper", font=dict(size=12), textangle=-90), dict(text=txtbot, x=0.0, y=0.0, ax=0, ay=0, xanchor="left", xshift=-65, yshift=-42, xref="paper", yref="paper", font=dict(size=16, family="sans serif", color="#000000")) ], height=340 ) return fig if plottype == "3D-alt": # Check if altitude entered is valid if value < self.altmin or value > self.altmax: print('Altitude is out of range: alt=',value,\ ' min/max=',self.altmin,'/',self.altmax) return # Set grid for plot altkm=value/1000. ilon = np.linspace(0, 360, 361) ilat = np.linspace(-90, 90, 181) ialt = np.array([value]) xx, yy = np.meshgrid(np.array(ilon), np.array(ilat)) grid = np.ndarray(shape=(np.size(np.reshape(xx,-1)),3), dtype=np.float32) grid[:,0] = np.reshape(xx,-1) grid[:,1] = np.reshape(yy,-1) grid[:,2] = value test = self.variables[var]['interpolator'](grid) result = np.reshape(test,(ilat.shape[0],ilon.shape[0])) if log == "T": txtbar = "log<br>"+txtbar result = np.log10(result) r = value + 6.3781E6 x=-(r*np.cos(yy*np.pi/180.)*np.cos(xx*np.pi/180.))/6.3781E6 y=-(r*np.cos(yy*np.pi/180.)*np.sin(xx*np.pi/180.))/6.3781E6 z= (r*np.sin(yy*np.pi/180.))/6.3781E6 if sym == "T" and vmin == "" and vmax == "": cmax = np.max(np.absolute(result)) cmin = -cmax else: cmax = np.max(result) cmin = np.min(result) if vmax != "": cmax = float(vmax) if vmin != "": cmin = float(vmin) def plot_var(x = x, y = y, z = z): return result plotvar = Kamodo(plot_var = plot_var) fig = plotvar.plot(plot_var = dict()) fig.update_scenes(xaxis=dict(title=dict(text="X [Re]")), yaxis=dict(title=dict(text="Y [Re]")), zaxis=dict(title=dict(text="Z [Re]"))) if colorscale == "BlueRed": fig.update_traces(colorscale="RdBu", reversescale=True) elif colorscale == "Rainbow": fig.update_traces( colorscale=[[0.00, 'rgb(0,0,255)'], [0.25, 'rgb(0,255,255)'], [0.50, 'rgb(0,255,0)'], [0.75, 'rgb(255,255,0)'], [1.00, 'rgb(255,0,0)']] ) else: fig.update_traces(colorscale=colorscale) fig.update_traces( cmin=cmin, cmax=cmax, colorbar=dict(title=txtbar, tickformat=".3g") ) fig.update_traces( hovertemplate="X [Re]: %{x:.3f}<br>Y [Re]: %{y:.3f}<br>Z [Re]: %{z:.3f}<extra></extra>" ) fig.update_layout( title=dict(text="Altitude="+"{:.0f}".format(altkm)+" km, Time = " + time, yref="container", yanchor="top", x=0.01, y=0.95), title_font_size=16, annotations=[ dict(text=txtbot, x=0.0, y=0.0, ax=0, ay=0, xanchor="left", xshift=0, yshift=-20, xref="paper", yref="paper", font=dict(size=16, family="sans serif", color="#000000")) ], margin=dict(l=0), width=600 ) x1=[0., 0.] y1=[0., 0.] z1=[-1.2, 1.2] fig.add_scatter3d(mode='lines',x=x1,y=y1,z=z1,line=dict(width=4,color='black'), showlegend=False,hovertemplate='Polar Axis<extra></extra>') r = value + 10000. + 6.3781E6 x2=-(r*np.cos(ilon*np.pi/180.))/6.3781E6 y2=-(r*np.sin(ilon*np.pi/180.))/6.3781E6 z2=0.*ilon fig.add_scatter3d(mode='lines',x=x2,y=y2,z=z2,line=dict(width=2,color='black'), showlegend=False,hovertemplate='Equator<extra></extra>') x3=-(r*np.cos(ilat*np.pi/180.))/6.3781E6 y3=0.*ilat z3=(r*np.sin(ilat*np.pi/180.))/6.3781E6 fig.add_scatter3d(mode='lines',x=x3,y=y3,z=z3,line=dict(width=2,color='black'), showlegend=False,hovertemplate='prime meridian<extra></extra>') return fig if plottype == "iso": # Check if value entered is valid (checking before possible log scale) cmin=np.min(self.variables[var]['data']) cmax=np.max(self.variables[var]['data']) if value < cmin or value > cmax: print('Iso value is out of range: iso=',value,' min/max=',cmin,cmax) sys.exit("Exiting ...") return # Determine altitude start, stop, number for use in isosurface step=10000. # 10000. m altitude step size alt1=(step*round(self.alt[2]/step)) alt2=self.alt[(self.alt.shape[0]-3)] nalt=round((alt2-alt1)/step) alt2=alt1+step*nalt nalt=1+int(nalt) # Set function for interpolation gridp = np.ndarray(shape=(1,3), dtype=np.float32) def finterp(x,y,z): gridp[0,:] = [x,y,z] return self.variables[var]['interpolator'](gridp) # Extract the isosurface as vertices and triangulated connectivity import mcubes verts, tri = mcubes.marching_cubes_func( (0, -90, alt1), (360, 90, alt2), # Bounds (min:x,y,z), (max:x,y,z) 73, 37, nalt, # Number of samples in each dimension finterp, # Implicit function of x,y,z value) # Isosurface value # Process output for creating plots, including face or vertex colors X, Y, Z = verts[:,:3].T I, J, K = tri.T # Update cmin, cmax for log, sym, vmin, vmax values if set if sym == "T": if vmax != "": cmax = abs(float(vmax)) if vmin != "": cmax = max(cmax,abs(float(vmin))) cmin = -cmax else: if vmax != "": cmax = float(vmax) if vmin != "": cmin = float(vmin) dvalue=value if log == "T": dvalue=np.log10(value) cmin=np.log10(cmin) cmax=np.log10(cmax) txtbar = "log<br>"+txtbar # Create fig and update with modifications to customize plot fig=go.Figure(data=[go.Mesh3d( name='ISO', x=X, y=Y, z=Z, i=I, j=J, k=K, intensity=np.linspace(dvalue,dvalue,X.shape[0]), cmin=cmin, cmax=cmax, showscale=True, opacity=0.6, colorbar=dict(title=txtbar, tickformat=".3g"), hovertemplate="Isosurface<br>"+var+"="+"{:.3g}".format(value)+" "+units+"<br><extra></extra>", )]) if colorscale == "BlueRed": fig.update_traces(colorscale="RdBu", reversescale=True) elif colorscale == "Rainbow": fig.update_traces( colorscale=[[0.00, 'rgb(0,0,255)'], [0.25, 'rgb(0,255,255)'], [0.50, 'rgb(0,255,0)'], [0.75, 'rgb(255,255,0)'], [1.00, 'rgb(255,0,0)']] ) else: fig.update_traces(colorscale=colorscale) fig.update_scenes( xaxis=dict(title=dict(text="Lon [degrees]"),tick0=0.,dtick=45.), yaxis=dict(title=dict(text="Lat [degrees]"),tick0=0.,dtick=45.), zaxis=dict(title=dict(text="Alt [km]")) ) fig.update_layout( scene_camera_eye=dict(x=.1, y=-1.8, z=1.5), scene_aspectmode='manual', scene_aspectratio=dict(x=2, y=1, z=1), scene_xaxis=dict(range=[0,360]), scene_yaxis=dict(range=[-90,90]), scene_zaxis=dict(range=[alt1,alt2]), title=dict(text="Isosurface of "+var+"="+\ "{:.3g}".format(value)+" "+units+"<br>"+"Time = "+time, yref="container", yanchor="top", x=0.01, y=0.95), title_font_size=16, annotations=[ dict(text=txtbot, x=0.0, y=0.0, ax=0, ay=0, xanchor="left", xshift=0, yshift=-20, xref="paper", yref="paper", font=dict(size=16, family="sans serif", color="#000000")) ], margin=dict(l=10,t=80), ) return fig if plottype == "iso1": # This method keeps all the data local, resulting in huge storage # as well as corrupted html divs. ilon = np.linspace(0, 360, 73) #181 ilat = np.linspace(-90, 90, 37) #91 step=10000. # 5000. alt1=(step*round(self.alt[2]/step)) alt2=self.alt[(self.alt.shape[0]-3)] nalt=round((alt2-alt1)/step) alt2=alt1+step*nalt nalt=1+int(nalt) ialt = np.linspace(alt1, alt2, nalt) xx,yy,zz = np.meshgrid(np.array(ilon),np.array(ilat),np.array(ialt)) grid = np.ndarray(shape=(np.size(np.reshape(xx,-1)),3), dtype=np.float32) grid[:,0] = np.reshape(xx,-1) grid[:,1] = np.reshape(yy,-1) grid[:,2] = np.reshape(zz,-1) test = self.variables[var]['interpolator'](grid) result = np.reshape(test,(ilat.shape[0],ilon.shape[0],ialt.shape[0])) isovalue=value if log == "T": isovalue=np.log10(value) txtbar = "log<br>"+txtbar result = np.log10(result) if sym == "T": cmax = np.max(np.absolute(result)) if vmax != "": cmax = abs(float(vmax)) if vmin != "": cmax = max(cmax,abs(float(vmin))) cmin = -cmax else: cmax = np.max(result) cmin = np.min(result) if vmax != "": cmax = float(vmax) if vmin != "": cmin = float(vmin) # Check if value entered is valid (checking before possible log scale) if value < np.min(test) or value > np.max(test): print('Iso value is out of range: iso=',value,\ ' min/max=',np.min(test),'/',np.max(test)) sys.exit("Exiting ...") return slicevalue=0. fig1 = go.Figure(data=go.Isosurface( x=xx.flatten(), y=yy.flatten(), z=zz.flatten()/1000., value=result.flatten(), opacity=0.6, isomin=cmin, isomax=cmax, surface=dict(count=2, fill=1., pattern='all'), caps=dict(x_show=False, y_show=False, z_show=False), showscale=True, # show colorbar colorbar=dict(title=txtbar, tickformat=".3g"), slices_y=dict(show=True, locations=[slicevalue]), )) fig1.update_traces( hovertemplate="<b>Slice</b><br>Lon: %{x:.0f}<br>Lat: %{y:.0f}<br>Alt: %{z:.0f}km<br><extra></extra>" ) if colorscale == "BlueRed": fig1.update_traces(colorscale="RdBu", reversescale=True) elif colorscale == "Rainbow": fig1.update_traces( colorscale=[[0.00, 'rgb(0,0,255)'], [0.25, 'rgb(0,255,255)'], [0.50, 'rgb(0,255,0)'], [0.75, 'rgb(255,255,0)'], [1.00, 'rgb(255,0,0)']] ) else: fig1.update_traces(colorscale=colorscale) fig2 = go.Figure(data=go.Isosurface( x=xx.flatten(), y=yy.flatten(), z=zz.flatten()/1000., value=result.flatten(), opacity=1., colorscale=[[0.0, '#777777'],[1.0, '#777777']], isomin=isovalue, isomax=isovalue, surface=dict(count=1, fill=1., pattern='all'), caps=dict(x_show=False, y_show=False, z_show=False), showscale=False, # remove colorbar hovertemplate="<b>Isosurface</b><br>Lon: %{x:.0f}<br>Lat: %{y:.0f}<br>Alt: %{z:.0f}km<extra></extra>" )) fig2.update_scenes( xaxis=dict(title=dict(text="Lon [degrees]"),tick0=0.,dtick=45.), yaxis=dict(title=dict(text="Lat [degrees]"),tick0=0.,dtick=45.), zaxis=dict(title=dict(text="Alt [km]")) ) fig2.update_layout( scene_camera_eye=dict(x=.1, y=-1.8, z=1.5), scene_aspectmode='manual', scene_aspectratio=dict(x=2, y=1, z=1), title=dict(text="Latitude="+"{:.0f}".format(slicevalue)+ " slice through the data<br>"+ "Isosurface of "+var+"="+"{:.0f}".format(isovalue)+units+"<br>"+ "Time = " + time, yref="container", yanchor="top", x=0.01, y=0.95), title_font_size=16, annotations=[ dict(text=txtbot, x=0.0, y=0.0, ax=0, ay=0, xanchor="left", xshift=0, yshift=-20, xref="paper", yref="paper", font=dict(size=16, family="sans serif", color="#000000")) ], margin=dict(l=10,t=80), ) fig2.add_trace(fig1.data[0]) return fig2 print('Unknown plottype (',plottype,') returning.') return
t = np.linspace(0, 10, 50) x, y, z = np.cos(t), np.sin(t), t fig = go.Figure(data=[go.Scatter3d(x=x, y=y, z=z, mode='markers')]) pyo.plot(fig, filename="3dscatter.html") fig.show() # In[104]: X, Y, Z = np.mgrid[-5:5:40j, -5:5:40j, -5:5:40j] values = X * X * 0.5 + Y * Y + Z * Z * 2 fig = go.Figure(data=go.Isosurface( x=X.flatten(), y=Y.flatten(), z=Z.flatten(), value=values.flatten(), isomin=10, isomax=50, surface_count=5, # number of isosurfaces, 2 by default: only min and max colorbar_nticks=5, # colorbar ticks correspond to isosurface values caps=dict(x_show=False, y_show=False))) pyo.plot(fig, filename="isospace.html") fig.show() # In[105]: fig = go.Figure(data=go.Streamtube(x=[0, 0, 0], y=[0, 1, 2], z=[0, 0, 0], u=[0, 0, 0], v=[1, 1, 1], w=[0, 0, 0]))
def update_figure(invar, invar_2, invar_3, outvar, invar1_log, invar2_log, invar3_log, outvar_log, param_slider, graph_type, color_use, color_dd, error_use, error_dd, filter_active, fit_use, fit_dd, fit_num, fit_conf, add_noise_var, fit_color, fit_opacity, fit_sampling, id_type, param_center, param_log): for i in range(len(param_slider)): if param_log[i] == ['log']: param_slider[i] = [10**val for val in param_slider[i]] param_center[i] = 10**param_center[i] if invar is None: return go.Figure() sel_y = np.full((len(outdata),), True) dds_value = [] for iteration, values in enumerate(param_slider): dds_value.append(invars[id_type[iteration]['index']]) # filter for minimum sel_y_min = np.array(indata[dds_value[iteration]] >= param_slider[iteration][0]) # filter for maximum sel_y_max = np.array(indata[dds_value[iteration]] <= param_slider[iteration][1]) # print('iter ', iteration, 'filer', filter_active[iteration][0]) if filter_active != [[]]: if filter_active[iteration] == ['act']: sel_y = sel_y_min & sel_y_max & sel_y if graph_type == '1D': fig = go.Figure( data=[go.Scatter( x=indata[invar][sel_y], y=outdata[outvar][sel_y], mode='markers', name='data', error_y=dict(type='data', array=outdata[error_dd][sel_y], visible= error_use == ['true']), # text=[(invar, outvar) for i in range(len(indata[invar][sel_y]))], # hovertemplate=" %{text} <br> %{x} <br> %{y}", )], layout=go.Layout(xaxis=dict(title=invar, rangeslider=dict(visible=True)), yaxis=dict(title=outvar)) ) if fit_use == ['show']: mesh_in, mesh_out, mesh_out_std, fit_dd_values = mesh_fit(param_slider, id_type, fit_dd, fit_num, param_center, [invar], [invar1_log], outvar, fit_sampling, add_noise_var) for i in range(len(fit_dd_values)): fig.add_trace(go.Scatter( x=mesh_in[i][invars.index(invar)], y=mesh_out[i], mode='lines', name=f'fit: {fit_dd}={fit_dd_values[i]:.1e}', line_color=colormap(indata[fit_dd].min(), indata[fit_dd].max(), fit_dd_values[i]), marker_line=dict(coloraxis="coloraxis2"), )) fig.add_trace(go.Scatter( x=np.hstack((mesh_in[i][invars.index(invar)], mesh_in[i][invars.index(invar)][::-1])), y=np.hstack((mesh_out[i] + fit_conf * mesh_out_std[i], mesh_out[i][::-1] - fit_conf * mesh_out_std[i][::-1])), showlegend=False, fill='toself', line_color=colormap(indata[fit_dd].min(), indata[fit_dd].max(), fit_dd_values[i]), marker_line=dict(coloraxis="coloraxis2"), opacity=fit_opacity, )) elif graph_type == '2D': fig = go.Figure( data=[go.Scatter3d( x=indata[invar][sel_y], y=indata[invar_2][sel_y], z=outdata[outvar][sel_y], mode='markers', name='Data', error_z=dict(type='data', array=outdata[error_dd][sel_y], visible=error_use == ['true'], width= 10) )], layout=go.Layout(scene=dict(xaxis_title=invar, yaxis_title=invar_2, zaxis_title=outvar)) ) if fit_use == ['show'] and invar != invar_2: mesh_in, mesh_out, mesh_out_std, fit_dd_values = mesh_fit(param_slider, id_type, fit_dd, fit_num, param_center, [invar, invar_2], [invar1_log, invar2_log], outvar, fit_sampling, add_noise_var) for i in range(len(fit_dd_values)): fig.add_trace(go.Surface( x=mesh_in[i][invars.index(invar)].reshape((fit_sampling, fit_sampling)), y=mesh_in[i][invars.index(invar_2)].reshape((fit_sampling, fit_sampling)), z=mesh_out[i].reshape((fit_sampling, fit_sampling)), name=f'fit: {fit_dd}={fit_dd_values[i]:.2f}', surfacecolor=fit_dd_values[i] * np.ones([fit_sampling, fit_sampling]) if fit_color == 'multi-fit' else (mesh_in[i][invars.index(color_dd)].reshape((fit_sampling, fit_sampling)) if (fit_color == 'marker-color' and color_dd in invars) else mesh_out[i].reshape((fit_sampling, fit_sampling))), opacity=fit_opacity, coloraxis="coloraxis2" if (fit_color == 'multi-fit' or (fit_color == 'output' and (color_dd != outvar and color_dd != 'OUTPUT'))) else "coloraxis", showlegend=True if len(invars) > 2 else False, )) if fit_conf > 0: fig.add_trace(go.Surface( x=mesh_in[i][invars.index(invar)].reshape((fit_sampling, fit_sampling)), y=mesh_in[i][invars.index(invar_2)].reshape((fit_sampling, fit_sampling)), z=mesh_out[i].reshape((fit_sampling, fit_sampling)) + fit_conf * mesh_out_std[i].reshape((fit_sampling, fit_sampling)), showlegend=False, name=f'fit+v: {fit_dd}={fit_dd_values[i]:.2f}', surfacecolor=fit_dd_values[i] * np.ones([fit_sampling, fit_sampling]) if fit_color == 'multi-fit' else (mesh_in[i][invars.index(color_dd)].reshape((fit_sampling, fit_sampling)) if (fit_color == 'marker-color' and color_dd in invars) else mesh_out[i].reshape((fit_sampling, fit_sampling))), opacity=fit_opacity, coloraxis="coloraxis2" if (fit_color == 'multi-fit' or (fit_color == 'output' and (color_dd != outvar and color_dd != 'OUTPUT'))) else "coloraxis", )) fig.add_trace(go.Surface( x=mesh_in[i][invars.index(invar)].reshape((fit_sampling, fit_sampling)), y=mesh_in[i][invars.index(invar_2)].reshape((fit_sampling, fit_sampling)), z=mesh_out[i].reshape((fit_sampling, fit_sampling)) - fit_conf * mesh_out_std[i].reshape((fit_sampling, fit_sampling)), showlegend=False, name=f'fit-v: {fit_dd}={fit_dd_values[i]:.2f}', surfacecolor=fit_dd_values[i] * np.ones([fit_sampling, fit_sampling]) if fit_color == 'multi-fit' else (mesh_in[i][invars.index(color_dd)].reshape((fit_sampling, fit_sampling)) if (fit_color == 'marker-color' and color_dd in invars) else mesh_out[i].reshape((fit_sampling, fit_sampling))), opacity=fit_opacity, coloraxis="coloraxis2" if (fit_color == 'multi-fit' or (fit_color == 'output' and (color_dd != outvar and color_dd != 'OUTPUT'))) else "coloraxis", )) fig.update_layout(coloraxis2=dict( colorbar=dict(title=outvar if fit_color == 'output' else fit_dd), cmin=min(fit_dd_values) if fit_color == 'multi-fit' else None, cmax=max(fit_dd_values) if fit_color == 'multi-fit' else None, )) elif graph_type == '2D contour': mesh_in, mesh_out, mesh_out_std, fit_dd_values = mesh_fit(param_slider, id_type, fit_dd, fit_num, param_center, [invar, invar_2], [invar1_log, invar2_log], outvar, fit_sampling, add_noise_var) data_x = mesh_in[0][invars.index(invar)] data_y = mesh_in[0][invars.index(invar_2)] fig = go.Figure() if min(data_x) != max(data_x): if min(data_y) != max(data_y): fig.add_trace(go.Scatter( x=indata[invar][sel_y], y=indata[invar_2][sel_y], mode='markers', name='Data', )) fig.add_trace(go.Contour( x=mesh_in[0][invars.index(invar)], y=mesh_in[0][invars.index(invar_2)], z=mesh_out[0], contours_coloring='heatmap', contours_showlabels=True, coloraxis='coloraxis2', name='fit', )) fig.update_xaxes( range=[log10(min(fig.data[1]['x'])), log10(max(fig.data[1]['x']))] if invar1_log == ['log'] else [min(fig.data[1]['x']), max(fig.data[1]['x'])]) fig.update_yaxes( range=[log10(min(fig.data[1]['y'])), log10(max(fig.data[1]['y']))] if invar2_log == ['log'] else [min(fig.data[1]['y']), max(fig.data[1]['y'])]) fig.update_layout(xaxis_title=invar, yaxis_title=invar_2, coloraxis2=dict(colorbar=dict(title=outvar), colorscale='solar', cmin=min(fig.data[1]['z']), cmax=max(fig.data[1]['z']))) else: fig.update_layout(title="y-data is constant, no contour-plot possible") else: fig.update_layout(title="x-data is constant, no contour-plot possible") elif graph_type == '3D': fig = go.Figure( data=go.Scatter3d( x=indata[invar][sel_y], y=indata[invar_2][sel_y], z=indata[invar_3][sel_y], mode='markers', marker=dict( color=outdata[outvar][sel_y], coloraxis="coloraxis2", ), name='Data', ), layout=go.Layout(scene=dict(xaxis_title=invar, yaxis_title=invar_2, zaxis_title=invar_3)), ) fig.update_layout(coloraxis2=dict( colorbar=dict(title=outvar), )) if fit_use == ['show'] and len({invar, invar_2, invar_3}) == 3: mesh_in, mesh_out, mesh_out_std, fit_dd_values = mesh_fit(param_slider, id_type, fit_dd, fit_num, param_center, [invar, invar_2, invar_3], [invar1_log, invar2_log, invar3_log], outvar, fit_sampling, add_noise_var) for i in range(len(fit_dd_values)): fig.add_trace( go.Isosurface( x=mesh_in[i][invars.index(invar)], y=mesh_in[i][invars.index(invar_2)], z=mesh_in[i][invars.index(invar_3)], value=mesh_out[i], surface_count=fit_num, coloraxis="coloraxis2", isomin=mesh_out[i].min() * 1.1, isomax=mesh_out[i].max() * 0.9, caps=dict(x_show=False, y_show=False, z_show=False), opacity=fit_opacity, ), ) else: fig = go.Figure() fig.update_layout(legend=dict(xanchor="left", x=0.01)) # log scale log_dict = {'1D': (invar1_log, outvar_log), '2D': (invar1_log, invar2_log, outvar_log), '2D contour': (invar1_log, invar2_log), '3D': (invar1_log, invar2_log, invar3_log),} log_list = ['linear' if log is None or len(log) == 0 else log[0] for log in log_dict[graph_type]] log_key = ['xaxis', 'yaxis', 'zaxis'] comb_dict = dict(zip(log_key, [{'type': log} for log in log_list])) if len(log_list) < 3 : fig.update_layout(**comb_dict) else: fig.update_scenes(**comb_dict) # color if color_use == ['true']: if fit_use == ['show'] and (graph_type=='2D' and (fit_color=='multi-fit' and color_dd==fit_dd)): fig.update_traces( marker=dict( coloraxis="coloraxis2", color=indata[color_dd][sel_y] if color_dd in indata.dtype.names else outdata[color_dd][sel_y], ), selector=dict(mode='markers'), ) elif graph_type == '3D': fig.update_traces( marker=dict( coloraxis="coloraxis2", color=outdata[outvar][sel_y], ), selector=dict(mode='markers'), ) elif graph_type=='1D': fig.update_traces( marker=dict( coloraxis="coloraxis2", color=outdata[outvar][sel_y] if color_dd == 'OUTPUT' else (indata[color_dd][sel_y] if color_dd in indata.dtype.names else outdata[color_dd][sel_y]), ), selector=dict(mode='markers'), ) if color_dd==fit_dd: fig.update_layout(coloraxis2=dict(colorscale='cividis', colorbar=dict(title=fit_dd))) elif color_dd == 'OUTPUT': fig.update_layout(coloraxis2=dict(colorscale='plasma', colorbar=dict(title=outvar))) else: fig.update_layout(coloraxis2=dict(colorscale='plasma', colorbar=dict(title=color_dd))) elif graph_type =='2D contour': fig.update_traces( marker=dict( coloraxis="coloraxis", color=outdata[outvar][sel_y] if color_dd == 'OUTPUT' else (indata[color_dd][sel_y] if color_dd in indata.dtype.names else outdata[color_dd][sel_y]), ), selector=dict(mode='markers'), ) if color_dd == outvar or color_dd == 'OUTPUT': fig.update_traces(marker_coloraxis="coloraxis2", selector=dict(mode='markers')) else: fig.update_layout(coloraxis=dict(colorbar=dict(title=color_dd, x=1.1), colorscale='ice')) else: fig.update_traces( marker=dict( coloraxis="coloraxis", color=outdata[outvar][sel_y] if color_dd == 'OUTPUT' else (indata[color_dd][sel_y] if color_dd in indata.dtype.names else outdata[color_dd][sel_y]), ), selector=dict(mode='markers'), ) fig.update_layout(coloraxis=dict( colorbar=dict(title=outvar if color_dd == 'OUTPUT' else color_dd, x=1.1), colorscale='viridis', )) fig.update_layout(height=graph_height) return fig
def voxel2plotly(neuron, legendgroup, showlegend, label, color, as_scatter=True, **kwargs): """Convert VoxelNeuron to plotly object. Turns out that plotly is horrendous for plotting voxel data (Volumes): anything more than a few thousand voxels (e.g. 40x40x40) and the html encoding and loading the plot takes ages. Unfortunately, the same happens with Isosurfaces. I'm adding an implementation here but until plotly gets MUCH better at this, there is really no point. For now, we will fallback to plotting the voxels as scatter plots using the top 10k voxels sorted by brightness. """ # Skip empty neurons if min(neuron.shape) == 0: return [] try: if len(color) == 3: c = 'rgb{}'.format(color) elif len(color) == 4: c = 'rgba{}'.format(color) except BaseException: c = 'rgb(10,10,10)' if kwargs.get('hover_name', False): hoverinfo = 'text' hovertext = neuron.label else: hoverinfo = 'none' hovertext = ' ' if not as_scatter: # Downsample heavily ds = ndimage.zoom(neuron.grid, .2, order=1) # Generate X, Y, Z, coordinates for values in grid X, Y, Z = np.meshgrid(range(ds.shape[0]), range(ds.shape[1]), range(ds.shape[2]), indexing='ij') # Flatten and scale coordinates X = X.flatten() * neuron.units_xyz[0] + neuron.offset[0] Y = Y.flatten() * neuron.units_xyz[1] + neuron.offset[1] Z = Z.flatten() * neuron.units_xyz[2] + neuron.offset[2] # Flatten and normalize values values = ds.flatten() / ds.max() trace_data = [ go.Isosurface( x=X, y=Y, z=Z, value=values, isomin=0.001, isomax=1, opacity=0.1, surface_count=21, ) ] else: voxels, values = neuron.voxels, neuron.values # Sort by brightness srt = np.argsort(values) # Take the top 100k voxels values = values[srt[-100000:]] voxels = voxels[srt[-100000:]] # Scale and offset voxels voxels = voxels * neuron.units_xyz.magnitude + neuron.offset with warnings.catch_warnings(): trace_data = [ go.Scatter3d(x=voxels[:, 0], y=voxels[:, 1], z=voxels[:, 2], mode='markers', marker=dict(color=values, size=4, colorscale='viridis', opacity=.1), name=label, legendgroup=legendgroup, showlegend=showlegend, hovertext=hovertext, hoverinfo=hoverinfo) ] return trace_data