Example #1
0
class TableConfigurer(HasTraits):
    columns = List
    children = List(TabularAdapter)
    available_columns = List
    sparse_columns = List
    adapter = Instance(TabularAdapter)
    id = 'table'
    font = Enum(*SIZES)
    auto_set = Bool(False)
    fontsize_enabled = Bool(True)
    title = Str('Configure Table')
    refresh_func = Callable
    show_all = Button('Show All')

    set_sparse = Button('Define Sparse')
    toggle_sparse = Button('Toggle Sparse')

    sparse_enabled = Property(depends_on='columns[]')

    _toggle_sparse_enabled = Bool(False)

    default_button = Button('Default')
    defaults_path = Str

    def _get_sparse_enabled(self):
        return self.sparse_columns != self.columns and len(self.columns) < 5

    def __init__(self, *args, **kw):
        super(TableConfigurer, self).__init__(*args, **kw)
        if self.auto_set:
            self.on_trait_change(self.update, 'font, columns[]')
        self._load_state()

    def closed(self, is_ok):
        if is_ok:
            self.dump()
            self.set_columns()
            self.set_font()

    def load(self):
        self._load_state()

    def dump(self):
        self._dump_state()

    def update(self):
        self.set_font()
        self.set_columns()

    def set_font(self):
        if self.adapter:
            font = 'arial {}'.format(self.font)
            self.adapter.font = font
            for ci in self.children:
                ci.font = font

            if self.refresh_func:
                self.refresh_func()

                # self.refresh_table_needed = True

    def set_columns(self):
        # def _columns_changed(self):
        if self.adapter:
            cols = self._assemble_columns()
            for ci in self.children:
                ci.columns = cols

            cols = [ci for ci in cols if ci in self.adapter.all_columns]
            self.adapter.columns = cols

    def _set_font(self, f):
        s = f.pointSize()
        self.font = s

    def _load_state(self):
        p = os.path.join(paths.hidden_dir, self.id)
        if os.path.isfile(p):
            try:
                with open(p, 'rb') as rfile:
                    state = pickle.load(rfile)

            except (pickle.PickleError, OSError, EOFError, TraitError):
                return

            try:
                self.sparse_columns = state.get('sparse_columns')
            except:
                pass

            cols = state.get('columns')
            if cols:
                ncols = []
                for ai in self.available_columns:
                    if ai in cols:
                        ncols.append(ai)

                self.columns = ncols

            font = state.get('font', None)
            if font:
                self.font = font

            self._load_hook(state)
            self.update()

    def _dump_state(self):
        p = os.path.join(paths.hidden_dir, self.id)
        obj = self._get_dump()

        with open(p, 'wb') as wfile:
            try:
                pickle.dump(obj, wfile)
            except pickle.PickleError:
                pass

    def _get_dump(self):
        obj = dict(columns=self.columns,
                   font=self.font,
                   sparse_columns=self.sparse_columns)
        return obj

    def _load_hook(self, state):
        pass

    def _assemble_columns(self):
        d = self.adapter.all_columns_dict
        return [(k, d[k]) for k, v in self.adapter.all_columns
                if k in self.columns]

    def _get_columns_grp(self):
        return

    def _set_defaults(self):
        p = self.defaults_path
        if os.path.isfile(p):
            import yaml

            with open(p, 'r') as rfile:
                yd = yaml.load(rfile)
                try:
                    self.columns = yd['columns']
                except KeyError:
                    pass
            self.set_columns()

    def _default_button_fired(self):
        self._set_defaults()

    def _set_sparse_fired(self):
        self.sparse_columns = self.columns

    def _toggle_sparse_fired(self):
        if self._toggle_sparse_enabled:
            columns = self._prev_columns
        else:
            self._prev_columns = self.columns
            columns = self.sparse_columns

        self.columns = columns
        self.set_columns()

        self._toggle_sparse_enabled = not self._toggle_sparse_enabled

    def _show_all_fired(self):
        self.columns = self.available_columns
        self.set_columns()

    def set_adapter(self, adp):
        self.adapter = adp
        # def _adapter_changed(self, adp):
        #     if adp:
        acols = [c for c, _ in adp.all_columns]

        # set currently visible columns
        t = [c for c, _ in adp.columns]

        cols = [c for c in acols if c in t]
        self.trait_set(columns=cols)

        # set all available columns
        self.available_columns = acols
        if adp.font:
            self._set_font(adp.font)

        self._load_state()

    def traits_view(self):
        v = okcancel_view(VGroup(
            HGroup(
                UItem('show_all', tooltip='Show all columns'),
                UItem(
                    'set_sparse',
                    tooltip=
                    'Set the current set of columns to the Sparse Column Set',
                    enabled_when='sparse_enabled'),
                UItem('toggle_sparse',
                      tooltip='Display only Sparse Column Set'),
                UItem('default_button',
                      tooltip='Set to Laboratory defaults. File located at '
                      '[root]/experiments/experiment_defaults.yaml')),
            VGroup(UItem('columns',
                         style='custom',
                         editor=CheckListEditor(name='available_columns',
                                                cols=3)),
                   Item('font', enabled_when='fontsize_enabled'),
                   show_border=True)),
                          handler=TableConfigurerHandler(),
                          title=self.title)
        return v
Example #2
0
class ArraySource(Source):
    """A simple source that allows one to view a suitably shaped numpy
    array as ImageData.  This supports both scalar and vector data.
    """

    # The scalar array data we manage.
    scalar_data = Trait(None, _check_scalar_array, rich_compare=False)

    # The name of our scalar array.
    scalar_name = Str('scalar')

    # The vector array data we manage.
    vector_data = Trait(None, _check_vector_array, rich_compare=False)

    # The name of our vector array.
    vector_name = Str('vector')

    # The spacing of the points in the array.
    spacing = DelegatesTo('change_information_filter',
                          'output_spacing',
                          desc='the spacing between points in array')

    # The origin of the points in the array.
    origin = DelegatesTo('change_information_filter',
                         'output_origin',
                         desc='the origin of the points in array')

    # Fire an event to update the spacing and origin. This
    # is here for backwards compatability. Firing this is no
    # longer needed.
    update_image_data = Button('Update spacing and origin')

    # The image data stored by this instance.
    image_data = Instance(tvtk.ImageData, (), allow_none=False)

    # Use an ImageChangeInformation filter to reliably set the
    # spacing and origin on the output
    change_information_filter = Instance(tvtk.ImageChangeInformation,
                                         args=(),
                                         kw={
                                             'output_spacing': (1.0, 1.0, 1.0),
                                             'output_origin': (0.0, 0.0, 0.0)
                                         })

    # Should we transpose the input data or not.  Transposing is
    # necessary to make the numpy array compatible with the way VTK
    # needs it.  However, transposing numpy arrays makes them
    # non-contiguous where the data is copied by VTK.  Thus, when the
    # user explicitly requests that transpose_input_array is false
    # then we assume that the array has already been suitably
    # formatted by the user.
    transpose_input_array = Bool(
        True,
        desc=
        'if input array should be transposed (if on VTK will copy the input data)'
    )

    # Information about what this object can produce.
    output_info = PipelineInfo(datasets=['image_data'])

    # Specify the order of dimensions. The default is: [0, 1, 2]
    dimensions_order = List(Int, [0, 1, 2])

    # Our view.
    view = View(
        Group(Item(name='transpose_input_array'),
              Item(name='scalar_name'),
              Item(name='vector_name'),
              Item(name='spacing'),
              Item(name='origin'),
              show_labels=True))

    ######################################################################
    # `object` interface.
    ######################################################################
    def __init__(self, **traits):
        # Set the scalar and vector data at the end so we pop it here.
        sd = traits.pop('scalar_data', None)
        vd = traits.pop('vector_data', None)
        # Now set the other traits.
        super(ArraySource, self).__init__(**traits)
        self.configure_input_data(self.change_information_filter,
                                  self.image_data)

        # And finally set the scalar and vector data.
        if sd is not None:
            self.scalar_data = sd
        if vd is not None:
            self.vector_data = vd

        self.outputs = [self.change_information_filter.output]
        self.on_trait_change(self._information_changed, 'spacing,origin')

    def __get_pure_state__(self):
        d = super(ArraySource, self).__get_pure_state__()
        d.pop('image_data', None)
        return d

    ######################################################################
    # ArraySource interface.
    ######################################################################
    def update(self):
        """Call this function when you change the array data
        in-place."""
        d = self.image_data
        d.modified()
        pd = d.point_data
        if self.scalar_data is not None:
            pd.scalars.modified()
        if self.vector_data is not None:
            pd.vectors.modified()
        self.change_information_filter.update()
        self.data_changed = True

    ######################################################################
    # Non-public interface.
    ######################################################################

    def _image_data_changed(self, value):
        self.configure_input_data(self.change_information_filter, value)

    def _scalar_data_changed(self, data):
        img_data = self.image_data
        if data is None:
            img_data.point_data.scalars = None
            self.data_changed = True
            return
        dims = list(data.shape)
        if len(dims) == 2:
            dims.append(1)

        # set the dimension indices
        dim0, dim1, dim2 = self.dimensions_order

        img_data.origin = tuple(self.origin)
        img_data.dimensions = tuple(dims)
        img_data.extent = 0, dims[dim0] - 1, 0, dims[dim1] - 1, 0, dims[
            dim2] - 1
        if is_old_pipeline():
            img_data.update_extent = 0, dims[dim0] - 1, 0, dims[
                dim1] - 1, 0, dims[dim2] - 1
        else:
            update_extent = [
                0, dims[dim0] - 1, 0, dims[dim1] - 1, 0, dims[dim2] - 1
            ]
            self.change_information_filter.set_update_extent(update_extent)
        if self.transpose_input_array:
            img_data.point_data.scalars = numpy.ravel(numpy.transpose(data))
        else:
            img_data.point_data.scalars = numpy.ravel(data)
        img_data.point_data.scalars.name = self.scalar_name
        # This is very important and if not done can lead to a segfault!
        typecode = data.dtype
        if is_old_pipeline():
            img_data.scalar_type = array_handler.get_vtk_array_type(typecode)
            img_data.update()  # This sets up the extents correctly.
        else:
            filter_out_info = self.change_information_filter.get_output_information(
                0)
            img_data.set_point_data_active_scalar_info(
                filter_out_info, array_handler.get_vtk_array_type(typecode),
                -1)
            img_data.modified()
        img_data.update_traits()
        self.change_information_filter.update()

        # Now flush the mayavi pipeline.
        self.data_changed = True

    def _vector_data_changed(self, data):
        img_data = self.image_data
        if data is None:
            img_data.point_data.vectors = None
            self.data_changed = True
            return
        dims = list(data.shape)
        if len(dims) == 3:
            dims.insert(2, 1)
            data = numpy.reshape(data, dims)

        img_data.origin = tuple(self.origin)
        img_data.dimensions = tuple(dims[:-1])
        img_data.extent = 0, dims[0] - 1, 0, dims[1] - 1, 0, dims[2] - 1
        if is_old_pipeline():
            img_data.update_extent = 0, dims[0] - 1, 0, dims[1] - 1, 0, dims[
                2] - 1
        else:
            self.change_information_filter.update_information()
            update_extent = [0, dims[0] - 1, 0, dims[1] - 1, 0, dims[2] - 1]
            self.change_information_filter.set_update_extent(update_extent)
        sz = numpy.size(data)
        if self.transpose_input_array:
            data_t = numpy.transpose(data, (2, 1, 0, 3))
        else:
            data_t = data
        img_data.point_data.vectors = numpy.reshape(data_t, (sz // 3, 3))
        img_data.point_data.vectors.name = self.vector_name
        if is_old_pipeline():
            img_data.update()  # This sets up the extents correctly.
        else:
            img_data.modified()
        img_data.update_traits()
        self.change_information_filter.update()

        # Now flush the mayavi pipeline.
        self.data_changed = True

    def _scalar_name_changed(self, value):
        if self.scalar_data is not None:
            self.image_data.point_data.scalars.name = value
            self.data_changed = True

    def _vector_name_changed(self, value):
        if self.vector_data is not None:
            self.image_data.point_data.vectors.name = value
            self.data_changed = True

    def _transpose_input_array_changed(self, value):
        if self.scalar_data is not None:
            self._scalar_data_changed(self.scalar_data)
        if self.vector_data is not None:
            self._vector_data_changed(self.vector_data)

    def _information_changed(self):
        self.change_information_filter.update()
        self.data_changed = True
Example #3
0
class ODMR(ManagedJob, GetSetItemsMixin):
    """Provides ODMR measurements."""

    # starting and stopping
    keep_data = Bool(
        False)  # helper variable to decide whether to keep existing data
    resubmit_button = Button(
        label='resubmit',
        desc=
        'Submits the measurement to the job manager. Tries to keep previously acquired data. Behaves like a normal submit if sequence or time bins have changed since previous run.'
    )

    # measurement parameters
    """switch      = Enum( 'mw_a', 'mw_b','mw_c',   desc='switch to use for different microwave source',     label='switch' )"""
    power = Range(low=-100.,
                  high=25.,
                  value=-12,
                  desc='Power [dBm]',
                  label='Power [dBm]',
                  mode='text',
                  auto_set=False,
                  enter_set=True)
    frequency_begin = Range(low=1,
                            high=20e9,
                            value=2.82e9,
                            desc='Start Frequency [Hz]',
                            label='Begin [Hz]',
                            editor=TextEditor(auto_set=False,
                                              enter_set=True,
                                              evaluate=float,
                                              format_str='%e'))
    frequency_end = Range(low=1,
                          high=20e9,
                          value=2.88e9,
                          desc='Stop Frequency [Hz]',
                          label='End [Hz]',
                          editor=TextEditor(auto_set=False,
                                            enter_set=True,
                                            evaluate=float,
                                            format_str='%e'))
    frequency_delta = Range(low=1e-3,
                            high=20e9,
                            value=1e6,
                            desc='frequency step [Hz]',
                            label='Delta [Hz]',
                            editor=TextEditor(auto_set=False,
                                              enter_set=True,
                                              evaluate=float,
                                              format_str='%e'))
    t_pi = Range(low=1.,
                 high=100000.,
                 value=1000.,
                 desc='length of pi pulse [ns]',
                 label='pi [ns]',
                 mode='text',
                 auto_set=False,
                 enter_set=True)
    laser = Range(low=1.,
                  high=10000.,
                  value=300.,
                  desc='laser [ns]',
                  label='laser [ns]',
                  mode='text',
                  auto_set=False,
                  enter_set=True)
    wait = Range(low=1.,
                 high=10000.,
                 value=1000.,
                 desc='wait [ns]',
                 label='wait [ns]',
                 mode='text',
                 auto_set=False,
                 enter_set=True)
    pulsed = Bool(False, label='pulsed', enabled_when='state != "run"')
    sequence = Property(trait=List, depends_on='laser,wait,t_pi')
    seconds_per_point = Range(low=20e-3,
                              high=1,
                              value=20e-3,
                              desc='Seconds per point',
                              label='Seconds per point',
                              mode='text',
                              auto_set=False,
                              enter_set=True)
    stop_time = Range(
        low=1.,
        value=np.inf,
        desc='Time after which the experiment stops by itself [s]',
        label='Stop time [s]',
        mode='text',
        auto_set=False,
        enter_set=True)
    n_lines = Range(low=1,
                    high=10000,
                    value=50,
                    desc='Number of lines in Matrix',
                    label='Matrix lines',
                    mode='text',
                    auto_set=False,
                    enter_set=True)

    # control data fitting
    perform_fit = Bool(False, label='perform fit')
    number_of_resonances = Trait(
        'auto', String('auto', auto_set=False, enter_set=True),
        Int(10000.,
            desc='Number of Lorentzians used in fit',
            label='N',
            auto_set=False,
            enter_set=True))
    threshold = Range(
        low=-99,
        high=99.,
        value=-50.,
        desc=
        'Threshold for detection of resonances [%]. The sign of the threshold specifies whether the resonances are negative or positive.',
        label='threshold [%]',
        mode='text',
        auto_set=False,
        enter_set=True)

    # fit result
    fit_parameters = Array(value=np.array((np.nan, np.nan, np.nan, np.nan)))
    fit_frequencies = Array(value=np.array((np.nan, )),
                            label='frequencies [Hz]')
    fit_line_width = Float(np.nan,
                           label='line width [Hz]',
                           editor=TextEditor(evaluate=float,
                                             format_str='%.3e'))
    fit_contrast = Float(np.nan,
                         label='contrast [%]',
                         editor=TextEditor(evaluate=float, format_str='%.1f'))

    # measurement data
    frequency = Array()
    counts = Array()
    counts_matrix = Array()
    run_time = Float(value=0.0, desc='Run time [s]', label='Run time [s]')

    # plotting
    line_label = Instance(PlotLabel)
    line_data = Instance(ArrayPlotData)
    matrix_data = Instance(ArrayPlotData)
    line_plot = Instance(Plot, editor=ComponentEditor())
    matrix_plot = Instance(Plot, editor=ComponentEditor())

    def __init__(self):
        super(ODMR, self).__init__()
        self._create_line_plot()
        self._create_matrix_plot()
        self.on_trait_change(self._update_line_data_index,
                             'frequency',
                             dispatch='ui')
        self.on_trait_change(self._update_line_data_value,
                             'counts',
                             dispatch='ui')
        self.on_trait_change(self._update_line_data_fit,
                             'fit_parameters',
                             dispatch='ui')
        self.on_trait_change(self._update_matrix_data_value,
                             'counts_matrix',
                             dispatch='ui')
        self.on_trait_change(self._update_matrix_data_index,
                             'n_lines,frequency',
                             dispatch='ui')

    def _counts_matrix_default(self):
        return np.zeros((self.n_lines, len(self.frequency)))

    def _frequency_default(self):
        return np.arange(self.frequency_begin,
                         self.frequency_end + self.frequency_delta,
                         self.frequency_delta)

    def _counts_default(self):
        return np.zeros(self.frequency.shape)

    # data acquisition

    def apply_parameters(self):
        """Apply the current parameters and decide whether to keep previous data."""
        frequency = np.arange(self.frequency_begin,
                              self.frequency_end + self.frequency_delta,
                              self.frequency_delta)

        if not self.keep_data or np.any(frequency != self.frequency):
            self.counts = np.zeros(frequency.shape)
            self.run_time = 0.0

        self.frequency = frequency
        self.keep_data = True  # when job manager stops and starts the job, data should be kept. Only new submission should clear data.

    def _run(self):

        try:
            self.state = 'run'
            self.apply_parameters()

            if self.run_time >= self.stop_time:
                self.state = 'done'
                return

            # if pulsed, turn on sequence
            if self.pulsed:
                ha.PulseGenerator().Sequence(100 *
                                             [(['laser', 'aom'], self.laser),
                                              ([], self.wait),
                                              (['mw'], self.t_pi)])
            else:
                ha.PulseGenerator().Open()

            n = len(self.frequency)
            """
            ha.MicrowaveB().setOutput( self.power, np.append(self.frequency,self.frequency[0]), self.seconds_per_point)
            self._prepareCounter(n)
            """
            ha.MicrowaveB().setPower(self.power)
            ha.MicrowaveB().initSweep(
                self.frequency, self.power * np.ones(self.frequency.shape))
            ha.CounterB().configure(n, self.seconds_per_point, DutyCycle=0.8)
            time.sleep(0.5)

            while self.run_time < self.stop_time:
                start_time = time.time()
                if threading.currentThread().stop_request.isSet():
                    break
                ha.MicrowaveB().resetListPos()
                counts = ha.CounterB().run()
                self.run_time += time.time() - start_time
                self.counts += counts
                self.counts_matrix = np.vstack(
                    (counts, self.counts_matrix[:-1, :]))
                self.trait_property_changed('counts', self.counts)
                """
                ha.MicrowaveB().doSweep()
                
                timeout = 3.
                start_time = time.time()
                while not self._count_between_markers.ready():
                    time.sleep(0.1)
                    if time.time() - start_time > timeout:
                        print "count between markers timeout in ODMR"
                        break
                        
                counts = self._count_between_markers.getData(0)
                self._count_between_markers.clean()
                """

            if self.run_time < self.stop_time:
                self.state = 'idle'
            else:
                self.state = 'done'
            ha.MicrowaveB().setOutput(None, self.frequency_begin)
            ha.PulseGenerator().Light()
            ha.CounterB().clear()
        except:
            logging.getLogger().exception('Error in odmr.')
            self.state = 'error'

    # fitting

    @on_trait_change('counts,perform_fit,number_of_resonances,threshold')
    def update_fit(self):
        if self.perform_fit:
            N = self.number_of_resonances
            if N != 'auto':
                N = int(N)
            try:
                p = fitting.fit_multiple_lorentzians(self.frequency,
                                                     self.counts,
                                                     N,
                                                     threshold=self.threshold *
                                                     0.01)
            except Exception:
                logging.getLogger().debug('ODMR fit failed.', exc_info=True)
                p = np.nan * np.empty(4)
        else:
            p = np.nan * np.empty(4)
        self.fit_parameters = p
        self.fit_frequencies = p[1::3]
        self.fit_line_width = p[2::3].mean()
        self.fit_contrast = -100 * p[3::3].mean() / (np.pi * p[2::3].mean() *
                                                     p[0])

    # plotting

    def _create_line_plot(self):
        line_data = ArrayPlotData(frequency=np.array((0., 1.)),
                                  counts=np.array((0., 0.)),
                                  fit=np.array((0., 0.)))
        line_plot = Plot(line_data,
                         padding=8,
                         padding_left=64,
                         padding_bottom=32)
        line_plot.plot(('frequency', 'counts'), style='line', color='blue')
        line_plot.index_axis.title = 'Frequency [MHz]'
        line_plot.value_axis.title = 'Fluorescence counts'
        line_label = PlotLabel(text='',
                               hjustify='left',
                               vjustify='bottom',
                               position=[64, 32])
        line_plot.overlays.append(line_label)
        self.line_label = line_label
        self.line_data = line_data
        self.line_plot = line_plot

    def _create_matrix_plot(self):
        matrix_data = ArrayPlotData(image=np.zeros((2, 2)))
        matrix_plot = Plot(matrix_data,
                           padding=8,
                           padding_left=64,
                           padding_bottom=32)
        matrix_plot.index_axis.title = 'Frequency [MHz]'
        matrix_plot.value_axis.title = 'line #'
        matrix_plot.img_plot('image',
                             xbounds=(self.frequency[0], self.frequency[-1]),
                             ybounds=(0, self.n_lines),
                             colormap=Spectral)
        self.matrix_data = matrix_data
        self.matrix_plot = matrix_plot

    def _perform_fit_changed(self, new):
        plot = self.line_plot
        if new:
            plot.plot(('frequency', 'fit'),
                      style='line',
                      color='red',
                      name='fit')
            self.line_label.visible = True
        else:
            plot.delplot('fit')
            self.line_label.visible = False
        plot.request_redraw()

    def _update_line_data_index(self):
        self.line_data.set_data('frequency', self.frequency * 1e-6)
        self.counts_matrix = self._counts_matrix_default()

    def _update_line_data_value(self):
        self.line_data.set_data('counts', self.counts)

    def _update_line_data_fit(self):
        if not np.isnan(self.fit_parameters[0]):
            self.line_data.set_data(
                'fit',
                fitting.NLorentzians(*self.fit_parameters)(self.frequency))
            p = self.fit_parameters
            f = p[1::3]
            w = p[2::3]
            c = 100 * p[3::3] / (np.pi * w * p[0])
            s = ''
            for i, fi in enumerate(f):
                s += 'f %i: %.6e Hz, HWHM %.3e Hz, contrast %.1f%%\n' % (
                    i + 1, fi, w[i], c[i])
            self.line_label.text = s

    def _update_matrix_data_value(self):
        self.matrix_data.set_data('image', self.counts_matrix)

    def _update_matrix_data_index(self):
        if self.n_lines > self.counts_matrix.shape[0]:
            self.counts_matrix = np.vstack(
                (self.counts_matrix,
                 np.zeros((self.n_lines - self.counts_matrix.shape[0],
                           self.counts_matrix.shape[1]))))
        else:
            self.counts_matrix = self.counts_matrix[:self.n_lines]
        self.matrix_plot.components[0].index.set_data(
            (self.frequency[0] * 1e-6, self.frequency[-1] * 1e-6),
            (0.0, float(self.n_lines)))

    # saving data

    def save_line_plot(self, filename):
        self.save_figure(self.line_plot, filename)

    def save_matrix_plot(self, filename):
        self.save_figure(self.matrix_plot, filename)

    def save_all(self, filename):
        self.save_line_plot(filename + '_ODMR_Line_Plot.png')
        self.save_matrix_plot(filename + '_ODMR_Matrix_Plot.png')
        self.save(filename + '_ODMR.pys')

    # react to GUI events

    def submit(self):
        """Submit the job to the JobManager."""
        self.keep_data = False
        ManagedJob.submit(self)

    def resubmit(self):
        """Submit the job to the JobManager."""
        self.keep_data = True
        ManagedJob.submit(self)

    def _resubmit_button_fired(self):
        """React to start button. Submit the Job."""
        self.resubmit()

    traits_view = View(VGroup(
        HGroup(
            Item('submit_button', show_label=False),
            Item('remove_button', show_label=False),
            Item('resubmit_button', show_label=False),
            Item('priority'),
            Item('state', style='readonly'),
            Item('run_time', style='readonly', format_str='%.f'),
        ),
        Group(
            VGroup(HGroup(
                Item('power', width=-40),
                Item('frequency_begin', width=-80),
                Item('frequency_end', width=-80),
                Item('frequency_delta', width=-80),
                Item('pulsed'),
                Item('t_pi', width=-80),
            ),
                   HGroup(
                       Item('perform_fit'),
                       Item('number_of_resonances'),
                       Item('threshold'),
                       Item('n_lines'),
                       Item('stop_time'),
                   ),
                   HGroup(
                       Item('fit_contrast', style='readonly'),
                       Item('fit_line_width', style='readonly'),
                       Item('fit_frequencies', style='readonly'),
                   ),
                   label='parameters'),
            VGroup(HGroup(Item('seconds_per_point'), ),
                   HGroup(
                       Item('laser', width=40),
                       Item('wait', width=40),
                   ),
                   label='settings'),
            layout='tabbed',
        ),
        VSplit(
            Item('matrix_plot', show_label=False, resizable=True),
            Item('line_plot', show_label=False, resizable=True),
        ),
    ),
                       menubar=MenuBar(
                           Menu(Action(action='saveLinePlot',
                                       name='SaveLinePlot (.png)'),
                                Action(action='saveMatrixPlot',
                                       name='SaveMatrixPlot (.png)'),
                                Action(action='save',
                                       name='Save (.pyd or .pys)'),
                                Action(action='saveAll',
                                       name='Save All (.png+.pys)'),
                                Action(action='export',
                                       name='Export as Ascii (.asc)'),
                                Action(action='load', name='Load'),
                                Action(action='_on_close', name='Quit'),
                                name='File')),
                       title='ODMR',
                       width=900,
                       height=800,
                       buttons=[],
                       resizable=True,
                       handler=ODMRHandler)

    get_set_items = [
        'frequency', 'counts', 'counts_matrix', 'fit_parameters',
        'fit_contrast', 'fit_line_width', 'fit_frequencies', 'perform_fit',
        'run_time', 'power', 'frequency_begin', 'frequency_end',
        'frequency_delta', 'laser', 'wait', 'pulsed', 't_pi',
        'seconds_per_point', 'stop_time', 'n_lines', 'number_of_resonances',
        'threshold', '__doc__'
    ]
class DataFrameAnalyzerView(ModelView):
    """ Flexible ModelView class for a DataFrameAnalyzer.

    The view is built using many methods to build each component of the view so
    it can easily be subclassed and customized.

    TODO: add traits events to pass update/refresh notifications to the
     DFEditors once we have updated TraitsUI.

    TODO: Add traits events to receive notifications that a column/row was
     clicked/double-clicked.
    """
    #: Model being viewed
    model = Instance(DataFrameAnalyzer)

    #: Selected list of data columns to display and analyze
    visible_columns = List(Str)

    #: Check box to hide/show what stats are included in the summary DF
    show_summary_controls = Bool

    #: Show the summary categorical df
    show_categorical_summary = Bool(True)

    #: Check box to hide/show what columns to analyze (panel when few columns)
    show_column_controls = Bool

    #: Open control for what columns to analyze (popup when many columns)
    open_column_controls = Button("Show column control")

    #: Button to launch the plotter tool when plotter_layout=popup
    plotter_launcher = Button("Launch Plot Tool")

    # Plotting tool attributes ------------------------------------------------

    #: Does the UI expose a DF plotter?
    include_plotter = Bool

    #: Plot manager view to display. Ignored if include_plotter is False.
    plotter = Instance(DataFramePlotManagerView)

    # Styling and branding attributes -----------------------------------------

    #: String describing the font to use, or dict mapping column names to font
    fonts = Either(Str, Dict)

    #: Name of the font to use if same across all columns
    font_name = Str(DEFAULT_FONT)

    #: Size of the font to use if same across all columns
    font_size = Int(14)

    #: Number of digits to display in the tables
    display_precision = Int(-1)

    #: Formatting to use to include
    formats = Either(Str, Dict)

    #: UI title for the Data section
    data_section_title = Str("Data")

    #: Exploration group label: visible only when plotter_layout="Tabbed"
    exploration_group_label = Str("Exploration Tools")

    #: Plotting group label: visible only when plotter_layout="Tabbed"
    plotting_group_label = Str("Plotting Tools")

    #: UI title for the data summary section
    summary_section_title = Str

    #: UI title for the categorical data summary section
    cat_summary_section_title = Str("Categorical data summary")

    #: UI title for the column list section
    column_list_section_title = Str("Column content")

    #: UI title for the summary content section
    summary_content_section_title = Str("Summary content")

    #: UI summary group (tab) name for numerical columns
    num_summary_group_name = Str("Numerical data")

    #: UI summary group (tab) name for categorical columns
    cat_summary_group_name = Str("Categorical data")

    #: Text to display in title bar of the containing window (if applicable)
    app_title = Str("Tabular Data Analyzer")

    #: How to place the plotter tool with respect to the exploration tool?
    plotter_layout = Enum("Tabbed", "HSplit", "VSplit", "popup")

    #: DFPlotManager traits to customize it
    plotter_kw = Dict

    #: Message displayed below the table if truncated
    truncation_msg = Property(Str, depends_on="model.num_displayed_rows")

    # Functionality controls --------------------------------------------------

    #: Button to shuffle the order of the filtered data
    shuffle_button = Button("Shuffle")

    show_shuffle_button = Bool(True)

    #: Button to display more rows in the data table
    show_more_button = Button

    #: Button to display all rows in the data table
    show_all_button = Button("Show All")

    #: Apply button for the filter if model not in auto-apply mode
    apply_filter_button = ToolbarButton(image=apply_img)

    #: Edit the filter in a pop-out dialog
    pop_out_filter_button = ToolbarButton(image=pop_out_img)

    #: Whether to support saving, and loading filters
    filter_manager = Bool

    #: Button to launch filter expression manager to load an existing filter
    load_filter_button = ToolbarButton(image=load_img)

    #: Button to save current filter expression
    save_filter_button = ToolbarButton(image=save_img)

    #: Button to launch filter expression manager to modify saved filters
    manage_filter_button = ToolbarButton(image=manage_img)

    #: List of saved filtered expressions
    _known_expr = Property(Set, depends_on="model.known_filter_exps")

    #: Show the bottom panel with the summary of the data:
    _show_summary = Bool(True)

    allow_show_summary = Bool(True)

    #: Button to export the analyzed data to a CSV file
    data_exporter = Button("Export Data to CSV")

    #: Button to export the summary data to a CSV file
    summary_exporter = Button("Export Summary to CSV")

    # Detailed configuration traits -------------------------------------------

    #: View class to use. Modify to customize.
    view_klass = Any(View)

    #: Width of the view
    view_width = Int(1100)

    #: Height of the view
    view_height = Int(700)

    #: Width of the filter box
    filter_item_width = Int(400)

    max_names_per_column = Int(12)

    truncation_msg_template = Str("Table truncated at {} rows")

    warn_if_sel_hidden = Bool(True)

    hidden_selection_msg = Str

    #: Column names (as a list) to include in filter editor assistant
    filter_editor_cols = List

    # Implementation details --------------------------------------------------

    #: Evaluate number of columns to select panel or popup column control
    _many_columns = Property(Bool, depends_on="model.column_list")

    #: Popped-up UI to control the visible columns
    _control_popup = Any

    #: Collected traitsUI editors for both the data DF and the summary DF
    _df_editors = Dict

    # HasTraits interface -----------------------------------------------------

    def __init__(self, **traits):
        if "model" in traits and isinstance(traits["model"], pd.DataFrame):
            traits["model"] = DataFrameAnalyzer(source_df=traits["model"])

        super(DataFrameAnalyzerView, self).__init__(**traits)

        if self.include_plotter:
            # If a plotter view was specified, its model should be in the
            # model's list of plot managers:
            if self.plotter.model not in self.model.plot_manager_list:
                self.model.plot_manager_list.append(self.plotter.model)

    def traits_view(self):
        """ Putting the view components together.

        Each component of the view is built in a separate method so it can
        easily be subclassed and customized.
        """
        # Construction of view groups -----------------------------------------

        data_group = self.view_data_group_builder()
        column_controls_group = self.view_data_control_group_builder()
        summary_group = self.view_summary_group_builder()
        summary_controls_group = self.view_summary_control_group_builder()
        if self.show_categorical_summary:
            cat_summary_group = self.view_cat_summary_group_builder()
        else:
            cat_summary_group = None
        plotter_group = self.view_plotter_group_builder()

        button_content = [
            Item("data_exporter", show_label=False),
            Spring(),
            Item("summary_exporter", show_label=False)
        ]

        if self.plotter_layout == "popup":
            button_content += [
                Spring(),
                Item("plotter_launcher", show_label=False)
            ]

        button_group = HGroup(*button_content)

        # Organization of item groups -----------------------------------------

        # If both types of summary are available, display as Tabbed view:
        if summary_group is not None and cat_summary_group is not None:
            summary_container = Tabbed(
                HSplit(
                    summary_controls_group,
                    summary_group,
                    label=self.num_summary_group_name
                ),
                cat_summary_group,
            )
        elif cat_summary_group is not None:
            summary_container = cat_summary_group
        else:
            summary_container = HSplit(
                summary_controls_group,
                summary_group
            )

        # Allow to hide all summary information:
        summary_container.visible_when = "_show_summary"

        exploration_groups = VGroup(
            VSplit(
                HSplit(
                    column_controls_group,
                    data_group,
                ),
                summary_container
            ),
            button_group,
            label=self.exploration_group_label
        )

        if self.include_plotter and self.plotter_layout != "popup":
            layout = getattr(traitsui.api, self.plotter_layout)
            groups = layout(
                exploration_groups,
                plotter_group
            )
        else:
            groups = exploration_groups

        view = self.view_klass(
            groups,
            resizable=True,
            title=self.app_title,
            width=self.view_width, height=self.view_height
        )
        return view

    # Traits view building methods --------------------------------------------

    def view_data_group_builder(self):
        """ Build view element for the Data display
        """
        editor_kw = dict(show_index=True, columns=self.visible_columns,
                         fonts=self.fonts, formats=self.formats)
        data_editor = DataFrameEditor(selected_row="selected_idx",
                                      multi_select=True, **editor_kw)

        filter_group = HGroup(
            Item("model.filter_exp", label="Filter",
                 width=self.filter_item_width),
            Item("pop_out_filter_button", show_label=False, style="custom",
                 tooltip="Open filter editor..."),
            Item("apply_filter_button", show_label=False,
                 visible_when="not model.filter_auto_apply", style="custom",
                 tooltip="Apply current filter"),
            Item("save_filter_button", show_label=False,
                 enabled_when="model.filter_exp not in _known_expr",
                 visible_when="filter_manager", style="custom",
                 tooltip="Save current filter"),
            Item("load_filter_button", show_label=False,
                 visible_when="filter_manager", style="custom",
                 tooltip="Load a filter..."),
            Item("manage_filter_button", show_label=False,
                 visible_when="filter_manager", style="custom",
                 tooltip="Manage filters..."),
        )

        truncated = ("len(model.displayed_df) < len(model.filtered_df) and "
                     "not model.show_selected_only")
        more_label = "Show {} More".format(self.model.num_display_increment)
        display_control_group = HGroup(
            Item("model.show_selected_only", label="Selected rows only"),
            Item("truncation_msg", style="readonly", show_label=False,
                 visible_when=truncated),
            Item("show_more_button", editor=ButtonEditor(label=more_label),
                 show_label=False, visible_when=truncated),
            Item("show_all_button", show_label=False,
                 visible_when=truncated),
        )

        data_group = VGroup(
            make_window_title_group(self.data_section_title, title_size=3,
                                    include_blank_spaces=False),
            HGroup(
                Item("model.sort_by_col", label="Sort by"),
                Item("shuffle_button", show_label=False,
                     visible_when="show_shuffle_button"),
                Spring(),
                filter_group
            ),
            HGroup(
                Item("model.displayed_df", editor=data_editor,
                     show_label=False),
            ),
            HGroup(
                Item("show_column_controls",
                     label="\u2190 Show column control",
                     visible_when="not _many_columns"),
                Item("open_column_controls", show_label=False,
                     visible_when="_many_columns"),
                Spring(),
                Item("_show_summary", label=u'\u2193 Show summary',
                     visible_when="allow_show_summary"),
                Spring(),
                display_control_group
            ),
            show_border=True
        )
        return data_group

    def view_data_control_group_builder(self, force_visible=False):
        """ Build view element for the Data column control.

        Parameters
        ----------
        force_visible : bool
            Controls visibility of the created group. Don't force for the group
            embedded in the global view, but force it when opened as a popup.
        """
        num_cols = 1 + len(self.model.column_list) // self.max_names_per_column

        column_controls_group = VGroup(
            make_window_title_group(self.column_list_section_title,
                                    title_size=3, include_blank_spaces=False),
            Item("visible_columns", show_label=False,
                 editor=CheckListEditor(values=self.model.column_list,
                                        cols=num_cols),
                 # The custom style allows to control a list of options rather
                 # than having a checklist editor for a single value:
                 style='custom'),
            show_border=True
        )
        if force_visible:
            column_controls_group.visible_when = ""
        else:
            column_controls_group.visible_when = "show_column_controls"

        return column_controls_group

    def view_summary_group_builder(self):
        """ Build view element for the numerical data summary display
        """
        editor_kw = dict(show_index=True, columns=self.visible_columns,
                         fonts=self.fonts, formats=self.formats)
        summary_editor = DataFrameEditor(**editor_kw)

        summary_group = VGroup(
            make_window_title_group(self.summary_section_title, title_size=3,
                                    include_blank_spaces=False),
            Item("model.summary_df", editor=summary_editor, show_label=False,
                 visible_when="len(model.summary_df) != 0"),
            # Workaround the fact that the Label's visible_when is buggy:
            # encapsulate it into a group and add the visible_when to the group
            HGroup(
                Label("No data columns with numbers were found."),
                visible_when="len(model.summary_df) == 0"
            ),
            HGroup(
                Item("show_summary_controls"),
                Spring(),
                visible_when="len(model.summary_df) != 0"
            ),
            show_border=True,
        )
        return summary_group

    def view_summary_control_group_builder(self):
        """ Build view element for the column controls for data summary.
        """
        summary_controls_group = VGroup(
            make_window_title_group(self.summary_content_section_title,
                                    title_size=3, include_blank_spaces=False),
            Item("model.summary_index", show_label=False),
            visible_when="show_summary_controls",
            show_border=True
        )

        return summary_controls_group

    def view_cat_summary_group_builder(self):
        """ Build view element for the categorical data summary display.
        """
        editor_kw = dict(show_index=True, fonts=self.fonts,
                         formats=self.formats)
        summary_editor = DataFrameEditor(**editor_kw)

        cat_summary_group = VGroup(
            make_window_title_group(self.cat_summary_section_title,
                                    title_size=3, include_blank_spaces=False),
            Item("model.summary_categorical_df", editor=summary_editor,
                 show_label=False,
                 visible_when="len(model.summary_categorical_df)!=0"),
            # Workaround the fact that the Label's visible_when is buggy:
            # encapsulate it into a group and add the visible_when to the group
            HGroup(
                Label("No data columns with numbers were found."),
                visible_when="len(model.summary_categorical_df)==0"
            ),
            show_border=True, label=self.cat_summary_group_name
        )
        return cat_summary_group

    def view_plotter_group_builder(self):
        """ Build view element for the plotter tool.
        """
        plotter_group = VGroup(
            Item("plotter", editor=InstanceEditor(), show_label=False,
                 style="custom"),
            label=self.plotting_group_label
        )
        return plotter_group

    # Public interface --------------------------------------------------------

    def destroy(self):
        """ Clean up resources.
        """
        if self._control_popup:
            self._control_popup.dispose()

    # Traits listeners --------------------------------------------------------

    def _open_column_controls_fired(self):
        """ Pop-up a new view on the column list control.
        """
        if self._control_popup and self._control_popup.control:
            # If there is an existing window, bring it in focus:
            # Discussion: https://stackoverflow.com/questions/2240717/in-qt-how-do-i-make-a-window-be-the-current-window  # noqa
            self._control_popup.control._mw.activateWindow()
            return

        # Before viewing self with a simplified view, make sure the original
        # view editors are collected so they can be modified when the controls
        # are used:
        if not self._df_editors:
            self._collect_df_editors()

        view = self.view_klass(
            self.view_data_control_group_builder(force_visible=True),
            buttons=[OKButton],
            width=600, resizable=True,
            title="Control visible columns"
        )
        # WARNING: this will modify the info object the view points to!
        self._control_popup = self.edit_traits(view=view, kind="live")

    def _shuffle_button_fired(self):
        self.model.shuffle_filtered_df()

    def _apply_filter_button_fired(self):
        flt = self.model.filter_exp
        msg = f"Applying filter {flt}."
        logger.log(ACTION_LEVEL, msg)

        self.model.recompute_filtered_df()

    def _pop_out_filter_button_fired(self):
        if not self.filter_editor_cols:
            # if there are no included columns, then use all categorical cols
            df = self.model.source_df
            cat_df = df.select_dtypes(include=CATEGORICAL_COL_TYPES)
            self.filter_editor_cols = list(cat_df.columns)
        filter_editor = FilterExpressionEditorView(
            expr=self.model.filter_exp, view_klass=self.view_klass,
            source_df=self.model.source_df,
            included_cols=self.filter_editor_cols)
        ui = filter_editor.edit_traits(kind="livemodal")
        if ui.result:
            self.model.filter_exp = filter_editor.expr
            self.apply_filter_button = True

    def _manage_filter_button_fired(self):
        """ TODO: review if replacing the copy by a deepcopy or removing the
             copy altogether would help traits trigger listeners correctly
        """
        msg = "Opening filter manager."
        logger.log(ACTION_LEVEL, msg)

        # Make a copy of the list of filters so the model can listen to changes
        # even if only a field of an existing filter is modified:
        filter_manager = FilterExpressionManager(
            known_filter_exps=copy(self.model.known_filter_exps),
            mode="manage", view_klass=self.view_klass
        )
        ui = filter_manager.edit_traits(kind="livemodal")
        if ui.result:
            # FIXME: figure out why this simpler assignment doesn't trigger the
            #  traits listener on the model when changing a FilterExpression
            #  attribute:
            # self.model.known_filter_exps = filter_manager.known_filter_exps

            self.model.known_filter_exps = [
                FilterExpression(name=e.name, expression=e.expression) for e in
                filter_manager.known_filter_exps
            ]

    def _load_filter_button_fired(self):
        filter_manager = FilterExpressionManager(
            known_filter_exps=self.model.known_filter_exps,
            mode="load", view_klass=self.view_klass
        )
        ui = filter_manager.edit_traits(kind="livemodal")
        if ui.result:
            selection = filter_manager.selected_expression
            self.model.filter_exp = selection.expression

    def _save_filter_button_fired(self):
        exp = self.model.filter_exp
        if exp in [e.expression for e in self.model.known_filter_exps]:
            return

        expr = FilterExpression(name=exp, expression=exp)
        self.model.known_filter_exps.append(expr)

    def _show_more_button_fired(self):
        self.model.num_displayed_rows += self.model.num_display_increment

    def _show_all_button_fired(self):
        self.model.num_displayed_rows = -1

    @on_trait_change("model:selected_data_in_plotter_updated", post_init=True)
    def warn_if_selection_hidden(self):
        """ Pop up warning msg if some of the selected rows aren't displayed.
        """
        if not self.warn_if_sel_hidden:
            return

        if not self.model.selected_idx:
            return

        truncated = len(self.model.displayed_df) < len(self.model.filtered_df)
        max_displayed = self.model.displayed_df.index.max()
        some_selection_hidden = max(self.model.selected_idx) > max_displayed
        if truncated and some_selection_hidden:
            warning(None, self.hidden_selection_msg, "Hidden selection")

    @on_trait_change("visible_columns[]", post_init=True)
    def update_filtered_df_on_columns(self):
        """ Just show the columns that are set to visible.

        Notes
        -----
        We are not modifying the filtered data because if we remove a column
        and then bring it back, the adapter breaks because it is missing data.
        Breakage happen when removing a column if the model is changed first,
        or when bring a column back if the adapter column list is changed
        first.
        """
        if not self.info.initialized:
            return

        if not self._df_editors:
            self._collect_df_editors()

        # Rebuild the column list (col name, column id) for the tabular
        # adapter:
        all_visible_cols = [(col, col) for col in self.visible_columns]

        df = self.model.source_df
        cat_dtypes = self.model.categorical_dtypes
        summarizable_df = df.select_dtypes(exclude=cat_dtypes)
        summary_visible_cols = [(col, col) for col in self.visible_columns
                                if col in summarizable_df.columns]

        for df_name, cols in zip(["displayed_df", "summary_df"],
                                 [all_visible_cols, summary_visible_cols]):
            df = getattr(self.model, df_name)
            index_name = df.index.name
            if index_name is None:
                index_name = ''

            # This grabs the corresponding _DataFrameEditor (not the editor
            # factory) which has access to the adapter object:
            editor = self._df_editors[df_name]
            editor.adapter.columns = [(index_name, 'index')] + cols

    def _collect_df_editors(self):
        for df_name in ["displayed_df", "summary_df"]:
            try:
                # This grabs the corresponding _DataFrameEditor (not the editor
                # factory) which has access to the adapter object:
                self._df_editors[df_name] = getattr(self.info, df_name)
            except Exception as e:
                msg = "Error trying to collect the tabular adapter: {}"
                logger.error(msg.format(e))

    def _plotter_launcher_fired(self):
        """ Pop up plot manager view. Only when self.plotter_layout="popup".
        """
        self.plotter.edit_traits(kind="livemodal")

    def _data_exporter_fired(self):
        filepath = request_csv_file(action="save as")
        if filepath:
            self.model.filtered_df.to_csv(filepath)
            open_file(filepath)

    def _summary_exporter_fired(self):
        filepath = request_csv_file(action="save as")
        if filepath:
            self.model.summary_df.to_csv(filepath)
            open_file(filepath)

    # Traits property getters/setters -----------------------------------------

    def _get__known_expr(self):
        return {e.expression for e in self.model.known_filter_exps}

    @cached_property
    def _get_truncation_msg(self):
        num_displayed_rows = self.model.num_displayed_rows
        return self.truncation_msg_template.format(num_displayed_rows)

    @cached_property
    def _get__many_columns(self):
        # Many columns means more than 2 columns:
        return len(self.model.column_list) > 2 * self.max_names_per_column

    # Traits initialization methods -------------------------------------------

    def _plotter_default(self):
        if self.include_plotter:
            if self.model.plot_manager_list:
                if len(self.model.plot_manager_list) > 1:
                    num_plotters = len(self.model.plot_manager_list)
                    msg = "Model contains {} plot manager, but only " \
                          "initializing the Analyzer view with the first " \
                          "plot manager available.".format(num_plotters)
                    logger.warning(msg)

                plot_manager = self.model.plot_manager_list[0]
            else:
                plot_manager = DataFramePlotManager(
                    data_source=self.model.filtered_df,
                    source_analyzer=self.model,
                    **self.plotter_kw
                )

            view = DataFramePlotManagerView(model=plot_manager,
                                            view_klass=self.view_klass)
            return view

    def _formats_default(self):
        if self.display_precision < 0:
            return '%s'
        else:
            formats = {}
            float_format = '%.{}g'.format(self.display_precision)
            for col in self.model.source_df.columns:
                col_dtype = self.model.source_df.dtypes[col]
                if np.issubdtype(col_dtype, np.number):
                    formats[col] = float_format
                else:
                    formats[col] = '%s'

            return formats

    def _visible_columns_default(self):
        return self.model.column_list

    def _hidden_selection_msg_default(self):
        msg = "The displayed data is truncated and some of the selected " \
              "rows isn't displayed in the data table."
        return msg

    def _summary_section_title_default(self):
        if len(self.model.summary_categorical_df) == 0:
            return "Data summary"
        else:
            return "Numerical data summary"

    def _fonts_default(self):
        return "{} {}".format(self.font_name, self.font_size)
Example #5
0
class GeoNDGrid(Source):
    '''
    Specification and representation of an nD-grid.

    GridND
    '''
    # The name of our scalar array.
    scalar_name = Str('scalar')

    # map of coordinate labels to the indices
    _dim_map = {'x': 0, 'y': 1, 'z': 2}

    # currently active dimensions
    active_dims = List(Str, ['x', 'y'])

    # Bottom left corner
    x_mins = Instance(GridPoint, label='Corner 1')

    def _x_mins_default(self):
        '''Bottom left corner'''
        return GridPoint()

    # Upper right corner
    x_maxs = Instance(GridPoint, label='Corner 2')

    def _x_maxs_default(self):
        '''Upper right corner'''
        return GridPoint(x=1, y=1, z=1)

    # indices of the currently active dimensions
    dim_indices = Property(Array(int), depends_on='active_dims')

    @cached_property
    def _get_dim_indices(self):
        ''' Get active indices '''
        return array([self._dim_map[dim_ix] for dim_ix in self.active_dims],
                     dtype='int_')

    # number of currently active dimensions
    n_dims = Property(Int, depends_on='active_dims')

    @cached_property
    def _get_n_dims(self):
        '''Number of currently active dimensions'''
        return len(self.active_dims)

    # number of elements in each direction
    # @todo: rename to n_faces
    shape = Tuple(int, int, int, label='Elements')

    def _shape_default(self):
        '''Number of elements in each direction'''
        return (1, 0, 0)

    n_act_nodes = Property(Array, depends_on='shape, active_dims')

    @cached_property
    def _get_n_act_nodes(self):
        '''Number of active nodes respecting the active_dim'''
        act_idx = ones((3, ), int)
        shape = array(list(self.shape), dtype=int)
        act_idx[self.dim_indices] += shape[self.dim_indices]
        return act_idx

    # total number of nodes of the system grid
    n_nodes = Property(Int, depends_on='shape, active_dims')

    @cached_property
    def _get_n_nodes(self):
        '''Number of nodes used for the geometry approximation'''
        return reduce(lambda x, y: x * y, self.n_act_nodes)

    enum_nodes = Property(Array, depends_on='shape,active_dims')

    @cached_property
    def _get_enum_nodes(self):
        '''
        Returns an array of element numbers respecting the grid structure
        (the nodes are numbered first in x-direction, then in y-direction and
        last in z-direction)
        '''
        return arange(self.n_nodes).reshape(tuple(self.n_act_nodes))

    grid = Property(Array, depends_on='shape,active_dims,x_mins.+,x_maxs.+')

    @cached_property
    def _get_grid(self):
        '''
        slice(start,stop,step) with step of type 'complex' leads to that number of divisions 
        in that direction including 'stop' (see numpy: 'mgrid')
        '''
        slices = [
            slice(x_min, x_max, complex(0, n_n)) for x_min, x_max, n_n in zip(
                self.x_mins, self.x_maxs, self.n_act_nodes)
        ]
        return mgrid[tuple(slices)]

    #-------------------------------------------------------------------------
    # Visualization pipelines
    #-------------------------------------------------------------------------


#     mvp_mgrid_geo = Trait(MVPolyData)
#
#     def _mvp_mgrid_geo_default(self):
#         return MVPolyData(name='Mesh geomeetry',
#                           points=self._get_points,
#                           lines=self._get_lines,
#                           polys=self._get_faces,
#                           scalars=self._get_random_scalars
#                           )
#
#     mvp_mgrid_labels = Trait(MVPointLabels)
#
#     def _mvp_mgrid_labels_default(self):
#         return MVPointLabels(name='Mesh numbers',
#                              points=self._get_points,
#                              scalars=self._get_random_scalars,
#                              vectors=self._get_points)

    changed = Button('Draw')

    @on_trait_change('changed')
    def redraw(self):
        '''
        '''
        self.mvp_mgrid_geo.redraw()
        self.mvp_mgrid_labels.redraw('label_scalars')

    def _get_points(self):
        '''
        Reshape the grid into a column.
        '''
        return c_[tuple([self.grid[i].flatten() for i in range(3)])]

    def _get_n_lines(self):
        '''
        Get the number of lines.
        '''
        act_idx = ones((3, ), int)
        act_idx[self.dim_indices] += self.shape[self.dim_indices]
        return reduce(lambda x, y: x * y, act_idx)

    def _get_lines(self):
        '''
        Only return data if n_dims = 1
        '''
        if self.n_dims != 1:
            return array([], int)
        #
        # Get the list of all base nodes
        #
        tidx = ones((3, ), dtype='int_')
        tidx[self.dim_indices] = -1
        slices = tuple([slice(0, idx) for idx in tidx])
        base_node_list = self.enum_nodes[slices].flatten()
        #
        # Get the node map within the line
        #
        ijk_arr = zeros((3, 2), dtype=int)
        ijk_arr[self.dim_indices[0]] = [0, 1]
        offsets = self.enum_nodes[ijk_arr[0], ijk_arr[1], ijk_arr[2]]
        #
        # Setup and fill the array with line connectivities
        #
        n_lines = self._get_n_lines()
        lines = zeros((n_lines, 2), dtype='int_')
        for n_idx, base_node in enumerate(base_node_list):
            lines[n_idx, :] = offsets + base_node
        return lines

    def _get_n_faces(self):
        '''Return the number of faces.

        The number is determined by putting 1 into inactive dimensions and 
        shape into the active dimensions. 
        '''
        act_idx = ones((3, ), int)
        shape = array(self.shape, dtype=int)
        act_idx[self.dim_indices] = shape[self.dim_indices]
        return reduce(lambda x, y: x * y, act_idx)

    def _get_faces(self):
        '''
        Only return data of n_dims = 2.
        '''
        if self.n_dims != 2:
            return array([], int)
        #
        # get the slices extracting all corner nodes with
        # the smallest node number within the element
        #
        tidx = ones((3, ), dtype='int_')
        tidx[self.dim_indices] = -1
        slices = tuple([slice(0, idx) for idx in tidx])
        base_node_list = self.enum_nodes[slices].flatten()
        #
        # get the node map within the face
        #
        ijk_arr = zeros((3, 4), dtype=int)
        ijk_arr[self.dim_indices[0]] = [0, 0, 1, 1]
        ijk_arr[self.dim_indices[1]] = [0, 1, 1, 0]
        offsets = self.enum_nodes[ijk_arr[0], ijk_arr[1], ijk_arr[2]]
        #
        # setup and fill the array with line connectivities
        #
        n_faces = self._get_n_faces()
        faces = zeros((n_faces, 4), dtype='int_')
        for n_idx, base_node in enumerate(base_node_list):
            faces[n_idx, :] = offsets + base_node
        return faces

    def _get_volumes(self):
        '''
        Only return data if ndims = 3
        '''
        if self.n_dims != 3:
            return array([], int)

        tidx = ones((3, ), dtype='int_')
        tidx[self.dim_indices] = -1
        slices = tuple([slice(0, idx) for idx in tidx])

        en = self.enum_nodes
        offsets = array([
            en[0, 0, 0], en[0, 1, 0], en[1, 1, 0], en[1, 0, 0], en[0, 0, 1],
            en[0, 1, 1], en[1, 1, 1], en[1, 0, 1]
        ],
                        dtype='int_')
        base_node_list = self.enum_nodes[slices].flatten()

        n_faces = self._get_n_faces()
        faces = zeros((n_faces, 8), dtype='int_')
        for n in base_node_list:
            faces[n, :] = offsets + n

        return faces

    # Identifiers
    var = Str('dummy')
    idx = Int(0)

    def _get_random_scalars(self):
        return random.weibull(1, size=self.n_nodes)

    traits_view = View(HSplit(
        Group(
            Item('changed', show_label=False),
            Item('active_dims@',
                 editor=CheckListEditor(values=['x', 'y', 'z'], cols=3)),
            Item('x_mins@', resizable=False),
            Item('x_maxs@'),
            Item('shape@'),
        ), ),
                       resizable=True)
Example #6
0
class HeadViewController(HasTraits):
    """
    Set head views for Anterior-Left-Superior coordinate system

    Parameters
    ----------
    system : 'RAS' | 'ALS' | 'ARI'
        Coordinate system described as initials for directions associated with
        the x, y, and z axes. Relevant terms are: Anterior, Right, Left,
        Superior, Inferior.
    """
    system = Enum("RAS",
                  "ALS",
                  "ARI",
                  desc="Coordinate system: directions of "
                  "the x, y, and z axis.")

    right = Button()
    front = Button()
    left = Button()
    top = Button()

    scale = Float(0.16)

    scene = Instance(MlabSceneModel)

    view = View(
        VGrid('0',
              'top',
              '0',
              Item('scale', label='Scale', show_label=True),
              'right',
              'front',
              'left',
              show_labels=False,
              columns=4))

    @on_trait_change('scene.activated')
    def _init_view(self):
        self.scene.parallel_projection = True

        # apparently scene,activated happens several times
        if self.scene.renderer:
            self.sync_trait('scale', self.scene.camera, 'parallel_scale')
            # and apparently this does not happen by default:
            self.on_trait_change(self.scene.render, 'scale')

    @on_trait_change('top,left,right,front')
    def on_set_view(self, view, _):
        if self.scene is None:
            return

        system = self.system
        kwargs = None

        if system == 'ALS':
            if view == 'front':
                kwargs = dict(azimuth=0, elevation=90, roll=-90)
            elif view == 'left':
                kwargs = dict(azimuth=90, elevation=90, roll=180)
            elif view == 'right':
                kwargs = dict(azimuth=-90, elevation=90, roll=0)
            elif view == 'top':
                kwargs = dict(azimuth=0, elevation=0, roll=-90)
        elif system == 'RAS':
            if view == 'front':
                kwargs = dict(azimuth=90, elevation=90, roll=180)
            elif view == 'left':
                kwargs = dict(azimuth=180, elevation=90, roll=90)
            elif view == 'right':
                kwargs = dict(azimuth=0, elevation=90, roll=270)
            elif view == 'top':
                kwargs = dict(azimuth=90, elevation=0, roll=180)
        elif system == 'ARI':
            if view == 'front':
                kwargs = dict(azimuth=0, elevation=90, roll=90)
            elif view == 'left':
                kwargs = dict(azimuth=-90, elevation=90, roll=180)
            elif view == 'right':
                kwargs = dict(azimuth=90, elevation=90, roll=0)
            elif view == 'top':
                kwargs = dict(azimuth=0, elevation=180, roll=90)
        else:
            raise ValueError("Invalid system: %r" % system)

        if kwargs is None:
            raise ValueError("Invalid view: %r" % view)

        self.scene.mlab.view(distance=None,
                             reset_roll=True,
                             figure=self.scene.mayavi_scene,
                             **kwargs)
Example #7
0
class PossionDemo(HasTraits):
    left_mask = Instance(ImageMaskDrawer)
    right_mask = Instance(ImageMaskDrawer)
    left_file = File(path.join(FOLDER, "vinci_src.png"))
    right_file = File(path.join(FOLDER, "vinci_target.png"))
    figure = Instance(Figure, ())
    load_button = Button("Load")
    clear_button = Button("Clear")
    mix_button = Button("Mix")

    view = View(
        Item("left_file"),
        Item("right_file"),
        VGroup(
            HGroup(
                "load_button", "clear_button", "mix_button",
                show_labels=False
            ),
            Group(
                Item("figure", editor=MPLFigureEditor(toolbar=False)),
                show_labels=False,
                orientation='horizontal'
            )
        ),
        width=800,
        height=600,
        resizable=True,
        title="Possion Demo")

    def __init__(self, **kw):
        super(PossionDemo, self).__init__(**kw)
        self.left_axe = self.figure.add_subplot(121)
        self.right_axe = self.figure.add_subplot(122)

    def load_images(self):
        self.left_pic = cv2.imread(self.left_file, 1)
        self.right_pic = cv2.imread(self.right_file, 1)

        shape = [max(v1, v2) for v1, v2 in
                 zip(self.left_pic.shape[:2], self.right_pic.shape[:2])]

        self.left_img = self.left_axe.imshow(self.left_pic[::-1, :, ::-1], origin="lower")
        self.left_axe.axis("off")
        self.left_mask = ImageMaskDrawer(self.left_axe, self.left_img, mask_shape=shape)

        self.right_img = self.right_axe.imshow(self.right_pic[::-1, :, ::-1], origin="lower")
        self.right_axe.axis("off")
        self.right_mask = ImageMaskDrawer(self.right_axe, self.right_img, mask_shape=shape)

        self.left_mask.on_trait_change(self.mask_changed, "drawed")
        self.right_mask.on_trait_change(self.mask_changed, "drawed")
        self.figure.canvas.draw_idle()

    def mask_changed(self, obj, name, new):
        if obj is self.left_mask:
            src, target = self.left_mask, self.right_mask
        else:
            src, target = self.right_mask, self.left_mask
        target.array.fill(0)
        target.array[:, :] = src.array[:, :]
        target.update()
        self.figure.canvas.draw()

    def _load_button_fired(self):
        self.load_images()

    def _mix_button_fired(self):
        lh, lw, _ = self.left_pic.shape
        rh, rw, _ = self.right_pic.shape

        left_mask = self.left_mask.array[-lh:, :lw, -1]
        if np.all(left_mask==0):
            return

        dx, dy = self.right_mask.get_mask_offset()
        result = possion_mix(self.left_pic, self.right_pic, left_mask, (dx, rh - lh - dy))

        self.right_img.set_data(result[::-1, :, ::-1])
        self.right_mask.mask_img.set_visible(False)
        self.figure.canvas.draw()

    def _clear_button_fired(self):
        self.left_mask.array.fill(0)
        self.right_mask.array.fill(0)
        self.left_mask.update()
        self.right_mask.update()
        self.figure.canvas.draw()
Example #8
0
class Spylot(HasTraits):
    """
    This class represents the spylot application state.
    """

    defaultlines = 'galaxy'  #can be 'galaxy' or 'stellar' or None
    """
    This class attribute sets the default line lists to use - 'galaxy' or
    'stellar'
    """

    plot = Instance(Plot)
    histplot = Bool(True)

    majorlineeditor = LineListEditor
    minorlineeditor = LineListEditor
    linenamelist = Property(Tuple)
    showmajorlines = Bool(True)
    showminorlines = Bool(False)

    labels = List(DataLabel)
    showlabels = Bool(False)
    editmajor = Button('Edit Major')
    editminor = Button('Edit Minor')

    specs = List(Instance(spec.Spectrum))
    currspeci = Int
    currspecip1 = Property(depends_on='currspeci')
    lowerspecip1 = Property(depends_on='currspeci')
    upperspecip1 = Property(depends_on='currspeci')
    currspec = Property
    lastspec = Instance(spec.Spectrum)
    z = Float
    lowerz = Float(0.0)
    upperz = Float(1.0)
    zviewtype = Enum('redshift', 'velocity')
    zview = Property(Float, depends_on='z,zviewtype')
    uzview = Property(Float, depends_on='upperz,zviewtype')
    lzview = Property(Float, depends_on='lowerz,zviewtype')
    coarserz = Button('Coarser')
    finerz = Button('Finer')
    _zql, _zqh = min(spec.Spectrum._zqmap.keys()), max(
        spec.Spectrum._zqmap.keys())
    zqual = Range(_zql, _zqh, -1)

    spechanged = Event
    specleft = Button('<')
    specright = Button('>')

    scaleerr = Bool(False)
    scaleerrfraclow = Range(0.0, 1.0, 1.0)
    scaleerrfrachigh = Float(1.0)
    fluxformat = Button('Flux Line Format')
    errformat = Button('Error Line Format')
    showcoords = Bool(False)
    showgrid = Bool(True)

    dosmoothing = Bool(False)
    smoothing = Float(3)

    contsub = Button('Fit Continuum...')
    contclear = Button('Clear Continuum')
    showcont = Bool(False)
    contformat = Button('Continuum Line Format')

    _titlestr = Str('Spectrum 0/0')
    _oldunit = Str('')
    maintool = Tuple(Instance(Interactor), Instance(AbstractOverlay))

    featureselmode = Enum([
        'No Selection', 'Click Select', 'Range Select', 'Base Select',
        'Click Delete'
    ])
    editfeatures = Button('Features...')
    showfeatures = Bool(True)
    featurelocsmooth = Float(None)
    featurelocsize = Int(200)
    featurelist = List(Instance(spec.SpectralFeature))

    selectedfeatureindex = Int
    deletefeature = Button('Delete')
    idfeature = Button('Identify')
    recalcfeature = Button('Recalculate')
    clearfeatures = Button('Clear')

    delcurrspec = Button('Delete Current')
    saveloadfile = File(filter=['*.specs'])
    savespeclist = Button('Save Spectra')
    loadspeclist = Button('Load Spectra')
    loadaddfile = File(filter=['*.fits'])
    loadaddspec = Button('Add Spectrum')
    loadaddspectype = Enum('wcs', 'deimos', 'astropysics')

    titlegroup = HGroup(
        Item('specleft', show_label=False, enabled_when='currspeci>0'), spring,
        Label('Spectrum #', height=0.5),
        Item('currspecip1',
             show_label=False,
             editor=RangeEditor(low_name='lowerspecip1',
                                high_name='upperspecip1',
                                mode='spinner')),
        Item('_titlestr', style='readonly', show_label=False), spring,
        Item('specright',
             show_label=False,
             enabled_when='currspeci<(len(specs)-1)'))

    speclistgroup = HGroup(
        Label('Spectrum List:'), spring,
        Item('delcurrspec', show_label=False, enabled_when='len(specs)>1'),
        Item('saveloadfile', show_label=False),
        Item('savespeclist', show_label=False),
        Item('loadspeclist',
             show_label=False,
             enabled_when='os.path.exists(saveloadfile)'),
        Item('loadaddfile', show_label=False),
        Item('loadaddspec',
             show_label=False,
             enabled_when='os.path.exists(saveloadfile)'),
        Item('loadaddspectype', show_label=False), spring)

    plotformatgroup = HGroup(
        spring, Item('fluxformat', show_label=False),
        Item('errformat', show_label=False),
        Item('scaleerr', label='Scale Error?'),
        Item('scaleerrfraclow',
             label='Lower',
             enabled_when='scaleerr',
             editor=TextEditor(evaluate=float)),
        Item('scaleerrfrachigh', label='Upper', enabled_when='scaleerr'),
        Item('showgrid', label='Grid?'), Item('showcoords', label='Coords?'),
        spring)

    featuregroup = HGroup(spring, Item('showmajorlines', label='Show major?'),
                          Item('editmajor', show_label=False),
                          Item('showlabels', label='Labels?'),
                          Item('showminorlines', label='Show minor?'),
                          Item('editminor', show_label=False), spring,
                          Item('editfeatures', show_label=False),
                          Item('featureselmode', show_label=False), spring)

    continuumgroup = HGroup(
        spring, Item('contsub', show_label=False),
        Item('contclear', show_label=False),
        Item('showcont', label='Continuum line?'),
        Item('contformat', show_label=False),
        Item('dosmoothing', label='Smooth?'),
        Item('smoothing', show_label=False, enabled_when='dosmoothing'),
        spring)

    zgroup = VGroup(
        HGroup(
            Item('zviewtype', show_label=False),
            Item('zview',
                 editor=RangeEditor(low_name='lzview',
                                    high_name='uzview',
                                    format='%5.4g ',
                                    mode='slider'),
                 show_label=False,
                 springy=True)),
        HGroup(
            Item('lzview', show_label=False), Item('coarserz',
                                                   show_label=False), spring,
            Item('zqual',
                 style='custom',
                 label='Z quality',
                 editor=RangeEditor(cols=_zqh - _zql + 1, low=_zql,
                                    high=_zqh)), spring,
            Item('finerz', show_label=False), Item('uzview',
                                                   show_label=False)))

    features_view = View(VGroup(
        HGroup(Item('showfeatures', label='Show'),
               Item('featurelocsmooth', label='Locator Smoothing'),
               Item('featurelocsize', label='Locator Window size')),
        Item('featurelist',
             editor=TabularEditor(adapter=FeaturesAdapter(),
                                  selected_row='selectedfeatureindex'),
             show_label=False),
        HGroup(
            Item(
                'deletefeature',
                show_label=False,
                enabled_when='len(featurelist)>0 and selectedfeatureindex>=0'),
            Item(
                'idfeature',
                show_label=False,
                enabled_when='len(featurelist)>0 and selectedfeatureindex>=0'),
            Item(
                'recalcfeature',
                show_label=False,
                enabled_when='len(featurelist)>0 and selectedfeatureindex>=0'),
            Item('clearfeatures',
                 show_label=False,
                 enabled_when='len(featurelist)>0'))),
                         resizable=True,
                         title='Spylot Features')

    traits_view = View(VGroup(
        Include('titlegroup'), Include('speclistgroup'),
        Include('plotformatgroup'), Include('featuregroup'),
        Include('continuumgroup'),
        Item('plot', editor=ComponentEditor(), show_label=False, width=768),
        Include('zgroup')),
                       resizable=True,
                       title='Spectrum Plotter',
                       handler=SpylotHandler(),
                       key_bindings=spylotkeybindings)

    def __init__(self, specs, **traits):
        """
        :param specs:
            The spectra/um to be analyzed as a sequence or a single object.
        :type specs: :class:`astropysics.spec.Spectrum`

        kwargs are passed in as additional traits.
        """

        #pd = ArrayPlotData(x=[1],x0=[1],flux=[1],err=[1]) #reset by spechanged event
        pd = ArrayPlotData(x=[1], flux=[1],
                           err=[1])  #reset by spechanged event
        pd.set_data('majorx', [1, 1])  #reset by majorlines change
        pd.set_data('majory', [0, 1])  #reset by majorlines change
        pd.set_data('minorx', [1, 1])  #reset by minorlines change
        pd.set_data('minory', [0, 1])  #reset by minorlines change
        pd.set_data('continuum', [0, 0])  #reset

        self.plot = plot = Plot(pd, resizeable='hv')
        ploi = plot.draw_order.index('plot')
        plot.draw_order[ploi:ploi] = [
            'continuum', 'err', 'flux', 'annotations'
        ]
        plot.plot(('x', 'flux'),
                  name='flux',
                  type='line',
                  line_style='solid',
                  color='blue',
                  draw_layer='flux',
                  unified_draw=True)
        plot.plot(('x', 'err'),
                  name='err',
                  type='line',
                  line_style='dash',
                  color='green',
                  draw_layer='err',
                  unified_draw=True)

        topmapper = LinearMapper(range=DataRange1D())
        plot.x_mapper.range.on_trait_change(self._update_upperaxis_range,
                                            'updated')
        plot.x_mapper.on_trait_change(self._update_upperaxis_screen, 'updated')
        self.upperaxis = PlotAxis(plot, orientation='top', mapper=topmapper)
        plot.overlays.append(self.upperaxis)

        self.errmapperfixed = plot.plots['err'][0].value_mapper
        self.errmapperscaled = LinearMapper(range=DataRange1D(high=1.0, low=0))
        plot.x_mapper.on_trait_change(self._update_errmapper_screen, 'updated')

        plot.padding_top = 30  #default is a bit much
        plot.padding_left = 70  #default is a bit too little

        majorlineplot = plot.plot(('majorx', 'majory'),
                                  name='majorlineplot',
                                  type='line',
                                  line_style='dash',
                                  color='red')[0]
        majorlineplot.set(draw_layer='annotations', unified_draw=True)
        majorlineplot.value_mapper = LinearMapper(
            range=DataRange1D(high=0.9, low=0.1))
        majorlineplot.visible = self.showmajorlines
        del plot.x_mapper.range.sources[
            -1]  #remove the line plot from the x_mapper sources so scaling is only on the spectrum
        self.majorlineeditor = LineListEditor(lineplot=majorlineplot)

        minorlineplot = plot.plot(('minorx', 'minory'),
                                  name='minorlineplot',
                                  type='line',
                                  line_style='dot',
                                  color='red')[0]
        minorlineplot.set(draw_layer='annotations', unified_draw=True)
        minorlineplot.value_mapper = LinearMapper(
            range=DataRange1D(high=0.9, low=0.1))
        minorlineplot.visible = self.showminorlines
        del plot.x_mapper.range.sources[
            -1]  #remove the line plot from the x_mapper sources so scaling is only on the spectrum
        self.minorlineeditor = LineListEditor(lineplot=minorlineplot)

        self.contline = plot.plot(('x', 'continuum'),
                                  name='continuum',
                                  type='line',
                                  line_style='solid',
                                  color='black')[0]
        self.contline.set(draw_layer='continuum', unified_draw=True)
        self.contline.visible = self.showcont
        #        idat = ArrayDataSource((0.0,1.0))
        #        vdat = ArrayDataSource((0.0,0.0))
        #        self.zeroline = LinePlot(index=idat,value=vdat,line_style='solid',color='black')
        #        self.zeroline.index_mapper = LinearMapper(range=DataRange1D(high=0.9,low=0.1))
        #        self.zeroline.value_mapper = self.plot.y_mapper
        #        self.zeroline.visible = self.showcont
        #        self.plot.add(self.zeroline)

        if Spylot.defaultlines:
            defaultlines = _get_default_lines(Spylot.defaultlines)
            self.majorlineeditor.candidates = defaultlines[0]
            self.majorlineeditor.selectednames = defaultlines[1]
            self.minorlineeditor.candidates = defaultlines[0]
            self.minorlineeditor.selectednames = defaultlines[2]

        plot.tools.append(PanTool(plot))
        plot.tools.append(ZoomTool(plot))
        plot.overlays.append(plot.tools[-1])

        self.coordtext = TextBoxOverlay(component=plot, align='ul')
        plot.overlays.append(self.coordtext)
        plot.tools.append(MouseMoveReporter(overlay=self.coordtext, plot=plot))
        self.coordtext.visible = self.showcoords

        self.linehighlighter = lho = LineHighlighterOverlay(component=plot)
        lho.visible = self.showfeatures
        plot.overlays.append(lho)

        if specs is None:
            specs = []
        elif isinstance(specs, spec.Spectrum):
            specs = [specs]
        self.specs = specs

        self.spechanged = True

        self.on_trait_change(self._majorlines_changed,
                             'majorlineeditor.selectedobjs')
        self.on_trait_change(self._minorlines_changed,
                             'minorlineeditor.selectedobjs')

        super(Spylot, self).__init__(**traits)

    def __del__(self):
        try:
            self.currspec._features = list(self.currspec._features)
        except (AttributeError, IndexError), e:
            pass
Example #9
0
class DofCellGrid(SDomain):
    '''
    Get an array with element Dof numbers
    '''
    implements(ICellArraySource)

    cell_grid = Instance(CellGrid)

    get_cell_point_X_arr = DelegatesTo('cell_grid')
    get_cell_mvpoints = DelegatesTo('cell_grid')
    cell_node_map = DelegatesTo('cell_grid')
    get_cell_offset = DelegatesTo('cell_grid')

    # offset of dof within domain list
    #
    dof_offset = Int(0)

    # number of degrees of freedom in a single node
    #
    n_nodal_dofs = Int(3)
    #-------------------------------------------------------------------------
    # Generation methods for geometry and index maps
    #-------------------------------------------------------------------------
    n_dofs = Property(depends_on='cell_grid.shape,n_nodal_dofs,dof_offset')

    def _get_n_dofs(self):
        '''
        Get the total number of DOFs
        '''
        unique_cell_nodes = unique(self.cell_node_map.flatten())
        n_unique_nodes = len(unique_cell_nodes)
        return n_unique_nodes * self.n_nodal_dofs

    dofs = Property(depends_on='cell_grid.shape,n_nodal_dofs,dof_offset')

    @cached_property
    def _get_dofs(self):
        '''
        Construct the point grid underlying the mesh grid structure.
        '''
        cell_node_map = self.cell_node_map

        unique_cell_nodes = unique(cell_node_map.flatten())
        n_unique_nodes = len(unique_cell_nodes)

        n_nodal_dofs = self.n_nodal_dofs
        n_nodes = self.cell_grid.point_grid_size
        node_dof_array = repeat(-1, n_nodes * n_nodal_dofs).reshape(
            n_nodes, n_nodal_dofs)

        # Enumerate the DOFs in the mesh. The result is an array with n_nodes rows
        # and n_nodal_dofs columns
        #
        # A = array( [[ 0, 1 ],
        #             [ 2, 3 ],
        #             [ 4, 5 ]] );
        #
        node_dof_array[ index_exp[ unique_cell_nodes ] ] = \
            arange(
                n_unique_nodes * n_nodal_dofs).reshape(n_unique_nodes, n_nodal_dofs)

        # add the dof_offset before returning the array
        #
        node_dof_array += self.dof_offset
        return node_dof_array

    def _get_doffed_nodes(self):
        '''
        Get the indices of nodes containing DOFs. 
        '''
        cell_node_map = self.cell_node_map

        unique_cell_nodes = unique(cell_node_map.flatten())

        n_nodes = self.cell_grid.point_grid_size
        doffed_nodes = repeat(-1, n_nodes)

        doffed_nodes[index_exp[unique_cell_nodes]] = 1
        return where(doffed_nodes > 0)[0]

    #-----------------------------------------------------------------
    # Elementwise-representation of dofs
    #-----------------------------------------------------------------

    cell_dof_map = Property(depends_on='cell_grid.shape,n_nodal_dofs')

    def _get_cell_dof_map(self):
        return self.dofs[index_exp[self.cell_grid.cell_node_map]]

    cell_grid_dof_map = Property(depends_on='cell_grid.shape,n_nodal_dofs')

    def _get_cell_grid_dof_map(self):
        return self.dofs[index_exp[self.cell_grid.cell_grid_node_map]]

    def get_cell_dofs(self, cell_idx):
        return self.cell_dof_map[cell_idx]

    elem_dof_map = Property(depends_on='cell_grid.shape,n_nodal_dofs')

    @cached_property
    def _get_elem_dof_map(self):
        el_dof_map = copy(self.cell_dof_map)
        tot_shape = el_dof_map.shape[0]
        n_entries = el_dof_map.shape[1] * el_dof_map.shape[2]
        elem_dof_map = el_dof_map.reshape(tot_shape, n_entries)
        return elem_dof_map

    def __getitem__(self, idx):
        '''High level access and slicing to the cells within the grid.

        The return value is a tuple with 
        1. array of cell indices
        2. array of nodes for each element
        3. array of coordinates for each node.
        '''
        dgs = DofGridSlice(dof_grid=self, grid_slice=idx)
        return dgs

    #-----------------------------------------------------------------
    # Spatial queries for dofs
    #-----------------------------------------------------------------

    def _get_dofs_for_nodes(self, nodes):
        '''Get the dof numbers and associated coordinates
        given the array of nodes.
        '''
        doffed_nodes = self._get_doffed_nodes()
        #         print 'nodes'
        #         print nodes
        #         print 'doffed_nodes'
        #         print doffed_nodes
        intersect_nodes = intersect1d(nodes, doffed_nodes, assume_unique=False)
        return (self.dofs[index_exp[intersect_nodes]],
                self.cell_grid.point_X_arr[index_exp[intersect_nodes]])

    def get_boundary_dofs(self):
        '''Get the boundary dofs and the associated coordinates
        '''
        nodes = [
            self.cell_grid.point_idx_grid[s]
            for s in self.cell_grid.boundary_slices
        ]
        dofs, coords = [], []
        for n in nodes:
            d, c = self._get_dofs_for_nodes(n)
            dofs.append(d)
            coords.append(c)
        return (vstack(dofs), vstack(coords))

    def get_all_dofs(self):
        nodes = self.cell_grid.point_idx_grid[...]
        return self._get_dofs_for_nodes(nodes)

    def get_left_dofs(self):
        nodes = self.cell_grid.point_idx_grid[0, ...]
        return self._get_dofs_for_nodes(nodes)

    def get_right_dofs(self):
        nodes = self.cell_grid.point_idx_grid[-1, ...]
        return self._get_dofs_for_nodes(nodes)

    def get_top_dofs(self):
        nodes = self.cell_grid.point_idx_grid[:, -1, ...]
        return self._get_dofs_for_nodes(nodes)

    def get_bottom_dofs(self):
        nodes = self.cell_grid.point_idx_grid[:, 0, ...]
        return self._get_dofs_for_nodes(nodes)

    def get_front_dofs(self):
        nodes = self.cell_grid.point_idx_grid[:, :, -1]
        return self._get_dofs_for_nodes(nodes)

    def get_back_dofs(self):
        nodes = self.cell_grid.point_idx_grid[:, :, 0]
        return self._get_dofs_for_nodes(nodes)

    def get_bottom_left_dofs(self):
        nodes = self.cell_grid.point_idx_grid[0, 0, ...]
        return self._get_dofs_for_nodes(nodes)

    def get_bottom_front_dofs(self):
        nodes = self.cell_grid.point_idx_grid[:, 0, -1]
        return self._get_dofs_for_nodes(nodes)

    def get_bottom_back_dofs(self):
        nodes = self.cell_grid.point_idx_grid[:, 0, 0]
        return self._get_dofs_for_nodes(nodes)

    def get_top_left_dofs(self):
        nodes = self.cell_grid.point_idx_grid[0, -1, ...]
        return self._get_dofs_for_nodes(nodes)

    def get_bottom_right_dofs(self):
        nodes = self.cell_grid.point_idx_grid[-1, 0, ...]
        return self._get_dofs_for_nodes(nodes)

    def get_top_right_dofs(self):
        nodes = self.cell_grid.point_idx_grid[-1, -1, ...]
        return self._get_dofs_for_nodes(nodes)

    def get_bottom_middle_dofs(self):
        if self.cell_grid.point_idx_grid.shape[0] % 2 == 1:
            slice_middle_x = self.cell_grid.point_idx_grid.shape[0] / 2
            nodes = self.cell_grid.point_idx_grid[slice_middle_x, 0, ...]
            return self._get_dofs_for_nodes(nodes)
        else:
            print 'Error in get_bottom_middle_dofs:'\
                ' the method is only defined for an odd number of dofs in x-direction'

    def get_top_middle_dofs(self):
        if self.cell_grid.point_idx_grid.shape[0] % 2 == 1:
            slice_middle_x = self.cell_grid.point_idx_grid.shape[0] / 2
            nodes = self.cell_grid.point_idx_grid[slice_middle_x, -1, ...]
            return self._get_dofs_for_nodes(nodes)
        else:
            print 'Error in get_top_middle_dofs:'\
                ' the method is only defined for an odd number of dofs in x-direction'

    def get_left_middle_dofs(self):
        if self.cell_grid.point_idx_grid.shape[1] % 2 == 1:
            slice_middle_y = self.cell_grid.point_idx_grid.shape[1] / 2
            nodes = self.cell_grid.point_idx_grid[0, slice_middle_y, ...]
            return self._get_dofs_for_nodes(nodes)
        else:
            print 'Error in get_left_middle_dofs:'\
                ' the method is only defined for an odd number of dofs in y-direction'

    def get_right_middle_dofs(self):
        if self.cell_grid.point_idx_grid.shape[1] % 2 == 1:
            slice_middle_y = self.cell_grid.point_idx_grid.shape[1] / 2
            nodes = self.cell_grid.point_idx_grid[-1, slice_middle_y, ...]
            return self._get_dofs_for_nodes(nodes)
        else:
            print 'Error in get_right_middle_dofs:'\
                ' the method is only defined for an odd number of dofs in y-direction'

    def get_left_front_bottom_dof(self):
        nodes = self.cell_grid.point_idx_grid[0, 0, -1]
        return self._get_dofs_for_nodes(nodes)

    def get_left_front_middle_dof(self):
        if self.cell_grid.point_idx_grid.shape[1] % 2 == 1:
            slice_middle_y = self.cell_grid.point_idx_grid.shape[1] / 2
            nodes = self.cell_grid.point_idx_grid[0, slice_middle_y, -1]
            return self._get_dofs_for_nodes(nodes)
        else:
            print 'Error in get_left_middle_front_dof:'\
                ' the method is only defined for an odd number of dofs in y-direction'

    #-----------------------------------------------------------------
    # Visualization related methods
    #-----------------------------------------------------------------

    refresh_button = Button('Draw')

    @on_trait_change('refresh_button')
    def redraw(self):
        '''Redraw the point grid.
        '''
        self.cell_grid.redraw()

    dof_cell_array = Button

    def _dof_cell_array_fired(self):
        cell_array = self.cell_grid.cell_node_map
        self.show_array = CellArray(data=cell_array,
                                    cell_view=DofCellView(cell_grid=self))
        self.show_array.current_row = 0
        self.show_array.configure_traits(kind='live')

    #------------------------------------------------------------------
    # UI - related methods
    #------------------------------------------------------------------
    traits_view = View(Item('n_nodal_dofs'),
                       Item('dof_offset'),
                       Item('cell_grid@', show_label=False),
                       Item('refresh_button', show_label=False),
                       Item('dof_cell_array', show_label=False),
                       resizable=True,
                       scrollable=True,
                       height=0.5,
                       width=0.5)
Example #10
0
class ExperimentFactory(Loggable, ConsumerMixin):
    db = Any
    run_factory = Instance(AutomatedRunFactory)
    queue_factory = Instance(ExperimentQueueFactory)

    #     templates = DelegatesTo('run_factory')
    #     template = DelegatesTo('run_factory')

    add_button = Button('Add')
    clear_button = Button('Clear')
    save_button = Button('Save')
    edit_mode_button = Button('Edit')
    edit_enabled = DelegatesTo('run_factory')

    auto_increment_id = Bool(False)
    auto_increment_position = Bool(False)

    queue = Instance(ExperimentQueue, ())

    #    ok_run = Property(depends_on='_mass_spectrometer, _extract_device')
    ok_add = Property(
        depends_on='_mass_spectrometer, _extract_device, _labnumber, _username'
    )

    _username = String
    _mass_spectrometer = String
    extract_device = String
    _labnumber = String

    selected_positions = List
    default_mass_spectrometer = Str

    #     help_label = String('Select Irradiation/Level or Project')

    #===========================================================================
    # permisions
    #===========================================================================
    #    max_allowable_runs = Int(10000)
    #    can_edit_scripts = Bool(True)
    def __init__(self, *args, **kw):
        super(ExperimentFactory, self).__init__(*args, **kw)
        self.setup_consumer(self._add_run)

    def destroy(self):
        self._should_consume = False

    def set_selected_runs(self, runs):
        self.run_factory.set_selected_runs(runs)

    def _add_run(self, *args, **kw):
        egs = list(set([ai.extract_group for ai in self.queue.automated_runs]))
        eg = max(egs) if egs else 0

        positions = [str(pi.positions[0]) for pi in self.selected_positions]

        load_name = self.queue_factory.load_name
        new_runs, freq = self.run_factory.new_runs(
            positions=positions,
            auto_increment_position=self.auto_increment_position,
            auto_increment_id=self.auto_increment_id,
            extract_group_cnt=eg)
        #         if self.run_factory.check_run_addition(new_runs, load_name):
        #if self.run_factory.check_run_addition(new_runs, load_name):
        q = self.queue
        if q.selected:
            idx = q.automated_runs.index(q.selected[-1])
        else:
            idx = len(q.automated_runs) - 1

        self.queue.add_runs(new_runs, freq)

        idx += len(new_runs)

        with self.run_factory.update_selected_ctx():
            self.queue.select_run_idx(idx)

            #add()
            #invoke_in_main_thread(add)

    #===============================================================================
    # handlers
    #===============================================================================
    def _clear_button_fired(self):
        self.queue.clear_frequency_runs()

    def _add_button_fired(self):
        """
            only allow add button to be fired every 0.5s

            use consumermixin.add_consumable instead of frequency limiting
        """
        self.add_consumable(1)

    def _edit_mode_button_fired(self):
        self.run_factory.edit_mode = not self.run_factory.edit_mode

        #@on_trait_change('run_factory:clear_end_after')
        #def _clear_end_after(self, new):
        #    print 'enadfas', new

    def _update_end_after(self, new):
        if new:
            for ai in self.queue.automated_runs:
                ai.end_after = False

        self.run_factory.set_end_after(new)

    @on_trait_change('''queue_factory:[mass_spectrometer,
extract_device, delay_+, tray, username, load_name]''')
    def _update_queue(self, name, new):
        if name == 'mass_spectrometer':
            self._mass_spectrometer = new
            self.run_factory.set_mass_spectrometer(new)

        elif name == 'extract_device':
            self._set_extract_device(new)
        elif name == 'username':
            self._username = new
            #            self.queue.username = new

        if self.queue:
            self.queue.trait_set(**{name: new})

        self.queue.changed = True

    #===============================================================================
    # private
    #===============================================================================
    def _set_extract_device(self, ed):
        self.extract_device = ed
        self.run_factory = self._run_factory_factory()
        #         self.run_factory.update_templates_needed = True

        self.run_factory.load_templates()

        self.run_factory.remote_patterns = self._get_patterns(ed)
        self.run_factory.load_patterns()

        if self.queue:
            self.queue.set_extract_device(ed)
            self.queue.username = self._username
            self.queue.mass_spectrometer = self._mass_spectrometer

    def _get_patterns(self, ed):
        ps = []
        service_name = convert_extract_device(ed)
        #service_name = ed.replace(' ', '_').lower()
        man = self.application.get_service(ILaserManager,
                                           'name=="{}"'.format(service_name))
        if man:
            ps = man.get_pattern_names()
        else:
            self.debug('No remote patterns. {} ({}) not available'.format(
                ed, service_name))

        return ps

    #===============================================================================
    # property get/set
    #===============================================================================
    def _get_ok_add(self):
        '''
            tol should be a user permission
        '''
        return self._username and \
               not self._mass_spectrometer in ('', 'Spectrometer', LINE_STR) and \
               self._labnumber

    #===============================================================================
    #
    #===============================================================================
    def _run_factory_factory(self):
        if self.extract_device == 'Fusions UV':
            klass = UVAutomatedRunFactory
        else:
            klass = AutomatedRunFactory

        rf = klass(db=self.db,
                   application=self.application,
                   extract_device=self.extract_device,
                   mass_spectrometer=self.default_mass_spectrometer)

        rf.load_truncations()
        rf.on_trait_change(lambda x: self.trait_set(_labnumber=x), 'labnumber')
        rf.on_trait_change(self._update_end_after, 'end_after')
        return rf

    #    def _can_edit_scripts_changed(self):
    #        self.run_factory.can_edit = self.can_edit_scripts

    #===============================================================================
    # defaults
    #===============================================================================
    def _run_factory_default(self):
        return self._run_factory_factory()

    def _queue_factory_default(self):
        eq = ExperimentQueueFactory(db=self.db)
        return eq

    def _db_changed(self):
        self.queue_factory.db = self.db
        self.run_factory.db = self.db

    def _default_mass_spectrometer_changed(self):
        self.run_factory.set_mass_spectrometer(self.default_mass_spectrometer)
        self.queue_factory.mass_spectrometer = self.default_mass_spectrometer
        self._mass_spectrometer = self.default_mass_spectrometer
Example #11
0
class SetupPane(TraitsTaskPane):
    """A TraitsTaskPane containing the factory selection and new object
    configuration editors."""

    # -------------------
    # Required Attributes
    # -------------------

    system_state = Instance(SystemState, allow_none=False)

    # ------------------
    # Regular Attributes
    # ------------------

    #: An internal identifier for this pane
    id = 'force_wfmanager.setup_pane'

    #: Name displayed as the title of this pane
    name = 'Setup Pane'

    #: Enables or disables the object views displayed
    ui_enabled = Bool(True)

    # ------------------
    # Derived Attributes
    # ------------------

    #: The model from selected_view.
    #: Listens to: :attr:`models.workflow_tree.selected_view
    #: <force_wfmanager.ui.setup.workflow_tree.WorkflowTree.selected_view>`
    selected_model = Instance(BaseModel)

    #: An error message for the entire workflow
    error_message = Str()

    # --------------------
    #  Button Attributes
    # --------------------

    #: A Button which calls add_new_entity when pressed.
    add_new_entity_btn = Button()

    #: A Button which calls remove_entity when pressed.
    remove_entity_btn = Button()

    # ----------------
    #    Properties
    # ----------------

    #: The string displayed on the 'add new entity' button.
    add_new_entity_label = Property(
        Str, depends_on='system_state:selected_factory_name'
    )

    #: A Boolean indicating whether the currently selected modelview is
    #: intended to be editable by the user. This is required to avoid
    #: displaying a default view when a model does not have a View defined for
    #: it. If a modelview has a View defining how it is represented in the UI
    #: then this is used.
    selected_view_editable = Property(
        Bool, depends_on='system_state:selected_view'
    )

    #: A Boolean indicating whether the currently selected view is the main
    #: view containing the FORCE logo
    main_view_visible = Property(
        Bool, depends_on='system_state:selected_factory_name'
    )

    #: A Boolean indicating whether the currently selected view represents
    #: either a KPI or Parameter.
    mco_view_visible = Property(
        Bool, depends_on='system_state:[selected_view,selected_factory_name]'
    )

    #: A Boolean indicating whether the currently selected view represents
    #: a factory view.
    factory_view_visible = Property(
        Bool, depends_on='system_state:selected_factory_name'
    )

    #: A Boolean indicating whether the currently selected view represents
    #: an instance view.
    instance_view_visible = Property(
        Bool, depends_on='system_state:selected_factory_name'
    )

    #: A Boolean indicating whether the entity creator view should be displayed
    entity_creator_visible = Property(
        Bool, depends_on='system_state:entity_creator'
    )

    #: A panel displaying extra information about the workflow: Available
    #: Plugins, non-KPI variables, current filenames and any error messages.
    #: Displayed for factories which have a lot of empty screen space.
    current_info = Property(
        Instance(WorkflowInfo),
        depends_on='system_state:[selected_factory_name,selected_view],'
                   'task.current_file'
    )

    #: Determines if the add button should be enabled/visible.
    #  KPI and Execution Layers can always be added, but other workflow items
    #  need a specific factory to be selected.
    add_button_enabled = Property(
        Bool, depends_on='system_state:[selected_factory_name,'
                         'entity_creator.model]'
    )
    add_button_visible = Property(
        Bool, depends_on='system_state:selected_factory_name'
    )

    #: Determines if the remove button should be visible.
    remove_button_visible = Property(
        Bool, depends_on='system_state:selected_factory_name'
    )

    # -------------------
    #        View
    # -------------------

    #: The view when editing an existing instance within the workflow tree
    def default_traits_view(self):
        """ Sets up a TraitsUI view which displays the details
        (parameters etc.) of the currently selected view. This varies
        depending on what type of view is selected."""
        view = View(
            VGroup(
                # Main View with FORCE Logo
                VGroup(
                    UItem(
                        "current_info",
                        editor=InstanceEditor(),
                        style="custom",
                        ),
                    visible_when="main_view_visible",
                    enabled_when="ui_enabled"
                ),
                # MCO Parameter and KPI views
                VGroup(
                    UItem(
                        "object.system_state.selected_view",
                        editor=InstanceEditor(),
                        style="custom",
                    ),
                    visible_when="mco_view_visible",
                    enabled_when="ui_enabled"
                ),
                # Process Tree Views
                HGroup(
                    # Instance View
                    VGroup(
                        UItem(
                            "object.system_state.selected_view",
                            editor=InstanceEditor(),
                            style="custom",
                            visible_when="selected_view_editable"
                        ),
                        UItem(
                            "selected_model", editor=InstanceEditor(),
                            style="custom",
                            visible_when="selected_model is not None"
                        ),
                        # Remove Buttons
                        HGroup(
                            UItem('remove_entity_btn', label='Delete'),
                        ),
                        label="Item Details",
                        visible_when="instance_view_visible",
                        show_border=True,
                    ),
                    # Factory View
                    VGroup(
                        HGroup(
                            UItem(
                                "object.system_state.entity_creator",
                                editor=InstanceEditor(),
                                style="custom",
                                visible_when="entity_creator_visible",
                                width=825
                            ),
                            springy=True,
                        ),
                        HGroup(
                            UItem(
                                'add_new_entity_btn',
                                editor=ButtonEditor(
                                    label_value='add_new_entity_label'
                                ),
                                enabled_when='add_button_enabled',
                                visible_when="add_button_visible",
                                springy=True
                            ),
                            # Remove Buttons
                            UItem(
                                'remove_entity_btn',
                                label='Delete Layer',
                                visible_when='remove_button_visible'
                            ),
                            label="New Item Details",
                            visible_when="factory_view_visible",
                            show_border=True,
                        ),
                    ),
                    enabled_when="ui_enabled",
                ),
            ),
            scrollable=True,
            width=500,
        )

        return view

    # -------------------
    #      Listeners
    # -------------------

    @cached_property
    def _get_selected_view_editable(self):
        """ Determines if the selected modelview in the WorkflowTree has a
        default or non-default view associated. A default view should not
        be editable by the user, a non-default one should be.

        Parameters
        ----------
        self.selected_view.trait_views(): List of View
            The list of Views associated with self.selected_view. The default
            view is not included in this list.

        Returns
        -------
        Bool
            Returns True if selected_view has a User Editable/Non-Default View,
            False if it only has a default View or no modelview is
            currently selected
        """
        if self.system_state.selected_view is None:
            return False
        elif len(self.system_state.selected_view.trait_views()) == 0:
            return False
        return True

    @cached_property
    def _get_main_view_visible(self):
        return self.system_state.selected_factory_name == 'Workflow'

    @cached_property
    def _get_mco_view_visible(self):
        if self.system_state.selected_view is not None:
            mco_factories = ['MCO Parameters', 'MCO KPIs']
            return (
                    self.system_state.selected_factory_name in mco_factories
            )
        return False

    @cached_property
    def _get_factory_view_visible(self):
        factories = ['None', 'Workflow', 'MCO Parameters', 'MCO KPIs']
        return (
            self.system_state.selected_factory_name not in factories
        )

    @cached_property
    def _get_instance_view_visible(self):
        return self.system_state.selected_factory_name == 'None'

    @cached_property
    def _get_entity_creator_visible(self):
        return self.system_state.entity_creator is not None

    @cached_property
    def _get_add_button_enabled(self):
        """ Determines if the add button in the UI should be enabled."""
        simple_factories = ['Execution Layer', 'MCO']
        if self.system_state.selected_factory_name in simple_factories:
            return True
        if self.system_state.entity_creator is None \
                or self.system_state.entity_creator.model is None:
            return False
        return True

    @cached_property
    def _get_add_button_visible(self):
        """ Determines if the add button in the UI should be enabled"""
        return self.system_state.selected_factory_name != 'None'

    @cached_property
    def _get_remove_button_visible(self):
        """ Determines if the add button in the UI should be visible"""
        return self.system_state.selected_factory_name == 'Data Source'

    @cached_property
    def _get_add_new_entity_label(self):
        """Returns the label displayed on add_new_entity_btn"""
        return 'Add New {!s}'.format(self.system_state.selected_factory_name)

    @cached_property
    def _get_current_info(self):
        return WorkflowInfo(
            workflow_filename=self.task.current_file,
            plugins=self.task.lookup_plugins(),
            selected_factory_name=self.system_state.selected_factory_name,
            error_message=self.error_message
        )

    # Synchronisation with WorkflowTree
    @on_trait_change('system_state:selected_view.error_message')
    def sync_selected_view(self):
        """ Synchronise selected_view with the selected modelview in the tree
        editor. Checks if the model held by the modelview needs to be displayed
        in the UI."""
        if self.system_state.selected_view is not None:
            if isinstance(self.system_state.selected_view.model, BaseModel):
                self.selected_model = self.system_state.selected_view.model
            else:
                self.selected_model = None

            self.error_message = self.system_state.selected_view.error_message

    # Button event handlers for creating and deleting workflow items
    def _add_new_entity_btn_fired(self):
        """Calls add_new_entity when add_new_entity_btn is clicked"""
        self.system_state.add_new_entity()

    def _remove_entity_btn_fired(self):
        """Calls remove_entity when remove_entity_btn is clicked"""
        self.system_state.remove_entity()

    # -------------------
    #   Private Methods
    # -------------------

    def _console_ns_default(self):
        namespace = {
            "task": self.task
        }
        try:
            namespace["app"] = self.task.window.application
        except AttributeError:
            namespace["app"] = None

        return namespace
Example #12
0
class BaselineView(HasTraits):
    python_console_cmds = Dict()

    table = List()

    logging_b = Bool(False)
    directory_name_b = File

    plot = Instance(Plot)
    plot_data = Instance(ArrayPlotData)

    running = Bool(True)
    zoomall = Bool(False)
    position_centered = Bool(False)

    clear_button = SVGButton(label='',
                             tooltip='Clear',
                             filename=os.path.join(determine_path(), 'images',
                                                   'iconic', 'x.svg'),
                             width=16,
                             height=16)
    zoomall_button = SVGButton(label='',
                               tooltip='Zoom All',
                               toggle=True,
                               filename=os.path.join(determine_path(),
                                                     'images', 'iconic',
                                                     'fullscreen.svg'),
                               width=16,
                               height=16)
    center_button = SVGButton(label='',
                              tooltip='Center on Baseline',
                              toggle=True,
                              filename=os.path.join(determine_path(), 'images',
                                                    'iconic', 'target.svg'),
                              width=16,
                              height=16)
    paused_button = SVGButton(label='',
                              tooltip='Pause',
                              toggle_tooltip='Run',
                              toggle=True,
                              filename=os.path.join(determine_path(), 'images',
                                                    'iconic', 'pause.svg'),
                              toggle_filename=os.path.join(
                                  determine_path(), 'images', 'iconic',
                                  'play.svg'),
                              width=16,
                              height=16)

    reset_button = Button(label='Reset Filters')
    reset_iar_button = Button(label='Reset IAR')
    init_base_button = Button(label='Init. with known baseline')

    traits_view = View(
        HSplit(
            Item('table',
                 style='readonly',
                 editor=TabularEditor(adapter=SimpleAdapter()),
                 show_label=False,
                 width=0.3),
            VGroup(
                HGroup(
                    Item('paused_button', show_label=False),
                    Item('clear_button', show_label=False),
                    Item('zoomall_button', show_label=False),
                    Item('center_button', show_label=False),
                    Item('reset_button', show_label=False),
                    Item('reset_iar_button', show_label=False),
                    Item('init_base_button', show_label=False),
                ),
                Item(
                    'plot',
                    show_label=False,
                    editor=ComponentEditor(bgcolor=(0.8, 0.8, 0.8)),
                ))))

    def _zoomall_button_fired(self):
        self.zoomall = not self.zoomall

    def _center_button_fired(self):
        self.position_centered = not self.position_centered

    def _paused_button_fired(self):
        self.running = not self.running

    def _reset_button_fired(self):
        self.link(MsgResetFilters(filter=0))

    def _reset_iar_button_fired(self):
        self.link(MsgResetFilters(filter=1))

    def _init_base_button_fired(self):
        self.link(MsgInitBase())

    def _clear_button_fired(self):
        self.neds[:] = np.NAN
        self.fixeds[:] = False
        self.plot_data.set_data('n_fixed', [])
        self.plot_data.set_data('e_fixed', [])
        self.plot_data.set_data('d_fixed', [])
        self.plot_data.set_data('n_float', [])
        self.plot_data.set_data('e_float', [])
        self.plot_data.set_data('d_float', [])
        self.plot_data.set_data('t', [])
        self.plot_data.set_data('cur_fixed_n', [])
        self.plot_data.set_data('cur_fixed_e', [])
        self.plot_data.set_data('cur_fixed_d', [])
        self.plot_data.set_data('cur_float_n', [])
        self.plot_data.set_data('cur_float_e', [])
        self.plot_data.set_data('cur_float_d', [])

    def iar_state_callback(self, sbp_msg, **metadata):
        self.num_hyps = sbp_msg.num_hyps
        self.last_hyp_update = time.time()

    def _baseline_callback_ned(self, sbp_msg, **metadata):
        # Updating an ArrayPlotData isn't thread safe (see chaco issue #9), so
        # actually perform the update in the UI thread.
        if self.running:
            GUI.invoke_later(self.baseline_callback, sbp_msg)

    def update_table(self):
        self._table_list = self.table.items()

    def gps_time_callback(self, sbp_msg, **metadata):
        self.week = MsgGPSTime(sbp_msg).wn
        self.nsec = MsgGPSTime(sbp_msg).ns

    def mode_string(self, msg):
        if msg:
            self.fixed = (msg.flags & 1) == 1
            if self.fixed:
                return 'Fixed RTK'
            else:
                return 'Float'
        return 'None'

    def baseline_callback(self, sbp_msg):
        self.last_btime_update = time.time()
        soln = MsgBaselineNED(sbp_msg)
        self.last_soln = soln
        table = []

        soln.n = soln.n * 1e-3
        soln.e = soln.e * 1e-3
        soln.d = soln.d * 1e-3

        dist = np.sqrt(soln.n**2 + soln.e**2 + soln.d**2)

        tow = soln.tow * 1e-3
        if self.nsec is not None:
            tow += self.nsec * 1e-9

        if self.week is not None:
            t = datetime.datetime(1980, 1, 6) + \
                datetime.timedelta(weeks=self.week) + \
                datetime.timedelta(seconds=tow)

            table.append(('GPS Time', t))
            table.append(('GPS Week', str(self.week)))

            if self.directory_name_b == '':
                filepath = time.strftime("baseline_log_%Y%m%d-%H%M%S.csv")
            else:
                filepath = os.path.join(
                    self.directory_name_b,
                    time.strftime("baseline_log_%Y%m%d-%H%M%S.csv"))

            if self.logging_b == False:
                self.log_file = None

            if self.logging_b:
                if self.log_file is None:
                    self.log_file = open(filepath, 'w')

                    self.log_file.write(
                        'time,north(meters),east(meters),down(meters),distance(meters),num_signals,flags,num_hypothesis\n'
                    )

                self.log_file.write('%s,%.4f,%.4f,%.4f,%.4f,%d,0x%02x,%d\n' %
                                    (str(t), soln.n, soln.e, soln.d, dist,
                                     soln.n_sats, soln.flags, self.num_hyps))
                self.log_file.flush()

        table.append(('GPS ToW', tow))

        table.append(('N', soln.n))
        table.append(('E', soln.e))
        table.append(('D', soln.d))
        table.append(('Dist.', dist))
        table.append(('Num. Signals.', soln.n_sats))
        table.append(('Flags', '0x%02x' % soln.flags))
        table.append(('Mode', self.mode_string(soln)))
        if time.time() - self.last_hyp_update < 10 and self.num_hyps != 1:
            table.append(('IAR Num. Hyps.', self.num_hyps))
        else:
            table.append(('IAR Num. Hyps.', "None"))

        # Rotate array, deleting oldest entries to maintain
        # no more than N in plot
        self.neds[1:] = self.neds[:-1]
        self.fixeds[1:] = self.fixeds[:-1]

        # Insert latest position
        self.neds[0][:] = [soln.n, soln.e, soln.d]
        self.fixeds[0] = self.fixed

        neds_fixed = self.neds[self.fixeds]
        neds_float = self.neds[np.logical_not(self.fixeds)]

        if not all(map(any, np.isnan(neds_fixed))):
            self.plot_data.set_data('n_fixed', neds_fixed.T[0])
            self.plot_data.set_data('e_fixed', neds_fixed.T[1])
            self.plot_data.set_data('d_fixed', neds_fixed.T[2])
        if not all(map(any, np.isnan(neds_float))):
            self.plot_data.set_data('n_float', neds_float.T[0])
            self.plot_data.set_data('e_float', neds_float.T[1])
            self.plot_data.set_data('d_float', neds_float.T[2])

        if self.fixed:
            self.plot_data.set_data('cur_fixed_n', [soln.n])
            self.plot_data.set_data('cur_fixed_e', [soln.e])
            self.plot_data.set_data('cur_fixed_d', [soln.d])
            self.plot_data.set_data('cur_float_n', [])
            self.plot_data.set_data('cur_float_e', [])
            self.plot_data.set_data('cur_float_d', [])
        else:
            self.plot_data.set_data('cur_float_n', [soln.n])
            self.plot_data.set_data('cur_float_e', [soln.e])
            self.plot_data.set_data('cur_float_d', [soln.d])
            self.plot_data.set_data('cur_fixed_n', [])
            self.plot_data.set_data('cur_fixed_e', [])
            self.plot_data.set_data('cur_fixed_d', [])

        self.plot_data.set_data('ref_n', [0.0])
        self.plot_data.set_data('ref_e', [0.0])
        self.plot_data.set_data('ref_d', [0.0])

        if self.position_centered:
            d = (self.plot.index_range.high - self.plot.index_range.low) / 2.
            self.plot.index_range.set_bounds(soln.e - d, soln.e + d)
            d = (self.plot.value_range.high - self.plot.value_range.low) / 2.
            self.plot.value_range.set_bounds(soln.n - d, soln.n + d)

        if self.zoomall:
            plot_square_axes(self.plot, ('e_fixed', 'e_float'),
                             ('n_fixed', 'n_float'))
        self.table = table

    def __init__(self, link, plot_history_max=1000, dirname=''):
        super(BaselineView, self).__init__()
        self.log_file = None
        self.directory_name_b = dirname
        self.num_hyps = 0
        self.last_hyp_update = 0
        self.last_btime_update = 0
        self.last_soln = None
        self.plot_data = ArrayPlotData(n_fixed=[0.0],
                                       e_fixed=[0.0],
                                       d_fixed=[0.0],
                                       n_float=[0.0],
                                       e_float=[0.0],
                                       d_float=[0.0],
                                       t=[0.0],
                                       ref_n=[0.0],
                                       ref_e=[0.0],
                                       ref_d=[0.0],
                                       cur_fixed_e=[],
                                       cur_fixed_n=[],
                                       cur_fixed_d=[],
                                       cur_float_e=[],
                                       cur_float_n=[],
                                       cur_float_d=[])
        self.plot_history_max = plot_history_max

        self.neds = np.empty((plot_history_max, 3))
        self.neds[:] = np.NAN
        self.fixeds = np.zeros(plot_history_max, dtype=bool)

        self.plot = Plot(self.plot_data)
        color_float = (0.5, 0.5, 1.0)
        color_fixed = 'orange'
        pts_float = self.plot.plot(('e_float', 'n_float'),
                                   type='scatter',
                                   color=color_float,
                                   marker='dot',
                                   line_width=0.0,
                                   marker_size=1.0)
        pts_fixed = self.plot.plot(('e_fixed', 'n_fixed'),
                                   type='scatter',
                                   color=color_fixed,
                                   marker='dot',
                                   line_width=0.0,
                                   marker_size=1.0)
        lin = self.plot.plot(('e_fixed', 'n_fixed'),
                             type='line',
                             color=(1, 0.65, 0, 0.1))
        ref = self.plot.plot(('ref_e', 'ref_n'),
                             type='scatter',
                             color='red',
                             marker='plus',
                             marker_size=5,
                             line_width=1.5)
        cur_fixed = self.plot.plot(('cur_fixed_e', 'cur_fixed_n'),
                                   type='scatter',
                                   color=color_fixed,
                                   marker='plus',
                                   marker_size=5,
                                   line_width=1.5)
        cur_float = self.plot.plot(('cur_float_e', 'cur_float_n'),
                                   type='scatter',
                                   color=color_float,
                                   marker='plus',
                                   marker_size=5,
                                   line_width=1.5)
        plot_labels = ['Base Position', 'RTK Fixed', 'RTK Float']
        plots_legend = dict(zip(plot_labels, [ref, cur_fixed, cur_float]))
        self.plot.legend.plots = plots_legend
        self.plot.legend.visible = True

        self.plot.index_axis.tick_label_position = 'inside'
        self.plot.index_axis.tick_label_color = 'gray'
        self.plot.index_axis.tick_color = 'gray'
        self.plot.index_axis.title = 'E (meters)'
        self.plot.index_axis.title_spacing = 5
        self.plot.value_axis.tick_label_position = 'inside'
        self.plot.value_axis.tick_label_color = 'gray'
        self.plot.value_axis.tick_color = 'gray'
        self.plot.value_axis.title = 'N (meters)'
        self.plot.value_axis.title_spacing = 5
        self.plot.padding = (25, 25, 25, 25)

        self.plot.tools.append(PanTool(self.plot))
        zt = ZoomTool(self.plot,
                      zoom_factor=1.1,
                      tool_mode="box",
                      always_on=False)
        self.plot.overlays.append(zt)

        self.week = None
        self.nsec = 0

        self.link = link
        self.link.add_callback(self._baseline_callback_ned,
                               SBP_MSG_BASELINE_NED)
        self.link.add_callback(self.iar_state_callback, SBP_MSG_IAR_STATE)
        self.link.add_callback(self.gps_time_callback, SBP_MSG_GPS_TIME)

        self.python_console_cmds = {'baseline': self}
Example #13
0
class config(BaseWorkflowConfig):
    uuid = traits.Str(desc="UUID")
    desc = traits.Str(desc='Workflow description')
    # Directories
    base_dir = Directory(
        os.path.abspath('.'),
        mandatory=True,
        desc='Base directory of data. (Should be subject-independent)')
    sink_dir = Directory(os.path.abspath('.'),
                         mandatory=True,
                         desc="Location where the BIP will store the results")
    field_dir = Directory(
        desc="Base directory of field-map data (Should be subject-independent) \
                                                 Set this value to None if you don't want fieldmap distortion correction"
    )
    surf_dir = Directory(mandatory=True, desc="Freesurfer subjects directory")

    # Subjects

    subjects = traits.List(
        traits.Str,
        mandatory=True,
        usedefault=True,
        desc="Subject id's. Note: These MUST match the subject id's in the \
                                Freesurfer directory. For simplicity, the subject id's should \
                                also match with the location of individual functional files."
    )
    func_template = traits.String('%s/functional.nii.gz')
    run_datagrabber_without_submitting = traits.Bool(
        desc="Run the datagrabber without \
    submitting to the cluster")
    timepoints_to_remove = traits.Int(0, usedefault=True)

    # Fieldmap

    use_fieldmap = Bool(
        False,
        mandatory=False,
        usedefault=True,
        desc='True to include fieldmap distortion correction. Note: field_dir \
                                     must be specified')
    magnitude_template = traits.String('%s/magnitude.nii.gz')
    phase_template = traits.String('%s/phase.nii.gz')
    TE_diff = traits.Float(desc='difference in B0 field map TEs')
    sigma = traits.Int(
        2, desc='2D spatial gaussing smoothing stdev (default = 2mm)')
    echospacing = traits.Float(desc="EPI echo spacing")

    # Motion Correction

    do_slicetiming = Bool(True,
                          usedefault=True,
                          desc="Perform slice timing correction")
    SliceOrder = traits.List(traits.Int)
    TR = traits.Float(1.0, mandatory=True, desc="TR of functional")
    motion_correct_node = traits.Enum(
        'nipy',
        'fsl',
        'spm',
        'afni',
        desc="motion correction algorithm to use",
        usedefault=True,
    )
    loops = traits.List([5], traits.Int(5), usedefault=True)
    #between_loops = traits.Either("None",traits.List([5]),usedefault=True)
    speedup = traits.List([5], traits.Int(5), usedefault=True)
    # Artifact Detection

    norm_thresh = traits.Float(1,
                               min=0,
                               usedefault=True,
                               desc="norm thresh for art")
    z_thresh = traits.Float(3, min=0, usedefault=True, desc="z thresh for art")

    # Smoothing
    fwhm = traits.List(
        [0, 5],
        traits.Float(),
        mandatory=True,
        usedefault=True,
        desc="Full width at half max. The data will be smoothed at all values \
                             specified in this list.")
    smooth_type = traits.Enum("susan",
                              "isotropic",
                              'freesurfer',
                              usedefault=True,
                              desc="Type of smoothing to use")
    surface_fwhm = traits.Float(
        0.0,
        desc='surface smoothing kernel, if freesurfer is selected',
        usedefault=True)

    # CompCor
    compcor_select = traits.BaseTuple(
        traits.Bool,
        traits.Bool,
        mandatory=True,
        desc="The first value in the list corresponds to applying \
                                       t-compcor, and the second value to a-compcor. Note: \
                                       both can be true")
    num_noise_components = traits.Int(
        6,
        usedefault=True,
        desc="number of principle components of the noise to use")
    regress_before_PCA = traits.Bool(True)
    # Highpass Filter
    hpcutoff = traits.Float(128., desc="highpass cutoff", usedefault=True)

    #zscore
    do_zscore = Bool(False)

    # Advanced Options
    use_advanced_options = traits.Bool()
    advanced_script = traits.Code()
    debug = traits.Bool(False)

    # Buttons
    check_func_datagrabber = Button("Check")
    check_field_datagrabber = Button("Check")

    def _check_func_datagrabber_fired(self):
        subs = self.subjects

        for s in subs:
            if not os.path.exists(
                    os.path.join(self.base_dir, self.func_template % s)):
                print "ERROR", os.path.join(self.base_dir, self.func_template %
                                            s), "does NOT exist!"
                break
            else:
                print os.path.join(self.base_dir,
                                   self.func_template % s), "exists!"

    def _check_field_datagrabber_fired(self):
        subs = self.subjects

        for s in subs:
            if not os.path.exists(
                    os.path.join(self.field_dir, self.magnitude_template % s)):
                print "ERROR:", os.path.join(self.field_dir,
                                             self.magnitude_template %
                                             s), "does NOT exist!"
                break
            else:
                print os.path.join(self.base_dir,
                                   self.magnitude_template % s), "exists!"
            if not os.path.exists(
                    os.path.join(self.field_dir, self.phase_template % s)):
                print "ERROR:", os.path.join(
                    self.field_dir, self.phase_template % s), "does NOT exist!"
                break
            else:
                print os.path.join(self.base_dir,
                                   self.phase_template % s), "exists!"
Example #14
0
class StartPage(Editor):
    """
    A Pyface Tasks Editor to hold the opening page
    """
    #: The model object to view. If not specified, the editor is used instead.
    model = Instance(HasTraits)

    #: The UI object associated with the Traits view, if it has been
    #: constructed.
    ui = Instance('traitsui.ui.UI')

    #: The editor's user-visible name.
    name = Str('Start Page')

    #: The task associated with the editor.
    task = Any()

    #: Button to open a new data file.
    open_data_file_button = Button(label='Open Data File', style='button')

    def default_traits_view(self):  # pylint: disable=no-self-use
        """
        Create the default traits View object for the model

        Returns
        -------
        default_traits_view : :py:class:`traitsui.view.View`
            The default traits View object for the model
        """
        return View(
            Group(Spring(),
                  Group(Spring(),
                        Item('open_data_file_button', show_label=False),
                        Spring(),
                        orientation='horizontal'),
                  Spring(),
                  orientation='vertical'))

    def create(self, parent):
        """
        Create and set the widget(s) for the Editor.

        Parameters
        ----------
        parent : toolkit-specific widget
            The parent widget for the Editor
        """
        self.ui = self.edit_traits(kind='subpanel', parent=parent)  # pylint: disable=invalid-name
        self.control = self.ui.control

    def destroy(self):
        """
        Destroy the Editor and clean up after
        """
        self.control = None
        if self.ui is not None:
            self.ui.dispose()
        self.ui = None

    @observe('open_data_file_button', post_init=True)
    def open(self, event):
        """
        Open new data file.

        Parameters
        ----------
        event : A :py:class:`traits.observation.events.TraitChangeEvent` instance
            The trait change event for open_data_file_button
        """
        self.task.open()
Example #15
0
class FiducialsPanel(HasPrivateTraits):
    """Set fiducials on an MRI surface"""
    model = Instance(MRIHeadWithFiducialsModel)

    fid_file = DelegatesTo('model')
    fid_fname = DelegatesTo('model')
    lpa = DelegatesTo('model')
    nasion = DelegatesTo('model')
    rpa = DelegatesTo('model')
    can_save = DelegatesTo('model')
    can_save_as = DelegatesTo('model')
    can_reset = DelegatesTo('model')
    fid_ok = DelegatesTo('model')
    locked = DelegatesTo('model', 'lock_fiducials')

    set = Enum('LPA', 'Nasion', 'RPA')
    current_pos = Array(float, (1, 3))  # for editing

    save_as = Button(label='Save As...')
    save = Button(label='Save')
    reset_fid = Button(label="Reset to File")

    headview = Instance(HeadViewController)
    hsp_obj = Instance(SurfaceObject)

    picker = Instance(object)

    # the layout of the dialog created
    view = View(
        VGroup(Item('fid_file', label='Fiducials File'),
               Item('fid_fname', show_label=False, style='readonly'),
               Item('set', style='custom'),
               Item('current_pos', label='Pos'),
               HGroup(Item('save',
                           enabled_when='can_save',
                           tooltip="If a filename is currently "
                           "specified, save to that file, otherwise "
                           "save to the default file name"),
                      Item('save_as', enabled_when='can_save_as'),
                      Item('reset_fid', enabled_when='can_reset'),
                      show_labels=False),
               enabled_when="locked==False"))

    def __init__(self, *args, **kwargs):
        super(FiducialsPanel, self).__init__(*args, **kwargs)
        self.sync_trait('lpa', self, 'current_pos', mutual=True)

    def _reset_fid_fired(self):
        self.model.reset = True

    def _save_fired(self):
        self.model.save()

    def _save_as_fired(self):
        if self.fid_file:
            default_path = self.fid_file
        else:
            default_path = self.model.default_fid_fname

        dlg = FileDialog(action="save as",
                         wildcard=fid_wildcard,
                         default_path=default_path)
        dlg.open()
        if dlg.return_code != OK:
            return

        path = dlg.path
        if not path.endswith('.fif'):
            path = path + '.fif'
            if os.path.exists(path):
                answer = confirm(
                    None, "The file %r already exists. Should it "
                    "be replaced?", "Overwrite File?")
                if answer != YES:
                    return

        self.model.save(path)

    def _on_pick(self, picker):
        if self.locked:
            return

        self.picker = picker
        n_pos = len(picker.picked_positions)

        if n_pos == 0:
            logger.debug("GUI: picked empty location")
            return

        if picker.actor is self.hsp_obj.surf.actor.actor:
            idxs = []
            idx = None
            pt = [picker.pick_position]
        elif self.hsp_obj.surf.actor.actor in picker.actors:
            idxs = [
                i for i in range(n_pos)
                if picker.actors[i] is self.hsp_obj.surf.actor.actor
            ]
            idx = idxs[-1]
            pt = [picker.picked_positions[idx]]
        else:
            logger.debug("GUI: picked object other than MRI")

        round_ = lambda x: round(x, 3)
        poss = [map(round_, pos) for pos in picker.picked_positions]
        pos = map(round_, picker.pick_position)
        msg = ["Pick Event: %i picked_positions:" % n_pos]

        line = str(pos)
        if idx is None:
            line += " <-pick_position"
        msg.append(line)

        for i, pos in enumerate(poss):
            line = str(pos)
            if i == idx:
                line += " <- MRI mesh"
            elif i in idxs:
                line += " (<- also MRI mesh)"
            msg.append(line)
        logger.debug(os.linesep.join(msg))

        if self.set == 'Nasion':
            self.nasion = pt
        elif self.set == 'LPA':
            self.lpa = pt
        elif self.set == 'RPA':
            self.rpa = pt
        else:
            raise ValueError("set = %r" % self.set)

    @on_trait_change('set')
    def _on_set_change(self, obj, name, old, new):
        self.sync_trait(old.lower(),
                        self,
                        'current_pos',
                        mutual=True,
                        remove=True)
        self.sync_trait(new.lower(), self, 'current_pos', mutual=True)
        if new == 'Nasion':
            self.headview.front = True
        elif new == 'LPA':
            self.headview.left = True
        elif new == 'RPA':
            self.headview.right = True
class FieldViewer(HasTraits):

    # 三个轴的取值范围
    x0, x1 = Float(-5), Float(5)
    y0, y1 = Float(-5), Float(5)
    z0, z1 = Float(-5), Float(5)
    points = Int(50)  # 分割点数
    autocontour = Bool(True)  # 是否自动计算等值面
    v0, v1 = Float(0.0), Float(1.0)  # 等值面的取值范围
    contour = Range("v0", "v1", 0.5)  # 等值面的值
    function = Str("x*x*0.5 + y*y + z*z*2.0")  # 标量场函数
    function_list = [
        "x*x*0.5 + y*y + z*z*2.0", "x*y*0.5 + np.sin(2*x)*y +y*z*2.0", "x*y*z",
        "np.sin((x*x+y*y)/z)"
    ]
    plotbutton = Button(u"描画")
    scene = Instance(MlabSceneModel, ())  #❶

    view = View(
        HSplit(
            VGroup(
                "x0",
                "x1",
                "y0",
                "y1",
                "z0",
                "z1",
                Item('points', label=u"点数"),
                Item('autocontour', label=u"自动等值"),
                Item('plotbutton', show_label=False),
            ),
            VGroup(
                Item(
                    'scene',
                    editor=SceneEditor(scene_class=MayaviScene),  #❷
                    resizable=True,
                    height=300,
                    width=350),
                Item('function',
                     editor=EnumEditor(name='function_list',
                                       evaluate=lambda x: x)),
                Item('contour',
                     editor=RangeEditor(format="%1.2f",
                                        low_name="v0",
                                        high_name="v1")),
                show_labels=False)),
        width=500,
        resizable=True,
        title=u"三维标量场观察器")

    def _plotbutton_fired(self):
        self.plot()

    def plot(self):
        # 产生三维网格
        x, y, z = np.mgrid[  #❸
            self.x0:self.x1:1j * self.points, self.y0:self.y1:1j * self.points,
            self.z0:self.z1:1j * self.points]

        # 根据函数计算标量场的值
        scalars = eval(self.function)  #❹
        self.scene.mlab.clf()  # 清空当前场景

        # 绘制等值平面
        g = self.scene.mlab.contour3d(x,
                                      y,
                                      z,
                                      scalars,
                                      contours=8,
                                      transparent=True)  #❺
        g.contour.auto_contours = self.autocontour
        self.scene.mlab.axes(figure=self.scene.mayavi_scene)  # 添加坐标轴

        # 添加一个X-Y的切面
        s = self.scene.mlab.pipeline.scalar_cut_plane(g)
        cutpoint = (self.x0 + self.x1) / 2, (self.y0 + self.y1) / 2, (
            self.z0 + self.z1) / 2
        s.implicit_plane.normal = (0, 0, 1)  # x cut
        s.implicit_plane.origin = cutpoint

        self.g = g  #❻
        self.scalars = scalars
        # 计算标量场的值的范围
        self.v0 = np.min(scalars)
        self.v1 = np.max(scalars)

    def _contour_changed(self):  #❼
        if hasattr(self, "g"):
            if not self.g.contour.auto_contours:
                self.g.contour.contours = [self.contour]

    def _autocontour_changed(self):  #❽
        if hasattr(self, "g"):
            self.g.contour.auto_contours = self.autocontour
            if not self.autocontour:
                self._contour_changed()
Example #17
0
class SbpRelayView(HasTraits):
  """
  SBP Relay view- Class allows user to specify port, IP address, and message set
  to relay over UDP and to configure a skylark connection
  """
  running = Bool(False)
  configured = Bool(False)
  broadcasting = Bool(False)
  msg_enum = Enum('Observations', 'All')
  ip_ad = String(DEFAULT_UDP_ADDRESS)
  port = Int(DEFAULT_UDP_PORT)
  information = String('UDP Streaming\n\nBroadcast SBP information received by'
    ' the console to other machines or processes over UDP. With the \'Observations\''
    ' radio button selected, the console will broadcast the necessary information'
    ' for a rover Piksi to acheive an RTK solution.'
    '\n\nThis can be used to stream observations to a remote Piksi through'
    ' aircraft telemetry via ground control software such as MAVProxy or'
    ' Mission Planner.')
  http_information = String('Skylark - Experimental Piksi Networking\n\n'
                            "Skylark is Swift Navigation's Internet service for connecting Piksi receivers without the use of a radio. To receive GPS observations from the closest nearby Piksi base station (within 5km), click Connect to Skylark.\n\n")
  start = Button(label='Start', toggle=True, width=32)
  stop = Button(label='Stop', toggle=True, width=32)
  connected_rover = Bool(False)
  connect_rover = Button(label='Connect to Skylark', toggle=True, width=32)
  disconnect_rover = Button(label='Disconnect from Skylark', toggle=True, width=32)
  skylark_url = String()
  base_pragma = String()
  rover_pragma = String()
  base_device_uid = String()
  rover_device_uid = String()
  toggle=True
  view = View(
           VGroup(
             spring,
             HGroup(
               VGroup(
                 Item('running', show_label=True, style='readonly', visible_when='running'),
                 Item('msg_enum', label="Messages to broadcast",
                      style='custom', enabled_when='not running'),
                 Item('ip_ad', label='IP Address', enabled_when='not running'),
                 Item('port', label="Port", enabled_when='not running'),
                 HGroup(
                   spring,
                   UItem('start', enabled_when='not running', show_label=False),
                   UItem('stop', enabled_when='running', show_label=False),
                   spring)),
               VGroup(
                 Item('information', label="Notes", height=10,
                      editor=MultilineTextEditor(TextEditor(multi_line=True)), style='readonly',
                      show_label=False, resizable=True, padding=15),
                 spring,
               ),
             ),
             spring,
             HGroup(
               VGroup(
                 HGroup(
                   spring,
                   UItem('connect_rover', enabled_when='not connected_rover', show_label=False),
                   UItem('disconnect_rover', enabled_when='connected_rover', show_label=False),
                   spring),
                   HGroup(Spring(springy=False, width=2),
                          Item('skylark_url', enabled_when='not connected_rover', show_label=True),
                          Spring(springy=False, width=2)
                          ),
                 HGroup(spring,
                        Item('base_pragma',  label='Base option '),
                        Item('base_device_uid',  label='Base device '),
                        spring),
                 HGroup(spring,
                        Item('rover_pragma', label='Rover option'),
                        Item('rover_device_uid',  label='Rover device'),
                        spring),),
               VGroup(
                 Item('http_information', label="Notes", height=10,
                      editor=MultilineTextEditor(TextEditor(multi_line=True)), style='readonly',
                      show_label=False, resizable=True, padding=15),
                 spring,
               ),
             ),
             spring
           )
  )

  def __init__(self, link, device_uid=None, base=DEFAULT_BASE, 
               whitelist=None, rover_pragma='', base_pragma='', rover_uuid='', base_uuid='',
               connect=False, verbose=False):
    """
    Traits tab with UI for UDP broadcast of SBP.

    Parameters
    ----------
    link : sbp.client.handler.Handler
      Link for SBP transfer to/from Piksi.
    device_uid : str
      Piksi Device UUID (defaults to None)
    base : str
      HTTP endpoint
    whitelist : [int] | None
      Piksi Device UUID (defaults to None)

    """
    self.link = link
    # Whitelist used for UDP broadcast view
    self.msgs = OBS_MSGS
    # register a callback when the msg_enum trait changes
    self.on_trait_change(self.update_msgs, 'msg_enum')
    # Whitelist used for Skylark broadcasting
    self.whitelist = whitelist
    self.device_uid = None
    self.python_console_cmds = {'update': self}
    self.rover_pragma = rover_pragma
    self.base_pragma = base_pragma
    self.rover_device_uid = rover_uuid
    self.base_device_uid = base_uuid
    self.verbose = verbose
    self.skylark_watchdog_thread = None
    self.skylark_url = base
    if connect:
      self.connect_when_uuid_received = True
    else:
      self.connect_when_uuid_received = False

  def update_msgs(self):
    """Updates the instance variable msgs which store the msgs that we
    will send over UDP.

    """
    if self.msg_enum == 'Observations':
      self.msgs = OBS_MSGS
    elif self.msg_enum == 'All':
      self.msgs = [None]
    else:
      raise NotImplementedError

  def set_route(self, uuid=None, serial_id=None, channel=CHANNEL_UUID):
    """Sets serial_id hash for HTTP headers.

    Parameters
    ----------
    uuid: str
      real uuid of device
    serial_id : int
      Piksi device ID
    channel : str
      UUID namespace for device UUID

    """
    if uuid:
      device_uid = uuid
    elif serial_id:
      device_uid = str(get_uuid(channel, serial_id % 1000))
    else:
      print("Improper call of set_route, either a serial number or UUID should be passed")
      device_uid = str(get_uuid(channel, 1234))
      print(("Setting UUID to default value of {0}".format(device_uid)))
    self.device_uid = device_uid

  def _prompt_setting_error(self, text):
    """Nonblocking prompt for a device setting error.

    Parameters
    ----------
    text : str
      Helpful error message for the user

    """
    prompt = CallbackPrompt(title="Setting Error", actions=[close_button])
    prompt.text = text
    prompt.run(block=False)

  def _disconnect_rover_fired(self):
    """Handle callback for HTTP rover disconnects.

    """
    try:
      if isinstance(self.skylark_watchdog_thread, threading.Thread) and \
         not self.skylark_watchdog_thread.stopped():
        self.skylark_watchdog_thread.stop()
      else:
        print((("Unable to disconnect: Skylark watchdog thread "
               "inititalized at {0} and connected since {1} has " 
               "already been stopped").format(self.skylark_watchdog_thread.get_init_time(),
                                              self.skylark_watchdog_thread.get_connect_time())))
      self.connected_rover = False
    except:
      self.connected_rover = False
      import traceback
      print((traceback.format_exc()))

  def _connect_rover_fired(self):
    """Handle callback for HTTP rover connections.  Launches an instance of skylark_watchdog_thread.
    """
    if not self.device_uid:
      msg = "\nDevice ID not found!\n\nConnection requires a valid Piksi device ID."
      self._prompt_setting_error(msg)
      return
    try:
      _base_device_uid = self.base_device_uid or self.device_uid
      _rover_device_uid = self.rover_device_uid or self.device_uid
      config = SkylarkConsoleConnectConfig(self.link, self.device_uid, 
               self.skylark_url, self.whitelist, self.rover_pragma, 
               self.base_pragma, _rover_device_uid, _base_device_uid)
      self.skylark_watchdog_thread = SkylarkWatchdogThread(link=self.link, skylark_config=config, 
                                        stopped_callback=self._disconnect_rover_fired,
                                        verbose=self.verbose)
      self.connected_rover = True
      self.skylark_watchdog_thread.start()
    except:
      if isinstance(self.skylark_watchdog_thread, threading.Thread) \
         and self.skylark_watchdog_thread.stopped():
        self.skylark_watchdog_thread.stop()
      self.connected_rover = False
      import traceback
      print((traceback.format_exc()))

  def _start_fired(self):
    """Handle start udp broadcast button. Registers callbacks on
    self.link for each of the self.msgs If self.msgs is None, it
    registers one generic callback for all messages.

    """
    self.running = True
    try:
      self.func = UdpLogger(self.ip_ad, self.port)
      self.link.add_callback(self.func, self.msgs)
    except:
      import traceback
      print((traceback.format_exc()))

  def _stop_fired(self):
    """Handle the stop udp broadcast button. It uses the self.funcs and
    self.msgs to remove the callbacks that were registered when the
    start button was pressed.

    """
    try:
      self.link.remove_callback(self.func, self.msgs)
      self.func.__exit__()
      self.func = None
      self.running = False
    except:
      import traceback
      print((traceback.format_exc()))
class InterpolatorView(HasTraits):

    # The bounds on which to interpolate.
    bounds = Array(cols=3,
                   dtype=float,
                   desc='spatial bounds for the interpolation '
                   '(xmin, xmax, ymin, ymax, zmin, zmax)')

    # The number of points to interpolate onto.
    num_points = Int(100000,
                     enter_set=True,
                     auto_set=False,
                     desc='number of points on which to interpolate')

    # The particle arrays to interpolate from.
    particle_arrays = List

    # The scalar to interpolate.
    scalar = Str('rho', desc='name of the active scalar to view')

    # Sync'd trait with the scalar lut manager.
    show_legend = Bool(False, desc='if the scalar legend is to be displayed')

    # Enable/disable the interpolation
    visible = Bool(False, desc='if the interpolation is to be displayed')

    # A button to use the set bounds.
    set_bounds = Button('Set Bounds')

    # A button to recompute the bounds.
    recompute_bounds = Button('Recompute Bounds')

    # Private traits. ######################################################

    # The interpolator we are a view for.
    interpolator = Instance(Interpolator)

    # The mlab plot for this particle array.
    plot = Instance(PipelineBase)

    scalar_list = List

    scene = Instance(MlabSceneModel)

    source = Instance(PipelineBase)

    _arrays_changed = Bool(False)

    # View definition ######################################################
    view = View(
        Item(name='visible'),
        Item(name='scalar', editor=EnumEditor(name='scalar_list')),
        Item(name='num_points'),
        Item(name='bounds'),
        Item(name='set_bounds', show_label=False),
        Item(name='recompute_bounds', show_label=False),
        Item(name='show_legend'),
    )

    # Private protocol  ###################################################
    def _change_bounds(self):
        interp = self.interpolator
        if interp is not None:
            interp.set_domain(self.bounds, self.interpolator.shape)
            self._update_plot()

    def _setup_interpolator(self):
        if self.interpolator is None:
            interpolator = Interpolator(self.particle_arrays,
                                        num_points=self.num_points)
            self.bounds = interpolator.bounds
            self.interpolator = interpolator
        else:
            if self._arrays_changed:
                self.interpolator.update_particle_arrays(self.particle_arrays)
                self._arrays_changed = False

    # Trait handlers  #####################################################
    def _particle_arrays_changed(self, pas):
        if len(pas) > 0:
            all_props = reduce(set.union,
                               [set(x.properties.keys()) for x in pas])
        else:
            all_props = set()
        self.scalar_list = list(all_props)
        self._arrays_changed = True
        self._update_plot()

    def _num_points_changed(self, value):
        interp = self.interpolator
        if interp is not None:
            bounds = self.interpolator.bounds
            shape = get_nx_ny_nz(value, bounds)
            interp.set_domain(bounds, shape)
            self._update_plot()

    def _recompute_bounds_fired(self):
        bounds = get_bounding_box(self.particle_arrays)
        self.bounds = bounds
        self._change_bounds()

    def _set_bounds_fired(self):
        self._change_bounds()

    def _bounds_default(self):
        return [0, 1, 0, 1, 0, 1]

    @on_trait_change('scalar, visible')
    def _update_plot(self):
        if self.visible:
            mlab = self.scene.mlab
            self._setup_interpolator()
            interp = self.interpolator
            prop = interp.interpolate(self.scalar)
            if self.source is None:
                src = mlab.pipeline.scalar_field(interp.x, interp.y, interp.z,
                                                 prop)
                self.source = src
            else:
                self.source.mlab_source.reset(x=interp.x,
                                              y=interp.y,
                                              z=interp.z,
                                              scalars=prop)
            src = self.source

            if self.plot is None:
                if interp.dim == 3:
                    plot = mlab.pipeline.scalar_cut_plane(src)
                else:
                    plot = mlab.pipeline.surface(src)
                self.plot = plot
                scm = plot.module_manager.scalar_lut_manager
                scm.set(show_legend=self.show_legend,
                        use_default_name=False,
                        data_name=self.scalar)
                self.sync_trait('show_legend', scm, mutual=True)
            else:
                self.plot.visible = True
                scm = self.plot.module_manager.scalar_lut_manager
                scm.data_name = self.scalar
        else:
            if self.plot is not None:
                self.plot.visible = False
Example #19
0
class TDViz(HasTraits):
    fitsfile = File(filter=[u"*.fits"])
    plotbutton1 = Button(u"Plot")
    plotbutton2 = Button(u"Plot")
    plotbutton3 = Button(u"Plot")
    clearbutton = Button(u"Clear")
    scene = Instance(MlabSceneModel, ())
    rendering = Enum("Surface-Spectrum", "Surface-Intensity",
                     "Volume-Intensity")
    save_the_scene = Button(u"Save")
    save_in_file = Str("test.x3d")
    movie = Button(u"Movie")
    iteration = Int(0)
    quality = Int(8)
    delay = Int(0)
    angle = Int(360)
    spin = Button(u"Spin")
    zscale = Int(1)
    xstart = Int(0)
    xend = Int(1)
    ystart = Int(0)
    yend = Int(1)
    zstart = Int(0)
    zend = Int(1)
    datamin = Float(0.0)
    datamax = Float(1.0)
    opacity = Float(0.4)
    dist = Float(0.0)
    leng = Float(0.0)
    vsp = Float(0.0)
    contfile = File(filter=[u"*.fits"])

    view = View(HSplit(
        VGroup(
            Item("fitsfile",
                 label=u"Select a FITS datacube",
                 show_label=True,
                 editor=FileEditor(dialog_style='open')),
            Item("rendering",
                 tooltip=u"Choose the rendering type",
                 show_label=True),
            Item('plotbutton1',
                 tooltip=u"Plot 3D surfaces, color coded by velocities",
                 visible_when="rendering=='Surface-Spectrum'"),
            Item('plotbutton2',
                 tooltip=u"Plot 3D surfaces, color coded by intensities",
                 visible_when="rendering=='Surface-Intensity'"),
            Item('plotbutton3',
                 tooltip=u"Plot 3D dots, color coded by intensities",
                 visible_when="rendering=='Volume-Intensity'"),
            "clearbutton",
            HGroup(
                Item('xstart',
                     tooltip=u"starting pixel in X axis",
                     show_label=True,
                     springy=True),
                Item('xend',
                     tooltip=u"ending pixel in X axis",
                     show_label=True,
                     springy=True)),
            HGroup(
                Item('ystart',
                     tooltip=u"starting pixel in Y axis",
                     show_label=True,
                     springy=True),
                Item('yend',
                     tooltip=u"ending pixel in Y axis",
                     show_label=True,
                     springy=True)),
            HGroup(
                Item('zstart',
                     tooltip=u"starting pixel in Z axis",
                     show_label=True,
                     springy=True),
                Item('zend',
                     tooltip=u"ending pixel in Z axis",
                     show_label=True,
                     springy=True)),
            HGroup(
                Item('datamax',
                     tooltip=u"Maximum datapoint shown",
                     show_label=True,
                     springy=True),
                Item('datamin',
                     tooltip=u"Minimum datapoint shown",
                     show_label=True,
                     springy=True)),
            HGroup(
                Item('dist', tooltip=u"Put a distance in kpc",
                     show_label=True),
                Item('leng',
                     tooltip=
                     u"Put a non-zero bar length in pc to show the scale bar",
                     show_label=True),
                Item(
                    'vsp',
                    tooltip=
                    u"Put a non-zero velocity range in km/s to show the scale bar",
                    show_label=True)),
            HGroup(Item('zscale',
                        tooltip=u"Stretch the datacube in Z axis",
                        show_label=True),
                   Item('opacity',
                        tooltip=u"Opacity of the scene",
                        show_label=True),
                   show_labels=False),
            Item('_'),
            Item(
                "contfile",
                label=u"Add background contours",
                tooltip=
                u"This file must be of the same (first two) dimension as the datacube!!!",
                show_label=True,
                editor=FileEditor(dialog_style='open')),
            Item('_'),
            HGroup(Item("spin", tooltip=u"Spin 360 degrees", show_label=False),
                   Item("movie", tooltip="Make a GIF movie",
                        show_label=False)),
            HGroup(
                Item('iteration',
                     tooltip=u"number of iterations, 0 means inf.",
                     show_label=True),
                Item('quality',
                     tooltip=u"quality of plots, 0 is worst, 8 is good.",
                     show_label=True)),
            HGroup(
                Item('delay',
                     tooltip=u"time delay between frames, in millisecond.",
                     show_label=True),
                Item('angle', tooltip=u"angle the cube spins",
                     show_label=True)),
            Item('_'),
            HGroup(
                Item("save_the_scene",
                     tooltip=u"Save current scene in a 3D model file"),
                Item("save_in_file",
                     tooltip=u"3D model file name",
                     show_label=False),
                visible_when=
                "rendering=='Surface-Spectrum' or rendering=='Surface-Intensity'"
            ),
            show_labels=False),
        VGroup(Item(name='scene',
                    editor=SceneEditor(scene_class=MayaviScene),
                    resizable=True,
                    height=600,
                    width=600),
               show_labels=False)),
                resizable=True,
                title=u"TDViz")

    def _fitsfile_changed(self):
        img = fits.open(self.fitsfile)  # Read the fits data
        dat = img[0].data
        self.hdr = img[0].header

        naxis = self.hdr['NAXIS']
        ## The three axes loaded by fits are: velo, dec, ra
        ## Swap the axes, RA<->velo
        if naxis == 4:
            self.data = np.swapaxes(dat[0], 0, 2) * 1000.0
        elif naxis == 3:
            self.data = np.swapaxes(dat, 0, 2) * 1000.0
        #onevpix = self.hdr['CDELT3']
        self.data[np.isnan(self.data)] = 0.0
        self.data[np.isinf(self.data)] = 0.0

        self.datamax = np.asscalar(np.max(self.data))
        self.datamin = np.asscalar(np.min(self.data))
        self.xend = self.data.shape[0] - 1
        self.yend = self.data.shape[1] - 1
        self.zend = self.data.shape[2] - 1

        self.data[self.data < self.datamin] = self.datamin

    def loaddata(self):
        channel = self.data
        ## Reset the range if it is beyond the cube:
        if self.xstart < 0:
            print('Wrong number!')
            self.xstart = 0
        if self.xend > channel.shape[0] - 1:
            print('Wrong number!')
            self.xend = channel.shape[0] - 1
        if self.ystart < 0:
            print('Wrong number!')
            self.ystart = 0
        if self.yend > channel.shape[1] - 1:
            print('Wrong number!')
            self.yend = channel.shape[1] - 1
        if self.zstart < 0:
            print('Wrong number!')
            self.zstart = 0
        if self.zend > channel.shape[2] - 1:
            print('Wrong number!')
            self.zend = channel.shape[2] - 1
        ## Select a region, use mJy unit
        region = channel[self.xstart:self.xend, self.ystart:self.yend,
                         self.zstart:self.zend]

        ## Stretch the cube in V axis
        from scipy.interpolate import splrep
        from scipy.interpolate import splev
        vol = region.shape
        stretch = self.zscale
        ## Stretch parameter: how many times longer the V axis will be
        sregion = np.empty((vol[0], vol[1], vol[2] * stretch))
        chanindex = np.linspace(0, vol[2] - 1, vol[2])
        for j in range(0, vol[0] - 1):
            for k in range(0, vol[1] - 1):
                spec = region[j, k, :]
                tck = splrep(chanindex, spec, k=1)
                chanindex2 = np.linspace(0, vol[2] - 1, vol[2] * stretch)
                sregion[j, k, :] = splev(chanindex2, tck)
        self.sregion = sregion
        # Reset the max/min values
        if self.datamin < np.asscalar(np.min(self.sregion)):
            print('Wrong number!')
            self.datamin = np.asscalar(np.min(self.sregion))
        if self.datamax > np.asscalar(np.max(self.sregion)):
            print('Wrong number!')
            self.datamax = np.asscalar(np.max(self.sregion))
        self.xrang = abs(self.xstart - self.xend)
        self.yrang = abs(self.ystart - self.yend)
        self.zrang = abs(self.zstart - self.zend) * stretch

        ## Keep a record of the coordinates:
        crval1 = self.hdr['crval1']
        cdelt1 = self.hdr['cdelt1']
        crpix1 = self.hdr['crpix1']
        crval2 = self.hdr['crval2']
        cdelt2 = self.hdr['cdelt2']
        crpix2 = self.hdr['crpix2']
        crval3 = self.hdr['crval3']
        cdelt3 = self.hdr['cdelt3']
        crpix3 = self.hdr['crpix3']

        ra_start = (self.xstart + 1 - crpix1) * cdelt1 + crval1
        ra_end = (self.xend + 1 - crpix1) * cdelt1 + crval1
        #if ra_start < ra_end:
        #	ra_start, ra_end = ra_end, ra_start
        dec_start = (self.ystart + 1 - crpix2) * cdelt2 + crval2
        dec_end = (self.yend + 1 - crpix2) * cdelt2 + crval2
        #if dec_start > dec_end:
        #	dec_start, dec_end = dec_end, dec_start
        vel_start = (self.zstart + 1 - crpix3) * cdelt3 + crval3
        vel_end = (self.zend + 1 - crpix3) * cdelt3 + crval3
        #if vel_start < vel_end:
        #	vel_start, vel_end = vel_end, vel_start
        vel_start /= 1e3
        vel_end /= 1e3

        ## Flip the V axis
        if cdelt3 > 0:
            self.sregion = self.sregion[:, :, ::-1]
            vel_start, vel_end = vel_end, vel_start

        self.extent = [
            ra_start, ra_end, dec_start, dec_end, vel_start, vel_end
        ]

    def labels(self):
        '''
		Add 3d text to show the axes.
		'''
        fontsize = max(self.xrang, self.yrang) / 40.
        tcolor = (1, 1, 1)
        mlab.text3d(self.xrang / 2,
                    -10,
                    self.zrang + 10,
                    'R.A.',
                    scale=fontsize,
                    orient_to_camera=True,
                    color=tcolor)
        mlab.text3d(-10,
                    self.yrang / 2,
                    self.zrang + 10,
                    'Decl.',
                    scale=fontsize,
                    orient_to_camera=True,
                    color=tcolor)
        mlab.text3d(-10,
                    -10,
                    self.zrang / 2 - 10,
                    'V (km/s)',
                    scale=fontsize,
                    orient_to_camera=True,
                    color=tcolor)

        # Add scale bars
        if self.leng != 0.0:
            distance = self.dist * 1e3
            length = self.leng
            leng_pix = np.round(length / distance / np.pi * 180. /
                                np.abs(self.hdr['cdelt1']))
            bar_x = [self.xrang - 20 - leng_pix, self.xrang - 20]
            bar_y = [self.yrang - 10, self.yrang - 10]
            bar_z = [0, 0]
            mlab.plot3d(bar_x, bar_y, bar_z, color=tcolor, tube_radius=1.)
            mlab.text3d(self.xrang - 30 - leng_pix,
                        self.yrang - 25,
                        0,
                        '{:.2f} pc'.format(length),
                        scale=fontsize,
                        orient_to_camera=False,
                        color=tcolor)
        if self.vsp != 0.0:
            vspan = self.vsp
            vspan_pix = np.round(vspan / np.abs(self.hdr['cdelt3'] / 1e3))
            bar_x = [self.xrang, self.xrang]
            bar_y = [self.yrang - 10, self.yrang - 10]
            bar_z = np.array([5, 5 + vspan_pix]) * self.zscale
            mlab.plot3d(bar_x, bar_y, bar_z, color=tcolor, tube_radius=1.)
            mlab.text3d(self.xrang,
                        self.yrang - 25,
                        10,
                        '{:.1f} km/s'.format(vspan),
                        scale=fontsize,
                        orient_to_camera=False,
                        color=tcolor,
                        orientation=(0, 90, 0))

        # Label the coordinates of the corners
        # Lower left corner
        ra0 = self.extent[0]
        dec0 = self.extent[2]
        c = SkyCoord(ra=ra0 * u.degree, dec=dec0 * u.degree, frame='icrs')
        RA_ll = str(int(c.ra.hms.h)) + 'h' + str(int(c.ra.hms.m)) + 'm' + str(
            round(c.ra.hms.s, 1)) + 's'
        mlab.text3d(0,
                    -10,
                    self.zrang + 5,
                    RA_ll,
                    scale=fontsize,
                    orient_to_camera=True,
                    color=tcolor)
        DEC_ll = str(int(c.dec.dms.d)) + 'd' + str(int(abs(
            c.dec.dms.m))) + 'm' + str(round(abs(c.dec.dms.s), 1)) + 's'
        mlab.text3d(-40,
                    0,
                    self.zrang + 5,
                    DEC_ll,
                    scale=fontsize,
                    orient_to_camera=True,
                    color=tcolor)
        # Upper right corner
        ra0 = self.extent[1]
        dec0 = self.extent[3]
        c = SkyCoord(ra=ra0 * u.degree, dec=dec0 * u.degree, frame='icrs')
        RA_ll = str(int(c.ra.hms.h)) + 'h' + str(int(c.ra.hms.m)) + 'm' + str(
            round(c.ra.hms.s, 1)) + 's'
        mlab.text3d(self.xrang,
                    -10,
                    self.zrang + 5,
                    RA_ll,
                    scale=fontsize,
                    orient_to_camera=True,
                    color=tcolor)
        DEC_ll = str(int(c.dec.dms.d)) + 'd' + str(int(abs(
            c.dec.dms.m))) + 'm' + str(round(abs(c.dec.dms.s), 1)) + 's'
        mlab.text3d(-40,
                    self.yrang,
                    self.zrang + 5,
                    DEC_ll,
                    scale=fontsize,
                    orient_to_camera=True,
                    color=tcolor)
        # V axis
        if self.extent[5] > self.extent[4]:
            v0 = self.extent[4]
            v1 = self.extent[5]
        else:
            v0 = self.extent[5]
            v1 = self.extent[4]
        mlab.text3d(-10,
                    -10,
                    self.zrang,
                    str(round(v0, 1)),
                    scale=fontsize,
                    orient_to_camera=True,
                    color=tcolor)
        mlab.text3d(-10,
                    -10,
                    0,
                    str(round(v1, 1)),
                    scale=fontsize,
                    orient_to_camera=True,
                    color=tcolor)

        mlab.axes(self.field,
                  ranges=self.extent,
                  x_axis_visibility=False,
                  y_axis_visibility=False,
                  z_axis_visibility=False)
        mlab.outline()

    def _plotbutton1_fired(self):
        mlab.clf()
        self.loaddata()
        self.sregion[np.where(self.sregion < self.datamin)] = self.datamin
        self.sregion[np.where(self.sregion > self.datamax)] = self.datamax

        # The following codes from: http://docs.enthought.com/mayavi/mayavi/auto/example_atomic_orbital.html#example-atomic-orbital
        field = mlab.pipeline.scalar_field(
            self.sregion)  # Generate a scalar field
        colored = self.sregion
        vol = self.sregion.shape
        for v in range(0, vol[2] - 1):
            colored[:, :,
                    v] = self.extent[4] + v * (-1) * abs(self.hdr['cdelt3'])
        field.image_data.point_data.add_array(colored.T.ravel())
        field.image_data.point_data.get_array(1).name = 'color'
        field.update()

        field2 = mlab.pipeline.set_active_attribute(field,
                                                    point_scalars='scalar')
        contour = mlab.pipeline.contour(field2)
        contour2 = mlab.pipeline.set_active_attribute(contour,
                                                      point_scalars='color')

        mlab.pipeline.surface(contour2, colormap='jet', opacity=self.opacity)

        ## Insert a continuum plot
        if self.contfile != '':
            im = fits.open(self.contfile)
            dat = im[0].data
            ##dat0 = dat[0]
            channel = dat[0]
            region = np.swapaxes(
                channel[self.ystart:self.yend, self.xstart:self.xend] * 1000.,
                0, 1)
            field = mlab.contour3d(region, colormap='gist_ncar')
            field.contour.minimum_contour = 5

        self.field = field2
        self.field.scene.render()
        self.labels()
        mlab.view(azimuth=0, elevation=0, distance='auto')
        mlab.show()

    def _plotbutton2_fired(self):
        mlab.clf()
        self.loaddata()
        #field=mlab.contour3d(self.sregion,colormap='gist_ncar')     # Generate a scalar field
        field = mlab.contour3d(self.sregion)  # Generate a scalar field
        field.contour.maximum_contour = self.datamax
        field.contour.minimum_contour = self.datamin
        field.actor.property.opacity = self.opacity

        self.field = field
        self.labels()
        mlab.view(azimuth=0, elevation=0, distance='auto')
        mlab.show()

    def _plotbutton3_fired(self):
        mlab.clf()
        self.loaddata()
        field = mlab.pipeline.scalar_field(
            self.sregion)  # Generate a scalar field
        mlab.pipeline.volume(field, vmax=self.datamax, vmin=self.datamin)

        self.field = field
        self.labels()
        mlab.view(azimuth=0, elevation=0, distance='auto')
        mlab.colorbar()
        mlab.show()


#	def _datamax_changed(self):
#		if hasattr(self, "field"):
#			self.field.contour.maximum_contour = self.datamax

    def _save_the_scene_fired(self):
        mlab.savefig(self.save_in_file)

    def _movie_fired(self):
        if os.path.exists("./tenpfigz"):
            print("The chance of you using this name is really small...")
        else:
            os.system("mkdir tenpfigz")

        if filter(os.path.isfile, glob.glob("./tenpfigz/*.jpg")) != []:
            os.system("rm -rf ./tenpfigz/*.jpg")

        i = 0
        ## Quality of the movie: 0 is the worst, 8 is ok.
        self.field.scene.anti_aliasing_frames = self.quality
        self.field.scene.disable_render = True
        mlab.savefig('./tenpfigz/screenshot0' + str(i) + '.jpg')
        while i < (self.angle / 5):
            self.field.scene.camera.azimuth(5)
            self.field.scene.render()
            i += 1
            if i < 10:
                mlab.savefig('./tenpfigz/screenshot0' + str(i) + '.jpg')
            elif 9 < i < 100:
                mlab.savefig('./tenpfigz/screenshot' + str(i) + '.jpg')
        self.field.scene.disable_render = False

        os.system("convert -delay " + str(self.delay) + " -loop " +
                  str(self.iteration) +
                  " ./tenpfigz/*.jpg ./tenpfigz/animation.gif")

    def _spin_fired(self):
        i = 0
        self.field.scene.disable_render = True

        @mlab.animate
        def anim():
            while i < 72:
                self.field.scene.camera.azimuth(5)
                self.field.scene.render()
                yield

        a = anim()
        #while i<72:
        #	self.field.scene.camera.azimuth(5)
        #	self.field.scene.render()
        #	i += 1
        #	#mlab.savefig('./'+str(i)+'.png')
        self.field.scene.disable_render = False

    def _clearbutton_fired(self):
        mlab.clf()
class ParticleArrayHelper(HasTraits):
    """
    This class manages a particle array and sets up the necessary
    plotting related information for it.
    """

    # The particle array we manage.
    particle_array = Instance(ParticleArray)

    # The name of the particle array.
    name = Str

    # Current time.
    time = Float(0.0)

    # The active scalar to view.
    scalar = Str('rho', desc='name of the active scalar to view')

    # The mlab scalar plot for this particle array.
    plot = Instance(PipelineBase)

    # The mlab vectors plot for this particle array.
    plot_vectors = Instance(PipelineBase)

    # List of available scalars in the particle array.
    scalar_list = List(Str)

    scene = Instance(MlabSceneModel)

    # Sync'd trait with the scalar lut manager.
    show_legend = Bool(False, desc='if the scalar legend is to be displayed')

    # Show all scalars.
    list_all_scalars = Bool(False, desc='if all scalars should be listed')

    # Sync'd trait with the dataset to turn on/off visibility.
    visible = Bool(True, desc='if the particle array is to be displayed')

    # Show the time of the simulation on screen.
    show_time = Bool(False, desc='if the current time is displayed')

    # Edit the scalars.
    edit_scalars = Button('More options ...')

    # Show vectors.
    show_vectors = Bool(False, desc='if vectors should be displayed')

    vectors = Str('u, v, w',
                  enter_set=True,
                  auto_set=False,
                  desc='the vectors to display')

    mask_on_ratio = Int(3, desc='mask one in specified points')

    scale_factor = Float(1.0,
                         desc='scale factor for vectors',
                         enter_set=True,
                         auto_set=False)

    edit_vectors = Button('More options ...')

    # Private attribute to store the Text module.
    _text = Instance(PipelineBase)

    # Extra scalars to show.  These will be added and saved to the data if
    # needed.
    extra_scalars = List(Str)

    # Set to True when the particle array is updated with a new property say.
    updated = Event

    # Private attribute to store old value of visibility in case of empty
    # arrays.
    _old_visible = Bool(True)

    ########################################
    # View related code.
    view = View(
        Item(name='name', show_label=False, editor=TitleEditor()),
        Group(Group(
            Group(
                Item(name='visible'),
                Item(name='show_legend'),
                Item(name='scalar', editor=EnumEditor(name='scalar_list')),
                Item(name='list_all_scalars'),
                Item(name='show_time'),
                columns=2,
            ),
            Item(name='edit_scalars', show_label=False),
            label='Scalars',
        ),
              Group(
                  Item(name='show_vectors'),
                  Item(name='vectors'),
                  Item(name='mask_on_ratio'),
                  Item(name='scale_factor'),
                  Item(name='edit_vectors', show_label=False),
                  label='Vectors',
              ),
              layout='tabbed'))

    # Private protocol ############################################
    def _add_vmag(self, pa):
        if 'vmag' not in pa.properties:
            if 'vmag2' in pa.output_property_arrays:
                vmag = numpy.sqrt(pa.vmag2)
            else:
                vmag = numpy.sqrt(pa.u**2 + pa.v**2 + pa.w**2)
            pa.add_property(name='vmag', data=vmag)
            if len(pa.output_property_arrays) > 0:
                # We do not call add_output_arrays when the default is empty
                # as if it is empty, all arrays are saved anyway. However,
                # adding just vmag in this case will mean that when the
                # particle array is saved it will only save vmag!  This is
                # not what we want, hence we add vmag *only* if the
                # output_property_arrays is non-zero length.
                pa.add_output_arrays(['vmag'])
            self.updated = True

    def _get_scalar(self, pa, scalar):
        """Return the requested scalar from the given particle array.
        """
        if scalar in self.extra_scalars:
            method_name = '_add_' + scalar
            method = getattr(self, method_name)
            method(pa)
        return getattr(pa, scalar)

    #  Traits handlers #############################################
    def _edit_scalars_fired(self):
        self.plot.edit_traits()

    def _edit_vectors_fired(self):
        self.plot_vectors.edit_traits()

    def _particle_array_changed(self, old, pa):
        self.name = pa.name

        self._list_all_scalars_changed(self.list_all_scalars)

        # Update the plot.
        x, y, z = pa.x, pa.y, pa.z
        s = self._get_scalar(pa, self.scalar)
        p = self.plot
        mlab = self.scene.mlab
        empty = len(x) == 0
        old_empty = len(old.x) == 0 if old is not None else True
        if p is None and not empty:
            src = mlab.pipeline.scalar_scatter(x, y, z, s)
            p = mlab.pipeline.glyph(src, mode='point', scale_mode='none')
            p.actor.property.point_size = 3
            scm = p.module_manager.scalar_lut_manager
            scm.set(show_legend=self.show_legend,
                    use_default_name=False,
                    data_name=self.scalar)
            self.sync_trait('visible', p, mutual=True)
            self.sync_trait('show_legend', scm, mutual=True)
            # set_arrays(p.mlab_source.m_data, pa)
            self.plot = p
        elif not empty:
            if len(x) == len(p.mlab_source.x):
                p.mlab_source.set(x=x, y=y, z=z, scalars=s)
                if self.plot_vectors:
                    self._vectors_changed(self.vectors)
            else:
                if self.plot_vectors:
                    u, v, w = self._get_vectors_for_plot(self.vectors)
                    p.mlab_source.reset(x=x,
                                        y=y,
                                        z=z,
                                        scalars=s,
                                        u=u,
                                        v=v,
                                        w=w)
                else:
                    p.mlab_source.reset(x=x, y=y, z=z, scalars=s)
                p.mlab_source.update()

        if empty and not old_empty:
            if p is not None:
                src = p.parent.parent
                self._old_visible = src.visible
                src.visible = False
        if old_empty and not empty:
            if p is not None:
                p.parent.parent.visible = self._old_visible
                self._show_vectors_changed(self.show_vectors)

        # Setup the time.
        self._show_time_changed(self.show_time)

    def _scalar_changed(self, value):
        p = self.plot
        if p is not None:
            p.mlab_source.scalars = self._get_scalar(self.particle_array,
                                                     value)
            p.module_manager.scalar_lut_manager.data_name = value

    def _list_all_scalars_changed(self, list_all_scalars):
        pa = self.particle_array
        if list_all_scalars:
            sc_list = pa.properties.keys()
            self.scalar_list = sorted(set(sc_list + self.extra_scalars))
        else:
            if len(pa.output_property_arrays) > 0:
                self.scalar_list = sorted(
                    set(pa.output_property_arrays + self.extra_scalars))
            else:
                sc_list = pa.properties.keys()
                self.scalar_list = sorted(set(sc_list + self.extra_scalars))

    def _show_time_changed(self, value):
        txt = self._text
        mlab = self.scene.mlab
        if value:
            if txt is not None:
                txt.visible = True
            elif self.plot is not None:
                mlab.get_engine().current_object = self.plot
                txt = mlab.text(0.01, 0.01, 'Time = 0.0', width=0.35)
                self._text = txt
                self._time_changed(self.time)
        else:
            if txt is not None:
                txt.visible = False

    def _get_vectors_for_plot(self, vectors):
        pa = self.particle_array
        comps = [x.strip() for x in vectors.split(',')]
        if len(comps) == 3:
            try:
                vec = tuple(getattr(pa, x) for x in comps)
            except AttributeError:
                return None
            else:
                return vec

    def _vectors_changed(self, value):
        vec = self._get_vectors_for_plot(value)
        if vec is not None:
            self.plot.mlab_source.set(vectors=numpy.c_[vec[0], vec[1], vec[2]])

    def _show_vectors_changed(self, value):
        pv = self.plot_vectors
        if pv is not None:
            pv.visible = value
        elif self.plot is not None and value:
            self._vectors_changed(self.vectors)
            pv = self.scene.mlab.pipeline.vectors(
                self.plot.mlab_source.m_data,
                mask_points=self.mask_on_ratio,
                scale_factor=self.scale_factor)
            self.plot_vectors = pv

    def _mask_on_ratio_changed(self, value):
        pv = self.plot_vectors
        if pv is not None:
            pv.glyph.mask_points.on_ratio = value

    def _scale_factor_changed(self, value):
        pv = self.plot_vectors
        if pv is not None:
            pv.glyph.glyph.scale_factor = value

    def _time_changed(self, value):
        txt = self._text
        if txt is not None:
            txt.text = 'Time = %.3e' % (value)

    def _extra_scalars_default(self):
        return ['vmag']
Example #21
0
class CSVGrapher(Loggable):
    open_button = Button('Open')
    plot_button = Button('Plot')
    as_series = Bool(False)

    _path = Str
    path = File
    data_selectors = List
    column_names = List

    stats = Str
    file_name = Property(depends_on='_path')
    short_name = Property(depends_on='_path')

    _graph_count = 0
    delimiter = Str(',')

    def quick_graph(self, p):
        kind = 'scatter'
        #        for det in ['H2']:
        for det in ['H2', 'H1', 'AX', 'L1', 'L2']:
            g = self._gc(p, det, kind)
            info = g.edit_traits()
            g.save_pdf('/Users/argonlab2/Sandbox/baselines/auto-down50/{}_obama{}'.format(kind, det))
            #            info.dispose()

    def _gc(self, p, det, kind):
        g = Graph(container_dict=dict(padding=5),
                  window_width=1000,
                  window_height=800,
                  window_x=40,
                  window_y=20
        )
        with open(p, 'r') as rfile:
            # gather data
            reader = csv.reader(rfile)
            header = reader.next()
            groups = self._parse_data(reader)
            '''
                groups= [data,]
                data shape = nrow,ncols
                
            '''
            data = groups[0]
            x = data[0]
            y = data[header.index(det)]

        sy = smooth(y, window_len=120)  # , window='flat')

        x = x[::50]
        y = y[::50]
        sy = sy[::50]

        # smooth

        # plot
        g.new_plot(zoom=True, xtitle='Time (s)', ytitle='{} Baseline Intensity (fA)'.format(det))
        g.new_series(x, y, type=kind, marker='dot', marker_size=2)
        g.new_series(x, sy, line_width=2)
        #        g.set_x_limits(500, 500 + 60 * 30)
        #        g.edit_traits()
        return g

    def add_column_selector(self):
        csnames = self.column_names

        vind = len(self.data_selectors) + 1
        try:
            cs = self.data_selectors[-1].clone_traits(['fit', 'plot_type', 'parent'])
            cs.trait_set(index=csnames[0],
                         value=csnames[vind],
                         column_names=csnames)

            self.data_selectors.append(cs)
        except IndexError:
            self.warning_dialog('No More Columns')


    def remove_column_selector(self, cs):
        self.data_selectors.remove(cs)

    # ===============================================================================
    # handlers
    # ===============================================================================
    def _open_button_fired(self):
        self.data_selectors = []
        #        p = '/Users/ross/Sandbox/csvdata.txt'
        #        self._path = p
        #        self._path=os.path.join(paths.data_dir,'spectrometer_scans','scan007.txt')
        dlg = FileDialog(action='open', default_directory=paths.data_dir)
        if dlg.open() == OK:
            self._path = dlg.path

        with open(self._path, 'U') as rfile:


            reader = csv.reader(rfile, delimiter=self.delimiter)
            self.column_names = names = reader.next()
            try:
                cs = DataSelector(column_names=names,
                                  index=names[0],
                                  value=names[1],
                                  removable=False,
                                  parent=self,
                )
                self.data_selectors.append(cs)
            except IndexError:

                self.warning_dialog('Invalid delimiter {} for {}'.format(DELIMITERS[self.delimiter],
                                                                         os.path.basename(self._path)
                ))

    def _parse_data(self, reader):
        groups = []
        while 1:
            lines = []
            for l in reader:
                l = [li.strip() for li in l]
                if not l or not any(l):
                    data = np.array([map(float, l) for l in lines])
                    data = data.transpose()
                    groups.append(data)
                    break
                lines.append(l)
            else:
                #                print lines
                #                for l in lines:
                #                    print l
                #                    print map(float, l)
                #
                data = np.array([map(float, l) for l in lines])
                data = data.transpose()
                groups.append(data)
                break
        return groups


    def _plot_button_fired(self):
        with open(self._path, 'U') as rfile:
            reader = csv.reader(rfile, delimiter=self.delimiter)
            _header = reader.next()
            groups = self._parse_data(reader)
            #            print groups
            for data in groups:
                #                print data
                self._show_plot(data)

    def _show_plot(self, data):
        cd = dict(padding=5, stack_order='top_to_bottom')
        csnames = self.column_names
        xmin = np.Inf
        xmax = -np.Inf

        if self.as_series:
            g = RegressionGraph(container_dict=cd)
            p = g.new_plot(padding=[50, 5, 5, 50],
                           xtitle=''
            )
            p.value_range.tight_bounds = False
            p.value_range.margin = 0.1
        else:
            g = StackedRegressionGraph(container_dict=cd)

        regressable = False
        #        metadata = None
        for i, csi in enumerate(self.data_selectors):
            if not self.as_series:
                p = g.new_plot(padding=[50, 5, 5, 50])
                p.value_range.tight_bounds = False
                p.value_range.margin = 0.1
                plotid = i
            else:
                plotid = 0

            try:
                x = data[csnames.index(csi.index)]
                y = data[csnames.index(csi.value)]
                xmin = min(xmin, min(x))
                xmax = max(xmax, max(x))
                fit = csi.fit if csi.fit != NULL_STR else None
                g.new_series(x, y, fit=fit,
                             filter_outliers=csi.use_filter,
                             type=csi.plot_type,
                             plotid=plotid)

                g.set_x_title(csi.index, plotid=plotid)
                g.set_y_title(csi.value, plotid=plotid)
                if fit:
                    regressable = True

            except IndexError:
                pass

        g.set_x_limits(xmin, xmax, pad='0.1')

        self._graph_count += 1
        if regressable:
            gg = StatsGraph(graph=g)
            gii = gg
        else:
            gii = g

        g._update_graph()

        def show(gi):
            gi.window_title = '{} Graph {}'.format(self.short_name, self._graph_count)
            gi.window_x = self._graph_count * 20 + 400
            gi.window_y = self._graph_count * 20 + 20
            gi.edit_traits()

        show(gii)

    # ===============================================================================
    # property get/set
    # ===============================================================================
    def _get_file_name(self):
        if os.path.isfile(self._path):
            return os.path.relpath(self._path, paths.data_dir)
        else:
            return ''

    def _get_short_name(self):
        if os.path.isfile(self._path):
            return os.path.basename(self._path)
        else:
            return ''
            # ===============================================================================
            # views
            # ===============================================================================

    def traits_view(self):
        v = View(Item('as_series'), Item('delimiter', editor=EnumEditor(values=DELIMITERS)),
                 HGroup(Item('open_button', show_label=False),
                        Item('plot_button', enabled_when='_path', show_label=False),
                        Item('file_name', show_label=False, style='readonly')),
                 Item('data_selectors', show_label=False, editor=ListEditor(mutable=False,
                                                                            style='custom',
                                                                            editor=InstanceEditor())),


                 resizable=True,
                 width=525,
                 height=225,
                 title='CSV Plotter'
        )
        return v
class MayaviViewer(HasTraits):
    """
    This class represents a Mayavi based viewer for the particles.  They
    are queried from a running solver.
    """

    particle_arrays = List(Instance(ParticleArrayHelper), [])
    pa_names = List(Str, [])

    interpolator = Instance(InterpolatorView)

    # The default scalar to load up when running the viewer.
    scalar = Str("rho")

    scene = Instance(MlabSceneModel, ())

    ########################################
    # Traits to pull data from a live solver.
    live_mode = Bool(False,
                     desc='if data is obtained from a running solver '
                     'or from saved files')

    shell = Button('Launch Python Shell')
    host = Str('localhost', desc='machine to connect to')
    port = Int(8800, desc='port to use to connect to solver')
    authkey = Password('pysph', desc='authorization key')
    host_changed = Bool(True)
    client = Instance(MultiprocessingClient)
    controller = Property(depends_on='live_mode, host_changed')

    ########################################
    # Traits to view saved solver output.
    files = List(Str, [])
    directory = Directory()
    current_file = Str('', desc='the file being viewed currently')
    update_files = Button('Refresh')
    file_count = Range(low='_low',
                       high='_n_files',
                       value=0,
                       desc='the file counter')
    play = Bool(False, desc='if all files are played automatically')
    play_delay = Float(0.2, desc='the delay between loading files')
    loop = Bool(False, desc='if the animation is looped')
    # This is len(files) - 1.
    _n_files = Int(0)
    _low = Int(0)

    ########################################
    # Timer traits.
    timer = Instance(Timer)
    interval = Range(0.5,
                     20.0,
                     2.0,
                     desc='frequency in seconds with which plot is updated')

    ########################################
    # Solver info/control.
    current_time = Float(0.0, desc='the current time in the simulation')
    time_step = Float(0.0, desc='the time-step of the solver')
    iteration = Int(0, desc='the current iteration number')
    pause_solver = Bool(False, desc='if the solver should be paused')

    ########################################
    # Movie.
    record = Bool(False, desc='if PNG files are to be saved for animation')
    frame_interval = Range(1, 100, 5, desc='the interval between screenshots')
    movie_directory = Str
    # internal counters.
    _count = Int(0)
    _frame_count = Int(0)
    _last_time = Float
    _solver_data = Any
    _file_name = Str
    _particle_array_updated = Bool

    ########################################
    # The layout of the dialog created
    view = View(HSplit(
        Group(
            Group(Group(
                Item(name='directory'),
                Item(name='current_file'),
                Item(name='file_count'),
                HGroup(Item(name='play'),
                       Item(name='play_delay', label='Delay', resizable=True),
                       Item(name='loop'),
                       Item(name='update_files', show_label=False),
                       padding=0),
                padding=0,
                label='Saved Data',
                selected=True,
                enabled_when='not live_mode',
            ),
                  Group(
                      Item(name='live_mode'),
                      Group(
                          Item(name='host'),
                          Item(name='port'),
                          Item(name='authkey'),
                          enabled_when='live_mode',
                      ),
                      label='Connection',
                  ),
                  layout='tabbed'),
            Group(
                Group(
                    Item(name='current_time'),
                    Item(name='time_step'),
                    Item(name='iteration'),
                    Item(name='pause_solver', enabled_when='live_mode'),
                    Item(name='interval', enabled_when='not live_mode'),
                    label='Solver',
                ),
                Group(
                    Item(name='record'),
                    Item(name='frame_interval'),
                    Item(name='movie_directory'),
                    label='Movie',
                ),
                layout='tabbed',
            ),
            Group(Item(name='particle_arrays',
                       style='custom',
                       show_label=False,
                       editor=ListEditor(use_notebook=True,
                                         deletable=False,
                                         page_name='.name')),
                  Item(name='interpolator', style='custom', show_label=False),
                  layout='tabbed'),
            Item(name='shell', show_label=False),
        ),
        Group(
            Item('scene',
                 editor=SceneEditor(scene_class=MayaviScene),
                 height=400,
                 width=600,
                 show_label=False), )),
                resizable=True,
                title='PySPH Particle Viewer',
                height=640,
                width=1024,
                handler=ViewerHandler)

    ######################################################################
    # `MayaviViewer` interface.
    ######################################################################
    def on_close(self):
        self._handle_particle_array_updates()

    @on_trait_change('scene:activated')
    def start_timer(self):
        if not self.live_mode:
            # No need for the timer if we are rendering files.
            return

        # Just accessing the timer will start it.
        t = self.timer
        if not t.IsRunning():
            t.Start(int(self.interval * 1000))

    @on_trait_change('scene:activated')
    def update_plot(self):

        # No need to do this if files are being used.
        if not self.live_mode:
            return

        # do not update if solver is paused
        if self.pause_solver:
            return

        if self.client is None:
            self.host_changed = True

        controller = self.controller
        if controller is None:
            return

        self.current_time = t = controller.get_t()
        self.time_step = controller.get_dt()
        self.iteration = controller.get_count()

        arrays = []
        for idx, name in enumerate(self.pa_names):
            pa = controller.get_named_particle_array(name)
            arrays.append(pa)
            pah = self.particle_arrays[idx]
            pah.set(particle_array=pa, time=t)

        self.interpolator.particle_arrays = arrays

        if self.record:
            self._do_snap()

    def run_script(self, path):
        """Execute a script in the namespace of the viewer.
        """
        with open(path) as fp:
            data = fp.read()
            ns = self._get_shell_namespace()
            exec(compile(data, path, 'exec'), ns)

    ######################################################################
    # Private interface.
    ######################################################################
    def _do_snap(self):
        """Generate the animation."""
        p_arrays = self.particle_arrays
        if len(p_arrays) == 0:
            return
        if self.current_time == self._last_time:
            return

        if len(self.movie_directory) == 0:
            controller = self.controller
            output_dir = controller.get_output_directory()
            movie_dir = os.path.join(output_dir, 'movie')
            self.movie_directory = movie_dir
        else:
            movie_dir = self.movie_directory
        if not os.path.exists(movie_dir):
            os.mkdir(movie_dir)

        interval = self.frame_interval
        count = self._count
        if count % interval == 0:
            fname = 'frame%06d.png' % (self._frame_count)
            p_arrays[0].scene.save_png(os.path.join(movie_dir, fname))
            self._frame_count += 1
            self._last_time = self.current_time
        self._count += 1

    @on_trait_change('host,port,authkey')
    def _mark_reconnect(self):
        if self.live_mode:
            self.host_changed = True

    @cached_property
    def _get_controller(self):
        ''' get the controller, also sets the iteration count '''
        if not self.live_mode:
            return None

        reconnect = self.host_changed
        if not reconnect:
            try:
                c = self.client.controller
            except Exception as e:
                logger.info('Error: no connection or connection closed: '
                            'reconnecting: %s' % e)
                reconnect = True
                self.client = None
            else:
                try:
                    self.client.controller.get_count()
                except IOError:
                    self.client = None
                    reconnect = True

        if reconnect:
            self.host_changed = False
            try:
                if MultiprocessingClient.is_available((self.host, self.port)):
                    self.client = MultiprocessingClient(address=(self.host,
                                                                 self.port),
                                                        authkey=self.authkey)
                else:
                    logger.info('Could not connect: Multiprocessing Interface'
                                ' not available on %s:%s' %
                                (self.host, self.port))
                    return None
            except Exception as e:
                logger.info('Could not connect: check if solver is '
                            'running:%s' % e)
                return None
            c = self.client.controller
            self.iteration = c.get_count()

        if self.client is None:
            return None
        else:
            return self.client.controller

    def _client_changed(self, old, new):
        if not self.live_mode:
            return

        self._clear()
        if new is None:
            return
        else:
            self.pa_names = self.client.controller.get_particle_array_names()

        self.particle_arrays = [
            self._make_particle_array_helper(self.scene, x)
            for x in self.pa_names
        ]
        self.interpolator = InterpolatorView(scene=self.scene)
        # Turn on the legend for the first particle array.
        if len(self.particle_arrays) > 0:
            self.particle_arrays[0].set(show_legend=True, show_time=True)

    def _timer_event(self):
        # catch all Exceptions else timer will stop
        try:
            self.update_plot()
        except Exception as e:
            logger.info('Exception: %s caught in timer_event' % e)

    def _interval_changed(self, value):
        t = self.timer
        if t is None:
            return
        if t.IsRunning():
            t.Stop()
            t.Start(int(value * 1000))

    def _timer_default(self):
        return Timer(int(self.interval * 1000), self._timer_event)

    def _pause_solver_changed(self, value):
        if self.live_mode:
            c = self.controller
            if c is None:
                return
            if value:
                c.pause_on_next()
            else:
                c.cont()

    def _record_changed(self, value):
        if value:
            self._do_snap()

    def _files_changed(self, value):
        if len(value) == 0:
            return
        else:
            d = os.path.dirname(os.path.abspath(value[0]))
            self.movie_directory = os.path.join(d, 'movie')
            self.set(directory=d, trait_change_notify=False)
        self._n_files = len(value) - 1
        self._frame_count = 0
        self._count = 0
        self.frame_interval = 1
        fc = self.file_count
        self.file_count = 0
        if fc == 0:
            # Force an update when our original file count is 0.
            self._file_count_changed(fc)
        t = self.timer
        if not self.live_mode:
            if t.IsRunning():
                t.Stop()
        else:
            if not t.IsRunning():
                t.Stop()
                t.Start(self.interval * 1000)

    def _file_count_changed(self, value):
        # Save out any updates for the previous file if needed.
        self._handle_particle_array_updates()
        # Load the new file.
        fname = self.files[value]
        self._file_name = fname
        self.current_file = os.path.basename(fname)
        # Code to read the file, create particle array and setup the helper.
        data = load(fname)
        solver_data = data["solver_data"]
        arrays = data["arrays"]
        self._solver_data = solver_data
        self.current_time = t = float(solver_data['t'])
        self.time_step = float(solver_data['dt'])
        self.iteration = int(solver_data['count'])
        names = list(arrays.keys())
        pa_names = self.pa_names

        if len(pa_names) == 0:
            self.interpolator = InterpolatorView(scene=self.scene)
            self.pa_names = names
            pas = []
            for name in names:
                pa = arrays[name]
                pah = self._make_particle_array_helper(self.scene, name)
                # Must set this after setting the scene.
                pah.set(particle_array=pa, time=t)
                pas.append(pah)
            self.particle_arrays = pas
        else:
            for idx, name in enumerate(pa_names):
                pa = arrays[name]
                pah = self.particle_arrays[idx]
                pah.set(particle_array=pa, time=t)

        self.interpolator.particle_arrays = list(arrays.values())

        if self.record:
            self._do_snap()

    def _loop_changed(self, value):
        if value and self.play:
            self._play_changed(self.play)

    def _play_changed(self, value):
        t = self.timer
        if value:
            t.Stop()
            t.callable = self._play_event
            t.Start(1000 * self.play_delay)
        else:
            t.Stop()
            t.callable = self._timer_event

    def _clear(self):
        self.pa_names = []
        self.scene.mayavi_scene.children[:] = []

    def _play_event(self):
        nf = self._n_files
        pc = self.file_count
        pc += 1
        if pc > nf:
            if self.loop:
                pc = 0
            else:
                self.timer.Stop()
                pc = nf
        self.file_count = pc
        self._handle_particle_array_updates()

    def _play_delay_changed(self):
        if self.play:
            self._play_changed(self.play)

    def _scalar_changed(self, value):
        for pa in self.particle_arrays:
            pa.scalar = value

    def _update_files_fired(self):
        fc = self.file_count
        files = glob_files(self.files[fc])
        sort_file_list(files)
        self.files = files
        self.file_count = fc
        if self.play:
            self._play_changed(self.play)

    def _shell_fired(self):
        ns = self._get_shell_namespace()
        obj = PythonShellView(ns=ns)
        obj.edit_traits()

    def _get_shell_namespace(self):
        return dict(viewer=self,
                    particle_arrays=self.particle_arrays,
                    interpolator=self.interpolator,
                    scene=self.scene,
                    mlab=self.scene.mlab)

    def _directory_changed(self, d):
        ext = os.path.splitext(self.files[-1])[1]
        files = glob.glob(os.path.join(d, '*' + ext))
        if len(files) > 0:
            self._clear()
            sort_file_list(files)
            self.files = files
            self.file_count = min(self.file_count, len(files))
        else:
            pass

    def _live_mode_changed(self, value):
        if value:
            self._file_name = ''
            self.client = None
            self._clear()
            self._mark_reconnect()
            self.start_timer()
        else:
            self.client = None
            self._clear()
            self.timer.Stop()

    def _particle_array_helper_updated(self, value):
        self._particle_array_updated = True

    def _handle_particle_array_updates(self):
        # Called when the particle array helper fires an updated event.
        if self._particle_array_updated and self._file_name:
            sd = self._solver_data
            arrays = [x.particle_array for x in self.particle_arrays]
            dump(self._file_name, arrays, sd)
            self._particle_array_updated = False

    def _make_particle_array_helper(self, scene, name):
        pah = ParticleArrayHelper(scene=scene, name=name, scalar=self.scalar)
        pah.on_trait_change(self._particle_array_helper_updated, 'updated')
        return pah
Example #23
0
class Demo(HasTraits):

    list1 = List(Int)

    list2 = List(Float)

    list3 = List(Str, maxlen=3)

    list4 = List(Enum('red', 'green', 'blue', 2, 3))

    list5 = List(Range(low=0.0, high=10.0))

    # 'low' and 'high' are used to demonstrate lists containing dynamic ranges.
    low = Float(0.0)
    high = Float(1.0)

    list6 = List(Range(low=-1.0, high='high'))

    list7 = List(Range(low='low', high='high'))

    pop1 = Button("Pop from first list")

    sort1 = Button("Sort first list")

    # This will be str(self.list1).
    list1str = Property(Str, depends_on='list1')

    traits_view = \
        View(
            HGroup(
                # This VGroup forms the column of CSVListEditor examples.
                VGroup(
                    Item('list1', label="List(Int)",
                        editor=CSVListEditor(ignore_trailing_sep=False),
                        tooltip='options: ignore_trailing_sep=False'),
                    Item('list1', label="List(Int)", style='readonly',
                        editor=CSVListEditor()),
                    Item('list2', label="List(Float)",
                        editor=CSVListEditor(enter_set=True, auto_set=False),
                        tooltip='options: enter_set=True, auto_set=False'),
                    Item('list3', label="List(Str, maxlen=3)",
                        editor=CSVListEditor()),
                    Item('list4',
                         label="List(Enum('red', 'green', 'blue', 2, 3))",
                        editor=CSVListEditor(sep=None),
                        tooltip='options: sep=None'),
                    Item('list5', label="List(Range(low=0.0, high=10.0))",
                        editor=CSVListEditor()),
                    Item('list6', label="List(Range(low=-1.0, high='high'))",
                        editor=CSVListEditor()),
                    Item('list7', label="List(Range(low='low', high='high'))",
                        editor=CSVListEditor()),
                    springy=True,
                ),
                # This VGroup forms the right column; it will display the
                # Python str representation of the lists.
                VGroup(
                    UItem('list1str', editor=TextEditor(),
                                        enabled_when='False', width=240),
                    UItem('list1str', editor=TextEditor(),
                                        enabled_when='False', width=240),
                    UItem('list2', editor=TextEditor(),
                                        enabled_when='False', width=240),
                    UItem('list3', editor=TextEditor(),
                                        enabled_when='False', width=240),
                    UItem('list4', editor=TextEditor(),
                                        enabled_when='False', width=240),
                    UItem('list5', editor=TextEditor(),
                                        enabled_when='False', width=240),
                    UItem('list6', editor=TextEditor(),
                                        enabled_when='False', width=240),
                    UItem('list7', editor=TextEditor(),
                                        enabled_when='False', width=240),
                ),
            ),
            '_',
            HGroup('low', 'high', spring, UItem('pop1'), UItem('sort1')),
            Heading("Notes"),
            Label("Hover over a list to see which editor options are set, "
                  "if any."),
            Label("The editor of the first list, List(Int), uses "
                  "ignore_trailing_sep=False, so a trailing comma is "
                  "an error."),
            Label("The second list is a read-only view of the first list."),
            Label("The editor of the List(Float) example has enter_set=True "
                  "and auto_set=False; press Enter to validate."),
            Label("The List(Str) example will accept at most 3 elements."),
            Label("The editor of the List(Enum(...)) example uses sep=None, "
                  "i.e. whitespace acts as a separator."),
            Label("The last two List(Range(...)) examples take one or both "
                  "of their limits from the Low and High fields below."),
            width=720,
            title="CSVListEditor Demonstration",
        )

    def _list1_default(self):
        return [1, 4, 0, 10]

    def _get_list1str(self):
        return str(self.list1)

    def _pop1_fired(self):
        if len(self.list1) > 0:
            x = self.list1.pop()
            print(x)

    def _sort1_fired(self):
        self.list1.sort()
Example #24
0
class Kit2FiffPanel(HasPrivateTraits):
    """Control panel for kit2fiff conversion"""
    model = Instance(Kit2FiffModel)

    # model copies for view
    use_mrk = DelegatesTo('model')
    sqd_file = DelegatesTo('model')
    hsp_file = DelegatesTo('model')
    fid_file = DelegatesTo('model')
    stim_chs = DelegatesTo('model')
    stim_chs_manual = DelegatesTo('model')
    stim_slope = DelegatesTo('model')

    # info
    can_save = DelegatesTo('model')
    sqd_fname = DelegatesTo('model')
    hsp_fname = DelegatesTo('model')
    fid_fname = DelegatesTo('model')

    # Source Files
    reset_dig = Button

    # Visualization
    scene = Instance(MlabSceneModel)
    fid_obj = Instance(PointObject)
    elp_obj = Instance(PointObject)
    hsp_obj = Instance(PointObject)

    # Output
    save_as = Button(label='Save FIFF...')
    clear_all = Button(label='Clear All')
    queue = Instance(queue.Queue, ())
    queue_feedback = Str('')
    queue_current = Str('')
    queue_len = Int(0)
    queue_len_str = Property(Str, depends_on=['queue_len'])
    error = Str('')

    view = View(
        VGroup(
            VGroup(Item('sqd_file', label="Data"),
                   Item('sqd_fname', show_label=False, style='readonly'),
                   Item('hsp_file', label='Dig Head Shape'),
                   Item('hsp_fname', show_label=False, style='readonly'),
                   Item('fid_file', label='Dig Points'),
                   Item('fid_fname', show_label=False, style='readonly'),
                   Item('reset_dig',
                        label='Clear Digitizer Files',
                        show_label=False),
                   Item('use_mrk', editor=use_editor, style='custom'),
                   label="Sources",
                   show_border=True),
            VGroup(Item('stim_slope',
                        label="Event Onset",
                        style='custom',
                        editor=EnumEditor(values={
                            '+': '2:Peak (0 to 5 V)',
                            '-': '1:Trough (5 to 0 V)'
                        },
                                          cols=2),
                        help="Whether events are marked by a decrease "
                        "(trough) or an increase (peak) in trigger "
                        "channel values"),
                   Item('stim_chs',
                        label="Binary Coding",
                        style='custom',
                        editor=EnumEditor(values={
                            '>': '1:1 ... 128',
                            '<': '3:128 ... 1',
                            'man': '2:Manual'
                        },
                                          cols=2),
                        help="Specifies the bit order in event "
                        "channels. Assign the first bit (1) to the "
                        "first or the last trigger channel."),
                   Item('stim_chs_manual',
                        label='Stim Channels',
                        style='custom',
                        visible_when="stim_chs == 'man'"),
                   label='Events',
                   show_border=True),
            HGroup(Item('save_as', enabled_when='can_save'),
                   spring,
                   'clear_all',
                   show_labels=False),
            Item('queue_feedback', show_label=False, style='readonly'),
            Item('queue_current', show_label=False, style='readonly'),
            Item('queue_len_str', show_label=False, style='readonly'),
        ))

    def __init__(self, *args, **kwargs):
        super(Kit2FiffPanel, self).__init__(*args, **kwargs)

        # setup save worker
        def worker():
            while True:
                raw, fname = self.queue.get()
                basename = os.path.basename(fname)
                self.queue_len -= 1
                self.queue_current = 'Processing: %s' % basename

                # task
                try:
                    raw.save(fname, overwrite=True)
                except Exception as err:
                    self.error = str(err)
                    res = "Error saving: %s"
                else:
                    res = "Saved: %s"

                # finalize
                self.queue_current = ''
                self.queue_feedback = res % basename
                self.queue.task_done()

        t = Thread(target=worker)
        t.daemon = True
        t.start()

        # setup mayavi visualization
        m = self.model
        self.fid_obj = PointObject(scene=self.scene,
                                   color=(25, 225, 25),
                                   point_scale=5e-3)
        m.sync_trait('fid', self.fid_obj, 'points', mutual=False)
        m.sync_trait('head_dev_trans', self.fid_obj, 'trans', mutual=False)

        self.elp_obj = PointObject(scene=self.scene,
                                   color=(50, 50, 220),
                                   point_scale=1e-2,
                                   opacity=.2)
        m.sync_trait('elp', self.elp_obj, 'points', mutual=False)
        m.sync_trait('head_dev_trans', self.elp_obj, 'trans', mutual=False)

        self.hsp_obj = PointObject(scene=self.scene,
                                   color=(200, 200, 200),
                                   point_scale=2e-3)
        m.sync_trait('hsp', self.hsp_obj, 'points', mutual=False)
        m.sync_trait('head_dev_trans', self.hsp_obj, 'trans', mutual=False)

        self.scene.camera.parallel_scale = 0.15
        self.scene.mlab.view(0, 0, .15)

    def _clear_all_fired(self):
        self.model.clear_all()

    @cached_property
    def _get_queue_len_str(self):
        if self.queue_len:
            return "Queue length: %i" % self.queue_len
        else:
            return ''

    def _reset_dig_fired(self):
        self.reset_traits(['hsp_file', 'fid_file'])

    def _save_as_fired(self):
        # create raw
        try:
            raw = self.model.get_raw()
        except Exception as err:
            error(None, str(err), "Error Creating KIT Raw")
            raise

        # find default path
        stem, _ = os.path.splitext(self.sqd_file)
        if not stem.endswith('raw'):
            stem += '-raw'
        default_path = stem + '.fif'

        # save as dialog
        dlg = FileDialog(action="save as",
                         wildcard="fiff raw file (*.fif)|*.fif",
                         default_path=default_path)
        dlg.open()
        if dlg.return_code != OK:
            return

        fname = dlg.path
        if not fname.endswith('.fif'):
            fname += '.fif'
            if os.path.exists(fname):
                answer = confirm(
                    None, "The file %r already exists. Should it "
                    "be replaced?", "Overwrite File?")
                if answer != YES:
                    return

        self.queue.put((raw, fname))
        self.queue_len += 1
Example #25
0
class TableFilterEditor(HasTraits):
    """ An editor that manages table filters.
    """

    # -------------------------------------------------------------------------
    #  Trait definitions:
    # -------------------------------------------------------------------------

    #: TableEditor this editor is associated with
    editor = Instance(TableEditor)

    #: The list of filters
    filters = List(TableFilter)

    #: The list of available templates from which filters can be created
    templates = Property(List(TableFilter), observe="filters")

    #: The currently selected filter template
    selected_template = Instance(TableFilter)

    #: The currently selected filter
    selected_filter = Instance(TableFilter, allow_none=True)

    #: The view to use for the current filter
    selected_filter_view = Property(observe="selected_filter")

    #: Buttons for add/removing filters
    add_button = Button("New")
    remove_button = Button("Delete")

    # The default view for this editor
    view = View(
        Group(
            Group(
                Group(
                    Item("add_button", enabled_when="selected_template"),
                    Item(
                        "remove_button",
                        enabled_when="len(templates) > 1 and "
                        "selected_filter is not None",
                    ),
                    orientation="horizontal",
                    show_labels=False,
                ),
                Label("Base filter for new filters:"),
                Item("selected_template", editor=EnumEditor(name="templates")),
                Item(
                    "selected_filter",
                    style="custom",
                    editor=EnumEditor(name="filters", mode="list"),
                ),
                show_labels=False,
            ),
            Item(
                "selected_filter",
                width=0.75,
                style="custom",
                editor=InstanceEditor(view_name="selected_filter_view"),
            ),
            id="TableFilterEditorSplit",
            show_labels=False,
            layout="split",
            orientation="horizontal",
        ),
        id="traitsui.qt4.table_editor.TableFilterEditor",
        buttons=["OK", "Cancel"],
        kind="livemodal",
        resizable=True,
        width=800,
        height=400,
        title="Customize filters",
    )

    # -------------------------------------------------------------------------
    #  Private methods:
    # -------------------------------------------------------------------------

    # -- Trait Property getter/setters ----------------------------------------

    @cached_property
    def _get_selected_filter_view(self):
        view = None
        if self.selected_filter:
            model = self.editor.model
            index = model.mapToSource(model.index(0, 0))
            if index.isValid():
                obj = self.editor.items()[index.row()]
            else:
                obj = None
            view = self.selected_filter.edit_view(obj)
        return view

    @cached_property
    def _get_templates(self):
        templates = [f for f in self.editor.factory.filters if f.template]
        templates.extend(self.filters)
        return templates

    # -- Trait Change Handlers ------------------------------------------------

    def _editor_changed(self):
        self.filters = [
            f.clone_traits() for f in self.editor.factory.filters
            if not f.template
        ]
        self.selected_template = self.templates[0]

    def _add_button_fired(self):
        """ Create a new filter based on the selected template and select it.
        """
        new_filter = self.selected_template.clone_traits()
        new_filter.template = False
        new_filter.name = new_filter._name = "New filter"
        self.filters.append(new_filter)
        self.selected_filter = new_filter

    def _remove_button_fired(self):
        """ Delete the currently selected filter.
        """
        if self.selected_template == self.selected_filter:
            self.selected_template = self.templates[0]

        index = self.filters.index(self.selected_filter)
        del self.filters[index]
        if index < len(self.filters):
            self.selected_filter = self.filters[index]
        else:
            self.selected_filter = None

    @observe("selected_filter:name")
    def _update_filter_list(self, event):
        """ A hack to make the EnumEditor watching the list of filters refresh
            their text when the name of the selected filter changes.
        """
        filters = self.filters
        self.filters = []
        self.filters = filters
Example #26
0
class BaselineView(HasTraits):

    # This mapping should match the flag definitions in libsbp for
    # the MsgBaselineNED message. While this isn't strictly necessary
    # it helps avoid confusion

    python_console_cmds = Dict()

    table = List()

    logging_b = Bool(False)
    directory_name_b = File

    plot = Instance(Plot)
    plot_data = Instance(ArrayPlotData)

    running = Bool(True)
    zoomall = Bool(False)
    position_centered = Bool(False)

    clear_button = SVGButton(
        label='',
        tooltip='Clear',
        filename=resource_filename('console/images/iconic/x.svg'),
        width=16,
        height=16)
    zoomall_button = SVGButton(
        label='',
        tooltip='Zoom All',
        toggle=True,
        filename=resource_filename('console/images/iconic/fullscreen.svg'),
        width=16,
        height=16)
    center_button = SVGButton(
        label='',
        tooltip='Center on Baseline',
        toggle=True,
        filename=resource_filename('console/images/iconic/target.svg'),
        width=16,
        height=16)
    paused_button = SVGButton(
        label='',
        tooltip='Pause',
        toggle_tooltip='Run',
        toggle=True,
        filename=resource_filename('console/images/iconic/pause.svg'),
        toggle_filename=resource_filename('console/images/iconic/play.svg'),
        width=16,
        height=16)

    reset_button = Button(label='Reset Filters')

    traits_view = View(
        HSplit(
            Item(
                'table',
                style='readonly',
                editor=TabularEditor(adapter=SimpleAdapter()),
                show_label=False,
                width=0.3),
            VGroup(
                HGroup(
                    Item('paused_button', show_label=False),
                    Item('clear_button', show_label=False),
                    Item('zoomall_button', show_label=False),
                    Item('center_button', show_label=False),
                    Item('reset_button', show_label=False), ),
                Item(
                    'plot',
                    show_label=False,
                    editor=ComponentEditor(bgcolor=(0.8, 0.8, 0.8)), ))))

    def _zoomall_button_fired(self):
        self.zoomall = not self.zoomall

    def _center_button_fired(self):
        self.position_centered = not self.position_centered

    def _paused_button_fired(self):
        self.running = not self.running

    def _reset_button_fired(self):
        self.link(MsgResetFilters(filter=0))

    def _reset_remove_current(self):
        self.plot_data.set_data('cur_fixed_n', [])
        self.plot_data.set_data('cur_fixed_e', [])
        self.plot_data.set_data('cur_fixed_d', [])
        self.plot_data.set_data('cur_float_n', [])
        self.plot_data.set_data('cur_float_e', [])
        self.plot_data.set_data('cur_float_d', [])
        self.plot_data.set_data('cur_dgnss_n', [])
        self.plot_data.set_data('cur_dgnss_e', [])
        self.plot_data.set_data('cur_dgnss_d', [])

    def _clear_history(self):
        self.plot_data.set_data('n_fixed', [])
        self.plot_data.set_data('e_fixed', [])
        self.plot_data.set_data('d_fixed', [])
        self.plot_data.set_data('n_float', [])
        self.plot_data.set_data('e_float', [])
        self.plot_data.set_data('d_float', [])
        self.plot_data.set_data('n_dgnss', [])
        self.plot_data.set_data('e_dgnss', [])
        self.plot_data.set_data('d_dgnss', [])

    def _clear_button_fired(self):
        self.n[:] = np.NAN
        self.e[:] = np.NAN
        self.d[:] = np.NAN
        self.mode[:] = np.NAN
        self.plot_data.set_data('t', [])
        self._clear_history()
        self._reset_remove_current()

    def iar_state_callback(self, sbp_msg, **metadata):
        self.num_hyps = sbp_msg.num_hyps
        self.last_hyp_update = time.time()

    def age_corrections_callback(self, sbp_msg, **metadata):
        age_msg = MsgAgeCorrections(sbp_msg)
        if age_msg.age != 0xFFFF:
            self.age_corrections = age_msg.age / 10.0
        else:
            self.age_corrections = None

    def gps_time_callback(self, sbp_msg, **metadata):
        if sbp_msg.msg_type == SBP_MSG_GPS_TIME_DEP_A:
            time_msg = MsgGPSTimeDepA(sbp_msg)
            flags = 1
        elif sbp_msg.msg_type == SBP_MSG_GPS_TIME:
            time_msg = MsgGPSTime(sbp_msg)
            flags = time_msg.flags
            if flags != 0:
                self.week = time_msg.wn
                self.nsec = time_msg.ns_residual

    def utc_time_callback(self, sbp_msg, **metadata):
        tmsg = MsgUtcTime(sbp_msg)
        seconds = math.floor(tmsg.seconds)
        microseconds = int(tmsg.ns / 1000.00)
        if tmsg.flags & 0x1 == 1:
            dt = datetime.datetime(tmsg.year, tmsg.month, tmsg.day, tmsg.hours,
                                   tmsg.minutes, tmsg.seconds, microseconds)
            self.utc_time = dt
            self.utc_time_flags = tmsg.flags
            if (tmsg.flags >> 3) & 0x3 == 0:
                self.utc_source = "Factory Default"
            elif (tmsg.flags >> 3) & 0x3 == 1:
                self.utc_source = "Non Volatile Memory"
            elif (tmsg.flags >> 3) & 0x3 == 2:
                self.utc_source = "Decoded this Session"
            else:
                self.utc_source = "Unknown"
        else:
            self.utc_time = None
            self.utc_source = None

    def baseline_heading_callback(self, sbp_msg, **metadata):
        headingMsg = MsgBaselineHeading(sbp_msg)
        if headingMsg.flags & 0x7 != 0:
            self.heading = headingMsg.heading * 1e-3
        else:
            self.heading = "---"

    def baseline_callback(self, sbp_msg, **metadata):
        soln = MsgBaselineNEDDepA(sbp_msg)
        self.last_soln = soln
        table = []

        soln.n = soln.n * 1e-3
        soln.e = soln.e * 1e-3
        soln.d = soln.d * 1e-3
        soln.h_accuracy = soln.h_accuracy * 1e-3
        soln.v_accuracy = soln.v_accuracy * 1e-3

        dist = np.sqrt(soln.n**2 + soln.e**2 + soln.d**2)

        tow = soln.tow * 1e-3
        if self.nsec is not None:
            tow += self.nsec * 1e-9

        ((tloc, secloc), (tgps, secgps)) = log_time_strings(self.week, tow)

        if self.utc_time is not None:
            ((tutc, secutc)) = datetime_2_str(self.utc_time)

        if self.directory_name_b == '':
            filepath = time.strftime("baseline_log_%Y%m%d-%H%M%S.csv")
        else:
            filepath = os.path.join(
                self.directory_name_b,
                time.strftime("baseline_log_%Y%m%d-%H%M%S.csv"))

        if not self.logging_b:
            self.log_file = None

        if self.logging_b:
            if self.log_file is None:
                self.log_file = sopen(filepath, 'w')
                self.log_file.write(
                    'pc_time,gps_time,tow(sec),north(meters),east(meters),down(meters),h_accuracy(meters),v_accuracy(meters),'
                    'distance(meters),num_sats,flags,num_hypothesis\n')
            log_str_gps = ''
            if tgps != '' and secgps != 0:
                log_str_gps = "{0}:{1:06.6f}".format(tgps, float(secgps))
            self.log_file.write(
                '%s,%s,%.3f,%.4f,%.4f,%.4f,%.4f,%.4f,%.4f,%d,%d,%d\n' %
                ("{0}:{1:06.6f}".format(tloc, float(secloc)), log_str_gps, tow,
                 soln.n, soln.e, soln.d, soln.h_accuracy, soln.v_accuracy,
                 dist, soln.n_sats, soln.flags, self.num_hyps))
            self.log_file.flush()

        self.last_mode = get_mode(soln)

        if self.last_mode < 1:
            table.append(('GPS Week', EMPTY_STR))
            table.append(('GPS TOW', EMPTY_STR))
            table.append(('GPS Time', EMPTY_STR))
            table.append(('UTC Time', EMPTY_STR))
            table.append(('UTC Src', EMPTY_STR))
            table.append(('N', EMPTY_STR))
            table.append(('E', EMPTY_STR))
            table.append(('D', EMPTY_STR))
            table.append(('Horiz Acc', EMPTY_STR))
            table.append(('Vert Acc', EMPTY_STR))
            table.append(('Dist.', EMPTY_STR))
            table.append(('Sats Used', EMPTY_STR))
            table.append(('Flags', EMPTY_STR))
            table.append(('Mode', EMPTY_STR))
        else:
            self.last_btime_update = time.time()
            if self.week is not None:
                table.append(('GPS Week', str(self.week)))
            table.append(('GPS TOW', "{:.3f}".format(tow)))

            if self.week is not None:
                table.append(('GPS Time', "{0}:{1:06.3f}".format(
                    tgps, float(secgps))))
            if self.utc_time is not None:
                table.append(('UTC Time', "{0}:{1:06.3f}".format(
                    tutc, float(secutc))))
                table.append(('UTC Src', self.utc_source))

            table.append(('N', soln.n))
            table.append(('E', soln.e))
            table.append(('D', soln.d))
            table.append(('Horiz Acc', soln.h_accuracy))
            table.append(('Vert Acc', soln.v_accuracy))
            table.append(('Dist.', "{0:.3f}".format(dist)))

            table.append(('Sats Used', soln.n_sats))

        table.append(('Flags', '0x%02x' % soln.flags))
        table.append(('Mode', mode_dict[self.last_mode]))
        if self.heading is not None:
            table.append(('Heading', self.heading))
        if self.age_corrections is not None:
            table.append(('Corr. Age [s]', self.age_corrections))
        self.table = table
        # Rotate array, deleting oldest entries to maintain
        # no more than N in plot
        self.n[1:] = self.n[:-1]
        self.e[1:] = self.e[:-1]
        self.d[1:] = self.d[:-1]
        self.mode[1:] = self.mode[:-1]

        # Insert latest position
        if self.last_mode > 1:
            self.n[0], self.e[0], self.d[0] = soln.n, soln.e, soln.d
        else:
            self.n[0], self.e[0], self.d[0] = [np.NAN, np.NAN, np.NAN]
        self.mode[0] = self.last_mode

    def solution_draw(self):
        if self.running:
            GUI.invoke_later(self._solution_draw)

    def _solution_draw(self):
        self._clear_history()
        soln = self.last_soln
        if np.any(self.mode):
            float_indexer = (self.mode == FLOAT_MODE)
            fixed_indexer = (self.mode == FIXED_MODE)
            dgnss_indexer = (self.mode == DGNSS_MODE)

            if np.any(fixed_indexer):
                self.plot_data.set_data('n_fixed', self.n[fixed_indexer])
                self.plot_data.set_data('e_fixed', self.e[fixed_indexer])
                self.plot_data.set_data('d_fixed', self.d[fixed_indexer])
            if np.any(float_indexer):
                self.plot_data.set_data('n_float', self.n[float_indexer])
                self.plot_data.set_data('e_float', self.e[float_indexer])
                self.plot_data.set_data('d_float', self.d[float_indexer])
            if np.any(dgnss_indexer):
                self.plot_data.set_data('n_dgnss', self.n[dgnss_indexer])
                self.plot_data.set_data('e_dgnss', self.e[dgnss_indexer])
                self.plot_data.set_data('d_dgnss', self.d[dgnss_indexer])

            # Update our last solution icon
            if self.last_mode == FIXED_MODE:
                self._reset_remove_current()
                self.plot_data.set_data('cur_fixed_n', [soln.n])
                self.plot_data.set_data('cur_fixed_e', [soln.e])
                self.plot_data.set_data('cur_fixed_d', [soln.d])
            elif self.last_mode == FLOAT_MODE:
                self._reset_remove_current()
                self.plot_data.set_data('cur_float_n', [soln.n])
                self.plot_data.set_data('cur_float_e', [soln.e])
                self.plot_data.set_data('cur_float_d', [soln.d])
            elif self.last_mode == DGNSS_MODE:
                self._reset_remove_current()
                self.plot_data.set_data('cur_dgnss_n', [soln.n])
                self.plot_data.set_data('cur_dgnss_e', [soln.e])
                self.plot_data.set_data('cur_dgnss_d', [soln.d])
            else:
                pass
        # make the zoomall win over the position centered button
        # position centered button has no effect when zoom all enabled

        if not self.zoomall and self.position_centered:
            d = (self.plot.index_range.high - self.plot.index_range.low) / 2.
            self.plot.index_range.set_bounds(soln.e - d, soln.e + d)
            d = (self.plot.value_range.high - self.plot.value_range.low) / 2.
            self.plot.value_range.set_bounds(soln.n - d, soln.n + d)

        if self.zoomall:
            plot_square_axes(self.plot, ('e_fixed', 'e_float', 'e_dgnss'),
                             ('n_fixed', 'n_float', 'n_dgnss'))

    def __init__(self, link, plot_history_max=1000, dirname=''):
        super(BaselineView, self).__init__()
        self.log_file = None
        self.directory_name_b = dirname
        self.num_hyps = 0
        self.last_hyp_update = 0
        self.last_btime_update = 0
        self.last_soln = None
        self.last_mode = 0
        self.plot_data = ArrayPlotData(
            n_fixed=[0.0],
            e_fixed=[0.0],
            d_fixed=[0.0],
            n_float=[0.0],
            e_float=[0.0],
            d_float=[0.0],
            n_dgnss=[0.0],
            e_dgnss=[0.0],
            d_dgnss=[0.0],
            t=[0.0],
            ref_n=[0.0],
            ref_e=[0.0],
            ref_d=[0.0],
            cur_fixed_e=[],
            cur_fixed_n=[],
            cur_fixed_d=[],
            cur_float_e=[],
            cur_float_n=[],
            cur_float_d=[],
            cur_dgnss_e=[],
            cur_dgnss_n=[],
            cur_dgnss_d=[])

        self.plot_history_max = plot_history_max
        self.n = np.zeros(plot_history_max)
        self.e = np.zeros(plot_history_max)
        self.d = np.zeros(plot_history_max)
        self.mode = np.zeros(plot_history_max)

        self.plot = Plot(self.plot_data)
        pts_float = self.plot.plot(
            ('e_float', 'n_float'),
            type='scatter',
            color=color_dict[FLOAT_MODE],
            marker='dot',
            line_width=0.0,
            marker_size=1.0)
        pts_fixed = self.plot.plot(  # noqa: F841
            ('e_fixed', 'n_fixed'),
            type='scatter',
            color=color_dict[FIXED_MODE],
            marker='dot',
            line_width=0.0,
            marker_size=1.0)
        pts_dgnss = self.plot.plot(  # noqa: F841
            ('e_dgnss', 'n_dgnss'),
            type='scatter',
            color=color_dict[DGNSS_MODE],
            marker='dot',
            line_width=0.0,
            marker_size=1.0)
        ref = self.plot.plot(
            ('ref_e', 'ref_n'),
            type='scatter',
            color='red',
            marker='plus',
            marker_size=5,
            line_width=1.5)
        cur_fixed = self.plot.plot(
            ('cur_fixed_e', 'cur_fixed_n'),
            type='scatter',
            color=color_dict[FIXED_MODE],
            marker='plus',
            marker_size=5,
            line_width=1.5)
        cur_float = self.plot.plot(
            ('cur_float_e', 'cur_float_n'),
            type='scatter',
            color=color_dict[FLOAT_MODE],
            marker='plus',
            marker_size=5,
            line_width=1.5)
        cur_dgnss = self.plot.plot(
            ('cur_dgnss_e', 'cur_dgnss_n'),
            type='scatter',
            color=color_dict[DGNSS_MODE],
            marker='plus',
            line_width=1.5,
            marker_size=5)
        plot_labels = [' Base Position', 'DGPS', 'RTK Float', 'RTK Fixed']
        plots_legend = dict(
            zip(plot_labels, [ref, cur_dgnss, cur_float, cur_fixed]))
        self.plot.legend.plots = plots_legend
        self.plot.legend.labels = plot_labels  # sets order
        self.plot.legend.visible = True

        self.plot.index_axis.tick_label_position = 'inside'
        self.plot.index_axis.tick_label_color = 'gray'
        self.plot.index_axis.tick_color = 'gray'
        self.plot.index_axis.title = 'E (meters)'
        self.plot.index_axis.title_spacing = 5
        self.plot.value_axis.tick_label_position = 'inside'
        self.plot.value_axis.tick_label_color = 'gray'
        self.plot.value_axis.tick_color = 'gray'
        self.plot.value_axis.title = 'N (meters)'
        self.plot.value_axis.title_spacing = 5
        self.plot.padding = (25, 25, 25, 25)

        self.plot.tools.append(PanTool(self.plot))
        zt = ZoomTool(
            self.plot, zoom_factor=1.1, tool_mode="box", always_on=False)
        self.plot.overlays.append(zt)

        self.week = None
        self.utc_time = None
        self.age_corrections = None
        self.heading = "---"
        self.nsec = 0

        self.link = link
        self.link.add_callback(self.baseline_callback, [
            SBP_MSG_BASELINE_NED, SBP_MSG_BASELINE_NED_DEP_A
        ])
        self.link.add_callback(self.baseline_heading_callback,
                               [SBP_MSG_BASELINE_HEADING])
        self.link.add_callback(self.iar_state_callback, SBP_MSG_IAR_STATE)
        self.link.add_callback(self.gps_time_callback,
                               [SBP_MSG_GPS_TIME, SBP_MSG_GPS_TIME_DEP_A])
        self.link.add_callback(self.utc_time_callback, [SBP_MSG_UTC_TIME])
        self.link.add_callback(self.age_corrections_callback,
                               SBP_MSG_AGE_CORRECTIONS)

        call_repeatedly(0.2, self.solution_draw)

        self.python_console_cmds = {'baseline': self}
Example #27
0
class WaitControl(Loggable):
    page_name = Str('Wait')
    message = Str
    message_color = Color('black')

    high = Float
    duration = Float(10)

    current_time = Float

    auto_start = Bool(False)
    timer = None
    end_evt = None

    continue_button = Button('Continue')

    _continued = Bool
    _canceled = Bool
    _no_update = False

    def __init__(self, *args, **kw):
        self.reset()
        super(WaitControl, self).__init__(*args, **kw)
        if self.auto_start:
            self.start(evt=self.end_evt)

    def is_active(self):
        if self.timer:
            return self.timer.isActive()

    def is_canceled(self):
        return self._canceled

    def is_continued(self):
        return self._continued

    def join(self, evt=None):
        if evt is None:
            evt = self.end_evt
        time.sleep(0.25)
        # while not self.end_evt.is_set():
        while not evt.is_set():
            # time.sleep(0.005)
            evt.wait(0.005)

        self.debug('Join finished')

    def start(self, block=True, evt=None, duration=None, message=None):
        if self.end_evt:
            self.end_evt.set()

        if evt is None:
            evt = Event()

        if evt:
            evt.clear()
            self.end_evt = evt

        if self.timer:
            self.timer.stop()
            self.timer.wait_for_completion()

        if duration:
            # self.duration = 1
            self.duration = duration
            self.reset()

        if message:
            self.message = message

        self.timer = Timer(1000, self._update_time, delay=1000)
        self._continued = False

        if block:
            self.join(evt=evt)
            if evt == self.end_evt:
                self.end_evt = None

    def stop(self):
        self._end()
        self.debug('wait dialog stopped')
        if self.current_time > 1:
            self.message = 'Stopped'
            self.message_color = 'red'
            # self.current_time = 0

    def reset(self):
        with no_update(self, fire_update_needed=False):
            self.high = self.duration
            self.current_time = self.duration

    # ===============================================================================
    # private
    # ===============================================================================

    def _continue(self):
        self._continued = True
        self._end()
        self.current_time = 0

    def _end(self):
        self.message = ''

        if self.timer is not None:
            self.timer.Stop()
        if self.end_evt is not None:
            self.end_evt.set()

    def _update_time(self):
        ct = self.current_time
        if self.timer and self.timer.isActive():
            self.current_time -= 1
            ct -= 1
            # self.debug('Current Time={}/{}'.format(ct, self.duration))
            if ct <= 0:
                self._end()
                self._canceled = False
            else:
                self.current_time = ct

                # def _current_time_changed(self):
                # if self.current_time <= 0:
                # self._end()
                # self._canceled = False

    # ===============================================================================
    # handlers
    # ===============================================================================
    def _continue_button_fired(self):
        self._continue()

    def _high_changed(self, v):
        if self._no_update:
            return

        self.duration = v
        self.current_time = v
Example #28
0
class IntervalSpectrogram(PlotsInterval):
    name = 'Spectrogram'
    # overload the figure stack to use GridSpec
    _figstack = Instance(GridSpecStack, args=())

    ## channel = Str('all')
    ## _chan_list = Property(depends_on='parent.chan_map')
    # estimations details
    NW = Enum(2.5, np.arange(2, 11, 0.5).tolist())
    lag = Float(25)  # ms
    strip = Float(100)  # ms
    detrend = Bool(True)
    adaptive = Bool(True)
    over_samp = Range(0, 16, mode='spinner', value=4)
    high_res = Bool(True)
    _bandwidth = Property(Float, depends_on='NW, strip')

    baseline = Float(1.0)  # sec
    normalize = Enum('Z score', ('None', 'Z score', 'divide', 'subtract'))
    plot = Button('Plot')
    freq_hi = Float
    freq_lo = Float(0)
    log_freq = Bool(False)
    new_figure = Bool(True)

    colormap = 'Spectral_r'

    def __init__(self, **traits):
        HasTraits.__init__(self, **traits)
        self.freq_hi = round((self.parent.x_scale**-1.0) / 2.0)

    # @property_depends_on('NW')
    @cached_property
    def _get__bandwidth(self):
        #t1, t2 = self.parent._qtwindow.current_frame()
        T = self.strip * 1e-3
        # TW = NW --> W = NW / T
        return 2.0 * self.NW / T

    def __default_mtm_kwargs(self, strip):
        kw = dict()
        kw['NW'] = self.NW
        kw['adaptive_weights'] = self.adaptive
        Fs = self.parent.x_scale**-1.0
        kw['Fs'] = Fs
        kw['jackknife'] = False
        kw['detrend'] = 'linear' if self.detrend else ''
        lag = int(Fs * self.lag / 1000.)
        if strip is None:
            strip = int(Fs * self.strip / 1000.)
        kw['nfft'] = nextpow2(strip)
        kw['pl'] = 1.0 - float(lag) / strip
        if self.high_res:
            kw['samp_factor'] = self.over_samp
            kw['low_bias'] = 0.95
        return kw, strip

    def _normalize(self, tx, ptf, mode, tc):
        # ptf should be ([n_chan], n_freq, n_time)
        if not mode:
            mode = self.normalize
        if not tc:
            tc = self.baseline

        m = tx < tc
        if not m.any():
            print('Baseline too short: using 2 bins')
            m[:2] = True
        if ptf.ndim < 3:
            ptf = ptf.reshape((1, ) + ptf.shape)
        nf = ptf.shape[1]
        # subtraction and z-score normalizations should be done with log transform
        if mode.lower() in ('subtract', 'z score'):
            ptf = np.log10(ptf)
        p_bl = ptf[..., m].transpose(0, 2, 1).copy().reshape(-1, nf)

        # get mean of baseline for each freq.
        jn = Jackknife(p_bl, axis=0)
        mn = jn.estimate(np.mean, se=False)
        # smooth out mean across frequencies (5 hz sigma) but only replace after 2NW points
        bias_pts = int(2 * self.NW)
        mn[bias_pts:] = gaussian_filter1d(mn, self.NW,
                                          mode='reflect')[bias_pts:]
        assert len(mn) == nf, '?'
        if mode.lower() == 'divide':
            ptf = ptf.mean(0) / mn[:, None]
        elif mode.lower() == 'subtract':
            ptf = ptf.mean(0) - mn[:, None]
        elif mode.lower() == 'z score':
            stdev = jn.estimate(np.std, se=False)
            stdev[bias_pts:] = gaussian_filter1d(stdev,
                                                 self.NW,
                                                 mode='reflect')[bias_pts:]
            ptf = (ptf.mean(0) - mn[:, None]) / stdev[:, None]
        return ptf

    def highres_spectrogram(self, array, strip=None, **mtm_kw):
        kw, strip = self.__default_mtm_kwargs(strip)
        kw.update(**mtm_kw)
        tx, fx, ptf = msp.mtm_spectrogram(array, strip, **kw)
        n_cut = 6
        tx = tx[n_cut:-n_cut]
        ptf = ptf[..., n_cut:-n_cut].copy()
        return tx, fx, ptf

    def spectrogram(self, array, strip=None, **mtm_kw):
        kw, strip = self.__default_mtm_kwargs(strip)
        kw.update(**mtm_kw)
        tx, fx, ptf = msp.mtm_spectrogram_basic(array, strip, **kw)
        return tx, fx, ptf

    def _plot_fired(self):
        x, y = self.curve_manager.interactive_curve.current_data()
        y *= 1e6
        if self.channel.lower() != 'all':
            i, j = list(map(float, self.channel.split(',')))
            y = y[self.chan_map.lookup(i, j)]

        if self.high_res:
            tx, fx, ptf = self.highres_spectrogram(y)
        else:
            tx, fx, ptf = self.spectrogram(y)

        if self.normalize.lower() != 'none':
            ptf = self._normalize(tx, ptf, self.normalize, self.baseline)
        else:
            ptf = np.log10(ptf)
            if ptf.ndim > 2:
                ptf = ptf.mean(0)
        # cut out non-display frequencies and advance timebase to match window
        m = (fx >= self.freq_lo) & (fx <= self.freq_hi)
        ptf = ptf[m]
        fx = fx[m]
        if fx[0] == 0 and self.log_freq:
            fx[0] = 0.5 * (fx[0] + fx[1])
        tx += x[0]
        # fig, ax = self._get_fig()
        fig, gs = self._get_fig(figsize=(8, 10),
                                nrows=2,
                                ncols=2,
                                height_ratios=[1, 3],
                                width_ratios=[20, 1])
        ts_ax = fig.add_subplot(gs[0, 0])
        ts_ax.plot(x, y)
        ts_ax.set_ylabel(r'$\mu$V')
        for t in ts_ax.get_xticklabels():
            t.set_visible(False)
        sg_ax = fig.add_subplot(gs[1, 0])
        im = sg_ax.imshow(ptf,
                          extent=[tx[0], tx[-1], fx[0], fx[-1]],
                          cmap=self.colormap,
                          origin='lower')
        if self.log_freq:
            sg_ax.set_yscale('log')
            sg_ax.set_ylim(fx[1], fx[-1])
        sg_ax.axis('auto')
        sg_ax.set_xlabel('Time (s)')
        sg_ax.set_ylabel('Frequency (Hz)')
        cb_ax = fig.add_subplot(gs[:, 1])
        cbar = fig.colorbar(im, cax=cb_ax, aspect=80)
        cbar.locator = ticker.MaxNLocator(nbins=3)
        if self.normalize.lower() == 'divide':
            title = 'Normalized ratio'
        elif self.normalize.lower() == 'subtract':
            title = 'Baseline subtracted'
        elif self.normalize.lower() == 'z score':
            title = 'Z score'
        else:
            title = 'Log-power'
        cbar.set_label(title)
        # fig.tight_layout()
        gs.tight_layout(fig, w_pad=0.025)
        ts_ax.set_xlim(sg_ax.get_xlim())
        try:
            fig.canvas.draw_idle()
        except:
            pass

    def default_traits_view(self):
        v = View(
            HGroup(
                VGroup(
                    HGroup(
                        VGroup(
                            Label('Channel to plot'),
                            UItem(
                                'channel',
                                editor=EnumEditor(name='object._chan_list'))),
                        Label('Log-Hz?'), UItem('log_freq'), UItem('plot')),
                    HGroup(Item('freq_lo', label='low', width=3),
                           Item('freq_hi', label='high', width=3),
                           label='Freq. Range')),
                VGroup(Item('high_res', label='High-res SG'),
                       Item('normalize', label='Normalize'),
                       Item('baseline', label='Baseline length (sec)'),
                       label='Spectrogram setup'),
                VGroup(Item('NW'),
                       Item('_bandwidth',
                            label='BW (Hz)',
                            style='readonly',
                            width=4),
                       Item('strip', label='SG strip len (ms)'),
                       Item('lag', label='SG lag (ms)'),
                       Item('detrend', label='Detrend window'),
                       Item('adaptive', label='Adaptive MTM'),
                       label='Estimation details',
                       columns=2),
                VGroup(Item('over_samp', label='Oversamp high-res SG'),
                       enabled_when='high_res')))
        return v
class IonOpticsManager(Manager):
    reference_detector = Instance(BaseDetector)
    reference_isotope = Any

    magnet_dac = Range(0.0, 6.0)
    graph = Instance(Graph)
    peak_center_button = Button('Peak Center')
    stop_button = Button('Stop')

    alive = Bool(False)
    spectrometer = Any

    peak_center = Instance(PeakCenter)
    coincidence = Instance(Coincidence)
    peak_center_config = Instance(PeakCenterConfigurer)
    # coincidence_config = Instance(CoincidenceConfig)
    canceled = False

    peak_center_result = None

    _centering_thread = None

    def close(self):
        self.cancel_peak_center()

    def cancel_peak_center(self):
        self.alive = False
        self.canceled = True
        self.peak_center.canceled = True
        self.peak_center.stop()
        self.info('peak center canceled')

    def get_mass(self, isotope_key):
        spec = self.spectrometer
        molweights = spec.molecular_weights
        return molweights[isotope_key]

    def mftable_ctx(self, mftable):
        return MFTableCTX(self, mftable)

    def set_mftable(self, name=None):
        """
            if mt is None set to the default mftable located at setupfiles/spectrometer/mftable.csv
        :param mt:
        :return:
        """
        if name and name != os.path.splitext(os.path.basename(
                paths.mftable))[0]:
            self.spectrometer.use_deflection_correction = False
        else:
            self.spectrometer.use_deflection_correction = True

        self.spectrometer.magnet.set_mftable(name)

    def get_position(self, *args, **kw):
        kw['update_isotopes'] = False
        return self._get_position(*args, **kw)

    def av_position(self, pos, detector, *args, **kw):
        av = self._get_av_position(pos, detector)
        self.spectrometer.source.set_hv(av)
        self.info('positioning {} ({}) on {}'.format(pos, av, detector))
        return av

    def position(self, pos, detector, use_af_demag=True, *args, **kw):
        dac = self._get_position(pos, detector, *args, **kw)
        mag = self.spectrometer.magnet

        self.info('positioning {} ({}) on {}'.format(pos, dac, detector))
        return mag.set_dac(dac, use_af_demag=use_af_demag)

    def do_coincidence_scan(self, new_thread=True):

        if new_thread:
            t = Thread(name='ion_optics.coincidence', target=self._coincidence)
            t.start()
            self._centering_thread = t

    def setup_coincidence(self):
        pcc = self.coincidence_config
        pcc.dac = self.spectrometer.magnet.dac

        info = pcc.edit_traits()
        if not info.result:
            return

        detector = pcc.detector
        isotope = pcc.isotope
        detectors = [d for d in pcc.additional_detectors]
        # integration_time = pcc.integration_time

        if pcc.use_nominal_dac:
            center_dac = self.get_position(isotope, detector)
        elif pcc.use_current_dac:
            center_dac = self.spectrometer.magnet.dac
        else:
            center_dac = pcc.dac

        # self.spectrometer.save_integration()
        # self.spectrometer.set_integration(integration_time)

        cs = Coincidence(spectrometer=self.spectrometer,
                         center_dac=center_dac,
                         reference_detector=detector,
                         reference_isotope=isotope,
                         additional_detectors=detectors)
        self.coincidence = cs
        return cs

    def get_center_dac(self, det, iso):
        spec = self.spectrometer
        det = spec.get_detector(det)

        molweights = spec.molecular_weights
        mass = molweights[iso]
        dac = spec.magnet.map_mass_to_dac(mass, det.name)

        # correct for deflection
        return spec.correct_dac(det, dac)

    def do_peak_center(self,
                       save=True,
                       confirm_save=False,
                       warn=False,
                       new_thread=True,
                       message='',
                       on_end=None,
                       timeout=None):
        self.debug('doing pc')

        self.canceled = False
        self.alive = True
        self.peak_center_result = None

        args = (save, confirm_save, warn, message, on_end, timeout)
        if new_thread:
            t = Thread(name='ion_optics.peak_center',
                       target=self._peak_center,
                       args=args)
            t.start()
            self._centering_thread = t
            return t
        else:
            self._peak_center(*args)

    def setup_peak_center(self,
                          detector=None,
                          isotope=None,
                          integration_time=1.04,
                          directions='Increase',
                          center_dac=None,
                          name='',
                          show_label=False,
                          window=0.015,
                          step_width=0.0005,
                          min_peak_height=1.0,
                          percent=80,
                          deconvolve=None,
                          use_interpolation=False,
                          interpolation_kind='linear',
                          dac_offset=None,
                          calculate_all_peaks=False,
                          config_name=None,
                          use_configuration_dac=True,
                          new=False,
                          update_others=True,
                          plot_panel=None):

        if deconvolve is None:
            n_peaks, select_peak = 1, 1

        use_dac_offset = False
        if dac_offset is not None:
            use_dac_offset = True

        spec = self.spectrometer
        pcconfig = self.peak_center_config

        spec.save_integration()
        self.debug('setup peak center. detector={}, isotope={}'.format(
            detector, isotope))

        pcc = None
        dataspace = 'dac'
        use_accel_voltage = False
        use_extend = False
        self._setup_config()
        if config_name:
            pcconfig.load()
            pcconfig.active_name = config_name
            pcc = pcconfig.active_item

        elif detector is None or isotope is None:
            self.debug('ask user for peak center configuration')

            pcconfig.load()
            if config_name:
                pcconfig.active_name = config_name

            info = pcconfig.edit_traits()

            if not info.result:
                return
            else:
                pcc = pcconfig.active_item

        if pcc:
            if not detector:
                detector = pcc.active_detectors

            if not isotope:
                isotope = pcc.isotope

            directions = pcc.directions
            integration_time = pcc.integration_time

            dataspace = pcc.dataspace
            use_accel_voltage = pcc.use_accel_voltage
            use_extend = pcc.use_extend
            window = pcc.window
            min_peak_height = pcc.min_peak_height
            step_width = pcc.step_width
            percent = pcc.percent
            use_interpolation = pcc.use_interpolation
            interpolation_kind = pcc.interpolation_kind
            n_peaks = pcc.n_peaks
            select_peak = pcc.select_n_peak
            use_dac_offset = pcc.use_dac_offset
            dac_offset = pcc.dac_offset
            calculate_all_peaks = pcc.calculate_all_peaks
            update_others = pcc.update_others
            if not pcc.use_mftable_dac and center_dac is None and use_configuration_dac:
                center_dac = pcc.dac

        spec.set_integration_time(integration_time)
        period = int(integration_time * 1000 * 0.9)

        if not isinstance(detector, (tuple, list)):
            detector = (detector, )

        ref = spec.get_detector(detector[0])

        if center_dac is None:
            center_dac = self.get_center_dac(ref, isotope)

        # if mass:
        #     mag = spec.magnet
        #     center_dac = mag.map_mass_to_dac(mass, ref)
        #     low = mag.map_mass_to_dac(mass - window / 2., ref)
        #     high = mag.map_mass_to_dac(mass + window / 2., ref)
        #     window = high - low
        #     step_width = abs(mag.map_mass_to_dac(mass + step_width, ref) - center_dac)

        if len(detector) > 1:
            ad = detector[1:]
        else:
            ad = []

        pc = self.peak_center
        klass = AccelVoltagePeakCenter if use_accel_voltage else PeakCenter
        if not pc or new or (use_accel_voltage
                             and not isinstance(pc, AccelVoltagePeakCenter)):
            pc = klass()

        pc.trait_set(center_dac=center_dac,
                     dataspace=dataspace,
                     use_accel_voltage=use_accel_voltage,
                     use_extend=use_extend,
                     period=period,
                     window=window,
                     percent=percent,
                     min_peak_height=min_peak_height,
                     step_width=step_width,
                     directions=directions,
                     reference_detector=ref,
                     additional_detectors=ad,
                     reference_isotope=isotope,
                     spectrometer=spec,
                     show_label=show_label,
                     use_interpolation=use_interpolation,
                     interpolation_kind=interpolation_kind,
                     n_peaks=n_peaks,
                     select_peak=select_peak,
                     use_dac_offset=use_dac_offset,
                     dac_offset=dac_offset,
                     calculate_all_peaks=calculate_all_peaks,
                     update_others=update_others)

        graph = pc.graph
        graph.name = name
        if plot_panel:
            plot_panel.set_peak_center_graph(graph)

        self.peak_center = pc
        self.reference_detector = ref
        self.reference_isotope = isotope

        return self.peak_center

    def backup_mftable(self):
        self.spectrometer.magnet.field_table.backup()

    # private
    def _setup_config(self):
        config = self.peak_center_config
        config.detectors = self.spectrometer.detector_names
        keys = list(self.spectrometer.molecular_weights.keys())
        config.isotopes = sort_isotopes(keys)
        config.integration_times = self.spectrometer.integration_times

    # def _get_peak_center_config(self, config_name):
    #     if config_name is None:
    #         config_name = 'default'
    #
    #     config = self.peak_center_config.get(config_name)
    #
    #     config.detectors = self.spectrometer.detectors_names
    #     if config.detector_name:
    #         config.detector = next((di for di in config.detectors if di == config.detector_name), None)
    #
    #     if not config.detector:
    #         config.detector = config.detectors[0]
    #
    #     keys = self.spectrometer.molecular_weights.keys()
    #     config.isotopes = sort_isotopes(keys)
    #     config.integration_times = self.spectrometer.integration_times
    #     return config

    # def _timeout_func(self, timeout, evt):
    #     st = time.time()
    #     while not evt.is_set():
    #         if not self.alive:
    #             break
    #
    #         if time.time() - st > timeout:
    #             self.warning('Peak Centering timed out after {}s'.format(timeout))
    #             self.cancel_peak_center()
    #             break
    #
    #         time.sleep(0.01)

    def _peak_center(self, save, confirm_save, warn, message, on_end, timeout):

        pc = self.peak_center
        spec = self.spectrometer
        ref = self.reference_detector
        isotope = self.reference_isotope

        try:
            center_value = pc.get_peak_center()
        except NoIntensityChange as e:
            self.warning(
                'Peak Centering failed. No Intensity change. {}'.format(e))
            center_value = None

        self.peak_center_result = center_value
        if center_value:

            det = spec.get_detector(ref)

            if pc.use_accel_voltage:
                args = ref, isotope, center_value
            else:
                dac_a = spec.uncorrect_dac(det, center_value)
                self.info(
                    'dac uncorrected for HV and deflection {}'.format(dac_a))
                args = ref, isotope, dac_a
                self.adjusted_peak_center_result = dac_a

            self.info('new center pos {} ({}) @ {}'.format(*args))
            if save:
                if confirm_save:
                    msg = 'Update Magnet Field Table with new peak center- {} ({}) @ RefDetUnits= {}'.format(
                        *args)
                    if pc.use_accel_voltage:
                        msg = 'Update Accel Voltage Table with new peak center- {} ({}) @ RefDetUnits= {}'.format(
                            *args)

                    save = self.confirmation_dialog(msg)

                if save:
                    if pc.use_accel_voltage:
                        spec.source.update_field_table(det, isotope,
                                                       center_value, message)
                    else:
                        spec.magnet.update_field_table(
                            det,
                            isotope,
                            dac_a,
                            message,
                            update_others=pc.update_others)
                        spec.magnet.set_dac(dac_a)

        elif not self.canceled:
            msg = 'centering failed'
            if warn:
                self.warning_dialog(msg)
            self.warning(msg)

            # needs to be called on the main thread to properly update
            # the menubar actions. alive=False enables IonOptics>Peak Center
        # d = lambda:self.trait_set(alive=False)
        # still necessary with qt? and tasks

        if on_end:
            on_end()

        self.trait_set(alive=False)

        self.spectrometer.restore_integration()

    def _get_av_position(self, pos, detector, update_isotopes=True):
        self.debug('AV POSITION {} {}'.format(pos, detector))
        spec = self.spectrometer
        if not isinstance(detector, str):
            detector = detector.name

        if isinstance(pos, str):
            try:
                pos = float(pos)
            except ValueError:
                # pos is isotope
                if update_isotopes:
                    # if the pos is an isotope then update the detectors
                    spec.update_isotopes(pos, detector)
                pos = self.get_mass(pos)

        # pos is mass i.e 39.962
        av = spec.source.map_mass_to_hv(pos, detector)
        return av

    def _get_position(self,
                      pos,
                      detector,
                      use_dac=False,
                      update_isotopes=True):
        """
            pos can be str or float
            "Ar40", "39.962", 39.962

            to set in DAC space set use_dac=True
        """
        if pos == NULL_STR:
            return

        spec = self.spectrometer
        mag = spec.magnet

        if isinstance(detector, str):
            det = spec.get_detector(detector)
        else:
            det = detector

        self.debug('detector {}'.format(det))

        if use_dac:
            dac = pos
        else:
            self.debug('POSITION {} {}'.format(pos, detector))
            if isinstance(pos, str):
                try:
                    pos = float(pos)
                except ValueError:
                    # pos is isotope
                    if update_isotopes:
                        # if the pos is an isotope then update the detectors
                        spec.update_isotopes(pos, detector)
                    pos = self.get_mass(pos)

                mag.mass_change(pos)

            # pos is mass i.e 39.962
            print('det is', det)
            dac = mag.map_mass_to_dac(pos, det.name)

        dac = spec.correct_dac(det, dac)
        return dac

    def _coincidence(self):
        self.coincidence.get_peak_center()
        self.info('coincidence finished')
        self.spectrometer.restore_integration()

    # ===============================================================================
    # handler
    # ===============================================================================
    def _coincidence_config_default(self):
        config = None
        p = os.path.join(paths.hidden_dir, 'coincidence_config.p')
        if os.path.isfile(p):
            try:
                with open(p) as rfile:
                    config = pickle.load(rfile)
                    config.detectors = dets = self.spectrometer.detectors
                    config.detector = next(
                        (di for di in dets if di.name == config.detector_name),
                        None)

            except Exception as e:
                print('coincidence config', e)

        if config is None:
            config = CoincidenceConfig()
            config.detectors = self.spectrometer.detectors
            config.detector = config.detectors[0]

        keys = list(self.spectrometer.molecular_weights.keys())
        config.isotopes = sort_isotopes(keys)

        return config

    def _peak_center_config_default(self):
        config = PeakCenterConfigurer()
        return config
Example #30
0
class GridCell( SDomain ):
    '''
    A single mgrid cell for geometrical representation of the domain.
    
    Based on the grid_cell_spec attribute, 
    the node distribution is determined.
    
    '''
    # Everything depends on the grid_cell_specification
    #
    grid_cell_spec = Instance( CellSpec )
    def _grid_cell_spec_default( self ):
        return CellSpec()

    # Generated grid cell coordinates as they come from mgrid.
    # The dimensionality of the mgrid comes from the 
    # grid_cell_spec_attribute
    #
    grid_cell_coords = Property( depends_on = 'grid_cell_spec' )
    @cached_property
    def _get_grid_cell_coords( self ):
        grid_cell = mgrid[ self.grid_cell_spec.get_cell_slices() ]
        return c_[ tuple( [ x.flatten() for x in grid_cell ] ) ]

    n_nodes = Property( depends_on = 'grid_cell_spec' )
    @cached_property
    def _get_n_nodes( self ):
        '''Return the number of all nodes within the cell.
        '''
        return self.grid_cell_coords.shape[0]

    # Node map lists the active nodes within the grid cell
    # in the specified order
    #
    node_map = Property( Array( int ), depends_on = 'grid_cell_spec' )
    @cached_property
    def _get_node_map( self ):
        n_map = []
        for node in self.grid_cell_spec._node_array:
            for idx, grid_cell_node in enumerate( self.grid_cell_coords ):
                if allclose( node , grid_cell_node , atol = 1.0e-3 ):
                    n_map.append( idx )
                    continue
        return array( n_map, int )

    # #-----------------------------------------------------------------
    # # Visualization related methods
    # #-----------------------------------------------------------------
    # mvp_mgrid_ngeo_labels = Trait( MVPointLabels )
    # def _mvp_mgrid_ngeo_labels_default( self ):
    #     return MVPointLabels( name = 'Geo node numbers',
    #                               points = self._get_points,
    #                               scalars = self._get_node_distribution )

    refresh_button = Button( 'Draw' )
    @on_trait_change( 'refresh_button' )
    def redraw( self ):
        '''
        '''
        self.mvp_mgrid_ngeo_labels.redraw( 'label_scalars' )

    def _get_points( self ):
        points = self.grid_cell_coords[ ix_( self.node_map ) ]
        shape = points.shape
        if shape[1] < 3:
            _points = zeros( ( shape[0], 3 ), dtype = float )
            _points[:, 0:shape[1]] = points
            return _points
        else:
            return points

    def _get_node_distribution( self ):
        #return arange(len(self.node_map))
        n_points = self.grid_cell_coords.shape[0]
        full_node_map = ones( n_points, dtype = float ) * -1.
        full_node_map[ ix_( self.node_map ) ] = arange( len( self.node_map ) )
        return full_node_map

    def __getitem__( self, idx ):
        # construct the full boolean map of the grid cell'        
        node_bool_map = repeat( False, self.n_nodes ).reshape( self.grid_cell_spec.get_cell_shape() )
        # put true at the sliced positions         
        node_bool_map[idx] = True
        # extract the used nodes using the node map
        node_selection = node_bool_map.flatten()[ self.node_map ]
        return node_selection

    #------------------------------------------------------------------
    # UI - related methods
    #------------------------------------------------------------------
    traits_view = View( Item( 'grid_cell_spec' ),
                       Item( 'refresh_button' ),
                       Item( 'node_map' ),
                       resizable = True,
                       height = 0.5,
                       width = 0.5 )