예제 #1
0
def add_surfaces_to_plotter(srf_lst: list,
                            plotter: pv.Plotter,
                            col_lst: list = [],
                            def_col='#33a1c9',
                            per=0,
                            use_transparency=False,
                            opacity=1.0) -> None:

    for i in range(len(col_lst), len(srf_lst)):
        col_lst.append(def_col)

    for msh, color in zip(mesh_surfaces(srf_lst), col_lst):
        plotter.add_mesh(msh,
                         color=color,
                         smooth_shading=True,
                         culling=False,
                         use_transparency=use_transparency,
                         opacity=opacity)
        for i in range(1, per):
            msh_per = msh.copy(False)  # avoid mesh duplication
            msh_per.rotate_z((360. * i) / per, inplace=True)
            plotter.add_mesh(msh_per,
                             color=color,
                             smooth_shading=True,
                             culling=False)
예제 #2
0
        def create_mesh(plotter:pv.Plotter,
                        value:float, 
                        ):
            res = int(value)
            closest_idx = find_nearest(energy_values, res)
            options_dict = self.__plot_options_helper(mode=mode,
                                                calculate_fermi_speed=calculate_fermi_speed,
                                                calculate_fermi_velocity=calculate_fermi_velocity,
                                                calculate_effective_mass=calculate_effective_mass)

            plotter.add_mesh(e_surfaces[closest_idx], 
                            name='iso_surface', 
                            scalars = options_dict['scalars'], 
                            show_scalar_bar=False)

            if mode != "plain" or spin_texture:
                plotter.add_scalar_bar(
                    title=options_dict['text'],
                    n_labels=6,
                    italic=False,
                    bold=False,
                    title_font_size=None,
                    label_font_size=None,
                    position_x=0.4,
                    position_y=0.01,
                    color="black",)
            
            if options_dict['scalars'] == "spin" or options_dict['scalars'] == "Fermi Velocity Vector_magnitude":
                if arrow_color is None:
                    arrows=e_surfaces[closest_idx].glyph(
                                                        orient=options_dict['vector_name'], 
                                                        scale=False ,
                                                        factor=arrow_size)

                    # To update arrows. First ininitialize actor(There will already be a iso_surface actor), then remove, then add
                    arrow_actor = [value for key, value in plotter.renderer.actors.items() if 'PolyData' in key]
                    if len(arrow_actor) != 0:
                        plotter.remove_actor(arrow_actor[0])
                    
                    plotter.add_mesh(arrows, 
                                    scalars = options_dict['scalars'],
                                    cmap=cmap,
                                    show_scalar_bar=False)
                else:
                    plotter.add_mesh(arrows, color=arrow_color,show_scalar_bar=False)

            return None
예제 #3
0
def plot_contours_3d(contours: gpd.geodataframe.GeoDataFrame,
                     plotter: pv.Plotter,
                     color: str = 'red',
                     add_to_z: Union[int, float] = 0):
    """
           Plotting the dem in 3D with pv
           Args:
               contours: GeoDataFrame containing the contour information
               plotter: name of the PyVista plotter
               color: string for the color of the contour lines
               add_to_z: int of float value to add to the height of points
       """
    if not isinstance(contours, gpd.geodataframe.GeoDataFrame):
        raise TypeError('Line Object must be of type GeoDataFrame')

    # Checking if the plotter is of type pv plotter
    if not isinstance(plotter, pv.Plotter):
        raise TypeError('Plotter must be of type pv.Plotter')

    # Checking if the color is of type string
    if not isinstance(color, str):
        raise TypeError('Color must be of type string')

    # Checking if additional Z value is of type int or float
    if not isinstance(add_to_z, (int, float)):
        raise TypeError('Add_to_z must be of type int or float')

    # Checking if Z values are in gdf
    if not {'Z'}.issubset(contours.columns):
        raise ValueError('Z-values not defined')

    # If XY coordinates not in gdf, extract X,Y values
    if not {'X', 'Y'}.issubset(contours.columns):
        contours = extract_xy(contours, reset_index=False)

    # Create list of points and plot them
    try:
        for j in contours.index.unique():
            point_list = [[
                contours.loc[j].iloc[i].X, contours.loc[j].iloc[i].Y,
                contours.loc[j].iloc[i].Z + add_to_z
            ] for i in range(len(contours.loc[j]))]
            vertices = np.array(point_list)
            plotter.add_lines(vertices, color=color)
    except AttributeError:
        raise AttributeError('X and Y coordinates of countours are missing')
예제 #4
0
def draw_tnb_frame(p: pv.Plotter, v, t, n, b, arrow_size=arrow_size):
    p.add_mesh(pv.Arrow(start=v, direction=t * arrow_size, scale="auto"),
               color="red")
    p.add_mesh(pv.Arrow(start=v, direction=n * arrow_size, scale="auto"),
               color="green")
    p.add_mesh(pv.Arrow(start=v, direction=b * arrow_size, scale="auto"),
               color="blue")
    p.add_mesh(pv.Plane(v, n, arrow_size * 4, arrow_size * 4),
               color=tangent_plane_color,
               opacity=0.6)
예제 #5
0
def plot_mesh_pyvista(
    plotter: pv.Plotter,
    polydata: pv.PolyData,
    # vertices: np.ndarray,
    # triangles: np.ndarray,
    rotations: List[Tuple[int, int, int]] = [(0, 0, 0)],
    vertexcolors: List[int] = [],
    vertexscalar: str = '',
    cmap: str = 'YlGnBu',
    title: str = '',
    scalar_bar_idx: int = 0,
    **mesh_kwargs,
):
    shape = plotter.shape
    if len(shape) == 1:
        assert shape[0] > 0
        assert shape[0] == len(rotations)
        subp_idx = [(x, ) for x in range(shape[0])]
    else:
        assert shape[0] > 0 and shape[1] > 0
        assert shape[0] * shape[1] == len(rotations)
        subp_idx = product(range(shape[0]), range(shape[1]))

    if vertexscalar and vertexcolors is not None:
        polydata[vertexscalar] = vertexcolors

    cmap = plt.cm.get_cmap(cmap)

    mesh_kwargs = {
        'cmap': cmap,
        'flip_scalars': True,
        'show_scalar_bar': False,
        **mesh_kwargs,
    }
    if vertexscalar and vertexcolors is not None:
        mesh_kwargs['scalars'] = vertexscalar

    for i, (subp, rots) in enumerate(zip(subp_idx, rotations)):
        x, y, z = rots
        plotter.subplot(*subp)
        poly_copy = polydata.copy()
        poly_copy.rotate_x(x)
        poly_copy.rotate_y(y)
        poly_copy.rotate_z(z)
        plotter.add_mesh(
            poly_copy,
            **mesh_kwargs,
        )
        if i == 0:
            plotter.add_title(title, font_size=5)
        if i == scalar_bar_idx:
            plotter.add_scalar_bar(label_font_size=10, position_x=0.85)
예제 #6
0
def plot_points_3d(points: Union[gpd.geodataframe.GeoDataFrame, pd.DataFrame],
                   plotter: pv.Plotter,
                   color: str = 'blue',
                   add_to_z: Union[int, float] = 0):
    """
    Plotting points in 3D with PyVista
    Args:
        points: GeoDataFrame containing the points
        plotter: name of the PyVista plotter
        color: string of the coloring for points
        add_to_z: int of float value to add to the height of points
    """

    # Checking if points is of type GeoDataFrame
    if not isinstance(points, (gpd.geodataframe.GeoDataFrame, pd.DataFrame)):
        raise TypeError('Points must be of type GeoDataFrame or DataFrame')

    # Checking if all necessary columns are in the GeoDataFrame
    if not pd.Series(['X', 'Y', 'Z']).isin(points.columns).all():
        raise ValueError('Points are missing columns, XYZ needed')

    # Checking if the plotter is of type pyvista plotter
    if not isinstance(plotter, pv.Plotter):
        raise TypeError('Plotter must be of type pv.Plotter')

    # Checking if the color is of type string
    if not isinstance(color, str):
        raise TypeError('Color must be of type string')

    # Checking if additional Z value is of type int or float
    if not isinstance(add_to_z, (int, float)):
        raise TypeError('Add_to_z must be of type int or float')

    # Adding a Z value to the points to make them better visible
    points['Z'] = points['Z'] + add_to_z

    # Create PyVista PolyData
    points = pv.PolyData(points[['X', 'Y', 'Z']].to_numpy())

    # Adding mesh to plot
    plotter.add_mesh(points, color=color)
예제 #7
0
def add_curves_to_plotter(crv_lst: list,
                          plotter: pv.Plotter,
                          col_lst: list = [],
                          def_col="Tomato",
                          line_width: float = None,
                          render_lines_as_tubes=False) -> None:
    """Add curves to plotter

    Args:
        crv_lst (list): [description]
        plotter (pv.Plotter): [description]
        col_lst (list, optional): [description]. Defaults to [].
        def_col (str, optional): [description]. Defaults to "Tomato".
    """
    col_dft = []
    for i in range(len(col_lst), len(crv_lst)):
        col_dft.append(def_col)

    for msh, color in zip(mesh_curves(crv_lst), col_lst + col_dft):
        plotter.add_mesh(msh,
                         color=color,
                         line_width=line_width,
                         render_lines_as_tubes=render_lines_as_tubes)
예제 #8
0
def plot_boreholes_3d(df: pd.DataFrame,
                      plotter: pv.Plotter,
                      min_length: Union[float, int],
                      color_dict: dict,
                      show_labels=False,
                      labels=None,
                      ve=1,
                      **kwargs):
    """
    Plot boreholes in 3D
     df: pd.DataFrame containing the extracted borehole data
        min_length: float/int defining the minimum depth of boreholes to be plotted
        color_dict: dict containing the surface colors of the model
        labels: PyVista polydata object containing the name and coordinates of cities
        show_labels: bool for showing city labels

    Kwargs:
        radius: float/int of the radius of the boreholes plotted with PyVista, default = 10
    """

    # Checking if df is of a pandas DataFrame
    if not isinstance(df, pd.DataFrame):
        raise TypeError('Borehole data must be provided as Pandas DataFrame')

    # Checking that all necessary columns are present in the DataFrame
    if not pd.Series([
            'Index', 'Name', 'X', 'Y', 'Z', 'Altitude', 'Depth', 'formation'
    ]).isin(df.columns).all():
        raise ValueError(
            '[%s, %s, %s, %s, %s, %s, %s, %s] need to be columns in the provided DataFrame'
            %
            ('Index', 'Name', 'X', 'Y', 'Z', 'Altitude', 'Depth', 'formation'))

    # Checking that the min_limit is of type float or int
    if not isinstance(min_length, (float, int)):
        raise TypeError(
            'Minimum length for boreholes must be of type float or int')

    # Checking that the color_dict is of type dict
    if not isinstance(color_dict, dict):
        raise TypeError('Surface color dictionary must be of type dict')

    # Getting the radius for the tubes
    radius = kwargs.get('radius', 10)

    # Checking that the radius is of type int or float
    if not isinstance(radius, (int, float)):
        raise TypeError('The radius must be provided as int or float')

    # Checking if show_labels is of type bool
    if not isinstance(show_labels, bool):
        raise TypeError('Show_label must be of type bool')

    # Creating tubes for later plotting
    tubes, df_groups = create_borehole_tubes(df,
                                             min_length,
                                             color_dict,
                                             radius=radius)

    # Plotting labels
    if show_labels:
        tubes["Labels"] = labels
        plotter.add_point_labels(tubes, "Labels", point_size=5, font_size=10)

    # Plotting the borehole data
    for j in tqdm(range(len(tubes))):
        df_groups[j] = df_groups[j][1:]
        plotter.add_mesh(
            mesh=tubes[j],
            cmap=[color_dict[i] for i in df_groups[j]['formation'].unique()])

    # Setting plotting parameters
    plotter.set_scale(1, 1, ve)
    plotter.set_background(color='white')
    plotter.remove_scalar_bar()
    plotter.add_bounding_box(color='black')
    plotter.show_grid(color='black')
예제 #9
0
def plot_dem_3d(dem: Union[rasterio.io.DatasetReader, np.ndarray],
                plotter: pv.Plotter,
                extent: list,
                cmap: str = 'gist_earth',
                texture: Union[np.ndarray or bool] = None,
                res: int = 1,
                **kwargs):
    """
        Plotting the dem in 3D with PyVista
        Args:
            dem: rasterio object containing the height values
            plotter: name of the PyVista plotter
            cmap: string for the coloring of the dem
            texture: texture of the dem
            extent: list containing the values for the extent of the array (minx,maxx,miny,maxy)
            res: Resolution of the meshgrid
        Kwargs:
            array: np.ndarray to be plotted
    """

    # Checking if dem is a rasterio object
    if not isinstance(dem, (rasterio.io.DatasetReader, np.ndarray)):
        raise TypeError('dem must be a rasterio object')

    # Checking if the plotter is of type pyvista plotter
    if not isinstance(plotter, pv.Plotter):
        raise TypeError('Plotter must be of type pv.Plotter')

    # Checking if cmap if of type string
    if not isinstance(cmap, str):
        raise TypeError('cmap must be of type string')

    # Checking if texture is of type np.ndarray or bool
    if not isinstance(texture, (np.ndarray, bool, type(None))):
        raise TypeError('Texture must be of type np.ndarray or bool')

    # Getting array from kwargs
    array = kwargs.get('array', None)

    # Checking if array is of type np.ndarray or type None
    if not isinstance(array, (np.ndarray, type(None))):
        raise TypeError('array must be of type np.ndarray')

    # Rescale array if array is not of type None
    if array is not None:
        dem = resize_by_array(array, dem.read(1))
        dem = np.flipud(dem)

    # Convert rasterio object to array
    if isinstance(dem, rasterio.io.DatasetReader):
        dem = dem.read(1)

    # Create meshgrid
    x = np.arange(extent[0], extent[1], res)
    y = np.arange(extent[2], extent[3], res)
    x, y = np.meshgrid(x, y)

    # Create Structured grid
    grid = pv.StructuredGrid(x, y, dem)

    # Assigning elevation values to grid
    grid["Elevation"] = dem.ravel(order="F")

    # Plotting the grid
    plotter.add_mesh(grid,
                     scalars=grid["Elevation"],
                     cmap=cmap,
                     texture=texture)
예제 #10
0
    def __common_plotting(self,
                        fermi_surface,
                        plotter: pv.Plotter,
                        mode:str, 
                        text:str,
                        spin_texture:bool=False,
                        
                        camera_pos:List[float]=[1, 1, 1],
                        background_color:str or Tuple[float,float,float,float]="white",
                        perspective:bool=True,

                        show:bool=False,
                        save_2d:bool=None,
                        save_gif:str=None,
                        save_mp4:str=None,
                        save_3d:str=None):

        if mode != "plain" or spin_texture:
            plotter.add_scalar_bar(
                title=text,
                n_labels=6,
                italic=False,
                bold=False,
                title_font_size=None,
                label_font_size=None,
                position_x=0.4,
                position_y=0.01,
                color="black",)

        plotter.add_axes(
            xlabel="Kx", 
            ylabel="Ky", 
            zlabel="Kz", 
            line_width=6, 
            labels_off=False)

        if not perspective:
            plotter.enable_parallel_projection()

        plotter.set_background(background_color)
        if not show:
            plotter.show(cpos=camera_pos, screenshot=save_2d)
        if save_gif is not None:
            path = plotter.generate_orbital_path(n_points=36)
            plotter.open_gif(save_gif)
            plotter.orbit_on_path(path) 
        if save_mp4:
            path = plotter.generate_orbital_path(n_points=36)
            plotter.open_movie(save_mp4)
            plotter.orbit_on_path(path) 

        if save_3d is not None:
            plotter.save_meshio(save_3d,  fermi_surface)
예제 #11
0
def draw_cpc_wireframe(p: pv.Plotter, V: np.ndarray, n=10):
    for i in np.linspace(0, 100, n + 1, dtype=np.int)[1:-1] - 1:
        p.add_mesh(polyline_from_points(V[i, :]), color=wireframe_color)
    for j in np.linspace(0, 100, n + 1, dtype=np.int)[1:-1] - 1:
        p.add_mesh(polyline_from_points(V[:, j]), color=wireframe_color)
예제 #12
0
def plot_meshes_pyvista(
        plotter: pv.Plotter,
        polydatas: List[pv.PolyData],
        rotations: List[Tuple[int, int, int]],
        vertexcolors: np.ndarray = None,
        vertexscalar: str = '',
        cmap: str = 'YlGnBu',
        titles: str = '',
        scalar_bar_idx: int = -1,
        scalar_bar_kwargs: dict = dict(label_font_size=15, position_x=0.85),
        mesh_kwargs: dict = dict(),
        title_kwargs: dict = dict(),
):
    """
    Plot multiple meshes, each with their own rotation.
    
    Need a separate matrix of colours for each polydata.
    Should be one colour per vertex of the mesh.
    """
    shape = plotter.shape
    if len(shape) == 1:
        assert shape[0] > 0
        assert shape[0] == len(polydatas)
        subp_idx = [(x, ) for x in range(shape[0])]
    else:
        assert shape[0] > 0 and shape[1] > 0
        assert shape[0] * shape[1] == len(polydatas)
        subp_idx = product(range(shape[0]), range(shape[1]))

    if vertexscalar and vertexcolors is not None:
        assert vertexcolors.shape[0] == len(polydatas)

    cmap = plt.cm.get_cmap(cmap)

    mesh_kwargs = {
        'cmap': cmap,
        'flip_scalars': True,
        'show_scalar_bar': False,
        **mesh_kwargs,
    }
    if vertexscalar and vertexcolors is not None:
        mesh_kwargs['scalars'] = vertexscalar
        if 'clim' not in mesh_kwargs:
            mesh_kwargs['clim'] = [vertexcolors.min(), vertexcolors.max()]

    if isinstance(titles, str):
        titles = [titles] * len(polydatas)
    elif isinstance(titles, list):
        assert len(titles) == len(polydatas)

    for i, (subp, rots,
            polydata) in enumerate(zip(subp_idx, rotations, polydatas)):
        x, y, z = rots
        plotter.subplot(*subp)
        poly_copy = polydata.copy()
        if vertexscalar and vertexcolors is not None:
            poly_copy[vertexscalar] = vertexcolors[i]
        poly_copy.rotate_x(x)
        poly_copy.rotate_y(y)
        poly_copy.rotate_z(z)
        plotter.add_mesh(poly_copy, **mesh_kwargs)
        if titles[i]:
            plotter.add_title(titles[i],
                              **title_kwargs)  # font_size=10, font='arial')
        if i == scalar_bar_idx:
            plotter.add_scalar_bar(**scalar_bar_kwargs)