Ejemplo n.º 1
0
    def plotlyIsosurface(self, isomin, isomax, title):
        '''Return a plotly FigureWidget containing an isosurface representation of the 3D histogram
        
        Args: 
            isomin (int): minimum value of isosurface colorscale
            isomax (int): max value of isosurface colorscale
            title (string): title of the plot
        '''

        # generate X, Y, Z, Values arrays. Plotly needs these as inputs to the Isosurface function. We are storing these as properties
        # so that the user could access them to plot more involved isosurface representations from the Jupyter notebook.
        self.histo3D_unrolled = np.reshape(self.histo3D[0], [-1])
        self.histo3D_unrolled_coords = [[x, y, z]
                                        for z in self.histo3D[1][2][:-1]
                                        for y in self.histo3D[1][1][:-1]
                                        for x in self.histo3D[1][0][:-1]]

        data = [
            go.Isosurface(x=np.array(self.histo3D_unrolled_coords)[:, 0],
                          y=np.array(self.histo3D_unrolled_coords)[:, 1],
                          z=np.array(self.histo3D_unrolled_coords)[:, 2],
                          value=self.histo3D_unrolled,
                          isomin=isomin,
                          isomax=isomax,
                          colorscale='Blues')
        ]
        layout = go.Layout(title=title)
        return go.FigureWidget(data, layout)
Ejemplo n.º 2
0
def plot_iso_surface(sdf,
                     box_size=(1.0, 1.0, 1.0),
                     max_n_eval_pts=1e6,
                     iso_max=0.1,
                     resolution=64,
                     thres=0.0,
                     save_path=None) -> go.Figure:
    """ plot levelset at a certain cross section, assume inputs are centered
    Args:
        decode_points_func: A function to extract the SDF/occupancy logits of (N, 3) points
        box_size (List[float]): bounding box dimension
        max_n_eval_pts (int): max number of points to evaluate in one inference
        resolution (int): cross section resolution xy
        thres (float): levelset value
        imgs_per_cut (int): number of images for each cut (plotted in rows)
    Returns:
        a numpy array for the image
    """
    grid = get_grid_uniform(resolution, box_side_length=box_size[0])
    z = []
    points = grid['grid_points']

    for i, pnts in enumerate(torch.split(points, 100000, dim=0)):
        values = sdf(pnts)
        if isinstance(values, torch.Tensor):
            values = values.detach().cpu().numpy()
        z.append(values)

    z = np.concatenate(z, axis=0)

    z = z.astype(np.float32)
    points = points.cpu().numpy()
    fig = go.Figure(data=go.Isosurface(
        x=points[..., 0].flatten(),
        y=points[..., 1].flatten(),
        z=points[..., 2].flatten(),
        value=z.flatten(),
        opacity=0.6,
        isomin=0.0,
        isomax=iso_max,
        surface_count=3,
        surface_fill=0.8,
        # surface_pattern='A',
        # surface=dict(count=3, fill=0.7),
        caps=dict(x_show=False, y_show=False, z_show=False)))

    if save_path is not None:
        os.makedirs(os.path.dirname(save_path), exist_ok=True)
        fig.write_html(save_path)

    return fig
Ejemplo n.º 3
0
    def get_plot(self, var, value, plottype, colorscale="Viridis", sym="F", log="F", vmin="", vmax=""):
        '''
        Return a plotly figure object for the plottype requested.
        var, value, and plottype are required variables.
        -var = name of plot variable
        -value = position of slice or isosurface value
        -plottype = 2D-alt, 2D-lon, 2D-lat, 3D-alt, iso
        -colorscale = Viridis [default], Cividis, Rainbow, or BlueRed
        -sym = F [default] for symetric colorscale around 0
        -log = F [default] for log10() of plot value
        -vmin,vmax = minimum and maximum value for contour values, empty is min/max
        '''

        # Common code blocks for all plots
        txtbot = "Model: GITM v" + str(self.codeversion) + ",  Run: " + self.runname
        units=self.variables[var]['units']
        txtbar = var + " [" + units + "]"
        time=self.dtvalue.strftime("%Y/%m/%d %H:%M:%S UT")

        if plottype == "2D-alt":
            # Check if altitude entered is valid
            if value < self.altmin or value > self.altmax:
                print('Altitude is out of range: alt=',value,\
                      ' min/max=',self.altmin,'/',self.altmax)
                return

            # Set grid for plot
            altkm=value/1000.
            ilon = np.linspace(0, 360, 361)
            ilat = np.linspace(-90, 90, 181)
            ialt = np.array([value])
            xx, yy = np.meshgrid(np.array(ilon), np.array(ilat))
            grid = np.ndarray(shape=(np.size(np.reshape(xx,-1)),3), dtype=np.float32)
            grid[:,0] = np.reshape(xx,-1)
            grid[:,1] = np.reshape(yy,-1)
            grid[:,2] = value
            test = self.variables[var]['interpolator'](grid)
            result = np.reshape(test,(ilat.shape[0],ilon.shape[0]))
            if log == "T":
                txtbar = "log<br>"+txtbar
                result = np.log10(result)
            if sym == "T":
                cmax = np.max(np.absolute(result))
                if vmax != "":
                    cmax = abs(float(vmax))
                if vmin != "":
                    cmax = max(cmax,abs(float(vmin)))
                cmin = -cmax
            else:
                cmax = np.max(result)
                cmin = np.min(result)
                if vmax != "":
                    cmax = float(vmax)
                if vmin != "":
                    cmin = float(vmin)

            def plot_var(lon = ilon, lat = ilat):
                return result
            plotvar = Kamodo(plot_var = plot_var)

            fig = plotvar.plot(plot_var = dict())
            #fig.update_xaxes(nticks=7,title_text="",scaleanchor='y')
            fig.update_xaxes(tick0=0.,dtick=45.,title_text="")
            fig.update_yaxes(tick0=0.,dtick=45,title_text="")
            if colorscale == "BlueRed":
                fig.update_traces(colorscale="RdBu", reversescale=True)
            elif colorscale == "Rainbow":
                fig.update_traces(
                    colorscale=[[0.00, 'rgb(0,0,255)'],
                                [0.25, 'rgb(0,255,255)'],
                                [0.50, 'rgb(0,255,0)'],
                                [0.75, 'rgb(255,255,0)'],
                                [1.00, 'rgb(255,0,0)']]
                )
            else:
                fig.update_traces(colorscale=colorscale)
            fig.update_traces(
                zmin=cmin, zmax=cmax,
                ncontours=201,
                colorbar=dict(title=txtbar, tickformat=".3g"),
                contours=dict(coloring="fill", showlines=False)
            )
            if log == "T":
                fig.update_traces(
                    hovertemplate="Lon: %{x:.0f}<br>Lat: %{y:.0f}<br><b>log("+var+"): %{z:.4g}</b><extra></extra>"
                )
            else:
                fig.update_traces(
                    hovertemplate="Lon: %{x:.0f}<br>Lat: %{y:.0f}<br><b>"+var+": %{z:.4g}</b><extra></extra>"
                )                
            fig.update_layout(
                title=dict(text="Altitude="+"{:.0f}".format(altkm)+" km,  Time = " + time,
                           yref="container", yanchor="top", y=0.95),
                title_font_size=16,
                annotations=[
                    dict(text="Lon [degrees]", x=0.5, y=-0.13, showarrow=False,
                         xref="paper", yref="paper", font=dict(size=12)),
                    dict(text="Lat [degrees]", x=-0.1, y=0.5, showarrow=False,
                         xref="paper", yref="paper", font=dict(size=12), textangle=-90),
                    dict(text=txtbot, x=0.0, y=0.0, ax=0, ay=0, xanchor="left",
                         xshift=-65, yshift=-42, xref="paper", yref="paper",
                         font=dict(size=16, family="sans serif", color="#000000"))
                ],
                height=340
            )
            return fig

        if plottype == "2D-lat":
            # Check if latitude entered is valid
            if value < -90. or value > 90.:
                print('Latitude is out of range: lat=',value,' min/max= -90./90.')
                return

            # Set grid for plot
            ilon = np.linspace(0, 360, 361)
            ilat = np.array([value])
            ialt = np.linspace(self.altmin, self.altmax, 300)
            xx, yy = np.meshgrid(np.array(ilon), np.array(ialt))
            grid = np.ndarray(shape=(np.size(np.reshape(xx,-1)),3), dtype=np.float32)
            grid[:,0] = np.reshape(xx,-1)
            grid[:,1] = value
            grid[:,2] = np.reshape(yy,-1)
            test = self.variables[var]['interpolator'](grid)
            result = np.reshape(test,(ialt.shape[0],ilon.shape[0]))
            if log == "T":
                txtbar = "log<br>"+txtbar
                result = np.log10(result)
            if sym == "T" and vmin == "" and vmax == "":
                cmax = np.max(np.absolute(result))
                cmin = -cmax
            else:
                cmax = np.max(result)
                cmin = np.min(result)
                if vmax != "":
                    cmax = float(vmax)
                if vmin != "":
                    cmin = float(vmin)

            ialt = ialt/1000.
            def plot_var(lon = ilon, alt = ialt):
                return result
            plotvar = Kamodo(plot_var = plot_var)

            fig = plotvar.plot(plot_var = dict())
            #fig.update_xaxes(nticks=7,title_text="",scaleanchor='y')
            fig.update_xaxes(tick0=0.,dtick=45.,title_text="")
            fig.update_yaxes(title_text="")
            if colorscale == "BlueRed":
                fig.update_traces(colorscale="RdBu", reversescale=True)
            elif colorscale == "Rainbow":
                fig.update_traces(
                    colorscale=[[0.00, 'rgb(0,0,255)'],
                                [0.25, 'rgb(0,255,255)'],
                                [0.50, 'rgb(0,255,0)'],
                                [0.75, 'rgb(255,255,0)'],
                                [1.00, 'rgb(255,0,0)']]
                )
            else:
                fig.update_traces(colorscale=colorscale)
            fig.update_traces(
                zmin=cmin, zmax=cmax,
                ncontours=201,
                colorbar=dict(title=txtbar, tickformat=".3g"),
                contours=dict(coloring="fill", showlines=False)
            )
            if log == "T":
                fig.update_traces(
                    hovertemplate="Lon: %{x:.0f}<br>Alt: %{y:.0f}<br><b>log("+var+"): %{z:.4g}</b><extra></extra>"
                )
            else:
                fig.update_traces(
                    hovertemplate="Lon: %{x:.0f}<br>Alt: %{y:.0f}<br><b>"+var+": %{z:.4g}</b><extra></extra>"
                )                
            fig.update_layout(
                title=dict(text="Latitude="+"{:.1f}".format(value)+" degrees,  Time = " + time,
                           yref="container", yanchor="top", y=0.95),
                title_font_size=16,
                annotations=[
                    dict(text="Lon [degrees]", x=0.5, y=-0.13, showarrow=False,
                         xref="paper", yref="paper", font=dict(size=12)),
                    dict(text="Altitude [km]", x=-0.1, y=0.5, showarrow=False,
                         xref="paper", yref="paper", font=dict(size=12), textangle=-90),
                    dict(text=txtbot, x=0.0, y=0.0, ax=0, ay=0, xanchor="left",
                         xshift=-65, yshift=-42, xref="paper", yref="paper",
                         font=dict(size=16, family="sans serif", color="#000000"))
                ],
                height=340
            )
            return fig

        if plottype == "2D-lon":
            # Check if longitude entered is valid
            if value < 0. or value > 360.:
                print('Latitude is out of range: lat=',value,' min/max= 0./360.')
                return

            # Set grid for plot
            ilon = np.array([value])
            ilat = np.linspace(-90, 90, 181)
            ialt = np.linspace(self.altmin, self.altmax, 300)
            xx, yy = np.meshgrid(np.array(ilat), np.array(ialt))
            grid = np.ndarray(shape=(np.size(np.reshape(xx,-1)),3), dtype=np.float32)
            grid[:,0] = value
            grid[:,1] = np.reshape(xx,-1)
            grid[:,2] = np.reshape(yy,-1)
            test = self.variables[var]['interpolator'](grid)
            result = np.reshape(test,(ialt.shape[0],ilat.shape[0]))
            if log == "T":
                txtbar = "log<br>"+txtbar
                result = np.log10(result)
            if sym == "T" and vmin == "" and vmax == "":
                cmax = np.max(np.absolute(result))
                cmin = -cmax
            else:
                cmax = np.max(result)
                cmin = np.min(result)
                if vmax != "":
                    cmax = float(vmax)
                if vmin != "":
                    cmin = float(vmin)

            ialt = ialt/1000.
            def plot_var(lat = ilat, alt = ialt):
                return result
            plotvar = Kamodo(plot_var = plot_var)

            fig = plotvar.plot(plot_var = dict())
            #fig.update_xaxes(nticks=7,title_text="",scaleanchor='y')
            fig.update_xaxes(tick0=0.,dtick=30.,title_text="")
            fig.update_yaxes(title_text="")
            if colorscale == "BlueRed":
                fig.update_traces(colorscale="RdBu", reversescale=True)
            elif colorscale == "Rainbow":
                fig.update_traces(
                    colorscale=[[0.00, 'rgb(0,0,255)'],
                                [0.25, 'rgb(0,255,255)'],
                                [0.50, 'rgb(0,255,0)'],
                                [0.75, 'rgb(255,255,0)'],
                                [1.00, 'rgb(255,0,0)']]
                )
            else:
                fig.update_traces(colorscale=colorscale)
            fig.update_traces(
                zmin=cmin, zmax=cmax,
                ncontours=201,
                colorbar=dict(title=txtbar, tickformat=".3g"),
                contours=dict(coloring="fill", showlines=False)
            )
            if log == "T":
                fig.update_traces(
                    hovertemplate="Lat: %{x:.0f}<br>Alt: %{y:.0f}<br><b>log("+var+"): %{z:.4g}</b><extra></extra>"
                )
            else:
                fig.update_traces(
                    hovertemplate="Lat: %{x:.0f}<br>Alt: %{y:.0f}<br><b>"+var+": %{z:.4g}</b><extra></extra>"
                )                
            fig.update_layout(
                title=dict(text="Longitude="+"{:.1f}".format(value)+" degrees,  Time = " + time,
                           yref="container", yanchor="top", y=0.95),
                title_font_size=16,
                annotations=[
                    dict(text="Lat [degrees]", x=0.5, y=-0.13, showarrow=False,
                         xref="paper", yref="paper", font=dict(size=12)),
                    dict(text="Altitude [km]", x=-0.1, y=0.5, showarrow=False,
                         xref="paper", yref="paper", font=dict(size=12), textangle=-90),
                    dict(text=txtbot, x=0.0, y=0.0, ax=0, ay=0, xanchor="left",
                         xshift=-65, yshift=-42, xref="paper", yref="paper",
                         font=dict(size=16, family="sans serif", color="#000000"))
                ],
                height=340
            )
            return fig

        if plottype == "3D-alt":
            # Check if altitude entered is valid
            if value < self.altmin or value > self.altmax:
                print('Altitude is out of range: alt=',value,\
                      ' min/max=',self.altmin,'/',self.altmax)
                return

            # Set grid for plot
            altkm=value/1000.
            ilon = np.linspace(0, 360, 361)
            ilat = np.linspace(-90, 90, 181)
            ialt = np.array([value])
            xx, yy = np.meshgrid(np.array(ilon), np.array(ilat))
            grid = np.ndarray(shape=(np.size(np.reshape(xx,-1)),3), dtype=np.float32)
            grid[:,0] = np.reshape(xx,-1)
            grid[:,1] = np.reshape(yy,-1)
            grid[:,2] = value
            test = self.variables[var]['interpolator'](grid)
            result = np.reshape(test,(ilat.shape[0],ilon.shape[0]))
            if log == "T":
                txtbar = "log<br>"+txtbar
                result = np.log10(result)
            r = value + 6.3781E6
            x=-(r*np.cos(yy*np.pi/180.)*np.cos(xx*np.pi/180.))/6.3781E6
            y=-(r*np.cos(yy*np.pi/180.)*np.sin(xx*np.pi/180.))/6.3781E6
            z= (r*np.sin(yy*np.pi/180.))/6.3781E6
            if sym == "T" and vmin == "" and vmax == "":
                cmax = np.max(np.absolute(result))
                cmin = -cmax
            else:
                cmax = np.max(result)
                cmin = np.min(result)
                if vmax != "":
                    cmax = float(vmax)
                if vmin != "":
                    cmin = float(vmin)

            def plot_var(x = x, y = y, z = z):
                return result
            plotvar = Kamodo(plot_var = plot_var)

            fig = plotvar.plot(plot_var = dict())
            fig.update_scenes(xaxis=dict(title=dict(text="X [Re]")),
                              yaxis=dict(title=dict(text="Y [Re]")),
                              zaxis=dict(title=dict(text="Z [Re]")))
            if colorscale == "BlueRed":
                fig.update_traces(colorscale="RdBu", reversescale=True)
            elif colorscale == "Rainbow":
                fig.update_traces(
                    colorscale=[[0.00, 'rgb(0,0,255)'],
                                [0.25, 'rgb(0,255,255)'],
                                [0.50, 'rgb(0,255,0)'],
                                [0.75, 'rgb(255,255,0)'],
                                [1.00, 'rgb(255,0,0)']]
                )
            else:
                fig.update_traces(colorscale=colorscale)
            fig.update_traces(
                cmin=cmin, cmax=cmax,
                colorbar=dict(title=txtbar, tickformat=".3g")
            )
            fig.update_traces(
                hovertemplate="X [Re]: %{x:.3f}<br>Y [Re]: %{y:.3f}<br>Z [Re]: %{z:.3f}<extra></extra>"
            )
            fig.update_layout(
                title=dict(text="Altitude="+"{:.0f}".format(altkm)+" km,  Time = " + time,
                           yref="container", yanchor="top", x=0.01, y=0.95),
                title_font_size=16,
                annotations=[
                    dict(text=txtbot, x=0.0, y=0.0, ax=0, ay=0, xanchor="left",
                         xshift=0, yshift=-20, xref="paper", yref="paper",
                         font=dict(size=16, family="sans serif", color="#000000"))
                ],
                margin=dict(l=0),
                width=600
            )
            x1=[0., 0.]
            y1=[0., 0.]
            z1=[-1.2, 1.2]
            fig.add_scatter3d(mode='lines',x=x1,y=y1,z=z1,line=dict(width=4,color='black'),
                              showlegend=False,hovertemplate='Polar Axis<extra></extra>')
            r = value + 10000. + 6.3781E6
            x2=-(r*np.cos(ilon*np.pi/180.))/6.3781E6
            y2=-(r*np.sin(ilon*np.pi/180.))/6.3781E6
            z2=0.*ilon
            fig.add_scatter3d(mode='lines',x=x2,y=y2,z=z2,line=dict(width=2,color='black'),
                              showlegend=False,hovertemplate='Equator<extra></extra>')
            x3=-(r*np.cos(ilat*np.pi/180.))/6.3781E6
            y3=0.*ilat
            z3=(r*np.sin(ilat*np.pi/180.))/6.3781E6
            fig.add_scatter3d(mode='lines',x=x3,y=y3,z=z3,line=dict(width=2,color='black'),
                              showlegend=False,hovertemplate='prime meridian<extra></extra>')
            return fig

        if plottype == "iso":
            # Check if value entered is valid (checking before possible log scale)
            cmin=np.min(self.variables[var]['data'])
            cmax=np.max(self.variables[var]['data'])
            if value < cmin or value > cmax:
                print('Iso value is out of range: iso=',value,' min/max=',cmin,cmax)
                sys.exit("Exiting ...")
                return

            # Determine altitude start, stop, number for use in isosurface
            step=10000. # 10000. m altitude step size
            alt1=(step*round(self.alt[2]/step))
            alt2=self.alt[(self.alt.shape[0]-3)]
            nalt=round((alt2-alt1)/step)
            alt2=alt1+step*nalt
            nalt=1+int(nalt)
            
            # Set function for interpolation
            gridp = np.ndarray(shape=(1,3), dtype=np.float32)
            def finterp(x,y,z):
                gridp[0,:] = [x,y,z]
                return self.variables[var]['interpolator'](gridp)

            # Extract the isosurface as vertices and triangulated connectivity
            import mcubes
            verts, tri = mcubes.marching_cubes_func(
                (0, -90, alt1), (360, 90, alt2),  # Bounds (min:x,y,z), (max:x,y,z)
                73, 37, nalt,                     # Number of samples in each dimension
                finterp,                          # Implicit function of x,y,z
                value)                            # Isosurface value

            # Process output for creating plots, including face or vertex colors
            X, Y, Z = verts[:,:3].T
            I, J, K = tri.T
            
            # Update cmin, cmax for log, sym, vmin, vmax values if set
            if sym == "T":
                if vmax != "":
                    cmax = abs(float(vmax))
                if vmin != "":
                    cmax = max(cmax,abs(float(vmin)))
                cmin = -cmax
            else:
                if vmax != "":
                    cmax = float(vmax)
                if vmin != "":
                    cmin = float(vmin)
            dvalue=value
            if log == "T":
                dvalue=np.log10(value)
                cmin=np.log10(cmin)
                cmax=np.log10(cmax)
                txtbar = "log<br>"+txtbar
            
            # Create fig and update with modifications to customize plot
            fig=go.Figure(data=[go.Mesh3d(
                name='ISO',
                x=X,
                y=Y,
                z=Z,
                i=I,
                j=J,
                k=K,
                intensity=np.linspace(dvalue,dvalue,X.shape[0]),
                cmin=cmin, cmax=cmax,
                showscale=True,
                opacity=0.6,
                colorbar=dict(title=txtbar, tickformat=".3g"),
                hovertemplate="Isosurface<br>"+var+"="+"{:.3g}".format(value)+" "+units+"<br><extra></extra>",
            )])
            if colorscale == "BlueRed":
                fig.update_traces(colorscale="RdBu", reversescale=True)
            elif colorscale == "Rainbow":
                fig.update_traces(
                    colorscale=[[0.00, 'rgb(0,0,255)'],
                                [0.25, 'rgb(0,255,255)'],
                                [0.50, 'rgb(0,255,0)'],
                                [0.75, 'rgb(255,255,0)'],
                                [1.00, 'rgb(255,0,0)']]
                )
            else:
                fig.update_traces(colorscale=colorscale)
            fig.update_scenes(
                xaxis=dict(title=dict(text="Lon [degrees]"),tick0=0.,dtick=45.),
                yaxis=dict(title=dict(text="Lat [degrees]"),tick0=0.,dtick=45.),
                zaxis=dict(title=dict(text="Alt [km]"))
            )
            fig.update_layout(
                scene_camera_eye=dict(x=.1, y=-1.8, z=1.5),
                scene_aspectmode='manual',
                scene_aspectratio=dict(x=2, y=1, z=1),
                scene_xaxis=dict(range=[0,360]),
                scene_yaxis=dict(range=[-90,90]),
                scene_zaxis=dict(range=[alt1,alt2]),
                title=dict(text="Isosurface of "+var+"="+\
                           "{:.3g}".format(value)+" "+units+"<br>"+"Time = "+time,
                           yref="container", yanchor="top", x=0.01, y=0.95),
                title_font_size=16,
                annotations=[
                    dict(text=txtbot, x=0.0, y=0.0, ax=0, ay=0, xanchor="left",
                         xshift=0, yshift=-20, xref="paper", yref="paper",
                         font=dict(size=16, family="sans serif", color="#000000"))
                ],
                margin=dict(l=10,t=80),
            )
            
            return fig

        if plottype == "iso1":
            # This method keeps all the data local, resulting in huge storage
            #   as well as corrupted html divs.
            
            ilon = np.linspace(0, 360, 73) #181
            ilat = np.linspace(-90, 90, 37) #91
            step=10000. # 5000.
            alt1=(step*round(self.alt[2]/step))
            alt2=self.alt[(self.alt.shape[0]-3)]
            nalt=round((alt2-alt1)/step)
            alt2=alt1+step*nalt
            nalt=1+int(nalt)
            ialt = np.linspace(alt1, alt2, nalt)
            xx,yy,zz = np.meshgrid(np.array(ilon),np.array(ilat),np.array(ialt))
            grid = np.ndarray(shape=(np.size(np.reshape(xx,-1)),3), dtype=np.float32)
            grid[:,0] = np.reshape(xx,-1)
            grid[:,1] = np.reshape(yy,-1)
            grid[:,2] = np.reshape(zz,-1)
            test = self.variables[var]['interpolator'](grid)
            result = np.reshape(test,(ilat.shape[0],ilon.shape[0],ialt.shape[0]))
            isovalue=value
            if log == "T":
                isovalue=np.log10(value)
                txtbar = "log<br>"+txtbar
                result = np.log10(result)
            if sym == "T":
                cmax = np.max(np.absolute(result))
                if vmax != "":
                    cmax = abs(float(vmax))
                if vmin != "":
                    cmax = max(cmax,abs(float(vmin)))
                cmin = -cmax
            else:
                cmax = np.max(result)
                cmin = np.min(result)
                if vmax != "":
                    cmax = float(vmax)
                if vmin != "":
                    cmin = float(vmin)

            # Check if value entered is valid (checking before possible log scale)
            if value < np.min(test) or value > np.max(test):
                print('Iso value is out of range: iso=',value,\
                      ' min/max=',np.min(test),'/',np.max(test))
                sys.exit("Exiting ...")
                return

            slicevalue=0.
            fig1 = go.Figure(data=go.Isosurface(
                x=xx.flatten(),
                y=yy.flatten(),
                z=zz.flatten()/1000.,
                value=result.flatten(),
                opacity=0.6,
                isomin=cmin,
                isomax=cmax,
                surface=dict(count=2, fill=1., pattern='all'),
                caps=dict(x_show=False, y_show=False, z_show=False),
                showscale=True, # show colorbar
                colorbar=dict(title=txtbar, tickformat=".3g"),
                slices_y=dict(show=True, locations=[slicevalue]),
            ))
            fig1.update_traces(
                hovertemplate="<b>Slice</b><br>Lon: %{x:.0f}<br>Lat: %{y:.0f}<br>Alt: %{z:.0f}km<br><extra></extra>"
            )
            if colorscale == "BlueRed":
                fig1.update_traces(colorscale="RdBu", reversescale=True)
            elif colorscale == "Rainbow":
                fig1.update_traces(
                    colorscale=[[0.00, 'rgb(0,0,255)'],
                                [0.25, 'rgb(0,255,255)'],
                                [0.50, 'rgb(0,255,0)'],
                                [0.75, 'rgb(255,255,0)'],
                                [1.00, 'rgb(255,0,0)']]
                )
            else:
                fig1.update_traces(colorscale=colorscale)
            fig2 = go.Figure(data=go.Isosurface(
                x=xx.flatten(),
                y=yy.flatten(),
                z=zz.flatten()/1000.,
                value=result.flatten(),
                opacity=1.,
                colorscale=[[0.0, '#777777'],[1.0, '#777777']],
                isomin=isovalue,
                isomax=isovalue,
                surface=dict(count=1, fill=1., pattern='all'),
                caps=dict(x_show=False, y_show=False, z_show=False),
                showscale=False, # remove colorbar
                hovertemplate="<b>Isosurface</b><br>Lon: %{x:.0f}<br>Lat: %{y:.0f}<br>Alt: %{z:.0f}km<extra></extra>"
            ))
            fig2.update_scenes(
                xaxis=dict(title=dict(text="Lon [degrees]"),tick0=0.,dtick=45.),
                yaxis=dict(title=dict(text="Lat [degrees]"),tick0=0.,dtick=45.),
                zaxis=dict(title=dict(text="Alt [km]"))
            )
            fig2.update_layout(
                scene_camera_eye=dict(x=.1, y=-1.8, z=1.5),
                scene_aspectmode='manual',
                scene_aspectratio=dict(x=2, y=1, z=1),
                title=dict(text="Latitude="+"{:.0f}".format(slicevalue)+
                           " slice through the data<br>"+
                           "Isosurface of "+var+"="+"{:.0f}".format(isovalue)+units+"<br>"+
                           "Time = " + time,
                           yref="container", yanchor="top", x=0.01, y=0.95),
                title_font_size=16,
                annotations=[
                    dict(text=txtbot, x=0.0, y=0.0, ax=0, ay=0, xanchor="left",
                         xshift=0, yshift=-20, xref="paper", yref="paper",
                         font=dict(size=16, family="sans serif", color="#000000"))
                ],
                margin=dict(l=10,t=80),
            )
            fig2.add_trace(fig1.data[0])
            
            return fig2

        print('Unknown plottype (',plottype,') returning.')
        return
Ejemplo n.º 4
0
t = np.linspace(0, 10, 50)
x, y, z = np.cos(t), np.sin(t), t
fig = go.Figure(data=[go.Scatter3d(x=x, y=y, z=z, mode='markers')])
pyo.plot(fig, filename="3dscatter.html")
fig.show()

# In[104]:

X, Y, Z = np.mgrid[-5:5:40j, -5:5:40j, -5:5:40j]
values = X * X * 0.5 + Y * Y + Z * Z * 2
fig = go.Figure(data=go.Isosurface(
    x=X.flatten(),
    y=Y.flatten(),
    z=Z.flatten(),
    value=values.flatten(),
    isomin=10,
    isomax=50,
    surface_count=5,  # number of isosurfaces, 2 by default: only min and max
    colorbar_nticks=5,  # colorbar ticks correspond to isosurface values
    caps=dict(x_show=False, y_show=False)))
pyo.plot(fig, filename="isospace.html")
fig.show()

# In[105]:

fig = go.Figure(data=go.Streamtube(x=[0, 0, 0],
                                   y=[0, 1, 2],
                                   z=[0, 0, 0],
                                   u=[0, 0, 0],
                                   v=[1, 1, 1],
                                   w=[0, 0, 0]))
Ejemplo n.º 5
0
 def update_figure(invar, invar_2, invar_3, outvar, invar1_log, invar2_log, invar3_log, outvar_log, param_slider,
                   graph_type, color_use, color_dd, error_use, error_dd, filter_active, fit_use, fit_dd, fit_num, fit_conf, add_noise_var, fit_color,
                   fit_opacity, fit_sampling, id_type, param_center, param_log):
     for i in range(len(param_slider)):
         if param_log[i] == ['log']:
             param_slider[i] = [10**val for val in param_slider[i]]
             param_center[i] = 10**param_center[i]
     if invar is None:
         return go.Figure()
     sel_y = np.full((len(outdata),), True)
     dds_value = []
     for iteration, values in enumerate(param_slider):
         dds_value.append(invars[id_type[iteration]['index']])
         # filter for minimum
         sel_y_min = np.array(indata[dds_value[iteration]] >= param_slider[iteration][0])
         # filter for maximum
         sel_y_max = np.array(indata[dds_value[iteration]] <= param_slider[iteration][1])
         # print('iter ', iteration, 'filer', filter_active[iteration][0])
         if filter_active != [[]]:
             if filter_active[iteration] == ['act']:
                 sel_y = sel_y_min & sel_y_max & sel_y
     if graph_type == '1D':
         fig = go.Figure(
             data=[go.Scatter(
                 x=indata[invar][sel_y],
                 y=outdata[outvar][sel_y],
                 mode='markers',
                 name='data',
                 error_y=dict(type='data', array=outdata[error_dd][sel_y], visible= error_use == ['true']),
                 # text=[(invar, outvar) for i in range(len(indata[invar][sel_y]))],
                 # hovertemplate=" %{text} <br> %{x} <br> %{y}",
             )],
             layout=go.Layout(xaxis=dict(title=invar, rangeslider=dict(visible=True)), yaxis=dict(title=outvar))
         )
         if fit_use == ['show']:
             mesh_in, mesh_out, mesh_out_std, fit_dd_values = mesh_fit(param_slider, id_type, fit_dd, fit_num,
                                                                       param_center, [invar], [invar1_log],
                                                                       outvar, fit_sampling, add_noise_var)
             for i in range(len(fit_dd_values)):
                 fig.add_trace(go.Scatter(
                     x=mesh_in[i][invars.index(invar)],
                     y=mesh_out[i],
                     mode='lines',
                     name=f'fit: {fit_dd}={fit_dd_values[i]:.1e}',
                     line_color=colormap(indata[fit_dd].min(), indata[fit_dd].max(), fit_dd_values[i]),
                     marker_line=dict(coloraxis="coloraxis2"),
                 ))
                 fig.add_trace(go.Scatter(
                     x=np.hstack((mesh_in[i][invars.index(invar)], mesh_in[i][invars.index(invar)][::-1])),
                     y=np.hstack((mesh_out[i] + fit_conf * mesh_out_std[i], mesh_out[i][::-1] - fit_conf * mesh_out_std[i][::-1])),
                     showlegend=False,
                     fill='toself',
                     line_color=colormap(indata[fit_dd].min(), indata[fit_dd].max(), fit_dd_values[i]),
                     marker_line=dict(coloraxis="coloraxis2"),
                     opacity=fit_opacity,
                 ))
     elif graph_type == '2D':
         fig = go.Figure(
             data=[go.Scatter3d(
                 x=indata[invar][sel_y],
                 y=indata[invar_2][sel_y],
                 z=outdata[outvar][sel_y],
                 mode='markers',
                 name='Data',
                 error_z=dict(type='data', array=outdata[error_dd][sel_y], visible=error_use == ['true'], width= 10)
             )],
             layout=go.Layout(scene=dict(xaxis_title=invar, yaxis_title=invar_2, zaxis_title=outvar))
         )
         if fit_use == ['show'] and invar != invar_2:
             mesh_in, mesh_out, mesh_out_std, fit_dd_values = mesh_fit(param_slider, id_type, fit_dd, fit_num,
                                                                       param_center, [invar, invar_2],
                                                                       [invar1_log, invar2_log], outvar,
                                                                       fit_sampling, add_noise_var)
             for i in range(len(fit_dd_values)):
                 fig.add_trace(go.Surface(
                     x=mesh_in[i][invars.index(invar)].reshape((fit_sampling, fit_sampling)),
                     y=mesh_in[i][invars.index(invar_2)].reshape((fit_sampling, fit_sampling)),
                     z=mesh_out[i].reshape((fit_sampling, fit_sampling)),
                     name=f'fit: {fit_dd}={fit_dd_values[i]:.2f}',
                     surfacecolor=fit_dd_values[i] * np.ones([fit_sampling, fit_sampling])
                         if fit_color == 'multi-fit' else
                         (mesh_in[i][invars.index(color_dd)].reshape((fit_sampling, fit_sampling))
                          if (fit_color == 'marker-color' and color_dd in invars) else
                          mesh_out[i].reshape((fit_sampling, fit_sampling))),
                     opacity=fit_opacity,
                     coloraxis="coloraxis2" if (fit_color == 'multi-fit' or
                         (fit_color == 'output' and (color_dd != outvar and color_dd != 'OUTPUT'))) else "coloraxis",
                     showlegend=True if len(invars) > 2 else False,
                 ))
                 if fit_conf > 0:
                     fig.add_trace(go.Surface(
                         x=mesh_in[i][invars.index(invar)].reshape((fit_sampling, fit_sampling)),
                         y=mesh_in[i][invars.index(invar_2)].reshape((fit_sampling, fit_sampling)),
                         z=mesh_out[i].reshape((fit_sampling, fit_sampling)) + fit_conf * mesh_out_std[i].reshape((fit_sampling, fit_sampling)),
                         showlegend=False,
                         name=f'fit+v: {fit_dd}={fit_dd_values[i]:.2f}',
                         surfacecolor=fit_dd_values[i] * np.ones([fit_sampling, fit_sampling])
                             if fit_color == 'multi-fit' else
                             (mesh_in[i][invars.index(color_dd)].reshape((fit_sampling, fit_sampling))
                              if (fit_color == 'marker-color' and color_dd in invars) else
                              mesh_out[i].reshape((fit_sampling, fit_sampling))),
                         opacity=fit_opacity,
                         coloraxis="coloraxis2" if (fit_color == 'multi-fit' or
                             (fit_color == 'output' and (color_dd != outvar and color_dd != 'OUTPUT')))
                             else "coloraxis",
                     ))
                     fig.add_trace(go.Surface(
                         x=mesh_in[i][invars.index(invar)].reshape((fit_sampling, fit_sampling)),
                         y=mesh_in[i][invars.index(invar_2)].reshape((fit_sampling, fit_sampling)),
                         z=mesh_out[i].reshape((fit_sampling, fit_sampling)) - fit_conf * mesh_out_std[i].reshape((fit_sampling, fit_sampling)),
                         showlegend=False,
                         name=f'fit-v: {fit_dd}={fit_dd_values[i]:.2f}',
                         surfacecolor=fit_dd_values[i] * np.ones([fit_sampling, fit_sampling])
                             if fit_color == 'multi-fit' else
                             (mesh_in[i][invars.index(color_dd)].reshape((fit_sampling, fit_sampling))
                              if (fit_color == 'marker-color' and color_dd in invars) else
                              mesh_out[i].reshape((fit_sampling, fit_sampling))),
                         opacity=fit_opacity,
                         coloraxis="coloraxis2" if (fit_color == 'multi-fit' or
                             (fit_color == 'output' and (color_dd != outvar and color_dd != 'OUTPUT')))
                             else "coloraxis",
                     ))
             fig.update_layout(coloraxis2=dict(
                 colorbar=dict(title=outvar if fit_color == 'output' else fit_dd),
                 cmin=min(fit_dd_values) if fit_color == 'multi-fit' else None,
                 cmax=max(fit_dd_values) if fit_color == 'multi-fit' else None,
             ))
     elif graph_type == '2D contour':
         mesh_in, mesh_out, mesh_out_std, fit_dd_values = mesh_fit(param_slider, id_type, fit_dd, fit_num,
                                                                   param_center, [invar, invar_2],
                                                                   [invar1_log, invar2_log], outvar,
                                                                   fit_sampling, add_noise_var)
         data_x = mesh_in[0][invars.index(invar)]
         data_y = mesh_in[0][invars.index(invar_2)]
         fig = go.Figure()
         if min(data_x) != max(data_x):
             if min(data_y) != max(data_y):
                 fig.add_trace(go.Scatter(
                     x=indata[invar][sel_y],
                     y=indata[invar_2][sel_y],
                     mode='markers',
                     name='Data',
                 ))
                 fig.add_trace(go.Contour(
                     x=mesh_in[0][invars.index(invar)],
                     y=mesh_in[0][invars.index(invar_2)],
                     z=mesh_out[0],
                     contours_coloring='heatmap',
                     contours_showlabels=True,
                     coloraxis='coloraxis2',
                     name='fit',
                 ))
                 fig.update_xaxes(
                     range=[log10(min(fig.data[1]['x'])), log10(max(fig.data[1]['x']))] if invar1_log == ['log']
                     else [min(fig.data[1]['x']), max(fig.data[1]['x'])])
                 fig.update_yaxes(
                     range=[log10(min(fig.data[1]['y'])), log10(max(fig.data[1]['y']))] if invar2_log == ['log']
                     else [min(fig.data[1]['y']), max(fig.data[1]['y'])])
                 fig.update_layout(xaxis_title=invar,
                                   yaxis_title=invar_2,
                                   coloraxis2=dict(colorbar=dict(title=outvar),
                                                   colorscale='solar',
                                                   cmin=min(fig.data[1]['z']),
                                                   cmax=max(fig.data[1]['z'])))
             else:
                 fig.update_layout(title="y-data is constant, no contour-plot possible")
         else:
             fig.update_layout(title="x-data is constant, no contour-plot possible")
     elif graph_type == '3D':
         fig = go.Figure(
             data=go.Scatter3d(
                 x=indata[invar][sel_y],
                 y=indata[invar_2][sel_y],
                 z=indata[invar_3][sel_y],
                 mode='markers',
                 marker=dict(
                         color=outdata[outvar][sel_y],
                         coloraxis="coloraxis2",
                     ),
                 name='Data',
             ),
             layout=go.Layout(scene=dict(xaxis_title=invar, yaxis_title=invar_2, zaxis_title=invar_3)),
         )
         fig.update_layout(coloraxis2=dict(
             colorbar=dict(title=outvar),
         ))
         if fit_use == ['show'] and len({invar, invar_2, invar_3}) == 3:
             mesh_in, mesh_out, mesh_out_std, fit_dd_values = mesh_fit(param_slider, id_type, fit_dd, fit_num,
                                                                       param_center, [invar, invar_2, invar_3],
                                                                       [invar1_log, invar2_log, invar3_log], outvar,
                                                                       fit_sampling, add_noise_var)
             for i in range(len(fit_dd_values)):
                 fig.add_trace(
                     go.Isosurface(
                         x=mesh_in[i][invars.index(invar)],
                         y=mesh_in[i][invars.index(invar_2)],
                         z=mesh_in[i][invars.index(invar_3)],
                         value=mesh_out[i],
                         surface_count=fit_num,
                         coloraxis="coloraxis2",
                         isomin=mesh_out[i].min() * 1.1,
                         isomax=mesh_out[i].max() * 0.9,
                         caps=dict(x_show=False, y_show=False, z_show=False),
                         opacity=fit_opacity,
                     ),
                 )
     else:
         fig = go.Figure()
     fig.update_layout(legend=dict(xanchor="left", x=0.01))
     # log scale
     log_dict = {'1D': (invar1_log, outvar_log),
                 '2D': (invar1_log, invar2_log, outvar_log),
                 '2D contour': (invar1_log, invar2_log),
                 '3D': (invar1_log, invar2_log, invar3_log),}
     log_list = ['linear' if log is None or len(log) == 0 else log[0] for log in log_dict[graph_type]]
     log_key = ['xaxis', 'yaxis', 'zaxis']
     comb_dict = dict(zip(log_key, [{'type': log} for log in log_list]))
     if len(log_list) < 3 :
         fig.update_layout(**comb_dict)
     else:
         fig.update_scenes(**comb_dict)
     # color
     if color_use == ['true']:
         if fit_use == ['show'] and (graph_type=='2D' and (fit_color=='multi-fit' and color_dd==fit_dd)):
             fig.update_traces(
                 marker=dict(
                     coloraxis="coloraxis2",
                     color=indata[color_dd][sel_y] if color_dd in indata.dtype.names else outdata[color_dd][sel_y],
                 ),
                 selector=dict(mode='markers'),
             )
         elif graph_type == '3D':
             fig.update_traces(
                 marker=dict(
                     coloraxis="coloraxis2",
                     color=outdata[outvar][sel_y],
                 ),
                 selector=dict(mode='markers'),
             )
         elif graph_type=='1D':
             fig.update_traces(
                 marker=dict(
                     coloraxis="coloraxis2",
                     color=outdata[outvar][sel_y] if color_dd == 'OUTPUT' else
                     (indata[color_dd][sel_y] if color_dd in indata.dtype.names else outdata[color_dd][sel_y]),
                 ),
                 selector=dict(mode='markers'),
             )
             if color_dd==fit_dd:
                 fig.update_layout(coloraxis2=dict(colorscale='cividis', colorbar=dict(title=fit_dd)))
             elif color_dd == 'OUTPUT':
                 fig.update_layout(coloraxis2=dict(colorscale='plasma', colorbar=dict(title=outvar)))
             else:
                 fig.update_layout(coloraxis2=dict(colorscale='plasma', colorbar=dict(title=color_dd)))
         elif graph_type =='2D contour':
             fig.update_traces(
                 marker=dict(
                     coloraxis="coloraxis",
                     color=outdata[outvar][sel_y] if color_dd == 'OUTPUT' else
                     (indata[color_dd][sel_y] if color_dd in indata.dtype.names else outdata[color_dd][sel_y]),
                 ),
                 selector=dict(mode='markers'),
             )
             if color_dd == outvar or color_dd == 'OUTPUT':
                 fig.update_traces(marker_coloraxis="coloraxis2", selector=dict(mode='markers'))
             else:
                 fig.update_layout(coloraxis=dict(colorbar=dict(title=color_dd, x=1.1),
                                                  colorscale='ice'))
         else:
             fig.update_traces(
                 marker=dict(
                     coloraxis="coloraxis",
                     color=outdata[outvar][sel_y] if color_dd == 'OUTPUT' else
                     (indata[color_dd][sel_y] if color_dd in indata.dtype.names else outdata[color_dd][sel_y]),
                 ),
                 selector=dict(mode='markers'),
             )
             fig.update_layout(coloraxis=dict(
                 colorbar=dict(title=outvar if color_dd == 'OUTPUT' else color_dd, x=1.1),
                 colorscale='viridis',
             ))
     fig.update_layout(height=graph_height)
     return fig
Ejemplo n.º 6
0
def voxel2plotly(neuron,
                 legendgroup,
                 showlegend,
                 label,
                 color,
                 as_scatter=True,
                 **kwargs):
    """Convert VoxelNeuron to plotly object.

    Turns out that plotly is horrendous for plotting voxel data (Volumes):
    anything more than a few thousand voxels (e.g. 40x40x40) and the html
    encoding and loading the plot takes ages. Unfortunately, the same happens
    with Isosurfaces.

    I'm adding an implementation here but until plotly gets MUCH better at this,
    there is really no point. For now, we will fallback to plotting the
    voxels as scatter plots using the top 10k voxels sorted by brightness.

    """
    # Skip empty neurons
    if min(neuron.shape) == 0:
        return []

    try:
        if len(color) == 3:
            c = 'rgb{}'.format(color)
        elif len(color) == 4:
            c = 'rgba{}'.format(color)
    except BaseException:
        c = 'rgb(10,10,10)'

    if kwargs.get('hover_name', False):
        hoverinfo = 'text'
        hovertext = neuron.label
    else:
        hoverinfo = 'none'
        hovertext = ' '

    if not as_scatter:
        # Downsample heavily
        ds = ndimage.zoom(neuron.grid, .2, order=1)

        # Generate X, Y, Z, coordinates for values in grid
        X, Y, Z = np.meshgrid(range(ds.shape[0]),
                              range(ds.shape[1]),
                              range(ds.shape[2]),
                              indexing='ij')

        # Flatten and scale coordinates
        X = X.flatten() * neuron.units_xyz[0] + neuron.offset[0]
        Y = Y.flatten() * neuron.units_xyz[1] + neuron.offset[1]
        Z = Z.flatten() * neuron.units_xyz[2] + neuron.offset[2]

        # Flatten and normalize values
        values = ds.flatten() / ds.max()

        trace_data = [
            go.Isosurface(
                x=X,
                y=Y,
                z=Z,
                value=values,
                isomin=0.001,
                isomax=1,
                opacity=0.1,
                surface_count=21,
            )
        ]
    else:
        voxels, values = neuron.voxels, neuron.values

        # Sort by brightness
        srt = np.argsort(values)

        # Take the top 100k voxels
        values = values[srt[-100000:]]
        voxels = voxels[srt[-100000:]]

        # Scale and offset voxels
        voxels = voxels * neuron.units_xyz.magnitude + neuron.offset

        with warnings.catch_warnings():
            trace_data = [
                go.Scatter3d(x=voxels[:, 0],
                             y=voxels[:, 1],
                             z=voxels[:, 2],
                             mode='markers',
                             marker=dict(color=values,
                                         size=4,
                                         colorscale='viridis',
                                         opacity=.1),
                             name=label,
                             legendgroup=legendgroup,
                             showlegend=showlegend,
                             hovertext=hovertext,
                             hoverinfo=hoverinfo)
            ]

    return trace_data