コード例 #1
0
ファイル: extract_grid.py プロジェクト: sjl421/code-2
class ExtractGrid(FilterBase):
    """This filter enables one to select a portion of, or subsample an
    input dataset which may be a StructuredPoints, StructuredGrid or
    Rectilinear.
    """
    # The version of this class.  Used for persistence.
    __version__ = 0

    # Minimum x value.
    x_min = Range(value=0,
                  low='_x_low',
                  high='_x_high',
                  enter_set=True,
                  auto_set=False,
                  desc='minimum x value of the domain')

    # Maximum x value.
    x_max = Range(value=10000,
                  low='_x_low',
                  high='_x_high',
                  enter_set=True,
                  auto_set=False,
                  desc='maximum x value of the domain')

    # Minimum y value.
    y_min = Range(value=0,
                  low='_y_low',
                  high='_y_high',
                  enter_set=True,
                  auto_set=False,
                  desc='minimum y value of the domain')

    # Maximum y value.
    y_max = Range(value=10000,
                  low='_y_low',
                  high='_y_high',
                  enter_set=True,
                  auto_set=False,
                  desc='maximum y value of the domain')

    # Minimum z value.
    z_min = Range(value=0,
                  low='_z_low',
                  high='_z_high',
                  enter_set=True,
                  auto_set=False,
                  desc='minimum z value of the domain')

    # Maximum z value.
    z_max = Range(value=10000,
                  low='_z_low',
                  high='_z_high',
                  enter_set=True,
                  auto_set=False,
                  desc='maximum z value of the domain')

    # Sample rate in x.
    x_ratio = Range(value=1,
                    low='_min_sample',
                    high='_x_s_high',
                    enter_set=True,
                    auto_set=False,
                    desc='sample rate along x')

    # Sample rate in y.
    y_ratio = Range(value=1,
                    low='_min_sample',
                    high='_y_s_high',
                    enter_set=True,
                    auto_set=False,
                    desc='sample rate along y')

    # Sample rate in z.
    z_ratio = Range(value=1,
                    low='_min_sample',
                    high='_z_s_high',
                    enter_set=True,
                    auto_set=False,
                    desc='sample rate along z')

    # The actual TVTK filter that this class manages.
    filter = Instance(tvtk.Object, tvtk.ExtractVOI(), allow_none=False)

    input_info = PipelineInfo(
        datasets=['image_data', 'rectilinear_grid', 'structured_grid'],
        attribute_types=['any'],
        attributes=['any'])

    output_info = PipelineInfo(
        datasets=['image_data', 'rectilinear_grid', 'structured_grid'],
        attribute_types=['any'],
        attributes=['any'])

    ########################################
    # Private traits.

    # Determines the lower/upper limit of the axes for the sliders.
    _min_sample = Int(1)
    _x_low = Int(0)
    _x_high = Int(10000)
    _x_s_high = Int(100)
    _y_low = Int(0)
    _y_high = Int(10000)
    _y_s_high = Int(100)
    _z_low = Int(0)
    _z_high = Int(10000)
    _z_s_high = Int(100)

    ########################################
    # View related traits.

    # The View for this object.
    view = View(
        Group(Item(label='Select Volume Of Interest'),
              Item(name='x_min'),
              Item(name='x_max'),
              Item(name='y_min'),
              Item(name='y_max'),
              Item(name='z_min'),
              Item(name='z_max'),
              Item('_'),
              Item(label='Select Sample Ratio'),
              Item(name='x_ratio'),
              Item(name='y_ratio'),
              Item(name='z_ratio'),
              label='VOI'),
        Group(Item(name='filter', style='custom', resizable=True),
              show_labels=False,
              label='Filter'),
        resizable=True,
    )

    ######################################################################
    # `object` interface
    ######################################################################
    def __get_pure_state__(self):
        d = super(ExtractGrid, self).__get_pure_state__()
        for axis in ('x', 'y', 'z'):
            for name in ('_min', '_max'):
                d.pop(axis + name, None)
            d.pop('_' + axis + '_low', None)
            d.pop('_' + axis + '_high', None)
            d.pop('_' + axis + '_s_high', None)
            d.pop(axis + '_ratio', None)
        return d

    ######################################################################
    # `Filter` interface
    ######################################################################
    def update_pipeline(self):
        inputs = self.inputs
        if len(inputs) == 0:
            return

        input = inputs[0].outputs[0]
        mapping = {
            'vtkStructuredGrid': tvtk.ExtractGrid,
            'vtkRectilinearGrid': tvtk.ExtractRectilinearGrid,
            'vtkImageData': tvtk.ExtractVOI
        }

        for key, klass in mapping.iteritems():
            if input.is_a(key):
                self.filter = klass()
                break
        else:
            error('This filter does not support %s objects'%\
                  (input.__class__.__name__))
            return

        fil = self.filter
        fil.input = input
        fil.update_whole_extent()
        fil.update()
        self._set_outputs([fil.output])
        self._update_limits()
        self._update_voi()
        self._update_sample_rate()

    def update_data(self):
        """This method is invoked (automatically) when any of the
        inputs sends a `data_changed` event.
        """
        self._update_limits()
        fil = self.filter
        fil.update_whole_extent()
        fil.update()
        # Propagate the data_changed event.
        self.data_changed = True

    ######################################################################
    # Non-public methods.
    ######################################################################
    def _update_limits(self):
        extents = self.filter.input.whole_extent
        self._x_low, self._x_high = extents[:2]
        self._y_low, self._y_high = extents[2:4]
        self._z_low, self._z_high = extents[4:]
        self._x_s_high = max(1, self._x_high)
        self._y_s_high = max(1, self._y_high)
        self._z_s_high = max(1, self._z_high)

    def _x_min_changed(self, val):
        if val > self.x_max:
            self.x_max = val
        else:
            self._update_voi()

    def _x_max_changed(self, val):
        if val < self.x_min:
            self.x_min = val
        else:
            self._update_voi()

    def _y_min_changed(self, val):
        if val > self.y_max:
            self.y_max = val
        else:
            self._update_voi()

    def _y_max_changed(self, val):
        if val < self.y_min:
            self.y_min = val
        else:
            self._update_voi()

    def _z_min_changed(self, val):
        if val > self.z_max:
            self.z_max = val
        else:
            self._update_voi()

    def _z_max_changed(self, val):
        if val < self.z_min:
            self.z_min = val
        else:
            self._update_voi()

    def _x_ratio_changed(self):
        self._update_sample_rate()

    def _y_ratio_changed(self):
        self._update_sample_rate()

    def _z_ratio_changed(self):
        self._update_sample_rate()

    def _update_voi(self):
        f = self.filter
        f.voi = (self.x_min, self.x_max, self.y_min, self.y_max, self.z_min,
                 self.z_max)
        f.update_whole_extent()
        f.update()
        self.data_changed = True

    def _update_sample_rate(self):
        f = self.filter
        f.sample_rate = (self.x_ratio, self.y_ratio, self.z_ratio)
        f.update_whole_extent()
        f.update()
        self.data_changed = True

    def _filter_changed(self, old, new):
        if old is not None:
            old.on_trait_change(self.render, remove=True)
        new.on_trait_change(self.render)
コード例 #2
0
ファイル: singleshot.py プロジェクト: Faridelnik/Pi3Diamond
class Pulsed(ManagedJob, GetSetItemsMixin):
    """Defines a pulsed measurement."""
    keep_data = Bool(
        False)  # helper variable to decide whether to keep existing data

    resubmit_button = Button(
        label='resubmit',
        desc=
        'Submits the measurement to the job manager. Tries to keep previously acquired data. Behaves like a normal submit if sequence or time bins have changed since previous run.'
    )

    sequence = Instance(list, factory=list)

    record_length = Float(value=0,
                          desc='length of acquisition record [ms]',
                          label='record length [ms] ',
                          mode='text')

    count_data = Array(value=np.zeros(2))

    run_time = Float(value=0.0, label='run time [ns]', format_str='%.f')
    stop_time = Range(
        low=1.,
        value=np.inf,
        desc='Time after which the experiment stops by itself [s]',
        label='Stop time [s]',
        mode='text',
        auto_set=False,
        enter_set=True)

    tau_begin = Range(low=0.,
                      high=1e5,
                      value=300.,
                      desc='tau begin [ns]',
                      label='repetition',
                      mode='text',
                      auto_set=False,
                      enter_set=True)
    tau_end = Range(low=1.,
                    high=1e5,
                    value=4000.,
                    desc='tau end [ns]',
                    label='N repetition',
                    mode='text',
                    auto_set=False,
                    enter_set=True)
    tau_delta = Range(low=1.,
                      high=1e5,
                      value=50.,
                      desc='delta tau [ns]',
                      label='delta',
                      mode='text',
                      auto_set=False,
                      enter_set=True)

    tau = Array(value=np.array((0., 1.)))
    sequence_points = Int(value=2, label='number of points', mode='text')

    laser_SST = Range(low=1.,
                      high=5e6,
                      value=200.,
                      desc='laser for SST [ns]',
                      label='laser_SST[ns]',
                      mode='text',
                      auto_set=False,
                      enter_set=True)
    wait_SST = Range(low=1.,
                     high=5e6,
                     value=1000.,
                     desc='wait for SST[ns]',
                     label='wait_SST [ns]',
                     mode='text',
                     auto_set=False,
                     enter_set=True)
    N_shot = Range(low=1,
                   high=20e5,
                   value=2e3,
                   desc='number of shots in SST',
                   label='N_shot',
                   mode='text',
                   auto_set=False,
                   enter_set=True)

    laser = Range(low=1.,
                  high=5e4,
                  value=3000,
                  desc='laser [ns]',
                  label='laser [ns]',
                  mode='text',
                  auto_set=False,
                  enter_set=True)
    wait = Range(low=1.,
                 high=5e4,
                 value=5000.,
                 desc='wait [ns]',
                 label='wait [ns]',
                 mode='text',
                 auto_set=False,
                 enter_set=True)

    freq_center = Range(low=1,
                        high=20e9,
                        value=2.71e9,
                        desc='frequency [Hz]',
                        label='MW freq[Hz]',
                        editor=TextEditor(auto_set=False,
                                          enter_set=True,
                                          evaluate=float,
                                          format_str='%.4e'))
    power = Range(low=-100.,
                  high=25.,
                  value=-26,
                  desc='power [dBm]',
                  label='power[dBm]',
                  editor=TextEditor(auto_set=False,
                                    enter_set=True,
                                    evaluate=float))
    freq = Range(low=1,
                 high=20e9,
                 value=2.71e9,
                 desc='frequency [Hz]',
                 label='freq [Hz]',
                 editor=TextEditor(auto_set=False,
                                   enter_set=True,
                                   evaluate=float,
                                   format_str='%.4e'))
    pi = Range(low=0.,
               high=5e4,
               value=2e3,
               desc='pi pulse length',
               label='pi [ns]',
               mode='text',
               auto_set=False,
               enter_set=True)

    amp = Range(low=0.,
                high=1.0,
                value=1.0,
                desc='Normalized amplitude of waveform',
                label='Amp',
                mode='text',
                auto_set=False,
                enter_set=True)
    vpp = Range(low=0.,
                high=4.5,
                value=0.6,
                desc='Amplitude of AWG [Vpp]',
                label='Vpp',
                mode='text',
                auto_set=False,
                enter_set=True)

    sweeps = Range(low=1.,
                   high=1e4,
                   value=1e2,
                   desc='number of sweeps',
                   label='sweeps',
                   mode='text',
                   auto_set=False,
                   enter_set=True)
    expected_duration = Property(
        trait=Float,
        depends_on='sweeps,sequence',
        desc='expected duration of the measurement [s]',
        label='expected duration [s]')
    elapsed_sweeps = Float(value=0,
                           desc='Elapsed Sweeps ',
                           label='Elapsed Sweeps ',
                           mode='text')
    elapsed_time = Float(value=0,
                         desc='Elapsed Time [ns]',
                         label='Elapsed Time [ns]',
                         mode='text')
    progress = Int(value=0,
                   desc='Progress [%]',
                   label='Progress [%]',
                   mode='text')

    load_button = Button(desc='compile and upload waveforms to AWG',
                         label='load')
    reload = True

    readout_interval = Float(
        1,
        label='Data readout interval [s]',
        desc='How often data read is requested from nidaq')
    samples_per_read = Int(
        200,
        label='# data points per read',
        desc=
        'Number of data points requested from nidaq per read. Nidaq will automatically wait for the data points to be aquired.'
    )

    def submit(self):
        """Submit the job to the JobManager."""
        self.keep_data = False
        ManagedJob.submit(self)

    def resubmit(self):
        """Submit the job to the JobManager."""
        self.keep_data = True
        ManagedJob.submit(self)

    def _resubmit_button_fired(self):
        """React to start button. Submit the Job."""
        self.resubmit()

    def generate_sequence(self):
        return []

    def prepare_awg(self):
        """ override this """
        AWG.reset()

    def _load_button_changed(self):
        self.load()

    def load(self):
        self.reload = True
        # update record_length, in ms
        self.record_length = self.N_shot * (self.pi + self.laser_SST +
                                            self.wait_SST) * 1e-6
        #make sure tau is updated
        self.tau = np.arange(self.tau_begin, self.tau_end, self.tau_delta)
        self.prepare_awg()
        self.reload = False

    @cached_property
    def _get_expected_duration(self):
        sequence_length = 0
        for step in self.sequence:
            sequence_length += step[1]
        return self.sweeps * sequence_length * 1e-9

    def _get_sequence_points(self):
        return len(self.tau)

    def apply_parameters(self):
        """Apply the current parameters and decide whether to keep previous data."""
        """if load button is not used, make sure tau is generated"""
        if (self.tau.shape[0] == 2):
            tau = np.arange(self.tau_begin, self.tau_end, self.tau_delta)
            self.tau = tau

        self.sequence_points = self._get_sequence_points()
        self.measurement_points = self.sequence_points * int(self.sweeps)
        sequence = self.generate_sequence()

        if self.keep_data and sequence == self.sequence:  # if the sequence and time_bins are the same as previous, keep existing data

            self.previous_sweeps = self.elapsed_sweeps
            self.previous_elapsed_time = self.elapsed_time
            self.keep_data = True  # when job manager stops and starts the job, data should be kept. Only new submission should clear data.
        else:

            #self.old_count_data = np.zeros((n_laser, n_bins))
            #self.check = True

            self.count_data = np.zeros(self.measurement_points)
            self.old_count_data = np.zeros(self.measurement_points)
            self.previous_sweeps = 0
            self.previous_elapsed_time = 0.0
            self.run_time = 0.0
            self.keep_data = True  # when job manager stops and starts the job, data should be kept. Only new submission should clear data.

        self.sequence = sequence

    def _run(self):
        """Acquire data."""

        try:  # try to run the acquisition from start_up to shut_down
            self.state = 'run'
            self.apply_parameters()

            PG.High([])

            self.prepare_awg()
            MW.setFrequency(self.freq_center)
            MW.setPower(self.power)

            AWG.run()
            time.sleep(4.0)
            PG.Sequence(self.sequence, loop=True)

            if CS.configure(
            ) != 0:  # initialize and start nidaq gated counting task, return 0 if succuessful
                print 'error in nidaq'
                return

            start_time = time.time()

            aquired_data = np.empty(
                0)  # new data will be appended to this array

            while True:

                self.thread.stop_request.wait(self.readout_interval)
                if self.thread.stop_request.isSet():
                    logging.getLogger().debug('Caught stop signal. Exiting.')
                    break

                #threading.current_thread().stop_request.wait(self.readout_interval) # wait for some time before new read command is given. not sure if this is neccessary
                #if threading.current_thread().stop_request.isSet():
                #break

                points_left = self.measurement_points - len(aquired_data)

                self.elapsed_time = self.previous_elapsed_time + time.time(
                ) - start_time
                self.run_time += self.elapsed_time

                new_data = CS.read_gated_counts(SampleLength=min(
                    self.samples_per_read, points_left
                ))  # do not attempt to read more data than neccessary

                aquired_data = np.append(
                    aquired_data, new_data[:min(len(new_data), points_left)])

                self.count_data[:len(
                    aquired_data
                )] = aquired_data[:]  # length of trace may not change due to plot, so just copy aquired data into trace

                sweeps = len(aquired_data) / self.sequence_points
                self.elapsed_sweeps += self.previous_sweeps + sweeps
                self.progress = int(100 * len(aquired_data) /
                                    self.measurement_points)

                if self.progress > 99.9:
                    break

            MW.Off()
            PG.High(['laser', 'mw'])
            AWG.stop()

            if self.elapsed_sweeps < self.sweeps:
                self.state = 'idle'
            else:
                self.state = 'done'

        except:  # if anything fails, log the exception and set the state
            logging.getLogger().exception(
                'Something went wrong in pulsed loop.')
            self.state = 'error'

        finally:
            CS.stop_gated_counting()  # stop nidaq task to free counters

    get_set_items = [
        '__doc__', 'record_length', 'laser', 'wait', 'sequence', 'count_data',
        'run_time', 'tau_begin', 'tau_end', 'tau_delta', 'tau', 'freq_center',
        'power', 'laser_SST', 'wait_SST', 'amp', 'vpp', 'pi', 'freq', 'N_shot',
        'readout_interval', 'samples_per_read'
    ]

    traits_view = View(
        VGroup(
            HGroup(
                Item('load_button', show_label=False),
                Item('submit_button', show_label=False),
                Item('remove_button', show_label=False),
                Item('resubmit_button', show_label=False),
                Item('priority'),
            ),
            HGroup(
                Item('freq', width=-70),
                Item('freq_center', width=-70),
                Item('amp', width=-30),
                Item('vpp', width=-30),
                Item('power', width=-40),
                Item('pi', width=-70),
            ),
            HGroup(
                Item('laser', width=-60),
                Item('wait', width=-60),
                Item('laser_SST', width=-50),
                Item('wait_SST', width=-50),
            ),
            HGroup(
                Item('samples_per_read', width=-50),
                Item('N_shot', width=-50),
                Item('record_length', style='readonly'),
            ),
            HGroup(
                Item('tau_begin', width=30),
                Item('tau_end', width=30),
                Item('tau_delta', width=30),
            ),
            HGroup(
                Item('state', style='readonly'),
                Item('run_time', style='readonly', format_str='%.f',
                     width=-60),
                Item('sweeps',
                     editor=TextEditor(auto_set=False,
                                       enter_set=True,
                                       evaluate=float,
                                       format_func=lambda x: '%.3e' % x),
                     width=-50),
                Item('expected_duration',
                     style='readonly',
                     editor=TextEditor(evaluate=float,
                                       format_func=lambda x: '%.2f' % x),
                     width=30),
                Item('progress', style='readonly'),
                Item('elapsed_time',
                     style='readonly',
                     editor=TextEditor(auto_set=False,
                                       enter_set=True,
                                       evaluate=float,
                                       format_func=lambda x: ' %.f' % x),
                     width=-50),
            ),
        ),
        title='Pulsed_SST Measurement',
    )
コード例 #3
0
class TFGaussian(HasTraits):
    center = Range(low = 'left_edge',
                   high = 'right_edge')
    left_edge = DelegatesTo('tf')
    right_edge = DelegatesTo('tf')

    tf = Any
    
    width = Property
    rwidth = Range(0.0, 0.5, 0.05)

    red = Range(0.0, 1.0, 0.5)
    green = Range(0.0, 1.0, 0.5)
    blue = Range(0.0, 1.0, 0.5)
    alpha = Range(0.0, 1.0, 1.0)

    traits_view = View(VGroup(
                         HGroup(
                    Item('center', editor=RangeEditor(format='%0.4f')),
                    Item('rwidth', label='Width',
                               editor=RangeEditor(format='%0.4f')),
                         ),
                         HGroup(
                    Item('red', editor=RangeEditor(format='%0.4f')),
                    Item('green', editor=RangeEditor(format='%0.4f')),
                    Item('blue', editor=RangeEditor(format='%0.4f')),
                    Item('alpha', editor=RangeEditor(format='%0.4f'))
                               ),
                       show_border=True,),
                       )

    def _get_width(self):
        width = self.rwidth * (self.tf.right_edge - self.tf.left_edge)
        return width

    def _center_default(self):
        return (self.left_edge + self.right_edge)/2.0

    def _width_default(self):
        return (self.right_edge - self.left_edge)/20.0

    def _red_changed(self):
        self.tf._redraw()

    def _green_changed(self):
        self.tf._redraw()

    def _blue_changed(self):
        self.tf._redraw()

    def _alpha_changed(self):
        self.tf._redraw()

    def _center_changed(self):
        self.tf._redraw()

    def _height_changed(self):
        self.tf._redraw()

    def _rwidth_changed(self):
        self.tf._redraw()
コード例 #4
0
ファイル: range_editor.py プロジェクト: jtomase/matplotlib
class ToolkitEditorFactory(EditorFactory):

    #---------------------------------------------------------------------------
    #  Trait definitions:
    #---------------------------------------------------------------------------

    cols = Range(1, 20)  # Number of columns when displayed as an enum
    auto_set = true  # Is user input set on every keystroke?
    enter_set = false  # Is user input set on enter key?
    low_label = Str  # Label for low end of range
    high_label = Str  # Label for high end of range
    is_float = true  # Is the range float (or int)?

    #---------------------------------------------------------------------------
    #  Performs any initialization needed after all constructor traits have
    #  been set:
    #---------------------------------------------------------------------------

    def init(self, handler=None):
        """ Performs any initialization needed after all constructor traits 
            have been set.
        """
        if handler is not None:
            if isinstance(handler, CTrait):
                handler = handler.handler
            self.low = handler.low
            self.high = handler.high

    #---------------------------------------------------------------------------
    #  Define the 'low' and 'high' traits:
    #---------------------------------------------------------------------------

    def _get_low(self):
        return self._low

    def _set_low(self, low):
        self._low = low
        self.is_float = (type(low) is float)
        if self.low_label == '':
            self.low_label = str(low)

    def _get_high(self):
        return self._high

    def _set_high(self, high):
        self._high = high
        self.is_float = (type(high) is float)
        if self.high_label == '':
            self.high_label = str(high)

    low = Property(_get_low, _set_low)
    high = Property(_get_high, _set_high)

    #---------------------------------------------------------------------------
    #  'Editor' factory methods:
    #---------------------------------------------------------------------------

    def simple_editor(self, ui, object, name, description, parent):
        if self.is_float or (abs(self.high - self.low) <= 100):
            return SimpleSliderEditor(parent,
                                      factory=self,
                                      ui=ui,
                                      object=object,
                                      name=name,
                                      description=description)
        return SimpleSpinEditor(parent,
                                factory=self,
                                ui=ui,
                                object=object,
                                name=name,
                                description=description)

    def custom_editor(self, ui, object, name, description, parent):
        if self.is_float or (abs(self.high - self.low) > 15):
            return self.simple_editor(ui, object, name, description, parent)

        if self._enum is None:
            import enum_editor
            self._enum = enum_editor.ToolkitEditorFactory(values=range(
                self.low, self.high + 1),
                                                          cols=self.cols)
        return self._enum.custom_editor(ui, object, name, description, parent)

    def text_editor(self, ui, object, name, description, parent):
        return RangeTextEditor(parent,
                               factory=self,
                               ui=ui,
                               object=object,
                               name=name,
                               description=description)
コード例 #5
0
class SelectOutput(Filter):
    """
    This filter lets a user select one among several of the outputs of a
    given input.  This is typically very useful for a multi-block data
    source.  
    """

    # The output index in the input to choose from.
    output_index = Range(value=0,
                         enter_set=True, 
                         auto_set=False,
                         low='_min_index',
                         high='_max_index')
    
    input_info = PipelineInfo(datasets=['any'],
                              attribute_types=['any'],
                              attributes=['any'])

    output_info = PipelineInfo(datasets=['any'],
                               attribute_types=['any'],
                               attributes=['any'])

    # The minimum output index of our input.
    _min_index = Int(0, desc='the minimum output index')
    # The maximum output index of our input.
    _max_index = Int(0, desc='the maximum output index')

    ######################################## 
    # Traits View.

    view = View(Group(Item('output_index', 
                           enabled_when='_max_index > 0')),
                resizable=True)

    ######################################################################
    # `object` interface.
    def __get_pure_state__(self):
        d = super(SelectOutput, self).__get_pure_state__()
        d['output_index'] = self.output_index
        return d

    def __set_pure_state__(self, state):
        super(SelectOutput, self).__set_pure_state__(state)
        # Force an update of the output index -- if not this doesn't
        # change. 
        self._output_index_changed(state.output_index)

    ######################################################################
    # `Filter` interface.
    def update_pipeline(self):
        # Do nothing if there is no input.
        inputs = self.inputs
        if len(inputs) == 0:
            return

        # Set the maximum index.
        self._max_index = len(inputs[0].outputs) - 1
        self._output_index_changed(self.output_index)

    def update_data(self):
        # Propagate the event.
        self.data_changed = True

    ######################################################################
    # Trait handlers.
    def _output_index_changed(self, value):
        """Static trait handler."""
        if value > self._max_index:
            self.output_index = self._max_index
        elif value < self._min_index:
            self.output_index = self._min_index
        else:
            self._set_outputs([self.inputs[0].outputs[value]])
            s = self.scene
            if s is not None:
                s.renderer.reset_camera_clipping_range()
                s.render()
コード例 #6
0
class MainWindow(HasTraits):
    """Main window for the viewer built using Traits."""

    # mpl figure
    figure = Instance(Figure)

    # Range slider for selecing slice to view
    slice_index_low = Int(0)  # These have to be trait ints or they don't work
    slice_index_high = Int(
        91)  # with the dynamic updating of the Range slider.
    slice_index = Range(low='slice_index_low', high='slice_index_high')

    # Radio box for selecting orthogonal slice
    slice_plane = Enum(_slice_planes)

    # Affine TextCtrl
    affine = Array(Float, (4, 4))

    def __init__(self):
        super(MainWindow, self).__init__()
        # Initialize our nipy image object
        self.img = ImageData()
        # Initialize our matplotlib figure
        self.img_plot = SingleImage(self.figure, self.img.data)

    #
    # Initializers for Traited attrs
    #
    def _figure_default(self):
        """Initialize matplotlib figure."""
        figure = Figure()
        return figure

    def _slice_index_default(self):
        """Initialize slice_index attr without triggering the
        on_trait_change method.
        """
        return 0

    #
    # Event handlers
    #
    @on_trait_change('slice_index, slice_plane')
    def update_slice_index(self):
        self.img.set_slice_index(self.slice_index)
        self.update_image_slicing()
        self.image_show()

    #
    # Data Model methods
    #
    def update_affine(self):
        self.affine = self.img.get_affine()

    def update_image_slicing(self):

        # XXX: BUG: self.slice_index is set by the slider of the
        # current slice.  When we switch the slice plane, this index
        # may be outside the range of the new slice.  Need to handle
        # this.

        if self.slice_plane == 'Axial':
            self.img.set_slice_plane(_slice_planes[0])
        elif self.slice_plane == 'Sagittal':
            self.img.set_slice_plane(_slice_planes[1])
        elif self.slice_plane == 'Coronal':
            self.img.set_slice_plane(_slice_planes[2])
        else:
            raise AttributeError('Unknown slice plane')

        # update image array
        self.img.update_data()

        # update figure data
        self.img_plot.set_data(self.img.data)

        # get range information for slider
        low, high = self.img.get_range()
        # update range slider
        self.slice_index_low = low
        self.slice_index_high = high

    def image_show(self):
        self.img_plot.draw()

    #
    # View code
    #

    # Menus
    def open_menu(self):
        dlg = FileDialog()
        dlg.open()
        if dlg.return_code == OK:
            self.img.load_image(dlg.path)
            self.update_affine()
            self.update_slice_index()

    menu_open_action = Action(name='Open Nifti', action='open_menu')

    file_menubar = MenuBar(Menu(menu_open_action, name='File'))

    # Items
    fig_item = Item('figure', editor=MPLFigureEditor())
    # radio button to pick slice
    _slice_opts = {
        'Axial': '1:Axial',
        'Sagittal': '2:Sagittal',
        'Coronal': '3:Coronal'
    }
    slice_opt_item = Item(name='slice_plane',
                          editor=EnumEditor(values=_slice_opts),
                          style='custom')

    affine_item = Item('affine', label='Affine', style='readonly')
    # BUG: The rendering with the 'readonly' style creates an ugly wx
    # "multi-line" control.

    traits_view = View(HSplit(
        Group(fig_item), Group(affine_item, slice_opt_item,
                               Item('slice_index'))),
                       menubar=file_menubar,
                       width=0.80,
                       height=0.80,
                       resizable=True)
コード例 #7
0
class TriangleWave(HasTraits):
    # 指定三角波的最窄和最宽范围,由于Range类型不能将常数和Traits属性名混用
    # 所以定义这两个值不变的Trait属性
    low = Float(0.02)
    hi = Float(1.0)

    # 三角波形的宽度
    wave_width = Range("low", "hi", 0.5)

    # 三角波的顶点C的x轴坐标
    length_c = Range("low", "wave_width", 0.5)

    # 三角波的定点的y轴坐标
    height_c = Float(1.0)

    # FFT计算所使用的取样点数,这里用一个Enum类型的属性以供用户从列表中选择
    fftsize = Enum([(2**x) for x in range(6, 12)])

    # FFT频谱图的x轴上限值
    fft_graph_up_limit = Range(0, 400, 20)

    # 用于显示FFT的结果
    peak_list = Str

    # 采用多少个频率合成三角波
    N = Range(1, 40, 4)

    # 保存绘图数据的对象
    plot_data = Instance(AbstractPlotData)

    # 绘制波形图的容器
    plot_wave = Instance(Component)

    # 绘制FFT频谱图的容器
    plot_fft = Instance(Component)

    # 包括两个绘图的容器
    container = Instance(Component)

    # 设置用户界面的视图, 注意一定要指定窗口的大小,这样绘图容器才能正常初始化
    view = View(HSplit(
        VSplit(
            VGroup(Item("wave_width", editor=scrubber, label="波形宽度"),
                   Item("length_c", editor=scrubber, label="最高点x坐标"),
                   Item("height_c", editor=scrubber, label="最高点y坐标"),
                   Item("fft_graph_up_limit", editor=scrubber, label="频谱图范围"),
                   Item("fftsize", label="FFT点数"), Item("N", label="合成波频率数")),
            Item("peak_list",
                 style="custom",
                 show_label=False,
                 width=100,
                 height=250)),
        VGroup(Item("container",
                    editor=ComponentEditor(size=(600, 300)),
                    show_label=False),
               orientation="vertical")),
                resizable=True,
                width=800,
                height=600,
                title="三角波FFT演示")

    # 创建绘图的辅助函数,创建波形图和频谱图有很多类似的地方,因此单独用一个函数以
    # 减少重复代码
    def _create_plot(self, data, name, type="line"):
        p = Plot(self.plot_data)
        p.plot(data, name=name, title=name, type=type)
        p.tools.append(PanTool(p))
        zoom = ZoomTool(component=p, tool_mode="box", always_on=False)
        p.overlays.append(zoom)
        p.title = name
        return p

    def __init__(self):
        # 首先需要调用父类的初始化函数
        super(TriangleWave, self).__init__()

        # 创建绘图数据集,暂时没有数据因此都赋值为空,只是创建几个名字,以供Plot引用
        self.plot_data = ArrayPlotData(x=[], y=[], f=[], p=[], x2=[], y2=[])

        # 创建一个垂直排列的绘图容器,它将频谱图和波形图上下排列
        self.container = VPlotContainer()

        # 创建波形图,波形图绘制两条曲线: 原始波形(x,y)和合成波形(x2,y2)
        self.plot_wave = self._create_plot(("x", "y"), "Triangle Wave")
        self.plot_wave.plot(("x2", "y2"), color="red")

        # 创建频谱图,使用数据集中的f和p
        self.plot_fft = self._create_plot(("f", "p"), "FFT", type="scatter")

        # 将两个绘图容器添加到垂直容器中
        self.container.add(self.plot_wave)
        self.container.add(self.plot_fft)

        # 设置
        self.plot_wave.x_axis.title = "Samples"
        self.plot_fft.x_axis.title = "Frequency pins"
        self.plot_fft.y_axis.title = "(dB)"

        # 改变fftsize为1024,因为Enum的默认缺省值为枚举列表中的第一个值
        self.fftsize = 1024

    # FFT频谱图的x轴上限值的改变事件处理函数,将最新的值赋值给频谱图的响应属性
    def _fft_graph_up_limit_changed(self):
        self.plot_fft.x_axis.mapper.range.high = self.fft_graph_up_limit

    def _N_changed(self):
        self.plot_sin_combine()

    # 多个trait属性的改变事件处理函数相同时,可以用@on_trait_change指定
    @on_trait_change("wave_width, length_c, height_c, fftsize")
    def update_plot(self):
        # 计算三角波
        global y_data
        x_data = np.arange(0, 1.0, 1.0 / self.fftsize)
        func = self.triangle_func()
        # 将func函数的返回值强制转换成float64
        y_data = np.cast["float64"](func(x_data))

        # 计算频谱
        fft_parameters = np.fft.fft(y_data) / len(y_data)

        # 计算各个频率的振幅
        fft_data = np.clip(
            20 * np.log10(np.abs(fft_parameters))[:self.fftsize / 2 + 1], -120,
            120)

        # 将计算的结果写进数据集
        self.plot_data.set_data("x", np.arange(0, self.fftsize))  # x坐标为取样点
        self.plot_data.set_data("y", y_data)
        self.plot_data.set_data("f", np.arange(0, len(fft_data)))  # x坐标为频率编号
        self.plot_data.set_data("p", fft_data)

        # 合成波的x坐标为取样点,显示2个周期
        self.plot_data.set_data("x2", np.arange(0, 2 * self.fftsize))

        # 更新频谱图x轴上限
        self._fft_graph_up_limit_changed()

        # 将振幅大于-80dB的频率输出
        peak_index = (fft_data > -80)
        peak_value = fft_data[peak_index][:20]
        result = []
        for f, v in zip(np.flatnonzero(peak_index), peak_value):
            result.append("%s : %s" % (f, v))
        self.peak_list = "\n".join(result)

        # 保存现在的fft计算结果,并计算正弦合成波
        self.fft_parameters = fft_parameters
        self.plot_sin_combine()

    # 计算正弦合成波,计算2个周期
    def plot_sin_combine(self):
        index, data = fft_combine(self.fft_parameters, self.N, 2)
        self.plot_data.set_data("y2", data)

    # 返回一个ufunc计算指定参数的三角波
    def triangle_func(self):
        c = self.wave_width
        c0 = self.length_c
        hc = self.height_c

        def trifunc(x):
            x = x - int(x)  # 三角波的周期为1,因此只取x坐标的小数部分进行计算
            if x >= c: r = 0.0
            elif x < c0: r = x / c0 * hc
            else: r = (c - x) / (c - c0) * hc
            return r

        # 用trifunc函数创建一个ufunc函数,可以直接对数组进行计算, 不过通过此函数
        # 计算得到的是一个Object数组,需要进行类型转换
        return np.frompyfunc(trifunc, 1, 1)
コード例 #8
0
ファイル: tcp.py プロジェクト: chris838/Plot-o-matic
class TCPDriver(IODriver):
    """
      TCP input driver.
  """

    name = Str('TCP Driver')
    view = View(Item(name='port', label='Port'),
                Item(name='show_debug_msgs', label='Show debug messages'),
                Item(name='buffer_size', label='Buffer size / kb'),
                Item(name='timeout', label='Timeout / s'),
                title='TCP input driver')

    _sock = socket.socket()

    port = Range(1024, 65535, 34443)
    buffer_size = Range(
        1, 4096,
        10)  # no reason not to go above 4MB but there should be some limit.
    timeout = Float(1.0)
    ip = Str('0.0.0.0')
    show_debug_msgs = Bool(False)

    is_open = Bool(False)

    def open(self):
        print "Opening (one time)"

    def listen(self):
        print "Listening..."
        self.is_open = False
        self._sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
        try:
            self._sock.bind((self.ip, self.port))
        except socket.error:
            print "Error, address probably already bound. Will try again"
            return
        self._sock.settimeout(self.timeout)  # seconds
        self._sock.listen(1)
        self._sock, (addr, _) = self._sock.accept()
        print addr, " connected!"
        self.is_open = True

    def close(self):
        print "Closing"
        self.is_open = False
        self._sock.close()

    def receive(self):
        try:
            if not self.is_open:
                self.listen()
                return None
            else:
                try:
                    (data, _) = self._sock.recvfrom(1024 * self.buffer_size)
                    if self.show_debug_msgs:
                        print "TCP driver: packet size %u bytes" % len(data)
                except socket.error:
                    return None
        except socket.timeout:
            print "Socket timed out"
            self.is_open = False
            return None
        return data

    def rebind_socket(self):
        self.close()

    @on_trait_change('port')
    def change_port(self):
        self.rebind_socket()

    @on_trait_change('address')
    def change_address(self):
        self.rebind_socket()

    @on_trait_change('timeout')
    def change_timeout(self):
        self.rebind_socket()
コード例 #9
0
ファイル: text.py プロジェクト: sjl421/code-2
class Text(Module):
    # The version of this class.  Used for persistence.
    __version__ = 0

    # The tvtk TextActor.
    actor = Instance(tvtk.TextActor, allow_none=False, record=True)

    # The property of the axes (color etc.).
    property = Property(record=True)

    # The text to be displayed.  Note that this should really be `Str`
    # but wxGTK only returns unicode.
    text = Str('Text', desc='the text to be displayed')

    # The x-position of this actor.
    x_position = Float(0.0, desc='the x-coordinate of the text')

    # The y-position of this actor.
    y_position = Float(0.0, desc='the y-coordinate of the text')

    # The z-position of this actor.
    z_position = Float(0.0, desc='the z-coordinate of the text')

    # Shadow the positions as ranges for 2D. Simply using a RangeEditor
    # does not work as it resets the 3D positions to 1 when the dialog is
    # loaded.
    _x_position_2d = Range(0.,
                           1.,
                           0.,
                           enter_set=True,
                           auto_set=False,
                           desc='the x-coordinate of the text')
    _y_position_2d = Range(0.,
                           1.,
                           0.,
                           enter_set=True,
                           auto_set=False,
                           desc='the y-coordinate of the text')

    # 3D position
    position_in_3d = Bool(
        False,
        desc='whether the position of the object is given in 2D or in 3D')

    # The width of the text.
    width = Range(0.0,
                  1.0,
                  0.4,
                  enter_set=True,
                  auto_set=False,
                  desc='the width of the text as a fraction of the viewport')

    input_info = PipelineInfo(datasets=['any'],
                              attribute_types=['any'],
                              attributes=['any'])

    ########################################
    # The view of this object.

    if VTK_VER > 5.1:
        _text_actor_group = Group(Item(name='visibility'),
                                  Item(name='text_scale_mode'),
                                  Item(name='alignment_point'),
                                  Item(name='minimum_size'),
                                  Item(name='maximum_line_height'),
                                  show_border=True,
                                  label='Text Actor')
    else:
        _text_actor_group = Group(Item(name='visibility'),
                                  Item(name='scaled_text'),
                                  Item(name='alignment_point'),
                                  Item(name='minimum_size'),
                                  Item(name='maximum_line_height'),
                                  show_border=True,
                                  label='Text Actor')

    _position_group_2d = Group(Item(name='_x_position_2d', label='X position'),
                               Item(name='_y_position_2d', label='Y position'),
                               visible_when='not position_in_3d')

    _position_group_3d = Group(Item(name='x_position', label='X',
                                    springy=True),
                               Item(name='y_position', label='Y',
                                    springy=True),
                               Item(name='z_position', label='Z',
                                    springy=True),
                               show_border=True,
                               label='Position',
                               orientation='horizontal',
                               visible_when='position_in_3d')

    view = View(Group(Group(Item(name='text'),
                            Item(name='position_in_3d'),
                            _position_group_2d,
                            _position_group_3d,
                            Item(name='width',
                                 enabled_when='object.actor.scaled_text'),
                            ),
                      Group(Item(name='actor', style='custom',
                                 editor=\
                                 InstanceEditor(view=View(_text_actor_group))
                                 ),
                            show_labels=False),
                      label='TextActor',
                      show_labels=False
                      ),
                Group(Item(name='_property', style='custom', resizable=True),
                      label='TextProperty',
                      show_labels=False),
                )

    ########################################
    # Private traits.
    _updating = Bool(False)
    _property = Instance(tvtk.TextProperty)

    ######################################################################
    # `object` interface
    ######################################################################
    def __set_pure_state__(self, state):
        self._updating = True
        state_pickler.set_state(self,
                                state,
                                first=['actor'],
                                ignore=['_updating'])
        self._updating = False

    ######################################################################
    # `Module` interface
    ######################################################################
    def setup_pipeline(self):
        """Override this method so that it *creates* the tvtk
        pipeline.

        This method is invoked when the object is initialized via
        `__init__`.  Note that at the time this method is called, the
        tvtk data pipeline will *not* yet be setup.  So upstream data
        will not be available.  The idea is that you simply create the
        basic objects and setup those parts of the pipeline not
        dependent on upstream sources and filters.  You should also
        set the `actors` attribute up at this point.
        """
        actor = self.actor = tvtk.TextActor(input=str(self.text))
        if VTK_VER > 5.1:
            actor.set(text_scale_mode='prop', width=0.4, height=1.0)
        else:
            actor.set(scaled_text=True, width=0.4, height=1.0)

        c = actor.position_coordinate
        c.set(coordinate_system='normalized_viewport',
              value=(self.x_position, self.y_position, 0.0))
        c = actor.position2_coordinate
        c.set(coordinate_system='normalized_viewport')

        self._property.opacity = 1.0

        self._text_changed(self.text)
        self._width_changed(self.width)
        self._shadow_positions(True)

    def update_pipeline(self):
        """Override this method so that it *updates* the tvtk pipeline
        when data upstream is known to have changed.

        This method is invoked (automatically) when any of the inputs
        sends a `pipeline_changed` event.
        """
        self.pipeline_changed = True

    def update_data(self):
        """Override this method so that it flushes the vtk pipeline if
        that is necessary.

        This method is invoked (automatically) when any of the inputs
        sends a `data_changed` event.
        """
        # Just set data_changed, the component should do the rest.
        self.data_changed = True

    ######################################################################
    # Non-public interface
    ######################################################################
    def _text_changed(self, value):
        actor = self.actor
        if actor is None:
            return
        if self._updating:
            return
        actor.input = str(value)
        self.render()

    def _shadow_positions(self, value):
        self.sync_trait('x_position',
                        self,
                        '_x_position_2d',
                        remove=(not value))
        self.sync_trait('y_position',
                        self,
                        '_y_position_2d',
                        remove=(not value))
        if not value:
            self._x_position_2d = self.x_position
            self._y_position_2d = self.y_position

    def _position_in_3d_changed(self, value):
        if value:
            self.actor.position_coordinate.coordinate_system = 'world'
            self.actor.position2_coordinate.coordinate_system = 'world'
        else:
            self.actor.position2_coordinate.coordinate_system=\
                                            'normalized_viewport'
            self.actor.position_coordinate.coordinate_system=\
                                            'normalized_viewport'
            x = self.x_position
            y = self.y_position
            if x < 0:
                x = 0
            elif x > 1:
                x = 1
            if y < 0:
                y = 0
            elif y > 1:
                y = 1
            self.set(x_position=x, y_position=y, trait_change_notify=False)
        self._shadow_positions(not value)
        self._change_position()
        self.actor._width_changed(self.width, self.width)
        self.pipeline_changed = True

    def _change_position(self):
        """ Callback for _x_position, _y_position, and z_position.
        """
        actor = self.actor
        if actor is None:
            return
        if self._updating:
            return
        x = self.x_position
        y = self.y_position
        z = self.z_position
        if self.position_in_3d:
            actor.position_coordinate.value = x, y, z
        else:
            actor.position = x, y
        self.render()

    _x_position_changed = _change_position

    _y_position_changed = _change_position

    _z_position_changed = _change_position

    def _width_changed(self, value):
        actor = self.actor
        if actor is None:
            return
        if self._updating:
            return
        actor.width = value
        self.render()

    def _update_traits(self):
        self._updating = True
        try:
            actor = self.actor
            self.text = actor.input
            pos = actor.position
            self.x_position, self.y_position = pos
            self.width = actor.width
        finally:
            self._updating = False

    def _get_property(self):
        return self._property

    def _actor_changed(self, old, new):
        if old is not None:
            for obj in (old, self._property):
                obj.on_trait_change(self.render, remove=True)
            old.on_trait_change(self._update_traits, remove=True)

        self._property = new.text_property
        for obj in (new, self._property):
            obj.on_trait_change(self.render)
        new.on_trait_change(self._update_traits)

        self.actors = [new]
        self.render()

    def _foreground_changed_for_scene(self, old, new):
        # Change the default color for the actor.
        self.property.color = new
        self.render()

    def _scene_changed(self, old, new):
        super(Text, self)._scene_changed(old, new)
        self._foreground_changed_for_scene(None, new.foreground)
コード例 #10
0
class PyDSP(HasTraits):
    sample_rate_value = Enum("1", "2", "5", "10", "20", "50", "100", "200",
                             "500")
    sample_rate_units = Enum("Hz", "kHz", "MHz", "GHz")
    sample_rate = Property(
        Float, depends_on=['sample_rate_value', 'sample_rate_units'])
    t = Property(Array, depends_on=['sample_rate'])
    input_type = Enum("sine", "square", "triangle", "chirp", "noise",
                      "impulse")
    input_freq = Range(low=1, high=1000, value=1)
    input_freq_units = Enum("Hz", "kHz", "MHz", "GHz")
    input_span = Range(low=1, high=1000, value=500)
    input_span_units = Enum("Hz", "kHz", "MHz", "GHz")
    x = Property(Array,
                 depends_on=[
                     'sample_rate', 'input_type', 'input_freq',
                     'input_freq_units', 'input_span', 'input_span_units', 't'
                 ])
    Ntaps = Range(low=1, high=Ntaps_max, value=3)
    filter_type = Enum("FIR", "IIR", "custom")
    filter_cutoff = Range(low=1, high=1000, value=500)
    filter_cutoff_units = Enum("Hz", "kHz", "MHz", "GHz")
    a = Property(Array,
                 depends_on=[
                     'Ntaps', 'filter_type', 'filter_cutoff',
                     'filter_cutoff_units', 'sample_rate', 'usr_a', 'usr_b'
                 ])
    b = array([1])  # Will be set by 'a' handler, upon change in dependencies.
    a_str = String("1", font='Courier')
    b_str = String("1")
    usr_a = Array(float, (1, Ntaps_max),
                  [[1] + [0 for i in range(Ntaps_max - 1)]])
    usr_b = Array(float, (1, Ntaps_max),
                  [[1] + [0 for i in range(Ntaps_max - 1)]])
    H = Property(Array, depends_on=['a'])
    H_mag = Property(Array, depends_on=['H'])
    H_phase = Property(Array, depends_on=['H'])
    h = Property(Array, depends_on=['H'])
    f = Property(Array, depends_on=['sample_rate'])
    y = Property(Array, depends_on=['a', 'x'])
    plot_type = Enum("line", "scatter")
    plot_type2 = Enum("line", "scatter")
    ident = String(
        'PyDSP v0.4 - a digital filter design tool, written in Python\n\n \
    David Banas <*****@*****.**>\n \
    October 15, 2014\n\n \
    Copyright (c) 2011 David Banas; All rights reserved World wide.')

    # Set the default values of the independent variables.
    def _sample_rate_units_default(self):
        """ Default handler for angular frequency Trait Array. """
        return "MHz"

    def _input_freq_units_default(self):
        return "kHz"

    def _filter_cutoff_default(self):
        return 100

    def _filter_cutoff_units_default(self):
        return "kHz"

    # Define dependent variables.
    @cached_property
    def _get_sample_rate(self):
        """Recalculate when a trait the property depends on changes."""
        val = float(self.sample_rate_value)
        units = self.sample_rate_units
        if (units == "Hz"):
            return val
        elif (units == "kHz"):
            return val * 1e3
        elif (units == "MHz"):
            return val * 1e6
        else:
            return val * 1e9

    @cached_property
    def _get_t(self):
        t = arange(Npts) / self.sample_rate
        return t

    @cached_property
    def _get_f(self):
        f = arange(Npts) * (self.sample_rate / (2 * Npts))
        return f

    @cached_property
    def _get_x(self):
        # Combine value/units from GUI to form actual frequency.
        val = self.input_freq
        units = self.input_freq_units
        if (units == "Hz"):
            sig_freq = val
        elif (units == "kHz"):
            sig_freq = val * 1e3
        elif (units == "MHz"):
            sig_freq = val * 1e6
        else:
            sig_freq = val * 1e9

        # Combine value/units from GUI to form actual span (chirp only).
        val = self.input_span
        units = self.input_span_units
        if (units == "Hz"):
            sig_span = val
        elif (units == "kHz"):
            sig_span = val * 1e3
        elif (units == "MHz"):
            sig_span = val * 1e6
        else:
            sig_span = val * 1e9

        # Generate the signal.
        square_wave = sign(sin(2 * pi * sig_freq * self.t))
        sig_type = self.input_type
        if (sig_type == "sine"):
            return sin(2 * pi * sig_freq * self.t)
        elif (sig_type == "square"):
            return square_wave
        elif (sig_type == "triangle"):
            triangle_wave = array(
                [sum(square_wave[0:i + 1]) for i in range(Npts)])
            triangle_wave = triangle_wave * 2 / (max(triangle_wave) -
                                                 min(triangle_wave))
            return triangle_wave + (1 - max(triangle_wave))
        elif (sig_type == "chirp"):
            freqs = array([
                sig_freq - sig_span / 2 + i * sig_span / Npts
                for i in range(Npts)
            ])
            return sin(self.t * (2 * pi * freqs))
        elif (sig_type == "noise"):
            return array([random() * 2 - 1 for i in range(Npts)])
        else:  # "impulse"
            return array([0, 1] + [0 for i in range(Npts - 2)])

    @cached_property
    def _get_a(self):
        # Combine value/units from GUI to form actual cutoff frequency.
        val = self.filter_cutoff
        units = self.filter_cutoff_units
        if (units == "Hz"):
            fc = val
        elif (units == "kHz"):
            fc = val * 1e3
        elif (units == "MHz"):
            fc = val * 1e6
        else:
            fc = val * 1e9

        # Generate the filter coefficients.
        if (self.filter_type == "FIR"):
            w = fc / (self.sample_rate / 2)
            b = firwin(self.Ntaps, w)
            a = [1]
        elif (self.filter_type == "IIR"):
            (b, a) = iirfilter(self.Ntaps - 1,
                               fc / (self.sample_rate / 2),
                               btype='lowpass')
        else:
            a = self.usr_a[0]
            b = self.usr_b[0]
        if (self.filter_type != "custom"):
            self.a_str = reduce(
                lambda string, item: string + "%+06.3f  " % item, a, "")
            self.b_str = reduce(
                lambda string, item: string + "%+06.3f  " % item, b, "")
        self.b = b
        return a

    @cached_property
    def _get_h(self):
        x = array([0, 1] + [0 for i in range(Npts - 2)])
        a = self.a
        b = self.b
        return lfilter(b, a, x)

    @cached_property
    def _get_H(self):
        (w, H) = freqz(self.b, self.a, worN=Npts)
        return H

    @cached_property
    def _get_H_mag(self):
        H = self.H
        H[0] = 0  # Kill the d.c. component, or it will swamp everything.
        return abs(H)

    @cached_property
    def _get_H_phase(self):
        H = self.H
        return angle(H, deg=True)

    @cached_property
    def _get_y(self):
        x = self.x
        a = self.a
        b = self.b
        return lfilter(b, a, x)
コード例 #11
0
ファイル: Flopper.py プロジェクト: Faridelnik/Pi3Diamond
class Flopper(SingletonHasTraits, GetSetItemsMixin):

    BufferLength = Range(low=1, high=512, value=512)
    Threshold = Range(low=0, high=10000, value=100)
    Pulse = Bool(True)
    Stream = Bool(True)
    Chunks = Range(low=1, high=10000, value=10)
    PulseLength = Int(10)

    Trace = Array()
    trace_binned = Array()

    file_path_timetrace = Str(r'D:\data\Timetrace')
    file_name_timetrace = Str('enter filename')
    save_timetrace = Button()

    binning = Range(low=1, high=100, value=1, mode='text')
    refresh_hist = Button()

    file_path_hist = Str(r'D:\data\Histograms')
    file_name_hist = Str('enter filename')
    save_hist = Button()

    #for readout protocol
    bits_readout = Int(
        6, label='# of substeps-bits for readout protocol (steps = 2**x)')

    # FileNameTrace = Str(r'D:\data\Trace' + str(datetime.date.today()) + '_01.dat')
    # FileNameHist = Str(r'D:\data\Histogram' + str(datetime.date.today()) + '_01.dat')
    view = View(VGroup(
        HGroup(Item('Pulse'), Item('Loop'), Item('Run'), Item('Threshold'),
               Item('Chunks')),
        HGroup(
            VGroup(
                Item('TracePlot',
                     editor=ComponentEditor(),
                     show_label=False,
                     width=500,
                     height=300,
                     resizable=True),
                HGroup(
                    Item('file_name_timetrace',
                         label='Filename of timetrace:'),
                    Item('save_timetrace',
                         label='Save Timetrace',
                         show_label=False))),
            VGroup(
                Item('HistPlot',
                     editor=ComponentEditor(),
                     show_label=False,
                     width=500,
                     height=300,
                     resizable=True),
                HGroup(
                    Item('binning', label='             # of bins'),
                    Item('refresh_hist',
                         label='Refresh histogram',
                         show_label=False)),
                HGroup(
                    Item('file_name_hist', label='Filename of histogram:'),
                    Item('save_hist', label='Save Histogram',
                         show_label=False)))),
    ),
                title='Flopper',
                width=700,
                height=500,
                buttons=['OK'],
                resizable=True)

    def _Trace_default(self):
        return numpy.zeros((self.BufferLength * self.Chunks, ))

    def _BufferLength_default(self):
        xem.SetWireInValue(0x00, 512)
        xem.UpdateWireIns()
        self._Buf = '\x00' * 2 * 512
        return 512

    def _BufferLength_changed(self):
        xem.SetWireInValue(0x00, self.BufferLength)
        xem.UpdateWireIns()
        self._Buf = '\x00' * 2 * self.BufferLength

    def _Threshold_changed(self):
        xem.SetWireInValue(0x02, self.Threshold)
        xem.UpdateWireIns()

    def _PulseLength_changed(self):
        xem.SetWireInValue(0x04, self.PulseLength)
        xem.UpdateWireIns()

    @on_trait_change('Pulse,Stream')
    def UpdateFlags(self):
        xem.SetWireInValue(0x03, (self.Pulse << 1) | self.Stream)
        xem.UpdateWireIns()

    def Reset(self):
        self.Stream = False
        xem.ActivateTriggerIn(0x40, 0)

    def ReadPipe(self):
        M = self.BufferLength
        buf = '\x00' * 2 * M
        xem.ReadFromBlockPipeOut(0xA0, 2 * M, buf)
        return numpy.array(struct.unpack('%iH' % M, buf))

    def GetTrace_old(self):
        M = self.BufferLength
        self.Trace = numpy.zeros((self.BufferLength * self.Chunks, ))
        temp = self.Trace.copy()
        self.Reset()
        self.Stream = True
        for i in range(self.Chunks):
            if self.abort.isSet():
                break
            temp[i * M:(i + 1) * M] = self.ReadPipe()
            self.Trace = temp
            self.HistogramN, self.HistogramBins = numpy.histogram(
                self.Trace[:(i + 1) * M],
                bins=numpy.arange(self.Trace.min(), self.Trace.max(), 1))
            self._Trace_changed()

    HistogramN = Array(value=numpy.array((0, 1)))
    HistogramBins = Array(value=numpy.array((0, 0)))

    def GetTrace(self):
        M = self.BufferLength
        self.Trace = numpy.zeros((self.BufferLength * self.Chunks, ))
        #temp = self.Trace.copy()
        self.Reset()
        self.Stream = True
        for i in range(self.Chunks):
            if self.abort.isSet():
                break
            self.Trace[i * M:(i + 1) * M] = self.ReadPipe()

            xem.UpdateWireOuts()
            ep21 = xem.GetWireOutValue(0x21)
            ep20 = xem.GetWireOutValue(0x20)
            print 'Output values', ep20, ep21
            print 'Trace[0:9] values', self.Trace[0:9]

            #self.Trace = temp
            #if i % 10 == 0:
            #    self.HistogramN, self.HistogramBins = numpy.histogram(self.Trace[:(i+1)*M], bins=numpy.arange(self.Trace.min(),self.Trace.max(),1))
            #    self._Trace_changed()
        self.HistogramN, self.HistogramBins = numpy.histogram(
            self.Trace[:(i + 1) * M],
            bins=numpy.arange(self.Trace.min(), self.Trace.max(), 1))
        self._Trace_changed()

    def GetTrace2(self, weightfnct):
        '''Get trace with weighting function'''
        M = self.BufferLength
        bits = self.bits_readout
        if len(weightfnct) != 64:
            raise RuntimeError('Length of weightfunction != 64')
        N = len(weightfnct)
        weightfnct_init = weightfnct[::-1]
        points = M / N
        self.Trace = numpy.zeros((self.BufferLength * self.Chunks))
        self.Reset()
        self.Stream = True
        Trace_init = []
        Trace_read = []
        for i in range(self.Chunks):
            if self.abort.isSet():
                break
            for k in range(N):
                temp = self.ReadPipe()
                for j in range(points):
                    value_init = 0
                    value_read = 0
                    for k in range(N):
                        value_init += temp[j * N + k] * weightfnct_init[k]
                        value_read += temp[j * N + k] * weightfnct[k]
                    Trace_init.append(value_init)
                    Trace_read.append(value_read)
        self.Trace = numpy.array(Trace_read)
        self.Trace_init = numpy.array(Trace_init)
        #self.HistogramN, self.HistogramBins = numpy.histogram(self.Trace[:(i+1)*M], bins=numpy.arange(self.Trace.min(),self.Trace.max(),1))
        self.HistogramN, self.HistogramBins = numpy.histogram(
            self.Trace,
            bins=numpy.arange(self.Trace.min(), self.Trace.max(), 1))
        self._Trace_changed()

    HistogramN = Array(value=numpy.array((0, 1)))
    HistogramBins = Array(value=numpy.array((0, 0)))

    # continuous data acquisition in thread

    Loop = Bool(False)

    Run = Property(trait=Bool)

    abort = threading.Event()
    abort.clear()

    _StopTimeout = 10.

    def _get_Run(self):
        if hasattr(self, 'Thread'):
            return self.Thread.isAlive()
        else:
            return False

    def _set_Run(self, value):
        if value == True:
            self.Start()
        else:
            self.Stop()

    def Start(self):
        """Start Measurement in a thread."""
        self.Stop()
        self.Thread = threading.Thread(target=self.run)
        self.Thread.start()

    def Stop(self):
        if hasattr(self, 'Thread'):
            self.abort.set()
            self.Thread.join(self._StopTimeout)
            self.abort.clear()
            if self.Thread.isAlive():
                self.ready.set()
            del self.Thread

    def run(self):
        self.Traces = []
        self.GetTrace()
        while self.Loop:
            if self.abort.isSet():
                break
            self.GetTrace()
        self.Traces.append(self.Export())
        self.trace = self.Export()

    # trace and histogram plots

    TracePlot = Instance(Component)
    HistPlot = Instance(Component)

    def _TracePlot_default(self):
        return self._create_TracePlot_component()

    def _HistPlot_default(self):
        return self._create_HistPlot_component()

    def _create_TracePlot_component(self):
        plot = DataView(border_visible=True)
        line = LinePlot(value=ArrayDataSource(self.Trace),
                        index=ArrayDataSource(numpy.arange(len(self.Trace))),
                        color='blue',
                        index_mapper=LinearMapper(range=plot.index_range),
                        value_mapper=LinearMapper(range=plot.value_range))
        plot.index_range.sources.append(line.index)
        plot.value_range.sources.append(line.value)
        plot.add(line)
        plot.index_axis.title = 'index'
        plot.value_axis.title = 'Fluorescence [ counts / s ]'
        plot.tools.append(PanTool(plot))
        plot.overlays.append(ZoomTool(plot))
        self.TraceLine = line
        return plot

    def _create_HistPlot_component(self):
        plot = DataView(border_visible=True)
        line = LinePlot(
            index=ArrayDataSource(self.HistogramBins),
            value=ArrayDataSource(self.HistogramN),
            color='blue',
            #fill_color='blue',
            index_mapper=LinearMapper(range=plot.index_range),
            value_mapper=LinearMapper(range=plot.value_range))
        plot.index_range.sources.append(line.index)
        plot.value_range.sources.append(line.value)
        plot.add(line)
        plot.index_axis.title = 'Fluorescence counts'
        plot.value_axis.title = 'number of occurences'
        plot.tools.append(PanTool(plot))
        plot.overlays.append(ZoomTool(plot))
        self.HistLine = line
        return plot

    def _save_timetrace_changed(self):
        fil = open(
            self.file_path_timetrace + '\\' + self.file_name_timetrace +
            '_Trace' + r'.asc', 'w')
        fil.write('[Data]')
        fil.write('\n')
        for x in self.trace.Normal():
            fil.write('%i ' % x)
            fil.write('\n')
        fil.close()

    def _save_hist_changed(self):
        fil = open(
            self.file_path_hist + '\\' + self.file_name_hist + '_Hist' +
            r'.asc', 'w')
        fil.write('binning = %i' % self.binning)
        fil.write('\n')
        fil.write('[Data]')
        fil.write('\n')
        for x in range(len(self.HistogramBins) - 1):
            fil.write('%i' % self.HistogramBins[x] + '   ' +
                      '%i' % self.HistogramN[x] + '\n')
        fil.close()

    def _refresh_hist_changed(self):
        self.trace_binned = numpy.zeros(
            (self.BufferLength * self.Chunks / self.binning, ))
        a = 0
        for i in range(len(self.Trace[:-(self.binning)])):
            if i % self.binning == 0:
                self.trace_binned[a] = self.Trace[i:(i + self.binning)].sum()
                a = a + 1
        self.HistogramN, self.HistogramBins = numpy.histogram(
            self.trace_binned,
            bins=numpy.arange(self.trace_binned.min(), self.trace_binned.max(),
                              1))
        self._Trace_changed()
        self._HistogramN_changed()
        self._HistogramBins_changed()

    def __init__(self):
        super(Flopper, self).__init__()

        # TracePlotData = ArrayPlotData(Trace=self.Trace)
        # HistPlotData = ArrayPlotData(Bins=self.HistogramBins, N=self.HistogramN)

        # TracePlot = Plot( TracePlotData )
        # HistPlot = Plot( HistPlotData )

        # TracePlot.index_axis.title = 'index'
        # TracePlot.value_axis.title = 'Fluorescence counts'
        # HistPlot.index_axis.title = 'Fluorescence counts'
        # HistPlot.value_axis.title = 'number'

        # TraceRenderer = TracePlot.plot('Trace', type='line', color='blue')[0]
        # HistRenderer = HistPlot.plot(('Bins','N'), type='line', color='blue')[0]

        # TracePlot.tools.append(PanTool(TracePlot, drag_button='right'))
        # TracePlot.tools.append(ZoomTool(TracePlot, tool_mode='range'))
        #TracePlot.overlays.append()

        # self.TracePlot = TracePlot
        # self.HistPlot = HistPlot
        # self.TraceRenderer=TraceRenderer
        # self.HistRenderer=HistRenderer

        self._PulseLength_changed()
        self._Threshold_changed()
        self.BufferLength = 512
        self.Pulse = False
        self.Stream = True
        self._Buf = '\x00' * 2 * self.BufferLength
        self.Reset()

    def _Trace_changed(self):
        if len(self.Trace) > 40000:  #Program can't handle very long traces
            Trace = self.Trace[0:40000]
        else:
            Trace = self.Trace
        self.TraceLine.value.set_data(Trace)
        self.TraceLine.index.set_data(numpy.arange(len(Trace)))

    def _HistogramN_changed(self):
        self.HistLine.value.set_data(self.HistogramN)

    def _HistogramBins_changed(self):
        self.HistLine.index.set_data(self.HistogramBins)

    def Export(self):
        return fitting.Trace(self.Trace, self.Threshold, self.Pulse)

        # def SaveAsc(self):
        # Ytrace=self.Trace
        # fil = open( self.FileNameTrace,'w')
        # for i in range(len(Ytrace)):
        # if i==0:
        # fil.write('Trace\n')
        # fil.write('%f\n'%(Ytrace[i]) )
        # else:
        # fil.write('%f\n'%(Ytrace[i]) )
        # fil.close()

        # Yhist= self.HistogramN
        Xhist = self.HistogramBins
        # fil = open( self.FileNameHist,'w')
        # for i in range(len(Yhist)):
        # if i==0:
        # fil.write('Histogram\n')
        # fil.write('%f\n'%(Yhist[i]) )
        # else:
        # fil.write('%f\n'%(Yhist[i]) )
        # fil.close()
#    def __del__(self):
#        del self.xem

    def __getstate__(self):
        """Returns current state of a selection of traits.
        Overwritten HasTraits.
        """
        state = SingletonHasTraits.__getstate__(self)
        for key in ['Thread', 'abort']:
            if state.has_key(key):
                del state[key]
        return state
コード例 #12
0
ファイル: geometry_viewer.py プロジェクト: p-chambers/pyavl
class GeometryViewer(HasTraits):
    meridional = Range(1, 30, 6)
    transverse = Range(0, 30, 11)
    scene = Instance(MlabSceneModel, ())
    geometry = Instance(Geometry)
    bodies = Property(List(Instance(BodyViewer)), depends_on='geometry.refresh_geometry_view,geometry.bodies[]')
    surfaces = Property(List(Instance(SurfaceViewer)), depends_on='geometry.refresh_geometry_view,geometry.surfaces[]')
    @cached_property
    def _get_surfaces(self):
        ret = []
        for surface in self.geometry.surfaces:
            ret.append(SurfaceViewer(surface=surface))
        return ret
    
    @cached_property
    def _get_bodies(self):
        ret = []
        for body in self.geometry.bodies:
            ret.append(BodyViewer(body=body))
        return ret
    
    def section_points(self, sections, yduplicate):
        ret = numpy.empty((len(sections * 2), 3))
        ret2 = []
        for sno, section in enumerate(sections):
            pno = sno * 2
            pt = ret[pno:pno + 2, :]
            pt[0, :] = section.leading_edge
            pt[1, :] = section.leading_edge
            pt[1, 0] += section.chord * numpy.cos(section.angle * numpy.pi / 180)
            pt[1, 2] -= section.chord * numpy.sin(section.angle * numpy.pi / 180)
            if yduplicate is not numpy.nan:
                pt2 = numpy.copy(pt)
                pt2[:, 1] = yduplicate - pt2[:, 1]
                ret2.append(pt2[0, :])
                ret2.append(pt2[1, :])
        #print ret.shape, len(ret2)
        if len(ret2) > 0:
            ret = numpy.concatenate((ret, numpy.array(ret2)))
        return ret
    
    def section_points_old(self, sections, yduplicate):
        ret = []
        for section in sections:
            pt = numpy.empty((2, 3))
            pt[0, :] = section.leading_edge
            pt[1, :] = section.leading_edge
            pt[1, 0] += section.chord * numpy.cos(section.angle * numpy.pi / 180)
            pt[1, 2] -= section.chord * numpy.sin(section.angle * numpy.pi / 180)
            ret.append(pt)
            if yduplicate is not numpy.nan:
                pt2 = numpy.copy(pt)
                pt2[:, 1] = yduplicate - pt2[:, 1]
                ret.append(pt2)
        ret = numpy.concatenate(ret)
        return ret
    
    def __init__(self, **args):
        # Do not forget to call the parent's __init__
        HasTraits.__init__(self, **args)
        #self.plot = self.scene.mlab.plot3d(x, y, z, t, colormap='Spectral')
        self.update_plot()

    @on_trait_change('geometry,geometry.refresh_geometry_view')
    def update_plot(self):
        self.plots = []
        #self.plot.mlab_source.set(x=x, y=y, z=z, scalars=t)
        self.scene.mlab.clf()
        # plot the axes
        
        for surface in self.surfaces:
            section_pts = surface.sectiondata
            for i, section_pt in enumerate(section_pts):
                if len(section_pt) == 2:
                    #tube_radius = 0.02 * abs(section_pt[-1,0]-section_pt[0,0])
                    tube_radius = None
                else:
                    tube_radius = None
                self.plots.append(self.scene.mlab.plot3d(section_pt[:, 0], section_pt[:, 1], section_pt[:, 2], tube_radius=tube_radius))
            self.plots.append(self.scene.mlab.mesh(surface.surfacedata[:, :, 0], surface.surfacedata[:, :, 1], surface.surfacedata[:, :, 2]))
        for body in self.bodies:
            width = (body.data_props[2] - body.data_props[1]) / body.num_pts * 0.15
            for data in body.bodydata:
                c = numpy.empty((2, 3))
                c[:] = data[:3]
                c[0, 0] -= width / 2
                c[1, 0] += width / 2
                #print c, width
                self.plots.append(self.scene.mlab.plot3d(c[:, 0], c[:, 1], c[:, 2], tube_radius=data[3], tube_sides=24))
                if numpy.isfinite(body.body.yduplicate):
                    c[:, 1] = body.body.yduplicate - c[:, 1]
                    self.plots.append(self.scene.mlab.plot3d(c[:, 0], c[:, 1], c[:, 2], tube_radius=data[3], tube_sides=24))
        #print 'numplots = ', len(self.plots)
    
    #@on_trait_change('geometry')
    def update_plot_old(self):
        self.plots = []
        #self.plot.mlab_source.set(x=x, y=y, z=z, scalars=t)
        self.scene.mlab.clf()
        for surface in self.geometry.surfaces:
            yduplicate = surface.yduplicate
            section_pts = self.section_points(surface.sections, yduplicate)
            for i in xrange(0, section_pts.shape[0], 2):
                self.plots.append(self.scene.mlab.plot3d(section_pts[i:i + 2, 0], section_pts[i:i + 2, 1], section_pts[i:i + 2, 2], tube_radius=0.1))
        print 'numplots = ', len(self.plots)

    # the layout of the dialog created
    view = View(Item('scene', editor=SceneEditor(scene_class=MayaviScene),
                    height=250, width=300, show_label=False),
                #Item('geometry', style='custom'),
                resizable=True
                )
コード例 #13
0
        class MyViewer(HasTraits):

            plot = Instance(Plot)
            plotdata = Instance(ArrayPlotData, args=())

            t = Array
            x = Property(Array, depends_on=['Y', 'Z', 'k_f', 'K_eq'])
            Y = Range(low=11e-3, high=0.1, value=20e-3)
            Z = Range(low=.5e-3, high=10e-3, value=2e-3)
            k_f = Range(low=15.0, high=150.0, value=100.0)
            K_eq = Range(low=0.1, high=100.0, value=10.0)
            plot_type = Enum("line", "scatter")

            traits_view = View(
                Item('plot', editor=ComponentEditor(), show_label=False),
                Item(name='Y'),
                Item(name='Z'),
                Item(name='k_f'),
                Item(name='K_eq'),
                Item(name='plot_type'),
                width=600,
                height=800,
                resizable=True,
                title="Stopped flow model",
            )

            def _plot_default(self):
                plot = Plot(self.plotdata)
                self.plotdata.set_data('t', self.t)
                self.plotdata.set_data('x', self.x)
                plot.plot(('t', 'x'),
                          color='red',
                          type_trait="plot_type",
                          name='x')
                plot.legend.visible = True
                plot.title = "Stopped flow"
                plot.x_axis.title = 't'

                # Add pan and zoom to the plot
                plot.tools.append(PanTool(plot, constrain_key="shift"))
                zoom = ZoomTool(plot)
                plot.overlays.append(zoom)
                return plot

            def _get_x(self):
                return cb(self.t, self.Y, self.Z, self.k_f, self.K_eq)

            def _x_changed(self):
                self.plotdata.set_data('x', self.x)

            def _t_default(self):
                return self.t_default

            def __init__(self, y0, params, t0, tend):
                for k, v in y0.items():
                    if hasattr(self, k):
                        setattr(self, k, v)
                    else:
                        raise AttributeError('No init cond. {}'.format(k +
                                                                       '0'))
                for k, v in params.items():
                    if hasattr(self, k):
                        setattr(self, k, v)
                    else:
                        raise AttributeError('No param {}'.format(k))
                self.t_default = np.linspace(t0, tend, 2048)
                super(MyViewer, self).__init__()

            @property
            def depv_init(self):
                return {'u': self.u0, 'v': self.v0}

            @property
            def params(self):
                return {'mu': self.mu}
コード例 #14
0
class FieldViewer(HasTraits):
    """三维标量场观察器"""

    # 三个轴的取值范围
    x0, x1 = Float(-5), Float(5)
    y0, y1 = Float(-5), Float(5)
    z0, z1 = Float(-5), Float(5)
    points = Int(50)  # 分割点数
    autocontour = Bool(True)  # 是否自动计算等值面
    v0, v1 = Float(0.0), Float(1.0)  # 等值面的取值范围
    contour = Range("v0", "v1", 0.5)  # 等值面的值
    function = Str("x*x*0.5 + y*y + z*z*2.0")  # 标量场函数
    function_list = [
        "x*x*0.5 + y*y + z*z*2.0", "x*y*0.5 + sin(2*x)*y +y*z*2.0", "x*y*z",
        "np.sin((x*x+y*y)/z)"
    ]
    plotbutton = Button("描画")
    scene = Instance(MlabSceneModel, ())  # mayavi场景

    view = View(
        HSplit(
            VGroup(
                "x0",
                "x1",
                "y0",
                "y1",
                "z0",
                "z1",
                Item('points', label="点数"),
                Item('autocontour', label="自动等值"),
                Item('plotbutton', show_label=False),
            ),
            VGroup(
                Item(
                    'scene',
                    editor=SceneEditor(
                        scene_class=MayaviScene),  # 设置mayavi的编辑器
                    resizable=True,
                    height=300,
                    width=350),
                Item('function',
                     editor=EnumEditor(name='function_list',
                                       evaluate=lambda x: x)),
                Item('contour',
                     editor=RangeEditor(format="%1.2f",
                                        low_name="v0",
                                        high_name="v1")),
                show_labels=False)),
        width=500,
        resizable=True,
        title="三维标量场观察器")

    def _plotbutton_fired(self):
        self.plot()

    def _autocontour_changed(self):
        "自动计算等值平面的设置改变事件响应"
        if hasattr(self, "g"):
            self.g.contour.auto_contours = self.autocontour
            if not self.autocontour:
                self._contour_changed()

    def _contour_changed(self):
        "等值平面的值改变事件响应"
        if hasattr(self, "g"):
            if not self.g.contour.auto_contours:
                self.g.contour.contours = [self.contour]

    def plot(self):
        "绘制场景"
        # 产生三维网格
        x, y, z = np.mgrid[self.x0:self.x1:1j * self.points,
                           self.y0:self.y1:1j * self.points,
                           self.z0:self.z1:1j * self.points]

        # 根据函数计算标量场的值
        scalars = eval(self.function)
        mlab.clf()  # 清空当前场景

        # 绘制等值平面
        g = mlab.contour3d(x, y, z, scalars, contours=8, transparent=True)
        g.contour.auto_contours = self.autocontour
        mlab.axes()  # 添加坐标轴

        # 添加一个X-Y的切面
        s = mlab.pipeline.scalar_cut_plane(g)
        cutpoint = (self.x0 + self.x1) / 2, (self.y0 + self.y1) / 2, (
            self.z0 + self.z1) / 2
        s.implicit_plane.normal = (0, 0, 1)  # x cut
        s.implicit_plane.origin = cutpoint

        self.g = g
        self.scalars = scalars
        # 计算标量场的值的范围
        self.v0 = np.min(scalars)
        self.v1 = np.max(scalars)
コード例 #15
0
ファイル: output.py プロジェクト: felixplasser/orbkit
    class MyModel(HasTraits):
        select = Range(0, len(data) - 1, 0)
        last_select = deepcopy(select)
        iso_value = Range(iso_min, iso_max, iso_val, mode='logslider')
        opacity = Range(0, 1.0, 1.0)
        show_atoms = Bool(True)
        label = Str()
        available = List(Str)
        available = datalabels

        prev_button = Button('Previous')
        next_button = Button('Next')

        scene = Instance(MlabSceneModel, ())
        plot_atoms = Instance(PipelineBase)
        plot0 = Instance(PipelineBase)

        # When the scene is activated, or when the parameters are changed, we
        # update the plot.
        @on_trait_change(
            'select,iso_value,show_atoms,opacity,label,scene.activated')
        def update_plot(self):
            if self.plot0 is None:
                if not is_vectorfield:
                    src = mlab.pipeline.scalar_field(X, Y, Z,
                                                     data[self.select])
                    self.plot0 = self.scene.mlab.pipeline.iso_surface(
                        src,
                        contours=[-self.iso_value, self.iso_value],
                        opacity=self.opacity,
                        colormap='blue-red',
                        vmin=-1e-8,
                        vmax=1e-8)
                else:
                    self.plot0 = self.scene.mlab.quiver3d(
                        X, Y, Z, *data[self.select])  #flow
                self.plot0.scene.background = (1, 1, 1)
            elif self.select != self.last_select:
                if not is_vectorfield:
                    self.plot0.mlab_source.set(scalars=data[self.select])
                else:
                    self.plot0.mlab_source.set(
                        vectors=data[self.select].reshape((3, -1)).T)
            if not is_vectorfield:
                self.plot0.contour.contours = [-self.iso_value, self.iso_value]
                self.plot0.actor.property.opacity = self.opacity
            self.last_select = deepcopy(self.select)
            if datalabels is not None:
                self.label = datalabels[self.select]
            if geo_spec is not None:
                if self.plot_atoms is None:
                    self.plot_atoms = self.scene.mlab.points3d(
                        geo_spec[:, 0],
                        geo_spec[:, 1],
                        geo_spec[:, 2],
                        scale_factor=0.75,
                        resolution=20)
                self.plot_atoms.visible = self.show_atoms

        def _prev_button_fired(self):
            if self.select > 0:
                self.select -= 1

        def _next_button_fired(self):
            if self.select < len(data) - 1:
                self.select += 1

        # The layout of the dialog created
        items = (Item('scene',
                      editor=SceneEditor(scene_class=MayaviScene),
                      height=400,
                      width=600,
                      show_label=False), )
        items0 = ()
        if len(data) > 1:
            items0 += (Group(
                'select',
                HSplit(Item('prev_button', show_label=False),
                       Item('next_button', show_label=False))), )
        items0 += (Group('iso_value', 'opacity', 'show_atoms'), )

        if datalabels is not None:
            if len(datalabels) > 1:
                items1 = (Item('available',
                               editor=ListStrEditor(title='Available Data',
                                                    editable=False),
                               show_label=False,
                               style='readonly',
                               width=300), )
                items0 = HSplit(items0, items1)
            items += (
                Group(
                    Item('label',
                         label='Selected Data',
                         style='readonly',
                         show_label=True), '_'),
                items0,
            )
        else:
            items += items0
        view = View(VSplit(items[0], items[1:]), resizable=True)
コード例 #16
0
class SURFDemo(HasTraits):
    m = Array(np.float, (2,3))
    max_distance = Range(0.1, 1.0, 0.26)
    draw_circle = Bool(True)
    
    hessian_th = Range(100.0, 1000.0, 1000.0)
    octaves = Range(1, 5, 2)
    layers = Range(1, 5, 3)
    view = View(
        Item("m", label=u"变换矩阵"),
        Item("hessian_th", label=u"hessian阈值"),    
        HGroup( 
            Item("octaves", label=u"Octaves"),
            Item("layers", label=u"层数"),
        ),
        Item("max_distance", label=u"距离阈值"),
        Item("draw_circle", label=u"绘制特征点"),
        title = u"SURF Demo控制面板",
        resizable = True,
    )        
    
    def __init__(self, *args, **kwargs):
        super(SURFDemo, self).__init__(*args, **kwargs)
        img = cv.imread("lena_small.jpg")
        self.m = np.array([[0.8,-0.6,60],[0.6,0.7,-20]])
        self.img1 = cv.Mat()
        cv.cvtColor(img, self.img1, cv.CV_BGR2GRAY)       
        self.affine()
        self.on_trait_change(self.redraw, "max_distance,draw_circle")
        self.on_trait_change(self.recalculate, "m,hessian_th,octaves,layers")
        self.recalculate()
        self.redraw()
        
    def get_features(self, img):
        surf = cv.SURF(self.hessian_th, self.octaves, self.layers, True)  
        keypoints = cv.vector_KeyPoint()
        features = surf(img, cv.Mat(), keypoints)  
        return keypoints, np.array(features)
        
    def affine(self):
        self.img2 = cv.Mat()
        M = cv.asMat(self.m, force_single_channel=True)
        cv.warpAffine(self.img1, self.img2, M, self.img1.size(), 
            borderValue=cv.CV_RGB(255,255,255))
                    
    def match_features(self):
        f1 = self.features1.reshape(len(self.keypoints1), -1) 
        f2 = self.features2.reshape(len(self.keypoints2), -1)
        self.f1 = f1
        self.f2 = f2
        distances = cdist(f1, f2) 
        self.mindist = np.min(distances, axis=1)  
        self.idx_mindist = np.argmin(distances, axis=1)
            
    def recalculate(self):
        self.affine()
        self.keypoints1, self.features1 = self.get_features(self.img1)        
        self.keypoints2, self.features2 = self.get_features(self.img2)        
        self.match_features()
        self.redraw()
        
    def draw_keypoints(self, img, keypoints, offset):
        for kp in keypoints:
            center = cv.Point(int(kp.pt.x)+offset, int(kp.pt.y))
            cv.circle(img, center, int(kp.size*0.25), cv.CV_RGB(255,255,0))

    def redraw(self):
        # 同时显示两幅图像
        w = self.img1.size().width
        h = self.img1.size().height
        show_img = cv.Mat(cv.Size(w*2, h), cv.CV_8UC3)
        for i in xrange(3):
            show_img[:,:w,i] = self.img1[:]
            show_img[:,w:,i] = self.img2[:]
        
        # 绘制特征线条
        if self.draw_circle:
            self.draw_keypoints(show_img, self.keypoints1, 0)
            self.draw_keypoints(show_img, self.keypoints2, w)
        
        
        # 绘制直线连接距离小于阈值的两个特征点
        for idx1 in np.where(self.mindist < self.max_distance)[0]:  
            idx2 = self.idx_mindist[idx1]
            pos1 = self.keypoints1[int(idx1)].pt  
            pos2 = self.keypoints2[int(idx2)].pt
            
            p1 = cv.Point(int(pos1.x), int(pos1.y)) 
            p2 = cv.Point(int(pos2.x)+w, int(pos2.y)) 
            cv.line(show_img, p1, p2, cv.CV_RGB(0,255,255), lineType=16)
        
        
        cv.imshow("SURF Demo",show_img)
コード例 #17
0
class HoughDemo(HasTraits):
    # Canny 参数
    th1 = Range(0.0, 255.0, 50.0)
    th2 = Range(0.0, 255.0, 200.0)
    show_canny = Bool(False)

    # HoughLine 参数
    rho = Range(1.0, 10.0, 1.0)
    theta = Range(0.1, 5.0, 1.0)
    hough_th = Range(1, 100, 40)
    minlen = Range(0, 100, 10)
    maxgap = Range(0, 20, 10)

    # HoughtCircle 参数
    dp = Range(1.0, 5.0, 2.0)
    mindist = Range(1.0, 100.0, 50.0)
    param1 = Range(50, 100.0, 50.0)
    param2 = Range(50, 100.0, 100.0)

    view = View(VGroup(
        Group(Item("th1", label="阈值1"),
              Item("th2", label="阈值2"),
              Item("show_canny", label="显示结果"),
              label="边缘检测参数"),
        Group(Item("rho", label="偏移分辨率(像素)"),
              Item("theta", label="角度分辨率(角度)"),
              Item("hough_th", label="阈值"),
              Item("minlen", label="最小长度"),
              Item("maxgap", label="最大空隙"),
              label="直线检测"),
        Group(Item("dp", label="分辨率(像素)"),
              Item("mindist", label="圆心最小距离(像素)"),
              Item("param1", label="参数1"),
              Item("param2", label="参数2"),
              label="圆检测")),
                title="直线和圆检测控制面板")

    def __init__(self, *args, **kwargs):
        super(HoughDemo, self).__init__(*args, **kwargs)

        self.img = cv.imread("stuff.jpg")
        self.img_gray = cv.Mat()
        cv.cvtColor(self.img, self.img_gray, cv.CV_BGR2GRAY)

        self.img_smooth = self.img_gray.clone()
        cv.smooth(self.img_gray, self.img_smooth, cv.CV_GAUSSIAN, 7, 7, 0, 0)

        self.redraw()

        self.on_trait_change(
            self.redraw,
            "th1,th2,show_canny,rho,theta,hough_th,minlen,maxgap,dp,mindist,param1,param2"
        )

    def redraw(self):

        edge_img = cv.Mat()
        # 边缘检测
        cv.Canny(self.img_gray, edge_img, self.th1, self.th2)
        3  ###
        # 计算结果图
        if self.show_canny:
            show_img = cv.Mat()
            cv.cvtColor(edge_img, show_img, cv.CV_GRAY2BGR)
        else:
            show_img = self.img.clone()
        4  ###
        # 线段检测
        theta = self.theta / 180.0 * np.pi
        lines = cv.HoughLinesP(edge_img, self.rho, theta, self.hough_th,
                               self.minlen, self.maxgap)
        for line in lines:
            cv.line(show_img, cv.asPoint(line[:2]), cv.asPoint(line[2:]),
                    cv.CV_RGB(255, 0, 0), 2)
        5  ###
        # 圆形检测
        circles = cv.HoughCircles(self.img_smooth,
                                  3,
                                  self.dp,
                                  self.mindist,
                                  param1=self.param1,
                                  param2=self.param2)

        for circle in circles:
            cv.circle(show_img, cv.Point(int(circle[0]), int(circle[1])),
                      int(circle[2]), cv.CV_RGB(0, 255, 0), 2)

        cv.imshow("Hough Demo", show_img)
コード例 #18
0
ファイル: image_button.py プロジェクト: sjl421/code-2
class ImageButton(Widget):
    """ An image and text-based control that can be used as a normal, radio or 
        toolbar button.
    """

    # Pens used to draw the 'selection' marker:
    _selectedPenDark = wx.Pen(
        wx.SystemSettings_GetColour(wx.SYS_COLOUR_3DSHADOW), 1, wx.SOLID)

    _selectedPenLight = wx.Pen(
        wx.SystemSettings_GetColour(wx.SYS_COLOUR_3DHIGHLIGHT), 1, wx.SOLID)

    #---------------------------------------------------------------------------
    #  Trait definitions:
    #---------------------------------------------------------------------------

    # The image:
    image = Instance(ImageResource, allow_none=True)

    # The (optional) label:
    label = Str

    # Extra padding to add to both the left and right sides:
    width_padding = Range(0, 31, 7)

    # Extra padding to add to both the top and bottom sides:
    height_padding = Range(0, 31, 5)

    # Presentation style:
    style = Enum('button', 'radio', 'toolbar', 'checkbox')

    # Orientation of the text relative to the image:
    orientation = Enum('vertical', 'horizontal')

    # Is the control selected ('radio' or 'checkbox' style)?
    selected = false

    # Fired when a 'button' or 'toolbar' style control is clicked:
    clicked = Event

    #---------------------------------------------------------------------------
    #  Initializes the object:
    #---------------------------------------------------------------------------

    def __init__(self, parent, **traits):
        """ Creates a new image control. 
        """
        self._image = None

        super(ImageButton, self).__init__(**traits)

        # Calculate the size of the button:
        idx = idy = tdx = tdy = 0
        if self._image is not None:
            idx = self._image.GetWidth()
            idy = self._image.GetHeight()

        if self.label != '':
            dc = wx.ScreenDC()
            dc.SetFont(wx.NORMAL_FONT)
            tdx, tdy = dc.GetTextExtent(self.label)

        wp2 = self.width_padding + 2
        hp2 = self.height_padding + 2
        if self.orientation == 'horizontal':
            self._ix = wp2
            spacing = (idx > 0) * (tdx > 0) * 4
            self._tx = self._ix + idx + spacing
            dx = idx + tdx + spacing
            dy = max(idy, tdy)
            self._iy = hp2 + ((dy - idy) / 2)
            self._ty = hp2 + ((dy - tdy) / 2)
        else:
            self._iy = hp2
            spacing = (idy > 0) * (tdy > 0) * 2
            self._ty = self._iy + idy + spacing
            dx = max(idx, tdx)
            dy = idy + tdy + spacing
            self._ix = wp2 + ((dx - idx) / 2)
            self._tx = wp2 + ((dx - tdx) / 2)

        # Create the toolkit-specific control:
        self._dx = dx + wp2 + wp2
        self._dy = dy + hp2 + hp2
        self.control = wx.Window(parent, -1, size=wx.Size(self._dx, self._dy))
        self.control._owner = self
        self._mouse_over = self._button_down = False

        # Set up mouse event handlers:
        wx.EVT_ENTER_WINDOW(self.control, self._on_enter_window)
        wx.EVT_LEAVE_WINDOW(self.control, self._on_leave_window)
        wx.EVT_LEFT_DOWN(self.control, self._on_left_down)
        wx.EVT_LEFT_UP(self.control, self._on_left_up)
        wx.EVT_PAINT(self.control, self._on_paint)

    #---------------------------------------------------------------------------
    #  Handles the 'image' trait being changed:
    #---------------------------------------------------------------------------

    def _image_changed(self, image):
        self._image = self._mono_image = None
        if image is not None:
            self._img = image.create_image()
            self._image = self._img.ConvertToBitmap()

        if self.control is not None:
            self.control.Refresh()

    #---------------------------------------------------------------------------
    #  Handles the 'selected' trait being changed:
    #---------------------------------------------------------------------------

    def _selected_changed(self, selected):
        """ Handles the 'selected' trait being changed.
        """
        if selected and (self.style == 'radio'):
            for control in self.control.GetParent().GetChildren():
                owner = getattr(control, '_owner', None)
                if (isinstance(owner, ImageButton) and owner.selected
                        and (owner is not self)):
                    owner.selected = False
                    break

        self.control.Refresh()

#-- wx event handlers ----------------------------------------------------------

    def _on_enter_window(self, event):
        """ Called when the mouse enters the widget. """

        if self.style != 'button':
            self._mouse_over = True
            self.control.Refresh()

    def _on_leave_window(self, event):
        """ Called when the mouse leaves the widget. """

        if self._mouse_over:
            self._mouse_over = False
            self.control.Refresh()

    def _on_left_down(self, event):
        """ Called when the left mouse button goes down on the widget. """
        self._button_down = True
        self.control.CaptureMouse()
        self.control.Refresh()

    def _on_left_up(self, event):
        """ Called when the left mouse button goes up on the widget. """
        control = self.control
        control.ReleaseMouse()
        self._button_down = False
        wdx, wdy = control.GetClientSizeTuple()
        x, y = event.GetX(), event.GetY()
        control.Refresh()
        if (0 <= x < wdx) and (0 <= y < wdy):
            if self.style == 'radio':
                self.selected = True
            elif self.style == 'checkbox':
                self.selected = not self.selected
            else:
                self.clicked = True

    def _on_paint(self, event):
        """ Called when the widget needs repainting. 
        """
        wdc = wx.PaintDC(self.control)
        wdx, wdy = self.control.GetClientSizeTuple()
        ox = (wdx - self._dx) / 2
        oy = (wdy - self._dy) / 2

        disabled = (not self.control.IsEnabled())
        if self._image is not None:
            image = self._image
            if disabled:
                if self._mono_image is None:
                    img = self._img
                    data = reshape(fromstring(img.GetData(), dtype('uint8')),
                                   (-1, 3)) * array([[0.297, 0.589, 0.114]])
                    g = data[:, 0] + data[:, 1] + data[:, 2]
                    data[:, 0] = data[:, 1] = data[:, 2] = g
                    img.SetData(ravel(data.astype(dtype('uint8'))).tostring())
                    img.SetMaskColour(0, 0, 0)
                    self._mono_image = img.ConvertToBitmap()
                    self._img = None
                image = self._mono_image
            wdc.DrawBitmap(image, ox + self._ix, oy + self._iy, True)

        if self.label != '':
            if disabled:
                wdc.SetTextForeground(DisabledTextColor)
            wdc.SetFont(wx.NORMAL_FONT)
            wdc.DrawText(self.label, ox + self._tx, oy + self._ty)

        pens = [self._selectedPenLight, self._selectedPenDark]
        bd = self._button_down
        style = self.style
        is_rc = (style in ('radio', 'checkbox'))
        if bd or (style == 'button') or (is_rc and self.selected):
            if is_rc:
                bd = 1 - bd
            wdc.SetBrush(wx.TRANSPARENT_BRUSH)
            wdc.SetPen(pens[bd])
            wdc.DrawLine(1, 1, wdx - 1, 1)
            wdc.DrawLine(1, 1, 1, wdy - 1)
            wdc.DrawLine(2, 2, wdx - 2, 2)
            wdc.DrawLine(2, 2, 2, wdy - 2)
            wdc.SetPen(pens[1 - bd])
            wdc.DrawLine(wdx - 2, 2, wdx - 2, wdy - 1)
            wdc.DrawLine(2, wdy - 2, wdx - 2, wdy - 2)
            wdc.DrawLine(wdx - 3, 3, wdx - 3, wdy - 2)
            wdc.DrawLine(3, wdy - 3, wdx - 3, wdy - 3)

        elif self._mouse_over and (not self.selected):
            wdc.SetBrush(wx.TRANSPARENT_BRUSH)
            wdc.SetPen(pens[bd])
            wdc.DrawLine(0, 0, wdx, 0)
            wdc.DrawLine(0, 1, 0, wdy)
            wdc.SetPen(pens[1 - bd])
            wdc.DrawLine(wdx - 1, 1, wdx - 1, wdy)
            wdc.DrawLine(1, wdy - 1, wdx - 1, wdy - 1)
コード例 #19
0
ファイル: int.py プロジェクト: mjfwest/OpenMDAO-Framework
    def __init__(self,
                 default_value=None,
                 iotype=None,
                 desc=None,
                 low=None,
                 high=None,
                 exclude_low=False,
                 exclude_high=False,
                 **metadata):

        # Range trait didn't seem to handle "None" correctly when passed on
        # the  command line.
        if default_value is None:
            if low is None and high is None:
                default_value = 0
            elif low is None:
                default_value = high
            else:
                default_value = low

        if low is None:
            low = -maxint
        if high is None:
            high = maxint

        if not isinstance(default_value, int):
            raise ValueError("Default value for an Int must be an integer.")

        if not isinstance(low, int):
            raise ValueError("Lower bound for an Int must be an integer.")

        if not isinstance(high, int):
            raise ValueError("Upper bound for an Int must be an integer.")

        if low > high:
            raise ValueError("Lower bound is greater than upper bound.")

        if default_value > high or default_value < low:
            raise ValueError("Default value is outside of bounds [%s, %s]." %
                             (str(low), str(high)))

        # Put iotype in the metadata dictionary
        if iotype is not None:
            metadata['iotype'] = iotype

        # Put desc in the metadata dictionary
        if desc is not None:
            metadata['desc'] = desc

        self._validator = Range(value=default_value,
                                low=low,
                                high=high,
                                exclude_low=exclude_low,
                                exclude_high=exclude_high,
                                **metadata)

        # Add low and high to the trait's dictionary so they can be accessed
        metadata['low'] = low
        metadata['high'] = high
        metadata['exclude_low'] = exclude_low
        metadata['exclude_high'] = exclude_high

        super(Int, self).__init__(default_value=default_value, **metadata)
コード例 #20
0
class Threshold(Filter):

    # The version of this class.  Used for persistence.
    __version__ = 0

    # The threshold filter used.
    threshold_filter = Property(Instance(tvtk.Object, allow_none=False),
                                record=True)

    # The filter type to use, specifies if the cells or the points are
    # cells filtered via a threshold.
    filter_type = Enum('cells',
                       'points',
                       desc='if thresholding is done on cells or points')

    # Lower threshold (this is a dynamic trait that is changed when
    # input data changes).
    lower_threshold = Range(value=-1.0e20,
                            low='_data_min',
                            high='_data_max',
                            enter_set=True,
                            auto_set=False,
                            desc='the lower threshold of the filter')

    # Upper threshold (this is a dynamic trait that is changed when
    # input data changes).
    upper_threshold = Range(value=1.0e20,
                            low='_data_min',
                            high='_data_max',
                            enter_set=True,
                            auto_set=False,
                            desc='the upper threshold of the filter')

    # Automatically reset the lower threshold when the upstream data
    # changes.
    auto_reset_lower = Bool(True,
                            desc='if the lower threshold is '
                            'automatically reset when upstream '
                            'data changes')

    # Automatically reset the upper threshold when the upstream data
    # changes.
    auto_reset_upper = Bool(True,
                            desc='if the upper threshold is '
                            'automatically reset when upstream '
                            'data changes')

    input_info = PipelineInfo(datasets=['any'],
                              attribute_types=['any'],
                              attributes=['any'])

    output_info = PipelineInfo(datasets=['poly_data', 'unstructured_grid'],
                               attribute_types=['any'],
                               attributes=['any'])

    # Our view.
    view = View(Group(
        Group(Item(name='filter_type'), Item(name='lower_threshold'),
              Item(name='auto_reset_lower'), Item(name='upper_threshold'),
              Item(name='auto_reset_upper')),
        Item(name='_'),
        Group(
            Item(name='threshold_filter',
                 show_label=False,
                 visible_when='object.filter_type == "cells"',
                 style='custom',
                 resizable=True)),
    ),
                resizable=True)

    ########################################
    # Private traits.

    # These traits are used to set the limits for the thresholding.
    # They store the minimum and maximum values of the input data.
    _data_min = Float(-1e20)
    _data_max = Float(1e20)

    # The threshold filter for cell based filtering
    _threshold = Instance(tvtk.Threshold, args=(), allow_none=False)

    # The threshold filter for points based filtering.
    _threshold_points = Instance(tvtk.ThresholdPoints,
                                 args=(),
                                 allow_none=False)

    # Internal data to
    _first = Bool(True)

    ######################################################################
    # `object` interface.
    ######################################################################
    def __get_pure_state__(self):
        d = super(Threshold, self).__get_pure_state__()
        # These traits are dynamically created.
        for name in ('_first', '_data_min', '_data_max'):
            d.pop(name, None)

        return d

    ######################################################################
    # `Filter` interface.
    ######################################################################
    def setup_pipeline(self):
        attrs = [
            'all_scalars', 'attribute_mode', 'component_mode',
            'selected_component'
        ]
        self._threshold.on_trait_change(self._threshold_filter_edited, attrs)

    def update_pipeline(self):
        """Override this method so that it *updates* the tvtk pipeline
        when data upstream is known to have changed.

        This method is invoked (automatically) when the input fires a
        `pipeline_changed` event.
        """
        if len(self.inputs) == 0:
            return

        # By default we set the input to the first output of the first
        # input.
        fil = self.threshold_filter
        fil.input = self.inputs[0].outputs[0]

        self._update_ranges()
        self._set_outputs([self.threshold_filter.output])

    def update_data(self):
        """Override this method to do what is necessary when upstream
        data changes.

        This method is invoked (automatically) when any of the inputs
        sends a `data_changed` event.
        """
        if len(self.inputs) == 0:
            return

        self._update_ranges()

        # Propagate the data_changed event.
        self.data_changed = True

    ######################################################################
    # Non-public interface
    ######################################################################
    def _lower_threshold_changed(self, new_value):
        fil = self.threshold_filter
        fil.threshold_between(new_value, self.upper_threshold)
        fil.update()
        self.data_changed = True

    def _upper_threshold_changed(self, new_value):
        fil = self.threshold_filter
        fil.threshold_between(self.lower_threshold, new_value)
        fil.update()
        self.data_changed = True

    def _update_ranges(self):
        """Updates the ranges of the input.
        """
        data_range = self._get_data_range()
        if len(data_range) > 0:
            dr = data_range
            if self._first:
                self._data_min, self._data_max = dr
                self.set(lower_threshold=dr[0], trait_change_notify=False)
                self.upper_threshold = dr[1]
                self._first = False
            else:
                if self.auto_reset_lower:
                    self._data_min = dr[0]
                    notify = not self.auto_reset_upper
                    self.set(lower_threshold=dr[0], trait_change_notify=notify)
                if self.auto_reset_upper:
                    self._data_max = dr[1]
                    self.upper_threshold = dr[1]

    def _get_data_range(self):
        """Returns the range of the input scalar data."""
        input = self.inputs[0].outputs[0]
        data_range = []
        ps = input.point_data.scalars
        cs = input.cell_data.scalars

        # FIXME: need to be able to handle cell and point data
        # together.
        if ps is not None:
            data_range = list(ps.range)
            if np.isnan(data_range[0]):
                data_range[0] = float(np.nanmin(ps.to_array()))
            if np.isnan(data_range[1]):
                data_range[1] = float(np.nanmax(ps.to_array()))
        elif cs is not None:
            data_range = cs.range
            if np.isnan(data_range[0]):
                data_range[0] = float(np.nanmin(cs.to_array()))
            if np.isnan(data_range[1]):
                data_range[1] = float(np.nanmax(cs.to_array()))
        return data_range

    def _auto_reset_lower_changed(self, value):
        if len(self.inputs) == 0:
            return
        if value:
            dr = self._get_data_range()
            self._data_min = dr[0]
            self.lower_threshold = dr[0]

    def _auto_reset_upper_changed(self, value):
        if len(self.inputs) == 0:
            return
        if value:
            dr = self._get_data_range()
            self._data_max = dr[1]
            self.upper_threshold = dr[1]

    def _get_threshold_filter(self):
        if self.filter_type == 'cells':
            return self._threshold
        else:
            return self._threshold_points

    def _filter_type_changed(self, value):
        if value == 'cells':
            old = self._threshold_points
            new = self._threshold
        else:
            old = self._threshold
            new = self._threshold_points
        self.trait_property_changed('threshold_filter', old, new)

    def _threshold_filter_changed(self, old, new):
        if len(self.inputs) == 0:
            return
        fil = new
        fil.input = self.inputs[0].outputs[0]
        fil.threshold_between(self.lower_threshold, self.upper_threshold)
        fil.update()
        self._set_outputs([fil.output])

    def _threshold_filter_edited(self):
        self.threshold_filter.update()
        self.data_changed = True
コード例 #21
0
class InPaintDemo(HasTraits):
    plot = Instance(Plot)
    painter = Instance(CirclePainter)
    r = Range(2.0, 20.0, 10.0)  # inpaint的半径参数
    method = Enum("INPAINT_NS", "INPAINT_TELEA")  # inpaint的算法
    show_mask = Bool(False)  # 是否显示选区
    clear_mask = Button("清除选区")
    apply = Button("保存结果")

    view = View(VGroup(
        VGroup(
            Item("object.painter.r", label="画笔半径"), Item("r",
                                                         label="inpaint半径"),
            HGroup(
                Item("method", label="inpaint算法"),
                Item("show_mask", label="显示选区"),
                Item("clear_mask", show_label=False),
                Item("apply", show_label=False),
            )),
        Item("plot", editor=ComponentEditor(), show_label=False),
    ),
                title="inpaint Demo控制面板",
                width=500,
                height=450,
                resizable=True)

    def __init__(self, *args, **kwargs):
        super(InPaintDemo, self).__init__(*args, **kwargs)
        self.img = cv.imread("stuff.jpg")  # 原始图像
        self.img2 = self.img.clone()  # inpaint效果预览图像
        self.mask = cv.Mat(self.img.size(), cv.CV_8UC1)  # 储存选区的图像
        self.mask[:] = 0
        self.data = ArrayPlotData(img=self.img[:, :, ::-1])
        self.plot = Plot(self.data,
                         padding=10,
                         aspect_ratio=float(self.img.size().width) /
                         self.img.size().height)
        self.plot.x_axis.visible = False
        self.plot.y_axis.visible = False
        imgplot = self.plot.img_plot("img", origin="top left")[0]
        self.painter = CirclePainter(component=imgplot)
        imgplot.overlays.append(self.painter)

    @on_trait_change("r,method")
    def inpaint(self):
        cv.inpaint(self.img, self.mask, self.img2, self.r,
                   getattr(cv, self.method))
        self.draw()

    @on_trait_change("painter:updated")
    def painter_updated(self):
        for _, _, x, y in self.painter.track:
            # 在储存选区的mask上绘制圆形
            cv.circle(self.mask,
                      cv.Point(int(x), int(y)),
                      int(self.painter.r),
                      cv.Scalar(255, 255, 255, 255),
                      thickness=-1)  # 宽度为负表示填充圆形
        self.inpaint()
        self.painter.track = []
        self.painter.request_redraw()

    def _clear_mask_fired(self):
        self.mask[:] = 0
        self.inpaint()

    def _apply_fired(self):
        """保存inpaint的处理结果,并清除选区"""
        self.img[:] = self.img2[:]
        self._clear_mask_fired()

    @on_trait_change("show_mask")
    def draw(self):
        if self.show_mask:
            data = self.img[:, :, ::-1].copy()
            data[self.mask[:] > 0] = 255
            self.data["img"] = data
        else:
            self.data["img"] = self.img2[:, :, ::-1]
コード例 #22
0
class Picker(HasTraits):
    """This module creates a 'Picker' that can interactively select a
    point and/or a cell in the data.  It also can use a world point
    picker (i.e. a generic point in space) and will probe for the data
    at that point.

    The Picker is usually called via a callback from the GUI
    interactor window.  After performing a pick on the VTK scene, a
    Picker object creates a `PickedData` object and passes it on to
    the pick_handler trait for further handling.
    """

    # The version of this class.  Used for persistence.
    __version__ = 0

    # Speficifies the pick type.  The 'point_picker' and 'cell_picker'
    # options are self-explanatory.  The 'world_picker' picks a point
    # using a WorldPointPicker and additionally uses a ProbeFilter to
    # probe the data at the picked point.
    pick_type = Trait('point',
                      TraitRevPrefixMap({
                          'point_picker': 1,
                          'cell_picker': 2,
                          'world_picker': 3
                      }),
                      desc='specifies the picker type to use')

    # The pick_handler.  Set this to your own subclass if you want do
    # do something different from the default.
    pick_handler = Trait(DefaultPickHandler(), Instance(PickHandler))

    # Picking tolerance.
    tolerance = Range(0.0, 0.25, 0.025)

    # show the GUI on pick ?
    show_gui = true(desc="whether to show the picker GUI on pick")

    # Raise the GUI on pick ?
    auto_raise = true(desc="whether to raise the picker GUI on pick")

    default_view = View(Group(
        Group(Item(name='pick_type'), Item(name='tolerance'),
              show_border=True),
        Group(Item(name='pick_handler', style='custom'),
              show_border=True,
              show_labels=False),
        Group(Item(name='show_gui'), Item(name='auto_raise'),
              show_border=True),
    ),
                        resizable=True,
                        buttons=['OK'],
                        handler=CloseHandler())

    #################################################################
    # `object` interface.
    #################################################################
    def __init__(self, renwin, **traits):
        super(Picker, self).__init__(**traits)

        self.renwin = renwin
        self.pointpicker = tvtk.PointPicker()
        self.cellpicker = tvtk.CellPicker()
        self.worldpicker = tvtk.WorldPointPicker()
        self.probe_data = tvtk.PolyData()
        self._tolerance_changed(self.tolerance)

        # Use a set of axis to show the picked point.
        self.p_source = tvtk.Axes()
        self.p_mapper = tvtk.PolyDataMapper()
        self.p_actor = tvtk.Actor()
        self.p_source.symmetric = 1
        self.p_actor.pickable = 0
        self.p_actor.visibility = 0
        prop = self.p_actor.property
        prop.line_width = 2
        prop.ambient = 1.0
        prop.diffuse = 0.0
        self.p_mapper.input = self.p_source.output
        self.p_actor.mapper = self.p_mapper

        self.probe_data.points = [[0.0, 0.0, 0.0]]

        self.ui = None

    def __get_pure_state__(self):
        d = self.__dict__.copy()
        for x in [
                'renwin', 'ui', 'pick_handler', '__sync_trait__',
                '__traits_listener__'
        ]:
            d.pop(x, None)
        return d

    def __getstate__(self):
        return state_pickler.dumps(self)

    def __setstate__(self, str_state):
        # This method is unnecessary since this object will almost
        # never be pickled by itself and only via the scene, therefore
        # __init__ will be called when the scene is constructed.
        # However, setstate is defined just for completeness.
        state_pickler.set_state(self, state_pickler.loads_state(str_state))

    #################################################################
    # `Picker` interface.
    #################################################################
    def pick(self, x, y):
        """Calls one of the current pickers and then passes the
        obtained data to the `self.pick_handler` object's
        `handle_pick` method.

        Parameters
        ----------

        - x : X position of the mouse in the window.
        
        - y : Y position of the mouse in the window.

          Note that the origin of x, y must be at the left bottom
          corner of the window.  Thus, for most GUI toolkits, y must
          be flipped appropriately such that y=0 is the bottom of the
          window.
        """

        data = None
        if self.pick_type_ == 1:
            data = self.pick_point(x, y)
        elif self.pick_type_ == 2:
            data = self.pick_cell(x, y)
        elif self.pick_type_ == 3:
            data = self.pick_world(x, y)

        self.pick_handler.handle_pick(data)
        if self.show_gui:
            self._setup_gui()

    def pick_point(self, x, y):
        """ Picks the nearest point. Returns a `PickedData` instance."""
        self.pointpicker.pick((float(x), float(y), 0.0), self.renwin.renderer)

        pp = self.pointpicker
        id = pp.point_id
        picked_data = PickedData()
        coord = pp.pick_position
        picked_data.coordinate = coord

        if id > -1:
            data = pp.mapper.input.point_data
            bounds = pp.mapper.input.bounds

            picked_data.valid = 1
            picked_data.point_id = id
            picked_data.data = data

            self._update_actor(coord, bounds)
        else:
            self.p_actor.visibility = 0

        self.renwin.render()
        return picked_data

    def pick_cell(self, x, y):
        """ Picks the nearest cell. Returns a `PickedData` instance."""
        self.cellpicker.pick((float(x), float(y), 0.0), self.renwin.renderer)

        cp = self.cellpicker
        id = cp.cell_id
        picked_data = PickedData()
        coord = cp.pick_position
        picked_data.coordinate = coord

        if id > -1:
            data = cp.mapper.input.cell_data
            bounds = cp.mapper.input.bounds

            picked_data.valid = 1
            picked_data.cell_id = id
            picked_data.data = data

            self._update_actor(coord, bounds)
        else:
            self.p_actor.visibility = 0

        self.renwin.render()
        return picked_data

    def pick_world(self, x, y):
        """ Picks a world point and probes for data there. Returns a
        `PickedData` instance."""
        self.worldpicker.pick((float(x), float(y), 0.0), self.renwin.renderer)

        # Use the cell picker to get the data that needs to be probed.
        self.cellpicker.pick((float(x), float(y), 0.0), self.renwin.renderer)

        wp = self.worldpicker
        cp = self.cellpicker
        coord = wp.pick_position
        self.probe_data.points = [list(coord)]
        picked_data = PickedData()
        picked_data.coordinate = coord

        if cp.mapper:
            data = get_last_input(cp.mapper.input)
            # Need to create the probe each time because otherwise it
            # does not seem to work properly.
            probe = tvtk.ProbeFilter()
            probe.source = data
            probe.input = self.probe_data
            probe.update()
            data = probe.output.point_data
            bounds = cp.mapper.input.bounds

            picked_data.valid = 1
            picked_data.world_pick = 1
            picked_data.point_id = 0
            picked_data.data = data

            self._update_actor(coord, bounds)
        else:
            self.p_actor.visibility = 0

        self.renwin.render()
        return picked_data

    def on_ui_close(self):
        """This method makes the picker actor invisible when the GUI
        dialog is closed."""
        self.p_actor.visibility = 0
        self.renwin.renderer.remove_actor(self.p_actor)
        self.ui = None

    #################################################################
    # Non-public interface.
    #################################################################
    def _tolerance_changed(self, val):
        """ Trait handler for the tolerance trait."""
        self.pointpicker.tolerance = val
        self.cellpicker.tolerance = val

    def _update_actor(self, coordinate, bounds):
        """Updates the actor by setting its position and scale."""
        dx = 0.3 * (bounds[1] - bounds[0])
        dy = 0.3 * (bounds[3] - bounds[2])
        dz = 0.3 * (bounds[5] - bounds[4])
        scale = max(dx, dy, dz)
        self.p_source.origin = coordinate
        self.p_source.scale_factor = scale
        self.p_actor.visibility = 1

    def _setup_gui(self):
        """Pops up the GUI control widget."""
        # Popup the GUI control.
        if self.ui is None:
            self.ui = self.edit_traits()
            # Note that we add actors to the renderer rather than to
            # renwin to prevent event notifications on actor
            # additions.
            self.renwin.renderer.add_actor(self.p_actor)
        elif self.auto_raise:
            try:
                self.ui.control.Raise()
            except AttributeError:
                pass
コード例 #23
0
class Autocorrelation(HasTraits):

    window_size = Range(low=1.,
                        high=1e9,
                        value=2000.,
                        desc='window size [ns]',
                        label='range [ns]',
                        mode='text',
                        auto_set=False,
                        enter_set=True)
    bin_width = Range(low=0.001,
                      high=1000.,
                      value=1.,
                      desc='bin width [ns]',
                      label='bin width [ns]',
                      mode='text',
                      auto_set=False,
                      enter_set=True)
    chan1 = Enum(0,
                 1,
                 2,
                 3,
                 4,
                 5,
                 6,
                 7,
                 desc="the trigger channel",
                 label="Channel 1")
    chan2 = Enum(0,
                 1,
                 2,
                 3,
                 4,
                 5,
                 6,
                 7,
                 desc="the signal channel",
                 label="Channel 2")

    counts = Array()
    bins = Array()

    start_time = Float(value=0.0)
    run_time = Float(value=0.0, label='run time [s]')

    plot = Instance(Plot)
    plot_data = Instance(ArrayPlotData)

    refresh_interval = Range(low=0.01,
                             high=1,
                             value=0.1,
                             desc='Refresh interval [s]',
                             label='Refresh interval [s]')

    def __init__(self, time_tagger, **kwargs):
        self.time_tagger = time_tagger
        self._create_pulsed()
        super(Autocorrelation, self).__init__(**kwargs)
        self._create_plot()
        self._stoppable_thread = StoppableThread(target=Autocorrelation._run,
                                                 args=(weakref.proxy(self), ))
        self._stoppable_thread.start()

    # @on_trait_change('chan1,chan2,window_size,bin_width')
    def _reset(self):
        self._create_pulsed()
        self.bins = self._bins_default()
        self.counts = self._counts_default()
        self.start_time = 0.0
        self.run_time = 0.0

    def _create_pulsed(self):
        n_bins = int(np.round(self.window_size / self.bin_width))
        bin_width = int(np.round(self.bin_width * 1000))
        self.p1 = TimeTagger.Correlation(self.time_tagger, n_bins, bin_width,
                                         1, self.chan1, self.chan2)
        self.p2 = TimeTagger.Correlation(self.time_tagger, n_bins, bin_width,
                                         1, self.chan2, self.chan1)

    def _refresh_data(self):
        data1 = self.p1.getData()
        data2 = self.p2.getData()
        self.run_time = time.time() - self.start_time
        self.counts = np.append(
            np.append(data1[0][-1:0:-1], max(data1[0][0], data2[0][0])),
            data2[0][1:])

    def _run(self):
        while not threading.current_thread().stop_request.isSet():
            self._refresh_data()
            threading.current_thread().stop_request.wait(self.refresh_interval)

    def _counts_default(self):
        n = int(np.round(self.window_size / self.bin_width))
        return np.zeros(n * 2 - 1)

    def _bins_default(self):
        n = int(np.round(self.window_size / self.bin_width))
        return self.bin_width * np.arange(-n + 1, n)

    def _bins_changed(self):
        self.plot_data.set_data('t', self.bins)

    def _counts_changed(self):
        self.plot_data.set_data('y', self.counts)

    def _chan1_default(self):
        return 0

    def _chan2_default(self):
        return 7

    def _window_size_changed(self):
        self.counts = self._counts_default()
        self.bins = self._bins_default()

    def _create_plot(self):
        data = ArrayPlotData(t=self.bins, y=self.counts)
        plot = Plot(data,
                    width=500,
                    height=500,
                    resizable='hv',
                    padding_left=96,
                    padding_bottom=32)
        plot.plot(('t', 'y'), type='line', color='blue')
        plot.tools.append(SaveTool(plot))
        plot.index_axis.title = 'time [ns]'
        plot.value_axis.title = 'counts'
        self.plot_data = data
        self.plot = plot

    def __del__(self):
        self._stoppable_thread.stop()

    traits_view = View(Item('plot',
                            editor=ComponentEditor(size=(100, 100)),
                            show_label=False),
                       HGroup(
                           Item('window_size'),
                           Item('bin_width'),
                           Item('chan1'),
                           Item('chan2'),
                       ),
                       title='Autocorrelation',
                       width=720,
                       height=520,
                       buttons=[],
                       resizable=True)
コード例 #24
0
class Lorenz(HasTraits):

    # The parameters for the Lorenz system, defaults to the standard ones.
    s = Range(0.0, 20.0, 10.0, desc='the parameter s', enter_set=True,
              auto_set=False)
    r = Range(0.0, 50.0, 28.0, desc='the parameter r', enter_set=True,
              auto_set=False)
    b = Range(0.0, 10.0, 8./3., desc='the parameter b', enter_set=True,
              auto_set=False)

    # These expressions are evaluated to compute the right hand sides of
    # the ODE.  Defaults to the Lorenz system.
    u = Str('s*(y-x)', desc='the x component of the velocity',
            auto_set=False, enter_set=True)
    v = Str('r*x - y - x*z', desc='the y component of the velocity',
            auto_set=False, enter_set=True)
    w = Str('x*y - b*z', desc='the z component of the velocity',
            auto_set=False, enter_set=True)

    # Tuple of x, y, z arrays where the field is sampled.
    points = Tuple(Array, Array, Array)

    # The mayavi(mlab) scene.
    scene = Instance(MlabSceneModel, args=())

    # The "flow" which is a Mayavi streamline module.
    flow = Instance(HasTraits)

    ########################################
    # The UI view to show the user.
    view = View(HSplit(
                    Group(
                        Item('scene', editor=SceneEditor(scene_class=MayaviScene),
                             height=500, width=500, show_label=False)),
                    Group(
                        Item('s'), 
                        Item('r'), 
                        Item('b'), 
                        Item('u'), Item('v'), Item('w')),
                    ),
                resizable=True
                )

    ######################################################################
    # Trait handlers.
    ###################################################################### 

    # Note that in the `on_trait_change` call below we listen for the
    # `scene.activated` trait.  This conveniently ensures that the flow
    # is generated as soon as the mlab `scene` is activated (which
    # happens when the configure/edit_traits method is called).  This
    # eliminates the need to manually call the `update_flow` method etc.
    @on_trait_change('s, r, b, scene.activated')
    def update_flow(self):
        x, y, z = self.points
        u, v, w = self.get_uvw()
        self.flow.mlab_source.set(u=u, v=v, w=w)

    @on_trait_change('u')
    def update_u(self):
        self.flow.mlab_source.set(u=self.get_vel('u'))
    
    @on_trait_change('v')
    def update_v(self):
        self.flow.mlab_source.set(v=self.get_vel('v'))

    @on_trait_change('w')
    def update_w(self):
        self.flow.mlab_source.set(w=self.get_vel('w'))

    def get_uvw(self):
        return self.get_vel('u'), self.get_vel('v'), self.get_vel('w')

    def get_vel(self, comp):
        """This function basically evaluates the user specified system
        of equations using scipy.
        """
        func_str = getattr(self, comp)
        try:
            g = scipy.__dict__
            x, y, z = self.points
            s, r, b = self.s, self.r, self.b
            val = eval(func_str, g, 
                        {'x': x, 'y': y, 'z': z,
                         's':s, 'r':r, 'b': b})
        except:
            # Mistake, so return the original value.
            val = getattr(self.flow.mlab_source, comp)
        return val

    ######################################################################
    # Private interface.
    ###################################################################### 
    def _points_default(self):
        x, y, z = np.mgrid[-50:50:100j,-50:50:100j,-10:60:70j]
        return x, y, z

    def _flow_default(self):
        x, y, z = self.points
        u, v, w = self.get_uvw()
        f = self.scene.mlab.flow(x, y, z, u, v, w)
        f.stream_tracer.integration_direction = 'both'
        f.stream_tracer.maximum_propagation = 200
        src = f.mlab_source.m_data
        o = mlab.outline()
        mlab.view(120, 60, 150)
        return f
コード例 #25
0
class DynamicRangeEditor(HasPrivateTraits):
    """ Defines an editor for dynamic ranges (i.e. ranges whose bounds can be
        changed at run time).
    """

    # The value with the dynamic range:
    value = Float

    # This determines the low end of the range:
    low = Range(0.0, 10.0, 0.0)

    # This determines the high end of the range:
    high = Range(20.0, 100.0, 20.0)

    # An integer value:
    int_value = Int

    # This determines the low end of the integer range:
    int_low = Range(0, 10, 0)

    # This determines the high end of the range:
    int_high = Range(20, 100, 20)

    # Traits view definitions:
    view = View(

        # Dynamic simple slider demo:
        Group(Item('value',
                   editor=RangeEditor(low_name='low',
                                      high_name='high',
                                      format='%.1f',
                                      label_width=28,
                                      mode='auto')),
              '_',
              Item('low'),
              Item('high'),
              '_',
              Label('Move the Low and High sliders to change the range of '
                    'Value.'),
              label='Simple Slider'),

        # Dynamic large range slider demo:
        Group(Item('value',
                   editor=RangeEditor(low_name='low',
                                      high_name='high',
                                      format='%.1f',
                                      label_width=28,
                                      mode='xslider')),
              '_',
              Item('low'),
              Item('high'),
              '_',
              Label('Move the Low and High sliders to change the range of '
                    'Value.'),
              label='Large Range Slider'),

        # Dynamic spinner demo:
        Group(Item('int_value',
                   editor=RangeEditor(low=0,
                                      high=20,
                                      low_name='int_low',
                                      high_name='int_high',
                                      format='%d',
                                      is_float=False,
                                      label_width=28,
                                      mode='spinner')),
              '_',
              Item('int_low'),
              Item('int_high'),
              '_',
              Label('Move the Low and High sliders to change the range of '
                    'Value.'),
              label='Spinner'),
        title='Dynamic Range Editor Demonstration',
        buttons=['OK'],
        resizable=True)
コード例 #26
0
ファイル: test_ui2.py プロジェクト: sjl421/code-2
class Person ( HasTraits ):

    #---------------------------------------------------------------------------
    #  Trait definitions:
    #---------------------------------------------------------------------------

    name      = Str( 'David Morrill' )
    age       = Int( 39 )
    sex       = Trait( 'Male', 'Female' )
    coolness  = Range( 0.0, 10.0, 10.0 )
    number    = Trait( 1, Range( 1, 6 ),
                       'one', 'two', 'three', 'four', 'five', 'six' )
    human     = Bool( True )
    employer  = Trait( Employer( company = 'Enthought, Inc.', boss = 'eric' ) )
    eye_color = RGBAColor
    set       = List( editor = CheckListEditor(
                                  values = [ 'one', 'two', 'three', 'four' ],
                                  cols   = 4 ) )
    font      = KivaFont
    street    = Str
    city      = Str
    state     = Str
    zip       = Int( 78663 )
    password  = Str
    books     = List( Str, [ 'East of Eden', 'The Grapes of Wrath',
                             'Of Mice and Men' ] )
    call      = Event( 0, editor = ButtonEditor( label = 'Click to call' ) )
    info      = Str( editor = FileEditor() )
    location  = Str( editor = DirectoryEditor() )
    origin    = Trait( editor = ImageEnumEditor( values = origin_values,
                                                 suffix = '_origin',
                                                 cols   = 4,
                                                 klass  = Employer ),
                       *origin_values )

    nm   = Item( 'name',     enabled_when = 'object.age >= 21' )
    pw   = Item( 'password', defined_when = 'object.zip == 78664' )
    view = View( ( ( nm, 'age', 'coolness',
                     '_', 'eye_color', 'eye_color@', 'eye_color*', 'eye_color~',
                     '_', 'font', 'font@', 'font*', 'font~',
                     '_', 'set', 'set@', 'set*', 'set~',
                     '_', 'sex', 'sex@', 'sex*', 'sex~',
                     '_', 'human', 'human@', 'human*', 'human~',
                     '_', 'number', 'number@', 'number*', 'number~',
                     '_', 'books', '_', 'books@', '_', 'books*', '_', 'books~',
                     '_', 'info', 'location', 'origin', 'origin@', 'call',
                     'employer', 'employer[]@', 'employer*', 'employer~',
                     pw,
                     '|<[Person:]' ),
                   ( ' ', 'street', 'city', 'state', 'zip', '|<[Address:]' ),
                   ( nm, nm, nm, nm, nm, nm, nm, nm, nm, nm, nm, nm, nm, nm,
                     '|<[Names:]' ),
                   '|' ),
                 title   = 'Traits 2 User Interface Test',
                 handler = PersonHandler(),
                 buttons = [ 'Apply', 'Revert', 'Undo', 'OK' ],
                 height  = 0.5 )

    wizard = View( ( '|p1:', 'name', 'age', 'sex' ),
                   ( '|p2:', 'street', 'city', 'state', 'zip' ),
                   ( '|p3:', 'eye_color', 'origin', 'human' ),
                   handler = WizardHandler() )
コード例 #27
0
ファイル: singleshot.py プロジェクト: Faridelnik/Pi3Diamond
class SSTCounterTrace(Pulsed):

    tau_begin = Range(low=0.,
                      high=1e5,
                      value=1.,
                      desc='tau begin [ns]',
                      label='repetition',
                      mode='text',
                      auto_set=False,
                      enter_set=True)
    tau_end = Range(low=1.,
                    high=1e5,
                    value=1000.,
                    desc='tau end [ns]',
                    label='N repetition',
                    mode='text',
                    auto_set=False,
                    enter_set=True)
    tau_delta = Range(low=1.,
                      high=1e5,
                      value=1,
                      desc='delta tau [ns]',
                      label='delta',
                      mode='text',
                      auto_set=False,
                      enter_set=True)

    sweeps = Range(low=1.,
                   high=1e4,
                   value=1,
                   desc='number of sweeps',
                   label='sweeps',
                   mode='text',
                   auto_set=False,
                   enter_set=True)

    def prepare_awg(self):
        sampling = 1.2e9
        N_shot = int(self.N_shot)

        pi = int(self.pi * sampling / 1.0e9)
        laser_SST = int(self.laser_SST * sampling / 1.0e9)
        wait_SST = int(self.wait_SST * sampling / 1.0e9)

        if self.reload:
            AWG.stop()
            AWG.set_output(0b0000)
            AWG.delete_all()

            zero = Idle(1)
            self.waves = []
            sub_seq = []
            p = {}

            p['pi + 0'] = Sin(pi, (self.freq - self.freq_center) / sampling, 0,
                              self.amp)
            p['pi + 90'] = Sin(pi, (self.freq - self.freq_center) / sampling,
                               np.pi / 2, self.amp)

            read_x = Waveform(
                'read_x',
                [p['pi + 0'],
                 Idle(laser_SST, marker1=1),
                 Idle(wait_SST)])
            read_y = Waveform(
                'read_y',
                [p['pi + 90'],
                 Idle(laser_SST, marker1=1),
                 Idle(wait_SST)])
            self.waves.append(read_x)
            self.waves.append(read_y)

            self.main_seq = Sequence('SST.SEQ')
            for i, t in enumerate(self.tau):
                name = 'DQH_12_%04i.SEQ' % i
                sub_seq = Sequence(name)
                sub_seq.append(read_x, read_y, repeat=N_shot)
                AWG.upload(sub_seq)

                self.main_seq.append(sub_seq, wait=True)
            for w in self.waves:
                w.join()
            AWG.upload(self.waves)
            AWG.upload(self.main_seq)
            AWG.tell('*WAI')
            AWG.load('SST.SEQ')
        AWG.set_vpp(self.vpp)
        AWG.set_sample(sampling / 1.0e9)
        AWG.set_mode('S')
        AWG.set_output(0b0011)

    def generate_sequence(self):
        points = int(self.sequence_points)
        N_shot = self.N_shot
        laser = self.laser
        wait = self.wait
        laser_SST = self.laser_SST
        wait_SST = self.wait_SST
        pi = self.pi
        record_length = self.record_length * 1e+6

        sequence = []
        for t in range(points):
            sequence.append((['laser'], laser))
            sequence.append(([], wait))
            sequence.append((['awgTrigger'], 100))
            sequence.append((['sst'], record_length))

        return sequence

    get_set_items = Pulsed.get_set_items

    traits_view = View(
        VGroup(
            HGroup(
                Item('load_button', show_label=False),
                Item('submit_button', show_label=False),
                Item('remove_button', show_label=False),
                Item('resubmit_button', show_label=False),
                Item('priority'),
            ),
            HGroup(
                Item('freq', width=-70),
                Item('freq_center', width=-70),
                Item('amp', width=-30),
                Item('vpp', width=-30),
                Item('power', width=-40),
                Item('pi', width=-70),
            ),
            HGroup(
                Item('laser', width=-60),
                Item('wait', width=-60),
                Item('laser_SST', width=-50),
                Item('wait_SST', width=-50),
            ),
            HGroup(
                Item('samples_per_read', width=-50),
                Item('N_shot', width=-50),
                Item('record_length', style='readonly'),
            ),
            HGroup(
                Item('tau_begin', width=30),
                Item('tau_end', width=30),
                Item('tau_delta', width=30),
            ),
            HGroup(
                Item('state', style='readonly'),
                Item('run_time', style='readonly', format_str='%.f',
                     width=-50),
                Item('sweeps',
                     editor=TextEditor(auto_set=False,
                                       enter_set=True,
                                       evaluate=float,
                                       format_func=lambda x: '%.1e' % x),
                     width=-50),
                Item('expected_duration',
                     style='readonly',
                     editor=TextEditor(evaluate=float,
                                       format_func=lambda x: '%.2f' % x),
                     width=30),
                Item('progress', style='readonly'),
                Item('elapsed_time',
                     style='readonly',
                     editor=TextEditor(auto_set=False,
                                       enter_set=True,
                                       evaluate=float,
                                       format_func=lambda x: ' %.f' % x),
                     width=-50),
            ),
        ),
        title='SST Trace Measurement',
    )
コード例 #28
0
class Worker(HasTraits):
    """This class basically allows you to create a data set, view it
    and modify the dataset.  This is a rather crude example but
    demonstrates how things can be done.
    """

    # Set by envisage when this is contributed as a ServiceOffer.
    window = Instance('enthought.pyface.workbench.api.WorkbenchWindow')

    create_data = Button('Create data')
    reset_data = Button('Reset data')
    view_data = Button('View data')
    scale = Range(0.0, 1.0)
    source = Instance('enthought.mayavi.core.source.Source')

    # Our UI view.
    view = View(Item('create_data', show_label=False),
                Item('view_data', show_label=False),
                Item('reset_data', show_label=False),
                Item('scale'),
                resizable=True)

    def get_mayavi(self):
        from enthought.mayavi.plugins.script import Script
        return self.window.get_service(Script)

    def _make_data(self):
        dims = [64, 64, 64]
        np = dims[0] * dims[1] * dims[2]
        x, y, z = scipy.ogrid[-5:5:dims[0] * 1j, -5:5:dims[1] * 1j,
                              -5:5:dims[2] * 1j]
        x = x.astype('f')
        y = y.astype('f')
        z = z.astype('f')
        s = (scipy.sin(x * y * z) / (x * y * z))
        s = s.transpose().copy()  # This makes the data contiguous.
        return s

    def _create_data_fired(self):
        mayavi = self.get_mayavi()
        from enthought.mayavi.sources.array_source import ArraySource
        s = self._make_data()
        src = ArraySource(transpose_input_array=False, scalar_data=s)
        self.source = src
        mayavi.add_source(src)

    def _reset_data_fired(self):
        self.source.scalar_data = self._make_data()

    def _view_data_fired(self):
        mayavi = self.get_mayavi()
        from enthought.mayavi.modules.outline import Outline
        from enthought.mayavi.modules.image_plane_widget import ImagePlaneWidget
        # Visualize the data.
        o = Outline()
        mayavi.add_module(o)
        ipw = ImagePlaneWidget()
        mayavi.add_module(ipw)
        ipw.module_manager.scalar_lut_manager.show_scalar_bar = True

        ipw_y = ImagePlaneWidget()
        mayavi.add_module(ipw_y)
        ipw_y.ipw.plane_orientation = 'y_axes'

    def _scale_changed(self, value):
        src = self.source
        data = src.scalar_data
        data += value * 0.01
        numpy.mod(data, 1.0, data)
        src.update()
コード例 #29
0
class BoundaryMarkerEditor(Filter):
    """
    Edit the boundary marker of a Triangle surface mesh. To use: select the label to
    assign, hover your cursor over the cell you wish to edit, and press 'p'.
    """

    # The version of this class.  Used for persistence.
    __version__ = 0

    _current_grid = Instance(tvtk.UnstructuredGrid, allow_none=False)
    _input_grid = Instance(tvtk.UnstructuredGrid, args=(), allow_none=False)
    _extract_cells_filter = Instance(tvtk.ExtractCells,
                                     args=(),
                                     allow_none=False)
    _dataset_manager = Instance(DatasetManager, allow_none=False)
    _cell_mappings = List

    label_to_apply = Range(0, 255)
    select_coplanar_cells = Bool
    epsilon = Range(0.0, 1.0, 0.0001)
    mask_labels = Bool
    labels_to_mask = List(label_to_apply)

    # Saving file
    output_file = File
    save = Button

    ######################################################################
    # The view.
    ######################################################################
    traits_view = \
        View(
            Group(
                Item(name='label_to_apply'),
                Item(name='select_coplanar_cells'),
                Item(name='epsilon', enabled_when='select_coplanar_cells', label='Tolerance'),
                Item(name='mask_labels'),
                Group(
                    Item(name='labels_to_mask', style='custom', editor=ListEditor(rows=3)),
                    show_labels=False,
                    show_border=True,
                    label='Labels to mask',
                    enabled_when='mask_labels==True'
                ),
                Group(
                    Item(name='output_file'),
                    Item(name='save', label='Save'),
                    show_labels=False,
                    show_border=True,
                    label='Save changes to file (give only a basename, without the file extension)'
                )
            ),
            height=500,
            width=600
        )

    ######################################################################
    # `Filter` interface.
    ######################################################################
    def update_pipeline(self):
        if len(self.inputs) == 0 or len(self.inputs[0].outputs) == 0:
            return

        # Call cell_picked() when a cell is clicked.
        self.scene.picker.cellpicker.add_observer("EndPickEvent",
                                                  self.cell_picked)
        self.scene.picker.pick_type = 'cell_picker'
        self.scene.picker.tolerance = 0.0
        self.scene.picker.show_gui = False

        self._input_grid.deep_copy(self.inputs[0].outputs[0])

        self._current_grid = self._input_grid
        self._dataset_manager = DatasetManager(dataset=self._input_grid)
        self._set_outputs([self._current_grid])

        # Filter for masking.
        self._extract_cells_filter.set_input(self._input_grid)

    def update_data(self):
        self.data_changed = True

    ######################################################################
    # Non-public interface.
    ######################################################################

    def cell_picked(self, object, event):
        cell_id = self.scene.picker.cellpicker.cell_id
        self.modify_cell(cell_id, self.label_to_apply)

        if (self.select_coplanar_cells):
            self.modify_neighbouring_cells(cell_id)

        if (self.mask_labels):
            self.perform_mask()

        self._dataset_manager.activate(self._input_grid.cell_data.scalars.name,
                                       'cell')
        self._dataset_manager.update()
        self.pipeline_changed = True

    def get_all_cell_neigbours(self, cell_id, cell):
        neighbour_cell_ids = array([], dtype=int)

        for i in range(cell.number_of_edges):
            # Get points belonging to ith edge
            edge_point_ids = cell.get_edge(i).point_ids

            # Find neigbours which share the edge
            current_neighbour_cell_ids = tvtk.IdList()
            self._current_grid.get_cell_neighbors(cell_id, edge_point_ids,
                                                  current_neighbour_cell_ids)
            neighbour_cell_ids = append(neighbour_cell_ids,
                                        array(current_neighbour_cell_ids))

        return neighbour_cell_ids.tolist()

    def modify_neighbouring_cells(self, cell_id):
        cell = self._current_grid.get_cell(cell_id)

        cell_normal = [0, 0, 0]
        cell.compute_normal(cell.points[0], cell.points[1], cell.points[2],
                            cell_normal)

        cells_pending = self.get_all_cell_neigbours(cell_id, cell)
        cells_visited = [cell_id]

        while (len(cells_pending) > 0):
            current_cell_id = cells_pending.pop()

            if (current_cell_id not in cells_visited):
                cells_visited.append(current_cell_id)
                current_cell = self._current_grid.get_cell(current_cell_id)

                current_cell_normal = [0, 0, 0]
                current_cell.compute_normal(current_cell.points[0],
                                            current_cell.points[1],
                                            current_cell.points[2],
                                            current_cell_normal)

                if (dot(cell_normal, current_cell_normal) >
                    (1 - self.epsilon)):
                    self.modify_cell(current_cell_id, self.label_to_apply)
                    cells_pending.extend(
                        self.get_all_cell_neigbours(current_cell_id,
                                                    current_cell))

    def _mask_labels_changed(self):
        if (self.mask_labels):
            self.perform_mask()
            self._current_grid = self._extract_cells_filter.output
        else:
            self._current_grid = self._input_grid

        self._set_outputs([self._current_grid])
        self.pipeline_changed = True

    def _labels_to_mask_changed(self):
        self.perform_mask()

    def _labels_to_mask_items_changed(self):
        self.perform_mask()

    def perform_mask(self):
        labels_array = self._input_grid.cell_data.get_array(
            self._input_grid.cell_data.scalars.name)
        in_masked = map(lambda x: x in self.labels_to_mask, labels_array)

        unmasked_cells_list = tvtk.IdList()
        cell_ids = range(self._input_grid.number_of_cells)
        # _cell_mappings is indexed by cell_id of the original input grid, and each value
        # is the new cell_id of the corresponding cell in the masked grid
        self._cell_mappings = map(
            lambda masked, cell_id: None
            if masked else unmasked_cells_list.insert_next_id(cell_id),
            in_masked, cell_ids)

        self._extract_cells_filter.set_cell_list(unmasked_cells_list)
        self._extract_cells_filter.update()
        self.pipeline_changed = True

    def modify_cell(self, cell_id, value):
        if (self.mask_labels):
            cell_id = self._cell_mappings.index(
                cell_id)  # Adjust cell_id if masked
        self._input_grid.cell_data.get_array(
            self._input_grid.cell_data.scalars.name)[cell_id] = value

    def _save_fired(self):
        from mayavi_amcg.triangle_writer import TriangleWriter
        if (self.output_file):
            writer = TriangleWriter(self._input_grid, self.output_file)
            writer.write()
            print "#### Saved ####"
コード例 #30
0
    def __init__(self, img_path):
        super().__init__()

        #
        # Load image data
        #
        base_path = os.path.splitext(img_path)[0]
        lenslet_path = base_path + '-lenslet.txt'
        optics_path = base_path + '-optics.txt'

        with open(lenslet_path) as f:
            tmp = eval(f.readline())
            x_offset, y_offset, right_dx, right_dy, down_dx, down_dy = \
              np.array(tmp, dtype=np.float32)

        with open(optics_path) as f:
            for line in f:
                name, val = line.strip().split()
                try:
                    setattr(self, name, np.float32(val))
                except:
                    pass

        max_angle = math.atan(self.pitch/2/self.flen)

        #
        # Prepare image
        #
        im_pil = Image.open(img_path)
        if im_pil.mode == 'RGB':
            self.NCHANNELS = 3
            w, h = im_pil.size
            im = np.zeros((h, w, 4), dtype=np.float32)
            im[:, :, :3] = np.array(im_pil).astype(np.float32)
            self.LF_dim = (ceil(h/down_dy), ceil(w/right_dx), 3)
        else:
            self.NCHANNELS = 1
            im = np.array(im_pil.getdata()).reshape(im_pil.size[::-1]).astype(np.float32)
            h, w = im.shape
            self.LF_dim = (ceil(h/down_dy), ceil(w/right_dx))

        x_start = x_offset - int(x_offset / right_dx) * right_dx
        y_start = y_offset - int(y_offset / down_dy) * down_dy
        x_ratio = self.flen * right_dx / self.pitch
        y_ratio = self.flen * down_dy / self.pitch

        #
        # Generate the cuda kernel
        #
        mod_LFview = pycuda.compiler.SourceModule(
            _kernel_tpl.render(
                newiw=self.LF_dim[1],
                newih=self.LF_dim[0],
                oldiw=w,
                oldih=h,
                x_start=x_start,
                y_start=y_start,
                x_ratio=x_ratio,
                y_ratio=y_ratio,
                x_step=right_dx,
                y_step=down_dy,
                NCHANNELS=self.NCHANNELS
                )
            )
        
        self.LFview_func = mod_LFview.get_function("LFview_kernel")
        self.texref = mod_LFview.get_texref("tex")
        
        #
        # Now generate the cuda texture
        #
        if self.NCHANNELS == 3:
            cuda.bind_array_to_texref(
                cuda.make_multichannel_2d_array(im, order="C"),
                self.texref
                )
        else:
            cuda.matrix_to_texref(im, self.texref, order="C")
            
        #
        # We could set the next if we wanted to address the image
        # in normalized coordinates ( 0 <= coordinate < 1.)
        # texref.set_flags(cuda.TRSF_NORMALIZED_COORDINATES)
        #
        self.texref.set_filter_mode(cuda.filter_mode.LINEAR)

        #
        # Prepare the traits
        #
        self.add_trait('X_angle', Range(-max_angle, max_angle, 0.0))
        self.add_trait('Y_angle', Range(-max_angle, max_angle, 0.0))
        
        self.plotdata = ArrayPlotData(LF_img=self.sampleLF())
        self.LF_img = Plot(self.plotdata)
        if self.NCHANNELS == 3:
            self.LF_img.img_plot("LF_img")
        else:
            self.LF_img.img_plot("LF_img", colormap=gray)