예제 #1
0
def test_plot_3d_montage(requires_pyvista, fname_raw, to_1020, ch_names):
    import pyvista
    pyvista.close_all()
    assert len(pyvista.plotting._ALL_PLOTTERS) == 0
    raw = mne.io.read_raw_nirx(fname_raw)
    if to_1020:
        need = set(
            sum((ch_name.split()[0].split('_') for ch_name in raw.ch_names),
                list()))
        mon = mne.channels.make_standard_montage('standard_1020')
        mon.rename_channels({h: n for h, n in zip(mon.ch_names, need)})
        raw.set_montage(mon)
    n_labels = len(raw.ch_names) // 2
    view_map = {
        'left-lat': np.arange(1, n_labels // 2),
        'caudal': np.arange(n_labels // 2, n_labels + 1)
    }
    # We use "sample" here even though it's wrong so that we can have a head
    # surface
    with catch_logging() as log:
        mne_nirs.viz.plot_3d_montage(raw.info,
                                     view_map,
                                     subject='sample',
                                     surface='white',
                                     subjects_dir=subjects_dir,
                                     ch_names=ch_names,
                                     verbose=True)
    assert len(pyvista.plotting._ALL_PLOTTERS) == 0
    log = log.getvalue().lower()
    if to_1020:
        assert 'automatically mapped' in log
    else:
        assert 'could not' in log
예제 #2
0
def test_closing_and_mem_cleanup():
    n = 5
    for i in range(n):
        for j in range(n):
            p = pyvista.Plotter(off_screen=OFF_SCREEN)
            for k in range(n):
                p.add_mesh(pyvista.Sphere(radius=k))
            p.show()
        pyvista.close_all()
예제 #3
0
def test_closing_and_mem_cleanup():
    n = 5
    for _ in range(n):
        for _ in range(n):
            p = pyvista.Plotter()
            for k in range(n):
                p.add_mesh(pyvista.Sphere(radius=k))
            p.show()
        pyvista.close_all()
예제 #4
0
def check_gc():
    """Ensure that all VTK objects are garbage-collected by Python."""
    before = set(id(o) for o in gc.get_objects() if _is_vtk(o))
    yield
    pyvista.close_all()

    gc.collect()
    after = [o for o in gc.get_objects() if _is_vtk(o) and id(o) not in before]
    assert len(after) == 0, \
        'Not all objects GCed:\n' + \
        '\n'.join(sorted(o.__class__.__name__ for o in after))
예제 #5
0
def render_figures(
    code,
    code_path,
    output_dir,
    output_base,
    context,
    function_name,
    config,
):
    """Run a pyplot script and save the images in *output_dir*.

    Save the images under *output_dir* with file names derived from
    *output_base*
    """
    # Try to determine if all images already exist
    is_doctest, code_pieces = _split_code_at_show(code)

    # Otherwise, we didn't find the files, so build them
    results = []
    ns = plot_context if context else {}

    for i, code_piece in enumerate(code_pieces):
        # generate the plot
        _run_code(
            doctest.script_from_examples(code_piece)
            if is_doctest else code_piece, code_path, ns, function_name)

        images = []
        figures = pyvista.plotting._ALL_PLOTTERS

        for j, (address, plotter) in enumerate(figures.items()):
            if hasattr(plotter, '_gif_filename'):
                image_file = ImageFile(output_dir,
                                       f"{output_base}_{i:02d}_{j:02d}.gif")
                shutil.move(plotter._gif_filename, image_file.filename)
            else:
                image_file = ImageFile(output_dir,
                                       f"{output_base}_{i:02d}_{j:02d}.png")
                plotter.screenshot(image_file.filename)
            images.append(image_file)

        pyvista.close_all()  # close and clear all plotters

        results.append((code_piece, images))

    return results
예제 #6
0
 def __call__(self, block, block_vars, gallery_conf):  # noqa:D101
     from sphinx_gallery.scrapers import figure_rst
     image_names = list()
     image_path_iterator = block_vars["image_path_iterator"]
     figures = pyvista.plotting._ALL_PLOTTERS
     seen_plotters = list()
     for address, plotter in figures.items():
         if plotter in seen_plotters:
             continue
         seen_plotters += [plotter]
         fname = next(image_path_iterator)
         if hasattr(plotter, '_gif_filename'):
             # move gif to fname
             shutil.move(plotter._gif_filename, fname)
         else:
             plotter.screenshot(fname)
         image_names.append(fname)
     pyvista.close_all()  # close and clear all plotters
     return figure_rst(image_names, gallery_conf["src_dir"])
예제 #7
0
def test_scraper(tmpdir):
    pytest.importorskip('sphinx_gallery')
    pyvista.close_all()
    plotter = pyvista.Plotter(off_screen=True)
    scraper = Scraper()
    src_dir = str(tmpdir)
    out_dir = op.join(str(tmpdir), '_build', 'html')
    img_fname = op.join(src_dir, 'auto_examples', 'images', 'sg_img.png')
    gallery_conf = {"src_dir": src_dir, "builder_name": "html"}
    target_file = op.join(src_dir, 'auto_examples', 'sg.py')
    block = None
    block_vars = dict(image_path_iterator=(img for img in [img_fname]),
                      example_globals=dict(a=1),
                      target_file=target_file)
    os.makedirs(op.dirname(img_fname))
    assert not os.path.isfile(img_fname)
    os.makedirs(out_dir)
    scraper(block, block_vars, gallery_conf)
    assert os.path.isfile(img_fname)
    plotter.close()
예제 #8
0
def plot_prediction(index, prediction, gound_truth, labels, path, title=""):
    """Plots a picture for each class with the mesh and the labels predicted for it in red if it is wrong and green if it is correct.

    :param index: array[n_labels]
        An array containing the index of the test meshes for each class
    :param prediction: array[n_prediction]
        The result of the classification
    :param gound_truth: array[n_prediction]
        The gound truth labels in the same order of prediction
    :param labels: list
        A list containing the classes
    :param path: string or os.path
        The path where to save the plots
    :return:
    """
    os.makedirs(os.path.join(path, "Qualitative"), exist_ok=True)
    i_prediction = 0
    err = prediction == gound_truth

    for inds, l in zip(index, labels):
        plotter = pv.Plotter(shape=(1, index[0].shape[0]),
                             off_screen=False,
                             window_size=[1024, 1024 // 2])
        plotter.set_background("white")
        j = 0
        offs = glob.glob(os.path.join('.', 'Dataset', l, "*.off"))
        for i in inds:

            mesh = pv.read(offs[i])
            plotter.subplot(0, j)
            plotter.add_text(labels[prediction[i_prediction]],
                             color="green" if err[i_prediction] else "red",
                             font_size=10)
            plotter.add_mesh(mesh, smooth_shading=True, color="grey")
            j += 1
            i_prediction += 1

        plotter.save_graphic(
            os.path.join(path, "Qualitative", l + title + '.pdf'))
    pv.close_all()
    return
예제 #9
0
 def __call__(self, block, block_vars, gallery_conf):
     """
     Called by sphinx-gallery to save the figures generated after running
     example code.
     """
     try:
         from sphinx_gallery.scrapers import figure_rst
     except ImportError:
         raise ImportError('You must install `sphinx_gallery`')
     image_names = list()
     image_path_iterator = block_vars["image_path_iterator"]
     figures = pyvista.plotting._ALL_PLOTTERS
     for address, plotter in figures.items():
         fname = next(image_path_iterator)
         if hasattr(plotter, '_gif_filename'):
             # move gif to fname
             shutil.move(plotter._gif_filename, fname)
         else:
             plotter.screenshot(fname)
         image_names.append(fname)
     pyvista.close_all()  # close and clear all plotters
     return figure_rst(image_names, gallery_conf["src_dir"])
예제 #10
0
def _close_all():
    with warnings.catch_warnings():
        warnings.filterwarnings("ignore", category=DeprecationWarning)
        close_all()
예제 #11
0
def plot_mesh(vertices,
              faces=None,
              input_file_name=None,
              darkmode=True,
              background_color='k',
              text=None,
              window_size=[1280, 720],
              font_color='w',
              cpos=[(3.77, 3.77, 3.77), (0.0069, -0.0045, 0.0),
                    (0.0, 0.0, 1.0)]):
    '''visualize the mesh surface using vtk and return an img.  
	vertices is a numpy array of vertices with faces.
	faces is a numpy array of indices specifying the faces of the surface of the mesh using ^those vertices.
	input_file_name is a string giving the path to a file containing ^those faces (e.g. *.ply or *.stl).
	an example text=f'time={tme:.1f}'
	window_size = [1280,720] is the standard aspect ratio for youtube,
	'''
    #visualize the mesh surface
    pv.set_plot_theme('document')

    #get the vtk object (wrapped by pyvista from withing tetgen.  Faces recorded by trimesh))
    if faces is None:
        if input_file_name is None:
            Exception('either faces or input_file_name must be specified')
        mesh_trimesh = trimesh.load(input_file_name)
        faces = mesh_trimesh.faces

    tet = tetgen.TetGen(vertices, faces)
    # #fault tolerant tetrahedralization
    vertices_tet, elements_tet = tet.tetrahedralize(order=1,
                                                    mindihedral=0.,
                                                    minratio=10.,
                                                    nobisect=False,
                                                    steinerleft=100000)  #
    tet.make_manifold()
    grid = tet.grid

    # advanced plotting
    plotter = pv.Plotter()
    if darkmode:
        plotter.set_background(background_color)
        plotter.add_mesh(grid, 'lightgrey', lighting=True)
        #looks like tron plotter.add_mesh(grid, 'r', 'wireframe')
    else:
        plotter.add_mesh(grid, 'lightgrey', lighting=True)
        font_color = 'k'
    if text is not None:
        plotter.add_text(text,
                         position='upper_left',
                         font_size=24,
                         color=font_color,
                         font='times')
    #font options
    # FONT_KEYS = {'arial': vtk.VTK_ARIAL,
    #              'courier': vtk.VTK_COURIER,
    #              'times': vtk.VTK_TIMES}

    #cpos is (camera position, focal point, and view up)
    #for movies, just set the camera position to some constant value.

    _cpos, img = plotter.show(title=None,
                              return_img=True,
                              cpos=cpos,
                              window_size=window_size,
                              use_ipyvtk=False,
                              interactive=False,
                              auto_close=True)
    plotter.deep_clean()
    del plotter
    pv.close_all()
    return img
예제 #12
0
def _close_all():
    with warnings.catch_warnings():
        warnings.filterwarnings("ignore", category=DeprecationWarning)
        from pyvista import close_all
        close_all()
예제 #13
0
def _close_all():
    pyvista.close_all()
예제 #14
0
def plot2D(file,
           field,
           bounds,
           ax=None,
           contours=False,
           colorbar=False,
           cfields=['crust_upper', 'crust_lower', 'mantle_lithosphere'],
           null_field='asthenosphere',
           **kwargs):
    """
    Plot 2D ASPECT results using Pyvista.

    Parameters
    ----------
    file : VTU or PVTU file to plot
    field : Field to use for color.
    bounds : List of bounds (km) by which to clip the plot.
    contours : Boolean for whether to add temperature contours. 
        The default is False.
    cfields : Names of compositional fields to use if field is 'comp_field.' 
        The default is ['crust_upper','crust_lower','mantle_lithosphere'].
    null_field : Null field if field is 'comp_field.'
        The default is 'asthenosphere'.

    Returns
    -------

    """

    mesh = pv.read(file)

    km2m = 1000
    bounds_m = [bound * km2m for bound in bounds]  # Convert bounds to m
    bounds_3D = bounds_m + [0, 0]
    mesh = mesh.clip_box(bounds=bounds_3D, invert=False)

    if field == 'comp_field':
        mesh = comp_field_vtk(mesh, fields=cfields, null_field=null_field)

    if contours == True:
        cntrs = add_contours(mesh)

    pv.set_plot_theme("document")
    plotter = pv.Plotter(off_screen=True)
    sargs = dict(width=0.6,
                 fmt='%.1e',
                 height=0.2,
                 label_font_size=32,
                 position_x=0.1)

    plotter.add_mesh(mesh, scalars=field, scalar_bar_args=sargs, **kwargs)

    if contours == True:
        plotter.add_mesh(cntrs, color='black', line_width=5)

    plotter.view_xy()

    if colorbar == False:
        plotter.remove_scalar_bar()

    plotter.enable_depth_peeling(10)

    # Calculate Camera Position from Bounds
    bounds_array = np.array(bounds_m)
    xmag = float(abs(bounds_array[1] - bounds_array[0]))
    ymag = float(abs(bounds_array[3] - bounds_array[2]))
    aspect_ratio = ymag / xmag

    plotter.window_size = (1024, int(1024 * aspect_ratio))

    xmid = xmag / 2 + bounds_array[0]  # X midpoint
    ymid = ymag / 2 + bounds_array[2]  # Y midpoint
    zoom = xmag * aspect_ratio * 1.875  # Zoom level - not sure why 1.875 works

    position = (xmid, ymid, zoom)
    focal_point = (xmid, ymid, 0)
    viewup = (0, 1, 0)

    camera = [position, focal_point, viewup]
    # print(camera)

    plotter.camera_position = camera
    plotter.camera_set = True

    # Create image
    img = plotter.screenshot(transparent_background=True, return_img=True)

    # Plot using imshow
    if ax is None:
        ax = plt.gca()

    ax.imshow(img, aspect='equal', extent=bounds)

    plotter.clear()
    pv.close_all()

    return (ax)