class AffectsAverageColumn(ObjectColumn):

    # Define the context menu for the column:
    menu = Menu(Action(name='Add', action='column.add(object)'),
                Action(name='Sub', action='column.sub(object)'))

    # Right-align numeric values (override):
    horizontal_alignment = 'center'

    # Column width (override):
    width = 0.09

    # Don't allow the data to be edited directly:
    editable = False

    # Action methods for the context menu items:

    def add(self, object):
        """ Increment the affected player statistic.
        """
        setattr(object, self.name, getattr(object, self.name) + 1)

    def sub(self, object):
        """ Decrement the affected player statistic.
        """
        setattr(object, self.name, getattr(object, self.name) - 1)
Beispiel #2
0
class TmpClass(Handler):
    aa = Int(10)
    bb = Int(100)

    def init(self, ui_info):
        super(TmpClass, self).init(ui_info)
        self.save_prefs(ui_info)

    def reset_prefs(self, ui_info):
        """ Reset the split to be equally wide.
        """
        control = getattr(ui_info, 'h_split').control
        width = control.width()
        control.moveSplitter(width/2, 1)

    def restore_prefs(self, ui_info):
        """ Apply the last saved ui preferences.
        """
        ui_info.ui.set_prefs(self._prefs)

    def save_prefs(self, ui_info):
        """ Save the current ui preferences.
        """
        self._prefs = ui_info.ui.get_prefs()

    def collapse_right(self, ui_info):
        """ Collapse the split to the right.
        """
        control = getattr(ui_info, 'h_split').control
        width = control.width()
        control.moveSplitter(width, 1)

    def collapse_left(self, ui_info):
        """ Collapse the split to the left.
        """
        control = getattr(ui_info, 'h_split').control
        control.moveSplitter(0, 1)

    view = View(
        HSplit(
            Group(
                Item('aa', resizable=True, width=50), show_border=True,
            ),
            Group(
                Item('bb', width=100), show_border=True,
            ),
            id='h_split',
        ),
        resizable=True,
        # add actions to test manually.
        buttons=[Action(name='collapse left', action='collapse_left'),
                 Action(name='collapse right', action='collapse_right'),
                 Action(name='reset_layout', action='reset_prefs'),
                 Action(name='restore layout', action='restore_prefs'),
                 Action(name='save layout', action='save_prefs')],
        height=300,
        id='test_view_for_splitter_pref_restore',
    )
Beispiel #3
0
 def get_menu(self, object, trait, row, column):
     column_name = self.column_map[column]
     if column_name not in ['name', 'average']:
         menu = Menu(
             Action(name='Add', action='editor.adapter.add(item, column)'),
             Action(name='Sub', action='editor.adapter.sub(item, column)'))
         return menu
     else:
         return super().get_menu(object, trait, row, column)
 def _non_standard_menu_actions(self):
     actions = [
         Separator(),
         Action(name=RUN_ACTION, action="object.fire_cadet_request"),
         Action(name=PLOT_ACTION, action="object.fire_plot_request"),
         Action(name=DUPLIC_ACTION, action="object.fire_duplicate_request"),
         Separator()
     ]
     return actions
 def _non_standard_menu_actions(self):
     """ Returns non standard menu actions for the pop-up menu.
     """
     actions = [
         Action(name="Run All Simulations...",
                action="object.fire_cadet_request"),
         Action(name="Plot All Simulations...",
                action="object.fire_plot_request"),
     ]
     return actions
class TmpClass(Handler):
    aa = Int(10)
    bb = Int(100)

    def init(self, ui_info):
        super().init(ui_info)
        self.save_prefs(ui_info)
        return True

    def reset_prefs(self, ui_info):
        """Reset the split to be equally wide."""
        control = getattr(ui_info, "h_split").control
        width = control.width()
        control.moveSplitter(width // 2, 1)

    def restore_prefs(self, ui_info):
        """Apply the last saved ui preferences."""
        ui_info.ui.set_prefs(self._prefs)

    def save_prefs(self, ui_info):
        """Save the current ui preferences."""
        self._prefs = ui_info.ui.get_prefs()

    def collapse_right(self, ui_info):
        """Collapse the split to the right."""
        control = getattr(ui_info, "h_split").control
        width = control.width()
        control.moveSplitter(width, 1)

    def collapse_left(self, ui_info):
        """Collapse the split to the left."""
        control = getattr(ui_info, "h_split").control
        control.moveSplitter(0, 1)

    view = View(
        HSplit(
            Group(Item("aa", resizable=True, width=50), show_border=True),
            Group(Item("bb", width=100), show_border=True),
            id="h_split",
        ),
        resizable=True,
        # add actions to test manually.
        buttons=[
            Action(name="collapse left", action="collapse_left"),
            Action(name="collapse right", action="collapse_right"),
            Action(name="reset_layout", action="reset_prefs"),
            Action(name="restore layout", action="restore_prefs"),
            Action(name="save layout", action="save_prefs"),
        ],
        height=300,
        id="test_view_for_splitter_pref_restore",
    )
 def get_menu(self):
     # See TreeEditor on_context_menu code to understand the context the
     # actions are evaluated in.
     return Menu(
         Action(name="New Simulation From Experiment...",
                action="object.request_new_simulations_from_experiments"),
         #: FIXME - cant refer to datasource in the action statement?
         # Action(name="Create New Simulation...",
         #        action="object.request_new_simulation_from_datasource")
     )
 def traits_view(self):
     v = View(HGroup(UItem('canvas', editor=ComponentEditor(width=550,
                                                            height=550)),
                     UItem('results', editor=TabularEditor(
                             adapter=ResultsAdapter()))),
              handler=StageVisualizerHandler(),
              buttons=[Action(action='save', name='Save'), ],
              title='Stage Visualizer',
              resizable=True)
     return v
Beispiel #9
0
    def config_model_view(self):
        menubar_config = MenuBar(
            Menu(Action(name='New Model',
                        action='new_model',
                        tooltip='Create a new model from scratch'),
                 Action(
                     name='Copy Model',
                     action='copy_model',
                     tooltip='Create a new model by copying an existing one'),
                 name='Create Model'))

        view_config = View(Group(Item(name='model_list', style='custom'),
                                 show_border=True),
                           Item(label="Lots of stuff should go here"),
                           menubar=menubar_config,
                           buttons=NoButtons,
                           title='BiKiPy Modeler')

        return (view_config)
Beispiel #10
0
class CableArgsTable(HasStrictTraits):
    cables = List(Cable) 
    view = View(
        VGroup(
            Item('cables',
                 show_label=False,
                 editor=cableArgs_table
                 ),
            show_border=True,
        ),
        title='电缆参数设置',
        width=.4,
        height=.5,
        resizable=True,
        buttons =[Action(name='添加',action='addItem'),
        Action(name='删除',action='delItem'),
        Action(name='OK',label='确定'),Action(name='Cancel',label='取消')],
        kind='live'
    )
Beispiel #11
0
class AddLabelsWindow(Handler):
    model = Any
    #clumsily old a reference to the model object

    annotation = Str
    label = File

    add_annot_button = Button('Add annotation')
    add_label_button = Button('Add label file')

    annot_borders = Bool
    annot_opacity = Range(0., 1., 1.)
    annot_hemi = Enum('both', 'lh', 'rh')
    label_borders = Bool
    label_opacity = Range(0., 1., 1.)
    label_color = Color('blue')

    remove_labels_action = Action(name='Remove all labels', action='do_remove')

    def _add_annot_button_fired(self):
        self.model.add_annotation(self.annotation,
                                  border=self.annot_borders,
                                  hemi=self.annot_hemi,
                                  opacity=self.annot_opacity)

    def _add_label_button_fired(self):
        self.model.add_label(self.label,
                             border=self.label_borders,
                             opacity=self.label_opacity,
                             color=self.label_color)

    def do_remove(self, info):
        self.model.remove_labels()

    traits_view = View(
        HSplit(
            VGroup(
                Item('annotation'),
                Item('annot_borders', label='show border only'),
                Item('annot_opacity', label='opacity'),
                Item('annot_hemi', label='hemi'),
                Item('add_annot_button', show_label=False),
            ),
            VGroup(
                Item('label'),
                Item('label_borders', label='show_border_only'),
                Item('label_opacity', label='opacity'),
                Item('label_color', label='color'),
                Item('add_label_button', show_label=False),
            ),
        ),
        buttons=[remove_labels_action, OKButton],
        kind='livemodal',
        title='Dial 1-800-COLLECT and save a buck or two',
    )
Beispiel #12
0
class Cable(HasTraits):
    '''
    电缆参数类
    --------------
    主要参数:
    电缆电压U0/U(kV),截面(mm2),电阻(Ω/km),
    电抗(Ω/km),零序电抗(Ω/km),对地电导(s/km),
    电纳(s/km),电容(F/km),电感(H/km),
    零序电感(H/km),电流频率(Hz),电缆芯数
    ------------------
    '''
    u0 = Enum(0.45,0.6,1.8,3.6,6,8.7,12,18,21,26,38,50,64,127,190,290)(26)
    u = Enum(0.75,1,3,6,10,15,20,30,35,66,110,220,330,500)(35)
    s = Range(low=0.0,value=95)
    r = Range(low=0.0,value=0.2465)
    x = Range(low=0.0,value=0.197)
    x0 = Range(low=0.0,value=0.0939)
    g = Range(low=0.0)
    b = Range(low=0.0,value=0.4298e-6)
    c = Range(low=0.0,value=0.1368e-6)
    l = Range(low=0.0,value=0.6279e-3)
    f = Range(low=0.0,value=50.0)
    number = Enum(1,2,3,4,5)(1) 
    def _l_changed(self,old,new):
        self.x = 2.0*np.pi*self.f*new
    def _c_changed(self,old,new):
        self.b = 2.0*np.pi*self.f*new
        
    
    view = View(HGroup(
        VGroup(Item('s',label=u'电缆截面(mm2)'),Item('r',label=u'电阻(Ω)'),
        Item('x',label=u'电抗(Ω)',format_str='%0.4e'), Item('x0',label=u'零序电抗(Ω)'),
        Item('g',label=u'对地电导(s)'),Item('b',label=u'对地电纳(s)',format_str='%0.4e'),
        show_border=True),
        VGroup(Item('u0',label=u'相电压(kV)'),Item('u',label=u'额定电压(kV)'),
        Item('number',label=u'电缆芯数'),Item('f',label=u'电源频率(Hz)'),
        Item('l',label=u'电感(H/km)'),Item('c',label=u'对地电容(F)'),
        show_border=True),padding = 10),title=u'电缆参数设置',resizable = True,
        buttons =[Action(name='确定',action='ok'),
        Action(name='取消',action='cancel')])
Beispiel #13
0
    def test_on_perform_action(self):
        # A test for issue #741, where actions with an on_perform function set
        # would get called twice
        object_list = ObjectList(
            values=[ListItem(value=str(i**2)) for i in range(10)])
        mock_function = Mock()
        action = Action(on_perform=mock_function)

        tester = UITester()
        with tester.create_ui(object_list, dict(view=simple_view)) as ui:
            editor = tester.find_by_name(ui, "values")._target
            editor.set_menu_context(None, None, None)
            editor.perform(action)
        mock_function.assert_called_once()
Beispiel #14
0
class Demo(HasTraits):
    plot = Instance(Component)
    traits_view = View(Group(Item('plot',
                                  editor=ComponentEditor(size=size),
                                  show_label=False),
                             orientation="vertical"),
                       handler=DemoHandler,
                       buttons=[Action(name='Export', action='do_export')],
                       resizable=True,
                       title='hello')

    def _plot_default(self):
        plot = _create_plot_component()
        return plot
Beispiel #15
0
    def test_on_perform_action(self):
        # A test for issue #741, where actions with an on_perform function set
        # would get called twice
        object_list = ObjectList(
            values=[ListItem(value=str(i**2)) for i in range(10)])
        mock_function = Mock()
        action = Action(on_perform=mock_function)

        with reraise_exceptions(), \
                create_ui(object_list, dict(view=simple_view)) as ui:
            editor = ui.get_editors("values")[0]
            process_cascade_events()
            editor.set_menu_context(None, None, None)
            editor.perform(action)
        mock_function.assert_called_once()
class DatasetNode(TreeNode):

    # List of object classes the node applies to
    node_for = [Dataset]

    # Automatically open the children underneath the node
    auto_open = False

    # Label of the node (this is an attribute of the class in 'node_for')
    label = 'title'

    # Menu
    menu = Menu(
        Action(name="Test...",
               action="handler.get_measurement(editor,object)"))

    #icon_path = Str('../images/')

    # View for the node
    view = View()
Beispiel #17
0
class InaivuModel(Handler):

    brain = Any  # Instance(surfer.viz.Brain)
    ieeg_loc = Any  #OrderedDict
    ieeg_glyph = Any  # Instance(mlab.Glyph3D)
    ch_names = Any  #List(Str) ?

    scene = Any  # mayavi.core.Scene
    scene = Instance(MlabSceneModel, ())

    _time_low = Float(0)
    _time_high = Float(1)
    time_slider = Float(0)

    shell = Dict

    subjects_dir = Str
    subject = Str('fake_subject')

    invasive_signals = Dict  # Str -> Instance(InvasiveSignal)
    current_invasive_signal = Instance(source_signal.InvasiveSignal)

    noninvasive_signals = Dict  # Str -> Instance(NoninvasiveSignal)
    current_noninvasive_signal = Instance(source_signal.NoninvasiveSignal)
    megsig = Dict
    invasive_labels = Any  #Dict
    invasive_labels_id = Any

    opacity = Float(.35)

    use_smoothing = Bool(False)
    smoothing_steps = Int(0)

    smoothl = Any  #Either(np.ndarray, None)
    smoothr = Any  #Either(np.ndarray, None)

    browser = Any  #Instance(BrowseStc)

    current_script_file = File
    run_script_button = Button('Load')

    # movie window
    make_movie_button = Button('Movie')

    movie_filename = File
    movie_normalization_style = Enum('local', 'global', 'none')

    movie_use_invasive = Bool(True)
    movie_use_noninvasive = Bool(True)
    movie_tmin = Float(0.)
    movie_tmax = Float(1.)
    movie_invasive_tmin = Float(0.)
    movie_invasive_tmax = Float(1.)
    movie_noninvasive_tmin = Float(0.)
    movie_noninvasive_tmax = Float(1.)

    movie_framerate = Float(24)
    movie_dilation = Float(2)
    movie_bitrate = Str('750k')
    movie_interpolation = Enum('quadratic', 'cubic', 'linear', 'slinear',
                               'nearest', 'zero')
    movie_animation_degrees = Float(0.)
    movie_sample_which_first = Enum('invasive', 'noninvasive')

    OKMakeMovieAction = Action(name='Make movie', action='do_movie')

    traits_view = View(
        Item('scene',
             editor=SceneEditor(scene_class=MayaviScene),
             show_label=False,
             height=300,
             width=300),
        VGroup(
            Item('time_slider',
                 editor=RangeEditor(mode='xslider',
                                    low_name='_time_low',
                                    high_name='_time_high',
                                    format='%.3f',
                                    is_float=True),
                 label='time'), ),
        HGroup(
            Item('make_movie_button', show_label=False),
            Label('Subject'),
            Item('current_script_file', show_label=False),
            Item('run_script_button', show_label=False),
        ),
        #Item('time_slider', style='custom', show_label=False),
        # Item('shell', editor=ShellEditor(), height=300, show_label=False),

        # title='Das ist meine Wassermelone es ist MEINE',
        title='Multi-Modalities Visualization',
        resizable=True,
    )

    def _run_script_button_fired(self):
        with open(self.current_script_file) as fd:
            exec(fd)

    make_movie_view = View(
        Label('Click make movie to specify filename'),
        #Item('movie_filename', label='filename', style='readonly'),
        HGroup(
            VGroup(
                HGroup(
                    Item('movie_use_invasive',
                         label='include invasive signal'), ),
                Item('movie_invasive_tmin',
                     label='invasive tmin',
                     enabled_when="movie_use_invasive"),
                Item('movie_invasive_tmax',
                     label='invasive tmin',
                     enabled_when="movie_use_invasive"),
                Item('movie_normalization_style', label='normalization'),
                Item('movie_sample_which_first',
                     label='samples first',
                     enabled_when="movie_use_noninvasive and "
                     "movie_use_invasive"),
            ),
            VGroup(
                HGroup(
                    Item('movie_use_noninvasive',
                         label='include noninvasive signal'), ),
                Item('movie_noninvasive_tmin',
                     label='noninvasive tmin',
                     enabled_when="movie_use_noninvasive"),
                Item('movie_noninvasive_tmax',
                     label='noninvasive_tmax',
                     enabled_when="movie_use_noninvasive"),
                Item('movie_interpolation', label='interp'),
            ),
            VGroup(
                Item('movie_bitrate', label='bitrate'),
                Item('movie_framerate', label='framerate'),
                Item('movie_dilation', label='temporal dilation'),
                Item('movie_animation_degrees', label='degrees rotation'),
            ),
        ),

        #        HGroup(
        #            Item('movie_tmin', label='tmin'),
        #            Item('movie_tmax', label='tmax'),
        #            Item('movie_dilation', label='temporal dilation'),
        #        ),
        #        HGroup(
        #            Item('movie_framerate', label='framerate'),
        #            Item('movie_bitrate', label='bitrate (b/s)'),
        #            Item('movie_interpolation', label='interp'),
        #        ),
        title='Chimer exodus from Aldmeris',
        buttons=[OKMakeMovieAction, CancelButton],
    )

    def _make_movie_button_fired(self):
        self.edit_traits(view='make_movie_view')

    def do_movie(self, info):
        from pyface.api import FileDialog, OK as FileOK
        dialog = FileDialog(action='save as')
        dialog.open()
        if dialog.return_code != FileOK:
            return

        self.movie_filename = os.path.join(dialog.directory, dialog.filename)
        info.ui.dispose()
        self.movie(
            self.movie_filename,
            noninvasive_tmin=self.movie_noninvasive_tmin,
            noninvasive_tmax=self.movie_noninvasive_tmax,
            invasive_tmin=self.movie_invasive_tmin,
            invasive_tmax=self.movie_invasive_tmax,
            normalization=self.movie_normalization_style,
            framerate=self.movie_framerate,
            dilation=self.movie_dilation,
            bitrate=self.movie_bitrate,
            interpolation=self.movie_interpolation,
            animation_degrees=self.movie_animation_degrees,
            samples_first=self.movie_sample_which_first,
        )

    def build_surface(self, subjects_dir=None, subject=None):
        '''
        creates a pysurfer surface and plots it

        specify subject, subjects_dir or these are taken from the environment

        Returns
        -------
        brain | surfer.viz.Brain
            Pysurfer brain object
        figure | mlab.scene
            Mayavi scene object
        '''

        if subjects_dir == None:
            subjects_dir = os.environ['SUBJECTS_DIR']
        if subject == None:
            subject = os.environ['SUBJECT']

        self.brain = surfer.Brain(subject,
                                  hemi='both',
                                  surf='pial',
                                  figure=self.scene.mayavi_scene,
                                  subjects_dir=subjects_dir,
                                  curv=False)

        #self.scene = self.brain._figures[0][0]

        self.brain.toggle_toolbars(True)

        #set the surface unpickable
        for srf in self.brain.brains:
            srf._geo_surf.actor.actor.pickable = False
            srf._geo_surf.actor.property.opacity = self.opacity

        return self.brain

    def invasive_callback(self, picker):
        if picker.actor not in self.ieeg_glyph.actor.actors:
            return

        ptid = int(picker.point_id / self.ieeg_glyph.glyph.glyph_source.
                   glyph_source.output.points.to_array().shape[0])

        pt_loc = tuple(self.ieeg_glyph.mlab_source.points[ptid])
        pt_name = self.ieeg_loc[pt_loc]
        pt_index = self.ch_names.index(pt_name)
        print ptid, pt_loc, pt_name, pt_index

        from browse_stc import do_browse
        # todo: change this to the real roi surface signal
        # surface_signal_rois = np.random.randn(*self.current_invasive_signal.mne_source_estimate.data.shape)
        # import random
        # import string
        # rois_labels = self.megsig.keys()
        # for _ in range(self.current_invasive_signal.mne_source_estimate.data.shape[0]):
        #     rois_labels.append(''.join(random.choice(string.ascii_uppercase + string.digits) for _ in range(10)))
        # nearest_rois = source_signal.identify_roi_from_atlas(pt_loc, atlas='laus250')
        self.invasive_labels = None
        if self.browser is None or self.browser.figure is None:
            self.browser = do_browse(self.invasive_signals,
                                     bads=['LPT8'],
                                     n_channels=1,
                                     const_event_time=2.0,
                                     surface_signal_rois=self.megsig,
                                     glyph=self.ieeg_glyph,
                                     brain=self.brain)


#        elif self.browser.figure is None:
#            self.browser = do_browse(self.current_invasive_signal,
#                bads=['LPT8'], n_channels=1,
#                                     const_event_time=2.0)

        self.browser._plot_imitate_scroll(pt_index)

    def plot_ieeg(self, raw=None, montage=None, elec_locs=None, ch_names=None):
        '''
        given a raw .fif file with sEEG electrodes, (and potentially other
        electrodes), extract and plot all of the sEEG electrodes in the file.

        alternately accepts a list of electrode locations and channel names
        or a montage file

        Returns
        -------
        ieeg_glyph | mlab.glyph
            Mayavi 3D glyph object
        '''
        if raw is None and montage is None and (elec_locs is None
                                                or ch_names is None):
            error_dialog("must specify raw .fif file or list of electrode "
                         "coordinates and channel names")

        if raw is not None:
            ra = mne.io.Raw(raw)

            #elecs = [(name, ra.info['chs'][i]['loc'][:3])
            elecs = [(tuple(ra.info['chs'][i]['loc'][:3]), name)
                     for i, name in enumerate(ra.ch_names) if ra.info['chs'][i]
                     ['kind'] == mne.io.constants.FIFF.FIFFV_SEEG_CH]

            self.ch_names = [e[1] for e in elecs]

            locs = np.array([e[0] for e in elecs])

        elif montage is not None:
            sfp = source_signal.load_montage(montage)
            locs = np.array(sfp.pos)
            self.ch_names = sfp.ch_names

            elecs = [(tuple(loc), name)
                     for name, loc in zip(self.ch_names, locs)]

        else:
            locs = np.array(elec_locs)

            self.ch_names = ch_names

            elecs = [(tuple(loc), name) for name, loc in zip(ch_names, locs)]

        # compare signal.ch_names to the ch_names here

        self.ieeg_loc = dict(elecs)

        source = mlab.pipeline.scalar_scatter(locs[:, 0],
                                              locs[:, 1],
                                              locs[:, 2],
                                              figure=self.scene.mayavi_scene)

        self.ieeg_glyph = mlab.pipeline.glyph(source,
                                              scale_mode='none',
                                              scale_factor=6,
                                              mode='sphere',
                                              figure=self.scene.mayavi_scene,
                                              color=(1, 0, 0),
                                              name='garbonzo',
                                              colormap='RdBu')
        self.ieeg_glyph.module_manager.scalar_lut_manager.reverse_lut = True

        #self.ieeg_glyph = mlab.points3d( locs[:,0], locs[:,1], locs[:,2],
        #    color = (1,0,0), scale_factor=6, figure=figure)

        self.ieeg_glyph.mlab_source.dataset.point_data.scalars = np.zeros(
            (len(locs), ))
        self._force_render()

        pick = self.scene.mayavi_scene.on_mouse_pick(self.invasive_callback)
        pick.tolerance = .1

        return self.ieeg_glyph

    def interactivize_ieeg(self):
        from browse_stc import do_browse
        self.browser = do_browse(
            self.current_invasive_signal.mne_source_estimate)

    def plot_ieeg_montage(self, montage):
        '''
        given a raw montage file with sEEG electrode coordinates,
        extract and plot all of the sEEG electrodes in the montage

        Returns
        -------
        ieeg_glyph | mlab.glyph
            Mayavi 3D glyph object
        '''

        mopath = os.path.abspath(montage)
        mo = mne.channels.read_montage(mopath)

        elecs = [(name, pos) for name, pos in zip(mo.ch_names, mo.pos)]
        self.ch_names = mo.ch_names
        self.ieeg_loc = dict(elecs)

        locs = mo.pos

        source = mlab.pipeline.scalar_scatter(locs[:, 0],
                                              locs[:, 1],
                                              locs[:, 2],
                                              figure=self.scene.mayavi_scene)

        self.ieeg_glyph = mlab.pipeline.glyph(source,
                                              scale_mode='none',
                                              scale_factor=6,
                                              mode='sphere',
                                              figure=self.scene.mayavi_scene,
                                              color=(1, 0, 0),
                                              name='gableebo',
                                              colormap='RdBu')
        self.ieeg_glyph.module_manager.scalar_lut_manager.reverse_lut = True

        return self.ieeg_glyph

    def generate_subcortical_surfaces(self, subjects_dir=None, subject=None):
        import subprocess

        if subjects_dir is not None:
            os.environ['SUBJECTS_DIR'] = subjects_dir
        if subject is None:
            subject = os.environ['SUBJECT']

        aseg2srf_cmd = os.path.realpath('aseg2srf')

        subprocess.call([aseg2srf_cmd, '-s', subject])

    def viz_subcortical_surfaces(self, subjects_dir=None, subject=None):
        structures_list = {
            'hippocampus': ([53, 17], (.69, .65, .93)),
            'amgydala': ([54, 18], (.8, .5, .29)),
            'thalamus': ([49, 10], (.318, 1, .447)),
            'caudate': ([50, 11], (1, .855, .67)),
            'putamen': ([51, 12], (0, .55, 1)),
            'insula': ([55, 19], (1, 1, 1)),
            'accumbens': ([58, 26], (1, .44, 1)),
        }

        for (strucl, strucr), _ in structures_list.values():

            for strucu in (strucl, strucr):
                surf_file = os.path.join(subjects_dir, subject, 'ascii',
                                         'aseg_%03d.srf' % strucu)

                if not os.path.exists(surf_file):
                    continue

                v, tri = mne.read_surface(surf_file)

                surf = mlab.triangular_mesh(v[:, 0],
                                            v[:, 1],
                                            v[:, 2],
                                            tri,
                                            opacity=.35,
                                            color=(.5, .5, .5))
                #)

                surf.actor.actor.pickable = False

    def viz_subcortical_points(self, subjects_dir=None, subject=None):
        '''
        add transparent voxel structures at the subcortical structures
        '''

        if subjects_dir is None:
            subjects_dir = os.environ['SUBJECTS_DIR']
        if subject is None:
            subject = os.environ['SUBJECT']

        structures_list = {
            'hippocampus': ([53, 17], (.69, .65, .93)),
            'amgydala': ([54, 18], (.8, .5, .29)),
            'thalamus': ([49, 10], (.318, 1, .447)),
            'caudate': ([50, 11], (1, .855, .67)),
            'putamen': ([51, 12], (0, .55, 1)),
            'insula': ([55, 19], (1, 1, 1)),
            'accumbens': ([58, 26], (1, .44, 1)),
        }

        asegf = os.path.join(subjects_dir, subject, 'mri', 'aseg.mgz')
        aseg = nib.load(asegf)
        asegd = aseg.get_data()

        for struct in structures_list:
            (strucl, strucr), color = structures_list[struct]

            for strucu in (strucl, strucr):

                strucw = np.where(asegd == strucu)

                if np.size(strucw) == 0:
                    print 'Nonne skippy %s' % struct
                    continue

                import geometry as geo
                xfm = geo.get_vox2rasxfm(asegf, stem='vox2ras-tkr')
                strucd = np.array(geo.apply_affine(np.transpose(strucw), xfm))

                print np.shape(strucd)

                src = mlab.pipeline.scalar_scatter(
                    strucd[:, 0],
                    strucd[:, 1],
                    strucd[:, 2],
                    figure=self.scene.mayavi_scene)

                mlab.pipeline.glyph(src,
                                    scale_mode='none',
                                    scale_factor=0.4,
                                    mode='sphere',
                                    opacity=1,
                                    figure=self.scene.mayavi_scene,
                                    color=color)

    def add_meg_signal(self, signal):
        self.megsig = signal

    def add_invasive_labels(self, labels, labels_id=None):
        self.invasive_labels = labels
        self.invasive_labels_id = labels.keys()[0] if labels_id is None \
            else labels[labels_id]

    def add_invasive_signal(self, name, signal):
        if len(self.ch_names) == 0:
            error_dialog("Cannot add invasive signal without first "
                         "specifying order of invasive electrodes")

        self.invasive_signals[name] = signal
        self.set_current_invasive_signal(name)

    def set_current_invasive_signal(self, name):
        self.current_invasive_signal = sig = self.invasive_signals[name]

        stc = sig.mne_source_estimate

        if stc.times[0] < self._time_low:
            self._time_low = stc.times[0]
        if stc.times[-1] > self._time_high:
            self._time_high = stc.times[-1]

        #reorder signal in terms of self.ch_names
        reorder_map = source_signal.adj_sort(sig.ch_names, self.ch_names)
        reorder_signal = stc.data[reorder_map]
        #ignore vertices, we only ever bother to get stc.data
        #reorder_vertices = stc.vertno[???]

        sig.ch_names = self.ch_names
        sig.data = reorder_signal

    def add_noninvasive_signal(self, name, signal):
        self.noninvasive_signals[name] = signal
        self.set_current_noninvasive_signal(name)

    def set_current_noninvasive_signal(self, name):
        self.current_noninvasive_signal = sig = self.noninvasive_signals[name]

        stc = sig.mne_source_estimate

        if stc.times[0] < self._time_low:
            self._time_low = stc.times[0]
        if stc.times[-1] > self._time_high:
            self._time_high = stc.times[-1]

        sig.data = stc.data

    def _display_interpolated_invasive_signal_timepoint(self, idx, ifunc):
        scalars = ifunc(idx)

        self.ieeg_glyph.mlab_source.dataset.point_data.scalars = (
            np.array(scalars))
        self.ieeg_glyph.actor.mapper.scalar_visibility = True
        self.ieeg_glyph.module_manager.scalar_lut_manager.data_range = (0, 1)

    def _display_interpolated_noninvasive_signal_timepoint(
            self, idx, ifunc, interpolation='quadratic'):
        scalars = ifunc(idx)

        lvt = self.current_noninvasive_signal.mne_source_estimate.lh_vertno
        rvt = self.current_noninvasive_signal.mne_source_estimate.rh_vertno

        if len(lvt) > 0:
            #assumes all lh scalar precede all rh scalar
            #if not, then we need separate lh_ifunc and rh_ifunc for this case
            lh_scalar = scalars[len(lvt):]
            #lh_scalar = scalars[lvt]
            lh_surf = self.brain.brains[0]._geo_surf
            if len(lvt) < len(self.brain.geo['lh'].coords):
                if self.smoothing_steps > 0:
                    lh_scalar = self.smoothl * lh_scalar
                else:
                    ls = np.ones(len(self.brain.geo['lh'].coords)) * .5
                    ls[rvt] = lh_scalar
                    lh_scalar = ls

            lh_surf.mlab_source.scalars = lh_scalar
            lh_surf.module_manager.scalar_lut_manager.data_range = (0, 1)

        if len(rvt) > 0:
            rh_scalar = scalars[:len(rvt)]
            #rh_scalar = scalars[rvt]
            rh_surf = self.brain.brains[1]._geo_surf
            if len(rvt) < len(self.brain.geo['rh'].coords):
                if self.smoothing_steps > 0:
                    rh_scalar = self.smoothr * rh_scalar
                else:
                    rs = np.ones(len(self.brain.geo['rh'].coords)) * .5
                    rs[rvt] = rh_scalar
                    rh_scalar = rs

            rh_surf.mlab_source.scalars = rh_scalar
            rh_surf.module_manager.scalar_lut_manager.data_range = (0, 1)

        #self.brain.set_data_time_index(idx, interpolation)

    @on_trait_change('time_slider')
    def _show_closest_timepoint_listen(self):
        self.set_closest_timepoint(self.time_slider)
        self._force_render()

    def set_closest_timepoint(self, time, invasive=True, noninvasive=True):
        if noninvasive:
            self._set_noninvasive_timepoint(time)
        if invasive:
            self._set_invasive_timepoint(time)

    def _set_invasive_timepoint(self, t, normalization='global'):
        if self.current_invasive_signal is None:
            return
        sig = self.current_invasive_signal
        stc = sig.mne_source_estimate

        sample_time = np.argmin(np.abs(stc.times - t))

        if normalization == 'global':
            dmax = np.max(sig.data)
            dmin = np.min(sig.data)
            data = (sig.data - dmin) / (dmax - dmin)
        else:
            data = sig.data

        scalars = data[:, sample_time]

        #from PyQt4.QtCore import pyqtRemoveInputHook
        #import pdb
        #pyqtRemoveInputHook()
        #pdb.set_trace()

        #unset any changes to the LUT
        self.ieeg_glyph.module_manager.scalar_lut_manager.lut_mode = 'black-white'
        self.ieeg_glyph.module_manager.scalar_lut_manager.lut_mode = 'RdBu'
        self.ieeg_glyph.module_manager.scalar_lut_manager.reverse_lut = True

        self.ieeg_glyph.mlab_source.dataset.point_data.scalars = (
            np.array(scalars))
        self.ieeg_glyph.actor.mapper.scalar_visibility = True
        self.ieeg_glyph.module_manager.scalar_lut_manager.data_range = (0, 1)

    def _setup_noninvasive_viz(self):
        lvt = self.current_noninvasive_signal.mne_source_estimate.lh_vertno
        rvt = self.current_noninvasive_signal.mne_source_estimate.rh_vertno

        if (0 < len(lvt) < len(self.brain.geo['lh'].coords)
                and self.smoothing_steps > 0):
            ladj = surfer.utils.mesh_edges(self.brain.geo['lh'].faces)
            self.smoothl = surfer.utils.smoothing_matrix(
                lvt, ladj, self.smoothing_steps)

        if (0 < len(rvt) < len(self.brain.geo['rh'].coords)
                and self.smoothing_steps > 0):
            radj = surfer.utils.mesh_edges(self.brain.geo['lh'].faces)
            self.smoothr = surfer.utils.smoothing_matrix(
                rvt, radj, self.smoothing_steps)

        for i, brain in enumerate(self.brain.brains):

            #leave gray if no vertices in hemisphere
            if len(lvt) == i == 0 or len(rvt) == 0 == i - 1:
                brain._geo_surf.actor.mapper.scalar_visibility = False
                continue

            brain._geo_surf.module_manager.scalar_lut_manager.lut_mode = (
                'RdBu')
            brain._geo_surf.module_manager.scalar_lut_manager.reverse_lut = (
                True)
            brain._geo_surf.actor.mapper.scalar_visibility = True

            brain._geo_surf.module_manager.scalar_lut_manager.data_range = (0,
                                                                            1)

    def _set_noninvasive_timepoint(self, t, normalization='global'):
        if self.current_noninvasive_signal is None:
            return
        stc = self.current_noninvasive_signal.mne_source_estimate

        sample_time = np.argmin(np.abs(stc.times - t))

        lvt = self.current_noninvasive_signal.mne_source_estimate.lh_vertno
        rvt = self.current_noninvasive_signal.mne_source_estimate.rh_vertno

        if normalization == 'global':
            if len(lvt) > 0:
                lh_dmax = np.max(stc.lh_data)
                lh_dmin = np.min(stc.lh_data)
                lh_scalar = (stc.lh_data - lh_dmin) / (lh_dmax - lh_dmin)
            if len(rvt) > 0:
                rh_dmax = np.max(stc.rh_data)
                rh_dmin = np.min(stc.rh_data)
                rh_scalar = (stc.rh_data - rh_dmin) / (rh_dmax - rh_dmin)
        else:
            lh_scalar = stc.lh_data
            rh_scalar = stc.rh_data

        self._setup_noninvasive_viz()

        if len(lvt) > 0:
            lh_surf = self.brain.brains[0]._geo_surf
            if len(lvt) < len(self.brain.geo['lh'].coords):
                if self.smoothing_steps > 0:
                    lh_scalar = self.smoothl * lh_scalar
                else:
                    ls = np.array(len(self.brain.geo['lh'].coords))
                    ls[lvt] = lh_scalar
                    lh_scalar = ls
            lh_surf.mlab_source.scalars = lh_scalar
            rh_surf.module_manager.scalar_lut_manager.data_range = (0, 1)

        if len(rvt) > 0:
            rh_surf = self.brain.brains[1]._geo_surf
            if len(rvt) < len(self.brain.geo['rh'].coords):
                if self.smoothing_steps > 0:
                    rh_scalar = self.smoothr * rh_scalar
                else:
                    rs = np.ones(len(self.brain.geo['rh'].coords)) * .5
                    rs[rvt] = rh_scalar
                    rh_scalar = rs
            rh_surf.mlab_source.scalars = rh_scalar
            rh_surf.module_manager.scalar_lut_manager.data_range = (0, 1)

    def movie(self,
              movname,
              invasive=True,
              noninvasive=True,
              framerate=24,
              interpolation='quadratic',
              dilation=2,
              normalization='local',
              debug_labels=False,
              bitrate='750k',
              animation_degrees=0,
              invasive_tmin=None,
              invasive_tmax=None,
              noninvasive_tmin=None,
              noninvasive_tmax=None,
              samples_first='invasive'):
        #potentially worth providing different options for normalization and
        #interpolation for noninvasive and invasive data

        if not invasive and not noninvasive:
            error_dialog("Movie has no noninvasive or invasive signals")

        def noninvasive_sampling(nr_samples):
            if self.current_noninvasive_signal is None:
                error_dialog("No noninvasive signal found")
            if self.current_noninvasive_signal.mne_source_estimate is None:
                error_dialog("Noninvasive signal has no source estimate")

            ni_times, _, nfunc, nr_samples = self._create_movie_samples(
                self.current_noninvasive_signal,
                tmin=noninvasive_tmin,
                tmax=noninvasive_tmax,
                framerate=framerate,
                dilation=dilation,
                interpolation=interpolation,
                normalization=normalization,
                is_invasive=False,
                nr_samples=nr_samples)

            nsteps = len(ni_times)
            steps = nsteps

            self._setup_noninvasive_viz()

            return ni_times, nfunc, nr_samples, nsteps, steps

        def invasive_sampling(nr_samples):
            if self.current_invasive_signal is None:
                error_dialog("No invasive signal found")
            if self.current_invasive_signal.mne_source_estimate is None:
                error_dialog("Invasive signal has no source estimate")

            i_times, _, ifunc, nr_samples = self._create_movie_samples(
                self.current_invasive_signal,
                tmin=invasive_tmin,
                tmax=invasive_tmax,
                framerate=framerate,
                dilation=dilation,
                interpolation=interpolation,
                normalization=normalization,
                is_invasive=True,
                nr_samples=nr_samples)

            isteps = len(i_times)
            steps = isteps

            return i_times, ifunc, nr_samples, isteps, steps

        #ensure that the samples are collected in the right order
        nr_samples = -1

        if invasive and samples_first == 'invasive':
            i_times, ifunc, nr_samples, isteps, steps = (
                invasive_sampling(nr_samples))

        if noninvasive:
            ni_times, nfunc, nr_samples, nsteps, steps = (
                noninvasive_sampling(nr_samples))

        if invasive and samples_first != 'invasive':
            i_times, ifunc, nr_samples, isteps, steps = (
                invasive_sampling(nr_samples))

        if noninvasive and invasive:
            if isteps != nsteps:
                error_dialog(
                    "Movie parameters do not yield equal number of "
                    "samples in invasive and noninvasive timecourses.\n"
                    "Invasive samples: %i\nNoninvasive samples: %i" %
                    (isteps, nsteps))

        from tempfile import mkdtemp
        tempdir = mkdtemp()
        frame_pattern = 'frame%%0%id.png' % (np.floor(np.log10(steps)) + 1)
        fname_pattern = os.path.join(tempdir, frame_pattern)

        images_written = []

        for i in xrange(steps):
            frname = fname_pattern % i

            #do the data display method
            if invasive:
                iidx = i_times[i]
                self._display_interpolated_invasive_signal_timepoint(
                    iidx, ifunc)

            if noninvasive:
                nidx = ni_times[i]
                self._display_interpolated_noninvasive_signal_timepoint(
                    nidx, nfunc, interpolation=interpolation)

            self.scene.camera.azimuth(animation_degrees)

            self.scene.render()
            mlab.draw(figure=self.scene.mayavi_scene)
            self._force_render()
            self.brain.save_image(frname)

        #return images_written
        from surfer.utils import ffmpeg
        ffmpeg(movname,
               fname_pattern,
               framerate=framerate,
               bitrate=bitrate,
               codec=None)

    def _force_render(self):
        from pyface.api import GUI
        _gui = GUI()
        orig_val = _gui.busy
        _gui.set_busy(busy=True)
        _gui.process_events()
        _gui.set_busy(busy=orig_val)
        _gui.process_events()

    def _create_movie_samples(self,
                              sig,
                              framerate=24,
                              interpolation='quadratic',
                              dilation=2,
                              tmin=None,
                              tmax=None,
                              normalization='local',
                              is_invasive=False,
                              nr_samples=-1):

        from scipy.interpolate import interp1d

        stc = sig.mne_source_estimate

        sample_rate = stc.tstep

        if tmin is None:
            tmin = stc.tmin

        if tmax is None:
            tmax = (stc.times[-1])

        smin = np.argmin(np.abs(stc.times - tmin))
        smax = np.argmin(np.abs(stc.times - tmax))

        # catch if the user asked for invasive timepoints that dont exist
        if tmin < stc.tmin:
            error_dialog("Time window too low for %s signal" %
                         ('invasive' if is_invasive else 'noninvasive'))
        if tmax > stc.times[-1]:
            error_dialog("Time window too high for %s signal" %
                         ('invasive' if is_invasive else 'noninvasive'))

        time_length = tmax - tmin
        sample_length = smax - smin + 1

        tstep_size = 1 / (framerate * dilation)
        sstep_size = tstep_size / sample_rate
        sstep_ceil = int(np.ceil(sstep_size))

        #to calculate the desired number of samples in the time window,
        #use the code which checks for step size and optionally adds 1 sample
        #however, the other signal might have a different sampling rate and
        #lack the extra sample even if this signal has it exactly.
        #therefore to compromise, don't do this check or try to interpret
        #at all the missing sample, instead use the wrong number of samples
        #and interpolate to the right thing as close as possible.

        #note that this solution has minor temporal instability inherent;
        #that is to say we are losing information beyond interpolation
        #that might be accurately sampled in one or potentially both
        #signals due to the sampling differences between the signals.
        #this is not a practical problem.
        if nr_samples == -1:
            if np.allclose(time_length % tstep_size, 0, atol=tstep_size / 2):
                sstop = smax + sstep_size / 2
            else:
                sstop = smax

            nr_samples = len(np.arange(smin, sstop, sstep_size))
        # end

        movie_sample_times = np.linspace(smin, smax, num=nr_samples)

        #smin is the minimum possible sample, the max is smax
        #to get to exactly smax we need to use smax+1
        raw_sample_times = np.arange(smin, smax + 1)

        exact_times = np.arange(sample_length)

        #print sstep_size, tstep_size
        #print tmin, tmax, smin, smax
        #print sample_rate
        #print movie_sample_times.shape, raw_sample_times.shape

        #this interpolation is exactly linear
        all_times = interp1d(raw_sample_times, exact_times)(movie_sample_times)

        data = sig.data[:, smin:smax + 1]

        #print data.shape
        #print all_times.shape

        if normalization == 'none':
            pass
        elif normalization == 'conservative':
            dmax = np.max(sig.data)
            dmin = np.min(sig.data)
            data = (data - dmin) / (dmax - dmin)
        elif normalization == 'local':
            dmax = np.max(data)
            dmin = np.min(data)
            data = (data - dmin) / (dmax - dmin)

        #the interpolation is quadratic and therefore does a very bad job
        #with extremely low frequency varying signals. which can happen when
        #plotting something that looks like raw data.
        interp_func = interp1d(exact_times, data, interpolation, axis=1)

        return all_times, data, interp_func, nr_samples
Beispiel #18
0
class ElectrodeWindow(Handler):
    model = Any
    #we clumsily hold a reference to the model object to fire its events

    cur_grid = Str

    electrodes = List(Instance(Electrode))
    cur_sel = Instance(Electrode)
    selection_callback = Method
   
    selected_ixes = Any
    swap_action = Action(name='Swap two electrodes', action='do_swap')
    add_blank_action = Action(name='Add blank electrode', 
        action='do_add_blank')

    previous_sel = Instance(Electrode)
    previous_color = Int

    distinct_prev_sel = Instance(Electrode)
    
    save_montage_action = Action(name='Save montage file', 
        action='do_montage')

    save_csv_action = Action(name='Save CSV file', action='do_csv')

    interpolate_action = Action(name='Linear interpolation',
        action='do_linear_interpolation')

    naming_convention = Enum('line', 'grid', 'reverse_grid')
    grid_type = Enum('depth', 'subdural')
    label_auto_action = Action(name='Automatic labeling',
        action='do_label_automatically')

    name_stem = Str
    c1, c2, c3 = 3*(Instance(Electrode),)

    parcellation = Str
    error_radius = Float(4)
    find_rois_action = Action(name='Estimate single ROI contacts',
        action='do_rois')
    find_all_rois_action = Action(name='Estimate all ROI contacts',
        action='do_all_rois')

    manual_reposition_action = Action(name='Manually modify electrode '
        'position', action='do_manual_reposition')

    img_dpi = Float(125.)
    img_size = List(Float)
    save_coronal_slice_action = Action(name='Save coronal slice',
        action='do_coronal_slice')

    def _img_size_default(self):
        return [450., 450.]

    #electrode_factory = Method
    def electrode_factory(self):
        return Electrode(special_name='Electrode for linear interpolation',
            grid_name=self.cur_grid,
            is_interpolation=True)

    def dynamic_view(self):
        return View(
            Item('electrodes',
                editor=TableEditor( columns = 
                    [ObjectColumn(label='electrode',
                                  editor=TextEditor(),
                                  style='readonly',
                                  editable=False,
                                  name='strrepr'),

                     ObjectColumn(label='corner',
                                  editor=CheckListEditor(
                                    values=['','corner 1','corner 2',
                                        'corner 3']),
                                  style='simple',
                                  name='corner'),

                     ObjectColumn(label='geometry',
                                  editor=CSVListEditor(),
                                  #editor=TextEditor(),
                                  #style='readonly',
                                  #editable=False,
                                  name='geom_coords'),
                                  
                     ObjectColumn(label='channel name',
                                  editor=TextEditor(),
                                  name='name'),

                     ObjectColumn(label='ROIs',
                                  editor=ListStrEditor(),
                                  editable=False, 
                                  name='roi_list'),
                     ],
                    selected='cur_sel',
                    deletable=True,
                    #row_factory=electrode_factory,
                    row_factory=self.electrode_factory,
                    ),
                show_label=False, height=350, width=700),

            HGroup(
                VGroup( 
                    Label( 'Automatic labeling parameters' ),
                    Item( 'name_stem' ),
                    HGroup(
                        Item( 'naming_convention' ),
                        Item( 'grid_type' ),
                    ),
                ),
                #VGroup(
                #    Label( 'ROI identification parameters' ),
                #    Item('parcellation'),
                #    Item('error_radius'),
                #),
                #VGroup(
                #    Label('Image parameters' ),
                #    Item('img_dpi', label='dpi'),
                #    Item('img_size', label='size', editor=CSVListEditor()),
                #),
            ),

            resizable=True, kind='panel', title='modify electrodes',
            #buttons=[OKButton, swap_action, label_auto_action,
            #    interpolate_action, save_montage_action, find_rois_action]) 

            buttons = [self.label_auto_action, self.swap_action, OKButton],
            menubar = MenuBar(
                Menu( self.label_auto_action, self.add_blank_action,
                    self.interpolate_action, self.find_rois_action, 
                    self.find_all_rois_action,
                    self.manual_reposition_action,
                    name='Operations',
                ),
                Menu( self.save_montage_action, self.save_csv_action,
                    self.save_coronal_slice_action,
                    name='Save Output',
                ),
            )
        )

    def edit_traits(self):
        super(ElectrodeWindow, self).edit_traits(view=self.dynamic_view())

    @on_trait_change('cur_sel')
    def selection_callback(self):
        if self.cur_sel is None:
            return

        # if this electrode has just been created and not interpolated yet
        # it has no change of being in the image yet so just pass
        if self.cur_sel.special_name == 'Electrode for linear interpolation':
            return

        #import pdb
        #pdb.set_trace()

        if self.previous_sel is not None:
            self.model._new_glyph_color = self.previous_color
            self.model._single_glyph_to_recolor = self.previous_sel.asct()
            self.model._update_single_glyph_event = True

        if self.distinct_prev_sel != self.previous_sel:
            self.distinct_prev_sel = self.previous_sel

        self.previous_sel = self.cur_sel
        self.previous_color = self.model._colors.keys().index(self.cur_grid)

        selection_color = (self.model._colors.keys().index('selection'))

        self.model._new_glyph_color = selection_color
        self.model._single_glyph_to_recolor = self.cur_sel.asct()
        self.model._update_single_glyph_event = True

    #whenever the window closes, it no longer has a valid cur_sel to listen to
    def closed(self, is_ok, info):
        self.cur_sel = None 
        del self.model.ews[self.cur_grid]
        if self.previous_sel is not None:
            self.model._new_glyph_color = self.previous_color
            self.model._single_glyph_to_recolor = self.previous_sel.asct()
            self.model._update_single_glyph_event = True

    @on_trait_change('grid_type')
    def change_grid_type(self):
        self.model._grid_types[self.cur_grid] = self.grid_type

    def do_add_blank(self, info):
        e = self.electrode_factory()
        e.grid_name = self.cur_grid
        self.electrodes.append(e)

    def do_swap(self, info):
        #if not len(self.selected_ixes) == 2:
        #    return
        if self.distinct_prev_sel == self.cur_sel:
            return
        elif None in (self.distinct_prev_sel, self.cur_sel):
            return

        #i,j = self.selected_ixes
        #e1 = self.electrodes[i]
        #e2 = self.electrodes[j]
        e1 = self.cur_sel
        e2 = self.distinct_prev_sel

        geom_swap = e1.geom_coords
        name_swap = e1.name

        e1.geom_coords = e2.geom_coords
        e1.name = e2.name

        e2.geom_coords = geom_swap
        e2.name = name_swap

    def do_label_automatically(self, info):
        #figure out c1, c2, c3
        c1,c2,c3 = 3*(None,)
        for e in self.electrodes:
            if len(e.corner) == 0:
                continue
            elif len(e.corner) > 1:
                error_dialog('Wrong corners specified, check again')
                return
    
            elif 'corner 1' in e.corner:
                c1 = e
            elif 'corner 2' in e.corner:
                c2 = e
            elif 'corner 3' in e.corner:
                c3 = e

        if c1 is None or c2 is None or c3 is None:
            error_dialog('Not all corners were specified')
            return
    
        cur_geom = self.model._grid_geom[self.cur_grid]
        if cur_geom=='user-defined' and self.naming_convention != 'line':
            from color_utils import mayavi2traits_color
            from name_holder import GeometryNameHolder, GeomGetterWindow
            nameholder = GeometryNameHolder(
                geometry=cur_geom,
                color=mayavi2traits_color(
                    self.model._colors[self.cur_grid]))
            geomgetterwindow = GeomGetterWindow(holder=nameholder)

            if geomgetterwindow.edit_traits().result:
                cur_geom = geomgetterwindow.geometry
            else:
                error_dialog("User did not specify any geometry")
                return

        import pipeline as pipe
        if self.naming_convention == 'line':
            pipe.fit_grid_to_line(self.electrodes, c1.asct(), c2.asct(),
                c3.asct(), cur_geom, delta=self.model.delta,
                rho_loose=self.model.rho_loose)
            #do actual labeling
            for elec in self.model._grids[self.cur_grid]:
                _,y = elec.geom_coords
                index = y+1
                elec.name = '%s%i'%(self.name_stem, index)

        else:
            pipe.fit_grid_to_plane(self.electrodes, c1.asct(), c2.asct(), 
                c3.asct(), cur_geom)

            #do actual labeling
            for elec in self.model._grids[self.cur_grid]:
                x,y = elec.geom_coords
                if self.naming_convention=='grid':
                    #index = y*np.max(cur_geom) + x + 1
                    index = x*np.min(cur_geom) + y + 1
                else: #reverse_grid
                    #index = x*np.min(cur_geom) + y + 1
                    index = y*np.max(cur_geom) + x + 1
                
                elec.name = '%s%i'%(self.name_stem, index)

    def do_linear_interpolation(self, info):
        #TODO does not feed back coordinates to model, 
        #which is important for snapping

        if self.cur_sel is None:
            return
        elif self.cur_sel.special_name == '':
            return
        
        if len(self.cur_sel.geom_coords) == 0:
            error_dialog("Specify geom_coords before linear interpolation")
            return

        x,y = self.cur_sel.geom_coords

        x_low = self._find_closest_neighbor(self.cur_sel, 'x', '-')
        x_hi = self._find_closest_neighbor(self.cur_sel, 'x', '+')
        y_low = self._find_closest_neighbor(self.cur_sel, 'y', '-')
        y_hi = self._find_closest_neighbor(self.cur_sel, 'y', '+')

        loc = None

        #handle simplest case of electrode directly in between others
        if x_low is not None and x_hi is not None:
            xl = x_low.geom_coords[0]
            xh = x_hi.geom_coords[0]
            ratio = (x - xl) / (xh - xl)
        
            loc = np.array(x_low.iso_coords) + (np.array(x_hi.iso_coords)-
                np.array(x_low.iso_coords))*ratio

        elif y_low is not None and y_hi is not None:
            yl = y_low.geom_coords[1]
            yh = y_hi.geom_coords[1]
            ratio = (y - yl) / (yh - yl)
        
            loc = np.array(y_low.iso_coords) + (np.array(y_hi.iso_coords)-
                np.array(y_low.iso_coords))*ratio

        #handle poorer case of electrode on end of line
        if x_low is not None and loc is None:
            x_lower = self._find_closest_neighbor(x_low, 'x', '-')
            xl = x_low.geom_coords[0]
            xll = x_lower.geom_coords[0]
            if xl == xll+1:
                loc = 2*np.array(x_low.iso_coords) - np.array(
                    x_lower.iso_coords)

        if x_hi is not None and loc is None:
            x_higher = self._find_closest_neighbor(x_hi, 'x', '+')
            xh = x_hi.geom_coords[0]
            xhh = x_higher.geom_coords[0]
            if xh == xhh-1:
                loc = 2*np.array(x_hi.iso_coords) - np.array(
                    x_higher.iso_coords)

        #import pdb
        #pdb.set_trace()

        if y_low is not None and loc is None:
            y_lower = self._find_closest_neighbor(y_low, 'y', '-')
            yl = y_low.geom_coords[1]
            yll = y_lower.geom_coords[1]
            if yl == yll+1:
                loc = 2*np.array(y_low.iso_coords) - np.array(
                    y_lower.iso_coords)
        
        if y_hi is not None and loc is None:
            y_higher = self._find_closest_neighbor(y_hi, 'y', '+')
            yh = y_hi.geom_coords[1]
            yhh = y_higher.geom_coords[1]
            if yh == yhh-1:
                loc = 2*np.array(y_hi.iso_coords) - np.array(
                    y_higher.iso_coords)
    
        if loc is not None:
            self.cur_sel.iso_coords = tuple(loc)
            self.cur_sel.special_name = 'Linearly interpolated electrode'
        else:
            error_dialog('No line for simple linear interpolation\n'
                'Better algorithm needed')

        # translate the electrode into RAS space
        import pipeline as pipe
        
        aff = self.model.acquire_affine()
        pipe.translate_electrodes_to_surface_space( [self.cur_sel], aff,
            subjects_dir=self.model.subjects_dir, subject=self.model.subject)

        pipe.linearly_transform_electrodes_to_isotropic_coordinate_space(
            [self.cur_sel], self.model.ct_scan,
            isotropization_type = ('deisotropize' if self.model.isotropize
                else 'copy_to_ct'))

        # add this electrode to the grid model so that it can be visualized
        self.model.add_electrode_to_grid(self.cur_sel, self.cur_grid)

    def _find_closest_neighbor(self, cur_elec, axis, direction): 
        x,y = cur_elec.geom_coords

        if direction=='+':
            new_ix = np.inf
        else:
            new_ix = -np.inf
        new_e = None

        for e in self.electrodes:
            if len(e.geom_coords) == 0:
                continue

            ex,ey = e.geom_coords
            
            if axis=='x' and direction=='+':
                if ex < new_ix and ex > x and ey == y:
                    new_e = e
                    new_ix = ex
            if axis=='x' and direction=='-':
                if ex > new_ix and ex < x and ey == y:
                    new_e = e
                    new_ix = ex
            if axis=='y' and direction=='+':
                if ey < new_ix and ey > y and ex == x:
                    new_e = e
                    new_ix = ey
            if axis=='y' and direction=='-':
                if ey > new_ix and ey < y and ex == x:
                    new_e = e
                    new_ix = ey

        return new_e

    def do_montage(self, info):
        electrodes = self.model.get_electrodes_from_grid(
            target=self.cur_grid,
            electrodes=self.electrodes)

        if electrodes is None:
            return

        from electrode_group import save_coordinates
        save_coordinates( electrodes, self.model._grid_types,
            snapping_completed=self.model._snapping_completed,
            file_type='montage')

    def do_csv(self, info):
        electrodes = self.model.get_electrodes_from_grid(
            target=self.cur_grid,
            electrodes=self.electrodes)

        if electrodes is None:
            return

        from electrode_group import save_coordinates
        save_coordinates( electrodes, self.model._grid_types,
            snapping_completed=self.model._snapping_completed,
            file_type='csv')

    def do_rois(self, info):
        from electrode_group import get_nearby_rois_elec
        get_nearby_rois_elec( self.cur_sel,
            parcellation=self.model.roi_parcellation,
            error_radius=self.model.roi_error_radius,
            subjects_dir=self.model.subjects_dir, 
            subject=self.model.subject )

    def do_all_rois(self, info):
        from electrode_group import get_nearby_rois_grid
        get_nearby_rois_grid( self.electrodes,
            parcellation=self.model.roi_parcellation,
            error_radius=self.model.roi_error_radius,
            subjects_dir=self.model.subjects_dir, 
            subject=self.model.subject )

    def do_coronal_slice(self, info):
        savefile = ask_user_for_savefile('save png file with slice image')

        from electrode_group import coronal_slice_grid
        coronal_slice_grid(self.electrodes, savefile=savefile,
            subjects_dir=self.model.subjects_dir, subject=self.model.subject,
            dpi=self.model.coronal_dpi, 
            size=tuple(self.model.coronal_size),
            title=self.name_stem)

    def do_manual_reposition(self, info):
        if self.cur_sel is None:
            return

        pd = self.model.construct_panel2d()
        #import panel2d

        #x,y,z = self.cur_sel.asras()
        x,y,z = self.cur_sel.asct()
        pd.move_cursor(x,y,z)
        pd.drop_pin(x,y,z, color='cyan', name='electrode', image_name='ct')

        rx,ry,rz = self.cur_sel.asras()
        pd.drop_pin(rx,ry,rz, color='cyan', name='electrode', image_name='t1',
            ras_coords=True)

        pd.edit_traits(kind='livemodal')

    @on_trait_change('model:panel2d:move_electrode_internally_event')
    def _internally_effect_electrode_reposition(self):
        if self.cur_sel is None:
            error_dialog("No electrode specified to move")
            return

        pd = self.model.panel2d
        image_name = pd.currently_showing.name
        px,py,pz,_ = pd.pins[image_name][pd.current_pin]

        if image_name=='t1':
            px,py,pz = pd.map_cursor((px,py,pz), pd.images['t1'][2])
            
        self.model.move_electrode( self.cur_sel, (px,py,pz),
            in_ras=(image_name=='t1') )

    @on_trait_change('model:panel2d:move_electrode_postprocessing_event')
    def _postprocessing_effect_electrode_reposition(self):
        if self.cur_sel is None:
            error_dialog("No electrode specified to move")
            return

        pd = self.model.panel2d
        image_name = pd.currently_showing.name
        px,py,pz,_ = pd.pins[image_name][pd.current_pin]

        if image_name=='t1':
            px,py,pz = pd.map_cursor((px,py,pz), pd.images['t1'][2])
            
        self.model.move_electrode( self.cur_sel, (px,py,pz),
            in_ras=(image_name=='t1'), as_postprocessing=True )
Beispiel #19
0
class Controller(HasTraits, BaseControllerImpl):
    ATTRIBUTES = collections.OrderedDict((
        ('w', IndexAttribute('w')),
        ('x', IndexAttribute('x')),
        ('y', IndexAttribute('y')),
        ('z', IndexAttribute('z')),
        ('colormap', ColormapAttribute()),
        ('colorbar', ColorbarAttribute()),
        ('slicing_axis', Dimension4DAttribute()),
        ('clip', ClipAttribute()),
        ('clip_min', ClipAttribute()),
        ('clip_max', ClipAttribute()),
        ('clip_auto', AutoClipAttribute()),
        ('clip_symmetric', SymmetricClipAttribute()),
# passed to views
        ('locate_mode', LocateModeAttribute()),
        ('locate_value', LocateValueAttribute()),
    ))
    DIMENSIONS = "2D, 3D, ..., nD"
    DATA_CHECK = classmethod(lambda cls, data: len(data.shape) >= 2)
    DESCRIPTION = """\
Controller for multiple views
"""
    LABEL_WIDTH = 30
    LOCAL_AXIS_NAMES = ['x', 'y', 'z']
    LOCAL_AXIS_NUMBERS = dict((axis_name, axis_number) for axis_number, axis_name in enumerate(LOCAL_AXIS_NAMES))
    GLOBAL_AXIS_NAMES = ['w', 'x', 'y', 'z']
    GLOBAL_AXIS_NUMBERS = dict((axis_name, axis_number) for axis_number, axis_name in enumerate(GLOBAL_AXIS_NAMES))
    S_ALL = slice(None, None, None)

    # The axis selectors
    ## w:
    w_low = Int(0)
    w_high = Int(0)
    w_index = Int(0)
    w_range = Range(low='w_low', high='w_high', value='w_index')
    ## x:
    x_low = Int
    x_high = Int
    x_index = Int
    x_range = Range(low='x_low', high='x_high', value='x_index')
    ## y:
    y_low = Int(0)
    y_high = Int(0)
    y_index = Int(0)
    y_range = Range(low='y_low', high='y_high', value='y_index')
    ## z:
    z_low = Int(0)
    z_high = Int(0)
    z_index = Int(0)
    z_range = Range(low='z_low', high='z_high', value='z_index')
    ## is4D:
    is4D = Bool()

    slicing_axis = Str("")
    data_shape = Str("")
    close_button = Action(name='Close', action='_on_close')

    colorbar = Bool()
    colormap = Str()

    data_min = Float()
    data_max = Float()

    clip = Float()
    clip_min = Float()
    clip_max = Float()
    clip_auto = Bool()
    clip_symmetric = Bool()
    clip_readonly = Bool()
    clip_visible = Bool()
    clip_range_readonly = Bool()
    clip_range_visible = Bool()

    def __init__(self, logger, attributes, title=None, **traits):
        HasTraits.__init__(self, **traits)
        BaseControllerImpl.__init__(self, logger=logger, title=title, attributes=attributes)
        self.w_low, self.x_low, self.y_low, self.z_low = 0, 0, 0, 0
        rank = len(self.shape)
        if rank == 2:
            wh = 1
            zh = 1
            xh, yh = self.shape
        elif rank == 3:
            wh = 1
            xh, yh, zh = self.shape
        elif rank == 4:
            wh, xh, yh, zh = self.shape
        self.w_high, self.x_high, self.y_high, self.z_high = wh - 1, xh - 1, yh - 1, zh - 1
        self.set_default_clips()
        self.set_axis_mapping()


    ### U t i l i t i e s :
    def create_view(self, view_class, data, title=None):
        return view_class(controller=self, data=data, title=title)

    def add_view(self, view_class, data, title=None):
        if data.shape != self.shape:
            raise ValueError("{}: cannot create {} view: data shape {} is not {}".format(self.name, view_class.__name__, data.shape, self.shape))
        local_volume = self.get_local_volume(data)
        view = self.create_view(view_class, local_volume, title=title)
        self.set_view_axis(view)
        self.views.append(view)
        self.views_data[view] = data
        #self.add_class_trait(view.name, Instance(view_class))
        #self.add_trait(view.name, view)
        self.update_data_range()
        return view

    def set_view_axis(self, view):
        for global_axis_name in self.GLOBAL_AXIS_NAMES:
            if global_axis_name != self.slicing_axis:
                local_axis_name = self.get_local_axis_name(global_axis_name)
                setattr(view, "{}_index".format(local_axis_name), getattr(self, "{}_index".format(global_axis_name)))


    def update_data_range(self):
        if self.views:
            data_min_l, data_max_l = [], []
            for view in self.views:
                for local_volume in view.data:
                    data_min_l.append(local_volume.min())
                    data_max_l.append(local_volume.max())
            self.data_min = float(min(data_min_l))
            self.data_max = float(max(data_max_l))
            self.logger.info("{}: data range: {} <-> {}".format(self.name, self.data_min, self.data_max))
        else:
            self.data_min = 0.0
            self.data_max = 0.0
                
    def get_local_volume(self, data):
        if data.shape != self.shape:
            raise ValueError("{}: invalid shape {}".format(self.name, data.shape))
        if len(self.shape) == 4:
            s = [self.S_ALL, self.S_ALL, self.S_ALL]
            s.insert(self.GLOBAL_AXIS_NUMBERS[self.slicing_axis], getattr(self, '{}_index'.format(self.slicing_axis)))
            return data[s]
        else:
            return data

    def set_axis_mapping(self):
        self._m_local2global = {}
        self._m_global2local = {}
        local_axis_number = 0
        for global_axis_number, global_axis_name in enumerate(self.GLOBAL_AXIS_NAMES):
            if global_axis_name != self.slicing_axis:
                local_axis_name = self.LOCAL_AXIS_NAMES[local_axis_number]
                self._m_local2global[local_axis_name] = global_axis_name
                self._m_global2local[global_axis_name] = local_axis_name
                local_axis_number += 1
        self.logger.info("{}: local2global: {}".format(self.name, self._m_local2global))
        self.logger.info("{}: global2local: {}".format(self.name, self._m_global2local))

    def get_global_axis_name(self, local_axis_name):
        return self._m_local2global[local_axis_name]
        
    def get_local_axis_name(self, global_axis_name):
        return self._m_global2local[global_axis_name]
        
    def set_default_clips(self):
        clip_symmetric = self.attributes["clip_symmetric"]
        clip_auto = self.attributes["clip_auto"]
        clip = self.attributes["clip"]
        clip_min = self.attributes["clip_min"]
        clip_max = self.attributes["clip_max"]
        if clip is not None:
            self.clip = clip
        if clip_min is not None:
            self.clip_min = clip_min
        if clip_max is not None:
            self.clip_max = clip_max
        if clip_symmetric is None:
            if clip is not None:
                clip_symmetric = True
            else:
                clip_symmetric = False
        if clip_auto is None:
            if clip is None and (clip_min is None or clip_max is None):
                clip_auto = True
            else:
                clip_auto = False
        self.clip_symmetric = clip_symmetric
        self.clip_auto = clip_auto

    def close_uis(self):
        # locks on exit !
        super(Controller, self).close_uis()
        
    def close_views(self):
        for view in self.views:
            view.close_uis()
        del self.views[:]

    ### D e f a u l t s :
    def _colorbar_default(self):
        return self.attributes["colorbar"]

    def _colormap_default(self):
        return self.attributes["colormap"]

    def _data_shape_default(self):
        return 'x'.join(str(d) for d in self.shape)

    def _is4D_default(self):
        return len(self.shape) == 4

    def _axis_index_default(self, axis_name):
        if self.attributes.get(axis_name, None) is None:
            h = getattr(self, "{}_high".format(axis_name))
            l = getattr(self, "{}_low".format(axis_name))
            return (h - l) // 2
        else:
            return self.attributes[axis_name]

    def _w_index_default(self):
        return self._axis_index_default('w')

    def _x_index_default(self):
        return self._axis_index_default('x')

    def _y_index_default(self):
        return self._axis_index_default('y')

    def _z_index_default(self):
        return self._axis_index_default('z')

    def _slicing_axis_default(self):
        slicing_axis = self.attributes.get("slicing_axis", None) 
        if slicing_axis is None:
            slicing_axis = self.GLOBAL_AXIS_NAMES[0]
        return slicing_axis

    def on_change_axis(self, global_axis_name):
        global_attribute = '{}_index'.format(global_axis_name)
        self.log_trait_change(global_attribute)
        if global_axis_name == self.slicing_axis:
            #self.logger.error("{}: changing the slicing axis is not supported yet".format(self.name))
            for view in self.views:
                local_volume = self.get_local_volume(self.views_data[view])
                view.set_volume(local_volume)
            self.update_data_range()
            self.update_clip_range()
        else:
            local_axis_name = self.get_local_axis_name(global_axis_name)
            local_attribute = '{}_index'.format(local_axis_name)
            self.apply_attribute(local_attribute, getattr(self, global_attribute))

    def set_clip(self):
        if self.clip_auto:
            self.clip_min, self.clip_max = self.data_min, self.data_max
            self.clip = max(abs(self.clip_min), abs(self.clip_max))
        self.clip_readonly = not self.clip_auto
        self.clip_range_readonly = not self.clip_auto
        self.clip_visible = self.clip_symmetric
        self.clip_range_visible = not self.clip_symmetric

    def get_clip_range(self):
        if self.clip_auto:
            clip_min = float(self.data_min)
            clip_max = float(self.data_max)
            clip = max(abs(clip_min), abs(clip_max))
        else:
            clip_min = self.clip_min
            clip_max = self.clip_max
            clip = self.clip
        if self.clip_symmetric:
            return -clip, clip
        else:
            return clip_min, clip_max

    def update_clip_range(self):
        clip_min, clip_max = self.get_clip_range()
        self.logger.info("{}: applying clip {} <-> {}".format(self.name, clip_min, clip_max))
        for view in self.views:
            view.update_clip_range(clip_min, clip_max)

    ### T r a t s   c h a n g e s :
    @on_trait_change('colorbar')
    def on_change_colorbar(self):
        for view in self.views:
            view.enable_colorbar(self.colorbar)

    @on_trait_change('colormap')
    def on_change_colormap(self):
        for view in self.views:
            view.set_colormap(self.colormap)

    @on_trait_change('data_min,data_max')
    def on_change_data_range(self):
        self.set_clip()
        self.update_clip_range()

    @on_trait_change('clip_symmetric')
    def on_change_clip_symmetric(self):
        self.set_clip()
        self.update_clip_range()

    @on_trait_change('clip_auto')
    def on_change_clip_auto(self):
        self.set_clip()
        self.update_clip_range()

    @on_trait_change('clip')
    def on_change_clip(self):
        self.logger.debug("{}: clip: auto={}, symmetric={}, clip={}".format(self.name, self.clip_auto, self.clip_symmetric, self.clip))
        if self.clip_symmetric:
            self.update_clip_range()
     
    @on_trait_change('clip_min,clip_max')
    def on_change_clip(self):
        self.logger.debug("{}: clip: auto={}, symmetric={}, clip_min={}, clip_max={}".format(self.name, self.clip_auto, self.clip_symmetric, self.clip_min, self.clip_max))
        if not self.clip_symmetric:
            self.update_clip_range()
     
    @on_trait_change('w_index')
    def on_change_w_index(self):
        self.on_change_axis('w')

    @on_trait_change('x_index')
    def on_change_x_index(self):
        self.on_change_axis('x')

    @on_trait_change('y_index')
    def on_change_y_index(self):
        self.on_change_axis('y')

    @on_trait_change('z_index')
    def on_change_z_index(self):
        self.on_change_axis('z')
       

    controller_view = View(
        HGroup(
            Group(
                Item(
                    'data_shape',
                    label="Shape",
                    style="readonly",
                ),
                Item(
                    'slicing_axis',
                    editor=EnumEditor(
                        values=GLOBAL_AXIS_NAMES

                    ),
                    label="Slicing dim",
                    enabled_when='is4D',
                    visible_when='is4D',
                    #emphasized=True,
                    style="readonly",
                    tooltip="the slicing dimension",
                    help="4D volumes are sliced along the 'slicing dimension'; it is possible to change the value of this dimension using the related slider",
                ),
                Item(
                    '_',
                    enabled_when='is4D',
                    visible_when='is4D',
                ),
                Item(
                    'w_index',
                    editor=RangeEditor(
                        enter_set=True,
                        low_name='w_low',
                        high_name='w_high',
                        format="%d",
                        #label_width=LABEL_WIDTH,
                        mode="auto",
                    ),
                    enabled_when='is4D',
                    visible_when='is4D',
                    tooltip="the w dimension",
                ),
                Item(
                    'x_index',
                    editor=RangeEditor(
                        enter_set=True,
                        low_name='x_low',
                        high_name='x_high',
                        format="%d",
                        #label_width=LABEL_WIDTH,
                        mode="slider",
                    ),
                    format_str="%<8s",
                    tooltip="the x dimension",
                ),
                Item(
                    'y_index',
                    editor=RangeEditor(
                        enter_set=True,
                        low_name='y_low',
                        high_name='y_high',
                        format="%d",
                        #label_width=LABEL_WIDTH,
                        mode="slider",
                    ),
                    format_str="%<8s",
                    tooltip="the y dimension",
                ),
                Item(
                    'z_index',
                    editor=RangeEditor(
                        enter_set=True,
                        low_name='z_low',
                        high_name='z_high',
                        format="%d",
                        #label_width=LABEL_WIDTH,
                        mode="slider",
                    ),
                    format_str="%<8s",
                    tooltip="the z dimension",
                ),
                '_',
                Item(
                    'colorbar',
                    editor=BooleanEditor(
                    ),
                    label="Colorbar",
                ),
                Item(
                    'colormap',
                    editor=EnumEditor(
                        values=COLORMAPS,

                    ),
                    label="Colormap",
                ),
                '_',
                Item(
                    'data_min',
                    label="Data min",
                    style="readonly",
                ),
                Item(
                    'data_max',
                    label="Data max",
                    style="readonly",
                ),
                Item(
                    'clip_auto',
                    editor=BooleanEditor(),
                    label="Automatic",
                    tooltip="makes clip automatic",
                    help="if set, clip is taken from data range"
                ),
                Item(
                    'clip_symmetric',
                    editor=BooleanEditor(),
                    label="Symmetric",
                    tooltip="makes clip symmetric",
                    help="if set, clip_min=-clip, clip_max=+clip",
                ),
                Item(
                    'clip',
                    label="Clip",
                    visible_when='clip_visible',
                    enabled_when='clip_readonly',
                ),
                Item(
                    'clip_min',
                    label="Clip min",
                    visible_when='clip_range_visible',
                    enabled_when='clip_range_readonly',
                ),
                Item(
                    'clip_max',
                    label="Clip max",
                    visible_when='clip_range_visible',
                    enabled_when='clip_range_readonly',
                ),

            ),
        ),
        buttons=[UndoButton, RevertButton, close_button],
        handler=ControllerHandler(),
        resizable=True,
        title="untitled",
    )
class ContributedUI(HasTraits):
    """An object which contains a custom UI for a particular workflow file."""

    #: Name for the UI in selection screen
    name = Str()

    #: Description of the UI
    desc = Str()

    #: List of plugin ids and versions required for this UI
    required_plugins = Dict(Str, Int)

    #: Data for a premade workflow
    workflow_data = Dict()

    #: A Group of Item(s) to show in the UI for this workflow
    workflow_group = Instance(Group)

    #: Event to request a workflow run.
    run_workflow = Event()

    run_workflow_action = Action(
        name="Run Workflow", action="run_workflow"
    )

    #: Event to update a workflow.
    update_workflow = Event()

    update_workflow_action = Action(
        name="Update Workflow", action="update_workflow"
    )

    def default_traits_view(self):
        # Add 'Run Workflow', 'Update Workflow' and 'Cancel' actions as part of
        # the default view.
        return View(
            self.workflow_group,
            buttons=[
                self.run_workflow_action, self.update_workflow_action,
                'Cancel'
            ]
        )

    def create_workflow(self, factory_registry):
        """Create a Workflow from this object's :attr:`workflow_data`

        Parameters
        ----------
        factory_registry: IFactoryRegistry
            The factory registry required by WorkflowReader
        """
        reader = WorkflowReader(factory_registry=factory_registry)
        wf_dict = reader.parse_data(self.workflow_data)
        wf = Workflow.from_json(
            factory_registry,
            wf_dict
        )
        return wf

    def _required_plugins_default(self):
        plugin_list = search(self.workflow_data, "id")
        required_plugins = {}
        for plugin_id in plugin_list:
            plugin_name, plugin_version = parse_id(plugin_id)
            required_plugins[plugin_name] = plugin_version

        return required_plugins
Beispiel #21
0
                       label="nice_name",
                       view=demo_file_view),
        ObjectTreeNode(node_for=[DemoContentFile],
                       label="nice_name",
                       view=demo_content_view),
        ObjectTreeNode(node_for=[DemoImageFile],
                       label="nice_name",
                       view=demo_content_view),
    ],
    selected='selected_node',
)

next_tool = Action(
    name='Next',
    image=ImageResource("next"),
    tooltip="Go to next file",
    action="do_next",
    enabled_when="_next_node is not None",
)

previous_tool = Action(
    name='Previous',
    image=ImageResource("previous"),
    tooltip="Go to next file",
    action="do_previous",
    enabled_when="_previous_node is not None",
)

parent_tool = Action(
    name='Parent',
    image=ImageResource("parent"),
Beispiel #22
0
                                    name="Inc_ref",
                                    index_scale="log",
                                )

            except Exception as err:
                error_message = f"The following error occured while processing item: {item_name}:\n \
\t{err}\nThe configuration was NOT saved."

                the_pybert.log(
                    "Exception raised by pybert.pybert_view.MyHandler.do_load_data().",
                    exception=RuntimeError(error_message))


# These are the "globally applicable" buttons referred to in pybert.py,
# just above the button definitions (approx. line 580).
run_sim = Action(name="Run", action="do_run_simulation")
stop_sim = Action(name="Stop", action="do_stop_simulation")
save_data = Action(name="Save Results", action="do_save_data")
load_data = Action(name="Load Results", action="do_load_data")
save_cfg = Action(name="Save Config.", action="do_save_cfg")
load_cfg = Action(name="Load Config.", action="do_load_cfg")

# Main window layout definition.
traits_view = View(
    Group(
        VGroup(
            HGroup(
                VGroup(
                    HGroup(  # Simulation Control
                        VGroup(
                            Item(
class Visualization(HasTraits):
    logo = []
    # start_pause_update = Button()
    check_cargo_thread = Instance(CheckCargoThread)
    # active_update = False
    update_scene_thread = Instance(UpdateSceneThread)
    # update_display_thread = Instance(UpdateDisplayThread)
    scene = Instance(MlabSceneModel, ())
    display = Instance(TextDisplay, ())
    log = Instance(LogDisplay, ())

    plots = []
    idx = [-1, 0]                   # idx = [curr_idx, next_idx]

    def __init__(self, display=None):
        # Do not forget to call the parent's __init__
        HasTraits.__init__(self)
        if display:
            self.display = display
        self.check_cargo_thread = CheckCargoThread(self.idx)
        # self.check_cargo_thread.idx = self.idx
        self.check_cargo_thread.wants_abort = False
        self.check_cargo_thread.start()

        self.update_scene_thread = UpdateSceneThread(self.idx)
        self.update_scene_thread.wants_abort = False
        self.update_scene_thread.scene = self.scene
        self.update_scene_thread.plots = self.plots
        self.update_scene_thread.display = self.display
        self.update_scene_thread.log = self.log
        self.update_scene_thread.start()
        # self.active_update = True


    def do_start_pause_update(self):
        if self.check_cargo_thread and not self.check_cargo_thread.isAlive():
            return
        if self.update_scene_thread and self.update_scene_thread.isAlive():
            self.plots = self.update_scene_thread.plots
            self.update_scene_thread.wants_abort = True
            self.log.update('Active update is off!\n')
        else:
            self.update_scene_thread = UpdateSceneThread(self.idx)
            self.update_scene_thread.wants_abort = False
            self.update_scene_thread.scene = self.scene
            self.update_scene_thread.plots = self.plots
            self.update_scene_thread.display = self.display
            self.update_scene_thread.log = self.log
            self.update_scene_thread.start()
            self.log.update('Active update is on!\n')

    def do_stop_all_threading(self):
        if self.check_cargo_thread:
            self.check_cargo_thread.wants_abort = True
        if self.update_scene_thread:
            self.update_scene_thread.wants_abort = True
        self.log.update('All threading stopped!\n')


    def do_prev_cargo(self):
        if self.check_cargo_thread and not self.check_cargo_thread.isAlive():
            return
        if self.update_scene_thread.isAlive():
            self.log.update('Still in Active!  Please pause first!\n')
            return
        else:
            if self.idx[0] == 0:
                self.log.update('Have reached the first cargo!\n')
                return
            else:
                self.idx[0] -= 1
        # self.update_scene()
        self.log.update('Change to Cargo ' + str(self.idx[0]+1) + " (" + str(
            self.idx[1]) + "). \nClick 'Refresh' to start plot\n")
        # self.update_scene_thread.updateOne()
        return

    def do_next_cargo(self):
        if self.check_cargo_thread and not self.check_cargo_thread.isAlive():
            return
        if self.update_scene_thread.isAlive():
            self.log.update('Still in Active!  Please pause first!\n')
            return
        else:
            if self.idx[0] == self.idx[1]-1:
                self.log.update('Have reached the last cargo!\n')
                return
            else:
                self.idx[0] += 1
        # self.update_scene()
        self.log.update('Change to Cargo ' + str(self.idx[0]+1) + " (" + str(
            self.idx[1]) + "). \nClick 'Refresh' to start plot\n")
        # self.update_scene_thread.updateOne()
        return

    def do_refresh_scene(self):
        if self.check_cargo_thread and not self.check_cargo_thread.isAlive():
            return
        if self.update_scene_thread.isAlive():
            self.log.update('Still in Active!  Please pause first!\n')
            return
        self.update_scene_thread.updateOne()
        return

    def do_prev_cargo_show(self):
        if self.check_cargo_thread and not self.check_cargo_thread.isAlive():
            return
        if self.update_scene_thread.isAlive():
            self.log.update('Still in Active!  Please pause first!\n')
            return
        else:
            if self.idx[0] == 0:
                self.log.update('Have reached the first cargo!\n')
                return
            else:
                self.idx[0] -= 1
        # self.update_scene()
        self.do_refresh_scene()
        return

    def do_next_cargo_show(self):
        if self.check_cargo_thread and not self.check_cargo_thread.isAlive():
            return
        if self.update_scene_thread.isAlive():
            self.log.update('Still in Active!  Please pause first!\n')
            return
        else:
            if self.idx[0] == self.idx[1]-1:
                self.log.update('Have reached the last cargo!\n')
                return
            else:
                self.idx[0] += 1
        # self.update_scene()
        self.do_refresh_scene()
        return

    def do_restart(self):
        self.do_stop_all_threading()
        time.sleep(1)
        self.idx = [-1, 0]
        self.plots = []
        self.check_cargo_thread = CheckCargoThread(self.idx)
        self.check_cargo_thread.wants_abort = False
        self.check_cargo_thread.start()

        self.update_scene_thread = UpdateSceneThread(self.idx)
        self.update_scene_thread.wants_abort = False
        self.update_scene_thread.scene = self.scene
        self.update_scene_thread.plots = self.plots
        self.update_scene_thread.display = self.display
        self.update_scene_thread.log = self.log
        self.update_scene_thread.start()

        self.log.update('Threading restarted!\n')

    def do_close(self):
        self.do_stop_all_threading()


    start_pause_update = Action(name="Start/Pause", action="do_start_pause_update")
    stop_all_threading = Action(name="Terminate", action="do_stop_all_threading")
    prev_cargo = Action(name="Prev Cargo", action="do_prev_cargo")
    next_cargo = Action(name="Next Cargo", action="do_next_cargo")
    refresh_scene = Action(name="Refresh", action="do_refresh_scene")
    prev_cargo_show = Action(name="Show Prev", action="do_prev_cargo_show")
    next_cargo_show = Action(name="Show Next", action="do_next_cargo_show")
    close_button = Action(name="Close", action="do_close")
    restart = Action(name='Restart', action='do_restart')

    view = View(
        Group(
            HGroup(
                Item('logo',
                     editor=ImageEditor(scale=True, image=ImageResource('panasonic-logo-small', search_path=['./res'])),
                     show_label=False),
                Item('logo', editor=ImageEditor(scale=True, image=ImageResource('umsjtu-logo', search_path=['./res'])),
                     show_label=False),
                # orientation='horizontal'
            ),
            HGroup(
                Item('scene', editor=SceneEditor(scene_class=MayaviScene), height=500, width=500, show_label=False),
                Group(
                    Item('log', style='custom', height=60, show_label=False),
                    Item('display', style='custom', height=60, show_label=False),
                ),

                # orientation='horizontal'
                # Item('start_pause_update', show_label=False),
                # Item('last_cargo', show_label=False),
                # Item('next_cargo', show_label=False),
                # Item('text', show_label=False, springy=True, height=100, style='custom'),
            ),
            orientation='vertical'
        ),
        buttons=[refresh_scene, prev_cargo, next_cargo, prev_cargo_show, next_cargo_show, start_pause_update, restart, stop_all_threading],
        )

    def __del__(self):
        # print "END"
        if self.check_cargo_thread:
            self.check_cargo_thread.wants_abort = True
        if self.update_scene_thread:
            self.update_scene_thread.wants_abort = True
Beispiel #24
0
class MaterialView(ABCView):
    """ Plots index of refraction and dielectric function of a material in real
    time, by listening to delegated traits.

    Important: Chose not to have direct update listeners from model.earray.
    Materials still need to physically do "update()" (called mview()),
    but now plots have access to Material for metdata attributes.
    
    """

    eplot = Instance(Plot)
    nplot = Instance(Plot)

    earray = DelegatesTo('model')
    narray = DelegatesTo('model')

    ToggleReal = Action(name="Toggle Real", action="togimag")
    ToggleImag = Action(name="Toggle Imaginary", action="togreal")

    interpolation = DelegatesTo('model')
    extrapolation = DelegatesTo('model')

    # Custom plot_title disallowed

    view = View(
        VGroup(
            HGroup(
                Item('interpolation',
                     label='Interpolation',
                     visible_when='interpolation is not None'),
                Item('extrapolation',
                     label='Extrapolate',
                     visible_when='interpolation is not None')),
            Tabbed(
                Item('eplot',
                     editor=ComponentEditor(),
                     dock='tab',
                     label='Permittivity'),
                Item('nplot',
                     editor=ComponentEditor(),
                     dock='tab',
                     label='Index'),
                show_labels=False  #Side label not tab label
            ),
        ),
        width=800,
        height=600,
        buttons=['Undo'],
        resizable=True)

    def _earray_changed(self):
        self.update_data()

    def create_plots(self):
        self.eplot = ToolbarPlot(self.data)
        self.nplot = ToolbarPlot(self.data)

        # Name required or will appear twice in legend!
        plot_line_points(self.eplot, ('x', 'er'), color='orange', name='e1')
        plot_line_points(self.eplot, ('x', 'ei'), color='green', name='ie2')
        plot_line_points(self.nplot, ('x', 'nr'), color='orange', name='n')
        plot_line_points(self.nplot, ('x', 'ni'), color='green', name='ik')

        self.add_tools_title(self.eplot, 'Dielectric Function vs. Wavelength')
        self.add_tools_title(self.nplot, 'Index of Refraction vs. Wavelength ')

    def update_data(self):
        """Method to update plots; draws them if they don't exist; otherwise 
        simply updates the data
        """
        self.data.set_data('er', self.earray.real)
        self.data.set_data('nr', self.narray.real)
        self.data.set_data('ei', self.earray.imag)
        self.data.set_data('ni', self.narray.imag)
_INFO = "information"

# Create an empty view and menu for objects that have no data to display:
no_view = View()
no_menu = Menu()

# -------------------
# Actions!
# -------------------

# A string to be used as the enabled_when argument for the actions.
# For reference, in the enabled_when expression namespace, handler is
# the WorkflowTree instance, object is the modelview for the selected node

# MCO Actions
new_mco_action = Action(name='New MCO...', action='new_mco')
delete_mco_action = Action(name='Delete', action='delete_mco')

# Notification Listener Actions
new_notification_listener_action = Action(name='New Notification Listener...',
                                          action='new_notification_listener')
delete_notification_listener_action = Action(
    name='Delete', action='delete_notification_listener')

# Execution Layer Actions
new_layer_action = Action(name="New Layer...", action='new_layer')
delete_layer_action = Action(name='Delete', action='delete_layer')

# DataSource Actions
new_data_source_action = Action(name='New DataSource...',
                                action='new_data_source')
Beispiel #26
0
class ActorViewer(HasTraits):

    # The scene model.
    scene = Instance(MlabSceneModel, ())

    ######################
    # Using 'scene_class=MayaviScene' adds a Mayavi icon to the toolbar,
    # to pop up a dialog editing the pipeline.
    view = View(
        Item(name='scene',
             editor=SceneEditor(scene_class=MayaviScene),
             show_label=False,
             resizable=True,
             height=600,
             width=1000),
        menubar=MenuBar(
            Menu(
                Action(name="Load Gifti",
                       action="opengifti"),  # see Controller for
                Action(name="Inflate Gii", action="inflategii"),
                Action(name="Template", action="DoTemplate"),
                Action(name="Load Overlay",
                       action="loadoverlay"),  # these callbacks
                Action(name="Load Network", action="loadnetwork"),
                Separator(),
                CloseAction,
                name="File"), ),
        title="ByBP: AAL90 Brain Plotter",
        resizable=True)

    def __init__(self, **traits):
        HasTraits.__init__(self, **traits)
        #self.DoTemplate()

    def DoTemplate(self):
        v, f = template()
        self.DoPlot(v, f)

    def DoPlot(self, v, f):

        clf()
        self.pts = self.scene.mlab.triangular_mesh(v[:, 0],
                                                   v[:, 1],
                                                   v[:, 2],
                                                   f,
                                                   color=(1, 1, 1),
                                                   opacity=0.3)
        self.scene.mlab.get_engine().scenes[0].scene.x_plus_view()
        self.scene.mlab.draw()
        self.scene.mlab.view(0., 0.)
        self.v = v
        self.f = f
        ActorViewer.v = v
        ActorViewer.f = f
        ActorViewer.plot = self
        return self

    def opengifti(self):
        G = GetGifti()
        G.configure_traits()

    def inflategii(self):
        iG = GetGiftiInflate()
        iG.configure_traits()

    def loadoverlay(self):

        o = LoadOverlay90()
        o.configure_traits()

    def alignoverlaydraw(self, o):
        #o = self.o

        y = alignoverlay(self.v, self.f, o)
        v = self.v  # get these from store in ActorViewer
        f = self.f
        ActorViewer.y = y
        a = 0.3

        #fig = mlab.figure(1, bgcolor=(0, 0, 0))
        #ActorViewer.plot.pts.mlab_source.set(x = v[:,0], y = v[:,1], z = v[:,2], triangles=f, scalars=y[:,0],opacity=a)
        ActorViewer.plot.pts = self.scene.mlab.triangular_mesh(v[:, 0],
                                                               v[:, 1],
                                                               v[:, 2],
                                                               f,
                                                               scalars=y[:, 0],
                                                               opacity=a)
        #pts = self.scene.mlab.triangular_mesh(v[:,0], v[:,1], v[:,2], f,scalars=y[:,0],opacity=a)
        ActorViewer.plot.scene.mlab.get_engine().scenes[0].scene.x_plus_view()
        ActorViewer.plot.scene.mlab.view(0., 0.)
        ActorViewer.plot.scene.mlab.colorbar(title="overlay")
        ActorViewer.plot.scene.mlab.draw()

    def loadnetwork(self):

        n = LoadNetwork90()
        n.configure_traits()

    def PlotNet(self):

        xx = self.xx
        yy = self.yy
        vv = self.vv

        jet = cm.get_cmap('jet')
        cNorm = cm.colors.Normalize(vmin=vv.min(), vmax=vv.max())
        scalarMap = cm.ScalarMappable(norm=cNorm, cmap=jet)

        for i in range(len(xx)):
            colorVal = scalarMap.to_rgba(vv[i])
            colorVal = colorVal[0:3]
            ActorViewer.plot.scene.mlab.plot3d([xx[i][0], yy[i][0]],
                                               [xx[i][1], yy[i][1]],
                                               [xx[i][2], yy[i][2]],
                                               color=colorVal,
                                               line_width=10,
                                               tube_radius=2)
            ActorViewer.plot.scene.mlab.points3d(xx[i][0],
                                                 xx[i][1],
                                                 xx[i][2],
                                                 color=(1, 0, 0),
                                                 scale_factor=5)
            ActorViewer.plot.scene.mlab.points3d(yy[i][0],
                                                 yy[i][1],
                                                 yy[i][2],
                                                 color=(1, 0, 0),
                                                 scale_factor=5)
        ActorViewer.plot.scene.mlab.colorbar(title="Network")
Beispiel #27
0
 def _man_temp_button_default(self):
     return Action(name='Manage templates...', action='do_manage')
Beispiel #28
0
                                    type="line", color="darkmagenta",  name="Cum_ref", index_scale='log')
                        if(has_both):
                            for (ix, prefix) in [
                                    (1, 'tx'),
                                    (2, 'ctle'),
                                    (3, 'dfe'), ]:
                                item_name = prefix + "_" + suffix + "_ref"
                                container[ix].plot(("f_GHz", item_name),
                                        type="line", color="darkcyan",  name="Inc_ref", index_scale='log')

            except Exception as err:
                print item_name
                err.message = "The following error occured:\n\t{}\nThe waveform data was NOT loaded.".format(err.message)
                the_pybert.handle_error(err)

run_simulation = Action(name="Run",     action="do_run_simulation")
save_data = Action(name="Save Results", action="do_save_data")
load_data = Action(name="Load Results", action="do_load_data")
    
# Main window layout definition.
traits_view = View(
    Group(
        VGroup(
            HGroup(
                HGroup(
                    VGroup(
                        Item(name='bit_rate',    label='Bit Rate (Gbps)',  tooltip="bit rate", show_label=True, enabled_when='True',
                            editor=TextEditor(auto_set=False, enter_set=True, evaluate=float)
                        ),
                        Item(name='nbits',       label='Nbits',    tooltip="# of bits to run",
                            editor=TextEditor(auto_set=False, enter_set=True, evaluate=int)
Beispiel #29
0
from __future__ import print_function
from __future__ import absolute_import

from new import instancemethod
from threading import Thread
from time import sleep

from pyface.api import GUI
from traits.api import Event, HasTraits, String
from traitsui.api import Action, Handler, Item, TextEditor, View

from .utils import icon

auto_survey_button = Action(name="Auto Survey",
                            action="set_execute_callback_true",
                            show_label=False)
continue_via_serial_button = Action(name="Continue via Serial...",
                                    action="set_execute_callback_true",
                                    show_label=False)
update_button = Action(name="Update",
                       action="set_execute_callback_true",
                       show_label=False)
reset_button = Action(name="Reset",
                      action="set_execute_callback_true",
                      show_label=False)
close_button = Action(name="Close",
                      action="set_execute_callback_false",
                      show_label=False)
ok_button = Action(name="Ok",
                   action="set_execute_callback_true",
Beispiel #30
0
class MainWindow(HasTraits):
    '''The main window for the Beams application.'''

    # Current folder for file dialog
    _current_folder = None

    camera = Instance(Camera)
    id_string = DelegatesTo('camera')
    resolution = DelegatesTo('camera')
    status = Str()
    screen = Instance(CameraImage, args=())
    cmap = DelegatesTo('screen')
    display_frame_rate = Range(1, 60, 15)
    transform_plugins = List(Instance(TransformPlugin))
    display_plugins = List(Instance(DisplayPlugin))
    acquisition_thread = Instance(AcquisitionThread)  # default: None
    processing_thread = Instance(ProcessingThread)  # default: None
    processing_queue = Instance(queue.Queue, kw={'maxsize': MAX_QUEUE_SIZE})
    cameras_dialog = Instance(CameraDialog, args=())

    # Actions
    about = Action(name='&About...',
                   tooltip='About Beams',
                   image=find_icon('about'),
                   action='action_about')
    save = Action(name='&Save Image',
                  accelerator='Ctrl+S',
                  tooltip='Save the current image to a file',
                  image=find_icon('save'),
                  action='action_save')
    quit = Action(name='&Quit',
                  accelerator='Ctrl+Q',
                  tooltip='Exit the application',
                  image=find_icon('quit'),
                  action='_on_close')
    choose_camera = Action(name='Choose &Camera...',
                           tooltip='Choose from a number of camera plugins',
                           action='action_choose_camera')
    take_video = Action(name='Take &Video',
                        style='toggle',
                        tooltip='Start viewing the video feed from the camera',
                        image=find_icon('camera-video'),
                        action='action_take_video')
    take_photo = Action(name='Take &Photo',
                        tooltip='Take one snapshot from the camera',
                        image=find_icon('camera-photo'),
                        action='action_take_photo',
                        enabled_when='self.take_video.checked == False')

    find_resolution = Button()
    view = View(
        VGroup(
            HSplit(
                Tabbed(
                    VGroup(Item('id_string', style='readonly', label='Camera'),
                           Item('resolution',
                                style='readonly',
                                format_str=u'%i \N{multiplication sign} %i'),
                           Group(Item('camera',
                                      show_label=False,
                                      style='custom'),
                                 label='Camera properties',
                                 show_border=True),
                           label='Camera'),
                    VGroup(Item('cmap',
                                label='Color scale',
                                editor=EnumEditor(
                                    values={
                                        None: '0:None (image default)',
                                        gray: '1:Grayscale',
                                        bone: '2:Bone',
                                        pink: '3:Copper',
                                        jet: '4:Rainbow (considered harmful)',
                                        isoluminant: '5:Isoluminant',
                                        awesome: '6:Low-intensity contrast'
                                    })),
                           Item('screen',
                                show_label=False,
                                editor=ColorMapEditor(width=256)),
                           Item('display_frame_rate'),
                           label='Video'),
                    # FIXME: mutable=False means the items can't be deleted,
                    # added, or rearranged, but we do actually want them to
                    # be rearranged.
                    VGroup(Item('transform_plugins',
                                show_label=False,
                                editor=ListEditor(style='custom',
                                                  mutable=False)),
                           label='Transform'),
                    VGroup(Item('display_plugins',
                                show_label=False,
                                editor=ListEditor(style='custom',
                                                  mutable=False)),
                           label='Math')),
                Item('screen',
                     show_label=False,
                     width=640,
                     height=480,
                     style='custom')),
            Item('status', style='readonly', show_label=False)),
        menubar=MenuBar(
            # vertical bar is undocumented but it seems to keep the menu
            # items in the order they were specified in
            Menu('|', save, '_', quit, name='&File'),
            Menu(name='&Edit'),
            Menu(name='&View'),
            Menu('|',
                 choose_camera,
                 '_',
                 take_photo,
                 take_video,
                 name='&Camera'),
            Menu(name='&Math'),
            Menu(about, name='&Help')),
        toolbar=ToolBar('|', save, '_', take_photo, take_video),
        title='Beams',
        resizable=True,
        handler=MainHandler)

    def _find_resolution_fired(self):
        return self.view.handler.action_find_resolution(None)

    def _display_frame_rate_changed(self, value):
        self.processing_thread.update_frequency = value

    def _transform_plugins_default(self):
        plugins = []
        for name in ['Rotator', 'BackgroundSubtract']:
            module = __import__(name, globals(), locals(), [name])
            plugins.append(getattr(module, name)())
        return plugins

    def _display_plugins_default(self):
        plugins = []
        for name in [
                'BeamProfiler', 'MinMaxDisplay', 'DeltaDetector', 'Centroid'
        ]:
            module = __import__(name, globals(), locals(), [name])
            plugins.append(getattr(module, name)(screen=self.screen))
        return plugins

    def __init__(self, **traits):
        super(MainWindow, self).__init__(**traits)

        # Build the camera selection dialog box
        self.cameras_dialog.on_trait_change(self.on_cameras_response, 'closed')
        self.on_cameras_response()

        self.processing_thread = ProcessingThread(self, self.processing_queue,
                                                  self.display_frame_rate)
        self.processing_thread.start()

    def on_cameras_response(self):
        plugin_obj = self.cameras_dialog.get_plugin_object()
        try:
            self.select_plugin(plugin_obj)
        except ImportError:
            # some module was not available, select the dummy
            error(
                None, 'Loading the {} camera plugin failed. '
                'Taking you back to the dummy plugin.'.format(
                    plugin_obj['name']))
            self.cameras_dialog.select_fallback()
            info = self.cameras_dialog.get_plugin_info()
            self.select_plugin(*info)

    # Select camera plugin
    def select_plugin(self, plugin_obj):
        # Set up image capturing
        self.camera = plugin_obj()
        try:
            self.camera.open()
        except CameraError:
            error(None,
                  'No camera was detected. Did you forget to plug it in?')
            sys.exit()