Пример #1
0
class IvyDriver(IODriver):
    """
      Ivy input driver.
  """
    _use_thread = False
    _ivy_id = Int(0)

    name = Str('Ivy Driver')
    ivy_agent_name = Str('Plot-o-matic')
    ivy_bus = Str('')
    ivy_ready_msg = Str('READY')
    ivy_regex = Str('(.*)')

    view = View(Item('ivy_agent_name',
                     label='Agent name',
                     editor=TextEditor(enter_set=True, auto_set=False)),
                Item('ivy_bus',
                     label='Ivy bus',
                     editor=TextEditor(enter_set=True, auto_set=False)),
                Item('ivy_regex',
                     label='Regex',
                     editor=TextEditor(enter_set=True, auto_set=False)),
                Item('ivy_ready_msg',
                     label='Ready message',
                     editor=TextEditor(enter_set=True, auto_set=False)),
                title='Ivy input driver')

    def open(self):
        IvyInit(self.ivy_agent_name, self.ivy_ready_msg)
        logging.getLogger('Ivy').setLevel(logging.ERROR)
        IvyStart(self.ivy_bus)
        self._ivy_id = IvyBindMsg(self.on_ivy_msg, self.ivy_regex)

    def close(self):
        IvyUnBindMsg(self._ivy_id)
        IvyStop()

    def reopen(self):
        self.close()
        self.open()

    def _ivy_agent_name_changed(self):
        self.reopen()

    def _ivy_bus_changed(self):
        self.reopen()

    def _ivy_ready_msg_changed(self):
        self.reopen()

    def _ivy_regex_changed(self):
        self.reopen()

    def on_ivy_msg(self, agent, *larg):
        if larg[0] != self.ivy_ready_msg:
            self.pass_data(larg[0])
Пример #2
0
class Probe(HasTraits):
    HT = Range(low=40, high=3000.0, value=200.0)
    alpha = Range(low=0.0, high=80.0, value=15.0)
    wl = Property(Float, depends_on=['HT'])
    nomenclature = Enum('Krivanek', 'Rose', 'Number3')
    ab = Instance(Aberrations)

    def _get_wl(self):
        h = 6.626 * 10**-34
        m0 = 9.109 * 10**-31
        eV = 1.602 * 10**-19 * self.HT * 1000
        C = 2.998 * 10**8
        return h / np.sqrt(2 * m0 * eV * (1 + eV / (2 * m0 * C**2))) * 10**12

    gen_group = Group(
        HGroup(
            Item(name='nomenclature', label='Nomenclature'), spring,
            Item(name='HT',
                 label="High Tension, kV",
                 help='The microscope accelerating voltage'),
            Item('wl',
                 label="Wavelength, pm ",
                 style='readonly',
                 editor=TextEditor(format_str='%3.2f')), spring,
            Item('alpha', label="Conv. Angle")), )
    ab_group = Group(Group(Item(name='ab', style='custom'), show_labels=False),
                     show_border=True)

    view = View(Group(gen_group, ab_group),
                title='Higher-order Aberrations',
                buttons=['OK', 'Cancel'],
                resizable=True,
                handler=ProbeHandler())
Пример #3
0
class InternetExplorerDemo ( HasTraits ):
    
    # A URL to display:
    url = Str( 'http://' )
    
    # The list of web pages being browsed:
    pages = List( WebPage )

    # The view to display:
    view = View(
        VGroup( 
            Item( 'url',
                  label  = 'Location',
                  editor = TextEditor( auto_set = False, enter_set = True )
            )
        ),
        Item( 'pages',
              show_label = False,
              style      = 'custom',
              editor     = ListEditor( use_notebook = True,
                                       deletable    = True,
                                       dock_style   = 'tab',
                                       export       = 'DockWindowShell',
                                       page_name    = '.title' )
        )
    )    
    
    # Event handlers:
    def _url_changed ( self, url ):
        self.pages.append( WebPage( url = url.strip() ) )
Пример #4
0
 def traits_view(self):
     view = \
         View(
             UItem('text', editor=TextEditor(multi_line=True), style='custom'),
             handler=_OutputStreamViewHandler(),
         )
     return view
Пример #5
0
class DocumentedItem(HasTraits):
    """ Container to hold a name and a documentation for an action.
    """

    # Name of the action
    name = Str

    # Button to trigger the action
    add = ToolbarButton('Add',
                        orientation='horizontal',
                        image=ImageResource('add.ico'))

    # Object the action will apply on
    object = Any

    # Two lines documentation for the action
    documentation = Str

    view = View(
        '_',
        Item('add', style='custom', show_label=False),
        Item('documentation',
             style='readonly',
             editor=TextEditor(multi_line=True),
             resizable=True,
             show_label=False),
    )

    def _add_fired(self):
        """ Trait handler for when the add_source button is clicked in
            one of the sub objects in the list.
        """
        action = getattr(self.object.menu_helper, self.id)
        action()
Пример #6
0
class RegexDecoder(DataDecoder):
    """
      Decodes arbitrary text using regex.
  """
    name = Str('Regex Decoder')
    view = View(
        Item(name='regex',
             label='Regex',
             editor=TextEditor(enter_set=True, auto_set=False)),
        Item(
            label=
            "Each subgroup in the regex is \nassigned to a variable \nin the list in order."
        ),
        Item(name='variable_names',
             label='Group names',
             editor=TextEditor(enter_set=True, auto_set=False)),
        Item(label="(comma separated, use '_' to ignore a subgroup)"),
        title='Regex decoder')
    regex = Str()
    variable_names = Str()

    def decode(self, data):
        """
        Decode CSV input data then assign variables based on a CSV format list
        list of names, using an '_' to ignore a field.
    """
        re_result = ''
        try:
            re_result = re.search(self.regex, data)
        except:
            return None

        if not re_result:
            return None

        re_groups = re_result.groups()
        var_names = self.variable_names.split(',')

        if len(re_groups) == len(var_names):
            data_dict = {}
            for n, var in enumerate(var_names):
                if var != '_':
                    try:
                        data_dict[var] = float(re_groups[n])
                    except:
                        data_dict[var] = re_groups[n]
            return data_dict
Пример #7
0
class Parameter(HasTraits):
    name = String('name')
    value = Float
    def __str__(self): return '{%s:%f}' %(self.name,self.value)
    def __repr__(self): return '{%s:%f}' %(self.name,self.value)
    editor = TableEditor(
        auto_size=False,
        columns=[ObjectColumn(name='name', label='Name', editable=False),
                 ObjectColumn(name='value', label='Value',
            editor=TextEditor(evaluate=float, enter_set=True, auto_set=False)),
            ])
Пример #8
0
class Parameter(HasTraits):
    name = String
    value = Float
    unit = String
    pattern = String
    cmd = String
    editor = TableEditor(
        auto_size=False,
        #row_height=20,
        columns=[ObjectColumn(name='name', editable=False, label='Parameter'),
                     ObjectColumn(name='value', label='Value', editor=TextEditor(evaluate=float, enter_set=True, auto_set=False)),
                     ObjectColumn(name='unit', editable=False, label='Unit')
                    ])
Пример #9
0
class Expression(HasTraits):
    _vars = Instance(Variables)
    _expr = ExpressionString('')
    _data_array_cache = None
    _data_array_cache_index = Int(0)

    view = View(
        Item('_expr',
             show_label=False,
             editor=TextEditor(enter_set=True, auto_set=False)))

    def __init__(self, variables, expr, **kwargs):
        HasTraits.__init__(self, **kwargs)
        self._vars = variables
        self.set_expr(expr)

    def set_expr(self, expr):
        if self._expr != expr:
            self._expr = expr

    def __expr_changed(self):
        self.clear_cache()

    def clear_cache(self):
        self._data_array_cache = numpy.array([])
        self._data_array_cache_index = 0

    def get_curr_value(self):
        return self._vars._eval_expr(self._expr)

    def get_array(self, first=0, last=None):
        first, last = self._vars.bound_array(first, last)
        if last > self._data_array_cache_index:
            #print "Cache miss of", (last - self._data_array_cache_index)
            new_data = self._vars._get_array(self._expr,
                                             self._data_array_cache_index,
                                             last)

            new_shape = list(new_data.shape)
            new_shape[
                0] = -1  # -1 lets the first index resize appropriately for the data length

            self._data_array_cache = numpy.append(self._data_array_cache,
                                                  new_data)
            self._data_array_cache.shape = new_shape
            self._data_array_cache_index = last
            # use the global max_samples to limit our cache size
            self._data_array_cache = self._data_array_cache[-self._vars.
                                                            max_samples:]

        return self._data_array_cache[first:last]
Пример #10
0
class ListItem(HasTraits):
    """ Class used to represent an item in a list with traits UI.
    """
    column_number = Int
    name = Str
    my_name = Str
    parent = Instance(HasTraits)
    view = View(
        HGroup(
            Item('name', style='readonly', show_label=False, resizable=False),
            Item('my_name',
                 style='simple',
                 show_label=False,
                 editor=TextEditor(auto_set=False, enter_set=True),
                 springy=True),
        ))
Пример #11
0
class Constraint(HasTraits):
    name = String
    runcase = Any
    constraint_variables_available = Property(List(String), depends_on='runcase.constraint_variables')
    @cached_property
    def _get_constraint_variables_available(self):
        return self.runcase.constraint_variables.keys()
    #constraint = Instance(ConstraintVariable)
    constraint_name = String
    value = Float
    pattern = String
    cmd = String
    editor = TableEditor(
        auto_size=False,
        columns=[ ObjectColumn(name='name', editable=False, label='Parameter'),
                     ObjectColumn(name='constraint_name', label='Constraint', editor=EnumEditor(name='constraint_variables_available')),
                     ObjectColumn(name='value', label='Value', editor=TextEditor(evaluate=float, enter_set=True, auto_set=False))
                    ])
Пример #12
0
class Sp4ArrayFileSource(ArraySource):

    file_name = Str

    component = Enum('Amp', 'Phase', 'Real', 'Imag')
    set_component = Str

    f = Instance(sp4.Sp4File)

    view = View(Group(Item(name='transpose_input_array'),
        Item(name='file_name', editor=TextEditor(auto_set=False,\
                enter_set=True)),
        Item(name='component', editor=EnumEditor(values=component)),
                      Item(name='scalar_name'),
                      Item(name='vector_name'),
                      Item(name='spacing'),
                      Item(name='origin'),
                      Item(name='update_image_data', show_label=False),
                      show_labels=True)
                )

    def __init__(self, **kw_args):
        super(Sp4ArrayFileSource, self).__init__(**kw_args)
        fn = kw_args.pop('file_name', None)
        if fn is not None:
            self.file_name = fn
            self._open(self.file_name)
        self.component = "Amp"
        self._component_changed('Amp')

    def _open(self, fn):
        self.f = sp4.Sp4File(self.file_name)

    def _component_changed(self, info):
        if info == "Amp":
            self.scalar_data = numpy.abs(self.f.GetArray())
        if info == "Phase":
            self.scalar_data = numpy.angle(self.f.GetArray())
        self.update()

    def _file_name_changed(self, info):
        print self.file_name
Пример #13
0
class Hahn(Pulsed):
    """Hahn echo measurement using standard pi/2-pi-pi/2 sequence.
    """

    t_pi2 = Range(low=1.,
                  high=100000.,
                  value=1000.,
                  desc='length of pi/2 pulse [ns]',
                  label='pi/2 [ns]',
                  mode='text',
                  auto_set=False,
                  enter_set=True)
    t_pi = Range(low=1.,
                 high=100000.,
                 value=1000.,
                 desc='length of pi pulse [ns]',
                 label='pi [ns]',
                 mode='text',
                 auto_set=False,
                 enter_set=True)
    t_3pi2 = Range(low=1.,
                   high=100000.,
                   value=1000.,
                   desc='length of 3pi/2 pulse [ns]',
                   label='3pi/2 [ns]',
                   mode='text',
                   auto_set=False,
                   enter_set=True)
    rabi_contrast = Range(low=1.,
                          high=100,
                          value=30.0,
                          desc='Rabi contrast [%]',
                          label='contrast',
                          mode='text',
                          auto_set=False,
                          enter_set=True)

    def __init__(self):
        super(Hahn, self).__init__()

    def _get_sequence_points(self):
        return 2 * len(self.tau)

    def generate_sequence(self):

        tau = self.tau
        laser = self.laser
        wait = self.wait
        t_pi2 = self.t_pi2
        t_pi = self.t_pi
        t_3pi2 = self.t_3pi2
        sequence = []
        for t in tau:

            sub = [(['mw'], t_pi2), ([], t), (['mw'], t_pi), ([], t),
                   (['mw'], t_pi2), (['laser', 'trigger'], laser), ([], wait),
                   (['mw'], t_pi2), ([], t), (['mw'], t_pi), ([], t),
                   (['mw'], t_3pi2), (['laser', 'trigger'], laser), ([], wait)]
            sequence.extend(sub)

        return sequence

    get_set_items = Pulsed.get_set_items + ['t_pi2', 't_pi', 't_3pi2']

    traits_view = View(
        VGroup(
            HGroup(
                Item('submit_button', show_label=False),
                Item('remove_button', show_label=False),
                Item('resubmit_button', show_label=False),
                Item('priority'),
            ),
            HGroup(
                Item('freq', width=40),
                Item('power', width=20),
                Item('t_pi2', width=20),
                Item('t_pi', width=20),
                Item('t_3pi2', width=20),
            ),
            HGroup(Item('tau_begin', width=20), Item('tau_end', width=20),
                   Item('tau_delta', width=20), Item('rabi_contrast',
                                                     width=20)),
            HGroup(
                Item('laser', width=40),
                Item('wait', width=40),
                Item('bin_width', width=40),
            ),
            HGroup(
                Item('state', style='readonly'),
                Item('run_time', style='readonly', format_str='%.f', width=50),
                Item('sweeps',
                     editor=TextEditor(auto_set=False,
                                       enter_set=True,
                                       evaluate=float,
                                       format_func=lambda x: '%.3e' % x),
                     width=40),
                Item('expected_duration',
                     style='readonly',
                     editor=TextEditor(evaluate=float,
                                       format_func=lambda x: '%.3e' % x),
                     width=40),
                Item('progress', style='readonly'),
                Item('elapsed_time', style='readonly'),
            ),
        ),
        title='Hahn',
    )
Пример #14
0
class TVTKClassChooser(HasTraits):

    # The selected object, is None if no valid class_name was made.
    object = Property

    # The TVTK class name to choose.
    class_name = Str('', desc='class name of TVTK class (case sensitive)')

    # The string to search for in the class docs -- the search supports
    # 'and' and 'or' keywords.
    search = Str('', desc='string to search in TVTK class documentation '\
                          'supports the "and" and "or" keywords. '\
                          'press <Enter> to start search. '\
                          'This is case insensitive.')

    clear_search = Button

    # The class documentation.
    doc = Str(_search_help_doc)

    # Completions for the choice of class.
    completions = List(Str)

    # List of available class names as strings.
    available = List(TVTK_CLASSES)

    ########################################
    # Private traits.

    finder = Instance(DocSearch)

    n_completion = Int(25)

    ########################################
    # View related traits.

    view = View(Group(Item(name='class_name',
                           editor=EnumEditor(name='available')),
                      Item(name='class_name',
                           has_focus=True
                           ),
                      Item(name='search',
                           editor=TextEditor(enter_set=True,
                                             auto_set=False)
                           ),
                      Item(name='clear_search',
                           show_label=False),
                      Item('_'),
                      Item(name='completions',
                           editor=ListEditor(columns=3),
                           style='readonly'
                           ),
                      Item(name='doc', 
                           resizable=True,
                           label='Documentation',
                           style='custom')
                      ),
                id='tvtk_doc',
                resizable=True,
                width=800,
                height=600,
                title='TVTK class chooser',
                buttons = ["OK", "Cancel"]
                )
    ######################################################################
    # `object` interface.
    ###################################################################### 
    def __init__(self, **traits):
        super(TVTKClassChooser, self).__init__(**traits)
        self._orig_available = list(self.available)

    ######################################################################
    # Non-public interface.
    ###################################################################### 
    def _get_object(self):
        o = None
        if len(self.class_name) > 0:
            try:
                o = getattr(tvtk, self.class_name)()
            except (AttributeError, TypeError):
                pass
        return o

    def _class_name_changed(self, value):
        av = self.available
        comp = [x for x in av if x.startswith(value)]
        self.completions = comp[:self.n_completion]
        if len(comp) == 1 and value != comp[0]:
            self.class_name = comp[0]

        o = self.object
        if o is not None:
            self.doc = get_tvtk_class_doc(o)
        else:
            self.doc = _search_help_doc

    def _finder_default(self):
        return DocSearch()

    def _clear_search_fired(self):
        self.search = ''

    def _search_changed(self, value):
        if len(value) < 3:
            self.available = self._orig_available
            return

        f = self.finder
        result = f.search(value)
        if len(result) == 0:
            self.available = self._orig_available
        elif len(result) == 1:
            self.class_name = result[0]
        else:
            self.available = result
            self.completions = result[:self.n_completion]
Пример #15
0
class DEER(PulsedDEER):
    """Hahn echo measurement using standard pi/2-pi-pi/2 sequence + another MW pulse after pi with HMC.
    """

    t_pi2 = Range(low=1.,
                  high=100000.,
                  value=1000.,
                  desc='length of pi/2 pulse [ns]',
                  label='pi/2 [ns]',
                  mode='text',
                  auto_set=False,
                  enter_set=True)
    t_pi = Range(low=1.,
                 high=100000.,
                 value=1000.,
                 desc='length of pi pulse [ns]',
                 label='pi [ns]',
                 mode='text',
                 auto_set=False,
                 enter_set=True)

    t_el_pi = Range(low=1.,
                    high=100000.,
                    value=1000.,
                    desc='length of pi pulse [ns]',
                    label='pi [ns]',
                    mode='text',
                    auto_set=False,
                    enter_set=True)

    def __init__(self):
        super(DEER, self).__init__()

    def generate_sequence(self):

        time_Hahn = self.time_Hahn
        tau = self.tau
        laser = self.laser
        wait = self.wait
        t_pi2 = self.t_pi2
        t_pi = self.t_pi
        t_el_pi = self.t_el_pi
        sequence = []
        for t in tau:
            sequence.append((['mw'], t_pi2))
            sequence.append(([], time_Hahn))
            sequence.append((['mw'], t_pi))
            sequence.append(([], time_Hahn - t - t_el_pi))
            sequence.append((['hmc_trigger'], t_el_pi))
            sequence.append(([], t))
            sequence.append((['mw'], t_pi2))
            sequence.append((['laser', 'trigger'], laser))
            sequence.append(([], wait))
        return sequence

    get_set_items = PulsedDEER.get_set_items + ['t_pi2', 't_pi', 't_el_pi']

    traits_view = View(
        VGroup(
            HGroup(
                Item('submit_button', show_label=False),
                Item('remove_button', show_label=False),
                Item('resubmit_button', show_label=False),
                Item('priority'),
            ),
            HGroup(
                Item('freq', width=40),
                Item('power', width=20),
                Item('t_pi2', width=20),
                Item('t_pi', width=20),
                Item('time_Hahn', width=20),
            ),
            HGroup(Item('freq_HMC', width=40), Item('power_HMC', width=40),
                   Item('t_el_pi', width=40)),
            HGroup(
                Item('tau_begin', width=20),
                Item('tau_end', width=20),
                Item('tau_delta', width=20),
            ),
            HGroup(
                Item('laser', width=40),
                Item('wait', width=40),
                Item('bin_width', width=40),
            ),
            HGroup(
                Item('state', style='readonly'),
                Item('run_time', style='readonly', format_str='%.f', width=50),
                Item('sweeps',
                     editor=TextEditor(auto_set=False,
                                       enter_set=True,
                                       evaluate=float,
                                       format_func=lambda x: '%.3e' % x),
                     width=40),
                Item('expected_duration',
                     style='readonly',
                     editor=TextEditor(evaluate=float,
                                       format_func=lambda x: '%.3e' % x),
                     width=40),
                Item('progress', style='readonly'),
                Item('elapsed_time', style='readonly'),
            ),
        ),
        title='DEER',
    )
Пример #16
0
class Pulsed(ManagedJob, GetSetItemsMixin):
    """Defines a pulsed measurement."""
    keep_data = Bool(
        False)  # helper variable to decide whether to keep existing data

    resubmit_button = Button(
        label='resubmit',
        desc=
        'Submits the measurement to the job manager. Tries to keep previously acquired data. Behaves like a normal submit if sequence or time bins have changed since previous run.'
    )

    sequence = Instance(list, factory=list)

    record_length = Float(value=0,
                          desc='length of acquisition record [ms]',
                          label='record length [ms] ',
                          mode='text')

    count_data = Array(value=np.zeros(2))

    run_time = Float(value=0.0, label='run time [ns]', format_str='%.f')
    stop_time = Range(
        low=1.,
        value=np.inf,
        desc='Time after which the experiment stops by itself [s]',
        label='Stop time [s]',
        mode='text',
        auto_set=False,
        enter_set=True)

    tau_begin = Range(low=0.,
                      high=1e5,
                      value=300.,
                      desc='tau begin [ns]',
                      label='repetition',
                      mode='text',
                      auto_set=False,
                      enter_set=True)
    tau_end = Range(low=1.,
                    high=1e5,
                    value=4000.,
                    desc='tau end [ns]',
                    label='N repetition',
                    mode='text',
                    auto_set=False,
                    enter_set=True)
    tau_delta = Range(low=1.,
                      high=1e5,
                      value=50.,
                      desc='delta tau [ns]',
                      label='delta',
                      mode='text',
                      auto_set=False,
                      enter_set=True)

    tau = Array(value=np.array((0., 1.)))
    sequence_points = Int(value=2, label='number of points', mode='text')

    laser_SST = Range(low=1.,
                      high=5e6,
                      value=200.,
                      desc='laser for SST [ns]',
                      label='laser_SST[ns]',
                      mode='text',
                      auto_set=False,
                      enter_set=True)
    wait_SST = Range(low=1.,
                     high=5e6,
                     value=1000.,
                     desc='wait for SST[ns]',
                     label='wait_SST [ns]',
                     mode='text',
                     auto_set=False,
                     enter_set=True)
    N_shot = Range(low=1,
                   high=20e5,
                   value=2e3,
                   desc='number of shots in SST',
                   label='N_shot',
                   mode='text',
                   auto_set=False,
                   enter_set=True)

    laser = Range(low=1.,
                  high=5e4,
                  value=3000,
                  desc='laser [ns]',
                  label='laser [ns]',
                  mode='text',
                  auto_set=False,
                  enter_set=True)
    wait = Range(low=1.,
                 high=5e4,
                 value=5000.,
                 desc='wait [ns]',
                 label='wait [ns]',
                 mode='text',
                 auto_set=False,
                 enter_set=True)

    freq_center = Range(low=1,
                        high=20e9,
                        value=2.71e9,
                        desc='frequency [Hz]',
                        label='MW freq[Hz]',
                        editor=TextEditor(auto_set=False,
                                          enter_set=True,
                                          evaluate=float,
                                          format_str='%.4e'))
    power = Range(low=-100.,
                  high=25.,
                  value=-26,
                  desc='power [dBm]',
                  label='power[dBm]',
                  editor=TextEditor(auto_set=False,
                                    enter_set=True,
                                    evaluate=float))
    freq = Range(low=1,
                 high=20e9,
                 value=2.71e9,
                 desc='frequency [Hz]',
                 label='freq [Hz]',
                 editor=TextEditor(auto_set=False,
                                   enter_set=True,
                                   evaluate=float,
                                   format_str='%.4e'))
    pi = Range(low=0.,
               high=5e4,
               value=2e3,
               desc='pi pulse length',
               label='pi [ns]',
               mode='text',
               auto_set=False,
               enter_set=True)

    amp = Range(low=0.,
                high=1.0,
                value=1.0,
                desc='Normalized amplitude of waveform',
                label='Amp',
                mode='text',
                auto_set=False,
                enter_set=True)
    vpp = Range(low=0.,
                high=4.5,
                value=0.6,
                desc='Amplitude of AWG [Vpp]',
                label='Vpp',
                mode='text',
                auto_set=False,
                enter_set=True)

    sweeps = Range(low=1.,
                   high=1e4,
                   value=1e2,
                   desc='number of sweeps',
                   label='sweeps',
                   mode='text',
                   auto_set=False,
                   enter_set=True)
    expected_duration = Property(
        trait=Float,
        depends_on='sweeps,sequence',
        desc='expected duration of the measurement [s]',
        label='expected duration [s]')
    elapsed_sweeps = Float(value=0,
                           desc='Elapsed Sweeps ',
                           label='Elapsed Sweeps ',
                           mode='text')
    elapsed_time = Float(value=0,
                         desc='Elapsed Time [ns]',
                         label='Elapsed Time [ns]',
                         mode='text')
    progress = Int(value=0,
                   desc='Progress [%]',
                   label='Progress [%]',
                   mode='text')

    load_button = Button(desc='compile and upload waveforms to AWG',
                         label='load')
    reload = True

    readout_interval = Float(
        1,
        label='Data readout interval [s]',
        desc='How often data read is requested from nidaq')
    samples_per_read = Int(
        200,
        label='# data points per read',
        desc=
        'Number of data points requested from nidaq per read. Nidaq will automatically wait for the data points to be aquired.'
    )

    def submit(self):
        """Submit the job to the JobManager."""
        self.keep_data = False
        ManagedJob.submit(self)

    def resubmit(self):
        """Submit the job to the JobManager."""
        self.keep_data = True
        ManagedJob.submit(self)

    def _resubmit_button_fired(self):
        """React to start button. Submit the Job."""
        self.resubmit()

    def generate_sequence(self):
        return []

    def prepare_awg(self):
        """ override this """
        AWG.reset()

    def _load_button_changed(self):
        self.load()

    def load(self):
        self.reload = True
        # update record_length, in ms
        self.record_length = self.N_shot * (self.pi + self.laser_SST +
                                            self.wait_SST) * 1e-6
        #make sure tau is updated
        self.tau = np.arange(self.tau_begin, self.tau_end, self.tau_delta)
        self.prepare_awg()
        self.reload = False

    @cached_property
    def _get_expected_duration(self):
        sequence_length = 0
        for step in self.sequence:
            sequence_length += step[1]
        return self.sweeps * sequence_length * 1e-9

    def _get_sequence_points(self):
        return len(self.tau)

    def apply_parameters(self):
        """Apply the current parameters and decide whether to keep previous data."""
        """if load button is not used, make sure tau is generated"""
        if (self.tau.shape[0] == 2):
            tau = np.arange(self.tau_begin, self.tau_end, self.tau_delta)
            self.tau = tau

        self.sequence_points = self._get_sequence_points()
        self.measurement_points = self.sequence_points * int(self.sweeps)
        sequence = self.generate_sequence()

        if self.keep_data and sequence == self.sequence:  # if the sequence and time_bins are the same as previous, keep existing data

            self.previous_sweeps = self.elapsed_sweeps
            self.previous_elapsed_time = self.elapsed_time
            self.keep_data = True  # when job manager stops and starts the job, data should be kept. Only new submission should clear data.
        else:

            #self.old_count_data = np.zeros((n_laser, n_bins))
            #self.check = True

            self.count_data = np.zeros(self.measurement_points)
            self.old_count_data = np.zeros(self.measurement_points)
            self.previous_sweeps = 0
            self.previous_elapsed_time = 0.0
            self.run_time = 0.0
            self.keep_data = True  # when job manager stops and starts the job, data should be kept. Only new submission should clear data.

        self.sequence = sequence

    def _run(self):
        """Acquire data."""

        try:  # try to run the acquisition from start_up to shut_down
            self.state = 'run'
            self.apply_parameters()

            PG.High([])

            self.prepare_awg()
            MW.setFrequency(self.freq_center)
            MW.setPower(self.power)

            AWG.run()
            time.sleep(4.0)
            PG.Sequence(self.sequence, loop=True)

            if CS.configure(
            ) != 0:  # initialize and start nidaq gated counting task, return 0 if succuessful
                print 'error in nidaq'
                return

            start_time = time.time()

            aquired_data = np.empty(
                0)  # new data will be appended to this array

            while True:

                self.thread.stop_request.wait(self.readout_interval)
                if self.thread.stop_request.isSet():
                    logging.getLogger().debug('Caught stop signal. Exiting.')
                    break

                #threading.current_thread().stop_request.wait(self.readout_interval) # wait for some time before new read command is given. not sure if this is neccessary
                #if threading.current_thread().stop_request.isSet():
                #break

                points_left = self.measurement_points - len(aquired_data)

                self.elapsed_time = self.previous_elapsed_time + time.time(
                ) - start_time
                self.run_time += self.elapsed_time

                new_data = CS.read_gated_counts(SampleLength=min(
                    self.samples_per_read, points_left
                ))  # do not attempt to read more data than neccessary

                aquired_data = np.append(
                    aquired_data, new_data[:min(len(new_data), points_left)])

                self.count_data[:len(
                    aquired_data
                )] = aquired_data[:]  # length of trace may not change due to plot, so just copy aquired data into trace

                sweeps = len(aquired_data) / self.sequence_points
                self.elapsed_sweeps += self.previous_sweeps + sweeps
                self.progress = int(100 * len(aquired_data) /
                                    self.measurement_points)

                if self.progress > 99.9:
                    break

            MW.Off()
            PG.High(['laser', 'mw'])
            AWG.stop()

            if self.elapsed_sweeps < self.sweeps:
                self.state = 'idle'
            else:
                self.state = 'done'

        except:  # if anything fails, log the exception and set the state
            logging.getLogger().exception(
                'Something went wrong in pulsed loop.')
            self.state = 'error'

        finally:
            CS.stop_gated_counting()  # stop nidaq task to free counters

    get_set_items = [
        '__doc__', 'record_length', 'laser', 'wait', 'sequence', 'count_data',
        'run_time', 'tau_begin', 'tau_end', 'tau_delta', 'tau', 'freq_center',
        'power', 'laser_SST', 'wait_SST', 'amp', 'vpp', 'pi', 'freq', 'N_shot',
        'readout_interval', 'samples_per_read'
    ]

    traits_view = View(
        VGroup(
            HGroup(
                Item('load_button', show_label=False),
                Item('submit_button', show_label=False),
                Item('remove_button', show_label=False),
                Item('resubmit_button', show_label=False),
                Item('priority'),
            ),
            HGroup(
                Item('freq', width=-70),
                Item('freq_center', width=-70),
                Item('amp', width=-30),
                Item('vpp', width=-30),
                Item('power', width=-40),
                Item('pi', width=-70),
            ),
            HGroup(
                Item('laser', width=-60),
                Item('wait', width=-60),
                Item('laser_SST', width=-50),
                Item('wait_SST', width=-50),
            ),
            HGroup(
                Item('samples_per_read', width=-50),
                Item('N_shot', width=-50),
                Item('record_length', style='readonly'),
            ),
            HGroup(
                Item('tau_begin', width=30),
                Item('tau_end', width=30),
                Item('tau_delta', width=30),
            ),
            HGroup(
                Item('state', style='readonly'),
                Item('run_time', style='readonly', format_str='%.f',
                     width=-60),
                Item('sweeps',
                     editor=TextEditor(auto_set=False,
                                       enter_set=True,
                                       evaluate=float,
                                       format_func=lambda x: '%.3e' % x),
                     width=-50),
                Item('expected_duration',
                     style='readonly',
                     editor=TextEditor(evaluate=float,
                                       format_func=lambda x: '%.2f' % x),
                     width=30),
                Item('progress', style='readonly'),
                Item('elapsed_time',
                     style='readonly',
                     editor=TextEditor(auto_set=False,
                                       enter_set=True,
                                       evaluate=float,
                                       format_func=lambda x: ' %.f' % x),
                     width=-50),
            ),
        ),
        title='Pulsed_SST Measurement',
    )
Пример #17
0
class ScatterPlotNM(MutableTemplate):

    #-- Template Traits --------------------------------------------------------

    # The title of the plot:
    title = TStr('NxM Scatter Plots')

    # The type of marker to use.  This is a mapped trait using strings as the
    # keys:
    marker = marker_trait(template='copy', event='update')

    # The pixel size of the marker (doesn't include the thickness of the
    # outline):
    marker_size = TRange(1, 5, 1, event='update')

    # The thickness, in pixels, of the outline to draw around the marker.  If
    # this is 0, no outline will be drawn.
    line_width = TRange(0.0, 5.0, 1.0)

    # The fill color of the marker:
    color = TColor('red', event='update')

    # The color of the outline to draw around the marker
    outline_color = TColor('black', event='update')

    # The number of rows of plots:
    rows = TRange(1, 3, 1, event='grid')

    # The number of columns of plots:
    columns = TRange(1, 5, 1, event='grid')

    # The contained scatter plots:
    scatter_plots = TList(ScatterPlot)

    #-- Derived Traits ---------------------------------------------------------

    plot = TDerived

    #-- Traits UI Views --------------------------------------------------------

    # The scatter plot view:
    template_view = View(VGroup(
        Item('title',
             show_label=False,
             style='readonly',
             editor=ThemedTextEditor(theme=Theme('@GBB', alignment='center'))),
        Item('plot',
             show_label=False,
             resizable=True,
             editor=EnableEditor(),
             item_theme=Theme('@GF5', margins=0))),
                         resizable=True)

    # The scatter plot options view:
    options_view = View(
        VGroup(
            VGroup(Label('Scatter Plot Options',
                         item_theme=Theme('@GBB', alignment='center')),
                   show_labels=False),
            VGroup(Item('title', editor=TextEditor()),
                   Item('marker'),
                   Item('marker_size', editor=ThemedSliderEditor()),
                   Item('line_width',
                        label='Line Width',
                        editor=ThemedSliderEditor()),
                   Item('color', label='Fill Color'),
                   Item('outline_color', label='Outline Color'),
                   Item('rows', editor=ThemedSliderEditor()),
                   Item('columns', editor=ThemedSliderEditor()),
                   group_theme=Theme('@GF5', margins=(-5, -1)),
                   item_theme=Theme('@G0B', margins=0))))

    #-- ITemplate Interface Implementation -------------------------------------

    def activate_template(self):
        """ Converts all contained 'TDerived' objects to real objects using the
            template traits of the object. This method must be overridden in
            subclasses.
            
            Returns
            -------
            None
        """
        plots = []
        i = 0
        for r in range(self.rows):
            row = []
            for c in range(self.columns):
                plot = self.scatter_plots[i].plot
                if plot is None:
                    plot = PlotComponent()
                row.append(plot)
                i += 1
            plots.append(row)

        self.plot = GridPlotContainer(shape=(self.rows, self.columns))
        self.plot.component_grid = plots

    #-- Default Values ---------------------------------------------------------

    def _scatter_plots_default(self):
        """ Returns the default value for the scatter plots list.
        """
        plots = []
        for i in range(self.rows * self.columns):
            plots.append(ScatterPlot())

        self._update_plots(plots)

        return plots

    #-- Trait Event Handlers ---------------------------------------------------

    def _update_changed(self, name, old, new):
        """ Handles a plot option being changed. 
        """
        for sp in self.scatter_plots:
            setattr(sp, name, new)

        self.plot = Undefined

    def _grid_changed(self):
        """ Handles the grid size being changed.
        """
        n = self.rows * self.columns
        plots = self.scatter_plots
        if n < len(plots):
            self.scatter_plots = plots[:n]
        else:
            for j in range(len(plots), n):
                plots.append(ScatterPlot())

        self._update_plots(plots)

        self.template_mutated = True

    #-- Private Methods --------------------------------------------------------

    def _update_plots(self, plots):
        """ Update the data sources for all of the current plots.
        """
        index = None
        i = 0
        for r in range(self.rows):
            for c in range(self.columns):
                sp = plots[i]
                i += 1
                desc = sp.value.description
                col = desc.rfind('[')
                if col >= 0:
                    desc = desc[:col]
                sp.value.description = '%s[%d,%d]' % (desc, r, c)
                sp.value.optional = True

                if index is None:
                    index = sp.index
                    index.description = 'Shared Plot Index'
                    index.optional = True
                else:
                    sp.index = index
Пример #18
0
class ScatterPlot2(Template):

    #-- Template Traits --------------------------------------------------------

    # The title of the plot:
    title = TStr('Dual Scatter Plots')

    # The type of marker to use.  This is a mapped trait using strings as the
    # keys:
    marker = marker_trait(template='copy', event='update')

    # The pixel size of the marker (doesn't include the thickness of the
    # outline):
    marker_size = TRange(1, 5, 1, event='update')

    # The thickness, in pixels, of the outline to draw around the marker.  If
    # this is 0, no outline will be drawn.
    line_width = TRange(0.0, 5.0, 1.0)

    # The fill color of the marker:
    color = TColor('red', event='update')

    # The color of the outline to draw around the marker
    outline_color = TColor('black', event='update')

    # The amount of space between plots:
    spacing = TRange(0.0, 20.0, 0.0)

    # The contained scatter plots:
    scatter_plot_1 = TInstance(ScatterPlot, ())
    scatter_plot_2 = TInstance(ScatterPlot, ())

    #-- Derived Traits ---------------------------------------------------------

    plot = TDerived

    #-- Traits UI Views --------------------------------------------------------

    # The scatter plot view:
    template_view = View(VGroup(
        Item('title',
             show_label=False,
             style='readonly',
             editor=ThemedTextEditor(theme=Theme('@GBB', alignment='center'))),
        Item('plot',
             show_label=False,
             resizable=True,
             editor=EnableEditor(),
             item_theme=Theme('@GF5', margins=0))),
                         resizable=True)

    # The scatter plot options view:
    options_view = View(
        VGroup(
            VGroup(Label('Scatter Plot Options',
                         item_theme=Theme('@GBB', alignment='center')),
                   show_labels=False),
            VGroup(Item('title', editor=TextEditor()),
                   Item('marker'),
                   Item('marker_size', editor=ThemedSliderEditor()),
                   Item('line_width',
                        label='Line Width',
                        editor=ThemedSliderEditor()),
                   Item('spacing', editor=ThemedSliderEditor()),
                   Item('color', label='Fill Color'),
                   Item('outline_color', label='Outline Color'),
                   group_theme=Theme('@GF5', margins=(-5, -1)),
                   item_theme=Theme('@G0B', margins=0))))

    #-- ITemplate Interface Implementation -------------------------------------

    def activate_template(self):
        """ Converts all contained 'TDerived' objects to real objects using the
            template traits of the object. This method must be overridden in
            subclasses.
            
            Returns
            -------
            None
        """
        plots = [
            p for p in [self.scatter_plot_1.plot, self.scatter_plot_2.plot]
            if p is not None
        ]
        if len(plots) == 2:
            self.plot = HPlotContainer(spacing=self.spacing)
            self.plot.add(*plots)
        elif len(plots) == 1:
            self.plot = plots[0]

    #-- Default Values ---------------------------------------------------------

    def _scatter_plot_1_default(self):
        """ Returns the default value for the first scatter plot.
        """
        result = ScatterPlot()
        result.index.description = 'Shared Plot Index'
        result.value.description += ' 1'

        return result

    def _scatter_plot_2_default(self):
        """ Returns the default value for the second scatter plot.
        """
        result = ScatterPlot(index=self.scatter_plot_1.index)
        result.value.description += ' 2'
        result.value.optional = True

        return result

    #-- Trait Event Handlers ---------------------------------------------------

    def _update_changed(self, name, old, new):
        """ Handles a plot option being changed. 
        """
        setattr(self.scatter_plot_1, name, new)
        setattr(self.scatter_plot_2, name, new)
        self.plot = Undefined

    def _spacing_changed(self, spacing):
        """ Handles the spacing between plots being changed.
        """
        self.plot = Undefined
Пример #19
0
class SSTCounterTrace(Pulsed):

    tau_begin = Range(low=0.,
                      high=1e5,
                      value=1.,
                      desc='tau begin [ns]',
                      label='repetition',
                      mode='text',
                      auto_set=False,
                      enter_set=True)
    tau_end = Range(low=1.,
                    high=1e5,
                    value=1000.,
                    desc='tau end [ns]',
                    label='N repetition',
                    mode='text',
                    auto_set=False,
                    enter_set=True)
    tau_delta = Range(low=1.,
                      high=1e5,
                      value=1,
                      desc='delta tau [ns]',
                      label='delta',
                      mode='text',
                      auto_set=False,
                      enter_set=True)

    sweeps = Range(low=1.,
                   high=1e4,
                   value=1,
                   desc='number of sweeps',
                   label='sweeps',
                   mode='text',
                   auto_set=False,
                   enter_set=True)

    def prepare_awg(self):
        sampling = 1.2e9
        N_shot = int(self.N_shot)

        pi = int(self.pi * sampling / 1.0e9)
        laser_SST = int(self.laser_SST * sampling / 1.0e9)
        wait_SST = int(self.wait_SST * sampling / 1.0e9)

        if self.reload:
            AWG.stop()
            AWG.set_output(0b0000)
            AWG.delete_all()

            zero = Idle(1)
            self.waves = []
            sub_seq = []
            p = {}

            p['pi + 0'] = Sin(pi, (self.freq - self.freq_center) / sampling, 0,
                              self.amp)
            p['pi + 90'] = Sin(pi, (self.freq - self.freq_center) / sampling,
                               np.pi / 2, self.amp)

            read_x = Waveform(
                'read_x',
                [p['pi + 0'],
                 Idle(laser_SST, marker1=1),
                 Idle(wait_SST)])
            read_y = Waveform(
                'read_y',
                [p['pi + 90'],
                 Idle(laser_SST, marker1=1),
                 Idle(wait_SST)])
            self.waves.append(read_x)
            self.waves.append(read_y)

            self.main_seq = Sequence('SST.SEQ')
            for i, t in enumerate(self.tau):
                name = 'DQH_12_%04i.SEQ' % i
                sub_seq = Sequence(name)
                sub_seq.append(read_x, read_y, repeat=N_shot)
                AWG.upload(sub_seq)

                self.main_seq.append(sub_seq, wait=True)
            for w in self.waves:
                w.join()
            AWG.upload(self.waves)
            AWG.upload(self.main_seq)
            AWG.tell('*WAI')
            AWG.load('SST.SEQ')
        AWG.set_vpp(self.vpp)
        AWG.set_sample(sampling / 1.0e9)
        AWG.set_mode('S')
        AWG.set_output(0b0011)

    def generate_sequence(self):
        points = int(self.sequence_points)
        N_shot = self.N_shot
        laser = self.laser
        wait = self.wait
        laser_SST = self.laser_SST
        wait_SST = self.wait_SST
        pi = self.pi
        record_length = self.record_length * 1e+6

        sequence = []
        for t in range(points):
            sequence.append((['laser'], laser))
            sequence.append(([], wait))
            sequence.append((['awgTrigger'], 100))
            sequence.append((['sst'], record_length))

        return sequence

    get_set_items = Pulsed.get_set_items

    traits_view = View(
        VGroup(
            HGroup(
                Item('load_button', show_label=False),
                Item('submit_button', show_label=False),
                Item('remove_button', show_label=False),
                Item('resubmit_button', show_label=False),
                Item('priority'),
            ),
            HGroup(
                Item('freq', width=-70),
                Item('freq_center', width=-70),
                Item('amp', width=-30),
                Item('vpp', width=-30),
                Item('power', width=-40),
                Item('pi', width=-70),
            ),
            HGroup(
                Item('laser', width=-60),
                Item('wait', width=-60),
                Item('laser_SST', width=-50),
                Item('wait_SST', width=-50),
            ),
            HGroup(
                Item('samples_per_read', width=-50),
                Item('N_shot', width=-50),
                Item('record_length', style='readonly'),
            ),
            HGroup(
                Item('tau_begin', width=30),
                Item('tau_end', width=30),
                Item('tau_delta', width=30),
            ),
            HGroup(
                Item('state', style='readonly'),
                Item('run_time', style='readonly', format_str='%.f',
                     width=-50),
                Item('sweeps',
                     editor=TextEditor(auto_set=False,
                                       enter_set=True,
                                       evaluate=float,
                                       format_func=lambda x: '%.1e' % x),
                     width=-50),
                Item('expected_duration',
                     style='readonly',
                     editor=TextEditor(evaluate=float,
                                       format_func=lambda x: '%.2f' % x),
                     width=30),
                Item('progress', style='readonly'),
                Item('elapsed_time',
                     style='readonly',
                     editor=TextEditor(auto_set=False,
                                       enter_set=True,
                                       evaluate=float,
                                       format_func=lambda x: ' %.f' % x),
                     width=-50),
            ),
        ),
        title='SST Trace Measurement',
    )
Пример #20
0
class ScatterPlot(Template):

    #-- Template Traits --------------------------------------------------------

    # The plot index data source:
    index = TDataSource

    # The plot value data source:
    value = TDataSource

    # The title of the plot:
    title = TStr('Scatter Plot')

    # The type of marker to use.  This is a mapped trait using strings as the
    # keys:
    marker = marker_trait(template='copy', event='update')

    # The pixel size of the marker (doesn't include the thickness of the
    # outline):
    marker_size = TRange(1, 5, 1, event='update')

    # The thickness, in pixels, of the outline to draw around the marker.  If
    # this is 0, no outline will be drawn.
    line_width = TRange(0.0, 5.0, 1.0)

    # The fill color of the marker:
    color = TColor('red', event='update')

    # The color of the outline to draw around the marker
    outline_color = TColor('black', event='update')

    #-- Derived Traits ---------------------------------------------------------

    plot = TDerived  # Instance( ScatterPlot )

    #-- Traits UI Views --------------------------------------------------------

    # The scatter plot view:
    template_view = View(VGroup(
        Item('title',
             show_label=False,
             style='readonly',
             editor=ThemedTextEditor(theme=Theme('@GBB', alignment='center'))),
        Item('plot',
             show_label=False,
             resizable=True,
             editor=EnableEditor(),
             item_theme=Theme('@GF5', margins=0))),
                         resizable=True)

    # The scatter plot options view:
    options_view = View(
        VGroup(
            VGroup(Label('Scatter Plot Options',
                         item_theme=Theme('@GBB', alignment='center')),
                   show_labels=False),
            VGroup(Item('title', editor=TextEditor()),
                   Item('marker'),
                   Item('marker_size', editor=ThemedSliderEditor()),
                   Item('line_width',
                        label='Line Width',
                        editor=ThemedSliderEditor()),
                   Item('color', label='Fill Color'),
                   Item('outline_color', label='Outline Color'),
                   group_theme=Theme('@GF5', margins=(-5, -1)),
                   item_theme=Theme('@G0B', margins=0))))

    #-- Default Values ---------------------------------------------------------

    def _index_default(self):
        """ Returns the default value for the 'index' trait.
        """
        return TemplateDataSource(
            items=[ValueDataNameItem(name='index', flatten=True)],
            description='Scatter Plot Index')

    def _value_default(self):
        """ Returns the default value for the 'value' trait.
        """
        return TemplateDataSource(
            items=[ValueDataNameItem(name='value', flatten=True)],
            description='Scatter Plot Value')

    #-- ITemplate Interface Implementation -------------------------------------

    def activate_template(self):
        """ Converts all contained 'TDerived' objects to real objects using the
            template traits of the object. This method must be overridden in
            subclasses.
            
            Returns
            -------
            None
        """
        # If our data sources are still unbound, then just exit; someone must
        # have marked them as optional:
        if ((self.index.context_data is Undefined)
                or (self.value.context_data is Undefined)):
            return

        # Create a plot data object and give it this data:
        pd = ArrayPlotData()
        pd.set_data('index', self.index.context_data)
        pd.set_data('value', self.value.context_data)

        # Create the plot:
        self.plot = plot = Plot(pd)
        plot.plot(('index', 'value'),
                  type='scatter',
                  index_sort='ascending',
                  marker=self.marker,
                  color=self.color,
                  outline_color=self.outline_color,
                  marker_size=self.marker_size,
                  line_width=self.line_width,
                  bgcolor='white')
        plot.set(padding_left=50,
                 padding_right=0,
                 padding_top=0,
                 padding_bottom=20)

        # Attach some tools to the plot:
        plot.tools.append(PanTool(plot, constrain_key='shift'))
        zoom = SimpleZoom(component=plot, tool_mode='box', always_on=False)
        plot.overlays.append(zoom)

    #-- Trait Event Handlers ---------------------------------------------------

    def _update_changed(self):
        """ Handles a plot option being changed. 
        """
        self.plot = Undefined
Пример #21
0
class MPLPlot(Viewer):
    """
      A plot, cointains code to display using a Matplotlib figure and to update itself
      dynamically from a Variables instance (which must be passed in on initialisation).
      The function plotted is calculated using 'expr' which should also be set on init
      and can be any python expression using the variables in the pool.
  """
    name = Str('MPL Plot')
    figure = Instance(Figure, ())
    expr = Str

    x_max = Float
    x_max_auto = Bool(True)
    x_min = Float
    x_min_auto = Bool(True)
    y_max = Float
    y_max_auto = Bool(True)
    y_min = Float
    y_min_auto = Bool(True)

    scroll = Bool(True)
    scroll_width = Float(300)

    legend = Bool(False)
    legend_pos = Enum('upper left', 'upper right', 'lower left', 'lower right',
                      'right', 'center left', 'center right', 'lower center',
                      'upper center', 'center', 'best')

    traits_view = View(Item(name='name', label='Plot name'),
                       Item(name='expr',
                            label='Expression(s)',
                            editor=TextEditor(enter_set=True, auto_set=False)),
                       Item(label='Use commas\nfor multi-line plots.'),
                       HGroup(
                           Item(name='legend', label='Show legend'),
                           Item(name='legend_pos', show_label=False),
                       ),
                       VGroup(HGroup(Item(name='x_max', label='Max'),
                                     Item(name='x_max_auto', label='Auto')),
                              HGroup(Item(name='x_min', label='Min'),
                                     Item(name='x_min_auto', label='Auto')),
                              HGroup(
                                  Item(name='scroll', label='Scroll'),
                                  Item(name='scroll_width',
                                       label='Scroll width'),
                              ),
                              label='X',
                              show_border=True),
                       VGroup(HGroup(Item(name='y_max', label='Max'),
                                     Item(name='y_max_auto', label='Auto')),
                              HGroup(Item(name='y_min', label='Min'),
                                     Item(name='y_min_auto', label='Auto')),
                              label='Y',
                              show_border=True),
                       title='Plot settings',
                       resizable=True)

    view = View(Item(name='figure', editor=MPLFigureEditor(),
                     show_label=False),
                width=400,
                height=300,
                resizable=True)

    legend_prop = matplotlib.font_manager.FontProperties(size=8)

    def start(self):
        # Init code creates an empty plot to be updated later.
        axes = self.figure.add_subplot(111)
        axes.plot([0], [0])

    def update(self):
        """
        Update the plot from the Variables instance and make a call to wx to
        redraw the figure.
    """
        axes = self.figure.gca()
        lines = axes.get_lines()

        if lines:
            exprs = self.get_exprs()
            if len(exprs) > len(lines):
                for i in range(len(exprs) - len(lines)):
                    axes.plot([0], [0])
                lines = axes.get_lines()

            max_xs = max_ys = min_xs = min_ys = 0

            for n, expr in enumerate(exprs):
                first = 0
                last = None
                if self.scroll and self.x_min_auto and self.x_max_auto:
                    first = -self.scroll_width
                if not self.x_min_auto:
                    first = int(self.x_min)
                if not self.x_max_auto:
                    last = int(self.x_max) + 1

                ys = self.variables.new_expression(expr).get_array(first, last)

                if len(ys) != 0:
                    xs = self.variables.new_expression('sample_num').get_array(
                        first, last)
                else:
                    xs = [0]
                    ys = [0]

                if len(xs) != len(ys):
                    print "MPL Plot: x and y arrays different sizes!!! Ignoring (but fix me soon)."
                    return

                lines[n].set_xdata(xs)
                lines[n].set_ydata(ys)

                max_xs = max_xs if (max(xs) < max_xs) else max(xs)
                max_ys = max_ys if (max(ys) < max_ys) else max(ys)
                min_xs = min_xs if (min(xs) > min_xs) else min(xs)
                min_ys = min_ys if (min(ys) > min_ys) else min(ys)

            if self.x_max_auto:
                self.x_max = max_xs
            if self.x_min_auto:
                if self.scroll and self.x_max_auto:
                    scroll_x_min = self.x_max - self.scroll_width
                    self.x_min = scroll_x_min if (scroll_x_min >= 0) else 0
                else:
                    self.x_min = min_xs
            if self.y_max_auto:
                self.y_max = max_ys
            if self.y_min_auto:
                self.y_min = min_ys

            axes.set_xbound(upper=self.x_max, lower=self.x_min)
            axes.set_ybound(upper=self.y_max * 1.1, lower=self.y_min * 1.1)

            self.draw_plot()

    def get_exprs(self):
        return self.expr.split(',')

    def add_expr(self, expr):
        if self.expr == '' or self.expr[:-1] == ',':
            self.expr += expr
        else:
            self.expr += ',' + expr

    def draw_plot(self):
        if self.figure.canvas:
            CallAfter(self.figure.canvas.draw)

    @on_trait_change('legend_pos')
    def update_legend_pos(self, old, new):
        """ Move the legend, calls update_legend """
        self.update_legend(None, None)

    @on_trait_change('legend')
    def update_legend(self, old, new):
        """ Called when we change the legend display """
        axes = self.figure.gca()
        lines = axes.get_lines()
        exprs = self.get_exprs()

        if len(exprs) >= 1 and self.legend:
            axes.legend(lines[:len(exprs)],
                        exprs,
                        loc=self.legend_pos,
                        prop=self.legend_prop)
        else:
            axes.legend_ = None

        self.draw_plot()
Пример #22
0
class PPCoordSource(VTKDataSource):

    file_name=Str

    component=Enum('Amp','Phase', 'Real', 'Imag')
    space=Enum("Real", "Recip")

    transpose_input_array=Bool(True)

    coords=Instance(cs.CoordSystem)
    sg=Instance(tvtk.StructuredGrid)
    im=Instance(tvtk.ImageData)
    scalar_data=Array
    coordarray=Array
    dataarray=Array

    engine=Instance(Engine)

    set_component=Str
    scalar_name=Str

    view = View(Group(
		Item(name='transpose_input_array'),
	      	Item(name='file_name', 
			editor=TextEditor(auto_set=False,\
			enter_set=True)),
	      	Item(name='component', style='custom', editor=EnumEditor(values=component)),
#              	Item(name='scalar_name'),
#              	Item(name='vector_name'),
#             	Item(name='spacing'),
#             	Item(name='origin'),
#             	Item(name='update_image_data', show_label=False),
		label='Data'
		),

		Group(
		Item(name='coords', style='custom', show_label=False), 
		label='Coordinate System'),
#		kind='modal',
		buttons=['Ok', 'Cancel'])

    def __init__(self, **kw_args):
	super(PPCoordSource, self).__init__(**kw_args)
	fn=kw_args.pop('file_name', None)

        dir=os.path.dirname(self.file_name)
        phparams=os.path.join(dir, "phasingparams.py")

	engine=kw_args.pop('engine', None)

	self.engine.add_trait('dataarrays', Dict)
	self.engine.add_trait('coords', Dict)

	if self.engine.coords.has_key(phparams):
	   self.coords=self.engine.coords[phparams]
	else:
	   self.engine.coords[phparams]=cs.CoordSystem()
	   self.coords=self.engine.coords[phparams]
	   self.coords.exec_param_file(phparams)
	self.coords.on_trait_change(self._coords_changed, 'T')

	if fn is not None:
		self.file_name= fn
		self._open(self.file_name)

	dims=self.dataarray.shape
	print "init coords update start"
        self.coords.UpdateCoordSystem(dims)
	print "init coords update end"

	self.sg=tvtk.StructuredGrid()
	self.im=tvtk.ImageData()
	self.component="Real"
	self._component_changed('Real')		#can change this for a different default

    def _coords_changed(self, info):
	dims=self.scalar_data.shape
	self.coords.UpdateCoordSystem(dims)
	self.set_data()

    def _component_changed(self, info):
	if info=="Amp":
		self.scalar_data=numpy.abs(self.dataarray)
	if info=="Phase":
		self.scalar_data=numpy.angle(self.dataarray)
	if info=="Real":
		self.scalar_data=self.dataarray.real
	if info=="Imag":
		self.scalar_data=self.self.dataarray.imag
	self.set_data()

    def _file_name_changed(self, info):
	print self.file_name

    def _transpose_input_array_changed(self, info):
	self.set_data()

    def set_data(self):

	if not self.scalar_data.any():
		return 

	print "MAKE SGRID"
	dims=list(self.scalar_data.shape)
        self.coords.UpdateCoordSystem(dims)
	sg=self.sg
	sg.points = self.coords.coords
	sg.point_data.scalars=self.scalar_data.ravel()
	# The transpose is not needed for ScalarGrid
	if self.transpose_input_array:
	  sg.point_data.scalars=self.scalar_data.ravel()
	else:
	  sg.point_data.scalars=numpy.ravel(numpy.transpose(self.scalar_data))
	sg.point_data.scalars.name=self.component
	sg.dimensions=(dims[2], dims[1], dims[0])
	sg.extent=0, dims[2]-1, 0, dims[1]-1, 0, dims[0]-1
	sg.update_extent=0, dims[2]-1, 0, dims[1]-1, 0, dims[0]-1
	#sg.dimensions=self.scalar_data.shape
	self.data=sg
	self._update_data()
	self.update()
Пример #23
0
                           selection_bg_color=None)

# The standard columns:
std_columns = [
    ResolvedColumn(name='resolved',
                   label='?',
                   editable=False,
                   width=20,
                   horizontal_alignment='center',
                   cell_color=0xFF8080),
    OptionalColumn(name='optional',
                   label='*',
                   editable=False,
                   width=20,
                   horizontal_alignment='center'),
    ObjectColumn(name='description', editor=TextEditor(), width=0.47)
]

#-------------------------------------------------------------------------------
#  'TemplateDataNames' class:
#-------------------------------------------------------------------------------


class TemplateDataNames(HasPrivateTraits):

    #-- Public Traits ----------------------------------------------------------

    # The data context to which bindings are made:
    context = Instance(ITemplateDataContext)

    # The current set of data names to be bound to the context:
Пример #24
0
class Pulsed(ManagedJob, GetSetItemsMixin):
    """Defines a pulsed measurement."""

    keep_data = Bool(
        False)  # helper variable to decide whether to keep existing data

    resubmit_button = Button(
        label='resubmit',
        desc=
        'Submits the measurement to the job manager. Tries to keep previously acquired data. Behaves like a normal submit if sequence or time bins have changed since previous run.'
    )

    sequence = Instance(list, factory=list)

    record_length = Range(low=100,
                          high=100000.,
                          value=3000,
                          desc='length of acquisition record [ns]',
                          label='record length [ns]',
                          mode='text',
                          auto_set=False,
                          enter_set=True)
    bin_width = Range(low=0.1,
                      high=1000.,
                      value=3.2,
                      desc='data bin width [ns]',
                      label='bin width [ns]',
                      mode='text',
                      auto_set=False,
                      enter_set=True)

    n_laser = Int(2)
    n_bins = Int(2)
    time_bins = Array(value=np.array((0, 1)))

    count_data = Array(value=np.zeros((2, 2)))

    run_time = Float(value=0.0, label='run time [ns]', format_str='%.f')
    stop_time = Range(
        low=1.,
        value=np.inf,
        desc='Time after which the experiment stops by itself [s]',
        label='Stop time [s]',
        mode='text',
        auto_set=False,
        enter_set=True)

    tau_begin = Range(low=0.,
                      high=1e8,
                      value=300.,
                      desc='tau begin [ns]',
                      label='tau begin [ns]',
                      mode='text',
                      auto_set=False,
                      enter_set=True)
    tau_end = Range(low=1.,
                    high=1e8,
                    value=4000.,
                    desc='tau end [ns]',
                    label='tau end [ns]',
                    mode='text',
                    auto_set=False,
                    enter_set=True)
    tau_delta = Range(low=1.,
                      high=1e8,
                      value=50.,
                      desc='delta tau [ns]',
                      label='delta tau [ns]',
                      mode='text',
                      auto_set=False,
                      enter_set=True)

    tau = Array(value=np.array((0., 1.)))

    laser = Range(low=1.,
                  high=5e6,
                  value=3000.,
                  desc='laser [ns]',
                  label='laser [ns]',
                  mode='text',
                  auto_set=False,
                  enter_set=True)
    wait = Range(low=1.,
                 high=5e6,
                 value=5000.,
                 desc='wait [ns]',
                 label='wait [ns]',
                 mode='text',
                 auto_set=False,
                 enter_set=True)

    freq = Range(low=1,
                 high=20e9,
                 value=2.71e9,
                 desc='frequency [Hz]',
                 label='frequency [Hz]',
                 mode='text',
                 auto_set=False,
                 enter_set=True)
    power = Range(low=-100.,
                  high=25.,
                  value=-26,
                  desc='power [dBm]',
                  label='power [dBm]',
                  mode='text',
                  auto_set=False,
                  enter_set=True)

    sweeps = Range(low=1.,
                   high=1e10,
                   value=1e6,
                   desc='number of sweeps',
                   label='sweeps',
                   mode='text',
                   auto_set=False,
                   enter_set=True)
    expected_duration = Property(
        trait=Float,
        depends_on='sweeps,sequence',
        desc='expected duration of the measurement [s]',
        label='expected duration [s]')
    elapsed_sweeps = Float(value=0,
                           desc='Elapsed Sweeps ',
                           label='Elapsed Sweeps ',
                           mode='text')
    elapsed_time = Float(value=0,
                         desc='Elapsed Time [ns]',
                         label='Elapsed Time [ns]',
                         mode='text')
    progress = Int(value=0,
                   desc='Progress [%]',
                   label='Progress [%]',
                   mode='text')

    import_code = Code()

    import_button = Button(
        desc=
        'set parameters such as pulse length, frequency, power, etc. by executing import code specified in settings',
        label='import')

    def __init__(self):
        super(Pulsed, self).__init__()

    def submit(self):
        """Submit the job to the JobManager."""
        self.keep_data = False
        ManagedJob.submit(self)

    def resubmit(self):
        """Submit the job to the JobManager."""
        self.keep_data = True
        ManagedJob.submit(self)

    def _resubmit_button_fired(self):
        """React to start button. Submit the Job."""
        self.resubmit()

    def generate_sequence(self):
        return []

    @cached_property
    def _get_expected_duration(self):
        sequence_length = 0
        for step in self.sequence:
            sequence_length += step[1]
        return self.sweeps * sequence_length * 1e-9

    def _get_sequence_points(self):
        return len(self.tau)

    def apply_parameters(self):
        """Apply the current parameters and decide whether to keep previous data."""
        n_bins = int(self.record_length / self.bin_width)
        time_bins = self.bin_width * np.arange(n_bins)
        tau = np.arange(self.tau_begin, self.tau_end, self.tau_delta)
        self.tau = tau

        sequence = self.generate_sequence()
        n_laser = find_laser_pulses(sequence)

        self.sequence = sequence
        self.sequence_points = self._get_sequence_points()
        self.time_bins = time_bins
        self.n_bins = n_bins
        self.n_laser = n_laser

        if self.keep_data and sequence == self.sequence and np.all(
                time_bins == self.time_bins
        ):  # if the sequence and time_bins are the same as previous, keep existing data
            self.old_count_data = self.count_data.copy()
            self.previous_sweeps = self.elapsed_sweeps
            self.previous_elapsed_time = self.elapsed_time
        else:
            #self.old_count_data = np.zeros((n_laser, n_bins))
            FC.Configure(self.laser, self.bin_width, self.sequence_points)
            #self.check = True
            self.old_count_data = np.zeros(FC.GetData().shape)
            self.previous_sweeps = 0
            self.previous_elapsed_time = 0.0
            self.run_time = 0.0

        self.keep_data = True  # when job manager stops and starts the job, data should be kept. Only new submission should clear data.

    def _run(self):
        """Acquire data."""

        try:  # try to run the acquisition from start_up to shut_down
            self.state = 'run'
            self.apply_parameters()

            PG.High([])
            FC.SetCycles(np.inf)
            FC.SetTime(np.inf)
            FC.SetDelay(0)
            FC.SetLevel(0.6, 0.6)
            FC.Configure(self.laser, self.bin_width, self.sequence_points)
            #self.previous_time = 0
            #self.previous_sweeps = 0
            #self.previous_count_data = FC.GetData()
            MW.setFrequency(self.freq)
            MW.setPower(self.power)
            time.sleep(2.0)
            FC.Start()
            time.sleep(0.1)
            PG.Sequence(self.sequence, loop=True)

            start_time = time.time()

            while self.run_time < self.stop_time:
                self.thread.stop_request.wait(1.0)
                if self.thread.stop_request.isSet():
                    logging.getLogger().debug('Caught stop signal. Exiting.')
                    break
                self.elapsed_time = time.time() - start_time
                self.run_time += self.elapsed_time
                runtime, cycles = FC.GetState()
                sweeps = cycles / FC.GetData().shape[0]
                self.elapsed_sweeps = self.previous_sweeps + sweeps
                self.progress = int(100 * self.elapsed_sweeps / self.sweeps)
                self.count_data = self.old_count_data + FC.GetData()
                if self.elapsed_sweeps > self.sweeps:
                    break

            FC.Halt()
            MW.Off()
            PG.High(['laser'])
            if self.elapsed_sweeps < self.sweeps:
                self.state = 'idle'
            else:
                self.state = 'done'

        except:  # if anything fails, log the exception and set the state
            logging.getLogger().exception(
                'Something went wrong in pulsed loop.')
            self.state = 'error'

    get_set_items = [
        '__doc__', 'record_length', 'laser', 'wait', 'bin_width', 'n_bins',
        'time_bins', 'n_laser', 'sequence', 'count_data', 'run_time',
        'tau_begin', 'tau_end', 'tau_delta', 'tau', 'power'
    ]

    traits_view = View(
        VGroup(
            HGroup(
                Item('submit_button', show_label=False),
                Item('remove_button', show_label=False),
                Item('resubmit_button', show_label=False),
                Item('priority'),
            ),
            HGroup(
                Item('freq', width=40),
                Item('power', width=40),
            ),
            HGroup(
                Item('laser', width=40),
                Item('wait', width=40),
                Item('bin_width', width=-80, enabled_when='state != "run"'),
                Item('record_length', width=-80,
                     enabled_when='state != "run"'),
            ),
            HGroup(
                Item('tau_begin', width=40),
                Item('tau_end', width=40),
                Item('tau_delta', width=40),
            ),
            HGroup(
                Item('state', style='readonly'),
                Item('run_time', style='readonly', format_str='%.f'),
                Item('sweeps',
                     editor=TextEditor(auto_set=False,
                                       enter_set=True,
                                       evaluate=float,
                                       format_func=lambda x: '%.3e' % x),
                     width=40),
                Item('expected_duration',
                     style='readonly',
                     editor=TextEditor(evaluate=float,
                                       format_func=lambda x: '%.f' % x),
                     width=40),
                Item('progress', style='readonly'),
                Item('elapsed_time',
                     style='readonly',
                     editor=TextEditor(auto_set=False,
                                       enter_set=True,
                                       evaluate=float,
                                       format_func=lambda x: ' %.f' % x),
                     width=40),
            ),
        ),
        title='Pulsed Measurement',
    )
Пример #25
0
class LiveTimestampModelerWithAnalogInput(LiveTimestampModeler):
    view_AIN = traits.Button(label='view analog input (AIN)')
    viewer = traits.Instance(AnalogInputViewer)

    # the actual analog data (as a wordstream)
    ain_data_raw = traits.Array(dtype=np.uint16, transient=True)
    old_data_raw = traits.Array(dtype=np.uint16, transient=True)

    timer3_top = traits.Property(
    )  # necessary to calculate precise timestamps for AIN data
    channel_names = traits.Property()
    Vcc = traits.Property(depends_on='_trigger_device')
    ain_overflowed = traits.Int(
        0,
        transient=True)  # integer for display (boolean readonly editor ugly)

    ain_wordstream_buffer = traits.Any()
    traits_view = View(
        Group(
            Item('synchronize', show_label=False),
            Item('view_time_model_plot', show_label=False),
            Item('ain_overflowed', style='readonly'),
            Item(
                name='gain',
                style='readonly',
                editor=TextEditor(evaluate=float, format_func=myformat),
            ),
            Item(
                name='offset',
                style='readonly',
                editor=TextEditor(evaluate=float, format_func=myformat2),
            ),
            Item(
                name='residual_error',
                style='readonly',
                editor=TextEditor(evaluate=float, format_func=myformat),
            ),
            Item('view_AIN', show_label=False),
        ),
        title='Timestamp modeler',
    )

    @traits.cached_property
    def _get_Vcc(self):
        return self._trigger_device.Vcc

    def _get_timer3_top(self):
        return self._trigger_device.timer3_top

    def _get_channel_names(self):
        return self._trigger_device.enabled_channel_names

    def update_analog_input(self):
        """call this function frequently to avoid overruns"""
        new_data_raw = self._trigger_device.get_analog_input_buffer_rawLE()
        data_raw = np.hstack((new_data_raw, self.old_data_raw))
        self.ain_data_raw = new_data_raw
        newdata_all = []
        chan_all = []
        any_overflow = False
        #cum_framestamps = []
        while len(data_raw):
            result = cDecode.process(data_raw)
            (N, samples, channels, did_overflow, framestamp) = result
            if N == 0:
                # no data was able to be processed
                break
            data_raw = data_raw[N:]
            newdata_all.append(samples)
            chan_all.append(channels)
            if did_overflow:
                any_overflow = True
            # Save framestamp data.
            # This is not done yet:
            ## if framestamp is not None:
            ##     cum_framestamps.append( framestamp )
        self.old_data_raw = data_raw  # save unprocessed data for next run

        if any_overflow:
            # XXX should move to logging the error.
            self.ain_overflowed = 1
            raise AnalogDataOverflowedError()

        if len(chan_all) == 0:
            # no data
            return
        chan_all = np.hstack(chan_all)
        newdata_all = np.hstack(newdata_all)
        USB_channel_numbers = np.unique(chan_all)
        #print len(newdata_all),'new samples on channels',USB_channel_numbers

        ## F_OSC = 8000000.0 # 8 MHz
        ## adc_prescaler = 128
        ## downsample = 20 # maybe 21?
        ## n_chan = 3
        ## F_samp = F_OSC/adc_prescaler/downsample/n_chan
        ## dt=1.0/F_samp
        ## ## print '%.1f Hz sampling. %.3f msec dt'%(F_samp,dt*1e3)
        ## MAXLEN_SEC=0.3
        ## #MAXLEN = int(MAXLEN_SEC/dt)
        MAXLEN = 5000  #int(MAXLEN_SEC/dt)
        ## ## print 'MAXLEN',MAXLEN
        ## ## print

        for USB_chan in USB_channel_numbers:
            vi = self.viewer.usb_device_number2index[USB_chan]
            cond = chan_all == USB_chan
            newdata = newdata_all[cond]

            oldidx = self.viewer.channels[vi].index
            olddata = self.viewer.channels[vi].data

            if len(oldidx):
                baseidx = oldidx[-1] + 1
            else:
                baseidx = 0.0
            newidx = np.arange(len(newdata), dtype=np.float) + baseidx

            tmpidx = np.hstack((oldidx, newidx))
            tmpdata = np.hstack((olddata, newdata))

            if len(tmpidx) > MAXLEN:
                # clip to MAXLEN
                self.viewer.channels[vi].index = tmpidx[-MAXLEN:]
                self.viewer.channels[vi].data = tmpdata[-MAXLEN:]
            else:
                self.viewer.channels[vi].index = tmpidx
                self.viewer.channels[vi].data = tmpdata

    def _view_AIN_fired(self):
        self.viewer.edit_traits()
Пример #26
0
class LiveTimestampModeler(traits.HasTraits):
    _trigger_device = traits.Instance(ttrigger.DeviceModel)

    sync_interval = traits.Float(2.0)
    has_ever_synchronized = traits.Bool(False, transient=True)

    frame_offset_changed = traits.Event

    timestamps_framestamps = traits.Array(shape=(None, 2), dtype=np.float)

    timestamp_data = traits.Any()
    block_activity = traits.Bool(False, transient=True)

    synchronize = traits.Button(label='Synchronize')
    synchronizing_info = traits.Any(None)

    gain_offset_residuals = traits.Property(
        depends_on=['timestamps_framestamps'])

    residual_error = traits.Property(depends_on='gain_offset_residuals')

    gain = traits.Property(depends_on='gain_offset_residuals')

    offset = traits.Property(depends_on='gain_offset_residuals')

    frame_offsets = traits.Dict()
    last_frame = traits.Dict()

    view_time_model_plot = traits.Button

    traits_view = View(
        Group(
            Item(
                name='gain',
                style='readonly',
                editor=TextEditor(evaluate=float, format_func=myformat),
            ),
            Item(
                name='offset',
                style='readonly',
                editor=TextEditor(evaluate=float, format_func=myformat2),
            ),
            Item(
                name='residual_error',
                style='readonly',
                editor=TextEditor(evaluate=float, format_func=myformat),
            ),
            Item('synchronize', show_label=False),
            Item('view_time_model_plot', show_label=False),
        ),
        title='Timestamp modeler',
    )

    def _block_activity_changed(self):
        if self.block_activity:
            print('Do not change frame rate or AIN parameters. '
                  'Automatic prevention of doing '
                  'so is not currently implemented.')
        else:
            print('You may change frame rate again')

    def _view_time_model_plot_fired(self):
        raise NotImplementedError('')

    def _synchronize_fired(self):
        if self.block_activity:
            print('Not synchronizing because activity is blocked. '
                  '(Perhaps because you are saving data now.')
            return

        orig_fps = self._trigger_device.frames_per_second_actual
        self._trigger_device.set_frames_per_second_approximate(0.0)
        self._trigger_device.reset_framecount_A = True  # trigger reset event
        self.synchronizing_info = (time.time() + self.sync_interval + 0.1,
                                   orig_fps)

    @traits.cached_property
    def _get_gain(self):
        result = self.gain_offset_residuals
        if result is None:
            # not enought data
            return None
        gain, offset, residuals = result
        return gain

    @traits.cached_property
    def _get_offset(self):
        result = self.gain_offset_residuals
        if result is None:
            # not enought data
            return None
        gain, offset, residuals = result
        return offset

    @traits.cached_property
    def _get_residual_error(self):
        result = self.gain_offset_residuals
        if result is None:
            # not enought data
            return None
        gain, offset, residuals = result
        if residuals is None or len(residuals) == 0:
            # not enought data
            return None
        assert len(residuals) == 1
        return residuals[0]

    @traits.cached_property
    def _get_gain_offset_residuals(self):
        if self.timestamps_framestamps is None:
            return None

        timestamps = self.timestamps_framestamps[:, 0]
        framestamps = self.timestamps_framestamps[:, 1]

        if len(timestamps) < 2:
            return None

        # like model_remote_to_local in flydra.analysis
        remote_timestamps = framestamps
        local_timestamps = timestamps

        a1 = remote_timestamps[:, np.newaxis]
        a2 = np.ones((len(remote_timestamps), 1))
        A = np.hstack((a1, a2))
        b = local_timestamps[:, np.newaxis]
        x, resids, rank, s = np.linalg.lstsq(A, b)

        gain = x[0, 0]
        offset = x[1, 0]
        return gain, offset, resids

    def set_trigger_device(self, device):
        self._trigger_device = device
        self._trigger_device.on_trait_event(
            self._on_trigger_device_reset_AIN_overflow_fired,
            name='reset_AIN_overflow')

    def _on_trigger_device_reset_AIN_overflow_fired(self):
        self.ain_overflowed = 0

    def _get_now_framestamp(self, max_error_seconds=0.003, full_output=False):
        count = 0
        while count <= 10:
            now1 = time.time()
            try:
                results = self._trigger_device.get_framestamp(
                    full_output=full_output)
            except ttrigger.NoDataError:
                raise ImpreciseMeasurementError('no data available')
            now2 = time.time()
            if full_output:
                framestamp, framecount, tcnt = results
            else:
                framestamp = results
            count += 1
            measurement_error = abs(now2 - now1)
            if framestamp % 1.0 < 0.1:
                warnings.warn('workaround of TCNT race condition on MCU...')
                continue
            if measurement_error < max_error_seconds:
                break
            time.sleep(0.01)  # wait 10 msec before trying again
        if not measurement_error < max_error_seconds:
            raise ImpreciseMeasurementError(
                'could not obtain low error measurement')
        if framestamp % 1.0 < 0.1:
            raise ImpreciseMeasurementError('workaround MCU bug')

        now = (now1 + now2) * 0.5
        if full_output:
            results = now, framestamp, now1, now2, framecount, tcnt
        else:
            results = now, framestamp
        return results

    def clear_samples(self, call_update=True):
        self.timestamps_framestamps = np.empty((0, 2))
        if call_update:
            self.update()

    def update(self, return_last_measurement_info=False):
        """call this function fairly often to pump information from the USB device"""
        if self.synchronizing_info is not None:
            done_time, orig_fps = self.synchronizing_info
            # suspended trigger pulses to re-synchronize
            if time.time() >= done_time:
                # we've waited the sync duration, restart
                self._trigger_device.set_frames_per_second_approximate(
                    orig_fps)
                self.clear_samples(call_update=False)  # avoid recursion
                self.synchronizing_info = None
                self.has_ever_synchronized = True

        results = self._get_now_framestamp(
            full_output=return_last_measurement_info)
        now, framestamp = results[:2]
        if return_last_measurement_info:
            start_timestamp, stop_timestamp, framecount, tcnt = results[2:]

        self.timestamps_framestamps = np.vstack(
            (self.timestamps_framestamps, [now, framestamp]))

        # If more than 100 samples,
        if len(self.timestamps_framestamps) > 100:
            # keep only the most recent 50.
            self.timestamps_framestamps = self.timestamps_framestamps[-50:]

        if return_last_measurement_info:
            return start_timestamp, stop_timestamp, framecount, tcnt

    def get_frame_offset(self, id_string):
        return self.frame_offsets[id_string]

    def register_frame(self,
                       id_string,
                       framenumber,
                       frame_timestamp,
                       full_output=False):
        """note that a frame happened and return start-of-frame time"""

        # This may get called from another thread (e.g. the realtime
        # image processing thread).

        # An important note about locking and thread safety: This code
        # relies on the Python interpreter to lock data structures
        # across threads. To do this internally, a lock would be made
        # for each variable in this instance and acquired before each
        # access. Because the data structures are simple Python
        # objects, I believe the operations are atomic and thus this
        # function is OK.

        # Don't trust camera drivers with giving a good timestamp. We
        # only use this to reset our framenumber-to-time data
        # gathering, anyway.
        frame_timestamp = time.time()

        if frame_timestamp is not None:
            last_frame_timestamp = self.last_frame.get(id_string, -np.inf)
            this_interval = frame_timestamp - last_frame_timestamp

            did_frame_offset_change = False
            if this_interval > self.sync_interval:
                if self.block_activity:
                    print(
                        'changing frame offset is disallowed, but you attempted to do it. ignoring.'
                    )
                else:
                    # re-synchronize camera

                    # XXX need to figure out where frame offset of two comes from:
                    self.frame_offsets[id_string] = framenumber - 2
                    did_frame_offset_change = True

            self.last_frame[id_string] = frame_timestamp

            if did_frame_offset_change:
                self.frame_offset_changed = True  # fire any listeners

        result = self.gain_offset_residuals
        if result is None:
            # not enough data
            if full_output:
                results = None, None, did_frame_offset_change
            else:
                results = None
            return results

        gain, offset, residuals = result
        corrected_framenumber = framenumber - self.frame_offsets[id_string]
        trigger_timestamp = corrected_framenumber * gain + offset

        if full_output:
            results = trigger_timestamp, corrected_framenumber, did_frame_offset_change
        else:
            results = trigger_timestamp
        return results
Пример #27
0
class Rabi(Pulsed):
    """Rabi measurement.
    """
    def __init__(self):
        super(Rabi, self).__init__()

    def generate_sequence(self):
        tau = self.tau
        laser = self.laser
        wait = self.wait
        sequence = []
        for t in tau:
            sequence.append((['mw'], t))
            sequence.append((['laser', 'trigger'], laser))
            sequence.append(([], wait))
        return sequence

    traits_view = View(
        VGroup(
            HGroup(
                Item('submit_button', show_label=False),
                Item('remove_button', show_label=False),
                Item('resubmit_button', show_label=False),
                Item('priority'),
            ),
            HGroup(
                Item('freq', width=40),
                Item('power', width=40),
            ),
            HGroup(
                Item('laser', width=40),
                Item('wait', width=40),
                Item('bin_width', width=-80, enabled_when='state != "run"'),
                Item('record_length', width=-80,
                     enabled_when='state != "run"'),
            ),
            HGroup(
                Item('tau_begin', width=40),
                Item('tau_end', width=40),
                Item('tau_delta', width=40),
            ),
            HGroup(
                Item('state', style='readonly'),
                Item('run_time', style='readonly', format_str='%.f', width=50),
                Item('sweeps',
                     editor=TextEditor(auto_set=False,
                                       enter_set=True,
                                       evaluate=float,
                                       format_func=lambda x: '%.3e' % x),
                     width=40),
                Item('expected_duration',
                     style='readonly',
                     editor=TextEditor(evaluate=float,
                                       format_func=lambda x: '%.f' % x),
                     width=40),
                Item('progress', style='readonly'),
                Item('elapsed_time',
                     style='readonly',
                     editor=TextEditor(auto_set=False,
                                       enter_set=True,
                                       evaluate=float,
                                       format_func=lambda x: ' %.f' % x),
                     width=40),
            ),
        ),
        title='Rabi Measurement',
    )
Пример #28
0
class Explorer3D(HasTraits):
    """This class basically allows you to create a 3D cube of data (a
    numpy array), specify an equation for the scalars and view it
    using the mayavi plugin.
    """

    ########################################
    # Traits.

    # Set by envisage when this is offered as a service offer.
    window = Instance('enthought.pyface.workbench.api.WorkbenchWindow')

    # The equation that generates the scalar field.
    equation = Str('sin(x*y*z)/(x*y*z)',
                   desc='equation to evaluate (enter to set)',
                   auto_set=False,
                   enter_set=True)

    # Dimensions of the cube of data.
    dimensions = Array(value=(128, 128, 128),
                       dtype=int,
                       shape=(3, ),
                       cols=1,
                       labels=['nx', 'ny', 'nz'],
                       desc='the array dimensions')

    # The volume of interest (VOI).
    volume = Array(dtype=float,
                   value=(-5, 5, -5, 5, -5, 5),
                   shape=(6, ),
                   cols=2,
                   labels=['xmin', 'xmax', 'ymin', 'ymax', 'zmin', 'zmax'],
                   desc='the volume of interest')

    # Clicking this button resets the data with the new dimensions and
    # VOI.
    update_data = Button('Update data')

    ########################################
    # Private traits.
    # Our data source.
    _x = Array
    _y = Array
    _z = Array
    data = Array
    source = Any
    _ipw1 = Any
    _ipw2 = Any
    _ipw3 = Any

    ########################################
    # Our UI view.
    view = View(
        Item('equation', editor=TextEditor(auto_set=False, enter_set=True)),
        Item('dimensions'),
        Item('volume'),
        Item('update_data', show_label=False),
        resizable=True,
        scrollable=True,
    )

    ######################################################################
    # `object` interface.
    ######################################################################
    def __init__(self, **traits):
        super(Explorer3D, self).__init__(**traits)
        # Make some default data.
        if len(self.data) == 0:
            self._make_data()
        # Note: to show the visualization by default we must wait till
        # the mayavi engine has started.  To do this we hook into the
        # mayavi engine's started event and setup our visualization.
        # Now, when this object is constructed (i.e. when this method
        # is invoked), the services are not running yet and our own
        # application instance has not been set.  So we can't even
        # get hold of the mayavi instance.  So, we do the hooking up
        # when our application instance is set by listening for
        # changes to our application trait.

    def get_mayavi(self):
        from enthought.mayavi.plugins.script import Script
        return self.window.get_service(Script)

    ######################################################################
    # Non-public methods.
    ######################################################################
    def _make_data(self):
        dims = self.dimensions.tolist()
        np = dims[0] * dims[1] * dims[2]
        xmin, xmax, ymin, ymax, zmin, zmax = self.volume
        x, y, z = numpy.ogrid[xmin:xmax:dims[0] * 1j, ymin:ymax:dims[1] * 1j,
                              zmin:zmax:dims[2] * 1j]
        self._x = x.astype('f')
        self._y = y.astype('f')
        self._z = z.astype('f')
        self._equation_changed('', self.equation)

    def _show_data(self):
        if self.source is not None:
            return
        mayavi = self.get_mayavi()
        if mayavi.engine.current_scene is None:
            mayavi.new_scene()
        from enthought.mayavi.sources.array_source import ArraySource
        vol = self.volume
        origin = vol[::2]
        spacing = (vol[1::2] - origin) / (self.dimensions - 1)
        src = ArraySource(transpose_input_array=False,
                          scalar_data=self.data,
                          origin=origin,
                          spacing=spacing)
        self.source = src
        mayavi.add_source(src)

        from enthought.mayavi.modules.outline import Outline
        from enthought.mayavi.modules.image_plane_widget import ImagePlaneWidget
        from enthought.mayavi.modules.axes import Axes
        # Visualize the data.
        o = Outline()
        mayavi.add_module(o)
        a = Axes()
        mayavi.add_module(a)
        self._ipw1 = ipw = ImagePlaneWidget()
        mayavi.add_module(ipw)
        ipw.module_manager.scalar_lut_manager.show_scalar_bar = True

        self._ipw2 = ipw_y = ImagePlaneWidget()
        mayavi.add_module(ipw_y)
        ipw_y.ipw.plane_orientation = 'y_axes'

        self._ipw3 = ipw_z = ImagePlaneWidget()
        mayavi.add_module(ipw_z)
        ipw_z.ipw.plane_orientation = 'z_axes'

    ######################################################################
    # Traits static event handlers.
    ######################################################################
    def _equation_changed(self, old, new):
        try:
            g = numpy.__dict__
            s = eval(new, g, {'x': self._x, 'y': self._y, 'z': self._z})
            # The copy makes the data contiguous and the transpose
            # makes it suitable for display via tvtk.
            s = s.transpose().copy()
            # Reshaping the array is needed since the transpose
            # messes up the dimensions of the data.  The scalars
            # themselves are ravel'd and used internally by VTK so the
            # dimension does not matter for the scalars.
            s.shape = s.shape[::-1]
            self.data = s
        except:
            pass

    def _dimensions_changed(self):
        """This does nothing and only changes to update_data do
        anything.
        """
        return

    def _volume_changed(self):
        return

    def _update_data_fired(self):
        self._make_data()
        src = self.source
        if src is not None:
            vol = self.volume
            origin = vol[::2]
            spacing = (vol[1::2] - origin) / (self.dimensions - 1)
            # Set the source spacing and origin.
            src.set(spacing=spacing, origin=origin)
            # Update the sources data.
            src.update_image_data = True
            self._reset_ipw()

    def _reset_ipw(self):
        ipw1, ipw2, ipw3 = self._ipw1, self._ipw2, self._ipw3
        if ipw1.running:
            ipw1.ipw.place_widget()
        if ipw2.running:
            ipw2.ipw.place_widget()
            ipw2.ipw.plane_orientation = 'y_axes'
        if ipw3.running:
            ipw3.ipw.place_widget()
            ipw3.ipw.plane_orientation = 'z_axes'
        self.source.render()

    def _data_changed(self, value):
        if self.source is None:
            return
        self.source.scalar_data = value

    def _window_changed(self):
        m = self.get_mayavi()
        if m.engine.running:
            if len(self.data) == 0:
                # Happens since the window may be set on __init__ at
                # which time the data is not created.
                self._make_data()
            self._show_data()
        else:
            # Show the data once the mayavi engine has started.
            m.engine.on_trait_change(self._show_data, 'started')
Пример #29
0
class T1(Hahn):
    def generate_sequence(self):
        tau = self.tau
        laser = self.laser
        wait = self.wait
        t_pi2 = self.t_pi2
        t_pi = self.t_pi
        t_3pi2 = self.t_3pi2
        sequence = []
        for t in tau:
            #sequence.append(  (['mw'   ],            t_3pi2  )  )
            #sequence.append(  (['mw'   ],            t_pi  )  )
            sequence.append(([], t))
            sequence.append((['laser', 'trigger'], laser))
            sequence.append(([], wait))
            #for t in tau:
            sequence.append((['mw'], t_pi))
            sequence.append(([], t))
            sequence.append((['laser', 'trigger'], laser))
            sequence.append(([], wait))
        return sequence

    def _get_sequence_points(self):
        return 2 * len(self.tau)

    get_set_items = Pulsed.get_set_items + ['t_pi2', 't_pi', 't_3pi2']

    traits_view = View(
        VGroup(
            HGroup(
                Item('submit_button', show_label=False),
                Item('remove_button', show_label=False),
                Item('resubmit_button', show_label=False),
                Item('priority'),
            ),
            HGroup(
                Item('freq', width=40),
                Item('power', width=20),
                Item('t_pi2', width=20),
                Item('t_pi', width=20),
                Item('t_3pi2', width=20),
            ),
            HGroup(Item('tau_begin', width=20), Item('tau_end', width=20),
                   Item('tau_delta', width=20), Item('rabi_contrast',
                                                     width=20)),
            HGroup(
                Item('laser', width=40),
                Item('wait', width=40),
                Item('bin_width', width=40),
            ),
            HGroup(
                Item('state', style='readonly'),
                Item('run_time', style='readonly', format_str='%.f', width=50),
                Item('sweeps',
                     editor=TextEditor(auto_set=False,
                                       enter_set=True,
                                       evaluate=float,
                                       format_func=lambda x: '%.3e' % x),
                     width=40),
                Item('expected_duration',
                     style='readonly',
                     editor=TextEditor(evaluate=float,
                                       format_func=lambda x: '%.3e' % x),
                     width=40),
                Item('progress', style='readonly'),
                Item('elapsed_time', style='readonly'),
            ),
        ),
        title='T1',
    )
Пример #30
0
class TemplatePicker(HasTraits):
    template = Array
    CC = Array
    peaks = List
    zero = Int(0)
    tmp_size = Range(low=2, high=512, value=64, cols=4)
    max_pos_x = Int(1023)
    max_pos_y = Int(1023)
    top = Range(low='zero', high='max_pos_x', value=20, cols=4)
    left = Range(low='zero', high='max_pos_y', value=20, cols=4)
    is_square = Bool
    img_plot = Instance(Plot)
    tmp_plot = Instance(Plot)
    findpeaks = Button
    peak_width = Range(low=2, high=200, value=10)
    tab_selected = Event
    ShowCC = Bool
    img_container = Instance(Component)
    container = Instance(Component)
    colorbar = Instance(Component)
    numpeaks_total = Int(0)
    numpeaks_img = Int(0)
    OK_custom = OK_custom_handler
    cbar_selection = Instance(RangeSelection)
    cbar_selected = Event
    thresh = Trait(None, None, List, Tuple, Array)
    thresh_upper = Float(1.0)
    thresh_lower = Float(0.0)
    numfiles = Int(1)
    img_idx = Int(0)
    tmp_img_idx = Int(0)

    csr = Instance(BaseCursorTool)

    traits_view = View(HFlow(
        VGroup(Item("img_container",
                    editor=ComponentEditor(),
                    show_label=False),
               Group(
                   Spring(),
                   Item("ShowCC",
                        editor=BooleanEditor(),
                        label="Show cross correlation image")),
               label="Original image",
               show_border=True,
               trait_modified="tab_selected"),
        VGroup(
            Group(HGroup(
                Item("left", label="Left coordinate", style="custom"),
                Item("top", label="Top coordinate", style="custom"),
            ),
                  Item("tmp_size", label="Template size", style="custom"),
                  Item("tmp_plot",
                       editor=ComponentEditor(height=256, width=256),
                       show_label=False,
                       resizable=True),
                  label="Template",
                  show_border=True),
            Group(Item("peak_width", label="Peak width", style="custom"),
                  Group(
                      Spring(),
                      Item("findpeaks",
                           editor=ButtonEditor(label="Find Peaks"),
                           show_label=False),
                      Spring(),
                  ),
                  HGroup(
                      Item("thresh_lower",
                           label="Threshold Lower Value",
                           editor=TextEditor(evaluate=float,
                                             format_str='%1.4f')),
                      Item("thresh_upper",
                           label="Threshold Upper Value",
                           editor=TextEditor(evaluate=float,
                                             format_str='%1.4f')),
                  ),
                  HGroup(
                      Item("numpeaks_img",
                           label="Number of Cells selected (this image)",
                           style='readonly'),
                      Spring(),
                      Item("numpeaks_total", label="Total", style='readonly'),
                      Spring(),
                  ),
                  label="Peak parameters",
                  show_border=True),
        )),
                       buttons=[
                           Action(name='OK',
                                  enabled_when='numpeaks_total > 0'),
                           CancelButton
                       ],
                       title="Template Picker",
                       handler=OK_custom,
                       kind='livemodal',
                       key_bindings=key_bindings,
                       width=960,
                       height=600)

    def __init__(self, signal_instance):
        super(TemplatePicker, self).__init__()
        try:
            import cv
        except:
            print "OpenCV unavailable.  Can't do cross correlation without it.  Aborting."
            return None
        self.OK_custom = OK_custom_handler()
        self.sig = signal_instance
        if not hasattr(self.sig.mapped_parameters, "original_files"):
            self.sig.data = np.atleast_3d(self.sig.data)
            self.titles = [self.sig.mapped_parameters.name]
        else:
            self.numfiles = len(
                self.sig.mapped_parameters.original_files.keys())
            self.titles = self.sig.mapped_parameters.original_files.keys()
        tmp_plot_data = ArrayPlotData(
            imagedata=self.sig.data[self.top:self.top + self.tmp_size,
                                    self.left:self.left + self.tmp_size,
                                    self.img_idx])
        tmp_plot = Plot(tmp_plot_data, default_origin="top left")
        tmp_plot.img_plot("imagedata", colormap=jet)
        tmp_plot.aspect_ratio = 1.0
        self.tmp_plot = tmp_plot
        self.tmp_plotdata = tmp_plot_data
        self.img_plotdata = ArrayPlotData(
            imagedata=self.sig.data[:, :, self.img_idx])
        self.img_container = self._image_plot_container()

        self.crop_sig = None

    def render_image(self):
        plot = Plot(self.img_plotdata, default_origin="top left")
        img = plot.img_plot("imagedata", colormap=gray)[0]
        plot.title = "%s of %s: " % (self.img_idx + 1,
                                     self.numfiles) + self.titles[self.img_idx]
        plot.aspect_ratio = float(self.sig.data.shape[1]) / float(
            self.sig.data.shape[0])

        #if not self.ShowCC:
        csr = CursorTool(img,
                         drag_button='left',
                         color='white',
                         line_width=2.0)
        self.csr = csr
        csr.current_position = self.left, self.top
        img.overlays.append(csr)

        # attach the rectangle tool
        plot.tools.append(PanTool(plot, drag_button="right"))
        zoom = ZoomTool(plot,
                        tool_mode="box",
                        always_on=False,
                        aspect_ratio=plot.aspect_ratio)
        plot.overlays.append(zoom)
        self.img_plot = plot
        return plot

    def render_scatplot(self):
        peakdata = ArrayPlotData()
        peakdata.set_data("index", self.peaks[self.img_idx][:, 0])
        peakdata.set_data("value", self.peaks[self.img_idx][:, 1])
        peakdata.set_data("color", self.peaks[self.img_idx][:, 2])
        scatplot = Plot(peakdata,
                        aspect_ratio=self.img_plot.aspect_ratio,
                        default_origin="top left")
        scatplot.plot(
            ("index", "value", "color"),
            type="cmap_scatter",
            name="my_plot",
            color_mapper=jet(DataRange1D(low=0.0, high=1.0)),
            marker="circle",
            fill_alpha=0.5,
            marker_size=6,
        )
        scatplot.x_grid.visible = False
        scatplot.y_grid.visible = False
        scatplot.range2d = self.img_plot.range2d
        self.scatplot = scatplot
        self.peakdata = peakdata
        return scatplot

    def _image_plot_container(self):
        plot = self.render_image()

        # Create a container to position the plot and the colorbar side-by-side
        self.container = OverlayPlotContainer()
        self.container.add(plot)
        self.img_container = HPlotContainer(use_backbuffer=False)
        self.img_container.add(self.container)
        self.img_container.bgcolor = "white"

        if self.numpeaks_img > 0:
            scatplot = self.render_scatplot()
            self.container.add(scatplot)
            colorbar = self.draw_colorbar()
            self.img_container.add(colorbar)
        return self.img_container

    def draw_colorbar(self):
        scatplot = self.scatplot
        cmap_renderer = scatplot.plots["my_plot"][0]
        selection = ColormappedSelectionOverlay(cmap_renderer,
                                                fade_alpha=0.35,
                                                selection_type="range")
        cmap_renderer.overlays.append(selection)
        if self.thresh is not None:
            cmap_renderer.color_data.metadata['selections'] = self.thresh
            cmap_renderer.color_data.metadata_changed = {
                'selections': self.thresh
            }
        # Create the colorbar, handing in the appropriate range and colormap
        colormap = scatplot.color_mapper
        colorbar = ColorBar(
            index_mapper=LinearMapper(range=DataRange1D(low=0.0, high=1.0)),
            orientation='v',
            resizable='v',
            width=30,
            padding=20)
        colorbar_selection = RangeSelection(component=colorbar)
        colorbar.tools.append(colorbar_selection)
        ovr = colorbar.overlays.append(
            RangeSelectionOverlay(component=colorbar,
                                  border_color="white",
                                  alpha=0.8,
                                  fill_color="lightgray",
                                  metadata_name='selections'))
        #ipshell('colorbar, colorbar_selection and ovr available:')
        self.cbar_selection = colorbar_selection
        self.cmap_renderer = cmap_renderer
        colorbar.plot = cmap_renderer
        colorbar.padding_top = scatplot.padding_top
        colorbar.padding_bottom = scatplot.padding_bottom
        self.colorbar = colorbar
        return colorbar

    @on_trait_change('ShowCC')
    def toggle_cc_view(self):
        if self.ShowCC:
            self.CC = cv_funcs.xcorr(
                self.sig.data[self.top:self.top + self.tmp_size,
                              self.left:self.left + self.tmp_size,
                              self.img_idx], self.sig.data[:, :, self.img_idx])
            self.img_plotdata.set_data("imagedata", self.CC)
        else:
            self.img_plotdata.set_data("imagedata",
                                       self.sig.data[:, :, self.img_idx])
        self.redraw_plots()

    @on_trait_change("img_idx")
    def update_img_depth(self):
        if self.ShowCC:
            self.CC = cv_funcs.xcorr(
                self.sig.data[self.top:self.top + self.tmp_size,
                              self.left:self.left + self.tmp_size,
                              self.img_idx], self.sig.data[:, :, self.img_idx])
            self.img_plotdata.set_data("imagedata", self.CC)
        else:
            self.img_plotdata.set_data("imagedata",
                                       self.sig.data[:, :, self.img_idx])
        self.img_plot.title = "%s of %s: " % (
            self.img_idx + 1, self.numfiles) + self.titles[self.img_idx]
        self.redraw_plots()

    @on_trait_change('tmp_size')
    def update_max_pos(self):
        max_pos_x = self.sig.data.shape[0] - self.tmp_size - 1
        if self.left > max_pos_x: self.left = max_pos_x
        self.max_pos_x = max_pos_x
        max_pos_y = self.sig.data.shape[1] - self.tmp_size - 1
        if self.top > max_pos_y: self.top = max_pos_y
        self.max_pos_y = max_pos_y
        return

    def increase_img_idx(self, info):
        if self.img_idx == (self.numfiles - 1):
            self.img_idx = 0
        else:
            self.img_idx += 1

    def decrease_img_idx(self, info):
        if self.img_idx == 0:
            self.img_idx = self.numfiles - 1
        else:
            self.img_idx -= 1

    @on_trait_change('left, top')
    def update_csr_position(self):
        self.csr.current_position = self.left, self.top

    @on_trait_change('csr:current_position')
    def update_top_left(self):
        self.left, self.top = self.csr.current_position

    @on_trait_change('left, top, tmp_size')
    def update_tmp_plot(self):
        self.tmp_plotdata.set_data(
            "imagedata",
            self.sig.data[self.top:self.top + self.tmp_size,
                          self.left:self.left + self.tmp_size, self.img_idx])
        grid_data_source = self.tmp_plot.range2d.sources[0]
        grid_data_source.set_data(np.arange(self.tmp_size),
                                  np.arange(self.tmp_size))
        self.tmp_img_idx = self.img_idx
        return

    @on_trait_change('left, top, tmp_size')
    def update_CC(self):
        if self.ShowCC:
            self.CC = cv_funcs.xcorr(
                self.sig.data[self.top:self.top + self.tmp_size,
                              self.left:self.left + self.tmp_size,
                              self.tmp_img_idx], self.sig.data[:, :,
                                                               self.img_idx])
            self.img_plotdata.set_data("imagedata", self.CC)
            grid_data_source = self.img_plot.range2d.sources[0]
            grid_data_source.set_data(np.arange(self.CC.shape[1]),
                                      np.arange(self.CC.shape[0]))
        if self.numpeaks_total > 0:
            self.peaks = [np.array([[0, 0, -1]])]

    @on_trait_change('cbar_selection:selection')
    def update_thresh(self):
        try:
            thresh = self.cbar_selection.selection
            self.thresh = thresh
            self.cmap_renderer.color_data.metadata['selections'] = thresh
            self.thresh_lower = thresh[0]
            self.thresh_upper = thresh[1]
            #cmap_renderer.color_data.metadata['selection_masks']=self.thresh
            self.cmap_renderer.color_data.metadata_changed = {
                'selections': thresh
            }
            self.container.request_redraw()
            self.img_container.request_redraw()
        except:
            pass

    @on_trait_change('thresh_upper,thresh_lower')
    def manual_thresh_update(self):
        self.thresh = [self.thresh_lower, self.thresh_upper]
        self.cmap_renderer.color_data.metadata['selections'] = self.thresh
        self.cmap_renderer.color_data.metadata_changed = {
            'selections': self.thresh
        }
        self.container.request_redraw()
        self.img_container.request_redraw()

    @on_trait_change('peaks,cbar_selection:selection,img_idx')
    def calc_numpeaks(self):
        try:
            thresh = self.cbar_selection.selection
            self.thresh = thresh
        except:
            thresh = []
        if thresh == [] or thresh == () or thresh == None:
            thresh = (0, 1)
        self.numpeaks_total = int(
            np.sum([
                np.sum(
                    np.ma.masked_inside(self.peaks[i][:, 2], thresh[0],
                                        thresh[1]).mask)
                for i in xrange(len(self.peaks))
            ]))
        try:
            self.numpeaks_img = int(
                np.sum(
                    np.ma.masked_inside(self.peaks[self.img_idx][:, 2],
                                        thresh[0], thresh[1]).mask))
        except:
            self.numpeaks_img = 0

    @on_trait_change('findpeaks')
    def locate_peaks(self):
        from hyperspy import peak_char as pc
        peaks = []
        for idx in xrange(self.numfiles):
            self.CC = cv_funcs.xcorr(
                self.sig.data[self.top:self.top + self.tmp_size,
                              self.left:self.left + self.tmp_size,
                              self.tmp_img_idx], self.sig.data[:, :, idx])
            # peak finder needs peaks greater than 1.  Multiply by 255 to scale them.
            pks = pc.two_dim_findpeaks(self.CC * 255,
                                       peak_width=self.peak_width,
                                       medfilt_radius=None)
            pks[:, 2] = pks[:, 2] / 255.
            peaks.append(pks)
        self.peaks = peaks

    def mask_peaks(self, idx):
        thresh = self.cbar_selection.selection
        if thresh == []:
            thresh = (0, 1)
        mpeaks = np.ma.asarray(self.peaks[idx])
        mpeaks[:, 2] = np.ma.masked_outside(mpeaks[:, 2], thresh[0], thresh[1])
        return mpeaks

    @on_trait_change("peaks")
    def redraw_plots(self):
        oldplot = self.img_plot
        self.container.remove(oldplot)
        newplot = self.render_image()
        self.container.add(newplot)
        self.img_plot = newplot

        try:
            # if these haven't been created before, this will fail.  wrap in try to prevent that.
            oldscat = self.scatplot
            self.container.remove(oldscat)
            oldcolorbar = self.colorbar
            self.img_container.remove(oldcolorbar)
        except:
            pass

        if self.numpeaks_img > 0:
            newscat = self.render_scatplot()
            self.container.add(newscat)
            self.scatplot = newscat
            colorbar = self.draw_colorbar()
            self.img_container.add(colorbar)
            self.colorbar = colorbar

        self.container.request_redraw()
        self.img_container.request_redraw()

    def crop_cells_stack(self):
        from eelslab.signals.aggregate import AggregateCells
        if self.numfiles == 1:
            self.crop_sig = self.crop_cells()
            return
        else:
            crop_agg = []
            for idx in xrange(self.numfiles):
                crop_agg.append(self.crop_cells(idx))
            self.crop_sig = AggregateCells(*crop_agg)
            return

    def crop_cells(self, idx=0):
        print "cropping cells..."
        from hyperspy.signals.image import Image
        # filter the peaks that are outside the selected threshold
        peaks = np.ma.compress_rows(self.mask_peaks(idx))
        tmp_sz = self.tmp_size
        data = np.zeros((tmp_sz, tmp_sz, peaks.shape[0]))
        if not hasattr(self.sig.mapped_parameters, "original_files"):
            parent = self.sig
        else:
            parent = self.sig.mapped_parameters.original_files[
                self.titles[idx]]
        for i in xrange(peaks.shape[0]):
            # crop the cells from the given locations
            data[:, :, i] = self.sig.data[peaks[i, 1]:peaks[i, 1] + tmp_sz,
                                          peaks[i,
                                                0]:peaks[i, 0] + tmp_sz, idx]
            crop_sig = Image({
                'data': data,
                'mapped_parameters': {
                    'name': 'Cropped cells from %s' % self.titles[idx],
                    'record_by': 'image',
                    'locations': peaks,
                    'parent': parent,
                }
            })
        return crop_sig
        # attach a class member that has the locations from which the images were cropped
        print "Complete.  "