Exemple #1
0
class ODMR( ManagedJob, GetSetItemsMixin ):
    """Provides ODMR measurements."""

    # starting and stopping
    keep_data = Bool(False) # helper variable to decide whether to keep existing data
    resubmit_button = Button(label='resubmit', desc='Submits the measurement to the job manager. Tries to keep previously acquired data. Behaves like a normal submit if sequence or time bins have changed since previous run.')
    submitp_button = Button(label='submitp', desc='Start pulsed ODMR')    
    
    # measurement parameters
    power = Range(low=-100., high=25., value=-5, desc='Power [dBm]', label='Power [dBm]', mode='text', auto_set=False, enter_set=True)
    frequency_begin = Range(low=1,      high=20e9, value=2.85e9,    desc='Start Frequency [Hz]',    label='Begin [Hz]', editor=TextEditor(auto_set=False, enter_set=True, evaluate=float, format_str='%e'))
    frequency_end   = Range(low=1,      high=20e9, value=2.88e9,    desc='Stop Frequency [Hz]',     label='End [Hz]', editor=TextEditor(auto_set=False, enter_set=True, evaluate=float, format_str='%e'))
    frequency_delta = Range(low=-20e9,   high=20e9, value=1e6,       desc='frequency step [Hz]',     label='Delta [Hz]', editor=TextEditor(auto_set=False, enter_set=True, evaluate=float, format_str='%2.2e'))
    
    powerp = Range(low=-100., high=25., value=-20, desc='Power [dBm]', label='Power [dBm]', mode='text', auto_set=False, enter_set=True)
    frequency_beginp = Range(low=1,      high=20e9, value=2.85e9,    desc='Start Frequency [Hz]',    label='Begin [Hz]', editor=TextEditor(auto_set=False, enter_set=True, evaluate=float, format_str='%e'))
    frequency_endp   = Range(low=1,      high=20e9, value=2.88e9,    desc='Stop Frequency [Hz]',     label='End [Hz]', editor=TextEditor(auto_set=False, enter_set=True, evaluate=float, format_str='%e'))
    frequency_deltap = Range(low=-20e9,   high=20e9, value=1e5,       desc='frequency step [Hz]',     label='Delta [Hz]', editor=TextEditor(auto_set=False, enter_set=True, evaluate=float, format_str='%2.2e'))
    
    t_pi  = Range(low=1., high=100000., value=1000., desc='length of pi pulse [ns]', label='pi [ns]', mode='text', auto_set=False, enter_set=True)
    laser = Range(low=1., high=10000., value=300., desc='laser [ns]', label='laser [ns]', mode='text', auto_set=False, enter_set=True)
    wait  = Range(low=1., high=10000., value=1000., desc='wait [ns]', label='wait [ns]', mode='text', auto_set=False, enter_set=True)
    pulsed = Bool(False, label='pulsed')
    seconds_per_point = Range(low=1e-4, high=1, value=20e-3, desc='Seconds per point', label='Seconds per point', mode='text', auto_set=False, enter_set=True)
    stop_time = Range(low=1., value=np.inf, desc='Time after which the experiment stops by itself [s]', label='Stop time [s]', mode='text', auto_set=False, enter_set=True)
    n_lines = Range (low=1, high=10000, value=50, desc='Number of lines in Matrix', label='Matrix lines', mode='text', auto_set=False, enter_set=True)
    
    # control data fitting
    perform_fit = Bool(False, label='perform fit')
    number_of_resonances = Trait( 'auto', String('auto', auto_set=False, enter_set=True), Int(10000., desc='Number of Lorentzians used in fit', label='N', auto_set=False, enter_set=True))
    threshold = Range(low=-99, high=99., value=-50., desc='Threshold for detection of resonances [%]. The sign of the threshold specifies whether the resonances are negative or positive.', label='threshold [%]', mode='text', auto_set=False, enter_set=True)
    
    # fit result
    fit_parameters = Array(value=np.array((np.nan, np.nan, np.nan, np.nan)))
    fit_frequencies = Array(value=np.array((np.nan,)), label='frequency [Hz]') 
    fit_line_width = Array(value=np.array((np.nan,)), label='line_width [Hz]') 
    fit_contrast = Array(value=np.array((np.nan,)), label='contrast [%]')

    # measurement data    
    frequency = Array()
    counts = Array()
    counts_matrix = Array()
    run_time = Float(value=0.0, desc='Run time [s]', label='Run time [s]')

    # plotting
    line_label  = Instance( PlotLabel )
    line_data   = Instance( ArrayPlotData )
    matrix_data = Instance( ArrayPlotData )
    line_plot   = Instance( Plot, editor=ComponentEditor() )
    matrix_plot = Instance( Plot, editor=ComponentEditor() )

    def __init__(self):
        super(ODMR, self).__init__()
        self._create_line_plot()
        self._create_matrix_plot()
        self.on_trait_change(self._update_line_data_index,      'frequency',            dispatch='ui')
        self.on_trait_change(self._update_line_data_value,      'counts',               dispatch='ui')
        self.on_trait_change(self._update_line_data_fit,        'fit_parameters',       dispatch='ui')
        self.on_trait_change(self._update_matrix_data_value,    'counts_matrix',        dispatch='ui')
        self.on_trait_change(self._update_matrix_data_index,    'n_lines,frequency',    dispatch='ui')
        self.on_trait_change(self._update_fit, 'counts,perform_fit,number_of_resonances,threshold', dispatch='ui')

    def _counts_matrix_default(self):
        return np.zeros( (self.n_lines, len(self.frequency)) )

    def _frequency_default(self):
        return np.arange(self.frequency_begin, self.frequency_end+self.frequency_delta, self.frequency_delta)

    def _counts_default(self):
        return np.zeros(self.frequency.shape)

    # data acquisition

    def apply_parameters(self):
        """Apply the current parameters and decide whether to keep previous data."""
        if self.pulsed==False:
            frequency = np.arange(self.frequency_begin, self.frequency_end+self.frequency_delta, self.frequency_delta)
            ha.Microwave().setPower(self.power)
            ha.Microwave().initSweep( self.frequency, self.power*np.ones(self.frequency.shape))
        else:
            frequency = np.arange(self.frequency_beginp, self.frequency_endp+self.frequency_deltap, self.frequency_deltap)
            ha.Microwave().setPower(self.powerp)
            ha.Microwave().initSweep( self.frequency, self.powerp*np.ones(self.frequency.shape))
            
        if not self.keep_data or np.any(frequency != self.frequency):
            self.frequency = frequency
            self.counts = np.zeros(frequency.shape)
            self.run_time = 0.0

        self.keep_data = True # when job manager stops and starts the job, data should be kept. Only new submission should clear data.

    def _run(self):
                
        try:
            self.state='run'
            self.apply_parameters()

            if self.run_time >= self.stop_time:
                self.state='done'
                return

            # if pulsed, turn on sequence
            if self.pulsed:
                ha.PulseGenerator().Sequence( 100 * [ (['laser','aom'],self.laser), ([],self.wait), (['mw'],self.t_pi) ] )
            else:
                ha.PulseGenerator().Open()

            n = len(self.frequency)

            """
            ha.Microwave().setOutput( self.power, np.append(self.frequency,self.frequency[0]), self.seconds_per_point)
            self._prepareCounter(n)
            """
            
            ha.Counter().configure(n, self.seconds_per_point, DutyCycle=0.8)
            time.sleep(0.5)

            while self.run_time < self.stop_time:
                start_time = time.time()
                if threading.currentThread().stop_request.isSet():
                    break
                ha.Microwave().resetListPos()
                counts = ha.Counter().run()                
                self.run_time += time.time() - start_time
                self.counts += counts
                self.counts_matrix = np.vstack( (counts, self.counts_matrix[:-1,:]) )
                self.trait_property_changed('counts', self.counts)
                """
                ha.Microwave().doSweep()
                
                timeout = 3.
                start_time = time.time()
                while not self._count_between_markers.ready():
                    time.sleep(0.1)
                    if time.time() - start_time > timeout:
                        print "count between markers timeout in ODMR"
                        break
                        
                counts = self._count_between_markers.getData(0)
                self._count_between_markers.clean()
                """
    
            if self.run_time < self.stop_time:
                self.state = 'idle'            
            else:
                self.state='done'
            ha.Microwave().setOutput( None, self.frequency_begin)
            ha.PulseGenerator().Light()
            ha.Counter().clear()
        except:
            logging.getLogger().exception('Error in odmr.')
            self.state = 'error'
        finally:
            ha.Microwave().setOutput( None, self.frequency_begin)

    # fitting
    def _update_fit(self):
        if self.perform_fit:
            N = self.number_of_resonances 
            if N != 'auto':
                N = int(N)
            try:
                p = fitting.fit_multiple_lorentzians(self.frequency,self.counts,N,threshold=self.threshold*0.01)
            except Exception:
                logging.getLogger().debug('ODMR fit failed.', exc_info=True)
                p = np.nan*np.empty(4)
        else:
            p = np.nan*np.empty(4)
        self.fit_parameters = p
        self.fit_frequencies = p[1::3]
        self.fit_line_width = p[2::3]
        N = len(p)/3
        contrast = np.empty(N)
        c = p[0]
        pp=p[1:].reshape((N,3))
        for i,pi in enumerate(pp):
            a = pi[2]
            g = pi[1]
            A = np.abs(a/(np.pi*g))
            if a > 0:
                contrast[i] = 100*A/(A+c)
            else:
                contrast[i] = 100*A/c
        self.fit_contrast = contrast
    
    
    # plotting
        
    def _create_line_plot(self):
        line_data = ArrayPlotData(frequency=np.array((0.,1.)), counts=np.array((0.,0.)), fit=np.array((0.,0.))) 
        line_plot = Plot(line_data, padding=8, padding_left=64, padding_bottom=32)
        line_plot.plot(('frequency','counts'), style='line', color='blue')
        line_plot.index_axis.title = 'Frequency [MHz]'
        line_plot.value_axis.title = 'Fluorescence counts'
        line_label = PlotLabel(text='', hjustify='left', vjustify='bottom', position=[64,128])
        line_plot.overlays.append(line_label)
        self.line_label = line_label
        self.line_data = line_data
        self.line_plot = line_plot
        
    def _create_matrix_plot(self):
        matrix_data = ArrayPlotData(image=np.zeros((2,2)))
        matrix_plot = Plot(matrix_data, padding=8, padding_left=64, padding_bottom=32)
        matrix_plot.index_axis.title = 'Frequency [MHz]'
        matrix_plot.value_axis.title = 'line #'
        matrix_plot.img_plot('image',
                             xbounds=(self.frequency[0],self.frequency[-1]),
                             ybounds=(0,self.n_lines),
                             colormap=Spectral)
        self.matrix_data = matrix_data
        self.matrix_plot = matrix_plot

    def _perform_fit_changed(self,new):
        plot = self.line_plot
        if new:
            plot.plot(('frequency','fit'), style='line', color='red', name='fit')
            self.line_label.visible=True
        else:
            plot.delplot('fit')
            self.line_label.visible=False
        plot.request_redraw()

    def _update_line_data_index(self):
        self.line_data.set_data('frequency', self.frequency*1e-6)
        self.counts_matrix = self._counts_matrix_default()

    def _update_line_data_value(self):
        self.line_data.set_data('counts', self.counts)

    def _update_line_data_fit(self):
        if not np.isnan(self.fit_parameters[0]):            
            self.line_data.set_data('fit', fitting.NLorentzians(*self.fit_parameters)(self.frequency))
            p = self.fit_parameters
            f = p[1::3]
            w = p[2::3]
            N = len(p)/3
            contrast = np.empty(N)
            c = p[0]
            pp=p[1:].reshape((N,3))
            for i,pi in enumerate(pp):
                a = pi[2]
                g = pi[1]
                A = np.abs(a/(np.pi*g))
                if a > 0:
                    contrast[i] = 100*A/(A+c)
                else:
                    contrast[i] = 100*A/c
            s = ''
            for i, fi in enumerate(f):
                s += 'f %i: %.6e Hz, HWHM %.3e Hz, contrast %.1f%%\n'%(i+1, fi, w[i], contrast[i])
            self.line_label.text = s

    def _update_matrix_data_value(self):
        self.matrix_data.set_data('image', self.counts_matrix)

    def _update_matrix_data_index(self):
        if self.n_lines > self.counts_matrix.shape[0]:
            self.counts_matrix = np.vstack( (self.counts_matrix,np.zeros((self.n_lines-self.counts_matrix.shape[0], self.counts_matrix.shape[1]))) )
        else:
            self.counts_matrix = self.counts_matrix[:self.n_lines]
        self.matrix_plot.components[0].index.set_data((self.frequency[0]*1e-6, self.frequency[-1]*1e-6),(0.0,float(self.n_lines)))

    # saving data
    
    def save_line_plot(self, filename):
        self.save_figure(self.line_plot, filename)

    def save_matrix_plot(self, filename):
        self.save_figure(self.matrix_plot, filename)
    
    def save_all(self, filename):
        self.save_line_plot(filename+'_ODMR_Line_Plot.png')
        self.save_matrix_plot(filename+'_ODMR_Matrix_Plot.png')
        self.save(filename+'_ODMR.pys')

    # react to GUI events

    def submit(self):
        """Start cw ODMR"""
        self.keep_data = False
        ManagedJob.submit(self)
        
    def submitp(self):
        """Start pulsed ODMR"""
        self.keep_data = False
        ManagedJob.submit(self)

    def resubmit(self):
        """Submit the job to the JobManager."""
        self.keep_data = True
        ManagedJob.submit(self)

    def _resubmit_button_fired(self):
        """React to start button. Submit the Job."""
        self.resubmit()
        
    def _submit_button_fired(self):
        self.pulsed=False
        super(ODMR,self)._submit_button_fired()
        
    def _submitp_button_fired(self):
        self.pulsed=True
        super(ODMR,self)._submit_button_fired()
            
    traits_view = View(VGroup(HGroup(Item('remove_button',   show_label=False),
                                     Item('resubmit_button', show_label=False),
                                     Item('priority', enabled_when='state != "run"'),
                                     Item('state', style='readonly'),
                                     Item('run_time', style='readonly',format_str='%.f'),
                                     Item('stop_time'),
                                     Item('pulsed', style='readonly'),
                                     ),
                              VGroup(HGroup(Item('submit_button',   show_label=False),
                                            Item('power', width=-40, enabled_when='state != "run"'),
                                            Item('frequency_begin', width=-80, enabled_when='state != "run"'),
                                            Item('frequency_end', width=-80, enabled_when='state != "run"'),
                                            Item('frequency_delta', width=-80, enabled_when='state != "run"'),
                                            ),
                                     HGroup(Item('submitp_button',   show_label=False),
                                            #Item('pulsed', style='readonly'),
                                            Item('t_pi', width=-50, enabled_when='state != "run"'),
                                            Item('powerp', width=-40, enabled_when='state != "run"'),
                                            Item('frequency_beginp', width=-80, enabled_when='state != "run"'),
                                            Item('frequency_endp', width=-80, enabled_when='state != "run"'),
                                            Item('frequency_deltap', width=-80, enabled_when='state != "run"'),
                                            ),
                                     HGroup(Item('seconds_per_point', width=-40, enabled_when='state != "run"'),
                                            Item('laser', width=-50, enabled_when='state != "run"'),
                                            Item('wait', width=-50, enabled_when='state != "run"'),
                                            ),
                                     HGroup(Item('perform_fit'),
                                            Item('number_of_resonances', width=-60),
                                            Item('threshold', width=-60),
                                            Item('n_lines', width=-60),
                                            ),
                                     HGroup(Item('fit_contrast', style='readonly'),
                                            Item('fit_line_width', style='readonly'),
                                            Item('fit_frequencies', style='readonly'),
                                            ),
                                     ),
                              VSplit(Item('matrix_plot', show_label=False, resizable=True),
                                     Item('line_plot', show_label=False, resizable=True),
                                     ),
                              ),
                       menubar = MenuBar(Menu(Action(action='saveLinePlot', name='SaveLinePlot (.png)'),
                                              Action(action='saveMatrixPlot', name='SaveMatrixPlot (.png)'),
                                              Action(action='save', name='Save (.pyd or .pys)'),
                                              Action(action='saveAll', name='Save All (.png+.pys)'),
                                              Action(action='export', name='Export as Ascii (.asc)'),
                                              Action(action='load', name='Load'),
                                              Action(action='_on_close', name='Quit'),
                                              name='File')),
                       title='ODMR', width=900, height=800, buttons=[], resizable=True, handler=ODMRHandler
                       )

    get_set_items = ['frequency', 'counts', 'counts_matrix',
                     'fit_parameters', 'fit_contrast', 'fit_line_width', 'fit_frequencies',
                     'perform_fit', 'run_time',
                     'power', 'frequency_begin', 'frequency_end', 'frequency_delta',
                     'powerp', 'frequency_beginp', 'frequency_endp', 'frequency_deltap',
                     'laser', 'wait', 'pulsed', 't_pi',
                     'seconds_per_point', 'stop_time', 'n_lines',
                     'number_of_resonances', 'threshold',
                     '__doc__']
Exemple #2
0
class HexYoshiFormingProcess(HasTraits):
    '''
    Define the simulation task prescribing the boundary conditions, 
    target surfaces and configuration of the algorithm itself.
    '''

    u_max = Float(0.1, auto_set=False, enter_set=True, input=True)
    n_fold_steps = Int(30, auto_set=False, enter_set=True, input=True)
    n_load_steps = Int(30, auto_set=False, enter_set=True, input=True)

    factory_task = Property(Instance(FormingTask))
    '''Factory task generating the crease pattern.
    '''
    @cached_property
    def _get_factory_task(self):
        h = 2.0
        v = 4.0
        X_base = [
            [0, 0, 0],
            [-h, 0, 0],
            [-h / 2.0, -v, 0],
            [h / 2.0, -v, 0],
            [h / 2.0, 0, 0],
            [h / 2.0, v, 0],
            [-h / 2.0, v, 0],
        ]
        X_add_left = [[-2 * h, 0, 0], [-3. / 2. * h, -v, 0],
                      [-3. / 2. * h, v, 0], [-5. / 2. * h, -v, 0],
                      [-5. / 2. * h, v, 0], [-5. / 2. * h, 0, 0]]
        L_base = [[0, 1], [0, 2], [0, 3], [0, 4], [0, 5], [0, 6], [1, 2],
                  [2, 3], [3, 4], [4, 5], [5, 6], [6, 1]]
        L_add_left = [
            [1, 7],
            [1, 8],
            [1, 9],
            [7, 8],
            [7, 9],
            [2, 8],
            [6, 9],
            [7, 10],
            [7, 11],
            [8, 10],
            [9, 11],
            [7, 12],
            [10, 12],
            [11, 12],
        ]
        F_base = [[0, 1, 2], [0, 2, 3], [0, 3, 4], [0, 4, 5], [0, 5, 6],
                  [0, 6, 1]]
        F_add_left = [[1, 7, 8], [1, 7, 9], [1, 8, 2], [1, 9, 6], [7, 8, 10],
                      [7, 9, 11], [7, 12, 10], [7, 12, 11]]
        cp = CreasePatternState(X=X_base + X_add_left,
                                L=L_base + L_add_left,
                                F=F_base + F_add_left)
        return CustomCPFactory(formed_object=cp)

    psi_lines = List([6, 13, 11, 2, 4, 19, 20])

    psi_max = Float(-np.pi * 0.54)

    fold_gravity_angle_cntl = Property(Instance(FormingTask))
    '''Configure the simulation task.
    '''

    @cached_property
    def _get_fold_gravity_angle_cntl(self):
        fixed_nodes_x = fix([8], (0))
        fixed_nodes_y = fix([8, 2], (1))
        fixed_nodes_z = fix([8, 2, 9], (2))

        dof_constraints = fixed_nodes_x + fixed_nodes_z + fixed_nodes_y

        def FN(psi):
            return lambda t: psi * t

        psi_constr = [([(i, 1.0)], FN(self.psi_max)) for i in self.psi_lines]

        gu_psi_constraints = \
            GuPsiConstraints(forming_task=self.factory_task,
                             psi_constraints=psi_constr)

        gu_dof_constraints = GuDofConstraints(dof_constraints=dof_constraints)

        gu_constant_length = GuConstantLength()
        sim_config = SimulationConfig(goal_function_type='none',
                                      gu={
                                          'cl': gu_constant_length,
                                          'u': gu_dof_constraints,
                                          'psi': gu_psi_constraints
                                      },
                                      acc=1e-5,
                                      MAX_ITER=500,
                                      debug_level=0)

        st = SimulationTask(previous_task=self.factory_task,
                            config=sim_config,
                            n_steps=self.n_fold_steps)

        cp = st.formed_object
        cp.u[(1, 3, 5, 10, 11), 2] -= 0.5
        cp.u[(0, 4, 7, 12), 2] += 0.5

        return st

    fold_angle_cntl = Property(Instance(FormingTask))
    '''Configure the simulation task.
    '''

    @cached_property
    def _get_fold_angle_cntl(self):

        # Link the crease factory it with the constraint client
        gu_constant_length = GuConstantLength()

        psi_max = np.pi * 0.3
        gu_psi_constraints = \
            GuPsiConstraints(forming_task=self.factory_task,
                             psi_constraints=[([(1, 1.0)], lambda t: -psi_max * t),
                                              ])

        dof_constraints = fix([2], [0, 1, 2]) + fix([0], [1, 2]) \
            + fix([3], [2])
        gu_dof_constraints = GuDofConstraints(dof_constraints=dof_constraints)

        sim_config = SimulationConfig(goal_function_type='none',
                                      gu={
                                          'cl': gu_constant_length,
                                          'u': gu_dof_constraints,
                                          'psi': gu_psi_constraints
                                      },
                                      acc=1e-5,
                                      MAX_ITER=100)
        sim_task = SimulationTask(previous_task=self.factory_task,
                                  config=sim_config,
                                  n_steps=5)
        return sim_task
Exemple #3
0
class BetterSelectingZoom(AbstractOverlay, BetterZoom):
    """ Zooming tool which allows the user to draw a box which defines the
        desired region to zoom in to
    """

    #: The selection mode:
    #:
    #: range:
    #:   Select a range across a single index or value axis.
    #: box:
    #:   Perform a "box" selection on two axes.
    tool_mode = Enum("box", "range")

    #: Is the tool always "on"? If True, left-clicking always initiates
    #: a zoom operation; if False, the user must press a key to enter zoom mode.
    always_on = Bool(False)

    #: Defines a meta-key, that works with always_on to set the zoom mode. This
    #: is useful when the zoom tool is used in conjunction with the pan tool.
    always_on_modifier = Enum('control', 'shift', 'alt')

    #: The mouse button that initiates the drag.  If "None", then the tool
    #: will not respond to drag.  (It can still respond to mousewheel events.)
    drag_button = Enum("left", "right", None)

    #: The minimum amount of screen space the user must select in order for
    #: the tool to actually take effect.
    minimum_screen_delta = Int(10)

    #-------------------------------------------------------------------------
    # deprecated interaction controls, used for API compatability with
    # SimpleZoom
    #-------------------------------------------------------------------------

    #: Conversion ratio from wheel steps to zoom factors.
    wheel_zoom_step = Property(Float, depends_on='zoom_factor')

    #: The key press to enter zoom mode, if **always_on** is False.  Has no effect
    #: if **always_on** is True.
    enter_zoom_key = Instance(KeySpec, args=("z", ))

    #: The key press to leave zoom mode, if **always_on** is False.  Has no effect
    #: if **always_on** is True.
    exit_zoom_key = Instance(KeySpec, args=("z", ))

    #: Disable the tool after the zoom is completed?
    disable_on_complete = Property()

    #-------------------------------------------------------------------------
    # Appearance properties (for Box mode)
    #-------------------------------------------------------------------------

    #: The pointer to use when drawing a zoom box.
    pointer = "magnifier"

    #: The color of the selection box.
    color = ColorTrait("lightskyblue")

    #: The alpha value to apply to **color** when filling in the selection
    #: region.  Because it is almost certainly useless to have an opaque zoom
    #: rectangle, but it's also extremely useful to be able to use the normal
    #: named colors from Enable, this attribute allows the specification of a
    #: separate alpha value that replaces the alpha value of **color** at draw
    #: time.
    alpha = Trait(0.4, None, Float)

    #: The color of the outside selection rectangle.
    border_color = ColorTrait("dodgerblue")

    #: The thickness of selection rectangle border.
    border_size = Int(1)

    #: The possible event states of this zoom tool.
    event_state = Enum("normal", "selecting", "pre_selecting")

    # The (x,y) screen point where the mouse went down.
    _screen_start = Trait(None, None, Tuple)

    # The (x,,y) screen point of the last seen mouse move event.
    _screen_end = Trait(None, None, Tuple)

    # If **always_on** is False, this attribute indicates whether the tool
    # is currently enabled.
    _enabled = Bool(False)

    #-------------------------------------------------------------------------
    # Private traits
    #-------------------------------------------------------------------------

    # the original numerical screen ranges
    _orig_low_setting = Tuple
    _orig_high_setting = Tuple

    def __init__(self, component=None, *args, **kw):
        # Since this class uses multiple inheritance (eek!), lets be
        # explicit about the order of the parent class constructors
        AbstractOverlay.__init__(self, component, *args, **kw)
        BetterZoom.__init__(self, component, *args, **kw)
        # Store the original range settings
        x_range = self._get_x_mapper().range
        y_range = self._get_y_mapper().range
        self._orig_low_setting = (x_range.low_setting, y_range.low_setting)
        self._orig_high_setting = (x_range.high_setting, y_range.high_setting)

    def reset(self, event=None):
        """ Resets the tool to normal state, with no start or end position.
        """
        self.event_state = "normal"
        self._screen_start = None
        self._screen_end = None

    #--------------------------------------------------------------------------
    #  BetterZoom interface
    #--------------------------------------------------------------------------

    def normal_key_pressed(self, event):
        """ Handles a key being pressed when the tool is in the 'normal'
        state.
        """
        if not self.always_on:
            if self.enter_zoom_key.match(event) and not self._enabled:
                self.event_state = 'pre_selecting'
                event.window.set_pointer(self.pointer)
                event.window.set_mouse_owner(self, event.net_transform())
                self._enabled = True
                event.handled = True
            elif self.exit_zoom_key.match(event) and self._enabled:
                self.state = 'normal'
                self._end_select(event)
                event.handled = True

        if not event.handled:
            super(BetterSelectingZoom, self).normal_key_pressed(event)

    def normal_left_down(self, event):
        """ Handles the left mouse button being pressed while the tool is
        in the 'normal' state.

        If the tool is enabled or always on, it starts selecting.
        """
        if self._is_enabling_event(event):
            self._start_select(event)
            event.handled = True

        return

    def normal_right_down(self, event):
        """ Handles the right mouse button being pressed while the tool is
        in the 'normal' state.

        If the tool is enabled or always on, it starts selecting.
        """
        if self._is_enabling_event(event):
            self._start_select(event)
            event.handled = True

        return

    def pre_selecting_left_down(self, event):
        """ The user pressed the key to turn on the zoom mode,
            now handle the click to start the select mode
        """
        self._start_select(event)
        event.handled = True

    def pre_selecting_key_pressed(self, event):
        """ Handle key presses, specifically the exit zoom key
        """
        if self.exit_zoom_key.match(event) and self._enabled:
            self._end_selecting(event)

    def selecting_key_pressed(self, event):
        """ Handle key presses, specifically the exit zoom key
        """
        if self.exit_zoom_key.match(event) and self._enabled:
            self._end_selecting(event)

    def selecting_mouse_move(self, event):
        """ Handles the mouse moving when the tool is in the 'selecting' state.

        The selection is extended to the current mouse position.
        """
        # Take into account when we update the current endpoint, but only
        # if we are in box select mode.  The way we handle aspect ratio
        # is to find the largest rectangle of the specified aspect which
        # will fit within the rectangle defined by the start coords and
        # the current mouse position.
        if self.tool_mode == "box" and self.aspect_ratio is not None:
            x1, y1 = self._screen_start
            x2, y2 = event.x, event.y
            if (y2 - y1) == 0:
                x2 = x1
                y2 = y1
            else:
                width = abs(x2 - x1)
                height = abs(y2 - y1)
                drawn_aspect = width / height
                if drawn_aspect > self.aspect_ratio:
                    # Drawn box is wider, so use its height to compute the
                    # restricted width
                    x2 = x1 + height * self.aspect_ratio * (1
                                                            if x2 > x1 else -1)
                else:
                    # Drawn box is taller, so use its width to compute the
                    # restricted height
                    y2 = y1 + width / self.aspect_ratio * (1
                                                           if y2 > y1 else -1)
            self._screen_end = (x2, y2)
        else:
            self._screen_end = (event.x, event.y)
        self.component.request_redraw()
        event.handled = True
        return

    def selecting_left_up(self, event):
        """ Handles the left mouse button being released when the tool is in
        the 'selecting' state.

        Finishes selecting and does the zoom.
        """
        if self.drag_button in ("left", None):
            self._end_select(event)
        return

    def selecting_right_up(self, event):
        """ Handles the right mouse button being released when the tool is in
        the 'selecting' state.

        Finishes selecting and does the zoom.
        """
        if self.drag_button == "right":
            self._end_select(event)
        return

    def selecting_mouse_leave(self, event):
        """ Handles the mouse leaving the plot when the tool is in the
        'selecting' state.

        Ends the selection operation without zooming.
        """
        self._end_selecting(event)
        return

    #--------------------------------------------------------------------------
    #  AbstractOverlay interface
    #--------------------------------------------------------------------------

    def overlay(self, component, gc, view_bounds=None, mode="normal"):
        """ Draws this component overlaid on another component.

        Overrides AbstractOverlay.
        """
        if self.event_state == "selecting":
            if self.tool_mode == "range":
                self._overlay_range(component, gc)
            else:
                self._overlay_box(component, gc)
        return

    #--------------------------------------------------------------------------
    #  private interface
    #--------------------------------------------------------------------------

    @deprecated
    def _get_disable_on_complete(self):
        return True

    @deprecated
    def _set_disable_on_complete(self, value):
        return

    @deprecated
    def _get_wheel_zoom_step(self):
        return self.zoom_factor - 1.0

    @deprecated
    def _set_wheel_zoom_step(self, value):
        self.zoom_factor = value + 1.0

    def _is_enabling_event(self, event):
        if self.always_on:
            enabled = True
        else:
            if self.always_on_modifier == 'shift':
                enabled = event.shift_down
            elif self.always_on_modifier == 'control':
                enabled = event.control_down
            elif self.always_on_modifier == 'alt':
                enabled = event.alt_down

        if enabled:
            if event.right_down and self.drag_button == 'right':
                return True
            if event.left_down and self.drag_button == 'left':
                return True

        return False

    def _start_select(self, event):
        """ Starts selecting the zoom region
        """
        if self.component.active_tool in (None, self):
            self.component.active_tool = self
        else:
            self._enabled = False
        self._screen_start = (event.x, event.y)
        self._screen_end = None
        self.event_state = "selecting"
        event.window.set_pointer(self.pointer)
        event.window.set_mouse_owner(self, event.net_transform())
        self.selecting_mouse_move(event)
        return

    def _end_select(self, event):
        """ Ends selection of the zoom region, adds the new zoom range to
        the zoom stack, and does the zoom.
        """
        self._screen_end = (event.x, event.y)

        start = numpy.array(self._screen_start)
        end = numpy.array(self._screen_end)

        if sum(abs(end - start)) < self.minimum_screen_delta:
            self._end_selecting(event)
            event.handled = True
            return

        low, high = self._map_coordinate_box(self._screen_start,
                                             self._screen_end)

        x_range = self._get_x_mapper().range
        y_range = self._get_y_mapper().range

        prev = (x_range.low, x_range.high, y_range.low, y_range.high)

        if self.tool_mode == 'range':
            axis = self._determine_axis()
            if axis == 1:
                # vertical
                next = (x_range.low, x_range.high, low[1], high[1])
            else:
                # horizontal
                next = (low[0], high[0], y_range.low, y_range.high)

        else:
            next = (low[0], high[0], low[1], high[1])

        zoom_state = SelectedZoomState(prev, next)
        zoom_state.apply(self)
        self._append_state(zoom_state)

        self._end_selecting(event)
        event.handled = True
        return

    def _end_selecting(self, event=None):
        """ Ends selection of zoom region, without zooming.
        """
        self.reset()
        self._enabled = False
        if self.component.active_tool == self:
            self.component.active_tool = None
        if event and event.window:
            event.window.set_pointer("arrow")

        self.component.request_redraw()
        if event and event.window.mouse_owner == self:
            event.window.set_mouse_owner(None)
        return

    def _overlay_box(self, component, gc):
        """ Draws the overlay as a box.
        """
        if self._screen_start and self._screen_end:
            with gc:
                gc.set_antialias(0)
                gc.set_line_width(self.border_size)
                gc.set_stroke_color(self.border_color_)
                gc.clip_to_rect(component.x, component.y, component.width,
                                component.height)
                x, y = self._screen_start
                x2, y2 = self._screen_end
                rect = (x, y, x2 - x + 1, y2 - y + 1)
                if self.color != "transparent":
                    color = self._get_fill_color()
                    gc.set_fill_color(color)
                    gc.draw_rect(rect)
                else:
                    gc.rect(*rect)
                    gc.stroke_path()
        return

    def _overlay_range(self, component, gc):
        """ Draws the overlay as a range.
        """
        axis_ndx = self._determine_axis()
        lower_left = [0, 0]
        upper_right = [0, 0]
        lower_left[axis_ndx] = self._screen_start[axis_ndx]
        lower_left[1 - axis_ndx] = self.component.position[1 - axis_ndx]
        upper_right[axis_ndx] = self._screen_end[
            axis_ndx] - self._screen_start[axis_ndx]
        upper_right[1 - axis_ndx] = self.component.bounds[1 - axis_ndx]

        with gc:
            gc.set_antialias(0)
            color = self._get_fill_color()
            gc.set_fill_color(color)
            gc.set_stroke_color(self.border_color_)
            gc.clip_to_rect(component.x, component.y, component.width,
                            component.height)
            gc.draw_rect(
                (lower_left[0], lower_left[1], upper_right[0], upper_right[1]))

        return

    def _get_fill_color(self):
        """Get the fill color based on the alpha and the color property
        """
        if self.alpha:
            color = list(self.color_)
            if len(color) == 4:
                color[3] = self.alpha
            else:
                color += [self.alpha]
        else:
            color = self.color_
        return color

    def _determine_axis(self):
        """ Determines whether the index of the coordinate along the axis of
        interest is the first or second element of an (x,y) coordinate tuple.
        """
        if self.axis == "index":
            if self.component.orientation == "h":
                return 0
            else:
                return 1
        else:
            if self.component.orientation == "h":
                return 1
            else:
                return 0

    def _map_coordinate_box(self, start, end):
        """ Given start and end points in screen space, returns corresponding
        low and high points in data space.
        """
        low = [0, 0]
        high = [0, 0]
        for axis_index, mapper in [(0, self.component.x_mapper), \
                                   (1, self.component.y_mapper)]:
            # Ignore missing axis mappers (ColorBar instances only have one).
            if not mapper:
                continue
            low_val = mapper.map_data(start[axis_index])
            high_val = mapper.map_data(end[axis_index])

            if low_val > high_val:
                low_val, high_val = high_val, low_val
            low[axis_index] = low_val
            high[axis_index] = high_val
        return low, high

    def _reset_range_settings(self):
        """ Reset the range settings to their original values """
        x_range = self._get_x_mapper().range
        y_range = self._get_y_mapper().range
        x_range.low_setting, y_range.low_setting = self._orig_low_setting
        x_range.high_setting, y_range.high_setting = self._orig_high_setting

    #--------------------------------------------------------------------------
    #  overloaded
    #--------------------------------------------------------------------------

    def _prev_state_pressed(self):
        super(BetterSelectingZoom, self)._prev_state_pressed()
        # Reset the range settings
        if self._history_index == 0:
            self._reset_range_settings()

    def _reset_state_pressed(self):
        super(BetterSelectingZoom, self)._reset_state_pressed()
        # Reset the range settings
        self._reset_range_settings()
Exemple #4
0
class TimeSeriesCorrelation(BaseModule):
    """时间序列相关性"""
    TestFactors = ListStr(arg_type="MultiOption",
                          label="测试因子",
                          order=0,
                          option_range=())
    #PriceFactor = Enum(None, arg_type="SingleOption", label="价格因子", order=1)
    ReturnType = Enum("简单收益率",
                      "对数收益率",
                      "价格变化量",
                      arg_type="SingleOption",
                      label="收益率类型",
                      order=2)
    ForecastPeriod = Int(1, arg_type="Integer", label="预测期数", order=3)
    Lag = Int(0, arg_type="Integer", label="滞后期数", order=4)
    CalcDTs = List(dt.datetime, arg_type="DateList", label="计算时点", order=5)
    CorrMethod = Enum("pearson",
                      "spearman",
                      "kendall",
                      arg_type="SingleOption",
                      label="相关性算法",
                      order=6)
    SummaryWindow = Float(np.inf, arg_type="Integer", label="统计窗口", order=7)
    MinSummaryWindow = Int(2, arg_type="Integer", label="最小统计窗口", order=8)

    def __init__(self,
                 factor_table,
                 price_table,
                 name="时间序列相关性",
                 sys_args={},
                 **kwargs):
        self._FactorTable = factor_table
        self._PriceTable = price_table
        super().__init__(name=name, sys_args=sys_args, **kwargs)

    def __QS_initArgs__(self):
        DefaultNumFactorList, DefaultStrFactorList = getFactorList(
            dict(self._FactorTable.getFactorMetaData(key="DataType")))
        self.add_trait(
            "TestFactors",
            ListStr(arg_type="MultiOption",
                    label="测试因子",
                    order=0,
                    option_range=tuple(DefaultNumFactorList)))
        self.TestFactors.append(DefaultNumFactorList[0])
        DefaultNumFactorList, DefaultStrFactorList = getFactorList(
            dict(self._PriceTable.getFactorMetaData(key="DataType")))
        self.add_trait(
            "PriceFactor",
            Enum(*DefaultNumFactorList,
                 arg_type="SingleOption",
                 label="价格因子",
                 order=1))
        self.PriceFactor = searchNameInStrList(DefaultNumFactorList,
                                               ['价', 'Price', 'price'])

    def getViewItems(self, context_name=""):
        Items, Context = super().getViewItems(context_name=context_name)
        Items[0].editor = SetEditor(
            values=self.trait("TestFactors").option_range)
        return (Items, Context)

    def __QS_start__(self, mdl, dts, **kwargs):
        if self._isStarted: return ()
        super().__QS_start__(mdl=mdl, dts=dts, **kwargs)
        self._Output = {}
        self._Output["滚动相关性"] = {
            iFactorName: {}
            for iFactorName in self.TestFactors
        }  # {因子: {时点: DataFrame(index=[因子ID], columns=[证券ID])}},
        self._Output["证券ID"] = self._PriceTable.getID()
        self._Output["收益率"] = np.zeros(shape=(0, len(self._Output["证券ID"])))
        self._Output["因子ID"] = self._FactorTable.getID()
        nFactorID = len(self._Output["因子ID"])
        self._Output["因子值"] = {
            iFactorName: np.zeros(shape=(0, nFactorID))
            for iFactorName in self.TestFactors
        }
        self._CurCalcInd = 0
        return (self._FactorTable, self._PriceTable)

    def __QS_move__(self, idt, **kwargs):
        if self._iDT == idt: return 0
        self._iDT = idt
        if self.CalcDTs:
            if idt not in self.CalcDTs[self._CurCalcInd:]: return 0
            self._CurCalcInd = self.CalcDTs[self._CurCalcInd:].index(
                idt) + self._CurCalcInd
            PreInd = self._CurCalcInd - self.ForecastPeriod - self.Lag
            LastInd = self._CurCalcInd - self.ForecastPeriod
            PreDateTime = self.CalcDTs[PreInd]
            LastDateTime = self.CalcDTs[LastInd]
        else:
            self._CurCalcInd = self._Model.DateTimeIndex
            PreInd = self._CurCalcInd - self.ForecastPeriod - self.Lag
            LastInd = self._CurCalcInd - self.ForecastPeriod
            PreDateTime = self._Model.DateTimeSeries[PreInd]
            LastDateTime = self._Model.DateTimeSeries[LastInd]
        if (PreInd < 0) or (LastInd < 0): return 0
        Price = self._PriceTable.readData(dts=[LastDateTime, idt],
                                          ids=self._Output["证券ID"],
                                          factor_names=[self.PriceFactor
                                                        ]).iloc[0, :, :].values
        self._Output["收益率"] = np.r_[
            self._Output["收益率"],
            _calcReturn(Price, return_type=self.ReturnType)]
        FactorData = self._FactorTable.readData(
            dts=[PreDateTime],
            ids=self._Output["因子ID"],
            factor_names=list(self.TestFactors)).iloc[:, 0, :].values.T
        StartInd = int(
            max(0, self._Output["收益率"].shape[0] - self.SummaryWindow))
        for i, iFactorName in enumerate(self.TestFactors):
            self._Output["因子值"][iFactorName] = np.r_[
                self._Output["因子值"][iFactorName], FactorData[i:i + 1]]
            if self._Output["收益率"].shape[0] >= self.MinSummaryWindow:
                self._Output["滚动相关性"][iFactorName][idt] = pd.DataFrame(
                    np.c_[self._Output["因子值"][iFactorName][StartInd:],
                          self._Output["收益率"][StartInd:]]).corr(
                              method=self.CorrMethod,
                              min_periods=self.MinSummaryWindow
                          ).values[:FactorData.shape[1], FactorData.shape[1]:]
        return 0

    def __QS_end__(self):
        if not self._isStarted: return 0
        FactorIDs, PriceIDs = self._Output.pop("因子ID"), self._Output.pop(
            "证券ID")
        LastDT = max(self._Output["滚动相关性"][self.TestFactors[0]])
        self._Output["最后一期相关性"], self._Output["全样本相关性"] = {}, {}
        for iFactorName in self.TestFactors:
            self._Output["最后一期相关性"][iFactorName] = self._Output["滚动相关性"][
                iFactorName][LastDT].T
            self._Output["全样本相关性"][iFactorName] = pd.DataFrame(
                np.c_[self._Output["因子值"][iFactorName], self._Output["收益率"]]
            ).corr(method=self.CorrMethod,
                   min_periods=self.MinSummaryWindow).values[:len(FactorIDs),
                                                             len(FactorIDs):].T
            self._Output["滚动相关性"][iFactorName] = pd.Panel(
                self._Output["滚动相关性"][iFactorName],
                major_axis=FactorIDs,
                minor_axis=PriceIDs).swapaxes(
                    0, 2).to_frame(filter_observations=False).reset_index()
            self._Output["滚动相关性"][iFactorName].columns = ["因子ID", "时点"
                                                          ] + PriceIDs
        self._Output["最后一期相关性"] = pd.Panel(
            self._Output["最后一期相关性"], major_axis=PriceIDs,
            minor_axis=FactorIDs).swapaxes(
                0, 1).to_frame(filter_observations=False).reset_index()
        self._Output["最后一期相关性"].columns = ["因子", "因子ID"] + PriceIDs
        self._Output["全样本相关性"] = pd.Panel(
            self._Output["全样本相关性"], major_axis=PriceIDs,
            minor_axis=FactorIDs).swapaxes(
                0, 1).to_frame(filter_observations=False).reset_index()
        self._Output["全样本相关性"].columns = ["因子", "因子ID"] + PriceIDs
        self._Output.pop("收益率"), self._Output.pop("因子值")
        return 0
Exemple #5
0
class ODMR(ManagedJob, GetSetItemsMixin):
    """
    Implements an Optically Detected Magnetic Resonance (ODMR) measurement.
    
    Here we sweep a microwave source and record
    the photon clicks in every point of the sweep.
    
    This measurement requires a microwwave source
    and a counter. The counter also generates
    a trigger for every counting bin that steps
    the microwave source to the next frequency.
    
    The results from successive sweeps are accumulated.
    
    Optionally one can run a pulsed ODMR sweep with
    microwave pi-pulses.
    
    We provide some basic fitting.
    """

    # starting and stopping
    keep_data = Bool(
        False)  # helper variable to decide whether to keep existing data
    resubmit_button = Button(
        label='resubmit',
        desc=
        'Submits the measurement to the job manager. Tries to keep previously acquired data. Behaves like a normal submit if sequence or time bins have changed since previous run.'
    )

    # measurement parameters
    power = Range(low=-100.,
                  high=25.,
                  value=-20,
                  desc='Power [dBm]',
                  label='Power [dBm]',
                  mode='text',
                  auto_set=False,
                  enter_set=True)
    frequency_begin = Float(default_value=2.85e9,
                            desc='Start Frequency [Hz]',
                            label='Begin [Hz]',
                            editor=TextEditor(auto_set=False,
                                              enter_set=True,
                                              evaluate=float,
                                              format_str='%e'))
    frequency_end = Float(default_value=2.88e9,
                          desc='Stop Frequency [Hz]',
                          label='End [Hz]',
                          editor=TextEditor(auto_set=False,
                                            enter_set=True,
                                            evaluate=float,
                                            format_str='%e'))
    frequency_delta = Float(default_value=1e6,
                            desc='frequency step [Hz]',
                            label='Delta [Hz]',
                            editor=TextEditor(auto_set=False,
                                              enter_set=True,
                                              evaluate=float,
                                              format_str='%e'))
    t_pi = Range(low=1.,
                 high=100000.,
                 value=1000.,
                 desc='length of pi pulse [ns]',
                 label='pi [ns]',
                 mode='text',
                 auto_set=False,
                 enter_set=True)
    laser = Range(low=1.,
                  high=10000.,
                  value=300.,
                  desc='laser [ns]',
                  label='laser [ns]',
                  mode='text',
                  auto_set=False,
                  enter_set=True)
    wait = Range(low=1.,
                 high=10000.,
                 value=1000.,
                 desc='wait [ns]',
                 label='wait [ns]',
                 mode='text',
                 auto_set=False,
                 enter_set=True)
    pulsed = Bool(False, label='pulsed')
    seconds_per_point = Range(low=20e-3,
                              high=1,
                              value=20e-3,
                              desc='Seconds per point',
                              label='Sec per point',
                              mode='text',
                              auto_set=False,
                              enter_set=True)
    stop_time = Range(
        low=1.,
        value=np.inf,
        desc='Time after which the experiment stops by itself [s]',
        label='Stop time [s]',
        mode='text',
        auto_set=False,
        enter_set=True)
    n_lines = Range(low=1,
                    high=10000,
                    value=50,
                    desc='Number of lines in Matrix',
                    label='Matrix lines',
                    mode='text',
                    auto_set=False,
                    enter_set=True)

    # control data fitting
    fit_method = Enum('Lorentzian',
                      'N14 Lorentzian',
                      desc='Fit Method',
                      label='Fit Method')
    fit = Bool(False, label='fit')
    number_of_resonances = Int(1,
                               desc='Number of peaks to be fitted',
                               label='num peaks',
                               mode='text',
                               auto_set=False,
                               enter_set=True)
    width = Float(5e6,
                  desc='width used for peak detection.',
                  label='width [Hz]',
                  mode='text',
                  auto_set=False,
                  enter_set=True)

    # fit result
    fit_parameters = Array(value=np.array(()))
    fit_frequencies = Array(value=np.array(()), label='frequency [Hz]')
    fit_line_width = Array(value=np.array(()), label='line_width [Hz]')
    fit_contrast = Array(value=np.array(()), label='contrast [%]')
    fit_peak_ind = Array(value=np.array(()), label='peak pos [Hz]')
    fit_peak_val = Array(value=np.array(()), label='peak val')
    fit_string = Str()

    # measurement data
    frequency = Array()
    counts = Array()
    counts_matrix = Array()
    run_time = Float(value=0.0, desc='Run time [s]', label='Run time [s]')

    # plotting
    line_label = Instance(PlotLabel)
    line_data = Instance(ArrayPlotData)
    matrix_data = Instance(ArrayPlotData)
    line_plot = Instance(Plot, editor=ComponentEditor())
    matrix_plot = Instance(Plot, editor=ComponentEditor())

    get_set_items = [
        'frequency', 'counts', 'counts_matrix', 'number_of_resonances',
        'width', 'fit_parameters', 'fit_contrast', 'fit_line_width',
        'fit_frequencies', 'fit_peak_ind', 'fit_peak_val', 'fit_string', 'fit',
        'run_time', 'power', 'frequency_begin', 'frequency_end',
        'frequency_delta', 'laser', 'wait', 'pulsed', 't_pi',
        'seconds_per_point', 'stop_time', 'n_lines', '__doc__'
    ]

    def __init__(self, microwave, counter, pulse_generator=None, **kwargs):
        super(ODMR, self).__init__(**kwargs)
        self.microwave = microwave
        self.counter = counter
        self.pulse_generator = pulse_generator
        self._create_line_plot()
        self._create_matrix_plot()
        self.on_trait_change(self._update_line_data_index,
                             'frequency',
                             dispatch='ui')
        self.on_trait_change(self._update_line_data_value,
                             'counts',
                             dispatch='ui')
        self.on_trait_change(self._update_matrix_data_value,
                             'counts_matrix',
                             dispatch='ui')
        self.on_trait_change(self._update_matrix_data_index,
                             'n_lines,frequency',
                             dispatch='ui')
        self.on_trait_change(
            self._update_fit,
            'counts,fit,number_of_resonances,width,fit_method',
            dispatch='ui')

    def _counts_matrix_default(self):
        return np.zeros((self.n_lines, len(self.frequency)))

    def _frequency_default(self):
        return np.arange(self.frequency_begin,
                         self.frequency_end + self.frequency_delta,
                         self.frequency_delta)

    def _counts_default(self):
        return np.zeros(self.frequency.shape)

    # data acquisition
    def apply_parameters(self):
        """Apply the current parameters and decide whether to keep previous data."""
        frequency = np.arange(self.frequency_begin,
                              self.frequency_end + self.frequency_delta,
                              self.frequency_delta)

        if not self.keep_data or np.any(frequency != self.frequency):
            self.frequency = frequency
            self.counts = np.zeros(frequency.shape)
            self.run_time = 0.0

        self.keep_data = True  # when job manager stops and starts the job, data should be kept. Only new submission should clear data.

    def _run(self):

        try:
            self.state = 'run'
            self.apply_parameters()

            if self.run_time >= self.stop_time:
                self.state = 'done'
                return

            # if pulsed, turn on sequence
            if self.pulse_generator:
                if self.pulsed:
                    self.pulse_generator.Sequence(
                        100 * [(['detect', 'aom'], self.laser),
                               ([], self.wait), (['microwave'], self.t_pi)])
                else:
                    self.pulse_generator.Open()
            else:
                if self.pulsed:
                    raise ValueError(
                        "pulse_generator not defined while running measurement in pulsed mode."
                    )

            n = len(self.frequency)
            """
            self.microwave.setOutput( self.power, np.append(self.frequency,self.frequency[0]), self.seconds_per_point)
            self._prepareCounter(n)
            """
            self.microwave.setPower(self.power)
            self.microwave.initSweep(
                self.frequency, self.power * np.ones(self.frequency.shape))
            self.counter.configure(n, self.seconds_per_point, DutyCycle=0.8)
            time.sleep(0.5)

            while self.run_time < self.stop_time:
                start_time = time.time()
                if threading.currentThread().stop_request.isSet():
                    break
                self.microwave.resetListPos()
                counts = self.counter.run()
                self.run_time += time.time() - start_time
                self.counts += counts
                self.counts_matrix = np.vstack(
                    (counts, self.counts_matrix[:-1, :]))
                self.trait_property_changed('counts', self.counts)
                """
                self.microwave.doSweep()
                
                timeout = 3.
                start_time = time.time()
                while not self._count_between_markers.ready():
                    time.sleep(0.1)
                    if time.time() - start_time > timeout:
                        print "count between markers timeout in ODMR"
                        break
                        
                counts = self._count_between_markers.getData(0)
                self._count_between_markers.clean()
                """

            self.microwave.setOutput(None, self.frequency_begin)
            if self.pulse_generator:
                self.pulse_generator.Light()
            self.counter.clear()
        except:
            logging.getLogger().exception('Error in odmr.')
            self.microwave.setOutput(None, self.frequency_begin)
            self.state = 'error'
        else:
            if self.run_time < self.stop_time:
                self.state = 'idle'
            else:
                try:
                    self.save(self.filename)
                except:
                    logging.getLogger().exception(
                        'Failed to save the data to file.')
                self.state = 'done'

    # fitting
    def _update_fit(self):
        if self.fit:
            try:
                if self.fit_method == 'Lorentzian':
                    fit_res = find_peaks(self.frequency, self.counts,
                                         self.width, self.number_of_resonances)
                elif self.fit_method == 'N14 Lorentzian':
                    raise NotImplementedError('you may find it in old code')
                else:
                    raise ValueError('unknown fit method')
            except Exception:
                logging.getLogger().debug('ODMR fit failed.', exc_info=True)
                fit_res = {}
                #p = np.nan*np.empty(4)
        else:
            fit_res = {}
        self.fit_res = fit_res
        if 'p' in fit_res:
            p = fit_res['p']
            x0 = fit_res['x0']
            y0 = fit_res['y0']
            self.fit_peak_ind = x0
            self.fit_peak_val = y0
            self.fit_parameters = p
            self.fit_frequencies = p[1::3]
            self.fit_line_width = p[2::3]

            n = len(p) / 3
            contrast = np.empty(n)
            c = p[0]
            pp = p[1:].reshape((n, 3))
            for i, pi in enumerate(pp):
                area = pi[2]
                hwhm = pi[1]
                amp = np.abs(area / (np.pi * hwhm))
                if area > 0:
                    contrast[i] = 100 * amp / (amp + c)
                else:
                    contrast[i] = 100 * amp / c

            self.fit_contrast = contrast

            s = ''
            for i, fi in enumerate(self.fit_frequencies):
                s += 'f %i: %.6e Hz, HWHM %.3e Hz, contrast %.1f%%\n' % (
                    i + 1, fi, self.fit_line_width[i], self.fit_contrast[i])

            self.fit_string = s

        elif 'x0' in fit_res:
            x0 = fit_res['x0']
            y0 = fit_res['y0']
            self.fit_peak_ind = x0
            self.fit_peak_val = y0
            self.fit_parameters = np.array(())
            self.fit_frequencies = x0
            self.fit_line_width = np.array(())

            s = ''
            for i, fi in enumerate(self.fit_frequencies):
                s += 'f %i: %.6e Hz\n' % (i + 1, fi)

            self.fit_string = s

        else:
            self.fit_peak_ind = np.array(())
            self.fit_peak_val = np.array(())
            self.fit_parameters = np.array(())
            self.fit_frequencies = np.array(())
            self.fit_line_width = np.array(())
            self.fit_string = ''

    def _fit_string_changed(self, new):
        self.line_label.text = new

    # plotting
    def _create_line_plot(self):
        line_data = ArrayPlotData(frequency=np.array((0., 1.)),
                                  counts=np.array((0., 0.)),
                                  fit=np.array((0., 0.)))
        line_plot = Plot(line_data,
                         padding=8,
                         padding_left=64,
                         padding_bottom=32)
        line_plot.plot(('frequency', 'counts'),
                       style='line',
                       color='blue',
                       name='data')
        line_plot.index_axis.title = 'Frequency [MHz]'
        line_plot.value_axis.title = 'Fluorescence counts'
        line_label = PlotLabel(text='',
                               hjustify='left',
                               vjustify='bottom',
                               position=[64, 128])
        line_plot.overlays.append(line_label)
        line_plot.tools.append(SaveTool(line_plot))
        self.line_label = line_label
        self.line_data = line_data
        self.line_plot = line_plot

    def _create_matrix_plot(self):
        matrix_data = ArrayPlotData(image=np.zeros((2, 2)))
        matrix_plot = Plot(matrix_data,
                           padding=8,
                           padding_left=64,
                           padding_bottom=32)
        matrix_plot.index_axis.title = 'Frequency [MHz]'
        matrix_plot.value_axis.title = 'line #'
        matrix_plot.img_plot('image',
                             xbounds=(self.frequency[0], self.frequency[-1]),
                             ybounds=(0, self.n_lines),
                             colormap=Spectral)
        matrix_plot.tools.append(SaveTool(matrix_plot))
        self.matrix_data = matrix_data
        self.matrix_plot = matrix_plot

    def _fit_parameters_changed(self, new):
        n = len(new) / 3
        line_plot = self.line_plot
        line_data = self.line_data
        for name in line_plot.plots.keys():
            if name != 'data':
                line_plot.delplot(name)
        for i in range(n):
            p = np.append(new[0], new[i * 3 + 1:i * 3 + 4])
            name = 'peak_%i' % i
            line_data.set_data(name, Lorentzian(*p)(self.frequency))
            line_plot.plot(('frequency', name),
                           style='line',
                           color='red',
                           name=name)
        line_plot.request_redraw()

    #def _fit_changed(self,new):
    #plot = self.line_plot
    #if new:
    #plot.plot(('frequency','fit'), style='line', color='red', name='fit')
    #self.line_label.visible=True
    #else:
    #plot.delplot('fit')
    #self.line_label.visible=False
    #plot.request_redraw()

    def _update_line_data_index(self):
        self.line_data.set_data('frequency', self.frequency * 1e-6)
        self.counts_matrix = self._counts_matrix_default()

    def _update_line_data_value(self):
        self.line_data.set_data('counts', self.counts)

    def _update_matrix_data_value(self):
        self.matrix_data.set_data('image', self.counts_matrix)

    def _update_matrix_data_index(self):
        if self.n_lines > self.counts_matrix.shape[0]:
            self.counts_matrix = np.vstack(
                (self.counts_matrix,
                 np.zeros((self.n_lines - self.counts_matrix.shape[0],
                           self.counts_matrix.shape[1]))))
        else:
            self.counts_matrix = self.counts_matrix[:self.n_lines]
        self.matrix_plot.components[0].index.set_data(
            (self.frequency.min() * 1e-6, self.frequency.max() * 1e-6),
            (0.0, float(self.n_lines)))

    # react to GUI events
    def submit(self):
        """Submit the job to the JobManager."""
        self.keep_data = False
        ManagedJob.submit(self)

    def resubmit(self):
        """Submit the job to the JobManager."""
        self.keep_data = True
        ManagedJob.submit(self)

    def _resubmit_button_fired(self):
        """React to start button. Submit the Job."""
        self.resubmit()

    traits_view = View(
        VGroup(
            HGroup(
                Item('submit_button', show_label=False),
                Item('remove_button', show_label=False),
                Item('resubmit_button', show_label=False),
                Item('priority', enabled_when='state != "run"'),
                Item('state', style='readonly'),
                Item('run_time', style='readonly', format_str='%.f'),
                Item('stop_time'),
            ),
            HGroup(Item('filename', springy=True),
                   Item('save_button', show_label=False),
                   Item('load_button', show_label=False)),
            VGroup(
                HGroup(
                    Item('power', enabled_when='state != "run"'),
                    Item('frequency_begin', enabled_when='state != "run"'),
                    Item('frequency_end', enabled_when='state != "run"'),
                    Item('frequency_delta', enabled_when='state != "run"'),
                    Item('seconds_per_point', enabled_when='state != "run"'),
                ),
                HGroup(
                    Item('pulsed', enabled_when='state != "run"'),
                    Item('laser', enabled_when='state != "run"'),
                    Item('wait', enabled_when='state != "run"'),
                    Item('t_pi', enabled_when='state != "run"'),
                ),
                HGroup(
                    Item('fit'),
                    Item('number_of_resonances'),
                    Item('width'),
                    Item('fit_method'),
                    Item('n_lines'),
                ),
                #Item('fit_string', style='readonly'),
                #HGroup(Item('fit_contrast', style='readonly'),
                #       Item('fit_line_width', style='readonly'),
                #       Item('fit_frequencies', style='readonly'),
                #       ),
            ),
            VSplit(
                Item('matrix_plot', show_label=False, resizable=True),
                Item('line_plot', show_label=False, resizable=True),
            ),
        ),
        title='ODMR',
        width=900,
        height=800,
        buttons=[],
        resizable=True)
class SystemMonitorEditor(SeriesEditor):
    conn_spec = Instance(ConnectionSpec, ())
    name = Property(depends_on='conn_spec:+')
    tool = Instance(SystemMonitorControls)
    subscriber = Instance(Subscriber)
    plotter_options_manager_klass = SystemMonitorOptionsManager

    use_poll = Bool(False)
    _poll_interval = Int(10)
    _db_poll_interval = Int(10)
    _polling = False
    pickle_path = 'system_monitor'

    console_display = Instance(DisplayController)
    _ideogram_editor = None
    _spectrum_editor = None

    _air_editor = None
    _blank_air_editor = None
    _cocktail_editor = None
    _blank_cocktail_editor = None
    _background_editor = None

    task = Any

    db_lock = None

    def __init__(self, *args, **kw):
        super(SystemMonitorEditor, self).__init__(*args, **kw)
        color_bind_preference(self.console_display.model, 'bgcolor',
                              'pychron.sys_mon.bgcolor')
        color_bind_preference(self.console_display, 'default_color',
                              'pychron.sys_mon.textcolor')
        self.db_lock = Lock()

    def prepare_destroy(self):
        self.stop()
        self.dump_tool()
        for e in ('ideogram', 'spectrum', 'air', 'blank_air', 'cocktail',
                  'blank_cocktail', 'background'):
            e = getattr(self, '_{}_editor'.format(e))
            if e is not None:
                e.dump_tool()

    def stop(self):
        self._polling = False
        self.subscriber.stop()

    def console_message_handler(self, msg):
        color = 'green'
        if '|' in msg:
            color, msg = msg.split('|')

        self.console_display.add_text(msg, color=color)

    def run_added_handler(self, last_run_uuid=None):
        """
            add to sys mon series
            if atype is blank, air, cocktail, background
                add to atype series
            else
                if step heat
                    add to spectrum
                else
                    add to ideogram
        """
        def func():
            #with self.db_lock:
            self.info('refresh analyses. last UUID={}'.format(last_run_uuid))
            proc = self.processor
            db = proc.db
            with db.session_ctx():
                if last_run_uuid is None:
                    dbrun = db.get_last_analysis(
                        spectrometer=self.conn_spec.system_name)
                else:
                    dbrun = db.get_analysis_uuid(last_run_uuid)

                #if last_run_uuid:
                #    dbrun = db.get_analysis_uuid(last_run_uuid)
                if dbrun:
                    an = proc.make_analysis(dbrun)
                    self._refresh_sys_mon_series(an)
                    self._refresh_figures(an)

        invoke_in_main_thread(func)

    def start(self):

        self.load_tool()

        if self.conn_spec.host:
            sub = self.subscriber
            connected = sub.connect(timeout=1)

            sub.subscribe('RunAdded', self.run_added_handler, True)
            sub.subscribe('ConsoleMessage', self.console_message_handler)

            if connected:
                sub.listen()
                self.task.connection_pane.connection_status = 'Connected'
                self.task.connection_pane.connection_color = 'green'

            else:
                self.task.connection_pane.connection_status = 'Not Connected'
                self.task.connection_pane.connection_color = 'red'

                url = self.conn_spec.url
                self.warning(
                    'System publisher not available url={}'.format(url))

        last_run_uuid = self._get_last_run_uuid()
        #self.run_added_handler(last_run_uuid)

        t = Thread(name='poll', target=self._poll, args=(last_run_uuid, ))
        t.setDaemon(True)
        t.start()

    def _get_dump_tool(self):
        return self.tool

    def _load_tool(self, obj):
        self.tool = obj

    def _poll(self, last_run_uuid):
        self._polling = True
        sub = self.subscriber

        db_poll_interval = self._db_poll_interval
        poll_interval = self._poll_interval

        st = time.time()
        while 1:
            #only check subscription availability if one poll_interval has elapsed
            #sinde the last subscription message was received

            #check subscription availability
            if time.time() - sub.last_message_time > poll_interval:
                if sub.check_server_availability(timeout=0.5, verbose=True):
                    if not sub.is_listening():
                        self.info(
                            'Subscription server now available. starting to listen'
                        )
                        self.subscriber.listen()
                else:
                    if sub.was_listening:
                        self.warning(
                            'Subscription server no longer available. stop listen'
                        )
                        self.subscriber.stop()

            if self._wait(poll_interval):
                if not sub.is_listening():
                    if time.time() - st > db_poll_interval:
                        st = time.time()
                        lr = self._get_last_run_uuid()
                        self.debug('current uuid {} <> {}'.format(
                            last_run_uuid, lr))
                        if lr != last_run_uuid:
                            last_run_uuid = lr
                        invoke_in_main_thread(self.run_added_handler, lr)
            else:
                break

    def _wait(self, t):
        st = time.time()
        while time.time() - st < t:
            if not self._polling:
                return
            time.sleep(0.5)

        return True

    def _get_last_run_uuid(self):
        db = self.processor.db
        with db.session_ctx():
            dbrun = db.get_last_analysis(
                spectrometer=self.conn_spec.system_name)
            if dbrun:
                return dbrun.uuid

    def _refresh_sys_mon_series(self, an):

        ms = an.mass_spectrometer
        kw = dict(weeks=self.tool.weeks,
                  days=self.tool.days,
                  hours=self.tool.hours,
                  limit=self.tool.limit)

        ans = self.processor.analysis_series(ms, **kw)
        self.analyses = ans

    def _refresh_figures(self, an):
        if an.analysis_type == 'unknown':
            if an.step:
                self._refresh_spectrum(an.labnumber, an.aliquot)
            else:
                self._refresh_ideogram(an.labnumber)
        else:
            atype = an.analysis_type

            func = getattr(self, '_refresh_{}'.format(atype))
            func(an.labnumber)

    def _refresh_air(self, identifier):
        self._set_series('air', identifier)

    def _refresh_blank_air(self, identifier):
        self._set_series('blank_air', identifier)

    def _refresh_cocktail(self, identifier):
        self._set_series('cocktail', identifier)

    def _refresh_blank_cocktail(self, identifier):
        self._set_series('blank_coctkail', identifier)

    def _refresh_background(self, identifier):
        self._set_series('background', identifier)

    def _set_series(self, attr, identifier):
        name = '_{}_editor'.format(attr)
        editor = getattr(self, name)

        def new():
            e = self.task.new_series(ans=[], add_table=False, add_iso=False)
            self.task.tab_editors(0, -1)
            e.basename = '{} Series'.format(camel_case(attr))
            return e

        editor = self._update_editor(editor,
                                     new,
                                     identifier,
                                     None,
                                     layout=False,
                                     use_date_range=True)
        setattr(self, name, editor)

    def _refresh_ideogram(self, identifier):
        """
            open a ideogram editor if one is not already open
        """
        editor = self._ideogram_editor
        f = lambda: self.task.new_ideogram(add_table=False, add_iso=False)
        editor = self._update_editor(editor, f, identifier, None)
        self._ideogram_editor = editor

    def _refresh_spectrum(self, identifier, aliquot):
        editor = self._spectrum_editor
        f = lambda: self.task.new_spectrum(add_table=False, add_iso=False)
        editor = self._update_editor(editor, f, identifier, aliquot)
        self._spectrum_editor = editor

    def _update_editor(self,
                       editor,
                       editor_factory,
                       identifier,
                       aliquot,
                       layout=True,
                       use_date_range=False):
        if editor is None:
            editor = editor_factory()
            if layout:
                self.task.split_editors(-2, -1)
        else:
            if not self._polling:
                self.task.activate_editor(editor)

        #gather analyses
        ans = self._get_analyses(identifier, aliquot, use_date_range)

        editor.unknowns = ans
        group_analyses_by_key(editor, editor.unknowns, 'labnumber')
        #        self.task.group_by_labnumber()

        return editor

    def _get_analyses(self, identifier, aliquot=None, use_date_range=False):
        db = self.processor.db
        with db.session_ctx():
            limit = self.tool.limit
            if aliquot is not None:

                def func(a, l):
                    return l.identifier == identifier, a.aliquot == aliquot

                ans = db.get_analyses(func=func, limit=limit)
            elif use_date_range:
                end = datetime.now()
                start = end - timedelta(hours=self.tool.hours,
                                        weeks=self.tool.weeks,
                                        days=self.tool.days)

                ans = db.get_date_range_analyses(start,
                                                 end,
                                                 labnumber=identifier,
                                                 limit=limit)
            else:
                ans, tc = db.get_labnumber_analyses(identifier, limit=limit)

            return self.processor.make_analyses(ans)

    @on_trait_change('tool:[+, refresh_button]')
    def _handle_tool_change(self):
        self.run_added_handler()

    def _load_refiso(self, ref):
        pass

    def _set_name(self):
        pass

    def _get_name(self):
        return '{}-{}'.format(self.conn_spec.system_name, self.conn_spec.host)

    def _tool_default(self):
        tool = SystemMonitorControls()
        return tool

    def _console_display_default(self):
        return DisplayController(bgcolor='black',
                                 default_color='limegreen',
                                 max_blocks=100)

    def _subscriber_default(self):
        h = self.conn_spec.host
        p = self.conn_spec.port

        self.info('Creating subscriber to {}:{}"'.format(h, p))
        sub = Subscriber(host=self.conn_spec.host,
                         port=self.conn_spec.port,
                         verbose=False)
        return sub
Exemple #7
0
class DockPane(TaskPane, MDockPane):
    """ The toolkit-specific implementation of a DockPane.

    See the IDockPane interface for API documentation.
    """

    # Keep a reference to the Aui pane name in order to update dock state
    pane_name = Str()

    # Whether the title bar of the pane is currently visible.
    caption_visible = Bool(True)

    # AUI ring number; note that panes won't be movable out of their ring
    # number.  This is a way to isolate panes
    dock_layer = Int(0)

    # 'IDockPane' interface ------------------------------------------------

    size = Property(Tuple)

    # Protected traits -----------------------------------------------------

    _receiving = Bool(False)

    # ------------------------------------------------------------------------
    # 'ITaskPane' interface.
    # ------------------------------------------------------------------------

    @classmethod
    def get_hierarchy(cls, parent, indent=""):
        lines = ["%s%s %s" % (indent, str(parent), parent.GetName())]
        for child in parent.GetChildren():
            lines.append(cls.get_hierarchy(child, indent + "  "))
        return "\n".join(lines)

    def create(self, parent):
        """ Create and set the dock widget that contains the pane contents.
        """
        # wx doesn't need a wrapper control, so the contents become the control
        self.control = self.create_contents(parent)

        # hide the pane till the task gets activated, whereupon it will take
        # its visibility from the task state
        self.control.Hide()

        # Set the widget's object name. This important for AUI Manager state
        # saving. Use the task ID and the pane ID to avoid collisions when a
        # pane is present in multiple tasks attached to the same window.
        self.pane_name = self.task.id + ":" + self.id
        logger.debug("dock_pane.create: %s  HIERARCHY:\n%s" %
                     (self.pane_name, self.get_hierarchy(parent, "    ")))

    def get_new_info(self):
        info = aui.AuiPaneInfo().Name(self.pane_name).DestroyOnClose(False)

        # size?

        # Configure the dock widget according to the DockPane settings.
        self.update_dock_area(info)
        self.update_dock_features(info)
        self.update_dock_title(info)
        self.update_floating(info)
        self.update_visible(info)

        return info

    def add_to_manager(self, row=None, pos=None, tabify_pane=None):
        info = self.get_new_info()
        if tabify_pane is not None:
            target = tabify_pane.get_pane_info()
            logger.debug("dock_pane.add_to_manager: Tabify! %s onto %s" %
                         (self.pane_name, target.name))
        else:
            target = None
        if row is not None:
            info.Row(row)
        if pos is not None:
            info.Position(pos)
        self.task.window._aui_manager.AddPane(self.control,
                                              info,
                                              target=target)

    def validate_traits_from_pane_info(self):
        """ Sync traits from the AUI pane info.

        Useful after perspective restore to make sure e.g. visibility state
        is set correctly.
        """
        info = self.get_pane_info()
        self.visible = info.IsShown()

    def destroy(self):
        """ Destroy the toolkit-specific control that represents the contents.
        """
        if self.control is not None:
            logger.debug("Destroying %s" % self.control)
            self.task.window._aui_manager.DetachPane(self.control)

            # Some containers (e.g.  TraitsDockPane) will destroy the control
            # before we get here (e.g.  traitsui.ui.UI.finish by way of
            # TraitsDockPane.destroy), so check to see if it's already been
            # destroyed.  Fortunately, the Reparent in DetachPane still seems
            # to work on a destroyed control.
            if self.control:
                self.control.Hide()
                self.control.Destroy()
            self.control = None

    # ------------------------------------------------------------------------
    # 'IDockPane' interface.
    # ------------------------------------------------------------------------

    def create_contents(self, parent):
        """ Create and return the toolkit-specific contents of the dock pane.
        """
        return wx.Window(parent, name=self.task.id + ":" + self.id)

    # Trait property getters/setters ---------------------------------------

    def _get_size(self):
        if self.control is not None:
            return self.control.GetSize().Get()
        return (-1, -1)

    # Trait change handlers ------------------------------------------------

    def get_pane_info(self):
        info = self.task.window._aui_manager.GetPane(self.pane_name)
        return info

    def commit_layout(self, layout=True):
        if layout:
            self.task.window._aui_manager.Update()
        else:
            self.task.window._aui_manager.UpdateWithoutLayout()

    def commit_if_active(self, layout=True):
        # Only attempt to commit the AUI changes if the area if the task is active.
        main_window = self.task.window.control
        if main_window and self.task == self.task.window.active_task:
            self.commit_layout(layout)
        else:
            logger.debug("task not active so not committing...")

    def update_dock_area(self, info):
        info.Direction(AREA_MAP[self.dock_area])
        logger.debug("info: dock_area=%s dir=%s" %
                     (self.dock_area, info.dock_direction))

    @observe("dock_area")
    def _set_dock_area(self, event):
        logger.debug("trait change: dock_area")
        if self.control is not None:
            info = self.get_pane_info()
            self.update_dock_area(info)
            self.commit_if_active()

    def update_dock_features(self, info):
        info.CloseButton(self.closable)
        info.Floatable(self.floatable)
        info.Movable(self.movable)
        info.CaptionVisible(self.caption_visible)
        info.Layer(self.dock_layer)

    @observe("closable,floatable,movable,caption_visible,dock_layer")
    def _set_dock_features(self, event):
        if self.control is not None:
            info = self.get_pane_info()
            self.update_dock_features(info)
            self.commit_if_active()

    def update_dock_title(self, info):
        info.Caption(self.name)

    @observe("name")
    def _set_dock_title(self, event):
        if self.control is not None:
            info = self.get_pane_info()
            self.update_dock_title(info)

            # Don't need to refresh everything if only the name is changing
            self.commit_if_active(False)

    def update_floating(self, info):
        if self.floating:
            info.Float()
        else:
            info.Dock()

    @observe("floating")
    def _set_floating(self, event):
        if self.control is not None:
            info = self.get_pane_info()
            self.update_floating(info)
            self.commit_if_active()

    def update_visible(self, info):
        if self.visible:
            info.Show()
        else:
            info.Hide()

    @observe("visible")
    def _set_visible(self, event):
        logger.debug("_set_visible %s on pane=%s, control=%s" %
                     (self.visible, self.pane_name, self.control))
        if self.control is not None:
            info = self.get_pane_info()
            self.update_visible(info)
            self.commit_if_active()
class TraitsTest(HasTraits):

    #-------------------------------------------------------------------------
    #  Trait definitions:
    #-------------------------------------------------------------------------

    integer_text = Int(1)
    enumeration = Enum('one', 'two', 'three', 'four', 'five', 'six', cols=3)
    float_range = Range(0.0, 10.0, 10.0)
    int_range = Range(1, 6)
    int_range2 = Range(1, 50)
    compound = Trait(1,
                     Range(1, 6), 'one', 'two', 'three', 'four', 'five', 'six')
    boolean = Bool(True)
    instance = Trait(Instance())
    color = Color
    font = Font
    check_list = List(editor=CheckListEditor(
        values=['one', 'two', 'three', 'four'], cols=4))
    list = List(Str,
                ['East of Eden', 'The Grapes of Wrath', 'Of Mice and Men'])
    button = Event(0, editor=ButtonEditor(label='Click'))
    file = File
    directory = Directory
    image_enum = Trait(
        editor=ImageEnumEditor(
            values=origin_values, suffix='_origin', cols=4, klass=Instance),
        *origin_values)

    #-------------------------------------------------------------------------
    #  View definitions:
    #-------------------------------------------------------------------------

    view = View(
        ('|{Enum}', ('|<[Enumeration]', 'enumeration[Simple]', '_',
                     'enumeration[Custom]@', '_', 'enumeration[Text]*', '_',
                     'enumeration[Readonly]~'),
         ('|<[Check List]', 'check_list[Simple]', '_', 'check_list[Custom]@',
          '_', 'check_list[Text]*', '_', 'check_list[Readonly]~')),
        ('|{Range}', ('|<[Float Range]', 'float_range[Simple]', '_',
                      'float_range[Custom]@', '_', 'float_range[Text]*', '_',
                      'float_range[Readonly]~'),
         ('|<[Int Range]', 'int_range[Simple]', '_', 'int_range[Custom]@', '_',
          'int_range[Text]*', '_', 'int_range[Readonly]~'),
         ('|<[Int Range 2]', 'int_range2[Simple]', '_', 'int_range2[Custom]@',
          '_', 'int_range2[Text]*', '_', 'int_range2[Readonly]~')),
        ('|{Misc}', ('|<[Integer Text]', 'integer_text[Simple]', '_',
                     'integer_text[Custom]@', '_', 'integer_text[Text]*', '_',
                     'integer_text[Readonly]~'),
         ('|<[Compound]', 'compound[Simple]', '_', 'compound[Custom]@', '_',
          'compound[Text]*', '_', 'compound[Readonly]~'),
         ('|<[Boolean]', 'boolean[Simple]', '_', 'boolean[Custom]@', '_',
          'boolean[Text]*', '_', 'boolean[Readonly]~')),
        ('|{Color/Font}', ('|<[Color]', 'color[Simple]', '_', 'color[Custom]@',
                           '_', 'color[Text]*', '_', 'color[Readonly]~'),
         ('|<[Font]', 'font[Simple]', '_', 'font[Custom]@', '_', 'font[Text]*',
          '_', 'font[Readonly]~')),
        ('|{List}', ('|<[List]', 'list[Simple]', '_', 'list[Custom]@', '_',
                     'list[Text]*', '_', 'list[Readonly]~')),
        (
            '|{Button}',
            ('|<[Button]', 'button[Simple]', '_', 'button[Custom]@'),
            #                                        'button[Text]*',
            #                                        'button[Readonly]~' ),
            ('|<[Image Enum]', 'image_enum[Simple]', '_',
             'image_enum[Custom]@', '_', 'image_enum[Text]*', '_',
             'image_enum[Readonly]~'),
            ('|<[Instance]', 'instance[Simple]', '_', 'instance[Custom]@', '_',
             'instance[Text]*', '_', 'instance[Readonly]~'), ),
        ('|{File}', (
            '|<[File]',
            'file[Simple]',
            '_',
            'file[Custom]@',
            '_',
            'file[Text]*',
            '_',
            'file[Readonly]~', ),
         ('|<[Directory]', 'directory[Simple]', '_', 'directory[Custom]@', '_',
          'directory[Text]*', '_', 'directory[Readonly]~')),
        buttons=['Apply', 'Revert', 'Undo', 'OK'])
class TVTKScene(HasPrivateTraits):
    """A TVTK interactor scene widget.

    This widget uses a RenderWindowInteractor and therefore supports
    interaction with VTK widgets.  The widget uses TVTK.  The widget
    also supports the following:

    - Save the scene to a bunch of common (and not so common) image
      formats.

    - save the rendered scene to the clipboard.

    - adding/removing lists/tuples of actors

    - setting the view to useful predefined views (just like in
      MayaVi).

    - If one passes `stereo=1` to the constructor, stereo rendering is
      enabled.  By default this is disabled.  Changing the stereo trait
      has no effect during runtime.

    - One can disable rendering by setting `disable_render` to True.

    """

    # The version of this class.  Used for persistence.
    __version__ = 0

    ###########################################################################
    # Traits.
    ###########################################################################

    # Turn on/off stereo rendering.  This is set on initialization and
    # has no effect once the widget is realized.
    stereo = Bool(False)

    # Perform line smoothing for all renderered lines.  This produces
    # much nicer looking lines but renders slower.  This setting works
    # only when called before the first render.
    line_smoothing = Bool(False)

    # Perform point smoothing for all renderered points.  This
    # produces much nicer looking points but renders slower.  This
    # setting works only when called before the first render.
    point_smoothing = Bool(False)

    # Perform polygon smoothing (anti-aliasing) for all rendered
    # polygons.  This produces much nicer looking points but renders
    # slower.  This setting works only when called before the first
    # render.
    polygon_smoothing = Bool(False)

    # Enable parallel projection.  This trait is synchronized with
    # that of the camera.
    parallel_projection = Bool(False,
                               desc='if the camera uses parallel projection')

    # Disable rendering.
    disable_render = Bool(False, desc='if rendering is to be disabled')

    # Enable off-screen rendering.  This allows a user to render the
    # scene to an image without the need to have the window active.
    # For example, the application can be minimized and the saved
    # scene should be generated correctly.  This is handy for batch
    # scripts and the like.  This works under Win32.  Under Mac OS X
    # and Linux it requires a recent VTK version (later than Oct 2005
    # and ideally later than March 2006) to work correctly.
    off_screen_rendering = Bool(False,
                                desc='if off-screen rendering is enabled')

    # The background color of the window.  This is really a shadow
    # trait of the renderer's background.  Delegation does not seem to
    # work nicely for this.
    background = Trait(vtk_color_trait((0.5, 0.5, 0.5)),
                       desc='the background color of the window')

    # The default foreground color of any actors.  This basically
    # saves the preference and actors will listen to changes --
    # the scene itself does not use this.
    foreground = Trait(vtk_color_trait((1.0, 1.0, 1.0)),
                       desc='the default foreground color of actors')

    # The magnification to use when generating images from the render
    # window.
    magnification = Range(
        1,
        2048,
        1,
        desc='the magnification used when the screen is saved to an image')

    # Specifies the number of frames to use for anti-aliasing when
    # saving a scene.  This basically increases
    # `self.render_window.aa_frames` in order to produce anti-aliased
    # figures when a scene is saved to an image.  It then restores the
    # `aa_frames` in order to get interactive rendering rates.
    anti_aliasing_frames = Range(
        0,
        20,
        8,
        desc='number of frames to use for anti-aliasing when saving a scene')

    # Default JPEG quality.
    jpeg_quality = Range(10,
                         100,
                         95,
                         desc='the quality of the JPEG image to produce')

    # Default JPEG progressive setting.
    jpeg_progressive = Bool(True,
                            desc='if the generated JPEG should be progressive')

    # The light manager.
    light_manager = Instance(light_manager.LightManager, record=True)

    # The movie maker instance.
    movie_maker = Instance('tvtk.pyface.movie_maker.MovieMaker', record=True)

    # Is the scene busy or not.
    busy = Property(Bool, record=False)

    ########################################
    # Events

    # Lifecycle events: there are no opening/opened events since the
    # control is actually created in __init__.

    # The control is going to be closed.
    closing = Event(record=False)

    # The control has been closed.
    closed = Event(record=False)

    # Event fired when an actor is added to the scene.
    actor_added = Event(record=False)
    # Event fired when any actor is removed from the scene.
    actor_removed = Event(record=False)

    ########################################
    # Properties.

    # The interactor used by the scene.
    interactor = Property(Instance(tvtk.GenericRenderWindowInteractor))

    # The render_window.
    render_window = Property(Instance(tvtk.RenderWindow))

    # The renderer.
    renderer = Property(Instance(tvtk.Renderer))

    # The camera.
    camera = Property(Instance(tvtk.Camera))

    # The control to mimic the Widget behavior.
    control = Any

    ########################################
    # Private traits.

    # A recorder for script recording.
    recorder = Instance(HasTraits, record=False, transient=True)
    # Cached last camera state.
    _last_camera_state = Any(transient=True)
    _camera_observer_id = Int(transient=True)
    _script_id = Str(transient=True)

    # Saved light_manager settings while loading a scene.  The light manager
    # may not be created at the time a scene is loaded from disk, so if it
    # is saved here, when it is created the state is set.
    _saved_light_manager_state = Any(transient=True)

    # The renderer instance.
    _renderer = Instance(tvtk.Renderer)
    _renwin = Instance(tvtk.RenderWindow)
    _interactor = Instance(tvtk.RenderWindowInteractor)
    _camera = Instance(tvtk.Camera)
    _busy_count = Int(0)

    ###########################################################################
    # 'object' interface.
    ###########################################################################
    def __init__(self, parent=None, **traits):
        """ Initializes the object. """

        # Base class constructor.
        super(TVTKScene, self).__init__(**traits)

        # Used to set the view of the scene.
        self._def_pos = 1
        self.control = self._create_control(parent)
        self._renwin.update_traits()

    def __get_pure_state__(self):
        """Allows us to pickle the scene."""
        # The control attribute is not picklable since it is a VTK
        # object so we remove it.
        d = self.__dict__.copy()
        for x in [
                'control', '_renwin', '_interactor', '_camera', '_busy_count',
                '__sync_trait__', 'recorder', '_last_camera_state',
                '_camera_observer_id', '_saved_light_manager_state',
                '_script_id', '__traits_listener__'
        ]:
            d.pop(x, None)
        # Additionally pickle these.
        d['camera'] = self.camera
        return d

    def __getstate__(self):
        return state_pickler.dumps(self)

    def __setstate__(self, str_state):
        # This method is unnecessary since this object will almost
        # never be pickled by itself and only via an object that
        # contains it, therefore __init__ will be called when the
        # scene is constructed.  However, setstate is defined just for
        # completeness.
        state_pickler.set_state(self, state_pickler.loads_state(str_state))

    ###########################################################################
    # 'event' interface.
    ###########################################################################
    def _closed_fired(self):
        self.light_manager = None
        self._interactor = None
        self.movie_maker = None

    ###########################################################################
    # 'Scene' interface.
    ###########################################################################
    def render(self):
        """ Force the scene to be rendered. Nothing is done if the
        `disable_render` trait is set to True."""
        if not self.disable_render:
            self._renwin.render()

    def add_actors(self, actors):
        """ Adds a single actor or a tuple or list of actors to the
        renderer."""
        # Reset the zoom if this is the first actor.
        reset_zoom = (len(self._renderer.actors) == 0
                      and len(self._renderer.volumes) == 0)
        if hasattr(actors, '__iter__'):
            for actor in actors:
                self._renderer.add_actor(actor)
        else:
            self._renderer.add_actor(actors)
        self.actor_added = actors

        if reset_zoom:
            self.reset_zoom()
        else:
            self.render()

    def remove_actors(self, actors):
        """ Removes a single actor or a tuple or list of actors from
        the renderer."""
        if hasattr(actors, '__iter__'):
            for actor in actors:
                self._renderer.remove_actor(actor)
        else:
            self._renderer.remove_actor(actors)
        self.actor_removed = actors
        self.render()

    # Conevenience methods.
    add_actor = add_actors
    remove_actor = remove_actors

    def add_widgets(self, widgets, enabled=True):
        """Adds a single 3D widget or a sequence of widgets to the renderer.
        If `enabled` is True the widget is also enabled once it is added."""
        if not hasattr(widgets, '__iter__'):
            widgets = [widgets]
        iren = self._interactor
        for widget in widgets:
            widget.interactor = iren
            widget.enabled = enabled
        self.render()

    def remove_widgets(self, widgets):
        """Removes a single 3D widget or a sequence of widgets from the
        renderer."""
        if not hasattr(widgets, '__iter__'):
            widgets = [widgets]
        iren = self._interactor
        for widget in widgets:
            if widget.interactor is not None:
                widget.enabled = False
                widget.interactor = None
        self.render()

    def close(self):
        """Close the scene cleanly.  This ensures that the scene is
        shutdown cleanly.  This should be called if you are getting
        async errors when closing a scene from a UI.  This is based on
        the observations of Charl Botha here:

          http://public.kitware.com/pipermail/vtkusers/2008-May/095291.html

        """
        # Return if we are already closed.
        if self._renwin is None:
            return

        # Fire the "closing" event.
        self.closing = True
        # Disable any renders through traits listner callbacks.
        self.disable_render = True
        # Remove sync trait listeners.
        self.sync_trait('background', self._renderer, remove=True)
        self.sync_trait('parallel_projection', self.camera, remove=True)
        self.sync_trait('off_screen_rendering', self._renwin, remove=True)

        # Remove all the renderer's props.
        self._renderer.remove_all_view_props()
        # Set the renderwindow to release all resources and the OpenGL
        # context.
        self._renwin.finalize()
        # Disconnect the interactor from the renderwindow.
        self._interactor.render_window = None
        # Remove the reference to the render window.
        del self._renwin
        # Fire the "closed" event.
        self.closed = True

    def x_plus_view(self):
        """View scene down the +X axis. """
        self._update_view(self._def_pos, 0, 0, 0, 0, 1)
        self._record_methods('x_plus_view()')

    def x_minus_view(self):
        """View scene down the -X axis. """
        self._update_view(-self._def_pos, 0, 0, 0, 0, 1)
        self._record_methods('x_minus_view()')

    def z_plus_view(self):
        """View scene down the +Z axis. """
        self._update_view(0, 0, self._def_pos, 0, 1, 0)
        self._record_methods('z_plus_view()')

    def z_minus_view(self):
        """View scene down the -Z axis. """
        self._update_view(0, 0, -self._def_pos, 0, 1, 0)
        self._record_methods('z_minus_view()')

    def y_plus_view(self):
        """View scene down the +Y axis. """
        self._update_view(0, self._def_pos, 0, 1, 0, 0)
        self._record_methods('y_plus_view()')

    def y_minus_view(self):
        """View scene down the -Y axis. """
        self._update_view(0, -self._def_pos, 0, 1, 0, 0)
        self._record_methods('y_minus_view()')

    def isometric_view(self):
        """Set the view to an iso-metric view. """
        self._update_view(self._def_pos, self._def_pos, self._def_pos, 0, 0, 1)
        self._record_methods('isometric_view()')

    def reset_zoom(self):
        """Reset the camera so everything in the scene fits."""
        self._renderer.reset_camera()
        self.render()
        self._record_methods('reset_zoom()')

    def save(self, file_name, size=None, **kw_args):
        """Saves rendered scene to one of several image formats
        depending on the specified extension of the filename.

        If an additional size (2-tuple) argument is passed the window
        is resized to the specified size in order to produce a
        suitably sized output image.  Please note that when the window
        is resized, the window may be obscured by other widgets and
        the camera zoom is not reset which is likely to produce an
        image that does not reflect what is seen on screen.

        Any extra keyword arguments are passed along to the respective
        image format's save method.
        """
        ext = os.path.splitext(file_name)[1]
        meth_map = {
            '.ps': 'ps',
            '.bmp': 'bmp',
            '.tiff': 'tiff',
            '.png': 'png',
            '.jpg': 'jpg',
            '.jpeg': 'jpg',
            '.iv': 'iv',
            '.wrl': 'vrml',
            '.vrml': 'vrml',
            '.oogl': 'oogl',
            '.rib': 'rib',
            '.obj': 'wavefront',
            '.eps': 'gl2ps',
            '.pdf': 'gl2ps',
            '.tex': 'gl2ps',
            '.x3d': 'x3d',
            '.pov': 'povray'
        }
        if ext.lower() not in meth_map:
            raise ValueError(
                'Unable to find suitable image type for given file extension.')
        meth = getattr(self, 'save_' + meth_map[ext.lower()])
        if size is not None:
            orig_size = self.get_size()
            self.set_size(size)
            meth(file_name, **kw_args)
            self.set_size(orig_size)
            self._record_methods('save(%r, %r)' % (file_name, size))
        else:
            meth(file_name, **kw_args)
            self._record_methods('save(%r)' % (file_name))

    def save_ps(self, file_name):
        """Saves the rendered scene to a rasterized PostScript image.
        For vector graphics use the save_gl2ps method."""
        if len(file_name) != 0:
            w2if = self._get_window_to_image()
            ex = tvtk.PostScriptWriter()
            ex.file_name = file_name
            configure_input(ex, w2if)
            self._exporter_write(ex)

    def save_bmp(self, file_name):
        """Save to a BMP image file."""
        if len(file_name) != 0:
            w2if = self._get_window_to_image()
            ex = tvtk.BMPWriter()
            ex.file_name = file_name
            configure_input(ex, w2if)
            self._exporter_write(ex)

    def save_tiff(self, file_name):
        """Save to a TIFF image file."""
        if len(file_name) != 0:
            w2if = self._get_window_to_image()
            ex = tvtk.TIFFWriter()
            ex.file_name = file_name
            configure_input(ex, w2if)
            self._exporter_write(ex)

    def save_png(self, file_name):
        """Save to a PNG image file."""
        if len(file_name) != 0:
            w2if = self._get_window_to_image()
            ex = tvtk.PNGWriter()
            ex.file_name = file_name
            configure_input(ex, w2if)
            self._exporter_write(ex)

    def save_jpg(self, file_name, quality=None, progressive=None):
        """Arguments: file_name if passed will be used, quality is the
        quality of the JPEG(10-100) are valid, the progressive
        arguments toggles progressive jpegs."""
        if len(file_name) != 0:
            if not quality and not progressive:
                quality, progressive = self.jpeg_quality, self.jpeg_progressive
            w2if = self._get_window_to_image()
            ex = tvtk.JPEGWriter()
            ex.quality = quality
            ex.progressive = progressive
            ex.file_name = file_name
            configure_input(ex, w2if)
            self._exporter_write(ex)

    def save_iv(self, file_name):
        """Save to an OpenInventor file."""
        if len(file_name) != 0:
            ex = tvtk.IVExporter()
            self._lift()
            ex.input = self._renwin
            ex.file_name = file_name
            self._exporter_write(ex)

    def save_vrml(self, file_name):
        """Save to a VRML file."""
        if len(file_name) != 0:
            ex = tvtk.VRMLExporter()
            self._lift()
            ex.input = self._renwin
            ex.file_name = file_name
            self._exporter_write(ex)

    def save_oogl(self, file_name):
        """Saves the scene to a Geomview OOGL file. Requires VTK 4 to
        work."""
        if len(file_name) != 0:
            ex = tvtk.OOGLExporter()
            self._lift()
            ex.input = self._renwin
            ex.file_name = file_name
            self._exporter_write(ex)

    def save_rib(self, file_name, bg=0, resolution=None, resfactor=1.0):
        """Save scene to a RenderMan RIB file.

        Keyword Arguments:

        file_name -- File name to save to.

        bg -- Optional background option.  If 0 then no background is
        saved.  If non-None then a background is saved.  If left alone
        (defaults to None) it will result in a pop-up window asking
        for yes/no.

        resolution -- Specify the resolution of the generated image in
        the form of a tuple (nx, ny).

        resfactor -- The resolution factor which scales the resolution.
        """
        if resolution is None:
            # get present window size
            Nx, Ny = self.render_window.size
        else:
            try:
                Nx, Ny = resolution
            except TypeError:
                raise TypeError(
                    "Resolution (%s) should be a sequence with two elements" %
                    resolution)

        if len(file_name) == 0:
            return

        f_pref = os.path.splitext(file_name)[0]
        ex = tvtk.RIBExporter()
        ex.size = int(resfactor * Nx), int(resfactor * Ny)
        ex.file_prefix = f_pref
        ex.texture_prefix = f_pref + "_tex"
        self._lift()
        ex.render_window = self._renwin
        ex.background = bg

        if VTK_VER[:3] in ['4.2', '4.4']:
            # The vtkRIBExporter is broken in respect to VTK light
            # types.  Therefore we need to convert all lights into
            # scene lights before the save and later convert them
            # back.

            ########################################
            # Internal functions
            def x3to4(x):
                # convert 3-vector to 4-vector (w=1 -> point in space)
                return (x[0], x[1], x[2], 1.0)

            def x4to3(x):
                # convert 4-vector to 3-vector
                return (x[0], x[1], x[2])

            def cameralight_transform(light, xform, light_type):
                # transform light by 4x4 matrix xform
                origin = x3to4(light.position)
                focus = x3to4(light.focal_point)
                neworigin = xform.multiply_point(origin)
                newfocus = xform.multiply_point(focus)
                light.position = x4to3(neworigin)
                light.focal_point = x4to3(newfocus)
                light.light_type = light_type

            ########################################

            save_lights_type = []
            for light in self.light_manager.lights:
                save_lights_type.append(light.source.light_type)

            # Convert lights to scene lights.
            cam = self.camera
            xform = tvtk.Matrix4x4()
            xform.deep_copy(cam.camera_light_transform_matrix)
            for light in self.light_manager.lights:
                cameralight_transform(light.source, xform, "scene_light")

            # Write the RIB file.
            self._exporter_write(ex)

            # Now re-convert lights to camera lights.
            xform.invert()
            for i, light in enumerate(self.light_manager.lights):
                cameralight_transform(light.source, xform, save_lights_type[i])

            # Change the camera position. Otherwise VTK would render
            # one broken frame after the export.
            cam.roll(0.5)
            cam.roll(-0.5)
        else:
            self._exporter_write(ex)

    def save_wavefront(self, file_name):
        """Save scene to a Wavefront OBJ file.  Two files are
        generated.  One with a .obj extension and another with a .mtl
        extension which contains the material properties.

        Keyword Arguments:

        file_name -- File name to save to
        """
        if len(file_name) != 0:
            ex = tvtk.OBJExporter()
            self._lift()
            ex.input = self._renwin
            f_pref = os.path.splitext(file_name)[0]
            ex.file_prefix = f_pref
            self._exporter_write(ex)

    def save_gl2ps(self, file_name, exp=None):
        """Save scene to a vector PostScript/EPS/PDF/TeX file using
        GL2PS.  If you choose to use a TeX file then note that only
        the text output is saved to the file.  You will need to save
        the graphics separately.

        Keyword Arguments:

        file_name -- File name to save to.

        exp -- Optionally configured vtkGL2PSExporter object.
        Defaults to None and this will use the default settings with
        the output file type chosen based on the extention of the file
        name.
        """

        # Make sure the exporter is available.
        if not hasattr(tvtk, 'GL2PSExporter'):
            msg = "Saving as a vector PS/EPS/PDF/TeX file using GL2PS is "\
                  "either not supported by your version of VTK or "\
                  "you have not configured VTK to work with GL2PS -- read "\
                  "the documentation for the vtkGL2PSExporter class."
            print(msg)
            return

        if len(file_name) != 0:
            f_prefix, f_ext = os.path.splitext(file_name)
            ex = None
            if exp:
                ex = exp
                if not isinstance(exp, tvtk.GL2PSExporter):
                    msg = "Need a vtkGL2PSExporter you passed a "\
                          "%s"%exp.__class__.__name__
                    raise TypeError(msg)
                ex.file_prefix = f_prefix
            else:
                ex = tvtk.GL2PSExporter()
                # defaults
                ex.file_prefix = f_prefix
                if f_ext == ".ps":
                    ex.file_format = 'ps'
                elif f_ext == ".tex":
                    ex.file_format = 'tex'
                elif f_ext == ".pdf":
                    ex.file_format = 'pdf'
                else:
                    ex.file_format = 'eps'
                ex.sort = 'bsp'
                ex.compress = 1
                ex.edit_traits(kind='livemodal')

            self._lift()
            ex.render_window = self._renwin
            if ex.write3d_props_as_raster_image:
                self._exporter_write(ex)
            else:
                ex.write()
            # Work around for a bug in VTK where it saves the file as a
            # .pdf.gz when the file is really a PDF file.
            if f_ext == '.pdf' and os.path.exists(f_prefix + '.pdf.gz'):
                os.rename(f_prefix + '.pdf.gz', file_name)

    def save_x3d(self, file_name):
        """Save scene to an X3D file (http://www.web3d.org/x3d/).

        Keyword Arguments:

        file_name -- File name to save to.
        """
        # Make sure the exporter is available.
        if not hasattr(tvtk, 'X3DExporter'):
            msg = "Saving as a X3D file does not appear to be  "\
                  "supported by your version of VTK."
            print(msg)
            return

        if len(file_name) != 0:
            ex = tvtk.X3DExporter()
            ex.input = self._renwin
            ex.file_name = file_name
            ex.update()
            ex.write()

    def save_povray(self, file_name):
        """Save scene to a POVRAY (Persistence of Vision Raytracer),
        file (http://www.povray.org).

        Keyword Arguments:

        file_name -- File name to save to.
        """
        # Make sure the exporter is available.
        if not hasattr(tvtk, 'POVExporter'):
            msg = "Saving as a POVRAY file does not appear to be  "\
                  "supported by your version of VTK."
            print(msg)
            return

        if len(file_name) != 0:
            ex = tvtk.POVExporter()
            ex.input = self._renwin
            if hasattr(ex, 'file_name'):
                ex.file_name = file_name
            else:
                ex.file_prefix = os.path.splitext(file_name)[0]
            ex.update()
            ex.write()

    def get_size(self):
        """Return size of the render window."""
        return self._interactor.size

    def set_size(self, size):
        """Set the size of the window."""
        self._interactor.size = size
        self._renwin.size = size

    ###########################################################################
    # Properties.
    ###########################################################################
    def _get_interactor(self):
        """Returns the vtkRenderWindowInteractor of the parent class"""
        return self._interactor

    def _set_interactor(self, iren):
        if self._interactor is not None:
            self._interactor.render_window = None
        self._interactor = iren
        iren.render_window = self._renwin

    def _get_render_window(self):
        """Returns the scene's render window."""
        return self._renwin

    def _get_renderer(self):
        """Returns the scene's renderer."""
        return self._renderer

    def _get_camera(self):
        """ Returns the active camera. """
        return self._renderer.active_camera

    def _get_busy(self):
        return self._busy_count > 0

    def _set_busy(self, value):
        """The `busy` trait is either `True` or `False`.  However,
        this could be problematic since we could have two methods
        `foo` and `bar that both set `scene.busy = True`.  As soon as
        `bar` is done it sets `busy` back to `False`.  This is wrong
        since the UI is still busy as `foo` is not done yet.  We
        therefore store the number of busy calls and either increment
        it or decrement it and change the state back to `False` only
        when the count is zero.
        """
        bc = self._busy_count
        if value:
            bc += 1
        else:
            bc -= 1
            bc = max(0, bc)

        self._busy_count = bc
        if bc == 1:
            self.trait_property_changed('busy', False, True)
        if bc == 0:
            self.trait_property_changed('busy', True, False)

    ###########################################################################
    # Non-public interface.
    ###########################################################################
    def _create_control(self, parent):
        """ Create the toolkit-specific control that represents the widget. """

        if self.off_screen_rendering:
            if hasattr(tvtk, 'EGLRenderWindow'):
                renwin = tvtk.EGLRenderWindow()
            elif hasattr(tvtk, 'OSOpenGLRenderWindow'):
                renwin = tvtk.OSOpenGLRenderWindow()
            else:
                renwin = tvtk.RenderWindow()
                # If we are doing offscreen rendering we set the window size to
                # (1,1) so the window does not appear at all
                renwin.size = (1, 1)

            self._renwin = renwin
            self._interactor = tvtk.GenericRenderWindowInteractor(
                render_window=renwin)
        else:
            renwin = self._renwin = tvtk.RenderWindow()
            self._interactor = tvtk.RenderWindowInteractor(
                render_window=renwin)

        renwin.trait_set(point_smoothing=self.point_smoothing,
                         line_smoothing=self.line_smoothing,
                         polygon_smoothing=self.polygon_smoothing)
        # Create a renderer and add it to the renderwindow
        self._renderer = tvtk.Renderer()
        renwin.add_renderer(self._renderer)
        # Save a reference to our camera so it is not GC'd -- needed for
        # the sync_traits to work.
        self._camera = self.camera

        # Sync various traits.
        self._renderer.background = self.background
        self.sync_trait('background', self._renderer)
        self._renderer.on_trait_change(self.render, 'background')
        self._camera.parallel_projection = self.parallel_projection
        self.sync_trait('parallel_projection', self._camera)
        renwin.off_screen_rendering = self.off_screen_rendering
        self.sync_trait('off_screen_rendering', self._renwin)
        self.render_window.on_trait_change(self.render, 'off_screen_rendering')
        self.render_window.on_trait_change(self.render, 'stereo_render')
        self.render_window.on_trait_change(self.render, 'stereo_type')
        self.camera.on_trait_change(self.render, 'parallel_projection')

        self._interactor.initialize()
        self._interactor.render()
        self.light_manager = light_manager.LightManager(self)
        if self.off_screen_rendering:
            # We want the default size to be the normal (300, 300).
            # Setting the size now should not resize the window if
            # offscreen is working properly in VTK.
            renwin.size = (300, 300)

        return self._interactor

    def _get_window_to_image(self):
        w2if = tvtk.WindowToImageFilter(
            read_front_buffer=not self.off_screen_rendering)
        set_magnification(w2if, self.magnification)
        self._lift()
        w2if.input = self._renwin
        return w2if

    def _lift(self):
        """Lift the window to the top. Useful when saving screen to an
        image."""
        return

    def _exporter_write(self, ex):
        """Abstracts the exporter's write method."""
        # Bumps up the anti-aliasing frames when the image is saved so
        # that the saved picture looks nicer.
        rw = self.render_window
        if hasattr(rw, 'aa_frames'):
            aa_frames = rw.aa_frames
            rw.aa_frames = self.anti_aliasing_frames
        else:
            aa_frames = rw.multi_samples
            rw.multi_samples = self.anti_aliasing_frames
        rw.render()
        ex.update()
        ex.write()
        # Set the frames back to original setting.
        if hasattr(rw, 'aa_frames'):
            rw.aa_frames = aa_frames
        else:
            rw.multi_samples = aa_frames
        rw.render()

    def _update_view(self, x, y, z, vx, vy, vz):
        """Used internally to set the view."""
        camera = self.camera
        camera.focal_point = 0.0, 0.0, 0.0
        camera.position = x, y, z
        camera.view_up = vx, vy, vz
        self._renderer.reset_camera()
        self.render()

    def _disable_render_changed(self, val):
        if not val and self._renwin is not None:
            self.render()

    def _record_methods(self, calls):
        """A method to record a simple method called on self.  We need a
        more powerful and less intrusive way like decorators to do this.
        Note that calls can be a string with new lines in which case we
        interpret this as multiple calls.
        """
        r = self.recorder
        if r is not None:
            sid = self._script_id
            for call in calls.split('\n'):
                r.record('%s.%s' % (sid, call))

    def _record_camera_position(self, vtk_obj=None, event=None):
        """Callback to record the camera position."""
        r = self.recorder
        if r is not None:
            state = self._get_camera_state()
            lcs = self._last_camera_state
            if state != lcs:
                self._last_camera_state = state
                sid = self._script_id
                for key, value in state:
                    r.record('%s.camera.%s = %r' % (sid, key, value))
                r.record('%s.camera.compute_view_plane_normal()' % sid)
                r.record('%s.render()' % sid)

    def _get_camera_state(self):
        c = self.camera
        state = []
        state.append(('position', list(c.position)))
        state.append(('focal_point', list(c.focal_point)))
        state.append(('view_angle', c.view_angle))
        state.append(('view_up', list(c.view_up)))
        state.append(('clipping_range', list(c.clipping_range)))
        return state

    def _recorder_changed(self, r):
        """When the recorder is set we add an event handler so we can
        record the change to the camera position after the interaction.
        """
        iren = self._interactor
        if r is not None:
            self._script_id = r.get_script_id(self)
            id = iren.add_observer('EndInteractionEvent', messenger.send)
            self._camera_observer_id = id
            i_vtk = tvtk.to_vtk(iren)
            messenger.connect(i_vtk, 'EndInteractionEvent',
                              self._record_camera_position)
        else:
            self._script_id = ''
            iren.remove_observer(self._camera_observer_id)
            i_vtk = tvtk.to_vtk(iren)
            messenger.disconnect(i_vtk, 'EndInteractionEvent',
                                 self._record_camera_position)

    def _light_manager_changed(self, lm):
        if lm is not None:
            if self._saved_light_manager_state is not None:
                lm.__set_pure_state__(self._saved_light_manager_state)
                self._saved_light_manager_state = None

    def _movie_maker_default(self):
        from tvtk.pyface.movie_maker import MovieMaker
        return MovieMaker(scene=self)
Exemple #10
0
class Controller(HasTraits):

    # A reference to the plot viewer object
    viewer = Instance(Viewer)

    # Some parameters controller the random signal that will be generated
    distribution_type = Enum("normal")
    mean = Float(0.0)
    stddev = Float(1.0)

    # The max number of data points to accumulate and show in the plot
    max_num_points = Int(100)

    # The number of data points we have received; we need to keep track of
    # this in order to generate the correct x axis data series.
    num_ticks = Int(0)

    # private reference to the random number generator.  this syntax
    # just means that self._generator should be initialized to
    # random.normal, which is a random number function, and in the future
    # it can be set to any callable object.
    _generator = Trait(np.random.normal, Callable)

    view = View(Group('distribution_type',
                      'mean',
                      'stddev',
                      'max_num_points',
                      orientation="vertical"),
                buttons=["OK", "Cancel"])

    def timer_tick(self, *args):
        """
        Callback function that should get called based on a timer tick.  This
        will generate a new random data point and set it on the `.data` array
        of our viewer object.
        """
        # Generate a new number and increment the tick count
        x, y, z = accel.read()
        new_val = x
        self.num_ticks += 1

        # grab the existing data, truncate it, and append the new point.
        # This isn't the most efficient thing in the world but it works.
        cur_data = self.viewer.data
        new_data = np.hstack((cur_data[-self.max_num_points + 1:], [new_val]))
        new_index = np.arange(self.num_ticks - len(new_data) + 1,
                              self.num_ticks + 0.01)

        self.viewer.index = new_index
        self.viewer.data = new_data
        return

    def _distribution_type_changed(self):
        # This listens for a change in the type of distribution to use.
        while True:
            # Read the X, Y, Z axis acceleration values and print them.
            x, y, z = accel.read()
            print('X={0}, Y={1}, Z={2}'.format(x, y, z))
            # Wait half a second and repeat.
            time.sleep(0.1)
        self._generator = x
Exemple #11
0
class TLoop(HasTraits):

    ts = Instance(TStepper)
    d_t = Float(0.01)
    t_max = Float(1.0)
    k_max = Int(50)
    tolerance = Float(1e-8)

    def eval(self):

        self.ts.apply_essential_bc()

        t_n = 0.
        t_n1 = t_n
        n_dofs = self.ts.domain.n_dofs
        n_e = self.ts.domain.n_active_elems
        n_ip = self.ts.fets_eval.n_gp
        n_s = self.ts.mats_eval.n_s
        U_k = np.zeros(n_dofs)
        eps = np.zeros((n_e, n_ip, n_s))
        sig = np.zeros((n_e, n_ip, n_s))
        alpha = np.zeros((n_e, n_ip))
        q = np.zeros((n_e, n_ip))
        kappa = np.zeros((n_e, n_ip))

        U_record = np.zeros(n_dofs)
        F_record = np.zeros(n_dofs)
        sf_record = np.zeros(2 * n_e)
        t_record = [t_n]
        eps_record = [np.zeros_like(eps)]
        sig_record = [np.zeros_like(sig)]

        while t_n1 <= self.t_max - self.d_t:
            t_n1 = t_n + self.d_t
            k = 0
            scale = 1.0
            step_flag = 'predictor'
            d_U = np.zeros(n_dofs)
            d_U_k = np.zeros(n_dofs)
            while k <= self.k_max:
                if k == self.k_max:  # handling non-convergence
                    print np.amax(kappa)
                    print t_n1
                    print 'non-convergence'
                #                     scale *= 0.5
                # print scale
                #                     t_n1 = t_n + scale * self.d_t
                #                     k = 0
                #                     d_U = np.zeros(n_dofs)
                #                     d_U_k = np.zeros(n_dofs)
                #                     step_flag = 'predictor'
                #                     eps = eps_r
                #                     sig = sig_r
                #                     alpha = alpha_r
                #                     q = q_r
                #                     kappa = kappa_r

                R, K, eps, sig, alpha, q, kappa = self.ts.get_corr_pred(
                    step_flag, U_k, d_U_k, eps, sig, t_n, t_n1, alpha, q,
                    kappa)

                F_ext = -R
                K.apply_constraints(R)
                #                 print 'r', np.linalg.norm(R)
                d_U_k = K.solve()
                d_U += d_U_k
                #                 print 'r', np.linalg.norm(R)
                if np.linalg.norm(R) < self.tolerance:
                    F_record = np.vstack((F_record, F_ext))
                    U_k += d_U
                    U_record = np.vstack((U_record, U_k))
                    sf_record = np.vstack((sf_record, sig[:, :, 1].flatten()))
                    eps_record.append(np.copy(eps))
                    sig_record.append(np.copy(sig))
                    t_record.append(t_n1)
                    break
                k += 1
                step_flag = 'corrector'

            t_n = t_n1
        return U_record, F_record, sf_record, np.array(t_record), np.array(
            eps_record), np.array(sig_record)
class UiService(HasTraits):
    """
    A service to enable UI interactions with the single project plugin.

    """

    ##########################################################################
    # Attributes
    ##########################################################################

    #### public 'UiService' interface ########################################

    # The manager of the default context menu
    default_context_menu_manager = Instance(MenuManager)

    # A reference to our plugin's model service.
    model_service = Instance(ModelService)

    # The project control (in our case a tree). This is created by the
    # project view.  Provided here so that sub-classes may access it.
    project_control = Any

    # Fired when a new project has been created.  The value should be the
    # project instance that was created.
    project_created = Event

    # A timer to implement automatic project saving.
    timer = Instance(Timer)

    # The interval (minutes)at which automatic saving should occur.
    autosave_interval = Int(5)

    ##########################################################################
    # 'object' interface.
    ##########################################################################

    #### operator methods ####################################################

    def __init__(self, model_service, menu_manager, **traits):
        """
        Constructor.

        Extended to require a reference to the plugin's model service to create
        an instance.

        """

        super(UiService,
              self).__init__(model_service=model_service,
                             default_context_menu_manager=menu_manager,
                             **traits)
        try:
            # Bind the autosave interval to the value specified in the
            # single project preferences
            p = self.model_service.preferences
            bind_preference(self, 'autosave_interval', 5, p)
        except:
            logger.exception('Failed to bind autosave_interval in [%s] to '
                             'preferences.' % self)

        return

    ##########################################################################
    # 'UiService' interface.
    ##########################################################################

    #### public interface ####################################################

    def close(self, event):
        """
        Close the current project.

        """

        # Ensure any current project is ready for this change.
        if self.is_current_project_saved(event.window.control):

            # If we have a current project, close it.
            current = self.model_service.project
            if current is not None:
                logger.debug("Closing Project [%s]", current.name)
                self.model_service.project = None

        return

    def create(self, event):
        """
        Create a new project.

        """
        # Ensure any current project is ready for this change.
        if self.is_current_project_saved(event.window.control):

            # Use the registered factory to create a new project
            project = self.model_service.factory.create()
            if project is not None:

                # Allow the user to customize the new project
                dialog = project.edit_traits(
                    parent=event.window.control,
                    # FIXME: Due to a bug in traits, using a wizard dialog
                    # causes all of the Instance traits on the object being
                    # edited to be replaced with new instances without any
                    # listeners on those traits being called.  Since we can't
                    # guarantee that our project's don't have Instance traits,
                    # we can't use the wizard dialog type.
                    #kind = 'wizard'
                    kind='livemodal')

                # If the user closed the dialog with an ok, make it the
                # current project.
                if dialog.result:
                    logger.debug("Created Project [%s]", project.name)
                    self.model_service.project = project
                    self.project_created = project

        return

    def display_default_context_menu(self, parent, event):
        """
        Display the default context menu for the plugin's ui.  This is the
        context menu used when neither a project nor the project's contents
        are right-clicked.

        """

        # Determine the current workbench window.  This should be safe since
        # we're only building a context menu when the user clicked on a
        # control that is contained in a window.
        workbench = self.model_service.application.get_service(
            'envisage.ui.workbench.workbench.Workbench')
        window = workbench.active_window

        # Build our menu
        from envisage.workbench.action.action_controller import \
            ActionController
        menu = self.default_context_menu_manager.create_menu(
            parent, controller=ActionController(window=window))

        # Popup the menu (if an action is selected it will be performed
        # before before 'PopupMenu' returns).
        if menu.GetMenuItemCount() > 0:
            menu.show(event.x, event.y)

        return

    def delete_selection(self):
        """
        Delete the current selection within the current project.

        """

        # Only do something if we have a current project and a non-empty
        # selection
        current = self.model_service.project
        selection = self.model_service.selection[:]
        if current is not None and len(selection) > 0:
            logger.debug('Deleting selection from Project [%s]', current)

            # Determine the context for the current project.  Raise an error
            # if we can't treat it as a context as then we don't know how
            # to delete anything.
            context = self._get_context_for_object(current)
            if context is None:
                raise Exception('Could not treat Project ' + \
                    '[%s] as a context' % current)

            # Filter out any objects in the selection that can NOT be deleted.
            deletables = []
            for item in selection:
                rt = self._get_resource_type_for_object(item.obj)
                nt = rt.node_type
                if nt.can_delete(item):
                    deletables.append(item)
                else:
                    logger.debug(
                        'Node type reports selection item [%s] is '
                        'not deletable.', nt)

            if deletables != []:
                # Confirm the delete operation with the user
                names = '\n\t'.join([b.name for b in deletables])
                message = ('You are about to delete the following selected '
                           'items:\n\t%s\n\n'
                           'Are you sure?') % names
                title = 'Delete Selected Items?'
                action = confirm(None, message, title)
                if action == YES:

                    # Unbind all the deletable nodes
                    if len(deletables) > 0:
                        self._unbind_nodes(context, deletables)

        return

    def is_current_project_saved(self, parent_window):
        """
        Give the user the option to save any modifications to the current
        project prior to closing it.

        If the user wanted to cancel the closing of the current project,
        this method returns False.  Otherwise, it returns True.

        """

        # The default is the user okay'd the closing of the project
        result = True

        # If the current project is dirty, handle that now by challenging the
        # user for how they want to handle them.
        current = self.model_service.project
        if not (self._get_project_state(current)):
            dialog = ConfirmationDialog(
                parent  = parent_window,
                cancel  = True,
                title   = 'Unsaved Changes',
                message = 'Do you want to save the changes to project "%s"?' \
                    % (current.name),
                )
            action = dialog.open()
            if action == CANCEL:
                result = False
            elif action == YES:
                result = self._save(current, parent_window)
            elif action == NO:
                # Delete the autosaved file as the user does not wish to
                # retain the unsaved changes.
                self._clean_autosave_location(current.location.strip())
        return result

    def listen_for_application_exit(self):
        """
        Ensure that we get notified of any attempts to, and thus have a chance
        to veto, the closing of the application.

        FIXME: Normally this should be called during startup of this
        plugin, however, Envisage won't let us find the workbench service
        then because we've made a contribution to its extension points
        and it insists on starting us first.

        """

        workbench = self.model_service.application.get_service(
            'envisage.ui.workbench.workbench.Workbench')
        workbench.on_trait_change(self._workbench_exiting, 'exiting')

        return

    def open(self, event):
        """
        Open a project.

        """
        # Ensure any current project is ready for this change.
        if self.is_current_project_saved(event.window.control):

            # Query the user for the location of the project to be opened.
            path = self._show_open_dialog(event.window.control)
            if path is not None:
                logger.debug("Opening project from location [%s]", path)

                project = self.model_service.factory.open(path)
                if project is not None:
                    logger.debug("Opened Project [%s]", project.name)
                    self.model_service.project = project
                else:
                    msg = 'Unable to open %s as a project.' % path
                    error(event.window.control,
                          msg,
                          title='Project Open Error')

        return

    def save(self, event):
        """
        Save a project.

        """

        current = self.model_service.project
        if current is not None:
            self._save(current, event.window.control)

        return

    def save_as(self, event):
        """
        Save the current project to a different location.

        """

        current = self.model_service.project
        if current is not None:
            self._save(current, event.window.control, prompt_for_location=True)

        return

    #### protected interface #################################################

    def _auto_save(self, project):
        """

        Called periodically by the timer's Notify function to automatically
        save the current project.
        The auto-saved project has the extension '.autosave'.

        """
        # Save the project only if it has been modified.
        if project.dirty and project.is_save_as_allowed:
            location = project.location.strip()
            if not (location is None or len(location) < 1):
                autosave_loc = self._get_autosave_location(location)
                try:
                    # We do not want the project's location and name to be
                    # updated.
                    project.save(autosave_loc, overwrite=True, autosave=True)
                    msg = '[%s] auto-saved to [%s]' % (project, autosave_loc)
                    logger.debug(msg)
                except:
                    logger.exception('Error auto-saving project [%s]' %
                                     project)
            else:
                logger.exception('Error auto-saving project [%s] in '
                                 'location %s' % (project, location))
        return

    def _clean_autosave_location(self, location):
        """
        Removes any existing autosaved files or directories for the project
        at the specified location.

        """
        autosave_loc = self._get_autosave_location(location)
        if os.path.exists(autosave_loc):
            self.model_service.clean_location(autosave_loc)
        return

    def _get_autosave_location(self, location):
        """
        Returns the path for auto-saving the project in location.

        """
        return os.path.join(os.path.dirname(location),
                            os.path.basename(location) + '.autosave')

    def _get_context_for_object(self, obj):
        """
        Return the context for the specified object.

        """

        if isinstance(obj, Context):
            context = obj
        else:
            context = None
            resource_type = self._get_resource_type_for_object(obj)
            if resource_type is not None:
                factory = resource_type.context_adapter_factory
                if factory is not None:
                    # FIXME: We probably should use a real environment and
                    # context (parent context?)
                    context = factory.adapt(obj, Context, {}, None)

        return context

    def _get_resource_type_for_object(self, obj):
        """
        Return the resource type for the specified object.

        If no type could be found, returns None.

        """

        resource_manager = self.model_service.resource_manager
        return resource_manager.get_type_of(obj)

    def _get_project_state(self, project):
        """ Returns True if the project is clean: i.e., the dirty flag is
        False and all autosaved versions have been deleted from the filesystem.

        """

        result = True
        if project is not None:
            autosave_loc = self._get_autosave_location(
                project.location.strip())
            if project.dirty or os.path.exists(autosave_loc):
                result = False
        return result

    def _get_user_location(self, project, parent_window):
        """
        Prompt the user for a new location for the specified project.

        Returns the chosen location or, if the user cancelled, an empty
        string.

        """

        # The dialog to use depends on whether we're prompting for a file or
        # a directory.
        if self.model_service.are_projects_files():
            dialog = FileDialog(
                parent=parent_window,
                title='Save Project As',
                default_path=project.location,
                action='save as',
            )
            title_type = 'File'
        else:
            dialog = DirectoryDialog(
                parent=parent_window,
                message='Choose a Directory for the Project',
                default_path=project.location,
                action='open')
            title_type = 'Directory'

        # Prompt the user for a new location and then validate we're not
        # overwriting something without getting confirmation from the user.
        result = ""
        while (dialog.open() == OK):
            location = dialog.path.strip()

            # If the chosen location doesn't exist yet, we're set.
            if not os.path.exists(location):
                logger.debug('Location [%s] does not exist yet.', location)
                result = location
                break

            # Otherwise, confirm with the user that they want to overwrite the
            # existing files or directories.  If they don't want to, then loop
            # back and prompt them for a new location.
            else:
                logger.debug(
                    'Location [%s] exists.  Prompting for overwrite '
                    'permission.', location)
                message = 'Overwrite %s?' % location
                title = 'Project %s Exists' % title_type
                action = confirm(parent_window, message, title)
                if action == YES:

                    # Only use the location if we successfully remove the
                    # existing files or directories at that location.
                    try:
                        self.model_service.clean_location(location)
                        result = location
                        break

                    # Otherwise, display the remove error to the user and give
                    # them another chance to pick another location
                    except Exception as e:
                        msg = str(e)
                        title = 'Unable To Overwrite %s' % location
                        information(parent_window, msg, title)

        logger.debug('Returning user location [%s]', result)
        return result

    def _restore_from_autosave(self, project, autosave_loc):
        """ Restores the project from the version saved in autosave_loc.

        """

        workbench = self.model_service.application.get_service(
            'envisage.ui.workbench.workbench.Workbench')
        window = workbench.active_window
        app_name = workbench.branding.application_name
        message = ('The app quit unexpectedly when [%s] was being modified.\n'
                   'An autosaved version of this project exists.\n'
                   'Do you want to restore the project from the '
                   'autosaved version ?' % project.name)
        title = '%s-%s' % (app_name, project.name)
        action = confirm(window.control,
                         message,
                         title,
                         cancel=True,
                         default=YES)
        if action == YES:
            try:
                saved_project = self.model_service.factory.open(autosave_loc)
                if saved_project is not None:
                    # Copy over the autosaved version to the current project's
                    # location, switch the model service's project, and delete
                    # the autosaved version.
                    loc = project.location.strip()
                    saved_project.save(loc, overwrite=True)
                    self.model_service.clean_location(autosave_loc)
                    self.model_service.project = saved_project
                else:
                    logger.debug('No usable project found in [%s].' %
                                 autosave_loc)
            except:
                logger.exception('Unable to restore project from [%s]' %
                                 autosave_loc)
        self._start_timer(self.model_service.project)

        return

    def _save(self, project, parent_window, prompt_for_location=False):
        """
        Save the specified project.  If *prompt_for_location* is True,
        or the project has no known location, then the user is prompted to
        provide a location to save to.

        Returns True if the project was saved successfully, False if not.

        """

        location = project.location.strip()

        # If the project's existing location is valid, check if there are any
        # autosaved versions.
        autosave_loc = ''
        if location is not None and os.path.exists(location):
            autosave_loc = self._get_autosave_location(location)

        # Ask the user to provide a location if we were told to do so or
        # if the project has no existing location.
        if prompt_for_location or location is None or len(location) < 1:
            location = self._get_user_location(project, parent_window)
            # Rename any existing autosaved versions to the new project
            # location.
            if location is not None and len(location) > 0:
                self._clean_autosave_location(location)
                new_autosave_loc = self._get_autosave_location(location)
                if os.path.exists(autosave_loc):
                    shutil.move(autosave_loc, new_autosave_loc)

        # If we have a location to save to, try saving the project.
        if location is not None and len(location) > 0:
            try:
                project.save(location)
                saved = True
                msg = '"%s" saved to %s' % (project.name, project.location)
                information(parent_window, msg, 'Project Saved')
                logger.debug(msg)

            except Exception as e:
                saved = False
                logger.exception('Error saving project [%s]', project)
                error(parent_window, str(e), title='Save Error')
        else:
            saved = False

        # If the save operation was successful, delete any autosaved files that
        # exist.
        if saved:
            self._clean_autosave_location(location)
        return saved

    def _show_open_dialog(self, parent):
        """
        Show the dialog to open a project.

        """

        # Determine the starting point for browsing.  It is likely most
        # projects will be stored in the default path used when creating new
        # projects.
        default_path = self.model_service.get_default_path()
        project_class = self.model_service.factory.PROJECT_CLASS

        if self.model_service.are_projects_files():
            dialog = FileDialog(parent=parent,
                                default_directory=default_path,
                                title='Open Project')
            if dialog.open() == OK:
                path = dialog.path
            else:
                path = None
        else:
            dialog = DirectoryDialog(parent=parent,
                                     default_path=default_path,
                                     message='Open Project')
            if dialog.open() == OK:
                path = project_class.get_pickle_filename(dialog.path)
                if File(path).exists:
                    path = dialog.path
                else:
                    error(parent, 'Directory does not contain a recognized '
                          'project')
                    path = None
            else:
                path = None

        return path

    def _start_timer(self, project):
        """
        Resets the timer to work on auto-saving the current project.

        """

        if self.timer is None:
            if self.autosave_interval > 0:
                # Timer needs the interval in millisecs
                self.timer = Timer(self.autosave_interval * 60000,
                                   self._auto_save, project)
        return

    def _unbind_nodes(self, context, nodes):
        """
        Unbinds all of the specified nodes that can be found within this
        context or any of its sub-contexts.

        This uses a breadth first algorithm on the assumption that the
        user will have likely selected peer nodes within a sub-context
        that isn't the deepest context.

        """

        logger.debug(
            'Unbinding nodes [%s] from context [%s] within '
            'UiService [%s]', nodes, context, self)

        # Iterate through all of the selected nodes looking for ones who's
        # name is within our context.
        context_names = context.list_names()
        for node in nodes[:]:
            if node.name in context_names:

                # Ensure we've found a matching node by matching the objects
                # as well.
                binding = context.lookup_binding(node.name)
                if id(node.obj) == id(binding.obj):

                    # Remove the node from the context -AND- from the list of
                    # nodes that are still being searched for.
                    context.unbind(node.name)
                    nodes.remove(node)

                    # Stop if we've unbound the last node
                    if len(nodes) < 1:
                        break

        # If we haven't unbound the last node, then search any sub-contexts
        # for more nodes to unbind.
        else:

            # Build a list of all current sub-contexts of this context.
            subs = []
            for name in context.list_names():
                if context.is_context(name):
                    obj = context.lookup_binding(name).obj
                    sub_context = self._get_context_for_object(obj)
                    if sub_context is not None:
                        subs.append(sub_context)

            # Iterate through each sub-context, stopping as soon as possible
            # if we've run out of nodes.
            for sub in subs:
                self._unbind_nodes(sub, nodes)
                if len(nodes) < 1:
                    break

    def _workbench_exiting(self, event):
        """
        Handle the workbench polling to see if it can exit and shutdown the
        application.

        """

        logger.debug('Detected workbench closing event in [%s]', self)
        # Determine if the current project is dirty, or if an autosaved file
        # exists for this project (i.e., the project has changes which were
        # captured in the autosave operation but were not saved explicitly by
        # the user).  If so, let the user
        # decide whether to veto the closing event, save the project, or
        # ignore the dirty state.
        current = self.model_service.project

        if not (self._get_project_state(current)):
            # Find the active workbench window to be our dialog parent and
            # the application name to use in our dialog title.
            workbench = self.model_service.application.get_service(
                'envisage.ui.workbench.workbench.Workbench')
            window = workbench.active_window
            app_name = workbench.branding.application_name

            # Show a confirmation dialog to the user.
            message = 'Do you want to save changes before exiting?'
            title = '%s - %s' % (current.name, app_name)
            action = confirm(window.control,
                             message,
                             title,
                             cancel=True,
                             default=YES)
            if action == YES:
                # If the save is successful, the autosaved file is deleted.
                if not self._save(current, window.control):
                    event.veto = True
            elif action == NO:
                # Delete the autosaved file as the user does not wish to
                # retain the unsaved changes.
                self._clean_autosave_location(current.location.strip())
            elif action == CANCEL:
                event.veto = True

    #### Trait change handlers ###############################################

    def _autosave_interval_changed(self, old, new):
        """
        Restarts the timer when the autosave interval changes.

        """

        self.timer = None
        if new > 0 and self.model_service.project is not None:
            self._start_timer(self.model_service.project)
        return

    def _project_changed_for_model_service(self, object, name, old, new):
        """
        Detects if an autosaved version exists for the project, and displays
        a dialog to confirm restoring the project from the autosaved version.

        """

        if old is not None:
            self.timer = None
        if new is not None:
            # Check if an autosaved version exists and if so, display a dialog
            # asking if the user wishes to restore the project from the
            # autosaved version.
            # Note: An autosaved version should exist only if the app crashed
            # unexpectedly. Regular exiting of the workbench should cause the
            # autosaved version to be deleted.
            autosave_loc = self._get_autosave_location(new.location.strip())
            if (os.path.exists(autosave_loc)):
                # Issue a do_later command here so as to allow time for the
                # project view to be updated first to reflect the current
                # project's state.
                do_later(self._restore_from_autosave, new, autosave_loc)
            else:
                self._start_timer(new)
        return
Exemple #13
0
 class IntEnumModel(HasTraits):
     value = Int()
     possible_values = List([0, 1])
Exemple #14
0
 class IntEnumModel(HasTraits):
     value = Int()
Exemple #15
0
class ToolkitEditorFactory(EditorFactory):
    """ Editor factory for file editors.
    """

    #-------------------------------------------------------------------------
    #  Trait definitions:
    #-------------------------------------------------------------------------

    # Wildcard filter to apply to the file dialog:
    filter = filter_trait

    # Optional extended trait name of the trait containing the list of filters:
    filter_name = Str

    # Should file extension be truncated?
    truncate_ext = Bool(False)

    # Can the user select directories as well as files?
    allow_dir = Bool(False)

    # Is user input set on every keystroke? (Overrides the default) ('simple'
    # style only):
    auto_set = False

    # Is user input set when the Enter key is pressed? (Overrides the default)
    # ('simple' style only):
    enter_set = True

    # The number of history entries to maintain:
    # FIXME: add support
    entries = Int(10)

    # The root path of the file tree view ('custom' style only, not supported
    # under wx). If not specified, the filesystem root is used.
    root_path = File

    # Optional extend trait name of the trait containing the root path.
    root_path_name = Str

    # Optional extended trait name used to notify the editor when the file
    # system view should be reloaded ('custom' style only):
    reload_name = Str

    # Optional extended trait name used to notify when the user double-clicks
    # an entry in the file tree view. The associated path is assigned it:
    dclick_name = Str

    # The style of file dialog to use when the 'Browse...' button is clicked
    # Should be one of 'open' or 'save'
    dialog_style = Str('open')

    #-------------------------------------------------------------------------
    #  Traits view definition:
    #-------------------------------------------------------------------------

    traits_view = View([[
        '<options>', 'truncate_ext{Automatically truncate file extension?}',
        '|options:[Options]>'
    ], ['filter', '|[Wildcard filters]<>']])

    extras = Group()
Exemple #16
0
class VideoServer(Loggable):
    video = Instance(Video)
    port = Int(1084)
    quality = Int(75)
    _started = False
    use_color = True
    start_button = Button
    start_label = Property(depends_on='_started')
    _started = Bool(False)
    def _get_start_label(self):
        return 'Start' if not self._started else 'Stop'

    def _start_button_fired(self):
        if self._started:
            self.stop()
        else:
            self.start()

    def traits_view(self):
        v = View(Item('start_button', editor=ButtonEditor(label_value='start_label')))
        return v

    def _video_default(self):
        return Video(swap_rb=True)

    def stop(self):
#        if self._started:
        self.info('stopping video server')
        self._stop_signal.set()
        self._started = False

    def start(self):
        self.info('starting video server')
        self._new_frame_ready = Event()
        self._stop_signal = Event()

        self.video.open(user='******')
        bt = Thread(name='broadcast', target=self._broadcast)
        bt.start()

        self.info('video server started')
        self._started = True

    def _broadcast(self):
#        new_frame = self._new_frame_ready
        self.info('video broadcast thread started')

        context = zmq.Context()
#         sock = context.socket(zmq.PUB)
        sock = context.socket(zmq.REP)
        sock.bind('tcp://*:{}'.format(self.port))

        poll = zmq.Poller()
        poll.register(sock, zmq.POLLIN)

        self.request_reply(sock, poll)
#        if use_color:
#            kw = dict(swap_rb=True)
#            depth = 3
#        else:
#            kw = dict(gray=True)
#            depth = 1

#         pt = time.time()

    def request_reply(self, sock, poll):
        stop = self._stop_signal
        video = self.video
        fps = 10
        from . import Image
        from cStringIO import StringIO
        quality = self.quality
        while not stop.isSet():

            socks = dict(poll.poll(100))
            if socks.get(sock) == zmq.POLLIN:
                resp = sock.recv()
                if resp == 'FPS':
                    buf = str(fps)
                elif resp.startswith('QUALITY'):
                    quality = int(resp[7:])
                    buf = ''
                else:
                    f = video.get_frame()

        #            new_frame.clear()

                    im = Image.fromarray(array(f))
                    s = StringIO()
                    im.save(s, 'JPEG', quality=quality)
                    buf = s.getvalue()

                sock.send(buf)


    def publisher(self, sock):
        stop = self._stop_signal
        video = self.video
        use_color = self.use_color
        fps = 10
        from . import Image
        from cStringIO import StringIO
        while not stop.isSet():

            f = video.get_frame(gray=False)
#            new_frame.clear()
            im = Image.fromarray(array(f))
            s = StringIO()
            im.save(s, 'JPEG')

            sock.send(str(fps))
            sock.send(s.getvalue())

            time.sleep(1.0 / fps)
Exemple #17
0
class UI(HasPrivateTraits):
    """ Information about the user interface for a View.
    """

    # -------------------------------------------------------------------------
    #  Trait definitions:
    # -------------------------------------------------------------------------

    #: The ViewElements object from which this UI resolves Include items
    view_elements = Instance(ViewElements)

    #: Context objects that the UI is editing
    context = Dict(Str, Any)

    #: Handler object used for event handling
    handler = Instance(Handler)

    #: View template used to construct the user interface
    view = Instance("traitsui.view.View")

    #: Panel or dialog associated with the user interface
    control = Any()

    #: The parent UI (if any) of this UI
    parent = Instance("UI")

    #: Toolkit-specific object that "owns" **control**
    owner = Any()

    #: UIInfo object containing context or editor objects
    info = Instance(UIInfo)

    #: Result from a modal or wizard dialog:
    result = Bool(False)

    #: Undo and Redo history
    history = Any()

    #: The KeyBindings object (if any) for this UI:
    key_bindings = Property(depends_on=["view._key_bindings", "context"])

    #: The unique ID for this UI for persistence
    id = Str()

    #: Have any modifications been made to UI contents?
    modified = Bool(False)

    #: Event when the user interface has changed
    updated = Event(Bool)

    #: Title of the dialog, if any
    title = Str()

    #: The ImageResource of the icon, if any
    icon = Image

    #: Should the created UI have scroll bars?
    scrollable = Bool(False)

    #: The number of currently pending editor error conditions
    errors = Int()

    #: The code used to rebuild an updated user interface
    rebuild = Callable()

    #: Set to True when the UI has finished being destroyed.
    destroyed = Bool(False)

    # -- Private Traits -------------------------------------------------------

    #: Original context when used with a modal dialog
    _context = Dict(Str, Any)

    #: Copy of original context used for reverting changes
    _revert = Dict(Str, Any)

    #: List of methods to call once the user interface is created
    _defined = List()

    #: List of (visible_when,Editor) pairs
    _visible = List()

    #: List of (enabled_when,Editor) pairs
    _enabled = List()

    #: List of (checked_when,Editor) pairs
    _checked = List()

    #: Search stack used while building a user interface
    _search = List()

    #: List of dispatchable Handler methods
    _dispatchers = List()

    #: List of editors used to build the user interface
    _editors = List()

    #: List of names bound to the **info** object
    _names = List()

    #: Index of currently the active group in the user interface
    _active_group = Int()

    #: List of top-level groups used to build the user interface
    _groups = Property()
    _groups_cache = Any()

    #: Count of levels of nesting for undoable actions
    _undoable = Int(-1)

    #: Code used to rebuild an updated user interface
    _rebuild = Callable()

    #: The statusbar listeners that have been set up:
    _statusbar = List()

    #: Control which gets focus after UI is created
    #: Note: this does not track focus after UI creation
    #: only used by Qt backend.
    _focus_control = Any()

    #: Does the UI contain any scrollable widgets?
    #:
    #: The _scrollable trait is set correctly, but not used currently because
    #: its value is arrived at too late to be of use in building the UI.
    _scrollable = Bool(False)

    #: Cache for key bindings.
    _key_bindings = Instance("traitsui.key_bindings.KeyBindings")

    #: List of traits that are reset when a user interface is recycled
    #: (i.e. rebuilt).
    recyclable_traits = [
        "_context",
        "_revert",
        "_defined",
        "_visible",
        "_enabled",
        "_checked",
        "_search",
        "_dispatchers",
        "_editors",
        "_names",
        "_active_group",
        "_undoable",
        "_rebuild",
        "_groups_cache",
        "_key_bindings",
        "_focus_control",
    ]

    #: List of additional traits that are discarded when a user interface is
    #: disposed.
    disposable_traits = [
        "view_elements",
        "info",
        "handler",
        "context",
        "view",
        "history",
        "key_bindings",
        "icon",
        "rebuild",
    ]

    def traits_init(self):
        """ Initializes the traits object.
        """
        self.info = UIInfo(ui=self)
        self.handler.init_info(self.info)

    def ui(self, parent, kind):
        """ Creates a user interface from the associated View template object.
        """
        if (parent is None) and (kind in kind_must_have_parent):
            kind = "live"
        self.view.on_trait_change(self._updated_changed,
                                  "updated",
                                  dispatch="ui")
        self.rebuild = getattr(toolkit(), "ui_" + kind)
        self.rebuild(self, parent)

    def dispose(self, result=None, abort=False):
        """ Disposes of the contents of a user interface.
        """
        if result is not None:
            self.result = result

        # Only continue if the view has not already been disposed of:
        if self.control is not None:
            # Save the user preference information for the user interface:
            if not abort:
                self.save_prefs()

            # Finish disposing of the user interface:
            self.finish()

    def recycle(self):
        """ Recycles the user interface prior to rebuilding it.
        """
        # Reset all user interface editors:
        self.reset(destroy=False)

        # Discard any context object associated with the ui view control:
        self.control._object = None

        # Reset all recyclable traits:
        self.reset_traits(self.recyclable_traits)

    def finish(self):
        """ Finishes disposing of a user interface.
        """

        # Reset the contents of the user interface
        self.reset(destroy=False)

        # Make sure that 'visible', 'enabled', and 'checked' handlers are not
        # called after the editor has been disposed:
        for object in self.context.values():
            object.on_trait_change(self._evaluate_when, remove=True)

        # Notify the handler that the view has been closed:
        self.handler.closed(self.info, self.result)

        # Clear the back-link from the UIInfo object to us:
        self.info.ui = None

        # Destroy the view control:
        self.control._object = None
        toolkit().destroy_control(self.control)
        self.control = None

        # Dispose of any KeyBindings object we reference:
        if self._key_bindings is not None:
            self._key_bindings.dispose()

        # Break the linkage to any objects in the context dictionary:
        self.context.clear()

        # Remove specified symbols from our dictionary to aid in clean-up:
        self.reset_traits(self.recyclable_traits)
        self.reset_traits(self.disposable_traits)

        self.destroyed = True

    def reset(self, destroy=True):
        """ Resets the contents of a user interface.
        """
        for editor in self._editors:
            if editor._ui is not None:
                # Propagate result to enclosed ui objects:
                editor._ui.result = self.result
            editor.dispose()

            # Zap the control. If there are pending events for the control in
            # the UI queue, the editor's '_update_editor' method will see that
            # the control is None and discard the update request:
            editor.control = None

        # Remove any statusbar listeners that have been set up:
        for object, handler, name in self._statusbar:
            object.on_trait_change(handler, name, remove=True)

        del self._statusbar[:]

        if destroy:
            toolkit().destroy_children(self.control)

        for dispatcher in self._dispatchers:
            dispatcher.remove()

    def find(self, include):
        """ Finds the definition of the specified Include object in the current
            user interface building context.
        """
        context = self.context
        result = None

        # Get the context 'object' (if available):
        if len(context) == 1:
            object = list(context.values())[0]
        else:
            object = context.get("object")

        # Try to use our ViewElements objects:
        ve = self.view_elements

        # If none specified, try to get it from the UI context:
        if (ve is None) and (object is not None):
            # Use the context object's ViewElements (if available):
            ve = object.trait_view_elements()

        # Ask the ViewElements to find the requested item for us:
        if ve is not None:
            result = ve.find(include.id, self._search)

        # If not found, then try to search the 'handler' and 'object' for a
        # method we can call that will define it:
        if result is None:
            handler = context.get("handler")
            if handler is not None:
                method = getattr(handler, include.id, None)
                if callable(method):
                    result = method()

            if (result is None) and (object is not None):
                method = getattr(object, include.id, None)
                if callable(method):
                    result = method()

        return result

    def push_level(self):
        """ Returns the current search stack level.
        """
        return len(self._search)

    def pop_level(self, level):
        """ Restores a previously pushed search stack level.
        """
        del self._search[:len(self._search) - level]

    def prepare_ui(self):
        """ Performs all processing that occurs after the user interface is
            created.
        """
        # Invoke all of the editor 'name_defined' methods we've accumulated:
        info = self.info.trait_set(initialized=False)
        for method in self._defined:
            method(info)

        # Then reset the list, since we don't need it anymore:
        del self._defined[:]

        # Synchronize all context traits with associated editor traits:
        self.sync_view()

        # Hook all keyboard events:
        toolkit().hook_events(self, self.control, "keys", self.key_handler)

        # Hook all events if the handler is an extended 'ViewHandler':
        handler = self.handler
        if isinstance(handler, ViewHandler):
            toolkit().hook_events(self, self.control)

        # Invoke the handler's 'init' method, and abort if it indicates
        # failure:
        if handler.init(info) == False:
            raise TraitError("User interface creation aborted")

        # For each Handler method whose name is of the form
        # 'object_name_changed', where 'object' is the name of an object in the
        # UI's 'context', create a trait notification handler that will call
        # the method whenever 'object's 'name' trait changes. Also invoke the
        # method immediately so initial user interface state can be correctly
        # set:
        context = self.context
        for name in self._each_trait_method(handler):
            if name[-8:] == "_changed":
                prefix = name[:-8]
                col = prefix.find("_", 1)
                if col >= 0:
                    object = context.get(prefix[:col])
                    if object is not None:
                        method = getattr(handler, name)
                        trait_name = prefix[col + 1:]
                        self._dispatchers.append(
                            Dispatcher(method, info, object, trait_name))
                        if object.base_trait(trait_name).type != "event":
                            method(info)

        # If there are any Editor object's whose 'visible', 'enabled' or
        # 'checked' state is controlled by a 'visible_when', 'enabled_when' or
        # 'checked_when' expression, set up an 'anytrait' changed notification
        # handler on each object in the 'context' that will cause the
        # 'visible', 'enabled' or 'checked' state of each affected Editor to be
        #  set. Also trigger the evaluation immediately, so the visible,
        # enabled or checked state of each Editor can be correctly initialized:
        if (len(self._visible) + len(self._enabled) + len(self._checked)) > 0:
            for object in context.values():
                object.on_trait_change(self._evaluate_when, dispatch="ui")
            self._do_evaluate_when(at_init=True)

        # Indicate that the user interface has been initialized:
        info.initialized = True

    def sync_view(self):
        """ Synchronize context object traits with view editor traits.
        """
        for name, object in self.context.items():
            self._sync_view(name, object, "sync_to_view", "from")
            self._sync_view(name, object, "sync_from_view", "to")
            self._sync_view(name, object, "sync_with_view", "both")

    def _sync_view(self, name, object, metadata, direction):
        info = self.info
        for trait_name, trait in object.traits(**{metadata: is_str}).items():
            for sync in getattr(trait, metadata).split(","):
                try:
                    editor_id, editor_name = [
                        item.strip() for item in sync.split(".")
                    ]
                except:
                    raise TraitError(
                        "The '%s' metadata for the '%s' trait in "
                        "the '%s' context object should be of the form: "
                        "'id1.trait1[,...,idn.traitn]." %
                        (metadata, trait_name, name))

                editor = getattr(info, editor_id, None)
                if editor is not None:
                    editor.sync_value("%s.%s" % (name, trait_name),
                                      editor_name, direction)
                else:
                    raise TraitError(
                        "No editor with id = '%s' was found for "
                        "the '%s' metadata for the '%s' trait in the '%s' "
                        "context object." %
                        (editor_id, metadata, trait_name, name))

    def get_extended_value(self, name):
        """ Gets the current value of a specified extended trait name.
        """
        names = name.split(".")
        if len(names) > 1:
            value = self.context[names[0]]
            del names[0]
        else:
            value = self.context["object"]

        for name in names:
            value = getattr(value, name)

        return value

    def restore_prefs(self):
        """ Retrieves and restores any saved user preference information
        associated with the UI.
        """
        id = self.id
        if id != "":
            db = self.get_ui_db()
            if db is not None:
                try:
                    ui_prefs = db.get(id)
                    db.close()
                    return self.set_prefs(ui_prefs)
                except:
                    pass

        return None

    def set_prefs(self, prefs):
        """ Sets the values of user preferences for the UI.
        """
        if isinstance(prefs, dict):
            info = self.info
            for name in self._names:
                editor = getattr(info, name, None)
                if isinstance(editor, Editor) and (editor.ui is self):
                    editor_prefs = prefs.get(name)
                    if editor_prefs is not None:
                        editor.restore_prefs(editor_prefs)

            if self.key_bindings is not None:
                key_bindings = prefs.get("$")
                if key_bindings is not None:
                    self.key_bindings.merge(key_bindings)

            return prefs.get("")

        return None

    def save_prefs(self, prefs=None):
        """ Saves any user preference information associated with the UI.
        """
        if prefs is None:
            toolkit().save_window(self)
            return

        id = self.id
        if id != "":
            db = self.get_ui_db(mode="c")
            if db is not None:
                db[id] = self.get_prefs(prefs)
                db.close()

    def get_prefs(self, prefs=None):
        """ Gets the preferences to be saved for the user interface.
        """
        ui_prefs = {}
        if prefs is not None:
            ui_prefs[""] = prefs

        if self.key_bindings is not None:
            ui_prefs["$"] = self.key_bindings

        info = self.info
        for name in self._names:
            editor = getattr(info, name, None)
            if isinstance(editor, Editor) and (editor.ui is self):
                prefs = editor.save_prefs()
                if prefs is not None:
                    ui_prefs[name] = prefs

        return ui_prefs

    def get_ui_db(self, mode="r"):
        """ Returns a reference to the Traits UI preference database.
        """
        try:
            return shelve.open(
                os.path.join(traits_home(), "traits_ui"),
                flag=mode,
                protocol=-1,
            )
        except:
            return None

    def get_editors(self, name):
        """ Returns a list of editors for the given trait name.
        """
        return [editor for editor in self._editors if editor.name == name]

    def get_error_controls(self):
        """ Returns the list of editor error controls contained by the user
            interface.
        """
        controls = []
        for editor in self._editors:
            control = editor.get_error_control()
            if isinstance(control, list):
                controls.extend(control)
            else:
                controls.append(control)

        return controls

    def add_defined(self, method):
        """ Adds a Handler method to the list of methods to be called once the
            user interface has been constructed.
        """
        self._defined.append(method)

    def add_visible(self, visible_when, editor):
        """ Adds a conditionally enabled Editor object to the list of monitored
            'visible_when' objects.
        """
        try:
            self._visible.append((compile(visible_when, "<string>",
                                          "eval"), editor))
        except:
            pass
            # fixme: Log an error here...

    def add_enabled(self, enabled_when, editor):
        """ Adds a conditionally enabled Editor object to the list of monitored
            'enabled_when' objects.
        """
        try:
            self._enabled.append((compile(enabled_when, "<string>",
                                          "eval"), editor))
        except:
            pass
            # fixme: Log an error here...

    def add_checked(self, checked_when, editor):
        """ Adds a conditionally enabled (menu) Editor object to the list of
            monitored 'checked_when' objects.
        """
        try:
            self._checked.append((compile(checked_when, "<string>",
                                          "eval"), editor))
        except:
            pass
            # fixme: Log an error here...

    def do_undoable(self, action, *args, **kw):
        """ Performs an action that can be undone.
        """
        undoable = self._undoable
        try:
            if (undoable == -1) and (self.history is not None):
                self._undoable = self.history.now

            action(*args, **kw)
        finally:
            if undoable == -1:
                self._undoable = -1

    def route_event(self, event):
        """ Routes a "hooked" event to the correct handler method.
        """
        toolkit().route_event(self, event)

    def key_handler(self, event, skip=True):
        """ Handles key events.
        """
        key_bindings = self.key_bindings
        handled = (key_bindings is not None) and key_bindings.do(
            event, [], self.info, recursive=(self.parent is None))

        if (not handled) and (self.parent is not None):
            handled = self.parent.key_handler(event, False)

        if (not handled) and skip:
            toolkit().skip_event(event)

        return handled

    def evaluate(self, function, *args, **kw_args):
        """ Evaluates a specified function in the UI's **context**.
        """
        if function is None:
            return None

        if callable(function):
            return function(*args, **kw_args)

        context = self.context.copy()
        context["ui"] = self
        context["handler"] = self.handler
        return eval(function, globals(), context)(*args, **kw_args)

    def eval_when(self, when, result=True):
        """ Evaluates an expression in the UI's **context** and returns the
            result.
        """
        context = self._get_context(self.context)
        try:
            result = eval(when, globals(), context)
        except:
            from traitsui.api import raise_to_debug

            raise_to_debug()

        del context["ui"]

        return result

    def _get_context(self, context):
        """ Gets the context to use for evaluating an expression.
        """
        name = "object"
        n = len(context)
        if (n == 2) and ("handler" in context):
            for name, value in context.items():
                if name != "handler":
                    break
        elif n == 1:
            name = list(context.keys())[0]

        value = context.get(name)
        if value is not None:
            context2 = value.trait_get()
            context2.update(context)
        else:
            context2 = context.copy()

        context2["ui"] = self

        return context2

    def _evaluate_when(self):
        """ Set the 'visible', 'enabled', and 'checked' states for all Editors
            controlled by a 'visible_when', 'enabled_when' or 'checked_when'
            expression.
        """
        self._do_evaluate_when(at_init=False)

    def _do_evaluate_when(self, at_init=False):
        """ Set the 'visible', 'enabled', and 'checked' states for all Editors.

        This function does the job of _evaluate_when. We define it here to
        work around the traits dispatching mechanism that automatically
        determines the number of arguments of a notification method.

        :attr:`at_init` is set to true when this function is called the first
        time at initialization. In that case, we want to force the state of
        the items to be set (normally it is set only if it changes).
        """
        self._evaluate_condition(self._visible, "visible", at_init)
        self._evaluate_condition(self._enabled, "enabled", at_init)
        self._evaluate_condition(self._checked, "checked", at_init)

    def _evaluate_condition(self, conditions, trait, at_init=False):
        """ Evaluates a list of (eval, editor) pairs and sets a specified trait
        on each editor to reflect the Boolean value of the expression.

        1) All conditions are evaluated
        2) The elements whose condition evaluates to False are updated
        3) The elements whose condition evaluates to True are updated

        E.g., we first make invisible all elements for which 'visible_when'
        evaluates to False, and then we make visible the ones
        for which 'visible_when' is True. This avoids mutually exclusive
        elements to be visible at the same time, and thus making a dialog
        unnecessarily large.

        The state of an editor is updated only when it changes, unless
        at_init is set to True.

        Parameters
        ----------
        conditions : list of (str, Editor) tuple
            A list of tuples, each formed by 1) a string that contains a
            condition that evaluates to either True or False, and
            2) the editor whose state depends on the condition

        trait : str
            The trait that is set by the condition.
            Either 'visible, 'enabled', or 'checked'.

        at_init : bool
            If False, the state of an editor is set only when it changes
            (e.g., a visible element would not be updated to visible=True
            again). If True, the state is always updated (used at
            initialization).
        """

        context = self._get_context(self.context)

        # list of elements that should be activated
        activate = []
        # list of elements that should be de-activated
        deactivate = []

        for when, editor in conditions:
            try:
                cond_value = eval(when, globals(), context)
                editor_state = getattr(editor, trait)

                # add to update lists only if at_init is True (called on
                # initialization), or if the editor state has to change

                if cond_value and (at_init or not editor_state):
                    activate.append(editor)

                if not cond_value and (at_init or editor_state):
                    deactivate.append(editor)

            except Exception:
                # catch errors in the validate_when expression
                from traitsui.api import raise_to_debug

                raise_to_debug()

        # update the state of the editors
        for editor in deactivate:
            setattr(editor, trait, False)
        for editor in activate:
            setattr(editor, trait, True)

    def _get__groups(self):
        """ Returns the top-level Groups for the view (after resolving
        Includes. (Implements the **_groups** property.)
        """
        if self._groups_cache is None:
            shadow_group = self.view.content.get_shadow(self)
            self._groups_cache = shadow_group.get_content()
            for item in self._groups_cache:
                if isinstance(item, Item):
                    self._groups_cache = [
                        ShadowGroup(
                            shadow=Group(*self._groups_cache),
                            content=self._groups_cache,
                            groups=1,
                        )
                    ]
                    break
        return self._groups_cache

    # -- Property Implementations ---------------------------------------------

    def _get_key_bindings(self):
        if self._key_bindings is None:
            # create a new key_bindings instance lazily

            view, context = self.view, self.context
            if (view is None) or (context is None):
                return None

            # Get the KeyBindings object to use:
            values = list(context.values())
            key_bindings = view.key_bindings
            if key_bindings is None:
                from .key_bindings import KeyBindings

                self._key_bindings = KeyBindings(controllers=values)
            else:
                self._key_bindings = key_bindings.clone(controllers=values)

        return self._key_bindings

    # -- Traits Event Handlers ------------------------------------------------

    def _updated_changed(self):
        if self.rebuild is not None:
            toolkit().rebuild_ui(self)

    def _title_changed(self):
        if self.control is not None:
            toolkit().set_title(self)

    def _icon_changed(self):
        if self.control is not None:
            toolkit().set_icon(self)

    @on_trait_change("parent, view, context")
    def _pvc_changed(self):
        parent = self.parent
        if (parent is not None) and (self.key_bindings is not None):
            # If we don't have our own history, use our parent's:
            if self.history is None:
                self.history = parent.history

            # Link our KeyBindings object as a child of our parent's
            # KeyBindings object (if any):
            if parent.key_bindings is not None:
                parent.key_bindings.children.append(self.key_bindings)
class BleedthroughPiecewiseOp(HasStrictTraits):
    """
    .. warning::
        **THIS OPERATION IS DEPRECATED.**
    
    Apply bleedthrough correction to a set of fluorescence channels.
    
    This is not a traditional bleedthrough matrix-based compensation; it uses
    a similar set of single-color controls, but instead of computing a compensation
    matrix, it fits a piecewise-linear spline to the untransformed data and
    uses those splines to compute the correction factor at each point in
    a mesh across the color space.  The experimental data is corrected using
    a linear interpolation along that mesh: this is much faster than computing
    the correction factor for each cell indiviually (an operation that takes
    5 msec each.)
    
    To use, set up the :attr:`controls` dict with the single color controls;
    call :meth:`estimate` to parameterize the operation; check that the bleedthrough 
    plots look good with the :meth:`plot` method of the 
    :class:`BleedthroughPiecewiseDiagnostic` instance returned by 
    :meth:`default_view`; and then call :meth:`apply` with an :class:`Experiment`.
    
    .. warning::
        **THIS OPERATION IS DEPRECATED AND WILL BE REMOVED IN A FUTURE RELEASE. 
        TO USE IT, SET :attr:`ignore_deprecated` TO ``True``.  IF YOU HAVE A 
        USE CASE WHERE THIS WORKS BETTER THAN THE LINEAR BLEEDTHROUGH 
        CORRECTION, PLEASE EMAIL ME OR FILE A BUG.**
    
    Attributes
    ----------
    controls : Dict(Str, File)
        The channel names to correct, and corresponding single-color control
        FCS files to estimate the correction splines with.  Must be set to
        use `estimate()`.
        
    num_knots : Int (default = 12)
        The number of internal control points to estimate, spaced log-evenly
        from 0 to the range of the channel.  Must be set to use `estimate()`.
        
    mesh_size : Int (default = 32)
        The size of each axis in the mesh used to interpolate corrected values.
        
    ignore_deprecated : Bool (default = False)

            
    Notes
    -----
    We use an interpolation-based scheme to estimate corrected bleedthrough.
    The algorithm is as follows:
    
     - Fit a piecewise-linear spline to each single-color control's bleedthrough
       into other channels.  Because we want to fit the spline to untransfomed
       data, but capture both the negative, positive-linear and positive-log 
       portions of a traditional flow data set, we distribute the spline knots 
       evenly on an hlog-transformed axis for each color we're correcting.   

     - At each point on a regular mesh spanning the entire range of the
       instrument, estimate the mapping from (raw colors) --> (actual colors).
       The mesh points are also distributed evenly along the hlog-transformed
       color axes; this captures negative data as well as positive 
       This is quite slow: ~30 seconds for a mesh size of 32 in 3-space.
       Remember that additional channels expand the number of mesh points
       exponentially!

     - Use these estimates to paramaterize a linear interpolator (in linear
       space, this time).  There's one interpolator per output channel (so
       for a 3-channel correction, each interpolator is R^3 --> R).  For 
       each measured cell, run each interpolator to give the corrected output.

    Examples
    --------

    Create a small experiment:
    
        >>> import cytoflow as flow
        >>> import_op = flow.ImportOp()
        >>> import_op.tubes = [flow.Tube(file = "tasbe/rby.fcs")]
        >>> ex = import_op.apply()
    
    Create and parameterize the operation
        
        >>> bl_op = flow.BleedthroughPiecewiseOp()
        >>> bl_op.controls = {'Pacific Blue-A' : 'tasbe/ebfp.fcs',
        ...                   'FITC-A' : 'tasbe/eyfp.fcs',
        ...                   'PE-Tx-Red-YG-A' : 'tasbe/mkate.fcs'}
        >>> bl_op.ignore_deprecated = True
    
    Estimate the model parameters
    
        >>> bl_op.estimate(ex)
    
    Plot the diagnostic plot

        >>> bl_op.default_view().plot(ex)  

    Apply the operation to the experiment
    
        >>> ex2 = bl_op.apply(ex)  
 
    """

    # traits
    id = Constant('edu.mit.synbio.cytoflow.operations.bleedthrough_piecewise')
    friendly_id = Constant("Piecewise Bleedthrough Correction")

    name = Constant("Bleedthrough")

    controls = Dict(Str, File)
    num_knots = Int(12)
    mesh_size = Int(32)

    ignore_deprecated = Bool(False)

    _splines = Dict(Str, Dict(Str, Python), transient=True)
    _interpolators = Dict(Str, Python, transient=True)

    # because the order of the channels is important, we can't just call
    # _interpolators.keys()
    # TODO - this is ugly and unpythonic.  :-/
    _channels = List(Str, transient=True)

    def estimate(self, experiment, subset=None):
        """
        Estimate the bleedthrough from the single-channel controls in 
        :attr:`controls`
        """

        if not self.ignore_deprecated:
            raise util.CytoflowOpError(
                None, "BleedthroughPiecewiseOp is DEPRECATED. "
                "To use it anyway, set ignore_deprected "
                "to True.")

        if experiment is None:
            raise util.CytoflowOpError('experiment', "No experiment specified")

        if self.num_knots < 3:
            raise util.CytoflowOpError(
                'num_knots', "Need to allow at least 3 knots in the spline")

        self._channels = list(self.controls.keys())

        if len(self._channels) < 2:
            raise util.CytoflowOpError(
                'controls',
                "Need at least two channels to correct bleedthrough.")

        for channel in list(self.controls.keys()):
            if 'range' not in experiment.metadata[channel]:
                raise util.CytoflowOpError(
                    None, "Can't find range for channel {}".format(channel))

        self._splines = {}
        mesh_axes = []

        for channel in self._channels:
            self._splines[channel] = {}

            # make a little Experiment
            check_tube(self.controls[channel], experiment)
            tube_exp = ImportOp(
                tubes=[Tube(file=self.controls[channel])],
                channels={
                    experiment.metadata[c]["fcs_name"]: c
                    for c in experiment.channels
                },
                name_metadata=experiment.metadata['name_metadata']).apply()

            # apply previous operations
            for op in experiment.history:
                tube_exp = op.apply(tube_exp)

            # subset it
            if subset:
                try:
                    tube_exp = tube_exp.query(subset)
                except Exception as e:
                    raise util.CytoflowOpError(
                        'subset', "Subset string '{0}' isn't valid".format(
                            self.subset)) from e

                if len(tube_exp.data) == 0:
                    raise util.CytoflowOpError(
                        'subset',
                        "Subset string '{0}' returned no events".format(
                            self.subset))

            tube_data = tube_exp.data

            # polyfit requires sorted data
            tube_data.sort_values(by=channel, inplace=True)

            channel_min = tube_data[channel].min()
            channel_max = tube_data[channel].max()

            # we're going to set the knots and splines evenly across the
            # logicle-transformed data, so as to captur both the "linear"
            # aspect of the near-0 and negative values, and the "log"
            # aspect of large values.

            scale = util.scale_factory("logicle", experiment, channel=channel)

            # the splines' knots
            knot_min = channel_min
            knot_max = channel_max

            lg_knot_min = scale(knot_min)
            lg_knot_max = scale(knot_max)
            lg_knots = np.linspace(lg_knot_min, lg_knot_max, self.num_knots)
            knots = scale.inverse(lg_knots)

            # only keep the interior knots
            knots = knots[1:-1]

            # the interpolators' mesh
            if 'af_median' in experiment.metadata[channel] and \
               'af_stdev' in experiment.metadata[channel]:
                mesh_min = experiment.metadata[channel]['af_median'] - \
                           3 * experiment.metadata[channel]['af_stdev']
            elif 'range' in experiment.metadata[channel]:
                mesh_min = -0.01 * experiment.metadata[channel][
                    'range']  # TODO - does this even work?
                warn(
                    "This works best if you apply AutofluorescenceOp before "
                    "computing bleedthrough", util.CytoflowOpWarning)

            mesh_max = experiment.metadata[channel]['range']

            lg_mesh_min = scale(mesh_min)
            lg_mesh_max = scale(mesh_max)
            lg_mesh_axis = \
                np.linspace(lg_mesh_min, lg_mesh_max, self.mesh_size)

            mesh_axis = scale.inverse(lg_mesh_axis)
            mesh_axes.append(mesh_axis)

            for to_channel in self._channels:
                from_channel = channel
                if from_channel == to_channel:
                    continue

                self._splines[from_channel][to_channel] = \
                    scipy.interpolate.LSQUnivariateSpline(tube_data[from_channel].values,
                                                          tube_data[to_channel].values,
                                                          t = knots,
                                                          k = 1)

        mesh = pd.DataFrame(util.cartesian(mesh_axes),
                            columns=[x for x in self._channels])

        mesh_corrected = mesh.apply(_correct_bleedthrough,
                                    axis=1,
                                    args=([[x for x in self._channels],
                                           self._splines]))

        for channel in self._channels:
            chan_values = mesh_corrected[channel].values.reshape(
                [len(x) for x in mesh_axes])
            self._interpolators[channel] = \
                scipy.interpolate.RegularGridInterpolator(points = mesh_axes,
                                                          values = chan_values,
                                                          bounds_error = False,
                                                          fill_value = 0.0)

        # TODO - some sort of validity checking.

    def apply(self, experiment):
        """Applies the bleedthrough correction to an experiment.
        
        Parameters
        ----------
        experiment : Experiment
            the old_experiment to which this op is applied
            
        Returns
        -------
            A new :class:`Experiment` with the bleedthrough subtracted out.
            Corrected channels have the following additional metadata:
            
            - **bleedthrough_channels** : List(Str)
              The channels that were used to correct this one.
        
            - **bleedthrough_fn** : Callable (Tuple(Float) --> Float)
              The function that will correct one event in this channel.  Pass it
              the values specified in `bleedthrough_channels` and it will return
              the corrected value for this channel. 
        """

        if not self.ignore_deprecated:
            raise util.CytoflowOpError(
                None, "BleedthroughPiecewiseOp is DEPRECATED. "
                "To use it anyway, set ignore_deprected "
                "to True.")

        if experiment is None:
            raise util.CytoflowOpError('experiment', "No experiment specified")

        if not self._interpolators:
            raise util.CytoflowOpError(
                None, "Module interpolators aren't set. "
                "Did you run estimate()?")

        if not set(self._interpolators.keys()) <= set(experiment.channels):
            raise util.CytoflowOpError(
                None, "Module parameters don't match experiment channels")

        new_experiment = experiment.clone()

        # get rid of data outside of the interpolators' mesh
        # (-3 * autofluorescence sigma)
        for channel in self._channels:

            # if you update the mesh calculation above, update it here too!
            if 'af_median' in experiment.metadata[channel] and \
               'af_stdev' in experiment.metadata[channel]:
                mesh_min = experiment.metadata[channel]['af_median'] - \
                           3 * experiment.metadata[channel]['af_stdev']
            else:
                mesh_min = -0.01 * experiment.metadata[channel][
                    'range']  # TODO - does this even work?

            new_experiment.data = \
                new_experiment.data[new_experiment.data[channel] > mesh_min]

        new_experiment.data.reset_index(drop=True, inplace=True)

        old_data = new_experiment.data[self._channels]

        for channel in self._channels:
            new_experiment[channel] = self._interpolators[channel](old_data)

            new_experiment.metadata[channel][
                'bleedthrough_channels'] = self._channels
            new_experiment.metadata[channel][
                'bleedthrough_fn'] = self._interpolators[channel]

        new_experiment.history.append(
            self.clone_traits(transient=lambda _: True))
        return new_experiment

    def default_view(self, **kwargs):
        """
        Returns a diagnostic plot to see if the bleedthrough spline estimation
        is working.
        
        Returns
        -------
        IView
            An IView, call plot() to see the diagnostic plots
        """

        if not self.ignore_deprecated:
            raise util.CytoflowOpError(
                None, "BleedthroughPiecewiseOp is DEPRECATED. "
                "To use it anyway, set ignore_deprected "
                "to True.")

        if set(self.controls.keys()) != set(self._splines.keys()):
            raise util.CytoflowOpError(
                'controls',
                "Must have both the controls and bleedthrough to plot")

        v = BleedthroughPiecewiseDiagnostic(op=self, **kwargs)
        v.trait_set(**kwargs)
        return v
class MassCalibratorSweep(MagnetSweep):
    db = Any

    start_dac = Float(4)
    stop_dac = Float(8.0)
    step_dac = Float(0.1)
    period = 10

    calibration_peaks = List

    selected = Any

    # peak detection tuning parameters
    min_peak_height = Float(1)
    min_peak_separation = Range(0.0001, 1000)
    # if the next point is less than delta from the current point than this is not a peak
    # essentially how much does the peak stand out from the background
    delta = Float(1)

    fperiod = Int(50)
    fwindow = Float(1)
    fstep_dac = Float(0.1)
    fexecute_button = Event
    fexecute_label = Property(depends_on='_alive')
    fine_scan_enabled = Property(depends_on='calibration_peaks:isotope')

    _fine_scanning = False

    def setup_graph(self):
        g = self.graph
        g.new_plot()
        g.set_x_title('DAC')

        g.new_series()

        mi = min(self.start_dac, self.stop_dac)
        ma = max(self.start_dac, self.stop_dac)

        g.set_x_limits(min_=mi, max_=ma, pad='0.1')

    def _fine_scan(self):
        operiod = self.period
        self.period = self.fperiod

        self._fine_scanning = True
        i = 1
        self.graph.new_plot(padding_top=10, xtitle='Relative DAC')
        w = self.fwindow / 2.0
        self.graph.set_x_limits(min_=-w, max_=w, plotid=1)
        self._redraw()

        for cp in self.calibration_peaks:
            if not cp.isotope:
                continue

            if self.isAlive():
                self.selected = cp
                self.info('Fine scan calibration peak {}. {} dac={}'.format(
                    i, cp.isotope, cp.dac))
                self._fine_scan_peak(cp)

            i += 1

        self.period = operiod
        self._fine_scanning = False
        if self.isAlive():
            if self.confirmation_dialog('Save to Database'):
                self._save_to_db()
                if self.confirmation_dialog('Apply Calibration'):
                    self._apply_calibration()

    def _pack(self, d):
        data = ''.join([struct.pack('>ff', x, y) for x, y in d])
        return data

    def _save_to_db(self):
        db = self.db
        with db.session_ctx():
            spectrometer = 'Obama'
            hist = db.add_mass_calibration_history(spectrometer)

            # add coarse scan
            d = self._get_coarse_data()
            data = self._pack(d)
            db.add_mass_calibration_scan(hist, blob=data)

            # add fine scans
            plot = self.graph.plots[1]
            cps = [cp for cp in self.calibration_peaks if cp.isotope]
            for cp, ki in zip(cps, sorted(plot.plots.keys())):
                p = plot.plots[ki][0]

                xs = p.index.get_data()
                ys = p.value.get_data()
                d = array((xs, ys)).T
                data = self._pack(d)
                db.add_mass_calibration_scan(
                    hist,
                    cp.isotope,
                    blob=data,
                    center=cp.dac,
                )

    def _apply_calibration(self):
        """
            save calibration peaks as mag field table
        """
        p = os.path.join(paths.spectrometer_dir, 'mftable.csv')
        with open(p, 'w') as wfile:
            writer = csv.writer(wfile, delimiter=',')
            for cp in self.calibration_peaks:
                if cp.isotope:
                    writer.writerow([cp.isotope, cp.dac])

    def _fine_scan_peak(self, cp):
        line, _ = self.graph.new_series(plotid=1)

        c = cp.dac
        w = self.fwindow / 2.0

        steps = self._calc_step_values(c - w, c + w, self.fstep_dac)
        self._scan_dac(steps)

        # get last scan
        xs = line.index.get_data()
        ys = line.value.get_data()

        try:
            center = calculate_peak_center(xs, ys)

            # if not isinstance(center, str):
            [lx, cx, hx], [ly, cy, hy], mx, my = center
            self.graph.add_vertical_rule(cx, plotid=1)
            self.info('new peak center. {} nominal={} dx={}'.format(
                cp.isotope, cp.dac, cx))
            cp.dac += cx
            self._redraw()
        except PeakCenterError, e:
            self.warning(e)
Exemple #20
0
class OdmrFitXY8(PulsedFit):
    """Provides fits and plots for xy8 measurement."""

    #fit = Instance(DoublePulsedFit, factory=DoublePulsedFit)
    text = Str('')

    number_of_resonances = Trait(
        'auto', String('auto', auto_set=False, enter_set=True),
        Int(10000.,
            desc='Number of Lorentzians used in fit',
            label='N',
            auto_set=False,
            enter_set=True))
    threshold = Range(
        low=-99,
        high=99.,
        value=-50.,
        desc=
        'Threshold for detection of resonances [%]. The sign of the threshold specifies whether the resonances are negative or positive.',
        label='threshold [%]',
        mode='text',
        auto_set=False,
        enter_set=True)

    perform_fit = Bool(False, label='perform fit')

    fit_frequencies = Array(value=np.array((np.nan, )),
                            label='frequency [MHz]')
    fit_times = Array(value=np.array((np.nan, )), label='time [ns]')
    fit_line_width = Array(value=np.array((np.nan, )), label='line_width [ns]')
    fit_contrast = Array(value=np.array((np.nan, )), label='contrast [%]')

    def __init__(self):
        super(OdmrFitXY8, self).__init__()
        self.on_trait_change(self._update_processed_plot_data_fit,
                             'fit_parameters',
                             dispatch='ui')
        self.on_trait_change(
            self._update_fit,
            'normalized_counts, number_of_resonances, threshold, perform_fit',
            dispatch='ui')
        #self.on_trait_change(self._update_plot_tau, 'fit.free_evolution_time', dispatch='ui')

    def _update_fit(self):

        if self.perform_fit:

            N = self.number_of_resonances

            if N != 'auto':
                N = int(N)

            try:

                self.fit_parameters = fitting.fit_multiple_lorentzians(
                    self.free_evolution_time,
                    self.normalized_counts,
                    N,
                    threshold=self.threshold * 0.01)
            except Exception:
                logging.getLogger().debug('XY8 fit failed.', exc_info=True)
                self.fit_parameters = np.nan * np.empty(4)
        else:
            self.fit_parameters = np.nan * np.empty(4)

        p = self.fit_parameters
        self.fit_times = p[1::3]
        self.fit_line_width = p[2::3]
        N = len(p) / 3
        contrast = np.empty(N)
        c = p[0]
        pp = p[1:].reshape((N, 3))
        for i, pn in enumerate(pp):
            a = pn[2]
            g = pn[1]
            A = np.abs(a / (np.pi * g))
            if a > 0:
                contrast[i] = 100 * A / (A + c)
            else:
                contrast[i] = 100 * A / c
        self.fit_contrast = contrast
        self.fit_frequencies = 1e+3 / (2 * self.fit_times)

    processed_plot_data = Instance(ArrayPlotData,
                                   factory=ArrayPlotData,
                                   kw={
                                       'x': np.array((0, 1)),
                                       'y': np.array((0, 0)),
                                       'fit': np.array((0, 0)),
                                   })

    plots = [{
        'data': ('tau', 'spin_state1'),
        'color': 'blue',
        'name': 'signal1',
        'label': 'Reference'
    }, {
        'data': ('tau', 'spin_state2'),
        'color': 'green',
        'name': 'signal2',
        'label': 'Signal'
    }, {
        'data': ('x', 'y'),
        'color': 'black',
        'name': 'processed'
    }, {
        'data': ('x', 'fit'),
        'color': 'purple',
        'name': 'fitting',
        'label': 'Fit'
    }]

    # def _update_plot_tau(self):
    # self.processed_plot_data.set_data('y', self.fit.free_evolution_time)

    def _update_processed_plot_data_fit(self):

        if not np.isnan(self.fit_parameters[0]):
            self.processed_plot_data.set_data(
                'fit',
                fitting.NLorentzians(*self.fit_parameters)(
                    self.free_evolution_time))
            p = self.fit_parameters
            f = p[1::3]
            w = p[2::3]
            N = len(p) / 3
            contrast = np.empty(N)
            c = p[0]
            pp = p[1:].reshape((N, 3))
            for i, pi in enumerate(pp):
                a = pi[2]
                g = pi[1]
                A = np.abs(a / (np.pi * g))
                if a > 0:
                    contrast[i] = 100 * A / (A + c)
                else:
                    contrast[i] = 100 * A / c
            s = ''

            for i, fi in enumerate(self.fit_times):
                s += 'f %i: %.6e ns, HWHM %.3e ns, contrast %.1f%%\n, freq %.3e MHz' % (
                    i + 1, fi, self.fit_line_width[i], contrast[i],
                    self.fit_frequencies[i])
            self.text = s

    traits_view = View(
        Tabbed(
            VGroup(HGroup(
                Item('number_of_resonances', width=-60),
                Item('threshold', width=-60),
                Item('perform_fit'),
            ),
                   HGroup(
                       Item('fit_contrast', width=-90, style='readonly'),
                       Item('fit_line_width', width=-90, style='readonly'),
                       Item('fit_frequencies', width=-90, style='readonly'),
                       Item('fit_times', width=-90, style='readonly'),
                   ),
                   label='fit_parameter'),
            HGroup(Item('integration_width'),
                   Item('position_signal'),
                   Item('position_normalize'),
                   label='settings'),
        ),
        title='Noise spectrum Fit',
    )

    get_set_items = PulsedFit.get_set_items + [
        'fit_parameters', 'fit_frequencies', 'fit_line_width', 'fit_contrast',
        'text', 'fit_times'
    ]
class Distance_to_NV(HasTraits, GetSetItemsMixin):

    counts=Array( )
    counts2=Array( )
    time=Array( )
    normalized_counts=Array()
    nu=Array()
    S=Array()
    FFx=Array()
    FFy=Array()
    tau=Array() #pulse spacing
    
    line_label = Instance(PlotLabel)
    line_data = Instance(ArrayPlotData)
    
    myfile=File(exists=True)
    
    substance =  Enum('immersion oil, H1 signal', 'hBN, B11 signal',
                     label='substance',
                     desc='choose the nucleus to calculate larmor frequency',
                     editor=EnumEditor(values={'immersion oil, H1 signal':'1:immersion oil, H1 signal','hBN, B11 signal':'2:hBN, B11 signal'},cols=8),)  
    
    rabi_contrast=Range(low=0., high=100., value=30., desc='full contrast', label='full contrast [%]', auto_set=False, enter_set=True)
    z=Float(value=0., desc='distance to NV [nm]', label='distance to NV [nm]')
    frequencies=Float(value=0., desc='frequencies [MHz]', label='frequencies [MHz]')
    z2=Float(value=0., desc='distance to NV [nm]', label='distance to NV [nm]')
    fit_threshold = Range(low= -99, high=99., value= 50., desc='Threshold for detection of resonances [%]. The sign of the threshold specifies whether the resonances are negative or positive.', label='threshold [%]', mode='text', auto_set=False, enter_set=True)
    n_FFT=Range(low=2.**10, high=2.**21, value=2.0e+6, desc='NUMBER OF POINTS FOR FOURIER TRANSFORM', label='N FFT', editor=TextEditor(auto_set=False, enter_set=True, evaluate=float, format_str='%.2e'))
    
    
    alpha=Range(low=0., high=1000., value=1., desc='fitting paramenter', label='alpha', auto_set=False, enter_set=True, mode='spinner')
    pulse_spacing=Float(value=0., desc='pulse_spacing', label='pulse spacing [ns]')
    pi_pulse=Float(value=0., desc='pi pulse', label='pi pulse')
    baseline=Float(value = 0., desc='baseline', label='baseline')
    Brms=Float(value = 0., desc='Brms', label='Brms [nT]')
    Brms2=Float(value = 0., desc='Brms', label='Brms [nT]')
    N=Int(value = 0, desc='number of repetitions', label='XY8-N')
    Magnetic_field=Range(low=0., high=3000., value=330., desc='magnetic field', label='Magnetic field [G]', auto_set=False, enter_set=True)
      
    perform_fit=Bool(False, label='perform fit')
    number_of_resonances=Trait('auto', String('auto', auto_set=False, enter_set=True), Int(10000., desc='Number of Lorentzians used in fit', label='N', auto_set=False, enter_set=True))
    
    norm_button=Button(label='normalize', desc='normalize')
    import_data_button = Button(label='import data', desc='import xy8')
    calculate_noise_spectrum_button=Button(label='calculate spectrum', desc='calculate noise spectrum')
    distance_to_NV_button = Button(label='calculate distance', desc='calculate distance to NV')
    filter_function_button = Button(label='Filter Function Fourier Transform', desc='calculate filter function')
    distance_from_envelope_button = Button(label='calculate distance from envelope', desc='calculate distance from envelope')
    
    plot_data_xy8_line  = Instance(ArrayPlotData)
    xy8_plot=Instance(Plot, editor=ComponentEditor())
    
    plot_data_normxy8_line  = Instance( ArrayPlotData )
    normxy8_plot=Instance(Plot, editor=ComponentEditor())
    
    plot_data_filter_function  = Instance( ArrayPlotData )
    filter_function_plot=Instance(Plot, editor=ComponentEditor())
    
    plot_data_spin_noise  = Instance( ArrayPlotData )
    spin_noise_plot=Instance(Plot, editor=ComponentEditor())
    
    plot_data_second_method  = Instance( ArrayPlotData )
    second_method_plot=Instance(Plot, editor=ComponentEditor())
    
    N_tau =  Range(low=0, high=50, value=6, desc='N tau', label='n tau', mode='spinner',  auto_set=False, enter_set=False)
        
    line_width=Float(value=0., desc='line_width', label='linewidth [kHz]')
    
    fit_parameters = Array(value=np.array((np.nan, np.nan, np.nan, np.nan)))
    fit_centers = Array(value=np.array((np.nan,)), label='center_position [Hz]') 
    fit_line_width = Array(value=np.array((np.nan,)), label='uncertanity [Hz]') 
    fit_contrast = Array(value=np.array((np.nan,)), label='contrast [%]')
    
    fitting_func = Enum('gaussian', 'loretzian',
                     label='fit',
                     desc='fitting function',
                     editor=EnumEditor(values={'gaussian':'1:gaussian','loretzian':'2:loretzian'},cols=2),)
    
    def __init__(self):
        super(Distance_to_NV, self).__init__()                  

        self._create_xy8_plot()      
        self._create_normxy8_plot()  
        self._create_filter_function_plot()
        self._create_spin_noise_plot()
        self._create_second_method_plot()
        self.on_trait_change(self._update_line_data_fit, 'perform_fit, fitting_func,fit_parameters', dispatch='ui')
        self.on_trait_change(self._update_fit, 'fit, perform_fit', dispatch='ui')
        self.on_trait_change(self._update_plot, 'alpha, line_width, n_FFT', dispatch='ui')       
            
    def _counts_default(self):
        return np.zeros(self.pdawg3.fit.x_tau.shape[0])   

    def _counts2_default(self):
        return np.zeros(self.pdawg3.fit.x_tau.shape[0])
        
    def _normalized_counts_default(self):
        return np.zeros(self.pdawg3.fit.x_tau.shape[0])

    def _time_default(self):
        return np.zeros(self.pdawg3.fit.x_tau.shape[0])        
        
    def _create_normxy8_plot(self):
        plot_data_normxy8_line = ArrayPlotData(normalized_counts=np.array((0., 1.)), time=np.array((0., 0.)), fit=np.array((0., 0.)) )
        plot = Plot(plot_data_normxy8_line, width=50, height=40, padding=8, padding_left=64, padding_bottom=32)
        plot.plot(('time', 'normalized_counts'), color='green', line_width = 2)
        plot.index_axis.title = 'time [ns]'
        plot.value_axis.title = 'normalized counts'
        #plot.title='normalized counts'
        
        self.plot_data_normxy8_line = plot_data_normxy8_line
        self.normxy8_plot = plot
        return self.normxy8_plot   
        
        
    def _create_xy8_plot(self):
        plot_data_xy8_line = ArrayPlotData(counts2=np.array((0., 1.)), time=np.array((0., 0.)), fit=np.array((0., 0.)) )
        plot = Plot(plot_data_xy8_line, width=50, height=40, padding=8, padding_left=64, padding_bottom=32)
        plot.plot(('time', 'counts2'), color='green', line_width = 2)
        plot.index_axis.title = 'time [ns]'
        plot.value_axis.title = 'counts'
        #plot.title='counts'
        self.plot_data_xy8_line = plot_data_xy8_line
        self.xy8_plot = plot
        return self.xy8_plot
        
    def _create_filter_function_plot(self):
        plot_data_filter_function = ArrayPlotData(freq=np.array((0., 1.)), value=np.array((0., 0.)), fit=np.array((0., 0.)) )
        plot = Plot(plot_data_filter_function, width=50, height=40, padding=8, padding_left=64, padding_bottom=32)
        plot.plot(('freq', 'value'), color='red', type='line', line_width = 3)
        plot.index_axis.title = 'frequency [MHz]'
        plot.value_axis.title = 'Filter Function Fourier Transform'
        #plot.title='Fourier transform of filter function'
        
        self.plot_data_filter_function = plot_data_filter_function
        self.filter_function_plot = plot
        return self.filter_function_plot
        
    def _create_spin_noise_plot(self):
        plot_data_spin_noise = ArrayPlotData(value=np.array((0., 1.)), time=np.array((0., 0.)), fit=np.array((0., 0.)) )
        plot = Plot(plot_data_spin_noise, width=50, height=40, padding=8, padding_left=64, padding_bottom=32)
        plot.plot(('time', 'value'), color='green', line_width = 2)
        plot.index_axis.title = 'frequency [MHz]'
        plot.value_axis.title = 'noise spectrum [nT^2/Hz]'
        #plot.title='noise spectrum'
        line_label = PlotLabel(text='', hjustify='left', vjustify='top', position=[50, 100])
        plot.overlays.append(line_label)
        self.line_label = line_label
        
        self.plot_data_spin_noise = plot_data_spin_noise
        self.spin_noise_plot = plot
        return self.spin_noise_plot
    
    def _create_second_method_plot(self):
        plot_data_second_method = ArrayPlotData(value=np.array((0., 1.)), time=np.array((0., 0.)), fit=np.array((0., 0.)) )
        plot = Plot(plot_data_second_method, width=50, height=40, padding=8, padding_left=64, padding_bottom=32)
        plot.plot(('time', 'value'), color='green', line_width = 2)
        plot.index_axis.title = 'time [ns]'
        plot.value_axis.title = 'Normalized Contrast'
        #plot.title='noise spectrum'
        line_label = PlotLabel(text='', hjustify='left', vjustify='top', position=[64, 280])
        plot.overlays.append(line_label)
        self.line_label = line_label
        
        self.plot_data_second_method= plot_data_second_method
        self.second_method_plot = plot
        return self.second_method_plot
    
    def _import_data_button_fired(self):
              
        File1=open(self.myfile,'r')

        File2=cPickle.load(File1)

        #File2.keys()
               
        self.N= File2['measurement']['pulse_num']
        self.counts=File2['fit']['spin_state1']
        self.counts2= File2['fit']['spin_state2']
        self.time= File2['measurement']['tau']
        self.pi_pulse=File2['measurement']['pi_1']
        
        self.rabi_contrast=File2['measurement']['rabi_contrast']
        #self.rabi_contrast=40
        counts=self.counts2-self.counts
        
        self.plot_data_xy8_line.set_data('time', self.time)
        self.plot_data_xy8_line.set_data('counts2', counts) 

        self.tau=(2*self.time+self.pi_pulse)*1e-9 # in seconds   

        self._norm_button_fired()  
        self._calculate_noise_spectrum_button_fired()
        self.perform_fit=True     
        #self._distance_to_NV_button_fired()       
        #self._distance_from_envelope_button_fired()
       
        
    def _norm_button_fired(self):
        
        l=(self.counts+self.counts2)/2.
        self.baseline=sum(l)/len(l)
        
        C0_up=self.baseline/(1-0.01*self.rabi_contrast/2)
        
        C0_down=C0_up*(1-0.01*self.rabi_contrast)
        
        counts=self.counts2-self.counts
        
        self.normalized_counts=(counts)/(C0_up-C0_down)
        
        self.plot_data_normxy8_line.set_data('time', self.tau*1e+9)
        self.plot_data_normxy8_line.set_data('normalized_counts', self.normalized_counts)
        
        
    def _calculate_noise_spectrum_button_fired(self):
 
        # Noise spectrum 
        
        g=2*np.pi*2.8*1e+10 #Hz/T
        hi=-np.log(self.normalized_counts)
        self.S=self.alpha*2*hi/(g**2*8*self.N*self.tau)  # noise spectrum
           
        self.plot_data_spin_noise.set_data('value', self.S*1e18)
        self.plot_data_spin_noise.set_data('time', (1/(2*self.tau))*1e-6)         
        
    def _filter_function(self, tau):
     
        #generate filter function
        dt = 1e-9
        n = int(tau / dt)
            
        v = np.zeros(8*self.N*n)
       
        T=np.linspace(0, dt*n*8*self.N, num=8*self.N*n)
        v[:n/2]=1
        k=n/2+1
        for j in range(8*self.N-1):
            v[(n/2+j*n):(n/2+j*n+n)]=(-1)**(j+1)
            k=k+1
        v[8*self.N*n-n/2:8*self.N*n]=np.ones((n/2,), dtype=np.int)    
        return T, v
        
    def _fourier_transform(self, N_tau):
        
        T, v = self._filter_function(self.tau[N_tau])
            
        g=int(self.n_FFT)
       
        signalFFT=np.fft.fft(v, g)
        
        yf=(np.abs(signalFFT)**2)*(1e-9)/(8*self.N)
        #yf = (np.abs(signalFFT)**1)/(8*self.N)
        #xf[:, m] = np.linspace(0, 1.0/(tau[m]/n), 8*self.N*n)
        xf = np.fft.fftfreq(g, 1e-9)   

        self.FFy=yf[0:g] #take only real part
        self.FFx=xf[0:g]
        
        f1=(1/(2*self.tau[0]))*1.03 #bigger
        f2=(1/(2*self.tau[-1]))*0.97 #smaller
         
        yf1=self.FFy[np.where(self.FFx<=f1)]
        xf1=self.FFx[np.where(self.FFx<=f1)]  

        self.FFy=self.FFy[np.where(xf1>=f2)]
        self.FFx=self.FFx[np.where(xf1>=f2)]       
        
    def _filter_function_button_fired(self):
    
        self._fourier_transform(self.N_tau)
     
        self.plot_data_filter_function.set_data('value', self.FFy) 
        self.plot_data_filter_function.set_data('freq', self.FFx*1e-6)
        self.pulse_spacing=self.tau[self.N_tau]*1e+9
        
    def _distance_to_NV_button_fired(self):
    
        rho_H = 5*1e+28 # m^(-3), density of protons
        rho_B11 = 2.1898552552552544e+28  # m^(-3), density of B11
        
        mu_p= 1.41060674333*1e-26 # proton magneton, J/T
        g_B11=85.847004*1e+6/(2*np.pi) # Hz/T
        hbar=1.054571800e-34 #J*s
        mu_B11=hbar*g_B11*2*np.pi # for central transition
        
        if self.substance == 'immersion oil, H1 signal':
            rho=rho_H   
            mu=mu_p
        elif self.substance == 'hBN, B11 signal':
            rho=rho_B11
            mu=mu_B11
    
        g=2*np.pi*2.8*1e+10 #rad/s/T            
        mu0=4*np.pi*1e-7 # vacuum permeability, H/m or T*m/A
                
        freq = 1/(2*self.tau)  # in Hz
                
        self._fourier_transform(0) 
        
        fit_x=self.FFx[np.where(self.FFx<=freq[0]*1.03)]
        fit_x=fit_x[np.where(fit_x>=freq[-1]*0.97)]
        
        d=fit_x[1]-fit_x[0]
        
        if self.fitting_func == 'loretzian':  
                fit_func = fitting.NLorentzians
        elif self.fitting_func == 'gaussian':
                fit_func = fitting.NGaussian 
        
        self.fit_parameters[2]=self.line_width*1e-3
        S=fit_func(*self.fit_parameters)(fit_x*1e-6) # fit for intergration
        
        base=S[0]
        
        Int = trapz((S-base), dx=d) # integration
        
        self.Brms=np.sqrt(Int)
        
        if self.substance == 'immersion oil, H1 signal':
            self.z=np.power(rho*((0.05*mu0*mu/self.Brms*1e9)**2), 1/3.)
        elif self.substance == 'hBN, B11 signal':
            C1=np.sqrt(0.654786)/(4*np.pi)
            self.z=np.power(rho*((C1*mu0*mu/self.Brms*1e9)**2), 1/3.)
        
        self.z=self.z*1e+9
        
        x1 = fit_x*1e-6                      #baseline
        y1 = [base.item()]*len(S)
        
        x_key1 = 'x1'
        y_key1 = 'y1'
        self.plot_data_spin_noise.set_data(x_key1, x1)
        self.plot_data_spin_noise.set_data(y_key1, y1)
        self.spin_noise_plot.plot((x_key1, y_key1), color='red', line_width = 1)
        
        x2 = fit_x*1e-6           # integrated fit
        y2 = S
        x_key2 = 'x2'
        y_key2 = 'y2'
        self.plot_data_spin_noise.set_data(x_key2, x2)
        self.plot_data_spin_noise.set_data(y_key2, y2)
        self.spin_noise_plot.plot((x_key2, y_key2), color='red', line_width = 1)
        
        # convolution
        
        Sum=np.ones(len(self.tau))
        for i in range(len(self.tau)):
            self._fourier_transform(i)
            
            #self.FFy=self.FFy[np.where(self.FFx<=freq[0]*1.05)]
            #fit_x=self.FFx[np.where(self.FFx<=freq[0]*1.05)]
            
            #self.FFy=self.FFy[np.where(fit_x>=freq[-1]*0.95)]
            
            S0=0
            for j in range(len(S)):             # j correspond to frequency
                S0=S0+S[j]*(self.FFy[j]**2)*1e-18
            Sum[i]=S0  
        
        dt = 1e-9        
        hi=((g**2)/2.)*Sum*dt/(8*self.N)
        
        calculated_counts=np.exp(-hi)   
       
        x3 =self.tau*1e+9
        y3 = calculated_counts
        x_key3 = 'x3'
        y_key3 = 'y3'
        self.plot_data_normxy8_line.set_data(x_key3, x3)
        self.plot_data_normxy8_line.set_data(y_key3, y3)
        self.normxy8_plot.plot((x_key3, y_key3), color='red', line_width = 1)      
        
    def _distance_from_envelope_button_fired(self):
    
        rho_H = 5*1e+28 # m^(-3), density of protons
        rho_B11 = 2.69105105105105e+28  # m^(-3), density of B11
        
        
        mu_p= 1.41060674333*1e-26 # proton magneton, J/T
        g_B11=85.84*1e+6 # Hz/T
        hbar=1.054571800e-34 #J*s
        mu_B11=hbar*g_B11 # for central transition
        
        if self.substance == 'immersion oil, H1 signal':
            rho=rho_H   
            mu=mu_p
        elif self.substance == 'hBN, B11 signal':
            rho=rho_B11
            mu=mu_B11
    
        def K(tau, T2star):
            tau=self.tau
            N=8*self.N
            g_H1=42.576 #gyromagnetic ratio, MHz/T 
            larmor_frequency=g_H1*self.Magnetic_field/10000 #Larmor frequency in MHz
            
            #return N*tau*np.sinc((N*tau/2)*(2*np.pi*larmor_frequency*1e+6 - np.pi/tau))
            return (2*T2star**2/(1+T2star**2*(2*np.pi*larmor_frequency*1e+6 - np.pi/tau)**2)**2)*\
                   (np.exp(-N*tau/T2star)*((1-T2star**2*(2*np.pi*larmor_frequency*1e+6 - np.pi/tau)**2)*\
                   np.cos(N*tau*(2*np.pi*larmor_frequency*1e+6 - np.pi/tau))-\
                   2*T2star*(2*np.pi*larmor_frequency*1e+6 - np.pi/tau)*np.sin(N*tau*(2*np.pi*larmor_frequency*1e+6 - np.pi/tau)))+\
                   (N*tau/T2star)*(1+T2star**2*(2*np.pi*larmor_frequency*1e+6 - np.pi/tau)**2) + T2star**2*(2*np.pi*larmor_frequency*1e+6 - np.pi/tau)**2-1)
    
        def func(tau, Brms, T2star):
            
            gamma_e=1.76e11 #rad/s/T
            return np.exp(-2*(gamma_e*Brms/np.pi)**2*K(tau, T2star))  
            
        def residuals(tau, tpl, S):
            Brms=tpl[0]
            T2star=tpl[1]
            return S-func(tau, Brms, T2star)        
            
        S=self.normalized_counts/self.normalized_counts[0]
       
        self.plot_data_second_method.set_data('time', self.tau*1e9)
        self.plot_data_second_method.set_data('value', S)
        T2star=60*1e-6
        
        popt,pcov = curve_fit(func, self.tau, S, p0=[self.Brms*1e-9, T2star])  
        
        tpl=(self.Brms*1e-9, T2star)
        
        #popt, pcov = leastsq(residuals, tpl[:], args=(self.tau, S))
        
        self.Brms2=popt[0]*1e+9
      
        fit=func(self.tau, *popt)

        self.plot_data_second_method.set_data('fit', fit)
        self.second_method_plot.plot(('time', 'fit'), color='blue', line_width = 1)   

        gamma_n=2.68*1e8 # rad/s/T
        mu0=4*np.pi*1e-7 # vacuum permeability, H/m or T*m/A
        h_planck = 1.054571800*1e-34 # J*s
        
        self.z2=np.power((5*np.pi*rho/96)*(mu0*h_planck*gamma_n/(4*np.pi*popt[0]))**2, 1./3)*1e+9
                                   
     #fitting---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
              
    def _perform_fit_changed(self, new):
    
        plot = self.spin_noise_plot
        x_name = self.plot_data_spin_noise.list_data()[2]
                        
        if new:
            plot.plot(('time', 'fit'), style='line', color='blue', name='fit', line_width=1)
            self.line_label.visible = True
            self.line_width=self.fit_parameters[2]
        else:
            plot.delplot('fit')
            self.line_label.visible = False
        plot.request_redraw()
            
            
    def _update_fit(self):
        if self.perform_fit:
        
            if self.fitting_func == 'loretzian':
                fit_func = fitting.fit_multiple_lorentzians
            elif self.fitting_func == 'gaussian':
                fit_func = fitting.fit_multiple_gaussian
                    
            N = self.number_of_resonances # number of peaks
            fit_x = (1/(2*self.tau))*1e-6
            self.counts = self.S*1e18
            p = fit_func(fit_x, self.counts, N, threshold=self.fit_threshold * 0.01)
            
        else:
            p = np.nan * np.empty(4)
            
        self.fit_parameters = p
        self.fit_centers = p[1::3]
        self.fit_line_width = p[2::3]
        N = len(p) / 3
        contrast = np.empty(N)
        c = p[0]
        pp = p[1:].reshape((N, 3))
        for i, pn in enumerate(pp):
            a = pn[2]
            g = pn[1]
            A = np.abs(a/(np.pi * g))
            if a > 0:
                contrast[i] = 100 * A / (A + c)
            else:
                contrast[i] = 100 * A / c
        self.fit_contrast = contrast     
        self.line_width=self.fit_parameters[2]*1e+3   
        #self.frequencies=self.centers[1]*1e+6
            
    def _update_line_data_fit(self):
        if not np.isnan(self.fit_parameters[0]):       
            
            if self.fitting_func == 'loretzian':  
                fit_func = fitting.NLorentzians
            elif self.fitting_func == 'gaussian':
                fit_func = fitting.NGaussian                
                               
            self.plot_data_spin_noise.set_data('fit', fit_func(*self.fit_parameters)((1/(2*self.tau))*1e-6))
            p = self.fit_parameters
            f = p[1::3]
            w = p[2::3]
            N = len(p) / 3
            contrast = np.empty(N)
            c = p[0]
            pp = p[1:].reshape((N, 3))
            for i, pi in enumerate(pp):
                a = pi[2]
                g = pi[1]
                A = np.abs(a / (np.pi * g))
                if a > 0:
                    contrast[i] = 100 * A / (A + c)
                else:
                    contrast[i] = 100 * A / c
            s = ''
            for i, fi in enumerate(f):
                s += '%.2f MHz, LW %.2f kHz, contrast %.1f%%\n' % (fi, w[i]*1e+3, contrast[i])
            self.line_label.text = s  
            self.line_width=self.fit_parameters[2]*1e+3
            
    def _update_plot(self):
        self._distance_to_NV_button_fired()
        # self._ditance_button_fired()
        # self._ditance_button_fired()
        # self._ditance_button_fired()
            
    def save_all(self, filename):
        self.save_figure(self.xy8_plot, filename + 'XY8-' + str(self.N) + '_counts' + '.png' )
        self.save_figure(self.normxy8_plot, filename + 'XY8-' + str(self.N) + '_normalized_counts' + '.png' )
        self.save_figure(self.spin_noise_plot, filename + 'XY8-' + str(self.N) + '_noise_spectrum_z=' + string.replace(str(self.z)[0:4], '.', 'd') + 'nm_B=' + string.replace(str(self.Brms)[0:4], '.', 'd') + 'nT.png' )
        self.save(filename + 'XY8-' + str(self.N) +'_distance' + '.pyd' )       
        self.save_figure(self.second_method_plot, filename + 'XY8-' + str(self.N) + '_noise_spectrum_z=' + string.replace(str(self.z2)[0:4], '.', 'd') + 'nm_B=' + string.replace(str(self.Brms2)[0:4], '.', 'd') + 'nT.png' )
        
    traits_view =  View( VGroup( HGroup( VGroup( HGroup( Item('myfile', show_label=False),
                                                         Item('import_data_button', show_label=False),
                                                         Item('substance', style='custom', show_label=False)
                                                       ),
                                                 HGroup( Item('xy8_plot', show_label=False, resizable=True),
                                                       ),
                                                 HGroup( Item('filter_function_button', show_label=False)
                                                       ),
                                                 HGroup(                                          
                                                         Item('N_tau', width= -40),
                                                         Item('pulse_spacing', width= -40, format_str='%.1f')
                                                       ),
                                                 HGroup( Item('filter_function_plot', show_label=False, resizable=True),
                                                       ),
                                               ),
                                    
                                         VGroup( HGroup( Item('rabi_contrast', width= -40),
                                                         Item('N', width= -40),
                                                         Item('norm_button', show_label=False)
                                                       ),
                                                                    
                                                 HGroup( Item('normxy8_plot', show_label=False, resizable=True)
                                                       ),
                                                 HGroup( Item('calculate_noise_spectrum_button', show_label=False),
                                                       ),
                                                 HGroup( Item('fit_threshold', width= -40),
                                                         Item('perform_fit'),
                                                         Item('number_of_resonances', width= -60),
                                                 
                                                       ),
                                                 HGroup( Item('spin_noise_plot', show_label=False, resizable=True)
                                                       ),
                                               ), 
                                       ),
                                 HGroup( 
                                         Item('alpha', width= -60, format_str='%.2f'),
                                         Item('n_FFT', width= -70),
                                         Item('line_width', width= -60, format_str='%.2f'),
                                         Item('fitting_func', style='custom', show_label=False),                                         
                                         Item('distance_to_NV_button', show_label=False),
                                         Item('z', width= -60, style='readonly', format_str='%.2f'),
                                         Item('Brms', width= -60, style='readonly', format_str='%.0f')
                                       )
                               ),
                               
                          Group( HGroup( Item('Magnetic_field', width= -60),
                                         Item('distance_from_envelope_button', show_label=False)),
                                 HGroup( Item('z2', width= -60, style='readonly', format_str='%.2f'),       
                                         Item('Brms2', width= -60, style='readonly', format_str='%.0f')), 
                                         
                                 HGroup( Item('second_method_plot', show_label=False, resizable=True)
                                                       )),
                               
                         menubar=MenuBar( Menu( Action(action='Save_All', name='Save all'),
                                                Action(action='load', name='Load'),
                                                name='File'
                                               )
                                        ),
                             
                     title='NV depth Toby', width=1200, height=800, buttons=[], resizable=True, handler=DistanceHandler)
                     
    get_set_items = ['N_tau', 'pulse_spacing', 'rabi_contrast', 'N', 'alpha','line_width', 'S', 'time', 'counts', 'counts2',
                     'z', 'Brms', 'normalized_counts','tau', 'fit_parameters','fit_centers','fit_contrast','fit_line_width',
                     'fitting_func', 'n_FFT',
                     '__doc__']
Exemple #22
0
class ImageSlice(HasPrivateTraits):

    # -- Trait Definitions ----------------------------------------------------

    #: The ImageResource to be sliced and drawn:
    image = Instance(ImageResource)

    #: The minimum number of adjacent, identical rows/columns needed to identify
    #: a repeatable section:
    threshold = Int(10)

    #: The maximum number of 'stretchable' rows and columns:
    stretch_rows = Enum(1, 2)
    stretch_columns = Enum(1, 2)

    #: Width/height of the image borders:
    top = Int()
    bottom = Int()
    left = Int()
    right = Int()

    #: Width/height of the extended image borders:
    xtop = Int()
    xbottom = Int()
    xleft = Int()
    xright = Int()

    #: The color to use for content text:
    content_color = Instance(wx.Colour)

    #: The color to use for label text:
    label_color = Instance(wx.Colour)

    #: The background color of the image:
    bg_color = Color

    #: Should debugging slice lines be drawn?
    debug = Bool(False)

    # -- Private Traits -------------------------------------------------------

    #: The current image's opaque bitmap:
    opaque_bitmap = Instance(wx.Bitmap)

    #: The current image's transparent bitmap:
    transparent_bitmap = Instance(wx.Bitmap)

    #: Size of the current image:
    dx = Int()
    dy = Int()

    #: Size of the current image's slices:
    dxs = List()
    dys = List()

    #: Fixed minimum size of current image:
    fdx = Int()
    fdy = Int()

    # -- Public Methods -------------------------------------------------------

    def fill(self, dc, x, y, dx, dy, transparent=False):
        """ 'Stretch fill' the specified region of a device context with the
            sliced image.
        """
        # Create the source image dc:
        idc = wx.MemoryDC()
        if transparent:
            idc.SelectObject(self.transparent_bitmap)
        else:
            idc.SelectObject(self.opaque_bitmap)

        # Set up the drawing parameters:
        sdx, sdy = self.dx, self.dx
        dxs, dys = self.dxs, self.dys
        tdx, tdy = dx - self.fdx, dy - self.fdy

        # Calculate vertical slice sizes to use for source and destination:
        n = len(dxs)
        if n == 1:
            pdxs = [
                (0, 0),
                (1, max(1, tdx // 2)),
                (sdx - 2, sdx - 2),
                (1, max(1, tdx - (tdx // 2))),
                (0, 0),
            ]
        elif n == 3:
            pdxs = [
                (dxs[0], dxs[0]),
                (dxs[1], max(0, tdx)),
                (0, 0),
                (0, 0),
                (dxs[2], dxs[2]),
            ]
        else:
            pdxs = [
                (dxs[0], dxs[0]),
                (dxs[1], max(0, tdx // 2)),
                (dxs[2], dxs[2]),
                (dxs[3], max(0, tdx - (tdx // 2))),
                (dxs[4], dxs[4]),
            ]

        # Calculate horizontal slice sizes to use for source and destination:
        n = len(dys)
        if n == 1:
            pdys = [
                (0, 0),
                (1, max(1, tdy // 2)),
                (sdy - 2, sdy - 2),
                (1, max(1, tdy - (tdy // 2))),
                (0, 0),
            ]
        elif n == 3:
            pdys = [
                (dys[0], dys[0]),
                (dys[1], max(0, tdy)),
                (0, 0),
                (0, 0),
                (dys[2], dys[2]),
            ]
        else:
            pdys = [
                (dys[0], dys[0]),
                (dys[1], max(0, tdy // 2)),
                (dys[2], dys[2]),
                (dys[3], max(0, tdy - (tdy // 2))),
                (dys[4], dys[4]),
            ]

        # Iterate over each cell, performing a stretch fill from the source
        # image to the destination window:
        last_x, last_y = x + dx, y + dy
        y0, iy0 = y, 0
        for idy, wdy in pdys:
            if y0 >= last_y:
                break

            if wdy != 0:
                x0, ix0 = x, 0
                for idx, wdx in pdxs:
                    if x0 >= last_x:
                        break

                    if wdx != 0:
                        self._fill(idc, ix0, iy0, idx, idy, dc, x0, y0, wdx,
                                   wdy)
                        x0 += wdx
                    ix0 += idx
                y0 += wdy
            iy0 += idy

        if self.debug:
            dc.SetPen(wx.Pen(wx.RED))
            dc.DrawLine(x, y + self.top, last_x, y + self.top)
            dc.DrawLine(x, last_y - self.bottom - 1, last_x,
                        last_y - self.bottom - 1)
            dc.DrawLine(x + self.left, y, x + self.left, last_y)
            dc.DrawLine(last_x - self.right - 1, y, last_x - self.right - 1,
                        last_y)

    # -- Event Handlers -------------------------------------------------------

    def _image_changed(self, image):
        """ Handles the 'image' trait being changed.
        """
        # Save the original bitmap as the transparent version:
        self.transparent_bitmap = (
            bitmap) = image.create_image().ConvertToBitmap()

        # Save the bitmap size information:
        self.dx = dx = bitmap.GetWidth()
        self.dy = dy = bitmap.GetHeight()

        # Create the opaque version of the bitmap:
        self.opaque_bitmap = wx.Bitmap(dx, dy)
        mdc2 = wx.MemoryDC()
        mdc2.SelectObject(self.opaque_bitmap)
        mdc2.SetBrush(wx.Brush(WindowColor))
        mdc2.SetPen(wx.TRANSPARENT_PEN)
        mdc2.DrawRectangle(0, 0, dx, dy)
        mdc = wx.MemoryDC()
        mdc.SelectObject(bitmap)
        mdc2.Blit(0, 0, dx, dy, mdc, 0, 0, useMask=True)
        mdc.SelectObject(wx.NullBitmap)
        mdc2.SelectObject(wx.NullBitmap)

        # Finally, analyze the image to find out its characteristics:
        self._analyze_bitmap()

    # -- Private Methods ------------------------------------------------------

    def _analyze_bitmap(self):
        """ Analyzes the bitmap.
        """
        # Get the image data:
        threshold = self.threshold
        bitmap = self.opaque_bitmap
        dx, dy = self.dx, self.dy
        image = bitmap.ConvertToImage()

        # Convert the bitmap data to a numpy array for analysis:
        data = reshape(image.GetData(), (dy, dx, 3)).astype(uint8)

        # Find the horizontal slices:
        matches = []
        y, last = 0, dy - 1
        max_diff = 0.10 * dx
        while y < last:
            y_data = data[y]
            for y2 in range(y + 1, dy):
                if abs(y_data - data[y2]).sum() > max_diff:
                    break

            n = y2 - y
            if n >= threshold:
                matches.append((y, n))

            y = y2

        n = len(matches)
        if n == 0:
            if dy > 50:
                matches = [(0, dy)]
            else:
                matches = [(dy // 2, 1)]
        elif n > self.stretch_rows:
            matches.sort(key=lambda x: x[1], reverse=True)
            matches = matches[:self.stretch_rows]

        # Calculate and save the horizontal slice sizes:
        self.fdy, self.dys = self._calculate_dxy(dy, matches)

        # Find the vertical slices:
        matches = []
        x, last = 0, dx - 1
        max_diff = 0.10 * dy
        while x < last:
            x_data = data[:, x]
            for x2 in range(x + 1, dx):
                if abs(x_data - data[:, x2]).sum() > max_diff:
                    break

            n = x2 - x
            if n >= threshold:
                matches.append((x, n))

            x = x2

        n = len(matches)
        if n == 0:
            if dx > 50:
                matches = [(0, dx)]
            else:
                matches = [(dx // 2, 1)]
        elif n > self.stretch_columns:
            matches.sort(key=lambda x: x[1], reverse=True)
            matches = matches[:self.stretch_columns]

        # Calculate and save the vertical slice sizes:
        self.fdx, self.dxs = self._calculate_dxy(dx, matches)

        # Save the border size information:
        self.top = min(dy // 2, self.dys[0])
        self.bottom = min(dy // 2, self.dys[-1])
        self.left = min(dx // 2, self.dxs[0])
        self.right = min(dx // 2, self.dxs[-1])

        # Find the optimal size for the borders (i.e. xleft, xright, ... ):
        self._find_best_borders(data)

        # Save the background color:
        x, y = (dx // 2), (dy // 2)
        r, g, b = data[y, x]
        self.bg_color = (0x10000 * r) + (0x100 * g) + b

        # Find the best contrasting text color (black or white):
        self.content_color = self._find_best_color(data, x, y)

        # Find the best contrasting label color:
        if self.xtop >= self.xbottom:
            self.label_color = self._find_best_color(data, x, self.xtop // 2)
        else:
            self.label_color = self._find_best_color(
                data, x, dy - (self.xbottom // 2) - 1)

    def _fill(self, idc, ix, iy, idx, idy, dc, x, y, dx, dy):
        """ Performs a stretch fill of a region of an image into a region of a
            window device context.
        """
        last_x, last_y = x + dx, y + dy
        while y < last_y:
            ddy = min(idy, last_y - y)
            x0 = x
            while x0 < last_x:
                ddx = min(idx, last_x - x0)
                dc.Blit(x0, y, ddx, ddy, idc, ix, iy, useMask=True)
                x0 += ddx
            y += ddy

    def _calculate_dxy(self, d, matches):
        """ Calculate the size of all image slices for a specified set of
            matches.
        """
        if len(matches) == 1:
            d1, d2 = matches[0]

            return (d - d2, [d1, d2, d - d1 - d2])

        d1, d2 = matches[0]
        d3, d4 = matches[1]

        return (d - d2 - d4, [d1, d2, d3 - d1 - d2, d4, d - d3 - d4])

    def _find_best_borders(self, data):
        """ Find the best set of image slice border sizes (e.g. for images with
            rounded corners, there should exist a better set of borders than
            the ones computed by the image slice algorithm.
        """
        # Make sure the image size is worth bothering about:
        dx, dy = self.dx, self.dy
        if (dx < 5) or (dy < 5):
            return

        # Calculate the starting point:
        left = right = dx // 2
        top = bottom = dy // 2

        # Calculate the end points:
        last_y = dy - 1
        last_x = dx - 1

        # Mark which edges as 'scanning':
        t = b = l = r = True

        # Keep looping while at last one edge is still 'scanning':
        while l or r or t or b:

            # Calculate the current core area size:
            height = bottom - top + 1
            width = right - left + 1

            # Try to extend all edges that are still 'scanning':
            nl = (l and (left > 0) and self._is_equal(data, left - 1, top,
                                                      left, top, 1, height))

            nr = (r and (right < last_x) and self._is_equal(
                data, right + 1, top, right, top, 1, height))

            nt = (t and (top > 0)
                  and self._is_equal(data, left, top - 1, left, top, width, 1))

            nb = (b and (bottom < last_y) and self._is_equal(
                data, left, bottom + 1, left, bottom, width, 1))

            # Now check the corners of the edges:
            tl = ((not nl) or (not nt)
                  or self._is_equal(data, left - 1, top - 1, left, top, 1, 1))

            tr = ((not nr) or (not nt) or self._is_equal(
                data, right + 1, top - 1, right, top, 1, 1))

            bl = ((not nl) or (not nb) or self._is_equal(
                data, left - 1, bottom + 1, left, bottom, 1, 1))

            br = ((not nr) or (not nb) or self._is_equal(
                data, right + 1, bottom + 1, right, bottom, 1, 1))

            # Calculate the new edge 'scanning' values:
            l = nl and tl and bl
            r = nr and tr and br
            t = nt and tl and tr
            b = nb and bl and br

            # Adjust the coordinate of an edge if it is still 'scanning':
            left -= l
            right += r
            top -= t
            bottom += b

        # Now compute the best set of image border sizes using the current set
        # and the ones we just calculated:
        self.xleft = min(self.left, left)
        self.xright = min(self.right, dx - right - 1)
        self.xtop = min(self.top, top)
        self.xbottom = min(self.bottom, dy - bottom - 1)

    def _find_best_color(self, data, x, y):
        """ Find the best contrasting text color for a specified pixel
            coordinate.
        """
        r, g, b = data[y, x]
        h, l, s = rgb_to_hls(r / 255.0, g / 255.0, b / 255.0)
        text_color = wx.Colour(wx.BLACK)
        if l < 0.50:
            text_color = wx.Colour(wx.WHITE)

        return text_color

    def _is_equal(self, data, x0, y0, x1, y1, dx, dy):
        """ Determines if two identically sized regions of an image array are
            'the same' (i.e. within some slight color variance of each other).
        """
        return (abs(data[y0:y0 + dy, x0:x0 + dx] -
                    data[y1:y1 + dy, x1:x1 + dx]).sum() < 0.10 * dx * dy)
Exemple #23
0
class Person(HasTraits):
    name = Str('David Morrill')
    age = Int(39)

    view = View('name', '<extra>', 'age', kind='modal')
class CodeEditingTest(HasTraits):
    block_code_editor = Instance(wx.Control, allow_none=False)
    code = Code("""from blockcanvas.debug.my_operator import add, mul
from numpy import arange
x = arange(0,10,.1)
c1 = mul(a,a)
x1 = mul(x,x)
t1 = mul(c1,x1)
t2 = mul(b, x)
t3 = add(t1,t2)
y = add(t3,c)
""")

    text_index = Int(0)
    random_seed = Int(0)
    random_generator = Instance(Random)
    permute_lines = Bool(True)
    random_backspace = Bool(True)
    clear_first = Bool(False)
    num_runs = Int(1)
    finish_callback = Any

    traits_view = View(
        Item('code'),
        Item('random_seed'),
        Item('num_runs'),
        Item('random_backspace'),
        Item('num_runs'),
        Item('permute_lines'),
        Item('clear_first'),
        buttons=['OK'],
        resizable=True,
    )

    def interactive_test(self):
        self.configure_test()
        self.run_test()
        return

    def run_test(self):
        self.random_generator = Random()
        self.random_generator.seed(self.random_seed)
        if self.permute_lines:
            codelines = self.code.split('\n')
            shuffle(codelines)
            self.code = '\n'.join(codelines)
        # Should have a more realistic markov process here
        if self.clear_first:
            self.clear_editor()
        timerid = wx.NewId()
        self.timer = wx.Timer(self.block_code_editor, timerid)
        self.block_code_editor.Bind(wx.EVT_TIMER, self.test_step, id=timerid)
        self.text_index = 0
        self.test_step(None)

    def test_step(self, event):
        if self.text_index < len(self.code):
            if random() > 0.8 or not self.random_backspace:
                self.text_index -= 1
                self.block_code_editor.CmdKeyExecute(STC_CMD_DELETEBACK)
            else:
                self.block_code_editor.AddText(self.code[self.text_index])
                self.text_index += 1
            self.timer.Start(50.0, wx.TIMER_ONE_SHOT)
        else:
            if self.finish_callback is not None:
                self.finish_callback()

    def configure_test(self):
        self.edit_traits(kind='modal')
        return

    def __init__(self, **kwtraits):
        super(CodeEditingTest, self).__init__(**kwtraits)
        from blockcanvas.block_display.code_block_ui import editor_control
        self.block_code_editor = editor_control()
        self.random_seed = int(random() * 1000000)
        return

    def insert_text(self, text, at_once=False):
        """Insert text into the code editor.  This can be done character
        by charager, or all at once if at_once is True"""
        if at_once:
            self.block_code_editor.AddText(text)
        else:
            for char in text:
                self.block_code_editor.AddText(char)
        return

    def clear_editor(self):
        """Removes all of the text all at once from the editor.
        this is equivalent for almost all purposes to the user
        selecting all and hitting backspace"""
        self.block_code_editor.ClearAll()

    def goto_line(self, linenum):
        self.block_code_editor.GotoLine(linenum)
        return

    def enter_basic_code(self):
        self.insert_text(self.basic_code)
class RemoteCommandServer(ConfigLoadable):
    '''
    '''
    simulation = False
    _server = None
    repeater = Instance(CommandRepeater)
    processor = Instance(
        'pychron.remote_hardware.command_processor.CommandProcessor')

    host = Str(enter_set=True, auto_set=False)
    port = Int(enter_set=True, auto_set=False)
    klass = Str

    loaded_port = None
    loaded_host = None

    packets_received = Int
    packets_sent = Int
    repeater_fails = Int

    cur_rpacket = String
    cur_spacket = String

    server_button = Event
    server_label = Property(depends_on='_running')
    _running = Bool(False)
    _connected = Bool(False)

    save = Button
    _dirty = Bool(False)

    run_time = Str
    led = Instance(LED, ())

    use_ipc = True

    def _repeater_default(self):
        """
        """
        if globalv.use_ipc:
            c = CommandRepeater(logger_name='{}_repeater'.format(self.name),
                                name=self.name,
                                config_path=os.path.join(
                                    paths.root, 'servers',
                                    '{}.cfg'.format(self.name)))
            if c.bootstrap():
                return c

    def _repeater_fails_changed(self, old, new):
        if new != 0:
            self.repeater.led.state = 0

    def load(self, *args, **kw):
        """
        """

        config = self.get_configuration()
        if config:

            server_class = self.config_get(config, 'General', 'class')
            if server_class is None:
                return

            if server_class == 'IPCServer':
                path = self.config_get(config, 'General', 'path')
                if path is None:
                    self.warning('Path not set. use path config value')
                    return
                addr = path
                self.host = path
                if os.path.exists(path):
                    os.remove(path)
            else:
                if LOCAL:
                    host = 'localhost'
                else:
                    host = self.config_get(config,
                                           'General',
                                           'host',
                                           optional=True)
                port = self.config_get(config, 'General', 'port', cast='int')

                if host is None:
                    host = socket.gethostbyname(socket.gethostname())
                if port is None:
                    self.warning('Host or Port not set {}:{}'.format(
                        host, port))
                    return
                elif port < 1024:
                    self.warning('Port Numbers < 1024 not allowed')
                    return
                addr = (host, port)

                self.host = host
                self.port = port if port else 0

                self.loaded_host = host
                self.loaded_port = port
            self.klass = server_class[:3]

            ds = None
            if config.has_option('Requests', 'datasize'):
                ds = config.getint('Requests', 'datasize')

            ptype = self.config_get(config, 'Requests', 'type', optional=False)
            if ptype is None:
                return

            self.datasize = ds
            self.processor_type = ptype

            self._server = self.server_factory(server_class, addr, ptype, ds)

            # add links
            for link in self.config_get_options(config, 'Links'):
                # note links cannot be stopped
                self._server.add_link(link,
                                      self.config_get(config, 'Links', link))

            if self._server is not None and self._server.connected:
                addr = self._server.server_address
                #                saddr = '({})'.format(','.join(addr if isinstance(addr, tuple) else (addr,)))
                saddr = '({})'.format(addr)
                msg = '%s - %s' % (server_class, saddr)
                self.info(msg)
                self._connected = True
                return True
            else:
                self._connected = False
                self.warning('Cannot connect to {}'.format(addr))

    def server_factory(self, klass, addr, ptype, ds):
        '''
        '''
        # from tcp_server import TCPServer
        # from udp_server import UDPServer

        module = __import__('pychron.messaging.{}_server'.format(
            klass[:3].lower()),
                            fromlist=[klass])
        factory = getattr(module, klass)

        #        gdict = globals()
        #        if handler in gdict:
        #            handler_klass = gdict[handler]

        #        server = gdict[server_class]
        if ds is None:
            ds = 2**10
        #        return server(self, ptype, ds, addr, handler_klass)
        return factory(self, ptype, ds, addr)

    def open(self):
        '''
        '''
        self._running = True
        # t = threading.Thread(target = self._server.serve_forever)
        t = threading.Thread(target=self.start_server)
        t.start()

        return True

    def start_server(self):
        SELECT_TIMEOUT = 1
        #        THREAD_LIMIT = 15
        while self._running:
            try:
                readySocket = select.select([self._server.socket], [], [],
                                            SELECT_TIMEOUT)
                if readySocket[0]:
                    self._server.handle_request()
                #                    if threading.activeCount() < THREAD_LIMIT:
                #                        self._server.handle_request()

            except:
                pass
            #        self._server.socket.close()

    def shutdown(self):
        """
        """
        self._connected = False
        if self._server is not None:
            #            self._server.shutdown()
            self._server.socket.close()

            self._running = False

    def traits_view(self):
        """
        """
        cparams = VGroup(
            HGroup(
                Item('led', show_label=False, editor=LEDEditor()),
                Item('server_button',
                     show_label=False,
                     editor=ButtonEditor(label_value='server_label'),
                     enabled_when='_connected'),
            ),
            Item('host', visible_when='not _running'),
            Item('port', visible_when='not _running'),
            show_border=True,
            label='Connection',
        )
        stats = Group(Item('packets_received', style='readonly'),
                      Item('cur_rpacket', label='Received', style='readonly'),
                      Item('packets_sent', style='readonly'),
                      Item('cur_spacket', label='Sent', style='readonly'),
                      Item('repeater_fails', style='readonly'),
                      Item('run_time', style='readonly'),
                      show_border=True,
                      label='Statistics',
                      visible_when='_connected')

        buttons = HGroup(Item('save', show_label=False, enabled_when='_dirty'))
        v = View(
            VGroup(cparams, stats, buttons),
            handler=RCSHandler,
        )
        return v

    def _run_time_update(self):
        '''
        '''

        #        t = datetime.datetime.now() - self.start_time

        #        h = t.seconds / 3600
        #        m = (t.seconds % 3600) / 60
        #        s = (t.seconds % 3600) % 60

        t, h, m, s = diff_timestamp(datetime.datetime.now(), self.start_time)

        rt = '{:02d}:{:02d}:{:02d}'.format(h, m, s)
        if t.days:
            rt = '{} {:02d}:{:02d}:{:02d}'.format(t.days, h, m, s)
        self.run_time = rt

    def __running_changed(self):
        '''
        '''
        if self._running:
            self.start_time = datetime.datetime.now()
            self.timer = Timer(1000, self._run_time_update)
        else:
            self.timer.Stop()

    def _anytrait_changed(self, name, value):
        '''

        '''
        if name in ['host', 'port']:
            attr = 'loaded_{}'.format(name)
            a = getattr(self, attr)
            if value != a and a is not None:
                self._dirty = True

    def _save_fired(self):
        '''
        '''

        self.shutdown()
        config = self.get_configuration()
        for attr in ['host', 'port']:
            a = getattr(self, attr)
            setattr(self, 'loaded_{}'.format(attr), a)
            config.set('General', attr, a)
            self.write_configuration(config)
            self.load()
        self._dirty = False

    def _server_button_fired(self):
        '''
        '''
        if self._running:
            self.shutdown()
        else:
            # reset the stats
            self.packets_received = 0
            self.packets_sent = 0
            self.cur_rpacket = ''
            self.cur_spacket = ''
            self.repeater_fails = 0

            #            self._server = self.server_factory('TCPServer',
            #                                               (self.host, self.port),
            #                                               self.handler,
            #                                               self.processor_type,
            #                                               self.datasize
            #                                             )
            #            self.open()
            self.bootstrap()

    def _get_server_label(self):
        '''
        '''
        return 'Start' if not self._running else 'Stop'
class FusionsLaserMonitor(LaserMonitor):
    """
    """

    max_coolant_temp = Float(25)
    max_coolant_temp_tries = Int(3)

    max_setpoint_tries = Int(6)

    _setpoint = None
    _cur_setpoints = None
    _setpoint_check_cnt = 0
    _coolant_check_cnt = 0
    _coolant_check_status_cnt = 0
    _setpoint_tolerance = 5

    _unavailable_cnt = 0
    max_unavailable = 3

    def load_additional_args(self, config):
        """
        """
        super(FusionsLaserMonitor, self).load_additional_args(self)
        self.set_attribute(config,
                           'max_coolant_temp',
                           'General',
                           'max_coolant_temp',
                           cast='float',
                           optional=True)

    def _fcheck_interlocks(self):
        """
        """
        # check laser interlocks
        manager = self.manager
        self.info('Check laser interlocks')
        interlocks = manager.laser_controller.check_interlocks(verbose=False)

        if interlocks:
            inter = ' '.join(interlocks)
            manager.emergency_shutoff(inter)
            return True

    def _fcheck_coolant_temp(self):
        """
        """
        manager = self.manager

        self.info('Check laser coolant temperature')
        ct = manager.get_coolant_temperature(verbose=False)
        if ct is None:
            self._chiller_unavailable()
        else:
            self._unavailable_cnt = 0
            if ct > self.max_coolant_temp:

                if self._coolant_check_cnt > self.max_coolant_temp_tries:
                    manager.emergency_shutoff(
                        'Coolant over temp {:0.2f}'.format(ct))
                else:
                    self._coolant_check_cnt += 1
                return True

            else:
                self._coolant_check_cnt = 0

    def reset(self):
        self._coolant_check_status_cnt = 0
        self._coolant_check_cnt = 0

    def _fcheck_coolant_status(self):
        manager = self.manager
        self.info('Check laser coolant status')

        status = manager.get_coolant_status()
        # returns an empty list
        if status is None:
            self._chiller_unavailable()
        else:
            # temporary disable pump fail check
            if 'Pump Fail' in status:
                self.debug('skip pump fail')
                return

            self._unavailable_cnt = 0
            if status and all(status):
                if self._coolant_check_status_cnt > self.max_coolant_temp_tries:
                    status = ','.join(status) if isinstance(status,
                                                            list) else status
                    reason = 'Laser coolant error {}'.format(status)
                    manager.emergency_shutoff(reason)
                else:
                    self._coolant_check_status_cnt += 1
                return True

            else:
                self._coolant_check_status_cnt = 0

    def _chiller_unavailable(self):
        from pychron.globals import globalv

        if not globalv.ignore_chiller_unavailable:
            if self._unavailable_cnt >= self.max_unavailable:
                reason = 'Laser chiller not available'
                self.manager.emergency_shutoff(reason)
            else:
                self._unavailable_cnt += 1

    def stop(self):
        self.setpoint = 0
        super(FusionsLaserMonitor, self).stop()

    def _get_setpoint(self):
        return self._setpoint

    def _set_setpoint(self, v):
        self._setpoint = v
        if v:
            self._setpoint_check_cnt = 0
            self._cur_setpoints = []

    setpoint = property(fget=_get_setpoint, fset=_set_setpoint)
Exemple #27
0
class MassSpecDatabaseImporter(Loggable):
    precedence = Int(0)

    db = Instance(MassSpecDatabaseAdapter)
    test = Button
    sample_loading_id = None
    data_reduction_session_id = None
    login_session_id = None
    _current_spec = None
    _analysis = None
    _database_version = 0

    def make_multipe_runs_sequence(self, exptxt):
        pass

    # IDatastore protocol
    def get_greatest_step(self, identifier, aliquot):

        ret = 0
        if self.db:
            identifier = self.get_identifier(identifier)
            ret = self.db.get_latest_analysis(identifier, aliquot)
            if ret:
                _, s = ret
                if s is not None and s in ALPHAS:
                    ret = ALPHAS.index(s)  # if s is not None else -1
                else:
                    ret = -1
        return ret

    def get_greatest_aliquot(self, identifier):
        ret = 0
        if self.db:
            identifier = self.get_identifier(identifier)
            ret = self.db.get_latest_analysis(identifier)
            if ret:
                ret, _ = ret
        return ret

    def is_connected(self):
        if self.db:
            return self.db.connected

    def connect(self, *args, **kw):
        ret = self.db.connect(*args, **kw)
        if ret:
            ver = self.db.get_database_version()
            if ver is not None:
                self._database_version = ver

        return ret

    def add_sample_loading(self, ms, tray):
        if self.sample_loading_id is None:
            db = self.db
            with db.session_ctx() as sess:
                sl = db.add_sample_loading(ms, tray)
                sess.flush()
                self.sample_loading_id = sl.SampleLoadingID

    def add_login_session(self, ms):
        self.info('adding new session for {}'.format(ms))
        db = self.db
        with db.session_ctx() as sess:
            ls = db.add_login_session(ms)
            sess.flush()
            self.login_session_id = ls.LoginSessionID

    def add_data_reduction_session(self):
        if self.data_reduction_session_id is None:
            db = self.db
            with db.session_ctx() as sess:
                dr = db.add_data_reduction_session()
                sess.flush()
                self.data_reduction_session_id = dr.DataReductionSessionID

    def create_import_session(self, spectrometer, tray):
        # add login, sample, dr ids
        if self.login_session_id is None or self._current_spec != spectrometer:
            self._current_spec = spectrometer
            self.add_login_session(spectrometer)

        if self.data_reduction_session_id is None:
            self.add_data_reduction_session()
        if self.sample_loading_id is None:
            self.add_sample_loading(spectrometer, tray)

    def clear_import_session(self):
        self.sample_loading_id = None
        self.data_reduction_session_id = None
        self.login_session_id = None
        self._current_spec = None

    def get_identifier(self, spec):
        """
            convert cocktails into mass spec labnumbers

            spec is either ExportSpec, int or str
            return identifier
        """
        if isinstance(spec, (int, str)):
            identifier = spec
            mass_spectrometer = ''
            if isinstance(identifier, str):
                if '-' in identifier:
                    a = identifier.split('-')[-1]
                    if a.lower() == 'o':
                        mass_spectrometer = 'obama'
                    elif a.lower() == 'j':
                        mass_spectrometer = 'jan'

        else:
            mass_spectrometer = spec.mass_spectrometer.lower()

        identifier = str(spec if isinstance(spec, (int,
                                                   str)) else spec.labnumber)

        if identifier.startswith('c'):
            if mass_spectrometer.lower() in ('obama', 'pychron obama'):
                identifier = '4358'
            else:
                identifier = '4359'
        return identifier

    def add_irradiation(self, irrad, level, pid):
        with self.db.session_ctx():
            sid = 0
            self.db.add_irradiation_level(irrad, level, sid, pid)

    def add_irradiation_position(self, identifier, levelname, hole, **kw):
        with self.db.session_ctx():
            return self.db.add_irradiation_position(identifier, levelname,
                                                    hole, **kw)

    def add_irradiation_production(self, name, prdict, ifdict):
        with self.db.session_ctx():
            return self.db.add_irradiation_production(name, prdict, ifdict)

    def add_irradiation_chronology(self, irrad, doses):

        with self.db.session_ctx():
            for pwr, st, et in doses:
                self.db.add_irradiation_chronology_segment(irrad, st, et)

    def add_analysis(self, spec, commit=True):
        for i in range(3):
            with self.db.session_ctx(commit=False) as sess:
                irradpos = spec.irradpos
                rid = spec.runid
                trid = rid.lower()
                identifier = spec.labnumber

                if trid.startswith('b'):
                    runtype = 'Blank'
                    irradpos = -1
                elif trid.startswith('a'):
                    runtype = 'Air'
                    irradpos = -2
                elif trid.startswith('c'):
                    runtype = 'Unknown'
                    identifier = irradpos = self.get_identifier(spec)
                else:
                    runtype = 'Unknown'

                rid = make_runid(identifier, spec.aliquot, spec.step)

                self._analysis = None
                self.db.reraise = True
                try:
                    ret = self._add_analysis(sess, spec, irradpos, rid,
                                             runtype)
                    sess.commit()
                    return ret
                except Exception, e:
                    self.debug('Mass Spec save exception. {}'.format(e))
                    if i == 2:
                        import traceback

                        tb = traceback.format_exc()
                        self.message('Could not save spec.runid={} rid={} '
                                     'to Mass Spec database.\n {}'.format(
                                         spec.runid, rid, tb))
                    else:
                        self.debug('retry mass spec save')
                    # if commit:
                    sess.rollback()
                finally:
                    self.db.reraise = True
class TasbeCalibrationOp(PluginOpMixin):
    handler_factory = Callable(TasbeHandler)
    
    id = Constant('edu.mit.synbio.cytoflowgui.op_plugins.bleedthrough_piecewise')
    friendly_id = Constant("Quantitative Pipeline")
    name = Constant("TASBE")
    
    fsc_channel = DelegatesTo('_polygon_op', 'xchannel', estimate = True)
    ssc_channel = DelegatesTo('_polygon_op', 'ychannel', estimate = True)
    vertices = DelegatesTo('_polygon_op', 'vertices', estimate = True)
    channels = List(Str, estimate = True)
    
    blank_file = File(filter = ["*.fcs"], estimate = True)
    
    bleedthrough_list = List(_BleedthroughControl, estimate = True)

    beads_name = Str(estimate = True)
    beads_file = File(filter = ["*.fcs"], estimate = True)
    units_list = List(_Unit, estimate = True)
    
    bead_peak_quantile = Int(80, estimate = True)
    bead_brightness_threshold = Float(100, estimate = True)
    bead_brightness_cutoff = util.FloatOrNone("", estimate = True)
    
    do_color_translation = Bool(estimate = True)
    to_channel = Str(estimate = True)
    translation_list = List(_TranslationControl, estimate = True)
    mixture_model = Bool(False, estimate = True)
    
    do_estimate = Event
    valid_model = Bool(False, status = True)
    do_exit = Event
    input_files = List(File)
    output_directory = Directory
        
    _blank_exp = Instance(Experiment, transient = True)
    _blank_exp_file = File(transient = True)
    _blank_exp_channels = List(Str, status = True)
    _polygon_op = Instance(PolygonOp, 
                           kw = {'name' : 'polygon',
                                 'xscale' : 'log', 
                                 'yscale' : 'log'}, 
                           transient = True)
    _af_op = Instance(AutofluorescenceOp, (), transient = True)
    _bleedthrough_op = Instance(BleedthroughLinearOp, (), transient = True)
    _bead_calibration_op = Instance(BeadCalibrationOp, (), transient = True)
    _color_translation_op = Instance(ColorTranslationOp, (), transient = True)
    
    status = Str(status = True)
    
    @on_trait_change('channels[], to_channel, do_color_translation', post_init = True)
    def _channels_changed(self, obj, name, old, new):
        for channel in self.channels:
            if channel not in [control.channel for control in self.bleedthrough_list]:
                self.bleedthrough_list.append(_BleedthroughControl(channel = channel))
                
            if channel not in [unit.channel for unit in self.units_list]:
                self.units_list.append(_Unit(channel = channel))

            
        to_remove = []    
        for control in self.bleedthrough_list:
            if control.channel not in self.channels:
                to_remove.append(control)
                
        for control in to_remove:
            self.bleedthrough_list.remove(control)
            
        to_remove = []    
        for unit in self.units_list:
            if unit.channel not in self.channels:
                to_remove.append(unit)
        
        for unit in to_remove:        
            self.units_list.remove(unit)
                
        if self.do_color_translation:
            to_remove = []
            for unit in self.units_list:
                if unit.channel != self.to_channel:
                    to_remove.append(unit)
            
            for unit in to_remove:
                self.units_list.remove(unit)
                 
            self.translation_list = []
            for c in self.channels:
                if c == self.to_channel:
                    continue
                self.translation_list.append(_TranslationControl(from_channel = c,
                                                                 to_channel = self.to_channel))
                
            self.changed = (Changed.ESTIMATE, ('translation_list', self.translation_list))
            
        self.changed = (Changed.ESTIMATE, ('bleedthrough_list', self.bleedthrough_list))            
        self.changed = (Changed.ESTIMATE, ('units_list', self.units_list))


    @on_trait_change('_polygon_op:vertices', post_init = True)
    def _polygon_changed(self, obj, name, old, new):
        self.changed = (Changed.ESTIMATE, (None, None))

    @on_trait_change("bleedthrough_list_items, bleedthrough_list.+", post_init = True)
    def _bleedthrough_controls_changed(self, obj, name, old, new):
        self.changed = (Changed.ESTIMATE, ('bleedthrough_list', self.bleedthrough_list))
     
    @on_trait_change("translation_list_items, translation_list.+", post_init = True)
    def _translation_controls_changed(self, obj, name, old, new):
        self.changed = (Changed.ESTIMATE, ('translation_list', self.translation_list))
        
    @on_trait_change('units_list_items,units_list.+', post_init = True)
    def _units_changed(self, obj, name, old, new):
        self.changed = (Changed.ESTIMATE, ('units_list', self.units_list))
#     
    def estimate(self, experiment, subset = None):
#         if not self.subset:
#             warnings.warn("Are you sure you don't want to specify a subset "
#                           "used to estimate the model?",
#                           util.CytoflowOpWarning)
            
#         if experiment is None:
#             raise util.CytoflowOpError("No valid result to estimate with")
        
#         experiment = experiment.clone()

        if not self.fsc_channel:
            raise util.CytoflowOpError('fsc_channel',
                                       "Must set FSC channel")
            
        if not self.ssc_channel:
            raise util.CytoflowOpError('ssc_channel',
                                       "Must set SSC channel")
        
        if not self._polygon_op.vertices:
            raise util.CytoflowOpError(None, "Please draw a polygon around the "
                                             "single-cell population in the "
                                             "Morphology tab")            

        experiment = self._blank_exp.clone()
        experiment = self._polygon_op.apply(experiment)
        
        self._af_op.channels = self.channels
        self._af_op.blank_file = self.blank_file
        
        self._af_op.estimate(experiment, subset = "polygon == True")
        self.changed = (Changed.ESTIMATE_RESULT, "Autofluorescence")
        experiment = self._af_op.apply(experiment)
        
        self.status = "Estimating bleedthrough"
        
        self._bleedthrough_op.controls.clear()
        for control in self.bleedthrough_list:
            self._bleedthrough_op.controls[control.channel] = control.file

        self._bleedthrough_op.estimate(experiment, subset = "polygon == True") 
        self.changed = (Changed.ESTIMATE_RESULT, "Bleedthrough")
        experiment = self._bleedthrough_op.apply(experiment)
        
        self.status = "Estimating bead calibration"
        
        self._bead_calibration_op.beads = BeadCalibrationOp.BEADS[self.beads_name]
        self._bead_calibration_op.beads_file = self.beads_file
        self._bead_calibration_op.bead_peak_quantile = self.bead_peak_quantile
        self._bead_calibration_op.bead_brightness_threshold = self.bead_brightness_threshold
        self._bead_calibration_op.bead_brightness_cutoff = self.bead_brightness_cutoff        
        
        self._bead_calibration_op.units.clear()

        for unit in self.units_list:
            self._bead_calibration_op.units[unit.channel] = unit.unit
            
        self._bead_calibration_op.estimate(experiment)
        self.changed = (Changed.ESTIMATE_RESULT, "Bead Calibration")
        
        if self.do_color_translation:
            self.status = "Estimating color translation"

            experiment = self._bead_calibration_op.apply(experiment)
            
            self._color_translation_op.mixture_model = self.mixture_model
            
            self._color_translation_op.controls.clear()
            for control in self.translation_list:
                self._color_translation_op.controls[(control.from_channel,
                                                     control.to_channel)] = control.file
                                                     
            self._color_translation_op.estimate(experiment, subset = 'polygon == True')                                         
            
            self.changed = (Changed.ESTIMATE_RESULT, "Color Translation")
            
        self.status = "Done estimating"
        self.valid_model = True
        
        
    def should_clear_estimate(self, changed, payload):
        """
        Should the owning WorkflowItem clear the estimated model by calling
        op.clear_estimate()?  `changed` can be:
        - Changed.ESTIMATE -- the parameters required to call 'estimate()' (ie
          traits with estimate = True metadata) have changed
        - Changed.PREV_RESULT -- the previous WorkflowItem's result changed

         """
        if changed == Changed.ESTIMATE:
            name, _ = payload
            if name == 'fsc_channel' or name == 'ssc_channel':
                return False
                    
        return True
        
        
    def clear_estimate(self):
        self._af_op = AutofluorescenceOp()
        self._bleedthrough_op = BleedthroughLinearOp()
        self._bead_calibration_op = BeadCalibrationOp()
        self._color_translation_op = ColorTranslationOp()
        self.valid_model = False
        
        self.changed = (Changed.ESTIMATE_RESULT, self)
                        
    def should_apply(self, changed, payload):
        """
        Should the owning WorkflowItem apply this operation when certain things
        change?  `changed` can be:
        - Changed.OPERATION -- the operation's parameters changed
        - Changed.PREV_RESULT -- the previous WorkflowItem's result changed
        - Changed.ESTIMATE_RESULT -- the results of calling "estimate" changed

        """
        if changed == Changed.ESTIMATE_RESULT and \
            self.blank_file != self._blank_exp_file:
            return True
        
        elif changed == Changed.OPERATION:
            name, _ = payload
            if name == "output_directory":
                return False

            return True
        
        return False

        
        
    def apply(self, experiment):

        # this "apply" function is a little odd -- it does not return an Experiment because
        # it always the only WI/operation in the workflow.
        
        if self.blank_file != self._blank_exp_file:
            self._blank_exp = ImportOp(tubes = [Tube(file = self.blank_file)] ).apply()
            self._blank_exp_file = self.blank_file
            self._blank_exp_channels = self._blank_exp.channels
            self.changed = (Changed.PREV_RESULT, None)
            return
        
            
        out_dir = Path(self.output_directory)
        for path in self.input_files:
            in_file_path = Path(path)
            out_file_path = out_dir / in_file_path.name
            if out_file_path.exists():
                raise util.CytoflowOpError(None,
                                           "File {} already exists"
                                           .format(out_file_path))
                
        tubes = [Tube(file = path, conditions = {'filename' : Path(path).stem})
                 for path in self.input_files]
        
        for tube in tubes:
            self.status = "Converting " + Path(tube.file).stem
            experiment = ImportOp(tubes = [tube], conditions = {'filename' : 'category'}).apply()
            
            experiment = self._af_op.apply(experiment)
            experiment = self._bleedthrough_op.apply(experiment)
            experiment = self._bead_calibration_op.apply(experiment)
            
            if self.do_color_translation:
                experiment = self._color_translation_op.apply(experiment)                                                
                    
            ExportFCS(path = self.output_directory,
                      by = ['filename'],
                      _include_by = False).export(experiment)
                      
        self.input_files = []
        self.status = "Done converting!"
    
    
    def default_view(self, **kwargs):
        return TasbeCalibrationView(op = self, **kwargs)
    
    def get_help(self):
        current_dir = os.path.abspath(__file__)
        help_dir = os.path.split(current_dir)[0]
        help_dir = os.path.join(help_dir, "help")
        
        help_file = None
        for klass in self.__class__.__mro__:
            mod = klass.__module__
            mod_html = mod + ".html"
            
            h = os.path.join(help_dir, mod_html)
            if os.path.exists(h):
                help_file = h
                break
                
        with open(help_file, encoding = 'utf-8') as f:
            help_html = f.read()
            
        return help_html
Exemple #29
0
class Dialog(MDialog, Window):
    """ The toolkit specific implementation of a Dialog.  See the IDialog
    interface for the API documentation.
    """

    implements(IDialog)

    #### 'IDialog' interface ##################################################

    cancel_label = Unicode

    help_id = Str

    help_label = Unicode

    ok_label = Unicode

    resizeable = Bool(True)

    return_code = Int(OK)

    style = Enum('modal', 'nonmodal')

    #### 'IWindow' interface ##################################################

    title = Unicode("Dialog")

    ###########################################################################
    # Protected 'IDialog' interface.
    ###########################################################################

    def _create_buttons(self, parent):
        buttons = QtGui.QDialogButtonBox()

        # 'OK' button.
        if self.ok_label:
            btn = buttons.addButton(self.ok_label,
                                    QtGui.QDialogButtonBox.AcceptRole)
        else:
            btn = buttons.addButton(QtGui.QDialogButtonBox.Ok)

        btn.setDefault(True)
        QtCore.QObject.connect(btn, QtCore.SIGNAL('clicked()'), self.control,
                               QtCore.SLOT('accept()'))

        # 'Cancel' button.
        if self.cancel_label:
            btn = buttons.addButton(self.cancel_label,
                                    QtGui.QDialogButtonBox.RejectRole)
        else:
            btn = buttons.addButton(QtGui.QDialogButtonBox.Cancel)

        QtCore.QObject.connect(btn, QtCore.SIGNAL('clicked()'), self.control,
                               QtCore.SLOT('reject()'))

        # 'Help' button.
        # FIXME v3: In the original code the only possible hook into the help
        # was to reimplement self._on_help().  However this was a private
        # method.  Obviously nobody uses the Help button.  For the moment we
        # display it but can't actually use it.
        if len(self.help_id) > 0:
            if self.help_label:
                buttons.addButton(self.help_label,
                                  QtGui.QDialogButtonBox.HelpRole)
            else:
                buttons.addButton(QtGui.QDialogButtonBox.Help)

        return buttons

    def _create_contents(self, parent):
        layout = QtGui.QVBoxLayout()

        if not self.resizeable:
            layout.setSizeConstraint(QtGui.QLayout.SetFixedSize)

        layout.addWidget(self._create_dialog_area(parent))
        layout.addWidget(self._create_buttons(parent))

        parent.setLayout(layout)

    def _create_dialog_area(self, parent):
        panel = QtGui.QWidget(parent)
        panel.setMinimumSize(QtCore.QSize(100, 200))

        palette = panel.palette()
        palette.setColor(QtGui.QPalette.Window, QtGui.QColor('red'))
        panel.setPalette(palette)
        panel.setAutoFillBackground(True)

        return panel

    def _show_modal(self):
        self.control.setWindowModality(QtCore.Qt.ApplicationModal)
        retval = self.control.exec_()
        return _RESULT_MAP[retval]

    ###########################################################################
    # Protected 'IWidget' interface.
    ###########################################################################

    def _create_control(self, parent):
        dlg = QtGui.QDialog(parent)

        # Setting return code and firing close events is handled for 'modal' in
        # MDialog's open method. For 'nonmodal', we do it here.
        if self.style == 'nonmodal':
            QtCore.QObject.connect(dlg, QtCore.SIGNAL('finished(int)'),
                                   self._finished_fired)

        if self.size != (-1, -1):
            dlg.resize(*self.size)

        dlg.setWindowTitle(self.title)

        return dlg

    ###########################################################################
    # Private interface.
    ###########################################################################

    def _finished_fired(self, result):
        """ Called when the dialog is closed (and nonmodal). """

        self.return_code = _RESULT_MAP[result]
        self.close()
Exemple #30
0
class OpenFileDialog(Handler):
    """ Defines the model and handler for the open file dialog.
    """

    # The starting and current file path:
    file_name = File

    # The list of file filters to apply:
    filter = CList(Str)

    # Number of history entries to allow:
    entries = Int(10)

    # The file dialog title:
    title = Str('Open File')

    # The Traits UI persistence id to use:
    id = Str('traitsui.file_dialog.OpenFileDialog')

    # A list of optional file dialog extensions:
    extensions = CList(IFileDialogModel)

    #-- Private Traits -------------------------------------------------------

    # The UIInfo object for the view:
    info = Instance(UIInfo)

    # Event fired when the file tree view should be reloaded:
    reload = Event

    # Event fired when the user double-clicks on a file name:
    dclick = Event

    # Allow extension models to be added dynamically:
    extension__ = Instance(IFileDialogModel)

    # Is the file dialog for saving a file (or opening a file)?
    is_save_file = Bool(False)

    # Is the currently specified file name valid?
    is_valid_file = Property(depends_on='file_name')

    # Can a directory be created now?
    can_create_dir = Property(depends_on='file_name')

    # The OK, Cancel and create directory buttons:
    ok = Button('OK')
    cancel = Button('Cancel')
    create = Button(image='@icons:folder-new', style='toolbar')

    #-- Handler Class Method Overrides ---------------------------------------

    def init_info(self, info):
        """ Handles the UIInfo object being initialized during view start-up.
        """
        self.info = info

    #-- Property Implementations ---------------------------------------------

    def _get_is_valid_file(self):
        if self.is_save_file:
            return (isfile(self.file_name) or (not exists(self.file_name)))

        return isfile(self.file_name)

    def _get_can_create_dir(self):
        dir = dirname(self.file_name)
        return (isdir(dir) and access(dir, R_OK | W_OK))

    #-- Handler Event Handlers -----------------------------------------------

    def object_ok_changed(self, info):
        """ Handles the user clicking the OK button.
        """
        if self.is_save_file and exists(self.file_name):
            do_later(self._file_already_exists)
        else:
            info.ui.dispose(True)

    def object_cancel_changed(self, info):
        """ Handles the user clicking the Cancel button.
        """
        info.ui.dispose(False)

    def object_create_changed(self, info):
        """ Handles the user clicking the create directory button.
        """
        if not isdir(self.file_name):
            self.file_name = dirname(self.file_name)

        CreateDirHandler().edit_traits(context=self,
                                       parent=info.create.control)

    #-- Traits Event Handlers ------------------------------------------------

    def _dclick_changed(self):
        """ Handles the user double-clicking a file name in the file tree view.
        """
        if self.is_valid_file:
            self.object_ok_changed(self.info)

    #-- Private Methods ------------------------------------------------------

    def open_file_view(self):
        """ Returns the file dialog view to use.
        """
        # Set up the default file dialog view and size information:
        item = Item('file_name',
                    id='file_tree',
                    style='custom',
                    show_label=False,
                    width=0.5,
                    editor=FileEditor(filter=self.filter,
                                      allow_dir=True,
                                      reload_name='reload',
                                      dclick_name='dclick'))
        width = height = 0.20

        # Check to see if we have any extensions being added:
        if len(self.extensions) > 0:

            # fixme: We should use the actual values of the view's Width and
            # height traits here to adjust the overall width and height...
            width *= 2.0

            # Assume we can used a fixed width Group:
            klass = HGroup

            # Set up to build a list of view Item objects:
            items = []

            # Add each extension to the dialog:
            for i, extension in enumerate(self.extensions):

                # Save the extension in a new trait (for use by the View):
                name = 'extension_%d' % i
                setattr(self, name, extension)

                extension_view = extension

                # Sync up the 'file_name' trait with the extension:
                self.sync_trait('file_name', extension, mutual=True)

                # Check to see if it also defines the optional IFileDialogView
                # interface, and if not, use the default view information:
                if not extension.has_traits_interface(IFileDialogView):
                    extension_view = default_view

                # Get the view that the extension wants to use:
                view = extension.trait_view(extension_view.view)

                # Determine if we should use a splitter for the dialog:
                if not extension_view.is_fixed:
                    klass = HSplit

                # Add the extension as a new view item:
                items.append(
                    Item(name,
                         label=user_name_for(extension.__class__.__name__),
                         show_label=False,
                         style='custom',
                         width=0.5,
                         height=0.5,
                         dock='horizontal',
                         resizable=True,
                         editor=InstanceEditor(view=view, id=name)))

            # Finally, combine the normal view element with the extensions:
            item = klass(item,
                         VSplit(id='splitter2', springy=True, *items),
                         id='splitter')
        # Return the resulting view:
        return View(VGroup(
            VGroup(item),
            HGroup(
                Item('create',
                     id='create',
                     show_label=False,
                     style='custom',
                     defined_when='is_save_file',
                     enabled_when='can_create_dir',
                     tooltip='Create a new directory'),
                Item('file_name',
                     id='history',
                     editor=HistoryEditor(entries=self.entries, auto_set=True),
                     springy=True),
                Item('ok',
                     id='ok',
                     show_label=False,
                     enabled_when='is_valid_file'),
                Item('cancel', show_label=False))),
                    title=self.title,
                    id=self.id,
                    kind='livemodal',
                    width=width,
                    height=height,
                    close_result=False,
                    resizable=True)

    def _file_already_exists(self):
        """ Handles prompting the user when the selected file already exists,
            and the dialog is a 'save file' dialog.
        """
        feh = FileExistsHandler(message=("The file '%s' already exists.\nDo "
                                         "you wish to overwrite it?") %
                                basename(self.file_name))
        feh.edit_traits(
            context=self,
            parent=self.info.ok.control).trait_set(parent=self.info.ui)