示例#1
0
def create_panel(line_slider_x, line_slider_label_x, line_slider_y,
                 line_slider_label_y, line_slider_z, line_slider_label_z,
                 opacity_slider, opacity_slider_label):
    """
    Now we will create a ``panel`` to contain the sliders and labels.
    """
    panel = ui.Panel2D(center=(1030, 120),
                       size=(300, 200),
                       color=(1, 1, 1),
                       opacity=0.1,
                       align="right")

    panel.add_element(line_slider_label_x, 'relative', (0.1, 0.75))
    panel.add_element(line_slider_x, 'relative', (0.65, 0.8))
    panel.add_element(line_slider_label_y, 'relative', (0.1, 0.55))
    panel.add_element(line_slider_y, 'relative', (0.65, 0.6))
    panel.add_element(line_slider_label_z, 'relative', (0.1, 0.35))
    panel.add_element(line_slider_z, 'relative', (0.65, 0.4))
    panel.add_element(opacity_slider_label, 'relative', (0.1, 0.15))
    panel.add_element(opacity_slider, 'relative', (0.65, 0.2))

    return panel
示例#2
0
    label.background = (0, 0, 0)
    label.color = (1, 1, 1)

    return label


line_slider_label_z = build_label(text="Z Slice")
line_slider_label_x = build_label(text="X Slice")
line_slider_label_y = build_label(text="Y Slice")
opacity_slider_label = build_label(text="Opacity")
"""
Now we will create a ``panel`` to contain the sliders and labels.
"""

panel = ui.Panel2D(size=(300, 200),
                   color=(1, 1, 1),
                   opacity=0.1,
                   align="right")
panel.center = (1030, 120)

panel.add_element(line_slider_label_x, (0.1, 0.75))
panel.add_element(line_slider_x, (0.38, 0.75))
panel.add_element(line_slider_label_y, (0.1, 0.55))
panel.add_element(line_slider_y, (0.38, 0.55))
panel.add_element(line_slider_label_z, (0.1, 0.35))
panel.add_element(line_slider_z, (0.38, 0.35))
panel.add_element(opacity_slider_label, (0.1, 0.15))
panel.add_element(opacity_slider, (0.38, 0.15))

ren.add(panel)
"""
Then, we can render all the widgets and everything else in the screen and
示例#3
0
def visualize_volume(volume,
                     x=None,
                     y=None,
                     z=None,
                     figure=None,
                     flip_axes=None,
                     opacity=0.6,
                     inline=True,
                     interact=False):
    """
    Visualize a volume

    Parameters
    ----------
    volume : ndarray or str
        3d volume to visualize.

    figure : fury Scene object, optional
        If provided, the visualization will be added to this Scene. Default:
        Initialize a new Scene.

    flip_axes : None
        This parameter is to conform fury and plotly APIs.

    opacity : float, optional
        Initial opacity of slices.
        Default: 0.6

    interact : bool
        Whether to provide an interactive VTK window for interaction.
        Default: False

    inline : bool
        Whether to embed the visualization inline in a notebook. Only works
        in the notebook context. Default: False.

    Returns
    -------
    Fury Scene object
    """
    volume = vut.load_volume(volume)

    if figure is None:
        figure = window.Scene()

    shape = volume.shape
    image_actor_z = actor.slicer(volume)
    slicer_opacity = opacity
    image_actor_z.opacity(slicer_opacity)

    image_actor_x = image_actor_z.copy()
    if x is None:
        x = int(np.round(shape[0] / 2))
    image_actor_x.display_extent(x, x, 0, shape[1] - 1, 0, shape[2] - 1)

    image_actor_y = image_actor_z.copy()

    if y is None:
        y = int(np.round(shape[1] / 2))
    image_actor_y.display_extent(0, shape[0] - 1, y, y, 0, shape[2] - 1)

    figure.add(image_actor_z)
    figure.add(image_actor_x)
    figure.add(image_actor_y)

    show_m = window.ShowManager(figure, size=(1200, 900))
    show_m.initialize()

    if interact:
        line_slider_z = ui.LineSlider2D(min_value=0,
                                        max_value=shape[2] - 1,
                                        initial_value=shape[2] / 2,
                                        text_template="{value:.0f}",
                                        length=140)

        line_slider_x = ui.LineSlider2D(min_value=0,
                                        max_value=shape[0] - 1,
                                        initial_value=shape[0] / 2,
                                        text_template="{value:.0f}",
                                        length=140)

        line_slider_y = ui.LineSlider2D(min_value=0,
                                        max_value=shape[1] - 1,
                                        initial_value=shape[1] / 2,
                                        text_template="{value:.0f}",
                                        length=140)

        opacity_slider = ui.LineSlider2D(min_value=0.0,
                                         max_value=1.0,
                                         initial_value=slicer_opacity,
                                         length=140)

        def change_slice_z(slider):
            z = int(np.round(slider.value))
            image_actor_z.display_extent(0, shape[0] - 1, 0, shape[1] - 1, z,
                                         z)

        def change_slice_x(slider):
            x = int(np.round(slider.value))
            image_actor_x.display_extent(x, x, 0, shape[1] - 1, 0,
                                         shape[2] - 1)

        def change_slice_y(slider):
            y = int(np.round(slider.value))
            image_actor_y.display_extent(0, shape[0] - 1, y, y, 0,
                                         shape[2] - 1)

        def change_opacity(slider):
            slicer_opacity = slider.value
            image_actor_z.opacity(slicer_opacity)
            image_actor_x.opacity(slicer_opacity)
            image_actor_y.opacity(slicer_opacity)

        line_slider_z.on_change = change_slice_z
        line_slider_x.on_change = change_slice_x
        line_slider_y.on_change = change_slice_y
        opacity_slider.on_change = change_opacity

        def build_label(text):
            label = ui.TextBlock2D()
            label.message = text
            label.font_size = 18
            label.font_family = 'Arial'
            label.justification = 'left'
            label.bold = False
            label.italic = False
            label.shadow = False
            label.background = (0, 0, 0)
            label.color = (1, 1, 1)

            return label

        line_slider_label_z = build_label(text="Z Slice")
        line_slider_label_x = build_label(text="X Slice")
        line_slider_label_y = build_label(text="Y Slice")
        opacity_slider_label = build_label(text="Opacity")

        panel = ui.Panel2D(size=(300, 200),
                           color=(1, 1, 1),
                           opacity=0.1,
                           align="right")
        panel.center = (1030, 120)

        panel.add_element(line_slider_label_x, (0.1, 0.75))
        panel.add_element(line_slider_x, (0.38, 0.75))
        panel.add_element(line_slider_label_y, (0.1, 0.55))
        panel.add_element(line_slider_y, (0.38, 0.55))
        panel.add_element(line_slider_label_z, (0.1, 0.35))
        panel.add_element(line_slider_z, (0.38, 0.35))
        panel.add_element(opacity_slider_label, (0.1, 0.15))
        panel.add_element(opacity_slider, (0.38, 0.15))

        show_m.scene.add(panel)

        global size
        size = figure.GetSize()

        def win_callback(obj, event):
            global size
            if size != obj.GetSize():
                size_old = size
                size = obj.GetSize()
                size_change = [size[0] - size_old[0], 0]
                panel.re_align(size_change)

    show_m.initialize()

    figure.zoom(1.5)
    figure.reset_clipping_range()

    if interact:
        show_m.add_window_callback(win_callback)
        show_m.render()
        show_m.start()

    return _inline_interact(figure, inline, interact)
示例#4
0
def show_mosaic(data):

    renderer = window.Renderer()
    renderer.background((0.5, 0.5, 0.5))

    slice_actor = actor.slicer(data)

    show_m = window.ShowManager(renderer, size=(1200, 900))
    show_m.initialize()

    label_position = ui.TextBlock2D(text='Position:')
    label_value = ui.TextBlock2D(text='Value:')

    result_position = ui.TextBlock2D(text='')
    result_value = ui.TextBlock2D(text='')

    panel_picking = ui.Panel2D(center=(200, 120),
                               size=(250, 125),
                               color=(0, 0, 0),
                               opacity=0.75,
                               align="left")

    panel_picking.add_element(label_position, 'relative', (0.1, 0.55))
    panel_picking.add_element(label_value, 'relative', (0.1, 0.25))

    panel_picking.add_element(result_position, 'relative', (0.45, 0.55))
    panel_picking.add_element(result_value, 'relative', (0.45, 0.25))

    show_m.ren.add(panel_picking)
    """
    Add a left-click callback to the slicer. Also disable interpolation so you can
    see what you are picking.
    """

    renderer.clear()
    renderer.projection('parallel')

    result_position.message = ''
    result_value.message = ''

    show_m_mosaic = window.ShowManager(renderer, size=(1200, 900))
    show_m_mosaic.initialize()

    def left_click_callback_mosaic(obj, ev):
        """Get the value of the clicked voxel and show it in the panel."""
        event_pos = show_m_mosaic.iren.GetEventPosition()

        obj.picker.Pick(event_pos[0], event_pos[1], 0, show_m_mosaic.ren)

        i, j, k = obj.picker.GetPointIJK()
        result_position.message = '({}, {}, {})'.format(str(i), str(j), str(k))
        result_value.message = '%.8f' % data[i, j, k]

    cnt = 0

    X, Y, Z = slice_actor.shape[:3]

    rows = 10
    cols = 15
    border = 10

    for j in range(rows):
        for i in range(cols):
            slice_mosaic = slice_actor.copy()
            slice_mosaic.display(None, None, cnt)
            slice_mosaic.SetPosition(
                (X + border) * i, 0.5 * cols * (Y + border) - (Y + border) * j,
                0)
            slice_mosaic.SetInterpolate(False)
            slice_mosaic.AddObserver('LeftButtonPressEvent',
                                     left_click_callback_mosaic, 1.0)
            renderer.add(slice_mosaic)
            cnt += 1
            if cnt > Z:
                break
        if cnt > Z:
            break

    renderer.reset_camera()
    renderer.zoom(1.6)

    show_m_mosaic.ren.add(panel_picking)
    show_m_mosaic.start()
示例#5
0
def test_ui_button_panel(recording=False):
    filename = "test_ui_button_panel"
    recording_filename = pjoin(DATA_DIR, filename + ".log.gz")
    expected_events_counts_filename = pjoin(DATA_DIR, filename + ".pkl")

    # Rectangle
    rectangle_test = ui.Rectangle2D(size=(10, 10))
    rectangle_test.get_actors()
    another_rectangle_test = ui.Rectangle2D(size=(1, 1))
    # /Rectangle

    # Button
    fetch_viz_icons()

    icon_files = dict()
    icon_files['stop'] = read_viz_icons(fname='stop2.png')
    icon_files['play'] = read_viz_icons(fname='play3.png')

    button_test = ui.Button2D(icon_fnames=icon_files)
    button_test.set_center((20, 20))

    def make_invisible(i_ren, obj, button):
        # i_ren: CustomInteractorStyle
        # obj: vtkActor picked
        # button: Button2D
        button.set_visibility(False)
        i_ren.force_render()
        i_ren.event.abort()

    def modify_button_callback(i_ren, obj, button):
        # i_ren: CustomInteractorStyle
        # obj: vtkActor picked
        # button: Button2D
        button.next_icon()
        i_ren.force_render()

    button_test.on_right_mouse_button_pressed = make_invisible
    button_test.on_left_mouse_button_pressed = modify_button_callback

    button_test.scale((2, 2))
    button_color = button_test.color
    button_test.color = button_color
    # /Button

    # Panel
    panel = ui.Panel2D(center=(440, 90),
                       size=(300, 150),
                       color=(1, 1, 1),
                       align="right")
    panel.add_element(rectangle_test, 'absolute', (580, 150))
    panel.add_element(button_test, 'relative', (0.2, 0.2))
    npt.assert_raises(ValueError, panel.add_element, another_rectangle_test,
                      'error_string', (1, 2))
    # /Panel

    # Assign the counter callback to every possible event.
    event_counter = EventCounter()
    event_counter.monitor(button_test)
    event_counter.monitor(panel)

    current_size = (600, 600)
    show_manager = window.ShowManager(size=current_size, title="DIPY Button")

    show_manager.ren.add(panel)

    if recording:
        show_manager.record_events_to_file(recording_filename)
        print(list(event_counter.events_counts.items()))
        event_counter.save(expected_events_counts_filename)

    else:
        show_manager.play_events_from_file(recording_filename)
        expected = EventCounter.load(expected_events_counts_filename)
        event_counter.check_counts(expected)
def main():
    # reads the tractography data in trk format
    # extracts streamlines and the file header. Streamlines should be in the same coordinate system as the FA map (used later).
    # input example: '/home/Example_data/tracts.trk'
    tractography_file = input(
        "Please, specify the file with tracts that you would like to analyse. File should be in the trk format. "
    )

    streams, hdr = load_trk(tractography_file)  # for old DIPY version
    # sft = load_trk(tractography_file, tractography_file)
    # streams = sft.streamlines
    streams_array = np.asarray(streams)
    print('imported tractography data:' + tractography_file)

    # load T1fs_conform image that operates in the same coordinates as simnibs except for the fact the center of mesh
    # is located at the image center
    # T1fs_conform image should be generated in advance during the head meshing procedure
    # input example: fname_T1='/home/Example_data/T1fs_conform.nii.gz'

    fname_T1 = input(
        "Please, specify the T1fs_conform image that has been generated during head meshing procedure. "
    )
    data_T1, affine_T1 = load_nifti(fname_T1)

    # load FA image in the same coordinates as tracts
    # input example:fname_FA='/home/Example_data/DTI_FA.nii'
    fname_FA = input("Please, specify the FA image. ")
    data_FA, affine_FA = load_nifti(fname_FA)

    print('loaded T1fs_conform.nii and FA images')

    # specify the head mesh file that is used later in simnibs to simulate induced electric field
    # input example:'/home/Example_data/SUBJECT_MESH.msh'
    global mesh_path
    mesh_path = input("Please, specify the head mesh file. ")

    last_slach = max([i for i, ltr in enumerate(mesh_path) if ltr == '/']) + 1
    global subject_name
    subject_name = mesh_path[last_slach:-4]

    # specify the directory where you would like to save your simulation results
    # input example:'/home/Example_data/Output'
    global out_dir
    out_dir = input(
        "Please, specify the directory where you would like to save your simulation results. "
    )
    out_dir = out_dir + '/simulation_at_pos_'

    # Co-registration of T1fs_conform and FA images. Performed in 4 steps.
    # Step 1. Calculation of the center of mass transform. Used later as starting transform.
    c_of_mass = transform_centers_of_mass(data_T1, affine_T1, data_FA,
                                          affine_FA)
    print('calculated c_of_mass transformation')

    # Step 2. Calculation of a 3D translation transform. Used in the next step as starting transform.
    nbins = 32
    sampling_prop = None
    metric = MutualInformationMetric(nbins, sampling_prop)
    level_iters = [10000, 1000, 100]
    sigmas = [3.0, 1.0, 0.0]
    factors = [4, 2, 1]
    affreg = AffineRegistration(metric=metric,
                                level_iters=level_iters,
                                sigmas=sigmas,
                                factors=factors)

    transform = TranslationTransform3D()
    params0 = None
    starting_affine = c_of_mass.affine
    translation = affreg.optimize(data_T1,
                                  data_FA,
                                  transform,
                                  params0,
                                  affine_T1,
                                  affine_FA,
                                  starting_affine=starting_affine)
    print('calculated 3D translation transform')

    # Step 3. Calculation of a Rigid 3D transform. Used in the next step as starting transform
    transform = RigidTransform3D()
    params0 = None
    starting_affine = translation.affine
    rigid = affreg.optimize(data_T1,
                            data_FA,
                            transform,
                            params0,
                            affine_T1,
                            affine_FA,
                            starting_affine=starting_affine)
    print('calculated Rigid 3D transform')

    # Step 4. Calculation of an affine transform. Used for co-registration of T1 and FA images.
    transform = AffineTransform3D()
    params0 = None
    starting_affine = rigid.affine
    affine = affreg.optimize(data_T1,
                             data_FA,
                             transform,
                             params0,
                             affine_T1,
                             affine_FA,
                             starting_affine=starting_affine)

    print('calculated Affine 3D transform')

    identity = np.eye(4)

    inv_affine_FA = np.linalg.inv(affine_FA)
    inv_affine_T1 = np.linalg.inv(affine_T1)
    inv_affine = np.linalg.inv(affine.affine)

    # transforming streamlines to FA space
    new_streams_FA = streamline.transform_streamlines(streams, inv_affine_FA)
    new_streams_FA_array = np.asarray(new_streams_FA)

    T1_to_FA = np.dot(inv_affine_FA, np.dot(affine.affine, affine_T1))
    FA_to_T1 = np.linalg.inv(T1_to_FA)

    # transforming streamlines from FA to T1 space
    new_streams_T1 = streamline.transform_streamlines(new_streams_FA, FA_to_T1)
    global new_streams_T1_array
    new_streams_T1_array = np.asarray(new_streams_T1)

    # calculating amline derivatives along the streamlines to get the local orientation of the streamlines
    global streams_array_derivative
    streams_array_derivative = copy.deepcopy(new_streams_T1_array)

    print('calculating amline derivatives')
    for stream in range(len(new_streams_T1_array)):
        my_steam = new_streams_T1_array[stream]
        for t in range(len(my_steam[:, 0])):
            streams_array_derivative[stream][t,
                                             0] = my_deriv(t, my_steam[:, 0])
            streams_array_derivative[stream][t,
                                             1] = my_deriv(t, my_steam[:, 1])
            streams_array_derivative[stream][t,
                                             2] = my_deriv(t, my_steam[:, 2])
            deriv_norm = np.linalg.norm(streams_array_derivative[stream][t, :])
            streams_array_derivative[stream][
                t, :] = streams_array_derivative[stream][t, :] / deriv_norm

    # to create a torus representing a coil in an interactive window

    torus = vtk.vtkParametricTorus()
    torus.SetRingRadius(5)
    torus.SetCrossSectionRadius(2)

    torusSource = vtk.vtkParametricFunctionSource()
    torusSource.SetParametricFunction(torus)
    torusSource.SetScalarModeToPhase()

    torusMapper = vtk.vtkPolyDataMapper()
    torusMapper.SetInputConnection(torusSource.GetOutputPort())
    torusMapper.SetScalarRange(0, 360)

    torusActor = vtk.vtkActor()
    torusActor.SetMapper(torusMapper)

    torus_pos_x = 100
    torus_pos_y = 129
    torus_pos_z = 211
    torusActor.SetPosition(torus_pos_x, torus_pos_y, torus_pos_z)

    list_streams_T1 = list(new_streams_T1)
    # adding one fictive bundle of length 1 with coordinates [0,0,0] to avoid some bugs with actor.line during visualization
    list_streams_T1.append(np.array([0, 0, 0]))

    global bundle_native
    bundle_native = list_streams_T1

    # generating a list of colors to visualize later the stimualtion effects
    effect_max = 0.100
    effect_min = -0.100
    global colors
    colors = [
        np.random.rand(*current_streamline.shape)
        for current_streamline in bundle_native
    ]

    for my_streamline in range(len(bundle_native) - 1):
        my_stream = copy.deepcopy(bundle_native[my_streamline])
        for point in range(len(my_stream)):
            colors[my_streamline][point] = vtkplotter.colors.colorMap(
                (effect_min + effect_max) / 2,
                name='jet',
                vmin=effect_min,
                vmax=effect_max)

    colors[my_streamline + 1] = vtkplotter.colors.colorMap(effect_min,
                                                           name='jet',
                                                           vmin=effect_min,
                                                           vmax=effect_max)

    # Vizualization of fibers over T1

    # i_coord = 0
    # j_coord = 0
    # k_coord = 0
    # global number_of_stimulations
    number_of_stimulations = 0

    actor_line_list = []

    scene = window.Scene()
    scene.clear()
    scene.background((0.5, 0.5, 0.5))

    world_coords = False
    shape = data_T1.shape

    lut = actor.colormap_lookup_table(scale_range=(effect_min, effect_max),
                                      hue_range=(0.4, 1.),
                                      saturation_range=(1, 1.))

    # # the lines below is for a non-interactive demonstration run only.
    # # they should remain commented unless you set "interactive" to False
    # lut, colors = change_TMS_effects(torus_pos_x, torus_pos_y, torus_pos_z)
    # bar =  actor.scalar_bar(lut)
    # bar.SetTitle("TMS effect")
    # bar.SetHeight(0.3)
    # bar.SetWidth(0.10)
    # bar.SetPosition(0.85, 0.3)
    # scene.add(bar)

    actor_line_list.append(
        actor.line(bundle_native,
                   colors,
                   linewidth=5,
                   fake_tube=True,
                   lookup_colormap=lut))

    if not world_coords:
        image_actor_z = actor.slicer(data_T1, identity)
    else:
        image_actor_z = actor.slicer(data_T1, identity)

    slicer_opacity = 0.6
    image_actor_z.opacity(slicer_opacity)

    image_actor_x = image_actor_z.copy()
    x_midpoint = int(np.round(shape[0] / 2))
    image_actor_x.display_extent(x_midpoint, x_midpoint, 0, shape[1] - 1, 0,
                                 shape[2] - 1)

    image_actor_y = image_actor_z.copy()
    y_midpoint = int(np.round(shape[1] / 2))
    image_actor_y.display_extent(0, shape[0] - 1, y_midpoint, y_midpoint, 0,
                                 shape[2] - 1)
    """
    Connect the actors with the scene.
    """

    scene.add(actor_line_list[0])
    scene.add(image_actor_z)
    scene.add(image_actor_x)
    scene.add(image_actor_y)

    show_m = window.ShowManager(scene, size=(1200, 900))
    show_m.initialize()
    """
    Create sliders to move the slices and change their opacity.
    """

    line_slider_z = ui.LineSlider2D(min_value=0,
                                    max_value=shape[2] - 1,
                                    initial_value=shape[2] / 2,
                                    text_template="{value:.0f}",
                                    length=140)

    line_slider_x = ui.LineSlider2D(min_value=0,
                                    max_value=shape[0] - 1,
                                    initial_value=shape[0] / 2,
                                    text_template="{value:.0f}",
                                    length=140)

    line_slider_y = ui.LineSlider2D(min_value=0,
                                    max_value=shape[1] - 1,
                                    initial_value=shape[1] / 2,
                                    text_template="{value:.0f}",
                                    length=140)

    opacity_slider = ui.LineSlider2D(min_value=0.0,
                                     max_value=1.0,
                                     initial_value=slicer_opacity,
                                     length=140)
    """
    Сallbacks for the sliders.
    """
    def change_slice_z(slider):
        z = int(np.round(slider.value))
        image_actor_z.display_extent(0, shape[0] - 1, 0, shape[1] - 1, z, z)

    def change_slice_x(slider):
        x = int(np.round(slider.value))
        image_actor_x.display_extent(x, x, 0, shape[1] - 1, 0, shape[2] - 1)

    def change_slice_y(slider):
        y = int(np.round(slider.value))
        image_actor_y.display_extent(0, shape[0] - 1, y, y, 0, shape[2] - 1)

    def change_opacity(slider):
        slicer_opacity = slider.value
        image_actor_z.opacity(slicer_opacity)
        image_actor_x.opacity(slicer_opacity)
        image_actor_y.opacity(slicer_opacity)

    line_slider_z.on_change = change_slice_z
    line_slider_x.on_change = change_slice_x
    line_slider_y.on_change = change_slice_y
    opacity_slider.on_change = change_opacity
    """
    Сreate text labels to identify the sliders.
    """

    def build_label(text):
        label = ui.TextBlock2D()
        label.message = text
        label.font_size = 18
        label.font_family = 'Arial'
        label.justification = 'left'
        label.bold = False
        label.italic = False
        label.shadow = False
        label.background = (0, 0, 0)
        label.color = (1, 1, 1)
        return label

    line_slider_label_z = build_label(text="Z Slice")
    line_slider_label_x = build_label(text="X Slice")
    line_slider_label_y = build_label(text="Y Slice")
    opacity_slider_label = build_label(text="Opacity")
    """
    Create a ``panel`` to contain the sliders and labels.
    """

    panel = ui.Panel2D(size=(300, 200),
                       color=(1, 1, 1),
                       opacity=0.1,
                       align="right")
    panel.center = (1030, 120)

    panel.add_element(line_slider_label_x, (0.1, 0.75))
    panel.add_element(line_slider_x, (0.38, 0.75))
    panel.add_element(line_slider_label_y, (0.1, 0.55))
    panel.add_element(line_slider_y, (0.38, 0.55))
    panel.add_element(line_slider_label_z, (0.1, 0.35))
    panel.add_element(line_slider_z, (0.38, 0.35))
    panel.add_element(opacity_slider_label, (0.1, 0.15))
    panel.add_element(opacity_slider, (0.38, 0.15))

    scene.add(panel)
    """
    Create a ``panel`` to show the value of a picked voxel.
    """

    label_position = ui.TextBlock2D(text='Position:')
    label_value = ui.TextBlock2D(text='Value:')

    result_position = ui.TextBlock2D(text='')
    result_value = ui.TextBlock2D(text='')

    text2 = ui.TextBlock2D(text='Calculate')

    panel_picking = ui.Panel2D(size=(250, 125),
                               color=(1, 1, 1),
                               opacity=0.1,
                               align="left")
    panel_picking.center = (200, 120)

    panel_picking.add_element(label_position, (0.1, 0.75))
    panel_picking.add_element(label_value, (0.1, 0.45))

    panel_picking.add_element(result_position, (0.45, 0.75))
    panel_picking.add_element(result_value, (0.45, 0.45))

    panel_picking.add_element(text2, (0.1, 0.15))

    icon_files = []
    icon_files.append(('left', read_viz_icons(fname='circle-left.png')))
    button_example = ui.Button2D(icon_fnames=icon_files, size=(100, 30))
    panel_picking.add_element(button_example, (0.5, 0.1))

    def change_text_callback(i_ren, obj, button):
        text2.message = str(i_coord) + ' ' + str(j_coord) + ' ' + str(k_coord)
        torusActor.SetPosition(i_coord, j_coord, k_coord)
        print(i_coord, j_coord, k_coord)
        lut, colors = change_TMS_effects(i_coord, j_coord, k_coord)
        scene.rm(actor_line_list[0])
        actor_line_list.append(
            actor.line(bundle_native,
                       colors,
                       linewidth=5,
                       fake_tube=True,
                       lookup_colormap=lut))
        scene.add(actor_line_list[1])

        nonlocal number_of_stimulations
        global bar
        if number_of_stimulations > 0:
            scene.rm(bar)
        else:
            number_of_stimulations = number_of_stimulations + 1

        bar = actor.scalar_bar(lut)
        bar.SetTitle("TMS effect")

        bar.SetHeight(0.3)
        bar.SetWidth(0.10)  # the width is set first
        bar.SetPosition(0.85, 0.3)
        scene.add(bar)

        actor_line_list.pop(0)
        i_ren.force_render()

    button_example.on_left_mouse_button_clicked = change_text_callback

    scene.add(panel_picking)
    scene.add(torusActor)

    def left_click_callback(obj, ev):
        """Get the value of the clicked voxel and show it in the panel."""
        event_pos = show_m.iren.GetEventPosition()

        obj.picker.Pick(event_pos[0], event_pos[1], 0, scene)

        global i_coord, j_coord, k_coord
        i_coord, j_coord, k_coord = obj.picker.GetPointIJK()
        print(i_coord, j_coord, k_coord)
        result_position.message = '({}, {}, {})'.format(
            str(i_coord), str(j_coord), str(k_coord))
        result_value.message = '%.8f' % data_T1[i_coord, j_coord, k_coord]
        torusActor.SetPosition(i_coord, j_coord, k_coord)

    image_actor_z.AddObserver('LeftButtonPressEvent', left_click_callback, 1.0)

    global size
    size = scene.GetSize()

    def win_callback(obj, event):
        global size
        if size != obj.GetSize():
            size_old = size
            size = obj.GetSize()
            size_change = [size[0] - size_old[0], 0]
            panel.re_align(size_change)

    show_m.initialize()
    """
    Set the following variable to ``True`` to interact with the datasets in 3D.
    """
    interactive = True

    scene.zoom(2.0)
    scene.reset_clipping_range()
    scene.set_camera(position=(-642.07, 495.40, 148.49),
                     focal_point=(127.50, 127.50, 127.50),
                     view_up=(0.02, -0.01, 1.00))

    if interactive:
        show_m.add_window_callback(win_callback)
        show_m.render()
        show_m.start()
    else:
        window.record(scene,
                      out_path=out_dir + '/bundles_and_effects.png',
                      size=(1200, 900),
                      reset_camera=True)
示例#7
0
def test_ui_button_panel(recording=False):
    filename = "test_ui_button_panel"
    recording_filename = pjoin(DATA_DIR, filename + ".log.gz")
    expected_events_counts_filename = pjoin(DATA_DIR, filename + ".pkl")

    # Rectangle
    rectangle_test = ui.Rectangle2D(size=(10, 10))
    another_rectangle_test = ui.Rectangle2D(size=(1, 1))

    # Button
    fetch_viz_icons()

    icon_files = []
    icon_files.append(('stop', read_viz_icons(fname='stop2.png')))
    icon_files.append(('play', read_viz_icons(fname='play3.png')))

    button_test = ui.Button2D(icon_fnames=icon_files)
    button_test.center = (20, 20)

    def make_invisible(i_ren, obj, button):
        # i_ren: CustomInteractorStyle
        # obj: vtkActor picked
        # button: Button2D
        button.set_visibility(False)
        i_ren.force_render()
        i_ren.event.abort()

    def modify_button_callback(i_ren, obj, button):
        # i_ren: CustomInteractorStyle
        # obj: vtkActor picked
        # button: Button2D
        button.next_icon()
        i_ren.force_render()

    button_test.on_right_mouse_button_pressed = make_invisible
    button_test.on_left_mouse_button_pressed = modify_button_callback

    button_test.scale((2, 2))
    button_color = button_test.color
    button_test.color = button_color

    # TextBlock
    text_block_test = ui.TextBlock2D()
    text_block_test.message = 'TextBlock'
    text_block_test.color = (0, 0, 0)

    # Panel
    panel = ui.Panel2D(size=(300, 150),
                       position=(290, 15),
                       color=(1, 1, 1),
                       align="right")
    panel.add_element(rectangle_test, (290, 135))
    panel.add_element(button_test, (0.1, 0.1))
    panel.add_element(text_block_test, (0.7, 0.7))
    npt.assert_raises(ValueError, panel.add_element, another_rectangle_test,
                      (10., 0.5))
    npt.assert_raises(ValueError, panel.add_element, another_rectangle_test,
                      (-0.5, 0.5))

    # Assign the counter callback to every possible event.
    event_counter = EventCounter()
    event_counter.monitor(button_test)
    event_counter.monitor(panel.background)

    current_size = (600, 600)
    show_manager = window.ShowManager(size=current_size, title="DIPY Button")

    show_manager.ren.add(panel)

    if recording:
        show_manager.record_events_to_file(recording_filename)
        print(list(event_counter.events_counts.items()))
        event_counter.save(expected_events_counts_filename)

    else:
        show_manager.play_events_from_file(recording_filename)
        expected = EventCounter.load(expected_events_counts_filename)
        event_counter.check_counts(expected)
示例#8
0

def modify_button_callback(i_ren, obj, button):
    button.next_icon()
    i_ren.force_render()


second_button_example.on_left_mouse_button_pressed = modify_button_callback
"""
Panels
======

Simply create a panel and add elements to it.
"""

panel = ui.Panel2D(size=(300, 150), color=(1, 1, 1), align="right")
panel.center = (440, 90)
panel.add_element(button_example, (0.2, 0.2))
panel.add_element(second_button_example, (190, 85))
"""
TextBox
=======
"""

text = ui.TextBox2D(height=3, width=10)
"""
2D Line Slider
==============
"""

def fiber_simple_3d_show_advanced(img,
                                  streamlines,
                                  colors=None,
                                  linewidth=1,
                                  s='png',
                                  imgcolor=False):

    streamlines = streamlines
    data = img.get_data()
    shape = img.shape
    affine = img.affine
    """
    With our current design it is easy to decide in which space you want the
    streamlines and slices to appear. The default we have here is to appear in
    world coordinates (RAS 1mm).
    """

    world_coords = True
    """
    If we want to see the objects in native space we need to make sure that all
    objects which are currently in world coordinates are transformed back to
    native space using the inverse of the affine.
    """

    if not world_coords:
        from dipy.tracking.streamline import transform_streamlines
        streamlines = transform_streamlines(streamlines, np.linalg.inv(affine))
    """
    Now we create, a ``Renderer`` object and add the streamlines using the ``line``
    function and an image plane using the ``slice`` function.
    """

    ren = window.Renderer()
    stream_actor = actor.line(streamlines, colors=colors, linewidth=linewidth)
    """img colormap"""
    if imgcolor:
        lut = actor.colormap_lookup_table(scale_range=(0, 1),
                                          hue_range=(0, 1.),
                                          saturation_range=(0., 1.),
                                          value_range=(0., 1.))
    else:
        lut = None
    if not world_coords:
        image_actor_z = actor.slicer(data,
                                     affine=np.eye(4),
                                     lookup_colormap=lut)
    else:
        image_actor_z = actor.slicer(data, affine, lookup_colormap=lut)
    """
    We can also change also the opacity of the slicer.
    """

    slicer_opacity = 0.6
    image_actor_z.opacity(slicer_opacity)
    """
    We can add additonal slicers by copying the original and adjusting the
    ``display_extent``.
    """

    image_actor_x = image_actor_z.copy()
    image_actor_x.opacity(slicer_opacity)
    x_midpoint = int(np.round(shape[0] / 2))
    image_actor_x.display_extent(x_midpoint, x_midpoint, 0, shape[1] - 1, 0,
                                 shape[2] - 1)

    image_actor_y = image_actor_z.copy()
    image_actor_y.opacity(slicer_opacity)
    y_midpoint = int(np.round(shape[1] / 2))
    image_actor_y.display_extent(0, shape[0] - 1, y_midpoint, y_midpoint, 0,
                                 shape[2] - 1)
    """
    Connect the actors with the Renderer.
    """

    ren.add(stream_actor)
    ren.add(image_actor_z)
    ren.add(image_actor_x)
    ren.add(image_actor_y)
    """
    Now we would like to change the position of each ``image_actor`` using a
    slider. The sliders are widgets which require access to different areas of the
    visualization pipeline and therefore we don't recommend using them with
    ``show``. The more appropriate way is to use them with the ``ShowManager``
    object which allows accessing the pipeline in different areas. Here is how:
    """

    show_m = window.ShowManager(ren, size=(1200, 900))
    show_m.initialize()
    """
    After we have initialized the ``ShowManager`` we can go ahead and create
    sliders to move the slices and change their opacity.
    """

    line_slider_z = ui.LineSlider2D(min_value=0,
                                    max_value=shape[2] - 1,
                                    initial_value=shape[2] / 2,
                                    text_template="{value:.0f}",
                                    length=140)

    line_slider_x = ui.LineSlider2D(min_value=0,
                                    max_value=shape[0] - 1,
                                    initial_value=shape[0] / 2,
                                    text_template="{value:.0f}",
                                    length=140)

    line_slider_y = ui.LineSlider2D(min_value=0,
                                    max_value=shape[1] - 1,
                                    initial_value=shape[1] / 2,
                                    text_template="{value:.0f}",
                                    length=140)

    opacity_slider = ui.LineSlider2D(min_value=0.0,
                                     max_value=1.0,
                                     initial_value=slicer_opacity,
                                     length=140)
    """
    Now we will write callbacks for the sliders and register them.
    """
    def change_slice_z(i_ren, obj, slider):
        z = int(np.round(slider.value))
        image_actor_z.display_extent(0, shape[0] - 1, 0, shape[1] - 1, z, z)

    def change_slice_x(i_ren, obj, slider):
        x = int(np.round(slider.value))
        image_actor_x.display_extent(x, x, 0, shape[1] - 1, 0, shape[2] - 1)

    def change_slice_y(i_ren, obj, slider):
        y = int(np.round(slider.value))
        image_actor_y.display_extent(0, shape[0] - 1, y, y, 0, shape[2] - 1)

    def change_opacity(i_ren, obj, slider):
        slicer_opacity = slider.value
        image_actor_z.opacity(slicer_opacity)
        image_actor_x.opacity(slicer_opacity)
        image_actor_y.opacity(slicer_opacity)

    line_slider_z.add_callback(line_slider_z.slider_disk, "MouseMoveEvent",
                               change_slice_z)
    line_slider_x.add_callback(line_slider_x.slider_disk, "MouseMoveEvent",
                               change_slice_x)
    line_slider_y.add_callback(line_slider_y.slider_disk, "MouseMoveEvent",
                               change_slice_y)
    opacity_slider.add_callback(opacity_slider.slider_disk, "MouseMoveEvent",
                                change_opacity)
    """
    We'll also create text labels to identify the sliders.
    """

    def build_label(text):
        label = ui.TextBlock2D()
        label.message = text
        label.font_size = 18
        label.font_family = 'Arial'
        label.justification = 'left'
        label.bold = False
        label.italic = False
        label.shadow = False
        # label.actor.GetTextProperty().SetBackgroundColor(0, 0, 0)
        # label.actor.GetTextProperty().SetBackgroundOpacity(0.0)
        label.color = (1, 1, 1)

        return label

    line_slider_label_z = build_label(text="Z Slice")
    line_slider_label_x = build_label(text="X Slice")
    line_slider_label_y = build_label(text="Y Slice")
    opacity_slider_label = build_label(text="Opacity")
    """
    Now we will create a ``panel`` to contain the sliders and labels.
    """

    panel = ui.Panel2D(center=(1030, 120),
                       size=(300, 200),
                       color=(1, 1, 1),
                       opacity=0.1,
                       align="right")

    panel.add_element(line_slider_label_x, 'relative', (0.1, 0.75))
    panel.add_element(line_slider_x, 'relative', (0.65, 0.8))
    panel.add_element(line_slider_label_y, 'relative', (0.1, 0.55))
    panel.add_element(line_slider_y, 'relative', (0.65, 0.6))
    panel.add_element(line_slider_label_z, 'relative', (0.1, 0.35))
    panel.add_element(line_slider_z, 'relative', (0.65, 0.4))
    panel.add_element(opacity_slider_label, 'relative', (0.1, 0.15))
    panel.add_element(opacity_slider, 'relative', (0.65, 0.2))

    show_m.ren.add(panel)
    """
    Then, we can render all the widgets and everything else in the screen and
    start the interaction using ``show_m.start()``.


    However, if you change the window size, the panel will not update its position
    properly. The solution to this issue is to update the position of the panel
    using its ``re_align`` method every time the window size changes.
    """

    global size
    size = ren.GetSize()

    def win_callback(obj, event):
        global size
        if size != obj.GetSize():
            size_old = size
            size = obj.GetSize()
            size_change = [size[0] - size_old[0], 0]
            panel.re_align(size_change)

    show_m.initialize()
    """
    Finally, please set the following variable to ``True`` to interact with the
    datasets in 3D.
    """

    interactive = True  #False

    ren.zoom(1.5)
    ren.reset_clipping_range()

    if interactive:

        show_m.add_window_callback(win_callback)
        show_m.render()
        show_m.start()

    else:

        window.record(
            ren,
            out_path=
            '/home/brain/workingdir/data/dwi/hcp/preprocessed/response_dhollander/'
            '100408/result/result20vs45/cc_clustering_png1/100408lr15_%s.png' %
            s,
            size=(1200, 900),
            reset_camera=False)
    """
    .. figure:: bundles_and_3_slices.png
       :align: center

       A few bundles with interactive slicing.
    """

    del show_m
    """
示例#10
0
show_m = window.ShowManager(renderer, size=(1200, 900))
show_m.initialize()
"""
We'll start by creating the panel and adding it to the ``ShowManager``
"""

label_position = ui.TextBlock2D(text='Position:')
label_value = ui.TextBlock2D(text='Value:')

result_position = ui.TextBlock2D(text='')
result_value = ui.TextBlock2D(text='')

panel_picking = ui.Panel2D(center=(200, 120),
                           size=(250, 125),
                           color=(0, 0, 0),
                           opacity=0.75,
                           align="left")

panel_picking.add_element(label_position, 'relative', (0.1, 0.55))
panel_picking.add_element(label_value, 'relative', (0.1, 0.25))

panel_picking.add_element(result_position, 'relative', (0.45, 0.55))
panel_picking.add_element(result_value, 'relative', (0.45, 0.25))

show_m.ren.add(panel_picking)
"""
Add a left-click callback to the slicer. Also disable interpolation so you can
see what you are picking.
"""
示例#11
0
文件: viz.py 项目: akeshavan/pyAFQ
def visualize_volume(volume,
                     x=None,
                     y=None,
                     z=None,
                     ren=None,
                     inline=True,
                     interact=False):
    """
    Visualize a volume
    """
    if ren is None:
        ren = window.Renderer()

    shape = volume.shape
    image_actor_z = actor.slicer(volume)
    slicer_opacity = 0.6
    image_actor_z.opacity(slicer_opacity)

    image_actor_x = image_actor_z.copy()
    x_midpoint = int(np.round(shape[0] / 2))
    image_actor_x.display_extent(x_midpoint, x_midpoint, 0, shape[1] - 1, 0,
                                 shape[2] - 1)

    image_actor_y = image_actor_z.copy()
    y_midpoint = int(np.round(shape[1] / 2))
    image_actor_y.display_extent(0, shape[0] - 1, y_midpoint, y_midpoint, 0,
                                 shape[2] - 1)

    ren.add(image_actor_z)
    ren.add(image_actor_x)
    ren.add(image_actor_y)

    show_m = window.ShowManager(ren, size=(1200, 900))
    show_m.initialize()

    line_slider_z = ui.LineSlider2D(min_value=0,
                                    max_value=shape[2] - 1,
                                    initial_value=shape[2] / 2,
                                    text_template="{value:.0f}",
                                    length=140)

    line_slider_x = ui.LineSlider2D(min_value=0,
                                    max_value=shape[0] - 1,
                                    initial_value=shape[0] / 2,
                                    text_template="{value:.0f}",
                                    length=140)

    line_slider_y = ui.LineSlider2D(min_value=0,
                                    max_value=shape[1] - 1,
                                    initial_value=shape[1] / 2,
                                    text_template="{value:.0f}",
                                    length=140)

    opacity_slider = ui.LineSlider2D(min_value=0.0,
                                     max_value=1.0,
                                     initial_value=slicer_opacity,
                                     length=140)

    def change_slice_z(slider):
        z = int(np.round(slider.value))
        image_actor_z.display_extent(0, shape[0] - 1, 0, shape[1] - 1, z, z)

    def change_slice_x(slider):
        x = int(np.round(slider.value))
        image_actor_x.display_extent(x, x, 0, shape[1] - 1, 0, shape[2] - 1)

    def change_slice_y(slider):
        y = int(np.round(slider.value))
        image_actor_y.display_extent(0, shape[0] - 1, y, y, 0, shape[2] - 1)

    def change_opacity(slider):
        slicer_opacity = slider.value
        image_actor_z.opacity(slicer_opacity)
        image_actor_x.opacity(slicer_opacity)
        image_actor_y.opacity(slicer_opacity)

    line_slider_z.on_change = change_slice_z
    line_slider_x.on_change = change_slice_x
    line_slider_y.on_change = change_slice_y
    opacity_slider.on_change = change_opacity

    def build_label(text):
        label = ui.TextBlock2D()
        label.message = text
        label.font_size = 18
        label.font_family = 'Arial'
        label.justification = 'left'
        label.bold = False
        label.italic = False
        label.shadow = False
        label.background = (0, 0, 0)
        label.color = (1, 1, 1)

        return label

    line_slider_label_z = build_label(text="Z Slice")
    line_slider_label_x = build_label(text="X Slice")
    line_slider_label_y = build_label(text="Y Slice")
    opacity_slider_label = build_label(text="Opacity")

    panel = ui.Panel2D(size=(300, 200),
                       color=(1, 1, 1),
                       opacity=0.1,
                       align="right")
    panel.center = (1030, 120)

    panel.add_element(line_slider_label_x, (0.1, 0.75))
    panel.add_element(line_slider_x, (0.38, 0.75))
    panel.add_element(line_slider_label_y, (0.1, 0.55))
    panel.add_element(line_slider_y, (0.38, 0.55))
    panel.add_element(line_slider_label_z, (0.1, 0.35))
    panel.add_element(line_slider_z, (0.38, 0.35))
    panel.add_element(opacity_slider_label, (0.1, 0.15))
    panel.add_element(opacity_slider, (0.38, 0.15))

    show_m.ren.add(panel)

    global size
    size = ren.GetSize()

    def win_callback(obj, event):
        global size
        if size != obj.GetSize():
            size_old = size
            size = obj.GetSize()
            size_change = [size[0] - size_old[0], 0]
            panel.re_align(size_change)

    show_m.initialize()

    ren.zoom(1.5)
    ren.reset_clipping_range()

    if interact:
        show_m.add_window_callback(win_callback)
        show_m.render()
        show_m.start()

    return _inline_interact(ren, inline, interact)
示例#12
0
def show_odfs_and_fa(fa,
                     pam,
                     mask,
                     affine,
                     sphere,
                     ftmp='odf.mmap',
                     basis_type=None,
                     norm_odfs=True,
                     scale_odfs=0.5):

    renderer = window.Renderer()
    renderer.background((1, 1, 1))

    slice_actor = actor.slicer(fa)  #, value_range)

    odf_shape = fa.shape + (sphere.vertices.shape[0], )
    odfs = np.memmap(ftmp, dtype=np.float32, mode='w+', shape=odf_shape)

    sph_harm_basis = sph_harm_lookup.get(basis_type)

    if sph_harm_basis is None:
        raise ValueError("Invalid basis name.")
    B, m, n = sph_harm_basis(8, sphere.theta, sphere.phi)

    odfs[:] = np.dot(pam.shm_coeff.astype('f4'), B.T.astype('f4'))

    odf_slicer = actor.odf_slicer(odfs,
                                  mask=mask,
                                  sphere=sphere,
                                  scale=scale_odfs,
                                  norm=norm_odfs,
                                  colormap='magma')

    renderer.add(odf_slicer)
    renderer.add(slice_actor)

    show_m = window.ShowManager(renderer, size=(2000, 1000))
    show_m.initialize()
    """
    We'll start by creating the panel and adding it to the ``ShowManager``
    """

    label_position = ui.TextBlock2D(text='Position:')
    label_value = ui.TextBlock2D(text='Value:')

    result_position = ui.TextBlock2D(text='')
    result_value = ui.TextBlock2D(text='')
    line_slider_z = ui.LineSlider2D(min_value=0,
                                    max_value=shape[2] - 1,
                                    initial_value=shape[2] / 2,
                                    text_template="{value:.0f}",
                                    length=140)

    def change_slice_z(i_ren, obj, slider):
        z = int(np.round(slider.value))
        slice_actor.display(z=z)
        odf_slicer.display(z=z)
        show_m.render()

    line_slider_z.add_callback(line_slider_z.slider_disk,
                               "LeftButtonReleaseEvent", change_slice_z)

    panel_picking = ui.Panel2D(center=(200, 120),
                               size=(250, 225),
                               color=(0, 0, 0),
                               opacity=0.75,
                               align="left")

    # panel_picking.add_element(label_position, 'relative', (0.1, 0.55))
    # panel_picking.add_element(label_value, 'relative', (0.1, 0.25))

    # panel_picking.add_element(result_position, 'relative', (0.45, 0.55))
    # panel_picking.add_element(result_value, 'relative', (0.45, 0.25))

    panel_picking.add_element(line_slider_z, 'relative', (0.5, 0.9))

    show_m.ren.add(panel_picking)

    def left_click_callback(obj, ev):
        """Get the value of the clicked voxel and show it in the panel."""
        event_pos = show_m.iren.GetEventPosition()

        obj.picker.Pick(event_pos[0], event_pos[1], 0, show_m.ren)

        i, j, k = obj.picker.GetPointIJK()
        print(i, j, k)
        result_position.message = '({}, {}, {})'.format(str(i), str(j), str(k))
        result_value.message = '%.3f' % fa[i, j, k]

    slice_actor.SetInterpolate(True)
    slice_actor.AddObserver('LeftButtonPressEvent', left_click_callback, 1.0)

    show_m.start()

    odfs._mmap.close()
    del odfs
    os.remove(ftmp)
示例#13
0
def slider(image_actor, line_actor):

    slicer_opacity = 0.6
    #image_actor.opacity(slicer_opacity)
    ren = window.Renderer()
    ren.add(image_actor)
    ren.add(line_actor)

    show_m = window.ShowManager(ren, size=(1200, 900))
    show_m.initialize()

    line_slider_z = ui.LineSlider2D(min_value=0,
                                    max_value=shape[2] - 1,
                                    initial_value=shape[2] / 2,
                                    text_template="{value:.0f}")
    opacity_slider = ui.LineSlider2D(min_value=0.0,
                                     max_value=1.0,
                                     initial_value=slicer_opacity)

    def change_slice_z(i_ren, obj, slider):
        z = int(np.round(slider.value))
        image_actor.display(None, None, z)

    def change_opacity(i_ren, obj, slider):
        slicer_opacity = slider.value
        image_actor.opacity(slicer_opacity)

    line_slider_z.add_callback(line_slider_z.slider_disk, "MouseMoveEvent",
                               change_slice_z)

    opacity_slider.add_callback(opacity_slider.slider_disk, "MouseMoveEvent",
                                change_opacity)

    line_slider_label_z = ui.TextBox2D(text="Slice", width=50, height=20)

    opacity_slider_label = ui.TextBox2D(text="Opacity", width=50, height=20)

    panel = ui.Panel2D(center=(1030, 120),
                       size=(300, 200),
                       color=(1, 1, 1),
                       opacity=0.1,
                       align="right")

    panel.add_element(line_slider_label_z, 'relative', (0.1, 0.4))
    panel.add_element(line_slider_z, 'relative', (0.5, 0.4))
    panel.add_element(opacity_slider_label, 'relative', (0.1, 0.2))
    panel.add_element(opacity_slider, 'relative', (0.5, 0.2))

    #show_m.ren.add(panel)
    ren.add(panel)
    global size
    size = ren.GetSize()

    def win_callback(obj, event):
        global size
        if size != obj.GetSize():
            size_old = size
            size = obj.GetSize()
            size_change = [size[0] - size_old[0], 0]
            panel.re_align(size_change)

    show_m.initialize()
    #window.show(ren)
    show_m.add_window_callback(win_callback)
    show_m.render()
    show_m.start()
示例#14
0
文件: app.py 项目: gongyuchen/dipy
    def build_show(self, scene):

        self.show_m = window.ShowManager(scene,
                                         size=(1200, 900),
                                         order_transparent=True,
                                         reset_camera=False)
        self.show_m.initialize()

        if self.cluster and self.tractograms:

            lengths = np.array([self.cla[c]['length'] for c in self.cla])
            szs = [self.cla[c]['size'] for c in self.cla]
            sizes = np.array(szs)

            # global self.panel2, slider_length, slider_size
            self.panel2 = ui.Panel2D(size=(400, 400),
                                     position=(850, 520),
                                     color=(1, 1, 1),
                                     opacity=0.1,
                                     align="right")

            slider_label_threshold = build_label(text="Threshold")
            slider_threshold = ui.LineSlider2D(min_value=5,
                                               max_value=25,
                                               initial_value=15,
                                               text_template="{value:.0f}",
                                               length=140,
                                               shape='square')
            _color_slider(slider_threshold)

            slider_label_length = build_label(text="Length")
            slider_length = ui.LineSlider2D(
                min_value=lengths.min(),
                max_value=np.percentile(lengths, 98),
                initial_value=np.percentile(lengths, 25),
                text_template="{value:.0f}",
                length=140)
            _color_slider(slider_length)

            slider_label_size = build_label(text="Size")
            slider_size = ui.LineSlider2D(min_value=sizes.min(),
                                          max_value=np.percentile(sizes, 98),
                                          initial_value=np.percentile(
                                              sizes, 50),
                                          text_template="{value:.0f}",
                                          length=140)
            _color_slider(slider_size)

            # global self.length_min, size_min
            self.size_min = sizes.min()
            self.length_min = lengths.min()

            def change_threshold(istyle, obj, slider):
                sv = np.round(slider.value, 0)
                self.remove_actors(scene)
                self.add_actors(scene, self.tractograms, threshold=sv)

            slider_threshold.handle_events(slider_threshold.handle.actor)
            slider_threshold.on_left_mouse_button_released = change_threshold

            def hide_clusters_length(slider):
                self.length_min = np.round(slider.value)

                for k in self.cla:
                    if (self.cla[k]['length'] < self.length_min
                            or self.cla[k]['size'] < self.size_min):
                        self.cla[k]['centroid_actor'].SetVisibility(0)
                        if k.GetVisibility() == 1:
                            k.SetVisibility(0)
                    else:
                        self.cla[k]['centroid_actor'].SetVisibility(1)
                self.show_m.render()

            def hide_clusters_size(slider):
                self.size_min = np.round(slider.value)

                for k in self.cla:
                    if (self.cla[k]['length'] < self.length_min
                            or self.cla[k]['size'] < self.size_min):
                        self.cla[k]['centroid_actor'].SetVisibility(0)
                        if k.GetVisibility() == 1:
                            k.SetVisibility(0)
                    else:
                        self.cla[k]['centroid_actor'].SetVisibility(1)
                self.show_m.render()

            slider_length.on_change = hide_clusters_length

            self.panel2.add_element(slider_label_threshold,
                                    coords=(0.1, 0.133))
            self.panel2.add_element(slider_threshold, coords=(0.4, 0.133))

            self.panel2.add_element(slider_label_length, coords=(0.1, 0.333))
            self.panel2.add_element(slider_length, coords=(0.4, 0.333))

            slider_size.on_change = hide_clusters_size

            self.panel2.add_element(slider_label_size, coords=(0.1, 0.6666))
            self.panel2.add_element(slider_size, coords=(0.4, 0.6666))

            scene.add(self.panel2)

            text_block = build_label(HELP_MESSAGE, 16)
            text_block.message = HELP_MESSAGE

            help_panel = ui.Panel2D(size=(300, 200),
                                    color=(1, 1, 1),
                                    opacity=0.1,
                                    align="left")

            help_panel.add_element(text_block, coords=(0.05, 0.1))
            scene.add(help_panel)

        if len(self.images) > 0:
            # !!Only first image loading supported for now')
            data, affine = self.images[0]
            self.vox2ras = affine

            if len(self.pams) > 0:
                pam = self.pams[0]
            else:
                pam = None
            self.panel = slicer_panel(scene,
                                      self.show_m.iren,
                                      data,
                                      affine,
                                      self.world_coords,
                                      pam=pam)
        else:
            data = None
            affine = None
            pam = None

        self.win_size = scene.GetSize()

        def win_callback(obj, event):
            if self.win_size != obj.GetSize():
                size_old = self.win_size
                self.win_size = obj.GetSize()
                size_change = [self.win_size[0] - size_old[0], 0]
                if data is not None:
                    self.panel.re_align(size_change)
                if self.cluster:
                    self.panel2.re_align(size_change)
                    help_panel.re_align(size_change)

        self.show_m.initialize()

        def left_click_centroid_callback(obj, event):

            self.cea[obj]['selected'] = not self.cea[obj]['selected']
            self.cla[self.cea[obj]['cluster_actor']]['selected'] = \
                self.cea[obj]['selected']
            self.show_m.render()

        def left_click_cluster_callback(obj, event):

            if self.cla[obj]['selected']:
                self.cla[obj]['centroid_actor'].VisibilityOn()
                ca = self.cla[obj]['centroid_actor']
                self.cea[ca]['selected'] = 0
                obj.VisibilityOff()
                self.cea[ca]['expanded'] = 0

            self.show_m.render()

        for cl in self.cla:
            cl.AddObserver('LeftButtonPressEvent', left_click_cluster_callback,
                           1.0)
            self.cla[cl]['centroid_actor'].AddObserver(
                'LeftButtonPressEvent', left_click_centroid_callback, 1.0)

        self.hide_centroids = True
        self.select_all = False

        def key_press(obj, event):
            key = obj.GetKeySym()
            if self.cluster:

                # hide on/off unselected centroids
                if key == 'h' or key == 'H':
                    if self.hide_centroids:
                        for ca in self.cea:
                            if (self.cea[ca]['length'] >= self.length_min
                                    or self.cea[ca]['size'] >= self.size_min):
                                if self.cea[ca]['selected'] == 0:
                                    ca.VisibilityOff()
                    else:
                        for ca in self.cea:
                            if (self.cea[ca]['length'] >= self.length_min
                                    and self.cea[ca]['size'] >= self.size_min):
                                if self.cea[ca]['selected'] == 0:
                                    ca.VisibilityOn()
                    self.hide_centroids = not self.hide_centroids
                    self.show_m.render()

                # invert selection
                if key == 'i' or key == 'I':

                    for ca in self.cea:
                        if (self.cea[ca]['length'] >= self.length_min
                                and self.cea[ca]['size'] >= self.size_min):
                            self.cea[ca]['selected'] = \
                                not self.cea[ca]['selected']
                            cas = self.cea[ca]['cluster_actor']
                            self.cla[cas]['selected'] = \
                                self.cea[ca]['selected']
                    self.show_m.render()

                # save current result
                if key == 's' or key == 'S':
                    saving_streamlines = Streamlines()
                    for bundle in self.cla.keys():
                        if bundle.GetVisibility():
                            t = self.cla[bundle]['tractogram']
                            c = self.cla[bundle]['cluster']
                            indices = self.tractogram_clusters[t][c]
                            saving_streamlines.extend(Streamlines(indices))
                    print('Saving result in tmp.trk')
                    sft = StatefulTractogram(saving_streamlines, 'same',
                                             Space.RASMM)
                    save_tractogram(sft, 'tmp.trk', bbox_valid_check=False)

                if key == 'y' or key == 'Y':
                    active_streamlines = Streamlines()
                    for bundle in self.cla.keys():
                        if bundle.GetVisibility():
                            t = self.cla[bundle]['tractogram']
                            c = self.cla[bundle]['cluster']
                            indices = self.tractogram_clusters[t][c]
                            active_streamlines.extend(Streamlines(indices))

                    # self.tractograms = [active_streamlines]
                    hz2 = Horizon([active_streamlines],
                                  self.images,
                                  cluster=True,
                                  cluster_thr=5,
                                  random_colors=self.random_colors,
                                  length_lt=np.inf,
                                  length_gt=0,
                                  clusters_lt=np.inf,
                                  clusters_gt=0,
                                  world_coords=True,
                                  interactive=True)
                    ren2 = hz2.build_scene()
                    hz2.build_show(ren2)

                if key == 'a' or key == 'A':

                    if self.select_all is False:
                        for ca in self.cea:
                            if (self.cea[ca]['length'] >= self.length_min
                                    and self.cea[ca]['size'] >= self.size_min):
                                self.cea[ca]['selected'] = 1
                                cas = self.cea[ca]['cluster_actor']
                                self.cla[cas]['selected'] = \
                                    self.cea[ca]['selected']
                        self.show_m.render()
                        self.select_all = True
                    else:
                        for ca in self.cea:
                            if (self.cea[ca]['length'] >= self.length_min
                                    and self.cea[ca]['size'] >= self.size_min):
                                self.cea[ca]['selected'] = 0
                                cas = self.cea[ca]['cluster_actor']
                                self.cla[cas]['selected'] = \
                                    self.cea[ca]['selected']
                        self.show_m.render()
                        self.select_all = False

                if key == 'e' or key == 'E':

                    for c in self.cea:
                        if self.cea[c]['selected']:
                            if not self.cea[c]['expanded']:
                                len_ = self.cea[c]['length']
                                sz_ = self.cea[c]['size']
                                if (len_ >= self.length_min
                                        and sz_ >= self.size_min):
                                    self.cea[c]['cluster_actor']. \
                                        VisibilityOn()
                                    c.VisibilityOff()
                                    self.cea[c]['expanded'] = 1

                    self.show_m.render()

                if key == 'r' or key == 'R':

                    for c in self.cea:

                        if (self.cea[c]['length'] >= self.length_min
                                and self.cea[c]['size'] >= self.size_min):
                            self.cea[c]['cluster_actor'].VisibilityOff()
                            c.VisibilityOn()
                            self.cea[c]['expanded'] = 0

                self.show_m.render()

        HORIMEM.window_timer_cnt = 0

        def timer_callback(obj, event):

            HORIMEM.window_timer_cnt += 1
            # TODO possibly add automatic rotation option
            # cnt = HORIMEM.window_timer_cnt
            # show_m.scene.azimuth(0.05 * cnt)
            # show_m.render()

        scene.reset_camera()
        scene.zoom(1.5)
        scene.reset_clipping_range()

        if self.interactive:

            if self.recorded_events is None:

                self.show_m.add_window_callback(win_callback)
                self.show_m.add_timer_callback(True, 200, timer_callback)
                self.show_m.iren.AddObserver('KeyPressEvent', key_press)
                self.show_m.render()
                self.show_m.start()

            else:

                self.show_m.add_window_callback(win_callback)
                self.show_m.add_timer_callback(True, 200, timer_callback)
                self.show_m.iren.AddObserver('KeyPressEvent', key_press)

                # set to True if event recording needs updating
                recording = False
                recording_filename = self.recorded_events

                if recording:
                    self.show_m.record_events_to_file(recording_filename)
                else:
                    self.show_m.play_events_from_file(recording_filename)

        else:

            window.record(scene,
                          out_path=self.out_png,
                          size=(1200, 900),
                          reset_camera=False)
示例#15
0
def slicer_panel(scene,
                 iren,
                 data=None,
                 affine=None,
                 world_coords=False,
                 pam=None,
                 mask=None,
                 mem=GlobalHorizon()):
    """ Slicer panel with slicer included

    Parameters
    ----------
    scene : Scene
    iren : Interactor
    data : 3d ndarray
    affine : 4x4 ndarray
    world_coords : bool
        If True then the affine is applied.

    peaks : PeaksAndMetrics
        Default None
    mem :

    Returns
    -------
    panel : Panel

    """
    orig_shape = data.shape
    print('Original shape', orig_shape)
    ndim = data.ndim
    tmp = data
    if ndim == 4:
        if orig_shape[-1] > 3:
            orig_shape = orig_shape[:3]
            # Sometimes, first volume is null, so we try the next one.
            for i in range(orig_shape[-1]):
                tmp = data[..., i]
                value_range = np.percentile(data[..., i], q=[2, 98])
                if np.sum(np.diff(value_range)) != 0:
                    break
        if orig_shape[-1] == 3:
            value_range = (0, 1.)
            mem.slicer_rgb = True
    if ndim == 3:
        value_range = np.percentile(tmp, q=[2, 98])

    if np.sum(np.diff(value_range)) == 0:
        msg = "Your data does not have any contrast. "
        msg += "Please, check the value range of your data."
        warnings.warn(msg)

    if not world_coords:
        affine = np.eye(4)

    image_actor_z = actor.slicer(tmp,
                                 affine=affine,
                                 value_range=value_range,
                                 interpolation='nearest',
                                 picking_tol=0.025)

    tmp_new = image_actor_z.resliced_array()

    if len(data.shape) == 4:
        if data.shape[-1] == 3:
            print('Resized to RAS shape ', tmp_new.shape)
        else:
            print('Resized to RAS shape ', tmp_new.shape + (data.shape[-1], ))
    else:
        print('Resized to RAS shape ', tmp_new.shape)

    shape = tmp_new.shape

    if pam is not None:

        peaks_actor_z = actor.peak_slicer(pam.peak_dirs,
                                          None,
                                          mask=mask,
                                          affine=affine,
                                          colors=None)

    slicer_opacity = 1.
    image_actor_z.opacity(slicer_opacity)

    image_actor_x = image_actor_z.copy()
    x_midpoint = int(np.round(shape[0] / 2))
    image_actor_x.display_extent(x_midpoint, x_midpoint, 0, shape[1] - 1, 0,
                                 shape[2] - 1)

    image_actor_y = image_actor_z.copy()
    y_midpoint = int(np.round(shape[1] / 2))
    image_actor_y.display_extent(0, shape[0] - 1, y_midpoint, y_midpoint, 0,
                                 shape[2] - 1)

    scene.add(image_actor_z)
    scene.add(image_actor_x)
    scene.add(image_actor_y)

    if pam is not None:
        scene.add(peaks_actor_z)

    line_slider_z = ui.LineSlider2D(min_value=0,
                                    max_value=shape[2] - 1,
                                    initial_value=shape[2] / 2,
                                    text_template="{value:.0f}",
                                    length=140)

    _color_slider(line_slider_z)

    def change_slice_z(slider):
        z = int(np.round(slider.value))
        mem.slicer_curr_actor_z.display_extent(0, shape[0] - 1, 0,
                                               shape[1] - 1, z, z)
        if pam is not None:
            mem.slicer_peaks_actor_z.display_extent(0, shape[0] - 1, 0,
                                                    shape[1] - 1, z, z)
        mem.slicer_curr_z = z
        scene.reset_clipping_range()

    line_slider_x = ui.LineSlider2D(min_value=0,
                                    max_value=shape[0] - 1,
                                    initial_value=shape[0] / 2,
                                    text_template="{value:.0f}",
                                    length=140)

    _color_slider(line_slider_x)

    def change_slice_x(slider):
        x = int(np.round(slider.value))
        mem.slicer_curr_actor_x.display_extent(x, x, 0, shape[1] - 1, 0,
                                               shape[2] - 1)
        scene.reset_clipping_range()
        mem.slicer_curr_x = x
        mem.window_timer_cnt += 100

    line_slider_y = ui.LineSlider2D(min_value=0,
                                    max_value=shape[1] - 1,
                                    initial_value=shape[1] / 2,
                                    text_template="{value:.0f}",
                                    length=140)

    _color_slider(line_slider_y)

    def change_slice_y(slider):
        y = int(np.round(slider.value))

        mem.slicer_curr_actor_y.display_extent(0, shape[0] - 1, y, y, 0,
                                               shape[2] - 1)
        scene.reset_clipping_range()
        mem.slicer_curr_y = y

    # TODO there is some small bug when starting the app the handles
    # are sitting a bit low
    double_slider = ui.LineDoubleSlider2D(length=140,
                                          initial_values=value_range,
                                          min_value=tmp.min(),
                                          max_value=tmp.max(),
                                          shape='square')

    _color_dslider(double_slider)

    def apply_colormap(r1, r2):
        if mem.slicer_rgb:
            return

        if mem.slicer_colormap == 'disting':
            # use distinguishable colors
            rgb = colormap.distinguishable_colormap(nb_colors=256)
            rgb = np.asarray(rgb)
        else:
            # use matplotlib colormaps
            rgb = colormap.create_colormap(np.linspace(r1, r2, 256),
                                           name=mem.slicer_colormap,
                                           auto=True)
        N = rgb.shape[0]

        lut = colormap.LookupTable()
        lut.SetNumberOfTableValues(N)
        lut.SetRange(r1, r2)
        for i in range(N):
            r, g, b = rgb[i]
            lut.SetTableValue(i, r, g, b)
        lut.SetRampToLinear()
        lut.Build()

        mem.slicer_curr_actor_z.output.SetLookupTable(lut)
        mem.slicer_curr_actor_z.output.Update()

    def on_change_ds(slider):

        values = slider._values
        r1, r2 = values
        apply_colormap(r1, r2)

    # TODO trying to see why there is a small bug in double slider
    # double_slider.left_disk_value = 0
    # double_slider.right_disk_value = 98

    # double_slider.update(0)
    # double_slider.update(1)

    double_slider.on_change = on_change_ds

    opacity_slider = ui.LineSlider2D(min_value=0.0,
                                     max_value=1.0,
                                     initial_value=slicer_opacity,
                                     length=140,
                                     text_template="{ratio:.0%}")

    _color_slider(opacity_slider)

    def change_opacity(slider):

        slicer_opacity = slider.value
        mem.slicer_curr_actor_x.opacity(slicer_opacity)
        mem.slicer_curr_actor_y.opacity(slicer_opacity)
        mem.slicer_curr_actor_z.opacity(slicer_opacity)

    volume_slider = ui.LineSlider2D(min_value=0,
                                    max_value=data.shape[-1] - 1,
                                    initial_value=0,
                                    length=140,
                                    text_template="{value:.0f}",
                                    shape='square')

    _color_slider(volume_slider)

    def change_volume(istyle, obj, slider):
        vol_idx = int(np.round(slider.value))
        mem.slicer_vol_idx = vol_idx

        scene.rm(mem.slicer_curr_actor_x)
        scene.rm(mem.slicer_curr_actor_y)
        scene.rm(mem.slicer_curr_actor_z)

        tmp = data[..., vol_idx]
        image_actor_z = actor.slicer(tmp,
                                     affine=affine,
                                     value_range=value_range,
                                     interpolation='nearest',
                                     picking_tol=0.025)

        tmp_new = image_actor_z.resliced_array()
        mem.slicer_vol = tmp_new

        z = mem.slicer_curr_z
        image_actor_z.display_extent(0, shape[0] - 1, 0, shape[1] - 1, z, z)

        mem.slicer_curr_actor_z = image_actor_z
        mem.slicer_curr_actor_x = image_actor_z.copy()

        if pam is not None:
            mem.slicer_peaks_actor_z = peaks_actor_z

        x = mem.slicer_curr_x
        mem.slicer_curr_actor_x.display_extent(x, x, 0, shape[1] - 1, 0,
                                               shape[2] - 1)

        mem.slicer_curr_actor_y = image_actor_z.copy()
        y = mem.slicer_curr_y
        mem.slicer_curr_actor_y.display_extent(0, shape[0] - 1, y, y, 0,
                                               shape[2] - 1)

        mem.slicer_curr_actor_z.AddObserver('LeftButtonPressEvent',
                                            left_click_picker_callback, 1.0)
        mem.slicer_curr_actor_x.AddObserver('LeftButtonPressEvent',
                                            left_click_picker_callback, 1.0)
        mem.slicer_curr_actor_y.AddObserver('LeftButtonPressEvent',
                                            left_click_picker_callback, 1.0)
        scene.add(mem.slicer_curr_actor_z)
        scene.add(mem.slicer_curr_actor_x)
        scene.add(mem.slicer_curr_actor_y)

        if pam is not None:
            scene.add(mem.slicer_peaks_actor_z)

        r1, r2 = double_slider._values
        apply_colormap(r1, r2)

        istyle.force_render()

    def left_click_picker_callback(obj, ev):
        ''' Get the value of the clicked voxel and show it in the panel.'''

        event_pos = iren.GetEventPosition()

        obj.picker.Pick(event_pos[0], event_pos[1], 0, scene)

        i, j, k = obj.picker.GetPointIJK()
        res = mem.slicer_vol[i, j, k]
        try:
            message = '%.3f' % res
        except TypeError:
            message = '%.3f %.3f %.3f' % (res[0], res[1], res[2])
        picker_label.message = '({}, {}, {})'.format(str(i), str(j), str(k)) \
            + ' ' + message

    mem.slicer_vol_idx = 0
    mem.slicer_vol = tmp_new
    mem.slicer_curr_actor_x = image_actor_x
    mem.slicer_curr_actor_y = image_actor_y
    mem.slicer_curr_actor_z = image_actor_z

    if pam is not None:
        # change_volume.peaks_actor_z = peaks_actor_z
        mem.slicer_peaks_actor_z = peaks_actor_z

    mem.slicer_curr_actor_x.AddObserver('LeftButtonPressEvent',
                                        left_click_picker_callback, 1.0)
    mem.slicer_curr_actor_y.AddObserver('LeftButtonPressEvent',
                                        left_click_picker_callback, 1.0)
    mem.slicer_curr_actor_z.AddObserver('LeftButtonPressEvent',
                                        left_click_picker_callback, 1.0)

    if pam is not None:
        mem.slicer_peaks_actor_z.AddObserver('LeftButtonPressEvent',
                                             left_click_picker_callback, 1.0)

    mem.slicer_curr_x = int(np.round(shape[0] / 2))
    mem.slicer_curr_y = int(np.round(shape[1] / 2))
    mem.slicer_curr_z = int(np.round(shape[2] / 2))

    line_slider_x.on_change = change_slice_x
    line_slider_y.on_change = change_slice_y
    line_slider_z.on_change = change_slice_z

    double_slider.on_change = on_change_ds

    opacity_slider.on_change = change_opacity

    volume_slider.handle_events(volume_slider.handle.actor)
    volume_slider.on_left_mouse_button_released = change_volume

    line_slider_label_x = build_label(text="X Slice")
    line_slider_label_x.visibility = True
    x_counter = itertools.count()

    def label_callback_x(obj, event):
        line_slider_label_x.visibility = not line_slider_label_x.visibility
        line_slider_x.set_visibility(line_slider_label_x.visibility)
        cnt = next(x_counter)
        if line_slider_label_x.visibility and cnt > 0:
            scene.add(mem.slicer_curr_actor_x)
        else:
            scene.rm(mem.slicer_curr_actor_x)
        iren.Render()

    line_slider_label_x.actor.AddObserver('LeftButtonPressEvent',
                                          label_callback_x, 1.0)

    line_slider_label_y = build_label(text="Y Slice")
    line_slider_label_y.visibility = True
    y_counter = itertools.count()

    def label_callback_y(obj, event):
        line_slider_label_y.visibility = not line_slider_label_y.visibility
        line_slider_y.set_visibility(line_slider_label_y.visibility)
        cnt = next(y_counter)
        if line_slider_label_y.visibility and cnt > 0:
            scene.add(mem.slicer_curr_actor_y)
        else:
            scene.rm(mem.slicer_curr_actor_y)
        iren.Render()

    line_slider_label_y.actor.AddObserver('LeftButtonPressEvent',
                                          label_callback_y, 1.0)

    line_slider_label_z = build_label(text="Z Slice")
    line_slider_label_z.visibility = True
    z_counter = itertools.count()

    def label_callback_z(obj, event):
        line_slider_label_z.visibility = not line_slider_label_z.visibility
        line_slider_z.set_visibility(line_slider_label_z.visibility)
        cnt = next(z_counter)
        if line_slider_label_z.visibility and cnt > 0:
            scene.add(mem.slicer_curr_actor_z)
        else:
            scene.rm(mem.slicer_curr_actor_z)

        iren.Render()

    line_slider_label_z.actor.AddObserver('LeftButtonPressEvent',
                                          label_callback_z, 1.0)

    opacity_slider_label = build_label(text="Opacity")
    volume_slider_label = build_label(text="Volume")
    picker_label = build_label(text='')
    double_slider_label = build_label(text='Colormap')
    slicer_panel_label = build_label(text="Slicer panel", bold=True)

    def label_colormap_callback(obj, event):

        if mem.slicer_colormap_cnt == len(mem.slicer_colormaps) - 1:
            mem.slicer_colormap_cnt = 0
        else:
            mem.slicer_colormap_cnt += 1

        cnt = mem.slicer_colormap_cnt
        mem.slicer_colormap = mem.slicer_colormaps[cnt]
        double_slider_label.message = mem.slicer_colormap
        values = double_slider._values
        r1, r2 = values
        apply_colormap(r1, r2)
        iren.Render()

    double_slider_label.actor.AddObserver('LeftButtonPressEvent',
                                          label_colormap_callback, 1.0)

    # volume_slider.on_right_mouse_button_released = change_volume2
    def label_opacity_callback(obj, event):
        if opacity_slider.value == 0:
            opacity_slider.value = 100
            opacity_slider.update()
            slicer_opacity = 1
        else:
            opacity_slider.value = 0
            opacity_slider.update()
            slicer_opacity = 0
        mem.slicer_curr_actor_x.opacity(slicer_opacity)
        mem.slicer_curr_actor_y.opacity(slicer_opacity)
        mem.slicer_curr_actor_z.opacity(slicer_opacity)
        iren.Render()

    opacity_slider_label.actor.AddObserver('LeftButtonPressEvent',
                                           label_opacity_callback, 1.0)

    if data.ndim == 4:
        panel_size = (320, 400 + 100)
    if data.ndim == 3:
        panel_size = (320, 300 + 100)

    panel = ui.Panel2D(size=panel_size,
                       position=(870, 10),
                       color=(1, 1, 1),
                       opacity=0.1,
                       align="right")

    ys = np.linspace(0, 1, 10)

    panel.add_element(line_slider_z, coords=(0.42, ys[1]))
    panel.add_element(line_slider_y, coords=(0.42, ys[2]))
    panel.add_element(line_slider_x, coords=(0.42, ys[3]))
    panel.add_element(opacity_slider, coords=(0.42, ys[4]))
    panel.add_element(double_slider, coords=(0.42, (ys[7] + ys[8]) / 2.))

    if data.ndim == 4:
        if data.shape[-1] > 3:
            panel.add_element(volume_slider, coords=(0.42, ys[6]))

    panel.add_element(line_slider_label_z, coords=(0.1, ys[1]))
    panel.add_element(line_slider_label_y, coords=(0.1, ys[2]))
    panel.add_element(line_slider_label_x, coords=(0.1, ys[3]))
    panel.add_element(opacity_slider_label, coords=(0.1, ys[4]))
    panel.add_element(double_slider_label, coords=(0.1, (ys[7] + ys[8]) / 2.))

    if data.ndim == 4:
        if data.shape[-1] > 3:
            panel.add_element(volume_slider_label, coords=(0.1, ys[6]))

    panel.add_element(picker_label, coords=(0.2, ys[5]))

    panel.add_element(slicer_panel_label, coords=(0.05, 0.9))

    scene.add(panel)

    # initialize colormap
    r1, r2 = value_range
    apply_colormap(r1, r2)

    return panel
示例#16
0
show_m = window.ShowManager(scene, size=(1200, 900))
show_m.initialize()
"""
We'll start by creating the panel and adding it to the ``ShowManager``
"""

label_position = ui.TextBlock2D(text='Position:')
label_value = ui.TextBlock2D(text='Value:')

result_position = ui.TextBlock2D(text='')
result_value = ui.TextBlock2D(text='')

panel_picking = ui.Panel2D(size=(250, 125),
                           position=(20, 20),
                           color=(0, 0, 0),
                           opacity=0.75,
                           align="left")

panel_picking.add_element(label_position, (0.1, 0.55))
panel_picking.add_element(label_value, (0.1, 0.25))

panel_picking.add_element(result_position, (0.45, 0.55))
panel_picking.add_element(result_value, (0.45, 0.25))

scene.add(panel_picking)
"""
Add a left-click callback to the slicer. Also disable interpolation so you can
see what you are picking.
"""
示例#17
0
def modify_button_callback(i_ren, obj, button):
    button.next_icon()
    i_ren.force_render()


second_button_example.on_left_mouse_button_pressed = modify_button_callback
"""
Panels
======

Simply create a panel and add elements to it.
"""

panel = ui.Panel2D(center=(440, 90),
                   size=(300, 150),
                   color=(1, 1, 1),
                   align="right")
panel.add_element(button_example, 'relative', (0.2, 0.2))
panel.add_element(second_button_example, 'absolute', (480, 100))
"""
TextBox
=======
"""

text = ui.TextBox2D(height=3, width=10)
"""
2D Line Slider
==============
"""

示例#18
0
文件: app.py 项目: theNaavik/dipy
    def build_show(self, scene):

        title = 'Horizon ' + horizon_version
        self.show_m = window.ShowManager(scene, title=title,
                                         size=(1200, 900),
                                         order_transparent=True,
                                         reset_camera=False)
        self.show_m.initialize()

        if self.cluster and self.tractograms:

            lengths = np.array(
                [self.cla[c]['length'] for c in self.cla])
            szs = [self.cla[c]['size'] for c in self.cla]
            sizes = np.array(szs)

            # global self.panel2, slider_length, slider_size
            self.panel2 = ui.Panel2D(size=(400, 200),
                                     position=(850, 670),
                                     color=(1, 1, 1),
                                     opacity=0.1,
                                     align="right")

            slider_label_threshold = build_label(text="Threshold")
            print("Cluster threshold", self.cluster_thr)
            slider_threshold = ui.LineSlider2D(
                    min_value=5,
                    max_value=25,
                    initial_value=self.cluster_thr,
                    text_template="{value:.0f}",
                    length=140, shape='square')
            _color_slider(slider_threshold)

            slider_label_length = build_label(text="Length")
            slider_length = ui.LineSlider2D(
                    min_value=lengths.min(),
                    max_value=np.percentile(lengths, 98),
                    initial_value=np.percentile(lengths, 25),
                    text_template="{value:.0f}",
                    length=140)
            _color_slider(slider_length)

            slider_label_size = build_label(text="Size")
            slider_size = ui.LineSlider2D(
                    min_value=sizes.min(),
                    max_value=np.percentile(sizes, 98),
                    initial_value=np.percentile(sizes, 50),
                    text_template="{value:.0f}",
                    length=140)
            _color_slider(slider_size)

            # global self.length_min, size_min
            self.size_min = sizes.min()
            self.length_min = lengths.min()

            def change_threshold(istyle, obj, slider):
                sv = np.round(slider.value, 0)
                self.remove_cluster_actors(scene)
                self.add_cluster_actors(scene, self.tractograms, threshold=sv)

                # TODO need to double check if this section is still needed
                lengths = np.array(
                    [self.cla[c]['length'] for c in self.cla])
                szs = [self.cla[c]['size'] for c in self.cla]
                sizes = np.array(szs)

                slider_length.min_value = lengths.min()
                slider_length.max_value = lengths.max()
                slider_length.value = lengths.min()
                slider_length.update()

                slider_size.min_value = sizes.min()
                slider_size.max_value = sizes.max()
                slider_size.value = sizes.min()
                slider_size.update()

                self.length_min = min(lengths)
                self.size_min = min(sizes)

                self.show_m.render()

            slider_threshold.handle_events(slider_threshold.handle.actor)
            slider_threshold.on_left_mouse_button_released = change_threshold

            def hide_clusters_length(slider):
                self.length_min = np.round(slider.value)

                for k in self.cla:
                    if (self.cla[k]['length'] < self.length_min or
                            self.cla[k]['size'] < self.size_min):
                        self.cla[k]['centroid_actor'].SetVisibility(0)
                        if k.GetVisibility() == 1:
                            k.SetVisibility(0)
                    else:
                        self.cla[k]['centroid_actor'].SetVisibility(1)
                self.show_m.render()

            def hide_clusters_size(slider):
                self.size_min = np.round(slider.value)

                for k in self.cla:
                    if (self.cla[k]['length'] < self.length_min or
                            self.cla[k]['size'] < self.size_min):
                        self.cla[k]['centroid_actor'].SetVisibility(0)
                        if k.GetVisibility() == 1:
                            k.SetVisibility(0)
                    else:
                        self.cla[k]['centroid_actor'].SetVisibility(1)
                self.show_m.render()

            slider_length.on_change = hide_clusters_length

            # Clustering panel
            self.panel2.add_element(slider_label_threshold, coords=(0.1, 0.26))
            self.panel2.add_element(slider_threshold, coords=(0.4, 0.26))

            self.panel2.add_element(slider_label_length, coords=(0.1, 0.52))
            self.panel2.add_element(slider_length, coords=(0.4, 0.52))

            slider_size.on_change = hide_clusters_size

            self.panel2.add_element(slider_label_size, coords=(0.1, 0.78))
            self.panel2.add_element(slider_size, coords=(0.4, 0.78))

            scene.add(self.panel2)

            # Information panel
            text_block = build_label(HELP_MESSAGE, 18)
            text_block.message = HELP_MESSAGE

            self.help_panel = ui.Panel2D(size=(320, 200),
                                         color=(0.8, 0.8, 1),
                                         opacity=0.2,
                                         align="left")

            self.help_panel.add_element(text_block, coords=(0.05, 0.1))
            scene.add(self.help_panel)

        if len(self.images) > 0:
            # !!Only first image loading supported for now')
            data, affine = self.images[0]
            self.vox2ras = affine

            if len(self.pams) > 0:
                pam = self.pams[0]
            else:
                pam = None
            self.panel = slicer_panel(scene, self.show_m.iren, data, affine,
                                      self.world_coords,
                                      pam=pam, mem=self.mem)
        else:
            data = None
            affine = None
            pam = None

        self.win_size = scene.GetSize()

        def win_callback(obj, event):
            if self.win_size != obj.GetSize():
                size_old = self.win_size
                self.win_size = obj.GetSize()
                size_change = [self.win_size[0] - size_old[0], 0]
                if data is not None:
                    self.panel.re_align(size_change)
                if self.cluster:
                    self.panel2.re_align(size_change)
                    self.help_panel.re_align(size_change)

        self.show_m.initialize()

        self.hide_centroids = True
        self.select_all = False

        def hide():
            if self.hide_centroids:
                for ca in self.cea:
                    if (self.cea[ca]['length'] >= self.length_min or
                            self.cea[ca]['size'] >= self.size_min):
                        if self.cea[ca]['selected'] == 0:
                            ca.VisibilityOff()
            else:
                for ca in self.cea:
                    if (self.cea[ca]['length'] >= self.length_min and
                            self.cea[ca]['size'] >= self.size_min):
                        if self.cea[ca]['selected'] == 0:
                            ca.VisibilityOn()
            self.hide_centroids = not self.hide_centroids
            self.show_m.render()

        def invert():
            for ca in self.cea:
                if (self.cea[ca]['length'] >= self.length_min and
                        self.cea[ca]['size'] >= self.size_min):
                    self.cea[ca]['selected'] = \
                        not self.cea[ca]['selected']
                    cas = self.cea[ca]['cluster_actor']
                    self.cla[cas]['selected'] = \
                        self.cea[ca]['selected']
            self.show_m.render()

        def save():
            saving_streamlines = Streamlines()
            for bundle in self.cla.keys():
                if bundle.GetVisibility():
                    t = self.cla[bundle]['tractogram']
                    c = self.cla[bundle]['cluster']
                    indices = self.tractogram_clusters[t][c]
                    saving_streamlines.extend(Streamlines(indices))
            print('Saving result in tmp.trk')

            # Using the header of the first of the tractograms
            sft_new = StatefulTractogram(saving_streamlines,
                                         self.tractograms[0],
                                         Space.RASMM)
            save_tractogram(sft_new, 'tmp.trk', bbox_valid_check=False)
            print('Saved!')

        def new_window():
            active_streamlines = Streamlines()
            for bundle in self.cla.keys():
                if bundle.GetVisibility():
                    t = self.cla[bundle]['tractogram']
                    c = self.cla[bundle]['cluster']
                    indices = self.tractogram_clusters[t][c]
                    active_streamlines.extend(Streamlines(indices))

            # Using the header of the first of the tractograms
            active_sft = StatefulTractogram(active_streamlines,
                                            self.tractograms[0],
                                            Space.RASMM)
            hz2 = Horizon([active_sft],
                          self.images, cluster=True,
                          cluster_thr=self.cluster_thr/2.,
                          random_colors=self.random_colors,
                          length_lt=np.inf,
                          length_gt=0, clusters_lt=np.inf,
                          clusters_gt=0,
                          world_coords=True,
                          interactive=True)
            ren2 = hz2.build_scene()
            hz2.build_show(ren2)

        def show_all():
            if self.select_all is False:
                for ca in self.cea:
                    if (self.cea[ca]['length'] >= self.length_min and
                            self.cea[ca]['size'] >= self.size_min):
                        self.cea[ca]['selected'] = 1
                        cas = self.cea[ca]['cluster_actor']
                        self.cla[cas]['selected'] = \
                            self.cea[ca]['selected']
                self.show_m.render()
                self.select_all = True
            else:
                for ca in self.cea:
                    if (self.cea[ca]['length'] >= self.length_min and
                            self.cea[ca]['size'] >= self.size_min):
                        self.cea[ca]['selected'] = 0
                        cas = self.cea[ca]['cluster_actor']
                        self.cla[cas]['selected'] = \
                            self.cea[ca]['selected']
                self.show_m.render()
                self.select_all = False

        def expand():
            for c in self.cea:
                if self.cea[c]['selected']:
                    if not self.cea[c]['expanded']:
                        len_ = self.cea[c]['length']
                        sz_ = self.cea[c]['size']
                        if (len_ >= self.length_min and
                                sz_ >= self.size_min):
                            self.cea[c]['cluster_actor']. \
                                VisibilityOn()
                            c.VisibilityOff()
                            self.cea[c]['expanded'] = 1

            self.show_m.render()

        def reset():
            for c in self.cea:

                if (self.cea[c]['length'] >= self.length_min and
                        self.cea[c]['size'] >= self.size_min):
                    self.cea[c]['cluster_actor'].VisibilityOff()
                    c.VisibilityOn()
                    self.cea[c]['expanded'] = 0

            self.show_m.render()

        def key_press(obj, event):
            key = obj.GetKeySym()
            if self.cluster:

                # hide on/off unselected centroids
                if key == 'h' or key == 'H':
                    hide()

                # invert selection
                if key == 'i' or key == 'I':
                    invert()

                # retract help panel
                if key == 'o' or key == 'O':
                    self.help_panel._set_position((-300, 0))
                    self.show_m.render()

                # save current result
                if key == 's' or key == 'S':
                    save()

                if key == 'y' or key == 'Y':
                    new_window()

                if key == 'a' or key == 'A':
                    show_all()

                if key == 'e' or key == 'E':
                    expand()

                if key == 'r' or key == 'R':
                    reset()

        options = [r'un\hide centroids', 'invert selection',
                   r'un\select all', 'expand clusters',
                   'collapse clusters', 'save streamlines',
                   'recluster']
        listbox = ui.ListBox2D(values=options, position=(10, 300),
                               size=(200, 270),
                               multiselection=False, font_size=18)

        def display_element():
            action = listbox.selected[0]
            if action == r'un\hide centroids':
                hide()
            if action == 'invert selection':
                invert()
            if action == r'un\select all':
                show_all()
            if action == 'expand clusters':
                expand()
            if action == 'collapse clusters':
                reset()
            if action == 'save streamlines':
                save()
            if action == 'recluster':
                new_window()

        listbox.on_change = display_element
        listbox.panel.opacity = 0.2
        listbox.set_visibility(0)

        self.show_m.scene.add(listbox)

        def left_click_centroid_callback(obj, event):

            self.cea[obj]['selected'] = not self.cea[obj]['selected']
            self.cla[self.cea[obj]['cluster_actor']]['selected'] = \
                self.cea[obj]['selected']
            self.show_m.render()

        def right_click_centroid_callback(obj, event):
            for lactor in listbox._get_actors():
                lactor.SetVisibility(not lactor.GetVisibility())

            listbox.scroll_bar.set_visibility(False)
            self.show_m.render()

        def left_click_cluster_callback(obj, event):

            if self.cla[obj]['selected']:
                self.cla[obj]['centroid_actor'].VisibilityOn()
                ca = self.cla[obj]['centroid_actor']
                self.cea[ca]['selected'] = 0
                obj.VisibilityOff()
                self.cea[ca]['expanded'] = 0

            self.show_m.render()

        def right_click_cluster_callback(obj, event):
            print('Cluster Area Selected')
            self.show_m.render()

        for cl in self.cla:
            cl.AddObserver('LeftButtonPressEvent',
                           left_click_cluster_callback,
                           1.0)
            cl.AddObserver('RightButtonPressEvent',
                           right_click_cluster_callback,
                           1.0)
            self.cla[cl]['centroid_actor'].AddObserver(
                'LeftButtonPressEvent', left_click_centroid_callback, 1.0)
            self.cla[cl]['centroid_actor'].AddObserver(
                'RightButtonPressEvent', right_click_centroid_callback, 1.0)

        self.mem.window_timer_cnt = 0

        def timer_callback(obj, event):

            self.mem.window_timer_cnt += 1
            # TODO possibly add automatic rotation option
            # self.show_m.scene.azimuth(0.01 * self.mem.window_timer_cnt)
            # self.show_m.render()

        scene.reset_camera()
        scene.zoom(1.5)
        scene.reset_clipping_range()

        if self.interactive:

            self.show_m.add_window_callback(win_callback)
            self.show_m.add_timer_callback(True, 200, timer_callback)
            self.show_m.iren.AddObserver('KeyPressEvent', key_press)

            if self.return_showm:
                return self.show_m

            if self.recorded_events is None:
                self.show_m.render()
                self.show_m.start()

            else:

                # set to True if event recorded file needs updating
                recording = False
                recording_filename = self.recorded_events

                if recording:
                    self.show_m.record_events_to_file(recording_filename)
                else:
                    self.show_m.play_events_from_file(recording_filename)

        else:

            window.record(scene, out_path=self.out_png,
                          size=(1200, 900),
                          reset_camera=False)
示例#19
0
def test_wrong_interactor_style():
    panel = ui.Panel2D(size=(300, 150))
    dummy_renderer = window.Renderer()
    dummy_show_manager = window.ShowManager(dummy_renderer,
                                            interactor_style='trackball')
    npt.assert_raises(TypeError, panel.add_to_renderer, dummy_renderer)
示例#20
0
def slicer_panel(renderer,
                 iren,
                 data=None,
                 affine=None,
                 world_coords=False,
                 pam=None,
                 mask=None):
    """ Slicer panel with slicer included

    Parameters
    ----------
    renderer : Renderer
    iren : Interactor
    data : 3d ndarray
    affine : 4x4 ndarray
    world_coords : bool
        If True then the affine is applied.

    peaks : PeaksAndMetrics
        Default None

    Returns
    -------
    panel : Panel

    """

    orig_shape = data.shape
    print('Original shape', orig_shape)
    ndim = data.ndim
    tmp = data
    if ndim == 4:
        if orig_shape[-1] > 3:
            tmp = data[..., 0]
            orig_shape = orig_shape[:3]
            value_range = np.percentile(data[..., 0], q=[2, 98])
        if orig_shape[-1] == 3:
            value_range = (0, 1.)
            HORIMEM.slicer_rgb = True
    if ndim == 3:
        value_range = np.percentile(tmp, q=[2, 98])

    if not world_coords:
        affine = np.eye(4)

    image_actor_z = actor.slicer(tmp,
                                 affine=affine,
                                 value_range=value_range,
                                 interpolation='nearest',
                                 picking_tol=0.025)

    tmp_new = image_actor_z.resliced_array()

    if len(data.shape) == 4:
        if data.shape[-1] == 3:
            print('Resized to RAS shape ', tmp_new.shape)
        else:
            print('Resized to RAS shape ', tmp_new.shape + (data.shape[-1], ))
    else:
        print('Resized to RAS shape ', tmp_new.shape)

    shape = tmp_new.shape

    if pam is not None:

        peaks_actor_z = actor.peak_slicer(pam.peak_dirs,
                                          None,
                                          mask=mask,
                                          affine=affine,
                                          colors=None)

    slicer_opacity = 1.
    image_actor_z.opacity(slicer_opacity)

    image_actor_x = image_actor_z.copy()
    x_midpoint = int(np.round(shape[0] / 2))
    image_actor_x.display_extent(x_midpoint, x_midpoint, 0, shape[1] - 1, 0,
                                 shape[2] - 1)

    image_actor_y = image_actor_z.copy()
    y_midpoint = int(np.round(shape[1] / 2))
    image_actor_y.display_extent(0, shape[0] - 1, y_midpoint, y_midpoint, 0,
                                 shape[2] - 1)

    renderer.add(image_actor_z)
    renderer.add(image_actor_x)
    renderer.add(image_actor_y)

    if pam is not None:
        renderer.add(peaks_actor_z)

    line_slider_z = ui.LineSlider2D(min_value=0,
                                    max_value=shape[2] - 1,
                                    initial_value=shape[2] / 2,
                                    text_template="{value:.0f}",
                                    length=140)

    _color_slider(line_slider_z)

    button = ui.Button2D(icon_fnames, size=(20, 30))

    def change_slice_z(slider):
        z = int(np.round(slider.value))
        HORIMEM.slicer_curr_actor_z.display_extent(0, shape[0] - 1, 0,
                                                   shape[1] - 1, z, z)
        if pam is not None:
            HORIMEM.slicer_peaks_actor_z.display_extent(
                0, shape[0] - 1, 0, shape[1] - 1, z, z)
        HORIMEM.slicer_curr_z = z

    line_slider_x = ui.LineSlider2D(min_value=0,
                                    max_value=shape[0] - 1,
                                    initial_value=shape[0] / 2,
                                    text_template="{value:.0f}",
                                    length=140)

    _color_slider(line_slider_x)

    def change_slice_x(slider):
        x = int(np.round(slider.value))
        HORIMEM.slicer_curr_actor_x.display_extent(x, x, 0, shape[1] - 1, 0,
                                                   shape[2] - 1)
        HORIMEM.slicer_curr_x = x
        HORIMEM.window_timer_cnt += 100

    line_slider_y = ui.LineSlider2D(min_value=0,
                                    max_value=shape[1] - 1,
                                    initial_value=shape[1] / 2,
                                    text_template="{value:.0f}",
                                    length=140)

    _color_slider(line_slider_y)

    def change_slice_y(slider):
        y = int(np.round(slider.value))

        HORIMEM.slicer_curr_actor_y.display_extent(0, shape[0] - 1, y, y, 0,
                                                   shape[2] - 1)
        HORIMEM.slicer_curr_y = y

    double_slider = ui.LineDoubleSlider2D(length=140,
                                          initial_values=value_range,
                                          min_value=tmp.min(),
                                          max_value=tmp.max(),
                                          shape='square')

    _color_dslider(double_slider)

    def apply_colormap(r1, r2):
        if HORIMEM.slicer_rgb:
            return

        if HORIMEM.slicer_colormap == 'disting':
            # use distinguishable colors
            rgb = colormap.distinguishable_colormap(nb_colors=256)
            rgb = np.asarray(rgb)
        else:
            # use matplotlib colormaps
            rgb = colormap.create_colormap(np.linspace(r1, r2, 256),
                                           name=HORIMEM.slicer_colormap,
                                           auto=True)
        N = rgb.shape[0]

        lut = colormap.vtk.vtkLookupTable()
        lut.SetNumberOfTableValues(N)
        lut.SetRange(r1, r2)
        for i in range(N):
            r, g, b = rgb[i]
            lut.SetTableValue(i, r, g, b)
        lut.SetRampToLinear()
        lut.Build()

        HORIMEM.slicer_curr_actor_z.output.SetLookupTable(lut)
        HORIMEM.slicer_curr_actor_z.output.Update()

    def on_change_ds(slider):

        values = slider._values
        r1, r2 = values
        apply_colormap(r1, r2)

    double_slider.on_change = on_change_ds

    opacity_slider = ui.LineSlider2D(min_value=0.0,
                                     max_value=1.0,
                                     initial_value=slicer_opacity,
                                     length=140,
                                     text_template="{ratio:.0%}")

    _color_slider(opacity_slider)

    def change_opacity(slider):
        slicer_opacity = slider.value
        HORIMEM.slicer_curr_actor_x.opacity(slicer_opacity)
        HORIMEM.slicer_curr_actor_y.opacity(slicer_opacity)
        HORIMEM.slicer_curr_actor_z.opacity(slicer_opacity)

    volume_slider = ui.LineSlider2D(min_value=0,
                                    max_value=data.shape[-1] - 1,
                                    initial_value=0,
                                    length=140,
                                    text_template="{value:.0f}",
                                    shape='square')

    _color_slider(volume_slider)

    def change_volume(istyle, obj, slider):
        vol_idx = int(np.round(slider.value))
        HORIMEM.slicer_vol_idx = vol_idx

        renderer.rm(HORIMEM.slicer_curr_actor_x)
        renderer.rm(HORIMEM.slicer_curr_actor_y)
        renderer.rm(HORIMEM.slicer_curr_actor_z)

        tmp = data[..., vol_idx]
        image_actor_z = actor.slicer(tmp,
                                     affine=affine,
                                     value_range=value_range,
                                     interpolation='nearest',
                                     picking_tol=0.025)

        tmp_new = image_actor_z.resliced_array()
        HORIMEM.slicer_vol = tmp_new

        z = HORIMEM.slicer_curr_z
        image_actor_z.display_extent(0, shape[0] - 1, 0, shape[1] - 1, z, z)

        HORIMEM.slicer_curr_actor_z = image_actor_z
        HORIMEM.slicer_curr_actor_x = image_actor_z.copy()

        if pam is not None:
            HORIMEM.slicer_peaks_actor_z = peaks_actor_z

        x = HORIMEM.slicer_curr_x
        HORIMEM.slicer_curr_actor_x.display_extent(x, x, 0, shape[1] - 1, 0,
                                                   shape[2] - 1)

        HORIMEM.slicer_curr_actor_y = image_actor_z.copy()
        y = HORIMEM.slicer_curr_y
        HORIMEM.slicer_curr_actor_y.display_extent(0, shape[0] - 1, y, y, 0,
                                                   shape[2] - 1)

        HORIMEM.slicer_curr_actor_z.AddObserver('LeftButtonPressEvent',
                                                left_click_picker_callback,
                                                1.0)
        HORIMEM.slicer_curr_actor_x.AddObserver('LeftButtonPressEvent',
                                                left_click_picker_callback,
                                                1.0)
        HORIMEM.slicer_curr_actor_y.AddObserver('LeftButtonPressEvent',
                                                left_click_picker_callback,
                                                1.0)
        renderer.add(HORIMEM.slicer_curr_actor_z)
        renderer.add(HORIMEM.slicer_curr_actor_x)
        renderer.add(HORIMEM.slicer_curr_actor_y)

        if pam is not None:
            renderer.add(HORIMEM.slicer_peaks_actor_z)

        r1, r2 = double_slider._values
        apply_colormap(r1, r2)

        istyle.force_render()

    def left_click_picker_callback(obj, ev):
        ''' Get the value of the clicked voxel and show it in the panel.'''

        event_pos = iren.GetEventPosition()

        obj.picker.Pick(event_pos[0], event_pos[1], 0, renderer)

        i, j, k = obj.picker.GetPointIJK()
        res = HORIMEM.slicer_vol[i, j, k]

        # generate figure
        cc_vox = data[i, j, k]
        print(cc_vox)
        plt.plot([cc_vox])
        plt.savefig('test.png')
        icon_fnames = [('square', 'test.png'), ('square1', 'test.png')]
        # connect to button
        button = ui.Button2D(icon_fnames, size=(20, 30))
        panel.add_element(button, coords=(0., 0.))
        # make button visible

        try:
            message = '%.3f' % res
        except TypeError:
            message = '%.3f %.3f %.3f' % (res[0], res[1], res[2])
        picker_label.message = '({}, {}, {})'.format(str(i), str(j), str(k)) \
            + ' ' + message

    HORIMEM.slicer_vol_idx = 0
    HORIMEM.slicer_vol = tmp_new
    HORIMEM.slicer_curr_actor_x = image_actor_x
    HORIMEM.slicer_curr_actor_y = image_actor_y
    HORIMEM.slicer_curr_actor_z = image_actor_z

    if pam is not None:
        # change_volume.peaks_actor_z = peaks_actor_z
        HORIMEM.slicer_peaks_actor_z = peaks_actor_z

    HORIMEM.slicer_curr_actor_x.AddObserver('LeftButtonPressEvent',
                                            left_click_picker_callback, 1.0)
    HORIMEM.slicer_curr_actor_y.AddObserver('LeftButtonPressEvent',
                                            left_click_picker_callback, 1.0)
    HORIMEM.slicer_curr_actor_z.AddObserver('LeftButtonPressEvent',
                                            left_click_picker_callback, 1.0)

    if pam is not None:
        HORIMEM.slicer_peaks_actor_z.AddObserver('LeftButtonPressEvent',
                                                 left_click_picker_callback,
                                                 1.0)

    HORIMEM.slicer_curr_x = int(np.round(shape[0] / 2))
    HORIMEM.slicer_curr_y = int(np.round(shape[1] / 2))
    HORIMEM.slicer_curr_z = int(np.round(shape[2] / 2))

    line_slider_x.on_change = change_slice_x
    line_slider_y.on_change = change_slice_y
    line_slider_z.on_change = change_slice_z

    double_slider.on_change = on_change_ds

    opacity_slider.on_change = change_opacity

    volume_slider.handle_events(volume_slider.handle.actor)
    volume_slider.on_left_mouse_button_released = change_volume

    # volume_slider.on_right_mouse_button_released = change_volume2

    line_slider_label_x = build_label(text="X Slice")
    line_slider_label_x.visibility = True
    x_counter = itertools.count()

    def label_callback_x(obj, event):
        line_slider_label_x.visibility = not line_slider_label_x.visibility
        line_slider_x.set_visibility(line_slider_label_x.visibility)
        cnt = next(x_counter)
        if line_slider_label_x.visibility and cnt > 0:
            renderer.add(HORIMEM.slicer_curr_actor_x)
        else:
            renderer.rm(HORIMEM.slicer_curr_actor_x)
        iren.Render()

    line_slider_label_x.actor.AddObserver('LeftButtonPressEvent',
                                          label_callback_x, 1.0)

    line_slider_label_y = build_label(text="Y Slice")
    line_slider_label_y.visibility = True
    y_counter = itertools.count()

    def label_callback_y(obj, event):
        line_slider_label_y.visibility = not line_slider_label_y.visibility
        line_slider_y.set_visibility(line_slider_label_y.visibility)
        cnt = next(y_counter)
        if line_slider_label_y.visibility and cnt > 0:
            renderer.add(HORIMEM.slicer_curr_actor_y)
        else:
            renderer.rm(HORIMEM.slicer_curr_actor_y)
        iren.Render()

    line_slider_label_y.actor.AddObserver('LeftButtonPressEvent',
                                          label_callback_y, 1.0)

    line_slider_label_z = build_label(text="Z Slice")
    line_slider_label_z.visibility = True
    z_counter = itertools.count()

    def label_callback_z(obj, event):
        line_slider_label_z.visibility = not line_slider_label_z.visibility
        line_slider_z.set_visibility(line_slider_label_z.visibility)
        cnt = next(z_counter)
        if line_slider_label_z.visibility and cnt > 0:
            renderer.add(HORIMEM.slicer_curr_actor_z)
        else:
            renderer.rm(HORIMEM.slicer_curr_actor_z)

        iren.Render()

    line_slider_label_z.actor.AddObserver('LeftButtonPressEvent',
                                          label_callback_z, 1.0)

    opacity_slider_label = build_label(text="Opacity")
    volume_slider_label = build_label(text="Volume")
    picker_label = build_label(text='')
    double_slider_label = build_label(text='Colormap')

    def label_colormap_callback(obj, event):

        if HORIMEM.slicer_colormap_cnt == len(HORIMEM.slicer_colormaps) - 1:
            HORIMEM.slicer_colormap_cnt = 0
        else:
            HORIMEM.slicer_colormap_cnt += 1

        cnt = HORIMEM.slicer_colormap_cnt
        HORIMEM.slicer_colormap = HORIMEM.slicer_colormaps[cnt]
        double_slider_label.message = HORIMEM.slicer_colormap
        values = double_slider._values
        r1, r2 = values
        apply_colormap(r1, r2)
        iren.Render()

    double_slider_label.actor.AddObserver('LeftButtonPressEvent',
                                          label_colormap_callback, 1.0)

    if data.ndim == 4:
        panel_size = (400, 400 + 100)
    if data.ndim == 3:
        panel_size = (400, 300 + 100)

    panel = ui.Panel2D(size=panel_size,
                       position=(850, 110),
                       color=(1, 1, 1),
                       opacity=0.1,
                       align="right")

    ys = np.linspace(0, 1, 10)

    panel.add_element(line_slider_z, coords=(0.4, ys[1]))
    panel.add_element(line_slider_y, coords=(0.4, ys[2]))
    panel.add_element(line_slider_x, coords=(0.4, ys[3]))
    panel.add_element(opacity_slider, coords=(0.4, ys[4]))
    panel.add_element(double_slider, coords=(0.4, (ys[7] + ys[8]) / 2.))

    if data.ndim == 4:
        if data.shape[-1] > 3:
            panel.add_element(volume_slider, coords=(0.4, ys[6]))

    panel.add_element(line_slider_label_z, coords=(0.1, ys[1]))
    panel.add_element(line_slider_label_y, coords=(0.1, ys[2]))
    panel.add_element(line_slider_label_x, coords=(0.1, ys[3]))
    panel.add_element(opacity_slider_label, coords=(0.1, ys[4]))
    panel.add_element(double_slider_label, coords=(0.1, (ys[7] + ys[8]) / 2.))

    if data.ndim == 4:
        if data.shape[-1] > 3:
            panel.add_element(volume_slider_label, coords=(0.1, ys[6]))

    panel.add_element(picker_label, coords=(0.2, ys[5]))

    renderer.add(panel)

    # initialize colormap
    r1, r2 = value_range
    apply_colormap(r1, r2)

    return panel
示例#21
0
    label.shadow = False
    label.actor.GetTextProperty().SetBackgroundColor(0, 0, 0)
    label.actor.GetTextProperty().SetBackgroundOpacity(0.0)
    label.color = (1, 1, 1)

    return label


line_slider_label_z = build_label(text="Z Slice")
line_slider_label_x = build_label(text="X Slice")
line_slider_label_y = build_label(text="Y Slice")
opacity_slider_label = build_label(text="Opacity")

panel = ui.Panel2D(center=(1030, 120),
                   size=(300, 200),
                   color=(1, 1, 1),
                   opacity=0.1,
                   align="right")

# create a panel to contain the sliders and labels
panel.add_element(line_slider_label_x, 'relative', (0.1, 0.75))
panel.add_element(line_slider_x, 'relative', (0.65, 0.8))
panel.add_element(line_slider_label_y, 'relative', (0.1, 0.55))
panel.add_element(line_slider_y, 'relative', (0.65, 0.6))
panel.add_element(line_slider_label_z, 'relative', (0.1, 0.35))
panel.add_element(line_slider_z, 'relative', (0.65, 0.4))
panel.add_element(opacity_slider_label, 'relative', (0.1, 0.15))
panel.add_element(opacity_slider, 'relative', (0.65, 0.2))

show_m.ren.add(panel)
示例#22
0
    def build_show(self, scene):

        show_m = window.ShowManager(scene,
                                    size=(1200, 900),
                                    order_transparent=True,
                                    reset_camera=False)
        show_m.initialize()

        if self.cluster:

            lengths = np.array([self.cla[c]['length'] for c in self.cla])
            szs = [self.cla[c]['size'] for c in self.cla]
            sizes = np.array(szs)

            # global self.panel2, slider_length, slider_size
            self.panel2 = ui.Panel2D(size=(300, 200),
                                     position=(850, 320),
                                     color=(1, 1, 1),
                                     opacity=0.1,
                                     align="right")

            slider_label_length = build_label(text="Length")
            slider_length = ui.LineSlider2D(
                min_value=lengths.min(),
                max_value=np.percentile(lengths, 98),
                initial_value=np.percentile(lengths, 25),
                text_template="{value:.0f}",
                length=140)

            slider_label_size = build_label(text="Size")
            slider_size = ui.LineSlider2D(min_value=sizes.min(),
                                          max_value=np.percentile(sizes, 98),
                                          initial_value=np.percentile(
                                              sizes, 50),
                                          text_template="{value:.0f}",
                                          length=140)

            # global self.length_min, size_min
            self.size_min = sizes.min()
            self.length_min = lengths.min()

            def hide_clusters_length(slider):
                self.length_min = np.round(slider.value)

                for k in self.cla:
                    if (self.cla[k]['length'] < self.length_min
                            or self.cla[k]['size'] < self.size_min):
                        self.cla[k]['centroid_actor'].SetVisibility(0)
                        if k.GetVisibility() == 1:
                            k.SetVisibility(0)
                    else:
                        self.cla[k]['centroid_actor'].SetVisibility(1)
                show_m.render()

            def hide_clusters_size(slider):
                self.size_min = np.round(slider.value)

                for k in self.cla:
                    if (self.cla[k]['length'] < self.length_min
                            or self.cla[k]['size'] < self.size_min):
                        self.cla[k]['centroid_actor'].SetVisibility(0)
                        if k.GetVisibility() == 1:
                            k.SetVisibility(0)
                    else:
                        self.cla[k]['centroid_actor'].SetVisibility(1)
                show_m.render()

            slider_length.on_change = hide_clusters_length

            self.panel2.add_element(slider_label_length, coords=(0.1, 0.333))
            self.panel2.add_element(slider_length, coords=(0.4, 0.333))

            slider_size.on_change = hide_clusters_size

            self.panel2.add_element(slider_label_size, coords=(0.1, 0.6666))
            self.panel2.add_element(slider_size, coords=(0.4, 0.6666))

            scene.add(self.panel2)

            text_block = build_label(HELP_MESSAGE, 16)  # ui.TextBlock2D()
            text_block.message = HELP_MESSAGE

            help_panel = ui.Panel2D(size=(300, 200),
                                    color=(1, 1, 1),
                                    opacity=0.1,
                                    align="left")

            help_panel.add_element(text_block, coords=(0.05, 0.1))
            scene.add(help_panel)

        if len(self.images) > 0:
            # !!Only first image loading supported for now')
            data, affine = self.images[0]
            self.panel = slicer_panel(scene, data, affine, self.world_coords)
        else:
            data = None
            affine = None

        self.win_size = scene.GetSize()

        def win_callback(obj, event):
            if self.win_size != obj.GetSize():
                size_old = self.win_size
                self.win_size = obj.GetSize()
                size_change = [self.win_size[0] - size_old[0], 0]
                if data is not None:
                    self.panel.re_align(size_change)
                if self.cluster:
                    self.panel2.re_align(size_change)
                    help_panel.re_align(size_change)

        show_m.initialize()

        def left_click_centroid_callback(obj, event):

            self.cea[obj]['selected'] = not self.cea[obj]['selected']
            self.cla[self.cea[obj]['cluster_actor']]['selected'] = \
                self.cea[obj]['selected']
            show_m.render()

        def left_click_cluster_callback(obj, event):

            if self.cla[obj]['selected']:
                self.cla[obj]['centroid_actor'].VisibilityOn()
                ca = self.cla[obj]['centroid_actor']
                self.cea[ca]['selected'] = 0
                obj.VisibilityOff()
                self.cea[ca]['expanded'] = 0

            show_m.render()

        for cl in self.cla:
            cl.AddObserver('LeftButtonPressEvent', left_click_cluster_callback,
                           1.0)
            self.cla[cl]['centroid_actor'].AddObserver(
                'LeftButtonPressEvent', left_click_centroid_callback, 1.0)

        self.hide_centroids = True
        self.select_all = False

        def key_press(obj, event):
            key = obj.GetKeySym()
            if self.cluster:

                # hide on/off unselected centroids
                if key == 'h' or key == 'H':
                    if self.hide_centroids:
                        for ca in self.cea:
                            if (self.cea[ca]['length'] >= self.length_min
                                    or self.cea[ca]['size'] >= self.size_min):
                                if self.cea[ca]['selected'] == 0:
                                    ca.VisibilityOff()
                    else:
                        for ca in self.cea:
                            if (self.cea[ca]['length'] >= self.length_min
                                    and self.cea[ca]['size'] >= self.size_min):
                                if self.cea[ca]['selected'] == 0:
                                    ca.VisibilityOn()
                    self.hide_centroids = not self.hide_centroids
                    show_m.render()

                # invert selection
                if key == 'i' or key == 'I':

                    for ca in self.cea:
                        if (self.cea[ca]['length'] >= self.self.length_min
                                and self.cea[ca]['size'] >= self.size_min):
                            self.cea[ca]['selected'] = \
                                not self.cea[ca]['selected']
                            cas = self.cea[ca]['cluster_actor']
                            self.cla[cas]['selected'] = \
                                self.cea[ca]['selected']
                    show_m.render()

                # save current result
                if key == 's' or key == 'S':
                    saving_streamlines = Streamlines()
                    for bundle in self.cla.keys():
                        if bundle.GetVisibility():
                            t = self.cla[bundle]['tractogram']
                            c = self.cla[bundle]['cluster']
                            indices = self.tractogram_clusters[t][c]
                            saving_streamlines.extend(Streamlines(indices))
                    print('Saving result in tmp.trk')
                    save_trk('tmp.trk', saving_streamlines, np.eye(4))

                if key == 'y' or key == 'Y':
                    active_streamlines = Streamlines()
                    for bundle in self.cla.keys():
                        if bundle.GetVisibility():
                            t = self.cla[bundle]['tractogram']
                            c = self.cla[bundle]['cluster']
                            indices = self.tractogram_clusters[t][c]
                            active_streamlines.extend(Streamlines(indices))

                    # self.tractograms = [active_streamlines]
                    hz2 = horizon([active_streamlines],
                                  self.images,
                                  cluster=True,
                                  cluster_thr=5,
                                  random_colors=self.random_colors,
                                  length_lt=np.inf,
                                  length_gt=0,
                                  clusters_lt=np.inf,
                                  clusters_gt=0,
                                  world_coords=True,
                                  interactive=True)
                    ren2 = hz2.build_scene()
                    hz2.build_show(ren2)

                if key == 'a' or key == 'A':

                    if self.select_all is False:
                        for ca in self.cea:
                            if (self.cea[ca]['length'] >= self.length_min
                                    and self.cea[ca]['size'] >= self.size_min):
                                self.cea[ca]['selected'] = 1
                                cas = self.cea[ca]['cluster_actor']
                                self.cla[cas]['selected'] = \
                                    self.cea[ca]['selected']
                        show_m.render()
                        self.select_all = True
                    else:
                        for ca in self.cea:
                            if (self.cea[ca]['length'] >= self.length_min
                                    and self.cea[ca]['size'] >= self.size_min):
                                self.cea[ca]['selected'] = 0
                                cas = self.cea[ca]['cluster_actor']
                                self.cla[cas]['selected'] = \
                                    self.cea[ca]['selected']
                        show_m.render()
                        self.select_all = False

                if key == 'e' or key == 'E':

                    for c in self.cea:
                        if self.cea[c]['selected']:
                            if not self.cea[c]['expanded']:
                                len_ = self.cea[c]['length']
                                sz_ = self.cea[c]['size']
                                if (len_ >= self.length_min
                                        and sz_ >= self.size_min):
                                    self.cea[c]['cluster_actor']. \
                                        VisibilityOn()
                                    c.VisibilityOff()
                                    self.cea[c]['expanded'] = 1

                    show_m.render()

                if key == 'r' or key == 'R':

                    for c in self.cea:

                        if (self.cea[c]['length'] >= self.length_min
                                and self.cea[c]['size'] >= self.size_min):
                            self.cea[c]['cluster_actor'].VisibilityOff()
                            c.VisibilityOn()
                            self.cea[c]['expanded'] = 0

                show_m.render()

        scene.reset_camera()
        scene.zoom(1.5)
        scene.reset_clipping_range()

        if self.interactive:

            show_m.add_window_callback(win_callback)
            show_m.iren.AddObserver('KeyPressEvent', key_press)
            show_m.render()
            show_m.start()

        else:

            window.record(scene,
                          out_path=self.out_png,
                          size=(1200, 900),
                          reset_camera=False)
示例#23
0
def slicer_panel(renderer, data=None, affine=None, world_coords=False):
    """ Slicer panel with slicer included

    Parameters
    ----------
    renderer : Renderer
    data : 3d ndarray
    affine : 4x4 ndarray
    world_coords : bool
        If True then the affine is applied.

    Returns
    -------
    panel : Panel

    """

    shape = data.shape
    if not world_coords:
        image_actor_z = actor.slicer(data, affine=np.eye(4))
    else:
        image_actor_z = actor.slicer(data, affine)

    slicer_opacity = 0.6
    image_actor_z.opacity(slicer_opacity)

    image_actor_x = image_actor_z.copy()
    x_midpoint = int(np.round(shape[0] / 2))
    image_actor_x.display_extent(x_midpoint,
                                 x_midpoint, 0,
                                 shape[1] - 1,
                                 0,
                                 shape[2] - 1)

    image_actor_y = image_actor_z.copy()
    y_midpoint = int(np.round(shape[1] / 2))
    image_actor_y.display_extent(0,
                                 shape[0] - 1,
                                 y_midpoint,
                                 y_midpoint,
                                 0,
                                 shape[2] - 1)

    renderer.add(image_actor_z)
    renderer.add(image_actor_x)
    renderer.add(image_actor_y)

    line_slider_z = ui.LineSlider2D(min_value=0,
                                    max_value=shape[2] - 1,
                                    initial_value=shape[2] / 2,
                                    text_template="{value:.0f}",
                                    length=140)

    line_slider_x = ui.LineSlider2D(min_value=0,
                                    max_value=shape[0] - 1,
                                    initial_value=shape[0] / 2,
                                    text_template="{value:.0f}",
                                    length=140)

    line_slider_y = ui.LineSlider2D(min_value=0,
                                    max_value=shape[1] - 1,
                                    initial_value=shape[1] / 2,
                                    text_template="{value:.0f}",
                                    length=140)

    opacity_slider = ui.LineSlider2D(min_value=0.0,
                                     max_value=1.0,
                                     initial_value=slicer_opacity,
                                     length=140)

    def change_slice_z(slider):
        z = int(np.round(slider.value))
        image_actor_z.display_extent(0, shape[0] - 1,
                                     0, shape[1] - 1, z, z)

    def change_slice_x(slider):
        x = int(np.round(slider.value))
        image_actor_x.display_extent(x, x, 0, shape[1] - 1, 0,
                                     shape[2] - 1)

    def change_slice_y(slider):
        y = int(np.round(slider.value))
        image_actor_y.display_extent(0, shape[0] - 1, y, y,
                                     0, shape[2] - 1)

    def change_opacity(slider):
        slicer_opacity = slider.value
        image_actor_z.opacity(slicer_opacity)
        image_actor_x.opacity(slicer_opacity)
        image_actor_y.opacity(slicer_opacity)

    line_slider_z.on_change = change_slice_z
    line_slider_y.on_change = change_slice_y
    line_slider_x.on_change = change_slice_x
    opacity_slider.on_change = change_opacity

    line_slider_label_z = build_label(text="Z Slice")
    line_slider_label_x = build_label(text="X Slice")
    line_slider_label_y = build_label(text="Y Slice")
    opacity_slider_label = build_label(text="Opacity")

    panel = ui.Panel2D(size=(300, 200),
                       position=(850, 110),
                       color=(1, 1, 1),
                       opacity=0.1,
                       align="right")

    panel.add_element(line_slider_label_x, coords=(0.1, 0.75))
    panel.add_element(line_slider_x, coords=(0.4, 0.8))
    panel.add_element(line_slider_label_y, coords=(0.1, 0.55))
    panel.add_element(line_slider_y, coords=(0.4, 0.6))
    panel.add_element(line_slider_label_z, coords=(0.1, 0.35))
    panel.add_element(line_slider_z, coords=(0.4, 0.4))
    panel.add_element(opacity_slider_label, coords=(0.1, 0.15))
    panel.add_element(opacity_slider, coords=(0.4, 0.2))

    renderer.add(panel)
    return panel
示例#24
0
def main():
    global parser
    global args
    global model
    global bar
    global lut_cmap
    global list_x_file
    global max_weight
    global saturation
    global renderer
    global norm_fib
    global norm1
    global norm2
    global norm3
    global big_stream_actor
    global good_stream_actor
    global weak_stream_actor
    global big_Weight
    global good_Weight
    global weak_Weight
    global smallBundle_safe
    global smallWeight_safe
    global show_m
    global big_bundle
    global good_bundle
    global weak_bundle
    global nF
    global nIC
    global Ra
    global change_colormap_slider
    global remove_small_weights_slider
    global opacity_slider
    global remove_big_weights_slider
    global change_iteration_slider
    global num_computed_streamlines
    global numbers_of_streamlines_in_interval
    
    #defining the model used (Stick or cylinder)
    model = None
    if(os.path.isdir(args.commitOutputPath+"/Results_StickZeppelinBall") and os.path.isdir(args.commitOutputPath+"/Results_CylinderZeppelinBall")):
        model_index = input("Which model do you want to load (1 for 'Cylinder', 2 for 'Stick') : ")
        if(model_index==1): model = "Cylinder"
        else: model ="Stick"
    elif(os.path.isdir(args.commitOutputPath+"/Results_StickZeppelinBall")):
        model = "Stick"
    elif(os.path.isdir(args.commitOutputPath+"/Results_CylinderZeppelinBall")):
        model = "Cylinder"
    else:
        print("No valide model in this path")
        sys.exit(0)


    #formalizing the filenames of the iterations
    list_x_file = [file for file in os.listdir(args.commitOutputPath+"/Coeff_x_"+model+"ZeppelinBall/") if (file.endswith('.npy') and (file[:-4]).isdigit() )]
    normalize_file_name(list_x_file)
    list_x_file.sort()
    num_iteration=len(list_x_file)

    #number of streamlines we want to load
    num_computed_streamlines = int(args.streamlinesNumber)
    #computing interval of weights
    max_weight = 0;
    if(model == "Cylinder"):
        file = open( args.commitOutputPath+"/Results_"+model+"ZeppelinBall/results.pickle",'rb' )
        object_file = pickle.load( file )

        Ra = np.linspace( 0.75,3.5,12 ) * 1E-6

        nIC = len(Ra)    # IC  atoms
        nEC = 4          # EC  atoms
        nISO = 1         # ISO atoms

        nF = object_file[0]['optimization']['regularisation']['sizeIC']
        nE = object_file[0]['optimization']['regularisation']['sizeEC']
        nV = object_file[0]['optimization']['regularisation']['sizeISO']


        num_ADI = np.zeros( nF )
        den_ADI = np.zeros( nF )

        dim = nib.load(args.commitOutputPath+"/Results_"+model+"ZeppelinBall/compartment_IC.nii.gz").get_data().shape
        norm_fib = np.load(args.commitOutputPath+"/Coeff_x_"+model+"ZeppelinBall/norm_fib.npy")
        norm1 = np.load(args.commitOutputPath+"/Coeff_x_"+model+"ZeppelinBall/norm1.npy")
        norm2 = np.load(args.commitOutputPath+"/Coeff_x_"+model+"ZeppelinBall/norm2.npy")
        norm3 = np.load(args.commitOutputPath+"/Coeff_x_"+model+"ZeppelinBall/norm3.npy")
        for itNbr in list_x_file:
            #computing diameter
            x = np.load(args.commitOutputPath+"/Coeff_x_"+model+"ZeppelinBall/"+ itNbr +'.npy')
            x_norm = x / np.hstack( (norm1*norm_fib,norm2,norm3) )

            for i in range(nIC):
                den_ADI = den_ADI + x_norm[i*nF:(i+1)*nF]
                num_ADI = num_ADI + x_norm[i*nF:(i+1)*nF] * Ra[i]

            Weight = 2 * ( num_ADI / ( den_ADI + np.spacing(1) ) ) * 1E6
            smallWeight_safe = Weight[:num_computed_streamlines]
            itNbr_max = np.amax(smallWeight_safe)
            if(itNbr_max>max_weight):
                max_weight=itNbr_max
    else:#model==Stick
        file = open( args.commitOutputPath+"/Results_"+model+"ZeppelinBall/results.pickle",'rb' )
        object_file = pickle.load( file )
        norm_fib = np.load(args.commitOutputPath+"/Coeff_x_"+model+"ZeppelinBall/norm_fib.npy")
        norm1 = np.load(args.commitOutputPath+"/Coeff_x_"+model+"ZeppelinBall/norm1.npy")
        norm2 = np.load(args.commitOutputPath+"/Coeff_x_"+model+"ZeppelinBall/norm2.npy")
        norm3 = np.load(args.commitOutputPath+"/Coeff_x_"+model+"ZeppelinBall/norm3.npy")
        nF = object_file[0]['optimization']['regularisation']['sizeIC']
        for itNbr in list_x_file:
            x = np.load(args.commitOutputPath+"/Coeff_x_"+model+"ZeppelinBall/"+ itNbr +'.npy')
            x_norm = x / np.hstack( (norm1*norm_fib,norm2,norm3) )

            Weight = x_norm[:nF]  #signal fractions
            smallWeight_safe = Weight[:num_computed_streamlines]
            itNbr_max = np.amax(smallWeight_safe)
            if(itNbr_max>max_weight):
                max_weight=itNbr_max
    #we need an interval slightly bigger than the max_weight
    max_weight = max_weight + 0.00001

    #computing initial weights
    if(model == "Cylinder"):#model==Cylinder
        file = open( args.commitOutputPath+"/Results_"+model+"ZeppelinBall/results.pickle",'rb' )
        object_file = pickle.load( file )

        Ra = np.linspace( 0.75,3.5,12 ) * 1E-6

        nIC = len(Ra)    # IC  atoms
        nEC = 4          # EC  atoms
        nISO = 1         # ISO atoms

        nF = object_file[0]['optimization']['regularisation']['sizeIC']
        nE = object_file[0]['optimization']['regularisation']['sizeEC']
        nV = object_file[0]['optimization']['regularisation']['sizeISO']

        dim = nib.load(args.commitOutputPath+"/Results_"+model+"ZeppelinBall/compartment_IC.nii.gz").get_data().shape


        norm_fib = np.load(args.commitOutputPath+"/Coeff_x_"+model+"ZeppelinBall/norm_fib.npy")
        #add the normalisation
        x = np.load(args.commitOutputPath+"/Coeff_x_"+model+"ZeppelinBall/"+list_x_file[0]+'.npy')
        norm1 = np.load(args.commitOutputPath+"/Coeff_x_"+model+"ZeppelinBall/norm1.npy")
        norm2 = np.load(args.commitOutputPath+"/Coeff_x_"+model+"ZeppelinBall/norm2.npy")
        norm3 = np.load(args.commitOutputPath+"/Coeff_x_"+model+"ZeppelinBall/norm3.npy")
        x_norm = x / np.hstack( (norm1*norm_fib,norm2,norm3) )

        num_ADI = np.zeros( nF )
        den_ADI = np.zeros( nF )

        for i in range(nIC):
            den_ADI = den_ADI + x_norm[i*nF:(i+1)*nF]
            num_ADI = num_ADI + x_norm[i*nF:(i+1)*nF] * Ra[i]

        Weight = 2 * ( num_ADI / ( den_ADI + np.spacing(1) ) ) * 1E6
        smallWeight_safe = Weight[:num_computed_streamlines]
        weak_Weight = smallWeight_safe[:1]
        big_Weight = smallWeight_safe[:1]
        good_Weight = copy.copy(smallWeight_safe)
    else:#model==Stick
        file = open( args.commitOutputPath+"/Results_"+model+"ZeppelinBall/results.pickle",'rb' )
        object_file = pickle.load( file )
        nF = object_file[0]['optimization']['regularisation']['sizeIC']
        x = np.load(args.commitOutputPath+"/Coeff_x_"+model+"ZeppelinBall/"+list_x_file[0]+'.npy')
        norm1 = np.load(args.commitOutputPath+"/Coeff_x_"+model+"ZeppelinBall/norm1.npy")
        norm2 = np.load(args.commitOutputPath+"/Coeff_x_"+model+"ZeppelinBall/norm2.npy")
        norm3 = np.load(args.commitOutputPath+"/Coeff_x_"+model+"ZeppelinBall/norm3.npy")
        x_norm = x / np.hstack( (norm1*norm_fib,norm2,norm3) )

        Weight = x_norm[:nF]  #signal fractions
        smallWeight_safe = Weight[:num_computed_streamlines]
        weak_Weight = smallWeight_safe[:1]
        big_Weight = smallWeight_safe[:1]
        good_Weight = copy.copy(smallWeight_safe)

    #load streamlines from the dictionary_TRK_fibers_trk file
    streams, hdr = nib.trackvis.read(args.commitOutputPath+"/dictionary_TRK_fibers.trk")
    streamlines = [s[0] for s in streams]
    smallBundle_safe = streamlines[:num_computed_streamlines]
    weak_bundle = smallBundle_safe[:1]
    big_bundle = smallBundle_safe[:1]
    good_bundle = copy.copy(smallBundle_safe)
    #number of good streamlines
    num_streamlines = len(smallBundle_safe)


    # mapping streamlines and initial weights(with a red bar) in a renderer
    hue = [0, 0]  # red only
    saturation = [0.0, 1.0]  # black to white

    lut_cmap = actor.colormap_lookup_table(
        scale_range=(0, max_weight),
        hue_range=hue,
        saturation_range=saturation)

    weak_stream_actor = actor.line(weak_bundle, weak_Weight,
                                   lookup_colormap=lut_cmap)
    big_stream_actor = actor.line(big_bundle, big_Weight,
                                lookup_colormap=lut_cmap)
    good_stream_actor = actor.line(good_bundle, good_Weight,
                               lookup_colormap=lut_cmap)

    bar = actor.scalar_bar(lut_cmap, title = 'weight')
    bar.SetHeight(0.5)
    bar.SetWidth(0.1)
    bar.SetPosition(0.85,0.45)

    renderer = window.Renderer()

    renderer.set_camera(position=(-176.42, 118.52, 128.20),
                        focal_point=(113.30, 100, 76.56),
                        view_up=(0.18, 0.00, 0.98))

    renderer.add(big_stream_actor)
    renderer.add(good_stream_actor)
    renderer.add(weak_stream_actor)
    renderer.add(bar)

    #adding sliders and renderer to a ShowManager
    show_m = window.ShowManager(renderer, size=(1200, 900))
    show_m.initialize()

    save_one_image_bouton = ui.LineSlider2D(min_value=0,
                                    max_value=1,
                                    initial_value=0,
                                    text_template="save",
                                    length=1)

    add_graph_bouton = ui.LineSlider2D(min_value=0,
                                    max_value=1,
                                    initial_value=0,
                                    text_template="graph",
                                    length=1)

    color_slider = ui.LineSlider2D(min_value=0.0,
                                     max_value=1.0,
                                     initial_value=0,
                                     text_template="{value:.1f}",
                                     length=140)

    change_colormap_slider = ui.LineSlider2D(min_value=0,
                                    max_value=1.0,
                                    initial_value=0,
                                    text_template="",
                                    length=40)

    change_iteration_slider = ui.LineSlider2D(min_value=0,
                    #we can't have max_value=num_iteration because
                    #list_x_file[num_iteration] lead to an error
                                    max_value=num_iteration-0.01,
                                    initial_value=0,
                                    text_template=list_x_file[0],
                                    length=140)

    remove_big_weights_slider = ui.LineSlider2D(min_value=0,
                                    max_value=max_weight,
                                    initial_value=max_weight,
                                    text_template="{value:.2f}",
                                    length=140)

    remove_small_weights_slider = ui.LineSlider2D(min_value=0,
                                    max_value=max_weight,
                                    initial_value=0,
                                    text_template="{value:.2f}",
                                    length=140)

    opacity_slider = ui.LineSlider2D(min_value=0.0,
                                     max_value=1.0,
                                     initial_value=0.5,
                                     text_template="{ratio:.0%}",
                                     length=140)

    save_one_image_bouton.add_callback(save_one_image_bouton.slider_disk,
                                "LeftButtonPressEvent", save_one_image)

    color_slider.add_callback(color_slider.slider_disk,
                                "MouseMoveEvent", change_streamlines_color)
    color_slider.add_callback(color_slider.slider_line,
                               "LeftButtonPressEvent", change_streamlines_color)
    add_graph_bouton.add_callback(add_graph_bouton.slider_disk,
                                "LeftButtonPressEvent", add_graph)

    change_colormap_slider.add_callback(change_colormap_slider.slider_disk,
                                "MouseMoveEvent", change_colormap)
    change_colormap_slider.add_callback(change_colormap_slider.slider_line,
                                "LeftButtonPressEvent", change_colormap)
    change_iteration_slider.add_callback(change_iteration_slider.slider_disk,
                                "MouseMoveEvent", change_iteration)
    change_iteration_slider.add_callback(change_iteration_slider.slider_line,
                               "LeftButtonPressEvent", change_iteration)

    remove_big_weights_slider.add_callback(remove_big_weights_slider.slider_disk,
                                "MouseMoveEvent", remove_big_weight)
    remove_big_weights_slider.add_callback(remove_big_weights_slider.slider_line,
                               "LeftButtonPressEvent", remove_big_weight)

    remove_small_weights_slider.add_callback(remove_small_weights_slider.slider_disk,
                                "MouseMoveEvent", remove_small_weight)
    remove_small_weights_slider.add_callback(remove_small_weights_slider.slider_line,
                               "LeftButtonPressEvent", remove_small_weight)
    opacity_slider.add_callback(opacity_slider.slider_disk,
                                "MouseMoveEvent", change_opacity)
    opacity_slider.add_callback(opacity_slider.slider_line,
                               "LeftButtonPressEvent", change_opacity)

    color_slider_label = ui.TextBlock2D()
    color_slider_label.message = 'color of streamlines'

    change_colormap_slider_label_weight = ui.TextBlock2D()
    change_colormap_slider_label_weight.message = 'weight color'
    change_colormap_slider_label_direction = ui.TextBlock2D()
    change_colormap_slider_label_direction.message = 'direction color'

    change_iteration_slider_label = ui.TextBlock2D()
    change_iteration_slider_label.message = 'number of the iteration'

    remove_big_weights_slider_label = ui.TextBlock2D()
    remove_big_weights_slider_label.message = 'big weights subdued'

    remove_small_weights_slider_label = ui.TextBlock2D()
    remove_small_weights_slider_label.message = 'small weights subdued'

    opacity_slider_label = ui.TextBlock2D()
    opacity_slider_label.message = 'Unwanted weights opacity'

    numbers_of_streamlines_in_interval = ui.TextBlock2D()
    numbers_of_streamlines_in_interval.message = "Number of streamlines in interval: "+str(num_streamlines)


    panel = ui.Panel2D(center=(300, 160),
                       size=(500, 280),
                       color=(1, 1, 1),
                       opacity=0.1,
                       align="right")

    panel.add_element(save_one_image_bouton, 'relative', (0.9, 0.9))
    panel.add_element(add_graph_bouton, 'relative', (0.9, 0.77))
    panel.add_element(color_slider_label, 'relative', (0.05, 0.85))
    panel.add_element(color_slider, 'relative', (0.7, 0.9))
    panel.add_element(numbers_of_streamlines_in_interval, 'relative', (0.05, 0.72))
    panel.add_element(change_colormap_slider_label_weight, 'relative', (0.05, 0.59))
    panel.add_element(change_colormap_slider_label_direction, 'relative', (0.5, 0.59))
    panel.add_element(change_colormap_slider, 'relative', (0.4, 0.64))
    panel.add_element(change_iteration_slider_label, 'relative', (0.05, 0.46))
    panel.add_element(change_iteration_slider, 'relative', (0.7, 0.51))
    panel.add_element(remove_big_weights_slider_label, 'relative', (0.05, 0.33))
    panel.add_element(remove_big_weights_slider, 'relative', (0.7, 0.37))
    panel.add_element(remove_small_weights_slider_label, 'relative', (0.05, 0.2))
    panel.add_element(remove_small_weights_slider, 'relative', (0.7, 0.24))
    panel.add_element(opacity_slider_label, 'relative', (0.05, 0.07))
    panel.add_element(opacity_slider, 'relative', (0.7, 0.11))

    panel.add_to_renderer(renderer)
    renderer.reset_clipping_range()

    show_m.render()
    show_m.start()