def main():
    parser = _build_args_parser()
    args = parser.parse_args()
    assert_inputs_exist(parser, [args.tractogram])
    assert_outputs_exist(parser, args, [], [args.save])

    tracts_format = detect_format(args.tractogram)
    if tracts_format is not TrkFile:
        raise ValueError("Invalid input streamline file format " +
                         "(must be trk): {0}".format(args.tractogram_filename))

    # Load files and data
    trk = TrkFile.load(args.tractogram)
    tractogram = trk.tractogram
    streamlines = tractogram.streamlines
    if 'seeds' not in tractogram.data_per_streamline:
        parser.error('Tractogram does not contain seeds')
    seeds = tractogram.data_per_streamline['seeds']

    # Make display objects
    streamlines_actor = actor.line(streamlines)
    points = actor.dots(seeds, color=(1., 1., 1.))

    # Add display objects to canvas
    r = window.Renderer()
    r.add(streamlines_actor)
    r.add(points)

    # Show and record if needed
    if args.save is not None:
        window.record(r, out_path=args.save, size=(1000, 1000))
    window.show(r)
def main():
    parser = _build_arg_parser()
    args = parser.parse_args()
    assert_inputs_exist(parser, [args.tractogram])
    assert_outputs_exist(parser, args, [], [args.save])

    tracts_format = detect_format(args.tractogram)
    if tracts_format is not TrkFile:
        raise ValueError("Invalid input streamline file format " +
                         "(must be trk): {0}".format(args.tractogram_filename))

    # Load files and data. TRKs can have 'same' as reference
    tractogram = load_tractogram(args.tractogram, 'same')
    # Streamlines are saved in RASMM but seeds are saved in VOX
    # This might produce weird behavior with non-iso
    tractogram.to_vox()

    streamlines = tractogram.streamlines
    if 'seeds' not in tractogram.data_per_streamline:
        parser.error('Tractogram does not contain seeds')
    seeds = tractogram.data_per_streamline['seeds']

    # Make display objects
    streamlines_actor = actor.line(streamlines)
    points = actor.dots(seeds, color=(1., 1., 1.))

    # Add display objects to canvas
    r = window.Renderer()
    r.add(streamlines_actor)
    r.add(points)

    # Show and record if needed
    if args.save is not None:
        window.record(r, out_path=args.save, size=(1000, 1000))
    window.show(r)
Exemple #3
0
def test_parallel_projection():

    ren = window.Renderer()
    axes = actor.axes()
    axes2 = actor.axes()
    axes2.SetPosition((2, 0, 0))

    # Add both axes.
    ren.add(axes, axes2)

    # Put the camera on a angle so that the
    # camera can show the difference between perspective
    # and parallel projection
    ren.set_camera((1.5, 1.5, 1.5))
    ren.GetActiveCamera().Zoom(2)

    # window.show(ren, reset_camera=True)
    ren.reset_camera()
    arr = window.snapshot(ren)

    ren.projection('parallel')
    # window.show(ren, reset_camera=False)
    arr2 = window.snapshot(ren)
    # Because of the parallel projection the two axes
    # will have the same size and therefore occupy more
    # pixels rather than in perspective projection were
    # the axes being further will be smaller.
    npt.assert_equal(np.sum(arr2 > 0) > np.sum(arr > 0), True)
def plot_each_shell(ms, plot_sym_vecs=True, use_sphere=True, same_color=False,
                    rad=0.025, opacity=1.0, ofile=None, ores=(300, 300)):
    """
    Plot each shell

    Parameters
    ----------
    ms: list of numpy.ndarray
        bvecs for each bval
    plot_sym_vecs: boolean
        Plot symmetrical vectors
    use_sphere: boolean
        rendering of the sphere
    same_color: boolean
        use same color for all shell
    rad: float
        radius of each point
    opacity: float
        opacity for the shells
    ofile: str
        output filename
    ores: tuple
        resolution of the output png

    Return
    ------
    """

    if len(ms) > 10:
        vtkcolors = fury.colormap.distinguishable_colormap(nb_colors=len(ms))

    if use_sphere:
        sphere = get_sphere('symmetric724')
        shape = (1, 1, 1, sphere.vertices.shape[0])
        fid, fname = mkstemp(suffix='_odf_slicer.mmap')
        odfs = np.memmap(fname, dtype=np.float64, mode='w+', shape=shape)
        odfs[:] = 1
        odfs[..., 0] = 1
        affine = np.eye(4)

    for i, shell in enumerate(ms):
        if same_color:
            i = 0
        ren = window.Renderer()
        ren.SetBackground(1, 1, 1)
        if use_sphere:
            sphere_actor = actor.odf_slicer(odfs, affine, sphere=sphere,
                                            colormap='winter', scale=1.0,
                                            opacity=opacity)
            ren.add(sphere_actor)
        pts_actor = actor.point(shell, vtkcolors[i], point_radius=rad)
        ren.add(pts_actor)
        if plot_sym_vecs:
            pts_actor = actor.point(-shell, vtkcolors[i], point_radius=rad)
            ren.add(pts_actor)
        window.show(ren)

        if ofile:
            window.snapshot(ren, fname=ofile + '_shell_' + str(i) + '.png',
                            size=ores)
Exemple #5
0
def test_text_widget():

    interactive = False

    renderer = window.Renderer()
    axes = actor.axes()
    window.add(renderer, axes)
    renderer.ResetCamera()

    show_manager = window.ShowManager(renderer, size=(900, 900))

    if interactive:
        show_manager.initialize()
        show_manager.render()

    fetch_viz_icons()
    button_png = read_viz_icons(fname='home3.png')

    def button_callback(obj, event):
        print('Button Pressed')

    button = widget.button(show_manager.iren, show_manager.ren,
                           button_callback, button_png, (.8, 1.2), (100, 100))

    global rulez
    rulez = True

    def text_callback(obj, event):

        global rulez
        print('Text selected')
        if rulez:
            obj.GetTextActor().SetInput("Diffusion Imaging Rulez!!")
            rulez = False
        else:
            obj.GetTextActor().SetInput("Diffusion Imaging in Python")
            rulez = True
        show_manager.render()

    text = widget.text(show_manager.iren,
                       show_manager.ren,
                       text_callback,
                       message="Diffusion Imaging in Python",
                       left_down_pos=(0., 0.),
                       right_top_pos=(0.4, 0.05),
                       opacity=1.,
                       border=False)

    if not interactive:
        button.Off()
        text.Off()
        pass

    if interactive:
        show_manager.render()
        show_manager.start()

    report = window.analyze_renderer(renderer)
    npt.assert_equal(report.actors, 3)
Exemple #6
0
def test_order_transparent():

    renderer = window.Renderer()

    lines = [
        np.array([[-1, 0, 0.], [1, 0, 0.]]),
        np.array([[-1, 1, 0.], [1, 1, 0.]])
    ]
    colors = np.array([[1., 0., 0.], [0., .5, 0.]])
    stream_actor = actor.streamtube(lines, colors, linewidth=0.3, opacity=0.5)
Exemple #7
0
def test_timer():
    """ Testing add a timer and exit window and app from inside timer.
    """

    xyzr = np.array([[0, 0, 0, 10], [100, 0, 0, 50], [300, 0, 0, 100]])
    xyzr2 = np.array([[0, 200, 0, 30], [100, 200, 0, 50], [300, 200, 0, 100]])
    colors = np.array([[1, 0, 0, 0.3], [0, 1, 0, 0.4], [0, 0, 1., 0.45]])

    renderer = window.Renderer()
    global sphere_actor, tb, cnt
    sphere_actor = actor.sphere(centers=xyzr[:, :3],
                                colors=colors[:],
                                radii=xyzr[:, 3])

    sphere = get_sphere('repulsion724')

    sphere_actor2 = actor.sphere(centers=xyzr2[:, :3],
                                 colors=colors[:],
                                 radii=xyzr2[:, 3],
                                 vertices=sphere.vertices,
                                 faces=sphere.faces.astype('i8'))

    renderer.add(sphere_actor)
    renderer.add(sphere_actor2)

    tb = ui.TextBlock2D()

    cnt = 0
    global showm
    showm = window.ShowManager(renderer,
                               size=(1024, 768),
                               reset_camera=False,
                               order_transparent=True)

    showm.initialize()

    def timer_callback(obj, event):
        global cnt, sphere_actor, showm, tb

        cnt += 1
        tb.message = "Let's count to 10 and exit :" + str(cnt)
        showm.render()
        if cnt > 9:
            showm.exit()

    renderer.add(tb)

    # Run every 200 milliseconds
    showm.add_timer_callback(True, 200, timer_callback)
    showm.start()

    arr = window.snapshot(renderer)

    npt.assert_(np.sum(arr) > 0)
Exemple #8
0
def test_peak_slicer(interactive=False):

    _peak_dirs = np.array([[1, 0, 0], [0, 1, 0], [0, 0, 1]], dtype='f4')
    # peak_dirs.shape = (1, 1, 1) + peak_dirs.shape

    peak_dirs = np.zeros((11, 11, 11, 3, 3))

    peak_values = np.random.rand(11, 11, 11, 3)

    peak_dirs[:, :, :] = _peak_dirs

    renderer = window.Renderer()
    peak_actor = actor.peak_slicer(peak_dirs)
    renderer.add(peak_actor)
    renderer.add(actor.axes((11, 11, 11)))
    if interactive:
        window.show(renderer)

    renderer.clear()
    renderer.add(peak_actor)
    renderer.add(actor.axes((11, 11, 11)))
    for k in range(11):
        peak_actor.display_extent(0, 10, 0, 10, k, k)

    for j in range(11):
        peak_actor.display_extent(0, 10, j, j, 0, 10)

    for i in range(11):
        peak_actor.display(i, None, None)

    renderer.rm_all()

    peak_actor = actor.peak_slicer(
        peak_dirs,
        peak_values,
        mask=None,
        affine=np.diag([3, 2, 1, 1]),
        colors=None,
        opacity=1,
        linewidth=3,
        lod=True,
        lod_points=10 ** 4,
        lod_points_size=3)

    renderer.add(peak_actor)
    renderer.add(actor.axes((11, 11, 11)))
    if interactive:
        window.show(renderer)

    report = window.analyze_renderer(renderer)
    ex = ['vtkLODActor', 'vtkOpenGLActor', 'vtkOpenGLActor', 'vtkOpenGLActor']
    npt.assert_equal(report.actors_classnames, ex)
Exemple #9
0
def show_template_bundles(final_streamlines, template_path, fname):
    import nibabel as nib
    from fury import actor, window
    renderer = window.Renderer()
    template_img_data = nib.load(template_path).get_data().astype('bool')
    template_actor = actor.contour_from_roi(template_img_data,
                                            color=(50, 50, 50), opacity=0.05)
    renderer.add(template_actor)
    lines_actor = actor.streamtube(final_streamlines, window.colors.orange,
                                   linewidth=0.3)
    renderer.add(lines_actor)
    window.record(renderer, n_frames=1, out_path=fname, size=(900, 900))
    return
Exemple #10
0
def test_labels(interactive=False):

    text_actor = actor.label("Hello")

    renderer = window.Renderer()
    renderer.add(text_actor)
    renderer.reset_camera()
    renderer.reset_clipping_range()

    if interactive:
        window.show(renderer, reset_camera=False)

    npt.assert_equal(renderer.GetActors().GetNumberOfItems(), 1)
Exemple #11
0
def test_deprecated():
    with warnings.catch_warnings(record=True) as w:
        warnings.simplefilter("always", DeprecationWarning)
        scene = window.Renderer()
        npt.assert_equal(scene.size(), (0, 0))
        npt.assert_equal(len(w), 1)
        npt.assert_(issubclass(w[-1].category, DeprecationWarning))

    with warnings.catch_warnings(record=True) as w:
        warnings.simplefilter("always", DeprecationWarning)
        scene = window.renderer(background=(0.0, 1.0, 0.0))
        npt.assert_equal(scene.size(), (0, 0))
        npt.assert_equal(len(w), 1)
        npt.assert_(issubclass(w[-1].category, DeprecationWarning))

    with warnings.catch_warnings(record=True) as w:
        warnings.simplefilter("always", DeprecationWarning)
        scene = window.ren()
        npt.assert_equal(scene.size(), (0, 0))
        npt.assert_equal(len(w), 2)
        npt.assert_(issubclass(w[-1].category, DeprecationWarning))

    scene = window.Scene()
    with warnings.catch_warnings(record=True) as l_warn:
        warnings.simplefilter("always", DeprecationWarning)
        obj = actor.axes(scale=(1, 1, 1))
        window.add(scene, obj)
        arr = window.snapshot(scene)
        report = window.analyze_snapshot(arr)
        npt.assert_equal(report.objects, 3)
        window.rm(scene, obj)
        arr = window.snapshot(scene)
        report = window.analyze_snapshot(arr)
        npt.assert_equal(report.objects, 0)
        window.add(scene, obj)
        window.rm_all(scene)
        arr = window.snapshot(scene)
        report = window.analyze_snapshot(arr)
        npt.assert_equal(report.objects, 0)
        window.add(scene, obj)
        window.clear(scene)
        report = window.analyze_renderer(scene)
        npt.assert_equal(report.actors, 0)
        deprecated_warns = [
            w for w in l_warn if issubclass(w.category, DeprecationWarning)
        ]
        npt.assert_equal(len(deprecated_warns), 7)
        npt.assert_(issubclass(l_warn[-1].category, DeprecationWarning))
Exemple #12
0
def test_spheres(interactive=False):

    xyzr = np.array([[0, 0, 0, 10], [100, 0, 0, 25], [200, 0, 0, 50]])
    colors = np.array([[1, 0, 0, 0.3], [0, 1, 0, 0.4], [0, 0, 1., 0.99]])

    renderer = window.Renderer()
    sphere_actor = actor.sphere(centers=xyzr[:, :3], colors=colors[:],
                                radii=xyzr[:, 3])
    renderer.add(sphere_actor)

    if interactive:
        window.show(renderer, order_transparent=True)

    arr = window.snapshot(renderer)
    report = window.analyze_snapshot(arr,
                                     colors=colors)
    npt.assert_equal(report.objects, 3)
Exemple #13
0
def test_points(interactive=False):
    points = np.array([[0, 0, 0], [0, 1, 0], [1, 0, 0]])
    colors = np.array([[1, 0, 0], [0, 1, 0], [0, 0, 1]])

    points_actor = actor.point(points,  colors)

    renderer = window.Renderer()
    renderer.add(points_actor)
    renderer.reset_camera()
    renderer.reset_clipping_range()

    if interactive:
        window.show(renderer, reset_camera=False)

    npt.assert_equal(renderer.GetActors().GetNumberOfItems(), 1)

    arr = window.snapshot(renderer)
    report = window.analyze_snapshot(arr,
                                     colors=colors)
    npt.assert_equal(report.objects, 3)
Exemple #14
0
def display_slices(volume_actor,
                   slices,
                   output_filename,
                   axis_name,
                   view_position,
                   focal_point,
                   peaks_actor=None,
                   streamlines_actor=None):
    # Setting for the slice of interest
    if axis_name == 'sagittal':
        volume_actor.display(slices[0], None, None)
        if peaks_actor:
            peaks_actor.display(slices[0], None, None)
        view_up_vector = (0, 0, 1)
    elif axis_name == 'coronal':
        volume_actor.display(None, slices[1], None)
        if peaks_actor:
            peaks_actor.display(None, slices[1], None)
        view_up_vector = (0, 0, 1)
    else:
        volume_actor.display(None, None, slices[2])
        if peaks_actor:
            peaks_actor.display(None, None, slices[2])
        view_up_vector = (0, 1, 0)

    # Generate the scene, set the camera and take the snapshot
    ren = window.Renderer()
    ren.add(volume_actor)
    if streamlines_actor:
        ren.add(streamlines_actor)
    elif peaks_actor:
        ren.add(peaks_actor)
    ren.set_camera(position=view_position,
                   view_up=view_up_vector,
                   focal_point=focal_point)

    window.snapshot(ren,
                    size=(1920, 1080),
                    offscreen=True,
                    fname=output_filename)
Exemple #15
0
def test_dots(interactive=False):
    points = np.array([[0, 0, 0], [0, 1, 0], [1, 0, 0]])

    dots_actor = actor.dots(points, color=(0, 255, 0))

    renderer = window.Renderer()
    renderer.add(dots_actor)
    renderer.reset_camera()
    renderer.reset_clipping_range()

    if interactive:
        window.show(renderer, reset_camera=False)

    npt.assert_equal(renderer.GetActors().GetNumberOfItems(), 1)

    extent = renderer.GetActors().GetLastActor().GetBounds()
    npt.assert_equal(extent, (0.0, 1.0, 0.0, 1.0, 0.0, 0.0))

    arr = window.snapshot(renderer)
    report = window.analyze_snapshot(arr,
                                     colors=(0, 255, 0))
    npt.assert_equal(report.objects, 3)

    # Test one point
    points = np.array([0, 0, 0])
    dot_actor = actor.dots(points, color=(0, 0, 255))

    renderer.clear()
    renderer.add(dot_actor)
    renderer.reset_camera()
    renderer.reset_clipping_range()

    arr = window.snapshot(renderer)
    report = window.analyze_snapshot(arr,
                                     colors=(0, 0, 255))
    npt.assert_equal(report.objects, 1)
Exemple #16
0
def show_template_bundles(final_streamlines, template_path, fname):
    """Displayes the template bundles

    Parameters
    ----------
    final_streamlines : list
        Generated streamlines
    template_path : str
        Path to reference FA nii.gz file
    fname : str
        Path of the output file (saved as )
    """

    renderer = window.Renderer()
    template_img_data = nib.load(template_path).get_data().astype("bool")
    template_actor = actor.contour_from_roi(
        template_img_data, color=(50, 50, 50), opacity=0.05
    )
    renderer.add(template_actor)
    lines_actor = actor.streamtube(
        final_streamlines, window.colors.orange, linewidth=0.3
    )
    renderer.add(lines_actor)
    window.record(renderer, n_frames=1, out_path=fname, size=(900, 900))
Exemple #17
0
def test_button_and_slider_widgets():
    recording = False
    filename = "test_button_and_slider_widgets.log.gz"
    recording_filename = pjoin(DATA_DIR, filename)
    renderer = window.Renderer()

    # create some minimalistic streamlines
    lines = [
        np.array([[-1, 0, 0.], [1, 0, 0.]]),
        np.array([[-1, 1, 0.], [1, 1, 0.]])
    ]
    colors = np.array([[1., 0., 0.], [0.3, 0.7, 0.]])
    stream_actor = actor.streamtube(lines, colors)

    states = {
        'camera_button_count': 0,
        'plus_button_count': 0,
        'minus_button_count': 0,
        'slider_moved_count': 0,
    }

    renderer.add(stream_actor)

    # the show manager allows to break the rendering process
    # in steps so that the widgets can be added properly
    show_manager = window.ShowManager(renderer, size=(800, 800))

    if recording:
        show_manager.initialize()
        show_manager.render()

    def button_callback(obj, event):
        # print('Camera pressed')
        states['camera_button_count'] += 1

    def button_plus_callback(obj, event):
        # print('+ pressed')
        states['plus_button_count'] += 1

    def button_minus_callback(obj, event):
        # print('- pressed')
        states['minus_button_count'] += 1

    fetch_viz_icons()
    button_png = read_viz_icons(fname='camera.png')

    button = widget.button(show_manager.iren, show_manager.ren,
                           button_callback, button_png, (.98, 1.), (80, 50))

    button_png_plus = read_viz_icons(fname='plus.png')
    button_plus = widget.button(show_manager.iren, show_manager.ren,
                                button_plus_callback, button_png_plus,
                                (.98, .9), (120, 50))

    button_png_minus = read_viz_icons(fname='minus.png')
    button_minus = widget.button(show_manager.iren, show_manager.ren,
                                 button_minus_callback, button_png_minus,
                                 (.98, .9), (50, 50))

    def print_status(obj, event):
        rep = obj.GetRepresentation()
        stream_actor.SetPosition((rep.GetValue(), 0, 0))
        states['slider_moved_count'] += 1

    slider = widget.slider(show_manager.iren,
                           show_manager.ren,
                           callback=print_status,
                           min_value=-1,
                           max_value=1,
                           value=0.,
                           label="X",
                           right_normalized_pos=(.98, 0.6),
                           size=(120, 0),
                           label_format="%0.2lf")

    # This callback is used to update the buttons/sliders' position
    # so they can stay on the right side of the window when the window
    # is being resized.

    global size
    size = renderer.GetSize()

    if recording:
        show_manager.record_events_to_file(recording_filename)
        print(states)
    else:
        show_manager.play_events_from_file(recording_filename)
        npt.assert_equal(states["camera_button_count"], 7)
        npt.assert_equal(states["plus_button_count"], 3)
        npt.assert_equal(states["minus_button_count"], 4)
        npt.assert_equal(states["slider_moved_count"], 116)

    if not recording:
        button.Off()
        slider.Off()
        # Uncomment below to test the slider and button with analyze
        # button.place(renderer)
        # slider.place(renderer)

        report = window.analyze_renderer(renderer)
        # import pylab as plt
        # plt.imshow(report.labels, origin='lower')
        # plt.show()
        npt.assert_equal(report.actors, 1)

    report = window.analyze_renderer(renderer)
    npt.assert_equal(report.actors, 1)
Exemple #18
0
def structural_plotting(conn_matrix,
                        uatlas,
                        streamlines_mni,
                        template_mask,
                        interactive=False):
    """

    :param conn_matrix:
    :param uatlas:
    :param streamlines_mni:
    :param template_mask:
    :param interactive:
    :return:
    """
    import nibabel as nib
    import numpy as np
    import networkx as nx
    import os
    import pkg_resources
    from nibabel.affines import apply_affine
    from fury import actor, window, colormap, ui
    from dipy.tracking.utils import streamline_near_roi
    from nilearn.plotting import find_parcellation_cut_coords
    from nilearn.image import resample_to_img
    from pynets.thresholding import normalize
    from pynets.nodemaker import mmToVox

    ch2better_loc = pkg_resources.resource_filename(
        "pynets", "templates/ch2better.nii.gz")

    # Instantiate scene
    r = window.Renderer()

    # Set camera
    r.set_camera(position=(-176.42, 118.52, 128.20),
                 focal_point=(113.30, 128.31, 76.56),
                 view_up=(0.18, 0.00, 0.98))

    # Load atlas rois
    atlas_img = nib.load(uatlas)
    atlas_img_data = atlas_img.get_data()

    # Collapse list of connected streamlines for visualization
    streamlines = nib.streamlines.load(streamlines_mni).streamlines
    parcels = []
    i = 0
    for roi in np.unique(atlas_img_data)[1:]:
        parcels.append(atlas_img_data == roi)
        i = i + 1

    # Add streamlines as cloud of 'white-matter'
    streamlines_actor = actor.line(streamlines,
                                   colormap.create_colormap(np.ones(
                                       [len(streamlines)]),
                                                            name='Greys_r',
                                                            auto=True),
                                   lod_points=10000,
                                   depth_cue=True,
                                   linewidth=0.2,
                                   fake_tube=True,
                                   opacity=1.0)
    r.add(streamlines_actor)

    # Creat palette of roi colors and add them to the scene as faint contours
    roi_colors = np.random.rand(int(np.max(atlas_img_data)), 3)
    parcel_contours = []
    i = 0
    for roi in np.unique(atlas_img_data)[1:]:
        include_roi_coords = np.array(np.where(atlas_img_data == roi)).T
        x_include_roi_coords = apply_affine(np.eye(4), include_roi_coords)
        bool_list = []
        for sl in streamlines:
            bool_list.append(
                streamline_near_roi(sl,
                                    x_include_roi_coords,
                                    tol=1.0,
                                    mode='either_end'))
        if sum(bool_list) > 0:
            print('ROI: ' + str(i))
            parcel_contours.append(
                actor.contour_from_roi(atlas_img_data == roi,
                                       color=roi_colors[i],
                                       opacity=0.2))
        else:
            pass
        i = i + 1

    for vol_actor in parcel_contours:
        r.add(vol_actor)

    # Get voxel coordinates of parcels and add them as 3d spherical centroid nodes
    [coords, labels] = find_parcellation_cut_coords(atlas_img,
                                                    background_label=0,
                                                    return_labels=True)

    coords_vox = []
    for i in coords:
        coords_vox.append(mmToVox(atlas_img.affine, i))
    coords_vox = list(set(list(tuple(x) for x in coords_vox)))

    # Build an edge list of 3d lines
    G = nx.from_numpy_array(normalize(conn_matrix))
    for i in G.nodes():
        nx.set_node_attributes(G, {i: coords_vox[i]}, labels[i])

    G.remove_nodes_from(list(nx.isolates(G)))
    G_filt = nx.Graph()
    fedges = filter(lambda x: G.degree()[x[0]] > 0 and G.degree()[x[1]] > 0,
                    G.edges())
    G_filt.add_edges_from(fedges)

    coord_nodes = []
    for i in range(len(G.edges())):
        edge = list(G.edges())[i]
        [x, y] = edge
        x_coord = list(G.nodes[x].values())[0]
        x_label = list(G.nodes[x].keys())[0]
        l_x = actor.label(text=str(x_label),
                          pos=x_coord,
                          scale=(1, 1, 1),
                          color=(50, 50, 50))
        r.add(l_x)
        y_coord = list(G.nodes[y].values())[0]
        y_label = list(G.nodes[y].keys())[0]
        l_y = actor.label(text=str(y_label),
                          pos=y_coord,
                          scale=(1, 1, 1),
                          color=(50, 50, 50))
        r.add(l_y)
        coord_nodes.append(x_coord)
        coord_nodes.append(y_coord)
        c = actor.line([(x_coord, y_coord)],
                       window.colors.coral,
                       linewidth=100 * (float(G.get_edge_data(x, y)['weight']))
                       ^ 2)
        r.add(c)

    point_actor = actor.point(list(set(coord_nodes)),
                              window.colors.grey,
                              point_radius=0.75)
    r.add(point_actor)

    # Load glass brain template and resample to MNI152_2mm brain
    template_img = nib.load(ch2better_loc)
    template_target_img = nib.load(template_mask)
    res_brain_img = resample_to_img(template_img, template_target_img)
    template_img_data = res_brain_img.get_data().astype('bool')
    template_actor = actor.contour_from_roi(template_img_data,
                                            color=(50, 50, 50),
                                            opacity=0.05)
    r.add(template_actor)

    # Show scene
    if interactive is True:
        window.show(r, size=(600, 600), reset_camera=False)
    else:
        fig_path = os.path.dirname(streamlines_mni) + '/3d_connectome_fig.png'
        window.record(r, out_path=fig_path, size=(600, 600))

    return
Exemple #19
0
def test_renderer():

    ren = window.Renderer()

    npt.assert_equal(ren.size(), (0, 0))

    # background color for renderer (1, 0.5, 0)
    # 0.001 added here to remove numerical errors when moving from float
    # to int values
    bg_float = (1, 0.501, 0)

    # that will come in the image in the 0-255 uint scale
    bg_color = tuple((np.round(255 * np.array(bg_float))).astype('uint8'))

    ren.background(bg_float)
    # window.show(ren)
    arr = window.snapshot(ren)

    report = window.analyze_snapshot(arr,
                                     bg_color=bg_color,
                                     colors=[bg_color, (0, 127, 0)])
    npt.assert_equal(report.objects, 0)
    npt.assert_equal(report.colors_found, [True, False])

    axes = actor.axes()
    ren.add(axes)
    # window.show(ren)

    arr = window.snapshot(ren)
    report = window.analyze_snapshot(arr, bg_color)
    npt.assert_equal(report.objects, 1)

    ren.rm(axes)
    arr = window.snapshot(ren)
    report = window.analyze_snapshot(arr, bg_color)
    npt.assert_equal(report.objects, 0)

    window.add(ren, axes)
    arr = window.snapshot(ren)
    report = window.analyze_snapshot(arr, bg_color)
    npt.assert_equal(report.objects, 1)

    ren.rm_all()
    arr = window.snapshot(ren)
    report = window.analyze_snapshot(arr, bg_color)
    npt.assert_equal(report.objects, 0)

    ren2 = window.Renderer(bg_float)
    ren2.background((0, 0, 0.))

    report = window.analyze_renderer(ren2)
    npt.assert_equal(report.bg_color, (0, 0, 0))

    ren2.add(axes)

    report = window.analyze_renderer(ren2)
    npt.assert_equal(report.actors, 3)

    window.rm(ren2, axes)
    report = window.analyze_renderer(ren2)
    npt.assert_equal(report.actors, 0)
Exemple #20
0
def test_tensor_slicer(interactive=False):

    evals = np.array([1.4, .35, .35]) * 10 ** (-3)
    evecs = np.eye(3)

    mevals = np.zeros((3, 2, 4, 3))
    mevecs = np.zeros((3, 2, 4, 3, 3))

    mevals[..., :] = evals
    mevecs[..., :, :] = evecs

    from dipy.data import get_sphere

    sphere = get_sphere('symmetric724')

    affine = np.eye(4)
    renderer = window.Renderer()

    tensor_actor = actor.tensor_slicer(mevals, mevecs, affine=affine,
                                       sphere=sphere,  scale=.3)
    I, J, K = mevals.shape[:3]
    renderer.add(tensor_actor)
    renderer.reset_camera()
    renderer.reset_clipping_range()

    tensor_actor.display_extent(0, 1, 0, J, 0, K)
    tensor_actor.GetProperty().SetOpacity(1.0)
    if interactive:
        window.show(renderer, reset_camera=False)

    npt.assert_equal(renderer.GetActors().GetNumberOfItems(), 1)

    # Test extent
    big_extent = renderer.GetActors().GetLastActor().GetBounds()
    big_extent_x = abs(big_extent[1] - big_extent[0])
    tensor_actor.display(x=2)

    if interactive:
        window.show(renderer, reset_camera=False)

    small_extent = renderer.GetActors().GetLastActor().GetBounds()
    small_extent_x = abs(small_extent[1] - small_extent[0])
    npt.assert_equal(big_extent_x > small_extent_x, True)

    # Test empty mask
    empty_actor = actor.tensor_slicer(mevals, mevecs, affine=affine,
                                      mask=np.zeros(mevals.shape[:3]),
                                      sphere=sphere,  scale=.3)
    npt.assert_equal(empty_actor.GetMapper(), None)

    # Test mask
    mask = np.ones(mevals.shape[:3])
    mask[:2, :3, :3] = 0
    cfa = color_fa(fractional_anisotropy(mevals), mevecs)
    tensor_actor = actor.tensor_slicer(mevals, mevecs, affine=affine,
                                       mask=mask, scalar_colors=cfa,
                                       sphere=sphere,  scale=.3)
    renderer.clear()
    renderer.add(tensor_actor)
    renderer.reset_camera()
    renderer.reset_clipping_range()

    if interactive:
        window.show(renderer, reset_camera=False)

    mask_extent = renderer.GetActors().GetLastActor().GetBounds()
    mask_extent_x = abs(mask_extent[1] - mask_extent[0])
    npt.assert_equal(big_extent_x > mask_extent_x, True)

    # test display
    tensor_actor.display()
    current_extent = renderer.GetActors().GetLastActor().GetBounds()
    current_extent_x = abs(current_extent[1] - current_extent[0])
    npt.assert_equal(big_extent_x > current_extent_x, True)
    if interactive:
        window.show(renderer, reset_camera=False)

    tensor_actor.display(y=1)
    current_extent = renderer.GetActors().GetLastActor().GetBounds()
    current_extent_y = abs(current_extent[3] - current_extent[2])
    big_extent_y = abs(big_extent[3] - big_extent[2])
    npt.assert_equal(big_extent_y > current_extent_y, True)
    if interactive:
        window.show(renderer, reset_camera=False)

    tensor_actor.display(z=1)
    current_extent = renderer.GetActors().GetLastActor().GetBounds()
    current_extent_z = abs(current_extent[5] - current_extent[4])
    big_extent_z = abs(big_extent[5] - big_extent[4])
    npt.assert_equal(big_extent_z > current_extent_z, True)
    if interactive:
        window.show(renderer, reset_camera=False)

    # Test error handling of the method when
    # incompatible dimension of mevals and evecs are passed.
    mevals = np.zeros((3, 2, 3))
    mevecs = np.zeros((3, 2, 4, 3, 3))

    with npt.assert_raises(RuntimeError):
        tensor_actor = actor.tensor_slicer(mevals, mevecs, affine=affine,
                                           mask=mask, scalar_colors=cfa,
                                           sphere=sphere, scale=.3)
Exemple #21
0
def test_wrong_interactor_style():
    panel = ui.Panel2D(size=(300, 150))
    dummy_renderer = window.Renderer()
    dummy_show_manager = window.ShowManager(dummy_renderer,
                                            interactor_style='trackball')
    npt.assert_raises(TypeError, panel.add_to_renderer, dummy_renderer)
Exemple #22
0
def screenshot_tracking(tracking, t1, directory="."):
    """
    Compute 3 view screenshot with streamlines on T1.

    Parameters
    ----------
    tracking : string
        tractogram filename.
    t1 : string
        t1 filename.
    directory : string
        Directory to save the mosaic.

    Returns
    -------
    name : string
        Path of the mosaic
    """
    tractogram = nib.streamlines.load(tracking, True).tractogram
    t1 = nib.load(t1)
    t1_data = t1.get_data()

    slice_name = ['sagittal', 'coronal', 'axial']
    img_center = [(int(t1_data.shape[0] / 2) + 5, None, None),
                  (None, int(t1_data.shape[1] / 2), None),
                  (None, None, int(t1_data.shape[2] / 2))]
    center = [(330, 90, 60), (70, 330, 60), (70, 90, 400)]
    viewup = [(0, 0, -1), (0, 0, -1), (0, -1, 0)]
    size = (1920, 1080)

    image = np.array([])
    for i, _axis in enumerate(slice_name):
        streamlines = []
        it = 0
        slice_idx = img_center[i][i]

        for streamline in tractogram:
            if it > 10000:
                break
            stream = streamline.streamline
            if slice_idx in np.array(stream, dtype=int)[:, i]:
                it += 1
                idx = np.where(np.array(stream, dtype=int)[:, i] == \
                               slice_idx)[0][0]
                lower = idx - 2
                if lower < 0:
                    lower = 0
                upper = idx + 2
                if upper > len(stream) - 1:
                    upper = len(stream) - 1
                streamlines.append(stream[lower:upper])

        ren = window.Renderer()

        streamline_actor = actor.line(streamlines, linewidth=0.2)
        ren.add(streamline_actor)

        min_val = np.min(t1_data[t1_data > 0])
        max_val = np.percentile(t1_data[t1_data > 0], 99)
        t1_color = np.float32(t1_data - min_val) \
                   / np.float32(max_val - min_val) * 255.0
        slice_actor = actor.slicer(t1_color, opacity=0.8, value_range=(0, 255),
                                   interpolation='nearest')
        ren.add(slice_actor)
        slice_actor.display(img_center[i][0], img_center[i][1],
                            img_center[i][2])

        camera = ren.GetActiveCamera()
        camera.SetViewUp(viewup[i])
        center_cam = streamline_actor.GetCenter()
        camera.SetPosition(center[i])
        camera.SetFocalPoint((center_cam))

        img2 = renderer_to_arr(ren, size)
        if image.size == 0:
            image = img2
        else:
            image = np.hstack((image, img2))

    streamlines = []
    it = 0
    for streamline in tractogram:
        if it > 10000:
            break
        it += 1
        streamlines.append(streamline.streamline)

    ren = window.Renderer()
    streamline_actor = actor.streamtube(streamlines, linewidth=0.2)
    ren.add(streamline_actor)
    camera = ren.GetActiveCamera()
    camera.SetViewUp(0, 0, -1)
    center = streamline_actor.GetCenter()
    camera.SetPosition(center[0], 350, center[2])
    camera.SetFocalPoint(center)
    img2 = renderer_to_arr(ren, (3 * 1920, 1920))
    image = np.vstack((image, img2))

    imgs_comb = Image.fromarray(image)
    imgs_comb = imgs_comb.resize((3 * 1920, 1920 + 1080))
    image_name = os.path.basename(str(tracking)).split(".")[0]
    name = os.path.join(directory, image_name + '.png')
    imgs_comb.save(name)

    return name
Exemple #23
0
def test_active_camera():
    renderer = window.Renderer()
    renderer.add(actor.axes(scale=(1, 1, 1)))

    renderer.reset_camera()
    renderer.reset_clipping_range()

    direction = renderer.camera_direction()
    position, focal_point, view_up = renderer.get_camera()

    renderer.set_camera((0., 0., 1.), (0., 0., 0), view_up)

    position, focal_point, view_up = renderer.get_camera()
    npt.assert_almost_equal(np.dot(direction, position), -1)

    renderer.zoom(1.5)

    new_position, _, _ = renderer.get_camera()

    npt.assert_array_almost_equal(position, new_position)

    renderer.zoom(1)

    # rotate around focal point
    renderer.azimuth(90)

    position, _, _ = renderer.get_camera()

    npt.assert_almost_equal(position, (1.0, 0.0, 0))

    arr = window.snapshot(renderer)
    report = window.analyze_snapshot(arr, colors=[(255, 0, 0)])
    npt.assert_equal(report.colors_found, [True])

    # rotate around camera's center
    renderer.yaw(90)

    arr = window.snapshot(renderer)
    report = window.analyze_snapshot(arr, colors=[(0, 0, 0)])
    npt.assert_equal(report.colors_found, [True])

    renderer.yaw(-90)
    renderer.elevation(90)

    arr = window.snapshot(renderer)
    report = window.analyze_snapshot(arr, colors=(0, 255, 0))
    npt.assert_equal(report.colors_found, [True])

    renderer.set_camera((0., 0., 1.), (0., 0., 0), view_up)

    # vertical rotation of the camera around the focal point
    renderer.pitch(10)
    renderer.pitch(-10)

    # rotate around the direction of projection
    renderer.roll(90)

    # inverted normalized distance from focal point along the direction
    # of the camera

    position, _, _ = renderer.get_camera()
    renderer.dolly(0.5)
    new_position, _, _ = renderer.get_camera()
    npt.assert_almost_equal(position[2], 0.5 * new_position[2])
def main():
    parser = _build_arg_parser()
    args = parser.parse_args()

    assert_inputs_exist(parser, [args.anat_reference])
    assert_outputs_exist(parser, args, [args.output_name])

    output_names = [
        'axial_superior', 'axial_inferior', 'coronal_posterior',
        'coronal_anterior', 'sagittal_left', 'sagittal_right'
    ]

    list_of_bundles = [f for f in args.inputs]
    # output_dir: where temporary files will be created
    output_dir = os.path.dirname(args.output_name)

    # ----------------------------------------------------------------------- #
    # Mosaic, column 0: orientation names and data description
    # ----------------------------------------------------------------------- #
    width = args.resolution_of_thumbnails
    height = args.resolution_of_thumbnails
    rows = 6
    cols = len(list_of_bundles)
    text_pos_x = 50
    text_pos_y = 50

    # Creates a new empty image, RGB mode
    mosaic = Image.new('RGB', ((cols + 1) * width, (rows + 1) * height))

    # Prepare draw and font objects to render text
    draw = ImageDraw.Draw(mosaic)
    font = get_font(args)

    # Data of the image used as background
    ref_img = nib.load(args.anat_reference)
    data = ref_img.get_data()
    affine = ref_img.affine
    mean, std = data[data > 0].mean(), data[data > 0].std()
    value_range = (mean - 0.5 * std, mean + 1.5 * std)

    # First column with rows description
    draw_column_with_names(draw, output_names, text_pos_x, text_pos_y, height,
                           font)

    # ----------------------------------------------------------------------- #
    # Columns with bundles
    # ----------------------------------------------------------------------- #
    for idx_bundle, bundle_file in enumerate(list_of_bundles):

        bundle_file_name = os.path.basename(bundle_file)
        bundle_name, _ = os.path.splitext(bundle_file_name)

        # !! It creates a temporary folder to create
        # the images to concatenate in the mosaic !!
        output_bundle_dir = os.path.join(output_dir, bundle_name)
        if not os.path.isdir(output_bundle_dir):
            os.makedirs(output_bundle_dir)

        output_paths = [
            os.path.join(output_bundle_dir, '{}_' +
                         os.path.basename(output_bundle_dir)).format(name)
            for name in output_names
        ]

        i = (idx_bundle + 1) * width

        if not os.path.isfile(bundle_file):
            print('\nInput file {} doesn\'t exist.'.format(bundle_file))

            number_streamlines = 0

            view_number = 6
            j = height * view_number

            draw_bundle_information(draw, bundle_file_name, number_streamlines,
                                    i + text_pos_x, j + text_pos_y, font)

        else:
            # Select the streamlines to plot
            bundle_tractogram_file = nib.streamlines.load(bundle_file)
            streamlines = bundle_tractogram_file.streamlines

            tubes = actor.line(streamlines)

            number_streamlines = len(streamlines)

            # Render
            ren = window.Renderer()
            zoom = args.zoom
            opacity = args.opacity_background

            # Structural data
            slice_actor = actor.slicer(data, affine, value_range)
            slice_actor.opacity(opacity)
            ren.add(slice_actor)

            # Streamlines
            ren.add(tubes)
            ren.reset_camera()
            ren.zoom(zoom)
            view_number = 0
            set_img_in_cell(mosaic, ren, view_number,
                            output_paths[view_number], width, height, i)

            ren.pitch(180)
            ren.reset_camera()
            ren.zoom(zoom)
            view_number = 1
            set_img_in_cell(mosaic, ren, view_number,
                            output_paths[view_number], width, height, i)

            ren.rm(slice_actor)
            slice_actor2 = slice_actor.copy()
            slice_actor2.display(None, slice_actor2.shape[1] // 2, None)
            slice_actor2.opacity(opacity)
            ren.add(slice_actor2)

            ren.pitch(90)
            ren.set_camera(view_up=(0, 0, 1))
            ren.reset_camera()
            ren.zoom(zoom)
            view_number = 2
            set_img_in_cell(mosaic, ren, view_number,
                            output_paths[view_number], width, height, i)

            ren.pitch(180)
            ren.set_camera(view_up=(0, 0, 1))
            ren.reset_camera()
            ren.zoom(zoom)
            view_number = 3
            set_img_in_cell(mosaic, ren, view_number,
                            output_paths[view_number], width, height, i)

            ren.rm(slice_actor2)
            slice_actor3 = slice_actor.copy()
            slice_actor3.display(slice_actor3.shape[0] // 2, None, None)
            slice_actor3.opacity(opacity)
            ren.add(slice_actor3)

            ren.yaw(90)
            ren.reset_camera()
            ren.zoom(zoom)
            view_number = 4
            set_img_in_cell(mosaic, ren, view_number,
                            output_paths[view_number], width, height, i)

            ren.yaw(180)
            ren.reset_camera()
            ren.zoom(zoom)
            view_number = 5
            set_img_in_cell(mosaic, ren, view_number,
                            output_paths[view_number], width, height, i)

            view_number = 6
            j = height * view_number
            draw_bundle_information(draw, bundle_file_name, number_streamlines,
                                    i + text_pos_x, j + text_pos_y, font)

        shutil.rmtree(output_bundle_dir)

    # Save image to file
    mosaic.save(args.output_name)
Exemple #25
0
world_coords = True

###############################################################################
# If we want to see the objects in native space we need to make sure that all
# objects which are currently in world coordinates are transformed back to
# native space using the inverse of the affine.

if not world_coords:
    from dipy.tracking.streamline import transform_streamlines
    streamlines = transform_streamlines(streamlines, np.linalg.inv(affine))

###############################################################################
# Now we create, a ``Renderer`` object and add the streamlines using the
# ``line`` function and an image plane using the ``slice`` function.

ren = window.Renderer()
stream_actor = actor.line(streamlines)

if not world_coords:
    image_actor_z = actor.slicer(data, affine=np.eye(4))
else:
    image_actor_z = actor.slicer(data, affine)

###############################################################################
# We can also change also the opacity of the slicer.

slicer_opacity = 0.6
image_actor_z.opacity(slicer_opacity)

###############################################################################
# We can add additonal slicers by copying the original and adjusting the
Exemple #26
0
def test_contour_from_roi():

    # Render volume
    renderer = window.renderer()
    data = np.zeros((50, 50, 50))
    data[20:30, 25, 25] = 1.
    data[25, 20:30, 25] = 1.
    affine = np.eye(4)
    surface = actor.contour_from_roi(data, affine,
                                     color=np.array([1, 0, 1]),
                                     opacity=.5)
    renderer.add(surface)

    renderer.reset_camera()
    renderer.reset_clipping_range()
    # window.show(renderer)

    # Test binarization
    renderer2 = window.renderer()
    data2 = np.zeros((50, 50, 50))
    data2[20:30, 25, 25] = 1.
    data2[35:40, 25, 25] = 1.
    affine = np.eye(4)
    surface2 = actor.contour_from_roi(data2, affine,
                                      color=np.array([0, 1, 1]),
                                      opacity=.5)
    renderer2.add(surface2)

    renderer2.reset_camera()
    renderer2.reset_clipping_range()
    # window.show(renderer2)

    arr = window.snapshot(renderer, 'test_surface.png', offscreen=True)
    arr2 = window.snapshot(renderer2, 'test_surface2.png', offscreen=True)

    report = window.analyze_snapshot(arr, find_objects=True)
    report2 = window.analyze_snapshot(arr2, find_objects=True)

    npt.assert_equal(report.objects, 1)
    npt.assert_equal(report2.objects, 2)

    # test on real streamlines using tracking example
    from dipy.data import read_stanford_labels
    from dipy.reconst.shm import CsaOdfModel
    from dipy.data import default_sphere
    from dipy.direction import peaks_from_model
    from dipy.tracking.local import ThresholdTissueClassifier
    from dipy.tracking import utils
    from dipy.tracking.local import LocalTracking
    from fury.colormap import line_colors

    hardi_img, gtab, labels_img = read_stanford_labels()
    data = hardi_img.get_data()
    labels = labels_img.get_data()
    affine = hardi_img.affine

    white_matter = (labels == 1) | (labels == 2)

    csa_model = CsaOdfModel(gtab, sh_order=6)
    csa_peaks = peaks_from_model(csa_model, data, default_sphere,
                                 relative_peak_threshold=.8,
                                 min_separation_angle=45,
                                 mask=white_matter)

    classifier = ThresholdTissueClassifier(csa_peaks.gfa, .25)

    seed_mask = labels == 2
    seeds = utils.seeds_from_mask(seed_mask, density=[1, 1, 1], affine=affine)

    # Initialization of LocalTracking.
    # The computation happens in the next step.
    streamlines = LocalTracking(csa_peaks, classifier, seeds, affine,
                                step_size=2)

    # Compute streamlines and store as a list.
    streamlines = list(streamlines)

    # Prepare the display objects.
    streamlines_actor = actor.line(streamlines, line_colors(streamlines))
    seedroi_actor = actor.contour_from_roi(seed_mask, affine, [0, 1, 1], 0.5)

    # Create the 3d display.
    r = window.Renderer()
    r2 = window.Renderer()
    r.add(streamlines_actor)
    arr3 = window.snapshot(r, 'test_surface3.png', offscreen=True)
    report3 = window.analyze_snapshot(arr3, find_objects=True)
    r2.add(streamlines_actor)
    r2.add(seedroi_actor)
    arr4 = window.snapshot(r2, 'test_surface4.png', offscreen=True)
    report4 = window.analyze_snapshot(arr4, find_objects=True)

    # assert that the seed ROI rendering is not far
    # away from the streamlines (affine error)
    npt.assert_equal(report3.objects, report4.objects)
Exemple #27
0
def test_odf_slicer(interactive=False):

    sphere = get_sphere('symmetric362')

    shape = (11, 11, 11, sphere.vertices.shape[0])

    fid, fname = mkstemp(suffix='_odf_slicer.mmap')
    print(fid)
    print(fname)

    odfs = np.memmap(fname, dtype=np.float64, mode='w+',
                     shape=shape)

    odfs[:] = 1

    affine = np.eye(4)
    renderer = window.Renderer()

    mask = np.ones(odfs.shape[:3])
    mask[:4, :4, :4] = 0

    odfs[..., 0] = 1

    odf_actor = actor.odf_slicer(odfs, affine,
                                 mask=mask, sphere=sphere, scale=.25,
                                 colormap='plasma')
    fa = 0. * np.zeros(odfs.shape[:3])
    fa[:, 0, :] = 1.
    fa[:, -1, :] = 1.
    fa[0, :, :] = 1.
    fa[-1, :, :] = 1.
    fa[5, 5, 5] = 1

    k = 5
    I, J, K = odfs.shape[:3]

    fa_actor = actor.slicer(fa, affine)
    fa_actor.display_extent(0, I, 0, J, k, k)
    renderer.add(odf_actor)
    renderer.reset_camera()
    renderer.reset_clipping_range()

    odf_actor.display_extent(0, I, 0, J, k, k)
    odf_actor.GetProperty().SetOpacity(1.0)
    if interactive:
        window.show(renderer, reset_camera=False)

    arr = window.snapshot(renderer)
    report = window.analyze_snapshot(arr, find_objects=True)
    npt.assert_equal(report.objects, 11 * 11)

    renderer.clear()
    renderer.add(fa_actor)
    renderer.reset_camera()
    renderer.reset_clipping_range()
    if interactive:
        window.show(renderer)

    mask[:] = 0
    mask[5, 5, 5] = 1
    fa[5, 5, 5] = 0
    fa_actor = actor.slicer(fa, None)
    fa_actor.display(None, None, 5)
    odf_actor = actor.odf_slicer(odfs, None, mask=mask,
                                 sphere=sphere, scale=.25,
                                 colormap='plasma',
                                 norm=False, global_cm=True)
    renderer.clear()
    renderer.add(fa_actor)
    renderer.add(odf_actor)
    renderer.reset_camera()
    renderer.reset_clipping_range()
    if interactive:
        window.show(renderer)

    renderer.clear()
    renderer.add(odf_actor)
    renderer.add(fa_actor)
    odfs[:, :, :] = 1
    mask = np.ones(odfs.shape[:3])
    odf_actor = actor.odf_slicer(odfs, None, mask=mask,
                                 sphere=sphere, scale=.25,
                                 colormap='plasma',
                                 norm=False, global_cm=True)

    renderer.clear()
    renderer.add(odf_actor)
    renderer.add(fa_actor)
    renderer.add(actor.axes((11, 11, 11)))
    for i in range(11):
        odf_actor.display(i, None, None)
        fa_actor.display(i, None, None)
        if interactive:
            window.show(renderer)
    for j in range(11):
        odf_actor.display(None, j, None)
        fa_actor.display(None, j, None)
        if interactive:
            window.show(renderer)
    # with mask equal to zero everything should be black
    mask = np.zeros(odfs.shape[:3])
    odf_actor = actor.odf_slicer(odfs, None, mask=mask,
                                 sphere=sphere, scale=.25,
                                 colormap='plasma',
                                 norm=False, global_cm=True)
    renderer.clear()
    renderer.add(odf_actor)
    renderer.reset_camera()
    renderer.reset_clipping_range()
    if interactive:
        window.show(renderer)

    report = window.analyze_renderer(renderer)
    npt.assert_equal(report.actors, 1)
    npt.assert_equal(report.actors_classnames[0], 'vtkLODActor')

    del odf_actor
    odfs._mmap.close()
    del odfs
    os.close(fid)

    os.remove(fname)
Exemple #28
0
def test_custom_interactor_style_events(recording=False):
    print("Using VTK {}".format(vtk.vtkVersion.GetVTKVersion()))
    filename = "test_custom_interactor_style_events.log.gz"
    recording_filename = pjoin(DATA_DIR, filename)
    renderer = window.Renderer()

    # the show manager allows to break the rendering process
    # in steps so that the widgets can be added properly
    interactor_style = interactor.CustomInteractorStyle()
    show_manager = window.ShowManager(renderer,
                                      size=(800, 800),
                                      reset_camera=False,
                                      interactor_style=interactor_style)

    # Create a cursor, a circle that will follow the mouse.
    polygon_source = vtk.vtkRegularPolygonSource()
    polygon_source.GeneratePolygonOff()  # Only the outline of the circle.
    polygon_source.SetNumberOfSides(50)
    polygon_source.SetRadius(10)
    # polygon_source.SetRadius
    polygon_source.SetCenter(0, 0, 0)

    mapper = vtk.vtkPolyDataMapper2D()
    vtk_utils.set_input(mapper, polygon_source.GetOutputPort())

    cursor = vtk.vtkActor2D()
    cursor.SetMapper(mapper)
    cursor.GetProperty().SetColor(1, 0.5, 0)
    renderer.add(cursor)

    def follow_mouse(iren, obj):
        obj.SetPosition(*iren.event.position)
        iren.force_render()

    interactor_style.add_active_prop(cursor)
    interactor_style.add_callback(cursor, "MouseMoveEvent", follow_mouse)

    # create some minimalistic streamlines
    lines = [
        np.array([[-1, 0, 0.], [1, 0, 0.]]),
        np.array([[-1, 1, 0.], [1, 1, 0.]])
    ]
    colors = np.array([[1., 0., 0.], [0.3, 0.7, 0.]])
    tube1 = actor.streamtube([lines[0]], colors[0])
    tube2 = actor.streamtube([lines[1]], colors[1])
    renderer.add(tube1)
    renderer.add(tube2)

    # Define some counter callback.
    states = defaultdict(lambda: 0)

    def counter(iren, obj):
        states[iren.event.name] += 1

    # Assign the counter callback to every possible event.
    for event in [
            "CharEvent", "MouseMoveEvent", "KeyPressEvent", "KeyReleaseEvent",
            "LeftButtonPressEvent", "LeftButtonReleaseEvent",
            "RightButtonPressEvent", "RightButtonReleaseEvent",
            "MiddleButtonPressEvent", "MiddleButtonReleaseEvent"
    ]:
        interactor_style.add_callback(tube1, event, counter)

    # Add callback to scale up/down tube1.
    def scale_up_obj(iren, obj):
        counter(iren, obj)
        scale = np.asarray(obj.GetScale()) + 0.1
        obj.SetScale(*scale)
        iren.force_render()
        iren.event.abort()  # Stop propagating the event.

    def scale_down_obj(iren, obj):
        counter(iren, obj)
        scale = np.array(obj.GetScale()) - 0.1
        obj.SetScale(*scale)
        iren.force_render()
        iren.event.abort()  # Stop propagating the event.

    interactor_style.add_callback(tube2, "MouseWheelForwardEvent",
                                  scale_up_obj)
    interactor_style.add_callback(tube2, "MouseWheelBackwardEvent",
                                  scale_down_obj)

    # Add callback to hide/show tube1.
    def toggle_visibility(iren, obj):
        key = iren.event.key
        if key.lower() == "v":
            obj.SetVisibility(not obj.GetVisibility())
            iren.force_render()

    interactor_style.add_active_prop(tube1)
    interactor_style.add_active_prop(tube2)
    interactor_style.remove_active_prop(tube2)
    interactor_style.add_callback(tube1, "CharEvent", toggle_visibility)

    if recording:
        show_manager.record_events_to_file(recording_filename)
        print(list(states.items()))
    else:
        show_manager.play_events_from_file(recording_filename)
        msg = ("Wrong count for '{}'.")
        expected = [('CharEvent', 6), ('KeyPressEvent', 6),
                    ('KeyReleaseEvent', 6), ('MouseMoveEvent', 1652),
                    ('LeftButtonPressEvent', 1), ('RightButtonPressEvent', 1),
                    ('MiddleButtonPressEvent', 2),
                    ('LeftButtonReleaseEvent', 1),
                    ('MouseWheelForwardEvent', 3),
                    ('MouseWheelBackwardEvent', 1),
                    ('MiddleButtonReleaseEvent', 2),
                    ('RightButtonReleaseEvent', 1)]

        # Useful loop for debugging.
        for event, count in expected:
            if states[event] != count:
                print("{}: {} vs. {} (expected)".format(
                    event, states[event], count))

        for event, count in expected:
            npt.assert_equal(states[event], count, err_msg=msg.format(event))