Exemplo n.º 1
0
    def OnMouseMove(self, obj, eventType):
        # Translate a choosen actor
        if self.chosenActor is not None:
            # Redo the same calculation as during OnRightButtonDown
            screen_pos = self.GetInteractor().GetEventPosition()
            self.picker.Pick(screen_pos[0], screen_pos[1], 0, self.renderer)
            actor_new = self.picker.GetActor()
            world_pos_new = self.picker.GetPickPosition()

            # Calculate the xy movement
            dx = world_pos_new[0] - self.world_pos[0]
            dy = world_pos_new[1] - self.world_pos[1]
            dz = world_pos_new[2] - self.world_pos[2]

            # Remember the new reference coordinate
            self.world_pos = world_pos_new

            # Shift the choosen actor in the xy plane
            x, y, z = self.chosenActor.GetPosition()
            self.chosenActor.SetPosition(x + dx, y + dy, z + dz)

            # Request a redraw
            self.GetInteractor().GetRenderWindow().Render()
        else:
            CustomInteractorStyle.on_mouse_move(self, obj, eventType)
    def OnRightButtonUp(self, obj, eventType):
        # When the right button is released, we stop the interaction
        self.chosenActor = None

        # Call parent interaction
        # super(MouseInteractorStylePP, self).on_right_button_up(obj, eventType)
        CustomInteractorStyle.on_right_button_up(self, obj, eventType)
Exemplo n.º 3
0
    def __init__(self, renderer, actors):
        CustomInteractorStyle.__init__(self)

        # Remember data we need for the interaction
        self.peaker = vtk.vtkPointPicker()
        self.renderer = renderer
        self.chosenActor = None
        self.actors = actors
Exemplo n.º 4
0
 def SetInteractor(self, interactor):
     CustomInteractorStyle.SetInteractor(self, interactor)
     # The following three events are involved in the actors interaction.
     self.RemoveObservers('RightButtonPressEvent')
     self.RemoveObservers('RightButtonReleaseEvent')
     self.RemoveObservers('MouseMoveEvent')
     self.AddObserver('RightButtonPressEvent', self.OnRightButtonDown)
     self.AddObserver('RightButtonReleaseEvent', self.OnRightButtonUp)
     self.AddObserver('MouseMoveEvent', self.OnMouseMove)
Exemplo n.º 5
0
    def OnRightButtonDown(self, obj, eventType):
        # The rightbutton can be used to pick up the actor.

        # Get the display mouse event position
        screen_pos = self.GetInteractor().GetEventPosition()

        # Use a picker to see which actor is under the mouse
        self.picker.Pick(screen_pos[0], screen_pos[1], 0, self.renderer)
        actor = self.picker.GetActor()

        # Is this a actor that we should interact with?
        if actor in self.actors:
            # Yes! Remember it.
            self.chosenActor = actor
            self.world_pos = self.picker.GetPickPosition()

        # Call parent interaction
        CustomInteractorStyle.on_right_button_down(self, obj, eventType)
Exemplo n.º 6
0
    def __init__(self,
                 ren,
                 title='DIPY',
                 size=(300, 300),
                 png_magnify=1,
                 reset_camera=True,
                 order_transparent=False,
                 interactor_style='custom'):
        """ Manages the visualization pipeline

        Parameters
        ----------
        ren : Renderer() or vtkRenderer()
            The scene that holds all the actors.
        title : string
            A string for the window title bar.
        size : (int, int)
            ``(width, height)`` of the window. Default is (300, 300).
        png_magnify : int
            Number of times to magnify the screenshot. This can be used to save
            high resolution screenshots when pressing 's' inside the window.
        reset_camera : bool
            Default is True. You can change this option to False if you want to
            keep the camera as set before calling this function.
        order_transparent : bool
            True is useful when you want to order transparent
            actors according to their relative position to the camera. The
            default option which is False will order the actors according to
            the order of their addition to the Renderer().
        interactor_style : str or vtkInteractorStyle
            If str then if 'trackball' then vtkInteractorStyleTrackballCamera()
            is used, if 'image' then vtkInteractorStyleImage() is used (no
            rotation) or if 'custom' then CustomInteractorStyle is used.
            Otherwise you can input your own interactor style.

        Attributes
        ----------
        ren : vtkRenderer()
        iren : vtkRenderWindowInteractor()
        style : vtkInteractorStyle()
        window : vtkRenderWindow()

        Methods
        -------
        initialize()
        render()
        start()
        add_window_callback()

        Notes
        -----
        Default interaction keys for

        * 3d navigation are with left, middle and right mouse dragging
        * resetting the camera press 'r'
        * saving a screenshot press 's'
        * for quiting press 'q'

        Examples
        --------
        >>> from dipy.viz import actor, window
        >>> renderer = window.Renderer()
        >>> renderer.add(actor.axes())
        >>> showm = window.ShowManager(renderer)
        >>> # showm.initialize()
        >>> # showm.render()
        >>> # showm.start()
        """

        self.ren = ren
        self.title = title
        self.size = size
        self.png_magnify = png_magnify
        self.reset_camera = reset_camera
        self.order_transparent = order_transparent
        self.interactor_style = interactor_style

        if self.reset_camera:
            self.ren.ResetCamera()

        self.window = vtk.vtkRenderWindow()
        self.window.AddRenderer(ren)

        if self.title == 'DIPY':
            self.window.SetWindowName(title + ' ' + dipy_version)
        else:
            self.window.SetWindowName(title)
        self.window.SetSize(size[0], size[1])

        if self.order_transparent:

            # Use a render window with alpha bits
            # as default is 0 (false))
            self.window.SetAlphaBitPlanes(True)

            # Force to not pick a framebuffer with a multisample buffer
            # (default is 8)
            self.window.SetMultiSamples(0)

            # Choose to use depth peeling (if supported)
            # (default is 0 (false)):
            self.ren.UseDepthPeelingOn()

            # Set depth peeling parameters
            # Set the maximum number of rendering passes (default is 4)
            ren.SetMaximumNumberOfPeels(4)

            # Set the occlusion ratio (initial value is 0.0, exact image):
            ren.SetOcclusionRatio(0.0)

        if self.interactor_style == 'image':
            self.style = vtk.vtkInteractorStyleImage()
        elif self.interactor_style == 'trackball':
            self.style = vtk.vtkInteractorStyleTrackballCamera()
        elif self.interactor_style == 'custom':
            self.style = CustomInteractorStyle()
        else:
            self.style = interactor_style

        self.iren = vtk.vtkRenderWindowInteractor()
        self.style.SetCurrentRenderer(self.ren)
        # Hack: below, we explicitly call the Python version of SetInteractor.
        self.style.SetInteractor(self.iren)
        self.iren.SetInteractorStyle(self.style)
        self.iren.SetRenderWindow(self.window)
Exemplo n.º 7
0
class ShowManager(object):
    """ This class is the interface between the renderer, the window and the
    interactor.
    """
    def __init__(self,
                 ren,
                 title='DIPY',
                 size=(300, 300),
                 png_magnify=1,
                 reset_camera=True,
                 order_transparent=False,
                 interactor_style='custom'):
        """ Manages the visualization pipeline

        Parameters
        ----------
        ren : Renderer() or vtkRenderer()
            The scene that holds all the actors.
        title : string
            A string for the window title bar.
        size : (int, int)
            ``(width, height)`` of the window. Default is (300, 300).
        png_magnify : int
            Number of times to magnify the screenshot. This can be used to save
            high resolution screenshots when pressing 's' inside the window.
        reset_camera : bool
            Default is True. You can change this option to False if you want to
            keep the camera as set before calling this function.
        order_transparent : bool
            True is useful when you want to order transparent
            actors according to their relative position to the camera. The
            default option which is False will order the actors according to
            the order of their addition to the Renderer().
        interactor_style : str or vtkInteractorStyle
            If str then if 'trackball' then vtkInteractorStyleTrackballCamera()
            is used, if 'image' then vtkInteractorStyleImage() is used (no
            rotation) or if 'custom' then CustomInteractorStyle is used.
            Otherwise you can input your own interactor style.

        Attributes
        ----------
        ren : vtkRenderer()
        iren : vtkRenderWindowInteractor()
        style : vtkInteractorStyle()
        window : vtkRenderWindow()

        Methods
        -------
        initialize()
        render()
        start()
        add_window_callback()

        Notes
        -----
        Default interaction keys for

        * 3d navigation are with left, middle and right mouse dragging
        * resetting the camera press 'r'
        * saving a screenshot press 's'
        * for quiting press 'q'

        Examples
        --------
        >>> from dipy.viz import actor, window
        >>> renderer = window.Renderer()
        >>> renderer.add(actor.axes())
        >>> showm = window.ShowManager(renderer)
        >>> # showm.initialize()
        >>> # showm.render()
        >>> # showm.start()
        """

        self.ren = ren
        self.title = title
        self.size = size
        self.png_magnify = png_magnify
        self.reset_camera = reset_camera
        self.order_transparent = order_transparent
        self.interactor_style = interactor_style

        if self.reset_camera:
            self.ren.ResetCamera()

        self.window = vtk.vtkRenderWindow()
        self.window.AddRenderer(ren)

        if self.title == 'DIPY':
            self.window.SetWindowName(title + ' ' + dipy_version)
        else:
            self.window.SetWindowName(title)
        self.window.SetSize(size[0], size[1])

        if self.order_transparent:

            # Use a render window with alpha bits
            # as default is 0 (false))
            self.window.SetAlphaBitPlanes(True)

            # Force to not pick a framebuffer with a multisample buffer
            # (default is 8)
            self.window.SetMultiSamples(0)

            # Choose to use depth peeling (if supported)
            # (default is 0 (false)):
            self.ren.UseDepthPeelingOn()

            # Set depth peeling parameters
            # Set the maximum number of rendering passes (default is 4)
            ren.SetMaximumNumberOfPeels(4)

            # Set the occlusion ratio (initial value is 0.0, exact image):
            ren.SetOcclusionRatio(0.0)

        if self.interactor_style == 'image':
            self.style = vtk.vtkInteractorStyleImage()
        elif self.interactor_style == 'trackball':
            self.style = vtk.vtkInteractorStyleTrackballCamera()
        elif self.interactor_style == 'custom':
            self.style = CustomInteractorStyle()
        else:
            self.style = interactor_style

        self.iren = vtk.vtkRenderWindowInteractor()
        self.style.SetCurrentRenderer(self.ren)
        # Hack: below, we explicitly call the Python version of SetInteractor.
        self.style.SetInteractor(self.iren)
        self.iren.SetInteractorStyle(self.style)
        self.iren.SetRenderWindow(self.window)

    def initialize(self):
        """ Initialize interaction
        """
        self.iren.Initialize()

    def render(self):
        """ Renders only once
        """
        self.window.Render()

    def start(self):
        """ Starts interaction
        """
        try:
            self.iren.Start()
        except AttributeError:
            self.__init__(self.ren,
                          self.title,
                          size=self.size,
                          png_magnify=self.png_magnify,
                          reset_camera=self.reset_camera,
                          order_transparent=self.order_transparent,
                          interactor_style=self.interactor_style)
            self.initialize()
            self.render()
            self.iren.Start()

        self.window.RemoveRenderer(self.ren)
        self.ren.SetRenderWindow(None)
        del self.iren
        del self.window

    def record_events(self):
        """ Records events during the interaction.

        The recording is represented as a list of VTK events that happened
        during the interaction. The recorded events are then returned.

        Returns
        -------
        events : str
            Recorded events (one per line).

        Notes
        -----
        Since VTK only allows recording events to a file, we use a
        temporary file from which we then read the events.
        """
        with InTemporaryDirectory():
            filename = "recorded_events.log"
            recorder = vtk.vtkInteractorEventRecorder()
            recorder.SetInteractor(self.iren)
            recorder.SetFileName(filename)

            def _stop_recording_and_close(obj, evt):
                recorder.Stop()
                self.iren.TerminateApp()

            self.iren.AddObserver("ExitEvent", _stop_recording_and_close)

            recorder.EnabledOn()
            recorder.Record()

            self.initialize()
            self.render()
            self.iren.Start()

            # Retrieved recorded events.
            events = open(filename).read()

        return events

    def record_events_to_file(self, filename="record.log"):
        """ Records events during the interaction.

        The recording is represented as a list of VTK events
        that happened during the interaction. The recording is
        going to be saved into `filename`.

        Parameters
        ----------
        filename : str
            Name of the file that will contain the recording (.log|.log.gz).
        """
        events = self.record_events()

        # Compress file if needed
        if filename.endswith(".gz"):
            gzip.open(filename, 'wb').write(asbytes(events))
        else:
            open(filename, 'w').write(events)

    def play_events(self, events):
        """ Plays recorded events of a past interaction.

        The VTK events that happened during the recorded interaction will be
        played back.

        Parameters
        ----------
        events : str
            Recorded events (one per line).
        """
        recorder = vtk.vtkInteractorEventRecorder()
        recorder.SetInteractor(self.iren)

        recorder.SetInputString(events)
        recorder.ReadFromInputStringOn()

        self.initialize()
        self.render()
        recorder.Play()

    def play_events_from_file(self, filename):
        """ Plays recorded events of a past interaction.

        The VTK events that happened during the recorded interaction will be
        played back from `filename`.

        Parameters
        ----------
        filename : str
            Name of the file containing the recorded events (.log|.log.gz).
        """
        # Uncompress file if needed.
        if filename.endswith(".gz"):
            with gzip.open(filename, 'r') as f:
                events = f.read()
        else:
            with open(filename) as f:
                events = f.read()

        self.play_events(events)

    def add_window_callback(self, win_callback):
        """ Add window callbacks
        """
        self.window.AddObserver(vtk.vtkCommand.ModifiedEvent, win_callback)
        self.window.Render()
Exemplo n.º 8
0
    def initialize_scene(self):
        self.ren = window.Renderer()
        self.iren = CustomInteractorStyle()
        self.show_m = window.ShowManager(self.ren,
                                         size=self.screen_size,
                                         interactor_style=self.iren)

        # Add clustering panel to the scene.
        self.clustering_panel = self._make_clustering_panel()
        self.clustering_panel.set_visibility(False)
        self.ren.add(self.clustering_panel)

        # Add "Reset/Home" button
        def reset_button_callback(iren, obj, button):
            # iren: CustomInteractorStyle
            # obj: vtkActor picked
            # button: Button2D
            print("Merging remaining bundles...")

            streamlines = nib.streamlines.ArraySequence()
            for k, bundle in self.bundles.items():
                streamlines.extend(bundle.streamlines)
                self.remove_bundle(k)

            if len(streamlines) == 0:
                print("No streamlines left to merge.")
                iren.force_render()
                iren.event.abort()  # Stop propagating the event.
                return

            # Add new root bundle to the scene.
            self.add_bundle(self.root_bundle, Bundle(streamlines))
            self._add_bundle_right_click_callback(
                self.bundles[self.root_bundle], self.root_bundle)
            self.select(None)

            print("{} streamlines merged.".format(len(streamlines)))
            button.color = (1, 1, 1)  # Restore color.
            iren.force_render()
            iren.event.abort()  # Stop propagating the event.

        reset_button = gui_2d.Button2D(
            icon_fnames={'reset': read_viz_icons(fname='home3_neg.png')})
        reset_button.color = (1, 1, 1)
        reset_button.add_callback("LeftButtonPressEvent",
                                  animate_button_callback)
        reset_button.add_callback("LeftButtonReleaseEvent",
                                  reset_button_callback)
        reset_button.set_center(
            (self.screen_size[0] - 20, self.screen_size[1] - 60))
        self.ren.add(reset_button)

        # Add toggle "Centroid/Streamlines" button
        def centroids_toggle_button_callback(iren, obj, button):
            if button.current_icon_name == "streamlines":
                button.next_icon()
                for bundle in self.bundles.values():
                    bundle.show_centroids()
                    bundle.hide_streamlines()

            elif button.current_icon_name == "centroids":
                button.next_icon()
                for bundle in self.bundles.values():
                    bundle.hide_centroids()
                    bundle.show_streamlines()

            iren.force_render()
            iren.event.abort()  # Stop propagating the event.

        centroids_toggle_button = gui_2d.Button2D(
            icon_fnames={
                'streamlines': read_viz_icons(fname='database_neg.png'),
                'centroids': read_viz_icons(fname='centroid_neg.png')
            })
        centroids_toggle_button.color = (1, 1, 1)
        # centroids_toggle_button.add_callback("LeftButtonPressEvent", animate_button_callback)
        centroids_toggle_button.add_callback("LeftButtonReleaseEvent",
                                             centroids_toggle_button_callback)
        centroids_toggle_button.set_center((20, self.screen_size[1] - 20))
        self.ren.add(centroids_toggle_button)

        # Add objects to the scene.
        self.ren.add(self.bundles[self.root_bundle])
        self._add_bundle_right_click_callback(self.bundles[self.root_bundle],
                                              self.root_bundle)

        # Add shortcut keys.
        def select_biggest_cluster_onchar_callback(iren, evt_name):
            if self.verbose:
                print("Pressed {} (shift={}), (ctrl={}), (alt={})".format(
                    iren.event.key, iren.event.shift_key, iren.event.ctrl_key,
                    iren.event.alt_key))

            if iren.event.key.lower() == "escape":
                self.select(None)

            elif "tab" in iren.event.key.lower():
                if iren.event.shift_key:
                    self.select_previous()
                else:
                    self.select_next()
            elif iren.event.key == "c":
                for bundle in self.bundles.values():
                    bundle.show_centroids()
                    bundle.hide_streamlines()
                iren.force_render()
            elif iren.event.key == "C":
                for bundle in self.bundles.values():
                    bundle.hide_centroids()
                    bundle.show_streamlines()
                iren.force_render()

            iren.event.abort()  # Stop propagating the event.

        self.iren.AddObserver("CharEvent",
                              select_biggest_cluster_onchar_callback)

        # Add anatomy, if there is one.
        if self.anat is not None:
            anat_data = self.anat.get_data()
            if anat_data.ndim == 4:
                # Take b0 (assuming it is diffusion data)
                anat_data = anat_data[..., 0]

            self.anat_axial_slicer = actor.slicer(anat_data,
                                                  affine=self.anat.affine)
            self.anat_coronal_slicer = actor.slicer(anat_data,
                                                    affine=self.anat.affine)
            self.anat_sagittal_slicer = actor.slicer(anat_data,
                                                     affine=self.anat.affine)
            self.ren.add(self.anat_axial_slicer, self.anat_coronal_slicer,
                         self.anat_sagittal_slicer)
            self.anatomy_panel = self._make_anatomy_panel(
                self.anat_axial_slicer, self.anat_coronal_slicer,
                self.anat_sagittal_slicer)
            self.ren.add(self.anatomy_panel)
Exemplo n.º 9
0
class StreamlinesVizu(object):
    def __init__(self,
                 tractogram,
                 anat=None,
                 screen_size=(1360, 768),
                 default_clustering_threshold=None,
                 verbose=False):
        self.screen_size = screen_size
        self.default_clustering_threshold = default_clustering_threshold
        self.verbose = verbose

        self.cpt = None  # Used for iterating through the clusters.
        self.bundles = {}
        self.root_bundle = "/"
        self.keys = [self.root_bundle]
        self.bundles[self.root_bundle] = Bundle(tractogram.streamlines)
        self.selected_bundle = None
        self.last_threshold = None
        self.last_bundles_visibility_state = "dimmed"
        self.anat = anat
        self.anat_actor = None

    def _set_bundles_visibility(self, state, bundles=None, exclude=[]):
        if bundles is None:
            bundles = list(self.bundles.values())

        if state == "visible":
            self.show_dim_hide_button.color = (0, 1, 0)
            # self.last_bundles_visibility_state = "visible"
            visibility = True
            opacity = 1

        elif state == "dimmed":
            self.show_dim_hide_button.color = (0, 0, 1)
            self.last_bundles_visibility_state = "dimmed"
            visibility = True
            opacity = 0.6

        elif state == "hidden":
            self.show_dim_hide_button.color = (1, 0, 0)
            self.last_bundles_visibility_state = "hidden"
            visibility = False
            opacity = 1

        else:
            raise ValueError("Unknown visibility state: {}".format(state))

        # Make the changes
        for bundle in bundles:
            if bundle not in exclude:
                bundle.actor.SetVisibility(visibility)
                if opacity < 1:
                    opacity = max(
                        0.1, opacity - 0.1 * np.log10(len(bundle.streamlines)))
                bundle.actor.GetProperty().SetOpacity(opacity)

    def add_bundle(self, bundle_name, bundle):
        self.keys.append(bundle_name)
        self.keys = sorted(self.keys)
        self.ren.add(bundle)
        self.bundles[bundle_name] = bundle

    def remove_bundle(self, bundle_name):
        self.keys.remove(bundle_name)
        self.keys = sorted(self.keys)
        self.ren.rm(self.bundles[bundle_name])
        del self.bundles[bundle_name]

    def select_next(self):
        # Sort bundle according to their bundle size.
        indices = np.lexsort(
            (self.keys, [len(self.bundles[k].streamlines)
                         for k in self.keys])).tolist()[::-1]

        if self.selected_bundle is None:
            cpt = 0
        else:
            cpt = indices.index(self.keys.index(self.selected_bundle))
            cpt = (cpt + 1) % len(self.keys)

        self.select(self.keys[indices[cpt]])
        print("({}/{})".format(cpt + 1, len(self.keys)))

    def select_previous(self):
        # Sort bundle according to their bundle size.
        indices = np.lexsort(
            (self.keys, [len(self.bundles[k].streamlines)
                         for k in self.keys])).tolist()[::-1]

        if self.selected_bundle is None:
            cpt = 0
        else:
            cpt = indices.index(self.keys.index(self.selected_bundle))
            cpt = (cpt - 1) % len(self.keys)

        self.select(self.keys[indices[cpt]])
        print("({}/{})".format(cpt + 1, len(self.keys)))

    def select(self, bundle_name=None):
        # Unselect first, if possible.
        if self.selected_bundle is not None and self.selected_bundle in self.bundles:
            bundle = self.bundles[self.selected_bundle]
            bundle.reset()

        if bundle_name is None:
            # Close panels
            self.selected_bundle = None
            self.clustering_panel.set_visibility(False)
            self._set_bundles_visibility("visible")
            self.iren.force_render()
            self.cpt = None  # Used for iterating through the clusters.
            return

        self.selected_bundle = bundle_name
        bundle = self.bundles[self.selected_bundle]
        print("Selecting {} streamlines...".format(len(bundle.streamlines)))

        # Set maximum threshold value depending on the selected bundle.
        self.clustering_panel.slider.max_value = bundle.actor.GetLength() / 2.
        if self.default_clustering_threshold is None:
            self.clustering_panel.slider.set_ratio(1)
        else:
            self.clustering_panel.slider.set_value(
                self.default_clustering_threshold)
        self.clustering_panel.slider.update()
        self.clustering_panel.set_visibility(True)

        # Dim other bundles
        self._set_bundles_visibility("visible", bundles=[bundle])
        self._set_bundles_visibility(self.last_bundles_visibility_state,
                                     exclude=[bundle])
        bundle.preview(threshold=self.clustering_panel.slider.value)

        self.iren.force_render()

    def _add_bundle_right_click_callback(self, bundle, bundle_name):
        def open_clustering_panel(iren, obj, *args):
            self.select(bundle_name)
            iren.event.abort()  # Stop propagating the event.

        def ctrl_leftcklick_open_clustering_panel(iren, obj, *args):
            if not iren.event.ctrl_key:
                return

            self.select(bundle_name)
            iren.event.abort()  # Stop propagating the event.

        self.iren.add_callback(bundle.actor, "RightButtonPressEvent",
                               open_clustering_panel)
        self.iren.add_callback(
            bundle.actor, "LeftButtonPressEvent",
            ctrl_leftcklick_open_clustering_panel)  # Support for MAC OSX

    def _make_clustering_panel(self):
        # Panel
        size = (self.screen_size[0], self.screen_size[1] // 10)
        center = tuple(np.array(size) / 2.)  # Lower left corner of the screen.
        panel = gui_2d.Panel2D(center=center,
                               size=size,
                               color=(1, 1, 1),
                               align="left")

        # Nb. clusters label
        label = gui_2d.Text2D("# clusters")
        panel.add_element(label, (0.01, 0.2))

        # "Apply" button
        def apply_button_callback(iren, obj, button):
            # iren: CustomInteractorStyle
            # obj: vtkActor picked
            # button: Button2D
            bundles = self.bundles[
                self.selected_bundle].get_cluster_as_bundles()
            print("Preparing the new {} clusters...".format(len(bundles)))

            # Create new actors, one for each new bundle.
            # Sort bundle in decreasing size.
            for i, bundle in enumerate(bundles):
                bundle_name = "{}{}/".format(self.selected_bundle, i)
                self.add_bundle(bundle_name, bundle)
                self._add_bundle_right_click_callback(bundle, bundle_name)

            # Remove original bundle.
            self.remove_bundle(self.selected_bundle)
            self.select(None)

            # TODO: apply clustering if needed, close panel, add command to history, re-enable bundles context-menu.
            button.color = (0, 1, 0)  # Restore color.
            print("Done.")
            iren.force_render()
            iren.event.abort()  # Stop propagating the event.

        button = gui_2d.Button2D(
            icon_fnames={'apply': read_viz_icons(fname='checkmark_neg.png')})
        button.color = (0, 1, 0)
        button.add_callback("LeftButtonPressEvent", animate_button_callback)
        button.add_callback("LeftButtonReleaseEvent", apply_button_callback)
        panel.add_element(button, (0.98, 0.2))

        # "Hide" button
        def toggle_other_bundles_visibility(iren, *args):
            # iren: CustomInteractorStyle
            # obj: vtkActor picked
            # button: Button2D

            if self.last_bundles_visibility_state == "dimmed":
                self.last_bundles_visibility_state = "hidden"
                self._set_bundles_visibility(
                    "hidden", exclude=[self.bundles[self.selected_bundle]])

            elif self.last_bundles_visibility_state == "hidden":
                self.last_bundles_visibility_state = "dimmed"
                self._set_bundles_visibility(
                    "dimmed", exclude=[self.bundles[self.selected_bundle]])

            iren.force_render()
            iren.event.abort()  # Stop propagating the event.

        self.show_dim_hide_button = gui_2d.Button2D(
            icon_fnames={
                'show_dim_hide': read_viz_icons(fname='infinite_neg.png')
            })
        self.show_dim_hide_button.add_callback(
            "LeftButtonPressEvent", toggle_other_bundles_visibility)
        panel.add_element(self.show_dim_hide_button, (0.02, 0.88))

        # Threshold slider
        def disk_press_callback(iren, obj, slider):
            # iren: CustomInteractorStyle
            # obj: vtkActor picked
            # slider: LineSlider2D
            # Only need to grab the focus.
            iren.event.abort()  # Stop propagating the event.

        def disk_move_callback(iren, obj, slider):
            # iren: CustomInteractorStyle
            # obj: vtkActor picked
            # slider: LineSlider2D

            # Reset textbox
            textbox = slider.textbox.actor
            if textbox in iren.active_props:
                iren.remove_active_prop(textbox)

            position = iren.event.position
            slider.set_position(position)

            threshold = slider.value
            if self.last_threshold != threshold:
                nb_bundles = self.bundles[self.selected_bundle].preview(
                    threshold)
                self.last_threshold = threshold
                label.set_message("{} clusters".format(nb_bundles))

            iren.force_render()
            iren.event.abort()  # Stop propagating the event.

        # Slider textbox
        def slider_textbox_select_callback(iren, obj, slider):
            # iren: CustomInteractorStyle
            # obj: vtkActor picked
            # slider: LineSlider2D
            iren.add_active_prop(slider.textbox.actor)
            slider.textbox.set_message("")
            slider.textbox.caret_pos = 0
            slider.textbox.render_text(show_caret=True)
            iren.force_render()

        def slider_textbox_keypress_callback(iren, obj, slider):
            # iren: CustomInteractorStyle
            # obj: vtkActor picked
            # slider: LineSlider2D

            key = iren.event.key.lower()
            textbox = slider.textbox

            if key not in [
                    "0", "1", "2", "3", "4", "5", "6", "7", "8", "9", "period",
                    "backspace", "return", "kp_enter"
            ]:
                # Unauthorized key
                pass
            elif len(textbox.text) == 0 and key in ["return", "kp_enter"]:
                # User pressed "enter" on empty field; reset textbox
                iren.remove_active_prop(textbox.actor)
                textbox.set_message(slider.format_text())
            elif len(textbox.text) >= 4 and key not in [
                    "backspace", "return", "kp_enter"
            ]:
                # Textbox is filled to max length
                pass
            else:
                # Switch for period character
                key = '.' if key == "period" else key

                # Process keypress
                is_done = textbox.handle_character(key)
                if is_done:
                    iren.remove_active_prop(textbox.actor)

                    try:
                        threshold = float(textbox.text)

                        if threshold > slider.max_value:
                            # Invalid value, reset textbox
                            textbox.set_message(slider.format_text())
                        elif self.last_threshold != threshold:
                            slider.set_value(threshold)

                            nb_bundles = self.bundles[
                                self.selected_bundle].preview(threshold)
                            self.last_threshold = threshold
                            label.set_message("{} clusters".format(nb_bundles))

                    except ValueError:
                        # Invalid value, reset textbox
                        textbox.set_message(slider.format_text())

            iren.force_render()
            iren.event.abort()  # Stop propagating the event.

        slider = gui_2d.LineSlider2D(length=1000,
                                     text_template="{value:.1f}mm")
        slider.add_callback("LeftButtonPressEvent", disk_move_callback,
                            slider.slider_line)
        slider.add_callback("LeftButtonPressEvent", disk_press_callback,
                            slider.slider_disk)
        slider.add_callback("MouseMoveEvent", disk_move_callback,
                            slider.slider_disk)
        slider.add_callback("MouseMoveEvent", disk_move_callback,
                            slider.slider_line)

        slider.add_callback("KeyPressEvent", slider_textbox_keypress_callback,
                            slider.text)
        slider.add_callback("LeftButtonPressEvent",
                            slider_textbox_select_callback, slider.text)

        panel.add_element(slider, (0.5, 0.5))
        panel.slider = slider

        # Add shortcut keys.
        def toggle_visibility_onchar_callback(iren, evt_name):
            if self.selected_bundle is None:
                return

            if iren.event.key.lower() == "space":
                toggle_other_bundles_visibility(iren)

        self.iren.AddObserver("CharEvent", toggle_visibility_onchar_callback)

        return panel

    def _make_anatomy_panel(self, axial_slicer, sagittal_slicer,
                            coronal_slicer):
        # Panel
        size = (self.screen_size[0] // 8, self.screen_size[1] // 6)
        center = (size[0] / 2., np.ceil(self.screen_size[1] / 10. + size[1])
                  )  # Lower left corner of the screen.
        panel = gui_2d.Panel2D(center=center,
                               size=size,
                               color=(0., 0., 0.),
                               align="left")

        # Create all sliders that will be responsible of moving the slices of the anatomy.
        length = size[0] - 10
        text_template = lambda obj: "{value:}".format(value=int(obj.value))
        axial_slider = gui_2d.LineSlider2D(length=length,
                                           text_template=text_template)
        coronal_slider = gui_2d.LineSlider2D(length=length,
                                             text_template=text_template)
        sagittal_slider = gui_2d.LineSlider2D(length=length,
                                              text_template=text_template)

        # Common to all sliders.
        def disk_press_callback(iren, obj, slider):
            # iren: CustomInteractorStyle
            # obj: vtkActor picked
            # slider: LineSlider2D
            # Only need to grab the focus.
            iren.event.abort()  # Stop propagating the event.

        def axial_disk_move_callback(iren, obj, slider):
            # iren: CustomInteractorStyle
            # obj: vtkActor picked
            # slider: LineSlider2D
            position = iren.event.position
            slider.set_position(position)
            # Move slices accordingly.
            axial_slicer.display(x=int(slider.value))
            iren.force_render()
            iren.event.abort()  # Stop propagating the event.

        def coronal_disk_move_callback(iren, obj, slider):
            # iren: CustomInteractorStyle
            # obj: vtkActor picked
            # slider: LineSlider2D
            position = iren.event.position
            slider.set_position(position)
            # Move slices accordingly.
            coronal_slicer.display(y=int(slider.value))
            iren.force_render()
            iren.event.abort()  # Stop propagating the event.

        def sagittal_disk_move_callback(iren, obj, slider):
            # iren: CustomInteractorStyle
            # obj: vtkActor picked
            # slider: LineSlider2D
            position = iren.event.position
            slider.set_position(position)
            # Move slices accordingly.
            sagittal_slicer.display(z=int(slider.value))
            iren.force_render()
            iren.event.abort()  # Stop propagating the event.

        # Add callbacks to the sliders.
        axial_slider.add_callback("LeftButtonPressEvent",
                                  axial_disk_move_callback,
                                  axial_slider.slider_line)
        axial_slider.add_callback("LeftButtonPressEvent", disk_press_callback,
                                  axial_slider.slider_disk)
        axial_slider.add_callback("MouseMoveEvent", axial_disk_move_callback,
                                  axial_slider.slider_disk)
        axial_slider.add_callback("MouseMoveEvent", axial_disk_move_callback,
                                  axial_slider.slider_line)
        axial_slider.max_value = axial_slicer.shape[0]
        axial_slider.set_ratio(0.5)
        axial_slider.update()

        coronal_slider.add_callback("LeftButtonPressEvent",
                                    coronal_disk_move_callback,
                                    coronal_slider.slider_line)
        coronal_slider.add_callback("LeftButtonPressEvent",
                                    disk_press_callback,
                                    coronal_slider.slider_disk)
        coronal_slider.add_callback("MouseMoveEvent",
                                    coronal_disk_move_callback,
                                    coronal_slider.slider_disk)
        coronal_slider.add_callback("MouseMoveEvent",
                                    coronal_disk_move_callback,
                                    coronal_slider.slider_line)
        coronal_slider.max_value = coronal_slicer.shape[1]
        coronal_slider.set_ratio(0.5)
        coronal_slider.update()

        sagittal_slider.add_callback("LeftButtonPressEvent",
                                     sagittal_disk_move_callback,
                                     sagittal_slider.slider_line)
        sagittal_slider.add_callback("LeftButtonPressEvent",
                                     disk_press_callback,
                                     sagittal_slider.slider_disk)
        sagittal_slider.add_callback("MouseMoveEvent",
                                     sagittal_disk_move_callback,
                                     sagittal_slider.slider_disk)
        sagittal_slider.add_callback("MouseMoveEvent",
                                     sagittal_disk_move_callback,
                                     sagittal_slider.slider_line)
        sagittal_slider.max_value = sagittal_slicer.shape[2]
        sagittal_slider.set_ratio(0.5)
        sagittal_slider.update()

        # Add the slicers to the panel.
        panel.add_element(axial_slider, (0.5, 0.15))
        panel.add_element(coronal_slider, (0.5, 0.5))
        panel.add_element(sagittal_slider, (0.5, 0.85))

        # Initialize slices of the anatomy.
        axial_slicer.display(x=int(axial_slider.value))
        coronal_slicer.display(y=int(coronal_slider.value))
        sagittal_slicer.display(z=int(sagittal_slider.value))

        return panel

    def initialize_scene(self):
        self.ren = window.Renderer()
        self.iren = CustomInteractorStyle()
        self.show_m = window.ShowManager(self.ren,
                                         size=self.screen_size,
                                         interactor_style=self.iren)

        # Add clustering panel to the scene.
        self.clustering_panel = self._make_clustering_panel()
        self.clustering_panel.set_visibility(False)
        self.ren.add(self.clustering_panel)

        # Add "Reset/Home" button
        def reset_button_callback(iren, obj, button):
            # iren: CustomInteractorStyle
            # obj: vtkActor picked
            # button: Button2D
            print("Merging remaining bundles...")

            streamlines = nib.streamlines.ArraySequence()
            for k, bundle in self.bundles.items():
                streamlines.extend(bundle.streamlines)
                self.remove_bundle(k)

            if len(streamlines) == 0:
                print("No streamlines left to merge.")
                iren.force_render()
                iren.event.abort()  # Stop propagating the event.
                return

            # Add new root bundle to the scene.
            self.add_bundle(self.root_bundle, Bundle(streamlines))
            self._add_bundle_right_click_callback(
                self.bundles[self.root_bundle], self.root_bundle)
            self.select(None)

            print("{} streamlines merged.".format(len(streamlines)))
            button.color = (1, 1, 1)  # Restore color.
            iren.force_render()
            iren.event.abort()  # Stop propagating the event.

        reset_button = gui_2d.Button2D(
            icon_fnames={'reset': read_viz_icons(fname='home3_neg.png')})
        reset_button.color = (1, 1, 1)
        reset_button.add_callback("LeftButtonPressEvent",
                                  animate_button_callback)
        reset_button.add_callback("LeftButtonReleaseEvent",
                                  reset_button_callback)
        reset_button.set_center(
            (self.screen_size[0] - 20, self.screen_size[1] - 60))
        self.ren.add(reset_button)

        # Add toggle "Centroid/Streamlines" button
        def centroids_toggle_button_callback(iren, obj, button):
            if button.current_icon_name == "streamlines":
                button.next_icon()
                for bundle in self.bundles.values():
                    bundle.show_centroids()
                    bundle.hide_streamlines()

            elif button.current_icon_name == "centroids":
                button.next_icon()
                for bundle in self.bundles.values():
                    bundle.hide_centroids()
                    bundle.show_streamlines()

            iren.force_render()
            iren.event.abort()  # Stop propagating the event.

        centroids_toggle_button = gui_2d.Button2D(
            icon_fnames={
                'streamlines': read_viz_icons(fname='database_neg.png'),
                'centroids': read_viz_icons(fname='centroid_neg.png')
            })
        centroids_toggle_button.color = (1, 1, 1)
        # centroids_toggle_button.add_callback("LeftButtonPressEvent", animate_button_callback)
        centroids_toggle_button.add_callback("LeftButtonReleaseEvent",
                                             centroids_toggle_button_callback)
        centroids_toggle_button.set_center((20, self.screen_size[1] - 20))
        self.ren.add(centroids_toggle_button)

        # Add objects to the scene.
        self.ren.add(self.bundles[self.root_bundle])
        self._add_bundle_right_click_callback(self.bundles[self.root_bundle],
                                              self.root_bundle)

        # Add shortcut keys.
        def select_biggest_cluster_onchar_callback(iren, evt_name):
            if self.verbose:
                print("Pressed {} (shift={}), (ctrl={}), (alt={})".format(
                    iren.event.key, iren.event.shift_key, iren.event.ctrl_key,
                    iren.event.alt_key))

            if iren.event.key.lower() == "escape":
                self.select(None)

            elif "tab" in iren.event.key.lower():
                if iren.event.shift_key:
                    self.select_previous()
                else:
                    self.select_next()
            elif iren.event.key == "c":
                for bundle in self.bundles.values():
                    bundle.show_centroids()
                    bundle.hide_streamlines()
                iren.force_render()
            elif iren.event.key == "C":
                for bundle in self.bundles.values():
                    bundle.hide_centroids()
                    bundle.show_streamlines()
                iren.force_render()

            iren.event.abort()  # Stop propagating the event.

        self.iren.AddObserver("CharEvent",
                              select_biggest_cluster_onchar_callback)

        # Add anatomy, if there is one.
        if self.anat is not None:
            anat_data = self.anat.get_data()
            if anat_data.ndim == 4:
                # Take b0 (assuming it is diffusion data)
                anat_data = anat_data[..., 0]

            self.anat_axial_slicer = actor.slicer(anat_data,
                                                  affine=self.anat.affine)
            self.anat_coronal_slicer = actor.slicer(anat_data,
                                                    affine=self.anat.affine)
            self.anat_sagittal_slicer = actor.slicer(anat_data,
                                                     affine=self.anat.affine)
            self.ren.add(self.anat_axial_slicer, self.anat_coronal_slicer,
                         self.anat_sagittal_slicer)
            self.anatomy_panel = self._make_anatomy_panel(
                self.anat_axial_slicer, self.anat_coronal_slicer,
                self.anat_sagittal_slicer)
            self.ren.add(self.anatomy_panel)

    def run(self):
        self.show_m.start()